diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py index ac0fc256..7376d56d 100644 --- a/sleekxmpp/xmlstream/xmlstream.py +++ b/sleekxmpp/xmlstream/xmlstream.py @@ -493,7 +493,8 @@ class XMLStream(object): ssl_socket = ssl.wrap_socket(self.socket, ca_certs=self.ca_certs, - cert_reqs=cert_policy) + cert_reqs=cert_policy, + do_handshake_on_connect=False) if hasattr(self.socket, 'socket'): # We are using a testing socket, so preserve the top @@ -510,6 +511,17 @@ class XMLStream(object): log.debug("Connecting to %s:%s", domain, self.address[1]) self.socket.connect(self.address) + try: + self.socket.do_handshake() + except: + log.error('CERT: Invalid certificate trust chain.') + if not self.event_handled('ssl_invalid_chain'): + self.disconnect(self.auto_reconnect, send_close=False) + else: + self.event('ssl_invalid_chain', direct=True) + return False + + if self.use_ssl and self.ssl_support: self._der_cert = self.socket.getpeercert(binary_form=True) pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert) @@ -520,8 +532,10 @@ class XMLStream(object): cert.verify(self._expected_server_name, self._der_cert) except cert.CertificateError as err: log.error(err.message) - self.event('ssl_invalid_cert', cert, direct=True) - self.disconnect(send_close=False) + if not self.event_handled('ssl_invalid_cert'): + self.disconnect(send_close=False) + else: + self.event('ssl_invalid_cert', cert, direct=True) self.set_socket(self.socket, ignore=True) #this event is where you should set your application state @@ -790,8 +804,10 @@ class XMLStream(object): self.socket.do_handshake() except: log.error('CERT: Invalid certificate trust chain.') - self.event('ssl_invalid_chain', direct=True) - self.disconnect(self.auto_reconnect, send_close=False) + if not self.event_handled('ssl_invalid_chain'): + self.disconnect(self.auto_reconnect, send_close=False) + else: + self.event('ssl_invalid_chain', direct=True) return False self._der_cert = self.socket.getpeercert(binary_form=True) @@ -803,9 +819,10 @@ class XMLStream(object): cert.verify(self._expected_server_name, self._der_cert) except cert.CertificateError as err: log.error(err.message) - self.event('ssl_invalid_cert', cert, direct=True) if not self.event_handled('ssl_invalid_cert'): self.disconnect(self.auto_reconnect, send_close=False) + else: + self.event('ssl_invalid_cert', cert, direct=True) self.set_socket(self.socket) return True @@ -820,8 +837,12 @@ class XMLStream(object): return def restart(): - log.warn("The server certificate has expired. Restarting.") - self.reconnect() + if not self.event_handled('ssl_expired_cert'): + log.warn("The server certificate has expired. Restarting.") + self.reconnect() + else: + pem_cert = ssl.DER_cert_to_PEM_cert(self._der_cert) + self.event('ssl_expired_cert', pem_cert) cert_ttl = cert.get_ttl(self._der_cert) if cert_ttl is None: