From 6e5f86632f78f40589f90d8af4a00b794bfc6d19 Mon Sep 17 00:00:00 2001 From: Astro Date: Fri, 14 Jul 2017 01:58:25 +0200 Subject: [PATCH] xmpp_codec: add remedies for truncated utf8 --- src/xmpp_codec.rs | 113 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 110 insertions(+), 3 deletions(-) diff --git a/src/xmpp_codec.rs b/src/xmpp_codec.rs index 2551cc4..1069372 100644 --- a/src/xmpp_codec.rs +++ b/src/xmpp_codec.rs @@ -56,6 +56,7 @@ pub enum Packet { pub struct XMPPCodec { parser: xml::Parser, root: Option, + buf: Vec, } impl XMPPCodec { @@ -63,6 +64,7 @@ impl XMPPCodec { XMPPCodec { parser: xml::Parser::new(), root: None, + buf: vec![], } } } @@ -72,15 +74,40 @@ impl Decoder for XMPPCodec { type Error = Error; fn decode(&mut self, buf: &mut BytesMut) -> Result, Self::Error> { - match from_utf8(buf.take().as_ref()) { + let buf1: Box> = + if self.buf.len() > 0 && buf.len() > 0 { + let mut prefix = std::mem::replace(&mut self.buf, vec![]); + prefix.extend_from_slice(buf.take().as_ref()); + Box::new(prefix) + } else { + Box::new(buf.take()) + }; + let buf1 = buf1.as_ref().as_ref(); + match from_utf8(buf1) { Ok(s) => { if s.len() > 0 { println!("<< {}", s); self.parser.feed_str(s); } }, - Err(e) => - return Err(Error::new(ErrorKind::InvalidInput, e)), + // Remedies for truncated utf8 + Err(e) if e.valid_up_to() >= buf1.len() - 3 => { + // Prepare all the valid data + let mut b = BytesMut::with_capacity(e.valid_up_to()); + b.put(&buf1[0..e.valid_up_to()]); + + // Retry + let result = self.decode(&mut b); + + // Keep the tail back in + self.buf.extend_from_slice(&buf1[e.valid_up_to()..]); + + return result; + }, + Err(e) => { + println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1); + return Err(Error::new(ErrorKind::InvalidInput, e)); + }, } let mut new_root: Option = None; @@ -171,3 +198,83 @@ impl Encoder for XMPPCodec { .map_err(|_| Error::from(ErrorKind::InvalidInput)) } } + +#[cfg(test)] +mod tests { + use super::*; + use bytes::BytesMut; + + #[test] + fn test_stream_start() { + let mut c = XMPPCodec::new(); + let mut b = BytesMut::with_capacity(1024); + b.put(r""); + let r = c.decode(&mut b); + assert!(match r { + Ok(Some(Packet::StreamStart(_))) => true, + _ => false, + }); + } + + #[test] + fn test_truncated_stanza() { + let mut c = XMPPCodec::new(); + let mut b = BytesMut::with_capacity(1024); + b.put(r""); + let r = c.decode(&mut b); + assert!(match r { + Ok(Some(Packet::StreamStart(_))) => true, + _ => false, + }); + + b.clear(); + b.put(r"ß true, + _ => false, + }); + + b.clear(); + b.put(r">"); + let r = c.decode(&mut b); + assert!(match r { + Ok(Some(Packet::Stanza(ref el))) + if el.name == "test" + && el.content_str() == "ß" + => true, + _ => false, + }); + } + + #[test] + fn test_truncated_utf8() { + let mut c = XMPPCodec::new(); + let mut b = BytesMut::with_capacity(1024); + b.put(r""); + let r = c.decode(&mut b); + assert!(match r { + Ok(Some(Packet::StreamStart(_))) => true, + _ => false, + }); + + b.clear(); + b.put(&b"\xc3"[..]); + let r = c.decode(&mut b); + assert!(match r { + Ok(None) => true, + _ => false, + }); + + b.clear(); + b.put(&b"\x9f"[..]); + let r = c.decode(&mut b); + assert!(match r { + Ok(Some(Packet::Stanza(ref el))) + if el.name == "test" + && el.content_str() == "ß" + => true, + _ => false, + }); + } +}