diff --git a/tokio-xmpp/src/xmlstream/initiator.rs b/tokio-xmpp/src/xmlstream/initiator.rs index 76cc4d47..5d6528e0 100644 --- a/tokio-xmpp/src/xmlstream/initiator.rs +++ b/tokio-xmpp/src/xmlstream/initiator.rs @@ -8,6 +8,8 @@ use core::pin::Pin; use std::borrow::Cow; use std::io; +use futures::SinkExt; + use tokio::io::{AsyncBufRead, AsyncWrite}; use xmpp_parsers::stream_features::StreamFeatures; @@ -19,6 +21,27 @@ use super::{ XmlStream, }; +/// Type state for an initiator stream which has not yet sent its stream +/// header. +/// +/// To continue stream setup, call [`send_header`][`Self::send_header`]. +pub struct InitiatingStream(pub(super) RawXmlStream); + +impl InitiatingStream { + /// Send the stream header. + pub async fn send_header( + self, + header: StreamHeader<'_>, + ) -> io::Result> { + let Self(mut stream) = self; + + header.send(Pin::new(&mut stream)).await?; + stream.flush().await?; + let header = StreamHeader::recv(Pin::new(&mut stream)).await?; + Ok(PendingFeaturesRecv { stream, header }) + } +} + /// Type state for an initiator stream which has sent and received the stream /// header. /// diff --git a/tokio-xmpp/src/xmlstream/mod.rs b/tokio-xmpp/src/xmlstream/mod.rs index bac1d3de..b14915ce 100644 --- a/tokio-xmpp/src/xmlstream/mod.rs +++ b/tokio-xmpp/src/xmlstream/mod.rs @@ -41,7 +41,7 @@ use core::pin::Pin; use core::task::{Context, Poll}; use std::io; -use futures::{ready, Sink, SinkExt, Stream}; +use futures::{ready, Sink, Stream}; use tokio::io::{AsyncBufRead, AsyncWrite}; @@ -54,7 +54,7 @@ mod responder; mod tests; use self::common::{RawXmlStream, ReadXsoError, ReadXsoState, StreamHeader}; -pub use self::initiator::PendingFeaturesRecv; +pub use self::initiator::{InitiatingStream, PendingFeaturesRecv}; pub use self::responder::{AcceptedStream, PendingFeaturesSend}; /// Initiate a new stream @@ -70,16 +70,8 @@ pub async fn initiate_stream( stream_ns: &'static str, stream_header: StreamHeader<'_>, ) -> Result, io::Error> { - let mut raw_stream = RawXmlStream::new(io, stream_ns); - stream_header.send(Pin::new(&mut raw_stream)).await?; - raw_stream.flush().await?; - - let header = StreamHeader::recv(Pin::new(&mut raw_stream)).await?; - - Ok(PendingFeaturesRecv { - stream: raw_stream, - header, - }) + let stream = InitiatingStream(RawXmlStream::new(io, stream_ns)); + stream.send_header(stream_header).await } /// Accept a new XML stream as responder @@ -194,8 +186,8 @@ impl XmlStream { impl XmlStream { /// Initiate a stream reset /// - /// The `header` is the new stream header which is sent to the remote - /// party. + /// To actually send the stream header, call + /// [`send_header`][`InitiatingStream::send_header`] on the result. /// /// # Panics /// @@ -205,18 +197,12 @@ impl XmlStream /// /// In addition, attempting to reset a stream which has been closed by /// either side or which has had an I/O error will also cause a panic. - pub async fn initiate_reset( - self, - header: StreamHeader<'_>, - ) -> io::Result> { + pub fn initiate_reset(self) -> InitiatingStream { self.assert_retypable(); let mut stream = self.inner; Pin::new(&mut stream).reset_state(); - header.send(Pin::new(&mut stream)).await?; - stream.flush().await?; - let header = StreamHeader::recv(Pin::new(&mut stream)).await?; - Ok(PendingFeaturesRecv { stream, header }) + InitiatingStream(stream) } /// Anticipate a new stream header sent by the remote party. diff --git a/tokio-xmpp/src/xmlstream/tests.rs b/tokio-xmpp/src/xmlstream/tests.rs index 92ea5fc1..2ac5cb6b 100644 --- a/tokio-xmpp/src/xmlstream/tests.rs +++ b/tokio-xmpp/src/xmlstream/tests.rs @@ -4,7 +4,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -use futures::StreamExt; +use futures::{SinkExt, StreamExt}; use xmpp_parsers::stream_features::StreamFeatures; @@ -185,7 +185,8 @@ async fn test_exchange_data_stream_reset_and_shutdown() { other => panic!("unexpected stream message: {:?}", other), } let stream = stream - .initiate_reset(StreamHeader { + .initiate_reset() + .send_header(StreamHeader { from: Some("client".into()), to: Some("server".into()), id: Some("client-id".into()),