diff --git a/src/error.rs b/src/error.rs index dfe3faed..1646ffff 100644 --- a/src/error.rs +++ b/src/error.rs @@ -6,10 +6,12 @@ use openssl::ssl::HandshakeError; use openssl::error::ErrorStack; use xml::reader::Error as XmlError; +use xml::writer::Error as EmitterError; #[derive(Debug)] pub enum Error { XmlError(XmlError), + EmitterError(EmitterError), IoError(io::Error), HandshakeError(HandshakeError), OpenSslErrorStack(ErrorStack), @@ -22,6 +24,12 @@ impl From for Error { } } +impl From for Error { + fn from(err: EmitterError) -> Error { + Error::EmitterError(err) + } +} + impl From for Error { fn from(err: io::Error) -> Error { Error::IoError(err) diff --git a/src/transport.rs b/src/transport.rs index 16c88e16..bf899168 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -3,7 +3,10 @@ use std::io; use std::net::{SocketAddr, TcpStream}; -use xml::reader::{EventReader, XmlEvent}; +use xml::reader::{EventReader, XmlEvent as XmlReaderEvent}; +use xml::writer::{EventWriter, XmlEvent as XmlWriterEvent}; + +use std::sync::{Arc, Mutex}; use ns; @@ -11,8 +14,50 @@ use error::Error; use openssl::ssl::{SslMethod, SslConnectorBuilder, SslStream}; +pub trait Transport { + fn write_event<'a, E: Into>>(&mut self, event: E) -> Result<(), Error>; + fn read_event(&mut self) -> Result; +} + +struct LockedWrite(Arc>); + +impl io::Write for LockedWrite { + fn write(&mut self, buf: &[u8]) -> io::Result { + let mut inner = self.0.lock().unwrap(); // TODO: make safer + inner.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + let mut inner = self.0.lock().unwrap(); // TODO: make safer + inner.flush() + } +} + +struct LockedRead(Arc>); + +impl io::Read for LockedRead { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let mut inner = self.0.lock().unwrap(); // TODO: make safer + inner.read(buf) + } +} + pub struct SslTransport { - inner: SslStream, + inner: Arc>>, // TODO: this feels rather ugly + reader: EventReader>>, // TODO: especially feels ugly because + // this read would keep the lock + // held very long (potentially) + writer: EventWriter>>, +} + +impl Transport for SslTransport { + fn write_event<'a, E: Into>>(&mut self, event: E) -> Result<(), Error> { + Ok(self.writer.write(event)?) + } + + fn read_event(&mut self) -> Result { + Ok(self.reader.next()?) + } } impl SslTransport { @@ -26,7 +71,7 @@ impl SslTransport { let mut parser = EventReader::new(stream); loop { // TODO: possibly a timeout? match parser.next()? { - XmlEvent::StartElement { name, namespace, .. } => { + XmlReaderEvent::StartElement { name, namespace, .. } => { if let Some(ns) = name.namespace { if ns == ns::TLS && name.local_name == "proceed" { break; @@ -41,9 +86,13 @@ impl SslTransport { } let stream = parser.into_inner(); let ssl_connector = SslConnectorBuilder::new(SslMethod::tls())?.build(); - let ssl_stream = ssl_connector.connect(host, stream)?; + let ssl_stream = Arc::new(Mutex::new(ssl_connector.connect(host, stream)?)); + let reader = EventReader::new(LockedRead(ssl_stream.clone())); + let writer = EventWriter::new(LockedWrite(ssl_stream.clone())); Ok(SslTransport { - inner: ssl_stream + inner: ssl_stream, + reader: reader, + writer: writer, }) } }