mirror of
https://gitlab.com/xmpp-rs/xmpp-rs.git
synced 2024-07-12 22:21:53 +00:00
restructure auth code
This commit is contained in:
parent
bbadf75c01
commit
ce039d767e
4 changed files with 70 additions and 90 deletions
|
@ -1,37 +1,29 @@
|
|||
use futures::{sink, Async, Future, Poll, Stream};
|
||||
use std::mem::replace;
|
||||
use std::str::FromStr;
|
||||
use futures::{sink, Async, Future, Poll, Stream, future::{ok, err, IntoFuture}};
|
||||
use minidom::Element;
|
||||
use sasl::client::mechanisms::{Anonymous, Plain, Scram};
|
||||
use sasl::client::Mechanism;
|
||||
use sasl::common::scram::{Sha1, Sha256};
|
||||
use sasl::common::Credentials;
|
||||
use std::mem::replace;
|
||||
use std::str::FromStr;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
use try_from::TryFrom;
|
||||
use xmpp_parsers::sasl::{Auth, Challenge, Failure, Mechanism as XMPPMechanism, Response, Success};
|
||||
|
||||
use crate::stream_start::StreamStart;
|
||||
use crate::xmpp_codec::Packet;
|
||||
use crate::xmpp_stream::XMPPStream;
|
||||
use crate::{AuthError, Error, ProtocolError};
|
||||
|
||||
const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
|
||||
|
||||
pub struct ClientAuth<S: AsyncWrite> {
|
||||
state: ClientAuthState<S>,
|
||||
mechanism: Box<Mechanism>,
|
||||
pub struct ClientAuth<S: AsyncRead + AsyncWrite> {
|
||||
future: Box<Future<Item = XMPPStream<S>, Error = Error>>,
|
||||
}
|
||||
|
||||
enum ClientAuthState<S: AsyncWrite> {
|
||||
WaitSend(sink::Send<XMPPStream<S>>),
|
||||
WaitRecv(XMPPStream<S>),
|
||||
Start(StreamStart<S>),
|
||||
Invalid,
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite> ClientAuth<S> {
|
||||
impl<S: AsyncRead + AsyncWrite + 'static> ClientAuth<S> {
|
||||
pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, Error> {
|
||||
let mechs: Vec<Box<Mechanism>> = vec![
|
||||
// TODO: Box::new(|| …
|
||||
Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap()),
|
||||
Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap()),
|
||||
Box::new(Plain::from_credentials(creds).unwrap()),
|
||||
|
@ -46,36 +38,74 @@ impl<S: AsyncWrite> ClientAuth<S> {
|
|||
.filter(|child| child.is("mechanism", NS_XMPP_SASL))
|
||||
.map(|mech_el| mech_el.text())
|
||||
.collect();
|
||||
// TODO: iter instead of collect()
|
||||
// println!("SASL mechanisms offered: {:?}", mech_names);
|
||||
|
||||
for mut mech in mechs {
|
||||
let name = mech.name().to_owned();
|
||||
for mut mechanism in mechs {
|
||||
let name = mechanism.name().to_owned();
|
||||
if mech_names.iter().any(|name1| *name1 == name) {
|
||||
// println!("SASL mechanism selected: {:?}", name);
|
||||
let initial = mech.initial().map_err(AuthError::Sasl)?;
|
||||
let mut this = ClientAuth {
|
||||
state: ClientAuthState::Invalid,
|
||||
mechanism: mech,
|
||||
};
|
||||
let mechanism = XMPPMechanism::from_str(&name).map_err(ProtocolError::Parsers)?;
|
||||
this.send(
|
||||
stream,
|
||||
Auth {
|
||||
mechanism,
|
||||
data: initial,
|
||||
},
|
||||
);
|
||||
return Ok(this);
|
||||
let initial = mechanism.initial().map_err(AuthError::Sasl)?;
|
||||
let mechanism_name = XMPPMechanism::from_str(&name).map_err(ProtocolError::Parsers)?;
|
||||
|
||||
let send_initial = Box::new(stream.send_stanza(Auth {
|
||||
mechanism: mechanism_name,
|
||||
data: initial,
|
||||
}))
|
||||
.map_err(Error::Io);
|
||||
let future = Box::new(send_initial.and_then(
|
||||
|stream| Self::handle_challenge(stream, mechanism)
|
||||
).and_then(
|
||||
|stream| stream.restart()
|
||||
));
|
||||
return Ok(ClientAuth {
|
||||
future,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Err(AuthError::NoMechanism)?
|
||||
}
|
||||
|
||||
fn send<N: Into<Element>>(&mut self, stream: XMPPStream<S>, nonza: N) {
|
||||
let send = stream.send_stanza(nonza);
|
||||
|
||||
self.state = ClientAuthState::WaitSend(send);
|
||||
fn handle_challenge(stream: XMPPStream<S>, mut mechanism: Box<Mechanism>) -> Box<Future<Item = XMPPStream<S>, Error = Error>> {
|
||||
Box::new(
|
||||
stream.into_future()
|
||||
.map_err(|(e, _stream)| e.into())
|
||||
.and_then(|(stanza, stream)| {
|
||||
match stanza {
|
||||
Some(Packet::Stanza(stanza)) => {
|
||||
if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
|
||||
let response = mechanism
|
||||
.response(&challenge.data);
|
||||
Box::new(
|
||||
response
|
||||
.map_err(|e| AuthError::Sasl(e).into())
|
||||
.into_future()
|
||||
.and_then(|response| {
|
||||
// Send response and loop
|
||||
stream.send_stanza(Response { data: response })
|
||||
.map_err(Error::Io)
|
||||
.and_then(|stream| Self::handle_challenge(stream, mechanism))
|
||||
})
|
||||
)
|
||||
} else if let Ok(_) = Success::try_from(stanza.clone()) {
|
||||
Box::new(ok(stream))
|
||||
} else if let Ok(failure) = Failure::try_from(stanza.clone()) {
|
||||
Box::new(err(Error::Auth(AuthError::Fail(failure.defined_condition))))
|
||||
} else {
|
||||
// ignore and loop
|
||||
println!("Ignore: {:?}", stanza);
|
||||
Self::handle_challenge(stream, mechanism)
|
||||
}
|
||||
}
|
||||
Some(_) => {
|
||||
// ignore and loop
|
||||
Self::handle_challenge(stream, mechanism)
|
||||
}
|
||||
None => Box::new(err(Error::Disconnected))
|
||||
}
|
||||
})
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -84,58 +114,6 @@ impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
|
|||
type Error = Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
let state = replace(&mut self.state, ClientAuthState::Invalid);
|
||||
|
||||
match state {
|
||||
ClientAuthState::WaitSend(mut send) => match send.poll() {
|
||||
Ok(Async::Ready(stream)) => {
|
||||
self.state = ClientAuthState::WaitRecv(stream);
|
||||
self.poll()
|
||||
}
|
||||
Ok(Async::NotReady) => {
|
||||
self.state = ClientAuthState::WaitSend(send);
|
||||
Ok(Async::NotReady)
|
||||
}
|
||||
Err(e) => Err(e)?,
|
||||
},
|
||||
ClientAuthState::WaitRecv(mut stream) => match stream.poll() {
|
||||
Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => {
|
||||
if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
|
||||
let response = self
|
||||
.mechanism
|
||||
.response(&challenge.data)
|
||||
.map_err(AuthError::Sasl)?;
|
||||
self.send(stream, Response { data: response });
|
||||
self.poll()
|
||||
} else if let Ok(_) = Success::try_from(stanza.clone()) {
|
||||
let start = stream.restart();
|
||||
self.state = ClientAuthState::Start(start);
|
||||
self.poll()
|
||||
} else if let Ok(failure) = Failure::try_from(stanza) {
|
||||
Err(AuthError::Fail(failure.defined_condition))?
|
||||
} else {
|
||||
Ok(Async::NotReady)
|
||||
}
|
||||
}
|
||||
Ok(Async::Ready(_event)) => {
|
||||
// println!("ClientAuth ignore {:?}", _event);
|
||||
Ok(Async::NotReady)
|
||||
}
|
||||
Ok(_) => {
|
||||
self.state = ClientAuthState::WaitRecv(stream);
|
||||
Ok(Async::NotReady)
|
||||
}
|
||||
Err(e) => Err(ProtocolError::Parser(e))?,
|
||||
},
|
||||
ClientAuthState::Start(mut start) => match start.poll() {
|
||||
Ok(Async::Ready(stream)) => Ok(Async::Ready(stream)),
|
||||
Ok(Async::NotReady) => {
|
||||
self.state = ClientAuthState::Start(start);
|
||||
Ok(Async::NotReady)
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
},
|
||||
ClientAuthState::Invalid => unreachable!(),
|
||||
}
|
||||
self.future.poll()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -104,7 +104,7 @@ impl Client {
|
|||
StartTlsClient::from_stream(stream)
|
||||
}
|
||||
|
||||
fn auth<S: AsyncRead + AsyncWrite>(
|
||||
fn auth<S: AsyncRead + AsyncWrite + 'static>(
|
||||
stream: xmpp_stream::XMPPStream<S>,
|
||||
username: String,
|
||||
password: String,
|
||||
|
|
|
@ -26,6 +26,8 @@ pub enum Error {
|
|||
Auth(AuthError),
|
||||
/// TLS error
|
||||
Tls(TlsError),
|
||||
/// Connection closed
|
||||
Disconnected,
|
||||
/// Shoud never happen
|
||||
InvalidState,
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ use xml5ever::interface::Attribute;
|
|||
use xml5ever::tokenizer::{Tag, TagKind, Token, TokenSink, XmlTokenizer};
|
||||
|
||||
/// Anything that can be sent or received on an XMPP/XML stream
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum Packet {
|
||||
/// `<stream:stream>` start tag
|
||||
StreamStart(HashMap<String, String>),
|
||||
|
|
Loading…
Reference in a new issue