slixmpp.util: type things

Fix a bug in the SASL implementation as well. (some special chars would
make things crash instead of being escaped)
This commit is contained in:
mathieui 2021-04-03 19:12:59 +02:00
parent b1411d8ed7
commit ef06429941
4 changed files with 89 additions and 88 deletions

View file

@ -1,4 +1,3 @@
# Slixmpp: The Slick XMPP Library
# Copyright (C) 2018 Emmanuel Gil Peyrot
# This file is part of Slixmpp.
@ -6,8 +5,11 @@
import os
import logging
from typing import Callable, Optional, Any
log = logging.getLogger(__name__)
class Cache:
def retrieve(self, key):
raise NotImplementedError
@ -16,7 +18,8 @@ class Cache:
raise NotImplementedError
def remove(self, key):
raise NotImplemented
raise NotImplementedError
class PerJidCache:
def retrieve_by_jid(self, jid, key):
@ -28,6 +31,7 @@ class PerJidCache:
def remove_by_jid(self, jid, key):
raise NotImplementedError
class MemoryCache(Cache):
def __init__(self):
self.cache = {}
@ -44,6 +48,7 @@ class MemoryCache(Cache):
del self.cache[key]
return True
class MemoryPerJidCache(PerJidCache):
def __init__(self):
self.cache = {}
@ -65,14 +70,15 @@ class MemoryPerJidCache(PerJidCache):
del cache[key]
return True
class FileSystemStorage:
def __init__(self, encode, decode, binary):
def __init__(self, encode: Optional[Callable[[Any], str]], decode: Optional[Callable[[str], Any]], binary: bool):
self.encode = encode if encode is not None else lambda x: x
self.decode = decode if decode is not None else lambda x: x
self.read = 'rb' if binary else 'r'
self.write = 'wb' if binary else 'w'
def _retrieve(self, directory, key):
def _retrieve(self, directory: str, key: str):
filename = os.path.join(directory, key.replace('/', '_'))
try:
with open(filename, self.read) as cache_file:
@ -86,7 +92,7 @@ class FileSystemStorage:
log.debug('Removing %s entry', key)
self._remove(directory, key)
def _store(self, directory, key, value):
def _store(self, directory: str, key: str, value):
filename = os.path.join(directory, key.replace('/', '_'))
try:
os.makedirs(directory, exist_ok=True)
@ -99,7 +105,7 @@ class FileSystemStorage:
except Exception:
log.debug('Failed to encode %s to cache:', key, exc_info=True)
def _remove(self, directory, key):
def _remove(self, directory: str, key: str):
filename = os.path.join(directory, key.replace('/', '_'))
try:
os.remove(filename)
@ -108,8 +114,9 @@ class FileSystemStorage:
return False
return True
class FileSystemCache(Cache, FileSystemStorage):
def __init__(self, directory, cache_type, *, encode=None, decode=None, binary=False):
def __init__(self, directory: str, cache_type: str, *, encode=None, decode=None, binary=False):
FileSystemStorage.__init__(self, encode, decode, binary)
self.base_dir = os.path.join(directory, cache_type)
@ -122,8 +129,9 @@ class FileSystemCache(Cache, FileSystemStorage):
def remove(self, key):
return self._remove(self.base_dir, key)
class FileSystemPerJidCache(PerJidCache, FileSystemStorage):
def __init__(self, directory, cache_type, *, encode=None, decode=None, binary=False):
def __init__(self, directory: str, cache_type: str, *, encode=None, decode=None, binary=False):
FileSystemStorage.__init__(self, encode, decode, binary)
self.base_dir = os.path.join(directory, cache_type)

View file

@ -2,15 +2,19 @@ import builtins
import sys
import hashlib
from typing import Optional, Union, Callable, List
def unicode(text):
bytes_ = builtins.bytes # alias the stdlib type but ew
def unicode(text: Union[bytes_, str]) -> str:
if not isinstance(text, str):
return text.decode('utf-8')
else:
return text
def bytes(text):
def bytes(text: Optional[Union[str, bytes_]]) -> bytes_:
"""
Convert Unicode text to UTF-8 encoded bytes.
@ -34,7 +38,7 @@ def bytes(text):
return builtins.bytes(text, encoding='utf-8')
def quote(text):
def quote(text: Union[str, bytes_]) -> bytes_:
"""
Enclose in quotes and escape internal slashes and double quotes.
@ -44,7 +48,7 @@ def quote(text):
return b'"' + text.replace(b'\\', b'\\\\').replace(b'"', b'\\"') + b'"'
def num_to_bytes(num):
def num_to_bytes(num: int) -> bytes_:
"""
Convert an integer into a four byte sequence.
@ -58,21 +62,21 @@ def num_to_bytes(num):
return bval
def bytes_to_num(bval):
def bytes_to_num(bval: bytes_) -> int:
"""
Convert a four byte sequence to an integer.
:param bytes bval: A four byte sequence to turn into an integer.
"""
num = 0
num += ord(bval[0] << 24)
num += ord(bval[1] << 16)
num += ord(bval[2] << 8)
num += ord(bval[3])
num += (bval[0] << 24)
num += (bval[1] << 16)
num += (bval[2] << 8)
num += (bval[3])
return num
def XOR(x, y):
def XOR(x: bytes_, y: bytes_) -> bytes_:
"""
Return the results of an XOR operation on two equal length byte strings.
@ -85,7 +89,7 @@ def XOR(x, y):
return builtins.bytes([a ^ b for a, b in zip(x, y)])
def hash(name):
def hash(name: str) -> Optional[Callable]:
"""
Return a hash function implementing the given algorithm.
@ -102,7 +106,7 @@ def hash(name):
return None
def hashes():
def hashes() -> List[str]:
"""
Return a list of available hashing algorithms.
@ -115,28 +119,3 @@ def hashes():
t += ['MD2']
hashes = ['SHA-' + h[3:] for h in dir(hashlib) if h.startswith('sha')]
return t + hashes
def setdefaultencoding(encoding):
"""
Set the current default string encoding used by the Unicode implementation.
Actually calls sys.setdefaultencoding under the hood - see the docs for that
for more details. This method exists only as a way to call find/call it
even after it has been 'deleted' when the site module is executed.
:param string encoding: An encoding name, compatible with sys.setdefaultencoding
"""
func = getattr(sys, 'setdefaultencoding', None)
if func is None:
import gc
import types
for obj in gc.get_objects():
if (isinstance(obj, types.BuiltinFunctionType)
and obj.__name__ == 'setdefaultencoding'):
func = obj
break
if func is None:
raise RuntimeError("Could not find setdefaultencoding")
sys.setdefaultencoding = func
return func(encoding)

View file

@ -1,4 +1,3 @@
# slixmpp.util.sasl.client
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# This module was originally based on Dave Cridland's Suelta library.
@ -6,9 +5,11 @@
# :copryight: (c) 2004-2013 David Alan Cridland
# :copyright: (c) 2013 Nathanael C. Fritz, Lance J.T. Stout
# :license: MIT, see LICENSE for more details
from __future__ import annotations
import logging
import stringprep
from typing import Iterable, Set, Callable, Dict, Any, Optional, Type
from slixmpp.util import hashes, bytes, stringprep_profiles
@ -16,11 +17,11 @@ log = logging.getLogger(__name__)
#: Global registry mapping mechanism names to implementation classes.
MECHANISMS = {}
MECHANISMS: Dict[str, Type[Mech]] = {}
#: Global registry mapping mechanism names to security scores.
MECH_SEC_SCORES = {}
MECH_SEC_SCORES: Dict[str, int] = {}
#: The SASLprep profile of stringprep used to validate simple username
@ -45,9 +46,10 @@ saslprep = stringprep_profiles.create(
unassigned=[stringprep.in_table_a1])
def sasl_mech(score):
def sasl_mech(score: int):
sec_score = score
def register(mech):
def register(mech: Type[Mech]):
n = 0
mech.score = sec_score
if mech.use_hashes:
@ -99,9 +101,9 @@ class Mech(object):
score = -1
use_hashes = False
channel_binding = False
required_credentials = set()
optional_credentials = set()
security = set()
required_credentials: Set[str] = set()
optional_credentials: Set[str] = set()
security: Set[str] = set()
def __init__(self, name, credentials, security_settings):
self.credentials = credentials
@ -118,7 +120,14 @@ class Mech(object):
return b''
def choose(mech_list, credentials, security_settings, limit=None, min_mech=None):
CredentialsCallback = Callable[[Iterable[str], Iterable[str]], Dict[str, Any]]
SecurityCallback = Callable[[Iterable[str]], Dict[str, Any]]
def choose(mech_list: Iterable[Type[Mech]], credentials: CredentialsCallback,
security_settings: SecurityCallback,
limit: Optional[Iterable[Type[Mech]]] = None,
min_mech: Optional[str] = None) -> Mech:
available_mechs = set(MECHANISMS.keys())
if limit is None:
limit = set(mech_list)
@ -130,7 +139,10 @@ def choose(mech_list, credentials, security_settings, limit=None, min_mech=None)
mech_list = mech_list.intersection(limit)
available_mechs = available_mechs.intersection(mech_list)
best_score = MECH_SEC_SCORES.get(min_mech, -1)
if min_mech is None:
best_score = -1
else:
best_score = MECH_SEC_SCORES.get(min_mech, -1)
best_mech = None
for name in available_mechs:
if name in MECH_SEC_SCORES:

View file

@ -11,6 +11,9 @@ import hmac
import random
from base64 import b64encode, b64decode
from typing import List, Dict, Optional
bytes_ = bytes
from slixmpp.util import bytes, hash, XOR, quote, num_to_bytes
from slixmpp.util.sasl.client import sasl_mech, Mech, \
@ -63,7 +66,7 @@ class PLAIN(Mech):
if not self.security_settings['encrypted_plain']:
raise SASLCancelled('PLAIN with encryption')
def process(self, challenge=b''):
def process(self, challenge: bytes_ = b'') -> bytes_:
authzid = self.credentials['authzid']
authcid = self.credentials['username']
password = self.credentials['password']
@ -148,7 +151,7 @@ class CRAM(Mech):
required_credentials = {'username', 'password'}
security = {'encrypted', 'unencrypted_cram'}
def setup(self, name):
def setup(self, name: str):
self.hash_name = name[5:]
self.hash = hash(self.hash_name)
if self.hash is None:
@ -157,14 +160,14 @@ class CRAM(Mech):
if not self.security_settings['unencrypted_cram']:
raise SASLCancelled('Unecrypted CRAM-%s' % self.hash_name)
def process(self, challenge=b''):
def process(self, challenge: bytes_ = b'') -> Optional[bytes_]:
if not challenge:
return None
username = self.credentials['username']
password = self.credentials['password']
mac = hmac.HMAC(key=password, digestmod=self.hash)
mac = hmac.HMAC(key=password, digestmod=self.hash) # type: ignore
mac.update(challenge)
return username + b' ' + bytes(mac.hexdigest())
@ -201,43 +204,42 @@ class SCRAM(Mech):
def HMAC(self, key, msg):
return hmac.HMAC(key=key, msg=msg, digestmod=self.hash).digest()
def Hi(self, text, salt, iterations):
text = bytes(text)
ui1 = self.HMAC(text, salt + b'\0\0\0\01')
def Hi(self, text: str, salt: bytes_, iterations: int):
text_enc = bytes(text)
ui1 = self.HMAC(text_enc, salt + b'\0\0\0\01')
ui = ui1
for i in range(iterations - 1):
ui1 = self.HMAC(text, ui1)
ui1 = self.HMAC(text_enc, ui1)
ui = XOR(ui, ui1)
return ui
def H(self, text):
def H(self, text: str) -> bytes_:
return self.hash(text).digest()
def saslname(self, value):
value = value.decode("utf-8")
escaped = []
def saslname(self, value_b: bytes_) -> bytes_:
value = value_b.decode("utf-8")
escaped: List[str] = []
for char in value:
if char == ',':
escaped += b'=2C'
escaped.append('=2C')
elif char == '=':
escaped += b'=3D'
escaped.append('=3D')
else:
escaped += char
escaped.append(char)
return "".join(escaped).encode("utf-8")
def parse(self, challenge):
def parse(self, challenge: bytes_) -> Dict[bytes_, bytes_]:
items = {}
for key, value in [item.split(b'=', 1) for item in challenge.split(b',')]:
items[key] = value
return items
def process(self, challenge=b''):
def process(self, challenge: bytes_ = b''):
steps = [self.process_1, self.process_2, self.process_3]
return steps[self.step](challenge)
def process_1(self, challenge):
def process_1(self, challenge: bytes_) -> bytes_:
self.step = 1
data = {}
self.cnonce = bytes(('%s' % random.random())[2:])
@ -263,7 +265,7 @@ class SCRAM(Mech):
return self.client_first_message
def process_2(self, challenge):
def process_2(self, challenge: bytes_) -> bytes_:
self.step = 2
data = self.parse(challenge)
@ -304,7 +306,7 @@ class SCRAM(Mech):
return client_final_message
def process_3(self, challenge):
def process_3(self, challenge: bytes_) -> bytes_:
data = self.parse(challenge)
verifier = data.get(b'v', None)
error = data.get(b'e', 'Unknown error')
@ -345,17 +347,16 @@ class DIGEST(Mech):
self.cnonce = b''
self.nonce_count = 1
def parse(self, challenge=b''):
data = {}
def parse(self, challenge: bytes_ = b''):
data: Dict[str, bytes_] = {}
var_name = b''
var_value = b''
# States: var, new_var, end, quote, escaped_quote
state = 'var'
for char in challenge:
char = bytes([char])
for char_int in challenge:
char = bytes_([char_int])
if state == 'var':
if char.isspace():
@ -401,14 +402,14 @@ class DIGEST(Mech):
state = 'var'
return data
def MAC(self, key, seq, msg):
def MAC(self, key: bytes_, seq: int, msg: bytes_) -> bytes_:
mac = hmac.HMAC(key=key, digestmod=self.hash)
seqnum = num_to_bytes(seq)
mac.update(seqnum)
mac.update(msg)
return mac.digest()[:10] + b'\x00\x01' + seqnum
def A1(self):
def A1(self) -> bytes_:
username = self.credentials['username']
password = self.credentials['password']
authzid = self.credentials['authzid']
@ -423,13 +424,13 @@ class DIGEST(Mech):
return bytes(a1)
def A2(self, prefix=b''):
def A2(self, prefix: bytes_ = b'') -> bytes_:
a2 = prefix + b':' + self.digest_uri()
if self.qop in (b'auth-int', b'auth-conf'):
a2 += b':00000000000000000000000000000000'
return bytes(a2)
def response(self, prefix=b''):
def response(self, prefix: bytes_ = b'') -> bytes_:
nc = bytes('%08x' % self.nonce_count)
a1 = bytes(self.hash(self.A1()).hexdigest().lower())
@ -439,7 +440,7 @@ class DIGEST(Mech):
return bytes(self.hash(a1 + b':' + s).hexdigest().lower())
def digest_uri(self):
def digest_uri(self) -> bytes_:
serv_type = self.credentials['service']
serv_name = self.credentials['service-name']
host = self.credentials['host']
@ -449,7 +450,7 @@ class DIGEST(Mech):
uri += b'/' + serv_name
return uri
def respond(self):
def respond(self) -> bytes_:
data = {
'username': quote(self.credentials['username']),
'authzid': quote(self.credentials['authzid']),
@ -469,7 +470,7 @@ class DIGEST(Mech):
resp += b',' + bytes(key) + b'=' + bytes(value)
return resp[1:]
def process(self, challenge=b''):
def process(self, challenge: bytes_ = b'') -> Optional[bytes_]:
if not challenge:
if self.cnonce and self.nonce and self.nonce_count and self.qop:
self.nonce_count += 1
@ -480,6 +481,7 @@ class DIGEST(Mech):
if 'rspauth' in data:
if data['rspauth'] != self.response():
raise SASLMutualAuthFailed()
return None
else:
self.nonce_count = 1
self.cnonce = bytes('%s' % random.random())[2:]