From bb52d93241717ae65c3e6c5331dc03c4393c8a91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maxime=20=E2=80=9Cpep=E2=80=9D=20Buquet?= Date: Sat, 17 Jul 2021 18:58:31 +0200 Subject: [PATCH] should_heartbeat: also return True on new sessions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit And the docstring now reflects the reality again! We're parsing the Encrypted dict again, when we just did it in decrypt_message above, but this function is also part of the API and doing that for them is the least we can do. Maybe there should be an internal function that we can call from decrypt_message, that also gets called by should_heartbeat. Signed-off-by: Maxime “pep” Buquet --- slixmpp_omemo/__init__.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/slixmpp_omemo/__init__.py b/slixmpp_omemo/__init__.py index f35440d..4825db3 100644 --- a/slixmpp_omemo/__init__.py +++ b/slixmpp_omemo/__init__.py @@ -543,7 +543,7 @@ class XEP_0384(BasePlugin): return lengths - def should_heartbeat(self, jid: JID) -> bool: + def should_heartbeat(self, jid: JID, msg: Union[Message, Encrypted]) -> bool: """ Returns whether we should send a heartbeat message for JID. See notes about heartbeat in @@ -554,9 +554,28 @@ class XEP_0384(BasePlugin): sessions hasn't been answered in a while. """ + new_session: bool = False + inactive_session: bool = False + + # Is the message is a prekey message. If so assume it's a new session + encrypted = msg + if isinstance(msg, Message): + encrypted = msg['omemo_encrypted'] + + header = encrypted['header'] + key = header.xml.find("{%s}key[@rid='%s']" % ( + OMEMO_BASE_NS, self._device_id)) + if key is not None: + key = Key(key) + new_session = key['prekey'] in TRUE_VALUES + + # Otherwise find how many message we haven't replied to in all of the + # sessions associated to this JID. receiving_chain_lengths = self._chain_lengths(jid).get('receiving', []) lengths = map(lambda d_l: d_l[1], receiving_chain_lengths) - return max(lengths, default=0) > self.heartbeat_after + inactive_session = max(lengths, default=0) > self.heartbeat_after + + return new_session or inactive_session async def make_heartbeat(self, jid: JID) -> Message: """ @@ -663,7 +682,7 @@ class XEP_0384(BasePlugin): finally: asyncio.ensure_future(self._publish_bundle()) - if self.auto_heartbeat and self.should_heartbeat(sender): + if self.auto_heartbeat and self.should_heartbeat(sender, encrypted): async def send_heartbeat(): msg = await self.make_heartbeat(JID(jid)) msg.send()