diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index 5b245e11..a67c337d 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -12,7 +12,15 @@ :license: MIT, see LICENSE for more details """ -from typing import Optional, Set, Callable, Any +from typing import ( + Any, + Callable, + Iterable, + List, + Optional, + Set, + Union, +) import functools import logging @@ -21,7 +29,7 @@ import ssl import weakref import uuid -from asyncio import iscoroutinefunction, wait +from asyncio import iscoroutinefunction, wait, Future import xml.etree.ElementTree as ET @@ -228,8 +236,9 @@ class XMLStream(asyncio.BaseProtocol): self.add_event_handler('disconnected', self._remove_schedules) self.add_event_handler('session_start', self._start_keepalive) - + self._run_filters = None + self.__slow_tasks: List[Future] = [] @property def loop(self): @@ -465,6 +474,7 @@ class XMLStream(asyncio.BaseProtocol): self.socket = None # Fire the events after cleanup if self.end_session_on_disconnect: + self._reset_sendq() self.event('session_end') self.event("disconnected", self.disconnect_reason or exception and exception.strerror) @@ -937,6 +947,18 @@ class XMLStream(asyncio.BaseProtocol): """ return xml + def _reset_sendq(self): + """Clear sending tasks on session end""" + # Cancel all pending slow send tasks + log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks)) + for slow_task in self.__slow_tasks: + slow_task.cancel() + self.__slow_tasks.clear() + # Purge pending stanzas + while not self.waiting_queue.empty(): + discarded = self.waiting_queue.get_nowait() + log.debug('Discarded stanza: %s', discarded) + async def _continue_slow_send( self, task: asyncio.Task, @@ -950,6 +972,7 @@ class XMLStream(asyncio.BaseProtocol): :param set already_used: Filters already used on this outgoing stanza """ data = await task + self.__slow_tasks.remove(task) for filter in self.__filters['out']: if filter in already_used: continue @@ -990,6 +1013,7 @@ class XMLStream(asyncio.BaseProtocol): timeout=1, ) if pending: + self.slow_tasks.append(task) asyncio.ensure_future( self._continue_slow_send( task,