Files
remotelink-docker/relay/main.py
monoadmin c361922005 Fix agent crash on missing DISPLAY and relay disconnect error
- Auto-detect DISPLAY on Linux by scanning /tmp/.X11-unix/ sockets,
  falling back to 'w' output, then :0 — runs before mss/pynput import
- ScreenCapture no longer raises on init failure; agent stays connected
  and notifies the viewer with an error message if capture unavailable
- stream_frames skips None frames instead of crashing the WebSocket
- Relay: check for websocket.disconnect message type to avoid
  'Cannot call receive once a disconnect message has been received'

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 16:38:58 -07:00

225 lines
8.0 KiB
Python

"""
RemoteLink WebSocket Relay
Bridges agent (remote machine) ↔ viewer (browser) connections.
Agent connects: ws://relay:8765/ws/agent?machine_id=<uuid>&access_key=<hex>
Viewer connects: ws://relay:8765/ws/viewer?session_id=<uuid>&viewer_token=<uuid>
"""
import asyncio
import json
import logging
import os
from contextlib import asynccontextmanager
from typing import Optional
from uuid import UUID
import asyncpg
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Query, HTTPException
from fastapi.middleware.cors import CORSMiddleware
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger("relay")
DATABASE_URL = os.environ["DATABASE_URL"]
# ── In-memory connection registry ────────────────────────────────────────────
# machine_id (str) → WebSocket
agents: dict[str, WebSocket] = {}
# session_id (str) → WebSocket
viewers: dict[str, WebSocket] = {}
# session_id → machine_id
session_to_machine: dict[str, str] = {}
db_pool: Optional[asyncpg.Pool] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global db_pool
db_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
log.info("Database pool ready")
yield
await db_pool.close()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ── Helpers ───────────────────────────────────────────────────────────────────
async def validate_agent(machine_id: str, access_key: str) -> Optional[dict]:
async with db_pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT id, name, user_id FROM machines WHERE id = $1 AND access_key = $2",
machine_id, access_key,
)
return dict(row) if row else None
async def validate_viewer(session_id: str, viewer_token: str) -> Optional[dict]:
async with db_pool.acquire() as conn:
row = await conn.fetchrow(
"""SELECT id, machine_id, machine_name
FROM sessions
WHERE id = $1 AND viewer_token = $2 AND ended_at IS NULL""",
session_id, viewer_token,
)
return dict(row) if row else None
async def set_machine_online(machine_id: str, online: bool):
async with db_pool.acquire() as conn:
await conn.execute(
"UPDATE machines SET is_online = $1, last_seen = now() WHERE id = $2",
online, machine_id,
)
async def send_json(ws: WebSocket, data: dict):
try:
await ws.send_text(json.dumps(data))
except Exception:
pass
# ── Agent WebSocket endpoint ──────────────────────────────────────────────────
@app.websocket("/ws/agent")
async def agent_endpoint(
websocket: WebSocket,
machine_id: str = Query(...),
access_key: str = Query(...),
):
machine = await validate_agent(machine_id, access_key)
if not machine:
await websocket.close(code=4001, reason="Unauthorized")
return
await websocket.accept()
agents[machine_id] = websocket
await set_machine_online(machine_id, True)
log.info(f"Agent connected: {machine['name']} ({machine_id})")
try:
while True:
msg = await websocket.receive()
# Client sent a close frame — exit cleanly
if msg.get("type") == "websocket.disconnect":
break
if "bytes" in msg and msg["bytes"]:
# Binary = JPEG frame → forward to all viewers watching this machine
frame_data = msg["bytes"]
for sid, mid in list(session_to_machine.items()):
if mid == machine_id and sid in viewers:
try:
await viewers[sid].send_bytes(frame_data)
except Exception:
viewers.pop(sid, None)
session_to_machine.pop(sid, None)
elif "text" in msg and msg["text"]:
# JSON message from agent (script output, status, etc.)
try:
data = json.loads(msg["text"])
# Forward script output to the relevant viewer
if data.get("type") == "script_output":
session_id = data.get("session_id")
if session_id and session_id in viewers:
await send_json(viewers[session_id], data)
elif data.get("type") == "ping":
await set_machine_online(machine_id, True)
except json.JSONDecodeError:
pass
except WebSocketDisconnect:
pass
except Exception as e:
log.warning(f"Agent {machine_id} error: {e}")
finally:
agents.pop(machine_id, None)
await set_machine_online(machine_id, False)
# Notify any connected viewers that the agent disconnected
for sid, mid in list(session_to_machine.items()):
if mid == machine_id and sid in viewers:
await send_json(viewers[sid], {"type": "agent_disconnected"})
log.info(f"Agent disconnected: {machine['name']} ({machine_id})")
# ── Viewer WebSocket endpoint ─────────────────────────────────────────────────
@app.websocket("/ws/viewer")
async def viewer_endpoint(
websocket: WebSocket,
session_id: str = Query(...),
viewer_token: str = Query(...),
):
session = await validate_viewer(session_id, viewer_token)
if not session:
await websocket.close(code=4001, reason="Session not found or expired")
return
machine_id = str(session["machine_id"]) if session["machine_id"] else None
if not machine_id:
await websocket.close(code=4002, reason="No machine associated with session")
return
await websocket.accept()
viewers[session_id] = websocket
session_to_machine[session_id] = machine_id
log.info(f"Viewer connected to session {session_id} (machine {machine_id})")
if machine_id in agents:
# Tell agent to start streaming for this session
await send_json(agents[machine_id], {
"type": "start_stream",
"session_id": session_id,
})
await send_json(websocket, {"type": "agent_connected", "machine_name": session["machine_name"]})
else:
await send_json(websocket, {"type": "agent_offline"})
try:
while True:
text = await websocket.receive_text()
try:
event = json.loads(text)
event["session_id"] = session_id
# Forward control events to the agent
if machine_id in agents:
await send_json(agents[machine_id], event)
except json.JSONDecodeError:
pass
except WebSocketDisconnect:
pass
except Exception as e:
log.warning(f"Viewer {session_id} error: {e}")
finally:
viewers.pop(session_id, None)
session_to_machine.pop(session_id, None)
if machine_id in agents:
await send_json(agents[machine_id], {"type": "stop_stream", "session_id": session_id})
log.info(f"Viewer disconnected from session {session_id}")
# ── Health / status endpoints ─────────────────────────────────────────────────
@app.get("/health")
async def health():
return {"status": "ok", "agents": len(agents), "viewers": len(viewers)}
@app.get("/status/{machine_id}")
async def machine_status(machine_id: str):
return {"online": machine_id in agents}