decorators: make decorators work with coroutines
Tried the least ugly solution I could thing of.
This commit is contained in:
parent
f5ad5199ae
commit
695b2ee09a
1 changed files with 67 additions and 34 deletions
|
@ -3,11 +3,13 @@ Module containing various decorators
|
|||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
from typing import (
|
||||
cast,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
TypeVar,
|
||||
|
@ -21,6 +23,37 @@ if TYPE_CHECKING:
|
|||
|
||||
T = TypeVar('T', bound=Callable[..., Any])
|
||||
|
||||
BeforeFunc = Callable[[List[Any], Dict[str, Any]], Any]
|
||||
AfterFunc = Callable[[List[Any], Dict[str, Any]], Any]
|
||||
|
||||
def wrap_generic(func: Callable, before: BeforeFunc=None, after: AfterFunc=None):
|
||||
"""
|
||||
Generic wrapper which can both wrap coroutines and normal functions.
|
||||
"""
|
||||
def wrap(*args, **kwargs):
|
||||
args = list(args)
|
||||
if before is not None:
|
||||
result = before(args, kwargs)
|
||||
if result is not None:
|
||||
return result
|
||||
result = func(*args, **kwargs)
|
||||
if after is not None:
|
||||
result = after(result, args, kwargs)
|
||||
return result
|
||||
|
||||
async def awrap(*args, **kwargs):
|
||||
args = list(args)
|
||||
if before is not None:
|
||||
result = before(args, kwargs)
|
||||
if result is not None:
|
||||
return result
|
||||
result = await func(*args, **kwargs)
|
||||
if after is not None:
|
||||
result = after(result, args, kwargs)
|
||||
return result
|
||||
if iscoroutinefunction(func):
|
||||
return awrap
|
||||
return wrap
|
||||
|
||||
|
||||
class RefreshWrapper:
|
||||
|
@ -32,12 +65,12 @@ class RefreshWrapper:
|
|||
Decorator to refresh the UI if the wrapped function
|
||||
returns True
|
||||
"""
|
||||
|
||||
def wrap(*args: Any, **kwargs: Any) -> Any:
|
||||
ret = func(*args, **kwargs)
|
||||
if self.core and ret:
|
||||
def after(result: Any, args, kwargs) -> Any:
|
||||
if self.core and result:
|
||||
self.core.refresh_window()
|
||||
return ret
|
||||
return result
|
||||
|
||||
wrap = wrap_generic(func, after=after)
|
||||
|
||||
return cast(T, wrap)
|
||||
|
||||
|
@ -45,13 +78,12 @@ class RefreshWrapper:
|
|||
"""
|
||||
Decorator that refreshs the UI no matter what after the function
|
||||
"""
|
||||
|
||||
def wrap(*args: Any, **kwargs: Any) -> Any:
|
||||
ret = func(*args, **kwargs)
|
||||
def after(result: Any, args, kwargs) -> Any:
|
||||
if self.core:
|
||||
self.core.refresh_window()
|
||||
return ret
|
||||
return result
|
||||
|
||||
wrap = wrap_generic(func, after=after)
|
||||
return cast(T, wrap)
|
||||
|
||||
def update(self, func: T) -> T:
|
||||
|
@ -59,12 +91,11 @@ class RefreshWrapper:
|
|||
Decorator that only updates the screen
|
||||
"""
|
||||
|
||||
def wrap(*args: Any, **kwargs: Any) -> Any:
|
||||
ret = func(*args, **kwargs)
|
||||
def after(result: Any, args, kwargs) -> Any:
|
||||
if self.core:
|
||||
self.core.doupdate()
|
||||
return ret
|
||||
|
||||
return result
|
||||
wrap = wrap_generic(func, after=after)
|
||||
return cast(T, wrap)
|
||||
|
||||
|
||||
|
@ -82,21 +113,18 @@ class CommandArgParser:
|
|||
"""Just call the function with a single string, which is the original string
|
||||
untouched
|
||||
"""
|
||||
|
||||
def wrap(self: Any, args: Any, *a: Any, **kw: Any) -> Any:
|
||||
return func(self, args, *a, **kw)
|
||||
|
||||
return cast(T, wrap)
|
||||
return func
|
||||
|
||||
@staticmethod
|
||||
def ignored(func: T) -> T:
|
||||
"""
|
||||
Call the function without any argument
|
||||
Call the function without textual arguments
|
||||
"""
|
||||
def before(args: List[Any], kwargs: Dict[Any, Any]) -> None:
|
||||
if len(args) >= 2:
|
||||
del args[1]
|
||||
|
||||
def wrap(self: Any, args: Any = None, *a: Any, **kw: Any) -> Any:
|
||||
return func(self, *a, **kw)
|
||||
|
||||
wrap = wrap_generic(func, before=before)
|
||||
return cast(T, wrap)
|
||||
|
||||
@staticmethod
|
||||
|
@ -149,14 +177,16 @@ class CommandArgParser:
|
|||
default_args_outer = defaults or []
|
||||
|
||||
def first(func: T) -> T:
|
||||
def second(self: Any, args: str, *a: Any, **kw: Any) -> Any:
|
||||
def before(args: List, kwargs: Dict[str, Any]) -> Any:
|
||||
default_args = default_args_outer
|
||||
if args and args.strip():
|
||||
split_args = common.shell_split(args)
|
||||
cmdargs = args[1]
|
||||
if cmdargs and cmdargs.strip():
|
||||
split_args = common.shell_split(cmdargs)
|
||||
else:
|
||||
split_args = []
|
||||
if len(split_args) < mandatory:
|
||||
return func(self, None, *a, **kw)
|
||||
args[1] = None
|
||||
return
|
||||
res, split_args = split_args[:mandatory], split_args[
|
||||
mandatory:]
|
||||
if optional == -1:
|
||||
|
@ -171,22 +201,25 @@ class CommandArgParser:
|
|||
res += default_args
|
||||
if split_args and res and not ignore_trailing_arguments:
|
||||
res[-1] += " " + " ".join(split_args)
|
||||
return func(self, res, *a, **kw)
|
||||
|
||||
return cast(T, second)
|
||||
args[1] = res
|
||||
return
|
||||
wrap = wrap_generic(func, before=before)
|
||||
return cast(T, wrap)
|
||||
return first
|
||||
|
||||
|
||||
command_args_parser = CommandArgParser()
|
||||
|
||||
|
||||
def deny_anonymous(func: Callable) -> Callable:
|
||||
"""Decorator to disable commands when using an anonymous account."""
|
||||
def wrap(self: RosterInfoTab, *args: Any, **kwargs: Any) -> Any:
|
||||
if self.core.xmpp.anon:
|
||||
return self.core.information(
|
||||
|
||||
def before(args: Any, kwargs: Any) -> Any:
|
||||
core = args[0].core
|
||||
if core.xmpp.anon:
|
||||
core.information(
|
||||
'This command is not available for anonymous accounts.',
|
||||
'Info'
|
||||
)
|
||||
return func(self, *args, **kwargs)
|
||||
return False
|
||||
wrap = wrap_generic(func, before=before)
|
||||
return cast(T, wrap)
|
||||
|
|
Loading…
Reference in a new issue