use futures::{sink::SinkExt, stream::StreamExt}; #[cfg(feature = "tls-rust")] use { std::convert::TryFrom, std::sync::Arc, tokio_rustls::{ client::TlsStream, rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}, TlsConnector, }, webpki_roots, }; #[cfg(feature = "tls-native")] use { native_tls::TlsConnector as NativeTlsConnector, tokio_native_tls::{TlsConnector, TlsStream}, }; use tokio::io::{AsyncRead, AsyncWrite}; use xmpp_parsers::{ns, Element}; use crate::xmpp_codec::Packet; use crate::xmpp_stream::XMPPStream; use crate::{Error, ProtocolError}; #[cfg(feature = "tls-native")] async fn get_tls_stream( xmpp_stream: XMPPStream, ) -> Result, Error> { let domain = &xmpp_stream.jid.clone().domain(); let stream = xmpp_stream.into_inner(); let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap()) .connect(&domain, stream) .await?; Ok(tls_stream) } #[cfg(feature = "tls-rust")] async fn get_tls_stream( xmpp_stream: XMPPStream, ) -> Result, Error> { let domain = &xmpp_stream.jid.clone().domain(); let domain = ServerName::try_from(domain.as_str()).unwrap(); let stream = xmpp_stream.into_inner(); let mut root_store = RootCertStore::empty(); root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { OwnedTrustAnchor::from_subject_spki_name_constraints( ta.subject, ta.spki, ta.name_constraints, ) })); let config = ClientConfig::builder() .with_safe_defaults() .with_root_certificates(root_store) .with_no_client_auth(); let tls_stream = TlsConnector::from(Arc::new(config)) .connect(domain, stream) .await?; Ok(tls_stream) } /// Performs `` on an XMPPStream and returns a binary /// TlsStream. pub async fn starttls( mut xmpp_stream: XMPPStream, ) -> Result, Error> { let nonza = Element::builder("starttls", ns::TLS).build(); let packet = Packet::Stanza(nonza); xmpp_stream.send(packet).await?; loop { match xmpp_stream.next().await { Some(Ok(Packet::Stanza(ref stanza))) if stanza.name() == "proceed" => break, Some(Ok(Packet::Text(_))) => {} Some(Err(e)) => return Err(e.into()), _ => { return Err(ProtocolError::NoTls.into()); } } } get_tls_stream(xmpp_stream).await }