Fix TLS with python 3.7

Use the "new" sslproto API instead of the deprecated TLS API.
Also remove the unused "socket" parameter in XMLStream.__init__.
This commit is contained in:
mathieui 2018-08-07 23:20:38 +02:00
parent a9abed6151
commit 7738a01311
No known key found for this signature in database
GPG key ID: C59F84CEEFD616E3
3 changed files with 25 additions and 32 deletions

View file

@ -6,9 +6,6 @@
See the file LICENSE for copying permission.
"""
import asyncio
if hasattr(asyncio, 'sslproto'): # no ssl proto: very old asyncio = no need for this
asyncio.sslproto._is_sslproto_available=lambda: False
import logging
logging.getLogger(__name__).addHandler(logging.NullHandler())

View file

@ -12,7 +12,7 @@ from slixmpp.stanza import StreamFeatures
from slixmpp.xmlstream import register_stanza_plugin
from slixmpp.plugins import BasePlugin
from slixmpp.xmlstream.matcher import MatchXPath
from slixmpp.xmlstream.handler import Callback
from slixmpp.xmlstream.handler import CoroutineCallback
from slixmpp.features.feature_starttls import stanza
@ -28,7 +28,7 @@ class FeatureSTARTTLS(BasePlugin):
def plugin_init(self):
self.xmpp.register_handler(
Callback('STARTTLS Proceed',
CoroutineCallback('STARTTLS Proceed',
MatchXPath(stanza.Proceed.tag_name()),
self._handle_starttls_proceed,
instream=True))
@ -58,8 +58,8 @@ class FeatureSTARTTLS(BasePlugin):
self.xmpp.send(features['starttls'])
return True
def _handle_starttls_proceed(self, proceed):
async def _handle_starttls_proceed(self, proceed):
"""Restart the XML stream when TLS is accepted."""
log.debug("Starting TLS")
if self.xmpp.start_tls():
if await self.xmpp.start_tls():
self.xmpp.features.add('starttls')

View file

@ -64,12 +64,12 @@ class XMLStream(asyncio.BaseProtocol):
:param int port: The port to use for the connection. Defaults to 0.
"""
def __init__(self, socket=None, host='', port=0):
def __init__(self, host='', port=0):
# The asyncio.Transport object provided by the connection_made()
# callback when we are connected
self.transport = None
# The socket the is used internally by the transport object
# The socket that is used internally by the transport object
self.socket = None
self.connect_loop_wait = 0
@ -354,7 +354,10 @@ class XMLStream(asyncio.BaseProtocol):
"""
self.event(self.event_when_connected)
self.transport = transport
self.socket = self.transport.get_extra_info("socket")
self.socket = self.transport.get_extra_info(
"ssl_object",
default=self.transport.get_extra_info("socket")
)
self.init_parser()
self.send_raw(self.stream_header)
self.dns_answers = None
@ -527,36 +530,29 @@ class XMLStream(asyncio.BaseProtocol):
return self.ssl_context
def start_tls(self):
async def start_tls(self):
"""Perform handshakes for TLS.
If the handshake is successful, the XML stream will need
to be restarted.
"""
self.event_when_connected = "tls_success"
ssl_context = self.get_ssl_context()
ssl_connect_routine = self.loop.create_connection(lambda: self, ssl=ssl_context,
sock=self.socket,
server_hostname=self.default_domain)
async def ssl_coro():
try:
transp, prot = await ssl_connect_routine
except ssl.SSLError as e:
log.debug('SSL: Unable to connect', exc_info=True)
log.error('CERT: Invalid certificate trust chain.')
if not self.event_handled('ssl_invalid_chain'):
self.disconnect()
else:
self.event('ssl_invalid_chain', e)
try:
transp = await self.loop.start_tls(self.transport, self, ssl_context)
except ssl.SSLError as e:
log.debug('SSL: Unable to connect', exc_info=True)
log.error('CERT: Invalid certificate trust chain.')
if not self.event_handled('ssl_invalid_chain'):
self.disconnect()
else:
# Workaround for a regression in 3.4 where ssl_object was not set.
der_cert = transp.get_extra_info("ssl_object",
default=transp.get_extra_info("socket")).getpeercert(True)
pem_cert = ssl.DER_cert_to_PEM_cert(der_cert)
self.event('ssl_cert', pem_cert)
asyncio.ensure_future(ssl_coro())
self.event('ssl_invalid_chain', e)
return False
der_cert = transp.get_extra_info("ssl_object").getpeercert(True)
pem_cert = ssl.DER_cert_to_PEM_cert(der_cert)
self.event('ssl_cert', pem_cert)
self.connection_made(transp)
return True
def _start_keepalive(self, event):
"""Begin sending whitespace periodically to keep the connection alive.