feat(vpn): add global server key helpers, subnet allocation, and allowed-IPs validation

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Jason Staack
2026-03-14 16:27:35 -05:00
parent 593323d277
commit 93fe935edf
2 changed files with 151 additions and 0 deletions

View File

@@ -47,6 +47,102 @@ def generate_preshared_key() -> str:
return base64.b64encode(os.urandom(32)).decode()
# ── Global Server Key & Subnet Allocation ──
async def _get_or_create_global_server_key(db: AsyncSession) -> tuple[str, str]:
"""Get (or create on first call) the global WireGuard server keypair.
Returns (private_key_b64, public_key_b64). Private key is decrypted.
Uses an advisory lock to prevent race conditions on first-time generation.
"""
from sqlalchemy import text as sa_text
# Advisory lock prevents two simultaneous first-calls from generating different keypairs
await db.execute(sa_text("SELECT pg_advisory_xact_lock(hashtext('vpn_server_keygen'))"))
result = await db.execute(
sa_text("SELECT key, value, encrypted_value FROM system_settings WHERE key IN ('vpn_server_public_key', 'vpn_server_private_key')")
)
rows = {row[0]: row for row in result.fetchall()}
if "vpn_server_public_key" in rows and "vpn_server_private_key" in rows:
public_key = rows["vpn_server_public_key"][1]
encrypted_private = rows["vpn_server_private_key"][2]
key_bytes = settings.get_encryption_key_bytes()
private_key = decrypt_credentials(encrypted_private, key_bytes)
return private_key, public_key
# First call on fresh install — generate and store
private_key_b64, public_key_b64 = generate_wireguard_keypair()
key_bytes = settings.get_encryption_key_bytes()
encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
await db.execute(
sa_text("""
INSERT INTO system_settings (key, value, updated_at)
VALUES ('vpn_server_public_key', :pub, now())
ON CONFLICT (key) DO UPDATE SET value = :pub, updated_at = now()
"""),
{"pub": public_key_b64},
)
await db.execute(
sa_text("""
INSERT INTO system_settings (key, value, encrypted_value, updated_at)
VALUES ('vpn_server_private_key', NULL, :enc, now())
ON CONFLICT (key) DO UPDATE SET encrypted_value = :enc, updated_at = now()
"""),
{"enc": encrypted_private},
)
await db.flush()
logger.info("vpn_global_server_keypair_generated", event="vpn_audit")
return private_key_b64, public_key_b64
def _allocate_subnet_index_from_used(used: set[int]) -> int:
"""Find the first available subnet index in [1, 255] not in `used`.
Pure function for unit testing. Raises ValueError if pool exhausted.
"""
for i in range(1, 256):
if i not in used:
return i
raise ValueError("VPN subnet pool exhausted")
async def _allocate_subnet_index(db: AsyncSession) -> int:
"""Allocate next available subnet_index from the database.
Uses gap-filling: finds the lowest integer in [1,255] not already used.
The UNIQUE constraint on subnet_index protects against races.
"""
result = await db.execute(select(VpnConfig.subnet_index))
used = {row[0] for row in result.all()}
return _allocate_subnet_index_from_used(used)
_VPN_ADDRESS_SPACE = ipaddress.ip_network("10.10.0.0/16")
def _validate_additional_allowed_ips(additional_allowed_ips: str | None) -> None:
"""Reject additional_allowed_ips that overlap the VPN address space (10.10.0.0/16)."""
if not additional_allowed_ips:
return
for entry in additional_allowed_ips.split(","):
entry = entry.strip()
if not entry:
continue
try:
network = ipaddress.ip_network(entry, strict=False)
except ValueError:
continue # let WireGuard reject malformed entries
if network.overlaps(_VPN_ADDRESS_SPACE):
raise ValueError(
"Additional allowed IPs must not overlap the VPN address space (10.10.0.0/16)"
)
# ── Config File Management ──

View File

@@ -0,0 +1,55 @@
"""Unit tests for VPN subnet allocation and allowed-IPs validation."""
import pytest
from app.services.vpn_service import _allocate_subnet_index_from_used, _validate_additional_allowed_ips
class TestSubnetAllocation:
def test_first_allocation_returns_1(self):
assert _allocate_subnet_index_from_used(set()) == 1
def test_sequential_allocation(self):
assert _allocate_subnet_index_from_used({1}) == 2
assert _allocate_subnet_index_from_used({1, 2}) == 3
def test_gap_filling(self):
assert _allocate_subnet_index_from_used({1, 3}) == 2
assert _allocate_subnet_index_from_used({2, 3}) == 1
def test_pool_exhausted(self):
with pytest.raises(ValueError, match="subnet pool exhausted"):
_allocate_subnet_index_from_used(set(range(1, 256)))
def test_max_allocation(self):
used = set(range(1, 255))
assert _allocate_subnet_index_from_used(used) == 255
class TestAllowedIpsValidation:
def test_none_is_valid(self):
_validate_additional_allowed_ips(None)
def test_empty_is_valid(self):
_validate_additional_allowed_ips("")
def test_non_vpn_subnet_is_valid(self):
_validate_additional_allowed_ips("192.168.1.0/24")
def test_multiple_non_vpn_subnets_valid(self):
_validate_additional_allowed_ips("192.168.1.0/24, 172.16.0.0/16")
def test_vpn_subnet_rejected(self):
with pytest.raises(ValueError, match="must not overlap"):
_validate_additional_allowed_ips("10.10.5.0/24")
def test_vpn_supernet_rejected(self):
with pytest.raises(ValueError, match="must not overlap"):
_validate_additional_allowed_ips("10.10.0.0/16")
def test_vpn_host_rejected(self):
with pytest.raises(ValueError, match="must not overlap"):
_validate_additional_allowed_ips("10.10.1.5/32")
def test_mixed_valid_and_invalid_rejected(self):
with pytest.raises(ValueError, match="must not overlap"):
_validate_additional_allowed_ips("192.168.1.0/24, 10.10.2.0/24")