From af9e47f0a8ea231e273d7d3f1a0c038fd9d0663b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maxime=20=E2=80=9Cpep=E2=80=9D=20Buquet?= Date: Wed, 19 Apr 2023 13:21:03 +0200 Subject: [PATCH] ScanElement: Propagate context MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Maxime “pep” Buquet --- src/element.rs | 151 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 111 insertions(+), 40 deletions(-) diff --git a/src/element.rs b/src/element.rs index 25c781c..65cf125 100644 --- a/src/element.rs +++ b/src/element.rs @@ -56,6 +56,8 @@ pub static DEFAULT_NS: &str = "jabber:client"; /// Namespace used for scansion attributes pub static SCANSION_NS: &str = "https://matthewwild.co.uk/projects/scansion"; +pub type Context = HashMap; + /// Strict Comparison marker #[derive(Debug)] struct StrictComparison; @@ -71,21 +73,24 @@ enum NodeType { } #[derive(Debug, Clone)] -struct ScanNode { +struct ScanNode<'a> { node: Node, + context: Option<&'a Context>, } -impl ScanNode { - fn new(node: Node) -> ScanNode { - ScanNode { node } +impl<'a> ScanNode<'a> { + fn new(node: Node, context: Option<&'a Context>) -> ScanNode { + ScanNode { node, context } } } -impl PartialEq for ScanNode { +impl<'a> PartialEq for ScanNode<'a> { fn eq(&self, other: &Node) -> bool { match (&self.node, other) { (Node::Text(text1), Node::Text(text2)) => text1 == text2, - (Node::Element(elem1), Node::Element(elem2)) => ScanElement::new(&elem1) == elem2, + (Node::Element(elem1), Node::Element(elem2)) => { + ScanElement::new(&elem1).with_context(self.context) == elem2 + } _ => false, } } @@ -131,24 +136,49 @@ fn filter_whitespace_nodes(nodes: Vec) -> Vec { } #[derive(Debug)] -struct ScanNodes { +struct ScanNodes<'a, T: Debug> { nodes: Vec, + context: Option<&'a Context>, _strict: PhantomData, } -impl ScanNodes { - fn new(nodes: Vec) -> ScanNodes { +impl<'a> ScanNodes<'a, NonStrictComparison> { + fn new(nodes: Vec) -> ScanNodes<'a, NonStrictComparison> { Self { nodes, + context: None, + _strict: PhantomData, + } + } + + fn new_with_context( + nodes: Vec, + context: Option<&'a Context>, + ) -> ScanNodes<'a, NonStrictComparison> { + Self { + nodes, + context, _strict: PhantomData, } } } -impl ScanNodes { - fn new_strict(nodes: Vec) -> ScanNodes { +impl<'a> ScanNodes<'a, StrictComparison> { + fn new_strict(nodes: Vec) -> ScanNodes<'a, StrictComparison> { Self { nodes, + context: None, + _strict: PhantomData, + } + } + + fn new_strict_with_context( + nodes: Vec, + context: Option<&'a Context>, + ) -> ScanNodes<'a, StrictComparison> { + Self { + nodes, + context, _strict: PhantomData, } } @@ -157,11 +187,11 @@ impl ScanNodes { /// Tags with mixed significant text and children tags aren't valid in XMPP, so we know we can /// remove them. Text leaves are compared as is. When comparing strictly, elements must be exactly the /// same. -impl PartialEq> for ScanNodes { +impl<'a> PartialEq> for ScanNodes<'a, StrictComparison> { fn eq(&self, other: &Vec) -> bool { let filtered_self = filter_whitespace_nodes(self.nodes.clone()) .into_iter() - .map(ScanNode::new) + .map(|node| ScanNode::new(node, self.context)) .collect::>(); let filtered_other = filter_whitespace_nodes(other.clone()); @@ -172,7 +202,7 @@ impl PartialEq> for ScanNodes { /// Tags with mixed significant text and children tags aren't valid in XMPP, so we know we can /// remove them. Text leaves are compared as is. When doing non-strict comparison, the target /// element must have all attributes and children of the test element but it can have more. -impl PartialEq> for ScanNodes { +impl<'a> PartialEq> for ScanNodes<'a, NonStrictComparison> { fn eq(&self, other: &Vec) -> bool { let filtered_other = filter_whitespace_nodes(other.clone()); @@ -180,7 +210,7 @@ impl PartialEq> for ScanNodes { .into_iter() // Maps nodes to their comparison result .fold(true, |res, node| { - let scan = ScanNode::new(node); + let scan = ScanNode::new(node, self.context); res && filtered_other .iter() .find(|onode| &&scan == onode) @@ -197,7 +227,7 @@ impl PartialEq> for ScanNodes { #[derive(Debug, Clone)] pub struct ScanElement<'a, 'b> { elem: &'a Element, - context: Option<&'b HashMap>, + context: Option<&'b Context>, } impl<'a, 'b> Deref for ScanElement<'a, 'b> { @@ -218,10 +248,10 @@ impl<'a> ScanElement<'a, 'static> { } impl<'a, 'b> ScanElement<'a, 'b> { - pub fn with_context(self, context: &'b HashMap) -> ScanElement<'a, 'b> { + pub fn with_context(self, context: Option<&'b Context>) -> ScanElement<'a, 'b> { Self { elem: self.elem, - context: Some(context), + context, } } } @@ -247,24 +277,24 @@ impl<'a, 'b> PartialEq<&Element> for ScanElement<'a, 'b> { // Parse variables. If parsing fails, continue attr comparison. // If context isn't set, skip this and continue attr comparison. if let Ok((_, var)) = parse_variable(val.into()) && - let Some(context) = self.context { - let res = match var { - VariableAttr::FullJid(name) => match context.get(&name) { - Some(Client { jid, .. }) => String::from(jid.clone()), - _ => return false, - }, - VariableAttr::BareJid(name) => match context.get(&name) { - Some(Client { jid, .. }) => String::from(BareJid::from(jid.clone())), - _ => return false, - }, - }; + let Some(context) = self.context { + let res = match var { + VariableAttr::FullJid(name) => match context.get(&name) { + Some(Client { jid, .. }) => String::from(jid.clone()), + _ => return false, + }, + VariableAttr::BareJid(name) => match context.get(&name) { + Some(Client { jid, .. }) => String::from(BareJid::from(jid.clone())), + _ => return false, + }, + }; - if let Some(oval) = other.attr(attr) && res == oval { - continue; - } else { - return false; - } - } + if let Some(oval) = other.attr(attr) && res == oval { + continue; + } else { + return false; + } + } match (attr, other.attr(attr)) { (attr, _) if attr == "scansion:strict" => continue, @@ -287,10 +317,14 @@ impl<'a, 'b> PartialEq<&Element> for ScanElement<'a, 'b> { _ => (), } - let nodes = ScanNodes::new_strict(self.elem.nodes().cloned().collect()); + let nodes = ScanNodes::new_strict_with_context( + self.elem.nodes().cloned().collect(), + self.context, + ); nodes == onodes } else { - let nodes = ScanNodes::new(self.elem.nodes().cloned().collect()); + let nodes = + ScanNodes::new_with_context(self.elem.nodes().cloned().collect(), self.context); nodes == onodes } } else { @@ -577,12 +611,14 @@ mod tests { #[test] fn variables_from_context() { let louise = Client::new(Jid::from_str("louise@example.com").unwrap(), "passwd"); + let rosa_phone = Client::new(Jid::from_str("rosa@example.com/phone").unwrap(), "passwd"); - let clients = { + let clients = Some({ let mut tmp = HashMap::new(); tmp.insert(String::from("louise"), louise); + tmp.insert(String::from("rosa's phone"), rosa_phone); tmp - }; + }); let elem1: Element = "" .parse() @@ -590,7 +626,42 @@ mod tests { let elem2: Element = "" .parse() .unwrap(); - let scan1 = ScanElement::new(&elem1).with_context(&clients); + let scan1 = ScanElement::new(&elem1).with_context(clients.as_ref()); + + assert_eq!(scan1, &elem2); + + let elem3: Element = "" + .parse() + .unwrap(); + let elem4: Element = "" + .parse() + .unwrap(); + let scan3 = ScanElement::new(&elem3).with_context(clients.as_ref()); + + assert_eq!(scan3, &elem4); + } + + #[test] + fn variables_propagate_context() { + let louise = Client::new( + Jid::from_str("louise@example.com/device1").unwrap(), + "passwd", + ); + + let clients = Some({ + let mut tmp = HashMap::new(); + tmp.insert(String::from("louise"), louise); + tmp + }); + + let elem1: Element = "" + .parse() + .unwrap(); + let elem2: Element = + "" + .parse() + .unwrap(); + let scan1 = ScanElement::new(&elem1).with_context(clients.as_ref()); assert_eq!(scan1, &elem2); }