Files
mtik-cert-pusher/mtik_cert_pusher/connector.py

133 lines
3.8 KiB
Python

#!python3
# dlitz 2026
import json
import os
import re
import shlex
import subprocess
import tempfile
from pathlib import Path, PosixPath
remote_name_validation_regex = re.compile(
r"""
\A
[A-Za-z0-9_]
[A-Za-z0-9_\.\-]*
(?:
\.
[A-Za-z0-9_]+
)?
\Z
""",
re.X,
)
def _private_opener(file, flags):
"""Open a file with restrictive permission bits"""
return os.open(os.fsencode(file), flags, 0o600)
class Connector:
pass
class SSHConnector(Connector):
ssh_executable = "ssh"
scp_executable = "scp"
temporary_directory = "/dev/shm"
common_args = ["-oBatchMode=yes", "-oControlMaster=no"]
def __init__(
self,
host: str,
*,
user: str | None = None,
port: int | None = None,
ssh_config_path: str | None = None,
extra_ssh_options=None,
extra_ssh_args=None,
extra_scp_args=None,
):
self.ssh_host = host
self.ssh_user = user
self.ssh_port = port
self.ssh_config_path = ssh_config_path
if extra_ssh_options is None:
extra_ssh_options = {}
assert isinstance(extra_ssh_options, dict), "extra_ssh_options should be a dict"
self.extra_ssh_options = extra_ssh_options
self.extra_ssh_args = extra_ssh_args or ()
self.extra_scp_args = extra_scp_args or ()
def _ssh_option_args(self) -> list:
result = []
if self.ssh_config_path:
result.extend(["-F", str(self.ssh_config_path)])
if self.ssh_user:
result.append(f"-oUser={self.ssh_user}")
if self.ssh_port:
result.append(f"-oPort={self.ssh_port:d}")
for k, v in self.extra_ssh_options.items():
assert "=" not in k
assert k
result.append(f"-o{k}={v}")
return result
def _ssh_args(self, args, /) -> list:
return [
self.ssh_executable,
*self.common_args,
*self._ssh_option_args(),
*self.extra_ssh_args,
*args,
]
def _scp_args(self, args, /) -> list:
return [
self.scp_executable,
*self.common_args,
*self._ssh_option_args(),
*self.extra_scp_args,
*args,
]
def invoke_remote_command(
self, cmdline: str, text: bool = False, capture: bool = False
) -> str:
cmd = self._ssh_args([self.ssh_host, cmdline])
# print("running: ", shlex.join(cmd))
if capture:
return subprocess.check_output(cmd, text=text)
subprocess.run(cmd, check=True)
def create_remote_files(self, content_by_name: dict, remote_directory: str):
if not content_by_name:
raise ValueError("require at least one file to copy")
with tempfile.TemporaryDirectory(
dir=self.temporary_directory, prefix="mtik-connector-tmp"
) as td:
tempfile_paths = []
# Write the files to a temporary directory
for remote_name, content in content_by_name.items():
assert isinstance(remote_name, str)
if not remote_name_validation_regex.fullmatch(remote_name):
raise ValueError(f"illegal remote filename: {remote_name!r}")
tempfile_path = Path(td, remote_name)
with open(tempfile_path, "wb", opener=_private_opener) as outfile:
outfile.write(content)
tempfile_paths.append(tempfile_path)
# Copy them in a single scp command
cmd = self._scp_args(
[
*(str(p) for p in tempfile_paths),
f"{self.ssh_host}:{remote_directory}",
]
)
# print("running: ", shlex.join(cmd))
subprocess.run(cmd, check=True)