DNS/TLS deps are now optional, component now also uses ServerConnector

This commit is contained in:
moparisthebest 2023-12-30 22:08:37 -05:00
parent e784b15402
commit 733d005f51
No known key found for this signature in database
GPG key ID: 88C93BFE27BC8229
17 changed files with 440 additions and 337 deletions

View file

@ -14,17 +14,12 @@ edition = "2021"
[dependencies]
bytes = "1"
futures = "0.3"
idna = "0.4"
log = "0.4"
native-tls = { version = "0.2", optional = true }
tokio = { version = "1", features = ["net", "rt", "rt-multi-thread", "macros"] }
tokio-native-tls = { version = "0.3", optional = true }
tokio-rustls = { version = "0.24", optional = true }
tokio-stream = { version = "0.1", features = [] }
tokio-util = { version = "0.7", features = ["codec"] }
hickory-resolver = "0.24"
rxml = "0.9.1"
webpki-roots = { version = "0.25", optional = true }
rxml = "0.9.1"
rand = "^0.8"
syntect = { version = "5", optional = true }
# same repository dependencies
@ -32,11 +27,21 @@ minidom = { version = "0.15", path = "../minidom" }
sasl = { version = "0.5", path = "../sasl" }
xmpp-parsers = { version = "0.20", path = "../parsers" }
# these are only needed for starttls ServerConnector support
hickory-resolver = { version = "0.24", optional = true}
idna = { version = "0.4", optional = true}
native-tls = { version = "0.2", optional = true }
tokio-native-tls = { version = "0.3", optional = true }
tokio-rustls = { version = "0.24", optional = true }
[dev-dependencies]
env_logger = { version = "0.10", default-features = false, features = ["auto-color", "humantime"] }
[features]
default = ["tls-native"]
default = ["starttls-rust"]
starttls = ["hickory-resolver", "idna"]
tls-rust = ["tokio-rustls", "webpki-roots"]
tls-native = ["tokio-native-tls", "native-tls"]
starttls-native = ["starttls", "tls-native"]
starttls-rust = ["starttls", "tls-rust"]
syntax-highlighting = ["syntect"]

View file

@ -1,23 +1,16 @@
use futures::{sink::SinkExt, task::Poll, Future, Sink, Stream};
use sasl::common::ChannelBinding;
use std::mem::replace;
use std::pin::Pin;
use std::task::Context;
use tokio::net::TcpStream;
use tokio::task::JoinHandle;
use xmpp_parsers::{ns, Element, Jid};
use super::connect::{AsyncReadAndWrite, ServerConnector};
use super::connect::client_login;
use crate::connect::{AsyncReadAndWrite, ServerConnector};
use crate::event::Event;
use crate::happy_eyeballs::{connect_to_host, connect_with_srv};
use crate::starttls::starttls;
use crate::xmpp_codec::Packet;
use crate::xmpp_stream::{self, add_stanza_id, XMPPStream};
use crate::{client_login, Error, ProtocolError};
#[cfg(feature = "tls-native")]
use tokio_native_tls::TlsStream;
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
use tokio_rustls::client::TlsStream;
use crate::xmpp_stream::{add_stanza_id, XMPPStream};
use crate::{Error, ProtocolError};
/// XMPP client connection and state
///
@ -43,76 +36,6 @@ pub struct Config<C> {
pub server: C,
}
/// XMPP server connection configuration
#[derive(Clone, Debug)]
pub enum ServerConfig {
/// Use SRV record to find server host
UseSrv,
#[allow(unused)]
/// Manually define server host and port
Manual {
/// Server host name
host: String,
/// Server port
port: u16,
},
}
impl ServerConnector for ServerConfig {
type Stream = TlsStream<TcpStream>;
async fn connect(&self, jid: &Jid) -> Result<XMPPStream<Self::Stream>, Error> {
// TCP connection
let tcp_stream = match self {
ServerConfig::UseSrv => {
connect_with_srv(jid.domain_str(), "_xmpp-client._tcp", 5222).await?
}
ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), *port).await?,
};
// Unencryped XMPPStream
let xmpp_stream =
xmpp_stream::XMPPStream::start(tcp_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
.await?;
if xmpp_stream.stream_features.can_starttls() {
// TlsStream
let tls_stream = starttls(xmpp_stream).await?;
// Encrypted XMPPStream
xmpp_stream::XMPPStream::start(tls_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
.await
} else {
return Err(Error::Protocol(ProtocolError::NoTls));
}
}
fn channel_binding(
#[allow(unused_variables)] stream: &Self::Stream,
) -> Result<sasl::common::ChannelBinding, Error> {
#[cfg(feature = "tls-native")]
{
log::warn!("tls-native doesnt support channel binding, please use tls-rust if you want this feature!");
Ok(ChannelBinding::None)
}
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
{
let (_, connection) = stream.get_ref();
Ok(match connection.protocol_version() {
// TODO: Add support for TLS 1.2 and earlier.
Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
let data = vec![0u8; 32];
let data = connection.export_keying_material(
data,
b"EXPORTER-Channel-Binding",
None,
)?;
ChannelBinding::TlsExporter(data)
}
_ => ChannelBinding::None,
})
}
}
}
enum ClientState<S: AsyncReadAndWrite> {
Invalid,
Disconnected,
@ -120,21 +43,6 @@ enum ClientState<S: AsyncReadAndWrite> {
Connected(XMPPStream<S>),
}
impl Client<ServerConfig> {
/// Start a new XMPP client
///
/// Start polling the returned instance so that it will connect
/// and yield events.
pub fn new<J: Into<Jid>, P: Into<String>>(jid: J, password: P) -> Self {
let config = Config {
jid: jid.into(),
password: password.into(),
server: ServerConfig::UseSrv,
};
Self::new_with_config(config)
}
}
impl<C: ServerConnector> Client<C> {
/// Start a new client given that the JID is already parsed.
pub fn new_with_config(config: Config<C>) -> Self {

View file

@ -1,32 +1,11 @@
use sasl::common::{ChannelBinding, Credentials};
use tokio::io::{AsyncRead, AsyncWrite};
use sasl::common::Credentials;
use xmpp_parsers::{ns, Jid};
use super::{auth::auth, bind::bind};
use crate::client::auth::auth;
use crate::client::bind::bind;
use crate::connect::ServerConnector;
use crate::{xmpp_stream::XMPPStream, Error};
/// trait returned wrapped in XMPPStream by ServerConnector
pub trait AsyncReadAndWrite: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T: AsyncRead + AsyncWrite + Unpin + Send> AsyncReadAndWrite for T {}
/// Trait called to connect to an XMPP server, perhaps called multiple times
pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
/// The type of Stream this ServerConnector produces
type Stream: AsyncReadAndWrite;
/// This must return the connection ready to login, ie if starttls is involved, after TLS has been started, and then after the <stream headers are exchanged
fn connect(
&self,
jid: &Jid,
) -> impl std::future::Future<Output = Result<XMPPStream<Self::Stream>, Error>> + Send;
/// Return channel binding data if available
/// do not fail if channel binding is simply unavailable, just return Ok(None)
/// this should only be called after the TLS handshake is finished
fn channel_binding(_stream: &Self::Stream) -> Result<ChannelBinding, Error> {
Ok(ChannelBinding::None)
}
}
/// Log into an XMPP server as a client with a jid+pass
/// does channel binding if supported
pub async fn client_login<C: ServerConnector>(
@ -37,7 +16,7 @@ pub async fn client_login<C: ServerConnector>(
let username = jid.node_str().unwrap();
let password = password;
let xmpp_stream = server.connect(&jid).await?;
let xmpp_stream = server.connect(&jid, ns::JABBER_CLIENT).await?;
let channel_binding = C::channel_binding(xmpp_stream.stream.get_ref())?;

View file

@ -1,6 +1,7 @@
mod auth;
mod bind;
pub(crate) mod connect;
pub mod async_client;
pub mod connect;
pub mod simple_client;

View file

@ -1,13 +1,15 @@
use futures::{sink::SinkExt, Sink, Stream};
use std::pin::Pin;
use std::str::FromStr;
use std::task::{Context, Poll};
use tokio_stream::StreamExt;
use xmpp_parsers::{ns, Element, Jid};
use crate::connect::ServerConnector;
use crate::xmpp_codec::Packet;
use crate::xmpp_stream::{add_stanza_id, XMPPStream};
use crate::{client_login, AsyncServerConfig, Error, ServerConnector};
use crate::Error;
use super::connect::client_login;
/// A simple XMPP client connection
///
@ -17,19 +19,6 @@ pub struct Client<C: ServerConnector> {
stream: XMPPStream<C::Stream>,
}
impl Client<AsyncServerConfig> {
/// Start a new XMPP client and wait for a usable session
pub async fn new<P: Into<String>>(jid: &str, password: P) -> Result<Self, Error> {
let jid = Jid::from_str(jid)?;
Self::new_with_jid(jid, password.into()).await
}
/// Start a new client given that the JID is already parsed.
pub async fn new_with_jid(jid: Jid, password: String) -> Result<Self, Error> {
Self::new_with_jid_connector(AsyncServerConfig::UseSrv, jid, password).await
}
}
impl<C: ServerConnector> Client<C> {
/// Start a new client given that the JID is already parsed.
pub async fn new_with_jid_connector(

View file

@ -0,0 +1,18 @@
use xmpp_parsers::{ns, Jid};
use crate::connect::ServerConnector;
use crate::{xmpp_stream::XMPPStream, Error};
use super::auth::auth;
/// Log into an XMPP server as a client with a jid+pass
pub async fn component_login<C: ServerConnector>(
connector: C,
jid: Jid,
password: String,
) -> Result<XMPPStream<C::Stream>, Error> {
let password = password;
let mut xmpp_stream = connector.connect(&jid, ns::COMPONENT).await?;
auth(&mut xmpp_stream, password).await?;
Ok(xmpp_stream)
}

View file

@ -5,53 +5,39 @@ use futures::{sink::SinkExt, task::Poll, Sink, Stream};
use std::pin::Pin;
use std::str::FromStr;
use std::task::Context;
use tokio::net::TcpStream;
use xmpp_parsers::{ns, Element, Jid};
use super::happy_eyeballs::connect_to_host;
use self::connect::component_login;
use super::xmpp_codec::Packet;
use super::xmpp_stream;
use super::Error;
use crate::connect::ServerConnector;
use crate::xmpp_stream::add_stanza_id;
use crate::xmpp_stream::XMPPStream;
mod auth;
pub(crate) mod connect;
/// Component connection to an XMPP server
///
/// This simplifies the `XMPPStream` to a `Stream`/`Sink` of `Element`
/// (stanzas). Connection handling however is up to the user.
pub struct Component {
pub struct Component<C: ServerConnector> {
/// The component's Jabber-Id
pub jid: Jid,
stream: XMPPStream,
stream: XMPPStream<C::Stream>,
}
type XMPPStream = xmpp_stream::XMPPStream<TcpStream>;
impl Component {
impl<C: ServerConnector> Component<C> {
/// Start a new XMPP component
pub async fn new(jid: &str, password: &str, server: &str, port: u16) -> Result<Self, Error> {
pub async fn new(jid: &str, password: &str, connector: C) -> Result<Self, Error> {
let jid = Jid::from_str(jid)?;
let password = password.to_owned();
let stream = Self::connect(jid.clone(), password, server, port).await?;
let stream = component_login(connector, jid.clone(), password).await?;
Ok(Component { jid, stream })
}
async fn connect(
jid: Jid,
password: String,
server: &str,
port: u16,
) -> Result<XMPPStream, Error> {
let password = password;
let tcp_stream = connect_to_host(server, port).await?;
let mut xmpp_stream =
xmpp_stream::XMPPStream::start(tcp_stream, jid, ns::COMPONENT_ACCEPT.to_owned())
.await?;
auth::auth(&mut xmpp_stream, password).await?;
Ok(xmpp_stream)
}
/// Send stanza
pub async fn send_stanza(&mut self, stanza: Element) -> Result<(), Error> {
self.send(add_stanza_id(stanza, ns::COMPONENT_ACCEPT)).await
@ -63,7 +49,7 @@ impl Component {
}
}
impl Stream for Component {
impl<C: ServerConnector> Stream for Component<C> {
type Item = Element;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
@ -86,7 +72,7 @@ impl Stream for Component {
}
}
impl Sink<Element> for Component {
impl<C: ServerConnector> Sink<Element> for Component<C> {
type Error = Error;
fn start_send(mut self: Pin<&mut Self>, item: Element) -> Result<(), Self::Error> {

35
tokio-xmpp/src/connect.rs Normal file
View file

@ -0,0 +1,35 @@
//! `ServerConnector` provides streams for XMPP clients
use sasl::common::ChannelBinding;
use tokio::io::{AsyncRead, AsyncWrite};
use xmpp_parsers::Jid;
use crate::xmpp_stream::XMPPStream;
/// trait returned wrapped in XMPPStream by ServerConnector
pub trait AsyncReadAndWrite: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T: AsyncRead + AsyncWrite + Unpin + Send> AsyncReadAndWrite for T {}
/// Trait that must be extended by the implementation of ServerConnector
pub trait ServerConnectorError: std::error::Error + Send {}
/// Trait called to connect to an XMPP server, perhaps called multiple times
pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
/// The type of Stream this ServerConnector produces
type Stream: AsyncReadAndWrite;
/// Error type to return
type Error: ServerConnectorError;
/// This must return the connection ready to login, ie if starttls is involved, after TLS has been started, and then after the <stream headers are exchanged
fn connect(
&self,
jid: &Jid,
ns: &str,
) -> impl std::future::Future<Output = Result<XMPPStream<Self::Stream>, Self::Error>> + Send;
/// Return channel binding data if available
/// do not fail if channel binding is simply unavailable, just return Ok(None)
/// this should only be called after the TLS handshake is finished
fn channel_binding(_stream: &Self::Stream) -> Result<ChannelBinding, Self::Error> {
Ok(ChannelBinding::None)
}
}

View file

@ -1,41 +1,26 @@
use hickory_resolver::{error::ResolveError, proto::error::ProtoError};
#[cfg(feature = "tls-native")]
use native_tls::Error as TlsError;
use sasl::client::MechanismError as SaslMechanismError;
use std::borrow::Cow;
use std::error::Error as StdError;
use std::fmt;
use std::io::Error as IoError;
use std::str::Utf8Error;
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
use tokio_rustls::rustls::client::InvalidDnsNameError;
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
use tokio_rustls::rustls::Error as TlsError;
use xmpp_parsers::sasl::DefinedCondition as SaslDefinedCondition;
use xmpp_parsers::{Error as ParsersError, JidParseError};
use crate::connect::ServerConnectorError;
/// Top-level error type
#[derive(Debug)]
pub enum Error {
/// I/O error
Io(IoError),
/// Error resolving DNS and establishing a connection
Connection(ConnecterError),
/// DNS label conversion error, no details available from module
/// `idna`
Idna,
/// Error parsing Jabber-Id
JidParse(JidParseError),
/// Protocol-level error
Protocol(ProtocolError),
/// Authentication error
Auth(AuthError),
/// TLS error
Tls(TlsError),
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
/// DNS name parsing error
DnsNameError(InvalidDnsNameError),
/// Connection closed
Disconnected,
/// Shoud never happen
@ -44,6 +29,8 @@ pub enum Error {
Fmt(fmt::Error),
/// Utf8 error
Utf8(Utf8Error),
/// Error resolving DNS and/or establishing a connection, returned by a ServerConnector impl
Connection(Box<dyn ServerConnectorError>),
}
impl fmt::Display for Error {
@ -51,13 +38,9 @@ impl fmt::Display for Error {
match self {
Error::Io(e) => write!(fmt, "IO error: {}", e),
Error::Connection(e) => write!(fmt, "connection error: {}", e),
Error::Idna => write!(fmt, "IDNA error"),
Error::JidParse(e) => write!(fmt, "jid parse error: {}", e),
Error::Protocol(e) => write!(fmt, "protocol error: {}", e),
Error::Auth(e) => write!(fmt, "authentication error: {}", e),
Error::Tls(e) => write!(fmt, "TLS error: {}", e),
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
Error::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
Error::Disconnected => write!(fmt, "disconnected"),
Error::InvalidState => write!(fmt, "invalid state"),
Error::Fmt(e) => write!(fmt, "Fmt error: {}", e),
@ -74,9 +57,9 @@ impl From<IoError> for Error {
}
}
impl From<ConnecterError> for Error {
fn from(e: ConnecterError) -> Self {
Error::Connection(e)
impl<T: ServerConnectorError + 'static> From<T> for Error {
fn from(e: T) -> Self {
Error::Connection(Box::new(e))
}
}
@ -98,12 +81,6 @@ impl From<AuthError> for Error {
}
}
impl From<TlsError> for Error {
fn from(e: TlsError) -> Self {
Error::Tls(e)
}
}
impl From<fmt::Error> for Error {
fn from(e: fmt::Error) -> Self {
Error::Fmt(e)
@ -116,13 +93,6 @@ impl From<Utf8Error> for Error {
}
}
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
impl From<InvalidDnsNameError> for Error {
fn from(e: InvalidDnsNameError) -> Self {
Error::DnsNameError(e)
}
}
/// XML parse error wrapper type
#[derive(Debug)]
pub struct ParseError(pub Cow<'static, str>);
@ -227,22 +197,3 @@ impl fmt::Display for AuthError {
}
}
}
/// Error establishing connection
#[derive(Debug)]
pub enum ConnecterError {
/// All attempts failed, no error available
AllFailed,
/// DNS protocol error
Dns(ProtoError),
/// DNS resolution error
Resolve(ResolveError),
}
impl StdError for ConnecterError {}
impl std::fmt::Display for ConnecterError {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(fmt, "{:?}", self)
}
}

View file

@ -5,31 +5,35 @@
#[cfg(all(feature = "tls-native", feature = "tls-rust"))]
compile_error!("Both tls-native and tls-rust features can't be enabled at the same time.");
#[cfg(all(not(feature = "tls-native"), not(feature = "tls-rust")))]
compile_error!("One of tls-native and tls-rust features must be enabled.");
#[cfg(all(
feature = "starttls",
not(feature = "tls-native"),
not(feature = "tls-rust")
))]
compile_error!(
"when starttls feature enabled one of tls-native and tls-rust features must be enabled."
);
mod starttls;
#[cfg(feature = "starttls")]
pub mod starttls;
mod stream_start;
mod xmpp_codec;
pub use crate::xmpp_codec::Packet;
mod event;
pub use event::Event;
mod client;
mod happy_eyeballs;
pub mod connect;
pub mod stream_features;
pub mod xmpp_stream;
pub use client::{
async_client::{
Client as AsyncClient, Config as AsyncConfig, ServerConfig as AsyncServerConfig,
},
connect::{client_login, AsyncReadAndWrite, ServerConnector},
async_client::{Client as AsyncClient, Config as AsyncConfig},
simple_client::Client as SimpleClient,
};
mod component;
pub use crate::component::Component;
mod error;
pub use crate::error::{AuthError, ConnecterError, Error, ParseError, ProtocolError};
pub use starttls::starttls;
pub use crate::error::{AuthError, Error, ParseError, ProtocolError};
// Re-exports
pub use minidom::Element;

View file

@ -1,85 +0,0 @@
use futures::{sink::SinkExt, stream::StreamExt};
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
use {
std::sync::Arc,
tokio_rustls::{
client::TlsStream,
rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName},
TlsConnector,
},
webpki_roots,
};
#[cfg(feature = "tls-native")]
use {
native_tls::TlsConnector as NativeTlsConnector,
tokio_native_tls::{TlsConnector, TlsStream},
};
use tokio::io::{AsyncRead, AsyncWrite};
use xmpp_parsers::{ns, Element};
use crate::xmpp_codec::Packet;
use crate::xmpp_stream::XMPPStream;
use crate::{Error, ProtocolError};
#[cfg(feature = "tls-native")]
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
xmpp_stream: XMPPStream<S>,
) -> Result<TlsStream<S>, Error> {
let domain = xmpp_stream.jid.domain_str().to_owned();
let stream = xmpp_stream.into_inner();
let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
.connect(&domain, stream)
.await?;
Ok(tls_stream)
}
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
xmpp_stream: XMPPStream<S>,
) -> Result<TlsStream<S>, Error> {
let domain = xmpp_stream.jid.domain_str().to_owned();
let domain = ServerName::try_from(domain.as_str())?;
let stream = xmpp_stream.into_inner();
let mut root_store = RootCertStore::empty();
root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let tls_stream = TlsConnector::from(Arc::new(config))
.connect(domain, stream)
.await?;
Ok(tls_stream)
}
/// Performs `<starttls/>` on an XMPPStream and returns a binary
/// TlsStream.
pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin>(
mut xmpp_stream: XMPPStream<S>,
) -> Result<TlsStream<S>, Error> {
let nonza = Element::builder("starttls", ns::TLS).build();
let packet = Packet::Stanza(nonza);
xmpp_stream.send(packet).await?;
loop {
match xmpp_stream.next().await {
Some(Ok(Packet::Stanza(ref stanza))) if stanza.name() == "proceed" => break,
Some(Ok(Packet::Text(_))) => {}
Some(Err(e)) => return Err(e.into()),
_ => {
return Err(ProtocolError::NoTls.into());
}
}
}
get_tls_stream(xmpp_stream).await
}

View file

@ -0,0 +1,35 @@
use std::str::FromStr;
use xmpp_parsers::Jid;
use crate::{AsyncClient, AsyncConfig, Error, SimpleClient};
use super::ServerConfig;
impl AsyncClient<ServerConfig> {
/// Start a new XMPP client
///
/// Start polling the returned instance so that it will connect
/// and yield events.
pub fn new<J: Into<Jid>, P: Into<String>>(jid: J, password: P) -> Self {
let config = AsyncConfig {
jid: jid.into(),
password: password.into(),
server: ServerConfig::UseSrv,
};
Self::new_with_config(config)
}
}
impl SimpleClient<ServerConfig> {
/// Start a new XMPP client and wait for a usable session
pub async fn new<P: Into<String>>(jid: &str, password: P) -> Result<Self, Error> {
let jid = Jid::from_str(jid)?;
Self::new_with_jid(jid, password.into()).await
}
/// Start a new client given that the JID is already parsed.
pub async fn new_with_jid(jid: Jid, password: String) -> Result<Self, Error> {
Self::new_with_jid_connector(ServerConfig::UseSrv, jid, password).await
}
}

View file

@ -0,0 +1,105 @@
use hickory_resolver::{error::ResolveError, proto::error::ProtoError};
#[cfg(feature = "tls-native")]
use native_tls::Error as TlsError;
use std::borrow::Cow;
use std::error::Error as StdError;
use std::fmt;
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
use tokio_rustls::rustls::client::InvalidDnsNameError;
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
use tokio_rustls::rustls::Error as TlsError;
/// Top-level error type
#[derive(Debug)]
pub enum Error {
/// Error resolving DNS and establishing a connection
Connection(ConnectorError),
/// DNS label conversion error, no details available from module
/// `idna`
Idna,
/// TLS error
Tls(TlsError),
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
/// DNS name parsing error
DnsNameError(InvalidDnsNameError),
/// tokio-xmpp error
TokioXMPP(crate::error::Error),
}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match self {
Error::Connection(e) => write!(fmt, "connection error: {}", e),
Error::Idna => write!(fmt, "IDNA error"),
Error::Tls(e) => write!(fmt, "TLS error: {}", e),
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
Error::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
Error::TokioXMPP(e) => write!(fmt, "TokioXMPP error: {}", e),
}
}
}
impl StdError for Error {}
impl From<crate::error::Error> for Error {
fn from(e: crate::error::Error) -> Self {
Error::TokioXMPP(e)
}
}
impl From<ConnectorError> for Error {
fn from(e: ConnectorError) -> Self {
Error::Connection(e)
}
}
impl From<TlsError> for Error {
fn from(e: TlsError) -> Self {
Error::Tls(e)
}
}
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
impl From<InvalidDnsNameError> for Error {
fn from(e: InvalidDnsNameError) -> Self {
Error::DnsNameError(e)
}
}
/// XML parse error wrapper type
#[derive(Debug)]
pub struct ParseError(pub Cow<'static, str>);
impl StdError for ParseError {
fn description(&self) -> &str {
self.0.as_ref()
}
fn cause(&self) -> Option<&dyn StdError> {
None
}
}
impl fmt::Display for ParseError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}
/// Error establishing connection
#[derive(Debug)]
pub enum ConnectorError {
/// All attempts failed, no error available
AllFailed,
/// DNS protocol error
Dns(ProtoError),
/// DNS resolution error
Resolve(ResolveError),
}
impl StdError for ConnectorError {}
impl std::fmt::Display for ConnectorError {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(fmt, "{:?}", self)
}
}

View file

@ -1,4 +1,4 @@
use crate::{ConnecterError, Error};
use super::error::{ConnectorError, Error};
use hickory_resolver::{IntoName, TokioAsyncResolver};
use idna;
use log::debug;
@ -9,22 +9,24 @@ pub async fn connect_to_host(domain: &str, port: u16) -> Result<TcpStream, Error
let ascii_domain = idna::domain_to_ascii(&domain).map_err(|_| Error::Idna)?;
if let Ok(ip) = ascii_domain.parse() {
return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
return Ok(TcpStream::connect(&SocketAddr::new(ip, port))
.await
.map_err(|e| Error::from(crate::Error::Io(e)))?);
}
let resolver = TokioAsyncResolver::tokio_from_system_conf().map_err(ConnecterError::Resolve)?;
let resolver = TokioAsyncResolver::tokio_from_system_conf().map_err(ConnectorError::Resolve)?;
let ips = resolver
.lookup_ip(ascii_domain)
.await
.map_err(ConnecterError::Resolve)?;
.map_err(ConnectorError::Resolve)?;
for ip in ips.iter() {
match TcpStream::connect(&SocketAddr::new(ip, port)).await {
Ok(stream) => return Ok(stream),
Err(_) => {}
}
}
Err(Error::Disconnected)
Err(crate::Error::Disconnected.into())
}
pub async fn connect_with_srv(
@ -36,14 +38,16 @@ pub async fn connect_with_srv(
if let Ok(ip) = ascii_domain.parse() {
debug!("Attempting connection to {ip}:{fallback_port}");
return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port))
.await
.map_err(|e| Error::from(crate::Error::Io(e)))?);
}
let resolver = TokioAsyncResolver::tokio_from_system_conf().map_err(ConnecterError::Resolve)?;
let resolver = TokioAsyncResolver::tokio_from_system_conf().map_err(ConnectorError::Resolve)?;
let srv_domain = format!("{}.{}.", srv, ascii_domain)
.into_name()
.map_err(ConnecterError::Dns)?;
.map_err(ConnectorError::Dns)?;
let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok();
match srv_records {
@ -56,7 +60,7 @@ pub async fn connect_with_srv(
Err(_) => {}
}
}
Err(Error::Disconnected)
Err(crate::Error::Disconnected.into())
}
None => {
// SRV lookup error, retry with hostname

View file

@ -0,0 +1,168 @@
//! `starttls::ServerConfig` provides a `ServerConnector` for starttls connections
use futures::{sink::SinkExt, stream::StreamExt};
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
use {
std::sync::Arc,
tokio_rustls::{
client::TlsStream,
rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName},
TlsConnector,
},
webpki_roots,
};
#[cfg(feature = "tls-native")]
use {
native_tls::TlsConnector as NativeTlsConnector,
tokio_native_tls::{TlsConnector, TlsStream},
};
use sasl::common::ChannelBinding;
use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpStream,
};
use xmpp_parsers::{ns, Element, Jid};
use crate::{connect::ServerConnector, xmpp_codec::Packet};
use crate::{connect::ServerConnectorError, xmpp_stream::XMPPStream};
use self::error::Error;
use self::happy_eyeballs::{connect_to_host, connect_with_srv};
mod client;
mod error;
mod happy_eyeballs;
/// StartTLS XMPP server connection configuration
#[derive(Clone, Debug)]
pub enum ServerConfig {
/// Use SRV record to find server host
UseSrv,
#[allow(unused)]
/// Manually define server host and port
Manual {
/// Server host name
host: String,
/// Server port
port: u16,
},
}
impl ServerConnectorError for Error {}
impl ServerConnector for ServerConfig {
type Stream = TlsStream<TcpStream>;
type Error = Error;
async fn connect(&self, jid: &Jid, ns: &str) -> Result<XMPPStream<Self::Stream>, Error> {
// TCP connection
let tcp_stream = match self {
ServerConfig::UseSrv => {
connect_with_srv(jid.domain_str(), "_xmpp-client._tcp", 5222).await?
}
ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), *port).await?,
};
// Unencryped XMPPStream
let xmpp_stream = XMPPStream::start(tcp_stream, jid.clone(), ns.to_owned()).await?;
if xmpp_stream.stream_features.can_starttls() {
// TlsStream
let tls_stream = starttls(xmpp_stream).await?;
// Encrypted XMPPStream
Ok(XMPPStream::start(tls_stream, jid.clone(), ns.to_owned()).await?)
} else {
return Err(crate::Error::Protocol(crate::ProtocolError::NoTls).into());
}
}
fn channel_binding(
#[allow(unused_variables)] stream: &Self::Stream,
) -> Result<sasl::common::ChannelBinding, Error> {
#[cfg(feature = "tls-native")]
{
log::warn!("tls-native doesnt support channel binding, please use tls-rust if you want this feature!");
Ok(ChannelBinding::None)
}
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
{
let (_, connection) = stream.get_ref();
Ok(match connection.protocol_version() {
// TODO: Add support for TLS 1.2 and earlier.
Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
let data = vec![0u8; 32];
let data = connection.export_keying_material(
data,
b"EXPORTER-Channel-Binding",
None,
)?;
ChannelBinding::TlsExporter(data)
}
_ => ChannelBinding::None,
})
}
}
}
#[cfg(feature = "tls-native")]
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
xmpp_stream: XMPPStream<S>,
) -> Result<TlsStream<S>, Error> {
let domain = xmpp_stream.jid.domain_str().to_owned();
let stream = xmpp_stream.into_inner();
let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
.connect(&domain, stream)
.await?;
Ok(tls_stream)
}
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
xmpp_stream: XMPPStream<S>,
) -> Result<TlsStream<S>, Error> {
let domain = xmpp_stream.jid.domain_str().to_owned();
let domain = ServerName::try_from(domain.as_str())?;
let stream = xmpp_stream.into_inner();
let mut root_store = RootCertStore::empty();
root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let tls_stream = TlsConnector::from(Arc::new(config))
.connect(domain, stream)
.await
.map_err(|e| Error::from(crate::Error::Io(e)))?;
Ok(tls_stream)
}
/// Performs `<starttls/>` on an XMPPStream and returns a binary
/// TlsStream.
pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin>(
mut xmpp_stream: XMPPStream<S>,
) -> Result<TlsStream<S>, Error> {
let nonza = Element::builder("starttls", ns::TLS).build();
let packet = Packet::Stanza(nonza);
xmpp_stream.send(packet).await?;
loop {
match xmpp_stream.next().await {
Some(Ok(Packet::Stanza(ref stanza))) if stanza.name() == "proceed" => break,
Some(Ok(Packet::Text(_))) => {}
Some(Err(e)) => return Err(e.into()),
_ => {
return Err(crate::Error::Protocol(crate::ProtocolError::NoTls).into());
}
}
}
get_tls_stream(xmpp_stream).await
}

View file

@ -31,7 +31,7 @@ name = "hello_bot"
required-features = ["avatars"]
[features]
default = ["avatars", "tls-native"]
tls-native = ["tokio-xmpp/tls-native"]
tls-rust = ["tokio-xmpp/tls-rust"]
default = ["avatars", "starttls-rust"]
starttls-native = ["tokio-xmpp/starttls", "tokio-xmpp/tls-native"]
starttls-rust = ["tokio-xmpp/starttls", "tokio-xmpp/tls-rust"]
avatars = []

View file

@ -7,7 +7,7 @@
#![deny(bare_trait_objects)]
pub use tokio_xmpp::parsers;
use tokio_xmpp::{AsyncClient, AsyncServerConfig};
use tokio_xmpp::AsyncClient;
pub use tokio_xmpp::{BareJid, Element, FullJid, Jid};
#[macro_use]
extern crate log;
@ -32,7 +32,7 @@ pub use builder::{ClientBuilder, ClientType};
pub use event::Event;
pub use feature::ClientFeature;
type TokioXmppClient = AsyncClient<AsyncServerConfig>;
type TokioXmppClient = AsyncClient<tokio_xmpp::starttls::ServerConfig>;
pub type Error = tokio_xmpp::Error;
pub type Id = Option<String>;