fix: address spec compliance gaps - tenant check, XFF fallback, rate limiting

- Gap 1: Add tenant ID verification after device lookup in SSH relay handleSSH,
  closing cross-tenant token reuse vulnerability
- Gap 2: Add X-Forwarded-For fallback (last entry) when X-Real-IP is absent in
  SSH relay source IP extraction; import strings package
- Gap 3: Add @limiter.limit("10/minute") to POST /winbox-session and POST
  /ssh-session using existing slowapi pattern from app.middleware.rate_limit
- Gap 4: Add TODO comment in open_ssh_session explaining that SSH session count
  enforcement is at the poller level; no NATS subject exists yet for API-side
  pre-check

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Jason Staack
2026-03-12 15:51:14 -05:00
parent a4e1c78744
commit 7aaaeaa1d1
2 changed files with 26 additions and 1 deletions

View File

@@ -29,6 +29,7 @@ from app.schemas.remote_access import (
TunnelStatusItem, TunnelStatusItem,
WinboxSessionResponse, WinboxSessionResponse,
) )
from app.middleware.rate_limit import limiter
from app.services.audit_service import log_action from app.services.audit_service import log_action
from sqlalchemy import select from sqlalchemy import select
@@ -102,6 +103,7 @@ async def _check_tenant_access(current_user: CurrentUser, tenant_id: uuid.UUID,
summary="Open a WinBox tunnel to the device", summary="Open a WinBox tunnel to the device",
dependencies=[Depends(require_operator_or_above)], dependencies=[Depends(require_operator_or_above)],
) )
@limiter.limit("10/minute")
async def open_winbox_session( async def open_winbox_session(
tenant_id: uuid.UUID, tenant_id: uuid.UUID,
device_id: uuid.UUID, device_id: uuid.UUID,
@@ -176,6 +178,7 @@ async def open_winbox_session(
summary="Create a single-use SSH WebSocket session token", summary="Create a single-use SSH WebSocket session token",
dependencies=[Depends(require_operator_or_above)], dependencies=[Depends(require_operator_or_above)],
) )
@limiter.limit("10/minute")
async def open_ssh_session( async def open_ssh_session(
tenant_id: uuid.UUID, tenant_id: uuid.UUID,
device_id: uuid.UUID, device_id: uuid.UUID,
@@ -194,6 +197,13 @@ async def open_ssh_session(
await _get_device(db, tenant_id, device_id) await _get_device(db, tenant_id, device_id)
source_ip = _source_ip(request) source_ip = _source_ip(request)
# TODO(defense-in-depth): No API-side SSH session count check is performed here.
# SSH session limits (per-user, per-device, global) are enforced at the poller/SSH
# relay level on WebSocket connect. There is currently no NATS subject that exposes
# SSH session counts to the API (tunnel.status.list only covers WinBox tunnels).
# When such a subject is added, query it here before issuing the token and raise
# HTTPException(429) if limits are exceeded, providing earlier feedback to the client.
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "ssh_session_open", db, tenant_id, current_user.user_id, "ssh_session_open",

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"net/http" "net/http"
"strings"
"sync" "sync"
"time" "time"
@@ -123,8 +124,15 @@ func (s *Server) handleSSH(w http.ResponseWriter, r *http.Request) {
} }
ws.SetReadLimit(1 << 20) ws.SetReadLimit(1 << 20)
// Extract source IP (nginx sets X-Real-IP) // Extract source IP (nginx sets X-Real-IP, fall back to X-Forwarded-For then RemoteAddr)
sourceIP := r.Header.Get("X-Real-IP") sourceIP := r.Header.Get("X-Real-IP")
if sourceIP == "" {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Use last entry (closest proxy)
parts := strings.Split(xff, ",")
sourceIP = strings.TrimSpace(parts[len(parts)-1])
}
}
if sourceIP == "" { if sourceIP == "" {
sourceIP = r.RemoteAddr sourceIP = r.RemoteAddr
} }
@@ -137,6 +145,13 @@ func (s *Server) handleSSH(w http.ResponseWriter, r *http.Request) {
return return
} }
// Verify device belongs to the tenant in the token
if dev.TenantID != payload.TenantID {
slog.Warn("ssh: tenant mismatch", "device_tenant", dev.TenantID, "token_tenant", payload.TenantID)
ws.Close(websocket.StatusPolicyViolation, "unauthorized")
return
}
// Decrypt credentials — GetCredentials returns (username, password, error) // Decrypt credentials — GetCredentials returns (username, password, error)
username, password, err := s.credCache.GetCredentials( username, password, err := s.credCache.GetCredentials(
dev.ID, dev.ID,