diff --git a/tokio-xmpp/src/client/async_client.rs b/tokio-xmpp/src/client/async_client.rs index e8ead197..9f75c96e 100644 --- a/tokio-xmpp/src/client/async_client.rs +++ b/tokio-xmpp/src/client/async_client.rs @@ -1,5 +1,4 @@ use futures::{sink::SinkExt, task::Poll, Future, Sink, Stream}; -use idna; use sasl::common::{ChannelBinding, Credentials}; use std::mem::replace; use std::pin::Pin; @@ -14,7 +13,7 @@ use xmpp_parsers::{ns, Element, Jid, JidParseError}; use super::auth::auth; use super::bind::bind; use crate::event::Event; -use crate::happy_eyeballs::connect; +use crate::happy_eyeballs::{connect_to_host, connect_with_srv}; use crate::starttls::starttls; use crate::xmpp_codec::Packet; use crate::xmpp_stream; @@ -33,12 +32,22 @@ pub struct Client { // TODO: tls_required=true } +/// XMPP server connection configuration +#[derive(Clone)] +pub enum ServerConfig { + UseSrv, + #[allow(unused)] + Manual { + host: String, + port: u16, + }, +} + /// XMMPP client configuration pub struct Config { jid: Jid, password: String, - server: String, - port: u16, + server: ServerConfig, } type XMPPStream = xmpp_stream::XMPPStream>; @@ -60,8 +69,7 @@ impl Client { let config = Config { jid: jid.clone(), password: password.into(), - server: jid.clone().domain(), - port: 5222, + server: ServerConfig::UseSrv, }; let client = Self::new_with_config(config); Ok(client) @@ -72,7 +80,6 @@ impl Client { let local = LocalSet::new(); let connect = local.spawn_local(Self::connect( config.server.clone(), - config.port, config.jid.clone(), config.password.clone(), )); @@ -92,17 +99,20 @@ impl Client { } async fn connect( - server: String, - port: u16, + server: ServerConfig, jid: Jid, password: String, ) -> Result { let username = jid.clone().node().unwrap(); let password = password; - let domain = idna::domain_to_ascii(&server).map_err(|_| Error::Idna)?; // TCP connection - let tcp_stream = connect(&domain, Some("_xmpp-client._tcp"), port).await?; + let tcp_stream = match server { + ServerConfig::UseSrv => { + connect_with_srv(&jid.clone().domain(), Some("_xmpp-client._tcp"), 5222).await? + } + ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), port).await?, + }; // Unencryped XMPPStream let xmpp_stream = @@ -186,7 +196,6 @@ impl Stream for Client { let mut local = LocalSet::new(); let connect = local.spawn_local(Self::connect( self.config.server.clone(), - self.config.port, self.config.jid.clone(), self.config.password.clone(), )); diff --git a/tokio-xmpp/src/client/simple_client.rs b/tokio-xmpp/src/client/simple_client.rs index e7a58150..4f90ca27 100644 --- a/tokio-xmpp/src/client/simple_client.rs +++ b/tokio-xmpp/src/client/simple_client.rs @@ -11,7 +11,7 @@ use xmpp_parsers::{ns, Element, Jid}; use super::auth::auth; use super::bind::bind; -use crate::happy_eyeballs::connect; +use crate::happy_eyeballs::connect_with_srv; use crate::starttls::starttls; use crate::xmpp_codec::Packet; use crate::xmpp_stream; @@ -47,7 +47,7 @@ impl Client { let domain = idna::domain_to_ascii(&jid.clone().domain()).map_err(|_| Error::Idna)?; // TCP connection - let tcp_stream = connect(&domain, Some("_xmpp-client._tcp"), 5222).await?; + let tcp_stream = connect_with_srv(&domain, Some("_xmpp-client._tcp"), 5222).await?; // Unencryped XMPPStream let xmpp_stream = diff --git a/tokio-xmpp/src/component/mod.rs b/tokio-xmpp/src/component/mod.rs index 3e2ab479..6ac1371f 100644 --- a/tokio-xmpp/src/component/mod.rs +++ b/tokio-xmpp/src/component/mod.rs @@ -8,7 +8,7 @@ use std::task::Context; use tokio::net::TcpStream; use xmpp_parsers::{ns, Element, Jid}; -use super::happy_eyeballs::connect; +use super::happy_eyeballs::connect_to_host; use super::xmpp_codec::Packet; use super::xmpp_stream; use super::Error; @@ -43,7 +43,7 @@ impl Component { port: u16, ) -> Result { let password = password; - let tcp_stream = connect(server, None, port).await?; + let tcp_stream = connect_to_host(server, port).await?; let mut xmpp_stream = xmpp_stream::XMPPStream::start(tcp_stream, jid, ns::COMPONENT_ACCEPT.to_owned()) .await?; diff --git a/tokio-xmpp/src/happy_eyeballs.rs b/tokio-xmpp/src/happy_eyeballs.rs index a687ca5f..5ad09729 100644 --- a/tokio-xmpp/src/happy_eyeballs.rs +++ b/tokio-xmpp/src/happy_eyeballs.rs @@ -1,15 +1,20 @@ use crate::{ConnecterError, Error}; +use idna; use std::net::SocketAddr; use tokio::net::TcpStream; use trust_dns_resolver::{IntoName, TokioAsyncResolver}; -async fn connect_to_host( - resolver: &TokioAsyncResolver, - host: &str, - port: u16, -) -> Result { +pub async fn connect_to_host(domain: &str, port: u16) -> Result { + let ascii_domain = idna::domain_to_ascii(&domain).map_err(|_| Error::Idna)?; + + if let Ok(ip) = ascii_domain.parse() { + return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?); + } + + let resolver = TokioAsyncResolver::tokio_from_system_conf().map_err(ConnecterError::Resolve)?; + let ips = resolver - .lookup_ip(host) + .lookup_ip(ascii_domain) .await .map_err(ConnecterError::Resolve)?; for ip in ips.iter() { @@ -21,12 +26,14 @@ async fn connect_to_host( Err(Error::Disconnected) } -pub async fn connect( +pub async fn connect_with_srv( domain: &str, srv: Option<&str>, fallback_port: u16, ) -> Result { - if let Ok(ip) = domain.parse() { + let ascii_domain = idna::domain_to_ascii(&domain).map_err(|_| Error::Idna)?; + + if let Ok(ip) = ascii_domain.parse() { return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?); } @@ -34,7 +41,7 @@ pub async fn connect( let srv_records = match srv { Some(srv) => { - let srv_domain = format!("{}.{}.", srv, domain) + let srv_domain = format!("{}.{}.", srv, ascii_domain) .into_name() .map_err(ConnecterError::Dns)?; resolver.srv_lookup(srv_domain).await.ok() @@ -46,7 +53,7 @@ pub async fn connect( Some(lookup) => { // TODO: sort lookup records by priority/weight for srv in lookup.iter() { - match connect_to_host(&resolver, &srv.target().to_ascii(), srv.port()).await { + match connect_to_host(&srv.target().to_ascii(), srv.port()).await { Ok(stream) => return Ok(stream), Err(_) => {} } @@ -55,7 +62,7 @@ pub async fn connect( } None => { // SRV lookup error, retry with hostname - connect_to_host(&resolver, domain, fallback_port).await + connect_to_host(domain, fallback_port).await } } }