should_heartbeat: also return True on new sessions
And the docstring now reflects the reality again! We're parsing the Encrypted dict again, when we just did it in decrypt_message above, but this function is also part of the API and doing that for them is the least we can do. Maybe there should be an internal function that we can call from decrypt_message, that also gets called by should_heartbeat. Signed-off-by: Maxime “pep” Buquet <pep@bouah.net>
This commit is contained in:
parent
91a04000d7
commit
bb52d93241
1 changed files with 22 additions and 3 deletions
|
@ -543,7 +543,7 @@ class XEP_0384(BasePlugin):
|
||||||
|
|
||||||
return lengths
|
return lengths
|
||||||
|
|
||||||
def should_heartbeat(self, jid: JID) -> bool:
|
def should_heartbeat(self, jid: JID, msg: Union[Message, Encrypted]) -> bool:
|
||||||
"""
|
"""
|
||||||
Returns whether we should send a heartbeat message for JID.
|
Returns whether we should send a heartbeat message for JID.
|
||||||
See notes about heartbeat in
|
See notes about heartbeat in
|
||||||
|
@ -554,9 +554,28 @@ class XEP_0384(BasePlugin):
|
||||||
sessions hasn't been answered in a while.
|
sessions hasn't been answered in a while.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
new_session: bool = False
|
||||||
|
inactive_session: bool = False
|
||||||
|
|
||||||
|
# Is the message is a prekey message. If so assume it's a new session
|
||||||
|
encrypted = msg
|
||||||
|
if isinstance(msg, Message):
|
||||||
|
encrypted = msg['omemo_encrypted']
|
||||||
|
|
||||||
|
header = encrypted['header']
|
||||||
|
key = header.xml.find("{%s}key[@rid='%s']" % (
|
||||||
|
OMEMO_BASE_NS, self._device_id))
|
||||||
|
if key is not None:
|
||||||
|
key = Key(key)
|
||||||
|
new_session = key['prekey'] in TRUE_VALUES
|
||||||
|
|
||||||
|
# Otherwise find how many message we haven't replied to in all of the
|
||||||
|
# sessions associated to this JID.
|
||||||
receiving_chain_lengths = self._chain_lengths(jid).get('receiving', [])
|
receiving_chain_lengths = self._chain_lengths(jid).get('receiving', [])
|
||||||
lengths = map(lambda d_l: d_l[1], receiving_chain_lengths)
|
lengths = map(lambda d_l: d_l[1], receiving_chain_lengths)
|
||||||
return max(lengths, default=0) > self.heartbeat_after
|
inactive_session = max(lengths, default=0) > self.heartbeat_after
|
||||||
|
|
||||||
|
return new_session or inactive_session
|
||||||
|
|
||||||
async def make_heartbeat(self, jid: JID) -> Message:
|
async def make_heartbeat(self, jid: JID) -> Message:
|
||||||
"""
|
"""
|
||||||
|
@ -663,7 +682,7 @@ class XEP_0384(BasePlugin):
|
||||||
finally:
|
finally:
|
||||||
asyncio.ensure_future(self._publish_bundle())
|
asyncio.ensure_future(self._publish_bundle())
|
||||||
|
|
||||||
if self.auto_heartbeat and self.should_heartbeat(sender):
|
if self.auto_heartbeat and self.should_heartbeat(sender, encrypted):
|
||||||
async def send_heartbeat():
|
async def send_heartbeat():
|
||||||
msg = await self.make_heartbeat(JID(jid))
|
msg = await self.make_heartbeat(JID(jid))
|
||||||
msg.send()
|
msg.send()
|
||||||
|
|
Loading…
Reference in a new issue