Use ServerConfig enum for tokio-xmpp client config

And expose connect_to_host from happy_eyeballs to let clients explicitly
choose to use SRV or not. (Rename connect to connect_with_srv)
This commit is contained in:
Paul Fariello 2020-12-06 15:23:29 +01:00
parent 08e58e44b1
commit 7b4a6e3ace
4 changed files with 43 additions and 27 deletions

View file

@ -1,5 +1,4 @@
use futures::{sink::SinkExt, task::Poll, Future, Sink, Stream}; use futures::{sink::SinkExt, task::Poll, Future, Sink, Stream};
use idna;
use sasl::common::{ChannelBinding, Credentials}; use sasl::common::{ChannelBinding, Credentials};
use std::mem::replace; use std::mem::replace;
use std::pin::Pin; use std::pin::Pin;
@ -14,7 +13,7 @@ use xmpp_parsers::{ns, Element, Jid, JidParseError};
use super::auth::auth; use super::auth::auth;
use super::bind::bind; use super::bind::bind;
use crate::event::Event; use crate::event::Event;
use crate::happy_eyeballs::connect; use crate::happy_eyeballs::{connect_to_host, connect_with_srv};
use crate::starttls::starttls; use crate::starttls::starttls;
use crate::xmpp_codec::Packet; use crate::xmpp_codec::Packet;
use crate::xmpp_stream; use crate::xmpp_stream;
@ -33,12 +32,22 @@ pub struct Client {
// TODO: tls_required=true // TODO: tls_required=true
} }
/// XMPP server connection configuration
#[derive(Clone)]
pub enum ServerConfig {
UseSrv,
#[allow(unused)]
Manual {
host: String,
port: u16,
},
}
/// XMMPP client configuration /// XMMPP client configuration
pub struct Config { pub struct Config {
jid: Jid, jid: Jid,
password: String, password: String,
server: String, server: ServerConfig,
port: u16,
} }
type XMPPStream = xmpp_stream::XMPPStream<TlsStream<TcpStream>>; type XMPPStream = xmpp_stream::XMPPStream<TlsStream<TcpStream>>;
@ -60,8 +69,7 @@ impl Client {
let config = Config { let config = Config {
jid: jid.clone(), jid: jid.clone(),
password: password.into(), password: password.into(),
server: jid.clone().domain(), server: ServerConfig::UseSrv,
port: 5222,
}; };
let client = Self::new_with_config(config); let client = Self::new_with_config(config);
Ok(client) Ok(client)
@ -72,7 +80,6 @@ impl Client {
let local = LocalSet::new(); let local = LocalSet::new();
let connect = local.spawn_local(Self::connect( let connect = local.spawn_local(Self::connect(
config.server.clone(), config.server.clone(),
config.port,
config.jid.clone(), config.jid.clone(),
config.password.clone(), config.password.clone(),
)); ));
@ -92,17 +99,20 @@ impl Client {
} }
async fn connect( async fn connect(
server: String, server: ServerConfig,
port: u16,
jid: Jid, jid: Jid,
password: String, password: String,
) -> Result<XMPPStream, Error> { ) -> Result<XMPPStream, Error> {
let username = jid.clone().node().unwrap(); let username = jid.clone().node().unwrap();
let password = password; let password = password;
let domain = idna::domain_to_ascii(&server).map_err(|_| Error::Idna)?;
// TCP connection // TCP connection
let tcp_stream = connect(&domain, Some("_xmpp-client._tcp"), port).await?; let tcp_stream = match server {
ServerConfig::UseSrv => {
connect_with_srv(&jid.clone().domain(), Some("_xmpp-client._tcp"), 5222).await?
}
ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), port).await?,
};
// Unencryped XMPPStream // Unencryped XMPPStream
let xmpp_stream = let xmpp_stream =
@ -186,7 +196,6 @@ impl Stream for Client {
let mut local = LocalSet::new(); let mut local = LocalSet::new();
let connect = local.spawn_local(Self::connect( let connect = local.spawn_local(Self::connect(
self.config.server.clone(), self.config.server.clone(),
self.config.port,
self.config.jid.clone(), self.config.jid.clone(),
self.config.password.clone(), self.config.password.clone(),
)); ));

View file

@ -11,7 +11,7 @@ use xmpp_parsers::{ns, Element, Jid};
use super::auth::auth; use super::auth::auth;
use super::bind::bind; use super::bind::bind;
use crate::happy_eyeballs::connect; use crate::happy_eyeballs::connect_with_srv;
use crate::starttls::starttls; use crate::starttls::starttls;
use crate::xmpp_codec::Packet; use crate::xmpp_codec::Packet;
use crate::xmpp_stream; use crate::xmpp_stream;
@ -47,7 +47,7 @@ impl Client {
let domain = idna::domain_to_ascii(&jid.clone().domain()).map_err(|_| Error::Idna)?; let domain = idna::domain_to_ascii(&jid.clone().domain()).map_err(|_| Error::Idna)?;
// TCP connection // TCP connection
let tcp_stream = connect(&domain, Some("_xmpp-client._tcp"), 5222).await?; let tcp_stream = connect_with_srv(&domain, Some("_xmpp-client._tcp"), 5222).await?;
// Unencryped XMPPStream // Unencryped XMPPStream
let xmpp_stream = let xmpp_stream =

View file

@ -8,7 +8,7 @@ use std::task::Context;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use xmpp_parsers::{ns, Element, Jid}; use xmpp_parsers::{ns, Element, Jid};
use super::happy_eyeballs::connect; use super::happy_eyeballs::connect_to_host;
use super::xmpp_codec::Packet; use super::xmpp_codec::Packet;
use super::xmpp_stream; use super::xmpp_stream;
use super::Error; use super::Error;
@ -43,7 +43,7 @@ impl Component {
port: u16, port: u16,
) -> Result<XMPPStream, Error> { ) -> Result<XMPPStream, Error> {
let password = password; let password = password;
let tcp_stream = connect(server, None, port).await?; let tcp_stream = connect_to_host(server, port).await?;
let mut xmpp_stream = let mut xmpp_stream =
xmpp_stream::XMPPStream::start(tcp_stream, jid, ns::COMPONENT_ACCEPT.to_owned()) xmpp_stream::XMPPStream::start(tcp_stream, jid, ns::COMPONENT_ACCEPT.to_owned())
.await?; .await?;

View file

@ -1,15 +1,20 @@
use crate::{ConnecterError, Error}; use crate::{ConnecterError, Error};
use idna;
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use trust_dns_resolver::{IntoName, TokioAsyncResolver}; use trust_dns_resolver::{IntoName, TokioAsyncResolver};
async fn connect_to_host( pub async fn connect_to_host(domain: &str, port: u16) -> Result<TcpStream, Error> {
resolver: &TokioAsyncResolver, let ascii_domain = idna::domain_to_ascii(&domain).map_err(|_| Error::Idna)?;
host: &str,
port: u16, if let Ok(ip) = ascii_domain.parse() {
) -> Result<TcpStream, Error> { return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
}
let resolver = TokioAsyncResolver::tokio_from_system_conf().map_err(ConnecterError::Resolve)?;
let ips = resolver let ips = resolver
.lookup_ip(host) .lookup_ip(ascii_domain)
.await .await
.map_err(ConnecterError::Resolve)?; .map_err(ConnecterError::Resolve)?;
for ip in ips.iter() { for ip in ips.iter() {
@ -21,12 +26,14 @@ async fn connect_to_host(
Err(Error::Disconnected) Err(Error::Disconnected)
} }
pub async fn connect( pub async fn connect_with_srv(
domain: &str, domain: &str,
srv: Option<&str>, srv: Option<&str>,
fallback_port: u16, fallback_port: u16,
) -> Result<TcpStream, Error> { ) -> Result<TcpStream, Error> {
if let Ok(ip) = domain.parse() { let ascii_domain = idna::domain_to_ascii(&domain).map_err(|_| Error::Idna)?;
if let Ok(ip) = ascii_domain.parse() {
return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?); return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
} }
@ -34,7 +41,7 @@ pub async fn connect(
let srv_records = match srv { let srv_records = match srv {
Some(srv) => { Some(srv) => {
let srv_domain = format!("{}.{}.", srv, domain) let srv_domain = format!("{}.{}.", srv, ascii_domain)
.into_name() .into_name()
.map_err(ConnecterError::Dns)?; .map_err(ConnecterError::Dns)?;
resolver.srv_lookup(srv_domain).await.ok() resolver.srv_lookup(srv_domain).await.ok()
@ -46,7 +53,7 @@ pub async fn connect(
Some(lookup) => { Some(lookup) => {
// TODO: sort lookup records by priority/weight // TODO: sort lookup records by priority/weight
for srv in lookup.iter() { for srv in lookup.iter() {
match connect_to_host(&resolver, &srv.target().to_ascii(), srv.port()).await { match connect_to_host(&srv.target().to_ascii(), srv.port()).await {
Ok(stream) => return Ok(stream), Ok(stream) => return Ok(stream),
Err(_) => {} Err(_) => {}
} }
@ -55,7 +62,7 @@ pub async fn connect(
} }
None => { None => {
// SRV lookup error, retry with hostname // SRV lookup error, retry with hostname
connect_to_host(&resolver, domain, fallback_port).await connect_to_host(domain, fallback_port).await
} }
} }
} }