use std::mem::replace; use std::str::FromStr; use futures::{Future, Poll, Async, sink, Stream}; use tokio_io::{AsyncRead, AsyncWrite}; use sasl::common::Credentials; use sasl::common::scram::{Sha1, Sha256}; use sasl::client::Mechanism; use sasl::client::mechanisms::{Scram, Plain, Anonymous}; use minidom::Element; use xmpp_parsers::sasl::{Auth, Challenge, Response, Success, Failure, Mechanism as XMPPMechanism}; use try_from::TryFrom; use xmpp_codec::Packet; use xmpp_stream::XMPPStream; use stream_start::StreamStart; use {Error, AuthError, ProtocolError}; const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl"; pub struct ClientAuth { state: ClientAuthState, mechanism: Box, } enum ClientAuthState { WaitSend(sink::Send>), WaitRecv(XMPPStream), Start(StreamStart), Invalid, } impl ClientAuth { pub fn new(stream: XMPPStream, creds: Credentials) -> Result { let mechs: Vec> = vec![ Box::new(Scram::::from_credentials(creds.clone()).unwrap()), Box::new(Scram::::from_credentials(creds.clone()).unwrap()), Box::new(Plain::from_credentials(creds).unwrap()), Box::new(Anonymous::new()), ]; let mech_names: Vec = match stream.stream_features.get_child("mechanisms", NS_XMPP_SASL) { None => return Err(AuthError::NoMechanism.into()), Some(mechs) => mechs.children() .filter(|child| child.is("mechanism", NS_XMPP_SASL)) .map(|mech_el| mech_el.text()) .collect(), }; // println!("SASL mechanisms offered: {:?}", mech_names); for mut mech in mechs { let name = mech.name().to_owned(); if mech_names.iter().any(|name1| *name1 == name) { // println!("SASL mechanism selected: {:?}", name); let initial = match mech.initial() { Ok(initial) => initial, Err(e) => return Err(AuthError::Sasl(e).into()), }; let mut this = ClientAuth { state: ClientAuthState::Invalid, mechanism: mech, }; let mechanism = match XMPPMechanism::from_str(&name) { Ok(mechanism) => mechanism, Err(e) => return Err(ProtocolError::Parsers(e).into()), }; this.send( stream, Auth { mechanism, data: initial, } ); return Ok(this); } } Err(AuthError::NoMechanism.into()) } fn send>(&mut self, stream: XMPPStream, nonza: N) { let send = stream.send_stanza(nonza); self.state = ClientAuthState::WaitSend(send); } } impl Future for ClientAuth { type Item = XMPPStream; type Error = Error; fn poll(&mut self) -> Poll { let state = replace(&mut self.state, ClientAuthState::Invalid); match state { ClientAuthState::WaitSend(mut send) => match send.poll() { Ok(Async::Ready(stream)) => { self.state = ClientAuthState::WaitRecv(stream); self.poll() }, Ok(Async::NotReady) => { self.state = ClientAuthState::WaitSend(send); Ok(Async::NotReady) }, Err(e) => Err(e.into()), }, ClientAuthState::WaitRecv(mut stream) => match stream.poll() { Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => { if let Ok(challenge) = Challenge::try_from(stanza.clone()) { let response = self.mechanism.response(&challenge.data) .map_err(AuthError::Sasl)?; self.send(stream, Response { data: response }); self.poll() } else if let Ok(_) = Success::try_from(stanza.clone()) { let start = stream.restart(); self.state = ClientAuthState::Start(start); self.poll() } else if let Ok(failure) = Failure::try_from(stanza) { Err(AuthError::Fail(failure.defined_condition).into()) } else { Ok(Async::NotReady) } } Ok(Async::Ready(_event)) => { // println!("ClientAuth ignore {:?}", _event); Ok(Async::NotReady) }, Ok(_) => { self.state = ClientAuthState::WaitRecv(stream); Ok(Async::NotReady) }, Err(e) => Err(ProtocolError::Parser(e).into()) }, ClientAuthState::Start(mut start) => match start.poll() { Ok(Async::Ready(stream)) => Ok(Async::Ready(stream)), Ok(Async::NotReady) => { self.state = ClientAuthState::Start(start); Ok(Async::NotReady) }, Err(e) => Err(e.into()) }, ClientAuthState::Invalid => unreachable!(), } } }