fix(ci): switch to commit-and-cleanup test isolation

Replace savepoint/shared-connection approach with real commits and
table cleanup in teardown. This ensures test data is visible to API
endpoint sessions without connection sharing deadlocks.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Jason Staack
2026-03-15 06:19:12 -05:00
parent d30c4ab522
commit eb60b219b8

View File

@@ -3,18 +3,19 @@ Integration test fixtures for the TOD backend.
Provides: Provides:
- Database engines (admin + app_user) pointing at real PostgreSQL+TimescaleDB - Database engines (admin + app_user) pointing at real PostgreSQL+TimescaleDB
- Per-test session fixtures with transaction rollback for isolation - Per-test session with commit-and-cleanup (no savepoint tricks)
- app_session_factory for RLS multi-tenant tests (creates sessions with tenant context) - app_session_factory for RLS multi-tenant tests
- FastAPI test client with dependency overrides - FastAPI test client with dependency overrides
- Entity factory fixtures (tenants, users, devices) - Entity factory fixtures (tenants, users, devices)
- Auth helper for getting login tokens - Auth helper that mints JWTs directly
All fixtures use the existing docker-compose PostgreSQL instance. All fixtures use the existing docker-compose PostgreSQL instance.
Set TEST_DATABASE_URL / TEST_APP_USER_DATABASE_URL env vars to override defaults. Set TEST_DATABASE_URL / TEST_APP_USER_DATABASE_URL env vars to override defaults.
Event loop strategy: All async fixtures are function-scoped to avoid the Test isolation strategy: tests commit data to the real database, then
pytest-asyncio 0.26 session/function loop mismatch. Engine creation and DB clean up all test-created rows in teardown. This avoids the savepoint
setup use synchronous subprocess calls (Alembic) and module-level singletons. visibility issues that break FK checks when API endpoints use separate
DB sessions.
""" """
import os import os
@@ -31,6 +32,7 @@ from httpx import ASGITransport, AsyncClient
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.ext.asyncio import ( from sqlalchemy.ext.asyncio import (
AsyncSession, AsyncSession,
async_sessionmaker,
create_async_engine, create_async_engine,
) )
@@ -64,16 +66,9 @@ def _ensure_database_setup():
backend_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) backend_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
env = os.environ.copy() env = os.environ.copy()
# Ensure DATABASE_URL points at the test database, not the dev/prod URL
# hardcoded in alembic.ini. alembic/env.py reads this variable and overrides
# sqlalchemy.url before opening any connection.
env["DATABASE_URL"] = TEST_DATABASE_URL env["DATABASE_URL"] = TEST_DATABASE_URL
# Migration 029 (VPN tenant isolation) encrypts a WireGuard server private key
# and requires CREDENTIAL_ENCRYPTION_KEY. Provide the dev default if the
# environment does not already supply it (CI always sets this explicitly).
env.setdefault("CREDENTIAL_ENCRYPTION_KEY", "LLLjnfBZTSycvL2U07HDSxUeTtLxb9cZzryQl0R9E4w=") env.setdefault("CREDENTIAL_ENCRYPTION_KEY", "LLLjnfBZTSycvL2U07HDSxUeTtLxb9cZzryQl0R9E4w=")
# Run Alembic migrations via subprocess (handles DB creation and schema)
result = subprocess.run( result = subprocess.run(
[sys.executable, "-m", "alembic", "upgrade", "head"], [sys.executable, "-m", "alembic", "upgrade", "head"],
capture_output=True, capture_output=True,
@@ -99,13 +94,9 @@ def setup_database():
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def admin_engine(): async def admin_engine():
"""Admin engine (superuser) -- bypasses RLS. """Admin engine (superuser) -- bypasses RLS."""
Created fresh per-test to avoid event loop issues.
pool_size=2 since each test only needs a few connections.
"""
engine = create_async_engine( engine = create_async_engine(
TEST_DATABASE_URL, echo=False, pool_pre_ping=True, pool_size=2, max_overflow=3 TEST_DATABASE_URL, echo=False, pool_pre_ping=True, pool_size=5, max_overflow=5
) )
yield engine yield engine
await engine.dispose() await engine.dispose()
@@ -113,46 +104,77 @@ async def admin_engine():
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def app_engine(): async def app_engine():
"""App-user engine -- RLS enforced. """App-user engine -- RLS enforced."""
Created fresh per-test to avoid event loop issues.
"""
engine = create_async_engine( engine = create_async_engine(
TEST_APP_USER_DATABASE_URL, echo=False, pool_pre_ping=True, pool_size=2, max_overflow=3 TEST_APP_USER_DATABASE_URL, echo=False, pool_pre_ping=True, pool_size=5, max_overflow=5
) )
yield engine yield engine
await engine.dispose() await engine.dispose()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Function-scoped session fixtures (fresh per test) # Tables to clean up after each test (reverse FK order)
# ---------------------------------------------------------------------------
_CLEANUP_TABLES = [
"key_access_log",
"user_key_sets",
"device_certificates",
"certificate_authorities",
"template_push_jobs",
"config_template_tags",
"config_templates",
"config_push_operations",
"config_backup_schedules",
"config_backup_runs",
"router_config_changes",
"router_config_diffs",
"router_config_snapshots",
"firmware_upgrade_jobs",
"alert_rule_channels",
"alert_events",
"alert_rules",
"notification_channels",
"maintenance_windows",
"vpn_peers",
"vpn_config",
"device_tag_assignments",
"device_group_memberships",
"device_tags",
"device_groups",
"api_keys",
"audit_logs",
"devices",
"invites",
"user_tenants",
"users",
"tenants",
]
# ---------------------------------------------------------------------------
# Session fixtures (commit-and-cleanup, no savepoints)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def admin_conn(admin_engine): async def admin_session(admin_engine) -> AsyncGenerator[AsyncSession, None]:
"""Shared admin connection with transaction rollback. """Per-test admin session that commits to disk.
Both admin_session and test_app bind to the same connection so that Data is visible to all connections (including API endpoint sessions).
data created in the test (via admin_session) is visible to API Cleanup deletes all rows from test tables after the test.
endpoints (via get_db / get_admin_db overrides).
""" """
conn = await admin_engine.connect() session = AsyncSession(admin_engine, expire_on_commit=False)
trans = await conn.begin()
try:
yield conn
finally:
await trans.rollback()
await conn.close()
@pytest_asyncio.fixture
async def admin_session(admin_conn) -> AsyncGenerator[AsyncSession, None]:
"""Per-test admin session sharing the admin_conn transaction."""
session = AsyncSession(bind=admin_conn, expire_on_commit=False)
try: try:
yield session yield session
finally: finally:
# Clean up all test data in reverse FK order
for table in _CLEANUP_TABLES:
try:
await session.execute(text(f"DELETE FROM {table}"))
except Exception:
pass # Table might not exist in some migration states
await session.commit()
await session.close() await session.close()
@@ -165,7 +187,6 @@ async def app_session(app_engine) -> AsyncGenerator[AsyncSession, None]:
async with app_engine.connect() as conn: async with app_engine.connect() as conn:
trans = await conn.begin() trans = await conn.begin()
session = AsyncSession(bind=conn, expire_on_commit=False) session = AsyncSession(bind=conn, expire_on_commit=False)
# Reset tenant context
await session.execute(text("RESET app.current_tenant")) await session.execute(text("RESET app.current_tenant"))
try: try:
yield session yield session
@@ -180,10 +201,6 @@ def app_session_factory(app_engine):
Each session gets its own connection and transaction (rolled back on exit). Each session gets its own connection and transaction (rolled back on exit).
Caller can pass tenant_id to auto-set RLS context. Caller can pass tenant_id to auto-set RLS context.
Usage:
async with app_session_factory(tenant_id=str(tenant.id)) as session:
result = await session.execute(select(Device))
""" """
from app.database import set_tenant_context from app.database import set_tenant_context
@@ -192,7 +209,6 @@ def app_session_factory(app_engine):
async with app_engine.connect() as conn: async with app_engine.connect() as conn:
trans = await conn.begin() trans = await conn.begin()
session = AsyncSession(bind=conn, expire_on_commit=False) session = AsyncSession(bind=conn, expire_on_commit=False)
# Reset tenant context to prevent leakage
await session.execute(text("RESET app.current_tenant")) await session.execute(text("RESET app.current_tenant"))
if tenant_id: if tenant_id:
await set_tenant_context(session, tenant_id) await set_tenant_context(session, tenant_id)
@@ -211,21 +227,19 @@ def app_session_factory(app_engine):
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def test_app(admin_conn, app_engine): async def test_app(admin_engine, app_engine):
"""Create a FastAPI app instance with test database dependency overrides. """Create a FastAPI app instance with test database dependency overrides.
Both get_db and get_admin_db bind to admin_conn (the shared connection Both get_db and get_admin_db create independent sessions from the admin
from admin_conn fixture). This means data created via admin_session engine. Since admin_session commits data to disk, it is visible to
is visible to API endpoints, and everything rolls back after the test. these sessions without needing shared connections.
""" """
from fastapi import FastAPI from fastapi import FastAPI
from app.database import get_admin_db, get_db from app.database import get_admin_db, get_db
# Create a minimal app without lifespan
app = FastAPI(lifespan=None) app = FastAPI(lifespan=None)
# Import and mount all routers (same as main app)
from app.routers.alerts import router as alerts_router from app.routers.alerts import router as alerts_router
from app.routers.auth import router as auth_router from app.routers.auth import router as auth_router
from app.routers.config_backups import router as config_router from app.routers.config_backups import router as config_router
@@ -254,26 +268,31 @@ async def test_app(admin_conn, app_engine):
app.include_router(templates_router, prefix="/api") app.include_router(templates_router, prefix="/api")
app.include_router(vpn_router, prefix="/api") app.include_router(vpn_router, prefix="/api")
# Register rate limiter (auth endpoints use @limiter.limit)
from app.middleware.rate_limit import setup_rate_limiting from app.middleware.rate_limit import setup_rate_limiting
setup_rate_limiting(app) setup_rate_limiting(app)
# API endpoints bind to the same shared connection as admin_session test_session_factory = async_sessionmaker(
# so test-created data is visible across the transaction. admin_engine, class_=AsyncSession, expire_on_commit=False
)
async def override_get_db() -> AsyncGenerator[AsyncSession, None]: async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
session = AsyncSession(bind=admin_conn, expire_on_commit=False) async with test_session_factory() as session:
try: try:
yield session yield session
finally: await session.commit()
await session.close() except Exception:
await session.rollback()
raise
async def override_get_admin_db() -> AsyncGenerator[AsyncSession, None]: async def override_get_admin_db() -> AsyncGenerator[AsyncSession, None]:
session = AsyncSession(bind=admin_conn, expire_on_commit=False) async with test_session_factory() as session:
try: try:
yield session yield session
finally: await session.commit()
await session.close() except Exception:
await session.rollback()
raise
app.dependency_overrides[get_db] = override_get_db app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_admin_db] = override_get_admin_db app.dependency_overrides[get_admin_db] = override_get_admin_db
@@ -293,12 +312,11 @@ async def client(test_app) -> AsyncGenerator[AsyncClient, None]:
import redis import redis
try: try:
# Rate limiter uses Redis DB 1 (see app/middleware/rate_limit.py)
r = redis.Redis(host="localhost", port=6379, db=1) r = redis.Redis(host="localhost", port=6379, db=1)
r.flushdb() r.flushdb()
r.close() r.close()
except Exception: except Exception:
pass # Redis not available -- skip clearing pass
transport = ASGITransport(app=test_app) transport = ASGITransport(app=test_app)
async with AsyncClient(transport=transport, base_url="http://test") as ac: async with AsyncClient(transport=transport, base_url="http://test") as ac:
@@ -394,12 +412,8 @@ def create_test_device():
def auth_headers_factory(create_test_tenant, create_test_user): def auth_headers_factory(create_test_tenant, create_test_user):
"""Factory to create authenticated headers for a test user. """Factory to create authenticated headers for a test user.
Creates a tenant + user, generates a JWT directly (no HTTP login Creates a tenant + user, commits to disk, then mints a JWT directly.
round-trip), and returns the Authorization headers dict. The commit ensures data is visible to API endpoint sessions.
We mint the token directly rather than going through /api/auth/login
because the test admin_session uses a savepoint transaction that is
invisible to the login endpoint's own DB session.
""" """
async def _create( async def _create(
@@ -410,7 +424,7 @@ def auth_headers_factory(create_test_tenant, create_test_user):
tenant_name: str | None = None, tenant_name: str | None = None,
existing_tenant_id: uuid.UUID | None = None, existing_tenant_id: uuid.UUID | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Create user, mint JWT, return headers + tenant/user info.""" """Create user, commit, mint JWT, return headers + tenant/user info."""
from app.services.auth import create_access_token from app.services.auth import create_access_token
if existing_tenant_id: if existing_tenant_id:
@@ -426,7 +440,8 @@ def auth_headers_factory(create_test_tenant, create_test_user):
password=password, password=password,
role=role, role=role,
) )
await admin_session.flush() # Commit to disk so API endpoint sessions can see this data
await admin_session.commit()
access_token = create_access_token( access_token = create_access_token(
user_id=user.id, user_id=user.id,