diff --git a/examples/client.rs b/examples/client.rs index 36ed4af3..9e12a4a3 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -4,19 +4,18 @@ use xmpp::jid::Jid; use xmpp::client::ClientBuilder; use xmpp::plugins::messaging::{MessagingPlugin, MessageEvent}; use xmpp::plugins::presence::{PresencePlugin, Show}; -use xmpp::sasl::mechanisms::{Scram, Sha1}; use std::env; fn main() { let jid: Jid = env::var("JID").unwrap().parse().unwrap(); - let mut client = ClientBuilder::new(jid.clone()).connect().unwrap(); + let pass = env::var("PASS").unwrap(); + let mut client = ClientBuilder::new(jid.clone()) + .password(pass) + .connect() + .unwrap(); client.register_plugin(MessagingPlugin::new()); client.register_plugin(PresencePlugin::new()); - let pass = env::var("PASS").unwrap(); - let name = jid.node.clone().expect("JID requires a node"); - client.connect(&mut Scram::::new(name, pass).unwrap()).unwrap(); - client.bind().unwrap(); client.plugin::().set_presence(Show::Available, None).unwrap(); loop { let event = client.next_event().unwrap(); diff --git a/src/client.rs b/src/client.rs index 475cd2e5..fff412e6 100644 --- a/src/client.rs +++ b/src/client.rs @@ -5,7 +5,8 @@ use ns; use plugin::{Plugin, PluginProxyBinding}; use event::AbstractEvent; use connection::{Connection, C2S}; -use sasl::SaslMechanism; +use sasl::{SaslMechanism, SaslCredentials, SaslSecret}; +use sasl::mechanisms::{Plain, Scram, Sha1, Sha256}; use base64; @@ -15,15 +16,18 @@ use xml::reader::XmlEvent as ReaderEvent; use std::sync::mpsc::{Receiver, channel}; +use std::collections::HashSet; + /// Struct that should be moved somewhere else and cleaned up. #[derive(Debug)] pub struct StreamFeatures { - pub sasl_mechanisms: Option>, + pub sasl_mechanisms: Option>, } /// A builder for `Client`s. pub struct ClientBuilder { jid: Jid, + credentials: Option, host: Option, port: u16, } @@ -33,6 +37,7 @@ impl ClientBuilder { pub fn new(jid: Jid) -> ClientBuilder { ClientBuilder { jid: jid, + credentials: None, host: None, port: 5222, } @@ -50,21 +55,35 @@ impl ClientBuilder { self } + /// Sets the password to use. + pub fn password>(mut self, password: P) -> ClientBuilder { + self.credentials = Some(SaslCredentials { + username: self.jid.node.clone().expect("JID has no node"), + secret: SaslSecret::Password(password.into()), + channel_binding: None, + }); + self + } + /// Connects to the server and returns a `Client` when succesful. pub fn connect(self) -> Result { let host = &self.host.unwrap_or(self.jid.domain.clone()); + // TODO: channel binding let mut transport = SslTransport::connect(host, self.port)?; C2S::init(&mut transport, &self.jid.domain, "before_sasl")?; let (sender_out, sender_in) = channel(); let (dispatcher_out, dispatcher_in) = channel(); - Ok(Client { + let mut client = Client { jid: self.jid, transport: transport, plugins: Vec::new(), binding: PluginProxyBinding::new(sender_out, dispatcher_out), sender_in: sender_in, dispatcher_in: dispatcher_in, - }) + }; + client.connect(self.credentials.expect("can't connect without credentials"))?; + client.bind()?; + Ok(client) } } @@ -127,9 +146,23 @@ impl Client { Ok(()) } - /// Connects and authenticates using the specified SASL mechanism. - pub fn connect(&mut self, mechanism: &mut S) -> Result<(), Error> { - self.wait_for_features()?; + fn connect(&mut self, credentials: SaslCredentials) -> Result<(), Error> { + let features = self.wait_for_features()?; + let ms = &features.sasl_mechanisms.ok_or(Error::SaslError(Some("no SASL mechanisms".to_owned())))?; + fn wrap_err(err: String) -> Error { Error::SaslError(Some(err)) } + // TODO: better way for selecting these, enabling anonymous auth + let mut mechanism: Box = if ms.contains("SCRAM-SHA-256") { + Box::new(Scram::::from_credentials(credentials).map_err(wrap_err)?) + } + else if ms.contains("SCRAM-SHA-1") { + Box::new(Scram::::from_credentials(credentials).map_err(wrap_err)?) + } + else if ms.contains("PLAIN") { + Box::new(Plain::from_credentials(credentials).map_err(wrap_err)?) + } + else { + return Err(Error::SaslError(Some("can't find a SASL mechanism to use".to_owned()))); + }; let auth = mechanism.initial().map_err(|x| Error::SaslError(Some(x)))?; let mut elem = Element::builder("auth") .ns(ns::SASL) @@ -180,7 +213,7 @@ impl Client { } } - pub fn bind(&mut self) -> Result<(), Error> { + fn bind(&mut self) -> Result<(), Error> { let mut elem = Element::builder("iq") .attr("id", "bind") .attr("type", "set") @@ -223,9 +256,9 @@ impl Client { sasl_mechanisms: None, }; if let Some(ms) = n.get_child("mechanisms", ns::SASL) { - let mut res = Vec::new(); + let mut res = HashSet::new(); for cld in ms.children() { - res.push(cld.text()); + res.insert(cld.text()); } features.sasl_mechanisms = Some(res); } diff --git a/src/sasl/mechanisms/anonymous.rs b/src/sasl/mechanisms/anonymous.rs index f54dd1fe..32538e5b 100644 --- a/src/sasl/mechanisms/anonymous.rs +++ b/src/sasl/mechanisms/anonymous.rs @@ -1,6 +1,6 @@ //! Provides the SASL "ANONYMOUS" mechanism. -use sasl::SaslMechanism; +use sasl::{SaslMechanism, SaslCredentials, SaslSecret}; pub struct Anonymous; @@ -12,4 +12,13 @@ impl Anonymous { impl SaslMechanism for Anonymous { fn name(&self) -> &str { "ANONYMOUS" } + + fn from_credentials(credentials: SaslCredentials) -> Result { + if let SaslSecret::None = credentials.secret { + Ok(Anonymous) + } + else { + Err("the anonymous sasl mechanism requires no credentials".to_owned()) + } + } } diff --git a/src/sasl/mechanisms/plain.rs b/src/sasl/mechanisms/plain.rs index 651d224f..b3739a87 100644 --- a/src/sasl/mechanisms/plain.rs +++ b/src/sasl/mechanisms/plain.rs @@ -1,6 +1,6 @@ //! Provides the SASL "PLAIN" mechanism. -use sasl::SaslMechanism; +use sasl::{SaslMechanism, SaslCredentials, SaslSecret}; pub struct Plain { username: String, @@ -19,6 +19,15 @@ impl Plain { impl SaslMechanism for Plain { fn name(&self) -> &str { "PLAIN" } + fn from_credentials(credentials: SaslCredentials) -> Result { + if let SaslSecret::Password(password) = credentials.secret { + Ok(Plain::new(credentials.username, password)) + } + else { + Err("PLAIN requires a password".to_owned()) + } + } + fn initial(&mut self) -> Result, String> { let mut auth = Vec::new(); auth.push(0); diff --git a/src/sasl/mechanisms/scram.rs b/src/sasl/mechanisms/scram.rs index f6f6d865..41fb2c52 100644 --- a/src/sasl/mechanisms/scram.rs +++ b/src/sasl/mechanisms/scram.rs @@ -2,7 +2,7 @@ use base64; -use sasl::SaslMechanism; +use sasl::{SaslMechanism, SaslCredentials, SaslSecret}; use error::Error; @@ -172,6 +172,22 @@ impl SaslMechanism for Scram { &self.name } + fn from_credentials(credentials: SaslCredentials) -> Result, String> { + if let SaslSecret::Password(password) = credentials.secret { + if let Some(binding) = credentials.channel_binding { + Scram::new_with_channel_binding(credentials.username, password, binding) + .map_err(|_| "can't generate nonce".to_owned()) + } + else { + Scram::new(credentials.username, password) + .map_err(|_| "can't generate nonce".to_owned()) + } + } + else { + Err("SCRAM requires a password".to_owned()) + } + } + fn initial(&mut self) -> Result, String> { let mut gs2_header = Vec::new(); if let Some(_) = self.channel_binding { diff --git a/src/sasl/mod.rs b/src/sasl/mod.rs index 55ee130d..25d4d3a0 100644 --- a/src/sasl/mod.rs +++ b/src/sasl/mod.rs @@ -1,9 +1,23 @@ //! Provides the `SaslMechanism` trait and some implementations. +pub struct SaslCredentials { + pub username: String, + pub secret: SaslSecret, + pub channel_binding: Option>, +} + +pub enum SaslSecret { + None, + Password(String), +} + pub trait SaslMechanism { /// The name of the mechanism. fn name(&self) -> &str; + /// Creates this mechanism from `SaslCredentials`. + fn from_credentials(credentials: SaslCredentials) -> Result where Self: Sized; + /// Provides initial payload of the SASL mechanism. fn initial(&mut self) -> Result, String> { Ok(Vec::new())