diff --git a/tokio-xmpp/src/client/auth.rs b/tokio-xmpp/src/client/auth.rs index 28a60dc..d317bba 100644 --- a/tokio-xmpp/src/client/auth.rs +++ b/tokio-xmpp/src/client/auth.rs @@ -13,8 +13,6 @@ use crate::xmpp_codec::Packet; use crate::xmpp_stream::XMPPStream; use crate::{AuthError, Error, ProtocolError}; -const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl"; - pub async fn auth( mut stream: XMPPStream, creds: Credentials, @@ -28,11 +26,7 @@ pub async fn auth( let remote_mechs: HashSet = stream .stream_features - .get_child("mechanisms", NS_XMPP_SASL) - .ok_or(AuthError::NoMechanism)? - .children() - .filter(|child| child.is("mechanism", NS_XMPP_SASL)) - .map(|mech_el| mech_el.text()) + .sasl_mechanisms()? .collect(); for local_mech in local_mechs { diff --git a/tokio-xmpp/src/client/bind.rs b/tokio-xmpp/src/client/bind.rs index 6331c94..386b4b4 100644 --- a/tokio-xmpp/src/client/bind.rs +++ b/tokio-xmpp/src/client/bind.rs @@ -10,46 +10,42 @@ use crate::xmpp_codec::Packet; use crate::xmpp_stream::XMPPStream; use crate::{Error, ProtocolError}; -const NS_XMPP_BIND: &str = "urn:ietf:params:xml:ns:xmpp-bind"; const BIND_REQ_ID: &str = "resource-bind"; pub async fn bind( mut stream: XMPPStream, ) -> Result, Error> { - match stream.stream_features.get_child("bind", NS_XMPP_BIND) { - None => { - // No resource binding available, - // return the (probably // usable) stream immediately - return Ok(stream); - } - Some(_) => { - let resource = if let Jid::Full(jid) = stream.jid.clone() { - Some(jid.resource) - } else { - None - }; - let iq = Iq::from_set(BIND_REQ_ID, BindQuery::new(resource)); - stream.send_stanza(iq).await?; + if stream.stream_features.can_bind() { + let resource = if let Jid::Full(jid) = stream.jid.clone() { + Some(jid.resource) + } else { + None + }; + let iq = Iq::from_set(BIND_REQ_ID, BindQuery::new(resource)); + stream.send_stanza(iq).await?; - loop { - match stream.next().await { - Some(Ok(Packet::Stanza(stanza))) => match Iq::try_from(stanza) { - Ok(iq) if iq.id == BIND_REQ_ID => match iq.payload { - IqType::Result(payload) => { - payload - .and_then(|payload| BindResponse::try_from(payload).ok()) - .map(|bind| stream.jid = bind.into()); - return Ok(stream); - } - _ => return Err(ProtocolError::InvalidBindResponse.into()), - }, - _ => {} + loop { + match stream.next().await { + Some(Ok(Packet::Stanza(stanza))) => match Iq::try_from(stanza) { + Ok(iq) if iq.id == BIND_REQ_ID => match iq.payload { + IqType::Result(payload) => { + payload + .and_then(|payload| BindResponse::try_from(payload).ok()) + .map(|bind| stream.jid = bind.into()); + return Ok(stream); + } + _ => return Err(ProtocolError::InvalidBindResponse.into()), }, - Some(Ok(_)) => {} - Some(Err(e)) => return Err(e), - None => return Err(Error::Disconnected), - } + _ => {} + }, + Some(Ok(_)) => {} + Some(Err(e)) => return Err(e), + None => return Err(Error::Disconnected), } } + } else { + // No resource binding available, + // return the (probably // usable) stream immediately + return Ok(stream); } } diff --git a/tokio-xmpp/src/client/mod.rs b/tokio-xmpp/src/client/mod.rs index 29e8396..96ca288 100644 --- a/tokio-xmpp/src/client/mod.rs +++ b/tokio-xmpp/src/client/mod.rs @@ -5,7 +5,6 @@ use std::mem::replace; use std::pin::Pin; use std::str::FromStr; use std::task::Context; -use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; use tokio::task::JoinHandle; use tokio::task::LocalSet; @@ -14,13 +13,18 @@ use xmpp_parsers::{Element, Jid, JidParseError}; use super::event::Event; use super::happy_eyeballs::connect; -use super::starttls::{starttls, NS_XMPP_TLS}; +use super::starttls::starttls; use super::xmpp_codec::Packet; use super::xmpp_stream; use super::{Error, ProtocolError}; mod auth; +use auth::auth; mod bind; +use bind::bind; + +pub const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl"; +pub const NS_XMPP_BIND: &str = "urn:ietf:params:xml:ns:xmpp-bind"; /// XMPP client connection and state /// @@ -79,56 +83,34 @@ impl Client { let password = password; let domain = idna::domain_to_ascii(&jid.clone().domain()).map_err(|_| Error::Idna)?; + // TCP connection let tcp_stream = connect(&domain, Some("_xmpp-client._tcp"), 5222).await?; + // Unencryped XMPPStream let xmpp_stream = - xmpp_stream::XMPPStream::start(tcp_stream, jid, NS_JABBER_CLIENT.to_owned()).await?; - let xmpp_stream = if Self::can_starttls(&xmpp_stream) { - Self::starttls(xmpp_stream).await? + xmpp_stream::XMPPStream::start(tcp_stream, jid.clone(), NS_JABBER_CLIENT.to_owned()).await?; + + let xmpp_stream = if xmpp_stream.stream_features.can_starttls() { + // TlsStream + let tls_stream = starttls(xmpp_stream).await?; + // Encrypted XMPPStream + xmpp_stream::XMPPStream::start(tls_stream, jid.clone(), NS_JABBER_CLIENT.to_owned()).await? } else { return Err(Error::Protocol(ProtocolError::NoTls)); }; - let xmpp_stream = Self::auth(xmpp_stream, username, password).await?; - let xmpp_stream = Self::bind(xmpp_stream).await?; - Ok(xmpp_stream) - } - - fn can_starttls( - xmpp_stream: &xmpp_stream::XMPPStream, - ) -> bool { - xmpp_stream - .stream_features - .get_child("starttls", NS_XMPP_TLS) - .is_some() - } - - async fn starttls( - xmpp_stream: xmpp_stream::XMPPStream, - ) -> Result>, Error> { - let jid = xmpp_stream.jid.clone(); - let tls_stream = starttls(xmpp_stream).await?; - xmpp_stream::XMPPStream::start(tls_stream, jid, NS_JABBER_CLIENT.to_owned()).await - } - - async fn auth( - xmpp_stream: xmpp_stream::XMPPStream, - username: String, - password: String, - ) -> Result, Error> { - let jid = xmpp_stream.jid.clone(); let creds = Credentials::default() .with_username(username) .with_password(password) .with_channel_binding(ChannelBinding::None); - let stream = auth::auth(xmpp_stream, creds).await?; - xmpp_stream::XMPPStream::start(stream, jid, NS_JABBER_CLIENT.to_owned()).await - } + // Authenticated (unspecified) stream + let stream = auth(xmpp_stream, creds).await?; + // Authenticated XMPPStream + let xmpp_stream = xmpp_stream::XMPPStream::start(stream, jid, NS_JABBER_CLIENT.to_owned()).await?; - async fn bind( - stream: xmpp_stream::XMPPStream, - ) -> Result, Error> { - bind::bind(stream).await + // XMPPStream bound to user session + let xmpp_stream = bind(xmpp_stream).await?; + Ok(xmpp_stream) } /// Get the client's bound JID (the one reported by the XMPP diff --git a/tokio-xmpp/src/lib.rs b/tokio-xmpp/src/lib.rs index 54d4c6f..9a8081a 100644 --- a/tokio-xmpp/src/lib.rs +++ b/tokio-xmpp/src/lib.rs @@ -9,6 +9,7 @@ pub use crate::xmpp_codec::Packet; mod event; mod happy_eyeballs; pub mod xmpp_stream; +pub mod stream_features; pub use crate::event::Event; mod client; pub use crate::client::Client; diff --git a/tokio-xmpp/src/stream_features.rs b/tokio-xmpp/src/stream_features.rs new file mode 100644 index 0000000..32bbd16 --- /dev/null +++ b/tokio-xmpp/src/stream_features.rs @@ -0,0 +1,45 @@ +//! Contains wrapper for `` + +use xmpp_parsers::Element; +use crate::starttls::NS_XMPP_TLS; +use crate::client::{NS_XMPP_SASL, NS_XMPP_BIND}; +use crate::error::AuthError; + +/// Wraps ``, usually the very first nonza of an +/// XMPPStream. +/// +/// TODO: should this rather go into xmpp-parsers, kept in a decoded +/// struct? +pub struct StreamFeatures(pub Element); + +impl StreamFeatures { + /// Wrap the nonza + pub fn new(element: Element) -> Self { + StreamFeatures(element) + } + + /// Can initiate TLS session with this server? + pub fn can_starttls(&self) -> bool { + self.0 + .get_child("starttls", NS_XMPP_TLS) + .is_some() + } + + /// Iterate over SASL mechanisms + pub fn sasl_mechanisms<'a>(&'a self) -> Result + 'a, AuthError> { + Ok(self.0 + .get_child("mechanisms", NS_XMPP_SASL) + .ok_or(AuthError::NoMechanism)? + .children() + .filter(|child| child.is("mechanism", NS_XMPP_SASL)) + .map(|mech_el| mech_el.text()) + ) + } + + /// Does server support user resource binding? + pub fn can_bind(&self) -> bool { + self.0 + .get_child("bind", NS_XMPP_BIND) + .is_some() + } +} diff --git a/tokio-xmpp/src/xmpp_stream.rs b/tokio-xmpp/src/xmpp_stream.rs index 733d342..4cc59ed 100644 --- a/tokio-xmpp/src/xmpp_stream.rs +++ b/tokio-xmpp/src/xmpp_stream.rs @@ -11,6 +11,7 @@ use tokio_util::codec::Framed; use xmpp_parsers::{Element, Jid}; use crate::stream_start; +use crate::stream_features::StreamFeatures; use crate::xmpp_codec::{Packet, XMPPCodec}; use crate::Error; @@ -27,7 +28,7 @@ pub struct XMPPStream { /// Codec instance pub stream: Mutex>, /// `` for XMPP version 1.0 - pub stream_features: Element, + pub stream_features: StreamFeatures, /// Root namespace /// /// This is different for either c2s, s2s, or component @@ -49,7 +50,7 @@ impl XMPPStream { XMPPStream { jid, stream: Mutex::new(stream), - stream_features, + stream_features: StreamFeatures::new(stream_features), ns, id, }