diff --git a/src/element.rs b/src/element.rs index 1f33020..7d72315 100644 --- a/src/element.rs +++ b/src/element.rs @@ -38,10 +38,16 @@ //! //! ``` +use std::collections::HashMap; use std::ops::Deref; use std::fmt::Debug; use std::marker::PhantomData; +use crate::{Client, ClientName}; +use crate::types::VariableAttr; +use crate::parsers::parse_variable; + +use jid::BareJid; use minidom::{Element, Node}; /// Namespaces used for Client entities @@ -193,25 +199,32 @@ impl PartialEq> for ScanNodes { /// changes the way the comparison is done. /// Also uses the custom ScanNode implementation. #[derive(Debug, Clone)] -pub struct ScanElement<'a> { +pub struct ScanElement<'a, 'b> { elem: &'a Element, + context: Option<&'b HashMap>, } -impl<'a> Deref for ScanElement<'a> { - type Target = Element; +impl<'a, 'b> Deref for ScanElement<'a, 'b> { + type Target = Element; fn deref(&self) -> &Self::Target { &self.elem } } -impl<'a> ScanElement<'a> { - pub fn new(elem: &'a Element) -> ScanElement { - Self { elem } - } +impl<'a> ScanElement<'a, 'static> { + pub fn new(elem: &'a Element) -> ScanElement { + Self { elem, context: None } + } } -impl<'a> PartialEq<&Element> for ScanElement<'a> { +impl <'a, 'b> ScanElement<'a, 'b> { + pub fn with_context(self, context: &'b HashMap) -> ScanElement<'a, 'b> { + Self { elem: self.elem, context: Some(context) } + } +} + +impl<'a, 'b> PartialEq<&Element> for ScanElement<'a, 'b> { fn eq(&self, other: &&Element) -> bool { let self_ns = self.elem.ns(); if self.elem.name() == other.name() && @@ -231,6 +244,28 @@ impl<'a> PartialEq<&Element> for ScanElement<'a> { continue; } + // Parse variables. If parsing fails, continue attr comparison. + // If context isn't set, skip this and continue attr comparison. + if let Ok((_, var)) = parse_variable(val.into()) && + let Some(context) = self.context { + let res = match var { + VariableAttr::FullJid(name) => match context.get(&name) { + Some(Client { jid, .. }) => String::from(jid.clone()), + _ => return false, + }, + VariableAttr::BareJid(name) => match context.get(&name) { + Some(Client { jid, .. }) => String::from(BareJid::from(jid.clone())), + _ => return false, + }, + }; + + if let Some(oval) = other.attr(attr) && res == oval { + continue; + } else { + return false; + } + } + match (attr, other.attr(attr)) { (attr, _) if attr == "scansion:strict" => continue, (_, None) => return false, @@ -493,4 +528,22 @@ mod tests { assert_eq!(scan1, &elem2); } + + #[test] + fn variables_from_context() { + let louise = Client::new(Jid::from_str("louise@example.com").unwrap(), "passwd"); + + let clients = { + let mut tmp = HashMap::new(); + tmp.insert(String::from("louise"), louise); + tmp + }; + + let elem1: Element = "" + .parse().unwrap(); + let elem2: Element = "".parse().unwrap(); + let scan1 = ScanElement::new(&elem1).with_context(&clients); + + assert_eq!(scan1, &elem2); + } } diff --git a/src/lib.rs b/src/lib.rs index 7c77136..71f9f2b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,10 +4,12 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. +#![feature(let_chains)] + pub mod element; pub mod parsers; pub mod types; pub use element::ScanElement; pub use parsers::parse_spec; -pub use types::{Action, Client, Metadata, Spec}; +pub use types::{Action, Client, ClientName, Metadata, Spec}; diff --git a/src/parsers.rs b/src/parsers.rs index 343ee4b..fbc4795 100644 --- a/src/parsers.rs +++ b/src/parsers.rs @@ -4,7 +4,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -use crate::types::{Action, Client, ClientName, Metadata, Spec}; +use crate::types::{Action, Client, ClientName, Metadata, VariableAttr, Spec}; use std::collections::HashMap; use std::str::FromStr; @@ -279,6 +279,25 @@ pub fn parse_spec(i: &str) -> Result { }) } +pub fn parse_variable(s: Span) -> IResult { + let (s, (_, name, attr, _)) = tuple(( + tag("${"), take_until_tags(vec![ + "'s full JID", + "'s JID", + ].into_iter(), + "}", + ), + alt((tag("'s full JID"), tag("'s JID"))), + tag("}"), + ))(s)?; + + Ok((s, match *attr.fragment() { + "'s full JID" => VariableAttr::FullJid(name.to_string()), + "'s JID" => VariableAttr::BareJid(name.to_string()), + _ => unreachable!(), + })) +} + #[cfg(test)] mod tests { use super::*; @@ -602,4 +621,29 @@ louise receives: }) ); } + + #[test] + fn parse_variable_attr() { + let buf1: Span = "${louise's full JID}".into(); + let buf2: Span = "${louise's JID}".into(); + let buf3: Span = "${louise's JID".into(); + + assert_eq!( + parse_variable(buf1).unwrap().1, + VariableAttr::FullJid(String::from("louise")), + ); + + assert_eq!( + parse_variable(buf2).unwrap().1, + VariableAttr::BareJid(String::from("louise")), + ); + + match parse_variable(buf3) { + Err(nom::Err::Error(nom::error::Error { input, .. })) => { + assert_eq!(input.location_offset(), 2); + assert_eq!(input.location_line(), 1); + } + err => panic!("Expected Err, found: {err:?}"), + } + } } diff --git a/src/types.rs b/src/types.rs index e163016..6bb5877 100644 --- a/src/types.rs +++ b/src/types.rs @@ -8,6 +8,12 @@ use std::collections::HashMap; use jid::Jid; +#[derive(Debug, PartialEq)] +pub enum VariableAttr { + FullJid(ClientName), + BareJid(ClientName), +} + #[derive(Debug, Clone, PartialEq)] pub struct Metadata { pub title: String,