tokio-xmpp client: condense fn connect(), refactor out into stream_features

This commit is contained in:
Astro 2020-03-18 01:12:48 +01:00
parent 4d24e6bebb
commit c13712b158
6 changed files with 100 additions and 81 deletions

View file

@ -13,8 +13,6 @@ use crate::xmpp_codec::Packet;
use crate::xmpp_stream::XMPPStream; use crate::xmpp_stream::XMPPStream;
use crate::{AuthError, Error, ProtocolError}; use crate::{AuthError, Error, ProtocolError};
const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
pub async fn auth<S: AsyncRead + AsyncWrite + Unpin>( pub async fn auth<S: AsyncRead + AsyncWrite + Unpin>(
mut stream: XMPPStream<S>, mut stream: XMPPStream<S>,
creds: Credentials, creds: Credentials,
@ -28,11 +26,7 @@ pub async fn auth<S: AsyncRead + AsyncWrite + Unpin>(
let remote_mechs: HashSet<String> = stream let remote_mechs: HashSet<String> = stream
.stream_features .stream_features
.get_child("mechanisms", NS_XMPP_SASL) .sasl_mechanisms()?
.ok_or(AuthError::NoMechanism)?
.children()
.filter(|child| child.is("mechanism", NS_XMPP_SASL))
.map(|mech_el| mech_el.text())
.collect(); .collect();
for local_mech in local_mechs { for local_mech in local_mechs {

View file

@ -10,46 +10,42 @@ use crate::xmpp_codec::Packet;
use crate::xmpp_stream::XMPPStream; use crate::xmpp_stream::XMPPStream;
use crate::{Error, ProtocolError}; use crate::{Error, ProtocolError};
const NS_XMPP_BIND: &str = "urn:ietf:params:xml:ns:xmpp-bind";
const BIND_REQ_ID: &str = "resource-bind"; const BIND_REQ_ID: &str = "resource-bind";
pub async fn bind<S: AsyncRead + AsyncWrite + Unpin>( pub async fn bind<S: AsyncRead + AsyncWrite + Unpin>(
mut stream: XMPPStream<S>, mut stream: XMPPStream<S>,
) -> Result<XMPPStream<S>, Error> { ) -> Result<XMPPStream<S>, Error> {
match stream.stream_features.get_child("bind", NS_XMPP_BIND) { if stream.stream_features.can_bind() {
None => { let resource = if let Jid::Full(jid) = stream.jid.clone() {
// No resource binding available, Some(jid.resource)
// return the (probably // usable) stream immediately } else {
return Ok(stream); None
} };
Some(_) => { let iq = Iq::from_set(BIND_REQ_ID, BindQuery::new(resource));
let resource = if let Jid::Full(jid) = stream.jid.clone() { stream.send_stanza(iq).await?;
Some(jid.resource)
} else {
None
};
let iq = Iq::from_set(BIND_REQ_ID, BindQuery::new(resource));
stream.send_stanza(iq).await?;
loop { loop {
match stream.next().await { match stream.next().await {
Some(Ok(Packet::Stanza(stanza))) => match Iq::try_from(stanza) { Some(Ok(Packet::Stanza(stanza))) => match Iq::try_from(stanza) {
Ok(iq) if iq.id == BIND_REQ_ID => match iq.payload { Ok(iq) if iq.id == BIND_REQ_ID => match iq.payload {
IqType::Result(payload) => { IqType::Result(payload) => {
payload payload
.and_then(|payload| BindResponse::try_from(payload).ok()) .and_then(|payload| BindResponse::try_from(payload).ok())
.map(|bind| stream.jid = bind.into()); .map(|bind| stream.jid = bind.into());
return Ok(stream); return Ok(stream);
} }
_ => return Err(ProtocolError::InvalidBindResponse.into()), _ => return Err(ProtocolError::InvalidBindResponse.into()),
},
_ => {}
}, },
Some(Ok(_)) => {} _ => {}
Some(Err(e)) => return Err(e), },
None => return Err(Error::Disconnected), Some(Ok(_)) => {}
} Some(Err(e)) => return Err(e),
None => return Err(Error::Disconnected),
} }
} }
} else {
// No resource binding available,
// return the (probably // usable) stream immediately
return Ok(stream);
} }
} }

View file

@ -5,7 +5,6 @@ use std::mem::replace;
use std::pin::Pin; use std::pin::Pin;
use std::str::FromStr; use std::str::FromStr;
use std::task::Context; use std::task::Context;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tokio::task::LocalSet; use tokio::task::LocalSet;
@ -14,13 +13,18 @@ use xmpp_parsers::{Element, Jid, JidParseError};
use super::event::Event; use super::event::Event;
use super::happy_eyeballs::connect; use super::happy_eyeballs::connect;
use super::starttls::{starttls, NS_XMPP_TLS}; use super::starttls::starttls;
use super::xmpp_codec::Packet; use super::xmpp_codec::Packet;
use super::xmpp_stream; use super::xmpp_stream;
use super::{Error, ProtocolError}; use super::{Error, ProtocolError};
mod auth; mod auth;
use auth::auth;
mod bind; mod bind;
use bind::bind;
pub const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
pub const NS_XMPP_BIND: &str = "urn:ietf:params:xml:ns:xmpp-bind";
/// XMPP client connection and state /// XMPP client connection and state
/// ///
@ -79,56 +83,34 @@ impl Client {
let password = password; let password = password;
let domain = idna::domain_to_ascii(&jid.clone().domain()).map_err(|_| Error::Idna)?; let domain = idna::domain_to_ascii(&jid.clone().domain()).map_err(|_| Error::Idna)?;
// TCP connection
let tcp_stream = connect(&domain, Some("_xmpp-client._tcp"), 5222).await?; let tcp_stream = connect(&domain, Some("_xmpp-client._tcp"), 5222).await?;
// Unencryped XMPPStream
let xmpp_stream = let xmpp_stream =
xmpp_stream::XMPPStream::start(tcp_stream, jid, NS_JABBER_CLIENT.to_owned()).await?; xmpp_stream::XMPPStream::start(tcp_stream, jid.clone(), NS_JABBER_CLIENT.to_owned()).await?;
let xmpp_stream = if Self::can_starttls(&xmpp_stream) {
Self::starttls(xmpp_stream).await? let xmpp_stream = if xmpp_stream.stream_features.can_starttls() {
// TlsStream
let tls_stream = starttls(xmpp_stream).await?;
// Encrypted XMPPStream
xmpp_stream::XMPPStream::start(tls_stream, jid.clone(), NS_JABBER_CLIENT.to_owned()).await?
} else { } else {
return Err(Error::Protocol(ProtocolError::NoTls)); return Err(Error::Protocol(ProtocolError::NoTls));
}; };
let xmpp_stream = Self::auth(xmpp_stream, username, password).await?;
let xmpp_stream = Self::bind(xmpp_stream).await?;
Ok(xmpp_stream)
}
fn can_starttls<S: AsyncRead + AsyncWrite + Unpin>(
xmpp_stream: &xmpp_stream::XMPPStream<S>,
) -> bool {
xmpp_stream
.stream_features
.get_child("starttls", NS_XMPP_TLS)
.is_some()
}
async fn starttls<S: AsyncRead + AsyncWrite + Unpin>(
xmpp_stream: xmpp_stream::XMPPStream<S>,
) -> Result<xmpp_stream::XMPPStream<TlsStream<S>>, Error> {
let jid = xmpp_stream.jid.clone();
let tls_stream = starttls(xmpp_stream).await?;
xmpp_stream::XMPPStream::start(tls_stream, jid, NS_JABBER_CLIENT.to_owned()).await
}
async fn auth<S: AsyncRead + AsyncWrite + Unpin + 'static>(
xmpp_stream: xmpp_stream::XMPPStream<S>,
username: String,
password: String,
) -> Result<xmpp_stream::XMPPStream<S>, Error> {
let jid = xmpp_stream.jid.clone();
let creds = Credentials::default() let creds = Credentials::default()
.with_username(username) .with_username(username)
.with_password(password) .with_password(password)
.with_channel_binding(ChannelBinding::None); .with_channel_binding(ChannelBinding::None);
let stream = auth::auth(xmpp_stream, creds).await?; // Authenticated (unspecified) stream
xmpp_stream::XMPPStream::start(stream, jid, NS_JABBER_CLIENT.to_owned()).await let stream = auth(xmpp_stream, creds).await?;
} // Authenticated XMPPStream
let xmpp_stream = xmpp_stream::XMPPStream::start(stream, jid, NS_JABBER_CLIENT.to_owned()).await?;
async fn bind<S: Unpin + AsyncRead + AsyncWrite>( // XMPPStream bound to user session
stream: xmpp_stream::XMPPStream<S>, let xmpp_stream = bind(xmpp_stream).await?;
) -> Result<xmpp_stream::XMPPStream<S>, Error> { Ok(xmpp_stream)
bind::bind(stream).await
} }
/// Get the client's bound JID (the one reported by the XMPP /// Get the client's bound JID (the one reported by the XMPP

View file

@ -9,6 +9,7 @@ pub use crate::xmpp_codec::Packet;
mod event; mod event;
mod happy_eyeballs; mod happy_eyeballs;
pub mod xmpp_stream; pub mod xmpp_stream;
pub mod stream_features;
pub use crate::event::Event; pub use crate::event::Event;
mod client; mod client;
pub use crate::client::Client; pub use crate::client::Client;

View file

@ -0,0 +1,45 @@
//! Contains wrapper for `<stream:features/>`
use xmpp_parsers::Element;
use crate::starttls::NS_XMPP_TLS;
use crate::client::{NS_XMPP_SASL, NS_XMPP_BIND};
use crate::error::AuthError;
/// Wraps `<stream:features/>`, usually the very first nonza of an
/// XMPPStream.
///
/// TODO: should this rather go into xmpp-parsers, kept in a decoded
/// struct?
pub struct StreamFeatures(pub Element);
impl StreamFeatures {
/// Wrap the nonza
pub fn new(element: Element) -> Self {
StreamFeatures(element)
}
/// Can initiate TLS session with this server?
pub fn can_starttls(&self) -> bool {
self.0
.get_child("starttls", NS_XMPP_TLS)
.is_some()
}
/// Iterate over SASL mechanisms
pub fn sasl_mechanisms<'a>(&'a self) -> Result<impl Iterator<Item=String> + 'a, AuthError> {
Ok(self.0
.get_child("mechanisms", NS_XMPP_SASL)
.ok_or(AuthError::NoMechanism)?
.children()
.filter(|child| child.is("mechanism", NS_XMPP_SASL))
.map(|mech_el| mech_el.text())
)
}
/// Does server support user resource binding?
pub fn can_bind(&self) -> bool {
self.0
.get_child("bind", NS_XMPP_BIND)
.is_some()
}
}

View file

@ -11,6 +11,7 @@ use tokio_util::codec::Framed;
use xmpp_parsers::{Element, Jid}; use xmpp_parsers::{Element, Jid};
use crate::stream_start; use crate::stream_start;
use crate::stream_features::StreamFeatures;
use crate::xmpp_codec::{Packet, XMPPCodec}; use crate::xmpp_codec::{Packet, XMPPCodec};
use crate::Error; use crate::Error;
@ -27,7 +28,7 @@ pub struct XMPPStream<S: AsyncRead + AsyncWrite + Unpin> {
/// Codec instance /// Codec instance
pub stream: Mutex<Framed<S, XMPPCodec>>, pub stream: Mutex<Framed<S, XMPPCodec>>,
/// `<stream:features/>` for XMPP version 1.0 /// `<stream:features/>` for XMPP version 1.0
pub stream_features: Element, pub stream_features: StreamFeatures,
/// Root namespace /// Root namespace
/// ///
/// This is different for either c2s, s2s, or component /// This is different for either c2s, s2s, or component
@ -49,7 +50,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> XMPPStream<S> {
XMPPStream { XMPPStream {
jid, jid,
stream: Mutex::new(stream), stream: Mutex::new(stream),
stream_features, stream_features: StreamFeatures::new(stream_features),
ns, ns,
id, id,
} }