From 93fe935edf1af02fbdf6204baed1a04da25afa6a Mon Sep 17 00:00:00 2001 From: Jason Staack Date: Sat, 14 Mar 2026 16:27:35 -0500 Subject: [PATCH] feat(vpn): add global server key helpers, subnet allocation, and allowed-IPs validation Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/app/services/vpn_service.py | 96 +++++++++++++++++++++++++++ backend/tests/unit/test_vpn_subnet.py | 55 +++++++++++++++ 2 files changed, 151 insertions(+) create mode 100644 backend/tests/unit/test_vpn_subnet.py diff --git a/backend/app/services/vpn_service.py b/backend/app/services/vpn_service.py index 947715c..b6a1af8 100644 --- a/backend/app/services/vpn_service.py +++ b/backend/app/services/vpn_service.py @@ -47,6 +47,102 @@ def generate_preshared_key() -> str: return base64.b64encode(os.urandom(32)).decode() +# ── Global Server Key & Subnet Allocation ── + + +async def _get_or_create_global_server_key(db: AsyncSession) -> tuple[str, str]: + """Get (or create on first call) the global WireGuard server keypair. + + Returns (private_key_b64, public_key_b64). Private key is decrypted. + Uses an advisory lock to prevent race conditions on first-time generation. + """ + from sqlalchemy import text as sa_text + + # Advisory lock prevents two simultaneous first-calls from generating different keypairs + await db.execute(sa_text("SELECT pg_advisory_xact_lock(hashtext('vpn_server_keygen'))")) + + result = await db.execute( + sa_text("SELECT key, value, encrypted_value FROM system_settings WHERE key IN ('vpn_server_public_key', 'vpn_server_private_key')") + ) + rows = {row[0]: row for row in result.fetchall()} + + if "vpn_server_public_key" in rows and "vpn_server_private_key" in rows: + public_key = rows["vpn_server_public_key"][1] + encrypted_private = rows["vpn_server_private_key"][2] + key_bytes = settings.get_encryption_key_bytes() + private_key = decrypt_credentials(encrypted_private, key_bytes) + return private_key, public_key + + # First call on fresh install — generate and store + private_key_b64, public_key_b64 = generate_wireguard_keypair() + key_bytes = settings.get_encryption_key_bytes() + encrypted_private = encrypt_credentials(private_key_b64, key_bytes) + + await db.execute( + sa_text(""" + INSERT INTO system_settings (key, value, updated_at) + VALUES ('vpn_server_public_key', :pub, now()) + ON CONFLICT (key) DO UPDATE SET value = :pub, updated_at = now() + """), + {"pub": public_key_b64}, + ) + await db.execute( + sa_text(""" + INSERT INTO system_settings (key, value, encrypted_value, updated_at) + VALUES ('vpn_server_private_key', NULL, :enc, now()) + ON CONFLICT (key) DO UPDATE SET encrypted_value = :enc, updated_at = now() + """), + {"enc": encrypted_private}, + ) + await db.flush() + + logger.info("vpn_global_server_keypair_generated", event="vpn_audit") + return private_key_b64, public_key_b64 + + +def _allocate_subnet_index_from_used(used: set[int]) -> int: + """Find the first available subnet index in [1, 255] not in `used`. + + Pure function for unit testing. Raises ValueError if pool exhausted. + """ + for i in range(1, 256): + if i not in used: + return i + raise ValueError("VPN subnet pool exhausted") + + +async def _allocate_subnet_index(db: AsyncSession) -> int: + """Allocate next available subnet_index from the database. + + Uses gap-filling: finds the lowest integer in [1,255] not already used. + The UNIQUE constraint on subnet_index protects against races. + """ + result = await db.execute(select(VpnConfig.subnet_index)) + used = {row[0] for row in result.all()} + return _allocate_subnet_index_from_used(used) + + +_VPN_ADDRESS_SPACE = ipaddress.ip_network("10.10.0.0/16") + + +def _validate_additional_allowed_ips(additional_allowed_ips: str | None) -> None: + """Reject additional_allowed_ips that overlap the VPN address space (10.10.0.0/16).""" + if not additional_allowed_ips: + return + for entry in additional_allowed_ips.split(","): + entry = entry.strip() + if not entry: + continue + try: + network = ipaddress.ip_network(entry, strict=False) + except ValueError: + continue # let WireGuard reject malformed entries + if network.overlaps(_VPN_ADDRESS_SPACE): + raise ValueError( + "Additional allowed IPs must not overlap the VPN address space (10.10.0.0/16)" + ) + + # ── Config File Management ── diff --git a/backend/tests/unit/test_vpn_subnet.py b/backend/tests/unit/test_vpn_subnet.py new file mode 100644 index 0000000..7ace7d2 --- /dev/null +++ b/backend/tests/unit/test_vpn_subnet.py @@ -0,0 +1,55 @@ +"""Unit tests for VPN subnet allocation and allowed-IPs validation.""" + +import pytest +from app.services.vpn_service import _allocate_subnet_index_from_used, _validate_additional_allowed_ips + + +class TestSubnetAllocation: + def test_first_allocation_returns_1(self): + assert _allocate_subnet_index_from_used(set()) == 1 + + def test_sequential_allocation(self): + assert _allocate_subnet_index_from_used({1}) == 2 + assert _allocate_subnet_index_from_used({1, 2}) == 3 + + def test_gap_filling(self): + assert _allocate_subnet_index_from_used({1, 3}) == 2 + assert _allocate_subnet_index_from_used({2, 3}) == 1 + + def test_pool_exhausted(self): + with pytest.raises(ValueError, match="subnet pool exhausted"): + _allocate_subnet_index_from_used(set(range(1, 256))) + + def test_max_allocation(self): + used = set(range(1, 255)) + assert _allocate_subnet_index_from_used(used) == 255 + + +class TestAllowedIpsValidation: + def test_none_is_valid(self): + _validate_additional_allowed_ips(None) + + def test_empty_is_valid(self): + _validate_additional_allowed_ips("") + + def test_non_vpn_subnet_is_valid(self): + _validate_additional_allowed_ips("192.168.1.0/24") + + def test_multiple_non_vpn_subnets_valid(self): + _validate_additional_allowed_ips("192.168.1.0/24, 172.16.0.0/16") + + def test_vpn_subnet_rejected(self): + with pytest.raises(ValueError, match="must not overlap"): + _validate_additional_allowed_ips("10.10.5.0/24") + + def test_vpn_supernet_rejected(self): + with pytest.raises(ValueError, match="must not overlap"): + _validate_additional_allowed_ips("10.10.0.0/16") + + def test_vpn_host_rejected(self): + with pytest.raises(ValueError, match="must not overlap"): + _validate_additional_allowed_ips("10.10.1.5/32") + + def test_mixed_valid_and_invalid_rejected(self): + with pytest.raises(ValueError, match="must not overlap"): + _validate_additional_allowed_ips("192.168.1.0/24, 10.10.2.0/24")