diff --git a/src/lib.rs b/src/lib.rs index ca596a6..ea726d5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,3 +6,5 @@ pub mod transport; pub mod error; pub mod jid; pub mod client; + +mod locked_io; diff --git a/src/locked_io.rs b/src/locked_io.rs new file mode 100644 index 0000000..fc78d63 --- /dev/null +++ b/src/locked_io.rs @@ -0,0 +1,37 @@ +use std::io; +use std::io::prelude::*; + +use std::sync::{Arc, Mutex}; + +pub struct LockedIO(Arc>); + +impl LockedIO { + pub fn from(inner: Arc>) -> LockedIO { + LockedIO(inner) + } +} + +impl Clone for LockedIO { + fn clone(&self) -> LockedIO { + LockedIO(self.0.clone()) + } +} + +impl io::Write for LockedIO { + fn write(&mut self, buf: &[u8]) -> io::Result { + let mut inner = self.0.lock().unwrap(); // TODO: make safer + inner.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + let mut inner = self.0.lock().unwrap(); // TODO: make safer + inner.flush() + } +} + +impl io::Read for LockedIO { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let mut inner = self.0.lock().unwrap(); // TODO: make safer + inner.read(buf) + } +} diff --git a/src/transport.rs b/src/transport.rs index 8016f87..f016412 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -1,5 +1,4 @@ use std::io::prelude::*; -use std::io; use std::net::TcpStream; @@ -10,6 +9,8 @@ use std::sync::{Arc, Mutex}; use ns; +use locked_io::LockedIO; + use error::Error; use openssl::ssl::{SslMethod, SslConnectorBuilder, SslStream}; @@ -19,27 +20,6 @@ pub trait Transport { fn read_event(&mut self) -> Result; } -struct LockedIO(Arc>); - -impl io::Write for LockedIO { - fn write(&mut self, buf: &[u8]) -> io::Result { - let mut inner = self.0.lock().unwrap(); // TODO: make safer - inner.write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - let mut inner = self.0.lock().unwrap(); // TODO: make safer - inner.flush() - } -} - -impl io::Read for LockedIO { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let mut inner = self.0.lock().unwrap(); // TODO: make safer - inner.read(buf) - } -} - pub struct SslTransport { inner: Arc>>, // TODO: this feels rather ugly reader: EventReader>>, // TODO: especially feels ugly because @@ -85,8 +65,9 @@ impl SslTransport { let stream = parser.into_inner(); let ssl_connector = SslConnectorBuilder::new(SslMethod::tls())?.build(); let ssl_stream = Arc::new(Mutex::new(ssl_connector.connect(host, stream)?)); - let reader = EventReader::new(LockedIO(ssl_stream.clone())); - let writer = EventWriter::new(LockedIO(ssl_stream.clone())); + let locked_io = LockedIO::from(ssl_stream.clone()); + let reader = EventReader::new(locked_io.clone()); + let writer = EventWriter::new(locked_io); Ok(SslTransport { inner: ssl_stream, reader: reader,