diff --git a/poezio/mam.py b/poezio/mam.py index 50dad4a3..05275975 100644 --- a/poezio/mam.py +++ b/poezio/mam.py @@ -6,34 +6,43 @@ XEP-0313: Message Archive Management(MAM). """ +import asyncio +import logging import random from datetime import datetime, timedelta, timezone from hashlib import md5 -from typing import Optional, Callable +from typing import ( + AsyncIterable, + Callable, + Dict, + List, + Optional, +) -from slixmpp import JID +from slixmpp import JID, Message as SMessage from slixmpp.exceptions import IqError, IqTimeout from poezio.theming import get_theme from poezio import tabs from poezio import xhtml, colors from poezio.config import config -from poezio.text_buffer import TextBuffer -from poezio.ui.types import Message +from poezio.text_buffer import TextBuffer, HistoryGap +from poezio.ui.types import BaseMessage, Message +log = logging.getLogger(__name__) + class DiscoInfoException(Exception): pass class MAMQueryException(Exception): pass class NoMAMSupportException(Exception): pass -def add_line( - tab, - text_buffer: TextBuffer, +def make_line( + tab: tabs.Tab, text: str, time: datetime, nick: str, - top: bool, - ) -> None: + identifier: str = '', + ) -> Message: """Adds a textual entry in the TextBuffer""" # Convert to local timezone @@ -61,39 +70,40 @@ def add_line( color = xhtml.colors.get(color) color = (color, -1) else: - nick = nick.split('/')[0] - color = get_theme().COLOR_OWN_NICK - text_buffer.add_message( - Message( - txt=text, - time=time, - nickname=nick, - nick_color=color, - history=True, - user=None, - top=top, - ) + if nick.split('/')[0] == tab.core.xmpp.boundjid.bare: + color = get_theme().COLOR_OWN_NICK + else: + color = get_theme().COLOR_REMOTE_USER + nick = tab.get_nick() + return Message( + txt=text, + identifier=identifier, + time=time, + nickname=nick, + nick_color=color, + history=True, + user=None, ) -async def query( +async def get_mam_iterator( core, groupchat: bool, remote_jid: JID, amount: int, - reverse: bool, + reverse: bool = True, start: Optional[datetime] = None, end: Optional[datetime] = None, before: Optional[str] = None, - callback: Optional[Callable] = None, - ) -> None: + ) -> AsyncIterable[Message]: + """Get an async iterator for this mam query""" try: query_jid = remote_jid if groupchat else JID(core.xmpp.boundjid.bare) iq = await core.xmpp.plugin['xep_0030'].get_info(jid=query_jid) except (IqError, IqTimeout): - raise DiscoInfoException + raise DiscoInfoException() if 'urn:xmpp:mam:2' not in iq['disco_info'].get_features(): - raise NoMAMSupportException + raise NoMAMSupportException() args = { 'iterator': True, @@ -105,64 +115,66 @@ async def query( else: args['with_jid'] = remote_jid - args['rsm'] = {'max': amount} - if reverse: - if before is not None: - args['rsm']['before'] = before - else: - args['end'] = end - else: - args['rsm']['start'] = start - if before is not None: - args['rsm']['end'] = end - try: - results = core.xmpp['xep_0313'].retrieve(**args) - except (IqError, IqTimeout): - raise MAMQueryException - if callback is not None: - callback(results) - - return results + if amount > 0: + args['rsm'] = {'max': amount} + args['start'] = start + args['end'] = end + return core.xmpp['xep_0313'].retrieve(**args) -async def add_messages_to_buffer(tab, top: bool, results, amount: int) -> bool: - """Prepends or appends messages to the tab text_buffer""" +def _parse_message(msg: SMessage) -> Dict: + """Parse info inside a MAM forwarded message""" + forwarded = msg['mam_result']['forwarded'] + message = forwarded['stanza'] + return { + 'time': forwarded['delay']['stamp'], + 'nick': str(message['from']), + 'text': message['body'], + 'identifier': message['origin-id'] + } + +async def retrieve_messages(tab: tabs.Tab, + results: AsyncIterable[SMessage], + amount: int = 100) -> List[Message]: + """Run the MAM query and put messages in order""" text_buffer = tab._text_buffer msg_count = 0 msgs = [] - async for rsm in results: - if top: + to_add = [] + last_stanza_id = tab.last_stanza_id + try: + async for rsm in results: for msg in rsm['mam']['results']: if msg['mam_result']['forwarded']['stanza'] \ - .xml.find('{%s}%s' % ('jabber:client', 'body')) is not None: - msgs.append(msg) - if msg_count == amount: - tab.core.refresh_window() - return False + .xml.find('{%s}%s' % ('jabber:client', 'body')) is not None: + args = _parse_message(msg) + msgs.append(make_line(tab, **args)) + for msg in reversed(msgs): + to_add.append(msg) msg_count += 1 - msgs.reverse() - for msg in msgs: - forwarded = msg['mam_result']['forwarded'] - timestamp = forwarded['delay']['stamp'] - message = forwarded['stanza'] - tab.last_stanza_id = msg['mam_result']['id'] - nick = str(message['from']) - add_line(tab, text_buffer, message['body'], timestamp, nick, top) - else: - for msg in rsm['mam']['results']: - forwarded = msg['mam_result']['forwarded'] - timestamp = forwarded['delay']['stamp'] - message = forwarded['stanza'] - nick = str(message['from']) - add_line(tab, text_buffer, message['body'], timestamp, nick, top) - tab.core.refresh_window() - return False + if msg_count == amount: + to_add.reverse() + return to_add + msgs = [] + to_add.reverse() + return to_add + except (IqError, IqTimeout) as exc: + log.debug('Unable to complete MAM query: %s', exc, exc_info=True) + raise MAMQueryException('Query interrupted') -async def fetch_history(tab, end: Optional[datetime] = None, amount: Optional[int] = None): +async def fetch_history(tab: tabs.Tab, + start: Optional[datetime] = None, + end: Optional[datetime] = None, + amount: Optional[int] = None) -> None: remote_jid = tab.jid - before = tab.last_stanza_id + if not end: + for msg in tab._text_buffer.messages: + if isinstance(msg, Message): + end = msg.time + end -= timedelta(microseconds=1) + break if end is None: end = datetime.now() tzone = datetime.now().astimezone().tzinfo @@ -170,38 +182,74 @@ async def fetch_history(tab, end: Optional[datetime] = None, amount: Optional[in end = end.replace(tzinfo=None) end = datetime.strftime(end, '%Y-%m-%dT%H:%M:%SZ') - if amount >= 100: - amount = 99 + if start is not None: + start = start.replace(tzinfo=tzone).astimezone(tz=timezone.utc) + start = start.replace(tzinfo=None) + start = datetime.strftime(start, '%Y-%m-%dT%H:%M:%SZ') - groupchat = isinstance(tab, tabs.MucTab) - - results = await query( - tab.core, - groupchat, - remote_jid, - amount, - reverse=True, + mam_iterator = await get_mam_iterator( + core=tab.core, + groupchat=isinstance(tab, tabs.MucTab), + remote_jid=remote_jid, + amount=amount, end=end, - before=before, + start=start, + reverse=True, ) - query_status = await add_messages_to_buffer(tab, True, results, amount) - tab.query_status = query_status + return await retrieve_messages(tab, mam_iterator, amount) +async def fill_missing_history(tab: tabs.Tab, gap: HistoryGap) -> None: + start = gap.last_timestamp_before_leave + end = gap.first_timestamp_after_join + if start: + start = start + timedelta(seconds=1) + if end: + end = end - timedelta(seconds=1) + try: + messages = await fetch_history(tab, start=start, end=end, amount=999) + tab._text_buffer.add_history_messages(messages, gap=gap) + tab.core.refresh_window() + except (NoMAMSupportException, MAMQueryException, DiscoInfoException): + return + finally: + tab.query_status = False -async def on_tab_open(tab) -> None: +async def on_new_tab_open(tab: tabs.Tab) -> None: + """Called when opening a new tab""" amount = 2 * tab.text_win.height end = datetime.now() - tab.query_status = True for message in tab._text_buffer.messages: - time = message.time - if time < end: - end = time - end = end + timedelta(seconds=-1) + if isinstance(message, Message) and message.time < end: + end = message.time + break + end = end - timedelta(microseconds=1) try: - await fetch_history(tab, end=end, amount=amount) + messages = await fetch_history(tab, end=end, amount=amount) + tab._text_buffer.add_history_messages(messages) except (NoMAMSupportException, MAMQueryException, DiscoInfoException): - tab.query_status = False return None + finally: + tab.query_status = False + + +def schedule_tab_open(tab: tabs.Tab) -> None: + """Set the query status and schedule a MAM query""" + tab.query_status = True + asyncio.ensure_future(on_tab_open(tab)) + + +async def on_tab_open(tab: tabs.Tab) -> None: + gap = tab._text_buffer.find_last_gap_muc() + if gap is not None: + await fill_missing_history(tab, gap) + else: + await on_new_tab_open(tab) + + +def schedule_scroll_up(tab: tabs.Tab) -> None: + """Set query status and schedule a scroll up""" + tab.query_status = True + asyncio.ensure_future(on_scroll_up(tab)) async def on_scroll_up(tab) -> None: @@ -212,22 +260,22 @@ async def on_scroll_up(tab) -> None: # join if not already available. total, pos, height = len(tw.built_lines), tw.pos, tw.height rest = (total - pos) // height - # Not resetting the state of query_status here, it is changed only after the - # query is complete (in fetch_history) - # This is done to stop message repetition, eg: if the user presses PageUp continuously. - tab.query_status = True if rest > 1: + tab.query_status = False return None try: # XXX: Do we want to fetch a possibly variable number of messages? # (InfoTab changes height depending on the type of messages, see # `information_buffer_popup_on`). - await fetch_history(tab, amount=height) + messages = await fetch_history(tab, amount=height) + tab._text_buffer.add_history_messages(messages) except NoMAMSupportException: tab.core.information('MAM not supported for %r' % tab.jid, 'Info') return None except (MAMQueryException, DiscoInfoException): tab.core.information('An error occured when fetching MAM for %r' % tab.jid, 'Error') return None + finally: + tab.query_status = False diff --git a/poezio/tabs/basetabs.py b/poezio/tabs/basetabs.py index fbb0c4cf..490363f0 100644 --- a/poezio/tabs/basetabs.py +++ b/poezio/tabs/basetabs.py @@ -32,7 +32,6 @@ from typing import ( ) from poezio import ( - mam, poopt, timed_events, xhtml, @@ -926,7 +925,8 @@ class ChatTab(Tab): def on_scroll_up(self): if not self.query_status: - asyncio.ensure_future(mam.on_scroll_up(tab=self)) + from poezio import mam + mam.schedule_scroll_up(tab=self) return self.text_win.scroll_up(self.text_win.height - 1) def on_scroll_down(self): diff --git a/poezio/tabs/muctab.py b/poezio/tabs/muctab.py index 7b0f8a42..edf80bb6 100644 --- a/poezio/tabs/muctab.py +++ b/poezio/tabs/muctab.py @@ -170,7 +170,7 @@ class MucTab(ChatTab): status=status.message, show=status.show, seconds=seconds) - asyncio.ensure_future(mam.on_tab_open(self)) + mam.schedule_tab_open(self) def leave_room(self, message: str): if self.joined: