diff --git a/src/client/auth.rs b/src/client/auth.rs index eadceb0..51bcaa1 100644 --- a/src/client/auth.rs +++ b/src/client/auth.rs @@ -1,37 +1,29 @@ -use futures::{sink, Async, Future, Poll, Stream}; +use std::mem::replace; +use std::str::FromStr; +use futures::{sink, Async, Future, Poll, Stream, future::{ok, err, IntoFuture}}; use minidom::Element; use sasl::client::mechanisms::{Anonymous, Plain, Scram}; use sasl::client::Mechanism; use sasl::common::scram::{Sha1, Sha256}; use sasl::common::Credentials; -use std::mem::replace; -use std::str::FromStr; use tokio_io::{AsyncRead, AsyncWrite}; use try_from::TryFrom; use xmpp_parsers::sasl::{Auth, Challenge, Failure, Mechanism as XMPPMechanism, Response, Success}; -use crate::stream_start::StreamStart; use crate::xmpp_codec::Packet; use crate::xmpp_stream::XMPPStream; use crate::{AuthError, Error, ProtocolError}; const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl"; -pub struct ClientAuth { - state: ClientAuthState, - mechanism: Box, +pub struct ClientAuth { + future: Box, Error = Error>>, } -enum ClientAuthState { - WaitSend(sink::Send>), - WaitRecv(XMPPStream), - Start(StreamStart), - Invalid, -} - -impl ClientAuth { +impl ClientAuth { pub fn new(stream: XMPPStream, creds: Credentials) -> Result { let mechs: Vec> = vec![ + // TODO: Box::new(|| … Box::new(Scram::::from_credentials(creds.clone()).unwrap()), Box::new(Scram::::from_credentials(creds.clone()).unwrap()), Box::new(Plain::from_credentials(creds).unwrap()), @@ -46,36 +38,74 @@ impl ClientAuth { .filter(|child| child.is("mechanism", NS_XMPP_SASL)) .map(|mech_el| mech_el.text()) .collect(); + // TODO: iter instead of collect() // println!("SASL mechanisms offered: {:?}", mech_names); - for mut mech in mechs { - let name = mech.name().to_owned(); + for mut mechanism in mechs { + let name = mechanism.name().to_owned(); if mech_names.iter().any(|name1| *name1 == name) { // println!("SASL mechanism selected: {:?}", name); - let initial = mech.initial().map_err(AuthError::Sasl)?; - let mut this = ClientAuth { - state: ClientAuthState::Invalid, - mechanism: mech, - }; - let mechanism = XMPPMechanism::from_str(&name).map_err(ProtocolError::Parsers)?; - this.send( - stream, - Auth { - mechanism, - data: initial, - }, - ); - return Ok(this); + let initial = mechanism.initial().map_err(AuthError::Sasl)?; + let mechanism_name = XMPPMechanism::from_str(&name).map_err(ProtocolError::Parsers)?; + + let send_initial = Box::new(stream.send_stanza(Auth { + mechanism: mechanism_name, + data: initial, + })) + .map_err(Error::Io); + let future = Box::new(send_initial.and_then( + |stream| Self::handle_challenge(stream, mechanism) + ).and_then( + |stream| stream.restart() + )); + return Ok(ClientAuth { + future, + }); } } Err(AuthError::NoMechanism)? } - fn send>(&mut self, stream: XMPPStream, nonza: N) { - let send = stream.send_stanza(nonza); - - self.state = ClientAuthState::WaitSend(send); + fn handle_challenge(stream: XMPPStream, mut mechanism: Box) -> Box, Error = Error>> { + Box::new( + stream.into_future() + .map_err(|(e, _stream)| e.into()) + .and_then(|(stanza, stream)| { + match stanza { + Some(Packet::Stanza(stanza)) => { + if let Ok(challenge) = Challenge::try_from(stanza.clone()) { + let response = mechanism + .response(&challenge.data); + Box::new( + response + .map_err(|e| AuthError::Sasl(e).into()) + .into_future() + .and_then(|response| { + // Send response and loop + stream.send_stanza(Response { data: response }) + .map_err(Error::Io) + .and_then(|stream| Self::handle_challenge(stream, mechanism)) + }) + ) + } else if let Ok(_) = Success::try_from(stanza.clone()) { + Box::new(ok(stream)) + } else if let Ok(failure) = Failure::try_from(stanza.clone()) { + Box::new(err(Error::Auth(AuthError::Fail(failure.defined_condition)))) + } else { + // ignore and loop + println!("Ignore: {:?}", stanza); + Self::handle_challenge(stream, mechanism) + } + } + Some(_) => { + // ignore and loop + Self::handle_challenge(stream, mechanism) + } + None => Box::new(err(Error::Disconnected)) + } + }) + ) } } @@ -84,58 +114,6 @@ impl Future for ClientAuth { type Error = Error; fn poll(&mut self) -> Poll { - let state = replace(&mut self.state, ClientAuthState::Invalid); - - match state { - ClientAuthState::WaitSend(mut send) => match send.poll() { - Ok(Async::Ready(stream)) => { - self.state = ClientAuthState::WaitRecv(stream); - self.poll() - } - Ok(Async::NotReady) => { - self.state = ClientAuthState::WaitSend(send); - Ok(Async::NotReady) - } - Err(e) => Err(e)?, - }, - ClientAuthState::WaitRecv(mut stream) => match stream.poll() { - Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => { - if let Ok(challenge) = Challenge::try_from(stanza.clone()) { - let response = self - .mechanism - .response(&challenge.data) - .map_err(AuthError::Sasl)?; - self.send(stream, Response { data: response }); - self.poll() - } else if let Ok(_) = Success::try_from(stanza.clone()) { - let start = stream.restart(); - self.state = ClientAuthState::Start(start); - self.poll() - } else if let Ok(failure) = Failure::try_from(stanza) { - Err(AuthError::Fail(failure.defined_condition))? - } else { - Ok(Async::NotReady) - } - } - Ok(Async::Ready(_event)) => { - // println!("ClientAuth ignore {:?}", _event); - Ok(Async::NotReady) - } - Ok(_) => { - self.state = ClientAuthState::WaitRecv(stream); - Ok(Async::NotReady) - } - Err(e) => Err(ProtocolError::Parser(e))?, - }, - ClientAuthState::Start(mut start) => match start.poll() { - Ok(Async::Ready(stream)) => Ok(Async::Ready(stream)), - Ok(Async::NotReady) => { - self.state = ClientAuthState::Start(start); - Ok(Async::NotReady) - } - Err(e) => Err(e), - }, - ClientAuthState::Invalid => unreachable!(), - } + self.future.poll() } } diff --git a/src/client/mod.rs b/src/client/mod.rs index 9e62373..3767013 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -104,7 +104,7 @@ impl Client { StartTlsClient::from_stream(stream) } - fn auth( + fn auth( stream: xmpp_stream::XMPPStream, username: String, password: String, diff --git a/src/error.rs b/src/error.rs index 26d7944..e7fad54 100644 --- a/src/error.rs +++ b/src/error.rs @@ -26,6 +26,8 @@ pub enum Error { Auth(AuthError), /// TLS error Tls(TlsError), + /// Connection closed + Disconnected, /// Shoud never happen InvalidState, } diff --git a/src/xmpp_codec.rs b/src/xmpp_codec.rs index e8c4e05..9c59bac 100644 --- a/src/xmpp_codec.rs +++ b/src/xmpp_codec.rs @@ -19,7 +19,7 @@ use xml5ever::interface::Attribute; use xml5ever::tokenizer::{Tag, TagKind, Token, TokenSink, XmlTokenizer}; /// Anything that can be sent or received on an XMPP/XML stream -#[derive(Debug)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum Packet { /// `` start tag StreamStart(HashMap),