From 897c2abe0b085a88011aaa756f98bdd48fc8db28 Mon Sep 17 00:00:00 2001 From: Emmanuel Gil Peyrot Date: Wed, 18 Dec 2024 18:17:39 +0100 Subject: [PATCH] tokio-xmpp: Refactor to provide channel-binding for ktls MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Instead of having a second method to fetch channel-binding from the TlsStream, do it directly in the connect() method, since after that we don’t have enough information to fetch it any longer when using ktls. --- tokio-xmpp/src/client/login.rs | 4 +- tokio-xmpp/src/component/login.rs | 2 +- tokio-xmpp/src/connect/mod.rs | 11 +--- tokio-xmpp/src/connect/starttls.rs | 92 +++++++++++++----------------- tokio-xmpp/src/connect/tcp.rs | 29 +++++----- 5 files changed, 61 insertions(+), 77 deletions(-) diff --git a/tokio-xmpp/src/client/login.rs b/tokio-xmpp/src/client/login.rs index d232de4d..a1bc87ed 100644 --- a/tokio-xmpp/src/client/login.rs +++ b/tokio-xmpp/src/client/login.rs @@ -112,11 +112,9 @@ pub async fn client_auth( let username = jid.node().unwrap().as_str(); let password = password; - let xmpp_stream = server.connect(&jid, ns::JABBER_CLIENT, timeouts).await?; + let (xmpp_stream, channel_binding) = server.connect(&jid, ns::JABBER_CLIENT, timeouts).await?; let (features, xmpp_stream) = xmpp_stream.recv_features().await?; - let channel_binding = C::channel_binding(xmpp_stream.get_stream())?; - let creds = Credentials::default() .with_username(username) .with_password(password) diff --git a/tokio-xmpp/src/component/login.rs b/tokio-xmpp/src/component/login.rs index 33b743e4..be8b3b4d 100644 --- a/tokio-xmpp/src/component/login.rs +++ b/tokio-xmpp/src/component/login.rs @@ -16,7 +16,7 @@ pub async fn component_login( timeouts: Timeouts, ) -> Result, Error> { let password = password; - let mut stream = connector.connect(&jid, ns::COMPONENT, timeouts).await?; + let (mut stream, _) = connector.connect(&jid, ns::COMPONENT, timeouts).await?; let header = stream.take_header(); let mut stream = stream.skip_features(); let stream_id = match header.id { diff --git a/tokio-xmpp/src/connect/mod.rs b/tokio-xmpp/src/connect/mod.rs index ca1ae61f..0fba5b0d 100644 --- a/tokio-xmpp/src/connect/mod.rs +++ b/tokio-xmpp/src/connect/mod.rs @@ -37,12 +37,7 @@ pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static { jid: &Jid, ns: &'static str, timeouts: Timeouts, - ) -> impl std::future::Future, Error>> + Send; - - /// Return channel binding data if available - /// do not fail if channel binding is simply unavailable, just return Ok(None) - /// this should only be called after the TLS handshake is finished - fn channel_binding(_stream: &Self::Stream) -> Result { - Ok(ChannelBinding::None) - } + ) -> impl std::future::Future< + Output = Result<(PendingFeaturesRecv, ChannelBinding), Error>, + > + Send; } diff --git a/tokio-xmpp/src/connect/starttls.rs b/tokio-xmpp/src/connect/starttls.rs index f8c33790..067c7e24 100644 --- a/tokio-xmpp/src/connect/starttls.rs +++ b/tokio-xmpp/src/connect/starttls.rs @@ -82,7 +82,7 @@ impl ServerConnector for StartTlsServerConnector { jid: &Jid, ns: &'static str, timeouts: Timeouts, - ) -> Result, Error> { + ) -> Result<(PendingFeaturesRecv, ChannelBinding), Error> { let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?); // Unencryped XmppStream @@ -101,78 +101,51 @@ impl ServerConnector for StartTlsServerConnector { if features.can_starttls() { // TlsStream - let tls_stream = starttls(xmpp_stream, jid.domain().as_str()).await?; + let (tls_stream, channel_binding) = + starttls(xmpp_stream, jid.domain().as_str()).await?; // Encrypted XmppStream - Ok(initiate_stream( - tokio::io::BufStream::new(tls_stream), - ns, - StreamHeader { - to: Some(Cow::Borrowed(jid.domain().as_str())), - from: None, - id: None, - }, - timeouts, - ) - .await?) + Ok(( + initiate_stream( + tokio::io::BufStream::new(tls_stream), + ns, + StreamHeader { + to: Some(Cow::Borrowed(jid.domain().as_str())), + from: None, + id: None, + }, + timeouts, + ) + .await?, + channel_binding, + )) } else { Err(crate::Error::Protocol(ProtocolError::NoTls).into()) } } - - fn channel_binding( - #[allow(unused_variables)] stream: &Self::Stream, - ) -> Result { - #[cfg(feature = "tls-native")] - { - log::warn!("tls-native doesn’t support channel binding, please use tls-rust if you want this feature!"); - Ok(ChannelBinding::None) - } - #[cfg(all(feature = "tls-rust-ktls", not(feature = "tls-native")))] - { - log::warn!("Kernel TLS doesn’t support channel binding yet, we would have to extract the secrets in the rustls TlsStream before converting it into a KtlsStream."); - Ok(ChannelBinding::None) - } - #[cfg(all( - feature = "tls-rust", - not(feature = "tls-native"), - not(feature = "tls-rust-ktls") - ))] - { - let (_, connection) = stream.get_ref().get_ref(); - Ok(match connection.protocol_version() { - // TODO: Add support for TLS 1.2 and earlier. - Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => { - let data = vec![0u8; 32]; - let data = connection - .export_keying_material(data, b"EXPORTER-Channel-Binding", None) - .map_err(|e| StartTlsError::Tls(e))?; - ChannelBinding::TlsExporter(data) - } - _ => ChannelBinding::None, - }) - } - } } #[cfg(feature = "tls-native")] async fn get_tls_stream( xmpp_stream: XmppStream>, domain: &str, -) -> Result, Error> { +) -> Result<(TlsStream, ChannelBinding), Error> { let domain = domain.to_owned(); let stream = xmpp_stream.into_inner().into_inner(); let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap()) .connect(&domain, stream) .await .map_err(|e| StartTlsError::Tls(e))?; - Ok(tls_stream) + log::warn!( + "tls-native doesn’t support channel binding, please use tls-rust if you want this feature!" + ); + Ok((tls_stream, ChannelBinding::None)) } #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] async fn get_tls_stream( xmpp_stream: XmppStream>, domain: &str, -) -> Result, Error> { +) -> Result<(TlsStream, ChannelBinding), Error> { let domain = ServerName::try_from(domain.to_owned()).map_err(StartTlsError::DnsNameError)?; let stream = xmpp_stream.into_inner().into_inner(); let mut root_store = RootCertStore::empty(); @@ -197,11 +170,26 @@ async fn get_tls_stream( .connect(domain, stream) .await .map_err(|e| Error::from(crate::Error::Io(e)))?; + + // Extract the channel-binding information before we hand the stream over to ktls. + let (_, connection) = tls_stream.get_ref(); + let channel_binding = match connection.protocol_version() { + // TODO: Add support for TLS 1.2 and earlier. + Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => { + let data = vec![0u8; 32]; + let data = connection + .export_keying_material(data, b"EXPORTER-Channel-Binding", None) + .map_err(|e| StartTlsError::Tls(e))?; + ChannelBinding::TlsExporter(data) + } + _ => ChannelBinding::None, + }; + #[cfg(feature = "tls-rust-ktls")] let tls_stream = ktls::config_ktls_client(tls_stream) .await .map_err(StartTlsError::KtlsError)?; - Ok(tls_stream) + Ok((tls_stream, channel_binding)) } /// Performs `` on an XmppStream and returns a binary @@ -209,7 +197,7 @@ async fn get_tls_stream( pub async fn starttls( mut stream: XmppStream>, domain: &str, -) -> Result, Error> { +) -> Result<(TlsStream, ChannelBinding), Error> { stream .send(&XmppStreamElement::Starttls(starttls::Nonza::Request( Request, diff --git a/tokio-xmpp/src/connect/tcp.rs b/tokio-xmpp/src/connect/tcp.rs index aa7bc2c8..af4722f6 100644 --- a/tokio-xmpp/src/connect/tcp.rs +++ b/tokio-xmpp/src/connect/tcp.rs @@ -5,7 +5,7 @@ use std::borrow::Cow; use tokio::{io::BufStream, net::TcpStream}; use crate::{ - connect::{DnsConfig, ServerConnector}, + connect::{ChannelBinding, DnsConfig, ServerConnector}, xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader, Timeouts}, Client, Component, Error, }; @@ -37,18 +37,21 @@ impl ServerConnector for TcpServerConnector { jid: &xmpp_parsers::jid::Jid, ns: &'static str, timeouts: Timeouts, - ) -> Result, Error> { + ) -> Result<(PendingFeaturesRecv, ChannelBinding), Error> { let stream = BufStream::new(self.0.resolve().await?); - Ok(initiate_stream( - stream, - ns, - StreamHeader { - to: Some(Cow::Borrowed(jid.domain().as_str())), - from: None, - id: None, - }, - timeouts, - ) - .await?) + Ok(( + initiate_stream( + stream, + ns, + StreamHeader { + to: Some(Cow::Borrowed(jid.domain().as_str())), + from: None, + id: None, + }, + timeouts, + ) + .await?, + ChannelBinding::None, + )) } }