Add dns
feature for DNS stuff (not just in starttls)
This commit is contained in:
parent
d706b318c3
commit
97698b4d1e
8 changed files with 145 additions and 96 deletions
|
@ -41,7 +41,7 @@ tokio-xmpp = { path = ".", features = ["insecure-tcp"]}
|
|||
|
||||
[features]
|
||||
default = ["starttls-rust"]
|
||||
starttls = ["hickory-resolver", "idna"]
|
||||
starttls = ["dns"]
|
||||
tls-rust = ["tokio-rustls", "webpki-roots"]
|
||||
tls-native = ["tokio-native-tls", "native-tls"]
|
||||
starttls-native = ["starttls", "tls-native"]
|
||||
|
@ -50,6 +50,8 @@ insecure-tcp = []
|
|||
syntax-highlighting = ["syntect"]
|
||||
# Enable serde support in jid crate
|
||||
serde = [ "xmpp-parsers/serde" ]
|
||||
# Required by starttls, and used by insecure-tcp by default
|
||||
dns = [ "hickory-resolver", "idna" ]
|
||||
|
||||
[lints.rust]
|
||||
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(xmpprs_doc_build)'] }
|
||||
|
|
|
@ -10,9 +10,11 @@ use super::connect::client_login;
|
|||
use crate::connect::{AsyncReadAndWrite, ServerConnector};
|
||||
use crate::error::{Error, ProtocolError};
|
||||
use crate::event::Event;
|
||||
#[cfg(feature = "starttls")]
|
||||
use crate::starttls::ServerConfig;
|
||||
use crate::xmpp_codec::Packet;
|
||||
use crate::xmpp_stream::{add_stanza_id, XMPPStream};
|
||||
#[cfg(feature = "starttls")]
|
||||
use crate::AsyncConfig;
|
||||
|
||||
/// XMPP client connection and state
|
||||
|
@ -46,6 +48,7 @@ enum ClientState<S: AsyncReadAndWrite> {
|
|||
Connected(XMPPStream<S>),
|
||||
}
|
||||
|
||||
#[cfg(feature = "starttls")]
|
||||
impl Client<ServerConfig> {
|
||||
/// Start a new XMPP client using StartTLS transport and autoreconnect
|
||||
///
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
use futures::{sink::SinkExt, Sink, Stream};
|
||||
use minidom::Element;
|
||||
use std::pin::Pin;
|
||||
#[cfg(feature = "starttls")]
|
||||
use std::str::FromStr;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio_stream::StreamExt;
|
||||
use xmpp_parsers::{jid::Jid, ns, stream_features::StreamFeatures};
|
||||
|
||||
use crate::connect::ServerConnector;
|
||||
#[cfg(feature = "starttls")]
|
||||
use crate::starttls::ServerConfig;
|
||||
use crate::xmpp_codec::Packet;
|
||||
use crate::xmpp_stream::{add_stanza_id, XMPPStream};
|
||||
|
@ -22,6 +24,7 @@ pub struct Client<C: ServerConnector> {
|
|||
stream: XMPPStream<C::Stream>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "starttls")]
|
||||
impl Client<ServerConfig> {
|
||||
/// Start a new XMPP client and wait for a usable session
|
||||
pub async fn new<P: Into<String>>(jid: &str, password: P) -> Result<Self, Error> {
|
||||
|
|
|
@ -1,7 +1,17 @@
|
|||
//! `ServerConnector` provides streams for XMPP clients
|
||||
|
||||
#[cfg(feature = "dns")]
|
||||
use futures::{future::select_ok, FutureExt};
|
||||
#[cfg(feature = "dns")]
|
||||
use hickory_resolver::{
|
||||
config::LookupIpStrategy, name_server::TokioConnectionProvider, IntoName, TokioAsyncResolver,
|
||||
};
|
||||
#[cfg(feature = "dns")]
|
||||
use log::debug;
|
||||
use sasl::common::ChannelBinding;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
use xmpp_parsers::jid::Jid;
|
||||
|
||||
use crate::xmpp_stream::XMPPStream;
|
||||
|
@ -32,3 +42,78 @@ pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
|
|||
Ok(ChannelBinding::None)
|
||||
}
|
||||
}
|
||||
|
||||
/// A simple wrapper to build [`TcpStream`]
|
||||
pub struct Tcp;
|
||||
|
||||
impl Tcp {
|
||||
/// Connect directly to an IP/Port combo
|
||||
pub async fn connect(ip: IpAddr, port: u16) -> Result<TcpStream, Error> {
|
||||
Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?)
|
||||
}
|
||||
|
||||
/// Connect over TCP, resolving A/AAAA records (happy eyeballs)
|
||||
#[cfg(feature = "dns")]
|
||||
pub async fn resolve(domain: &str, port: u16) -> Result<TcpStream, Error> {
|
||||
let ascii_domain = idna::domain_to_ascii(&domain)?;
|
||||
|
||||
if let Ok(ip) = ascii_domain.parse() {
|
||||
return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
|
||||
}
|
||||
|
||||
let (config, mut options) = hickory_resolver::system_conf::read_system_conf()?;
|
||||
options.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
|
||||
let resolver = TokioAsyncResolver::new(config, options, TokioConnectionProvider::default());
|
||||
|
||||
let ips = resolver.lookup_ip(ascii_domain).await?;
|
||||
|
||||
// Happy Eyeballs: connect to all records in parallel, return the
|
||||
// first to succeed
|
||||
select_ok(
|
||||
ips.into_iter()
|
||||
.map(|ip| TcpStream::connect(SocketAddr::new(ip, port)).boxed()),
|
||||
)
|
||||
.await
|
||||
.map(|(result, _)| result)
|
||||
.map_err(|_| Error::Disconnected)
|
||||
}
|
||||
|
||||
/// Connect over TCP, resolving SRV records
|
||||
#[cfg(feature = "dns")]
|
||||
pub async fn resolve_with_srv(
|
||||
domain: &str,
|
||||
srv: &str,
|
||||
fallback_port: u16,
|
||||
) -> Result<TcpStream, Error> {
|
||||
let ascii_domain = idna::domain_to_ascii(&domain)?;
|
||||
|
||||
if let Ok(ip) = ascii_domain.parse() {
|
||||
debug!("Attempting connection to {ip}:{fallback_port}");
|
||||
return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
|
||||
}
|
||||
|
||||
let resolver = TokioAsyncResolver::tokio_from_system_conf()?;
|
||||
|
||||
let srv_domain = format!("{}.{}.", srv, ascii_domain).into_name()?;
|
||||
let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok();
|
||||
|
||||
match srv_records {
|
||||
Some(lookup) => {
|
||||
// TODO: sort lookup records by priority/weight
|
||||
for srv in lookup.iter() {
|
||||
debug!("Attempting connection to {srv_domain} {srv}");
|
||||
match Self::resolve(&srv.target().to_ascii(), srv.port()).await {
|
||||
Ok(stream) => return Ok(stream),
|
||||
Err(_) => {}
|
||||
}
|
||||
}
|
||||
Err(Error::Disconnected)
|
||||
}
|
||||
None => {
|
||||
// SRV lookup error, retry with hostname
|
||||
debug!("Attempting connection to {domain}:{fallback_port}");
|
||||
Self::resolve(domain, fallback_port).await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
#[cfg(feature = "dns")]
|
||||
use hickory_resolver::{
|
||||
error::ResolveError as DnsResolveError, proto::error::ProtoError as DnsProtoError,
|
||||
};
|
||||
use sasl::client::MechanismError as SaslMechanismError;
|
||||
use std::error::Error as StdError;
|
||||
use std::fmt;
|
||||
|
@ -28,8 +32,18 @@ pub enum Error {
|
|||
Fmt(fmt::Error),
|
||||
/// Utf8 error
|
||||
Utf8(Utf8Error),
|
||||
/// Error resolving DNS and/or establishing a connection, returned by a ServerConnector impl
|
||||
/// Error specific to ServerConnector impl
|
||||
Connection(Box<dyn ServerConnectorError>),
|
||||
/// DNS protocol error
|
||||
#[cfg(feature = "dns")]
|
||||
Dns(DnsProtoError),
|
||||
/// DNS resolution error
|
||||
#[cfg(feature = "dns")]
|
||||
Resolve(DnsResolveError),
|
||||
/// DNS label conversion error, no details available from module
|
||||
/// `idna`
|
||||
#[cfg(feature = "dns")]
|
||||
Idna,
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
|
@ -44,6 +58,12 @@ impl fmt::Display for Error {
|
|||
Error::InvalidState => write!(fmt, "invalid state"),
|
||||
Error::Fmt(e) => write!(fmt, "Fmt error: {}", e),
|
||||
Error::Utf8(e) => write!(fmt, "Utf8 error: {}", e),
|
||||
#[cfg(feature = "dns")]
|
||||
Error::Dns(e) => write!(fmt, "{:?}", e),
|
||||
#[cfg(feature = "dns")]
|
||||
Error::Resolve(e) => write!(fmt, "{:?}", e),
|
||||
#[cfg(feature = "dns")]
|
||||
Error::Idna => write!(fmt, "IDNA error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -92,6 +112,27 @@ impl From<Utf8Error> for Error {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "dns")]
|
||||
impl From<idna::Errors> for Error {
|
||||
fn from(_e: idna::Errors) -> Self {
|
||||
Error::Idna
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "dns")]
|
||||
impl From<DnsResolveError> for Error {
|
||||
fn from(e: DnsResolveError) -> Error {
|
||||
Error::Resolve(e)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "dns")]
|
||||
impl From<DnsProtoError> for Error {
|
||||
fn from(e: DnsProtoError) -> Error {
|
||||
Error::Dns(e)
|
||||
}
|
||||
}
|
||||
|
||||
/// XMPP protocol-level error
|
||||
#[derive(Debug)]
|
||||
pub enum ProtocolError {
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
//! StartTLS ServerConnector Error
|
||||
|
||||
use hickory_resolver::{error::ResolveError, proto::error::ProtoError};
|
||||
#[cfg(feature = "tls-native")]
|
||||
use native_tls::Error as TlsError;
|
||||
use std::error::Error as StdError;
|
||||
|
@ -15,13 +14,6 @@ use super::ServerConnectorError;
|
|||
/// StartTLS ServerConnector Error
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
/// DNS protocol error
|
||||
Dns(ProtoError),
|
||||
/// DNS resolution error
|
||||
Resolve(ResolveError),
|
||||
/// DNS label conversion error, no details available from module
|
||||
/// `idna`
|
||||
Idna,
|
||||
/// TLS error
|
||||
Tls(TlsError),
|
||||
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
|
||||
|
@ -34,9 +26,6 @@ impl ServerConnectorError for Error {}
|
|||
impl fmt::Display for Error {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Self::Dns(e) => write!(fmt, "{:?}", e),
|
||||
Self::Resolve(e) => write!(fmt, "{:?}", e),
|
||||
Self::Idna => write!(fmt, "IDNA error"),
|
||||
Self::Tls(e) => write!(fmt, "TLS error: {}", e),
|
||||
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
|
||||
Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
|
||||
|
|
|
@ -1,75 +0,0 @@
|
|||
use super::error::Error as StartTlsError;
|
||||
use crate::Error;
|
||||
use futures::{future::select_ok, FutureExt};
|
||||
use hickory_resolver::{
|
||||
config::LookupIpStrategy, name_server::TokioConnectionProvider, IntoName, TokioAsyncResolver,
|
||||
};
|
||||
use log::debug;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
pub async fn connect_to_host(domain: &str, port: u16) -> Result<TcpStream, Error> {
|
||||
let ascii_domain = idna::domain_to_ascii(&domain).map_err(|_| StartTlsError::Idna)?;
|
||||
|
||||
if let Ok(ip) = ascii_domain.parse() {
|
||||
return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
|
||||
}
|
||||
|
||||
let (config, mut options) =
|
||||
hickory_resolver::system_conf::read_system_conf().map_err(StartTlsError::Resolve)?;
|
||||
options.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
|
||||
let resolver = TokioAsyncResolver::new(config, options, TokioConnectionProvider::default());
|
||||
|
||||
let ips = resolver
|
||||
.lookup_ip(ascii_domain)
|
||||
.await
|
||||
.map_err(StartTlsError::Resolve)?;
|
||||
// Happy Eyeballs: connect to all records in parallel, return the
|
||||
// first to succeed
|
||||
select_ok(
|
||||
ips.into_iter()
|
||||
.map(|ip| TcpStream::connect(SocketAddr::new(ip, port)).boxed()),
|
||||
)
|
||||
.await
|
||||
.map(|(result, _)| result)
|
||||
.map_err(|_| crate::Error::Disconnected)
|
||||
}
|
||||
|
||||
pub async fn connect_with_srv(
|
||||
domain: &str,
|
||||
srv: &str,
|
||||
fallback_port: u16,
|
||||
) -> Result<TcpStream, Error> {
|
||||
let ascii_domain = idna::domain_to_ascii(&domain).map_err(|_| StartTlsError::Idna)?;
|
||||
|
||||
if let Ok(ip) = ascii_domain.parse() {
|
||||
debug!("Attempting connection to {ip}:{fallback_port}");
|
||||
return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
|
||||
}
|
||||
|
||||
let resolver = TokioAsyncResolver::tokio_from_system_conf().map_err(StartTlsError::Resolve)?;
|
||||
|
||||
let srv_domain = format!("{}.{}.", srv, ascii_domain)
|
||||
.into_name()
|
||||
.map_err(StartTlsError::Dns)?;
|
||||
let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok();
|
||||
|
||||
match srv_records {
|
||||
Some(lookup) => {
|
||||
// TODO: sort lookup records by priority/weight
|
||||
for srv in lookup.iter() {
|
||||
debug!("Attempting connection to {srv_domain} {srv}");
|
||||
match connect_to_host(&srv.target().to_ascii(), srv.port()).await {
|
||||
Ok(stream) => return Ok(stream),
|
||||
Err(_) => {}
|
||||
}
|
||||
}
|
||||
Err(crate::Error::Disconnected.into())
|
||||
}
|
||||
None => {
|
||||
// SRV lookup error, retry with hostname
|
||||
debug!("Attempting connection to {domain}:{fallback_port}");
|
||||
connect_to_host(domain, fallback_port).await
|
||||
}
|
||||
}
|
||||
}
|
|
@ -27,16 +27,17 @@ use tokio::{
|
|||
};
|
||||
use xmpp_parsers::{jid::Jid, ns};
|
||||
|
||||
use crate::error::ProtocolError;
|
||||
use crate::Error;
|
||||
use crate::{connect::ServerConnector, xmpp_codec::Packet, AsyncClient, SimpleClient};
|
||||
use crate::{connect::ServerConnectorError, xmpp_stream::XMPPStream};
|
||||
use crate::{
|
||||
connect::{ServerConnector, ServerConnectorError, Tcp},
|
||||
error::{Error, ProtocolError},
|
||||
xmpp_codec::Packet,
|
||||
xmpp_stream::XMPPStream,
|
||||
AsyncClient, SimpleClient,
|
||||
};
|
||||
|
||||
use self::error::Error as StartTlsError;
|
||||
use self::happy_eyeballs::{connect_to_host, connect_with_srv};
|
||||
|
||||
pub mod error;
|
||||
mod happy_eyeballs;
|
||||
|
||||
/// AsyncClient that connects over StartTls
|
||||
pub type StartTlsAsyncClient = AsyncClient<ServerConfig>;
|
||||
|
@ -64,9 +65,9 @@ impl ServerConnector for ServerConfig {
|
|||
// TCP connection
|
||||
let tcp_stream = match self {
|
||||
ServerConfig::UseSrv => {
|
||||
connect_with_srv(jid.domain().as_str(), "_xmpp-client._tcp", 5222).await?
|
||||
Tcp::resolve_with_srv(jid.domain().as_str(), "_xmpp-client._tcp", 5222).await?
|
||||
}
|
||||
ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), *port).await?,
|
||||
ServerConfig::Manual { host, port } => Tcp::resolve(host.as_str(), *port).await?,
|
||||
};
|
||||
|
||||
// Unencryped XMPPStream
|
||||
|
|
Loading…
Reference in a new issue