Update omemo lib to 0.13 and asyncio changes

Signed-off-by: Maxime “pep” Buquet <pep@bouah.net>
This commit is contained in:
Maxime “pep” Buquet 2021-12-15 22:52:57 +01:00
parent 080a27e7d8
commit 29bf6e8650
2 changed files with 78 additions and 56 deletions

View file

@ -115,9 +115,11 @@ class EchoBot(ClientXMPP):
return await self.encrypted_reply(mto, mtype, body) return await self.encrypted_reply(mto, mtype, body)
async def cmd_chain_length(self, mto: JID, mtype: str) -> None: async def cmd_chain_length(self, mto: JID, mtype: str) -> None:
chain_length = await self['xep_0384']._chain_lengths(mto)
should_heartbeat = await self['xep_0384'].should_heartbeat(mto)
body = ( body = (
'lengths: %r\n' % self['xep_0384']._chain_lengths(mto) + 'lengths: %r\n' % chain_length +
'should heartbeat: %r' % self['xep_0384'].should_heartbeat(mto) 'should heartbeat: %r' % should_heartbeat
) )
return await self.encrypted_reply(mto, mtype, body) return await self.encrypted_reply(mto, mtype, body)
@ -149,7 +151,7 @@ class EchoBot(ClientXMPP):
try: try:
encrypted = msg['omemo_encrypted'] encrypted = msg['omemo_encrypted']
body = self['xep_0384'].decrypt_message(encrypted, mfrom, allow_untrusted) body = await self['xep_0384'].decrypt_message(encrypted, mfrom, allow_untrusted)
# decrypt_message returns Optional[str]. It is possible to get # decrypt_message returns Optional[str]. It is possible to get
# body-less OMEMO message (see KeyTransportMessages), currently # body-less OMEMO message (see KeyTransportMessages), currently
# used for example to send heartbeats to other devices. # used for example to send heartbeats to other devices.
@ -249,7 +251,7 @@ class EchoBot(ClientXMPP):
# untrusted/undecided barejid, so we need to make a decision here. # untrusted/undecided barejid, so we need to make a decision here.
# This is where you prompt your user to ask what to do. In # This is where you prompt your user to ask what to do. In
# this bot we will automatically trust undecided recipients. # this bot we will automatically trust undecided recipients.
self['xep_0384'].trust(exn.bare_jid, exn.device, exn.ik) await self['xep_0384'].trust(exn.bare_jid, exn.device, exn.ik)
# TODO: catch NoEligibleDevicesException # TODO: catch NoEligibleDevicesException
except EncryptionPrepareException as exn: except EncryptionPrepareException as exn:
# This exception is being raised when the library has tried # This exception is being raised when the library has tried

View file

@ -162,6 +162,9 @@ class MissingOwnKey(XEP0384): pass
class NoAvailableSession(XEP0384): pass class NoAvailableSession(XEP0384): pass
class UninitializedOMEMOSession(XEP0384): pass
class EncryptionPrepareException(XEP0384): class EncryptionPrepareException(XEP0384):
def __init__(self, errors): def __init__(self, errors):
self.errors = errors self.errors = errors
@ -212,6 +215,9 @@ class XEP_0384(BasePlugin):
# Used at startup to prevent publishing device list and bundles multiple times # Used at startup to prevent publishing device list and bundles multiple times
_initial_publish_done = False _initial_publish_done = False
# Initiated once the OMEMO session is created.
__omemo_session: Optional[SessionManager] = None
def plugin_init(self) -> None: def plugin_init(self) -> None:
if not self.backend_loaded: if not self.backend_loaded:
log_str = ("xep_0384 cannot be loaded as the backend omemo library " log_str = ("xep_0384 cannot be loaded as the backend omemo library "
@ -228,35 +234,18 @@ class XEP_0384(BasePlugin):
raise PluginCouldNotLoad("xep_0384 cannot be loaded as there is " raise PluginCouldNotLoad("xep_0384 cannot be loaded as there is "
"no data directory specified.") "no data directory specified.")
storage = self.storage_backend
if self.storage_backend is None:
storage = JSONFileStorage(self.data_dir)
otpkpolicy = self.otpk_policy
bare_jid = self.xmpp.boundjid.bare
self._device_id = _load_device_id(self.data_dir) self._device_id = _load_device_id(self.data_dir)
asyncio.ensure_future(self.session_start_omemo())
try:
self._omemo = SessionManager.create(
storage,
otpkpolicy,
self.omemo_backend,
bare_jid,
self._device_id,
)
except:
log.error("Couldn't load the OMEMO object; ¯\\_(ツ)_/¯")
raise PluginCouldNotLoad
self.xmpp.add_event_handler('session_start', self.session_start) self.xmpp.add_event_handler('session_start', self.session_start)
self.xmpp['xep_0060'].map_node_event(OMEMO_DEVICES_NS, 'omemo_device_list') self.xmpp['xep_0060'].map_node_event(OMEMO_DEVICES_NS, 'omemo_device_list')
self.xmpp.add_event_handler('omemo_device_list_publish', self._receive_device_list) self.xmpp.add_event_handler('omemo_device_list_publish', self._receive_device_list)
# If this plugin is loaded after 'session_start' has fired, we still
# need to publish bundles
if self.xmpp.is_connected and not self._initial_publish_done: if self.xmpp.is_connected and not self._initial_publish_done:
asyncio.ensure_future(self._initial_publish()) asyncio.ensure_future(self._initial_publish())
return None
def plugin_end(self): def plugin_end(self):
if not self.backend_loaded: if not self.backend_loaded:
return return
@ -265,6 +254,34 @@ class XEP_0384(BasePlugin):
self.xmpp.remove_event_handler('omemo_device_list_publish', self._receive_device_list) self.xmpp.remove_event_handler('omemo_device_list_publish', self._receive_device_list)
self.xmpp['xep_0163'].remove_interest(OMEMO_DEVICES_NS) self.xmpp['xep_0163'].remove_interest(OMEMO_DEVICES_NS)
async def session_start_omemo(self):
"""Creates the OMEMO session object"""
storage = self.storage_backend
if self.storage_backend is None:
storage = JSONFileStorage(self.data_dir)
otpkpolicy = self.otpk_policy
bare_jid = self.xmpp.boundjid.bare
try:
self.__omemo_session = await SessionManager.create(
storage,
otpkpolicy,
self.omemo_backend,
bare_jid,
self._device_id,
)
except Exception as exn:
log.error("Couldn't load the OMEMO object; ¯\\_(ツ)_/¯")
raise PluginCouldNotLoad from exn
def _omemo(self) -> SessionManager:
"""Helper method to unguard potentially uninitialized SessionManager"""
if self.__omemo_session is None:
raise UninitializedOMEMOSession
return self.__omemo_session
async def session_start(self, _jid): async def session_start(self, _jid):
await self._initial_publish() await self._initial_publish()
@ -280,8 +297,8 @@ class XEP_0384(BasePlugin):
def my_device_id(self) -> int: def my_device_id(self) -> int:
return self._device_id return self._device_id
def my_fingerprint(self) -> str: async def my_fingerprint(self) -> str:
bundle = self._omemo.public_bundle.serialize(self.omemo_backend) bundle = await self._omemo().public_bundle.serialize(self.omemo_backend)
return fp_from_ik(bundle['ik']) return fp_from_ik(bundle['ik'])
def _set_node_config( def _set_node_config(
@ -331,7 +348,7 @@ class XEP_0384(BasePlugin):
) )
async def _generate_bundle_iq(self, publish_options: bool = True) -> Iq: async def _generate_bundle_iq(self, publish_options: bool = True) -> Iq:
bundle = self._omemo.public_bundle.serialize(self.omemo_backend) bundle = self._omemo().public_bundle.serialize(self.omemo_backend)
jid = self.xmpp.boundjid jid = self.xmpp.boundjid
disco = await self.xmpp['xep_0030'].get_info(jid.bare) disco = await self.xmpp['xep_0030'].get_info(jid.bare)
@ -369,7 +386,7 @@ class XEP_0384(BasePlugin):
return iq return iq
async def _publish_bundle(self) -> None: async def _publish_bundle(self) -> None:
if self._omemo.republish_bundle: if self._omemo().republish_bundle:
iq = await self._generate_bundle_iq() iq = await self._generate_bundle_iq()
try: try:
await iq.send() await iq.send()
@ -408,13 +425,13 @@ class XEP_0384(BasePlugin):
iq = await self.xmpp['xep_0060'].get_items(jid.full, OMEMO_DEVICES_NS) iq = await self.xmpp['xep_0060'].get_items(jid.full, OMEMO_DEVICES_NS)
return await self._read_device_list(jid, iq['pubsub']['items']) return await self._read_device_list(jid, iq['pubsub']['items'])
def _store_device_ids(self, jid: str, items: Union[Items, EventItems]) -> None: async def _store_device_ids(self, jid: str, items: Union[Items, EventItems]) -> None:
"""Store Device list""" """Store Device list"""
device_ids = [] # type: List[int] device_ids = [] # type: List[int]
items = list(items) items = list(items)
if items: if items:
device_ids = [int(d['id']) for d in items[0]['devices']] device_ids = [int(d['id']) for d in items[0]['devices']]
return self._omemo.newDeviceList(str(jid), device_ids) return await self._omemo().newDeviceList(str(jid), device_ids)
def _receive_device_list(self, msg: Message) -> None: def _receive_device_list(self, msg: Message) -> None:
"""Handler for received PEP OMEMO_DEVICES_NS payloads""" """Handler for received PEP OMEMO_DEVICES_NS payloads"""
@ -425,7 +442,7 @@ class XEP_0384(BasePlugin):
async def _read_device_list(self, jid: JID, items: Union[Items, EventItems]) -> None: async def _read_device_list(self, jid: JID, items: Union[Items, EventItems]) -> None:
"""Read items and devices if we need to set the device list again or not""" """Read items and devices if we need to set the device list again or not"""
bare_jid = jid.bare bare_jid = jid.bare
self._store_device_ids(bare_jid, items) await self._store_device_ids(bare_jid, items)
items = list(items) items = list(items)
device_ids = [] device_ids = []
@ -446,15 +463,15 @@ class XEP_0384(BasePlugin):
own_jid.bare, OMEMO_DEVICES_NS, own_jid.bare, OMEMO_DEVICES_NS,
) )
items = iq['pubsub']['items'] items = iq['pubsub']['items']
self._store_device_ids(own_jid.bare, items) await self._store_device_ids(own_jid.bare, items)
except IqError as iq_err: except IqError as iq_err:
if iq_err.condition == "item-not-found": if iq_err.condition == "item-not-found":
self._store_device_ids(own_jid.bare, []) await self._store_device_ids(own_jid.bare, [])
else: else:
return # XXX: Handle this! return # XXX: Handle this!
if device_ids is None: if device_ids is None:
device_ids = self.get_device_list(own_jid) device_ids = await self.get_device_list(own_jid)
devices = [] devices = []
for i in device_ids: for i in device_ids:
@ -502,11 +519,12 @@ class XEP_0384(BasePlugin):
own_jid.bare, OMEMO_DEVICES_NS, payload=payload, own_jid.bare, OMEMO_DEVICES_NS, payload=payload,
) )
def get_device_list(self, jid: JID) -> List[str]: async def get_device_list(self, jid: JID) -> List[str]:
"""Return active device ids. Always contains our own device id.""" """Return active device ids. Always contains our own device id."""
return self._omemo.getDevices(jid.bare).get('active', []) devices = await self._omemo().getDevices(jid.bare)
return devices.get('active', [])
def _chain_lengths(self, jid: JID) -> ChainLengths: async def _chain_lengths(self, jid: JID) -> ChainLengths:
""" """
Gather receiving and sending chain lengths for all devices (active Gather receiving and sending chain lengths for all devices (active
/ inactive) of a JID. / inactive) of a JID.
@ -525,14 +543,14 @@ class XEP_0384(BasePlugin):
# OMEMO library as of 0.12 (9fd7123). # OMEMO library as of 0.12 (9fd7123).
bare = jid.bare bare = jid.bare
devices = self._omemo.getDevices(bare) devices = await self._omemo().getDevices(bare)
active = devices.get('active', set()) active = devices.get('active', set())
inactive = devices.get('inactive', set()) inactive = devices.get('inactive', set())
devices = active.union(inactive) devices = active.union(inactive)
lengths: ChainLengths = {'sending': [], 'receiving': []} lengths: ChainLengths = {'sending': [], 'receiving': []}
for did in devices: for did in devices:
session = self._omemo._SessionManager__loadSession(bare, did) session = self._omemo()._SessionManager__loadSession(bare, did)
if session is None: if session is None:
continue continue
skr = session._DoubleRatchet__skr skr = session._DoubleRatchet__skr
@ -543,7 +561,7 @@ class XEP_0384(BasePlugin):
return lengths return lengths
def _should_heartbeat(self, jid: JID, prekey: bool) -> bool: async def _should_heartbeat(self, jid: JID, prekey: bool) -> bool:
""" """
Internal helper for :py:func:`XEP_0384.should_heartbeat`. Internal helper for :py:func:`XEP_0384.should_heartbeat`.
@ -557,13 +575,14 @@ class XEP_0384(BasePlugin):
still active. still active.
""" """
receiving_chain_lengths = self._chain_lengths(jid).get('receiving', []) chain_lengths = await self._chain_lengths(jid)
receiving_chain_lengths = chain_lengths.get('receiving', [])
lengths = map(lambda d_l: d_l[1], receiving_chain_lengths) lengths = map(lambda d_l: d_l[1], receiving_chain_lengths)
inactive_session = max(lengths, default=0) > self.heartbeat_after inactive_session = max(lengths, default=0) > self.heartbeat_after
return prekey or inactive_session return prekey or inactive_session
def should_heartbeat(self, jid: JID, msg: Union[Message, Encrypted]) -> bool: async def should_heartbeat(self, jid: JID, msg: Union[Message, Encrypted]) -> bool:
""" """
Returns whether we should send a heartbeat message for JID. Returns whether we should send a heartbeat message for JID.
See notes about heartbeat in See notes about heartbeat in
@ -590,7 +609,7 @@ class XEP_0384(BasePlugin):
key = Key(key) key = Key(key)
prekey = key['prekey'] in TRUE_VALUES prekey = key['prekey'] in TRUE_VALUES
return self._should_heartbeat(jid, prekey) return await self._should_heartbeat(jid, prekey)
async def make_heartbeat(self, jid: JID) -> Message: async def make_heartbeat(self, jid: JID) -> Message:
""" """
@ -611,13 +630,13 @@ class XEP_0384(BasePlugin):
msg.append(encrypted) msg.append(encrypted)
return msg return msg
def trust(self, jid: JID, device_id: int, ik: bytes) -> None: async def trust(self, jid: JID, device_id: int, ik: bytes) -> None:
self._omemo.setTrust(jid.bare, device_id, ik, True) await self._omemo().setTrust(jid.bare, device_id, ik, True)
def distrust(self, jid: JID, device_id: int, ik: bytes) -> None: async def distrust(self, jid: JID, device_id: int, ik: bytes) -> None:
self._omemo.setTrust(jid.bare, device_id, ik, False) await self._omemo().setTrust(jid.bare, device_id, ik, False)
def get_trust_for_jid(self, jid: JID) -> Dict[str, List[Optional[Dict[str, Any]]]]: async def get_trust_for_jid(self, jid: JID) -> Dict[str, List[Optional[Dict[str, Any]]]]:
""" """
Fetches trust for JID. The returned dictionary will contain active Fetches trust for JID. The returned dictionary will contain active
and inactive devices. Each of these dict will contain device ids and inactive devices. Each of these dict will contain device ids
@ -638,12 +657,12 @@ class XEP_0384(BasePlugin):
} }
""" """
return self._omemo.getTrustForJID(jid.bare) return await self._omemo().getTrustForJID(jid.bare)
def is_encrypted(self, msg: Message) -> bool: def is_encrypted(self, msg: Message) -> bool:
return msg.xml.find('{%s}encrypted' % OMEMO_BASE_NS) is not None return msg.xml.find('{%s}encrypted' % OMEMO_BASE_NS) is not None
def decrypt_message( async def decrypt_message(
self, self,
encrypted: Encrypted, encrypted: Encrypted,
sender: JID, sender: JID,
@ -676,7 +695,7 @@ class XEP_0384(BasePlugin):
# is passed. We do not implement this yet. # is passed. We do not implement this yet.
try: try:
if payload is None: if payload is None:
self._omemo.decryptRatchetFowardingMessage( await self._omemo().decryptRatchetFowardingMessage(
jid, jid,
sid, sid,
iv, iv,
@ -686,7 +705,7 @@ class XEP_0384(BasePlugin):
) )
body = None body = None
else: else:
body = self._omemo.decryptMessage( body = await self._omemo().decryptMessage(
jid, jid,
sid, sid,
iv, iv,
@ -710,7 +729,8 @@ class XEP_0384(BasePlugin):
finally: finally:
asyncio.ensure_future(self._publish_bundle()) asyncio.ensure_future(self._publish_bundle())
if self.auto_heartbeat and self._should_heartbeat(sender, isPrekeyMessage): should_heartbeat = await self._should_heartbeat(sender, isPrekeyMessage)
if self.auto_heartbeat and should_heartbeat:
async def send_heartbeat(): async def send_heartbeat():
log.debug('Sending a heartbeat message') log.debug('Sending a heartbeat message')
msg = await self.make_heartbeat(JID(jid)) msg = await self.make_heartbeat(JID(jid))
@ -757,20 +777,20 @@ class XEP_0384(BasePlugin):
try: try:
if plaintext is not None: if plaintext is not None:
encrypted = self._omemo.encryptMessage( encrypted = await self._omemo().encryptMessage(
recipients, recipients,
plaintext.encode('utf-8'), plaintext.encode('utf-8'),
self.bundles, self.bundles,
expect_problems=expect_problems, expect_problems=expect_problems,
) )
elif _ignore_trust: elif _ignore_trust:
encrypted = self._omemo.encryptRatchetForwardingMessage( encrypted = await self._omemo().encryptRatchetForwardingMessage(
recipients, recipients,
self.bundles, self.bundles,
expect_problems=expect_problems, expect_problems=expect_problems,
) )
else: else:
encrypted = self._omemo.encryptKeyTransportMessage( encrypted = await self._omemo().encryptKeyTransportMessage(
recipients, recipients,
self.bundles, self.bundles,
expect_problems=expect_problems, expect_problems=expect_problems,