Files
the-other-dude/backend/app/services/session_audit_subscriber.py
Jason Staack acf1790bed feat: add audit.session.end NATS pipeline for SSH session tracking
Poller publishes session end events via JetStream when SSH sessions
close (normal disconnect or idle timeout). Backend subscribes with a
durable consumer and writes ssh_session_end audit log entries with
duration.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-12 16:07:10 -05:00

188 lines
5.9 KiB
Python

"""NATS JetStream subscriber for SSH session end audit events from the Go poller.
Subscribes to audit.session.end.> and writes ssh_session_end audit log entries
with session duration. Uses the existing self-committing audit service.
"""
import asyncio
import json
import logging
import uuid
from datetime import datetime
from typing import Optional
import nats
from nats.js import JetStreamContext
from nats.aio.client import Client as NATSClient
from app.config import settings
from app.services.audit_service import log_action
logger = logging.getLogger(__name__)
_session_audit_client: Optional[NATSClient] = None
async def on_session_end(msg) -> None:
"""Handle an audit.session.end event published by the Go poller.
Payload (JSON):
session_id (str) -- UUID of the SSH session
user_id (str) -- UUID of the user
tenant_id (str) -- UUID of the owning tenant
device_id (str) -- UUID of the device
start_time (str) -- RFC3339 session start
end_time (str) -- RFC3339 session end
source_ip (str) -- client IP address
reason (str) -- "normal", "idle_timeout", or "shutdown"
"""
try:
data = json.loads(msg.data)
tenant_id = data.get("tenant_id")
user_id = data.get("user_id")
device_id = data.get("device_id")
if not tenant_id or not device_id:
logger.warning("audit.session.end event missing tenant_id or device_id — skipping")
await msg.ack()
return
start_time = data.get("start_time", "")
end_time = data.get("end_time", "")
duration_seconds = None
if start_time and end_time:
try:
t0 = datetime.fromisoformat(start_time)
t1 = datetime.fromisoformat(end_time)
duration_seconds = int((t1 - t0).total_seconds())
except (ValueError, TypeError):
pass
await log_action(
db=None, # not used by audit_service internally
tenant_id=uuid.UUID(tenant_id),
user_id=uuid.UUID(user_id) if user_id else None,
action="ssh_session_end",
resource_type="device",
resource_id=device_id,
device_id=uuid.UUID(device_id),
details={
"session_id": data.get("session_id"),
"start_time": start_time,
"end_time": end_time,
"duration_seconds": duration_seconds,
"source_ip": data.get("source_ip"),
"reason": data.get("reason"),
},
ip_address=data.get("source_ip"),
)
logger.debug(
"audit.session.end processed",
extra={
"session_id": data.get("session_id"),
"device_id": device_id,
"duration_seconds": duration_seconds,
},
)
await msg.ack()
except Exception as exc:
logger.error(
"Failed to process audit.session.end event: %s",
exc,
exc_info=True,
)
try:
await msg.nak()
except Exception:
pass
async def _subscribe_with_retry(js: JetStreamContext) -> None:
"""Subscribe to audit.session.end.> with durable consumer, retrying if stream not ready."""
max_attempts = 6 # ~30 seconds at 5s intervals
for attempt in range(1, max_attempts + 1):
try:
await js.subscribe(
"audit.session.end.>",
cb=on_session_end,
durable="api-session-audit-consumer",
stream="DEVICE_EVENTS",
)
logger.info(
"NATS: subscribed to audit.session.end.> (durable: api-session-audit-consumer)"
)
return
except Exception as exc:
if attempt < max_attempts:
logger.warning(
"NATS: stream DEVICE_EVENTS not ready for session audit (attempt %d/%d): %s — retrying in 5s",
attempt,
max_attempts,
exc,
)
await asyncio.sleep(5)
else:
logger.warning(
"NATS: giving up on audit.session.end.> after %d attempts: %s — API will run without session audit",
max_attempts,
exc,
)
return
async def start_session_audit_subscriber() -> Optional[NATSClient]:
"""Connect to NATS and start the audit.session.end.> subscription.
Returns the NATS connection (must be passed to stop_session_audit_subscriber on shutdown).
"""
global _session_audit_client
logger.info("NATS session audit: connecting to %s", settings.NATS_URL)
nc = await nats.connect(
settings.NATS_URL,
max_reconnect_attempts=-1,
reconnect_time_wait=2,
error_cb=_on_error,
reconnected_cb=_on_reconnected,
disconnected_cb=_on_disconnected,
)
logger.info("NATS session audit: connected to %s", settings.NATS_URL)
js = nc.jetstream()
await _subscribe_with_retry(js)
_session_audit_client = nc
return nc
async def stop_session_audit_subscriber(nc: Optional[NATSClient]) -> None:
"""Drain and close the session audit NATS connection gracefully."""
if nc is None:
return
try:
logger.info("NATS session audit: draining connection...")
await nc.drain()
logger.info("NATS session audit: connection closed")
except Exception as exc:
logger.warning("NATS session audit: error during drain: %s", exc)
try:
await nc.close()
except Exception:
pass
async def _on_error(exc: Exception) -> None:
logger.error("NATS session audit error: %s", exc)
async def _on_reconnected() -> None:
logger.info("NATS session audit: reconnected")
async def _on_disconnected() -> None:
logger.warning("NATS session audit: disconnected")