- WebSocket relay service (FastAPI) bridges agents and viewers - Python agent with screen capture (mss), input control (pynput), script execution, and auto-reconnect - Windows service wrapper, PyInstaller spec, NSIS installer for silent mass deployment (RemoteLink-Setup.exe /S /SERVER= /ENROLL=) - Enrollment token system: admin generates tokens, agents self-register - Real WebSocket viewer replaces simulated canvas - Linux agent binary served from /downloads/remotelink-agent-linux - DB migration 0002: viewer_token on sessions, enrollment_tokens table - Sign-up pages cleaned up (invite-only redirect) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
221 lines
7.8 KiB
Python
221 lines
7.8 KiB
Python
"""
|
|
RemoteLink WebSocket Relay
|
|
Bridges agent (remote machine) ↔ viewer (browser) connections.
|
|
|
|
Agent connects: ws://relay:8765/ws/agent?machine_id=<uuid>&access_key=<hex>
|
|
Viewer connects: ws://relay:8765/ws/viewer?session_id=<uuid>&viewer_token=<uuid>
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
from typing import Optional
|
|
from uuid import UUID
|
|
|
|
import asyncpg
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Query, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
|
log = logging.getLogger("relay")
|
|
|
|
DATABASE_URL = os.environ["DATABASE_URL"]
|
|
|
|
# ── In-memory connection registry ────────────────────────────────────────────
|
|
# machine_id (str) → WebSocket
|
|
agents: dict[str, WebSocket] = {}
|
|
# session_id (str) → WebSocket
|
|
viewers: dict[str, WebSocket] = {}
|
|
# session_id → machine_id
|
|
session_to_machine: dict[str, str] = {}
|
|
|
|
db_pool: Optional[asyncpg.Pool] = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
global db_pool
|
|
db_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
|
|
log.info("Database pool ready")
|
|
yield
|
|
await db_pool.close()
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
|
|
|
async def validate_agent(machine_id: str, access_key: str) -> Optional[dict]:
|
|
async with db_pool.acquire() as conn:
|
|
row = await conn.fetchrow(
|
|
"SELECT id, name, user_id FROM machines WHERE id = $1 AND access_key = $2",
|
|
machine_id, access_key,
|
|
)
|
|
return dict(row) if row else None
|
|
|
|
|
|
async def validate_viewer(session_id: str, viewer_token: str) -> Optional[dict]:
|
|
async with db_pool.acquire() as conn:
|
|
row = await conn.fetchrow(
|
|
"""SELECT id, machine_id, machine_name
|
|
FROM sessions
|
|
WHERE id = $1 AND viewer_token = $2 AND ended_at IS NULL""",
|
|
session_id, viewer_token,
|
|
)
|
|
return dict(row) if row else None
|
|
|
|
|
|
async def set_machine_online(machine_id: str, online: bool):
|
|
async with db_pool.acquire() as conn:
|
|
await conn.execute(
|
|
"UPDATE machines SET is_online = $1, last_seen = now() WHERE id = $2",
|
|
online, machine_id,
|
|
)
|
|
|
|
|
|
async def send_json(ws: WebSocket, data: dict):
|
|
try:
|
|
await ws.send_text(json.dumps(data))
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
# ── Agent WebSocket endpoint ──────────────────────────────────────────────────
|
|
|
|
@app.websocket("/ws/agent")
|
|
async def agent_endpoint(
|
|
websocket: WebSocket,
|
|
machine_id: str = Query(...),
|
|
access_key: str = Query(...),
|
|
):
|
|
machine = await validate_agent(machine_id, access_key)
|
|
if not machine:
|
|
await websocket.close(code=4001, reason="Unauthorized")
|
|
return
|
|
|
|
await websocket.accept()
|
|
agents[machine_id] = websocket
|
|
await set_machine_online(machine_id, True)
|
|
log.info(f"Agent connected: {machine['name']} ({machine_id})")
|
|
|
|
try:
|
|
while True:
|
|
msg = await websocket.receive()
|
|
|
|
if "bytes" in msg and msg["bytes"]:
|
|
# Binary = JPEG frame → forward to all viewers watching this machine
|
|
frame_data = msg["bytes"]
|
|
for sid, mid in list(session_to_machine.items()):
|
|
if mid == machine_id and sid in viewers:
|
|
try:
|
|
await viewers[sid].send_bytes(frame_data)
|
|
except Exception:
|
|
viewers.pop(sid, None)
|
|
session_to_machine.pop(sid, None)
|
|
|
|
elif "text" in msg and msg["text"]:
|
|
# JSON message from agent (script output, status, etc.)
|
|
try:
|
|
data = json.loads(msg["text"])
|
|
# Forward script output to the relevant viewer
|
|
if data.get("type") == "script_output":
|
|
session_id = data.get("session_id")
|
|
if session_id and session_id in viewers:
|
|
await send_json(viewers[session_id], data)
|
|
elif data.get("type") == "ping":
|
|
await set_machine_online(machine_id, True)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
except WebSocketDisconnect:
|
|
pass
|
|
except Exception as e:
|
|
log.warning(f"Agent {machine_id} error: {e}")
|
|
finally:
|
|
agents.pop(machine_id, None)
|
|
await set_machine_online(machine_id, False)
|
|
# Notify any connected viewers that the agent disconnected
|
|
for sid, mid in list(session_to_machine.items()):
|
|
if mid == machine_id and sid in viewers:
|
|
await send_json(viewers[sid], {"type": "agent_disconnected"})
|
|
log.info(f"Agent disconnected: {machine['name']} ({machine_id})")
|
|
|
|
|
|
# ── Viewer WebSocket endpoint ─────────────────────────────────────────────────
|
|
|
|
@app.websocket("/ws/viewer")
|
|
async def viewer_endpoint(
|
|
websocket: WebSocket,
|
|
session_id: str = Query(...),
|
|
viewer_token: str = Query(...),
|
|
):
|
|
session = await validate_viewer(session_id, viewer_token)
|
|
if not session:
|
|
await websocket.close(code=4001, reason="Session not found or expired")
|
|
return
|
|
|
|
machine_id = str(session["machine_id"]) if session["machine_id"] else None
|
|
if not machine_id:
|
|
await websocket.close(code=4002, reason="No machine associated with session")
|
|
return
|
|
|
|
await websocket.accept()
|
|
viewers[session_id] = websocket
|
|
session_to_machine[session_id] = machine_id
|
|
log.info(f"Viewer connected to session {session_id} (machine {machine_id})")
|
|
|
|
if machine_id in agents:
|
|
# Tell agent to start streaming for this session
|
|
await send_json(agents[machine_id], {
|
|
"type": "start_stream",
|
|
"session_id": session_id,
|
|
})
|
|
await send_json(websocket, {"type": "agent_connected", "machine_name": session["machine_name"]})
|
|
else:
|
|
await send_json(websocket, {"type": "agent_offline"})
|
|
|
|
try:
|
|
while True:
|
|
text = await websocket.receive_text()
|
|
try:
|
|
event = json.loads(text)
|
|
event["session_id"] = session_id
|
|
# Forward control events to the agent
|
|
if machine_id in agents:
|
|
await send_json(agents[machine_id], event)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
except WebSocketDisconnect:
|
|
pass
|
|
except Exception as e:
|
|
log.warning(f"Viewer {session_id} error: {e}")
|
|
finally:
|
|
viewers.pop(session_id, None)
|
|
session_to_machine.pop(session_id, None)
|
|
if machine_id in agents:
|
|
await send_json(agents[machine_id], {"type": "stop_stream", "session_id": session_id})
|
|
log.info(f"Viewer disconnected from session {session_id}")
|
|
|
|
|
|
# ── Health / status endpoints ─────────────────────────────────────────────────
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "ok", "agents": len(agents), "viewers": len(viewers)}
|
|
|
|
|
|
@app.get("/status/{machine_id}")
|
|
async def machine_status(machine_id: str):
|
|
return {"online": machine_id in agents}
|