diff --git a/tokio-xmpp/Cargo.toml b/tokio-xmpp/Cargo.toml index a984e63..1d70da6 100644 --- a/tokio-xmpp/Cargo.toml +++ b/tokio-xmpp/Cargo.toml @@ -14,17 +14,12 @@ edition = "2021" [dependencies] bytes = "1" futures = "0.3" -idna = "0.4" log = "0.4" -native-tls = { version = "0.2", optional = true } tokio = { version = "1", features = ["net", "rt", "rt-multi-thread", "macros"] } -tokio-native-tls = { version = "0.3", optional = true } -tokio-rustls = { version = "0.24", optional = true } tokio-stream = { version = "0.1", features = [] } tokio-util = { version = "0.7", features = ["codec"] } -hickory-resolver = "0.24" -rxml = "0.9.1" webpki-roots = { version = "0.25", optional = true } +rxml = "0.9.1" rand = "^0.8" syntect = { version = "5", optional = true } # same repository dependencies @@ -32,11 +27,21 @@ minidom = { version = "0.15", path = "../minidom" } sasl = { version = "0.5", path = "../sasl" } xmpp-parsers = { version = "0.20", path = "../parsers" } +# these are only needed for starttls ServerConnector support +hickory-resolver = { version = "0.24", optional = true} +idna = { version = "0.4", optional = true} +native-tls = { version = "0.2", optional = true } +tokio-native-tls = { version = "0.3", optional = true } +tokio-rustls = { version = "0.24", optional = true } + [dev-dependencies] env_logger = { version = "0.10", default-features = false, features = ["auto-color", "humantime"] } [features] -default = ["tls-native"] +default = ["starttls-rust"] +starttls = ["hickory-resolver", "idna"] tls-rust = ["tokio-rustls", "webpki-roots"] tls-native = ["tokio-native-tls", "native-tls"] +starttls-native = ["starttls", "tls-native"] +starttls-rust = ["starttls", "tls-rust"] syntax-highlighting = ["syntect"] diff --git a/tokio-xmpp/src/client/async_client.rs b/tokio-xmpp/src/client/async_client.rs index 0990d3d..34171d6 100644 --- a/tokio-xmpp/src/client/async_client.rs +++ b/tokio-xmpp/src/client/async_client.rs @@ -1,23 +1,16 @@ use futures::{sink::SinkExt, task::Poll, Future, Sink, Stream}; -use sasl::common::ChannelBinding; use std::mem::replace; use std::pin::Pin; use std::task::Context; -use tokio::net::TcpStream; use tokio::task::JoinHandle; use xmpp_parsers::{ns, Element, Jid}; -use super::connect::{AsyncReadAndWrite, ServerConnector}; +use super::connect::client_login; +use crate::connect::{AsyncReadAndWrite, ServerConnector}; use crate::event::Event; -use crate::happy_eyeballs::{connect_to_host, connect_with_srv}; -use crate::starttls::starttls; use crate::xmpp_codec::Packet; -use crate::xmpp_stream::{self, add_stanza_id, XMPPStream}; -use crate::{client_login, Error, ProtocolError}; -#[cfg(feature = "tls-native")] -use tokio_native_tls::TlsStream; -#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] -use tokio_rustls::client::TlsStream; +use crate::xmpp_stream::{add_stanza_id, XMPPStream}; +use crate::{Error, ProtocolError}; /// XMPP client connection and state /// @@ -43,76 +36,6 @@ pub struct Config { pub server: C, } -/// XMPP server connection configuration -#[derive(Clone, Debug)] -pub enum ServerConfig { - /// Use SRV record to find server host - UseSrv, - #[allow(unused)] - /// Manually define server host and port - Manual { - /// Server host name - host: String, - /// Server port - port: u16, - }, -} - -impl ServerConnector for ServerConfig { - type Stream = TlsStream; - async fn connect(&self, jid: &Jid) -> Result, Error> { - // TCP connection - let tcp_stream = match self { - ServerConfig::UseSrv => { - connect_with_srv(jid.domain_str(), "_xmpp-client._tcp", 5222).await? - } - ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), *port).await?, - }; - - // Unencryped XMPPStream - let xmpp_stream = - xmpp_stream::XMPPStream::start(tcp_stream, jid.clone(), ns::JABBER_CLIENT.to_owned()) - .await?; - - if xmpp_stream.stream_features.can_starttls() { - // TlsStream - let tls_stream = starttls(xmpp_stream).await?; - // Encrypted XMPPStream - xmpp_stream::XMPPStream::start(tls_stream, jid.clone(), ns::JABBER_CLIENT.to_owned()) - .await - } else { - return Err(Error::Protocol(ProtocolError::NoTls)); - } - } - - fn channel_binding( - #[allow(unused_variables)] stream: &Self::Stream, - ) -> Result { - #[cfg(feature = "tls-native")] - { - log::warn!("tls-native doesn’t support channel binding, please use tls-rust if you want this feature!"); - Ok(ChannelBinding::None) - } - #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] - { - let (_, connection) = stream.get_ref(); - Ok(match connection.protocol_version() { - // TODO: Add support for TLS 1.2 and earlier. - Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => { - let data = vec![0u8; 32]; - let data = connection.export_keying_material( - data, - b"EXPORTER-Channel-Binding", - None, - )?; - ChannelBinding::TlsExporter(data) - } - _ => ChannelBinding::None, - }) - } - } -} - enum ClientState { Invalid, Disconnected, @@ -120,21 +43,6 @@ enum ClientState { Connected(XMPPStream), } -impl Client { - /// Start a new XMPP client - /// - /// Start polling the returned instance so that it will connect - /// and yield events. - pub fn new, P: Into>(jid: J, password: P) -> Self { - let config = Config { - jid: jid.into(), - password: password.into(), - server: ServerConfig::UseSrv, - }; - Self::new_with_config(config) - } -} - impl Client { /// Start a new client given that the JID is already parsed. pub fn new_with_config(config: Config) -> Self { diff --git a/tokio-xmpp/src/client/connect.rs b/tokio-xmpp/src/client/connect.rs index 1302b6a..d34929c 100644 --- a/tokio-xmpp/src/client/connect.rs +++ b/tokio-xmpp/src/client/connect.rs @@ -1,32 +1,11 @@ -use sasl::common::{ChannelBinding, Credentials}; -use tokio::io::{AsyncRead, AsyncWrite}; +use sasl::common::Credentials; use xmpp_parsers::{ns, Jid}; -use super::{auth::auth, bind::bind}; +use crate::client::auth::auth; +use crate::client::bind::bind; +use crate::connect::ServerConnector; use crate::{xmpp_stream::XMPPStream, Error}; -/// trait returned wrapped in XMPPStream by ServerConnector -pub trait AsyncReadAndWrite: AsyncRead + AsyncWrite + Unpin + Send {} -impl AsyncReadAndWrite for T {} - -/// Trait called to connect to an XMPP server, perhaps called multiple times -pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static { - /// The type of Stream this ServerConnector produces - type Stream: AsyncReadAndWrite; - /// This must return the connection ready to login, ie if starttls is involved, after TLS has been started, and then after the impl std::future::Future, Error>> + Send; - - /// Return channel binding data if available - /// do not fail if channel binding is simply unavailable, just return Ok(None) - /// this should only be called after the TLS handshake is finished - fn channel_binding(_stream: &Self::Stream) -> Result { - Ok(ChannelBinding::None) - } -} - /// Log into an XMPP server as a client with a jid+pass /// does channel binding if supported pub async fn client_login( @@ -37,7 +16,7 @@ pub async fn client_login( let username = jid.node_str().unwrap(); let password = password; - let xmpp_stream = server.connect(&jid).await?; + let xmpp_stream = server.connect(&jid, ns::JABBER_CLIENT).await?; let channel_binding = C::channel_binding(xmpp_stream.stream.get_ref())?; diff --git a/tokio-xmpp/src/client/mod.rs b/tokio-xmpp/src/client/mod.rs index 4910eec..b664785 100644 --- a/tokio-xmpp/src/client/mod.rs +++ b/tokio-xmpp/src/client/mod.rs @@ -1,6 +1,7 @@ mod auth; mod bind; +pub(crate) mod connect; + pub mod async_client; -pub mod connect; pub mod simple_client; diff --git a/tokio-xmpp/src/client/simple_client.rs b/tokio-xmpp/src/client/simple_client.rs index cd5d4ca..3f2b07c 100644 --- a/tokio-xmpp/src/client/simple_client.rs +++ b/tokio-xmpp/src/client/simple_client.rs @@ -1,13 +1,15 @@ use futures::{sink::SinkExt, Sink, Stream}; use std::pin::Pin; -use std::str::FromStr; use std::task::{Context, Poll}; use tokio_stream::StreamExt; use xmpp_parsers::{ns, Element, Jid}; +use crate::connect::ServerConnector; use crate::xmpp_codec::Packet; use crate::xmpp_stream::{add_stanza_id, XMPPStream}; -use crate::{client_login, AsyncServerConfig, Error, ServerConnector}; +use crate::Error; + +use super::connect::client_login; /// A simple XMPP client connection /// @@ -17,19 +19,6 @@ pub struct Client { stream: XMPPStream, } -impl Client { - /// Start a new XMPP client and wait for a usable session - pub async fn new>(jid: &str, password: P) -> Result { - let jid = Jid::from_str(jid)?; - Self::new_with_jid(jid, password.into()).await - } - - /// Start a new client given that the JID is already parsed. - pub async fn new_with_jid(jid: Jid, password: String) -> Result { - Self::new_with_jid_connector(AsyncServerConfig::UseSrv, jid, password).await - } -} - impl Client { /// Start a new client given that the JID is already parsed. pub async fn new_with_jid_connector( diff --git a/tokio-xmpp/src/component/connect.rs b/tokio-xmpp/src/component/connect.rs new file mode 100644 index 0000000..509172e --- /dev/null +++ b/tokio-xmpp/src/component/connect.rs @@ -0,0 +1,18 @@ +use xmpp_parsers::{ns, Jid}; + +use crate::connect::ServerConnector; +use crate::{xmpp_stream::XMPPStream, Error}; + +use super::auth::auth; + +/// Log into an XMPP server as a client with a jid+pass +pub async fn component_login( + connector: C, + jid: Jid, + password: String, +) -> Result, Error> { + let password = password; + let mut xmpp_stream = connector.connect(&jid, ns::COMPONENT).await?; + auth(&mut xmpp_stream, password).await?; + Ok(xmpp_stream) +} diff --git a/tokio-xmpp/src/component/mod.rs b/tokio-xmpp/src/component/mod.rs index dba79fa..d2b4bfa 100644 --- a/tokio-xmpp/src/component/mod.rs +++ b/tokio-xmpp/src/component/mod.rs @@ -5,53 +5,39 @@ use futures::{sink::SinkExt, task::Poll, Sink, Stream}; use std::pin::Pin; use std::str::FromStr; use std::task::Context; -use tokio::net::TcpStream; use xmpp_parsers::{ns, Element, Jid}; -use super::happy_eyeballs::connect_to_host; +use self::connect::component_login; + use super::xmpp_codec::Packet; -use super::xmpp_stream; use super::Error; +use crate::connect::ServerConnector; use crate::xmpp_stream::add_stanza_id; +use crate::xmpp_stream::XMPPStream; mod auth; +pub(crate) mod connect; + /// Component connection to an XMPP server /// /// This simplifies the `XMPPStream` to a `Stream`/`Sink` of `Element` /// (stanzas). Connection handling however is up to the user. -pub struct Component { +pub struct Component { /// The component's Jabber-Id pub jid: Jid, - stream: XMPPStream, + stream: XMPPStream, } -type XMPPStream = xmpp_stream::XMPPStream; - -impl Component { +impl Component { /// Start a new XMPP component - pub async fn new(jid: &str, password: &str, server: &str, port: u16) -> Result { + pub async fn new(jid: &str, password: &str, connector: C) -> Result { let jid = Jid::from_str(jid)?; let password = password.to_owned(); - let stream = Self::connect(jid.clone(), password, server, port).await?; + let stream = component_login(connector, jid.clone(), password).await?; Ok(Component { jid, stream }) } - async fn connect( - jid: Jid, - password: String, - server: &str, - port: u16, - ) -> Result { - let password = password; - let tcp_stream = connect_to_host(server, port).await?; - let mut xmpp_stream = - xmpp_stream::XMPPStream::start(tcp_stream, jid, ns::COMPONENT_ACCEPT.to_owned()) - .await?; - auth::auth(&mut xmpp_stream, password).await?; - Ok(xmpp_stream) - } - /// Send stanza pub async fn send_stanza(&mut self, stanza: Element) -> Result<(), Error> { self.send(add_stanza_id(stanza, ns::COMPONENT_ACCEPT)).await @@ -63,7 +49,7 @@ impl Component { } } -impl Stream for Component { +impl Stream for Component { type Item = Element; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { @@ -86,7 +72,7 @@ impl Stream for Component { } } -impl Sink for Component { +impl Sink for Component { type Error = Error; fn start_send(mut self: Pin<&mut Self>, item: Element) -> Result<(), Self::Error> { diff --git a/tokio-xmpp/src/connect.rs b/tokio-xmpp/src/connect.rs new file mode 100644 index 0000000..2358aba --- /dev/null +++ b/tokio-xmpp/src/connect.rs @@ -0,0 +1,35 @@ +//! `ServerConnector` provides streams for XMPP clients + +use sasl::common::ChannelBinding; +use tokio::io::{AsyncRead, AsyncWrite}; +use xmpp_parsers::Jid; + +use crate::xmpp_stream::XMPPStream; + +/// trait returned wrapped in XMPPStream by ServerConnector +pub trait AsyncReadAndWrite: AsyncRead + AsyncWrite + Unpin + Send {} +impl AsyncReadAndWrite for T {} + +/// Trait that must be extended by the implementation of ServerConnector +pub trait ServerConnectorError: std::error::Error + Send {} + +/// Trait called to connect to an XMPP server, perhaps called multiple times +pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static { + /// The type of Stream this ServerConnector produces + type Stream: AsyncReadAndWrite; + /// Error type to return + type Error: ServerConnectorError; + /// This must return the connection ready to login, ie if starttls is involved, after TLS has been started, and then after the impl std::future::Future, Self::Error>> + Send; + + /// Return channel binding data if available + /// do not fail if channel binding is simply unavailable, just return Ok(None) + /// this should only be called after the TLS handshake is finished + fn channel_binding(_stream: &Self::Stream) -> Result { + Ok(ChannelBinding::None) + } +} diff --git a/tokio-xmpp/src/error.rs b/tokio-xmpp/src/error.rs index 3869365..d1c5b19 100644 --- a/tokio-xmpp/src/error.rs +++ b/tokio-xmpp/src/error.rs @@ -1,41 +1,26 @@ -use hickory_resolver::{error::ResolveError, proto::error::ProtoError}; -#[cfg(feature = "tls-native")] -use native_tls::Error as TlsError; use sasl::client::MechanismError as SaslMechanismError; use std::borrow::Cow; use std::error::Error as StdError; use std::fmt; use std::io::Error as IoError; use std::str::Utf8Error; -#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] -use tokio_rustls::rustls::client::InvalidDnsNameError; -#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] -use tokio_rustls::rustls::Error as TlsError; use xmpp_parsers::sasl::DefinedCondition as SaslDefinedCondition; use xmpp_parsers::{Error as ParsersError, JidParseError}; +use crate::connect::ServerConnectorError; + /// Top-level error type #[derive(Debug)] pub enum Error { /// I/O error Io(IoError), - /// Error resolving DNS and establishing a connection - Connection(ConnecterError), - /// DNS label conversion error, no details available from module - /// `idna` - Idna, /// Error parsing Jabber-Id JidParse(JidParseError), /// Protocol-level error Protocol(ProtocolError), /// Authentication error Auth(AuthError), - /// TLS error - Tls(TlsError), - #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] - /// DNS name parsing error - DnsNameError(InvalidDnsNameError), /// Connection closed Disconnected, /// Shoud never happen @@ -44,6 +29,8 @@ pub enum Error { Fmt(fmt::Error), /// Utf8 error Utf8(Utf8Error), + /// Error resolving DNS and/or establishing a connection, returned by a ServerConnector impl + Connection(Box), } impl fmt::Display for Error { @@ -51,13 +38,9 @@ impl fmt::Display for Error { match self { Error::Io(e) => write!(fmt, "IO error: {}", e), Error::Connection(e) => write!(fmt, "connection error: {}", e), - Error::Idna => write!(fmt, "IDNA error"), Error::JidParse(e) => write!(fmt, "jid parse error: {}", e), Error::Protocol(e) => write!(fmt, "protocol error: {}", e), Error::Auth(e) => write!(fmt, "authentication error: {}", e), - Error::Tls(e) => write!(fmt, "TLS error: {}", e), - #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] - Error::DnsNameError(e) => write!(fmt, "DNS name error: {}", e), Error::Disconnected => write!(fmt, "disconnected"), Error::InvalidState => write!(fmt, "invalid state"), Error::Fmt(e) => write!(fmt, "Fmt error: {}", e), @@ -74,9 +57,9 @@ impl From for Error { } } -impl From for Error { - fn from(e: ConnecterError) -> Self { - Error::Connection(e) +impl From for Error { + fn from(e: T) -> Self { + Error::Connection(Box::new(e)) } } @@ -98,12 +81,6 @@ impl From for Error { } } -impl From for Error { - fn from(e: TlsError) -> Self { - Error::Tls(e) - } -} - impl From for Error { fn from(e: fmt::Error) -> Self { Error::Fmt(e) @@ -116,13 +93,6 @@ impl From for Error { } } -#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] -impl From for Error { - fn from(e: InvalidDnsNameError) -> Self { - Error::DnsNameError(e) - } -} - /// XML parse error wrapper type #[derive(Debug)] pub struct ParseError(pub Cow<'static, str>); @@ -227,22 +197,3 @@ impl fmt::Display for AuthError { } } } - -/// Error establishing connection -#[derive(Debug)] -pub enum ConnecterError { - /// All attempts failed, no error available - AllFailed, - /// DNS protocol error - Dns(ProtoError), - /// DNS resolution error - Resolve(ResolveError), -} - -impl StdError for ConnecterError {} - -impl std::fmt::Display for ConnecterError { - fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { - write!(fmt, "{:?}", self) - } -} diff --git a/tokio-xmpp/src/lib.rs b/tokio-xmpp/src/lib.rs index 359f2b5..1b95b04 100644 --- a/tokio-xmpp/src/lib.rs +++ b/tokio-xmpp/src/lib.rs @@ -5,31 +5,35 @@ #[cfg(all(feature = "tls-native", feature = "tls-rust"))] compile_error!("Both tls-native and tls-rust features can't be enabled at the same time."); -#[cfg(all(not(feature = "tls-native"), not(feature = "tls-rust")))] -compile_error!("One of tls-native and tls-rust features must be enabled."); +#[cfg(all( + feature = "starttls", + not(feature = "tls-native"), + not(feature = "tls-rust") +))] +compile_error!( + "when starttls feature enabled one of tls-native and tls-rust features must be enabled." +); -mod starttls; +#[cfg(feature = "starttls")] +pub mod starttls; mod stream_start; mod xmpp_codec; pub use crate::xmpp_codec::Packet; mod event; pub use event::Event; mod client; -mod happy_eyeballs; +pub mod connect; pub mod stream_features; pub mod xmpp_stream; + pub use client::{ - async_client::{ - Client as AsyncClient, Config as AsyncConfig, ServerConfig as AsyncServerConfig, - }, - connect::{client_login, AsyncReadAndWrite, ServerConnector}, + async_client::{Client as AsyncClient, Config as AsyncConfig}, simple_client::Client as SimpleClient, }; mod component; pub use crate::component::Component; mod error; -pub use crate::error::{AuthError, ConnecterError, Error, ParseError, ProtocolError}; -pub use starttls::starttls; +pub use crate::error::{AuthError, Error, ParseError, ProtocolError}; // Re-exports pub use minidom::Element; diff --git a/tokio-xmpp/src/starttls.rs b/tokio-xmpp/src/starttls.rs deleted file mode 100644 index 73d72f8..0000000 --- a/tokio-xmpp/src/starttls.rs +++ /dev/null @@ -1,85 +0,0 @@ -use futures::{sink::SinkExt, stream::StreamExt}; - -#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] -use { - std::sync::Arc, - tokio_rustls::{ - client::TlsStream, - rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}, - TlsConnector, - }, - webpki_roots, -}; - -#[cfg(feature = "tls-native")] -use { - native_tls::TlsConnector as NativeTlsConnector, - tokio_native_tls::{TlsConnector, TlsStream}, -}; - -use tokio::io::{AsyncRead, AsyncWrite}; -use xmpp_parsers::{ns, Element}; - -use crate::xmpp_codec::Packet; -use crate::xmpp_stream::XMPPStream; -use crate::{Error, ProtocolError}; - -#[cfg(feature = "tls-native")] -async fn get_tls_stream( - xmpp_stream: XMPPStream, -) -> Result, Error> { - let domain = xmpp_stream.jid.domain_str().to_owned(); - let stream = xmpp_stream.into_inner(); - let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap()) - .connect(&domain, stream) - .await?; - Ok(tls_stream) -} - -#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] -async fn get_tls_stream( - xmpp_stream: XMPPStream, -) -> Result, Error> { - let domain = xmpp_stream.jid.domain_str().to_owned(); - let domain = ServerName::try_from(domain.as_str())?; - let stream = xmpp_stream.into_inner(); - let mut root_store = RootCertStore::empty(); - root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); - let config = ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_store) - .with_no_client_auth(); - let tls_stream = TlsConnector::from(Arc::new(config)) - .connect(domain, stream) - .await?; - Ok(tls_stream) -} - -/// Performs `` on an XMPPStream and returns a binary -/// TlsStream. -pub async fn starttls( - mut xmpp_stream: XMPPStream, -) -> Result, Error> { - let nonza = Element::builder("starttls", ns::TLS).build(); - let packet = Packet::Stanza(nonza); - xmpp_stream.send(packet).await?; - - loop { - match xmpp_stream.next().await { - Some(Ok(Packet::Stanza(ref stanza))) if stanza.name() == "proceed" => break, - Some(Ok(Packet::Text(_))) => {} - Some(Err(e)) => return Err(e.into()), - _ => { - return Err(ProtocolError::NoTls.into()); - } - } - } - - get_tls_stream(xmpp_stream).await -} diff --git a/tokio-xmpp/src/starttls/client.rs b/tokio-xmpp/src/starttls/client.rs new file mode 100644 index 0000000..2a2395c --- /dev/null +++ b/tokio-xmpp/src/starttls/client.rs @@ -0,0 +1,35 @@ +use std::str::FromStr; + +use xmpp_parsers::Jid; + +use crate::{AsyncClient, AsyncConfig, Error, SimpleClient}; + +use super::ServerConfig; + +impl AsyncClient { + /// Start a new XMPP client + /// + /// Start polling the returned instance so that it will connect + /// and yield events. + pub fn new, P: Into>(jid: J, password: P) -> Self { + let config = AsyncConfig { + jid: jid.into(), + password: password.into(), + server: ServerConfig::UseSrv, + }; + Self::new_with_config(config) + } +} + +impl SimpleClient { + /// Start a new XMPP client and wait for a usable session + pub async fn new>(jid: &str, password: P) -> Result { + let jid = Jid::from_str(jid)?; + Self::new_with_jid(jid, password.into()).await + } + + /// Start a new client given that the JID is already parsed. + pub async fn new_with_jid(jid: Jid, password: String) -> Result { + Self::new_with_jid_connector(ServerConfig::UseSrv, jid, password).await + } +} diff --git a/tokio-xmpp/src/starttls/error.rs b/tokio-xmpp/src/starttls/error.rs new file mode 100644 index 0000000..f2db656 --- /dev/null +++ b/tokio-xmpp/src/starttls/error.rs @@ -0,0 +1,105 @@ +use hickory_resolver::{error::ResolveError, proto::error::ProtoError}; +#[cfg(feature = "tls-native")] +use native_tls::Error as TlsError; +use std::borrow::Cow; +use std::error::Error as StdError; +use std::fmt; +#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] +use tokio_rustls::rustls::client::InvalidDnsNameError; +#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] +use tokio_rustls::rustls::Error as TlsError; + +/// Top-level error type +#[derive(Debug)] +pub enum Error { + /// Error resolving DNS and establishing a connection + Connection(ConnectorError), + /// DNS label conversion error, no details available from module + /// `idna` + Idna, + /// TLS error + Tls(TlsError), + #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] + /// DNS name parsing error + DnsNameError(InvalidDnsNameError), + /// tokio-xmpp error + TokioXMPP(crate::error::Error), +} + +impl fmt::Display for Error { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match self { + Error::Connection(e) => write!(fmt, "connection error: {}", e), + Error::Idna => write!(fmt, "IDNA error"), + Error::Tls(e) => write!(fmt, "TLS error: {}", e), + #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] + Error::DnsNameError(e) => write!(fmt, "DNS name error: {}", e), + Error::TokioXMPP(e) => write!(fmt, "TokioXMPP error: {}", e), + } + } +} + +impl StdError for Error {} + +impl From for Error { + fn from(e: crate::error::Error) -> Self { + Error::TokioXMPP(e) + } +} + +impl From for Error { + fn from(e: ConnectorError) -> Self { + Error::Connection(e) + } +} + +impl From for Error { + fn from(e: TlsError) -> Self { + Error::Tls(e) + } +} + +#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] +impl From for Error { + fn from(e: InvalidDnsNameError) -> Self { + Error::DnsNameError(e) + } +} + +/// XML parse error wrapper type +#[derive(Debug)] +pub struct ParseError(pub Cow<'static, str>); + +impl StdError for ParseError { + fn description(&self) -> &str { + self.0.as_ref() + } + fn cause(&self) -> Option<&dyn StdError> { + None + } +} + +impl fmt::Display for ParseError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Error establishing connection +#[derive(Debug)] +pub enum ConnectorError { + /// All attempts failed, no error available + AllFailed, + /// DNS protocol error + Dns(ProtoError), + /// DNS resolution error + Resolve(ResolveError), +} + +impl StdError for ConnectorError {} + +impl std::fmt::Display for ConnectorError { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + write!(fmt, "{:?}", self) + } +} diff --git a/tokio-xmpp/src/happy_eyeballs.rs b/tokio-xmpp/src/starttls/happy_eyeballs.rs similarity index 81% rename from tokio-xmpp/src/happy_eyeballs.rs rename to tokio-xmpp/src/starttls/happy_eyeballs.rs index abe4827..04879af 100644 --- a/tokio-xmpp/src/happy_eyeballs.rs +++ b/tokio-xmpp/src/starttls/happy_eyeballs.rs @@ -1,4 +1,4 @@ -use crate::{ConnecterError, Error}; +use super::error::{ConnectorError, Error}; use hickory_resolver::{IntoName, TokioAsyncResolver}; use idna; use log::debug; @@ -9,22 +9,24 @@ pub async fn connect_to_host(domain: &str, port: u16) -> Result return Ok(stream), Err(_) => {} } } - Err(Error::Disconnected) + Err(crate::Error::Disconnected.into()) } pub async fn connect_with_srv( @@ -36,14 +38,16 @@ pub async fn connect_with_srv( if let Ok(ip) = ascii_domain.parse() { debug!("Attempting connection to {ip}:{fallback_port}"); - return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?); + return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)) + .await + .map_err(|e| Error::from(crate::Error::Io(e)))?); } - let resolver = TokioAsyncResolver::tokio_from_system_conf().map_err(ConnecterError::Resolve)?; + let resolver = TokioAsyncResolver::tokio_from_system_conf().map_err(ConnectorError::Resolve)?; let srv_domain = format!("{}.{}.", srv, ascii_domain) .into_name() - .map_err(ConnecterError::Dns)?; + .map_err(ConnectorError::Dns)?; let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok(); match srv_records { @@ -56,7 +60,7 @@ pub async fn connect_with_srv( Err(_) => {} } } - Err(Error::Disconnected) + Err(crate::Error::Disconnected.into()) } None => { // SRV lookup error, retry with hostname diff --git a/tokio-xmpp/src/starttls/mod.rs b/tokio-xmpp/src/starttls/mod.rs new file mode 100644 index 0000000..67ecb25 --- /dev/null +++ b/tokio-xmpp/src/starttls/mod.rs @@ -0,0 +1,168 @@ +//! `starttls::ServerConfig` provides a `ServerConnector` for starttls connections + +use futures::{sink::SinkExt, stream::StreamExt}; + +#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] +use { + std::sync::Arc, + tokio_rustls::{ + client::TlsStream, + rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}, + TlsConnector, + }, + webpki_roots, +}; + +#[cfg(feature = "tls-native")] +use { + native_tls::TlsConnector as NativeTlsConnector, + tokio_native_tls::{TlsConnector, TlsStream}, +}; + +use sasl::common::ChannelBinding; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::TcpStream, +}; +use xmpp_parsers::{ns, Element, Jid}; + +use crate::{connect::ServerConnector, xmpp_codec::Packet}; +use crate::{connect::ServerConnectorError, xmpp_stream::XMPPStream}; + +use self::error::Error; +use self::happy_eyeballs::{connect_to_host, connect_with_srv}; + +mod client; +mod error; +mod happy_eyeballs; + +/// StartTLS XMPP server connection configuration +#[derive(Clone, Debug)] +pub enum ServerConfig { + /// Use SRV record to find server host + UseSrv, + #[allow(unused)] + /// Manually define server host and port + Manual { + /// Server host name + host: String, + /// Server port + port: u16, + }, +} + +impl ServerConnectorError for Error {} + +impl ServerConnector for ServerConfig { + type Stream = TlsStream; + type Error = Error; + async fn connect(&self, jid: &Jid, ns: &str) -> Result, Error> { + // TCP connection + let tcp_stream = match self { + ServerConfig::UseSrv => { + connect_with_srv(jid.domain_str(), "_xmpp-client._tcp", 5222).await? + } + ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), *port).await?, + }; + + // Unencryped XMPPStream + let xmpp_stream = XMPPStream::start(tcp_stream, jid.clone(), ns.to_owned()).await?; + + if xmpp_stream.stream_features.can_starttls() { + // TlsStream + let tls_stream = starttls(xmpp_stream).await?; + // Encrypted XMPPStream + Ok(XMPPStream::start(tls_stream, jid.clone(), ns.to_owned()).await?) + } else { + return Err(crate::Error::Protocol(crate::ProtocolError::NoTls).into()); + } + } + + fn channel_binding( + #[allow(unused_variables)] stream: &Self::Stream, + ) -> Result { + #[cfg(feature = "tls-native")] + { + log::warn!("tls-native doesn’t support channel binding, please use tls-rust if you want this feature!"); + Ok(ChannelBinding::None) + } + #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] + { + let (_, connection) = stream.get_ref(); + Ok(match connection.protocol_version() { + // TODO: Add support for TLS 1.2 and earlier. + Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => { + let data = vec![0u8; 32]; + let data = connection.export_keying_material( + data, + b"EXPORTER-Channel-Binding", + None, + )?; + ChannelBinding::TlsExporter(data) + } + _ => ChannelBinding::None, + }) + } + } +} + +#[cfg(feature = "tls-native")] +async fn get_tls_stream( + xmpp_stream: XMPPStream, +) -> Result, Error> { + let domain = xmpp_stream.jid.domain_str().to_owned(); + let stream = xmpp_stream.into_inner(); + let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap()) + .connect(&domain, stream) + .await?; + Ok(tls_stream) +} + +#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] +async fn get_tls_stream( + xmpp_stream: XMPPStream, +) -> Result, Error> { + let domain = xmpp_stream.jid.domain_str().to_owned(); + let domain = ServerName::try_from(domain.as_str())?; + let stream = xmpp_stream.into_inner(); + let mut root_store = RootCertStore::empty(); + root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); + let config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); + let tls_stream = TlsConnector::from(Arc::new(config)) + .connect(domain, stream) + .await + .map_err(|e| Error::from(crate::Error::Io(e)))?; + Ok(tls_stream) +} + +/// Performs `` on an XMPPStream and returns a binary +/// TlsStream. +pub async fn starttls( + mut xmpp_stream: XMPPStream, +) -> Result, Error> { + let nonza = Element::builder("starttls", ns::TLS).build(); + let packet = Packet::Stanza(nonza); + xmpp_stream.send(packet).await?; + + loop { + match xmpp_stream.next().await { + Some(Ok(Packet::Stanza(ref stanza))) if stanza.name() == "proceed" => break, + Some(Ok(Packet::Text(_))) => {} + Some(Err(e)) => return Err(e.into()), + _ => { + return Err(crate::Error::Protocol(crate::ProtocolError::NoTls).into()); + } + } + } + + get_tls_stream(xmpp_stream).await +} diff --git a/xmpp/Cargo.toml b/xmpp/Cargo.toml index 105a036..e2dc569 100644 --- a/xmpp/Cargo.toml +++ b/xmpp/Cargo.toml @@ -31,7 +31,7 @@ name = "hello_bot" required-features = ["avatars"] [features] -default = ["avatars", "tls-native"] -tls-native = ["tokio-xmpp/tls-native"] -tls-rust = ["tokio-xmpp/tls-rust"] +default = ["avatars", "starttls-rust"] +starttls-native = ["tokio-xmpp/starttls", "tokio-xmpp/tls-native"] +starttls-rust = ["tokio-xmpp/starttls", "tokio-xmpp/tls-rust"] avatars = [] diff --git a/xmpp/src/lib.rs b/xmpp/src/lib.rs index dd9b417..b347866 100644 --- a/xmpp/src/lib.rs +++ b/xmpp/src/lib.rs @@ -7,7 +7,7 @@ #![deny(bare_trait_objects)] pub use tokio_xmpp::parsers; -use tokio_xmpp::{AsyncClient, AsyncServerConfig}; +use tokio_xmpp::AsyncClient; pub use tokio_xmpp::{BareJid, Element, FullJid, Jid}; #[macro_use] extern crate log; @@ -32,7 +32,7 @@ pub use builder::{ClientBuilder, ClientType}; pub use event::Event; pub use feature::ClientFeature; -type TokioXmppClient = AsyncClient; +type TokioXmppClient = AsyncClient; pub type Error = tokio_xmpp::Error; pub type Id = Option;