Files
the-other-dude/backend/app/services/vpn_service.py
2026-03-14 16:27:35 -05:00

489 lines
17 KiB
Python

"""WireGuard VPN management service.
Handles key generation, peer management, config file sync, and RouterOS command generation.
"""
import base64
import ipaddress
import json
import os
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
import structlog
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
from cryptography.hazmat.primitives.serialization import (
Encoding,
NoEncryption,
PrivateFormat,
PublicFormat,
)
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.models.device import Device
from app.models.vpn import VpnConfig, VpnPeer
from app.services.crypto import decrypt_credentials, encrypt_credentials, encrypt_credentials_transit
logger = structlog.get_logger(__name__)
# ── Key Generation ──
def generate_wireguard_keypair() -> tuple[str, str]:
"""Generate a WireGuard X25519 keypair. Returns (private_key_b64, public_key_b64)."""
private_key = X25519PrivateKey.generate()
priv_bytes = private_key.private_bytes(Encoding.Raw, PrivateFormat.Raw, NoEncryption())
pub_bytes = private_key.public_key().public_bytes(Encoding.Raw, PublicFormat.Raw)
return base64.b64encode(priv_bytes).decode(), base64.b64encode(pub_bytes).decode()
def generate_preshared_key() -> str:
"""Generate a WireGuard preshared key (32 random bytes, base64)."""
return base64.b64encode(os.urandom(32)).decode()
# ── Global Server Key & Subnet Allocation ──
async def _get_or_create_global_server_key(db: AsyncSession) -> tuple[str, str]:
"""Get (or create on first call) the global WireGuard server keypair.
Returns (private_key_b64, public_key_b64). Private key is decrypted.
Uses an advisory lock to prevent race conditions on first-time generation.
"""
from sqlalchemy import text as sa_text
# Advisory lock prevents two simultaneous first-calls from generating different keypairs
await db.execute(sa_text("SELECT pg_advisory_xact_lock(hashtext('vpn_server_keygen'))"))
result = await db.execute(
sa_text("SELECT key, value, encrypted_value FROM system_settings WHERE key IN ('vpn_server_public_key', 'vpn_server_private_key')")
)
rows = {row[0]: row for row in result.fetchall()}
if "vpn_server_public_key" in rows and "vpn_server_private_key" in rows:
public_key = rows["vpn_server_public_key"][1]
encrypted_private = rows["vpn_server_private_key"][2]
key_bytes = settings.get_encryption_key_bytes()
private_key = decrypt_credentials(encrypted_private, key_bytes)
return private_key, public_key
# First call on fresh install — generate and store
private_key_b64, public_key_b64 = generate_wireguard_keypair()
key_bytes = settings.get_encryption_key_bytes()
encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
await db.execute(
sa_text("""
INSERT INTO system_settings (key, value, updated_at)
VALUES ('vpn_server_public_key', :pub, now())
ON CONFLICT (key) DO UPDATE SET value = :pub, updated_at = now()
"""),
{"pub": public_key_b64},
)
await db.execute(
sa_text("""
INSERT INTO system_settings (key, value, encrypted_value, updated_at)
VALUES ('vpn_server_private_key', NULL, :enc, now())
ON CONFLICT (key) DO UPDATE SET encrypted_value = :enc, updated_at = now()
"""),
{"enc": encrypted_private},
)
await db.flush()
logger.info("vpn_global_server_keypair_generated", event="vpn_audit")
return private_key_b64, public_key_b64
def _allocate_subnet_index_from_used(used: set[int]) -> int:
"""Find the first available subnet index in [1, 255] not in `used`.
Pure function for unit testing. Raises ValueError if pool exhausted.
"""
for i in range(1, 256):
if i not in used:
return i
raise ValueError("VPN subnet pool exhausted")
async def _allocate_subnet_index(db: AsyncSession) -> int:
"""Allocate next available subnet_index from the database.
Uses gap-filling: finds the lowest integer in [1,255] not already used.
The UNIQUE constraint on subnet_index protects against races.
"""
result = await db.execute(select(VpnConfig.subnet_index))
used = {row[0] for row in result.all()}
return _allocate_subnet_index_from_used(used)
_VPN_ADDRESS_SPACE = ipaddress.ip_network("10.10.0.0/16")
def _validate_additional_allowed_ips(additional_allowed_ips: str | None) -> None:
"""Reject additional_allowed_ips that overlap the VPN address space (10.10.0.0/16)."""
if not additional_allowed_ips:
return
for entry in additional_allowed_ips.split(","):
entry = entry.strip()
if not entry:
continue
try:
network = ipaddress.ip_network(entry, strict=False)
except ValueError:
continue # let WireGuard reject malformed entries
if network.overlaps(_VPN_ADDRESS_SPACE):
raise ValueError(
"Additional allowed IPs must not overlap the VPN address space (10.10.0.0/16)"
)
# ── Config File Management ──
def _get_wg_config_path() -> Path:
"""Return the path to the shared WireGuard config directory."""
return Path(os.getenv("WIREGUARD_CONFIG_PATH", "/data/wireguard"))
async def sync_wireguard_config(db: AsyncSession, tenant_id: uuid.UUID) -> None:
"""Regenerate wg0.conf from database state and write to shared volume."""
config = await get_vpn_config(db, tenant_id)
if not config or not config.is_enabled:
return
key_bytes = settings.get_encryption_key_bytes()
server_private_key = decrypt_credentials(config.server_private_key, key_bytes)
result = await db.execute(
select(VpnPeer).where(VpnPeer.tenant_id == tenant_id, VpnPeer.is_enabled.is_(True))
)
peers = result.scalars().all()
# Build wg0.conf
lines = [
"[Interface]",
f"Address = {config.server_address}",
f"ListenPort = {config.server_port}",
f"PrivateKey = {server_private_key}",
"",
]
for peer in peers:
peer_ip = peer.assigned_ip.split("/")[0] # strip CIDR for AllowedIPs
allowed_ips = [f"{peer_ip}/32"]
if peer.additional_allowed_ips:
# Comma-separated additional subnets (e.g. site-to-site routing)
extra = [s.strip() for s in peer.additional_allowed_ips.split(",") if s.strip()]
allowed_ips.extend(extra)
lines.append("[Peer]")
lines.append(f"PublicKey = {peer.peer_public_key}")
if peer.preshared_key:
psk = decrypt_credentials(peer.preshared_key, key_bytes)
lines.append(f"PresharedKey = {psk}")
lines.append(f"AllowedIPs = {', '.join(allowed_ips)}")
lines.append("")
config_dir = _get_wg_config_path()
wg_confs_dir = config_dir / "wg_confs"
wg_confs_dir.mkdir(parents=True, exist_ok=True)
conf_path = wg_confs_dir / "wg0.conf"
conf_path.write_text("\n".join(lines))
# Signal WireGuard container to reload config
reload_flag = wg_confs_dir / ".reload"
reload_flag.write_text("1")
logger.info("wireguard config synced", tenant_id=str(tenant_id), peers=len(peers))
# ── Live Status ──
def read_wg_status() -> dict[str, dict]:
"""Read live WireGuard peer status from the shared volume.
The WireGuard container writes wg_status.json every 15 seconds
with output from `wg show wg0 dump`. Returns a dict keyed by
peer public key with handshake timestamp and transfer stats.
"""
status_path = _get_wg_config_path() / "wg_status.json"
if not status_path.exists():
return {}
try:
data = json.loads(status_path.read_text())
return {entry["public_key"]: entry for entry in data}
except (json.JSONDecodeError, KeyError, OSError):
return {}
def get_peer_handshake(wg_status: dict[str, dict], public_key: str) -> Optional[datetime]:
"""Get last_handshake datetime for a peer from live WireGuard status."""
entry = wg_status.get(public_key)
if not entry:
return None
ts = entry.get("last_handshake", 0)
if ts and ts > 0:
return datetime.fromtimestamp(ts, tz=timezone.utc)
return None
# ── CRUD Operations ──
async def get_vpn_config(db: AsyncSession, tenant_id: uuid.UUID) -> Optional[VpnConfig]:
"""Get the VPN config for a tenant."""
result = await db.execute(select(VpnConfig).where(VpnConfig.tenant_id == tenant_id))
return result.scalar_one_or_none()
async def setup_vpn(
db: AsyncSession, tenant_id: uuid.UUID, endpoint: Optional[str] = None
) -> VpnConfig:
"""Initialize VPN for a tenant — generates server keys and creates config."""
existing = await get_vpn_config(db, tenant_id)
if existing:
raise ValueError("VPN already configured for this tenant")
private_key_b64, public_key_b64 = generate_wireguard_keypair()
key_bytes = settings.get_encryption_key_bytes()
encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
config = VpnConfig(
tenant_id=tenant_id,
server_private_key=encrypted_private,
server_public_key=public_key_b64,
endpoint=endpoint,
is_enabled=True,
)
db.add(config)
await db.flush()
await sync_wireguard_config(db, tenant_id)
return config
async def update_vpn_config(
db: AsyncSession, tenant_id: uuid.UUID, endpoint: Optional[str] = None, is_enabled: Optional[bool] = None
) -> VpnConfig:
"""Update VPN config settings."""
config = await get_vpn_config(db, tenant_id)
if not config:
raise ValueError("VPN not configured for this tenant")
if endpoint is not None:
config.endpoint = endpoint
if is_enabled is not None:
config.is_enabled = is_enabled
await db.flush()
await sync_wireguard_config(db, tenant_id)
return config
async def get_peers(db: AsyncSession, tenant_id: uuid.UUID) -> list[VpnPeer]:
"""List all VPN peers for a tenant."""
result = await db.execute(
select(VpnPeer).where(VpnPeer.tenant_id == tenant_id).order_by(VpnPeer.created_at)
)
return list(result.scalars().all())
async def _next_available_ip(db: AsyncSession, tenant_id: uuid.UUID, config: VpnConfig) -> str:
"""Allocate the next available IP in the VPN subnet."""
# Parse subnet: e.g. "10.10.0.0/24" → start from .2 (server is .1)
network = ipaddress.ip_network(config.subnet, strict=False)
hosts = list(network.hosts())
# Get already assigned IPs
result = await db.execute(select(VpnPeer.assigned_ip).where(VpnPeer.tenant_id == tenant_id))
used_ips = {row[0].split("/")[0] for row in result.all()}
used_ips.add(config.server_address.split("/")[0]) # exclude server IP
for host in hosts[1:]: # skip .1 (server)
if str(host) not in used_ips:
return f"{host}/24"
raise ValueError("No available IPs in VPN subnet")
async def add_peer(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UUID, additional_allowed_ips: Optional[str] = None) -> VpnPeer:
"""Add a device as a VPN peer."""
config = await get_vpn_config(db, tenant_id)
if not config:
raise ValueError("VPN not configured — enable VPN first")
# Check device exists
device = await db.execute(select(Device).where(Device.id == device_id, Device.tenant_id == tenant_id))
if not device.scalar_one_or_none():
raise ValueError("Device not found")
# Check if already a peer
existing = await db.execute(select(VpnPeer).where(VpnPeer.device_id == device_id))
if existing.scalar_one_or_none():
raise ValueError("Device is already a VPN peer")
private_key_b64, public_key_b64 = generate_wireguard_keypair()
psk = generate_preshared_key()
key_bytes = settings.get_encryption_key_bytes()
encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
encrypted_psk = encrypt_credentials(psk, key_bytes)
assigned_ip = await _next_available_ip(db, tenant_id, config)
peer = VpnPeer(
tenant_id=tenant_id,
device_id=device_id,
peer_private_key=encrypted_private,
peer_public_key=public_key_b64,
preshared_key=encrypted_psk,
assigned_ip=assigned_ip,
additional_allowed_ips=additional_allowed_ips,
)
db.add(peer)
await db.flush()
await sync_wireguard_config(db, tenant_id)
return peer
async def remove_peer(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.UUID) -> None:
"""Remove a VPN peer."""
result = await db.execute(
select(VpnPeer).where(VpnPeer.id == peer_id, VpnPeer.tenant_id == tenant_id)
)
peer = result.scalar_one_or_none()
if not peer:
raise ValueError("Peer not found")
await db.delete(peer)
await db.flush()
await sync_wireguard_config(db, tenant_id)
async def get_peer_config(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.UUID) -> dict:
"""Get the full config for a peer — includes private key for device setup."""
config = await get_vpn_config(db, tenant_id)
if not config:
raise ValueError("VPN not configured")
result = await db.execute(
select(VpnPeer).where(VpnPeer.id == peer_id, VpnPeer.tenant_id == tenant_id)
)
peer = result.scalar_one_or_none()
if not peer:
raise ValueError("Peer not found")
key_bytes = settings.get_encryption_key_bytes()
private_key = decrypt_credentials(peer.peer_private_key, key_bytes)
psk = decrypt_credentials(peer.preshared_key, key_bytes) if peer.preshared_key else None
endpoint = config.endpoint or "YOUR_SERVER_IP:51820"
peer_ip_no_cidr = peer.assigned_ip.split("/")[0]
routeros_commands = [
f'/interface wireguard add name=wg-portal listen-port=13231 private-key="{private_key}"',
f'/interface wireguard peers add interface=wg-portal public-key="{config.server_public_key}" '
f'endpoint-address={endpoint.split(":")[0]} endpoint-port={endpoint.split(":")[-1]} '
f'allowed-address={config.subnet} persistent-keepalive=25'
+ (f' preshared-key="{psk}"' if psk else ""),
f"/ip address add address={peer.assigned_ip} interface=wg-portal",
]
return {
"peer_private_key": private_key,
"peer_public_key": peer.peer_public_key,
"assigned_ip": peer.assigned_ip,
"server_public_key": config.server_public_key,
"server_endpoint": endpoint,
"allowed_ips": config.subnet,
"routeros_commands": routeros_commands,
}
async def onboard_device(
db: AsyncSession,
tenant_id: uuid.UUID,
hostname: str,
username: str,
password: str,
) -> dict:
"""Create device + VPN peer in one transaction. Returns device, peer, and RouterOS commands.
Unlike regular device creation, this skips TCP connectivity checks because
the VPN tunnel isn't up yet. The device IP is set to the VPN-assigned address.
"""
config = await get_vpn_config(db, tenant_id)
if not config:
raise ValueError("VPN not configured — enable VPN first")
# Allocate VPN IP before creating device
assigned_ip = await _next_available_ip(db, tenant_id, config)
vpn_ip_no_cidr = assigned_ip.split("/")[0]
# Create device with VPN IP (skip TCP check — tunnel not up yet)
credentials_json = json.dumps({"username": username, "password": password})
transit_ciphertext = await encrypt_credentials_transit(credentials_json, str(tenant_id))
device = Device(
tenant_id=tenant_id,
hostname=hostname,
ip_address=vpn_ip_no_cidr,
api_port=8728,
api_ssl_port=8729,
encrypted_credentials_transit=transit_ciphertext,
status="unknown",
)
db.add(device)
await db.flush()
# Create VPN peer linked to this device
private_key_b64, public_key_b64 = generate_wireguard_keypair()
psk = generate_preshared_key()
key_bytes = settings.get_encryption_key_bytes()
encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
encrypted_psk = encrypt_credentials(psk, key_bytes)
peer = VpnPeer(
tenant_id=tenant_id,
device_id=device.id,
peer_private_key=encrypted_private,
peer_public_key=public_key_b64,
preshared_key=encrypted_psk,
assigned_ip=assigned_ip,
)
db.add(peer)
await db.flush()
await sync_wireguard_config(db, tenant_id)
# Generate RouterOS commands
endpoint = config.endpoint or "YOUR_SERVER_IP:51820"
psk_decrypted = decrypt_credentials(encrypted_psk, key_bytes)
routeros_commands = [
f'/interface wireguard add name=wg-portal listen-port=13231 private-key="{private_key_b64}"',
f'/interface wireguard peers add interface=wg-portal public-key="{config.server_public_key}" '
f'endpoint-address={endpoint.split(":")[0]} endpoint-port={endpoint.split(":")[-1]} '
f'allowed-address={config.subnet} persistent-keepalive=25'
f' preshared-key="{psk_decrypted}"',
f"/ip address add address={assigned_ip} interface=wg-portal",
]
return {
"device_id": device.id,
"peer_id": peer.id,
"hostname": hostname,
"assigned_ip": assigned_ip,
"routeros_commands": routeros_commands,
}