From f49311ef9ee76c2e4cce402e377867eff308aca0 Mon Sep 17 00:00:00 2001 From: Lance Stout Date: Tue, 22 May 2012 03:56:06 -0700 Subject: [PATCH] Add better certificate handling. Certificate host names are now matched (using DNS, SRV, XMPPAddr, and Common Name), along with expiration check. Scheduled event to reset the stream once the server's cert expires. Handle invalid cert trust chains gracefully now. --- setup.py | 2 +- sleekxmpp/basexmpp.py | 1 + sleekxmpp/clientxmpp.py | 2 + sleekxmpp/componentxmpp.py | 3 + sleekxmpp/xmlstream/cert.py | 173 +++++++++++++++++++++++++++++++ sleekxmpp/xmlstream/xmlstream.py | 77 +++++++++++--- 6 files changed, 245 insertions(+), 13 deletions(-) create mode 100644 sleekxmpp/xmlstream/cert.py diff --git a/setup.py b/setup.py index 4bb77bd9..de89021b 100755 --- a/setup.py +++ b/setup.py @@ -112,7 +112,7 @@ setup( license = 'MIT', platforms = [ 'any' ], packages = packages, - requires = [ 'dnspython' ], + requires = [ 'dnspython', 'pyasn1', 'pyasn1_modules' ], classifiers = CLASSIFIERS, cmdclass = {'test': TestCommand} ) diff --git a/sleekxmpp/basexmpp.py b/sleekxmpp/basexmpp.py index 1c835460..63e3339c 100644 --- a/sleekxmpp/basexmpp.py +++ b/sleekxmpp/basexmpp.py @@ -68,6 +68,7 @@ class BaseXMPP(XMLStream): #: The JabberID (JID) used by this connection. self.boundjid = JID(jid) + self._expected_server_name = self.boundjid.host #: A dictionary mapping plugin names to plugins. self.plugin = PluginManager(self) diff --git a/sleekxmpp/clientxmpp.py b/sleekxmpp/clientxmpp.py index e77e6ce2..94ced031 100644 --- a/sleekxmpp/clientxmpp.py +++ b/sleekxmpp/clientxmpp.py @@ -149,6 +149,8 @@ class ClientXMPP(BaseXMPP): address = (self.boundjid.host, 5222) self.dns_service = 'xmpp-client' + self._expected_server_name = self.boundjid.host + return XMLStream.connect(self, address[0], address[1], use_tls=use_tls, use_ssl=use_ssl, reattempt=reattempt) diff --git a/sleekxmpp/componentxmpp.py b/sleekxmpp/componentxmpp.py index df23c2f6..348a08e0 100644 --- a/sleekxmpp/componentxmpp.py +++ b/sleekxmpp/componentxmpp.py @@ -101,6 +101,9 @@ class ComponentXMPP(BaseXMPP): host = self.server_host if port is None: port = self.server_port + + self.server_name = self.boundjid.host + log.debug("Connecting to %s:%s", host, port) return XMLStream.connect(self, host=host, port=port, use_ssl=use_ssl, diff --git a/sleekxmpp/xmlstream/cert.py b/sleekxmpp/xmlstream/cert.py new file mode 100644 index 00000000..0f58d4ed --- /dev/null +++ b/sleekxmpp/xmlstream/cert.py @@ -0,0 +1,173 @@ +import logging +from datetime import datetime, timedelta + + +try: + from pyasn1.codec.der import decoder, encoder + from pyasn1.type.univ import Any, ObjectIdentifier, OctetString + from pyasn1.type.char import BMPString, IA5String, UTF8String + from pyasn1.type.useful import GeneralizedTime + from pyasn1_modules.rfc2459 import Certificate, DirectoryString, SubjectAltName, GeneralNames, GeneralName + from pyasn1_modules.rfc2459 import id_ce_subjectAltName as SUBJECT_ALT_NAME + from pyasn1_modules.rfc2459 import id_at_commonName as COMMON_NAME + + + XMPP_ADDR = ObjectIdentifier('1.3.6.1.5.5.7.8.5') + SRV_NAME = ObjectIdentifier('1.3.6.1.5.5.7.8.7') + + HAVE_PYASN1 = True +except ImportError: + HAVE_PYASN1 = False + + +log = logging.getLogger(__name__) + + +class CertificateError(Exception): + pass + + +def decode_str(data): + encoding = 'utf-16-be' if isinstance(data, BMPString) else 'utf-8' + return bytes(data).decode(encoding) + + +def extract_names(raw_cert): + results = {'CN': set(), + 'DNS': set(), + 'SRV': set(), + 'URI': set(), + 'XMPPAddr': set()} + + cert = decoder.decode(raw_cert, asn1Spec=Certificate())[0] + tbs = cert.getComponentByName('tbsCertificate') + subject = tbs.getComponentByName('subject') + extensions = tbs.getComponentByName('extensions') + + # Extract the CommonName(s) from the cert. + for rdnss in subject: + for rdns in rdnss: + for name in rdns: + oid = name.getComponentByName('type') + value = name.getComponentByName('value') + + if oid != COMMON_NAME: + continue + + value = decoder.decode(value, asn1Spec=DirectoryString())[0] + value = decode_str(value.getComponent()) + results['CN'].add(value) + + # Extract the Subject Alternate Names (DNS, SRV, URI, XMPPAddr) + for extension in extensions: + oid = extension.getComponentByName('extnID') + if oid != SUBJECT_ALT_NAME: + continue + + value = decoder.decode(extension.getComponentByName('extnValue'), + asn1Spec=OctetString())[0] + sa_names = decoder.decode(value, asn1Spec=SubjectAltName())[0] + for name in sa_names: + name_type = name.getName() + if name_type == 'dNSName': + results['DNS'].add(decode_str(name.getComponent())) + if name_type == 'uniformResourceIdentifier': + value = decode_str(name.getComponent()) + if value.startswith('xmpp:'): + results['URI'].add(value[5:]) + elif name_type == 'otherName': + name = name.getComponent() + + oid = name.getComponentByName('type-id') + value = name.getComponentByName('value') + + if oid == XMPP_ADDR: + value = decoder.decode(value, asn1Spec=UTF8String())[0] + results['XMPPAddr'].add(decode_str(value)) + elif oid == SRV_NAME: + value = decoder.decode(value, asn1Spec=IA5String())[0] + results['SRV'].add(decode_str(value)) + + return results + + +def extract_dates(raw_cert): + if not HAVE_PYASN1: + log.warning("Could not find pyasn1 module. " + \ + "SSL certificate expiration COULD NOT BE VERIFIED.") + return None, None + + cert = decoder.decode(raw_cert, asn1Spec=Certificate())[0] + tbs = cert.getComponentByName('tbsCertificate') + validity = tbs.getComponentByName('validity') + + not_before = validity.getComponentByName('notBefore') + not_before = str(not_before.getComponent()) + + not_after = validity.getComponentByName('notAfter') + not_after = str(not_after.getComponent()) + + if isinstance(not_before, GeneralizedTime): + not_before = datetime.strptime(not_before, '%Y%m%d%H%M%SZ') + else: + not_before = datetime.strptime(not_before, '%y%m%d%H%M%SZ') + + if isinstance(not_after, GeneralizedTime): + not_after = datetime.strptime(not_after, '%Y%m%d%H%M%SZ') + else: + not_after = datetime.strptime(not_after, '%y%m%d%H%M%SZ') + + return not_before, not_after + + +def get_ttl(raw_cert): + not_before, not_after = extract_dates(raw_cert) + if not_after is None: + return None + return not_after - datetime.utcnow() + + +def verify(expected, raw_cert): + if not HAVE_PYASN1: + log.warning("Could not find pyasn1 module. " + \ + "SSL certificate COULD NOT BE VERIFIED.") + return + + not_before, not_after = extract_dates(raw_cert) + cert_names = extract_names(raw_cert) + + now = datetime.utcnow() + + if not_before > now: + raise CertificateError( + 'Certificate has not entered its valid date range.') + + if not_after <= now: + raise CertificateError( + 'Certificate has expired.') + + expected_wild = expected[expected.index('.'):] + expected_srv = '_xmpp-client.%s' % expected + + for name in cert_names['XMPPAddr']: + if name == expected: + return True + for name in cert_names['SRV']: + if name == expected_srv or name == expected: + return True + for name in cert_names['DNS']: + if name == expected: + return True + if name.startswith('*'): + name_wild = name[name.index('.'):] + if expected_wild == name_wild: + return True + for name in cert_names['URI']: + if name == expected: + return True + for name in cert_names['CN']: + if name == expected: + return True + + raise CertificateError( + 'Could not match certficate against hostname: %s' % expected) diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py index daa1af1a..56177556 100644 --- a/sleekxmpp/xmlstream/xmlstream.py +++ b/sleekxmpp/xmlstream/xmlstream.py @@ -35,7 +35,7 @@ from xml.parsers.expat import ExpatError import sleekxmpp from sleekxmpp.thirdparty.statemachine import StateMachine -from sleekxmpp.xmlstream import Scheduler, tostring +from sleekxmpp.xmlstream import Scheduler, tostring, cert from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET, ElementBase from sleekxmpp.xmlstream.handler import Waiter, XMLCallback from sleekxmpp.xmlstream.matcher import MatchXMLMask @@ -181,6 +181,9 @@ class XMLStream(object): #: The domain to try when querying DNS records. self.default_domain = '' + + #: The expected name of the server, for validation. + self._expected_server_name = '' #: The desired, or actual, address of the connected server. self.address = (host, int(port)) @@ -313,8 +316,9 @@ class XMLStream(object): self.dns_service = None self.add_event_handler('connected', self._handle_connected) - self.add_event_handler('session_start', self._start_keepalive) self.add_event_handler('disconnected', self._end_keepalive) + self.add_event_handler('session_start', self._start_keepalive) + self.add_event_handler('session_start', self._cert_expiration) def use_signals(self, signals=None): """Register signal handlers for ``SIGHUP`` and ``SIGTERM``. @@ -500,10 +504,17 @@ class XMLStream(object): self.socket.connect(self.address) if self.use_ssl and self.ssl_support: - cert = self.socket.getpeercert(binary_form=True) - cert = ssl.DER_cert_to_PEM_cert(cert) - log.debug('CERT: %s', cert) - self.event('ssl_cert', cert, direct=True) + self._der_cert = self.socket.getpeercert(binary_form=True) + pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert) + log.debug('CERT: %s', pem_cert) + + self.event('ssl_cert', pem_cert, direct=True) + try: + cert.verify(self._expected_server_name, self._der_cert) + except cert.CertificateError as err: + log.error(err.message) + self.event('ssl_invalid_cert', cert, direct=True) + self.disconnect(send_close=False) self.set_socket(self.socket, ignore=True) #this event is where you should set your application state @@ -767,12 +778,27 @@ class XMLStream(object): self.socket.socket = ssl_socket else: self.socket = ssl_socket - self.socket.do_handshake() - cert = self.socket.getpeercert(binary_form=True) - cert = ssl.DER_cert_to_PEM_cert(cert) - log.debug('CERT: %s', cert) - self.event('ssl_cert', cert, direct=True) + try: + self.socket.do_handshake() + except: + log.error('CERT: Invalid certificate trust chain.') + self.event('ssl_invalid_chain', direct=True) + self.disconnect(self.auto_reconnect, send_close=False) + return False + + self._der_cert = self.socket.getpeercert(binary_form=True) + pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert) + log.debug('CERT: %s', pem_cert) + self.event('ssl_cert', pem_cert, direct=True) + + try: + cert.verify(self._expected_server_name, self._der_cert) + except cert.CertificateError as err: + log.error(err.message) + self.event('ssl_invalid_cert', cert, direct=True) + if not self.event_handled('ssl_invalid_cert'): + self.disconnect(self.auto_reconnect, send_close=False) self.set_socket(self.socket) return True @@ -780,6 +806,26 @@ class XMLStream(object): log.warning("Tried to enable TLS, but ssl module not found.") return False + def _cert_expiration(self, event): + """Schedule an event for when the TLS certificate expires.""" + + def restart(): + log.warn("The server certificate has expired. Restarting.") + self.reconnect() + + cert_ttl = cert.get_ttl(self._der_cert) + if cert_ttl is None: + return + + if cert_ttl.days < 0: + log.warn('CERT: Certificate has expired.') + restart() + + log.info('CERT: Time until certificate expiration: %s' % cert_ttl) + self.schedule('Certificate Expiration', + cert_ttl.seconds, + restart) + def _start_keepalive(self, event): """Begin sending whitespace periodically to keep the connection alive. @@ -1298,9 +1344,16 @@ class XMLStream(object): except (Socket.error, ssl.SSLError) as serr: self.event('socket_error', serr, direct=True) log.error('Socket Error #%s: %s', serr.errno, serr.strerror) + except ValueError as e: + msg = e.message if hasattr(e, 'message') else e.args[0] + + if 'I/O operation on closed file' in msg: + log.error('Can not read from closed socket.') + else: + self.exception(e) except Exception as e: if not self.stop.is_set(): - log.exception('Connection error.') + log.error('Connection error.') self.exception(e) if not shutdown and not self.stop.is_set() \