use std::mem::replace; use futures::{Future, Sink, Poll, Async}; use futures::stream::Stream; use futures::sink; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_tls::*; use native_tls::TlsConnector; use minidom::Element; use jid::Jid; use xmpp_codec::*; use xmpp_stream::*; use stream_start::StreamStart; pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls"; pub struct StartTlsClient { state: StartTlsClientState, jid: Jid, } enum StartTlsClientState { Invalid, SendStartTls(sink::Send>), AwaitProceed(XMPPStream), StartingTls(ConnectAsync), Start(StreamStart>), } impl StartTlsClient { /// Waits for pub fn from_stream(xmpp_stream: XMPPStream) -> Self { let jid = xmpp_stream.jid.clone(); let nonza = Element::builder("starttls") .ns(NS_XMPP_TLS) .build(); let packet = Packet::Stanza(nonza); let send = xmpp_stream.send(packet); StartTlsClient { state: StartTlsClientState::SendStartTls(send), jid, } } } impl Future for StartTlsClient { type Item = XMPPStream>; type Error = String; fn poll(&mut self) -> Poll { let old_state = replace(&mut self.state, StartTlsClientState::Invalid); let mut retry = false; let (new_state, result) = match old_state { StartTlsClientState::SendStartTls(mut send) => match send.poll() { Ok(Async::Ready(xmpp_stream)) => { let new_state = StartTlsClientState::AwaitProceed(xmpp_stream); retry = true; (new_state, Ok(Async::NotReady)) }, Ok(Async::NotReady) => (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady)), Err(e) => (StartTlsClientState::SendStartTls(send), Err(format!("{}", e))), }, StartTlsClientState::AwaitProceed(mut xmpp_stream) => match xmpp_stream.poll() { Ok(Async::Ready(Some(Packet::Stanza(ref stanza)))) if stanza.name() == "proceed" => { let stream = xmpp_stream.stream.into_inner(); let connect = TlsConnector::builder().unwrap() .build().unwrap() .connect_async(&self.jid.domain, stream); let new_state = StartTlsClientState::StartingTls(connect); retry = true; (new_state, Ok(Async::NotReady)) }, Ok(Async::Ready(value)) => { println!("StartTlsClient ignore {:?}", value); (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)) }, Ok(_) => (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)), Err(e) => (StartTlsClientState::AwaitProceed(xmpp_stream), Err(format!("{}", e))), }, StartTlsClientState::StartingTls(mut connect) => match connect.poll() { Ok(Async::Ready(tls_stream)) => { println!("TLS stream established"); let start = XMPPStream::from_stream(tls_stream, self.jid.clone()); let new_state = StartTlsClientState::Start(start); retry = true; (new_state, Ok(Async::NotReady)) }, Ok(Async::NotReady) => (StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)), Err(e) => (StartTlsClientState::StartingTls(connect), Err(format!("{}", e))), }, StartTlsClientState::Start(mut start) => match start.poll() { Ok(Async::Ready(xmpp_stream)) => (StartTlsClientState::Invalid, Ok(Async::Ready(xmpp_stream))), Ok(Async::NotReady) => (StartTlsClientState::Start(start), Ok(Async::NotReady)), Err(e) => (StartTlsClientState::Invalid, Err(format!("{}", e))), }, StartTlsClientState::Invalid => unreachable!(), }; self.state = new_state; if retry { self.poll() } else { result } } }