MAM: many changes

- Fix color & nicks in one to one chats
- Make poezio-facing functions "schedules" to avoid races on tab query
  state
- Rename functions
- Use a different behavior when filling a history gap and populating a
  new tab in a MUC
This commit is contained in:
mathieui 2020-05-22 01:36:13 +02:00
parent 54339ee7e0
commit d174e1fa35
3 changed files with 150 additions and 102 deletions

View file

@ -6,34 +6,43 @@
XEP-0313: Message Archive Management(MAM). XEP-0313: Message Archive Management(MAM).
""" """
import asyncio
import logging
import random import random
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from hashlib import md5 from hashlib import md5
from typing import Optional, Callable from typing import (
AsyncIterable,
Callable,
Dict,
List,
Optional,
)
from slixmpp import JID from slixmpp import JID, Message as SMessage
from slixmpp.exceptions import IqError, IqTimeout from slixmpp.exceptions import IqError, IqTimeout
from poezio.theming import get_theme from poezio.theming import get_theme
from poezio import tabs from poezio import tabs
from poezio import xhtml, colors from poezio import xhtml, colors
from poezio.config import config from poezio.config import config
from poezio.text_buffer import TextBuffer from poezio.text_buffer import TextBuffer, HistoryGap
from poezio.ui.types import Message from poezio.ui.types import BaseMessage, Message
log = logging.getLogger(__name__)
class DiscoInfoException(Exception): pass class DiscoInfoException(Exception): pass
class MAMQueryException(Exception): pass class MAMQueryException(Exception): pass
class NoMAMSupportException(Exception): pass class NoMAMSupportException(Exception): pass
def add_line( def make_line(
tab, tab: tabs.Tab,
text_buffer: TextBuffer,
text: str, text: str,
time: datetime, time: datetime,
nick: str, nick: str,
top: bool, identifier: str = '',
) -> None: ) -> Message:
"""Adds a textual entry in the TextBuffer""" """Adds a textual entry in the TextBuffer"""
# Convert to local timezone # Convert to local timezone
@ -61,39 +70,40 @@ def add_line(
color = xhtml.colors.get(color) color = xhtml.colors.get(color)
color = (color, -1) color = (color, -1)
else: else:
nick = nick.split('/')[0] if nick.split('/')[0] == tab.core.xmpp.boundjid.bare:
color = get_theme().COLOR_OWN_NICK color = get_theme().COLOR_OWN_NICK
text_buffer.add_message( else:
Message( color = get_theme().COLOR_REMOTE_USER
txt=text, nick = tab.get_nick()
time=time, return Message(
nickname=nick, txt=text,
nick_color=color, identifier=identifier,
history=True, time=time,
user=None, nickname=nick,
top=top, nick_color=color,
) history=True,
user=None,
) )
async def query( async def get_mam_iterator(
core, core,
groupchat: bool, groupchat: bool,
remote_jid: JID, remote_jid: JID,
amount: int, amount: int,
reverse: bool, reverse: bool = True,
start: Optional[datetime] = None, start: Optional[datetime] = None,
end: Optional[datetime] = None, end: Optional[datetime] = None,
before: Optional[str] = None, before: Optional[str] = None,
callback: Optional[Callable] = None, ) -> AsyncIterable[Message]:
) -> None: """Get an async iterator for this mam query"""
try: try:
query_jid = remote_jid if groupchat else JID(core.xmpp.boundjid.bare) query_jid = remote_jid if groupchat else JID(core.xmpp.boundjid.bare)
iq = await core.xmpp.plugin['xep_0030'].get_info(jid=query_jid) iq = await core.xmpp.plugin['xep_0030'].get_info(jid=query_jid)
except (IqError, IqTimeout): except (IqError, IqTimeout):
raise DiscoInfoException raise DiscoInfoException()
if 'urn:xmpp:mam:2' not in iq['disco_info'].get_features(): if 'urn:xmpp:mam:2' not in iq['disco_info'].get_features():
raise NoMAMSupportException raise NoMAMSupportException()
args = { args = {
'iterator': True, 'iterator': True,
@ -105,64 +115,66 @@ async def query(
else: else:
args['with_jid'] = remote_jid args['with_jid'] = remote_jid
args['rsm'] = {'max': amount} if amount > 0:
if reverse: args['rsm'] = {'max': amount}
if before is not None: args['start'] = start
args['rsm']['before'] = before args['end'] = end
else: return core.xmpp['xep_0313'].retrieve(**args)
args['end'] = end
else:
args['rsm']['start'] = start
if before is not None:
args['rsm']['end'] = end
try:
results = core.xmpp['xep_0313'].retrieve(**args)
except (IqError, IqTimeout):
raise MAMQueryException
if callback is not None:
callback(results)
return results
async def add_messages_to_buffer(tab, top: bool, results, amount: int) -> bool: def _parse_message(msg: SMessage) -> Dict:
"""Prepends or appends messages to the tab text_buffer""" """Parse info inside a MAM forwarded message"""
forwarded = msg['mam_result']['forwarded']
message = forwarded['stanza']
return {
'time': forwarded['delay']['stamp'],
'nick': str(message['from']),
'text': message['body'],
'identifier': message['origin-id']
}
async def retrieve_messages(tab: tabs.Tab,
results: AsyncIterable[SMessage],
amount: int = 100) -> List[Message]:
"""Run the MAM query and put messages in order"""
text_buffer = tab._text_buffer text_buffer = tab._text_buffer
msg_count = 0 msg_count = 0
msgs = [] msgs = []
async for rsm in results: to_add = []
if top: last_stanza_id = tab.last_stanza_id
try:
async for rsm in results:
for msg in rsm['mam']['results']: for msg in rsm['mam']['results']:
if msg['mam_result']['forwarded']['stanza'] \ if msg['mam_result']['forwarded']['stanza'] \
.xml.find('{%s}%s' % ('jabber:client', 'body')) is not None: .xml.find('{%s}%s' % ('jabber:client', 'body')) is not None:
msgs.append(msg) args = _parse_message(msg)
if msg_count == amount: msgs.append(make_line(tab, **args))
tab.core.refresh_window() for msg in reversed(msgs):
return False to_add.append(msg)
msg_count += 1 msg_count += 1
msgs.reverse() if msg_count == amount:
for msg in msgs: to_add.reverse()
forwarded = msg['mam_result']['forwarded'] return to_add
timestamp = forwarded['delay']['stamp'] msgs = []
message = forwarded['stanza'] to_add.reverse()
tab.last_stanza_id = msg['mam_result']['id'] return to_add
nick = str(message['from']) except (IqError, IqTimeout) as exc:
add_line(tab, text_buffer, message['body'], timestamp, nick, top) log.debug('Unable to complete MAM query: %s', exc, exc_info=True)
else: raise MAMQueryException('Query interrupted')
for msg in rsm['mam']['results']:
forwarded = msg['mam_result']['forwarded']
timestamp = forwarded['delay']['stamp']
message = forwarded['stanza']
nick = str(message['from'])
add_line(tab, text_buffer, message['body'], timestamp, nick, top)
tab.core.refresh_window()
return False
async def fetch_history(tab, end: Optional[datetime] = None, amount: Optional[int] = None): async def fetch_history(tab: tabs.Tab,
start: Optional[datetime] = None,
end: Optional[datetime] = None,
amount: Optional[int] = None) -> None:
remote_jid = tab.jid remote_jid = tab.jid
before = tab.last_stanza_id if not end:
for msg in tab._text_buffer.messages:
if isinstance(msg, Message):
end = msg.time
end -= timedelta(microseconds=1)
break
if end is None: if end is None:
end = datetime.now() end = datetime.now()
tzone = datetime.now().astimezone().tzinfo tzone = datetime.now().astimezone().tzinfo
@ -170,38 +182,74 @@ async def fetch_history(tab, end: Optional[datetime] = None, amount: Optional[in
end = end.replace(tzinfo=None) end = end.replace(tzinfo=None)
end = datetime.strftime(end, '%Y-%m-%dT%H:%M:%SZ') end = datetime.strftime(end, '%Y-%m-%dT%H:%M:%SZ')
if amount >= 100: if start is not None:
amount = 99 start = start.replace(tzinfo=tzone).astimezone(tz=timezone.utc)
start = start.replace(tzinfo=None)
start = datetime.strftime(start, '%Y-%m-%dT%H:%M:%SZ')
groupchat = isinstance(tab, tabs.MucTab) mam_iterator = await get_mam_iterator(
core=tab.core,
results = await query( groupchat=isinstance(tab, tabs.MucTab),
tab.core, remote_jid=remote_jid,
groupchat, amount=amount,
remote_jid,
amount,
reverse=True,
end=end, end=end,
before=before, start=start,
reverse=True,
) )
query_status = await add_messages_to_buffer(tab, True, results, amount) return await retrieve_messages(tab, mam_iterator, amount)
tab.query_status = query_status
async def fill_missing_history(tab: tabs.Tab, gap: HistoryGap) -> None:
start = gap.last_timestamp_before_leave
end = gap.first_timestamp_after_join
if start:
start = start + timedelta(seconds=1)
if end:
end = end - timedelta(seconds=1)
try:
messages = await fetch_history(tab, start=start, end=end, amount=999)
tab._text_buffer.add_history_messages(messages, gap=gap)
tab.core.refresh_window()
except (NoMAMSupportException, MAMQueryException, DiscoInfoException):
return
finally:
tab.query_status = False
async def on_tab_open(tab) -> None: async def on_new_tab_open(tab: tabs.Tab) -> None:
"""Called when opening a new tab"""
amount = 2 * tab.text_win.height amount = 2 * tab.text_win.height
end = datetime.now() end = datetime.now()
tab.query_status = True
for message in tab._text_buffer.messages: for message in tab._text_buffer.messages:
time = message.time if isinstance(message, Message) and message.time < end:
if time < end: end = message.time
end = time break
end = end + timedelta(seconds=-1) end = end - timedelta(microseconds=1)
try: try:
await fetch_history(tab, end=end, amount=amount) messages = await fetch_history(tab, end=end, amount=amount)
tab._text_buffer.add_history_messages(messages)
except (NoMAMSupportException, MAMQueryException, DiscoInfoException): except (NoMAMSupportException, MAMQueryException, DiscoInfoException):
tab.query_status = False
return None return None
finally:
tab.query_status = False
def schedule_tab_open(tab: tabs.Tab) -> None:
"""Set the query status and schedule a MAM query"""
tab.query_status = True
asyncio.ensure_future(on_tab_open(tab))
async def on_tab_open(tab: tabs.Tab) -> None:
gap = tab._text_buffer.find_last_gap_muc()
if gap is not None:
await fill_missing_history(tab, gap)
else:
await on_new_tab_open(tab)
def schedule_scroll_up(tab: tabs.Tab) -> None:
"""Set query status and schedule a scroll up"""
tab.query_status = True
asyncio.ensure_future(on_scroll_up(tab))
async def on_scroll_up(tab) -> None: async def on_scroll_up(tab) -> None:
@ -212,22 +260,22 @@ async def on_scroll_up(tab) -> None:
# join if not already available. # join if not already available.
total, pos, height = len(tw.built_lines), tw.pos, tw.height total, pos, height = len(tw.built_lines), tw.pos, tw.height
rest = (total - pos) // height rest = (total - pos) // height
# Not resetting the state of query_status here, it is changed only after the
# query is complete (in fetch_history)
# This is done to stop message repetition, eg: if the user presses PageUp continuously.
tab.query_status = True
if rest > 1: if rest > 1:
tab.query_status = False
return None return None
try: try:
# XXX: Do we want to fetch a possibly variable number of messages? # XXX: Do we want to fetch a possibly variable number of messages?
# (InfoTab changes height depending on the type of messages, see # (InfoTab changes height depending on the type of messages, see
# `information_buffer_popup_on`). # `information_buffer_popup_on`).
await fetch_history(tab, amount=height) messages = await fetch_history(tab, amount=height)
tab._text_buffer.add_history_messages(messages)
except NoMAMSupportException: except NoMAMSupportException:
tab.core.information('MAM not supported for %r' % tab.jid, 'Info') tab.core.information('MAM not supported for %r' % tab.jid, 'Info')
return None return None
except (MAMQueryException, DiscoInfoException): except (MAMQueryException, DiscoInfoException):
tab.core.information('An error occured when fetching MAM for %r' % tab.jid, 'Error') tab.core.information('An error occured when fetching MAM for %r' % tab.jid, 'Error')
return None return None
finally:
tab.query_status = False

View file

@ -32,7 +32,6 @@ from typing import (
) )
from poezio import ( from poezio import (
mam,
poopt, poopt,
timed_events, timed_events,
xhtml, xhtml,
@ -926,7 +925,8 @@ class ChatTab(Tab):
def on_scroll_up(self): def on_scroll_up(self):
if not self.query_status: if not self.query_status:
asyncio.ensure_future(mam.on_scroll_up(tab=self)) from poezio import mam
mam.schedule_scroll_up(tab=self)
return self.text_win.scroll_up(self.text_win.height - 1) return self.text_win.scroll_up(self.text_win.height - 1)
def on_scroll_down(self): def on_scroll_down(self):

View file

@ -170,7 +170,7 @@ class MucTab(ChatTab):
status=status.message, status=status.message,
show=status.show, show=status.show,
seconds=seconds) seconds=seconds)
asyncio.ensure_future(mam.on_tab_open(self)) mam.schedule_tab_open(self)
def leave_room(self, message: str): def leave_room(self, message: str):
if self.joined: if self.joined: