feat: add audit.session.end NATS pipeline for SSH session tracking

Poller publishes session end events via JetStream when SSH sessions
close (normal disconnect or idle timeout). Backend subscribes with a
durable consumer and writes ssh_session_end audit log entries with
duration.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Jason Staack
2026-03-12 16:07:10 -05:00
parent 7aaaeaa1d1
commit acf1790bed
5 changed files with 276 additions and 3 deletions

View File

@@ -80,6 +80,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
from app.services.firmware_subscriber import start_firmware_subscriber, stop_firmware_subscriber
from app.services.metrics_subscriber import start_metrics_subscriber, stop_metrics_subscriber
from app.services.nats_subscriber import start_nats_subscriber, stop_nats_subscriber
from app.services.session_audit_subscriber import start_session_audit_subscriber, stop_session_audit_subscriber
from app.services.sse_manager import ensure_sse_streams
# Configure structured logging FIRST -- before any other startup work
@@ -126,6 +127,16 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
error=str(exc),
)
# Start NATS subscriber for SSH session end audit events (separate NATS connection).
session_audit_nc = None
try:
session_audit_nc = await start_session_audit_subscriber()
except Exception as exc:
logger.warning(
"NATS session audit subscriber could not start (API will run without it)",
error=str(exc),
)
# Ensure NATS streams for SSE event delivery exist (ALERT_EVENTS, OPERATION_EVENTS).
# Non-fatal -- API starts without SSE streams; they'll be created on first SSE connection.
try:
@@ -212,6 +223,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await stop_nats_subscriber(nats_connection)
await stop_metrics_subscriber(metrics_nc)
await stop_firmware_subscriber(firmware_nc)
await stop_session_audit_subscriber(session_audit_nc)
if config_change_nc:
await stop_config_change_subscriber()
if push_rollback_nc:

View File

@@ -0,0 +1,187 @@
"""NATS JetStream subscriber for SSH session end audit events from the Go poller.
Subscribes to audit.session.end.> and writes ssh_session_end audit log entries
with session duration. Uses the existing self-committing audit service.
"""
import asyncio
import json
import logging
import uuid
from datetime import datetime
from typing import Optional
import nats
from nats.js import JetStreamContext
from nats.aio.client import Client as NATSClient
from app.config import settings
from app.services.audit_service import log_action
logger = logging.getLogger(__name__)
_session_audit_client: Optional[NATSClient] = None
async def on_session_end(msg) -> None:
"""Handle an audit.session.end event published by the Go poller.
Payload (JSON):
session_id (str) -- UUID of the SSH session
user_id (str) -- UUID of the user
tenant_id (str) -- UUID of the owning tenant
device_id (str) -- UUID of the device
start_time (str) -- RFC3339 session start
end_time (str) -- RFC3339 session end
source_ip (str) -- client IP address
reason (str) -- "normal", "idle_timeout", or "shutdown"
"""
try:
data = json.loads(msg.data)
tenant_id = data.get("tenant_id")
user_id = data.get("user_id")
device_id = data.get("device_id")
if not tenant_id or not device_id:
logger.warning("audit.session.end event missing tenant_id or device_id — skipping")
await msg.ack()
return
start_time = data.get("start_time", "")
end_time = data.get("end_time", "")
duration_seconds = None
if start_time and end_time:
try:
t0 = datetime.fromisoformat(start_time)
t1 = datetime.fromisoformat(end_time)
duration_seconds = int((t1 - t0).total_seconds())
except (ValueError, TypeError):
pass
await log_action(
db=None, # not used by audit_service internally
tenant_id=uuid.UUID(tenant_id),
user_id=uuid.UUID(user_id) if user_id else None,
action="ssh_session_end",
resource_type="device",
resource_id=device_id,
device_id=uuid.UUID(device_id),
details={
"session_id": data.get("session_id"),
"start_time": start_time,
"end_time": end_time,
"duration_seconds": duration_seconds,
"source_ip": data.get("source_ip"),
"reason": data.get("reason"),
},
ip_address=data.get("source_ip"),
)
logger.debug(
"audit.session.end processed",
extra={
"session_id": data.get("session_id"),
"device_id": device_id,
"duration_seconds": duration_seconds,
},
)
await msg.ack()
except Exception as exc:
logger.error(
"Failed to process audit.session.end event: %s",
exc,
exc_info=True,
)
try:
await msg.nak()
except Exception:
pass
async def _subscribe_with_retry(js: JetStreamContext) -> None:
"""Subscribe to audit.session.end.> with durable consumer, retrying if stream not ready."""
max_attempts = 6 # ~30 seconds at 5s intervals
for attempt in range(1, max_attempts + 1):
try:
await js.subscribe(
"audit.session.end.>",
cb=on_session_end,
durable="api-session-audit-consumer",
stream="DEVICE_EVENTS",
)
logger.info(
"NATS: subscribed to audit.session.end.> (durable: api-session-audit-consumer)"
)
return
except Exception as exc:
if attempt < max_attempts:
logger.warning(
"NATS: stream DEVICE_EVENTS not ready for session audit (attempt %d/%d): %s — retrying in 5s",
attempt,
max_attempts,
exc,
)
await asyncio.sleep(5)
else:
logger.warning(
"NATS: giving up on audit.session.end.> after %d attempts: %s — API will run without session audit",
max_attempts,
exc,
)
return
async def start_session_audit_subscriber() -> Optional[NATSClient]:
"""Connect to NATS and start the audit.session.end.> subscription.
Returns the NATS connection (must be passed to stop_session_audit_subscriber on shutdown).
"""
global _session_audit_client
logger.info("NATS session audit: connecting to %s", settings.NATS_URL)
nc = await nats.connect(
settings.NATS_URL,
max_reconnect_attempts=-1,
reconnect_time_wait=2,
error_cb=_on_error,
reconnected_cb=_on_reconnected,
disconnected_cb=_on_disconnected,
)
logger.info("NATS session audit: connected to %s", settings.NATS_URL)
js = nc.jetstream()
await _subscribe_with_retry(js)
_session_audit_client = nc
return nc
async def stop_session_audit_subscriber(nc: Optional[NATSClient]) -> None:
"""Drain and close the session audit NATS connection gracefully."""
if nc is None:
return
try:
logger.info("NATS session audit: draining connection...")
await nc.drain()
logger.info("NATS session audit: connection closed")
except Exception as exc:
logger.warning("NATS session audit: error during drain: %s", exc)
try:
await nc.close()
except Exception:
pass
async def _on_error(exc: Exception) -> None:
logger.error("NATS session audit error: %s", exc)
async def _on_reconnected() -> None:
logger.info("NATS session audit: reconnected")
async def _on_disconnected() -> None:
logger.warning("NATS session audit: disconnected")

View File

@@ -220,7 +220,7 @@ func main() {
// -----------------------------------------------------------------------
// Initialize SSH relay server and HTTP listener
// -----------------------------------------------------------------------
sshServer := sshrelay.NewServer(redisClient, credentialCache, deviceStore, sshrelay.Config{
sshServer := sshrelay.NewServer(redisClient, credentialCache, deviceStore, publisher, sshrelay.Config{
IdleTimeout: time.Duration(cfg.SSHIdleTimeout) * time.Second,
MaxSessions: cfg.SSHMaxSessions,
MaxPerUser: cfg.SSHMaxPerUser,

View File

@@ -124,6 +124,7 @@ func NewPublisher(natsURL string) (*Publisher, error) {
"config.changed.>",
"config.push.rollback.>",
"config.push.alert.>",
"audit.session.end.>",
},
MaxAge: 24 * time.Hour,
})
@@ -306,6 +307,43 @@ func (p *Publisher) PublishPushAlert(ctx context.Context, event PushAlertEvent)
return nil
}
// SessionEndEvent is the payload published to NATS JetStream when an SSH
// relay session ends. The backend subscribes to audit.session.end.> and
// writes an audit log entry with the session duration.
type SessionEndEvent struct {
SessionID string `json:"session_id"`
UserID string `json:"user_id"`
TenantID string `json:"tenant_id"`
DeviceID string `json:"device_id"`
StartTime string `json:"start_time"` // RFC3339
EndTime string `json:"end_time"` // RFC3339
SourceIP string `json:"source_ip"`
Reason string `json:"reason"` // "normal", "idle_timeout", "shutdown"
}
// PublishSessionEnd publishes an SSH session end event to NATS JetStream.
func (p *Publisher) PublishSessionEnd(ctx context.Context, event SessionEndEvent) error {
data, err := json.Marshal(event)
if err != nil {
return fmt.Errorf("marshalling session end event: %w", err)
}
subject := fmt.Sprintf("audit.session.end.%s", event.SessionID)
_, err = p.js.Publish(ctx, subject, data)
if err != nil {
return fmt.Errorf("publishing to %s: %w", subject, err)
}
slog.Debug("published session end event",
"session_id", event.SessionID,
"device_id", event.DeviceID,
"subject", subject,
)
return nil
}
// Conn returns the raw NATS connection for use by other components
// (e.g., CmdResponder for request-reply subscriptions).
func (p *Publisher) Conn() *nats.Conn {

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/mikrotik-portal/poller/internal/bus"
"github.com/mikrotik-portal/poller/internal/store"
"github.com/mikrotik-portal/poller/internal/vault"
"github.com/redis/go-redis/v9"
@@ -35,6 +36,7 @@ type Server struct {
redis *redis.Client
credCache *vault.CredentialCache
deviceStore *store.DeviceStore
publisher *bus.Publisher
sessions map[string]*Session
mu sync.Mutex
idleTime time.Duration
@@ -53,12 +55,13 @@ type Config struct {
}
// NewServer creates and starts a new SSH relay server.
func NewServer(rc *redis.Client, cc *vault.CredentialCache, ds *store.DeviceStore, cfg Config) *Server {
func NewServer(rc *redis.Client, cc *vault.CredentialCache, ds *store.DeviceStore, pub *bus.Publisher, cfg Config) *Server {
ctx, cancel := context.WithCancel(context.Background())
s := &Server{
redis: rc,
credCache: cc,
deviceStore: ds,
publisher: pub,
sessions: make(map[string]*Session),
idleTime: cfg.IdleTimeout,
maxSessions: cfg.MaxSessions,
@@ -255,12 +258,15 @@ func (s *Server) handleSSH(w http.ResponseWriter, r *http.Request) {
delete(s.sessions, sess.ID)
s.mu.Unlock()
duration := time.Since(sess.StartTime)
endTime := time.Now()
duration := endTime.Sub(sess.StartTime)
slog.Info("ssh session ended",
"session_id", sess.ID,
"device_id", payload.DeviceID,
"duration", duration.String(),
)
s.publishSessionEnd(sess, endTime, "normal")
}
// validateToken performs a Redis GETDEL to atomically consume a single-use token.
@@ -331,6 +337,36 @@ func (s *Server) cleanupIdle() {
for _, sess := range toCancel {
slog.Info("ssh session idle timeout", "session_id", sess.ID)
sess.cancel()
s.publishSessionEnd(sess, time.Now(), "idle_timeout")
}
}
// publishSessionEnd publishes an audit.session.end event via NATS JetStream.
// Errors are logged but never block session cleanup.
func (s *Server) publishSessionEnd(sess *Session, endTime time.Time, reason string) {
if s.publisher == nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
event := bus.SessionEndEvent{
SessionID: sess.ID,
UserID: sess.UserID,
TenantID: sess.TenantID,
DeviceID: sess.DeviceID,
StartTime: sess.StartTime.Format(time.RFC3339),
EndTime: endTime.Format(time.RFC3339),
SourceIP: sess.SourceIP,
Reason: reason,
}
if err := s.publisher.PublishSessionEnd(ctx, event); err != nil {
slog.Error("failed to publish session end event",
"session_id", sess.ID,
"error", err,
)
}
}