tokio-xmpp: Refactor to provide channel-binding for ktls
Instead of having a second method to fetch channel-binding from the TlsStream, do it directly in the connect() method, since after that we don’t have enough information to fetch it any longer when using ktls.
This commit is contained in:
parent
7991cef904
commit
897c2abe0b
5 changed files with 61 additions and 77 deletions
|
@ -112,11 +112,9 @@ pub async fn client_auth<C: ServerConnector>(
|
||||||
let username = jid.node().unwrap().as_str();
|
let username = jid.node().unwrap().as_str();
|
||||||
let password = password;
|
let password = password;
|
||||||
|
|
||||||
let xmpp_stream = server.connect(&jid, ns::JABBER_CLIENT, timeouts).await?;
|
let (xmpp_stream, channel_binding) = server.connect(&jid, ns::JABBER_CLIENT, timeouts).await?;
|
||||||
let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
|
let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
|
||||||
|
|
||||||
let channel_binding = C::channel_binding(xmpp_stream.get_stream())?;
|
|
||||||
|
|
||||||
let creds = Credentials::default()
|
let creds = Credentials::default()
|
||||||
.with_username(username)
|
.with_username(username)
|
||||||
.with_password(password)
|
.with_password(password)
|
||||||
|
|
|
@ -16,7 +16,7 @@ pub async fn component_login<C: ServerConnector>(
|
||||||
timeouts: Timeouts,
|
timeouts: Timeouts,
|
||||||
) -> Result<XmppStream<C::Stream>, Error> {
|
) -> Result<XmppStream<C::Stream>, Error> {
|
||||||
let password = password;
|
let password = password;
|
||||||
let mut stream = connector.connect(&jid, ns::COMPONENT, timeouts).await?;
|
let (mut stream, _) = connector.connect(&jid, ns::COMPONENT, timeouts).await?;
|
||||||
let header = stream.take_header();
|
let header = stream.take_header();
|
||||||
let mut stream = stream.skip_features();
|
let mut stream = stream.skip_features();
|
||||||
let stream_id = match header.id {
|
let stream_id = match header.id {
|
||||||
|
|
|
@ -37,12 +37,7 @@ pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
|
||||||
jid: &Jid,
|
jid: &Jid,
|
||||||
ns: &'static str,
|
ns: &'static str,
|
||||||
timeouts: Timeouts,
|
timeouts: Timeouts,
|
||||||
) -> impl std::future::Future<Output = Result<PendingFeaturesRecv<Self::Stream>, Error>> + Send;
|
) -> impl std::future::Future<
|
||||||
|
Output = Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error>,
|
||||||
/// Return channel binding data if available
|
> + Send;
|
||||||
/// do not fail if channel binding is simply unavailable, just return Ok(None)
|
|
||||||
/// this should only be called after the TLS handshake is finished
|
|
||||||
fn channel_binding(_stream: &Self::Stream) -> Result<ChannelBinding, Error> {
|
|
||||||
Ok(ChannelBinding::None)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -82,7 +82,7 @@ impl ServerConnector for StartTlsServerConnector {
|
||||||
jid: &Jid,
|
jid: &Jid,
|
||||||
ns: &'static str,
|
ns: &'static str,
|
||||||
timeouts: Timeouts,
|
timeouts: Timeouts,
|
||||||
) -> Result<PendingFeaturesRecv<Self::Stream>, Error> {
|
) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
|
||||||
let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?);
|
let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?);
|
||||||
|
|
||||||
// Unencryped XmppStream
|
// Unencryped XmppStream
|
||||||
|
@ -101,78 +101,51 @@ impl ServerConnector for StartTlsServerConnector {
|
||||||
|
|
||||||
if features.can_starttls() {
|
if features.can_starttls() {
|
||||||
// TlsStream
|
// TlsStream
|
||||||
let tls_stream = starttls(xmpp_stream, jid.domain().as_str()).await?;
|
let (tls_stream, channel_binding) =
|
||||||
|
starttls(xmpp_stream, jid.domain().as_str()).await?;
|
||||||
// Encrypted XmppStream
|
// Encrypted XmppStream
|
||||||
Ok(initiate_stream(
|
Ok((
|
||||||
tokio::io::BufStream::new(tls_stream),
|
initiate_stream(
|
||||||
ns,
|
tokio::io::BufStream::new(tls_stream),
|
||||||
StreamHeader {
|
ns,
|
||||||
to: Some(Cow::Borrowed(jid.domain().as_str())),
|
StreamHeader {
|
||||||
from: None,
|
to: Some(Cow::Borrowed(jid.domain().as_str())),
|
||||||
id: None,
|
from: None,
|
||||||
},
|
id: None,
|
||||||
timeouts,
|
},
|
||||||
)
|
timeouts,
|
||||||
.await?)
|
)
|
||||||
|
.await?,
|
||||||
|
channel_binding,
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
Err(crate::Error::Protocol(ProtocolError::NoTls).into())
|
Err(crate::Error::Protocol(ProtocolError::NoTls).into())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn channel_binding(
|
|
||||||
#[allow(unused_variables)] stream: &Self::Stream,
|
|
||||||
) -> Result<sasl::common::ChannelBinding, Error> {
|
|
||||||
#[cfg(feature = "tls-native")]
|
|
||||||
{
|
|
||||||
log::warn!("tls-native doesn’t support channel binding, please use tls-rust if you want this feature!");
|
|
||||||
Ok(ChannelBinding::None)
|
|
||||||
}
|
|
||||||
#[cfg(all(feature = "tls-rust-ktls", not(feature = "tls-native")))]
|
|
||||||
{
|
|
||||||
log::warn!("Kernel TLS doesn’t support channel binding yet, we would have to extract the secrets in the rustls TlsStream before converting it into a KtlsStream.");
|
|
||||||
Ok(ChannelBinding::None)
|
|
||||||
}
|
|
||||||
#[cfg(all(
|
|
||||||
feature = "tls-rust",
|
|
||||||
not(feature = "tls-native"),
|
|
||||||
not(feature = "tls-rust-ktls")
|
|
||||||
))]
|
|
||||||
{
|
|
||||||
let (_, connection) = stream.get_ref().get_ref();
|
|
||||||
Ok(match connection.protocol_version() {
|
|
||||||
// TODO: Add support for TLS 1.2 and earlier.
|
|
||||||
Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
|
|
||||||
let data = vec![0u8; 32];
|
|
||||||
let data = connection
|
|
||||||
.export_keying_material(data, b"EXPORTER-Channel-Binding", None)
|
|
||||||
.map_err(|e| StartTlsError::Tls(e))?;
|
|
||||||
ChannelBinding::TlsExporter(data)
|
|
||||||
}
|
|
||||||
_ => ChannelBinding::None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "tls-native")]
|
#[cfg(feature = "tls-native")]
|
||||||
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
|
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
|
||||||
xmpp_stream: XmppStream<BufStream<S>>,
|
xmpp_stream: XmppStream<BufStream<S>>,
|
||||||
domain: &str,
|
domain: &str,
|
||||||
) -> Result<TlsStream<S>, Error> {
|
) -> Result<(TlsStream<S>, ChannelBinding), Error> {
|
||||||
let domain = domain.to_owned();
|
let domain = domain.to_owned();
|
||||||
let stream = xmpp_stream.into_inner().into_inner();
|
let stream = xmpp_stream.into_inner().into_inner();
|
||||||
let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
|
let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
|
||||||
.connect(&domain, stream)
|
.connect(&domain, stream)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| StartTlsError::Tls(e))?;
|
.map_err(|e| StartTlsError::Tls(e))?;
|
||||||
Ok(tls_stream)
|
log::warn!(
|
||||||
|
"tls-native doesn’t support channel binding, please use tls-rust if you want this feature!"
|
||||||
|
);
|
||||||
|
Ok((tls_stream, ChannelBinding::None))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
|
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
|
||||||
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
|
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
|
||||||
xmpp_stream: XmppStream<BufStream<S>>,
|
xmpp_stream: XmppStream<BufStream<S>>,
|
||||||
domain: &str,
|
domain: &str,
|
||||||
) -> Result<TlsStream<S>, Error> {
|
) -> Result<(TlsStream<S>, ChannelBinding), Error> {
|
||||||
let domain = ServerName::try_from(domain.to_owned()).map_err(StartTlsError::DnsNameError)?;
|
let domain = ServerName::try_from(domain.to_owned()).map_err(StartTlsError::DnsNameError)?;
|
||||||
let stream = xmpp_stream.into_inner().into_inner();
|
let stream = xmpp_stream.into_inner().into_inner();
|
||||||
let mut root_store = RootCertStore::empty();
|
let mut root_store = RootCertStore::empty();
|
||||||
|
@ -197,11 +170,26 @@ async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
|
||||||
.connect(domain, stream)
|
.connect(domain, stream)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| Error::from(crate::Error::Io(e)))?;
|
.map_err(|e| Error::from(crate::Error::Io(e)))?;
|
||||||
|
|
||||||
|
// Extract the channel-binding information before we hand the stream over to ktls.
|
||||||
|
let (_, connection) = tls_stream.get_ref();
|
||||||
|
let channel_binding = match connection.protocol_version() {
|
||||||
|
// TODO: Add support for TLS 1.2 and earlier.
|
||||||
|
Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
|
||||||
|
let data = vec![0u8; 32];
|
||||||
|
let data = connection
|
||||||
|
.export_keying_material(data, b"EXPORTER-Channel-Binding", None)
|
||||||
|
.map_err(|e| StartTlsError::Tls(e))?;
|
||||||
|
ChannelBinding::TlsExporter(data)
|
||||||
|
}
|
||||||
|
_ => ChannelBinding::None,
|
||||||
|
};
|
||||||
|
|
||||||
#[cfg(feature = "tls-rust-ktls")]
|
#[cfg(feature = "tls-rust-ktls")]
|
||||||
let tls_stream = ktls::config_ktls_client(tls_stream)
|
let tls_stream = ktls::config_ktls_client(tls_stream)
|
||||||
.await
|
.await
|
||||||
.map_err(StartTlsError::KtlsError)?;
|
.map_err(StartTlsError::KtlsError)?;
|
||||||
Ok(tls_stream)
|
Ok((tls_stream, channel_binding))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Performs `<starttls/>` on an XmppStream and returns a binary
|
/// Performs `<starttls/>` on an XmppStream and returns a binary
|
||||||
|
@ -209,7 +197,7 @@ async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
|
||||||
pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
|
pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
|
||||||
mut stream: XmppStream<BufStream<S>>,
|
mut stream: XmppStream<BufStream<S>>,
|
||||||
domain: &str,
|
domain: &str,
|
||||||
) -> Result<TlsStream<S>, Error> {
|
) -> Result<(TlsStream<S>, ChannelBinding), Error> {
|
||||||
stream
|
stream
|
||||||
.send(&XmppStreamElement::Starttls(starttls::Nonza::Request(
|
.send(&XmppStreamElement::Starttls(starttls::Nonza::Request(
|
||||||
Request,
|
Request,
|
||||||
|
|
|
@ -5,7 +5,7 @@ use std::borrow::Cow;
|
||||||
use tokio::{io::BufStream, net::TcpStream};
|
use tokio::{io::BufStream, net::TcpStream};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
connect::{DnsConfig, ServerConnector},
|
connect::{ChannelBinding, DnsConfig, ServerConnector},
|
||||||
xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader, Timeouts},
|
xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader, Timeouts},
|
||||||
Client, Component, Error,
|
Client, Component, Error,
|
||||||
};
|
};
|
||||||
|
@ -37,18 +37,21 @@ impl ServerConnector for TcpServerConnector {
|
||||||
jid: &xmpp_parsers::jid::Jid,
|
jid: &xmpp_parsers::jid::Jid,
|
||||||
ns: &'static str,
|
ns: &'static str,
|
||||||
timeouts: Timeouts,
|
timeouts: Timeouts,
|
||||||
) -> Result<PendingFeaturesRecv<Self::Stream>, Error> {
|
) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
|
||||||
let stream = BufStream::new(self.0.resolve().await?);
|
let stream = BufStream::new(self.0.resolve().await?);
|
||||||
Ok(initiate_stream(
|
Ok((
|
||||||
stream,
|
initiate_stream(
|
||||||
ns,
|
stream,
|
||||||
StreamHeader {
|
ns,
|
||||||
to: Some(Cow::Borrowed(jid.domain().as_str())),
|
StreamHeader {
|
||||||
from: None,
|
to: Some(Cow::Borrowed(jid.domain().as_str())),
|
||||||
id: None,
|
from: None,
|
||||||
},
|
id: None,
|
||||||
timeouts,
|
},
|
||||||
)
|
timeouts,
|
||||||
.await?)
|
)
|
||||||
|
.await?,
|
||||||
|
ChannelBinding::None,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue