diff --git a/examples/echo_client.py b/examples/echo_client.py index d775bf0..a4ae84e 100644 --- a/examples/echo_client.py +++ b/examples/echo_client.py @@ -90,6 +90,8 @@ class EchoBot(ClientXMPP): await self.cmd_verbose(mto, mtype) elif cmd == 'error': await self.cmd_error(mto, mtype) + elif cmd == 'chain_length': + await self.cmd_chain_length(mto, mtype) return None @@ -112,6 +114,13 @@ class EchoBot(ClientXMPP): body = '''Debug level set to 'error'.''' return await self.encrypted_reply(mto, mtype, body) + async def cmd_chain_length(self, mto: JID, mtype: str) -> None: + body = ( + 'lengths: %r\n' % self['xep_0384']._chain_lengths(mto) + + 'should heartbeat: %r' % self['xep_0384'].should_heartbeat(mto) + ) + return await self.encrypted_reply(mto, mtype, body) + def message_handler(self, msg: Message) -> None: asyncio.ensure_future(self.message(msg)) diff --git a/slixmpp_omemo/__init__.py b/slixmpp_omemo/__init__.py index 6d33e3c..47c30c5 100644 --- a/slixmpp_omemo/__init__.py +++ b/slixmpp_omemo/__init__.py @@ -13,6 +13,10 @@ import logging from typing import Any, Dict, List, Optional, Set, Tuple, Union +# Not available in Python 3.7, and slixmpp already imports the right things +# for me +from slixmpp.types import TypedDict + import os import json import base64 @@ -57,6 +61,11 @@ PUBLISH_OPTIONS_NODE = 'http://jabber.org/protocol/pubsub#publish-options' PUBSUB_ERRORS = 'http://jabber.org/protocol/pubsub#errors' +class ChainLengths(TypedDict): + receiving: List[Tuple[int, int]] + sending: List[Tuple[int, int]] + + def b64enc(data: bytes) -> str: return base64.b64encode(bytes(bytearray(data))).decode('ASCII') @@ -188,6 +197,8 @@ class XEP_0384(BasePlugin): 'storage_backend': None, 'otpk_policy': DefaultOTPKPolicy, 'omemo_backend': SignalBackend, + 'heartbeat_after': 53, + # TODO: 'drop_inactive_after': 300, } backend_loaded = HAS_OMEMO and HAS_OMEMO_BACKEND @@ -481,6 +492,50 @@ class XEP_0384(BasePlugin): """Return active device ids. Always contains our own device id.""" return self._omemo.getDevices(jid.bare).get('active', []) + def _chain_lengths(self, jid: JID) -> ChainLengths: + """ + Gather receiving and sending chain lengths for all devices (active + / inactive) of a JID. + + Receiving chain length is used to know when to send a heartbeat to + signal recipients our device is still active and listening. See: + https://xmpp.org/extensions/xep-0384.html#rules + + Sending chain length is used on the other side when a device + hasn't been sending us messages and seems inactive. + + # XXX: Only the receiving part is used in this library for the + # moment. + """ + # XXX: This method uses APIs that haven't been made public yet in the + # OMEMO library as of 0.12 (9fd7123). + + bare = jid.bare + devices = self._omemo.getDevices(bare) + active = devices.get('active', set()) + inactive = devices.get('inactive', set()) + devices = active.union(inactive) + + lengths: ChainLengths = {'sending': [], 'receiving': []} + for did in devices: + session = self._omemo._SessionManager__loadSession(bare, did) + skr = session._DoubleRatchet__skr + lengths['sending'].append((did, skr.sending_chain_length)) + lengths['receiving'].append((did, skr.receiving_chain_length)) + + return lengths + + def should_heartbeat(self, jid: JID) -> bool: + """ + Returns whether we should send a heartbeat message for JID. + See notes about heartbeat in + https://xmpp.org/extensions/xep-0384.html#rules. + """ + + receiving_chain_lengths = self._chain_lengths(jid).get('receiving', []) + lengths = map(lambda d_l: d_l[1], receiving_chain_lengths) + return max(lengths, default=0) > self.heartbeat_after + def trust(self, jid: JID, device_id: int, ik: bytes) -> None: self._omemo.setTrust(jid.bare, device_id, ik, True)