From 09745829f19975455d5a8551e34f1671873776b2 Mon Sep 17 00:00:00 2001 From: Emmanuel Gil Peyrot Date: Tue, 25 Feb 2020 23:31:21 +0100 Subject: [PATCH 1/2] client: Remove Result from Mechanism::initial(). --- sasl/src/client/mechanisms/plain.rs | 4 ++-- sasl/src/client/mechanisms/scram.rs | 8 ++++---- sasl/src/client/mod.rs | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sasl/src/client/mechanisms/plain.rs b/sasl/src/client/mechanisms/plain.rs index 1c5bd1a3..08036e60 100644 --- a/sasl/src/client/mechanisms/plain.rs +++ b/sasl/src/client/mechanisms/plain.rs @@ -39,12 +39,12 @@ impl Mechanism for Plain { } } - fn initial(&mut self) -> Result, String> { + fn initial(&mut self) -> Vec { let mut auth = Vec::new(); auth.push(0); auth.extend(self.username.bytes()); auth.push(0); auth.extend(self.password.bytes()); - Ok(auth) + auth } } diff --git a/sasl/src/client/mechanisms/scram.rs b/sasl/src/client/mechanisms/scram.rs index a3bd35c9..f3c1d30a 100644 --- a/sasl/src/client/mechanisms/scram.rs +++ b/sasl/src/client/mechanisms/scram.rs @@ -93,7 +93,7 @@ impl Mechanism for Scram { } } - fn initial(&mut self) -> Result, String> { + fn initial(&mut self) -> Vec { let mut gs2_header = Vec::new(); gs2_header.extend(self.channel_binding.header()); let mut bare = Vec::new(); @@ -108,7 +108,7 @@ impl Mechanism for Scram { initial_message: bare, gs2_header: gs2_header, }; - Ok(data) + data } fn response(&mut self, challenge: &[u8]) -> Result, String> { @@ -206,7 +206,7 @@ mod tests { let server_final = b"v=rmF9pqV8S7suAoZWja4dJRkFsKQ="; let mut mechanism = Scram::::new_with_nonce(username, password, client_nonce.to_owned()); - let init = mechanism.initial().unwrap(); + let init = mechanism.initial(); assert_eq!( String::from_utf8(init.clone()).unwrap(), String::from_utf8(client_init[..].to_owned()).unwrap() @@ -231,7 +231,7 @@ mod tests { let server_final = b"v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4="; let mut mechanism = Scram::::new_with_nonce(username, password, client_nonce.to_owned()); - let init = mechanism.initial().unwrap(); + let init = mechanism.initial(); assert_eq!( String::from_utf8(init.clone()).unwrap(), String::from_utf8(client_init[..].to_owned()).unwrap() diff --git a/sasl/src/client/mod.rs b/sasl/src/client/mod.rs index 2acf9cff..93463f65 100644 --- a/sasl/src/client/mod.rs +++ b/sasl/src/client/mod.rs @@ -11,8 +11,8 @@ pub trait Mechanism { Self: Sized; /// Provides initial payload of the SASL mechanism. - fn initial(&mut self) -> Result, String> { - Ok(Vec::new()) + fn initial(&mut self) -> Vec { + Vec::new() } /// Creates a response to the SASL challenge. From 7fd692346471d0321ab6f7bb0a685162dcbff49f Mon Sep 17 00:00:00 2001 From: Emmanuel Gil Peyrot Date: Fri, 15 May 2020 13:48:27 +0200 Subject: [PATCH 2/2] Use error structs for errors instead of plain strings. --- sasl/src/client/mechanisms/anonymous.rs | 6 +- sasl/src/client/mechanisms/plain.rs | 8 +- sasl/src/client/mechanisms/scram.rs | 32 +++--- sasl/src/client/mod.rs | 82 ++++++++++++- sasl/src/common/scram.rs | 77 ++++++++----- sasl/src/lib.rs | 43 +++++-- sasl/src/secret.rs | 15 ++- sasl/src/server/mechanisms/plain.rs | 16 +-- sasl/src/server/mechanisms/scram.rs | 31 ++--- sasl/src/server/mod.rs | 146 +++++++++++++++++++++++- 10 files changed, 357 insertions(+), 99 deletions(-) diff --git a/sasl/src/client/mechanisms/anonymous.rs b/sasl/src/client/mechanisms/anonymous.rs index 45361377..96b236a6 100644 --- a/sasl/src/client/mechanisms/anonymous.rs +++ b/sasl/src/client/mechanisms/anonymous.rs @@ -1,6 +1,6 @@ //! Provides the SASL "ANONYMOUS" mechanism. -use crate::client::Mechanism; +use crate::client::{Mechanism, MechanismError}; use crate::common::{Credentials, Secret}; /// A struct for the SASL ANONYMOUS mechanism. @@ -21,11 +21,11 @@ impl Mechanism for Anonymous { "ANONYMOUS" } - fn from_credentials(credentials: Credentials) -> Result { + fn from_credentials(credentials: Credentials) -> Result { if let Secret::None = credentials.secret { Ok(Anonymous) } else { - Err("the anonymous sasl mechanism requires no credentials".to_owned()) + Err(MechanismError::AnonymousRequiresNoCredentials) } } } diff --git a/sasl/src/client/mechanisms/plain.rs b/sasl/src/client/mechanisms/plain.rs index 08036e60..bc08fd85 100644 --- a/sasl/src/client/mechanisms/plain.rs +++ b/sasl/src/client/mechanisms/plain.rs @@ -1,6 +1,6 @@ //! Provides the SASL "PLAIN" mechanism. -use crate::client::Mechanism; +use crate::client::{Mechanism, MechanismError}; use crate::common::{Credentials, Identity, Password, Secret}; /// A struct for the SASL PLAIN mechanism. @@ -27,15 +27,15 @@ impl Mechanism for Plain { "PLAIN" } - fn from_credentials(credentials: Credentials) -> Result { + fn from_credentials(credentials: Credentials) -> Result { if let Secret::Password(Password::Plain(password)) = credentials.secret { if let Identity::Username(username) = credentials.identity { Ok(Plain::new(username, password)) } else { - Err("PLAIN requires a username".to_owned()) + Err(MechanismError::PlainRequiresUsername) } } else { - Err("PLAIN requires a plaintext password".to_owned()) + Err(MechanismError::PlainRequiresPlaintextPassword) } } diff --git a/sasl/src/client/mechanisms/scram.rs b/sasl/src/client/mechanisms/scram.rs index f3c1d30a..10f828d8 100644 --- a/sasl/src/client/mechanisms/scram.rs +++ b/sasl/src/client/mechanisms/scram.rs @@ -2,7 +2,7 @@ use base64; -use crate::client::Mechanism; +use crate::client::{Mechanism, MechanismError}; use crate::common::scram::{generate_nonce, ScramProvider}; use crate::common::{parse_frame, xor, ChannelBinding, Credentials, Identity, Password, Secret}; @@ -80,16 +80,16 @@ impl Mechanism for Scram { &self.name } - fn from_credentials(credentials: Credentials) -> Result, String> { + fn from_credentials(credentials: Credentials) -> Result, MechanismError> { if let Secret::Password(password) = credentials.secret { if let Identity::Username(username) = credentials.identity { Scram::new(username, password, credentials.channel_binding) - .map_err(|_| "can't generate nonce".to_owned()) + .map_err(|_| MechanismError::CannotGenerateNonce) } else { - Err("SCRAM requires a username".to_owned()) + Err(MechanismError::ScramRequiresUsername) } } else { - Err("SCRAM requires a password".to_owned()) + Err(MechanismError::ScramRequiresPassword) } } @@ -111,7 +111,7 @@ impl Mechanism for Scram { data } - fn response(&mut self, challenge: &[u8]) -> Result, String> { + fn response(&mut self, challenge: &[u8]) -> Result, MechanismError> { let next_state; let ret; match self.state { @@ -120,13 +120,13 @@ impl Mechanism for Scram { ref gs2_header, } => { let frame = - parse_frame(challenge).map_err(|_| "can't decode challenge".to_owned())?; + parse_frame(challenge).map_err(|_| MechanismError::CannotDecodeChallenge)?; let server_nonce = frame.get("r"); let salt = frame.get("s").and_then(|v| base64::decode(v).ok()); let iterations = frame.get("i").and_then(|v| v.parse().ok()); - let server_nonce = server_nonce.ok_or_else(|| "no server nonce".to_owned())?; - let salt = salt.ok_or_else(|| "no server salt".to_owned())?; - let iterations = iterations.ok_or_else(|| "no server iterations".to_owned())?; + let server_nonce = server_nonce.ok_or_else(|| MechanismError::NoServerNonce)?; + let salt = salt.ok_or_else(|| MechanismError::NoServerSalt)?; + let iterations = iterations.ok_or_else(|| MechanismError::NoServerIterations)?; // TODO: SASLprep let mut client_final_message_bare = Vec::new(); client_final_message_bare.extend(b"c="); @@ -159,15 +159,15 @@ impl Mechanism for Scram { ret = client_final_message; } _ => { - return Err("not in the right state to receive this response".to_owned()); + return Err(MechanismError::InvalidState); } } self.state = next_state; Ok(ret) } - fn success(&mut self, data: &[u8]) -> Result<(), String> { - let frame = parse_frame(data).map_err(|_| "can't decode success response".to_owned())?; + fn success(&mut self, data: &[u8]) -> Result<(), MechanismError> { + let frame = parse_frame(data).map_err(|_| MechanismError::CannotDecodeSuccessResponse)?; match self.state { ScramState::GotServerData { ref server_signature, @@ -176,13 +176,13 @@ impl Mechanism for Scram { if sig == *server_signature { Ok(()) } else { - Err("invalid signature in success response".to_owned()) + Err(MechanismError::InvalidSignatureInSuccessResponse) } } else { - Err("no signature in success response".to_owned()) + Err(MechanismError::NoSignatureInSuccessResponse) } } - _ => Err("not in the right state to get a success response".to_owned()), + _ => Err(MechanismError::InvalidState), } } } diff --git a/sasl/src/client/mod.rs b/sasl/src/client/mod.rs index 93463f65..621de878 100644 --- a/sasl/src/client/mod.rs +++ b/sasl/src/client/mod.rs @@ -1,4 +1,80 @@ +use crate::common::scram::DeriveError; use crate::common::Credentials; +use hmac::crypto_mac::InvalidKeyLength; +use std::fmt; + +#[derive(Debug, PartialEq)] +pub enum MechanismError { + AnonymousRequiresNoCredentials, + + PlainRequiresUsername, + PlainRequiresPlaintextPassword, + + CannotGenerateNonce, + ScramRequiresUsername, + ScramRequiresPassword, + + CannotDecodeChallenge, + NoServerNonce, + NoServerSalt, + NoServerIterations, + DeriveError(DeriveError), + InvalidKeyLength(InvalidKeyLength), + InvalidState, + + CannotDecodeSuccessResponse, + InvalidSignatureInSuccessResponse, + NoSignatureInSuccessResponse, +} + +impl From for MechanismError { + fn from(err: DeriveError) -> MechanismError { + MechanismError::DeriveError(err) + } +} + +impl From for MechanismError { + fn from(err: InvalidKeyLength) -> MechanismError { + MechanismError::InvalidKeyLength(err) + } +} + +impl fmt::Display for MechanismError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!( + fmt, + "{}", + match self { + MechanismError::AnonymousRequiresNoCredentials => + "ANONYMOUS mechanism requires no credentials", + + MechanismError::PlainRequiresUsername => "PLAIN requires a username", + MechanismError::PlainRequiresPlaintextPassword => + "PLAIN requires a plaintext password", + + MechanismError::CannotGenerateNonce => "can't generate nonce", + MechanismError::ScramRequiresUsername => "SCRAM requires a username", + MechanismError::ScramRequiresPassword => "SCRAM requires a password", + + MechanismError::CannotDecodeChallenge => "can't decode challenge", + MechanismError::NoServerNonce => "no server nonce", + MechanismError::NoServerSalt => "no server salt", + MechanismError::NoServerIterations => "no server iterations", + MechanismError::DeriveError(err) => return write!(fmt, "derive error: {}", err), + MechanismError::InvalidKeyLength(err) => + return write!(fmt, "invalid key length: {}", err), + MechanismError::InvalidState => "not in the right state to receive this response", + + MechanismError::CannotDecodeSuccessResponse => "can't decode success response", + MechanismError::InvalidSignatureInSuccessResponse => + "invalid signature in success response", + MechanismError::NoSignatureInSuccessResponse => "no signature in success response", + } + ) + } +} + +impl std::error::Error for MechanismError {} /// A trait which defines SASL mechanisms. pub trait Mechanism { @@ -6,7 +82,7 @@ pub trait Mechanism { fn name(&self) -> &str; /// Creates this mechanism from `Credentials`. - fn from_credentials(credentials: Credentials) -> Result + fn from_credentials(credentials: Credentials) -> Result where Self: Sized; @@ -16,12 +92,12 @@ pub trait Mechanism { } /// Creates a response to the SASL challenge. - fn response(&mut self, _challenge: &[u8]) -> Result, String> { + fn response(&mut self, _challenge: &[u8]) -> Result, MechanismError> { Ok(Vec::new()) } /// Verifies the server success response, if there is one. - fn success(&mut self, _data: &[u8]) -> Result<(), String> { + fn success(&mut self, _data: &[u8]) -> Result<(), MechanismError> { Ok(()) } } diff --git a/sasl/src/common/scram.rs b/sasl/src/common/scram.rs index 860e441d..f65a89e9 100644 --- a/sasl/src/common/scram.rs +++ b/sasl/src/common/scram.rs @@ -1,4 +1,4 @@ -use hmac::{Hmac, Mac}; +use hmac::{crypto_mac::InvalidKeyLength, Hmac, Mac}; use pbkdf2::pbkdf2; use rand_os::{ rand_core::{Error as RngError, RngCore}, @@ -21,6 +21,29 @@ pub fn generate_nonce() -> Result { Ok(base64::encode(&data)) } +#[derive(Debug, PartialEq)] +pub enum DeriveError { + IncompatibleHashingMethod(String, String), + IncorrectSalt, + IncompatibleIterationCount(usize, usize), +} + +impl std::fmt::Display for DeriveError { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + DeriveError::IncompatibleHashingMethod(one, two) => { + write!(fmt, "incompatible hashing method, {} is not {}", one, two) + } + DeriveError::IncorrectSalt => write!(fmt, "incorrect salt"), + DeriveError::IncompatibleIterationCount(one, two) => { + write!(fmt, "incompatible iteration count, {} is not {}", one, two) + } + } + } +} + +impl std::error::Error for DeriveError {} + /// A trait which defines the needed methods for SCRAM. pub trait ScramProvider { /// The kind of secret this `ScramProvider` requires. @@ -33,10 +56,10 @@ pub trait ScramProvider { fn hash(data: &[u8]) -> Vec; /// A function which performs an HMAC using the hash function. - fn hmac(data: &[u8], key: &[u8]) -> Result, String>; + fn hmac(data: &[u8], key: &[u8]) -> Result, InvalidKeyLength>; /// A function which does PBKDF2 key derivation using the hash function. - fn derive(data: &Password, salt: &[u8], iterations: usize) -> Result, String>; + fn derive(data: &Password, salt: &[u8], iterations: usize) -> Result, DeriveError>; } /// A `ScramProvider` which provides SCRAM-SHA-1 and SCRAM-SHA-1-PLUS @@ -56,12 +79,9 @@ impl ScramProvider for Sha1 { vec } - fn hmac(data: &[u8], key: &[u8]) -> Result, String> { + fn hmac(data: &[u8], key: &[u8]) -> Result, InvalidKeyLength> { type HmacSha1 = Hmac; - let mut mac = match HmacSha1::new_varkey(key) { - Ok(mac) => mac, - Err(err) => return Err(format!("{}", err)), - }; + let mut mac = HmacSha1::new_varkey(key)?; mac.input(data); let result = mac.result(); let mut vec = Vec::with_capacity(Sha1_hash::output_size()); @@ -69,7 +89,7 @@ impl ScramProvider for Sha1 { Ok(vec) } - fn derive(password: &Password, salt: &[u8], iterations: usize) -> Result, String> { + fn derive(password: &Password, salt: &[u8], iterations: usize) -> Result, DeriveError> { match *password { Password::Plain(ref plain) => { let mut result = vec![0; 20]; @@ -83,17 +103,16 @@ impl ScramProvider for Sha1 { ref data, } => { if method != Self::name() { - Err(format!( - "incompatible hashing method, {} is not {}", - method, - Self::name() + Err(DeriveError::IncompatibleHashingMethod( + method.to_string(), + Self::name().to_string(), )) } else if my_salt == &salt { - Err(format!("incorrect salt")) + Err(DeriveError::IncorrectSalt) } else if my_iterations == iterations { - Err(format!( - "incompatible iteration count, {} is not {}", - my_iterations, iterations + Err(DeriveError::IncompatibleIterationCount( + my_iterations, + iterations, )) } else { Ok(data.to_vec()) @@ -120,12 +139,9 @@ impl ScramProvider for Sha256 { vec } - fn hmac(data: &[u8], key: &[u8]) -> Result, String> { + fn hmac(data: &[u8], key: &[u8]) -> Result, InvalidKeyLength> { type HmacSha256 = Hmac; - let mut mac = match HmacSha256::new_varkey(key) { - Ok(mac) => mac, - Err(err) => return Err(format!("{}", err)), - }; + let mut mac = HmacSha256::new_varkey(key)?; mac.input(data); let result = mac.result(); let mut vec = Vec::with_capacity(Sha256_hash::output_size()); @@ -133,7 +149,7 @@ impl ScramProvider for Sha256 { Ok(vec) } - fn derive(password: &Password, salt: &[u8], iterations: usize) -> Result, String> { + fn derive(password: &Password, salt: &[u8], iterations: usize) -> Result, DeriveError> { match *password { Password::Plain(ref plain) => { let mut result = vec![0; 32]; @@ -147,17 +163,16 @@ impl ScramProvider for Sha256 { ref data, } => { if method != Self::name() { - Err(format!( - "incompatible hashing method, {} is not {}", - method, - Self::name() + Err(DeriveError::IncompatibleHashingMethod( + method.to_string(), + Self::name().to_string(), )) } else if my_salt == &salt { - Err(format!("incorrect salt")) + Err(DeriveError::IncorrectSalt) } else if my_iterations == iterations { - Err(format!( - "incompatible iteration count, {} is not {}", - my_iterations, iterations + Err(DeriveError::IncompatibleIterationCount( + my_iterations, + iterations, )) } else { Ok(data.to_vec()) diff --git a/sasl/src/lib.rs b/sasl/src/lib.rs index 2a64e762..9684814c 100644 --- a/sasl/src/lib.rs +++ b/sasl/src/lib.rs @@ -17,7 +17,7 @@ //! //! let mut mechanism = Plain::from_credentials(creds).unwrap(); //! -//! let initial_data = mechanism.initial().unwrap(); +//! let initial_data = mechanism.initial(); //! //! assert_eq!(initial_data, b"\0user\0pencil"); //! ``` @@ -28,8 +28,9 @@ //! #[macro_use] extern crate sasl; //! //! use sasl::server::{Validator, Provider, Mechanism as ServerMechanism, Response}; +//! use sasl::server::{ValidatorError, ProviderError, MechanismError as ServerMechanismError}; //! use sasl::server::mechanisms::{Plain as ServerPlain, Scram as ServerScram}; -//! use sasl::client::Mechanism as ClientMechanism; +//! use sasl::client::{Mechanism as ClientMechanism, MechanismError as ClientMechanismError}; //! use sasl::client::mechanisms::{Plain as ClientPlain, Scram as ClientScram}; //! use sasl::common::{Identity, Credentials, Password, ChannelBinding}; //! use sasl::common::scram::{ScramProvider, Sha1, Sha256}; @@ -43,13 +44,13 @@ //! struct MyValidator; //! //! impl Validator for MyValidator { -//! fn validate(&self, identity: &Identity, value: &secret::Plain) -> Result<(), String> { +//! fn validate(&self, identity: &Identity, value: &secret::Plain) -> Result<(), ValidatorError> { //! let &secret::Plain(ref password) = value; //! if identity != &Identity::Username(USERNAME.to_owned()) { -//! Err("authentication failed".to_owned()) +//! Err(ValidatorError::AuthenticationFailed) //! } //! else if password != PASSWORD { -//! Err("authentication failed".to_owned()) +//! Err(ValidatorError::AuthenticationFailed) //! } //! else { //! Ok(()) @@ -58,9 +59,9 @@ //! } //! //! impl Provider for MyValidator { -//! fn provide(&self, identity: &Identity) -> Result { +//! fn provide(&self, identity: &Identity) -> Result { //! if identity != &Identity::Username(USERNAME.to_owned()) { -//! Err("authentication failed".to_owned()) +//! Err(ProviderError::AuthenticationFailed) //! } //! else { //! let digest = sasl::common::scram::Sha1::derive @@ -79,9 +80,9 @@ //! impl_validator_using_provider!(MyValidator, secret::Pbkdf2Sha1); //! //! impl Provider for MyValidator { -//! fn provide(&self, identity: &Identity) -> Result { +//! fn provide(&self, identity: &Identity) -> Result { //! if identity != &Identity::Username(USERNAME.to_owned()) { -//! Err("authentication failed".to_owned()) +//! Err(ProviderError::AuthenticationFailed) //! } //! else { //! let digest = sasl::common::scram::Sha256::derive @@ -99,10 +100,28 @@ //! //! impl_validator_using_provider!(MyValidator, secret::Pbkdf2Sha256); //! -//! fn finish(cm: &mut CM, sm: &mut SM) -> Result +//! #[derive(Debug, PartialEq)] +//! enum MechanismError { +//! Client(ClientMechanismError), +//! Server(ServerMechanismError), +//! } +//! +//! impl From for MechanismError { +//! fn from(err: ClientMechanismError) -> MechanismError { +//! MechanismError::Client(err) +//! } +//! } +//! +//! impl From for MechanismError { +//! fn from(err: ServerMechanismError) -> MechanismError { +//! MechanismError::Server(err) +//! } +//! } +//! +//! fn finish(cm: &mut CM, sm: &mut SM) -> Result //! where CM: ClientMechanism, //! SM: ServerMechanism { -//! let init = cm.initial()?; +//! let init = cm.initial(); //! println!("C: {}", String::from_utf8_lossy(&init)); //! let mut resp = sm.respond(&init)?; //! loop { @@ -133,7 +152,7 @@ //! assert_eq!(mech.respond(b"\0user\0pencil"), Ok(expected_response)); //! //! let mut mech = ServerPlain::new(MyValidator); -//! assert_eq!(mech.respond(b"\0user\0marker"), Err("authentication failed".to_owned())); +//! assert_eq!(mech.respond(b"\0user\0marker"), Err(ServerMechanismError::ValidatorError(ValidatorError::AuthenticationFailed))); //! //! let creds = Credentials::default() //! .with_username(USERNAME) diff --git a/sasl/src/secret.rs b/sasl/src/secret.rs index 31f42dd9..05b5850b 100644 --- a/sasl/src/secret.rs +++ b/sasl/src/secret.rs @@ -1,3 +1,6 @@ +#[cfg(feature = "scram")] +use crate::common::scram::DeriveError; + pub trait Secret {} pub trait Pbkdf2Secret { @@ -20,7 +23,11 @@ pub struct Pbkdf2Sha1 { impl Pbkdf2Sha1 { #[cfg(feature = "scram")] - pub fn derive(password: &str, salt: &[u8], iterations: usize) -> Result { + pub fn derive( + password: &str, + salt: &[u8], + iterations: usize, + ) -> Result { use crate::common::scram::{ScramProvider, Sha1}; use crate::common::Password; let digest = Sha1::derive(&Password::Plain(password.to_owned()), salt, iterations)?; @@ -55,7 +62,11 @@ pub struct Pbkdf2Sha256 { impl Pbkdf2Sha256 { #[cfg(feature = "scram")] - pub fn derive(password: &str, salt: &[u8], iterations: usize) -> Result { + pub fn derive( + password: &str, + salt: &[u8], + iterations: usize, + ) -> Result { use crate::common::scram::{ScramProvider, Sha256}; use crate::common::Password; let digest = Sha256::derive(&Password::Plain(password.to_owned()), salt, iterations)?; diff --git a/sasl/src/server/mechanisms/plain.rs b/sasl/src/server/mechanisms/plain.rs index 8df0e76f..79d090f6 100644 --- a/sasl/src/server/mechanisms/plain.rs +++ b/sasl/src/server/mechanisms/plain.rs @@ -1,6 +1,6 @@ use crate::common::Identity; use crate::secret; -use crate::server::{Mechanism, Response, Validator}; +use crate::server::{Mechanism, MechanismError, Response, Validator}; pub struct Plain> { validator: V, @@ -19,19 +19,19 @@ impl> Mechanism for Plain { "PLAIN" } - fn respond(&mut self, payload: &[u8]) -> Result { + fn respond(&mut self, payload: &[u8]) -> Result { let mut sp = payload.split(|&b| b == 0); sp.next(); let username = sp .next() - .ok_or_else(|| "no username specified".to_owned())?; - let username = - String::from_utf8(username.to_vec()).map_err(|_| "error decoding username")?; + .ok_or_else(|| MechanismError::NoUsernameSpecified)?; + let username = String::from_utf8(username.to_vec()) + .map_err(|_| MechanismError::ErrorDecodingUsername)?; let password = sp .next() - .ok_or_else(|| "no password specified".to_owned())?; - let password = - String::from_utf8(password.to_vec()).map_err(|_| "error decoding password")?; + .ok_or_else(|| MechanismError::NoPasswordSpecified)?; + let password = String::from_utf8(password.to_vec()) + .map_err(|_| MechanismError::ErrorDecodingPassword)?; let ident = Identity::Username(username); self.validator.validate(&ident, &secret::Plain(password))?; Ok(Response::Success(ident, Vec::new())) diff --git a/sasl/src/server/mechanisms/scram.rs b/sasl/src/server/mechanisms/scram.rs index a53e97ed..41d7ac7f 100644 --- a/sasl/src/server/mechanisms/scram.rs +++ b/sasl/src/server/mechanisms/scram.rs @@ -6,7 +6,7 @@ use crate::common::scram::{generate_nonce, ScramProvider}; use crate::common::{parse_frame, xor, ChannelBinding, Identity}; use crate::secret; use crate::secret::Pbkdf2Secret; -use crate::server::{Mechanism, Provider, Response}; +use crate::server::{Mechanism, MechanismError, Provider, Response}; enum ScramState { Init, @@ -61,7 +61,7 @@ where &self.name } - fn respond(&mut self, payload: &[u8]) -> Result { + fn respond(&mut self, payload: &[u8]) -> Result { let next_state; let ret; match self.state { @@ -82,7 +82,7 @@ where } } if commas < 2 { - return Err("failed to decode message".to_owned()); + return Err(MechanismError::FailedToDecodeMessage); } let gs2_header = payload[..idx].to_vec(); let rest = payload[idx..].to_vec(); @@ -92,29 +92,29 @@ where // Not supported. if gs2_header[0] != 0x79 { // ord("y") - return Err("channel binding not supported".to_owned()); + return Err(MechanismError::ChannelBindingNotSupported); } } ref other => { // Supported. if gs2_header[0] == 0x79 { // ord("y") - return Err("channel binding is supported".to_owned()); + return Err(MechanismError::ChannelBindingIsSupported); } else if !other.supports("tls-unique") { // TODO: grab the data - return Err("channel binding mechanism incorrect".to_owned()); + return Err(MechanismError::ChannelBindingMechanismIncorrect); } } } let frame = - parse_frame(&rest).map_err(|_| "can't decode initial message".to_owned())?; - let username = frame.get("n").ok_or_else(|| "no username".to_owned())?; + parse_frame(&rest).map_err(|_| MechanismError::CannotDecodeInitialMessage)?; + let username = frame.get("n").ok_or_else(|| MechanismError::NoUsername)?; let identity = Identity::Username(username.to_owned()); - let client_nonce = frame.get("r").ok_or_else(|| "no nonce".to_owned())?; + let client_nonce = frame.get("r").ok_or_else(|| MechanismError::NoNonce)?; let mut server_nonce = String::new(); server_nonce += client_nonce; server_nonce += - &generate_nonce().map_err(|_| "failed to generate nonce".to_owned())?; + &generate_nonce().map_err(|_| MechanismError::FailedToGenerateNonce)?; let pbkdf2 = self.provider.provide(&identity)?; let mut buf = Vec::new(); buf.extend(b"r="); @@ -141,7 +141,8 @@ where ref initial_client_message, ref initial_server_message, } => { - let frame = parse_frame(payload).map_err(|_| "can't decode response".to_owned())?; + let frame = + parse_frame(payload).map_err(|_| MechanismError::CannotDecodeResponse)?; let mut cb_data: Vec = Vec::new(); cb_data.extend(gs2_header); cb_data.extend(self.channel_binding.data()); @@ -161,11 +162,11 @@ where let stored_key = S::hash(&client_key); let client_signature = S::hmac(&auth_message, &stored_key)?; let client_proof = xor(&client_key, &client_signature); - let sent_proof = frame.get("p").ok_or_else(|| "no proof".to_owned())?; + let sent_proof = frame.get("p").ok_or_else(|| MechanismError::NoProof)?; let sent_proof = - base64::decode(sent_proof).map_err(|_| "can't decode proof".to_owned())?; + base64::decode(sent_proof).map_err(|_| MechanismError::CannotDecodeProof)?; if client_proof != sent_proof { - return Err("authentication failed".to_owned()); + return Err(MechanismError::AuthenticationFailed); } let server_signature = S::hmac(&auth_message, &server_key)?; let mut buf = Vec::new(); @@ -175,7 +176,7 @@ where next_state = ScramState::Done; } ScramState::Done => { - return Err("sasl session is already over".to_owned()); + return Err(MechanismError::SaslSessionAlreadyOver); } } self.state = next_state; diff --git a/sasl/src/server/mod.rs b/sasl/src/server/mod.rs index 020e88b2..c144235c 100644 --- a/sasl/src/server/mod.rs +++ b/sasl/src/server/mod.rs @@ -1,5 +1,7 @@ +use crate::common::scram::DeriveError; use crate::common::Identity; use crate::secret::Secret; +use std::fmt; #[macro_export] macro_rules! impl_validator_using_provider { @@ -9,11 +11,11 @@ macro_rules! impl_validator_using_provider { &self, identity: &$crate::common::Identity, value: &$secret, - ) -> Result<(), String> { + ) -> Result<(), ValidatorError> { if &(self as &$crate::server::Provider<$secret>).provide(identity)? == value { Ok(()) } else { - Err("authentication failure".to_owned()) + Err(ValidatorError::AuthenticationFailed) } } } @@ -21,16 +23,150 @@ macro_rules! impl_validator_using_provider { } pub trait Provider: Validator { - fn provide(&self, identity: &Identity) -> Result; + fn provide(&self, identity: &Identity) -> Result; } pub trait Validator { - fn validate(&self, identity: &Identity, value: &S) -> Result<(), String>; + fn validate(&self, identity: &Identity, value: &S) -> Result<(), ValidatorError>; +} + +#[derive(Debug, PartialEq)] +pub enum ProviderError { + AuthenticationFailed, + DeriveError(DeriveError), +} + +#[derive(Debug, PartialEq)] +pub enum ValidatorError { + AuthenticationFailed, + ProviderError(ProviderError), +} + +#[derive(Debug, PartialEq)] +pub enum MechanismError { + NoUsernameSpecified, + ErrorDecodingUsername, + NoPasswordSpecified, + ErrorDecodingPassword, + ValidatorError(ValidatorError), + + FailedToDecodeMessage, + ChannelBindingNotSupported, + ChannelBindingIsSupported, + ChannelBindingMechanismIncorrect, + CannotDecodeInitialMessage, + NoUsername, + NoNonce, + FailedToGenerateNonce, + ProviderError(ProviderError), + + CannotDecodeResponse, + InvalidKeyLength(hmac::crypto_mac::InvalidKeyLength), + NoProof, + CannotDecodeProof, + AuthenticationFailed, + SaslSessionAlreadyOver, +} + +impl From for ProviderError { + fn from(err: DeriveError) -> ProviderError { + ProviderError::DeriveError(err) + } +} + +impl From for ValidatorError { + fn from(err: ProviderError) -> ValidatorError { + ValidatorError::ProviderError(err) + } +} + +impl From for MechanismError { + fn from(err: ProviderError) -> MechanismError { + MechanismError::ProviderError(err) + } +} + +impl From for MechanismError { + fn from(err: ValidatorError) -> MechanismError { + MechanismError::ValidatorError(err) + } +} + +impl From for MechanismError { + fn from(err: hmac::crypto_mac::InvalidKeyLength) -> MechanismError { + MechanismError::InvalidKeyLength(err) + } +} + +impl fmt::Display for ProviderError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "provider error") + } +} + +impl fmt::Display for ValidatorError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "validator error") + } +} + +impl fmt::Display for MechanismError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match self { + MechanismError::NoUsernameSpecified => write!(fmt, "no username specified"), + MechanismError::ErrorDecodingUsername => write!(fmt, "error decoding username"), + MechanismError::NoPasswordSpecified => write!(fmt, "no password specified"), + MechanismError::ErrorDecodingPassword => write!(fmt, "error decoding password"), + MechanismError::ValidatorError(err) => write!(fmt, "validator error: {}", err), + + MechanismError::FailedToDecodeMessage => write!(fmt, "failed to decode message"), + MechanismError::ChannelBindingNotSupported => { + write!(fmt, "channel binding not supported") + } + MechanismError::ChannelBindingIsSupported => { + write!(fmt, "channel binding is supported") + } + MechanismError::ChannelBindingMechanismIncorrect => { + write!(fmt, "channel binding mechanism is incorrect") + } + MechanismError::CannotDecodeInitialMessage => { + write!(fmt, "can’t decode initial message") + } + MechanismError::NoUsername => write!(fmt, "no username"), + MechanismError::NoNonce => write!(fmt, "no nonce"), + MechanismError::FailedToGenerateNonce => write!(fmt, "failed to generate nonce"), + MechanismError::ProviderError(err) => write!(fmt, "provider error: {}", err), + + MechanismError::CannotDecodeResponse => write!(fmt, "can’t decode response"), + MechanismError::InvalidKeyLength(err) => write!(fmt, "invalid key length: {}", err), + MechanismError::NoProof => write!(fmt, "no proof"), + MechanismError::CannotDecodeProof => write!(fmt, "can’t decode proof"), + MechanismError::AuthenticationFailed => write!(fmt, "authentication failed"), + MechanismError::SaslSessionAlreadyOver => write!(fmt, "SASL session already over"), + } + } +} + +impl Error for ProviderError {} + +impl Error for ValidatorError {} + +use std::error::Error; +impl Error for MechanismError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + MechanismError::ValidatorError(err) => Some(err), + MechanismError::ProviderError(err) => Some(err), + // TODO: figure out how to enable the std feature on this crate. + //MechanismError::InvalidKeyLength(err) => Some(err), + _ => None, + } + } } pub trait Mechanism { fn name(&self) -> &str; - fn respond(&mut self, payload: &[u8]) -> Result; + fn respond(&mut self, payload: &[u8]) -> Result; } #[derive(Debug, Clone, PartialEq, Eq)]