feat(03-01): implement config snapshot subscriber with dedup and encryption

- NATS subscriber for config.snapshot.> on DEVICE_EVENTS stream
- Dedup by SHA256 hash against latest snapshot per device
- OpenBao Transit encryption before INSERT (plaintext never stored)
- Malformed/orphan messages acked and discarded safely
- Transit failure causes nak for NATS retry
- Prometheus metrics: ingested, dedup_skipped, errors, duration
- All 6 unit tests pass

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Jason Staack
2026-03-12 21:47:07 -05:00
parent 9d8274158a
commit 3ab9f27d49

View File

@@ -0,0 +1,256 @@
"""NATS JetStream subscriber for config snapshot ingestion from the Go poller.
Consumes config.snapshot.> messages, deduplicates by SHA256 hash,
encrypts config text via OpenBao Transit, and persists new snapshots
to the router_config_snapshots table.
Plaintext config is NEVER stored in PostgreSQL and NEVER logged.
"""
import asyncio
import json
import logging
import time
from datetime import datetime, timezone
from typing import Any, Optional
from prometheus_client import Counter, Histogram
from sqlalchemy import text
from sqlalchemy.exc import IntegrityError, OperationalError
from app.config import settings
from app.database import AdminAsyncSessionLocal
from app.services.openbao_service import OpenBaoTransitService
logger = logging.getLogger(__name__)
# --- Prometheus metrics ---
config_snapshot_ingested_total = Counter(
"config_snapshot_ingested_total",
"Total config snapshots successfully ingested",
)
config_snapshot_dedup_skipped_total = Counter(
"config_snapshot_dedup_skipped_total",
"Total config snapshots skipped due to deduplication",
)
config_snapshot_errors_total = Counter(
"config_snapshot_errors_total",
"Total config snapshot ingestion errors",
["error_type"],
)
config_snapshot_ingestion_duration_seconds = Histogram(
"config_snapshot_ingestion_duration_seconds",
"Time to process a config snapshot message",
)
# --- Module state ---
_nc: Optional[Any] = None
async def handle_config_snapshot(msg) -> None:
"""Handle a config.snapshot.> message from the Go poller.
1. Parse JSON payload; malformed -> ack + discard
2. Dedup check against latest hash for device
3. Encrypt via OpenBao Transit; failure -> nak (NATS retry)
4. INSERT new RouterConfigSnapshot row
5. FK violation (orphan device) -> ack + discard
6. Transient DB error -> nak for retry
"""
start_time = time.monotonic()
try:
data = json.loads(msg.data)
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
logger.warning("Malformed config snapshot message (bad JSON): %s", exc)
config_snapshot_errors_total.labels(error_type="malformed").inc()
await msg.ack()
return
device_id = data.get("device_id")
tenant_id = data.get("tenant_id")
sha256_hash = data.get("sha256_hash")
config_text = data.get("config_text")
# Validate required fields
if not device_id or not tenant_id or not sha256_hash or not config_text:
logger.warning(
"Config snapshot message missing required fields (device_id=%s, tenant_id=%s)",
device_id,
tenant_id,
)
config_snapshot_errors_total.labels(error_type="malformed").inc()
await msg.ack()
return
collected_at_raw = data.get("collected_at")
try:
collected_at = datetime.fromisoformat(
collected_at_raw.replace("Z", "+00:00")
) if collected_at_raw else datetime.now(timezone.utc)
except (ValueError, AttributeError):
collected_at = datetime.now(timezone.utc)
async with AdminAsyncSessionLocal() as session:
# --- Dedup check ---
result = await session.execute(
text(
"SELECT sha256_hash FROM router_config_snapshots "
"WHERE device_id = CAST(:device_id AS uuid) "
"ORDER BY collected_at DESC LIMIT 1"
),
{"device_id": device_id},
)
latest_hash = result.scalar_one_or_none()
if latest_hash == sha256_hash:
logger.debug(
"Duplicate config snapshot skipped for device %s",
device_id,
)
config_snapshot_dedup_skipped_total.inc()
await msg.ack()
return
# --- Encrypt via OpenBao Transit ---
openbao = OpenBaoTransitService()
try:
encrypted_text = await openbao.encrypt(
tenant_id, config_text.encode("utf-8")
)
except Exception as exc:
logger.warning(
"Transit encrypt failed for device %s tenant %s: %s",
device_id,
tenant_id,
exc,
)
config_snapshot_errors_total.labels(error_type="encrypt_unavailable").inc()
await msg.nak()
return
finally:
await openbao.close()
# --- INSERT new snapshot ---
try:
await session.execute(
text(
"INSERT INTO router_config_snapshots "
"(device_id, tenant_id, config_text, sha256_hash, collected_at) "
"VALUES (CAST(:device_id AS uuid), CAST(:tenant_id AS uuid), "
":config_text, :sha256_hash, :collected_at)"
),
{
"device_id": device_id,
"tenant_id": tenant_id,
"config_text": encrypted_text,
"sha256_hash": sha256_hash,
"collected_at": collected_at,
},
)
await session.commit()
except IntegrityError:
logger.warning(
"Orphan device_id %s (FK constraint violation) — discarding snapshot",
device_id,
)
config_snapshot_errors_total.labels(error_type="orphan_device").inc()
await session.rollback()
await msg.ack()
return
except OperationalError as exc:
logger.warning(
"Transient DB error storing snapshot for device %s: %s",
device_id,
exc,
)
config_snapshot_errors_total.labels(error_type="db_error").inc()
await session.rollback()
await msg.nak()
return
logger.info(
"Config snapshot stored for device %s tenant %s",
device_id,
tenant_id,
)
config_snapshot_ingested_total.inc()
duration = time.monotonic() - start_time
config_snapshot_ingestion_duration_seconds.observe(duration)
await msg.ack()
async def _subscribe_with_retry(js) -> None:
"""Subscribe to config.snapshot.> with durable consumer, retrying if stream not ready."""
max_attempts = 6
for attempt in range(1, max_attempts + 1):
try:
await js.subscribe(
"config.snapshot.>",
cb=handle_config_snapshot,
durable="config_snapshot_ingest",
stream="DEVICE_EVENTS",
manual_ack=True,
)
logger.info(
"NATS: subscribed to config.snapshot.> (durable: config_snapshot_ingest)"
)
return
except Exception as exc:
if attempt < max_attempts:
logger.warning(
"NATS: stream DEVICE_EVENTS not ready for config.snapshot (attempt %d/%d): %s — retrying in 5s",
attempt,
max_attempts,
exc,
)
await asyncio.sleep(5)
else:
logger.warning(
"NATS: giving up on config.snapshot.> after %d attempts: %s",
max_attempts,
exc,
)
return
async def start_config_snapshot_subscriber() -> Optional[Any]:
"""Connect to NATS and start the config.snapshot.> subscription.
Returns the NATS connection for shutdown management.
"""
import nats
global _nc
logger.info("NATS config-snapshot: connecting to %s", settings.NATS_URL)
_nc = await nats.connect(
settings.NATS_URL,
max_reconnect_attempts=-1,
reconnect_time_wait=2,
)
logger.info("NATS config-snapshot: connected")
js = _nc.jetstream()
await _subscribe_with_retry(js)
return _nc
async def stop_config_snapshot_subscriber() -> None:
"""Drain and close the NATS connection gracefully."""
global _nc
if _nc:
try:
logger.info("NATS config-snapshot: draining connection...")
await _nc.drain()
logger.info("NATS config-snapshot: connection closed")
except Exception as exc:
logger.warning("NATS config-snapshot: error during drain: %s", exc)
try:
await _nc.close()
except Exception:
pass
finally:
_nc = None