feat(vpn): refactor setup_vpn and sync_wireguard_config for multi-tenant isolation

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Jason Staack
2026-03-14 16:30:13 -05:00
parent 93fe935edf
commit 5e70890d76

View File

@@ -151,34 +151,66 @@ def _get_wg_config_path() -> Path:
return Path(os.getenv("WIREGUARD_CONFIG_PATH", "/data/wireguard")) return Path(os.getenv("WIREGUARD_CONFIG_PATH", "/data/wireguard"))
async def sync_wireguard_config(db: AsyncSession, tenant_id: uuid.UUID) -> None: async def sync_wireguard_config(db: AsyncSession) -> None:
"""Regenerate wg0.conf from database state and write to shared volume.""" """Regenerate wg0.conf with ALL tenants' peers and write to shared volume.
config = await get_vpn_config(db, tenant_id)
if not config or not config.is_enabled:
return
key_bytes = settings.get_encryption_key_bytes() Uses AdminAsyncSessionLocal to bypass RLS (must see all tenants).
server_private_key = decrypt_credentials(config.server_private_key, key_bytes) Uses a PostgreSQL advisory lock to prevent concurrent writes.
Writes atomically via temp file + rename.
"""
from app.database import AdminAsyncSessionLocal
from sqlalchemy import text as sa_text
result = await db.execute( async with AdminAsyncSessionLocal() as admin_db:
select(VpnPeer).where(VpnPeer.tenant_id == tenant_id, VpnPeer.is_enabled.is_(True)) # Acquire advisory lock (released when this session closes)
await admin_db.execute(sa_text("SELECT pg_advisory_lock(hashtext('wireguard_config'))"))
try:
# Get global server private key
private_key_b64, _ = await _get_or_create_global_server_key(admin_db)
# Query ALL enabled VPN configs (admin session bypasses RLS)
configs_result = await admin_db.execute(
select(VpnConfig).where(VpnConfig.is_enabled.is_(True)).order_by(VpnConfig.subnet_index)
) )
peers = result.scalars().all() configs = configs_result.scalars().all()
# Build wg0.conf # Build wg0.conf
lines = [ lines = [
"[Interface]", "[Interface]",
f"Address = {config.server_address}", "Address = 10.10.0.1/16",
f"ListenPort = {config.server_port}", f"ListenPort = {configs[0].server_port if configs else 51820}",
f"PrivateKey = {server_private_key}", f"PrivateKey = {private_key_b64}",
"", "",
] ]
key_bytes = settings.get_encryption_key_bytes()
total_peers = 0
for config in configs:
# Get tenant name for comment
tenant_result = await admin_db.execute(
sa_text("SELECT name FROM tenants WHERE id = :tid"),
{"tid": config.tenant_id},
)
tenant_row = tenant_result.fetchone()
tenant_name = tenant_row[0] if tenant_row else str(config.tenant_id)
peers_result = await admin_db.execute(
select(VpnPeer).where(
VpnPeer.tenant_id == config.tenant_id,
VpnPeer.is_enabled.is_(True),
)
)
peers = peers_result.scalars().all()
if peers:
lines.append(f"# --- Tenant: {tenant_name} ({config.subnet}) ---")
for peer in peers: for peer in peers:
peer_ip = peer.assigned_ip.split("/")[0] # strip CIDR for AllowedIPs peer_ip = peer.assigned_ip.split("/")[0]
allowed_ips = [f"{peer_ip}/32"] allowed_ips = [f"{peer_ip}/32"]
if peer.additional_allowed_ips: if peer.additional_allowed_ips:
# Comma-separated additional subnets (e.g. site-to-site routing)
extra = [s.strip() for s in peer.additional_allowed_ips.split(",") if s.strip()] extra = [s.strip() for s in peer.additional_allowed_ips.split(",") if s.strip()]
allowed_ips.extend(extra) allowed_ips.extend(extra)
lines.append("[Peer]") lines.append("[Peer]")
@@ -188,19 +220,28 @@ async def sync_wireguard_config(db: AsyncSession, tenant_id: uuid.UUID) -> None:
lines.append(f"PresharedKey = {psk}") lines.append(f"PresharedKey = {psk}")
lines.append(f"AllowedIPs = {', '.join(allowed_ips)}") lines.append(f"AllowedIPs = {', '.join(allowed_ips)}")
lines.append("") lines.append("")
total_peers += 1
# Atomic write: temp file + rename
config_dir = _get_wg_config_path() config_dir = _get_wg_config_path()
wg_confs_dir = config_dir / "wg_confs" wg_confs_dir = config_dir / "wg_confs"
wg_confs_dir.mkdir(parents=True, exist_ok=True) wg_confs_dir.mkdir(parents=True, exist_ok=True)
conf_path = wg_confs_dir / "wg0.conf" conf_path = wg_confs_dir / "wg0.conf"
conf_path.write_text("\n".join(lines)) tmp_path = wg_confs_dir / "wg0.conf.tmp"
tmp_path.write_text("\n".join(lines))
os.rename(str(tmp_path), str(conf_path))
# Signal WireGuard container to reload config # Signal WireGuard container to reload
reload_flag = wg_confs_dir / ".reload" reload_flag = wg_confs_dir / ".reload"
reload_flag.write_text("1") reload_flag.write_text("1")
logger.info("wireguard config synced", tenant_id=str(tenant_id), peers=len(peers)) logger.info("wireguard_config_synced", event="vpn_audit",
tenants=len(configs), peers=total_peers)
finally:
# Release advisory lock explicitly (session-level lock, not xact-level)
await admin_db.execute(sa_text("SELECT pg_advisory_unlock(hashtext('wireguard_config'))"))
# ── Live Status ── # ── Live Status ──
@@ -246,13 +287,23 @@ async def get_vpn_config(db: AsyncSession, tenant_id: uuid.UUID) -> Optional[Vpn
async def setup_vpn( async def setup_vpn(
db: AsyncSession, tenant_id: uuid.UUID, endpoint: Optional[str] = None db: AsyncSession, tenant_id: uuid.UUID, endpoint: Optional[str] = None
) -> VpnConfig: ) -> VpnConfig:
"""Initialize VPN for a tenant — generates server keys and creates config.""" """Initialize VPN for a tenant — allocates unique subnet, uses global server key."""
existing = await get_vpn_config(db, tenant_id) existing = await get_vpn_config(db, tenant_id)
if existing: if existing:
raise ValueError("VPN already configured for this tenant") raise ValueError("VPN already configured for this tenant")
private_key_b64, public_key_b64 = generate_wireguard_keypair() # Get or create global server keypair
_, public_key_b64 = await _get_or_create_global_server_key(db)
# Allocate unique subnet
subnet_index = await _allocate_subnet_index(db)
subnet = f"10.10.{subnet_index}.0/24"
server_address = f"10.10.{subnet_index}.1/24"
# Generate a per-tenant key for the deprecated server_private_key column.
# This column is NOT NULL and kept for rollback safety. The global key
# in system_settings is authoritative; this per-tenant key is unused.
private_key_b64, _ = generate_wireguard_keypair()
key_bytes = settings.get_encryption_key_bytes() key_bytes = settings.get_encryption_key_bytes()
encrypted_private = encrypt_credentials(private_key_b64, key_bytes) encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
@@ -260,13 +311,19 @@ async def setup_vpn(
tenant_id=tenant_id, tenant_id=tenant_id,
server_private_key=encrypted_private, server_private_key=encrypted_private,
server_public_key=public_key_b64, server_public_key=public_key_b64,
subnet_index=subnet_index,
subnet=subnet,
server_address=server_address,
endpoint=endpoint, endpoint=endpoint,
is_enabled=True, is_enabled=True,
) )
db.add(config) db.add(config)
await db.flush() await db.flush()
await sync_wireguard_config(db, tenant_id) logger.info("vpn_subnet_allocated", event="vpn_audit",
tenant_id=str(tenant_id), subnet_index=subnet_index, subnet=subnet)
await sync_wireguard_config(db)
return config return config
@@ -284,7 +341,7 @@ async def update_vpn_config(
config.is_enabled = is_enabled config.is_enabled = is_enabled
await db.flush() await db.flush()
await sync_wireguard_config(db, tenant_id) await sync_wireguard_config(db)
return config return config
@@ -330,6 +387,8 @@ async def add_peer(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UUID,
if existing.scalar_one_or_none(): if existing.scalar_one_or_none():
raise ValueError("Device is already a VPN peer") raise ValueError("Device is already a VPN peer")
_validate_additional_allowed_ips(additional_allowed_ips)
private_key_b64, public_key_b64 = generate_wireguard_keypair() private_key_b64, public_key_b64 = generate_wireguard_keypair()
psk = generate_preshared_key() psk = generate_preshared_key()
@@ -351,7 +410,7 @@ async def add_peer(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UUID,
db.add(peer) db.add(peer)
await db.flush() await db.flush()
await sync_wireguard_config(db, tenant_id) await sync_wireguard_config(db)
return peer return peer
@@ -366,7 +425,7 @@ async def remove_peer(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.UUID
await db.delete(peer) await db.delete(peer)
await db.flush() await db.flush()
await sync_wireguard_config(db, tenant_id) await sync_wireguard_config(db)
async def get_peer_config(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.UUID) -> dict: async def get_peer_config(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.UUID) -> dict:
@@ -464,7 +523,7 @@ async def onboard_device(
db.add(peer) db.add(peer)
await db.flush() await db.flush()
await sync_wireguard_config(db, tenant_id) await sync_wireguard_config(db)
# Generate RouterOS commands # Generate RouterOS commands
endpoint = config.endpoint or "YOUR_SERVER_IP:51820" endpoint = config.endpoint or "YOUR_SERVER_IP:51820"