Record the current connection attempt in a future and allow cancellation

It does not make sense to have competing connection attempts, as the
XMLStream class is not designed for this. On slow and unpredictable
networks, it means we could have two c2s connections opened, leading to
mayhem.
This commit is contained in:
mathieui 2017-11-23 00:00:37 +01:00
parent 80b9cd43b1
commit eab8c265f4
No known key found for this signature in database
GPG key ID: C59F84CEEFD616E3

View file

@ -204,6 +204,9 @@ class XMLStream(asyncio.BaseProtocol):
#: We use an ID prefix to ensure that all ID values are unique. #: We use an ID prefix to ensure that all ID values are unique.
self._id_prefix = '%s-' % uuid.uuid4() self._id_prefix = '%s-' % uuid.uuid4()
# Current connection attempt (Future)
self._current_connection_attempt = None
#: A list of DNS results that have not yet been tried. #: A list of DNS results that have not yet been tried.
self.dns_answers = None self.dns_answers = None
@ -265,6 +268,7 @@ class XMLStream(asyncio.BaseProtocol):
localhost localhost
""" """
self.cancel_connection_attempt()
if host and port: if host and port:
self.address = (host, int(port)) self.address = (host, int(port))
try: try:
@ -281,7 +285,7 @@ class XMLStream(asyncio.BaseProtocol):
self.disable_starttls = disable_starttls self.disable_starttls = disable_starttls
self.event("connecting") self.event("connecting")
asyncio.async(self._connect_routine()) self._current_connection_attempt = asyncio.async(self._connect_routine())
@asyncio.coroutine @asyncio.coroutine
def _connect_routine(self): def _connect_routine(self):
@ -310,6 +314,7 @@ class XMLStream(asyncio.BaseProtocol):
self.address[1], self.address[1],
ssl=ssl_context, ssl=ssl_context,
server_hostname=self.default_domain if self.use_ssl else None) server_hostname=self.default_domain if self.use_ssl else None)
self.connect_loop_wait = 0
except Socket.gaierror as e: except Socket.gaierror as e:
self.event('connection_failed', self.event('connection_failed',
'No DNS record available for %s' % self.default_domain) 'No DNS record available for %s' % self.default_domain)
@ -317,9 +322,7 @@ class XMLStream(asyncio.BaseProtocol):
log.debug('Connection failed: %s', e) log.debug('Connection failed: %s', e)
self.event("connection_failed", e) self.event("connection_failed", e)
self.connect_loop_wait = self.connect_loop_wait * 2 + 1 self.connect_loop_wait = self.connect_loop_wait * 2 + 1
asyncio.async(self._connect_routine()) self._current_connection_attempt = asyncio.async(self._connect_routine())
else:
self.connect_loop_wait = 0
def process(self, *, forever=True, timeout=None): def process(self, *, forever=True, timeout=None):
"""Process all the available XMPP events (receiving or sending data on the """Process all the available XMPP events (receiving or sending data on the
@ -431,6 +434,17 @@ class XMLStream(asyncio.BaseProtocol):
self.transport = None self.transport = None
self.socket = None self.socket = None
def cancel_connection_attempt(self):
"""
Immediatly cancel the current create_connection() Future.
This is useful when a client using slixmpp tries to connect
on flaky networks, where sometimes a connection just gets lost
and it needs to reconnect while the attempt is still ongoing.
"""
if self._current_connection_attempt:
self._current_connection_attempt.cancel()
self._current_connection_attempt = None
def disconnect(self, wait=2.0): def disconnect(self, wait=2.0):
"""Close the XML stream and wait for an acknowldgement from the server for """Close the XML stream and wait for an acknowldgement from the server for
at most `wait` seconds. After the given number of seconds has at most `wait` seconds. After the given number of seconds has
@ -444,6 +458,7 @@ class XMLStream(asyncio.BaseProtocol):
:param wait: Time to wait for a response from the server. :param wait: Time to wait for a response from the server.
""" """
self.cancel_connection_attempt()
if self.transport: if self.transport:
self.send_raw(self.stream_footer) self.send_raw(self.stream_footer)
self.schedule('Disconnect wait', wait, self.schedule('Disconnect wait', wait,
@ -453,6 +468,7 @@ class XMLStream(asyncio.BaseProtocol):
""" """
Forcibly close the connection Forcibly close the connection
""" """
self.cancel_connection_attempt()
if self.transport: if self.transport:
self.transport.close() self.transport.close()
self.transport.abort() self.transport.abort()