feat(vpn): refactor setup_vpn and sync_wireguard_config for multi-tenant isolation
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -151,34 +151,66 @@ def _get_wg_config_path() -> Path:
|
|||||||
return Path(os.getenv("WIREGUARD_CONFIG_PATH", "/data/wireguard"))
|
return Path(os.getenv("WIREGUARD_CONFIG_PATH", "/data/wireguard"))
|
||||||
|
|
||||||
|
|
||||||
async def sync_wireguard_config(db: AsyncSession, tenant_id: uuid.UUID) -> None:
|
async def sync_wireguard_config(db: AsyncSession) -> None:
|
||||||
"""Regenerate wg0.conf from database state and write to shared volume."""
|
"""Regenerate wg0.conf with ALL tenants' peers and write to shared volume.
|
||||||
config = await get_vpn_config(db, tenant_id)
|
|
||||||
if not config or not config.is_enabled:
|
|
||||||
return
|
|
||||||
|
|
||||||
key_bytes = settings.get_encryption_key_bytes()
|
Uses AdminAsyncSessionLocal to bypass RLS (must see all tenants).
|
||||||
server_private_key = decrypt_credentials(config.server_private_key, key_bytes)
|
Uses a PostgreSQL advisory lock to prevent concurrent writes.
|
||||||
|
Writes atomically via temp file + rename.
|
||||||
|
"""
|
||||||
|
from app.database import AdminAsyncSessionLocal
|
||||||
|
from sqlalchemy import text as sa_text
|
||||||
|
|
||||||
result = await db.execute(
|
async with AdminAsyncSessionLocal() as admin_db:
|
||||||
select(VpnPeer).where(VpnPeer.tenant_id == tenant_id, VpnPeer.is_enabled.is_(True))
|
# Acquire advisory lock (released when this session closes)
|
||||||
|
await admin_db.execute(sa_text("SELECT pg_advisory_lock(hashtext('wireguard_config'))"))
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get global server private key
|
||||||
|
private_key_b64, _ = await _get_or_create_global_server_key(admin_db)
|
||||||
|
|
||||||
|
# Query ALL enabled VPN configs (admin session bypasses RLS)
|
||||||
|
configs_result = await admin_db.execute(
|
||||||
|
select(VpnConfig).where(VpnConfig.is_enabled.is_(True)).order_by(VpnConfig.subnet_index)
|
||||||
)
|
)
|
||||||
peers = result.scalars().all()
|
configs = configs_result.scalars().all()
|
||||||
|
|
||||||
# Build wg0.conf
|
# Build wg0.conf
|
||||||
lines = [
|
lines = [
|
||||||
"[Interface]",
|
"[Interface]",
|
||||||
f"Address = {config.server_address}",
|
"Address = 10.10.0.1/16",
|
||||||
f"ListenPort = {config.server_port}",
|
f"ListenPort = {configs[0].server_port if configs else 51820}",
|
||||||
f"PrivateKey = {server_private_key}",
|
f"PrivateKey = {private_key_b64}",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
key_bytes = settings.get_encryption_key_bytes()
|
||||||
|
total_peers = 0
|
||||||
|
|
||||||
|
for config in configs:
|
||||||
|
# Get tenant name for comment
|
||||||
|
tenant_result = await admin_db.execute(
|
||||||
|
sa_text("SELECT name FROM tenants WHERE id = :tid"),
|
||||||
|
{"tid": config.tenant_id},
|
||||||
|
)
|
||||||
|
tenant_row = tenant_result.fetchone()
|
||||||
|
tenant_name = tenant_row[0] if tenant_row else str(config.tenant_id)
|
||||||
|
|
||||||
|
peers_result = await admin_db.execute(
|
||||||
|
select(VpnPeer).where(
|
||||||
|
VpnPeer.tenant_id == config.tenant_id,
|
||||||
|
VpnPeer.is_enabled.is_(True),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
peers = peers_result.scalars().all()
|
||||||
|
|
||||||
|
if peers:
|
||||||
|
lines.append(f"# --- Tenant: {tenant_name} ({config.subnet}) ---")
|
||||||
|
|
||||||
for peer in peers:
|
for peer in peers:
|
||||||
peer_ip = peer.assigned_ip.split("/")[0] # strip CIDR for AllowedIPs
|
peer_ip = peer.assigned_ip.split("/")[0]
|
||||||
allowed_ips = [f"{peer_ip}/32"]
|
allowed_ips = [f"{peer_ip}/32"]
|
||||||
if peer.additional_allowed_ips:
|
if peer.additional_allowed_ips:
|
||||||
# Comma-separated additional subnets (e.g. site-to-site routing)
|
|
||||||
extra = [s.strip() for s in peer.additional_allowed_ips.split(",") if s.strip()]
|
extra = [s.strip() for s in peer.additional_allowed_ips.split(",") if s.strip()]
|
||||||
allowed_ips.extend(extra)
|
allowed_ips.extend(extra)
|
||||||
lines.append("[Peer]")
|
lines.append("[Peer]")
|
||||||
@@ -188,19 +220,28 @@ async def sync_wireguard_config(db: AsyncSession, tenant_id: uuid.UUID) -> None:
|
|||||||
lines.append(f"PresharedKey = {psk}")
|
lines.append(f"PresharedKey = {psk}")
|
||||||
lines.append(f"AllowedIPs = {', '.join(allowed_ips)}")
|
lines.append(f"AllowedIPs = {', '.join(allowed_ips)}")
|
||||||
lines.append("")
|
lines.append("")
|
||||||
|
total_peers += 1
|
||||||
|
|
||||||
|
# Atomic write: temp file + rename
|
||||||
config_dir = _get_wg_config_path()
|
config_dir = _get_wg_config_path()
|
||||||
wg_confs_dir = config_dir / "wg_confs"
|
wg_confs_dir = config_dir / "wg_confs"
|
||||||
wg_confs_dir.mkdir(parents=True, exist_ok=True)
|
wg_confs_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
conf_path = wg_confs_dir / "wg0.conf"
|
conf_path = wg_confs_dir / "wg0.conf"
|
||||||
conf_path.write_text("\n".join(lines))
|
tmp_path = wg_confs_dir / "wg0.conf.tmp"
|
||||||
|
tmp_path.write_text("\n".join(lines))
|
||||||
|
os.rename(str(tmp_path), str(conf_path))
|
||||||
|
|
||||||
# Signal WireGuard container to reload config
|
# Signal WireGuard container to reload
|
||||||
reload_flag = wg_confs_dir / ".reload"
|
reload_flag = wg_confs_dir / ".reload"
|
||||||
reload_flag.write_text("1")
|
reload_flag.write_text("1")
|
||||||
|
|
||||||
logger.info("wireguard config synced", tenant_id=str(tenant_id), peers=len(peers))
|
logger.info("wireguard_config_synced", event="vpn_audit",
|
||||||
|
tenants=len(configs), peers=total_peers)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Release advisory lock explicitly (session-level lock, not xact-level)
|
||||||
|
await admin_db.execute(sa_text("SELECT pg_advisory_unlock(hashtext('wireguard_config'))"))
|
||||||
|
|
||||||
|
|
||||||
# ── Live Status ──
|
# ── Live Status ──
|
||||||
@@ -246,13 +287,23 @@ async def get_vpn_config(db: AsyncSession, tenant_id: uuid.UUID) -> Optional[Vpn
|
|||||||
async def setup_vpn(
|
async def setup_vpn(
|
||||||
db: AsyncSession, tenant_id: uuid.UUID, endpoint: Optional[str] = None
|
db: AsyncSession, tenant_id: uuid.UUID, endpoint: Optional[str] = None
|
||||||
) -> VpnConfig:
|
) -> VpnConfig:
|
||||||
"""Initialize VPN for a tenant — generates server keys and creates config."""
|
"""Initialize VPN for a tenant — allocates unique subnet, uses global server key."""
|
||||||
existing = await get_vpn_config(db, tenant_id)
|
existing = await get_vpn_config(db, tenant_id)
|
||||||
if existing:
|
if existing:
|
||||||
raise ValueError("VPN already configured for this tenant")
|
raise ValueError("VPN already configured for this tenant")
|
||||||
|
|
||||||
private_key_b64, public_key_b64 = generate_wireguard_keypair()
|
# Get or create global server keypair
|
||||||
|
_, public_key_b64 = await _get_or_create_global_server_key(db)
|
||||||
|
|
||||||
|
# Allocate unique subnet
|
||||||
|
subnet_index = await _allocate_subnet_index(db)
|
||||||
|
subnet = f"10.10.{subnet_index}.0/24"
|
||||||
|
server_address = f"10.10.{subnet_index}.1/24"
|
||||||
|
|
||||||
|
# Generate a per-tenant key for the deprecated server_private_key column.
|
||||||
|
# This column is NOT NULL and kept for rollback safety. The global key
|
||||||
|
# in system_settings is authoritative; this per-tenant key is unused.
|
||||||
|
private_key_b64, _ = generate_wireguard_keypair()
|
||||||
key_bytes = settings.get_encryption_key_bytes()
|
key_bytes = settings.get_encryption_key_bytes()
|
||||||
encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
|
encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
|
||||||
|
|
||||||
@@ -260,13 +311,19 @@ async def setup_vpn(
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
server_private_key=encrypted_private,
|
server_private_key=encrypted_private,
|
||||||
server_public_key=public_key_b64,
|
server_public_key=public_key_b64,
|
||||||
|
subnet_index=subnet_index,
|
||||||
|
subnet=subnet,
|
||||||
|
server_address=server_address,
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
is_enabled=True,
|
is_enabled=True,
|
||||||
)
|
)
|
||||||
db.add(config)
|
db.add(config)
|
||||||
await db.flush()
|
await db.flush()
|
||||||
|
|
||||||
await sync_wireguard_config(db, tenant_id)
|
logger.info("vpn_subnet_allocated", event="vpn_audit",
|
||||||
|
tenant_id=str(tenant_id), subnet_index=subnet_index, subnet=subnet)
|
||||||
|
|
||||||
|
await sync_wireguard_config(db)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
@@ -284,7 +341,7 @@ async def update_vpn_config(
|
|||||||
config.is_enabled = is_enabled
|
config.is_enabled = is_enabled
|
||||||
|
|
||||||
await db.flush()
|
await db.flush()
|
||||||
await sync_wireguard_config(db, tenant_id)
|
await sync_wireguard_config(db)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
@@ -330,6 +387,8 @@ async def add_peer(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UUID,
|
|||||||
if existing.scalar_one_or_none():
|
if existing.scalar_one_or_none():
|
||||||
raise ValueError("Device is already a VPN peer")
|
raise ValueError("Device is already a VPN peer")
|
||||||
|
|
||||||
|
_validate_additional_allowed_ips(additional_allowed_ips)
|
||||||
|
|
||||||
private_key_b64, public_key_b64 = generate_wireguard_keypair()
|
private_key_b64, public_key_b64 = generate_wireguard_keypair()
|
||||||
psk = generate_preshared_key()
|
psk = generate_preshared_key()
|
||||||
|
|
||||||
@@ -351,7 +410,7 @@ async def add_peer(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UUID,
|
|||||||
db.add(peer)
|
db.add(peer)
|
||||||
await db.flush()
|
await db.flush()
|
||||||
|
|
||||||
await sync_wireguard_config(db, tenant_id)
|
await sync_wireguard_config(db)
|
||||||
return peer
|
return peer
|
||||||
|
|
||||||
|
|
||||||
@@ -366,7 +425,7 @@ async def remove_peer(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.UUID
|
|||||||
|
|
||||||
await db.delete(peer)
|
await db.delete(peer)
|
||||||
await db.flush()
|
await db.flush()
|
||||||
await sync_wireguard_config(db, tenant_id)
|
await sync_wireguard_config(db)
|
||||||
|
|
||||||
|
|
||||||
async def get_peer_config(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.UUID) -> dict:
|
async def get_peer_config(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.UUID) -> dict:
|
||||||
@@ -464,7 +523,7 @@ async def onboard_device(
|
|||||||
db.add(peer)
|
db.add(peer)
|
||||||
await db.flush()
|
await db.flush()
|
||||||
|
|
||||||
await sync_wireguard_config(db, tenant_id)
|
await sync_wireguard_config(db)
|
||||||
|
|
||||||
# Generate RouterOS commands
|
# Generate RouterOS commands
|
||||||
endpoint = config.endpoint or "YOUR_SERVER_IP:51820"
|
endpoint = config.endpoint or "YOUR_SERVER_IP:51820"
|
||||||
|
|||||||
Reference in New Issue
Block a user