1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2026-02-04 07:51:50 +00:00

Add type hints to action and test plugins and to plugin utils; fix some bugs, and improve input validation (#11167)

* Add type hints to action and test plugins and to plugin utils. Also fix some bugs and add proper input validation.

* Combine lines.

Co-authored-by: Alexei Znamensky <103110+russoz@users.noreply.github.com>

* Extend changelog fragment.

* Move task_vars initialization up.

---------

Co-authored-by: Alexei Znamensky <103110+russoz@users.noreply.github.com>
This commit is contained in:
Felix Fontein 2025-11-22 22:52:21 +01:00 committed by GitHub
parent 4517b86ed4
commit 19757b3a4c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 194 additions and 96 deletions

View file

@ -5,6 +5,7 @@
from __future__ import annotations
import time
import typing as t
from ansible.plugins.action import ActionBase
from ansible.errors import AnsibleActionFail, AnsibleConnectionFailure
@ -20,7 +21,7 @@ class ActionModule(ActionBase):
DEFAULT_SUDOABLE = True
@staticmethod
def msg_error__async_and_poll_not_zero(task_poll, task_async, max_timeout):
def msg_error__async_and_poll_not_zero(task_poll, task_async, max_timeout) -> str:
return (
"This module doesn't support async>0 and poll>0 when its 'state' param "
"is set to 'restored'. To enable its rollback feature (that needs the "
@ -30,7 +31,7 @@ class ActionModule(ActionBase):
)
@staticmethod
def msg_warning__no_async_is_no_rollback(task_poll, task_async, max_timeout):
def msg_warning__no_async_is_no_rollback(task_poll, task_async, max_timeout) -> str:
return (
"Attempts to restore iptables state without rollback in case of mistake "
"may lead the ansible controller to loose access to the hosts and never "
@ -41,7 +42,7 @@ class ActionModule(ActionBase):
)
@staticmethod
def msg_warning__async_greater_than_timeout(task_poll, task_async, max_timeout):
def msg_warning__async_greater_than_timeout(task_poll, task_async, max_timeout) -> str:
return (
"You attempt to restore iptables state with rollback in case of mistake, "
"but with settings that will lead this rollback to happen AFTER that the "
@ -50,7 +51,9 @@ class ActionModule(ActionBase):
f"'ansible_timeout' (={max_timeout}) (recommended)."
)
def _async_result(self, async_status_args, task_vars, timeout):
def _async_result(
self, async_status_args: dict[str, t.Any], task_vars: dict[str, t.Any], timeout: int
) -> dict[str, t.Any]:
"""
Retrieve results of the asynchronous task, and display them in place of
the async wrapper results (those with the ansible_job_id key).
@ -81,10 +84,13 @@ class ActionModule(ActionBase):
return async_result
def run(self, tmp=None, task_vars=None):
def run(self, tmp: str | None = None, task_vars: dict[str, t.Any] | None = None) -> dict[str, t.Any]:
self._supports_check_mode = True
self._supports_async = True
if task_vars is None:
task_vars = {}
result = super().run(tmp, task_vars)
del tmp # tmp no longer has any effect

View file

@ -7,19 +7,25 @@
from __future__ import annotations
import typing as t
from ansible.errors import AnsibleError, AnsibleConnectionFailure
from ansible.module_utils.common.text.converters import to_native, to_text
from ansible.module_utils.common.collections import is_string
from ansible.plugins.action import ActionBase
from ansible.utils.display import Display
if t.TYPE_CHECKING:
class Distribution(t.TypedDict):
name: str
version: str
family: str
display = Display()
def fmt(mapping, key):
return to_native(mapping[key]).strip()
class TimedOutException(Exception):
pass
@ -60,25 +66,23 @@ class ActionModule(ActionBase):
def delay(self):
return self._check_delay("delay", self.DEFAULT_PRE_SHUTDOWN_DELAY)
def _check_delay(self, key, default):
def _check_delay(self, key: str, default: int) -> int:
"""Ensure that the value is positive or zero"""
value = int(self._task.args.get(key, default))
if value < 0:
value = 0
return value
def _get_value_from_facts(self, variable_name, distribution, default_value):
@staticmethod
def _get_value_from_facts(data: dict[str, str], distribution: Distribution, default_value: str) -> str:
"""Get dist+version specific args first, then distribution, then family, lastly use default"""
attr = getattr(self, variable_name)
value = attr.get(
return data.get(
distribution["name"] + distribution["version"],
attr.get(distribution["name"], attr.get(distribution["family"], getattr(self, default_value))),
data.get(distribution["name"], data.get(distribution["family"], default_value)),
)
return value
def get_distribution(self, task_vars):
def get_distribution(self, task_vars: dict[str, t.Any]) -> Distribution:
# FIXME: only execute the module if we don't already have the facts we need
distribution = {}
display.debug(f"{self._task.action}: running setup module to get distribution")
module_output = self._execute_module(
task_vars=task_vars, module_name="ansible.legacy.setup", module_args={"gather_subset": "min"}
@ -86,20 +90,20 @@ class ActionModule(ActionBase):
try:
if module_output.get("failed", False):
raise AnsibleError(
f"Failed to determine system distribution. {fmt(module_output, 'module_stdout')}, {fmt(module_output, 'module_stderr')}"
f"Failed to determine system distribution. {to_native(module_output['module_stdout'])}, {to_native(module_output['module_stderr'])}"
)
distribution["name"] = module_output["ansible_facts"]["ansible_distribution"].lower()
distribution["version"] = to_text(
module_output["ansible_facts"]["ansible_distribution_version"].split(".")[0]
)
distribution["family"] = to_text(module_output["ansible_facts"]["ansible_os_family"].lower())
distribution: Distribution = {
"name": module_output["ansible_facts"]["ansible_distribution"].lower(),
"version": to_text(module_output["ansible_facts"]["ansible_distribution_version"].split(".")[0]),
"family": to_text(module_output["ansible_facts"]["ansible_os_family"].lower()),
}
display.debug(f"{self._task.action}: distribution: {distribution}")
return distribution
except KeyError as ke:
raise AnsibleError(f'Failed to get distribution information. Missing "{ke.args[0]}" in output.') from ke
def get_shutdown_command(self, task_vars, distribution):
def find_command(command, find_search_paths):
def get_shutdown_command(self, task_vars: dict[str, t.Any], distribution: Distribution) -> str:
def find_command(command: str, find_search_paths: list[str]) -> list[str]:
display.debug(
f'{self._task.action}: running find module looking in {find_search_paths} to get path for "{command}"'
)
@ -111,7 +115,7 @@ class ActionModule(ActionBase):
)
return [x["path"] for x in find_result["files"]]
shutdown_bin = self._get_value_from_facts("SHUTDOWN_COMMANDS", distribution, "DEFAULT_SHUTDOWN_COMMAND")
shutdown_bin = self._get_value_from_facts(self.SHUTDOWN_COMMANDS, distribution, self.DEFAULT_SHUTDOWN_COMMAND)
default_search_paths = ["/sbin", "/usr/sbin", "/usr/local/sbin"]
search_paths = self._task.args.get("search_paths", default_search_paths)
@ -146,7 +150,7 @@ class ActionModule(ActionBase):
return f"{full_path[0]} poweroff" # done, since we cannot use args with systemd shutdown
# systemd case taken care of, here we add args to the command
args = self._get_value_from_facts("SHUTDOWN_COMMAND_ARGS", distribution, "DEFAULT_SHUTDOWN_COMMAND_ARGS")
args = self._get_value_from_facts(self.SHUTDOWN_COMMAND_ARGS, distribution, self.DEFAULT_SHUTDOWN_COMMAND_ARGS)
# Convert seconds to minutes. If less that 60, set it to 0.
delay_sec = self.delay
shutdown_message = self._task.args.get("msg", self.DEFAULT_SHUTDOWN_MESSAGE)
@ -154,8 +158,8 @@ class ActionModule(ActionBase):
af = args.format(delay_sec=delay_sec, delay_min=delay_sec // 60, message=shutdown_message)
return f"{full_path[0]} {af}"
def perform_shutdown(self, task_vars, distribution):
result = {}
def perform_shutdown(self, task_vars, distribution) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
shutdown_result = {}
shutdown_command_exec = self.get_shutdown_command(task_vars, distribution)
@ -176,7 +180,7 @@ class ActionModule(ActionBase):
result["failed"] = True
result["shutdown"] = False
result["msg"] = (
f"Shutdown command failed. Error was {fmt(shutdown_result, 'stdout')}, {fmt(shutdown_result, 'stderr')}"
f"Shutdown command failed. Error was {to_native(shutdown_result['stdout'])}, {to_native(shutdown_result['stderr'])}"
)
return result
@ -184,7 +188,7 @@ class ActionModule(ActionBase):
result["shutdown_command"] = shutdown_command_exec
return result
def run(self, tmp=None, task_vars=None):
def run(self, tmp: str | None = None, task_vars: dict[str, t.Any] | None = None) -> dict[str, t.Any]:
self._supports_check_mode = True
self._supports_async = True