diff --git a/setup.py b/setup.py index 4bb77bd9..de89021b 100755 --- a/setup.py +++ b/setup.py @@ -112,7 +112,7 @@ setup( license = 'MIT', platforms = [ 'any' ], packages = packages, - requires = [ 'dnspython' ], + requires = [ 'dnspython', 'pyasn1', 'pyasn1_modules' ], classifiers = CLASSIFIERS, cmdclass = {'test': TestCommand} ) diff --git a/sleekxmpp/basexmpp.py b/sleekxmpp/basexmpp.py index 1c835460..63e3339c 100644 --- a/sleekxmpp/basexmpp.py +++ b/sleekxmpp/basexmpp.py @@ -68,6 +68,7 @@ class BaseXMPP(XMLStream): #: The JabberID (JID) used by this connection. self.boundjid = JID(jid) + self._expected_server_name = self.boundjid.host #: A dictionary mapping plugin names to plugins. self.plugin = PluginManager(self) diff --git a/sleekxmpp/clientxmpp.py b/sleekxmpp/clientxmpp.py index e77e6ce2..94ced031 100644 --- a/sleekxmpp/clientxmpp.py +++ b/sleekxmpp/clientxmpp.py @@ -149,6 +149,8 @@ class ClientXMPP(BaseXMPP): address = (self.boundjid.host, 5222) self.dns_service = 'xmpp-client' + self._expected_server_name = self.boundjid.host + return XMLStream.connect(self, address[0], address[1], use_tls=use_tls, use_ssl=use_ssl, reattempt=reattempt) diff --git a/sleekxmpp/componentxmpp.py b/sleekxmpp/componentxmpp.py index df23c2f6..348a08e0 100644 --- a/sleekxmpp/componentxmpp.py +++ b/sleekxmpp/componentxmpp.py @@ -101,6 +101,9 @@ class ComponentXMPP(BaseXMPP): host = self.server_host if port is None: port = self.server_port + + self.server_name = self.boundjid.host + log.debug("Connecting to %s:%s", host, port) return XMLStream.connect(self, host=host, port=port, use_ssl=use_ssl, diff --git a/sleekxmpp/xmlstream/cert.py b/sleekxmpp/xmlstream/cert.py new file mode 100644 index 00000000..0f58d4ed --- /dev/null +++ b/sleekxmpp/xmlstream/cert.py @@ -0,0 +1,173 @@ +import logging +from datetime import datetime, timedelta + + +try: + from pyasn1.codec.der import decoder, encoder + from pyasn1.type.univ import Any, ObjectIdentifier, OctetString + from pyasn1.type.char import BMPString, IA5String, UTF8String + from pyasn1.type.useful import GeneralizedTime + from pyasn1_modules.rfc2459 import Certificate, DirectoryString, SubjectAltName, GeneralNames, GeneralName + from pyasn1_modules.rfc2459 import id_ce_subjectAltName as SUBJECT_ALT_NAME + from pyasn1_modules.rfc2459 import id_at_commonName as COMMON_NAME + + + XMPP_ADDR = ObjectIdentifier('1.3.6.1.5.5.7.8.5') + SRV_NAME = ObjectIdentifier('1.3.6.1.5.5.7.8.7') + + HAVE_PYASN1 = True +except ImportError: + HAVE_PYASN1 = False + + +log = logging.getLogger(__name__) + + +class CertificateError(Exception): + pass + + +def decode_str(data): + encoding = 'utf-16-be' if isinstance(data, BMPString) else 'utf-8' + return bytes(data).decode(encoding) + + +def extract_names(raw_cert): + results = {'CN': set(), + 'DNS': set(), + 'SRV': set(), + 'URI': set(), + 'XMPPAddr': set()} + + cert = decoder.decode(raw_cert, asn1Spec=Certificate())[0] + tbs = cert.getComponentByName('tbsCertificate') + subject = tbs.getComponentByName('subject') + extensions = tbs.getComponentByName('extensions') + + # Extract the CommonName(s) from the cert. + for rdnss in subject: + for rdns in rdnss: + for name in rdns: + oid = name.getComponentByName('type') + value = name.getComponentByName('value') + + if oid != COMMON_NAME: + continue + + value = decoder.decode(value, asn1Spec=DirectoryString())[0] + value = decode_str(value.getComponent()) + results['CN'].add(value) + + # Extract the Subject Alternate Names (DNS, SRV, URI, XMPPAddr) + for extension in extensions: + oid = extension.getComponentByName('extnID') + if oid != SUBJECT_ALT_NAME: + continue + + value = decoder.decode(extension.getComponentByName('extnValue'), + asn1Spec=OctetString())[0] + sa_names = decoder.decode(value, asn1Spec=SubjectAltName())[0] + for name in sa_names: + name_type = name.getName() + if name_type == 'dNSName': + results['DNS'].add(decode_str(name.getComponent())) + if name_type == 'uniformResourceIdentifier': + value = decode_str(name.getComponent()) + if value.startswith('xmpp:'): + results['URI'].add(value[5:]) + elif name_type == 'otherName': + name = name.getComponent() + + oid = name.getComponentByName('type-id') + value = name.getComponentByName('value') + + if oid == XMPP_ADDR: + value = decoder.decode(value, asn1Spec=UTF8String())[0] + results['XMPPAddr'].add(decode_str(value)) + elif oid == SRV_NAME: + value = decoder.decode(value, asn1Spec=IA5String())[0] + results['SRV'].add(decode_str(value)) + + return results + + +def extract_dates(raw_cert): + if not HAVE_PYASN1: + log.warning("Could not find pyasn1 module. " + \ + "SSL certificate expiration COULD NOT BE VERIFIED.") + return None, None + + cert = decoder.decode(raw_cert, asn1Spec=Certificate())[0] + tbs = cert.getComponentByName('tbsCertificate') + validity = tbs.getComponentByName('validity') + + not_before = validity.getComponentByName('notBefore') + not_before = str(not_before.getComponent()) + + not_after = validity.getComponentByName('notAfter') + not_after = str(not_after.getComponent()) + + if isinstance(not_before, GeneralizedTime): + not_before = datetime.strptime(not_before, '%Y%m%d%H%M%SZ') + else: + not_before = datetime.strptime(not_before, '%y%m%d%H%M%SZ') + + if isinstance(not_after, GeneralizedTime): + not_after = datetime.strptime(not_after, '%Y%m%d%H%M%SZ') + else: + not_after = datetime.strptime(not_after, '%y%m%d%H%M%SZ') + + return not_before, not_after + + +def get_ttl(raw_cert): + not_before, not_after = extract_dates(raw_cert) + if not_after is None: + return None + return not_after - datetime.utcnow() + + +def verify(expected, raw_cert): + if not HAVE_PYASN1: + log.warning("Could not find pyasn1 module. " + \ + "SSL certificate COULD NOT BE VERIFIED.") + return + + not_before, not_after = extract_dates(raw_cert) + cert_names = extract_names(raw_cert) + + now = datetime.utcnow() + + if not_before > now: + raise CertificateError( + 'Certificate has not entered its valid date range.') + + if not_after <= now: + raise CertificateError( + 'Certificate has expired.') + + expected_wild = expected[expected.index('.'):] + expected_srv = '_xmpp-client.%s' % expected + + for name in cert_names['XMPPAddr']: + if name == expected: + return True + for name in cert_names['SRV']: + if name == expected_srv or name == expected: + return True + for name in cert_names['DNS']: + if name == expected: + return True + if name.startswith('*'): + name_wild = name[name.index('.'):] + if expected_wild == name_wild: + return True + for name in cert_names['URI']: + if name == expected: + return True + for name in cert_names['CN']: + if name == expected: + return True + + raise CertificateError( + 'Could not match certficate against hostname: %s' % expected) diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py index daa1af1a..56177556 100644 --- a/sleekxmpp/xmlstream/xmlstream.py +++ b/sleekxmpp/xmlstream/xmlstream.py @@ -35,7 +35,7 @@ from xml.parsers.expat import ExpatError import sleekxmpp from sleekxmpp.thirdparty.statemachine import StateMachine -from sleekxmpp.xmlstream import Scheduler, tostring +from sleekxmpp.xmlstream import Scheduler, tostring, cert from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET, ElementBase from sleekxmpp.xmlstream.handler import Waiter, XMLCallback from sleekxmpp.xmlstream.matcher import MatchXMLMask @@ -181,6 +181,9 @@ class XMLStream(object): #: The domain to try when querying DNS records. self.default_domain = '' + + #: The expected name of the server, for validation. + self._expected_server_name = '' #: The desired, or actual, address of the connected server. self.address = (host, int(port)) @@ -313,8 +316,9 @@ class XMLStream(object): self.dns_service = None self.add_event_handler('connected', self._handle_connected) - self.add_event_handler('session_start', self._start_keepalive) self.add_event_handler('disconnected', self._end_keepalive) + self.add_event_handler('session_start', self._start_keepalive) + self.add_event_handler('session_start', self._cert_expiration) def use_signals(self, signals=None): """Register signal handlers for ``SIGHUP`` and ``SIGTERM``. @@ -500,10 +504,17 @@ class XMLStream(object): self.socket.connect(self.address) if self.use_ssl and self.ssl_support: - cert = self.socket.getpeercert(binary_form=True) - cert = ssl.DER_cert_to_PEM_cert(cert) - log.debug('CERT: %s', cert) - self.event('ssl_cert', cert, direct=True) + self._der_cert = self.socket.getpeercert(binary_form=True) + pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert) + log.debug('CERT: %s', pem_cert) + + self.event('ssl_cert', pem_cert, direct=True) + try: + cert.verify(self._expected_server_name, self._der_cert) + except cert.CertificateError as err: + log.error(err.message) + self.event('ssl_invalid_cert', cert, direct=True) + self.disconnect(send_close=False) self.set_socket(self.socket, ignore=True) #this event is where you should set your application state @@ -767,12 +778,27 @@ class XMLStream(object): self.socket.socket = ssl_socket else: self.socket = ssl_socket - self.socket.do_handshake() - cert = self.socket.getpeercert(binary_form=True) - cert = ssl.DER_cert_to_PEM_cert(cert) - log.debug('CERT: %s', cert) - self.event('ssl_cert', cert, direct=True) + try: + self.socket.do_handshake() + except: + log.error('CERT: Invalid certificate trust chain.') + self.event('ssl_invalid_chain', direct=True) + self.disconnect(self.auto_reconnect, send_close=False) + return False + + self._der_cert = self.socket.getpeercert(binary_form=True) + pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert) + log.debug('CERT: %s', pem_cert) + self.event('ssl_cert', pem_cert, direct=True) + + try: + cert.verify(self._expected_server_name, self._der_cert) + except cert.CertificateError as err: + log.error(err.message) + self.event('ssl_invalid_cert', cert, direct=True) + if not self.event_handled('ssl_invalid_cert'): + self.disconnect(self.auto_reconnect, send_close=False) self.set_socket(self.socket) return True @@ -780,6 +806,26 @@ class XMLStream(object): log.warning("Tried to enable TLS, but ssl module not found.") return False + def _cert_expiration(self, event): + """Schedule an event for when the TLS certificate expires.""" + + def restart(): + log.warn("The server certificate has expired. Restarting.") + self.reconnect() + + cert_ttl = cert.get_ttl(self._der_cert) + if cert_ttl is None: + return + + if cert_ttl.days < 0: + log.warn('CERT: Certificate has expired.') + restart() + + log.info('CERT: Time until certificate expiration: %s' % cert_ttl) + self.schedule('Certificate Expiration', + cert_ttl.seconds, + restart) + def _start_keepalive(self, event): """Begin sending whitespace periodically to keep the connection alive. @@ -1298,9 +1344,16 @@ class XMLStream(object): except (Socket.error, ssl.SSLError) as serr: self.event('socket_error', serr, direct=True) log.error('Socket Error #%s: %s', serr.errno, serr.strerror) + except ValueError as e: + msg = e.message if hasattr(e, 'message') else e.args[0] + + if 'I/O operation on closed file' in msg: + log.error('Can not read from closed socket.') + else: + self.exception(e) except Exception as e: if not self.stop.is_set(): - log.exception('Connection error.') + log.error('Connection error.') self.exception(e) if not shutdown and not self.stop.is_set() \