1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2026-02-04 07:51:50 +00:00
community.general/plugins/module_utils/cmd_runner.py
Felix Fontein c7f6a28d89
Add basic typing for module_utils (#11222)
* Add basic typing for module_utils.

* Apply some suggestions.

Co-authored-by: Alexei Znamensky <103110+russoz@users.noreply.github.com>

* Make pass again.

* Add more types as suggested.

* Normalize extra imports.

* Add more type hints.

* Improve typing.

* Add changelog fragment.

* Reduce changelog.

* Apply suggestions from code review.

Co-authored-by: Alexei Znamensky <103110+russoz@users.noreply.github.com>

* Fix typo.

* Cleanup.

* Improve types and make type checking happy.

* Let's see whether older Pythons barf on this.

* Revert "Let's see whether older Pythons barf on this."

This reverts commit 9973af3dbe.

* Add noqa.

---------

Co-authored-by: Alexei Znamensky <103110+russoz@users.noreply.github.com>
2025-12-01 20:40:06 +01:00

248 lines
8.3 KiB
Python

# Copyright (c) 2022, Alexei Znamensky <russoz@gmail.com>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
from __future__ import annotations
import os
import typing as t
from ansible.module_utils.common.collections import is_sequence
from ansible.module_utils.common.locale import get_best_parsable_locale
from ansible_collections.community.general.plugins.module_utils import cmd_runner_fmt
if t.TYPE_CHECKING:
from collections.abc import Callable, Mapping, Sequence
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.general.plugins.module_utils.cmd_runner_fmt import ArgFormatType
ArgFormatter = t.Union[ArgFormatType, cmd_runner_fmt._ArgFormat] # noqa: UP007
def _ensure_list(value):
return list(value) if is_sequence(value) else [value]
def _process_as_is(rc, out, err):
return rc, out, err
class CmdRunnerException(Exception):
pass
class MissingArgumentFormat(CmdRunnerException):
def __init__(self, arg, args_order: tuple[str, ...], args_formats) -> None:
self.args_order = args_order
self.arg = arg
self.args_formats = args_formats
def __repr__(self):
return f"MissingArgumentFormat({self.arg!r}, {self.args_order!r}, {self.args_formats!r})"
def __str__(self):
return f"Cannot find format for parameter {self.arg} {self.args_order} in: {self.args_formats}"
class MissingArgumentValue(CmdRunnerException):
def __init__(self, args_order: tuple[str, ...], arg) -> None:
self.args_order = args_order
self.arg = arg
def __repr__(self):
return f"MissingArgumentValue({self.args_order!r}, {self.arg!r})"
def __str__(self):
return f"Cannot find value for parameter {self.arg} in {self.args_order}"
class FormatError(CmdRunnerException):
def __init__(self, name, value, args_formats, exc):
self.name = name
self.value = value
self.args_formats = args_formats
self.exc = exc
super().__init__()
def __repr__(self):
return f"FormatError({self.name!r}, {self.value!r}, {self.args_formats!r}, {self.exc!r})"
def __str__(self):
return f"Failed to format parameter {self.name} with value {self.value}: {self.exc}"
class CmdRunner:
"""
Wrapper for ``AnsibleModule.run_command()``.
It aims to provide a reusable runner with consistent argument formatting
and sensible defaults.
"""
@staticmethod
def _prepare_args_order(order: str | Sequence[str]) -> tuple[str, ...]:
return tuple(order) if is_sequence(order) else tuple(order.split()) # type: ignore
def __init__(
self,
module: AnsibleModule,
command,
arg_formats: Mapping[str, ArgFormatter] | None = None,
default_args_order: str | Sequence[str] = (),
check_rc: bool = False,
force_lang: str = "C",
path_prefix: Sequence[str] | None = None,
environ_update: dict[str, str] | None = None,
):
self.module = module
self.command = _ensure_list(command)
self.default_args_order = self._prepare_args_order(default_args_order)
if arg_formats is None:
arg_formats = {}
self.arg_formats = {}
for fmt_name, fmt in arg_formats.items():
if not cmd_runner_fmt.is_argformat(fmt):
fmt = cmd_runner_fmt.as_func(func=fmt, ignore_none=True)
self.arg_formats[fmt_name] = fmt
self.check_rc = check_rc
if force_lang == "auto":
try:
self.force_lang = get_best_parsable_locale(module)
except RuntimeWarning:
self.force_lang = "C"
else:
self.force_lang = force_lang
self.path_prefix = path_prefix
if environ_update is None:
environ_update = {}
self.environ_update = environ_update
_cmd = self.command[0]
self.command[0] = (
_cmd
if (os.path.isabs(_cmd) or "/" in _cmd)
else module.get_bin_path(_cmd, opt_dirs=path_prefix, required=True)
)
@property
def binary(self) -> str:
return self.command[0]
def __call__(
self,
args_order: str | Sequence[str] | None = None,
output_process: Callable[[int, str, str], t.Any] | None = None,
check_mode_skip: bool = False,
check_mode_return: t.Any | None = None,
**kwargs,
):
if output_process is None:
output_process = _process_as_is
if args_order is None:
args_order = self.default_args_order
args_order = self._prepare_args_order(args_order)
for p in args_order:
if p not in self.arg_formats:
raise MissingArgumentFormat(p, args_order, tuple(self.arg_formats.keys()))
return _CmdRunnerContext(
runner=self,
args_order=args_order,
output_process=output_process,
check_mode_skip=check_mode_skip,
check_mode_return=check_mode_return,
**kwargs,
)
def has_arg_format(self, arg):
return arg in self.arg_formats
# not decided whether to keep it or not, but if deprecating it will happen in a farther future.
context = __call__
class _CmdRunnerContext:
def __init__(
self,
runner: CmdRunner,
args_order: tuple[str, ...],
output_process: Callable[[int, str, str], t.Any],
check_mode_skip: bool,
check_mode_return: t.Any,
**kwargs,
) -> None:
self.runner = runner
self.args_order = tuple(args_order)
self.output_process = output_process
self.check_mode_skip = check_mode_skip
self.check_mode_return = check_mode_return
self.run_command_args = dict(kwargs)
self.environ_update = runner.environ_update
self.environ_update.update(self.run_command_args.get("environ_update", {}))
if runner.force_lang:
self.environ_update.update(
{
"LANGUAGE": runner.force_lang,
"LC_ALL": runner.force_lang,
}
)
self.run_command_args["environ_update"] = self.environ_update
if "check_rc" not in self.run_command_args:
self.run_command_args["check_rc"] = runner.check_rc
self.check_rc = self.run_command_args["check_rc"]
self.cmd = None
self.results_rc = None
self.results_out = None
self.results_err = None
self.results_processed = None
def run(self, **kwargs):
runner = self.runner
module = self.runner.module
self.cmd = list(runner.command)
self.context_run_args = dict(kwargs)
named_args = dict(module.params)
named_args.update(kwargs)
for arg_name in self.args_order:
value = None
try:
if arg_name in named_args:
value = named_args[arg_name]
elif not runner.arg_formats[arg_name].ignore_missing_value:
raise MissingArgumentValue(self.args_order, arg_name)
self.cmd.extend(runner.arg_formats[arg_name](value))
except MissingArgumentValue:
raise
except Exception as e:
raise FormatError(arg_name, value, runner.arg_formats[arg_name], e) from e
if self.check_mode_skip and module.check_mode:
return self.check_mode_return
results = module.run_command(self.cmd, **self.run_command_args)
self.results_rc, self.results_out, self.results_err = results
self.results_processed = self.output_process(*results)
return self.results_processed
@property
def run_info(self) -> dict[str, t.Any]:
return dict(
check_rc=self.check_rc,
environ_update=self.environ_update,
args_order=self.args_order,
cmd=self.cmd,
run_command_args=self.run_command_args,
context_run_args=self.context_run_args,
results_rc=self.results_rc,
results_out=self.results_out,
results_err=self.results_err,
results_processed=self.results_processed,
)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return False