diff --git a/examples/client.rs b/examples/client.rs index e144433..7a6e2f9 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -4,15 +4,17 @@ use xmpp::jid::Jid; use xmpp::client::ClientBuilder; use xmpp::plugins::messaging::{MessagingPlugin, MessageEvent}; use xmpp::plugins::presence::{PresencePlugin, Show}; +use xmpp::sasl::mechanisms::Plain; use std::env; fn main() { let jid: Jid = env::var("JID").unwrap().parse().unwrap(); - let mut client = ClientBuilder::new(jid).connect().unwrap(); + let mut client = ClientBuilder::new(jid.clone()).connect().unwrap(); client.register_plugin(MessagingPlugin::new()); client.register_plugin(PresencePlugin::new()); - client.connect_plain(&env::var("PASS").unwrap()).unwrap(); + let pass = env::var("PASS").unwrap(); + client.connect(&mut Plain::new(jid.node.clone().expect("JID requires a node"), pass)).unwrap(); client.plugin::().set_presence(Show::Available, None).unwrap(); loop { let event = client.next_event().unwrap(); diff --git a/src/client.rs b/src/client.rs index f8cd6fa..b5f8f4e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -6,7 +6,6 @@ use plugin::{Plugin, PluginProxyBinding}; use event::AbstractEvent; use connection::{Connection, C2S}; use sasl::SaslMechanism; -use sasl::mechanisms::Plain as SaslPlain; use base64; @@ -122,8 +121,8 @@ impl Client { Ok(()) } - /// Connects using SASL plain authentication. - pub fn connect_plain(&mut self, password: &str) -> Result<(), Error> { + /// Connects using the specified SASL mechanism. + pub fn connect(&mut self, mechanism: &mut S) -> Result<(), Error> { // TODO: this is very ugly loop { let e = self.transport.read_event().unwrap(); @@ -150,19 +149,36 @@ impl Client { self.transport.write_element(&elem)?; } else { - let name = self.jid.node.clone().expect("JID has no node"); - let mut plain = SaslPlain::new(name, password.to_owned()); - let auth = plain.initial(); - let elem = Element::builder("auth") - .text(base64::encode(&auth)) - .ns(ns::SASL) - .attr("mechanism", "PLAIN") - .build(); + let auth = mechanism.initial(); + let mut elem = Element::builder("auth") + .ns(ns::SASL) + .attr("mechanism", "PLAIN") + .build(); + if !auth.is_empty() { + elem.append_text_node(base64::encode(&auth)); + } self.transport.write_element(&elem)?; - did_sasl = true; } } + else if n.is("challenge", ns::SASL) { + let text = n.text(); + let challenge = if text == "" { + Vec::new() + } + else { + base64::decode(&text)? + }; + let response = mechanism.response(&challenge); + let mut elem = Element::builder("response") + .ns(ns::SASL) + .build(); + if !response.is_empty() { + elem.append_text_node(base64::encode(&response)); + } + self.transport.write_element(&elem)?; + } else if n.is("success", ns::SASL) { + did_sasl = true; self.transport.reset_stream(); C2S::init(&mut self.transport, &self.jid.domain, "after_sasl")?; loop { diff --git a/src/error.rs b/src/error.rs index d0f35df..1d8e373 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,6 +12,8 @@ use xml::writer::Error as EmitterError; use minidom::Error as MinidomError; +use base64::Base64Error; + /// An error which wraps a bunch of errors from different crates and the stdlib. #[derive(Debug)] pub enum Error { @@ -21,6 +23,7 @@ pub enum Error { HandshakeError(HandshakeError), OpenSslErrorStack(ErrorStack), MinidomError(MinidomError), + Base64Error(Base64Error), StreamError, EndOfDocument, } @@ -60,3 +63,9 @@ impl From for Error { Error::MinidomError(err) } } + +impl From for Error { + fn from(err: Base64Error) -> Error { + Error::Base64Error(err) + } +} diff --git a/src/sasl.rs b/src/sasl.rs index aa04b49..6a047ce 100644 --- a/src/sasl.rs +++ b/src/sasl.rs @@ -10,7 +10,7 @@ pub trait SaslMechanism { } /// Creates a response to the SASL challenge. - fn respond(&mut self, _challenge: &[u8]) -> Vec { + fn response(&mut self, _challenge: &[u8]) -> Vec { Vec::new() } }