Merge branch 'rsm-fixes' into 'master'

XEP-0059 (RSM) - Some fixes

See merge request poezio/slixmpp!145
This commit is contained in:
mathieui 2021-03-09 19:25:26 +01:00
commit 1289cf575c
3 changed files with 168 additions and 87 deletions

View file

@ -8,6 +8,9 @@ XEP-0059: Result Set Management
:members: :members:
:exclude-members: session_bind, plugin_init, plugin_end :exclude-members: session_bind, plugin_init, plugin_end
.. autoclass:: ResultIterator
:members:
:member-order: bysource
Stanza elements Stanza elements
--------------- ---------------

View file

@ -5,9 +5,16 @@
# See the file LICENSE for copying permission. # See the file LICENSE for copying permission.
import logging import logging
import slixmpp from collections.abc import AsyncIterator
from slixmpp import Iq from typing import (
from slixmpp.plugins import BasePlugin, register_plugin Any,
Callable,
Dict,
Optional,
)
from slixmpp.stanza import Iq
from slixmpp.plugins import BasePlugin
from slixmpp.xmlstream import register_stanza_plugin from slixmpp.xmlstream import register_stanza_plugin
from slixmpp.plugins.xep_0059 import stanza, Set from slixmpp.plugins.xep_0059 import stanza, Set
from slixmpp.exceptions import XMPPError from slixmpp.exceptions import XMPPError
@ -16,41 +23,73 @@ from slixmpp.exceptions import XMPPError
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class ResultIterator: class ResultIterator(AsyncIterator):
""" """
An iterator for Result Set Management An iterator for Result Set Management
"""
def __init__(self, query, interface, results='substanzas', amount=10,
start=None, reverse=False, recv_interface=None,
pre_cb=None, post_cb=None):
"""
Arguments:
query -- The template query
interface -- The substanza of the query to send, for example disco_items
recv_interface -- The substanza of the query to receive, for example disco_items
results -- The query stanza's interface which provides a
countable list of query results.
amount -- The max amounts of items to request per iteration
start -- From which item id to start
reverse -- If True, page backwards through the results
pre_cb -- Callback to run before sending the stanza
post_cb -- Callback to run after receiving the reply
Example: Example:
.. code-block:: python
q = Iq() q = Iq()
q['to'] = 'pubsub.example.com' q['to'] = 'pubsub.example.com'
q['disco_items']['node'] = 'blog' q['disco_items']['node'] = 'blog'
for i in ResultIterator(q, 'disco_items', '10'): async for i in ResultIterator(q, 'disco_items', '10'):
print i['disco_items']['items'] print(i['disco_items']['items'])
"""
#: Template for the RSM query
query: Iq
#: Substanza of the query to send, e.g. "disco_items"
interface: str
#: Stanza interface on the query results providing the retrieved
#: elements (used to count them)
results: str
#: From which item id to start
start: Optional[str]
#: Amount of elements to retrieve for each page
amount: int
#: If True, page backwards through the results
reverse: bool
#: Callback to run before sending the stanza
pre_cb: Optional[Callable[[Iq], None]]
#: Callback to run after receiving the reply
post_cb: Optional[Callable[[Iq], None]]
#: Optional dict of Iq options (timeout, etc…) for Iq.send()
iq_options: Dict[str, Any]
def __init__(self, query: Iq, interface: str, results: str = 'substanzas',
amount: int = 10,
start: Optional[str] = None, reverse: bool = False,
recv_interface: Optional[str] = None,
pre_cb: Optional[Callable[[Iq], None]] = None,
post_cb: Optional[Callable[[Iq], None]] = None,
iq_options: Optional[Dict[str, Any]] = None):
"""
:param query: The template query
:param interface: The substanza of the query to send, for example
disco_items
:param recv_interface: The substanza of the query to receive, for
example disco_items
:param results: The query stanza's interface which provides a
countable list of query results.
:param amount: The max amounts of items to request per iteration
:param start: From which item id to start
:param reverse: If True, page backwards through the results
:param pre_cb: Callback to run before sending the stanza
:param post_cb: Callback to run after receiving the reply
:param iq_options: Optional dict of parameters for Iq.send
""" """
self.query = query self.query = query
self.amount = amount self.amount = amount
self.start = start self.start = start
if iq_options is None:
self.iq_options = {}
else:
self.iq_options = iq_options
self.interface = interface self.interface = interface
if recv_interface: if recv_interface is not None:
self.recv_interface = recv_interface self.recv_interface = recv_interface
else: else:
self.recv_interface = interface self.recv_interface = interface
@ -63,10 +102,10 @@ class ResultIterator:
def __aiter__(self): def __aiter__(self):
return self return self
async def __anext__(self): async def __anext__(self) -> Iq:
return await self.next() return await self.next()
async def next(self): async def next(self) -> Iq:
""" """
Return the next page of results from a query. Return the next page of results from a query.
@ -76,20 +115,21 @@ class ResultIterator:
""" """
if self._stop: if self._stop:
raise StopAsyncIteration raise StopAsyncIteration
if self.query[self.interface]['rsm']['before'] is None:
self.query[self.interface]['rsm']['before'] = self.reverse
self.query['id'] = self.query.stream.new_id() self.query['id'] = self.query.stream.new_id()
self.query[self.interface]['rsm']['max'] = str(self.amount) self.query[self.interface]['rsm']['max'] = str(self.amount)
if self.start and self.reverse: if self.start:
if self.reverse:
self.query[self.interface]['rsm']['before'] = self.start self.query[self.interface]['rsm']['before'] = self.start
elif self.start: else:
self.query[self.interface]['rsm']['after'] = self.start self.query[self.interface]['rsm']['after'] = self.start
elif self.reverse:
self.query[self.interface]['rsm']['before'] = True
try: try:
if self.pre_cb: if self.pre_cb:
self.pre_cb(self.query) self.pre_cb(self.query)
r = await self.query.send() r = await self.query.send(**self.iq_options)
if not r[self.recv_interface]['rsm']['first'] and \ if not r[self.recv_interface]['rsm']['first'] and \
not r[self.recv_interface]['rsm']['last']: not r[self.recv_interface]['rsm']['last']:
@ -118,7 +158,7 @@ class ResultIterator:
class XEP_0059(BasePlugin): class XEP_0059(BasePlugin):
""" """
XEP-0050: Result Set Management XEP-0059: Result Set Management
""" """
name = 'xep_0059' name = 'xep_0059'
@ -139,34 +179,40 @@ class XEP_0059(BasePlugin):
def session_bind(self, jid): def session_bind(self, jid):
self.xmpp['xep_0030'].add_feature(Set.namespace) self.xmpp['xep_0030'].add_feature(Set.namespace)
def iterate(self, stanza, interface, results='substanzas', amount=10, reverse=False, def iterate(self, stanza: Iq, interface: str, results: str = 'substanzas',
recv_interface=None, pre_cb=None, post_cb=None): amount: int = 10, reverse: bool = False,
recv_interface: Optional[str] = None,
pre_cb: Optional[Callable[[Iq], None]] = None,
post_cb: Optional[Callable[[Iq], None]] = None,
iq_options: Optional[Dict[str, Any]] = None
) -> ResultIterator:
""" """
Create a new result set iterator for a given stanza query. Create a new result set iterator for a given stanza query.
Arguments: :param stanza: A stanza object to serve as a template for
stanza -- A stanza object to serve as a template for
queries made each iteration. For example, a queries made each iteration. For example, a
basic disco#items query. basic disco#items query.
interface -- The name of the substanza to which the :param interface: The name of the substanza to which the
result set management stanza should be result set management stanza should be
appended in the query stanza. For example, appended in the query stanza. For example,
for disco#items queries the interface for disco#items queries the interface
'disco_items' should be used. 'disco_items' should be used.
recv_interface -- The name of the substanza from which the :param recv_interface: The name of the substanza from which the
result set management stanza should be result set management stanza should be
read in the result stanza. If unspecified, read in the result stanza. If unspecified,
it will be set to the same value as the it will be set to the same value as the
``interface`` parameter. ``interface`` parameter.
pre_cb -- Callback to run before sending each stanza e.g. :param pre_cb: Callback to run before sending each stanza e.g.
setting the MAM queryid and starting a stanza setting the MAM queryid and starting a stanza
collector. collector.
post_cb -- Callback to run after receiving each stanza e.g. :param post_cb: Callback to run after receiving each stanza e.g.
stopping a MAM stanza collector in order to stopping a MAM stanza collector in order to
gather results. gather results.
results -- The name of the interface containing the :param results: The name of the interface containing the
query results (typically just 'substanzas'). query results (typically just 'substanzas').
:param iq_options: Optional dict of parameters for Iq.send
""" """
return ResultIterator(stanza, interface, results, amount, reverse=reverse, return ResultIterator(stanza, interface, results, amount,
recv_interface=recv_interface, pre_cb=pre_cb, reverse=reverse, recv_interface=recv_interface,
post_cb=post_cb) pre_cb=pre_cb, post_cb=post_cb,
iq_options=iq_options)

View file

@ -512,30 +512,28 @@ class TestStreamDisco(SlixTest):
self.assertEqual(results, items, self.assertEqual(results, items,
"Unexpected items: %s" % results) "Unexpected items: %s" % results)
''' def testGetItemsIterators(self):
def testGetItemsIterator(self):
"""Test interaction between XEP-0030 and XEP-0059 plugins.""" """Test interaction between XEP-0030 and XEP-0059 plugins."""
iteration_finished = []
raised_exceptions = [] jids_found = set()
self.stream_start(mode='client', self.stream_start(mode='client',
plugins=['xep_0030', 'xep_0059']) plugins=['xep_0030', 'xep_0059'])
results = self.xmpp['xep_0030'].get_items(jid='foo@localhost', async def run_test():
iterator = await self.xmpp['xep_0030'].get_items(
jid='foo@localhost',
node='bar', node='bar',
iterator=True) iterator=True
results.amount = 10 )
iterator.amount = 10
def run_test(): async for page in iterator:
try: for item in page['disco_items']['items']:
results.next() jids_found.add(item[0])
except StopIteration: iteration_finished.append(True)
raised_exceptions.append(True)
t = threading.Thread(name="get_items_iterator",
target=run_test)
t.start()
test_run = self.xmpp.wrap(run_test())
self.wait_()
self.send(""" self.send("""
<iq id="2" type="get" to="foo@localhost"> <iq id="2" type="get" to="foo@localhost">
<query xmlns="http://jabber.org/protocol/disco#items" <query xmlns="http://jabber.org/protocol/disco#items"
@ -549,17 +547,51 @@ class TestStreamDisco(SlixTest):
self.recv(""" self.recv("""
<iq id="2" type="result" to="tester@localhost"> <iq id="2" type="result" to="tester@localhost">
<query xmlns="http://jabber.org/protocol/disco#items"> <query xmlns="http://jabber.org/protocol/disco#items">
<item jid="a@b" node="1"/>
<item jid="b@b" node="2"/>
<item jid="c@b" node="3"/>
<item jid="d@b" node="4"/>
<item jid="e@b" node="5"/>
<set xmlns="http://jabber.org/protocol/rsm"> <set xmlns="http://jabber.org/protocol/rsm">
<first index='0'>a@b</first>
<last>e@b</last>
<count>10</count>
</set> </set>
</query> </query>
</iq> </iq>
""") """)
self.wait_()
t.join() self.send("""
<iq id="3" type="get" to="foo@localhost">
self.assertEqual(raised_exceptions, [True], <query xmlns="http://jabber.org/protocol/disco#items"
"StopIteration was not raised: %s" % raised_exceptions) node="bar">
''' <set xmlns="http://jabber.org/protocol/rsm">
<max>10</max>
<after>e@b</after>
</set>
</query>
</iq>
""")
self.recv("""
<iq id="3" type="result" to="tester@localhost">
<query xmlns="http://jabber.org/protocol/disco#items">
<item jid="f@b" node="6"/>
<item jid="g@b" node="7"/>
<item jid="h@b" node="8"/>
<item jid="i@b" node="9"/>
<item jid="j@b" node="10"/>
<set xmlns="http://jabber.org/protocol/rsm">
<first index='5'>f@b</first>
<last>j@b</last>
<count>10</count>
</set>
</query>
</iq>
""")
expected_jids = {'%s@b' % i for i in 'abcdefghij'}
self.run_coro(test_run)
self.assertEqual(expected_jids, jids_found)
self.assertEqual(iteration_finished, [True])
suite = unittest.TestLoader().loadTestsFromTestCase(TestStreamDisco) suite = unittest.TestLoader().loadTestsFromTestCase(TestStreamDisco)