starttls works

This commit is contained in:
Astro 2017-06-05 00:42:35 +02:00
parent 482bf77955
commit a618acd6d6
6 changed files with 178 additions and 4 deletions

View file

@ -9,3 +9,5 @@ tokio-core = "*"
tokio-io = "*"
bytes = "*"
RustyXML = "*"
rustls = "*"
tokio-rustls = "*"

View file

@ -1,10 +1,15 @@
extern crate futures;
extern crate tokio_core;
extern crate tokio_xmpp;
extern crate rustls;
use std::sync::Arc;
use std::io::BufReader;
use std::fs::File;
use tokio_core::reactor::Core;
use futures::{Future, Stream};
use tokio_xmpp::{Packet, TcpClient};
use tokio_xmpp::{Packet, TcpClient, StartTlsClient};
use rustls::ClientConfig;
fn main() {
use std::net::ToSocketAddrs;
@ -12,10 +17,16 @@ fn main() {
.to_socket_addrs().unwrap()
.next().unwrap();
let mut config = ClientConfig::new();
let mut certfile = BufReader::new(File::open("/usr/share/ca-certificates/CAcert/root.crt").unwrap());
config.root_store.add_pem_file(&mut certfile).unwrap();
let arc_config = Arc::new(config);
let mut core = Core::new().unwrap();
let client = TcpClient::connect(
&addr,
&core.handle()
).and_then(|stream| StartTlsClient::from_stream(stream, arc_config)
).and_then(|stream| {
stream.for_each(|event| {
match event {
@ -25,5 +36,11 @@ fn main() {
Ok(())
})
});
core.run(client).unwrap();
match core.run(client) {
Ok(_) => (),
Err(e) => {
println!("Fatal: {}", e);
()
}
}
}

View file

@ -4,12 +4,16 @@ extern crate tokio_core;
extern crate tokio_io;
extern crate bytes;
extern crate xml;
extern crate rustls;
extern crate tokio_rustls;
mod xmpp_codec;
pub use xmpp_codec::*;
mod tcp;
pub use tcp::*;
mod starttls;
pub use starttls::*;
// type FullClient = sasl::Client<StartTLS<TCPConnection>>

142
src/starttls.rs Normal file
View file

@ -0,0 +1,142 @@
use std::mem::replace;
use std::io::{Error, ErrorKind};
use std::sync::Arc;
use futures::{Future, Sink, Poll, Async};
use futures::stream::Stream;
use futures::sink;
use tokio_core::net::TcpStream;
use rustls::*;
use tokio_rustls::*;
use xml;
use super::{XMPPStream, XMPPCodec, Packet};
const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
pub struct StartTlsClient {
state: StartTlsClientState,
arc_config: Arc<ClientConfig>,
}
enum StartTlsClientState {
Invalid,
AwaitFeatures(XMPPStream<TcpStream>),
SendStartTls(sink::Send<XMPPStream<TcpStream>>),
AwaitProceed(XMPPStream<TcpStream>),
StartingTls(ConnectAsync<TcpStream>),
}
impl StartTlsClient {
/// Waits for <stream:features>
pub fn from_stream(xmpp_stream: XMPPStream<TcpStream>, arc_config: Arc<ClientConfig>) -> Self {
StartTlsClient {
state: StartTlsClientState::AwaitFeatures(xmpp_stream),
arc_config: arc_config,
}
}
}
// TODO: eval <stream:features>, check ns
impl Future for StartTlsClient {
type Item = XMPPStream<TlsStream<TcpStream, ClientSession>>;
type Error = Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
let mut retry = false;
let (new_state, result) = match old_state {
StartTlsClientState::AwaitFeatures(mut xmpp_stream) =>
match xmpp_stream.poll() {
Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
if stanza.name == "features"
&& stanza.ns == Some(NS_XMPP_STREAM.to_owned())
=>
{
println!("Got features: {}", stanza);
match stanza.get_child("starttls", Some(NS_XMPP_TLS)) {
None =>
(StartTlsClientState::Invalid, Err(Error::from(ErrorKind::InvalidData))),
Some(_) => {
let nonza = xml::Element::new(
"starttls".to_owned(), Some(NS_XMPP_TLS.to_owned()),
vec![]
);
println!("send {}", nonza);
let packet = Packet::Stanza(nonza);
let send = xmpp_stream.send(packet);
let new_state = StartTlsClientState::SendStartTls(send);
retry = true;
(new_state, Ok(Async::NotReady))
},
}
},
Ok(Async::Ready(value)) => {
println!("StartTlsClient ignore {:?}", value);
(StartTlsClientState::AwaitFeatures(xmpp_stream), Ok(Async::NotReady))
},
Ok(_) =>
(StartTlsClientState::AwaitFeatures(xmpp_stream), Ok(Async::NotReady)),
Err(e) =>
(StartTlsClientState::AwaitFeatures(xmpp_stream), Err(e)),
},
StartTlsClientState::SendStartTls(mut send) =>
match send.poll() {
Ok(Async::Ready(xmpp_stream)) => {
println!("starttls sent");
let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
retry = true;
(new_state, Ok(Async::NotReady))
},
Ok(Async::NotReady) =>
(StartTlsClientState::SendStartTls(send), Ok(Async::NotReady)),
Err(e) =>
(StartTlsClientState::SendStartTls(send), Err(e)),
},
StartTlsClientState::AwaitProceed(mut xmpp_stream) =>
match xmpp_stream.poll() {
Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
if stanza.name == "proceed" =>
{
println!("* proceed *");
let stream = xmpp_stream.into_inner();
let connect = self.arc_config.connect_async("spaceboyz.net", stream);
let new_state = StartTlsClientState::StartingTls(connect);
retry = true;
(new_state, Ok(Async::NotReady))
},
Ok(Async::Ready(value)) => {
println!("StartTlsClient ignore {:?}", value);
(StartTlsClientState::AwaitFeatures(xmpp_stream), Ok(Async::NotReady))
},
Ok(_) =>
(StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)),
Err(e) =>
(StartTlsClientState::AwaitProceed(xmpp_stream), Err(e)),
},
StartTlsClientState::StartingTls(mut connect) =>
match connect.poll() {
Ok(Async::Ready(tls_stream)) => {
println!("Got a TLS stream!");
let xmpp_stream = XMPPCodec::frame_stream(tls_stream);
(StartTlsClientState::Invalid, Ok(Async::Ready(xmpp_stream)))
},
Ok(Async::NotReady) =>
(StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)),
Err(e) =>
(StartTlsClientState::StartingTls(connect), Err(e)),
},
StartTlsClientState::Invalid =>
unreachable!(),
};
self.state = new_state;
if retry {
self.poll()
} else {
result
}
}
}

View file

@ -5,7 +5,6 @@ use futures::{Future, Sink, Poll, Async};
use futures::stream::Stream;
use futures::sink;
use tokio_core::reactor::Handle;
use tokio_io::AsyncRead;
use tokio_core::net::{TcpStream, TcpStreamNew};
use super::{XMPPStream, XMPPCodec, Packet};
@ -53,7 +52,7 @@ impl Future for TcpClient {
let (new_state, result) = match self.state {
TcpClientState::Connecting(ref mut tcp_stream_new) => {
let tcp_stream = try_ready!(tcp_stream_new.poll());
let xmpp_stream = AsyncRead::framed(tcp_stream, XMPPCodec::new());
let xmpp_stream = XMPPCodec::frame_stream(tcp_stream);
let send = xmpp_stream.send(Packet::StreamStart);
let new_state = TcpClientState::SendStart(send);
(new_state, Ok(Async::NotReady))

View file

@ -3,6 +3,7 @@ use std::fmt::Write;
use std::str::from_utf8;
use std::io::{Error, ErrorKind};
use std::collections::HashMap;
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_io::codec::{Framed, Encoder, Decoder};
use xml;
use bytes::*;
@ -67,6 +68,12 @@ impl XMPPCodec {
root: None,
}
}
pub fn frame_stream<S>(stream: S) -> Framed<S, XMPPCodec>
where S: AsyncRead + AsyncWrite
{
AsyncRead::framed(stream, XMPPCodec::new())
}
}
impl Decoder for XMPPCodec {
@ -146,6 +153,9 @@ impl Encoder for XMPPCodec {
NS_CLIENT, NS_STREAMS)
.map_err(|_| Error::from(ErrorKind::WriteZero))
},
Packet::Stanza(stanza) =>
write!(dst, "{}", stanza)
.map_err(|_| Error::from(ErrorKind::InvalidInput)),
// TODO: Implement all
_ => Ok(())
}