should_heartbeat: also take into account unacked sessions

receiving_chain_length would sometimes be None (thanks python strict
typing) causing the thing to fail.

When this is the case, I assume this means the session hasn't been
confirmed from our side yet and it would be good to ACK it. (To be
confirmed with people who know, in progress).

Signed-off-by: Maxime “pep” Buquet <pep@bouah.net>
This commit is contained in:
Maxime “pep” Buquet 2021-07-17 02:54:57 +02:00
parent a2a287ee5d
commit c7a0a092d4

View file

@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
# Not available in Python 3.7, and slixmpp already imports the right things # Not available in Python 3.7, and slixmpp already imports the right things
# for me # for me
from slixmpp.types import TypedDict from slixmpp.types import TypedDict
from functools import reduce
import os import os
import json import json
@ -533,10 +534,12 @@ class XEP_0384(BasePlugin):
for did in devices: for did in devices:
session = self._omemo._SessionManager__loadSession(bare, did) session = self._omemo._SessionManager__loadSession(bare, did)
if session is None: if session is None:
break continue
skr = session._DoubleRatchet__skr skr = session._DoubleRatchet__skr
lengths['sending'].append((did, skr.sending_chain_length)) sending = skr.sending_chain_length or -1
lengths['receiving'].append((did, skr.receiving_chain_length)) receiving = skr.receiving_chain_length or -1
lengths['sending'].append((did, sending))
lengths['receiving'].append((did, receiving))
return lengths return lengths
@ -545,11 +548,16 @@ class XEP_0384(BasePlugin):
Returns whether we should send a heartbeat message for JID. Returns whether we should send a heartbeat message for JID.
See notes about heartbeat in See notes about heartbeat in
https://xmpp.org/extensions/xep-0384.html#rules. https://xmpp.org/extensions/xep-0384.html#rules.
This method will return True when a session among all of the
sessions for this JID is not yet confirmed, or if one of the
sessions hasn't been answered in a while.
""" """
receiving_chain_lengths = self._chain_lengths(jid).get('receiving', []) receiving_chain_lengths = self._chain_lengths(jid).get('receiving', [])
lengths = map(lambda d_l: d_l[1], receiving_chain_lengths) lengths = map(lambda d_l: d_l[1], receiving_chain_lengths)
return max(lengths, default=0) > self.heartbeat_after min_length = reduce(lambda x, d_l: min(x, d_l[1]), receiving_chain_lengths, 0) == -1
return min_length or max(lengths, default=0) > self.heartbeat_after
async def make_heartbeat(self, jid: JID) -> Message: async def make_heartbeat(self, jid: JID) -> Message:
""" """