From 4cfe4f842967f145bb3fa973a1e0d3a8820c63e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20Sch=C3=A4fer?= Date: Sun, 18 Aug 2024 17:40:39 +0200 Subject: [PATCH] xmlstream: implement simple timeout logic This allows to detect and handle dying streams without getting stuck forever. Timeouts are always wrong, though, so we put the burden of choosing the right values (mostly) on the creator of a stream. --- tokio-xmpp/Cargo.toml | 1 + tokio-xmpp/examples/echo_component.rs | 4 +- tokio-xmpp/examples/echo_server.rs | 3 +- tokio-xmpp/src/client/login.rs | 7 +- tokio-xmpp/src/client/mod.rs | 22 ++- tokio-xmpp/src/client/stream.rs | 1 + tokio-xmpp/src/component/login.rs | 5 +- tokio-xmpp/src/component/mod.rs | 26 ++- tokio-xmpp/src/connect/mod.rs | 3 +- tokio-xmpp/src/connect/starttls.rs | 5 +- tokio-xmpp/src/connect/tcp.rs | 4 +- tokio-xmpp/src/xmlstream/common.rs | 254 ++++++++++++++++++++++---- tokio-xmpp/src/xmlstream/initiator.rs | 19 +- tokio-xmpp/src/xmlstream/mod.rs | 28 ++- tokio-xmpp/src/xmlstream/tests.rs | 143 ++++++++++++++- xmpp/src/builder.rs | 20 +- 16 files changed, 469 insertions(+), 76 deletions(-) diff --git a/tokio-xmpp/Cargo.toml b/tokio-xmpp/Cargo.toml index 49e9b9f5..1a07dd35 100644 --- a/tokio-xmpp/Cargo.toml +++ b/tokio-xmpp/Cargo.toml @@ -39,6 +39,7 @@ tokio-rustls = { version = "0.26", optional = true } [dev-dependencies] env_logger = { version = "0.11", default-features = false, features = ["auto-color", "humantime"] } # this is needed for echo-component example +tokio = { version = "1", features = ["test-util"] } tokio-xmpp = { path = ".", features = ["insecure-tcp"]} [features] diff --git a/tokio-xmpp/examples/echo_component.rs b/tokio-xmpp/examples/echo_component.rs index 70ffa538..6c549349 100644 --- a/tokio-xmpp/examples/echo_component.rs +++ b/tokio-xmpp/examples/echo_component.rs @@ -31,9 +31,7 @@ async fn main() { // If you don't need a custom server but default localhost:5347, you can use // Component::new() directly - let mut component = Component::new_plaintext(jid, password, server) - .await - .unwrap(); + let mut component = Component::new(jid, password).await.unwrap(); // Make the two interfaces for sending and receiving independent // of each other so we can move one into a closure. diff --git a/tokio-xmpp/examples/echo_server.rs b/tokio-xmpp/examples/echo_server.rs index dbf2eb05..6cbb75cd 100644 --- a/tokio-xmpp/examples/echo_server.rs +++ b/tokio-xmpp/examples/echo_server.rs @@ -2,7 +2,7 @@ use futures::{SinkExt, StreamExt}; use tokio::{self, io, net::TcpSocket}; use tokio_xmpp::parsers::stream_features::StreamFeatures; -use tokio_xmpp::xmlstream::{accept_stream, StreamHeader}; +use tokio_xmpp::xmlstream::{accept_stream, StreamHeader, Timeouts}; #[tokio::main] async fn main() -> Result<(), io::Error> { @@ -19,6 +19,7 @@ async fn main() -> Result<(), io::Error> { let stream = accept_stream( tokio::io::BufStream::new(stream), tokio_xmpp::parsers::ns::DEFAULT_NS, + Timeouts::default(), ) .await?; let stream = stream.send_header(StreamHeader::default()).await?; diff --git a/tokio-xmpp/src/client/login.rs b/tokio-xmpp/src/client/login.rs index 14cb0293..90aa0a34 100644 --- a/tokio-xmpp/src/client/login.rs +++ b/tokio-xmpp/src/client/login.rs @@ -19,7 +19,9 @@ use crate::{ client::bind::bind, connect::ServerConnector, error::{AuthError, Error, ProtocolError}, - xmlstream::{xmpp::XmppStreamElement, InitiatingStream, ReadError, StreamHeader, XmppStream}, + xmlstream::{ + xmpp::XmppStreamElement, InitiatingStream, ReadError, StreamHeader, Timeouts, XmppStream, + }, }; pub async fn auth( @@ -107,11 +109,12 @@ pub async fn client_login( server: C, jid: Jid, password: String, + timeouts: Timeouts, ) -> Result<(Option, StreamFeatures, XmppStream), Error> { let username = jid.node().unwrap().as_str(); let password = password; - let xmpp_stream = server.connect(&jid, ns::JABBER_CLIENT).await?; + let xmpp_stream = server.connect(&jid, ns::JABBER_CLIENT, timeouts).await?; let (features, xmpp_stream) = xmpp_stream.recv_features().await?; let channel_binding = C::channel_binding(xmpp_stream.get_stream())?; diff --git a/tokio-xmpp/src/client/mod.rs b/tokio-xmpp/src/client/mod.rs index 13fcf6f2..0197ccd3 100644 --- a/tokio-xmpp/src/client/mod.rs +++ b/tokio-xmpp/src/client/mod.rs @@ -5,6 +5,7 @@ use crate::{ client::{login::client_login, stream::ClientState}, connect::ServerConnector, error::Error, + xmlstream::Timeouts, Stanza, }; @@ -30,6 +31,7 @@ pub struct Client { password: String, connector: C, state: ClientState, + timeouts: Timeouts, reconnect: bool, // TODO: tls_required=true } @@ -95,6 +97,7 @@ impl Client { jid.clone(), password, DnsConfig::srv(&jid.domain().to_string(), "_xmpp-client._tcp", 5222), + Timeouts::default(), ); client.set_reconnect(true); client @@ -105,8 +108,14 @@ impl Client { jid: J, password: P, dns_config: DnsConfig, + timeouts: Timeouts, ) -> Self { - Self::new_with_connector(jid, password, StartTlsServerConnector::from(dns_config)) + Self::new_with_connector( + jid, + password, + StartTlsServerConnector::from(dns_config), + timeouts, + ) } } @@ -117,8 +126,14 @@ impl Client { jid: J, password: P, dns_config: DnsConfig, + timeouts: Timeouts, ) -> Self { - Self::new_with_connector(jid, password, TcpServerConnector::from(dns_config)) + Self::new_with_connector( + jid, + password, + TcpServerConnector::from(dns_config), + timeouts, + ) } } @@ -128,6 +143,7 @@ impl Client { jid: J, password: P, connector: C, + timeouts: Timeouts, ) -> Self { let jid = jid.into(); let password = password.into(); @@ -136,6 +152,7 @@ impl Client { connector.clone(), jid.clone(), password.clone(), + timeouts, )); let client = Client { jid, @@ -143,6 +160,7 @@ impl Client { connector, state: ClientState::Connecting(connect), reconnect: false, + timeouts, }; client } diff --git a/tokio-xmpp/src/client/stream.rs b/tokio-xmpp/src/client/stream.rs index 7b19ecb8..8120a870 100644 --- a/tokio-xmpp/src/client/stream.rs +++ b/tokio-xmpp/src/client/stream.rs @@ -56,6 +56,7 @@ impl Stream for Client { self.connector.clone(), self.jid.clone(), self.password.clone(), + self.timeouts, )); self.state = ClientState::Connecting(connect); self.poll_next(cx) diff --git a/tokio-xmpp/src/component/login.rs b/tokio-xmpp/src/component/login.rs index 427ef89a..33b743e4 100644 --- a/tokio-xmpp/src/component/login.rs +++ b/tokio-xmpp/src/component/login.rs @@ -6,16 +6,17 @@ use xmpp_parsers::{component::Handshake, jid::Jid, ns}; use crate::component::ServerConnector; use crate::error::{AuthError, Error}; -use crate::xmlstream::{ReadError, XmppStream, XmppStreamElement}; +use crate::xmlstream::{ReadError, Timeouts, XmppStream, XmppStreamElement}; /// Log into an XMPP server as a client with a jid+pass pub async fn component_login( connector: C, jid: Jid, password: String, + timeouts: Timeouts, ) -> Result, Error> { let password = password; - let mut stream = connector.connect(&jid, ns::COMPONENT).await?; + let mut stream = connector.connect(&jid, ns::COMPONENT, timeouts).await?; let header = stream.take_header(); let mut stream = stream.skip_features(); let stream_id = match header.id { diff --git a/tokio-xmpp/src/component/mod.rs b/tokio-xmpp/src/component/mod.rs index efd58a40..d041fe46 100644 --- a/tokio-xmpp/src/component/mod.rs +++ b/tokio-xmpp/src/component/mod.rs @@ -6,8 +6,10 @@ use std::str::FromStr; use xmpp_parsers::jid::Jid; use crate::{ - component::login::component_login, connect::ServerConnector, xmlstream::XmppStream, Error, - Stanza, + component::login::component_login, + connect::ServerConnector, + xmlstream::{Timeouts, XmppStream}, + Error, Stanza, }; #[cfg(any(feature = "starttls", feature = "insecure-tcp"))] @@ -46,7 +48,13 @@ impl Component { /// Start a new XMPP component over plaintext TCP to localhost:5347 #[cfg(feature = "insecure-tcp")] pub async fn new(jid: &str, password: &str) -> Result { - Self::new_plaintext(jid, password, DnsConfig::addr("127.0.0.1:5347")).await + Self::new_plaintext( + jid, + password, + DnsConfig::addr("127.0.0.1:5347"), + Timeouts::tight(), + ) + .await } /// Start a new XMPP component over plaintext TCP @@ -55,8 +63,15 @@ impl Component { jid: &str, password: &str, dns_config: DnsConfig, + timeouts: Timeouts, ) -> Result { - Component::new_with_connector(jid, password, TcpServerConnector::from(dns_config)).await + Component::new_with_connector( + jid, + password, + TcpServerConnector::from(dns_config), + timeouts, + ) + .await } } @@ -69,10 +84,11 @@ impl Component { jid: &str, password: &str, connector: C, + timeouts: Timeouts, ) -> Result { let jid = Jid::from_str(jid)?; let password = password.to_owned(); - let stream = component_login(connector, jid.clone(), password).await?; + let stream = component_login(connector, jid.clone(), password, timeouts).await?; Ok(Component { jid, stream }) } } diff --git a/tokio-xmpp/src/connect/mod.rs b/tokio-xmpp/src/connect/mod.rs index 937b0402..ca1ae61f 100644 --- a/tokio-xmpp/src/connect/mod.rs +++ b/tokio-xmpp/src/connect/mod.rs @@ -4,7 +4,7 @@ use sasl::common::ChannelBinding; use tokio::io::{AsyncBufRead, AsyncWrite}; use xmpp_parsers::jid::Jid; -use crate::xmlstream::PendingFeaturesRecv; +use crate::xmlstream::{PendingFeaturesRecv, Timeouts}; use crate::Error; #[cfg(feature = "starttls")] @@ -36,6 +36,7 @@ pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static { &self, jid: &Jid, ns: &'static str, + timeouts: Timeouts, ) -> impl std::future::Future, Error>> + Send; /// Return channel binding data if available diff --git a/tokio-xmpp/src/connect/starttls.rs b/tokio-xmpp/src/connect/starttls.rs index 1d0d26e4..7df476c1 100644 --- a/tokio-xmpp/src/connect/starttls.rs +++ b/tokio-xmpp/src/connect/starttls.rs @@ -44,7 +44,7 @@ use crate::{ connect::{DnsConfig, ServerConnector, ServerConnectorError}, error::{Error, ProtocolError}, xmlstream::{ - initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, XmppStream, + initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, Timeouts, XmppStream, XmppStreamElement, }, Client, @@ -70,6 +70,7 @@ impl ServerConnector for StartTlsServerConnector { &self, jid: &Jid, ns: &'static str, + timeouts: Timeouts, ) -> Result, Error> { let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?); @@ -82,6 +83,7 @@ impl ServerConnector for StartTlsServerConnector { from: None, id: None, }, + timeouts, ) .await?; let (features, xmpp_stream) = xmpp_stream.recv_features().await?; @@ -98,6 +100,7 @@ impl ServerConnector for StartTlsServerConnector { from: None, id: None, }, + timeouts, ) .await?) } else { diff --git a/tokio-xmpp/src/connect/tcp.rs b/tokio-xmpp/src/connect/tcp.rs index 89a9e12d..474e25b7 100644 --- a/tokio-xmpp/src/connect/tcp.rs +++ b/tokio-xmpp/src/connect/tcp.rs @@ -6,7 +6,7 @@ use tokio::{io::BufStream, net::TcpStream}; use crate::{ connect::{DnsConfig, ServerConnector}, - xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader}, + xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader, Timeouts}, Client, Component, Error, }; @@ -35,6 +35,7 @@ impl ServerConnector for TcpServerConnector { &self, jid: &xmpp_parsers::jid::Jid, ns: &'static str, + timeouts: Timeouts, ) -> Result, Error> { let stream = BufStream::new(self.0.resolve().await?); Ok(initiate_stream( @@ -45,6 +46,7 @@ impl ServerConnector for TcpServerConnector { from: None, id: None, }, + timeouts, ) .await?) } diff --git a/tokio-xmpp/src/xmlstream/common.rs b/tokio-xmpp/src/xmlstream/common.rs index 2ea0526b..ff509824 100644 --- a/tokio-xmpp/src/xmlstream/common.rs +++ b/tokio-xmpp/src/xmlstream/common.rs @@ -7,6 +7,7 @@ use core::future::Future; use core::pin::Pin; use core::task::{Context, Poll}; +use core::time::Duration; use std::borrow::Cow; use std::io; @@ -14,7 +15,10 @@ use futures::{ready, Sink, SinkExt, Stream, StreamExt}; use bytes::{Buf, BytesMut}; -use tokio::io::{AsyncBufRead, AsyncWrite}; +use tokio::{ + io::{AsyncBufRead, AsyncWrite}, + time::Instant, +}; use xso::{ exports::rxml::{self, writer::TrackNamespace, xml_ncname, Event, Namespace}, @@ -25,6 +29,129 @@ use super::capture::{log_enabled, log_recv, log_send, CaptureBufRead}; use xmpp_parsers::ns::STREAM as XML_STREAM_NS; +/// Configuration for timeouts on an XML stream. +/// +/// The defaults are tuned toward common desktop/laptop use and may not hold +/// up to extreme conditions (arctic sattelite link, mobile internet on a +/// train in Brandenburg, Germany, and similar) and may be inefficient in +/// other conditions (stable server link, localhost communication). +#[derive(Debug, Clone, Copy)] +pub struct Timeouts { + /// Maximum silence time before a + /// [`ReadError::SoftTimeout`][`super::ReadError::SoftTimeout`] is + /// returned. + /// + /// Soft timeouts are not fatal, but they must be handled by user code so + /// that more data is read after at most [`Self::response_timeout`], + /// starting from the moment the soft timeout is returned. + pub read_timeout: Duration, + + /// Maximum silence after a soft timeout. + /// + /// If the stream is silent for longer than this time after a soft timeout + /// has been emitted, a hard [`TimedOut`][`std::io::ErrorKind::TimedOut`] + /// I/O error is returned and the stream is to be considered dead. + pub response_timeout: Duration, +} + +impl Default for Timeouts { + fn default() -> Self { + Self { + read_timeout: Duration::new(300, 0), + response_timeout: Duration::new(300, 0), + } + } +} + +impl Timeouts { + /// Tight timeouts suitable for communicating on a fast LAN or localhost. + pub fn tight() -> Self { + Self { + read_timeout: Duration::new(60, 0), + response_timeout: Duration::new(15, 0), + } + } + + fn data_to_soft(&self) -> Duration { + self.read_timeout + } + + fn soft_to_warn(&self) -> Duration { + self.response_timeout / 2 + } + + fn warn_to_hard(&self) -> Duration { + self.response_timeout / 2 + } +} + +#[derive(Clone, Copy)] +enum TimeoutLevel { + Soft, + Warn, + Hard, +} + +#[derive(Debug)] +pub(super) enum RawError { + Io(io::Error), + SoftTimeout, +} + +impl From for RawError { + fn from(other: io::Error) -> Self { + Self::Io(other) + } +} + +struct TimeoutState { + /// Configuration for the timeouts. + timeouts: Timeouts, + + /// Level of the next timeout which will trip. + level: TimeoutLevel, + + /// Sleep timer used for read timeouts. + // NOTE: even though we pretend we could deal with an !Unpin + // RawXmlStream, we really can't: box_stream for example needs it, + // but also all the typestate around the initial stream setup needs + // to be able to move the stream around. + deadline: Pin>, +} + +impl TimeoutState { + fn new(timeouts: Timeouts) -> Self { + Self { + deadline: Box::pin(tokio::time::sleep(timeouts.data_to_soft())), + level: TimeoutLevel::Soft, + timeouts, + } + } + + fn poll(&mut self, cx: &mut Context) -> Poll { + ready!(self.deadline.as_mut().poll(cx)); + // Deadline elapsed! + let to_return = self.level; + let (next_level, next_duration) = match self.level { + TimeoutLevel::Soft => (TimeoutLevel::Warn, self.timeouts.soft_to_warn()), + TimeoutLevel::Warn => (TimeoutLevel::Hard, self.timeouts.warn_to_hard()), + // Something short-ish so that we fire this over and over until + // someone finally kills the stream for good. + TimeoutLevel::Hard => (TimeoutLevel::Hard, Duration::new(1, 0)), + }; + self.level = next_level; + self.deadline.as_mut().reset(Instant::now() + next_duration); + Poll::Ready(to_return) + } + + fn reset(&mut self) { + self.level = TimeoutLevel::Soft; + self.deadline + .as_mut() + .reset((Instant::now() + self.timeouts.data_to_soft()).into()); + } +} + pin_project_lite::pin_project! { // NOTE: due to limitations of pin_project_lite, the field comments are // no doc comments. Luckily, this struct is only `pub(super)` anyway. @@ -37,6 +164,8 @@ pin_project_lite::pin_project! { // The writer used for serialising data. writer: rxml::writer::Encoder, + timeouts: TimeoutState, + // The default namespace to declare on the stream header. stream_ns: &'static str, @@ -112,7 +241,7 @@ impl RawXmlStream { writer } - pub(super) fn new(io: Io, stream_ns: &'static str) -> Self { + pub(super) fn new(io: Io, stream_ns: &'static str, timeouts: Timeouts) -> Self { let parser = rxml::Parser::default(); let mut io = CaptureBufRead::wrap(io); if log_enabled() { @@ -121,6 +250,7 @@ impl RawXmlStream { Self { parser: rxml::AsyncReader::wrap(io, parser), writer: Self::new_writer(stream_ns), + timeouts: TimeoutState::new(timeouts), tx_buffer_logged: 0, stream_ns, tx_buffer: BytesMut::new(), @@ -189,18 +319,34 @@ impl RawXmlStream { } impl Stream for RawXmlStream { - type Item = Result; + type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); loop { - return Poll::Ready( - match ready!(this.parser.as_mut().poll_read(cx)).transpose() { - // Skip the XML declaration, nobody wants to hear about that. - Some(Ok(rxml::Event::XmlDeclaration(_, _))) => continue, - other => other, - }, - ); + match this.parser.as_mut().poll_read(cx) { + Poll::Pending => (), + Poll::Ready(v) => { + this.timeouts.reset(); + match v.transpose() { + // Skip the XML declaration, nobody wants to hear about that. + Some(Ok(rxml::Event::XmlDeclaration(_, _))) => continue, + other => return Poll::Ready(other.map(|x| x.map_err(RawError::Io))), + } + } + }; + + // poll_read returned pending... what do the timeouts have to say? + match ready!(this.timeouts.poll(cx)) { + TimeoutLevel::Soft => return Poll::Ready(Some(Err(RawError::SoftTimeout))), + TimeoutLevel::Warn => (), + TimeoutLevel::Hard => { + return Poll::Ready(Some(Err(RawError::Io(io::Error::new( + io::ErrorKind::TimedOut, + "read and response timeouts elapsed", + ))))) + } + } } } } @@ -312,6 +458,20 @@ pub(super) enum ReadXsoError { /// not well-formed document. Hard(io::Error), + /// The underlying stream signalled a soft read timeout before a child + /// element could be read. + /// + /// Note that soft timeouts which are triggered in the middle of receiving + /// an element are converted to hard timeouts (i.e. I/O errors). + /// + /// This masking is intentional, because: + /// - Returning a [`Self::SoftTimeout`] from the middle of parsing is not + /// possible without complicating the API. + /// - There is no reason why the remote side should interrupt sending data + /// in the middle of an element except if it or the transport has failed + /// fatally. + SoftTimeout, + /// A parse error occurred. /// /// The XML structure was well-formed, but the data contained did not @@ -324,19 +484,6 @@ pub(super) enum ReadXsoError { Parse(xso::error::Error), } -impl From for io::Error { - fn from(other: ReadXsoError) -> Self { - match other { - ReadXsoError::Hard(v) => v, - ReadXsoError::Parse(e) => io::Error::new(io::ErrorKind::InvalidData, e), - ReadXsoError::Footer => io::Error::new( - io::ErrorKind::UnexpectedEof, - "element footer while waiting for XSO element start", - ), - } - } -} - impl From for ReadXsoError { fn from(other: io::Error) -> Self { Self::Hard(other) @@ -425,13 +572,13 @@ impl ReadXsoState { .parser_pinned() .set_text_buffering(text_buffering); - let ev = ready!(source.as_mut().poll_next(cx)).transpose()?; + let ev = ready!(source.as_mut().poll_next(cx)).transpose(); match self { ReadXsoState::PreData => { log::trace!("ReadXsoState::PreData ev = {:?}", ev); match ev { - Some(rxml::Event::XmlDeclaration(_, _)) => (), - Some(rxml::Event::Text(_, data)) => { + Ok(Some(rxml::Event::XmlDeclaration(_, _))) => (), + Ok(Some(rxml::Event::Text(_, data))) => { if xso::is_xml_whitespace(data.as_bytes()) { log::trace!("Received {} bytes of whitespace", data.len()); source.as_mut().stream_pinned().discard_capture(); @@ -445,18 +592,18 @@ impl ReadXsoState { .into())); } } - Some(rxml::Event::StartElement(_, name, attrs)) => { + Ok(Some(rxml::Event::StartElement(_, name, attrs))) => { *self = ReadXsoState::Parsing( as FromXml>::from_events(name, attrs) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?, ); } // Amounts to EOF, as we expect to start on the stream level. - Some(rxml::Event::EndElement(_)) => { + Ok(Some(rxml::Event::EndElement(_))) => { *self = ReadXsoState::Done; return Poll::Ready(Err(ReadXsoError::Footer)); } - None => { + Ok(None) => { *self = ReadXsoState::Done; return Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidData, @@ -464,17 +611,42 @@ impl ReadXsoState { ) .into())); } + Err(RawError::SoftTimeout) => { + *self = ReadXsoState::Done; + return Poll::Ready(Err(ReadXsoError::SoftTimeout)); + } + Err(RawError::Io(e)) => { + *self = ReadXsoState::Done; + return Poll::Ready(Err(ReadXsoError::Hard(e))); + } } } ReadXsoState::Parsing(builder) => { log::trace!("ReadXsoState::Parsing ev = {:?}", ev); - let Some(ev) = ev else { - *self = ReadXsoState::Done; - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "eof during XSO parsing", - ) - .into())); + let ev = match ev { + Ok(Some(ev)) => ev, + Ok(None) => { + *self = ReadXsoState::Done; + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "eof during XSO parsing", + ) + .into())); + } + Err(RawError::Io(e)) => { + *self = ReadXsoState::Done; + return Poll::Ready(Err(e.into())); + } + Err(RawError::SoftTimeout) => { + // See also [`ReadXsoError::SoftTimeout`] for why + // we mask the SoftTimeout condition here. + *self = ReadXsoState::Done; + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::TimedOut, + "read timeout during XSO parsing", + ) + .into())); + } }; match builder.feed(ev) { @@ -622,8 +794,10 @@ impl StreamHeader<'static> { mut stream: Pin<&mut RawXmlStream>, ) -> io::Result { loop { - match stream.as_mut().next().await.transpose()? { - Some(Event::StartElement(_, (ns, name), mut attrs)) => { + match stream.as_mut().next().await { + Some(Err(RawError::Io(e))) => return Err(e), + Some(Err(RawError::SoftTimeout)) => (), + Some(Ok(Event::StartElement(_, (ns, name), mut attrs))) => { if ns != XML_STREAM_NS || name != "stream" { return Err(io::Error::new( io::ErrorKind::InvalidData, @@ -666,7 +840,7 @@ impl StreamHeader<'static> { id: id.map(Cow::Owned), }); } - Some(Event::Text(_, _)) | Some(Event::EndElement(_)) => { + Some(Ok(Event::Text(_, _))) | Some(Ok(Event::EndElement(_))) => { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, "unexpected content before stream header", @@ -674,7 +848,7 @@ impl StreamHeader<'static> { } // We cannot loop infinitely here because the XML parser will // prevent more than one XML declaration from being parsed. - Some(Event::XmlDeclaration(_, _)) => (), + Some(Ok(Event::XmlDeclaration(_, _))) => (), None => { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, diff --git a/tokio-xmpp/src/xmlstream/initiator.rs b/tokio-xmpp/src/xmlstream/initiator.rs index 0829c9a0..4682548d 100644 --- a/tokio-xmpp/src/xmlstream/initiator.rs +++ b/tokio-xmpp/src/xmlstream/initiator.rs @@ -17,7 +17,7 @@ use xmpp_parsers::stream_features::StreamFeatures; use xso::{AsXml, FromXml}; use super::{ - common::{RawXmlStream, ReadXso, StreamHeader}, + common::{RawXmlStream, ReadXso, ReadXsoError, StreamHeader}, XmlStream, }; @@ -80,7 +80,22 @@ impl PendingFeaturesRecv { mut stream, header: _, } = self; - let features = ReadXso::read_from(Pin::new(&mut stream)).await?; + let features = loop { + match ReadXso::read_from(Pin::new(&mut stream)).await { + Ok(v) => break v, + Err(ReadXsoError::SoftTimeout) => (), + Err(ReadXsoError::Hard(e)) => return Err(e), + Err(ReadXsoError::Parse(e)) => { + return Err(io::Error::new(io::ErrorKind::InvalidData, e)) + } + Err(ReadXsoError::Footer) => { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "unexpected stream footer", + )) + } + } + }; Ok((features, XmlStream::wrap(stream))) } diff --git a/tokio-xmpp/src/xmlstream/mod.rs b/tokio-xmpp/src/xmlstream/mod.rs index a4d88a6a..5e414966 100644 --- a/tokio-xmpp/src/xmlstream/mod.rs +++ b/tokio-xmpp/src/xmlstream/mod.rs @@ -77,8 +77,8 @@ mod responder; mod tests; pub(crate) mod xmpp; -pub use self::common::StreamHeader; -use self::common::{RawXmlStream, ReadXsoError, ReadXsoState}; +use self::common::{RawError, RawXmlStream, ReadXsoError, ReadXsoState}; +pub use self::common::{StreamHeader, Timeouts}; pub use self::initiator::{InitiatingStream, PendingFeaturesRecv}; pub use self::responder::{AcceptedStream, PendingFeaturesSend}; pub use self::xmpp::XmppStreamElement; @@ -129,8 +129,9 @@ pub async fn initiate_stream( io: Io, stream_ns: &'static str, stream_header: StreamHeader<'_>, + timeouts: Timeouts, ) -> Result, io::Error> { - let stream = InitiatingStream(RawXmlStream::new(io, stream_ns)); + let stream = InitiatingStream(RawXmlStream::new(io, stream_ns, timeouts)); stream.send_header(stream_header).await } @@ -144,8 +145,9 @@ pub async fn initiate_stream( pub async fn accept_stream( io: Io, stream_ns: &'static str, + timeouts: Timeouts, ) -> Result, io::Error> { - let mut stream = RawXmlStream::new(io, stream_ns); + let mut stream = RawXmlStream::new(io, stream_ns, timeouts); let header = StreamHeader::recv(Pin::new(&mut stream)).await?; Ok(AcceptedStream { stream, header }) } @@ -319,14 +321,21 @@ impl Stream for XmlStream; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); + let mut this = self.project(); let result = match this.read_state.as_mut() { None => { // awaiting eof. - return match ready!(this.inner.poll_next(cx)) { - None => Poll::Ready(None), - Some(Ok(_)) => unreachable!("xml parser allowed data after stream footer"), - Some(Err(e)) => Poll::Ready(Some(Err(ReadError::HardError(e)))), + return loop { + match ready!(this.inner.as_mut().poll_next(cx)) { + None => break Poll::Ready(None), + Some(Ok(_)) => unreachable!("xml parser allowed data after stream footer"), + Some(Err(RawError::Io(e))) => { + break Poll::Ready(Some(Err(ReadError::HardError(e)))) + } + // Swallow soft timeout, we don't want the user to trigger + // anything here. + Some(Err(RawError::SoftTimeout)) => continue, + } }; } Some(read_state) => ready!(read_state.poll_advance(this.inner, cx)), @@ -341,6 +350,7 @@ impl Stream for XmlStream Poll::Ready(Some(Err(ReadError::SoftTimeout))), }; *this.read_state = Some(ReadXsoState::default()); result diff --git a/tokio-xmpp/src/xmlstream/tests.rs b/tokio-xmpp/src/xmlstream/tests.rs index 2270547c..44c0256f 100644 --- a/tokio-xmpp/src/xmlstream/tests.rs +++ b/tokio-xmpp/src/xmlstream/tests.rs @@ -4,6 +4,8 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. +use std::time::Duration; + use futures::{SinkExt, StreamExt}; use xmpp_parsers::stream_features::StreamFeatures; @@ -29,12 +31,18 @@ async fn test_initiate_accept_stream() { to: Some("server".into()), id: Some("client-id".into()), }, + Timeouts::tight(), ) .await?; Ok::<_, io::Error>(stream.take_header()) }); let responder = tokio::spawn(async move { - let stream = accept_stream(tokio::io::BufStream::new(rhs), "jabber:client").await?; + let stream = accept_stream( + tokio::io::BufStream::new(rhs), + "jabber:client", + Timeouts::tight(), + ) + .await?; assert_eq!(stream.header().from.unwrap(), "client"); assert_eq!(stream.header().to.unwrap(), "server"); assert_eq!(stream.header().id.unwrap(), "client-id"); @@ -61,13 +69,19 @@ async fn test_exchange_stream_features() { tokio::io::BufStream::new(lhs), "jabber:client", StreamHeader::default(), + Timeouts::tight(), ) .await?; let (features, _) = stream.recv_features::().await?; Ok::<_, io::Error>(features) }); let responder = tokio::spawn(async move { - let stream = accept_stream(tokio::io::BufStream::new(rhs), "jabber:client").await?; + let stream = accept_stream( + tokio::io::BufStream::new(rhs), + "jabber:client", + Timeouts::tight(), + ) + .await?; let stream = stream.send_header(StreamHeader::default()).await?; stream .send_features::(&StreamFeatures::default()) @@ -88,6 +102,7 @@ async fn test_exchange_data() { tokio::io::BufStream::new(lhs), "jabber:client", StreamHeader::default(), + Timeouts::tight(), ) .await?; let (_, mut stream) = stream.recv_features::().await?; @@ -104,7 +119,12 @@ async fn test_exchange_data() { }); let responder = tokio::spawn(async move { - let stream = accept_stream(tokio::io::BufStream::new(rhs), "jabber:client").await?; + let stream = accept_stream( + tokio::io::BufStream::new(rhs), + "jabber:client", + Timeouts::tight(), + ) + .await?; let stream = stream.send_header(StreamHeader::default()).await?; let mut stream = stream .send_features::(&StreamFeatures::default()) @@ -134,6 +154,7 @@ async fn test_clean_shutdown() { tokio::io::BufStream::new(lhs), "jabber:client", StreamHeader::default(), + Timeouts::tight(), ) .await?; let (_, mut stream) = stream.recv_features::().await?; @@ -146,7 +167,12 @@ async fn test_clean_shutdown() { }); let responder = tokio::spawn(async move { - let stream = accept_stream(tokio::io::BufStream::new(rhs), "jabber:client").await?; + let stream = accept_stream( + tokio::io::BufStream::new(rhs), + "jabber:client", + Timeouts::tight(), + ) + .await?; let stream = stream.send_header(StreamHeader::default()).await?; let mut stream = stream .send_features::(&StreamFeatures::default()) @@ -172,6 +198,7 @@ async fn test_exchange_data_stream_reset_and_shutdown() { tokio::io::BufStream::new(lhs), "jabber:client", StreamHeader::default(), + Timeouts::tight(), ) .await?; let (_, mut stream) = stream.recv_features::().await?; @@ -215,7 +242,12 @@ async fn test_exchange_data_stream_reset_and_shutdown() { }); let responder = tokio::spawn(async move { - let stream = accept_stream(tokio::io::BufStream::new(rhs), "jabber:client").await?; + let stream = accept_stream( + tokio::io::BufStream::new(rhs), + "jabber:client", + Timeouts::tight(), + ) + .await?; let stream = stream.send_header(StreamHeader::default()).await?; let mut stream = stream .send_features::(&StreamFeatures::default()) @@ -262,3 +294,104 @@ async fn test_exchange_data_stream_reset_and_shutdown() { responder.await.unwrap().expect("responder failed"); initiator.await.unwrap().expect("initiator failed"); } + +#[tokio::test(start_paused = true)] +async fn test_emits_soft_timeout_after_silence() { + let (lhs, rhs) = tokio::io::duplex(65536); + + let client_timeouts = Timeouts { + read_timeout: Duration::new(300, 0), + response_timeout: Duration::new(15, 0), + }; + + // We do want to trigger only one set of timeouts, so we set the server + // timeouts much longer than the client timeouts + let server_timeouts = Timeouts { + read_timeout: Duration::new(900, 0), + response_timeout: Duration::new(15, 0), + }; + + let initiator = tokio::spawn(async move { + let stream = initiate_stream( + tokio::io::BufStream::new(lhs), + "jabber:client", + StreamHeader::default(), + client_timeouts, + ) + .await?; + let (_, mut stream) = stream.recv_features::().await?; + stream + .send(&Data { + contents: "hello".to_owned(), + }) + .await?; + match stream.next().await { + Some(Ok(Data { contents })) => assert_eq!(contents, "world!"), + other => panic!("unexpected stream message: {:?}", other), + } + // Here we prove that the stream doesn't see any data and also does + // not see the SoftTimeout too early. + // (Well, not exactly a proof: We only check until half of the read + // timeout, because that was easy to write and I deem it good enough.) + match tokio::time::timeout(client_timeouts.read_timeout / 2, stream.next()).await { + Err(_) => (), + Ok(ev) => panic!("early stream message (before soft timeout): {:?}", ev), + }; + // Now the next thing that happens is the soft timeout ... + match stream.next().await { + Some(Err(ReadError::SoftTimeout)) => (), + other => panic!("unexpected stream message: {:?}", other), + } + // Another check that the there is some time between soft and hard + // timeout. + match tokio::time::timeout(client_timeouts.response_timeout / 3, stream.next()).await { + Err(_) => (), + Ok(ev) => { + panic!("early stream message (before hard timeout): {:?}", ev); + } + }; + // ... and thereafter the hard timeout in form of an I/O error. + match stream.next().await { + Some(Err(ReadError::HardError(e))) if e.kind() == io::ErrorKind::TimedOut => (), + other => panic!("unexpected stream message: {:?}", other), + } + Ok::<_, io::Error>(()) + }); + + let responder = tokio::spawn(async move { + let stream = accept_stream( + tokio::io::BufStream::new(rhs), + "jabber:client", + server_timeouts, + ) + .await?; + let stream = stream.send_header(StreamHeader::default()).await?; + let mut stream = stream + .send_features::(&StreamFeatures::default()) + .await?; + stream + .send(&Data { + contents: "world!".to_owned(), + }) + .await?; + match stream.next().await { + Some(Ok(Data { contents })) => assert_eq!(contents, "hello"), + other => panic!("unexpected stream message: {:?}", other), + } + match stream.next().await { + Some(Err(ReadError::HardError(e))) if e.kind() == io::ErrorKind::InvalidData => { + match e.downcast::() { + // the initiator closes the stream by dropping it once the + // timeout trips, so we get a hard eof here. + Ok(rxml::Error::InvalidEof(_)) => (), + other => panic!("unexpected error: {:?}", other), + } + } + other => panic!("unexpected stream message: {:?}", other), + } + Ok::<_, io::Error>(()) + }); + + responder.await.unwrap().expect("responder failed"); + initiator.await.unwrap().expect("initiator failed"); +} diff --git a/xmpp/src/builder.rs b/xmpp/src/builder.rs index 8b0e769d..f8f29ce3 100644 --- a/xmpp/src/builder.rs +++ b/xmpp/src/builder.rs @@ -15,6 +15,7 @@ use tokio_xmpp::{ disco::{DiscoInfoResult, Feature, Identity}, ns, }, + xmlstream::Timeouts, Client as TokioXmppClient, }; @@ -51,6 +52,7 @@ pub struct ClientBuilder<'a, C: ServerConnector> { disco: (ClientType, String), features: Vec, resource: Option, + timeouts: Timeouts, } #[cfg(any(feature = "starttls-rust", feature = "starttls-native"))] @@ -80,6 +82,7 @@ impl ClientBuilder<'_, C> { disco: (ClientType::default(), String::from("tokio-xmpp")), features: vec![], resource: None, + timeouts: Timeouts::default(), } } @@ -109,6 +112,15 @@ impl ClientBuilder<'_, C> { self } + /// Configure the timeouts used. + /// + /// See [`Timeouts`] for more information on the semantics and the + /// defaults (which are used unless you call this method). + pub fn set_timeouts(mut self, timeouts: Timeouts) -> Self { + self.timeouts = timeouts; + self + } + pub fn enable_feature(mut self, feature: ClientFeature) -> Self { self.features.push(feature); self @@ -146,8 +158,12 @@ impl ClientBuilder<'_, C> { self.jid.clone().into() }; - let mut client = - TokioXmppClient::new_with_connector(jid, self.password, self.server_connector.clone()); + let mut client = TokioXmppClient::new_with_connector( + jid, + self.password, + self.server_connector.clone(), + self.timeouts, + ); client.set_reconnect(true); self.build_impl(client) }