Fix TLS with python 3.7

Use the "new" sslproto API instead of the deprecated TLS API.
Also remove the unused "socket" parameter in XMLStream.__init__.
This commit is contained in:
mathieui 2018-08-07 23:20:38 +02:00
parent a9abed6151
commit 7738a01311
No known key found for this signature in database
GPG key ID: C59F84CEEFD616E3
3 changed files with 25 additions and 32 deletions

View file

@ -6,9 +6,6 @@
See the file LICENSE for copying permission. See the file LICENSE for copying permission.
""" """
import asyncio
if hasattr(asyncio, 'sslproto'): # no ssl proto: very old asyncio = no need for this
asyncio.sslproto._is_sslproto_available=lambda: False
import logging import logging
logging.getLogger(__name__).addHandler(logging.NullHandler()) logging.getLogger(__name__).addHandler(logging.NullHandler())

View file

@ -12,7 +12,7 @@ from slixmpp.stanza import StreamFeatures
from slixmpp.xmlstream import register_stanza_plugin from slixmpp.xmlstream import register_stanza_plugin
from slixmpp.plugins import BasePlugin from slixmpp.plugins import BasePlugin
from slixmpp.xmlstream.matcher import MatchXPath from slixmpp.xmlstream.matcher import MatchXPath
from slixmpp.xmlstream.handler import Callback from slixmpp.xmlstream.handler import CoroutineCallback
from slixmpp.features.feature_starttls import stanza from slixmpp.features.feature_starttls import stanza
@ -28,7 +28,7 @@ class FeatureSTARTTLS(BasePlugin):
def plugin_init(self): def plugin_init(self):
self.xmpp.register_handler( self.xmpp.register_handler(
Callback('STARTTLS Proceed', CoroutineCallback('STARTTLS Proceed',
MatchXPath(stanza.Proceed.tag_name()), MatchXPath(stanza.Proceed.tag_name()),
self._handle_starttls_proceed, self._handle_starttls_proceed,
instream=True)) instream=True))
@ -58,8 +58,8 @@ class FeatureSTARTTLS(BasePlugin):
self.xmpp.send(features['starttls']) self.xmpp.send(features['starttls'])
return True return True
def _handle_starttls_proceed(self, proceed): async def _handle_starttls_proceed(self, proceed):
"""Restart the XML stream when TLS is accepted.""" """Restart the XML stream when TLS is accepted."""
log.debug("Starting TLS") log.debug("Starting TLS")
if self.xmpp.start_tls(): if await self.xmpp.start_tls():
self.xmpp.features.add('starttls') self.xmpp.features.add('starttls')

View file

@ -64,12 +64,12 @@ class XMLStream(asyncio.BaseProtocol):
:param int port: The port to use for the connection. Defaults to 0. :param int port: The port to use for the connection. Defaults to 0.
""" """
def __init__(self, socket=None, host='', port=0): def __init__(self, host='', port=0):
# The asyncio.Transport object provided by the connection_made() # The asyncio.Transport object provided by the connection_made()
# callback when we are connected # callback when we are connected
self.transport = None self.transport = None
# The socket the is used internally by the transport object # The socket that is used internally by the transport object
self.socket = None self.socket = None
self.connect_loop_wait = 0 self.connect_loop_wait = 0
@ -354,7 +354,10 @@ class XMLStream(asyncio.BaseProtocol):
""" """
self.event(self.event_when_connected) self.event(self.event_when_connected)
self.transport = transport self.transport = transport
self.socket = self.transport.get_extra_info("socket") self.socket = self.transport.get_extra_info(
"ssl_object",
default=self.transport.get_extra_info("socket")
)
self.init_parser() self.init_parser()
self.send_raw(self.stream_header) self.send_raw(self.stream_header)
self.dns_answers = None self.dns_answers = None
@ -527,36 +530,29 @@ class XMLStream(asyncio.BaseProtocol):
return self.ssl_context return self.ssl_context
def start_tls(self): async def start_tls(self):
"""Perform handshakes for TLS. """Perform handshakes for TLS.
If the handshake is successful, the XML stream will need If the handshake is successful, the XML stream will need
to be restarted. to be restarted.
""" """
self.event_when_connected = "tls_success" self.event_when_connected = "tls_success"
ssl_context = self.get_ssl_context() ssl_context = self.get_ssl_context()
ssl_connect_routine = self.loop.create_connection(lambda: self, ssl=ssl_context, try:
sock=self.socket, transp = await self.loop.start_tls(self.transport, self, ssl_context)
server_hostname=self.default_domain) except ssl.SSLError as e:
async def ssl_coro(): log.debug('SSL: Unable to connect', exc_info=True)
try: log.error('CERT: Invalid certificate trust chain.')
transp, prot = await ssl_connect_routine if not self.event_handled('ssl_invalid_chain'):
except ssl.SSLError as e: self.disconnect()
log.debug('SSL: Unable to connect', exc_info=True)
log.error('CERT: Invalid certificate trust chain.')
if not self.event_handled('ssl_invalid_chain'):
self.disconnect()
else:
self.event('ssl_invalid_chain', e)
else: else:
# Workaround for a regression in 3.4 where ssl_object was not set. self.event('ssl_invalid_chain', e)
der_cert = transp.get_extra_info("ssl_object", return False
default=transp.get_extra_info("socket")).getpeercert(True) der_cert = transp.get_extra_info("ssl_object").getpeercert(True)
pem_cert = ssl.DER_cert_to_PEM_cert(der_cert) pem_cert = ssl.DER_cert_to_PEM_cert(der_cert)
self.event('ssl_cert', pem_cert) self.event('ssl_cert', pem_cert)
self.connection_made(transp)
asyncio.ensure_future(ssl_coro()) return True
def _start_keepalive(self, event): def _start_keepalive(self, event):
"""Begin sending whitespace periodically to keep the connection alive. """Begin sending whitespace periodically to keep the connection alive.