diff --git a/slixmpp/plugins/xep_0313/mam.py b/slixmpp/plugins/xep_0313/mam.py index e3268e9b..d407e2bf 100644 --- a/slixmpp/plugins/xep_0313/mam.py +++ b/slixmpp/plugins/xep_0313/mam.py @@ -8,9 +8,11 @@ import logging -import slixmpp +from datetime import datetime +from typing import Any, Dict, Callable, Optional, Awaitable + +from slixmpp import JID from slixmpp.stanza import Message, Iq -from slixmpp.exceptions import XMPPError from slixmpp.xmlstream.handler import Collector from slixmpp.xmlstream.matcher import StanzaPath from slixmpp.xmlstream import register_stanza_plugin @@ -41,8 +43,32 @@ class XEP_0313(BasePlugin): register_stanza_plugin(stanza.MAM, self.xmpp['xep_0059'].stanza.Set) register_stanza_plugin(stanza.Fin, self.xmpp['xep_0059'].stanza.Set) - def retrieve(self, jid=None, start=None, end=None, with_jid=None, ifrom=None, - reverse=False, timeout=None, callback=None, iterator=False, rsm=None): + def retrieve( + self, + jid: Optional[JID] = None, + start: Optional[datetime] = None, + end: Optional[datetime] = None, + with_jid: Optional[JID] = None, + ifrom: Optional[JID] = None, + reverse: bool = False, + timeout: int = None, + callback: Callable[[Iq], None] = None, + iterator: bool = False, + rsm: Optional[Dict[str, Any]] = None + ) -> Awaitable: + """ + Send a MAM query and retrieve the results. + + :param JID jid: Entity holding the MAM records + :param datetime start,end: MAM query temporal boundaries + :param JID with_jid: Filter results on this JID + :param JID ifrom: To change the from address of the query + :param bool reverse: Get the results in reverse order + :param int timeout: IQ timeout + :param func callback: Custom callback for handling results + :param bool iterator: Use RSM and iterate over a paginated query + :param dict rsm: RSM custom options + """ iq = self.xmpp.Iq() query_id = iq['id'] @@ -57,11 +83,11 @@ class XEP_0313(BasePlugin): if rsm: for key, value in rsm.items(): iq['mam']['rsm'][key] = str(value) - if key is 'max': + if key == 'max': amount = value - cb_data = {} - def pre_cb(query): + + def pre_cb(query: Iq) -> None: query['mam']['queryid'] = query['id'] collector = Collector( 'MAM_Results_%s' % query_id, @@ -69,7 +95,7 @@ class XEP_0313(BasePlugin): self.xmpp.register_handler(collector) cb_data['collector'] = collector - def post_cb(result): + def post_cb(result: Iq) -> None: results = cb_data['collector'].stop() if result['type'] == 'result': result['mam']['results'] = results @@ -84,7 +110,7 @@ class XEP_0313(BasePlugin): StanzaPath('message/mam_result@queryid=%s' % query_id)) self.xmpp.register_handler(collector) - def wrapped_cb(iq): + def wrapped_cb(iq: Iq) -> None: results = collector.stop() if iq['type'] == 'result': iq['mam']['results'] = results