tokio-xmpp: Refactor to provide channel-binding for ktls

Instead of having a second method to fetch channel-binding from the
TlsStream, do it directly in the connect() method, since after that we
don’t have enough information to fetch it any longer when using ktls.
This commit is contained in:
Emmanuel Gil Peyrot 2024-12-18 18:17:39 +01:00 committed by xmpp ftw
parent 7991cef904
commit 897c2abe0b
5 changed files with 61 additions and 77 deletions

View file

@ -112,11 +112,9 @@ pub async fn client_auth<C: ServerConnector>(
let username = jid.node().unwrap().as_str(); let username = jid.node().unwrap().as_str();
let password = password; let password = password;
let xmpp_stream = server.connect(&jid, ns::JABBER_CLIENT, timeouts).await?; let (xmpp_stream, channel_binding) = server.connect(&jid, ns::JABBER_CLIENT, timeouts).await?;
let (features, xmpp_stream) = xmpp_stream.recv_features().await?; let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
let channel_binding = C::channel_binding(xmpp_stream.get_stream())?;
let creds = Credentials::default() let creds = Credentials::default()
.with_username(username) .with_username(username)
.with_password(password) .with_password(password)

View file

@ -16,7 +16,7 @@ pub async fn component_login<C: ServerConnector>(
timeouts: Timeouts, timeouts: Timeouts,
) -> Result<XmppStream<C::Stream>, Error> { ) -> Result<XmppStream<C::Stream>, Error> {
let password = password; let password = password;
let mut stream = connector.connect(&jid, ns::COMPONENT, timeouts).await?; let (mut stream, _) = connector.connect(&jid, ns::COMPONENT, timeouts).await?;
let header = stream.take_header(); let header = stream.take_header();
let mut stream = stream.skip_features(); let mut stream = stream.skip_features();
let stream_id = match header.id { let stream_id = match header.id {

View file

@ -37,12 +37,7 @@ pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
jid: &Jid, jid: &Jid,
ns: &'static str, ns: &'static str,
timeouts: Timeouts, timeouts: Timeouts,
) -> impl std::future::Future<Output = Result<PendingFeaturesRecv<Self::Stream>, Error>> + Send; ) -> impl std::future::Future<
Output = Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error>,
/// Return channel binding data if available > + Send;
/// do not fail if channel binding is simply unavailable, just return Ok(None)
/// this should only be called after the TLS handshake is finished
fn channel_binding(_stream: &Self::Stream) -> Result<ChannelBinding, Error> {
Ok(ChannelBinding::None)
}
} }

View file

@ -82,7 +82,7 @@ impl ServerConnector for StartTlsServerConnector {
jid: &Jid, jid: &Jid,
ns: &'static str, ns: &'static str,
timeouts: Timeouts, timeouts: Timeouts,
) -> Result<PendingFeaturesRecv<Self::Stream>, Error> { ) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?); let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?);
// Unencryped XmppStream // Unencryped XmppStream
@ -101,78 +101,51 @@ impl ServerConnector for StartTlsServerConnector {
if features.can_starttls() { if features.can_starttls() {
// TlsStream // TlsStream
let tls_stream = starttls(xmpp_stream, jid.domain().as_str()).await?; let (tls_stream, channel_binding) =
starttls(xmpp_stream, jid.domain().as_str()).await?;
// Encrypted XmppStream // Encrypted XmppStream
Ok(initiate_stream( Ok((
tokio::io::BufStream::new(tls_stream), initiate_stream(
ns, tokio::io::BufStream::new(tls_stream),
StreamHeader { ns,
to: Some(Cow::Borrowed(jid.domain().as_str())), StreamHeader {
from: None, to: Some(Cow::Borrowed(jid.domain().as_str())),
id: None, from: None,
}, id: None,
timeouts, },
) timeouts,
.await?) )
.await?,
channel_binding,
))
} else { } else {
Err(crate::Error::Protocol(ProtocolError::NoTls).into()) Err(crate::Error::Protocol(ProtocolError::NoTls).into())
} }
} }
fn channel_binding(
#[allow(unused_variables)] stream: &Self::Stream,
) -> Result<sasl::common::ChannelBinding, Error> {
#[cfg(feature = "tls-native")]
{
log::warn!("tls-native doesnt support channel binding, please use tls-rust if you want this feature!");
Ok(ChannelBinding::None)
}
#[cfg(all(feature = "tls-rust-ktls", not(feature = "tls-native")))]
{
log::warn!("Kernel TLS doesnt support channel binding yet, we would have to extract the secrets in the rustls TlsStream before converting it into a KtlsStream.");
Ok(ChannelBinding::None)
}
#[cfg(all(
feature = "tls-rust",
not(feature = "tls-native"),
not(feature = "tls-rust-ktls")
))]
{
let (_, connection) = stream.get_ref().get_ref();
Ok(match connection.protocol_version() {
// TODO: Add support for TLS 1.2 and earlier.
Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
let data = vec![0u8; 32];
let data = connection
.export_keying_material(data, b"EXPORTER-Channel-Binding", None)
.map_err(|e| StartTlsError::Tls(e))?;
ChannelBinding::TlsExporter(data)
}
_ => ChannelBinding::None,
})
}
}
} }
#[cfg(feature = "tls-native")] #[cfg(feature = "tls-native")]
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>( async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
xmpp_stream: XmppStream<BufStream<S>>, xmpp_stream: XmppStream<BufStream<S>>,
domain: &str, domain: &str,
) -> Result<TlsStream<S>, Error> { ) -> Result<(TlsStream<S>, ChannelBinding), Error> {
let domain = domain.to_owned(); let domain = domain.to_owned();
let stream = xmpp_stream.into_inner().into_inner(); let stream = xmpp_stream.into_inner().into_inner();
let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap()) let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
.connect(&domain, stream) .connect(&domain, stream)
.await .await
.map_err(|e| StartTlsError::Tls(e))?; .map_err(|e| StartTlsError::Tls(e))?;
Ok(tls_stream) log::warn!(
"tls-native doesnt support channel binding, please use tls-rust if you want this feature!"
);
Ok((tls_stream, ChannelBinding::None))
} }
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>( async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
xmpp_stream: XmppStream<BufStream<S>>, xmpp_stream: XmppStream<BufStream<S>>,
domain: &str, domain: &str,
) -> Result<TlsStream<S>, Error> { ) -> Result<(TlsStream<S>, ChannelBinding), Error> {
let domain = ServerName::try_from(domain.to_owned()).map_err(StartTlsError::DnsNameError)?; let domain = ServerName::try_from(domain.to_owned()).map_err(StartTlsError::DnsNameError)?;
let stream = xmpp_stream.into_inner().into_inner(); let stream = xmpp_stream.into_inner().into_inner();
let mut root_store = RootCertStore::empty(); let mut root_store = RootCertStore::empty();
@ -197,11 +170,26 @@ async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
.connect(domain, stream) .connect(domain, stream)
.await .await
.map_err(|e| Error::from(crate::Error::Io(e)))?; .map_err(|e| Error::from(crate::Error::Io(e)))?;
// Extract the channel-binding information before we hand the stream over to ktls.
let (_, connection) = tls_stream.get_ref();
let channel_binding = match connection.protocol_version() {
// TODO: Add support for TLS 1.2 and earlier.
Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
let data = vec![0u8; 32];
let data = connection
.export_keying_material(data, b"EXPORTER-Channel-Binding", None)
.map_err(|e| StartTlsError::Tls(e))?;
ChannelBinding::TlsExporter(data)
}
_ => ChannelBinding::None,
};
#[cfg(feature = "tls-rust-ktls")] #[cfg(feature = "tls-rust-ktls")]
let tls_stream = ktls::config_ktls_client(tls_stream) let tls_stream = ktls::config_ktls_client(tls_stream)
.await .await
.map_err(StartTlsError::KtlsError)?; .map_err(StartTlsError::KtlsError)?;
Ok(tls_stream) Ok((tls_stream, channel_binding))
} }
/// Performs `<starttls/>` on an XmppStream and returns a binary /// Performs `<starttls/>` on an XmppStream and returns a binary
@ -209,7 +197,7 @@ async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>( pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
mut stream: XmppStream<BufStream<S>>, mut stream: XmppStream<BufStream<S>>,
domain: &str, domain: &str,
) -> Result<TlsStream<S>, Error> { ) -> Result<(TlsStream<S>, ChannelBinding), Error> {
stream stream
.send(&XmppStreamElement::Starttls(starttls::Nonza::Request( .send(&XmppStreamElement::Starttls(starttls::Nonza::Request(
Request, Request,

View file

@ -5,7 +5,7 @@ use std::borrow::Cow;
use tokio::{io::BufStream, net::TcpStream}; use tokio::{io::BufStream, net::TcpStream};
use crate::{ use crate::{
connect::{DnsConfig, ServerConnector}, connect::{ChannelBinding, DnsConfig, ServerConnector},
xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader, Timeouts}, xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader, Timeouts},
Client, Component, Error, Client, Component, Error,
}; };
@ -37,18 +37,21 @@ impl ServerConnector for TcpServerConnector {
jid: &xmpp_parsers::jid::Jid, jid: &xmpp_parsers::jid::Jid,
ns: &'static str, ns: &'static str,
timeouts: Timeouts, timeouts: Timeouts,
) -> Result<PendingFeaturesRecv<Self::Stream>, Error> { ) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
let stream = BufStream::new(self.0.resolve().await?); let stream = BufStream::new(self.0.resolve().await?);
Ok(initiate_stream( Ok((
stream, initiate_stream(
ns, stream,
StreamHeader { ns,
to: Some(Cow::Borrowed(jid.domain().as_str())), StreamHeader {
from: None, to: Some(Cow::Borrowed(jid.domain().as_str())),
id: None, from: None,
}, id: None,
timeouts, },
) timeouts,
.await?) )
.await?,
ChannelBinding::None,
))
} }
} }