From 216e9e28c3094116beccad73c2db8512a0abab24 Mon Sep 17 00:00:00 2001 From: "patchback[bot]" <45432694+patchback[bot]@users.noreply.github.com> Date: Sat, 6 Dec 2025 13:52:41 +0100 Subject: [PATCH] [PR #11258/0ef3eac0 backport][stable-12] iptables_state: get rid of temporary files (#11262) iptables_state: get rid of temporary files (#11258) Get rid of temporary files. (cherry picked from commit 0ef3eac0f484c342ccb06e1cf459db6db41b0144) Co-authored-by: Felix Fontein --- changelogs/fragments/11258-iptables_state.yml | 2 + plugins/modules/iptables_state.py | 91 +++++++++---------- 2 files changed, 47 insertions(+), 46 deletions(-) create mode 100644 changelogs/fragments/11258-iptables_state.yml diff --git a/changelogs/fragments/11258-iptables_state.yml b/changelogs/fragments/11258-iptables_state.yml new file mode 100644 index 0000000000..26195c5045 --- /dev/null +++ b/changelogs/fragments/11258-iptables_state.yml @@ -0,0 +1,2 @@ +bugfixes: + - "iptables_state - refactor code to avoid writing unnecessary temporary files (https://github.com/ansible-collections/community.general/pull/11258)." diff --git a/plugins/modules/iptables_state.py b/plugins/modules/iptables_state.py index a2046c386b..855bb5c38c 100644 --- a/plugins/modules/iptables_state.py +++ b/plugins/modules/iptables_state.py @@ -225,9 +225,6 @@ tables: import re import os import time -import tempfile -import filecmp -import shutil from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.common.text.converters import to_bytes, to_native @@ -260,18 +257,28 @@ def read_state(b_path): return [t for t in text.splitlines() if t != ""] -def write_state(b_path, lines, changed): +def get_file_contents(b_path: bytes) -> bytes | None: + try: + with open(b_path, "rb") as f: + return f.read() + except FileNotFoundError: + return None + + +def write_state(module: AnsibleModule, b_path: bytes, lines: list[str], changed: bool) -> bool: """ Write given contents to the given path, and return changed status. """ - # Populate a temporary file - tmpfd, tmpfile = tempfile.mkstemp() - with os.fdopen(tmpfd, "w") as f: - joined_lines = "\n".join(lines) - f.write(f"{joined_lines}\n") + joined_lines = "\n".join(lines) + content = f"{joined_lines}\n".encode("utf-8") - # Prepare to copy temporary file to the final destination - if not os.path.exists(b_path): + existing_contents = get_file_contents(b_path) + if existing_contents == content: + return changed + + changed = True + + if existing_contents is None: b_destdir = os.path.dirname(b_path) destdir = to_native(b_destdir, errors="surrogate_or_strict") if b_destdir and not os.path.exists(b_destdir) and not module.check_mode: @@ -279,15 +286,11 @@ def write_state(b_path, lines, changed): os.makedirs(b_destdir) except Exception as err: module.fail_json(msg=f"Error creating {destdir}: {err}", initial_state=lines) - changed = True - elif not filecmp.cmp(tmpfile, b_path): - changed = True - - # Do it - if changed and not module.check_mode: + if not module.check_mode: try: - shutil.copyfile(tmpfile, b_path) + with open(b_path, "wb") as f: + f.write(content) except Exception as err: path = to_native(b_path, errors="surrogate_or_strict") module.fail_json(msg=f"Error saving state into {path}: {err}", initial_state=lines) @@ -295,7 +298,7 @@ def write_state(b_path, lines, changed): return changed -def initialize_from_null_state(initializer, initcommand, fallbackcmd, table): +def initialize_from_null_state(module: AnsibleModule, initializer, initcommand, fallbackcmd, table): """ This ensures iptables-state output is suitable for iptables-restore to roll back to it, i.e. iptables-save output is not empty. This also works for the @@ -317,7 +320,7 @@ def initialize_from_null_state(initializer, initcommand, fallbackcmd, table): return rc, out, err -def filter_and_format_state(string): +def filter_and_format_state(module: AnsibleModule, string: str) -> list[str]: """ Remove timestamps to ensure idempotence between runs. Also remove counters by default. And return the result as a list. @@ -329,15 +332,15 @@ def filter_and_format_state(string): return lines -def parse_per_table_state(all_states_dump): +def parse_per_table_state(module: AnsibleModule, all_states_dump) -> dict[str, list[str]]: """ Convert raw iptables-save output into usable datastructure, for reliable comparisons between initial and final states. """ - lines = filter_and_format_state(all_states_dump) - tables = dict() + lines = filter_and_format_state(module, all_states_dump) + tables: dict[str, list[str]] = {} current_table = "" - current_list = list() + current_list: list[str] = [] for line in lines: if re.match(r"^[*](filter|mangle|nat|raw|security)$", line): current_table = line[1:] @@ -345,7 +348,7 @@ def parse_per_table_state(all_states_dump): if line == "COMMIT": tables[current_table] = current_list current_table = "" - current_list = list() + current_list = [] continue if line.startswith("# "): continue @@ -353,9 +356,7 @@ def parse_per_table_state(all_states_dump): return tables -def main(): - global module - +def main() -> None: module = AnsibleModule( argument_spec=dict( path=dict(type="path", required=True), @@ -459,28 +460,30 @@ def main(): for t in TABLES: if f"*{t}" in state_to_restore: if len(stdout) == 0 or f"*{t}" not in stdout.splitlines(): - (rc, stdout, stderr) = initialize_from_null_state(INITIALIZER, INITCOMMAND, FALLBACKCMD, t) + (rc, stdout, stderr) = initialize_from_null_state( + module, INITIALIZER, INITCOMMAND, FALLBACKCMD, t + ) elif len(stdout) == 0: - (rc, stdout, stderr) = initialize_from_null_state(INITIALIZER, INITCOMMAND, FALLBACKCMD, "filter") + (rc, stdout, stderr) = initialize_from_null_state(module, INITIALIZER, INITCOMMAND, FALLBACKCMD, "filter") elif state == "restored" and f"*{table}" not in state_to_restore: module.fail_json(msg=f"Table {table} to restore not defined in {path}") elif len(stdout) == 0 or f"*{table}" not in stdout.splitlines(): - (rc, stdout, stderr) = initialize_from_null_state(INITIALIZER, INITCOMMAND, FALLBACKCMD, table) + (rc, stdout, stderr) = initialize_from_null_state(module, INITIALIZER, INITCOMMAND, FALLBACKCMD, table) - initial_state = filter_and_format_state(stdout) + initial_state = filter_and_format_state(module, stdout) if initial_state is None: module.fail_json(msg="Unable to initialize firewall from NULL state.") # Depending on the value of 'table', initref_state may differ from # initial_state. (rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True) - tables_before = parse_per_table_state(stdout) - initref_state = filter_and_format_state(stdout) + tables_before = parse_per_table_state(module, stdout) + initref_state = filter_and_format_state(module, stdout) if state == "saved": - changed = write_state(b_path, initref_state, changed) + changed = write_state(module, b_path, initref_state, changed) module.exit_json( changed=changed, cmd=cmd, tables=tables_before, initial_state=initial_state, saved=initref_state ) @@ -497,7 +500,7 @@ def main(): if _back is not None: b_back = to_bytes(_back, errors="surrogate_or_strict") - dummy = write_state(b_back, initref_state, changed) + dummy = write_state(module, b_back, initref_state, changed) BACKCOMMAND = list(MAINCOMMAND) BACKCOMMAND.append(_back) @@ -536,12 +539,8 @@ def main(): ) if module.check_mode: - tmpfd, tmpfile = tempfile.mkstemp() - with os.fdopen(tmpfd, "w") as f: - joined_initial_state = "\n".join(initial_state) - f.write(f"{joined_initial_state}\n") - - if filecmp.cmp(tmpfile, b_path): + joined_initial_state = "\n".join(initial_state) + if get_file_contents(b_path) == f"{joined_initial_state}\n".encode("utf-8"): restored_state = initial_state else: restored_state = state_to_restore @@ -572,8 +571,8 @@ def main(): ) (rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True) - restored_state = filter_and_format_state(stdout) - tables_after = parse_per_table_state("\n".join(restored_state)) + restored_state = filter_and_format_state(module, stdout) + tables_after = parse_per_table_state(module, "\n".join(restored_state)) if restored_state not in (initref_state, initial_state): for table_name, table_content in tables_after.items(): if table_name not in tables_before: @@ -608,7 +607,7 @@ def main(): # timeout # * task attribute 'poll' equals 0 # - for dummy in range(_timeout): + for dummy2 in range(_timeout): if os.path.exists(b_back): time.sleep(1) continue @@ -628,7 +627,7 @@ def main(): os.remove(b_back) (rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True) - tables_rollback = parse_per_table_state(stdout) + tables_rollback = parse_per_table_state(module, stdout) msg = f"Failed to confirm state restored from {path} after {_timeout}s. Firewall has been rolled back to its initial state."