diff --git a/src/client.rs b/src/client.rs index 70bf176..45da512 100644 --- a/src/client.rs +++ b/src/client.rs @@ -121,46 +121,21 @@ impl Client { Ok(()) } - /// Connects using the specified SASL mechanism. + /// Connects and authenticates using the specified SASL mechanism. pub fn connect(&mut self, mechanism: &mut S) -> Result<(), Error> { - // TODO: this is very ugly - loop { - let e = self.transport.read_event().unwrap(); - match e { - ReaderEvent::StartElement { .. } => { - break; - }, - _ => (), - } + self.wait_for_features()?; + let auth = mechanism.initial(); + let mut elem = Element::builder("auth") + .ns(ns::SASL) + .attr("mechanism", S::name()) + .build(); + if !auth.is_empty() { + elem.append_text_node(base64::encode(&auth)); } - let mut did_sasl = false; + self.transport.write_element(&elem)?; loop { - let n = self.transport.read_element().unwrap(); - if n.is("features", ns::STREAM) { - if did_sasl { - let mut elem = Element::builder("iq") - .attr("id", "bind") - .attr("type", "set") - .build(); - let bind = Element::builder("bind") - .ns(ns::BIND) - .build(); - elem.append_child(bind); - self.transport.write_element(&elem)?; - } - else { - let auth = mechanism.initial(); - let mut elem = Element::builder("auth") - .ns(ns::SASL) - .attr("mechanism", "PLAIN") - .build(); - if !auth.is_empty() { - elem.append_text_node(base64::encode(&auth)); - } - self.transport.write_element(&elem)?; - } - } - else if n.is("challenge", ns::SASL) { + let n = self.transport.read_element()?; + if n.is("challenge", ns::SASL) { let text = n.text(); let challenge = if text == "" { Vec::new() @@ -178,27 +153,53 @@ impl Client { self.transport.write_element(&elem)?; } else if n.is("success", ns::SASL) { - did_sasl = true; self.transport.reset_stream(); C2S::init(&mut self.transport, &self.jid.domain, "after_sasl")?; - loop { - let e = self.transport.read_event()?; - match e { - ReaderEvent::StartElement { .. } => { - break; - }, - _ => (), - } - } + return self.bind(); } else if n.is("failure", ns::SASL) { let msg = n.text(); let inner = if msg == "" { None } else { Some(msg) }; return Err(Error::SaslError(inner)); } - else if n.is("iq", ns::CLIENT) && n.has_child("bind", ns::BIND) { + } + } + + fn bind(&mut self) -> Result<(), Error> { + self.wait_for_features(); + let mut elem = Element::builder("iq") + .attr("id", "bind") + .attr("type", "set") + .build(); + let bind = Element::builder("bind") + .ns(ns::BIND) + .build(); + elem.append_child(bind); + self.transport.write_element(&elem)?; + loop { + let n = self.transport.read_element()?; + if n.is("iq", ns::CLIENT) && n.has_child("bind", ns::BIND) { return Ok(()); } } } + + fn wait_for_features(&mut self) -> Result { + // TODO: this is very ugly + loop { + let e = self.transport.read_event()?; + match e { + ReaderEvent::StartElement { .. } => { + break; + }, + _ => (), + } + } + loop { + let n = self.transport.read_element()?; + if n.is("features", ns::STREAM) { + return Ok(n); + } + } + } }