diff --git a/plugin.py b/plugin.py index c76718b..f793717 100644 --- a/plugin.py +++ b/plugin.py @@ -16,7 +16,7 @@ import base64 import asyncio from slixmpp.plugins.xep_0384.stanza import OMEMO_BASE_NS from slixmpp.plugins.xep_0384.stanza import OMEMO_DEVICES_NS, OMEMO_BUNDLES_NS -from slixmpp.plugins.xep_0384.stanza import Devices, Device, Encrypted, Key, PreKeyPublic +from slixmpp.plugins.xep_0384.stanza import Bundle, Devices, Device, Encrypted, Key, PreKeyPublic from slixmpp.plugins.base import BasePlugin, register_plugin from slixmpp.exceptions import IqError from slixmpp.stanza import Message, Iq @@ -26,8 +26,10 @@ log = logging.getLogger(__name__) HAS_OMEMO = True try: - from omemo import SessionManager + from omemo.exceptions import MissingBundleException + from omemo import SessionManager, ExtendedPublicBundle from omemo.util import generateDeviceID + from omemo.backends import Backend from omemo_backend_signal import BACKEND as SignalBackend from slixmpp.plugins.xep_0384.storage import SyncFileStorage from slixmpp.plugins.xep_0384.otpkpolicy import KeepingOTPKPolicy @@ -158,6 +160,30 @@ class XEP_0384(BasePlugin): iq = self._generate_bundle_iq() await iq.send() + async def _fetch_bundle(self, jid: str, device_id: int) -> Union[None, ExtendedPublicBundle]: + node = '%s:%d' % (OMEMO_BUNDLES_NS, device_id) + iq = await self.xmpp['xep_0060'].get_items(jid, node) + bundle = iq['pubsub']['items']['item']['bundle'] + + return self._parse_bundle(self._omemo_backend, bundle) + + def _parse_bundle(self, backend: Backend, bundle: Bundle) -> ExtendedPublicBundle: + ik = b64dec(bundle['identityKey']['value'].strip()) + spk = { + 'id': int(bundle['signedPreKeyPublic']['signedPreKeyId']), + 'key': b64dec(bundle['signedPreKeyPublic']['value'].strip()), + } + spk_signature = b64dec(bundle['signedPreKeySignature']['value'].strip()) + + otpks = [] + for prekey in bundle['prekeys']: + otpks.append({ + 'id': int(prekey['preKeyId']), + 'key': b64dec(prekey['value'].strip()), + }) + + return ExtendedPublicBundle.parse(backend, ik, spk, spk_signature, otpks) + def _store_device_ids(self, jid: str, items) -> None: device_ids = [] # type: List[int] for item in items: @@ -282,7 +308,11 @@ class XEP_0384(BasePlugin): break for (exn, key, val) in errors: - pass + if isinstance(exn, MissingBundleException): + bundle = await self._fetch_bundle(key, val) + if bundle is not None: + devices = bundles.setdefault(key, {}) + devices[val] = bundle break