slixmpp.util: type things
Fix a bug in the SASL implementation as well. (some special chars would make things crash instead of being escaped)
This commit is contained in:
parent
b1411d8ed7
commit
ef06429941
4 changed files with 89 additions and 88 deletions
|
@ -1,4 +1,3 @@
|
|||
|
||||
# Slixmpp: The Slick XMPP Library
|
||||
# Copyright (C) 2018 Emmanuel Gil Peyrot
|
||||
# This file is part of Slixmpp.
|
||||
|
@ -6,8 +5,11 @@
|
|||
import os
|
||||
import logging
|
||||
|
||||
from typing import Callable, Optional, Any
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Cache:
|
||||
def retrieve(self, key):
|
||||
raise NotImplementedError
|
||||
|
@ -16,7 +18,8 @@ class Cache:
|
|||
raise NotImplementedError
|
||||
|
||||
def remove(self, key):
|
||||
raise NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PerJidCache:
|
||||
def retrieve_by_jid(self, jid, key):
|
||||
|
@ -28,6 +31,7 @@ class PerJidCache:
|
|||
def remove_by_jid(self, jid, key):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MemoryCache(Cache):
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
@ -44,6 +48,7 @@ class MemoryCache(Cache):
|
|||
del self.cache[key]
|
||||
return True
|
||||
|
||||
|
||||
class MemoryPerJidCache(PerJidCache):
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
@ -65,14 +70,15 @@ class MemoryPerJidCache(PerJidCache):
|
|||
del cache[key]
|
||||
return True
|
||||
|
||||
|
||||
class FileSystemStorage:
|
||||
def __init__(self, encode, decode, binary):
|
||||
def __init__(self, encode: Optional[Callable[[Any], str]], decode: Optional[Callable[[str], Any]], binary: bool):
|
||||
self.encode = encode if encode is not None else lambda x: x
|
||||
self.decode = decode if decode is not None else lambda x: x
|
||||
self.read = 'rb' if binary else 'r'
|
||||
self.write = 'wb' if binary else 'w'
|
||||
|
||||
def _retrieve(self, directory, key):
|
||||
def _retrieve(self, directory: str, key: str):
|
||||
filename = os.path.join(directory, key.replace('/', '_'))
|
||||
try:
|
||||
with open(filename, self.read) as cache_file:
|
||||
|
@ -86,7 +92,7 @@ class FileSystemStorage:
|
|||
log.debug('Removing %s entry', key)
|
||||
self._remove(directory, key)
|
||||
|
||||
def _store(self, directory, key, value):
|
||||
def _store(self, directory: str, key: str, value):
|
||||
filename = os.path.join(directory, key.replace('/', '_'))
|
||||
try:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
@ -99,7 +105,7 @@ class FileSystemStorage:
|
|||
except Exception:
|
||||
log.debug('Failed to encode %s to cache:', key, exc_info=True)
|
||||
|
||||
def _remove(self, directory, key):
|
||||
def _remove(self, directory: str, key: str):
|
||||
filename = os.path.join(directory, key.replace('/', '_'))
|
||||
try:
|
||||
os.remove(filename)
|
||||
|
@ -108,8 +114,9 @@ class FileSystemStorage:
|
|||
return False
|
||||
return True
|
||||
|
||||
|
||||
class FileSystemCache(Cache, FileSystemStorage):
|
||||
def __init__(self, directory, cache_type, *, encode=None, decode=None, binary=False):
|
||||
def __init__(self, directory: str, cache_type: str, *, encode=None, decode=None, binary=False):
|
||||
FileSystemStorage.__init__(self, encode, decode, binary)
|
||||
self.base_dir = os.path.join(directory, cache_type)
|
||||
|
||||
|
@ -122,8 +129,9 @@ class FileSystemCache(Cache, FileSystemStorage):
|
|||
def remove(self, key):
|
||||
return self._remove(self.base_dir, key)
|
||||
|
||||
|
||||
class FileSystemPerJidCache(PerJidCache, FileSystemStorage):
|
||||
def __init__(self, directory, cache_type, *, encode=None, decode=None, binary=False):
|
||||
def __init__(self, directory: str, cache_type: str, *, encode=None, decode=None, binary=False):
|
||||
FileSystemStorage.__init__(self, encode, decode, binary)
|
||||
self.base_dir = os.path.join(directory, cache_type)
|
||||
|
||||
|
|
|
@ -2,15 +2,19 @@ import builtins
|
|||
import sys
|
||||
import hashlib
|
||||
|
||||
from typing import Optional, Union, Callable, List
|
||||
|
||||
def unicode(text):
|
||||
bytes_ = builtins.bytes # alias the stdlib type but ew
|
||||
|
||||
|
||||
def unicode(text: Union[bytes_, str]) -> str:
|
||||
if not isinstance(text, str):
|
||||
return text.decode('utf-8')
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
def bytes(text):
|
||||
def bytes(text: Optional[Union[str, bytes_]]) -> bytes_:
|
||||
"""
|
||||
Convert Unicode text to UTF-8 encoded bytes.
|
||||
|
||||
|
@ -34,7 +38,7 @@ def bytes(text):
|
|||
return builtins.bytes(text, encoding='utf-8')
|
||||
|
||||
|
||||
def quote(text):
|
||||
def quote(text: Union[str, bytes_]) -> bytes_:
|
||||
"""
|
||||
Enclose in quotes and escape internal slashes and double quotes.
|
||||
|
||||
|
@ -44,7 +48,7 @@ def quote(text):
|
|||
return b'"' + text.replace(b'\\', b'\\\\').replace(b'"', b'\\"') + b'"'
|
||||
|
||||
|
||||
def num_to_bytes(num):
|
||||
def num_to_bytes(num: int) -> bytes_:
|
||||
"""
|
||||
Convert an integer into a four byte sequence.
|
||||
|
||||
|
@ -58,21 +62,21 @@ def num_to_bytes(num):
|
|||
return bval
|
||||
|
||||
|
||||
def bytes_to_num(bval):
|
||||
def bytes_to_num(bval: bytes_) -> int:
|
||||
"""
|
||||
Convert a four byte sequence to an integer.
|
||||
|
||||
:param bytes bval: A four byte sequence to turn into an integer.
|
||||
"""
|
||||
num = 0
|
||||
num += ord(bval[0] << 24)
|
||||
num += ord(bval[1] << 16)
|
||||
num += ord(bval[2] << 8)
|
||||
num += ord(bval[3])
|
||||
num += (bval[0] << 24)
|
||||
num += (bval[1] << 16)
|
||||
num += (bval[2] << 8)
|
||||
num += (bval[3])
|
||||
return num
|
||||
|
||||
|
||||
def XOR(x, y):
|
||||
def XOR(x: bytes_, y: bytes_) -> bytes_:
|
||||
"""
|
||||
Return the results of an XOR operation on two equal length byte strings.
|
||||
|
||||
|
@ -85,7 +89,7 @@ def XOR(x, y):
|
|||
return builtins.bytes([a ^ b for a, b in zip(x, y)])
|
||||
|
||||
|
||||
def hash(name):
|
||||
def hash(name: str) -> Optional[Callable]:
|
||||
"""
|
||||
Return a hash function implementing the given algorithm.
|
||||
|
||||
|
@ -102,7 +106,7 @@ def hash(name):
|
|||
return None
|
||||
|
||||
|
||||
def hashes():
|
||||
def hashes() -> List[str]:
|
||||
"""
|
||||
Return a list of available hashing algorithms.
|
||||
|
||||
|
@ -115,28 +119,3 @@ def hashes():
|
|||
t += ['MD2']
|
||||
hashes = ['SHA-' + h[3:] for h in dir(hashlib) if h.startswith('sha')]
|
||||
return t + hashes
|
||||
|
||||
|
||||
def setdefaultencoding(encoding):
|
||||
"""
|
||||
Set the current default string encoding used by the Unicode implementation.
|
||||
|
||||
Actually calls sys.setdefaultencoding under the hood - see the docs for that
|
||||
for more details. This method exists only as a way to call find/call it
|
||||
even after it has been 'deleted' when the site module is executed.
|
||||
|
||||
:param string encoding: An encoding name, compatible with sys.setdefaultencoding
|
||||
"""
|
||||
func = getattr(sys, 'setdefaultencoding', None)
|
||||
if func is None:
|
||||
import gc
|
||||
import types
|
||||
for obj in gc.get_objects():
|
||||
if (isinstance(obj, types.BuiltinFunctionType)
|
||||
and obj.__name__ == 'setdefaultencoding'):
|
||||
func = obj
|
||||
break
|
||||
if func is None:
|
||||
raise RuntimeError("Could not find setdefaultencoding")
|
||||
sys.setdefaultencoding = func
|
||||
return func(encoding)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
# slixmpp.util.sasl.client
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# This module was originally based on Dave Cridland's Suelta library.
|
||||
|
@ -6,9 +5,11 @@
|
|||
# :copryight: (c) 2004-2013 David Alan Cridland
|
||||
# :copyright: (c) 2013 Nathanael C. Fritz, Lance J.T. Stout
|
||||
# :license: MIT, see LICENSE for more details
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import stringprep
|
||||
|
||||
from typing import Iterable, Set, Callable, Dict, Any, Optional, Type
|
||||
from slixmpp.util import hashes, bytes, stringprep_profiles
|
||||
|
||||
|
||||
|
@ -16,11 +17,11 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
#: Global registry mapping mechanism names to implementation classes.
|
||||
MECHANISMS = {}
|
||||
MECHANISMS: Dict[str, Type[Mech]] = {}
|
||||
|
||||
|
||||
#: Global registry mapping mechanism names to security scores.
|
||||
MECH_SEC_SCORES = {}
|
||||
MECH_SEC_SCORES: Dict[str, int] = {}
|
||||
|
||||
|
||||
#: The SASLprep profile of stringprep used to validate simple username
|
||||
|
@ -45,9 +46,10 @@ saslprep = stringprep_profiles.create(
|
|||
unassigned=[stringprep.in_table_a1])
|
||||
|
||||
|
||||
def sasl_mech(score):
|
||||
def sasl_mech(score: int):
|
||||
sec_score = score
|
||||
def register(mech):
|
||||
|
||||
def register(mech: Type[Mech]):
|
||||
n = 0
|
||||
mech.score = sec_score
|
||||
if mech.use_hashes:
|
||||
|
@ -99,9 +101,9 @@ class Mech(object):
|
|||
score = -1
|
||||
use_hashes = False
|
||||
channel_binding = False
|
||||
required_credentials = set()
|
||||
optional_credentials = set()
|
||||
security = set()
|
||||
required_credentials: Set[str] = set()
|
||||
optional_credentials: Set[str] = set()
|
||||
security: Set[str] = set()
|
||||
|
||||
def __init__(self, name, credentials, security_settings):
|
||||
self.credentials = credentials
|
||||
|
@ -118,7 +120,14 @@ class Mech(object):
|
|||
return b''
|
||||
|
||||
|
||||
def choose(mech_list, credentials, security_settings, limit=None, min_mech=None):
|
||||
CredentialsCallback = Callable[[Iterable[str], Iterable[str]], Dict[str, Any]]
|
||||
SecurityCallback = Callable[[Iterable[str]], Dict[str, Any]]
|
||||
|
||||
|
||||
def choose(mech_list: Iterable[Type[Mech]], credentials: CredentialsCallback,
|
||||
security_settings: SecurityCallback,
|
||||
limit: Optional[Iterable[Type[Mech]]] = None,
|
||||
min_mech: Optional[str] = None) -> Mech:
|
||||
available_mechs = set(MECHANISMS.keys())
|
||||
if limit is None:
|
||||
limit = set(mech_list)
|
||||
|
@ -130,7 +139,10 @@ def choose(mech_list, credentials, security_settings, limit=None, min_mech=None)
|
|||
mech_list = mech_list.intersection(limit)
|
||||
available_mechs = available_mechs.intersection(mech_list)
|
||||
|
||||
best_score = MECH_SEC_SCORES.get(min_mech, -1)
|
||||
if min_mech is None:
|
||||
best_score = -1
|
||||
else:
|
||||
best_score = MECH_SEC_SCORES.get(min_mech, -1)
|
||||
best_mech = None
|
||||
for name in available_mechs:
|
||||
if name in MECH_SEC_SCORES:
|
||||
|
|
|
@ -11,6 +11,9 @@ import hmac
|
|||
import random
|
||||
|
||||
from base64 import b64encode, b64decode
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
bytes_ = bytes
|
||||
|
||||
from slixmpp.util import bytes, hash, XOR, quote, num_to_bytes
|
||||
from slixmpp.util.sasl.client import sasl_mech, Mech, \
|
||||
|
@ -63,7 +66,7 @@ class PLAIN(Mech):
|
|||
if not self.security_settings['encrypted_plain']:
|
||||
raise SASLCancelled('PLAIN with encryption')
|
||||
|
||||
def process(self, challenge=b''):
|
||||
def process(self, challenge: bytes_ = b'') -> bytes_:
|
||||
authzid = self.credentials['authzid']
|
||||
authcid = self.credentials['username']
|
||||
password = self.credentials['password']
|
||||
|
@ -148,7 +151,7 @@ class CRAM(Mech):
|
|||
required_credentials = {'username', 'password'}
|
||||
security = {'encrypted', 'unencrypted_cram'}
|
||||
|
||||
def setup(self, name):
|
||||
def setup(self, name: str):
|
||||
self.hash_name = name[5:]
|
||||
self.hash = hash(self.hash_name)
|
||||
if self.hash is None:
|
||||
|
@ -157,14 +160,14 @@ class CRAM(Mech):
|
|||
if not self.security_settings['unencrypted_cram']:
|
||||
raise SASLCancelled('Unecrypted CRAM-%s' % self.hash_name)
|
||||
|
||||
def process(self, challenge=b''):
|
||||
def process(self, challenge: bytes_ = b'') -> Optional[bytes_]:
|
||||
if not challenge:
|
||||
return None
|
||||
|
||||
username = self.credentials['username']
|
||||
password = self.credentials['password']
|
||||
|
||||
mac = hmac.HMAC(key=password, digestmod=self.hash)
|
||||
mac = hmac.HMAC(key=password, digestmod=self.hash) # type: ignore
|
||||
mac.update(challenge)
|
||||
|
||||
return username + b' ' + bytes(mac.hexdigest())
|
||||
|
@ -201,43 +204,42 @@ class SCRAM(Mech):
|
|||
def HMAC(self, key, msg):
|
||||
return hmac.HMAC(key=key, msg=msg, digestmod=self.hash).digest()
|
||||
|
||||
def Hi(self, text, salt, iterations):
|
||||
text = bytes(text)
|
||||
ui1 = self.HMAC(text, salt + b'\0\0\0\01')
|
||||
def Hi(self, text: str, salt: bytes_, iterations: int):
|
||||
text_enc = bytes(text)
|
||||
ui1 = self.HMAC(text_enc, salt + b'\0\0\0\01')
|
||||
ui = ui1
|
||||
for i in range(iterations - 1):
|
||||
ui1 = self.HMAC(text, ui1)
|
||||
ui1 = self.HMAC(text_enc, ui1)
|
||||
ui = XOR(ui, ui1)
|
||||
return ui
|
||||
|
||||
def H(self, text):
|
||||
def H(self, text: str) -> bytes_:
|
||||
return self.hash(text).digest()
|
||||
|
||||
def saslname(self, value):
|
||||
value = value.decode("utf-8")
|
||||
escaped = []
|
||||
def saslname(self, value_b: bytes_) -> bytes_:
|
||||
value = value_b.decode("utf-8")
|
||||
escaped: List[str] = []
|
||||
for char in value:
|
||||
if char == ',':
|
||||
escaped += b'=2C'
|
||||
escaped.append('=2C')
|
||||
elif char == '=':
|
||||
escaped += b'=3D'
|
||||
escaped.append('=3D')
|
||||
else:
|
||||
escaped += char
|
||||
escaped.append(char)
|
||||
return "".join(escaped).encode("utf-8")
|
||||
|
||||
def parse(self, challenge):
|
||||
def parse(self, challenge: bytes_) -> Dict[bytes_, bytes_]:
|
||||
items = {}
|
||||
for key, value in [item.split(b'=', 1) for item in challenge.split(b',')]:
|
||||
items[key] = value
|
||||
return items
|
||||
|
||||
def process(self, challenge=b''):
|
||||
def process(self, challenge: bytes_ = b''):
|
||||
steps = [self.process_1, self.process_2, self.process_3]
|
||||
return steps[self.step](challenge)
|
||||
|
||||
def process_1(self, challenge):
|
||||
def process_1(self, challenge: bytes_) -> bytes_:
|
||||
self.step = 1
|
||||
data = {}
|
||||
|
||||
self.cnonce = bytes(('%s' % random.random())[2:])
|
||||
|
||||
|
@ -263,7 +265,7 @@ class SCRAM(Mech):
|
|||
|
||||
return self.client_first_message
|
||||
|
||||
def process_2(self, challenge):
|
||||
def process_2(self, challenge: bytes_) -> bytes_:
|
||||
self.step = 2
|
||||
|
||||
data = self.parse(challenge)
|
||||
|
@ -304,7 +306,7 @@ class SCRAM(Mech):
|
|||
|
||||
return client_final_message
|
||||
|
||||
def process_3(self, challenge):
|
||||
def process_3(self, challenge: bytes_) -> bytes_:
|
||||
data = self.parse(challenge)
|
||||
verifier = data.get(b'v', None)
|
||||
error = data.get(b'e', 'Unknown error')
|
||||
|
@ -345,17 +347,16 @@ class DIGEST(Mech):
|
|||
self.cnonce = b''
|
||||
self.nonce_count = 1
|
||||
|
||||
def parse(self, challenge=b''):
|
||||
data = {}
|
||||
def parse(self, challenge: bytes_ = b''):
|
||||
data: Dict[str, bytes_] = {}
|
||||
var_name = b''
|
||||
var_value = b''
|
||||
|
||||
# States: var, new_var, end, quote, escaped_quote
|
||||
state = 'var'
|
||||
|
||||
|
||||
for char in challenge:
|
||||
char = bytes([char])
|
||||
for char_int in challenge:
|
||||
char = bytes_([char_int])
|
||||
|
||||
if state == 'var':
|
||||
if char.isspace():
|
||||
|
@ -401,14 +402,14 @@ class DIGEST(Mech):
|
|||
state = 'var'
|
||||
return data
|
||||
|
||||
def MAC(self, key, seq, msg):
|
||||
def MAC(self, key: bytes_, seq: int, msg: bytes_) -> bytes_:
|
||||
mac = hmac.HMAC(key=key, digestmod=self.hash)
|
||||
seqnum = num_to_bytes(seq)
|
||||
mac.update(seqnum)
|
||||
mac.update(msg)
|
||||
return mac.digest()[:10] + b'\x00\x01' + seqnum
|
||||
|
||||
def A1(self):
|
||||
def A1(self) -> bytes_:
|
||||
username = self.credentials['username']
|
||||
password = self.credentials['password']
|
||||
authzid = self.credentials['authzid']
|
||||
|
@ -423,13 +424,13 @@ class DIGEST(Mech):
|
|||
|
||||
return bytes(a1)
|
||||
|
||||
def A2(self, prefix=b''):
|
||||
def A2(self, prefix: bytes_ = b'') -> bytes_:
|
||||
a2 = prefix + b':' + self.digest_uri()
|
||||
if self.qop in (b'auth-int', b'auth-conf'):
|
||||
a2 += b':00000000000000000000000000000000'
|
||||
return bytes(a2)
|
||||
|
||||
def response(self, prefix=b''):
|
||||
def response(self, prefix: bytes_ = b'') -> bytes_:
|
||||
nc = bytes('%08x' % self.nonce_count)
|
||||
|
||||
a1 = bytes(self.hash(self.A1()).hexdigest().lower())
|
||||
|
@ -439,7 +440,7 @@ class DIGEST(Mech):
|
|||
|
||||
return bytes(self.hash(a1 + b':' + s).hexdigest().lower())
|
||||
|
||||
def digest_uri(self):
|
||||
def digest_uri(self) -> bytes_:
|
||||
serv_type = self.credentials['service']
|
||||
serv_name = self.credentials['service-name']
|
||||
host = self.credentials['host']
|
||||
|
@ -449,7 +450,7 @@ class DIGEST(Mech):
|
|||
uri += b'/' + serv_name
|
||||
return uri
|
||||
|
||||
def respond(self):
|
||||
def respond(self) -> bytes_:
|
||||
data = {
|
||||
'username': quote(self.credentials['username']),
|
||||
'authzid': quote(self.credentials['authzid']),
|
||||
|
@ -469,7 +470,7 @@ class DIGEST(Mech):
|
|||
resp += b',' + bytes(key) + b'=' + bytes(value)
|
||||
return resp[1:]
|
||||
|
||||
def process(self, challenge=b''):
|
||||
def process(self, challenge: bytes_ = b'') -> Optional[bytes_]:
|
||||
if not challenge:
|
||||
if self.cnonce and self.nonce and self.nonce_count and self.qop:
|
||||
self.nonce_count += 1
|
||||
|
@ -480,6 +481,7 @@ class DIGEST(Mech):
|
|||
if 'rspauth' in data:
|
||||
if data['rspauth'] != self.response():
|
||||
raise SASLMutualAuthFailed()
|
||||
return None
|
||||
else:
|
||||
self.nonce_count = 1
|
||||
self.cnonce = bytes('%s' % random.random())[2:]
|
||||
|
|
Loading…
Reference in a new issue