api: fix typing

This commit is contained in:
mathieui 2021-07-05 20:09:59 +02:00
parent ea7f7d8119
commit fe1a325aa7

View file

@ -21,7 +21,7 @@ class APIWrapper(object):
if name not in self.api.settings: if name not in self.api.settings:
self.api.settings[name] = {} self.api.settings[name] = {}
def __getattr__(self, attr): def __getattr__(self, attr: str):
"""Curry API management commands with the API name.""" """Curry API management commands with the API name."""
if attr == 'name': if attr == 'name':
return self.name return self.name
@ -33,13 +33,13 @@ class APIWrapper(object):
return register(handler, self.name, op, jid, node, default) return register(handler, self.name, op, jid, node, default)
return partial return partial
elif attr == 'register_default': elif attr == 'register_default':
def partial(handler, op, jid=None, node=None): def partial1(handler, op, jid=None, node=None):
return getattr(self.api, attr)(handler, self.name, op) return getattr(self.api, attr)(handler, self.name, op)
return partial return partial1
elif attr in ('run', 'restore_default', 'unregister'): elif attr in ('run', 'restore_default', 'unregister'):
def partial(*args, **kwargs): def partial2(*args, **kwargs):
return getattr(self.api, attr)(self.name, *args, **kwargs) return getattr(self.api, attr)(self.name, *args, **kwargs)
return partial return partial2
return None return None
def __getitem__(self, attr): def __getitem__(self, attr):
@ -82,7 +82,7 @@ class APIRegistry(object):
"""Return a wrapper object that targets a specific API.""" """Return a wrapper object that targets a specific API."""
return APIWrapper(self, ctype) return APIWrapper(self, ctype)
def purge(self, ctype: str): def purge(self, ctype: str) -> None:
"""Remove all information for a given API.""" """Remove all information for a given API."""
del self.settings[ctype] del self.settings[ctype]
del self._handler_defaults[ctype] del self._handler_defaults[ctype]
@ -131,22 +131,23 @@ class APIRegistry(object):
jid = JID(jid) jid = JID(jid)
elif jid == JID(''): elif jid == JID(''):
jid = self.xmpp.boundjid jid = self.xmpp.boundjid
assert jid is not None
if node is None: if node is None:
node = '' node = ''
if self.xmpp.is_component: if self.xmpp.is_component:
if self.settings[ctype].get('component_bare', False): if self.settings[ctype].get('component_bare', False):
jid = jid.bare jid_str = jid.bare
else: else:
jid = jid.full jid_str = jid.full
else: else:
if self.settings[ctype].get('client_bare', False): if self.settings[ctype].get('client_bare', False):
jid = jid.bare jid_str = jid.bare
else: else:
jid = jid.full jid_str = jid.full
jid = JID(jid) jid = JID(jid_str)
handler = self._handlers[ctype][op]['node'].get((jid, node), None) handler = self._handlers[ctype][op]['node'].get((jid, node), None)
if handler is None: if handler is None:
@ -167,8 +168,11 @@ class APIRegistry(object):
# To preserve backward compatibility, drop the ifrom # To preserve backward compatibility, drop the ifrom
# parameter for existing handlers that don't understand it. # parameter for existing handlers that don't understand it.
return handler(jid, node, args) return handler(jid, node, args)
future = Future()
future.set_result(None)
return future
def register(self, handler: APIHandler, ctype: str, op: str, def register(self, handler: Optional[APIHandler], ctype: str, op: str,
jid: Optional[JID] = None, node: Optional[str] = None, jid: Optional[JID] = None, node: Optional[str] = None,
default: bool = False): default: bool = False):
"""Register an API callback, with JID+node specificity. """Register an API callback, with JID+node specificity.