tokio-xmpp: Refactor to provide channel-binding for ktls
Instead of having a second method to fetch channel-binding from the TlsStream, do it directly in the connect() method, since after that we don’t have enough information to fetch it any longer when using ktls.
This commit is contained in:
parent
7991cef904
commit
897c2abe0b
5 changed files with 61 additions and 77 deletions
|
@ -112,11 +112,9 @@ pub async fn client_auth<C: ServerConnector>(
|
|||
let username = jid.node().unwrap().as_str();
|
||||
let password = password;
|
||||
|
||||
let xmpp_stream = server.connect(&jid, ns::JABBER_CLIENT, timeouts).await?;
|
||||
let (xmpp_stream, channel_binding) = server.connect(&jid, ns::JABBER_CLIENT, timeouts).await?;
|
||||
let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
|
||||
|
||||
let channel_binding = C::channel_binding(xmpp_stream.get_stream())?;
|
||||
|
||||
let creds = Credentials::default()
|
||||
.with_username(username)
|
||||
.with_password(password)
|
||||
|
|
|
@ -16,7 +16,7 @@ pub async fn component_login<C: ServerConnector>(
|
|||
timeouts: Timeouts,
|
||||
) -> Result<XmppStream<C::Stream>, Error> {
|
||||
let password = password;
|
||||
let mut stream = connector.connect(&jid, ns::COMPONENT, timeouts).await?;
|
||||
let (mut stream, _) = connector.connect(&jid, ns::COMPONENT, timeouts).await?;
|
||||
let header = stream.take_header();
|
||||
let mut stream = stream.skip_features();
|
||||
let stream_id = match header.id {
|
||||
|
|
|
@ -37,12 +37,7 @@ pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
|
|||
jid: &Jid,
|
||||
ns: &'static str,
|
||||
timeouts: Timeouts,
|
||||
) -> impl std::future::Future<Output = Result<PendingFeaturesRecv<Self::Stream>, Error>> + Send;
|
||||
|
||||
/// Return channel binding data if available
|
||||
/// do not fail if channel binding is simply unavailable, just return Ok(None)
|
||||
/// this should only be called after the TLS handshake is finished
|
||||
fn channel_binding(_stream: &Self::Stream) -> Result<ChannelBinding, Error> {
|
||||
Ok(ChannelBinding::None)
|
||||
}
|
||||
) -> impl std::future::Future<
|
||||
Output = Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error>,
|
||||
> + Send;
|
||||
}
|
||||
|
|
|
@ -82,7 +82,7 @@ impl ServerConnector for StartTlsServerConnector {
|
|||
jid: &Jid,
|
||||
ns: &'static str,
|
||||
timeouts: Timeouts,
|
||||
) -> Result<PendingFeaturesRecv<Self::Stream>, Error> {
|
||||
) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
|
||||
let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?);
|
||||
|
||||
// Unencryped XmppStream
|
||||
|
@ -101,9 +101,11 @@ impl ServerConnector for StartTlsServerConnector {
|
|||
|
||||
if features.can_starttls() {
|
||||
// TlsStream
|
||||
let tls_stream = starttls(xmpp_stream, jid.domain().as_str()).await?;
|
||||
let (tls_stream, channel_binding) =
|
||||
starttls(xmpp_stream, jid.domain().as_str()).await?;
|
||||
// Encrypted XmppStream
|
||||
Ok(initiate_stream(
|
||||
Ok((
|
||||
initiate_stream(
|
||||
tokio::io::BufStream::new(tls_stream),
|
||||
ns,
|
||||
StreamHeader {
|
||||
|
@ -113,66 +115,37 @@ impl ServerConnector for StartTlsServerConnector {
|
|||
},
|
||||
timeouts,
|
||||
)
|
||||
.await?)
|
||||
.await?,
|
||||
channel_binding,
|
||||
))
|
||||
} else {
|
||||
Err(crate::Error::Protocol(ProtocolError::NoTls).into())
|
||||
}
|
||||
}
|
||||
|
||||
fn channel_binding(
|
||||
#[allow(unused_variables)] stream: &Self::Stream,
|
||||
) -> Result<sasl::common::ChannelBinding, Error> {
|
||||
#[cfg(feature = "tls-native")]
|
||||
{
|
||||
log::warn!("tls-native doesn’t support channel binding, please use tls-rust if you want this feature!");
|
||||
Ok(ChannelBinding::None)
|
||||
}
|
||||
#[cfg(all(feature = "tls-rust-ktls", not(feature = "tls-native")))]
|
||||
{
|
||||
log::warn!("Kernel TLS doesn’t support channel binding yet, we would have to extract the secrets in the rustls TlsStream before converting it into a KtlsStream.");
|
||||
Ok(ChannelBinding::None)
|
||||
}
|
||||
#[cfg(all(
|
||||
feature = "tls-rust",
|
||||
not(feature = "tls-native"),
|
||||
not(feature = "tls-rust-ktls")
|
||||
))]
|
||||
{
|
||||
let (_, connection) = stream.get_ref().get_ref();
|
||||
Ok(match connection.protocol_version() {
|
||||
// TODO: Add support for TLS 1.2 and earlier.
|
||||
Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
|
||||
let data = vec![0u8; 32];
|
||||
let data = connection
|
||||
.export_keying_material(data, b"EXPORTER-Channel-Binding", None)
|
||||
.map_err(|e| StartTlsError::Tls(e))?;
|
||||
ChannelBinding::TlsExporter(data)
|
||||
}
|
||||
_ => ChannelBinding::None,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls-native")]
|
||||
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
xmpp_stream: XmppStream<BufStream<S>>,
|
||||
domain: &str,
|
||||
) -> Result<TlsStream<S>, Error> {
|
||||
) -> Result<(TlsStream<S>, ChannelBinding), Error> {
|
||||
let domain = domain.to_owned();
|
||||
let stream = xmpp_stream.into_inner().into_inner();
|
||||
let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
|
||||
.connect(&domain, stream)
|
||||
.await
|
||||
.map_err(|e| StartTlsError::Tls(e))?;
|
||||
Ok(tls_stream)
|
||||
log::warn!(
|
||||
"tls-native doesn’t support channel binding, please use tls-rust if you want this feature!"
|
||||
);
|
||||
Ok((tls_stream, ChannelBinding::None))
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
|
||||
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
|
||||
xmpp_stream: XmppStream<BufStream<S>>,
|
||||
domain: &str,
|
||||
) -> Result<TlsStream<S>, Error> {
|
||||
) -> Result<(TlsStream<S>, ChannelBinding), Error> {
|
||||
let domain = ServerName::try_from(domain.to_owned()).map_err(StartTlsError::DnsNameError)?;
|
||||
let stream = xmpp_stream.into_inner().into_inner();
|
||||
let mut root_store = RootCertStore::empty();
|
||||
|
@ -197,11 +170,26 @@ async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
|
|||
.connect(domain, stream)
|
||||
.await
|
||||
.map_err(|e| Error::from(crate::Error::Io(e)))?;
|
||||
|
||||
// Extract the channel-binding information before we hand the stream over to ktls.
|
||||
let (_, connection) = tls_stream.get_ref();
|
||||
let channel_binding = match connection.protocol_version() {
|
||||
// TODO: Add support for TLS 1.2 and earlier.
|
||||
Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
|
||||
let data = vec![0u8; 32];
|
||||
let data = connection
|
||||
.export_keying_material(data, b"EXPORTER-Channel-Binding", None)
|
||||
.map_err(|e| StartTlsError::Tls(e))?;
|
||||
ChannelBinding::TlsExporter(data)
|
||||
}
|
||||
_ => ChannelBinding::None,
|
||||
};
|
||||
|
||||
#[cfg(feature = "tls-rust-ktls")]
|
||||
let tls_stream = ktls::config_ktls_client(tls_stream)
|
||||
.await
|
||||
.map_err(StartTlsError::KtlsError)?;
|
||||
Ok(tls_stream)
|
||||
Ok((tls_stream, channel_binding))
|
||||
}
|
||||
|
||||
/// Performs `<starttls/>` on an XmppStream and returns a binary
|
||||
|
@ -209,7 +197,7 @@ async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
|
|||
pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
|
||||
mut stream: XmppStream<BufStream<S>>,
|
||||
domain: &str,
|
||||
) -> Result<TlsStream<S>, Error> {
|
||||
) -> Result<(TlsStream<S>, ChannelBinding), Error> {
|
||||
stream
|
||||
.send(&XmppStreamElement::Starttls(starttls::Nonza::Request(
|
||||
Request,
|
||||
|
|
|
@ -5,7 +5,7 @@ use std::borrow::Cow;
|
|||
use tokio::{io::BufStream, net::TcpStream};
|
||||
|
||||
use crate::{
|
||||
connect::{DnsConfig, ServerConnector},
|
||||
connect::{ChannelBinding, DnsConfig, ServerConnector},
|
||||
xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader, Timeouts},
|
||||
Client, Component, Error,
|
||||
};
|
||||
|
@ -37,9 +37,10 @@ impl ServerConnector for TcpServerConnector {
|
|||
jid: &xmpp_parsers::jid::Jid,
|
||||
ns: &'static str,
|
||||
timeouts: Timeouts,
|
||||
) -> Result<PendingFeaturesRecv<Self::Stream>, Error> {
|
||||
) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
|
||||
let stream = BufStream::new(self.0.resolve().await?);
|
||||
Ok(initiate_stream(
|
||||
Ok((
|
||||
initiate_stream(
|
||||
stream,
|
||||
ns,
|
||||
StreamHeader {
|
||||
|
@ -49,6 +50,8 @@ impl ServerConnector for TcpServerConnector {
|
|||
},
|
||||
timeouts,
|
||||
)
|
||||
.await?)
|
||||
.await?,
|
||||
ChannelBinding::None,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue