xmpp-rs-mirror/tokio-xmpp/src/starttls.rs

87 lines
2.6 KiB
Rust
Raw Normal View History

2020-03-05 00:25:24 +00:00
use futures::{sink::SinkExt, stream::StreamExt};
2021-02-15 19:45:58 +00:00
#[cfg(feature = "tls-rust")]
2021-02-15 19:45:58 +00:00
use {
std::convert::TryFrom,
2021-02-15 19:45:58 +00:00
std::sync::Arc,
tokio_rustls::{
client::TlsStream,
rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName},
TlsConnector,
},
2021-02-15 19:45:58 +00:00
webpki_roots,
};
#[cfg(feature = "tls-native")]
2021-02-15 19:45:58 +00:00
use {
native_tls::TlsConnector as NativeTlsConnector,
tokio_native_tls::{TlsConnector, TlsStream},
};
2020-03-05 00:25:24 +00:00
use tokio::io::{AsyncRead, AsyncWrite};
use xmpp_parsers::{ns, Element};
2017-06-04 22:42:35 +00:00
2018-12-18 17:29:31 +00:00
use crate::xmpp_codec::Packet;
use crate::xmpp_stream::XMPPStream;
2020-03-05 00:25:24 +00:00
use crate::{Error, ProtocolError};
2017-06-04 22:42:35 +00:00
#[cfg(feature = "tls-native")]
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
xmpp_stream: XMPPStream<S>,
) -> Result<TlsStream<S>, Error> {
let domain = &xmpp_stream.jid.clone().domain();
let stream = xmpp_stream.into_inner();
let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
.connect(&domain, stream)
.await?;
Ok(tls_stream)
}
#[cfg(feature = "tls-rust")]
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
xmpp_stream: XMPPStream<S>,
) -> Result<TlsStream<S>, Error> {
let domain = &xmpp_stream.jid.clone().domain();
let domain = ServerName::try_from(domain.as_str()).unwrap();
let stream = xmpp_stream.into_inner();
let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
2021-02-15 19:45:58 +00:00
let tls_stream = TlsConnector::from(Arc::new(config))
.connect(domain, stream)
.await?;
Ok(tls_stream)
}
2020-03-15 23:34:46 +00:00
/// Performs `<starttls/>` on an XMPPStream and returns a binary
/// TlsStream.
2020-03-05 00:25:24 +00:00
pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin>(
mut xmpp_stream: XMPPStream<S>,
) -> Result<TlsStream<S>, Error> {
let nonza = Element::builder("starttls", ns::TLS).build();
2020-03-05 00:25:24 +00:00
let packet = Packet::Stanza(nonza);
xmpp_stream.send(packet).await?;
loop {
match xmpp_stream.next().await {
Some(Ok(Packet::Stanza(ref stanza))) if stanza.name() == "proceed" => break,
Some(Ok(Packet::Text(_))) => {}
Some(Err(e)) => return Err(e.into()),
_ => {
return Err(ProtocolError::NoTls.into());
}
2017-06-04 22:42:35 +00:00
}
}
2018-09-01 19:59:02 +00:00
get_tls_stream(xmpp_stream).await
2017-06-04 22:42:35 +00:00
}