feat: The Other Dude v9.0.1 — full-featured email system
ci: add GitHub Pages deployment workflow for docs site Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
0
backend/tests/__init__.py
Normal file
0
backend/tests/__init__.py
Normal file
16
backend/tests/conftest.py
Normal file
16
backend/tests/conftest.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Shared test fixtures for the backend test suite.
|
||||
|
||||
Phase 7: Minimal fixtures for unit tests (no database, no async).
|
||||
Phase 10: Integration test fixtures added in tests/integration/conftest.py.
|
||||
|
||||
Pytest marker registration and shared configuration lives here.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Register custom markers."""
|
||||
config.addinivalue_line(
|
||||
"markers", "integration: marks tests as integration tests requiring PostgreSQL"
|
||||
)
|
||||
2
backend/tests/integration/__init__.py
Normal file
2
backend/tests/integration/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Integration tests for TOD backend.
|
||||
# Run against real PostgreSQL+TimescaleDB via docker-compose.
|
||||
439
backend/tests/integration/conftest.py
Normal file
439
backend/tests/integration/conftest.py
Normal file
@@ -0,0 +1,439 @@
|
||||
"""
|
||||
Integration test fixtures for the TOD backend.
|
||||
|
||||
Provides:
|
||||
- Database engines (admin + app_user) pointing at real PostgreSQL+TimescaleDB
|
||||
- Per-test session fixtures with transaction rollback for isolation
|
||||
- app_session_factory for RLS multi-tenant tests (creates sessions with tenant context)
|
||||
- FastAPI test client with dependency overrides
|
||||
- Entity factory fixtures (tenants, users, devices)
|
||||
- Auth helper for getting login tokens
|
||||
|
||||
All fixtures use the existing docker-compose PostgreSQL instance.
|
||||
Set TEST_DATABASE_URL / TEST_APP_USER_DATABASE_URL env vars to override defaults.
|
||||
|
||||
Event loop strategy: All async fixtures are function-scoped to avoid the
|
||||
pytest-asyncio 0.26 session/function loop mismatch. Engine creation and DB
|
||||
setup use synchronous subprocess calls (Alembic) and module-level singletons.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Environment configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TEST_DATABASE_URL = os.environ.get(
|
||||
"TEST_DATABASE_URL",
|
||||
"postgresql+asyncpg://postgres:postgres@localhost:5432/mikrotik_test",
|
||||
)
|
||||
TEST_APP_USER_DATABASE_URL = os.environ.get(
|
||||
"TEST_APP_USER_DATABASE_URL",
|
||||
"postgresql+asyncpg://app_user:app_password@localhost:5432/mikrotik_test",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# One-time database setup (runs once per session via autouse sync fixture)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DB_SETUP_DONE = False
|
||||
|
||||
|
||||
def _ensure_database_setup():
|
||||
"""Synchronous one-time DB setup: create test DB if needed, run migrations."""
|
||||
global _DB_SETUP_DONE
|
||||
if _DB_SETUP_DONE:
|
||||
return
|
||||
_DB_SETUP_DONE = True
|
||||
|
||||
backend_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
env = os.environ.copy()
|
||||
env["DATABASE_URL"] = TEST_DATABASE_URL
|
||||
|
||||
# Run Alembic migrations via subprocess (handles DB creation and schema)
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "alembic", "upgrade", "head"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=backend_dir,
|
||||
env=env,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Alembic migration failed:\n{result.stderr}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def setup_database():
|
||||
"""Session-scoped sync fixture: ensures DB schema is ready."""
|
||||
_ensure_database_setup()
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Engine fixtures (function-scoped to stay on same event loop as tests)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def admin_engine():
|
||||
"""Admin engine (superuser) -- bypasses RLS.
|
||||
|
||||
Created fresh per-test to avoid event loop issues.
|
||||
pool_size=2 since each test only needs a few connections.
|
||||
"""
|
||||
engine = create_async_engine(
|
||||
TEST_DATABASE_URL, echo=False, pool_pre_ping=True, pool_size=2, max_overflow=3
|
||||
)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def app_engine():
|
||||
"""App-user engine -- RLS enforced.
|
||||
|
||||
Created fresh per-test to avoid event loop issues.
|
||||
"""
|
||||
engine = create_async_engine(
|
||||
TEST_APP_USER_DATABASE_URL, echo=False, pool_pre_ping=True, pool_size=2, max_overflow=3
|
||||
)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Function-scoped session fixtures (fresh per test)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def admin_session(admin_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Per-test admin session with transaction rollback.
|
||||
|
||||
Each test gets a clean transaction that is rolled back after the test,
|
||||
ensuring no state leakage between tests.
|
||||
"""
|
||||
async with admin_engine.connect() as conn:
|
||||
trans = await conn.begin()
|
||||
session = AsyncSession(bind=conn, expire_on_commit=False)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await trans.rollback()
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def app_session(app_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Per-test app_user session with transaction rollback (RLS enforced).
|
||||
|
||||
Caller must call set_tenant_context() before querying.
|
||||
"""
|
||||
async with app_engine.connect() as conn:
|
||||
trans = await conn.begin()
|
||||
session = AsyncSession(bind=conn, expire_on_commit=False)
|
||||
# Reset tenant context
|
||||
await session.execute(text("RESET app.current_tenant"))
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await trans.rollback()
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_session_factory(app_engine):
|
||||
"""Factory that returns an async context manager for app_user sessions.
|
||||
|
||||
Each session gets its own connection and transaction (rolled back on exit).
|
||||
Caller can pass tenant_id to auto-set RLS context.
|
||||
|
||||
Usage:
|
||||
async with app_session_factory(tenant_id=str(tenant.id)) as session:
|
||||
result = await session.execute(select(Device))
|
||||
"""
|
||||
from app.database import set_tenant_context
|
||||
|
||||
@asynccontextmanager
|
||||
async def _create(tenant_id: str | None = None):
|
||||
async with app_engine.connect() as conn:
|
||||
trans = await conn.begin()
|
||||
session = AsyncSession(bind=conn, expire_on_commit=False)
|
||||
# Reset tenant context to prevent leakage
|
||||
await session.execute(text("RESET app.current_tenant"))
|
||||
if tenant_id:
|
||||
await set_tenant_context(session, tenant_id)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await trans.rollback()
|
||||
await session.close()
|
||||
|
||||
return _create
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FastAPI test app and HTTP client
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_app(admin_engine, app_engine):
|
||||
"""Create a FastAPI app instance with test database dependency overrides.
|
||||
|
||||
- get_db uses app_engine (non-superuser, RLS enforced) so tenant
|
||||
isolation is tested correctly at the API level.
|
||||
- get_admin_db uses admin_engine (superuser) for auth/bootstrap routes.
|
||||
- Disables lifespan to skip migrations, NATS, and scheduler startup.
|
||||
"""
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.database import get_admin_db, get_db
|
||||
|
||||
# Create a minimal app without lifespan
|
||||
app = FastAPI(lifespan=None)
|
||||
|
||||
# Import and mount all routers (same as main app)
|
||||
from app.routers.alerts import router as alerts_router
|
||||
from app.routers.auth import router as auth_router
|
||||
from app.routers.config_backups import router as config_router
|
||||
from app.routers.config_editor import router as config_editor_router
|
||||
from app.routers.device_groups import router as device_groups_router
|
||||
from app.routers.device_tags import router as device_tags_router
|
||||
from app.routers.devices import router as devices_router
|
||||
from app.routers.firmware import router as firmware_router
|
||||
from app.routers.metrics import router as metrics_router
|
||||
from app.routers.templates import router as templates_router
|
||||
from app.routers.tenants import router as tenants_router
|
||||
from app.routers.users import router as users_router
|
||||
|
||||
app.include_router(auth_router, prefix="/api")
|
||||
app.include_router(tenants_router, prefix="/api")
|
||||
app.include_router(users_router, prefix="/api")
|
||||
app.include_router(devices_router, prefix="/api")
|
||||
app.include_router(device_groups_router, prefix="/api")
|
||||
app.include_router(device_tags_router, prefix="/api")
|
||||
app.include_router(metrics_router, prefix="/api")
|
||||
app.include_router(config_router, prefix="/api")
|
||||
app.include_router(firmware_router, prefix="/api")
|
||||
app.include_router(alerts_router, prefix="/api")
|
||||
app.include_router(config_editor_router, prefix="/api")
|
||||
app.include_router(templates_router, prefix="/api")
|
||||
|
||||
# Register rate limiter (auth endpoints use @limiter.limit)
|
||||
from app.middleware.rate_limit import setup_rate_limiting
|
||||
setup_rate_limiting(app)
|
||||
|
||||
# Create test session factories
|
||||
test_admin_session_factory = async_sessionmaker(
|
||||
admin_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
test_app_session_factory = async_sessionmaker(
|
||||
app_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
# get_db uses app_engine (RLS enforced) -- tenant context is set
|
||||
# by get_current_user dependency via set_tenant_context()
|
||||
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with test_app_session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
# get_admin_db uses admin engine (superuser) for auth/bootstrap
|
||||
async def override_get_admin_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with test_admin_session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_admin_db] = override_get_admin_db
|
||||
|
||||
yield app
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(test_app) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""HTTP client using ASGI transport (no network, real app).
|
||||
|
||||
Flushes Redis DB 1 (rate limit storage) before each test to prevent
|
||||
cross-test 429 errors from slowapi.
|
||||
"""
|
||||
import redis
|
||||
|
||||
try:
|
||||
# Rate limiter uses Redis DB 1 (see app/middleware/rate_limit.py)
|
||||
r = redis.Redis(host="localhost", port=6379, db=1)
|
||||
r.flushdb()
|
||||
r.close()
|
||||
except Exception:
|
||||
pass # Redis not available -- skip clearing
|
||||
|
||||
transport = ASGITransport(app=test_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entity factory fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_test_tenant():
|
||||
"""Factory to create a test tenant via admin session."""
|
||||
|
||||
async def _create(
|
||||
session: AsyncSession,
|
||||
name: str | None = None,
|
||||
):
|
||||
from app.models.tenant import Tenant
|
||||
|
||||
tenant_name = name or f"test-tenant-{uuid.uuid4().hex[:8]}"
|
||||
tenant = Tenant(name=tenant_name)
|
||||
session.add(tenant)
|
||||
await session.flush()
|
||||
return tenant
|
||||
|
||||
return _create
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_test_user():
|
||||
"""Factory to create a test user via admin session."""
|
||||
|
||||
async def _create(
|
||||
session: AsyncSession,
|
||||
tenant_id: uuid.UUID | None,
|
||||
email: str | None = None,
|
||||
password: str = "TestPass123!",
|
||||
role: str = "tenant_admin",
|
||||
name: str = "Test User",
|
||||
):
|
||||
from app.models.user import User
|
||||
from app.services.auth import hash_password
|
||||
|
||||
user_email = email or f"test-{uuid.uuid4().hex[:8]}@example.com"
|
||||
user = User(
|
||||
email=user_email,
|
||||
hashed_password=hash_password(password),
|
||||
name=name,
|
||||
role=role,
|
||||
tenant_id=tenant_id,
|
||||
is_active=True,
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
return user
|
||||
|
||||
return _create
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_test_device():
|
||||
"""Factory to create a test device via admin session."""
|
||||
|
||||
async def _create(
|
||||
session: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
hostname: str | None = None,
|
||||
ip_address: str | None = None,
|
||||
status: str = "online",
|
||||
):
|
||||
from app.models.device import Device
|
||||
|
||||
device_hostname = hostname or f"router-{uuid.uuid4().hex[:8]}"
|
||||
device_ip = ip_address or f"10.0.{uuid.uuid4().int % 256}.{uuid.uuid4().int % 256}"
|
||||
device = Device(
|
||||
tenant_id=tenant_id,
|
||||
hostname=device_hostname,
|
||||
ip_address=device_ip,
|
||||
api_port=8728,
|
||||
api_ssl_port=8729,
|
||||
status=status,
|
||||
)
|
||||
session.add(device)
|
||||
await session.flush()
|
||||
return device
|
||||
|
||||
return _create
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_headers_factory(client, create_test_tenant, create_test_user):
|
||||
"""Factory to create authenticated headers for a test user.
|
||||
|
||||
Creates a tenant + user, logs in via the test client, and returns
|
||||
the Authorization headers dict ready for use in subsequent requests.
|
||||
"""
|
||||
|
||||
async def _create(
|
||||
admin_session: AsyncSession,
|
||||
email: str | None = None,
|
||||
password: str = "TestPass123!",
|
||||
role: str = "tenant_admin",
|
||||
tenant_name: str | None = None,
|
||||
existing_tenant_id: uuid.UUID | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create user, login, return headers + tenant/user info."""
|
||||
if existing_tenant_id:
|
||||
tenant_id = existing_tenant_id
|
||||
else:
|
||||
tenant = await create_test_tenant(admin_session, name=tenant_name)
|
||||
tenant_id = tenant.id
|
||||
|
||||
user = await create_test_user(
|
||||
admin_session,
|
||||
tenant_id=tenant_id,
|
||||
email=email,
|
||||
password=password,
|
||||
role=role,
|
||||
)
|
||||
await admin_session.commit()
|
||||
|
||||
user_email = user.email
|
||||
|
||||
# Login via the API
|
||||
login_resp = await client.post(
|
||||
"/api/auth/login",
|
||||
json={"email": user_email, "password": password},
|
||||
)
|
||||
assert login_resp.status_code == 200, f"Login failed: {login_resp.text}"
|
||||
tokens = login_resp.json()
|
||||
|
||||
return {
|
||||
"headers": {"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
"access_token": tokens["access_token"],
|
||||
"refresh_token": tokens.get("refresh_token"),
|
||||
"tenant_id": str(tenant_id),
|
||||
"user_id": str(user.id),
|
||||
"user_email": user_email,
|
||||
}
|
||||
|
||||
return _create
|
||||
275
backend/tests/integration/test_alerts_api.py
Normal file
275
backend/tests/integration/test_alerts_api.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
Integration tests for the Alerts API endpoints.
|
||||
|
||||
Tests exercise:
|
||||
- GET /api/tenants/{tenant_id}/alert-rules -- list rules
|
||||
- POST /api/tenants/{tenant_id}/alert-rules -- create rule
|
||||
- PUT /api/tenants/{tenant_id}/alert-rules/{rule_id} -- update rule
|
||||
- DELETE /api/tenants/{tenant_id}/alert-rules/{rule_id} -- delete rule
|
||||
- PATCH /api/tenants/{tenant_id}/alert-rules/{rule_id}/toggle
|
||||
- GET /api/tenants/{tenant_id}/alerts -- list events
|
||||
- GET /api/tenants/{tenant_id}/alerts/active-count -- active count
|
||||
- GET /api/tenants/{tenant_id}/devices/{device_id}/alerts -- device alerts
|
||||
|
||||
All tests run against real PostgreSQL.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
VALID_ALERT_RULE = {
|
||||
"name": "High CPU Alert",
|
||||
"metric": "cpu_load",
|
||||
"operator": "gt",
|
||||
"threshold": 90.0,
|
||||
"duration_polls": 3,
|
||||
"severity": "warning",
|
||||
"enabled": True,
|
||||
"channel_ids": [],
|
||||
}
|
||||
|
||||
|
||||
class TestAlertRulesCRUD:
|
||||
"""Alert rules CRUD endpoints."""
|
||||
|
||||
async def test_list_alert_rules_empty(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""GET /api/tenants/{tenant_id}/alert-rules returns 200 with empty list."""
|
||||
auth = await auth_headers_factory(admin_session)
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/alert-rules",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
async def test_create_alert_rule(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""POST /api/tenants/{tenant_id}/alert-rules creates a rule."""
|
||||
auth = await auth_headers_factory(admin_session, role="operator")
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
rule_data = {**VALID_ALERT_RULE, "name": f"CPU Alert {uuid.uuid4().hex[:6]}"}
|
||||
|
||||
resp = await client.post(
|
||||
f"/api/tenants/{tenant_id}/alert-rules",
|
||||
json=rule_data,
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["name"] == rule_data["name"]
|
||||
assert data["metric"] == "cpu_load"
|
||||
assert data["operator"] == "gt"
|
||||
assert data["threshold"] == 90.0
|
||||
assert data["severity"] == "warning"
|
||||
assert "id" in data
|
||||
|
||||
async def test_update_alert_rule(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""PUT /api/tenants/{tenant_id}/alert-rules/{rule_id} updates a rule."""
|
||||
auth = await auth_headers_factory(admin_session, role="operator")
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
# Create a rule first
|
||||
rule_data = {**VALID_ALERT_RULE, "name": f"Update Test {uuid.uuid4().hex[:6]}"}
|
||||
create_resp = await client.post(
|
||||
f"/api/tenants/{tenant_id}/alert-rules",
|
||||
json=rule_data,
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert create_resp.status_code == 201
|
||||
rule_id = create_resp.json()["id"]
|
||||
|
||||
# Update it
|
||||
updated_data = {**rule_data, "threshold": 95.0, "severity": "critical"}
|
||||
update_resp = await client.put(
|
||||
f"/api/tenants/{tenant_id}/alert-rules/{rule_id}",
|
||||
json=updated_data,
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert update_resp.status_code == 200
|
||||
data = update_resp.json()
|
||||
assert data["threshold"] == 95.0
|
||||
assert data["severity"] == "critical"
|
||||
|
||||
async def test_delete_alert_rule(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""DELETE /api/tenants/{tenant_id}/alert-rules/{rule_id} deletes a rule."""
|
||||
auth = await auth_headers_factory(admin_session, role="operator")
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
# Create a non-default rule
|
||||
rule_data = {**VALID_ALERT_RULE, "name": f"Delete Test {uuid.uuid4().hex[:6]}"}
|
||||
create_resp = await client.post(
|
||||
f"/api/tenants/{tenant_id}/alert-rules",
|
||||
json=rule_data,
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert create_resp.status_code == 201
|
||||
rule_id = create_resp.json()["id"]
|
||||
|
||||
# Delete it
|
||||
del_resp = await client.delete(
|
||||
f"/api/tenants/{tenant_id}/alert-rules/{rule_id}",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert del_resp.status_code == 204
|
||||
|
||||
async def test_toggle_alert_rule(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""PATCH toggle flips the enabled state of a rule."""
|
||||
auth = await auth_headers_factory(admin_session, role="operator")
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
# Create a rule (enabled=True)
|
||||
rule_data = {**VALID_ALERT_RULE, "name": f"Toggle Test {uuid.uuid4().hex[:6]}"}
|
||||
create_resp = await client.post(
|
||||
f"/api/tenants/{tenant_id}/alert-rules",
|
||||
json=rule_data,
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert create_resp.status_code == 201
|
||||
rule_id = create_resp.json()["id"]
|
||||
|
||||
# Toggle it
|
||||
toggle_resp = await client.patch(
|
||||
f"/api/tenants/{tenant_id}/alert-rules/{rule_id}/toggle",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert toggle_resp.status_code == 200
|
||||
data = toggle_resp.json()
|
||||
assert data["enabled"] is False # Was True, toggled to False
|
||||
|
||||
async def test_create_alert_rule_invalid_metric(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""POST with invalid metric returns 422."""
|
||||
auth = await auth_headers_factory(admin_session, role="operator")
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
rule_data = {**VALID_ALERT_RULE, "metric": "invalid_metric"}
|
||||
resp = await client.post(
|
||||
f"/api/tenants/{tenant_id}/alert-rules",
|
||||
json=rule_data,
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
async def test_create_alert_rule_viewer_forbidden(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""POST as viewer returns 403."""
|
||||
auth = await auth_headers_factory(admin_session, role="viewer")
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
resp = await client.post(
|
||||
f"/api/tenants/{tenant_id}/alert-rules",
|
||||
json=VALID_ALERT_RULE,
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
class TestAlertEvents:
|
||||
"""Alert events listing endpoints."""
|
||||
|
||||
async def test_list_alerts_empty(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""GET /api/tenants/{tenant_id}/alerts returns 200 with paginated empty response."""
|
||||
auth = await auth_headers_factory(admin_session)
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/alerts",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "items" in data
|
||||
assert "total" in data
|
||||
assert data["total"] >= 0
|
||||
assert isinstance(data["items"], list)
|
||||
|
||||
async def test_active_alert_count(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""GET active-count returns count of firing alerts."""
|
||||
auth = await auth_headers_factory(admin_session)
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/alerts/active-count",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "count" in data
|
||||
assert isinstance(data["count"], int)
|
||||
assert data["count"] >= 0
|
||||
|
||||
async def test_device_alerts_empty(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""GET /api/tenants/{tenant_id}/devices/{device_id}/alerts returns paginated response."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
device = await create_test_device(admin_session, tenant.id)
|
||||
await admin_session.commit()
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/devices/{device.id}/alerts",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "items" in data
|
||||
assert "total" in data
|
||||
302
backend/tests/integration/test_auth_api.py
Normal file
302
backend/tests/integration/test_auth_api.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""
|
||||
Auth API endpoint integration tests (TEST-04 partial).
|
||||
|
||||
Tests auth endpoints end-to-end against real PostgreSQL:
|
||||
- POST /api/auth/login (success, wrong password, nonexistent user)
|
||||
- POST /api/auth/refresh (token refresh flow)
|
||||
- GET /api/auth/me (current user info)
|
||||
- Protected endpoint access without/with invalid token
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
|
||||
from app.models.tenant import Tenant
|
||||
from app.models.user import User
|
||||
from app.services.auth import hash_password
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
from tests.integration.conftest import TEST_DATABASE_URL
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _admin_commit(url, callback):
|
||||
"""Open a fresh admin connection, run callback, commit, close."""
|
||||
engine = create_async_engine(url, echo=False)
|
||||
async with engine.connect() as conn:
|
||||
session = AsyncSession(bind=conn, expire_on_commit=False)
|
||||
result = await callback(session)
|
||||
await session.commit()
|
||||
await engine.dispose()
|
||||
return result
|
||||
|
||||
|
||||
async def _admin_cleanup(url, *table_names):
|
||||
"""Delete from specified tables via admin engine."""
|
||||
from sqlalchemy import text
|
||||
|
||||
engine = create_async_engine(url, echo=False)
|
||||
async with engine.connect() as conn:
|
||||
for table in table_names:
|
||||
await conn.execute(text(f"DELETE FROM {table}"))
|
||||
await conn.commit()
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 1: Login success
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_login_success(client, admin_engine):
|
||||
"""POST /api/auth/login with correct credentials returns 200 and tokens."""
|
||||
uid = uuid.uuid4().hex[:6]
|
||||
|
||||
async def setup(session):
|
||||
tenant = Tenant(name=f"auth-login-{uid}")
|
||||
session.add(tenant)
|
||||
await session.flush()
|
||||
|
||||
user = User(
|
||||
email=f"auth-login-{uid}@example.com",
|
||||
hashed_password=hash_password("SecurePass123!"),
|
||||
name="Auth Test User",
|
||||
role="tenant_admin",
|
||||
tenant_id=tenant.id,
|
||||
is_active=True,
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
return {"email": user.email, "tenant_id": str(tenant.id)}
|
||||
|
||||
data = await _admin_commit(TEST_DATABASE_URL, setup)
|
||||
|
||||
try:
|
||||
resp = await client.post(
|
||||
"/api/auth/login",
|
||||
json={"email": data["email"], "password": "SecurePass123!"},
|
||||
)
|
||||
assert resp.status_code == 200, f"Login failed: {resp.text}"
|
||||
|
||||
body = resp.json()
|
||||
assert "access_token" in body
|
||||
assert "refresh_token" in body
|
||||
assert body["token_type"] == "bearer"
|
||||
assert len(body["access_token"]) > 0
|
||||
assert len(body["refresh_token"]) > 0
|
||||
|
||||
# Verify httpOnly cookie is set
|
||||
cookies = resp.cookies
|
||||
# Cookie may or may not appear in httpx depending on secure flag
|
||||
# Just verify the response contains Set-Cookie header
|
||||
set_cookie = resp.headers.get("set-cookie", "")
|
||||
assert "access_token" in set_cookie or len(body["access_token"]) > 0
|
||||
finally:
|
||||
await _admin_cleanup(TEST_DATABASE_URL, "users", "tenants")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 2: Login with wrong password
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_login_wrong_password(client, admin_engine):
|
||||
"""POST /api/auth/login with wrong password returns 401."""
|
||||
uid = uuid.uuid4().hex[:6]
|
||||
|
||||
async def setup(session):
|
||||
tenant = Tenant(name=f"auth-wrongpw-{uid}")
|
||||
session.add(tenant)
|
||||
await session.flush()
|
||||
|
||||
user = User(
|
||||
email=f"auth-wrongpw-{uid}@example.com",
|
||||
hashed_password=hash_password("CorrectPass123!"),
|
||||
name="Wrong PW User",
|
||||
role="tenant_admin",
|
||||
tenant_id=tenant.id,
|
||||
is_active=True,
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
return {"email": user.email}
|
||||
|
||||
data = await _admin_commit(TEST_DATABASE_URL, setup)
|
||||
|
||||
try:
|
||||
resp = await client.post(
|
||||
"/api/auth/login",
|
||||
json={"email": data["email"], "password": "WrongPassword!"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "Invalid credentials" in resp.json()["detail"]
|
||||
finally:
|
||||
await _admin_cleanup(TEST_DATABASE_URL, "users", "tenants")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 3: Login with nonexistent user
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_login_nonexistent_user(client):
|
||||
"""POST /api/auth/login with email that doesn't exist returns 401."""
|
||||
resp = await client.post(
|
||||
"/api/auth/login",
|
||||
json={"email": f"doesnotexist-{uuid.uuid4().hex[:6]}@example.com", "password": "Anything!"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "Invalid credentials" in resp.json()["detail"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 4: Token refresh
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_token_refresh(client, admin_engine):
|
||||
"""POST /api/auth/refresh with valid refresh token returns new tokens."""
|
||||
uid = uuid.uuid4().hex[:6]
|
||||
|
||||
async def setup(session):
|
||||
tenant = Tenant(name=f"auth-refresh-{uid}")
|
||||
session.add(tenant)
|
||||
await session.flush()
|
||||
|
||||
user = User(
|
||||
email=f"auth-refresh-{uid}@example.com",
|
||||
hashed_password=hash_password("RefreshPass123!"),
|
||||
name="Refresh User",
|
||||
role="tenant_admin",
|
||||
tenant_id=tenant.id,
|
||||
is_active=True,
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
return {"email": user.email}
|
||||
|
||||
data = await _admin_commit(TEST_DATABASE_URL, setup)
|
||||
|
||||
try:
|
||||
# Login first to get refresh token
|
||||
login_resp = await client.post(
|
||||
"/api/auth/login",
|
||||
json={"email": data["email"], "password": "RefreshPass123!"},
|
||||
)
|
||||
assert login_resp.status_code == 200
|
||||
tokens = login_resp.json()
|
||||
refresh_token = tokens["refresh_token"]
|
||||
original_access = tokens["access_token"]
|
||||
|
||||
# Use refresh token to get new access token
|
||||
refresh_resp = await client.post(
|
||||
"/api/auth/refresh",
|
||||
json={"refresh_token": refresh_token},
|
||||
)
|
||||
assert refresh_resp.status_code == 200
|
||||
|
||||
new_tokens = refresh_resp.json()
|
||||
assert "access_token" in new_tokens
|
||||
assert "refresh_token" in new_tokens
|
||||
assert new_tokens["token_type"] == "bearer"
|
||||
# Verify the new access token is a valid JWT (can be same if within same second)
|
||||
assert len(new_tokens["access_token"]) > 0
|
||||
assert len(new_tokens["refresh_token"]) > 0
|
||||
|
||||
# Verify the new access token works for /me
|
||||
me_resp = await client.get(
|
||||
"/api/auth/me",
|
||||
headers={"Authorization": f"Bearer {new_tokens['access_token']}"},
|
||||
)
|
||||
assert me_resp.status_code == 200
|
||||
assert me_resp.json()["email"] == data["email"]
|
||||
finally:
|
||||
await _admin_cleanup(TEST_DATABASE_URL, "users", "tenants")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 5: Get current user
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_get_current_user(client, admin_engine):
|
||||
"""GET /api/auth/me with valid token returns current user info."""
|
||||
uid = uuid.uuid4().hex[:6]
|
||||
|
||||
async def setup(session):
|
||||
tenant = Tenant(name=f"auth-me-{uid}")
|
||||
session.add(tenant)
|
||||
await session.flush()
|
||||
|
||||
user = User(
|
||||
email=f"auth-me-{uid}@example.com",
|
||||
hashed_password=hash_password("MePass123!"),
|
||||
name="Me User",
|
||||
role="tenant_admin",
|
||||
tenant_id=tenant.id,
|
||||
is_active=True,
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
return {"email": user.email, "tenant_id": str(tenant.id), "user_id": str(user.id)}
|
||||
|
||||
data = await _admin_commit(TEST_DATABASE_URL, setup)
|
||||
|
||||
try:
|
||||
# Login
|
||||
login_resp = await client.post(
|
||||
"/api/auth/login",
|
||||
json={"email": data["email"], "password": "MePass123!"},
|
||||
)
|
||||
assert login_resp.status_code == 200
|
||||
token = login_resp.json()["access_token"]
|
||||
|
||||
# Get /me
|
||||
me_resp = await client.get(
|
||||
"/api/auth/me",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert me_resp.status_code == 200
|
||||
|
||||
me_data = me_resp.json()
|
||||
assert me_data["email"] == data["email"]
|
||||
assert me_data["name"] == "Me User"
|
||||
assert me_data["role"] == "tenant_admin"
|
||||
assert me_data["tenant_id"] == data["tenant_id"]
|
||||
assert me_data["id"] == data["user_id"]
|
||||
finally:
|
||||
await _admin_cleanup(TEST_DATABASE_URL, "users", "tenants")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 6: Protected endpoint without token
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_protected_endpoint_without_token(client):
|
||||
"""GET /api/tenants/{id}/devices without auth headers returns 401."""
|
||||
fake_tenant_id = str(uuid.uuid4())
|
||||
resp = await client.get(f"/api/tenants/{fake_tenant_id}/devices")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 7: Protected endpoint with invalid token
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_protected_endpoint_with_invalid_token(client):
|
||||
"""GET /api/tenants/{id}/devices with invalid Bearer token returns 401."""
|
||||
fake_tenant_id = str(uuid.uuid4())
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{fake_tenant_id}/devices",
|
||||
headers={"Authorization": "Bearer totally-invalid-jwt-token"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
149
backend/tests/integration/test_config_api.py
Normal file
149
backend/tests/integration/test_config_api.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
Integration tests for the Config Backup API endpoints.
|
||||
|
||||
Tests exercise:
|
||||
- GET /api/tenants/{tenant_id}/devices/{device_id}/config/backups
|
||||
- GET /api/tenants/{tenant_id}/devices/{device_id}/config/schedules
|
||||
- PUT /api/tenants/{tenant_id}/devices/{device_id}/config/schedules
|
||||
|
||||
POST /backups (trigger) and POST /restore require actual RouterOS connections
|
||||
and git store, so we only test that the endpoints exist and respond appropriately.
|
||||
|
||||
All tests run against real PostgreSQL.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
class TestConfigBackups:
|
||||
"""Config backup listing and schedule endpoints."""
|
||||
|
||||
async def test_list_config_backups_empty(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""GET config backups for a device with no backups returns 200 + empty list."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
device = await create_test_device(admin_session, tenant.id)
|
||||
await admin_session.commit()
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/devices/{device.id}/config/backups",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 0
|
||||
|
||||
async def test_get_backup_schedule_default(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""GET schedule returns synthetic default when no schedule configured."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
device = await create_test_device(admin_session, tenant.id)
|
||||
await admin_session.commit()
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/devices/{device.id}/config/schedules",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["is_default"] is True
|
||||
assert data["cron_expression"] == "0 2 * * *"
|
||||
assert data["enabled"] is True
|
||||
|
||||
async def test_update_backup_schedule(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""PUT schedule creates/updates device-specific backup schedule."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id, role="operator"
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
device = await create_test_device(admin_session, tenant.id)
|
||||
await admin_session.commit()
|
||||
|
||||
schedule_data = {
|
||||
"cron_expression": "0 3 * * 1", # Monday at 3am
|
||||
"enabled": True,
|
||||
}
|
||||
resp = await client.put(
|
||||
f"/api/tenants/{tenant_id}/devices/{device.id}/config/schedules",
|
||||
json=schedule_data,
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["cron_expression"] == "0 3 * * 1"
|
||||
assert data["enabled"] is True
|
||||
assert data["is_default"] is False
|
||||
assert data["device_id"] == str(device.id)
|
||||
|
||||
async def test_backup_endpoints_respond(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""Config backup router responds (not 404) for expected paths."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
device = await create_test_device(admin_session, tenant.id)
|
||||
await admin_session.commit()
|
||||
|
||||
# List backups -- should respond
|
||||
backups_resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/devices/{device.id}/config/backups",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert backups_resp.status_code != 404
|
||||
|
||||
# Get schedule -- should respond
|
||||
schedule_resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/devices/{device.id}/config/schedules",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert schedule_resp.status_code != 404
|
||||
|
||||
async def test_config_backups_unauthenticated(self, client):
|
||||
"""GET config backups without auth returns 401."""
|
||||
tenant_id = str(uuid.uuid4())
|
||||
device_id = str(uuid.uuid4())
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/devices/{device_id}/config/backups"
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
227
backend/tests/integration/test_devices_api.py
Normal file
227
backend/tests/integration/test_devices_api.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""
|
||||
Integration tests for the Device CRUD API endpoints.
|
||||
|
||||
Tests exercise /api/tenants/{tenant_id}/devices/* endpoints against
|
||||
real PostgreSQL+TimescaleDB with full auth + RLS enforcement.
|
||||
|
||||
All tests are independent and create their own test data.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _unique_suffix():
|
||||
"""Return a short unique suffix for test data."""
|
||||
return uuid.uuid4().hex[:8]
|
||||
|
||||
|
||||
class TestDevicesCRUD:
|
||||
"""Device list, create, get, update, delete endpoints."""
|
||||
|
||||
async def test_list_devices_empty(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""GET /api/tenants/{tenant_id}/devices returns 200 with empty list."""
|
||||
auth = await auth_headers_factory(admin_session)
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/devices",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["items"] == []
|
||||
assert data["total"] == 0
|
||||
|
||||
async def test_create_device(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""POST /api/tenants/{tenant_id}/devices creates a device and returns 201."""
|
||||
auth = await auth_headers_factory(admin_session, role="operator")
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
device_data = {
|
||||
"hostname": f"test-router-{uuid.uuid4().hex[:8]}",
|
||||
"ip_address": "192.168.88.1",
|
||||
"api_port": 8728,
|
||||
"api_ssl_port": 8729,
|
||||
"username": "admin",
|
||||
"password": "admin123",
|
||||
}
|
||||
|
||||
resp = await client.post(
|
||||
f"/api/tenants/{tenant_id}/devices",
|
||||
json=device_data,
|
||||
headers=auth["headers"],
|
||||
)
|
||||
# create_device does TCP probe -- may fail in test env without real device
|
||||
# Accept either 201 (success) or 502/422 (connectivity check failure)
|
||||
if resp.status_code == 201:
|
||||
data = resp.json()
|
||||
assert data["hostname"] == device_data["hostname"]
|
||||
assert data["ip_address"] == device_data["ip_address"]
|
||||
assert "id" in data
|
||||
# Credentials should never be returned in response
|
||||
assert "password" not in data
|
||||
assert "username" not in data
|
||||
assert "encrypted_credentials" not in data
|
||||
|
||||
async def test_get_device(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""GET /api/tenants/{tenant_id}/devices/{device_id} returns correct device."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
device = await create_test_device(admin_session, tenant.id)
|
||||
await admin_session.commit()
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/devices/{device.id}",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == str(device.id)
|
||||
assert data["hostname"] == device.hostname
|
||||
assert data["ip_address"] == device.ip_address
|
||||
|
||||
async def test_update_device(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""PUT /api/tenants/{tenant_id}/devices/{device_id} updates device fields."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id, role="operator"
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
device = await create_test_device(admin_session, tenant.id, hostname="old-hostname")
|
||||
await admin_session.commit()
|
||||
|
||||
update_data = {"hostname": f"new-hostname-{uuid.uuid4().hex[:8]}"}
|
||||
resp = await client.put(
|
||||
f"/api/tenants/{tenant_id}/devices/{device.id}",
|
||||
json=update_data,
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["hostname"] == update_data["hostname"]
|
||||
|
||||
async def test_delete_device(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""DELETE /api/tenants/{tenant_id}/devices/{device_id} removes the device."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
# delete requires tenant_admin or above
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id, role="tenant_admin"
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
device = await create_test_device(admin_session, tenant.id)
|
||||
await admin_session.commit()
|
||||
|
||||
resp = await client.delete(
|
||||
f"/api/tenants/{tenant_id}/devices/{device.id}",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# Verify it's gone
|
||||
get_resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/devices/{device.id}",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert get_resp.status_code == 404
|
||||
|
||||
async def test_list_devices_with_status_filter(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""GET /api/tenants/{tenant_id}/devices?status=online returns filtered results."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
# Create devices with different statuses
|
||||
await create_test_device(
|
||||
admin_session, tenant.id, hostname="online-device", status="online"
|
||||
)
|
||||
await create_test_device(
|
||||
admin_session, tenant.id, hostname="offline-device", status="offline"
|
||||
)
|
||||
await admin_session.commit()
|
||||
|
||||
# Filter for online only
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/devices?status=online",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] >= 1
|
||||
for item in data["items"]:
|
||||
assert item["status"] == "online"
|
||||
|
||||
async def test_get_device_not_found(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""GET /api/tenants/{tenant_id}/devices/{nonexistent} returns 404."""
|
||||
auth = await auth_headers_factory(admin_session)
|
||||
tenant_id = auth["tenant_id"]
|
||||
fake_id = str(uuid.uuid4())
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/devices/{fake_id}",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_list_devices_unauthenticated(self, client):
|
||||
"""GET /api/tenants/{tenant_id}/devices without auth returns 401."""
|
||||
tenant_id = str(uuid.uuid4())
|
||||
resp = await client.get(f"/api/tenants/{tenant_id}/devices")
|
||||
assert resp.status_code == 401
|
||||
183
backend/tests/integration/test_firmware_api.py
Normal file
183
backend/tests/integration/test_firmware_api.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
Integration tests for the Firmware API endpoints.
|
||||
|
||||
Tests exercise:
|
||||
- GET /api/firmware/versions -- list firmware versions (global)
|
||||
- GET /api/tenants/{tenant_id}/firmware/overview -- firmware overview per tenant
|
||||
- GET /api/tenants/{tenant_id}/firmware/upgrades -- list upgrade jobs
|
||||
- PATCH /api/tenants/{tenant_id}/devices/{device_id}/preferred-channel
|
||||
|
||||
Upgrade endpoints (POST .../upgrade, .../mass-upgrade) require actual RouterOS
|
||||
connections and NATS, so we verify the endpoint exists and handles missing
|
||||
services gracefully. Download/cache endpoints require super_admin.
|
||||
|
||||
All tests run against real PostgreSQL.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
class TestFirmwareVersions:
|
||||
"""Firmware version listing endpoints."""
|
||||
|
||||
async def test_list_firmware_versions(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""GET /api/firmware/versions returns 200 with list (may be empty)."""
|
||||
auth = await auth_headers_factory(admin_session)
|
||||
|
||||
resp = await client.get(
|
||||
"/api/firmware/versions",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
async def test_list_firmware_versions_with_filters(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""GET /api/firmware/versions with filters returns 200."""
|
||||
auth = await auth_headers_factory(admin_session)
|
||||
|
||||
resp = await client.get(
|
||||
"/api/firmware/versions",
|
||||
params={"architecture": "arm", "channel": "stable"},
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert isinstance(resp.json(), list)
|
||||
|
||||
|
||||
class TestFirmwareOverview:
|
||||
"""Tenant-scoped firmware overview."""
|
||||
|
||||
async def test_firmware_overview(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""GET /api/tenants/{tenant_id}/firmware/overview returns 200."""
|
||||
auth = await auth_headers_factory(admin_session)
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/firmware/overview",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
# May return 200 or 500 if firmware_service depends on external state
|
||||
# At minimum, it should not be 404
|
||||
assert resp.status_code != 404
|
||||
|
||||
|
||||
class TestPreferredChannel:
|
||||
"""Device preferred firmware channel endpoint."""
|
||||
|
||||
async def test_set_device_preferred_channel(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""PATCH preferred channel updates the device firmware channel preference."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id, role="operator"
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
device = await create_test_device(admin_session, tenant.id)
|
||||
await admin_session.commit()
|
||||
|
||||
resp = await client.patch(
|
||||
f"/api/tenants/{tenant_id}/devices/{device.id}/preferred-channel",
|
||||
json={"preferred_channel": "long-term"},
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["preferred_channel"] == "long-term"
|
||||
assert data["status"] == "ok"
|
||||
|
||||
async def test_set_invalid_preferred_channel(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""PATCH with invalid channel returns 422."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id, role="operator"
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
device = await create_test_device(admin_session, tenant.id)
|
||||
await admin_session.commit()
|
||||
|
||||
resp = await client.patch(
|
||||
f"/api/tenants/{tenant_id}/devices/{device.id}/preferred-channel",
|
||||
json={"preferred_channel": "invalid"},
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
class TestUpgradeJobs:
|
||||
"""Upgrade job listing endpoints."""
|
||||
|
||||
async def test_list_upgrade_jobs_empty(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""GET /api/tenants/{tenant_id}/firmware/upgrades returns paginated response."""
|
||||
auth = await auth_headers_factory(admin_session)
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/firmware/upgrades",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "items" in data
|
||||
assert "total" in data
|
||||
assert isinstance(data["items"], list)
|
||||
assert data["total"] >= 0
|
||||
|
||||
async def test_get_upgrade_job_not_found(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""GET /api/tenants/{tenant_id}/firmware/upgrades/{fake_id} returns 404."""
|
||||
auth = await auth_headers_factory(admin_session)
|
||||
tenant_id = auth["tenant_id"]
|
||||
fake_id = str(uuid.uuid4())
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/firmware/upgrades/{fake_id}",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_firmware_unauthenticated(self, client):
|
||||
"""GET firmware versions without auth returns 401."""
|
||||
resp = await client.get("/api/firmware/versions")
|
||||
assert resp.status_code == 401
|
||||
323
backend/tests/integration/test_monitoring_api.py
Normal file
323
backend/tests/integration/test_monitoring_api.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""
|
||||
Integration tests for the Monitoring / Metrics API endpoints.
|
||||
|
||||
Tests exercise:
|
||||
- /api/tenants/{tenant_id}/devices/{device_id}/metrics/health
|
||||
- /api/tenants/{tenant_id}/devices/{device_id}/metrics/interfaces
|
||||
- /api/tenants/{tenant_id}/devices/{device_id}/metrics/interfaces/list
|
||||
- /api/tenants/{tenant_id}/devices/{device_id}/metrics/wireless
|
||||
- /api/tenants/{tenant_id}/devices/{device_id}/metrics/wireless/latest
|
||||
- /api/tenants/{tenant_id}/devices/{device_id}/metrics/sparkline
|
||||
- /api/tenants/{tenant_id}/fleet/summary
|
||||
- /api/fleet/summary (super_admin only)
|
||||
|
||||
All tests run against real PostgreSQL+TimescaleDB.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
class TestHealthMetrics:
|
||||
"""Device health metrics endpoints."""
|
||||
|
||||
async def test_get_device_health_metrics_empty(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""GET health metrics for a device with no data returns 200 + empty list."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
device = await create_test_device(admin_session, tenant.id)
|
||||
await admin_session.commit()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
start = (now - timedelta(hours=1)).isoformat()
|
||||
end = now.isoformat()
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/devices/{device.id}/metrics/health",
|
||||
params={"start": start, "end": end},
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 0
|
||||
|
||||
async def test_get_device_health_metrics_with_data(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""GET health metrics returns bucketed data when rows exist."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
device = await create_test_device(admin_session, tenant.id)
|
||||
await admin_session.flush()
|
||||
|
||||
# Insert test metric rows directly via admin session
|
||||
now = datetime.now(timezone.utc)
|
||||
for i in range(5):
|
||||
ts = now - timedelta(minutes=i * 5)
|
||||
await admin_session.execute(
|
||||
text(
|
||||
"INSERT INTO health_metrics "
|
||||
"(device_id, time, cpu_load, free_memory, total_memory, "
|
||||
"free_disk, total_disk, temperature) "
|
||||
"VALUES (:device_id, :ts, :cpu, :free_mem, :total_mem, "
|
||||
":free_disk, :total_disk, :temp)"
|
||||
),
|
||||
{
|
||||
"device_id": str(device.id),
|
||||
"ts": ts,
|
||||
"cpu": 30 + i * 5,
|
||||
"free_mem": 500000000,
|
||||
"total_mem": 1000000000,
|
||||
"free_disk": 200000000,
|
||||
"total_disk": 500000000,
|
||||
"temp": 45,
|
||||
},
|
||||
)
|
||||
await admin_session.commit()
|
||||
|
||||
start = (now - timedelta(hours=1)).isoformat()
|
||||
end = now.isoformat()
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/devices/{device.id}/metrics/health",
|
||||
params={"start": start, "end": end},
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) > 0
|
||||
# Each bucket should have expected fields
|
||||
for bucket in data:
|
||||
assert "bucket" in bucket
|
||||
assert "avg_cpu" in bucket
|
||||
|
||||
|
||||
class TestInterfaceMetrics:
|
||||
"""Interface traffic metrics endpoints."""
|
||||
|
||||
async def test_get_interface_metrics_empty(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""GET interface metrics for device with no data returns 200 + empty list."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
device = await create_test_device(admin_session, tenant.id)
|
||||
await admin_session.commit()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/devices/{device.id}/metrics/interfaces",
|
||||
params={
|
||||
"start": (now - timedelta(hours=1)).isoformat(),
|
||||
"end": now.isoformat(),
|
||||
},
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert isinstance(resp.json(), list)
|
||||
|
||||
async def test_get_interface_list_empty(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""GET interface list for device with no data returns 200 + empty list."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
device = await create_test_device(admin_session, tenant.id)
|
||||
await admin_session.commit()
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/devices/{device.id}/metrics/interfaces/list",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert isinstance(resp.json(), list)
|
||||
|
||||
|
||||
class TestSparkline:
|
||||
"""Sparkline endpoint."""
|
||||
|
||||
async def test_sparkline_empty(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""GET sparkline for device with no data returns 200 + empty list."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
device = await create_test_device(admin_session, tenant.id)
|
||||
await admin_session.commit()
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/devices/{device.id}/metrics/sparkline",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert isinstance(resp.json(), list)
|
||||
|
||||
|
||||
class TestFleetSummary:
|
||||
"""Fleet summary endpoints."""
|
||||
|
||||
async def test_fleet_summary_empty(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""GET /api/tenants/{tenant_id}/fleet/summary returns 200 with empty fleet."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/fleet/summary",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
async def test_fleet_summary_with_devices(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""GET fleet summary returns device data when devices exist."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
await create_test_device(admin_session, tenant.id, hostname="fleet-dev-1")
|
||||
await create_test_device(admin_session, tenant.id, hostname="fleet-dev-2")
|
||||
await admin_session.commit()
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/fleet/summary",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) >= 2
|
||||
hostnames = [d["hostname"] for d in data]
|
||||
assert "fleet-dev-1" in hostnames
|
||||
assert "fleet-dev-2" in hostnames
|
||||
|
||||
async def test_fleet_summary_unauthenticated(self, client):
|
||||
"""GET fleet summary without auth returns 401."""
|
||||
tenant_id = str(uuid.uuid4())
|
||||
resp = await client.get(f"/api/tenants/{tenant_id}/fleet/summary")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
class TestWirelessMetrics:
|
||||
"""Wireless metrics endpoints."""
|
||||
|
||||
async def test_wireless_metrics_empty(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""GET wireless metrics for device with no data returns 200 + empty list."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
device = await create_test_device(admin_session, tenant.id)
|
||||
await admin_session.commit()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/devices/{device.id}/metrics/wireless",
|
||||
params={
|
||||
"start": (now - timedelta(hours=1)).isoformat(),
|
||||
"end": now.isoformat(),
|
||||
},
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert isinstance(resp.json(), list)
|
||||
|
||||
async def test_wireless_latest_empty(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""GET wireless latest for device with no data returns 200 + empty list."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
device = await create_test_device(admin_session, tenant.id)
|
||||
await admin_session.commit()
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/devices/{device.id}/metrics/wireless/latest",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert isinstance(resp.json(), list)
|
||||
437
backend/tests/integration/test_rls_isolation.py
Normal file
437
backend/tests/integration/test_rls_isolation.py
Normal file
@@ -0,0 +1,437 @@
|
||||
"""
|
||||
RLS (Row Level Security) tenant isolation integration tests.
|
||||
|
||||
Verifies that PostgreSQL RLS policies correctly isolate tenant data:
|
||||
- Tenant A cannot see Tenant B's devices, alerts, or device groups
|
||||
- Tenant A cannot insert data into Tenant B's namespace
|
||||
- super_admin context sees all tenants
|
||||
- API-level isolation matches DB-level isolation
|
||||
|
||||
These tests commit real data to PostgreSQL and use the app_user engine
|
||||
(which enforces RLS) to validate isolation. Each test creates unique
|
||||
entity names to avoid collisions and cleans up via admin engine.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
|
||||
from app.database import set_tenant_context
|
||||
from app.models.alert import AlertRule
|
||||
from app.models.device import Device, DeviceGroup
|
||||
from app.models.tenant import Tenant
|
||||
from app.models.user import User
|
||||
from app.services.auth import hash_password
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
# Use the same test DB URLs as conftest
|
||||
from tests.integration.conftest import TEST_APP_USER_DATABASE_URL, TEST_DATABASE_URL
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers: create and commit entities, and cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _admin_commit(url, callback):
|
||||
"""Open a fresh admin connection, run callback, commit, close."""
|
||||
engine = create_async_engine(url, echo=False)
|
||||
async with engine.connect() as conn:
|
||||
session = AsyncSession(bind=conn, expire_on_commit=False)
|
||||
result = await callback(session)
|
||||
await session.commit()
|
||||
await engine.dispose()
|
||||
return result
|
||||
|
||||
|
||||
async def _app_query(url, tenant_id, model_class):
|
||||
"""Open a fresh app_user connection, set tenant context, query model, close."""
|
||||
engine = create_async_engine(url, echo=False)
|
||||
async with engine.connect() as conn:
|
||||
session = AsyncSession(bind=conn, expire_on_commit=False)
|
||||
await set_tenant_context(session, tenant_id)
|
||||
result = await session.execute(select(model_class))
|
||||
rows = result.scalars().all()
|
||||
await engine.dispose()
|
||||
return rows
|
||||
|
||||
|
||||
async def _admin_cleanup(url, *table_names):
|
||||
"""Truncate specified tables via admin engine."""
|
||||
engine = create_async_engine(url, echo=False)
|
||||
async with engine.connect() as conn:
|
||||
for table in table_names:
|
||||
await conn.execute(text(f"DELETE FROM {table}"))
|
||||
await conn.commit()
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 1: Tenant A cannot see Tenant B devices
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_tenant_a_cannot_see_tenant_b_devices():
|
||||
"""Tenant A app_user session only returns Tenant A devices."""
|
||||
uid = uuid.uuid4().hex[:6]
|
||||
|
||||
# Create tenants via admin
|
||||
async def setup(session):
|
||||
ta = Tenant(name=f"rls-dev-ta-{uid}")
|
||||
tb = Tenant(name=f"rls-dev-tb-{uid}")
|
||||
session.add_all([ta, tb])
|
||||
await session.flush()
|
||||
|
||||
da = Device(
|
||||
tenant_id=ta.id, hostname=f"rls-ra-{uid}", ip_address="10.1.1.1",
|
||||
api_port=8728, api_ssl_port=8729, status="online",
|
||||
)
|
||||
db = Device(
|
||||
tenant_id=tb.id, hostname=f"rls-rb-{uid}", ip_address="10.1.1.2",
|
||||
api_port=8728, api_ssl_port=8729, status="online",
|
||||
)
|
||||
session.add_all([da, db])
|
||||
await session.flush()
|
||||
return {"ta_id": str(ta.id), "tb_id": str(tb.id)}
|
||||
|
||||
ids = await _admin_commit(TEST_DATABASE_URL, setup)
|
||||
|
||||
try:
|
||||
# Query as Tenant A
|
||||
devices_a = await _app_query(TEST_APP_USER_DATABASE_URL, ids["ta_id"], Device)
|
||||
assert len(devices_a) == 1
|
||||
assert devices_a[0].hostname == f"rls-ra-{uid}"
|
||||
|
||||
# Query as Tenant B
|
||||
devices_b = await _app_query(TEST_APP_USER_DATABASE_URL, ids["tb_id"], Device)
|
||||
assert len(devices_b) == 1
|
||||
assert devices_b[0].hostname == f"rls-rb-{uid}"
|
||||
finally:
|
||||
await _admin_cleanup(TEST_DATABASE_URL, "devices", "tenants")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 2: Tenant A cannot see Tenant B alerts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_tenant_a_cannot_see_tenant_b_alerts():
|
||||
"""Tenant A only sees its own alert rules."""
|
||||
uid = uuid.uuid4().hex[:6]
|
||||
|
||||
async def setup(session):
|
||||
ta = Tenant(name=f"rls-alrt-ta-{uid}")
|
||||
tb = Tenant(name=f"rls-alrt-tb-{uid}")
|
||||
session.add_all([ta, tb])
|
||||
await session.flush()
|
||||
|
||||
ra = AlertRule(
|
||||
tenant_id=ta.id, name=f"CPU Alert A {uid}",
|
||||
metric="cpu_load", operator=">", threshold=90.0, severity="warning",
|
||||
)
|
||||
rb = AlertRule(
|
||||
tenant_id=tb.id, name=f"CPU Alert B {uid}",
|
||||
metric="cpu_load", operator=">", threshold=85.0, severity="critical",
|
||||
)
|
||||
session.add_all([ra, rb])
|
||||
await session.flush()
|
||||
return {"ta_id": str(ta.id), "tb_id": str(tb.id)}
|
||||
|
||||
ids = await _admin_commit(TEST_DATABASE_URL, setup)
|
||||
|
||||
try:
|
||||
rules_a = await _app_query(TEST_APP_USER_DATABASE_URL, ids["ta_id"], AlertRule)
|
||||
assert len(rules_a) == 1
|
||||
assert rules_a[0].name == f"CPU Alert A {uid}"
|
||||
finally:
|
||||
await _admin_cleanup(TEST_DATABASE_URL, "alert_rules", "tenants")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 3: Tenant A cannot see Tenant B device groups
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_tenant_a_cannot_see_tenant_b_device_groups():
|
||||
"""Tenant A only sees its own device groups."""
|
||||
uid = uuid.uuid4().hex[:6]
|
||||
|
||||
async def setup(session):
|
||||
ta = Tenant(name=f"rls-grp-ta-{uid}")
|
||||
tb = Tenant(name=f"rls-grp-tb-{uid}")
|
||||
session.add_all([ta, tb])
|
||||
await session.flush()
|
||||
|
||||
ga = DeviceGroup(tenant_id=ta.id, name=f"Group A {uid}")
|
||||
gb = DeviceGroup(tenant_id=tb.id, name=f"Group B {uid}")
|
||||
session.add_all([ga, gb])
|
||||
await session.flush()
|
||||
return {"ta_id": str(ta.id), "tb_id": str(tb.id)}
|
||||
|
||||
ids = await _admin_commit(TEST_DATABASE_URL, setup)
|
||||
|
||||
try:
|
||||
groups_a = await _app_query(TEST_APP_USER_DATABASE_URL, ids["ta_id"], DeviceGroup)
|
||||
assert len(groups_a) == 1
|
||||
assert groups_a[0].name == f"Group A {uid}"
|
||||
finally:
|
||||
await _admin_cleanup(TEST_DATABASE_URL, "device_groups", "tenants")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 4: Tenant A cannot insert device into Tenant B
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_tenant_a_cannot_insert_device_into_tenant_b():
|
||||
"""Inserting a device with tenant_b's ID while in tenant_a context should fail or be invisible."""
|
||||
uid = uuid.uuid4().hex[:6]
|
||||
|
||||
async def setup(session):
|
||||
ta = Tenant(name=f"rls-ins-ta-{uid}")
|
||||
tb = Tenant(name=f"rls-ins-tb-{uid}")
|
||||
session.add_all([ta, tb])
|
||||
await session.flush()
|
||||
return {"ta_id": str(ta.id), "tb_id": str(tb.id)}
|
||||
|
||||
ids = await _admin_commit(TEST_DATABASE_URL, setup)
|
||||
|
||||
try:
|
||||
engine = create_async_engine(TEST_APP_USER_DATABASE_URL, echo=False)
|
||||
async with engine.connect() as conn:
|
||||
session = AsyncSession(bind=conn, expire_on_commit=False)
|
||||
await set_tenant_context(session, ids["ta_id"])
|
||||
|
||||
# Attempt to insert a device with tenant_b's tenant_id
|
||||
device = Device(
|
||||
tenant_id=uuid.UUID(ids["tb_id"]),
|
||||
hostname=f"evil-device-{uid}",
|
||||
ip_address="10.99.99.99",
|
||||
api_port=8728,
|
||||
api_ssl_port=8729,
|
||||
status="online",
|
||||
)
|
||||
session.add(device)
|
||||
|
||||
# RLS policy should prevent this -- either by raising an error
|
||||
# or by making the row invisible after insert
|
||||
try:
|
||||
await session.flush()
|
||||
# If the insert succeeded, verify the device is NOT visible
|
||||
result = await session.execute(select(Device))
|
||||
visible = result.scalars().all()
|
||||
cross_tenant = [d for d in visible if d.hostname == f"evil-device-{uid}"]
|
||||
assert len(cross_tenant) == 0, (
|
||||
"Cross-tenant device should not be visible to tenant_a"
|
||||
)
|
||||
except Exception:
|
||||
# RLS violation raised -- this is the expected behavior
|
||||
pass
|
||||
await engine.dispose()
|
||||
finally:
|
||||
await _admin_cleanup(TEST_DATABASE_URL, "devices", "tenants")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 5: super_admin sees all tenants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_super_admin_sees_all_tenants():
|
||||
"""super_admin bypasses RLS via admin engine (superuser) and sees all devices.
|
||||
|
||||
The RLS policy does NOT have a special 'super_admin' tenant context.
|
||||
Instead, super_admin users use the admin engine (PostgreSQL superuser)
|
||||
which bypasses all RLS policies entirely.
|
||||
"""
|
||||
uid = uuid.uuid4().hex[:6]
|
||||
|
||||
async def setup(session):
|
||||
ta = Tenant(name=f"rls-sa-ta-{uid}")
|
||||
tb = Tenant(name=f"rls-sa-tb-{uid}")
|
||||
session.add_all([ta, tb])
|
||||
await session.flush()
|
||||
|
||||
da = Device(
|
||||
tenant_id=ta.id, hostname=f"sa-ra-{uid}", ip_address="10.2.1.1",
|
||||
api_port=8728, api_ssl_port=8729, status="online",
|
||||
)
|
||||
db = Device(
|
||||
tenant_id=tb.id, hostname=f"sa-rb-{uid}", ip_address="10.2.1.2",
|
||||
api_port=8728, api_ssl_port=8729, status="online",
|
||||
)
|
||||
session.add_all([da, db])
|
||||
await session.flush()
|
||||
return {"ta_id": str(ta.id), "tb_id": str(tb.id)}
|
||||
|
||||
ids = await _admin_commit(TEST_DATABASE_URL, setup)
|
||||
|
||||
try:
|
||||
# super_admin uses admin engine (superuser) which bypasses RLS
|
||||
engine = create_async_engine(TEST_DATABASE_URL, echo=False)
|
||||
async with engine.connect() as conn:
|
||||
session = AsyncSession(bind=conn, expire_on_commit=False)
|
||||
result = await session.execute(select(Device))
|
||||
devices = result.scalars().all()
|
||||
await engine.dispose()
|
||||
|
||||
# Admin engine (superuser) should see devices from both tenants
|
||||
hostnames = {d.hostname for d in devices}
|
||||
assert f"sa-ra-{uid}" in hostnames, "admin engine should see tenant_a device"
|
||||
assert f"sa-rb-{uid}" in hostnames, "admin engine should see tenant_b device"
|
||||
|
||||
# Verify that app_user engine with a specific tenant only sees that tenant
|
||||
devices_a = await _app_query(TEST_APP_USER_DATABASE_URL, ids["ta_id"], Device)
|
||||
hostnames_a = {d.hostname for d in devices_a}
|
||||
assert f"sa-ra-{uid}" in hostnames_a
|
||||
assert f"sa-rb-{uid}" not in hostnames_a
|
||||
finally:
|
||||
await _admin_cleanup(TEST_DATABASE_URL, "devices", "tenants")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 6: API-level RLS isolation (devices endpoint)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_api_rls_isolation_devices_endpoint(client, admin_engine):
|
||||
"""Each user only sees their own tenant's devices via the API."""
|
||||
uid = uuid.uuid4().hex[:6]
|
||||
|
||||
# Create data via admin engine (committed for API visibility)
|
||||
async def setup(session):
|
||||
ta = Tenant(name=f"api-rls-ta-{uid}")
|
||||
tb = Tenant(name=f"api-rls-tb-{uid}")
|
||||
session.add_all([ta, tb])
|
||||
await session.flush()
|
||||
|
||||
ua = User(
|
||||
email=f"api-ua-{uid}@example.com",
|
||||
hashed_password=hash_password("TestPass123!"),
|
||||
name="User A", role="tenant_admin",
|
||||
tenant_id=ta.id, is_active=True,
|
||||
)
|
||||
ub = User(
|
||||
email=f"api-ub-{uid}@example.com",
|
||||
hashed_password=hash_password("TestPass123!"),
|
||||
name="User B", role="tenant_admin",
|
||||
tenant_id=tb.id, is_active=True,
|
||||
)
|
||||
session.add_all([ua, ub])
|
||||
await session.flush()
|
||||
|
||||
da = Device(
|
||||
tenant_id=ta.id, hostname=f"api-ra-{uid}", ip_address="10.3.1.1",
|
||||
api_port=8728, api_ssl_port=8729, status="online",
|
||||
)
|
||||
db = Device(
|
||||
tenant_id=tb.id, hostname=f"api-rb-{uid}", ip_address="10.3.1.2",
|
||||
api_port=8728, api_ssl_port=8729, status="online",
|
||||
)
|
||||
session.add_all([da, db])
|
||||
await session.flush()
|
||||
return {
|
||||
"ta_id": str(ta.id), "tb_id": str(tb.id),
|
||||
"ua_email": ua.email, "ub_email": ub.email,
|
||||
}
|
||||
|
||||
ids = await _admin_commit(TEST_DATABASE_URL, setup)
|
||||
|
||||
try:
|
||||
# Login as user A
|
||||
login_a = await client.post(
|
||||
"/api/auth/login",
|
||||
json={"email": ids["ua_email"], "password": "TestPass123!"},
|
||||
)
|
||||
assert login_a.status_code == 200, f"Login A failed: {login_a.text}"
|
||||
token_a = login_a.json()["access_token"]
|
||||
|
||||
# Login as user B
|
||||
login_b = await client.post(
|
||||
"/api/auth/login",
|
||||
json={"email": ids["ub_email"], "password": "TestPass123!"},
|
||||
)
|
||||
assert login_b.status_code == 200, f"Login B failed: {login_b.text}"
|
||||
token_b = login_b.json()["access_token"]
|
||||
|
||||
# User A lists devices for tenant A
|
||||
resp_a = await client.get(
|
||||
f"/api/tenants/{ids['ta_id']}/devices",
|
||||
headers={"Authorization": f"Bearer {token_a}"},
|
||||
)
|
||||
assert resp_a.status_code == 200
|
||||
hostnames_a = [d["hostname"] for d in resp_a.json()["items"]]
|
||||
assert f"api-ra-{uid}" in hostnames_a
|
||||
assert f"api-rb-{uid}" not in hostnames_a
|
||||
|
||||
# User B lists devices for tenant B
|
||||
resp_b = await client.get(
|
||||
f"/api/tenants/{ids['tb_id']}/devices",
|
||||
headers={"Authorization": f"Bearer {token_b}"},
|
||||
)
|
||||
assert resp_b.status_code == 200
|
||||
hostnames_b = [d["hostname"] for d in resp_b.json()["items"]]
|
||||
assert f"api-rb-{uid}" in hostnames_b
|
||||
assert f"api-ra-{uid}" not in hostnames_b
|
||||
finally:
|
||||
await _admin_cleanup(TEST_DATABASE_URL, "devices", "users", "tenants")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 7: API-level cross-tenant device access
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_api_rls_isolation_cross_tenant_device_access(client, admin_engine):
|
||||
"""Accessing another tenant's endpoint returns 403 (tenant access check)."""
|
||||
uid = uuid.uuid4().hex[:6]
|
||||
|
||||
async def setup(session):
|
||||
ta = Tenant(name=f"api-xt-ta-{uid}")
|
||||
tb = Tenant(name=f"api-xt-tb-{uid}")
|
||||
session.add_all([ta, tb])
|
||||
await session.flush()
|
||||
|
||||
ua = User(
|
||||
email=f"api-xt-ua-{uid}@example.com",
|
||||
hashed_password=hash_password("TestPass123!"),
|
||||
name="User A", role="tenant_admin",
|
||||
tenant_id=ta.id, is_active=True,
|
||||
)
|
||||
session.add(ua)
|
||||
await session.flush()
|
||||
|
||||
db = Device(
|
||||
tenant_id=tb.id, hostname=f"api-xt-rb-{uid}", ip_address="10.4.1.1",
|
||||
api_port=8728, api_ssl_port=8729, status="online",
|
||||
)
|
||||
session.add(db)
|
||||
await session.flush()
|
||||
return {
|
||||
"ta_id": str(ta.id), "tb_id": str(tb.id),
|
||||
"ua_email": ua.email, "db_id": str(db.id),
|
||||
}
|
||||
|
||||
ids = await _admin_commit(TEST_DATABASE_URL, setup)
|
||||
|
||||
try:
|
||||
# Login as user A
|
||||
login_a = await client.post(
|
||||
"/api/auth/login",
|
||||
json={"email": ids["ua_email"], "password": "TestPass123!"},
|
||||
)
|
||||
assert login_a.status_code == 200
|
||||
token_a = login_a.json()["access_token"]
|
||||
|
||||
# User A tries to access tenant B's devices endpoint
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{ids['tb_id']}/devices",
|
||||
headers={"Authorization": f"Bearer {token_a}"},
|
||||
)
|
||||
# Should be 403 -- tenant access check prevents cross-tenant access
|
||||
assert resp.status_code == 403
|
||||
finally:
|
||||
await _admin_cleanup(TEST_DATABASE_URL, "devices", "users", "tenants")
|
||||
322
backend/tests/integration/test_templates_api.py
Normal file
322
backend/tests/integration/test_templates_api.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
Integration tests for the Config Templates API endpoints.
|
||||
|
||||
Tests exercise:
|
||||
- GET /api/tenants/{tenant_id}/templates -- list templates
|
||||
- POST /api/tenants/{tenant_id}/templates -- create template
|
||||
- GET /api/tenants/{tenant_id}/templates/{id} -- get template
|
||||
- PUT /api/tenants/{tenant_id}/templates/{id} -- update template
|
||||
- DELETE /api/tenants/{tenant_id}/templates/{id} -- delete template
|
||||
- POST /api/tenants/{tenant_id}/templates/{id}/preview -- preview rendered template
|
||||
|
||||
Push endpoints (POST .../push) require actual RouterOS connections, so we
|
||||
only test the preview endpoint which only needs a database device record.
|
||||
|
||||
All tests run against real PostgreSQL.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
TEMPLATE_CONTENT = """/ip address add address={{ ip_address }}/24 interface=ether1
|
||||
/system identity set name={{ hostname }}
|
||||
"""
|
||||
|
||||
TEMPLATE_VARIABLES = [
|
||||
{"name": "ip_address", "type": "ip", "default": "192.168.1.1"},
|
||||
{"name": "hostname", "type": "string", "default": "router"},
|
||||
]
|
||||
|
||||
|
||||
class TestTemplatesCRUD:
|
||||
"""Template list, create, get, update, delete endpoints."""
|
||||
|
||||
async def test_list_templates_empty(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""GET /api/tenants/{tenant_id}/templates returns 200 with empty list."""
|
||||
auth = await auth_headers_factory(admin_session)
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/templates",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
async def test_create_template(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""POST /api/tenants/{tenant_id}/templates creates a template."""
|
||||
auth = await auth_headers_factory(admin_session, role="operator")
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
template_data = {
|
||||
"name": f"Test Template {uuid.uuid4().hex[:6]}",
|
||||
"description": "A test config template",
|
||||
"content": TEMPLATE_CONTENT,
|
||||
"variables": TEMPLATE_VARIABLES,
|
||||
"tags": ["test", "integration"],
|
||||
}
|
||||
|
||||
resp = await client.post(
|
||||
f"/api/tenants/{tenant_id}/templates",
|
||||
json=template_data,
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["name"] == template_data["name"]
|
||||
assert data["description"] == "A test config template"
|
||||
assert "id" in data
|
||||
assert "content" in data
|
||||
assert data["content"] == TEMPLATE_CONTENT
|
||||
assert data["variable_count"] == 2
|
||||
assert set(data["tags"]) == {"test", "integration"}
|
||||
|
||||
async def test_get_template(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""GET /api/tenants/{tenant_id}/templates/{id} returns full template with content."""
|
||||
auth = await auth_headers_factory(admin_session, role="operator")
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
# Create first
|
||||
create_data = {
|
||||
"name": f"Get Test {uuid.uuid4().hex[:6]}",
|
||||
"content": TEMPLATE_CONTENT,
|
||||
"variables": TEMPLATE_VARIABLES,
|
||||
"tags": [],
|
||||
}
|
||||
create_resp = await client.post(
|
||||
f"/api/tenants/{tenant_id}/templates",
|
||||
json=create_data,
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert create_resp.status_code == 201
|
||||
template_id = create_resp.json()["id"]
|
||||
|
||||
# Get it
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/templates/{template_id}",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == template_id
|
||||
assert data["content"] == TEMPLATE_CONTENT
|
||||
assert "variables" in data
|
||||
assert len(data["variables"]) == 2
|
||||
|
||||
async def test_update_template(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""PUT /api/tenants/{tenant_id}/templates/{id} updates template content."""
|
||||
auth = await auth_headers_factory(admin_session, role="operator")
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
# Create first
|
||||
create_data = {
|
||||
"name": f"Update Test {uuid.uuid4().hex[:6]}",
|
||||
"content": TEMPLATE_CONTENT,
|
||||
"variables": TEMPLATE_VARIABLES,
|
||||
"tags": ["original"],
|
||||
}
|
||||
create_resp = await client.post(
|
||||
f"/api/tenants/{tenant_id}/templates",
|
||||
json=create_data,
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert create_resp.status_code == 201
|
||||
template_id = create_resp.json()["id"]
|
||||
|
||||
# Update it
|
||||
updated_content = "/system identity set name={{ hostname }}-updated\n"
|
||||
update_data = {
|
||||
"name": create_data["name"],
|
||||
"content": updated_content,
|
||||
"variables": [{"name": "hostname", "type": "string"}],
|
||||
"tags": ["updated"],
|
||||
}
|
||||
resp = await client.put(
|
||||
f"/api/tenants/{tenant_id}/templates/{template_id}",
|
||||
json=update_data,
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["content"] == updated_content
|
||||
assert data["variable_count"] == 1
|
||||
assert "updated" in data["tags"]
|
||||
|
||||
async def test_delete_template(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""DELETE /api/tenants/{tenant_id}/templates/{id} removes the template."""
|
||||
auth = await auth_headers_factory(admin_session, role="operator")
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
# Create first
|
||||
create_data = {
|
||||
"name": f"Delete Test {uuid.uuid4().hex[:6]}",
|
||||
"content": "/system identity set name=test\n",
|
||||
"variables": [],
|
||||
"tags": [],
|
||||
}
|
||||
create_resp = await client.post(
|
||||
f"/api/tenants/{tenant_id}/templates",
|
||||
json=create_data,
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert create_resp.status_code == 201
|
||||
template_id = create_resp.json()["id"]
|
||||
|
||||
# Delete it
|
||||
resp = await client.delete(
|
||||
f"/api/tenants/{tenant_id}/templates/{template_id}",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# Verify it's gone
|
||||
get_resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/templates/{template_id}",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert get_resp.status_code == 404
|
||||
|
||||
async def test_get_template_not_found(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
):
|
||||
"""GET non-existent template returns 404."""
|
||||
auth = await auth_headers_factory(admin_session)
|
||||
tenant_id = auth["tenant_id"]
|
||||
fake_id = str(uuid.uuid4())
|
||||
|
||||
resp = await client.get(
|
||||
f"/api/tenants/{tenant_id}/templates/{fake_id}",
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
class TestTemplatePreview:
|
||||
"""Template preview endpoint."""
|
||||
|
||||
async def test_template_preview(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""POST /api/tenants/{tenant_id}/templates/{id}/preview renders template for device."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id, role="operator"
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
# Create device for preview context
|
||||
device = await create_test_device(
|
||||
admin_session, tenant.id, hostname="preview-router", ip_address="10.0.1.1"
|
||||
)
|
||||
await admin_session.commit()
|
||||
|
||||
# Create template
|
||||
template_data = {
|
||||
"name": f"Preview Test {uuid.uuid4().hex[:6]}",
|
||||
"content": "/system identity set name={{ hostname }}\n",
|
||||
"variables": [],
|
||||
"tags": [],
|
||||
}
|
||||
create_resp = await client.post(
|
||||
f"/api/tenants/{tenant_id}/templates",
|
||||
json=template_data,
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert create_resp.status_code == 201
|
||||
template_id = create_resp.json()["id"]
|
||||
|
||||
# Preview it
|
||||
preview_resp = await client.post(
|
||||
f"/api/tenants/{tenant_id}/templates/{template_id}/preview",
|
||||
json={"device_id": str(device.id), "variables": {}},
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert preview_resp.status_code == 200
|
||||
data = preview_resp.json()
|
||||
assert "rendered" in data
|
||||
assert "preview-router" in data["rendered"]
|
||||
assert data["device_hostname"] == "preview-router"
|
||||
|
||||
async def test_template_preview_with_variables(
|
||||
self,
|
||||
client,
|
||||
auth_headers_factory,
|
||||
admin_session,
|
||||
create_test_device,
|
||||
create_test_tenant,
|
||||
):
|
||||
"""Preview with custom variables renders them into the template."""
|
||||
tenant = await create_test_tenant(admin_session)
|
||||
auth = await auth_headers_factory(
|
||||
admin_session, existing_tenant_id=tenant.id, role="operator"
|
||||
)
|
||||
tenant_id = auth["tenant_id"]
|
||||
|
||||
device = await create_test_device(admin_session, tenant.id)
|
||||
await admin_session.commit()
|
||||
|
||||
template_data = {
|
||||
"name": f"VarPreview {uuid.uuid4().hex[:6]}",
|
||||
"content": "/ip address add address={{ custom_ip }}/24 interface=ether1\n",
|
||||
"variables": [{"name": "custom_ip", "type": "ip", "default": "192.168.1.1"}],
|
||||
"tags": [],
|
||||
}
|
||||
create_resp = await client.post(
|
||||
f"/api/tenants/{tenant_id}/templates",
|
||||
json=template_data,
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert create_resp.status_code == 201
|
||||
template_id = create_resp.json()["id"]
|
||||
|
||||
preview_resp = await client.post(
|
||||
f"/api/tenants/{tenant_id}/templates/{template_id}/preview",
|
||||
json={"device_id": str(device.id), "variables": {"custom_ip": "10.10.10.1"}},
|
||||
headers=auth["headers"],
|
||||
)
|
||||
assert preview_resp.status_code == 200
|
||||
data = preview_resp.json()
|
||||
assert "10.10.10.1" in data["rendered"]
|
||||
|
||||
async def test_templates_unauthenticated(self, client):
|
||||
"""GET templates without auth returns 401."""
|
||||
tenant_id = str(uuid.uuid4())
|
||||
resp = await client.get(f"/api/tenants/{tenant_id}/templates")
|
||||
assert resp.status_code == 401
|
||||
42
backend/tests/test_backup_scheduler.py
Normal file
42
backend/tests/test_backup_scheduler.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Tests for dynamic backup scheduling."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from app.services.backup_scheduler import (
|
||||
build_schedule_map,
|
||||
_cron_to_trigger,
|
||||
)
|
||||
|
||||
|
||||
def test_cron_to_trigger_parses_standard_cron():
|
||||
"""Parse '0 2 * * *' into CronTrigger with hour=2, minute=0."""
|
||||
trigger = _cron_to_trigger("0 2 * * *")
|
||||
assert trigger is not None
|
||||
|
||||
|
||||
def test_cron_to_trigger_parses_every_6_hours():
|
||||
"""Parse '0 */6 * * *' into CronTrigger."""
|
||||
trigger = _cron_to_trigger("0 */6 * * *")
|
||||
assert trigger is not None
|
||||
|
||||
|
||||
def test_cron_to_trigger_invalid_returns_none():
|
||||
"""Invalid cron returns None (fallback to default)."""
|
||||
trigger = _cron_to_trigger("not a cron")
|
||||
assert trigger is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_schedule_map_groups_by_cron():
|
||||
"""Devices sharing a cron expression should be grouped together."""
|
||||
schedules = [
|
||||
MagicMock(device_id="dev1", tenant_id="t1", cron_expression="0 2 * * *", enabled=True),
|
||||
MagicMock(device_id="dev2", tenant_id="t1", cron_expression="0 2 * * *", enabled=True),
|
||||
MagicMock(device_id="dev3", tenant_id="t2", cron_expression="0 6 * * *", enabled=True),
|
||||
]
|
||||
schedule_map = build_schedule_map(schedules)
|
||||
assert "0 2 * * *" in schedule_map
|
||||
assert "0 6 * * *" in schedule_map
|
||||
assert len(schedule_map["0 2 * * *"]) == 2
|
||||
assert len(schedule_map["0 6 * * *"]) == 1
|
||||
55
backend/tests/test_config_change_subscriber.py
Normal file
55
backend/tests/test_config_change_subscriber.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Tests for config change NATS subscriber."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from app.services.config_change_subscriber import handle_config_changed
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_triggers_backup_on_config_change():
|
||||
"""Config change event should trigger a backup."""
|
||||
event = {
|
||||
"device_id": str(uuid4()),
|
||||
"tenant_id": str(uuid4()),
|
||||
"old_timestamp": "2026-03-07 11:00:00",
|
||||
"new_timestamp": "2026-03-07 12:00:00",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"app.services.config_change_subscriber.backup_service.run_backup",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_backup, patch(
|
||||
"app.services.config_change_subscriber._last_backup_within_dedup_window",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
):
|
||||
await handle_config_changed(event)
|
||||
|
||||
mock_backup.assert_called_once()
|
||||
assert mock_backup.call_args[1]["trigger_type"] == "config-change"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_backup_within_dedup_window():
|
||||
"""Should skip backup if last backup was < 5 minutes ago."""
|
||||
event = {
|
||||
"device_id": str(uuid4()),
|
||||
"tenant_id": str(uuid4()),
|
||||
"old_timestamp": "2026-03-07 11:00:00",
|
||||
"new_timestamp": "2026-03-07 12:00:00",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"app.services.config_change_subscriber.backup_service.run_backup",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_backup, patch(
|
||||
"app.services.config_change_subscriber._last_backup_within_dedup_window",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
await handle_config_changed(event)
|
||||
|
||||
mock_backup.assert_not_called()
|
||||
82
backend/tests/test_config_checkpoint.py
Normal file
82
backend/tests/test_config_checkpoint.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Tests for config checkpoint endpoint."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestCheckpointEndpointExists:
|
||||
"""Verify the checkpoint route is registered on the config_backups router."""
|
||||
|
||||
def test_router_has_checkpoint_route(self):
|
||||
from app.routers.config_backups import router
|
||||
|
||||
paths = [r.path for r in router.routes]
|
||||
assert any("checkpoint" in p for p in paths), (
|
||||
f"No checkpoint route found. Routes: {paths}"
|
||||
)
|
||||
|
||||
def test_checkpoint_route_is_post(self):
|
||||
from app.routers.config_backups import router
|
||||
|
||||
for route in router.routes:
|
||||
if hasattr(route, "path") and "checkpoint" in route.path:
|
||||
assert "POST" in route.methods, (
|
||||
f"Checkpoint route should be POST, got {route.methods}"
|
||||
)
|
||||
break
|
||||
else:
|
||||
pytest.fail("No checkpoint route found")
|
||||
|
||||
|
||||
class TestCheckpointFunction:
|
||||
"""Test the create_checkpoint handler logic."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkpoint_calls_backup_service_with_checkpoint_trigger(self):
|
||||
"""create_checkpoint should call backup_service.run_backup with trigger_type='checkpoint'."""
|
||||
from app.routers.config_backups import create_checkpoint
|
||||
|
||||
mock_result = {
|
||||
"commit_sha": "abc1234",
|
||||
"trigger_type": "checkpoint",
|
||||
"lines_added": 100,
|
||||
"lines_removed": 0,
|
||||
}
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_user = MagicMock()
|
||||
|
||||
tenant_id = uuid.uuid4()
|
||||
device_id = uuid.uuid4()
|
||||
|
||||
mock_request = MagicMock()
|
||||
|
||||
with patch(
|
||||
"app.routers.config_backups.backup_service.run_backup",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
) as mock_backup, patch(
|
||||
"app.routers.config_backups._check_tenant_access",
|
||||
new_callable=AsyncMock,
|
||||
), patch(
|
||||
"app.routers.config_backups.limiter.enabled",
|
||||
False,
|
||||
):
|
||||
result = await create_checkpoint(
|
||||
request=mock_request,
|
||||
tenant_id=tenant_id,
|
||||
device_id=device_id,
|
||||
db=mock_db,
|
||||
current_user=mock_user,
|
||||
)
|
||||
|
||||
assert result["trigger_type"] == "checkpoint"
|
||||
assert result["commit_sha"] == "abc1234"
|
||||
mock_backup.assert_called_once_with(
|
||||
device_id=str(device_id),
|
||||
tenant_id=str(tenant_id),
|
||||
trigger_type="checkpoint",
|
||||
db_session=mock_db,
|
||||
)
|
||||
120
backend/tests/test_push_recovery.py
Normal file
120
backend/tests/test_push_recovery.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Tests for stale push operation recovery on API startup."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from app.services.restore_service import recover_stale_push_operations
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recovery_commits_reachable_device_with_scheduler():
|
||||
"""If device is reachable and panic-revert scheduler exists, delete it and commit."""
|
||||
push_op = MagicMock()
|
||||
push_op.id = uuid4()
|
||||
push_op.device_id = uuid4()
|
||||
push_op.tenant_id = uuid4()
|
||||
push_op.status = "pending_verification"
|
||||
push_op.scheduler_name = "mikrotik-portal-panic-revert"
|
||||
push_op.started_at = datetime.now(timezone.utc) - timedelta(minutes=10)
|
||||
|
||||
device = MagicMock()
|
||||
device.ip_address = "192.168.1.1"
|
||||
device.api_port = 8729
|
||||
device.ssh_port = 22
|
||||
|
||||
mock_session = AsyncMock()
|
||||
# Return stale ops query
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [push_op]
|
||||
mock_session.execute = AsyncMock(side_effect=[mock_result, MagicMock()])
|
||||
|
||||
# Mock device query result (second execute call)
|
||||
dev_result = MagicMock()
|
||||
dev_result.scalar_one_or_none.return_value = device
|
||||
mock_session.execute = AsyncMock(side_effect=[mock_result, dev_result])
|
||||
|
||||
with patch(
|
||||
"app.services.restore_service._check_reachability",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
), patch(
|
||||
"app.services.restore_service._remove_panic_scheduler",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
), patch(
|
||||
"app.services.restore_service._update_push_op_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update, patch(
|
||||
"app.services.restore_service._publish_push_progress",
|
||||
new_callable=AsyncMock,
|
||||
), patch(
|
||||
"app.services.crypto.decrypt_credentials_hybrid",
|
||||
new_callable=AsyncMock,
|
||||
return_value='{"username": "admin", "password": "test123"}',
|
||||
), patch(
|
||||
"app.services.restore_service.settings",
|
||||
):
|
||||
await recover_stale_push_operations(mock_session)
|
||||
|
||||
mock_update.assert_called_once()
|
||||
call_args = mock_update.call_args
|
||||
assert call_args[0][1] == "committed" or call_args[1].get("new_status") == "committed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recovery_marks_unreachable_device_failed():
|
||||
"""If device is unreachable, mark operation as failed."""
|
||||
push_op = MagicMock()
|
||||
push_op.id = uuid4()
|
||||
push_op.device_id = uuid4()
|
||||
push_op.tenant_id = uuid4()
|
||||
push_op.status = "pending_verification"
|
||||
push_op.scheduler_name = "mikrotik-portal-panic-revert"
|
||||
push_op.started_at = datetime.now(timezone.utc) - timedelta(minutes=10)
|
||||
|
||||
device = MagicMock()
|
||||
device.ip_address = "192.168.1.1"
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [push_op]
|
||||
dev_result = MagicMock()
|
||||
dev_result.scalar_one_or_none.return_value = device
|
||||
mock_session.execute = AsyncMock(side_effect=[mock_result, dev_result])
|
||||
|
||||
with patch(
|
||||
"app.services.restore_service._check_reachability",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
), patch(
|
||||
"app.services.restore_service._update_push_op_status",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update, patch(
|
||||
"app.services.restore_service._publish_push_progress",
|
||||
new_callable=AsyncMock,
|
||||
), patch(
|
||||
"app.services.crypto.decrypt_credentials_hybrid",
|
||||
new_callable=AsyncMock,
|
||||
return_value='{"username": "admin", "password": "test123"}',
|
||||
), patch(
|
||||
"app.services.restore_service.settings",
|
||||
):
|
||||
await recover_stale_push_operations(mock_session)
|
||||
|
||||
mock_update.assert_called_once()
|
||||
call_args = mock_update.call_args
|
||||
assert call_args[0][1] == "failed" or call_args[1].get("new_status") == "failed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recovery_skips_recent_ops():
|
||||
"""Operations less than 5 minutes old should not be recovered (still in progress)."""
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [] # Query filters by age
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
await recover_stale_push_operations(mock_session)
|
||||
# No errors, no updates — just returns cleanly
|
||||
156
backend/tests/test_push_rollback_subscriber.py
Normal file
156
backend/tests/test_push_rollback_subscriber.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Tests for push rollback NATS subscriber."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from app.services.push_rollback_subscriber import (
|
||||
handle_push_rollback,
|
||||
handle_push_alert,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rollback_triggers_restore():
|
||||
"""Push rollback should call restore_config with the pre-push commit SHA."""
|
||||
event = {
|
||||
"device_id": str(uuid4()),
|
||||
"tenant_id": str(uuid4()),
|
||||
"push_operation_id": str(uuid4()),
|
||||
"pre_push_commit_sha": "abc1234",
|
||||
}
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.push_rollback_subscriber.restore_service.restore_config",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"status": "committed"},
|
||||
) as mock_restore,
|
||||
patch(
|
||||
"app.services.push_rollback_subscriber.AdminAsyncSessionLocal",
|
||||
return_value=mock_cm,
|
||||
),
|
||||
):
|
||||
await handle_push_rollback(event)
|
||||
|
||||
mock_restore.assert_called_once()
|
||||
call_kwargs = mock_restore.call_args[1]
|
||||
assert call_kwargs["device_id"] == event["device_id"]
|
||||
assert call_kwargs["tenant_id"] == event["tenant_id"]
|
||||
assert call_kwargs["commit_sha"] == "abc1234"
|
||||
assert call_kwargs["db_session"] is mock_session
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rollback_missing_fields_skips():
|
||||
"""Rollback with missing fields should log warning and return."""
|
||||
event = {"device_id": str(uuid4())} # missing tenant_id and commit_sha
|
||||
|
||||
with patch(
|
||||
"app.services.push_rollback_subscriber.restore_service.restore_config",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_restore:
|
||||
await handle_push_rollback(event)
|
||||
|
||||
mock_restore.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rollback_failure_creates_alert():
|
||||
"""When restore_config raises, an alert should be created."""
|
||||
event = {
|
||||
"device_id": str(uuid4()),
|
||||
"tenant_id": str(uuid4()),
|
||||
"pre_push_commit_sha": "abc1234",
|
||||
}
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.push_rollback_subscriber.restore_service.restore_config",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("SSH failed"),
|
||||
),
|
||||
patch(
|
||||
"app.services.push_rollback_subscriber.AdminAsyncSessionLocal",
|
||||
return_value=mock_cm,
|
||||
),
|
||||
patch(
|
||||
"app.services.push_rollback_subscriber._create_push_alert",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_alert,
|
||||
):
|
||||
await handle_push_rollback(event)
|
||||
|
||||
mock_alert.assert_called_once_with(
|
||||
event["device_id"],
|
||||
event["tenant_id"],
|
||||
"template (auto-rollback failed)",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alert_creates_alert_record():
|
||||
"""Editor push alert should create a high-priority alert."""
|
||||
event = {
|
||||
"device_id": str(uuid4()),
|
||||
"tenant_id": str(uuid4()),
|
||||
"push_type": "editor",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"app.services.push_rollback_subscriber._create_push_alert",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_alert:
|
||||
await handle_push_alert(event)
|
||||
|
||||
mock_alert.assert_called_once_with(
|
||||
event["device_id"],
|
||||
event["tenant_id"],
|
||||
"editor",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alert_missing_fields_skips():
|
||||
"""Alert with missing fields should skip."""
|
||||
event = {"device_id": str(uuid4())} # missing tenant_id
|
||||
|
||||
with patch(
|
||||
"app.services.push_rollback_subscriber._create_push_alert",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_alert:
|
||||
await handle_push_alert(event)
|
||||
|
||||
mock_alert.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alert_defaults_to_editor_push_type():
|
||||
"""Alert without push_type should default to 'editor'."""
|
||||
event = {
|
||||
"device_id": str(uuid4()),
|
||||
"tenant_id": str(uuid4()),
|
||||
# no push_type
|
||||
}
|
||||
|
||||
with patch(
|
||||
"app.services.push_rollback_subscriber._create_push_alert",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_alert:
|
||||
await handle_push_alert(event)
|
||||
|
||||
mock_alert.assert_called_once_with(
|
||||
event["device_id"],
|
||||
event["tenant_id"],
|
||||
"editor",
|
||||
)
|
||||
211
backend/tests/test_restore_preview.py
Normal file
211
backend/tests/test_restore_preview.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""Tests for the preview-restore endpoint."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestPreviewRestoreEndpointExists:
|
||||
"""Verify the preview-restore route is registered on the config_backups router."""
|
||||
|
||||
def test_router_has_preview_restore_route(self):
|
||||
from app.routers.config_backups import router
|
||||
|
||||
paths = [r.path for r in router.routes]
|
||||
assert any("preview-restore" in p for p in paths), (
|
||||
f"No preview-restore route found. Routes: {paths}"
|
||||
)
|
||||
|
||||
def test_preview_restore_route_is_post(self):
|
||||
from app.routers.config_backups import router
|
||||
|
||||
for route in router.routes:
|
||||
if hasattr(route, "path") and "preview-restore" in route.path:
|
||||
assert "POST" in route.methods, (
|
||||
f"preview-restore route should be POST, got {route.methods}"
|
||||
)
|
||||
break
|
||||
else:
|
||||
pytest.fail("No preview-restore route found")
|
||||
|
||||
|
||||
class TestPreviewRestoreFunction:
|
||||
"""Test the preview_restore handler logic."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_returns_impact_analysis(self):
|
||||
"""preview_restore should return diff, categories, warnings, validation."""
|
||||
from app.routers.config_backups import preview_restore, RestoreRequest
|
||||
|
||||
tenant_id = uuid.uuid4()
|
||||
device_id = uuid.uuid4()
|
||||
|
||||
current_export = "/ip address\nadd address=192.168.1.1/24 interface=ether1\n"
|
||||
target_export = "/ip address\nadd address=10.0.0.1/24 interface=ether1\n"
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_user = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
body = RestoreRequest(commit_sha="abc1234")
|
||||
|
||||
# Mock device query result
|
||||
mock_device = MagicMock()
|
||||
mock_device.ip_address = "192.168.88.1"
|
||||
mock_device.encrypted_credentials_transit = "vault:v1:abc"
|
||||
mock_device.encrypted_credentials = None
|
||||
mock_device.tenant_id = tenant_id
|
||||
|
||||
mock_scalar = MagicMock()
|
||||
mock_scalar.scalar_one_or_none.return_value = mock_device
|
||||
mock_db.execute.return_value = mock_scalar
|
||||
|
||||
with patch(
|
||||
"app.routers.config_backups._check_tenant_access",
|
||||
new_callable=AsyncMock,
|
||||
), patch(
|
||||
"app.routers.config_backups.limiter.enabled",
|
||||
False,
|
||||
), patch(
|
||||
"app.routers.config_backups.git_store.read_file",
|
||||
return_value=target_export.encode(),
|
||||
), patch(
|
||||
"app.routers.config_backups.backup_service.capture_export",
|
||||
new_callable=AsyncMock,
|
||||
return_value=current_export,
|
||||
), patch(
|
||||
"app.routers.config_backups.decrypt_credentials_hybrid",
|
||||
new_callable=AsyncMock,
|
||||
return_value='{"username": "admin", "password": "pass"}',
|
||||
), patch(
|
||||
"app.routers.config_backups.settings",
|
||||
):
|
||||
result = await preview_restore(
|
||||
request=mock_request,
|
||||
tenant_id=tenant_id,
|
||||
device_id=device_id,
|
||||
body=body,
|
||||
db=mock_db,
|
||||
current_user=mock_user,
|
||||
)
|
||||
|
||||
assert "diff" in result
|
||||
assert "categories" in result
|
||||
assert "warnings" in result
|
||||
assert "validation" in result
|
||||
# Both exports have /ip address with different commands
|
||||
assert isinstance(result["categories"], list)
|
||||
assert isinstance(result["diff"], dict)
|
||||
assert "added" in result["diff"]
|
||||
assert "removed" in result["diff"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_falls_back_to_latest_backup_when_device_unreachable(self):
|
||||
"""When live capture fails, preview should fall back to the latest backup."""
|
||||
from app.routers.config_backups import preview_restore, RestoreRequest
|
||||
|
||||
tenant_id = uuid.uuid4()
|
||||
device_id = uuid.uuid4()
|
||||
|
||||
current_export = "/ip address\nadd address=192.168.1.1/24 interface=ether1\n"
|
||||
target_export = "/ip address\nadd address=10.0.0.1/24 interface=ether1\n"
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_user = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
body = RestoreRequest(commit_sha="abc1234")
|
||||
|
||||
# Mock device query result
|
||||
mock_device = MagicMock()
|
||||
mock_device.ip_address = "192.168.88.1"
|
||||
mock_device.encrypted_credentials_transit = "vault:v1:abc"
|
||||
mock_device.encrypted_credentials = None
|
||||
mock_device.tenant_id = tenant_id
|
||||
|
||||
# First call: device query, second call: latest backup query
|
||||
mock_device_result = MagicMock()
|
||||
mock_device_result.scalar_one_or_none.return_value = mock_device
|
||||
|
||||
mock_latest_run = MagicMock()
|
||||
mock_latest_run.commit_sha = "latest123"
|
||||
mock_backup_result = MagicMock()
|
||||
mock_backup_result.scalar_one_or_none.return_value = mock_latest_run
|
||||
|
||||
mock_db.execute.side_effect = [mock_device_result, mock_backup_result]
|
||||
|
||||
def mock_read_file(tid, sha, did, filename):
|
||||
if sha == "abc1234":
|
||||
return target_export.encode()
|
||||
elif sha == "latest123":
|
||||
return current_export.encode()
|
||||
return b""
|
||||
|
||||
with patch(
|
||||
"app.routers.config_backups._check_tenant_access",
|
||||
new_callable=AsyncMock,
|
||||
), patch(
|
||||
"app.routers.config_backups.limiter.enabled",
|
||||
False,
|
||||
), patch(
|
||||
"app.routers.config_backups.git_store.read_file",
|
||||
side_effect=mock_read_file,
|
||||
), patch(
|
||||
"app.routers.config_backups.backup_service.capture_export",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ConnectionError("Device unreachable"),
|
||||
), patch(
|
||||
"app.routers.config_backups.decrypt_credentials_hybrid",
|
||||
new_callable=AsyncMock,
|
||||
return_value='{"username": "admin", "password": "pass"}',
|
||||
), patch(
|
||||
"app.routers.config_backups.settings",
|
||||
):
|
||||
result = await preview_restore(
|
||||
request=mock_request,
|
||||
tenant_id=tenant_id,
|
||||
device_id=device_id,
|
||||
body=body,
|
||||
db=mock_db,
|
||||
current_user=mock_user,
|
||||
)
|
||||
|
||||
assert "diff" in result
|
||||
assert "categories" in result
|
||||
assert "warnings" in result
|
||||
assert "validation" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_404_when_backup_not_found(self):
|
||||
"""preview_restore should return 404 when the target backup doesn't exist."""
|
||||
from app.routers.config_backups import preview_restore, RestoreRequest
|
||||
from fastapi import HTTPException
|
||||
|
||||
tenant_id = uuid.uuid4()
|
||||
device_id = uuid.uuid4()
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_user = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
body = RestoreRequest(commit_sha="nonexistent")
|
||||
|
||||
with patch(
|
||||
"app.routers.config_backups._check_tenant_access",
|
||||
new_callable=AsyncMock,
|
||||
), patch(
|
||||
"app.routers.config_backups.limiter.enabled",
|
||||
False,
|
||||
), patch(
|
||||
"app.routers.config_backups.git_store.read_file",
|
||||
side_effect=KeyError("not found"),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await preview_restore(
|
||||
request=mock_request,
|
||||
tenant_id=tenant_id,
|
||||
device_id=device_id,
|
||||
body=body,
|
||||
db=mock_db,
|
||||
current_user=mock_user,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
106
backend/tests/test_rsc_parser.py
Normal file
106
backend/tests/test_rsc_parser.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Tests for RouterOS RSC export parser."""
|
||||
|
||||
import pytest
|
||||
from app.services.rsc_parser import parse_rsc, validate_rsc, compute_impact
|
||||
|
||||
|
||||
SAMPLE_EXPORT = """\
|
||||
# 2026-03-07 12:00:00 by RouterOS 7.16.2
|
||||
# software id = ABCD-1234
|
||||
#
|
||||
# model = RB750Gr3
|
||||
/interface bridge
|
||||
add name=bridge1
|
||||
/ip address
|
||||
add address=192.168.88.1/24 interface=ether1 network=192.168.88.0
|
||||
add address=10.0.0.1/24 interface=bridge1 network=10.0.0.0
|
||||
/ip firewall filter
|
||||
add action=accept chain=input comment="allow established" \\
|
||||
connection-state=established,related
|
||||
add action=drop chain=input in-interface-list=WAN
|
||||
/ip dns
|
||||
set servers=8.8.8.8,8.8.4.4
|
||||
/system identity
|
||||
set name=test-router
|
||||
"""
|
||||
|
||||
|
||||
class TestParseRsc:
|
||||
def test_extracts_categories(self):
|
||||
result = parse_rsc(SAMPLE_EXPORT)
|
||||
paths = [c["path"] for c in result["categories"]]
|
||||
assert "/interface bridge" in paths
|
||||
assert "/ip address" in paths
|
||||
assert "/ip firewall filter" in paths
|
||||
assert "/ip dns" in paths
|
||||
assert "/system identity" in paths
|
||||
|
||||
def test_counts_commands_per_category(self):
|
||||
result = parse_rsc(SAMPLE_EXPORT)
|
||||
cat_map = {c["path"]: c for c in result["categories"]}
|
||||
assert cat_map["/ip address"]["adds"] == 2
|
||||
assert cat_map["/ip address"]["sets"] == 0
|
||||
assert cat_map["/ip firewall filter"]["adds"] == 2
|
||||
assert cat_map["/ip dns"]["sets"] == 1
|
||||
assert cat_map["/system identity"]["sets"] == 1
|
||||
|
||||
def test_handles_continuation_lines(self):
|
||||
result = parse_rsc(SAMPLE_EXPORT)
|
||||
cat_map = {c["path"]: c for c in result["categories"]}
|
||||
# The firewall filter has a continuation line — should still count as 2 adds
|
||||
assert cat_map["/ip firewall filter"]["adds"] == 2
|
||||
|
||||
def test_ignores_comments_and_blank_lines(self):
|
||||
result = parse_rsc(SAMPLE_EXPORT)
|
||||
# Comments at top should not create categories
|
||||
paths = [c["path"] for c in result["categories"]]
|
||||
assert "#" not in paths
|
||||
|
||||
def test_empty_input(self):
|
||||
result = parse_rsc("")
|
||||
assert result["categories"] == []
|
||||
|
||||
|
||||
class TestValidateRsc:
|
||||
def test_valid_export_passes(self):
|
||||
result = validate_rsc(SAMPLE_EXPORT)
|
||||
assert result["valid"] is True
|
||||
assert result["errors"] == []
|
||||
|
||||
def test_unbalanced_quotes_detected(self):
|
||||
bad = '/system identity\nset name="missing-end-quote\n'
|
||||
result = validate_rsc(bad)
|
||||
assert result["valid"] is False
|
||||
assert any("quote" in e.lower() for e in result["errors"])
|
||||
|
||||
def test_truncated_continuation_detected(self):
|
||||
bad = '/ip address\nadd address=192.168.1.1/24 \\\n'
|
||||
result = validate_rsc(bad)
|
||||
assert result["valid"] is False
|
||||
assert any("truncat" in e.lower() or "continuation" in e.lower() for e in result["errors"])
|
||||
|
||||
|
||||
class TestComputeImpact:
|
||||
def test_high_risk_for_firewall_input(self):
|
||||
current = '/ip firewall filter\nadd action=accept chain=input\n'
|
||||
target = '/ip firewall filter\nadd action=drop chain=input\n'
|
||||
result = compute_impact(parse_rsc(current), parse_rsc(target))
|
||||
assert any(c["risk"] == "high" for c in result["categories"])
|
||||
|
||||
def test_high_risk_for_ip_address_changes(self):
|
||||
current = '/ip address\nadd address=192.168.1.1/24 interface=ether1\n'
|
||||
target = '/ip address\nadd address=10.0.0.1/24 interface=ether1\n'
|
||||
result = compute_impact(parse_rsc(current), parse_rsc(target))
|
||||
ip_cat = next(c for c in result["categories"] if c["path"] == "/ip address")
|
||||
assert ip_cat["risk"] in ("high", "medium")
|
||||
|
||||
def test_warnings_for_management_access(self):
|
||||
current = ""
|
||||
target = '/ip firewall filter\nadd action=drop chain=input protocol=tcp dst-port=22\n'
|
||||
result = compute_impact(parse_rsc(current), parse_rsc(target))
|
||||
assert len(result["warnings"]) > 0
|
||||
|
||||
def test_no_changes_no_warnings(self):
|
||||
same = '/ip dns\nset servers=8.8.8.8\n'
|
||||
result = compute_impact(parse_rsc(same), parse_rsc(same))
|
||||
assert result["warnings"] == [] or all(c["risk"] == "none" for c in result["categories"])
|
||||
128
backend/tests/test_srp_interop.py
Normal file
128
backend/tests/test_srp_interop.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""SRP-6a interop verification.
|
||||
|
||||
Uses srptools to perform a complete SRP handshake with fixed inputs,
|
||||
then prints all intermediate hex values. The TypeScript SRP client
|
||||
(frontend/src/lib/crypto/srp.ts) can be verified against these
|
||||
known-good values to catch encoding mismatches.
|
||||
|
||||
Run standalone:
|
||||
cd backend && python -m tests.test_srp_interop
|
||||
|
||||
Or via pytest:
|
||||
cd backend && python -m pytest tests/test_srp_interop.py -v
|
||||
"""
|
||||
|
||||
from srptools import SRPContext, SRPClientSession, SRPServerSession
|
||||
from srptools.constants import PRIME_2048, PRIME_2048_GEN
|
||||
|
||||
|
||||
# Fixed test inputs
|
||||
EMAIL = "test@example.com"
|
||||
PASSWORD = "test-password"
|
||||
|
||||
|
||||
def test_srp_roundtrip():
|
||||
"""Verify srptools produces a successful handshake end-to-end.
|
||||
|
||||
This test ensures the server-side library completes a full SRP
|
||||
handshake without errors. The printed intermediate values serve as
|
||||
reference data for the TypeScript client interop test.
|
||||
"""
|
||||
# Step 1: Registration -- compute salt + verifier (needs password in context)
|
||||
context = SRPContext(EMAIL, password=PASSWORD, prime=PRIME_2048, generator=PRIME_2048_GEN)
|
||||
username, verifier, salt = context.get_user_data_triplet()
|
||||
|
||||
print(f"\n--- SRP Interop Reference Values ---")
|
||||
print(f"email (I): {EMAIL}")
|
||||
print(f"salt (s): {salt}")
|
||||
print(f"verifier (v): {verifier[:64]}... (len={len(verifier)})")
|
||||
|
||||
# Step 2: Server init -- generate B (server only needs verifier, no password)
|
||||
server_context = SRPContext(EMAIL, prime=PRIME_2048, generator=PRIME_2048_GEN)
|
||||
server_session = SRPServerSession(server_context, verifier)
|
||||
server_public = server_session.public
|
||||
|
||||
print(f"server_public (B): {server_public[:64]}... (len={len(server_public)})")
|
||||
|
||||
# Step 3: Client init -- generate A (client needs password for proof)
|
||||
client_context = SRPContext(EMAIL, password=PASSWORD, prime=PRIME_2048, generator=PRIME_2048_GEN)
|
||||
client_session = SRPClientSession(client_context)
|
||||
client_public = client_session.public
|
||||
|
||||
print(f"client_public (A): {client_public[:64]}... (len={len(client_public)})")
|
||||
|
||||
# Step 4: Client processes B
|
||||
client_session.process(server_public, salt)
|
||||
|
||||
# Step 5: Server processes A
|
||||
server_session.process(client_public, salt)
|
||||
|
||||
# Step 6: Client generates proof M1
|
||||
client_proof = client_session.key_proof
|
||||
|
||||
print(f"client_proof (M1): {client_proof}")
|
||||
|
||||
# Step 7: Server verifies M1 and generates M2
|
||||
server_session.verify_proof(client_proof)
|
||||
server_proof = server_session.key_proof_hash
|
||||
|
||||
print(f"server_proof (M2): {server_proof}")
|
||||
|
||||
# Step 8: Client verifies M2
|
||||
client_session.verify_proof(server_proof)
|
||||
|
||||
# Step 9: Verify session keys match
|
||||
assert client_session.key == server_session.key, (
|
||||
f"Session key mismatch: client={client_session.key[:32]}... "
|
||||
f"server={server_session.key[:32]}..."
|
||||
)
|
||||
|
||||
print(f"session_key (K): {client_session.key[:64]}... (len={len(client_session.key)})")
|
||||
print(f"--- Handshake PASSED ---\n")
|
||||
|
||||
|
||||
def test_srp_bad_proof_rejected():
|
||||
"""Verify that an incorrect M1 proof is rejected by the server."""
|
||||
context = SRPContext(EMAIL, password=PASSWORD, prime=PRIME_2048, generator=PRIME_2048_GEN)
|
||||
_, verifier, salt = context.get_user_data_triplet()
|
||||
|
||||
server_context = SRPContext(EMAIL, prime=PRIME_2048, generator=PRIME_2048_GEN)
|
||||
server_session = SRPServerSession(server_context, verifier)
|
||||
|
||||
client_context = SRPContext(EMAIL, password=PASSWORD, prime=PRIME_2048, generator=PRIME_2048_GEN)
|
||||
client_session = SRPClientSession(client_context)
|
||||
|
||||
client_session.process(server_session.public, salt)
|
||||
server_session.process(client_session.public, salt)
|
||||
|
||||
# Tamper with proof
|
||||
bad_proof = "00" * 32
|
||||
|
||||
try:
|
||||
server_session.verify_proof(bad_proof)
|
||||
assert False, "Server should have rejected bad proof"
|
||||
except Exception:
|
||||
pass # Expected: bad proof rejected
|
||||
|
||||
|
||||
def test_srp_deterministic_verifier():
|
||||
"""Verify that the same salt + identity produce consistent verifiers."""
|
||||
context1 = SRPContext(EMAIL, password=PASSWORD, prime=PRIME_2048, generator=PRIME_2048_GEN)
|
||||
_, v1, s1 = context1.get_user_data_triplet()
|
||||
|
||||
# Same email + password, new context
|
||||
context2 = SRPContext(EMAIL, password=PASSWORD, prime=PRIME_2048, generator=PRIME_2048_GEN)
|
||||
_, v2, s2 = context2.get_user_data_triplet()
|
||||
|
||||
# srptools generates random salt each time, so verifiers will differ.
|
||||
# But the output format is consistent.
|
||||
assert len(v1) > 0
|
||||
assert len(v2) > 0
|
||||
assert len(s1) == len(s2), "Salt lengths should be consistent"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_srp_roundtrip()
|
||||
test_srp_bad_proof_rejected()
|
||||
test_srp_deterministic_verifier()
|
||||
print("All SRP interop tests passed.")
|
||||
0
backend/tests/unit/__init__.py
Normal file
0
backend/tests/unit/__init__.py
Normal file
76
backend/tests/unit/test_api_key_service.py
Normal file
76
backend/tests/unit/test_api_key_service.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Unit tests for API key service.
|
||||
|
||||
Tests cover:
|
||||
- Key generation format (mktp_ prefix, sufficient length)
|
||||
- Key hashing (SHA-256 hex digest, 64 chars)
|
||||
- Scope validation against allowed list
|
||||
- Key prefix extraction
|
||||
|
||||
These are pure function tests -- no database or async required.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
|
||||
from app.services.api_key_service import (
|
||||
ALLOWED_SCOPES,
|
||||
generate_raw_key,
|
||||
hash_key,
|
||||
)
|
||||
|
||||
|
||||
class TestKeyGeneration:
|
||||
"""Tests for API key generation."""
|
||||
|
||||
def test_key_starts_with_prefix(self):
|
||||
key = generate_raw_key()
|
||||
assert key.startswith("mktp_")
|
||||
|
||||
def test_key_has_sufficient_length(self):
|
||||
"""Key should be mktp_ + at least 32 chars of randomness."""
|
||||
key = generate_raw_key()
|
||||
assert len(key) >= 37 # "mktp_" (5) + 32
|
||||
|
||||
def test_key_uniqueness(self):
|
||||
"""Two generated keys should never be the same."""
|
||||
key1 = generate_raw_key()
|
||||
key2 = generate_raw_key()
|
||||
assert key1 != key2
|
||||
|
||||
|
||||
class TestKeyHashing:
|
||||
"""Tests for SHA-256 key hashing."""
|
||||
|
||||
def test_hash_produces_64_char_hex(self):
|
||||
key = "mktp_test1234567890abcdef"
|
||||
h = hash_key(key)
|
||||
assert len(h) == 64
|
||||
assert all(c in "0123456789abcdef" for c in h)
|
||||
|
||||
def test_hash_is_sha256(self):
|
||||
key = "mktp_test1234567890abcdef"
|
||||
expected = hashlib.sha256(key.encode()).hexdigest()
|
||||
assert hash_key(key) == expected
|
||||
|
||||
def test_hash_deterministic(self):
|
||||
key = generate_raw_key()
|
||||
assert hash_key(key) == hash_key(key)
|
||||
|
||||
def test_different_keys_different_hashes(self):
|
||||
key1 = generate_raw_key()
|
||||
key2 = generate_raw_key()
|
||||
assert hash_key(key1) != hash_key(key2)
|
||||
|
||||
|
||||
class TestAllowedScopes:
|
||||
"""Tests for scope definitions."""
|
||||
|
||||
def test_allowed_scopes_contains_expected(self):
|
||||
expected = {
|
||||
"devices:read",
|
||||
"devices:write",
|
||||
"config:read",
|
||||
"config:write",
|
||||
"alerts:read",
|
||||
"firmware:write",
|
||||
}
|
||||
assert expected == ALLOWED_SCOPES
|
||||
75
backend/tests/unit/test_audit_service.py
Normal file
75
backend/tests/unit/test_audit_service.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Unit tests for the audit service and model.
|
||||
|
||||
Tests cover:
|
||||
- AuditLog model can be imported
|
||||
- log_action function signature is correct
|
||||
- Audit logs router is importable with expected endpoints
|
||||
- CSV export endpoint exists
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestAuditLogModel:
|
||||
"""Tests for the AuditLog ORM model."""
|
||||
|
||||
def test_model_importable(self):
|
||||
from app.models.audit_log import AuditLog
|
||||
assert AuditLog.__tablename__ == "audit_logs"
|
||||
|
||||
def test_model_has_required_columns(self):
|
||||
from app.models.audit_log import AuditLog
|
||||
mapper = AuditLog.__table__.columns
|
||||
expected_columns = {
|
||||
"id", "tenant_id", "user_id", "action",
|
||||
"resource_type", "resource_id", "device_id",
|
||||
"details", "ip_address", "created_at",
|
||||
}
|
||||
actual_columns = {c.name for c in mapper}
|
||||
assert expected_columns.issubset(actual_columns), (
|
||||
f"Missing columns: {expected_columns - actual_columns}"
|
||||
)
|
||||
|
||||
def test_model_exported_from_init(self):
|
||||
from app.models import AuditLog
|
||||
assert AuditLog.__tablename__ == "audit_logs"
|
||||
|
||||
|
||||
class TestAuditService:
|
||||
"""Tests for the audit service log_action function."""
|
||||
|
||||
def test_log_action_importable(self):
|
||||
from app.services.audit_service import log_action
|
||||
assert callable(log_action)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_action_does_not_raise_on_db_error(self):
|
||||
"""log_action must swallow exceptions so it never breaks the caller."""
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.execute = AsyncMock(side_effect=Exception("DB down"))
|
||||
|
||||
# Should NOT raise even though the DB call fails
|
||||
await log_action(
|
||||
db=mock_db,
|
||||
tenant_id=uuid.uuid4(),
|
||||
user_id=uuid.uuid4(),
|
||||
action="test_action",
|
||||
)
|
||||
|
||||
|
||||
class TestAuditRouter:
|
||||
"""Tests for the audit logs router."""
|
||||
|
||||
def test_router_importable(self):
|
||||
from app.routers.audit_logs import router
|
||||
assert router is not None
|
||||
|
||||
def test_router_has_audit_logs_endpoint(self):
|
||||
from app.routers.audit_logs import router
|
||||
paths = [route.path for route in router.routes]
|
||||
assert "/audit-logs" in paths or any("/audit-logs" in p for p in paths)
|
||||
169
backend/tests/unit/test_auth.py
Normal file
169
backend/tests/unit/test_auth.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Unit tests for the JWT authentication service.
|
||||
|
||||
Tests cover:
|
||||
- Password hashing and verification (bcrypt)
|
||||
- JWT access token creation and validation
|
||||
- JWT refresh token creation and validation
|
||||
- Token rejection for wrong type, expired, invalid, missing subject
|
||||
|
||||
These are pure function tests -- no database or async required.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from jose import jwt
|
||||
|
||||
from app.services.auth import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
hash_password,
|
||||
verify_password,
|
||||
verify_token,
|
||||
)
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class TestPasswordHashing:
|
||||
"""Tests for bcrypt password hashing."""
|
||||
|
||||
def test_hash_returns_different_string(self):
|
||||
password = "test-password-123!"
|
||||
hashed = hash_password(password)
|
||||
assert hashed != password
|
||||
|
||||
def test_hash_verify_roundtrip(self):
|
||||
password = "test-password-123!"
|
||||
hashed = hash_password(password)
|
||||
assert verify_password(password, hashed) is True
|
||||
|
||||
def test_verify_rejects_wrong_password(self):
|
||||
hashed = hash_password("correct-password")
|
||||
assert verify_password("wrong-password", hashed) is False
|
||||
|
||||
def test_hash_uses_unique_salts(self):
|
||||
"""Each hash should be different even for the same password (random salt)."""
|
||||
hash1 = hash_password("same-password")
|
||||
hash2 = hash_password("same-password")
|
||||
assert hash1 != hash2
|
||||
|
||||
def test_verify_both_hashes_valid(self):
|
||||
"""Both unique hashes should verify against the original password."""
|
||||
password = "same-password"
|
||||
hash1 = hash_password(password)
|
||||
hash2 = hash_password(password)
|
||||
assert verify_password(password, hash1) is True
|
||||
assert verify_password(password, hash2) is True
|
||||
|
||||
|
||||
class TestAccessToken:
|
||||
"""Tests for JWT access token creation and validation."""
|
||||
|
||||
def test_create_and_verify_roundtrip(self):
|
||||
user_id = uuid.uuid4()
|
||||
tenant_id = uuid.uuid4()
|
||||
token = create_access_token(user_id=user_id, tenant_id=tenant_id, role="admin")
|
||||
payload = verify_token(token, expected_type="access")
|
||||
|
||||
assert payload["sub"] == str(user_id)
|
||||
assert payload["tenant_id"] == str(tenant_id)
|
||||
assert payload["role"] == "admin"
|
||||
assert payload["type"] == "access"
|
||||
|
||||
def test_super_admin_null_tenant(self):
|
||||
user_id = uuid.uuid4()
|
||||
token = create_access_token(user_id=user_id, tenant_id=None, role="super_admin")
|
||||
payload = verify_token(token, expected_type="access")
|
||||
|
||||
assert payload["sub"] == str(user_id)
|
||||
assert payload["tenant_id"] is None
|
||||
assert payload["role"] == "super_admin"
|
||||
|
||||
def test_contains_expiry(self):
|
||||
token = create_access_token(
|
||||
user_id=uuid.uuid4(), tenant_id=uuid.uuid4(), role="viewer"
|
||||
)
|
||||
payload = verify_token(token, expected_type="access")
|
||||
assert "exp" in payload
|
||||
assert "iat" in payload
|
||||
|
||||
|
||||
class TestRefreshToken:
|
||||
"""Tests for JWT refresh token creation and validation."""
|
||||
|
||||
def test_create_and_verify_roundtrip(self):
|
||||
user_id = uuid.uuid4()
|
||||
token = create_refresh_token(user_id=user_id)
|
||||
payload = verify_token(token, expected_type="refresh")
|
||||
|
||||
assert payload["sub"] == str(user_id)
|
||||
assert payload["type"] == "refresh"
|
||||
|
||||
def test_refresh_token_has_no_tenant_or_role(self):
|
||||
token = create_refresh_token(user_id=uuid.uuid4())
|
||||
payload = verify_token(token, expected_type="refresh")
|
||||
|
||||
# Refresh tokens intentionally omit tenant_id and role
|
||||
assert "tenant_id" not in payload
|
||||
assert "role" not in payload
|
||||
|
||||
|
||||
class TestTokenRejection:
|
||||
"""Tests for JWT token validation failure cases."""
|
||||
|
||||
def test_rejects_wrong_type(self):
|
||||
"""Access token should not verify as refresh, and vice versa."""
|
||||
access_token = create_access_token(
|
||||
user_id=uuid.uuid4(), tenant_id=uuid.uuid4(), role="admin"
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_token(access_token, expected_type="refresh")
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_rejects_expired_token(self):
|
||||
"""Manually craft an expired token and verify it is rejected."""
|
||||
expired_payload = {
|
||||
"sub": str(uuid.uuid4()),
|
||||
"type": "access",
|
||||
"exp": datetime.now(UTC) - timedelta(hours=1),
|
||||
"iat": datetime.now(UTC) - timedelta(hours=2),
|
||||
}
|
||||
expired_token = jwt.encode(
|
||||
expired_payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_token(expired_token, expected_type="access")
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_rejects_invalid_token(self):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_token("not-a-valid-jwt", expected_type="access")
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_rejects_wrong_signing_key(self):
|
||||
"""Token signed with a different key should be rejected."""
|
||||
payload = {
|
||||
"sub": str(uuid.uuid4()),
|
||||
"type": "access",
|
||||
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||
}
|
||||
wrong_key_token = jwt.encode(payload, "wrong-secret-key", algorithm="HS256")
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_token(wrong_key_token, expected_type="access")
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_rejects_missing_subject(self):
|
||||
"""Token without 'sub' claim should be rejected."""
|
||||
no_sub_payload = {
|
||||
"type": "access",
|
||||
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||
}
|
||||
no_sub_token = jwt.encode(
|
||||
no_sub_payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_token(no_sub_token, expected_type="access")
|
||||
assert exc_info.value.status_code == 401
|
||||
126
backend/tests/unit/test_crypto.py
Normal file
126
backend/tests/unit/test_crypto.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Unit tests for the credential encryption/decryption service.
|
||||
|
||||
Tests cover:
|
||||
- Encryption/decryption round-trip with valid key
|
||||
- Random nonce ensures different ciphertext per encryption
|
||||
- Wrong key rejection (InvalidTag)
|
||||
- Invalid key length rejection (ValueError)
|
||||
- Unicode and JSON payload handling
|
||||
- Tampered ciphertext detection
|
||||
|
||||
These are pure function tests -- no database or async required.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from cryptography.exceptions import InvalidTag
|
||||
|
||||
from app.services.crypto import decrypt_credentials, encrypt_credentials
|
||||
|
||||
|
||||
class TestEncryptDecryptRoundTrip:
|
||||
"""Tests for successful encryption/decryption cycles."""
|
||||
|
||||
def test_basic_roundtrip(self):
|
||||
key = os.urandom(32)
|
||||
plaintext = "secret-password"
|
||||
ciphertext = encrypt_credentials(plaintext, key)
|
||||
result = decrypt_credentials(ciphertext, key)
|
||||
assert result == plaintext
|
||||
|
||||
def test_json_credentials_roundtrip(self):
|
||||
"""The actual use case: encrypting JSON credential objects."""
|
||||
key = os.urandom(32)
|
||||
creds = json.dumps({"username": "admin", "password": "RouterOS!123"})
|
||||
ciphertext = encrypt_credentials(creds, key)
|
||||
result = decrypt_credentials(ciphertext, key)
|
||||
parsed = json.loads(result)
|
||||
assert parsed["username"] == "admin"
|
||||
assert parsed["password"] == "RouterOS!123"
|
||||
|
||||
def test_unicode_roundtrip(self):
|
||||
key = os.urandom(32)
|
||||
plaintext = "password-with-unicode-\u00e9\u00e8\u00ea"
|
||||
ciphertext = encrypt_credentials(plaintext, key)
|
||||
result = decrypt_credentials(ciphertext, key)
|
||||
assert result == plaintext
|
||||
|
||||
def test_empty_string_roundtrip(self):
|
||||
key = os.urandom(32)
|
||||
ciphertext = encrypt_credentials("", key)
|
||||
result = decrypt_credentials(ciphertext, key)
|
||||
assert result == ""
|
||||
|
||||
def test_long_payload_roundtrip(self):
|
||||
"""Ensure large payloads work (e.g., SSH keys in credentials)."""
|
||||
key = os.urandom(32)
|
||||
plaintext = "x" * 10000
|
||||
ciphertext = encrypt_credentials(plaintext, key)
|
||||
result = decrypt_credentials(ciphertext, key)
|
||||
assert result == plaintext
|
||||
|
||||
|
||||
class TestNonceRandomness:
|
||||
"""Tests that encryption uses random nonces."""
|
||||
|
||||
def test_different_ciphertext_each_time(self):
|
||||
"""Two encryptions of the same plaintext should produce different ciphertext
|
||||
because a random 12-byte nonce is generated each time."""
|
||||
key = os.urandom(32)
|
||||
plaintext = "same-plaintext"
|
||||
ct1 = encrypt_credentials(plaintext, key)
|
||||
ct2 = encrypt_credentials(plaintext, key)
|
||||
assert ct1 != ct2
|
||||
|
||||
def test_both_decrypt_correctly(self):
|
||||
"""Both different ciphertexts should decrypt to the same plaintext."""
|
||||
key = os.urandom(32)
|
||||
plaintext = "same-plaintext"
|
||||
ct1 = encrypt_credentials(plaintext, key)
|
||||
ct2 = encrypt_credentials(plaintext, key)
|
||||
assert decrypt_credentials(ct1, key) == plaintext
|
||||
assert decrypt_credentials(ct2, key) == plaintext
|
||||
|
||||
|
||||
class TestDecryptionFailures:
|
||||
"""Tests for proper rejection of invalid inputs."""
|
||||
|
||||
def test_wrong_key_raises_invalid_tag(self):
|
||||
key1 = os.urandom(32)
|
||||
key2 = os.urandom(32)
|
||||
ciphertext = encrypt_credentials("secret", key1)
|
||||
with pytest.raises(InvalidTag):
|
||||
decrypt_credentials(ciphertext, key2)
|
||||
|
||||
def test_tampered_ciphertext_raises_invalid_tag(self):
|
||||
"""Flipping a byte in the ciphertext should cause authentication failure."""
|
||||
key = os.urandom(32)
|
||||
ciphertext = bytearray(encrypt_credentials("secret", key))
|
||||
# Flip a byte in the encrypted portion (after the 12-byte nonce)
|
||||
ciphertext[15] ^= 0xFF
|
||||
with pytest.raises(InvalidTag):
|
||||
decrypt_credentials(bytes(ciphertext), key)
|
||||
|
||||
|
||||
class TestKeyValidation:
|
||||
"""Tests for encryption key length validation."""
|
||||
|
||||
def test_short_key_encrypt_raises(self):
|
||||
with pytest.raises(ValueError, match="32 bytes"):
|
||||
encrypt_credentials("test", os.urandom(16))
|
||||
|
||||
def test_long_key_encrypt_raises(self):
|
||||
with pytest.raises(ValueError, match="32 bytes"):
|
||||
encrypt_credentials("test", os.urandom(64))
|
||||
|
||||
def test_short_key_decrypt_raises(self):
|
||||
key = os.urandom(32)
|
||||
ciphertext = encrypt_credentials("test", key)
|
||||
with pytest.raises(ValueError, match="32 bytes"):
|
||||
decrypt_credentials(ciphertext, os.urandom(16))
|
||||
|
||||
def test_empty_key_raises(self):
|
||||
with pytest.raises(ValueError, match="32 bytes"):
|
||||
encrypt_credentials("test", b"")
|
||||
121
backend/tests/unit/test_maintenance_windows.py
Normal file
121
backend/tests/unit/test_maintenance_windows.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Unit tests for maintenance window model, router schemas, and alert suppression.
|
||||
|
||||
Tests cover:
|
||||
- MaintenanceWindow ORM model imports and field definitions
|
||||
- MaintenanceWindowCreate/Update/Response Pydantic schema validation
|
||||
- Alert evaluator _is_device_in_maintenance integration
|
||||
- Router registration in main app
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
class TestMaintenanceWindowModel:
|
||||
"""Test that the MaintenanceWindow ORM model is importable and has correct fields."""
|
||||
|
||||
def test_model_importable(self):
|
||||
from app.models.maintenance_window import MaintenanceWindow
|
||||
assert MaintenanceWindow.__tablename__ == "maintenance_windows"
|
||||
|
||||
def test_model_exported_from_init(self):
|
||||
from app.models import MaintenanceWindow
|
||||
assert MaintenanceWindow.__tablename__ == "maintenance_windows"
|
||||
|
||||
def test_model_has_required_columns(self):
|
||||
from app.models.maintenance_window import MaintenanceWindow
|
||||
mapper = MaintenanceWindow.__mapper__
|
||||
column_names = {c.key for c in mapper.columns}
|
||||
expected = {
|
||||
"id", "tenant_id", "name", "device_ids",
|
||||
"start_at", "end_at", "suppress_alerts",
|
||||
"notes", "created_by", "created_at", "updated_at",
|
||||
}
|
||||
assert expected.issubset(column_names), f"Missing columns: {expected - column_names}"
|
||||
|
||||
|
||||
class TestMaintenanceWindowSchemas:
|
||||
"""Test Pydantic schemas for request/response validation."""
|
||||
|
||||
def test_create_schema_valid(self):
|
||||
from app.routers.maintenance_windows import MaintenanceWindowCreate
|
||||
data = MaintenanceWindowCreate(
|
||||
name="Nightly update",
|
||||
device_ids=["abc-123"],
|
||||
start_at=datetime.now(timezone.utc),
|
||||
end_at=datetime.now(timezone.utc) + timedelta(hours=2),
|
||||
suppress_alerts=True,
|
||||
notes="Scheduled maintenance",
|
||||
)
|
||||
assert data.name == "Nightly update"
|
||||
assert data.suppress_alerts is True
|
||||
|
||||
def test_create_schema_defaults(self):
|
||||
from app.routers.maintenance_windows import MaintenanceWindowCreate
|
||||
data = MaintenanceWindowCreate(
|
||||
name="Quick reboot",
|
||||
device_ids=[],
|
||||
start_at=datetime.now(timezone.utc),
|
||||
end_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
)
|
||||
assert data.suppress_alerts is True # default
|
||||
assert data.notes is None
|
||||
|
||||
def test_update_schema_partial(self):
|
||||
from app.routers.maintenance_windows import MaintenanceWindowUpdate
|
||||
data = MaintenanceWindowUpdate(name="Updated name")
|
||||
assert data.name == "Updated name"
|
||||
assert data.device_ids is None # all optional
|
||||
|
||||
def test_response_schema(self):
|
||||
from app.routers.maintenance_windows import MaintenanceWindowResponse
|
||||
data = MaintenanceWindowResponse(
|
||||
id="abc",
|
||||
tenant_id="def",
|
||||
name="Test",
|
||||
device_ids=["x"],
|
||||
start_at=datetime.now(timezone.utc).isoformat(),
|
||||
end_at=datetime.now(timezone.utc).isoformat(),
|
||||
suppress_alerts=True,
|
||||
notes=None,
|
||||
created_by="ghi",
|
||||
created_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
assert data.id == "abc"
|
||||
|
||||
|
||||
class TestRouterRegistration:
|
||||
"""Test that the maintenance_windows router is properly registered."""
|
||||
|
||||
def test_router_importable(self):
|
||||
from app.routers.maintenance_windows import router
|
||||
assert router is not None
|
||||
|
||||
def test_router_has_routes(self):
|
||||
from app.routers.maintenance_windows import router
|
||||
paths = [r.path for r in router.routes]
|
||||
assert any("maintenance-windows" in p for p in paths)
|
||||
|
||||
def test_main_app_includes_router(self):
|
||||
try:
|
||||
from app.main import app
|
||||
except ImportError:
|
||||
pytest.skip("app.main requires full dependencies (prometheus, etc.)")
|
||||
route_paths = [r.path for r in app.routes]
|
||||
route_paths_str = " ".join(route_paths)
|
||||
assert "maintenance-windows" in route_paths_str
|
||||
|
||||
|
||||
class TestAlertEvaluatorMaintenance:
|
||||
"""Test that alert_evaluator has maintenance window check capability."""
|
||||
|
||||
def test_maintenance_cache_exists(self):
|
||||
from app.services import alert_evaluator
|
||||
assert hasattr(alert_evaluator, "_maintenance_cache")
|
||||
|
||||
def test_is_device_in_maintenance_function_exists(self):
|
||||
from app.services.alert_evaluator import _is_device_in_maintenance
|
||||
assert callable(_is_device_in_maintenance)
|
||||
231
backend/tests/unit/test_security.py
Normal file
231
backend/tests/unit/test_security.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""Unit tests for security hardening.
|
||||
|
||||
Tests cover:
|
||||
- Production startup validation (insecure defaults rejection)
|
||||
- Security headers middleware (per-environment header behavior)
|
||||
|
||||
These are pure function/middleware tests -- no database or async required
|
||||
for startup validation, async only for middleware tests.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.config import KNOWN_INSECURE_DEFAULTS, validate_production_settings
|
||||
|
||||
|
||||
class TestStartupValidation:
|
||||
"""Tests for validate_production_settings()."""
|
||||
|
||||
def _make_settings(self, **kwargs):
|
||||
"""Create a mock settings object with given field values."""
|
||||
defaults = {
|
||||
"ENVIRONMENT": "dev",
|
||||
"JWT_SECRET_KEY": "change-this-in-production-use-a-long-random-string",
|
||||
"CREDENTIAL_ENCRYPTION_KEY": "LLLjnfBZTSycvL2U07HDSxUeTtLxb9cZzryQl0R9E4w=",
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return SimpleNamespace(**defaults)
|
||||
|
||||
def test_production_rejects_insecure_jwt_secret(self):
|
||||
"""Production with default JWT secret must exit."""
|
||||
settings = self._make_settings(
|
||||
ENVIRONMENT="production",
|
||||
JWT_SECRET_KEY=KNOWN_INSECURE_DEFAULTS["JWT_SECRET_KEY"][0],
|
||||
)
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
validate_production_settings(settings)
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
def test_production_rejects_insecure_encryption_key(self):
|
||||
"""Production with default encryption key must exit."""
|
||||
settings = self._make_settings(
|
||||
ENVIRONMENT="production",
|
||||
JWT_SECRET_KEY="a-real-secure-jwt-secret-that-is-long-enough",
|
||||
CREDENTIAL_ENCRYPTION_KEY=KNOWN_INSECURE_DEFAULTS["CREDENTIAL_ENCRYPTION_KEY"][0],
|
||||
)
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
validate_production_settings(settings)
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
def test_dev_allows_insecure_defaults(self):
|
||||
"""Dev environment allows insecure defaults without error."""
|
||||
settings = self._make_settings(
|
||||
ENVIRONMENT="dev",
|
||||
JWT_SECRET_KEY=KNOWN_INSECURE_DEFAULTS["JWT_SECRET_KEY"][0],
|
||||
CREDENTIAL_ENCRYPTION_KEY=KNOWN_INSECURE_DEFAULTS["CREDENTIAL_ENCRYPTION_KEY"][0],
|
||||
)
|
||||
# Should NOT raise
|
||||
validate_production_settings(settings)
|
||||
|
||||
def test_production_allows_secure_values(self):
|
||||
"""Production with non-default secrets should pass."""
|
||||
settings = self._make_settings(
|
||||
ENVIRONMENT="production",
|
||||
JWT_SECRET_KEY="a-real-secure-jwt-secret-that-is-long-enough-for-production",
|
||||
CREDENTIAL_ENCRYPTION_KEY="dGhpcyBpcyBhIHNlY3VyZSBrZXkgdGhhdCBpcw==",
|
||||
)
|
||||
# Should NOT raise
|
||||
validate_production_settings(settings)
|
||||
|
||||
|
||||
class TestSecurityHeadersMiddleware:
|
||||
"""Tests for SecurityHeadersMiddleware."""
|
||||
|
||||
@pytest.fixture
|
||||
def prod_app(self):
|
||||
"""Create a minimal FastAPI app with security middleware in production mode."""
|
||||
from fastapi import FastAPI
|
||||
from app.middleware.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(SecurityHeadersMiddleware, environment="production")
|
||||
|
||||
@app.get("/test")
|
||||
async def test_endpoint():
|
||||
return {"status": "ok"}
|
||||
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def dev_app(self):
|
||||
"""Create a minimal FastAPI app with security middleware in dev mode."""
|
||||
from fastapi import FastAPI
|
||||
from app.middleware.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(SecurityHeadersMiddleware, environment="dev")
|
||||
|
||||
@app.get("/test")
|
||||
async def test_endpoint():
|
||||
return {"status": "ok"}
|
||||
|
||||
return app
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_production_includes_hsts(self, prod_app):
|
||||
"""Production responses must include HSTS header."""
|
||||
import httpx
|
||||
|
||||
transport = httpx.ASGITransport(app=prod_app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/test")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["strict-transport-security"] == "max-age=31536000; includeSubDomains"
|
||||
assert response.headers["x-content-type-options"] == "nosniff"
|
||||
assert response.headers["x-frame-options"] == "DENY"
|
||||
assert response.headers["cache-control"] == "no-store"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dev_excludes_hsts(self, dev_app):
|
||||
"""Dev responses must NOT include HSTS (breaks plain HTTP)."""
|
||||
import httpx
|
||||
|
||||
transport = httpx.ASGITransport(app=dev_app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/test")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "strict-transport-security" not in response.headers
|
||||
assert response.headers["x-content-type-options"] == "nosniff"
|
||||
assert response.headers["x-frame-options"] == "DENY"
|
||||
assert response.headers["cache-control"] == "no-store"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csp_header_present_production(self, prod_app):
|
||||
"""Production responses must include CSP header."""
|
||||
import httpx
|
||||
|
||||
transport = httpx.ASGITransport(app=prod_app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/test")
|
||||
|
||||
assert "content-security-policy" in response.headers
|
||||
csp = response.headers["content-security-policy"]
|
||||
assert "default-src 'self'" in csp
|
||||
assert "script-src" in csp
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csp_header_present_dev(self, dev_app):
|
||||
"""Dev responses must include CSP header."""
|
||||
import httpx
|
||||
|
||||
transport = httpx.ASGITransport(app=dev_app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/test")
|
||||
|
||||
assert "content-security-policy" in response.headers
|
||||
csp = response.headers["content-security-policy"]
|
||||
assert "default-src 'self'" in csp
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csp_production_blocks_inline_scripts(self, prod_app):
|
||||
"""Production CSP must block inline scripts (no unsafe-inline in script-src)."""
|
||||
import httpx
|
||||
|
||||
transport = httpx.ASGITransport(app=prod_app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/test")
|
||||
|
||||
csp = response.headers["content-security-policy"]
|
||||
# Extract the script-src directive value
|
||||
script_src = [d for d in csp.split(";") if "script-src" in d][0]
|
||||
assert "'unsafe-inline'" not in script_src
|
||||
assert "'unsafe-eval'" not in script_src
|
||||
assert "'self'" in script_src
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csp_dev_allows_unsafe_inline(self, dev_app):
|
||||
"""Dev CSP must allow unsafe-inline and unsafe-eval for Vite HMR."""
|
||||
import httpx
|
||||
|
||||
transport = httpx.ASGITransport(app=dev_app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/test")
|
||||
|
||||
csp = response.headers["content-security-policy"]
|
||||
script_src = [d for d in csp.split(";") if "script-src" in d][0]
|
||||
assert "'unsafe-inline'" in script_src
|
||||
assert "'unsafe-eval'" in script_src
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csp_production_allows_inline_styles(self, prod_app):
|
||||
"""Production CSP must allow unsafe-inline for styles (Tailwind, Framer Motion, Radix)."""
|
||||
import httpx
|
||||
|
||||
transport = httpx.ASGITransport(app=prod_app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/test")
|
||||
|
||||
csp = response.headers["content-security-policy"]
|
||||
style_src = [d for d in csp.split(";") if "style-src" in d][0]
|
||||
assert "'unsafe-inline'" in style_src
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csp_allows_websocket_connections(self, prod_app):
|
||||
"""CSP must allow wss: and ws: for SSE/WebSocket connections."""
|
||||
import httpx
|
||||
|
||||
transport = httpx.ASGITransport(app=prod_app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/test")
|
||||
|
||||
csp = response.headers["content-security-policy"]
|
||||
connect_src = [d for d in csp.split(";") if "connect-src" in d][0]
|
||||
assert "wss:" in connect_src
|
||||
assert "ws:" in connect_src
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csp_frame_ancestors_none(self, prod_app):
|
||||
"""CSP must include frame-ancestors 'none' (anti-clickjacking)."""
|
||||
import httpx
|
||||
|
||||
transport = httpx.ASGITransport(app=prod_app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/test")
|
||||
|
||||
csp = response.headers["content-security-policy"]
|
||||
assert "frame-ancestors 'none'" in csp
|
||||
Reference in New Issue
Block a user