1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2026-02-03 23:41:51 +00:00

iptables_state: get rid of temporary files (#11258)

Get rid of temporary files.
This commit is contained in:
Felix Fontein 2025-12-06 13:40:59 +01:00 committed by GitHub
parent 3d25aac978
commit 0ef3eac0f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 47 additions and 46 deletions

View file

@ -225,9 +225,6 @@ tables:
import re
import os
import time
import tempfile
import filecmp
import shutil
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.text.converters import to_bytes, to_native
@ -260,18 +257,28 @@ def read_state(b_path):
return [t for t in text.splitlines() if t != ""]
def write_state(b_path, lines, changed):
def get_file_contents(b_path: bytes) -> bytes | None:
try:
with open(b_path, "rb") as f:
return f.read()
except FileNotFoundError:
return None
def write_state(module: AnsibleModule, b_path: bytes, lines: list[str], changed: bool) -> bool:
"""
Write given contents to the given path, and return changed status.
"""
# Populate a temporary file
tmpfd, tmpfile = tempfile.mkstemp()
with os.fdopen(tmpfd, "w") as f:
joined_lines = "\n".join(lines)
f.write(f"{joined_lines}\n")
joined_lines = "\n".join(lines)
content = f"{joined_lines}\n".encode("utf-8")
# Prepare to copy temporary file to the final destination
if not os.path.exists(b_path):
existing_contents = get_file_contents(b_path)
if existing_contents == content:
return changed
changed = True
if existing_contents is None:
b_destdir = os.path.dirname(b_path)
destdir = to_native(b_destdir, errors="surrogate_or_strict")
if b_destdir and not os.path.exists(b_destdir) and not module.check_mode:
@ -279,15 +286,11 @@ def write_state(b_path, lines, changed):
os.makedirs(b_destdir)
except Exception as err:
module.fail_json(msg=f"Error creating {destdir}: {err}", initial_state=lines)
changed = True
elif not filecmp.cmp(tmpfile, b_path):
changed = True
# Do it
if changed and not module.check_mode:
if not module.check_mode:
try:
shutil.copyfile(tmpfile, b_path)
with open(b_path, "wb") as f:
f.write(content)
except Exception as err:
path = to_native(b_path, errors="surrogate_or_strict")
module.fail_json(msg=f"Error saving state into {path}: {err}", initial_state=lines)
@ -295,7 +298,7 @@ def write_state(b_path, lines, changed):
return changed
def initialize_from_null_state(initializer, initcommand, fallbackcmd, table):
def initialize_from_null_state(module: AnsibleModule, initializer, initcommand, fallbackcmd, table):
"""
This ensures iptables-state output is suitable for iptables-restore to roll
back to it, i.e. iptables-save output is not empty. This also works for the
@ -317,7 +320,7 @@ def initialize_from_null_state(initializer, initcommand, fallbackcmd, table):
return rc, out, err
def filter_and_format_state(string):
def filter_and_format_state(module: AnsibleModule, string: str) -> list[str]:
"""
Remove timestamps to ensure idempotence between runs. Also remove counters
by default. And return the result as a list.
@ -329,15 +332,15 @@ def filter_and_format_state(string):
return lines
def parse_per_table_state(all_states_dump):
def parse_per_table_state(module: AnsibleModule, all_states_dump) -> dict[str, list[str]]:
"""
Convert raw iptables-save output into usable datastructure, for reliable
comparisons between initial and final states.
"""
lines = filter_and_format_state(all_states_dump)
tables = dict()
lines = filter_and_format_state(module, all_states_dump)
tables: dict[str, list[str]] = {}
current_table = ""
current_list = list()
current_list: list[str] = []
for line in lines:
if re.match(r"^[*](filter|mangle|nat|raw|security)$", line):
current_table = line[1:]
@ -345,7 +348,7 @@ def parse_per_table_state(all_states_dump):
if line == "COMMIT":
tables[current_table] = current_list
current_table = ""
current_list = list()
current_list = []
continue
if line.startswith("# "):
continue
@ -353,9 +356,7 @@ def parse_per_table_state(all_states_dump):
return tables
def main():
global module
def main() -> None:
module = AnsibleModule(
argument_spec=dict(
path=dict(type="path", required=True),
@ -459,28 +460,30 @@ def main():
for t in TABLES:
if f"*{t}" in state_to_restore:
if len(stdout) == 0 or f"*{t}" not in stdout.splitlines():
(rc, stdout, stderr) = initialize_from_null_state(INITIALIZER, INITCOMMAND, FALLBACKCMD, t)
(rc, stdout, stderr) = initialize_from_null_state(
module, INITIALIZER, INITCOMMAND, FALLBACKCMD, t
)
elif len(stdout) == 0:
(rc, stdout, stderr) = initialize_from_null_state(INITIALIZER, INITCOMMAND, FALLBACKCMD, "filter")
(rc, stdout, stderr) = initialize_from_null_state(module, INITIALIZER, INITCOMMAND, FALLBACKCMD, "filter")
elif state == "restored" and f"*{table}" not in state_to_restore:
module.fail_json(msg=f"Table {table} to restore not defined in {path}")
elif len(stdout) == 0 or f"*{table}" not in stdout.splitlines():
(rc, stdout, stderr) = initialize_from_null_state(INITIALIZER, INITCOMMAND, FALLBACKCMD, table)
(rc, stdout, stderr) = initialize_from_null_state(module, INITIALIZER, INITCOMMAND, FALLBACKCMD, table)
initial_state = filter_and_format_state(stdout)
initial_state = filter_and_format_state(module, stdout)
if initial_state is None:
module.fail_json(msg="Unable to initialize firewall from NULL state.")
# Depending on the value of 'table', initref_state may differ from
# initial_state.
(rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True)
tables_before = parse_per_table_state(stdout)
initref_state = filter_and_format_state(stdout)
tables_before = parse_per_table_state(module, stdout)
initref_state = filter_and_format_state(module, stdout)
if state == "saved":
changed = write_state(b_path, initref_state, changed)
changed = write_state(module, b_path, initref_state, changed)
module.exit_json(
changed=changed, cmd=cmd, tables=tables_before, initial_state=initial_state, saved=initref_state
)
@ -497,7 +500,7 @@ def main():
if _back is not None:
b_back = to_bytes(_back, errors="surrogate_or_strict")
dummy = write_state(b_back, initref_state, changed)
dummy = write_state(module, b_back, initref_state, changed)
BACKCOMMAND = list(MAINCOMMAND)
BACKCOMMAND.append(_back)
@ -536,12 +539,8 @@ def main():
)
if module.check_mode:
tmpfd, tmpfile = tempfile.mkstemp()
with os.fdopen(tmpfd, "w") as f:
joined_initial_state = "\n".join(initial_state)
f.write(f"{joined_initial_state}\n")
if filecmp.cmp(tmpfile, b_path):
joined_initial_state = "\n".join(initial_state)
if get_file_contents(b_path) == f"{joined_initial_state}\n".encode("utf-8"):
restored_state = initial_state
else:
restored_state = state_to_restore
@ -572,8 +571,8 @@ def main():
)
(rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True)
restored_state = filter_and_format_state(stdout)
tables_after = parse_per_table_state("\n".join(restored_state))
restored_state = filter_and_format_state(module, stdout)
tables_after = parse_per_table_state(module, "\n".join(restored_state))
if restored_state not in (initref_state, initial_state):
for table_name, table_content in tables_after.items():
if table_name not in tables_before:
@ -608,7 +607,7 @@ def main():
# timeout
# * task attribute 'poll' equals 0
#
for dummy in range(_timeout):
for dummy2 in range(_timeout):
if os.path.exists(b_back):
time.sleep(1)
continue
@ -628,7 +627,7 @@ def main():
os.remove(b_back)
(rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True)
tables_rollback = parse_per_table_state(stdout)
tables_rollback = parse_per_table_state(module, stdout)
msg = f"Failed to confirm state restored from {path} after {_timeout}s. Firewall has been rolled back to its initial state."