From 52c60229e33a82b85ba345e564c3855c08c2ccb5 Mon Sep 17 00:00:00 2001 From: Astro Date: Tue, 6 Jun 2017 01:38:48 +0200 Subject: [PATCH] client_auth: add stream restart --- src/client_auth.rs | 19 ++++++++++++++++++- src/xmpp_stream.rs | 7 +++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/client_auth.rs b/src/client_auth.rs index ecb0b78f..5ae57b76 100644 --- a/src/client_auth.rs +++ b/src/client_auth.rs @@ -11,6 +11,7 @@ use serialize::base64::{self, ToBase64, FromBase64}; use xmpp_codec::*; use xmpp_stream::*; +use stream_start::*; const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl"; @@ -22,6 +23,7 @@ pub struct ClientAuth { enum ClientAuthState { WaitSend(sink::Send>), WaitRecv(XMPPStream), + Start(StreamStart), Invalid, } @@ -124,7 +126,11 @@ impl Future for ClientAuth { Ok(Async::Ready(Some(Packet::Stanza(ref stanza)))) if stanza.name == "success" && stanza.ns == Some(NS_XMPP_SASL.to_owned()) => - Ok(Async::Ready(stream)), + { + let start = stream.restart(); + self.state = ClientAuthState::Start(start); + self.poll() + }, Ok(Async::Ready(Some(Packet::Stanza(ref stanza)))) if stanza.name == "failure" && stanza.ns == Some(NS_XMPP_SASL.to_owned()) => @@ -153,6 +159,17 @@ impl Future for ClientAuth { Err(e) => Err(format!("{}", e)), }, + ClientAuthState::Start(mut start) => + match start.poll() { + Ok(Async::Ready(stream)) => + Ok(Async::Ready(stream)), + Ok(Async::NotReady) => { + self.state = ClientAuthState::Start(start); + Ok(Async::NotReady) + }, + Err(e) => + Err(format!("{}", e)), + }, ClientAuthState::Invalid => unreachable!(), } diff --git a/src/xmpp_stream.rs b/src/xmpp_stream.rs index e480ffdb..56c778c4 100644 --- a/src/xmpp_stream.rs +++ b/src/xmpp_stream.rs @@ -31,6 +31,13 @@ impl XMPPStream { self.stream.into_inner() } + pub fn restart(self) -> StreamStart { + let to = self.stream_attrs.get("from") + .map(|s| s.to_owned()) + .unwrap_or_else(|| "".to_owned()); + Self::from_stream(self.into_inner(), to.clone()) + } + pub fn can_starttls(&self) -> bool { self.stream_features .get_child("starttls", Some(NS_XMPP_TLS))