Use new encryptRatchetForwardingMessage API added in f3c3a45e

Signed-off-by: Maxime “pep” Buquet <pep@bouah.net>
This commit is contained in:
Maxime “pep” Buquet 2022-03-03 20:00:18 +01:00
parent bf3f5472f7
commit 6ab8bba4f0
Signed by: pep
GPG key ID: DEDA74AEECA9D0F2

View file

@ -184,6 +184,11 @@ class ErroneousPayload(XEP0384):
"""To be raised when the payload is not of the form we expect""" """To be raised when the payload is not of the form we expect"""
class ErroneousParameter(XEP0384):
"""To be raised when parameters to the `encrypt_message` method aren't
used as expected."""
class XEP_0384(BasePlugin): class XEP_0384(BasePlugin):
""" """
@ -527,34 +532,12 @@ class XEP_0384(BasePlugin):
devices = await self._omemo().getDevices(jid.bare) devices = await self._omemo().getDevices(jid.bare)
return devices.get('active', []) return devices.get('active', [])
async def _chain_lengths(self, jid: JID) -> Set[Tuple[int, int]]: async def _should_heartbeat(self, jid: JID, device_id: int, prekey: bool) -> bool:
"""
Gather receiving chain lengths for all devices (active / inactive)
of a JID.
Receiving chain length is used to know when to send a heartbeat to
signal recipients our device is still active and listening. See:
https://xmpp.org/extensions/xep-0384.html#rules
"""
bare = jid.bare
devices = await self._omemo().getDevices(bare)
active = devices.get('active', set())
inactive = devices.get('inactive', set())
devices = active.union(inactive)
lengths: Set[Tuple[int, int]] = set()
for did in devices:
receiving = await self._omemo().receiving_chain_length(bare, did)
lengths.add((did, receiving or 0))
return lengths
async def _should_heartbeat(self, jid: JID, prekey: bool) -> bool:
""" """
Internal helper for :py:func:`XEP_0384.should_heartbeat`. Internal helper for :py:func:`XEP_0384.should_heartbeat`.
Returns whether we should send a heartbeat message for JID. Returns whether we should send a heartbeat message for (JID,
device_id).
We check if the message is a prekey message, in which case we We check if the message is a prekey message, in which case we
assume it's a new session and we want to ACK relatively early. assume it's a new session and we want to ACK relatively early.
@ -564,21 +547,20 @@ class XEP_0384(BasePlugin):
still active. still active.
""" """
receiving_chain_lengths = await self._chain_lengths(jid) length = await self._omemo().receiving_chain_length(jid.bare, device_id)
lengths = map(lambda d_l: d_l[1], receiving_chain_lengths) inactive_session = (length or 0) > self.heartbeat_after
inactive_session = max(lengths, default=0) > self.heartbeat_after
return prekey or inactive_session return prekey or inactive_session
async def should_heartbeat(self, jid: JID, msg: Union[Message, Encrypted]) -> bool: async def should_heartbeat(self, jid: JID, msg: Union[Message, Encrypted]) -> bool:
""" """
Returns whether we should send a heartbeat message for JID. Returns whether we should send a heartbeat message to the sender
See notes about heartbeat in device. See notes about heartbeat in
https://xmpp.org/extensions/xep-0384.html#rules. https://xmpp.org/extensions/xep-0384.html#rules.
This method will return True when a session among all of the This method will return True if this session (to the sender
sessions for this JID is not yet confirmed, or if one of the device) is not yet confirmed, or if it hasn't been answered in a
sessions hasn't been answered in a while. while.
""" """
prekey: bool = False prekey: bool = False
@ -589,6 +571,7 @@ class XEP_0384(BasePlugin):
encrypted = msg['omemo_encrypted'] encrypted = msg['omemo_encrypted']
header = encrypted['header'] header = encrypted['header']
sid = header['sid']
key = header.xml.find("{%s}key[@rid='%s']" % ( key = header.xml.find("{%s}key[@rid='%s']" % (
OMEMO_BASE_NS, self._device_id)) OMEMO_BASE_NS, self._device_id))
# Don't error out. If it's not encrypted to us we don't need to send a # Don't error out. If it's not encrypted to us we don't need to send a
@ -598,9 +581,9 @@ class XEP_0384(BasePlugin):
key = Key(key) key = Key(key)
prekey = key['prekey'] in TRUE_VALUES prekey = key['prekey'] in TRUE_VALUES
return await self._should_heartbeat(jid, prekey) return await self._should_heartbeat(jid, sid, prekey)
async def make_heartbeat(self, jid: JID) -> Message: async def make_heartbeat(self, jid: JID, device_id: int) -> Message:
""" """
Returns a heartbeat message. Returns a heartbeat message.
@ -615,6 +598,7 @@ class XEP_0384(BasePlugin):
recipients=[jid], recipients=[jid],
expect_problems=None, expect_problems=None,
_ignore_trust=True, _ignore_trust=True,
_device_id=device_id,
) )
msg.append(encrypted) msg.append(encrypted)
return msg return msg
@ -724,12 +708,12 @@ class XEP_0384(BasePlugin):
if self.auto_heartbeat: if self.auto_heartbeat:
log.debug('Checking if heartbeat is required. auto_hearbeat enabled.') log.debug('Checking if heartbeat is required. auto_hearbeat enabled.')
should_heartbeat = await self._should_heartbeat(sender, isPrekeyMessage) should_heartbeat = await self._should_heartbeat(sender, sid, isPrekeyMessage)
if should_heartbeat: if should_heartbeat:
log.debug('Decryption: Sending hearbeat to JID %r', jid, device) log.debug('Decryption: Sending hearbeat to %s / %d', jid, sid)
async def send_heartbeat(): async def send_heartbeat():
log.debug('Sending a heartbeat message') log.debug('Sending a heartbeat message')
msg = await self.make_heartbeat(JID(jid)) msg = await self.make_heartbeat(JID(jid), sid)
msg.send() msg.send()
asyncio.ensure_future(send_heartbeat()) asyncio.ensure_future(send_heartbeat())
@ -741,6 +725,7 @@ class XEP_0384(BasePlugin):
recipients: List[JID], recipients: List[JID],
expect_problems: Optional[Dict[JID, List[int]]] = None, expect_problems: Optional[Dict[JID, List[int]]] = None,
_ignore_trust: bool = False, _ignore_trust: bool = False,
_device_id: Optional[int] = None,
) -> Encrypted: ) -> Encrypted:
""" """
Returns an encrypted payload to be placed into a message. Returns an encrypted payload to be placed into a message.
@ -756,6 +741,11 @@ class XEP_0384(BasePlugin):
These are rather technical details to the user and fiddling with These are rather technical details to the user and fiddling with
parameters else than `plaintext` and `recipients` should be rarely parameters else than `plaintext` and `recipients` should be rarely
needed. needed.
The `_device_id` parameter is required in the case of a ratchet
forwarding message. That is, `plaintext` to None, and `_ignore_trust`
to True. If specified, a single recipient JID is required. If not all
these conditions are met, ErroneousParameter will be raised.
""" """
barejids: List[str] = [jid.bare for jid in recipients] barejids: List[str] = [jid.bare for jid in recipients]
@ -781,9 +771,15 @@ class XEP_0384(BasePlugin):
expect_problems=expect_problems, expect_problems=expect_problems,
) )
elif _ignore_trust: elif _ignore_trust:
if not _device_id or len(barejids) != 1:
raise ErroneousParameter
bundle = self.bundles.get(barejids[0], {}).get(_device_id)
if bundle is None:
error = omemo.exceptions.MissingBundleException(barejids[0], _device_id)
raise omemo.exceptions.EncryptionProblemsException([error])
encrypted = await self._omemo().encryptRatchetForwardingMessage( encrypted = await self._omemo().encryptRatchetForwardingMessage(
barejids barejids[0],
self.bundles, bundle,
expect_problems=expect_problems, expect_problems=expect_problems,
) )
else: else: