Files
mtik-cert-pusher/mtik_cert_pusher/ssl_util.py
2026-03-19 13:45:14 -06:00

181 lines
6.2 KiB
Python

#!python3
# dlitz 2025-2026
import contextlib
import fcntl
import os
import re
import subprocess
from cryptography import x509
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.serialization import (
BestAvailableEncryption,
Encoding,
KeySerializationEncryption,
PrivateFormat,
load_pem_private_key,
load_pem_public_key,
pkcs12,
)
def pipe_with_buffered_data(data, *, text=False, encoding=None, errors=None):
# Open a pipe and place a small string into its buffer, then return a file
# open for reading the pipe.
rfile, wfile = _open_pipes(text=text)
try:
if text:
bdata = data.encode(encoding=wfile.encoding, errors=wfile.errors)
else:
bdata = data
pipe_buffer_size = fcntl.fcntl(wfile.fileno(), fcntl.F_GETPIPE_SZ)
if pipe_buffer_size < len(bdata):
pipe_buffer_size = fcntl.fcntl(
wfile.fileno(), fcntl.F_SETPIPE_SZ, len(bdata)
)
assert pipe_buffer_size >= len(bdata)
wfile.write(data)
wfile.close()
return rfile
except:
wfile.close()
rfile.close()
raise
def _open_pipes(*, text=False, encoding=None, errors=None):
rfd, wfd = os.pipe()
rfile = wfile = None
try:
rfile = open(rfd, "r" if text else "rb", encoding=encoding, errors=errors)
wfile = open(wfd, "w" if text else "wb", encoding=encoding, errors=errors)
return rfile, wfile
except:
if rfile is not None:
rfile.close()
else:
os.close(rfd)
if wfile is not None:
wfile.close()
else:
os.close(wfd)
raise
class ResultParseError(Exception):
pass
class SSLUtil:
openssl_prog = "openssl"
def cert_fingerprint_sha256(self, pem_cert: str) -> bytes:
"""Return the SHA256 fingerprint of the certificate, as bytes."""
cert_obj = x509.load_pem_x509_certificate(pem_cert.encode())
result = cert_obj.fingerprint(hashes.SHA256())
assert isinstance(result, bytes) and len(result) == 32, (result,)
return result
def cert_serial(self, cert: str) -> int:
"""Return the serial number of the certificate, as integer. Might be negative."""
cert_obj = x509.load_pem_x509_certificate(cert.encode())
result = cert_obj.serial_number
assert isinstance(result, int), ("serial number not integer?", result)
return result
def cert_skid(self, pem_cert: str) -> bytes:
"""Return the SubjectKeyIdentifier of the certificate, as bytes."""
cert_obj = x509.load_pem_x509_certificate(pem_cert.encode())
pubkey_obj = cert_obj.public_key()
skid_obj = x509.SubjectKeyIdentifier.from_public_key(pubkey_obj)
assert skid_obj.digest == skid_obj.key_identifier
result = skid_obj.key_identifier
assert isinstance(result, bytes)
assert len(result) == 20
return result
def skid_from_pubkey(self, pubkey: str) -> bytes:
pubkey_obj = load_pem_public_key(pubkey.encode())
result = x509.SubjectKeyIdentifier.from_public_key(pubkey_obj).key_identifier
assert isinstance(result, bytes)
assert len(result) == 20
return result
def skid_from_private_key(self, private_key: str) -> bytes:
private_key_obj = load_pem_private_key(private_key.encode())
pubkey_obj = private_key_obj.public_key()
result = x509.SubjectKeyIdentifier.from_public_key(pubkey_obj).key_identifier
assert isinstance(result, bytes)
assert len(result) == 20
return result
def create_pkcs12_from_key_and_certificates(
self,
*,
name: str | None = None,
key: str,
cert: str,
chain: str | None = None,
passphrase: str,
) -> bytes:
private_key_obj = load_pem_private_key(key.encode(), password=None)
cert_obj = x509.load_pem_x509_certificate(cert.encode())
chain_objs = x509.load_pem_x509_certificates(chain.encode()) if chain else []
if name is None:
pubkey_obj = cert_obj.public_key()
skid_obj = x509.SubjectKeyIdentifier.from_public_key(pubkey_obj)
skid = skid_obj.key_identifier.hex()
fingerprint = cert_obj.fingerprint(hashes.SHA256()).hex()
name = f"SKID:{skid} FPR:{fingerprint}"
result = pkcs12.serialize_key_and_certificates(
name=name.encode(),
key=private_key_obj,
cert=cert_obj,
cas=chain_objs,
encryption_algorithm=self.pkcs12_encryption_algorithm(passphrase.encode()),
)
assert isinstance(result, bytes)
assert result
return result
# def export_pkcs12(self, privkey_data, cert_data, chain_data, passphrase):
# assert re.search(
# r"^-----BEGIN(?: (.*))? PRIVATE KEY-----\n", privkey_data, re.M
# )
# assert re.search(r"^-----BEGIN CERTIFICATE-----\n", cert_data, re.M)
# assert "PRIVATE KEY" not in cert_data
# assert chain_data is None or "PRIVATE KEY" not in chain_data
#
# fullchain_data = cert_data + "\n" + (chain_data or "") + "\n"
#
# all_data = privkey_data + "\n" + fullchain_data
#
# with pipe_with_buffered_data(passphrase, text=True) as passphrase_r:
# cmd = [
# self.openssl_prog,
# "pkcs12",
# "-export",
# "-passout",
# f"fd:{passphrase_r.fileno():d}",
# # "-macalg", "SHA256",
# # "-keypbe", "AES-256-CBC",
# # "-certpbe", "NONE",
# ]
# return subprocess.check_output(
# cmd, pass_fds=[passphrase_r.fileno()], input=all_data.encode()
# )
def pkcs12_encryption_algorithm(
self, passphrase: bytes
) -> KeySerializationEncryption:
return (
PrivateFormat.PKCS12.encryption_builder()
.key_cert_algorithm(pkcs12.PBES.PBESv2SHA256AndAES256CBC)
.kdf_rounds(20000)
.hmac_hash(hashes.SHA256())
.build(passphrase)
)