From 4266368a98958daf4cd6c88dc6fe93ed80d4b850 Mon Sep 17 00:00:00 2001 From: xmppftw Date: Wed, 21 Jun 2023 18:30:25 +0200 Subject: [PATCH] JIDs now have typed and stringy methods for node/domain/resource access Jid now has typed with_resource and stringy with_resource_str Jid now has is_full, is_bare --- jid/src/lib.rs | 152 +++++++++++++++++++++---- jid/src/parts.rs | 44 ++++++- tokio-xmpp/src/client/async_client.rs | 4 +- tokio-xmpp/src/client/bind.rs | 2 +- tokio-xmpp/src/client/simple_client.rs | 4 +- tokio-xmpp/src/starttls.rs | 2 +- tokio-xmpp/src/stream_start.rs | 2 +- xmpp/src/lib.rs | 6 +- 8 files changed, 177 insertions(+), 39 deletions(-) diff --git a/jid/src/lib.rs b/jid/src/lib.rs index 8159b76..74a2a3f 100644 --- a/jid/src/lib.rs +++ b/jid/src/lib.rs @@ -34,7 +34,6 @@ use core::num::NonZeroU16; use std::convert::TryFrom; use std::fmt; use std::str::FromStr; -use stringprep::resourceprep; #[cfg(feature = "serde")] use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; @@ -103,9 +102,9 @@ impl Jid { /// # fn main() -> Result<(), Error> { /// let jid = Jid::new("node@domain/resource")?; /// - /// assert_eq!(jid.node(), Some("node")); - /// assert_eq!(jid.domain(), "domain"); - /// assert_eq!(jid.resource(), Some("resource")); + /// assert_eq!(jid.node_str(), Some("node")); + /// assert_eq!(jid.domain_str(), "domain"); + /// assert_eq!(jid.resource_str(), Some("resource")); /// # Ok(()) /// # } /// ``` @@ -133,22 +132,51 @@ impl Jid { } } - /// The optional node part of the JID. - pub fn node(&self) -> Option<&str> { + /// The optional node part of the JID, as a [`NodePart`] + pub fn node(&self) -> Option { + match self { + Jid::Bare(BareJid { inner }) | Jid::Full(FullJid { inner }) => { + inner.node().map(|s| NodePart::new_unchecked(s)) + } + } + } + + /// The optional node part of the JID, as a stringy reference + pub fn node_str(&self) -> Option<&str> { match self { Jid::Bare(BareJid { inner }) | Jid::Full(FullJid { inner }) => inner.node(), } } - /// The domain part of the JID. - pub fn domain(&self) -> &str { + /// The domain part of the JID, as a [`DomainPart`] + pub fn domain(&self) -> DomainPart { + match self { + Jid::Bare(BareJid { inner }) | Jid::Full(FullJid { inner }) => { + DomainPart::new_unchecked(inner.domain()) + } + } + } + + /// The domain part of the JID, as a stringy reference + pub fn domain_str(&self) -> &str { match self { Jid::Bare(BareJid { inner }) | Jid::Full(FullJid { inner }) => inner.domain(), } } - /// The optional resource part of the JID. - pub fn resource(&self) -> Option<&str> { + /// The optional resource part of the JID, as a [`ResourcePart`]. It is guaranteed to be present + /// when the JID is a Full variant, which you can check with [`Jid::is_full`]. + pub fn resource(&self) -> Option { + match self { + Jid::Bare(BareJid { inner }) | Jid::Full(FullJid { inner }) => { + inner.resource().map(|s| ResourcePart::new_unchecked(s)) + } + } + } + + /// The optional resource of the Jabber ID. It is guaranteed to be present when the JID is + /// a Full variant, which you can check with [`Jid::is_full`]. + pub fn resource_str(&self) -> Option<&str> { match self { Jid::Bare(BareJid { inner }) | Jid::Full(FullJid { inner }) => inner.resource(), } @@ -169,6 +197,19 @@ impl Jid { Jid::Bare(jid) => jid, } } + + /// Checks if the JID contains a [`FullJid`] + pub fn is_full(&self) -> bool { + match self { + Self::Full(_) => true, + Self::Bare(_) => false, + } + } + + /// Checks if the JID contains a [`BareJid`] + pub fn is_bare(&self) -> bool { + !self.is_full() + } } impl TryFrom for FullJid { @@ -488,22 +529,23 @@ impl BareJid { self.inner.domain() } - /// Constructs a [`FullJid`] from the bare JID, by specifying a `resource`. + /// Constructs a [`BareJid`] from the bare JID, by specifying a [`ResourcePart`]. + /// If you'd like to specify a stringy resource, use [`BareJid::with_resource_str`] instead. /// /// # Examples /// /// ``` - /// use jid::BareJid; + /// use jid::{BareJid, ResourcePart}; /// + /// let resource = ResourcePart::new("resource").unwrap(); /// let bare = BareJid::new("node@domain").unwrap(); - /// let full = bare.with_resource("resource").unwrap(); + /// let full = bare.with_resource(&resource); /// /// assert_eq!(full.node(), Some("node")); /// assert_eq!(full.domain(), "domain"); /// assert_eq!(full.resource(), "resource"); /// ``` - pub fn with_resource(&self, resource: &str) -> Result { - let resource = resourceprep(resource).map_err(|_| Error::ResourcePrep)?; + pub fn with_resource(&self, resource: &ResourcePart) -> FullJid { let slash = NonZeroU16::new(self.inner.normalized.len() as u16); let normalized = format!("{}/{resource}", self.inner.normalized); let inner = InnerJid { @@ -511,7 +553,28 @@ impl BareJid { at: self.inner.at, slash, }; - Ok(FullJid { inner }) + + FullJid { inner } + } + + /// Constructs a [`FullJid`] from the bare JID, by specifying a stringy `resource`. + /// If your resource has already been parsed into a [`ResourcePart`], use [`BareJid::with_resource`]. + /// + /// # Examples + /// + /// ``` + /// use jid::BareJid; + /// + /// let bare = BareJid::new("node@domain").unwrap(); + /// let full = bare.with_resource_str("resource").unwrap(); + /// + /// assert_eq!(full.node(), Some("node")); + /// assert_eq!(full.domain(), "domain"); + /// assert_eq!(full.resource(), "resource"); + /// ``` + pub fn with_resource_str(&self, resource: &str) -> Result { + let resource = ResourcePart::new(resource)?; + Ok(self.with_resource(&resource)) } } @@ -634,24 +697,51 @@ mod tests { } #[test] - fn bare_to_full_jid() { + fn bare_to_full_jid_str() { assert_eq!( - BareJid::new("a@b.c").unwrap().with_resource("d").unwrap(), + BareJid::new("a@b.c") + .unwrap() + .with_resource_str("d") + .unwrap(), FullJid::new("a@b.c/d").unwrap() ); } #[test] - fn node_from_jid() { + fn bare_to_full_jid() { assert_eq!( - Jid::Full(FullJid::new("a@b.c/d").unwrap()).node(), - Some("a"), - ); + BareJid::new("a@b.c") + .unwrap() + .with_resource(&ResourcePart::new("d").unwrap()), + FullJid::new("a@b.c/d").unwrap() + ) + } + + #[test] + fn node_from_jid() { + let jid = Jid::new("a@b.c/d").unwrap(); + + assert_eq!(jid.node_str(), Some("a"),); + + assert_eq!(jid.node(), Some(NodePart::new("a").unwrap())); } #[test] fn domain_from_jid() { - assert_eq!(Jid::Bare(BareJid::new("a@b.c").unwrap()).domain(), "b.c"); + let jid = Jid::new("a@b.c").unwrap(); + + assert_eq!(jid.domain_str(), "b.c"); + + assert_eq!(jid.domain(), DomainPart::new("b.c").unwrap()); + } + + #[test] + fn resource_from_jid() { + let jid = Jid::new("a@b.c/d").unwrap(); + + assert_eq!(jid.resource_str(), Some("d"),); + + assert_eq!(jid.resource(), Some(ResourcePart::new("d").unwrap())); } #[test] @@ -772,4 +862,20 @@ mod tests { let equiv = FullJid::new("test@☃.com/TestTM").unwrap(); assert_eq!(full, equiv); } + + #[test] + fn jid_from_parts() { + let node = NodePart::new("node").unwrap(); + let domain = DomainPart::new("domain").unwrap(); + let resource = ResourcePart::new("resource").unwrap(); + + let jid = Jid::from_parts(Some(&node), &domain, Some(&resource)); + assert_eq!(jid, Jid::new("node@domain/resource").unwrap()); + + let barejid = BareJid::from_parts(Some(&node), &domain); + assert_eq!(barejid, BareJid::new("node@domain").unwrap()); + + let fulljid = FullJid::from_parts(Some(&node), &domain, &resource); + assert_eq!(fulljid, FullJid::new("node@domain/resource").unwrap()); + } } diff --git a/jid/src/parts.rs b/jid/src/parts.rs index 62b9556..ee33b7d 100644 --- a/jid/src/parts.rs +++ b/jid/src/parts.rs @@ -1,10 +1,8 @@ use stringprep::{nameprep, nodeprep, resourceprep}; -use crate::Error; +use std::fmt; -/// The [`NodePart`] is the optional part before the (optional) `@` in any [`Jid`], whether [`BareJid`] or [`FullJid`]. -#[derive(Clone, Debug, PartialEq, Hash, PartialOrd)] -pub struct NodePart(pub(crate) String); +use crate::Error; fn length_check(len: usize, error_empty: Error, error_too_long: Error) -> Result<(), Error> { if len == 0 { @@ -16,6 +14,10 @@ fn length_check(len: usize, error_empty: Error, error_too_long: Error) -> Result } } +/// The [`NodePart`] is the optional part before the (optional) `@` in any [`Jid`], whether [`BareJid`] or [`FullJid`]. +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct NodePart(pub(crate) String); + impl NodePart { /// Build a new [`NodePart`] from a string slice. Will fail in case of stringprep validation error. pub fn new(s: &str) -> Result { @@ -23,10 +25,20 @@ impl NodePart { length_check(node.len(), Error::NodeEmpty, Error::NodeTooLong)?; Ok(NodePart(node.to_string())) } + + pub(crate) fn new_unchecked(s: &str) -> NodePart { + NodePart(s.to_string()) + } +} + +impl fmt::Display for NodePart { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } } /// The [`DomainPart`] is the part between the (optional) `@` and the (optional) `/` in any [`Jid`], whether [`BareJid`] or [`FullJid`]. -#[derive(Clone, Debug, PartialEq, Hash, PartialOrd)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct DomainPart(pub(crate) String); impl DomainPart { @@ -36,10 +48,20 @@ impl DomainPart { length_check(domain.len(), Error::DomainEmpty, Error::DomainTooLong)?; Ok(DomainPart(domain.to_string())) } + + pub(crate) fn new_unchecked(s: &str) -> DomainPart { + DomainPart(s.to_string()) + } +} + +impl fmt::Display for DomainPart { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } } /// The [`ResourcePart`] is the optional part after the `/` in a [`Jid`]. It is mandatory in [`FullJid`]. -#[derive(Clone, Debug, PartialEq, Hash, PartialOrd)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct ResourcePart(pub(crate) String); impl ResourcePart { @@ -49,4 +71,14 @@ impl ResourcePart { length_check(resource.len(), Error::ResourceEmpty, Error::ResourceTooLong)?; Ok(ResourcePart(resource.to_string())) } + + pub(crate) fn new_unchecked(s: &str) -> ResourcePart { + ResourcePart(s.to_string()) + } +} + +impl fmt::Display for ResourcePart { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } } diff --git a/tokio-xmpp/src/client/async_client.rs b/tokio-xmpp/src/client/async_client.rs index cb0f17f..18334a4 100644 --- a/tokio-xmpp/src/client/async_client.rs +++ b/tokio-xmpp/src/client/async_client.rs @@ -109,13 +109,13 @@ impl Client { jid: Jid, password: String, ) -> Result { - let username = jid.node().unwrap(); + let username = jid.node_str().unwrap(); let password = password; // TCP connection let tcp_stream = match server { ServerConfig::UseSrv => { - connect_with_srv(jid.domain(), "_xmpp-client._tcp", 5222).await? + connect_with_srv(jid.domain_str(), "_xmpp-client._tcp", 5222).await? } ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), port).await?, }; diff --git a/tokio-xmpp/src/client/bind.rs b/tokio-xmpp/src/client/bind.rs index 8d6cb3f..9172c12 100644 --- a/tokio-xmpp/src/client/bind.rs +++ b/tokio-xmpp/src/client/bind.rs @@ -17,7 +17,7 @@ pub async fn bind( if stream.stream_features.can_bind() { let resource = stream .jid - .resource() + .resource_str() .and_then(|resource| Some(resource.to_owned())); let iq = Iq::from_set(BIND_REQ_ID, BindQuery::new(resource)); stream.send_stanza(iq).await?; diff --git a/tokio-xmpp/src/client/simple_client.rs b/tokio-xmpp/src/client/simple_client.rs index 10c1b3b..a2fe69e 100644 --- a/tokio-xmpp/src/client/simple_client.rs +++ b/tokio-xmpp/src/client/simple_client.rs @@ -50,9 +50,9 @@ impl Client { } async fn connect(jid: Jid, password: String) -> Result { - let username = jid.node().unwrap(); + let username = jid.node_str().unwrap(); let password = password; - let domain = idna::domain_to_ascii(&jid.clone().domain()).map_err(|_| Error::Idna)?; + let domain = idna::domain_to_ascii(&jid.clone().domain_str()).map_err(|_| Error::Idna)?; // TCP connection let tcp_stream = connect_with_srv(&domain, "_xmpp-client._tcp", 5222).await?; diff --git a/tokio-xmpp/src/starttls.rs b/tokio-xmpp/src/starttls.rs index 4d5af3a..c355149 100644 --- a/tokio-xmpp/src/starttls.rs +++ b/tokio-xmpp/src/starttls.rs @@ -29,7 +29,7 @@ use crate::{Error, ProtocolError}; async fn get_tls_stream( xmpp_stream: XMPPStream, ) -> Result, Error> { - let domain = xmpp_stream.jid.domain().to_owned(); + let domain = xmpp_stream.jid.domain_str().to_owned(); let stream = xmpp_stream.into_inner(); let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap()) .connect(&domain, stream) diff --git a/tokio-xmpp/src/stream_start.rs b/tokio-xmpp/src/stream_start.rs index 563d7c2..06763bf 100644 --- a/tokio-xmpp/src/stream_start.rs +++ b/tokio-xmpp/src/stream_start.rs @@ -16,7 +16,7 @@ pub async fn start( ns: String, ) -> Result, Error> { let attrs = [ - ("to".to_owned(), jid.domain().to_owned()), + ("to".to_owned(), jid.domain_str().to_owned()), ("version".to_owned(), "1.0".to_owned()), ("xmlns".to_owned(), ns.clone()), ("xmlns:stream".to_owned(), ns::STREAM.to_owned()), diff --git a/xmpp/src/lib.rs b/xmpp/src/lib.rs index 16b7bb8..d9096e6 100644 --- a/xmpp/src/lib.rs +++ b/xmpp/src/lib.rs @@ -180,7 +180,7 @@ impl ClientBuilder<'_> { pub fn build(self) -> Agent { let jid: Jid = if let Some(resource) = &self.resource { - self.jid.with_resource(resource).unwrap().into() + self.jid.with_resource_str(resource).unwrap().into() } else { self.jid.clone().into() }; @@ -233,7 +233,7 @@ impl Agent { } let nick = nick.unwrap_or_else(|| self.default_nick.read().unwrap().clone()); - let room_jid = room.with_resource(&nick).unwrap(); + let room_jid = room.with_resource_str(&nick).unwrap(); let mut presence = Presence::new(PresenceType::None).with_to(room_jid); presence.add_payload(muc); presence.set_status(String::from(lang), String::from(status)); @@ -262,7 +262,7 @@ impl Agent { lang: &str, text: &str, ) { - let recipient: Jid = room.with_resource(&recipient).unwrap().into(); + let recipient: Jid = room.with_resource_str(&recipient).unwrap().into(); let mut message = Message::new(recipient).with_payload(MucUser::new()); message.type_ = MessageType::Chat; message