Learn about chain length: new should_heartbeat method

With this commit, slixmpp-omemo now reads the ratchet chain length,
(both receiving and sending), that we should track to know when to send
a heartbeat message.

This allows us to signal other devices that we are still active and
listening. Some clients will stop encrypting to us if we haven't replied
for a certain number of messages.

The current 0384 spec (0.7) says we should send a heartbeat message at
least once this number goes over 53 (fair dice roll). It doesn't say
when a client may/should stop encrypting to us, or what it should do at
all once we go over 53.

Signed-off-by: Maxime “pep” Buquet <pep@bouah.net>
This commit is contained in:
Maxime “pep” Buquet 2021-07-13 20:29:14 +02:00
parent 38701075e9
commit af33cd41e5
Signed by: pep
GPG key ID: DEDA74AEECA9D0F2
2 changed files with 64 additions and 0 deletions

View file

@ -90,6 +90,8 @@ class EchoBot(ClientXMPP):
await self.cmd_verbose(mto, mtype)
elif cmd == 'error':
await self.cmd_error(mto, mtype)
elif cmd == 'chain_length':
await self.cmd_chain_length(mto, mtype)
return None
@ -112,6 +114,13 @@ class EchoBot(ClientXMPP):
body = '''Debug level set to 'error'.'''
return await self.encrypted_reply(mto, mtype, body)
async def cmd_chain_length(self, mto: JID, mtype: str) -> None:
body = (
'lengths: %r\n' % self['xep_0384']._chain_lengths(mto) +
'should heartbeat: %r' % self['xep_0384'].should_heartbeat(mto)
)
return await self.encrypted_reply(mto, mtype, body)
def message_handler(self, msg: Message) -> None:
asyncio.ensure_future(self.message(msg))

View file

@ -13,6 +13,10 @@ import logging
from typing import Any, Dict, List, Optional, Set, Tuple, Union
# Not available in Python 3.7, and slixmpp already imports the right things
# for me
from slixmpp.types import TypedDict
import os
import json
import base64
@ -57,6 +61,11 @@ PUBLISH_OPTIONS_NODE = 'http://jabber.org/protocol/pubsub#publish-options'
PUBSUB_ERRORS = 'http://jabber.org/protocol/pubsub#errors'
class ChainLengths(TypedDict):
receiving: List[Tuple[int, int]]
sending: List[Tuple[int, int]]
def b64enc(data: bytes) -> str:
return base64.b64encode(bytes(bytearray(data))).decode('ASCII')
@ -188,6 +197,8 @@ class XEP_0384(BasePlugin):
'storage_backend': None,
'otpk_policy': DefaultOTPKPolicy,
'omemo_backend': SignalBackend,
'heartbeat_after': 53,
# TODO: 'drop_inactive_after': 300,
}
backend_loaded = HAS_OMEMO and HAS_OMEMO_BACKEND
@ -481,6 +492,50 @@ class XEP_0384(BasePlugin):
"""Return active device ids. Always contains our own device id."""
return self._omemo.getDevices(jid.bare).get('active', [])
def _chain_lengths(self, jid: JID) -> ChainLengths:
"""
Gather receiving and sending chain lengths for all devices (active
/ inactive) of a JID.
Receiving chain length is used to know when to send a heartbeat to
signal recipients our device is still active and listening. See:
https://xmpp.org/extensions/xep-0384.html#rules
Sending chain length is used on the other side when a device
hasn't been sending us messages and seems inactive.
# XXX: Only the receiving part is used in this library for the
# moment.
"""
# XXX: This method uses APIs that haven't been made public yet in the
# OMEMO library as of 0.12 (9fd7123).
bare = jid.bare
devices = self._omemo.getDevices(bare)
active = devices.get('active', set())
inactive = devices.get('inactive', set())
devices = active.union(inactive)
lengths: ChainLengths = {'sending': [], 'receiving': []}
for did in devices:
session = self._omemo._SessionManager__loadSession(bare, did)
skr = session._DoubleRatchet__skr
lengths['sending'].append((did, skr.sending_chain_length))
lengths['receiving'].append((did, skr.receiving_chain_length))
return lengths
def should_heartbeat(self, jid: JID) -> bool:
"""
Returns whether we should send a heartbeat message for JID.
See notes about heartbeat in
https://xmpp.org/extensions/xep-0384.html#rules.
"""
receiving_chain_lengths = self._chain_lengths(jid).get('receiving', [])
lengths = map(lambda d_l: d_l[1], receiving_chain_lengths)
return max(lengths, default=0) > self.heartbeat_after
def trust(self, jid: JID, device_id: int, ik: bytes) -> None:
self._omemo.setTrust(jid.bare, device_id, ik, True)