xmpp-rs/src/starttls.rs

132 lines
5 KiB
Rust
Raw Normal View History

2017-06-04 22:42:35 +00:00
use std::mem::replace;
use futures::{Future, Sink, Poll, Async};
use futures::stream::Stream;
use futures::sink;
2017-06-04 22:45:16 +00:00
use tokio_io::{AsyncRead, AsyncWrite};
2017-06-06 00:03:38 +00:00
use tokio_tls::*;
use native_tls::TlsConnector;
2017-06-04 22:42:35 +00:00
use xml;
use xmpp_codec::*;
use xmpp_stream::*;
use stream_start::StreamStart;
2017-06-04 22:42:35 +00:00
pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
2017-06-04 22:42:35 +00:00
pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
2017-06-04 22:45:16 +00:00
state: StartTlsClientState<S>,
2017-06-06 00:03:38 +00:00
domain: String,
2017-06-04 22:42:35 +00:00
}
enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
2017-06-04 22:42:35 +00:00
Invalid,
2017-06-04 22:45:16 +00:00
SendStartTls(sink::Send<XMPPStream<S>>),
AwaitProceed(XMPPStream<S>),
StartingTls(ConnectAsync<S>),
2017-06-06 00:03:38 +00:00
Start(StreamStart<TlsStream<S>>),
2017-06-04 22:42:35 +00:00
}
impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
2017-06-04 22:42:35 +00:00
/// Waits for <stream:features>
2017-06-06 00:03:38 +00:00
pub fn from_stream(xmpp_stream: XMPPStream<S>) -> Self {
let domain = xmpp_stream.stream_attrs.get("from")
.map(|s| s.to_owned())
.unwrap_or_else(|| String::new());
let nonza = xml::Element::new(
"starttls".to_owned(), Some(NS_XMPP_TLS.to_owned()),
vec![]
);
println!("send {}", nonza);
let packet = Packet::Stanza(nonza);
let send = xmpp_stream.send(packet);
2017-06-04 22:42:35 +00:00
StartTlsClient {
state: StartTlsClientState::SendStartTls(send),
2017-06-06 00:03:38 +00:00
domain,
2017-06-04 22:42:35 +00:00
}
}
}
2017-06-04 22:45:16 +00:00
impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
2017-06-06 00:03:38 +00:00
type Item = XMPPStream<TlsStream<S>>;
type Error = String;
2017-06-04 22:42:35 +00:00
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
let mut retry = false;
let (new_state, result) = match old_state {
StartTlsClientState::SendStartTls(mut send) =>
match send.poll() {
Ok(Async::Ready(xmpp_stream)) => {
println!("starttls sent");
let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
retry = true;
(new_state, Ok(Async::NotReady))
},
Ok(Async::NotReady) =>
(StartTlsClientState::SendStartTls(send), Ok(Async::NotReady)),
Err(e) =>
2017-06-06 00:03:38 +00:00
(StartTlsClientState::SendStartTls(send), Err(format!("{}", e))),
2017-06-04 22:42:35 +00:00
},
StartTlsClientState::AwaitProceed(mut xmpp_stream) =>
match xmpp_stream.poll() {
Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
if stanza.name == "proceed" =>
{
println!("* proceed *");
let stream = xmpp_stream.into_inner();
2017-06-06 00:03:38 +00:00
let connect = TlsConnector::builder().unwrap()
.build().unwrap()
.connect_async(&self.domain, stream);
2017-06-04 22:42:35 +00:00
let new_state = StartTlsClientState::StartingTls(connect);
retry = true;
(new_state, Ok(Async::NotReady))
},
Ok(Async::Ready(value)) => {
println!("StartTlsClient ignore {:?}", value);
(StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady))
2017-06-04 22:42:35 +00:00
},
Ok(_) =>
(StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)),
Err(e) =>
2017-06-06 00:03:38 +00:00
(StartTlsClientState::AwaitProceed(xmpp_stream), Err(format!("{}", e))),
2017-06-04 22:42:35 +00:00
},
StartTlsClientState::StartingTls(mut connect) =>
match connect.poll() {
Ok(Async::Ready(tls_stream)) => {
println!("Got a TLS stream!");
2017-06-06 00:03:38 +00:00
let start = XMPPStream::from_stream(tls_stream, self.domain.clone());
let new_state = StartTlsClientState::Start(start);
retry = true;
(new_state, Ok(Async::NotReady))
2017-06-04 22:42:35 +00:00
},
Ok(Async::NotReady) =>
(StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)),
Err(e) =>
2017-06-06 00:03:38 +00:00
(StartTlsClientState::StartingTls(connect), Err(format!("{}", e))),
2017-06-04 22:42:35 +00:00
},
StartTlsClientState::Start(mut start) =>
match start.poll() {
Ok(Async::Ready(xmpp_stream)) =>
(StartTlsClientState::Invalid, Ok(Async::Ready(xmpp_stream))),
Ok(Async::NotReady) =>
(StartTlsClientState::Start(start), Ok(Async::NotReady)),
Err(e) =>
2017-06-06 00:03:38 +00:00
(StartTlsClientState::Invalid, Err(format!("{}", e))),
},
2017-06-04 22:42:35 +00:00
StartTlsClientState::Invalid =>
unreachable!(),
};
self.state = new_state;
if retry {
self.poll()
} else {
result
}
}
}