""" RemoteLink WebSocket Relay Bridges agent (remote machine) ↔ viewer (browser) connections. Agent connects: ws://relay:8765/ws/agent?machine_id=&access_key= Viewer connects: ws://relay:8765/ws/viewer?session_id=&viewer_token= """ import asyncio import json import logging import os from contextlib import asynccontextmanager from typing import Optional from uuid import UUID import asyncpg from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Query, HTTPException from fastapi.middleware.cors import CORSMiddleware logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") log = logging.getLogger("relay") DATABASE_URL = os.environ["DATABASE_URL"] # Comma-separated list of allowed origins for CORS, e.g. "https://app.example.com" ALLOWED_ORIGINS = [o.strip() for o in os.environ.get("ALLOWED_ORIGINS", "").split(",") if o.strip()] # Set RECORDING_DIR to enable session recording (JPEG frame archive) RECORDING_DIR = os.environ.get("RECORDING_DIR", "") # e.g. "/recordings" # ── In-memory connection registry ──────────────────────────────────────────── # machine_id (str) → WebSocket agents: dict[str, WebSocket] = {} # session_id (str) → WebSocket viewers: dict[str, WebSocket] = {} # session_id → machine_id session_to_machine: dict[str, str] = {} # session_id → viewer's user role ("admin" | "user") viewer_roles: dict[str, str] = {} # session_id → open file handle for recording recordings: dict[str, "io.BufferedWriter"] = {} import io import struct import time as _time db_pool: Optional[asyncpg.Pool] = None @asynccontextmanager async def lifespan(app: FastAPI): global db_pool db_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10) log.info("Database pool ready") yield await db_pool.close() app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=ALLOWED_ORIGINS if ALLOWED_ORIGINS else ["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ── Helpers ─────────────────────────────────────────────────────────────────── async def validate_agent(machine_id: str, access_key: str) -> Optional[dict]: async with db_pool.acquire() as conn: row = await conn.fetchrow( "SELECT id, name, user_id FROM machines WHERE id = $1 AND access_key = $2", machine_id, access_key, ) return dict(row) if row else None async def validate_viewer(session_id: str, viewer_token: str) -> Optional[dict]: async with db_pool.acquire() as conn: row = await conn.fetchrow( """SELECT s.id, s.machine_id, s.machine_name, COALESCE(u.role, 'user') AS viewer_role FROM sessions s LEFT JOIN users u ON u.id = s.viewer_user_id WHERE s.id = $1 AND s.viewer_token = $2 AND s.ended_at IS NULL""", session_id, viewer_token, ) return dict(row) if row else None async def set_machine_online(machine_id: str, online: bool): async with db_pool.acquire() as conn: await conn.execute( "UPDATE machines SET is_online = $1, last_seen = now() WHERE id = $2", online, machine_id, ) async def send_json(ws: WebSocket, data: dict): try: await ws.send_text(json.dumps(data)) except Exception: pass # ── Recording helpers ───────────────────────────────────────────────────────── def _start_recording(session_id: str): """Open a .remrec file for this session (simple frame archive format).""" try: import pathlib rec_dir = pathlib.Path(RECORDING_DIR) rec_dir.mkdir(parents=True, exist_ok=True) ts = _time.strftime("%Y%m%d_%H%M%S") path = rec_dir / f"{ts}_{session_id[:8]}.remrec" recordings[session_id] = open(path, "wb") # File header: magic "RLREC" + version 1 recordings[session_id].write(b"RLREC\x01") log.info(f"Recording started: {path}") except Exception as e: log.warning(f"Failed to start recording for {session_id}: {e}") def _write_frame(session_id: str, frame_data: bytes): """Append a JPEG frame with timestamp to the recording file.""" f = recordings.get(session_id) if not f: return try: ts = int(_time.time() * 1000) # milliseconds # Frame record: 8-byte timestamp + 4-byte length + JPEG bytes f.write(struct.pack(">QI", ts, len(frame_data))) f.write(frame_data) except Exception as e: log.warning(f"Recording write error: {e}") def _stop_recording(session_id: str): f = recordings.pop(session_id, None) if f: try: f.close() log.info(f"Recording stopped: {session_id}") except Exception: pass # ── Agent WebSocket endpoint ────────────────────────────────────────────────── @app.websocket("/ws/agent") async def agent_endpoint( websocket: WebSocket, machine_id: str = Query(...), access_key: str = Query(...), ): machine = await validate_agent(machine_id, access_key) if not machine: await websocket.close(code=4001, reason="Unauthorized") return await websocket.accept() agents[machine_id] = websocket await set_machine_online(machine_id, True) log.info(f"Agent connected: {machine['name']} ({machine_id})") try: while True: msg = await websocket.receive() # Client sent a close frame — exit cleanly if msg.get("type") == "websocket.disconnect": break if "bytes" in msg and msg["bytes"]: # Binary = JPEG frame → forward to all viewers watching this machine frame_data = msg["bytes"] for sid, mid in list(session_to_machine.items()): if mid == machine_id and sid in viewers: try: await viewers[sid].send_bytes(frame_data) if RECORDING_DIR: _write_frame(sid, frame_data) except Exception: viewers.pop(sid, None) session_to_machine.pop(sid, None) elif "text" in msg and msg["text"]: # JSON message from agent try: data = json.loads(msg["text"]) msg_type = data.get("type") if msg_type == "ping": await set_machine_online(machine_id, True) elif msg_type in { "script_output", "monitor_list", "clipboard_content", "file_chunk", "file_list", "chat_message", }: # Forward to the relevant viewer(s) session_id = data.get("session_id") if session_id and session_id in viewers: await send_json(viewers[session_id], data) else: # Broadcast to all viewers watching this machine for sid, mid in list(session_to_machine.items()): if mid == machine_id and sid in viewers: await send_json(viewers[sid], data) except json.JSONDecodeError: pass except WebSocketDisconnect: pass except Exception as e: log.warning(f"Agent {machine_id} error: {e}") finally: agents.pop(machine_id, None) await set_machine_online(machine_id, False) # Notify any connected viewers that the agent disconnected for sid, mid in list(session_to_machine.items()): if mid == machine_id and sid in viewers: await send_json(viewers[sid], {"type": "agent_disconnected"}) log.info(f"Agent disconnected: {machine['name']} ({machine_id})") # ── Viewer WebSocket endpoint ───────────────────────────────────────────────── @app.websocket("/ws/viewer") async def viewer_endpoint( websocket: WebSocket, session_id: str = Query(...), viewer_token: str = Query(...), ): session = await validate_viewer(session_id, viewer_token) if not session: await websocket.close(code=4001, reason="Session not found or expired") return machine_id = str(session["machine_id"]) if session["machine_id"] else None if not machine_id: await websocket.close(code=4002, reason="No machine associated with session") return viewer_role = session.get("viewer_role", "user") await websocket.accept() viewers[session_id] = websocket session_to_machine[session_id] = machine_id viewer_roles[session_id] = viewer_role log.info(f"Viewer connected to session {session_id} (machine {machine_id}, role {viewer_role})") if machine_id in agents: # Tell agent to start streaming for this session await send_json(agents[machine_id], { "type": "start_stream", "session_id": session_id, }) await send_json(websocket, {"type": "agent_connected", "machine_name": session["machine_name"]}) # Start recording if enabled if RECORDING_DIR: _start_recording(session_id) else: await send_json(websocket, {"type": "agent_offline"}) try: while True: text = await websocket.receive_text() try: event = json.loads(text) event["session_id"] = session_id # exec_script is restricted to admin viewers only if event.get("type") == "exec_script": if viewer_roles.get(session_id) != "admin": await send_json(websocket, { "type": "error", "message": "exec_script requires admin role", }) continue # Forward control events to the agent if machine_id in agents: await send_json(agents[machine_id], event) except json.JSONDecodeError: pass except WebSocketDisconnect: pass except Exception as e: log.warning(f"Viewer {session_id} error: {e}") finally: viewers.pop(session_id, None) session_to_machine.pop(session_id, None) viewer_roles.pop(session_id, None) if RECORDING_DIR: _stop_recording(session_id) if machine_id in agents: await send_json(agents[machine_id], {"type": "stop_stream", "session_id": session_id}) log.info(f"Viewer disconnected from session {session_id}") # ── Health / status endpoints ───────────────────────────────────────────────── @app.get("/health") async def health(): return {"status": "ok", "agents": len(agents), "viewers": len(viewers)} @app.get("/status/{machine_id}") async def machine_status(machine_id: str): return {"online": machine_id in agents}