From 0b5f6cb0a8e86c2e0b76d5f7ba1ae4f2ad801ebc Mon Sep 17 00:00:00 2001 From: mathieui Date: Fri, 30 Apr 2021 18:40:33 +0200 Subject: [PATCH 1/2] xmlstream: fix slow tasks scheduling - wrong attribute used - some mistakes in the slow tasks function --- slixmpp/xmlstream/xmlstream.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index 7a94bf50..ab9b781d 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -1053,11 +1053,13 @@ class XMLStream(asyncio.BaseProtocol): """ data = await task self.__slow_tasks.remove(task) - for filter in self.__filters['out']: + if data is None: + return + for filter in self.__filters['out'][:]: if filter in already_used: continue if iscoroutinefunction(filter): - data = await task + data = await filter(data) else: data = filter(data) if data is None: @@ -1093,7 +1095,7 @@ class XMLStream(asyncio.BaseProtocol): timeout=1, ) if pending: - self.slow_tasks.append(task) + self.__slow_tasks.append(task) asyncio.ensure_future( self._continue_slow_send( task, @@ -1101,7 +1103,9 @@ class XMLStream(asyncio.BaseProtocol): ), loop=self.loop, ) - raise Exception("Slow coro, rescheduling") + raise ContinueQueue( + "Slow coroutine, rescheduling filters" + ) data = task.result() else: data = filter(data) From aaab58d229379492e4a2d5c9972db793662a29da Mon Sep 17 00:00:00 2001 From: mathieui Date: Fri, 30 Apr 2021 19:28:12 +0200 Subject: [PATCH 2/2] itests: add a simple slow filter test --- itests/test_slow_filters.py | 49 +++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 itests/test_slow_filters.py diff --git a/itests/test_slow_filters.py b/itests/test_slow_filters.py new file mode 100644 index 00000000..254a6b03 --- /dev/null +++ b/itests/test_slow_filters.py @@ -0,0 +1,49 @@ +import asyncio +import unittest +from slixmpp.test.integration import SlixIntegration +from slixmpp import Message + + +class TestSlowFilter(SlixIntegration): + async def asyncSetUp(self): + await super().asyncSetUp() + self.add_client( + self.envjid('CI_ACCOUNT1'), + self.envstr('CI_ACCOUNT1_PASSWORD'), + ) + self.add_client( + self.envjid('CI_ACCOUNT2'), + self.envstr('CI_ACCOUNT2_PASSWORD'), + ) + await self.connect_clients() + + async def test_filters(self): + """Make sure filters work""" + def add_a(stanza): + if isinstance(stanza, Message): + stanza['body'] = stanza['body'] + ' a' + return stanza + + async def add_b(stanza): + if isinstance(stanza, Message): + stanza['body'] = stanza['body'] + ' b' + return stanza + + async def add_c_wait(stanza): + if isinstance(stanza, Message): + await asyncio.sleep(2) + stanza['body'] = stanza['body'] + ' c' + return stanza + self.clients[0].add_filter('out', add_a) + self.clients[0].add_filter('out', add_b) + self.clients[0].add_filter('out', add_c_wait) + body = 'Msg body' + msg = self.clients[0].make_message( + mto=self.clients[1].boundjid, mbody=body, + ) + msg.send() + message = await self.clients[1].wait_until('message') + self.assertEqual(message['body'], body + ' a b c') + + +suite = unittest.TestLoader().loadTestsFromTestCase(TestSlowFilter)