1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2026-02-04 07:51:50 +00:00

Add type hints to action and test plugins and to plugin utils; fix some bugs, and improve input validation (#11167)

* Add type hints to action and test plugins and to plugin utils. Also fix some bugs and add proper input validation.

* Combine lines.

Co-authored-by: Alexei Znamensky <103110+russoz@users.noreply.github.com>

* Extend changelog fragment.

* Move task_vars initialization up.

---------

Co-authored-by: Alexei Znamensky <103110+russoz@users.noreply.github.com>
This commit is contained in:
Felix Fontein 2025-11-22 22:52:21 +01:00 committed by GitHub
parent 4517b86ed4
commit 19757b3a4c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 194 additions and 96 deletions

View file

@ -0,0 +1,15 @@
minor_changes:
- "iptables_state action plugin - add type hints (https://github.com/ansible-collections/community.general/pull/11167)."
- "shutdown action plugin - add type hints (https://github.com/ansible-collections/community.general/pull/11167)."
- "ansible_type plugin utils - add type hints (https://github.com/ansible-collections/community.general/pull/11167)."
- "keys_filter.py plugin utils - add type hints (https://github.com/ansible-collections/community.general/pull/11167)."
- "unsafe.py plugin utils - add type hints (https://github.com/ansible-collections/community.general/pull/11167)."
- "a_module test plugin - add proper parameter checking and type hints (https://github.com/ansible-collections/community.general/pull/11167)."
- "ansible_type test plugin - add type hints (https://github.com/ansible-collections/community.general/pull/11167)."
- "fqdn_valid test plugin - add proper parameter checking, and add type hints (https://github.com/ansible-collections/community.general/pull/11167)."
bugfixes:
- "keys_filter.py plugin utils - fixed requirements check so that other sequences than lists and strings are checked,
and corrected broken formatting during error reporting
(https://github.com/ansible-collections/community.general/pull/11167)."
- "ansible_type test plugin - fix parameter checking (https://github.com/ansible-collections/community.general/pull/11167)."
- "ansible_type plugin utils - avoid potential concatenation of non-strings when ``alias`` has non-string values (https://github.com/ansible-collections/community.general/pull/11167)."

View file

@ -5,6 +5,7 @@
from __future__ import annotations from __future__ import annotations
import time import time
import typing as t
from ansible.plugins.action import ActionBase from ansible.plugins.action import ActionBase
from ansible.errors import AnsibleActionFail, AnsibleConnectionFailure from ansible.errors import AnsibleActionFail, AnsibleConnectionFailure
@ -20,7 +21,7 @@ class ActionModule(ActionBase):
DEFAULT_SUDOABLE = True DEFAULT_SUDOABLE = True
@staticmethod @staticmethod
def msg_error__async_and_poll_not_zero(task_poll, task_async, max_timeout): def msg_error__async_and_poll_not_zero(task_poll, task_async, max_timeout) -> str:
return ( return (
"This module doesn't support async>0 and poll>0 when its 'state' param " "This module doesn't support async>0 and poll>0 when its 'state' param "
"is set to 'restored'. To enable its rollback feature (that needs the " "is set to 'restored'. To enable its rollback feature (that needs the "
@ -30,7 +31,7 @@ class ActionModule(ActionBase):
) )
@staticmethod @staticmethod
def msg_warning__no_async_is_no_rollback(task_poll, task_async, max_timeout): def msg_warning__no_async_is_no_rollback(task_poll, task_async, max_timeout) -> str:
return ( return (
"Attempts to restore iptables state without rollback in case of mistake " "Attempts to restore iptables state without rollback in case of mistake "
"may lead the ansible controller to loose access to the hosts and never " "may lead the ansible controller to loose access to the hosts and never "
@ -41,7 +42,7 @@ class ActionModule(ActionBase):
) )
@staticmethod @staticmethod
def msg_warning__async_greater_than_timeout(task_poll, task_async, max_timeout): def msg_warning__async_greater_than_timeout(task_poll, task_async, max_timeout) -> str:
return ( return (
"You attempt to restore iptables state with rollback in case of mistake, " "You attempt to restore iptables state with rollback in case of mistake, "
"but with settings that will lead this rollback to happen AFTER that the " "but with settings that will lead this rollback to happen AFTER that the "
@ -50,7 +51,9 @@ class ActionModule(ActionBase):
f"'ansible_timeout' (={max_timeout}) (recommended)." f"'ansible_timeout' (={max_timeout}) (recommended)."
) )
def _async_result(self, async_status_args, task_vars, timeout): def _async_result(
self, async_status_args: dict[str, t.Any], task_vars: dict[str, t.Any], timeout: int
) -> dict[str, t.Any]:
""" """
Retrieve results of the asynchronous task, and display them in place of Retrieve results of the asynchronous task, and display them in place of
the async wrapper results (those with the ansible_job_id key). the async wrapper results (those with the ansible_job_id key).
@ -81,10 +84,13 @@ class ActionModule(ActionBase):
return async_result return async_result
def run(self, tmp=None, task_vars=None): def run(self, tmp: str | None = None, task_vars: dict[str, t.Any] | None = None) -> dict[str, t.Any]:
self._supports_check_mode = True self._supports_check_mode = True
self._supports_async = True self._supports_async = True
if task_vars is None:
task_vars = {}
result = super().run(tmp, task_vars) result = super().run(tmp, task_vars)
del tmp # tmp no longer has any effect del tmp # tmp no longer has any effect

View file

@ -7,19 +7,25 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from ansible.errors import AnsibleError, AnsibleConnectionFailure from ansible.errors import AnsibleError, AnsibleConnectionFailure
from ansible.module_utils.common.text.converters import to_native, to_text from ansible.module_utils.common.text.converters import to_native, to_text
from ansible.module_utils.common.collections import is_string from ansible.module_utils.common.collections import is_string
from ansible.plugins.action import ActionBase from ansible.plugins.action import ActionBase
from ansible.utils.display import Display from ansible.utils.display import Display
if t.TYPE_CHECKING:
class Distribution(t.TypedDict):
name: str
version: str
family: str
display = Display() display = Display()
def fmt(mapping, key):
return to_native(mapping[key]).strip()
class TimedOutException(Exception): class TimedOutException(Exception):
pass pass
@ -60,25 +66,23 @@ class ActionModule(ActionBase):
def delay(self): def delay(self):
return self._check_delay("delay", self.DEFAULT_PRE_SHUTDOWN_DELAY) return self._check_delay("delay", self.DEFAULT_PRE_SHUTDOWN_DELAY)
def _check_delay(self, key, default): def _check_delay(self, key: str, default: int) -> int:
"""Ensure that the value is positive or zero""" """Ensure that the value is positive or zero"""
value = int(self._task.args.get(key, default)) value = int(self._task.args.get(key, default))
if value < 0: if value < 0:
value = 0 value = 0
return value return value
def _get_value_from_facts(self, variable_name, distribution, default_value): @staticmethod
def _get_value_from_facts(data: dict[str, str], distribution: Distribution, default_value: str) -> str:
"""Get dist+version specific args first, then distribution, then family, lastly use default""" """Get dist+version specific args first, then distribution, then family, lastly use default"""
attr = getattr(self, variable_name) return data.get(
value = attr.get(
distribution["name"] + distribution["version"], distribution["name"] + distribution["version"],
attr.get(distribution["name"], attr.get(distribution["family"], getattr(self, default_value))), data.get(distribution["name"], data.get(distribution["family"], default_value)),
) )
return value
def get_distribution(self, task_vars): def get_distribution(self, task_vars: dict[str, t.Any]) -> Distribution:
# FIXME: only execute the module if we don't already have the facts we need # FIXME: only execute the module if we don't already have the facts we need
distribution = {}
display.debug(f"{self._task.action}: running setup module to get distribution") display.debug(f"{self._task.action}: running setup module to get distribution")
module_output = self._execute_module( module_output = self._execute_module(
task_vars=task_vars, module_name="ansible.legacy.setup", module_args={"gather_subset": "min"} task_vars=task_vars, module_name="ansible.legacy.setup", module_args={"gather_subset": "min"}
@ -86,20 +90,20 @@ class ActionModule(ActionBase):
try: try:
if module_output.get("failed", False): if module_output.get("failed", False):
raise AnsibleError( raise AnsibleError(
f"Failed to determine system distribution. {fmt(module_output, 'module_stdout')}, {fmt(module_output, 'module_stderr')}" f"Failed to determine system distribution. {to_native(module_output['module_stdout'])}, {to_native(module_output['module_stderr'])}"
) )
distribution["name"] = module_output["ansible_facts"]["ansible_distribution"].lower() distribution: Distribution = {
distribution["version"] = to_text( "name": module_output["ansible_facts"]["ansible_distribution"].lower(),
module_output["ansible_facts"]["ansible_distribution_version"].split(".")[0] "version": to_text(module_output["ansible_facts"]["ansible_distribution_version"].split(".")[0]),
) "family": to_text(module_output["ansible_facts"]["ansible_os_family"].lower()),
distribution["family"] = to_text(module_output["ansible_facts"]["ansible_os_family"].lower()) }
display.debug(f"{self._task.action}: distribution: {distribution}") display.debug(f"{self._task.action}: distribution: {distribution}")
return distribution return distribution
except KeyError as ke: except KeyError as ke:
raise AnsibleError(f'Failed to get distribution information. Missing "{ke.args[0]}" in output.') from ke raise AnsibleError(f'Failed to get distribution information. Missing "{ke.args[0]}" in output.') from ke
def get_shutdown_command(self, task_vars, distribution): def get_shutdown_command(self, task_vars: dict[str, t.Any], distribution: Distribution) -> str:
def find_command(command, find_search_paths): def find_command(command: str, find_search_paths: list[str]) -> list[str]:
display.debug( display.debug(
f'{self._task.action}: running find module looking in {find_search_paths} to get path for "{command}"' f'{self._task.action}: running find module looking in {find_search_paths} to get path for "{command}"'
) )
@ -111,7 +115,7 @@ class ActionModule(ActionBase):
) )
return [x["path"] for x in find_result["files"]] return [x["path"] for x in find_result["files"]]
shutdown_bin = self._get_value_from_facts("SHUTDOWN_COMMANDS", distribution, "DEFAULT_SHUTDOWN_COMMAND") shutdown_bin = self._get_value_from_facts(self.SHUTDOWN_COMMANDS, distribution, self.DEFAULT_SHUTDOWN_COMMAND)
default_search_paths = ["/sbin", "/usr/sbin", "/usr/local/sbin"] default_search_paths = ["/sbin", "/usr/sbin", "/usr/local/sbin"]
search_paths = self._task.args.get("search_paths", default_search_paths) search_paths = self._task.args.get("search_paths", default_search_paths)
@ -146,7 +150,7 @@ class ActionModule(ActionBase):
return f"{full_path[0]} poweroff" # done, since we cannot use args with systemd shutdown return f"{full_path[0]} poweroff" # done, since we cannot use args with systemd shutdown
# systemd case taken care of, here we add args to the command # systemd case taken care of, here we add args to the command
args = self._get_value_from_facts("SHUTDOWN_COMMAND_ARGS", distribution, "DEFAULT_SHUTDOWN_COMMAND_ARGS") args = self._get_value_from_facts(self.SHUTDOWN_COMMAND_ARGS, distribution, self.DEFAULT_SHUTDOWN_COMMAND_ARGS)
# Convert seconds to minutes. If less that 60, set it to 0. # Convert seconds to minutes. If less that 60, set it to 0.
delay_sec = self.delay delay_sec = self.delay
shutdown_message = self._task.args.get("msg", self.DEFAULT_SHUTDOWN_MESSAGE) shutdown_message = self._task.args.get("msg", self.DEFAULT_SHUTDOWN_MESSAGE)
@ -154,8 +158,8 @@ class ActionModule(ActionBase):
af = args.format(delay_sec=delay_sec, delay_min=delay_sec // 60, message=shutdown_message) af = args.format(delay_sec=delay_sec, delay_min=delay_sec // 60, message=shutdown_message)
return f"{full_path[0]} {af}" return f"{full_path[0]} {af}"
def perform_shutdown(self, task_vars, distribution): def perform_shutdown(self, task_vars, distribution) -> dict[str, t.Any]:
result = {} result: dict[str, t.Any] = {}
shutdown_result = {} shutdown_result = {}
shutdown_command_exec = self.get_shutdown_command(task_vars, distribution) shutdown_command_exec = self.get_shutdown_command(task_vars, distribution)
@ -176,7 +180,7 @@ class ActionModule(ActionBase):
result["failed"] = True result["failed"] = True
result["shutdown"] = False result["shutdown"] = False
result["msg"] = ( result["msg"] = (
f"Shutdown command failed. Error was {fmt(shutdown_result, 'stdout')}, {fmt(shutdown_result, 'stderr')}" f"Shutdown command failed. Error was {to_native(shutdown_result['stdout'])}, {to_native(shutdown_result['stderr'])}"
) )
return result return result
@ -184,7 +188,7 @@ class ActionModule(ActionBase):
result["shutdown_command"] = shutdown_command_exec result["shutdown_command"] = shutdown_command_exec
return result return result
def run(self, tmp=None, task_vars=None): def run(self, tmp: str | None = None, task_vars: dict[str, t.Any] | None = None) -> dict[str, t.Any]:
self._supports_check_mode = True self._supports_check_mode = True
self._supports_async = True self._supports_async = True

View file

@ -4,9 +4,11 @@
from __future__ import annotations from __future__ import annotations
from ansible.errors import AnsibleFilterError import typing as t
from collections.abc import Mapping from collections.abc import Mapping
from ansible.errors import AnsibleFilterError
try: try:
# Introduced with Data Tagging (https://github.com/ansible/ansible/pull/84621): # Introduced with Data Tagging (https://github.com/ansible/ansible/pull/84621):
from ansible.module_utils.datatag import native_type_name as _native_type_name from ansible.module_utils.datatag import native_type_name as _native_type_name
@ -16,7 +18,7 @@ except ImportError:
HAS_NATIVE_TYPE_NAME = False HAS_NATIVE_TYPE_NAME = False
def _atype(data, alias, *, use_native_type: bool = False): def _atype(data: t.Any, alias: Mapping, *, use_native_type: bool = False) -> str:
""" """
Returns the name of the type class. Returns the name of the type class.
""" """
@ -30,10 +32,10 @@ def _atype(data, alias, *, use_native_type: bool = False):
data_type = "dict" data_type = "dict"
elif data_type == "_AnsibleLazyTemplateList": elif data_type == "_AnsibleLazyTemplateList":
data_type = "list" data_type = "list"
return alias.get(data_type, data_type) return str(alias.get(data_type, data_type))
def _ansible_type(data, alias, *, use_native_type: bool = False): def _ansible_type(data: t.Any, alias: t.Any, *, use_native_type: bool = False) -> str:
""" """
Returns the Ansible data type. Returns the Ansible data type.
""" """
@ -42,21 +44,20 @@ def _ansible_type(data, alias, *, use_native_type: bool = False):
alias = {} alias = {}
if not isinstance(alias, Mapping): if not isinstance(alias, Mapping):
msg = "The argument alias must be a dictionary. %s is %s" raise AnsibleFilterError(f"The argument alias must be a dictionary. {alias!r} is {type(alias)}")
raise AnsibleFilterError(msg % (alias, type(alias)))
data_type = _atype(data, alias, use_native_type=use_native_type) data_type = _atype(data, alias, use_native_type=use_native_type)
if data_type == "list" and len(data) > 0: if data_type == "list" and len(data) > 0:
items = [_atype(i, alias, use_native_type=use_native_type) for i in data] items = {_atype(i, alias, use_native_type=use_native_type) for i in data}
items_type = "|".join(sorted(set(items))) items_type = "|".join(sorted(items))
return f"{data_type}[{items_type}]" return f"{data_type}[{items_type}]"
if data_type == "dict" and len(data) > 0: if data_type == "dict" and len(data) > 0:
keys = [_atype(i, alias, use_native_type=use_native_type) for i in data.keys()] keys = {_atype(i, alias, use_native_type=use_native_type) for i in data.keys()}
vals = [_atype(i, alias, use_native_type=use_native_type) for i in data.values()] vals = {_atype(i, alias, use_native_type=use_native_type) for i in data.values()}
keys_type = "|".join(sorted(set(keys))) keys_type = "|".join(sorted(keys))
vals_type = "|".join(sorted(set(vals))) vals_type = "|".join(sorted(vals))
return f"{data_type}[{keys_type}, {vals_type}]" return f"{data_type}[{keys_type}, {vals_type}]"
return data_type return data_type

View file

@ -6,12 +6,17 @@
from __future__ import annotations from __future__ import annotations
import re import re
import typing as t
from ansible.errors import AnsibleFilterError
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from ansible.errors import AnsibleFilterError
from ansible.module_utils.common.collections import is_sequence
def _keys_filter_params(data, matching_parameter):
def _keys_filter_params(
data: t.Any, matching_parameter: t.Any
) -> tuple[Sequence[Mapping[str, t.Any]], t.Literal["equal", "starts_with", "ends_with", "regex"]]:
"""test parameters: """test parameters:
* data must be a list of dictionaries. All keys must be strings. * data must be a list of dictionaries. All keys must be strings.
* matching_parameter is member of a list. * matching_parameter is member of a list.
@ -21,27 +26,27 @@ def _keys_filter_params(data, matching_parameter):
ml = ["equal", "starts_with", "ends_with", "regex"] ml = ["equal", "starts_with", "ends_with", "regex"]
if not isinstance(data, Sequence): if not isinstance(data, Sequence):
msg = "First argument must be a list. %s is %s" msg = f"First argument must be a list. {data!r} is {type(data)}"
raise AnsibleFilterError(msg % (data, type(data))) raise AnsibleFilterError(msg)
for elem in data: for elem in data:
if not isinstance(elem, Mapping): if not isinstance(elem, Mapping):
msg = "The data items must be dictionaries. %s is %s" msg = f"The data items must be dictionaries. {elem} is {type(elem)}"
raise AnsibleFilterError(msg % (elem, type(elem))) raise AnsibleFilterError(msg)
for elem in data: for elem in data:
if not all(isinstance(item, str) for item in elem.keys()): if not all(isinstance(item, str) for item in elem.keys()):
msg = "Top level keys must be strings. keys: %s" msg = f"Top level keys must be strings. keys: {list(elem.keys())}"
raise AnsibleFilterError(msg % elem.keys()) raise AnsibleFilterError(msg)
if mp not in ml: if mp not in ml:
msg = "The matching_parameter must be one of %s. matching_parameter=%s" msg = f"The matching_parameter must be one of {ml}. matching_parameter={mp!r}"
raise AnsibleFilterError(msg % (ml, mp)) raise AnsibleFilterError(msg)
return return data, mp
def _keys_filter_target_str(target, matching_parameter): def _keys_filter_target_str(target: t.Any, matching_parameter: t.Any) -> tuple[str, ...] | re.Pattern:
""" """
Test: Test:
* target is a non-empty string or list. * target is a non-empty string or list.
@ -54,18 +59,18 @@ def _keys_filter_target_str(target, matching_parameter):
""" """
if not isinstance(target, Sequence): if not isinstance(target, Sequence):
msg = "The target must be a string or a list. target is %s." msg = f"The target must be a string or a list. target is {type(target)}."
raise AnsibleFilterError(msg % type(target)) raise AnsibleFilterError(msg)
if len(target) == 0: if len(target) == 0:
msg = "The target can't be empty." msg = "The target can't be empty."
raise AnsibleFilterError(msg) raise AnsibleFilterError(msg)
if isinstance(target, list): if is_sequence(target):
for elem in target: for elem in target:
if not isinstance(elem, str): if not isinstance(elem, str):
msg = "The target items must be strings. %s is %s" msg = f"The target items must be strings. {elem!r} is {type(elem)}"
raise AnsibleFilterError(msg % (elem, type(elem))) raise AnsibleFilterError(msg)
if matching_parameter == "regex": if matching_parameter == "regex":
if isinstance(target, str): if isinstance(target, str):
@ -77,19 +82,19 @@ def _keys_filter_target_str(target, matching_parameter):
else: else:
r = target[0] r = target[0]
try: try:
tt = re.compile(r) return re.compile(r)
except re.error as e: except re.error as e:
msg = "The target must be a valid regex if matching_parameter=regex. target is %s" msg = f"The target must be a valid regex if matching_parameter=regex. target is {r}"
raise AnsibleFilterError(msg % r) from e raise AnsibleFilterError(msg) from e
elif isinstance(target, str): elif isinstance(target, str):
tt = (target,) return (target,)
else: else:
tt = tuple(set(target)) return tuple(set(target))
return tt
def _keys_filter_target_dict(target, matching_parameter): def _keys_filter_target_dict(
target: t.Any, matching_parameter: t.Any
) -> list[tuple[str, str]] | list[tuple[re.Pattern, str]]:
""" """
Test: Test:
* target is a list of dictionaries with attributes 'after' and 'before'. * target is a list of dictionaries with attributes 'after' and 'before'.
@ -101,8 +106,8 @@ def _keys_filter_target_dict(target, matching_parameter):
""" """
if not isinstance(target, list): if not isinstance(target, list):
msg = "The target must be a list. target is %s." msg = f"The target must be a list. target is {target!r} of type {type(target)}."
raise AnsibleFilterError(msg % (target, type(target))) raise AnsibleFilterError(msg)
if len(target) == 0: if len(target) == 0:
msg = "The target can't be empty." msg = "The target can't be empty."
@ -110,25 +115,25 @@ def _keys_filter_target_dict(target, matching_parameter):
for elem in target: for elem in target:
if not isinstance(elem, Mapping): if not isinstance(elem, Mapping):
msg = "The target items must be dictionaries. %s is %s" msg = f"The target items must be dictionaries. {elem!r}%s is {type(elem)}"
raise AnsibleFilterError(msg % (elem, type(elem))) raise AnsibleFilterError(msg)
if not all(k in elem for k in ("before", "after")): if not all(k in elem for k in ("before", "after")):
msg = "All dictionaries in target must include attributes: after, before." msg = "All dictionaries in target must include attributes: after, before."
raise AnsibleFilterError(msg) raise AnsibleFilterError(msg)
if not isinstance(elem["before"], str): if not isinstance(elem["before"], str):
msg = "The attributes before must be strings. %s is %s" msg = f"The attributes before must be strings. {elem['before']!r} is {type(elem['before'])}"
raise AnsibleFilterError(msg % (elem["before"], type(elem["before"]))) raise AnsibleFilterError(msg)
if not isinstance(elem["after"], str): if not isinstance(elem["after"], str):
msg = "The attributes after must be strings. %s is %s" msg = f"The attributes after must be strings. {elem['after']!r} is {type(elem['after'])}"
raise AnsibleFilterError(msg % (elem["after"], type(elem["after"]))) raise AnsibleFilterError(msg)
before = [d["before"] for d in target] before: list[str] = [d["before"] for d in target]
after = [d["after"] for d in target] after: list[str] = [d["after"] for d in target]
if matching_parameter == "regex": if matching_parameter == "regex":
try: try:
tr = map(re.compile, before) tr = map(re.compile, before)
tz = list(zip(tr, after)) return list(zip(tr, after))
except re.error as e: except re.error as e:
msg = ( msg = (
"The attributes before must be valid regex if matching_parameter=regex." "The attributes before must be valid regex if matching_parameter=regex."
@ -136,6 +141,4 @@ def _keys_filter_target_dict(target, matching_parameter):
) )
raise AnsibleFilterError(msg % before) from e raise AnsibleFilterError(msg % before) from e
else: else:
tz = list(zip(before, after)) return list(zip(before, after))
return tz

View file

@ -5,8 +5,9 @@
from __future__ import annotations from __future__ import annotations
import re import re
import typing as t
from collections.abc import Mapping, Set from collections.abc import Mapping, Sequence, Set
from ansible.module_utils.common.collections import is_sequence from ansible.module_utils.common.collections import is_sequence
from ansible.utils.unsafe_proxy import ( from ansible.utils.unsafe_proxy import (
AnsibleUnsafe, AnsibleUnsafe,
@ -17,14 +18,54 @@ _RE_TEMPLATE_CHARS = re.compile("[{}]")
_RE_TEMPLATE_CHARS_BYTES = re.compile(b"[{}]") _RE_TEMPLATE_CHARS_BYTES = re.compile(b"[{}]")
def make_unsafe(value): @t.overload
def make_unsafe(value: None) -> None: ...
@t.overload
def make_unsafe(value: Mapping) -> dict: ...
@t.overload
def make_unsafe(value: Set) -> set: ...
@t.overload
def make_unsafe(value: tuple) -> tuple: ...
@t.overload
def make_unsafe(value: list) -> list: ...
@t.overload
def make_unsafe(value: Sequence) -> Sequence: ...
@t.overload
def make_unsafe(value: str) -> str: ...
@t.overload
def make_unsafe(value: bool) -> bool: ...
@t.overload
def make_unsafe(value: int) -> int: ...
@t.overload
def make_unsafe(value: float) -> float: ...
def make_unsafe(value: t.Any) -> t.Any:
if value is None or isinstance(value, AnsibleUnsafe): if value is None or isinstance(value, AnsibleUnsafe):
return value return value
if isinstance(value, Mapping): if isinstance(value, Mapping):
return {make_unsafe(key): make_unsafe(val) for key, val in value.items()} return {make_unsafe(key): make_unsafe(val) for key, val in value.items()}
elif isinstance(value, Set): elif isinstance(value, Set):
return set(make_unsafe(elt) for elt in value) return {make_unsafe(elt) for elt in value}
elif is_sequence(value): elif is_sequence(value):
return type(value)(make_unsafe(elt) for elt in value) return type(value)(make_unsafe(elt) for elt in value)
elif isinstance(value, bytes): elif isinstance(value, bytes):

View file

@ -38,7 +38,11 @@ _value:
type: boolean type: boolean
""" """
import typing as t
from collections.abc import Callable
from ansible.plugins.loader import action_loader, module_loader from ansible.plugins.loader import action_loader, module_loader
from ansible.errors import AnsibleFilterError
try: try:
from ansible.errors import AnsiblePluginRemovedError from ansible.errors import AnsiblePluginRemovedError
@ -46,12 +50,14 @@ except ImportError:
AnsiblePluginRemovedError = Exception # type: ignore AnsiblePluginRemovedError = Exception # type: ignore
def a_module(term): def a_module(term: t.Any) -> bool:
""" """
Example: Example:
- 'community.general.ufw' is community.general.a_module - 'community.general.ufw' is community.general.a_module
- 'community.general.does_not_exist' is not community.general.a_module - 'community.general.does_not_exist' is not community.general.a_module
""" """
if not isinstance(term, str):
raise AnsibleFilterError(f"Parameter must be a string, got {term!r} of type {type(term)}")
try: try:
for loader in (action_loader, module_loader): for loader in (action_loader, module_loader):
data = loader.find_plugin(term) data = loader.find_plugin(term)
@ -65,7 +71,7 @@ def a_module(term):
class TestModule: class TestModule:
"""Ansible jinja2 tests""" """Ansible jinja2 tests"""
def tests(self): def tests(self) -> dict[str, Callable]:
return { return {
"a_module": a_module, "a_module": a_module,
} }

View file

@ -222,19 +222,21 @@ _value:
type: bool type: bool
""" """
from collections.abc import Sequence import typing as t
from collections.abc import Callable, Sequence
from ansible.errors import AnsibleFilterError from ansible.errors import AnsibleFilterError
from ansible_collections.community.general.plugins.plugin_utils.ansible_type import _ansible_type from ansible_collections.community.general.plugins.plugin_utils.ansible_type import _ansible_type
def ansible_type(data, dtype, alias=None): def ansible_type(data: t.Any, dtype: t.Any, alias: t.Any = None) -> bool:
"""Validates data type""" """Validates data type"""
if not isinstance(dtype, Sequence): if not isinstance(dtype, Sequence):
msg = "The argument dtype must be a string or a list. dtype is %s." msg = f"The argument dtype must be a string or a list. dtype is {dtype!r} of type {type(dtype)}."
raise AnsibleFilterError(msg % (dtype, type(dtype))) raise AnsibleFilterError(msg)
data_types: Sequence
if isinstance(dtype, str): if isinstance(dtype, str):
data_types = [dtype] data_types = [dtype]
else: else:
@ -245,5 +247,5 @@ def ansible_type(data, dtype, alias=None):
class TestModule: class TestModule:
def tests(self): def tests(self) -> dict[str, Callable]:
return {"ansible_type": ansible_type} return {"ansible_type": ansible_type}

View file

@ -63,7 +63,10 @@ _value:
type: bool type: bool
""" """
from ansible.errors import AnsibleError import typing as t
from collections.abc import Callable
from ansible.errors import AnsibleFilterError
ANOTHER_LIBRARY_IMPORT_ERROR: ImportError | None ANOTHER_LIBRARY_IMPORT_ERROR: ImportError | None
try: try:
@ -74,7 +77,7 @@ else:
ANOTHER_LIBRARY_IMPORT_ERROR = None ANOTHER_LIBRARY_IMPORT_ERROR = None
def fqdn_valid(name, min_labels=1, allow_underscores=False): def fqdn_valid(name: t.Any, min_labels: t.Any = 1, allow_underscores: t.Any = False) -> bool:
""" """
Example: Example:
- 'srv.example.com' is community.general.fqdn_valid - 'srv.example.com' is community.general.fqdn_valid
@ -82,7 +85,22 @@ def fqdn_valid(name, min_labels=1, allow_underscores=False):
""" """
if ANOTHER_LIBRARY_IMPORT_ERROR: if ANOTHER_LIBRARY_IMPORT_ERROR:
raise AnsibleError("Python package fqdn must be installed to use this test.") from ANOTHER_LIBRARY_IMPORT_ERROR raise AnsibleFilterError(
"Python package fqdn must be installed to use this test."
) from ANOTHER_LIBRARY_IMPORT_ERROR
if not isinstance(name, str):
raise AnsibleFilterError(f"The name parameter must be a string, got {name!r} of type {type(name)}")
if not isinstance(min_labels, int):
raise AnsibleFilterError(
f"The min_labels parameter must be an integer, got {min_labels!r} of type {type(min_labels)}"
)
if not isinstance(allow_underscores, bool):
raise AnsibleFilterError(
f"The allow_underscores parameter must be a boolean, got {allow_underscores!r} of type {type(allow_underscores)}"
)
fobj = FQDN(name, min_labels=min_labels, allow_underscores=allow_underscores) fobj = FQDN(name, min_labels=min_labels, allow_underscores=allow_underscores)
return fobj.is_valid return fobj.is_valid
@ -93,7 +111,7 @@ class TestModule:
https://pypi.org/project/fqdn/ https://pypi.org/project/fqdn/
""" """
def tests(self): def tests(self) -> dict[str, Callable]:
return { return {
"fqdn_valid": fqdn_valid, "fqdn_valid": fqdn_valid,
} }

View file

@ -9,4 +9,5 @@ plugins/modules/rhevm.py validate-modules:parameter-state-invalid-choice
plugins/modules/udm_user.py import-3.11 # Uses deprecated stdlib library 'crypt' plugins/modules/udm_user.py import-3.11 # Uses deprecated stdlib library 'crypt'
plugins/modules/udm_user.py import-3.12 # Uses deprecated stdlib library 'crypt' plugins/modules/udm_user.py import-3.12 # Uses deprecated stdlib library 'crypt'
plugins/modules/xfconf.py validate-modules:return-syntax-error plugins/modules/xfconf.py validate-modules:return-syntax-error
plugins/plugin_utils/unsafe.py pep8:E704
tests/unit/plugins/modules/test_gio_mime.yaml no-smart-quotes tests/unit/plugins/modules/test_gio_mime.yaml no-smart-quotes

View file

@ -9,4 +9,5 @@ plugins/modules/rhevm.py validate-modules:parameter-state-invalid-choice
plugins/modules/udm_user.py import-3.11 # Uses deprecated stdlib library 'crypt' plugins/modules/udm_user.py import-3.11 # Uses deprecated stdlib library 'crypt'
plugins/modules/udm_user.py import-3.12 # Uses deprecated stdlib library 'crypt' plugins/modules/udm_user.py import-3.12 # Uses deprecated stdlib library 'crypt'
plugins/modules/xfconf.py validate-modules:return-syntax-error plugins/modules/xfconf.py validate-modules:return-syntax-error
plugins/plugin_utils/unsafe.py pep8:E704
tests/unit/plugins/modules/test_gio_mime.yaml no-smart-quotes tests/unit/plugins/modules/test_gio_mime.yaml no-smart-quotes