diff --git a/tokio-xmpp/Cargo.toml b/tokio-xmpp/Cargo.toml index 08818c8c..92039e71 100644 --- a/tokio-xmpp/Cargo.toml +++ b/tokio-xmpp/Cargo.toml @@ -41,7 +41,7 @@ tokio-xmpp = { path = ".", features = ["insecure-tcp"]} [features] default = ["starttls-rust"] -starttls = ["hickory-resolver", "idna"] +starttls = ["dns"] tls-rust = ["tokio-rustls", "webpki-roots"] tls-native = ["tokio-native-tls", "native-tls"] starttls-native = ["starttls", "tls-native"] @@ -50,6 +50,8 @@ insecure-tcp = [] syntax-highlighting = ["syntect"] # Enable serde support in jid crate serde = [ "xmpp-parsers/serde" ] +# Required by starttls, and used by insecure-tcp by default +dns = [ "hickory-resolver", "idna" ] [lints.rust] unexpected_cfgs = { level = "warn", check-cfg = ['cfg(xmpprs_doc_build)'] } diff --git a/tokio-xmpp/src/client/async_client.rs b/tokio-xmpp/src/client/async_client.rs index 38c6ca44..de2937dc 100644 --- a/tokio-xmpp/src/client/async_client.rs +++ b/tokio-xmpp/src/client/async_client.rs @@ -10,9 +10,11 @@ use super::connect::client_login; use crate::connect::{AsyncReadAndWrite, ServerConnector}; use crate::error::{Error, ProtocolError}; use crate::event::Event; +#[cfg(feature = "starttls")] use crate::starttls::ServerConfig; use crate::xmpp_codec::Packet; use crate::xmpp_stream::{add_stanza_id, XMPPStream}; +#[cfg(feature = "starttls")] use crate::AsyncConfig; /// XMPP client connection and state @@ -46,6 +48,7 @@ enum ClientState { Connected(XMPPStream), } +#[cfg(feature = "starttls")] impl Client { /// Start a new XMPP client using StartTLS transport and autoreconnect /// diff --git a/tokio-xmpp/src/client/simple_client.rs b/tokio-xmpp/src/client/simple_client.rs index d57e82e0..feb122e3 100644 --- a/tokio-xmpp/src/client/simple_client.rs +++ b/tokio-xmpp/src/client/simple_client.rs @@ -1,12 +1,14 @@ use futures::{sink::SinkExt, Sink, Stream}; use minidom::Element; use std::pin::Pin; +#[cfg(feature = "starttls")] use std::str::FromStr; use std::task::{Context, Poll}; use tokio_stream::StreamExt; use xmpp_parsers::{jid::Jid, ns, stream_features::StreamFeatures}; use crate::connect::ServerConnector; +#[cfg(feature = "starttls")] use crate::starttls::ServerConfig; use crate::xmpp_codec::Packet; use crate::xmpp_stream::{add_stanza_id, XMPPStream}; @@ -22,6 +24,7 @@ pub struct Client { stream: XMPPStream, } +#[cfg(feature = "starttls")] impl Client { /// Start a new XMPP client and wait for a usable session pub async fn new>(jid: &str, password: P) -> Result { diff --git a/tokio-xmpp/src/connect.rs b/tokio-xmpp/src/connect.rs index d72b3676..532dcf19 100644 --- a/tokio-xmpp/src/connect.rs +++ b/tokio-xmpp/src/connect.rs @@ -1,7 +1,17 @@ //! `ServerConnector` provides streams for XMPP clients +#[cfg(feature = "dns")] +use futures::{future::select_ok, FutureExt}; +#[cfg(feature = "dns")] +use hickory_resolver::{ + config::LookupIpStrategy, name_server::TokioConnectionProvider, IntoName, TokioAsyncResolver, +}; +#[cfg(feature = "dns")] +use log::debug; use sasl::common::ChannelBinding; +use std::net::{IpAddr, SocketAddr}; use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpStream; use xmpp_parsers::jid::Jid; use crate::xmpp_stream::XMPPStream; @@ -32,3 +42,78 @@ pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static { Ok(ChannelBinding::None) } } + +/// A simple wrapper to build [`TcpStream`] +pub struct Tcp; + +impl Tcp { + /// Connect directly to an IP/Port combo + pub async fn connect(ip: IpAddr, port: u16) -> Result { + Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?) + } + + /// Connect over TCP, resolving A/AAAA records (happy eyeballs) + #[cfg(feature = "dns")] + pub async fn resolve(domain: &str, port: u16) -> Result { + let ascii_domain = idna::domain_to_ascii(&domain)?; + + if let Ok(ip) = ascii_domain.parse() { + return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?); + } + + let (config, mut options) = hickory_resolver::system_conf::read_system_conf()?; + options.ip_strategy = LookupIpStrategy::Ipv4AndIpv6; + let resolver = TokioAsyncResolver::new(config, options, TokioConnectionProvider::default()); + + let ips = resolver.lookup_ip(ascii_domain).await?; + + // Happy Eyeballs: connect to all records in parallel, return the + // first to succeed + select_ok( + ips.into_iter() + .map(|ip| TcpStream::connect(SocketAddr::new(ip, port)).boxed()), + ) + .await + .map(|(result, _)| result) + .map_err(|_| Error::Disconnected) + } + + /// Connect over TCP, resolving SRV records + #[cfg(feature = "dns")] + pub async fn resolve_with_srv( + domain: &str, + srv: &str, + fallback_port: u16, + ) -> Result { + let ascii_domain = idna::domain_to_ascii(&domain)?; + + if let Ok(ip) = ascii_domain.parse() { + debug!("Attempting connection to {ip}:{fallback_port}"); + return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?); + } + + let resolver = TokioAsyncResolver::tokio_from_system_conf()?; + + let srv_domain = format!("{}.{}.", srv, ascii_domain).into_name()?; + let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok(); + + match srv_records { + Some(lookup) => { + // TODO: sort lookup records by priority/weight + for srv in lookup.iter() { + debug!("Attempting connection to {srv_domain} {srv}"); + match Self::resolve(&srv.target().to_ascii(), srv.port()).await { + Ok(stream) => return Ok(stream), + Err(_) => {} + } + } + Err(Error::Disconnected) + } + None => { + // SRV lookup error, retry with hostname + debug!("Attempting connection to {domain}:{fallback_port}"); + Self::resolve(domain, fallback_port).await + } + } + } +} diff --git a/tokio-xmpp/src/error.rs b/tokio-xmpp/src/error.rs index f5135fe2..80a14962 100644 --- a/tokio-xmpp/src/error.rs +++ b/tokio-xmpp/src/error.rs @@ -1,3 +1,7 @@ +#[cfg(feature = "dns")] +use hickory_resolver::{ + error::ResolveError as DnsResolveError, proto::error::ProtoError as DnsProtoError, +}; use sasl::client::MechanismError as SaslMechanismError; use std::error::Error as StdError; use std::fmt; @@ -28,8 +32,18 @@ pub enum Error { Fmt(fmt::Error), /// Utf8 error Utf8(Utf8Error), - /// Error resolving DNS and/or establishing a connection, returned by a ServerConnector impl + /// Error specific to ServerConnector impl Connection(Box), + /// DNS protocol error + #[cfg(feature = "dns")] + Dns(DnsProtoError), + /// DNS resolution error + #[cfg(feature = "dns")] + Resolve(DnsResolveError), + /// DNS label conversion error, no details available from module + /// `idna` + #[cfg(feature = "dns")] + Idna, } impl fmt::Display for Error { @@ -44,6 +58,12 @@ impl fmt::Display for Error { Error::InvalidState => write!(fmt, "invalid state"), Error::Fmt(e) => write!(fmt, "Fmt error: {}", e), Error::Utf8(e) => write!(fmt, "Utf8 error: {}", e), + #[cfg(feature = "dns")] + Error::Dns(e) => write!(fmt, "{:?}", e), + #[cfg(feature = "dns")] + Error::Resolve(e) => write!(fmt, "{:?}", e), + #[cfg(feature = "dns")] + Error::Idna => write!(fmt, "IDNA error"), } } } @@ -92,6 +112,27 @@ impl From for Error { } } +#[cfg(feature = "dns")] +impl From for Error { + fn from(_e: idna::Errors) -> Self { + Error::Idna + } +} + +#[cfg(feature = "dns")] +impl From for Error { + fn from(e: DnsResolveError) -> Error { + Error::Resolve(e) + } +} + +#[cfg(feature = "dns")] +impl From for Error { + fn from(e: DnsProtoError) -> Error { + Error::Dns(e) + } +} + /// XMPP protocol-level error #[derive(Debug)] pub enum ProtocolError { diff --git a/tokio-xmpp/src/starttls/error.rs b/tokio-xmpp/src/starttls/error.rs index 5de28217..824bde23 100644 --- a/tokio-xmpp/src/starttls/error.rs +++ b/tokio-xmpp/src/starttls/error.rs @@ -1,6 +1,5 @@ //! StartTLS ServerConnector Error -use hickory_resolver::{error::ResolveError, proto::error::ProtoError}; #[cfg(feature = "tls-native")] use native_tls::Error as TlsError; use std::error::Error as StdError; @@ -15,13 +14,6 @@ use super::ServerConnectorError; /// StartTLS ServerConnector Error #[derive(Debug)] pub enum Error { - /// DNS protocol error - Dns(ProtoError), - /// DNS resolution error - Resolve(ResolveError), - /// DNS label conversion error, no details available from module - /// `idna` - Idna, /// TLS error Tls(TlsError), #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] @@ -34,9 +26,6 @@ impl ServerConnectorError for Error {} impl fmt::Display for Error { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { match self { - Self::Dns(e) => write!(fmt, "{:?}", e), - Self::Resolve(e) => write!(fmt, "{:?}", e), - Self::Idna => write!(fmt, "IDNA error"), Self::Tls(e) => write!(fmt, "TLS error: {}", e), #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e), diff --git a/tokio-xmpp/src/starttls/happy_eyeballs.rs b/tokio-xmpp/src/starttls/happy_eyeballs.rs deleted file mode 100644 index 0ce5c8ea..00000000 --- a/tokio-xmpp/src/starttls/happy_eyeballs.rs +++ /dev/null @@ -1,75 +0,0 @@ -use super::error::Error as StartTlsError; -use crate::Error; -use futures::{future::select_ok, FutureExt}; -use hickory_resolver::{ - config::LookupIpStrategy, name_server::TokioConnectionProvider, IntoName, TokioAsyncResolver, -}; -use log::debug; -use std::net::SocketAddr; -use tokio::net::TcpStream; - -pub async fn connect_to_host(domain: &str, port: u16) -> Result { - let ascii_domain = idna::domain_to_ascii(&domain).map_err(|_| StartTlsError::Idna)?; - - if let Ok(ip) = ascii_domain.parse() { - return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?); - } - - let (config, mut options) = - hickory_resolver::system_conf::read_system_conf().map_err(StartTlsError::Resolve)?; - options.ip_strategy = LookupIpStrategy::Ipv4AndIpv6; - let resolver = TokioAsyncResolver::new(config, options, TokioConnectionProvider::default()); - - let ips = resolver - .lookup_ip(ascii_domain) - .await - .map_err(StartTlsError::Resolve)?; - // Happy Eyeballs: connect to all records in parallel, return the - // first to succeed - select_ok( - ips.into_iter() - .map(|ip| TcpStream::connect(SocketAddr::new(ip, port)).boxed()), - ) - .await - .map(|(result, _)| result) - .map_err(|_| crate::Error::Disconnected) -} - -pub async fn connect_with_srv( - domain: &str, - srv: &str, - fallback_port: u16, -) -> Result { - let ascii_domain = idna::domain_to_ascii(&domain).map_err(|_| StartTlsError::Idna)?; - - if let Ok(ip) = ascii_domain.parse() { - debug!("Attempting connection to {ip}:{fallback_port}"); - return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?); - } - - let resolver = TokioAsyncResolver::tokio_from_system_conf().map_err(StartTlsError::Resolve)?; - - let srv_domain = format!("{}.{}.", srv, ascii_domain) - .into_name() - .map_err(StartTlsError::Dns)?; - let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok(); - - match srv_records { - Some(lookup) => { - // TODO: sort lookup records by priority/weight - for srv in lookup.iter() { - debug!("Attempting connection to {srv_domain} {srv}"); - match connect_to_host(&srv.target().to_ascii(), srv.port()).await { - Ok(stream) => return Ok(stream), - Err(_) => {} - } - } - Err(crate::Error::Disconnected.into()) - } - None => { - // SRV lookup error, retry with hostname - debug!("Attempting connection to {domain}:{fallback_port}"); - connect_to_host(domain, fallback_port).await - } - } -} diff --git a/tokio-xmpp/src/starttls/mod.rs b/tokio-xmpp/src/starttls/mod.rs index 0e1e9cb4..4824afb9 100644 --- a/tokio-xmpp/src/starttls/mod.rs +++ b/tokio-xmpp/src/starttls/mod.rs @@ -27,16 +27,17 @@ use tokio::{ }; use xmpp_parsers::{jid::Jid, ns}; -use crate::error::ProtocolError; -use crate::Error; -use crate::{connect::ServerConnector, xmpp_codec::Packet, AsyncClient, SimpleClient}; -use crate::{connect::ServerConnectorError, xmpp_stream::XMPPStream}; +use crate::{ + connect::{ServerConnector, ServerConnectorError, Tcp}, + error::{Error, ProtocolError}, + xmpp_codec::Packet, + xmpp_stream::XMPPStream, + AsyncClient, SimpleClient, +}; use self::error::Error as StartTlsError; -use self::happy_eyeballs::{connect_to_host, connect_with_srv}; pub mod error; -mod happy_eyeballs; /// AsyncClient that connects over StartTls pub type StartTlsAsyncClient = AsyncClient; @@ -64,9 +65,9 @@ impl ServerConnector for ServerConfig { // TCP connection let tcp_stream = match self { ServerConfig::UseSrv => { - connect_with_srv(jid.domain().as_str(), "_xmpp-client._tcp", 5222).await? + Tcp::resolve_with_srv(jid.domain().as_str(), "_xmpp-client._tcp", 5222).await? } - ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), *port).await?, + ServerConfig::Manual { host, port } => Tcp::resolve(host.as_str(), *port).await?, }; // Unencryped XMPPStream