diff --git a/slixmpp/__init__.py b/slixmpp/__init__.py index 46804bf5..0730cc60 100644 --- a/slixmpp/__init__.py +++ b/slixmpp/__init__.py @@ -6,9 +6,6 @@ See the file LICENSE for copying permission. """ -import asyncio -if hasattr(asyncio, 'sslproto'): # no ssl proto: very old asyncio = no need for this - asyncio.sslproto._is_sslproto_available=lambda: False import logging logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/slixmpp/features/feature_starttls/starttls.py b/slixmpp/features/feature_starttls/starttls.py index d472dad7..7e3af992 100644 --- a/slixmpp/features/feature_starttls/starttls.py +++ b/slixmpp/features/feature_starttls/starttls.py @@ -12,7 +12,7 @@ from slixmpp.stanza import StreamFeatures from slixmpp.xmlstream import register_stanza_plugin from slixmpp.plugins import BasePlugin from slixmpp.xmlstream.matcher import MatchXPath -from slixmpp.xmlstream.handler import Callback +from slixmpp.xmlstream.handler import CoroutineCallback from slixmpp.features.feature_starttls import stanza @@ -28,7 +28,7 @@ class FeatureSTARTTLS(BasePlugin): def plugin_init(self): self.xmpp.register_handler( - Callback('STARTTLS Proceed', + CoroutineCallback('STARTTLS Proceed', MatchXPath(stanza.Proceed.tag_name()), self._handle_starttls_proceed, instream=True)) @@ -58,8 +58,8 @@ class FeatureSTARTTLS(BasePlugin): self.xmpp.send(features['starttls']) return True - def _handle_starttls_proceed(self, proceed): + async def _handle_starttls_proceed(self, proceed): """Restart the XML stream when TLS is accepted.""" log.debug("Starting TLS") - if self.xmpp.start_tls(): + if await self.xmpp.start_tls(): self.xmpp.features.add('starttls') diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index d5dce586..1fa07b0c 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -64,12 +64,12 @@ class XMLStream(asyncio.BaseProtocol): :param int port: The port to use for the connection. Defaults to 0. """ - def __init__(self, socket=None, host='', port=0): + def __init__(self, host='', port=0): # The asyncio.Transport object provided by the connection_made() # callback when we are connected self.transport = None - # The socket the is used internally by the transport object + # The socket that is used internally by the transport object self.socket = None self.connect_loop_wait = 0 @@ -354,7 +354,10 @@ class XMLStream(asyncio.BaseProtocol): """ self.event(self.event_when_connected) self.transport = transport - self.socket = self.transport.get_extra_info("socket") + self.socket = self.transport.get_extra_info( + "ssl_object", + default=self.transport.get_extra_info("socket") + ) self.init_parser() self.send_raw(self.stream_header) self.dns_answers = None @@ -527,36 +530,29 @@ class XMLStream(asyncio.BaseProtocol): return self.ssl_context - def start_tls(self): + async def start_tls(self): """Perform handshakes for TLS. If the handshake is successful, the XML stream will need to be restarted. """ self.event_when_connected = "tls_success" - ssl_context = self.get_ssl_context() - ssl_connect_routine = self.loop.create_connection(lambda: self, ssl=ssl_context, - sock=self.socket, - server_hostname=self.default_domain) - async def ssl_coro(): - try: - transp, prot = await ssl_connect_routine - except ssl.SSLError as e: - log.debug('SSL: Unable to connect', exc_info=True) - log.error('CERT: Invalid certificate trust chain.') - if not self.event_handled('ssl_invalid_chain'): - self.disconnect() - else: - self.event('ssl_invalid_chain', e) + try: + transp = await self.loop.start_tls(self.transport, self, ssl_context) + except ssl.SSLError as e: + log.debug('SSL: Unable to connect', exc_info=True) + log.error('CERT: Invalid certificate trust chain.') + if not self.event_handled('ssl_invalid_chain'): + self.disconnect() else: - # Workaround for a regression in 3.4 where ssl_object was not set. - der_cert = transp.get_extra_info("ssl_object", - default=transp.get_extra_info("socket")).getpeercert(True) - pem_cert = ssl.DER_cert_to_PEM_cert(der_cert) - self.event('ssl_cert', pem_cert) - - asyncio.ensure_future(ssl_coro()) + self.event('ssl_invalid_chain', e) + return False + der_cert = transp.get_extra_info("ssl_object").getpeercert(True) + pem_cert = ssl.DER_cert_to_PEM_cert(der_cert) + self.event('ssl_cert', pem_cert) + self.connection_made(transp) + return True def _start_keepalive(self, event): """Begin sending whitespace periodically to keep the connection alive.