use std; use std::default::Default; use std::iter::FromIterator; use std::cell::RefCell; use std::rc::Rc; use std::fmt::Write; use std::str::from_utf8; use std::io::{Error, ErrorKind}; use std::collections::HashMap; use std::collections::vec_deque::VecDeque; use tokio_io::codec::{Encoder, Decoder}; use minidom::{Element, Node}; use xml5ever::tokenizer::{XmlTokenizer, TokenSink, Token, Tag, TagKind}; use xml5ever::interface::Attribute; use bytes::*; // const NS_XMLNS: &'static str = "http://www.w3.org/2000/xmlns/"; #[derive(Debug)] pub enum Packet { Error(Box), StreamStart(HashMap), Stanza(Element), Text(String), StreamEnd, } struct ParserSink { // Ready stanzas, shared with XMPPCodec queue: Rc>>, // Parsing stack stack: Vec, ns_stack: Vec, String>>, } impl ParserSink { pub fn new(queue: Rc>>) -> Self { ParserSink { queue, stack: vec![], ns_stack: vec![], } } fn push_queue(&self, pkt: Packet) { self.queue.borrow_mut().push_back(pkt); } fn lookup_ns(&self, prefix: &Option) -> Option<&str> { for nss in self.ns_stack.iter().rev() { match nss.get(prefix) { Some(ns) => return Some(ns), None => (), } } None } fn handle_start_tag(&mut self, tag: Tag) { let mut nss = HashMap::new(); let is_prefix_xmlns = |attr: &Attribute| attr.name.prefix.as_ref() .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns")) .unwrap_or(false); for attr in &tag.attrs { match attr.name.local.as_ref() { "xmlns" => { nss.insert(None, attr.value.as_ref().to_owned()); }, prefix if is_prefix_xmlns(attr) => { nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned()); }, _ => (), } } self.ns_stack.push(nss); let el = { let mut el_builder = Element::builder(tag.name.local.as_ref()); match self.lookup_ns(&tag.name.prefix.map(|prefix| prefix.as_ref().to_owned())) { Some(el_ns) => el_builder = el_builder.ns(el_ns), None => (), } for attr in &tag.attrs { match attr.name.local.as_ref() { "xmlns" => (), _ if is_prefix_xmlns(attr) => (), _ => { el_builder = el_builder.attr( attr.name.local.as_ref(), attr.value.as_ref() ); }, } } el_builder.build() }; if self.stack.is_empty() { let attrs = HashMap::from_iter( el.attrs() .map(|(name, value)| (name.to_owned(), value.to_owned())) ); self.push_queue(Packet::StreamStart(attrs)); } self.stack.push(el); } fn handle_end_tag(&mut self) { let el = self.stack.pop().unwrap(); self.ns_stack.pop(); match self.stack.len() { // 0 => self.push_queue(Packet::StreamEnd), // 1 => self.push_queue(Packet::Stanza(el)), len => { let parent = &mut self.stack[len - 1]; parent.append_child(el); }, } } } impl TokenSink for ParserSink { fn process_token(&mut self, token: Token) { match token { Token::TagToken(tag) => match tag.kind { TagKind::StartTag => self.handle_start_tag(tag), TagKind::EndTag => self.handle_end_tag(), TagKind::EmptyTag => { self.handle_start_tag(tag); self.handle_end_tag(); }, TagKind::ShortTag => self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, "ShortTag")))), }, Token::CharacterTokens(tendril) => match self.stack.len() { 0 | 1 => self.push_queue(Packet::Text(tendril.into())), len => { let el = &mut self.stack[len - 1]; el.append_text_node(tendril); }, }, Token::EOFToken => self.push_queue(Packet::StreamEnd), Token::ParseError(s) => { println!("ParseError: {:?}", s); self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, (*s).to_owned())))) }, _ => (), } } // fn end(&mut self) { // } } pub struct XMPPCodec { /// Outgoing ns: Option, /// Incoming parser: XmlTokenizer, /// For handling incoming truncated utf8 // TODO: optimize using tendrils? buf: Vec, /// Shared with ParserSink queue: Rc>>, } impl XMPPCodec { pub fn new() -> Self { let queue = Rc::new(RefCell::new((VecDeque::new()))); let sink = ParserSink::new(queue.clone()); // TODO: configure parser? let parser = XmlTokenizer::new(sink, Default::default()); XMPPCodec { ns: None, parser, queue, buf: vec![], } } } impl Decoder for XMPPCodec { type Item = Packet; type Error = Error; fn decode(&mut self, buf: &mut BytesMut) -> Result, Self::Error> { let buf1: Box> = if self.buf.len() > 0 && buf.len() > 0 { let mut prefix = std::mem::replace(&mut self.buf, vec![]); prefix.extend_from_slice(buf.take().as_ref()); Box::new(prefix) } else { Box::new(buf.take()) }; let buf1 = buf1.as_ref().as_ref(); match from_utf8(buf1) { Ok(s) => { if s.len() > 0 { println!("<< {}", s); let tendril = FromIterator::from_iter(s.chars()); self.parser.feed(tendril); } }, // Remedies for truncated utf8 Err(e) if e.valid_up_to() >= buf1.len() - 3 => { // Prepare all the valid data let mut b = BytesMut::with_capacity(e.valid_up_to()); b.put(&buf1[0..e.valid_up_to()]); // Retry let result = self.decode(&mut b); // Keep the tail back in self.buf.extend_from_slice(&buf1[e.valid_up_to()..]); return result; }, Err(e) => { println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1); return Err(Error::new(ErrorKind::InvalidInput, e)); }, } let result = self.queue.borrow_mut().pop_front(); Ok(result) } fn decode_eof(&mut self, buf: &mut BytesMut) -> Result, Self::Error> { self.decode(buf) } } impl Encoder for XMPPCodec { type Item = Packet; type Error = Error; fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> { match item { Packet::StreamStart(start_attrs) => { let mut buf = String::new(); write!(buf, "\n").unwrap(); print!(">> {}", buf); write!(dst, "{}", buf) .map_err(|e| Error::new(ErrorKind::InvalidInput, e)) }, Packet::Stanza(stanza) => { let root_ns = self.ns.as_ref().map(|s| s.as_ref()); write_element(&stanza, dst, root_ns) .and_then(|_| { println!(">> {:?}", dst); Ok(()) }) .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e))) }, Packet::Text(text) => { write_text(&text, dst) .and_then(|_| { println!(">> {:?}", dst); Ok(()) }) .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e))) }, // TODO: Implement all _ => Ok(()) } } } pub fn write_text(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> { write!(writer, "{}", text) } // TODO: escape everything? pub fn write_element(el: &Element, writer: &mut W, parent_ns: Option<&str>) -> Result<(), std::fmt::Error> { write!(writer, "<")?; write!(writer, "{}", el.name())?; if let Some(ref ns) = el.ns() { if parent_ns.map(|s| s.as_ref()) != el.ns() { write!(writer, " xmlns=\"{}\"", ns)?; } } for (key, value) in el.attrs() { write!(writer, " {}=\"{}\"", key, value)?; } if ! el.nodes().any(|_| true) { write!(writer, " />")?; return Ok(()) } write!(writer, ">")?; for node in el.nodes() { match node { &Node::Element(ref child) => write_element(child, writer, el.ns())?, &Node::Text(ref text) => write_text(text, writer)?, } } write!(writer, "", el.name())?; Ok(()) } /// Copied from RustyXML for now pub fn escape(input: &str) -> String { let mut result = String::with_capacity(input.len()); for c in input.chars() { match c { '&' => result.push_str("&"), '<' => result.push_str("<"), '>' => result.push_str(">"), '\'' => result.push_str("'"), '"' => result.push_str("""), o => result.push(o) } } result } #[cfg(test)] mod tests { use super::*; use bytes::BytesMut; #[test] fn test_stream_start() { let mut c = XMPPCodec::new(); let mut b = BytesMut::with_capacity(1024); b.put(r""); let r = c.decode(&mut b); assert!(match r { Ok(Some(Packet::StreamStart(_))) => true, _ => false, }); } #[test] fn test_truncated_stanza() { let mut c = XMPPCodec::new(); let mut b = BytesMut::with_capacity(1024); b.put(r""); let r = c.decode(&mut b); assert!(match r { Ok(Some(Packet::StreamStart(_))) => true, _ => false, }); b.clear(); b.put(r"ß true, _ => false, }); b.clear(); b.put(r">"); let r = c.decode(&mut b); assert!(match r { Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true, _ => false, }); } #[test] fn test_truncated_utf8() { let mut c = XMPPCodec::new(); let mut b = BytesMut::with_capacity(1024); b.put(r""); let r = c.decode(&mut b); assert!(match r { Ok(Some(Packet::StreamStart(_))) => true, _ => false, }); b.clear(); b.put(&b"\xc3"[..]); let r = c.decode(&mut b); assert!(match r { Ok(None) => true, _ => false, }); b.clear(); b.put(&b"\x9f"[..]); let r = c.decode(&mut b); assert!(match r { Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true, _ => false, }); } }