diff --git a/tokio-xmpp/Cargo.toml b/tokio-xmpp/Cargo.toml index 49e9b9f5..1a07dd35 100644 --- a/tokio-xmpp/Cargo.toml +++ b/tokio-xmpp/Cargo.toml @@ -39,6 +39,7 @@ tokio-rustls = { version = "0.26", optional = true } [dev-dependencies] env_logger = { version = "0.11", default-features = false, features = ["auto-color", "humantime"] } # this is needed for echo-component example +tokio = { version = "1", features = ["test-util"] } tokio-xmpp = { path = ".", features = ["insecure-tcp"]} [features] diff --git a/tokio-xmpp/examples/echo_component.rs b/tokio-xmpp/examples/echo_component.rs index 70ffa538..6c549349 100644 --- a/tokio-xmpp/examples/echo_component.rs +++ b/tokio-xmpp/examples/echo_component.rs @@ -31,9 +31,7 @@ async fn main() { // If you don't need a custom server but default localhost:5347, you can use // Component::new() directly - let mut component = Component::new_plaintext(jid, password, server) - .await - .unwrap(); + let mut component = Component::new(jid, password).await.unwrap(); // Make the two interfaces for sending and receiving independent // of each other so we can move one into a closure. diff --git a/tokio-xmpp/examples/echo_server.rs b/tokio-xmpp/examples/echo_server.rs index dbf2eb05..6cbb75cd 100644 --- a/tokio-xmpp/examples/echo_server.rs +++ b/tokio-xmpp/examples/echo_server.rs @@ -2,7 +2,7 @@ use futures::{SinkExt, StreamExt}; use tokio::{self, io, net::TcpSocket}; use tokio_xmpp::parsers::stream_features::StreamFeatures; -use tokio_xmpp::xmlstream::{accept_stream, StreamHeader}; +use tokio_xmpp::xmlstream::{accept_stream, StreamHeader, Timeouts}; #[tokio::main] async fn main() -> Result<(), io::Error> { @@ -19,6 +19,7 @@ async fn main() -> Result<(), io::Error> { let stream = accept_stream( tokio::io::BufStream::new(stream), tokio_xmpp::parsers::ns::DEFAULT_NS, + Timeouts::default(), ) .await?; let stream = stream.send_header(StreamHeader::default()).await?; diff --git a/tokio-xmpp/src/client/login.rs b/tokio-xmpp/src/client/login.rs index 14cb0293..90aa0a34 100644 --- a/tokio-xmpp/src/client/login.rs +++ b/tokio-xmpp/src/client/login.rs @@ -19,7 +19,9 @@ use crate::{ client::bind::bind, connect::ServerConnector, error::{AuthError, Error, ProtocolError}, - xmlstream::{xmpp::XmppStreamElement, InitiatingStream, ReadError, StreamHeader, XmppStream}, + xmlstream::{ + xmpp::XmppStreamElement, InitiatingStream, ReadError, StreamHeader, Timeouts, XmppStream, + }, }; pub async fn auth( @@ -107,11 +109,12 @@ pub async fn client_login( server: C, jid: Jid, password: String, + timeouts: Timeouts, ) -> Result<(Option, StreamFeatures, XmppStream), Error> { let username = jid.node().unwrap().as_str(); let password = password; - let xmpp_stream = server.connect(&jid, ns::JABBER_CLIENT).await?; + let xmpp_stream = server.connect(&jid, ns::JABBER_CLIENT, timeouts).await?; let (features, xmpp_stream) = xmpp_stream.recv_features().await?; let channel_binding = C::channel_binding(xmpp_stream.get_stream())?; diff --git a/tokio-xmpp/src/client/mod.rs b/tokio-xmpp/src/client/mod.rs index 13fcf6f2..0197ccd3 100644 --- a/tokio-xmpp/src/client/mod.rs +++ b/tokio-xmpp/src/client/mod.rs @@ -5,6 +5,7 @@ use crate::{ client::{login::client_login, stream::ClientState}, connect::ServerConnector, error::Error, + xmlstream::Timeouts, Stanza, }; @@ -30,6 +31,7 @@ pub struct Client { password: String, connector: C, state: ClientState, + timeouts: Timeouts, reconnect: bool, // TODO: tls_required=true } @@ -95,6 +97,7 @@ impl Client { jid.clone(), password, DnsConfig::srv(&jid.domain().to_string(), "_xmpp-client._tcp", 5222), + Timeouts::default(), ); client.set_reconnect(true); client @@ -105,8 +108,14 @@ impl Client { jid: J, password: P, dns_config: DnsConfig, + timeouts: Timeouts, ) -> Self { - Self::new_with_connector(jid, password, StartTlsServerConnector::from(dns_config)) + Self::new_with_connector( + jid, + password, + StartTlsServerConnector::from(dns_config), + timeouts, + ) } } @@ -117,8 +126,14 @@ impl Client { jid: J, password: P, dns_config: DnsConfig, + timeouts: Timeouts, ) -> Self { - Self::new_with_connector(jid, password, TcpServerConnector::from(dns_config)) + Self::new_with_connector( + jid, + password, + TcpServerConnector::from(dns_config), + timeouts, + ) } } @@ -128,6 +143,7 @@ impl Client { jid: J, password: P, connector: C, + timeouts: Timeouts, ) -> Self { let jid = jid.into(); let password = password.into(); @@ -136,6 +152,7 @@ impl Client { connector.clone(), jid.clone(), password.clone(), + timeouts, )); let client = Client { jid, @@ -143,6 +160,7 @@ impl Client { connector, state: ClientState::Connecting(connect), reconnect: false, + timeouts, }; client } diff --git a/tokio-xmpp/src/client/stream.rs b/tokio-xmpp/src/client/stream.rs index 7b19ecb8..8120a870 100644 --- a/tokio-xmpp/src/client/stream.rs +++ b/tokio-xmpp/src/client/stream.rs @@ -56,6 +56,7 @@ impl Stream for Client { self.connector.clone(), self.jid.clone(), self.password.clone(), + self.timeouts, )); self.state = ClientState::Connecting(connect); self.poll_next(cx) diff --git a/tokio-xmpp/src/component/login.rs b/tokio-xmpp/src/component/login.rs index 427ef89a..33b743e4 100644 --- a/tokio-xmpp/src/component/login.rs +++ b/tokio-xmpp/src/component/login.rs @@ -6,16 +6,17 @@ use xmpp_parsers::{component::Handshake, jid::Jid, ns}; use crate::component::ServerConnector; use crate::error::{AuthError, Error}; -use crate::xmlstream::{ReadError, XmppStream, XmppStreamElement}; +use crate::xmlstream::{ReadError, Timeouts, XmppStream, XmppStreamElement}; /// Log into an XMPP server as a client with a jid+pass pub async fn component_login( connector: C, jid: Jid, password: String, + timeouts: Timeouts, ) -> Result, Error> { let password = password; - let mut stream = connector.connect(&jid, ns::COMPONENT).await?; + let mut stream = connector.connect(&jid, ns::COMPONENT, timeouts).await?; let header = stream.take_header(); let mut stream = stream.skip_features(); let stream_id = match header.id { diff --git a/tokio-xmpp/src/component/mod.rs b/tokio-xmpp/src/component/mod.rs index efd58a40..d041fe46 100644 --- a/tokio-xmpp/src/component/mod.rs +++ b/tokio-xmpp/src/component/mod.rs @@ -6,8 +6,10 @@ use std::str::FromStr; use xmpp_parsers::jid::Jid; use crate::{ - component::login::component_login, connect::ServerConnector, xmlstream::XmppStream, Error, - Stanza, + component::login::component_login, + connect::ServerConnector, + xmlstream::{Timeouts, XmppStream}, + Error, Stanza, }; #[cfg(any(feature = "starttls", feature = "insecure-tcp"))] @@ -46,7 +48,13 @@ impl Component { /// Start a new XMPP component over plaintext TCP to localhost:5347 #[cfg(feature = "insecure-tcp")] pub async fn new(jid: &str, password: &str) -> Result { - Self::new_plaintext(jid, password, DnsConfig::addr("127.0.0.1:5347")).await + Self::new_plaintext( + jid, + password, + DnsConfig::addr("127.0.0.1:5347"), + Timeouts::tight(), + ) + .await } /// Start a new XMPP component over plaintext TCP @@ -55,8 +63,15 @@ impl Component { jid: &str, password: &str, dns_config: DnsConfig, + timeouts: Timeouts, ) -> Result { - Component::new_with_connector(jid, password, TcpServerConnector::from(dns_config)).await + Component::new_with_connector( + jid, + password, + TcpServerConnector::from(dns_config), + timeouts, + ) + .await } } @@ -69,10 +84,11 @@ impl Component { jid: &str, password: &str, connector: C, + timeouts: Timeouts, ) -> Result { let jid = Jid::from_str(jid)?; let password = password.to_owned(); - let stream = component_login(connector, jid.clone(), password).await?; + let stream = component_login(connector, jid.clone(), password, timeouts).await?; Ok(Component { jid, stream }) } } diff --git a/tokio-xmpp/src/connect/mod.rs b/tokio-xmpp/src/connect/mod.rs index 937b0402..ca1ae61f 100644 --- a/tokio-xmpp/src/connect/mod.rs +++ b/tokio-xmpp/src/connect/mod.rs @@ -4,7 +4,7 @@ use sasl::common::ChannelBinding; use tokio::io::{AsyncBufRead, AsyncWrite}; use xmpp_parsers::jid::Jid; -use crate::xmlstream::PendingFeaturesRecv; +use crate::xmlstream::{PendingFeaturesRecv, Timeouts}; use crate::Error; #[cfg(feature = "starttls")] @@ -36,6 +36,7 @@ pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static { &self, jid: &Jid, ns: &'static str, + timeouts: Timeouts, ) -> impl std::future::Future, Error>> + Send; /// Return channel binding data if available diff --git a/tokio-xmpp/src/connect/starttls.rs b/tokio-xmpp/src/connect/starttls.rs index 1d0d26e4..7df476c1 100644 --- a/tokio-xmpp/src/connect/starttls.rs +++ b/tokio-xmpp/src/connect/starttls.rs @@ -44,7 +44,7 @@ use crate::{ connect::{DnsConfig, ServerConnector, ServerConnectorError}, error::{Error, ProtocolError}, xmlstream::{ - initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, XmppStream, + initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, Timeouts, XmppStream, XmppStreamElement, }, Client, @@ -70,6 +70,7 @@ impl ServerConnector for StartTlsServerConnector { &self, jid: &Jid, ns: &'static str, + timeouts: Timeouts, ) -> Result, Error> { let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?); @@ -82,6 +83,7 @@ impl ServerConnector for StartTlsServerConnector { from: None, id: None, }, + timeouts, ) .await?; let (features, xmpp_stream) = xmpp_stream.recv_features().await?; @@ -98,6 +100,7 @@ impl ServerConnector for StartTlsServerConnector { from: None, id: None, }, + timeouts, ) .await?) } else { diff --git a/tokio-xmpp/src/connect/tcp.rs b/tokio-xmpp/src/connect/tcp.rs index 89a9e12d..474e25b7 100644 --- a/tokio-xmpp/src/connect/tcp.rs +++ b/tokio-xmpp/src/connect/tcp.rs @@ -6,7 +6,7 @@ use tokio::{io::BufStream, net::TcpStream}; use crate::{ connect::{DnsConfig, ServerConnector}, - xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader}, + xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader, Timeouts}, Client, Component, Error, }; @@ -35,6 +35,7 @@ impl ServerConnector for TcpServerConnector { &self, jid: &xmpp_parsers::jid::Jid, ns: &'static str, + timeouts: Timeouts, ) -> Result, Error> { let stream = BufStream::new(self.0.resolve().await?); Ok(initiate_stream( @@ -45,6 +46,7 @@ impl ServerConnector for TcpServerConnector { from: None, id: None, }, + timeouts, ) .await?) } diff --git a/tokio-xmpp/src/xmlstream/common.rs b/tokio-xmpp/src/xmlstream/common.rs index 2ea0526b..ff509824 100644 --- a/tokio-xmpp/src/xmlstream/common.rs +++ b/tokio-xmpp/src/xmlstream/common.rs @@ -7,6 +7,7 @@ use core::future::Future; use core::pin::Pin; use core::task::{Context, Poll}; +use core::time::Duration; use std::borrow::Cow; use std::io; @@ -14,7 +15,10 @@ use futures::{ready, Sink, SinkExt, Stream, StreamExt}; use bytes::{Buf, BytesMut}; -use tokio::io::{AsyncBufRead, AsyncWrite}; +use tokio::{ + io::{AsyncBufRead, AsyncWrite}, + time::Instant, +}; use xso::{ exports::rxml::{self, writer::TrackNamespace, xml_ncname, Event, Namespace}, @@ -25,6 +29,129 @@ use super::capture::{log_enabled, log_recv, log_send, CaptureBufRead}; use xmpp_parsers::ns::STREAM as XML_STREAM_NS; +/// Configuration for timeouts on an XML stream. +/// +/// The defaults are tuned toward common desktop/laptop use and may not hold +/// up to extreme conditions (arctic sattelite link, mobile internet on a +/// train in Brandenburg, Germany, and similar) and may be inefficient in +/// other conditions (stable server link, localhost communication). +#[derive(Debug, Clone, Copy)] +pub struct Timeouts { + /// Maximum silence time before a + /// [`ReadError::SoftTimeout`][`super::ReadError::SoftTimeout`] is + /// returned. + /// + /// Soft timeouts are not fatal, but they must be handled by user code so + /// that more data is read after at most [`Self::response_timeout`], + /// starting from the moment the soft timeout is returned. + pub read_timeout: Duration, + + /// Maximum silence after a soft timeout. + /// + /// If the stream is silent for longer than this time after a soft timeout + /// has been emitted, a hard [`TimedOut`][`std::io::ErrorKind::TimedOut`] + /// I/O error is returned and the stream is to be considered dead. + pub response_timeout: Duration, +} + +impl Default for Timeouts { + fn default() -> Self { + Self { + read_timeout: Duration::new(300, 0), + response_timeout: Duration::new(300, 0), + } + } +} + +impl Timeouts { + /// Tight timeouts suitable for communicating on a fast LAN or localhost. + pub fn tight() -> Self { + Self { + read_timeout: Duration::new(60, 0), + response_timeout: Duration::new(15, 0), + } + } + + fn data_to_soft(&self) -> Duration { + self.read_timeout + } + + fn soft_to_warn(&self) -> Duration { + self.response_timeout / 2 + } + + fn warn_to_hard(&self) -> Duration { + self.response_timeout / 2 + } +} + +#[derive(Clone, Copy)] +enum TimeoutLevel { + Soft, + Warn, + Hard, +} + +#[derive(Debug)] +pub(super) enum RawError { + Io(io::Error), + SoftTimeout, +} + +impl From for RawError { + fn from(other: io::Error) -> Self { + Self::Io(other) + } +} + +struct TimeoutState { + /// Configuration for the timeouts. + timeouts: Timeouts, + + /// Level of the next timeout which will trip. + level: TimeoutLevel, + + /// Sleep timer used for read timeouts. + // NOTE: even though we pretend we could deal with an !Unpin + // RawXmlStream, we really can't: box_stream for example needs it, + // but also all the typestate around the initial stream setup needs + // to be able to move the stream around. + deadline: Pin>, +} + +impl TimeoutState { + fn new(timeouts: Timeouts) -> Self { + Self { + deadline: Box::pin(tokio::time::sleep(timeouts.data_to_soft())), + level: TimeoutLevel::Soft, + timeouts, + } + } + + fn poll(&mut self, cx: &mut Context) -> Poll { + ready!(self.deadline.as_mut().poll(cx)); + // Deadline elapsed! + let to_return = self.level; + let (next_level, next_duration) = match self.level { + TimeoutLevel::Soft => (TimeoutLevel::Warn, self.timeouts.soft_to_warn()), + TimeoutLevel::Warn => (TimeoutLevel::Hard, self.timeouts.warn_to_hard()), + // Something short-ish so that we fire this over and over until + // someone finally kills the stream for good. + TimeoutLevel::Hard => (TimeoutLevel::Hard, Duration::new(1, 0)), + }; + self.level = next_level; + self.deadline.as_mut().reset(Instant::now() + next_duration); + Poll::Ready(to_return) + } + + fn reset(&mut self) { + self.level = TimeoutLevel::Soft; + self.deadline + .as_mut() + .reset((Instant::now() + self.timeouts.data_to_soft()).into()); + } +} + pin_project_lite::pin_project! { // NOTE: due to limitations of pin_project_lite, the field comments are // no doc comments. Luckily, this struct is only `pub(super)` anyway. @@ -37,6 +164,8 @@ pin_project_lite::pin_project! { // The writer used for serialising data. writer: rxml::writer::Encoder, + timeouts: TimeoutState, + // The default namespace to declare on the stream header. stream_ns: &'static str, @@ -112,7 +241,7 @@ impl RawXmlStream { writer } - pub(super) fn new(io: Io, stream_ns: &'static str) -> Self { + pub(super) fn new(io: Io, stream_ns: &'static str, timeouts: Timeouts) -> Self { let parser = rxml::Parser::default(); let mut io = CaptureBufRead::wrap(io); if log_enabled() { @@ -121,6 +250,7 @@ impl RawXmlStream { Self { parser: rxml::AsyncReader::wrap(io, parser), writer: Self::new_writer(stream_ns), + timeouts: TimeoutState::new(timeouts), tx_buffer_logged: 0, stream_ns, tx_buffer: BytesMut::new(), @@ -189,18 +319,34 @@ impl RawXmlStream { } impl Stream for RawXmlStream { - type Item = Result; + type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); loop { - return Poll::Ready( - match ready!(this.parser.as_mut().poll_read(cx)).transpose() { - // Skip the XML declaration, nobody wants to hear about that. - Some(Ok(rxml::Event::XmlDeclaration(_, _))) => continue, - other => other, - }, - ); + match this.parser.as_mut().poll_read(cx) { + Poll::Pending => (), + Poll::Ready(v) => { + this.timeouts.reset(); + match v.transpose() { + // Skip the XML declaration, nobody wants to hear about that. + Some(Ok(rxml::Event::XmlDeclaration(_, _))) => continue, + other => return Poll::Ready(other.map(|x| x.map_err(RawError::Io))), + } + } + }; + + // poll_read returned pending... what do the timeouts have to say? + match ready!(this.timeouts.poll(cx)) { + TimeoutLevel::Soft => return Poll::Ready(Some(Err(RawError::SoftTimeout))), + TimeoutLevel::Warn => (), + TimeoutLevel::Hard => { + return Poll::Ready(Some(Err(RawError::Io(io::Error::new( + io::ErrorKind::TimedOut, + "read and response timeouts elapsed", + ))))) + } + } } } } @@ -312,6 +458,20 @@ pub(super) enum ReadXsoError { /// not well-formed document. Hard(io::Error), + /// The underlying stream signalled a soft read timeout before a child + /// element could be read. + /// + /// Note that soft timeouts which are triggered in the middle of receiving + /// an element are converted to hard timeouts (i.e. I/O errors). + /// + /// This masking is intentional, because: + /// - Returning a [`Self::SoftTimeout`] from the middle of parsing is not + /// possible without complicating the API. + /// - There is no reason why the remote side should interrupt sending data + /// in the middle of an element except if it or the transport has failed + /// fatally. + SoftTimeout, + /// A parse error occurred. /// /// The XML structure was well-formed, but the data contained did not @@ -324,19 +484,6 @@ pub(super) enum ReadXsoError { Parse(xso::error::Error), } -impl From for io::Error { - fn from(other: ReadXsoError) -> Self { - match other { - ReadXsoError::Hard(v) => v, - ReadXsoError::Parse(e) => io::Error::new(io::ErrorKind::InvalidData, e), - ReadXsoError::Footer => io::Error::new( - io::ErrorKind::UnexpectedEof, - "element footer while waiting for XSO element start", - ), - } - } -} - impl From for ReadXsoError { fn from(other: io::Error) -> Self { Self::Hard(other) @@ -425,13 +572,13 @@ impl ReadXsoState { .parser_pinned() .set_text_buffering(text_buffering); - let ev = ready!(source.as_mut().poll_next(cx)).transpose()?; + let ev = ready!(source.as_mut().poll_next(cx)).transpose(); match self { ReadXsoState::PreData => { log::trace!("ReadXsoState::PreData ev = {:?}", ev); match ev { - Some(rxml::Event::XmlDeclaration(_, _)) => (), - Some(rxml::Event::Text(_, data)) => { + Ok(Some(rxml::Event::XmlDeclaration(_, _))) => (), + Ok(Some(rxml::Event::Text(_, data))) => { if xso::is_xml_whitespace(data.as_bytes()) { log::trace!("Received {} bytes of whitespace", data.len()); source.as_mut().stream_pinned().discard_capture(); @@ -445,18 +592,18 @@ impl ReadXsoState { .into())); } } - Some(rxml::Event::StartElement(_, name, attrs)) => { + Ok(Some(rxml::Event::StartElement(_, name, attrs))) => { *self = ReadXsoState::Parsing( as FromXml>::from_events(name, attrs) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?, ); } // Amounts to EOF, as we expect to start on the stream level. - Some(rxml::Event::EndElement(_)) => { + Ok(Some(rxml::Event::EndElement(_))) => { *self = ReadXsoState::Done; return Poll::Ready(Err(ReadXsoError::Footer)); } - None => { + Ok(None) => { *self = ReadXsoState::Done; return Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidData, @@ -464,17 +611,42 @@ impl ReadXsoState { ) .into())); } + Err(RawError::SoftTimeout) => { + *self = ReadXsoState::Done; + return Poll::Ready(Err(ReadXsoError::SoftTimeout)); + } + Err(RawError::Io(e)) => { + *self = ReadXsoState::Done; + return Poll::Ready(Err(ReadXsoError::Hard(e))); + } } } ReadXsoState::Parsing(builder) => { log::trace!("ReadXsoState::Parsing ev = {:?}", ev); - let Some(ev) = ev else { - *self = ReadXsoState::Done; - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "eof during XSO parsing", - ) - .into())); + let ev = match ev { + Ok(Some(ev)) => ev, + Ok(None) => { + *self = ReadXsoState::Done; + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "eof during XSO parsing", + ) + .into())); + } + Err(RawError::Io(e)) => { + *self = ReadXsoState::Done; + return Poll::Ready(Err(e.into())); + } + Err(RawError::SoftTimeout) => { + // See also [`ReadXsoError::SoftTimeout`] for why + // we mask the SoftTimeout condition here. + *self = ReadXsoState::Done; + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::TimedOut, + "read timeout during XSO parsing", + ) + .into())); + } }; match builder.feed(ev) { @@ -622,8 +794,10 @@ impl StreamHeader<'static> { mut stream: Pin<&mut RawXmlStream>, ) -> io::Result { loop { - match stream.as_mut().next().await.transpose()? { - Some(Event::StartElement(_, (ns, name), mut attrs)) => { + match stream.as_mut().next().await { + Some(Err(RawError::Io(e))) => return Err(e), + Some(Err(RawError::SoftTimeout)) => (), + Some(Ok(Event::StartElement(_, (ns, name), mut attrs))) => { if ns != XML_STREAM_NS || name != "stream" { return Err(io::Error::new( io::ErrorKind::InvalidData, @@ -666,7 +840,7 @@ impl StreamHeader<'static> { id: id.map(Cow::Owned), }); } - Some(Event::Text(_, _)) | Some(Event::EndElement(_)) => { + Some(Ok(Event::Text(_, _))) | Some(Ok(Event::EndElement(_))) => { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, "unexpected content before stream header", @@ -674,7 +848,7 @@ impl StreamHeader<'static> { } // We cannot loop infinitely here because the XML parser will // prevent more than one XML declaration from being parsed. - Some(Event::XmlDeclaration(_, _)) => (), + Some(Ok(Event::XmlDeclaration(_, _))) => (), None => { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, diff --git a/tokio-xmpp/src/xmlstream/initiator.rs b/tokio-xmpp/src/xmlstream/initiator.rs index 0829c9a0..4682548d 100644 --- a/tokio-xmpp/src/xmlstream/initiator.rs +++ b/tokio-xmpp/src/xmlstream/initiator.rs @@ -17,7 +17,7 @@ use xmpp_parsers::stream_features::StreamFeatures; use xso::{AsXml, FromXml}; use super::{ - common::{RawXmlStream, ReadXso, StreamHeader}, + common::{RawXmlStream, ReadXso, ReadXsoError, StreamHeader}, XmlStream, }; @@ -80,7 +80,22 @@ impl PendingFeaturesRecv { mut stream, header: _, } = self; - let features = ReadXso::read_from(Pin::new(&mut stream)).await?; + let features = loop { + match ReadXso::read_from(Pin::new(&mut stream)).await { + Ok(v) => break v, + Err(ReadXsoError::SoftTimeout) => (), + Err(ReadXsoError::Hard(e)) => return Err(e), + Err(ReadXsoError::Parse(e)) => { + return Err(io::Error::new(io::ErrorKind::InvalidData, e)) + } + Err(ReadXsoError::Footer) => { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "unexpected stream footer", + )) + } + } + }; Ok((features, XmlStream::wrap(stream))) } diff --git a/tokio-xmpp/src/xmlstream/mod.rs b/tokio-xmpp/src/xmlstream/mod.rs index a4d88a6a..5e414966 100644 --- a/tokio-xmpp/src/xmlstream/mod.rs +++ b/tokio-xmpp/src/xmlstream/mod.rs @@ -77,8 +77,8 @@ mod responder; mod tests; pub(crate) mod xmpp; -pub use self::common::StreamHeader; -use self::common::{RawXmlStream, ReadXsoError, ReadXsoState}; +use self::common::{RawError, RawXmlStream, ReadXsoError, ReadXsoState}; +pub use self::common::{StreamHeader, Timeouts}; pub use self::initiator::{InitiatingStream, PendingFeaturesRecv}; pub use self::responder::{AcceptedStream, PendingFeaturesSend}; pub use self::xmpp::XmppStreamElement; @@ -129,8 +129,9 @@ pub async fn initiate_stream( io: Io, stream_ns: &'static str, stream_header: StreamHeader<'_>, + timeouts: Timeouts, ) -> Result, io::Error> { - let stream = InitiatingStream(RawXmlStream::new(io, stream_ns)); + let stream = InitiatingStream(RawXmlStream::new(io, stream_ns, timeouts)); stream.send_header(stream_header).await } @@ -144,8 +145,9 @@ pub async fn initiate_stream( pub async fn accept_stream( io: Io, stream_ns: &'static str, + timeouts: Timeouts, ) -> Result, io::Error> { - let mut stream = RawXmlStream::new(io, stream_ns); + let mut stream = RawXmlStream::new(io, stream_ns, timeouts); let header = StreamHeader::recv(Pin::new(&mut stream)).await?; Ok(AcceptedStream { stream, header }) } @@ -319,14 +321,21 @@ impl Stream for XmlStream; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); + let mut this = self.project(); let result = match this.read_state.as_mut() { None => { // awaiting eof. - return match ready!(this.inner.poll_next(cx)) { - None => Poll::Ready(None), - Some(Ok(_)) => unreachable!("xml parser allowed data after stream footer"), - Some(Err(e)) => Poll::Ready(Some(Err(ReadError::HardError(e)))), + return loop { + match ready!(this.inner.as_mut().poll_next(cx)) { + None => break Poll::Ready(None), + Some(Ok(_)) => unreachable!("xml parser allowed data after stream footer"), + Some(Err(RawError::Io(e))) => { + break Poll::Ready(Some(Err(ReadError::HardError(e)))) + } + // Swallow soft timeout, we don't want the user to trigger + // anything here. + Some(Err(RawError::SoftTimeout)) => continue, + } }; } Some(read_state) => ready!(read_state.poll_advance(this.inner, cx)), @@ -341,6 +350,7 @@ impl Stream for XmlStream Poll::Ready(Some(Err(ReadError::SoftTimeout))), }; *this.read_state = Some(ReadXsoState::default()); result diff --git a/tokio-xmpp/src/xmlstream/tests.rs b/tokio-xmpp/src/xmlstream/tests.rs index 2270547c..44c0256f 100644 --- a/tokio-xmpp/src/xmlstream/tests.rs +++ b/tokio-xmpp/src/xmlstream/tests.rs @@ -4,6 +4,8 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. +use std::time::Duration; + use futures::{SinkExt, StreamExt}; use xmpp_parsers::stream_features::StreamFeatures; @@ -29,12 +31,18 @@ async fn test_initiate_accept_stream() { to: Some("server".into()), id: Some("client-id".into()), }, + Timeouts::tight(), ) .await?; Ok::<_, io::Error>(stream.take_header()) }); let responder = tokio::spawn(async move { - let stream = accept_stream(tokio::io::BufStream::new(rhs), "jabber:client").await?; + let stream = accept_stream( + tokio::io::BufStream::new(rhs), + "jabber:client", + Timeouts::tight(), + ) + .await?; assert_eq!(stream.header().from.unwrap(), "client"); assert_eq!(stream.header().to.unwrap(), "server"); assert_eq!(stream.header().id.unwrap(), "client-id"); @@ -61,13 +69,19 @@ async fn test_exchange_stream_features() { tokio::io::BufStream::new(lhs), "jabber:client", StreamHeader::default(), + Timeouts::tight(), ) .await?; let (features, _) = stream.recv_features::().await?; Ok::<_, io::Error>(features) }); let responder = tokio::spawn(async move { - let stream = accept_stream(tokio::io::BufStream::new(rhs), "jabber:client").await?; + let stream = accept_stream( + tokio::io::BufStream::new(rhs), + "jabber:client", + Timeouts::tight(), + ) + .await?; let stream = stream.send_header(StreamHeader::default()).await?; stream .send_features::(&StreamFeatures::default()) @@ -88,6 +102,7 @@ async fn test_exchange_data() { tokio::io::BufStream::new(lhs), "jabber:client", StreamHeader::default(), + Timeouts::tight(), ) .await?; let (_, mut stream) = stream.recv_features::().await?; @@ -104,7 +119,12 @@ async fn test_exchange_data() { }); let responder = tokio::spawn(async move { - let stream = accept_stream(tokio::io::BufStream::new(rhs), "jabber:client").await?; + let stream = accept_stream( + tokio::io::BufStream::new(rhs), + "jabber:client", + Timeouts::tight(), + ) + .await?; let stream = stream.send_header(StreamHeader::default()).await?; let mut stream = stream .send_features::(&StreamFeatures::default()) @@ -134,6 +154,7 @@ async fn test_clean_shutdown() { tokio::io::BufStream::new(lhs), "jabber:client", StreamHeader::default(), + Timeouts::tight(), ) .await?; let (_, mut stream) = stream.recv_features::().await?; @@ -146,7 +167,12 @@ async fn test_clean_shutdown() { }); let responder = tokio::spawn(async move { - let stream = accept_stream(tokio::io::BufStream::new(rhs), "jabber:client").await?; + let stream = accept_stream( + tokio::io::BufStream::new(rhs), + "jabber:client", + Timeouts::tight(), + ) + .await?; let stream = stream.send_header(StreamHeader::default()).await?; let mut stream = stream .send_features::(&StreamFeatures::default()) @@ -172,6 +198,7 @@ async fn test_exchange_data_stream_reset_and_shutdown() { tokio::io::BufStream::new(lhs), "jabber:client", StreamHeader::default(), + Timeouts::tight(), ) .await?; let (_, mut stream) = stream.recv_features::().await?; @@ -215,7 +242,12 @@ async fn test_exchange_data_stream_reset_and_shutdown() { }); let responder = tokio::spawn(async move { - let stream = accept_stream(tokio::io::BufStream::new(rhs), "jabber:client").await?; + let stream = accept_stream( + tokio::io::BufStream::new(rhs), + "jabber:client", + Timeouts::tight(), + ) + .await?; let stream = stream.send_header(StreamHeader::default()).await?; let mut stream = stream .send_features::(&StreamFeatures::default()) @@ -262,3 +294,104 @@ async fn test_exchange_data_stream_reset_and_shutdown() { responder.await.unwrap().expect("responder failed"); initiator.await.unwrap().expect("initiator failed"); } + +#[tokio::test(start_paused = true)] +async fn test_emits_soft_timeout_after_silence() { + let (lhs, rhs) = tokio::io::duplex(65536); + + let client_timeouts = Timeouts { + read_timeout: Duration::new(300, 0), + response_timeout: Duration::new(15, 0), + }; + + // We do want to trigger only one set of timeouts, so we set the server + // timeouts much longer than the client timeouts + let server_timeouts = Timeouts { + read_timeout: Duration::new(900, 0), + response_timeout: Duration::new(15, 0), + }; + + let initiator = tokio::spawn(async move { + let stream = initiate_stream( + tokio::io::BufStream::new(lhs), + "jabber:client", + StreamHeader::default(), + client_timeouts, + ) + .await?; + let (_, mut stream) = stream.recv_features::().await?; + stream + .send(&Data { + contents: "hello".to_owned(), + }) + .await?; + match stream.next().await { + Some(Ok(Data { contents })) => assert_eq!(contents, "world!"), + other => panic!("unexpected stream message: {:?}", other), + } + // Here we prove that the stream doesn't see any data and also does + // not see the SoftTimeout too early. + // (Well, not exactly a proof: We only check until half of the read + // timeout, because that was easy to write and I deem it good enough.) + match tokio::time::timeout(client_timeouts.read_timeout / 2, stream.next()).await { + Err(_) => (), + Ok(ev) => panic!("early stream message (before soft timeout): {:?}", ev), + }; + // Now the next thing that happens is the soft timeout ... + match stream.next().await { + Some(Err(ReadError::SoftTimeout)) => (), + other => panic!("unexpected stream message: {:?}", other), + } + // Another check that the there is some time between soft and hard + // timeout. + match tokio::time::timeout(client_timeouts.response_timeout / 3, stream.next()).await { + Err(_) => (), + Ok(ev) => { + panic!("early stream message (before hard timeout): {:?}", ev); + } + }; + // ... and thereafter the hard timeout in form of an I/O error. + match stream.next().await { + Some(Err(ReadError::HardError(e))) if e.kind() == io::ErrorKind::TimedOut => (), + other => panic!("unexpected stream message: {:?}", other), + } + Ok::<_, io::Error>(()) + }); + + let responder = tokio::spawn(async move { + let stream = accept_stream( + tokio::io::BufStream::new(rhs), + "jabber:client", + server_timeouts, + ) + .await?; + let stream = stream.send_header(StreamHeader::default()).await?; + let mut stream = stream + .send_features::(&StreamFeatures::default()) + .await?; + stream + .send(&Data { + contents: "world!".to_owned(), + }) + .await?; + match stream.next().await { + Some(Ok(Data { contents })) => assert_eq!(contents, "hello"), + other => panic!("unexpected stream message: {:?}", other), + } + match stream.next().await { + Some(Err(ReadError::HardError(e))) if e.kind() == io::ErrorKind::InvalidData => { + match e.downcast::() { + // the initiator closes the stream by dropping it once the + // timeout trips, so we get a hard eof here. + Ok(rxml::Error::InvalidEof(_)) => (), + other => panic!("unexpected error: {:?}", other), + } + } + other => panic!("unexpected stream message: {:?}", other), + } + Ok::<_, io::Error>(()) + }); + + responder.await.unwrap().expect("responder failed"); + initiator.await.unwrap().expect("initiator failed"); +} diff --git a/xmpp/src/builder.rs b/xmpp/src/builder.rs index 8b0e769d..f8f29ce3 100644 --- a/xmpp/src/builder.rs +++ b/xmpp/src/builder.rs @@ -15,6 +15,7 @@ use tokio_xmpp::{ disco::{DiscoInfoResult, Feature, Identity}, ns, }, + xmlstream::Timeouts, Client as TokioXmppClient, }; @@ -51,6 +52,7 @@ pub struct ClientBuilder<'a, C: ServerConnector> { disco: (ClientType, String), features: Vec, resource: Option, + timeouts: Timeouts, } #[cfg(any(feature = "starttls-rust", feature = "starttls-native"))] @@ -80,6 +82,7 @@ impl ClientBuilder<'_, C> { disco: (ClientType::default(), String::from("tokio-xmpp")), features: vec![], resource: None, + timeouts: Timeouts::default(), } } @@ -109,6 +112,15 @@ impl ClientBuilder<'_, C> { self } + /// Configure the timeouts used. + /// + /// See [`Timeouts`] for more information on the semantics and the + /// defaults (which are used unless you call this method). + pub fn set_timeouts(mut self, timeouts: Timeouts) -> Self { + self.timeouts = timeouts; + self + } + pub fn enable_feature(mut self, feature: ClientFeature) -> Self { self.features.push(feature); self @@ -146,8 +158,12 @@ impl ClientBuilder<'_, C> { self.jid.clone().into() }; - let mut client = - TokioXmppClient::new_with_connector(jid, self.password, self.server_connector.clone()); + let mut client = TokioXmppClient::new_with_connector( + jid, + self.password, + self.server_connector.clone(), + self.timeouts, + ); client.set_reconnect(true); self.build_impl(client) }