diff --git a/slixmpp/plugins/xep_0198/stream_management.py b/slixmpp/plugins/xep_0198/stream_management.py index 0200646a..1344235a 100644 --- a/slixmpp/plugins/xep_0198/stream_management.py +++ b/slixmpp/plugins/xep_0198/stream_management.py @@ -174,6 +174,9 @@ class XEP_0198(BasePlugin): def send_ack(self): """Send the current ack count to the server.""" + if not self.xmpp.transport: + log.debug('Disconnected: not sending ack') + return ack = stanza.Ack(self.xmpp) ack['h'] = self.handled self.xmpp.send_raw(str(ack)) @@ -198,20 +201,7 @@ class XEP_0198(BasePlugin): # We've already negotiated stream management, # so no need to do it again. return False - if not self.sm_id: - if 'bind' in self.xmpp.features: - enable = stanza.Enable(self.xmpp) - enable['resume'] = self.allow_resume - enable.send() - log.debug("enabling SM") - - waiter = Waiter('enabled_or_failed', - MatchMany([ - MatchXPath(stanza.Enabled.tag_name()), - MatchXPath(stanza.Failed.tag_name())])) - self.xmpp.register_handler(waiter) - result = await waiter.wait() - elif self.sm_id and self.allow_resume and 'bind' not in self.xmpp.features: + if self.sm_id and self.allow_resume and 'bind' not in self.xmpp.features: resume = stanza.Resume(self.xmpp) resume['h'] = self.handled resume['previd'] = self.sm_id @@ -229,6 +219,19 @@ class XEP_0198(BasePlugin): result = await waiter.wait() if result is not None and result.name == 'resumed': return True + self.xmpp.event("session_end") + if 'bind' in self.xmpp.features: + enable = stanza.Enable(self.xmpp) + enable['resume'] = self.allow_resume + enable.send() + log.debug("enabling SM") + + waiter = Waiter('enabled_or_failed', + MatchMany([ + MatchXPath(stanza.Enabled.tag_name()), + MatchXPath(stanza.Failed.tag_name())])) + self.xmpp.register_handler(waiter) + result = await waiter.wait() return False def _handle_enabled(self, stanza): diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index 6b890729..5074aa8c 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -12,7 +12,15 @@ :license: MIT, see LICENSE for more details """ -from typing import Optional, Set, Callable, Any +from typing import ( + Any, + Callable, + Iterable, + List, + Optional, + Set, + Union, +) import functools import logging @@ -21,7 +29,7 @@ import ssl import weakref import uuid -from asyncio import iscoroutinefunction, wait +from asyncio import iscoroutinefunction, wait, Future import xml.etree.ElementTree as ET @@ -224,12 +232,13 @@ class XMLStream(asyncio.BaseProtocol): self.disconnect_reason = None #: An asyncio Future being done when the stream is disconnected. - self.disconnected = asyncio.Future() + self.disconnected: Future = Future() self.add_event_handler('disconnected', self._remove_schedules) self.add_event_handler('session_start', self._start_keepalive) - - self._run_filters = None + + self._run_out_filters: Optional[Future] = None + self.__slow_tasks: List[Future] = [] @property def loop(self): @@ -250,6 +259,12 @@ class XMLStream(asyncio.BaseProtocol): """ return uuid.uuid4().hex + def _set_disconnected_future(self): + """Set the self.disconnected future on disconnect""" + if not self.disconnected.done(): + self.disconnected.set_result(True) + self.disconnected = asyncio.Future() + def connect(self, host='', port=0, use_ssl=False, force_starttls=True, disable_starttls=False): """Create a new socket and connect to the server. @@ -272,8 +287,8 @@ class XMLStream(asyncio.BaseProtocol): localhost """ - if self._run_filters is None: - self._run_filters = asyncio.ensure_future( + if self._run_out_filters is None or self._run_out_filters.done(): + self._run_out_filters = asyncio.ensure_future( self.run_filters(), loop=self.loop, ) @@ -418,10 +433,10 @@ class XMLStream(asyncio.BaseProtocol): if self.xml_depth == 0: # The stream's root element has closed, # terminating the stream. - self.end_session_on_disconnect = True log.debug("End of stream received") self.disconnect_reason = "End of stream" self.abort() + return elif self.xml_depth == 1: # A stanza is an XML element that is a direct child of # the root element, hence the check of depth == 1 @@ -463,11 +478,11 @@ class XMLStream(asyncio.BaseProtocol): self.parser = None self.transport = None self.socket = None - if self._run_filters: - self._run_filters.cancel() # Fire the events after cleanup if self.end_session_on_disconnect: + self._reset_sendq() self.event('session_end') + self._set_disconnected_future() self.event("disconnected", self.disconnect_reason or exception and exception.strerror) def cancel_connection_attempt(self): @@ -480,10 +495,8 @@ class XMLStream(asyncio.BaseProtocol): if self._current_connection_attempt: self._current_connection_attempt.cancel() self._current_connection_attempt = None - if self._run_filters: - self._run_filters.cancel() - def disconnect(self, wait: float = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> None: + def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future: """Close the XML stream and wait for an acknowldgement from the server for at most `wait` seconds. After the given number of seconds has passed without a response from the server, or when the server @@ -491,10 +504,13 @@ class XMLStream(asyncio.BaseProtocol): called. If wait is 0.0, this will call abort() directly without closing the stream. - Does nothing if we are not connected. + Does nothing but trigger the disconnected event if we are not connected. :param wait: Time to wait for a response from the server. - + :param reason: An optional reason for the disconnect. + :param ignore_send_queue: Boolean to toggle if we want to ignore + the in-flight stanzas and disconnect immediately. + :return: A future that ends when all code involved in the disconnect has ended """ # Compat: docs/getting_started/sendlogout.rst has been promoting # `disconnect(wait=True)` for ages. This doesn't mean anything to the @@ -504,50 +520,75 @@ class XMLStream(asyncio.BaseProtocol): wait = 2.0 if self.transport: + self.disconnect_reason = reason if self.waiting_queue.empty() or ignore_send_queue: - self.disconnect_reason = reason self.cancel_connection_attempt() - if wait > 0.0: - self.send_raw(self.stream_footer) - self.schedule('Disconnect wait', wait, - self.abort, repeat=False) + return asyncio.ensure_future( + self._end_stream_wait(wait, reason=reason), + loop=self.loop, + ) else: - asyncio.ensure_future( + return asyncio.ensure_future( self._consume_send_queue_before_disconnecting(reason, wait), loop=self.loop, ) else: + self._set_disconnected_future() self.event("disconnected", reason) + future = Future() + future.set_result(None) + return future async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float): """Wait until the send queue is empty before disconnecting""" - await self.waiting_queue.join() + try: + await asyncio.wait_for( + self.waiting_queue.join(), + wait, + loop=self.loop + ) + except asyncio.TimeoutError: + wait = 0 # we already consumed the timeout self.disconnect_reason = reason - self.cancel_connection_attempt() - if wait > 0.0: + await self._end_stream_wait(wait) + + async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None): + """ + Run abort() if we do not received the disconnected event + after a waiting time. + + :param wait: The waiting time (defaults to 2) + """ + try: self.send_raw(self.stream_footer) - self.schedule('Disconnect wait', wait, - self.abort, repeat=False) + await self.wait_until('disconnected', wait) + except asyncio.TimeoutError: + self.abort() + except NotConnectedError: + # We are not connected when sending the end of stream + # that means the disconnect has already been handled + pass def abort(self): """ Forcibly close the connection """ - self.cancel_connection_attempt() if self.transport: + self.cancel_connection_attempt() self.transport.close() self.transport.abort() self.event("killed") - self.disconnected.set_result(True) - self.disconnected = asyncio.Future() - self.event("disconnected", self.disconnect_reason) def reconnect(self, wait=2.0, reason="Reconnecting"): """Calls disconnect(), and once we are disconnected (after the timeout, or when the server acknowledgement is received), call connect() """ log.debug("reconnecting...") - self.add_event_handler('disconnected', lambda event: self.connect(), disposable=True) + async def handler(event): + # We yield here to allow synchronous handlers to work first + await asyncio.sleep(0, loop=self.loop) + self.connect() + self.add_event_handler('disconnected', handler, disposable=True) self.disconnect(wait, reason) def configure_socket(self): @@ -655,7 +696,6 @@ class XMLStream(asyncio.BaseProtocol): def _remove_schedules(self, event): """Remove some schedules that become pointless when disconnected""" self.cancel_schedule('Whitespace Keepalive') - self.cancel_schedule('Disconnect wait') def start_stream_handler(self, xml): """Perform any initialization actions, such as handshakes, @@ -833,7 +873,7 @@ class XMLStream(asyncio.BaseProtocol): """ log.debug("Event triggered: %s", name) - handlers = self.__event_handlers.get(name, []) + handlers = self.__event_handlers.get(name, [])[:] for handler in handlers: handler_callback, disposable = handler old_exception = getattr(data, 'exception', None) @@ -941,6 +981,18 @@ class XMLStream(asyncio.BaseProtocol): """ return xml + def _reset_sendq(self): + """Clear sending tasks on session end""" + # Cancel all pending slow send tasks + log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks)) + for slow_task in self.__slow_tasks: + slow_task.cancel() + self.__slow_tasks.clear() + # Purge pending stanzas + while not self.waiting_queue.empty(): + discarded = self.waiting_queue.get_nowait() + log.debug('Discarded stanza: %s', discarded) + async def _continue_slow_send( self, task: asyncio.Task, @@ -954,6 +1006,7 @@ class XMLStream(asyncio.BaseProtocol): :param set already_used: Filters already used on this outgoing stanza """ data = await task + self.__slow_tasks.remove(task) for filter in self.__filters['out']: if filter in already_used: continue @@ -975,7 +1028,6 @@ class XMLStream(asyncio.BaseProtocol): else: self.send_raw(data) - async def run_filters(self): """ Background loop that processes stanzas to send. @@ -995,11 +1047,13 @@ class XMLStream(asyncio.BaseProtocol): timeout=1, ) if pending: + self.slow_tasks.append(task) asyncio.ensure_future( self._continue_slow_send( task, already_run_filters - ) + ), + loop=self.loop, ) raise Exception("Slow coro, rescheduling") data = task.result() @@ -1142,9 +1196,15 @@ class XMLStream(asyncio.BaseProtocol): :param int timeout: Timeout """ fut = asyncio.Future() + def result_handler(event_data): + if not fut.done(): + fut.set_result(event_data) + else: + log.debug("Future registered on event '%s' was alredy done", event) + self.add_event_handler( event, - fut.set_result, + result_handler, disposable=True, ) - return await asyncio.wait_for(fut, timeout) + return await asyncio.wait_for(fut, timeout, loop=self.loop)