#!python3 # dlitz 2026 import json import os import re import shlex import subprocess import tempfile import textwrap from collections import defaultdict from dataclasses import dataclass from io import StringIO from pathlib import Path, PosixPath from cryptography import x509 from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric.types import ( CertificatePublicKeyTypes, PrivateKeyTypes, PublicKeyTypes, ) from cryptography.hazmat.primitives.serialization import ( BestAvailableEncryption, Encoding, PrivateFormat, PublicFormat, load_pem_private_key, ) from .connector import Connector import_result_regex = re.compile( r""" ^ \s* (?P certificates-imported | private-keys-imported | files-imported | decryption-failures | keys-with-no-certificate ) :\s (?P\d+) \s* $ """, re.M | re.X, ) json_list_regex = re.compile(r"^\[.*\]$") fingerprint_regex = re.compile(r"\A[0-9a-f]{64}\Z") class CertInstallError(Exception): pass @dataclass class InstallInfo: skid_hex: str # SubjectKeyInfo digest (hexadecimal) fingerprint_hex: str | None # fingerprint of cert remote_name: str # remote name remote_filename: str # remote filename content: bytes # file content install_cmd: str # command used to install class RouterOS: def __init__(self, connector: Connector): self.connector = connector def get_certificate_list(self): cmdline = ":put [:serialize to=json [/certificate print detail as-value]]" result_json = self.connector.invoke_remote_command(cmdline, capture=True) return json.loads(result_json) @classmethod def cert_skid(cls, cert_obj) -> str: skid_bin = cert_obj.extensions.get_extension_for_class( x509.SubjectKeyIdentifier ).value.key_identifier return skid_bin.hex().upper() @classmethod def cert_fingerprint(cls, cert_obj) -> str: return cert_obj.fingerprint(hashes.SHA256()).hex() def _make_cert_installinfo(self, cert_obj, privkey_obj=None) -> InstallInfo: # RouterOS indexes cert by their SubjectKeyIdentifier (SKID) so we'll # use that for naming. skid_hex = self.cert_skid(cert_obj) fpr_hex = self.cert_fingerprint(cert_obj) remote_name = f"skid_{skid_hex}" remote_filename = f"fpr_{fpr_hex}.pem" content = cert_obj.public_bytes(Encoding.PEM) passphrase = None if privkey_obj is not None: assert ( cert_obj.public_key() == privkey_obj.public_key() ), "private key doesn't match certificate" passphrase = os.urandom(32).hex() encrypted_private_key = privkey_obj.private_bytes( Encoding.PEM, PrivateFormat.PKCS8, BestAvailableEncryption(passphrase.encode()), ) content = content.rstrip(b"\n") + b"\n" + encrypted_private_key install_cmd = f'/certificate import trusted=no no-key-export=yes name="{remote_name}" file-name="{remote_filename}"' if passphrase is not None: install_cmd += f' passphrase="{passphrase}"' return InstallInfo( skid_hex=skid_hex, fingerprint_hex=fpr_hex, remote_name=remote_name, remote_filename=remote_filename, content=content, install_cmd=install_cmd, ) def _dedup_certs(self, cert_objs) -> list: # Dedup certificates by fingerprint cert_by_fpr = {} for cert_obj in cert_objs: fpr = self.cert_fingerprint(cert_obj) cert_by_fpr[fpr] = cert_obj # Look for duplicate skids in different certs cert_skids = {} for cert_obj in cert_by_fpr.values(): skid = self.cert_skid(cert_obj) if skid in cert_skids: raise ValueError(f"skid conflict: {skid!r}") return list(cert_by_fpr.values()) def install_key_and_certificates( self, key: str, cert: str, chain: str | None = None ): if not chain: chain = "" private_key_obj = load_pem_private_key(key.encode(), None) cert_objs = x509.load_pem_x509_certificates((cert + "\n" + chain).encode()) cert_objs = self._dedup_certs(cert_objs) # Find the certificate that signs our private key public_key_obj = private_key_obj.public_key() (host_cert_obj,) = [ cert_obj for cert_obj in cert_objs if cert_obj.public_key() == public_key_obj ] chain_cert_objs = [ cert_obj for cert_obj in cert_objs if cert_obj.public_key() != public_key_obj ] private_key_skid = self.cert_skid(host_cert_obj) # Build the list of InstallInfo objects, and sets of what we're # expecting to find. install_list = [self._make_cert_installinfo(host_cert_obj, private_key_obj)] install_list += [ self._make_cert_installinfo(cert_obj) for cert_obj in chain_cert_objs ] expected_skids = {info.skid_hex for info in install_list} expected_fingerprints = {info.fingerprint_hex for info in install_list} - {None} # Merge the commands into a single command-line, and # add a command to get a list of installed certs at the end. remote_commands = [] # workaround for race condition where /certificate/import sometimes # doesn't see the recently-uploaded file remote_commands += [":delay 0.1"] remote_commands += [info.install_cmd for info in install_list] remote_commands += [ ":put [:serialize to=json [/certificate print detail as-value]]" ] remote_cmdline = "\n".join(cmd for cmd in remote_commands) # Collect the files to upload remote_files = {info.remote_filename: info.content for info in install_list} expected_files_count = len(cert_objs) + 1 # Upload the files and run the command self.connector.create_remote_files(remote_files, "") cmd_output = self.connector.invoke_remote_command( remote_cmdline, capture=True, text=True ) try: import_result_counts, certlist_output = self._parse_cmd_output(cmd_output) if import_result_counts["decryption-failures"] != 0: raise CertInstallError( f"BUG: Private key decryption failed on install; {import_result_counts!r}" ) if import_result_counts["keys-with-no-certificate"] != 0: raise CertInstallError( f"BUG: No certificate for private key; {import_result_counts!r}" ) found_fingerprints = set() found_skids = set() found_private_key_skids = set() for cert_detail in certlist_output: assert isinstance(cert_detail["fingerprint"], str) assert isinstance(cert_detail["skid"], str) assert isinstance(cert_detail["private-key"], bool) skid = cert_detail["skid"].upper() found_fingerprints.add(cert_detail["fingerprint"].lower()) found_skids.add(skid) if cert_detail["private-key"]: found_private_key_skids.add(skid) if private_key_skid not in found_private_key_skids: raise CertInstallError( f"private-key for skid {private_key_skid} was not installed" ) missing_skids: set = expected_skids - found_skids missing_fingerprints: set = expected_fingerprints - found_fingerprints if missing_skids or missing_fingerprints: raise CertInstallError( f"some certs were not installed {missing_skids=!r}, {missing_fingerprints=!r}" f",\n{expected_skids=!r}" f",\n{found_skids=!r}" f",\n{expected_fingerprints=!r}" f",\n{found_fingerprints=!r}" ) return self.cert_fingerprint(host_cert_obj), host_cert_obj except Exception as exc: exc.add_note(f"remote_cmdline={remote_cmdline!r}") exc.add_note(f"cmd_output={cmd_output!r}") raise def use_certificate(self, fingerprint: str): if not fingerprint_regex.fullmatch(fingerprint): raise ValueError(f"illegal fingerprint {fingerprint!r}") cmds = [ f'/ip/service set api-ssl,www-ssl certificate=[/certificate find where fingerprint="{fingerprint}"]', f':put [:serialize to=json value={{[/certificate get [/ip/service get api-ssl certificate] fingerprint],[/certificate get [/ip/service get www-ssl certificate] fingerprint]}}]' ] remote_cmdline = "\n".join(cmds) raw_result = self.connector.invoke_remote_command( remote_cmdline, text=True, capture=True ) result = json.loads(raw_result) api_fingerprint = bytes.fromhex(result[0][0]).hex() www_fingerprint = bytes.fromhex(result[0][1]).hex() missing = [] if api_fingerprint != fingerprint: missing.append("api-ssl") if www_fingerprint != fingerprint: missing.append("www-ssl") if missing: raise CertInstallError( f"certs didn't get installed to {','.join(missing)}:" f" {api_fingerprint=!r}" f", {www_fingerprint=!r}" f", {fingerprint=!r}" ) def _parse_cmd_output(self, cmd_output: str): import_result_counts = defaultdict(int) certlist_output = None for rawline in StringIO(cmd_output): line = rawline.rstrip("\r\n") if m := import_result_regex.match(line): import_result_counts[m["name"]] += int(m["count"]) elif m := json_list_regex.match(line): assert certlist_output is None, "certificate list json received twice?" certlist_output = json.loads(line) elif not line.strip(): # blank line pass else: raise ValueError(f"unable to parse output line: {line!r}") return import_result_counts, certlist_output