1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2026-02-04 07:51:50 +00:00

[PR #11222/c7f6a28d backport][stable-12] Add basic typing for module_utils (#11243)

Add basic typing for module_utils (#11222)

* Add basic typing for module_utils.

* Apply some suggestions.



* Make pass again.

* Add more types as suggested.

* Normalize extra imports.

* Add more type hints.

* Improve typing.

* Add changelog fragment.

* Reduce changelog.

* Apply suggestions from code review.



* Fix typo.

* Cleanup.

* Improve types and make type checking happy.

* Let's see whether older Pythons barf on this.

* Revert "Let's see whether older Pythons barf on this."

This reverts commit 9973af3dbe.

* Add noqa.

---------


(cherry picked from commit c7f6a28d89)

Co-authored-by: Felix Fontein <felix@fontein.de>
Co-authored-by: Alexei Znamensky <103110+russoz@users.noreply.github.com>
This commit is contained in:
patchback[bot] 2025-12-01 21:16:37 +01:00 committed by GitHub
parent a2c7f9f89a
commit 377a599372
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
56 changed files with 725 additions and 469 deletions

View file

@ -0,0 +1,3 @@
bugfixes:
- "_filelock module utils - add type hints. Fix bug if ``set_lock()`` is called with ``lock_timeout=None`` (https://github.com/ansible-collections/community.general/pull/11222)."
- "gitlab module utils - add type hints. Pass API version to python-gitlab as string and not as integer (https://github.com/ansible-collections/community.general/pull/11222)."

View file

@ -11,9 +11,13 @@ import os
import stat
import time
import fcntl
import typing as t
from contextlib import contextmanager
if t.TYPE_CHECKING:
from io import TextIOWrapper
class LockTimeout(Exception):
pass
@ -27,11 +31,13 @@ class FileLock:
unwanted and/or unexpected behaviour
"""
def __init__(self):
self.lockfd = None
def __init__(self) -> None:
self.lockfd: TextIOWrapper | None = None
@contextmanager
def lock_file(self, path, tmpdir, lock_timeout=None):
def lock_file(
self, path: os.PathLike, tmpdir: os.PathLike, lock_timeout: int | float | None = None
) -> t.Generator[None]:
"""
Context for lock acquisition
"""
@ -41,7 +47,9 @@ class FileLock:
finally:
self.unlock()
def set_lock(self, path, tmpdir, lock_timeout=None):
def set_lock(
self, path: os.PathLike, tmpdir: os.PathLike, lock_timeout: int | float | None = None
) -> t.Literal[True]:
"""
Create a lock file based on path with flock to prevent other processes
using given path.
@ -61,13 +69,13 @@ class FileLock:
self.lockfd = open(lock_path, "w")
if lock_timeout <= 0:
if lock_timeout is not None and lock_timeout <= 0:
fcntl.flock(self.lockfd, fcntl.LOCK_EX | fcntl.LOCK_NB)
os.chmod(lock_path, stat.S_IWRITE | stat.S_IREAD)
return True
if lock_timeout:
e_secs = 0
e_secs: float = 0
while e_secs < lock_timeout:
try:
fcntl.flock(self.lockfd, fcntl.LOCK_EX | fcntl.LOCK_NB)
@ -86,7 +94,7 @@ class FileLock:
return True
def unlock(self):
def unlock(self) -> t.Literal[True]:
"""
Make sure lock file is available for everyone and Unlock the file descriptor
locked by set_lock

View file

@ -14,8 +14,13 @@ from __future__ import annotations
import os
import json
import traceback
import typing as t
from ansible.module_utils.basic import env_fallback
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
try:
import footmark
import footmark.ecs
@ -200,7 +205,7 @@ def get_profile(params):
return params
def ecs_connect(module):
def ecs_connect(module: AnsibleModule):
"""Return an ecs connection"""
ecs_params = get_profile(module.params)
# If we have a region specified, connect to its endpoint.
@ -214,7 +219,7 @@ def ecs_connect(module):
return ecs
def slb_connect(module):
def slb_connect(module: AnsibleModule):
"""Return an slb connection"""
slb_params = get_profile(module.params)
# If we have a region specified, connect to its endpoint.
@ -228,7 +233,7 @@ def slb_connect(module):
return slb
def dns_connect(module):
def dns_connect(module: AnsibleModule):
"""Return an dns connection"""
dns_params = get_profile(module.params)
# If we have a region specified, connect to its endpoint.
@ -242,7 +247,7 @@ def dns_connect(module):
return dns
def vpc_connect(module):
def vpc_connect(module: AnsibleModule):
"""Return an vpc connection"""
vpc_params = get_profile(module.params)
# If we have a region specified, connect to its endpoint.
@ -256,7 +261,7 @@ def vpc_connect(module):
return vpc
def rds_connect(module):
def rds_connect(module: AnsibleModule):
"""Return an rds connection"""
rds_params = get_profile(module.params)
# If we have a region specified, connect to its endpoint.
@ -270,7 +275,7 @@ def rds_connect(module):
return rds
def ess_connect(module):
def ess_connect(module: AnsibleModule):
"""Return an ess connection"""
ess_params = get_profile(module.params)
# If we have a region specified, connect to its endpoint.
@ -284,7 +289,7 @@ def ess_connect(module):
return ess
def sts_connect(module):
def sts_connect(module: AnsibleModule):
"""Return an sts connection"""
sts_params = get_profile(module.params)
# If we have a region specified, connect to its endpoint.
@ -298,7 +303,7 @@ def sts_connect(module):
return sts
def ram_connect(module):
def ram_connect(module: AnsibleModule):
"""Return an ram connection"""
ram_params = get_profile(module.params)
# If we have a region specified, connect to its endpoint.
@ -312,7 +317,7 @@ def ram_connect(module):
return ram
def market_connect(module):
def market_connect(module: AnsibleModule):
"""Return an market connection"""
market_params = get_profile(module.params)
# If we have a region specified, connect to its endpoint.

View file

@ -6,22 +6,27 @@ from __future__ import annotations
import re
import typing as t
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
__state_map = {"present": "--install", "absent": "--uninstall"}
# sdkmanager --help 2>&1 | grep -A 2 -- --channel
__channel_map = {"stable": 0, "beta": 1, "dev": 2, "canary": 3}
def __map_channel(channel_name):
def __map_channel(channel_name: str) -> int:
if channel_name not in __channel_map:
raise ValueError(f"Unknown channel name '{channel_name}'")
return __channel_map[channel_name]
def sdkmanager_runner(module, **kwargs):
def sdkmanager_runner(module: AnsibleModule, **kwargs) -> CmdRunner:
return CmdRunner(
module,
command="sdkmanager",
@ -40,18 +45,18 @@ def sdkmanager_runner(module, **kwargs):
class Package:
def __init__(self, name):
def __init__(self, name: str) -> None:
self.name = name
def __hash__(self):
def __hash__(self) -> int:
return hash(self.name)
def __ne__(self, other):
def __ne__(self, other: object) -> bool:
if not isinstance(other, Package):
return True
return self.name != other.name
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, Package):
return False
@ -78,20 +83,20 @@ class AndroidSdkManager:
r"the packages they depend on were not accepted"
)
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.runner = sdkmanager_runner(module)
def get_installed_packages(self):
def get_installed_packages(self) -> set[Package]:
with self.runner("installed sdk_root channel") as ctx:
rc, stdout, stderr = ctx.run()
return self._parse_packages(stdout, self._RE_INSTALLED_PACKAGES_HEADER, self._RE_INSTALLED_PACKAGE)
def get_updatable_packages(self):
def get_updatable_packages(self) -> set[Package]:
with self.runner("list newer sdk_root channel") as ctx:
rc, stdout, stderr = ctx.run()
return self._parse_packages(stdout, self._RE_UPDATABLE_PACKAGES_HEADER, self._RE_UPDATABLE_PACKAGE)
def apply_packages_changes(self, packages, accept_licenses=False):
def apply_packages_changes(self, packages: list[Package], accept_licenses: bool = False) -> tuple[int, str, str]:
"""Install or delete packages, depending on the `module.vars.state` parameter"""
if len(packages) == 0:
return 0, "", ""
@ -113,7 +118,7 @@ class AndroidSdkManager:
return rc, stdout, stderr
return 0, "", ""
def _try_parse_stderr(self, stderr):
def _try_parse_stderr(self, stderr: str) -> None:
data = stderr.splitlines()
for line in data:
unknown_package_regex = self._RE_UNKNOWN_PACKAGE.match(line)
@ -122,15 +127,15 @@ class AndroidSdkManager:
raise SdkManagerException(f"Unknown package {package}")
@staticmethod
def _parse_packages(stdout, header_regexp, row_regexp):
def _parse_packages(stdout: str, header_regexp: re.Pattern, row_regexp: re.Pattern) -> set[Package]:
data = stdout.splitlines()
section_found = False
section_found: bool = False
packages = set()
for line in data:
if not section_found:
section_found = header_regexp.match(line)
section_found = bool(header_regexp.match(line))
continue
else:
p = row_regexp.match(line)

View file

@ -7,6 +7,10 @@ from __future__ import annotations
from ansible.module_utils.common.text.converters import to_bytes
import re
import os
import typing as t
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def normalize_subvolume_path(path):
@ -28,11 +32,11 @@ class BtrfsCommands:
Provides access to a subset of the Btrfs command line
"""
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.__module = module
self.__btrfs = self.__module.get_bin_path("btrfs", required=True)
self.__btrfs: str = self.__module.get_bin_path("btrfs", required=True)
def filesystem_show(self):
def filesystem_show(self) -> list[dict[str, t.Any]]:
command = f"{self.__btrfs} filesystem show -d"
result = self.__module.run_command(command, check_rc=True)
stdout = [x.strip() for x in result[1].splitlines()]
@ -43,14 +47,16 @@ class BtrfsCommands:
current = self.__parse_filesystem(line)
filesystems.append(current)
elif line.startswith("devid"):
if current is None:
raise ValueError("Found 'devid' line without previous 'Label' line")
current["devices"].append(self.__parse_filesystem_device(line))
return filesystems
def __parse_filesystem(self, line):
def __parse_filesystem(self, line) -> dict[str, t.Any]:
label = re.sub(r"\s*uuid:.*$", "", re.sub(r"^Label:\s*", "", line))
id = re.sub(r"^.*uuid:\s*", "", line)
filesystem = {}
filesystem: dict[str, t.Any] = {}
filesystem["label"] = label.strip("'") if label != "none" else None
filesystem["uuid"] = id
filesystem["devices"] = []
@ -59,44 +65,44 @@ class BtrfsCommands:
filesystem["default_subvolid"] = None
return filesystem
def __parse_filesystem_device(self, line):
def __parse_filesystem_device(self, line: str) -> str:
return re.sub(r"^.*path\s", "", line)
def subvolumes_list(self, filesystem_path):
def subvolumes_list(self, filesystem_path: str) -> list[dict[str, t.Any]]:
command = f"{self.__btrfs} subvolume list -tap {filesystem_path}"
result = self.__module.run_command(command, check_rc=True)
stdout = [x.split("\t") for x in result[1].splitlines()]
subvolumes = [{"id": 5, "parent": None, "path": "/"}]
subvolumes: list[dict[str, t.Any]] = [{"id": 5, "parent": None, "path": "/"}]
if len(stdout) > 2:
subvolumes.extend([self.__parse_subvolume_list_record(x) for x in stdout[2:]])
return subvolumes
def __parse_subvolume_list_record(self, item):
def __parse_subvolume_list_record(self, item: list[str]) -> dict[str, t.Any]:
return {
"id": int(item[0]),
"parent": int(item[2]),
"path": normalize_subvolume_path(item[5]),
}
def subvolume_get_default(self, filesystem_path):
def subvolume_get_default(self, filesystem_path: str) -> int:
command = [self.__btrfs, "subvolume", "get-default", to_bytes(filesystem_path)]
result = self.__module.run_command(command, check_rc=True)
# ID [n] ...
return int(result[1].strip().split()[1])
def subvolume_set_default(self, filesystem_path, subvolume_id):
def subvolume_set_default(self, filesystem_path: str, subvolume_id: int) -> None:
command = [self.__btrfs, "subvolume", "set-default", str(subvolume_id), to_bytes(filesystem_path)]
self.__module.run_command(command, check_rc=True)
def subvolume_create(self, subvolume_path):
def subvolume_create(self, subvolume_path: str) -> None:
command = [self.__btrfs, "subvolume", "create", to_bytes(subvolume_path)]
self.__module.run_command(command, check_rc=True)
def subvolume_snapshot(self, snapshot_source, snapshot_destination):
def subvolume_snapshot(self, snapshot_source: str, snapshot_destination: str) -> None:
command = [self.__btrfs, "subvolume", "snapshot", to_bytes(snapshot_source), to_bytes(snapshot_destination)]
self.__module.run_command(command, check_rc=True)
def subvolume_delete(self, subvolume_path):
def subvolume_delete(self, subvolume_path: str) -> None:
command = [self.__btrfs, "subvolume", "delete", to_bytes(subvolume_path)]
self.__module.run_command(command, check_rc=True)
@ -106,12 +112,12 @@ class BtrfsInfoProvider:
Utility providing details of the currently available btrfs filesystems
"""
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.__module = module
self.__btrfs_api = BtrfsCommands(module)
self.__findmnt_path = self.__module.get_bin_path("findmnt", required=True)
self.__findmnt_path: str = self.__module.get_bin_path("findmnt", required=True)
def get_filesystems(self):
def get_filesystems(self) -> list[dict[str, t.Any]]:
filesystems = self.__btrfs_api.filesystem_show()
mountpoints = self.__find_mountpoints()
for filesystem in filesystems:
@ -126,20 +132,22 @@ class BtrfsInfoProvider:
return filesystems
def get_mountpoints(self, filesystem_devices):
def get_mountpoints(self, filesystem_devices: list[str]) -> list[dict[str, t.Any]]:
mountpoints = self.__find_mountpoints()
return self.__filter_mountpoints_for_devices(mountpoints, filesystem_devices)
def get_subvolumes(self, filesystem_path):
def get_subvolumes(self, filesystem_path) -> list[dict[str, t.Any]]:
return self.__btrfs_api.subvolumes_list(filesystem_path)
def get_default_subvolume_id(self, filesystem_path):
def get_default_subvolume_id(self, filesystem_path) -> int:
return self.__btrfs_api.subvolume_get_default(filesystem_path)
def __filter_mountpoints_for_devices(self, mountpoints, devices):
def __filter_mountpoints_for_devices(
self, mountpoints: list[dict[str, t.Any]], devices: list[str]
) -> list[dict[str, t.Any]]:
return [m for m in mountpoints if (m["device"] in devices)]
def __find_mountpoints(self):
def __find_mountpoints(self) -> list[dict[str, t.Any]]:
command = f"{self.__findmnt_path} -t btrfs -nvP"
result = self.__module.run_command(command)
mountpoints = []
@ -150,7 +158,7 @@ class BtrfsInfoProvider:
mountpoints.append(mountpoint)
return mountpoints
def __parse_mountpoint_pairs(self, line):
def __parse_mountpoint_pairs(self, line) -> dict[str, t.Any]:
pattern = re.compile(
r'^TARGET="(?P<target>.*)"\s+SOURCE="(?P<source>.*)"\s+FSTYPE="(?P<fstype>.*)"\s+OPTIONS="(?P<options>.*)"\s*$'
)
@ -164,13 +172,13 @@ class BtrfsInfoProvider:
"subvolid": self.__extract_mount_subvolid(groups["options"]),
}
else:
raise BtrfsModuleException(f"Failed to parse findmnt result for line: '{line}'")
raise BtrfsModuleException(f"Failed to parse findmnt result for line: {line!r}")
def __extract_mount_subvolid(self, mount_options):
def __extract_mount_subvolid(self, mount_options: str) -> int:
for option in mount_options.split(","):
if option.startswith("subvolid="):
return int(option[len("subvolid=") :])
raise BtrfsModuleException(f"Failed to find subvolid for mountpoint in options '{mount_options}'")
raise BtrfsModuleException(f"Failed to find subvolid for mountpoint in options {mount_options!r}")
class BtrfsSubvolume:
@ -178,39 +186,38 @@ class BtrfsSubvolume:
Wrapper class providing convenience methods for inspection of a btrfs subvolume
"""
def __init__(self, filesystem, subvolume_id):
def __init__(self, filesystem: BtrfsFilesystem, subvolume_id: int):
self.__filesystem = filesystem
self.__subvolume_id = subvolume_id
def get_filesystem(self):
def get_filesystem(self) -> BtrfsFilesystem:
return self.__filesystem
def is_mounted(self):
def is_mounted(self) -> bool:
mountpoints = self.get_mountpoints()
return mountpoints is not None and len(mountpoints) > 0
def is_filesystem_root(self):
def is_filesystem_root(self) -> bool:
return self.__subvolume_id == 5
def is_filesystem_default(self):
def is_filesystem_default(self) -> bool:
return self.__filesystem.default_subvolid == self.__subvolume_id
def get_mounted_path(self):
def get_mounted_path(self) -> str | None:
mountpoints = self.get_mountpoints()
if mountpoints is not None and len(mountpoints) > 0:
return mountpoints[0]
elif self.parent is not None:
if self.parent is not None:
parent = self.__filesystem.get_subvolume_by_id(self.parent)
parent_path = parent.get_mounted_path()
parent_path = parent.get_mounted_path() if parent else None
if parent_path is not None:
return parent_path + os.path.sep + self.name
else:
return None
return f"{parent_path}{os.path.sep}{self.name}"
return None
def get_mountpoints(self):
def get_mountpoints(self) -> list[str]:
return self.__filesystem.get_mountpoints_by_subvolume_id(self.__subvolume_id)
def get_child_relative_path(self, absolute_child_path):
def get_child_relative_path(self, absolute_child_path: str) -> str:
"""
Get the relative path from this subvolume to the named child subvolume.
The provided parameter is expected to be normalized as by normalize_subvolume_path.
@ -222,19 +229,21 @@ class BtrfsSubvolume:
else:
raise BtrfsModuleException(f"Path '{absolute_child_path}' doesn't start with '{path}'")
def get_parent_subvolume(self):
def get_parent_subvolume(self) -> BtrfsSubvolume | None:
parent_id = self.parent
return self.__filesystem.get_subvolume_by_id(parent_id) if parent_id is not None else None
def get_child_subvolumes(self):
def get_child_subvolumes(self) -> list[BtrfsSubvolume]:
return self.__filesystem.get_subvolume_children(self.__subvolume_id)
@property
def __info(self):
return self.__filesystem.get_subvolume_info_for_id(self.__subvolume_id)
def __info(self) -> dict[str, t.Any]:
result = self.__filesystem.get_subvolume_info_for_id(self.__subvolume_id)
# assert result is not None
return result # type: ignore
@property
def id(self):
def id(self) -> int:
return self.__subvolume_id
@property
@ -242,7 +251,7 @@ class BtrfsSubvolume:
return self.path.split("/").pop()
@property
def path(self):
def path(self) -> str:
return self.__info["path"]
@property
@ -255,105 +264,105 @@ class BtrfsFilesystem:
Wrapper class providing convenience methods for inspection of a btrfs filesystem
"""
def __init__(self, info, provider, module):
def __init__(self, info: dict[str, t.Any], provider: BtrfsInfoProvider, module: AnsibleModule) -> None:
self.__provider = provider
# constant for module execution
self.__uuid = info["uuid"]
self.__label = info["label"]
self.__devices = info["devices"]
self.__uuid: str = info["uuid"]
self.__label: str = info["label"]
self.__devices: list[str] = info["devices"]
# refreshable
self.__default_subvolid = info["default_subvolid"] if "default_subvolid" in info else None
self.__default_subvolid: int | None = info["default_subvolid"] if "default_subvolid" in info else None
self.__update_mountpoints(info["mountpoints"] if "mountpoints" in info else [])
self.__update_subvolumes(info["subvolumes"] if "subvolumes" in info else [])
@property
def uuid(self):
def uuid(self) -> str:
return self.__uuid
@property
def label(self):
def label(self) -> str:
return self.__label
@property
def default_subvolid(self):
def default_subvolid(self) -> int | None:
return self.__default_subvolid
@property
def devices(self):
def devices(self) -> list[str]:
return list(self.__devices)
def refresh(self):
def refresh(self) -> None:
self.refresh_mountpoints()
self.refresh_subvolumes()
self.refresh_default_subvolume()
def refresh_mountpoints(self):
def refresh_mountpoints(self) -> None:
mountpoints = self.__provider.get_mountpoints(list(self.__devices))
self.__update_mountpoints(mountpoints)
def __update_mountpoints(self, mountpoints):
self.__mountpoints = dict()
def __update_mountpoints(self, mountpoints: list[dict[str, t.Any]]) -> None:
self.__mountpoints: dict[int, list[str]] = dict()
for i in mountpoints:
subvolid = i["subvolid"]
mountpoint = i["mountpoint"]
subvolid: int = i["subvolid"]
mountpoint: str = i["mountpoint"]
if subvolid not in self.__mountpoints:
self.__mountpoints[subvolid] = []
self.__mountpoints[subvolid].append(mountpoint)
def refresh_subvolumes(self):
def refresh_subvolumes(self) -> None:
filesystem_path = self.get_any_mountpoint()
if filesystem_path is not None:
subvolumes = self.__provider.get_subvolumes(filesystem_path)
self.__update_subvolumes(subvolumes)
def __update_subvolumes(self, subvolumes):
def __update_subvolumes(self, subvolumes: list[dict[str, t.Any]]) -> None:
# TODO strategy for retaining information on deleted subvolumes?
self.__subvolumes = dict()
self.__subvolumes: dict[int, dict[str, t.Any]] = dict()
for subvolume in subvolumes:
self.__subvolumes[subvolume["id"]] = subvolume
def refresh_default_subvolume(self):
def refresh_default_subvolume(self) -> None:
filesystem_path = self.get_any_mountpoint()
if filesystem_path is not None:
self.__default_subvolid = self.__provider.get_default_subvolume_id(filesystem_path)
def contains_device(self, device):
def contains_device(self, device: str) -> bool:
return device in self.__devices
def contains_subvolume(self, subvolume):
def contains_subvolume(self, subvolume: str) -> bool:
return self.get_subvolume_by_name(subvolume) is not None
def get_subvolume_by_id(self, subvolume_id):
def get_subvolume_by_id(self, subvolume_id: int) -> BtrfsSubvolume | None:
return BtrfsSubvolume(self, subvolume_id) if subvolume_id in self.__subvolumes else None
def get_subvolume_info_for_id(self, subvolume_id):
def get_subvolume_info_for_id(self, subvolume_id: int) -> dict[str, t.Any] | None:
return self.__subvolumes[subvolume_id] if subvolume_id in self.__subvolumes else None
def get_subvolume_by_name(self, subvolume):
def get_subvolume_by_name(self, subvolume: str) -> BtrfsSubvolume | None:
for subvolume_info in self.__subvolumes.values():
if subvolume_info["path"] == subvolume:
return BtrfsSubvolume(self, subvolume_info["id"])
return None
def get_any_mountpoint(self):
def get_any_mountpoint(self) -> str | None:
for subvol_mountpoints in self.__mountpoints.values():
if len(subvol_mountpoints) > 0:
return subvol_mountpoints[0]
# maybe error?
return None
def get_any_mounted_subvolume(self):
def get_any_mounted_subvolume(self) -> BtrfsSubvolume | None:
for subvolid, subvol_mountpoints in self.__mountpoints.items():
if len(subvol_mountpoints) > 0:
return self.get_subvolume_by_id(subvolid)
return None
def get_mountpoints_by_subvolume_id(self, subvolume_id):
def get_mountpoints_by_subvolume_id(self, subvolume_id: int) -> list[str]:
return self.__mountpoints[subvolume_id] if subvolume_id in self.__mountpoints else []
def get_nearest_subvolume(self, subvolume):
def get_nearest_subvolume(self, subvolume: str) -> BtrfsSubvolume:
"""Return the identified subvolume if existing, else the closest matching parent"""
subvolumes_by_path = self.__get_subvolumes_by_path()
while len(subvolume) > 1:
@ -364,30 +373,31 @@ class BtrfsFilesystem:
return BtrfsSubvolume(self, 5)
def get_mountpath_as_child(self, subvolume_name):
def get_mountpath_as_child(self, subvolume_name: str) -> str:
"""Find a path to the target subvolume through a mounted ancestor"""
nearest = self.get_nearest_subvolume(subvolume_name)
nearest_or_none: BtrfsSubvolume | None = nearest
if nearest.path == subvolume_name:
nearest = nearest.get_parent_subvolume()
if nearest is None or nearest.get_mounted_path() is None:
nearest_or_none = nearest.get_parent_subvolume()
if nearest_or_none is None or nearest_or_none.get_mounted_path() is None:
raise BtrfsModuleException(f"Failed to find a path '{subvolume_name}' through a mounted parent subvolume")
else:
return nearest.get_mounted_path() + os.path.sep + nearest.get_child_relative_path(subvolume_name)
return f"{nearest_or_none.get_mounted_path()}{os.path.sep}{nearest_or_none.get_child_relative_path(subvolume_name)}"
def get_subvolume_children(self, subvolume_id):
def get_subvolume_children(self, subvolume_id: int) -> list[BtrfsSubvolume]:
return [BtrfsSubvolume(self, x["id"]) for x in self.__subvolumes.values() if x["parent"] == subvolume_id]
def __get_subvolumes_by_path(self):
def __get_subvolumes_by_path(self) -> dict[str, dict[str, t.Any]]:
result = {}
for s in self.__subvolumes.values():
path = s["path"]
result[path] = s
return result
def is_mounted(self):
def is_mounted(self) -> bool:
return self.__mountpoints is not None and len(self.__mountpoints) > 0
def get_summary(self):
def get_summary(self) -> dict[str, t.Any]:
subvolumes = []
sources = self.__subvolumes.values() if self.__subvolumes is not None else []
for subvolume in sources:
@ -415,17 +425,19 @@ class BtrfsFilesystemsProvider:
Provides methods to query available btrfs filesystems
"""
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.__module = module
self.__provider = BtrfsInfoProvider(module)
self.__filesystems = None
self.__filesystems: dict[str, BtrfsFilesystem] | None = None
def get_matching_filesystem(self, criteria):
def get_matching_filesystem(self, criteria: dict[str, t.Any]) -> BtrfsFilesystem:
if criteria["device"] is not None:
criteria["device"] = os.path.realpath(criteria["device"])
self.__check_init()
matching = [f for f in self.__filesystems.values() if self.__filesystem_matches_criteria(f, criteria)]
# assert self.__filesystems is not None # TODO
self_filesystems: dict[str, BtrfsFilesystem] = self.__filesystems # type: ignore
matching = [f for f in self_filesystems.values() if self.__filesystem_matches_criteria(f, criteria)]
if len(matching) == 1:
return matching[0]
else:
@ -433,26 +445,30 @@ class BtrfsFilesystemsProvider:
f"Found {len(matching)} filesystems matching criteria uuid={criteria['uuid']} label={criteria['label']} device={criteria['device']}"
)
def __filesystem_matches_criteria(self, filesystem, criteria):
def __filesystem_matches_criteria(self, filesystem: BtrfsFilesystem, criteria: dict[str, t.Any]):
return (
(criteria["uuid"] is None or filesystem.uuid == criteria["uuid"])
and (criteria["label"] is None or filesystem.label == criteria["label"])
and (criteria["device"] is None or filesystem.contains_device(criteria["device"]))
)
def get_filesystem_for_device(self, device):
def get_filesystem_for_device(self, device: str) -> BtrfsFilesystem | None:
real_device = os.path.realpath(device)
self.__check_init()
for fs in self.__filesystems.values():
# assert self.__filesystems is not None # TODO
self_filesystems: dict[str, BtrfsFilesystem] = self.__filesystems # type: ignore
for fs in self_filesystems.values():
if fs.contains_device(real_device):
return fs
return None
def get_filesystems(self):
def get_filesystems(self) -> list[BtrfsFilesystem]:
self.__check_init()
return list(self.__filesystems.values())
# assert self.__filesystems is not None # TODO
self_filesystems: dict[str, BtrfsFilesystem] = self.__filesystems # type: ignore
return list(self_filesystems.values())
def __check_init(self):
def __check_init(self) -> None:
if self.__filesystems is None:
self.__filesystems = dict()
for f in self.__provider.get_filesystems():

View file

@ -5,11 +5,19 @@
from __future__ import annotations
import os
import typing as t
from ansible.module_utils.common.collections import is_sequence
from ansible.module_utils.common.locale import get_best_parsable_locale
from ansible_collections.community.general.plugins.module_utils import cmd_runner_fmt
if t.TYPE_CHECKING:
from collections.abc import Callable, Mapping, Sequence
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.general.plugins.module_utils.cmd_runner_fmt import ArgFormatType
ArgFormatter = t.Union[ArgFormatType, cmd_runner_fmt._ArgFormat] # noqa: UP007
def _ensure_list(value):
return list(value) if is_sequence(value) else [value]
@ -24,7 +32,7 @@ class CmdRunnerException(Exception):
class MissingArgumentFormat(CmdRunnerException):
def __init__(self, arg, args_order, args_formats):
def __init__(self, arg, args_order: tuple[str, ...], args_formats) -> None:
self.args_order = args_order
self.arg = arg
self.args_formats = args_formats
@ -37,7 +45,7 @@ class MissingArgumentFormat(CmdRunnerException):
class MissingArgumentValue(CmdRunnerException):
def __init__(self, args_order, arg):
def __init__(self, args_order: tuple[str, ...], arg) -> None:
self.args_order = args_order
self.arg = arg
@ -72,19 +80,19 @@ class CmdRunner:
"""
@staticmethod
def _prepare_args_order(order):
return tuple(order) if is_sequence(order) else tuple(order.split())
def _prepare_args_order(order: str | Sequence[str]) -> tuple[str, ...]:
return tuple(order) if is_sequence(order) else tuple(order.split()) # type: ignore
def __init__(
self,
module,
module: AnsibleModule,
command,
arg_formats=None,
default_args_order=(),
check_rc=False,
force_lang="C",
path_prefix=None,
environ_update=None,
arg_formats: Mapping[str, ArgFormatter] | None = None,
default_args_order: str | Sequence[str] = (),
check_rc: bool = False,
force_lang: str = "C",
path_prefix: Sequence[str] | None = None,
environ_update: dict[str, str] | None = None,
):
self.module = module
self.command = _ensure_list(command)
@ -117,10 +125,17 @@ class CmdRunner:
)
@property
def binary(self):
def binary(self) -> str:
return self.command[0]
def __call__(self, args_order=None, output_process=None, check_mode_skip=False, check_mode_return=None, **kwargs):
def __call__(
self,
args_order: str | Sequence[str] | None = None,
output_process: Callable[[int, str, str], t.Any] | None = None,
check_mode_skip: bool = False,
check_mode_return: t.Any | None = None,
**kwargs,
):
if output_process is None:
output_process = _process_as_is
if args_order is None:
@ -146,7 +161,15 @@ class CmdRunner:
class _CmdRunnerContext:
def __init__(self, runner, args_order, output_process, check_mode_skip, check_mode_return, **kwargs):
def __init__(
self,
runner: CmdRunner,
args_order: tuple[str, ...],
output_process: Callable[[int, str, str], t.Any],
check_mode_skip: bool,
check_mode_return: t.Any,
**kwargs,
) -> None:
self.runner = runner
self.args_order = tuple(args_order)
self.output_process = output_process
@ -204,7 +227,7 @@ class _CmdRunnerContext:
return self.results_processed
@property
def run_info(self):
def run_info(self) -> dict[str, t.Any]:
return dict(
check_rc=self.check_rc,
environ_update=self.environ_update,

View file

@ -11,36 +11,46 @@ from functools import wraps
from ansible.module_utils.common.collections import is_sequence
if t.TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Callable, Mapping, Sequence
ArgFormatType = Callable[[t.Any], list[str]]
ArgFormatType = Callable[[t.Any], Sequence[t.Any]]
_T = t.TypeVar("_T")
def _ensure_list(value):
return list(value) if is_sequence(value) else [value]
def _ensure_list(value: _T | Sequence[_T]) -> list[_T]:
return list(value) if is_sequence(value) else [value] # type: ignore # TODO need type assertion for is_sequence
class _ArgFormat:
def __init__(self, func, ignore_none=True, ignore_missing_value=False):
def __init__(
self,
func: ArgFormatType,
ignore_none: bool | None = True,
ignore_missing_value: bool = False,
) -> None:
self.func = func
self.ignore_none = ignore_none
self.ignore_missing_value = ignore_missing_value
def __call__(self, value):
def __call__(self, value: t.Any | None) -> list[str]:
ignore_none = self.ignore_none if self.ignore_none is not None else True
if value is None and ignore_none:
return []
f = self.func
return [str(x) for x in f(value)]
def __str__(self):
def __str__(self) -> str:
return f"<ArgFormat: func={self.func}, ignore_none={self.ignore_none}, ignore_missing_value={self.ignore_missing_value}>"
def __repr__(self):
def __repr__(self) -> str:
return str(self)
def as_bool(args_true, args_false=None, ignore_none=None):
def as_bool(
args_true: Sequence[t.Any] | t.Any,
args_false: Sequence[t.Any] | t.Any | None = None,
ignore_none: bool | None = None,
) -> _ArgFormat:
if args_false is not None:
if ignore_none is None:
ignore_none = False
@ -51,24 +61,24 @@ def as_bool(args_true, args_false=None, ignore_none=None):
)
def as_bool_not(args):
def as_bool_not(args: Sequence[t.Any] | t.Any) -> _ArgFormat:
return as_bool([], args, ignore_none=False)
def as_optval(arg, ignore_none=None):
def as_optval(arg, ignore_none: bool | None = None) -> _ArgFormat:
return _ArgFormat(lambda value: [f"{arg}{value}"], ignore_none=ignore_none)
def as_opt_val(arg, ignore_none=None):
def as_opt_val(arg: str, ignore_none: bool | None = None) -> _ArgFormat:
return _ArgFormat(lambda value: [arg, value], ignore_none=ignore_none)
def as_opt_eq_val(arg, ignore_none=None):
def as_opt_eq_val(arg: str, ignore_none: bool | None = None) -> _ArgFormat:
return _ArgFormat(lambda value: [f"{arg}={value}"], ignore_none=ignore_none)
def as_list(ignore_none=None, min_len=0, max_len=None):
def func(value):
def as_list(ignore_none: bool | None = None, min_len: int = 0, max_len: int | None = None) -> _ArgFormat:
def func(value: t.Any) -> list[t.Any]:
value = _ensure_list(value)
if len(value) < min_len:
raise ValueError(f"Parameter must have at least {min_len} element(s)")
@ -79,17 +89,21 @@ def as_list(ignore_none=None, min_len=0, max_len=None):
return _ArgFormat(func, ignore_none=ignore_none)
def as_fixed(*args):
def as_fixed(*args: t.Any) -> _ArgFormat:
if len(args) == 1 and is_sequence(args[0]):
args = args[0]
return _ArgFormat(lambda value: _ensure_list(args), ignore_none=False, ignore_missing_value=True)
def as_func(func, ignore_none=None):
def as_func(func: ArgFormatType, ignore_none: bool | None = None) -> _ArgFormat:
return _ArgFormat(func, ignore_none=ignore_none)
def as_map(_map, default=None, ignore_none=None):
def as_map(
_map: Mapping[t.Any, Sequence[t.Any] | t.Any],
default: Sequence[t.Any] | t.Any | None = None,
ignore_none: bool | None = None,
) -> _ArgFormat:
if default is None:
default = []
return _ArgFormat(lambda value: _ensure_list(_map.get(value, default)), ignore_none=ignore_none)
@ -126,5 +140,5 @@ def stack(fmt):
return wrapper
def is_argformat(fmt):
def is_argformat(fmt: object) -> t.TypeGuard[_ArgFormat]:
return isinstance(fmt, _ArgFormat)

View file

@ -14,6 +14,9 @@ from urllib.parse import urlencode
from ansible.module_utils.urls import open_url
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def get_consul_url(configuration):
return f"{configuration.scheme}://{configuration.host}:{configuration.port}/v1"
@ -120,7 +123,7 @@ class _ConsulModule:
operational_attributes: set[str] = set()
params: dict[str, t.Any] = {}
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self._module = module
self.params = _normalize_params(module.params, module.argument_spec)
self.api_params = {

View file

@ -6,11 +6,26 @@
from __future__ import annotations
import csv
import typing as t
from io import StringIO
from ansible.module_utils.common.text.converters import to_native
if t.TYPE_CHECKING:
from collections.abc import Sequence
class DialectParamsOrNone(t.TypedDict):
delimiter: t.NotRequired[str | None]
doublequote: t.NotRequired[bool | None]
escapechar: t.NotRequired[str | None]
lineterminator: t.NotRequired[str | None]
quotechar: t.NotRequired[str | None]
quoting: t.NotRequired[int | None]
skipinitialspace: t.NotRequired[bool | None]
strict: t.NotRequired[bool | None]
class CustomDialectFailureError(Exception):
pass
@ -22,7 +37,7 @@ class DialectNotAvailableError(Exception):
CSVError = csv.Error
def initialize_dialect(dialect, **kwargs):
def initialize_dialect(dialect: str, **kwargs: t.Unpack[DialectParamsOrNone]) -> str:
# Add Unix dialect from Python 3
class unix_dialect(csv.Dialect):
"""Describe the usual properties of Unix-generated CSV files."""
@ -43,7 +58,7 @@ def initialize_dialect(dialect, **kwargs):
dialect_params = {k: v for k, v in kwargs.items() if v is not None}
if dialect_params:
try:
csv.register_dialect("custom", dialect, **dialect_params)
csv.register_dialect("custom", dialect, **dialect_params) # type: ignore
except TypeError as e:
raise CustomDialectFailureError(f"Unable to create custom dialect: {e}") from e
dialect = "custom"
@ -51,7 +66,7 @@ def initialize_dialect(dialect, **kwargs):
return dialect
def read_csv(data, dialect, fieldnames=None):
def read_csv(data: str, dialect: str, fieldnames: Sequence[str] | None = None) -> csv.DictReader:
BOM = "\ufeff"
data = to_native(data, errors="surrogate_or_strict")
if data.startswith(BOM):

View file

@ -14,6 +14,10 @@
from __future__ import annotations
import re
import typing as t
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
# Input patterns for is_input_dangerous function:
@ -162,7 +166,7 @@ def is_input_dangerous(string):
return any(pattern.search(string) for pattern in (PATTERN_1, PATTERN_2, PATTERN_3))
def check_input(module, *args):
def check_input(module: AnsibleModule, *args) -> None:
"""Wrapper for is_input_dangerous function."""
needs_to_check = args

View file

@ -8,15 +8,15 @@ from __future__ import annotations
import datetime as _datetime
def ensure_timezone_info(value):
def ensure_timezone_info(value: _datetime.datetime) -> _datetime.datetime:
if value.tzinfo is not None:
return value
return value.astimezone(_datetime.timezone.utc)
def fromtimestamp(value):
def fromtimestamp(value: int | float) -> _datetime.datetime:
return _datetime.datetime.fromtimestamp(value, tz=_datetime.timezone.utc)
def now():
def now() -> _datetime.datetime:
return _datetime.datetime.now(tz=_datetime.timezone.utc)

View file

@ -7,56 +7,60 @@ from __future__ import annotations
import traceback
import typing as t
from contextlib import contextmanager
from ansible.module_utils.basic import missing_required_lib
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
_deps = dict()
_deps: dict[str, _Dependency] = dict()
class _Dependency:
_states = ["pending", "failure", "success"]
def __init__(self, name, reason=None, url=None, msg=None):
def __init__(self, name: str, reason: str | None = None, url: str | None = None, msg: str | None = None) -> None:
self.name = name
self.reason = reason
self.url = url
self.msg = msg
self.state = 0
self.trace = None
self.exc = None
self.trace: str | None = None
self.exc: Exception | None = None
def succeed(self):
def succeed(self) -> None:
self.state = 2
def fail(self, exc, trace):
def fail(self, exc: Exception, trace: str) -> None:
self.state = 1
self.exc = exc
self.trace = trace
@property
def message(self):
def message(self) -> str:
if self.msg:
return str(self.msg)
else:
return missing_required_lib(self.name, reason=self.reason, url=self.url)
@property
def failed(self):
def failed(self) -> bool:
return self.state == 1
def validate(self, module):
def validate(self, module: AnsibleModule) -> None:
if self.failed:
module.fail_json(msg=self.message, exception=self.trace)
def __str__(self):
def __str__(self) -> str:
return f"<dependency: {self.name} [{self._states[self.state]}]>"
@contextmanager
def declare(name, *args, **kwargs):
def declare(name: str, *args, **kwargs) -> t.Generator[_Dependency]:
dep = _Dependency(name, *args, **kwargs)
try:
yield dep
@ -68,7 +72,7 @@ def declare(name, *args, **kwargs):
_deps[name] = dep
def _select_names(spec):
def _select_names(spec: str | None) -> list[str]:
dep_names = sorted(_deps)
if spec:
@ -86,14 +90,14 @@ def _select_names(spec):
return dep_names
def validate(module, spec=None):
def validate(module: AnsibleModule, spec: str | None = None) -> None:
for dep in _select_names(spec):
_deps[dep].validate(module)
def failed(spec=None):
def failed(spec: str | None = None) -> bool:
return any(_deps[d].failed for d in _select_names(spec))
def clear():
def clear() -> None:
_deps.clear()

View file

@ -24,7 +24,6 @@ import os
import re
import traceback
# (TODO: remove AnsibleModule from next line!)
from ansible.module_utils.basic import AnsibleModule, missing_required_lib # noqa: F401, pylint: disable=unused-import
from os.path import expanduser
from uuid import UUID
@ -57,7 +56,7 @@ class DimensionDataModule:
The base class containing common functionality used by Dimension Data modules for Ansible.
"""
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
"""
Create a new DimensionDataModule.

View file

@ -12,7 +12,8 @@ from ansible_collections.community.general.plugins.module_utils.python_runner im
from ansible_collections.community.general.plugins.module_utils.module_helper import ModuleHelper
if t.TYPE_CHECKING:
from .cmd_runner_fmt import ArgFormatType
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.general.plugins.module_utils.cmd_runner import ArgFormatter
django_std_args = dict(
@ -36,7 +37,7 @@ _pks = dict(
primary_keys=dict(type="list", elements="str"),
)
_django_std_arg_fmts: dict[str, ArgFormatType] = dict(
_django_std_arg_fmts: dict[str, ArgFormatter] = dict(
all=cmd_runner_fmt.as_bool("--all"),
app=cmd_runner_fmt.as_opt_val("--app"),
apps=cmd_runner_fmt.as_list(),
@ -81,7 +82,7 @@ _args_menu = dict(
class _DjangoRunner(PythonRunner):
def __init__(self, module, arg_formats=None, **kwargs):
def __init__(self, module: AnsibleModule, arg_formats=None, **kwargs) -> None:
arg_fmts = dict(arg_formats) if arg_formats else {}
arg_fmts.update(_django_std_arg_fmts)
@ -108,12 +109,12 @@ class _DjangoRunner(PythonRunner):
class DjangoModuleHelper(ModuleHelper):
module = {}
django_admin_cmd: str | None = None
arg_formats: dict[str, ArgFormatType] = {}
arg_formats: dict[str, ArgFormatter] = {}
django_admin_arg_order: tuple[str, ...] | str = ()
_django_args: list[str] = []
_check_mode_arg: str = ""
def __init__(self):
def __init__(self) -> None:
self.module["argument_spec"], self.arg_formats = self._build_args(
self.module.get("argument_spec", {}), self.arg_formats, *(["std"] + self._django_args)
)
@ -122,9 +123,9 @@ class DjangoModuleHelper(ModuleHelper):
self.vars.command = self.django_admin_cmd
@staticmethod
def _build_args(arg_spec, arg_format, *names):
res_arg_spec = {}
res_arg_fmts = {}
def _build_args(arg_spec, arg_format, *names) -> tuple[dict[str, t.Any], dict[str, ArgFormatter]]:
res_arg_spec: dict[str, t.Any] = {}
res_arg_fmts: dict[str, ArgFormatter] = {}
for name in names:
args, fmts = _args_menu[name]
res_arg_spec = dict_merge(res_arg_spec, args)

View file

@ -5,9 +5,13 @@
from __future__ import annotations
import json
import typing as t
from ansible.module_utils.urls import fetch_url
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
class GandiLiveDNSAPI:
api_endpoint = "https://api.gandi.net/v5/livedns"
@ -21,7 +25,7 @@ class GandiLiveDNSAPI:
attribute_map = {"record": "rrset_name", "type": "rrset_type", "ttl": "rrset_ttl", "values": "rrset_values"}
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.api_key = module.params["api_key"]
self.personal_access_token = module.params["personal_access_token"]

View file

@ -4,8 +4,13 @@
from __future__ import annotations
import typing as t
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
_state_map = {
"present": "--set",
@ -14,7 +19,7 @@ _state_map = {
}
def gconftool2_runner(module, **kwargs):
def gconftool2_runner(module: AnsibleModule, **kwargs) -> CmdRunner:
return CmdRunner(
module,
command="gconftool-2",

View file

@ -4,10 +4,15 @@
from __future__ import annotations
import typing as t
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def gio_mime_runner(module, **kwargs):
def gio_mime_runner(module: AnsibleModule, **kwargs) -> CmdRunner:
return CmdRunner(
module,
command=["gio"],
@ -21,8 +26,8 @@ def gio_mime_runner(module, **kwargs):
)
def gio_mime_get(runner, mime_type):
def process(rc, out, err):
def gio_mime_get(runner: CmdRunner, mime_type) -> str | None:
def process(rc, out, err) -> str | None:
if err.startswith("No default applications for"):
return None
out = out.splitlines()[0]

View file

@ -5,18 +5,19 @@
from __future__ import annotations
import traceback
import typing as t
from urllib.parse import urljoin
from ansible.module_utils.basic import missing_required_lib
from ansible_collections.community.general.plugins.module_utils.version import LooseVersion
from urllib.parse import urljoin
import traceback
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def _determine_list_all_kwargs(version) -> dict[str, t.Any]:
def _determine_list_all_kwargs(version: str) -> dict[str, t.Any]:
gitlab_version = LooseVersion(version)
if gitlab_version >= LooseVersion("4.0.0"):
# 4.0.0 removed 'as_list'
@ -42,7 +43,7 @@ except Exception:
list_all_kwargs = {}
def auth_argument_spec(spec=None):
def auth_argument_spec(spec: dict[str, t.Any] | None = None) -> dict[str, t.Any]:
arg_spec = dict(
ca_path=dict(type="str"),
api_token=dict(type="str", no_log=True),
@ -76,7 +77,7 @@ def find_group(gitlab_instance, identifier):
return group
def ensure_gitlab_package(module, min_version=None):
def ensure_gitlab_package(module: AnsibleModule, min_version=None) -> None:
if not HAS_GITLAB_PACKAGE:
module.fail_json(
msg=missing_required_lib("python-gitlab", url="https://python-gitlab.readthedocs.io/en/stable/"),
@ -92,7 +93,7 @@ def ensure_gitlab_package(module, min_version=None):
)
def gitlab_authentication(module, min_version=None):
def gitlab_authentication(module: AnsibleModule, min_version=None) -> gitlab.Gitlab:
ensure_gitlab_package(module, min_version=min_version)
gitlab_url = module.params["api_url"]
@ -121,7 +122,7 @@ def gitlab_authentication(module, min_version=None):
private_token=gitlab_token,
oauth_token=gitlab_oauth_token,
job_token=gitlab_job_token,
api_version=4,
api_version="4",
)
gitlab_instance.auth()
except (gitlab.exceptions.GitlabAuthenticationError, gitlab.exceptions.GitlabGetError) as e:
@ -137,7 +138,7 @@ def gitlab_authentication(module, min_version=None):
return gitlab_instance
def filter_returned_variables(gitlab_variables):
def filter_returned_variables(gitlab_variables) -> list[dict[str, t.Any]]:
# pop properties we don't know
existing_variables = [dict(x.attributes) for x in gitlab_variables]
KNOWN = [
@ -158,9 +159,11 @@ def filter_returned_variables(gitlab_variables):
return existing_variables
def vars_to_variables(vars, module):
def vars_to_variables(
vars: dict[str, str | int | float | dict[str, t.Any]], module: AnsibleModule
) -> list[dict[str, t.Any]]:
# transform old vars to new variables structure
variables = list()
variables = []
for item, value in vars.items():
if isinstance(value, (str, int, float)):
variables.append(

View file

@ -5,9 +5,13 @@
from __future__ import annotations
import traceback
import typing as t
from ansible.module_utils.basic import env_fallback, missing_required_lib
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
HAS_HEROKU = False
HEROKU_IMP_ERR = None
try:
@ -19,17 +23,17 @@ except ImportError:
class HerokuHelper:
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.check_lib()
self.api_key = module.params["api_key"]
def check_lib(self):
def check_lib(self) -> None:
if not HAS_HEROKU:
self.module.fail_json(msg=missing_required_lib("heroku3"), exception=HEROKU_IMP_ERR)
@staticmethod
def heroku_argument_spec():
def heroku_argument_spec() -> dict[str, t.Any]:
return dict(
api_key=dict(fallback=(env_fallback, ["HEROKU_API_KEY", "TF_VAR_HEROKU_API_KEY"]), type="str", no_log=True)
)

View file

@ -7,9 +7,13 @@ from __future__ import annotations
import os
import re
import typing as t
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def _create_regex_group_complement(s):
def _create_regex_group_complement(s: str) -> re.Pattern:
lines = (line.strip() for line in s.split("\n") if line.strip())
chars = [_f for _f in (line.split("#")[0].strip() for line in lines) if _f]
group = rf"[^{''.join(chars)}]"
@ -52,7 +56,7 @@ class HomebrewValidate:
# class validations -------------------------------------------- {{{
@classmethod
def valid_path(cls, path):
def valid_path(cls, path: list[str] | str) -> bool:
"""
`path` must be one of:
- list of paths
@ -77,7 +81,7 @@ class HomebrewValidate:
return all(cls.valid_brew_path(path_) for path_ in paths)
@classmethod
def valid_brew_path(cls, brew_path):
def valid_brew_path(cls, brew_path: str | None) -> bool:
"""
`brew_path` must be one of:
- None
@ -95,7 +99,7 @@ class HomebrewValidate:
return isinstance(brew_path, str) and not cls.INVALID_BREW_PATH_REGEX.search(brew_path)
@classmethod
def valid_package(cls, package):
def valid_package(cls, package: str | None) -> bool:
"""A valid package is either None or alphanumeric."""
if package is None:
@ -104,8 +108,7 @@ class HomebrewValidate:
return isinstance(package, str) and not cls.INVALID_PACKAGE_REGEX.search(package)
def parse_brew_path(module):
# type: (...) -> str
def parse_brew_path(module: AnsibleModule) -> str:
"""Attempt to find the Homebrew executable path.
Requires:

View file

@ -7,6 +7,7 @@ from __future__ import annotations
import re
import time
import traceback
import typing as t
THIRD_LIBRARIES_IMP_ERR = None
try:
@ -24,37 +25,37 @@ from ansible.module_utils.common.text.converters import to_text
class HwcModuleException(Exception):
def __init__(self, message):
def __init__(self, message: str) -> None:
super().__init__()
self._message = message
def __str__(self):
def __str__(self) -> str:
return f"[HwcClientException] message={self._message}"
class HwcClientException(Exception):
def __init__(self, code, message):
def __init__(self, code: int, message: str) -> None:
super().__init__()
self._code = code
self._message = message
def __str__(self):
def __str__(self) -> str:
msg = f" code={self._code}," if self._code != 0 else ""
return f"[HwcClientException]{msg} message={self._message}"
class HwcClientException404(HwcClientException):
def __init__(self, message):
def __init__(self, message: str) -> None:
super().__init__(404, message)
def __str__(self):
def __str__(self) -> str:
return f"[HwcClientException404] message={self._message}"
def session_method_wrapper(f):
def _wrap(self, url, *args, **kwargs):
def _wrap(self, url: str, *args, **kwargs):
try:
url = self.endpoint + url
r = f(self, url, *args, **kwargs)
@ -91,7 +92,7 @@ def session_method_wrapper(f):
class _ServiceClient:
def __init__(self, client, endpoint, product):
def __init__(self, client, endpoint: str, product):
self._client = client
self._endpoint = endpoint
self._default_header = {
@ -100,30 +101,30 @@ class _ServiceClient:
}
@property
def endpoint(self):
def endpoint(self) -> str:
return self._endpoint
@endpoint.setter
def endpoint(self, e):
def endpoint(self, e: str) -> None:
self._endpoint = e
@session_method_wrapper
def get(self, url, body=None, header=None, timeout=None):
def get(self, url: str, body=None, header: dict[str, t.Any] | None = None, timeout=None):
return self._client.get(url, json=body, timeout=timeout, headers=self._header(header))
@session_method_wrapper
def post(self, url, body=None, header=None, timeout=None):
def post(self, url: str, body=None, header: dict[str, t.Any] | None = None, timeout=None):
return self._client.post(url, json=body, timeout=timeout, headers=self._header(header))
@session_method_wrapper
def delete(self, url, body=None, header=None, timeout=None):
def delete(self, url: str, body=None, header: dict[str, t.Any] | None = None, timeout=None):
return self._client.delete(url, json=body, timeout=timeout, headers=self._header(header))
@session_method_wrapper
def put(self, url, body=None, header=None, timeout=None):
def put(self, url: str, body=None, header: dict[str, t.Any] | None = None, timeout=None):
return self._client.put(url, json=body, timeout=timeout, headers=self._header(header))
def _header(self, header):
def _header(self, header: dict[str, t.Any] | None) -> dict[str, t.Any]:
if header and isinstance(header, dict):
for k, v in self._default_header.items():
if k not in header:
@ -135,18 +136,18 @@ class _ServiceClient:
class Config:
def __init__(self, module, product):
def __init__(self, module: AnsibleModule, product) -> None:
self._project_client = None
self._domain_client = None
self._module = module
self._product = product
self._endpoints = {}
self._endpoints: dict[str, t.Any] = {}
self._validate()
self._gen_provider_client()
@property
def module(self):
def module(self) -> AnsibleModule:
return self._module
def client(self, region, service_type, service_level):
@ -380,7 +381,7 @@ def navigate_value(data, index, array_index=None):
return d
def build_path(module, path, kv=None):
def build_path(module: AnsibleModule, path, kv=None):
if kv is None:
kv = dict()
@ -400,7 +401,7 @@ def build_path(module, path, kv=None):
return path.format(**v)
def get_region(module):
def get_region(module: AnsibleModule):
if module.params["region"]:
return module.params["region"]

View file

@ -7,10 +7,14 @@
from __future__ import annotations
import traceback
import typing as t
from functools import wraps
from ansible.module_utils.basic import missing_required_lib
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
PYXCLI_INSTALLED = True
PYXCLI_IMP_ERR = None
try:
@ -49,7 +53,7 @@ def xcli_wrapper(func):
"""Catch xcli errors and return a proper message"""
@wraps(func)
def wrapper(module, *args, **kwargs):
def wrapper(module: AnsibleModule, *args, **kwargs):
try:
return func(module, *args, **kwargs)
except errors.CommandExecutionError as e:
@ -59,7 +63,7 @@ def xcli_wrapper(func):
@xcli_wrapper
def connect_ssl(module):
def connect_ssl(module: AnsibleModule):
endpoints = module.params["endpoints"]
username = module.params["username"]
password = module.params["password"]
@ -72,7 +76,7 @@ def connect_ssl(module):
module.fail_json(msg=f"Connection with Spectrum Accelerate system has failed: {e}.")
def spectrum_accelerate_spec():
def spectrum_accelerate_spec() -> dict[str, t.Any]:
"""Return arguments spec for AnsibleModule"""
return dict(
endpoints=dict(required=True),
@ -82,7 +86,7 @@ def spectrum_accelerate_spec():
@xcli_wrapper
def execute_pyxcli_command(module, xcli_command, xcli_client):
def execute_pyxcli_command(module: AnsibleModule, xcli_command, xcli_client):
pyxcli_args = build_pyxcli_command(module.params)
getattr(xcli_client.cmd, xcli_command)(**(pyxcli_args))
return True
@ -99,6 +103,6 @@ def build_pyxcli_command(fields):
return pyxcli_args
def is_pyxcli_installed(module):
def is_pyxcli_installed(module: AnsibleModule) -> None:
if not PYXCLI_INSTALLED:
module.fail_json(msg=missing_required_lib("pyxcli"), exception=PYXCLI_IMP_ERR)

View file

@ -4,13 +4,15 @@
from __future__ import annotations
from ansible_collections.community.general.plugins.module_utils.redfish_utils import RedfishUtils
import time
import typing as t
from ansible_collections.community.general.plugins.module_utils.redfish_utils import RedfishUtils
class iLORedfishUtils(RedfishUtils):
def get_ilo_sessions(self):
result = {}
def get_ilo_sessions(self) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
# listing all users has always been slower than other operations, why?
session_list = []
sessions_results = []
@ -48,8 +50,8 @@ class iLORedfishUtils(RedfishUtils):
result["ret"] = True
return result
def set_ntp_server(self, mgr_attributes):
result = {}
def set_ntp_server(self, mgr_attributes) -> dict[str, t.Any]:
result: dict[str, t.Any] = {}
setkey = mgr_attributes["mgr_attr_name"]
nic_info = self.get_manager_ethernet_uri()
@ -60,7 +62,7 @@ class iLORedfishUtils(RedfishUtils):
return response
result["ret"] = True
data = response["data"]
payload = {"DHCPv4": {"UseNTPServers": ""}}
payload: dict[str, t.Any] = {"DHCPv4": {"UseNTPServers": ""}}
if data["DHCPv4"]["UseNTPServers"]:
payload["DHCPv4"]["UseNTPServers"] = False
@ -97,7 +99,7 @@ class iLORedfishUtils(RedfishUtils):
return {"ret": True, "changed": True, "msg": f"Modified {mgr_attributes['mgr_attr_name']}"}
def set_time_zone(self, attr):
def set_time_zone(self, attr) -> dict[str, t.Any]:
key = attr["mgr_attr_name"]
uri = f"{self.manager_uri}DateTime/"
@ -124,7 +126,7 @@ class iLORedfishUtils(RedfishUtils):
return {"ret": True, "changed": True, "msg": f"Modified {attr['mgr_attr_name']}"}
def set_dns_server(self, attr):
def set_dns_server(self, attr) -> dict[str, t.Any]:
key = attr["mgr_attr_name"]
nic_info = self.get_manager_ethernet_uri()
uri = nic_info["nic_addr"]
@ -148,7 +150,7 @@ class iLORedfishUtils(RedfishUtils):
return {"ret": True, "changed": True, "msg": f"Modified {attr['mgr_attr_name']}"}
def set_domain_name(self, attr):
def set_domain_name(self, attr) -> dict[str, t.Any]:
key = attr["mgr_attr_name"]
nic_info = self.get_manager_ethernet_uri()
@ -160,7 +162,7 @@ class iLORedfishUtils(RedfishUtils):
data = response["data"]
payload = {"DHCPv4": {"UseDomainName": ""}}
payload: dict[str, t.Any] = {"DHCPv4": {"UseDomainName": ""}}
if data["DHCPv4"]["UseDomainName"]:
payload["DHCPv4"]["UseDomainName"] = False
@ -185,7 +187,7 @@ class iLORedfishUtils(RedfishUtils):
return response
return {"ret": True, "changed": True, "msg": f"Modified {attr['mgr_attr_name']}"}
def set_wins_registration(self, mgrattr):
def set_wins_registration(self, mgrattr) -> dict[str, t.Any]:
Key = mgrattr["mgr_attr_name"]
nic_info = self.get_manager_ethernet_uri()
@ -198,7 +200,7 @@ class iLORedfishUtils(RedfishUtils):
return response
return {"ret": True, "changed": True, "msg": f"Modified {mgrattr['mgr_attr_name']}"}
def get_server_poststate(self):
def get_server_poststate(self) -> dict[str, t.Any]:
# Get server details
response = self.get_request(self.root_uri + self.systems_uri)
if not response["ret"]:
@ -210,7 +212,7 @@ class iLORedfishUtils(RedfishUtils):
else:
return {"ret": True, "server_poststate": server_data["Oem"]["Hp"]["PostState"]}
def wait_for_ilo_reboot_completion(self, polling_interval=60, max_polling_time=1800):
def wait_for_ilo_reboot_completion(self, polling_interval=60, max_polling_time=1800) -> dict[str, t.Any]:
# This method checks if OOB controller reboot is completed
time.sleep(10)

View file

@ -5,11 +5,15 @@
from __future__ import annotations
import traceback
import typing as t
from ansible.module_utils.basic import missing_required_lib
from ansible_collections.community.general.plugins.module_utils.version import LooseVersion
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
REQUESTS_IMP_ERR = None
try:
import requests.exceptions # noqa: F401, pylint: disable=unused-import
@ -32,7 +36,7 @@ except ImportError:
class InfluxDb:
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.params = self.module.params
self.check_lib()
@ -43,7 +47,7 @@ class InfluxDb:
self.password = self.params["password"]
self.database_name = self.params.get("database_name")
def check_lib(self):
def check_lib(self) -> None:
if not HAS_REQUESTS:
self.module.fail_json(msg=missing_required_lib("requests"), exception=REQUESTS_IMP_ERR)
@ -51,7 +55,7 @@ class InfluxDb:
self.module.fail_json(msg=missing_required_lib("influxdb"), exception=INFLUXDB_IMP_ERR)
@staticmethod
def influxdb_argument_spec():
def influxdb_argument_spec() -> dict[str, t.Any]:
return dict(
hostname=dict(type="str", default="localhost"),
port=dict(type="int", default=8086),
@ -67,7 +71,7 @@ class InfluxDb:
udp_port=dict(type="int", default=4444),
)
def connect_to_influxdb(self):
def connect_to_influxdb(self) -> InfluxDBClient:
args = dict(
host=self.hostname,
port=self.port,

View file

@ -13,17 +13,21 @@ from __future__ import annotations
import json
import os
import re
import socket
import uuid
import re
from ansible.module_utils.common.text.converters import to_bytes, to_text
from ansible.module_utils.urls import fetch_url, HAS_GSSAPI
from ansible.module_utils.basic import env_fallback, AnsibleFallbackNotFound
import typing as t
from urllib.parse import quote
from ansible.module_utils.basic import env_fallback, AnsibleFallbackNotFound
from ansible.module_utils.common.text.converters import to_bytes, to_text
from ansible.module_utils.urls import fetch_url, HAS_GSSAPI
def _env_then_dns_fallback(*args, **kwargs):
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def _env_then_dns_fallback(*args, **kwargs) -> str:
"""Load value from environment or DNS in that order"""
try:
result = env_fallback(*args, **kwargs)
@ -41,7 +45,7 @@ def _env_then_dns_fallback(*args, **kwargs):
class IPAClient:
def __init__(self, module, host, port, protocol):
def __init__(self, module: AnsibleModule, host, port, protocol):
self.host = host
self.port = port
self.protocol = protocol
@ -50,10 +54,10 @@ class IPAClient:
self.timeout = module.params.get("ipa_timeout")
self.use_gssapi = False
def get_base_url(self):
def get_base_url(self) -> str:
return f"{self.protocol}://{self.host}/ipa"
def get_json_url(self):
def get_json_url(self) -> str:
return f"{self.get_base_url()}/session/json"
def login(self, username, password):
@ -98,7 +102,7 @@ class IPAClient:
{"referer": self.get_base_url(), "Content-Type": "application/json", "Accept": "application/json"}
)
def _fail(self, msg, e):
def _fail(self, msg: str, e) -> t.NoReturn:
if "message" in e:
err_string = e.get("message")
else:
@ -205,7 +209,7 @@ class IPAClient:
return changed
def ipa_argument_spec():
def ipa_argument_spec() -> dict[str, t.Any]:
return dict(
ipa_prot=dict(type="str", default="https", choices=["http", "https"], fallback=(env_fallback, ["IPA_PROT"])),
ipa_host=dict(type="str", default="ipa.example.com", fallback=(_env_then_dns_fallback, ["IPA_HOST"])),

View file

@ -10,7 +10,7 @@ import os
import time
def download_updates_file(updates_expiration):
def download_updates_file(updates_expiration: int | float) -> tuple[str, bool]:
updates_filename = "jenkins-plugin-cache.json"
updates_dir = os.path.expanduser("~/.ansible/tmp")
updates_file = os.path.join(updates_dir, updates_filename)

View file

@ -9,6 +9,10 @@ from __future__ import annotations
import re
import traceback
import typing as t
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
try:
import ldap
@ -26,7 +30,7 @@ except ImportError:
HAS_LDAP = False
def gen_specs(**specs):
def gen_specs(**specs: t.Any) -> dict[str, t.Any]:
specs.update(
{
"bind_dn": dict(),
@ -47,12 +51,12 @@ def gen_specs(**specs):
return specs
def ldap_required_together():
def ldap_required_together() -> list[list[str]]:
return [["client_cert", "client_key"]]
class LdapGeneric:
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
# Shortcuts
self.module = module
self.bind_dn = self.module.params["bind_dn"]
@ -76,11 +80,11 @@ class LdapGeneric:
else:
self.dn = self.module.params["dn"]
def fail(self, msg, exn):
def fail(self, msg: str, exn: str | Exception) -> t.NoReturn:
self.module.fail_json(msg=msg, details=f"{exn}", exception=traceback.format_exc())
def _find_dn(self):
dn = self.module.params["dn"]
def _find_dn(self) -> str:
dn: str = self.module.params["dn"]
explode_dn = ldap.dn.explode_dn(dn)
@ -130,7 +134,7 @@ class LdapGeneric:
return connection
def _xorder_dn(self):
def _xorder_dn(self) -> bool:
# match X_ORDERed DNs
regex = r".+\{\d+\}.+"
explode_dn = ldap.dn.explode_dn(self.module.params["dn"])

View file

@ -14,6 +14,6 @@ from __future__ import annotations
from ansible.module_utils.ansible_release import __version__ as ansible_version
def get_user_agent(module):
def get_user_agent(module: str) -> str:
"""Retrieve a user-agent to send with LinodeClient requests."""
return f"Ansible-{module}/{ansible_version}"

View file

@ -4,10 +4,15 @@
from __future__ import annotations
import typing as t
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def locale_runner(module):
def locale_runner(module: AnsibleModule) -> CmdRunner:
runner = CmdRunner(
module,
command=["locale", "-a"],
@ -16,7 +21,7 @@ def locale_runner(module):
return runner
def locale_gen_runner(module):
def locale_gen_runner(module: AnsibleModule) -> CmdRunner:
runner = CmdRunner(
module,
command="locale-gen",

View file

@ -6,10 +6,11 @@ from __future__ import annotations
import http.client as http_client
import json
import os
import socket
import ssl
import json
import typing as t
from urllib.parse import urlparse
from ansible.module_utils.urls import generic_urlparse
@ -20,7 +21,7 @@ HTTPSConnection = http_client.HTTPSConnection
class UnixHTTPConnection(HTTPConnection):
def __init__(self, path):
def __init__(self, path: str) -> None:
HTTPConnection.__init__(self, "localhost")
self.path = path
@ -31,33 +32,34 @@ class UnixHTTPConnection(HTTPConnection):
class LXDClientException(Exception):
def __init__(self, msg, **kwargs):
def __init__(self, msg: str, **kwargs) -> None:
self.msg = msg
self.kwargs = kwargs
class LXDClient:
def __init__(
self, url, key_file=None, cert_file=None, debug=False, server_cert_file=None, server_check_hostname=True
):
self,
url: str,
key_file: str | None = None,
cert_file: str | None = None,
debug: bool = False,
server_cert_file: str | None = None,
server_check_hostname: bool = True,
) -> None:
"""LXD Client.
:param url: The URL of the LXD server. (e.g. unix:/var/lib/lxd/unix.socket or https://127.0.0.1)
:type url: ``str``
:param key_file: The path of the client certificate key file.
:type key_file: ``str``
:param cert_file: The path of the client certificate file.
:type cert_file: ``str``
:param debug: The debug flag. The request and response are stored in logs when debug is true.
:type debug: ``bool``
:param server_cert_file: The path of the server certificate file.
:type server_cert_file: ``str``
:param server_check_hostname: Whether to check the server's hostname as part of TLS verification.
:type debug: ``bool``
"""
self.url = url
self.debug = debug
self.logs = []
self.logs: list[dict[str, t.Any]] = []
self.connection: UnixHTTPConnection | HTTPSConnection
if url.startswith("https:"):
self.cert_file = cert_file
self.key_file = key_file
@ -67,7 +69,7 @@ class LXDClient:
# Check that the received cert is signed by the provided server_cert_file
ctx.load_verify_locations(cafile=server_cert_file)
ctx.check_hostname = server_check_hostname
ctx.load_cert_chain(cert_file, keyfile=key_file)
ctx.load_cert_chain(cert_file, keyfile=key_file) # type: ignore # TODO!
self.connection = HTTPSConnection(parts.get("netloc"), context=ctx)
elif url.startswith("unix:"):
unix_socket_path = url[len("unix:") :]
@ -75,7 +77,7 @@ class LXDClient:
else:
raise LXDClientException("URL scheme must be unix: or https:")
def do(self, method, url, body_json=None, ok_error_codes=None, timeout=None, wait_for_container=None):
def do(self, method: str, url: str, body_json=None, ok_error_codes=None, timeout=None, wait_for_container=None):
resp_json = self._send_request(method, url, body_json=body_json, ok_error_codes=ok_error_codes, timeout=timeout)
if resp_json["type"] == "async":
url = f"{resp_json['operation']}/wait"
@ -91,7 +93,7 @@ class LXDClient:
body_json = {"type": "client", "password": trust_password}
return self._send_request("POST", "/1.0/certificates", body_json=body_json)
def _send_request(self, method, url, body_json=None, ok_error_codes=None, timeout=None):
def _send_request(self, method: str, url: str, body_json=None, ok_error_codes=None, timeout=None):
try:
body = json.dumps(body_json)
self.connection.request(method, url, body=body)
@ -133,9 +135,9 @@ class LXDClient:
return err
def default_key_file():
def default_key_file() -> str:
return os.path.expanduser("~/.config/lxc/client.key")
def default_cert_file():
def default_cert_file() -> str:
return os.path.expanduser("~/.config/lxc/client.crt")

View file

@ -15,9 +15,13 @@ from __future__ import annotations
import os
import traceback
import typing as t
from ansible.module_utils.basic import missing_required_lib
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
CLIENT_IMP_ERR = None
try:
from manageiq_client.api import ManageIQClient
@ -28,7 +32,7 @@ except ImportError:
HAS_CLIENT = False
def manageiq_argument_spec():
def manageiq_argument_spec() -> dict[str, t.Any]:
options = dict(
url=dict(default=os.environ.get("MIQ_URL", None)),
username=dict(default=os.environ.get("MIQ_USERNAME", None)),
@ -43,27 +47,28 @@ def manageiq_argument_spec():
)
def check_client(module):
def check_client(module: AnsibleModule) -> None:
if not HAS_CLIENT:
module.fail_json(msg=missing_required_lib("manageiq-client"), exception=CLIENT_IMP_ERR)
def validate_connection_params(module):
params = module.params["manageiq_connection"]
def validate_connection_params(module: AnsibleModule) -> dict[str, t.Any]:
params: dict[str, t.Any] = module.params["manageiq_connection"]
error_str = "missing required argument: manageiq_connection[{}]"
url = params["url"]
token = params["token"]
username = params["username"]
password = params["password"]
url: str | None = params["url"]
token: str | None = params["token"]
username: str | None = params["username"]
password: str | None = params["password"]
if (url and username and password) or (url and token):
return params
for arg in ["url", "username", "password"]:
if params[arg] in (None, ""):
module.fail_json(msg=error_str.format(arg))
raise AssertionError("should be unreachable")
def manageiq_entities():
def manageiq_entities() -> dict[str, str]:
return {
"provider": "providers",
"host": "hosts",
@ -87,7 +92,7 @@ class ManageIQ:
class encapsulating ManageIQ API client.
"""
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
# handle import errors
check_client(module)
@ -111,7 +116,7 @@ class ManageIQ:
self.module.fail_json(msg=f"failed to open connection ({url}): {e}")
@property
def module(self):
def module(self) -> AnsibleModule:
"""Ansible module module
Returns:
@ -120,7 +125,7 @@ class ManageIQ:
return self._module
@property
def api_url(self):
def api_url(self) -> str:
"""Base ManageIQ API
Returns:
@ -192,7 +197,7 @@ class ManageIQPolicies:
Object to execute policies management operations of manageiq resources.
"""
def __init__(self, manageiq, resource_type, resource_id):
def __init__(self, manageiq: ManageIQ, resource_type, resource_id):
self.manageiq = manageiq
self.module = self.manageiq.module
@ -330,7 +335,7 @@ class ManageIQTags:
Object to execute tags management operations of manageiq resources.
"""
def __init__(self, manageiq, resource_type, resource_id):
def __init__(self, manageiq: ManageIQ, resource_type, resource_id):
self.manageiq = manageiq
self.module = self.manageiq.module

View file

@ -8,12 +8,16 @@ from __future__ import annotations
import json
import os
import uuid
import typing as t
from urllib.error import URLError, HTTPError
from urllib.parse import urlparse
from ansible.module_utils.urls import open_url
from ansible.module_utils.common.text.converters import to_native
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
GET_HEADERS = {"accept": "application/json"}
PUT_HEADERS = {"content-type": "application/json", "accept": "application/json"}
@ -24,7 +28,7 @@ HEALTH_OK = 5
class OcapiUtils:
def __init__(self, creds, base_uri, proxy_slot_number, timeout, module):
def __init__(self, creds, base_uri, proxy_slot_number, timeout, module: AnsibleModule) -> None:
self.root_uri = base_uri
self.proxy_slot_number = proxy_slot_number
self.creds = creds

View file

@ -14,11 +14,11 @@ class OnePasswordConfig:
"~/.config/.op/config",
)
def __init__(self):
def __init__(self) -> None:
self._config_file_path = ""
@property
def config_file_path(self):
def config_file_path(self) -> str | None:
if self._config_file_path:
return self._config_file_path
@ -27,3 +27,5 @@ class OnePasswordConfig:
if os.path.exists(realpath):
self._config_file_path = realpath
return self._config_file_path
return None

View file

@ -6,12 +6,16 @@ from __future__ import annotations
import json
import sys
import typing as t
from ansible.module_utils.basic import env_fallback
from ansible.module_utils.urls import fetch_url
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def online_argument_spec():
def online_argument_spec() -> dict[str, t.Any]:
return dict(
api_token=dict(
required=True,
@ -28,7 +32,7 @@ def online_argument_spec():
class OnlineException(Exception):
def __init__(self, message):
def __init__(self, message: str) -> None:
self.message = message
@ -60,7 +64,7 @@ class Response:
class Online:
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.headers = {
"Authorization": f"Bearer {self.module.params.get('api_token')}",

View file

@ -5,9 +5,14 @@
from __future__ import annotations
import re
import typing as t
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
_state_map = {
"present": "create",
"absent": "remove",
@ -46,7 +51,7 @@ def fmt_resource_argument(value):
return ["--group" if value["argument_action"] == "group" else value["argument_action"]] + value["argument_option"]
def get_pacemaker_maintenance_mode(runner):
def get_pacemaker_maintenance_mode(runner: CmdRunner) -> bool:
with runner("cli_action config") as ctx:
rc, out, err = ctx.run(cli_action="property")
maint_mode_re = re.compile(r"maintenance-mode.*true", re.IGNORECASE)
@ -54,7 +59,7 @@ def get_pacemaker_maintenance_mode(runner):
return bool(maintenance_mode_output)
def pacemaker_runner(module, **kwargs):
def pacemaker_runner(module: AnsibleModule, **kwargs) -> CmdRunner:
runner_command = ["pcs"]
runner = CmdRunner(
module,

View file

@ -4,12 +4,14 @@
from __future__ import annotations
import json
import typing as t
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
pipx_common_argspec = {
"global": dict(type="bool", default=False),
@ -36,7 +38,7 @@ _state_map = dict(
)
def pipx_runner(module, command, **kwargs):
def pipx_runner(module: AnsibleModule, command, **kwargs) -> CmdRunner:
arg_formats = dict(
state=cmd_runner_fmt.as_map(_state_map),
name=cmd_runner_fmt.as_list(),

View file

@ -4,8 +4,13 @@
from __future__ import annotations
import typing as t
from ansible_collections.community.general.plugins.module_utils import deps
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
with deps.declare("packaging"):
from packaging.requirements import Requirement
@ -13,11 +18,11 @@ with deps.declare("packaging"):
class PackageRequirement:
def __init__(self, module, name):
def __init__(self, module: AnsibleModule, name: str) -> None:
self.module = module
self.parsed_name, self.requirement = self._parse_spec(name)
def _parse_spec(self, name):
def _parse_spec(self, name: str) -> tuple[str, Requirement | None]:
"""
Parse a package name that may include version specifiers using PEP 508.
Returns a tuple of (name, requirement) where requirement is of type packaging.requirements.Requirement and it may be None.
@ -49,7 +54,7 @@ class PackageRequirement:
except Exception as e:
raise ValueError(f"Invalid package specification for '{name}': {e}") from e
def matches_version(self, version):
def matches_version(self, version: str):
"""
Check if a version string fulfills a version specifier.

View file

@ -4,29 +4,32 @@
from __future__ import annotations
import os
import typing as t
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
_PUPPET_PATH_PREFIX = ["/opt/puppetlabs/bin"]
def get_facter_dir():
def get_facter_dir() -> str:
if os.getuid() == 0:
return "/etc/facter/facts.d"
else:
return os.path.expanduser("~/.facter/facts.d")
def _puppet_cmd(module):
def _puppet_cmd(module: AnsibleModule) -> str | None:
return module.get_bin_path("puppet", False, _PUPPET_PATH_PREFIX)
# If the `timeout` CLI command feature is removed,
# Then we could add this as a fixed param to `puppet_runner`
def ensure_agent_enabled(module):
def ensure_agent_enabled(module: AnsibleModule) -> None:
runner = CmdRunner(
module,
command="puppet",
@ -44,7 +47,7 @@ def ensure_agent_enabled(module):
module.fail_json(msg="Puppet agent state could not be determined.")
def puppet_runner(module):
def puppet_runner(module: AnsibleModule) -> CmdRunner:
# Keeping backward compatibility, allow for running with the `timeout` CLI command.
# If this can be replaced with ansible `timeout` parameter in playbook,
# then this function could be removed.

View file

@ -5,31 +5,35 @@
from __future__ import annotations
import os
import typing as t
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, _ensure_list
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
class PythonRunner(CmdRunner):
def __init__(
self,
module,
module: AnsibleModule,
command,
arg_formats=None,
default_args_order=(),
check_rc=False,
force_lang="C",
path_prefix=None,
environ_update=None,
python="python",
venv=None,
):
check_rc: bool = False,
force_lang: str = "C",
path_prefix: list[str] | None = None,
environ_update: dict[str, str] | None = None,
python: str = "python",
venv: str | None = None,
) -> None:
self.python = python
self.venv = venv
self.has_venv = venv is not None
if os.path.isabs(python) or "/" in python:
self.python = python
elif self.has_venv:
elif venv is not None:
if path_prefix is None:
path_prefix = []
path_prefix.append(os.path.join(venv, "bin"))

View file

@ -10,6 +10,8 @@ import os
import random
import string
import time
import typing as t
from ansible.module_utils.urls import open_url
from ansible.module_utils.common.text.converters import to_native
from ansible.module_utils.common.text.converters import to_text
@ -17,6 +19,10 @@ from ansible.module_utils.common.text.converters import to_bytes
from urllib.error import URLError, HTTPError
from urllib.parse import urlparse
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
GET_HEADERS = {"accept": "application/json", "OData-Version": "4.0"}
POST_HEADERS = {"content-type": "application/json", "accept": "application/json", "OData-Version": "4.0"}
PATCH_HEADERS = {"content-type": "application/json", "accept": "application/json", "OData-Version": "4.0"}
@ -49,15 +55,15 @@ REDFISH_COMMON_ARGUMENT_SPEC = {
class RedfishUtils:
def __init__(
self,
creds,
root_uri,
creds: dict[str, str],
root_uri: str,
timeout,
module,
module: AnsibleModule,
resource_id=None,
data_modification=False,
strip_etag_quotes=False,
ciphers=None,
):
data_modification: bool = False,
strip_etag_quotes: bool = False,
ciphers: str | None = None,
) -> None:
self.root_uri = root_uri
self.creds = creds
self.timeout = timeout
@ -73,7 +79,7 @@ class RedfishUtils:
self.validate_certs = module.params.get("validate_certs", False)
self.ca_path = module.params.get("ca_path")
def _auth_params(self, headers):
def _auth_params(self, headers: dict[str, str]) -> tuple[str | None, str | None, bool]:
"""
Return tuple of required authentication params based on the presence
of a token in the self.creds dict. If using a token, set the
@ -151,7 +157,7 @@ class RedfishUtils:
resp["msg"] = f"Properties in {uri} are already set"
return resp
def _request(self, uri, **kwargs):
def _request(self, uri: str, **kwargs):
kwargs.setdefault("validate_certs", self.validate_certs)
kwargs.setdefault("follow_redirects", "all")
kwargs.setdefault("use_proxy", True)
@ -163,7 +169,9 @@ class RedfishUtils:
return resp, headers
# The following functions are to send GET/POST/PATCH/DELETE requests
def get_request(self, uri, override_headers=None, allow_no_resp=False, timeout=None):
def get_request(
self, uri: str, override_headers: dict[str, str] | None = None, allow_no_resp: bool = False, timeout=None
):
req_headers = dict(GET_HEADERS)
if override_headers:
req_headers.update(override_headers)
@ -206,7 +214,7 @@ class RedfishUtils:
return {"ret": False, "msg": f"Failed GET request to '{uri}': '{e}'"}
return {"ret": True, "data": data, "headers": headers, "resp": resp}
def post_request(self, uri, pyld, multipart=False):
def post_request(self, uri: str, pyld, multipart: bool = False):
req_headers = dict(POST_HEADERS)
username, password, basic_auth = self._auth_params(req_headers)
try:
@ -251,7 +259,7 @@ class RedfishUtils:
return {"ret": False, "msg": f"Failed POST request to '{uri}': '{e}'"}
return {"ret": True, "data": data, "headers": headers, "resp": resp}
def patch_request(self, uri, pyld, check_pyld=False):
def patch_request(self, uri: str, pyld, check_pyld: bool = False):
req_headers = dict(PATCH_HEADERS)
r = self.get_request(uri)
if r["ret"]:
@ -303,7 +311,7 @@ class RedfishUtils:
return {"ret": False, "changed": False, "msg": f"Failed PATCH request to '{uri}': '{e}'"}
return {"ret": True, "changed": True, "resp": resp, "msg": f"Modified {uri}"}
def put_request(self, uri, pyld):
def put_request(self, uri: str, pyld):
req_headers = dict(PUT_HEADERS)
r = self.get_request(uri)
if r["ret"]:
@ -341,7 +349,7 @@ class RedfishUtils:
return {"ret": False, "msg": f"Failed PUT request to '{uri}': '{e}'"}
return {"ret": True, "resp": resp}
def delete_request(self, uri, pyld=None):
def delete_request(self, uri: str, pyld=None):
req_headers = dict(DELETE_HEADERS)
username, password, basic_auth = self._auth_params(req_headers)
try:

View file

@ -5,11 +5,15 @@
from __future__ import annotations
from ansible.module_utils.basic import missing_required_lib
import traceback
import typing as t
REDIS_IMP_ERR = None
from ansible.module_utils.basic import missing_required_lib, AnsibleModule
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
REDIS_IMP_ERR: str | None = None
try:
from redis import Redis
from redis import __version__ as redis_version
@ -20,30 +24,30 @@ except ImportError:
REDIS_IMP_ERR = traceback.format_exc()
HAS_REDIS_PACKAGE = False
CERTIFI_IMPORT_ERROR: str | None = None
try:
import certifi
HAS_CERTIFI_PACKAGE = True
CERTIFI_IMPORT_ERROR = None
except ImportError:
CERTIFI_IMPORT_ERROR = traceback.format_exc()
HAS_CERTIFI_PACKAGE = False
def fail_imports(module, needs_certifi=True):
errors = []
traceback = []
def fail_imports(module: AnsibleModule, needs_certifi: bool = True) -> None:
errors: list[str] = []
traceback: list[str] = []
if not HAS_REDIS_PACKAGE:
errors.append(missing_required_lib("redis"))
traceback.append(REDIS_IMP_ERR)
traceback.append(REDIS_IMP_ERR) # type: ignore
if not HAS_CERTIFI_PACKAGE and needs_certifi:
errors.append(missing_required_lib("certifi"))
traceback.append(CERTIFI_IMPORT_ERROR)
traceback.append(CERTIFI_IMPORT_ERROR) # type: ignore
if errors:
module.fail_json(msg="\n".join(errors), traceback="\n".join(traceback))
def redis_auth_argument_spec(tls_default=True):
def redis_auth_argument_spec(tls_default: bool = True) -> dict[str, t.Any]:
return dict(
login_host=dict(
type="str",
@ -60,7 +64,7 @@ def redis_auth_argument_spec(tls_default=True):
)
def redis_auth_params(module):
def redis_auth_params(module: AnsibleModule) -> dict[str, t.Any]:
login_host = module.params["login_host"]
login_user = module.params["login_user"]
login_password = module.params["login_password"]
@ -92,13 +96,12 @@ def redis_auth_params(module):
class RedisAnsible:
"""Base class for Redis module"""
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.connection = self._connect()
def _connect(self):
def _connect(self) -> Redis:
try:
return Redis(**redis_auth_params(self.module))
except Exception as e:
self.module.fail_json(msg=f"{e}")
return None

View file

@ -6,11 +6,15 @@ from __future__ import annotations
import json
import traceback
import typing as t
from ansible.module_utils.urls import fetch_url, url_argument_spec
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def api_argument_spec():
def api_argument_spec() -> dict[str, t.Any]:
"""
Creates an argument spec that can be used with any module
that will be requesting content via Rundeck API
@ -27,7 +31,13 @@ def api_argument_spec():
return api_argument_spec
def api_request(module, endpoint, data=None, method="GET", content_type="application/json"):
def api_request(
module: AnsibleModule,
endpoint: str,
data: t.Any | None = None,
method: str = "GET",
content_type: str = "application/json",
) -> tuple[t.Any, dict[str, t.Any]]:
"""Manages Rundeck API requests via HTTP(S)
:arg module: The AnsibleModule (used to get url, api_version, api_token, etc).

View file

@ -10,6 +10,7 @@ import sys
import datetime
import time
import traceback
import typing as t
from urllib.parse import urlencode
from ansible.module_utils.basic import env_fallback, missing_required_lib
@ -19,6 +20,10 @@ from ansible_collections.community.general.plugins.module_utils.datetime import
now,
)
if t.TYPE_CHECKING:
from collections.abc import Iterable
from ansible.module_utils.basic import AnsibleModule
SCALEWAY_SECRET_IMP_ERR: str | None = None
try:
from passlib.hash import argon2
@ -29,7 +34,7 @@ except Exception:
HAS_SCALEWAY_SECRET_PACKAGE = False
def scaleway_argument_spec():
def scaleway_argument_spec() -> dict[str, t.Any]:
return dict(
api_token=dict(
required=True,
@ -59,7 +64,7 @@ def payload_from_object(scw_object):
class ScalewayException(Exception):
def __init__(self, message):
def __init__(self, message: str) -> None:
self.message = message
@ -70,7 +75,7 @@ R_LINK_HEADER = r"""<[^>]+>;\srel="(first|previous|next|last)"
R_RELATION = r'</?(?P<target_IRI>[^>]+)>; rel="(?P<relation>first|previous|next|last)"'
def parse_pagination_link(header):
def parse_pagination_link(header: str) -> dict[str, str]:
if not re.match(R_LINK_HEADER, header, re.VERBOSE):
raise ScalewayException("Scaleway API answered with an invalid Link pagination header")
else:
@ -86,7 +91,7 @@ def parse_pagination_link(header):
return parsed_relations
def filter_sensitive_attributes(container, attributes):
def filter_sensitive_attributes(container: dict[str, t.Any], attributes: Iterable[str]) -> dict[str, t.Any]:
"""
WARNING: This function is effectively private, **do not use it**!
It will be removed or renamed once changing its name no longer triggers a pylint bug.
@ -99,7 +104,7 @@ def filter_sensitive_attributes(container, attributes):
class SecretVariables:
@staticmethod
def ensure_scaleway_secret_package(module):
def ensure_scaleway_secret_package(module: AnsibleModule) -> None:
if not HAS_SCALEWAY_SECRET_PACKAGE:
module.fail_json(
msg=missing_required_lib("passlib[argon2]", url="https://passlib.readthedocs.io/en/stable/"),
@ -169,7 +174,7 @@ class Response:
class Scaleway:
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
self.module = module
self.headers = {
"X-Auth-Token": self.module.params.get("api_token"),
@ -224,8 +229,9 @@ class Scaleway:
return Response(resp, info)
@staticmethod
def get_user_agent_string(module):
return f"ansible {module.ansible_version} Python {sys.version.split(' ', 1)[0]}"
def get_user_agent_string(module: AnsibleModule) -> str:
ansible_version = module.ansible_version # type: ignore # For some reason this isn't documented in AnsibleModule
return f"ansible {ansible_version} Python {sys.version.split(' ', 1)[0]}"
def get(self, path, data=None, headers=None, params=None):
return self.send(method="GET", path=path, data=data, headers=headers, params=params)
@ -245,7 +251,7 @@ class Scaleway:
def update(self, path, data=None, headers=None, params=None):
return self.send(method="UPDATE", path=path, data=data, headers=headers, params=params)
def warn(self, x):
def warn(self, x) -> None:
self.module.warn(str(x))
def fetch_state(self, resource):

View file

@ -4,8 +4,13 @@
from __future__ import annotations
import typing as t
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
_alias_state_map = dict(
present="alias",
@ -22,7 +27,7 @@ _state_map = dict(
)
def snap_runner(module, **kwargs):
def snap_runner(module: AnsibleModule, **kwargs) -> CmdRunner:
runner = CmdRunner(
module,
"snap",
@ -47,7 +52,7 @@ def snap_runner(module, **kwargs):
return runner
def get_version(runner):
def get_version(runner: CmdRunner) -> dict[str, list[str]]:
with runner("version") as ctx:
rc, out, err = ctx.run()
return dict(x.split() for x in out.splitlines() if len(x.split()) == 2)

View file

@ -11,7 +11,7 @@ from __future__ import annotations
import os
def determine_config_file(user, config_file):
def determine_config_file(user: str | None, config_file: str | None) -> str:
if user:
config_file = os.path.join(os.path.expanduser(f"~{user}"), ".ssh", "config")
elif config_file is None:

View file

@ -4,11 +4,15 @@
from __future__ import annotations
import typing as t
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def systemd_runner(module, command, **kwargs):
def systemd_runner(module: AnsibleModule, command, **kwargs) -> CmdRunner:
arg_formats = dict(
version=cmd_runner_fmt.as_fixed("--version"),
list_units=cmd_runner_fmt.as_fixed(["list-units", "--no-pager"]),

View file

@ -165,7 +165,7 @@ def ldap_search(filter, base=None, attr=None):
uldap().lo.lo.abandon(msgid)
def module_by_name(module_name_):
def module_by_name(module_name_: str):
"""Returns an initialized UMC module, identified by the given name.
The module is a module specification according to the udm commandline.
@ -202,7 +202,7 @@ def get_umc_admin_objects():
return univention.admin.objects
def umc_module_for_add(module, container_dn, superordinate=None):
def umc_module_for_add(module: str, container_dn, superordinate=None):
"""Returns an UMC module object prepared for creating a new entry.
The module is a module specification according to the udm commandline.
@ -226,7 +226,7 @@ def umc_module_for_add(module, container_dn, superordinate=None):
return obj
def umc_module_for_edit(module, object_dn, superordinate=None):
def umc_module_for_edit(module: str, object_dn, superordinate=None):
"""Returns an UMC module object prepared for editing an existing entry.
The module is a module specification according to the udm commandline.

View file

@ -12,18 +12,19 @@
from __future__ import annotations
import json
import typing as t
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.urls import fetch_url
class UTMModuleConfigurationError(Exception):
def __init__(self, msg, **args):
def __init__(self, msg: str, **args):
super().__init__(self, msg)
self.msg = msg
self.module_fail_args = args
def do_fail(self, module):
def do_fail(self, module: AnsibleModule) -> t.NoReturn:
module.fail_json(msg=self.msg, other=self.module_fail_args)
@ -75,7 +76,7 @@ class UTMModule(AnsibleModule):
class UTM:
def __init__(self, module, endpoint, change_relevant_keys, info_only=False):
def __init__(self, module: UTMModule, endpoint, change_relevant_keys, info_only=False):
"""
Initialize UTM Class
:param module: The Ansible module

View file

@ -6,33 +6,41 @@
from __future__ import annotations
import copy
import typing as t
class _Variable:
NOTHING = object()
def __init__(self, diff=False, output=True, change=None, fact=False, verbosity=0):
def __init__(
self,
diff: bool = False,
output: bool = True,
change: bool | None = None,
fact: bool = False,
verbosity: int = 0,
):
self.init = False
self.initial_value = None
self.value = None
self.initial_value: t.Any = None
self.value: t.Any = None
self.diff = None
self._change = None
self.output = None
self.fact = None
self._verbosity = None
self.diff: bool = None # type: ignore # will be changed in set_meta() call
self._change: bool | None = None
self.output: bool = None # type: ignore # will be changed in set_meta() call
self.fact: bool = None # type: ignore # will be changed in set_meta() call
self._verbosity: int = None # type: ignore # will be changed in set_meta() call
self.set_meta(output=output, diff=diff, change=change, fact=fact, verbosity=verbosity)
def getchange(self):
def getchange(self) -> bool:
return self.diff if self._change is None else self._change
def setchange(self, value):
def setchange(self, value: bool | None) -> None:
self._change = value
def getverbosity(self):
def getverbosity(self) -> int:
return self._verbosity
def setverbosity(self, v):
def setverbosity(self, v: int) -> None:
if not (0 <= v <= 4):
raise ValueError("verbosity must be an int in the range 0 to 4")
self._verbosity = v
@ -40,7 +48,15 @@ class _Variable:
change = property(getchange, setchange)
verbosity = property(getverbosity, setverbosity)
def set_meta(self, output=None, diff=None, change=None, fact=None, initial_value=NOTHING, verbosity=None):
def set_meta(
self,
output: bool | None = None,
diff: bool | None = None,
change: bool | None = None,
fact: bool | None = None,
initial_value: t.Any = NOTHING,
verbosity: int | None = None,
) -> None:
"""Set the metadata for the variable
Args:
@ -64,7 +80,7 @@ class _Variable:
if verbosity is not None:
self.verbosity = verbosity
def as_dict(self, meta_only=False):
def as_dict(self, meta_only: bool = False) -> dict[str, t.Any]:
d = {
"diff": self.diff,
"change": self.change,
@ -77,27 +93,27 @@ class _Variable:
d["value"] = self.value
return d
def set_value(self, value):
def set_value(self, value: t.Any) -> t.Self:
if not self.init:
self.initial_value = copy.deepcopy(value)
self.init = True
self.value = value
return self
def is_visible(self, verbosity):
def is_visible(self, verbosity: int) -> bool:
return self.verbosity <= verbosity
@property
def has_changed(self):
def has_changed(self) -> bool:
return self.change and (self.initial_value != self.value)
@property
def diff_result(self):
def diff_result(self) -> dict[str, t.Any] | None:
if self.diff and self.has_changed:
return {"before": self.initial_value, "after": self.value}
return
return None
def __str__(self):
def __str__(self) -> str:
return (
f"<Variable: value={self.value!r}, initial={self.initial_value!r}, diff={self.diff}, "
f"output={self.output}, change={self.change}, verbosity={self.verbosity}>"
@ -119,34 +135,34 @@ class VarDict:
"as_dict",
)
def __init__(self):
self.__vars__ = dict()
def __init__(self) -> None:
self.__vars__: dict[str, _Variable] = dict()
def __getitem__(self, item):
def __getitem__(self, item: str):
return self.__vars__[item].value
def __setitem__(self, key, value):
def __setitem__(self, key: str, value) -> None:
self.set(key, value)
def __getattr__(self, item):
def __getattr__(self, item: str):
try:
return self.__vars__[item].value
except KeyError:
return getattr(super(), item)
def __setattr__(self, key, value):
def __setattr__(self, key: str, value) -> None:
if key == "__vars__":
super().__setattr__(key, value)
else:
self.set(key, value)
def _var(self, name):
def _var(self, name: str) -> _Variable:
return self.__vars__[name]
def var(self, name):
def var(self, name: str) -> dict[str, t.Any]:
return self._var(name).as_dict()
def set_meta(self, name, **kwargs):
def set_meta(self, name: str, **kwargs):
"""Set the metadata for the variable
Args:
@ -160,10 +176,10 @@ class VarDict:
"""
self._var(name).set_meta(**kwargs)
def get_meta(self, name):
def get_meta(self, name: str) -> dict[str, t.Any]:
return self._var(name).as_dict(meta_only=True)
def set(self, name, value, **kwargs):
def set(self, name: str, value, **kwargs) -> None:
"""Set the value and optionally metadata for a variable. The variable is not required to exist prior to calling `set`.
For details on the accepted metada see the documentation for method `set_meta`.
@ -185,10 +201,10 @@ class VarDict:
var.set_value(value)
self.__vars__[name] = var
def output(self, verbosity=0):
def output(self, verbosity: int = 0) -> dict[str, t.Any]:
return {n: v.value for n, v in self.__vars__.items() if v.output and v.is_visible(verbosity)}
def diff(self, verbosity=0):
def diff(self, verbosity: int = 0) -> dict[str, dict[str, t.Any]] | None:
diff_results = [
(n, v.diff_result) for n, v in self.__vars__.items() if v.diff_result and v.is_visible(verbosity)
]
@ -198,13 +214,13 @@ class VarDict:
return {"before": before, "after": after}
return None
def facts(self, verbosity=0):
def facts(self, verbosity: int = 0) -> dict[str, t.Any] | None:
facts_result = {n: v.value for n, v in self.__vars__.items() if v.fact and v.is_visible(verbosity)}
return facts_result if facts_result else None
@property
def has_changed(self):
def has_changed(self) -> bool:
return any(var.has_changed for var in self.__vars__.values())
def as_dict(self):
def as_dict(self) -> dict[str, t.Any]:
return {name: var.value for name, var in self.__vars__.items()}

View file

@ -5,6 +5,7 @@
from __future__ import annotations
import typing as t
HAS_VEXATAPI = True
try:
@ -14,10 +15,14 @@ except ImportError:
from ansible.module_utils.basic import env_fallback
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
VXOS_VERSION = None
def get_version(iocs_json):
def get_version(iocs_json) -> tuple[int, ...]:
if not iocs_json:
raise Exception("Invalid IOC json")
active = next((x for x in iocs_json if x["mgmtRole"]), None)
@ -31,7 +36,7 @@ def get_version(iocs_json):
return tuple(ver)
def get_array(module):
def get_array(module: AnsibleModule):
"""Return storage array object or fail"""
global VXOS_VERSION
array = module.params["array"]
@ -60,7 +65,7 @@ def get_array(module):
module.fail_json(msg=f"Vexata API access failed: {e}")
def argument_spec():
def argument_spec() -> dict[str, t.Any]:
"""Return standard base dictionary used for the argument_spec argument in AnsibleModule"""
return dict(
array=dict(type="str", required=True),
@ -70,20 +75,20 @@ def argument_spec():
)
def required_together():
def required_together() -> list[list[str]]:
"""Return the default list used for the required_together argument to AnsibleModule"""
return [["user", "password"]]
def size_to_MiB(size):
def size_to_MiB(size: str) -> int:
"""Convert a '<integer>[MGT]' string to MiB, return -1 on error."""
quant = size[:-1]
exponent = size[-1]
if not quant.isdigit() or exponent not in "MGT":
return -1
quant = int(quant)
quant_int = int(quant)
if exponent == "G":
quant <<= 10
quant_int <<= 10
elif exponent == "T":
quant <<= 20
return quant
quant_int <<= 20
return quant_int

View file

@ -5,15 +5,19 @@
from __future__ import annotations
import datetime
import os
import re
import time
import tarfile
import os
import typing as t
from urllib.parse import urlparse, urlunparse
from ansible.module_utils.urls import fetch_file
from ansible_collections.community.general.plugins.module_utils.redfish_utils import RedfishUtils
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
class WdcRedfishUtils(RedfishUtils):
"""Extension to RedfishUtils to support WDC enclosures."""
@ -41,7 +45,7 @@ class WdcRedfishUtils(RedfishUtils):
CHASSIS_LOCATE = "#Chassis.Locate"
CHASSIS_POWER_MODE = "#Chassis.PowerMode"
def __init__(self, creds, root_uris, timeout, module, resource_id, data_modification):
def __init__(self, creds, root_uris, timeout, module: AnsibleModule, resource_id, data_modification) -> None:
super().__init__(
creds=creds,
root_uri=root_uris[0],

View file

@ -5,10 +5,15 @@
from __future__ import annotations
import typing as t
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def xdg_mime_runner(module, **kwargs):
def xdg_mime_runner(module: AnsibleModule, **kwargs) -> CmdRunner:
return CmdRunner(
module,
command=["xdg-mime"],
@ -23,8 +28,8 @@ def xdg_mime_runner(module, **kwargs):
)
def xdg_mime_get(runner, mime_type):
def process(rc, out, err):
def xdg_mime_get(runner: CmdRunner, mime_type) -> str | None:
def process(rc, out, err) -> str | None:
if not out.strip():
return None
out = out.splitlines()[0]

View file

@ -9,6 +9,7 @@ import atexit
import time
import re
import traceback
import typing as t
XENAPI_IMP_ERR = None
try:
@ -22,8 +23,11 @@ except ImportError:
from ansible.module_utils.basic import env_fallback, missing_required_lib
from ansible.module_utils.ansible_release import __version__ as ANSIBLE_VERSION
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
def xenserver_common_argument_spec():
def xenserver_common_argument_spec() -> dict[str, t.Any]:
return dict(
hostname=dict(
type="str",
@ -41,7 +45,7 @@ def xenserver_common_argument_spec():
)
def xapi_to_module_vm_power_state(power_state):
def xapi_to_module_vm_power_state(power_state: str) -> str | None:
"""Maps XAPI VM power states to module VM power states."""
module_power_state_map = {
"running": "poweredon",
@ -53,7 +57,7 @@ def xapi_to_module_vm_power_state(power_state):
return module_power_state_map.get(power_state)
def module_to_xapi_vm_power_state(power_state):
def module_to_xapi_vm_power_state(power_state: str) -> str | None:
"""Maps module VM power states to XAPI VM power states."""
vm_power_state_map = {
"poweredon": "running",
@ -67,7 +71,7 @@ def module_to_xapi_vm_power_state(power_state):
return vm_power_state_map.get(power_state)
def is_valid_ip_addr(ip_addr):
def is_valid_ip_addr(ip_addr: str) -> bool:
"""Validates given string as IPv4 address for given string.
Args:
@ -93,7 +97,7 @@ def is_valid_ip_addr(ip_addr):
return True
def is_valid_ip_netmask(ip_netmask):
def is_valid_ip_netmask(ip_netmask: str) -> bool:
"""Validates given string as IPv4 netmask.
Args:
@ -125,7 +129,7 @@ def is_valid_ip_netmask(ip_netmask):
return True
def is_valid_ip_prefix(ip_prefix):
def is_valid_ip_prefix(ip_prefix: str) -> bool:
"""Validates given string as IPv4 prefix.
Args:
@ -142,7 +146,7 @@ def is_valid_ip_prefix(ip_prefix):
return not (ip_prefix_int < 0 or ip_prefix_int > 32)
def ip_prefix_to_netmask(ip_prefix, skip_check=False):
def ip_prefix_to_netmask(ip_prefix: str, skip_check: bool = False) -> str:
"""Converts IPv4 prefix to netmask.
Args:
@ -165,7 +169,7 @@ def ip_prefix_to_netmask(ip_prefix, skip_check=False):
return ""
def ip_netmask_to_prefix(ip_netmask, skip_check=False):
def ip_netmask_to_prefix(ip_netmask: str, skip_check: bool = False) -> str:
"""Converts IPv4 netmask to prefix.
Args:
@ -188,7 +192,7 @@ def ip_netmask_to_prefix(ip_netmask, skip_check=False):
return ""
def is_valid_ip6_addr(ip6_addr):
def is_valid_ip6_addr(ip6_addr: str) -> bool:
"""Validates given string as IPv6 address.
Args:
@ -222,7 +226,7 @@ def is_valid_ip6_addr(ip6_addr):
return all(ip6_addr_hextet_regex.match(ip6_addr_hextet) for ip6_addr_hextet in ip6_addr_split)
def is_valid_ip6_prefix(ip6_prefix):
def is_valid_ip6_prefix(ip6_prefix: str) -> bool:
"""Validates given string as IPv6 prefix.
Args:
@ -239,7 +243,7 @@ def is_valid_ip6_prefix(ip6_prefix):
return not (ip6_prefix_int < 0 or ip6_prefix_int > 128)
def get_object_ref(module, name, uuid=None, obj_type="VM", fail=True, msg_prefix=""):
def get_object_ref(module: AnsibleModule, name, uuid=None, obj_type="VM", fail=True, msg_prefix=""):
"""Finds and returns a reference to arbitrary XAPI object.
An object is searched by using either name (name_label) or UUID
@ -305,7 +309,7 @@ def get_object_ref(module, name, uuid=None, obj_type="VM", fail=True, msg_prefix
return obj_ref
def gather_vm_params(module, vm_ref):
def gather_vm_params(module: AnsibleModule, vm_ref):
"""Gathers all VM parameters available in XAPI database.
Args:
@ -395,7 +399,7 @@ def gather_vm_params(module, vm_ref):
return vm_params
def gather_vm_facts(module, vm_params):
def gather_vm_facts(module: AnsibleModule, vm_params):
"""Gathers VM facts.
Args:
@ -502,7 +506,7 @@ def gather_vm_facts(module, vm_params):
return vm_facts
def set_vm_power_state(module, vm_ref, power_state, timeout=300):
def set_vm_power_state(module: AnsibleModule, vm_ref, power_state, timeout=300):
"""Controls VM power state.
Args:
@ -608,7 +612,7 @@ def set_vm_power_state(module, vm_ref, power_state, timeout=300):
return (state_changed, vm_power_state_resulting)
def wait_for_task(module, task_ref, timeout=300):
def wait_for_task(module: AnsibleModule, task_ref, timeout=300):
"""Waits for async XAPI task to finish.
Args:
@ -667,7 +671,7 @@ def wait_for_task(module, task_ref, timeout=300):
return result
def wait_for_vm_ip_address(module, vm_ref, timeout=300):
def wait_for_vm_ip_address(module: AnsibleModule, vm_ref, timeout=300):
"""Waits for VM to acquire an IP address.
Args:
@ -730,7 +734,7 @@ def wait_for_vm_ip_address(module, vm_ref, timeout=300):
return vm_guest_metrics
def get_xenserver_version(module):
def get_xenserver_version(module: AnsibleModule):
"""Returns XenServer version.
Args:
@ -758,10 +762,10 @@ def get_xenserver_version(module):
class XAPI:
"""Class for XAPI session management."""
_xapi_session = None
_xapi_session: t.Any | None = None
@classmethod
def connect(cls, module, disconnect_atexit=True):
def connect(cls, module: AnsibleModule, disconnect_atexit=True):
"""Establishes XAPI connection and returns session reference.
If no existing session is available, establishes a new one
@ -837,7 +841,7 @@ class XenServerObject:
minor version.
"""
def __init__(self, module):
def __init__(self, module: AnsibleModule) -> None:
"""Inits XenServerObject using common module parameters.
Args:

View file

@ -4,9 +4,14 @@
from __future__ import annotations
import typing as t
from ansible.module_utils.parsing.convert_bool import boolean
from ansible_collections.community.general.plugins.module_utils.cmd_runner import CmdRunner, cmd_runner_fmt
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
@cmd_runner_fmt.unpack_args
def _values_fmt(values, value_types):
@ -18,7 +23,7 @@ def _values_fmt(values, value_types):
return result
def xfconf_runner(module, **kwargs):
def xfconf_runner(module: AnsibleModule, **kwargs) -> CmdRunner:
runner = CmdRunner(
module,
command="xfconf-query",
@ -37,7 +42,7 @@ def xfconf_runner(module, **kwargs):
return runner
def get_xfconf_version(runner):
def get_xfconf_version(runner: CmdRunner) -> str:
with runner("version") as ctx:
rc, out, err = ctx.run()
return out.splitlines()[0].split()[1]

View file

@ -183,6 +183,7 @@ logs:
"""
import os
import typing as t
from urllib.parse import urlencode
from ansible.module_utils.basic import AnsibleModule
@ -246,19 +247,16 @@ class LXDStoragePoolInfo:
self.trust_password = self.module.params["trust_password"]
def _fail_from_lxd_exception(self, exception: LXDClientException) -> None:
def _fail_from_lxd_exception(self, exception: LXDClientException) -> t.NoReturn:
"""Build failure parameters from LXDClientException and fail.
:param exception: The LXDClientException instance
:type exception: LXDClientException
"""
fail_params = {
"msg": exception.msg,
"changed": False,
}
fail_params = {}
if self.client.debug and "logs" in exception.kwargs:
fail_params["logs"] = exception.kwargs["logs"]
self.module.fail_json(**fail_params)
self.module.fail_json(msg=exception.msg, changed=False, **fail_params)
def _build_url(self, endpoint: str) -> str:
"""Build URL with project parameter if specified."""

View file

@ -188,6 +188,7 @@ logs:
"""
import os
import typing as t
from urllib.parse import quote, urlencode
from ansible.module_utils.basic import AnsibleModule
@ -252,19 +253,16 @@ class LXDStorageVolumeInfo:
self.trust_password = self.module.params["trust_password"]
def _fail_from_lxd_exception(self, exception: LXDClientException) -> None:
def _fail_from_lxd_exception(self, exception: LXDClientException) -> t.NoReturn:
"""Build failure parameters from LXDClientException and fail.
:param exception: The LXDClientException instance
:type exception: LXDClientException
"""
fail_params = {
"msg": exception.msg,
"changed": False,
}
fail_params = {}
if self.client.debug and "logs" in exception.kwargs:
fail_params["logs"] = exception.kwargs["logs"]
self.module.fail_json(**fail_params)
self.module.fail_json(msg=exception.msg, changed=False, **fail_params)
def _build_url(self, endpoint: str) -> str:
"""Build URL with project parameter if specified."""