tokio-xmpp: Refactor to provide channel-binding for ktls

Instead of having a second method to fetch channel-binding from the
TlsStream, do it directly in the connect() method, since after that we
don’t have enough information to fetch it any longer when using ktls.
This commit is contained in:
Emmanuel Gil Peyrot 2024-12-18 18:17:39 +01:00 committed by xmpp ftw
parent 7991cef904
commit 897c2abe0b
5 changed files with 61 additions and 77 deletions

View file

@ -112,11 +112,9 @@ pub async fn client_auth<C: ServerConnector>(
let username = jid.node().unwrap().as_str();
let password = password;
let xmpp_stream = server.connect(&jid, ns::JABBER_CLIENT, timeouts).await?;
let (xmpp_stream, channel_binding) = server.connect(&jid, ns::JABBER_CLIENT, timeouts).await?;
let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
let channel_binding = C::channel_binding(xmpp_stream.get_stream())?;
let creds = Credentials::default()
.with_username(username)
.with_password(password)

View file

@ -16,7 +16,7 @@ pub async fn component_login<C: ServerConnector>(
timeouts: Timeouts,
) -> Result<XmppStream<C::Stream>, Error> {
let password = password;
let mut stream = connector.connect(&jid, ns::COMPONENT, timeouts).await?;
let (mut stream, _) = connector.connect(&jid, ns::COMPONENT, timeouts).await?;
let header = stream.take_header();
let mut stream = stream.skip_features();
let stream_id = match header.id {

View file

@ -37,12 +37,7 @@ pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
jid: &Jid,
ns: &'static str,
timeouts: Timeouts,
) -> impl std::future::Future<Output = Result<PendingFeaturesRecv<Self::Stream>, Error>> + Send;
/// Return channel binding data if available
/// do not fail if channel binding is simply unavailable, just return Ok(None)
/// this should only be called after the TLS handshake is finished
fn channel_binding(_stream: &Self::Stream) -> Result<ChannelBinding, Error> {
Ok(ChannelBinding::None)
}
) -> impl std::future::Future<
Output = Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error>,
> + Send;
}

View file

@ -82,7 +82,7 @@ impl ServerConnector for StartTlsServerConnector {
jid: &Jid,
ns: &'static str,
timeouts: Timeouts,
) -> Result<PendingFeaturesRecv<Self::Stream>, Error> {
) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?);
// Unencryped XmppStream
@ -101,9 +101,11 @@ impl ServerConnector for StartTlsServerConnector {
if features.can_starttls() {
// TlsStream
let tls_stream = starttls(xmpp_stream, jid.domain().as_str()).await?;
let (tls_stream, channel_binding) =
starttls(xmpp_stream, jid.domain().as_str()).await?;
// Encrypted XmppStream
Ok(initiate_stream(
Ok((
initiate_stream(
tokio::io::BufStream::new(tls_stream),
ns,
StreamHeader {
@ -113,66 +115,37 @@ impl ServerConnector for StartTlsServerConnector {
},
timeouts,
)
.await?)
.await?,
channel_binding,
))
} else {
Err(crate::Error::Protocol(ProtocolError::NoTls).into())
}
}
fn channel_binding(
#[allow(unused_variables)] stream: &Self::Stream,
) -> Result<sasl::common::ChannelBinding, Error> {
#[cfg(feature = "tls-native")]
{
log::warn!("tls-native doesnt support channel binding, please use tls-rust if you want this feature!");
Ok(ChannelBinding::None)
}
#[cfg(all(feature = "tls-rust-ktls", not(feature = "tls-native")))]
{
log::warn!("Kernel TLS doesnt support channel binding yet, we would have to extract the secrets in the rustls TlsStream before converting it into a KtlsStream.");
Ok(ChannelBinding::None)
}
#[cfg(all(
feature = "tls-rust",
not(feature = "tls-native"),
not(feature = "tls-rust-ktls")
))]
{
let (_, connection) = stream.get_ref().get_ref();
Ok(match connection.protocol_version() {
// TODO: Add support for TLS 1.2 and earlier.
Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
let data = vec![0u8; 32];
let data = connection
.export_keying_material(data, b"EXPORTER-Channel-Binding", None)
.map_err(|e| StartTlsError::Tls(e))?;
ChannelBinding::TlsExporter(data)
}
_ => ChannelBinding::None,
})
}
}
}
#[cfg(feature = "tls-native")]
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
xmpp_stream: XmppStream<BufStream<S>>,
domain: &str,
) -> Result<TlsStream<S>, Error> {
) -> Result<(TlsStream<S>, ChannelBinding), Error> {
let domain = domain.to_owned();
let stream = xmpp_stream.into_inner().into_inner();
let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
.connect(&domain, stream)
.await
.map_err(|e| StartTlsError::Tls(e))?;
Ok(tls_stream)
log::warn!(
"tls-native doesnt support channel binding, please use tls-rust if you want this feature!"
);
Ok((tls_stream, ChannelBinding::None))
}
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
xmpp_stream: XmppStream<BufStream<S>>,
domain: &str,
) -> Result<TlsStream<S>, Error> {
) -> Result<(TlsStream<S>, ChannelBinding), Error> {
let domain = ServerName::try_from(domain.to_owned()).map_err(StartTlsError::DnsNameError)?;
let stream = xmpp_stream.into_inner().into_inner();
let mut root_store = RootCertStore::empty();
@ -197,11 +170,26 @@ async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
.connect(domain, stream)
.await
.map_err(|e| Error::from(crate::Error::Io(e)))?;
// Extract the channel-binding information before we hand the stream over to ktls.
let (_, connection) = tls_stream.get_ref();
let channel_binding = match connection.protocol_version() {
// TODO: Add support for TLS 1.2 and earlier.
Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
let data = vec![0u8; 32];
let data = connection
.export_keying_material(data, b"EXPORTER-Channel-Binding", None)
.map_err(|e| StartTlsError::Tls(e))?;
ChannelBinding::TlsExporter(data)
}
_ => ChannelBinding::None,
};
#[cfg(feature = "tls-rust-ktls")]
let tls_stream = ktls::config_ktls_client(tls_stream)
.await
.map_err(StartTlsError::KtlsError)?;
Ok(tls_stream)
Ok((tls_stream, channel_binding))
}
/// Performs `<starttls/>` on an XmppStream and returns a binary
@ -209,7 +197,7 @@ async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
mut stream: XmppStream<BufStream<S>>,
domain: &str,
) -> Result<TlsStream<S>, Error> {
) -> Result<(TlsStream<S>, ChannelBinding), Error> {
stream
.send(&XmppStreamElement::Starttls(starttls::Nonza::Request(
Request,

View file

@ -5,7 +5,7 @@ use std::borrow::Cow;
use tokio::{io::BufStream, net::TcpStream};
use crate::{
connect::{DnsConfig, ServerConnector},
connect::{ChannelBinding, DnsConfig, ServerConnector},
xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader, Timeouts},
Client, Component, Error,
};
@ -37,9 +37,10 @@ impl ServerConnector for TcpServerConnector {
jid: &xmpp_parsers::jid::Jid,
ns: &'static str,
timeouts: Timeouts,
) -> Result<PendingFeaturesRecv<Self::Stream>, Error> {
) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
let stream = BufStream::new(self.0.resolve().await?);
Ok(initiate_stream(
Ok((
initiate_stream(
stream,
ns,
StreamHeader {
@ -49,6 +50,8 @@ impl ServerConnector for TcpServerConnector {
},
timeouts,
)
.await?)
.await?,
ChannelBinding::None,
))
}
}