simplify the API regarding authentication

This commit is contained in:
lumi 2017-02-25 06:49:13 +01:00
parent a0685e2dc6
commit 6579ce6563
6 changed files with 99 additions and 19 deletions

View file

@ -4,19 +4,18 @@ use xmpp::jid::Jid;
use xmpp::client::ClientBuilder; use xmpp::client::ClientBuilder;
use xmpp::plugins::messaging::{MessagingPlugin, MessageEvent}; use xmpp::plugins::messaging::{MessagingPlugin, MessageEvent};
use xmpp::plugins::presence::{PresencePlugin, Show}; use xmpp::plugins::presence::{PresencePlugin, Show};
use xmpp::sasl::mechanisms::{Scram, Sha1};
use std::env; use std::env;
fn main() { fn main() {
let jid: Jid = env::var("JID").unwrap().parse().unwrap(); let jid: Jid = env::var("JID").unwrap().parse().unwrap();
let mut client = ClientBuilder::new(jid.clone()).connect().unwrap(); let pass = env::var("PASS").unwrap();
let mut client = ClientBuilder::new(jid.clone())
.password(pass)
.connect()
.unwrap();
client.register_plugin(MessagingPlugin::new()); client.register_plugin(MessagingPlugin::new());
client.register_plugin(PresencePlugin::new()); client.register_plugin(PresencePlugin::new());
let pass = env::var("PASS").unwrap();
let name = jid.node.clone().expect("JID requires a node");
client.connect(&mut Scram::<Sha1>::new(name, pass).unwrap()).unwrap();
client.bind().unwrap();
client.plugin::<PresencePlugin>().set_presence(Show::Available, None).unwrap(); client.plugin::<PresencePlugin>().set_presence(Show::Available, None).unwrap();
loop { loop {
let event = client.next_event().unwrap(); let event = client.next_event().unwrap();

View file

@ -5,7 +5,8 @@ use ns;
use plugin::{Plugin, PluginProxyBinding}; use plugin::{Plugin, PluginProxyBinding};
use event::AbstractEvent; use event::AbstractEvent;
use connection::{Connection, C2S}; use connection::{Connection, C2S};
use sasl::SaslMechanism; use sasl::{SaslMechanism, SaslCredentials, SaslSecret};
use sasl::mechanisms::{Plain, Scram, Sha1, Sha256};
use base64; use base64;
@ -15,15 +16,18 @@ use xml::reader::XmlEvent as ReaderEvent;
use std::sync::mpsc::{Receiver, channel}; use std::sync::mpsc::{Receiver, channel};
use std::collections::HashSet;
/// Struct that should be moved somewhere else and cleaned up. /// Struct that should be moved somewhere else and cleaned up.
#[derive(Debug)] #[derive(Debug)]
pub struct StreamFeatures { pub struct StreamFeatures {
pub sasl_mechanisms: Option<Vec<String>>, pub sasl_mechanisms: Option<HashSet<String>>,
} }
/// A builder for `Client`s. /// A builder for `Client`s.
pub struct ClientBuilder { pub struct ClientBuilder {
jid: Jid, jid: Jid,
credentials: Option<SaslCredentials>,
host: Option<String>, host: Option<String>,
port: u16, port: u16,
} }
@ -33,6 +37,7 @@ impl ClientBuilder {
pub fn new(jid: Jid) -> ClientBuilder { pub fn new(jid: Jid) -> ClientBuilder {
ClientBuilder { ClientBuilder {
jid: jid, jid: jid,
credentials: None,
host: None, host: None,
port: 5222, port: 5222,
} }
@ -50,21 +55,35 @@ impl ClientBuilder {
self self
} }
/// Sets the password to use.
pub fn password<P: Into<String>>(mut self, password: P) -> ClientBuilder {
self.credentials = Some(SaslCredentials {
username: self.jid.node.clone().expect("JID has no node"),
secret: SaslSecret::Password(password.into()),
channel_binding: None,
});
self
}
/// Connects to the server and returns a `Client` when succesful. /// Connects to the server and returns a `Client` when succesful.
pub fn connect(self) -> Result<Client, Error> { pub fn connect(self) -> Result<Client, Error> {
let host = &self.host.unwrap_or(self.jid.domain.clone()); let host = &self.host.unwrap_or(self.jid.domain.clone());
// TODO: channel binding
let mut transport = SslTransport::connect(host, self.port)?; let mut transport = SslTransport::connect(host, self.port)?;
C2S::init(&mut transport, &self.jid.domain, "before_sasl")?; C2S::init(&mut transport, &self.jid.domain, "before_sasl")?;
let (sender_out, sender_in) = channel(); let (sender_out, sender_in) = channel();
let (dispatcher_out, dispatcher_in) = channel(); let (dispatcher_out, dispatcher_in) = channel();
Ok(Client { let mut client = Client {
jid: self.jid, jid: self.jid,
transport: transport, transport: transport,
plugins: Vec::new(), plugins: Vec::new(),
binding: PluginProxyBinding::new(sender_out, dispatcher_out), binding: PluginProxyBinding::new(sender_out, dispatcher_out),
sender_in: sender_in, sender_in: sender_in,
dispatcher_in: dispatcher_in, dispatcher_in: dispatcher_in,
}) };
client.connect(self.credentials.expect("can't connect without credentials"))?;
client.bind()?;
Ok(client)
} }
} }
@ -127,9 +146,23 @@ impl Client {
Ok(()) Ok(())
} }
/// Connects and authenticates using the specified SASL mechanism. fn connect(&mut self, credentials: SaslCredentials) -> Result<(), Error> {
pub fn connect<S: SaslMechanism>(&mut self, mechanism: &mut S) -> Result<(), Error> { let features = self.wait_for_features()?;
self.wait_for_features()?; let ms = &features.sasl_mechanisms.ok_or(Error::SaslError(Some("no SASL mechanisms".to_owned())))?;
fn wrap_err(err: String) -> Error { Error::SaslError(Some(err)) }
// TODO: better way for selecting these, enabling anonymous auth
let mut mechanism: Box<SaslMechanism> = if ms.contains("SCRAM-SHA-256") {
Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
}
else if ms.contains("SCRAM-SHA-1") {
Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
}
else if ms.contains("PLAIN") {
Box::new(Plain::from_credentials(credentials).map_err(wrap_err)?)
}
else {
return Err(Error::SaslError(Some("can't find a SASL mechanism to use".to_owned())));
};
let auth = mechanism.initial().map_err(|x| Error::SaslError(Some(x)))?; let auth = mechanism.initial().map_err(|x| Error::SaslError(Some(x)))?;
let mut elem = Element::builder("auth") let mut elem = Element::builder("auth")
.ns(ns::SASL) .ns(ns::SASL)
@ -180,7 +213,7 @@ impl Client {
} }
} }
pub fn bind(&mut self) -> Result<(), Error> { fn bind(&mut self) -> Result<(), Error> {
let mut elem = Element::builder("iq") let mut elem = Element::builder("iq")
.attr("id", "bind") .attr("id", "bind")
.attr("type", "set") .attr("type", "set")
@ -223,9 +256,9 @@ impl Client {
sasl_mechanisms: None, sasl_mechanisms: None,
}; };
if let Some(ms) = n.get_child("mechanisms", ns::SASL) { if let Some(ms) = n.get_child("mechanisms", ns::SASL) {
let mut res = Vec::new(); let mut res = HashSet::new();
for cld in ms.children() { for cld in ms.children() {
res.push(cld.text()); res.insert(cld.text());
} }
features.sasl_mechanisms = Some(res); features.sasl_mechanisms = Some(res);
} }

View file

@ -1,6 +1,6 @@
//! Provides the SASL "ANONYMOUS" mechanism. //! Provides the SASL "ANONYMOUS" mechanism.
use sasl::SaslMechanism; use sasl::{SaslMechanism, SaslCredentials, SaslSecret};
pub struct Anonymous; pub struct Anonymous;
@ -12,4 +12,13 @@ impl Anonymous {
impl SaslMechanism for Anonymous { impl SaslMechanism for Anonymous {
fn name(&self) -> &str { "ANONYMOUS" } fn name(&self) -> &str { "ANONYMOUS" }
fn from_credentials(credentials: SaslCredentials) -> Result<Anonymous, String> {
if let SaslSecret::None = credentials.secret {
Ok(Anonymous)
}
else {
Err("the anonymous sasl mechanism requires no credentials".to_owned())
}
}
} }

View file

@ -1,6 +1,6 @@
//! Provides the SASL "PLAIN" mechanism. //! Provides the SASL "PLAIN" mechanism.
use sasl::SaslMechanism; use sasl::{SaslMechanism, SaslCredentials, SaslSecret};
pub struct Plain { pub struct Plain {
username: String, username: String,
@ -19,6 +19,15 @@ impl Plain {
impl SaslMechanism for Plain { impl SaslMechanism for Plain {
fn name(&self) -> &str { "PLAIN" } fn name(&self) -> &str { "PLAIN" }
fn from_credentials(credentials: SaslCredentials) -> Result<Plain, String> {
if let SaslSecret::Password(password) = credentials.secret {
Ok(Plain::new(credentials.username, password))
}
else {
Err("PLAIN requires a password".to_owned())
}
}
fn initial(&mut self) -> Result<Vec<u8>, String> { fn initial(&mut self) -> Result<Vec<u8>, String> {
let mut auth = Vec::new(); let mut auth = Vec::new();
auth.push(0); auth.push(0);

View file

@ -2,7 +2,7 @@
use base64; use base64;
use sasl::SaslMechanism; use sasl::{SaslMechanism, SaslCredentials, SaslSecret};
use error::Error; use error::Error;
@ -172,6 +172,22 @@ impl<S: ScramProvider> SaslMechanism for Scram<S> {
&self.name &self.name
} }
fn from_credentials(credentials: SaslCredentials) -> Result<Scram<S>, String> {
if let SaslSecret::Password(password) = credentials.secret {
if let Some(binding) = credentials.channel_binding {
Scram::new_with_channel_binding(credentials.username, password, binding)
.map_err(|_| "can't generate nonce".to_owned())
}
else {
Scram::new(credentials.username, password)
.map_err(|_| "can't generate nonce".to_owned())
}
}
else {
Err("SCRAM requires a password".to_owned())
}
}
fn initial(&mut self) -> Result<Vec<u8>, String> { fn initial(&mut self) -> Result<Vec<u8>, String> {
let mut gs2_header = Vec::new(); let mut gs2_header = Vec::new();
if let Some(_) = self.channel_binding { if let Some(_) = self.channel_binding {

View file

@ -1,9 +1,23 @@
//! Provides the `SaslMechanism` trait and some implementations. //! Provides the `SaslMechanism` trait and some implementations.
pub struct SaslCredentials {
pub username: String,
pub secret: SaslSecret,
pub channel_binding: Option<Vec<u8>>,
}
pub enum SaslSecret {
None,
Password(String),
}
pub trait SaslMechanism { pub trait SaslMechanism {
/// The name of the mechanism. /// The name of the mechanism.
fn name(&self) -> &str; fn name(&self) -> &str;
/// Creates this mechanism from `SaslCredentials`.
fn from_credentials(credentials: SaslCredentials) -> Result<Self, String> where Self: Sized;
/// Provides initial payload of the SASL mechanism. /// Provides initial payload of the SASL mechanism.
fn initial(&mut self) -> Result<Vec<u8>, String> { fn initial(&mut self) -> Result<Vec<u8>, String> {
Ok(Vec::new()) Ok(Vec::new())