diff --git a/backend/app/routers/remote_access.py b/backend/app/routers/remote_access.py index 578ac5d..b89e195 100644 --- a/backend/app/routers/remote_access.py +++ b/backend/app/routers/remote_access.py @@ -29,6 +29,7 @@ from app.schemas.remote_access import ( TunnelStatusItem, WinboxSessionResponse, ) +from app.middleware.rate_limit import limiter from app.services.audit_service import log_action from sqlalchemy import select @@ -102,6 +103,7 @@ async def _check_tenant_access(current_user: CurrentUser, tenant_id: uuid.UUID, summary="Open a WinBox tunnel to the device", dependencies=[Depends(require_operator_or_above)], ) +@limiter.limit("10/minute") async def open_winbox_session( tenant_id: uuid.UUID, device_id: uuid.UUID, @@ -176,6 +178,7 @@ async def open_winbox_session( summary="Create a single-use SSH WebSocket session token", dependencies=[Depends(require_operator_or_above)], ) +@limiter.limit("10/minute") async def open_ssh_session( tenant_id: uuid.UUID, device_id: uuid.UUID, @@ -194,6 +197,13 @@ async def open_ssh_session( await _get_device(db, tenant_id, device_id) source_ip = _source_ip(request) + # TODO(defense-in-depth): No API-side SSH session count check is performed here. + # SSH session limits (per-user, per-device, global) are enforced at the poller/SSH + # relay level on WebSocket connect. There is currently no NATS subject that exposes + # SSH session counts to the API (tunnel.status.list only covers WinBox tunnels). + # When such a subject is added, query it here before issuing the token and raise + # HTTPException(429) if limits are exceeded, providing earlier feedback to the client. + try: await log_action( db, tenant_id, current_user.user_id, "ssh_session_open", diff --git a/poller/internal/sshrelay/server.go b/poller/internal/sshrelay/server.go index 79e98c1..38475da 100644 --- a/poller/internal/sshrelay/server.go +++ b/poller/internal/sshrelay/server.go @@ -6,6 +6,7 @@ import ( "fmt" "log/slog" "net/http" + "strings" "sync" "time" @@ -123,8 +124,15 @@ func (s *Server) handleSSH(w http.ResponseWriter, r *http.Request) { } ws.SetReadLimit(1 << 20) - // Extract source IP (nginx sets X-Real-IP) + // Extract source IP (nginx sets X-Real-IP, fall back to X-Forwarded-For then RemoteAddr) sourceIP := r.Header.Get("X-Real-IP") + if sourceIP == "" { + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // Use last entry (closest proxy) + parts := strings.Split(xff, ",") + sourceIP = strings.TrimSpace(parts[len(parts)-1]) + } + } if sourceIP == "" { sourceIP = r.RemoteAddr } @@ -137,6 +145,13 @@ func (s *Server) handleSSH(w http.ResponseWriter, r *http.Request) { return } + // Verify device belongs to the tenant in the token + if dev.TenantID != payload.TenantID { + slog.Warn("ssh: tenant mismatch", "device_tenant", dev.TenantID, "token_tenant", payload.TenantID) + ws.Close(websocket.StatusPolicyViolation, "unauthorized") + return + } + // Decrypt credentials — GetCredentials returns (username, password, error) username, password, err := s.credCache.GetCredentials( dev.ID,