diff --git a/Cargo.toml b/Cargo.toml index 23d59db..39de218 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,3 +9,5 @@ tokio-core = "*" tokio-io = "*" bytes = "*" RustyXML = "*" +rustls = "*" +tokio-rustls = "*" diff --git a/examples/echo_bot.rs b/examples/echo_bot.rs index f1e4aff..6933e1a 100644 --- a/examples/echo_bot.rs +++ b/examples/echo_bot.rs @@ -1,10 +1,15 @@ extern crate futures; extern crate tokio_core; extern crate tokio_xmpp; +extern crate rustls; +use std::sync::Arc; +use std::io::BufReader; +use std::fs::File; use tokio_core::reactor::Core; use futures::{Future, Stream}; -use tokio_xmpp::{Packet, TcpClient}; +use tokio_xmpp::{Packet, TcpClient, StartTlsClient}; +use rustls::ClientConfig; fn main() { use std::net::ToSocketAddrs; @@ -12,10 +17,16 @@ fn main() { .to_socket_addrs().unwrap() .next().unwrap(); + let mut config = ClientConfig::new(); + let mut certfile = BufReader::new(File::open("/usr/share/ca-certificates/CAcert/root.crt").unwrap()); + config.root_store.add_pem_file(&mut certfile).unwrap(); + let arc_config = Arc::new(config); + let mut core = Core::new().unwrap(); let client = TcpClient::connect( &addr, &core.handle() + ).and_then(|stream| StartTlsClient::from_stream(stream, arc_config) ).and_then(|stream| { stream.for_each(|event| { match event { @@ -25,5 +36,11 @@ fn main() { Ok(()) }) }); - core.run(client).unwrap(); + match core.run(client) { + Ok(_) => (), + Err(e) => { + println!("Fatal: {}", e); + () + } + } } diff --git a/src/lib.rs b/src/lib.rs index 20ded2e..bcf6665 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,12 +4,16 @@ extern crate tokio_core; extern crate tokio_io; extern crate bytes; extern crate xml; +extern crate rustls; +extern crate tokio_rustls; mod xmpp_codec; pub use xmpp_codec::*; mod tcp; pub use tcp::*; +mod starttls; +pub use starttls::*; // type FullClient = sasl::Client> diff --git a/src/starttls.rs b/src/starttls.rs new file mode 100644 index 0000000..b14982a --- /dev/null +++ b/src/starttls.rs @@ -0,0 +1,142 @@ +use std::mem::replace; +use std::io::{Error, ErrorKind}; +use std::sync::Arc; +use futures::{Future, Sink, Poll, Async}; +use futures::stream::Stream; +use futures::sink; +use tokio_core::net::TcpStream; +use rustls::*; +use tokio_rustls::*; +use xml; + +use super::{XMPPStream, XMPPCodec, Packet}; + + +const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams"; +const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls"; + +pub struct StartTlsClient { + state: StartTlsClientState, + arc_config: Arc, +} + +enum StartTlsClientState { + Invalid, + AwaitFeatures(XMPPStream), + SendStartTls(sink::Send>), + AwaitProceed(XMPPStream), + StartingTls(ConnectAsync), +} + +impl StartTlsClient { + /// Waits for + pub fn from_stream(xmpp_stream: XMPPStream, arc_config: Arc) -> Self { + StartTlsClient { + state: StartTlsClientState::AwaitFeatures(xmpp_stream), + arc_config: arc_config, + } + } +} + +// TODO: eval , check ns +impl Future for StartTlsClient { + type Item = XMPPStream>; + type Error = Error; + + fn poll(&mut self) -> Poll { + let old_state = replace(&mut self.state, StartTlsClientState::Invalid); + let mut retry = false; + + let (new_state, result) = match old_state { + StartTlsClientState::AwaitFeatures(mut xmpp_stream) => + match xmpp_stream.poll() { + Ok(Async::Ready(Some(Packet::Stanza(ref stanza)))) + if stanza.name == "features" + && stanza.ns == Some(NS_XMPP_STREAM.to_owned()) + => + { + println!("Got features: {}", stanza); + match stanza.get_child("starttls", Some(NS_XMPP_TLS)) { + None => + (StartTlsClientState::Invalid, Err(Error::from(ErrorKind::InvalidData))), + Some(_) => { + let nonza = xml::Element::new( + "starttls".to_owned(), Some(NS_XMPP_TLS.to_owned()), + vec![] + ); + println!("send {}", nonza); + let packet = Packet::Stanza(nonza); + let send = xmpp_stream.send(packet); + let new_state = StartTlsClientState::SendStartTls(send); + retry = true; + (new_state, Ok(Async::NotReady)) + }, + } + }, + Ok(Async::Ready(value)) => { + println!("StartTlsClient ignore {:?}", value); + (StartTlsClientState::AwaitFeatures(xmpp_stream), Ok(Async::NotReady)) + }, + Ok(_) => + (StartTlsClientState::AwaitFeatures(xmpp_stream), Ok(Async::NotReady)), + Err(e) => + (StartTlsClientState::AwaitFeatures(xmpp_stream), Err(e)), + }, + StartTlsClientState::SendStartTls(mut send) => + match send.poll() { + Ok(Async::Ready(xmpp_stream)) => { + println!("starttls sent"); + let new_state = StartTlsClientState::AwaitProceed(xmpp_stream); + retry = true; + (new_state, Ok(Async::NotReady)) + }, + Ok(Async::NotReady) => + (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady)), + Err(e) => + (StartTlsClientState::SendStartTls(send), Err(e)), + }, + StartTlsClientState::AwaitProceed(mut xmpp_stream) => + match xmpp_stream.poll() { + Ok(Async::Ready(Some(Packet::Stanza(ref stanza)))) + if stanza.name == "proceed" => + { + println!("* proceed *"); + let stream = xmpp_stream.into_inner(); + let connect = self.arc_config.connect_async("spaceboyz.net", stream); + let new_state = StartTlsClientState::StartingTls(connect); + retry = true; + (new_state, Ok(Async::NotReady)) + }, + Ok(Async::Ready(value)) => { + println!("StartTlsClient ignore {:?}", value); + (StartTlsClientState::AwaitFeatures(xmpp_stream), Ok(Async::NotReady)) + }, + Ok(_) => + (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)), + Err(e) => + (StartTlsClientState::AwaitProceed(xmpp_stream), Err(e)), + }, + StartTlsClientState::StartingTls(mut connect) => + match connect.poll() { + Ok(Async::Ready(tls_stream)) => { + println!("Got a TLS stream!"); + let xmpp_stream = XMPPCodec::frame_stream(tls_stream); + (StartTlsClientState::Invalid, Ok(Async::Ready(xmpp_stream))) + }, + Ok(Async::NotReady) => + (StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)), + Err(e) => + (StartTlsClientState::StartingTls(connect), Err(e)), + }, + StartTlsClientState::Invalid => + unreachable!(), + }; + + self.state = new_state; + if retry { + self.poll() + } else { + result + } + } +} diff --git a/src/tcp.rs b/src/tcp.rs index 01b9db0..29608ea 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -5,7 +5,6 @@ use futures::{Future, Sink, Poll, Async}; use futures::stream::Stream; use futures::sink; use tokio_core::reactor::Handle; -use tokio_io::AsyncRead; use tokio_core::net::{TcpStream, TcpStreamNew}; use super::{XMPPStream, XMPPCodec, Packet}; @@ -53,7 +52,7 @@ impl Future for TcpClient { let (new_state, result) = match self.state { TcpClientState::Connecting(ref mut tcp_stream_new) => { let tcp_stream = try_ready!(tcp_stream_new.poll()); - let xmpp_stream = AsyncRead::framed(tcp_stream, XMPPCodec::new()); + let xmpp_stream = XMPPCodec::frame_stream(tcp_stream); let send = xmpp_stream.send(Packet::StreamStart); let new_state = TcpClientState::SendStart(send); (new_state, Ok(Async::NotReady)) diff --git a/src/xmpp_codec.rs b/src/xmpp_codec.rs index 28c0dfb..38c593e 100644 --- a/src/xmpp_codec.rs +++ b/src/xmpp_codec.rs @@ -3,6 +3,7 @@ use std::fmt::Write; use std::str::from_utf8; use std::io::{Error, ErrorKind}; use std::collections::HashMap; +use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::codec::{Framed, Encoder, Decoder}; use xml; use bytes::*; @@ -67,6 +68,12 @@ impl XMPPCodec { root: None, } } + + pub fn frame_stream(stream: S) -> Framed + where S: AsyncRead + AsyncWrite + { + AsyncRead::framed(stream, XMPPCodec::new()) + } } impl Decoder for XMPPCodec { @@ -146,6 +153,9 @@ impl Encoder for XMPPCodec { NS_CLIENT, NS_STREAMS) .map_err(|_| Error::from(ErrorKind::WriteZero)) }, + Packet::Stanza(stanza) => + write!(dst, "{}", stanza) + .map_err(|_| Error::from(ErrorKind::InvalidInput)), // TODO: Implement all _ => Ok(()) }