feat(vpn): refactor setup_vpn and sync_wireguard_config for multi-tenant isolation
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -151,34 +151,66 @@ def _get_wg_config_path() -> Path:
|
||||
return Path(os.getenv("WIREGUARD_CONFIG_PATH", "/data/wireguard"))
|
||||
|
||||
|
||||
async def sync_wireguard_config(db: AsyncSession, tenant_id: uuid.UUID) -> None:
|
||||
"""Regenerate wg0.conf from database state and write to shared volume."""
|
||||
config = await get_vpn_config(db, tenant_id)
|
||||
if not config or not config.is_enabled:
|
||||
return
|
||||
async def sync_wireguard_config(db: AsyncSession) -> None:
|
||||
"""Regenerate wg0.conf with ALL tenants' peers and write to shared volume.
|
||||
|
||||
key_bytes = settings.get_encryption_key_bytes()
|
||||
server_private_key = decrypt_credentials(config.server_private_key, key_bytes)
|
||||
Uses AdminAsyncSessionLocal to bypass RLS (must see all tenants).
|
||||
Uses a PostgreSQL advisory lock to prevent concurrent writes.
|
||||
Writes atomically via temp file + rename.
|
||||
"""
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
from sqlalchemy import text as sa_text
|
||||
|
||||
result = await db.execute(
|
||||
select(VpnPeer).where(VpnPeer.tenant_id == tenant_id, VpnPeer.is_enabled.is_(True))
|
||||
async with AdminAsyncSessionLocal() as admin_db:
|
||||
# Acquire advisory lock (released when this session closes)
|
||||
await admin_db.execute(sa_text("SELECT pg_advisory_lock(hashtext('wireguard_config'))"))
|
||||
|
||||
try:
|
||||
# Get global server private key
|
||||
private_key_b64, _ = await _get_or_create_global_server_key(admin_db)
|
||||
|
||||
# Query ALL enabled VPN configs (admin session bypasses RLS)
|
||||
configs_result = await admin_db.execute(
|
||||
select(VpnConfig).where(VpnConfig.is_enabled.is_(True)).order_by(VpnConfig.subnet_index)
|
||||
)
|
||||
peers = result.scalars().all()
|
||||
configs = configs_result.scalars().all()
|
||||
|
||||
# Build wg0.conf
|
||||
lines = [
|
||||
"[Interface]",
|
||||
f"Address = {config.server_address}",
|
||||
f"ListenPort = {config.server_port}",
|
||||
f"PrivateKey = {server_private_key}",
|
||||
"Address = 10.10.0.1/16",
|
||||
f"ListenPort = {configs[0].server_port if configs else 51820}",
|
||||
f"PrivateKey = {private_key_b64}",
|
||||
"",
|
||||
]
|
||||
|
||||
key_bytes = settings.get_encryption_key_bytes()
|
||||
total_peers = 0
|
||||
|
||||
for config in configs:
|
||||
# Get tenant name for comment
|
||||
tenant_result = await admin_db.execute(
|
||||
sa_text("SELECT name FROM tenants WHERE id = :tid"),
|
||||
{"tid": config.tenant_id},
|
||||
)
|
||||
tenant_row = tenant_result.fetchone()
|
||||
tenant_name = tenant_row[0] if tenant_row else str(config.tenant_id)
|
||||
|
||||
peers_result = await admin_db.execute(
|
||||
select(VpnPeer).where(
|
||||
VpnPeer.tenant_id == config.tenant_id,
|
||||
VpnPeer.is_enabled.is_(True),
|
||||
)
|
||||
)
|
||||
peers = peers_result.scalars().all()
|
||||
|
||||
if peers:
|
||||
lines.append(f"# --- Tenant: {tenant_name} ({config.subnet}) ---")
|
||||
|
||||
for peer in peers:
|
||||
peer_ip = peer.assigned_ip.split("/")[0] # strip CIDR for AllowedIPs
|
||||
peer_ip = peer.assigned_ip.split("/")[0]
|
||||
allowed_ips = [f"{peer_ip}/32"]
|
||||
if peer.additional_allowed_ips:
|
||||
# Comma-separated additional subnets (e.g. site-to-site routing)
|
||||
extra = [s.strip() for s in peer.additional_allowed_ips.split(",") if s.strip()]
|
||||
allowed_ips.extend(extra)
|
||||
lines.append("[Peer]")
|
||||
@@ -188,19 +220,28 @@ async def sync_wireguard_config(db: AsyncSession, tenant_id: uuid.UUID) -> None:
|
||||
lines.append(f"PresharedKey = {psk}")
|
||||
lines.append(f"AllowedIPs = {', '.join(allowed_ips)}")
|
||||
lines.append("")
|
||||
total_peers += 1
|
||||
|
||||
# Atomic write: temp file + rename
|
||||
config_dir = _get_wg_config_path()
|
||||
wg_confs_dir = config_dir / "wg_confs"
|
||||
wg_confs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
conf_path = wg_confs_dir / "wg0.conf"
|
||||
conf_path.write_text("\n".join(lines))
|
||||
tmp_path = wg_confs_dir / "wg0.conf.tmp"
|
||||
tmp_path.write_text("\n".join(lines))
|
||||
os.rename(str(tmp_path), str(conf_path))
|
||||
|
||||
# Signal WireGuard container to reload config
|
||||
# Signal WireGuard container to reload
|
||||
reload_flag = wg_confs_dir / ".reload"
|
||||
reload_flag.write_text("1")
|
||||
|
||||
logger.info("wireguard config synced", tenant_id=str(tenant_id), peers=len(peers))
|
||||
logger.info("wireguard_config_synced", event="vpn_audit",
|
||||
tenants=len(configs), peers=total_peers)
|
||||
|
||||
finally:
|
||||
# Release advisory lock explicitly (session-level lock, not xact-level)
|
||||
await admin_db.execute(sa_text("SELECT pg_advisory_unlock(hashtext('wireguard_config'))"))
|
||||
|
||||
|
||||
# ── Live Status ──
|
||||
@@ -246,13 +287,23 @@ async def get_vpn_config(db: AsyncSession, tenant_id: uuid.UUID) -> Optional[Vpn
|
||||
async def setup_vpn(
|
||||
db: AsyncSession, tenant_id: uuid.UUID, endpoint: Optional[str] = None
|
||||
) -> VpnConfig:
|
||||
"""Initialize VPN for a tenant — generates server keys and creates config."""
|
||||
"""Initialize VPN for a tenant — allocates unique subnet, uses global server key."""
|
||||
existing = await get_vpn_config(db, tenant_id)
|
||||
if existing:
|
||||
raise ValueError("VPN already configured for this tenant")
|
||||
|
||||
private_key_b64, public_key_b64 = generate_wireguard_keypair()
|
||||
# Get or create global server keypair
|
||||
_, public_key_b64 = await _get_or_create_global_server_key(db)
|
||||
|
||||
# Allocate unique subnet
|
||||
subnet_index = await _allocate_subnet_index(db)
|
||||
subnet = f"10.10.{subnet_index}.0/24"
|
||||
server_address = f"10.10.{subnet_index}.1/24"
|
||||
|
||||
# Generate a per-tenant key for the deprecated server_private_key column.
|
||||
# This column is NOT NULL and kept for rollback safety. The global key
|
||||
# in system_settings is authoritative; this per-tenant key is unused.
|
||||
private_key_b64, _ = generate_wireguard_keypair()
|
||||
key_bytes = settings.get_encryption_key_bytes()
|
||||
encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
|
||||
|
||||
@@ -260,13 +311,19 @@ async def setup_vpn(
|
||||
tenant_id=tenant_id,
|
||||
server_private_key=encrypted_private,
|
||||
server_public_key=public_key_b64,
|
||||
subnet_index=subnet_index,
|
||||
subnet=subnet,
|
||||
server_address=server_address,
|
||||
endpoint=endpoint,
|
||||
is_enabled=True,
|
||||
)
|
||||
db.add(config)
|
||||
await db.flush()
|
||||
|
||||
await sync_wireguard_config(db, tenant_id)
|
||||
logger.info("vpn_subnet_allocated", event="vpn_audit",
|
||||
tenant_id=str(tenant_id), subnet_index=subnet_index, subnet=subnet)
|
||||
|
||||
await sync_wireguard_config(db)
|
||||
return config
|
||||
|
||||
|
||||
@@ -284,7 +341,7 @@ async def update_vpn_config(
|
||||
config.is_enabled = is_enabled
|
||||
|
||||
await db.flush()
|
||||
await sync_wireguard_config(db, tenant_id)
|
||||
await sync_wireguard_config(db)
|
||||
return config
|
||||
|
||||
|
||||
@@ -330,6 +387,8 @@ async def add_peer(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UUID,
|
||||
if existing.scalar_one_or_none():
|
||||
raise ValueError("Device is already a VPN peer")
|
||||
|
||||
_validate_additional_allowed_ips(additional_allowed_ips)
|
||||
|
||||
private_key_b64, public_key_b64 = generate_wireguard_keypair()
|
||||
psk = generate_preshared_key()
|
||||
|
||||
@@ -351,7 +410,7 @@ async def add_peer(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UUID,
|
||||
db.add(peer)
|
||||
await db.flush()
|
||||
|
||||
await sync_wireguard_config(db, tenant_id)
|
||||
await sync_wireguard_config(db)
|
||||
return peer
|
||||
|
||||
|
||||
@@ -366,7 +425,7 @@ async def remove_peer(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.UUID
|
||||
|
||||
await db.delete(peer)
|
||||
await db.flush()
|
||||
await sync_wireguard_config(db, tenant_id)
|
||||
await sync_wireguard_config(db)
|
||||
|
||||
|
||||
async def get_peer_config(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.UUID) -> dict:
|
||||
@@ -464,7 +523,7 @@ async def onboard_device(
|
||||
db.add(peer)
|
||||
await db.flush()
|
||||
|
||||
await sync_wireguard_config(db, tenant_id)
|
||||
await sync_wireguard_config(db)
|
||||
|
||||
# Generate RouterOS commands
|
||||
endpoint = config.endpoint or "YOUR_SERVER_IP:51820"
|
||||
|
||||
Reference in New Issue
Block a user