Allow Xmlstream.ca_certs to be an iterable

Signed-off-by: Maxime “pep” Buquet <pep@bouah.net>
This commit is contained in:
Maxime “pep” Buquet 2021-12-28 19:50:20 +01:00
parent 834ea8ed74
commit d733c54518
Signed by: pep
GPG key ID: DEDA74AEECA9D0F2

View file

@ -15,6 +15,7 @@ from typing import (
Coroutine,
Callable,
Iterator,
Iterable,
List,
Optional,
Set,
@ -33,7 +34,6 @@ import socket as Socket
import ssl
import weakref
import uuid
from pathlib import Path
from contextlib import contextmanager
import xml.etree.ElementTree as ET
@ -47,6 +47,7 @@ from asyncio import (
iscoroutinefunction,
wait,
)
from pathlib import Path
from slixmpp.types import FilterString
from slixmpp.xmlstream.tostring import tostring
@ -75,6 +76,15 @@ class NotConnectedError(Exception):
"""
class InvalidCABundle(Exception):
"""
Exception raised when the CA Bundle file hasn't been found.
"""
def __init__(self, path: Optional[Path]):
self.path = path
_T = TypeVar('_T', str, ElementBase, StanzaBase)
@ -162,7 +172,7 @@ class XMLStream(asyncio.BaseProtocol):
#:
#: On Mac OS X, certificates in the system keyring will
#: be consulted, even if they are not in the provided file.
ca_certs: Optional[Path]
ca_certs: Optional[Union[Path, Iterable[Path]]]
#: Path to a file containing a client certificate to use for
#: authenticating via SASL EXTERNAL. If set, there must also
@ -760,8 +770,20 @@ class XMLStream(asyncio.BaseProtocol):
log.debug('Loaded cert file %s and key file %s',
self.certfile, self.keyfile)
if self.ca_certs is not None:
ca_cert: Optional[Path] = None
if isinstance(self.ca_certs, Path):
if self.ca_certs.is_file():
ca_cert = self.ca_certs
else:
for bundle in self.ca_certs:
if bundle.is_file():
ca_cert = bundle
break
if ca_cert is None:
raise InvalidCABundle(ca_cert)
self.ssl_context.verify_mode = ssl.CERT_REQUIRED
self.ssl_context.load_verify_locations(cafile=self.ca_certs)
self.ssl_context.load_verify_locations(cafile=ca_cert)
return self.ssl_context