""" RemoteLink WebSocket Relay Bridges agent (remote machine) ↔ viewer (browser) connections. Agent connects: ws://relay:8765/ws/agent?machine_id=&access_key= Viewer connects: ws://relay:8765/ws/viewer?session_id=&viewer_token= """ import asyncio import json import logging import os from contextlib import asynccontextmanager from typing import Optional from uuid import UUID import asyncpg from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Query, HTTPException from fastapi.middleware.cors import CORSMiddleware logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") log = logging.getLogger("relay") DATABASE_URL = os.environ["DATABASE_URL"] # ── In-memory connection registry ──────────────────────────────────────────── # machine_id (str) → WebSocket agents: dict[str, WebSocket] = {} # session_id (str) → WebSocket viewers: dict[str, WebSocket] = {} # session_id → machine_id session_to_machine: dict[str, str] = {} db_pool: Optional[asyncpg.Pool] = None @asynccontextmanager async def lifespan(app: FastAPI): global db_pool db_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10) log.info("Database pool ready") yield await db_pool.close() app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ── Helpers ─────────────────────────────────────────────────────────────────── async def validate_agent(machine_id: str, access_key: str) -> Optional[dict]: async with db_pool.acquire() as conn: row = await conn.fetchrow( "SELECT id, name, user_id FROM machines WHERE id = $1 AND access_key = $2", machine_id, access_key, ) return dict(row) if row else None async def validate_viewer(session_id: str, viewer_token: str) -> Optional[dict]: async with db_pool.acquire() as conn: row = await conn.fetchrow( """SELECT id, machine_id, machine_name FROM sessions WHERE id = $1 AND viewer_token = $2 AND ended_at IS NULL""", session_id, viewer_token, ) return dict(row) if row else None async def set_machine_online(machine_id: str, online: bool): async with db_pool.acquire() as conn: await conn.execute( "UPDATE machines SET is_online = $1, last_seen = now() WHERE id = $2", online, machine_id, ) async def send_json(ws: WebSocket, data: dict): try: await ws.send_text(json.dumps(data)) except Exception: pass # ── Agent WebSocket endpoint ────────────────────────────────────────────────── @app.websocket("/ws/agent") async def agent_endpoint( websocket: WebSocket, machine_id: str = Query(...), access_key: str = Query(...), ): machine = await validate_agent(machine_id, access_key) if not machine: await websocket.close(code=4001, reason="Unauthorized") return await websocket.accept() agents[machine_id] = websocket await set_machine_online(machine_id, True) log.info(f"Agent connected: {machine['name']} ({machine_id})") try: while True: msg = await websocket.receive() # Client sent a close frame — exit cleanly if msg.get("type") == "websocket.disconnect": break if "bytes" in msg and msg["bytes"]: # Binary = JPEG frame → forward to all viewers watching this machine frame_data = msg["bytes"] for sid, mid in list(session_to_machine.items()): if mid == machine_id and sid in viewers: try: await viewers[sid].send_bytes(frame_data) except Exception: viewers.pop(sid, None) session_to_machine.pop(sid, None) elif "text" in msg and msg["text"]: # JSON message from agent (script output, status, etc.) try: data = json.loads(msg["text"]) # Forward script output to the relevant viewer if data.get("type") == "script_output": session_id = data.get("session_id") if session_id and session_id in viewers: await send_json(viewers[session_id], data) elif data.get("type") == "ping": await set_machine_online(machine_id, True) except json.JSONDecodeError: pass except WebSocketDisconnect: pass except Exception as e: log.warning(f"Agent {machine_id} error: {e}") finally: agents.pop(machine_id, None) await set_machine_online(machine_id, False) # Notify any connected viewers that the agent disconnected for sid, mid in list(session_to_machine.items()): if mid == machine_id and sid in viewers: await send_json(viewers[sid], {"type": "agent_disconnected"}) log.info(f"Agent disconnected: {machine['name']} ({machine_id})") # ── Viewer WebSocket endpoint ───────────────────────────────────────────────── @app.websocket("/ws/viewer") async def viewer_endpoint( websocket: WebSocket, session_id: str = Query(...), viewer_token: str = Query(...), ): session = await validate_viewer(session_id, viewer_token) if not session: await websocket.close(code=4001, reason="Session not found or expired") return machine_id = str(session["machine_id"]) if session["machine_id"] else None if not machine_id: await websocket.close(code=4002, reason="No machine associated with session") return await websocket.accept() viewers[session_id] = websocket session_to_machine[session_id] = machine_id log.info(f"Viewer connected to session {session_id} (machine {machine_id})") if machine_id in agents: # Tell agent to start streaming for this session await send_json(agents[machine_id], { "type": "start_stream", "session_id": session_id, }) await send_json(websocket, {"type": "agent_connected", "machine_name": session["machine_name"]}) else: await send_json(websocket, {"type": "agent_offline"}) try: while True: text = await websocket.receive_text() try: event = json.loads(text) event["session_id"] = session_id # Forward control events to the agent if machine_id in agents: await send_json(agents[machine_id], event) except json.JSONDecodeError: pass except WebSocketDisconnect: pass except Exception as e: log.warning(f"Viewer {session_id} error: {e}") finally: viewers.pop(session_id, None) session_to_machine.pop(session_id, None) if machine_id in agents: await send_json(agents[machine_id], {"type": "stop_stream", "session_id": session_id}) log.info(f"Viewer disconnected from session {session_id}") # ── Health / status endpoints ───────────────────────────────────────────────── @app.get("/health") async def health(): return {"status": "ok", "agents": len(agents), "viewers": len(viewers)} @app.get("/status/{machine_id}") async def machine_status(machine_id: str): return {"online": machine_id in agents}