auth: clarify + optimize

This commit is contained in:
Astro 2018-12-20 21:17:56 +01:00
parent 11cc7f183a
commit 78f74c6338

View file

@ -1,5 +1,6 @@
use std::mem::replace; use std::mem::replace;
use std::str::FromStr; use std::str::FromStr;
use std::collections::HashSet;
use futures::{sink, Async, Future, Poll, Stream, future::{ok, err, IntoFuture}}; use futures::{sink, Async, Future, Poll, Stream, future::{ok, err, IntoFuture}};
use minidom::Element; use minidom::Element;
use sasl::client::mechanisms::{Anonymous, Plain, Scram}; use sasl::client::mechanisms::{Anonymous, Plain, Scram};
@ -22,15 +23,14 @@ pub struct ClientAuth<S: AsyncRead + AsyncWrite> {
impl<S: AsyncRead + AsyncWrite + 'static> ClientAuth<S> { impl<S: AsyncRead + AsyncWrite + 'static> ClientAuth<S> {
pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, Error> { pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, Error> {
let mechs: Vec<Box<Mechanism>> = vec![ let local_mechs: Vec<Box<Fn() -> Box<Mechanism>>> = vec![
// TODO: Box::new(|| … Box::new(|| Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap())),
Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap()), Box::new(|| Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap())),
Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap()), Box::new(|| Box::new(Plain::from_credentials(creds.clone()).unwrap())),
Box::new(Plain::from_credentials(creds).unwrap()), Box::new(|| Box::new(Anonymous::new())),
Box::new(Anonymous::new()),
]; ];
let mech_names: Vec<String> = stream let remote_mechs: HashSet<String> = stream
.stream_features .stream_features
.get_child("mechanisms", NS_XMPP_SASL) .get_child("mechanisms", NS_XMPP_SASL)
.ok_or(AuthError::NoMechanism)? .ok_or(AuthError::NoMechanism)?
@ -38,15 +38,12 @@ impl<S: AsyncRead + AsyncWrite + 'static> ClientAuth<S> {
.filter(|child| child.is("mechanism", NS_XMPP_SASL)) .filter(|child| child.is("mechanism", NS_XMPP_SASL))
.map(|mech_el| mech_el.text()) .map(|mech_el| mech_el.text())
.collect(); .collect();
// TODO: iter instead of collect()
// println!("SASL mechanisms offered: {:?}", mech_names);
for mut mechanism in mechs { for local_mech in local_mechs {
let name = mechanism.name().to_owned(); let mut mechanism = local_mech();
if mech_names.iter().any(|name1| *name1 == name) { if remote_mechs.contains(mechanism.name()) {
// println!("SASL mechanism selected: {:?}", name);
let initial = mechanism.initial().map_err(AuthError::Sasl)?; let initial = mechanism.initial().map_err(AuthError::Sasl)?;
let mechanism_name = XMPPMechanism::from_str(&name).map_err(ProtocolError::Parsers)?; let mechanism_name = XMPPMechanism::from_str(mechanism.name()).map_err(ProtocolError::Parsers)?;
let send_initial = Box::new(stream.send_stanza(Auth { let send_initial = Box::new(stream.send_stanza(Auth {
mechanism: mechanism_name, mechanism: mechanism_name,