diff --git a/src/xmpp_codec.rs b/src/xmpp_codec.rs index e6a1bd2c..808f2389 100644 --- a/src/xmpp_codec.rs +++ b/src/xmpp_codec.rs @@ -9,8 +9,9 @@ use std::io::{Error, ErrorKind}; use std::collections::HashMap; use std::collections::vec_deque::VecDeque; use tokio_io::codec::{Encoder, Decoder}; -use minidom::Element; +use minidom::{Element, Node}; use xml5ever::tokenizer::{XmlTokenizer, TokenSink, Token, Tag, TagKind}; +use xml5ever::interface::Attribute; use bytes::*; // const NS_XMLNS: &'static str = "http://www.w3.org/2000/xmlns/"; @@ -25,10 +26,11 @@ pub enum Packet { } struct ParserSink { - // Ready stanzas + // Ready stanzas, shared with XMPPCodec queue: Rc>>, // Parsing stack stack: Vec, + ns_stack: Vec, String>>, } impl ParserSink { @@ -36,21 +38,79 @@ impl ParserSink { ParserSink { queue, stack: vec![], + ns_stack: vec![], } } fn push_queue(&self, pkt: Packet) { - println!("push: {:?}", pkt); self.queue.borrow_mut().push_back(pkt); } + fn lookup_ns(&self, prefix: &Option) -> Option<&str> { + for nss in self.ns_stack.iter().rev() { + match nss.get(prefix) { + Some(ns) => return Some(ns), + None => (), + } + } + + None + } + fn handle_start_tag(&mut self, tag: Tag) { - let el = tag_to_element(&tag); + let mut nss = HashMap::new(); + let is_prefix_xmlns = |attr: &Attribute| attr.name.prefix.as_ref() + .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns")) + .unwrap_or(false); + for attr in &tag.attrs { + match attr.name.local.as_ref() { + "xmlns" => { + nss.insert(None, attr.value.as_ref().to_owned()); + }, + prefix if is_prefix_xmlns(attr) => { + nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned()); + }, + _ => (), + } + } + self.ns_stack.push(nss); + + let el = { + let mut el_builder = Element::builder(tag.name.local.as_ref()); + match self.lookup_ns(&tag.name.prefix.map(|prefix| prefix.as_ref().to_owned())) { + Some(el_ns) => el_builder = el_builder.ns(el_ns), + None => (), + } + for attr in &tag.attrs { + match attr.name.local.as_ref() { + "xmlns" => (), + _ if is_prefix_xmlns(attr) => (), + _ => { + el_builder = el_builder.attr( + attr.name.local.as_ref(), + attr.value.as_ref() + ); + }, + } + } + el_builder.build() + }; + + if self.stack.is_empty() { + let attrs = HashMap::from_iter( + el.attrs() + .map(|(name, value)| (name.to_owned(), value.to_owned())) + ); + self.push_queue(Packet::StreamStart(attrs)); + } + self.stack.push(el); } fn handle_end_tag(&mut self) { let el = self.stack.pop().unwrap(); + self.ns_stack.pop(); + match self.stack.len() { // 0 => @@ -66,22 +126,8 @@ impl ParserSink { } } -fn tag_to_element(tag: &Tag) -> Element { - let el_builder = Element::builder(tag.name.local.as_ref()) - .ns(tag.name.ns.as_ref()); - let el_builder = tag.attrs.iter().fold( - el_builder, - |el_builder, attr| el_builder.attr( - attr.name.local.as_ref(), - attr.value.as_ref() - ) - ); - el_builder.build() -} - impl TokenSink for ParserSink { fn process_token(&mut self, token: Token) { - println!("Token: {:?}", token); match token { Token::TagToken(tag) => match tag.kind { TagKind::StartTag => @@ -119,9 +165,14 @@ impl TokenSink for ParserSink { } pub struct XMPPCodec { + /// Outgoing + ns: Option, + /// Incoming parser: XmlTokenizer, - // For handling truncated utf8 + /// For handling incoming truncated utf8 + // TODO: optimize using tendrils? buf: Vec, + /// Shared with ParserSink queue: Rc>>, } @@ -132,6 +183,7 @@ impl XMPPCodec { // TODO: configure parser? let parser = XmlTokenizer::new(sink, Default::default()); XMPPCodec { + ns: None, parser, queue, buf: vec![], @@ -144,7 +196,6 @@ impl Decoder for XMPPCodec { type Error = Error; fn decode(&mut self, buf: &mut BytesMut) -> Result, Self::Error> { - println!("decode {} bytes", buf.len()); let buf1: Box> = if self.buf.len() > 0 && buf.len() > 0 { let mut prefix = std::mem::replace(&mut self.buf, vec![]); @@ -196,32 +247,39 @@ impl Encoder for XMPPCodec { type Error = Error; fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> { - println!("encode {:?}", item); match item { Packet::StreamStart(start_attrs) => { let mut buf = String::new(); write!(buf, "\n").unwrap(); print!(">> {}", buf); write!(dst, "{}", buf) - .map_err(|_| Error::from(ErrorKind::InvalidInput)) + .map_err(|e| Error::new(ErrorKind::InvalidInput, e)) }, Packet::Stanza(stanza) => { - println!(">> {:?}", stanza); - let mut root_ns = None; // TODO - stanza.write_to_inner(&mut dst.clone().writer(), &mut root_ns) - .map_err(|_| Error::from(ErrorKind::InvalidInput)) + let root_ns = self.ns.as_ref().map(|s| s.as_ref()); + write_element(&stanza, dst, root_ns) + .and_then(|_| { + println!(">> {:?}", dst); + Ok(()) + }) + .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e))) }, Packet::Text(text) => { - let escaped = escape(&text); - println!(">> {}", escaped); - write!(dst, "{}", escaped) - .map_err(|_| Error::from(ErrorKind::InvalidInput)) + write_text(&text, dst) + .and_then(|_| { + println!(">> {:?}", dst); + Ok(()) + }) + .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e))) }, // TODO: Implement all _ => Ok(()) @@ -229,6 +287,45 @@ impl Encoder for XMPPCodec { } } +pub fn write_text(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> { + write!(writer, "{}", text) +} + +// TODO: escape everything? +pub fn write_element(el: &Element, writer: &mut W, parent_ns: Option<&str>) -> Result<(), std::fmt::Error> { + write!(writer, "<")?; + write!(writer, "{}", el.name())?; + + if let Some(ref ns) = el.ns() { + if parent_ns.map(|s| s.as_ref()) != el.ns() { + write!(writer, " xmlns=\"{}\"", ns)?; + } + } + + for (key, value) in el.attrs() { + write!(writer, " {}=\"{}\"", key, value)?; + } + + if ! el.nodes().any(|_| true) { + write!(writer, " />")?; + return Ok(()) + } + + write!(writer, ">")?; + + for node in el.nodes() { + match node { + &Node::Element(ref child) => + write_element(child, writer, el.ns())?, + &Node::Text(ref text) => + write_text(text, writer)?, + } + } + + write!(writer, "", el.name())?; + Ok(()) +} + /// Copied from RustyXML for now pub fn escape(input: &str) -> String { let mut result = String::with_capacity(input.len());