293 lines
10 KiB
Python
293 lines
10 KiB
Python
#!python3
|
|
# dlitz 2026
|
|
|
|
import json
|
|
import os
|
|
import re
|
|
import shlex
|
|
import subprocess
|
|
import tempfile
|
|
import textwrap
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from io import StringIO
|
|
from pathlib import Path, PosixPath
|
|
|
|
from cryptography import x509
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.asymmetric.types import (
|
|
CertificatePublicKeyTypes,
|
|
PrivateKeyTypes,
|
|
PublicKeyTypes,
|
|
)
|
|
from cryptography.hazmat.primitives.serialization import (
|
|
BestAvailableEncryption,
|
|
Encoding,
|
|
PrivateFormat,
|
|
PublicFormat,
|
|
load_pem_private_key,
|
|
)
|
|
|
|
from .connector import Connector
|
|
|
|
import_result_regex = re.compile(
|
|
r"""
|
|
^
|
|
\s*
|
|
(?P<name>
|
|
certificates-imported
|
|
| private-keys-imported
|
|
| files-imported
|
|
| decryption-failures
|
|
| keys-with-no-certificate
|
|
)
|
|
:\s
|
|
(?P<count>\d+)
|
|
\s*
|
|
$
|
|
""",
|
|
re.M | re.X,
|
|
)
|
|
|
|
json_list_regex = re.compile(r"^\[.*\]$")
|
|
|
|
fingerprint_regex = re.compile(r"\A[0-9a-f]{64}\Z")
|
|
|
|
|
|
class CertInstallError(Exception):
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class InstallInfo:
|
|
skid_hex: str # SubjectKeyInfo digest (hexadecimal)
|
|
fingerprint_hex: str | None # fingerprint of cert
|
|
remote_name: str # remote name
|
|
remote_filename: str # remote filename
|
|
content: bytes # file content
|
|
install_cmd: str # command used to install
|
|
|
|
|
|
class RouterOS:
|
|
def __init__(self, connector: Connector):
|
|
self.connector = connector
|
|
|
|
def get_certificate_list(self):
|
|
cmdline = ":put [:serialize to=json [/certificate print detail as-value]]"
|
|
result_json = self.connector.invoke_remote_command(cmdline, capture=True)
|
|
return json.loads(result_json)
|
|
|
|
@classmethod
|
|
def cert_skid(cls, cert_obj) -> str:
|
|
skid_bin = cert_obj.extensions.get_extension_for_class(
|
|
x509.SubjectKeyIdentifier
|
|
).value.key_identifier
|
|
return skid_bin.hex().upper()
|
|
|
|
@classmethod
|
|
def cert_fingerprint(cls, cert_obj) -> str:
|
|
return cert_obj.fingerprint(hashes.SHA256()).hex()
|
|
|
|
def _make_cert_installinfo(self, cert_obj, privkey_obj=None) -> InstallInfo:
|
|
# RouterOS indexes cert by their SubjectKeyIdentifier (SKID) so we'll
|
|
# use that for naming.
|
|
skid_hex = self.cert_skid(cert_obj)
|
|
fpr_hex = self.cert_fingerprint(cert_obj)
|
|
|
|
remote_name = f"skid_{skid_hex}"
|
|
remote_filename = f"fpr_{fpr_hex}.pem"
|
|
|
|
content = cert_obj.public_bytes(Encoding.PEM)
|
|
passphrase = None
|
|
if privkey_obj is not None:
|
|
assert (
|
|
cert_obj.public_key() == privkey_obj.public_key()
|
|
), "private key doesn't match certificate"
|
|
|
|
passphrase = os.urandom(32).hex()
|
|
|
|
encrypted_private_key = privkey_obj.private_bytes(
|
|
Encoding.PEM,
|
|
PrivateFormat.PKCS8,
|
|
BestAvailableEncryption(passphrase.encode()),
|
|
)
|
|
|
|
content = content.rstrip(b"\n") + b"\n" + encrypted_private_key
|
|
|
|
install_cmd = f'/certificate import trusted=no no-key-export=yes name="{remote_name}" file-name="{remote_filename}"'
|
|
if passphrase is not None:
|
|
install_cmd += f' passphrase="{passphrase}"'
|
|
|
|
return InstallInfo(
|
|
skid_hex=skid_hex,
|
|
fingerprint_hex=fpr_hex,
|
|
remote_name=remote_name,
|
|
remote_filename=remote_filename,
|
|
content=content,
|
|
install_cmd=install_cmd,
|
|
)
|
|
|
|
def _dedup_certs(self, cert_objs) -> list:
|
|
# Dedup certificates by fingerprint
|
|
cert_by_fpr = {}
|
|
for cert_obj in cert_objs:
|
|
fpr = self.cert_fingerprint(cert_obj)
|
|
cert_by_fpr[fpr] = cert_obj
|
|
|
|
# Look for duplicate skids in different certs
|
|
cert_skids = {}
|
|
for cert_obj in cert_by_fpr.values():
|
|
skid = self.cert_skid(cert_obj)
|
|
if skid in cert_skids:
|
|
raise ValueError(f"skid conflict: {skid!r}")
|
|
|
|
return list(cert_by_fpr.values())
|
|
|
|
def install_key_and_certificates(
|
|
self, key: str, cert: str, chain: str | None = None
|
|
):
|
|
if not chain:
|
|
chain = ""
|
|
private_key_obj = load_pem_private_key(key.encode(), None)
|
|
cert_objs = x509.load_pem_x509_certificates((cert + "\n" + chain).encode())
|
|
cert_objs = self._dedup_certs(cert_objs)
|
|
|
|
# Find the certificate that signs our private key
|
|
public_key_obj = private_key_obj.public_key()
|
|
(host_cert_obj,) = [
|
|
cert_obj
|
|
for cert_obj in cert_objs
|
|
if cert_obj.public_key() == public_key_obj
|
|
]
|
|
chain_cert_objs = [
|
|
cert_obj
|
|
for cert_obj in cert_objs
|
|
if cert_obj.public_key() != public_key_obj
|
|
]
|
|
|
|
private_key_skid = self.cert_skid(host_cert_obj)
|
|
|
|
# Build the list of InstallInfo objects, and sets of what we're
|
|
# expecting to find.
|
|
install_list = [self._make_cert_installinfo(host_cert_obj, private_key_obj)]
|
|
install_list += [
|
|
self._make_cert_installinfo(cert_obj) for cert_obj in chain_cert_objs
|
|
]
|
|
|
|
expected_skids = {info.skid_hex for info in install_list}
|
|
expected_fingerprints = {info.fingerprint_hex for info in install_list} - {None}
|
|
|
|
# Merge the commands into a single command-line, and
|
|
# add a command to get a list of installed certs at the end.
|
|
remote_commands = []
|
|
# workaround for race condition where /certificate/import sometimes
|
|
# doesn't see the recently-uploaded file
|
|
remote_commands += [":delay 0.1"]
|
|
remote_commands += [info.install_cmd for info in install_list]
|
|
remote_commands += [
|
|
":put [:serialize to=json [/certificate print detail as-value]]"
|
|
]
|
|
remote_cmdline = "\n".join(cmd for cmd in remote_commands)
|
|
|
|
# Collect the files to upload
|
|
remote_files = {info.remote_filename: info.content for info in install_list}
|
|
expected_files_count = len(cert_objs) + 1
|
|
|
|
# Upload the files and run the command
|
|
self.connector.create_remote_files(remote_files, "")
|
|
cmd_output = self.connector.invoke_remote_command(
|
|
remote_cmdline, capture=True, text=True
|
|
)
|
|
try:
|
|
import_result_counts, certlist_output = self._parse_cmd_output(cmd_output)
|
|
|
|
if import_result_counts["decryption-failures"] != 0:
|
|
raise CertInstallError(
|
|
f"BUG: Private key decryption failed on install; {import_result_counts!r}"
|
|
)
|
|
if import_result_counts["keys-with-no-certificate"] != 0:
|
|
raise CertInstallError(
|
|
f"BUG: No certificate for private key; {import_result_counts!r}"
|
|
)
|
|
|
|
found_fingerprints = set()
|
|
found_skids = set()
|
|
found_private_key_skids = set()
|
|
for cert_detail in certlist_output:
|
|
assert isinstance(cert_detail["fingerprint"], str)
|
|
assert isinstance(cert_detail["skid"], str)
|
|
assert isinstance(cert_detail["private-key"], bool)
|
|
skid = cert_detail["skid"].upper()
|
|
found_fingerprints.add(cert_detail["fingerprint"].lower())
|
|
found_skids.add(skid)
|
|
if cert_detail["private-key"]:
|
|
found_private_key_skids.add(skid)
|
|
|
|
if private_key_skid not in found_private_key_skids:
|
|
raise CertInstallError(
|
|
f"private-key for skid {private_key_skid} was not installed"
|
|
)
|
|
missing_skids: set = expected_skids - found_skids
|
|
missing_fingerprints: set = expected_fingerprints - found_fingerprints
|
|
if missing_skids or missing_fingerprints:
|
|
raise CertInstallError(
|
|
f"some certs were not installed {missing_skids=!r}, {missing_fingerprints=!r}"
|
|
f",\n{expected_skids=!r}"
|
|
f",\n{found_skids=!r}"
|
|
f",\n{expected_fingerprints=!r}"
|
|
f",\n{found_fingerprints=!r}"
|
|
)
|
|
|
|
return self.cert_fingerprint(host_cert_obj), host_cert_obj
|
|
|
|
except Exception as exc:
|
|
exc.add_note(f"remote_cmdline={remote_cmdline!r}")
|
|
exc.add_note(f"cmd_output={cmd_output!r}")
|
|
raise
|
|
|
|
def use_certificate(self, fingerprint: str):
|
|
if not fingerprint_regex.fullmatch(fingerprint):
|
|
raise ValueError(f"illegal fingerprint {fingerprint!r}")
|
|
cmds = [
|
|
f'/ip/service set api-ssl,www-ssl certificate=[/certificate find where fingerprint="{fingerprint}"]',
|
|
":put [:serialize to=json value={[/certificate get [/ip/service get api-ssl certificate] fingerprint],[/certificate get [/ip/service get www-ssl certificate] fingerprint]}]",
|
|
]
|
|
remote_cmdline = "\n".join(cmds)
|
|
raw_result = self.connector.invoke_remote_command(
|
|
remote_cmdline, text=True, capture=True
|
|
)
|
|
result = json.loads(raw_result)
|
|
api_fingerprint = bytes.fromhex(result[0][0]).hex()
|
|
www_fingerprint = bytes.fromhex(result[0][1]).hex()
|
|
missing = []
|
|
if api_fingerprint != fingerprint:
|
|
missing.append("api-ssl")
|
|
if www_fingerprint != fingerprint:
|
|
missing.append("www-ssl")
|
|
if missing:
|
|
raise CertInstallError(
|
|
f"certs didn't get installed to {','.join(missing)}:"
|
|
f" {api_fingerprint=!r}"
|
|
f", {www_fingerprint=!r}"
|
|
f", {fingerprint=!r}"
|
|
)
|
|
|
|
def _parse_cmd_output(self, cmd_output: str):
|
|
import_result_counts = defaultdict(int)
|
|
certlist_output = None
|
|
|
|
for rawline in StringIO(cmd_output):
|
|
line = rawline.rstrip("\r\n")
|
|
|
|
if m := import_result_regex.match(line):
|
|
import_result_counts[m["name"]] += int(m["count"])
|
|
elif m := json_list_regex.match(line):
|
|
assert certlist_output is None, "certificate list json received twice?"
|
|
certlist_output = json.loads(line)
|
|
elif not line.strip():
|
|
# blank line
|
|
pass
|
|
else:
|
|
raise ValueError(f"unable to parse output line: {line!r}")
|
|
return import_result_counts, certlist_output
|