Add AsyncServerConnector to AsyncClient to be able to support any stream

Unfortunately API breaking unless we do some export mangling
This commit is contained in:
moparisthebest 2023-12-29 01:09:33 -05:00
parent 3d9bdd6fe2
commit 3cab603a4c
No known key found for this signature in database
GPG key ID: 88C93BFE27BC8229
6 changed files with 86 additions and 57 deletions

View file

@ -5,10 +5,6 @@ use std::pin::Pin;
use std::task::Context; use std::task::Context;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
#[cfg(feature = "tls-native")]
use tokio_native_tls::TlsStream;
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
use tokio_rustls::client::TlsStream;
use xmpp_parsers::{ns, Element, Jid}; use xmpp_parsers::{ns, Element, Jid};
use super::auth::auth; use super::auth::auth;
@ -17,8 +13,12 @@ use crate::event::Event;
use crate::happy_eyeballs::{connect_to_host, connect_with_srv}; use crate::happy_eyeballs::{connect_to_host, connect_with_srv};
use crate::starttls::starttls; use crate::starttls::starttls;
use crate::xmpp_codec::Packet; use crate::xmpp_codec::Packet;
use crate::xmpp_stream::{self, add_stanza_id}; use crate::xmpp_stream::{self, add_stanza_id, XMPPStream};
use crate::{Error, ProtocolError}; use crate::{Error, ProtocolError};
#[cfg(feature = "tls-native")]
use tokio_native_tls::TlsStream;
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
use tokio_rustls::client::TlsStream;
/// XMPP client connection and state /// XMPP client connection and state
/// ///
@ -26,13 +26,35 @@ use crate::{Error, ProtocolError};
/// ///
/// This implements the `futures` crate's [`Stream`](#impl-Stream) and /// This implements the `futures` crate's [`Stream`](#impl-Stream) and
/// [`Sink`](#impl-Sink<Packet>) traits. /// [`Sink`](#impl-Sink<Packet>) traits.
pub struct Client { pub struct Client<C: ServerConnector> {
config: Config, config: Config<C>,
state: ClientState, state: ClientState<C::Stream>,
reconnect: bool, reconnect: bool,
// TODO: tls_required=true // TODO: tls_required=true
} }
/// XMPP client configuration
#[derive(Clone, Debug)]
pub struct Config<C> {
/// jid of the account
pub jid: Jid,
/// password of the account
pub password: String,
/// server configuration for the account
pub server: C,
}
/// Trait called to connect to an XMPP server, perhaps called multiple times
pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
/// The type of Stream this ServerConnector produces
type Stream: AsyncReadAndWrite;
/// This must return the connection ready to login, ie if starttls is involved, after TLS has been started, and then after the <stream headers are exchanged
fn connect(
&self,
jid: &Jid,
) -> impl std::future::Future<Output = Result<XMPPStream<Self::Stream>, Error>> + Send;
}
/// XMPP server connection configuration /// XMPP server connection configuration
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum ServerConfig { pub enum ServerConfig {
@ -48,27 +70,46 @@ pub enum ServerConfig {
}, },
} }
/// XMPP client configuration impl ServerConnector for ServerConfig {
#[derive(Clone, Debug)] type Stream = TlsStream<TcpStream>;
pub struct Config { async fn connect(&self, jid: &Jid) -> Result<XMPPStream<Self::Stream>, Error> {
/// jid of the account // TCP connection
pub jid: Jid, let tcp_stream = match self {
/// password of the account ServerConfig::UseSrv => {
pub password: String, connect_with_srv(jid.domain_str(), "_xmpp-client._tcp", 5222).await?
/// server configuration for the account }
pub server: ServerConfig, ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), *port).await?,
};
// Unencryped XMPPStream
let xmpp_stream =
xmpp_stream::XMPPStream::start(tcp_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
.await?;
if xmpp_stream.stream_features.can_starttls() {
// TlsStream
let tls_stream = starttls(xmpp_stream).await?;
// Encrypted XMPPStream
xmpp_stream::XMPPStream::start(tls_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
.await
} else {
return Err(Error::Protocol(ProtocolError::NoTls));
}
}
} }
type XMPPStream = xmpp_stream::XMPPStream<TlsStream<TcpStream>>; /// trait used by XMPPStream type
pub trait AsyncReadAndWrite: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send {}
impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send> AsyncReadAndWrite for T {}
enum ClientState { enum ClientState<S: AsyncReadAndWrite> {
Invalid, Invalid,
Disconnected, Disconnected,
Connecting(JoinHandle<Result<XMPPStream, Error>>), Connecting(JoinHandle<Result<XMPPStream<S>, Error>>),
Connected(XMPPStream), Connected(XMPPStream<S>),
} }
impl Client { impl Client<ServerConfig> {
/// Start a new XMPP client /// Start a new XMPP client
/// ///
/// Start polling the returned instance so that it will connect /// Start polling the returned instance so that it will connect
@ -81,9 +122,11 @@ impl Client {
}; };
Self::new_with_config(config) Self::new_with_config(config)
} }
}
impl<C: ServerConnector> Client<C> {
/// Start a new client given that the JID is already parsed. /// Start a new client given that the JID is already parsed.
pub fn new_with_config(config: Config) -> Self { pub fn new_with_config(config: Config<C>) -> Self {
let connect = tokio::spawn(Self::connect( let connect = tokio::spawn(Self::connect(
config.server.clone(), config.server.clone(),
config.jid.clone(), config.jid.clone(),
@ -105,35 +148,14 @@ impl Client {
} }
async fn connect( async fn connect(
server: ServerConfig, server: C,
jid: Jid, jid: Jid,
password: String, password: String,
) -> Result<XMPPStream, Error> { ) -> Result<XMPPStream<C::Stream>, Error> {
let username = jid.node_str().unwrap(); let username = jid.node_str().unwrap();
let password = password; let password = password;
// TCP connection let xmpp_stream = server.connect(&jid).await?;
let tcp_stream = match server {
ServerConfig::UseSrv => {
connect_with_srv(jid.domain_str(), "_xmpp-client._tcp", 5222).await?
}
ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), port).await?,
};
// Unencryped XMPPStream
let xmpp_stream =
xmpp_stream::XMPPStream::start(tcp_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
.await?;
let xmpp_stream = if xmpp_stream.stream_features.can_starttls() {
// TlsStream
let tls_stream = starttls(xmpp_stream).await?;
// Encrypted XMPPStream
xmpp_stream::XMPPStream::start(tls_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
.await?
} else {
return Err(Error::Protocol(ProtocolError::NoTls));
};
let creds = Credentials::default() let creds = Credentials::default()
.with_username(username) .with_username(username)
@ -180,7 +202,7 @@ impl Client {
/// ///
/// In an `async fn` you may want to use this with `use /// In an `async fn` you may want to use this with `use
/// futures::stream::StreamExt;` /// futures::stream::StreamExt;`
impl Stream for Client { impl<C: ServerConnector> Stream for Client<C> {
type Item = Event; type Item = Event;
/// Low-level read on the XMPP stream, allowing the underlying /// Low-level read on the XMPP stream, allowing the underlying
@ -297,7 +319,7 @@ impl Stream for Client {
/// Outgoing XMPP packets /// Outgoing XMPP packets
/// ///
/// See `send_stanza()` for an `async fn` /// See `send_stanza()` for an `async fn`
impl Sink<Packet> for Client { impl<C: ServerConnector> Sink<Packet> for Client<C> {
type Error = Error; type Error = Error;
fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> { fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {

View file

@ -19,8 +19,11 @@ mod happy_eyeballs;
pub mod stream_features; pub mod stream_features;
pub mod xmpp_stream; pub mod xmpp_stream;
pub use client::{ pub use client::{
async_client::Client as AsyncClient, async_client::Config as AsyncConfig, async_client::{
async_client::ServerConfig as AsyncServerConfig, simple_client::Client as SimpleClient, AsyncReadAndWrite, Client as AsyncClient, Config as AsyncConfig,
ServerConfig as AsyncServerConfig, ServerConnector as AsyncServerConnector,
},
simple_client::Client as SimpleClient,
}; };
mod component; mod component;
pub use crate::component::Component; pub use crate::component::Component;

View file

@ -21,7 +21,7 @@ log = "0.4"
reqwest = { version = "0.11.8", features = ["stream"] } reqwest = { version = "0.11.8", features = ["stream"] }
tokio-util = { version = "0.7", features = ["codec"] } tokio-util = { version = "0.7", features = ["codec"] }
# same repository dependencies # same repository dependencies
tokio-xmpp = { version = "3.4", path = "../tokio-xmpp" } tokio-xmpp = { version = "3.4", path = "../tokio-xmpp", default-features = false }
[dev-dependencies] [dev-dependencies]
env_logger = { version = "0.10", default-features = false, features = ["auto-color", "humantime"] } env_logger = { version = "0.10", default-features = false, features = ["auto-color", "humantime"] }
@ -31,5 +31,7 @@ name = "hello_bot"
required-features = ["avatars"] required-features = ["avatars"]
[features] [features]
default = ["avatars"] default = ["avatars", "tls-native"]
tls-native = ["tokio-xmpp/tls-native"]
tls-rust = ["tokio-xmpp/tls-rust"]
avatars = [] avatars = []

View file

@ -8,10 +8,9 @@ use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
pub use tokio_xmpp::parsers; pub use tokio_xmpp::parsers;
use tokio_xmpp::parsers::{disco::DiscoInfoResult, message::MessageType}; use tokio_xmpp::parsers::{disco::DiscoInfoResult, message::MessageType};
use tokio_xmpp::AsyncClient as TokioXmppClient;
pub use tokio_xmpp::{BareJid, Element, FullJid, Jid}; pub use tokio_xmpp::{BareJid, Element, FullJid, Jid};
use crate::{event_loop, message, muc, upload, Error, Event, RoomNick}; use crate::{event_loop, message, muc, upload, Error, Event, RoomNick, TokioXmppClient};
pub struct Agent { pub struct Agent {
pub(crate) client: TokioXmppClient, pub(crate) client: TokioXmppClient,

View file

@ -10,10 +10,10 @@ use tokio_xmpp::{
disco::{DiscoInfoResult, Feature, Identity}, disco::{DiscoInfoResult, Feature, Identity},
ns, ns,
}, },
AsyncClient as TokioXmppClient, BareJid, Jid, BareJid, Jid,
}; };
use crate::{Agent, ClientFeature}; use crate::{Agent, ClientFeature, TokioXmppClient};
#[derive(Debug)] #[derive(Debug)]
pub enum ClientType { pub enum ClientType {

View file

@ -7,6 +7,7 @@
#![deny(bare_trait_objects)] #![deny(bare_trait_objects)]
pub use tokio_xmpp::parsers; pub use tokio_xmpp::parsers;
use tokio_xmpp::{AsyncClient, AsyncServerConfig};
pub use tokio_xmpp::{BareJid, Element, FullJid, Jid}; pub use tokio_xmpp::{BareJid, Element, FullJid, Jid};
#[macro_use] #[macro_use]
extern crate log; extern crate log;
@ -31,6 +32,8 @@ pub use builder::{ClientBuilder, ClientType};
pub use event::Event; pub use event::Event;
pub use feature::ClientFeature; pub use feature::ClientFeature;
type TokioXmppClient = AsyncClient<AsyncServerConfig>;
pub type Error = tokio_xmpp::Error; pub type Error = tokio_xmpp::Error;
pub type Id = Option<String>; pub type Id = Option<String>;
pub type RoomNick = String; pub type RoomNick = String;