Allow Xmlstream.ca_certs to be an iterable
Signed-off-by: Maxime “pep” Buquet <pep@bouah.net>
This commit is contained in:
parent
834ea8ed74
commit
d733c54518
1 changed files with 25 additions and 3 deletions
|
@ -15,6 +15,7 @@ from typing import (
|
|||
Coroutine,
|
||||
Callable,
|
||||
Iterator,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
|
@ -33,7 +34,6 @@ import socket as Socket
|
|||
import ssl
|
||||
import weakref
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from contextlib import contextmanager
|
||||
import xml.etree.ElementTree as ET
|
||||
|
@ -47,6 +47,7 @@ from asyncio import (
|
|||
iscoroutinefunction,
|
||||
wait,
|
||||
)
|
||||
from pathlib import Path
|
||||
|
||||
from slixmpp.types import FilterString
|
||||
from slixmpp.xmlstream.tostring import tostring
|
||||
|
@ -75,6 +76,15 @@ class NotConnectedError(Exception):
|
|||
"""
|
||||
|
||||
|
||||
class InvalidCABundle(Exception):
|
||||
"""
|
||||
Exception raised when the CA Bundle file hasn't been found.
|
||||
"""
|
||||
|
||||
def __init__(self, path: Optional[Path]):
|
||||
self.path = path
|
||||
|
||||
|
||||
_T = TypeVar('_T', str, ElementBase, StanzaBase)
|
||||
|
||||
|
||||
|
@ -162,7 +172,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
#:
|
||||
#: On Mac OS X, certificates in the system keyring will
|
||||
#: be consulted, even if they are not in the provided file.
|
||||
ca_certs: Optional[Path]
|
||||
ca_certs: Optional[Union[Path, Iterable[Path]]]
|
||||
|
||||
#: Path to a file containing a client certificate to use for
|
||||
#: authenticating via SASL EXTERNAL. If set, there must also
|
||||
|
@ -760,8 +770,20 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
log.debug('Loaded cert file %s and key file %s',
|
||||
self.certfile, self.keyfile)
|
||||
if self.ca_certs is not None:
|
||||
ca_cert: Optional[Path] = None
|
||||
if isinstance(self.ca_certs, Path):
|
||||
if self.ca_certs.is_file():
|
||||
ca_cert = self.ca_certs
|
||||
else:
|
||||
for bundle in self.ca_certs:
|
||||
if bundle.is_file():
|
||||
ca_cert = bundle
|
||||
break
|
||||
if ca_cert is None:
|
||||
raise InvalidCABundle(ca_cert)
|
||||
|
||||
self.ssl_context.verify_mode = ssl.CERT_REQUIRED
|
||||
self.ssl_context.load_verify_locations(cafile=self.ca_certs)
|
||||
self.ssl_context.load_verify_locations(cafile=ca_cert)
|
||||
|
||||
return self.ssl_context
|
||||
|
||||
|
|
Loading…
Reference in a new issue