1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2026-04-13 15:35:08 +00:00

Add basic typing for module_utils.

This commit is contained in:
Felix Fontein 2025-11-26 22:37:16 +01:00
parent fb2f34ba85
commit e74033eac1
48 changed files with 209 additions and 142 deletions

View file

@ -11,9 +11,13 @@ import os
import stat
import time
import fcntl
import typing as t
from contextlib import contextmanager
if t.TYPE_CHECKING:
from io import TextIOWrapper
class LockTimeout(Exception):
pass
@ -27,11 +31,13 @@ class FileLock:
unwanted and/or unexpected behaviour
"""
def __init__(self):
self.lockfd = None
def __init__(self) -> None:
self.lockfd: TextIOWrapper | None = None
@contextmanager
def lock_file(self, path, tmpdir, lock_timeout=None):
def lock_file(
self, path: os.PathLike, tmpdir: os.PathLike, lock_timeout: int | float | None = None
) -> t.Generator[None]:
"""
Context for lock acquisition
"""
@ -41,7 +47,9 @@ class FileLock:
finally:
self.unlock()
def set_lock(self, path, tmpdir, lock_timeout=None):
def set_lock(
self, path: os.PathLike, tmpdir: os.PathLike, lock_timeout: int | float | None = None
) -> t.Literal[True]:
"""
Create a lock file based on path with flock to prevent other processes
using given path.
@ -61,13 +69,13 @@ class FileLock:
self.lockfd = open(lock_path, "w")
if lock_timeout <= 0:
if lock_timeout is not None and lock_timeout <= 0:
fcntl.flock(self.lockfd, fcntl.LOCK_EX | fcntl.LOCK_NB)
os.chmod(lock_path, stat.S_IWRITE | stat.S_IREAD)
return True
if lock_timeout:
e_secs = 0
e_secs: float = 0
while e_secs < lock_timeout:
try:
fcntl.flock(self.lockfd, fcntl.LOCK_EX | fcntl.LOCK_NB)
@ -86,7 +94,7 @@ class FileLock:
return True
def unlock(self):
def unlock(self) -> t.Literal[True]:
"""
Make sure lock file is available for everyone and Unlock the file descriptor
locked by set_lock

View file

@ -14,7 +14,7 @@ from __future__ import annotations
import os
import json
import traceback
from ansible.module_utils.basic import env_fallback
from ansible.module_utils.basic import env_fallback, AnsibleModule
try:
import footmark
@ -200,7 +200,7 @@ def get_profile(params):
return params
def ecs_connect(module):
def ecs_connect(module: AnsibleModule):
"""Return an ecs connection"""
ecs_params = get_profile(module.params)
# If we have a region specified, connect to its endpoint.
@ -214,7 +214,7 @@ def ecs_connect(module):
return ecs
def slb_connect(module):
def slb_connect(module: AnsibleModule):
"""Return an slb connection"""
slb_params = get_profile(module.params)
# If we have a region specified, connect to its endpoint.
@ -228,7 +228,7 @@ def slb_connect(module):
return slb
def dns_connect(module):
def dns_connect(module: AnsibleModule):
"""Return an dns connection"""
dns_params = get_profile(module.params)
# If we have a region specified, connect to its endpoint.
@ -242,7 +242,7 @@ def dns_connect(module):
return dns
def vpc_connect(module):
def vpc_connect(module: AnsibleModule):
"""Return an vpc connection"""
vpc_params = get_profile(module.params)
# If we have a region specified, connect to its endpoint.
@ -256,7 +256,7 @@ def vpc_connect(module):
return vpc
def rds_connect(module):
def rds_connect(module: AnsibleModule):
"""Return an rds connection"""
rds_params = get_profile(module.params)
# If we have a region specified, connect to its endpoint.
@ -270,7 +270,7 @@ def rds_connect(module):
return rds
def ess_connect(module):
def ess_connect(module: AnsibleModule):
"""Return an ess connection"""
ess_params = get_profile(module.params)
# If we have a region specified, connect to its endpoint.
@ -284,7 +284,7 @@ def ess_connect(module):
return ess
def sts_connect(module):
def sts_connect(module: AnsibleModule):
"""Return an sts connection"""
sts_params = get_profile(module.params)
# If we have a region specified, connect to its endpoint.
@ -298,7 +298,7 @@ def sts_connect(module):
return sts
def ram_connect(module):
def ram_connect(module: AnsibleModule):
"""Return an ram connection"""
ram_params = get_profile(module.params)
# If we have a region specified, connect to its endpoint.
@ -312,7 +312,7 @@ def ram_connect(module):
return ram
def market_connect(module):
def market_connect(module: AnsibleModule):
"""Return an market connection"""
market_params = get_profile(module.params)
# If we have a region specified, connect to its endpoint.

View file

@ -7,6 +7,8 @@ from __future__ import annotations
import re
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
__state_map = {"present": "--install", "absent": "--uninstall"}
@ -21,7 +23,7 @@ def __map_channel(channel_name):
return __channel_map[channel_name]
def sdkmanager_runner(module, **kwargs):
def sdkmanager_runner(module: AnsibleModule, **kwargs):
return CmdRunner(
module,
command="sdkmanager",
@ -78,7 +80,7 @@ class AndroidSdkManager:
r"the packages they depend on were not accepted"
)
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.runner = sdkmanager_runner(module)
def get_installed_packages(self):

View file

@ -4,6 +4,7 @@
from __future__ import annotations
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes
import re
import os
@ -28,9 +29,9 @@ class BtrfsCommands:
Provides access to a subset of the Btrfs command line
"""
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.__module = module
self.__btrfs = self.__module.get_bin_path("btrfs", required=True)
self.__btrfs: str = self.__module.get_bin_path("btrfs", required=True)
def filesystem_show(self):
command = f"{self.__btrfs} filesystem show -d"
@ -106,10 +107,10 @@ class BtrfsInfoProvider:
Utility providing details of the currently available btrfs filesystems
"""
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.__module = module
self.__btrfs_api = BtrfsCommands(module)
self.__findmnt_path = self.__module.get_bin_path("findmnt", required=True)
self.__findmnt_path: str = self.__module.get_bin_path("findmnt", required=True)
def get_filesystems(self):
filesystems = self.__btrfs_api.filesystem_show()
@ -255,7 +256,7 @@ class BtrfsFilesystem:
Wrapper class providing convenience methods for inspection of a btrfs filesystem
"""
def __init__(self, info, provider, module):
def __init__(self, info, provider, module: AnsibleModule) -> None:
self.__provider = provider
# constant for module execution
@ -415,7 +416,7 @@ class BtrfsFilesystemsProvider:
Provides methods to query available btrfs filesystems
"""
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.__module = module
self.__provider = BtrfsInfoProvider(module)
self.__filesystems = None

View file

@ -6,6 +6,7 @@ from __future__ import annotations
import os
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.collections import is_sequence
from ansible.module_utils.common.locale import get_best_parsable_locale
from ansible_collections.community.general.plugins.module_utils import cmd_runner_fmt
@ -77,7 +78,7 @@ class CmdRunner:
def __init__(
self,
module,
module: AnsibleModule,
command,
arg_formats=None,
default_args_order=(),

View file

@ -12,6 +12,7 @@ import typing as t
from urllib import error as urllib_error
from urllib.parse import urlencode
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.urls import open_url
@ -120,7 +121,7 @@ class _ConsulModule:
operational_attributes: set[str] = set()
params: dict[str, t.Any] = {}
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self._module = module
self.params = _normalize_params(module.params, module.argument_spec)
self.api_params = {

View file

@ -6,11 +6,26 @@
from __future__ import annotations
import csv
import typing as t
from io import StringIO
from ansible.module_utils.common.text.converters import to_native
if t.TYPE_CHECKING:
from collections.abc import Sequence
class DialectParamsOrNone(t.TypedDict):
delimiter: t.NotRequired[str | None]
doublequote: t.NotRequired[bool | None]
escapechar: t.NotRequired[str | None]
lineterminator: t.NotRequired[str | None]
quotechar: t.NotRequired[str | None]
quoting: t.NotRequired[int | None]
skipinitialspace: t.NotRequired[bool | None]
strict: t.NotRequired[bool | None]
class CustomDialectFailureError(Exception):
pass
@ -22,7 +37,7 @@ class DialectNotAvailableError(Exception):
CSVError = csv.Error
def initialize_dialect(dialect, **kwargs):
def initialize_dialect(dialect: str, **kwargs: t.Unpack[DialectParamsOrNone]) -> str:
# Add Unix dialect from Python 3
class unix_dialect(csv.Dialect):
"""Describe the usual properties of Unix-generated CSV files."""
@ -43,7 +58,7 @@ def initialize_dialect(dialect, **kwargs):
dialect_params = {k: v for k, v in kwargs.items() if v is not None}
if dialect_params:
try:
csv.register_dialect("custom", dialect, **dialect_params)
csv.register_dialect("custom", dialect, **dialect_params) # type: ignore
except TypeError as e:
raise CustomDialectFailureError(f"Unable to create custom dialect: {e}") from e
dialect = "custom"
@ -51,7 +66,7 @@ def initialize_dialect(dialect, **kwargs):
return dialect
def read_csv(data, dialect, fieldnames=None):
def read_csv(data: str, dialect: str, fieldnames: Sequence[str] | None = None) -> csv.DictReader:
BOM = "\ufeff"
data = to_native(data, errors="surrogate_or_strict")
if data.startswith(BOM):

View file

@ -15,6 +15,8 @@ from __future__ import annotations
import re
from ansible.module_utils.basic import AnsibleModule
# Input patterns for is_input_dangerous function:
#
@ -162,7 +164,7 @@ def is_input_dangerous(string):
return any(pattern.search(string) for pattern in (PATTERN_1, PATTERN_2, PATTERN_3))
def check_input(module, *args):
def check_input(module: AnsibleModule, *args) -> None:
"""Wrapper for is_input_dangerous function."""
needs_to_check = args

View file

@ -8,15 +8,15 @@ from __future__ import annotations
import datetime as _datetime
def ensure_timezone_info(value):
def ensure_timezone_info(value: _datetime.datetime) -> _datetime.datetime:
if value.tzinfo is not None:
return value
return value.astimezone(_datetime.timezone.utc)
def fromtimestamp(value):
def fromtimestamp(value: int | float) -> _datetime.datetime:
return _datetime.datetime.fromtimestamp(value, tz=_datetime.timezone.utc)
def now():
def now() -> _datetime.datetime:
return _datetime.datetime.now(tz=_datetime.timezone.utc)

View file

@ -9,7 +9,7 @@ from __future__ import annotations
import traceback
from contextlib import contextmanager
from ansible.module_utils.basic import missing_required_lib
from ansible.module_utils.basic import missing_required_lib, AnsibleModule
_deps = dict()
@ -47,7 +47,7 @@ class _Dependency:
def failed(self):
return self.state == 1
def validate(self, module):
def validate(self, module: AnsibleModule) -> None:
if self.failed:
module.fail_json(msg=self.message, exception=self.trace)
@ -86,14 +86,14 @@ def _select_names(spec):
return dep_names
def validate(module, spec=None):
def validate(module: AnsibleModule, spec=None) -> None:
for dep in _select_names(spec):
_deps[dep].validate(module)
def failed(spec=None):
def failed(spec=None) -> bool:
return any(_deps[d].failed for d in _select_names(spec))
def clear():
def clear() -> None:
_deps.clear()

View file

@ -24,7 +24,6 @@ import os
import re
import traceback
# (TODO: remove AnsibleModule from next line!)
from ansible.module_utils.basic import AnsibleModule, missing_required_lib # noqa: F401, pylint: disable=unused-import
from os.path import expanduser
from uuid import UUID
@ -57,7 +56,7 @@ class DimensionDataModule:
The base class containing common functionality used by Dimension Data modules for Ansible.
"""
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
"""
Create a new DimensionDataModule.

View file

@ -6,6 +6,7 @@ from __future__ import annotations
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.dict_transformations import dict_merge
from ansible_collections.community.general.plugins.module_utils import cmd_runner_fmt
from ansible_collections.community.general.plugins.module_utils.python_runner import PythonRunner
@ -81,7 +82,7 @@ _args_menu = dict(
class _DjangoRunner(PythonRunner):
def __init__(self, module, arg_formats=None, **kwargs):
def __init__(self, module: AnsibleModule, arg_formats=None, **kwargs) -> None:
arg_fmts = dict(arg_formats) if arg_formats else {}
arg_fmts.update(_django_std_arg_fmts)

View file

@ -6,6 +6,7 @@ from __future__ import annotations
import json
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.urls import fetch_url
@ -21,7 +22,7 @@ class GandiLiveDNSAPI:
attribute_map = {"record": "rrset_name", "type": "rrset_type", "ttl": "rrset_ttl", "values": "rrset_values"}
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.api_key = module.params["api_key"]
self.personal_access_token = module.params["personal_access_token"]

View file

@ -4,6 +4,8 @@
from __future__ import annotations
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
@ -14,7 +16,7 @@ _state_map = {
}
def gconftool2_runner(module, **kwargs):
def gconftool2_runner(module: AnsibleModule, **kwargs) -> CmdRunner:
return CmdRunner(
module,
command="gconftool-2",

View file

@ -4,10 +4,12 @@
from __future__ import annotations
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
def gio_mime_runner(module, **kwargs):
def gio_mime_runner(module: AnsibleModule, **kwargs):
return CmdRunner(
module,
command=["gio"],

View file

@ -7,7 +7,7 @@ from __future__ import annotations
import typing as t
from ansible.module_utils.basic import missing_required_lib
from ansible.module_utils.basic import missing_required_lib, AnsibleModule
from ansible_collections.community.general.plugins.module_utils.version import LooseVersion
@ -76,7 +76,7 @@ def find_group(gitlab_instance, identifier):
return group
def ensure_gitlab_package(module, min_version=None):
def ensure_gitlab_package(module: AnsibleModule, min_version=None) -> None:
if not HAS_GITLAB_PACKAGE:
module.fail_json(
msg=missing_required_lib("python-gitlab", url="https://python-gitlab.readthedocs.io/en/stable/"),
@ -92,7 +92,7 @@ def ensure_gitlab_package(module, min_version=None):
)
def gitlab_authentication(module, min_version=None):
def gitlab_authentication(module: AnsibleModule, min_version=None) -> gitlab.Gitlab:
ensure_gitlab_package(module, min_version=min_version)
gitlab_url = module.params["api_url"]
@ -121,7 +121,7 @@ def gitlab_authentication(module, min_version=None):
private_token=gitlab_token,
oauth_token=gitlab_oauth_token,
job_token=gitlab_job_token,
api_version=4,
api_version="4",
)
gitlab_instance.auth()
except (gitlab.exceptions.GitlabAuthenticationError, gitlab.exceptions.GitlabGetError) as e:
@ -158,7 +158,7 @@ def filter_returned_variables(gitlab_variables):
return existing_variables
def vars_to_variables(vars, module):
def vars_to_variables(vars, module: AnsibleModule):
# transform old vars to new variables structure
variables = list()
for item, value in vars.items():

View file

@ -6,7 +6,7 @@ from __future__ import annotations
import traceback
from ansible.module_utils.basic import env_fallback, missing_required_lib
from ansible.module_utils.basic import env_fallback, missing_required_lib, AnsibleModule
HAS_HEROKU = False
HEROKU_IMP_ERR = None
@ -19,7 +19,7 @@ except ImportError:
class HerokuHelper:
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.check_lib()
self.api_key = module.params["api_key"]

View file

@ -7,9 +7,13 @@ from __future__ import annotations
import os
import re
import typing as t
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def _create_regex_group_complement(s):
def _create_regex_group_complement(s: str) -> re.Pattern:
lines = (line.strip() for line in s.split("\n") if line.strip())
chars = [_f for _f in (line.split("#")[0].strip() for line in lines) if _f]
group = rf"[^{''.join(chars)}]"
@ -52,7 +56,7 @@ class HomebrewValidate:
# class validations -------------------------------------------- {{{
@classmethod
def valid_path(cls, path):
def valid_path(cls, path: list[str] | str) -> bool:
"""
`path` must be one of:
- list of paths
@ -77,7 +81,7 @@ class HomebrewValidate:
return all(cls.valid_brew_path(path_) for path_ in paths)
@classmethod
def valid_brew_path(cls, brew_path):
def valid_brew_path(cls, brew_path: str | None) -> bool:
"""
`brew_path` must be one of:
- None
@ -95,7 +99,7 @@ class HomebrewValidate:
return isinstance(brew_path, str) and not cls.INVALID_BREW_PATH_REGEX.search(brew_path)
@classmethod
def valid_package(cls, package):
def valid_package(cls, package: str | None) -> bool:
"""A valid package is either None or alphanumeric."""
if package is None:
@ -104,8 +108,7 @@ class HomebrewValidate:
return isinstance(package, str) and not cls.INVALID_PACKAGE_REGEX.search(package)
def parse_brew_path(module):
# type: (...) -> str
def parse_brew_path(module: AnsibleModule) -> str:
"""Attempt to find the Homebrew executable path.
Requires:

View file

@ -7,6 +7,7 @@ from __future__ import annotations
import re
import time
import traceback
import typing as t
THIRD_LIBRARIES_IMP_ERR = None
try:
@ -135,18 +136,18 @@ class _ServiceClient:
class Config:
def __init__(self, module, product):
def __init__(self, module: AnsibleModule, product):
self._project_client = None
self._domain_client = None
self._module = module
self._product = product
self._endpoints = {}
self._endpoints: dict[str, t.Any] = {}
self._validate()
self._gen_provider_client()
@property
def module(self):
def module(self) -> AnsibleModule:
return self._module
def client(self, region, service_type, service_level):
@ -380,7 +381,7 @@ def navigate_value(data, index, array_index=None):
return d
def build_path(module, path, kv=None):
def build_path(module: AnsibleModule, path, kv=None):
if kv is None:
kv = dict()
@ -400,7 +401,7 @@ def build_path(module, path, kv=None):
return path.format(**v)
def get_region(module):
def get_region(module: AnsibleModule):
if module.params["region"]:
return module.params["region"]

View file

@ -9,7 +9,7 @@ from __future__ import annotations
import traceback
from functools import wraps
from ansible.module_utils.basic import missing_required_lib
from ansible.module_utils.basic import missing_required_lib, AnsibleModule
PYXCLI_INSTALLED = True
PYXCLI_IMP_ERR = None
@ -49,7 +49,7 @@ def xcli_wrapper(func):
"""Catch xcli errors and return a proper message"""
@wraps(func)
def wrapper(module, *args, **kwargs):
def wrapper(module: AnsibleModule, *args, **kwargs):
try:
return func(module, *args, **kwargs)
except errors.CommandExecutionError as e:
@ -59,7 +59,7 @@ def xcli_wrapper(func):
@xcli_wrapper
def connect_ssl(module):
def connect_ssl(module: AnsibleModule):
endpoints = module.params["endpoints"]
username = module.params["username"]
password = module.params["password"]
@ -82,7 +82,7 @@ def spectrum_accelerate_spec():
@xcli_wrapper
def execute_pyxcli_command(module, xcli_command, xcli_client):
def execute_pyxcli_command(module: AnsibleModule, xcli_command, xcli_client):
pyxcli_args = build_pyxcli_command(module.params)
getattr(xcli_client.cmd, xcli_command)(**(pyxcli_args))
return True
@ -99,6 +99,6 @@ def build_pyxcli_command(fields):
return pyxcli_args
def is_pyxcli_installed(module):
def is_pyxcli_installed(module: AnsibleModule):
if not PYXCLI_INSTALLED:
module.fail_json(msg=missing_required_lib("pyxcli"), exception=PYXCLI_IMP_ERR)

View file

@ -6,7 +6,7 @@ from __future__ import annotations
import traceback
from ansible.module_utils.basic import missing_required_lib
from ansible.module_utils.basic import missing_required_lib, AnsibleModule
from ansible_collections.community.general.plugins.module_utils.version import LooseVersion
@ -32,7 +32,7 @@ except ImportError:
class InfluxDb:
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.params = self.module.params
self.check_lib()

View file

@ -19,7 +19,7 @@ import uuid
import re
from ansible.module_utils.common.text.converters import to_bytes, to_text
from ansible.module_utils.urls import fetch_url, HAS_GSSAPI
from ansible.module_utils.basic import env_fallback, AnsibleFallbackNotFound
from ansible.module_utils.basic import env_fallback, AnsibleFallbackNotFound, AnsibleModule
from urllib.parse import quote
@ -41,7 +41,7 @@ def _env_then_dns_fallback(*args, **kwargs):
class IPAClient:
def __init__(self, module, host, port, protocol):
def __init__(self, module: AnsibleModule, host, port, protocol):
self.host = host
self.port = port
self.protocol = protocol

View file

@ -10,6 +10,8 @@ from __future__ import annotations
import re
import traceback
from ansible.module_utils.basic import AnsibleModule
try:
import ldap
import ldap.dn
@ -52,7 +54,7 @@ def ldap_required_together():
class LdapGeneric:
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
# Shortcuts
self.module = module
self.bind_dn = self.module.params["bind_dn"]

View file

@ -14,6 +14,6 @@ from __future__ import annotations
from ansible.module_utils.ansible_release import __version__ as ansible_version
def get_user_agent(module):
def get_user_agent(module: str) -> str:
"""Retrieve a user-agent to send with LinodeClient requests."""
return f"Ansible-{module}/{ansible_version}"

View file

@ -4,10 +4,12 @@
from __future__ import annotations
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
def locale_runner(module):
def locale_runner(module: AnsibleModule) -> CmdRunner:
runner = CmdRunner(
module,
command=["locale", "-a"],
@ -16,7 +18,7 @@ def locale_runner(module):
return runner
def locale_gen_runner(module):
def locale_gen_runner(module: AnsibleModule) -> CmdRunner:
runner = CmdRunner(
module,
command="locale-gen",

View file

@ -15,8 +15,9 @@ from __future__ import annotations
import os
import traceback
import typing as t
from ansible.module_utils.basic import missing_required_lib
from ansible.module_utils.basic import missing_required_lib, AnsibleModule
CLIENT_IMP_ERR = None
try:
@ -43,24 +44,25 @@ def manageiq_argument_spec():
)
def check_client(module):
def check_client(module: AnsibleModule) -> None:
if not HAS_CLIENT:
module.fail_json(msg=missing_required_lib("manageiq-client"), exception=CLIENT_IMP_ERR)
def validate_connection_params(module):
params = module.params["manageiq_connection"]
def validate_connection_params(module: AnsibleModule) -> dict[str, t.Any]:
params: dict[str, t.Any] = module.params["manageiq_connection"]
error_str = "missing required argument: manageiq_connection[{}]"
url = params["url"]
token = params["token"]
username = params["username"]
password = params["password"]
url: str | None = params["url"]
token: str | None = params["token"]
username: str | None = params["username"]
password: str | None = params["password"]
if (url and username and password) or (url and token):
return params
for arg in ["url", "username", "password"]:
if params[arg] in (None, ""):
module.fail_json(msg=error_str.format(arg))
raise AssertionError("should be unreachable")
def manageiq_entities():
@ -87,7 +89,7 @@ class ManageIQ:
class encapsulating ManageIQ API client.
"""
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
# handle import errors
check_client(module)
@ -111,7 +113,7 @@ class ManageIQ:
self.module.fail_json(msg=f"failed to open connection ({url}): {e}")
@property
def module(self):
def module(self) -> AnsibleModule:
"""Ansible module module
Returns:
@ -192,7 +194,7 @@ class ManageIQPolicies:
Object to execute policies management operations of manageiq resources.
"""
def __init__(self, manageiq, resource_type, resource_id):
def __init__(self, manageiq: ManageIQ, resource_type, resource_id):
self.manageiq = manageiq
self.module = self.manageiq.module
@ -330,7 +332,7 @@ class ManageIQTags:
Object to execute tags management operations of manageiq resources.
"""
def __init__(self, manageiq, resource_type, resource_id):
def __init__(self, manageiq: ManageIQ, resource_type, resource_id):
self.manageiq = manageiq
self.module = self.manageiq.module

View file

@ -12,6 +12,7 @@ from urllib.error import URLError, HTTPError
from urllib.parse import urlparse
from ansible.module_utils.urls import open_url
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_native
@ -24,7 +25,7 @@ HEALTH_OK = 5
class OcapiUtils:
def __init__(self, creds, base_uri, proxy_slot_number, timeout, module):
def __init__(self, creds, base_uri, proxy_slot_number, timeout, module: AnsibleModule):
self.root_uri = base_uri
self.proxy_slot_number = proxy_slot_number
self.creds = creds

View file

@ -14,11 +14,11 @@ class OnePasswordConfig:
"~/.config/.op/config",
)
def __init__(self):
def __init__(self) -> None:
self._config_file_path = ""
@property
def config_file_path(self):
def config_file_path(self) -> str | None:
if self._config_file_path:
return self._config_file_path
@ -27,3 +27,5 @@ class OnePasswordConfig:
if os.path.exists(realpath):
self._config_file_path = realpath
return self._config_file_path
return None

View file

@ -7,7 +7,7 @@ from __future__ import annotations
import json
import sys
from ansible.module_utils.basic import env_fallback
from ansible.module_utils.basic import env_fallback, AnsibleModule
from ansible.module_utils.urls import fetch_url
@ -60,7 +60,7 @@ class Response:
class Online:
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.headers = {
"Authorization": f"Bearer {self.module.params.get('api_token')}",

View file

@ -6,6 +6,8 @@ from __future__ import annotations
import re
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
_state_map = {
@ -46,7 +48,7 @@ def fmt_resource_argument(value):
return ["--group" if value["argument_action"] == "group" else value["argument_action"]] + value["argument_option"]
def get_pacemaker_maintenance_mode(runner):
def get_pacemaker_maintenance_mode(runner: CmdRunner) -> bool:
with runner("cli_action config") as ctx:
rc, out, err = ctx.run(cli_action="property")
maint_mode_re = re.compile(r"maintenance-mode.*true", re.IGNORECASE)
@ -54,7 +56,7 @@ def get_pacemaker_maintenance_mode(runner):
return bool(maintenance_mode_output)
def pacemaker_runner(module, **kwargs):
def pacemaker_runner(module: AnsibleModule, **kwargs) -> CmdRunner:
runner_command = ["pcs"]
runner = CmdRunner(
module,

View file

@ -7,6 +7,7 @@ from __future__ import annotations
import json
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
@ -36,7 +37,7 @@ _state_map = dict(
)
def pipx_runner(module, command, **kwargs):
def pipx_runner(module: AnsibleModule, command, **kwargs) -> CmdRunner:
arg_formats = dict(
state=cmd_runner_fmt.as_map(_state_map),
name=cmd_runner_fmt.as_list(),

View file

@ -4,6 +4,8 @@
from __future__ import annotations
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.general.plugins.module_utils import deps
@ -13,7 +15,7 @@ with deps.declare("packaging"):
class PackageRequirement:
def __init__(self, module, name):
def __init__(self, module: AnsibleModule, name) -> None:
self.module = module
self.parsed_name, self.requirement = self._parse_spec(name)

View file

@ -7,26 +7,28 @@ from __future__ import annotations
import os
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
_PUPPET_PATH_PREFIX = ["/opt/puppetlabs/bin"]
def get_facter_dir():
def get_facter_dir() -> str:
if os.getuid() == 0:
return "/etc/facter/facts.d"
else:
return os.path.expanduser("~/.facter/facts.d")
def _puppet_cmd(module):
def _puppet_cmd(module: AnsibleModule) -> str | None:
return module.get_bin_path("puppet", False, _PUPPET_PATH_PREFIX)
# If the `timeout` CLI command feature is removed,
# Then we could add this as a fixed param to `puppet_runner`
def ensure_agent_enabled(module):
def ensure_agent_enabled(module: AnsibleModule) -> None:
runner = CmdRunner(
module,
command="puppet",
@ -44,7 +46,7 @@ def ensure_agent_enabled(module):
module.fail_json(msg="Puppet agent state could not be determined.")
def puppet_runner(module):
def puppet_runner(module: AnsibleModule) -> CmdRunner:
# Keeping backward compatibility, allow for running with the `timeout` CLI command.
# If this can be replaced with ansible `timeout` parameter in playbook,
# then this function could be removed.

View file

@ -6,13 +6,15 @@ from __future__ import annotations
import os
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, _ensure_list
class PythonRunner(CmdRunner):
def __init__(
self,
module,
module: AnsibleModule,
command,
arg_formats=None,
default_args_order=(),

View file

@ -10,6 +10,7 @@ import os
import random
import string
import time
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.urls import open_url
from ansible.module_utils.common.text.converters import to_native
from ansible.module_utils.common.text.converters import to_text
@ -52,7 +53,7 @@ class RedfishUtils:
creds,
root_uri,
timeout,
module,
module: AnsibleModule,
resource_id=None,
data_modification=False,
strip_etag_quotes=False,

View file

@ -5,11 +5,11 @@
from __future__ import annotations
from ansible.module_utils.basic import missing_required_lib
from ansible.module_utils.basic import missing_required_lib, AnsibleModule
import traceback
REDIS_IMP_ERR = None
REDIS_IMP_ERR: str | None = None
try:
from redis import Redis
from redis import __version__ as redis_version
@ -20,25 +20,25 @@ except ImportError:
REDIS_IMP_ERR = traceback.format_exc()
HAS_REDIS_PACKAGE = False
CERTIFI_IMPORT_ERROR: str | None = None
try:
import certifi
HAS_CERTIFI_PACKAGE = True
CERTIFI_IMPORT_ERROR = None
except ImportError:
CERTIFI_IMPORT_ERROR = traceback.format_exc()
HAS_CERTIFI_PACKAGE = False
def fail_imports(module, needs_certifi=True):
errors = []
traceback = []
def fail_imports(module: AnsibleModule, needs_certifi: bool = True) -> None:
errors: list[str] = []
traceback: list[str] = []
if not HAS_REDIS_PACKAGE:
errors.append(missing_required_lib("redis"))
traceback.append(REDIS_IMP_ERR)
traceback.append(REDIS_IMP_ERR) # type: ignore
if not HAS_CERTIFI_PACKAGE and needs_certifi:
errors.append(missing_required_lib("certifi"))
traceback.append(CERTIFI_IMPORT_ERROR)
traceback.append(CERTIFI_IMPORT_ERROR) # type: ignore
if errors:
module.fail_json(msg="\n".join(errors), traceback="\n".join(traceback))
@ -60,7 +60,7 @@ def redis_auth_argument_spec(tls_default=True):
)
def redis_auth_params(module):
def redis_auth_params(module: AnsibleModule):
login_host = module.params["login_host"]
login_user = module.params["login_user"]
login_password = module.params["login_password"]
@ -92,7 +92,7 @@ def redis_auth_params(module):
class RedisAnsible:
"""Base class for Redis module"""
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.connection = self._connect()

View file

@ -7,6 +7,7 @@ from __future__ import annotations
import json
import traceback
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.urls import fetch_url, url_argument_spec
@ -27,7 +28,7 @@ def api_argument_spec():
return api_argument_spec
def api_request(module, endpoint, data=None, method="GET", content_type="application/json"):
def api_request(module: AnsibleModule, endpoint, data=None, method="GET", content_type="application/json"):
"""Manages Rundeck API requests via HTTP(S)
:arg module: The AnsibleModule (used to get url, api_version, api_token, etc).

View file

@ -12,7 +12,7 @@ import time
import traceback
from urllib.parse import urlencode
from ansible.module_utils.basic import env_fallback, missing_required_lib
from ansible.module_utils.basic import env_fallback, missing_required_lib, AnsibleModule
from ansible.module_utils.urls import fetch_url
from ansible_collections.community.general.plugins.module_utils.datetime import (
@ -169,7 +169,7 @@ class Response:
class Scaleway:
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.headers = {
"X-Auth-Token": self.module.params.get("api_token"),

View file

@ -4,6 +4,8 @@
from __future__ import annotations
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
@ -22,7 +24,7 @@ _state_map = dict(
)
def snap_runner(module, **kwargs):
def snap_runner(module: AnsibleModule, **kwargs) -> CmdRunner:
runner = CmdRunner(
module,
"snap",
@ -47,7 +49,7 @@ def snap_runner(module, **kwargs):
return runner
def get_version(runner):
def get_version(runner: CmdRunner):
with runner("version") as ctx:
rc, out, err = ctx.run()
return dict(x.split() for x in out.splitlines() if len(x.split()) == 2)

View file

@ -11,7 +11,7 @@ from __future__ import annotations
import os
def determine_config_file(user, config_file):
def determine_config_file(user: str | None, config_file: str | None) -> str:
if user:
config_file = os.path.join(os.path.expanduser(f"~{user}"), ".ssh", "config")
elif config_file is None:

View file

@ -4,11 +4,12 @@
from __future__ import annotations
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
def systemd_runner(module, command, **kwargs):
def systemd_runner(module: AnsibleModule, command, **kwargs) -> CmdRunner:
arg_formats = dict(
version=cmd_runner_fmt.as_fixed("--version"),
list_units=cmd_runner_fmt.as_fixed(["list-units", "--no-pager"]),

View file

@ -165,7 +165,7 @@ def ldap_search(filter, base=None, attr=None):
uldap().lo.lo.abandon(msgid)
def module_by_name(module_name_):
def module_by_name(module_name_: str):
"""Returns an initialized UMC module, identified by the given name.
The module is a module specification according to the udm commandline.
@ -202,7 +202,7 @@ def get_umc_admin_objects():
return univention.admin.objects
def umc_module_for_add(module, container_dn, superordinate=None):
def umc_module_for_add(module: str, container_dn, superordinate=None):
"""Returns an UMC module object prepared for creating a new entry.
The module is a module specification according to the udm commandline.
@ -226,7 +226,7 @@ def umc_module_for_add(module, container_dn, superordinate=None):
return obj
def umc_module_for_edit(module, object_dn, superordinate=None):
def umc_module_for_edit(module: str, object_dn, superordinate=None):
"""Returns an UMC module object prepared for editing an existing entry.
The module is a module specification according to the udm commandline.

View file

@ -75,7 +75,7 @@ class UTMModule(AnsibleModule):
class UTM:
def __init__(self, module, endpoint, change_relevant_keys, info_only=False):
def __init__(self, module: UTMModule, endpoint, change_relevant_keys, info_only=False):
"""
Initialize UTM Class
:param module: The Ansible module

View file

@ -12,7 +12,7 @@ try:
except ImportError:
HAS_VEXATAPI = False
from ansible.module_utils.basic import env_fallback
from ansible.module_utils.basic import env_fallback, AnsibleModule
VXOS_VERSION = None
@ -31,7 +31,7 @@ def get_version(iocs_json):
return tuple(ver)
def get_array(module):
def get_array(module: AnsibleModule):
"""Return storage array object or fail"""
global VXOS_VERSION
array = module.params["array"]

View file

@ -11,6 +11,7 @@ import tarfile
import os
from urllib.parse import urlparse, urlunparse
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.urls import fetch_file
from ansible_collections.community.general.plugins.module_utils.redfish_utils import RedfishUtils
@ -41,7 +42,7 @@ class WdcRedfishUtils(RedfishUtils):
CHASSIS_LOCATE = "#Chassis.Locate"
CHASSIS_POWER_MODE = "#Chassis.PowerMode"
def __init__(self, creds, root_uris, timeout, module, resource_id, data_modification):
def __init__(self, creds, root_uris, timeout, module: AnsibleModule, resource_id, data_modification):
super().__init__(
creds=creds,
root_uri=root_uris[0],

View file

@ -5,10 +5,12 @@
from __future__ import annotations
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
def xdg_mime_runner(module, **kwargs):
def xdg_mime_runner(module: AnsibleModule, **kwargs) -> CmdRunner:
return CmdRunner(
module,
command=["xdg-mime"],
@ -23,7 +25,7 @@ def xdg_mime_runner(module, **kwargs):
)
def xdg_mime_get(runner, mime_type):
def xdg_mime_get(runner: CmdRunner, mime_type):
def process(rc, out, err):
if not out.strip():
return None

View file

@ -9,6 +9,7 @@ import atexit
import time
import re
import traceback
import typing as t
XENAPI_IMP_ERR = None
try:
@ -19,7 +20,7 @@ except ImportError:
HAS_XENAPI = False
XENAPI_IMP_ERR = traceback.format_exc()
from ansible.module_utils.basic import env_fallback, missing_required_lib
from ansible.module_utils.basic import env_fallback, missing_required_lib, AnsibleModule
from ansible.module_utils.ansible_release import __version__ as ANSIBLE_VERSION
@ -239,7 +240,7 @@ def is_valid_ip6_prefix(ip6_prefix):
return not (ip6_prefix_int < 0 or ip6_prefix_int > 128)
def get_object_ref(module, name, uuid=None, obj_type="VM", fail=True, msg_prefix=""):
def get_object_ref(module: AnsibleModule, name, uuid=None, obj_type="VM", fail=True, msg_prefix=""):
"""Finds and returns a reference to arbitrary XAPI object.
An object is searched by using either name (name_label) or UUID
@ -305,7 +306,7 @@ def get_object_ref(module, name, uuid=None, obj_type="VM", fail=True, msg_prefix
return obj_ref
def gather_vm_params(module, vm_ref):
def gather_vm_params(module: AnsibleModule, vm_ref):
"""Gathers all VM parameters available in XAPI database.
Args:
@ -395,7 +396,7 @@ def gather_vm_params(module, vm_ref):
return vm_params
def gather_vm_facts(module, vm_params):
def gather_vm_facts(module: AnsibleModule, vm_params):
"""Gathers VM facts.
Args:
@ -502,7 +503,7 @@ def gather_vm_facts(module, vm_params):
return vm_facts
def set_vm_power_state(module, vm_ref, power_state, timeout=300):
def set_vm_power_state(module: AnsibleModule, vm_ref, power_state, timeout=300):
"""Controls VM power state.
Args:
@ -608,7 +609,7 @@ def set_vm_power_state(module, vm_ref, power_state, timeout=300):
return (state_changed, vm_power_state_resulting)
def wait_for_task(module, task_ref, timeout=300):
def wait_for_task(module: AnsibleModule, task_ref, timeout=300):
"""Waits for async XAPI task to finish.
Args:
@ -667,7 +668,7 @@ def wait_for_task(module, task_ref, timeout=300):
return result
def wait_for_vm_ip_address(module, vm_ref, timeout=300):
def wait_for_vm_ip_address(module: AnsibleModule, vm_ref, timeout=300):
"""Waits for VM to acquire an IP address.
Args:
@ -730,7 +731,7 @@ def wait_for_vm_ip_address(module, vm_ref, timeout=300):
return vm_guest_metrics
def get_xenserver_version(module):
def get_xenserver_version(module: AnsibleModule):
"""Returns XenServer version.
Args:
@ -758,10 +759,10 @@ def get_xenserver_version(module):
class XAPI:
"""Class for XAPI session management."""
_xapi_session = None
_xapi_session: t.Any | None = None
@classmethod
def connect(cls, module, disconnect_atexit=True):
def connect(cls, module: AnsibleModule, disconnect_atexit=True):
"""Establishes XAPI connection and returns session reference.
If no existing session is available, establishes a new one
@ -837,7 +838,7 @@ class XenServerObject:
minor version.
"""
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
"""Inits XenServerObject using common module parameters.
Args:

View file

@ -4,6 +4,7 @@
from __future__ import annotations
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.parsing.convert_bool import boolean
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
@ -18,7 +19,7 @@ def _values_fmt(values, value_types):
return result
def xfconf_runner(module, **kwargs):
def xfconf_runner(module: AnsibleModule, **kwargs) -> CmdRunner:
runner = CmdRunner(
module,
command="xfconf-query",
@ -37,7 +38,7 @@ def xfconf_runner(module, **kwargs):
return runner
def get_xfconf_version(runner):
def get_xfconf_version(runner: CmdRunner) -> str:
with runner("version") as ctx:
rc, out, err = ctx.run()
return out.splitlines()[0].split()[1]