From c7f6a28d89e0ec2733aa014df07ce497f2562c54 Mon Sep 17 00:00:00 2001 From: Felix Fontein Date: Mon, 1 Dec 2025 20:40:06 +0100 Subject: [PATCH] Add basic typing for module_utils (#11222) * Add basic typing for module_utils. * Apply some suggestions. Co-authored-by: Alexei Znamensky <103110+russoz@users.noreply.github.com> * Make pass again. * Add more types as suggested. * Normalize extra imports. * Add more type hints. * Improve typing. * Add changelog fragment. * Reduce changelog. * Apply suggestions from code review. Co-authored-by: Alexei Znamensky <103110+russoz@users.noreply.github.com> * Fix typo. * Cleanup. * Improve types and make type checking happy. * Let's see whether older Pythons barf on this. * Revert "Let's see whether older Pythons barf on this." This reverts commit 9973af3dbe8e87e83a7c8f27fc44f2b47f79c304. * Add noqa. --------- Co-authored-by: Alexei Znamensky <103110+russoz@users.noreply.github.com> --- changelogs/fragments/11222-typing.yml | 3 + plugins/module_utils/_filelock.py | 22 ++- plugins/module_utils/alicloud_ecs.py | 23 ++- plugins/module_utils/android_sdkmanager.py | 33 ++-- plugins/module_utils/btrfs.py | 198 +++++++++++---------- plugins/module_utils/cmd_runner.py | 53 ++++-- plugins/module_utils/cmd_runner_fmt.py | 52 ++++-- plugins/module_utils/consul.py | 5 +- plugins/module_utils/csv.py | 21 ++- plugins/module_utils/database.py | 6 +- plugins/module_utils/datetime.py | 6 +- plugins/module_utils/deps.py | 34 ++-- plugins/module_utils/dimensiondata.py | 3 +- plugins/module_utils/django.py | 17 +- plugins/module_utils/gandi_livedns_api.py | 6 +- plugins/module_utils/gconftool2.py | 7 +- plugins/module_utils/gio_mime.py | 11 +- plugins/module_utils/gitlab.py | 25 +-- plugins/module_utils/heroku.py | 10 +- plugins/module_utils/homebrew.py | 15 +- plugins/module_utils/hwc_utils.py | 41 ++--- plugins/module_utils/ibm_sa_utils.py | 14 +- plugins/module_utils/ilo_redfish_utils.py | 28 +-- plugins/module_utils/influxdb.py | 12 +- plugins/module_utils/ipa.py | 26 +-- plugins/module_utils/jenkins.py | 2 +- plugins/module_utils/ldap.py | 18 +- plugins/module_utils/linode.py | 2 +- plugins/module_utils/locale_gen.py | 9 +- plugins/module_utils/lxd.py | 36 ++-- plugins/module_utils/manageiq.py | 33 ++-- plugins/module_utils/ocapi_utils.py | 6 +- plugins/module_utils/onepassword.py | 6 +- plugins/module_utils/online.py | 10 +- plugins/module_utils/pacemaker.py | 9 +- plugins/module_utils/pipx.py | 8 +- plugins/module_utils/pkg_req.py | 11 +- plugins/module_utils/puppet.py | 13 +- plugins/module_utils/python_runner.py | 22 ++- plugins/module_utils/redfish_utils.py | 36 ++-- plugins/module_utils/redis.py | 31 ++-- plugins/module_utils/rundeck.py | 14 +- plugins/module_utils/scaleway.py | 24 ++- plugins/module_utils/snap.py | 9 +- plugins/module_utils/ssh.py | 2 +- plugins/module_utils/systemd.py | 6 +- plugins/module_utils/univention_umc.py | 6 +- plugins/module_utils/utm_utils.py | 7 +- plugins/module_utils/vardict.py | 88 +++++---- plugins/module_utils/vexata.py | 23 ++- plugins/module_utils/wdc_redfish_utils.py | 8 +- plugins/module_utils/xdg_mime.py | 11 +- plugins/module_utils/xenserver.py | 44 ++--- plugins/module_utils/xfconf.py | 9 +- plugins/modules/lxd_storage_pool_info.py | 10 +- plugins/modules/lxd_storage_volume_info.py | 10 +- 56 files changed, 725 insertions(+), 469 deletions(-) create mode 100644 changelogs/fragments/11222-typing.yml diff --git a/changelogs/fragments/11222-typing.yml b/changelogs/fragments/11222-typing.yml new file mode 100644 index 0000000000..0edaf44c47 --- /dev/null +++ b/changelogs/fragments/11222-typing.yml @@ -0,0 +1,3 @@ +bugfixes: + - "_filelock module utils - add type hints. Fix bug if ``set_lock()`` is called with ``lock_timeout=None`` (https://github.com/ansible-collections/community.general/pull/11222)." + - "gitlab module utils - add type hints. Pass API version to python-gitlab as string and not as integer (https://github.com/ansible-collections/community.general/pull/11222)." diff --git a/plugins/module_utils/_filelock.py b/plugins/module_utils/_filelock.py index 73cbccf8e8..3821bba4a2 100644 --- a/plugins/module_utils/_filelock.py +++ b/plugins/module_utils/_filelock.py @@ -11,9 +11,13 @@ import os import stat import time import fcntl +import typing as t from contextlib import contextmanager +if t.TYPE_CHECKING: + from io import TextIOWrapper + class LockTimeout(Exception): pass @@ -27,11 +31,13 @@ class FileLock: unwanted and/or unexpected behaviour """ - def __init__(self): - self.lockfd = None + def __init__(self) -> None: + self.lockfd: TextIOWrapper | None = None @contextmanager - def lock_file(self, path, tmpdir, lock_timeout=None): + def lock_file( + self, path: os.PathLike, tmpdir: os.PathLike, lock_timeout: int | float | None = None + ) -> t.Generator[None]: """ Context for lock acquisition """ @@ -41,7 +47,9 @@ class FileLock: finally: self.unlock() - def set_lock(self, path, tmpdir, lock_timeout=None): + def set_lock( + self, path: os.PathLike, tmpdir: os.PathLike, lock_timeout: int | float | None = None + ) -> t.Literal[True]: """ Create a lock file based on path with flock to prevent other processes using given path. @@ -61,13 +69,13 @@ class FileLock: self.lockfd = open(lock_path, "w") - if lock_timeout <= 0: + if lock_timeout is not None and lock_timeout <= 0: fcntl.flock(self.lockfd, fcntl.LOCK_EX | fcntl.LOCK_NB) os.chmod(lock_path, stat.S_IWRITE | stat.S_IREAD) return True if lock_timeout: - e_secs = 0 + e_secs: float = 0 while e_secs < lock_timeout: try: fcntl.flock(self.lockfd, fcntl.LOCK_EX | fcntl.LOCK_NB) @@ -86,7 +94,7 @@ class FileLock: return True - def unlock(self): + def unlock(self) -> t.Literal[True]: """ Make sure lock file is available for everyone and Unlock the file descriptor locked by set_lock diff --git a/plugins/module_utils/alicloud_ecs.py b/plugins/module_utils/alicloud_ecs.py index c21c2261f7..fb9c1bfb9d 100644 --- a/plugins/module_utils/alicloud_ecs.py +++ b/plugins/module_utils/alicloud_ecs.py @@ -14,8 +14,13 @@ from __future__ import annotations import os import json import traceback +import typing as t + from ansible.module_utils.basic import env_fallback +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + try: import footmark import footmark.ecs @@ -200,7 +205,7 @@ def get_profile(params): return params -def ecs_connect(module): +def ecs_connect(module: AnsibleModule): """Return an ecs connection""" ecs_params = get_profile(module.params) # If we have a region specified, connect to its endpoint. @@ -214,7 +219,7 @@ def ecs_connect(module): return ecs -def slb_connect(module): +def slb_connect(module: AnsibleModule): """Return an slb connection""" slb_params = get_profile(module.params) # If we have a region specified, connect to its endpoint. @@ -228,7 +233,7 @@ def slb_connect(module): return slb -def dns_connect(module): +def dns_connect(module: AnsibleModule): """Return an dns connection""" dns_params = get_profile(module.params) # If we have a region specified, connect to its endpoint. @@ -242,7 +247,7 @@ def dns_connect(module): return dns -def vpc_connect(module): +def vpc_connect(module: AnsibleModule): """Return an vpc connection""" vpc_params = get_profile(module.params) # If we have a region specified, connect to its endpoint. @@ -256,7 +261,7 @@ def vpc_connect(module): return vpc -def rds_connect(module): +def rds_connect(module: AnsibleModule): """Return an rds connection""" rds_params = get_profile(module.params) # If we have a region specified, connect to its endpoint. @@ -270,7 +275,7 @@ def rds_connect(module): return rds -def ess_connect(module): +def ess_connect(module: AnsibleModule): """Return an ess connection""" ess_params = get_profile(module.params) # If we have a region specified, connect to its endpoint. @@ -284,7 +289,7 @@ def ess_connect(module): return ess -def sts_connect(module): +def sts_connect(module: AnsibleModule): """Return an sts connection""" sts_params = get_profile(module.params) # If we have a region specified, connect to its endpoint. @@ -298,7 +303,7 @@ def sts_connect(module): return sts -def ram_connect(module): +def ram_connect(module: AnsibleModule): """Return an ram connection""" ram_params = get_profile(module.params) # If we have a region specified, connect to its endpoint. @@ -312,7 +317,7 @@ def ram_connect(module): return ram -def market_connect(module): +def market_connect(module: AnsibleModule): """Return an market connection""" market_params = get_profile(module.params) # If we have a region specified, connect to its endpoint. diff --git a/plugins/module_utils/android_sdkmanager.py b/plugins/module_utils/android_sdkmanager.py index 0bdcfb7ac3..8396179f2d 100644 --- a/plugins/module_utils/android_sdkmanager.py +++ b/plugins/module_utils/android_sdkmanager.py @@ -6,22 +6,27 @@ from __future__ import annotations import re +import typing as t from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + + __state_map = {"present": "--install", "absent": "--uninstall"} # sdkmanager --help 2>&1 | grep -A 2 -- --channel __channel_map = {"stable": 0, "beta": 1, "dev": 2, "canary": 3} -def __map_channel(channel_name): +def __map_channel(channel_name: str) -> int: if channel_name not in __channel_map: raise ValueError(f"Unknown channel name '{channel_name}'") return __channel_map[channel_name] -def sdkmanager_runner(module, **kwargs): +def sdkmanager_runner(module: AnsibleModule, **kwargs) -> CmdRunner: return CmdRunner( module, command="sdkmanager", @@ -40,18 +45,18 @@ def sdkmanager_runner(module, **kwargs): class Package: - def __init__(self, name): + def __init__(self, name: str) -> None: self.name = name - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: if not isinstance(other, Package): return True return self.name != other.name - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, Package): return False @@ -78,20 +83,20 @@ class AndroidSdkManager: r"the packages they depend on were not accepted" ) - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: self.runner = sdkmanager_runner(module) - def get_installed_packages(self): + def get_installed_packages(self) -> set[Package]: with self.runner("installed sdk_root channel") as ctx: rc, stdout, stderr = ctx.run() return self._parse_packages(stdout, self._RE_INSTALLED_PACKAGES_HEADER, self._RE_INSTALLED_PACKAGE) - def get_updatable_packages(self): + def get_updatable_packages(self) -> set[Package]: with self.runner("list newer sdk_root channel") as ctx: rc, stdout, stderr = ctx.run() return self._parse_packages(stdout, self._RE_UPDATABLE_PACKAGES_HEADER, self._RE_UPDATABLE_PACKAGE) - def apply_packages_changes(self, packages, accept_licenses=False): + def apply_packages_changes(self, packages: list[Package], accept_licenses: bool = False) -> tuple[int, str, str]: """Install or delete packages, depending on the `module.vars.state` parameter""" if len(packages) == 0: return 0, "", "" @@ -113,7 +118,7 @@ class AndroidSdkManager: return rc, stdout, stderr return 0, "", "" - def _try_parse_stderr(self, stderr): + def _try_parse_stderr(self, stderr: str) -> None: data = stderr.splitlines() for line in data: unknown_package_regex = self._RE_UNKNOWN_PACKAGE.match(line) @@ -122,15 +127,15 @@ class AndroidSdkManager: raise SdkManagerException(f"Unknown package {package}") @staticmethod - def _parse_packages(stdout, header_regexp, row_regexp): + def _parse_packages(stdout: str, header_regexp: re.Pattern, row_regexp: re.Pattern) -> set[Package]: data = stdout.splitlines() - section_found = False + section_found: bool = False packages = set() for line in data: if not section_found: - section_found = header_regexp.match(line) + section_found = bool(header_regexp.match(line)) continue else: p = row_regexp.match(line) diff --git a/plugins/module_utils/btrfs.py b/plugins/module_utils/btrfs.py index f99b185894..ef5f53c870 100644 --- a/plugins/module_utils/btrfs.py +++ b/plugins/module_utils/btrfs.py @@ -7,6 +7,10 @@ from __future__ import annotations from ansible.module_utils.common.text.converters import to_bytes import re import os +import typing as t + +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule def normalize_subvolume_path(path): @@ -28,11 +32,11 @@ class BtrfsCommands: Provides access to a subset of the Btrfs command line """ - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: self.__module = module - self.__btrfs = self.__module.get_bin_path("btrfs", required=True) + self.__btrfs: str = self.__module.get_bin_path("btrfs", required=True) - def filesystem_show(self): + def filesystem_show(self) -> list[dict[str, t.Any]]: command = f"{self.__btrfs} filesystem show -d" result = self.__module.run_command(command, check_rc=True) stdout = [x.strip() for x in result[1].splitlines()] @@ -43,14 +47,16 @@ class BtrfsCommands: current = self.__parse_filesystem(line) filesystems.append(current) elif line.startswith("devid"): + if current is None: + raise ValueError("Found 'devid' line without previous 'Label' line") current["devices"].append(self.__parse_filesystem_device(line)) return filesystems - def __parse_filesystem(self, line): + def __parse_filesystem(self, line) -> dict[str, t.Any]: label = re.sub(r"\s*uuid:.*$", "", re.sub(r"^Label:\s*", "", line)) id = re.sub(r"^.*uuid:\s*", "", line) - filesystem = {} + filesystem: dict[str, t.Any] = {} filesystem["label"] = label.strip("'") if label != "none" else None filesystem["uuid"] = id filesystem["devices"] = [] @@ -59,44 +65,44 @@ class BtrfsCommands: filesystem["default_subvolid"] = None return filesystem - def __parse_filesystem_device(self, line): + def __parse_filesystem_device(self, line: str) -> str: return re.sub(r"^.*path\s", "", line) - def subvolumes_list(self, filesystem_path): + def subvolumes_list(self, filesystem_path: str) -> list[dict[str, t.Any]]: command = f"{self.__btrfs} subvolume list -tap {filesystem_path}" result = self.__module.run_command(command, check_rc=True) stdout = [x.split("\t") for x in result[1].splitlines()] - subvolumes = [{"id": 5, "parent": None, "path": "/"}] + subvolumes: list[dict[str, t.Any]] = [{"id": 5, "parent": None, "path": "/"}] if len(stdout) > 2: subvolumes.extend([self.__parse_subvolume_list_record(x) for x in stdout[2:]]) return subvolumes - def __parse_subvolume_list_record(self, item): + def __parse_subvolume_list_record(self, item: list[str]) -> dict[str, t.Any]: return { "id": int(item[0]), "parent": int(item[2]), "path": normalize_subvolume_path(item[5]), } - def subvolume_get_default(self, filesystem_path): + def subvolume_get_default(self, filesystem_path: str) -> int: command = [self.__btrfs, "subvolume", "get-default", to_bytes(filesystem_path)] result = self.__module.run_command(command, check_rc=True) # ID [n] ... return int(result[1].strip().split()[1]) - def subvolume_set_default(self, filesystem_path, subvolume_id): + def subvolume_set_default(self, filesystem_path: str, subvolume_id: int) -> None: command = [self.__btrfs, "subvolume", "set-default", str(subvolume_id), to_bytes(filesystem_path)] self.__module.run_command(command, check_rc=True) - def subvolume_create(self, subvolume_path): + def subvolume_create(self, subvolume_path: str) -> None: command = [self.__btrfs, "subvolume", "create", to_bytes(subvolume_path)] self.__module.run_command(command, check_rc=True) - def subvolume_snapshot(self, snapshot_source, snapshot_destination): + def subvolume_snapshot(self, snapshot_source: str, snapshot_destination: str) -> None: command = [self.__btrfs, "subvolume", "snapshot", to_bytes(snapshot_source), to_bytes(snapshot_destination)] self.__module.run_command(command, check_rc=True) - def subvolume_delete(self, subvolume_path): + def subvolume_delete(self, subvolume_path: str) -> None: command = [self.__btrfs, "subvolume", "delete", to_bytes(subvolume_path)] self.__module.run_command(command, check_rc=True) @@ -106,12 +112,12 @@ class BtrfsInfoProvider: Utility providing details of the currently available btrfs filesystems """ - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: self.__module = module self.__btrfs_api = BtrfsCommands(module) - self.__findmnt_path = self.__module.get_bin_path("findmnt", required=True) + self.__findmnt_path: str = self.__module.get_bin_path("findmnt", required=True) - def get_filesystems(self): + def get_filesystems(self) -> list[dict[str, t.Any]]: filesystems = self.__btrfs_api.filesystem_show() mountpoints = self.__find_mountpoints() for filesystem in filesystems: @@ -126,20 +132,22 @@ class BtrfsInfoProvider: return filesystems - def get_mountpoints(self, filesystem_devices): + def get_mountpoints(self, filesystem_devices: list[str]) -> list[dict[str, t.Any]]: mountpoints = self.__find_mountpoints() return self.__filter_mountpoints_for_devices(mountpoints, filesystem_devices) - def get_subvolumes(self, filesystem_path): + def get_subvolumes(self, filesystem_path) -> list[dict[str, t.Any]]: return self.__btrfs_api.subvolumes_list(filesystem_path) - def get_default_subvolume_id(self, filesystem_path): + def get_default_subvolume_id(self, filesystem_path) -> int: return self.__btrfs_api.subvolume_get_default(filesystem_path) - def __filter_mountpoints_for_devices(self, mountpoints, devices): + def __filter_mountpoints_for_devices( + self, mountpoints: list[dict[str, t.Any]], devices: list[str] + ) -> list[dict[str, t.Any]]: return [m for m in mountpoints if (m["device"] in devices)] - def __find_mountpoints(self): + def __find_mountpoints(self) -> list[dict[str, t.Any]]: command = f"{self.__findmnt_path} -t btrfs -nvP" result = self.__module.run_command(command) mountpoints = [] @@ -150,7 +158,7 @@ class BtrfsInfoProvider: mountpoints.append(mountpoint) return mountpoints - def __parse_mountpoint_pairs(self, line): + def __parse_mountpoint_pairs(self, line) -> dict[str, t.Any]: pattern = re.compile( r'^TARGET="(?P.*)"\s+SOURCE="(?P.*)"\s+FSTYPE="(?P.*)"\s+OPTIONS="(?P.*)"\s*$' ) @@ -164,13 +172,13 @@ class BtrfsInfoProvider: "subvolid": self.__extract_mount_subvolid(groups["options"]), } else: - raise BtrfsModuleException(f"Failed to parse findmnt result for line: '{line}'") + raise BtrfsModuleException(f"Failed to parse findmnt result for line: {line!r}") - def __extract_mount_subvolid(self, mount_options): + def __extract_mount_subvolid(self, mount_options: str) -> int: for option in mount_options.split(","): if option.startswith("subvolid="): return int(option[len("subvolid=") :]) - raise BtrfsModuleException(f"Failed to find subvolid for mountpoint in options '{mount_options}'") + raise BtrfsModuleException(f"Failed to find subvolid for mountpoint in options {mount_options!r}") class BtrfsSubvolume: @@ -178,39 +186,38 @@ class BtrfsSubvolume: Wrapper class providing convenience methods for inspection of a btrfs subvolume """ - def __init__(self, filesystem, subvolume_id): + def __init__(self, filesystem: BtrfsFilesystem, subvolume_id: int): self.__filesystem = filesystem self.__subvolume_id = subvolume_id - def get_filesystem(self): + def get_filesystem(self) -> BtrfsFilesystem: return self.__filesystem - def is_mounted(self): + def is_mounted(self) -> bool: mountpoints = self.get_mountpoints() return mountpoints is not None and len(mountpoints) > 0 - def is_filesystem_root(self): + def is_filesystem_root(self) -> bool: return self.__subvolume_id == 5 - def is_filesystem_default(self): + def is_filesystem_default(self) -> bool: return self.__filesystem.default_subvolid == self.__subvolume_id - def get_mounted_path(self): + def get_mounted_path(self) -> str | None: mountpoints = self.get_mountpoints() if mountpoints is not None and len(mountpoints) > 0: return mountpoints[0] - elif self.parent is not None: + if self.parent is not None: parent = self.__filesystem.get_subvolume_by_id(self.parent) - parent_path = parent.get_mounted_path() + parent_path = parent.get_mounted_path() if parent else None if parent_path is not None: - return parent_path + os.path.sep + self.name - else: - return None + return f"{parent_path}{os.path.sep}{self.name}" + return None - def get_mountpoints(self): + def get_mountpoints(self) -> list[str]: return self.__filesystem.get_mountpoints_by_subvolume_id(self.__subvolume_id) - def get_child_relative_path(self, absolute_child_path): + def get_child_relative_path(self, absolute_child_path: str) -> str: """ Get the relative path from this subvolume to the named child subvolume. The provided parameter is expected to be normalized as by normalize_subvolume_path. @@ -222,19 +229,21 @@ class BtrfsSubvolume: else: raise BtrfsModuleException(f"Path '{absolute_child_path}' doesn't start with '{path}'") - def get_parent_subvolume(self): + def get_parent_subvolume(self) -> BtrfsSubvolume | None: parent_id = self.parent return self.__filesystem.get_subvolume_by_id(parent_id) if parent_id is not None else None - def get_child_subvolumes(self): + def get_child_subvolumes(self) -> list[BtrfsSubvolume]: return self.__filesystem.get_subvolume_children(self.__subvolume_id) @property - def __info(self): - return self.__filesystem.get_subvolume_info_for_id(self.__subvolume_id) + def __info(self) -> dict[str, t.Any]: + result = self.__filesystem.get_subvolume_info_for_id(self.__subvolume_id) + # assert result is not None + return result # type: ignore @property - def id(self): + def id(self) -> int: return self.__subvolume_id @property @@ -242,7 +251,7 @@ class BtrfsSubvolume: return self.path.split("/").pop() @property - def path(self): + def path(self) -> str: return self.__info["path"] @property @@ -255,105 +264,105 @@ class BtrfsFilesystem: Wrapper class providing convenience methods for inspection of a btrfs filesystem """ - def __init__(self, info, provider, module): + def __init__(self, info: dict[str, t.Any], provider: BtrfsInfoProvider, module: AnsibleModule) -> None: self.__provider = provider # constant for module execution - self.__uuid = info["uuid"] - self.__label = info["label"] - self.__devices = info["devices"] + self.__uuid: str = info["uuid"] + self.__label: str = info["label"] + self.__devices: list[str] = info["devices"] # refreshable - self.__default_subvolid = info["default_subvolid"] if "default_subvolid" in info else None + self.__default_subvolid: int | None = info["default_subvolid"] if "default_subvolid" in info else None self.__update_mountpoints(info["mountpoints"] if "mountpoints" in info else []) self.__update_subvolumes(info["subvolumes"] if "subvolumes" in info else []) @property - def uuid(self): + def uuid(self) -> str: return self.__uuid @property - def label(self): + def label(self) -> str: return self.__label @property - def default_subvolid(self): + def default_subvolid(self) -> int | None: return self.__default_subvolid @property - def devices(self): + def devices(self) -> list[str]: return list(self.__devices) - def refresh(self): + def refresh(self) -> None: self.refresh_mountpoints() self.refresh_subvolumes() self.refresh_default_subvolume() - def refresh_mountpoints(self): + def refresh_mountpoints(self) -> None: mountpoints = self.__provider.get_mountpoints(list(self.__devices)) self.__update_mountpoints(mountpoints) - def __update_mountpoints(self, mountpoints): - self.__mountpoints = dict() + def __update_mountpoints(self, mountpoints: list[dict[str, t.Any]]) -> None: + self.__mountpoints: dict[int, list[str]] = dict() for i in mountpoints: - subvolid = i["subvolid"] - mountpoint = i["mountpoint"] + subvolid: int = i["subvolid"] + mountpoint: str = i["mountpoint"] if subvolid not in self.__mountpoints: self.__mountpoints[subvolid] = [] self.__mountpoints[subvolid].append(mountpoint) - def refresh_subvolumes(self): + def refresh_subvolumes(self) -> None: filesystem_path = self.get_any_mountpoint() if filesystem_path is not None: subvolumes = self.__provider.get_subvolumes(filesystem_path) self.__update_subvolumes(subvolumes) - def __update_subvolumes(self, subvolumes): + def __update_subvolumes(self, subvolumes: list[dict[str, t.Any]]) -> None: # TODO strategy for retaining information on deleted subvolumes? - self.__subvolumes = dict() + self.__subvolumes: dict[int, dict[str, t.Any]] = dict() for subvolume in subvolumes: self.__subvolumes[subvolume["id"]] = subvolume - def refresh_default_subvolume(self): + def refresh_default_subvolume(self) -> None: filesystem_path = self.get_any_mountpoint() if filesystem_path is not None: self.__default_subvolid = self.__provider.get_default_subvolume_id(filesystem_path) - def contains_device(self, device): + def contains_device(self, device: str) -> bool: return device in self.__devices - def contains_subvolume(self, subvolume): + def contains_subvolume(self, subvolume: str) -> bool: return self.get_subvolume_by_name(subvolume) is not None - def get_subvolume_by_id(self, subvolume_id): + def get_subvolume_by_id(self, subvolume_id: int) -> BtrfsSubvolume | None: return BtrfsSubvolume(self, subvolume_id) if subvolume_id in self.__subvolumes else None - def get_subvolume_info_for_id(self, subvolume_id): + def get_subvolume_info_for_id(self, subvolume_id: int) -> dict[str, t.Any] | None: return self.__subvolumes[subvolume_id] if subvolume_id in self.__subvolumes else None - def get_subvolume_by_name(self, subvolume): + def get_subvolume_by_name(self, subvolume: str) -> BtrfsSubvolume | None: for subvolume_info in self.__subvolumes.values(): if subvolume_info["path"] == subvolume: return BtrfsSubvolume(self, subvolume_info["id"]) return None - def get_any_mountpoint(self): + def get_any_mountpoint(self) -> str | None: for subvol_mountpoints in self.__mountpoints.values(): if len(subvol_mountpoints) > 0: return subvol_mountpoints[0] # maybe error? return None - def get_any_mounted_subvolume(self): + def get_any_mounted_subvolume(self) -> BtrfsSubvolume | None: for subvolid, subvol_mountpoints in self.__mountpoints.items(): if len(subvol_mountpoints) > 0: return self.get_subvolume_by_id(subvolid) return None - def get_mountpoints_by_subvolume_id(self, subvolume_id): + def get_mountpoints_by_subvolume_id(self, subvolume_id: int) -> list[str]: return self.__mountpoints[subvolume_id] if subvolume_id in self.__mountpoints else [] - def get_nearest_subvolume(self, subvolume): + def get_nearest_subvolume(self, subvolume: str) -> BtrfsSubvolume: """Return the identified subvolume if existing, else the closest matching parent""" subvolumes_by_path = self.__get_subvolumes_by_path() while len(subvolume) > 1: @@ -364,30 +373,31 @@ class BtrfsFilesystem: return BtrfsSubvolume(self, 5) - def get_mountpath_as_child(self, subvolume_name): + def get_mountpath_as_child(self, subvolume_name: str) -> str: """Find a path to the target subvolume through a mounted ancestor""" nearest = self.get_nearest_subvolume(subvolume_name) + nearest_or_none: BtrfsSubvolume | None = nearest if nearest.path == subvolume_name: - nearest = nearest.get_parent_subvolume() - if nearest is None or nearest.get_mounted_path() is None: + nearest_or_none = nearest.get_parent_subvolume() + if nearest_or_none is None or nearest_or_none.get_mounted_path() is None: raise BtrfsModuleException(f"Failed to find a path '{subvolume_name}' through a mounted parent subvolume") else: - return nearest.get_mounted_path() + os.path.sep + nearest.get_child_relative_path(subvolume_name) + return f"{nearest_or_none.get_mounted_path()}{os.path.sep}{nearest_or_none.get_child_relative_path(subvolume_name)}" - def get_subvolume_children(self, subvolume_id): + def get_subvolume_children(self, subvolume_id: int) -> list[BtrfsSubvolume]: return [BtrfsSubvolume(self, x["id"]) for x in self.__subvolumes.values() if x["parent"] == subvolume_id] - def __get_subvolumes_by_path(self): + def __get_subvolumes_by_path(self) -> dict[str, dict[str, t.Any]]: result = {} for s in self.__subvolumes.values(): path = s["path"] result[path] = s return result - def is_mounted(self): + def is_mounted(self) -> bool: return self.__mountpoints is not None and len(self.__mountpoints) > 0 - def get_summary(self): + def get_summary(self) -> dict[str, t.Any]: subvolumes = [] sources = self.__subvolumes.values() if self.__subvolumes is not None else [] for subvolume in sources: @@ -415,17 +425,19 @@ class BtrfsFilesystemsProvider: Provides methods to query available btrfs filesystems """ - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: self.__module = module self.__provider = BtrfsInfoProvider(module) - self.__filesystems = None + self.__filesystems: dict[str, BtrfsFilesystem] | None = None - def get_matching_filesystem(self, criteria): + def get_matching_filesystem(self, criteria: dict[str, t.Any]) -> BtrfsFilesystem: if criteria["device"] is not None: criteria["device"] = os.path.realpath(criteria["device"]) self.__check_init() - matching = [f for f in self.__filesystems.values() if self.__filesystem_matches_criteria(f, criteria)] + # assert self.__filesystems is not None # TODO + self_filesystems: dict[str, BtrfsFilesystem] = self.__filesystems # type: ignore + matching = [f for f in self_filesystems.values() if self.__filesystem_matches_criteria(f, criteria)] if len(matching) == 1: return matching[0] else: @@ -433,26 +445,30 @@ class BtrfsFilesystemsProvider: f"Found {len(matching)} filesystems matching criteria uuid={criteria['uuid']} label={criteria['label']} device={criteria['device']}" ) - def __filesystem_matches_criteria(self, filesystem, criteria): + def __filesystem_matches_criteria(self, filesystem: BtrfsFilesystem, criteria: dict[str, t.Any]): return ( (criteria["uuid"] is None or filesystem.uuid == criteria["uuid"]) and (criteria["label"] is None or filesystem.label == criteria["label"]) and (criteria["device"] is None or filesystem.contains_device(criteria["device"])) ) - def get_filesystem_for_device(self, device): + def get_filesystem_for_device(self, device: str) -> BtrfsFilesystem | None: real_device = os.path.realpath(device) self.__check_init() - for fs in self.__filesystems.values(): + # assert self.__filesystems is not None # TODO + self_filesystems: dict[str, BtrfsFilesystem] = self.__filesystems # type: ignore + for fs in self_filesystems.values(): if fs.contains_device(real_device): return fs return None - def get_filesystems(self): + def get_filesystems(self) -> list[BtrfsFilesystem]: self.__check_init() - return list(self.__filesystems.values()) + # assert self.__filesystems is not None # TODO + self_filesystems: dict[str, BtrfsFilesystem] = self.__filesystems # type: ignore + return list(self_filesystems.values()) - def __check_init(self): + def __check_init(self) -> None: if self.__filesystems is None: self.__filesystems = dict() for f in self.__provider.get_filesystems(): diff --git a/plugins/module_utils/cmd_runner.py b/plugins/module_utils/cmd_runner.py index 44bd9fc6ec..e7ed48cb3a 100644 --- a/plugins/module_utils/cmd_runner.py +++ b/plugins/module_utils/cmd_runner.py @@ -5,11 +5,19 @@ from __future__ import annotations import os +import typing as t from ansible.module_utils.common.collections import is_sequence from ansible.module_utils.common.locale import get_best_parsable_locale from ansible_collections.community.general.plugins.module_utils import cmd_runner_fmt +if t.TYPE_CHECKING: + from collections.abc import Callable, Mapping, Sequence + from ansible.module_utils.basic import AnsibleModule + from ansible_collections.community.general.plugins.module_utils.cmd_runner_fmt import ArgFormatType + + ArgFormatter = t.Union[ArgFormatType, cmd_runner_fmt._ArgFormat] # noqa: UP007 + def _ensure_list(value): return list(value) if is_sequence(value) else [value] @@ -24,7 +32,7 @@ class CmdRunnerException(Exception): class MissingArgumentFormat(CmdRunnerException): - def __init__(self, arg, args_order, args_formats): + def __init__(self, arg, args_order: tuple[str, ...], args_formats) -> None: self.args_order = args_order self.arg = arg self.args_formats = args_formats @@ -37,7 +45,7 @@ class MissingArgumentFormat(CmdRunnerException): class MissingArgumentValue(CmdRunnerException): - def __init__(self, args_order, arg): + def __init__(self, args_order: tuple[str, ...], arg) -> None: self.args_order = args_order self.arg = arg @@ -72,19 +80,19 @@ class CmdRunner: """ @staticmethod - def _prepare_args_order(order): - return tuple(order) if is_sequence(order) else tuple(order.split()) + def _prepare_args_order(order: str | Sequence[str]) -> tuple[str, ...]: + return tuple(order) if is_sequence(order) else tuple(order.split()) # type: ignore def __init__( self, - module, + module: AnsibleModule, command, - arg_formats=None, - default_args_order=(), - check_rc=False, - force_lang="C", - path_prefix=None, - environ_update=None, + arg_formats: Mapping[str, ArgFormatter] | None = None, + default_args_order: str | Sequence[str] = (), + check_rc: bool = False, + force_lang: str = "C", + path_prefix: Sequence[str] | None = None, + environ_update: dict[str, str] | None = None, ): self.module = module self.command = _ensure_list(command) @@ -117,10 +125,17 @@ class CmdRunner: ) @property - def binary(self): + def binary(self) -> str: return self.command[0] - def __call__(self, args_order=None, output_process=None, check_mode_skip=False, check_mode_return=None, **kwargs): + def __call__( + self, + args_order: str | Sequence[str] | None = None, + output_process: Callable[[int, str, str], t.Any] | None = None, + check_mode_skip: bool = False, + check_mode_return: t.Any | None = None, + **kwargs, + ): if output_process is None: output_process = _process_as_is if args_order is None: @@ -146,7 +161,15 @@ class CmdRunner: class _CmdRunnerContext: - def __init__(self, runner, args_order, output_process, check_mode_skip, check_mode_return, **kwargs): + def __init__( + self, + runner: CmdRunner, + args_order: tuple[str, ...], + output_process: Callable[[int, str, str], t.Any], + check_mode_skip: bool, + check_mode_return: t.Any, + **kwargs, + ) -> None: self.runner = runner self.args_order = tuple(args_order) self.output_process = output_process @@ -204,7 +227,7 @@ class _CmdRunnerContext: return self.results_processed @property - def run_info(self): + def run_info(self) -> dict[str, t.Any]: return dict( check_rc=self.check_rc, environ_update=self.environ_update, diff --git a/plugins/module_utils/cmd_runner_fmt.py b/plugins/module_utils/cmd_runner_fmt.py index 535a012947..b91e2ef92d 100644 --- a/plugins/module_utils/cmd_runner_fmt.py +++ b/plugins/module_utils/cmd_runner_fmt.py @@ -11,36 +11,46 @@ from functools import wraps from ansible.module_utils.common.collections import is_sequence if t.TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Mapping, Sequence - ArgFormatType = Callable[[t.Any], list[str]] + ArgFormatType = Callable[[t.Any], Sequence[t.Any]] + _T = t.TypeVar("_T") -def _ensure_list(value): - return list(value) if is_sequence(value) else [value] +def _ensure_list(value: _T | Sequence[_T]) -> list[_T]: + return list(value) if is_sequence(value) else [value] # type: ignore # TODO need type assertion for is_sequence class _ArgFormat: - def __init__(self, func, ignore_none=True, ignore_missing_value=False): + def __init__( + self, + func: ArgFormatType, + ignore_none: bool | None = True, + ignore_missing_value: bool = False, + ) -> None: self.func = func self.ignore_none = ignore_none self.ignore_missing_value = ignore_missing_value - def __call__(self, value): + def __call__(self, value: t.Any | None) -> list[str]: ignore_none = self.ignore_none if self.ignore_none is not None else True if value is None and ignore_none: return [] f = self.func return [str(x) for x in f(value)] - def __str__(self): + def __str__(self) -> str: return f"" - def __repr__(self): + def __repr__(self) -> str: return str(self) -def as_bool(args_true, args_false=None, ignore_none=None): +def as_bool( + args_true: Sequence[t.Any] | t.Any, + args_false: Sequence[t.Any] | t.Any | None = None, + ignore_none: bool | None = None, +) -> _ArgFormat: if args_false is not None: if ignore_none is None: ignore_none = False @@ -51,24 +61,24 @@ def as_bool(args_true, args_false=None, ignore_none=None): ) -def as_bool_not(args): +def as_bool_not(args: Sequence[t.Any] | t.Any) -> _ArgFormat: return as_bool([], args, ignore_none=False) -def as_optval(arg, ignore_none=None): +def as_optval(arg, ignore_none: bool | None = None) -> _ArgFormat: return _ArgFormat(lambda value: [f"{arg}{value}"], ignore_none=ignore_none) -def as_opt_val(arg, ignore_none=None): +def as_opt_val(arg: str, ignore_none: bool | None = None) -> _ArgFormat: return _ArgFormat(lambda value: [arg, value], ignore_none=ignore_none) -def as_opt_eq_val(arg, ignore_none=None): +def as_opt_eq_val(arg: str, ignore_none: bool | None = None) -> _ArgFormat: return _ArgFormat(lambda value: [f"{arg}={value}"], ignore_none=ignore_none) -def as_list(ignore_none=None, min_len=0, max_len=None): - def func(value): +def as_list(ignore_none: bool | None = None, min_len: int = 0, max_len: int | None = None) -> _ArgFormat: + def func(value: t.Any) -> list[t.Any]: value = _ensure_list(value) if len(value) < min_len: raise ValueError(f"Parameter must have at least {min_len} element(s)") @@ -79,17 +89,21 @@ def as_list(ignore_none=None, min_len=0, max_len=None): return _ArgFormat(func, ignore_none=ignore_none) -def as_fixed(*args): +def as_fixed(*args: t.Any) -> _ArgFormat: if len(args) == 1 and is_sequence(args[0]): args = args[0] return _ArgFormat(lambda value: _ensure_list(args), ignore_none=False, ignore_missing_value=True) -def as_func(func, ignore_none=None): +def as_func(func: ArgFormatType, ignore_none: bool | None = None) -> _ArgFormat: return _ArgFormat(func, ignore_none=ignore_none) -def as_map(_map, default=None, ignore_none=None): +def as_map( + _map: Mapping[t.Any, Sequence[t.Any] | t.Any], + default: Sequence[t.Any] | t.Any | None = None, + ignore_none: bool | None = None, +) -> _ArgFormat: if default is None: default = [] return _ArgFormat(lambda value: _ensure_list(_map.get(value, default)), ignore_none=ignore_none) @@ -126,5 +140,5 @@ def stack(fmt): return wrapper -def is_argformat(fmt): +def is_argformat(fmt: object) -> t.TypeGuard[_ArgFormat]: return isinstance(fmt, _ArgFormat) diff --git a/plugins/module_utils/consul.py b/plugins/module_utils/consul.py index 1d405d20ae..f3f18bacf6 100644 --- a/plugins/module_utils/consul.py +++ b/plugins/module_utils/consul.py @@ -14,6 +14,9 @@ from urllib.parse import urlencode from ansible.module_utils.urls import open_url +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + def get_consul_url(configuration): return f"{configuration.scheme}://{configuration.host}:{configuration.port}/v1" @@ -120,7 +123,7 @@ class _ConsulModule: operational_attributes: set[str] = set() params: dict[str, t.Any] = {} - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: self._module = module self.params = _normalize_params(module.params, module.argument_spec) self.api_params = { diff --git a/plugins/module_utils/csv.py b/plugins/module_utils/csv.py index 3a6194fd01..99c03ad825 100644 --- a/plugins/module_utils/csv.py +++ b/plugins/module_utils/csv.py @@ -6,11 +6,26 @@ from __future__ import annotations import csv +import typing as t from io import StringIO from ansible.module_utils.common.text.converters import to_native +if t.TYPE_CHECKING: + from collections.abc import Sequence + + class DialectParamsOrNone(t.TypedDict): + delimiter: t.NotRequired[str | None] + doublequote: t.NotRequired[bool | None] + escapechar: t.NotRequired[str | None] + lineterminator: t.NotRequired[str | None] + quotechar: t.NotRequired[str | None] + quoting: t.NotRequired[int | None] + skipinitialspace: t.NotRequired[bool | None] + strict: t.NotRequired[bool | None] + + class CustomDialectFailureError(Exception): pass @@ -22,7 +37,7 @@ class DialectNotAvailableError(Exception): CSVError = csv.Error -def initialize_dialect(dialect, **kwargs): +def initialize_dialect(dialect: str, **kwargs: t.Unpack[DialectParamsOrNone]) -> str: # Add Unix dialect from Python 3 class unix_dialect(csv.Dialect): """Describe the usual properties of Unix-generated CSV files.""" @@ -43,7 +58,7 @@ def initialize_dialect(dialect, **kwargs): dialect_params = {k: v for k, v in kwargs.items() if v is not None} if dialect_params: try: - csv.register_dialect("custom", dialect, **dialect_params) + csv.register_dialect("custom", dialect, **dialect_params) # type: ignore except TypeError as e: raise CustomDialectFailureError(f"Unable to create custom dialect: {e}") from e dialect = "custom" @@ -51,7 +66,7 @@ def initialize_dialect(dialect, **kwargs): return dialect -def read_csv(data, dialect, fieldnames=None): +def read_csv(data: str, dialect: str, fieldnames: Sequence[str] | None = None) -> csv.DictReader: BOM = "\ufeff" data = to_native(data, errors="surrogate_or_strict") if data.startswith(BOM): diff --git a/plugins/module_utils/database.py b/plugins/module_utils/database.py index 0557ebe731..7cb0983b4e 100644 --- a/plugins/module_utils/database.py +++ b/plugins/module_utils/database.py @@ -14,6 +14,10 @@ from __future__ import annotations import re +import typing as t + +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule # Input patterns for is_input_dangerous function: @@ -162,7 +166,7 @@ def is_input_dangerous(string): return any(pattern.search(string) for pattern in (PATTERN_1, PATTERN_2, PATTERN_3)) -def check_input(module, *args): +def check_input(module: AnsibleModule, *args) -> None: """Wrapper for is_input_dangerous function.""" needs_to_check = args diff --git a/plugins/module_utils/datetime.py b/plugins/module_utils/datetime.py index f6690b6d4a..bf0aff23ea 100644 --- a/plugins/module_utils/datetime.py +++ b/plugins/module_utils/datetime.py @@ -8,15 +8,15 @@ from __future__ import annotations import datetime as _datetime -def ensure_timezone_info(value): +def ensure_timezone_info(value: _datetime.datetime) -> _datetime.datetime: if value.tzinfo is not None: return value return value.astimezone(_datetime.timezone.utc) -def fromtimestamp(value): +def fromtimestamp(value: int | float) -> _datetime.datetime: return _datetime.datetime.fromtimestamp(value, tz=_datetime.timezone.utc) -def now(): +def now() -> _datetime.datetime: return _datetime.datetime.now(tz=_datetime.timezone.utc) diff --git a/plugins/module_utils/deps.py b/plugins/module_utils/deps.py index 63330bfed7..31e6dd7f2f 100644 --- a/plugins/module_utils/deps.py +++ b/plugins/module_utils/deps.py @@ -7,56 +7,60 @@ from __future__ import annotations import traceback +import typing as t from contextlib import contextmanager from ansible.module_utils.basic import missing_required_lib +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule -_deps = dict() + +_deps: dict[str, _Dependency] = dict() class _Dependency: _states = ["pending", "failure", "success"] - def __init__(self, name, reason=None, url=None, msg=None): + def __init__(self, name: str, reason: str | None = None, url: str | None = None, msg: str | None = None) -> None: self.name = name self.reason = reason self.url = url self.msg = msg self.state = 0 - self.trace = None - self.exc = None + self.trace: str | None = None + self.exc: Exception | None = None - def succeed(self): + def succeed(self) -> None: self.state = 2 - def fail(self, exc, trace): + def fail(self, exc: Exception, trace: str) -> None: self.state = 1 self.exc = exc self.trace = trace @property - def message(self): + def message(self) -> str: if self.msg: return str(self.msg) else: return missing_required_lib(self.name, reason=self.reason, url=self.url) @property - def failed(self): + def failed(self) -> bool: return self.state == 1 - def validate(self, module): + def validate(self, module: AnsibleModule) -> None: if self.failed: module.fail_json(msg=self.message, exception=self.trace) - def __str__(self): + def __str__(self) -> str: return f"" @contextmanager -def declare(name, *args, **kwargs): +def declare(name: str, *args, **kwargs) -> t.Generator[_Dependency]: dep = _Dependency(name, *args, **kwargs) try: yield dep @@ -68,7 +72,7 @@ def declare(name, *args, **kwargs): _deps[name] = dep -def _select_names(spec): +def _select_names(spec: str | None) -> list[str]: dep_names = sorted(_deps) if spec: @@ -86,14 +90,14 @@ def _select_names(spec): return dep_names -def validate(module, spec=None): +def validate(module: AnsibleModule, spec: str | None = None) -> None: for dep in _select_names(spec): _deps[dep].validate(module) -def failed(spec=None): +def failed(spec: str | None = None) -> bool: return any(_deps[d].failed for d in _select_names(spec)) -def clear(): +def clear() -> None: _deps.clear() diff --git a/plugins/module_utils/dimensiondata.py b/plugins/module_utils/dimensiondata.py index 0a7a06b78e..605e7a9c75 100644 --- a/plugins/module_utils/dimensiondata.py +++ b/plugins/module_utils/dimensiondata.py @@ -24,7 +24,6 @@ import os import re import traceback -# (TODO: remove AnsibleModule from next line!) from ansible.module_utils.basic import AnsibleModule, missing_required_lib # noqa: F401, pylint: disable=unused-import from os.path import expanduser from uuid import UUID @@ -57,7 +56,7 @@ class DimensionDataModule: The base class containing common functionality used by Dimension Data modules for Ansible. """ - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: """ Create a new DimensionDataModule. diff --git a/plugins/module_utils/django.py b/plugins/module_utils/django.py index 0f807e673a..0e3f258baf 100644 --- a/plugins/module_utils/django.py +++ b/plugins/module_utils/django.py @@ -12,7 +12,8 @@ from ansible_collections.community.general.plugins.module_utils.python_runner im from ansible_collections.community.general.plugins.module_utils.module_helper import ModuleHelper if t.TYPE_CHECKING: - from .cmd_runner_fmt import ArgFormatType + from ansible.module_utils.basic import AnsibleModule + from ansible_collections.community.general.plugins.module_utils.cmd_runner import ArgFormatter django_std_args = dict( @@ -36,7 +37,7 @@ _pks = dict( primary_keys=dict(type="list", elements="str"), ) -_django_std_arg_fmts: dict[str, ArgFormatType] = dict( +_django_std_arg_fmts: dict[str, ArgFormatter] = dict( all=cmd_runner_fmt.as_bool("--all"), app=cmd_runner_fmt.as_opt_val("--app"), apps=cmd_runner_fmt.as_list(), @@ -81,7 +82,7 @@ _args_menu = dict( class _DjangoRunner(PythonRunner): - def __init__(self, module, arg_formats=None, **kwargs): + def __init__(self, module: AnsibleModule, arg_formats=None, **kwargs) -> None: arg_fmts = dict(arg_formats) if arg_formats else {} arg_fmts.update(_django_std_arg_fmts) @@ -108,12 +109,12 @@ class _DjangoRunner(PythonRunner): class DjangoModuleHelper(ModuleHelper): module = {} django_admin_cmd: str | None = None - arg_formats: dict[str, ArgFormatType] = {} + arg_formats: dict[str, ArgFormatter] = {} django_admin_arg_order: tuple[str, ...] | str = () _django_args: list[str] = [] _check_mode_arg: str = "" - def __init__(self): + def __init__(self) -> None: self.module["argument_spec"], self.arg_formats = self._build_args( self.module.get("argument_spec", {}), self.arg_formats, *(["std"] + self._django_args) ) @@ -122,9 +123,9 @@ class DjangoModuleHelper(ModuleHelper): self.vars.command = self.django_admin_cmd @staticmethod - def _build_args(arg_spec, arg_format, *names): - res_arg_spec = {} - res_arg_fmts = {} + def _build_args(arg_spec, arg_format, *names) -> tuple[dict[str, t.Any], dict[str, ArgFormatter]]: + res_arg_spec: dict[str, t.Any] = {} + res_arg_fmts: dict[str, ArgFormatter] = {} for name in names: args, fmts = _args_menu[name] res_arg_spec = dict_merge(res_arg_spec, args) diff --git a/plugins/module_utils/gandi_livedns_api.py b/plugins/module_utils/gandi_livedns_api.py index 1b60b1b087..44368cc274 100644 --- a/plugins/module_utils/gandi_livedns_api.py +++ b/plugins/module_utils/gandi_livedns_api.py @@ -5,9 +5,13 @@ from __future__ import annotations import json +import typing as t from ansible.module_utils.urls import fetch_url +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + class GandiLiveDNSAPI: api_endpoint = "https://api.gandi.net/v5/livedns" @@ -21,7 +25,7 @@ class GandiLiveDNSAPI: attribute_map = {"record": "rrset_name", "type": "rrset_type", "ttl": "rrset_ttl", "values": "rrset_values"} - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: self.module = module self.api_key = module.params["api_key"] self.personal_access_token = module.params["personal_access_token"] diff --git a/plugins/module_utils/gconftool2.py b/plugins/module_utils/gconftool2.py index 9eafa553fd..c483ba3791 100644 --- a/plugins/module_utils/gconftool2.py +++ b/plugins/module_utils/gconftool2.py @@ -4,8 +4,13 @@ from __future__ import annotations +import typing as t + from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + _state_map = { "present": "--set", @@ -14,7 +19,7 @@ _state_map = { } -def gconftool2_runner(module, **kwargs): +def gconftool2_runner(module: AnsibleModule, **kwargs) -> CmdRunner: return CmdRunner( module, command="gconftool-2", diff --git a/plugins/module_utils/gio_mime.py b/plugins/module_utils/gio_mime.py index e6987c6e7d..4bd3a76857 100644 --- a/plugins/module_utils/gio_mime.py +++ b/plugins/module_utils/gio_mime.py @@ -4,10 +4,15 @@ from __future__ import annotations +import typing as t + from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule -def gio_mime_runner(module, **kwargs): + +def gio_mime_runner(module: AnsibleModule, **kwargs) -> CmdRunner: return CmdRunner( module, command=["gio"], @@ -21,8 +26,8 @@ def gio_mime_runner(module, **kwargs): ) -def gio_mime_get(runner, mime_type): - def process(rc, out, err): +def gio_mime_get(runner: CmdRunner, mime_type) -> str | None: + def process(rc, out, err) -> str | None: if err.startswith("No default applications for"): return None out = out.splitlines()[0] diff --git a/plugins/module_utils/gitlab.py b/plugins/module_utils/gitlab.py index 2222b07703..16528dce57 100644 --- a/plugins/module_utils/gitlab.py +++ b/plugins/module_utils/gitlab.py @@ -5,18 +5,19 @@ from __future__ import annotations +import traceback import typing as t +from urllib.parse import urljoin from ansible.module_utils.basic import missing_required_lib from ansible_collections.community.general.plugins.module_utils.version import LooseVersion -from urllib.parse import urljoin - -import traceback +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule -def _determine_list_all_kwargs(version) -> dict[str, t.Any]: +def _determine_list_all_kwargs(version: str) -> dict[str, t.Any]: gitlab_version = LooseVersion(version) if gitlab_version >= LooseVersion("4.0.0"): # 4.0.0 removed 'as_list' @@ -42,7 +43,7 @@ except Exception: list_all_kwargs = {} -def auth_argument_spec(spec=None): +def auth_argument_spec(spec: dict[str, t.Any] | None = None) -> dict[str, t.Any]: arg_spec = dict( ca_path=dict(type="str"), api_token=dict(type="str", no_log=True), @@ -76,7 +77,7 @@ def find_group(gitlab_instance, identifier): return group -def ensure_gitlab_package(module, min_version=None): +def ensure_gitlab_package(module: AnsibleModule, min_version=None) -> None: if not HAS_GITLAB_PACKAGE: module.fail_json( msg=missing_required_lib("python-gitlab", url="https://python-gitlab.readthedocs.io/en/stable/"), @@ -92,7 +93,7 @@ def ensure_gitlab_package(module, min_version=None): ) -def gitlab_authentication(module, min_version=None): +def gitlab_authentication(module: AnsibleModule, min_version=None) -> gitlab.Gitlab: ensure_gitlab_package(module, min_version=min_version) gitlab_url = module.params["api_url"] @@ -121,7 +122,7 @@ def gitlab_authentication(module, min_version=None): private_token=gitlab_token, oauth_token=gitlab_oauth_token, job_token=gitlab_job_token, - api_version=4, + api_version="4", ) gitlab_instance.auth() except (gitlab.exceptions.GitlabAuthenticationError, gitlab.exceptions.GitlabGetError) as e: @@ -137,7 +138,7 @@ def gitlab_authentication(module, min_version=None): return gitlab_instance -def filter_returned_variables(gitlab_variables): +def filter_returned_variables(gitlab_variables) -> list[dict[str, t.Any]]: # pop properties we don't know existing_variables = [dict(x.attributes) for x in gitlab_variables] KNOWN = [ @@ -158,9 +159,11 @@ def filter_returned_variables(gitlab_variables): return existing_variables -def vars_to_variables(vars, module): +def vars_to_variables( + vars: dict[str, str | int | float | dict[str, t.Any]], module: AnsibleModule +) -> list[dict[str, t.Any]]: # transform old vars to new variables structure - variables = list() + variables = [] for item, value in vars.items(): if isinstance(value, (str, int, float)): variables.append( diff --git a/plugins/module_utils/heroku.py b/plugins/module_utils/heroku.py index ed9f8c59cd..1ac213872e 100644 --- a/plugins/module_utils/heroku.py +++ b/plugins/module_utils/heroku.py @@ -5,9 +5,13 @@ from __future__ import annotations import traceback +import typing as t from ansible.module_utils.basic import env_fallback, missing_required_lib +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + HAS_HEROKU = False HEROKU_IMP_ERR = None try: @@ -19,17 +23,17 @@ except ImportError: class HerokuHelper: - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: self.module = module self.check_lib() self.api_key = module.params["api_key"] - def check_lib(self): + def check_lib(self) -> None: if not HAS_HEROKU: self.module.fail_json(msg=missing_required_lib("heroku3"), exception=HEROKU_IMP_ERR) @staticmethod - def heroku_argument_spec(): + def heroku_argument_spec() -> dict[str, t.Any]: return dict( api_key=dict(fallback=(env_fallback, ["HEROKU_API_KEY", "TF_VAR_HEROKU_API_KEY"]), type="str", no_log=True) ) diff --git a/plugins/module_utils/homebrew.py b/plugins/module_utils/homebrew.py index 2908bee72e..06ca5fb4f3 100644 --- a/plugins/module_utils/homebrew.py +++ b/plugins/module_utils/homebrew.py @@ -7,9 +7,13 @@ from __future__ import annotations import os import re +import typing as t + +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule -def _create_regex_group_complement(s): +def _create_regex_group_complement(s: str) -> re.Pattern: lines = (line.strip() for line in s.split("\n") if line.strip()) chars = [_f for _f in (line.split("#")[0].strip() for line in lines) if _f] group = rf"[^{''.join(chars)}]" @@ -52,7 +56,7 @@ class HomebrewValidate: # class validations -------------------------------------------- {{{ @classmethod - def valid_path(cls, path): + def valid_path(cls, path: list[str] | str) -> bool: """ `path` must be one of: - list of paths @@ -77,7 +81,7 @@ class HomebrewValidate: return all(cls.valid_brew_path(path_) for path_ in paths) @classmethod - def valid_brew_path(cls, brew_path): + def valid_brew_path(cls, brew_path: str | None) -> bool: """ `brew_path` must be one of: - None @@ -95,7 +99,7 @@ class HomebrewValidate: return isinstance(brew_path, str) and not cls.INVALID_BREW_PATH_REGEX.search(brew_path) @classmethod - def valid_package(cls, package): + def valid_package(cls, package: str | None) -> bool: """A valid package is either None or alphanumeric.""" if package is None: @@ -104,8 +108,7 @@ class HomebrewValidate: return isinstance(package, str) and not cls.INVALID_PACKAGE_REGEX.search(package) -def parse_brew_path(module): - # type: (...) -> str +def parse_brew_path(module: AnsibleModule) -> str: """Attempt to find the Homebrew executable path. Requires: diff --git a/plugins/module_utils/hwc_utils.py b/plugins/module_utils/hwc_utils.py index dfb9069d28..64ce4b59ea 100644 --- a/plugins/module_utils/hwc_utils.py +++ b/plugins/module_utils/hwc_utils.py @@ -7,6 +7,7 @@ from __future__ import annotations import re import time import traceback +import typing as t THIRD_LIBRARIES_IMP_ERR = None try: @@ -24,37 +25,37 @@ from ansible.module_utils.common.text.converters import to_text class HwcModuleException(Exception): - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__() self._message = message - def __str__(self): + def __str__(self) -> str: return f"[HwcClientException] message={self._message}" class HwcClientException(Exception): - def __init__(self, code, message): + def __init__(self, code: int, message: str) -> None: super().__init__() self._code = code self._message = message - def __str__(self): + def __str__(self) -> str: msg = f" code={self._code}," if self._code != 0 else "" return f"[HwcClientException]{msg} message={self._message}" class HwcClientException404(HwcClientException): - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__(404, message) - def __str__(self): + def __str__(self) -> str: return f"[HwcClientException404] message={self._message}" def session_method_wrapper(f): - def _wrap(self, url, *args, **kwargs): + def _wrap(self, url: str, *args, **kwargs): try: url = self.endpoint + url r = f(self, url, *args, **kwargs) @@ -91,7 +92,7 @@ def session_method_wrapper(f): class _ServiceClient: - def __init__(self, client, endpoint, product): + def __init__(self, client, endpoint: str, product): self._client = client self._endpoint = endpoint self._default_header = { @@ -100,30 +101,30 @@ class _ServiceClient: } @property - def endpoint(self): + def endpoint(self) -> str: return self._endpoint @endpoint.setter - def endpoint(self, e): + def endpoint(self, e: str) -> None: self._endpoint = e @session_method_wrapper - def get(self, url, body=None, header=None, timeout=None): + def get(self, url: str, body=None, header: dict[str, t.Any] | None = None, timeout=None): return self._client.get(url, json=body, timeout=timeout, headers=self._header(header)) @session_method_wrapper - def post(self, url, body=None, header=None, timeout=None): + def post(self, url: str, body=None, header: dict[str, t.Any] | None = None, timeout=None): return self._client.post(url, json=body, timeout=timeout, headers=self._header(header)) @session_method_wrapper - def delete(self, url, body=None, header=None, timeout=None): + def delete(self, url: str, body=None, header: dict[str, t.Any] | None = None, timeout=None): return self._client.delete(url, json=body, timeout=timeout, headers=self._header(header)) @session_method_wrapper - def put(self, url, body=None, header=None, timeout=None): + def put(self, url: str, body=None, header: dict[str, t.Any] | None = None, timeout=None): return self._client.put(url, json=body, timeout=timeout, headers=self._header(header)) - def _header(self, header): + def _header(self, header: dict[str, t.Any] | None) -> dict[str, t.Any]: if header and isinstance(header, dict): for k, v in self._default_header.items(): if k not in header: @@ -135,18 +136,18 @@ class _ServiceClient: class Config: - def __init__(self, module, product): + def __init__(self, module: AnsibleModule, product) -> None: self._project_client = None self._domain_client = None self._module = module self._product = product - self._endpoints = {} + self._endpoints: dict[str, t.Any] = {} self._validate() self._gen_provider_client() @property - def module(self): + def module(self) -> AnsibleModule: return self._module def client(self, region, service_type, service_level): @@ -380,7 +381,7 @@ def navigate_value(data, index, array_index=None): return d -def build_path(module, path, kv=None): +def build_path(module: AnsibleModule, path, kv=None): if kv is None: kv = dict() @@ -400,7 +401,7 @@ def build_path(module, path, kv=None): return path.format(**v) -def get_region(module): +def get_region(module: AnsibleModule): if module.params["region"]: return module.params["region"] diff --git a/plugins/module_utils/ibm_sa_utils.py b/plugins/module_utils/ibm_sa_utils.py index c6c583e9cf..df339ec6fb 100644 --- a/plugins/module_utils/ibm_sa_utils.py +++ b/plugins/module_utils/ibm_sa_utils.py @@ -7,10 +7,14 @@ from __future__ import annotations import traceback +import typing as t from functools import wraps from ansible.module_utils.basic import missing_required_lib +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + PYXCLI_INSTALLED = True PYXCLI_IMP_ERR = None try: @@ -49,7 +53,7 @@ def xcli_wrapper(func): """Catch xcli errors and return a proper message""" @wraps(func) - def wrapper(module, *args, **kwargs): + def wrapper(module: AnsibleModule, *args, **kwargs): try: return func(module, *args, **kwargs) except errors.CommandExecutionError as e: @@ -59,7 +63,7 @@ def xcli_wrapper(func): @xcli_wrapper -def connect_ssl(module): +def connect_ssl(module: AnsibleModule): endpoints = module.params["endpoints"] username = module.params["username"] password = module.params["password"] @@ -72,7 +76,7 @@ def connect_ssl(module): module.fail_json(msg=f"Connection with Spectrum Accelerate system has failed: {e}.") -def spectrum_accelerate_spec(): +def spectrum_accelerate_spec() -> dict[str, t.Any]: """Return arguments spec for AnsibleModule""" return dict( endpoints=dict(required=True), @@ -82,7 +86,7 @@ def spectrum_accelerate_spec(): @xcli_wrapper -def execute_pyxcli_command(module, xcli_command, xcli_client): +def execute_pyxcli_command(module: AnsibleModule, xcli_command, xcli_client): pyxcli_args = build_pyxcli_command(module.params) getattr(xcli_client.cmd, xcli_command)(**(pyxcli_args)) return True @@ -99,6 +103,6 @@ def build_pyxcli_command(fields): return pyxcli_args -def is_pyxcli_installed(module): +def is_pyxcli_installed(module: AnsibleModule) -> None: if not PYXCLI_INSTALLED: module.fail_json(msg=missing_required_lib("pyxcli"), exception=PYXCLI_IMP_ERR) diff --git a/plugins/module_utils/ilo_redfish_utils.py b/plugins/module_utils/ilo_redfish_utils.py index c76477d3e0..64f746989e 100644 --- a/plugins/module_utils/ilo_redfish_utils.py +++ b/plugins/module_utils/ilo_redfish_utils.py @@ -4,13 +4,15 @@ from __future__ import annotations -from ansible_collections.community.general.plugins.module_utils.redfish_utils import RedfishUtils import time +import typing as t + +from ansible_collections.community.general.plugins.module_utils.redfish_utils import RedfishUtils class iLORedfishUtils(RedfishUtils): - def get_ilo_sessions(self): - result = {} + def get_ilo_sessions(self) -> dict[str, t.Any]: + result: dict[str, t.Any] = {} # listing all users has always been slower than other operations, why? session_list = [] sessions_results = [] @@ -48,8 +50,8 @@ class iLORedfishUtils(RedfishUtils): result["ret"] = True return result - def set_ntp_server(self, mgr_attributes): - result = {} + def set_ntp_server(self, mgr_attributes) -> dict[str, t.Any]: + result: dict[str, t.Any] = {} setkey = mgr_attributes["mgr_attr_name"] nic_info = self.get_manager_ethernet_uri() @@ -60,7 +62,7 @@ class iLORedfishUtils(RedfishUtils): return response result["ret"] = True data = response["data"] - payload = {"DHCPv4": {"UseNTPServers": ""}} + payload: dict[str, t.Any] = {"DHCPv4": {"UseNTPServers": ""}} if data["DHCPv4"]["UseNTPServers"]: payload["DHCPv4"]["UseNTPServers"] = False @@ -97,7 +99,7 @@ class iLORedfishUtils(RedfishUtils): return {"ret": True, "changed": True, "msg": f"Modified {mgr_attributes['mgr_attr_name']}"} - def set_time_zone(self, attr): + def set_time_zone(self, attr) -> dict[str, t.Any]: key = attr["mgr_attr_name"] uri = f"{self.manager_uri}DateTime/" @@ -124,7 +126,7 @@ class iLORedfishUtils(RedfishUtils): return {"ret": True, "changed": True, "msg": f"Modified {attr['mgr_attr_name']}"} - def set_dns_server(self, attr): + def set_dns_server(self, attr) -> dict[str, t.Any]: key = attr["mgr_attr_name"] nic_info = self.get_manager_ethernet_uri() uri = nic_info["nic_addr"] @@ -148,7 +150,7 @@ class iLORedfishUtils(RedfishUtils): return {"ret": True, "changed": True, "msg": f"Modified {attr['mgr_attr_name']}"} - def set_domain_name(self, attr): + def set_domain_name(self, attr) -> dict[str, t.Any]: key = attr["mgr_attr_name"] nic_info = self.get_manager_ethernet_uri() @@ -160,7 +162,7 @@ class iLORedfishUtils(RedfishUtils): data = response["data"] - payload = {"DHCPv4": {"UseDomainName": ""}} + payload: dict[str, t.Any] = {"DHCPv4": {"UseDomainName": ""}} if data["DHCPv4"]["UseDomainName"]: payload["DHCPv4"]["UseDomainName"] = False @@ -185,7 +187,7 @@ class iLORedfishUtils(RedfishUtils): return response return {"ret": True, "changed": True, "msg": f"Modified {attr['mgr_attr_name']}"} - def set_wins_registration(self, mgrattr): + def set_wins_registration(self, mgrattr) -> dict[str, t.Any]: Key = mgrattr["mgr_attr_name"] nic_info = self.get_manager_ethernet_uri() @@ -198,7 +200,7 @@ class iLORedfishUtils(RedfishUtils): return response return {"ret": True, "changed": True, "msg": f"Modified {mgrattr['mgr_attr_name']}"} - def get_server_poststate(self): + def get_server_poststate(self) -> dict[str, t.Any]: # Get server details response = self.get_request(self.root_uri + self.systems_uri) if not response["ret"]: @@ -210,7 +212,7 @@ class iLORedfishUtils(RedfishUtils): else: return {"ret": True, "server_poststate": server_data["Oem"]["Hp"]["PostState"]} - def wait_for_ilo_reboot_completion(self, polling_interval=60, max_polling_time=1800): + def wait_for_ilo_reboot_completion(self, polling_interval=60, max_polling_time=1800) -> dict[str, t.Any]: # This method checks if OOB controller reboot is completed time.sleep(10) diff --git a/plugins/module_utils/influxdb.py b/plugins/module_utils/influxdb.py index 6cd864782d..3a9677f7c4 100644 --- a/plugins/module_utils/influxdb.py +++ b/plugins/module_utils/influxdb.py @@ -5,11 +5,15 @@ from __future__ import annotations import traceback +import typing as t from ansible.module_utils.basic import missing_required_lib from ansible_collections.community.general.plugins.module_utils.version import LooseVersion +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + REQUESTS_IMP_ERR = None try: import requests.exceptions # noqa: F401, pylint: disable=unused-import @@ -32,7 +36,7 @@ except ImportError: class InfluxDb: - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: self.module = module self.params = self.module.params self.check_lib() @@ -43,7 +47,7 @@ class InfluxDb: self.password = self.params["password"] self.database_name = self.params.get("database_name") - def check_lib(self): + def check_lib(self) -> None: if not HAS_REQUESTS: self.module.fail_json(msg=missing_required_lib("requests"), exception=REQUESTS_IMP_ERR) @@ -51,7 +55,7 @@ class InfluxDb: self.module.fail_json(msg=missing_required_lib("influxdb"), exception=INFLUXDB_IMP_ERR) @staticmethod - def influxdb_argument_spec(): + def influxdb_argument_spec() -> dict[str, t.Any]: return dict( hostname=dict(type="str", default="localhost"), port=dict(type="int", default=8086), @@ -67,7 +71,7 @@ class InfluxDb: udp_port=dict(type="int", default=4444), ) - def connect_to_influxdb(self): + def connect_to_influxdb(self) -> InfluxDBClient: args = dict( host=self.hostname, port=self.port, diff --git a/plugins/module_utils/ipa.py b/plugins/module_utils/ipa.py index ef8e11bb62..ed8c180f8c 100644 --- a/plugins/module_utils/ipa.py +++ b/plugins/module_utils/ipa.py @@ -13,17 +13,21 @@ from __future__ import annotations import json import os +import re import socket import uuid - -import re -from ansible.module_utils.common.text.converters import to_bytes, to_text -from ansible.module_utils.urls import fetch_url, HAS_GSSAPI -from ansible.module_utils.basic import env_fallback, AnsibleFallbackNotFound +import typing as t from urllib.parse import quote +from ansible.module_utils.basic import env_fallback, AnsibleFallbackNotFound +from ansible.module_utils.common.text.converters import to_bytes, to_text +from ansible.module_utils.urls import fetch_url, HAS_GSSAPI -def _env_then_dns_fallback(*args, **kwargs): +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + + +def _env_then_dns_fallback(*args, **kwargs) -> str: """Load value from environment or DNS in that order""" try: result = env_fallback(*args, **kwargs) @@ -41,7 +45,7 @@ def _env_then_dns_fallback(*args, **kwargs): class IPAClient: - def __init__(self, module, host, port, protocol): + def __init__(self, module: AnsibleModule, host, port, protocol): self.host = host self.port = port self.protocol = protocol @@ -50,10 +54,10 @@ class IPAClient: self.timeout = module.params.get("ipa_timeout") self.use_gssapi = False - def get_base_url(self): + def get_base_url(self) -> str: return f"{self.protocol}://{self.host}/ipa" - def get_json_url(self): + def get_json_url(self) -> str: return f"{self.get_base_url()}/session/json" def login(self, username, password): @@ -98,7 +102,7 @@ class IPAClient: {"referer": self.get_base_url(), "Content-Type": "application/json", "Accept": "application/json"} ) - def _fail(self, msg, e): + def _fail(self, msg: str, e) -> t.NoReturn: if "message" in e: err_string = e.get("message") else: @@ -205,7 +209,7 @@ class IPAClient: return changed -def ipa_argument_spec(): +def ipa_argument_spec() -> dict[str, t.Any]: return dict( ipa_prot=dict(type="str", default="https", choices=["http", "https"], fallback=(env_fallback, ["IPA_PROT"])), ipa_host=dict(type="str", default="ipa.example.com", fallback=(_env_then_dns_fallback, ["IPA_HOST"])), diff --git a/plugins/module_utils/jenkins.py b/plugins/module_utils/jenkins.py index 9c9f16c969..810128dab8 100644 --- a/plugins/module_utils/jenkins.py +++ b/plugins/module_utils/jenkins.py @@ -10,7 +10,7 @@ import os import time -def download_updates_file(updates_expiration): +def download_updates_file(updates_expiration: int | float) -> tuple[str, bool]: updates_filename = "jenkins-plugin-cache.json" updates_dir = os.path.expanduser("~/.ansible/tmp") updates_file = os.path.join(updates_dir, updates_filename) diff --git a/plugins/module_utils/ldap.py b/plugins/module_utils/ldap.py index 430e51b1d2..41139d1882 100644 --- a/plugins/module_utils/ldap.py +++ b/plugins/module_utils/ldap.py @@ -9,6 +9,10 @@ from __future__ import annotations import re import traceback +import typing as t + +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule try: import ldap @@ -26,7 +30,7 @@ except ImportError: HAS_LDAP = False -def gen_specs(**specs): +def gen_specs(**specs: t.Any) -> dict[str, t.Any]: specs.update( { "bind_dn": dict(), @@ -47,12 +51,12 @@ def gen_specs(**specs): return specs -def ldap_required_together(): +def ldap_required_together() -> list[list[str]]: return [["client_cert", "client_key"]] class LdapGeneric: - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: # Shortcuts self.module = module self.bind_dn = self.module.params["bind_dn"] @@ -76,11 +80,11 @@ class LdapGeneric: else: self.dn = self.module.params["dn"] - def fail(self, msg, exn): + def fail(self, msg: str, exn: str | Exception) -> t.NoReturn: self.module.fail_json(msg=msg, details=f"{exn}", exception=traceback.format_exc()) - def _find_dn(self): - dn = self.module.params["dn"] + def _find_dn(self) -> str: + dn: str = self.module.params["dn"] explode_dn = ldap.dn.explode_dn(dn) @@ -130,7 +134,7 @@ class LdapGeneric: return connection - def _xorder_dn(self): + def _xorder_dn(self) -> bool: # match X_ORDERed DNs regex = r".+\{\d+\}.+" explode_dn = ldap.dn.explode_dn(self.module.params["dn"]) diff --git a/plugins/module_utils/linode.py b/plugins/module_utils/linode.py index b5279c85c1..95eeef8aa6 100644 --- a/plugins/module_utils/linode.py +++ b/plugins/module_utils/linode.py @@ -14,6 +14,6 @@ from __future__ import annotations from ansible.module_utils.ansible_release import __version__ as ansible_version -def get_user_agent(module): +def get_user_agent(module: str) -> str: """Retrieve a user-agent to send with LinodeClient requests.""" return f"Ansible-{module}/{ansible_version}" diff --git a/plugins/module_utils/locale_gen.py b/plugins/module_utils/locale_gen.py index 22661ae261..f43337f08b 100644 --- a/plugins/module_utils/locale_gen.py +++ b/plugins/module_utils/locale_gen.py @@ -4,10 +4,15 @@ from __future__ import annotations +import typing as t + from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule -def locale_runner(module): + +def locale_runner(module: AnsibleModule) -> CmdRunner: runner = CmdRunner( module, command=["locale", "-a"], @@ -16,7 +21,7 @@ def locale_runner(module): return runner -def locale_gen_runner(module): +def locale_gen_runner(module: AnsibleModule) -> CmdRunner: runner = CmdRunner( module, command="locale-gen", diff --git a/plugins/module_utils/lxd.py b/plugins/module_utils/lxd.py index 95644ff860..7513acd69e 100644 --- a/plugins/module_utils/lxd.py +++ b/plugins/module_utils/lxd.py @@ -6,10 +6,11 @@ from __future__ import annotations import http.client as http_client +import json import os import socket import ssl -import json +import typing as t from urllib.parse import urlparse from ansible.module_utils.urls import generic_urlparse @@ -20,7 +21,7 @@ HTTPSConnection = http_client.HTTPSConnection class UnixHTTPConnection(HTTPConnection): - def __init__(self, path): + def __init__(self, path: str) -> None: HTTPConnection.__init__(self, "localhost") self.path = path @@ -31,33 +32,34 @@ class UnixHTTPConnection(HTTPConnection): class LXDClientException(Exception): - def __init__(self, msg, **kwargs): + def __init__(self, msg: str, **kwargs) -> None: self.msg = msg self.kwargs = kwargs class LXDClient: def __init__( - self, url, key_file=None, cert_file=None, debug=False, server_cert_file=None, server_check_hostname=True - ): + self, + url: str, + key_file: str | None = None, + cert_file: str | None = None, + debug: bool = False, + server_cert_file: str | None = None, + server_check_hostname: bool = True, + ) -> None: """LXD Client. :param url: The URL of the LXD server. (e.g. unix:/var/lib/lxd/unix.socket or https://127.0.0.1) - :type url: ``str`` :param key_file: The path of the client certificate key file. - :type key_file: ``str`` :param cert_file: The path of the client certificate file. - :type cert_file: ``str`` :param debug: The debug flag. The request and response are stored in logs when debug is true. - :type debug: ``bool`` :param server_cert_file: The path of the server certificate file. - :type server_cert_file: ``str`` :param server_check_hostname: Whether to check the server's hostname as part of TLS verification. - :type debug: ``bool`` """ self.url = url self.debug = debug - self.logs = [] + self.logs: list[dict[str, t.Any]] = [] + self.connection: UnixHTTPConnection | HTTPSConnection if url.startswith("https:"): self.cert_file = cert_file self.key_file = key_file @@ -67,7 +69,7 @@ class LXDClient: # Check that the received cert is signed by the provided server_cert_file ctx.load_verify_locations(cafile=server_cert_file) ctx.check_hostname = server_check_hostname - ctx.load_cert_chain(cert_file, keyfile=key_file) + ctx.load_cert_chain(cert_file, keyfile=key_file) # type: ignore # TODO! self.connection = HTTPSConnection(parts.get("netloc"), context=ctx) elif url.startswith("unix:"): unix_socket_path = url[len("unix:") :] @@ -75,7 +77,7 @@ class LXDClient: else: raise LXDClientException("URL scheme must be unix: or https:") - def do(self, method, url, body_json=None, ok_error_codes=None, timeout=None, wait_for_container=None): + def do(self, method: str, url: str, body_json=None, ok_error_codes=None, timeout=None, wait_for_container=None): resp_json = self._send_request(method, url, body_json=body_json, ok_error_codes=ok_error_codes, timeout=timeout) if resp_json["type"] == "async": url = f"{resp_json['operation']}/wait" @@ -91,7 +93,7 @@ class LXDClient: body_json = {"type": "client", "password": trust_password} return self._send_request("POST", "/1.0/certificates", body_json=body_json) - def _send_request(self, method, url, body_json=None, ok_error_codes=None, timeout=None): + def _send_request(self, method: str, url: str, body_json=None, ok_error_codes=None, timeout=None): try: body = json.dumps(body_json) self.connection.request(method, url, body=body) @@ -133,9 +135,9 @@ class LXDClient: return err -def default_key_file(): +def default_key_file() -> str: return os.path.expanduser("~/.config/lxc/client.key") -def default_cert_file(): +def default_cert_file() -> str: return os.path.expanduser("~/.config/lxc/client.crt") diff --git a/plugins/module_utils/manageiq.py b/plugins/module_utils/manageiq.py index 707dcf0601..3936079f21 100644 --- a/plugins/module_utils/manageiq.py +++ b/plugins/module_utils/manageiq.py @@ -15,9 +15,13 @@ from __future__ import annotations import os import traceback +import typing as t from ansible.module_utils.basic import missing_required_lib +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + CLIENT_IMP_ERR = None try: from manageiq_client.api import ManageIQClient @@ -28,7 +32,7 @@ except ImportError: HAS_CLIENT = False -def manageiq_argument_spec(): +def manageiq_argument_spec() -> dict[str, t.Any]: options = dict( url=dict(default=os.environ.get("MIQ_URL", None)), username=dict(default=os.environ.get("MIQ_USERNAME", None)), @@ -43,27 +47,28 @@ def manageiq_argument_spec(): ) -def check_client(module): +def check_client(module: AnsibleModule) -> None: if not HAS_CLIENT: module.fail_json(msg=missing_required_lib("manageiq-client"), exception=CLIENT_IMP_ERR) -def validate_connection_params(module): - params = module.params["manageiq_connection"] +def validate_connection_params(module: AnsibleModule) -> dict[str, t.Any]: + params: dict[str, t.Any] = module.params["manageiq_connection"] error_str = "missing required argument: manageiq_connection[{}]" - url = params["url"] - token = params["token"] - username = params["username"] - password = params["password"] + url: str | None = params["url"] + token: str | None = params["token"] + username: str | None = params["username"] + password: str | None = params["password"] if (url and username and password) or (url and token): return params for arg in ["url", "username", "password"]: if params[arg] in (None, ""): module.fail_json(msg=error_str.format(arg)) + raise AssertionError("should be unreachable") -def manageiq_entities(): +def manageiq_entities() -> dict[str, str]: return { "provider": "providers", "host": "hosts", @@ -87,7 +92,7 @@ class ManageIQ: class encapsulating ManageIQ API client. """ - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: # handle import errors check_client(module) @@ -111,7 +116,7 @@ class ManageIQ: self.module.fail_json(msg=f"failed to open connection ({url}): {e}") @property - def module(self): + def module(self) -> AnsibleModule: """Ansible module module Returns: @@ -120,7 +125,7 @@ class ManageIQ: return self._module @property - def api_url(self): + def api_url(self) -> str: """Base ManageIQ API Returns: @@ -192,7 +197,7 @@ class ManageIQPolicies: Object to execute policies management operations of manageiq resources. """ - def __init__(self, manageiq, resource_type, resource_id): + def __init__(self, manageiq: ManageIQ, resource_type, resource_id): self.manageiq = manageiq self.module = self.manageiq.module @@ -330,7 +335,7 @@ class ManageIQTags: Object to execute tags management operations of manageiq resources. """ - def __init__(self, manageiq, resource_type, resource_id): + def __init__(self, manageiq: ManageIQ, resource_type, resource_id): self.manageiq = manageiq self.module = self.manageiq.module diff --git a/plugins/module_utils/ocapi_utils.py b/plugins/module_utils/ocapi_utils.py index 747e18c366..11b5787cfb 100644 --- a/plugins/module_utils/ocapi_utils.py +++ b/plugins/module_utils/ocapi_utils.py @@ -8,12 +8,16 @@ from __future__ import annotations import json import os import uuid +import typing as t from urllib.error import URLError, HTTPError from urllib.parse import urlparse from ansible.module_utils.urls import open_url from ansible.module_utils.common.text.converters import to_native +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + GET_HEADERS = {"accept": "application/json"} PUT_HEADERS = {"content-type": "application/json", "accept": "application/json"} @@ -24,7 +28,7 @@ HEALTH_OK = 5 class OcapiUtils: - def __init__(self, creds, base_uri, proxy_slot_number, timeout, module): + def __init__(self, creds, base_uri, proxy_slot_number, timeout, module: AnsibleModule) -> None: self.root_uri = base_uri self.proxy_slot_number = proxy_slot_number self.creds = creds diff --git a/plugins/module_utils/onepassword.py b/plugins/module_utils/onepassword.py index 6567d4bd1c..7eabbc7eff 100644 --- a/plugins/module_utils/onepassword.py +++ b/plugins/module_utils/onepassword.py @@ -14,11 +14,11 @@ class OnePasswordConfig: "~/.config/.op/config", ) - def __init__(self): + def __init__(self) -> None: self._config_file_path = "" @property - def config_file_path(self): + def config_file_path(self) -> str | None: if self._config_file_path: return self._config_file_path @@ -27,3 +27,5 @@ class OnePasswordConfig: if os.path.exists(realpath): self._config_file_path = realpath return self._config_file_path + + return None diff --git a/plugins/module_utils/online.py b/plugins/module_utils/online.py index bf5bcb4725..adf5f66ae4 100644 --- a/plugins/module_utils/online.py +++ b/plugins/module_utils/online.py @@ -6,12 +6,16 @@ from __future__ import annotations import json import sys +import typing as t from ansible.module_utils.basic import env_fallback from ansible.module_utils.urls import fetch_url +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule -def online_argument_spec(): + +def online_argument_spec() -> dict[str, t.Any]: return dict( api_token=dict( required=True, @@ -28,7 +32,7 @@ def online_argument_spec(): class OnlineException(Exception): - def __init__(self, message): + def __init__(self, message: str) -> None: self.message = message @@ -60,7 +64,7 @@ class Response: class Online: - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: self.module = module self.headers = { "Authorization": f"Bearer {self.module.params.get('api_token')}", diff --git a/plugins/module_utils/pacemaker.py b/plugins/module_utils/pacemaker.py index b2085eff4a..01f675f9f6 100644 --- a/plugins/module_utils/pacemaker.py +++ b/plugins/module_utils/pacemaker.py @@ -5,9 +5,14 @@ from __future__ import annotations import re +import typing as t from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + + _state_map = { "present": "create", "absent": "remove", @@ -46,7 +51,7 @@ def fmt_resource_argument(value): return ["--group" if value["argument_action"] == "group" else value["argument_action"]] + value["argument_option"] -def get_pacemaker_maintenance_mode(runner): +def get_pacemaker_maintenance_mode(runner: CmdRunner) -> bool: with runner("cli_action config") as ctx: rc, out, err = ctx.run(cli_action="property") maint_mode_re = re.compile(r"maintenance-mode.*true", re.IGNORECASE) @@ -54,7 +59,7 @@ def get_pacemaker_maintenance_mode(runner): return bool(maintenance_mode_output) -def pacemaker_runner(module, **kwargs): +def pacemaker_runner(module: AnsibleModule, **kwargs) -> CmdRunner: runner_command = ["pcs"] runner = CmdRunner( module, diff --git a/plugins/module_utils/pipx.py b/plugins/module_utils/pipx.py index f4bfe26c55..093a251ac2 100644 --- a/plugins/module_utils/pipx.py +++ b/plugins/module_utils/pipx.py @@ -4,12 +4,14 @@ from __future__ import annotations - import json - +import typing as t from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + pipx_common_argspec = { "global": dict(type="bool", default=False), @@ -36,7 +38,7 @@ _state_map = dict( ) -def pipx_runner(module, command, **kwargs): +def pipx_runner(module: AnsibleModule, command, **kwargs) -> CmdRunner: arg_formats = dict( state=cmd_runner_fmt.as_map(_state_map), name=cmd_runner_fmt.as_list(), diff --git a/plugins/module_utils/pkg_req.py b/plugins/module_utils/pkg_req.py index 8b3297e6e2..06a44cd5cf 100644 --- a/plugins/module_utils/pkg_req.py +++ b/plugins/module_utils/pkg_req.py @@ -4,8 +4,13 @@ from __future__ import annotations +import typing as t + from ansible_collections.community.general.plugins.module_utils import deps +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + with deps.declare("packaging"): from packaging.requirements import Requirement @@ -13,11 +18,11 @@ with deps.declare("packaging"): class PackageRequirement: - def __init__(self, module, name): + def __init__(self, module: AnsibleModule, name: str) -> None: self.module = module self.parsed_name, self.requirement = self._parse_spec(name) - def _parse_spec(self, name): + def _parse_spec(self, name: str) -> tuple[str, Requirement | None]: """ Parse a package name that may include version specifiers using PEP 508. Returns a tuple of (name, requirement) where requirement is of type packaging.requirements.Requirement and it may be None. @@ -49,7 +54,7 @@ class PackageRequirement: except Exception as e: raise ValueError(f"Invalid package specification for '{name}': {e}") from e - def matches_version(self, version): + def matches_version(self, version: str): """ Check if a version string fulfills a version specifier. diff --git a/plugins/module_utils/puppet.py b/plugins/module_utils/puppet.py index 50c59ee142..a43eeace78 100644 --- a/plugins/module_utils/puppet.py +++ b/plugins/module_utils/puppet.py @@ -4,29 +4,32 @@ from __future__ import annotations - import os +import typing as t from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + _PUPPET_PATH_PREFIX = ["/opt/puppetlabs/bin"] -def get_facter_dir(): +def get_facter_dir() -> str: if os.getuid() == 0: return "/etc/facter/facts.d" else: return os.path.expanduser("~/.facter/facts.d") -def _puppet_cmd(module): +def _puppet_cmd(module: AnsibleModule) -> str | None: return module.get_bin_path("puppet", False, _PUPPET_PATH_PREFIX) # If the `timeout` CLI command feature is removed, # Then we could add this as a fixed param to `puppet_runner` -def ensure_agent_enabled(module): +def ensure_agent_enabled(module: AnsibleModule) -> None: runner = CmdRunner( module, command="puppet", @@ -44,7 +47,7 @@ def ensure_agent_enabled(module): module.fail_json(msg="Puppet agent state could not be determined.") -def puppet_runner(module): +def puppet_runner(module: AnsibleModule) -> CmdRunner: # Keeping backward compatibility, allow for running with the `timeout` CLI command. # If this can be replaced with ansible `timeout` parameter in playbook, # then this function could be removed. diff --git a/plugins/module_utils/python_runner.py b/plugins/module_utils/python_runner.py index aed806b175..c8780174db 100644 --- a/plugins/module_utils/python_runner.py +++ b/plugins/module_utils/python_runner.py @@ -5,31 +5,35 @@ from __future__ import annotations import os +import typing as t from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, _ensure_list +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + class PythonRunner(CmdRunner): def __init__( self, - module, + module: AnsibleModule, command, arg_formats=None, default_args_order=(), - check_rc=False, - force_lang="C", - path_prefix=None, - environ_update=None, - python="python", - venv=None, - ): + check_rc: bool = False, + force_lang: str = "C", + path_prefix: list[str] | None = None, + environ_update: dict[str, str] | None = None, + python: str = "python", + venv: str | None = None, + ) -> None: self.python = python self.venv = venv self.has_venv = venv is not None if os.path.isabs(python) or "/" in python: self.python = python - elif self.has_venv: + elif venv is not None: if path_prefix is None: path_prefix = [] path_prefix.append(os.path.join(venv, "bin")) diff --git a/plugins/module_utils/redfish_utils.py b/plugins/module_utils/redfish_utils.py index c9f44b7d0b..5904aa9eb4 100644 --- a/plugins/module_utils/redfish_utils.py +++ b/plugins/module_utils/redfish_utils.py @@ -10,6 +10,8 @@ import os import random import string import time +import typing as t + from ansible.module_utils.urls import open_url from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.common.text.converters import to_text @@ -17,6 +19,10 @@ from ansible.module_utils.common.text.converters import to_bytes from urllib.error import URLError, HTTPError from urllib.parse import urlparse +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + + GET_HEADERS = {"accept": "application/json", "OData-Version": "4.0"} POST_HEADERS = {"content-type": "application/json", "accept": "application/json", "OData-Version": "4.0"} PATCH_HEADERS = {"content-type": "application/json", "accept": "application/json", "OData-Version": "4.0"} @@ -49,15 +55,15 @@ REDFISH_COMMON_ARGUMENT_SPEC = { class RedfishUtils: def __init__( self, - creds, - root_uri, + creds: dict[str, str], + root_uri: str, timeout, - module, + module: AnsibleModule, resource_id=None, - data_modification=False, - strip_etag_quotes=False, - ciphers=None, - ): + data_modification: bool = False, + strip_etag_quotes: bool = False, + ciphers: str | None = None, + ) -> None: self.root_uri = root_uri self.creds = creds self.timeout = timeout @@ -73,7 +79,7 @@ class RedfishUtils: self.validate_certs = module.params.get("validate_certs", False) self.ca_path = module.params.get("ca_path") - def _auth_params(self, headers): + def _auth_params(self, headers: dict[str, str]) -> tuple[str | None, str | None, bool]: """ Return tuple of required authentication params based on the presence of a token in the self.creds dict. If using a token, set the @@ -151,7 +157,7 @@ class RedfishUtils: resp["msg"] = f"Properties in {uri} are already set" return resp - def _request(self, uri, **kwargs): + def _request(self, uri: str, **kwargs): kwargs.setdefault("validate_certs", self.validate_certs) kwargs.setdefault("follow_redirects", "all") kwargs.setdefault("use_proxy", True) @@ -163,7 +169,9 @@ class RedfishUtils: return resp, headers # The following functions are to send GET/POST/PATCH/DELETE requests - def get_request(self, uri, override_headers=None, allow_no_resp=False, timeout=None): + def get_request( + self, uri: str, override_headers: dict[str, str] | None = None, allow_no_resp: bool = False, timeout=None + ): req_headers = dict(GET_HEADERS) if override_headers: req_headers.update(override_headers) @@ -206,7 +214,7 @@ class RedfishUtils: return {"ret": False, "msg": f"Failed GET request to '{uri}': '{e}'"} return {"ret": True, "data": data, "headers": headers, "resp": resp} - def post_request(self, uri, pyld, multipart=False): + def post_request(self, uri: str, pyld, multipart: bool = False): req_headers = dict(POST_HEADERS) username, password, basic_auth = self._auth_params(req_headers) try: @@ -251,7 +259,7 @@ class RedfishUtils: return {"ret": False, "msg": f"Failed POST request to '{uri}': '{e}'"} return {"ret": True, "data": data, "headers": headers, "resp": resp} - def patch_request(self, uri, pyld, check_pyld=False): + def patch_request(self, uri: str, pyld, check_pyld: bool = False): req_headers = dict(PATCH_HEADERS) r = self.get_request(uri) if r["ret"]: @@ -303,7 +311,7 @@ class RedfishUtils: return {"ret": False, "changed": False, "msg": f"Failed PATCH request to '{uri}': '{e}'"} return {"ret": True, "changed": True, "resp": resp, "msg": f"Modified {uri}"} - def put_request(self, uri, pyld): + def put_request(self, uri: str, pyld): req_headers = dict(PUT_HEADERS) r = self.get_request(uri) if r["ret"]: @@ -341,7 +349,7 @@ class RedfishUtils: return {"ret": False, "msg": f"Failed PUT request to '{uri}': '{e}'"} return {"ret": True, "resp": resp} - def delete_request(self, uri, pyld=None): + def delete_request(self, uri: str, pyld=None): req_headers = dict(DELETE_HEADERS) username, password, basic_auth = self._auth_params(req_headers) try: diff --git a/plugins/module_utils/redis.py b/plugins/module_utils/redis.py index dc00330b90..615dcd48cd 100644 --- a/plugins/module_utils/redis.py +++ b/plugins/module_utils/redis.py @@ -5,11 +5,15 @@ from __future__ import annotations -from ansible.module_utils.basic import missing_required_lib - import traceback +import typing as t -REDIS_IMP_ERR = None +from ansible.module_utils.basic import missing_required_lib, AnsibleModule + +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + +REDIS_IMP_ERR: str | None = None try: from redis import Redis from redis import __version__ as redis_version @@ -20,30 +24,30 @@ except ImportError: REDIS_IMP_ERR = traceback.format_exc() HAS_REDIS_PACKAGE = False +CERTIFI_IMPORT_ERROR: str | None = None try: import certifi HAS_CERTIFI_PACKAGE = True - CERTIFI_IMPORT_ERROR = None except ImportError: CERTIFI_IMPORT_ERROR = traceback.format_exc() HAS_CERTIFI_PACKAGE = False -def fail_imports(module, needs_certifi=True): - errors = [] - traceback = [] +def fail_imports(module: AnsibleModule, needs_certifi: bool = True) -> None: + errors: list[str] = [] + traceback: list[str] = [] if not HAS_REDIS_PACKAGE: errors.append(missing_required_lib("redis")) - traceback.append(REDIS_IMP_ERR) + traceback.append(REDIS_IMP_ERR) # type: ignore if not HAS_CERTIFI_PACKAGE and needs_certifi: errors.append(missing_required_lib("certifi")) - traceback.append(CERTIFI_IMPORT_ERROR) + traceback.append(CERTIFI_IMPORT_ERROR) # type: ignore if errors: module.fail_json(msg="\n".join(errors), traceback="\n".join(traceback)) -def redis_auth_argument_spec(tls_default=True): +def redis_auth_argument_spec(tls_default: bool = True) -> dict[str, t.Any]: return dict( login_host=dict( type="str", @@ -60,7 +64,7 @@ def redis_auth_argument_spec(tls_default=True): ) -def redis_auth_params(module): +def redis_auth_params(module: AnsibleModule) -> dict[str, t.Any]: login_host = module.params["login_host"] login_user = module.params["login_user"] login_password = module.params["login_password"] @@ -92,13 +96,12 @@ def redis_auth_params(module): class RedisAnsible: """Base class for Redis module""" - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: self.module = module self.connection = self._connect() - def _connect(self): + def _connect(self) -> Redis: try: return Redis(**redis_auth_params(self.module)) except Exception as e: self.module.fail_json(msg=f"{e}") - return None diff --git a/plugins/module_utils/rundeck.py b/plugins/module_utils/rundeck.py index a9a213446e..996ab698aa 100644 --- a/plugins/module_utils/rundeck.py +++ b/plugins/module_utils/rundeck.py @@ -6,11 +6,15 @@ from __future__ import annotations import json import traceback +import typing as t from ansible.module_utils.urls import fetch_url, url_argument_spec +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule -def api_argument_spec(): + +def api_argument_spec() -> dict[str, t.Any]: """ Creates an argument spec that can be used with any module that will be requesting content via Rundeck API @@ -27,7 +31,13 @@ def api_argument_spec(): return api_argument_spec -def api_request(module, endpoint, data=None, method="GET", content_type="application/json"): +def api_request( + module: AnsibleModule, + endpoint: str, + data: t.Any | None = None, + method: str = "GET", + content_type: str = "application/json", +) -> tuple[t.Any, dict[str, t.Any]]: """Manages Rundeck API requests via HTTP(S) :arg module: The AnsibleModule (used to get url, api_version, api_token, etc). diff --git a/plugins/module_utils/scaleway.py b/plugins/module_utils/scaleway.py index 5df64b6fdd..c7663d7b16 100644 --- a/plugins/module_utils/scaleway.py +++ b/plugins/module_utils/scaleway.py @@ -10,6 +10,7 @@ import sys import datetime import time import traceback +import typing as t from urllib.parse import urlencode from ansible.module_utils.basic import env_fallback, missing_required_lib @@ -19,6 +20,10 @@ from ansible_collections.community.general.plugins.module_utils.datetime import now, ) +if t.TYPE_CHECKING: + from collections.abc import Iterable + from ansible.module_utils.basic import AnsibleModule + SCALEWAY_SECRET_IMP_ERR: str | None = None try: from passlib.hash import argon2 @@ -29,7 +34,7 @@ except Exception: HAS_SCALEWAY_SECRET_PACKAGE = False -def scaleway_argument_spec(): +def scaleway_argument_spec() -> dict[str, t.Any]: return dict( api_token=dict( required=True, @@ -59,7 +64,7 @@ def payload_from_object(scw_object): class ScalewayException(Exception): - def __init__(self, message): + def __init__(self, message: str) -> None: self.message = message @@ -70,7 +75,7 @@ R_LINK_HEADER = r"""<[^>]+>;\srel="(first|previous|next|last)" R_RELATION = r'[^>]+)>; rel="(?Pfirst|previous|next|last)"' -def parse_pagination_link(header): +def parse_pagination_link(header: str) -> dict[str, str]: if not re.match(R_LINK_HEADER, header, re.VERBOSE): raise ScalewayException("Scaleway API answered with an invalid Link pagination header") else: @@ -86,7 +91,7 @@ def parse_pagination_link(header): return parsed_relations -def filter_sensitive_attributes(container, attributes): +def filter_sensitive_attributes(container: dict[str, t.Any], attributes: Iterable[str]) -> dict[str, t.Any]: """ WARNING: This function is effectively private, **do not use it**! It will be removed or renamed once changing its name no longer triggers a pylint bug. @@ -99,7 +104,7 @@ def filter_sensitive_attributes(container, attributes): class SecretVariables: @staticmethod - def ensure_scaleway_secret_package(module): + def ensure_scaleway_secret_package(module: AnsibleModule) -> None: if not HAS_SCALEWAY_SECRET_PACKAGE: module.fail_json( msg=missing_required_lib("passlib[argon2]", url="https://passlib.readthedocs.io/en/stable/"), @@ -169,7 +174,7 @@ class Response: class Scaleway: - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: self.module = module self.headers = { "X-Auth-Token": self.module.params.get("api_token"), @@ -224,8 +229,9 @@ class Scaleway: return Response(resp, info) @staticmethod - def get_user_agent_string(module): - return f"ansible {module.ansible_version} Python {sys.version.split(' ', 1)[0]}" + def get_user_agent_string(module: AnsibleModule) -> str: + ansible_version = module.ansible_version # type: ignore # For some reason this isn't documented in AnsibleModule + return f"ansible {ansible_version} Python {sys.version.split(' ', 1)[0]}" def get(self, path, data=None, headers=None, params=None): return self.send(method="GET", path=path, data=data, headers=headers, params=params) @@ -245,7 +251,7 @@ class Scaleway: def update(self, path, data=None, headers=None, params=None): return self.send(method="UPDATE", path=path, data=data, headers=headers, params=params) - def warn(self, x): + def warn(self, x) -> None: self.module.warn(str(x)) def fetch_state(self, resource): diff --git a/plugins/module_utils/snap.py b/plugins/module_utils/snap.py index b0b41c7bd3..f14b7a5315 100644 --- a/plugins/module_utils/snap.py +++ b/plugins/module_utils/snap.py @@ -4,8 +4,13 @@ from __future__ import annotations +import typing as t + from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + _alias_state_map = dict( present="alias", @@ -22,7 +27,7 @@ _state_map = dict( ) -def snap_runner(module, **kwargs): +def snap_runner(module: AnsibleModule, **kwargs) -> CmdRunner: runner = CmdRunner( module, "snap", @@ -47,7 +52,7 @@ def snap_runner(module, **kwargs): return runner -def get_version(runner): +def get_version(runner: CmdRunner) -> dict[str, list[str]]: with runner("version") as ctx: rc, out, err = ctx.run() return dict(x.split() for x in out.splitlines() if len(x.split()) == 2) diff --git a/plugins/module_utils/ssh.py b/plugins/module_utils/ssh.py index 83a390e4dd..b2af22212d 100644 --- a/plugins/module_utils/ssh.py +++ b/plugins/module_utils/ssh.py @@ -11,7 +11,7 @@ from __future__ import annotations import os -def determine_config_file(user, config_file): +def determine_config_file(user: str | None, config_file: str | None) -> str: if user: config_file = os.path.join(os.path.expanduser(f"~{user}"), ".ssh", "config") elif config_file is None: diff --git a/plugins/module_utils/systemd.py b/plugins/module_utils/systemd.py index 533ce6e729..27f12cf869 100644 --- a/plugins/module_utils/systemd.py +++ b/plugins/module_utils/systemd.py @@ -4,11 +4,15 @@ from __future__ import annotations +import typing as t from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule -def systemd_runner(module, command, **kwargs): + +def systemd_runner(module: AnsibleModule, command, **kwargs) -> CmdRunner: arg_formats = dict( version=cmd_runner_fmt.as_fixed("--version"), list_units=cmd_runner_fmt.as_fixed(["list-units", "--no-pager"]), diff --git a/plugins/module_utils/univention_umc.py b/plugins/module_utils/univention_umc.py index 54f1654827..887e7a6492 100644 --- a/plugins/module_utils/univention_umc.py +++ b/plugins/module_utils/univention_umc.py @@ -165,7 +165,7 @@ def ldap_search(filter, base=None, attr=None): uldap().lo.lo.abandon(msgid) -def module_by_name(module_name_): +def module_by_name(module_name_: str): """Returns an initialized UMC module, identified by the given name. The module is a module specification according to the udm commandline. @@ -202,7 +202,7 @@ def get_umc_admin_objects(): return univention.admin.objects -def umc_module_for_add(module, container_dn, superordinate=None): +def umc_module_for_add(module: str, container_dn, superordinate=None): """Returns an UMC module object prepared for creating a new entry. The module is a module specification according to the udm commandline. @@ -226,7 +226,7 @@ def umc_module_for_add(module, container_dn, superordinate=None): return obj -def umc_module_for_edit(module, object_dn, superordinate=None): +def umc_module_for_edit(module: str, object_dn, superordinate=None): """Returns an UMC module object prepared for editing an existing entry. The module is a module specification according to the udm commandline. diff --git a/plugins/module_utils/utm_utils.py b/plugins/module_utils/utm_utils.py index b3c6810215..443ee1a178 100644 --- a/plugins/module_utils/utm_utils.py +++ b/plugins/module_utils/utm_utils.py @@ -12,18 +12,19 @@ from __future__ import annotations import json +import typing as t from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.urls import fetch_url class UTMModuleConfigurationError(Exception): - def __init__(self, msg, **args): + def __init__(self, msg: str, **args): super().__init__(self, msg) self.msg = msg self.module_fail_args = args - def do_fail(self, module): + def do_fail(self, module: AnsibleModule) -> t.NoReturn: module.fail_json(msg=self.msg, other=self.module_fail_args) @@ -75,7 +76,7 @@ class UTMModule(AnsibleModule): class UTM: - def __init__(self, module, endpoint, change_relevant_keys, info_only=False): + def __init__(self, module: UTMModule, endpoint, change_relevant_keys, info_only=False): """ Initialize UTM Class :param module: The Ansible module diff --git a/plugins/module_utils/vardict.py b/plugins/module_utils/vardict.py index 195ec4d847..da01f6fc09 100644 --- a/plugins/module_utils/vardict.py +++ b/plugins/module_utils/vardict.py @@ -6,33 +6,41 @@ from __future__ import annotations import copy +import typing as t class _Variable: NOTHING = object() - def __init__(self, diff=False, output=True, change=None, fact=False, verbosity=0): + def __init__( + self, + diff: bool = False, + output: bool = True, + change: bool | None = None, + fact: bool = False, + verbosity: int = 0, + ): self.init = False - self.initial_value = None - self.value = None + self.initial_value: t.Any = None + self.value: t.Any = None - self.diff = None - self._change = None - self.output = None - self.fact = None - self._verbosity = None + self.diff: bool = None # type: ignore # will be changed in set_meta() call + self._change: bool | None = None + self.output: bool = None # type: ignore # will be changed in set_meta() call + self.fact: bool = None # type: ignore # will be changed in set_meta() call + self._verbosity: int = None # type: ignore # will be changed in set_meta() call self.set_meta(output=output, diff=diff, change=change, fact=fact, verbosity=verbosity) - def getchange(self): + def getchange(self) -> bool: return self.diff if self._change is None else self._change - def setchange(self, value): + def setchange(self, value: bool | None) -> None: self._change = value - def getverbosity(self): + def getverbosity(self) -> int: return self._verbosity - def setverbosity(self, v): + def setverbosity(self, v: int) -> None: if not (0 <= v <= 4): raise ValueError("verbosity must be an int in the range 0 to 4") self._verbosity = v @@ -40,7 +48,15 @@ class _Variable: change = property(getchange, setchange) verbosity = property(getverbosity, setverbosity) - def set_meta(self, output=None, diff=None, change=None, fact=None, initial_value=NOTHING, verbosity=None): + def set_meta( + self, + output: bool | None = None, + diff: bool | None = None, + change: bool | None = None, + fact: bool | None = None, + initial_value: t.Any = NOTHING, + verbosity: int | None = None, + ) -> None: """Set the metadata for the variable Args: @@ -64,7 +80,7 @@ class _Variable: if verbosity is not None: self.verbosity = verbosity - def as_dict(self, meta_only=False): + def as_dict(self, meta_only: bool = False) -> dict[str, t.Any]: d = { "diff": self.diff, "change": self.change, @@ -77,27 +93,27 @@ class _Variable: d["value"] = self.value return d - def set_value(self, value): + def set_value(self, value: t.Any) -> t.Self: if not self.init: self.initial_value = copy.deepcopy(value) self.init = True self.value = value return self - def is_visible(self, verbosity): + def is_visible(self, verbosity: int) -> bool: return self.verbosity <= verbosity @property - def has_changed(self): + def has_changed(self) -> bool: return self.change and (self.initial_value != self.value) @property - def diff_result(self): + def diff_result(self) -> dict[str, t.Any] | None: if self.diff and self.has_changed: return {"before": self.initial_value, "after": self.value} - return + return None - def __str__(self): + def __str__(self) -> str: return ( f"" @@ -119,34 +135,34 @@ class VarDict: "as_dict", ) - def __init__(self): - self.__vars__ = dict() + def __init__(self) -> None: + self.__vars__: dict[str, _Variable] = dict() - def __getitem__(self, item): + def __getitem__(self, item: str): return self.__vars__[item].value - def __setitem__(self, key, value): + def __setitem__(self, key: str, value) -> None: self.set(key, value) - def __getattr__(self, item): + def __getattr__(self, item: str): try: return self.__vars__[item].value except KeyError: return getattr(super(), item) - def __setattr__(self, key, value): + def __setattr__(self, key: str, value) -> None: if key == "__vars__": super().__setattr__(key, value) else: self.set(key, value) - def _var(self, name): + def _var(self, name: str) -> _Variable: return self.__vars__[name] - def var(self, name): + def var(self, name: str) -> dict[str, t.Any]: return self._var(name).as_dict() - def set_meta(self, name, **kwargs): + def set_meta(self, name: str, **kwargs): """Set the metadata for the variable Args: @@ -160,10 +176,10 @@ class VarDict: """ self._var(name).set_meta(**kwargs) - def get_meta(self, name): + def get_meta(self, name: str) -> dict[str, t.Any]: return self._var(name).as_dict(meta_only=True) - def set(self, name, value, **kwargs): + def set(self, name: str, value, **kwargs) -> None: """Set the value and optionally metadata for a variable. The variable is not required to exist prior to calling `set`. For details on the accepted metada see the documentation for method `set_meta`. @@ -185,10 +201,10 @@ class VarDict: var.set_value(value) self.__vars__[name] = var - def output(self, verbosity=0): + def output(self, verbosity: int = 0) -> dict[str, t.Any]: return {n: v.value for n, v in self.__vars__.items() if v.output and v.is_visible(verbosity)} - def diff(self, verbosity=0): + def diff(self, verbosity: int = 0) -> dict[str, dict[str, t.Any]] | None: diff_results = [ (n, v.diff_result) for n, v in self.__vars__.items() if v.diff_result and v.is_visible(verbosity) ] @@ -198,13 +214,13 @@ class VarDict: return {"before": before, "after": after} return None - def facts(self, verbosity=0): + def facts(self, verbosity: int = 0) -> dict[str, t.Any] | None: facts_result = {n: v.value for n, v in self.__vars__.items() if v.fact and v.is_visible(verbosity)} return facts_result if facts_result else None @property - def has_changed(self): + def has_changed(self) -> bool: return any(var.has_changed for var in self.__vars__.values()) - def as_dict(self): + def as_dict(self) -> dict[str, t.Any]: return {name: var.value for name, var in self.__vars__.items()} diff --git a/plugins/module_utils/vexata.py b/plugins/module_utils/vexata.py index 1ea0ecf17c..b55b67f113 100644 --- a/plugins/module_utils/vexata.py +++ b/plugins/module_utils/vexata.py @@ -5,6 +5,7 @@ from __future__ import annotations +import typing as t HAS_VEXATAPI = True try: @@ -14,10 +15,14 @@ except ImportError: from ansible.module_utils.basic import env_fallback +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + + VXOS_VERSION = None -def get_version(iocs_json): +def get_version(iocs_json) -> tuple[int, ...]: if not iocs_json: raise Exception("Invalid IOC json") active = next((x for x in iocs_json if x["mgmtRole"]), None) @@ -31,7 +36,7 @@ def get_version(iocs_json): return tuple(ver) -def get_array(module): +def get_array(module: AnsibleModule): """Return storage array object or fail""" global VXOS_VERSION array = module.params["array"] @@ -60,7 +65,7 @@ def get_array(module): module.fail_json(msg=f"Vexata API access failed: {e}") -def argument_spec(): +def argument_spec() -> dict[str, t.Any]: """Return standard base dictionary used for the argument_spec argument in AnsibleModule""" return dict( array=dict(type="str", required=True), @@ -70,20 +75,20 @@ def argument_spec(): ) -def required_together(): +def required_together() -> list[list[str]]: """Return the default list used for the required_together argument to AnsibleModule""" return [["user", "password"]] -def size_to_MiB(size): +def size_to_MiB(size: str) -> int: """Convert a '[MGT]' string to MiB, return -1 on error.""" quant = size[:-1] exponent = size[-1] if not quant.isdigit() or exponent not in "MGT": return -1 - quant = int(quant) + quant_int = int(quant) if exponent == "G": - quant <<= 10 + quant_int <<= 10 elif exponent == "T": - quant <<= 20 - return quant + quant_int <<= 20 + return quant_int diff --git a/plugins/module_utils/wdc_redfish_utils.py b/plugins/module_utils/wdc_redfish_utils.py index f27102d61b..56ec7537d6 100644 --- a/plugins/module_utils/wdc_redfish_utils.py +++ b/plugins/module_utils/wdc_redfish_utils.py @@ -5,15 +5,19 @@ from __future__ import annotations import datetime +import os import re import time import tarfile -import os +import typing as t from urllib.parse import urlparse, urlunparse from ansible.module_utils.urls import fetch_file from ansible_collections.community.general.plugins.module_utils.redfish_utils import RedfishUtils +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + class WdcRedfishUtils(RedfishUtils): """Extension to RedfishUtils to support WDC enclosures.""" @@ -41,7 +45,7 @@ class WdcRedfishUtils(RedfishUtils): CHASSIS_LOCATE = "#Chassis.Locate" CHASSIS_POWER_MODE = "#Chassis.PowerMode" - def __init__(self, creds, root_uris, timeout, module, resource_id, data_modification): + def __init__(self, creds, root_uris, timeout, module: AnsibleModule, resource_id, data_modification) -> None: super().__init__( creds=creds, root_uri=root_uris[0], diff --git a/plugins/module_utils/xdg_mime.py b/plugins/module_utils/xdg_mime.py index 220d9f9391..4d48406551 100644 --- a/plugins/module_utils/xdg_mime.py +++ b/plugins/module_utils/xdg_mime.py @@ -5,10 +5,15 @@ from __future__ import annotations +import typing as t + from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule -def xdg_mime_runner(module, **kwargs): + +def xdg_mime_runner(module: AnsibleModule, **kwargs) -> CmdRunner: return CmdRunner( module, command=["xdg-mime"], @@ -23,8 +28,8 @@ def xdg_mime_runner(module, **kwargs): ) -def xdg_mime_get(runner, mime_type): - def process(rc, out, err): +def xdg_mime_get(runner: CmdRunner, mime_type) -> str | None: + def process(rc, out, err) -> str | None: if not out.strip(): return None out = out.splitlines()[0] diff --git a/plugins/module_utils/xenserver.py b/plugins/module_utils/xenserver.py index c51cef85db..bf9b4f0a9b 100644 --- a/plugins/module_utils/xenserver.py +++ b/plugins/module_utils/xenserver.py @@ -9,6 +9,7 @@ import atexit import time import re import traceback +import typing as t XENAPI_IMP_ERR = None try: @@ -22,8 +23,11 @@ except ImportError: from ansible.module_utils.basic import env_fallback, missing_required_lib from ansible.module_utils.ansible_release import __version__ as ANSIBLE_VERSION +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule -def xenserver_common_argument_spec(): + +def xenserver_common_argument_spec() -> dict[str, t.Any]: return dict( hostname=dict( type="str", @@ -41,7 +45,7 @@ def xenserver_common_argument_spec(): ) -def xapi_to_module_vm_power_state(power_state): +def xapi_to_module_vm_power_state(power_state: str) -> str | None: """Maps XAPI VM power states to module VM power states.""" module_power_state_map = { "running": "poweredon", @@ -53,7 +57,7 @@ def xapi_to_module_vm_power_state(power_state): return module_power_state_map.get(power_state) -def module_to_xapi_vm_power_state(power_state): +def module_to_xapi_vm_power_state(power_state: str) -> str | None: """Maps module VM power states to XAPI VM power states.""" vm_power_state_map = { "poweredon": "running", @@ -67,7 +71,7 @@ def module_to_xapi_vm_power_state(power_state): return vm_power_state_map.get(power_state) -def is_valid_ip_addr(ip_addr): +def is_valid_ip_addr(ip_addr: str) -> bool: """Validates given string as IPv4 address for given string. Args: @@ -93,7 +97,7 @@ def is_valid_ip_addr(ip_addr): return True -def is_valid_ip_netmask(ip_netmask): +def is_valid_ip_netmask(ip_netmask: str) -> bool: """Validates given string as IPv4 netmask. Args: @@ -125,7 +129,7 @@ def is_valid_ip_netmask(ip_netmask): return True -def is_valid_ip_prefix(ip_prefix): +def is_valid_ip_prefix(ip_prefix: str) -> bool: """Validates given string as IPv4 prefix. Args: @@ -142,7 +146,7 @@ def is_valid_ip_prefix(ip_prefix): return not (ip_prefix_int < 0 or ip_prefix_int > 32) -def ip_prefix_to_netmask(ip_prefix, skip_check=False): +def ip_prefix_to_netmask(ip_prefix: str, skip_check: bool = False) -> str: """Converts IPv4 prefix to netmask. Args: @@ -165,7 +169,7 @@ def ip_prefix_to_netmask(ip_prefix, skip_check=False): return "" -def ip_netmask_to_prefix(ip_netmask, skip_check=False): +def ip_netmask_to_prefix(ip_netmask: str, skip_check: bool = False) -> str: """Converts IPv4 netmask to prefix. Args: @@ -188,7 +192,7 @@ def ip_netmask_to_prefix(ip_netmask, skip_check=False): return "" -def is_valid_ip6_addr(ip6_addr): +def is_valid_ip6_addr(ip6_addr: str) -> bool: """Validates given string as IPv6 address. Args: @@ -222,7 +226,7 @@ def is_valid_ip6_addr(ip6_addr): return all(ip6_addr_hextet_regex.match(ip6_addr_hextet) for ip6_addr_hextet in ip6_addr_split) -def is_valid_ip6_prefix(ip6_prefix): +def is_valid_ip6_prefix(ip6_prefix: str) -> bool: """Validates given string as IPv6 prefix. Args: @@ -239,7 +243,7 @@ def is_valid_ip6_prefix(ip6_prefix): return not (ip6_prefix_int < 0 or ip6_prefix_int > 128) -def get_object_ref(module, name, uuid=None, obj_type="VM", fail=True, msg_prefix=""): +def get_object_ref(module: AnsibleModule, name, uuid=None, obj_type="VM", fail=True, msg_prefix=""): """Finds and returns a reference to arbitrary XAPI object. An object is searched by using either name (name_label) or UUID @@ -305,7 +309,7 @@ def get_object_ref(module, name, uuid=None, obj_type="VM", fail=True, msg_prefix return obj_ref -def gather_vm_params(module, vm_ref): +def gather_vm_params(module: AnsibleModule, vm_ref): """Gathers all VM parameters available in XAPI database. Args: @@ -395,7 +399,7 @@ def gather_vm_params(module, vm_ref): return vm_params -def gather_vm_facts(module, vm_params): +def gather_vm_facts(module: AnsibleModule, vm_params): """Gathers VM facts. Args: @@ -502,7 +506,7 @@ def gather_vm_facts(module, vm_params): return vm_facts -def set_vm_power_state(module, vm_ref, power_state, timeout=300): +def set_vm_power_state(module: AnsibleModule, vm_ref, power_state, timeout=300): """Controls VM power state. Args: @@ -608,7 +612,7 @@ def set_vm_power_state(module, vm_ref, power_state, timeout=300): return (state_changed, vm_power_state_resulting) -def wait_for_task(module, task_ref, timeout=300): +def wait_for_task(module: AnsibleModule, task_ref, timeout=300): """Waits for async XAPI task to finish. Args: @@ -667,7 +671,7 @@ def wait_for_task(module, task_ref, timeout=300): return result -def wait_for_vm_ip_address(module, vm_ref, timeout=300): +def wait_for_vm_ip_address(module: AnsibleModule, vm_ref, timeout=300): """Waits for VM to acquire an IP address. Args: @@ -730,7 +734,7 @@ def wait_for_vm_ip_address(module, vm_ref, timeout=300): return vm_guest_metrics -def get_xenserver_version(module): +def get_xenserver_version(module: AnsibleModule): """Returns XenServer version. Args: @@ -758,10 +762,10 @@ def get_xenserver_version(module): class XAPI: """Class for XAPI session management.""" - _xapi_session = None + _xapi_session: t.Any | None = None @classmethod - def connect(cls, module, disconnect_atexit=True): + def connect(cls, module: AnsibleModule, disconnect_atexit=True): """Establishes XAPI connection and returns session reference. If no existing session is available, establishes a new one @@ -837,7 +841,7 @@ class XenServerObject: minor version. """ - def __init__(self, module): + def __init__(self, module: AnsibleModule) -> None: """Inits XenServerObject using common module parameters. Args: diff --git a/plugins/module_utils/xfconf.py b/plugins/module_utils/xfconf.py index 2903af62cf..fb5a00df25 100644 --- a/plugins/module_utils/xfconf.py +++ b/plugins/module_utils/xfconf.py @@ -4,9 +4,14 @@ from __future__ import annotations +import typing as t + from ansible.module_utils.parsing.convert_bool import boolean from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt +if t.TYPE_CHECKING: + from ansible.module_utils.basic import AnsibleModule + @cmd_runner_fmt.unpack_args def _values_fmt(values, value_types): @@ -18,7 +23,7 @@ def _values_fmt(values, value_types): return result -def xfconf_runner(module, **kwargs): +def xfconf_runner(module: AnsibleModule, **kwargs) -> CmdRunner: runner = CmdRunner( module, command="xfconf-query", @@ -37,7 +42,7 @@ def xfconf_runner(module, **kwargs): return runner -def get_xfconf_version(runner): +def get_xfconf_version(runner: CmdRunner) -> str: with runner("version") as ctx: rc, out, err = ctx.run() return out.splitlines()[0].split()[1] diff --git a/plugins/modules/lxd_storage_pool_info.py b/plugins/modules/lxd_storage_pool_info.py index a3ed200088..d615d79a8a 100644 --- a/plugins/modules/lxd_storage_pool_info.py +++ b/plugins/modules/lxd_storage_pool_info.py @@ -183,6 +183,7 @@ logs: """ import os +import typing as t from urllib.parse import urlencode from ansible.module_utils.basic import AnsibleModule @@ -246,19 +247,16 @@ class LXDStoragePoolInfo: self.trust_password = self.module.params["trust_password"] - def _fail_from_lxd_exception(self, exception: LXDClientException) -> None: + def _fail_from_lxd_exception(self, exception: LXDClientException) -> t.NoReturn: """Build failure parameters from LXDClientException and fail. :param exception: The LXDClientException instance :type exception: LXDClientException """ - fail_params = { - "msg": exception.msg, - "changed": False, - } + fail_params = {} if self.client.debug and "logs" in exception.kwargs: fail_params["logs"] = exception.kwargs["logs"] - self.module.fail_json(**fail_params) + self.module.fail_json(msg=exception.msg, changed=False, **fail_params) def _build_url(self, endpoint: str) -> str: """Build URL with project parameter if specified.""" diff --git a/plugins/modules/lxd_storage_volume_info.py b/plugins/modules/lxd_storage_volume_info.py index e4abcef995..d243eaa0a7 100644 --- a/plugins/modules/lxd_storage_volume_info.py +++ b/plugins/modules/lxd_storage_volume_info.py @@ -188,6 +188,7 @@ logs: """ import os +import typing as t from urllib.parse import quote, urlencode from ansible.module_utils.basic import AnsibleModule @@ -252,19 +253,16 @@ class LXDStorageVolumeInfo: self.trust_password = self.module.params["trust_password"] - def _fail_from_lxd_exception(self, exception: LXDClientException) -> None: + def _fail_from_lxd_exception(self, exception: LXDClientException) -> t.NoReturn: """Build failure parameters from LXDClientException and fail. :param exception: The LXDClientException instance :type exception: LXDClientException """ - fail_params = { - "msg": exception.msg, - "changed": False, - } + fail_params = {} if self.client.debug and "logs" in exception.kwargs: fail_params["logs"] = exception.kwargs["logs"] - self.module.fail_json(**fail_params) + self.module.fail_json(msg=exception.msg, changed=False, **fail_params) def _build_url(self, endpoint: str) -> str: """Build URL with project parameter if specified."""