use futures::sink; use futures::stream::Stream; use futures::{Async, Future, Poll, Sink}; use native_tls::TlsConnector as NativeTlsConnector; use std::mem::replace; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_tls::{Connect, TlsConnector, TlsStream}; use xmpp_parsers::{Element, Jid}; use crate::xmpp_codec::Packet; use crate::xmpp_stream::XMPPStream; use crate::Error; /// XMPP TLS XML namespace pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls"; /// XMPP stream that switches to TLS if available in received features pub struct StartTlsClient { state: StartTlsClientState, jid: Jid, } enum StartTlsClientState { Invalid, SendStartTls(sink::Send>), AwaitProceed(XMPPStream), StartingTls(Connect), } impl StartTlsClient { /// Waits for pub fn from_stream(xmpp_stream: XMPPStream) -> Self { let jid = xmpp_stream.jid.clone(); let nonza = Element::builder("starttls").ns(NS_XMPP_TLS).build(); let packet = Packet::Stanza(nonza); let send = xmpp_stream.send(packet); StartTlsClient { state: StartTlsClientState::SendStartTls(send), jid, } } } impl Future for StartTlsClient { type Item = TlsStream; type Error = Error; fn poll(&mut self) -> Poll { let old_state = replace(&mut self.state, StartTlsClientState::Invalid); let mut retry = false; let (new_state, result) = match old_state { StartTlsClientState::SendStartTls(mut send) => match send.poll() { Ok(Async::Ready(xmpp_stream)) => { let new_state = StartTlsClientState::AwaitProceed(xmpp_stream); retry = true; (new_state, Ok(Async::NotReady)) } Ok(Async::NotReady) => { (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady)) } Err(e) => (StartTlsClientState::SendStartTls(send), Err(e.into())), }, StartTlsClientState::AwaitProceed(mut xmpp_stream) => match xmpp_stream.poll() { Ok(Async::Ready(Some(Packet::Stanza(ref stanza)))) if stanza.name() == "proceed" => { let stream = xmpp_stream.stream.into_inner(); let connect = TlsConnector::from(NativeTlsConnector::builder().build().unwrap()) .connect(&self.jid.clone().domain(), stream); let new_state = StartTlsClientState::StartingTls(connect); retry = true; (new_state, Ok(Async::NotReady)) } Ok(Async::Ready(_value)) => { // println!("StartTlsClient ignore {:?}", _value); ( StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady), ) } Ok(_) => ( StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady), ), Err(e) => ( StartTlsClientState::AwaitProceed(xmpp_stream), Err(Error::Protocol(e.into())), ), }, StartTlsClientState::StartingTls(mut connect) => match connect.poll() { Ok(Async::Ready(tls_stream)) => { (StartTlsClientState::Invalid, Ok(Async::Ready(tls_stream))) } Ok(Async::NotReady) => ( StartTlsClientState::StartingTls(connect), Ok(Async::NotReady), ), Err(e) => (StartTlsClientState::Invalid, Err(e.into())), }, StartTlsClientState::Invalid => unreachable!(), }; self.state = new_state; if retry { self.poll() } else { result } } }