feat: The Other Dude v9.0.1 — full-featured email system

ci: add GitHub Pages deployment workflow for docs site

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Jason Staack
2026-03-08 17:46:37 -05:00
commit b840047e19
511 changed files with 106948 additions and 0 deletions

View File

View File

@@ -0,0 +1,76 @@
"""Unit tests for API key service.
Tests cover:
- Key generation format (mktp_ prefix, sufficient length)
- Key hashing (SHA-256 hex digest, 64 chars)
- Scope validation against allowed list
- Key prefix extraction
These are pure function tests -- no database or async required.
"""
import hashlib
from app.services.api_key_service import (
ALLOWED_SCOPES,
generate_raw_key,
hash_key,
)
class TestKeyGeneration:
"""Tests for API key generation."""
def test_key_starts_with_prefix(self):
key = generate_raw_key()
assert key.startswith("mktp_")
def test_key_has_sufficient_length(self):
"""Key should be mktp_ + at least 32 chars of randomness."""
key = generate_raw_key()
assert len(key) >= 37 # "mktp_" (5) + 32
def test_key_uniqueness(self):
"""Two generated keys should never be the same."""
key1 = generate_raw_key()
key2 = generate_raw_key()
assert key1 != key2
class TestKeyHashing:
"""Tests for SHA-256 key hashing."""
def test_hash_produces_64_char_hex(self):
key = "mktp_test1234567890abcdef"
h = hash_key(key)
assert len(h) == 64
assert all(c in "0123456789abcdef" for c in h)
def test_hash_is_sha256(self):
key = "mktp_test1234567890abcdef"
expected = hashlib.sha256(key.encode()).hexdigest()
assert hash_key(key) == expected
def test_hash_deterministic(self):
key = generate_raw_key()
assert hash_key(key) == hash_key(key)
def test_different_keys_different_hashes(self):
key1 = generate_raw_key()
key2 = generate_raw_key()
assert hash_key(key1) != hash_key(key2)
class TestAllowedScopes:
"""Tests for scope definitions."""
def test_allowed_scopes_contains_expected(self):
expected = {
"devices:read",
"devices:write",
"config:read",
"config:write",
"alerts:read",
"firmware:write",
}
assert expected == ALLOWED_SCOPES

View File

@@ -0,0 +1,75 @@
"""Unit tests for the audit service and model.
Tests cover:
- AuditLog model can be imported
- log_action function signature is correct
- Audit logs router is importable with expected endpoints
- CSV export endpoint exists
"""
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
class TestAuditLogModel:
"""Tests for the AuditLog ORM model."""
def test_model_importable(self):
from app.models.audit_log import AuditLog
assert AuditLog.__tablename__ == "audit_logs"
def test_model_has_required_columns(self):
from app.models.audit_log import AuditLog
mapper = AuditLog.__table__.columns
expected_columns = {
"id", "tenant_id", "user_id", "action",
"resource_type", "resource_id", "device_id",
"details", "ip_address", "created_at",
}
actual_columns = {c.name for c in mapper}
assert expected_columns.issubset(actual_columns), (
f"Missing columns: {expected_columns - actual_columns}"
)
def test_model_exported_from_init(self):
from app.models import AuditLog
assert AuditLog.__tablename__ == "audit_logs"
class TestAuditService:
"""Tests for the audit service log_action function."""
def test_log_action_importable(self):
from app.services.audit_service import log_action
assert callable(log_action)
@pytest.mark.asyncio
async def test_log_action_does_not_raise_on_db_error(self):
"""log_action must swallow exceptions so it never breaks the caller."""
from app.services.audit_service import log_action
mock_db = AsyncMock()
mock_db.execute = AsyncMock(side_effect=Exception("DB down"))
# Should NOT raise even though the DB call fails
await log_action(
db=mock_db,
tenant_id=uuid.uuid4(),
user_id=uuid.uuid4(),
action="test_action",
)
class TestAuditRouter:
"""Tests for the audit logs router."""
def test_router_importable(self):
from app.routers.audit_logs import router
assert router is not None
def test_router_has_audit_logs_endpoint(self):
from app.routers.audit_logs import router
paths = [route.path for route in router.routes]
assert "/audit-logs" in paths or any("/audit-logs" in p for p in paths)

View File

@@ -0,0 +1,169 @@
"""Unit tests for the JWT authentication service.
Tests cover:
- Password hashing and verification (bcrypt)
- JWT access token creation and validation
- JWT refresh token creation and validation
- Token rejection for wrong type, expired, invalid, missing subject
These are pure function tests -- no database or async required.
"""
import uuid
from datetime import UTC, datetime, timedelta
from unittest.mock import patch
import pytest
from fastapi import HTTPException
from jose import jwt
from app.services.auth import (
create_access_token,
create_refresh_token,
hash_password,
verify_password,
verify_token,
)
from app.config import settings
class TestPasswordHashing:
"""Tests for bcrypt password hashing."""
def test_hash_returns_different_string(self):
password = "test-password-123!"
hashed = hash_password(password)
assert hashed != password
def test_hash_verify_roundtrip(self):
password = "test-password-123!"
hashed = hash_password(password)
assert verify_password(password, hashed) is True
def test_verify_rejects_wrong_password(self):
hashed = hash_password("correct-password")
assert verify_password("wrong-password", hashed) is False
def test_hash_uses_unique_salts(self):
"""Each hash should be different even for the same password (random salt)."""
hash1 = hash_password("same-password")
hash2 = hash_password("same-password")
assert hash1 != hash2
def test_verify_both_hashes_valid(self):
"""Both unique hashes should verify against the original password."""
password = "same-password"
hash1 = hash_password(password)
hash2 = hash_password(password)
assert verify_password(password, hash1) is True
assert verify_password(password, hash2) is True
class TestAccessToken:
"""Tests for JWT access token creation and validation."""
def test_create_and_verify_roundtrip(self):
user_id = uuid.uuid4()
tenant_id = uuid.uuid4()
token = create_access_token(user_id=user_id, tenant_id=tenant_id, role="admin")
payload = verify_token(token, expected_type="access")
assert payload["sub"] == str(user_id)
assert payload["tenant_id"] == str(tenant_id)
assert payload["role"] == "admin"
assert payload["type"] == "access"
def test_super_admin_null_tenant(self):
user_id = uuid.uuid4()
token = create_access_token(user_id=user_id, tenant_id=None, role="super_admin")
payload = verify_token(token, expected_type="access")
assert payload["sub"] == str(user_id)
assert payload["tenant_id"] is None
assert payload["role"] == "super_admin"
def test_contains_expiry(self):
token = create_access_token(
user_id=uuid.uuid4(), tenant_id=uuid.uuid4(), role="viewer"
)
payload = verify_token(token, expected_type="access")
assert "exp" in payload
assert "iat" in payload
class TestRefreshToken:
"""Tests for JWT refresh token creation and validation."""
def test_create_and_verify_roundtrip(self):
user_id = uuid.uuid4()
token = create_refresh_token(user_id=user_id)
payload = verify_token(token, expected_type="refresh")
assert payload["sub"] == str(user_id)
assert payload["type"] == "refresh"
def test_refresh_token_has_no_tenant_or_role(self):
token = create_refresh_token(user_id=uuid.uuid4())
payload = verify_token(token, expected_type="refresh")
# Refresh tokens intentionally omit tenant_id and role
assert "tenant_id" not in payload
assert "role" not in payload
class TestTokenRejection:
"""Tests for JWT token validation failure cases."""
def test_rejects_wrong_type(self):
"""Access token should not verify as refresh, and vice versa."""
access_token = create_access_token(
user_id=uuid.uuid4(), tenant_id=uuid.uuid4(), role="admin"
)
with pytest.raises(HTTPException) as exc_info:
verify_token(access_token, expected_type="refresh")
assert exc_info.value.status_code == 401
def test_rejects_expired_token(self):
"""Manually craft an expired token and verify it is rejected."""
expired_payload = {
"sub": str(uuid.uuid4()),
"type": "access",
"exp": datetime.now(UTC) - timedelta(hours=1),
"iat": datetime.now(UTC) - timedelta(hours=2),
}
expired_token = jwt.encode(
expired_payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
)
with pytest.raises(HTTPException) as exc_info:
verify_token(expired_token, expected_type="access")
assert exc_info.value.status_code == 401
def test_rejects_invalid_token(self):
with pytest.raises(HTTPException) as exc_info:
verify_token("not-a-valid-jwt", expected_type="access")
assert exc_info.value.status_code == 401
def test_rejects_wrong_signing_key(self):
"""Token signed with a different key should be rejected."""
payload = {
"sub": str(uuid.uuid4()),
"type": "access",
"exp": datetime.now(UTC) + timedelta(hours=1),
}
wrong_key_token = jwt.encode(payload, "wrong-secret-key", algorithm="HS256")
with pytest.raises(HTTPException) as exc_info:
verify_token(wrong_key_token, expected_type="access")
assert exc_info.value.status_code == 401
def test_rejects_missing_subject(self):
"""Token without 'sub' claim should be rejected."""
no_sub_payload = {
"type": "access",
"exp": datetime.now(UTC) + timedelta(hours=1),
}
no_sub_token = jwt.encode(
no_sub_payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
)
with pytest.raises(HTTPException) as exc_info:
verify_token(no_sub_token, expected_type="access")
assert exc_info.value.status_code == 401

View File

@@ -0,0 +1,126 @@
"""Unit tests for the credential encryption/decryption service.
Tests cover:
- Encryption/decryption round-trip with valid key
- Random nonce ensures different ciphertext per encryption
- Wrong key rejection (InvalidTag)
- Invalid key length rejection (ValueError)
- Unicode and JSON payload handling
- Tampered ciphertext detection
These are pure function tests -- no database or async required.
"""
import json
import os
import pytest
from cryptography.exceptions import InvalidTag
from app.services.crypto import decrypt_credentials, encrypt_credentials
class TestEncryptDecryptRoundTrip:
"""Tests for successful encryption/decryption cycles."""
def test_basic_roundtrip(self):
key = os.urandom(32)
plaintext = "secret-password"
ciphertext = encrypt_credentials(plaintext, key)
result = decrypt_credentials(ciphertext, key)
assert result == plaintext
def test_json_credentials_roundtrip(self):
"""The actual use case: encrypting JSON credential objects."""
key = os.urandom(32)
creds = json.dumps({"username": "admin", "password": "RouterOS!123"})
ciphertext = encrypt_credentials(creds, key)
result = decrypt_credentials(ciphertext, key)
parsed = json.loads(result)
assert parsed["username"] == "admin"
assert parsed["password"] == "RouterOS!123"
def test_unicode_roundtrip(self):
key = os.urandom(32)
plaintext = "password-with-unicode-\u00e9\u00e8\u00ea"
ciphertext = encrypt_credentials(plaintext, key)
result = decrypt_credentials(ciphertext, key)
assert result == plaintext
def test_empty_string_roundtrip(self):
key = os.urandom(32)
ciphertext = encrypt_credentials("", key)
result = decrypt_credentials(ciphertext, key)
assert result == ""
def test_long_payload_roundtrip(self):
"""Ensure large payloads work (e.g., SSH keys in credentials)."""
key = os.urandom(32)
plaintext = "x" * 10000
ciphertext = encrypt_credentials(plaintext, key)
result = decrypt_credentials(ciphertext, key)
assert result == plaintext
class TestNonceRandomness:
"""Tests that encryption uses random nonces."""
def test_different_ciphertext_each_time(self):
"""Two encryptions of the same plaintext should produce different ciphertext
because a random 12-byte nonce is generated each time."""
key = os.urandom(32)
plaintext = "same-plaintext"
ct1 = encrypt_credentials(plaintext, key)
ct2 = encrypt_credentials(plaintext, key)
assert ct1 != ct2
def test_both_decrypt_correctly(self):
"""Both different ciphertexts should decrypt to the same plaintext."""
key = os.urandom(32)
plaintext = "same-plaintext"
ct1 = encrypt_credentials(plaintext, key)
ct2 = encrypt_credentials(plaintext, key)
assert decrypt_credentials(ct1, key) == plaintext
assert decrypt_credentials(ct2, key) == plaintext
class TestDecryptionFailures:
"""Tests for proper rejection of invalid inputs."""
def test_wrong_key_raises_invalid_tag(self):
key1 = os.urandom(32)
key2 = os.urandom(32)
ciphertext = encrypt_credentials("secret", key1)
with pytest.raises(InvalidTag):
decrypt_credentials(ciphertext, key2)
def test_tampered_ciphertext_raises_invalid_tag(self):
"""Flipping a byte in the ciphertext should cause authentication failure."""
key = os.urandom(32)
ciphertext = bytearray(encrypt_credentials("secret", key))
# Flip a byte in the encrypted portion (after the 12-byte nonce)
ciphertext[15] ^= 0xFF
with pytest.raises(InvalidTag):
decrypt_credentials(bytes(ciphertext), key)
class TestKeyValidation:
"""Tests for encryption key length validation."""
def test_short_key_encrypt_raises(self):
with pytest.raises(ValueError, match="32 bytes"):
encrypt_credentials("test", os.urandom(16))
def test_long_key_encrypt_raises(self):
with pytest.raises(ValueError, match="32 bytes"):
encrypt_credentials("test", os.urandom(64))
def test_short_key_decrypt_raises(self):
key = os.urandom(32)
ciphertext = encrypt_credentials("test", key)
with pytest.raises(ValueError, match="32 bytes"):
decrypt_credentials(ciphertext, os.urandom(16))
def test_empty_key_raises(self):
with pytest.raises(ValueError, match="32 bytes"):
encrypt_credentials("test", b"")

View File

@@ -0,0 +1,121 @@
"""Unit tests for maintenance window model, router schemas, and alert suppression.
Tests cover:
- MaintenanceWindow ORM model imports and field definitions
- MaintenanceWindowCreate/Update/Response Pydantic schema validation
- Alert evaluator _is_device_in_maintenance integration
- Router registration in main app
"""
import uuid
from datetime import datetime, timezone, timedelta
import pytest
from pydantic import ValidationError
class TestMaintenanceWindowModel:
"""Test that the MaintenanceWindow ORM model is importable and has correct fields."""
def test_model_importable(self):
from app.models.maintenance_window import MaintenanceWindow
assert MaintenanceWindow.__tablename__ == "maintenance_windows"
def test_model_exported_from_init(self):
from app.models import MaintenanceWindow
assert MaintenanceWindow.__tablename__ == "maintenance_windows"
def test_model_has_required_columns(self):
from app.models.maintenance_window import MaintenanceWindow
mapper = MaintenanceWindow.__mapper__
column_names = {c.key for c in mapper.columns}
expected = {
"id", "tenant_id", "name", "device_ids",
"start_at", "end_at", "suppress_alerts",
"notes", "created_by", "created_at", "updated_at",
}
assert expected.issubset(column_names), f"Missing columns: {expected - column_names}"
class TestMaintenanceWindowSchemas:
"""Test Pydantic schemas for request/response validation."""
def test_create_schema_valid(self):
from app.routers.maintenance_windows import MaintenanceWindowCreate
data = MaintenanceWindowCreate(
name="Nightly update",
device_ids=["abc-123"],
start_at=datetime.now(timezone.utc),
end_at=datetime.now(timezone.utc) + timedelta(hours=2),
suppress_alerts=True,
notes="Scheduled maintenance",
)
assert data.name == "Nightly update"
assert data.suppress_alerts is True
def test_create_schema_defaults(self):
from app.routers.maintenance_windows import MaintenanceWindowCreate
data = MaintenanceWindowCreate(
name="Quick reboot",
device_ids=[],
start_at=datetime.now(timezone.utc),
end_at=datetime.now(timezone.utc) + timedelta(hours=1),
)
assert data.suppress_alerts is True # default
assert data.notes is None
def test_update_schema_partial(self):
from app.routers.maintenance_windows import MaintenanceWindowUpdate
data = MaintenanceWindowUpdate(name="Updated name")
assert data.name == "Updated name"
assert data.device_ids is None # all optional
def test_response_schema(self):
from app.routers.maintenance_windows import MaintenanceWindowResponse
data = MaintenanceWindowResponse(
id="abc",
tenant_id="def",
name="Test",
device_ids=["x"],
start_at=datetime.now(timezone.utc).isoformat(),
end_at=datetime.now(timezone.utc).isoformat(),
suppress_alerts=True,
notes=None,
created_by="ghi",
created_at=datetime.now(timezone.utc).isoformat(),
)
assert data.id == "abc"
class TestRouterRegistration:
"""Test that the maintenance_windows router is properly registered."""
def test_router_importable(self):
from app.routers.maintenance_windows import router
assert router is not None
def test_router_has_routes(self):
from app.routers.maintenance_windows import router
paths = [r.path for r in router.routes]
assert any("maintenance-windows" in p for p in paths)
def test_main_app_includes_router(self):
try:
from app.main import app
except ImportError:
pytest.skip("app.main requires full dependencies (prometheus, etc.)")
route_paths = [r.path for r in app.routes]
route_paths_str = " ".join(route_paths)
assert "maintenance-windows" in route_paths_str
class TestAlertEvaluatorMaintenance:
"""Test that alert_evaluator has maintenance window check capability."""
def test_maintenance_cache_exists(self):
from app.services import alert_evaluator
assert hasattr(alert_evaluator, "_maintenance_cache")
def test_is_device_in_maintenance_function_exists(self):
from app.services.alert_evaluator import _is_device_in_maintenance
assert callable(_is_device_in_maintenance)

View File

@@ -0,0 +1,231 @@
"""Unit tests for security hardening.
Tests cover:
- Production startup validation (insecure defaults rejection)
- Security headers middleware (per-environment header behavior)
These are pure function/middleware tests -- no database or async required
for startup validation, async only for middleware tests.
"""
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from app.config import KNOWN_INSECURE_DEFAULTS, validate_production_settings
class TestStartupValidation:
"""Tests for validate_production_settings()."""
def _make_settings(self, **kwargs):
"""Create a mock settings object with given field values."""
defaults = {
"ENVIRONMENT": "dev",
"JWT_SECRET_KEY": "change-this-in-production-use-a-long-random-string",
"CREDENTIAL_ENCRYPTION_KEY": "LLLjnfBZTSycvL2U07HDSxUeTtLxb9cZzryQl0R9E4w=",
}
defaults.update(kwargs)
return SimpleNamespace(**defaults)
def test_production_rejects_insecure_jwt_secret(self):
"""Production with default JWT secret must exit."""
settings = self._make_settings(
ENVIRONMENT="production",
JWT_SECRET_KEY=KNOWN_INSECURE_DEFAULTS["JWT_SECRET_KEY"][0],
)
with pytest.raises(SystemExit) as exc_info:
validate_production_settings(settings)
assert exc_info.value.code == 1
def test_production_rejects_insecure_encryption_key(self):
"""Production with default encryption key must exit."""
settings = self._make_settings(
ENVIRONMENT="production",
JWT_SECRET_KEY="a-real-secure-jwt-secret-that-is-long-enough",
CREDENTIAL_ENCRYPTION_KEY=KNOWN_INSECURE_DEFAULTS["CREDENTIAL_ENCRYPTION_KEY"][0],
)
with pytest.raises(SystemExit) as exc_info:
validate_production_settings(settings)
assert exc_info.value.code == 1
def test_dev_allows_insecure_defaults(self):
"""Dev environment allows insecure defaults without error."""
settings = self._make_settings(
ENVIRONMENT="dev",
JWT_SECRET_KEY=KNOWN_INSECURE_DEFAULTS["JWT_SECRET_KEY"][0],
CREDENTIAL_ENCRYPTION_KEY=KNOWN_INSECURE_DEFAULTS["CREDENTIAL_ENCRYPTION_KEY"][0],
)
# Should NOT raise
validate_production_settings(settings)
def test_production_allows_secure_values(self):
"""Production with non-default secrets should pass."""
settings = self._make_settings(
ENVIRONMENT="production",
JWT_SECRET_KEY="a-real-secure-jwt-secret-that-is-long-enough-for-production",
CREDENTIAL_ENCRYPTION_KEY="dGhpcyBpcyBhIHNlY3VyZSBrZXkgdGhhdCBpcw==",
)
# Should NOT raise
validate_production_settings(settings)
class TestSecurityHeadersMiddleware:
"""Tests for SecurityHeadersMiddleware."""
@pytest.fixture
def prod_app(self):
"""Create a minimal FastAPI app with security middleware in production mode."""
from fastapi import FastAPI
from app.middleware.security_headers import SecurityHeadersMiddleware
app = FastAPI()
app.add_middleware(SecurityHeadersMiddleware, environment="production")
@app.get("/test")
async def test_endpoint():
return {"status": "ok"}
return app
@pytest.fixture
def dev_app(self):
"""Create a minimal FastAPI app with security middleware in dev mode."""
from fastapi import FastAPI
from app.middleware.security_headers import SecurityHeadersMiddleware
app = FastAPI()
app.add_middleware(SecurityHeadersMiddleware, environment="dev")
@app.get("/test")
async def test_endpoint():
return {"status": "ok"}
return app
@pytest.mark.asyncio
async def test_production_includes_hsts(self, prod_app):
"""Production responses must include HSTS header."""
import httpx
transport = httpx.ASGITransport(app=prod_app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/test")
assert response.status_code == 200
assert response.headers["strict-transport-security"] == "max-age=31536000; includeSubDomains"
assert response.headers["x-content-type-options"] == "nosniff"
assert response.headers["x-frame-options"] == "DENY"
assert response.headers["cache-control"] == "no-store"
@pytest.mark.asyncio
async def test_dev_excludes_hsts(self, dev_app):
"""Dev responses must NOT include HSTS (breaks plain HTTP)."""
import httpx
transport = httpx.ASGITransport(app=dev_app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/test")
assert response.status_code == 200
assert "strict-transport-security" not in response.headers
assert response.headers["x-content-type-options"] == "nosniff"
assert response.headers["x-frame-options"] == "DENY"
assert response.headers["cache-control"] == "no-store"
@pytest.mark.asyncio
async def test_csp_header_present_production(self, prod_app):
"""Production responses must include CSP header."""
import httpx
transport = httpx.ASGITransport(app=prod_app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/test")
assert "content-security-policy" in response.headers
csp = response.headers["content-security-policy"]
assert "default-src 'self'" in csp
assert "script-src" in csp
@pytest.mark.asyncio
async def test_csp_header_present_dev(self, dev_app):
"""Dev responses must include CSP header."""
import httpx
transport = httpx.ASGITransport(app=dev_app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/test")
assert "content-security-policy" in response.headers
csp = response.headers["content-security-policy"]
assert "default-src 'self'" in csp
@pytest.mark.asyncio
async def test_csp_production_blocks_inline_scripts(self, prod_app):
"""Production CSP must block inline scripts (no unsafe-inline in script-src)."""
import httpx
transport = httpx.ASGITransport(app=prod_app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/test")
csp = response.headers["content-security-policy"]
# Extract the script-src directive value
script_src = [d for d in csp.split(";") if "script-src" in d][0]
assert "'unsafe-inline'" not in script_src
assert "'unsafe-eval'" not in script_src
assert "'self'" in script_src
@pytest.mark.asyncio
async def test_csp_dev_allows_unsafe_inline(self, dev_app):
"""Dev CSP must allow unsafe-inline and unsafe-eval for Vite HMR."""
import httpx
transport = httpx.ASGITransport(app=dev_app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/test")
csp = response.headers["content-security-policy"]
script_src = [d for d in csp.split(";") if "script-src" in d][0]
assert "'unsafe-inline'" in script_src
assert "'unsafe-eval'" in script_src
@pytest.mark.asyncio
async def test_csp_production_allows_inline_styles(self, prod_app):
"""Production CSP must allow unsafe-inline for styles (Tailwind, Framer Motion, Radix)."""
import httpx
transport = httpx.ASGITransport(app=prod_app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/test")
csp = response.headers["content-security-policy"]
style_src = [d for d in csp.split(";") if "style-src" in d][0]
assert "'unsafe-inline'" in style_src
@pytest.mark.asyncio
async def test_csp_allows_websocket_connections(self, prod_app):
"""CSP must allow wss: and ws: for SSE/WebSocket connections."""
import httpx
transport = httpx.ASGITransport(app=prod_app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/test")
csp = response.headers["content-security-policy"]
connect_src = [d for d in csp.split(";") if "connect-src" in d][0]
assert "wss:" in connect_src
assert "ws:" in connect_src
@pytest.mark.asyncio
async def test_csp_frame_ancestors_none(self, prod_app):
"""CSP must include frame-ancestors 'none' (anti-clickjacking)."""
import httpx
transport = httpx.ASGITransport(app=prod_app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/test")
csp = response.headers["content-security-policy"]
assert "frame-ancestors 'none'" in csp