feat: The Other Dude v9.0.1 — full-featured email system
ci: add GitHub Pages deployment workflow for docs site Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
0
backend/tests/unit/__init__.py
Normal file
0
backend/tests/unit/__init__.py
Normal file
76
backend/tests/unit/test_api_key_service.py
Normal file
76
backend/tests/unit/test_api_key_service.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Unit tests for API key service.
|
||||
|
||||
Tests cover:
|
||||
- Key generation format (mktp_ prefix, sufficient length)
|
||||
- Key hashing (SHA-256 hex digest, 64 chars)
|
||||
- Scope validation against allowed list
|
||||
- Key prefix extraction
|
||||
|
||||
These are pure function tests -- no database or async required.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
|
||||
from app.services.api_key_service import (
|
||||
ALLOWED_SCOPES,
|
||||
generate_raw_key,
|
||||
hash_key,
|
||||
)
|
||||
|
||||
|
||||
class TestKeyGeneration:
|
||||
"""Tests for API key generation."""
|
||||
|
||||
def test_key_starts_with_prefix(self):
|
||||
key = generate_raw_key()
|
||||
assert key.startswith("mktp_")
|
||||
|
||||
def test_key_has_sufficient_length(self):
|
||||
"""Key should be mktp_ + at least 32 chars of randomness."""
|
||||
key = generate_raw_key()
|
||||
assert len(key) >= 37 # "mktp_" (5) + 32
|
||||
|
||||
def test_key_uniqueness(self):
|
||||
"""Two generated keys should never be the same."""
|
||||
key1 = generate_raw_key()
|
||||
key2 = generate_raw_key()
|
||||
assert key1 != key2
|
||||
|
||||
|
||||
class TestKeyHashing:
|
||||
"""Tests for SHA-256 key hashing."""
|
||||
|
||||
def test_hash_produces_64_char_hex(self):
|
||||
key = "mktp_test1234567890abcdef"
|
||||
h = hash_key(key)
|
||||
assert len(h) == 64
|
||||
assert all(c in "0123456789abcdef" for c in h)
|
||||
|
||||
def test_hash_is_sha256(self):
|
||||
key = "mktp_test1234567890abcdef"
|
||||
expected = hashlib.sha256(key.encode()).hexdigest()
|
||||
assert hash_key(key) == expected
|
||||
|
||||
def test_hash_deterministic(self):
|
||||
key = generate_raw_key()
|
||||
assert hash_key(key) == hash_key(key)
|
||||
|
||||
def test_different_keys_different_hashes(self):
|
||||
key1 = generate_raw_key()
|
||||
key2 = generate_raw_key()
|
||||
assert hash_key(key1) != hash_key(key2)
|
||||
|
||||
|
||||
class TestAllowedScopes:
|
||||
"""Tests for scope definitions."""
|
||||
|
||||
def test_allowed_scopes_contains_expected(self):
|
||||
expected = {
|
||||
"devices:read",
|
||||
"devices:write",
|
||||
"config:read",
|
||||
"config:write",
|
||||
"alerts:read",
|
||||
"firmware:write",
|
||||
}
|
||||
assert expected == ALLOWED_SCOPES
|
||||
75
backend/tests/unit/test_audit_service.py
Normal file
75
backend/tests/unit/test_audit_service.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Unit tests for the audit service and model.
|
||||
|
||||
Tests cover:
|
||||
- AuditLog model can be imported
|
||||
- log_action function signature is correct
|
||||
- Audit logs router is importable with expected endpoints
|
||||
- CSV export endpoint exists
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestAuditLogModel:
|
||||
"""Tests for the AuditLog ORM model."""
|
||||
|
||||
def test_model_importable(self):
|
||||
from app.models.audit_log import AuditLog
|
||||
assert AuditLog.__tablename__ == "audit_logs"
|
||||
|
||||
def test_model_has_required_columns(self):
|
||||
from app.models.audit_log import AuditLog
|
||||
mapper = AuditLog.__table__.columns
|
||||
expected_columns = {
|
||||
"id", "tenant_id", "user_id", "action",
|
||||
"resource_type", "resource_id", "device_id",
|
||||
"details", "ip_address", "created_at",
|
||||
}
|
||||
actual_columns = {c.name for c in mapper}
|
||||
assert expected_columns.issubset(actual_columns), (
|
||||
f"Missing columns: {expected_columns - actual_columns}"
|
||||
)
|
||||
|
||||
def test_model_exported_from_init(self):
|
||||
from app.models import AuditLog
|
||||
assert AuditLog.__tablename__ == "audit_logs"
|
||||
|
||||
|
||||
class TestAuditService:
|
||||
"""Tests for the audit service log_action function."""
|
||||
|
||||
def test_log_action_importable(self):
|
||||
from app.services.audit_service import log_action
|
||||
assert callable(log_action)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_action_does_not_raise_on_db_error(self):
|
||||
"""log_action must swallow exceptions so it never breaks the caller."""
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.execute = AsyncMock(side_effect=Exception("DB down"))
|
||||
|
||||
# Should NOT raise even though the DB call fails
|
||||
await log_action(
|
||||
db=mock_db,
|
||||
tenant_id=uuid.uuid4(),
|
||||
user_id=uuid.uuid4(),
|
||||
action="test_action",
|
||||
)
|
||||
|
||||
|
||||
class TestAuditRouter:
|
||||
"""Tests for the audit logs router."""
|
||||
|
||||
def test_router_importable(self):
|
||||
from app.routers.audit_logs import router
|
||||
assert router is not None
|
||||
|
||||
def test_router_has_audit_logs_endpoint(self):
|
||||
from app.routers.audit_logs import router
|
||||
paths = [route.path for route in router.routes]
|
||||
assert "/audit-logs" in paths or any("/audit-logs" in p for p in paths)
|
||||
169
backend/tests/unit/test_auth.py
Normal file
169
backend/tests/unit/test_auth.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Unit tests for the JWT authentication service.
|
||||
|
||||
Tests cover:
|
||||
- Password hashing and verification (bcrypt)
|
||||
- JWT access token creation and validation
|
||||
- JWT refresh token creation and validation
|
||||
- Token rejection for wrong type, expired, invalid, missing subject
|
||||
|
||||
These are pure function tests -- no database or async required.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from jose import jwt
|
||||
|
||||
from app.services.auth import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
hash_password,
|
||||
verify_password,
|
||||
verify_token,
|
||||
)
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class TestPasswordHashing:
|
||||
"""Tests for bcrypt password hashing."""
|
||||
|
||||
def test_hash_returns_different_string(self):
|
||||
password = "test-password-123!"
|
||||
hashed = hash_password(password)
|
||||
assert hashed != password
|
||||
|
||||
def test_hash_verify_roundtrip(self):
|
||||
password = "test-password-123!"
|
||||
hashed = hash_password(password)
|
||||
assert verify_password(password, hashed) is True
|
||||
|
||||
def test_verify_rejects_wrong_password(self):
|
||||
hashed = hash_password("correct-password")
|
||||
assert verify_password("wrong-password", hashed) is False
|
||||
|
||||
def test_hash_uses_unique_salts(self):
|
||||
"""Each hash should be different even for the same password (random salt)."""
|
||||
hash1 = hash_password("same-password")
|
||||
hash2 = hash_password("same-password")
|
||||
assert hash1 != hash2
|
||||
|
||||
def test_verify_both_hashes_valid(self):
|
||||
"""Both unique hashes should verify against the original password."""
|
||||
password = "same-password"
|
||||
hash1 = hash_password(password)
|
||||
hash2 = hash_password(password)
|
||||
assert verify_password(password, hash1) is True
|
||||
assert verify_password(password, hash2) is True
|
||||
|
||||
|
||||
class TestAccessToken:
|
||||
"""Tests for JWT access token creation and validation."""
|
||||
|
||||
def test_create_and_verify_roundtrip(self):
|
||||
user_id = uuid.uuid4()
|
||||
tenant_id = uuid.uuid4()
|
||||
token = create_access_token(user_id=user_id, tenant_id=tenant_id, role="admin")
|
||||
payload = verify_token(token, expected_type="access")
|
||||
|
||||
assert payload["sub"] == str(user_id)
|
||||
assert payload["tenant_id"] == str(tenant_id)
|
||||
assert payload["role"] == "admin"
|
||||
assert payload["type"] == "access"
|
||||
|
||||
def test_super_admin_null_tenant(self):
|
||||
user_id = uuid.uuid4()
|
||||
token = create_access_token(user_id=user_id, tenant_id=None, role="super_admin")
|
||||
payload = verify_token(token, expected_type="access")
|
||||
|
||||
assert payload["sub"] == str(user_id)
|
||||
assert payload["tenant_id"] is None
|
||||
assert payload["role"] == "super_admin"
|
||||
|
||||
def test_contains_expiry(self):
|
||||
token = create_access_token(
|
||||
user_id=uuid.uuid4(), tenant_id=uuid.uuid4(), role="viewer"
|
||||
)
|
||||
payload = verify_token(token, expected_type="access")
|
||||
assert "exp" in payload
|
||||
assert "iat" in payload
|
||||
|
||||
|
||||
class TestRefreshToken:
|
||||
"""Tests for JWT refresh token creation and validation."""
|
||||
|
||||
def test_create_and_verify_roundtrip(self):
|
||||
user_id = uuid.uuid4()
|
||||
token = create_refresh_token(user_id=user_id)
|
||||
payload = verify_token(token, expected_type="refresh")
|
||||
|
||||
assert payload["sub"] == str(user_id)
|
||||
assert payload["type"] == "refresh"
|
||||
|
||||
def test_refresh_token_has_no_tenant_or_role(self):
|
||||
token = create_refresh_token(user_id=uuid.uuid4())
|
||||
payload = verify_token(token, expected_type="refresh")
|
||||
|
||||
# Refresh tokens intentionally omit tenant_id and role
|
||||
assert "tenant_id" not in payload
|
||||
assert "role" not in payload
|
||||
|
||||
|
||||
class TestTokenRejection:
|
||||
"""Tests for JWT token validation failure cases."""
|
||||
|
||||
def test_rejects_wrong_type(self):
|
||||
"""Access token should not verify as refresh, and vice versa."""
|
||||
access_token = create_access_token(
|
||||
user_id=uuid.uuid4(), tenant_id=uuid.uuid4(), role="admin"
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_token(access_token, expected_type="refresh")
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_rejects_expired_token(self):
|
||||
"""Manually craft an expired token and verify it is rejected."""
|
||||
expired_payload = {
|
||||
"sub": str(uuid.uuid4()),
|
||||
"type": "access",
|
||||
"exp": datetime.now(UTC) - timedelta(hours=1),
|
||||
"iat": datetime.now(UTC) - timedelta(hours=2),
|
||||
}
|
||||
expired_token = jwt.encode(
|
||||
expired_payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_token(expired_token, expected_type="access")
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_rejects_invalid_token(self):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_token("not-a-valid-jwt", expected_type="access")
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_rejects_wrong_signing_key(self):
|
||||
"""Token signed with a different key should be rejected."""
|
||||
payload = {
|
||||
"sub": str(uuid.uuid4()),
|
||||
"type": "access",
|
||||
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||
}
|
||||
wrong_key_token = jwt.encode(payload, "wrong-secret-key", algorithm="HS256")
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_token(wrong_key_token, expected_type="access")
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_rejects_missing_subject(self):
|
||||
"""Token without 'sub' claim should be rejected."""
|
||||
no_sub_payload = {
|
||||
"type": "access",
|
||||
"exp": datetime.now(UTC) + timedelta(hours=1),
|
||||
}
|
||||
no_sub_token = jwt.encode(
|
||||
no_sub_payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_token(no_sub_token, expected_type="access")
|
||||
assert exc_info.value.status_code == 401
|
||||
126
backend/tests/unit/test_crypto.py
Normal file
126
backend/tests/unit/test_crypto.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Unit tests for the credential encryption/decryption service.
|
||||
|
||||
Tests cover:
|
||||
- Encryption/decryption round-trip with valid key
|
||||
- Random nonce ensures different ciphertext per encryption
|
||||
- Wrong key rejection (InvalidTag)
|
||||
- Invalid key length rejection (ValueError)
|
||||
- Unicode and JSON payload handling
|
||||
- Tampered ciphertext detection
|
||||
|
||||
These are pure function tests -- no database or async required.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from cryptography.exceptions import InvalidTag
|
||||
|
||||
from app.services.crypto import decrypt_credentials, encrypt_credentials
|
||||
|
||||
|
||||
class TestEncryptDecryptRoundTrip:
|
||||
"""Tests for successful encryption/decryption cycles."""
|
||||
|
||||
def test_basic_roundtrip(self):
|
||||
key = os.urandom(32)
|
||||
plaintext = "secret-password"
|
||||
ciphertext = encrypt_credentials(plaintext, key)
|
||||
result = decrypt_credentials(ciphertext, key)
|
||||
assert result == plaintext
|
||||
|
||||
def test_json_credentials_roundtrip(self):
|
||||
"""The actual use case: encrypting JSON credential objects."""
|
||||
key = os.urandom(32)
|
||||
creds = json.dumps({"username": "admin", "password": "RouterOS!123"})
|
||||
ciphertext = encrypt_credentials(creds, key)
|
||||
result = decrypt_credentials(ciphertext, key)
|
||||
parsed = json.loads(result)
|
||||
assert parsed["username"] == "admin"
|
||||
assert parsed["password"] == "RouterOS!123"
|
||||
|
||||
def test_unicode_roundtrip(self):
|
||||
key = os.urandom(32)
|
||||
plaintext = "password-with-unicode-\u00e9\u00e8\u00ea"
|
||||
ciphertext = encrypt_credentials(plaintext, key)
|
||||
result = decrypt_credentials(ciphertext, key)
|
||||
assert result == plaintext
|
||||
|
||||
def test_empty_string_roundtrip(self):
|
||||
key = os.urandom(32)
|
||||
ciphertext = encrypt_credentials("", key)
|
||||
result = decrypt_credentials(ciphertext, key)
|
||||
assert result == ""
|
||||
|
||||
def test_long_payload_roundtrip(self):
|
||||
"""Ensure large payloads work (e.g., SSH keys in credentials)."""
|
||||
key = os.urandom(32)
|
||||
plaintext = "x" * 10000
|
||||
ciphertext = encrypt_credentials(plaintext, key)
|
||||
result = decrypt_credentials(ciphertext, key)
|
||||
assert result == plaintext
|
||||
|
||||
|
||||
class TestNonceRandomness:
|
||||
"""Tests that encryption uses random nonces."""
|
||||
|
||||
def test_different_ciphertext_each_time(self):
|
||||
"""Two encryptions of the same plaintext should produce different ciphertext
|
||||
because a random 12-byte nonce is generated each time."""
|
||||
key = os.urandom(32)
|
||||
plaintext = "same-plaintext"
|
||||
ct1 = encrypt_credentials(plaintext, key)
|
||||
ct2 = encrypt_credentials(plaintext, key)
|
||||
assert ct1 != ct2
|
||||
|
||||
def test_both_decrypt_correctly(self):
|
||||
"""Both different ciphertexts should decrypt to the same plaintext."""
|
||||
key = os.urandom(32)
|
||||
plaintext = "same-plaintext"
|
||||
ct1 = encrypt_credentials(plaintext, key)
|
||||
ct2 = encrypt_credentials(plaintext, key)
|
||||
assert decrypt_credentials(ct1, key) == plaintext
|
||||
assert decrypt_credentials(ct2, key) == plaintext
|
||||
|
||||
|
||||
class TestDecryptionFailures:
|
||||
"""Tests for proper rejection of invalid inputs."""
|
||||
|
||||
def test_wrong_key_raises_invalid_tag(self):
|
||||
key1 = os.urandom(32)
|
||||
key2 = os.urandom(32)
|
||||
ciphertext = encrypt_credentials("secret", key1)
|
||||
with pytest.raises(InvalidTag):
|
||||
decrypt_credentials(ciphertext, key2)
|
||||
|
||||
def test_tampered_ciphertext_raises_invalid_tag(self):
|
||||
"""Flipping a byte in the ciphertext should cause authentication failure."""
|
||||
key = os.urandom(32)
|
||||
ciphertext = bytearray(encrypt_credentials("secret", key))
|
||||
# Flip a byte in the encrypted portion (after the 12-byte nonce)
|
||||
ciphertext[15] ^= 0xFF
|
||||
with pytest.raises(InvalidTag):
|
||||
decrypt_credentials(bytes(ciphertext), key)
|
||||
|
||||
|
||||
class TestKeyValidation:
|
||||
"""Tests for encryption key length validation."""
|
||||
|
||||
def test_short_key_encrypt_raises(self):
|
||||
with pytest.raises(ValueError, match="32 bytes"):
|
||||
encrypt_credentials("test", os.urandom(16))
|
||||
|
||||
def test_long_key_encrypt_raises(self):
|
||||
with pytest.raises(ValueError, match="32 bytes"):
|
||||
encrypt_credentials("test", os.urandom(64))
|
||||
|
||||
def test_short_key_decrypt_raises(self):
|
||||
key = os.urandom(32)
|
||||
ciphertext = encrypt_credentials("test", key)
|
||||
with pytest.raises(ValueError, match="32 bytes"):
|
||||
decrypt_credentials(ciphertext, os.urandom(16))
|
||||
|
||||
def test_empty_key_raises(self):
|
||||
with pytest.raises(ValueError, match="32 bytes"):
|
||||
encrypt_credentials("test", b"")
|
||||
121
backend/tests/unit/test_maintenance_windows.py
Normal file
121
backend/tests/unit/test_maintenance_windows.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Unit tests for maintenance window model, router schemas, and alert suppression.
|
||||
|
||||
Tests cover:
|
||||
- MaintenanceWindow ORM model imports and field definitions
|
||||
- MaintenanceWindowCreate/Update/Response Pydantic schema validation
|
||||
- Alert evaluator _is_device_in_maintenance integration
|
||||
- Router registration in main app
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
class TestMaintenanceWindowModel:
|
||||
"""Test that the MaintenanceWindow ORM model is importable and has correct fields."""
|
||||
|
||||
def test_model_importable(self):
|
||||
from app.models.maintenance_window import MaintenanceWindow
|
||||
assert MaintenanceWindow.__tablename__ == "maintenance_windows"
|
||||
|
||||
def test_model_exported_from_init(self):
|
||||
from app.models import MaintenanceWindow
|
||||
assert MaintenanceWindow.__tablename__ == "maintenance_windows"
|
||||
|
||||
def test_model_has_required_columns(self):
|
||||
from app.models.maintenance_window import MaintenanceWindow
|
||||
mapper = MaintenanceWindow.__mapper__
|
||||
column_names = {c.key for c in mapper.columns}
|
||||
expected = {
|
||||
"id", "tenant_id", "name", "device_ids",
|
||||
"start_at", "end_at", "suppress_alerts",
|
||||
"notes", "created_by", "created_at", "updated_at",
|
||||
}
|
||||
assert expected.issubset(column_names), f"Missing columns: {expected - column_names}"
|
||||
|
||||
|
||||
class TestMaintenanceWindowSchemas:
|
||||
"""Test Pydantic schemas for request/response validation."""
|
||||
|
||||
def test_create_schema_valid(self):
|
||||
from app.routers.maintenance_windows import MaintenanceWindowCreate
|
||||
data = MaintenanceWindowCreate(
|
||||
name="Nightly update",
|
||||
device_ids=["abc-123"],
|
||||
start_at=datetime.now(timezone.utc),
|
||||
end_at=datetime.now(timezone.utc) + timedelta(hours=2),
|
||||
suppress_alerts=True,
|
||||
notes="Scheduled maintenance",
|
||||
)
|
||||
assert data.name == "Nightly update"
|
||||
assert data.suppress_alerts is True
|
||||
|
||||
def test_create_schema_defaults(self):
|
||||
from app.routers.maintenance_windows import MaintenanceWindowCreate
|
||||
data = MaintenanceWindowCreate(
|
||||
name="Quick reboot",
|
||||
device_ids=[],
|
||||
start_at=datetime.now(timezone.utc),
|
||||
end_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
)
|
||||
assert data.suppress_alerts is True # default
|
||||
assert data.notes is None
|
||||
|
||||
def test_update_schema_partial(self):
|
||||
from app.routers.maintenance_windows import MaintenanceWindowUpdate
|
||||
data = MaintenanceWindowUpdate(name="Updated name")
|
||||
assert data.name == "Updated name"
|
||||
assert data.device_ids is None # all optional
|
||||
|
||||
def test_response_schema(self):
|
||||
from app.routers.maintenance_windows import MaintenanceWindowResponse
|
||||
data = MaintenanceWindowResponse(
|
||||
id="abc",
|
||||
tenant_id="def",
|
||||
name="Test",
|
||||
device_ids=["x"],
|
||||
start_at=datetime.now(timezone.utc).isoformat(),
|
||||
end_at=datetime.now(timezone.utc).isoformat(),
|
||||
suppress_alerts=True,
|
||||
notes=None,
|
||||
created_by="ghi",
|
||||
created_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
assert data.id == "abc"
|
||||
|
||||
|
||||
class TestRouterRegistration:
|
||||
"""Test that the maintenance_windows router is properly registered."""
|
||||
|
||||
def test_router_importable(self):
|
||||
from app.routers.maintenance_windows import router
|
||||
assert router is not None
|
||||
|
||||
def test_router_has_routes(self):
|
||||
from app.routers.maintenance_windows import router
|
||||
paths = [r.path for r in router.routes]
|
||||
assert any("maintenance-windows" in p for p in paths)
|
||||
|
||||
def test_main_app_includes_router(self):
|
||||
try:
|
||||
from app.main import app
|
||||
except ImportError:
|
||||
pytest.skip("app.main requires full dependencies (prometheus, etc.)")
|
||||
route_paths = [r.path for r in app.routes]
|
||||
route_paths_str = " ".join(route_paths)
|
||||
assert "maintenance-windows" in route_paths_str
|
||||
|
||||
|
||||
class TestAlertEvaluatorMaintenance:
|
||||
"""Test that alert_evaluator has maintenance window check capability."""
|
||||
|
||||
def test_maintenance_cache_exists(self):
|
||||
from app.services import alert_evaluator
|
||||
assert hasattr(alert_evaluator, "_maintenance_cache")
|
||||
|
||||
def test_is_device_in_maintenance_function_exists(self):
|
||||
from app.services.alert_evaluator import _is_device_in_maintenance
|
||||
assert callable(_is_device_in_maintenance)
|
||||
231
backend/tests/unit/test_security.py
Normal file
231
backend/tests/unit/test_security.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""Unit tests for security hardening.
|
||||
|
||||
Tests cover:
|
||||
- Production startup validation (insecure defaults rejection)
|
||||
- Security headers middleware (per-environment header behavior)
|
||||
|
||||
These are pure function/middleware tests -- no database or async required
|
||||
for startup validation, async only for middleware tests.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.config import KNOWN_INSECURE_DEFAULTS, validate_production_settings
|
||||
|
||||
|
||||
class TestStartupValidation:
|
||||
"""Tests for validate_production_settings()."""
|
||||
|
||||
def _make_settings(self, **kwargs):
|
||||
"""Create a mock settings object with given field values."""
|
||||
defaults = {
|
||||
"ENVIRONMENT": "dev",
|
||||
"JWT_SECRET_KEY": "change-this-in-production-use-a-long-random-string",
|
||||
"CREDENTIAL_ENCRYPTION_KEY": "LLLjnfBZTSycvL2U07HDSxUeTtLxb9cZzryQl0R9E4w=",
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return SimpleNamespace(**defaults)
|
||||
|
||||
def test_production_rejects_insecure_jwt_secret(self):
|
||||
"""Production with default JWT secret must exit."""
|
||||
settings = self._make_settings(
|
||||
ENVIRONMENT="production",
|
||||
JWT_SECRET_KEY=KNOWN_INSECURE_DEFAULTS["JWT_SECRET_KEY"][0],
|
||||
)
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
validate_production_settings(settings)
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
def test_production_rejects_insecure_encryption_key(self):
|
||||
"""Production with default encryption key must exit."""
|
||||
settings = self._make_settings(
|
||||
ENVIRONMENT="production",
|
||||
JWT_SECRET_KEY="a-real-secure-jwt-secret-that-is-long-enough",
|
||||
CREDENTIAL_ENCRYPTION_KEY=KNOWN_INSECURE_DEFAULTS["CREDENTIAL_ENCRYPTION_KEY"][0],
|
||||
)
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
validate_production_settings(settings)
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
def test_dev_allows_insecure_defaults(self):
|
||||
"""Dev environment allows insecure defaults without error."""
|
||||
settings = self._make_settings(
|
||||
ENVIRONMENT="dev",
|
||||
JWT_SECRET_KEY=KNOWN_INSECURE_DEFAULTS["JWT_SECRET_KEY"][0],
|
||||
CREDENTIAL_ENCRYPTION_KEY=KNOWN_INSECURE_DEFAULTS["CREDENTIAL_ENCRYPTION_KEY"][0],
|
||||
)
|
||||
# Should NOT raise
|
||||
validate_production_settings(settings)
|
||||
|
||||
def test_production_allows_secure_values(self):
|
||||
"""Production with non-default secrets should pass."""
|
||||
settings = self._make_settings(
|
||||
ENVIRONMENT="production",
|
||||
JWT_SECRET_KEY="a-real-secure-jwt-secret-that-is-long-enough-for-production",
|
||||
CREDENTIAL_ENCRYPTION_KEY="dGhpcyBpcyBhIHNlY3VyZSBrZXkgdGhhdCBpcw==",
|
||||
)
|
||||
# Should NOT raise
|
||||
validate_production_settings(settings)
|
||||
|
||||
|
||||
class TestSecurityHeadersMiddleware:
|
||||
"""Tests for SecurityHeadersMiddleware."""
|
||||
|
||||
@pytest.fixture
|
||||
def prod_app(self):
|
||||
"""Create a minimal FastAPI app with security middleware in production mode."""
|
||||
from fastapi import FastAPI
|
||||
from app.middleware.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(SecurityHeadersMiddleware, environment="production")
|
||||
|
||||
@app.get("/test")
|
||||
async def test_endpoint():
|
||||
return {"status": "ok"}
|
||||
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def dev_app(self):
|
||||
"""Create a minimal FastAPI app with security middleware in dev mode."""
|
||||
from fastapi import FastAPI
|
||||
from app.middleware.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(SecurityHeadersMiddleware, environment="dev")
|
||||
|
||||
@app.get("/test")
|
||||
async def test_endpoint():
|
||||
return {"status": "ok"}
|
||||
|
||||
return app
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_production_includes_hsts(self, prod_app):
|
||||
"""Production responses must include HSTS header."""
|
||||
import httpx
|
||||
|
||||
transport = httpx.ASGITransport(app=prod_app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/test")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["strict-transport-security"] == "max-age=31536000; includeSubDomains"
|
||||
assert response.headers["x-content-type-options"] == "nosniff"
|
||||
assert response.headers["x-frame-options"] == "DENY"
|
||||
assert response.headers["cache-control"] == "no-store"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dev_excludes_hsts(self, dev_app):
|
||||
"""Dev responses must NOT include HSTS (breaks plain HTTP)."""
|
||||
import httpx
|
||||
|
||||
transport = httpx.ASGITransport(app=dev_app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/test")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "strict-transport-security" not in response.headers
|
||||
assert response.headers["x-content-type-options"] == "nosniff"
|
||||
assert response.headers["x-frame-options"] == "DENY"
|
||||
assert response.headers["cache-control"] == "no-store"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csp_header_present_production(self, prod_app):
|
||||
"""Production responses must include CSP header."""
|
||||
import httpx
|
||||
|
||||
transport = httpx.ASGITransport(app=prod_app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/test")
|
||||
|
||||
assert "content-security-policy" in response.headers
|
||||
csp = response.headers["content-security-policy"]
|
||||
assert "default-src 'self'" in csp
|
||||
assert "script-src" in csp
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csp_header_present_dev(self, dev_app):
|
||||
"""Dev responses must include CSP header."""
|
||||
import httpx
|
||||
|
||||
transport = httpx.ASGITransport(app=dev_app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/test")
|
||||
|
||||
assert "content-security-policy" in response.headers
|
||||
csp = response.headers["content-security-policy"]
|
||||
assert "default-src 'self'" in csp
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csp_production_blocks_inline_scripts(self, prod_app):
|
||||
"""Production CSP must block inline scripts (no unsafe-inline in script-src)."""
|
||||
import httpx
|
||||
|
||||
transport = httpx.ASGITransport(app=prod_app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/test")
|
||||
|
||||
csp = response.headers["content-security-policy"]
|
||||
# Extract the script-src directive value
|
||||
script_src = [d for d in csp.split(";") if "script-src" in d][0]
|
||||
assert "'unsafe-inline'" not in script_src
|
||||
assert "'unsafe-eval'" not in script_src
|
||||
assert "'self'" in script_src
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csp_dev_allows_unsafe_inline(self, dev_app):
|
||||
"""Dev CSP must allow unsafe-inline and unsafe-eval for Vite HMR."""
|
||||
import httpx
|
||||
|
||||
transport = httpx.ASGITransport(app=dev_app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/test")
|
||||
|
||||
csp = response.headers["content-security-policy"]
|
||||
script_src = [d for d in csp.split(";") if "script-src" in d][0]
|
||||
assert "'unsafe-inline'" in script_src
|
||||
assert "'unsafe-eval'" in script_src
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csp_production_allows_inline_styles(self, prod_app):
|
||||
"""Production CSP must allow unsafe-inline for styles (Tailwind, Framer Motion, Radix)."""
|
||||
import httpx
|
||||
|
||||
transport = httpx.ASGITransport(app=prod_app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/test")
|
||||
|
||||
csp = response.headers["content-security-policy"]
|
||||
style_src = [d for d in csp.split(";") if "style-src" in d][0]
|
||||
assert "'unsafe-inline'" in style_src
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csp_allows_websocket_connections(self, prod_app):
|
||||
"""CSP must allow wss: and ws: for SSE/WebSocket connections."""
|
||||
import httpx
|
||||
|
||||
transport = httpx.ASGITransport(app=prod_app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/test")
|
||||
|
||||
csp = response.headers["content-security-policy"]
|
||||
connect_src = [d for d in csp.split(";") if "connect-src" in d][0]
|
||||
assert "wss:" in connect_src
|
||||
assert "ws:" in connect_src
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csp_frame_ancestors_none(self, prod_app):
|
||||
"""CSP must include frame-ancestors 'none' (anti-clickjacking)."""
|
||||
import httpx
|
||||
|
||||
transport = httpx.ASGITransport(app=prod_app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/test")
|
||||
|
||||
csp = response.headers["content-security-policy"]
|
||||
assert "frame-ancestors 'none'" in csp
|
||||
Reference in New Issue
Block a user