diff --git a/slixmpp_omemo/__init__.py b/slixmpp_omemo/__init__.py index f35440d..4825db3 100644 --- a/slixmpp_omemo/__init__.py +++ b/slixmpp_omemo/__init__.py @@ -543,7 +543,7 @@ class XEP_0384(BasePlugin): return lengths - def should_heartbeat(self, jid: JID) -> bool: + def should_heartbeat(self, jid: JID, msg: Union[Message, Encrypted]) -> bool: """ Returns whether we should send a heartbeat message for JID. See notes about heartbeat in @@ -554,9 +554,28 @@ class XEP_0384(BasePlugin): sessions hasn't been answered in a while. """ + new_session: bool = False + inactive_session: bool = False + + # Is the message is a prekey message. If so assume it's a new session + encrypted = msg + if isinstance(msg, Message): + encrypted = msg['omemo_encrypted'] + + header = encrypted['header'] + key = header.xml.find("{%s}key[@rid='%s']" % ( + OMEMO_BASE_NS, self._device_id)) + if key is not None: + key = Key(key) + new_session = key['prekey'] in TRUE_VALUES + + # Otherwise find how many message we haven't replied to in all of the + # sessions associated to this JID. receiving_chain_lengths = self._chain_lengths(jid).get('receiving', []) lengths = map(lambda d_l: d_l[1], receiving_chain_lengths) - return max(lengths, default=0) > self.heartbeat_after + inactive_session = max(lengths, default=0) > self.heartbeat_after + + return new_session or inactive_session async def make_heartbeat(self, jid: JID) -> Message: """ @@ -663,7 +682,7 @@ class XEP_0384(BasePlugin): finally: asyncio.ensure_future(self._publish_bundle()) - if self.auto_heartbeat and self.should_heartbeat(sender): + if self.auto_heartbeat and self.should_heartbeat(sender, encrypted): async def send_heartbeat(): msg = await self.make_heartbeat(JID(jid)) msg.send()