mirror of
https://github.com/ansible-collections/community.general.git
synced 2026-02-04 07:51:50 +00:00
Add type hints to action and test plugins and to plugin utils; fix some bugs, and improve input validation (#11167)
* Add type hints to action and test plugins and to plugin utils. Also fix some bugs and add proper input validation. * Combine lines. Co-authored-by: Alexei Znamensky <103110+russoz@users.noreply.github.com> * Extend changelog fragment. * Move task_vars initialization up. --------- Co-authored-by: Alexei Znamensky <103110+russoz@users.noreply.github.com>
This commit is contained in:
parent
4517b86ed4
commit
19757b3a4c
11 changed files with 194 additions and 96 deletions
|
|
@ -5,6 +5,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import typing as t
|
||||
|
||||
from ansible.plugins.action import ActionBase
|
||||
from ansible.errors import AnsibleActionFail, AnsibleConnectionFailure
|
||||
|
|
@ -20,7 +21,7 @@ class ActionModule(ActionBase):
|
|||
DEFAULT_SUDOABLE = True
|
||||
|
||||
@staticmethod
|
||||
def msg_error__async_and_poll_not_zero(task_poll, task_async, max_timeout):
|
||||
def msg_error__async_and_poll_not_zero(task_poll, task_async, max_timeout) -> str:
|
||||
return (
|
||||
"This module doesn't support async>0 and poll>0 when its 'state' param "
|
||||
"is set to 'restored'. To enable its rollback feature (that needs the "
|
||||
|
|
@ -30,7 +31,7 @@ class ActionModule(ActionBase):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def msg_warning__no_async_is_no_rollback(task_poll, task_async, max_timeout):
|
||||
def msg_warning__no_async_is_no_rollback(task_poll, task_async, max_timeout) -> str:
|
||||
return (
|
||||
"Attempts to restore iptables state without rollback in case of mistake "
|
||||
"may lead the ansible controller to loose access to the hosts and never "
|
||||
|
|
@ -41,7 +42,7 @@ class ActionModule(ActionBase):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def msg_warning__async_greater_than_timeout(task_poll, task_async, max_timeout):
|
||||
def msg_warning__async_greater_than_timeout(task_poll, task_async, max_timeout) -> str:
|
||||
return (
|
||||
"You attempt to restore iptables state with rollback in case of mistake, "
|
||||
"but with settings that will lead this rollback to happen AFTER that the "
|
||||
|
|
@ -50,7 +51,9 @@ class ActionModule(ActionBase):
|
|||
f"'ansible_timeout' (={max_timeout}) (recommended)."
|
||||
)
|
||||
|
||||
def _async_result(self, async_status_args, task_vars, timeout):
|
||||
def _async_result(
|
||||
self, async_status_args: dict[str, t.Any], task_vars: dict[str, t.Any], timeout: int
|
||||
) -> dict[str, t.Any]:
|
||||
"""
|
||||
Retrieve results of the asynchronous task, and display them in place of
|
||||
the async wrapper results (those with the ansible_job_id key).
|
||||
|
|
@ -81,10 +84,13 @@ class ActionModule(ActionBase):
|
|||
|
||||
return async_result
|
||||
|
||||
def run(self, tmp=None, task_vars=None):
|
||||
def run(self, tmp: str | None = None, task_vars: dict[str, t.Any] | None = None) -> dict[str, t.Any]:
|
||||
self._supports_check_mode = True
|
||||
self._supports_async = True
|
||||
|
||||
if task_vars is None:
|
||||
task_vars = {}
|
||||
|
||||
result = super().run(tmp, task_vars)
|
||||
del tmp # tmp no longer has any effect
|
||||
|
||||
|
|
|
|||
|
|
@ -7,19 +7,25 @@
|
|||
from __future__ import annotations
|
||||
|
||||
|
||||
import typing as t
|
||||
|
||||
from ansible.errors import AnsibleError, AnsibleConnectionFailure
|
||||
from ansible.module_utils.common.text.converters import to_native, to_text
|
||||
from ansible.module_utils.common.collections import is_string
|
||||
from ansible.plugins.action import ActionBase
|
||||
from ansible.utils.display import Display
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
class Distribution(t.TypedDict):
|
||||
name: str
|
||||
version: str
|
||||
family: str
|
||||
|
||||
|
||||
display = Display()
|
||||
|
||||
|
||||
def fmt(mapping, key):
|
||||
return to_native(mapping[key]).strip()
|
||||
|
||||
|
||||
class TimedOutException(Exception):
|
||||
pass
|
||||
|
||||
|
|
@ -60,25 +66,23 @@ class ActionModule(ActionBase):
|
|||
def delay(self):
|
||||
return self._check_delay("delay", self.DEFAULT_PRE_SHUTDOWN_DELAY)
|
||||
|
||||
def _check_delay(self, key, default):
|
||||
def _check_delay(self, key: str, default: int) -> int:
|
||||
"""Ensure that the value is positive or zero"""
|
||||
value = int(self._task.args.get(key, default))
|
||||
if value < 0:
|
||||
value = 0
|
||||
return value
|
||||
|
||||
def _get_value_from_facts(self, variable_name, distribution, default_value):
|
||||
@staticmethod
|
||||
def _get_value_from_facts(data: dict[str, str], distribution: Distribution, default_value: str) -> str:
|
||||
"""Get dist+version specific args first, then distribution, then family, lastly use default"""
|
||||
attr = getattr(self, variable_name)
|
||||
value = attr.get(
|
||||
return data.get(
|
||||
distribution["name"] + distribution["version"],
|
||||
attr.get(distribution["name"], attr.get(distribution["family"], getattr(self, default_value))),
|
||||
data.get(distribution["name"], data.get(distribution["family"], default_value)),
|
||||
)
|
||||
return value
|
||||
|
||||
def get_distribution(self, task_vars):
|
||||
def get_distribution(self, task_vars: dict[str, t.Any]) -> Distribution:
|
||||
# FIXME: only execute the module if we don't already have the facts we need
|
||||
distribution = {}
|
||||
display.debug(f"{self._task.action}: running setup module to get distribution")
|
||||
module_output = self._execute_module(
|
||||
task_vars=task_vars, module_name="ansible.legacy.setup", module_args={"gather_subset": "min"}
|
||||
|
|
@ -86,20 +90,20 @@ class ActionModule(ActionBase):
|
|||
try:
|
||||
if module_output.get("failed", False):
|
||||
raise AnsibleError(
|
||||
f"Failed to determine system distribution. {fmt(module_output, 'module_stdout')}, {fmt(module_output, 'module_stderr')}"
|
||||
f"Failed to determine system distribution. {to_native(module_output['module_stdout'])}, {to_native(module_output['module_stderr'])}"
|
||||
)
|
||||
distribution["name"] = module_output["ansible_facts"]["ansible_distribution"].lower()
|
||||
distribution["version"] = to_text(
|
||||
module_output["ansible_facts"]["ansible_distribution_version"].split(".")[0]
|
||||
)
|
||||
distribution["family"] = to_text(module_output["ansible_facts"]["ansible_os_family"].lower())
|
||||
distribution: Distribution = {
|
||||
"name": module_output["ansible_facts"]["ansible_distribution"].lower(),
|
||||
"version": to_text(module_output["ansible_facts"]["ansible_distribution_version"].split(".")[0]),
|
||||
"family": to_text(module_output["ansible_facts"]["ansible_os_family"].lower()),
|
||||
}
|
||||
display.debug(f"{self._task.action}: distribution: {distribution}")
|
||||
return distribution
|
||||
except KeyError as ke:
|
||||
raise AnsibleError(f'Failed to get distribution information. Missing "{ke.args[0]}" in output.') from ke
|
||||
|
||||
def get_shutdown_command(self, task_vars, distribution):
|
||||
def find_command(command, find_search_paths):
|
||||
def get_shutdown_command(self, task_vars: dict[str, t.Any], distribution: Distribution) -> str:
|
||||
def find_command(command: str, find_search_paths: list[str]) -> list[str]:
|
||||
display.debug(
|
||||
f'{self._task.action}: running find module looking in {find_search_paths} to get path for "{command}"'
|
||||
)
|
||||
|
|
@ -111,7 +115,7 @@ class ActionModule(ActionBase):
|
|||
)
|
||||
return [x["path"] for x in find_result["files"]]
|
||||
|
||||
shutdown_bin = self._get_value_from_facts("SHUTDOWN_COMMANDS", distribution, "DEFAULT_SHUTDOWN_COMMAND")
|
||||
shutdown_bin = self._get_value_from_facts(self.SHUTDOWN_COMMANDS, distribution, self.DEFAULT_SHUTDOWN_COMMAND)
|
||||
default_search_paths = ["/sbin", "/usr/sbin", "/usr/local/sbin"]
|
||||
search_paths = self._task.args.get("search_paths", default_search_paths)
|
||||
|
||||
|
|
@ -146,7 +150,7 @@ class ActionModule(ActionBase):
|
|||
return f"{full_path[0]} poweroff" # done, since we cannot use args with systemd shutdown
|
||||
|
||||
# systemd case taken care of, here we add args to the command
|
||||
args = self._get_value_from_facts("SHUTDOWN_COMMAND_ARGS", distribution, "DEFAULT_SHUTDOWN_COMMAND_ARGS")
|
||||
args = self._get_value_from_facts(self.SHUTDOWN_COMMAND_ARGS, distribution, self.DEFAULT_SHUTDOWN_COMMAND_ARGS)
|
||||
# Convert seconds to minutes. If less that 60, set it to 0.
|
||||
delay_sec = self.delay
|
||||
shutdown_message = self._task.args.get("msg", self.DEFAULT_SHUTDOWN_MESSAGE)
|
||||
|
|
@ -154,8 +158,8 @@ class ActionModule(ActionBase):
|
|||
af = args.format(delay_sec=delay_sec, delay_min=delay_sec // 60, message=shutdown_message)
|
||||
return f"{full_path[0]} {af}"
|
||||
|
||||
def perform_shutdown(self, task_vars, distribution):
|
||||
result = {}
|
||||
def perform_shutdown(self, task_vars, distribution) -> dict[str, t.Any]:
|
||||
result: dict[str, t.Any] = {}
|
||||
shutdown_result = {}
|
||||
shutdown_command_exec = self.get_shutdown_command(task_vars, distribution)
|
||||
|
||||
|
|
@ -176,7 +180,7 @@ class ActionModule(ActionBase):
|
|||
result["failed"] = True
|
||||
result["shutdown"] = False
|
||||
result["msg"] = (
|
||||
f"Shutdown command failed. Error was {fmt(shutdown_result, 'stdout')}, {fmt(shutdown_result, 'stderr')}"
|
||||
f"Shutdown command failed. Error was {to_native(shutdown_result['stdout'])}, {to_native(shutdown_result['stderr'])}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
|
@ -184,7 +188,7 @@ class ActionModule(ActionBase):
|
|||
result["shutdown_command"] = shutdown_command_exec
|
||||
return result
|
||||
|
||||
def run(self, tmp=None, task_vars=None):
|
||||
def run(self, tmp: str | None = None, task_vars: dict[str, t.Any] | None = None) -> dict[str, t.Any]:
|
||||
self._supports_check_mode = True
|
||||
self._supports_async = True
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue