xmpp-rs-mirror/tokio-xmpp/src/starttls.rs

115 lines
4.1 KiB
Rust
Raw Normal View History

2017-06-04 22:42:35 +00:00
use futures::sink;
2018-12-18 18:04:31 +00:00
use futures::stream::Stream;
use futures::{Async, Future, Poll, Sink};
use native_tls::TlsConnector as NativeTlsConnector;
use std::mem::replace;
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_tls::{Connect, TlsConnector, TlsStream};
use xmpp_parsers::{Element, Jid};
2017-06-04 22:42:35 +00:00
2018-12-18 17:29:31 +00:00
use crate::xmpp_codec::Packet;
use crate::xmpp_stream::XMPPStream;
use crate::Error;
2017-06-04 22:42:35 +00:00
2018-08-02 17:58:19 +00:00
/// XMPP TLS XML namespace
pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
2017-06-04 22:42:35 +00:00
2018-08-02 17:58:19 +00:00
/// XMPP stream that switches to TLS if available in received features
pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
2017-06-04 22:45:16 +00:00
state: StartTlsClientState<S>,
2017-06-13 23:55:56 +00:00
jid: Jid,
2017-06-04 22:42:35 +00:00
}
enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
2017-06-04 22:42:35 +00:00
Invalid,
2017-06-04 22:45:16 +00:00
SendStartTls(sink::Send<XMPPStream<S>>),
AwaitProceed(XMPPStream<S>),
2018-09-01 19:59:02 +00:00
StartingTls(Connect<S>),
2017-06-04 22:42:35 +00:00
}
impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
2017-06-04 22:42:35 +00:00
/// Waits for <stream:features>
2017-06-06 00:03:38 +00:00
pub fn from_stream(xmpp_stream: XMPPStream<S>) -> Self {
2017-06-13 23:55:56 +00:00
let jid = xmpp_stream.jid.clone();
2017-06-06 00:03:38 +00:00
2018-12-18 18:04:31 +00:00
let nonza = Element::builder("starttls").ns(NS_XMPP_TLS).build();
let packet = Packet::Stanza(nonza);
let send = xmpp_stream.send(packet);
2017-06-04 22:42:35 +00:00
StartTlsClient {
state: StartTlsClientState::SendStartTls(send),
2017-06-13 23:55:56 +00:00
jid,
2017-06-04 22:42:35 +00:00
}
}
}
2017-06-04 22:45:16 +00:00
impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
2017-07-18 23:02:45 +00:00
type Item = TlsStream<S>;
2018-09-06 15:46:06 +00:00
type Error = Error;
2017-06-04 22:42:35 +00:00
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
let mut retry = false;
2018-09-01 19:59:02 +00:00
2017-06-04 22:42:35 +00:00
let (new_state, result) = match old_state {
2018-12-18 18:04:31 +00:00
StartTlsClientState::SendStartTls(mut send) => match send.poll() {
Ok(Async::Ready(xmpp_stream)) => {
let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
retry = true;
(new_state, Ok(Async::NotReady))
}
Ok(Async::NotReady) => {
(StartTlsClientState::SendStartTls(send), Ok(Async::NotReady))
}
Err(e) => (StartTlsClientState::SendStartTls(send), Err(e.into())),
},
StartTlsClientState::AwaitProceed(mut xmpp_stream) => match xmpp_stream.poll() {
Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
if stanza.name() == "proceed" =>
{
let stream = xmpp_stream.stream.into_inner();
let connect =
TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
2019-09-08 19:28:44 +00:00
.connect(&self.jid.clone().domain(), stream);
2018-12-18 18:04:31 +00:00
let new_state = StartTlsClientState::StartingTls(connect);
retry = true;
(new_state, Ok(Async::NotReady))
}
2019-01-26 18:28:57 +00:00
Ok(Async::Ready(_value)) => {
// println!("StartTlsClient ignore {:?}", _value);
2018-12-18 18:04:31 +00:00
(
StartTlsClientState::AwaitProceed(xmpp_stream),
Ok(Async::NotReady),
)
}
Ok(_) => (
StartTlsClientState::AwaitProceed(xmpp_stream),
Ok(Async::NotReady),
),
Err(e) => (
StartTlsClientState::AwaitProceed(xmpp_stream),
Err(Error::Protocol(e.into())),
),
},
StartTlsClientState::StartingTls(mut connect) => match connect.poll() {
Ok(Async::Ready(tls_stream)) => {
(StartTlsClientState::Invalid, Ok(Async::Ready(tls_stream)))
}
Ok(Async::NotReady) => (
StartTlsClientState::StartingTls(connect),
Ok(Async::NotReady),
),
Err(e) => (StartTlsClientState::Invalid, Err(e.into())),
},
StartTlsClientState::Invalid => unreachable!(),
2017-06-04 22:42:35 +00:00
};
self.state = new_state;
if retry {
self.poll()
} else {
result
}
}
}