feat(03-01): implement config snapshot subscriber with dedup and encryption
- NATS subscriber for config.snapshot.> on DEVICE_EVENTS stream - Dedup by SHA256 hash against latest snapshot per device - OpenBao Transit encryption before INSERT (plaintext never stored) - Malformed/orphan messages acked and discarded safely - Transit failure causes nak for NATS retry - Prometheus metrics: ingested, dedup_skipped, errors, duration - All 6 unit tests pass Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
256
backend/app/services/config_snapshot_subscriber.py
Normal file
256
backend/app/services/config_snapshot_subscriber.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""NATS JetStream subscriber for config snapshot ingestion from the Go poller.
|
||||
|
||||
Consumes config.snapshot.> messages, deduplicates by SHA256 hash,
|
||||
encrypts config text via OpenBao Transit, and persists new snapshots
|
||||
to the router_config_snapshots table.
|
||||
|
||||
Plaintext config is NEVER stored in PostgreSQL and NEVER logged.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from prometheus_client import Counter, Histogram
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
from app.services.openbao_service import OpenBaoTransitService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Prometheus metrics ---
|
||||
|
||||
config_snapshot_ingested_total = Counter(
|
||||
"config_snapshot_ingested_total",
|
||||
"Total config snapshots successfully ingested",
|
||||
)
|
||||
config_snapshot_dedup_skipped_total = Counter(
|
||||
"config_snapshot_dedup_skipped_total",
|
||||
"Total config snapshots skipped due to deduplication",
|
||||
)
|
||||
config_snapshot_errors_total = Counter(
|
||||
"config_snapshot_errors_total",
|
||||
"Total config snapshot ingestion errors",
|
||||
["error_type"],
|
||||
)
|
||||
config_snapshot_ingestion_duration_seconds = Histogram(
|
||||
"config_snapshot_ingestion_duration_seconds",
|
||||
"Time to process a config snapshot message",
|
||||
)
|
||||
|
||||
# --- Module state ---
|
||||
|
||||
_nc: Optional[Any] = None
|
||||
|
||||
|
||||
async def handle_config_snapshot(msg) -> None:
|
||||
"""Handle a config.snapshot.> message from the Go poller.
|
||||
|
||||
1. Parse JSON payload; malformed -> ack + discard
|
||||
2. Dedup check against latest hash for device
|
||||
3. Encrypt via OpenBao Transit; failure -> nak (NATS retry)
|
||||
4. INSERT new RouterConfigSnapshot row
|
||||
5. FK violation (orphan device) -> ack + discard
|
||||
6. Transient DB error -> nak for retry
|
||||
"""
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
|
||||
logger.warning("Malformed config snapshot message (bad JSON): %s", exc)
|
||||
config_snapshot_errors_total.labels(error_type="malformed").inc()
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
device_id = data.get("device_id")
|
||||
tenant_id = data.get("tenant_id")
|
||||
sha256_hash = data.get("sha256_hash")
|
||||
config_text = data.get("config_text")
|
||||
|
||||
# Validate required fields
|
||||
if not device_id or not tenant_id or not sha256_hash or not config_text:
|
||||
logger.warning(
|
||||
"Config snapshot message missing required fields (device_id=%s, tenant_id=%s)",
|
||||
device_id,
|
||||
tenant_id,
|
||||
)
|
||||
config_snapshot_errors_total.labels(error_type="malformed").inc()
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
collected_at_raw = data.get("collected_at")
|
||||
try:
|
||||
collected_at = datetime.fromisoformat(
|
||||
collected_at_raw.replace("Z", "+00:00")
|
||||
) if collected_at_raw else datetime.now(timezone.utc)
|
||||
except (ValueError, AttributeError):
|
||||
collected_at = datetime.now(timezone.utc)
|
||||
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
# --- Dedup check ---
|
||||
result = await session.execute(
|
||||
text(
|
||||
"SELECT sha256_hash FROM router_config_snapshots "
|
||||
"WHERE device_id = CAST(:device_id AS uuid) "
|
||||
"ORDER BY collected_at DESC LIMIT 1"
|
||||
),
|
||||
{"device_id": device_id},
|
||||
)
|
||||
latest_hash = result.scalar_one_or_none()
|
||||
|
||||
if latest_hash == sha256_hash:
|
||||
logger.debug(
|
||||
"Duplicate config snapshot skipped for device %s",
|
||||
device_id,
|
||||
)
|
||||
config_snapshot_dedup_skipped_total.inc()
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
# --- Encrypt via OpenBao Transit ---
|
||||
openbao = OpenBaoTransitService()
|
||||
try:
|
||||
encrypted_text = await openbao.encrypt(
|
||||
tenant_id, config_text.encode("utf-8")
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Transit encrypt failed for device %s tenant %s: %s",
|
||||
device_id,
|
||||
tenant_id,
|
||||
exc,
|
||||
)
|
||||
config_snapshot_errors_total.labels(error_type="encrypt_unavailable").inc()
|
||||
await msg.nak()
|
||||
return
|
||||
finally:
|
||||
await openbao.close()
|
||||
|
||||
# --- INSERT new snapshot ---
|
||||
try:
|
||||
await session.execute(
|
||||
text(
|
||||
"INSERT INTO router_config_snapshots "
|
||||
"(device_id, tenant_id, config_text, sha256_hash, collected_at) "
|
||||
"VALUES (CAST(:device_id AS uuid), CAST(:tenant_id AS uuid), "
|
||||
":config_text, :sha256_hash, :collected_at)"
|
||||
),
|
||||
{
|
||||
"device_id": device_id,
|
||||
"tenant_id": tenant_id,
|
||||
"config_text": encrypted_text,
|
||||
"sha256_hash": sha256_hash,
|
||||
"collected_at": collected_at,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
except IntegrityError:
|
||||
logger.warning(
|
||||
"Orphan device_id %s (FK constraint violation) — discarding snapshot",
|
||||
device_id,
|
||||
)
|
||||
config_snapshot_errors_total.labels(error_type="orphan_device").inc()
|
||||
await session.rollback()
|
||||
await msg.ack()
|
||||
return
|
||||
except OperationalError as exc:
|
||||
logger.warning(
|
||||
"Transient DB error storing snapshot for device %s: %s",
|
||||
device_id,
|
||||
exc,
|
||||
)
|
||||
config_snapshot_errors_total.labels(error_type="db_error").inc()
|
||||
await session.rollback()
|
||||
await msg.nak()
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Config snapshot stored for device %s tenant %s",
|
||||
device_id,
|
||||
tenant_id,
|
||||
)
|
||||
config_snapshot_ingested_total.inc()
|
||||
duration = time.monotonic() - start_time
|
||||
config_snapshot_ingestion_duration_seconds.observe(duration)
|
||||
await msg.ack()
|
||||
|
||||
|
||||
async def _subscribe_with_retry(js) -> None:
|
||||
"""Subscribe to config.snapshot.> with durable consumer, retrying if stream not ready."""
|
||||
max_attempts = 6
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
await js.subscribe(
|
||||
"config.snapshot.>",
|
||||
cb=handle_config_snapshot,
|
||||
durable="config_snapshot_ingest",
|
||||
stream="DEVICE_EVENTS",
|
||||
manual_ack=True,
|
||||
)
|
||||
logger.info(
|
||||
"NATS: subscribed to config.snapshot.> (durable: config_snapshot_ingest)"
|
||||
)
|
||||
return
|
||||
except Exception as exc:
|
||||
if attempt < max_attempts:
|
||||
logger.warning(
|
||||
"NATS: stream DEVICE_EVENTS not ready for config.snapshot (attempt %d/%d): %s — retrying in 5s",
|
||||
attempt,
|
||||
max_attempts,
|
||||
exc,
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
logger.warning(
|
||||
"NATS: giving up on config.snapshot.> after %d attempts: %s",
|
||||
max_attempts,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
async def start_config_snapshot_subscriber() -> Optional[Any]:
|
||||
"""Connect to NATS and start the config.snapshot.> subscription.
|
||||
|
||||
Returns the NATS connection for shutdown management.
|
||||
"""
|
||||
import nats
|
||||
|
||||
global _nc
|
||||
|
||||
logger.info("NATS config-snapshot: connecting to %s", settings.NATS_URL)
|
||||
_nc = await nats.connect(
|
||||
settings.NATS_URL,
|
||||
max_reconnect_attempts=-1,
|
||||
reconnect_time_wait=2,
|
||||
)
|
||||
logger.info("NATS config-snapshot: connected")
|
||||
|
||||
js = _nc.jetstream()
|
||||
await _subscribe_with_retry(js)
|
||||
|
||||
return _nc
|
||||
|
||||
|
||||
async def stop_config_snapshot_subscriber() -> None:
|
||||
"""Drain and close the NATS connection gracefully."""
|
||||
global _nc
|
||||
if _nc:
|
||||
try:
|
||||
logger.info("NATS config-snapshot: draining connection...")
|
||||
await _nc.drain()
|
||||
logger.info("NATS config-snapshot: connection closed")
|
||||
except Exception as exc:
|
||||
logger.warning("NATS config-snapshot: error during drain: %s", exc)
|
||||
try:
|
||||
await _nc.close()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
_nc = None
|
||||
Reference in New Issue
Block a user