From 3ab9f27d493d45bd83a088b14e05cef5b15002f8 Mon Sep 17 00:00:00 2001 From: Jason Staack Date: Thu, 12 Mar 2026 21:47:07 -0500 Subject: [PATCH] feat(03-01): implement config snapshot subscriber with dedup and encryption - NATS subscriber for config.snapshot.> on DEVICE_EVENTS stream - Dedup by SHA256 hash against latest snapshot per device - OpenBao Transit encryption before INSERT (plaintext never stored) - Malformed/orphan messages acked and discarded safely - Transit failure causes nak for NATS retry - Prometheus metrics: ingested, dedup_skipped, errors, duration - All 6 unit tests pass Co-Authored-By: Claude Opus 4.6 --- .../services/config_snapshot_subscriber.py | 256 ++++++++++++++++++ 1 file changed, 256 insertions(+) create mode 100644 backend/app/services/config_snapshot_subscriber.py diff --git a/backend/app/services/config_snapshot_subscriber.py b/backend/app/services/config_snapshot_subscriber.py new file mode 100644 index 0000000..6bb82c6 --- /dev/null +++ b/backend/app/services/config_snapshot_subscriber.py @@ -0,0 +1,256 @@ +"""NATS JetStream subscriber for config snapshot ingestion from the Go poller. + +Consumes config.snapshot.> messages, deduplicates by SHA256 hash, +encrypts config text via OpenBao Transit, and persists new snapshots +to the router_config_snapshots table. + +Plaintext config is NEVER stored in PostgreSQL and NEVER logged. +""" + +import asyncio +import json +import logging +import time +from datetime import datetime, timezone +from typing import Any, Optional + +from prometheus_client import Counter, Histogram +from sqlalchemy import text +from sqlalchemy.exc import IntegrityError, OperationalError + +from app.config import settings +from app.database import AdminAsyncSessionLocal +from app.services.openbao_service import OpenBaoTransitService + +logger = logging.getLogger(__name__) + +# --- Prometheus metrics --- + +config_snapshot_ingested_total = Counter( + "config_snapshot_ingested_total", + "Total config snapshots successfully ingested", +) +config_snapshot_dedup_skipped_total = Counter( + "config_snapshot_dedup_skipped_total", + "Total config snapshots skipped due to deduplication", +) +config_snapshot_errors_total = Counter( + "config_snapshot_errors_total", + "Total config snapshot ingestion errors", + ["error_type"], +) +config_snapshot_ingestion_duration_seconds = Histogram( + "config_snapshot_ingestion_duration_seconds", + "Time to process a config snapshot message", +) + +# --- Module state --- + +_nc: Optional[Any] = None + + +async def handle_config_snapshot(msg) -> None: + """Handle a config.snapshot.> message from the Go poller. + + 1. Parse JSON payload; malformed -> ack + discard + 2. Dedup check against latest hash for device + 3. Encrypt via OpenBao Transit; failure -> nak (NATS retry) + 4. INSERT new RouterConfigSnapshot row + 5. FK violation (orphan device) -> ack + discard + 6. Transient DB error -> nak for retry + """ + start_time = time.monotonic() + try: + data = json.loads(msg.data) + except (json.JSONDecodeError, UnicodeDecodeError) as exc: + logger.warning("Malformed config snapshot message (bad JSON): %s", exc) + config_snapshot_errors_total.labels(error_type="malformed").inc() + await msg.ack() + return + + device_id = data.get("device_id") + tenant_id = data.get("tenant_id") + sha256_hash = data.get("sha256_hash") + config_text = data.get("config_text") + + # Validate required fields + if not device_id or not tenant_id or not sha256_hash or not config_text: + logger.warning( + "Config snapshot message missing required fields (device_id=%s, tenant_id=%s)", + device_id, + tenant_id, + ) + config_snapshot_errors_total.labels(error_type="malformed").inc() + await msg.ack() + return + + collected_at_raw = data.get("collected_at") + try: + collected_at = datetime.fromisoformat( + collected_at_raw.replace("Z", "+00:00") + ) if collected_at_raw else datetime.now(timezone.utc) + except (ValueError, AttributeError): + collected_at = datetime.now(timezone.utc) + + async with AdminAsyncSessionLocal() as session: + # --- Dedup check --- + result = await session.execute( + text( + "SELECT sha256_hash FROM router_config_snapshots " + "WHERE device_id = CAST(:device_id AS uuid) " + "ORDER BY collected_at DESC LIMIT 1" + ), + {"device_id": device_id}, + ) + latest_hash = result.scalar_one_or_none() + + if latest_hash == sha256_hash: + logger.debug( + "Duplicate config snapshot skipped for device %s", + device_id, + ) + config_snapshot_dedup_skipped_total.inc() + await msg.ack() + return + + # --- Encrypt via OpenBao Transit --- + openbao = OpenBaoTransitService() + try: + encrypted_text = await openbao.encrypt( + tenant_id, config_text.encode("utf-8") + ) + except Exception as exc: + logger.warning( + "Transit encrypt failed for device %s tenant %s: %s", + device_id, + tenant_id, + exc, + ) + config_snapshot_errors_total.labels(error_type="encrypt_unavailable").inc() + await msg.nak() + return + finally: + await openbao.close() + + # --- INSERT new snapshot --- + try: + await session.execute( + text( + "INSERT INTO router_config_snapshots " + "(device_id, tenant_id, config_text, sha256_hash, collected_at) " + "VALUES (CAST(:device_id AS uuid), CAST(:tenant_id AS uuid), " + ":config_text, :sha256_hash, :collected_at)" + ), + { + "device_id": device_id, + "tenant_id": tenant_id, + "config_text": encrypted_text, + "sha256_hash": sha256_hash, + "collected_at": collected_at, + }, + ) + await session.commit() + except IntegrityError: + logger.warning( + "Orphan device_id %s (FK constraint violation) — discarding snapshot", + device_id, + ) + config_snapshot_errors_total.labels(error_type="orphan_device").inc() + await session.rollback() + await msg.ack() + return + except OperationalError as exc: + logger.warning( + "Transient DB error storing snapshot for device %s: %s", + device_id, + exc, + ) + config_snapshot_errors_total.labels(error_type="db_error").inc() + await session.rollback() + await msg.nak() + return + + logger.info( + "Config snapshot stored for device %s tenant %s", + device_id, + tenant_id, + ) + config_snapshot_ingested_total.inc() + duration = time.monotonic() - start_time + config_snapshot_ingestion_duration_seconds.observe(duration) + await msg.ack() + + +async def _subscribe_with_retry(js) -> None: + """Subscribe to config.snapshot.> with durable consumer, retrying if stream not ready.""" + max_attempts = 6 + for attempt in range(1, max_attempts + 1): + try: + await js.subscribe( + "config.snapshot.>", + cb=handle_config_snapshot, + durable="config_snapshot_ingest", + stream="DEVICE_EVENTS", + manual_ack=True, + ) + logger.info( + "NATS: subscribed to config.snapshot.> (durable: config_snapshot_ingest)" + ) + return + except Exception as exc: + if attempt < max_attempts: + logger.warning( + "NATS: stream DEVICE_EVENTS not ready for config.snapshot (attempt %d/%d): %s — retrying in 5s", + attempt, + max_attempts, + exc, + ) + await asyncio.sleep(5) + else: + logger.warning( + "NATS: giving up on config.snapshot.> after %d attempts: %s", + max_attempts, + exc, + ) + return + + +async def start_config_snapshot_subscriber() -> Optional[Any]: + """Connect to NATS and start the config.snapshot.> subscription. + + Returns the NATS connection for shutdown management. + """ + import nats + + global _nc + + logger.info("NATS config-snapshot: connecting to %s", settings.NATS_URL) + _nc = await nats.connect( + settings.NATS_URL, + max_reconnect_attempts=-1, + reconnect_time_wait=2, + ) + logger.info("NATS config-snapshot: connected") + + js = _nc.jetstream() + await _subscribe_with_retry(js) + + return _nc + + +async def stop_config_snapshot_subscriber() -> None: + """Drain and close the NATS connection gracefully.""" + global _nc + if _nc: + try: + logger.info("NATS config-snapshot: draining connection...") + await _nc.drain() + logger.info("NATS config-snapshot: connection closed") + except Exception as exc: + logger.warning("NATS config-snapshot: error during drain: %s", exc) + try: + await _nc.close() + except Exception: + pass + finally: + _nc = None