From c7a0a092d44d77dea531cb310a2fd925c31e42d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maxime=20=E2=80=9Cpep=E2=80=9D=20Buquet?= Date: Sat, 17 Jul 2021 02:54:57 +0200 Subject: [PATCH] should_heartbeat: also take into account unacked sessions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit receiving_chain_length would sometimes be None (thanks python strict typing) causing the thing to fail. When this is the case, I assume this means the session hasn't been confirmed from our side yet and it would be good to ACK it. (To be confirmed with people who know, in progress). Signed-off-by: Maxime “pep” Buquet --- slixmpp_omemo/__init__.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/slixmpp_omemo/__init__.py b/slixmpp_omemo/__init__.py index 841a46a..e736017 100644 --- a/slixmpp_omemo/__init__.py +++ b/slixmpp_omemo/__init__.py @@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union # Not available in Python 3.7, and slixmpp already imports the right things # for me from slixmpp.types import TypedDict +from functools import reduce import os import json @@ -533,10 +534,12 @@ class XEP_0384(BasePlugin): for did in devices: session = self._omemo._SessionManager__loadSession(bare, did) if session is None: - break + continue skr = session._DoubleRatchet__skr - lengths['sending'].append((did, skr.sending_chain_length)) - lengths['receiving'].append((did, skr.receiving_chain_length)) + sending = skr.sending_chain_length or -1 + receiving = skr.receiving_chain_length or -1 + lengths['sending'].append((did, sending)) + lengths['receiving'].append((did, receiving)) return lengths @@ -545,11 +548,16 @@ class XEP_0384(BasePlugin): Returns whether we should send a heartbeat message for JID. See notes about heartbeat in https://xmpp.org/extensions/xep-0384.html#rules. + + This method will return True when a session among all of the + sessions for this JID is not yet confirmed, or if one of the + sessions hasn't been answered in a while. """ receiving_chain_lengths = self._chain_lengths(jid).get('receiving', []) lengths = map(lambda d_l: d_l[1], receiving_chain_lengths) - return max(lengths, default=0) > self.heartbeat_after + min_length = reduce(lambda x, d_l: min(x, d_l[1]), receiving_chain_lengths, 0) == -1 + return min_length or max(lengths, default=0) > self.heartbeat_after async def make_heartbeat(self, jid: JID) -> Message: """