use std::mem::replace; use futures::{Future, Sink, Poll, Async}; use futures::stream::Stream; use futures::sink; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_tls::*; use native_tls::TlsConnector; use xml; use xmpp_codec::*; use xmpp_stream::*; use stream_start::StreamStart; pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls"; pub struct StartTlsClient { state: StartTlsClientState, domain: String, } enum StartTlsClientState { Invalid, SendStartTls(sink::Send>), AwaitProceed(XMPPStream), StartingTls(ConnectAsync), Start(StreamStart>), } impl StartTlsClient { /// Waits for pub fn from_stream(xmpp_stream: XMPPStream) -> Self { let domain = xmpp_stream.stream_attrs.get("from") .map(|s| s.to_owned()) .unwrap_or_else(|| String::new()); let nonza = xml::Element::new( "starttls".to_owned(), Some(NS_XMPP_TLS.to_owned()), vec![] ); println!("send {}", nonza); let packet = Packet::Stanza(nonza); let send = xmpp_stream.send(packet); StartTlsClient { state: StartTlsClientState::SendStartTls(send), domain, } } } impl Future for StartTlsClient { type Item = XMPPStream>; type Error = String; fn poll(&mut self) -> Poll { let old_state = replace(&mut self.state, StartTlsClientState::Invalid); let mut retry = false; let (new_state, result) = match old_state { StartTlsClientState::SendStartTls(mut send) => match send.poll() { Ok(Async::Ready(xmpp_stream)) => { println!("starttls sent"); let new_state = StartTlsClientState::AwaitProceed(xmpp_stream); retry = true; (new_state, Ok(Async::NotReady)) }, Ok(Async::NotReady) => (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady)), Err(e) => (StartTlsClientState::SendStartTls(send), Err(format!("{}", e))), }, StartTlsClientState::AwaitProceed(mut xmpp_stream) => match xmpp_stream.poll() { Ok(Async::Ready(Some(Packet::Stanza(ref stanza)))) if stanza.name == "proceed" => { println!("* proceed *"); let stream = xmpp_stream.into_inner(); let connect = TlsConnector::builder().unwrap() .build().unwrap() .connect_async(&self.domain, stream); let new_state = StartTlsClientState::StartingTls(connect); retry = true; (new_state, Ok(Async::NotReady)) }, Ok(Async::Ready(value)) => { println!("StartTlsClient ignore {:?}", value); (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)) }, Ok(_) => (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)), Err(e) => (StartTlsClientState::AwaitProceed(xmpp_stream), Err(format!("{}", e))), }, StartTlsClientState::StartingTls(mut connect) => match connect.poll() { Ok(Async::Ready(tls_stream)) => { println!("Got a TLS stream!"); let start = XMPPStream::from_stream(tls_stream, self.domain.clone()); let new_state = StartTlsClientState::Start(start); retry = true; (new_state, Ok(Async::NotReady)) }, Ok(Async::NotReady) => (StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)), Err(e) => (StartTlsClientState::StartingTls(connect), Err(format!("{}", e))), }, StartTlsClientState::Start(mut start) => match start.poll() { Ok(Async::Ready(xmpp_stream)) => (StartTlsClientState::Invalid, Ok(Async::Ready(xmpp_stream))), Ok(Async::NotReady) => (StartTlsClientState::Start(start), Ok(Async::NotReady)), Err(e) => (StartTlsClientState::Invalid, Err(format!("{}", e))), }, StartTlsClientState::Invalid => unreachable!(), }; self.state = new_state; if retry { self.poll() } else { result } } }