Add better certificate handling.

Certificate host names are now matched (using DNS, SRV, XMPPAddr, and
Common Name), along with expiration check.

Scheduled event to reset the stream once the server's cert expires.

Handle invalid cert trust chains gracefully now.
This commit is contained in:
Lance Stout 2012-05-22 03:56:06 -07:00
parent 678e529efc
commit f49311ef9e
6 changed files with 245 additions and 13 deletions

View file

@ -112,7 +112,7 @@ setup(
license = 'MIT',
platforms = [ 'any' ],
packages = packages,
requires = [ 'dnspython' ],
requires = [ 'dnspython', 'pyasn1', 'pyasn1_modules' ],
classifiers = CLASSIFIERS,
cmdclass = {'test': TestCommand}
)

View file

@ -68,6 +68,7 @@ class BaseXMPP(XMLStream):
#: The JabberID (JID) used by this connection.
self.boundjid = JID(jid)
self._expected_server_name = self.boundjid.host
#: A dictionary mapping plugin names to plugins.
self.plugin = PluginManager(self)

View file

@ -149,6 +149,8 @@ class ClientXMPP(BaseXMPP):
address = (self.boundjid.host, 5222)
self.dns_service = 'xmpp-client'
self._expected_server_name = self.boundjid.host
return XMLStream.connect(self, address[0], address[1],
use_tls=use_tls, use_ssl=use_ssl,
reattempt=reattempt)

View file

@ -101,6 +101,9 @@ class ComponentXMPP(BaseXMPP):
host = self.server_host
if port is None:
port = self.server_port
self.server_name = self.boundjid.host
log.debug("Connecting to %s:%s", host, port)
return XMLStream.connect(self, host=host, port=port,
use_ssl=use_ssl,

173
sleekxmpp/xmlstream/cert.py Normal file
View file

@ -0,0 +1,173 @@
import logging
from datetime import datetime, timedelta
try:
from pyasn1.codec.der import decoder, encoder
from pyasn1.type.univ import Any, ObjectIdentifier, OctetString
from pyasn1.type.char import BMPString, IA5String, UTF8String
from pyasn1.type.useful import GeneralizedTime
from pyasn1_modules.rfc2459 import Certificate, DirectoryString, SubjectAltName, GeneralNames, GeneralName
from pyasn1_modules.rfc2459 import id_ce_subjectAltName as SUBJECT_ALT_NAME
from pyasn1_modules.rfc2459 import id_at_commonName as COMMON_NAME
XMPP_ADDR = ObjectIdentifier('1.3.6.1.5.5.7.8.5')
SRV_NAME = ObjectIdentifier('1.3.6.1.5.5.7.8.7')
HAVE_PYASN1 = True
except ImportError:
HAVE_PYASN1 = False
log = logging.getLogger(__name__)
class CertificateError(Exception):
pass
def decode_str(data):
encoding = 'utf-16-be' if isinstance(data, BMPString) else 'utf-8'
return bytes(data).decode(encoding)
def extract_names(raw_cert):
results = {'CN': set(),
'DNS': set(),
'SRV': set(),
'URI': set(),
'XMPPAddr': set()}
cert = decoder.decode(raw_cert, asn1Spec=Certificate())[0]
tbs = cert.getComponentByName('tbsCertificate')
subject = tbs.getComponentByName('subject')
extensions = tbs.getComponentByName('extensions')
# Extract the CommonName(s) from the cert.
for rdnss in subject:
for rdns in rdnss:
for name in rdns:
oid = name.getComponentByName('type')
value = name.getComponentByName('value')
if oid != COMMON_NAME:
continue
value = decoder.decode(value, asn1Spec=DirectoryString())[0]
value = decode_str(value.getComponent())
results['CN'].add(value)
# Extract the Subject Alternate Names (DNS, SRV, URI, XMPPAddr)
for extension in extensions:
oid = extension.getComponentByName('extnID')
if oid != SUBJECT_ALT_NAME:
continue
value = decoder.decode(extension.getComponentByName('extnValue'),
asn1Spec=OctetString())[0]
sa_names = decoder.decode(value, asn1Spec=SubjectAltName())[0]
for name in sa_names:
name_type = name.getName()
if name_type == 'dNSName':
results['DNS'].add(decode_str(name.getComponent()))
if name_type == 'uniformResourceIdentifier':
value = decode_str(name.getComponent())
if value.startswith('xmpp:'):
results['URI'].add(value[5:])
elif name_type == 'otherName':
name = name.getComponent()
oid = name.getComponentByName('type-id')
value = name.getComponentByName('value')
if oid == XMPP_ADDR:
value = decoder.decode(value, asn1Spec=UTF8String())[0]
results['XMPPAddr'].add(decode_str(value))
elif oid == SRV_NAME:
value = decoder.decode(value, asn1Spec=IA5String())[0]
results['SRV'].add(decode_str(value))
return results
def extract_dates(raw_cert):
if not HAVE_PYASN1:
log.warning("Could not find pyasn1 module. " + \
"SSL certificate expiration COULD NOT BE VERIFIED.")
return None, None
cert = decoder.decode(raw_cert, asn1Spec=Certificate())[0]
tbs = cert.getComponentByName('tbsCertificate')
validity = tbs.getComponentByName('validity')
not_before = validity.getComponentByName('notBefore')
not_before = str(not_before.getComponent())
not_after = validity.getComponentByName('notAfter')
not_after = str(not_after.getComponent())
if isinstance(not_before, GeneralizedTime):
not_before = datetime.strptime(not_before, '%Y%m%d%H%M%SZ')
else:
not_before = datetime.strptime(not_before, '%y%m%d%H%M%SZ')
if isinstance(not_after, GeneralizedTime):
not_after = datetime.strptime(not_after, '%Y%m%d%H%M%SZ')
else:
not_after = datetime.strptime(not_after, '%y%m%d%H%M%SZ')
return not_before, not_after
def get_ttl(raw_cert):
not_before, not_after = extract_dates(raw_cert)
if not_after is None:
return None
return not_after - datetime.utcnow()
def verify(expected, raw_cert):
if not HAVE_PYASN1:
log.warning("Could not find pyasn1 module. " + \
"SSL certificate COULD NOT BE VERIFIED.")
return
not_before, not_after = extract_dates(raw_cert)
cert_names = extract_names(raw_cert)
now = datetime.utcnow()
if not_before > now:
raise CertificateError(
'Certificate has not entered its valid date range.')
if not_after <= now:
raise CertificateError(
'Certificate has expired.')
expected_wild = expected[expected.index('.'):]
expected_srv = '_xmpp-client.%s' % expected
for name in cert_names['XMPPAddr']:
if name == expected:
return True
for name in cert_names['SRV']:
if name == expected_srv or name == expected:
return True
for name in cert_names['DNS']:
if name == expected:
return True
if name.startswith('*'):
name_wild = name[name.index('.'):]
if expected_wild == name_wild:
return True
for name in cert_names['URI']:
if name == expected:
return True
for name in cert_names['CN']:
if name == expected:
return True
raise CertificateError(
'Could not match certficate against hostname: %s' % expected)

View file

@ -35,7 +35,7 @@ from xml.parsers.expat import ExpatError
import sleekxmpp
from sleekxmpp.thirdparty.statemachine import StateMachine
from sleekxmpp.xmlstream import Scheduler, tostring
from sleekxmpp.xmlstream import Scheduler, tostring, cert
from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET, ElementBase
from sleekxmpp.xmlstream.handler import Waiter, XMLCallback
from sleekxmpp.xmlstream.matcher import MatchXMLMask
@ -181,6 +181,9 @@ class XMLStream(object):
#: The domain to try when querying DNS records.
self.default_domain = ''
#: The expected name of the server, for validation.
self._expected_server_name = ''
#: The desired, or actual, address of the connected server.
self.address = (host, int(port))
@ -313,8 +316,9 @@ class XMLStream(object):
self.dns_service = None
self.add_event_handler('connected', self._handle_connected)
self.add_event_handler('session_start', self._start_keepalive)
self.add_event_handler('disconnected', self._end_keepalive)
self.add_event_handler('session_start', self._start_keepalive)
self.add_event_handler('session_start', self._cert_expiration)
def use_signals(self, signals=None):
"""Register signal handlers for ``SIGHUP`` and ``SIGTERM``.
@ -500,10 +504,17 @@ class XMLStream(object):
self.socket.connect(self.address)
if self.use_ssl and self.ssl_support:
cert = self.socket.getpeercert(binary_form=True)
cert = ssl.DER_cert_to_PEM_cert(cert)
log.debug('CERT: %s', cert)
self.event('ssl_cert', cert, direct=True)
self._der_cert = self.socket.getpeercert(binary_form=True)
pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert)
log.debug('CERT: %s', pem_cert)
self.event('ssl_cert', pem_cert, direct=True)
try:
cert.verify(self._expected_server_name, self._der_cert)
except cert.CertificateError as err:
log.error(err.message)
self.event('ssl_invalid_cert', cert, direct=True)
self.disconnect(send_close=False)
self.set_socket(self.socket, ignore=True)
#this event is where you should set your application state
@ -767,12 +778,27 @@ class XMLStream(object):
self.socket.socket = ssl_socket
else:
self.socket = ssl_socket
self.socket.do_handshake()
cert = self.socket.getpeercert(binary_form=True)
cert = ssl.DER_cert_to_PEM_cert(cert)
log.debug('CERT: %s', cert)
self.event('ssl_cert', cert, direct=True)
try:
self.socket.do_handshake()
except:
log.error('CERT: Invalid certificate trust chain.')
self.event('ssl_invalid_chain', direct=True)
self.disconnect(self.auto_reconnect, send_close=False)
return False
self._der_cert = self.socket.getpeercert(binary_form=True)
pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert)
log.debug('CERT: %s', pem_cert)
self.event('ssl_cert', pem_cert, direct=True)
try:
cert.verify(self._expected_server_name, self._der_cert)
except cert.CertificateError as err:
log.error(err.message)
self.event('ssl_invalid_cert', cert, direct=True)
if not self.event_handled('ssl_invalid_cert'):
self.disconnect(self.auto_reconnect, send_close=False)
self.set_socket(self.socket)
return True
@ -780,6 +806,26 @@ class XMLStream(object):
log.warning("Tried to enable TLS, but ssl module not found.")
return False
def _cert_expiration(self, event):
"""Schedule an event for when the TLS certificate expires."""
def restart():
log.warn("The server certificate has expired. Restarting.")
self.reconnect()
cert_ttl = cert.get_ttl(self._der_cert)
if cert_ttl is None:
return
if cert_ttl.days < 0:
log.warn('CERT: Certificate has expired.')
restart()
log.info('CERT: Time until certificate expiration: %s' % cert_ttl)
self.schedule('Certificate Expiration',
cert_ttl.seconds,
restart)
def _start_keepalive(self, event):
"""Begin sending whitespace periodically to keep the connection alive.
@ -1298,9 +1344,16 @@ class XMLStream(object):
except (Socket.error, ssl.SSLError) as serr:
self.event('socket_error', serr, direct=True)
log.error('Socket Error #%s: %s', serr.errno, serr.strerror)
except ValueError as e:
msg = e.message if hasattr(e, 'message') else e.args[0]
if 'I/O operation on closed file' in msg:
log.error('Can not read from closed socket.')
else:
self.exception(e)
except Exception as e:
if not self.stop.is_set():
log.exception('Connection error.')
log.error('Connection error.')
self.exception(e)
if not shutdown and not self.stop.is_set() \