Files

293 lines
10 KiB
Python

#!python3
# dlitz 2026
import json
import os
import re
import shlex
import subprocess
import tempfile
import textwrap
from collections import defaultdict
from dataclasses import dataclass
from io import StringIO
from pathlib import Path, PosixPath
from cryptography import x509
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric.types import (
CertificatePublicKeyTypes,
PrivateKeyTypes,
PublicKeyTypes,
)
from cryptography.hazmat.primitives.serialization import (
BestAvailableEncryption,
Encoding,
PrivateFormat,
PublicFormat,
load_pem_private_key,
)
from .connector import Connector
import_result_regex = re.compile(
r"""
^
\s*
(?P<name>
certificates-imported
| private-keys-imported
| files-imported
| decryption-failures
| keys-with-no-certificate
)
:\s
(?P<count>\d+)
\s*
$
""",
re.M | re.X,
)
json_list_regex = re.compile(r"^\[.*\]$")
fingerprint_regex = re.compile(r"\A[0-9a-f]{64}\Z")
class CertInstallError(Exception):
pass
@dataclass
class InstallInfo:
skid_hex: str # SubjectKeyInfo digest (hexadecimal)
fingerprint_hex: str | None # fingerprint of cert
remote_name: str # remote name
remote_filename: str # remote filename
content: bytes # file content
install_cmd: str # command used to install
class RouterOS:
def __init__(self, connector: Connector):
self.connector = connector
def get_certificate_list(self):
cmdline = ":put [:serialize to=json [/certificate print detail as-value]]"
result_json = self.connector.invoke_remote_command(cmdline, capture=True)
return json.loads(result_json)
@classmethod
def cert_skid(cls, cert_obj) -> str:
skid_bin = cert_obj.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
).value.key_identifier
return skid_bin.hex().upper()
@classmethod
def cert_fingerprint(cls, cert_obj) -> str:
return cert_obj.fingerprint(hashes.SHA256()).hex()
def _make_cert_installinfo(self, cert_obj, privkey_obj=None) -> InstallInfo:
# RouterOS indexes cert by their SubjectKeyIdentifier (SKID) so we'll
# use that for naming.
skid_hex = self.cert_skid(cert_obj)
fpr_hex = self.cert_fingerprint(cert_obj)
remote_name = f"skid_{skid_hex}"
remote_filename = f"fpr_{fpr_hex}.pem"
content = cert_obj.public_bytes(Encoding.PEM)
passphrase = None
if privkey_obj is not None:
assert (
cert_obj.public_key() == privkey_obj.public_key()
), "private key doesn't match certificate"
passphrase = os.urandom(32).hex()
encrypted_private_key = privkey_obj.private_bytes(
Encoding.PEM,
PrivateFormat.PKCS8,
BestAvailableEncryption(passphrase.encode()),
)
content = content.rstrip(b"\n") + b"\n" + encrypted_private_key
install_cmd = f'/certificate import trusted=no no-key-export=yes name="{remote_name}" file-name="{remote_filename}"'
if passphrase is not None:
install_cmd += f' passphrase="{passphrase}"'
return InstallInfo(
skid_hex=skid_hex,
fingerprint_hex=fpr_hex,
remote_name=remote_name,
remote_filename=remote_filename,
content=content,
install_cmd=install_cmd,
)
def _dedup_certs(self, cert_objs) -> list:
# Dedup certificates by fingerprint
cert_by_fpr = {}
for cert_obj in cert_objs:
fpr = self.cert_fingerprint(cert_obj)
cert_by_fpr[fpr] = cert_obj
# Look for duplicate skids in different certs
cert_skids = {}
for cert_obj in cert_by_fpr.values():
skid = self.cert_skid(cert_obj)
if skid in cert_skids:
raise ValueError(f"skid conflict: {skid!r}")
return list(cert_by_fpr.values())
def install_key_and_certificates(
self, key: str, cert: str, chain: str | None = None
):
if not chain:
chain = ""
private_key_obj = load_pem_private_key(key.encode(), None)
cert_objs = x509.load_pem_x509_certificates((cert + "\n" + chain).encode())
cert_objs = self._dedup_certs(cert_objs)
# Find the certificate that signs our private key
public_key_obj = private_key_obj.public_key()
(host_cert_obj,) = [
cert_obj
for cert_obj in cert_objs
if cert_obj.public_key() == public_key_obj
]
chain_cert_objs = [
cert_obj
for cert_obj in cert_objs
if cert_obj.public_key() != public_key_obj
]
private_key_skid = self.cert_skid(host_cert_obj)
# Build the list of InstallInfo objects, and sets of what we're
# expecting to find.
install_list = [self._make_cert_installinfo(host_cert_obj, private_key_obj)]
install_list += [
self._make_cert_installinfo(cert_obj) for cert_obj in chain_cert_objs
]
expected_skids = {info.skid_hex for info in install_list}
expected_fingerprints = {info.fingerprint_hex for info in install_list} - {None}
# Merge the commands into a single command-line, and
# add a command to get a list of installed certs at the end.
remote_commands = []
# workaround for race condition where /certificate/import sometimes
# doesn't see the recently-uploaded file
remote_commands += [":delay 0.1"]
remote_commands += [info.install_cmd for info in install_list]
remote_commands += [
":put [:serialize to=json [/certificate print detail as-value]]"
]
remote_cmdline = "\n".join(cmd for cmd in remote_commands)
# Collect the files to upload
remote_files = {info.remote_filename: info.content for info in install_list}
expected_files_count = len(cert_objs) + 1
# Upload the files and run the command
self.connector.create_remote_files(remote_files, "")
cmd_output = self.connector.invoke_remote_command(
remote_cmdline, capture=True, text=True
)
try:
import_result_counts, certlist_output = self._parse_cmd_output(cmd_output)
if import_result_counts["decryption-failures"] != 0:
raise CertInstallError(
f"BUG: Private key decryption failed on install; {import_result_counts!r}"
)
if import_result_counts["keys-with-no-certificate"] != 0:
raise CertInstallError(
f"BUG: No certificate for private key; {import_result_counts!r}"
)
found_fingerprints = set()
found_skids = set()
found_private_key_skids = set()
for cert_detail in certlist_output:
assert isinstance(cert_detail["fingerprint"], str)
assert isinstance(cert_detail["skid"], str)
assert isinstance(cert_detail["private-key"], bool)
skid = cert_detail["skid"].upper()
found_fingerprints.add(cert_detail["fingerprint"].lower())
found_skids.add(skid)
if cert_detail["private-key"]:
found_private_key_skids.add(skid)
if private_key_skid not in found_private_key_skids:
raise CertInstallError(
f"private-key for skid {private_key_skid} was not installed"
)
missing_skids: set = expected_skids - found_skids
missing_fingerprints: set = expected_fingerprints - found_fingerprints
if missing_skids or missing_fingerprints:
raise CertInstallError(
f"some certs were not installed {missing_skids=!r}, {missing_fingerprints=!r}"
f",\n{expected_skids=!r}"
f",\n{found_skids=!r}"
f",\n{expected_fingerprints=!r}"
f",\n{found_fingerprints=!r}"
)
return self.cert_fingerprint(host_cert_obj), host_cert_obj
except Exception as exc:
exc.add_note(f"remote_cmdline={remote_cmdline!r}")
exc.add_note(f"cmd_output={cmd_output!r}")
raise
def use_certificate(self, fingerprint: str):
if not fingerprint_regex.fullmatch(fingerprint):
raise ValueError(f"illegal fingerprint {fingerprint!r}")
cmds = [
f'/ip/service set api-ssl,www-ssl certificate=[/certificate find where fingerprint="{fingerprint}"]',
":put [:serialize to=json value={[/certificate get [/ip/service get api-ssl certificate] fingerprint],[/certificate get [/ip/service get www-ssl certificate] fingerprint]}]",
]
remote_cmdline = "\n".join(cmds)
raw_result = self.connector.invoke_remote_command(
remote_cmdline, text=True, capture=True
)
result = json.loads(raw_result)
api_fingerprint = bytes.fromhex(result[0][0]).hex()
www_fingerprint = bytes.fromhex(result[0][1]).hex()
missing = []
if api_fingerprint != fingerprint:
missing.append("api-ssl")
if www_fingerprint != fingerprint:
missing.append("www-ssl")
if missing:
raise CertInstallError(
f"certs didn't get installed to {','.join(missing)}:"
f" {api_fingerprint=!r}"
f", {www_fingerprint=!r}"
f", {fingerprint=!r}"
)
def _parse_cmd_output(self, cmd_output: str):
import_result_counts = defaultdict(int)
certlist_output = None
for rawline in StringIO(cmd_output):
line = rawline.rstrip("\r\n")
if m := import_result_regex.match(line):
import_result_counts[m["name"]] += int(m["count"])
elif m := json_list_regex.match(line):
assert certlist_output is None, "certificate list json received twice?"
certlist_output = json.loads(line)
elif not line.strip():
# blank line
pass
else:
raise ValueError(f"unable to parse output line: {line!r}")
return import_result_counts, certlist_output