Files
remotelink-docker/relay/main.py
monoadmin 27673daa63 Fix critical and high security issues
- exec_script: relay now enforces admin role before forwarding to agent
- relay CORS: restrict allow_origins via ALLOWED_ORIGINS env var (docker-compose passes app URL)
- session-code: replace Math.random() with crypto.randomInt, add per-key rate limit (10 req/min)
- sessions GET: fix IDOR — users can only read their own sessions (admins see all)
- signal API: validate session ownership on create; enforce ownerUserId on all subsequent actions

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 23:38:03 -07:00

242 lines
9.0 KiB
Python

"""
RemoteLink WebSocket Relay
Bridges agent (remote machine) ↔ viewer (browser) connections.
Agent connects: ws://relay:8765/ws/agent?machine_id=<uuid>&access_key=<hex>
Viewer connects: ws://relay:8765/ws/viewer?session_id=<uuid>&viewer_token=<uuid>
"""
import asyncio
import json
import logging
import os
from contextlib import asynccontextmanager
from typing import Optional
from uuid import UUID
import asyncpg
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Query, HTTPException
from fastapi.middleware.cors import CORSMiddleware
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger("relay")
DATABASE_URL = os.environ["DATABASE_URL"]
# Comma-separated list of allowed origins for CORS, e.g. "https://app.example.com"
ALLOWED_ORIGINS = [o.strip() for o in os.environ.get("ALLOWED_ORIGINS", "").split(",") if o.strip()]
# ── In-memory connection registry ────────────────────────────────────────────
# machine_id (str) → WebSocket
agents: dict[str, WebSocket] = {}
# session_id (str) → WebSocket
viewers: dict[str, WebSocket] = {}
# session_id → machine_id
session_to_machine: dict[str, str] = {}
# session_id → viewer's user role ("admin" | "user")
viewer_roles: dict[str, str] = {}
db_pool: Optional[asyncpg.Pool] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global db_pool
db_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
log.info("Database pool ready")
yield
await db_pool.close()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS if ALLOWED_ORIGINS else ["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ── Helpers ───────────────────────────────────────────────────────────────────
async def validate_agent(machine_id: str, access_key: str) -> Optional[dict]:
async with db_pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT id, name, user_id FROM machines WHERE id = $1 AND access_key = $2",
machine_id, access_key,
)
return dict(row) if row else None
async def validate_viewer(session_id: str, viewer_token: str) -> Optional[dict]:
async with db_pool.acquire() as conn:
row = await conn.fetchrow(
"""SELECT s.id, s.machine_id, s.machine_name, COALESCE(u.role, 'user') AS viewer_role
FROM sessions s
LEFT JOIN users u ON u.id = s.viewer_user_id
WHERE s.id = $1 AND s.viewer_token = $2 AND s.ended_at IS NULL""",
session_id, viewer_token,
)
return dict(row) if row else None
async def set_machine_online(machine_id: str, online: bool):
async with db_pool.acquire() as conn:
await conn.execute(
"UPDATE machines SET is_online = $1, last_seen = now() WHERE id = $2",
online, machine_id,
)
async def send_json(ws: WebSocket, data: dict):
try:
await ws.send_text(json.dumps(data))
except Exception:
pass
# ── Agent WebSocket endpoint ──────────────────────────────────────────────────
@app.websocket("/ws/agent")
async def agent_endpoint(
websocket: WebSocket,
machine_id: str = Query(...),
access_key: str = Query(...),
):
machine = await validate_agent(machine_id, access_key)
if not machine:
await websocket.close(code=4001, reason="Unauthorized")
return
await websocket.accept()
agents[machine_id] = websocket
await set_machine_online(machine_id, True)
log.info(f"Agent connected: {machine['name']} ({machine_id})")
try:
while True:
msg = await websocket.receive()
# Client sent a close frame — exit cleanly
if msg.get("type") == "websocket.disconnect":
break
if "bytes" in msg and msg["bytes"]:
# Binary = JPEG frame → forward to all viewers watching this machine
frame_data = msg["bytes"]
for sid, mid in list(session_to_machine.items()):
if mid == machine_id and sid in viewers:
try:
await viewers[sid].send_bytes(frame_data)
except Exception:
viewers.pop(sid, None)
session_to_machine.pop(sid, None)
elif "text" in msg and msg["text"]:
# JSON message from agent (script output, status, etc.)
try:
data = json.loads(msg["text"])
# Forward script output to the relevant viewer
if data.get("type") == "script_output":
session_id = data.get("session_id")
if session_id and session_id in viewers:
await send_json(viewers[session_id], data)
elif data.get("type") == "ping":
await set_machine_online(machine_id, True)
except json.JSONDecodeError:
pass
except WebSocketDisconnect:
pass
except Exception as e:
log.warning(f"Agent {machine_id} error: {e}")
finally:
agents.pop(machine_id, None)
await set_machine_online(machine_id, False)
# Notify any connected viewers that the agent disconnected
for sid, mid in list(session_to_machine.items()):
if mid == machine_id and sid in viewers:
await send_json(viewers[sid], {"type": "agent_disconnected"})
log.info(f"Agent disconnected: {machine['name']} ({machine_id})")
# ── Viewer WebSocket endpoint ─────────────────────────────────────────────────
@app.websocket("/ws/viewer")
async def viewer_endpoint(
websocket: WebSocket,
session_id: str = Query(...),
viewer_token: str = Query(...),
):
session = await validate_viewer(session_id, viewer_token)
if not session:
await websocket.close(code=4001, reason="Session not found or expired")
return
machine_id = str(session["machine_id"]) if session["machine_id"] else None
if not machine_id:
await websocket.close(code=4002, reason="No machine associated with session")
return
viewer_role = session.get("viewer_role", "user")
await websocket.accept()
viewers[session_id] = websocket
session_to_machine[session_id] = machine_id
viewer_roles[session_id] = viewer_role
log.info(f"Viewer connected to session {session_id} (machine {machine_id}, role {viewer_role})")
if machine_id in agents:
# Tell agent to start streaming for this session
await send_json(agents[machine_id], {
"type": "start_stream",
"session_id": session_id,
})
await send_json(websocket, {"type": "agent_connected", "machine_name": session["machine_name"]})
else:
await send_json(websocket, {"type": "agent_offline"})
try:
while True:
text = await websocket.receive_text()
try:
event = json.loads(text)
event["session_id"] = session_id
# exec_script is restricted to admin viewers only
if event.get("type") == "exec_script":
if viewer_roles.get(session_id) != "admin":
await send_json(websocket, {
"type": "error",
"message": "exec_script requires admin role",
})
continue
# Forward control events to the agent
if machine_id in agents:
await send_json(agents[machine_id], event)
except json.JSONDecodeError:
pass
except WebSocketDisconnect:
pass
except Exception as e:
log.warning(f"Viewer {session_id} error: {e}")
finally:
viewers.pop(session_id, None)
session_to_machine.pop(session_id, None)
viewer_roles.pop(session_id, None)
if machine_id in agents:
await send_json(agents[machine_id], {"type": "stop_stream", "session_id": session_id})
log.info(f"Viewer disconnected from session {session_id}")
# ── Health / status endpoints ─────────────────────────────────────────────────
@app.get("/health")
async def health():
return {"status": "ok", "agents": len(agents), "viewers": len(viewers)}
@app.get("/status/{machine_id}")
async def machine_status(machine_id: str):
return {"online": machine_id in agents}