mirror of
https://github.com/ansible-collections/community.general.git
synced 2026-02-03 23:41:51 +00:00
iptables_state: get rid of temporary files (#11258)
Get rid of temporary files.
This commit is contained in:
parent
3d25aac978
commit
0ef3eac0f4
2 changed files with 47 additions and 46 deletions
2
changelogs/fragments/11258-iptables_state.yml
Normal file
2
changelogs/fragments/11258-iptables_state.yml
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
bugfixes:
|
||||
- "iptables_state - refactor code to avoid writing unnecessary temporary files (https://github.com/ansible-collections/community.general/pull/11258)."
|
||||
|
|
@ -225,9 +225,6 @@ tables:
|
|||
import re
|
||||
import os
|
||||
import time
|
||||
import tempfile
|
||||
import filecmp
|
||||
import shutil
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.common.text.converters import to_bytes, to_native
|
||||
|
|
@ -260,18 +257,28 @@ def read_state(b_path):
|
|||
return [t for t in text.splitlines() if t != ""]
|
||||
|
||||
|
||||
def write_state(b_path, lines, changed):
|
||||
def get_file_contents(b_path: bytes) -> bytes | None:
|
||||
try:
|
||||
with open(b_path, "rb") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
def write_state(module: AnsibleModule, b_path: bytes, lines: list[str], changed: bool) -> bool:
|
||||
"""
|
||||
Write given contents to the given path, and return changed status.
|
||||
"""
|
||||
# Populate a temporary file
|
||||
tmpfd, tmpfile = tempfile.mkstemp()
|
||||
with os.fdopen(tmpfd, "w") as f:
|
||||
joined_lines = "\n".join(lines)
|
||||
f.write(f"{joined_lines}\n")
|
||||
joined_lines = "\n".join(lines)
|
||||
content = f"{joined_lines}\n".encode("utf-8")
|
||||
|
||||
# Prepare to copy temporary file to the final destination
|
||||
if not os.path.exists(b_path):
|
||||
existing_contents = get_file_contents(b_path)
|
||||
if existing_contents == content:
|
||||
return changed
|
||||
|
||||
changed = True
|
||||
|
||||
if existing_contents is None:
|
||||
b_destdir = os.path.dirname(b_path)
|
||||
destdir = to_native(b_destdir, errors="surrogate_or_strict")
|
||||
if b_destdir and not os.path.exists(b_destdir) and not module.check_mode:
|
||||
|
|
@ -279,15 +286,11 @@ def write_state(b_path, lines, changed):
|
|||
os.makedirs(b_destdir)
|
||||
except Exception as err:
|
||||
module.fail_json(msg=f"Error creating {destdir}: {err}", initial_state=lines)
|
||||
changed = True
|
||||
|
||||
elif not filecmp.cmp(tmpfile, b_path):
|
||||
changed = True
|
||||
|
||||
# Do it
|
||||
if changed and not module.check_mode:
|
||||
if not module.check_mode:
|
||||
try:
|
||||
shutil.copyfile(tmpfile, b_path)
|
||||
with open(b_path, "wb") as f:
|
||||
f.write(content)
|
||||
except Exception as err:
|
||||
path = to_native(b_path, errors="surrogate_or_strict")
|
||||
module.fail_json(msg=f"Error saving state into {path}: {err}", initial_state=lines)
|
||||
|
|
@ -295,7 +298,7 @@ def write_state(b_path, lines, changed):
|
|||
return changed
|
||||
|
||||
|
||||
def initialize_from_null_state(initializer, initcommand, fallbackcmd, table):
|
||||
def initialize_from_null_state(module: AnsibleModule, initializer, initcommand, fallbackcmd, table):
|
||||
"""
|
||||
This ensures iptables-state output is suitable for iptables-restore to roll
|
||||
back to it, i.e. iptables-save output is not empty. This also works for the
|
||||
|
|
@ -317,7 +320,7 @@ def initialize_from_null_state(initializer, initcommand, fallbackcmd, table):
|
|||
return rc, out, err
|
||||
|
||||
|
||||
def filter_and_format_state(string):
|
||||
def filter_and_format_state(module: AnsibleModule, string: str) -> list[str]:
|
||||
"""
|
||||
Remove timestamps to ensure idempotence between runs. Also remove counters
|
||||
by default. And return the result as a list.
|
||||
|
|
@ -329,15 +332,15 @@ def filter_and_format_state(string):
|
|||
return lines
|
||||
|
||||
|
||||
def parse_per_table_state(all_states_dump):
|
||||
def parse_per_table_state(module: AnsibleModule, all_states_dump) -> dict[str, list[str]]:
|
||||
"""
|
||||
Convert raw iptables-save output into usable datastructure, for reliable
|
||||
comparisons between initial and final states.
|
||||
"""
|
||||
lines = filter_and_format_state(all_states_dump)
|
||||
tables = dict()
|
||||
lines = filter_and_format_state(module, all_states_dump)
|
||||
tables: dict[str, list[str]] = {}
|
||||
current_table = ""
|
||||
current_list = list()
|
||||
current_list: list[str] = []
|
||||
for line in lines:
|
||||
if re.match(r"^[*](filter|mangle|nat|raw|security)$", line):
|
||||
current_table = line[1:]
|
||||
|
|
@ -345,7 +348,7 @@ def parse_per_table_state(all_states_dump):
|
|||
if line == "COMMIT":
|
||||
tables[current_table] = current_list
|
||||
current_table = ""
|
||||
current_list = list()
|
||||
current_list = []
|
||||
continue
|
||||
if line.startswith("# "):
|
||||
continue
|
||||
|
|
@ -353,9 +356,7 @@ def parse_per_table_state(all_states_dump):
|
|||
return tables
|
||||
|
||||
|
||||
def main():
|
||||
global module
|
||||
|
||||
def main() -> None:
|
||||
module = AnsibleModule(
|
||||
argument_spec=dict(
|
||||
path=dict(type="path", required=True),
|
||||
|
|
@ -459,28 +460,30 @@ def main():
|
|||
for t in TABLES:
|
||||
if f"*{t}" in state_to_restore:
|
||||
if len(stdout) == 0 or f"*{t}" not in stdout.splitlines():
|
||||
(rc, stdout, stderr) = initialize_from_null_state(INITIALIZER, INITCOMMAND, FALLBACKCMD, t)
|
||||
(rc, stdout, stderr) = initialize_from_null_state(
|
||||
module, INITIALIZER, INITCOMMAND, FALLBACKCMD, t
|
||||
)
|
||||
elif len(stdout) == 0:
|
||||
(rc, stdout, stderr) = initialize_from_null_state(INITIALIZER, INITCOMMAND, FALLBACKCMD, "filter")
|
||||
(rc, stdout, stderr) = initialize_from_null_state(module, INITIALIZER, INITCOMMAND, FALLBACKCMD, "filter")
|
||||
|
||||
elif state == "restored" and f"*{table}" not in state_to_restore:
|
||||
module.fail_json(msg=f"Table {table} to restore not defined in {path}")
|
||||
|
||||
elif len(stdout) == 0 or f"*{table}" not in stdout.splitlines():
|
||||
(rc, stdout, stderr) = initialize_from_null_state(INITIALIZER, INITCOMMAND, FALLBACKCMD, table)
|
||||
(rc, stdout, stderr) = initialize_from_null_state(module, INITIALIZER, INITCOMMAND, FALLBACKCMD, table)
|
||||
|
||||
initial_state = filter_and_format_state(stdout)
|
||||
initial_state = filter_and_format_state(module, stdout)
|
||||
if initial_state is None:
|
||||
module.fail_json(msg="Unable to initialize firewall from NULL state.")
|
||||
|
||||
# Depending on the value of 'table', initref_state may differ from
|
||||
# initial_state.
|
||||
(rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True)
|
||||
tables_before = parse_per_table_state(stdout)
|
||||
initref_state = filter_and_format_state(stdout)
|
||||
tables_before = parse_per_table_state(module, stdout)
|
||||
initref_state = filter_and_format_state(module, stdout)
|
||||
|
||||
if state == "saved":
|
||||
changed = write_state(b_path, initref_state, changed)
|
||||
changed = write_state(module, b_path, initref_state, changed)
|
||||
module.exit_json(
|
||||
changed=changed, cmd=cmd, tables=tables_before, initial_state=initial_state, saved=initref_state
|
||||
)
|
||||
|
|
@ -497,7 +500,7 @@ def main():
|
|||
|
||||
if _back is not None:
|
||||
b_back = to_bytes(_back, errors="surrogate_or_strict")
|
||||
dummy = write_state(b_back, initref_state, changed)
|
||||
dummy = write_state(module, b_back, initref_state, changed)
|
||||
BACKCOMMAND = list(MAINCOMMAND)
|
||||
BACKCOMMAND.append(_back)
|
||||
|
||||
|
|
@ -536,12 +539,8 @@ def main():
|
|||
)
|
||||
|
||||
if module.check_mode:
|
||||
tmpfd, tmpfile = tempfile.mkstemp()
|
||||
with os.fdopen(tmpfd, "w") as f:
|
||||
joined_initial_state = "\n".join(initial_state)
|
||||
f.write(f"{joined_initial_state}\n")
|
||||
|
||||
if filecmp.cmp(tmpfile, b_path):
|
||||
joined_initial_state = "\n".join(initial_state)
|
||||
if get_file_contents(b_path) == f"{joined_initial_state}\n".encode("utf-8"):
|
||||
restored_state = initial_state
|
||||
else:
|
||||
restored_state = state_to_restore
|
||||
|
|
@ -572,8 +571,8 @@ def main():
|
|||
)
|
||||
|
||||
(rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True)
|
||||
restored_state = filter_and_format_state(stdout)
|
||||
tables_after = parse_per_table_state("\n".join(restored_state))
|
||||
restored_state = filter_and_format_state(module, stdout)
|
||||
tables_after = parse_per_table_state(module, "\n".join(restored_state))
|
||||
if restored_state not in (initref_state, initial_state):
|
||||
for table_name, table_content in tables_after.items():
|
||||
if table_name not in tables_before:
|
||||
|
|
@ -608,7 +607,7 @@ def main():
|
|||
# timeout
|
||||
# * task attribute 'poll' equals 0
|
||||
#
|
||||
for dummy in range(_timeout):
|
||||
for dummy2 in range(_timeout):
|
||||
if os.path.exists(b_back):
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
|
@ -628,7 +627,7 @@ def main():
|
|||
os.remove(b_back)
|
||||
|
||||
(rc, stdout, stderr) = module.run_command(SAVECOMMAND, check_rc=True)
|
||||
tables_rollback = parse_per_table_state(stdout)
|
||||
tables_rollback = parse_per_table_state(module, stdout)
|
||||
|
||||
msg = f"Failed to confirm state restored from {path} after {_timeout}s. Firewall has been rolled back to its initial state."
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue