Files
the-other-dude/backend/app/services/ca_service.py
Jason Staack 06a41ca9bf fix(lint): resolve all ruff lint errors
Add ruff config to exclude alembic E402, SQLAlchemy F821, and pre-existing
E501 line-length issues. Auto-fix 69 unused imports and 2 f-strings without
placeholders. Manually fix 8 unused variables. Apply ruff format to 127 files.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-14 22:17:50 -05:00

458 lines
15 KiB
Python

"""Certificate Authority service — CA generation, device cert signing, lifecycle.
This module provides the core PKI functionality for the Internal Certificate
Authority feature. All functions receive an ``AsyncSession`` and an
``encryption_key`` as parameters (no direct Settings access) for testability.
Security notes:
- CA private keys are encrypted with AES-256-GCM before database storage.
- PEM key material is NEVER logged.
- Device keys are decrypted only when needed for NATS transmission.
"""
from __future__ import annotations
import datetime
import ipaddress
import logging
from uuid import UUID
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.certificate import CertificateAuthority, DeviceCertificate
from app.services.crypto import (
decrypt_credentials_hybrid,
encrypt_credentials_transit,
)
logger = logging.getLogger(__name__)
# Valid status transitions for the device certificate lifecycle.
_VALID_TRANSITIONS: dict[str, set[str]] = {
"issued": {"deploying"},
"deploying": {"deployed", "issued"}, # issued = rollback on deploy failure
"deployed": {"expiring", "revoked", "superseded"},
"expiring": {"expired", "revoked", "superseded"},
"expired": {"superseded"},
"revoked": set(),
"superseded": set(),
}
# ---------------------------------------------------------------------------
# CA Generation
# ---------------------------------------------------------------------------
async def generate_ca(
db: AsyncSession,
tenant_id: UUID,
common_name: str,
validity_years: int,
encryption_key: bytes,
) -> CertificateAuthority:
"""Generate a self-signed root CA for a tenant.
Args:
db: Async database session.
tenant_id: Tenant UUID — only one CA per tenant.
common_name: CN for the CA certificate (e.g., "Portal Root CA").
validity_years: How many years the CA cert is valid.
encryption_key: 32-byte AES-256-GCM key for encrypting the CA private key.
Returns:
The newly created ``CertificateAuthority`` model instance.
Raises:
ValueError: If the tenant already has a CA.
"""
# Ensure one CA per tenant
existing = await get_ca_for_tenant(db, tenant_id)
if existing is not None:
raise ValueError(
f"Tenant {tenant_id} already has a CA (id={existing.id}). "
"Delete the existing CA before creating a new one."
)
# Generate RSA 2048 key pair
ca_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
now = datetime.datetime.now(datetime.timezone.utc)
expiry = now + datetime.timedelta(days=365 * validity_years)
subject = issuer = x509.Name(
[
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
]
)
ca_cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(ca_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(expiry)
.add_extension(x509.BasicConstraints(ca=True, path_length=0), critical=True)
.add_extension(
x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=True,
crl_sign=True,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.add_extension(
x509.SubjectKeyIdentifier.from_public_key(ca_key.public_key()),
critical=False,
)
.sign(ca_key, hashes.SHA256())
)
# Serialize public cert to PEM
cert_pem = ca_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
# Serialize private key to PEM, then encrypt with OpenBao Transit
key_pem = ca_key.private_bytes(
serialization.Encoding.PEM,
serialization.PrivateFormat.PKCS8,
serialization.NoEncryption(),
).decode("utf-8")
encrypted_key_transit = await encrypt_credentials_transit(key_pem, str(tenant_id))
# Compute SHA-256 fingerprint (colon-separated hex)
fingerprint_bytes = ca_cert.fingerprint(hashes.SHA256())
fingerprint = ":".join(f"{b:02X}" for b in fingerprint_bytes)
# Serial number as hex string
serial_hex = format(ca_cert.serial_number, "X")
model = CertificateAuthority(
tenant_id=tenant_id,
common_name=common_name,
cert_pem=cert_pem,
encrypted_private_key=b"", # Legacy column kept for schema compat
encrypted_private_key_transit=encrypted_key_transit,
serial_number=serial_hex,
fingerprint_sha256=fingerprint,
not_valid_before=now,
not_valid_after=expiry,
)
db.add(model)
await db.flush()
logger.info(
"Generated CA for tenant %s: cn=%s fingerprint=%s",
tenant_id,
common_name,
fingerprint,
)
return model
# ---------------------------------------------------------------------------
# Device Certificate Signing
# ---------------------------------------------------------------------------
async def sign_device_cert(
db: AsyncSession,
ca: CertificateAuthority,
device_id: UUID,
hostname: str,
ip_address: str,
validity_days: int,
encryption_key: bytes,
) -> DeviceCertificate:
"""Sign a per-device TLS certificate using the tenant's CA.
Args:
db: Async database session.
ca: The tenant's CertificateAuthority model instance.
device_id: UUID of the device receiving the cert.
hostname: Device hostname — used as CN and SAN DNSName.
ip_address: Device IP — used as SAN IPAddress.
validity_days: Certificate validity in days.
encryption_key: 32-byte AES-256-GCM key for encrypting the device private key.
Returns:
The newly created ``DeviceCertificate`` model instance (status='issued').
"""
# Decrypt CA private key (dual-read: Transit preferred, legacy fallback)
ca_key_pem = await decrypt_credentials_hybrid(
ca.encrypted_private_key_transit,
ca.encrypted_private_key,
str(ca.tenant_id),
encryption_key,
)
ca_key = serialization.load_pem_private_key(ca_key_pem.encode("utf-8"), password=None)
# Load CA certificate for issuer info and AuthorityKeyIdentifier
ca_cert = x509.load_pem_x509_certificate(ca.cert_pem.encode("utf-8"))
# Generate device RSA 2048 key
device_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
now = datetime.datetime.now(datetime.timezone.utc)
expiry = now + datetime.timedelta(days=validity_days)
device_cert = (
x509.CertificateBuilder()
.subject_name(
x509.Name(
[
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
x509.NameAttribute(NameOID.COMMON_NAME, hostname),
]
)
)
.issuer_name(ca_cert.subject)
.public_key(device_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(expiry)
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
.add_extension(
x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=True,
data_encipherment=False,
key_agreement=False,
key_cert_sign=False,
crl_sign=False,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.add_extension(
x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]),
critical=False,
)
.add_extension(
x509.SubjectAlternativeName(
[
x509.IPAddress(ipaddress.ip_address(ip_address)),
x509.DNSName(hostname),
]
),
critical=False,
)
.add_extension(
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
ca_cert.extensions.get_extension_for_class(x509.SubjectKeyIdentifier).value
),
critical=False,
)
.sign(ca_key, hashes.SHA256())
)
# Serialize device cert and key to PEM
cert_pem = device_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
key_pem = device_key.private_bytes(
serialization.Encoding.PEM,
serialization.PrivateFormat.PKCS8,
serialization.NoEncryption(),
).decode("utf-8")
# Encrypt device private key via OpenBao Transit
encrypted_key_transit = await encrypt_credentials_transit(key_pem, str(ca.tenant_id))
# Compute fingerprint
fingerprint_bytes = device_cert.fingerprint(hashes.SHA256())
fingerprint = ":".join(f"{b:02X}" for b in fingerprint_bytes)
serial_hex = format(device_cert.serial_number, "X")
model = DeviceCertificate(
tenant_id=ca.tenant_id,
device_id=device_id,
ca_id=ca.id,
common_name=hostname,
serial_number=serial_hex,
fingerprint_sha256=fingerprint,
cert_pem=cert_pem,
encrypted_private_key=b"", # Legacy column kept for schema compat
encrypted_private_key_transit=encrypted_key_transit,
not_valid_before=now,
not_valid_after=expiry,
status="issued",
)
db.add(model)
await db.flush()
logger.info(
"Signed device cert for device %s: cn=%s fingerprint=%s",
device_id,
hostname,
fingerprint,
)
return model
# ---------------------------------------------------------------------------
# Queries
# ---------------------------------------------------------------------------
async def get_ca_for_tenant(
db: AsyncSession,
tenant_id: UUID,
) -> CertificateAuthority | None:
"""Return the tenant's CA, or None if not yet initialized."""
result = await db.execute(
select(CertificateAuthority).where(CertificateAuthority.tenant_id == tenant_id)
)
return result.scalar_one_or_none()
async def get_device_certs(
db: AsyncSession,
tenant_id: UUID,
device_id: UUID | None = None,
) -> list[DeviceCertificate]:
"""List device certificates for a tenant.
Args:
db: Async database session.
tenant_id: Tenant UUID.
device_id: If provided, filter to certs for this device only.
Returns:
List of DeviceCertificate models (excludes superseded by default).
"""
stmt = (
select(DeviceCertificate)
.where(DeviceCertificate.tenant_id == tenant_id)
.where(DeviceCertificate.status != "superseded")
)
if device_id is not None:
stmt = stmt.where(DeviceCertificate.device_id == device_id)
stmt = stmt.order_by(DeviceCertificate.created_at.desc())
result = await db.execute(stmt)
return list(result.scalars().all())
# ---------------------------------------------------------------------------
# Status Management
# ---------------------------------------------------------------------------
async def update_cert_status(
db: AsyncSession,
cert_id: UUID,
status: str,
deployed_at: datetime.datetime | None = None,
) -> DeviceCertificate:
"""Update a device certificate's lifecycle status.
Validates that the transition is allowed by the state machine:
issued -> deploying -> deployed -> expiring -> expired
\\-> revoked
\\-> superseded
Args:
db: Async database session.
cert_id: Certificate UUID.
status: New status value.
deployed_at: Timestamp to set when transitioning to 'deployed'.
Returns:
The updated DeviceCertificate model.
Raises:
ValueError: If the certificate is not found or the transition is invalid.
"""
result = await db.execute(select(DeviceCertificate).where(DeviceCertificate.id == cert_id))
cert = result.scalar_one_or_none()
if cert is None:
raise ValueError(f"Device certificate {cert_id} not found")
allowed = _VALID_TRANSITIONS.get(cert.status, set())
if status not in allowed:
raise ValueError(
f"Invalid status transition: {cert.status} -> {status}. "
f"Allowed transitions from '{cert.status}': {allowed or 'none'}"
)
cert.status = status
cert.updated_at = datetime.datetime.now(datetime.timezone.utc)
if status == "deployed" and deployed_at is not None:
cert.deployed_at = deployed_at
elif status == "deployed":
cert.deployed_at = cert.updated_at
await db.flush()
logger.info(
"Updated cert %s status to %s",
cert_id,
status,
)
return cert
# ---------------------------------------------------------------------------
# Cert Data for Deployment
# ---------------------------------------------------------------------------
async def get_cert_for_deploy(
db: AsyncSession,
cert_id: UUID,
encryption_key: bytes,
) -> tuple[str, str, str]:
"""Retrieve and decrypt certificate data for NATS deployment.
Returns the device cert PEM, decrypted device key PEM, and the CA cert
PEM — everything needed to push to a device via the Go poller.
Args:
db: Async database session.
cert_id: Device certificate UUID.
encryption_key: 32-byte AES-256-GCM key to decrypt the device private key.
Returns:
Tuple of (cert_pem, key_pem_decrypted, ca_cert_pem).
Raises:
ValueError: If the certificate or its CA is not found.
"""
result = await db.execute(select(DeviceCertificate).where(DeviceCertificate.id == cert_id))
cert = result.scalar_one_or_none()
if cert is None:
raise ValueError(f"Device certificate {cert_id} not found")
# Fetch the CA for the ca_cert_pem
ca_result = await db.execute(
select(CertificateAuthority).where(CertificateAuthority.id == cert.ca_id)
)
ca = ca_result.scalar_one_or_none()
if ca is None:
raise ValueError(f"CA {cert.ca_id} not found for certificate {cert_id}")
# Decrypt device private key (dual-read: Transit preferred, legacy fallback)
key_pem = await decrypt_credentials_hybrid(
cert.encrypted_private_key_transit,
cert.encrypted_private_key,
str(cert.tenant_id),
encryption_key,
)
return cert.cert_pem, key_pem, ca.cert_pem