From 6ab8bba4f0208fef95b1027fb2d9b3a11f18e493 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maxime=20=E2=80=9Cpep=E2=80=9D=20Buquet?= Date: Thu, 3 Mar 2022 20:00:18 +0100 Subject: [PATCH] Use new encryptRatchetForwardingMessage API added in f3c3a45e MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Maxime “pep” Buquet --- slixmpp_omemo/__init__.py | 76 +++++++++++++++++++-------------------- 1 file changed, 36 insertions(+), 40 deletions(-) diff --git a/slixmpp_omemo/__init__.py b/slixmpp_omemo/__init__.py index f2820e4..5638e3f 100644 --- a/slixmpp_omemo/__init__.py +++ b/slixmpp_omemo/__init__.py @@ -184,6 +184,11 @@ class ErroneousPayload(XEP0384): """To be raised when the payload is not of the form we expect""" +class ErroneousParameter(XEP0384): + """To be raised when parameters to the `encrypt_message` method aren't + used as expected.""" + + class XEP_0384(BasePlugin): """ @@ -527,34 +532,12 @@ class XEP_0384(BasePlugin): devices = await self._omemo().getDevices(jid.bare) return devices.get('active', []) - async def _chain_lengths(self, jid: JID) -> Set[Tuple[int, int]]: - """ - Gather receiving chain lengths for all devices (active / inactive) - of a JID. - - Receiving chain length is used to know when to send a heartbeat to - signal recipients our device is still active and listening. See: - https://xmpp.org/extensions/xep-0384.html#rules - """ - - bare = jid.bare - devices = await self._omemo().getDevices(bare) - active = devices.get('active', set()) - inactive = devices.get('inactive', set()) - devices = active.union(inactive) - - lengths: Set[Tuple[int, int]] = set() - for did in devices: - receiving = await self._omemo().receiving_chain_length(bare, did) - lengths.add((did, receiving or 0)) - - return lengths - - async def _should_heartbeat(self, jid: JID, prekey: bool) -> bool: + async def _should_heartbeat(self, jid: JID, device_id: int, prekey: bool) -> bool: """ Internal helper for :py:func:`XEP_0384.should_heartbeat`. - Returns whether we should send a heartbeat message for JID. + Returns whether we should send a heartbeat message for (JID, + device_id). We check if the message is a prekey message, in which case we assume it's a new session and we want to ACK relatively early. @@ -564,21 +547,20 @@ class XEP_0384(BasePlugin): still active. """ - receiving_chain_lengths = await self._chain_lengths(jid) - lengths = map(lambda d_l: d_l[1], receiving_chain_lengths) - inactive_session = max(lengths, default=0) > self.heartbeat_after + length = await self._omemo().receiving_chain_length(jid.bare, device_id) + inactive_session = (length or 0) > self.heartbeat_after return prekey or inactive_session async def should_heartbeat(self, jid: JID, msg: Union[Message, Encrypted]) -> bool: """ - Returns whether we should send a heartbeat message for JID. - See notes about heartbeat in + Returns whether we should send a heartbeat message to the sender + device. See notes about heartbeat in https://xmpp.org/extensions/xep-0384.html#rules. - This method will return True when a session among all of the - sessions for this JID is not yet confirmed, or if one of the - sessions hasn't been answered in a while. + This method will return True if this session (to the sender + device) is not yet confirmed, or if it hasn't been answered in a + while. """ prekey: bool = False @@ -589,6 +571,7 @@ class XEP_0384(BasePlugin): encrypted = msg['omemo_encrypted'] header = encrypted['header'] + sid = header['sid'] key = header.xml.find("{%s}key[@rid='%s']" % ( OMEMO_BASE_NS, self._device_id)) # Don't error out. If it's not encrypted to us we don't need to send a @@ -598,9 +581,9 @@ class XEP_0384(BasePlugin): key = Key(key) prekey = key['prekey'] in TRUE_VALUES - return await self._should_heartbeat(jid, prekey) + return await self._should_heartbeat(jid, sid, prekey) - async def make_heartbeat(self, jid: JID) -> Message: + async def make_heartbeat(self, jid: JID, device_id: int) -> Message: """ Returns a heartbeat message. @@ -615,6 +598,7 @@ class XEP_0384(BasePlugin): recipients=[jid], expect_problems=None, _ignore_trust=True, + _device_id=device_id, ) msg.append(encrypted) return msg @@ -724,12 +708,12 @@ class XEP_0384(BasePlugin): if self.auto_heartbeat: log.debug('Checking if heartbeat is required. auto_hearbeat enabled.') - should_heartbeat = await self._should_heartbeat(sender, isPrekeyMessage) + should_heartbeat = await self._should_heartbeat(sender, sid, isPrekeyMessage) if should_heartbeat: - log.debug('Decryption: Sending hearbeat to JID %r', jid, device) + log.debug('Decryption: Sending hearbeat to %s / %d', jid, sid) async def send_heartbeat(): log.debug('Sending a heartbeat message') - msg = await self.make_heartbeat(JID(jid)) + msg = await self.make_heartbeat(JID(jid), sid) msg.send() asyncio.ensure_future(send_heartbeat()) @@ -741,6 +725,7 @@ class XEP_0384(BasePlugin): recipients: List[JID], expect_problems: Optional[Dict[JID, List[int]]] = None, _ignore_trust: bool = False, + _device_id: Optional[int] = None, ) -> Encrypted: """ Returns an encrypted payload to be placed into a message. @@ -756,6 +741,11 @@ class XEP_0384(BasePlugin): These are rather technical details to the user and fiddling with parameters else than `plaintext` and `recipients` should be rarely needed. + + The `_device_id` parameter is required in the case of a ratchet + forwarding message. That is, `plaintext` to None, and `_ignore_trust` + to True. If specified, a single recipient JID is required. If not all + these conditions are met, ErroneousParameter will be raised. """ barejids: List[str] = [jid.bare for jid in recipients] @@ -781,9 +771,15 @@ class XEP_0384(BasePlugin): expect_problems=expect_problems, ) elif _ignore_trust: + if not _device_id or len(barejids) != 1: + raise ErroneousParameter + bundle = self.bundles.get(barejids[0], {}).get(_device_id) + if bundle is None: + error = omemo.exceptions.MissingBundleException(barejids[0], _device_id) + raise omemo.exceptions.EncryptionProblemsException([error]) encrypted = await self._omemo().encryptRatchetForwardingMessage( - barejids - self.bundles, + barejids[0], + bundle, expect_problems=expect_problems, ) else: