# This code is part of Ansible, but is an independent component. # This particular file snippet, and this file snippet only, is based on # the config parser from here: https://github.com/emre/storm/blob/master/storm/parsers/ssh_config_parser.py # Copyright (C) <2013> # SPDX-License-Identifier: MIT from __future__ import annotations import os import re import traceback from operator import itemgetter PARAMIKO_IMPORT_ERROR: str | None try: from paramiko.config import SSHConfig except ImportError: SSHConfig = object # type: ignore HAS_PARAMIKO = False PARAMIKO_IMPORT_ERROR = traceback.format_exc() else: HAS_PARAMIKO = True PARAMIKO_IMPORT_ERROR = None class StormConfig(SSHConfig): def parse(self, file_obj): """ Read an OpenSSH config from the given file object. @param file_obj: a file-like object to read the config file from @type file_obj: file """ order = 1 host = { "host": ["*"], "config": {}, } for line in file_obj: line = line.rstrip("\n").lstrip() if line == "": self._config.append( { "type": "empty_line", "value": line, "host": "", "order": order, } ) order += 1 continue if line.startswith("#"): self._config.append( { "type": "comment", "value": line, "host": "", "order": order, } ) order += 1 continue if "=" in line: # Ensure ProxyCommand gets properly split if line.lower().strip().startswith("proxycommand"): proxy_re = re.compile(r"^(proxycommand)\s*=*\s*(.*)", re.I) match = proxy_re.match(line) key, value = match.group(1).lower(), match.group(2) else: key, value = line.split("=", 1) key = key.strip().lower() else: # find first whitespace, and split there i = 0 while (i < len(line)) and not line[i].isspace(): i += 1 if i == len(line): raise Exception(f"Unparsable line: {line!r}") key = line[:i].lower() value = line[i:].lstrip() if key == "host": self._config.append(host) value = value.split() host = {key: value, "config": {}, "type": "entry", "order": order} order += 1 elif key in ["identityfile", "localforward", "remoteforward"]: if key in host["config"]: host["config"][key].append(value) else: host["config"][key] = [value] elif key not in host["config"]: host["config"].update({key: value}) self._config.append(host) class ConfigParser: """ Config parser for ~/.ssh/config files. """ def __init__(self, ssh_config_file=None): if not ssh_config_file: ssh_config_file = self.get_default_ssh_config_file() self.defaults = {} self.ssh_config_file = ssh_config_file if not os.path.exists(self.ssh_config_file): if not os.path.exists(os.path.dirname(self.ssh_config_file)): os.makedirs(os.path.dirname(self.ssh_config_file)) open(self.ssh_config_file, "w+").close() os.chmod(self.ssh_config_file, 0o600) self.config_data = [] def get_default_ssh_config_file(self): return os.path.expanduser("~/.ssh/config") def load(self): config = StormConfig() with open(self.ssh_config_file) as fd: config.parse(fd) for entry in config.__dict__.get("_config"): if entry.get("host") == ["*"]: self.defaults.update(entry.get("config")) if entry.get("type") in ["comment", "empty_line"]: self.config_data.append(entry) continue host_item = { "host": entry["host"][0], "options": entry.get("config"), "type": "entry", "order": entry.get("order", 0), } if len(entry["host"]) > 1: host_item.update( { "host": " ".join(entry["host"]), } ) # minor bug in paramiko.SSHConfig that duplicates # "Host *" entries. if entry.get("config") and len(entry.get("config")) > 0: self.config_data.append(host_item) return self.config_data def add_host(self, host, options): self.config_data.append( { "host": host, "options": options, "order": self.get_last_index(), } ) return self def update_host(self, host, options, use_regex=False): for index, host_entry in enumerate(self.config_data): if host_entry.get("host") == host or (use_regex and re.match(host, host_entry.get("host"))): if "deleted_fields" in options: deleted_fields = options.pop("deleted_fields") for deleted_field in deleted_fields: del self.config_data[index]["options"][deleted_field] self.config_data[index]["options"].update(options) return self def search_host(self, search_string): results = [] for host_entry in self.config_data: if host_entry.get("type") != "entry": continue if host_entry.get("host") == "*": continue searchable_information = host_entry.get("host") for key, value in host_entry.get("options").items(): if isinstance(value, list): value = " ".join(value) if isinstance(value, int): value = str(value) searchable_information += f" {value}" if search_string in searchable_information: results.append(host_entry) return results def delete_host(self, host): found = 0 for index, host_entry in enumerate(self.config_data): if host_entry.get("host") == host: del self.config_data[index] found += 1 if found == 0: raise ValueError("No host found") return self def delete_all_hosts(self): self.config_data = [] self.write_to_ssh_config() return self def dump(self): if len(self.config_data) < 1: return file_content = "" self.config_data = sorted(self.config_data, key=itemgetter("order")) for host_item in self.config_data: if host_item.get("type") in ["comment", "empty_line"]: file_content += f"{host_item.get('value')}\n" continue host_item_content = f"Host {host_item.get('host')}\n" for key, value in host_item.get("options").items(): if isinstance(value, list): sub_content = "" for value_ in value: sub_content += f" {key} {value_}\n" host_item_content += sub_content else: host_item_content += f" {key} {value}\n" file_content += host_item_content return file_content def write_to_ssh_config(self): with open(self.ssh_config_file, "w+") as f: data = self.dump() if data: f.write(data) return self def get_last_index(self): last_index = 0 indexes = [] for item in self.config_data: if item.get("order"): indexes.append(item.get("order")) if len(indexes) > 0: last_index = max(indexes) return last_index