1
0
Fork 0
mirror of https://github.com/ansible-collections/community.general.git synced 2026-02-04 07:51:50 +00:00
community.general/plugins/module_utils/database.py
Felix Fontein c7f6a28d89
Add basic typing for module_utils (#11222)
* Add basic typing for module_utils.

* Apply some suggestions.

Co-authored-by: Alexei Znamensky <103110+russoz@users.noreply.github.com>

* Make pass again.

* Add more types as suggested.

* Normalize extra imports.

* Add more type hints.

* Improve typing.

* Add changelog fragment.

* Reduce changelog.

* Apply suggestions from code review.

Co-authored-by: Alexei Znamensky <103110+russoz@users.noreply.github.com>

* Fix typo.

* Cleanup.

* Improve types and make type checking happy.

* Let's see whether older Pythons barf on this.

* Revert "Let's see whether older Pythons barf on this."

This reverts commit 9973af3dbe.

* Add noqa.

---------

Co-authored-by: Alexei Znamensky <103110+russoz@users.noreply.github.com>
2025-12-01 20:40:06 +01:00

194 lines
6.5 KiB
Python

# This code is part of Ansible, but is an independent component.
# This particular file snippet, and this file snippet only, is BSD licensed.
# Modules you write using this snippet, which is embedded dynamically by Ansible
# still belong to the author of the module, and may assign their own license
# to the complete work.
#
# Copyright (c) 2014, Toshio Kuratomi <tkuratomi@ansible.com>
#
# Simplified BSD License (see LICENSES/BSD-2-Clause.txt or https://opensource.org/licenses/BSD-2-Clause)
# SPDX-License-Identifier: BSD-2-Clause
# This module utils is deprecated and will be removed in community.general 13.0.0
from __future__ import annotations
import re
import typing as t
if t.TYPE_CHECKING:
from ansible.module_utils.basic import AnsibleModule
# Input patterns for is_input_dangerous function:
#
# 1. '"' in string and '--' in string or
# "'" in string and '--' in string
PATTERN_1 = re.compile(r"(\'|\").*--")
# 2. union \ intersect \ except + select
PATTERN_2 = re.compile(r"(UNION|INTERSECT|EXCEPT).*SELECT", re.IGNORECASE)
# 3. ';' and any KEY_WORDS
PATTERN_3 = re.compile(r";.*(SELECT|UPDATE|INSERT|DELETE|DROP|TRUNCATE|ALTER)", re.IGNORECASE)
class SQLParseError(Exception):
pass
class UnclosedQuoteError(SQLParseError):
pass
# maps a type of identifier to the maximum number of dot levels that are
# allowed to specify that identifier. For example, a database column can be
# specified by up to 4 levels: database.schema.table.column
_PG_IDENTIFIER_TO_DOT_LEVEL = dict(
database=1,
schema=2,
table=3,
column=4,
role=1,
tablespace=1,
sequence=3,
publication=1,
)
_MYSQL_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, table=2, column=3, role=1, vars=1)
def _find_end_quote(identifier, quote_char):
accumulate = 0
while True:
try:
quote = identifier.index(quote_char)
except ValueError as e:
raise UnclosedQuoteError from e
accumulate = accumulate + quote
try:
next_char = identifier[quote + 1]
except IndexError:
return accumulate
if next_char == quote_char:
try:
identifier = identifier[quote + 2 :]
accumulate = accumulate + 2
except IndexError as e:
raise UnclosedQuoteError from e
else:
return accumulate
def _identifier_parse(identifier, quote_char):
if not identifier:
raise SQLParseError("Identifier name unspecified or unquoted trailing dot")
already_quoted = False
if identifier.startswith(quote_char):
already_quoted = True
try:
end_quote = _find_end_quote(identifier[1:], quote_char=quote_char) + 1
except UnclosedQuoteError:
already_quoted = False
else:
if end_quote < len(identifier) - 1:
if identifier[end_quote + 1] == ".":
dot = end_quote + 1
first_identifier = identifier[:dot]
next_identifier = identifier[dot + 1 :]
further_identifiers = _identifier_parse(next_identifier, quote_char)
further_identifiers.insert(0, first_identifier)
else:
raise SQLParseError("User escaped identifiers must escape extra quotes")
else:
further_identifiers = [identifier]
if not already_quoted:
try:
dot = identifier.index(".")
except ValueError:
identifier = identifier.replace(quote_char, quote_char * 2)
identifier = f"{quote_char}{identifier}{quote_char}"
further_identifiers = [identifier]
else:
if dot == 0 or dot >= len(identifier) - 1:
identifier = identifier.replace(quote_char, quote_char * 2)
identifier = f"{quote_char}{identifier}{quote_char}"
further_identifiers = [identifier]
else:
first_identifier = identifier[:dot]
next_identifier = identifier[dot + 1 :]
further_identifiers = _identifier_parse(next_identifier, quote_char)
first_identifier = first_identifier.replace(quote_char, quote_char * 2)
first_identifier = f"{quote_char}{first_identifier}{quote_char}"
further_identifiers.insert(0, first_identifier)
return further_identifiers
def pg_quote_identifier(identifier, id_type):
identifier_fragments = _identifier_parse(identifier, quote_char='"')
if len(identifier_fragments) > _PG_IDENTIFIER_TO_DOT_LEVEL[id_type]:
raise SQLParseError(
f"PostgreSQL does not support {id_type} with more than {_PG_IDENTIFIER_TO_DOT_LEVEL[id_type]} dots"
)
return ".".join(identifier_fragments)
def mysql_quote_identifier(identifier, id_type):
identifier_fragments = _identifier_parse(identifier, quote_char="`")
if (len(identifier_fragments) - 1) > _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type]:
raise SQLParseError(
f"MySQL does not support {id_type} with more than {_MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type]} dots"
)
special_cased_fragments = []
for fragment in identifier_fragments:
if fragment == "`*`":
special_cased_fragments.append("*")
else:
special_cased_fragments.append(fragment)
return ".".join(special_cased_fragments)
def is_input_dangerous(string):
"""Check if the passed string is potentially dangerous.
Can be used to prevent SQL injections.
Note: use this function only when you can't use
psycopg2's cursor.execute method parametrized
(typically with DDL queries).
"""
if not string:
return False
return any(pattern.search(string) for pattern in (PATTERN_1, PATTERN_2, PATTERN_3))
def check_input(module: AnsibleModule, *args) -> None:
"""Wrapper for is_input_dangerous function."""
needs_to_check = args
dangerous_elements = []
for elem in needs_to_check:
if isinstance(elem, str):
if is_input_dangerous(elem):
dangerous_elements.append(elem)
elif isinstance(elem, list):
for e in elem:
if is_input_dangerous(e):
dangerous_elements.append(e)
elif elem is None or isinstance(elem, bool):
pass
else:
elem = str(elem)
if is_input_dangerous(elem):
dangerous_elements.append(elem)
if dangerous_elements:
module.fail_json(msg=f"Passed input '{', '.join(dangerous_elements)}' is potentially dangerous")