Files
the-other-dude/backend/app/services/config_snapshot_subscriber.py
Jason Staack eb76343d04 feat(05-01): wire diff generation into snapshot subscriber
- Add RETURNING id to snapshot INSERT for new_snapshot_id capture
- Call generate_and_store_diff after successful commit (best-effort)
- Outer try/except safety net ensures snapshot ack never blocked by diff
- Update subscriber tests to mock diff service

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-12 22:32:40 -05:00

269 lines
8.8 KiB
Python

"""NATS JetStream subscriber for config snapshot ingestion from the Go poller.
Consumes config.snapshot.> messages, deduplicates by SHA256 hash,
encrypts config text via OpenBao Transit, and persists new snapshots
to the router_config_snapshots table.
Plaintext config is NEVER stored in PostgreSQL and NEVER logged.
"""
import asyncio
import json
import logging
import time
from datetime import datetime, timezone
from typing import Any, Optional
from prometheus_client import Counter, Histogram
from sqlalchemy import text
from sqlalchemy.exc import IntegrityError, OperationalError
from app.config import settings
from app.database import AdminAsyncSessionLocal
from app.services.config_diff_service import generate_and_store_diff
from app.services.openbao_service import OpenBaoTransitService
logger = logging.getLogger(__name__)
# --- Prometheus metrics ---
config_snapshot_ingested_total = Counter(
"config_snapshot_ingested_total",
"Total config snapshots successfully ingested",
)
config_snapshot_dedup_skipped_total = Counter(
"config_snapshot_dedup_skipped_total",
"Total config snapshots skipped due to deduplication",
)
config_snapshot_errors_total = Counter(
"config_snapshot_errors_total",
"Total config snapshot ingestion errors",
["error_type"],
)
config_snapshot_ingestion_duration_seconds = Histogram(
"config_snapshot_ingestion_duration_seconds",
"Time to process a config snapshot message",
)
# --- Module state ---
_nc: Optional[Any] = None
async def handle_config_snapshot(msg) -> None:
"""Handle a config.snapshot.> message from the Go poller.
1. Parse JSON payload; malformed -> ack + discard
2. Dedup check against latest hash for device
3. Encrypt via OpenBao Transit; failure -> nak (NATS retry)
4. INSERT new RouterConfigSnapshot row
5. FK violation (orphan device) -> ack + discard
6. Transient DB error -> nak for retry
"""
start_time = time.monotonic()
try:
data = json.loads(msg.data)
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
logger.warning("Malformed config snapshot message (bad JSON): %s", exc)
config_snapshot_errors_total.labels(error_type="malformed").inc()
await msg.ack()
return
device_id = data.get("device_id")
tenant_id = data.get("tenant_id")
sha256_hash = data.get("sha256_hash")
config_text = data.get("config_text")
# Validate required fields
if not device_id or not tenant_id or not sha256_hash or not config_text:
logger.warning(
"Config snapshot message missing required fields (device_id=%s, tenant_id=%s)",
device_id,
tenant_id,
)
config_snapshot_errors_total.labels(error_type="malformed").inc()
await msg.ack()
return
collected_at_raw = data.get("collected_at")
try:
collected_at = datetime.fromisoformat(
collected_at_raw.replace("Z", "+00:00")
) if collected_at_raw else datetime.now(timezone.utc)
except (ValueError, AttributeError):
collected_at = datetime.now(timezone.utc)
async with AdminAsyncSessionLocal() as session:
# --- Dedup check ---
result = await session.execute(
text(
"SELECT sha256_hash FROM router_config_snapshots "
"WHERE device_id = CAST(:device_id AS uuid) "
"ORDER BY collected_at DESC LIMIT 1"
),
{"device_id": device_id},
)
latest_hash = result.scalar_one_or_none()
if latest_hash == sha256_hash:
logger.debug(
"Duplicate config snapshot skipped for device %s",
device_id,
)
config_snapshot_dedup_skipped_total.inc()
await msg.ack()
return
# --- Encrypt via OpenBao Transit ---
openbao = OpenBaoTransitService()
try:
encrypted_text = await openbao.encrypt(
tenant_id, config_text.encode("utf-8")
)
except Exception as exc:
logger.warning(
"Transit encrypt failed for device %s tenant %s: %s",
device_id,
tenant_id,
exc,
)
config_snapshot_errors_total.labels(error_type="encrypt_unavailable").inc()
await msg.nak()
return
finally:
await openbao.close()
# --- INSERT new snapshot ---
try:
insert_result = await session.execute(
text(
"INSERT INTO router_config_snapshots "
"(device_id, tenant_id, config_text, sha256_hash, collected_at) "
"VALUES (CAST(:device_id AS uuid), CAST(:tenant_id AS uuid), "
":config_text, :sha256_hash, :collected_at) "
"RETURNING id"
),
{
"device_id": device_id,
"tenant_id": tenant_id,
"config_text": encrypted_text,
"sha256_hash": sha256_hash,
"collected_at": collected_at,
},
)
new_snapshot_id = insert_result.scalar_one()
await session.commit()
except IntegrityError:
logger.warning(
"Orphan device_id %s (FK constraint violation) — discarding snapshot",
device_id,
)
config_snapshot_errors_total.labels(error_type="orphan_device").inc()
await session.rollback()
await msg.ack()
return
except OperationalError as exc:
logger.warning(
"Transient DB error storing snapshot for device %s: %s",
device_id,
exc,
)
config_snapshot_errors_total.labels(error_type="db_error").inc()
await session.rollback()
await msg.nak()
return
# --- Diff generation (best-effort) ---
try:
await generate_and_store_diff(device_id, tenant_id, str(new_snapshot_id), session)
except Exception as exc:
logger.warning(
"Diff generation failed for device %s (non-fatal): %s",
device_id, exc,
)
logger.info(
"Config snapshot stored for device %s tenant %s",
device_id,
tenant_id,
)
config_snapshot_ingested_total.inc()
duration = time.monotonic() - start_time
config_snapshot_ingestion_duration_seconds.observe(duration)
await msg.ack()
async def _subscribe_with_retry(js) -> None:
"""Subscribe to config.snapshot.> with durable consumer, retrying if stream not ready."""
max_attempts = 6
for attempt in range(1, max_attempts + 1):
try:
await js.subscribe(
"config.snapshot.>",
cb=handle_config_snapshot,
durable="config_snapshot_ingest",
stream="DEVICE_EVENTS",
manual_ack=True,
)
logger.info(
"NATS: subscribed to config.snapshot.> (durable: config_snapshot_ingest)"
)
return
except Exception as exc:
if attempt < max_attempts:
logger.warning(
"NATS: stream DEVICE_EVENTS not ready for config.snapshot (attempt %d/%d): %s — retrying in 5s",
attempt,
max_attempts,
exc,
)
await asyncio.sleep(5)
else:
logger.warning(
"NATS: giving up on config.snapshot.> after %d attempts: %s",
max_attempts,
exc,
)
return
async def start_config_snapshot_subscriber() -> Optional[Any]:
"""Connect to NATS and start the config.snapshot.> subscription.
Returns the NATS connection for shutdown management.
"""
import nats
global _nc
logger.info("NATS config-snapshot: connecting to %s", settings.NATS_URL)
_nc = await nats.connect(
settings.NATS_URL,
max_reconnect_attempts=-1,
reconnect_time_wait=2,
)
logger.info("NATS config-snapshot: connected")
js = _nc.jetstream()
await _subscribe_with_retry(js)
return _nc
async def stop_config_snapshot_subscriber() -> None:
"""Drain and close the NATS connection gracefully."""
global _nc
if _nc:
try:
logger.info("NATS config-snapshot: draining connection...")
await _nc.drain()
logger.info("NATS config-snapshot: connection closed")
except Exception as exc:
logger.warning("NATS config-snapshot: error during drain: %s", exc)
try:
await _nc.close()
except Exception:
pass
finally:
_nc = None