Add AsyncServerConnector to AsyncClient to be able to support any stream

Unfortunately API breaking unless we do some export mangling
This commit is contained in:
moparisthebest 2023-12-29 01:09:33 -05:00
parent 3d9bdd6fe2
commit 3cab603a4c
No known key found for this signature in database
GPG key ID: 88C93BFE27BC8229
6 changed files with 86 additions and 57 deletions

View file

@ -5,10 +5,6 @@ use std::pin::Pin;
use std::task::Context;
use tokio::net::TcpStream;
use tokio::task::JoinHandle;
#[cfg(feature = "tls-native")]
use tokio_native_tls::TlsStream;
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
use tokio_rustls::client::TlsStream;
use xmpp_parsers::{ns, Element, Jid};
use super::auth::auth;
@ -17,8 +13,12 @@ use crate::event::Event;
use crate::happy_eyeballs::{connect_to_host, connect_with_srv};
use crate::starttls::starttls;
use crate::xmpp_codec::Packet;
use crate::xmpp_stream::{self, add_stanza_id};
use crate::xmpp_stream::{self, add_stanza_id, XMPPStream};
use crate::{Error, ProtocolError};
#[cfg(feature = "tls-native")]
use tokio_native_tls::TlsStream;
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
use tokio_rustls::client::TlsStream;
/// XMPP client connection and state
///
@ -26,13 +26,35 @@ use crate::{Error, ProtocolError};
///
/// This implements the `futures` crate's [`Stream`](#impl-Stream) and
/// [`Sink`](#impl-Sink<Packet>) traits.
pub struct Client {
config: Config,
state: ClientState,
pub struct Client<C: ServerConnector> {
config: Config<C>,
state: ClientState<C::Stream>,
reconnect: bool,
// TODO: tls_required=true
}
/// XMPP client configuration
#[derive(Clone, Debug)]
pub struct Config<C> {
/// jid of the account
pub jid: Jid,
/// password of the account
pub password: String,
/// server configuration for the account
pub server: C,
}
/// Trait called to connect to an XMPP server, perhaps called multiple times
pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
/// The type of Stream this ServerConnector produces
type Stream: AsyncReadAndWrite;
/// This must return the connection ready to login, ie if starttls is involved, after TLS has been started, and then after the <stream headers are exchanged
fn connect(
&self,
jid: &Jid,
) -> impl std::future::Future<Output = Result<XMPPStream<Self::Stream>, Error>> + Send;
}
/// XMPP server connection configuration
#[derive(Clone, Debug)]
pub enum ServerConfig {
@ -48,27 +70,46 @@ pub enum ServerConfig {
},
}
/// XMPP client configuration
#[derive(Clone, Debug)]
pub struct Config {
/// jid of the account
pub jid: Jid,
/// password of the account
pub password: String,
/// server configuration for the account
pub server: ServerConfig,
impl ServerConnector for ServerConfig {
type Stream = TlsStream<TcpStream>;
async fn connect(&self, jid: &Jid) -> Result<XMPPStream<Self::Stream>, Error> {
// TCP connection
let tcp_stream = match self {
ServerConfig::UseSrv => {
connect_with_srv(jid.domain_str(), "_xmpp-client._tcp", 5222).await?
}
ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), *port).await?,
};
// Unencryped XMPPStream
let xmpp_stream =
xmpp_stream::XMPPStream::start(tcp_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
.await?;
if xmpp_stream.stream_features.can_starttls() {
// TlsStream
let tls_stream = starttls(xmpp_stream).await?;
// Encrypted XMPPStream
xmpp_stream::XMPPStream::start(tls_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
.await
} else {
return Err(Error::Protocol(ProtocolError::NoTls));
}
}
}
type XMPPStream = xmpp_stream::XMPPStream<TlsStream<TcpStream>>;
/// trait used by XMPPStream type
pub trait AsyncReadAndWrite: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send {}
impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send> AsyncReadAndWrite for T {}
enum ClientState {
enum ClientState<S: AsyncReadAndWrite> {
Invalid,
Disconnected,
Connecting(JoinHandle<Result<XMPPStream, Error>>),
Connected(XMPPStream),
Connecting(JoinHandle<Result<XMPPStream<S>, Error>>),
Connected(XMPPStream<S>),
}
impl Client {
impl Client<ServerConfig> {
/// Start a new XMPP client
///
/// Start polling the returned instance so that it will connect
@ -81,9 +122,11 @@ impl Client {
};
Self::new_with_config(config)
}
}
impl<C: ServerConnector> Client<C> {
/// Start a new client given that the JID is already parsed.
pub fn new_with_config(config: Config) -> Self {
pub fn new_with_config(config: Config<C>) -> Self {
let connect = tokio::spawn(Self::connect(
config.server.clone(),
config.jid.clone(),
@ -105,35 +148,14 @@ impl Client {
}
async fn connect(
server: ServerConfig,
server: C,
jid: Jid,
password: String,
) -> Result<XMPPStream, Error> {
) -> Result<XMPPStream<C::Stream>, Error> {
let username = jid.node_str().unwrap();
let password = password;
// TCP connection
let tcp_stream = match server {
ServerConfig::UseSrv => {
connect_with_srv(jid.domain_str(), "_xmpp-client._tcp", 5222).await?
}
ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), port).await?,
};
// Unencryped XMPPStream
let xmpp_stream =
xmpp_stream::XMPPStream::start(tcp_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
.await?;
let xmpp_stream = if xmpp_stream.stream_features.can_starttls() {
// TlsStream
let tls_stream = starttls(xmpp_stream).await?;
// Encrypted XMPPStream
xmpp_stream::XMPPStream::start(tls_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
.await?
} else {
return Err(Error::Protocol(ProtocolError::NoTls));
};
let xmpp_stream = server.connect(&jid).await?;
let creds = Credentials::default()
.with_username(username)
@ -180,7 +202,7 @@ impl Client {
///
/// In an `async fn` you may want to use this with `use
/// futures::stream::StreamExt;`
impl Stream for Client {
impl<C: ServerConnector> Stream for Client<C> {
type Item = Event;
/// Low-level read on the XMPP stream, allowing the underlying
@ -297,7 +319,7 @@ impl Stream for Client {
/// Outgoing XMPP packets
///
/// See `send_stanza()` for an `async fn`
impl Sink<Packet> for Client {
impl<C: ServerConnector> Sink<Packet> for Client<C> {
type Error = Error;
fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {

View file

@ -19,8 +19,11 @@ mod happy_eyeballs;
pub mod stream_features;
pub mod xmpp_stream;
pub use client::{
async_client::Client as AsyncClient, async_client::Config as AsyncConfig,
async_client::ServerConfig as AsyncServerConfig, simple_client::Client as SimpleClient,
async_client::{
AsyncReadAndWrite, Client as AsyncClient, Config as AsyncConfig,
ServerConfig as AsyncServerConfig, ServerConnector as AsyncServerConnector,
},
simple_client::Client as SimpleClient,
};
mod component;
pub use crate::component::Component;

View file

@ -21,7 +21,7 @@ log = "0.4"
reqwest = { version = "0.11.8", features = ["stream"] }
tokio-util = { version = "0.7", features = ["codec"] }
# same repository dependencies
tokio-xmpp = { version = "3.4", path = "../tokio-xmpp" }
tokio-xmpp = { version = "3.4", path = "../tokio-xmpp", default-features = false }
[dev-dependencies]
env_logger = { version = "0.10", default-features = false, features = ["auto-color", "humantime"] }
@ -31,5 +31,7 @@ name = "hello_bot"
required-features = ["avatars"]
[features]
default = ["avatars"]
default = ["avatars", "tls-native"]
tls-native = ["tokio-xmpp/tls-native"]
tls-rust = ["tokio-xmpp/tls-rust"]
avatars = []

View file

@ -8,10 +8,9 @@ use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
pub use tokio_xmpp::parsers;
use tokio_xmpp::parsers::{disco::DiscoInfoResult, message::MessageType};
use tokio_xmpp::AsyncClient as TokioXmppClient;
pub use tokio_xmpp::{BareJid, Element, FullJid, Jid};
use crate::{event_loop, message, muc, upload, Error, Event, RoomNick};
use crate::{event_loop, message, muc, upload, Error, Event, RoomNick, TokioXmppClient};
pub struct Agent {
pub(crate) client: TokioXmppClient,

View file

@ -10,10 +10,10 @@ use tokio_xmpp::{
disco::{DiscoInfoResult, Feature, Identity},
ns,
},
AsyncClient as TokioXmppClient, BareJid, Jid,
BareJid, Jid,
};
use crate::{Agent, ClientFeature};
use crate::{Agent, ClientFeature, TokioXmppClient};
#[derive(Debug)]
pub enum ClientType {

View file

@ -7,6 +7,7 @@
#![deny(bare_trait_objects)]
pub use tokio_xmpp::parsers;
use tokio_xmpp::{AsyncClient, AsyncServerConfig};
pub use tokio_xmpp::{BareJid, Element, FullJid, Jid};
#[macro_use]
extern crate log;
@ -31,6 +32,8 @@ pub use builder::{ClientBuilder, ClientType};
pub use event::Event;
pub use feature::ClientFeature;
type TokioXmppClient = AsyncClient<AsyncServerConfig>;
pub type Error = tokio_xmpp::Error;
pub type Id = Option<String>;
pub type RoomNick = String;