134 lines
3.8 KiB
Python
134 lines
3.8 KiB
Python
#!python3
|
|
# dlitz 2026
|
|
|
|
import json
|
|
import os
|
|
import re
|
|
import shlex
|
|
import subprocess
|
|
import tempfile
|
|
from pathlib import Path, PosixPath
|
|
|
|
remote_name_validation_regex = re.compile(
|
|
r"""
|
|
\A
|
|
[A-Za-z0-9_]
|
|
[A-Za-z0-9_\.\-]*
|
|
(?:
|
|
\.
|
|
[A-Za-z0-9_]+
|
|
)?
|
|
\Z
|
|
""",
|
|
re.X,
|
|
)
|
|
|
|
|
|
def _private_opener(file, flags):
|
|
"""Open a file with restrictive permission bits"""
|
|
return os.open(os.fsencode(file), flags, 0o600)
|
|
|
|
|
|
class Connector:
|
|
pass
|
|
|
|
|
|
class SSHConnector(Connector):
|
|
|
|
ssh_executable = "ssh"
|
|
scp_executable = "scp"
|
|
temporary_directory = "/dev/shm"
|
|
common_args = ["-oBatchMode=yes", "-oControlMaster=no"]
|
|
|
|
def __init__(
|
|
self,
|
|
host: str,
|
|
*,
|
|
user: str | None = None,
|
|
port: int | None = None,
|
|
ssh_config_path: str | None = None,
|
|
extra_ssh_options=None,
|
|
extra_ssh_args=None,
|
|
extra_scp_args=None,
|
|
):
|
|
self.ssh_host = host
|
|
self.ssh_user = user
|
|
self.ssh_port = port
|
|
self.ssh_config_path = ssh_config_path
|
|
if extra_ssh_options is None:
|
|
extra_ssh_options = {}
|
|
assert isinstance(extra_ssh_options, dict), "extra_ssh_options should be a dict"
|
|
self.extra_ssh_options = extra_ssh_options
|
|
self.extra_ssh_args = extra_ssh_args or ()
|
|
self.extra_scp_args = extra_scp_args or ()
|
|
|
|
def _ssh_option_args(self) -> list:
|
|
result = []
|
|
if self.ssh_config_path:
|
|
result.extend(["-F", self.ssh_config_path])
|
|
if self.ssh_user:
|
|
result.append("-oUser={self.ssh_user}")
|
|
if self.ssh_port:
|
|
result.append("-oPort={self.ssh_port:d}")
|
|
for k, v in self.extra_ssh_options.items():
|
|
assert "=" not in k
|
|
assert k
|
|
result.append(f"-o{k}={v}")
|
|
return result
|
|
|
|
def _ssh_args(self, args, /) -> list:
|
|
return [
|
|
self.ssh_executable,
|
|
*self.common_args,
|
|
*self._ssh_option_args(),
|
|
*self.extra_ssh_args,
|
|
*args,
|
|
]
|
|
|
|
def _scp_args(self, args, /) -> list:
|
|
return [
|
|
self.scp_executable,
|
|
*self.common_args,
|
|
*self._ssh_option_args(),
|
|
*self.extra_scp_args,
|
|
*args,
|
|
]
|
|
|
|
def invoke_remote_command(
|
|
self, cmdline: str, text: bool = False, capture: bool = False
|
|
) -> str:
|
|
cmd = self._ssh_args([self.ssh_host, cmdline])
|
|
# print("running: ", shlex.join(cmd))
|
|
if capture:
|
|
return subprocess.check_output(cmd, text=text)
|
|
subprocess.run(cmd, check=True)
|
|
|
|
def create_remote_files(self, content_by_name: dict, remote_directory: str):
|
|
if not content_by_name:
|
|
raise ValueError("require at least one file to copy")
|
|
with tempfile.TemporaryDirectory(
|
|
dir=self.temporary_directory, prefix="mtik-connector-tmp"
|
|
) as td:
|
|
tempfile_paths = []
|
|
|
|
# Write the files to a temporary directory
|
|
for remote_name, content in content_by_name.items():
|
|
assert isinstance(remote_name, str)
|
|
if not remote_name_validation_regex.fullmatch(remote_name):
|
|
raise ValueError(f"illegal remote filename: {remote_name!r}")
|
|
tempfile_path = Path(td, remote_name)
|
|
with open(tempfile_path, "wb", opener=_private_opener) as outfile:
|
|
outfile.write(content)
|
|
tempfile_paths.append(tempfile_path)
|
|
|
|
# Copy them in a single scp command
|
|
cmd = self._scp_args(
|
|
[
|
|
"-q",
|
|
*(str(p) for p in tempfile_paths),
|
|
f"{self.ssh_host}:{remote_directory}",
|
|
]
|
|
)
|
|
# print("running: ", shlex.join(cmd))
|
|
subprocess.run(cmd, check=True)
|