diff --git a/tokio-xmpp/Cargo.toml b/tokio-xmpp/Cargo.toml index a8ca13b9..6e6eeee3 100644 --- a/tokio-xmpp/Cargo.toml +++ b/tokio-xmpp/Cargo.toml @@ -16,19 +16,24 @@ bytes = "1" futures = "0.3" idna = "0.2" log = "0.4" -native-tls = "0.2" +native-tls = { version = "0.2", optional = true } sasl = "0.5" tokio = { version = "1", features = ["net", "rt", "rt-multi-thread", "macros"] } -tokio-util = { version = "0.6", features = ["codec"] } +tokio-native-tls = { version = "0.3", optional = true } +tokio-rustls = { version = "0.22", optional = true } tokio-stream = { version = "0.1", features = [] } -tokio-tls = { package = "tokio-native-tls", version = "0.3" } -trust-dns-resolver = "0.20" +tokio-util = { version = "0.6", features = ["codec"] } trust-dns-proto = "0.20" +trust-dns-resolver = "0.20" xml5ever = "0.16" xmpp-parsers = "0.18" +webpki = { version = "0.21", optional = true } [build-dependencies] rustc_version = "0.3" [features] +default = ["tls-native"] +tls-rust = ["tokio-rustls", "webpki"] +tls-native = ["tokio-native-tls", "native-tls"] serde = ["xmpp-parsers/serde"] diff --git a/tokio-xmpp/src/client/async_client.rs b/tokio-xmpp/src/client/async_client.rs index 79a7aa07..e2fdb102 100644 --- a/tokio-xmpp/src/client/async_client.rs +++ b/tokio-xmpp/src/client/async_client.rs @@ -7,7 +7,10 @@ use std::task::Context; use tokio::net::TcpStream; use tokio::task::JoinHandle; use tokio::task::LocalSet; -use tokio_tls::TlsStream; +#[cfg(feature = "tls-native")] +use tokio_native_tls::TlsStream; +#[cfg(feature = "tls-rust")] +use tokio_rustls::client::TlsStream; use xmpp_parsers::{ns, Element, Jid, JidParseError}; use super::auth::auth; diff --git a/tokio-xmpp/src/client/simple_client.rs b/tokio-xmpp/src/client/simple_client.rs index a2d666a8..4b69045b 100644 --- a/tokio-xmpp/src/client/simple_client.rs +++ b/tokio-xmpp/src/client/simple_client.rs @@ -5,8 +5,11 @@ use std::pin::Pin; use std::str::FromStr; use std::task::{Context, Poll}; use tokio::net::TcpStream; +#[cfg(feature = "tls-native")] +use tokio_native_tls::TlsStream; +#[cfg(feature = "tls-rust")] +use tokio_rustls::client::TlsStream; use tokio_stream::StreamExt; -use tokio_tls::TlsStream; use xmpp_parsers::{ns, Element, Jid}; use super::auth::auth; diff --git a/tokio-xmpp/src/error.rs b/tokio-xmpp/src/error.rs index 7a03037e..42d1f9ea 100644 --- a/tokio-xmpp/src/error.rs +++ b/tokio-xmpp/src/error.rs @@ -1,3 +1,4 @@ +#[cfg(feature = "tls-native")] use native_tls::Error as TlsError; use sasl::client::MechanismError as SaslMechanismError; use std::borrow::Cow; @@ -5,6 +6,8 @@ use std::error::Error as StdError; use std::fmt; use std::io::Error as IoError; use std::str::Utf8Error; +#[cfg(feature = "tls-rust")] +use tokio_rustls::rustls::TLSError as TlsError; use trust_dns_proto::error::ProtoError; use trust_dns_resolver::error::ResolveError; diff --git a/tokio-xmpp/src/starttls.rs b/tokio-xmpp/src/starttls.rs index 1515d050..26985e0d 100644 --- a/tokio-xmpp/src/starttls.rs +++ b/tokio-xmpp/src/starttls.rs @@ -1,13 +1,49 @@ use futures::{sink::SinkExt, stream::StreamExt}; +#[cfg(feature = "tls-rust")] +use idna; +#[cfg(feature = "tls-native")] use native_tls::TlsConnector as NativeTlsConnector; +#[cfg(feature = "tls-rust")] +use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_tls::{TlsConnector, TlsStream}; +#[cfg(feature = "tls-native")] +use tokio_native_tls::{TlsConnector, TlsStream}; +#[cfg(feature = "tls-rust")] +use tokio_rustls::{client::TlsStream, rustls::ClientConfig, TlsConnector}; +#[cfg(feature = "tls-rust")] +use webpki::DNSNameRef; use xmpp_parsers::{ns, Element}; use crate::xmpp_codec::Packet; use crate::xmpp_stream::XMPPStream; use crate::{Error, ProtocolError}; +#[cfg(feature = "tls-native")] +async fn get_tls_stream( + xmpp_stream: XMPPStream, +) -> Result, Error> { + let domain = &xmpp_stream.jid.clone().domain(); + let stream = xmpp_stream.into_inner(); + let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap()) + .connect(&domain, stream) + .await?; + Ok(tls_stream) +} + +#[cfg(feature = "tls-rust")] +async fn get_tls_stream( + xmpp_stream: XMPPStream, +) -> Result, Error> { + let domain = &xmpp_stream.jid.clone().domain(); + let ascii_domain = idna::domain_to_ascii(domain).map_err(|_| Error::Idna)?; + let domain = DNSNameRef::try_from_ascii_str(&ascii_domain).unwrap(); + let stream = xmpp_stream.into_inner(); + let tls_stream = TlsConnector::from(Arc::new(ClientConfig::new())) + .connect(domain, stream) + .await?; + Ok(tls_stream) +} + /// Performs `` on an XMPPStream and returns a binary /// TlsStream. pub async fn starttls( @@ -28,11 +64,5 @@ pub async fn starttls( } } - let domain = xmpp_stream.jid.clone().domain(); - let stream = xmpp_stream.into_inner(); - let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap()) - .connect(&domain, stream) - .await?; - - Ok(tls_stream) + get_tls_stream(xmpp_stream).await }