diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index 327de7db..f7c0dea0 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -204,6 +204,9 @@ class XMLStream(asyncio.BaseProtocol): #: We use an ID prefix to ensure that all ID values are unique. self._id_prefix = '%s-' % uuid.uuid4() + # Current connection attempt (Future) + self._current_connection_attempt = None + #: A list of DNS results that have not yet been tried. self.dns_answers = None @@ -265,6 +268,7 @@ class XMLStream(asyncio.BaseProtocol): localhost """ + self.cancel_connection_attempt() if host and port: self.address = (host, int(port)) try: @@ -281,7 +285,7 @@ class XMLStream(asyncio.BaseProtocol): self.disable_starttls = disable_starttls self.event("connecting") - asyncio.async(self._connect_routine()) + self._current_connection_attempt = asyncio.async(self._connect_routine()) @asyncio.coroutine def _connect_routine(self): @@ -310,6 +314,7 @@ class XMLStream(asyncio.BaseProtocol): self.address[1], ssl=ssl_context, server_hostname=self.default_domain if self.use_ssl else None) + self.connect_loop_wait = 0 except Socket.gaierror as e: self.event('connection_failed', 'No DNS record available for %s' % self.default_domain) @@ -317,9 +322,7 @@ class XMLStream(asyncio.BaseProtocol): log.debug('Connection failed: %s', e) self.event("connection_failed", e) self.connect_loop_wait = self.connect_loop_wait * 2 + 1 - asyncio.async(self._connect_routine()) - else: - self.connect_loop_wait = 0 + self._current_connection_attempt = asyncio.async(self._connect_routine()) def process(self, *, forever=True, timeout=None): """Process all the available XMPP events (receiving or sending data on the @@ -431,6 +434,17 @@ class XMLStream(asyncio.BaseProtocol): self.transport = None self.socket = None + def cancel_connection_attempt(self): + """ + Immediatly cancel the current create_connection() Future. + This is useful when a client using slixmpp tries to connect + on flaky networks, where sometimes a connection just gets lost + and it needs to reconnect while the attempt is still ongoing. + """ + if self._current_connection_attempt: + self._current_connection_attempt.cancel() + self._current_connection_attempt = None + def disconnect(self, wait=2.0): """Close the XML stream and wait for an acknowldgement from the server for at most `wait` seconds. After the given number of seconds has @@ -444,6 +458,7 @@ class XMLStream(asyncio.BaseProtocol): :param wait: Time to wait for a response from the server. """ + self.cancel_connection_attempt() if self.transport: self.send_raw(self.stream_footer) self.schedule('Disconnect wait', wait, @@ -453,6 +468,7 @@ class XMLStream(asyncio.BaseProtocol): """ Forcibly close the connection """ + self.cancel_connection_attempt() if self.transport: self.transport.close() self.transport.abort()