diff --git a/backend/app/main.py b/backend/app/main.py index 5ed1e91..d2a05bd 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -80,6 +80,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: from app.services.firmware_subscriber import start_firmware_subscriber, stop_firmware_subscriber from app.services.metrics_subscriber import start_metrics_subscriber, stop_metrics_subscriber from app.services.nats_subscriber import start_nats_subscriber, stop_nats_subscriber + from app.services.session_audit_subscriber import start_session_audit_subscriber, stop_session_audit_subscriber from app.services.sse_manager import ensure_sse_streams # Configure structured logging FIRST -- before any other startup work @@ -126,6 +127,16 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: error=str(exc), ) + # Start NATS subscriber for SSH session end audit events (separate NATS connection). + session_audit_nc = None + try: + session_audit_nc = await start_session_audit_subscriber() + except Exception as exc: + logger.warning( + "NATS session audit subscriber could not start (API will run without it)", + error=str(exc), + ) + # Ensure NATS streams for SSE event delivery exist (ALERT_EVENTS, OPERATION_EVENTS). # Non-fatal -- API starts without SSE streams; they'll be created on first SSE connection. try: @@ -212,6 +223,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: await stop_nats_subscriber(nats_connection) await stop_metrics_subscriber(metrics_nc) await stop_firmware_subscriber(firmware_nc) + await stop_session_audit_subscriber(session_audit_nc) if config_change_nc: await stop_config_change_subscriber() if push_rollback_nc: diff --git a/backend/app/services/session_audit_subscriber.py b/backend/app/services/session_audit_subscriber.py new file mode 100644 index 0000000..f8e0638 --- /dev/null +++ b/backend/app/services/session_audit_subscriber.py @@ -0,0 +1,187 @@ +"""NATS JetStream subscriber for SSH session end audit events from the Go poller. + +Subscribes to audit.session.end.> and writes ssh_session_end audit log entries +with session duration. Uses the existing self-committing audit service. +""" + +import asyncio +import json +import logging +import uuid +from datetime import datetime +from typing import Optional + +import nats +from nats.js import JetStreamContext +from nats.aio.client import Client as NATSClient + +from app.config import settings +from app.services.audit_service import log_action + +logger = logging.getLogger(__name__) + +_session_audit_client: Optional[NATSClient] = None + + +async def on_session_end(msg) -> None: + """Handle an audit.session.end event published by the Go poller. + + Payload (JSON): + session_id (str) -- UUID of the SSH session + user_id (str) -- UUID of the user + tenant_id (str) -- UUID of the owning tenant + device_id (str) -- UUID of the device + start_time (str) -- RFC3339 session start + end_time (str) -- RFC3339 session end + source_ip (str) -- client IP address + reason (str) -- "normal", "idle_timeout", or "shutdown" + """ + try: + data = json.loads(msg.data) + tenant_id = data.get("tenant_id") + user_id = data.get("user_id") + device_id = data.get("device_id") + + if not tenant_id or not device_id: + logger.warning("audit.session.end event missing tenant_id or device_id — skipping") + await msg.ack() + return + + start_time = data.get("start_time", "") + end_time = data.get("end_time", "") + duration_seconds = None + if start_time and end_time: + try: + t0 = datetime.fromisoformat(start_time) + t1 = datetime.fromisoformat(end_time) + duration_seconds = int((t1 - t0).total_seconds()) + except (ValueError, TypeError): + pass + + await log_action( + db=None, # not used by audit_service internally + tenant_id=uuid.UUID(tenant_id), + user_id=uuid.UUID(user_id) if user_id else None, + action="ssh_session_end", + resource_type="device", + resource_id=device_id, + device_id=uuid.UUID(device_id), + details={ + "session_id": data.get("session_id"), + "start_time": start_time, + "end_time": end_time, + "duration_seconds": duration_seconds, + "source_ip": data.get("source_ip"), + "reason": data.get("reason"), + }, + ip_address=data.get("source_ip"), + ) + + logger.debug( + "audit.session.end processed", + extra={ + "session_id": data.get("session_id"), + "device_id": device_id, + "duration_seconds": duration_seconds, + }, + ) + await msg.ack() + + except Exception as exc: + logger.error( + "Failed to process audit.session.end event: %s", + exc, + exc_info=True, + ) + try: + await msg.nak() + except Exception: + pass + + +async def _subscribe_with_retry(js: JetStreamContext) -> None: + """Subscribe to audit.session.end.> with durable consumer, retrying if stream not ready.""" + max_attempts = 6 # ~30 seconds at 5s intervals + for attempt in range(1, max_attempts + 1): + try: + await js.subscribe( + "audit.session.end.>", + cb=on_session_end, + durable="api-session-audit-consumer", + stream="DEVICE_EVENTS", + ) + logger.info( + "NATS: subscribed to audit.session.end.> (durable: api-session-audit-consumer)" + ) + return + except Exception as exc: + if attempt < max_attempts: + logger.warning( + "NATS: stream DEVICE_EVENTS not ready for session audit (attempt %d/%d): %s — retrying in 5s", + attempt, + max_attempts, + exc, + ) + await asyncio.sleep(5) + else: + logger.warning( + "NATS: giving up on audit.session.end.> after %d attempts: %s — API will run without session audit", + max_attempts, + exc, + ) + return + + +async def start_session_audit_subscriber() -> Optional[NATSClient]: + """Connect to NATS and start the audit.session.end.> subscription. + + Returns the NATS connection (must be passed to stop_session_audit_subscriber on shutdown). + """ + global _session_audit_client + + logger.info("NATS session audit: connecting to %s", settings.NATS_URL) + + nc = await nats.connect( + settings.NATS_URL, + max_reconnect_attempts=-1, + reconnect_time_wait=2, + error_cb=_on_error, + reconnected_cb=_on_reconnected, + disconnected_cb=_on_disconnected, + ) + + logger.info("NATS session audit: connected to %s", settings.NATS_URL) + + js = nc.jetstream() + await _subscribe_with_retry(js) + + _session_audit_client = nc + return nc + + +async def stop_session_audit_subscriber(nc: Optional[NATSClient]) -> None: + """Drain and close the session audit NATS connection gracefully.""" + if nc is None: + return + try: + logger.info("NATS session audit: draining connection...") + await nc.drain() + logger.info("NATS session audit: connection closed") + except Exception as exc: + logger.warning("NATS session audit: error during drain: %s", exc) + try: + await nc.close() + except Exception: + pass + + +async def _on_error(exc: Exception) -> None: + logger.error("NATS session audit error: %s", exc) + + +async def _on_reconnected() -> None: + logger.info("NATS session audit: reconnected") + + +async def _on_disconnected() -> None: + logger.warning("NATS session audit: disconnected") diff --git a/poller/cmd/poller/main.go b/poller/cmd/poller/main.go index 13fa705..528ac7f 100644 --- a/poller/cmd/poller/main.go +++ b/poller/cmd/poller/main.go @@ -220,7 +220,7 @@ func main() { // ----------------------------------------------------------------------- // Initialize SSH relay server and HTTP listener // ----------------------------------------------------------------------- - sshServer := sshrelay.NewServer(redisClient, credentialCache, deviceStore, sshrelay.Config{ + sshServer := sshrelay.NewServer(redisClient, credentialCache, deviceStore, publisher, sshrelay.Config{ IdleTimeout: time.Duration(cfg.SSHIdleTimeout) * time.Second, MaxSessions: cfg.SSHMaxSessions, MaxPerUser: cfg.SSHMaxPerUser, diff --git a/poller/internal/bus/publisher.go b/poller/internal/bus/publisher.go index b2afbc0..0288dfe 100644 --- a/poller/internal/bus/publisher.go +++ b/poller/internal/bus/publisher.go @@ -124,6 +124,7 @@ func NewPublisher(natsURL string) (*Publisher, error) { "config.changed.>", "config.push.rollback.>", "config.push.alert.>", + "audit.session.end.>", }, MaxAge: 24 * time.Hour, }) @@ -306,6 +307,43 @@ func (p *Publisher) PublishPushAlert(ctx context.Context, event PushAlertEvent) return nil } +// SessionEndEvent is the payload published to NATS JetStream when an SSH +// relay session ends. The backend subscribes to audit.session.end.> and +// writes an audit log entry with the session duration. +type SessionEndEvent struct { + SessionID string `json:"session_id"` + UserID string `json:"user_id"` + TenantID string `json:"tenant_id"` + DeviceID string `json:"device_id"` + StartTime string `json:"start_time"` // RFC3339 + EndTime string `json:"end_time"` // RFC3339 + SourceIP string `json:"source_ip"` + Reason string `json:"reason"` // "normal", "idle_timeout", "shutdown" +} + +// PublishSessionEnd publishes an SSH session end event to NATS JetStream. +func (p *Publisher) PublishSessionEnd(ctx context.Context, event SessionEndEvent) error { + data, err := json.Marshal(event) + if err != nil { + return fmt.Errorf("marshalling session end event: %w", err) + } + + subject := fmt.Sprintf("audit.session.end.%s", event.SessionID) + + _, err = p.js.Publish(ctx, subject, data) + if err != nil { + return fmt.Errorf("publishing to %s: %w", subject, err) + } + + slog.Debug("published session end event", + "session_id", event.SessionID, + "device_id", event.DeviceID, + "subject", subject, + ) + + return nil +} + // Conn returns the raw NATS connection for use by other components // (e.g., CmdResponder for request-reply subscriptions). func (p *Publisher) Conn() *nats.Conn { diff --git a/poller/internal/sshrelay/server.go b/poller/internal/sshrelay/server.go index 38475da..7815ece 100644 --- a/poller/internal/sshrelay/server.go +++ b/poller/internal/sshrelay/server.go @@ -11,6 +11,7 @@ import ( "time" "github.com/google/uuid" + "github.com/mikrotik-portal/poller/internal/bus" "github.com/mikrotik-portal/poller/internal/store" "github.com/mikrotik-portal/poller/internal/vault" "github.com/redis/go-redis/v9" @@ -35,6 +36,7 @@ type Server struct { redis *redis.Client credCache *vault.CredentialCache deviceStore *store.DeviceStore + publisher *bus.Publisher sessions map[string]*Session mu sync.Mutex idleTime time.Duration @@ -53,12 +55,13 @@ type Config struct { } // NewServer creates and starts a new SSH relay server. -func NewServer(rc *redis.Client, cc *vault.CredentialCache, ds *store.DeviceStore, cfg Config) *Server { +func NewServer(rc *redis.Client, cc *vault.CredentialCache, ds *store.DeviceStore, pub *bus.Publisher, cfg Config) *Server { ctx, cancel := context.WithCancel(context.Background()) s := &Server{ redis: rc, credCache: cc, deviceStore: ds, + publisher: pub, sessions: make(map[string]*Session), idleTime: cfg.IdleTimeout, maxSessions: cfg.MaxSessions, @@ -255,12 +258,15 @@ func (s *Server) handleSSH(w http.ResponseWriter, r *http.Request) { delete(s.sessions, sess.ID) s.mu.Unlock() - duration := time.Since(sess.StartTime) + endTime := time.Now() + duration := endTime.Sub(sess.StartTime) slog.Info("ssh session ended", "session_id", sess.ID, "device_id", payload.DeviceID, "duration", duration.String(), ) + + s.publishSessionEnd(sess, endTime, "normal") } // validateToken performs a Redis GETDEL to atomically consume a single-use token. @@ -331,6 +337,36 @@ func (s *Server) cleanupIdle() { for _, sess := range toCancel { slog.Info("ssh session idle timeout", "session_id", sess.ID) sess.cancel() + s.publishSessionEnd(sess, time.Now(), "idle_timeout") + } +} + +// publishSessionEnd publishes an audit.session.end event via NATS JetStream. +// Errors are logged but never block session cleanup. +func (s *Server) publishSessionEnd(sess *Session, endTime time.Time, reason string) { + if s.publisher == nil { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + event := bus.SessionEndEvent{ + SessionID: sess.ID, + UserID: sess.UserID, + TenantID: sess.TenantID, + DeviceID: sess.DeviceID, + StartTime: sess.StartTime.Format(time.RFC3339), + EndTime: endTime.Format(time.RFC3339), + SourceIP: sess.SourceIP, + Reason: reason, + } + + if err := s.publisher.PublishSessionEnd(ctx, event); err != nil { + slog.Error("failed to publish session end event", + "session_id", sess.ID, + "error", err, + ) } }