diff --git a/changelogs/fragments/11257-typing.yml b/changelogs/fragments/11257-typing.yml new file mode 100644 index 0000000000..9f8077420a --- /dev/null +++ b/changelogs/fragments/11257-typing.yml @@ -0,0 +1,2 @@ +minor_changes: + - "ModuleHelper module utils - refactor some internals to allow better type checking (https://github.com/ansible-collections/community.general/pull/11257)." diff --git a/plugins/module_utils/mh/base.py b/plugins/module_utils/mh/base.py index 8d8017855a..b80805530f 100644 --- a/plugins/module_utils/mh/base.py +++ b/plugins/module_utils/mh/base.py @@ -14,7 +14,9 @@ from ansible_collections.community.general.plugins.module_utils.mh.exceptions im class ModuleHelperBase: - module: dict[str, t.Any] | None = None # TODO: better spec using t.TypedDict + # The type of module should be AnsibleModule, not something else. + # TODO: Rename the property of type dict[str, t.Any] | None to something like module_spec instead + module: dict[str, t.Any] | AnsibleModule | None = None # TODO: better spec using t.TypedDict ModuleHelperException = _MHE _delegated_to_module: tuple[str, ...] = ( "check_mode", @@ -23,23 +25,30 @@ class ModuleHelperBase: "deprecate", "debug", ) + _module: AnsibleModule # TODO: remove once module has proper type - def __init__(self, module=None): + def __init__(self, module: AnsibleModule | dict[str, t.Any] | None = None) -> None: self._changed = False if module: self.module = module if not isinstance(self.module, AnsibleModule): - self.module = AnsibleModule(**self.module) + if self.module is None: + raise TypeError("module or module spec must be provided") + module = AnsibleModule(**self.module) + self.module = module # type: ignore + self._module = module + else: + self._module = self.module @property - def diff_mode(self): - return self.module._diff + def diff_mode(self) -> bool: + return self._module._diff @property - def verbosity(self): - return self.module._verbosity + def verbosity(self) -> int: + return self._module._verbosity def do_raise(self, *args, **kwargs) -> t.NoReturn: raise _MHE(*args, **kwargs) diff --git a/plugins/module_utils/mh/deco.py b/plugins/module_utils/mh/deco.py index e05492b66a..30d7c0213a 100644 --- a/plugins/module_utils/mh/deco.py +++ b/plugins/module_utils/mh/deco.py @@ -6,15 +6,25 @@ from __future__ import annotations import traceback +import typing as t from functools import wraps from ansible_collections.community.general.plugins.module_utils.mh.exceptions import ModuleHelperException +if t.TYPE_CHECKING: + from collections.abc import Callable -def cause_changes(when=None): - def deco(func): + from .base import ModuleHelperBase + + P = t.ParamSpec("P") + S = t.TypeVar("S", bound=ModuleHelperBase) + T = t.TypeVar("T") + + +def cause_changes(when=None) -> Callable[[Callable[t.Concatenate[S, P], T]], Callable[t.Concatenate[S, P], None]]: + def deco(func: Callable[t.Concatenate[S, P], T]) -> Callable[t.Concatenate[S, P], None]: @wraps(func) - def wrapper(self, *args, **kwargs): + def wrapper(self: S, *args: P.args, **kwargs: P.kwargs) -> None: try: func(self, *args, **kwargs) if when == "success": @@ -32,11 +42,11 @@ def cause_changes(when=None): return deco -def module_fails_on_exception(func): +def module_fails_on_exception(func: Callable[t.Concatenate[S, P], T]) -> Callable[t.Concatenate[S, P], None]: conflict_list = ("msg", "exception", "output", "vars", "changed") @wraps(func) - def wrapper(self, *args, **kwargs): + def wrapper(self: S, *args: P.args, **kwargs: P.kwargs) -> None: def fix_key(k): return k if k not in conflict_list else f"_{k}" @@ -51,36 +61,57 @@ def module_fails_on_exception(func): self.update_output(e.update_output) # patchy solution to resolve conflict with output variables output = fix_var_conflicts(self.output) - self.module.fail_json( + self._module.fail_json( msg=e.msg, exception=traceback.format_exc(), output=self.output, vars=self.vars.output(), **output ) except Exception as e: # patchy solution to resolve conflict with output variables output = fix_var_conflicts(self.output) msg = f"Module failed with exception: {str(e).strip()}" - self.module.fail_json( + self._module.fail_json( msg=msg, exception=traceback.format_exc(), output=self.output, vars=self.vars.output(), **output ) return wrapper -def check_mode_skip(func): +def check_mode_skip(func: Callable[t.Concatenate[S, P], T]) -> Callable[t.Concatenate[S, P], T | None]: @wraps(func) - def wrapper(self, *args, **kwargs): - if not self.module.check_mode: + def wrapper(self: S, *args: P.args, **kwargs: P.kwargs) -> T | None: + if not self._module.check_mode: return func(self, *args, **kwargs) + return None return wrapper -def check_mode_skip_returns(callable=None, value=None): - def deco(func): +@t.overload +def check_mode_skip_returns( + callable: Callable[t.Concatenate[S, P], T], value: T | None = None +) -> Callable[[Callable[t.Concatenate[S, P], T]], Callable[t.Concatenate[S, P], T]]: ... + + +@t.overload +def check_mode_skip_returns( + callable: None, value: T +) -> Callable[[Callable[t.Concatenate[S, P], T]], Callable[t.Concatenate[S, P], T]]: ... + + +@t.overload +def check_mode_skip_returns( + callable: None = None, *, value: T +) -> Callable[[Callable[t.Concatenate[S, P], T]], Callable[t.Concatenate[S, P], T]]: ... + + +def check_mode_skip_returns( + callable: Callable[t.Concatenate[S, P], T] | None = None, value: T | None = None +) -> Callable[[Callable[t.Concatenate[S, P], T]], Callable[t.Concatenate[S, P], T]]: + def deco(func: Callable[t.Concatenate[S, P], T]) -> Callable[t.Concatenate[S, P], T]: if callable is not None: @wraps(func) - def wrapper_callable(self, *args, **kwargs): - if self.module.check_mode: + def wrapper_callable(self: S, *args: P.args, **kwargs: P.kwargs) -> T: + if self._module.check_mode: return callable(self, *args, **kwargs) return func(self, *args, **kwargs) @@ -89,9 +120,9 @@ def check_mode_skip_returns(callable=None, value=None): else: @wraps(func) - def wrapper_value(self, *args, **kwargs): - if self.module.check_mode: - return value + def wrapper_value(self: S, *args: P.args, **kwargs: P.kwargs) -> T: + if self._module.check_mode: + return value # type: ignore # must be of type T due to the overloads return func(self, *args, **kwargs) return wrapper_value diff --git a/plugins/module_utils/mh/module_helper.py b/plugins/module_utils/mh/module_helper.py index b41cf48639..44460f31d7 100644 --- a/plugins/module_utils/mh/module_helper.py +++ b/plugins/module_utils/mh/module_helper.py @@ -31,7 +31,7 @@ class ModuleHelper(DeprecateAttrsMixin, ModuleHelperBase): super().__init__(module) self.vars = VarDict() - for name, value in self.module.params.items(): # type: ignore[union-attr] + for name, value in self._module.params.items(): self.vars.set( name, value, diff --git a/tests/sanity/ignore-2.17.txt b/tests/sanity/ignore-2.17.txt index b96f5d20fc..159e775828 100644 --- a/tests/sanity/ignore-2.17.txt +++ b/tests/sanity/ignore-2.17.txt @@ -8,5 +8,6 @@ plugins/modules/rhevm.py validate-modules:parameter-state-invalid-choice plugins/modules/udm_user.py import-3.11 # Uses deprecated stdlib library 'crypt' plugins/modules/udm_user.py import-3.12 # Uses deprecated stdlib library 'crypt' plugins/modules/xfconf.py validate-modules:return-syntax-error +plugins/module_utils/mh/deco.py pep8:E704 plugins/plugin_utils/unsafe.py pep8:E704 tests/unit/plugins/modules/test_gio_mime.yaml no-smart-quotes diff --git a/tests/sanity/ignore-2.18.txt b/tests/sanity/ignore-2.18.txt index b96f5d20fc..159e775828 100644 --- a/tests/sanity/ignore-2.18.txt +++ b/tests/sanity/ignore-2.18.txt @@ -8,5 +8,6 @@ plugins/modules/rhevm.py validate-modules:parameter-state-invalid-choice plugins/modules/udm_user.py import-3.11 # Uses deprecated stdlib library 'crypt' plugins/modules/udm_user.py import-3.12 # Uses deprecated stdlib library 'crypt' plugins/modules/xfconf.py validate-modules:return-syntax-error +plugins/module_utils/mh/deco.py pep8:E704 plugins/plugin_utils/unsafe.py pep8:E704 tests/unit/plugins/modules/test_gio_mime.yaml no-smart-quotes