1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2026-02-04 07:51:50 +00:00
community.general/tests/unit/plugins/lookup/test_tss.py
Felix Fontein 236b9c0e04
Sort imports with ruff check --fix (#11400)
Sort imports with ruff check --fix.
2026-01-09 07:40:58 +01:00

112 lines
3.4 KiB
Python

# Copyright (c) 2020, Adam Migus <adam@migus.org>
# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt)
# SPDX-License-Identifier: GPL-3.0-or-later
# Make coding more python3-ish
from __future__ import annotations
from unittest import TestCase
from unittest.mock import (
DEFAULT,
MagicMock,
patch,
)
from ansible.plugins.loader import lookup_loader
from ansible_collections.community.general.plugins.lookup import tss
TSS_IMPORT_PATH = "ansible_collections.community.general.plugins.lookup.tss"
def make_absolute(name):
return f"{TSS_IMPORT_PATH}.{name}"
class SecretServerError(Exception):
def __init__(self):
self.message = ""
class MockSecretServer(MagicMock):
RESPONSE = '{"foo": "bar"}'
def get_secret_json(self, path):
return self.RESPONSE
class MockFaultySecretServer(MagicMock):
def get_secret_json(self, path):
raise SecretServerError
@patch(make_absolute("SecretServer"), MockSecretServer())
class TestTSSClient(TestCase):
def setUp(self):
self.server_params = {
"base_url": "",
"username": "",
"domain": "",
"password": "",
"api_path_uri": "",
"token_path_uri": "",
}
def test_from_params(self):
with patch(make_absolute("HAS_TSS_AUTHORIZER"), False):
self.assert_client_version("v0")
with patch.dict(self.server_params, {"domain": "foo"}):
with self.assertRaises(tss.AnsibleError):
self._get_client()
with patch.multiple(
TSS_IMPORT_PATH,
HAS_TSS_AUTHORIZER=True,
PasswordGrantAuthorizer=DEFAULT,
DomainPasswordGrantAuthorizer=DEFAULT,
):
self.assert_client_version("v1")
with patch.dict(self.server_params, {"domain": "foo"}):
self.assert_client_version("v1")
def assert_client_version(self, version):
version_to_class = {"v0": tss.TSSClientV0, "v1": tss.TSSClientV1}
client = self._get_client()
self.assertIsInstance(client, version_to_class[version])
def _get_client(self):
return tss.TSSClient.from_params(**self.server_params)
class TestLookupModule(TestCase):
VALID_TERMS = [1]
INVALID_TERMS = ["foo"]
def setUp(self):
self.lookup = lookup_loader.get("community.general.tss")
@patch.multiple(TSS_IMPORT_PATH, HAS_TSS_SDK=False, SecretServer=MockSecretServer)
def test_missing_sdk(self):
with self.assertRaises(tss.AnsibleError):
self._run_lookup(self.VALID_TERMS)
@patch.multiple(TSS_IMPORT_PATH, HAS_TSS_SDK=True, SecretServerError=SecretServerError)
def test_get_secret_json(self):
with patch(make_absolute("SecretServer"), MockSecretServer):
self.assertListEqual([MockSecretServer.RESPONSE], self._run_lookup(self.VALID_TERMS))
with self.assertRaises(tss.AnsibleOptionsError):
self._run_lookup(self.INVALID_TERMS)
with patch(make_absolute("SecretServer"), MockFaultySecretServer):
with self.assertRaises(tss.AnsibleError):
self._run_lookup(self.VALID_TERMS)
def _run_lookup(self, terms, variables=None, **kwargs):
variables = variables or []
kwargs = kwargs or {"base_url": "dummy", "username": "dummy", "password": "dummy"}
return self.lookup.run(terms, variables, **kwargs)