ci: add GitHub Pages deployment workflow for docs site Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
325 lines
11 KiB
Python
325 lines
11 KiB
Python
"""Key hierarchy management service for zero-knowledge architecture.
|
|
|
|
Provides CRUD operations for encrypted key bundles (UserKeySet),
|
|
append-only audit logging (KeyAccessLog), and OpenBao Transit
|
|
tenant key provisioning with credential migration.
|
|
"""
|
|
|
|
import logging
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.key_set import KeyAccessLog, UserKeySet
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def store_user_key_set(
|
|
db: AsyncSession,
|
|
user_id: UUID,
|
|
tenant_id: UUID | None,
|
|
encrypted_private_key: bytes,
|
|
private_key_nonce: bytes,
|
|
encrypted_vault_key: bytes,
|
|
vault_key_nonce: bytes,
|
|
public_key: bytes,
|
|
pbkdf2_salt: bytes,
|
|
hkdf_salt: bytes,
|
|
pbkdf2_iterations: int = 650000,
|
|
) -> UserKeySet:
|
|
"""Store encrypted key bundle during registration.
|
|
|
|
Creates a new UserKeySet for the user. Each user has exactly one
|
|
key set (UNIQUE constraint on user_id).
|
|
|
|
Args:
|
|
db: Async database session.
|
|
user_id: The user's UUID.
|
|
tenant_id: The user's tenant UUID (None for super_admin).
|
|
encrypted_private_key: RSA private key wrapped by AUK (AES-GCM).
|
|
private_key_nonce: 12-byte AES-GCM nonce for private key.
|
|
encrypted_vault_key: Tenant vault key wrapped by user's public key.
|
|
vault_key_nonce: 12-byte AES-GCM nonce for vault key.
|
|
public_key: RSA-2048 public key in SPKI format.
|
|
pbkdf2_salt: 32-byte salt for PBKDF2 key derivation.
|
|
hkdf_salt: 32-byte salt for HKDF Secret Key derivation.
|
|
pbkdf2_iterations: PBKDF2 iteration count (default 650000).
|
|
|
|
Returns:
|
|
The created UserKeySet instance.
|
|
"""
|
|
# Remove any existing key set (e.g. from a failed prior upgrade attempt)
|
|
from sqlalchemy import delete
|
|
await db.execute(delete(UserKeySet).where(UserKeySet.user_id == user_id))
|
|
|
|
key_set = UserKeySet(
|
|
user_id=user_id,
|
|
tenant_id=tenant_id,
|
|
encrypted_private_key=encrypted_private_key,
|
|
private_key_nonce=private_key_nonce,
|
|
encrypted_vault_key=encrypted_vault_key,
|
|
vault_key_nonce=vault_key_nonce,
|
|
public_key=public_key,
|
|
pbkdf2_salt=pbkdf2_salt,
|
|
hkdf_salt=hkdf_salt,
|
|
pbkdf2_iterations=pbkdf2_iterations,
|
|
)
|
|
db.add(key_set)
|
|
await db.flush()
|
|
return key_set
|
|
|
|
|
|
async def get_user_key_set(
|
|
db: AsyncSession, user_id: UUID
|
|
) -> UserKeySet | None:
|
|
"""Retrieve encrypted key bundle for login response.
|
|
|
|
Args:
|
|
db: Async database session.
|
|
user_id: The user's UUID.
|
|
|
|
Returns:
|
|
The UserKeySet if found, None otherwise.
|
|
"""
|
|
result = await db.execute(
|
|
select(UserKeySet).where(UserKeySet.user_id == user_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def log_key_access(
|
|
db: AsyncSession,
|
|
tenant_id: UUID,
|
|
user_id: UUID | None,
|
|
action: str,
|
|
resource_type: str | None = None,
|
|
resource_id: str | None = None,
|
|
key_version: int | None = None,
|
|
ip_address: str | None = None,
|
|
device_id: UUID | None = None,
|
|
justification: str | None = None,
|
|
correlation_id: str | None = None,
|
|
) -> None:
|
|
"""Append to immutable key_access_log.
|
|
|
|
This table is append-only (INSERT+SELECT only via RLS policy).
|
|
No UPDATE or DELETE is permitted.
|
|
|
|
Args:
|
|
db: Async database session.
|
|
tenant_id: The tenant UUID for RLS isolation.
|
|
user_id: The user who performed the action (None for system ops).
|
|
action: Action description (e.g., 'create_key_set', 'decrypt_vault_key').
|
|
resource_type: Optional resource type being accessed.
|
|
resource_id: Optional resource identifier.
|
|
key_version: Optional key version involved.
|
|
ip_address: Optional client IP address.
|
|
device_id: Optional device UUID for credential access tracking.
|
|
justification: Optional justification for the access (e.g., 'api_backup').
|
|
correlation_id: Optional correlation ID for request tracing.
|
|
"""
|
|
log_entry = KeyAccessLog(
|
|
tenant_id=tenant_id,
|
|
user_id=user_id,
|
|
action=action,
|
|
resource_type=resource_type,
|
|
resource_id=resource_id,
|
|
key_version=key_version,
|
|
ip_address=ip_address,
|
|
device_id=device_id,
|
|
justification=justification,
|
|
correlation_id=correlation_id,
|
|
)
|
|
db.add(log_entry)
|
|
await db.flush()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# OpenBao Transit tenant key provisioning and credential migration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def provision_tenant_key(db: AsyncSession, tenant_id: UUID) -> str:
|
|
"""Provision an OpenBao Transit key for a tenant and update the tenant record.
|
|
|
|
Idempotent: if the key already exists in OpenBao, it's a no-op on the
|
|
OpenBao side. The tenant record is always updated with the key name.
|
|
|
|
Args:
|
|
db: Async database session (admin engine, no RLS).
|
|
tenant_id: Tenant UUID.
|
|
|
|
Returns:
|
|
The key name (tenant_{uuid}).
|
|
"""
|
|
from app.models.tenant import Tenant
|
|
from app.services.openbao_service import get_openbao_service
|
|
|
|
openbao = get_openbao_service()
|
|
key_name = f"tenant_{tenant_id}"
|
|
|
|
await openbao.create_tenant_key(str(tenant_id))
|
|
|
|
# Update tenant record with key name
|
|
result = await db.execute(
|
|
select(Tenant).where(Tenant.id == tenant_id)
|
|
)
|
|
tenant = result.scalar_one_or_none()
|
|
if tenant:
|
|
tenant.openbao_key_name = key_name
|
|
await db.flush()
|
|
|
|
logger.info(
|
|
"Provisioned OpenBao Transit key for tenant %s (key=%s)",
|
|
tenant_id,
|
|
key_name,
|
|
)
|
|
return key_name
|
|
|
|
|
|
async def migrate_tenant_credentials(db: AsyncSession, tenant_id: UUID) -> dict:
|
|
"""Re-encrypt all legacy credentials for a tenant from AES-256-GCM to Transit.
|
|
|
|
Migrates device credentials, CA private keys, device cert private keys,
|
|
and notification channel secrets. Already-migrated items are skipped.
|
|
|
|
Args:
|
|
db: Async database session (admin engine, no RLS).
|
|
tenant_id: Tenant UUID.
|
|
|
|
Returns:
|
|
Dict with counts: {"devices": N, "cas": N, "certs": N, "channels": N, "errors": N}
|
|
"""
|
|
from app.config import settings
|
|
from app.models.alert import NotificationChannel
|
|
from app.models.certificate import CertificateAuthority, DeviceCertificate
|
|
from app.models.device import Device
|
|
from app.services.crypto import decrypt_credentials
|
|
from app.services.openbao_service import get_openbao_service
|
|
|
|
openbao = get_openbao_service()
|
|
legacy_key = settings.get_encryption_key_bytes()
|
|
tid = str(tenant_id)
|
|
|
|
counts = {"devices": 0, "cas": 0, "certs": 0, "channels": 0, "errors": 0}
|
|
|
|
# --- Migrate device credentials ---
|
|
result = await db.execute(
|
|
select(Device).where(
|
|
Device.tenant_id == tenant_id,
|
|
Device.encrypted_credentials.isnot(None),
|
|
(Device.encrypted_credentials_transit.is_(None) | (Device.encrypted_credentials_transit == "")),
|
|
)
|
|
)
|
|
for device in result.scalars().all():
|
|
try:
|
|
plaintext = decrypt_credentials(device.encrypted_credentials, legacy_key)
|
|
device.encrypted_credentials_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
|
|
counts["devices"] += 1
|
|
except Exception as e:
|
|
logger.error("Failed to migrate device %s credentials: %s", device.id, e)
|
|
counts["errors"] += 1
|
|
|
|
# --- Migrate CA private keys ---
|
|
result = await db.execute(
|
|
select(CertificateAuthority).where(
|
|
CertificateAuthority.tenant_id == tenant_id,
|
|
CertificateAuthority.encrypted_private_key.isnot(None),
|
|
(CertificateAuthority.encrypted_private_key_transit.is_(None) | (CertificateAuthority.encrypted_private_key_transit == "")),
|
|
)
|
|
)
|
|
for ca in result.scalars().all():
|
|
try:
|
|
plaintext = decrypt_credentials(ca.encrypted_private_key, legacy_key)
|
|
ca.encrypted_private_key_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
|
|
counts["cas"] += 1
|
|
except Exception as e:
|
|
logger.error("Failed to migrate CA %s private key: %s", ca.id, e)
|
|
counts["errors"] += 1
|
|
|
|
# --- Migrate device cert private keys ---
|
|
result = await db.execute(
|
|
select(DeviceCertificate).where(
|
|
DeviceCertificate.tenant_id == tenant_id,
|
|
DeviceCertificate.encrypted_private_key.isnot(None),
|
|
(DeviceCertificate.encrypted_private_key_transit.is_(None) | (DeviceCertificate.encrypted_private_key_transit == "")),
|
|
)
|
|
)
|
|
for cert in result.scalars().all():
|
|
try:
|
|
plaintext = decrypt_credentials(cert.encrypted_private_key, legacy_key)
|
|
cert.encrypted_private_key_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
|
|
counts["certs"] += 1
|
|
except Exception as e:
|
|
logger.error("Failed to migrate cert %s private key: %s", cert.id, e)
|
|
counts["errors"] += 1
|
|
|
|
# --- Migrate notification channel secrets ---
|
|
result = await db.execute(
|
|
select(NotificationChannel).where(
|
|
NotificationChannel.tenant_id == tenant_id,
|
|
)
|
|
)
|
|
for ch in result.scalars().all():
|
|
migrated_any = False
|
|
try:
|
|
# SMTP password
|
|
if ch.smtp_password and not ch.smtp_password_transit:
|
|
plaintext = decrypt_credentials(ch.smtp_password, legacy_key)
|
|
ch.smtp_password_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
|
|
migrated_any = True
|
|
if migrated_any:
|
|
counts["channels"] += 1
|
|
except Exception as e:
|
|
logger.error("Failed to migrate channel %s secrets: %s", ch.id, e)
|
|
counts["errors"] += 1
|
|
|
|
await db.flush()
|
|
|
|
logger.info(
|
|
"Tenant %s credential migration complete: %s",
|
|
tenant_id,
|
|
counts,
|
|
)
|
|
return counts
|
|
|
|
|
|
async def provision_existing_tenants(db: AsyncSession) -> dict:
|
|
"""Provision OpenBao Transit keys for all existing tenants and migrate credentials.
|
|
|
|
Called on app startup to ensure all tenants have Transit keys.
|
|
Idempotent -- running multiple times is safe (already-migrated items are skipped).
|
|
|
|
Args:
|
|
db: Async database session (admin engine, no RLS).
|
|
|
|
Returns:
|
|
Summary dict with total counts across all tenants.
|
|
"""
|
|
from app.models.tenant import Tenant
|
|
|
|
result = await db.execute(select(Tenant))
|
|
tenants = result.scalars().all()
|
|
|
|
total = {"tenants": len(tenants), "devices": 0, "cas": 0, "certs": 0, "channels": 0, "errors": 0}
|
|
|
|
for tenant in tenants:
|
|
try:
|
|
await provision_tenant_key(db, tenant.id)
|
|
counts = await migrate_tenant_credentials(db, tenant.id)
|
|
total["devices"] += counts["devices"]
|
|
total["cas"] += counts["cas"]
|
|
total["certs"] += counts["certs"]
|
|
total["channels"] += counts["channels"]
|
|
total["errors"] += counts["errors"]
|
|
except Exception as e:
|
|
logger.error("Failed to provision/migrate tenant %s: %s", tenant.id, e)
|
|
total["errors"] += 1
|
|
|
|
await db.commit()
|
|
|
|
logger.info("Existing tenant provisioning complete: %s", total)
|
|
return total
|