diff --git a/poezio/decorators.py b/poezio/decorators.py index 6a853446..4b5ef1dc 100644 --- a/poezio/decorators.py +++ b/poezio/decorators.py @@ -3,11 +3,13 @@ Module containing various decorators """ from __future__ import annotations +from asyncio import iscoroutinefunction from typing import ( cast, Any, Callable, + Dict, List, Optional, TypeVar, @@ -21,6 +23,37 @@ if TYPE_CHECKING: T = TypeVar('T', bound=Callable[..., Any]) +BeforeFunc = Callable[[List[Any], Dict[str, Any]], Any] +AfterFunc = Callable[[List[Any], Dict[str, Any]], Any] + +def wrap_generic(func: Callable, before: BeforeFunc=None, after: AfterFunc=None): + """ + Generic wrapper which can both wrap coroutines and normal functions. + """ + def wrap(*args, **kwargs): + args = list(args) + if before is not None: + result = before(args, kwargs) + if result is not None: + return result + result = func(*args, **kwargs) + if after is not None: + result = after(result, args, kwargs) + return result + + async def awrap(*args, **kwargs): + args = list(args) + if before is not None: + result = before(args, kwargs) + if result is not None: + return result + result = await func(*args, **kwargs) + if after is not None: + result = after(result, args, kwargs) + return result + if iscoroutinefunction(func): + return awrap + return wrap class RefreshWrapper: @@ -32,12 +65,12 @@ class RefreshWrapper: Decorator to refresh the UI if the wrapped function returns True """ - - def wrap(*args: Any, **kwargs: Any) -> Any: - ret = func(*args, **kwargs) - if self.core and ret: + def after(result: Any, args, kwargs) -> Any: + if self.core and result: self.core.refresh_window() - return ret + return result + + wrap = wrap_generic(func, after=after) return cast(T, wrap) @@ -45,13 +78,12 @@ class RefreshWrapper: """ Decorator that refreshs the UI no matter what after the function """ - - def wrap(*args: Any, **kwargs: Any) -> Any: - ret = func(*args, **kwargs) + def after(result: Any, args, kwargs) -> Any: if self.core: self.core.refresh_window() - return ret + return result + wrap = wrap_generic(func, after=after) return cast(T, wrap) def update(self, func: T) -> T: @@ -59,12 +91,11 @@ class RefreshWrapper: Decorator that only updates the screen """ - def wrap(*args: Any, **kwargs: Any) -> Any: - ret = func(*args, **kwargs) + def after(result: Any, args, kwargs) -> Any: if self.core: self.core.doupdate() - return ret - + return result + wrap = wrap_generic(func, after=after) return cast(T, wrap) @@ -82,21 +113,18 @@ class CommandArgParser: """Just call the function with a single string, which is the original string untouched """ - - def wrap(self: Any, args: Any, *a: Any, **kw: Any) -> Any: - return func(self, args, *a, **kw) - - return cast(T, wrap) + return func @staticmethod def ignored(func: T) -> T: """ - Call the function without any argument + Call the function without textual arguments """ + def before(args: List[Any], kwargs: Dict[Any, Any]) -> None: + if len(args) >= 2: + del args[1] - def wrap(self: Any, args: Any = None, *a: Any, **kw: Any) -> Any: - return func(self, *a, **kw) - + wrap = wrap_generic(func, before=before) return cast(T, wrap) @staticmethod @@ -149,14 +177,16 @@ class CommandArgParser: default_args_outer = defaults or [] def first(func: T) -> T: - def second(self: Any, args: str, *a: Any, **kw: Any) -> Any: + def before(args: List, kwargs: Dict[str, Any]) -> Any: default_args = default_args_outer - if args and args.strip(): - split_args = common.shell_split(args) + cmdargs = args[1] + if cmdargs and cmdargs.strip(): + split_args = common.shell_split(cmdargs) else: split_args = [] if len(split_args) < mandatory: - return func(self, None, *a, **kw) + args[1] = None + return res, split_args = split_args[:mandatory], split_args[ mandatory:] if optional == -1: @@ -171,22 +201,25 @@ class CommandArgParser: res += default_args if split_args and res and not ignore_trailing_arguments: res[-1] += " " + " ".join(split_args) - return func(self, res, *a, **kw) - - return cast(T, second) + args[1] = res + return + wrap = wrap_generic(func, before=before) + return cast(T, wrap) return first - command_args_parser = CommandArgParser() def deny_anonymous(func: Callable) -> Callable: """Decorator to disable commands when using an anonymous account.""" - def wrap(self: RosterInfoTab, *args: Any, **kwargs: Any) -> Any: - if self.core.xmpp.anon: - return self.core.information( + + def before(args: Any, kwargs: Any) -> Any: + core = args[0].core + if core.xmpp.anon: + core.information( 'This command is not available for anonymous accounts.', 'Info' ) - return func(self, *args, **kwargs) + return False + wrap = wrap_generic(func, before=before) return cast(T, wrap)