should_heartbeat: also return True on new sessions

And the docstring now reflects the reality again!

We're parsing the Encrypted dict again, when we just did it in
decrypt_message above, but this function is also part of the API and
doing that for them is the least we can do.

Maybe there should be an internal function that we can call from
decrypt_message, that also gets called by should_heartbeat.

Signed-off-by: Maxime “pep” Buquet <pep@bouah.net>
This commit is contained in:
Maxime “pep” Buquet 2021-07-17 18:58:31 +02:00
parent 91a04000d7
commit bb52d93241
Signed by: pep
GPG key ID: DEDA74AEECA9D0F2

View file

@ -543,7 +543,7 @@ class XEP_0384(BasePlugin):
return lengths
def should_heartbeat(self, jid: JID) -> bool:
def should_heartbeat(self, jid: JID, msg: Union[Message, Encrypted]) -> bool:
"""
Returns whether we should send a heartbeat message for JID.
See notes about heartbeat in
@ -554,9 +554,28 @@ class XEP_0384(BasePlugin):
sessions hasn't been answered in a while.
"""
new_session: bool = False
inactive_session: bool = False
# Is the message is a prekey message. If so assume it's a new session
encrypted = msg
if isinstance(msg, Message):
encrypted = msg['omemo_encrypted']
header = encrypted['header']
key = header.xml.find("{%s}key[@rid='%s']" % (
OMEMO_BASE_NS, self._device_id))
if key is not None:
key = Key(key)
new_session = key['prekey'] in TRUE_VALUES
# Otherwise find how many message we haven't replied to in all of the
# sessions associated to this JID.
receiving_chain_lengths = self._chain_lengths(jid).get('receiving', [])
lengths = map(lambda d_l: d_l[1], receiving_chain_lengths)
return max(lengths, default=0) > self.heartbeat_after
inactive_session = max(lengths, default=0) > self.heartbeat_after
return new_session or inactive_session
async def make_heartbeat(self, jid: JID) -> Message:
"""
@ -663,7 +682,7 @@ class XEP_0384(BasePlugin):
finally:
asyncio.ensure_future(self._publish_bundle())
if self.auto_heartbeat and self.should_heartbeat(sender):
if self.auto_heartbeat and self.should_heartbeat(sender, encrypted):
async def send_heartbeat():
msg = await self.make_heartbeat(JID(jid))
msg.send()