diff --git a/backend/app/services/vpn_service.py b/backend/app/services/vpn_service.py index b6a1af8..35349ee 100644 --- a/backend/app/services/vpn_service.py +++ b/backend/app/services/vpn_service.py @@ -151,56 +151,97 @@ def _get_wg_config_path() -> Path: return Path(os.getenv("WIREGUARD_CONFIG_PATH", "/data/wireguard")) -async def sync_wireguard_config(db: AsyncSession, tenant_id: uuid.UUID) -> None: - """Regenerate wg0.conf from database state and write to shared volume.""" - config = await get_vpn_config(db, tenant_id) - if not config or not config.is_enabled: - return +async def sync_wireguard_config(db: AsyncSession) -> None: + """Regenerate wg0.conf with ALL tenants' peers and write to shared volume. - key_bytes = settings.get_encryption_key_bytes() - server_private_key = decrypt_credentials(config.server_private_key, key_bytes) + Uses AdminAsyncSessionLocal to bypass RLS (must see all tenants). + Uses a PostgreSQL advisory lock to prevent concurrent writes. + Writes atomically via temp file + rename. + """ + from app.database import AdminAsyncSessionLocal + from sqlalchemy import text as sa_text - result = await db.execute( - select(VpnPeer).where(VpnPeer.tenant_id == tenant_id, VpnPeer.is_enabled.is_(True)) - ) - peers = result.scalars().all() + async with AdminAsyncSessionLocal() as admin_db: + # Acquire advisory lock (released when this session closes) + await admin_db.execute(sa_text("SELECT pg_advisory_lock(hashtext('wireguard_config'))")) - # Build wg0.conf - lines = [ - "[Interface]", - f"Address = {config.server_address}", - f"ListenPort = {config.server_port}", - f"PrivateKey = {server_private_key}", - "", - ] + try: + # Get global server private key + private_key_b64, _ = await _get_or_create_global_server_key(admin_db) - for peer in peers: - peer_ip = peer.assigned_ip.split("/")[0] # strip CIDR for AllowedIPs - allowed_ips = [f"{peer_ip}/32"] - if peer.additional_allowed_ips: - # Comma-separated additional subnets (e.g. site-to-site routing) - extra = [s.strip() for s in peer.additional_allowed_ips.split(",") if s.strip()] - allowed_ips.extend(extra) - lines.append("[Peer]") - lines.append(f"PublicKey = {peer.peer_public_key}") - if peer.preshared_key: - psk = decrypt_credentials(peer.preshared_key, key_bytes) - lines.append(f"PresharedKey = {psk}") - lines.append(f"AllowedIPs = {', '.join(allowed_ips)}") - lines.append("") + # Query ALL enabled VPN configs (admin session bypasses RLS) + configs_result = await admin_db.execute( + select(VpnConfig).where(VpnConfig.is_enabled.is_(True)).order_by(VpnConfig.subnet_index) + ) + configs = configs_result.scalars().all() - config_dir = _get_wg_config_path() - wg_confs_dir = config_dir / "wg_confs" - wg_confs_dir.mkdir(parents=True, exist_ok=True) + # Build wg0.conf + lines = [ + "[Interface]", + "Address = 10.10.0.1/16", + f"ListenPort = {configs[0].server_port if configs else 51820}", + f"PrivateKey = {private_key_b64}", + "", + ] - conf_path = wg_confs_dir / "wg0.conf" - conf_path.write_text("\n".join(lines)) + key_bytes = settings.get_encryption_key_bytes() + total_peers = 0 - # Signal WireGuard container to reload config - reload_flag = wg_confs_dir / ".reload" - reload_flag.write_text("1") + for config in configs: + # Get tenant name for comment + tenant_result = await admin_db.execute( + sa_text("SELECT name FROM tenants WHERE id = :tid"), + {"tid": config.tenant_id}, + ) + tenant_row = tenant_result.fetchone() + tenant_name = tenant_row[0] if tenant_row else str(config.tenant_id) - logger.info("wireguard config synced", tenant_id=str(tenant_id), peers=len(peers)) + peers_result = await admin_db.execute( + select(VpnPeer).where( + VpnPeer.tenant_id == config.tenant_id, + VpnPeer.is_enabled.is_(True), + ) + ) + peers = peers_result.scalars().all() + + if peers: + lines.append(f"# --- Tenant: {tenant_name} ({config.subnet}) ---") + + for peer in peers: + peer_ip = peer.assigned_ip.split("/")[0] + allowed_ips = [f"{peer_ip}/32"] + if peer.additional_allowed_ips: + extra = [s.strip() for s in peer.additional_allowed_ips.split(",") if s.strip()] + allowed_ips.extend(extra) + lines.append("[Peer]") + lines.append(f"PublicKey = {peer.peer_public_key}") + if peer.preshared_key: + psk = decrypt_credentials(peer.preshared_key, key_bytes) + lines.append(f"PresharedKey = {psk}") + lines.append(f"AllowedIPs = {', '.join(allowed_ips)}") + lines.append("") + total_peers += 1 + + # Atomic write: temp file + rename + config_dir = _get_wg_config_path() + wg_confs_dir = config_dir / "wg_confs" + wg_confs_dir.mkdir(parents=True, exist_ok=True) + + conf_path = wg_confs_dir / "wg0.conf" + tmp_path = wg_confs_dir / "wg0.conf.tmp" + tmp_path.write_text("\n".join(lines)) + os.rename(str(tmp_path), str(conf_path)) + + # Signal WireGuard container to reload + reload_flag = wg_confs_dir / ".reload" + reload_flag.write_text("1") + + logger.info("wireguard_config_synced", event="vpn_audit", + tenants=len(configs), peers=total_peers) + + finally: + # Release advisory lock explicitly (session-level lock, not xact-level) + await admin_db.execute(sa_text("SELECT pg_advisory_unlock(hashtext('wireguard_config'))")) # ── Live Status ── @@ -246,13 +287,23 @@ async def get_vpn_config(db: AsyncSession, tenant_id: uuid.UUID) -> Optional[Vpn async def setup_vpn( db: AsyncSession, tenant_id: uuid.UUID, endpoint: Optional[str] = None ) -> VpnConfig: - """Initialize VPN for a tenant — generates server keys and creates config.""" + """Initialize VPN for a tenant — allocates unique subnet, uses global server key.""" existing = await get_vpn_config(db, tenant_id) if existing: raise ValueError("VPN already configured for this tenant") - private_key_b64, public_key_b64 = generate_wireguard_keypair() + # Get or create global server keypair + _, public_key_b64 = await _get_or_create_global_server_key(db) + # Allocate unique subnet + subnet_index = await _allocate_subnet_index(db) + subnet = f"10.10.{subnet_index}.0/24" + server_address = f"10.10.{subnet_index}.1/24" + + # Generate a per-tenant key for the deprecated server_private_key column. + # This column is NOT NULL and kept for rollback safety. The global key + # in system_settings is authoritative; this per-tenant key is unused. + private_key_b64, _ = generate_wireguard_keypair() key_bytes = settings.get_encryption_key_bytes() encrypted_private = encrypt_credentials(private_key_b64, key_bytes) @@ -260,13 +311,19 @@ async def setup_vpn( tenant_id=tenant_id, server_private_key=encrypted_private, server_public_key=public_key_b64, + subnet_index=subnet_index, + subnet=subnet, + server_address=server_address, endpoint=endpoint, is_enabled=True, ) db.add(config) await db.flush() - await sync_wireguard_config(db, tenant_id) + logger.info("vpn_subnet_allocated", event="vpn_audit", + tenant_id=str(tenant_id), subnet_index=subnet_index, subnet=subnet) + + await sync_wireguard_config(db) return config @@ -284,7 +341,7 @@ async def update_vpn_config( config.is_enabled = is_enabled await db.flush() - await sync_wireguard_config(db, tenant_id) + await sync_wireguard_config(db) return config @@ -330,6 +387,8 @@ async def add_peer(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UUID, if existing.scalar_one_or_none(): raise ValueError("Device is already a VPN peer") + _validate_additional_allowed_ips(additional_allowed_ips) + private_key_b64, public_key_b64 = generate_wireguard_keypair() psk = generate_preshared_key() @@ -351,7 +410,7 @@ async def add_peer(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UUID, db.add(peer) await db.flush() - await sync_wireguard_config(db, tenant_id) + await sync_wireguard_config(db) return peer @@ -366,7 +425,7 @@ async def remove_peer(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.UUID await db.delete(peer) await db.flush() - await sync_wireguard_config(db, tenant_id) + await sync_wireguard_config(db) async def get_peer_config(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.UUID) -> dict: @@ -464,7 +523,7 @@ async def onboard_device( db.add(peer) await db.flush() - await sync_wireguard_config(db, tenant_id) + await sync_wireguard_config(db) # Generate RouterOS commands endpoint = config.endpoint or "YOUR_SERVER_IP:51820"