diff --git a/Cargo.toml b/Cargo.toml index 39de2186..5da50cca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,3 +11,5 @@ bytes = "*" RustyXML = "*" rustls = "*" tokio-rustls = "*" +sasl = "*" +rustc-serialize = "*" diff --git a/examples/echo_bot.rs b/examples/echo_bot.rs index 0cdd6b07..416e98d5 100644 --- a/examples/echo_bot.rs +++ b/examples/echo_bot.rs @@ -33,6 +33,9 @@ fn main() { } else { panic!("No STARTTLS") } + }).map_err(|e| format!("{}", e) + ).and_then(|stream| { + stream.auth("astrobot", "").expect("auth") }).and_then(|stream| { stream.for_each(|event| { match event { @@ -40,7 +43,7 @@ fn main() { _ => println!("!! {:?}", event), } Ok(()) - }) + }).map_err(|e| format!("{}", e)) }); match core.run(client) { Ok(_) => (), diff --git a/src/client_auth.rs b/src/client_auth.rs new file mode 100644 index 00000000..ecb0b78f --- /dev/null +++ b/src/client_auth.rs @@ -0,0 +1,160 @@ +use std::mem::replace; +use futures::*; +use futures::sink; +use tokio_io::{AsyncRead, AsyncWrite}; +use xml; +use sasl::common::Credentials; +use sasl::common::scram::*; +use sasl::client::Mechanism; +use sasl::client::mechanisms::*; +use serialize::base64::{self, ToBase64, FromBase64}; + +use xmpp_codec::*; +use xmpp_stream::*; + +const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl"; + +pub struct ClientAuth { + state: ClientAuthState, + mechanism: Box, +} + +enum ClientAuthState { + WaitSend(sink::Send>), + WaitRecv(XMPPStream), + Invalid, +} + +impl ClientAuth { + pub fn new(stream: XMPPStream, creds: Credentials) -> Result { + let mechs: Vec> = vec![ + Box::new(Scram::::from_credentials(creds.clone()).unwrap()), + Box::new(Scram::::from_credentials(creds.clone()).unwrap()), + Box::new(Plain::from_credentials(creds).unwrap()), + Box::new(Anonymous::new()), + ]; + + println!("stream_features: {}", stream.stream_features); + let mech_names: Vec = + match stream.stream_features.get_child("mechanisms", Some(NS_XMPP_SASL)) { + None => + return Err("No auth mechanisms".to_owned()), + Some(mechs) => + mechs.get_children("mechanism", Some(NS_XMPP_SASL)) + .map(|mech_el| mech_el.content_str()) + .collect(), + }; + println!("Offered mechanisms: {:?}", mech_names); + + for mut mech in mechs { + let name = mech.name().to_owned(); + if mech_names.iter().any(|name1| *name1 == name) { + println!("Selected mechanism: {:?}", name); + let initial = try!(mech.initial()); + let mut this = ClientAuth { + state: ClientAuthState::Invalid, + mechanism: mech, + }; + this.send( + stream, + "auth", &[("mechanism".to_owned(), name)], + &initial + ); + return Ok(this); + } + } + + Err("No supported SASL mechanism available".to_owned()) + } + + fn send(&mut self, stream: XMPPStream, nonza_name: &str, attrs: &[(String, String)], content: &[u8]) { + let mut nonza = xml::Element::new( + nonza_name.to_owned(), + Some(NS_XMPP_SASL.to_owned()), + attrs.iter() + .map(|&(ref name, ref value)| (name.clone(), None, value.clone())) + .collect() + ); + nonza.text(content.to_base64(base64::URL_SAFE)); + + println!("send {}", nonza); + let send = stream.send(Packet::Stanza(nonza)); + + self.state = ClientAuthState::WaitSend(send); + } +} + +impl Future for ClientAuth { + type Item = XMPPStream; + type Error = String; + + fn poll(&mut self) -> Poll { + let state = replace(&mut self.state, ClientAuthState::Invalid); + + match state { + ClientAuthState::WaitSend(mut send) => + match send.poll() { + Ok(Async::Ready(stream)) => { + println!("send done"); + self.state = ClientAuthState::WaitRecv(stream); + self.poll() + }, + Ok(Async::NotReady) => { + self.state = ClientAuthState::WaitSend(send); + Ok(Async::NotReady) + }, + Err(e) => + Err(format!("{}", e)), + }, + ClientAuthState::WaitRecv(mut stream) => + match stream.poll() { + Ok(Async::Ready(Some(Packet::Stanza(ref stanza)))) + if stanza.name == "challenge" + && stanza.ns == Some(NS_XMPP_SASL.to_owned()) => + { + let content = try!( + stanza.content_str() + .from_base64() + .map_err(|e| format!("{}", e)) + ); + let response = try!(self.mechanism.response(&content)); + self.send(stream, "response", &[], &response); + self.poll() + }, + Ok(Async::Ready(Some(Packet::Stanza(ref stanza)))) + if stanza.name == "success" + && stanza.ns == Some(NS_XMPP_SASL.to_owned()) => + Ok(Async::Ready(stream)), + Ok(Async::Ready(Some(Packet::Stanza(ref stanza)))) + if stanza.name == "failure" + && stanza.ns == Some(NS_XMPP_SASL.to_owned()) => + { + let mut e = None; + for child in &stanza.children { + match child { + &xml::Xml::ElementNode(ref child) => { + e = Some(child.name.clone()); + break + }, + _ => (), + } + } + let e = e.unwrap_or_else(|| "Authentication failure".to_owned()); + Err(e) + }, + Ok(Async::Ready(event)) => { + println!("ClientAuth ignore {:?}", event); + Ok(Async::NotReady) + }, + Ok(_) => { + self.state = ClientAuthState::WaitRecv(stream); + Ok(Async::NotReady) + }, + Err(e) => + Err(format!("{}", e)), + }, + ClientAuthState::Invalid => + unreachable!(), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 9764daea..ceb68424 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,8 @@ extern crate bytes; extern crate xml; extern crate rustls; extern crate tokio_rustls; +extern crate sasl; +extern crate rustc_serialize as serialize; pub mod xmpp_codec; @@ -15,6 +17,8 @@ mod tcp; pub use tcp::*; mod starttls; pub use starttls::*; +mod client_auth; +pub use client_auth::*; // type FullClient = sasl::Client> diff --git a/src/xmpp_stream.rs b/src/xmpp_stream.rs index c4080bb7..e480ffdb 100644 --- a/src/xmpp_stream.rs +++ b/src/xmpp_stream.rs @@ -1,3 +1,4 @@ +use std::default::Default; use std::sync::Arc; use std::collections::HashMap; use futures::*; @@ -5,10 +6,12 @@ use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::codec::Framed; use rustls::ClientConfig; use xml; +use sasl::common::Credentials; use xmpp_codec::*; use stream_start::*; use starttls::{NS_XMPP_TLS, StartTlsClient}; +use client_auth::ClientAuth; pub const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams"; @@ -37,8 +40,16 @@ impl XMPPStream { pub fn starttls(self, arc_config: Arc) -> StartTlsClient { StartTlsClient::from_stream(self, arc_config) } + + pub fn auth(self, username: &str, password: &str) -> Result, String> { + let creds = Credentials::default() + .with_username(username) + .with_password(password); + ClientAuth::new(self, creds) + } } +/// Proxy to self.stream impl Sink for XMPPStream { type SinkItem = as Sink>::SinkItem; type SinkError = as Sink>::SinkError; @@ -52,6 +63,7 @@ impl Sink for XMPPStream { } } +/// Proxy to self.stream impl Stream for XMPPStream { type Item = as Stream>::Item; type Error = as Stream>::Error;