- exec_script: relay now enforces admin role before forwarding to agent - relay CORS: restrict allow_origins via ALLOWED_ORIGINS env var (docker-compose passes app URL) - session-code: replace Math.random() with crypto.randomInt, add per-key rate limit (10 req/min) - sessions GET: fix IDOR — users can only read their own sessions (admins see all) - signal API: validate session ownership on create; enforce ownerUserId on all subsequent actions Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
242 lines
9.0 KiB
Python
242 lines
9.0 KiB
Python
"""
|
|
RemoteLink WebSocket Relay
|
|
Bridges agent (remote machine) ↔ viewer (browser) connections.
|
|
|
|
Agent connects: ws://relay:8765/ws/agent?machine_id=<uuid>&access_key=<hex>
|
|
Viewer connects: ws://relay:8765/ws/viewer?session_id=<uuid>&viewer_token=<uuid>
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
from typing import Optional
|
|
from uuid import UUID
|
|
|
|
import asyncpg
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Query, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
|
log = logging.getLogger("relay")
|
|
|
|
DATABASE_URL = os.environ["DATABASE_URL"]
|
|
# Comma-separated list of allowed origins for CORS, e.g. "https://app.example.com"
|
|
ALLOWED_ORIGINS = [o.strip() for o in os.environ.get("ALLOWED_ORIGINS", "").split(",") if o.strip()]
|
|
|
|
# ── In-memory connection registry ────────────────────────────────────────────
|
|
# machine_id (str) → WebSocket
|
|
agents: dict[str, WebSocket] = {}
|
|
# session_id (str) → WebSocket
|
|
viewers: dict[str, WebSocket] = {}
|
|
# session_id → machine_id
|
|
session_to_machine: dict[str, str] = {}
|
|
# session_id → viewer's user role ("admin" | "user")
|
|
viewer_roles: dict[str, str] = {}
|
|
|
|
db_pool: Optional[asyncpg.Pool] = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
global db_pool
|
|
db_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
|
|
log.info("Database pool ready")
|
|
yield
|
|
await db_pool.close()
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=ALLOWED_ORIGINS if ALLOWED_ORIGINS else ["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
|
|
|
async def validate_agent(machine_id: str, access_key: str) -> Optional[dict]:
|
|
async with db_pool.acquire() as conn:
|
|
row = await conn.fetchrow(
|
|
"SELECT id, name, user_id FROM machines WHERE id = $1 AND access_key = $2",
|
|
machine_id, access_key,
|
|
)
|
|
return dict(row) if row else None
|
|
|
|
|
|
async def validate_viewer(session_id: str, viewer_token: str) -> Optional[dict]:
|
|
async with db_pool.acquire() as conn:
|
|
row = await conn.fetchrow(
|
|
"""SELECT s.id, s.machine_id, s.machine_name, COALESCE(u.role, 'user') AS viewer_role
|
|
FROM sessions s
|
|
LEFT JOIN users u ON u.id = s.viewer_user_id
|
|
WHERE s.id = $1 AND s.viewer_token = $2 AND s.ended_at IS NULL""",
|
|
session_id, viewer_token,
|
|
)
|
|
return dict(row) if row else None
|
|
|
|
|
|
async def set_machine_online(machine_id: str, online: bool):
|
|
async with db_pool.acquire() as conn:
|
|
await conn.execute(
|
|
"UPDATE machines SET is_online = $1, last_seen = now() WHERE id = $2",
|
|
online, machine_id,
|
|
)
|
|
|
|
|
|
async def send_json(ws: WebSocket, data: dict):
|
|
try:
|
|
await ws.send_text(json.dumps(data))
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
# ── Agent WebSocket endpoint ──────────────────────────────────────────────────
|
|
|
|
@app.websocket("/ws/agent")
|
|
async def agent_endpoint(
|
|
websocket: WebSocket,
|
|
machine_id: str = Query(...),
|
|
access_key: str = Query(...),
|
|
):
|
|
machine = await validate_agent(machine_id, access_key)
|
|
if not machine:
|
|
await websocket.close(code=4001, reason="Unauthorized")
|
|
return
|
|
|
|
await websocket.accept()
|
|
agents[machine_id] = websocket
|
|
await set_machine_online(machine_id, True)
|
|
log.info(f"Agent connected: {machine['name']} ({machine_id})")
|
|
|
|
try:
|
|
while True:
|
|
msg = await websocket.receive()
|
|
|
|
# Client sent a close frame — exit cleanly
|
|
if msg.get("type") == "websocket.disconnect":
|
|
break
|
|
|
|
if "bytes" in msg and msg["bytes"]:
|
|
# Binary = JPEG frame → forward to all viewers watching this machine
|
|
frame_data = msg["bytes"]
|
|
for sid, mid in list(session_to_machine.items()):
|
|
if mid == machine_id and sid in viewers:
|
|
try:
|
|
await viewers[sid].send_bytes(frame_data)
|
|
except Exception:
|
|
viewers.pop(sid, None)
|
|
session_to_machine.pop(sid, None)
|
|
|
|
elif "text" in msg and msg["text"]:
|
|
# JSON message from agent (script output, status, etc.)
|
|
try:
|
|
data = json.loads(msg["text"])
|
|
# Forward script output to the relevant viewer
|
|
if data.get("type") == "script_output":
|
|
session_id = data.get("session_id")
|
|
if session_id and session_id in viewers:
|
|
await send_json(viewers[session_id], data)
|
|
elif data.get("type") == "ping":
|
|
await set_machine_online(machine_id, True)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
except WebSocketDisconnect:
|
|
pass
|
|
except Exception as e:
|
|
log.warning(f"Agent {machine_id} error: {e}")
|
|
finally:
|
|
agents.pop(machine_id, None)
|
|
await set_machine_online(machine_id, False)
|
|
# Notify any connected viewers that the agent disconnected
|
|
for sid, mid in list(session_to_machine.items()):
|
|
if mid == machine_id and sid in viewers:
|
|
await send_json(viewers[sid], {"type": "agent_disconnected"})
|
|
log.info(f"Agent disconnected: {machine['name']} ({machine_id})")
|
|
|
|
|
|
# ── Viewer WebSocket endpoint ─────────────────────────────────────────────────
|
|
|
|
@app.websocket("/ws/viewer")
|
|
async def viewer_endpoint(
|
|
websocket: WebSocket,
|
|
session_id: str = Query(...),
|
|
viewer_token: str = Query(...),
|
|
):
|
|
session = await validate_viewer(session_id, viewer_token)
|
|
if not session:
|
|
await websocket.close(code=4001, reason="Session not found or expired")
|
|
return
|
|
|
|
machine_id = str(session["machine_id"]) if session["machine_id"] else None
|
|
if not machine_id:
|
|
await websocket.close(code=4002, reason="No machine associated with session")
|
|
return
|
|
|
|
viewer_role = session.get("viewer_role", "user")
|
|
|
|
await websocket.accept()
|
|
viewers[session_id] = websocket
|
|
session_to_machine[session_id] = machine_id
|
|
viewer_roles[session_id] = viewer_role
|
|
log.info(f"Viewer connected to session {session_id} (machine {machine_id}, role {viewer_role})")
|
|
|
|
if machine_id in agents:
|
|
# Tell agent to start streaming for this session
|
|
await send_json(agents[machine_id], {
|
|
"type": "start_stream",
|
|
"session_id": session_id,
|
|
})
|
|
await send_json(websocket, {"type": "agent_connected", "machine_name": session["machine_name"]})
|
|
else:
|
|
await send_json(websocket, {"type": "agent_offline"})
|
|
|
|
try:
|
|
while True:
|
|
text = await websocket.receive_text()
|
|
try:
|
|
event = json.loads(text)
|
|
event["session_id"] = session_id
|
|
# exec_script is restricted to admin viewers only
|
|
if event.get("type") == "exec_script":
|
|
if viewer_roles.get(session_id) != "admin":
|
|
await send_json(websocket, {
|
|
"type": "error",
|
|
"message": "exec_script requires admin role",
|
|
})
|
|
continue
|
|
# Forward control events to the agent
|
|
if machine_id in agents:
|
|
await send_json(agents[machine_id], event)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
except WebSocketDisconnect:
|
|
pass
|
|
except Exception as e:
|
|
log.warning(f"Viewer {session_id} error: {e}")
|
|
finally:
|
|
viewers.pop(session_id, None)
|
|
session_to_machine.pop(session_id, None)
|
|
viewer_roles.pop(session_id, None)
|
|
if machine_id in agents:
|
|
await send_json(agents[machine_id], {"type": "stop_stream", "session_id": session_id})
|
|
log.info(f"Viewer disconnected from session {session_id}")
|
|
|
|
|
|
# ── Health / status endpoints ─────────────────────────────────────────────────
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "ok", "agents": len(agents), "viewers": len(viewers)}
|
|
|
|
|
|
@app.get("/status/{machine_id}")
|
|
async def machine_status(machine_id: str):
|
|
return {"online": machine_id in agents}
|