feat: The Other Dude v9.0.1 — full-featured email system
ci: add GitHub Pages deployment workflow for docs site Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
462
backend/app/services/ca_service.py
Normal file
462
backend/app/services/ca_service.py
Normal file
@@ -0,0 +1,462 @@
|
||||
"""Certificate Authority service — CA generation, device cert signing, lifecycle.
|
||||
|
||||
This module provides the core PKI functionality for the Internal Certificate
|
||||
Authority feature. All functions receive an ``AsyncSession`` and an
|
||||
``encryption_key`` as parameters (no direct Settings access) for testability.
|
||||
|
||||
Security notes:
|
||||
- CA private keys are encrypted with AES-256-GCM before database storage.
|
||||
- PEM key material is NEVER logged.
|
||||
- Device keys are decrypted only when needed for NATS transmission.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import ipaddress
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.certificate import CertificateAuthority, DeviceCertificate
|
||||
from app.services.crypto import (
|
||||
decrypt_credentials_hybrid,
|
||||
encrypt_credentials_transit,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Valid status transitions for the device certificate lifecycle.
|
||||
_VALID_TRANSITIONS: dict[str, set[str]] = {
|
||||
"issued": {"deploying"},
|
||||
"deploying": {"deployed", "issued"}, # issued = rollback on deploy failure
|
||||
"deployed": {"expiring", "revoked", "superseded"},
|
||||
"expiring": {"expired", "revoked", "superseded"},
|
||||
"expired": {"superseded"},
|
||||
"revoked": set(),
|
||||
"superseded": set(),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CA Generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def generate_ca(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
common_name: str,
|
||||
validity_years: int,
|
||||
encryption_key: bytes,
|
||||
) -> CertificateAuthority:
|
||||
"""Generate a self-signed root CA for a tenant.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
tenant_id: Tenant UUID — only one CA per tenant.
|
||||
common_name: CN for the CA certificate (e.g., "Portal Root CA").
|
||||
validity_years: How many years the CA cert is valid.
|
||||
encryption_key: 32-byte AES-256-GCM key for encrypting the CA private key.
|
||||
|
||||
Returns:
|
||||
The newly created ``CertificateAuthority`` model instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If the tenant already has a CA.
|
||||
"""
|
||||
# Ensure one CA per tenant
|
||||
existing = await get_ca_for_tenant(db, tenant_id)
|
||||
if existing is not None:
|
||||
raise ValueError(
|
||||
f"Tenant {tenant_id} already has a CA (id={existing.id}). "
|
||||
"Delete the existing CA before creating a new one."
|
||||
)
|
||||
|
||||
# Generate RSA 2048 key pair
|
||||
ca_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
expiry = now + datetime.timedelta(days=365 * validity_years)
|
||||
|
||||
subject = issuer = x509.Name([
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
||||
])
|
||||
|
||||
ca_cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(subject)
|
||||
.issuer_name(issuer)
|
||||
.public_key(ca_key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now)
|
||||
.not_valid_after(expiry)
|
||||
.add_extension(
|
||||
x509.BasicConstraints(ca=True, path_length=0), critical=True
|
||||
)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
content_commitment=False,
|
||||
key_encipherment=False,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
key_cert_sign=True,
|
||||
crl_sign=True,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.SubjectKeyIdentifier.from_public_key(ca_key.public_key()),
|
||||
critical=False,
|
||||
)
|
||||
.sign(ca_key, hashes.SHA256())
|
||||
)
|
||||
|
||||
# Serialize public cert to PEM
|
||||
cert_pem = ca_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
|
||||
|
||||
# Serialize private key to PEM, then encrypt with OpenBao Transit
|
||||
key_pem = ca_key.private_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PrivateFormat.PKCS8,
|
||||
serialization.NoEncryption(),
|
||||
).decode("utf-8")
|
||||
encrypted_key_transit = await encrypt_credentials_transit(key_pem, str(tenant_id))
|
||||
|
||||
# Compute SHA-256 fingerprint (colon-separated hex)
|
||||
fingerprint_bytes = ca_cert.fingerprint(hashes.SHA256())
|
||||
fingerprint = ":".join(f"{b:02X}" for b in fingerprint_bytes)
|
||||
|
||||
# Serial number as hex string
|
||||
serial_hex = format(ca_cert.serial_number, "X")
|
||||
|
||||
model = CertificateAuthority(
|
||||
tenant_id=tenant_id,
|
||||
common_name=common_name,
|
||||
cert_pem=cert_pem,
|
||||
encrypted_private_key=b"", # Legacy column kept for schema compat
|
||||
encrypted_private_key_transit=encrypted_key_transit,
|
||||
serial_number=serial_hex,
|
||||
fingerprint_sha256=fingerprint,
|
||||
not_valid_before=now,
|
||||
not_valid_after=expiry,
|
||||
)
|
||||
db.add(model)
|
||||
await db.flush()
|
||||
|
||||
logger.info(
|
||||
"Generated CA for tenant %s: cn=%s fingerprint=%s",
|
||||
tenant_id,
|
||||
common_name,
|
||||
fingerprint,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Device Certificate Signing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def sign_device_cert(
|
||||
db: AsyncSession,
|
||||
ca: CertificateAuthority,
|
||||
device_id: UUID,
|
||||
hostname: str,
|
||||
ip_address: str,
|
||||
validity_days: int,
|
||||
encryption_key: bytes,
|
||||
) -> DeviceCertificate:
|
||||
"""Sign a per-device TLS certificate using the tenant's CA.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
ca: The tenant's CertificateAuthority model instance.
|
||||
device_id: UUID of the device receiving the cert.
|
||||
hostname: Device hostname — used as CN and SAN DNSName.
|
||||
ip_address: Device IP — used as SAN IPAddress.
|
||||
validity_days: Certificate validity in days.
|
||||
encryption_key: 32-byte AES-256-GCM key for encrypting the device private key.
|
||||
|
||||
Returns:
|
||||
The newly created ``DeviceCertificate`` model instance (status='issued').
|
||||
"""
|
||||
# Decrypt CA private key (dual-read: Transit preferred, legacy fallback)
|
||||
ca_key_pem = await decrypt_credentials_hybrid(
|
||||
ca.encrypted_private_key_transit,
|
||||
ca.encrypted_private_key,
|
||||
str(ca.tenant_id),
|
||||
encryption_key,
|
||||
)
|
||||
ca_key = serialization.load_pem_private_key(
|
||||
ca_key_pem.encode("utf-8"), password=None
|
||||
)
|
||||
|
||||
# Load CA certificate for issuer info and AuthorityKeyIdentifier
|
||||
ca_cert = x509.load_pem_x509_certificate(ca.cert_pem.encode("utf-8"))
|
||||
|
||||
# Generate device RSA 2048 key
|
||||
device_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
expiry = now + datetime.timedelta(days=validity_days)
|
||||
|
||||
device_cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(
|
||||
x509.Name([
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, hostname),
|
||||
])
|
||||
)
|
||||
.issuer_name(ca_cert.subject)
|
||||
.public_key(device_key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now)
|
||||
.not_valid_after(expiry)
|
||||
.add_extension(
|
||||
x509.BasicConstraints(ca=False, path_length=None), critical=True
|
||||
)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
content_commitment=False,
|
||||
key_encipherment=True,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
key_cert_sign=False,
|
||||
crl_sign=False,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.SubjectAlternativeName([
|
||||
x509.IPAddress(ipaddress.ip_address(ip_address)),
|
||||
x509.DNSName(hostname),
|
||||
]),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
|
||||
ca_cert.extensions.get_extension_for_class(
|
||||
x509.SubjectKeyIdentifier
|
||||
).value
|
||||
),
|
||||
critical=False,
|
||||
)
|
||||
.sign(ca_key, hashes.SHA256())
|
||||
)
|
||||
|
||||
# Serialize device cert and key to PEM
|
||||
cert_pem = device_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
|
||||
key_pem = device_key.private_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PrivateFormat.PKCS8,
|
||||
serialization.NoEncryption(),
|
||||
).decode("utf-8")
|
||||
|
||||
# Encrypt device private key via OpenBao Transit
|
||||
encrypted_key_transit = await encrypt_credentials_transit(key_pem, str(ca.tenant_id))
|
||||
|
||||
# Compute fingerprint
|
||||
fingerprint_bytes = device_cert.fingerprint(hashes.SHA256())
|
||||
fingerprint = ":".join(f"{b:02X}" for b in fingerprint_bytes)
|
||||
|
||||
serial_hex = format(device_cert.serial_number, "X")
|
||||
|
||||
model = DeviceCertificate(
|
||||
tenant_id=ca.tenant_id,
|
||||
device_id=device_id,
|
||||
ca_id=ca.id,
|
||||
common_name=hostname,
|
||||
serial_number=serial_hex,
|
||||
fingerprint_sha256=fingerprint,
|
||||
cert_pem=cert_pem,
|
||||
encrypted_private_key=b"", # Legacy column kept for schema compat
|
||||
encrypted_private_key_transit=encrypted_key_transit,
|
||||
not_valid_before=now,
|
||||
not_valid_after=expiry,
|
||||
status="issued",
|
||||
)
|
||||
db.add(model)
|
||||
await db.flush()
|
||||
|
||||
logger.info(
|
||||
"Signed device cert for device %s: cn=%s fingerprint=%s",
|
||||
device_id,
|
||||
hostname,
|
||||
fingerprint,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Queries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def get_ca_for_tenant(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
) -> CertificateAuthority | None:
|
||||
"""Return the tenant's CA, or None if not yet initialized."""
|
||||
result = await db.execute(
|
||||
select(CertificateAuthority).where(
|
||||
CertificateAuthority.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_device_certs(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
device_id: UUID | None = None,
|
||||
) -> list[DeviceCertificate]:
|
||||
"""List device certificates for a tenant.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
tenant_id: Tenant UUID.
|
||||
device_id: If provided, filter to certs for this device only.
|
||||
|
||||
Returns:
|
||||
List of DeviceCertificate models (excludes superseded by default).
|
||||
"""
|
||||
stmt = (
|
||||
select(DeviceCertificate)
|
||||
.where(DeviceCertificate.tenant_id == tenant_id)
|
||||
.where(DeviceCertificate.status != "superseded")
|
||||
)
|
||||
if device_id is not None:
|
||||
stmt = stmt.where(DeviceCertificate.device_id == device_id)
|
||||
stmt = stmt.order_by(DeviceCertificate.created_at.desc())
|
||||
result = await db.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Status Management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def update_cert_status(
|
||||
db: AsyncSession,
|
||||
cert_id: UUID,
|
||||
status: str,
|
||||
deployed_at: datetime.datetime | None = None,
|
||||
) -> DeviceCertificate:
|
||||
"""Update a device certificate's lifecycle status.
|
||||
|
||||
Validates that the transition is allowed by the state machine:
|
||||
issued -> deploying -> deployed -> expiring -> expired
|
||||
\\-> revoked
|
||||
\\-> superseded
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
cert_id: Certificate UUID.
|
||||
status: New status value.
|
||||
deployed_at: Timestamp to set when transitioning to 'deployed'.
|
||||
|
||||
Returns:
|
||||
The updated DeviceCertificate model.
|
||||
|
||||
Raises:
|
||||
ValueError: If the certificate is not found or the transition is invalid.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
|
||||
)
|
||||
cert = result.scalar_one_or_none()
|
||||
if cert is None:
|
||||
raise ValueError(f"Device certificate {cert_id} not found")
|
||||
|
||||
allowed = _VALID_TRANSITIONS.get(cert.status, set())
|
||||
if status not in allowed:
|
||||
raise ValueError(
|
||||
f"Invalid status transition: {cert.status} -> {status}. "
|
||||
f"Allowed transitions from '{cert.status}': {allowed or 'none'}"
|
||||
)
|
||||
|
||||
cert.status = status
|
||||
cert.updated_at = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
if status == "deployed" and deployed_at is not None:
|
||||
cert.deployed_at = deployed_at
|
||||
elif status == "deployed":
|
||||
cert.deployed_at = cert.updated_at
|
||||
|
||||
await db.flush()
|
||||
|
||||
logger.info(
|
||||
"Updated cert %s status to %s",
|
||||
cert_id,
|
||||
status,
|
||||
)
|
||||
return cert
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cert Data for Deployment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def get_cert_for_deploy(
|
||||
db: AsyncSession,
|
||||
cert_id: UUID,
|
||||
encryption_key: bytes,
|
||||
) -> tuple[str, str, str]:
|
||||
"""Retrieve and decrypt certificate data for NATS deployment.
|
||||
|
||||
Returns the device cert PEM, decrypted device key PEM, and the CA cert
|
||||
PEM — everything needed to push to a device via the Go poller.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
cert_id: Device certificate UUID.
|
||||
encryption_key: 32-byte AES-256-GCM key to decrypt the device private key.
|
||||
|
||||
Returns:
|
||||
Tuple of (cert_pem, key_pem_decrypted, ca_cert_pem).
|
||||
|
||||
Raises:
|
||||
ValueError: If the certificate or its CA is not found.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
|
||||
)
|
||||
cert = result.scalar_one_or_none()
|
||||
if cert is None:
|
||||
raise ValueError(f"Device certificate {cert_id} not found")
|
||||
|
||||
# Fetch the CA for the ca_cert_pem
|
||||
ca_result = await db.execute(
|
||||
select(CertificateAuthority).where(
|
||||
CertificateAuthority.id == cert.ca_id
|
||||
)
|
||||
)
|
||||
ca = ca_result.scalar_one_or_none()
|
||||
if ca is None:
|
||||
raise ValueError(f"CA {cert.ca_id} not found for certificate {cert_id}")
|
||||
|
||||
# Decrypt device private key (dual-read: Transit preferred, legacy fallback)
|
||||
key_pem = await decrypt_credentials_hybrid(
|
||||
cert.encrypted_private_key_transit,
|
||||
cert.encrypted_private_key,
|
||||
str(cert.tenant_id),
|
||||
encryption_key,
|
||||
)
|
||||
|
||||
return cert.cert_pem, key_pem, ca.cert_pem
|
||||
Reference in New Issue
Block a user