diff --git a/src/client_auth.rs b/src/client_auth.rs index ecb0b78..5ae57b7 100644 --- a/src/client_auth.rs +++ b/src/client_auth.rs @@ -11,6 +11,7 @@ use serialize::base64::{self, ToBase64, FromBase64}; use xmpp_codec::*; use xmpp_stream::*; +use stream_start::*; const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl"; @@ -22,6 +23,7 @@ pub struct ClientAuth { enum ClientAuthState { WaitSend(sink::Send>), WaitRecv(XMPPStream), + Start(StreamStart), Invalid, } @@ -124,7 +126,11 @@ impl Future for ClientAuth { Ok(Async::Ready(Some(Packet::Stanza(ref stanza)))) if stanza.name == "success" && stanza.ns == Some(NS_XMPP_SASL.to_owned()) => - Ok(Async::Ready(stream)), + { + let start = stream.restart(); + self.state = ClientAuthState::Start(start); + self.poll() + }, Ok(Async::Ready(Some(Packet::Stanza(ref stanza)))) if stanza.name == "failure" && stanza.ns == Some(NS_XMPP_SASL.to_owned()) => @@ -153,6 +159,17 @@ impl Future for ClientAuth { Err(e) => Err(format!("{}", e)), }, + ClientAuthState::Start(mut start) => + match start.poll() { + Ok(Async::Ready(stream)) => + Ok(Async::Ready(stream)), + Ok(Async::NotReady) => { + self.state = ClientAuthState::Start(start); + Ok(Async::NotReady) + }, + Err(e) => + Err(format!("{}", e)), + }, ClientAuthState::Invalid => unreachable!(), } diff --git a/src/xmpp_stream.rs b/src/xmpp_stream.rs index e480ffd..56c778c 100644 --- a/src/xmpp_stream.rs +++ b/src/xmpp_stream.rs @@ -31,6 +31,13 @@ impl XMPPStream { self.stream.into_inner() } + pub fn restart(self) -> StreamStart { + let to = self.stream_attrs.get("from") + .map(|s| s.to_owned()) + .unwrap_or_else(|| "".to_owned()); + Self::from_stream(self.into_inner(), to.clone()) + } + pub fn can_starttls(&self) -> bool { self.stream_features .get_child("starttls", Some(NS_XMPP_TLS))