xmpp-rs/tokio-xmpp/src/starttls.rs
2020-03-06 18:01:31 +01:00

39 lines
1.3 KiB
Rust

use futures::{sink::SinkExt, stream::StreamExt};
use native_tls::TlsConnector as NativeTlsConnector;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_tls::{TlsConnector, TlsStream};
use xmpp_parsers::Element;
use crate::xmpp_codec::Packet;
use crate::xmpp_stream::XMPPStream;
use crate::{Error, ProtocolError};
/// XMPP TLS XML namespace
pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin>(
mut xmpp_stream: XMPPStream<S>,
) -> Result<TlsStream<S>, Error> {
let nonza = Element::builder("starttls").ns(NS_XMPP_TLS).build();
let packet = Packet::Stanza(nonza);
xmpp_stream.send(packet).await?;
loop {
match xmpp_stream.next().await {
Some(Ok(Packet::Stanza(ref stanza))) if stanza.name() == "proceed" => break,
Some(Ok(Packet::Text(_))) => {}
Some(Err(e)) => return Err(e.into()),
_ => {
return Err(ProtocolError::NoTls.into());
}
}
}
let domain = xmpp_stream.jid.clone().domain();
let stream = xmpp_stream.into_inner();
let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
.connect(&domain, stream)
.await?;
Ok(tls_stream)
}