diff --git a/slixmpp_omemo/__init__.py b/slixmpp_omemo/__init__.py index 4825db3..dafd807 100644 --- a/slixmpp_omemo/__init__.py +++ b/slixmpp_omemo/__init__.py @@ -543,6 +543,26 @@ class XEP_0384(BasePlugin): return lengths + def _should_heartbeat(self, jid: JID, prekey: bool) -> bool: + """ + Internal helper for :py:func:`XEP_0384.should_heartbeat`. + + Returns whether we should send a heartbeat message for JID. + + We check if the message is a prekey message, in which case we + assume it's a new session and we want to ACK relatively early. + + Otherwise we look at the number of messages since we have last + replied and if above a certain threshold we notify them that we're + still active. + """ + + receiving_chain_lengths = self._chain_lengths(jid).get('receiving', []) + lengths = map(lambda d_l: d_l[1], receiving_chain_lengths) + inactive_session = max(lengths, default=0) > self.heartbeat_after + + return prekey or inactive_session + def should_heartbeat(self, jid: JID, msg: Union[Message, Encrypted]) -> bool: """ Returns whether we should send a heartbeat message for JID. @@ -554,10 +574,9 @@ class XEP_0384(BasePlugin): sessions hasn't been answered in a while. """ - new_session: bool = False - inactive_session: bool = False + prekey: bool = False - # Is the message is a prekey message. If so assume it's a new session + # Get prekey information from message encrypted = msg if isinstance(msg, Message): encrypted = msg['omemo_encrypted'] @@ -565,17 +584,13 @@ class XEP_0384(BasePlugin): header = encrypted['header'] key = header.xml.find("{%s}key[@rid='%s']" % ( OMEMO_BASE_NS, self._device_id)) + # Don't error out. If it's not encrypted to us we don't need to send a + # heartbeat. if key is not None: key = Key(key) - new_session = key['prekey'] in TRUE_VALUES + prekey = key['prekey'] in TRUE_VALUES - # Otherwise find how many message we haven't replied to in all of the - # sessions associated to this JID. - receiving_chain_lengths = self._chain_lengths(jid).get('receiving', []) - lengths = map(lambda d_l: d_l[1], receiving_chain_lengths) - inactive_session = max(lengths, default=0) > self.heartbeat_after - - return new_session or inactive_session + return self._should_heartbeat(jid, prekey) async def make_heartbeat(self, jid: JID) -> Message: """ @@ -682,8 +697,9 @@ class XEP_0384(BasePlugin): finally: asyncio.ensure_future(self._publish_bundle()) - if self.auto_heartbeat and self.should_heartbeat(sender, encrypted): + if self.auto_heartbeat and self._should_heartbeat(sender, isPrekeyMessage): async def send_heartbeat(): + log.debug('Sending a heartbeat message') msg = await self.make_heartbeat(JID(jid)) msg.send() asyncio.ensure_future(send_heartbeat())