From c32a38874c00e1befdbbb2cd92e6f3501410255f Mon Sep 17 00:00:00 2001 From: Astro Date: Mon, 5 Jun 2017 02:50:22 +0200 Subject: [PATCH] refactor into stream_start + xmpp_stream --- examples/echo_bot.rs | 10 ++++- src/lib.rs | 5 ++- src/starttls.rs | 79 ++++++++++++++------------------- src/stream_start.rs | 102 +++++++++++++++++++++++++++++++++++++++++++ src/tcp.rs | 52 ++++------------------ src/xmpp_codec.rs | 46 +++++++++---------- src/xmpp_stream.rs | 62 ++++++++++++++++++++++++++ 7 files changed, 239 insertions(+), 117 deletions(-) create mode 100644 src/stream_start.rs create mode 100644 src/xmpp_stream.rs diff --git a/examples/echo_bot.rs b/examples/echo_bot.rs index 6933e1a2..0cdd6b07 100644 --- a/examples/echo_bot.rs +++ b/examples/echo_bot.rs @@ -8,7 +8,8 @@ use std::io::BufReader; use std::fs::File; use tokio_core::reactor::Core; use futures::{Future, Stream}; -use tokio_xmpp::{Packet, TcpClient, StartTlsClient}; +use tokio_xmpp::TcpClient; +use tokio_xmpp::xmpp_codec::Packet; use rustls::ClientConfig; fn main() { @@ -26,8 +27,13 @@ fn main() { let client = TcpClient::connect( &addr, &core.handle() - ).and_then(|stream| StartTlsClient::from_stream(stream, arc_config) ).and_then(|stream| { + if stream.can_starttls() { + stream.starttls(arc_config) + } else { + panic!("No STARTTLS") + } + }).and_then(|stream| { stream.for_each(|event| { match event { Packet::Stanza(el) => println!("<< {}", el), diff --git a/src/lib.rs b/src/lib.rs index bcf66651..9764daea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,8 +8,9 @@ extern crate rustls; extern crate tokio_rustls; -mod xmpp_codec; -pub use xmpp_codec::*; +pub mod xmpp_codec; +pub mod xmpp_stream; +mod stream_start; mod tcp; pub use tcp::*; mod starttls; diff --git a/src/starttls.rs b/src/starttls.rs index 59946da3..b0fdc92f 100644 --- a/src/starttls.rs +++ b/src/starttls.rs @@ -1,5 +1,5 @@ use std::mem::replace; -use std::io::{Error, ErrorKind}; +use std::io::Error; use std::sync::Arc; use futures::{Future, Sink, Poll, Async}; use futures::stream::Stream; @@ -9,36 +9,44 @@ use rustls::*; use tokio_rustls::*; use xml; -use super::{XMPPStream, XMPPCodec, Packet}; +use xmpp_codec::*; +use xmpp_stream::*; +use stream_start::StreamStart; -const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams"; -const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls"; +pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls"; -pub struct StartTlsClient { +pub struct StartTlsClient { state: StartTlsClientState, arc_config: Arc, } -enum StartTlsClientState { +enum StartTlsClientState { Invalid, - AwaitFeatures(XMPPStream), SendStartTls(sink::Send>), AwaitProceed(XMPPStream), StartingTls(ConnectAsync), + Start(StreamStart>), } -impl StartTlsClient { +impl StartTlsClient { /// Waits for pub fn from_stream(xmpp_stream: XMPPStream, arc_config: Arc) -> Self { + let nonza = xml::Element::new( + "starttls".to_owned(), Some(NS_XMPP_TLS.to_owned()), + vec![] + ); + println!("send {}", nonza); + let packet = Packet::Stanza(nonza); + let send = xmpp_stream.send(packet); + StartTlsClient { - state: StartTlsClientState::AwaitFeatures(xmpp_stream), + state: StartTlsClientState::SendStartTls(send), arc_config: arc_config, } } } -// TODO: eval , check ns impl Future for StartTlsClient { type Item = XMPPStream>; type Error = Error; @@ -48,40 +56,6 @@ impl Future for StartTlsClient { let mut retry = false; let (new_state, result) = match old_state { - StartTlsClientState::AwaitFeatures(mut xmpp_stream) => - match xmpp_stream.poll() { - Ok(Async::Ready(Some(Packet::Stanza(ref stanza)))) - if stanza.name == "features" - && stanza.ns == Some(NS_XMPP_STREAM.to_owned()) - => - { - println!("Got features: {}", stanza); - match stanza.get_child("starttls", Some(NS_XMPP_TLS)) { - None => - (StartTlsClientState::Invalid, Err(Error::from(ErrorKind::InvalidData))), - Some(_) => { - let nonza = xml::Element::new( - "starttls".to_owned(), Some(NS_XMPP_TLS.to_owned()), - vec![] - ); - println!("send {}", nonza); - let packet = Packet::Stanza(nonza); - let send = xmpp_stream.send(packet); - let new_state = StartTlsClientState::SendStartTls(send); - retry = true; - (new_state, Ok(Async::NotReady)) - }, - } - }, - Ok(Async::Ready(value)) => { - println!("StartTlsClient ignore {:?}", value); - (StartTlsClientState::AwaitFeatures(xmpp_stream), Ok(Async::NotReady)) - }, - Ok(_) => - (StartTlsClientState::AwaitFeatures(xmpp_stream), Ok(Async::NotReady)), - Err(e) => - (StartTlsClientState::AwaitFeatures(xmpp_stream), Err(e)), - }, StartTlsClientState::SendStartTls(mut send) => match send.poll() { Ok(Async::Ready(xmpp_stream)) => { @@ -109,7 +83,7 @@ impl Future for StartTlsClient { }, Ok(Async::Ready(value)) => { println!("StartTlsClient ignore {:?}", value); - (StartTlsClientState::AwaitFeatures(xmpp_stream), Ok(Async::NotReady)) + (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)) }, Ok(_) => (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)), @@ -120,14 +94,25 @@ impl Future for StartTlsClient { match connect.poll() { Ok(Async::Ready(tls_stream)) => { println!("Got a TLS stream!"); - let xmpp_stream = XMPPCodec::frame_stream(tls_stream); - (StartTlsClientState::Invalid, Ok(Async::Ready(xmpp_stream))) + let start = XMPPStream::from_stream(tls_stream, "spaceboyz.net".to_owned()); + let new_state = StartTlsClientState::Start(start); + retry = true; + (new_state, Ok(Async::NotReady)) }, Ok(Async::NotReady) => (StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)), Err(e) => (StartTlsClientState::StartingTls(connect), Err(e)), }, + StartTlsClientState::Start(mut start) => + match start.poll() { + Ok(Async::Ready(xmpp_stream)) => + (StartTlsClientState::Invalid, Ok(Async::Ready(xmpp_stream))), + Ok(Async::NotReady) => + (StartTlsClientState::Start(start), Ok(Async::NotReady)), + Err(e) => + (StartTlsClientState::Invalid, Err(e)), + }, StartTlsClientState::Invalid => unreachable!(), }; diff --git a/src/stream_start.rs b/src/stream_start.rs new file mode 100644 index 00000000..480178eb --- /dev/null +++ b/src/stream_start.rs @@ -0,0 +1,102 @@ +use std::mem::replace; +use std::io::{Error, ErrorKind}; +use std::collections::HashMap; +use futures::*; +use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_io::codec::Framed; + +use xmpp_codec::*; +use xmpp_stream::*; + +const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams"; + +pub struct StreamStart { + state: StreamStartState, +} + +enum StreamStartState { + SendStart(sink::Send>), + RecvStart(Framed), + RecvFeatures(Framed, HashMap), + Invalid, +} + +impl StreamStart { + pub fn from_stream(stream: Framed, to: String) -> Self { + let attrs = [("to".to_owned(), to), + ("version".to_owned(), "1.0".to_owned()), + ("xmlns".to_owned(), "jabber:client".to_owned()), + ("xmlns:stream".to_owned(), NS_XMPP_STREAM.to_owned()), + ].iter().cloned().collect(); + let send = stream.send(Packet::StreamStart(attrs)); + + StreamStart { + state: StreamStartState::SendStart(send), + } + } +} + +impl Future for StreamStart { + type Item = XMPPStream; + type Error = Error; + + fn poll(&mut self) -> Poll { + let old_state = replace(&mut self.state, StreamStartState::Invalid); + let mut retry = false; + + let (new_state, result) = match old_state { + StreamStartState::SendStart(mut send) => + match send.poll() { + Ok(Async::Ready(stream)) => { + retry = true; + (StreamStartState::RecvStart(stream), Ok(Async::NotReady)) + }, + Ok(Async::NotReady) => + (StreamStartState::SendStart(send), Ok(Async::NotReady)), + Err(e) => + (StreamStartState::Invalid, Err(e)), + }, + StreamStartState::RecvStart(mut stream) => + match stream.poll() { + Ok(Async::Ready(Some(Packet::StreamStart(stream_attrs)))) => { + retry = true; + // TODO: skip RecvFeatures for version < 1.0 + (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady)) + }, + Ok(Async::Ready(_)) => + return Err(Error::from(ErrorKind::InvalidData)), + Ok(Async::NotReady) => + (StreamStartState::RecvStart(stream), Ok(Async::NotReady)), + Err(e) => + return Err(e), + }, + StreamStartState::RecvFeatures(mut stream, stream_attrs) => + match stream.poll() { + Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => + if stanza.name == "features" + && stanza.ns == Some(NS_XMPP_STREAM.to_owned()) { + (StreamStartState::Invalid, Ok(Async::Ready(XMPPStream { stream, stream_attrs, stream_features: stanza }))) + } else { + (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady)) + }, + Ok(Async::Ready(item)) => { + println!("StreamStart skip {:?}", item); + (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady)) + }, + Ok(Async::NotReady) => + (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady)), + Err(e) => + return Err(e), + }, + StreamStartState::Invalid => + unreachable!(), + }; + + self.state = new_state; + if retry { + self.poll() + } else { + result + } + } +} diff --git a/src/tcp.rs b/src/tcp.rs index 29608ea6..57fb0b1f 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -1,40 +1,22 @@ -use std::fmt; use std::net::SocketAddr; -use std::io::{Error, ErrorKind}; -use futures::{Future, Sink, Poll, Async}; -use futures::stream::Stream; -use futures::sink; +use std::io::Error; +use futures::{Future, Poll, Async}; use tokio_core::reactor::Handle; use tokio_core::net::{TcpStream, TcpStreamNew}; -use super::{XMPPStream, XMPPCodec, Packet}; +use xmpp_stream::*; +use stream_start::StreamStart; - -#[derive(Debug)] pub struct TcpClient { state: TcpClientState, } enum TcpClientState { Connecting(TcpStreamNew), - SendStart(sink::Send>), - RecvStart(Option>), + Start(StreamStart), Established, } -impl fmt::Debug for TcpClientState { - fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> { - let s = match *self { - TcpClientState::Connecting(_) => "Connecting", - TcpClientState::SendStart(_) => "SendStart", - TcpClientState::RecvStart(_) => "RecvStart", - TcpClientState::Established => "Established", - }; - try!(write!(fmt, "{}", s)); - Ok(()) - } -} - impl TcpClient { pub fn connect(addr: &SocketAddr, handle: &Handle) -> Self { let tcp_stream_new = TcpStream::connect(addr, handle); @@ -52,27 +34,12 @@ impl Future for TcpClient { let (new_state, result) = match self.state { TcpClientState::Connecting(ref mut tcp_stream_new) => { let tcp_stream = try_ready!(tcp_stream_new.poll()); - let xmpp_stream = XMPPCodec::frame_stream(tcp_stream); - let send = xmpp_stream.send(Packet::StreamStart); - let new_state = TcpClientState::SendStart(send); + let start = XMPPStream::from_stream(tcp_stream, "spaceboyz.net".to_owned()); + let new_state = TcpClientState::Start(start); (new_state, Ok(Async::NotReady)) }, - TcpClientState::SendStart(ref mut send) => { - let xmpp_stream = try_ready!(send.poll()); - let new_state = TcpClientState::RecvStart(Some(xmpp_stream)); - (new_state, Ok(Async::NotReady)) - }, - TcpClientState::RecvStart(ref mut opt_xmpp_stream) => { - let mut xmpp_stream = opt_xmpp_stream.take().unwrap(); - match xmpp_stream.poll() { - Ok(Async::Ready(Some(Packet::StreamStart))) => println!("Recv start!"), - Ok(Async::Ready(_)) => return Err(Error::from(ErrorKind::InvalidData)), - Ok(Async::NotReady) => { - *opt_xmpp_stream = Some(xmpp_stream); - return Ok(Async::NotReady); - }, - Err(e) => return Err(e) - }; + TcpClientState::Start(ref mut start) => { + let xmpp_stream = try_ready!(start.poll()); let new_state = TcpClientState::Established; (new_state, Ok(Async::Ready(xmpp_stream))) }, @@ -80,7 +47,6 @@ impl Future for TcpClient { unreachable!(), }; - println!("Next state: {:?}", new_state); self.state = new_state; match result { // by polling again, we register new future diff --git a/src/xmpp_codec.rs b/src/xmpp_codec.rs index 38c593e4..8595dfca 100644 --- a/src/xmpp_codec.rs +++ b/src/xmpp_codec.rs @@ -3,18 +3,17 @@ use std::fmt::Write; use std::str::from_utf8; use std::io::{Error, ErrorKind}; use std::collections::HashMap; -use tokio_io::{AsyncRead, AsyncWrite}; -use tokio_io::codec::{Framed, Encoder, Decoder}; +use tokio_io::codec::{Encoder, Decoder}; use xml; use bytes::*; const NS_XMLNS: &'static str = "http://www.w3.org/2000/xmlns/"; -const NS_STREAMS: &'static str = "http://etherx.jabber.org/streams"; -const NS_CLIENT: &'static str = "jabber:client"; + +pub type Attributes = HashMap<(String, Option), String>; struct XMPPRoot { builder: xml::ElementBuilder, - pub attributes: HashMap<(String, Option), String>, + pub attributes: Attributes, } impl XMPPRoot { @@ -49,13 +48,11 @@ impl XMPPRoot { #[derive(Debug)] pub enum Packet { Error(Box), - StreamStart, + StreamStart(HashMap), Stanza(xml::Element), StreamEnd, } -pub type XMPPStream = Framed; - pub struct XMPPCodec { parser: xml::Parser, root: Option, @@ -68,12 +65,6 @@ impl XMPPCodec { root: None, } } - - pub fn frame_stream(stream: S) -> Framed - where S: AsyncRead + AsyncWrite - { - AsyncRead::framed(stream, XMPPCodec::new()) - } } impl Decoder for XMPPCodec { @@ -97,8 +88,12 @@ impl Decoder for XMPPCodec { // Expecting match event { Ok(xml::Event::ElementStart(start_tag)) => { + let mut attrs: HashMap = HashMap::new(); + for (&(ref name, _), value) in &start_tag.attributes { + attrs.insert(name.to_owned(), value.to_owned()); + } + result = Some(Packet::StreamStart(attrs)); self.root = Some(XMPPRoot::new(start_tag)); - result = Some(Packet::StreamStart); break }, Err(e) => { @@ -146,18 +141,23 @@ impl Encoder for XMPPCodec { fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> { match item { - Packet::StreamStart => { - write!(dst, - "\n -\n", - NS_CLIENT, NS_STREAMS) - .map_err(|_| Error::from(ErrorKind::WriteZero)) + Packet::StreamStart(start_attrs) => { + let mut buf = String::new(); + write!(buf, "\n").unwrap(); + + println!("Encode start to {}", buf); + write!(dst, "{}", buf) }, Packet::Stanza(stanza) => - write!(dst, "{}", stanza) - .map_err(|_| Error::from(ErrorKind::InvalidInput)), + write!(dst, "{}", stanza), // TODO: Implement all _ => Ok(()) } + .map_err(|_| Error::from(ErrorKind::InvalidInput)) } } diff --git a/src/xmpp_stream.rs b/src/xmpp_stream.rs new file mode 100644 index 00000000..c4080bb7 --- /dev/null +++ b/src/xmpp_stream.rs @@ -0,0 +1,62 @@ +use std::sync::Arc; +use std::collections::HashMap; +use futures::*; +use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_io::codec::Framed; +use rustls::ClientConfig; +use xml; + +use xmpp_codec::*; +use stream_start::*; +use starttls::{NS_XMPP_TLS, StartTlsClient}; + +pub const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams"; + +pub struct XMPPStream { + pub stream: Framed, + pub stream_attrs: HashMap, + pub stream_features: xml::Element, +} + +impl XMPPStream { + pub fn from_stream(stream: S, to: String) -> StreamStart { + let xmpp_stream = AsyncRead::framed(stream, XMPPCodec::new()); + StreamStart::from_stream(xmpp_stream, to) + } + + pub fn into_inner(self) -> S { + self.stream.into_inner() + } + + pub fn can_starttls(&self) -> bool { + self.stream_features + .get_child("starttls", Some(NS_XMPP_TLS)) + .is_some() + } + + pub fn starttls(self, arc_config: Arc) -> StartTlsClient { + StartTlsClient::from_stream(self, arc_config) + } +} + +impl Sink for XMPPStream { + type SinkItem = as Sink>::SinkItem; + type SinkError = as Sink>::SinkError; + + fn start_send(&mut self, item: Self::SinkItem) -> StartSend { + self.stream.start_send(item) + } + + fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { + self.stream.poll_complete() + } +} + +impl Stream for XMPPStream { + type Item = as Stream>::Item; + type Error = as Stream>::Error; + + fn poll(&mut self) -> Poll, Self::Error> { + self.stream.poll() + } +}