fix(lint): resolve all ruff lint errors

Add ruff config to exclude alembic E402, SQLAlchemy F821, and pre-existing
E501 line-length issues. Auto-fix 69 unused imports and 2 f-strings without
placeholders. Manually fix 8 unused variables. Apply ruff format to 127 files.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Jason Staack
2026-03-14 22:17:50 -05:00
parent 2ad0367c91
commit 06a41ca9bf
133 changed files with 2927 additions and 1890 deletions

View File

@@ -6,8 +6,6 @@ Phase 10: Integration test fixtures added in tests/integration/conftest.py.
Pytest marker registration and shared configuration lives here.
"""
import pytest
def pytest_configure(config):
"""Register custom markers."""

View File

@@ -241,6 +241,7 @@ async def test_app(admin_engine, app_engine):
# Register rate limiter (auth endpoints use @limiter.limit)
from app.middleware.rate_limit import setup_rate_limiting
setup_rate_limiting(app)
# Create test session factories

View File

@@ -258,9 +258,7 @@ class TestAlertEvents:
):
"""GET /api/tenants/{tenant_id}/devices/{device_id}/alerts returns paginated response."""
tenant = await create_test_tenant(admin_session)
auth = await auth_headers_factory(
admin_session, existing_tenant_id=tenant.id
)
auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id)
tenant_id = auth["tenant_id"]
device = await create_test_device(admin_session, tenant.id)
await admin_session.commit()

View File

@@ -19,7 +19,7 @@ from app.services.auth import hash_password
pytestmark = pytest.mark.integration
from tests.integration.conftest import TEST_DATABASE_URL
from tests.integration.conftest import TEST_DATABASE_URL # noqa: E402
# ---------------------------------------------------------------------------
@@ -93,7 +93,6 @@ async def test_login_success(client, admin_engine):
assert len(body["refresh_token"]) > 0
# Verify httpOnly cookie is set
cookies = resp.cookies
# Cookie may or may not appear in httpx depending on secure flag
# Just verify the response contains Set-Cookie header
set_cookie = resp.headers.get("set-cookie", "")
@@ -193,7 +192,6 @@ async def test_token_refresh(client, admin_engine):
assert login_resp.status_code == 200
tokens = login_resp.json()
refresh_token = tokens["refresh_token"]
original_access = tokens["access_token"]
# Use refresh token to get new access token
refresh_resp = await client.post(

View File

@@ -32,9 +32,7 @@ class TestConfigBackups:
):
"""GET config backups for a device with no backups returns 200 + empty list."""
tenant = await create_test_tenant(admin_session)
auth = await auth_headers_factory(
admin_session, existing_tenant_id=tenant.id
)
auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id)
tenant_id = auth["tenant_id"]
device = await create_test_device(admin_session, tenant.id)
await admin_session.commit()
@@ -58,9 +56,7 @@ class TestConfigBackups:
):
"""GET schedule returns synthetic default when no schedule configured."""
tenant = await create_test_tenant(admin_session)
auth = await auth_headers_factory(
admin_session, existing_tenant_id=tenant.id
)
auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id)
tenant_id = auth["tenant_id"]
device = await create_test_device(admin_session, tenant.id)
await admin_session.commit()
@@ -118,9 +114,7 @@ class TestConfigBackups:
):
"""Config backup router responds (not 404) for expected paths."""
tenant = await create_test_tenant(admin_session)
auth = await auth_headers_factory(
admin_session, existing_tenant_id=tenant.id
)
auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id)
tenant_id = auth["tenant_id"]
device = await create_test_device(admin_session, tenant.id)
await admin_session.commit()
@@ -143,7 +137,5 @@ class TestConfigBackups:
"""GET config backups without auth returns 401."""
tenant_id = str(uuid.uuid4())
device_id = str(uuid.uuid4())
resp = await client.get(
f"/api/tenants/{tenant_id}/devices/{device_id}/config/backups"
)
resp = await client.get(f"/api/tenants/{tenant_id}/devices/{device_id}/config/backups")
assert resp.status_code == 401

View File

@@ -10,7 +10,6 @@ All tests are independent and create their own test data.
import uuid
import pytest
import pytest_asyncio
pytestmark = pytest.mark.integration
@@ -90,9 +89,7 @@ class TestDevicesCRUD:
):
"""GET /api/tenants/{tenant_id}/devices/{device_id} returns correct device."""
tenant = await create_test_tenant(admin_session)
auth = await auth_headers_factory(
admin_session, existing_tenant_id=tenant.id
)
auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id)
tenant_id = auth["tenant_id"]
device = await create_test_device(admin_session, tenant.id)
@@ -178,9 +175,7 @@ class TestDevicesCRUD:
):
"""GET /api/tenants/{tenant_id}/devices?status=online returns filtered results."""
tenant = await create_test_tenant(admin_session)
auth = await auth_headers_factory(
admin_session, existing_tenant_id=tenant.id
)
auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id)
tenant_id = auth["tenant_id"]
# Create devices with different statuses

View File

@@ -36,9 +36,7 @@ class TestHealthMetrics:
):
"""GET health metrics for a device with no data returns 200 + empty list."""
tenant = await create_test_tenant(admin_session)
auth = await auth_headers_factory(
admin_session, existing_tenant_id=tenant.id
)
auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id)
tenant_id = auth["tenant_id"]
device = await create_test_device(admin_session, tenant.id)
@@ -68,9 +66,7 @@ class TestHealthMetrics:
):
"""GET health metrics returns bucketed data when rows exist."""
tenant = await create_test_tenant(admin_session)
auth = await auth_headers_factory(
admin_session, existing_tenant_id=tenant.id
)
auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id)
tenant_id = auth["tenant_id"]
device = await create_test_device(admin_session, tenant.id)
await admin_session.flush()
@@ -131,9 +127,7 @@ class TestInterfaceMetrics:
):
"""GET interface metrics for device with no data returns 200 + empty list."""
tenant = await create_test_tenant(admin_session)
auth = await auth_headers_factory(
admin_session, existing_tenant_id=tenant.id
)
auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id)
tenant_id = auth["tenant_id"]
device = await create_test_device(admin_session, tenant.id)
await admin_session.commit()
@@ -160,9 +154,7 @@ class TestInterfaceMetrics:
):
"""GET interface list for device with no data returns 200 + empty list."""
tenant = await create_test_tenant(admin_session)
auth = await auth_headers_factory(
admin_session, existing_tenant_id=tenant.id
)
auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id)
tenant_id = auth["tenant_id"]
device = await create_test_device(admin_session, tenant.id)
await admin_session.commit()
@@ -188,9 +180,7 @@ class TestSparkline:
):
"""GET sparkline for device with no data returns 200 + empty list."""
tenant = await create_test_tenant(admin_session)
auth = await auth_headers_factory(
admin_session, existing_tenant_id=tenant.id
)
auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id)
tenant_id = auth["tenant_id"]
device = await create_test_device(admin_session, tenant.id)
await admin_session.commit()
@@ -215,9 +205,7 @@ class TestFleetSummary:
):
"""GET /api/tenants/{tenant_id}/fleet/summary returns 200 with empty fleet."""
tenant = await create_test_tenant(admin_session)
auth = await auth_headers_factory(
admin_session, existing_tenant_id=tenant.id
)
auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id)
tenant_id = auth["tenant_id"]
resp = await client.get(
@@ -238,9 +226,7 @@ class TestFleetSummary:
):
"""GET fleet summary returns device data when devices exist."""
tenant = await create_test_tenant(admin_session)
auth = await auth_headers_factory(
admin_session, existing_tenant_id=tenant.id
)
auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id)
tenant_id = auth["tenant_id"]
await create_test_device(admin_session, tenant.id, hostname="fleet-dev-1")
@@ -279,9 +265,7 @@ class TestWirelessMetrics:
):
"""GET wireless metrics for device with no data returns 200 + empty list."""
tenant = await create_test_tenant(admin_session)
auth = await auth_headers_factory(
admin_session, existing_tenant_id=tenant.id
)
auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id)
tenant_id = auth["tenant_id"]
device = await create_test_device(admin_session, tenant.id)
await admin_session.commit()
@@ -308,9 +292,7 @@ class TestWirelessMetrics:
):
"""GET wireless latest for device with no data returns 200 + empty list."""
tenant = await create_test_tenant(admin_session)
auth = await auth_headers_factory(
admin_session, existing_tenant_id=tenant.id
)
auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id)
tenant_id = auth["tenant_id"]
device = await create_test_device(admin_session, tenant.id)
await admin_session.commit()

View File

@@ -28,7 +28,7 @@ from app.services.auth import hash_password
pytestmark = pytest.mark.integration
# Use the same test DB URLs as conftest
from tests.integration.conftest import TEST_APP_USER_DATABASE_URL, TEST_DATABASE_URL
from tests.integration.conftest import TEST_APP_USER_DATABASE_URL, TEST_DATABASE_URL # noqa: E402
# ---------------------------------------------------------------------------
@@ -86,12 +86,20 @@ async def test_tenant_a_cannot_see_tenant_b_devices():
await session.flush()
da = Device(
tenant_id=ta.id, hostname=f"rls-ra-{uid}", ip_address="10.1.1.1",
api_port=8728, api_ssl_port=8729, status="online",
tenant_id=ta.id,
hostname=f"rls-ra-{uid}",
ip_address="10.1.1.1",
api_port=8728,
api_ssl_port=8729,
status="online",
)
db = Device(
tenant_id=tb.id, hostname=f"rls-rb-{uid}", ip_address="10.1.1.2",
api_port=8728, api_ssl_port=8729, status="online",
tenant_id=tb.id,
hostname=f"rls-rb-{uid}",
ip_address="10.1.1.2",
api_port=8728,
api_ssl_port=8729,
status="online",
)
session.add_all([da, db])
await session.flush()
@@ -129,12 +137,20 @@ async def test_tenant_a_cannot_see_tenant_b_alerts():
await session.flush()
ra = AlertRule(
tenant_id=ta.id, name=f"CPU Alert A {uid}",
metric="cpu_load", operator=">", threshold=90.0, severity="warning",
tenant_id=ta.id,
name=f"CPU Alert A {uid}",
metric="cpu_load",
operator=">",
threshold=90.0,
severity="warning",
)
rb = AlertRule(
tenant_id=tb.id, name=f"CPU Alert B {uid}",
metric="cpu_load", operator=">", threshold=85.0, severity="critical",
tenant_id=tb.id,
name=f"CPU Alert B {uid}",
metric="cpu_load",
operator=">",
threshold=85.0,
severity="critical",
)
session.add_all([ra, rb])
await session.flush()
@@ -256,12 +272,20 @@ async def test_super_admin_sees_all_tenants():
await session.flush()
da = Device(
tenant_id=ta.id, hostname=f"sa-ra-{uid}", ip_address="10.2.1.1",
api_port=8728, api_ssl_port=8729, status="online",
tenant_id=ta.id,
hostname=f"sa-ra-{uid}",
ip_address="10.2.1.1",
api_port=8728,
api_ssl_port=8729,
status="online",
)
db = Device(
tenant_id=tb.id, hostname=f"sa-rb-{uid}", ip_address="10.2.1.2",
api_port=8728, api_ssl_port=8729, status="online",
tenant_id=tb.id,
hostname=f"sa-rb-{uid}",
ip_address="10.2.1.2",
api_port=8728,
api_ssl_port=8729,
status="online",
)
session.add_all([da, db])
await session.flush()
@@ -311,31 +335,45 @@ async def test_api_rls_isolation_devices_endpoint(client, admin_engine):
ua = User(
email=f"api-ua-{uid}@example.com",
hashed_password=hash_password("TestPass123!"),
name="User A", role="tenant_admin",
tenant_id=ta.id, is_active=True,
name="User A",
role="tenant_admin",
tenant_id=ta.id,
is_active=True,
)
ub = User(
email=f"api-ub-{uid}@example.com",
hashed_password=hash_password("TestPass123!"),
name="User B", role="tenant_admin",
tenant_id=tb.id, is_active=True,
name="User B",
role="tenant_admin",
tenant_id=tb.id,
is_active=True,
)
session.add_all([ua, ub])
await session.flush()
da = Device(
tenant_id=ta.id, hostname=f"api-ra-{uid}", ip_address="10.3.1.1",
api_port=8728, api_ssl_port=8729, status="online",
tenant_id=ta.id,
hostname=f"api-ra-{uid}",
ip_address="10.3.1.1",
api_port=8728,
api_ssl_port=8729,
status="online",
)
db = Device(
tenant_id=tb.id, hostname=f"api-rb-{uid}", ip_address="10.3.1.2",
api_port=8728, api_ssl_port=8729, status="online",
tenant_id=tb.id,
hostname=f"api-rb-{uid}",
ip_address="10.3.1.2",
api_port=8728,
api_ssl_port=8729,
status="online",
)
session.add_all([da, db])
await session.flush()
return {
"ta_id": str(ta.id), "tb_id": str(tb.id),
"ua_email": ua.email, "ub_email": ub.email,
"ta_id": str(ta.id),
"tb_id": str(tb.id),
"ua_email": ua.email,
"ub_email": ub.email,
}
ids = await _admin_commit(TEST_DATABASE_URL, setup)
@@ -398,21 +436,29 @@ async def test_api_rls_isolation_cross_tenant_device_access(client, admin_engine
ua = User(
email=f"api-xt-ua-{uid}@example.com",
hashed_password=hash_password("TestPass123!"),
name="User A", role="tenant_admin",
tenant_id=ta.id, is_active=True,
name="User A",
role="tenant_admin",
tenant_id=ta.id,
is_active=True,
)
session.add(ua)
await session.flush()
db = Device(
tenant_id=tb.id, hostname=f"api-xt-rb-{uid}", ip_address="10.4.1.1",
api_port=8728, api_ssl_port=8729, status="online",
tenant_id=tb.id,
hostname=f"api-xt-rb-{uid}",
ip_address="10.4.1.1",
api_port=8728,
api_ssl_port=8729,
status="online",
)
session.add(db)
await session.flush()
return {
"ta_id": str(ta.id), "tb_id": str(tb.id),
"ua_email": ua.email, "db_id": str(db.id),
"ta_id": str(ta.id),
"tb_id": str(tb.id),
"ua_email": ua.email,
"db_id": str(db.id),
}
ids = await _admin_commit(TEST_DATABASE_URL, setup)

View File

@@ -5,22 +5,16 @@ tenant deletion cleanup, and allowed-IPs validation.
"""
import os
import uuid
from unittest.mock import AsyncMock, patch
import pytest
import pytest_asyncio
from sqlalchemy import select, text
from sqlalchemy import select
from app.models.vpn import VpnConfig, VpnPeer
from app.services.vpn_service import (
add_peer,
get_peer_config,
get_vpn_config,
remove_peer,
setup_vpn,
sync_wireguard_config,
_get_wg_config_path,
)
pytestmark = pytest.mark.integration
@@ -213,9 +207,8 @@ class TestTenantDeletion:
# Delete tenant 2
from app.models.tenant import Tenant
result = await admin_session.execute(
select(Tenant).where(Tenant.id == t2.id)
)
result = await admin_session.execute(select(Tenant).where(Tenant.id == t2.id))
tenant_obj = result.scalar_one()
await admin_session.delete(tenant_obj)
await admin_session.flush()

View File

@@ -62,24 +62,31 @@ async def test_snapshot_created_audit_event():
mock_log_action = AsyncMock()
with patch(
"app.services.config_snapshot_subscriber.AdminAsyncSessionLocal",
return_value=mock_ctx,
), patch(
"app.services.config_snapshot_subscriber.OpenBaoTransitService",
return_value=mock_openbao,
), patch(
"app.services.config_snapshot_subscriber.generate_and_store_diff",
new_callable=AsyncMock,
), patch(
"app.services.config_snapshot_subscriber.log_action",
mock_log_action,
with (
patch(
"app.services.config_snapshot_subscriber.AdminAsyncSessionLocal",
return_value=mock_ctx,
),
patch(
"app.services.config_snapshot_subscriber.OpenBaoTransitService",
return_value=mock_openbao,
),
patch(
"app.services.config_snapshot_subscriber.generate_and_store_diff",
new_callable=AsyncMock,
),
patch(
"app.services.config_snapshot_subscriber.log_action",
mock_log_action,
),
):
await handle_config_snapshot(msg)
# log_action should have been called with config_snapshot_created
actions = [call.kwargs.get("action", call.args[4] if len(call.args) > 4 else None)
for call in mock_log_action.call_args_list]
actions = [
call.kwargs.get("action", call.args[4] if len(call.args) > 4 else None)
for call in mock_log_action.call_args_list
]
assert "config_snapshot_created" in actions
@@ -103,21 +110,27 @@ async def test_snapshot_skipped_duplicate_audit_event():
mock_log_action = AsyncMock()
with patch(
"app.services.config_snapshot_subscriber.AdminAsyncSessionLocal",
return_value=mock_ctx,
), patch(
"app.services.config_snapshot_subscriber.OpenBaoTransitService",
return_value=AsyncMock(),
), patch(
"app.services.config_snapshot_subscriber.log_action",
mock_log_action,
with (
patch(
"app.services.config_snapshot_subscriber.AdminAsyncSessionLocal",
return_value=mock_ctx,
),
patch(
"app.services.config_snapshot_subscriber.OpenBaoTransitService",
return_value=AsyncMock(),
),
patch(
"app.services.config_snapshot_subscriber.log_action",
mock_log_action,
),
):
await handle_config_snapshot(msg)
# log_action should have been called with config_snapshot_skipped_duplicate
actions = [call.kwargs.get("action", call.args[4] if len(call.args) > 4 else None)
for call in mock_log_action.call_args_list]
actions = [
call.kwargs.get("action", call.args[4] if len(call.args) > 4 else None)
for call in mock_log_action.call_args_list
]
assert "config_snapshot_skipped_duplicate" in actions
@@ -150,22 +163,28 @@ async def test_diff_generated_audit_event():
mock_session.commit = AsyncMock()
mock_openbao = AsyncMock()
mock_openbao.decrypt = AsyncMock(side_effect=[
old_config.encode("utf-8"),
new_config.encode("utf-8"),
])
mock_openbao.decrypt = AsyncMock(
side_effect=[
old_config.encode("utf-8"),
new_config.encode("utf-8"),
]
)
mock_log_action = AsyncMock()
with patch(
"app.services.config_diff_service.OpenBaoTransitService",
return_value=mock_openbao,
), patch(
"app.services.config_diff_service.parse_diff_changes",
return_value=[],
), patch(
"app.services.audit_service.log_action",
mock_log_action,
with (
patch(
"app.services.config_diff_service.OpenBaoTransitService",
return_value=mock_openbao,
),
patch(
"app.services.config_diff_service.parse_diff_changes",
return_value=[],
),
patch(
"app.services.audit_service.log_action",
mock_log_action,
),
):
await generate_and_store_diff(device_id, tenant_id, new_snapshot_id, mock_session)
@@ -217,13 +236,21 @@ async def test_manual_trigger_audit_event():
original_enabled = limiter.enabled
limiter.enabled = False
try:
with patch.object(
cb_module, "_get_nats", return_value=mock_nc,
), patch.object(
cb_module, "_check_tenant_access", new_callable=AsyncMock,
), patch(
"app.services.audit_service.log_action",
mock_log_action,
with (
patch.object(
cb_module,
"_get_nats",
return_value=mock_nc,
),
patch.object(
cb_module,
"_check_tenant_access",
new_callable=AsyncMock,
),
patch(
"app.services.audit_service.log_action",
mock_log_action,
),
):
result = await cb_module.trigger_config_snapshot(
request=mock_request,

View File

@@ -1,7 +1,7 @@
"""Tests for dynamic backup scheduling."""
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from unittest.mock import MagicMock
from app.services.backup_scheduler import (
build_schedule_map,

View File

@@ -4,8 +4,6 @@ Tests the parse_diff_changes function that extracts structured RouterOS
component changes from unified diffs.
"""
import pytest
from app.services.config_change_parser import parse_diff_changes

View File

@@ -1,8 +1,7 @@
"""Tests for config change NATS subscriber."""
import pytest
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, patch, MagicMock
from unittest.mock import AsyncMock, patch
from uuid import uuid4
from app.services.config_change_subscriber import handle_config_changed
@@ -18,13 +17,16 @@ async def test_triggers_backup_on_config_change():
"new_timestamp": "2026-03-07 12:00:00",
}
with patch(
"app.services.config_change_subscriber.backup_service.run_backup",
new_callable=AsyncMock,
) as mock_backup, patch(
"app.services.config_change_subscriber._last_backup_within_dedup_window",
new_callable=AsyncMock,
return_value=False,
with (
patch(
"app.services.config_change_subscriber.backup_service.run_backup",
new_callable=AsyncMock,
) as mock_backup,
patch(
"app.services.config_change_subscriber._last_backup_within_dedup_window",
new_callable=AsyncMock,
return_value=False,
),
):
await handle_config_changed(event)
@@ -42,13 +44,16 @@ async def test_skips_backup_within_dedup_window():
"new_timestamp": "2026-03-07 12:00:00",
}
with patch(
"app.services.config_change_subscriber.backup_service.run_backup",
new_callable=AsyncMock,
) as mock_backup, patch(
"app.services.config_change_subscriber._last_backup_within_dedup_window",
new_callable=AsyncMock,
return_value=True,
with (
patch(
"app.services.config_change_subscriber.backup_service.run_backup",
new_callable=AsyncMock,
) as mock_backup,
patch(
"app.services.config_change_subscriber._last_backup_within_dedup_window",
new_callable=AsyncMock,
return_value=True,
),
):
await handle_config_changed(event)

View File

@@ -13,9 +13,7 @@ class TestCheckpointEndpointExists:
from app.routers.config_backups import router
paths = [r.path for r in router.routes]
assert any("checkpoint" in p for p in paths), (
f"No checkpoint route found. Routes: {paths}"
)
assert any("checkpoint" in p for p in paths), f"No checkpoint route found. Routes: {paths}"
def test_checkpoint_route_is_post(self):
from app.routers.config_backups import router
@@ -53,16 +51,20 @@ class TestCheckpointFunction:
mock_request = MagicMock()
with patch(
"app.routers.config_backups.backup_service.run_backup",
new_callable=AsyncMock,
return_value=mock_result,
) as mock_backup, patch(
"app.routers.config_backups._check_tenant_access",
new_callable=AsyncMock,
), patch(
"app.routers.config_backups.limiter.enabled",
False,
with (
patch(
"app.routers.config_backups.backup_service.run_backup",
new_callable=AsyncMock,
return_value=mock_result,
) as mock_backup,
patch(
"app.routers.config_backups._check_tenant_access",
new_callable=AsyncMock,
),
patch(
"app.routers.config_backups.limiter.enabled",
False,
),
):
result = await create_checkpoint(
request=mock_request,

View File

@@ -4,9 +4,8 @@ Tests the generate_and_store_diff function with mocked DB sessions
and OpenBao Transit service.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch, call
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
@@ -51,17 +50,22 @@ async def test_diff_generated_and_stored():
mock_session.commit = AsyncMock()
mock_openbao = AsyncMock()
mock_openbao.decrypt = AsyncMock(side_effect=[
old_config.encode("utf-8"),
new_config.encode("utf-8"),
])
mock_openbao.decrypt = AsyncMock(
side_effect=[
old_config.encode("utf-8"),
new_config.encode("utf-8"),
]
)
with patch(
"app.services.config_diff_service.OpenBaoTransitService",
return_value=mock_openbao,
), patch(
"app.services.config_diff_service.parse_diff_changes",
return_value=[],
with (
patch(
"app.services.config_diff_service.OpenBaoTransitService",
return_value=mock_openbao,
),
patch(
"app.services.config_diff_service.parse_diff_changes",
return_value=[],
),
):
await generate_and_store_diff(device_id, tenant_id, new_snapshot_id, mock_session)
@@ -178,17 +182,22 @@ async def test_line_counts_correct():
mock_session.commit = AsyncMock()
mock_openbao = AsyncMock()
mock_openbao.decrypt = AsyncMock(side_effect=[
old_config.encode("utf-8"),
new_config.encode("utf-8"),
])
mock_openbao.decrypt = AsyncMock(
side_effect=[
old_config.encode("utf-8"),
new_config.encode("utf-8"),
]
)
with patch(
"app.services.config_diff_service.OpenBaoTransitService",
return_value=mock_openbao,
), patch(
"app.services.config_diff_service.parse_diff_changes",
return_value=[],
with (
patch(
"app.services.config_diff_service.OpenBaoTransitService",
return_value=mock_openbao,
),
patch(
"app.services.config_diff_service.parse_diff_changes",
return_value=[],
),
):
await generate_and_store_diff(device_id, tenant_id, new_snapshot_id, mock_session)
@@ -222,10 +231,12 @@ async def test_empty_diff_skips_insert():
mock_session.commit = AsyncMock()
mock_openbao = AsyncMock()
mock_openbao.decrypt = AsyncMock(side_effect=[
same_config.encode("utf-8"),
same_config.encode("utf-8"),
])
mock_openbao.decrypt = AsyncMock(
side_effect=[
same_config.encode("utf-8"),
same_config.encode("utf-8"),
]
)
with patch(
"app.services.config_diff_service.OpenBaoTransitService",
@@ -271,22 +282,31 @@ async def test_change_parser_called_and_changes_stored():
mock_session.commit = AsyncMock()
mock_openbao = AsyncMock()
mock_openbao.decrypt = AsyncMock(side_effect=[
old_config.encode("utf-8"),
new_config.encode("utf-8"),
])
mock_openbao.decrypt = AsyncMock(
side_effect=[
old_config.encode("utf-8"),
new_config.encode("utf-8"),
]
)
mock_changes = [
{"component": "ip/firewall/filter", "summary": "Added 1 firewall filter rule", "raw_line": "+add chain=forward action=drop"},
{
"component": "ip/firewall/filter",
"summary": "Added 1 firewall filter rule",
"raw_line": "+add chain=forward action=drop",
},
]
with patch(
"app.services.config_diff_service.OpenBaoTransitService",
return_value=mock_openbao,
), patch(
"app.services.config_diff_service.parse_diff_changes",
return_value=mock_changes,
) as mock_parser:
with (
patch(
"app.services.config_diff_service.OpenBaoTransitService",
return_value=mock_openbao,
),
patch(
"app.services.config_diff_service.parse_diff_changes",
return_value=mock_changes,
) as mock_parser,
):
await generate_and_store_diff(device_id, tenant_id, new_snapshot_id, mock_session)
# parse_diff_changes called with the diff text
@@ -332,17 +352,22 @@ async def test_change_parser_error_does_not_block_diff():
mock_session.commit = AsyncMock()
mock_openbao = AsyncMock()
mock_openbao.decrypt = AsyncMock(side_effect=[
old_config.encode("utf-8"),
new_config.encode("utf-8"),
])
mock_openbao.decrypt = AsyncMock(
side_effect=[
old_config.encode("utf-8"),
new_config.encode("utf-8"),
]
)
with patch(
"app.services.config_diff_service.OpenBaoTransitService",
return_value=mock_openbao,
), patch(
"app.services.config_diff_service.parse_diff_changes",
side_effect=Exception("Parser exploded"),
with (
patch(
"app.services.config_diff_service.OpenBaoTransitService",
return_value=mock_openbao,
),
patch(
"app.services.config_diff_service.parse_diff_changes",
side_effect=Exception("Parser exploded"),
),
):
# Should NOT raise
await generate_and_store_diff(device_id, tenant_id, new_snapshot_id, mock_session)

View File

@@ -10,7 +10,9 @@ from uuid import uuid4
from datetime import datetime, timezone
def _make_change_row(change_id, component, summary, created_at, diff_id, lines_added, lines_removed, snapshot_id):
def _make_change_row(
change_id, component, summary, created_at, diff_id, lines_added, lines_removed, snapshot_id
):
"""Create a mock row matching the JOIN query result."""
row = MagicMock()
row._mapping = {
@@ -41,7 +43,9 @@ async def test_returns_formatted_entries():
mock_session = AsyncMock()
result_mock = MagicMock()
result_mock.fetchall.return_value = [
_make_change_row(change_id, "ip/firewall/filter", "Added 1 rule", ts, diff_id, 3, 1, snapshot_id),
_make_change_row(
change_id, "ip/firewall/filter", "Added 1 rule", ts, diff_id, 3, 1, snapshot_id
),
]
mock_session.execute = AsyncMock(return_value=result_mock)

View File

@@ -56,15 +56,19 @@ async def test_new_snapshot_encrypted_and_stored():
mock_openbao = AsyncMock()
mock_openbao.encrypt.return_value = "vault:v1:encrypted_data"
with patch(
"app.services.config_snapshot_subscriber.AdminAsyncSessionLocal",
return_value=mock_ctx,
), patch(
"app.services.config_snapshot_subscriber.OpenBaoTransitService",
return_value=mock_openbao,
), patch(
"app.services.config_snapshot_subscriber.generate_and_store_diff",
new_callable=AsyncMock,
with (
patch(
"app.services.config_snapshot_subscriber.AdminAsyncSessionLocal",
return_value=mock_ctx,
),
patch(
"app.services.config_snapshot_subscriber.OpenBaoTransitService",
return_value=mock_openbao,
),
patch(
"app.services.config_snapshot_subscriber.generate_and_store_diff",
new_callable=AsyncMock,
),
):
await handle_config_snapshot(msg)
@@ -102,12 +106,15 @@ async def test_duplicate_snapshot_skipped():
mock_openbao = AsyncMock()
with patch(
"app.services.config_snapshot_subscriber.AdminAsyncSessionLocal",
return_value=mock_ctx,
), patch(
"app.services.config_snapshot_subscriber.OpenBaoTransitService",
return_value=mock_openbao,
with (
patch(
"app.services.config_snapshot_subscriber.AdminAsyncSessionLocal",
return_value=mock_ctx,
),
patch(
"app.services.config_snapshot_subscriber.OpenBaoTransitService",
return_value=mock_openbao,
),
):
await handle_config_snapshot(msg)
@@ -141,12 +148,15 @@ async def test_transit_encrypt_failure_causes_nak():
mock_openbao = AsyncMock()
mock_openbao.encrypt.side_effect = Exception("Transit unavailable")
with patch(
"app.services.config_snapshot_subscriber.AdminAsyncSessionLocal",
return_value=mock_ctx,
), patch(
"app.services.config_snapshot_subscriber.OpenBaoTransitService",
return_value=mock_openbao,
with (
patch(
"app.services.config_snapshot_subscriber.AdminAsyncSessionLocal",
return_value=mock_ctx,
),
patch(
"app.services.config_snapshot_subscriber.OpenBaoTransitService",
return_value=mock_openbao,
),
):
await handle_config_snapshot(msg)
@@ -168,11 +178,14 @@ async def test_malformed_message_acked_and_discarded():
mock_openbao = AsyncMock()
with patch(
"app.services.config_snapshot_subscriber.AdminAsyncSessionLocal",
) as mock_session_cls, patch(
"app.services.config_snapshot_subscriber.OpenBaoTransitService",
return_value=mock_openbao,
with (
patch(
"app.services.config_snapshot_subscriber.AdminAsyncSessionLocal",
),
patch(
"app.services.config_snapshot_subscriber.OpenBaoTransitService",
return_value=mock_openbao,
),
):
await handle_config_snapshot(msg)
@@ -209,12 +222,15 @@ async def test_orphan_device_acked_and_discarded():
mock_openbao = AsyncMock()
mock_openbao.encrypt.return_value = "vault:v1:encrypted_data"
with patch(
"app.services.config_snapshot_subscriber.AdminAsyncSessionLocal",
return_value=mock_ctx,
), patch(
"app.services.config_snapshot_subscriber.OpenBaoTransitService",
return_value=mock_openbao,
with (
patch(
"app.services.config_snapshot_subscriber.AdminAsyncSessionLocal",
return_value=mock_ctx,
),
patch(
"app.services.config_snapshot_subscriber.OpenBaoTransitService",
return_value=mock_openbao,
),
):
await handle_config_snapshot(msg)
@@ -245,15 +261,19 @@ async def test_first_snapshot_for_device_always_stored():
mock_openbao = AsyncMock()
mock_openbao.encrypt.return_value = "vault:v1:first_snapshot_encrypted"
with patch(
"app.services.config_snapshot_subscriber.AdminAsyncSessionLocal",
return_value=mock_ctx,
), patch(
"app.services.config_snapshot_subscriber.OpenBaoTransitService",
return_value=mock_openbao,
), patch(
"app.services.config_snapshot_subscriber.generate_and_store_diff",
new_callable=AsyncMock,
with (
patch(
"app.services.config_snapshot_subscriber.AdminAsyncSessionLocal",
return_value=mock_ctx,
),
patch(
"app.services.config_snapshot_subscriber.OpenBaoTransitService",
return_value=mock_openbao,
),
patch(
"app.services.config_snapshot_subscriber.generate_and_store_diff",
new_callable=AsyncMock,
),
):
await handle_config_snapshot(msg)

View File

@@ -15,7 +15,6 @@ from unittest.mock import AsyncMock, MagicMock
import nats.errors
from fastapi import HTTPException, status
from sqlalchemy import select
# ---------------------------------------------------------------------------
@@ -118,11 +117,13 @@ def _mock_db(device_exists: bool):
async def test_trigger_success_returns_201():
"""POST with operator role returns 201 with status and sha256_hash."""
sha256 = "b" * 64
nc = _mock_nats_reply({
"status": "success",
"sha256_hash": sha256,
"message": "Config snapshot collected",
})
nc = _mock_nats_reply(
{
"status": "success",
"sha256_hash": sha256,
"message": "Config snapshot collected",
}
)
db = _mock_db(device_exists=True)
result = await _simulate_trigger(nats_conn=nc, db_session=db)
@@ -156,10 +157,12 @@ async def test_trigger_nats_timeout_returns_504():
@pytest.mark.asyncio
async def test_trigger_poller_failure_returns_502():
"""Poller failure reply returns 502."""
nc = _mock_nats_reply({
"status": "failed",
"error": "SSH connection refused",
})
nc = _mock_nats_reply(
{
"status": "failed",
"error": "SSH connection refused",
}
)
db = _mock_db(device_exists=True)
with pytest.raises(HTTPException) as exc_info:
@@ -184,10 +187,12 @@ async def test_trigger_device_not_found_returns_404():
@pytest.mark.asyncio
async def test_trigger_locked_returns_409():
"""Lock contention returns 409 Conflict."""
nc = _mock_nats_reply({
"status": "locked",
"message": "backup already in progress",
})
nc = _mock_nats_reply(
{
"status": "locked",
"message": "backup already in progress",
}
)
db = _mock_db(device_exists=True)
with pytest.raises(HTTPException) as exc_info:

View File

@@ -35,26 +35,33 @@ async def test_recovery_commits_reachable_device_with_scheduler():
dev_result.scalar_one_or_none.return_value = device
mock_session.execute = AsyncMock(side_effect=[mock_result, dev_result])
with patch(
"app.services.restore_service._check_reachability",
new_callable=AsyncMock,
return_value=True,
), patch(
"app.services.restore_service._remove_panic_scheduler",
new_callable=AsyncMock,
return_value=True,
), patch(
"app.services.restore_service._update_push_op_status",
new_callable=AsyncMock,
) as mock_update, patch(
"app.services.restore_service._publish_push_progress",
new_callable=AsyncMock,
), patch(
"app.services.crypto.decrypt_credentials_hybrid",
new_callable=AsyncMock,
return_value='{"username": "admin", "password": "test123"}',
), patch(
"app.services.restore_service.settings",
with (
patch(
"app.services.restore_service._check_reachability",
new_callable=AsyncMock,
return_value=True,
),
patch(
"app.services.restore_service._remove_panic_scheduler",
new_callable=AsyncMock,
return_value=True,
),
patch(
"app.services.restore_service._update_push_op_status",
new_callable=AsyncMock,
) as mock_update,
patch(
"app.services.restore_service._publish_push_progress",
new_callable=AsyncMock,
),
patch(
"app.services.crypto.decrypt_credentials_hybrid",
new_callable=AsyncMock,
return_value='{"username": "admin", "password": "test123"}',
),
patch(
"app.services.restore_service.settings",
),
):
await recover_stale_push_operations(mock_session)
@@ -84,22 +91,28 @@ async def test_recovery_marks_unreachable_device_failed():
dev_result.scalar_one_or_none.return_value = device
mock_session.execute = AsyncMock(side_effect=[mock_result, dev_result])
with patch(
"app.services.restore_service._check_reachability",
new_callable=AsyncMock,
return_value=False,
), patch(
"app.services.restore_service._update_push_op_status",
new_callable=AsyncMock,
) as mock_update, patch(
"app.services.restore_service._publish_push_progress",
new_callable=AsyncMock,
), patch(
"app.services.crypto.decrypt_credentials_hybrid",
new_callable=AsyncMock,
return_value='{"username": "admin", "password": "test123"}',
), patch(
"app.services.restore_service.settings",
with (
patch(
"app.services.restore_service._check_reachability",
new_callable=AsyncMock,
return_value=False,
),
patch(
"app.services.restore_service._update_push_op_status",
new_callable=AsyncMock,
) as mock_update,
patch(
"app.services.restore_service._publish_push_progress",
new_callable=AsyncMock,
),
patch(
"app.services.crypto.decrypt_credentials_hybrid",
new_callable=AsyncMock,
return_value='{"username": "admin", "password": "test123"}',
),
patch(
"app.services.restore_service.settings",
),
):
await recover_stale_push_operations(mock_session)

View File

@@ -1,7 +1,7 @@
"""Tests for push rollback NATS subscriber."""
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from unittest.mock import AsyncMock, patch
from uuid import uuid4
from app.services.push_rollback_subscriber import (

View File

@@ -60,25 +60,32 @@ class TestPreviewRestoreFunction:
mock_scalar.scalar_one_or_none.return_value = mock_device
mock_db.execute.return_value = mock_scalar
with patch(
"app.routers.config_backups._check_tenant_access",
new_callable=AsyncMock,
), patch(
"app.routers.config_backups.limiter.enabled",
False,
), patch(
"app.routers.config_backups.git_store.read_file",
return_value=target_export.encode(),
), patch(
"app.routers.config_backups.backup_service.capture_export",
new_callable=AsyncMock,
return_value=current_export,
), patch(
"app.routers.config_backups.decrypt_credentials_hybrid",
new_callable=AsyncMock,
return_value='{"username": "admin", "password": "pass"}',
), patch(
"app.routers.config_backups.settings",
with (
patch(
"app.routers.config_backups._check_tenant_access",
new_callable=AsyncMock,
),
patch(
"app.routers.config_backups.limiter.enabled",
False,
),
patch(
"app.routers.config_backups.git_store.read_file",
return_value=target_export.encode(),
),
patch(
"app.routers.config_backups.backup_service.capture_export",
new_callable=AsyncMock,
return_value=current_export,
),
patch(
"app.routers.config_backups.decrypt_credentials_hybrid",
new_callable=AsyncMock,
return_value='{"username": "admin", "password": "pass"}',
),
patch(
"app.routers.config_backups.settings",
),
):
result = await preview_restore(
request=mock_request,
@@ -140,25 +147,32 @@ class TestPreviewRestoreFunction:
return current_export.encode()
return b""
with patch(
"app.routers.config_backups._check_tenant_access",
new_callable=AsyncMock,
), patch(
"app.routers.config_backups.limiter.enabled",
False,
), patch(
"app.routers.config_backups.git_store.read_file",
side_effect=mock_read_file,
), patch(
"app.routers.config_backups.backup_service.capture_export",
new_callable=AsyncMock,
side_effect=ConnectionError("Device unreachable"),
), patch(
"app.routers.config_backups.decrypt_credentials_hybrid",
new_callable=AsyncMock,
return_value='{"username": "admin", "password": "pass"}',
), patch(
"app.routers.config_backups.settings",
with (
patch(
"app.routers.config_backups._check_tenant_access",
new_callable=AsyncMock,
),
patch(
"app.routers.config_backups.limiter.enabled",
False,
),
patch(
"app.routers.config_backups.git_store.read_file",
side_effect=mock_read_file,
),
patch(
"app.routers.config_backups.backup_service.capture_export",
new_callable=AsyncMock,
side_effect=ConnectionError("Device unreachable"),
),
patch(
"app.routers.config_backups.decrypt_credentials_hybrid",
new_callable=AsyncMock,
return_value='{"username": "admin", "password": "pass"}',
),
patch(
"app.routers.config_backups.settings",
),
):
result = await preview_restore(
request=mock_request,
@@ -188,15 +202,19 @@ class TestPreviewRestoreFunction:
mock_request = MagicMock()
body = RestoreRequest(commit_sha="nonexistent")
with patch(
"app.routers.config_backups._check_tenant_access",
new_callable=AsyncMock,
), patch(
"app.routers.config_backups.limiter.enabled",
False,
), patch(
"app.routers.config_backups.git_store.read_file",
side_effect=KeyError("not found"),
with (
patch(
"app.routers.config_backups._check_tenant_access",
new_callable=AsyncMock,
),
patch(
"app.routers.config_backups.limiter.enabled",
False,
),
patch(
"app.routers.config_backups.git_store.read_file",
side_effect=KeyError("not found"),
),
):
with pytest.raises(HTTPException) as exc_info:
await preview_restore(

View File

@@ -23,12 +23,15 @@ async def test_cleanup_deletes_expired_snapshots():
mock_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
with patch(
"app.services.retention_service.AdminAsyncSessionLocal",
return_value=mock_ctx,
), patch(
"app.services.retention_service.settings",
) as mock_settings:
with (
patch(
"app.services.retention_service.AdminAsyncSessionLocal",
return_value=mock_ctx,
),
patch(
"app.services.retention_service.settings",
) as mock_settings,
):
mock_settings.CONFIG_RETENTION_DAYS = 90
count = await cleanup_expired_snapshots()
@@ -60,12 +63,15 @@ async def test_cleanup_keeps_snapshots_within_retention_window():
mock_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
with patch(
"app.services.retention_service.AdminAsyncSessionLocal",
return_value=mock_ctx,
), patch(
"app.services.retention_service.settings",
) as mock_settings:
with (
patch(
"app.services.retention_service.AdminAsyncSessionLocal",
return_value=mock_ctx,
),
patch(
"app.services.retention_service.settings",
) as mock_settings,
):
mock_settings.CONFIG_RETENTION_DAYS = 90
count = await cleanup_expired_snapshots()
@@ -87,12 +93,15 @@ async def test_cleanup_returns_deleted_count():
mock_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
with patch(
"app.services.retention_service.AdminAsyncSessionLocal",
return_value=mock_ctx,
), patch(
"app.services.retention_service.settings",
) as mock_settings:
with (
patch(
"app.services.retention_service.AdminAsyncSessionLocal",
return_value=mock_ctx,
),
patch(
"app.services.retention_service.settings",
) as mock_settings,
):
mock_settings.CONFIG_RETENTION_DAYS = 30
count = await cleanup_expired_snapshots()
@@ -114,12 +123,15 @@ async def test_cleanup_handles_empty_table():
mock_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
with patch(
"app.services.retention_service.AdminAsyncSessionLocal",
return_value=mock_ctx,
), patch(
"app.services.retention_service.settings",
) as mock_settings:
with (
patch(
"app.services.retention_service.AdminAsyncSessionLocal",
return_value=mock_ctx,
),
patch(
"app.services.retention_service.settings",
) as mock_settings,
):
mock_settings.CONFIG_RETENTION_DAYS = 90
count = await cleanup_expired_snapshots()

View File

@@ -1,6 +1,5 @@
"""Tests for RouterOS RSC export parser."""
import pytest
from app.services.rsc_parser import parse_rsc, validate_rsc, compute_impact
@@ -74,7 +73,7 @@ class TestValidateRsc:
assert any("quote" in e.lower() for e in result["errors"])
def test_truncated_continuation_detected(self):
bad = '/ip address\nadd address=192.168.1.1/24 \\\n'
bad = "/ip address\nadd address=192.168.1.1/24 \\\n"
result = validate_rsc(bad)
assert result["valid"] is False
assert any("truncat" in e.lower() or "continuation" in e.lower() for e in result["errors"])
@@ -82,25 +81,25 @@ class TestValidateRsc:
class TestComputeImpact:
def test_high_risk_for_firewall_input(self):
current = '/ip firewall filter\nadd action=accept chain=input\n'
target = '/ip firewall filter\nadd action=drop chain=input\n'
current = "/ip firewall filter\nadd action=accept chain=input\n"
target = "/ip firewall filter\nadd action=drop chain=input\n"
result = compute_impact(parse_rsc(current), parse_rsc(target))
assert any(c["risk"] == "high" for c in result["categories"])
def test_high_risk_for_ip_address_changes(self):
current = '/ip address\nadd address=192.168.1.1/24 interface=ether1\n'
target = '/ip address\nadd address=10.0.0.1/24 interface=ether1\n'
current = "/ip address\nadd address=192.168.1.1/24 interface=ether1\n"
target = "/ip address\nadd address=10.0.0.1/24 interface=ether1\n"
result = compute_impact(parse_rsc(current), parse_rsc(target))
ip_cat = next(c for c in result["categories"] if c["path"] == "/ip address")
assert ip_cat["risk"] in ("high", "medium")
def test_warnings_for_management_access(self):
current = ""
target = '/ip firewall filter\nadd action=drop chain=input protocol=tcp dst-port=22\n'
target = "/ip firewall filter\nadd action=drop chain=input protocol=tcp dst-port=22\n"
result = compute_impact(parse_rsc(current), parse_rsc(target))
assert len(result["warnings"]) > 0
def test_no_changes_no_warnings(self):
same = '/ip dns\nset servers=8.8.8.8\n'
same = "/ip dns\nset servers=8.8.8.8\n"
result = compute_impact(parse_rsc(same), parse_rsc(same))
assert result["warnings"] == [] or all(c["risk"] == "none" for c in result["categories"])

View File

@@ -32,7 +32,7 @@ def test_srp_roundtrip():
context = SRPContext(EMAIL, password=PASSWORD, prime=PRIME_2048, generator=PRIME_2048_GEN)
username, verifier, salt = context.get_user_data_triplet()
print(f"\n--- SRP Interop Reference Values ---")
print("\n--- SRP Interop Reference Values ---")
print(f"email (I): {EMAIL}")
print(f"salt (s): {salt}")
print(f"verifier (v): {verifier[:64]}... (len={len(verifier)})")
@@ -45,7 +45,9 @@ def test_srp_roundtrip():
print(f"server_public (B): {server_public[:64]}... (len={len(server_public)})")
# Step 3: Client init -- generate A (client needs password for proof)
client_context = SRPContext(EMAIL, password=PASSWORD, prime=PRIME_2048, generator=PRIME_2048_GEN)
client_context = SRPContext(
EMAIL, password=PASSWORD, prime=PRIME_2048, generator=PRIME_2048_GEN
)
client_session = SRPClientSession(client_context)
client_public = client_session.public
@@ -78,7 +80,7 @@ def test_srp_roundtrip():
)
print(f"session_key (K): {client_session.key[:64]}... (len={len(client_session.key)})")
print(f"--- Handshake PASSED ---\n")
print("--- Handshake PASSED ---\n")
def test_srp_bad_proof_rejected():
@@ -89,7 +91,9 @@ def test_srp_bad_proof_rejected():
server_context = SRPContext(EMAIL, prime=PRIME_2048, generator=PRIME_2048_GEN)
server_session = SRPServerSession(server_context, verifier)
client_context = SRPContext(EMAIL, password=PASSWORD, prime=PRIME_2048, generator=PRIME_2048_GEN)
client_context = SRPContext(
EMAIL, password=PASSWORD, prime=PRIME_2048, generator=PRIME_2048_GEN
)
client_session = SRPClientSession(client_context)
client_session.process(server_session.public, salt)

View File

@@ -8,7 +8,7 @@ Tests cover:
"""
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock
import pytest
@@ -18,15 +18,24 @@ class TestAuditLogModel:
def test_model_importable(self):
from app.models.audit_log import AuditLog
assert AuditLog.__tablename__ == "audit_logs"
def test_model_has_required_columns(self):
from app.models.audit_log import AuditLog
mapper = AuditLog.__table__.columns
expected_columns = {
"id", "tenant_id", "user_id", "action",
"resource_type", "resource_id", "device_id",
"details", "ip_address", "created_at",
"id",
"tenant_id",
"user_id",
"action",
"resource_type",
"resource_id",
"device_id",
"details",
"ip_address",
"created_at",
}
actual_columns = {c.name for c in mapper}
assert expected_columns.issubset(actual_columns), (
@@ -35,6 +44,7 @@ class TestAuditLogModel:
def test_model_exported_from_init(self):
from app.models import AuditLog
assert AuditLog.__tablename__ == "audit_logs"
@@ -43,6 +53,7 @@ class TestAuditService:
def test_log_action_importable(self):
from app.services.audit_service import log_action
assert callable(log_action)
@pytest.mark.asyncio
@@ -67,9 +78,11 @@ class TestAuditRouter:
def test_router_importable(self):
from app.routers.audit_logs import router
assert router is not None
def test_router_has_audit_logs_endpoint(self):
from app.routers.audit_logs import router
paths = [route.path for route in router.routes]
assert "/audit-logs" in paths or any("/audit-logs" in p for p in paths)

View File

@@ -11,7 +11,6 @@ These are pure function tests -- no database or async required.
import uuid
from datetime import UTC, datetime, timedelta
from unittest.mock import patch
import pytest
from fastapi import HTTPException
@@ -83,9 +82,7 @@ class TestAccessToken:
assert payload["role"] == "super_admin"
def test_contains_expiry(self):
token = create_access_token(
user_id=uuid.uuid4(), tenant_id=uuid.uuid4(), role="viewer"
)
token = create_access_token(user_id=uuid.uuid4(), tenant_id=uuid.uuid4(), role="viewer")
payload = verify_token(token, expected_type="access")
assert "exp" in payload
assert "iat" in payload

View File

@@ -3,53 +3,65 @@
Verifies STOR-01 (table/column structure) and STOR-05 (config_text stores ciphertext).
"""
import uuid
from sqlalchemy import String, Text
from sqlalchemy.dialects.postgresql import UUID
def test_router_config_snapshot_importable():
"""RouterConfigSnapshot can be imported from app.models."""
from app.models import RouterConfigSnapshot
assert RouterConfigSnapshot is not None
def test_router_config_diff_importable():
"""RouterConfigDiff can be imported from app.models."""
from app.models import RouterConfigDiff
assert RouterConfigDiff is not None
def test_router_config_change_importable():
"""RouterConfigChange can be imported from app.models."""
from app.models import RouterConfigChange
assert RouterConfigChange is not None
def test_snapshot_tablename():
"""RouterConfigSnapshot.__tablename__ is correct."""
from app.models import RouterConfigSnapshot
assert RouterConfigSnapshot.__tablename__ == "router_config_snapshots"
def test_diff_tablename():
"""RouterConfigDiff.__tablename__ is correct."""
from app.models import RouterConfigDiff
assert RouterConfigDiff.__tablename__ == "router_config_diffs"
def test_change_tablename():
"""RouterConfigChange.__tablename__ is correct."""
from app.models import RouterConfigChange
assert RouterConfigChange.__tablename__ == "router_config_changes"
def test_snapshot_columns():
"""RouterConfigSnapshot has all required columns."""
from app.models import RouterConfigSnapshot
table = RouterConfigSnapshot.__table__
expected = {"id", "device_id", "tenant_id", "config_text", "sha256_hash", "collected_at", "created_at"}
expected = {
"id",
"device_id",
"tenant_id",
"config_text",
"sha256_hash",
"collected_at",
"created_at",
}
actual = {c.name for c in table.columns}
assert expected.issubset(actual), f"Missing columns: {expected - actual}"
@@ -57,10 +69,18 @@ def test_snapshot_columns():
def test_diff_columns():
"""RouterConfigDiff has all required columns."""
from app.models import RouterConfigDiff
table = RouterConfigDiff.__table__
expected = {
"id", "device_id", "tenant_id", "old_snapshot_id", "new_snapshot_id",
"diff_text", "lines_added", "lines_removed", "created_at",
"id",
"device_id",
"tenant_id",
"old_snapshot_id",
"new_snapshot_id",
"diff_text",
"lines_added",
"lines_removed",
"created_at",
}
actual = {c.name for c in table.columns}
assert expected.issubset(actual), f"Missing columns: {expected - actual}"
@@ -69,10 +89,17 @@ def test_diff_columns():
def test_change_columns():
"""RouterConfigChange has all required columns."""
from app.models import RouterConfigChange
table = RouterConfigChange.__table__
expected = {
"id", "diff_id", "device_id", "tenant_id",
"component", "summary", "raw_line", "created_at",
"id",
"diff_id",
"device_id",
"tenant_id",
"component",
"summary",
"raw_line",
"created_at",
}
actual = {c.name for c in table.columns}
assert expected.issubset(actual), f"Missing columns: {expected - actual}"
@@ -81,6 +108,7 @@ def test_change_columns():
def test_snapshot_config_text_is_text_type():
"""config_text column type is Text (documents Transit ciphertext contract)."""
from app.models import RouterConfigSnapshot
col = RouterConfigSnapshot.__table__.c.config_text
assert isinstance(col.type, Text), f"Expected Text, got {type(col.type)}"
@@ -88,6 +116,7 @@ def test_snapshot_config_text_is_text_type():
def test_snapshot_sha256_hash_is_string_64():
"""sha256_hash column type is String(64) for plaintext hash deduplication."""
from app.models import RouterConfigSnapshot
col = RouterConfigSnapshot.__table__.c.sha256_hash
assert isinstance(col.type, String), f"Expected String, got {type(col.type)}"
assert col.type.length == 64, f"Expected length 64, got {col.type.length}"

View File

@@ -7,11 +7,9 @@ Tests cover:
- Router registration in main app
"""
import uuid
from datetime import datetime, timezone, timedelta
import pytest
from pydantic import ValidationError
class TestMaintenanceWindowModel:
@@ -19,20 +17,31 @@ class TestMaintenanceWindowModel:
def test_model_importable(self):
from app.models.maintenance_window import MaintenanceWindow
assert MaintenanceWindow.__tablename__ == "maintenance_windows"
def test_model_exported_from_init(self):
from app.models import MaintenanceWindow
assert MaintenanceWindow.__tablename__ == "maintenance_windows"
def test_model_has_required_columns(self):
from app.models.maintenance_window import MaintenanceWindow
mapper = MaintenanceWindow.__mapper__
column_names = {c.key for c in mapper.columns}
expected = {
"id", "tenant_id", "name", "device_ids",
"start_at", "end_at", "suppress_alerts",
"notes", "created_by", "created_at", "updated_at",
"id",
"tenant_id",
"name",
"device_ids",
"start_at",
"end_at",
"suppress_alerts",
"notes",
"created_by",
"created_at",
"updated_at",
}
assert expected.issubset(column_names), f"Missing columns: {expected - column_names}"
@@ -42,6 +51,7 @@ class TestMaintenanceWindowSchemas:
def test_create_schema_valid(self):
from app.routers.maintenance_windows import MaintenanceWindowCreate
data = MaintenanceWindowCreate(
name="Nightly update",
device_ids=["abc-123"],
@@ -55,6 +65,7 @@ class TestMaintenanceWindowSchemas:
def test_create_schema_defaults(self):
from app.routers.maintenance_windows import MaintenanceWindowCreate
data = MaintenanceWindowCreate(
name="Quick reboot",
device_ids=[],
@@ -66,12 +77,14 @@ class TestMaintenanceWindowSchemas:
def test_update_schema_partial(self):
from app.routers.maintenance_windows import MaintenanceWindowUpdate
data = MaintenanceWindowUpdate(name="Updated name")
assert data.name == "Updated name"
assert data.device_ids is None # all optional
def test_response_schema(self):
from app.routers.maintenance_windows import MaintenanceWindowResponse
data = MaintenanceWindowResponse(
id="abc",
tenant_id="def",
@@ -92,10 +105,12 @@ class TestRouterRegistration:
def test_router_importable(self):
from app.routers.maintenance_windows import router
assert router is not None
def test_router_has_routes(self):
from app.routers.maintenance_windows import router
paths = [r.path for r in router.routes]
assert any("maintenance-windows" in p for p in paths)
@@ -114,8 +129,10 @@ class TestAlertEvaluatorMaintenance:
def test_maintenance_cache_exists(self):
from app.services import alert_evaluator
assert hasattr(alert_evaluator, "_maintenance_cache")
def test_is_device_in_maintenance_function_exists(self):
from app.services.alert_evaluator import _is_device_in_maintenance
assert callable(_is_device_in_maintenance)

View File

@@ -9,7 +9,6 @@ for startup validation, async only for middleware tests.
"""
from types import SimpleNamespace
from unittest.mock import patch
import pytest
@@ -114,7 +113,9 @@ class TestSecurityHeadersMiddleware:
response = await client.get("/test")
assert response.status_code == 200
assert response.headers["strict-transport-security"] == "max-age=31536000; includeSubDomains"
assert (
response.headers["strict-transport-security"] == "max-age=31536000; includeSubDomains"
)
assert response.headers["x-content-type-options"] == "nosniff"
assert response.headers["x-frame-options"] == "DENY"
assert response.headers["cache-control"] == "no-store"

View File

@@ -1,7 +1,10 @@
"""Unit tests for VPN subnet allocation and allowed-IPs validation."""
import pytest
from app.services.vpn_service import _allocate_subnet_index_from_used, _validate_additional_allowed_ips
from app.services.vpn_service import (
_allocate_subnet_index_from_used,
_validate_additional_allowed_ips,
)
class TestSubnetAllocation: