Add dns feature for DNS stuff (not just in starttls)

This commit is contained in:
xmppftw 2024-08-05 15:09:59 +02:00
parent d706b318c3
commit 97698b4d1e
8 changed files with 145 additions and 96 deletions

View file

@ -41,7 +41,7 @@ tokio-xmpp = { path = ".", features = ["insecure-tcp"]}
[features]
default = ["starttls-rust"]
starttls = ["hickory-resolver", "idna"]
starttls = ["dns"]
tls-rust = ["tokio-rustls", "webpki-roots"]
tls-native = ["tokio-native-tls", "native-tls"]
starttls-native = ["starttls", "tls-native"]
@ -50,6 +50,8 @@ insecure-tcp = []
syntax-highlighting = ["syntect"]
# Enable serde support in jid crate
serde = [ "xmpp-parsers/serde" ]
# Required by starttls, and used by insecure-tcp by default
dns = [ "hickory-resolver", "idna" ]
[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(xmpprs_doc_build)'] }

View file

@ -10,9 +10,11 @@ use super::connect::client_login;
use crate::connect::{AsyncReadAndWrite, ServerConnector};
use crate::error::{Error, ProtocolError};
use crate::event::Event;
#[cfg(feature = "starttls")]
use crate::starttls::ServerConfig;
use crate::xmpp_codec::Packet;
use crate::xmpp_stream::{add_stanza_id, XMPPStream};
#[cfg(feature = "starttls")]
use crate::AsyncConfig;
/// XMPP client connection and state
@ -46,6 +48,7 @@ enum ClientState<S: AsyncReadAndWrite> {
Connected(XMPPStream<S>),
}
#[cfg(feature = "starttls")]
impl Client<ServerConfig> {
/// Start a new XMPP client using StartTLS transport and autoreconnect
///

View file

@ -1,12 +1,14 @@
use futures::{sink::SinkExt, Sink, Stream};
use minidom::Element;
use std::pin::Pin;
#[cfg(feature = "starttls")]
use std::str::FromStr;
use std::task::{Context, Poll};
use tokio_stream::StreamExt;
use xmpp_parsers::{jid::Jid, ns, stream_features::StreamFeatures};
use crate::connect::ServerConnector;
#[cfg(feature = "starttls")]
use crate::starttls::ServerConfig;
use crate::xmpp_codec::Packet;
use crate::xmpp_stream::{add_stanza_id, XMPPStream};
@ -22,6 +24,7 @@ pub struct Client<C: ServerConnector> {
stream: XMPPStream<C::Stream>,
}
#[cfg(feature = "starttls")]
impl Client<ServerConfig> {
/// Start a new XMPP client and wait for a usable session
pub async fn new<P: Into<String>>(jid: &str, password: P) -> Result<Self, Error> {

View file

@ -1,7 +1,17 @@
//! `ServerConnector` provides streams for XMPP clients
#[cfg(feature = "dns")]
use futures::{future::select_ok, FutureExt};
#[cfg(feature = "dns")]
use hickory_resolver::{
config::LookupIpStrategy, name_server::TokioConnectionProvider, IntoName, TokioAsyncResolver,
};
#[cfg(feature = "dns")]
use log::debug;
use sasl::common::ChannelBinding;
use std::net::{IpAddr, SocketAddr};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use xmpp_parsers::jid::Jid;
use crate::xmpp_stream::XMPPStream;
@ -32,3 +42,78 @@ pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
Ok(ChannelBinding::None)
}
}
/// A simple wrapper to build [`TcpStream`]
pub struct Tcp;
impl Tcp {
/// Connect directly to an IP/Port combo
pub async fn connect(ip: IpAddr, port: u16) -> Result<TcpStream, Error> {
Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?)
}
/// Connect over TCP, resolving A/AAAA records (happy eyeballs)
#[cfg(feature = "dns")]
pub async fn resolve(domain: &str, port: u16) -> Result<TcpStream, Error> {
let ascii_domain = idna::domain_to_ascii(&domain)?;
if let Ok(ip) = ascii_domain.parse() {
return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
}
let (config, mut options) = hickory_resolver::system_conf::read_system_conf()?;
options.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
let resolver = TokioAsyncResolver::new(config, options, TokioConnectionProvider::default());
let ips = resolver.lookup_ip(ascii_domain).await?;
// Happy Eyeballs: connect to all records in parallel, return the
// first to succeed
select_ok(
ips.into_iter()
.map(|ip| TcpStream::connect(SocketAddr::new(ip, port)).boxed()),
)
.await
.map(|(result, _)| result)
.map_err(|_| Error::Disconnected)
}
/// Connect over TCP, resolving SRV records
#[cfg(feature = "dns")]
pub async fn resolve_with_srv(
domain: &str,
srv: &str,
fallback_port: u16,
) -> Result<TcpStream, Error> {
let ascii_domain = idna::domain_to_ascii(&domain)?;
if let Ok(ip) = ascii_domain.parse() {
debug!("Attempting connection to {ip}:{fallback_port}");
return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
}
let resolver = TokioAsyncResolver::tokio_from_system_conf()?;
let srv_domain = format!("{}.{}.", srv, ascii_domain).into_name()?;
let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok();
match srv_records {
Some(lookup) => {
// TODO: sort lookup records by priority/weight
for srv in lookup.iter() {
debug!("Attempting connection to {srv_domain} {srv}");
match Self::resolve(&srv.target().to_ascii(), srv.port()).await {
Ok(stream) => return Ok(stream),
Err(_) => {}
}
}
Err(Error::Disconnected)
}
None => {
// SRV lookup error, retry with hostname
debug!("Attempting connection to {domain}:{fallback_port}");
Self::resolve(domain, fallback_port).await
}
}
}
}

View file

@ -1,3 +1,7 @@
#[cfg(feature = "dns")]
use hickory_resolver::{
error::ResolveError as DnsResolveError, proto::error::ProtoError as DnsProtoError,
};
use sasl::client::MechanismError as SaslMechanismError;
use std::error::Error as StdError;
use std::fmt;
@ -28,8 +32,18 @@ pub enum Error {
Fmt(fmt::Error),
/// Utf8 error
Utf8(Utf8Error),
/// Error resolving DNS and/or establishing a connection, returned by a ServerConnector impl
/// Error specific to ServerConnector impl
Connection(Box<dyn ServerConnectorError>),
/// DNS protocol error
#[cfg(feature = "dns")]
Dns(DnsProtoError),
/// DNS resolution error
#[cfg(feature = "dns")]
Resolve(DnsResolveError),
/// DNS label conversion error, no details available from module
/// `idna`
#[cfg(feature = "dns")]
Idna,
}
impl fmt::Display for Error {
@ -44,6 +58,12 @@ impl fmt::Display for Error {
Error::InvalidState => write!(fmt, "invalid state"),
Error::Fmt(e) => write!(fmt, "Fmt error: {}", e),
Error::Utf8(e) => write!(fmt, "Utf8 error: {}", e),
#[cfg(feature = "dns")]
Error::Dns(e) => write!(fmt, "{:?}", e),
#[cfg(feature = "dns")]
Error::Resolve(e) => write!(fmt, "{:?}", e),
#[cfg(feature = "dns")]
Error::Idna => write!(fmt, "IDNA error"),
}
}
}
@ -92,6 +112,27 @@ impl From<Utf8Error> for Error {
}
}
#[cfg(feature = "dns")]
impl From<idna::Errors> for Error {
fn from(_e: idna::Errors) -> Self {
Error::Idna
}
}
#[cfg(feature = "dns")]
impl From<DnsResolveError> for Error {
fn from(e: DnsResolveError) -> Error {
Error::Resolve(e)
}
}
#[cfg(feature = "dns")]
impl From<DnsProtoError> for Error {
fn from(e: DnsProtoError) -> Error {
Error::Dns(e)
}
}
/// XMPP protocol-level error
#[derive(Debug)]
pub enum ProtocolError {

View file

@ -1,6 +1,5 @@
//! StartTLS ServerConnector Error
use hickory_resolver::{error::ResolveError, proto::error::ProtoError};
#[cfg(feature = "tls-native")]
use native_tls::Error as TlsError;
use std::error::Error as StdError;
@ -15,13 +14,6 @@ use super::ServerConnectorError;
/// StartTLS ServerConnector Error
#[derive(Debug)]
pub enum Error {
/// DNS protocol error
Dns(ProtoError),
/// DNS resolution error
Resolve(ResolveError),
/// DNS label conversion error, no details available from module
/// `idna`
Idna,
/// TLS error
Tls(TlsError),
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
@ -34,9 +26,6 @@ impl ServerConnectorError for Error {}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Dns(e) => write!(fmt, "{:?}", e),
Self::Resolve(e) => write!(fmt, "{:?}", e),
Self::Idna => write!(fmt, "IDNA error"),
Self::Tls(e) => write!(fmt, "TLS error: {}", e),
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),

View file

@ -1,75 +0,0 @@
use super::error::Error as StartTlsError;
use crate::Error;
use futures::{future::select_ok, FutureExt};
use hickory_resolver::{
config::LookupIpStrategy, name_server::TokioConnectionProvider, IntoName, TokioAsyncResolver,
};
use log::debug;
use std::net::SocketAddr;
use tokio::net::TcpStream;
pub async fn connect_to_host(domain: &str, port: u16) -> Result<TcpStream, Error> {
let ascii_domain = idna::domain_to_ascii(&domain).map_err(|_| StartTlsError::Idna)?;
if let Ok(ip) = ascii_domain.parse() {
return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
}
let (config, mut options) =
hickory_resolver::system_conf::read_system_conf().map_err(StartTlsError::Resolve)?;
options.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
let resolver = TokioAsyncResolver::new(config, options, TokioConnectionProvider::default());
let ips = resolver
.lookup_ip(ascii_domain)
.await
.map_err(StartTlsError::Resolve)?;
// Happy Eyeballs: connect to all records in parallel, return the
// first to succeed
select_ok(
ips.into_iter()
.map(|ip| TcpStream::connect(SocketAddr::new(ip, port)).boxed()),
)
.await
.map(|(result, _)| result)
.map_err(|_| crate::Error::Disconnected)
}
pub async fn connect_with_srv(
domain: &str,
srv: &str,
fallback_port: u16,
) -> Result<TcpStream, Error> {
let ascii_domain = idna::domain_to_ascii(&domain).map_err(|_| StartTlsError::Idna)?;
if let Ok(ip) = ascii_domain.parse() {
debug!("Attempting connection to {ip}:{fallback_port}");
return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
}
let resolver = TokioAsyncResolver::tokio_from_system_conf().map_err(StartTlsError::Resolve)?;
let srv_domain = format!("{}.{}.", srv, ascii_domain)
.into_name()
.map_err(StartTlsError::Dns)?;
let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok();
match srv_records {
Some(lookup) => {
// TODO: sort lookup records by priority/weight
for srv in lookup.iter() {
debug!("Attempting connection to {srv_domain} {srv}");
match connect_to_host(&srv.target().to_ascii(), srv.port()).await {
Ok(stream) => return Ok(stream),
Err(_) => {}
}
}
Err(crate::Error::Disconnected.into())
}
None => {
// SRV lookup error, retry with hostname
debug!("Attempting connection to {domain}:{fallback_port}");
connect_to_host(domain, fallback_port).await
}
}
}

View file

@ -27,16 +27,17 @@ use tokio::{
};
use xmpp_parsers::{jid::Jid, ns};
use crate::error::ProtocolError;
use crate::Error;
use crate::{connect::ServerConnector, xmpp_codec::Packet, AsyncClient, SimpleClient};
use crate::{connect::ServerConnectorError, xmpp_stream::XMPPStream};
use crate::{
connect::{ServerConnector, ServerConnectorError, Tcp},
error::{Error, ProtocolError},
xmpp_codec::Packet,
xmpp_stream::XMPPStream,
AsyncClient, SimpleClient,
};
use self::error::Error as StartTlsError;
use self::happy_eyeballs::{connect_to_host, connect_with_srv};
pub mod error;
mod happy_eyeballs;
/// AsyncClient that connects over StartTls
pub type StartTlsAsyncClient = AsyncClient<ServerConfig>;
@ -64,9 +65,9 @@ impl ServerConnector for ServerConfig {
// TCP connection
let tcp_stream = match self {
ServerConfig::UseSrv => {
connect_with_srv(jid.domain().as_str(), "_xmpp-client._tcp", 5222).await?
Tcp::resolve_with_srv(jid.domain().as_str(), "_xmpp-client._tcp", 5222).await?
}
ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), *port).await?,
ServerConfig::Manual { host, port } => Tcp::resolve(host.as_str(), *port).await?,
};
// Unencryped XMPPStream