client_auth: add stream restart

This commit is contained in:
Astro 2017-06-06 01:38:48 +02:00
parent f8de49569f
commit 52c60229e3
2 changed files with 25 additions and 1 deletions

View file

@ -11,6 +11,7 @@ use serialize::base64::{self, ToBase64, FromBase64};
use xmpp_codec::*; use xmpp_codec::*;
use xmpp_stream::*; use xmpp_stream::*;
use stream_start::*;
const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl"; const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
@ -22,6 +23,7 @@ pub struct ClientAuth<S: AsyncWrite> {
enum ClientAuthState<S: AsyncWrite> { enum ClientAuthState<S: AsyncWrite> {
WaitSend(sink::Send<XMPPStream<S>>), WaitSend(sink::Send<XMPPStream<S>>),
WaitRecv(XMPPStream<S>), WaitRecv(XMPPStream<S>),
Start(StreamStart<S>),
Invalid, Invalid,
} }
@ -124,7 +126,11 @@ impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
Ok(Async::Ready(Some(Packet::Stanza(ref stanza)))) Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
if stanza.name == "success" if stanza.name == "success"
&& stanza.ns == Some(NS_XMPP_SASL.to_owned()) => && stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
Ok(Async::Ready(stream)), {
let start = stream.restart();
self.state = ClientAuthState::Start(start);
self.poll()
},
Ok(Async::Ready(Some(Packet::Stanza(ref stanza)))) Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
if stanza.name == "failure" if stanza.name == "failure"
&& stanza.ns == Some(NS_XMPP_SASL.to_owned()) => && stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
@ -153,6 +159,17 @@ impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
Err(e) => Err(e) =>
Err(format!("{}", e)), Err(format!("{}", e)),
}, },
ClientAuthState::Start(mut start) =>
match start.poll() {
Ok(Async::Ready(stream)) =>
Ok(Async::Ready(stream)),
Ok(Async::NotReady) => {
self.state = ClientAuthState::Start(start);
Ok(Async::NotReady)
},
Err(e) =>
Err(format!("{}", e)),
},
ClientAuthState::Invalid => ClientAuthState::Invalid =>
unreachable!(), unreachable!(),
} }

View file

@ -31,6 +31,13 @@ impl<S: AsyncRead + AsyncWrite> XMPPStream<S> {
self.stream.into_inner() self.stream.into_inner()
} }
pub fn restart(self) -> StreamStart<S> {
let to = self.stream_attrs.get("from")
.map(|s| s.to_owned())
.unwrap_or_else(|| "".to_owned());
Self::from_stream(self.into_inner(), to.clone())
}
pub fn can_starttls(&self) -> bool { pub fn can_starttls(&self) -> bool {
self.stream_features self.stream_features
.get_child("starttls", Some(NS_XMPP_TLS)) .get_child("starttls", Some(NS_XMPP_TLS))