feat: The Other Dude v9.0.1 — full-featured email system

ci: add GitHub Pages deployment workflow for docs site

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Jason Staack
2026-03-08 17:46:37 -05:00
commit b840047e19
511 changed files with 106948 additions and 0 deletions

1
backend/app/__init__.py Normal file
View File

@@ -0,0 +1 @@
# TOD Backend

177
backend/app/config.py Normal file
View File

@@ -0,0 +1,177 @@
"""Application configuration using Pydantic Settings."""
import base64
import sys
from functools import lru_cache
from typing import Optional
from pydantic import field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
# Known insecure default values that MUST NOT be used in non-dev environments.
# If any of these are detected in production/staging, the app refuses to start.
KNOWN_INSECURE_DEFAULTS: dict[str, list[str]] = {
"JWT_SECRET_KEY": [
"change-this-in-production-use-a-long-random-string",
"dev-jwt-secret-change-in-production",
"CHANGE_ME_IN_PRODUCTION",
],
"CREDENTIAL_ENCRYPTION_KEY": [
"LLLjnfBZTSycvL2U07HDSxUeTtLxb9cZzryQl0R9E4w=",
"CHANGE_ME_IN_PRODUCTION",
],
"OPENBAO_TOKEN": [
"dev-openbao-token",
"CHANGE_ME_IN_PRODUCTION",
],
}
def validate_production_settings(settings: "Settings") -> None:
"""Reject known-insecure defaults in non-dev environments.
Called during app startup. Exits with code 1 and clear error message
if production is running with dev secrets.
"""
if settings.ENVIRONMENT == "dev":
return
for field, insecure_values in KNOWN_INSECURE_DEFAULTS.items():
actual = getattr(settings, field, None)
if actual in insecure_values:
print(
f"FATAL: {field} uses a known insecure default in '{settings.ENVIRONMENT}' environment.\n"
f"Generate a secure value and set it in your .env.prod file.\n"
f"For JWT_SECRET_KEY: python -c \"import secrets; print(secrets.token_urlsafe(64))\"\n"
f"For CREDENTIAL_ENCRYPTION_KEY: python -c \"import secrets, base64; print(base64.b64encode(secrets.token_bytes(32)).decode())\"",
file=sys.stderr,
)
sys.exit(1)
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
extra="ignore",
)
# Environment (dev | staging | production)
ENVIRONMENT: str = "dev"
# Database
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/mikrotik"
# Sync URL used by Alembic only
SYNC_DATABASE_URL: str = "postgresql+psycopg2://postgres:postgres@localhost:5432/mikrotik"
# App user for RLS enforcement (cannot bypass RLS)
APP_USER_DATABASE_URL: str = "postgresql+asyncpg://app_user:app_password@localhost:5432/mikrotik"
# Database connection pool
DB_POOL_SIZE: int = 20
DB_MAX_OVERFLOW: int = 40
DB_ADMIN_POOL_SIZE: int = 10
DB_ADMIN_MAX_OVERFLOW: int = 20
# Redis
REDIS_URL: str = "redis://localhost:6379/0"
# NATS JetStream
NATS_URL: str = "nats://localhost:4222"
# JWT configuration
JWT_SECRET_KEY: str = "change-this-in-production-use-a-long-random-string"
JWT_ALGORITHM: str = "HS256"
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 15
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 7
# Credential encryption key — must be 32 bytes, base64-encoded in env
# Generate with: python -c "import secrets, base64; print(base64.b64encode(secrets.token_bytes(32)).decode())"
CREDENTIAL_ENCRYPTION_KEY: str = "LLLjnfBZTSycvL2U07HDSxUeTtLxb9cZzryQl0R9E4w="
# OpenBao Transit (KMS for per-tenant credential encryption)
OPENBAO_ADDR: str = "http://localhost:8200"
OPENBAO_TOKEN: str = "dev-openbao-token"
# First admin bootstrap
FIRST_ADMIN_EMAIL: Optional[str] = None
FIRST_ADMIN_PASSWORD: Optional[str] = None
# CORS origins (comma-separated)
CORS_ORIGINS: str = "http://localhost:3000,http://localhost:5173,http://localhost:8080"
# Git store — PVC mount for bare git repos (one per tenant).
# In production: /data/git-store (Kubernetes PVC ReadWriteMany).
# In local dev: ./git-store (relative to cwd, created on first use).
GIT_STORE_PATH: str = "./git-store"
# WireGuard config path — shared volume with the WireGuard container
WIREGUARD_CONFIG_PATH: str = "/data/wireguard"
# Firmware cache
FIRMWARE_CACHE_DIR: str = "/data/firmware-cache" # PVC mount path
FIRMWARE_CHECK_INTERVAL_HOURS: int = 24 # How often to check for new versions
# SMTP settings for transactional email (password reset, etc.)
SMTP_HOST: str = "localhost"
SMTP_PORT: int = 587
SMTP_USER: Optional[str] = None
SMTP_PASSWORD: Optional[str] = None
SMTP_USE_TLS: bool = False
SMTP_FROM_ADDRESS: str = "noreply@mikrotik-portal.local"
# Password reset
PASSWORD_RESET_TOKEN_EXPIRE_MINUTES: int = 30
APP_BASE_URL: str = "http://localhost:5173"
# App settings
APP_NAME: str = "TOD - The Other Dude"
APP_VERSION: str = "0.1.0"
DEBUG: bool = False
@field_validator("CREDENTIAL_ENCRYPTION_KEY")
@classmethod
def validate_encryption_key(cls, v: str) -> str:
"""Ensure the key decodes to exactly 32 bytes.
Note: CHANGE_ME_IN_PRODUCTION is allowed through this validator
because it fails the base64 length check. The production safety
check in validate_production_settings() catches it separately.
"""
if v == "CHANGE_ME_IN_PRODUCTION":
# Allow the placeholder through field validation -- the production
# safety check will reject it in non-dev environments.
return v
try:
key_bytes = base64.b64decode(v)
if len(key_bytes) != 32:
raise ValueError(
f"CREDENTIAL_ENCRYPTION_KEY must decode to exactly 32 bytes, got {len(key_bytes)}"
)
except Exception as e:
raise ValueError(f"Invalid CREDENTIAL_ENCRYPTION_KEY: {e}") from e
return v
def get_encryption_key_bytes(self) -> bytes:
"""Return the encryption key as raw bytes."""
return base64.b64decode(self.CREDENTIAL_ENCRYPTION_KEY)
def get_cors_origins(self) -> list[str]:
"""Return CORS origins as a list."""
return [origin.strip() for origin in self.CORS_ORIGINS.split(",") if origin.strip()]
@lru_cache()
def get_settings() -> Settings:
"""Return cached settings instance.
Validates that production environments do not use insecure defaults.
This runs once (cached) at startup before the app accepts requests.
"""
s = Settings()
validate_production_settings(s)
return s
settings = get_settings()

114
backend/app/database.py Normal file
View File

@@ -0,0 +1,114 @@
"""Database engine, session factory, and dependency injection."""
import uuid
from collections.abc import AsyncGenerator
from typing import Optional
from sqlalchemy import text
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import DeclarativeBase
from app.config import settings
class Base(DeclarativeBase):
"""Base class for all SQLAlchemy ORM models."""
pass
# Primary engine using postgres superuser (for migrations/admin)
engine = create_async_engine(
settings.DATABASE_URL,
echo=settings.DEBUG,
pool_pre_ping=True,
pool_size=settings.DB_ADMIN_POOL_SIZE,
max_overflow=settings.DB_ADMIN_MAX_OVERFLOW,
)
# App user engine (enforces RLS — no superuser bypass)
app_engine = create_async_engine(
settings.APP_USER_DATABASE_URL,
echo=settings.DEBUG,
pool_pre_ping=True,
pool_size=settings.DB_POOL_SIZE,
max_overflow=settings.DB_MAX_OVERFLOW,
)
# Session factory for the app_user connection (RLS enforced)
AsyncSessionLocal = async_sessionmaker(
app_engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
# Admin session factory (for bootstrap/migrations only)
AdminAsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""
Dependency that yields an async database session using app_user (RLS enforced).
The tenant context (SET LOCAL app.current_tenant) must be set by
tenant_context middleware before any tenant-scoped queries.
"""
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def get_admin_db() -> AsyncGenerator[AsyncSession, None]:
"""
Dependency that yields an admin database session (bypasses RLS).
USE ONLY for bootstrap operations and internal system tasks.
"""
async with AdminAsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def set_tenant_context(session: AsyncSession, tenant_id: Optional[str]) -> None:
"""
Set the PostgreSQL session variable for RLS enforcement.
This MUST be called before any tenant-scoped query to activate RLS policies.
Uses SET LOCAL so the context resets at transaction end.
"""
if tenant_id:
# Allow 'super_admin' as a special RLS context value for cross-tenant access.
# Otherwise validate tenant_id is a valid UUID to prevent SQL injection.
# SET LOCAL cannot use parameterized queries in PostgreSQL.
if tenant_id != "super_admin":
try:
uuid.UUID(tenant_id)
except ValueError:
raise ValueError(f"Invalid tenant_id format: {tenant_id!r}")
await session.execute(text(f"SET LOCAL app.current_tenant = '{tenant_id}'"))
else:
# For super_admin users: set empty string which will not match any tenant
# The super_admin uses the admin engine which bypasses RLS
await session.execute(text("SET LOCAL app.current_tenant = ''"))

View File

@@ -0,0 +1,81 @@
"""Structured logging configuration for the FastAPI backend.
Uses structlog with two rendering modes:
- Dev mode (ENVIRONMENT=dev or DEBUG=true): colored console output
- Prod mode: machine-parseable JSON output
Must be called once during app startup (in lifespan), NOT at module import time,
so tests can override the configuration.
"""
import logging
import os
import structlog
def configure_logging() -> None:
"""Configure structlog for the FastAPI application.
Dev mode: colored console output with human-readable formatting.
Prod mode: JSON output with machine-parseable fields.
Must be called once during app startup (in lifespan), NOT at module import time,
so tests can override the configuration.
"""
is_dev = os.getenv("ENVIRONMENT", "dev") == "dev"
log_level_name = os.getenv("LOG_LEVEL", "debug" if is_dev else "info").upper()
log_level = getattr(logging, log_level_name, logging.INFO)
shared_processors: list[structlog.types.Processor] = [
structlog.contextvars.merge_contextvars,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.UnicodeDecoder(),
]
if is_dev:
renderer = structlog.dev.ConsoleRenderer()
else:
renderer = structlog.processors.JSONRenderer()
structlog.configure(
processors=[
*shared_processors,
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
],
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)
# Capture stdlib loggers (uvicorn, SQLAlchemy, alembic) into structlog pipeline
formatter = structlog.stdlib.ProcessorFormatter(
processors=[
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
renderer,
],
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
root_logger = logging.getLogger()
root_logger.handlers.clear()
root_logger.addHandler(handler)
root_logger.setLevel(log_level)
# Quiet down noisy libraries in dev
if is_dev:
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
def get_logger(name: str | None = None) -> structlog.stdlib.BoundLogger:
"""Get a structlog bound logger.
Use this instead of logging.getLogger() throughout the application.
"""
return structlog.get_logger(name)

330
backend/app/main.py Normal file
View File

@@ -0,0 +1,330 @@
"""FastAPI application entry point."""
from contextlib import asynccontextmanager
from typing import AsyncGenerator
import structlog
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import JSONResponse
from app.config import settings
from app.logging_config import configure_logging
from app.middleware.rate_limit import setup_rate_limiting
from app.middleware.request_id import RequestIDMiddleware
from app.middleware.security_headers import SecurityHeadersMiddleware
from app.observability import check_health_ready, setup_instrumentator
logger = structlog.get_logger(__name__)
async def run_migrations() -> None:
"""Run Alembic migrations on startup."""
import os
import subprocess
import sys
result = subprocess.run(
[sys.executable, "-m", "alembic", "upgrade", "head"],
capture_output=True,
text=True,
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
)
if result.returncode != 0:
logger.error("migration failed", stderr=result.stderr)
raise RuntimeError(f"Database migration failed: {result.stderr}")
logger.info("migrations applied successfully")
async def bootstrap_first_admin() -> None:
"""Create the first super_admin user if no users exist."""
if not settings.FIRST_ADMIN_EMAIL or not settings.FIRST_ADMIN_PASSWORD:
logger.info("FIRST_ADMIN_EMAIL/PASSWORD not set, skipping bootstrap")
return
from sqlalchemy import select
from app.database import AdminAsyncSessionLocal
from app.models.user import User, UserRole
from app.services.auth import hash_password
async with AdminAsyncSessionLocal() as session:
# Check if any users exist (bypass RLS with admin session)
result = await session.execute(select(User).limit(1))
existing_user = result.scalar_one_or_none()
if existing_user:
logger.info("users already exist, skipping first admin bootstrap")
return
# Create the first super_admin with bcrypt password.
# must_upgrade_auth=True triggers the SRP registration flow on first login.
admin = User(
email=settings.FIRST_ADMIN_EMAIL,
hashed_password=hash_password(settings.FIRST_ADMIN_PASSWORD),
name="Super Admin",
role=UserRole.SUPER_ADMIN.value,
tenant_id=None, # super_admin has no tenant
is_active=True,
must_upgrade_auth=True,
)
session.add(admin)
await session.commit()
logger.info("created first super_admin", email=settings.FIRST_ADMIN_EMAIL)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application lifespan: run migrations and bootstrap on startup."""
from app.services.backup_scheduler import start_backup_scheduler, stop_backup_scheduler
from app.services.firmware_subscriber import start_firmware_subscriber, stop_firmware_subscriber
from app.services.metrics_subscriber import start_metrics_subscriber, stop_metrics_subscriber
from app.services.nats_subscriber import start_nats_subscriber, stop_nats_subscriber
from app.services.sse_manager import ensure_sse_streams
# Configure structured logging FIRST -- before any other startup work
configure_logging()
logger.info("starting TOD API")
# Run database migrations
await run_migrations()
# Bootstrap first admin user
await bootstrap_first_admin()
# Start NATS subscriber for device status events.
# Wrapped in try/except so NATS failure doesn't prevent API startup --
# allows running the API locally without NATS during frontend development.
nats_connection = None
try:
nats_connection = await start_nats_subscriber()
except Exception as exc:
logger.warning(
"NATS status subscriber could not start (API will run without it)",
error=str(exc),
)
# Start NATS subscriber for device metrics events (separate NATS connection).
# Same pattern -- failure is non-fatal so the API starts without full NATS stack.
metrics_nc = None
try:
metrics_nc = await start_metrics_subscriber()
except Exception as exc:
logger.warning(
"NATS metrics subscriber could not start (API will run without it)",
error=str(exc),
)
# Start NATS subscriber for device firmware events (separate NATS connection).
firmware_nc = None
try:
firmware_nc = await start_firmware_subscriber()
except Exception as exc:
logger.warning(
"NATS firmware subscriber could not start (API will run without it)",
error=str(exc),
)
# Ensure NATS streams for SSE event delivery exist (ALERT_EVENTS, OPERATION_EVENTS).
# Non-fatal -- API starts without SSE streams; they'll be created on first SSE connection.
try:
await ensure_sse_streams()
except Exception as exc:
logger.warning(
"SSE NATS streams could not be created (SSE will retry on connection)",
error=str(exc),
)
# Start APScheduler for automated nightly config backups.
# Non-fatal -- API starts and serves requests even without the scheduler.
try:
await start_backup_scheduler()
except Exception as exc:
logger.warning("backup scheduler could not start", error=str(exc))
# Register daily firmware version check (3am UTC) on the same scheduler.
try:
from app.services.firmware_service import schedule_firmware_checks
schedule_firmware_checks()
except Exception as exc:
logger.warning("firmware check scheduler could not start", error=str(exc))
# Provision OpenBao Transit keys for existing tenants and migrate legacy credentials.
# Non-blocking: if OpenBao is unavailable, the dual-read path handles fallback.
if settings.OPENBAO_ADDR:
try:
from app.database import AdminAsyncSessionLocal
from app.services.key_service import provision_existing_tenants
async with AdminAsyncSessionLocal() as openbao_session:
counts = await provision_existing_tenants(openbao_session)
logger.info(
"openbao tenant provisioning complete",
**{k: v for k, v in counts.items()},
)
except Exception as exc:
logger.warning(
"openbao tenant provisioning failed (will retry on next restart)",
error=str(exc),
)
# Recover stale push operations from previous API instance
try:
from app.services.restore_service import recover_stale_push_operations
from app.database import AdminAsyncSessionLocal as _AdminSession
async with _AdminSession() as session:
await recover_stale_push_operations(session)
logger.info("push operation recovery check complete")
except Exception as e:
logger.error("push operation recovery failed (non-fatal): %s", e)
# Config change subscriber (event-driven backups)
config_change_nc = None
try:
from app.services.config_change_subscriber import (
start_config_change_subscriber,
stop_config_change_subscriber,
)
config_change_nc = await start_config_change_subscriber()
except Exception as e:
logger.error("Config change subscriber failed to start (non-fatal): %s", e)
# Push rollback/alert subscriber
push_rollback_nc = None
try:
from app.services.push_rollback_subscriber import (
start_push_rollback_subscriber,
stop_push_rollback_subscriber,
)
push_rollback_nc = await start_push_rollback_subscriber()
except Exception as e:
logger.error("Push rollback subscriber failed to start (non-fatal): %s", e)
logger.info("startup complete, ready to serve requests")
yield
# Shutdown
logger.info("shutting down TOD API")
await stop_backup_scheduler()
await stop_nats_subscriber(nats_connection)
await stop_metrics_subscriber(metrics_nc)
await stop_firmware_subscriber(firmware_nc)
if config_change_nc:
await stop_config_change_subscriber()
if push_rollback_nc:
await stop_push_rollback_subscriber()
# Dispose database engine connections to release all pooled connections cleanly.
from app.database import app_engine, engine
await app_engine.dispose()
await engine.dispose()
logger.info("database connections closed")
def create_app() -> FastAPI:
"""Create and configure the FastAPI application."""
app = FastAPI(
title=settings.APP_NAME,
version=settings.APP_VERSION,
description="The Other Dude — Fleet Management API",
docs_url="/docs" if settings.ENVIRONMENT == "dev" else None,
redoc_url="/redoc" if settings.ENVIRONMENT == "dev" else None,
lifespan=lifespan,
)
# Starlette processes middleware in LIFO order (last added = first to run).
# We want: Request -> RequestID -> CORS -> Route handler
# So add CORS first, then RequestID (it will wrap CORS).
app.add_middleware(
CORSMiddleware,
allow_origins=settings.get_cors_origins(),
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["Authorization", "Content-Type", "X-Request-ID"],
)
app.add_middleware(SecurityHeadersMiddleware, environment=settings.ENVIRONMENT)
setup_rate_limiting(app) # Register 429 exception handler (no middleware added)
app.add_middleware(RequestIDMiddleware)
# Include routers
from app.routers.alerts import router as alerts_router
from app.routers.auth import router as auth_router
from app.routers.sse import router as sse_router
from app.routers.config_backups import router as config_router
from app.routers.config_editor import router as config_editor_router
from app.routers.device_groups import router as device_groups_router
from app.routers.device_tags import router as device_tags_router
from app.routers.devices import router as devices_router
from app.routers.firmware import router as firmware_router
from app.routers.metrics import router as metrics_router
from app.routers.events import router as events_router
from app.routers.clients import router as clients_router
from app.routers.device_logs import router as device_logs_router
from app.routers.templates import router as templates_router
from app.routers.tenants import router as tenants_router
from app.routers.reports import router as reports_router
from app.routers.topology import router as topology_router
from app.routers.users import router as users_router
from app.routers.audit_logs import router as audit_logs_router
from app.routers.api_keys import router as api_keys_router
from app.routers.maintenance_windows import router as maintenance_windows_router
from app.routers.vpn import router as vpn_router
from app.routers.certificates import router as certificates_router
from app.routers.transparency import router as transparency_router
from app.routers.settings import router as settings_router
app.include_router(auth_router, prefix="/api")
app.include_router(tenants_router, prefix="/api")
app.include_router(users_router, prefix="/api")
app.include_router(devices_router, prefix="/api")
app.include_router(device_groups_router, prefix="/api")
app.include_router(device_tags_router, prefix="/api")
app.include_router(metrics_router, prefix="/api")
app.include_router(config_router, prefix="/api")
app.include_router(firmware_router, prefix="/api")
app.include_router(alerts_router, prefix="/api")
app.include_router(config_editor_router, prefix="/api")
app.include_router(events_router, prefix="/api")
app.include_router(device_logs_router, prefix="/api")
app.include_router(templates_router, prefix="/api")
app.include_router(clients_router, prefix="/api")
app.include_router(topology_router, prefix="/api")
app.include_router(sse_router, prefix="/api")
app.include_router(audit_logs_router, prefix="/api")
app.include_router(reports_router, prefix="/api")
app.include_router(api_keys_router, prefix="/api")
app.include_router(maintenance_windows_router, prefix="/api")
app.include_router(vpn_router, prefix="/api")
app.include_router(certificates_router, prefix="/api/certificates", tags=["certificates"])
app.include_router(transparency_router, prefix="/api")
app.include_router(settings_router, prefix="/api")
# Health check endpoints
@app.get("/health", tags=["health"])
async def health_check() -> dict:
"""Liveness probe -- returns 200 if the process is alive."""
return {"status": "ok", "version": settings.APP_VERSION}
@app.get("/health/ready", tags=["health"])
async def health_ready() -> JSONResponse:
"""Readiness probe -- returns 200 only when PostgreSQL, Redis, and NATS are healthy."""
result = await check_health_ready()
status_code = 200 if result["status"] == "healthy" else 503
return JSONResponse(content=result, status_code=status_code)
@app.get("/api/health", tags=["health"])
async def api_health_check() -> dict:
"""Backward-compatible health endpoint under /api prefix."""
return {"status": "ok", "version": settings.APP_VERSION}
# Prometheus metrics instrumentation -- MUST be after routers so all routes are captured
setup_instrumentator(app)
return app
app = create_app()

View File

@@ -0,0 +1 @@
"""FastAPI middleware and dependencies for auth, tenant context, and RBAC."""

View File

@@ -0,0 +1,48 @@
"""Rate limiting middleware using slowapi with Redis backend.
Per-route rate limits only -- no global limits to avoid blocking the
Go poller, NATS subscribers, and health check endpoints.
Rate limit data uses Redis DB 1 (separate from app data in DB 0).
"""
from fastapi import FastAPI
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from app.config import settings
def _get_redis_url() -> str:
"""Return Redis URL pointing to DB 1 for rate limit storage.
Keeps rate limit counters separate from application data in DB 0.
"""
url = settings.REDIS_URL
if url.endswith("/0"):
return url[:-2] + "/1"
# If no DB specified or different DB, append /1
if url.rstrip("/").split("/")[-1].isdigit():
# Replace existing DB number
parts = url.rsplit("/", 1)
return parts[0] + "/1"
return url.rstrip("/") + "/1"
limiter = Limiter(
key_func=get_remote_address,
storage_uri=_get_redis_url(),
default_limits=[], # No global limits -- per-route only
)
def setup_rate_limiting(app: FastAPI) -> None:
"""Register the rate limiter on the FastAPI app.
This sets app.state.limiter (required by slowapi) and registers
the 429 exception handler. It does NOT add middleware -- the
@limiter.limit() decorators handle actual limiting per-route.
"""
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

View File

@@ -0,0 +1,186 @@
"""
Role-Based Access Control (RBAC) middleware.
Provides dependency factories for enforcing role-based access control
on FastAPI routes. Roles are hierarchical:
super_admin > tenant_admin > operator > viewer
Role permissions per plan TENANT-04/05/06:
- viewer: GET endpoints only (read-only)
- operator: GET + device/config management endpoints
- tenant_admin: full access within their tenant
- super_admin: full access across all tenants
"""
from typing import Callable
from fastapi import Depends, HTTPException, Request, status
from fastapi.params import Depends as DependsClass
from app.middleware.tenant_context import CurrentUser, get_current_user
# Role hierarchy (higher index = more privilege)
# api_key is at operator level for RBAC checks; fine-grained access controlled by scopes.
ROLE_HIERARCHY = {
"viewer": 0,
"api_key": 1,
"operator": 1,
"tenant_admin": 2,
"super_admin": 3,
}
def _get_role_level(role: str) -> int:
"""Return numeric privilege level for a role string."""
return ROLE_HIERARCHY.get(role, -1)
def require_role(*allowed_roles: str) -> Callable:
"""
FastAPI dependency factory that checks the current user's role.
Usage:
@router.post("/items", dependencies=[Depends(require_role("tenant_admin", "super_admin"))])
Args:
*allowed_roles: Role strings that are permitted to access the endpoint
Returns:
FastAPI dependency that raises 403 if the role is insufficient
"""
async def dependency(
current_user: CurrentUser = Depends(get_current_user),
) -> CurrentUser:
if current_user.role not in allowed_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied. Required roles: {', '.join(allowed_roles)}. "
f"Your role: {current_user.role}",
)
return current_user
return dependency
def require_min_role(min_role: str) -> Callable:
"""
Dependency factory that allows any role at or above the minimum level.
Usage:
@router.get("/items", dependencies=[Depends(require_min_role("operator"))])
# Allows: operator, tenant_admin, super_admin
# Denies: viewer
"""
min_level = _get_role_level(min_role)
async def dependency(
current_user: CurrentUser = Depends(get_current_user),
) -> CurrentUser:
user_level = _get_role_level(current_user.role)
if user_level < min_level:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied. Minimum required role: {min_role}. "
f"Your role: {current_user.role}",
)
return current_user
return dependency
def require_write_access() -> Callable:
"""
Dependency that enforces viewer read-only restriction.
Viewers are NOT allowed on POST/PUT/PATCH/DELETE endpoints.
Call this on any mutating endpoint to deny viewers.
"""
async def dependency(
request: Request,
current_user: CurrentUser = Depends(get_current_user),
) -> CurrentUser:
if request.method in ("POST", "PUT", "PATCH", "DELETE"):
if current_user.role == "viewer":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Viewers have read-only access. "
"Contact your administrator to request elevated permissions.",
)
return current_user
return dependency
def require_scope(scope: str) -> DependsClass:
"""FastAPI dependency that checks API key scopes.
No-op for regular users (JWT auth) -- scopes only apply to API keys.
For API key users: checks that the required scope is in the key's scope list.
Returns a Depends() instance so it can be used in dependency lists:
@router.get("/items", dependencies=[require_scope("devices:read")])
Args:
scope: Required scope string (e.g. "devices:read", "config:write").
Raises:
HTTPException 403 if the API key is missing the required scope.
"""
async def _check_scope(
current_user: CurrentUser = Depends(get_current_user),
) -> CurrentUser:
if current_user.role == "api_key":
if not current_user.scopes or scope not in current_user.scopes:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"API key missing required scope: {scope}",
)
return current_user
return Depends(_check_scope)
# Pre-built convenience dependencies
async def require_super_admin(
current_user: CurrentUser = Depends(get_current_user),
) -> CurrentUser:
"""Require super_admin role (portal-wide admin)."""
if current_user.role != "super_admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. Super admin role required.",
)
return current_user
async def require_tenant_admin_or_above(
current_user: CurrentUser = Depends(get_current_user),
) -> CurrentUser:
"""Require tenant_admin or super_admin role."""
if current_user.role not in ("tenant_admin", "super_admin"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. Tenant admin or higher role required.",
)
return current_user
async def require_operator_or_above(
current_user: CurrentUser = Depends(get_current_user),
) -> CurrentUser:
"""Require operator, tenant_admin, or super_admin role."""
if current_user.role not in ("operator", "tenant_admin", "super_admin"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. Operator or higher role required.",
)
return current_user
async def require_authenticated(
current_user: CurrentUser = Depends(get_current_user),
) -> CurrentUser:
"""Require any authenticated user (viewer and above)."""
return current_user

View File

@@ -0,0 +1,67 @@
"""Request ID middleware for structured logging context.
Generates or extracts a request ID for every incoming request and binds it
(along with tenant_id from JWT) to structlog's contextvars so that all log
lines emitted during the request include these correlation fields.
"""
import uuid
import structlog
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
class RequestIDMiddleware(BaseHTTPMiddleware):
"""Middleware that binds request_id and tenant_id to structlog context."""
async def dispatch(self, request: Request, call_next):
# CRITICAL: Clear stale context from previous request to prevent leaks
structlog.contextvars.clear_contextvars()
# Generate or extract request ID
request_id = request.headers.get("X-Request-ID", str(uuid.uuid4()))
# Best-effort tenant_id extraction from JWT (does not fail if no token)
tenant_id = self._extract_tenant_id(request)
# Bind to structlog context -- all subsequent log calls include these fields
structlog.contextvars.bind_contextvars(
request_id=request_id,
tenant_id=tenant_id,
)
response: Response = await call_next(request)
response.headers["X-Request-ID"] = request_id
return response
def _extract_tenant_id(self, request: Request) -> str | None:
"""Best-effort extraction of tenant_id from JWT.
Looks in cookies first (access_token), then Authorization header.
Returns None if no valid token is found -- this is fine for
unauthenticated endpoints like /login.
"""
token = request.cookies.get("access_token")
if not token:
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[7:]
if not token:
return None
try:
from jose import jwt as jose_jwt
from app.config import settings
payload = jose_jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM],
)
return payload.get("tenant_id")
except Exception:
return None

View File

@@ -0,0 +1,79 @@
"""Security response headers middleware.
Adds standard security headers to all API responses:
- X-Content-Type-Options: nosniff (prevent MIME sniffing)
- X-Frame-Options: DENY (prevent clickjacking)
- Referrer-Policy: strict-origin-when-cross-origin
- Cache-Control: no-store (prevent browser caching of API responses)
- Strict-Transport-Security (HSTS, production only -- breaks plain HTTP dev)
- Content-Security-Policy (strict in production, relaxed for dev HMR)
CSP directives:
- script-src 'self' (production) blocks inline scripts -- XSS mitigation
- style-src 'unsafe-inline' required for Tailwind, Framer Motion, Radix, Sonner
- connect-src includes wss:/ws: for SSE and WebSocket connections
- Dev mode adds 'unsafe-inline' and 'unsafe-eval' for Vite HMR
"""
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
# Production CSP: strict -- no inline scripts allowed
_CSP_PRODUCTION = "; ".join([
"default-src 'self'",
"script-src 'self'",
"style-src 'self' 'unsafe-inline'",
"img-src 'self' data: blob:",
"font-src 'self'",
"connect-src 'self' wss: ws:",
"worker-src 'self'",
"frame-ancestors 'none'",
"base-uri 'self'",
"form-action 'self'",
])
# Development CSP: relaxed for Vite HMR (hot module replacement)
_CSP_DEV = "; ".join([
"default-src 'self'",
"script-src 'self' 'unsafe-inline' 'unsafe-eval'",
"style-src 'self' 'unsafe-inline'",
"img-src 'self' data: blob:",
"font-src 'self'",
"connect-src 'self' http://localhost:* ws://localhost:* wss:",
"worker-src 'self' blob:",
"frame-ancestors 'none'",
"base-uri 'self'",
"form-action 'self'",
])
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add security headers to every API response."""
def __init__(self, app, environment: str = "dev"):
super().__init__(app)
self.is_production = environment != "dev"
async def dispatch(self, request: Request, call_next) -> Response:
response = await call_next(request)
# Always-on security headers
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Cache-Control"] = "no-store"
# Content-Security-Policy (environment-aware)
if self.is_production:
response.headers["Content-Security-Policy"] = _CSP_PRODUCTION
else:
response.headers["Content-Security-Policy"] = _CSP_DEV
# HSTS only in production (plain HTTP in dev would be blocked)
if self.is_production:
response.headers["Strict-Transport-Security"] = (
"max-age=31536000; includeSubDomains"
)
return response

View File

@@ -0,0 +1,177 @@
"""
Tenant context middleware and current user dependency.
Extracts JWT from Authorization header (Bearer token) or httpOnly cookie,
validates it, and provides current user context for request handlers.
For tenant-scoped users: sets SET LOCAL app.current_tenant on the DB session.
For super_admin: uses special 'super_admin' context that grants cross-tenant access.
"""
import uuid
from typing import Annotated, Optional
from fastapi import Cookie, Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db, set_tenant_context
from app.services.auth import verify_token
# Optional HTTP Bearer scheme (won't raise 403 automatically — we handle auth ourselves)
bearer_scheme = HTTPBearer(auto_error=False)
class CurrentUser:
"""Represents the currently authenticated user extracted from JWT or API key."""
def __init__(
self,
user_id: uuid.UUID,
tenant_id: Optional[uuid.UUID],
role: str,
scopes: Optional[list[str]] = None,
) -> None:
self.user_id = user_id
self.tenant_id = tenant_id
self.role = role
self.scopes = scopes
@property
def is_super_admin(self) -> bool:
return self.role == "super_admin"
@property
def is_api_key(self) -> bool:
return self.role == "api_key"
def __repr__(self) -> str:
return f"<CurrentUser user_id={self.user_id} role={self.role} tenant_id={self.tenant_id}>"
def _extract_token(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials],
access_token: Optional[str],
) -> Optional[str]:
"""
Extract JWT token from Authorization header or httpOnly cookie.
Priority: Authorization header > cookie.
"""
if credentials and credentials.scheme.lower() == "bearer":
return credentials.credentials
if access_token:
return access_token
return None
async def get_current_user(
request: Request,
credentials: Annotated[Optional[HTTPAuthorizationCredentials], Depends(bearer_scheme)] = None,
access_token: Annotated[Optional[str], Cookie()] = None,
db: AsyncSession = Depends(get_db),
) -> CurrentUser:
"""
FastAPI dependency that extracts and validates the current user from JWT.
Supports both Bearer token (Authorization header) and httpOnly cookie.
Sets the tenant context on the database session for RLS enforcement.
Raises:
HTTPException 401: If no token provided or token is invalid
"""
token = _extract_token(request, credentials, access_token)
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
# API key authentication: detect mktp_ prefix and validate via api_key_service
if token.startswith("mktp_"):
from app.services.api_key_service import validate_api_key
key_data = await validate_api_key(token)
if not key_data:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid, expired, or revoked API key",
headers={"WWW-Authenticate": "Bearer"},
)
tenant_id = key_data["tenant_id"]
# Set tenant context on the request-scoped DB session for RLS
await set_tenant_context(db, str(tenant_id))
return CurrentUser(
user_id=key_data["user_id"],
tenant_id=tenant_id,
role="api_key",
scopes=key_data["scopes"],
)
# Decode and validate the JWT
payload = verify_token(token, expected_type="access")
user_id_str = payload.get("sub")
tenant_id_str = payload.get("tenant_id")
role = payload.get("role")
if not user_id_str or not role:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token payload",
headers={"WWW-Authenticate": "Bearer"},
)
try:
user_id = uuid.UUID(user_id_str)
except ValueError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token payload",
)
tenant_id: Optional[uuid.UUID] = None
if tenant_id_str:
try:
tenant_id = uuid.UUID(tenant_id_str)
except ValueError:
pass
# Set the tenant context on the database session for RLS enforcement
if role == "super_admin":
# super_admin uses special context that grants cross-tenant access
await set_tenant_context(db, "super_admin")
elif tenant_id:
await set_tenant_context(db, str(tenant_id))
else:
# Non-super_admin without tenant — deny access
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token: no tenant context",
)
return CurrentUser(
user_id=user_id,
tenant_id=tenant_id,
role=role,
)
async def get_optional_current_user(
request: Request,
credentials: Annotated[Optional[HTTPAuthorizationCredentials], Depends(bearer_scheme)] = None,
access_token: Annotated[Optional[str], Cookie()] = None,
db: AsyncSession = Depends(get_db),
) -> Optional[CurrentUser]:
"""Same as get_current_user but returns None instead of raising 401."""
try:
return await get_current_user(request, credentials, access_token, db)
except HTTPException:
return None

View File

@@ -0,0 +1,35 @@
"""SQLAlchemy ORM models."""
from app.models.tenant import Tenant
from app.models.user import User, UserRole
from app.models.device import Device, DeviceGroup, DeviceTag, DeviceGroupMembership, DeviceTagAssignment, DeviceStatus
from app.models.alert import AlertRule, NotificationChannel, AlertRuleChannel, AlertEvent
from app.models.firmware import FirmwareVersion, FirmwareUpgradeJob
from app.models.config_template import ConfigTemplate, ConfigTemplateTag, TemplatePushJob
from app.models.audit_log import AuditLog
from app.models.maintenance_window import MaintenanceWindow
from app.models.api_key import ApiKey
__all__ = [
"Tenant",
"User",
"UserRole",
"Device",
"DeviceGroup",
"DeviceTag",
"DeviceGroupMembership",
"DeviceTagAssignment",
"DeviceStatus",
"AlertRule",
"NotificationChannel",
"AlertRuleChannel",
"AlertEvent",
"FirmwareVersion",
"FirmwareUpgradeJob",
"ConfigTemplate",
"ConfigTemplateTag",
"TemplatePushJob",
"AuditLog",
"MaintenanceWindow",
"ApiKey",
]

177
backend/app/models/alert.py Normal file
View File

@@ -0,0 +1,177 @@
"""Alert system ORM models: rules, notification channels, and alert events."""
import uuid
from datetime import datetime
from sqlalchemy import (
Boolean,
DateTime,
ForeignKey,
Integer,
LargeBinary,
Numeric,
Text,
func,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
class AlertRule(Base):
"""Configurable alert threshold rule.
Rules can be tenant-wide (device_id=NULL), device-specific, or group-scoped.
When a metric breaches the threshold for duration_polls consecutive polls,
an alert fires.
"""
__tablename__ = "alert_rules"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
)
device_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("devices.id", ondelete="CASCADE"),
nullable=True,
)
group_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("device_groups.id", ondelete="SET NULL"),
nullable=True,
)
name: Mapped[str] = mapped_column(Text, nullable=False)
metric: Mapped[str] = mapped_column(Text, nullable=False)
operator: Mapped[str] = mapped_column(Text, nullable=False)
threshold: Mapped[float] = mapped_column(Numeric, nullable=False)
duration_polls: Mapped[int] = mapped_column(Integer, nullable=False, default=1, server_default="1")
severity: Mapped[str] = mapped_column(Text, nullable=False)
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True, server_default="true")
is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default="false")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
def __repr__(self) -> str:
return f"<AlertRule id={self.id} name={self.name!r} metric={self.metric}>"
class NotificationChannel(Base):
"""Email, webhook, or Slack notification destination."""
__tablename__ = "notification_channels"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
)
name: Mapped[str] = mapped_column(Text, nullable=False)
channel_type: Mapped[str] = mapped_column(Text, nullable=False) # "email", "webhook", or "slack"
# SMTP fields (email channels)
smtp_host: Mapped[str | None] = mapped_column(Text, nullable=True)
smtp_port: Mapped[int | None] = mapped_column(Integer, nullable=True)
smtp_user: Mapped[str | None] = mapped_column(Text, nullable=True)
smtp_password: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True) # AES-256-GCM encrypted
smtp_use_tls: Mapped[bool] = mapped_column(Boolean, default=False, server_default="false")
from_address: Mapped[str | None] = mapped_column(Text, nullable=True)
to_address: Mapped[str | None] = mapped_column(Text, nullable=True)
# Webhook fields
webhook_url: Mapped[str | None] = mapped_column(Text, nullable=True)
# Slack fields
slack_webhook_url: Mapped[str | None] = mapped_column(Text, nullable=True)
# OpenBao Transit ciphertext (dual-write migration)
smtp_password_transit: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
def __repr__(self) -> str:
return f"<NotificationChannel id={self.id} name={self.name!r} type={self.channel_type}>"
class AlertRuleChannel(Base):
"""Many-to-many association between alert rules and notification channels."""
__tablename__ = "alert_rule_channels"
rule_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("alert_rules.id", ondelete="CASCADE"),
primary_key=True,
)
channel_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("notification_channels.id", ondelete="CASCADE"),
primary_key=True,
)
class AlertEvent(Base):
"""Record of an alert firing, resolving, or flapping.
rule_id is NULL for system-level alerts (e.g., device offline).
"""
__tablename__ = "alert_events"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
rule_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("alert_rules.id", ondelete="SET NULL"),
nullable=True,
)
device_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("devices.id", ondelete="CASCADE"),
nullable=False,
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
)
status: Mapped[str] = mapped_column(Text, nullable=False) # "firing", "resolved", "flapping"
severity: Mapped[str] = mapped_column(Text, nullable=False)
metric: Mapped[str | None] = mapped_column(Text, nullable=True)
value: Mapped[float | None] = mapped_column(Numeric, nullable=True)
threshold: Mapped[float | None] = mapped_column(Numeric, nullable=True)
message: Mapped[str | None] = mapped_column(Text, nullable=True)
is_flapping: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default="false")
acknowledged_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
acknowledged_by: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
silenced_until: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
fired_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
resolved_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
def __repr__(self) -> str:
return f"<AlertEvent id={self.id} status={self.status} severity={self.severity}>"

View File

@@ -0,0 +1,60 @@
"""API key ORM model for tenant-scoped programmatic access."""
import uuid
from datetime import datetime
from typing import Optional
from sqlalchemy import DateTime, ForeignKey, Text, func
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
class ApiKey(Base):
"""Tracks API keys for programmatic access to the portal.
Keys are stored as SHA-256 hashes (never plaintext).
Scoped permissions limit what each key can do.
Revocation is soft-delete (sets revoked_at, row preserved for audit).
"""
__tablename__ = "api_keys"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
)
user_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
)
name: Mapped[str] = mapped_column(Text, nullable=False)
key_prefix: Mapped[str] = mapped_column(Text, nullable=False)
key_hash: Mapped[str] = mapped_column(Text, nullable=False, unique=True)
scopes: Mapped[list] = mapped_column(JSONB, nullable=False, server_default="'[]'::jsonb")
expires_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True), nullable=True
)
last_used_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True), nullable=True
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
revoked_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True), nullable=True
)
def __repr__(self) -> str:
return f"<ApiKey id={self.id} name={self.name} prefix={self.key_prefix}>"

View File

@@ -0,0 +1,59 @@
"""Audit log model for centralized audit trail."""
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy import DateTime, ForeignKey, String, Text, func
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
class AuditLog(Base):
"""Records all auditable actions in the system (config changes, CRUD, auth events)."""
__tablename__ = "audit_logs"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
user_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
action: Mapped[str] = mapped_column(String(100), nullable=False)
resource_type: Mapped[str | None] = mapped_column(String(50), nullable=True)
resource_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
device_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("devices.id", ondelete="SET NULL"),
nullable=True,
)
details: Mapped[dict[str, Any]] = mapped_column(
JSONB,
nullable=False,
server_default="{}",
)
# Transit-encrypted details JSON (vault:v1:...) — set when details are encrypted
encrypted_details: Mapped[str | None] = mapped_column(Text, nullable=True)
ip_address: Mapped[str | None] = mapped_column(String(45), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
def __repr__(self) -> str:
return f"<AuditLog id={self.id} action={self.action!r} tenant_id={self.tenant_id}>"

View File

@@ -0,0 +1,140 @@
"""Certificate Authority and Device Certificate ORM models.
Supports the Internal Certificate Authority feature:
- CertificateAuthority: one per tenant, stores encrypted CA private key + public cert
- DeviceCertificate: per-device signed certificate with lifecycle status tracking
"""
import uuid
from datetime import datetime
from sqlalchemy import DateTime, ForeignKey, LargeBinary, String, Text, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
class CertificateAuthority(Base):
"""Per-tenant root Certificate Authority.
Each tenant has at most one CA. The CA private key is encrypted with
AES-256-GCM before storage (using the same pattern as device credentials).
The public cert_pem is not sensitive and can be distributed freely.
"""
__tablename__ = "certificate_authorities"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
unique=True,
)
common_name: Mapped[str] = mapped_column(String(255), nullable=False)
cert_pem: Mapped[str] = mapped_column(Text, nullable=False)
encrypted_private_key: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
serial_number: Mapped[str] = mapped_column(String(64), nullable=False)
fingerprint_sha256: Mapped[str] = mapped_column(String(95), nullable=False)
not_valid_before: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False
)
not_valid_after: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False
)
# OpenBao Transit ciphertext (dual-write migration)
encrypted_private_key_transit: Mapped[str | None] = mapped_column(
Text, nullable=True
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
def __repr__(self) -> str:
return (
f"<CertificateAuthority id={self.id} "
f"cn={self.common_name!r} tenant={self.tenant_id}>"
)
class DeviceCertificate(Base):
"""Per-device TLS certificate signed by the tenant's CA.
Status lifecycle:
issued -> deploying -> deployed -> expiring -> expired
\\-> revoked
\\-> superseded (when rotated)
"""
__tablename__ = "device_certificates"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
)
device_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("devices.id", ondelete="CASCADE"),
nullable=False,
)
ca_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("certificate_authorities.id", ondelete="CASCADE"),
nullable=False,
)
common_name: Mapped[str] = mapped_column(String(255), nullable=False)
serial_number: Mapped[str] = mapped_column(String(64), nullable=False)
fingerprint_sha256: Mapped[str] = mapped_column(String(95), nullable=False)
cert_pem: Mapped[str] = mapped_column(Text, nullable=False)
encrypted_private_key: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
not_valid_before: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False
)
not_valid_after: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False
)
# OpenBao Transit ciphertext (dual-write migration)
encrypted_private_key_transit: Mapped[str | None] = mapped_column(
Text, nullable=True
)
status: Mapped[str] = mapped_column(
String(20), nullable=False, server_default="issued"
)
deployed_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
def __repr__(self) -> str:
return (
f"<DeviceCertificate id={self.id} "
f"cn={self.common_name!r} status={self.status}>"
)

View File

@@ -0,0 +1,178 @@
"""SQLAlchemy models for config backup tables."""
import uuid
from datetime import datetime
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, LargeBinary, SmallInteger, String, Text, UniqueConstraint, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
class ConfigBackupRun(Base):
"""Metadata for a single config backup run.
The actual config content (export.rsc and backup.bin) lives in the tenant's
bare git repository at GIT_STORE_PATH/{tenant_id}.git. This table provides
the timeline view and per-run metadata without duplicating file content in
PostgreSQL.
"""
__tablename__ = "config_backup_runs"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
device_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("devices.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
# Git commit hash in the tenant's bare repo where this backup is stored.
commit_sha: Mapped[str] = mapped_column(Text, nullable=False)
# Trigger type: 'scheduled' | 'manual' | 'pre-restore' | 'checkpoint' | 'config-change'
trigger_type: Mapped[str] = mapped_column(String(20), nullable=False)
# Lines added/removed vs the prior export.rsc for this device.
# NULL for the first backup (no prior version to diff against).
lines_added: Mapped[int | None] = mapped_column(Integer, nullable=True)
lines_removed: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Encryption metadata: NULL=plaintext, 1=client-side AES-GCM, 2=OpenBao Transit
encryption_tier: Mapped[int | None] = mapped_column(SmallInteger, nullable=True)
# 12-byte AES-GCM nonce for Tier 1 (client-side) backups; NULL for plaintext/Transit
encryption_nonce: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
def __repr__(self) -> str:
return (
f"<ConfigBackupRun id={self.id} device_id={self.device_id} "
f"trigger={self.trigger_type!r} sha={self.commit_sha[:8]!r}>"
)
class ConfigBackupSchedule(Base):
"""Per-tenant default and per-device override backup schedule config.
A row with device_id=NULL is the tenant-level default (daily at 2am).
A row with a specific device_id overrides the tenant default for that device.
"""
__tablename__ = "config_backup_schedules"
__table_args__ = (
UniqueConstraint("tenant_id", "device_id", name="uq_backup_schedule_tenant_device"),
)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
# NULL = tenant-level default schedule; non-NULL = device-specific override.
device_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("devices.id", ondelete="CASCADE"),
nullable=True,
)
# Standard cron expression (5 fields). Default: daily at 2am UTC.
cron_expression: Mapped[str] = mapped_column(
String(100),
nullable=False,
default="0 2 * * *",
server_default="0 2 * * *",
)
enabled: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=True,
server_default="TRUE",
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
def __repr__(self) -> str:
scope = f"device={self.device_id}" if self.device_id else f"tenant={self.tenant_id}"
return f"<ConfigBackupSchedule {scope} cron={self.cron_expression!r} enabled={self.enabled}>"
class ConfigPushOperation(Base):
"""Tracks pending two-phase config push operations for panic-revert recovery.
Before pushing a config, a row is inserted with status='pending_verification'.
If the API pod restarts during the 60-second verification window, the startup
handler checks this table and either commits (deletes the RouterOS scheduler
job) or marks the operation as 'failed'. This prevents the panic-revert
scheduler from firing and reverting a successful push after an API restart.
See Pitfall 6 in 04-RESEARCH.md for the full failure scenario.
"""
__tablename__ = "config_push_operations"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
device_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("devices.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
# Git commit SHA we'd revert to if the push fails.
pre_push_commit_sha: Mapped[str] = mapped_column(Text, nullable=False)
# RouterOS scheduler job name created on the device for panic-revert.
scheduler_name: Mapped[str] = mapped_column(String(255), nullable=False)
# 'pending_verification' | 'committed' | 'reverted' | 'failed'
status: Mapped[str] = mapped_column(
String(30),
nullable=False,
default="pending_verification",
server_default="pending_verification",
)
started_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
completed_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True),
nullable=True,
)
def __repr__(self) -> str:
return (
f"<ConfigPushOperation id={self.id} device_id={self.device_id} "
f"status={self.status!r}>"
)

View File

@@ -0,0 +1,153 @@
"""Config template, template tag, and template push job models."""
import uuid
from datetime import datetime
from sqlalchemy import (
DateTime,
Float,
ForeignKey,
String,
Text,
UniqueConstraint,
func,
)
from sqlalchemy.dialects.postgresql import JSON, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
class ConfigTemplate(Base):
__tablename__ = "config_templates"
__table_args__ = (
UniqueConstraint("tenant_id", "name", name="uq_config_templates_tenant_name"),
)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(Text, nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
content: Mapped[str] = mapped_column(Text, nullable=False)
variables: Mapped[list] = mapped_column(JSON, nullable=False, default=list, server_default="[]")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
# Relationships
tenant: Mapped["Tenant"] = relationship("Tenant") # type: ignore[name-defined]
tags: Mapped[list["ConfigTemplateTag"]] = relationship(
"ConfigTemplateTag", back_populates="template", cascade="all, delete-orphan"
)
def __repr__(self) -> str:
return f"<ConfigTemplate id={self.id} name={self.name!r} tenant_id={self.tenant_id}>"
class ConfigTemplateTag(Base):
__tablename__ = "config_template_tags"
__table_args__ = (
UniqueConstraint("template_id", "name", name="uq_config_template_tags_template_name"),
)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
)
name: Mapped[str] = mapped_column(String(100), nullable=False)
template_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("config_templates.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
# Relationships
template: Mapped["ConfigTemplate"] = relationship(
"ConfigTemplate", back_populates="tags"
)
def __repr__(self) -> str:
return f"<ConfigTemplateTag id={self.id} name={self.name!r} template_id={self.template_id}>"
class TemplatePushJob(Base):
__tablename__ = "template_push_jobs"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
)
template_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("config_templates.id", ondelete="SET NULL"),
nullable=True,
)
device_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("devices.id", ondelete="CASCADE"),
nullable=False,
)
rollout_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
nullable=True,
)
rendered_content: Mapped[str] = mapped_column(Text, nullable=False)
status: Mapped[str] = mapped_column(
Text,
nullable=False,
default="pending",
server_default="pending",
)
pre_push_backup_sha: Mapped[str | None] = mapped_column(Text, nullable=True)
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
started_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
completed_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
# Relationships
template: Mapped["ConfigTemplate | None"] = relationship("ConfigTemplate")
device: Mapped["Device"] = relationship("Device") # type: ignore[name-defined]
def __repr__(self) -> str:
return f"<TemplatePushJob id={self.id} status={self.status!r} device_id={self.device_id}>"

View File

@@ -0,0 +1,214 @@
"""Device, DeviceGroup, DeviceTag, and membership models."""
import uuid
from datetime import datetime
from enum import Enum
from sqlalchemy import (
Boolean,
DateTime,
Float,
ForeignKey,
Integer,
LargeBinary,
String,
Text,
UniqueConstraint,
func,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
class DeviceStatus(str, Enum):
"""Device connection status."""
UNKNOWN = "unknown"
ONLINE = "online"
OFFLINE = "offline"
class Device(Base):
__tablename__ = "devices"
__table_args__ = (
UniqueConstraint("tenant_id", "hostname", name="uq_devices_tenant_hostname"),
)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
hostname: Mapped[str] = mapped_column(String(255), nullable=False)
ip_address: Mapped[str] = mapped_column(String(45), nullable=False) # IPv4 or IPv6
api_port: Mapped[int] = mapped_column(Integer, default=8728, nullable=False)
api_ssl_port: Mapped[int] = mapped_column(Integer, default=8729, nullable=False)
model: Mapped[str | None] = mapped_column(String(255), nullable=True)
serial_number: Mapped[str | None] = mapped_column(String(255), nullable=True)
firmware_version: Mapped[str | None] = mapped_column(String(100), nullable=True)
routeros_version: Mapped[str | None] = mapped_column(String(100), nullable=True)
routeros_major_version: Mapped[int | None] = mapped_column(Integer, nullable=True)
uptime_seconds: Mapped[int | None] = mapped_column(Integer, nullable=True)
last_cpu_load: Mapped[int | None] = mapped_column(Integer, nullable=True)
last_memory_used_pct: Mapped[int | None] = mapped_column(Integer, nullable=True)
architecture: Mapped[str | None] = mapped_column(Text, nullable=True) # CPU arch (arm, arm64, mipsbe, etc.)
preferred_channel: Mapped[str] = mapped_column(
Text, default="stable", server_default="stable", nullable=False
) # Firmware release channel
last_seen: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
# AES-256-GCM encrypted credentials (username + password JSON)
encrypted_credentials: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
# OpenBao Transit ciphertext (dual-write migration)
encrypted_credentials_transit: Mapped[str | None] = mapped_column(Text, nullable=True)
latitude: Mapped[float | None] = mapped_column(Float, nullable=True)
longitude: Mapped[float | None] = mapped_column(Float, nullable=True)
status: Mapped[str] = mapped_column(
String(20),
default=DeviceStatus.UNKNOWN.value,
nullable=False,
)
tls_mode: Mapped[str] = mapped_column(
String(20),
default="auto",
server_default="auto",
nullable=False,
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
# Relationships
tenant: Mapped["Tenant"] = relationship("Tenant", back_populates="devices") # type: ignore[name-defined]
group_memberships: Mapped[list["DeviceGroupMembership"]] = relationship(
"DeviceGroupMembership", back_populates="device", cascade="all, delete-orphan"
)
tag_assignments: Mapped[list["DeviceTagAssignment"]] = relationship(
"DeviceTagAssignment", back_populates="device", cascade="all, delete-orphan"
)
def __repr__(self) -> str:
return f"<Device id={self.id} hostname={self.hostname!r} tenant_id={self.tenant_id}>"
class DeviceGroup(Base):
__tablename__ = "device_groups"
__table_args__ = (
UniqueConstraint("tenant_id", "name", name="uq_device_groups_tenant_name"),
)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
preferred_channel: Mapped[str] = mapped_column(
Text, default="stable", server_default="stable", nullable=False
) # Firmware release channel for the group
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
# Relationships
tenant: Mapped["Tenant"] = relationship("Tenant", back_populates="device_groups") # type: ignore[name-defined]
memberships: Mapped[list["DeviceGroupMembership"]] = relationship(
"DeviceGroupMembership", back_populates="group", cascade="all, delete-orphan"
)
def __repr__(self) -> str:
return f"<DeviceGroup id={self.id} name={self.name!r} tenant_id={self.tenant_id}>"
class DeviceTag(Base):
__tablename__ = "device_tags"
__table_args__ = (
UniqueConstraint("tenant_id", "name", name="uq_device_tags_tenant_name"),
)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(String(100), nullable=False)
color: Mapped[str | None] = mapped_column(String(7), nullable=True) # hex color e.g. #FF5733
# Relationships
tenant: Mapped["Tenant"] = relationship("Tenant", back_populates="device_tags") # type: ignore[name-defined]
assignments: Mapped[list["DeviceTagAssignment"]] = relationship(
"DeviceTagAssignment", back_populates="tag", cascade="all, delete-orphan"
)
def __repr__(self) -> str:
return f"<DeviceTag id={self.id} name={self.name!r} tenant_id={self.tenant_id}>"
class DeviceGroupMembership(Base):
__tablename__ = "device_group_memberships"
device_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("devices.id", ondelete="CASCADE"),
primary_key=True,
)
group_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("device_groups.id", ondelete="CASCADE"),
primary_key=True,
)
# Relationships
device: Mapped["Device"] = relationship("Device", back_populates="group_memberships")
group: Mapped["DeviceGroup"] = relationship("DeviceGroup", back_populates="memberships")
class DeviceTagAssignment(Base):
__tablename__ = "device_tag_assignments"
device_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("devices.id", ondelete="CASCADE"),
primary_key=True,
)
tag_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("device_tags.id", ondelete="CASCADE"),
primary_key=True,
)
# Relationships
device: Mapped["Device"] = relationship("Device", back_populates="tag_assignments")
tag: Mapped["DeviceTag"] = relationship("DeviceTag", back_populates="assignments")

View File

@@ -0,0 +1,102 @@
"""Firmware version tracking and upgrade job ORM models."""
import uuid
from datetime import datetime
from sqlalchemy import (
BigInteger,
Boolean,
DateTime,
Integer,
Text,
UniqueConstraint,
func,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import ForeignKey
from app.database import Base
class FirmwareVersion(Base):
"""Cached firmware version from MikroTik download server or poller discovery.
Not tenant-scoped — firmware versions are global data shared across all tenants.
"""
__tablename__ = "firmware_versions"
__table_args__ = (
UniqueConstraint("architecture", "channel", "version"),
)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
architecture: Mapped[str] = mapped_column(Text, nullable=False)
channel: Mapped[str] = mapped_column(Text, nullable=False) # "stable", "long-term", "testing"
version: Mapped[str] = mapped_column(Text, nullable=False)
npk_url: Mapped[str] = mapped_column(Text, nullable=False)
npk_local_path: Mapped[str | None] = mapped_column(Text, nullable=True)
npk_size_bytes: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
checked_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
def __repr__(self) -> str:
return f"<FirmwareVersion arch={self.architecture} ch={self.channel} ver={self.version}>"
class FirmwareUpgradeJob(Base):
"""Tracks a firmware upgrade operation for a single device.
Multiple jobs can share a rollout_group_id for mass upgrades.
"""
__tablename__ = "firmware_upgrade_jobs"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
)
device_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("devices.id", ondelete="CASCADE"),
nullable=False,
)
rollout_group_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
nullable=True,
)
target_version: Mapped[str] = mapped_column(Text, nullable=False)
architecture: Mapped[str] = mapped_column(Text, nullable=False)
channel: Mapped[str] = mapped_column(Text, nullable=False)
status: Mapped[str] = mapped_column(
Text, nullable=False, default="pending", server_default="pending"
)
pre_upgrade_backup_sha: Mapped[str | None] = mapped_column(Text, nullable=True)
scheduled_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
started_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
confirmed_major_upgrade: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False, server_default="false"
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
def __repr__(self) -> str:
return f"<FirmwareUpgradeJob id={self.id} status={self.status} target={self.target_version}>"

View File

@@ -0,0 +1,134 @@
"""Key set and key access log models for zero-knowledge architecture."""
import uuid
from datetime import datetime
from sqlalchemy import DateTime, ForeignKey, Integer, LargeBinary, Text, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
class UserKeySet(Base):
"""Encrypted key bundle for a user.
Stores the RSA private key (wrapped by AUK), tenant vault key
(wrapped by AUK), RSA public key, and key derivation salts.
One key set per user (UNIQUE on user_id).
"""
__tablename__ = "user_key_sets"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
user_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
unique=True,
)
tenant_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=True, # NULL for super_admin
)
encrypted_private_key: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
private_key_nonce: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
encrypted_vault_key: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
vault_key_nonce: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
public_key: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
pbkdf2_iterations: Mapped[int] = mapped_column(
Integer,
server_default=func.literal_column("650000"),
nullable=False,
)
pbkdf2_salt: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
hkdf_salt: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
key_version: Mapped[int] = mapped_column(
Integer,
server_default=func.literal_column("1"),
nullable=False,
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
# Relationships
user: Mapped["User"] = relationship("User") # type: ignore[name-defined]
tenant: Mapped["Tenant | None"] = relationship("Tenant") # type: ignore[name-defined]
def __repr__(self) -> str:
return f"<UserKeySet id={self.id} user_id={self.user_id} version={self.key_version}>"
class KeyAccessLog(Base):
"""Immutable audit trail for key operations.
Append-only: INSERT+SELECT only, no UPDATE/DELETE via RLS.
"""
__tablename__ = "key_access_log"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
)
user_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
action: Mapped[str] = mapped_column(Text, nullable=False)
resource_type: Mapped[str | None] = mapped_column(Text, nullable=True)
resource_id: Mapped[str | None] = mapped_column(Text, nullable=True)
key_version: Mapped[int | None] = mapped_column(Integer, nullable=True)
ip_address: Mapped[str | None] = mapped_column(Text, nullable=True)
# Phase 29 extensions for device credential access tracking
device_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("devices.id"),
nullable=True,
)
justification: Mapped[str | None] = mapped_column(Text, nullable=True)
correlation_id: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
def __repr__(self) -> str:
return f"<KeyAccessLog id={self.id} action={self.action!r}>"

View File

@@ -0,0 +1,74 @@
"""Maintenance window ORM model for scheduled maintenance periods.
Maintenance windows allow operators to define time periods during which
alerts are suppressed for specific devices (or all devices in a tenant).
"""
import uuid
from datetime import datetime
from sqlalchemy import Boolean, DateTime, ForeignKey, Text, VARCHAR, func
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
class MaintenanceWindow(Base):
"""Scheduled maintenance window with optional alert suppression.
device_ids is a JSONB array of device UUID strings.
An empty array means "all devices in tenant".
"""
__tablename__ = "maintenance_windows"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
)
name: Mapped[str] = mapped_column(VARCHAR(200), nullable=False)
device_ids: Mapped[list] = mapped_column(
JSONB,
nullable=False,
server_default="'[]'::jsonb",
)
start_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
)
end_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
)
suppress_alerts: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=True,
server_default="true",
)
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
created_by: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
def __repr__(self) -> str:
return f"<MaintenanceWindow id={self.id} name={self.name!r}>"

View File

@@ -0,0 +1,49 @@
"""Tenant model — represents an MSP client organization."""
import uuid
from datetime import datetime
from sqlalchemy import DateTime, LargeBinary, Integer, String, Text, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
class Tenant(Base):
__tablename__ = "tenants"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
name: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
contact_email: Mapped[str | None] = mapped_column(String(255), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
# Zero-knowledge key management (Phase 28+29)
encrypted_vault_key: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
vault_key_version: Mapped[int | None] = mapped_column(Integer, nullable=True)
openbao_key_name: Mapped[str | None] = mapped_column(Text, nullable=True)
# Relationships — passive_deletes=True lets the DB ON DELETE CASCADE handle cleanup
users: Mapped[list["User"]] = relationship("User", back_populates="tenant", passive_deletes=True) # type: ignore[name-defined]
devices: Mapped[list["Device"]] = relationship("Device", back_populates="tenant", passive_deletes=True) # type: ignore[name-defined]
device_groups: Mapped[list["DeviceGroup"]] = relationship("DeviceGroup", back_populates="tenant", passive_deletes=True) # type: ignore[name-defined]
device_tags: Mapped[list["DeviceTag"]] = relationship("DeviceTag", back_populates="tenant", passive_deletes=True) # type: ignore[name-defined]
def __repr__(self) -> str:
return f"<Tenant id={self.id} name={self.name!r}>"

View File

@@ -0,0 +1,74 @@
"""User model with role-based access control."""
import uuid
from datetime import datetime
from enum import Enum
from sqlalchemy import Boolean, DateTime, ForeignKey, LargeBinary, SmallInteger, String, func, text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
class UserRole(str, Enum):
"""User roles with increasing privilege levels."""
SUPER_ADMIN = "super_admin"
TENANT_ADMIN = "tenant_admin"
OPERATOR = "operator"
VIEWER = "viewer"
class User(Base):
__tablename__ = "users"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
hashed_password: Mapped[str | None] = mapped_column(String(255), nullable=True)
name: Mapped[str] = mapped_column(String(255), nullable=False)
role: Mapped[str] = mapped_column(
String(50),
nullable=False,
default=UserRole.VIEWER.value,
)
# tenant_id is nullable for super_admin users (portal-wide role)
tenant_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=True,
index=True,
)
# SRP zero-knowledge authentication columns (nullable during migration period)
srp_salt: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
srp_verifier: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
auth_version: Mapped[int] = mapped_column(
SmallInteger, server_default=text("1"), nullable=False
) # 1=bcrypt legacy, 2=SRP
must_upgrade_auth: Mapped[bool] = mapped_column(
Boolean, server_default=text("false"), nullable=False
) # True for bcrypt users who need SRP upgrade
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
last_login: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
# Relationships
tenant: Mapped["Tenant | None"] = relationship("Tenant", back_populates="users") # type: ignore[name-defined]
def __repr__(self) -> str:
return f"<User id={self.id} email={self.email!r} role={self.role!r}>"

85
backend/app/models/vpn.py Normal file
View File

@@ -0,0 +1,85 @@
"""VPN configuration and peer models for WireGuard management."""
import uuid
from datetime import datetime
from typing import Optional
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, LargeBinary, String, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
class VpnConfig(Base):
"""Per-tenant WireGuard server configuration."""
__tablename__ = "vpn_config"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
unique=True,
)
server_private_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
server_public_key: Mapped[str] = mapped_column(String(64), nullable=False)
subnet: Mapped[str] = mapped_column(String(32), nullable=False, server_default="10.10.0.0/24")
server_port: Mapped[int] = mapped_column(Integer, nullable=False, server_default="51820")
server_address: Mapped[str] = mapped_column(String(32), nullable=False, server_default="10.10.0.1/24")
endpoint: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
is_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default="false")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False, onupdate=func.now()
)
# Peers are queried separately via tenant_id — no ORM relationship needed
class VpnPeer(Base):
"""WireGuard peer representing a device's VPN connection."""
__tablename__ = "vpn_peers"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
server_default=func.gen_random_uuid(),
)
tenant_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
device_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("devices.id", ondelete="CASCADE"),
nullable=False,
unique=True,
)
peer_private_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
peer_public_key: Mapped[str] = mapped_column(String(64), nullable=False)
preshared_key: Mapped[Optional[bytes]] = mapped_column(LargeBinary, nullable=True)
assigned_ip: Mapped[str] = mapped_column(String(32), nullable=False)
additional_allowed_ips: Mapped[Optional[str]] = mapped_column(String(512), nullable=True)
is_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default="true")
last_handshake: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False, onupdate=func.now()
)
# Config is queried separately via tenant_id — no ORM relationship needed

View File

@@ -0,0 +1,140 @@
"""Prometheus metrics and health check infrastructure.
Provides:
- setup_instrumentator(): Configures Prometheus auto-instrumentation for FastAPI
- check_health_ready(): Verifies PostgreSQL, Redis, and NATS connectivity for readiness probes
"""
import asyncio
import time
import structlog
from fastapi import FastAPI
from prometheus_fastapi_instrumentator import Instrumentator
logger = structlog.get_logger(__name__)
def setup_instrumentator(app: FastAPI) -> Instrumentator:
"""Configure and mount Prometheus metrics instrumentation.
Auto-instruments all HTTP endpoints with:
- http_requests_total (counter) by method, handler, status_code
- http_request_duration_seconds (histogram) by method, handler
- http_requests_in_progress (gauge)
The /metrics endpoint is mounted at root level (not under /api prefix).
Labels use handler templates (e.g., /api/tenants/{tenant_id}/...) not
resolved paths, ensuring bounded cardinality.
Must be called AFTER all routers are included so all routes are captured.
"""
instrumentator = Instrumentator(
should_group_status_codes=False,
should_ignore_untemplated=True,
excluded_handlers=["/health", "/health/ready", "/metrics", "/api/health"],
should_respect_env_var=False,
)
instrumentator.instrument(app)
instrumentator.expose(app, include_in_schema=False, should_gzip=True)
logger.info("prometheus instrumentation enabled", endpoint="/metrics")
return instrumentator
async def check_health_ready() -> dict:
"""Check readiness by verifying all critical dependencies.
Checks PostgreSQL, Redis, and NATS connectivity with 5-second timeouts.
Returns a structured result with per-dependency status and latency.
Returns:
dict with "status" ("healthy"|"unhealthy"), "version", and "checks"
containing per-dependency results.
"""
from app.config import settings
checks: dict[str, dict] = {}
all_healthy = True
# PostgreSQL check
checks["postgres"] = await _check_postgres()
if checks["postgres"]["status"] != "up":
all_healthy = False
# Redis check
checks["redis"] = await _check_redis(settings.REDIS_URL)
if checks["redis"]["status"] != "up":
all_healthy = False
# NATS check
checks["nats"] = await _check_nats(settings.NATS_URL)
if checks["nats"]["status"] != "up":
all_healthy = False
return {
"status": "healthy" if all_healthy else "unhealthy",
"version": settings.APP_VERSION,
"checks": checks,
}
async def _check_postgres() -> dict:
"""Verify PostgreSQL connectivity via the admin engine."""
start = time.monotonic()
try:
from sqlalchemy import text
from app.database import engine
async with engine.connect() as conn:
await asyncio.wait_for(
conn.execute(text("SELECT 1")),
timeout=5.0,
)
latency_ms = round((time.monotonic() - start) * 1000)
return {"status": "up", "latency_ms": latency_ms, "error": None}
except Exception as exc:
latency_ms = round((time.monotonic() - start) * 1000)
logger.warning("health check: postgres failed", error=str(exc))
return {"status": "down", "latency_ms": latency_ms, "error": str(exc)}
async def _check_redis(redis_url: str) -> dict:
"""Verify Redis connectivity."""
start = time.monotonic()
try:
import redis.asyncio as aioredis
client = aioredis.from_url(redis_url, socket_connect_timeout=5)
try:
await asyncio.wait_for(client.ping(), timeout=5.0)
finally:
await client.aclose()
latency_ms = round((time.monotonic() - start) * 1000)
return {"status": "up", "latency_ms": latency_ms, "error": None}
except Exception as exc:
latency_ms = round((time.monotonic() - start) * 1000)
logger.warning("health check: redis failed", error=str(exc))
return {"status": "down", "latency_ms": latency_ms, "error": str(exc)}
async def _check_nats(nats_url: str) -> dict:
"""Verify NATS connectivity."""
start = time.monotonic()
try:
import nats
nc = await asyncio.wait_for(
nats.connect(nats_url),
timeout=5.0,
)
try:
await nc.drain()
except Exception:
pass
latency_ms = round((time.monotonic() - start) * 1000)
return {"status": "up", "latency_ms": latency_ms, "error": None}
except Exception as exc:
latency_ms = round((time.monotonic() - start) * 1000)
logger.warning("health check: nats failed", error=str(exc))
return {"status": "down", "latency_ms": latency_ms, "error": str(exc)}

View File

@@ -0,0 +1 @@
"""FastAPI routers for all API endpoints."""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,172 @@
"""API key management endpoints.
Tenant-scoped routes under /api/tenants/{tenant_id}/api-keys:
- List all keys (active + revoked)
- Create new key (returns plaintext once)
- Revoke key (soft delete)
RBAC: tenant_admin or above for all operations.
RLS enforced via get_db() (app_user engine with tenant context).
"""
import uuid
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, ConfigDict
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db, set_tenant_context
from app.middleware.rbac import require_min_role
from app.middleware.tenant_context import CurrentUser, get_current_user
from app.services.api_key_service import (
ALLOWED_SCOPES,
create_api_key,
list_api_keys,
revoke_api_key,
)
router = APIRouter(tags=["api-keys"])
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
"""Verify the current user is allowed to access the given tenant."""
if current_user.is_super_admin:
await set_tenant_context(db, str(tenant_id))
elif current_user.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to this tenant",
)
# ---------------------------------------------------------------------------
# Request/response schemas
# ---------------------------------------------------------------------------
class ApiKeyCreate(BaseModel):
model_config = ConfigDict(extra="forbid")
name: str
scopes: list[str]
expires_at: Optional[datetime] = None
class ApiKeyResponse(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
name: str
key_prefix: str
scopes: list[str]
expires_at: Optional[str] = None
last_used_at: Optional[str] = None
created_at: str
revoked_at: Optional[str] = None
class ApiKeyCreateResponse(ApiKeyResponse):
"""Extended response that includes the plaintext key (shown once)."""
key: str
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.get("/tenants/{tenant_id}/api-keys", response_model=list[ApiKeyResponse])
async def list_keys(
tenant_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("tenant_admin")),
) -> list[dict]:
"""List all API keys for a tenant."""
await _check_tenant_access(current_user, tenant_id, db)
keys = await list_api_keys(db, tenant_id)
# Convert UUID ids to strings for response
for k in keys:
k["id"] = str(k["id"])
return keys
@router.post(
"/tenants/{tenant_id}/api-keys",
response_model=ApiKeyCreateResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_key(
tenant_id: uuid.UUID,
body: ApiKeyCreate,
db: AsyncSession = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("tenant_admin")),
) -> dict:
"""Create a new API key. The plaintext key is returned only once."""
await _check_tenant_access(current_user, tenant_id, db)
# Validate scopes against allowed list
invalid_scopes = set(body.scopes) - ALLOWED_SCOPES
if invalid_scopes:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid scopes: {', '.join(sorted(invalid_scopes))}. "
f"Allowed: {', '.join(sorted(ALLOWED_SCOPES))}",
)
if not body.scopes:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="At least one scope is required.",
)
result = await create_api_key(
db=db,
tenant_id=tenant_id,
user_id=current_user.user_id,
name=body.name,
scopes=body.scopes,
expires_at=body.expires_at,
)
return {
"id": str(result["id"]),
"name": result["name"],
"key_prefix": result["key_prefix"],
"key": result["key"],
"scopes": result["scopes"],
"expires_at": result["expires_at"].isoformat() if result["expires_at"] else None,
"last_used_at": None,
"created_at": result["created_at"].isoformat() if result["created_at"] else None,
"revoked_at": None,
}
@router.delete("/tenants/{tenant_id}/api-keys/{key_id}", status_code=status.HTTP_200_OK)
async def revoke_key(
tenant_id: uuid.UUID,
key_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("tenant_admin")),
) -> dict:
"""Revoke an API key (soft delete -- sets revoked_at timestamp)."""
await _check_tenant_access(current_user, tenant_id, db)
success = await revoke_api_key(db, tenant_id, key_id)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found or already revoked.",
)
return {"status": "revoked", "key_id": str(key_id)}

View File

@@ -0,0 +1,294 @@
"""Audit log API endpoints.
Tenant-scoped routes under /api/tenants/{tenant_id}/ for:
- Paginated, filterable audit log listing
- CSV export of audit logs
RLS enforced via get_db() (app_user engine with tenant context).
RBAC: operator and above can view audit logs.
Phase 30: Audit log details are encrypted at rest via Transit (Tier 2).
When encrypted_details is set, the router decrypts via Transit on-demand
and returns the plaintext details in the response. Structural fields
(action, resource_type, timestamp, ip_address) are always plaintext.
"""
import asyncio
import csv
import io
import json
import logging
import uuid
from datetime import datetime
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy import and_, func, select, text
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db, set_tenant_context
from app.middleware.tenant_context import CurrentUser, get_current_user
logger = logging.getLogger(__name__)
router = APIRouter(tags=["audit-logs"])
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
"""Verify the current user is allowed to access the given tenant."""
if current_user.is_super_admin:
await set_tenant_context(db, str(tenant_id))
elif current_user.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to this tenant",
)
def _require_operator(current_user: CurrentUser) -> None:
"""Raise 403 if user does not have at least operator role."""
allowed = {"super_admin", "admin", "operator"}
if current_user.role not in allowed:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="At least operator role required to view audit logs.",
)
async def _decrypt_audit_details(
encrypted_details: str | None,
plaintext_details: dict[str, Any] | None,
tenant_id: str,
) -> dict[str, Any]:
"""Decrypt encrypted audit log details via Transit, falling back to plaintext.
Priority:
1. If encrypted_details is set, decrypt via Transit and parse as JSON.
2. If decryption fails, return plaintext details as fallback.
3. If neither available, return empty dict.
"""
if encrypted_details:
try:
from app.services.crypto import decrypt_data_transit
decrypted_json = await decrypt_data_transit(encrypted_details, tenant_id)
return json.loads(decrypted_json)
except Exception:
logger.warning(
"Failed to decrypt audit details for tenant %s, using plaintext fallback",
tenant_id,
exc_info=True,
)
# Fall through to plaintext
return plaintext_details if plaintext_details else {}
async def _decrypt_details_batch(
rows: list[Any],
tenant_id: str,
) -> list[dict[str, Any]]:
"""Decrypt encrypted_details for a batch of audit log rows concurrently.
Uses asyncio.gather with limited concurrency to avoid overwhelming OpenBao.
Rows without encrypted_details return their plaintext details directly.
"""
semaphore = asyncio.Semaphore(10) # Limit concurrent Transit calls
async def _decrypt_one(row: Any) -> dict[str, Any]:
async with semaphore:
return await _decrypt_audit_details(
row.get("encrypted_details"),
row.get("details"),
tenant_id,
)
return list(await asyncio.gather(*[_decrypt_one(row) for row in rows]))
# ---------------------------------------------------------------------------
# Response models
# ---------------------------------------------------------------------------
class AuditLogItem(BaseModel):
id: str
user_email: Optional[str] = None
action: str
resource_type: Optional[str] = None
resource_id: Optional[str] = None
device_name: Optional[str] = None
details: dict[str, Any] = {}
ip_address: Optional[str] = None
created_at: str
class AuditLogResponse(BaseModel):
items: list[AuditLogItem]
total: int
page: int
per_page: int
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.get(
"/tenants/{tenant_id}/audit-logs",
response_model=AuditLogResponse,
summary="List audit logs with pagination and filters",
)
async def list_audit_logs(
tenant_id: uuid.UUID,
page: int = Query(default=1, ge=1),
per_page: int = Query(default=50, ge=1, le=100),
action: Optional[str] = Query(default=None),
user_id: Optional[uuid.UUID] = Query(default=None),
device_id: Optional[uuid.UUID] = Query(default=None),
date_from: Optional[datetime] = Query(default=None),
date_to: Optional[datetime] = Query(default=None),
format: Optional[str] = Query(default=None, description="Set to 'csv' for CSV export"),
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
_require_operator(current_user)
await _check_tenant_access(current_user, tenant_id, db)
# Build filter conditions using parameterized text fragments
conditions = [text("a.tenant_id = :tenant_id")]
params: dict[str, Any] = {"tenant_id": str(tenant_id)}
if action:
conditions.append(text("a.action = :action"))
params["action"] = action
if user_id:
conditions.append(text("a.user_id = :user_id"))
params["user_id"] = str(user_id)
if device_id:
conditions.append(text("a.device_id = :device_id"))
params["device_id"] = str(device_id)
if date_from:
conditions.append(text("a.created_at >= :date_from"))
params["date_from"] = date_from.isoformat()
if date_to:
conditions.append(text("a.created_at <= :date_to"))
params["date_to"] = date_to.isoformat()
where_clause = and_(*conditions)
# Shared SELECT columns for data queries
_data_columns = text(
"a.id, u.email AS user_email, a.action, a.resource_type, "
"a.resource_id, d.hostname AS device_name, a.details, "
"a.encrypted_details, a.ip_address, a.created_at"
)
_data_from = text(
"audit_logs a "
"LEFT JOIN users u ON a.user_id = u.id "
"LEFT JOIN devices d ON a.device_id = d.id"
)
# Count total
count_result = await db.execute(
select(func.count()).select_from(text("audit_logs a")).where(where_clause),
params,
)
total = count_result.scalar() or 0
# CSV export -- no pagination limit
if format == "csv":
result = await db.execute(
select(_data_columns)
.select_from(_data_from)
.where(where_clause)
.order_by(text("a.created_at DESC")),
params,
)
all_rows = result.mappings().all()
# Decrypt encrypted details concurrently
decrypted_details = await _decrypt_details_batch(
all_rows, str(tenant_id)
)
output = io.StringIO()
writer = csv.writer(output)
writer.writerow([
"ID", "User Email", "Action", "Resource Type",
"Resource ID", "Device", "Details", "IP Address", "Timestamp",
])
for row, details in zip(all_rows, decrypted_details):
details_str = json.dumps(details) if details else "{}"
writer.writerow([
str(row["id"]),
row["user_email"] or "",
row["action"],
row["resource_type"] or "",
row["resource_id"] or "",
row["device_name"] or "",
details_str,
row["ip_address"] or "",
str(row["created_at"]),
])
output.seek(0)
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=audit-logs.csv"},
)
# Paginated query
offset = (page - 1) * per_page
params["limit"] = per_page
params["offset"] = offset
result = await db.execute(
select(_data_columns)
.select_from(_data_from)
.where(where_clause)
.order_by(text("a.created_at DESC"))
.limit(per_page)
.offset(offset),
params,
)
rows = result.mappings().all()
# Decrypt encrypted details concurrently (skips rows without encrypted_details)
decrypted_details = await _decrypt_details_batch(rows, str(tenant_id))
items = [
AuditLogItem(
id=str(row["id"]),
user_email=row["user_email"],
action=row["action"],
resource_type=row["resource_type"],
resource_id=row["resource_id"],
device_name=row["device_name"],
details=details,
ip_address=row["ip_address"],
created_at=row["created_at"].isoformat() if row["created_at"] else "",
)
for row, details in zip(rows, decrypted_details)
]
return AuditLogResponse(
items=items,
total=total,
page=page,
per_page=per_page,
)

1052
backend/app/routers/auth.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,763 @@
"""Certificate Authority management API endpoints.
Provides the full certificate lifecycle for tenant CAs:
- CA initialization and info retrieval
- Per-device certificate signing
- Certificate deployment via NATS to Go poller (SFTP + RouterOS import)
- Bulk deployment across multiple devices
- Certificate rotation and revocation
RLS enforced via get_db() (app_user engine with tenant context).
RBAC: viewer = read-only (GET); tenant_admin and above = mutating actions.
"""
import json
import logging
import uuid
from datetime import datetime, timezone
import nats
import nats.aio.client
import nats.errors
import structlog
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from fastapi.responses import PlainTextResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.database import get_db, set_tenant_context
from app.middleware.rate_limit import limiter
from app.middleware.rbac import require_min_role
from app.middleware.tenant_context import CurrentUser, get_current_user
from app.models.certificate import CertificateAuthority, DeviceCertificate
from app.models.device import Device
from app.schemas.certificate import (
BulkCertDeployRequest,
CACreateRequest,
CAResponse,
CertDeployResponse,
CertSignRequest,
DeviceCertResponse,
)
from app.services.audit_service import log_action
from app.services.ca_service import (
generate_ca,
get_ca_for_tenant,
get_cert_for_deploy,
get_device_certs,
sign_device_cert,
update_cert_status,
)
logger = structlog.get_logger(__name__)
router = APIRouter(tags=["certificates"])
# Module-level NATS connection for cert deployment (lazy initialized)
_nc: nats.aio.client.Client | None = None
async def _get_nats() -> nats.aio.client.Client:
"""Get or create a NATS connection for certificate deployment requests."""
global _nc
if _nc is None or _nc.is_closed:
_nc = await nats.connect(settings.NATS_URL)
logger.info("Certificate NATS connection established")
return _nc
async def _deploy_cert_via_nats(
device_id: str,
cert_pem: str,
key_pem: str,
cert_name: str,
ssh_port: int = 22,
) -> dict:
"""Send a certificate deployment request to the Go poller via NATS.
Args:
device_id: Target device UUID string.
cert_pem: PEM-encoded device certificate.
key_pem: PEM-encoded device private key (decrypted).
cert_name: Name for the cert on the device (e.g., "portal-device-cert").
ssh_port: SSH port for SFTP upload (default 22).
Returns:
Dict with success, cert_name_on_device, and error fields.
"""
nc = await _get_nats()
payload = json.dumps({
"device_id": device_id,
"cert_pem": cert_pem,
"key_pem": key_pem,
"cert_name": cert_name,
"ssh_port": ssh_port,
}).encode()
try:
reply = await nc.request(
f"cert.deploy.{device_id}",
payload,
timeout=60.0,
)
return json.loads(reply.data)
except nats.errors.TimeoutError:
return {
"success": False,
"error": "Certificate deployment timed out -- device may be offline or unreachable",
}
except Exception as exc:
logger.error("NATS cert deploy request failed", device_id=device_id, error=str(exc))
return {"success": False, "error": str(exc)}
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _get_device_for_tenant(
db: AsyncSession, device_id: uuid.UUID, current_user: CurrentUser
) -> Device:
"""Fetch a device and verify tenant ownership."""
result = await db.execute(
select(Device).where(Device.id == device_id)
)
device = result.scalar_one_or_none()
if device is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Device {device_id} not found",
)
return device
async def _get_tenant_id(
current_user: CurrentUser,
db: AsyncSession,
tenant_id_override: uuid.UUID | None = None,
) -> uuid.UUID:
"""Extract tenant_id from the current user, handling super_admin.
Super admins must provide tenant_id_override (from query param).
Regular users use their own tenant_id.
"""
if current_user.is_super_admin:
if tenant_id_override is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Super admin must provide tenant_id query parameter.",
)
# Set RLS context for the selected tenant
await set_tenant_context(db, str(tenant_id_override))
return tenant_id_override
if current_user.tenant_id is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No tenant context available.",
)
return current_user.tenant_id
async def _get_cert_with_tenant_check(
db: AsyncSession, cert_id: uuid.UUID, tenant_id: uuid.UUID
) -> DeviceCertificate:
"""Fetch a device certificate and verify tenant ownership."""
result = await db.execute(
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
)
cert = result.scalar_one_or_none()
if cert is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Certificate {cert_id} not found",
)
# RLS should enforce this, but double-check
if cert.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Certificate {cert_id} not found",
)
return cert
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post(
"/ca",
response_model=CAResponse,
status_code=status.HTTP_201_CREATED,
summary="Initialize a Certificate Authority for the tenant",
)
@limiter.limit("5/minute")
async def create_ca(
request: Request,
body: CACreateRequest,
tenant_id: uuid.UUID | None = Query(None, description="Tenant ID (required for super_admin)"),
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("tenant_admin")),
db: AsyncSession = Depends(get_db),
) -> CAResponse:
"""Generate a self-signed root CA for the tenant.
Each tenant may have at most one CA. Returns 409 if a CA already exists.
"""
tenant_id = await _get_tenant_id(current_user, db, tenant_id)
# Check if CA already exists
existing = await get_ca_for_tenant(db, tenant_id)
if existing is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Tenant already has a Certificate Authority. Delete it before creating a new one.",
)
ca = await generate_ca(
db,
tenant_id,
body.common_name,
body.validity_years,
settings.get_encryption_key_bytes(),
)
try:
await log_action(
db, tenant_id, current_user.user_id, "ca_create",
resource_type="certificate_authority", resource_id=str(ca.id),
details={"common_name": body.common_name, "validity_years": body.validity_years},
)
except Exception:
pass
logger.info("CA created", tenant_id=str(tenant_id), ca_id=str(ca.id))
return CAResponse.model_validate(ca)
@router.get(
"/ca",
response_model=CAResponse,
summary="Get tenant CA information",
)
async def get_ca(
tenant_id: uuid.UUID | None = Query(None, description="Tenant ID (required for super_admin)"),
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("viewer")),
db: AsyncSession = Depends(get_db),
) -> CAResponse:
"""Return the tenant's CA public information (no private key)."""
tenant_id = await _get_tenant_id(current_user, db, tenant_id)
ca = await get_ca_for_tenant(db, tenant_id)
if ca is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No Certificate Authority configured for this tenant.",
)
return CAResponse.model_validate(ca)
@router.get(
"/ca/pem",
response_class=PlainTextResponse,
summary="Download the CA public certificate (PEM)",
)
async def get_ca_pem(
tenant_id: uuid.UUID | None = Query(None, description="Tenant ID (required for super_admin)"),
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("viewer")),
db: AsyncSession = Depends(get_db),
) -> PlainTextResponse:
"""Return the CA's public certificate in PEM format.
Users can import this into their trust store to validate device connections.
"""
tenant_id = await _get_tenant_id(current_user, db, tenant_id)
ca = await get_ca_for_tenant(db, tenant_id)
if ca is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No Certificate Authority configured for this tenant.",
)
return PlainTextResponse(
content=ca.cert_pem,
media_type="application/x-pem-file",
headers={"Content-Disposition": "attachment; filename=portal-ca.pem"},
)
@router.post(
"/sign",
response_model=DeviceCertResponse,
status_code=status.HTTP_201_CREATED,
summary="Sign a certificate for a device",
)
@limiter.limit("20/minute")
async def sign_cert(
request: Request,
body: CertSignRequest,
tenant_id: uuid.UUID | None = Query(None, description="Tenant ID (required for super_admin)"),
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("tenant_admin")),
db: AsyncSession = Depends(get_db),
) -> DeviceCertResponse:
"""Sign a per-device TLS certificate using the tenant's CA.
The device must belong to the tenant. The cert uses CN=hostname, SAN=IP+DNS.
"""
tenant_id = await _get_tenant_id(current_user, db, tenant_id)
# Verify device belongs to tenant (RLS enforces, but also get device data)
device = await _get_device_for_tenant(db, body.device_id, current_user)
# Get tenant CA
ca = await get_ca_for_tenant(db, tenant_id)
if ca is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No Certificate Authority configured. Initialize a CA first.",
)
cert = await sign_device_cert(
db,
ca,
body.device_id,
device.hostname,
device.ip_address,
body.validity_days,
settings.get_encryption_key_bytes(),
)
try:
await log_action(
db, tenant_id, current_user.user_id, "cert_sign",
resource_type="device_certificate", resource_id=str(cert.id),
device_id=body.device_id,
details={"hostname": device.hostname, "validity_days": body.validity_days},
)
except Exception:
pass
logger.info("Device cert signed", device_id=str(body.device_id), cert_id=str(cert.id))
return DeviceCertResponse.model_validate(cert)
@router.post(
"/{cert_id}/deploy",
response_model=CertDeployResponse,
summary="Deploy a signed certificate to a device",
)
@limiter.limit("20/minute")
async def deploy_cert(
request: Request,
cert_id: uuid.UUID,
tenant_id: uuid.UUID | None = Query(None, description="Tenant ID (required for super_admin)"),
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("tenant_admin")),
db: AsyncSession = Depends(get_db),
) -> CertDeployResponse:
"""Deploy a signed certificate to a device via NATS/SFTP.
The Go poller receives the cert, uploads it via SFTP, imports it,
and assigns it to the api-ssl service on the RouterOS device.
"""
tenant_id = await _get_tenant_id(current_user, db, tenant_id)
cert = await _get_cert_with_tenant_check(db, cert_id, tenant_id)
# Update status to deploying
try:
await update_cert_status(db, cert_id, "deploying")
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=str(e),
)
# Get decrypted cert data for deployment
try:
cert_pem, key_pem, _ca_cert_pem = await get_cert_for_deploy(
db, cert_id, settings.get_encryption_key_bytes()
)
except ValueError as e:
# Rollback status
await update_cert_status(db, cert_id, "issued")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to prepare cert for deployment: {e}",
)
# Flush DB changes before NATS call so deploying status is persisted
await db.flush()
# Send deployment command via NATS
result = await _deploy_cert_via_nats(
device_id=str(cert.device_id),
cert_pem=cert_pem,
key_pem=key_pem,
cert_name="portal-device-cert",
)
if result.get("success"):
# Update cert status to deployed
await update_cert_status(db, cert_id, "deployed")
# Update device tls_mode to portal_ca
device_result = await db.execute(
select(Device).where(Device.id == cert.device_id)
)
device = device_result.scalar_one_or_none()
if device is not None:
device.tls_mode = "portal_ca"
try:
await log_action(
db, tenant_id, current_user.user_id, "cert_deploy",
resource_type="device_certificate", resource_id=str(cert_id),
device_id=cert.device_id,
details={"cert_name_on_device": result.get("cert_name_on_device")},
)
except Exception:
pass
logger.info(
"Certificate deployed successfully",
cert_id=str(cert_id),
device_id=str(cert.device_id),
cert_name_on_device=result.get("cert_name_on_device"),
)
return CertDeployResponse(
success=True,
device_id=cert.device_id,
cert_name_on_device=result.get("cert_name_on_device"),
)
else:
# Rollback status to issued
await update_cert_status(db, cert_id, "issued")
logger.warning(
"Certificate deployment failed",
cert_id=str(cert_id),
device_id=str(cert.device_id),
error=result.get("error"),
)
return CertDeployResponse(
success=False,
device_id=cert.device_id,
error=result.get("error"),
)
@router.post(
"/deploy/bulk",
response_model=list[CertDeployResponse],
summary="Bulk deploy certificates to multiple devices",
)
@limiter.limit("5/minute")
async def bulk_deploy(
request: Request,
body: BulkCertDeployRequest,
tenant_id: uuid.UUID | None = Query(None, description="Tenant ID (required for super_admin)"),
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("tenant_admin")),
db: AsyncSession = Depends(get_db),
) -> list[CertDeployResponse]:
"""Deploy certificates to multiple devices sequentially.
For each device: signs a cert if none exists (status=issued), then deploys.
Sequential deployment per project patterns (no concurrent NATS calls).
"""
tenant_id = await _get_tenant_id(current_user, db, tenant_id)
# Get tenant CA
ca = await get_ca_for_tenant(db, tenant_id)
if ca is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No Certificate Authority configured. Initialize a CA first.",
)
results: list[CertDeployResponse] = []
for device_id in body.device_ids:
try:
# Get device info
device = await _get_device_for_tenant(db, device_id, current_user)
# Check if device already has an issued cert
existing_certs = await get_device_certs(db, tenant_id, device_id)
issued_cert = None
for c in existing_certs:
if c.status == "issued":
issued_cert = c
break
# Sign a new cert if none exists in issued state
if issued_cert is None:
issued_cert = await sign_device_cert(
db,
ca,
device_id,
device.hostname,
device.ip_address,
730, # Default 2 years
settings.get_encryption_key_bytes(),
)
await db.flush()
# Deploy the cert
await update_cert_status(db, issued_cert.id, "deploying")
cert_pem, key_pem, _ca_cert_pem = await get_cert_for_deploy(
db, issued_cert.id, settings.get_encryption_key_bytes()
)
await db.flush()
result = await _deploy_cert_via_nats(
device_id=str(device_id),
cert_pem=cert_pem,
key_pem=key_pem,
cert_name="portal-device-cert",
)
if result.get("success"):
await update_cert_status(db, issued_cert.id, "deployed")
device.tls_mode = "portal_ca"
results.append(CertDeployResponse(
success=True,
device_id=device_id,
cert_name_on_device=result.get("cert_name_on_device"),
))
else:
await update_cert_status(db, issued_cert.id, "issued")
results.append(CertDeployResponse(
success=False,
device_id=device_id,
error=result.get("error"),
))
except HTTPException as e:
results.append(CertDeployResponse(
success=False,
device_id=device_id,
error=e.detail,
))
except Exception as e:
logger.error("Bulk deploy error", device_id=str(device_id), error=str(e))
results.append(CertDeployResponse(
success=False,
device_id=device_id,
error=str(e),
))
try:
await log_action(
db, tenant_id, current_user.user_id, "cert_bulk_deploy",
resource_type="device_certificate",
details={
"device_count": len(body.device_ids),
"successful": sum(1 for r in results if r.success),
"failed": sum(1 for r in results if not r.success),
},
)
except Exception:
pass
return results
@router.get(
"/devices",
response_model=list[DeviceCertResponse],
summary="List device certificates",
)
async def list_device_certs(
device_id: uuid.UUID | None = Query(None, description="Filter by device ID"),
tenant_id: uuid.UUID | None = Query(None, description="Tenant ID (required for super_admin)"),
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("viewer")),
db: AsyncSession = Depends(get_db),
) -> list[DeviceCertResponse]:
"""List device certificates for the tenant.
Optionally filter by device_id. Excludes superseded certs.
"""
tenant_id = await _get_tenant_id(current_user, db, tenant_id)
certs = await get_device_certs(db, tenant_id, device_id)
return [DeviceCertResponse.model_validate(c) for c in certs]
@router.post(
"/{cert_id}/revoke",
response_model=DeviceCertResponse,
summary="Revoke a device certificate",
)
@limiter.limit("5/minute")
async def revoke_cert(
request: Request,
cert_id: uuid.UUID,
tenant_id: uuid.UUID | None = Query(None, description="Tenant ID (required for super_admin)"),
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("tenant_admin")),
db: AsyncSession = Depends(get_db),
) -> DeviceCertResponse:
"""Revoke a device certificate and reset the device TLS mode to insecure."""
tenant_id = await _get_tenant_id(current_user, db, tenant_id)
cert = await _get_cert_with_tenant_check(db, cert_id, tenant_id)
try:
updated_cert = await update_cert_status(db, cert_id, "revoked")
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=str(e),
)
# Reset device tls_mode to insecure
device_result = await db.execute(
select(Device).where(Device.id == cert.device_id)
)
device = device_result.scalar_one_or_none()
if device is not None:
device.tls_mode = "insecure"
try:
await log_action(
db, tenant_id, current_user.user_id, "cert_revoke",
resource_type="device_certificate", resource_id=str(cert_id),
device_id=cert.device_id,
)
except Exception:
pass
logger.info("Certificate revoked", cert_id=str(cert_id), device_id=str(cert.device_id))
return DeviceCertResponse.model_validate(updated_cert)
@router.post(
"/{cert_id}/rotate",
response_model=CertDeployResponse,
summary="Rotate a device certificate",
)
@limiter.limit("5/minute")
async def rotate_cert(
request: Request,
cert_id: uuid.UUID,
tenant_id: uuid.UUID | None = Query(None, description="Tenant ID (required for super_admin)"),
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("tenant_admin")),
db: AsyncSession = Depends(get_db),
) -> CertDeployResponse:
"""Rotate a device certificate: supersede the old cert, sign a new one, and deploy it.
This is equivalent to: mark old cert as superseded, sign new cert, deploy new cert.
"""
tenant_id = await _get_tenant_id(current_user, db, tenant_id)
old_cert = await _get_cert_with_tenant_check(db, cert_id, tenant_id)
# Get the device for hostname/IP
device_result = await db.execute(
select(Device).where(Device.id == old_cert.device_id)
)
device = device_result.scalar_one_or_none()
if device is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Device {old_cert.device_id} not found",
)
# Get tenant CA
ca = await get_ca_for_tenant(db, tenant_id)
if ca is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No Certificate Authority configured.",
)
# Mark old cert as superseded
try:
await update_cert_status(db, cert_id, "superseded")
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=str(e),
)
# Sign new cert
new_cert = await sign_device_cert(
db,
ca,
old_cert.device_id,
device.hostname,
device.ip_address,
730, # Default 2 years
settings.get_encryption_key_bytes(),
)
await db.flush()
# Deploy new cert
await update_cert_status(db, new_cert.id, "deploying")
cert_pem, key_pem, _ca_cert_pem = await get_cert_for_deploy(
db, new_cert.id, settings.get_encryption_key_bytes()
)
await db.flush()
result = await _deploy_cert_via_nats(
device_id=str(old_cert.device_id),
cert_pem=cert_pem,
key_pem=key_pem,
cert_name="portal-device-cert",
)
if result.get("success"):
await update_cert_status(db, new_cert.id, "deployed")
device.tls_mode = "portal_ca"
try:
await log_action(
db, tenant_id, current_user.user_id, "cert_rotate",
resource_type="device_certificate", resource_id=str(new_cert.id),
device_id=old_cert.device_id,
details={
"old_cert_id": str(cert_id),
"cert_name_on_device": result.get("cert_name_on_device"),
},
)
except Exception:
pass
logger.info(
"Certificate rotated successfully",
old_cert_id=str(cert_id),
new_cert_id=str(new_cert.id),
device_id=str(old_cert.device_id),
)
return CertDeployResponse(
success=True,
device_id=old_cert.device_id,
cert_name_on_device=result.get("cert_name_on_device"),
)
else:
# Rollback: mark new cert as issued (deploy failed)
await update_cert_status(db, new_cert.id, "issued")
logger.warning(
"Certificate rotation deploy failed",
new_cert_id=str(new_cert.id),
device_id=str(old_cert.device_id),
error=result.get("error"),
)
return CertDeployResponse(
success=False,
device_id=old_cert.device_id,
error=result.get("error"),
)

View File

@@ -0,0 +1,297 @@
"""
Client device discovery API endpoint.
Fetches ARP, DHCP lease, and wireless registration data from a RouterOS device
via the NATS command proxy, merges by MAC address, and returns a unified client list.
All routes are tenant-scoped under:
/api/tenants/{tenant_id}/devices/{device_id}/clients
RLS is enforced via get_db() (app_user engine with tenant context).
RBAC: viewer and above (read-only operation).
"""
import asyncio
import uuid
from datetime import datetime, timezone
from typing import Any
import structlog
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.middleware.rbac import require_min_role
from app.middleware.tenant_context import CurrentUser, get_current_user
from app.models.device import Device
from app.services import routeros_proxy
logger = structlog.get_logger(__name__)
router = APIRouter(tags=["clients"])
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
"""Verify the current user is allowed to access the given tenant."""
if current_user.is_super_admin:
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied: you do not belong to this tenant.",
)
async def _check_device_online(
db: AsyncSession, device_id: uuid.UUID
) -> Device:
"""Verify the device exists and is online. Returns the Device object."""
result = await db.execute(
select(Device).where(Device.id == device_id) # type: ignore[arg-type]
)
device = result.scalar_one_or_none()
if device is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Device {device_id} not found",
)
if device.status != "online":
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Device is offline -- client discovery requires a live connection.",
)
return device
# ---------------------------------------------------------------------------
# MAC-address merge logic
# ---------------------------------------------------------------------------
def _normalize_mac(mac: str) -> str:
"""Normalize a MAC address to uppercase colon-separated format."""
return mac.strip().upper().replace("-", ":")
def _merge_client_data(
arp_data: list[dict[str, Any]],
dhcp_data: list[dict[str, Any]],
wireless_data: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Merge ARP, DHCP lease, and wireless registration data by MAC address.
ARP entries are the base. DHCP enriches with hostname. Wireless enriches
with signal/tx/rx/uptime and marks the client as wireless.
"""
# Index DHCP leases by MAC
dhcp_by_mac: dict[str, dict[str, Any]] = {}
for lease in dhcp_data:
mac_raw = lease.get("mac-address") or lease.get("active-mac-address", "")
if mac_raw:
dhcp_by_mac[_normalize_mac(mac_raw)] = lease
# Index wireless registrations by MAC
wireless_by_mac: dict[str, dict[str, Any]] = {}
for reg in wireless_data:
mac_raw = reg.get("mac-address", "")
if mac_raw:
wireless_by_mac[_normalize_mac(mac_raw)] = reg
# Track which MACs we've already processed (from ARP)
seen_macs: set[str] = set()
clients: list[dict[str, Any]] = []
# Start with ARP entries as base
for entry in arp_data:
mac_raw = entry.get("mac-address", "")
if not mac_raw:
continue
mac = _normalize_mac(mac_raw)
if mac in seen_macs:
continue
seen_macs.add(mac)
# Determine status: ARP complete flag or dynamic flag
is_complete = entry.get("complete", "true").lower() == "true"
arp_status = "reachable" if is_complete else "stale"
client: dict[str, Any] = {
"mac": mac,
"ip": entry.get("address", ""),
"interface": entry.get("interface", ""),
"hostname": None,
"status": arp_status,
"signal_strength": None,
"tx_rate": None,
"rx_rate": None,
"uptime": None,
"is_wireless": False,
}
# Enrich with DHCP data
dhcp = dhcp_by_mac.get(mac)
if dhcp:
client["hostname"] = dhcp.get("host-name") or None
dhcp_status = dhcp.get("status", "")
if dhcp_status:
client["dhcp_status"] = dhcp_status
# Enrich with wireless data
wireless = wireless_by_mac.get(mac)
if wireless:
client["is_wireless"] = True
client["signal_strength"] = wireless.get("signal-strength") or None
client["tx_rate"] = wireless.get("tx-rate") or None
client["rx_rate"] = wireless.get("rx-rate") or None
client["uptime"] = wireless.get("uptime") or None
clients.append(client)
# Also include DHCP-only entries (no ARP match -- e.g. expired leases)
for mac, lease in dhcp_by_mac.items():
if mac in seen_macs:
continue
seen_macs.add(mac)
client = {
"mac": mac,
"ip": lease.get("active-address") or lease.get("address", ""),
"interface": lease.get("active-server") or "",
"hostname": lease.get("host-name") or None,
"status": "stale", # No ARP entry = not actively reachable
"signal_strength": None,
"tx_rate": None,
"rx_rate": None,
"uptime": None,
"is_wireless": mac in wireless_by_mac,
}
wireless = wireless_by_mac.get(mac)
if wireless:
client["signal_strength"] = wireless.get("signal-strength") or None
client["tx_rate"] = wireless.get("tx-rate") or None
client["rx_rate"] = wireless.get("rx-rate") or None
client["uptime"] = wireless.get("uptime") or None
clients.append(client)
return clients
# ---------------------------------------------------------------------------
# Endpoint
# ---------------------------------------------------------------------------
@router.get(
"/tenants/{tenant_id}/devices/{device_id}/clients",
summary="List connected client devices (ARP + DHCP + wireless)",
)
async def list_clients(
tenant_id: uuid.UUID,
device_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("viewer")),
db: AsyncSession = Depends(get_db),
) -> dict[str, Any]:
"""Discover all client devices connected to a MikroTik device.
Fetches ARP table, DHCP server leases, and wireless registration table
in parallel, then merges by MAC address into a unified client list.
Wireless fetch failure is non-fatal (device may not have wireless interfaces).
DHCP fetch failure is non-fatal (device may not run a DHCP server).
ARP fetch failure is fatal (core data source).
"""
await _check_tenant_access(current_user, tenant_id, db)
await _check_device_online(db, device_id)
device_id_str = str(device_id)
# Fetch all three sources in parallel
arp_result, dhcp_result, wireless_result = await asyncio.gather(
routeros_proxy.execute_command(device_id_str, "/ip/arp/print"),
routeros_proxy.execute_command(device_id_str, "/ip/dhcp-server/lease/print"),
routeros_proxy.execute_command(
device_id_str, "/interface/wireless/registration-table/print"
),
return_exceptions=True,
)
# ARP is required -- if it failed, return 502
if isinstance(arp_result, Exception):
logger.error("ARP fetch exception", device_id=device_id_str, error=str(arp_result))
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Failed to fetch ARP table: {arp_result}",
)
if not arp_result.get("success"):
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=arp_result.get("error", "Failed to fetch ARP table"),
)
arp_data: list[dict[str, Any]] = arp_result.get("data", [])
# DHCP is optional -- log warning and continue with empty data
dhcp_data: list[dict[str, Any]] = []
if isinstance(dhcp_result, Exception):
logger.warning(
"DHCP fetch exception (continuing without DHCP data)",
device_id=device_id_str,
error=str(dhcp_result),
)
elif not dhcp_result.get("success"):
logger.warning(
"DHCP fetch failed (continuing without DHCP data)",
device_id=device_id_str,
error=dhcp_result.get("error"),
)
else:
dhcp_data = dhcp_result.get("data", [])
# Wireless is optional -- many devices have no wireless interfaces
wireless_data: list[dict[str, Any]] = []
if isinstance(wireless_result, Exception):
logger.warning(
"Wireless fetch exception (device may not have wireless interfaces)",
device_id=device_id_str,
error=str(wireless_result),
)
elif not wireless_result.get("success"):
logger.warning(
"Wireless fetch failed (device may not have wireless interfaces)",
device_id=device_id_str,
error=wireless_result.get("error"),
)
else:
wireless_data = wireless_result.get("data", [])
# Merge by MAC address
clients = _merge_client_data(arp_data, dhcp_data, wireless_data)
logger.info(
"client_discovery_complete",
device_id=device_id_str,
tenant_id=str(tenant_id),
arp_count=len(arp_data),
dhcp_count=len(dhcp_data),
wireless_count=len(wireless_data),
merged_count=len(clients),
)
return {
"clients": clients,
"device_id": device_id_str,
"timestamp": datetime.now(timezone.utc).isoformat(),
}

View File

@@ -0,0 +1,745 @@
"""
Config backup API endpoints.
All routes are tenant-scoped under:
/api/tenants/{tenant_id}/devices/{device_id}/config/
Provides:
- GET /backups — list backup timeline
- POST /backups — trigger manual backup
- POST /checkpoint — create a checkpoint (restore point)
- GET /backups/{sha}/export — retrieve export.rsc text
- GET /backups/{sha}/binary — download backup.bin
- POST /preview-restore — preview impact analysis before restore
- POST /restore — restore a config version (two-phase panic-revert)
- POST /emergency-rollback — rollback to most recent pre-push backup
- GET /schedules — view effective backup schedule
- PUT /schedules — create/update device-specific schedule override
RLS is enforced via get_db() (app_user engine with tenant context).
RBAC: viewer = read-only (GET); operator and above = write (POST/PUT).
"""
import asyncio
import logging
import uuid
from datetime import timezone, datetime
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import Response
from pydantic import BaseModel, ConfigDict
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.middleware.rate_limit import limiter
from app.middleware.rbac import require_min_role, require_scope
from app.middleware.tenant_context import CurrentUser, get_current_user
from app.models.config_backup import ConfigBackupRun, ConfigBackupSchedule
from app.config import settings
from app.models.device import Device
from app.services import backup_service, git_store
from app.services import restore_service
from app.services.crypto import decrypt_credentials_hybrid
from app.services.rsc_parser import parse_rsc, validate_rsc, compute_impact
logger = logging.getLogger(__name__)
router = APIRouter(tags=["config-backups"])
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
"""
Verify the current user is allowed to access the given tenant.
- super_admin can access any tenant — re-sets DB tenant context to target tenant.
- All other roles must match their own tenant_id.
"""
if current_user.is_super_admin:
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied: you do not belong to this tenant.",
)
# ---------------------------------------------------------------------------
# Request/Response schemas
# ---------------------------------------------------------------------------
class RestoreRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
commit_sha: str
class ScheduleUpdate(BaseModel):
model_config = ConfigDict(extra="forbid")
cron_expression: str
enabled: bool
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.get(
"/tenants/{tenant_id}/devices/{device_id}/config/backups",
summary="List backup timeline for a device",
dependencies=[require_scope("config:read")],
)
async def list_backups(
tenant_id: uuid.UUID,
device_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("viewer")),
db: AsyncSession = Depends(get_db),
) -> list[dict[str, Any]]:
"""Return backup timeline for a device, newest first.
Each entry includes: id, commit_sha, trigger_type, lines_added,
lines_removed, and created_at.
"""
await _check_tenant_access(current_user, tenant_id, db)
result = await db.execute(
select(ConfigBackupRun)
.where(
ConfigBackupRun.device_id == device_id, # type: ignore[arg-type]
ConfigBackupRun.tenant_id == tenant_id, # type: ignore[arg-type]
)
.order_by(ConfigBackupRun.created_at.desc())
)
runs = result.scalars().all()
return [
{
"id": str(run.id),
"commit_sha": run.commit_sha,
"trigger_type": run.trigger_type,
"lines_added": run.lines_added,
"lines_removed": run.lines_removed,
"encryption_tier": run.encryption_tier,
"created_at": run.created_at.isoformat(),
}
for run in runs
]
@router.post(
"/tenants/{tenant_id}/devices/{device_id}/config/backups",
summary="Trigger a manual config backup",
status_code=status.HTTP_201_CREATED,
dependencies=[require_scope("config:write")],
)
@limiter.limit("20/minute")
async def trigger_backup(
request: Request,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("operator")),
db: AsyncSession = Depends(get_db),
) -> dict[str, Any]:
"""Trigger an immediate manual backup for a device.
Captures export.rsc and backup.bin via SSH, commits to the tenant's
git store, and records a ConfigBackupRun with trigger_type='manual'.
Returns the backup metadata dict.
"""
await _check_tenant_access(current_user, tenant_id, db)
try:
result = await backup_service.run_backup(
device_id=str(device_id),
tenant_id=str(tenant_id),
trigger_type="manual",
db_session=db,
)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(exc),
) from exc
except Exception as exc:
logger.error(
"Manual backup failed for device %s tenant %s: %s",
device_id,
tenant_id,
exc,
)
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Backup failed: {exc}",
) from exc
return result
@router.post(
"/tenants/{tenant_id}/devices/{device_id}/config/checkpoint",
summary="Create a checkpoint (restore point) of the current config",
dependencies=[require_scope("config:write")],
)
@limiter.limit("5/minute")
async def create_checkpoint(
request: Request,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("operator")),
db: AsyncSession = Depends(get_db),
) -> dict[str, Any]:
"""Create a checkpoint (restore point) of the current device config.
Identical to a manual backup but tagged with trigger_type='checkpoint'.
Checkpoints serve as named restore points that operators create before
making risky changes, so they can easily roll back.
"""
await _check_tenant_access(current_user, tenant_id, db)
try:
result = await backup_service.run_backup(
device_id=str(device_id),
tenant_id=str(tenant_id),
trigger_type="checkpoint",
db_session=db,
)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(exc),
) from exc
except Exception as exc:
logger.error(
"Checkpoint backup failed for device %s tenant %s: %s",
device_id,
tenant_id,
exc,
)
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Checkpoint failed: {exc}",
) from exc
return result
@router.get(
"/tenants/{tenant_id}/devices/{device_id}/config/backups/{commit_sha}/export",
summary="Get export.rsc text for a specific backup",
response_class=Response,
dependencies=[require_scope("config:read")],
)
async def get_export(
tenant_id: uuid.UUID,
device_id: uuid.UUID,
commit_sha: str,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("viewer")),
db: AsyncSession = Depends(get_db),
) -> Response:
"""Return the raw /export compact text for a specific backup version.
For encrypted backups (encryption_tier != NULL), the Transit ciphertext
stored in git is decrypted on-demand before returning plaintext.
Legacy plaintext backups (encryption_tier = NULL) are returned as-is.
Content-Type: text/plain
"""
await _check_tenant_access(current_user, tenant_id, db)
loop = asyncio.get_event_loop()
try:
content_bytes = await loop.run_in_executor(
None,
git_store.read_file,
str(tenant_id),
commit_sha,
str(device_id),
"export.rsc",
)
except KeyError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Backup version not found: {exc}",
) from exc
# Check if this backup is encrypted — decrypt via Transit if so
result = await db.execute(
select(ConfigBackupRun).where(
ConfigBackupRun.commit_sha == commit_sha,
ConfigBackupRun.device_id == device_id,
)
)
backup_run = result.scalar_one_or_none()
if backup_run and backup_run.encryption_tier:
try:
from app.services.crypto import decrypt_data_transit
plaintext = await decrypt_data_transit(
content_bytes.decode("utf-8"), str(tenant_id)
)
content_bytes = plaintext.encode("utf-8")
except Exception as dec_err:
logger.error(
"Failed to decrypt export for device %s sha %s: %s",
device_id, commit_sha, dec_err,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to decrypt backup content",
) from dec_err
return Response(content=content_bytes, media_type="text/plain")
@router.get(
"/tenants/{tenant_id}/devices/{device_id}/config/backups/{commit_sha}/binary",
summary="Download backup.bin for a specific backup",
response_class=Response,
dependencies=[require_scope("config:read")],
)
async def get_binary(
tenant_id: uuid.UUID,
device_id: uuid.UUID,
commit_sha: str,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("viewer")),
db: AsyncSession = Depends(get_db),
) -> Response:
"""Download the RouterOS binary backup file for a specific backup version.
For encrypted backups, the Transit ciphertext is decrypted and the
base64-encoded binary is decoded back to raw bytes before returning.
Legacy plaintext backups are returned as-is.
Content-Type: application/octet-stream (attachment download).
"""
await _check_tenant_access(current_user, tenant_id, db)
loop = asyncio.get_event_loop()
try:
content_bytes = await loop.run_in_executor(
None,
git_store.read_file,
str(tenant_id),
commit_sha,
str(device_id),
"backup.bin",
)
except KeyError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Backup version not found: {exc}",
) from exc
# Check if this backup is encrypted — decrypt via Transit if so
result = await db.execute(
select(ConfigBackupRun).where(
ConfigBackupRun.commit_sha == commit_sha,
ConfigBackupRun.device_id == device_id,
)
)
backup_run = result.scalar_one_or_none()
if backup_run and backup_run.encryption_tier:
try:
import base64 as b64
from app.services.crypto import decrypt_data_transit
# Transit ciphertext -> base64-encoded binary -> raw bytes
b64_plaintext = await decrypt_data_transit(
content_bytes.decode("utf-8"), str(tenant_id)
)
content_bytes = b64.b64decode(b64_plaintext)
except Exception as dec_err:
logger.error(
"Failed to decrypt binary backup for device %s sha %s: %s",
device_id, commit_sha, dec_err,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to decrypt backup content",
) from dec_err
return Response(
content=content_bytes,
media_type="application/octet-stream",
headers={
"Content-Disposition": f'attachment; filename="backup-{commit_sha[:8]}.bin"'
},
)
@router.post(
"/tenants/{tenant_id}/devices/{device_id}/config/preview-restore",
summary="Preview the impact of restoring a config backup",
dependencies=[require_scope("config:read")],
)
@limiter.limit("20/minute")
async def preview_restore(
request: Request,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
body: RestoreRequest,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("operator")),
db: AsyncSession = Depends(get_db),
) -> dict[str, Any]:
"""Preview the impact of restoring a config backup before executing.
Reads the target config from the git backup, fetches the current config
from the live device (falling back to the latest backup if unreachable),
and returns a diff with categories, risk levels, warnings, and validation.
"""
await _check_tenant_access(current_user, tenant_id, db)
loop = asyncio.get_event_loop()
# 1. Read target export from git
try:
target_bytes = await loop.run_in_executor(
None,
git_store.read_file,
str(tenant_id),
body.commit_sha,
str(device_id),
"export.rsc",
)
except KeyError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Backup export not found: {exc}",
) from exc
target_text = target_bytes.decode("utf-8", errors="replace")
# 2. Get current export from device (live) or fallback to latest backup
current_text = ""
try:
result = await db.execute(
select(Device).where(Device.id == device_id) # type: ignore[arg-type]
)
device = result.scalar_one_or_none()
if device and (device.encrypted_credentials_transit or device.encrypted_credentials):
key = settings.get_encryption_key_bytes()
creds_json = await decrypt_credentials_hybrid(
device.encrypted_credentials_transit,
device.encrypted_credentials,
str(tenant_id),
key,
)
import json
creds = json.loads(creds_json)
current_text = await backup_service.capture_export(
device.ip_address,
username=creds.get("username", "admin"),
password=creds.get("password", ""),
)
except Exception:
# Fallback to latest backup in git
logger.debug(
"Live export failed for device %s, falling back to latest backup",
device_id,
)
latest = await db.execute(
select(ConfigBackupRun)
.where(
ConfigBackupRun.device_id == device_id, # type: ignore[arg-type]
)
.order_by(ConfigBackupRun.created_at.desc())
.limit(1)
)
latest_run = latest.scalar_one_or_none()
if latest_run:
try:
current_bytes = await loop.run_in_executor(
None,
git_store.read_file,
str(tenant_id),
latest_run.commit_sha,
str(device_id),
"export.rsc",
)
current_text = current_bytes.decode("utf-8", errors="replace")
except Exception:
current_text = ""
# 3. Parse and analyze
current_parsed = parse_rsc(current_text)
target_parsed = parse_rsc(target_text)
validation = validate_rsc(target_text)
impact = compute_impact(current_parsed, target_parsed)
return {
"diff": impact["diff"],
"categories": impact["categories"],
"warnings": impact["warnings"],
"validation": validation,
}
@router.post(
"/tenants/{tenant_id}/devices/{device_id}/config/restore",
summary="Restore a config version (two-phase push with panic-revert)",
dependencies=[require_scope("config:write")],
)
@limiter.limit("5/minute")
async def restore_config_endpoint(
request: Request,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
body: RestoreRequest,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("operator")),
db: AsyncSession = Depends(get_db),
) -> dict[str, Any]:
"""Restore a device config to a specific backup version.
Implements two-phase push with panic-revert:
1. Pre-backup is taken on device (mandatory before any push)
2. RouterOS scheduler is installed as safety net (auto-reverts if unreachable)
3. Config is pushed via /import
4. Wait 60s for config to settle
5. Reachability check — remove scheduler if device is reachable
6. Return committed/reverted/failed status
Returns: {"status": str, "message": str, "pre_backup_sha": str}
"""
await _check_tenant_access(current_user, tenant_id, db)
try:
result = await restore_service.restore_config(
device_id=str(device_id),
tenant_id=str(tenant_id),
commit_sha=body.commit_sha,
db_session=db,
)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(exc),
) from exc
except Exception as exc:
logger.error(
"Restore failed for device %s tenant %s commit %s: %s",
device_id,
tenant_id,
body.commit_sha,
exc,
)
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Restore failed: {exc}",
) from exc
return result
@router.post(
"/tenants/{tenant_id}/devices/{device_id}/config/emergency-rollback",
summary="Emergency rollback to most recent pre-push backup",
dependencies=[require_scope("config:write")],
)
@limiter.limit("5/minute")
async def emergency_rollback(
request: Request,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("operator")),
db: AsyncSession = Depends(get_db),
) -> dict[str, Any]:
"""Emergency rollback: restore the most recent pre-push backup.
Used when a device goes offline after a config push.
Finds the latest 'pre-restore', 'checkpoint', or 'pre-template-push'
backup and restores it via the two-phase panic-revert process.
"""
await _check_tenant_access(current_user, tenant_id, db)
result = await db.execute(
select(ConfigBackupRun)
.where(
ConfigBackupRun.device_id == device_id, # type: ignore[arg-type]
ConfigBackupRun.tenant_id == tenant_id, # type: ignore[arg-type]
ConfigBackupRun.trigger_type.in_(
["pre-restore", "checkpoint", "pre-template-push"]
),
)
.order_by(ConfigBackupRun.created_at.desc())
.limit(1)
)
backup = result.scalar_one_or_none()
if not backup:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No pre-push backup found for rollback",
)
try:
restore_result = await restore_service.restore_config(
device_id=str(device_id),
tenant_id=str(tenant_id),
commit_sha=backup.commit_sha,
db_session=db,
)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(exc),
) from exc
except Exception as exc:
logger.error(
"Emergency rollback failed for device %s tenant %s: %s",
device_id,
tenant_id,
exc,
)
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Emergency rollback failed: {exc}",
) from exc
return {
**restore_result,
"rolled_back_to": backup.commit_sha,
"rolled_back_to_date": backup.created_at.isoformat(),
}
@router.get(
"/tenants/{tenant_id}/devices/{device_id}/config/schedules",
summary="Get effective backup schedule for a device",
dependencies=[require_scope("config:read")],
)
async def get_schedule(
tenant_id: uuid.UUID,
device_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("viewer")),
db: AsyncSession = Depends(get_db),
) -> dict[str, Any]:
"""Return the effective backup schedule for a device.
Returns the device-specific override if it exists; falls back to the
tenant-level default. If no schedule is configured, returns a synthetic
default (2am UTC daily, enabled=True).
"""
await _check_tenant_access(current_user, tenant_id, db)
# Check for device-specific override first
result = await db.execute(
select(ConfigBackupSchedule).where(
ConfigBackupSchedule.tenant_id == tenant_id, # type: ignore[arg-type]
ConfigBackupSchedule.device_id == device_id, # type: ignore[arg-type]
)
)
schedule = result.scalar_one_or_none()
if schedule is None:
# Fall back to tenant-level default
result = await db.execute(
select(ConfigBackupSchedule).where(
ConfigBackupSchedule.tenant_id == tenant_id, # type: ignore[arg-type]
ConfigBackupSchedule.device_id.is_(None), # type: ignore[union-attr]
)
)
schedule = result.scalar_one_or_none()
if schedule is None:
# No schedule configured — return synthetic default
return {
"id": None,
"tenant_id": str(tenant_id),
"device_id": str(device_id),
"cron_expression": "0 2 * * *",
"enabled": True,
"is_default": True,
}
is_device_specific = schedule.device_id is not None
return {
"id": str(schedule.id),
"tenant_id": str(schedule.tenant_id),
"device_id": str(schedule.device_id) if schedule.device_id else None,
"cron_expression": schedule.cron_expression,
"enabled": schedule.enabled,
"is_default": not is_device_specific,
}
@router.put(
"/tenants/{tenant_id}/devices/{device_id}/config/schedules",
summary="Create or update the device-specific backup schedule",
dependencies=[require_scope("config:write")],
)
@limiter.limit("20/minute")
async def update_schedule(
request: Request,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
body: ScheduleUpdate,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("operator")),
db: AsyncSession = Depends(get_db),
) -> dict[str, Any]:
"""Create or update the device-specific backup schedule override.
If no device-specific schedule exists, creates one. If one exists, updates
its cron_expression and enabled fields.
Returns the updated schedule.
"""
await _check_tenant_access(current_user, tenant_id, db)
# Look for existing device-specific schedule
result = await db.execute(
select(ConfigBackupSchedule).where(
ConfigBackupSchedule.tenant_id == tenant_id, # type: ignore[arg-type]
ConfigBackupSchedule.device_id == device_id, # type: ignore[arg-type]
)
)
schedule = result.scalar_one_or_none()
if schedule is None:
# Create new device-specific schedule
schedule = ConfigBackupSchedule(
tenant_id=tenant_id,
device_id=device_id,
cron_expression=body.cron_expression,
enabled=body.enabled,
)
db.add(schedule)
else:
# Update existing schedule
schedule.cron_expression = body.cron_expression
schedule.enabled = body.enabled
await db.flush()
# Hot-reload the scheduler so changes take effect immediately
from app.services.backup_scheduler import on_schedule_change
await on_schedule_change(tenant_id, device_id)
return {
"id": str(schedule.id),
"tenant_id": str(schedule.tenant_id),
"device_id": str(schedule.device_id),
"cron_expression": schedule.cron_expression,
"enabled": schedule.enabled,
"is_default": False,
}

View File

@@ -0,0 +1,371 @@
"""
Dynamic RouterOS config editor API endpoints.
All routes are tenant-scoped under:
/api/tenants/{tenant_id}/devices/{device_id}/config-editor/
Proxies commands to the Go poller's CmdResponder via the RouterOS proxy service.
Provides:
- GET /browse -- browse a RouterOS menu path
- POST /add -- add a new entry
- POST /set -- edit an existing entry
- POST /remove -- delete an entry
- POST /execute -- execute an arbitrary CLI command
RLS is enforced via get_db() (app_user engine with tenant context).
RBAC: viewer = read-only (GET browse); operator and above = write (POST).
"""
import uuid
import structlog
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from pydantic import BaseModel, ConfigDict
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.middleware.rate_limit import limiter
from app.middleware.rbac import require_min_role, require_scope
from app.middleware.tenant_context import CurrentUser, get_current_user
from app.models.device import Device
from app.security.command_blocklist import check_command_safety, check_path_safety
from app.services import routeros_proxy
from app.services.audit_service import log_action
logger = structlog.get_logger(__name__)
audit_logger = structlog.get_logger("audit")
router = APIRouter(tags=["config-editor"])
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
"""Verify the current user is allowed to access the given tenant."""
from app.database import set_tenant_context
if current_user.is_super_admin:
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied: you do not belong to this tenant.",
)
# Set RLS context for regular users too
await set_tenant_context(db, str(tenant_id))
async def _check_device_online(
db: AsyncSession, device_id: uuid.UUID
) -> Device:
"""Verify the device exists and is online. Returns the Device object."""
result = await db.execute(
select(Device).where(Device.id == device_id) # type: ignore[arg-type]
)
device = result.scalar_one_or_none()
if device is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Device {device_id} not found",
)
if device.status != "online":
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Device is offline \u2014 config editor requires a live connection.",
)
return device
# ---------------------------------------------------------------------------
# Request schemas
# ---------------------------------------------------------------------------
class AddEntryRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
path: str
properties: dict[str, str]
class SetEntryRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
path: str
entry_id: str | None = None # Optional for singleton paths (e.g. /ip/dns)
properties: dict[str, str]
class RemoveEntryRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
path: str
entry_id: str
class ExecuteRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
command: str
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.get(
"/tenants/{tenant_id}/devices/{device_id}/config-editor/browse",
summary="Browse a RouterOS menu path",
dependencies=[require_scope("config:read")],
)
async def browse_menu(
tenant_id: uuid.UUID,
device_id: uuid.UUID,
path: str = Query("/interface", description="RouterOS menu path to browse"),
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("viewer")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Browse a RouterOS menu path and return all entries at that path."""
await _check_tenant_access(current_user, tenant_id, db)
await _check_device_online(db, device_id)
check_path_safety(path)
result = await routeros_proxy.browse_menu(str(device_id), path)
if not result.get("success"):
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=result.get("error", "Failed to browse menu path"),
)
audit_logger.info(
"routeros_config_browsed",
device_id=str(device_id),
tenant_id=str(tenant_id),
user_id=str(current_user.user_id),
path=path,
)
return {
"success": True,
"entries": result.get("data", []),
"error": None,
"path": path,
}
@router.post(
"/tenants/{tenant_id}/devices/{device_id}/config-editor/add",
summary="Add a new entry to a RouterOS menu path",
dependencies=[require_scope("config:write")],
)
@limiter.limit("20/minute")
async def add_entry(
request: Request,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
body: AddEntryRequest,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("operator")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Add a new entry to a RouterOS menu path with the given properties."""
await _check_tenant_access(current_user, tenant_id, db)
await _check_device_online(db, device_id)
check_path_safety(body.path, write=True)
result = await routeros_proxy.add_entry(str(device_id), body.path, body.properties)
if not result.get("success"):
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=result.get("error", "Failed to add entry"),
)
audit_logger.info(
"routeros_config_added",
device_id=str(device_id),
tenant_id=str(tenant_id),
user_id=str(current_user.user_id),
user_role=current_user.role,
path=body.path,
success=result.get("success", False),
)
try:
await log_action(
db, tenant_id, current_user.user_id, "config_add",
resource_type="config", resource_id=str(device_id),
device_id=device_id,
details={"path": body.path, "properties": body.properties},
)
except Exception:
pass
return result
@router.post(
"/tenants/{tenant_id}/devices/{device_id}/config-editor/set",
summary="Edit an existing entry in a RouterOS menu path",
dependencies=[require_scope("config:write")],
)
@limiter.limit("20/minute")
async def set_entry(
request: Request,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
body: SetEntryRequest,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("operator")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Update an existing entry's properties on the device."""
await _check_tenant_access(current_user, tenant_id, db)
await _check_device_online(db, device_id)
check_path_safety(body.path, write=True)
result = await routeros_proxy.update_entry(
str(device_id), body.path, body.entry_id, body.properties
)
if not result.get("success"):
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=result.get("error", "Failed to update entry"),
)
audit_logger.info(
"routeros_config_modified",
device_id=str(device_id),
tenant_id=str(tenant_id),
user_id=str(current_user.user_id),
user_role=current_user.role,
path=body.path,
entry_id=body.entry_id,
success=result.get("success", False),
)
try:
await log_action(
db, tenant_id, current_user.user_id, "config_set",
resource_type="config", resource_id=str(device_id),
device_id=device_id,
details={"path": body.path, "entry_id": body.entry_id, "properties": body.properties},
)
except Exception:
pass
return result
@router.post(
"/tenants/{tenant_id}/devices/{device_id}/config-editor/remove",
summary="Delete an entry from a RouterOS menu path",
dependencies=[require_scope("config:write")],
)
@limiter.limit("5/minute")
async def remove_entry(
request: Request,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
body: RemoveEntryRequest,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("operator")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Remove an entry from a RouterOS menu path."""
await _check_tenant_access(current_user, tenant_id, db)
await _check_device_online(db, device_id)
check_path_safety(body.path, write=True)
result = await routeros_proxy.remove_entry(
str(device_id), body.path, body.entry_id
)
if not result.get("success"):
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=result.get("error", "Failed to remove entry"),
)
audit_logger.info(
"routeros_config_removed",
device_id=str(device_id),
tenant_id=str(tenant_id),
user_id=str(current_user.user_id),
user_role=current_user.role,
path=body.path,
entry_id=body.entry_id,
success=result.get("success", False),
)
try:
await log_action(
db, tenant_id, current_user.user_id, "config_remove",
resource_type="config", resource_id=str(device_id),
device_id=device_id,
details={"path": body.path, "entry_id": body.entry_id},
)
except Exception:
pass
return result
@router.post(
"/tenants/{tenant_id}/devices/{device_id}/config-editor/execute",
summary="Execute an arbitrary RouterOS CLI command",
dependencies=[require_scope("config:write")],
)
@limiter.limit("20/minute")
async def execute_command(
request: Request,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
body: ExecuteRequest,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("operator")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Execute an arbitrary RouterOS CLI command on the device."""
await _check_tenant_access(current_user, tenant_id, db)
await _check_device_online(db, device_id)
check_command_safety(body.command)
result = await routeros_proxy.execute_cli(str(device_id), body.command)
if not result.get("success"):
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=result.get("error", "Failed to execute command"),
)
audit_logger.info(
"routeros_command_executed",
device_id=str(device_id),
tenant_id=str(tenant_id),
user_id=str(current_user.user_id),
user_role=current_user.role,
command=body.command,
success=result.get("success", False),
)
try:
await log_action(
db, tenant_id, current_user.user_id, "config_execute",
resource_type="config", resource_id=str(device_id),
device_id=device_id,
details={"command": body.command},
)
except Exception:
pass
return result

View File

@@ -0,0 +1,94 @@
"""
Device group management API endpoints.
Routes: /api/tenants/{tenant_id}/device-groups
RBAC:
- viewer: GET (read-only)
- operator: POST, PUT (write)
- tenant_admin/admin: DELETE
"""
import uuid
from fastapi import APIRouter, Depends, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.middleware.rbac import require_operator_or_above, require_tenant_admin_or_above
from app.middleware.tenant_context import CurrentUser, get_current_user
from app.routers.devices import _check_tenant_access
from app.schemas.device import DeviceGroupCreate, DeviceGroupResponse, DeviceGroupUpdate
from app.services import device as device_service
router = APIRouter(tags=["device-groups"])
@router.get(
"/tenants/{tenant_id}/device-groups",
response_model=list[DeviceGroupResponse],
summary="List device groups",
)
async def list_groups(
tenant_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> list[DeviceGroupResponse]:
"""List all device groups for a tenant. Viewer role and above."""
await _check_tenant_access(current_user, tenant_id, db)
return await device_service.get_groups(db=db, tenant_id=tenant_id)
@router.post(
"/tenants/{tenant_id}/device-groups",
response_model=DeviceGroupResponse,
status_code=status.HTTP_201_CREATED,
summary="Create a device group",
dependencies=[Depends(require_operator_or_above)],
)
async def create_group(
tenant_id: uuid.UUID,
data: DeviceGroupCreate,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> DeviceGroupResponse:
"""Create a new device group. Requires operator role or above."""
await _check_tenant_access(current_user, tenant_id, db)
return await device_service.create_group(db=db, tenant_id=tenant_id, data=data)
@router.put(
"/tenants/{tenant_id}/device-groups/{group_id}",
response_model=DeviceGroupResponse,
summary="Update a device group",
dependencies=[Depends(require_operator_or_above)],
)
async def update_group(
tenant_id: uuid.UUID,
group_id: uuid.UUID,
data: DeviceGroupUpdate,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> DeviceGroupResponse:
"""Update a device group. Requires operator role or above."""
await _check_tenant_access(current_user, tenant_id, db)
return await device_service.update_group(
db=db, tenant_id=tenant_id, group_id=group_id, data=data
)
@router.delete(
"/tenants/{tenant_id}/device-groups/{group_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a device group",
dependencies=[Depends(require_tenant_admin_or_above)],
)
async def delete_group(
tenant_id: uuid.UUID,
group_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> None:
"""Delete a device group. Requires tenant_admin or above."""
await _check_tenant_access(current_user, tenant_id, db)
await device_service.delete_group(db=db, tenant_id=tenant_id, group_id=group_id)

View File

@@ -0,0 +1,150 @@
"""
Device syslog fetch endpoint via NATS RouterOS proxy.
Provides:
- GET /tenants/{tenant_id}/devices/{device_id}/logs -- fetch device log entries
RLS enforced via get_db() (app_user engine with tenant context).
RBAC: viewer and above can read logs.
"""
import uuid
import structlog
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.middleware.rbac import require_min_role
from app.middleware.tenant_context import CurrentUser, get_current_user
from app.services import routeros_proxy
logger = structlog.get_logger(__name__)
router = APIRouter(tags=["device-logs"])
# ---------------------------------------------------------------------------
# Helpers (same pattern as config_editor.py)
# ---------------------------------------------------------------------------
async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
"""Verify the current user is allowed to access the given tenant."""
if current_user.is_super_admin:
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied: you do not belong to this tenant.",
)
async def _check_device_exists(
db: AsyncSession, device_id: uuid.UUID
) -> None:
"""Verify the device exists (does not require online status for logs)."""
from sqlalchemy import select
from app.models.device import Device
result = await db.execute(
select(Device).where(Device.id == device_id)
)
device = result.scalar_one_or_none()
if device is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Device {device_id} not found",
)
# ---------------------------------------------------------------------------
# Response model
# ---------------------------------------------------------------------------
class LogEntry(BaseModel):
time: str
topics: str
message: str
class LogsResponse(BaseModel):
logs: list[LogEntry]
device_id: str
count: int
# ---------------------------------------------------------------------------
# Endpoint
# ---------------------------------------------------------------------------
@router.get(
"/tenants/{tenant_id}/devices/{device_id}/logs",
response_model=LogsResponse,
summary="Fetch device syslog entries via RouterOS API",
dependencies=[Depends(require_min_role("viewer"))],
)
async def get_device_logs(
tenant_id: uuid.UUID,
device_id: uuid.UUID,
limit: int = Query(default=100, ge=1, le=500),
topic: str | None = Query(default=None, description="Filter by log topic"),
search: str | None = Query(default=None, description="Search in message/topics"),
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> LogsResponse:
"""Fetch device log entries via the RouterOS /log/print command."""
await _check_tenant_access(current_user, tenant_id, db)
await _check_device_exists(db, device_id)
# Build RouterOS command args
args = [f"=count={limit}"]
if topic:
args.append(f"?topics={topic}")
result = await routeros_proxy.execute_command(
str(device_id), "/log/print", args=args, timeout=15.0
)
if not result.get("success"):
error_msg = result.get("error", "Unknown error fetching logs")
logger.warning(
"failed to fetch device logs",
device_id=str(device_id),
error=error_msg,
)
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Failed to fetch device logs: {error_msg}",
)
# Parse log entries from RouterOS response
raw_entries = result.get("data", [])
logs: list[LogEntry] = []
for entry in raw_entries:
log_entry = LogEntry(
time=entry.get("time", ""),
topics=entry.get("topics", ""),
message=entry.get("message", ""),
)
# Apply search filter (case-insensitive) if provided
if search:
search_lower = search.lower()
if (
search_lower not in log_entry.message.lower()
and search_lower not in log_entry.topics.lower()
):
continue
logs.append(log_entry)
return LogsResponse(
logs=logs,
device_id=str(device_id),
count=len(logs),
)

View File

@@ -0,0 +1,94 @@
"""
Device tag management API endpoints.
Routes: /api/tenants/{tenant_id}/device-tags
RBAC:
- viewer: GET (read-only)
- operator: POST, PUT (write)
- tenant_admin/admin: DELETE
"""
import uuid
from fastapi import APIRouter, Depends, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.middleware.rbac import require_operator_or_above, require_tenant_admin_or_above
from app.middleware.tenant_context import CurrentUser, get_current_user
from app.routers.devices import _check_tenant_access
from app.schemas.device import DeviceTagCreate, DeviceTagResponse, DeviceTagUpdate
from app.services import device as device_service
router = APIRouter(tags=["device-tags"])
@router.get(
"/tenants/{tenant_id}/device-tags",
response_model=list[DeviceTagResponse],
summary="List device tags",
)
async def list_tags(
tenant_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> list[DeviceTagResponse]:
"""List all device tags for a tenant. Viewer role and above."""
await _check_tenant_access(current_user, tenant_id, db)
return await device_service.get_tags(db=db, tenant_id=tenant_id)
@router.post(
"/tenants/{tenant_id}/device-tags",
response_model=DeviceTagResponse,
status_code=status.HTTP_201_CREATED,
summary="Create a device tag",
dependencies=[Depends(require_operator_or_above)],
)
async def create_tag(
tenant_id: uuid.UUID,
data: DeviceTagCreate,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> DeviceTagResponse:
"""Create a new device tag. Requires operator role or above."""
await _check_tenant_access(current_user, tenant_id, db)
return await device_service.create_tag(db=db, tenant_id=tenant_id, data=data)
@router.put(
"/tenants/{tenant_id}/device-tags/{tag_id}",
response_model=DeviceTagResponse,
summary="Update a device tag",
dependencies=[Depends(require_operator_or_above)],
)
async def update_tag(
tenant_id: uuid.UUID,
tag_id: uuid.UUID,
data: DeviceTagUpdate,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> DeviceTagResponse:
"""Update a device tag. Requires operator role or above."""
await _check_tenant_access(current_user, tenant_id, db)
return await device_service.update_tag(
db=db, tenant_id=tenant_id, tag_id=tag_id, data=data
)
@router.delete(
"/tenants/{tenant_id}/device-tags/{tag_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a device tag",
dependencies=[Depends(require_tenant_admin_or_above)],
)
async def delete_tag(
tenant_id: uuid.UUID,
tag_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> None:
"""Delete a device tag. Requires tenant_admin or above."""
await _check_tenant_access(current_user, tenant_id, db)
await device_service.delete_tag(db=db, tenant_id=tenant_id, tag_id=tag_id)

View File

@@ -0,0 +1,452 @@
"""
Device management API endpoints.
All routes are tenant-scoped under /api/tenants/{tenant_id}/devices.
RLS is enforced via PostgreSQL — the app_user engine automatically filters
cross-tenant data based on the SET LOCAL app.current_tenant context set by
get_current_user dependency.
RBAC:
- viewer: GET (read-only)
- operator: POST, PUT (write)
- admin/tenant_admin: DELETE
"""
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.database import get_db
from app.middleware.rate_limit import limiter
from app.services.audit_service import log_action
from app.middleware.rbac import (
require_min_role,
require_operator_or_above,
require_scope,
require_tenant_admin_or_above,
)
from app.middleware.tenant_context import CurrentUser, get_current_user
from app.schemas.device import (
BulkAddRequest,
BulkAddResult,
DeviceCreate,
DeviceListResponse,
DeviceResponse,
DeviceUpdate,
SubnetScanRequest,
SubnetScanResponse,
)
from app.services import device as device_service
from app.services.scanner import scan_subnet
router = APIRouter(tags=["devices"])
async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
"""
Verify the current user is allowed to access the given tenant.
- super_admin can access any tenant — re-sets DB tenant context to target tenant.
- All other roles must match their own tenant_id.
"""
if current_user.is_super_admin:
# Re-set tenant context to the target tenant so RLS allows the operation
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied: you do not belong to this tenant.",
)
# ---------------------------------------------------------------------------
# Device CRUD
# ---------------------------------------------------------------------------
@router.get(
"/tenants/{tenant_id}/devices",
response_model=DeviceListResponse,
summary="List devices with pagination and filtering",
dependencies=[require_scope("devices:read")],
)
async def list_devices(
tenant_id: uuid.UUID,
page: int = Query(1, ge=1, description="Page number (1-based)"),
page_size: int = Query(25, ge=1, le=100, description="Items per page (1-100)"),
status_filter: Optional[str] = Query(None, alias="status"),
search: Optional[str] = Query(None, description="Text search on hostname or IP"),
tag_id: Optional[uuid.UUID] = Query(None),
group_id: Optional[uuid.UUID] = Query(None),
sort_by: str = Query("created_at", description="Field to sort by"),
sort_order: str = Query("desc", description="asc or desc"),
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> DeviceListResponse:
"""List devices for a tenant with optional pagination, filtering, and sorting."""
await _check_tenant_access(current_user, tenant_id, db)
items, total = await device_service.get_devices(
db=db,
tenant_id=tenant_id,
page=page,
page_size=page_size,
status=status_filter,
search=search,
tag_id=tag_id,
group_id=group_id,
sort_by=sort_by,
sort_order=sort_order,
)
return DeviceListResponse(items=items, total=total, page=page, page_size=page_size)
@router.post(
"/tenants/{tenant_id}/devices",
response_model=DeviceResponse,
status_code=status.HTTP_201_CREATED,
summary="Add a device (validates TCP connectivity first)",
dependencies=[Depends(require_operator_or_above), require_scope("devices:write")],
)
@limiter.limit("20/minute")
async def create_device(
request: Request,
tenant_id: uuid.UUID,
data: DeviceCreate,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> DeviceResponse:
"""
Create a new device. Requires operator role or above.
The device IP/port is TCP-probed before the record is saved.
Credentials are encrypted with AES-256-GCM before storage and never returned.
"""
await _check_tenant_access(current_user, tenant_id, db)
result = await device_service.create_device(
db=db,
tenant_id=tenant_id,
data=data,
encryption_key=settings.get_encryption_key_bytes(),
)
try:
await log_action(
db, tenant_id, current_user.user_id, "device_create",
resource_type="device", resource_id=str(result.id),
details={"hostname": data.hostname, "ip_address": data.ip_address},
ip_address=request.client.host if request.client else None,
)
except Exception:
pass
return result
@router.get(
"/tenants/{tenant_id}/devices/{device_id}",
response_model=DeviceResponse,
summary="Get a single device",
dependencies=[require_scope("devices:read")],
)
async def get_device(
tenant_id: uuid.UUID,
device_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> DeviceResponse:
"""Get device details. Viewer role and above."""
await _check_tenant_access(current_user, tenant_id, db)
return await device_service.get_device(db=db, tenant_id=tenant_id, device_id=device_id)
@router.put(
"/tenants/{tenant_id}/devices/{device_id}",
response_model=DeviceResponse,
summary="Update a device",
dependencies=[Depends(require_operator_or_above), require_scope("devices:write")],
)
@limiter.limit("20/minute")
async def update_device(
request: Request,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
data: DeviceUpdate,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> DeviceResponse:
"""Update device fields. Requires operator role or above."""
await _check_tenant_access(current_user, tenant_id, db)
result = await device_service.update_device(
db=db,
tenant_id=tenant_id,
device_id=device_id,
data=data,
encryption_key=settings.get_encryption_key_bytes(),
)
try:
await log_action(
db, tenant_id, current_user.user_id, "device_update",
resource_type="device", resource_id=str(device_id),
device_id=device_id,
details={"changes": data.model_dump(exclude_unset=True)},
ip_address=request.client.host if request.client else None,
)
except Exception:
pass
return result
@router.delete(
"/tenants/{tenant_id}/devices/{device_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a device",
dependencies=[Depends(require_tenant_admin_or_above), require_scope("devices:write")],
)
@limiter.limit("5/minute")
async def delete_device(
request: Request,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> None:
"""Hard-delete a device. Requires tenant_admin or above."""
await _check_tenant_access(current_user, tenant_id, db)
try:
await log_action(
db, tenant_id, current_user.user_id, "device_delete",
resource_type="device", resource_id=str(device_id),
device_id=device_id,
ip_address=request.client.host if request.client else None,
)
except Exception:
pass
await device_service.delete_device(db=db, tenant_id=tenant_id, device_id=device_id)
# ---------------------------------------------------------------------------
# Subnet scan and bulk add
# ---------------------------------------------------------------------------
@router.post(
"/tenants/{tenant_id}/devices/scan",
response_model=SubnetScanResponse,
summary="Scan a subnet for MikroTik devices",
dependencies=[Depends(require_operator_or_above), require_scope("devices:write")],
)
@limiter.limit("5/minute")
async def scan_devices(
request: Request,
tenant_id: uuid.UUID,
data: SubnetScanRequest,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> SubnetScanResponse:
"""
Scan a CIDR subnet for hosts with open RouterOS API ports (8728/8729).
Returns a list of discovered IPs for the user to review and selectively
import — does NOT automatically add devices.
Requires operator role or above.
"""
if not current_user.is_super_admin and current_user.tenant_id != tenant_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
discovered = await scan_subnet(data.cidr)
import ipaddress
network = ipaddress.ip_network(data.cidr, strict=False)
total_scanned = network.num_addresses - 2 if network.num_addresses > 2 else network.num_addresses
# Audit log the scan (fire-and-forget — never breaks the response)
try:
await log_action(
db, tenant_id, current_user.user_id, "subnet_scan",
resource_type="network", resource_id=data.cidr,
details={
"cidr": data.cidr,
"devices_found": len(discovered),
"ip": request.client.host if request.client else None,
},
ip_address=request.client.host if request.client else None,
)
except Exception:
pass
return SubnetScanResponse(
cidr=data.cidr,
discovered=discovered,
total_scanned=total_scanned,
total_discovered=len(discovered),
)
@router.post(
"/tenants/{tenant_id}/devices/bulk-add",
response_model=BulkAddResult,
status_code=status.HTTP_201_CREATED,
summary="Bulk-add devices from scan results",
dependencies=[Depends(require_operator_or_above), require_scope("devices:write")],
)
@limiter.limit("5/minute")
async def bulk_add_devices(
request: Request,
tenant_id: uuid.UUID,
data: BulkAddRequest,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> BulkAddResult:
"""
Add multiple devices at once from scan results.
Per-device credentials take precedence over shared credentials.
Devices that fail connectivity checks or validation are reported in `failed`.
Requires operator role or above.
"""
await _check_tenant_access(current_user, tenant_id, db)
added = []
failed = []
encryption_key = settings.get_encryption_key_bytes()
for dev_data in data.devices:
# Resolve credentials: per-device first, then shared
username = dev_data.username or data.shared_username
password = dev_data.password or data.shared_password
if not username or not password:
failed.append({
"ip_address": dev_data.ip_address,
"error": "No credentials provided (set per-device or shared credentials)",
})
continue
create_data = DeviceCreate(
hostname=dev_data.hostname or dev_data.ip_address,
ip_address=dev_data.ip_address,
api_port=dev_data.api_port,
api_ssl_port=dev_data.api_ssl_port,
username=username,
password=password,
)
try:
device = await device_service.create_device(
db=db,
tenant_id=tenant_id,
data=create_data,
encryption_key=encryption_key,
)
added.append(device)
try:
await log_action(
db, tenant_id, current_user.user_id, "device_adopt",
resource_type="device", resource_id=str(device.id),
details={"hostname": create_data.hostname, "ip_address": create_data.ip_address},
ip_address=request.client.host if request.client else None,
)
except Exception:
pass
except HTTPException as exc:
failed.append({"ip_address": dev_data.ip_address, "error": exc.detail})
except Exception as exc:
failed.append({"ip_address": dev_data.ip_address, "error": str(exc)})
return BulkAddResult(added=added, failed=failed)
# ---------------------------------------------------------------------------
# Group assignment
# ---------------------------------------------------------------------------
@router.post(
"/tenants/{tenant_id}/devices/{device_id}/groups/{group_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Add device to a group",
dependencies=[Depends(require_operator_or_above), require_scope("devices:write")],
)
@limiter.limit("20/minute")
async def add_device_to_group(
request: Request,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
group_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> None:
"""Assign a device to a group. Requires operator or above."""
await _check_tenant_access(current_user, tenant_id, db)
await device_service.assign_device_to_group(db, tenant_id, device_id, group_id)
@router.delete(
"/tenants/{tenant_id}/devices/{device_id}/groups/{group_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Remove device from a group",
dependencies=[Depends(require_operator_or_above), require_scope("devices:write")],
)
@limiter.limit("5/minute")
async def remove_device_from_group(
request: Request,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
group_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> None:
"""Remove a device from a group. Requires operator or above."""
await _check_tenant_access(current_user, tenant_id, db)
await device_service.remove_device_from_group(db, tenant_id, device_id, group_id)
# ---------------------------------------------------------------------------
# Tag assignment
# ---------------------------------------------------------------------------
@router.post(
"/tenants/{tenant_id}/devices/{device_id}/tags/{tag_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Add tag to a device",
dependencies=[Depends(require_operator_or_above), require_scope("devices:write")],
)
@limiter.limit("20/minute")
async def add_tag_to_device(
request: Request,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
tag_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> None:
"""Assign a tag to a device. Requires operator or above."""
await _check_tenant_access(current_user, tenant_id, db)
await device_service.assign_tag_to_device(db, tenant_id, device_id, tag_id)
@router.delete(
"/tenants/{tenant_id}/devices/{device_id}/tags/{tag_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Remove tag from a device",
dependencies=[Depends(require_operator_or_above), require_scope("devices:write")],
)
@limiter.limit("5/minute")
async def remove_tag_from_device(
request: Request,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
tag_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> None:
"""Remove a tag from a device. Requires operator or above."""
await _check_tenant_access(current_user, tenant_id, db)
await device_service.remove_tag_from_device(db, tenant_id, device_id, tag_id)

View File

@@ -0,0 +1,164 @@
"""Unified events timeline API endpoint.
Provides a single GET endpoint that unions alert events, device status changes,
and config backup runs into a unified timeline for the dashboard.
RLS enforced via get_db() (app_user engine with tenant context).
"""
import logging
import uuid
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db, set_tenant_context
from app.middleware.tenant_context import CurrentUser, get_current_user
logger = logging.getLogger(__name__)
router = APIRouter(tags=["events"])
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
"""Verify the current user is allowed to access the given tenant."""
if current_user.is_super_admin:
await set_tenant_context(db, str(tenant_id))
elif current_user.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to this tenant",
)
# ---------------------------------------------------------------------------
# Unified events endpoint
# ---------------------------------------------------------------------------
@router.get(
"/tenants/{tenant_id}/events",
summary="List unified events (alerts, status changes, config backups)",
)
async def list_events(
tenant_id: uuid.UUID,
limit: int = Query(50, ge=1, le=200, description="Max events to return"),
event_type: Optional[str] = Query(
None,
description="Filter by event type: alert, status_change, config_backup",
),
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> list[dict[str, Any]]:
"""Return a unified list of recent events across alerts, device status, and config backups.
Events are ordered by timestamp descending, limited to `limit` (default 50).
RLS automatically filters to the tenant's data via the app_user session.
"""
await _check_tenant_access(current_user, tenant_id, db)
if event_type and event_type not in ("alert", "status_change", "config_backup"):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="event_type must be one of: alert, status_change, config_backup",
)
events: list[dict[str, Any]] = []
# 1. Alert events
if not event_type or event_type == "alert":
alert_result = await db.execute(
text("""
SELECT ae.id, ae.status, ae.severity, ae.metric, ae.message,
ae.fired_at, ae.device_id, d.hostname
FROM alert_events ae
LEFT JOIN devices d ON d.id = ae.device_id
ORDER BY ae.fired_at DESC
LIMIT :limit
"""),
{"limit": limit},
)
for row in alert_result.fetchall():
alert_status = row[1] or "firing"
metric = row[3] or "unknown"
events.append({
"id": str(row[0]),
"event_type": "alert",
"severity": row[2],
"title": f"{alert_status}: {metric}",
"description": row[4] or f"Alert {alert_status} for {metric}",
"device_hostname": row[7],
"device_id": str(row[6]) if row[6] else None,
"timestamp": row[5].isoformat() if row[5] else None,
})
# 2. Device status changes (inferred from current status + last_seen)
if not event_type or event_type == "status_change":
status_result = await db.execute(
text("""
SELECT d.id, d.hostname, d.status, d.last_seen
FROM devices d
WHERE d.last_seen IS NOT NULL
ORDER BY d.last_seen DESC
LIMIT :limit
"""),
{"limit": limit},
)
for row in status_result.fetchall():
device_status = row[2] or "unknown"
hostname = row[1] or "Unknown device"
severity = "info" if device_status == "online" else "warning"
events.append({
"id": f"status-{row[0]}",
"event_type": "status_change",
"severity": severity,
"title": f"Device {device_status}",
"description": f"{hostname} is now {device_status}",
"device_hostname": hostname,
"device_id": str(row[0]),
"timestamp": row[3].isoformat() if row[3] else None,
})
# 3. Config backup runs
if not event_type or event_type == "config_backup":
backup_result = await db.execute(
text("""
SELECT cbr.id, cbr.trigger_type, cbr.created_at,
cbr.device_id, d.hostname
FROM config_backup_runs cbr
LEFT JOIN devices d ON d.id = cbr.device_id
ORDER BY cbr.created_at DESC
LIMIT :limit
"""),
{"limit": limit},
)
for row in backup_result.fetchall():
trigger_type = row[1] or "manual"
hostname = row[4] or "Unknown device"
events.append({
"id": str(row[0]),
"event_type": "config_backup",
"severity": "info",
"title": "Config backup",
"description": f"{trigger_type} backup completed for {hostname}",
"device_hostname": hostname,
"device_id": str(row[3]) if row[3] else None,
"timestamp": row[2].isoformat() if row[2] else None,
})
# Sort all events by timestamp descending, then apply final limit
events.sort(
key=lambda e: e["timestamp"] or "",
reverse=True,
)
return events[:limit]

View File

@@ -0,0 +1,712 @@
"""Firmware API endpoints for version overview, cache management, preferred channel,
and firmware upgrade orchestration.
Tenant-scoped routes under /api/tenants/{tenant_id}/firmware/*.
Global routes under /api/firmware/* for version listing and admin actions.
"""
import asyncio
import uuid
from datetime import datetime
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from pydantic import BaseModel, ConfigDict
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db, set_tenant_context
from app.middleware.rate_limit import limiter
from app.middleware.rbac import require_scope
from app.middleware.tenant_context import CurrentUser, get_current_user
from app.services.audit_service import log_action
router = APIRouter(tags=["firmware"])
async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
"""Verify the current user is allowed to access the given tenant."""
if current_user.is_super_admin:
await set_tenant_context(db, str(tenant_id))
elif current_user.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to this tenant",
)
class PreferredChannelRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
preferred_channel: str # "stable", "long-term", "testing"
class FirmwareDownloadRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
architecture: str
channel: str
version: str
# =========================================================================
# TENANT-SCOPED ENDPOINTS
# =========================================================================
@router.get(
"/tenants/{tenant_id}/firmware/overview",
summary="Get firmware status for all devices in tenant",
dependencies=[require_scope("firmware:write")],
)
async def get_firmware_overview(
tenant_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> dict[str, Any]:
await _check_tenant_access(current_user, tenant_id, db)
from app.services.firmware_service import get_firmware_overview as _get_overview
return await _get_overview(str(tenant_id))
@router.patch(
"/tenants/{tenant_id}/devices/{device_id}/preferred-channel",
summary="Set preferred firmware channel for a device",
dependencies=[require_scope("firmware:write")],
)
@limiter.limit("20/minute")
async def set_device_preferred_channel(
request: Request,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
body: PreferredChannelRequest,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> dict[str, str]:
await _check_tenant_access(current_user, tenant_id, db)
if body.preferred_channel not in ("stable", "long-term", "testing"):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="preferred_channel must be one of: stable, long-term, testing",
)
result = await db.execute(
text("""
UPDATE devices SET preferred_channel = :channel, updated_at = NOW()
WHERE id = :device_id
RETURNING id
"""),
{"channel": body.preferred_channel, "device_id": str(device_id)},
)
if not result.fetchone():
raise HTTPException(status_code=404, detail="Device not found")
await db.commit()
return {"status": "ok", "preferred_channel": body.preferred_channel}
@router.patch(
"/tenants/{tenant_id}/device-groups/{group_id}/preferred-channel",
summary="Set preferred firmware channel for a device group",
dependencies=[require_scope("firmware:write")],
)
@limiter.limit("20/minute")
async def set_group_preferred_channel(
request: Request,
tenant_id: uuid.UUID,
group_id: uuid.UUID,
body: PreferredChannelRequest,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> dict[str, str]:
await _check_tenant_access(current_user, tenant_id, db)
if body.preferred_channel not in ("stable", "long-term", "testing"):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="preferred_channel must be one of: stable, long-term, testing",
)
result = await db.execute(
text("""
UPDATE device_groups SET preferred_channel = :channel
WHERE id = :group_id
RETURNING id
"""),
{"channel": body.preferred_channel, "group_id": str(group_id)},
)
if not result.fetchone():
raise HTTPException(status_code=404, detail="Device group not found")
await db.commit()
return {"status": "ok", "preferred_channel": body.preferred_channel}
# =========================================================================
# GLOBAL ENDPOINTS (firmware versions are not tenant-scoped)
# =========================================================================
@router.get(
"/firmware/versions",
summary="List all known firmware versions from cache",
)
async def list_firmware_versions(
architecture: Optional[str] = Query(None),
channel: Optional[str] = Query(None),
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> list[dict[str, Any]]:
filters = []
params: dict[str, Any] = {}
if architecture:
filters.append("architecture = :arch")
params["arch"] = architecture
if channel:
filters.append("channel = :channel")
params["channel"] = channel
where = f"WHERE {' AND '.join(filters)}" if filters else ""
result = await db.execute(
text(f"""
SELECT id, architecture, channel, version, npk_url,
npk_local_path, npk_size_bytes, checked_at
FROM firmware_versions
{where}
ORDER BY architecture, channel, checked_at DESC
"""),
params,
)
return [
{
"id": str(row[0]),
"architecture": row[1],
"channel": row[2],
"version": row[3],
"npk_url": row[4],
"npk_local_path": row[5],
"npk_size_bytes": row[6],
"checked_at": row[7].isoformat() if row[7] else None,
}
for row in result.fetchall()
]
@router.post(
"/firmware/check",
summary="Trigger immediate firmware version check (super admin only)",
)
async def trigger_firmware_check(
current_user: CurrentUser = Depends(get_current_user),
) -> dict[str, Any]:
if not current_user.is_super_admin:
raise HTTPException(status_code=403, detail="Super admin only")
from app.services.firmware_service import check_latest_versions
results = await check_latest_versions()
return {"status": "ok", "versions_discovered": len(results), "versions": results}
@router.get(
"/firmware/cache",
summary="List locally cached NPK files (super admin only)",
)
async def list_firmware_cache(
current_user: CurrentUser = Depends(get_current_user),
) -> list[dict[str, Any]]:
if not current_user.is_super_admin:
raise HTTPException(status_code=403, detail="Super admin only")
from app.services.firmware_service import get_cached_firmware
return await get_cached_firmware()
@router.post(
"/firmware/download",
summary="Download a specific NPK to local cache (super admin only)",
)
async def download_firmware(
body: FirmwareDownloadRequest,
current_user: CurrentUser = Depends(get_current_user),
) -> dict[str, str]:
if not current_user.is_super_admin:
raise HTTPException(status_code=403, detail="Super admin only")
from app.services.firmware_service import download_firmware as _download
path = await _download(body.architecture, body.channel, body.version)
return {"status": "ok", "path": path}
# =========================================================================
# UPGRADE ENDPOINTS
# =========================================================================
class UpgradeRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
device_id: str
target_version: str
architecture: str
channel: str = "stable"
confirmed_major_upgrade: bool = False
scheduled_at: Optional[str] = None # ISO datetime or None for immediate
class MassUpgradeRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
device_ids: list[str]
target_version: str
channel: str = "stable"
confirmed_major_upgrade: bool = False
scheduled_at: Optional[str] = None
@router.post(
"/tenants/{tenant_id}/firmware/upgrade",
summary="Start or schedule a single device firmware upgrade",
status_code=status.HTTP_202_ACCEPTED,
dependencies=[require_scope("firmware:write")],
)
@limiter.limit("20/minute")
async def start_firmware_upgrade(
request: Request,
tenant_id: uuid.UUID,
body: UpgradeRequest,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> dict[str, Any]:
await _check_tenant_access(current_user, tenant_id, db)
if current_user.role == "viewer":
raise HTTPException(403, "Viewers cannot initiate upgrades")
# Look up device architecture if not provided
architecture = body.architecture
if not architecture:
dev_result = await db.execute(
text("SELECT architecture FROM devices WHERE id = CAST(:id AS uuid)"),
{"id": body.device_id},
)
dev_row = dev_result.fetchone()
if not dev_row or not dev_row[0]:
raise HTTPException(422, "Device architecture unknown — cannot upgrade")
architecture = dev_row[0]
# Create upgrade job
job_id = str(uuid.uuid4())
await db.execute(
text("""
INSERT INTO firmware_upgrade_jobs
(id, tenant_id, device_id, target_version, architecture, channel,
status, confirmed_major_upgrade, scheduled_at)
VALUES
(CAST(:id AS uuid), CAST(:tenant_id AS uuid), CAST(:device_id AS uuid),
:target_version, :architecture, :channel,
:status, :confirmed, :scheduled_at)
"""),
{
"id": job_id,
"tenant_id": str(tenant_id),
"device_id": body.device_id,
"target_version": body.target_version,
"architecture": architecture,
"channel": body.channel,
"status": "scheduled" if body.scheduled_at else "pending",
"confirmed": body.confirmed_major_upgrade,
"scheduled_at": body.scheduled_at,
},
)
await db.commit()
# Schedule or start immediately
if body.scheduled_at:
from app.services.upgrade_service import schedule_upgrade
schedule_upgrade(job_id, datetime.fromisoformat(body.scheduled_at))
else:
from app.services.upgrade_service import start_upgrade
asyncio.create_task(start_upgrade(job_id))
try:
await log_action(
db, tenant_id, current_user.user_id, "firmware_upgrade",
resource_type="firmware", resource_id=job_id,
device_id=uuid.UUID(body.device_id),
details={"target_version": body.target_version, "channel": body.channel},
)
except Exception:
pass
return {"status": "accepted", "job_id": job_id}
@router.post(
"/tenants/{tenant_id}/firmware/mass-upgrade",
summary="Start or schedule a mass firmware upgrade for multiple devices",
status_code=status.HTTP_202_ACCEPTED,
dependencies=[require_scope("firmware:write")],
)
@limiter.limit("5/minute")
async def start_mass_firmware_upgrade(
request: Request,
tenant_id: uuid.UUID,
body: MassUpgradeRequest,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> dict[str, Any]:
await _check_tenant_access(current_user, tenant_id, db)
if current_user.role == "viewer":
raise HTTPException(403, "Viewers cannot initiate upgrades")
rollout_group_id = str(uuid.uuid4())
jobs = []
for device_id in body.device_ids:
# Look up architecture per device
dev_result = await db.execute(
text("SELECT architecture FROM devices WHERE id = CAST(:id AS uuid)"),
{"id": device_id},
)
dev_row = dev_result.fetchone()
architecture = dev_row[0] if dev_row and dev_row[0] else "unknown"
job_id = str(uuid.uuid4())
await db.execute(
text("""
INSERT INTO firmware_upgrade_jobs
(id, tenant_id, device_id, rollout_group_id,
target_version, architecture, channel,
status, confirmed_major_upgrade, scheduled_at)
VALUES
(CAST(:id AS uuid), CAST(:tenant_id AS uuid),
CAST(:device_id AS uuid), CAST(:group_id AS uuid),
:target_version, :architecture, :channel,
:status, :confirmed, :scheduled_at)
"""),
{
"id": job_id,
"tenant_id": str(tenant_id),
"device_id": device_id,
"group_id": rollout_group_id,
"target_version": body.target_version,
"architecture": architecture,
"channel": body.channel,
"status": "scheduled" if body.scheduled_at else "pending",
"confirmed": body.confirmed_major_upgrade,
"scheduled_at": body.scheduled_at,
},
)
jobs.append({"job_id": job_id, "device_id": device_id, "architecture": architecture})
await db.commit()
# Schedule or start immediately
if body.scheduled_at:
from app.services.upgrade_service import schedule_mass_upgrade
schedule_mass_upgrade(rollout_group_id, datetime.fromisoformat(body.scheduled_at))
else:
from app.services.upgrade_service import start_mass_upgrade
asyncio.create_task(start_mass_upgrade(rollout_group_id))
return {
"status": "accepted",
"rollout_group_id": rollout_group_id,
"jobs": jobs,
}
@router.get(
"/tenants/{tenant_id}/firmware/upgrades",
summary="List firmware upgrade jobs for tenant",
dependencies=[require_scope("firmware:write")],
)
async def list_upgrade_jobs(
tenant_id: uuid.UUID,
upgrade_status: Optional[str] = Query(None, alias="status"),
device_id: Optional[str] = Query(None),
rollout_group_id: Optional[str] = Query(None),
page: int = Query(1, ge=1),
per_page: int = Query(50, ge=1, le=200),
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> dict[str, Any]:
await _check_tenant_access(current_user, tenant_id, db)
filters = ["1=1"]
params: dict[str, Any] = {}
if upgrade_status:
filters.append("j.status = :status")
params["status"] = upgrade_status
if device_id:
filters.append("j.device_id = CAST(:device_id AS uuid)")
params["device_id"] = device_id
if rollout_group_id:
filters.append("j.rollout_group_id = CAST(:group_id AS uuid)")
params["group_id"] = rollout_group_id
where = " AND ".join(filters)
offset = (page - 1) * per_page
count_result = await db.execute(
text(f"SELECT COUNT(*) FROM firmware_upgrade_jobs j WHERE {where}"),
params,
)
total = count_result.scalar() or 0
result = await db.execute(
text(f"""
SELECT j.id, j.device_id, j.rollout_group_id,
j.target_version, j.architecture, j.channel,
j.status, j.pre_upgrade_backup_sha, j.scheduled_at,
j.started_at, j.completed_at, j.error_message,
j.confirmed_major_upgrade, j.created_at,
d.hostname AS device_hostname
FROM firmware_upgrade_jobs j
LEFT JOIN devices d ON d.id = j.device_id
WHERE {where}
ORDER BY j.created_at DESC
LIMIT :limit OFFSET :offset
"""),
{**params, "limit": per_page, "offset": offset},
)
items = [
{
"id": str(row[0]),
"device_id": str(row[1]),
"rollout_group_id": str(row[2]) if row[2] else None,
"target_version": row[3],
"architecture": row[4],
"channel": row[5],
"status": row[6],
"pre_upgrade_backup_sha": row[7],
"scheduled_at": row[8].isoformat() if row[8] else None,
"started_at": row[9].isoformat() if row[9] else None,
"completed_at": row[10].isoformat() if row[10] else None,
"error_message": row[11],
"confirmed_major_upgrade": row[12],
"created_at": row[13].isoformat() if row[13] else None,
"device_hostname": row[14],
}
for row in result.fetchall()
]
return {"items": items, "total": total, "page": page, "per_page": per_page}
@router.get(
"/tenants/{tenant_id}/firmware/upgrades/{job_id}",
summary="Get single upgrade job detail",
dependencies=[require_scope("firmware:write")],
)
async def get_upgrade_job(
tenant_id: uuid.UUID,
job_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> dict[str, Any]:
await _check_tenant_access(current_user, tenant_id, db)
result = await db.execute(
text("""
SELECT j.id, j.device_id, j.rollout_group_id,
j.target_version, j.architecture, j.channel,
j.status, j.pre_upgrade_backup_sha, j.scheduled_at,
j.started_at, j.completed_at, j.error_message,
j.confirmed_major_upgrade, j.created_at,
d.hostname AS device_hostname
FROM firmware_upgrade_jobs j
LEFT JOIN devices d ON d.id = j.device_id
WHERE j.id = CAST(:job_id AS uuid)
"""),
{"job_id": str(job_id)},
)
row = result.fetchone()
if not row:
raise HTTPException(404, "Upgrade job not found")
return {
"id": str(row[0]),
"device_id": str(row[1]),
"rollout_group_id": str(row[2]) if row[2] else None,
"target_version": row[3],
"architecture": row[4],
"channel": row[5],
"status": row[6],
"pre_upgrade_backup_sha": row[7],
"scheduled_at": row[8].isoformat() if row[8] else None,
"started_at": row[9].isoformat() if row[9] else None,
"completed_at": row[10].isoformat() if row[10] else None,
"error_message": row[11],
"confirmed_major_upgrade": row[12],
"created_at": row[13].isoformat() if row[13] else None,
"device_hostname": row[14],
}
@router.get(
"/tenants/{tenant_id}/firmware/rollouts/{rollout_group_id}",
summary="Get mass rollout status with all jobs",
dependencies=[require_scope("firmware:write")],
)
async def get_rollout_status(
tenant_id: uuid.UUID,
rollout_group_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> dict[str, Any]:
await _check_tenant_access(current_user, tenant_id, db)
result = await db.execute(
text("""
SELECT j.id, j.device_id, j.status, j.target_version,
j.architecture, j.error_message, j.started_at,
j.completed_at, d.hostname
FROM firmware_upgrade_jobs j
LEFT JOIN devices d ON d.id = j.device_id
WHERE j.rollout_group_id = CAST(:group_id AS uuid)
ORDER BY j.created_at ASC
"""),
{"group_id": str(rollout_group_id)},
)
rows = result.fetchall()
if not rows:
raise HTTPException(404, "Rollout group not found")
# Compute summary
total = len(rows)
completed = sum(1 for r in rows if r[2] == "completed")
failed = sum(1 for r in rows if r[2] == "failed")
paused = sum(1 for r in rows if r[2] == "paused")
pending = sum(1 for r in rows if r[2] in ("pending", "scheduled"))
# Find currently running device
active_statuses = {"downloading", "uploading", "rebooting", "verifying"}
current_device = None
for r in rows:
if r[2] in active_statuses:
current_device = r[8] or str(r[1])
break
jobs = [
{
"id": str(r[0]),
"device_id": str(r[1]),
"status": r[2],
"target_version": r[3],
"architecture": r[4],
"error_message": r[5],
"started_at": r[6].isoformat() if r[6] else None,
"completed_at": r[7].isoformat() if r[7] else None,
"device_hostname": r[8],
}
for r in rows
]
return {
"rollout_group_id": str(rollout_group_id),
"total": total,
"completed": completed,
"failed": failed,
"paused": paused,
"pending": pending,
"current_device": current_device,
"jobs": jobs,
}
@router.post(
"/tenants/{tenant_id}/firmware/upgrades/{job_id}/cancel",
summary="Cancel a scheduled or pending upgrade",
dependencies=[require_scope("firmware:write")],
)
@limiter.limit("20/minute")
async def cancel_upgrade_endpoint(
request: Request,
tenant_id: uuid.UUID,
job_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> dict[str, str]:
await _check_tenant_access(current_user, tenant_id, db)
if current_user.role == "viewer":
raise HTTPException(403, "Viewers cannot cancel upgrades")
from app.services.upgrade_service import cancel_upgrade
await cancel_upgrade(str(job_id))
return {"status": "ok", "message": "Upgrade cancelled"}
@router.post(
"/tenants/{tenant_id}/firmware/upgrades/{job_id}/retry",
summary="Retry a failed upgrade",
dependencies=[require_scope("firmware:write")],
)
@limiter.limit("20/minute")
async def retry_upgrade_endpoint(
request: Request,
tenant_id: uuid.UUID,
job_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> dict[str, str]:
await _check_tenant_access(current_user, tenant_id, db)
if current_user.role == "viewer":
raise HTTPException(403, "Viewers cannot retry upgrades")
from app.services.upgrade_service import retry_failed_upgrade
await retry_failed_upgrade(str(job_id))
return {"status": "ok", "message": "Upgrade retry started"}
@router.post(
"/tenants/{tenant_id}/firmware/rollouts/{rollout_group_id}/resume",
summary="Resume a paused mass rollout",
dependencies=[require_scope("firmware:write")],
)
@limiter.limit("20/minute")
async def resume_rollout_endpoint(
request: Request,
tenant_id: uuid.UUID,
rollout_group_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> dict[str, str]:
await _check_tenant_access(current_user, tenant_id, db)
if current_user.role == "viewer":
raise HTTPException(403, "Viewers cannot resume rollouts")
from app.services.upgrade_service import resume_mass_upgrade
await resume_mass_upgrade(str(rollout_group_id))
return {"status": "ok", "message": "Rollout resumed"}
@router.post(
"/tenants/{tenant_id}/firmware/rollouts/{rollout_group_id}/abort",
summary="Abort remaining devices in a paused rollout",
dependencies=[require_scope("firmware:write")],
)
@limiter.limit("5/minute")
async def abort_rollout_endpoint(
request: Request,
tenant_id: uuid.UUID,
rollout_group_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> dict[str, Any]:
await _check_tenant_access(current_user, tenant_id, db)
if current_user.role == "viewer":
raise HTTPException(403, "Viewers cannot abort rollouts")
from app.services.upgrade_service import abort_mass_upgrade
aborted = await abort_mass_upgrade(str(rollout_group_id))
return {"status": "ok", "aborted_count": aborted}

View File

@@ -0,0 +1,309 @@
"""Maintenance windows API endpoints.
Tenant-scoped routes under /api/tenants/{tenant_id}/ for:
- Maintenance window CRUD (list, create, update, delete)
- Filterable by status: upcoming, active, past
RLS enforced via get_db() (app_user engine with tenant context).
RBAC: operator and above for all operations.
"""
import json
import logging
import uuid
from datetime import datetime
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from pydantic import BaseModel, ConfigDict
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db, set_tenant_context
from app.middleware.rate_limit import limiter
from app.middleware.tenant_context import CurrentUser, get_current_user
logger = logging.getLogger(__name__)
router = APIRouter(tags=["maintenance-windows"])
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
"""Verify the current user is allowed to access the given tenant."""
if current_user.is_super_admin:
await set_tenant_context(db, str(tenant_id))
elif current_user.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to this tenant",
)
def _require_operator(current_user: CurrentUser) -> None:
"""Raise 403 if user does not have at least operator role."""
if current_user.role == "viewer":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Requires at least operator role.",
)
# ---------------------------------------------------------------------------
# Request/response schemas
# ---------------------------------------------------------------------------
class MaintenanceWindowCreate(BaseModel):
model_config = ConfigDict(extra="forbid")
name: str
device_ids: list[str] = []
start_at: datetime
end_at: datetime
suppress_alerts: bool = True
notes: Optional[str] = None
class MaintenanceWindowUpdate(BaseModel):
model_config = ConfigDict(extra="forbid")
name: Optional[str] = None
device_ids: Optional[list[str]] = None
start_at: Optional[datetime] = None
end_at: Optional[datetime] = None
suppress_alerts: Optional[bool] = None
notes: Optional[str] = None
class MaintenanceWindowResponse(BaseModel):
model_config = ConfigDict(extra="forbid")
id: str
tenant_id: str
name: str
device_ids: list[str]
start_at: str
end_at: str
suppress_alerts: bool
notes: Optional[str] = None
created_by: Optional[str] = None
created_at: str
# ---------------------------------------------------------------------------
# CRUD endpoints
# ---------------------------------------------------------------------------
@router.get(
"/tenants/{tenant_id}/maintenance-windows",
summary="List maintenance windows for tenant",
)
async def list_maintenance_windows(
tenant_id: uuid.UUID,
window_status: Optional[str] = Query(None, alias="status"),
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> list[dict[str, Any]]:
await _check_tenant_access(current_user, tenant_id, db)
_require_operator(current_user)
filters = ["1=1"]
params: dict[str, Any] = {}
if window_status == "active":
filters.append("mw.start_at <= NOW() AND mw.end_at >= NOW()")
elif window_status == "upcoming":
filters.append("mw.start_at > NOW()")
elif window_status == "past":
filters.append("mw.end_at < NOW()")
where = " AND ".join(filters)
result = await db.execute(
text(f"""
SELECT mw.id, mw.tenant_id, mw.name, mw.device_ids,
mw.start_at, mw.end_at, mw.suppress_alerts,
mw.notes, mw.created_by, mw.created_at
FROM maintenance_windows mw
WHERE {where}
ORDER BY mw.start_at DESC
"""),
params,
)
return [
{
"id": str(row[0]),
"tenant_id": str(row[1]),
"name": row[2],
"device_ids": row[3] if isinstance(row[3], list) else [],
"start_at": row[4].isoformat() if row[4] else None,
"end_at": row[5].isoformat() if row[5] else None,
"suppress_alerts": row[6],
"notes": row[7],
"created_by": str(row[8]) if row[8] else None,
"created_at": row[9].isoformat() if row[9] else None,
}
for row in result.fetchall()
]
@router.post(
"/tenants/{tenant_id}/maintenance-windows",
summary="Create maintenance window",
status_code=status.HTTP_201_CREATED,
)
@limiter.limit("20/minute")
async def create_maintenance_window(
request: Request,
tenant_id: uuid.UUID,
body: MaintenanceWindowCreate,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> dict[str, Any]:
await _check_tenant_access(current_user, tenant_id, db)
_require_operator(current_user)
if body.end_at <= body.start_at:
raise HTTPException(422, "end_at must be after start_at")
window_id = str(uuid.uuid4())
await db.execute(
text("""
INSERT INTO maintenance_windows
(id, tenant_id, name, device_ids, start_at, end_at,
suppress_alerts, notes, created_by)
VALUES
(CAST(:id AS uuid), CAST(:tenant_id AS uuid),
:name, CAST(:device_ids AS jsonb), :start_at, :end_at,
:suppress_alerts, :notes, CAST(:created_by AS uuid))
"""),
{
"id": window_id,
"tenant_id": str(tenant_id),
"name": body.name,
"device_ids": json.dumps(body.device_ids),
"start_at": body.start_at,
"end_at": body.end_at,
"suppress_alerts": body.suppress_alerts,
"notes": body.notes,
"created_by": str(current_user.user_id),
},
)
await db.commit()
return {
"id": window_id,
"tenant_id": str(tenant_id),
"name": body.name,
"device_ids": body.device_ids,
"start_at": body.start_at.isoformat(),
"end_at": body.end_at.isoformat(),
"suppress_alerts": body.suppress_alerts,
"notes": body.notes,
"created_by": str(current_user.user_id),
"created_at": datetime.utcnow().isoformat(),
}
@router.put(
"/tenants/{tenant_id}/maintenance-windows/{window_id}",
summary="Update maintenance window",
)
@limiter.limit("20/minute")
async def update_maintenance_window(
request: Request,
tenant_id: uuid.UUID,
window_id: uuid.UUID,
body: MaintenanceWindowUpdate,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> dict[str, Any]:
await _check_tenant_access(current_user, tenant_id, db)
_require_operator(current_user)
# Build dynamic SET clause for partial updates
set_parts: list[str] = ["updated_at = NOW()"]
params: dict[str, Any] = {"window_id": str(window_id)}
if body.name is not None:
set_parts.append("name = :name")
params["name"] = body.name
if body.device_ids is not None:
set_parts.append("device_ids = CAST(:device_ids AS jsonb)")
params["device_ids"] = json.dumps(body.device_ids)
if body.start_at is not None:
set_parts.append("start_at = :start_at")
params["start_at"] = body.start_at
if body.end_at is not None:
set_parts.append("end_at = :end_at")
params["end_at"] = body.end_at
if body.suppress_alerts is not None:
set_parts.append("suppress_alerts = :suppress_alerts")
params["suppress_alerts"] = body.suppress_alerts
if body.notes is not None:
set_parts.append("notes = :notes")
params["notes"] = body.notes
set_clause = ", ".join(set_parts)
result = await db.execute(
text(f"""
UPDATE maintenance_windows
SET {set_clause}
WHERE id = CAST(:window_id AS uuid)
RETURNING id, tenant_id, name, device_ids, start_at, end_at,
suppress_alerts, notes, created_by, created_at
"""),
params,
)
row = result.fetchone()
if not row:
raise HTTPException(404, "Maintenance window not found")
await db.commit()
return {
"id": str(row[0]),
"tenant_id": str(row[1]),
"name": row[2],
"device_ids": row[3] if isinstance(row[3], list) else [],
"start_at": row[4].isoformat() if row[4] else None,
"end_at": row[5].isoformat() if row[5] else None,
"suppress_alerts": row[6],
"notes": row[7],
"created_by": str(row[8]) if row[8] else None,
"created_at": row[9].isoformat() if row[9] else None,
}
@router.delete(
"/tenants/{tenant_id}/maintenance-windows/{window_id}",
summary="Delete maintenance window",
status_code=status.HTTP_204_NO_CONTENT,
)
@limiter.limit("5/minute")
async def delete_maintenance_window(
request: Request,
tenant_id: uuid.UUID,
window_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> None:
await _check_tenant_access(current_user, tenant_id, db)
_require_operator(current_user)
result = await db.execute(
text(
"DELETE FROM maintenance_windows WHERE id = CAST(:id AS uuid) RETURNING id"
),
{"id": str(window_id)},
)
if not result.fetchone():
raise HTTPException(404, "Maintenance window not found")
await db.commit()

View File

@@ -0,0 +1,414 @@
"""
Metrics API endpoints for querying TimescaleDB hypertables.
All device-scoped routes are tenant-scoped under
/api/tenants/{tenant_id}/devices/{device_id}/metrics/*.
Fleet summary endpoints are under /api/tenants/{tenant_id}/fleet/summary
and /api/fleet/summary (super_admin cross-tenant).
RLS is enforced via get_db() — the app_user engine applies tenant filtering
automatically based on the SET LOCAL app.current_tenant context.
All endpoints require authentication (get_current_user) and enforce
tenant access via _check_tenant_access.
"""
import uuid
from datetime import datetime, timedelta
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.middleware.tenant_context import CurrentUser, get_current_user
router = APIRouter(tags=["metrics"])
def _bucket_for_range(start: datetime, end: datetime) -> timedelta:
"""
Select an appropriate time_bucket size based on the requested time range.
Shorter ranges get finer granularity; longer ranges get coarser buckets
to keep result sets manageable.
Returns a timedelta because asyncpg requires a Python timedelta (not a
string interval literal) when binding the first argument of time_bucket().
"""
delta = end - start
hours = delta.total_seconds() / 3600
if hours <= 1:
return timedelta(minutes=1)
elif hours <= 6:
return timedelta(minutes=5)
elif hours <= 24:
return timedelta(minutes=15)
elif hours <= 168: # 7 days
return timedelta(hours=1)
elif hours <= 720: # 30 days
return timedelta(hours=6)
else:
return timedelta(days=1)
async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
"""
Verify the current user is allowed to access the given tenant.
- super_admin can access any tenant — re-sets DB tenant context to target tenant.
- All other roles must match their own tenant_id.
"""
if current_user.is_super_admin:
# Re-set tenant context to the target tenant so RLS allows the operation
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied: you do not belong to this tenant.",
)
# ---------------------------------------------------------------------------
# Health metrics
# ---------------------------------------------------------------------------
@router.get(
"/tenants/{tenant_id}/devices/{device_id}/metrics/health",
summary="Time-bucketed health metrics (CPU, memory, disk, temperature)",
)
async def device_health_metrics(
tenant_id: uuid.UUID,
device_id: uuid.UUID,
start: datetime = Query(..., description="Start of time range (ISO format)"),
end: datetime = Query(..., description="End of time range (ISO format)"),
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> list[dict[str, Any]]:
"""
Return time-bucketed CPU, memory, disk, and temperature metrics for a device.
Bucket size adapts automatically to the requested time range.
"""
await _check_tenant_access(current_user, tenant_id, db)
bucket = _bucket_for_range(start, end)
result = await db.execute(
text("""
SELECT
time_bucket(:bucket, time) AS bucket,
avg(cpu_load)::smallint AS avg_cpu,
max(cpu_load)::smallint AS max_cpu,
avg(CASE WHEN total_memory > 0
THEN round((1 - free_memory::float / total_memory) * 100)
ELSE NULL END)::smallint AS avg_mem_pct,
avg(CASE WHEN total_disk > 0
THEN round((1 - free_disk::float / total_disk) * 100)
ELSE NULL END)::smallint AS avg_disk_pct,
avg(temperature)::smallint AS avg_temp
FROM health_metrics
WHERE device_id = :device_id
AND time >= :start AND time < :end
GROUP BY bucket
ORDER BY bucket ASC
"""),
{"bucket": bucket, "device_id": str(device_id), "start": start, "end": end},
)
rows = result.mappings().all()
return [dict(row) for row in rows]
# ---------------------------------------------------------------------------
# Interface traffic metrics
# ---------------------------------------------------------------------------
@router.get(
"/tenants/{tenant_id}/devices/{device_id}/metrics/interfaces",
summary="Time-bucketed interface bandwidth metrics (bps from cumulative byte deltas)",
)
async def device_interface_metrics(
tenant_id: uuid.UUID,
device_id: uuid.UUID,
start: datetime = Query(..., description="Start of time range (ISO format)"),
end: datetime = Query(..., description="End of time range (ISO format)"),
interface: Optional[str] = Query(None, description="Filter to a specific interface name"),
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> list[dict[str, Any]]:
"""
Return time-bucketed interface traffic metrics for a device.
Bandwidth (bps) is computed from raw cumulative byte counters using
SQL LAG() window functions — no poller-side state is required.
Counter wraps (rx_bytes < prev_rx) are treated as NULL to avoid
incorrect spikes.
"""
await _check_tenant_access(current_user, tenant_id, db)
bucket = _bucket_for_range(start, end)
# Build interface filter clause conditionally.
# The interface name is passed as a bind parameter — never interpolated
# into the SQL string — so this is safe from SQL injection.
interface_filter = "AND interface = :interface" if interface else ""
sql = f"""
WITH ordered AS (
SELECT
time,
interface,
rx_bytes,
tx_bytes,
LAG(rx_bytes) OVER (PARTITION BY interface ORDER BY time) AS prev_rx,
LAG(tx_bytes) OVER (PARTITION BY interface ORDER BY time) AS prev_tx,
EXTRACT(EPOCH FROM time - LAG(time) OVER (PARTITION BY interface ORDER BY time)) AS dt
FROM interface_metrics
WHERE device_id = :device_id
AND time >= :start AND time < :end
{interface_filter}
),
with_bps AS (
SELECT
time,
interface,
rx_bytes,
tx_bytes,
CASE WHEN rx_bytes >= prev_rx AND dt > 0
THEN ((rx_bytes - prev_rx) * 8 / dt)::bigint
ELSE NULL END AS rx_bps,
CASE WHEN tx_bytes >= prev_tx AND dt > 0
THEN ((tx_bytes - prev_tx) * 8 / dt)::bigint
ELSE NULL END AS tx_bps
FROM ordered
WHERE prev_rx IS NOT NULL
)
SELECT
time_bucket(:bucket, time) AS bucket,
interface,
avg(rx_bps)::bigint AS avg_rx_bps,
avg(tx_bps)::bigint AS avg_tx_bps,
max(rx_bps)::bigint AS max_rx_bps,
max(tx_bps)::bigint AS max_tx_bps
FROM with_bps
WHERE rx_bps IS NOT NULL
GROUP BY bucket, interface
ORDER BY interface, bucket ASC
"""
params: dict[str, Any] = {
"bucket": bucket,
"device_id": str(device_id),
"start": start,
"end": end,
}
if interface:
params["interface"] = interface
result = await db.execute(text(sql), params)
rows = result.mappings().all()
return [dict(row) for row in rows]
@router.get(
"/tenants/{tenant_id}/devices/{device_id}/metrics/interfaces/list",
summary="List distinct interface names for a device",
)
async def device_interface_list(
tenant_id: uuid.UUID,
device_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> list[str]:
"""Return distinct interface names seen in interface_metrics for a device."""
await _check_tenant_access(current_user, tenant_id, db)
result = await db.execute(
text("""
SELECT DISTINCT interface
FROM interface_metrics
WHERE device_id = :device_id
ORDER BY interface
"""),
{"device_id": str(device_id)},
)
rows = result.scalars().all()
return list(rows)
# ---------------------------------------------------------------------------
# Wireless metrics
# ---------------------------------------------------------------------------
@router.get(
"/tenants/{tenant_id}/devices/{device_id}/metrics/wireless",
summary="Time-bucketed wireless metrics (clients, signal, CCQ)",
)
async def device_wireless_metrics(
tenant_id: uuid.UUID,
device_id: uuid.UUID,
start: datetime = Query(..., description="Start of time range (ISO format)"),
end: datetime = Query(..., description="End of time range (ISO format)"),
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> list[dict[str, Any]]:
"""Return time-bucketed wireless metrics per interface for a device."""
await _check_tenant_access(current_user, tenant_id, db)
bucket = _bucket_for_range(start, end)
result = await db.execute(
text("""
SELECT
time_bucket(:bucket, time) AS bucket,
interface,
avg(client_count)::smallint AS avg_clients,
max(client_count)::smallint AS max_clients,
avg(avg_signal)::smallint AS avg_signal,
avg(ccq)::smallint AS avg_ccq,
max(frequency) AS frequency
FROM wireless_metrics
WHERE device_id = :device_id
AND time >= :start AND time < :end
GROUP BY bucket, interface
ORDER BY interface, bucket ASC
"""),
{"bucket": bucket, "device_id": str(device_id), "start": start, "end": end},
)
rows = result.mappings().all()
return [dict(row) for row in rows]
@router.get(
"/tenants/{tenant_id}/devices/{device_id}/metrics/wireless/latest",
summary="Latest wireless stats per interface (not time-bucketed)",
)
async def device_wireless_latest(
tenant_id: uuid.UUID,
device_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> list[dict[str, Any]]:
"""Return the most recent wireless reading per interface for a device."""
await _check_tenant_access(current_user, tenant_id, db)
result = await db.execute(
text("""
SELECT DISTINCT ON (interface)
interface, client_count, avg_signal, ccq, frequency, time
FROM wireless_metrics
WHERE device_id = :device_id
ORDER BY interface, time DESC
"""),
{"device_id": str(device_id)},
)
rows = result.mappings().all()
return [dict(row) for row in rows]
# ---------------------------------------------------------------------------
# Sparkline
# ---------------------------------------------------------------------------
@router.get(
"/tenants/{tenant_id}/devices/{device_id}/metrics/sparkline",
summary="Last 12 health readings for sparkline display",
)
async def device_sparkline(
tenant_id: uuid.UUID,
device_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> list[dict[str, Any]]:
"""
Return the last 12 CPU readings (in chronological order) for sparkline
display in the fleet table.
"""
await _check_tenant_access(current_user, tenant_id, db)
result = await db.execute(
text("""
SELECT cpu_load, time
FROM (
SELECT cpu_load, time
FROM health_metrics
WHERE device_id = :device_id
ORDER BY time DESC
LIMIT 12
) sub
ORDER BY time ASC
"""),
{"device_id": str(device_id)},
)
rows = result.mappings().all()
return [dict(row) for row in rows]
# ---------------------------------------------------------------------------
# Fleet summary
# ---------------------------------------------------------------------------
_FLEET_SUMMARY_SQL = """
SELECT
d.id, d.hostname, d.ip_address, d.status, d.model, d.last_seen,
d.uptime_seconds, d.last_cpu_load, d.last_memory_used_pct,
d.latitude, d.longitude,
d.tenant_id, t.name AS tenant_name
FROM devices d
JOIN tenants t ON d.tenant_id = t.id
ORDER BY t.name, d.hostname
"""
@router.get(
"/tenants/{tenant_id}/fleet/summary",
summary="Fleet summary for a tenant (latest metrics per device)",
)
async def fleet_summary(
tenant_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> list[dict[str, Any]]:
"""
Return fleet summary for a single tenant.
Queries the devices table (not hypertables) for speed.
RLS filters to only devices belonging to the tenant automatically.
"""
await _check_tenant_access(current_user, tenant_id, db)
result = await db.execute(text(_FLEET_SUMMARY_SQL))
rows = result.mappings().all()
return [dict(row) for row in rows]
@router.get(
"/fleet/summary",
summary="Cross-tenant fleet summary (super_admin only)",
)
async def fleet_summary_all(
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> list[dict[str, Any]]:
"""
Return fleet summary across ALL tenants.
Requires super_admin role. The RLS policy for super_admin returns all
rows across all tenants, so the same SQL query works without modification.
This avoids the N+1 problem of fetching per-tenant summaries in a loop.
"""
if current_user.role != "super_admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Super admin required",
)
result = await db.execute(text(_FLEET_SUMMARY_SQL))
rows = result.mappings().all()
return [dict(row) for row in rows]

View File

@@ -0,0 +1,146 @@
"""Report generation API endpoint.
POST /api/tenants/{tenant_id}/reports/generate
Generates PDF or CSV reports for device inventory, metrics summary,
alert history, and change log.
RLS enforced via get_db() (app_user engine with tenant context).
RBAC: require at least operator role.
"""
import uuid
from datetime import datetime
from enum import Enum
from typing import Optional
import structlog
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db, set_tenant_context
from app.middleware.tenant_context import CurrentUser, get_current_user
from app.services.report_service import generate_report
logger = structlog.get_logger(__name__)
router = APIRouter(tags=["reports"])
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
"""Verify the current user is allowed to access the given tenant."""
if current_user.is_super_admin:
await set_tenant_context(db, str(tenant_id))
elif current_user.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to this tenant",
)
def _require_operator(current_user: CurrentUser) -> None:
"""Raise 403 if user is a viewer (reports require operator+)."""
if current_user.role == "viewer":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Reports require at least operator role.",
)
# ---------------------------------------------------------------------------
# Request schema
# ---------------------------------------------------------------------------
class ReportType(str, Enum):
device_inventory = "device_inventory"
metrics_summary = "metrics_summary"
alert_history = "alert_history"
change_log = "change_log"
class ReportFormat(str, Enum):
pdf = "pdf"
csv = "csv"
class ReportRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
type: ReportType
date_from: Optional[datetime] = None
date_to: Optional[datetime] = None
format: ReportFormat = ReportFormat.pdf
# ---------------------------------------------------------------------------
# Endpoint
# ---------------------------------------------------------------------------
@router.post(
"/tenants/{tenant_id}/reports/generate",
summary="Generate a report (PDF or CSV)",
response_class=StreamingResponse,
)
async def generate_report_endpoint(
tenant_id: uuid.UUID,
body: ReportRequest,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> StreamingResponse:
"""Generate and download a report as PDF or CSV.
- device_inventory: no date range required
- metrics_summary, alert_history, change_log: date_from and date_to required
"""
await _check_tenant_access(current_user, tenant_id, db)
_require_operator(current_user)
# Validate date range for time-based reports
if body.type != ReportType.device_inventory:
if not body.date_from or not body.date_to:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"date_from and date_to are required for {body.type.value} reports.",
)
if body.date_from > body.date_to:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="date_from must be before date_to.",
)
try:
file_bytes, content_type, filename = await generate_report(
db=db,
tenant_id=tenant_id,
report_type=body.type.value,
date_from=body.date_from,
date_to=body.date_to,
fmt=body.format.value,
)
except Exception as exc:
logger.error("report_generation_failed", error=str(exc), report_type=body.type.value)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Report generation failed: {str(exc)}",
)
import io
return StreamingResponse(
io.BytesIO(file_bytes),
media_type=content_type,
headers={
"Content-Disposition": f'attachment; filename="{filename}"',
"Content-Length": str(len(file_bytes)),
},
)

View File

@@ -0,0 +1,155 @@
"""System settings router — global SMTP configuration.
Super-admin only. Stores SMTP settings in system_settings table with
Transit encryption for passwords. Falls back to .env values.
"""
import logging
from typing import Optional
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from sqlalchemy import text
from app.config import settings
from app.database import AdminAsyncSessionLocal
from app.middleware.rbac import require_role
from app.services.email_service import SMTPConfig, send_test_email, test_smtp_connection
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/settings", tags=["settings"])
SMTP_KEYS = [
"smtp_host",
"smtp_port",
"smtp_user",
"smtp_password",
"smtp_use_tls",
"smtp_from_address",
"smtp_provider",
]
class SMTPSettingsUpdate(BaseModel):
smtp_host: str
smtp_port: int = 587
smtp_user: Optional[str] = None
smtp_password: Optional[str] = None
smtp_use_tls: bool = False
smtp_from_address: str = "noreply@example.com"
smtp_provider: str = "custom"
class SMTPTestRequest(BaseModel):
to: str
smtp_host: Optional[str] = None
smtp_port: Optional[int] = None
smtp_user: Optional[str] = None
smtp_password: Optional[str] = None
smtp_use_tls: Optional[bool] = None
smtp_from_address: Optional[str] = None
async def _get_system_settings(keys: list[str]) -> dict:
"""Read settings from system_settings table."""
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("SELECT key, value FROM system_settings WHERE key = ANY(:keys)"),
{"keys": keys},
)
return {row[0]: row[1] for row in result.fetchall()}
async def _set_system_settings(updates: dict, user_id: str) -> None:
"""Upsert settings into system_settings table."""
async with AdminAsyncSessionLocal() as session:
for key, value in updates.items():
await session.execute(
text("""
INSERT INTO system_settings (key, value, updated_by, updated_at)
VALUES (:key, :value, CAST(:user_id AS uuid), now())
ON CONFLICT (key) DO UPDATE
SET value = :value, updated_by = CAST(:user_id AS uuid), updated_at = now()
"""),
{"key": key, "value": str(value) if value is not None else None, "user_id": user_id},
)
await session.commit()
async def get_smtp_config() -> SMTPConfig:
"""Get SMTP config from system_settings, falling back to .env."""
db_settings = await _get_system_settings(SMTP_KEYS)
return SMTPConfig(
host=db_settings.get("smtp_host") or settings.SMTP_HOST,
port=int(db_settings.get("smtp_port") or settings.SMTP_PORT),
user=db_settings.get("smtp_user") or settings.SMTP_USER,
password=db_settings.get("smtp_password") or settings.SMTP_PASSWORD,
use_tls=(db_settings.get("smtp_use_tls") or str(settings.SMTP_USE_TLS)).lower() == "true",
from_address=db_settings.get("smtp_from_address") or settings.SMTP_FROM_ADDRESS,
)
@router.get("/smtp")
async def get_smtp_settings(user=Depends(require_role("super_admin"))):
"""Get current global SMTP configuration. Password is redacted."""
db_settings = await _get_system_settings(SMTP_KEYS)
return {
"smtp_host": db_settings.get("smtp_host") or settings.SMTP_HOST,
"smtp_port": int(db_settings.get("smtp_port") or settings.SMTP_PORT),
"smtp_user": db_settings.get("smtp_user") or settings.SMTP_USER or "",
"smtp_use_tls": (db_settings.get("smtp_use_tls") or str(settings.SMTP_USE_TLS)).lower() == "true",
"smtp_from_address": db_settings.get("smtp_from_address") or settings.SMTP_FROM_ADDRESS,
"smtp_provider": db_settings.get("smtp_provider") or "custom",
"smtp_password_set": bool(db_settings.get("smtp_password") or settings.SMTP_PASSWORD),
"source": "database" if db_settings.get("smtp_host") else "environment",
}
@router.put("/smtp")
async def update_smtp_settings(
data: SMTPSettingsUpdate,
user=Depends(require_role("super_admin")),
):
"""Update global SMTP configuration."""
updates = {
"smtp_host": data.smtp_host,
"smtp_port": str(data.smtp_port),
"smtp_user": data.smtp_user,
"smtp_use_tls": str(data.smtp_use_tls).lower(),
"smtp_from_address": data.smtp_from_address,
"smtp_provider": data.smtp_provider,
}
if data.smtp_password is not None:
updates["smtp_password"] = data.smtp_password
await _set_system_settings(updates, str(user.id))
return {"status": "ok"}
@router.post("/smtp/test")
async def test_smtp(
data: SMTPTestRequest,
user=Depends(require_role("super_admin")),
):
"""Test SMTP connection and optionally send a test email."""
# Use provided values or fall back to saved config
saved = await get_smtp_config()
config = SMTPConfig(
host=data.smtp_host or saved.host,
port=data.smtp_port if data.smtp_port is not None else saved.port,
user=data.smtp_user if data.smtp_user is not None else saved.user,
password=data.smtp_password if data.smtp_password is not None else saved.password,
use_tls=data.smtp_use_tls if data.smtp_use_tls is not None else saved.use_tls,
from_address=data.smtp_from_address or saved.from_address,
)
conn_result = await test_smtp_connection(config)
if not conn_result["success"]:
return conn_result
if data.to:
return await send_test_email(data.to, config)
return conn_result

141
backend/app/routers/sse.py Normal file
View File

@@ -0,0 +1,141 @@
"""SSE streaming endpoint for real-time event delivery.
Provides a Server-Sent Events endpoint per tenant that streams device status,
alert, config push, and firmware progress events in real time. Authentication
is via a short-lived, single-use exchange token (obtained from POST /auth/sse-token)
to avoid exposing the full JWT in query parameters.
"""
import asyncio
import json
import uuid
from typing import AsyncGenerator, Optional
import redis.asyncio as aioredis
import structlog
from fastapi import APIRouter, HTTPException, Query, Request, status
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
from app.services.sse_manager import SSEConnectionManager
logger = structlog.get_logger(__name__)
router = APIRouter(tags=["sse"])
# ─── Redis for SSE token validation ───────────────────────────────────────────
_redis: aioredis.Redis | None = None
async def _get_sse_redis() -> aioredis.Redis:
"""Lazily initialise and return the SSE Redis client."""
global _redis
if _redis is None:
from app.config import settings
_redis = aioredis.from_url(settings.REDIS_URL, decode_responses=True)
return _redis
async def _validate_sse_token(token: str) -> dict:
"""Validate a short-lived SSE exchange token via Redis.
The token is single-use: retrieved and deleted atomically with GETDEL.
If the token is not found (expired or already used), raises 401.
Args:
token: SSE exchange token string (from query param).
Returns:
Dict with user_id, tenant_id, and role.
Raises:
HTTPException 401: If the token is invalid, expired, or already used.
"""
redis = await _get_sse_redis()
key = f"sse_token:{token}"
data = await redis.getdel(key) # Single-use: delete on retrieval
if not data:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired SSE token",
)
return json.loads(data)
@router.get(
"/tenants/{tenant_id}/events/stream",
summary="SSE event stream for real-time tenant events",
response_class=EventSourceResponse,
)
async def event_stream(
request: Request,
tenant_id: uuid.UUID,
token: str = Query(..., description="Short-lived SSE exchange token (from POST /auth/sse-token)"),
) -> EventSourceResponse:
"""Stream real-time events for a tenant via Server-Sent Events.
Event types: device_status, alert_fired, alert_resolved, config_push,
firmware_progress, metric_update.
Supports Last-Event-ID header for reconnection replay.
Sends heartbeat comments every 15 seconds on idle connections.
"""
# Validate exchange token from query parameter (single-use, 30s TTL)
user_context = await _validate_sse_token(token)
user_role = user_context.get("role", "")
user_tenant_id = user_context.get("tenant_id")
user_id = user_context.get("user_id", "")
# Authorization: user must belong to the requested tenant or be super_admin
if user_role != "super_admin" and (user_tenant_id is None or str(user_tenant_id) != str(tenant_id)):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authorized for this tenant",
)
# super_admin receives events from ALL tenants (tenant_id filter = None)
filter_tenant_id: Optional[str] = None if user_role == "super_admin" else str(tenant_id)
# Generate unique connection ID
connection_id = f"sse-{uuid.uuid4().hex[:12]}"
# Check for Last-Event-ID header (reconnection replay)
last_event_id = request.headers.get("Last-Event-ID")
logger.info(
"sse.stream_requested",
connection_id=connection_id,
tenant_id=str(tenant_id),
user_id=user_id,
role=user_role,
last_event_id=last_event_id,
)
manager = SSEConnectionManager()
queue = await manager.connect(
connection_id=connection_id,
tenant_id=filter_tenant_id,
last_event_id=last_event_id,
)
async def event_generator() -> AsyncGenerator[ServerSentEvent, None]:
"""Yield SSE events from the queue with 15s heartbeat on idle."""
try:
while True:
try:
event = await asyncio.wait_for(queue.get(), timeout=15.0)
yield ServerSentEvent(
data=event["data"],
event=event["event"],
id=event["id"],
)
except asyncio.TimeoutError:
# Send heartbeat comment to keep connection alive
yield ServerSentEvent(comment="heartbeat")
except asyncio.CancelledError:
break
finally:
await manager.disconnect()
logger.info("sse.stream_closed", connection_id=connection_id)
return EventSourceResponse(event_generator())

View File

@@ -0,0 +1,613 @@
"""
Config template CRUD, preview, and push API endpoints.
All routes are tenant-scoped under:
/api/tenants/{tenant_id}/templates/
Provides:
- GET /templates -- list templates (optional tag filter)
- POST /templates -- create a template
- GET /templates/{id} -- get single template
- PUT /templates/{id} -- update a template
- DELETE /templates/{id} -- delete a template
- POST /templates/{id}/preview -- preview rendered template for a device
- POST /templates/{id}/push -- push template to devices (sequential rollout)
- GET /templates/push-status/{rollout_id} -- poll push progress
RLS is enforced via get_db() (app_user engine with tenant context).
RBAC: viewer = read (GET/preview); operator and above = write (POST/PUT/DELETE/push).
"""
import asyncio
import logging
import uuid
from datetime import datetime, timezone
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from pydantic import BaseModel, ConfigDict
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.database import get_db
from app.middleware.rate_limit import limiter
from app.middleware.rbac import require_min_role, require_scope
from app.middleware.tenant_context import CurrentUser, get_current_user
from app.models.config_template import ConfigTemplate, ConfigTemplateTag, TemplatePushJob
from app.models.device import Device
from app.services import template_service
logger = logging.getLogger(__name__)
router = APIRouter(tags=["templates"])
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
"""Verify the current user is allowed to access the given tenant."""
if current_user.is_super_admin:
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied: you do not belong to this tenant.",
)
def _serialize_template(template: ConfigTemplate, include_content: bool = False) -> dict:
"""Serialize a ConfigTemplate to a response dict."""
result: dict[str, Any] = {
"id": str(template.id),
"name": template.name,
"description": template.description,
"tags": [tag.name for tag in template.tags],
"variable_count": len(template.variables) if template.variables else 0,
"created_at": template.created_at.isoformat(),
"updated_at": template.updated_at.isoformat(),
}
if include_content:
result["content"] = template.content
result["variables"] = template.variables or []
return result
# ---------------------------------------------------------------------------
# Request/Response schemas
# ---------------------------------------------------------------------------
class VariableDef(BaseModel):
model_config = ConfigDict(extra="forbid")
name: str
type: str = "string" # string | ip | integer | boolean | subnet
default: Optional[str] = None
description: Optional[str] = None
class TemplateCreateRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
name: str
description: Optional[str] = None
content: str
variables: list[VariableDef] = []
tags: list[str] = []
class TemplateUpdateRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
name: str
description: Optional[str] = None
content: str
variables: list[VariableDef] = []
tags: list[str] = []
class PreviewRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
device_id: str
variables: dict[str, str] = {}
class PushRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
device_ids: list[str]
variables: dict[str, str] = {}
# ---------------------------------------------------------------------------
# CRUD endpoints
# ---------------------------------------------------------------------------
@router.get(
"/tenants/{tenant_id}/templates",
summary="List config templates",
dependencies=[require_scope("config:read")],
)
async def list_templates(
tenant_id: uuid.UUID,
tag: Optional[str] = Query(None, description="Filter by tag name"),
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("viewer")),
db: AsyncSession = Depends(get_db),
) -> list[dict]:
"""List all config templates for a tenant with optional tag filtering."""
await _check_tenant_access(current_user, tenant_id, db)
query = (
select(ConfigTemplate)
.options(selectinload(ConfigTemplate.tags))
.where(ConfigTemplate.tenant_id == tenant_id) # type: ignore[arg-type]
.order_by(ConfigTemplate.updated_at.desc())
)
if tag:
query = query.where(
ConfigTemplate.id.in_( # type: ignore[attr-defined]
select(ConfigTemplateTag.template_id).where(
ConfigTemplateTag.name == tag,
ConfigTemplateTag.tenant_id == tenant_id, # type: ignore[arg-type]
)
)
)
result = await db.execute(query)
templates = result.scalars().all()
return [_serialize_template(t) for t in templates]
@router.post(
"/tenants/{tenant_id}/templates",
summary="Create a config template",
status_code=status.HTTP_201_CREATED,
dependencies=[require_scope("config:write")],
)
@limiter.limit("20/minute")
async def create_template(
request: Request,
tenant_id: uuid.UUID,
body: TemplateCreateRequest,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("operator")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Create a new config template with Jinja2 content and variable definitions."""
await _check_tenant_access(current_user, tenant_id, db)
# Auto-extract variables from content for comparison
detected = template_service.extract_variables(body.content)
provided_names = {v.name for v in body.variables}
unmatched = set(detected) - provided_names
if unmatched:
logger.warning(
"Template '%s' has undeclared variables: %s (auto-adding as string type)",
body.name, unmatched,
)
# Create template
template = ConfigTemplate(
tenant_id=tenant_id,
name=body.name,
description=body.description,
content=body.content,
variables=[v.model_dump() for v in body.variables],
)
db.add(template)
await db.flush() # Get the generated ID
# Create tags
for tag_name in body.tags:
tag = ConfigTemplateTag(
tenant_id=tenant_id,
name=tag_name,
template_id=template.id,
)
db.add(tag)
await db.flush()
# Re-query with tags loaded
result = await db.execute(
select(ConfigTemplate)
.options(selectinload(ConfigTemplate.tags))
.where(ConfigTemplate.id == template.id) # type: ignore[arg-type]
)
template = result.scalar_one()
return _serialize_template(template, include_content=True)
@router.get(
"/tenants/{tenant_id}/templates/{template_id}",
summary="Get a single config template",
dependencies=[require_scope("config:read")],
)
async def get_template(
tenant_id: uuid.UUID,
template_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("viewer")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Get a config template with full content, variables, and tags."""
await _check_tenant_access(current_user, tenant_id, db)
result = await db.execute(
select(ConfigTemplate)
.options(selectinload(ConfigTemplate.tags))
.where(
ConfigTemplate.id == template_id, # type: ignore[arg-type]
ConfigTemplate.tenant_id == tenant_id, # type: ignore[arg-type]
)
)
template = result.scalar_one_or_none()
if template is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Template {template_id} not found",
)
return _serialize_template(template, include_content=True)
@router.put(
"/tenants/{tenant_id}/templates/{template_id}",
summary="Update a config template",
dependencies=[require_scope("config:write")],
)
@limiter.limit("20/minute")
async def update_template(
request: Request,
tenant_id: uuid.UUID,
template_id: uuid.UUID,
body: TemplateUpdateRequest,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("operator")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Update an existing config template."""
await _check_tenant_access(current_user, tenant_id, db)
result = await db.execute(
select(ConfigTemplate)
.options(selectinload(ConfigTemplate.tags))
.where(
ConfigTemplate.id == template_id, # type: ignore[arg-type]
ConfigTemplate.tenant_id == tenant_id, # type: ignore[arg-type]
)
)
template = result.scalar_one_or_none()
if template is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Template {template_id} not found",
)
# Update fields
template.name = body.name
template.description = body.description
template.content = body.content
template.variables = [v.model_dump() for v in body.variables]
# Replace tags: delete old, create new
await db.execute(
delete(ConfigTemplateTag).where(
ConfigTemplateTag.template_id == template_id # type: ignore[arg-type]
)
)
for tag_name in body.tags:
tag = ConfigTemplateTag(
tenant_id=tenant_id,
name=tag_name,
template_id=template.id,
)
db.add(tag)
await db.flush()
# Re-query with fresh tags
result = await db.execute(
select(ConfigTemplate)
.options(selectinload(ConfigTemplate.tags))
.where(ConfigTemplate.id == template.id) # type: ignore[arg-type]
)
template = result.scalar_one()
return _serialize_template(template, include_content=True)
@router.delete(
"/tenants/{tenant_id}/templates/{template_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a config template",
dependencies=[require_scope("config:write")],
)
@limiter.limit("5/minute")
async def delete_template(
request: Request,
tenant_id: uuid.UUID,
template_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("operator")),
db: AsyncSession = Depends(get_db),
) -> None:
"""Delete a config template. Tags are cascade-deleted. Push jobs are SET NULL."""
await _check_tenant_access(current_user, tenant_id, db)
result = await db.execute(
select(ConfigTemplate).where(
ConfigTemplate.id == template_id, # type: ignore[arg-type]
ConfigTemplate.tenant_id == tenant_id, # type: ignore[arg-type]
)
)
template = result.scalar_one_or_none()
if template is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Template {template_id} not found",
)
await db.delete(template)
# ---------------------------------------------------------------------------
# Preview & Push endpoints
# ---------------------------------------------------------------------------
@router.post(
"/tenants/{tenant_id}/templates/{template_id}/preview",
summary="Preview template rendered for a specific device",
dependencies=[require_scope("config:read")],
)
async def preview_template(
tenant_id: uuid.UUID,
template_id: uuid.UUID,
body: PreviewRequest,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("viewer")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Render a template with device context and custom variables for preview."""
await _check_tenant_access(current_user, tenant_id, db)
# Load template
result = await db.execute(
select(ConfigTemplate).where(
ConfigTemplate.id == template_id, # type: ignore[arg-type]
ConfigTemplate.tenant_id == tenant_id, # type: ignore[arg-type]
)
)
template = result.scalar_one_or_none()
if template is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Template {template_id} not found",
)
# Load device
result = await db.execute(
select(Device).where(Device.id == body.device_id) # type: ignore[arg-type]
)
device = result.scalar_one_or_none()
if device is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Device {body.device_id} not found",
)
# Validate variables against type definitions
if template.variables:
for var_def in template.variables:
var_name = var_def.get("name", "")
var_type = var_def.get("type", "string")
value = body.variables.get(var_name)
if value is None:
# Use default if available
default = var_def.get("default")
if default is not None:
body.variables[var_name] = default
continue
error = template_service.validate_variable(var_name, value, var_type)
if error:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=error,
)
# Render
try:
rendered = template_service.render_template(
template.content,
{
"hostname": device.hostname,
"ip_address": device.ip_address,
"model": device.model,
},
body.variables,
)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Template rendering failed: {exc}",
)
return {
"rendered": rendered,
"device_hostname": device.hostname,
}
@router.post(
"/tenants/{tenant_id}/templates/{template_id}/push",
summary="Push template to devices (sequential rollout with panic-revert)",
dependencies=[require_scope("config:write")],
)
@limiter.limit("5/minute")
async def push_template(
request: Request,
tenant_id: uuid.UUID,
template_id: uuid.UUID,
body: PushRequest,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("operator")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Start a template push to one or more devices.
Creates push jobs for each device and starts a background sequential rollout.
Returns the rollout_id for status polling.
"""
await _check_tenant_access(current_user, tenant_id, db)
# Load template
result = await db.execute(
select(ConfigTemplate).where(
ConfigTemplate.id == template_id, # type: ignore[arg-type]
ConfigTemplate.tenant_id == tenant_id, # type: ignore[arg-type]
)
)
template = result.scalar_one_or_none()
if template is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Template {template_id} not found",
)
if not body.device_ids:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="At least one device_id is required",
)
# Validate variables
if template.variables:
for var_def in template.variables:
var_name = var_def.get("name", "")
var_type = var_def.get("type", "string")
value = body.variables.get(var_name)
if value is None:
default = var_def.get("default")
if default is not None:
body.variables[var_name] = default
continue
error = template_service.validate_variable(var_name, value, var_type)
if error:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=error,
)
rollout_id = uuid.uuid4()
jobs_created = []
for device_id_str in body.device_ids:
# Load device to render template per-device
result = await db.execute(
select(Device).where(Device.id == device_id_str) # type: ignore[arg-type]
)
device = result.scalar_one_or_none()
if device is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Device {device_id_str} not found",
)
# Render template with this device's context
try:
rendered = template_service.render_template(
template.content,
{
"hostname": device.hostname,
"ip_address": device.ip_address,
"model": device.model,
},
body.variables,
)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Template rendering failed for device {device.hostname}: {exc}",
)
# Create push job
job = TemplatePushJob(
tenant_id=tenant_id,
template_id=template_id,
device_id=device.id,
rollout_id=rollout_id,
rendered_content=rendered,
status="pending",
)
db.add(job)
jobs_created.append({
"job_id": str(job.id),
"device_id": str(device.id),
"device_hostname": device.hostname,
})
await db.flush()
# Start background push task
asyncio.create_task(template_service.push_to_devices(str(rollout_id)))
return {
"rollout_id": str(rollout_id),
"jobs": jobs_created,
}
@router.get(
"/tenants/{tenant_id}/templates/push-status/{rollout_id}",
summary="Poll push progress for a rollout",
dependencies=[require_scope("config:read")],
)
async def push_status(
tenant_id: uuid.UUID,
rollout_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("viewer")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Return all push job statuses for a rollout with device hostnames."""
await _check_tenant_access(current_user, tenant_id, db)
result = await db.execute(
select(TemplatePushJob, Device.hostname)
.join(Device, TemplatePushJob.device_id == Device.id) # type: ignore[arg-type]
.where(
TemplatePushJob.rollout_id == rollout_id, # type: ignore[arg-type]
TemplatePushJob.tenant_id == tenant_id, # type: ignore[arg-type]
)
.order_by(TemplatePushJob.created_at.asc())
)
rows = result.all()
jobs = []
for job, hostname in rows:
jobs.append({
"device_id": str(job.device_id),
"hostname": hostname,
"status": job.status,
"error_message": job.error_message,
"started_at": job.started_at.isoformat() if job.started_at else None,
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
})
return {
"rollout_id": str(rollout_id),
"jobs": jobs,
}

View File

@@ -0,0 +1,367 @@
"""
Tenant management endpoints.
GET /api/tenants — list tenants (super_admin: all; tenant_admin: own only)
POST /api/tenants — create tenant (super_admin only)
GET /api/tenants/{id} — get tenant detail
PUT /api/tenants/{id} — update tenant (super_admin only)
DELETE /api/tenants/{id} — delete tenant (super_admin only)
"""
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy import func, select, text
from sqlalchemy.ext.asyncio import AsyncSession
from app.middleware.rate_limit import limiter
from app.database import get_admin_db, get_db
from app.middleware.rbac import require_super_admin, require_tenant_admin_or_above
from app.middleware.tenant_context import CurrentUser
from app.models.device import Device
from app.models.tenant import Tenant
from app.models.user import User
from app.schemas.tenant import TenantCreate, TenantResponse, TenantUpdate
router = APIRouter(prefix="/tenants", tags=["tenants"])
async def _get_tenant_response(
tenant: Tenant,
db: AsyncSession,
) -> TenantResponse:
"""Build a TenantResponse with user and device counts."""
user_count_result = await db.execute(
select(func.count(User.id)).where(User.tenant_id == tenant.id)
)
user_count = user_count_result.scalar_one() or 0
device_count_result = await db.execute(
select(func.count(Device.id)).where(Device.tenant_id == tenant.id)
)
device_count = device_count_result.scalar_one() or 0
return TenantResponse(
id=tenant.id,
name=tenant.name,
description=tenant.description,
contact_email=tenant.contact_email,
user_count=user_count,
device_count=device_count,
created_at=tenant.created_at,
)
@router.get("", response_model=list[TenantResponse], summary="List tenants")
async def list_tenants(
current_user: CurrentUser = Depends(require_tenant_admin_or_above),
db: AsyncSession = Depends(get_admin_db),
) -> list[TenantResponse]:
"""
List tenants.
- super_admin: sees all tenants
- tenant_admin: sees only their own tenant
"""
if current_user.is_super_admin:
result = await db.execute(select(Tenant).order_by(Tenant.name))
tenants = result.scalars().all()
else:
if not current_user.tenant_id:
return []
result = await db.execute(
select(Tenant).where(Tenant.id == current_user.tenant_id)
)
tenants = result.scalars().all()
return [await _get_tenant_response(tenant, db) for tenant in tenants]
@router.post("", response_model=TenantResponse, status_code=status.HTTP_201_CREATED, summary="Create a tenant")
@limiter.limit("20/minute")
async def create_tenant(
request: Request,
data: TenantCreate,
current_user: CurrentUser = Depends(require_super_admin),
db: AsyncSession = Depends(get_admin_db),
) -> TenantResponse:
"""Create a new tenant (super_admin only)."""
# Check for name uniqueness
existing = await db.execute(select(Tenant).where(Tenant.name == data.name))
if existing.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Tenant with name '{data.name}' already exists",
)
tenant = Tenant(name=data.name, description=data.description, contact_email=data.contact_email)
db.add(tenant)
await db.commit()
await db.refresh(tenant)
# Seed default alert rules for new tenant
default_rules = [
("High CPU Usage", "cpu_load", "gt", 90, 5, "warning"),
("High Memory Usage", "memory_used_pct", "gt", 90, 5, "warning"),
("High Disk Usage", "disk_used_pct", "gt", 85, 3, "warning"),
("Device Offline", "device_offline", "eq", 1, 1, "critical"),
]
for name, metric, operator, threshold, duration, sev in default_rules:
await db.execute(text("""
INSERT INTO alert_rules (id, tenant_id, name, metric, operator, threshold, duration_polls, severity, enabled, is_default)
VALUES (gen_random_uuid(), CAST(:tenant_id AS uuid), :name, :metric, :operator, :threshold, :duration, :severity, TRUE, TRUE)
"""), {
"tenant_id": str(tenant.id), "name": name, "metric": metric,
"operator": operator, "threshold": threshold, "duration": duration, "severity": sev,
})
await db.commit()
# Seed starter config templates for new tenant
await _seed_starter_templates(db, tenant.id)
await db.commit()
# Provision OpenBao Transit key for the new tenant (non-blocking)
try:
from app.config import settings
from app.services.key_service import provision_tenant_key
if settings.OPENBAO_ADDR:
await provision_tenant_key(db, tenant.id)
await db.commit()
except Exception as exc:
import logging
logging.getLogger(__name__).warning(
"OpenBao key provisioning failed for tenant %s (will be provisioned on next startup): %s",
tenant.id,
exc,
)
return await _get_tenant_response(tenant, db)
@router.get("/{tenant_id}", response_model=TenantResponse, summary="Get tenant detail")
async def get_tenant(
tenant_id: uuid.UUID,
current_user: CurrentUser = Depends(require_tenant_admin_or_above),
db: AsyncSession = Depends(get_admin_db),
) -> TenantResponse:
"""Get tenant detail. Tenant admins can only view their own tenant."""
# Enforce tenant_admin can only see their own tenant
if not current_user.is_super_admin and current_user.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to this tenant",
)
result = await db.execute(select(Tenant).where(Tenant.id == tenant_id))
tenant = result.scalar_one_or_none()
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Tenant not found",
)
return await _get_tenant_response(tenant, db)
@router.put("/{tenant_id}", response_model=TenantResponse, summary="Update a tenant")
@limiter.limit("20/minute")
async def update_tenant(
request: Request,
tenant_id: uuid.UUID,
data: TenantUpdate,
current_user: CurrentUser = Depends(require_super_admin),
db: AsyncSession = Depends(get_admin_db),
) -> TenantResponse:
"""Update tenant (super_admin only)."""
result = await db.execute(select(Tenant).where(Tenant.id == tenant_id))
tenant = result.scalar_one_or_none()
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Tenant not found",
)
if data.name is not None:
# Check name uniqueness
name_check = await db.execute(
select(Tenant).where(Tenant.name == data.name, Tenant.id != tenant_id)
)
if name_check.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Tenant with name '{data.name}' already exists",
)
tenant.name = data.name
if data.description is not None:
tenant.description = data.description
if data.contact_email is not None:
tenant.contact_email = data.contact_email
await db.commit()
await db.refresh(tenant)
return await _get_tenant_response(tenant, db)
@router.delete("/{tenant_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete a tenant")
@limiter.limit("5/minute")
async def delete_tenant(
request: Request,
tenant_id: uuid.UUID,
current_user: CurrentUser = Depends(require_super_admin),
db: AsyncSession = Depends(get_admin_db),
) -> None:
"""Delete tenant (super_admin only). Cascades to all users and devices."""
result = await db.execute(select(Tenant).where(Tenant.id == tenant_id))
tenant = result.scalar_one_or_none()
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Tenant not found",
)
await db.delete(tenant)
await db.commit()
# ---------------------------------------------------------------------------
# Starter template seeding
# ---------------------------------------------------------------------------
_STARTER_TEMPLATES = [
{
"name": "Basic Router",
"description": "Complete SOHO/branch router setup: WAN on ether1, LAN bridge, DHCP, DNS, NAT, basic firewall",
"content": """/interface bridge add name=bridge-lan comment="LAN bridge"
/interface bridge port add bridge=bridge-lan interface=ether2
/interface bridge port add bridge=bridge-lan interface=ether3
/interface bridge port add bridge=bridge-lan interface=ether4
/interface bridge port add bridge=bridge-lan interface=ether5
# WAN — DHCP client on ether1
/ip dhcp-client add interface={{ wan_interface }} disabled=no comment="WAN uplink"
# LAN address
/ip address add address={{ lan_gateway }}/{{ lan_cidr }} interface=bridge-lan
# DNS
/ip dns set servers={{ dns_servers }} allow-remote-requests=yes
# DHCP server for LAN
/ip pool add name=lan-pool ranges={{ dhcp_start }}-{{ dhcp_end }}
/ip dhcp-server network add address={{ lan_network }}/{{ lan_cidr }} gateway={{ lan_gateway }} dns-server={{ lan_gateway }}
/ip dhcp-server add name=lan-dhcp interface=bridge-lan address-pool=lan-pool disabled=no
# NAT masquerade
/ip firewall nat add chain=srcnat out-interface={{ wan_interface }} action=masquerade
# Firewall — input chain
/ip firewall filter
add chain=input connection-state=established,related action=accept
add chain=input connection-state=invalid action=drop
add chain=input in-interface={{ wan_interface }} action=drop comment="Drop all other WAN input"
# Firewall — forward chain
add chain=forward connection-state=established,related action=accept
add chain=forward connection-state=invalid action=drop
add chain=forward in-interface=bridge-lan out-interface={{ wan_interface }} action=accept comment="Allow LAN to WAN"
add chain=forward action=drop comment="Drop everything else"
# NTP
/system ntp client set enabled=yes servers={{ ntp_server }}
# Identity
/system identity set name={{ device.hostname }}""",
"variables": [
{"name": "wan_interface", "type": "string", "default": "ether1", "description": "WAN-facing interface"},
{"name": "lan_gateway", "type": "ip", "default": "192.168.88.1", "description": "LAN gateway IP"},
{"name": "lan_cidr", "type": "integer", "default": "24", "description": "LAN subnet mask bits"},
{"name": "lan_network", "type": "ip", "default": "192.168.88.0", "description": "LAN network address"},
{"name": "dhcp_start", "type": "ip", "default": "192.168.88.100", "description": "DHCP pool start"},
{"name": "dhcp_end", "type": "ip", "default": "192.168.88.254", "description": "DHCP pool end"},
{"name": "dns_servers", "type": "string", "default": "8.8.8.8,8.8.4.4", "description": "Upstream DNS servers"},
{"name": "ntp_server", "type": "string", "default": "pool.ntp.org", "description": "NTP server"},
],
},
{
"name": "Basic Firewall",
"description": "Standard firewall ruleset with WAN protection and LAN forwarding",
"content": """/ip firewall filter
add chain=input connection-state=established,related action=accept
add chain=input connection-state=invalid action=drop
add chain=input in-interface={{ wan_interface }} protocol=tcp dst-port=8291 action=drop comment="Block Winbox from WAN"
add chain=input in-interface={{ wan_interface }} protocol=tcp dst-port=22 action=drop comment="Block SSH from WAN"
add chain=forward connection-state=established,related action=accept
add chain=forward connection-state=invalid action=drop
add chain=forward src-address={{ allowed_network }} action=accept
add chain=forward action=drop""",
"variables": [
{"name": "wan_interface", "type": "string", "default": "ether1", "description": "WAN-facing interface"},
{"name": "allowed_network", "type": "subnet", "default": "192.168.88.0/24", "description": "Allowed source network"},
],
},
{
"name": "DHCP Server Setup",
"description": "Configure DHCP server with address pool, DNS, and gateway",
"content": """/ip pool add name=dhcp-pool ranges={{ pool_start }}-{{ pool_end }}
/ip dhcp-server network add address={{ gateway }}/24 gateway={{ gateway }} dns-server={{ dns_server }}
/ip dhcp-server add name=dhcp1 interface={{ interface }} address-pool=dhcp-pool disabled=no""",
"variables": [
{"name": "pool_start", "type": "ip", "default": "192.168.88.100", "description": "DHCP pool start address"},
{"name": "pool_end", "type": "ip", "default": "192.168.88.254", "description": "DHCP pool end address"},
{"name": "gateway", "type": "ip", "default": "192.168.88.1", "description": "Default gateway"},
{"name": "dns_server", "type": "ip", "default": "8.8.8.8", "description": "DNS server address"},
{"name": "interface", "type": "string", "default": "bridge-lan", "description": "Interface to serve DHCP on"},
],
},
{
"name": "Wireless AP Config",
"description": "Configure wireless access point with WPA2 security",
"content": """/interface wireless security-profiles add name=portal-wpa2 mode=dynamic-keys authentication-types=wpa2-psk wpa2-pre-shared-key={{ password }}
/interface wireless set wlan1 mode=ap-bridge ssid={{ ssid }} security-profile=portal-wpa2 frequency={{ frequency }} channel-width={{ channel_width }} disabled=no""",
"variables": [
{"name": "ssid", "type": "string", "default": "MikroTik-AP", "description": "Wireless network name"},
{"name": "password", "type": "string", "default": "", "description": "WPA2 pre-shared key (min 8 characters)"},
{"name": "frequency", "type": "integer", "default": "2412", "description": "Wireless frequency in MHz"},
{"name": "channel_width", "type": "string", "default": "20/40mhz-XX", "description": "Channel width setting"},
],
},
{
"name": "Initial Device Setup",
"description": "Set device identity, NTP, DNS, and disable unused services",
"content": """/system identity set name={{ device.hostname }}
/system ntp client set enabled=yes servers={{ ntp_server }}
/ip dns set servers={{ dns_servers }} allow-remote-requests=no
/ip service disable telnet,ftp,www,api-ssl
/ip service set ssh port=22
/ip service set winbox port=8291""",
"variables": [
{"name": "ntp_server", "type": "ip", "default": "pool.ntp.org", "description": "NTP server address"},
{"name": "dns_servers", "type": "string", "default": "8.8.8.8,8.8.4.4", "description": "Comma-separated DNS servers"},
],
},
]
async def _seed_starter_templates(db, tenant_id) -> None:
"""Insert starter config templates for a newly created tenant."""
import json as _json
for tmpl in _STARTER_TEMPLATES:
await db.execute(text("""
INSERT INTO config_templates (id, tenant_id, name, description, content, variables)
VALUES (gen_random_uuid(), CAST(:tid AS uuid), :name, :desc, :content, CAST(:vars AS jsonb))
"""), {
"tid": str(tenant_id),
"name": tmpl["name"],
"desc": tmpl["description"],
"content": tmpl["content"],
"vars": _json.dumps(tmpl["variables"]),
})

View File

@@ -0,0 +1,374 @@
"""
Network topology inference endpoint.
Endpoint: GET /api/tenants/{tenant_id}/topology
Builds a topology graph of managed devices by:
1. Querying all devices for the tenant (via RLS)
2. Fetching /ip/neighbor tables from online devices via NATS
3. Matching neighbor addresses to known devices
4. Falling back to shared /24 subnet inference when neighbor data is unavailable
5. Caching results in Redis with 5-minute TTL
"""
import asyncio
import ipaddress
import json
import logging
import uuid
from typing import Any
import redis.asyncio as aioredis
import structlog
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.database import get_db, set_tenant_context
from app.middleware.rbac import require_min_role
from app.middleware.tenant_context import CurrentUser, get_current_user
from app.models.device import Device
from app.models.vpn import VpnPeer
from app.services import routeros_proxy
logger = structlog.get_logger(__name__)
router = APIRouter(tags=["topology"])
# ---------------------------------------------------------------------------
# Redis connection (lazy initialized, same pattern as routeros_proxy NATS)
# ---------------------------------------------------------------------------
_redis: aioredis.Redis | None = None
TOPOLOGY_CACHE_TTL = 300 # 5 minutes
async def _get_redis() -> aioredis.Redis:
"""Get or create a Redis connection for topology caching."""
global _redis
if _redis is None:
_redis = aioredis.from_url(settings.REDIS_URL, decode_responses=True)
logger.info("Topology Redis connection established")
return _redis
# ---------------------------------------------------------------------------
# Response schemas
# ---------------------------------------------------------------------------
class TopologyNode(BaseModel):
id: str
hostname: str
ip: str
status: str
model: str | None
uptime: str | None
class TopologyEdge(BaseModel):
source: str
target: str
label: str
class TopologyResponse(BaseModel):
nodes: list[TopologyNode]
edges: list[TopologyEdge]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
"""Verify the current user is allowed to access the given tenant."""
if current_user.is_super_admin:
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied: you do not belong to this tenant.",
)
def _format_uptime(seconds: int | None) -> str | None:
"""Convert uptime seconds to a human-readable string."""
if seconds is None:
return None
days = seconds // 86400
hours = (seconds % 86400) // 3600
minutes = (seconds % 3600) // 60
if days > 0:
return f"{days}d {hours}h {minutes}m"
if hours > 0:
return f"{hours}h {minutes}m"
return f"{minutes}m"
def _get_subnet_key(ip_str: str) -> str | None:
"""Return the /24 network key for an IPv4 address, or None if invalid."""
try:
addr = ipaddress.ip_address(ip_str)
if isinstance(addr, ipaddress.IPv4Address):
network = ipaddress.ip_network(f"{ip_str}/24", strict=False)
return str(network)
except ValueError:
pass
return None
def _build_edges_from_neighbors(
neighbor_data: dict[str, list[dict[str, Any]]],
ip_to_device: dict[str, str],
) -> list[TopologyEdge]:
"""Build topology edges from neighbor discovery results.
Args:
neighbor_data: Mapping of device_id -> list of neighbor entries.
ip_to_device: Mapping of IP address -> device_id for known devices.
Returns:
De-duplicated list of topology edges.
"""
seen_edges: set[tuple[str, str]] = set()
edges: list[TopologyEdge] = []
for device_id, neighbors in neighbor_data.items():
for neighbor in neighbors:
# RouterOS neighbor entry has 'address' (or 'address4') field
neighbor_ip = neighbor.get("address") or neighbor.get("address4", "")
if not neighbor_ip:
continue
target_device_id = ip_to_device.get(neighbor_ip)
if target_device_id is None or target_device_id == device_id:
continue
# De-duplicate bidirectional edges (A->B and B->A become one edge)
edge_key = tuple(sorted([device_id, target_device_id]))
if edge_key in seen_edges:
continue
seen_edges.add(edge_key)
interface_name = neighbor.get("interface", "neighbor")
edges.append(
TopologyEdge(
source=device_id,
target=target_device_id,
label=interface_name,
)
)
return edges
def _build_edges_from_subnets(
devices: list[Device],
existing_connected: set[tuple[str, str]],
) -> list[TopologyEdge]:
"""Infer edges from shared /24 subnets for devices without neighbor data.
Only adds subnet-based edges for device pairs that are NOT already connected
via neighbor discovery.
"""
# Group devices by /24 subnet
subnet_groups: dict[str, list[str]] = {}
for device in devices:
subnet_key = _get_subnet_key(device.ip_address)
if subnet_key:
subnet_groups.setdefault(subnet_key, []).append(str(device.id))
edges: list[TopologyEdge] = []
for subnet, device_ids in subnet_groups.items():
if len(device_ids) < 2:
continue
# Connect all pairs in the subnet
for i, src in enumerate(device_ids):
for tgt in device_ids[i + 1 :]:
edge_key = tuple(sorted([src, tgt]))
if edge_key in existing_connected:
continue
edges.append(
TopologyEdge(
source=src,
target=tgt,
label="shared subnet",
)
)
existing_connected.add(edge_key)
return edges
# ---------------------------------------------------------------------------
# Endpoint
# ---------------------------------------------------------------------------
@router.get(
"/tenants/{tenant_id}/topology",
response_model=TopologyResponse,
summary="Get network topology for a tenant",
)
async def get_topology(
tenant_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
_role: CurrentUser = Depends(require_min_role("viewer")),
db: AsyncSession = Depends(get_db),
) -> TopologyResponse:
"""Build and return a network topology graph for the given tenant.
The topology is inferred from:
1. LLDP/CDP/MNDP neighbor discovery on online devices
2. Shared /24 subnet fallback for devices without neighbor data
Results are cached in Redis with a 5-minute TTL.
"""
await _check_tenant_access(current_user, tenant_id, db)
cache_key = f"topology:{tenant_id}"
# Check Redis cache
try:
rd = await _get_redis()
cached = await rd.get(cache_key)
if cached:
data = json.loads(cached)
return TopologyResponse(**data)
except Exception as exc:
logger.warning("Redis cache read failed, computing topology fresh", error=str(exc))
# Fetch all devices for tenant (RLS enforced via get_db)
result = await db.execute(
select(
Device.id,
Device.hostname,
Device.ip_address,
Device.status,
Device.model,
Device.uptime_seconds,
)
)
rows = result.all()
if not rows:
return TopologyResponse(nodes=[], edges=[])
# Build nodes
nodes: list[TopologyNode] = []
ip_to_device: dict[str, str] = {}
online_device_ids: list[str] = []
devices_by_id: dict[str, Any] = {}
for row in rows:
device_id = str(row.id)
nodes.append(
TopologyNode(
id=device_id,
hostname=row.hostname,
ip=row.ip_address,
status=row.status,
model=row.model,
uptime=_format_uptime(row.uptime_seconds),
)
)
ip_to_device[row.ip_address] = device_id
if row.status == "online":
online_device_ids.append(device_id)
# Fetch neighbor tables from online devices in parallel
neighbor_data: dict[str, list[dict[str, Any]]] = {}
if online_device_ids:
tasks = [
routeros_proxy.execute_command(
device_id, "/ip/neighbor/print", timeout=10.0
)
for device_id in online_device_ids
]
results = await asyncio.gather(*tasks, return_exceptions=True)
for device_id, res in zip(online_device_ids, results):
if isinstance(res, Exception):
logger.warning(
"Neighbor fetch failed",
device_id=device_id,
error=str(res),
)
continue
if isinstance(res, dict) and res.get("success") and res.get("data"):
neighbor_data[device_id] = res["data"]
# Build edges from neighbor discovery
neighbor_edges = _build_edges_from_neighbors(neighbor_data, ip_to_device)
# Track connected pairs for subnet fallback
connected_pairs: set[tuple[str, str]] = set()
for edge in neighbor_edges:
connected_pairs.add(tuple(sorted([edge.source, edge.target])))
# VPN-based edges: query WireGuard peers to infer hub-spoke topology.
# VPN peers all connect to the same WireGuard server. The gateway device
# is the managed device NOT in the VPN peers list (it's the server, not a
# client). If found, create star edges from gateway to each VPN peer device.
vpn_edges: list[TopologyEdge] = []
vpn_peer_device_ids: set[str] = set()
try:
peer_result = await db.execute(
select(VpnPeer.device_id).where(VpnPeer.is_enabled.is_(True))
)
vpn_peer_device_ids = {str(row[0]) for row in peer_result.all()}
if vpn_peer_device_ids:
# Gateway = managed devices NOT in VPN peers (typically the Core router)
all_device_ids = {str(row.id) for row in rows}
gateway_ids = all_device_ids - vpn_peer_device_ids
# Pick the gateway that's online (prefer online devices)
gateway_id = None
for gid in gateway_ids:
if gid in online_device_ids:
gateway_id = gid
break
if not gateway_id and gateway_ids:
gateway_id = next(iter(gateway_ids))
if gateway_id:
for peer_device_id in vpn_peer_device_ids:
edge_key = tuple(sorted([gateway_id, peer_device_id]))
if edge_key not in connected_pairs:
vpn_edges.append(
TopologyEdge(
source=gateway_id,
target=peer_device_id,
label="vpn tunnel",
)
)
connected_pairs.add(edge_key)
except Exception as exc:
logger.warning("VPN edge detection failed", error=str(exc))
# Fallback: infer connections from shared /24 subnets
# Query full Device objects for subnet analysis
device_result = await db.execute(select(Device))
all_devices = list(device_result.scalars().all())
subnet_edges = _build_edges_from_subnets(all_devices, connected_pairs)
all_edges = neighbor_edges + vpn_edges + subnet_edges
topology = TopologyResponse(nodes=nodes, edges=all_edges)
# Cache result in Redis
try:
rd = await _get_redis()
await rd.set(cache_key, topology.model_dump_json(), ex=TOPOLOGY_CACHE_TTL)
except Exception as exc:
logger.warning("Redis cache write failed", error=str(exc))
return topology

View File

@@ -0,0 +1,391 @@
"""Transparency log API endpoints.
Tenant-scoped routes under /api/tenants/{tenant_id}/ for:
- Paginated, filterable key access transparency log listing
- Transparency log statistics (total events, last 24h, unique devices, justification breakdown)
- CSV export of transparency logs
RLS enforced via get_db() (app_user engine with tenant context).
RBAC: admin and above can view transparency logs (tenant_admin or super_admin).
Phase 31: Data Access Transparency Dashboard - TRUST-01, TRUST-02
Shows tenant admins every KMS credential access event for their tenant.
"""
import csv
import io
import logging
import uuid
from datetime import datetime
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy import and_, func, select, text
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db, set_tenant_context
from app.middleware.tenant_context import CurrentUser, get_current_user
logger = logging.getLogger(__name__)
router = APIRouter(tags=["transparency"])
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
"""Verify the current user is allowed to access the given tenant."""
if current_user.is_super_admin:
await set_tenant_context(db, str(tenant_id))
elif current_user.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to this tenant",
)
def _require_admin(current_user: CurrentUser) -> None:
"""Raise 403 if user does not have at least admin role.
Transparency data is sensitive operational intelligence --
only tenant_admin and super_admin can view it.
"""
allowed = {"super_admin", "admin", "tenant_admin"}
if current_user.role not in allowed:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="At least admin role required to view transparency logs.",
)
# ---------------------------------------------------------------------------
# Response models
# ---------------------------------------------------------------------------
class TransparencyLogItem(BaseModel):
id: str
action: str
device_name: Optional[str] = None
device_id: Optional[str] = None
justification: Optional[str] = None
operator_email: Optional[str] = None
correlation_id: Optional[str] = None
resource_type: Optional[str] = None
resource_id: Optional[str] = None
ip_address: Optional[str] = None
created_at: str
class TransparencyLogResponse(BaseModel):
items: list[TransparencyLogItem]
total: int
page: int
per_page: int
class TransparencyStats(BaseModel):
total_events: int
events_last_24h: int
unique_devices: int
justification_breakdown: dict[str, int]
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.get(
"/tenants/{tenant_id}/transparency-logs",
response_model=TransparencyLogResponse,
summary="List KMS credential access events for tenant",
)
async def list_transparency_logs(
tenant_id: uuid.UUID,
page: int = Query(default=1, ge=1),
per_page: int = Query(default=50, ge=1, le=100),
device_id: Optional[uuid.UUID] = Query(default=None),
justification: Optional[str] = Query(default=None),
action: Optional[str] = Query(default=None),
date_from: Optional[datetime] = Query(default=None),
date_to: Optional[datetime] = Query(default=None),
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
_require_admin(current_user)
await _check_tenant_access(current_user, tenant_id, db)
# Build filter conditions using parameterized text fragments
conditions = [text("k.tenant_id = :tenant_id")]
params: dict[str, Any] = {"tenant_id": str(tenant_id)}
if device_id:
conditions.append(text("k.device_id = :device_id"))
params["device_id"] = str(device_id)
if justification:
conditions.append(text("k.justification = :justification"))
params["justification"] = justification
if action:
conditions.append(text("k.action = :action"))
params["action"] = action
if date_from:
conditions.append(text("k.created_at >= :date_from"))
params["date_from"] = date_from.isoformat()
if date_to:
conditions.append(text("k.created_at <= :date_to"))
params["date_to"] = date_to.isoformat()
where_clause = and_(*conditions)
# Shared SELECT columns for data queries
_data_columns = text(
"k.id, k.action, d.hostname AS device_name, "
"k.device_id, k.justification, u.email AS operator_email, "
"k.correlation_id, k.resource_type, k.resource_id, "
"k.ip_address, k.created_at"
)
_data_from = text(
"key_access_log k "
"LEFT JOIN users u ON k.user_id = u.id "
"LEFT JOIN devices d ON k.device_id = d.id"
)
# Count total
count_result = await db.execute(
select(func.count())
.select_from(text("key_access_log k"))
.where(where_clause),
params,
)
total = count_result.scalar() or 0
# Paginated query
offset = (page - 1) * per_page
params["limit"] = per_page
params["offset"] = offset
result = await db.execute(
select(_data_columns)
.select_from(_data_from)
.where(where_clause)
.order_by(text("k.created_at DESC"))
.limit(per_page)
.offset(offset),
params,
)
rows = result.mappings().all()
items = [
TransparencyLogItem(
id=str(row["id"]),
action=row["action"],
device_name=row["device_name"],
device_id=str(row["device_id"]) if row["device_id"] else None,
justification=row["justification"],
operator_email=row["operator_email"],
correlation_id=row["correlation_id"],
resource_type=row["resource_type"],
resource_id=row["resource_id"],
ip_address=row["ip_address"],
created_at=row["created_at"].isoformat() if row["created_at"] else "",
)
for row in rows
]
return TransparencyLogResponse(
items=items,
total=total,
page=page,
per_page=per_page,
)
@router.get(
"/tenants/{tenant_id}/transparency-logs/stats",
response_model=TransparencyStats,
summary="Get transparency log statistics",
)
async def get_transparency_stats(
tenant_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> TransparencyStats:
_require_admin(current_user)
await _check_tenant_access(current_user, tenant_id, db)
params: dict[str, Any] = {"tenant_id": str(tenant_id)}
# Total events
total_result = await db.execute(
select(func.count())
.select_from(text("key_access_log"))
.where(text("tenant_id = :tenant_id")),
params,
)
total_events = total_result.scalar() or 0
# Events in last 24 hours
last_24h_result = await db.execute(
select(func.count())
.select_from(text("key_access_log"))
.where(
and_(
text("tenant_id = :tenant_id"),
text("created_at >= NOW() - INTERVAL '24 hours'"),
)
),
params,
)
events_last_24h = last_24h_result.scalar() or 0
# Unique devices
unique_devices_result = await db.execute(
select(func.count(text("DISTINCT device_id")))
.select_from(text("key_access_log"))
.where(
and_(
text("tenant_id = :tenant_id"),
text("device_id IS NOT NULL"),
)
),
params,
)
unique_devices = unique_devices_result.scalar() or 0
# Justification breakdown
breakdown_result = await db.execute(
select(
text("COALESCE(justification, 'system') AS justification_label"),
func.count().label("count"),
)
.select_from(text("key_access_log"))
.where(text("tenant_id = :tenant_id"))
.group_by(text("justification_label")),
params,
)
justification_breakdown: dict[str, int] = {}
for row in breakdown_result.mappings().all():
justification_breakdown[row["justification_label"]] = row["count"]
return TransparencyStats(
total_events=total_events,
events_last_24h=events_last_24h,
unique_devices=unique_devices,
justification_breakdown=justification_breakdown,
)
@router.get(
"/tenants/{tenant_id}/transparency-logs/export",
summary="Export transparency logs as CSV",
)
async def export_transparency_logs(
tenant_id: uuid.UUID,
device_id: Optional[uuid.UUID] = Query(default=None),
justification: Optional[str] = Query(default=None),
action: Optional[str] = Query(default=None),
date_from: Optional[datetime] = Query(default=None),
date_to: Optional[datetime] = Query(default=None),
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> StreamingResponse:
_require_admin(current_user)
await _check_tenant_access(current_user, tenant_id, db)
# Build filter conditions
conditions = [text("k.tenant_id = :tenant_id")]
params: dict[str, Any] = {"tenant_id": str(tenant_id)}
if device_id:
conditions.append(text("k.device_id = :device_id"))
params["device_id"] = str(device_id)
if justification:
conditions.append(text("k.justification = :justification"))
params["justification"] = justification
if action:
conditions.append(text("k.action = :action"))
params["action"] = action
if date_from:
conditions.append(text("k.created_at >= :date_from"))
params["date_from"] = date_from.isoformat()
if date_to:
conditions.append(text("k.created_at <= :date_to"))
params["date_to"] = date_to.isoformat()
where_clause = and_(*conditions)
_data_columns = text(
"k.id, k.action, d.hostname AS device_name, "
"k.device_id, k.justification, u.email AS operator_email, "
"k.correlation_id, k.resource_type, k.resource_id, "
"k.ip_address, k.created_at"
)
_data_from = text(
"key_access_log k "
"LEFT JOIN users u ON k.user_id = u.id "
"LEFT JOIN devices d ON k.device_id = d.id"
)
result = await db.execute(
select(_data_columns)
.select_from(_data_from)
.where(where_clause)
.order_by(text("k.created_at DESC")),
params,
)
all_rows = result.mappings().all()
output = io.StringIO()
writer = csv.writer(output)
writer.writerow([
"ID",
"Action",
"Device Name",
"Device ID",
"Justification",
"Operator Email",
"Correlation ID",
"Resource Type",
"Resource ID",
"IP Address",
"Timestamp",
])
for row in all_rows:
writer.writerow([
str(row["id"]),
row["action"],
row["device_name"] or "",
str(row["device_id"]) if row["device_id"] else "",
row["justification"] or "",
row["operator_email"] or "",
row["correlation_id"] or "",
row["resource_type"] or "",
row["resource_id"] or "",
row["ip_address"] or "",
str(row["created_at"]),
])
output.seek(0)
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={
"Content-Disposition": "attachment; filename=transparency-logs.csv"
},
)

View File

@@ -0,0 +1,231 @@
"""
User management endpoints (scoped to tenant).
GET /api/tenants/{tenant_id}/users — list users in tenant
POST /api/tenants/{tenant_id}/users — create user in tenant
GET /api/tenants/{tenant_id}/users/{id} — get user detail
PUT /api/tenants/{tenant_id}/users/{id} — update user
DELETE /api/tenants/{tenant_id}/users/{id} — deactivate user
"""
import uuid
from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.middleware.rate_limit import limiter
from app.database import get_admin_db
from app.middleware.rbac import require_tenant_admin_or_above
from app.middleware.tenant_context import CurrentUser
from app.models.tenant import Tenant
from app.models.user import User, UserRole
from app.schemas.user import UserCreate, UserResponse, UserUpdate
from app.services.auth import hash_password
router = APIRouter(prefix="/tenants", tags=["users"])
async def _check_tenant_access(
tenant_id: uuid.UUID,
current_user: CurrentUser,
db: AsyncSession,
) -> Tenant:
"""
Verify the tenant exists and the current user has access to it.
super_admin can access any tenant.
tenant_admin can only access their own tenant.
"""
if not current_user.is_super_admin and current_user.tenant_id != tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to this tenant",
)
result = await db.execute(select(Tenant).where(Tenant.id == tenant_id))
tenant = result.scalar_one_or_none()
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Tenant not found",
)
return tenant
@router.get("/{tenant_id}/users", response_model=list[UserResponse], summary="List users in tenant")
async def list_users(
tenant_id: uuid.UUID,
current_user: CurrentUser = Depends(require_tenant_admin_or_above),
db: AsyncSession = Depends(get_admin_db),
) -> list[UserResponse]:
"""
List users in a tenant.
- super_admin: can list users in any tenant
- tenant_admin: can only list users in their own tenant
"""
await _check_tenant_access(tenant_id, current_user, db)
result = await db.execute(
select(User)
.where(User.tenant_id == tenant_id)
.order_by(User.name)
)
users = result.scalars().all()
return [UserResponse.model_validate(user) for user in users]
@router.post(
"/{tenant_id}/users",
response_model=UserResponse,
status_code=status.HTTP_201_CREATED,
summary="Create a user in tenant",
)
@limiter.limit("20/minute")
async def create_user(
request: Request,
tenant_id: uuid.UUID,
data: UserCreate,
current_user: CurrentUser = Depends(require_tenant_admin_or_above),
db: AsyncSession = Depends(get_admin_db),
) -> UserResponse:
"""
Create a user within a tenant.
- super_admin: can create users in any tenant
- tenant_admin: can only create users in their own tenant
- No email invitation flow — admin creates accounts with temporary passwords
"""
await _check_tenant_access(tenant_id, current_user, db)
# Check email uniqueness (global, not per-tenant)
existing = await db.execute(
select(User).where(User.email == data.email.lower())
)
if existing.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="A user with this email already exists",
)
user = User(
email=data.email.lower(),
hashed_password=hash_password(data.password),
name=data.name,
role=data.role.value,
tenant_id=tenant_id,
is_active=True,
must_upgrade_auth=True,
)
db.add(user)
await db.commit()
await db.refresh(user)
return UserResponse.model_validate(user)
@router.get("/{tenant_id}/users/{user_id}", response_model=UserResponse, summary="Get user detail")
async def get_user(
tenant_id: uuid.UUID,
user_id: uuid.UUID,
current_user: CurrentUser = Depends(require_tenant_admin_or_above),
db: AsyncSession = Depends(get_admin_db),
) -> UserResponse:
"""Get user detail."""
await _check_tenant_access(tenant_id, current_user, db)
result = await db.execute(
select(User).where(User.id == user_id, User.tenant_id == tenant_id)
)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
return UserResponse.model_validate(user)
@router.put("/{tenant_id}/users/{user_id}", response_model=UserResponse, summary="Update a user")
@limiter.limit("20/minute")
async def update_user(
request: Request,
tenant_id: uuid.UUID,
user_id: uuid.UUID,
data: UserUpdate,
current_user: CurrentUser = Depends(require_tenant_admin_or_above),
db: AsyncSession = Depends(get_admin_db),
) -> UserResponse:
"""
Update user attributes (name, role, is_active).
Role assignment is editable by admins.
"""
await _check_tenant_access(tenant_id, current_user, db)
result = await db.execute(
select(User).where(User.id == user_id, User.tenant_id == tenant_id)
)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
if data.name is not None:
user.name = data.name
if data.role is not None:
user.role = data.role.value
if data.is_active is not None:
user.is_active = data.is_active
await db.commit()
await db.refresh(user)
return UserResponse.model_validate(user)
@router.delete("/{tenant_id}/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Deactivate a user")
@limiter.limit("5/minute")
async def deactivate_user(
request: Request,
tenant_id: uuid.UUID,
user_id: uuid.UUID,
current_user: CurrentUser = Depends(require_tenant_admin_or_above),
db: AsyncSession = Depends(get_admin_db),
) -> None:
"""
Deactivate a user (soft delete — sets is_active=False).
This preserves audit trail while preventing login.
"""
await _check_tenant_access(tenant_id, current_user, db)
result = await db.execute(
select(User).where(User.id == user_id, User.tenant_id == tenant_id)
)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
# Prevent self-deactivation
if user.id == current_user.user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot deactivate your own account",
)
user.is_active = False
await db.commit()

236
backend/app/routers/vpn.py Normal file
View File

@@ -0,0 +1,236 @@
"""WireGuard VPN API endpoints.
Tenant-scoped routes under /api/tenants/{tenant_id}/vpn/ for:
- VPN setup (enable WireGuard for tenant)
- VPN config management (update endpoint, enable/disable)
- Peer management (add device, remove, get config)
RLS enforced via get_db() (app_user engine with tenant context).
RBAC: operator and above for all operations.
"""
import uuid
from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db, set_tenant_context
from app.middleware.rate_limit import limiter
from app.middleware.tenant_context import CurrentUser, get_current_user
from app.models.device import Device
from app.schemas.vpn import (
VpnConfigResponse,
VpnConfigUpdate,
VpnOnboardRequest,
VpnOnboardResponse,
VpnPeerConfig,
VpnPeerCreate,
VpnPeerResponse,
VpnSetupRequest,
)
from app.services import vpn_service
router = APIRouter(tags=["vpn"])
async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
if current_user.is_super_admin:
await set_tenant_context(db, str(tenant_id))
elif current_user.tenant_id != tenant_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
def _require_operator(current_user: CurrentUser) -> None:
if current_user.role == "viewer":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Operator role required")
# ── VPN Config ──
@router.get("/tenants/{tenant_id}/vpn", response_model=VpnConfigResponse | None)
async def get_vpn_config(
tenant_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get VPN configuration for this tenant."""
await _check_tenant_access(current_user, tenant_id, db)
config = await vpn_service.get_vpn_config(db, tenant_id)
if not config:
return None
peers = await vpn_service.get_peers(db, tenant_id)
resp = VpnConfigResponse.model_validate(config)
resp.peer_count = len(peers)
return resp
@router.post("/tenants/{tenant_id}/vpn", response_model=VpnConfigResponse, status_code=status.HTTP_201_CREATED)
@limiter.limit("20/minute")
async def setup_vpn(
request: Request,
tenant_id: uuid.UUID,
body: VpnSetupRequest,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Enable VPN for this tenant — generates server keys."""
await _check_tenant_access(current_user, tenant_id, db)
_require_operator(current_user)
try:
config = await vpn_service.setup_vpn(db, tenant_id, endpoint=body.endpoint)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e))
return VpnConfigResponse.model_validate(config)
@router.patch("/tenants/{tenant_id}/vpn", response_model=VpnConfigResponse)
@limiter.limit("20/minute")
async def update_vpn_config(
request: Request,
tenant_id: uuid.UUID,
body: VpnConfigUpdate,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Update VPN settings (endpoint, enable/disable)."""
await _check_tenant_access(current_user, tenant_id, db)
_require_operator(current_user)
try:
config = await vpn_service.update_vpn_config(
db, tenant_id, endpoint=body.endpoint, is_enabled=body.is_enabled
)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
peers = await vpn_service.get_peers(db, tenant_id)
resp = VpnConfigResponse.model_validate(config)
resp.peer_count = len(peers)
return resp
# ── VPN Peers ──
@router.get("/tenants/{tenant_id}/vpn/peers", response_model=list[VpnPeerResponse])
async def list_peers(
tenant_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""List all VPN peers for this tenant."""
await _check_tenant_access(current_user, tenant_id, db)
peers = await vpn_service.get_peers(db, tenant_id)
# Enrich with device info
device_ids = [p.device_id for p in peers]
devices = {}
if device_ids:
result = await db.execute(select(Device).where(Device.id.in_(device_ids)))
devices = {d.id: d for d in result.scalars().all()}
# Read live WireGuard status for handshake enrichment
wg_status = vpn_service.read_wg_status()
responses = []
for peer in peers:
resp = VpnPeerResponse.model_validate(peer)
device = devices.get(peer.device_id)
if device:
resp.device_hostname = device.hostname
resp.device_ip = device.ip_address
# Enrich with live handshake from WireGuard container
live_handshake = vpn_service.get_peer_handshake(wg_status, peer.peer_public_key)
if live_handshake:
resp.last_handshake = live_handshake
responses.append(resp)
return responses
@router.post("/tenants/{tenant_id}/vpn/peers", response_model=VpnPeerResponse, status_code=status.HTTP_201_CREATED)
@limiter.limit("20/minute")
async def add_peer(
request: Request,
tenant_id: uuid.UUID,
body: VpnPeerCreate,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Add a device as a VPN peer."""
await _check_tenant_access(current_user, tenant_id, db)
_require_operator(current_user)
try:
peer = await vpn_service.add_peer(db, tenant_id, body.device_id, additional_allowed_ips=body.additional_allowed_ips)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e))
# Enrich with device info
result = await db.execute(select(Device).where(Device.id == peer.device_id))
device = result.scalar_one_or_none()
resp = VpnPeerResponse.model_validate(peer)
if device:
resp.device_hostname = device.hostname
resp.device_ip = device.ip_address
return resp
@router.post("/tenants/{tenant_id}/vpn/peers/onboard", response_model=VpnOnboardResponse, status_code=status.HTTP_201_CREATED)
@limiter.limit("10/minute")
async def onboard_device(
request: Request,
tenant_id: uuid.UUID,
body: VpnOnboardRequest,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Create device + VPN peer in one step. Returns RouterOS commands for tunnel setup."""
await _check_tenant_access(current_user, tenant_id, db)
_require_operator(current_user)
try:
result = await vpn_service.onboard_device(
db, tenant_id,
hostname=body.hostname,
username=body.username,
password=body.password,
)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e))
return VpnOnboardResponse(**result)
@router.delete("/tenants/{tenant_id}/vpn/peers/{peer_id}", status_code=status.HTTP_204_NO_CONTENT)
@limiter.limit("5/minute")
async def remove_peer(
request: Request,
tenant_id: uuid.UUID,
peer_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Remove a VPN peer."""
await _check_tenant_access(current_user, tenant_id, db)
_require_operator(current_user)
try:
await vpn_service.remove_peer(db, tenant_id, peer_id)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
@router.get("/tenants/{tenant_id}/vpn/peers/{peer_id}/config", response_model=VpnPeerConfig)
async def get_peer_device_config(
tenant_id: uuid.UUID,
peer_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get the full config for a peer — includes private key and RouterOS commands."""
await _check_tenant_access(current_user, tenant_id, db)
_require_operator(current_user)
try:
config = await vpn_service.get_peer_config(db, tenant_id, peer_id)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
return VpnPeerConfig(**config)

View File

@@ -0,0 +1,18 @@
"""Pydantic schemas for request/response validation."""
from app.schemas.auth import LoginRequest, TokenResponse, RefreshRequest, UserMeResponse
from app.schemas.tenant import TenantCreate, TenantResponse, TenantUpdate
from app.schemas.user import UserCreate, UserResponse, UserUpdate
__all__ = [
"LoginRequest",
"TokenResponse",
"RefreshRequest",
"UserMeResponse",
"TenantCreate",
"TenantResponse",
"TenantUpdate",
"UserCreate",
"UserResponse",
"UserUpdate",
]

123
backend/app/schemas/auth.py Normal file
View File

@@ -0,0 +1,123 @@
"""Authentication request/response schemas."""
import uuid
from typing import Optional
from pydantic import BaseModel, EmailStr
class LoginRequest(BaseModel):
email: EmailStr
password: str
class TokenResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
auth_upgrade_required: bool = False # True when bcrypt user needs SRP registration
class RefreshRequest(BaseModel):
refresh_token: str
class UserMeResponse(BaseModel):
id: uuid.UUID
email: str
name: str
role: str
tenant_id: Optional[uuid.UUID] = None
auth_version: int = 1
model_config = {"from_attributes": True}
class ChangePasswordRequest(BaseModel):
current_password: str
new_password: str
# SRP users must provide re-derived credentials
new_srp_salt: Optional[str] = None
new_srp_verifier: Optional[str] = None
# Re-wrapped key bundle (SRP users re-encrypt with new AUK)
encrypted_private_key: Optional[str] = None
private_key_nonce: Optional[str] = None
encrypted_vault_key: Optional[str] = None
vault_key_nonce: Optional[str] = None
public_key: Optional[str] = None
pbkdf2_salt: Optional[str] = None
hkdf_salt: Optional[str] = None
class ForgotPasswordRequest(BaseModel):
email: EmailStr
class ResetPasswordRequest(BaseModel):
token: str
new_password: str
class MessageResponse(BaseModel):
message: str
# --- SRP Zero-Knowledge Authentication Schemas ---
class SRPInitRequest(BaseModel):
"""Step 1 request: client sends email to begin SRP handshake."""
email: EmailStr
class SRPInitResponse(BaseModel):
"""Step 1 response: server returns ephemeral B and key derivation salts."""
salt: str # hex-encoded SRP salt
server_public: str # hex-encoded server ephemeral B
session_id: str # Redis session key nonce
pbkdf2_salt: str # base64-encoded, from user_key_sets (needed for 2SKD before SRP verify)
hkdf_salt: str # base64-encoded, from user_key_sets (needed for 2SKD before SRP verify)
class SRPVerifyRequest(BaseModel):
"""Step 2 request: client sends proof M1 to complete handshake."""
email: EmailStr
session_id: str
client_public: str # hex-encoded client ephemeral A
client_proof: str # hex-encoded client proof M1
class SRPVerifyResponse(BaseModel):
"""Step 2 response: server returns tokens and proof M2."""
access_token: str
refresh_token: str
token_type: str = "bearer"
server_proof: str # hex-encoded server proof M2
encrypted_key_set: Optional[dict] = None # Key bundle for client-side decryption
class SRPRegisterRequest(BaseModel):
"""Used during registration to store SRP verifier and key set."""
srp_salt: str # hex-encoded
srp_verifier: str # hex-encoded
encrypted_private_key: str # base64-encoded
private_key_nonce: str # base64-encoded
encrypted_vault_key: str # base64-encoded
vault_key_nonce: str # base64-encoded
public_key: str # base64-encoded
pbkdf2_salt: str # base64-encoded
hkdf_salt: str # base64-encoded
# --- Account Self-Service Schemas ---
class DeleteAccountRequest(BaseModel):
"""Request body for account self-deletion. User must type 'DELETE' to confirm."""
confirmation: str # Must be "DELETE" to confirm
class DeleteAccountResponse(BaseModel):
"""Response after successful account deletion."""
message: str
deleted: bool

View File

@@ -0,0 +1,78 @@
"""Pydantic request/response schemas for the Internal Certificate Authority."""
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, ConfigDict
# ---------------------------------------------------------------------------
# Request schemas
# ---------------------------------------------------------------------------
class CACreateRequest(BaseModel):
"""Request to generate a new root CA for the tenant."""
common_name: str = "Portal Root CA"
validity_years: int = 10 # Default 10 years for CA
class CertSignRequest(BaseModel):
"""Request to sign a per-device certificate using the tenant CA."""
device_id: UUID
validity_days: int = 730 # Default 2 years for device certs
class BulkCertDeployRequest(BaseModel):
"""Request to deploy certificates to multiple devices."""
device_ids: list[UUID]
# ---------------------------------------------------------------------------
# Response schemas
# ---------------------------------------------------------------------------
class CAResponse(BaseModel):
"""Public details of a tenant's Certificate Authority (no private key)."""
id: UUID
tenant_id: UUID
common_name: str
fingerprint_sha256: str
serial_number: str
not_valid_before: datetime
not_valid_after: datetime
created_at: datetime
model_config = ConfigDict(from_attributes=True)
class DeviceCertResponse(BaseModel):
"""Public details of a device certificate (no private key)."""
id: UUID
tenant_id: UUID
device_id: UUID
ca_id: UUID
common_name: str
fingerprint_sha256: str
serial_number: str
not_valid_before: datetime
not_valid_after: datetime
status: str
deployed_at: datetime | None
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
class CertDeployResponse(BaseModel):
"""Result of a single device certificate deployment attempt."""
success: bool
device_id: UUID
cert_name_on_device: str | None = None
error: str | None = None

View File

@@ -0,0 +1,271 @@
"""Pydantic schemas for Device, DeviceGroup, and DeviceTag endpoints."""
import uuid
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, field_validator
# ---------------------------------------------------------------------------
# Device schemas
# ---------------------------------------------------------------------------
class DeviceCreate(BaseModel):
"""Schema for creating a new device."""
hostname: str
ip_address: str
api_port: int = 8728
api_ssl_port: int = 8729
username: str
password: str
class DeviceUpdate(BaseModel):
"""Schema for updating an existing device. All fields optional."""
hostname: Optional[str] = None
ip_address: Optional[str] = None
api_port: Optional[int] = None
api_ssl_port: Optional[int] = None
username: Optional[str] = None
password: Optional[str] = None
latitude: Optional[float] = None
longitude: Optional[float] = None
tls_mode: Optional[str] = None
@field_validator("tls_mode")
@classmethod
def validate_tls_mode(cls, v: Optional[str]) -> Optional[str]:
"""Validate tls_mode is one of the allowed values."""
if v is None:
return v
allowed = {"auto", "insecure", "plain", "portal_ca"}
if v not in allowed:
raise ValueError(f"tls_mode must be one of: {', '.join(sorted(allowed))}")
return v
class DeviceTagRef(BaseModel):
"""Minimal tag info embedded in device responses."""
id: uuid.UUID
name: str
color: Optional[str] = None
model_config = {"from_attributes": True}
class DeviceGroupRef(BaseModel):
"""Minimal group info embedded in device responses."""
id: uuid.UUID
name: str
model_config = {"from_attributes": True}
class DeviceResponse(BaseModel):
"""Device response schema. NEVER includes credential fields."""
id: uuid.UUID
hostname: str
ip_address: str
api_port: int
api_ssl_port: int
model: Optional[str] = None
serial_number: Optional[str] = None
firmware_version: Optional[str] = None
routeros_version: Optional[str] = None
routeros_major_version: Optional[int] = None
uptime_seconds: Optional[int] = None
last_seen: Optional[datetime] = None
latitude: Optional[float] = None
longitude: Optional[float] = None
status: str
tls_mode: str = "auto"
tags: list[DeviceTagRef] = []
groups: list[DeviceGroupRef] = []
created_at: datetime
model_config = {"from_attributes": True}
class DeviceListResponse(BaseModel):
"""Paginated device list response."""
items: list[DeviceResponse]
total: int
page: int
page_size: int
# ---------------------------------------------------------------------------
# Subnet scan schemas
# ---------------------------------------------------------------------------
class SubnetScanRequest(BaseModel):
"""Request body for a subnet scan."""
cidr: str
@field_validator("cidr")
@classmethod
def validate_cidr(cls, v: str) -> str:
"""Validate that the value is a valid CIDR notation and RFC 1918 private range."""
import ipaddress
try:
network = ipaddress.ip_network(v, strict=False)
except ValueError as e:
raise ValueError(f"Invalid CIDR notation: {e}") from e
# Only allow private IP ranges (RFC 1918: 10/8, 172.16/12, 192.168/16)
if not network.is_private:
raise ValueError(
"Only private IP ranges can be scanned (RFC 1918: "
"10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16)"
)
# Reject ranges larger than /20 (4096 IPs) to prevent abuse
if network.num_addresses > 4096:
raise ValueError(
f"CIDR range too large ({network.num_addresses} addresses). "
"Maximum allowed: /20 (4096 addresses)."
)
return v
class SubnetScanResult(BaseModel):
"""A single discovered host from a subnet scan."""
ip_address: str
hostname: Optional[str] = None
api_port_open: bool = False
api_ssl_port_open: bool = False
class SubnetScanResponse(BaseModel):
"""Response for a subnet scan operation."""
cidr: str
discovered: list[SubnetScanResult]
total_scanned: int
total_discovered: int
# ---------------------------------------------------------------------------
# Bulk add from scan
# ---------------------------------------------------------------------------
class BulkDeviceAdd(BaseModel):
"""One device entry within a bulk-add request."""
ip_address: str
hostname: Optional[str] = None
api_port: int = 8728
api_ssl_port: int = 8729
username: Optional[str] = None
password: Optional[str] = None
class BulkAddRequest(BaseModel):
"""
Bulk-add devices selected from a scan result.
shared_username / shared_password are used for all devices that do not
provide their own credentials.
"""
devices: list[BulkDeviceAdd]
shared_username: Optional[str] = None
shared_password: Optional[str] = None
class BulkAddResult(BaseModel):
"""Summary result of a bulk-add operation."""
added: list[DeviceResponse]
failed: list[dict] # {ip_address, error}
# ---------------------------------------------------------------------------
# DeviceGroup schemas
# ---------------------------------------------------------------------------
class DeviceGroupCreate(BaseModel):
"""Schema for creating a device group."""
name: str
description: Optional[str] = None
class DeviceGroupUpdate(BaseModel):
"""Schema for updating a device group."""
name: Optional[str] = None
description: Optional[str] = None
class DeviceGroupResponse(BaseModel):
"""Device group response schema."""
id: uuid.UUID
name: str
description: Optional[str] = None
device_count: int = 0
created_at: datetime
model_config = {"from_attributes": True}
# ---------------------------------------------------------------------------
# DeviceTag schemas
# ---------------------------------------------------------------------------
class DeviceTagCreate(BaseModel):
"""Schema for creating a device tag."""
name: str
color: Optional[str] = None
@field_validator("color")
@classmethod
def validate_color(cls, v: Optional[str]) -> Optional[str]:
"""Validate hex color format if provided."""
if v is None:
return v
import re
if not re.match(r"^#[0-9A-Fa-f]{6}$", v):
raise ValueError("Color must be a valid 6-digit hex color (e.g. #FF5733)")
return v
class DeviceTagUpdate(BaseModel):
"""Schema for updating a device tag."""
name: Optional[str] = None
color: Optional[str] = None
@field_validator("color")
@classmethod
def validate_color(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
import re
if not re.match(r"^#[0-9A-Fa-f]{6}$", v):
raise ValueError("Color must be a valid 6-digit hex color (e.g. #FF5733)")
return v
class DeviceTagResponse(BaseModel):
"""Device tag response schema."""
id: uuid.UUID
name: str
color: Optional[str] = None
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,31 @@
"""Tenant request/response schemas."""
import uuid
from datetime import datetime
from typing import Optional
from pydantic import BaseModel
class TenantCreate(BaseModel):
name: str
description: Optional[str] = None
contact_email: Optional[str] = None
class TenantUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
contact_email: Optional[str] = None
class TenantResponse(BaseModel):
id: uuid.UUID
name: str
description: Optional[str] = None
contact_email: Optional[str] = None
user_count: int = 0
device_count: int = 0
created_at: datetime
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,53 @@
"""User request/response schemas."""
import uuid
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, EmailStr, field_validator
from app.models.user import UserRole
class UserCreate(BaseModel):
name: str
email: EmailStr
password: str
role: UserRole = UserRole.VIEWER
@field_validator("password")
@classmethod
def validate_password(cls, v: str) -> str:
if len(v) < 8:
raise ValueError("Password must be at least 8 characters")
return v
@field_validator("role")
@classmethod
def validate_role(cls, v: UserRole) -> UserRole:
"""Tenant admins can only create operator/viewer roles; super_admin via separate flow."""
allowed_tenant_roles = {UserRole.TENANT_ADMIN, UserRole.OPERATOR, UserRole.VIEWER}
if v not in allowed_tenant_roles:
raise ValueError(
f"Role must be one of: {', '.join(r.value for r in allowed_tenant_roles)}"
)
return v
class UserResponse(BaseModel):
id: uuid.UUID
name: str
email: str
role: str
tenant_id: Optional[uuid.UUID] = None
is_active: bool
last_login: Optional[datetime] = None
created_at: datetime
model_config = {"from_attributes": True}
class UserUpdate(BaseModel):
name: Optional[str] = None
role: Optional[UserRole] = None
is_active: Optional[bool] = None

View File

@@ -0,0 +1,91 @@
"""Pydantic schemas for WireGuard VPN management."""
import uuid
from datetime import datetime
from typing import Optional
from pydantic import BaseModel
# ── VPN Config (server-side) ──
class VpnSetupRequest(BaseModel):
"""Request to enable VPN for a tenant."""
endpoint: Optional[str] = None # public hostname:port — if blank, devices must be configured manually
class VpnConfigResponse(BaseModel):
"""VPN server configuration (never exposes private key)."""
model_config = {"from_attributes": True}
id: uuid.UUID
tenant_id: uuid.UUID
server_public_key: str
subnet: str
server_port: int
server_address: str
endpoint: Optional[str]
is_enabled: bool
peer_count: int = 0
created_at: datetime
class VpnConfigUpdate(BaseModel):
"""Update VPN configuration."""
endpoint: Optional[str] = None
is_enabled: Optional[bool] = None
# ── VPN Peers ──
class VpnPeerCreate(BaseModel):
"""Add a device as a VPN peer."""
device_id: uuid.UUID
additional_allowed_ips: Optional[str] = None # comma-separated subnets for site-to-site routing
class VpnPeerResponse(BaseModel):
"""VPN peer info (never exposes private key)."""
model_config = {"from_attributes": True}
id: uuid.UUID
device_id: uuid.UUID
device_hostname: str = ""
device_ip: str = ""
peer_public_key: str
assigned_ip: str
is_enabled: bool
last_handshake: Optional[datetime]
created_at: datetime
# ── VPN Onboarding (combined device + peer creation) ──
class VpnOnboardRequest(BaseModel):
"""Combined device creation + VPN peer onboarding."""
hostname: str
username: str
password: str
class VpnOnboardResponse(BaseModel):
"""Response from onboarding — device, peer, and RouterOS commands."""
device_id: uuid.UUID
peer_id: uuid.UUID
hostname: str
assigned_ip: str
routeros_commands: list[str]
class VpnPeerConfig(BaseModel):
"""Full peer config for display/export — includes private key for device setup."""
peer_private_key: str
peer_public_key: str
assigned_ip: str
server_public_key: str
server_endpoint: str
allowed_ips: str
routeros_commands: list[str]

View File

View File

@@ -0,0 +1,95 @@
"""Dangerous RouterOS command and path blocklist.
Prevents destructive or sensitive operations from being executed through
the config editor. Commands and paths are checked via case-insensitive
prefix matching against known-dangerous entries.
To extend: add strings to DANGEROUS_COMMANDS, BROWSE_BLOCKED_PATHS,
or WRITE_BLOCKED_PATHS.
"""
from fastapi import HTTPException, status
# CLI commands blocked from the execute endpoint.
# Matched as case-insensitive prefixes (e.g., "/user" blocks "/user/print" too).
DANGEROUS_COMMANDS: list[str] = [
"/system/reset-configuration",
"/system/shutdown",
"/system/reboot",
"/system/backup",
"/system/license",
"/user",
"/password",
"/certificate",
"/radius",
"/export",
"/import",
]
# Paths blocked from ALL operations including browse (truly dangerous to read).
BROWSE_BLOCKED_PATHS: list[str] = [
"system/reset-configuration",
"system/shutdown",
"system/reboot",
"system/backup",
"system/license",
"password",
]
# Paths blocked from write operations (add/set/remove) but readable via browse.
WRITE_BLOCKED_PATHS: list[str] = [
"user",
"certificate",
"radius",
]
def check_command_safety(command: str) -> None:
"""Reject dangerous CLI commands with HTTP 403.
Normalizes the command (strip + lowercase) and checks against
DANGEROUS_COMMANDS using prefix matching.
Raises:
HTTPException: 403 if the command matches a dangerous prefix.
"""
normalized = command.strip().lower()
for blocked in DANGEROUS_COMMANDS:
if normalized.startswith(blocked):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=(
f"Command blocked: '{command}' matches dangerous prefix '{blocked}'. "
f"This operation is not allowed through the config editor."
),
)
def check_path_safety(path: str, *, write: bool = False) -> None:
"""Reject dangerous menu paths with HTTP 403.
Normalizes the path (strip + lowercase + lstrip '/') and checks
against blocked path lists using prefix matching.
Args:
path: The RouterOS menu path to check.
write: If True, also check WRITE_BLOCKED_PATHS (for add/set/remove).
If False, only check BROWSE_BLOCKED_PATHS (for read-only browse).
Raises:
HTTPException: 403 if the path matches a blocked prefix.
"""
normalized = path.strip().lower().lstrip("/")
blocked_lists = [BROWSE_BLOCKED_PATHS]
if write:
blocked_lists.append(WRITE_BLOCKED_PATHS)
for blocklist in blocked_lists:
for blocked in blocklist:
if normalized.startswith(blocked):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=(
f"Path blocked: '{path}' matches dangerous prefix '{blocked}'. "
f"This operation is not allowed through the config editor."
),
)

View File

@@ -0,0 +1 @@
"""Backend services — auth, crypto, and business logic."""

View File

@@ -0,0 +1,240 @@
"""Account self-service operations: deletion and data export.
Provides GDPR/CCPA-compliant account deletion with full PII erasure
and data portability export (Article 20).
All queries use raw SQL via text() with admin sessions (bypass RLS)
since these are cross-table operations on the authenticated user's data.
"""
import hashlib
import uuid
from datetime import UTC, datetime
from typing import Any
import structlog
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import AdminAsyncSessionLocal
from app.services.audit_service import log_action
logger = structlog.get_logger("account_service")
async def delete_user_account(
db: AsyncSession,
user_id: uuid.UUID,
tenant_id: uuid.UUID | None,
user_email: str,
) -> dict[str, Any]:
"""Hard-delete a user account with full PII erasure.
Steps:
1. Create a deletion receipt audit log (persisted via separate session)
2. Anonymize PII in existing audit_logs for this user
3. Hard-delete the user row (CASCADE handles related tables)
4. Best-effort session invalidation via Redis
Args:
db: Admin async session (bypasses RLS).
user_id: UUID of the user to delete.
tenant_id: Tenant UUID (None for super_admin).
user_email: User's email (needed for audit hash before deletion).
Returns:
Dict with deleted=True and user_id on success.
"""
effective_tenant_id = tenant_id or uuid.UUID(int=0)
email_hash = hashlib.sha256(user_email.encode()).hexdigest()
# ── 1. Pre-deletion audit receipt (separate session so it persists) ────
try:
async with AdminAsyncSessionLocal() as audit_db:
await log_action(
audit_db,
tenant_id=effective_tenant_id,
user_id=user_id,
action="account_deleted",
resource_type="user",
resource_id=str(user_id),
details={
"deleted_user_id": str(user_id),
"email_hash": email_hash,
"deletion_type": "self_service",
"deleted_at": datetime.now(UTC).isoformat(),
},
)
await audit_db.commit()
except Exception:
logger.warning(
"deletion_receipt_failed",
user_id=str(user_id),
exc_info=True,
)
# ── 2. Anonymize PII in audit_logs for this user ─────────────────────
# Strip PII keys from details JSONB (email, name, user_email, user_name)
await db.execute(
text(
"UPDATE audit_logs "
"SET details = details - 'email' - 'name' - 'user_email' - 'user_name' "
"WHERE user_id = :user_id"
),
{"user_id": user_id},
)
# Null out encrypted_details (may contain encrypted PII)
await db.execute(
text(
"UPDATE audit_logs "
"SET encrypted_details = NULL "
"WHERE user_id = :user_id"
),
{"user_id": user_id},
)
# ── 3. Hard delete user row ──────────────────────────────────────────
# CASCADE handles: user_key_sets, api_keys, password_reset_tokens
# SET NULL handles: audit_logs.user_id, key_access_log.user_id,
# maintenance_windows.created_by, alert_events.acknowledged_by
await db.execute(
text("DELETE FROM users WHERE id = :user_id"),
{"user_id": user_id},
)
await db.commit()
# ── 4. Best-effort Redis session invalidation ────────────────────────
try:
import redis.asyncio as aioredis
from app.config import settings
from app.services.auth import revoke_user_tokens
r = aioredis.from_url(settings.REDIS_URL, decode_responses=True)
await revoke_user_tokens(r, str(user_id))
await r.aclose()
except Exception:
# JWT expires in 15 min anyway; not critical
logger.debug("redis_session_invalidation_skipped", user_id=str(user_id))
logger.info("account_deleted", user_id=str(user_id), email_hash=email_hash)
return {"deleted": True, "user_id": str(user_id)}
async def export_user_data(
db: AsyncSession,
user_id: uuid.UUID,
tenant_id: uuid.UUID | None,
) -> dict[str, Any]:
"""Assemble all user data for GDPR Art. 20 data portability export.
Returns a structured dict with user profile, API keys, audit logs,
and key access log entries.
Args:
db: Admin async session (bypasses RLS).
user_id: UUID of the user whose data to export.
tenant_id: Tenant UUID (None for super_admin).
Returns:
Envelope dict with export_date, format_version, and all user data.
"""
# ── User profile ─────────────────────────────────────────────────────
result = await db.execute(
text(
"SELECT id, email, name, role, tenant_id, "
"created_at, last_login, auth_version "
"FROM users WHERE id = :user_id"
),
{"user_id": user_id},
)
user_row = result.mappings().first()
user_data: dict[str, Any] = {}
if user_row:
user_data = {
"id": str(user_row["id"]),
"email": user_row["email"],
"name": user_row["name"],
"role": user_row["role"],
"tenant_id": str(user_row["tenant_id"]) if user_row["tenant_id"] else None,
"created_at": user_row["created_at"].isoformat() if user_row["created_at"] else None,
"last_login": user_row["last_login"].isoformat() if user_row["last_login"] else None,
"auth_version": user_row["auth_version"],
}
# ── API keys (exclude key_hash for security) ─────────────────────────
result = await db.execute(
text(
"SELECT id, name, key_prefix, scopes, created_at, "
"expires_at, revoked_at, last_used_at "
"FROM api_keys WHERE user_id = :user_id "
"ORDER BY created_at DESC"
),
{"user_id": user_id},
)
api_keys = []
for row in result.mappings().all():
api_keys.append({
"id": str(row["id"]),
"name": row["name"],
"key_prefix": row["key_prefix"],
"scopes": row["scopes"],
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
"expires_at": row["expires_at"].isoformat() if row["expires_at"] else None,
"revoked_at": row["revoked_at"].isoformat() if row["revoked_at"] else None,
"last_used_at": row["last_used_at"].isoformat() if row["last_used_at"] else None,
})
# ── Audit logs (limit 1000, most recent first) ───────────────────────
result = await db.execute(
text(
"SELECT id, action, resource_type, resource_id, "
"details, ip_address, created_at "
"FROM audit_logs WHERE user_id = :user_id "
"ORDER BY created_at DESC LIMIT 1000"
),
{"user_id": user_id},
)
audit_logs = []
for row in result.mappings().all():
details = row["details"] if row["details"] else {}
audit_logs.append({
"id": str(row["id"]),
"action": row["action"],
"resource_type": row["resource_type"],
"resource_id": row["resource_id"],
"details": details,
"ip_address": row["ip_address"],
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
})
# ── Key access log (limit 1000, most recent first) ───────────────────
result = await db.execute(
text(
"SELECT id, action, resource_type, ip_address, created_at "
"FROM key_access_log WHERE user_id = :user_id "
"ORDER BY created_at DESC LIMIT 1000"
),
{"user_id": user_id},
)
key_access_entries = []
for row in result.mappings().all():
key_access_entries.append({
"id": str(row["id"]),
"action": row["action"],
"resource_type": row["resource_type"],
"ip_address": row["ip_address"],
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
})
return {
"export_date": datetime.now(UTC).isoformat(),
"format_version": "1.0",
"user": user_data,
"api_keys": api_keys,
"audit_logs": audit_logs,
"key_access_log": key_access_entries,
}

View File

@@ -0,0 +1,723 @@
"""Alert rule evaluation engine with Redis breach counters and flap detection.
Entry points:
- evaluate(device_id, tenant_id, metric_type, data): called from metrics_subscriber
- evaluate_offline(device_id, tenant_id): called from nats_subscriber on device offline
- evaluate_online(device_id, tenant_id): called from nats_subscriber on device online
Uses Redis for:
- Consecutive breach counting (alert:breach:{device_id}:{rule_id})
- Flap detection (alert:flap:{device_id}:{rule_id} sorted set)
Uses AdminAsyncSessionLocal for all DB operations (runs cross-tenant in NATS handlers).
"""
import asyncio
import logging
import time
from datetime import datetime, timezone
from typing import Any
import redis.asyncio as aioredis
from sqlalchemy import text
from app.config import settings
from app.database import AdminAsyncSessionLocal
from app.services.event_publisher import publish_event
logger = logging.getLogger(__name__)
# Module-level Redis client, lazily initialized
_redis_client: aioredis.Redis | None = None
# Module-level rule cache: {tenant_id: (rules_list, fetched_at_timestamp)}
_rule_cache: dict[str, tuple[list[dict], float]] = {}
_CACHE_TTL_SECONDS = 60
# Module-level maintenance window cache: {tenant_id: (active_windows_list, fetched_at_timestamp)}
# Each window: {"device_ids": [...], "suppress_alerts": True}
_maintenance_cache: dict[str, tuple[list[dict], float]] = {}
_MAINTENANCE_CACHE_TTL = 30 # 30 seconds
async def _get_redis() -> aioredis.Redis:
"""Get or create the Redis client."""
global _redis_client
if _redis_client is None:
_redis_client = aioredis.from_url(settings.REDIS_URL, decode_responses=True)
return _redis_client
async def _get_active_maintenance_windows(tenant_id: str) -> list[dict]:
"""Fetch active maintenance windows for a tenant, with 30s cache."""
now = time.time()
cached = _maintenance_cache.get(tenant_id)
if cached and (now - cached[1]) < _MAINTENANCE_CACHE_TTL:
return cached[0]
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT device_ids, suppress_alerts
FROM maintenance_windows
WHERE tenant_id = CAST(:tenant_id AS uuid)
AND suppress_alerts = true
AND start_at <= NOW()
AND end_at >= NOW()
"""),
{"tenant_id": tenant_id},
)
rows = result.fetchall()
windows = [
{
"device_ids": row[0] if isinstance(row[0], list) else [],
"suppress_alerts": row[1],
}
for row in rows
]
_maintenance_cache[tenant_id] = (windows, now)
return windows
async def _is_device_in_maintenance(tenant_id: str, device_id: str) -> bool:
"""Check if a device is currently under active maintenance with alert suppression.
Returns True if there is at least one active maintenance window covering
this device (or all devices via empty device_ids array).
"""
windows = await _get_active_maintenance_windows(tenant_id)
for window in windows:
device_ids = window["device_ids"]
# Empty device_ids means "all devices in tenant"
if not device_ids or device_id in device_ids:
return True
return False
async def _get_rules_for_tenant(tenant_id: str) -> list[dict]:
"""Fetch active alert rules for a tenant, with 60s cache."""
now = time.time()
cached = _rule_cache.get(tenant_id)
if cached and (now - cached[1]) < _CACHE_TTL_SECONDS:
return cached[0]
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT id, tenant_id, device_id, group_id, name, metric,
operator, threshold, duration_polls, severity
FROM alert_rules
WHERE tenant_id = CAST(:tenant_id AS uuid) AND enabled = TRUE
"""),
{"tenant_id": tenant_id},
)
rows = result.fetchall()
rules = [
{
"id": str(row[0]),
"tenant_id": str(row[1]),
"device_id": str(row[2]) if row[2] else None,
"group_id": str(row[3]) if row[3] else None,
"name": row[4],
"metric": row[5],
"operator": row[6],
"threshold": float(row[7]),
"duration_polls": row[8],
"severity": row[9],
}
for row in rows
]
_rule_cache[tenant_id] = (rules, now)
return rules
def _check_threshold(value: float, operator: str, threshold: float) -> bool:
"""Check if a metric value breaches a threshold."""
if operator == "gt":
return value > threshold
elif operator == "lt":
return value < threshold
elif operator == "gte":
return value >= threshold
elif operator == "lte":
return value <= threshold
return False
def _extract_metrics(metric_type: str, data: dict) -> dict[str, float]:
"""Extract metric name->value pairs from a NATS metrics event."""
metrics: dict[str, float] = {}
if metric_type == "health":
health = data.get("health", {})
for key in ("cpu_load", "temperature"):
val = health.get(key)
if val is not None and val != "":
try:
metrics[key] = float(val)
except (ValueError, TypeError):
pass
# Compute memory_used_pct and disk_used_pct
free_mem = health.get("free_memory")
total_mem = health.get("total_memory")
if free_mem is not None and total_mem is not None:
try:
total = float(total_mem)
free = float(free_mem)
if total > 0:
metrics["memory_used_pct"] = round((1.0 - free / total) * 100, 1)
except (ValueError, TypeError):
pass
free_disk = health.get("free_disk")
total_disk = health.get("total_disk")
if free_disk is not None and total_disk is not None:
try:
total = float(total_disk)
free = float(free_disk)
if total > 0:
metrics["disk_used_pct"] = round((1.0 - free / total) * 100, 1)
except (ValueError, TypeError):
pass
elif metric_type == "wireless":
wireless = data.get("wireless", [])
# Aggregate: use worst signal, lowest CCQ, sum client_count
for wif in wireless:
for key in ("signal_strength", "ccq", "client_count"):
val = wif.get(key) if key != "avg_signal" else wif.get("avg_signal")
if key == "signal_strength":
val = wif.get("avg_signal")
if val is not None and val != "":
try:
fval = float(val)
if key not in metrics:
metrics[key] = fval
elif key == "signal_strength":
metrics[key] = min(metrics[key], fval) # worst signal
elif key == "ccq":
metrics[key] = min(metrics[key], fval) # worst CCQ
elif key == "client_count":
metrics[key] = metrics.get(key, 0) + fval # sum
except (ValueError, TypeError):
pass
# TODO: Interface bandwidth alerting (rx_bps/tx_bps) requires stateful delta
# computation between consecutive poll values. Deferred for now — the alert_rules
# table supports these metric types, but evaluation is skipped.
return metrics
async def _increment_breach(
r: aioredis.Redis, device_id: str, rule_id: str, required_polls: int
) -> bool:
"""Increment breach counter in Redis. Returns True when threshold duration reached."""
key = f"alert:breach:{device_id}:{rule_id}"
count = await r.incr(key)
# Set TTL to (required_polls + 2) * 60 seconds so it expires if breaches stop
await r.expire(key, (required_polls + 2) * 60)
return count >= required_polls
async def _reset_breach(r: aioredis.Redis, device_id: str, rule_id: str) -> None:
"""Reset breach counter when metric returns to normal."""
key = f"alert:breach:{device_id}:{rule_id}"
await r.delete(key)
async def _check_flapping(r: aioredis.Redis, device_id: str, rule_id: str) -> bool:
"""Check if alert is flapping (>= 5 state transitions in 10 minutes).
Uses a Redis sorted set with timestamps as scores.
"""
key = f"alert:flap:{device_id}:{rule_id}"
now = time.time()
window_start = now - 600 # 10 minute window
# Add this transition
await r.zadd(key, {str(now): now})
# Remove entries outside the window
await r.zremrangebyscore(key, "-inf", window_start)
# Set TTL on the key
await r.expire(key, 1200)
# Count transitions in window
count = await r.zcard(key)
return count >= 5
async def _get_device_groups(device_id: str) -> list[str]:
"""Get group IDs for a device."""
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("SELECT group_id FROM device_group_memberships WHERE device_id = CAST(:device_id AS uuid)"),
{"device_id": device_id},
)
return [str(row[0]) for row in result.fetchall()]
async def _has_open_alert(device_id: str, rule_id: str | None, metric: str | None = None) -> bool:
"""Check if there's an open (firing, unresolved) alert for this device+rule."""
async with AdminAsyncSessionLocal() as session:
if rule_id:
result = await session.execute(
text("""
SELECT 1 FROM alert_events
WHERE device_id = CAST(:device_id AS uuid) AND rule_id = CAST(:rule_id AS uuid)
AND status = 'firing' AND resolved_at IS NULL
LIMIT 1
"""),
{"device_id": device_id, "rule_id": rule_id},
)
else:
result = await session.execute(
text("""
SELECT 1 FROM alert_events
WHERE device_id = CAST(:device_id AS uuid) AND rule_id IS NULL
AND metric = :metric AND status = 'firing' AND resolved_at IS NULL
LIMIT 1
"""),
{"device_id": device_id, "metric": metric or "offline"},
)
return result.fetchone() is not None
async def _create_alert_event(
device_id: str,
tenant_id: str,
rule_id: str | None,
status: str,
severity: str,
metric: str | None,
value: float | None,
threshold: float | None,
message: str | None,
is_flapping: bool = False,
) -> dict:
"""Create an alert event row and return its data."""
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
INSERT INTO alert_events
(id, device_id, tenant_id, rule_id, status, severity, metric,
value, threshold, message, is_flapping, fired_at,
resolved_at)
VALUES
(gen_random_uuid(), CAST(:device_id AS uuid), CAST(:tenant_id AS uuid),
:rule_id, :status, :severity, :metric,
:value, :threshold, :message, :is_flapping, NOW(),
CASE WHEN :status = 'resolved' THEN NOW() ELSE NULL END)
RETURNING id, fired_at
"""),
{
"device_id": device_id,
"tenant_id": tenant_id,
"rule_id": rule_id,
"status": status,
"severity": severity,
"metric": metric,
"value": value,
"threshold": threshold,
"message": message,
"is_flapping": is_flapping,
},
)
row = result.fetchone()
await session.commit()
alert_data = {
"id": str(row[0]) if row else None,
"device_id": device_id,
"tenant_id": tenant_id,
"rule_id": rule_id,
"status": status,
"severity": severity,
"metric": metric,
"value": value,
"threshold": threshold,
"message": message,
"is_flapping": is_flapping,
}
# Publish real-time event to NATS for SSE pipeline (fire-and-forget)
if status in ("firing", "flapping"):
await publish_event(f"alert.fired.{tenant_id}", {
"event_type": "alert_fired",
"tenant_id": tenant_id,
"device_id": device_id,
"alert_event_id": alert_data["id"],
"severity": severity,
"metric": metric,
"current_value": value,
"threshold": threshold,
"message": message,
"is_flapping": is_flapping,
"fired_at": datetime.now(timezone.utc).isoformat(),
})
elif status == "resolved":
await publish_event(f"alert.resolved.{tenant_id}", {
"event_type": "alert_resolved",
"tenant_id": tenant_id,
"device_id": device_id,
"alert_event_id": alert_data["id"],
"severity": severity,
"metric": metric,
"message": message,
"resolved_at": datetime.now(timezone.utc).isoformat(),
})
return alert_data
async def _resolve_alert(device_id: str, rule_id: str | None, metric: str | None = None) -> None:
"""Resolve an open alert by setting resolved_at."""
async with AdminAsyncSessionLocal() as session:
if rule_id:
await session.execute(
text("""
UPDATE alert_events SET resolved_at = NOW(), status = 'resolved'
WHERE device_id = CAST(:device_id AS uuid) AND rule_id = CAST(:rule_id AS uuid)
AND status = 'firing' AND resolved_at IS NULL
"""),
{"device_id": device_id, "rule_id": rule_id},
)
else:
await session.execute(
text("""
UPDATE alert_events SET resolved_at = NOW(), status = 'resolved'
WHERE device_id = CAST(:device_id AS uuid) AND rule_id IS NULL
AND metric = :metric AND status = 'firing' AND resolved_at IS NULL
"""),
{"device_id": device_id, "metric": metric or "offline"},
)
await session.commit()
async def _get_channels_for_tenant(tenant_id: str) -> list[dict]:
"""Get all notification channels for a tenant."""
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT id, name, channel_type, smtp_host, smtp_port, smtp_user,
smtp_password, smtp_use_tls, from_address, to_address,
webhook_url, smtp_password_transit, slack_webhook_url, tenant_id
FROM notification_channels
WHERE tenant_id = CAST(:tenant_id AS uuid)
"""),
{"tenant_id": tenant_id},
)
return [
{
"id": str(row[0]),
"name": row[1],
"channel_type": row[2],
"smtp_host": row[3],
"smtp_port": row[4],
"smtp_user": row[5],
"smtp_password": row[6],
"smtp_use_tls": row[7],
"from_address": row[8],
"to_address": row[9],
"webhook_url": row[10],
"smtp_password_transit": row[11],
"slack_webhook_url": row[12],
"tenant_id": str(row[13]) if row[13] else None,
}
for row in result.fetchall()
]
async def _get_channels_for_rule(rule_id: str) -> list[dict]:
"""Get notification channels linked to a specific alert rule."""
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT nc.id, nc.name, nc.channel_type, nc.smtp_host, nc.smtp_port,
nc.smtp_user, nc.smtp_password, nc.smtp_use_tls,
nc.from_address, nc.to_address, nc.webhook_url,
nc.smtp_password_transit, nc.slack_webhook_url, nc.tenant_id
FROM notification_channels nc
JOIN alert_rule_channels arc ON arc.channel_id = nc.id
WHERE arc.rule_id = CAST(:rule_id AS uuid)
"""),
{"rule_id": rule_id},
)
return [
{
"id": str(row[0]),
"name": row[1],
"channel_type": row[2],
"smtp_host": row[3],
"smtp_port": row[4],
"smtp_user": row[5],
"smtp_password": row[6],
"smtp_use_tls": row[7],
"from_address": row[8],
"to_address": row[9],
"webhook_url": row[10],
"smtp_password_transit": row[11],
"slack_webhook_url": row[12],
"tenant_id": str(row[13]) if row[13] else None,
}
for row in result.fetchall()
]
async def _dispatch_async(alert_event: dict, channels: list[dict], device_hostname: str) -> None:
"""Fire-and-forget notification dispatch."""
try:
from app.services.notification_service import dispatch_notifications
await dispatch_notifications(alert_event, channels, device_hostname)
except Exception as e:
logger.warning("Notification dispatch failed: %s", e)
async def _get_device_hostname(device_id: str) -> str:
"""Get device hostname for notification messages."""
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("SELECT hostname FROM devices WHERE id = CAST(:device_id AS uuid)"),
{"device_id": device_id},
)
row = result.fetchone()
return row[0] if row else device_id
async def evaluate(
device_id: str,
tenant_id: str,
metric_type: str,
data: dict[str, Any],
) -> None:
"""Evaluate alert rules for incoming device metrics.
Called from metrics_subscriber after metric DB write.
"""
# Check maintenance window suppression before evaluating rules
if await _is_device_in_maintenance(tenant_id, device_id):
logger.debug(
"Alert suppressed by maintenance window for device %s tenant %s",
device_id, tenant_id,
)
return
rules = await _get_rules_for_tenant(tenant_id)
if not rules:
return
metrics = _extract_metrics(metric_type, data)
if not metrics:
return
r = await _get_redis()
device_groups = await _get_device_groups(device_id)
# Build a set of metrics that have device-specific rules
device_specific_metrics: set[str] = set()
for rule in rules:
if rule["device_id"] == device_id:
device_specific_metrics.add(rule["metric"])
for rule in rules:
rule_metric = rule["metric"]
if rule_metric not in metrics:
continue
# Check if rule applies to this device
applies = False
if rule["device_id"] == device_id:
applies = True
elif rule["device_id"] is None and rule["group_id"] is None:
# Tenant-wide rule — skip if device-specific rule exists for same metric
if rule_metric in device_specific_metrics:
continue
applies = True
elif rule["group_id"] and rule["group_id"] in device_groups:
applies = True
if not applies:
continue
value = metrics[rule_metric]
breaching = _check_threshold(value, rule["operator"], rule["threshold"])
if breaching:
reached = await _increment_breach(r, device_id, rule["id"], rule["duration_polls"])
if reached:
# Check if already firing
if await _has_open_alert(device_id, rule["id"]):
continue
# Check flapping
is_flapping = await _check_flapping(r, device_id, rule["id"])
hostname = await _get_device_hostname(device_id)
message = f"{rule['name']}: {rule_metric} = {value} (threshold: {rule['operator']} {rule['threshold']})"
alert_event = await _create_alert_event(
device_id=device_id,
tenant_id=tenant_id,
rule_id=rule["id"],
status="flapping" if is_flapping else "firing",
severity=rule["severity"],
metric=rule_metric,
value=value,
threshold=rule["threshold"],
message=message,
is_flapping=is_flapping,
)
if is_flapping:
logger.info(
"Alert %s for device %s is flapping — notifications suppressed",
rule["name"], device_id,
)
else:
channels = await _get_channels_for_rule(rule["id"])
if channels:
asyncio.create_task(_dispatch_async(alert_event, channels, hostname))
else:
# Not breaching — reset counter and check for open alert to resolve
await _reset_breach(r, device_id, rule["id"])
if await _has_open_alert(device_id, rule["id"]):
# Check flapping before resolving
is_flapping = await _check_flapping(r, device_id, rule["id"])
await _resolve_alert(device_id, rule["id"])
hostname = await _get_device_hostname(device_id)
message = f"Resolved: {rule['name']}: {rule_metric} = {value}"
resolved_event = await _create_alert_event(
device_id=device_id,
tenant_id=tenant_id,
rule_id=rule["id"],
status="resolved",
severity=rule["severity"],
metric=rule_metric,
value=value,
threshold=rule["threshold"],
message=message,
is_flapping=is_flapping,
)
if not is_flapping:
channels = await _get_channels_for_rule(rule["id"])
if channels:
asyncio.create_task(_dispatch_async(resolved_event, channels, hostname))
async def _get_offline_rule(tenant_id: str) -> dict | None:
"""Look up the device_offline default rule for a tenant."""
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT id, enabled FROM alert_rules
WHERE tenant_id = CAST(:tenant_id AS uuid)
AND metric = 'device_offline' AND is_default = TRUE
LIMIT 1
"""),
{"tenant_id": tenant_id},
)
row = result.fetchone()
if row:
return {"id": str(row[0]), "enabled": row[1]}
return None
async def evaluate_offline(device_id: str, tenant_id: str) -> None:
"""Create a critical alert when a device goes offline.
Uses the tenant's device_offline default rule if it exists and is enabled.
Falls back to system-level alert (rule_id=NULL) for backward compatibility.
"""
if await _is_device_in_maintenance(tenant_id, device_id):
logger.debug(
"Offline alert suppressed by maintenance window for device %s",
device_id,
)
return
rule = await _get_offline_rule(tenant_id)
rule_id = rule["id"] if rule else None
# If rule exists but is disabled, skip alert creation (user opted out)
if rule and not rule["enabled"]:
return
if rule_id:
if await _has_open_alert(device_id, rule_id):
return
else:
if await _has_open_alert(device_id, None, "offline"):
return
hostname = await _get_device_hostname(device_id)
message = f"Device {hostname} is offline"
alert_event = await _create_alert_event(
device_id=device_id,
tenant_id=tenant_id,
rule_id=rule_id,
status="firing",
severity="critical",
metric="offline",
value=None,
threshold=None,
message=message,
)
# Use rule-linked channels if available, otherwise tenant-wide channels
if rule_id:
channels = await _get_channels_for_rule(rule_id)
if not channels:
channels = await _get_channels_for_tenant(tenant_id)
else:
channels = await _get_channels_for_tenant(tenant_id)
if channels:
asyncio.create_task(_dispatch_async(alert_event, channels, hostname))
async def evaluate_online(device_id: str, tenant_id: str) -> None:
"""Resolve offline alert when device comes back online."""
rule = await _get_offline_rule(tenant_id)
rule_id = rule["id"] if rule else None
if rule_id:
if not await _has_open_alert(device_id, rule_id):
return
await _resolve_alert(device_id, rule_id)
else:
if not await _has_open_alert(device_id, None, "offline"):
return
await _resolve_alert(device_id, None, "offline")
hostname = await _get_device_hostname(device_id)
message = f"Device {hostname} is back online"
resolved_event = await _create_alert_event(
device_id=device_id,
tenant_id=tenant_id,
rule_id=rule_id,
status="resolved",
severity="critical",
metric="offline",
value=None,
threshold=None,
message=message,
)
if rule_id:
channels = await _get_channels_for_rule(rule_id)
if not channels:
channels = await _get_channels_for_tenant(tenant_id)
else:
channels = await _get_channels_for_tenant(tenant_id)
if channels:
asyncio.create_task(_dispatch_async(resolved_event, channels, hostname))

View File

@@ -0,0 +1,190 @@
"""API key generation, validation, and management service.
Keys use the mktp_ prefix for easy identification in logs.
Storage uses SHA-256 hash -- the plaintext key is never persisted.
Validation uses AdminAsyncSessionLocal since it runs before tenant context is set.
"""
import hashlib
import json
import secrets
import uuid
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy import text
from app.database import AdminAsyncSessionLocal
# Allowed scopes for API keys
ALLOWED_SCOPES: set[str] = {
"devices:read",
"devices:write",
"config:read",
"config:write",
"alerts:read",
"firmware:write",
}
def generate_raw_key() -> str:
"""Generate a raw API key with mktp_ prefix + 32 URL-safe random chars."""
random_part = secrets.token_urlsafe(32)
return f"mktp_{random_part}"
def hash_key(raw_key: str) -> str:
"""SHA-256 hex digest of a raw API key."""
return hashlib.sha256(raw_key.encode()).hexdigest()
async def create_api_key(
db,
tenant_id: uuid.UUID,
user_id: uuid.UUID,
name: str,
scopes: list[str],
expires_at: Optional[datetime] = None,
) -> dict:
"""Create a new API key.
Returns dict with:
- key: the plaintext key (shown once, never again)
- id: the key UUID
- key_prefix: first 9 chars of the key (e.g. "mktp_abc1")
"""
raw_key = generate_raw_key()
key_hash_value = hash_key(raw_key)
key_prefix = raw_key[:9] # "mktp_" + first 4 random chars
result = await db.execute(
text("""
INSERT INTO api_keys (tenant_id, user_id, name, key_prefix, key_hash, scopes, expires_at)
VALUES (:tenant_id, :user_id, :name, :key_prefix, :key_hash, CAST(:scopes AS jsonb), :expires_at)
RETURNING id, created_at
"""),
{
"tenant_id": str(tenant_id),
"user_id": str(user_id),
"name": name,
"key_prefix": key_prefix,
"key_hash": key_hash_value,
"scopes": json.dumps(scopes),
"expires_at": expires_at,
},
)
row = result.fetchone()
await db.commit()
return {
"key": raw_key,
"id": row.id,
"key_prefix": key_prefix,
"name": name,
"scopes": scopes,
"expires_at": expires_at,
"created_at": row.created_at,
}
async def validate_api_key(raw_key: str) -> Optional[dict]:
"""Validate an API key and return context if valid.
Uses AdminAsyncSessionLocal since this runs before tenant context is set.
Returns dict with tenant_id, user_id, scopes, key_id on success.
Returns None for invalid, expired, or revoked keys.
Updates last_used_at on successful validation.
"""
key_hash_value = hash_key(raw_key)
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT id, tenant_id, user_id, scopes, expires_at, revoked_at
FROM api_keys
WHERE key_hash = :key_hash
"""),
{"key_hash": key_hash_value},
)
row = result.fetchone()
if not row:
return None
# Check revoked
if row.revoked_at is not None:
return None
# Check expired
if row.expires_at is not None and row.expires_at <= datetime.now(timezone.utc):
return None
# Update last_used_at
await session.execute(
text("""
UPDATE api_keys SET last_used_at = now()
WHERE id = :key_id
"""),
{"key_id": str(row.id)},
)
await session.commit()
return {
"tenant_id": row.tenant_id,
"user_id": row.user_id,
"scopes": row.scopes if row.scopes else [],
"key_id": row.id,
}
async def list_api_keys(db, tenant_id: uuid.UUID) -> list[dict]:
"""List all API keys for a tenant (active and revoked).
Returns keys with masked display (key_prefix + "...").
"""
result = await db.execute(
text("""
SELECT id, name, key_prefix, scopes, expires_at, last_used_at,
created_at, revoked_at, user_id
FROM api_keys
WHERE tenant_id = :tenant_id
ORDER BY created_at DESC
"""),
{"tenant_id": str(tenant_id)},
)
rows = result.fetchall()
return [
{
"id": row.id,
"name": row.name,
"key_prefix": row.key_prefix,
"scopes": row.scopes if row.scopes else [],
"expires_at": row.expires_at.isoformat() if row.expires_at else None,
"last_used_at": row.last_used_at.isoformat() if row.last_used_at else None,
"created_at": row.created_at.isoformat() if row.created_at else None,
"revoked_at": row.revoked_at.isoformat() if row.revoked_at else None,
"user_id": str(row.user_id),
}
for row in rows
]
async def revoke_api_key(db, tenant_id: uuid.UUID, key_id: uuid.UUID) -> bool:
"""Revoke an API key by setting revoked_at = now().
Returns True if a key was actually revoked, False if not found or already revoked.
"""
result = await db.execute(
text("""
UPDATE api_keys
SET revoked_at = now()
WHERE id = :key_id AND tenant_id = :tenant_id AND revoked_at IS NULL
RETURNING id
"""),
{"key_id": str(key_id), "tenant_id": str(tenant_id)},
)
row = result.fetchone()
await db.commit()
return row is not None

View File

@@ -0,0 +1,92 @@
"""Centralized audit logging service.
Provides a fire-and-forget ``log_action`` coroutine that inserts a row into
the ``audit_logs`` table. Uses raw SQL INSERT (not ORM) for minimal overhead.
The function is wrapped in a try/except so that a logging failure **never**
breaks the parent operation.
Phase 30: When details are non-empty, they are encrypted via OpenBao Transit
(per-tenant data key) and stored in encrypted_details. The plaintext details
column is set to '{}' for column compatibility. If Transit encryption fails
(e.g., OpenBao unavailable), details are stored in plaintext as a fallback.
"""
import uuid
from typing import Any, Optional
import structlog
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
logger = structlog.get_logger("audit")
async def log_action(
db: AsyncSession,
tenant_id: uuid.UUID,
user_id: uuid.UUID,
action: str,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
device_id: Optional[uuid.UUID] = None,
details: Optional[dict[str, Any]] = None,
ip_address: Optional[str] = None,
) -> None:
"""Insert a row into audit_logs. Swallows all exceptions on failure."""
try:
import json as _json
details_dict = details or {}
details_json = _json.dumps(details_dict)
encrypted_details: Optional[str] = None
# Attempt Transit encryption for non-empty details
if details_dict:
try:
from app.services.crypto import encrypt_data_transit
encrypted_details = await encrypt_data_transit(
details_json, str(tenant_id)
)
# Encryption succeeded — clear plaintext details
details_json = _json.dumps({})
except Exception:
# Transit unavailable — fall back to plaintext details
logger.warning(
"audit_transit_encryption_failed",
action=action,
tenant_id=str(tenant_id),
exc_info=True,
)
# Keep details_json as-is (plaintext fallback)
encrypted_details = None
await db.execute(
text(
"INSERT INTO audit_logs "
"(tenant_id, user_id, action, resource_type, resource_id, "
"device_id, details, encrypted_details, ip_address) "
"VALUES (:tenant_id, :user_id, :action, :resource_type, "
":resource_id, :device_id, CAST(:details AS jsonb), "
":encrypted_details, :ip_address)"
),
{
"tenant_id": str(tenant_id),
"user_id": str(user_id),
"action": action,
"resource_type": resource_type,
"resource_id": resource_id,
"device_id": str(device_id) if device_id else None,
"details": details_json,
"encrypted_details": encrypted_details,
"ip_address": ip_address,
},
)
except Exception:
logger.warning(
"audit_log_insert_failed",
action=action,
tenant_id=str(tenant_id),
exc_info=True,
)

View File

@@ -0,0 +1,154 @@
"""
JWT authentication service.
Handles password hashing, JWT token creation, token verification,
and token revocation via Redis.
"""
import time
import uuid
from datetime import UTC, datetime, timedelta
from typing import Optional
import bcrypt
from fastapi import HTTPException, status
from jose import JWTError, jwt
from redis.asyncio import Redis
from app.config import settings
TOKEN_REVOCATION_PREFIX = "token_revoked:"
def hash_password(password: str) -> str:
"""Hash a plaintext password using bcrypt.
DEPRECATED: Used only by password reset (temporary bcrypt hash for
upgrade flow) and bootstrap_first_admin. Remove post-v6.0.
"""
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a plaintext password against a bcrypt hash.
DEPRECATED: Used only by the one-time SRP upgrade flow (login with
must_upgrade_auth=True) and anti-enumeration dummy calls. Remove post-v6.0.
"""
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
def create_access_token(
user_id: uuid.UUID,
tenant_id: Optional[uuid.UUID],
role: str,
) -> str:
"""
Create a short-lived JWT access token.
Claims:
sub: user UUID (subject)
tenant_id: tenant UUID or None for super_admin
role: user's role string
type: "access" (to distinguish from refresh tokens)
exp: expiry timestamp
"""
now = datetime.now(UTC)
expire = now + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
payload = {
"sub": str(user_id),
"tenant_id": str(tenant_id) if tenant_id else None,
"role": role,
"type": "access",
"iat": now,
"exp": expire,
}
return jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
def create_refresh_token(user_id: uuid.UUID) -> str:
"""
Create a long-lived JWT refresh token.
Claims:
sub: user UUID (subject)
type: "refresh" (to distinguish from access tokens)
exp: expiry timestamp (7 days)
"""
now = datetime.now(UTC)
expire = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
payload = {
"sub": str(user_id),
"type": "refresh",
"iat": now,
"exp": expire,
}
return jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
def verify_token(token: str, expected_type: str = "access") -> dict:
"""
Decode and validate a JWT token.
Args:
token: JWT string to validate
expected_type: "access" or "refresh"
Returns:
dict: Decoded payload (sub, tenant_id, role, type, exp, iat)
Raises:
HTTPException 401: If token is invalid, expired, or wrong type
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM],
)
except JWTError:
raise credentials_exception
# Validate token type
token_type = payload.get("type")
if token_type != expected_type:
raise credentials_exception
# Validate subject exists
sub = payload.get("sub")
if not sub:
raise credentials_exception
return payload
async def revoke_user_tokens(redis: Redis, user_id: str) -> None:
"""Mark all tokens for a user as revoked by storing current timestamp.
Any refresh token issued before this timestamp will be rejected.
TTL matches maximum refresh token lifetime (7 days).
"""
key = f"{TOKEN_REVOCATION_PREFIX}{user_id}"
await redis.set(key, str(time.time()), ex=7 * 24 * 3600)
async def is_token_revoked(redis: Redis, user_id: str, issued_at: float) -> bool:
"""Check if a token was issued before the user's revocation timestamp.
Returns True if the token should be rejected.
"""
key = f"{TOKEN_REVOCATION_PREFIX}{user_id}"
revoked_at = await redis.get(key)
if revoked_at is None:
return False
return issued_at < float(revoked_at)

View File

@@ -0,0 +1,197 @@
"""Dynamic backup scheduler — reads cron schedules from DB, manages APScheduler jobs."""
import logging
from typing import Optional
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from app.database import AdminAsyncSessionLocal
from app.models.config_backup import ConfigBackupSchedule
from app.models.device import Device
from app.services import backup_service
from sqlalchemy import select
logger = logging.getLogger(__name__)
_scheduler: Optional[AsyncIOScheduler] = None
# System default: 2am UTC daily
DEFAULT_CRON = "0 2 * * *"
def _cron_to_trigger(cron_expr: str) -> Optional[CronTrigger]:
"""Parse a 5-field cron expression into an APScheduler CronTrigger.
Returns None if the expression is invalid.
"""
try:
parts = cron_expr.strip().split()
if len(parts) != 5:
return None
minute, hour, day, month, day_of_week = parts
return CronTrigger(
minute=minute, hour=hour, day=day, month=month,
day_of_week=day_of_week, timezone="UTC",
)
except Exception as e:
logger.warning("Invalid cron expression '%s': %s", cron_expr, e)
return None
def build_schedule_map(schedules: list) -> dict[str, list[dict]]:
"""Group device schedules by cron expression.
Returns: {cron_expression: [{device_id, tenant_id}, ...]}
"""
schedule_map: dict[str, list[dict]] = {}
for s in schedules:
if not s.enabled:
continue
cron = s.cron_expression or DEFAULT_CRON
if cron not in schedule_map:
schedule_map[cron] = []
schedule_map[cron].append({
"device_id": str(s.device_id),
"tenant_id": str(s.tenant_id),
})
return schedule_map
async def _run_scheduled_backups(devices: list[dict]) -> None:
"""Run backups for a list of devices. Each failure is isolated."""
success_count = 0
failure_count = 0
for dev_info in devices:
try:
async with AdminAsyncSessionLocal() as session:
await backup_service.run_backup(
device_id=dev_info["device_id"],
tenant_id=dev_info["tenant_id"],
trigger_type="scheduled",
db_session=session,
)
await session.commit()
logger.info("Scheduled backup OK: device %s", dev_info["device_id"])
success_count += 1
except Exception as e:
logger.error(
"Scheduled backup FAILED: device %s: %s",
dev_info["device_id"], e,
)
failure_count += 1
logger.info(
"Backup batch complete — %d succeeded, %d failed",
success_count, failure_count,
)
async def _load_effective_schedules() -> list:
"""Load all effective schedules from DB.
For each device: use device-specific schedule if exists, else tenant default.
Returns flat list of (device_id, tenant_id, cron_expression, enabled) objects.
"""
from types import SimpleNamespace
async with AdminAsyncSessionLocal() as session:
# Get all devices
dev_result = await session.execute(select(Device))
devices = dev_result.scalars().all()
# Get all schedules
sched_result = await session.execute(select(ConfigBackupSchedule))
schedules = sched_result.scalars().all()
# Index: device-specific and tenant defaults
device_schedules = {} # device_id -> schedule
tenant_defaults = {} # tenant_id -> schedule
for s in schedules:
if s.device_id:
device_schedules[str(s.device_id)] = s
else:
tenant_defaults[str(s.tenant_id)] = s
effective = []
for dev in devices:
dev_id = str(dev.id)
tenant_id = str(dev.tenant_id)
if dev_id in device_schedules:
sched = device_schedules[dev_id]
elif tenant_id in tenant_defaults:
sched = tenant_defaults[tenant_id]
else:
# No schedule configured — use system default
sched = None
effective.append(SimpleNamespace(
device_id=dev_id,
tenant_id=tenant_id,
cron_expression=sched.cron_expression if sched else DEFAULT_CRON,
enabled=sched.enabled if sched else True,
))
return effective
async def sync_schedules() -> None:
"""Reload all schedules from DB and reconfigure APScheduler jobs."""
global _scheduler
if not _scheduler:
return
# Remove all existing backup jobs (keep other jobs like firmware check)
for job in _scheduler.get_jobs():
if job.id.startswith("backup_cron_"):
job.remove()
schedules = await _load_effective_schedules()
schedule_map = build_schedule_map(schedules)
for cron_expr, devices in schedule_map.items():
trigger = _cron_to_trigger(cron_expr)
if not trigger:
logger.warning("Skipping invalid cron '%s', using default", cron_expr)
trigger = _cron_to_trigger(DEFAULT_CRON)
job_id = f"backup_cron_{cron_expr.replace(' ', '_')}"
_scheduler.add_job(
_run_scheduled_backups,
trigger=trigger,
args=[devices],
id=job_id,
name=f"Backup: {cron_expr} ({len(devices)} devices)",
max_instances=1,
replace_existing=True,
)
logger.info("Scheduled %d devices with cron '%s'", len(devices), cron_expr)
async def on_schedule_change(tenant_id: str, device_id: str) -> None:
"""Called when a schedule is created/updated via API. Hot-reloads all schedules."""
logger.info("Schedule changed for tenant=%s device=%s, resyncing", tenant_id, device_id)
await sync_schedules()
async def start_backup_scheduler() -> None:
"""Start the APScheduler and load initial schedules from DB."""
global _scheduler
_scheduler = AsyncIOScheduler(timezone="UTC")
_scheduler.start()
await sync_schedules()
logger.info("Backup scheduler started with dynamic schedules")
async def stop_backup_scheduler() -> None:
"""Gracefully shutdown the scheduler."""
global _scheduler
if _scheduler:
_scheduler.shutdown(wait=False)
_scheduler = None
logger.info("Backup scheduler stopped")

View File

@@ -0,0 +1,378 @@
"""SSH-based config capture service for RouterOS devices.
This service handles:
1. capture_export() — SSH to device, run /export compact, return stdout text
2. capture_binary_backup() — SSH to device, trigger /system backup save, SFTP-download result
3. run_backup() — Orchestrate a full backup: capture + git commit + DB record
All functions are async (asyncssh is asyncio-native).
Security policy:
known_hosts=None is intentional — RouterOS devices use self-signed SSH host keys
that change on reset or key regeneration. This mirrors InsecureSkipVerify=true
used in the poller's TLS connection. The threat model accepts device impersonation
risk in exchange for operational simplicity (no pre-enrollment of host keys needed).
See Pitfall 2 in 04-RESEARCH.md.
pygit2 calls are synchronous C bindings and MUST be wrapped in run_in_executor.
See Pitfall 3 in 04-RESEARCH.md.
Phase 30: ALL backups (manual, scheduled, pre-restore) are encrypted via OpenBao
Transit (Tier 2) before git commit. The server retains decrypt capability for
on-demand viewing. Raw files in git are ciphertext; the API decrypts on GET.
"""
import asyncio
import base64
import io
import json
import logging
from datetime import datetime, timezone
import asyncssh
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.database import AdminAsyncSessionLocal, set_tenant_context
from app.models.config_backup import ConfigBackupRun
from app.models.device import Device
from app.services import git_store
from app.services.crypto import decrypt_credentials_hybrid
logger = logging.getLogger(__name__)
# Fixed backup file name on device flash — overwrites on each run so files
# don't accumulate. See Pitfall 4 in 04-RESEARCH.md.
_BACKUP_NAME = "portal-backup"
async def capture_export(
ip: str,
port: int = 22,
username: str = "",
password: str = "",
) -> str:
"""SSH to a RouterOS device and capture /export compact output.
Args:
ip: Device IP address.
port: SSH port (default 22; RouterOS default is 22).
username: SSH login username.
password: SSH login password.
Returns:
The raw RSC text from /export compact (may include RouterOS header line).
Raises:
asyncssh.Error: On SSH connection or command execution failure.
"""
async with asyncssh.connect(
ip,
port=port,
username=username,
password=password,
known_hosts=None, # RouterOS self-signed host keys — see module docstring
connect_timeout=30,
) as conn:
result = await conn.run("/export compact", check=True)
return result.stdout
async def capture_binary_backup(
ip: str,
port: int = 22,
username: str = "",
password: str = "",
) -> bytes:
"""SSH to a RouterOS device, create a binary backup, SFTP-download it, then clean up.
Uses a fixed backup name ({_BACKUP_NAME}.backup) so the file overwrites
on subsequent runs, preventing flash storage accumulation.
The cleanup (removing the file from device flash) runs in a try/finally
block so cleanup failures don't mask the actual backup error but are
logged for observability. See Pitfall 4 in 04-RESEARCH.md.
Args:
ip: Device IP address.
port: SSH port (default 22).
username: SSH login username.
password: SSH login password.
Returns:
Raw bytes of the binary backup file.
Raises:
asyncssh.Error: On SSH connection, command, or SFTP failure.
"""
async with asyncssh.connect(
ip,
port=port,
username=username,
password=password,
known_hosts=None,
connect_timeout=30,
) as conn:
# Step 1: Trigger backup creation on device flash.
await conn.run(
f"/system backup save name={_BACKUP_NAME} dont-encrypt=yes",
check=True,
)
buf = io.BytesIO()
try:
# Step 2: SFTP-download the backup file.
async with conn.start_sftp_client() as sftp:
async with sftp.open(f"{_BACKUP_NAME}.backup", "rb") as f:
buf.write(await f.read())
finally:
# Step 3: Remove backup file from device flash (best-effort cleanup).
try:
await conn.run(f"/file remove {_BACKUP_NAME}.backup", check=True)
except Exception as cleanup_err:
logger.warning(
"Failed to remove backup file from device %s: %s",
ip,
cleanup_err,
)
return buf.getvalue()
async def run_backup(
device_id: str,
tenant_id: str,
trigger_type: str,
db_session: AsyncSession | None = None,
) -> dict:
"""Orchestrate a full config backup for a device.
Steps:
1. Load device from DB (ip_address, encrypted_credentials).
2. Decrypt credentials using crypto.decrypt_credentials().
3. Capture /export compact and binary backup concurrently via asyncio.gather().
4. Compute line delta vs the most recent export.rsc in git (None for first backup).
5. Commit both files to the tenant's bare git repo (run_in_executor for pygit2).
6. Insert ConfigBackupRun record with commit SHA, trigger type, line deltas.
7. Return summary dict.
Args:
device_id: Device UUID as string.
tenant_id: Tenant UUID as string.
trigger_type: 'scheduled' | 'manual' | 'pre-restore'
db_session: Optional AsyncSession with RLS context already set.
If None, uses AdminAsyncSessionLocal (for scheduler context).
Returns:
Dict: {"commit_sha": str, "trigger_type": str, "lines_added": int|None, "lines_removed": int|None}
Raises:
ValueError: If device not found or missing credentials.
asyncssh.Error: On SSH/SFTP failure.
"""
loop = asyncio.get_event_loop()
ts = datetime.now(timezone.utc).isoformat()
# -----------------------------------------------------------------------
# Step 1: Load device from DB
# -----------------------------------------------------------------------
if db_session is not None:
session = db_session
should_close = False
else:
# Scheduler context: use admin session (cross-tenant; RLS bypassed)
session = AdminAsyncSessionLocal()
should_close = True
try:
from sqlalchemy import select
if should_close:
# Admin session doesn't have RLS context — query directly.
result = await session.execute(
select(Device).where(
Device.id == device_id, # type: ignore[arg-type]
Device.tenant_id == tenant_id, # type: ignore[arg-type]
)
)
else:
result = await session.execute(
select(Device).where(Device.id == device_id) # type: ignore[arg-type]
)
device = result.scalar_one_or_none()
if device is None:
raise ValueError(f"Device {device_id!r} not found for tenant {tenant_id!r}")
if not device.encrypted_credentials_transit and not device.encrypted_credentials:
raise ValueError(
f"Device {device_id!r} has no stored credentials — cannot perform backup"
)
# -----------------------------------------------------------------------
# Step 2: Decrypt credentials (dual-read: Transit preferred, legacy fallback)
# -----------------------------------------------------------------------
key = settings.get_encryption_key_bytes()
creds_json = await decrypt_credentials_hybrid(
device.encrypted_credentials_transit,
device.encrypted_credentials,
str(device.tenant_id),
key,
)
creds = json.loads(creds_json)
ssh_username = creds.get("username", "")
ssh_password = creds.get("password", "")
ip = device.ip_address
hostname = device.hostname or ip
# -----------------------------------------------------------------------
# Step 3: Capture export and binary backup concurrently
# -----------------------------------------------------------------------
logger.info(
"Starting %s backup for device %s (%s) tenant %s",
trigger_type,
hostname,
ip,
tenant_id,
)
export_text, binary_backup = await asyncio.gather(
capture_export(ip, username=ssh_username, password=ssh_password),
capture_binary_backup(ip, username=ssh_username, password=ssh_password),
)
# -----------------------------------------------------------------------
# Step 4: Compute line delta vs prior version
# -----------------------------------------------------------------------
lines_added: int | None = None
lines_removed: int | None = None
prior_commits = await loop.run_in_executor(
None, git_store.list_device_commits, tenant_id, device_id
)
if prior_commits:
try:
prior_export_bytes = await loop.run_in_executor(
None, git_store.read_file, tenant_id, prior_commits[0]["sha"], device_id, "export.rsc"
)
prior_text = prior_export_bytes.decode("utf-8", errors="replace")
lines_added, lines_removed = await loop.run_in_executor(
None, git_store.compute_line_delta, prior_text, export_text
)
except Exception as delta_err:
logger.warning(
"Failed to compute line delta for device %s: %s",
device_id,
delta_err,
)
# Keep lines_added/lines_removed as None on error — non-fatal
else:
# First backup: all lines are "added", none removed
all_lines = len(export_text.splitlines())
lines_added = all_lines
lines_removed = 0
# -----------------------------------------------------------------------
# Step 5: Encrypt ALL backups via Transit (Tier 2: OpenBao Transit)
# -----------------------------------------------------------------------
encryption_tier: int | None = None
git_export_content = export_text
git_binary_content = binary_backup
try:
from app.services.crypto import encrypt_data_transit
encrypted_export = await encrypt_data_transit(
export_text, tenant_id
)
encrypted_binary = await encrypt_data_transit(
base64.b64encode(binary_backup).decode(), tenant_id
)
# Transit ciphertext is text — store directly in git
git_export_content = encrypted_export
git_binary_content = encrypted_binary.encode("utf-8")
encryption_tier = 2
logger.info(
"Tier 2 Transit encryption applied for %s backup of device %s",
trigger_type,
device_id,
)
except Exception as enc_err:
# Transit unavailable — fall back to plaintext (non-fatal)
logger.warning(
"Transit encryption failed for %s backup of device %s, "
"storing plaintext: %s",
trigger_type,
device_id,
enc_err,
)
# Keep encryption_tier = None (plaintext fallback)
# -----------------------------------------------------------------------
# Step 6: Commit to git (wrapped in run_in_executor — pygit2 is sync C bindings)
# -----------------------------------------------------------------------
commit_message = (
f"{trigger_type}: {hostname} ({ip}) at {ts}"
)
commit_sha = await loop.run_in_executor(
None,
git_store.commit_backup,
tenant_id,
device_id,
git_export_content,
git_binary_content,
commit_message,
)
logger.info(
"Committed backup for device %s to git SHA %s (tier=%s)",
device_id,
commit_sha[:8],
encryption_tier,
)
# -----------------------------------------------------------------------
# Step 7: Insert ConfigBackupRun record
# -----------------------------------------------------------------------
if not should_close:
# RLS-scoped session from API context — record directly
backup_run = ConfigBackupRun(
device_id=device.id,
tenant_id=device.tenant_id,
commit_sha=commit_sha,
trigger_type=trigger_type,
lines_added=lines_added,
lines_removed=lines_removed,
encryption_tier=encryption_tier,
)
session.add(backup_run)
await session.flush()
else:
# Admin session — set tenant context before insert so RLS policy is satisfied
async with AdminAsyncSessionLocal() as admin_session:
await set_tenant_context(admin_session, str(device.tenant_id))
backup_run = ConfigBackupRun(
device_id=device.id,
tenant_id=device.tenant_id,
commit_sha=commit_sha,
trigger_type=trigger_type,
lines_added=lines_added,
lines_removed=lines_removed,
encryption_tier=encryption_tier,
)
admin_session.add(backup_run)
await admin_session.commit()
return {
"commit_sha": commit_sha,
"trigger_type": trigger_type,
"lines_added": lines_added,
"lines_removed": lines_removed,
}
finally:
if should_close:
await session.close()

View File

@@ -0,0 +1,462 @@
"""Certificate Authority service — CA generation, device cert signing, lifecycle.
This module provides the core PKI functionality for the Internal Certificate
Authority feature. All functions receive an ``AsyncSession`` and an
``encryption_key`` as parameters (no direct Settings access) for testability.
Security notes:
- CA private keys are encrypted with AES-256-GCM before database storage.
- PEM key material is NEVER logged.
- Device keys are decrypted only when needed for NATS transmission.
"""
from __future__ import annotations
import datetime
import ipaddress
import logging
from uuid import UUID
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.certificate import CertificateAuthority, DeviceCertificate
from app.services.crypto import (
decrypt_credentials_hybrid,
encrypt_credentials_transit,
)
logger = logging.getLogger(__name__)
# Valid status transitions for the device certificate lifecycle.
_VALID_TRANSITIONS: dict[str, set[str]] = {
"issued": {"deploying"},
"deploying": {"deployed", "issued"}, # issued = rollback on deploy failure
"deployed": {"expiring", "revoked", "superseded"},
"expiring": {"expired", "revoked", "superseded"},
"expired": {"superseded"},
"revoked": set(),
"superseded": set(),
}
# ---------------------------------------------------------------------------
# CA Generation
# ---------------------------------------------------------------------------
async def generate_ca(
db: AsyncSession,
tenant_id: UUID,
common_name: str,
validity_years: int,
encryption_key: bytes,
) -> CertificateAuthority:
"""Generate a self-signed root CA for a tenant.
Args:
db: Async database session.
tenant_id: Tenant UUID — only one CA per tenant.
common_name: CN for the CA certificate (e.g., "Portal Root CA").
validity_years: How many years the CA cert is valid.
encryption_key: 32-byte AES-256-GCM key for encrypting the CA private key.
Returns:
The newly created ``CertificateAuthority`` model instance.
Raises:
ValueError: If the tenant already has a CA.
"""
# Ensure one CA per tenant
existing = await get_ca_for_tenant(db, tenant_id)
if existing is not None:
raise ValueError(
f"Tenant {tenant_id} already has a CA (id={existing.id}). "
"Delete the existing CA before creating a new one."
)
# Generate RSA 2048 key pair
ca_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
now = datetime.datetime.now(datetime.timezone.utc)
expiry = now + datetime.timedelta(days=365 * validity_years)
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
])
ca_cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(ca_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(expiry)
.add_extension(
x509.BasicConstraints(ca=True, path_length=0), critical=True
)
.add_extension(
x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=True,
crl_sign=True,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.add_extension(
x509.SubjectKeyIdentifier.from_public_key(ca_key.public_key()),
critical=False,
)
.sign(ca_key, hashes.SHA256())
)
# Serialize public cert to PEM
cert_pem = ca_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
# Serialize private key to PEM, then encrypt with OpenBao Transit
key_pem = ca_key.private_bytes(
serialization.Encoding.PEM,
serialization.PrivateFormat.PKCS8,
serialization.NoEncryption(),
).decode("utf-8")
encrypted_key_transit = await encrypt_credentials_transit(key_pem, str(tenant_id))
# Compute SHA-256 fingerprint (colon-separated hex)
fingerprint_bytes = ca_cert.fingerprint(hashes.SHA256())
fingerprint = ":".join(f"{b:02X}" for b in fingerprint_bytes)
# Serial number as hex string
serial_hex = format(ca_cert.serial_number, "X")
model = CertificateAuthority(
tenant_id=tenant_id,
common_name=common_name,
cert_pem=cert_pem,
encrypted_private_key=b"", # Legacy column kept for schema compat
encrypted_private_key_transit=encrypted_key_transit,
serial_number=serial_hex,
fingerprint_sha256=fingerprint,
not_valid_before=now,
not_valid_after=expiry,
)
db.add(model)
await db.flush()
logger.info(
"Generated CA for tenant %s: cn=%s fingerprint=%s",
tenant_id,
common_name,
fingerprint,
)
return model
# ---------------------------------------------------------------------------
# Device Certificate Signing
# ---------------------------------------------------------------------------
async def sign_device_cert(
db: AsyncSession,
ca: CertificateAuthority,
device_id: UUID,
hostname: str,
ip_address: str,
validity_days: int,
encryption_key: bytes,
) -> DeviceCertificate:
"""Sign a per-device TLS certificate using the tenant's CA.
Args:
db: Async database session.
ca: The tenant's CertificateAuthority model instance.
device_id: UUID of the device receiving the cert.
hostname: Device hostname — used as CN and SAN DNSName.
ip_address: Device IP — used as SAN IPAddress.
validity_days: Certificate validity in days.
encryption_key: 32-byte AES-256-GCM key for encrypting the device private key.
Returns:
The newly created ``DeviceCertificate`` model instance (status='issued').
"""
# Decrypt CA private key (dual-read: Transit preferred, legacy fallback)
ca_key_pem = await decrypt_credentials_hybrid(
ca.encrypted_private_key_transit,
ca.encrypted_private_key,
str(ca.tenant_id),
encryption_key,
)
ca_key = serialization.load_pem_private_key(
ca_key_pem.encode("utf-8"), password=None
)
# Load CA certificate for issuer info and AuthorityKeyIdentifier
ca_cert = x509.load_pem_x509_certificate(ca.cert_pem.encode("utf-8"))
# Generate device RSA 2048 key
device_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
now = datetime.datetime.now(datetime.timezone.utc)
expiry = now + datetime.timedelta(days=validity_days)
device_cert = (
x509.CertificateBuilder()
.subject_name(
x509.Name([
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
x509.NameAttribute(NameOID.COMMON_NAME, hostname),
])
)
.issuer_name(ca_cert.subject)
.public_key(device_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(expiry)
.add_extension(
x509.BasicConstraints(ca=False, path_length=None), critical=True
)
.add_extension(
x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=True,
data_encipherment=False,
key_agreement=False,
key_cert_sign=False,
crl_sign=False,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.add_extension(
x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]),
critical=False,
)
.add_extension(
x509.SubjectAlternativeName([
x509.IPAddress(ipaddress.ip_address(ip_address)),
x509.DNSName(hostname),
]),
critical=False,
)
.add_extension(
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
ca_cert.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
).value
),
critical=False,
)
.sign(ca_key, hashes.SHA256())
)
# Serialize device cert and key to PEM
cert_pem = device_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
key_pem = device_key.private_bytes(
serialization.Encoding.PEM,
serialization.PrivateFormat.PKCS8,
serialization.NoEncryption(),
).decode("utf-8")
# Encrypt device private key via OpenBao Transit
encrypted_key_transit = await encrypt_credentials_transit(key_pem, str(ca.tenant_id))
# Compute fingerprint
fingerprint_bytes = device_cert.fingerprint(hashes.SHA256())
fingerprint = ":".join(f"{b:02X}" for b in fingerprint_bytes)
serial_hex = format(device_cert.serial_number, "X")
model = DeviceCertificate(
tenant_id=ca.tenant_id,
device_id=device_id,
ca_id=ca.id,
common_name=hostname,
serial_number=serial_hex,
fingerprint_sha256=fingerprint,
cert_pem=cert_pem,
encrypted_private_key=b"", # Legacy column kept for schema compat
encrypted_private_key_transit=encrypted_key_transit,
not_valid_before=now,
not_valid_after=expiry,
status="issued",
)
db.add(model)
await db.flush()
logger.info(
"Signed device cert for device %s: cn=%s fingerprint=%s",
device_id,
hostname,
fingerprint,
)
return model
# ---------------------------------------------------------------------------
# Queries
# ---------------------------------------------------------------------------
async def get_ca_for_tenant(
db: AsyncSession,
tenant_id: UUID,
) -> CertificateAuthority | None:
"""Return the tenant's CA, or None if not yet initialized."""
result = await db.execute(
select(CertificateAuthority).where(
CertificateAuthority.tenant_id == tenant_id
)
)
return result.scalar_one_or_none()
async def get_device_certs(
db: AsyncSession,
tenant_id: UUID,
device_id: UUID | None = None,
) -> list[DeviceCertificate]:
"""List device certificates for a tenant.
Args:
db: Async database session.
tenant_id: Tenant UUID.
device_id: If provided, filter to certs for this device only.
Returns:
List of DeviceCertificate models (excludes superseded by default).
"""
stmt = (
select(DeviceCertificate)
.where(DeviceCertificate.tenant_id == tenant_id)
.where(DeviceCertificate.status != "superseded")
)
if device_id is not None:
stmt = stmt.where(DeviceCertificate.device_id == device_id)
stmt = stmt.order_by(DeviceCertificate.created_at.desc())
result = await db.execute(stmt)
return list(result.scalars().all())
# ---------------------------------------------------------------------------
# Status Management
# ---------------------------------------------------------------------------
async def update_cert_status(
db: AsyncSession,
cert_id: UUID,
status: str,
deployed_at: datetime.datetime | None = None,
) -> DeviceCertificate:
"""Update a device certificate's lifecycle status.
Validates that the transition is allowed by the state machine:
issued -> deploying -> deployed -> expiring -> expired
\\-> revoked
\\-> superseded
Args:
db: Async database session.
cert_id: Certificate UUID.
status: New status value.
deployed_at: Timestamp to set when transitioning to 'deployed'.
Returns:
The updated DeviceCertificate model.
Raises:
ValueError: If the certificate is not found or the transition is invalid.
"""
result = await db.execute(
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
)
cert = result.scalar_one_or_none()
if cert is None:
raise ValueError(f"Device certificate {cert_id} not found")
allowed = _VALID_TRANSITIONS.get(cert.status, set())
if status not in allowed:
raise ValueError(
f"Invalid status transition: {cert.status} -> {status}. "
f"Allowed transitions from '{cert.status}': {allowed or 'none'}"
)
cert.status = status
cert.updated_at = datetime.datetime.now(datetime.timezone.utc)
if status == "deployed" and deployed_at is not None:
cert.deployed_at = deployed_at
elif status == "deployed":
cert.deployed_at = cert.updated_at
await db.flush()
logger.info(
"Updated cert %s status to %s",
cert_id,
status,
)
return cert
# ---------------------------------------------------------------------------
# Cert Data for Deployment
# ---------------------------------------------------------------------------
async def get_cert_for_deploy(
db: AsyncSession,
cert_id: UUID,
encryption_key: bytes,
) -> tuple[str, str, str]:
"""Retrieve and decrypt certificate data for NATS deployment.
Returns the device cert PEM, decrypted device key PEM, and the CA cert
PEM — everything needed to push to a device via the Go poller.
Args:
db: Async database session.
cert_id: Device certificate UUID.
encryption_key: 32-byte AES-256-GCM key to decrypt the device private key.
Returns:
Tuple of (cert_pem, key_pem_decrypted, ca_cert_pem).
Raises:
ValueError: If the certificate or its CA is not found.
"""
result = await db.execute(
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
)
cert = result.scalar_one_or_none()
if cert is None:
raise ValueError(f"Device certificate {cert_id} not found")
# Fetch the CA for the ca_cert_pem
ca_result = await db.execute(
select(CertificateAuthority).where(
CertificateAuthority.id == cert.ca_id
)
)
ca = ca_result.scalar_one_or_none()
if ca is None:
raise ValueError(f"CA {cert.ca_id} not found for certificate {cert_id}")
# Decrypt device private key (dual-read: Transit preferred, legacy fallback)
key_pem = await decrypt_credentials_hybrid(
cert.encrypted_private_key_transit,
cert.encrypted_private_key,
str(cert.tenant_id),
encryption_key,
)
return cert.cert_pem, key_pem, ca.cert_pem

View File

@@ -0,0 +1,118 @@
"""NATS subscriber for config change events from the Go poller.
Triggers automatic backups when out-of-band config changes are detected,
with 5-minute deduplication to prevent rapid-fire backups.
"""
import json
import logging
from datetime import datetime, timedelta, timezone
from typing import Any, Optional
from sqlalchemy import select
from app.config import settings
from app.database import AdminAsyncSessionLocal
from app.models.config_backup import ConfigBackupRun
from app.services import backup_service
logger = logging.getLogger(__name__)
DEDUP_WINDOW_MINUTES = 5
_nc: Optional[Any] = None
async def _last_backup_within_dedup_window(device_id: str) -> bool:
"""Check if a backup was created for this device in the last N minutes."""
cutoff = datetime.now(timezone.utc) - timedelta(minutes=DEDUP_WINDOW_MINUTES)
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
select(ConfigBackupRun)
.where(
ConfigBackupRun.device_id == device_id,
ConfigBackupRun.created_at > cutoff,
)
.limit(1)
)
return result.scalar_one_or_none() is not None
async def handle_config_changed(event: dict) -> None:
"""Handle a config change event. Trigger backup with dedup."""
device_id = event.get("device_id")
tenant_id = event.get("tenant_id")
if not device_id or not tenant_id:
logger.warning("Config change event missing device_id or tenant_id: %s", event)
return
# Dedup check
if await _last_backup_within_dedup_window(device_id):
logger.info(
"Config change on device %s — skipping backup (within %dm dedup window)",
device_id, DEDUP_WINDOW_MINUTES,
)
return
logger.info(
"Config change detected on device %s (tenant %s): %s -> %s",
device_id, tenant_id,
event.get("old_timestamp", "?"),
event.get("new_timestamp", "?"),
)
try:
async with AdminAsyncSessionLocal() as session:
await backup_service.run_backup(
device_id=device_id,
tenant_id=tenant_id,
trigger_type="config-change",
db_session=session,
)
await session.commit()
logger.info("Config-change backup completed for device %s", device_id)
except Exception as e:
logger.error("Config-change backup failed for device %s: %s", device_id, e)
async def _on_message(msg) -> None:
"""NATS message handler for config.changed.> subjects."""
try:
event = json.loads(msg.data.decode())
await handle_config_changed(event)
await msg.ack()
except Exception as e:
logger.error("Error handling config change message: %s", e)
await msg.nak()
async def start_config_change_subscriber() -> Optional[Any]:
"""Connect to NATS and subscribe to config.changed.> events."""
import nats
global _nc
try:
logger.info("NATS config-change: connecting to %s", settings.NATS_URL)
_nc = await nats.connect(settings.NATS_URL)
js = _nc.jetstream()
await js.subscribe(
"config.changed.>",
cb=_on_message,
durable="api-config-change-consumer",
stream="DEVICE_EVENTS",
manual_ack=True,
)
logger.info("Config change subscriber started")
return _nc
except Exception as e:
logger.error("Failed to start config change subscriber: %s", e)
return None
async def stop_config_change_subscriber() -> None:
"""Gracefully close the NATS connection."""
global _nc
if _nc:
await _nc.drain()
_nc = None

View File

@@ -0,0 +1,183 @@
"""
Credential encryption/decryption with dual-read (OpenBao Transit + legacy AES-256-GCM).
This module provides two encryption paths:
1. Legacy (sync): AES-256-GCM with static CREDENTIAL_ENCRYPTION_KEY — used for fallback reads.
2. Transit (async): OpenBao Transit per-tenant keys — used for all new writes.
The dual-read pattern:
- New writes always use OpenBao Transit (encrypt_credentials_transit).
- Reads prefer Transit ciphertext, falling back to legacy (decrypt_credentials_hybrid).
- Legacy functions are preserved for backward compatibility during migration.
Security properties:
- AES-256-GCM provides authenticated encryption (confidentiality + integrity)
- A unique 12-byte random nonce is generated per legacy encryption operation
- OpenBao Transit keys are AES-256-GCM96, managed entirely by OpenBao
- Ciphertext format: "vault:v1:..." for Transit, raw bytes for legacy
"""
import os
def encrypt_credentials(plaintext: str, key: bytes) -> bytes:
"""
Encrypt a plaintext string using AES-256-GCM.
Args:
plaintext: The credential string to encrypt (e.g., JSON with username/password)
key: 32-byte encryption key
Returns:
bytes: nonce (12 bytes) + ciphertext + GCM tag (16 bytes)
Raises:
ValueError: If key is not exactly 32 bytes
"""
if len(key) != 32:
raise ValueError(f"Key must be exactly 32 bytes, got {len(key)}")
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
aesgcm = AESGCM(key)
nonce = os.urandom(12) # 96-bit nonce, unique per encryption
ciphertext = aesgcm.encrypt(nonce, plaintext.encode("utf-8"), None)
# Store as: nonce (12 bytes) + ciphertext + GCM tag (included in ciphertext by library)
return nonce + ciphertext
def decrypt_credentials(ciphertext: bytes, key: bytes) -> str:
"""
Decrypt AES-256-GCM encrypted credentials.
Args:
ciphertext: bytes from encrypt_credentials (nonce + encrypted data + GCM tag)
key: 32-byte encryption key (must match the key used for encryption)
Returns:
str: The original plaintext string
Raises:
ValueError: If key is not exactly 32 bytes
cryptography.exceptions.InvalidTag: If authentication fails (tampered data or wrong key)
"""
if len(key) != 32:
raise ValueError(f"Key must be exactly 32 bytes, got {len(key)}")
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
nonce = ciphertext[:12]
encrypted_data = ciphertext[12:]
aesgcm = AESGCM(key)
plaintext_bytes = aesgcm.decrypt(nonce, encrypted_data, None)
return plaintext_bytes.decode("utf-8")
# ---------------------------------------------------------------------------
# OpenBao Transit functions (async, per-tenant keys)
# ---------------------------------------------------------------------------
async def encrypt_credentials_transit(plaintext: str, tenant_id: str) -> str:
"""Encrypt via OpenBao Transit. Returns ciphertext string (vault:v1:...).
Args:
plaintext: The credential string to encrypt.
tenant_id: Tenant UUID string for key lookup.
Returns:
Transit ciphertext string (vault:v1:base64...).
"""
from app.services.openbao_service import get_openbao_service
service = get_openbao_service()
return await service.encrypt(tenant_id, plaintext.encode("utf-8"))
async def decrypt_credentials_transit(ciphertext: str, tenant_id: str) -> str:
"""Decrypt OpenBao Transit ciphertext. Returns plaintext string.
Args:
ciphertext: Transit ciphertext (vault:v1:...).
tenant_id: Tenant UUID string for key lookup.
Returns:
Decrypted plaintext string.
"""
from app.services.openbao_service import get_openbao_service
service = get_openbao_service()
plaintext_bytes = await service.decrypt(tenant_id, ciphertext)
return plaintext_bytes.decode("utf-8")
# ---------------------------------------------------------------------------
# OpenBao Transit data encryption (async, per-tenant _data keys — Phase 30)
# ---------------------------------------------------------------------------
async def encrypt_data_transit(plaintext: str, tenant_id: str) -> str:
"""Encrypt non-credential data via OpenBao Transit using per-tenant data key.
Used for audit log details, config backups, and reports. Data keys are
separate from credential keys (tenant_{uuid}_data vs tenant_{uuid}).
Args:
plaintext: The data string to encrypt.
tenant_id: Tenant UUID string for data key lookup.
Returns:
Transit ciphertext string (vault:v1:base64...).
"""
from app.services.openbao_service import get_openbao_service
service = get_openbao_service()
return await service.encrypt_data(tenant_id, plaintext.encode("utf-8"))
async def decrypt_data_transit(ciphertext: str, tenant_id: str) -> str:
"""Decrypt OpenBao Transit data ciphertext. Returns plaintext string.
Args:
ciphertext: Transit ciphertext (vault:v1:...).
tenant_id: Tenant UUID string for data key lookup.
Returns:
Decrypted plaintext string.
"""
from app.services.openbao_service import get_openbao_service
service = get_openbao_service()
plaintext_bytes = await service.decrypt_data(tenant_id, ciphertext)
return plaintext_bytes.decode("utf-8")
async def decrypt_credentials_hybrid(
transit_ciphertext: str | None,
legacy_ciphertext: bytes | None,
tenant_id: str,
legacy_key: bytes,
) -> str:
"""Dual-read: prefer Transit ciphertext, fall back to legacy.
Args:
transit_ciphertext: OpenBao Transit ciphertext (vault:v1:...) or None.
legacy_ciphertext: Legacy AES-256-GCM bytes (nonce+ciphertext+tag) or None.
tenant_id: Tenant UUID string for Transit key lookup.
legacy_key: 32-byte legacy encryption key for fallback.
Returns:
Decrypted plaintext string.
Raises:
ValueError: If neither ciphertext is available.
"""
if transit_ciphertext and transit_ciphertext.startswith("vault:v"):
return await decrypt_credentials_transit(transit_ciphertext, tenant_id)
elif legacy_ciphertext:
return decrypt_credentials(legacy_ciphertext, legacy_key)
else:
raise ValueError("No credentials available (both transit and legacy are empty)")

View File

@@ -0,0 +1,670 @@
"""
Device service — business logic for device CRUD, credential encryption, groups, and tags.
All functions operate via the app_user engine (RLS enforced).
Tenant isolation is handled automatically by PostgreSQL RLS policies
(SET LOCAL app.current_tenant is set by the get_current_user dependency before
this layer is called).
Credential policy:
- Credentials are always stored as AES-256-GCM encrypted JSON blobs.
- Credentials are NEVER returned in any public-facing response.
- Re-encryption happens only when a new password is explicitly provided in an update.
"""
import asyncio
import json
import uuid
from typing import Optional
from sqlalchemy import func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models.device import (
Device,
DeviceGroup,
DeviceGroupMembership,
DeviceTag,
DeviceTagAssignment,
)
from app.schemas.device import (
BulkAddRequest,
BulkAddResult,
DeviceCreate,
DeviceGroupCreate,
DeviceGroupResponse,
DeviceGroupUpdate,
DeviceResponse,
DeviceTagCreate,
DeviceTagResponse,
DeviceTagUpdate,
DeviceUpdate,
)
from app.config import settings
from app.services.crypto import (
decrypt_credentials,
decrypt_credentials_hybrid,
encrypt_credentials,
encrypt_credentials_transit,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _tcp_reachable(ip: str, port: int, timeout: float = 3.0) -> bool:
"""Return True if a TCP connection to ip:port succeeds within timeout."""
try:
_, writer = await asyncio.wait_for(
asyncio.open_connection(ip, port), timeout=timeout
)
writer.close()
try:
await writer.wait_closed()
except Exception:
pass
return True
except Exception:
return False
def _build_device_response(device: Device) -> DeviceResponse:
"""
Build a DeviceResponse from an ORM Device instance.
Tags and groups are extracted from pre-loaded relationships.
Credentials are explicitly EXCLUDED.
"""
from app.schemas.device import DeviceGroupRef, DeviceTagRef
tags = [
DeviceTagRef(
id=a.tag.id,
name=a.tag.name,
color=a.tag.color,
)
for a in device.tag_assignments
]
groups = [
DeviceGroupRef(
id=m.group.id,
name=m.group.name,
)
for m in device.group_memberships
]
return DeviceResponse(
id=device.id,
hostname=device.hostname,
ip_address=device.ip_address,
api_port=device.api_port,
api_ssl_port=device.api_ssl_port,
model=device.model,
serial_number=device.serial_number,
firmware_version=device.firmware_version,
routeros_version=device.routeros_version,
uptime_seconds=device.uptime_seconds,
last_seen=device.last_seen,
latitude=device.latitude,
longitude=device.longitude,
status=device.status,
tls_mode=device.tls_mode,
tags=tags,
groups=groups,
created_at=device.created_at,
)
def _device_with_relations():
"""Return a select() for Device with tags and groups eagerly loaded."""
return select(Device).options(
selectinload(Device.tag_assignments).selectinload(DeviceTagAssignment.tag),
selectinload(Device.group_memberships).selectinload(DeviceGroupMembership.group),
)
# ---------------------------------------------------------------------------
# Device CRUD
# ---------------------------------------------------------------------------
async def create_device(
db: AsyncSession,
tenant_id: uuid.UUID,
data: DeviceCreate,
encryption_key: bytes,
) -> DeviceResponse:
"""
Create a new device.
- Validates TCP connectivity (api_port or api_ssl_port must be reachable).
- Encrypts credentials before storage.
- Status set to "unknown" until the Go poller runs a full auth check (Phase 2).
"""
# Test connectivity before accepting the device
api_reachable = await _tcp_reachable(data.ip_address, data.api_port)
ssl_reachable = await _tcp_reachable(data.ip_address, data.api_ssl_port)
if not api_reachable and not ssl_reachable:
from fastapi import HTTPException, status
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=(
f"Cannot reach {data.ip_address} on port {data.api_port} "
f"(RouterOS API) or {data.api_ssl_port} (RouterOS SSL API). "
"Verify the IP address and that the RouterOS API is enabled."
),
)
# Encrypt credentials via OpenBao Transit (new writes go through Transit)
credentials_json = json.dumps({"username": data.username, "password": data.password})
transit_ciphertext = await encrypt_credentials_transit(
credentials_json, str(tenant_id)
)
device = Device(
tenant_id=tenant_id,
hostname=data.hostname,
ip_address=data.ip_address,
api_port=data.api_port,
api_ssl_port=data.api_ssl_port,
encrypted_credentials_transit=transit_ciphertext,
status="unknown",
)
db.add(device)
await db.flush() # Get the ID without committing
await db.refresh(device)
# Re-query with relationships loaded
result = await db.execute(
_device_with_relations().where(Device.id == device.id)
)
device = result.scalar_one()
return _build_device_response(device)
async def get_devices(
db: AsyncSession,
tenant_id: uuid.UUID,
page: int = 1,
page_size: int = 25,
status: Optional[str] = None,
search: Optional[str] = None,
tag_id: Optional[uuid.UUID] = None,
group_id: Optional[uuid.UUID] = None,
sort_by: str = "created_at",
sort_order: str = "desc",
) -> tuple[list[DeviceResponse], int]:
"""
Return a paginated list of devices with optional filtering and sorting.
Returns (items, total_count).
RLS automatically scopes this to the caller's tenant.
"""
base_q = _device_with_relations()
# Filtering
if status:
base_q = base_q.where(Device.status == status)
if search:
pattern = f"%{search}%"
base_q = base_q.where(
or_(
Device.hostname.ilike(pattern),
Device.ip_address.ilike(pattern),
)
)
if tag_id:
base_q = base_q.where(
Device.id.in_(
select(DeviceTagAssignment.device_id).where(
DeviceTagAssignment.tag_id == tag_id
)
)
)
if group_id:
base_q = base_q.where(
Device.id.in_(
select(DeviceGroupMembership.device_id).where(
DeviceGroupMembership.group_id == group_id
)
)
)
# Count total before pagination
count_q = select(func.count()).select_from(base_q.subquery())
total_result = await db.execute(count_q)
total = total_result.scalar_one()
# Sorting
allowed_sort_cols = {
"created_at": Device.created_at,
"hostname": Device.hostname,
"ip_address": Device.ip_address,
"status": Device.status,
"last_seen": Device.last_seen,
}
sort_col = allowed_sort_cols.get(sort_by, Device.created_at)
if sort_order.lower() == "asc":
base_q = base_q.order_by(sort_col.asc())
else:
base_q = base_q.order_by(sort_col.desc())
# Pagination
offset = (page - 1) * page_size
base_q = base_q.offset(offset).limit(page_size)
result = await db.execute(base_q)
devices = result.scalars().all()
return [_build_device_response(d) for d in devices], total
async def get_device(
db: AsyncSession,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
) -> DeviceResponse:
"""Get a single device by ID."""
from fastapi import HTTPException, status
result = await db.execute(
_device_with_relations().where(Device.id == device_id)
)
device = result.scalar_one_or_none()
if not device:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
return _build_device_response(device)
async def update_device(
db: AsyncSession,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
data: DeviceUpdate,
encryption_key: bytes,
) -> DeviceResponse:
"""
Update device fields. Re-encrypts credentials only if password is provided.
"""
from fastapi import HTTPException, status
result = await db.execute(
_device_with_relations().where(Device.id == device_id)
)
device = result.scalar_one_or_none()
if not device:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
# Update scalar fields
if data.hostname is not None:
device.hostname = data.hostname
if data.ip_address is not None:
device.ip_address = data.ip_address
if data.api_port is not None:
device.api_port = data.api_port
if data.api_ssl_port is not None:
device.api_ssl_port = data.api_ssl_port
if data.latitude is not None:
device.latitude = data.latitude
if data.longitude is not None:
device.longitude = data.longitude
if data.tls_mode is not None:
device.tls_mode = data.tls_mode
# Re-encrypt credentials if new ones are provided
credentials_changed = False
if data.password is not None:
# Decrypt existing to get current username if no new username given
current_username: str = data.username or ""
if not current_username and (device.encrypted_credentials_transit or device.encrypted_credentials):
try:
existing_json = await decrypt_credentials_hybrid(
device.encrypted_credentials_transit,
device.encrypted_credentials,
str(device.tenant_id),
settings.get_encryption_key_bytes(),
)
existing = json.loads(existing_json)
current_username = existing.get("username", "")
except Exception:
current_username = ""
credentials_json = json.dumps({
"username": data.username if data.username is not None else current_username,
"password": data.password,
})
# New writes go through Transit
device.encrypted_credentials_transit = await encrypt_credentials_transit(
credentials_json, str(device.tenant_id)
)
device.encrypted_credentials = None # Clear legacy (Transit is canonical)
credentials_changed = True
elif data.username is not None and (device.encrypted_credentials_transit or device.encrypted_credentials):
# Only username changed — update it without changing the password
try:
existing_json = await decrypt_credentials_hybrid(
device.encrypted_credentials_transit,
device.encrypted_credentials,
str(device.tenant_id),
settings.get_encryption_key_bytes(),
)
existing = json.loads(existing_json)
existing["username"] = data.username
# Re-encrypt via Transit
device.encrypted_credentials_transit = await encrypt_credentials_transit(
json.dumps(existing), str(device.tenant_id)
)
device.encrypted_credentials = None
credentials_changed = True
except Exception:
pass # Keep existing encrypted blob if decryption fails
await db.flush()
await db.refresh(device)
# Notify poller to invalidate cached credentials (fire-and-forget via NATS)
if credentials_changed:
try:
from app.services.event_publisher import publish_event
await publish_event(
f"device.credential_changed.{device_id}",
{"device_id": str(device_id), "tenant_id": str(tenant_id)},
)
except Exception:
pass # Never fail the update due to NATS issues
result2 = await db.execute(
_device_with_relations().where(Device.id == device_id)
)
device = result2.scalar_one()
return _build_device_response(device)
async def delete_device(
db: AsyncSession,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
) -> None:
"""Hard-delete a device (v1 — no soft delete for devices)."""
from fastapi import HTTPException, status
result = await db.execute(select(Device).where(Device.id == device_id))
device = result.scalar_one_or_none()
if not device:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
await db.delete(device)
await db.flush()
# ---------------------------------------------------------------------------
# Group / Tag assignment
# ---------------------------------------------------------------------------
async def assign_device_to_group(
db: AsyncSession,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
group_id: uuid.UUID,
) -> None:
"""Assign a device to a group (idempotent)."""
from fastapi import HTTPException, status
# Verify device and group exist (RLS scopes both)
dev = await db.get(Device, device_id)
if not dev:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
grp = await db.get(DeviceGroup, group_id)
if not grp:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group not found")
existing = await db.get(DeviceGroupMembership, (device_id, group_id))
if not existing:
db.add(DeviceGroupMembership(device_id=device_id, group_id=group_id))
await db.flush()
async def remove_device_from_group(
db: AsyncSession,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
group_id: uuid.UUID,
) -> None:
"""Remove a device from a group."""
from fastapi import HTTPException, status
membership = await db.get(DeviceGroupMembership, (device_id, group_id))
if not membership:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Device is not in this group",
)
await db.delete(membership)
await db.flush()
async def assign_tag_to_device(
db: AsyncSession,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
tag_id: uuid.UUID,
) -> None:
"""Assign a tag to a device (idempotent)."""
from fastapi import HTTPException, status
dev = await db.get(Device, device_id)
if not dev:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
tag = await db.get(DeviceTag, tag_id)
if not tag:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tag not found")
existing = await db.get(DeviceTagAssignment, (device_id, tag_id))
if not existing:
db.add(DeviceTagAssignment(device_id=device_id, tag_id=tag_id))
await db.flush()
async def remove_tag_from_device(
db: AsyncSession,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
tag_id: uuid.UUID,
) -> None:
"""Remove a tag from a device."""
from fastapi import HTTPException, status
assignment = await db.get(DeviceTagAssignment, (device_id, tag_id))
if not assignment:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Tag is not assigned to this device",
)
await db.delete(assignment)
await db.flush()
# ---------------------------------------------------------------------------
# DeviceGroup CRUD
# ---------------------------------------------------------------------------
async def create_group(
db: AsyncSession,
tenant_id: uuid.UUID,
data: DeviceGroupCreate,
) -> DeviceGroupResponse:
"""Create a new device group."""
group = DeviceGroup(
tenant_id=tenant_id,
name=data.name,
description=data.description,
)
db.add(group)
await db.flush()
await db.refresh(group)
# Count devices in the group (0 for new group)
return DeviceGroupResponse(
id=group.id,
name=group.name,
description=group.description,
device_count=0,
created_at=group.created_at,
)
async def get_groups(
db: AsyncSession,
tenant_id: uuid.UUID,
) -> list[DeviceGroupResponse]:
"""Return all device groups for the current tenant with device counts."""
result = await db.execute(
select(DeviceGroup).options(
selectinload(DeviceGroup.memberships)
)
)
groups = result.scalars().all()
return [
DeviceGroupResponse(
id=g.id,
name=g.name,
description=g.description,
device_count=len(g.memberships),
created_at=g.created_at,
)
for g in groups
]
async def update_group(
db: AsyncSession,
tenant_id: uuid.UUID,
group_id: uuid.UUID,
data: DeviceGroupUpdate,
) -> DeviceGroupResponse:
"""Update a device group."""
from fastapi import HTTPException, status
result = await db.execute(
select(DeviceGroup).options(
selectinload(DeviceGroup.memberships)
).where(DeviceGroup.id == group_id)
)
group = result.scalar_one_or_none()
if not group:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group not found")
if data.name is not None:
group.name = data.name
if data.description is not None:
group.description = data.description
await db.flush()
await db.refresh(group)
result2 = await db.execute(
select(DeviceGroup).options(
selectinload(DeviceGroup.memberships)
).where(DeviceGroup.id == group_id)
)
group = result2.scalar_one()
return DeviceGroupResponse(
id=group.id,
name=group.name,
description=group.description,
device_count=len(group.memberships),
created_at=group.created_at,
)
async def delete_group(
db: AsyncSession,
tenant_id: uuid.UUID,
group_id: uuid.UUID,
) -> None:
"""Delete a device group."""
from fastapi import HTTPException, status
group = await db.get(DeviceGroup, group_id)
if not group:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group not found")
await db.delete(group)
await db.flush()
# ---------------------------------------------------------------------------
# DeviceTag CRUD
# ---------------------------------------------------------------------------
async def create_tag(
db: AsyncSession,
tenant_id: uuid.UUID,
data: DeviceTagCreate,
) -> DeviceTagResponse:
"""Create a new device tag."""
tag = DeviceTag(
tenant_id=tenant_id,
name=data.name,
color=data.color,
)
db.add(tag)
await db.flush()
await db.refresh(tag)
return DeviceTagResponse(id=tag.id, name=tag.name, color=tag.color)
async def get_tags(
db: AsyncSession,
tenant_id: uuid.UUID,
) -> list[DeviceTagResponse]:
"""Return all device tags for the current tenant."""
result = await db.execute(select(DeviceTag))
tags = result.scalars().all()
return [DeviceTagResponse(id=t.id, name=t.name, color=t.color) for t in tags]
async def update_tag(
db: AsyncSession,
tenant_id: uuid.UUID,
tag_id: uuid.UUID,
data: DeviceTagUpdate,
) -> DeviceTagResponse:
"""Update a device tag."""
from fastapi import HTTPException, status
tag = await db.get(DeviceTag, tag_id)
if not tag:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tag not found")
if data.name is not None:
tag.name = data.name
if data.color is not None:
tag.color = data.color
await db.flush()
await db.refresh(tag)
return DeviceTagResponse(id=tag.id, name=tag.name, color=tag.color)
async def delete_tag(
db: AsyncSession,
tenant_id: uuid.UUID,
tag_id: uuid.UUID,
) -> None:
"""Delete a device tag."""
from fastapi import HTTPException, status
tag = await db.get(DeviceTag, tag_id)
if not tag:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tag not found")
await db.delete(tag)
await db.flush()

View File

@@ -0,0 +1,124 @@
"""Unified email sending service.
All email sending (system emails, alert notifications) goes through this module.
Supports TLS, STARTTLS, and plain SMTP. Handles Transit + legacy Fernet password decryption.
"""
import logging
from email.message import EmailMessage
from typing import Optional
import aiosmtplib
logger = logging.getLogger(__name__)
class SMTPConfig:
"""SMTP connection configuration."""
def __init__(
self,
host: str,
port: int = 587,
user: Optional[str] = None,
password: Optional[str] = None,
use_tls: bool = False,
from_address: str = "noreply@example.com",
):
self.host = host
self.port = port
self.user = user
self.password = password
self.use_tls = use_tls
self.from_address = from_address
async def send_email(
to: str,
subject: str,
html: str,
plain_text: str,
smtp_config: SMTPConfig,
) -> None:
"""Send an email via SMTP.
Args:
to: Recipient email address.
subject: Email subject line.
html: HTML body.
plain_text: Plain text fallback body.
smtp_config: SMTP connection settings.
Raises:
aiosmtplib.SMTPException: On SMTP connection or send failure.
"""
msg = EmailMessage()
msg["Subject"] = subject
msg["From"] = smtp_config.from_address
msg["To"] = to
msg.set_content(plain_text)
msg.add_alternative(html, subtype="html")
use_tls = smtp_config.use_tls
start_tls = not use_tls if smtp_config.port != 25 else False
await aiosmtplib.send(
msg,
hostname=smtp_config.host,
port=smtp_config.port,
username=smtp_config.user or None,
password=smtp_config.password or None,
use_tls=use_tls,
start_tls=start_tls,
)
async def test_smtp_connection(smtp_config: SMTPConfig) -> dict:
"""Test SMTP connectivity without sending an email.
Returns:
dict with "success" bool and "message" string.
"""
try:
smtp = aiosmtplib.SMTP(
hostname=smtp_config.host,
port=smtp_config.port,
use_tls=smtp_config.use_tls,
start_tls=not smtp_config.use_tls if smtp_config.port != 25 else False,
)
await smtp.connect()
if smtp_config.user and smtp_config.password:
await smtp.login(smtp_config.user, smtp_config.password)
await smtp.quit()
return {"success": True, "message": "SMTP connection successful"}
except Exception as e:
return {"success": False, "message": str(e)}
async def send_test_email(to: str, smtp_config: SMTPConfig) -> dict:
"""Send a test email to verify the full SMTP flow.
Returns:
dict with "success" bool and "message" string.
"""
html = """
<div style="font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; max-width: 600px; margin: 0 auto;">
<div style="background: #0f172a; padding: 24px; border-radius: 8px 8px 0 0;">
<h2 style="color: #38bdf8; margin: 0;">TOD — Email Test</h2>
</div>
<div style="background: #1e293b; padding: 24px; border-radius: 0 0 8px 8px; color: #e2e8f0;">
<p>This is a test email from The Other Dude.</p>
<p>If you're reading this, your SMTP configuration is working correctly.</p>
<p style="color: #94a3b8; font-size: 13px; margin-top: 24px;">
Sent from TOD Fleet Management
</p>
</div>
</div>
"""
plain = "TOD — Email Test\n\nThis is a test email from The Other Dude.\nIf you're reading this, your SMTP configuration is working correctly."
try:
await send_email(to, "TOD — Test Email", html, plain, smtp_config)
return {"success": True, "message": f"Test email sent to {to}"}
except Exception as e:
return {"success": False, "message": str(e)}

View File

@@ -0,0 +1,54 @@
"""Emergency Kit PDF template generation.
Generates an Emergency Kit PDF containing the user's email and sign-in URL
but NOT the Secret Key. The Secret Key placeholder is filled client-side
so that the server never sees it.
Uses Jinja2 + WeasyPrint following the same pattern as the reports service.
"""
import asyncio
from datetime import UTC, datetime
from pathlib import Path
from jinja2 import Environment, FileSystemLoader
from app.config import settings
TEMPLATE_DIR = Path(__file__).parent.parent.parent / "templates"
async def generate_emergency_kit_template(
email: str,
) -> bytes:
"""Generate Emergency Kit PDF template WITHOUT the Secret Key.
The Secret Key placeholder will be filled client-side.
The server never sees the Secret Key.
Args:
email: The user's email address to display in the PDF.
Returns:
PDF bytes ready for streaming response.
"""
env = Environment(
loader=FileSystemLoader(str(TEMPLATE_DIR)),
autoescape=True,
)
template = env.get_template("emergency_kit.html")
html_content = template.render(
email=email,
signin_url=settings.APP_BASE_URL,
date=datetime.now(UTC).strftime("%Y-%m-%d"),
secret_key_placeholder="[Download complete -- your Secret Key will be inserted by your browser]",
)
# Run weasyprint in thread to avoid blocking the event loop
from weasyprint import HTML
pdf_bytes = await asyncio.to_thread(
lambda: HTML(string=html_content).write_pdf()
)
return pdf_bytes

View File

@@ -0,0 +1,52 @@
"""Fire-and-forget NATS JetStream event publisher for real-time SSE pipeline.
Provides a shared lazy NATS connection and publish helper used by:
- alert_evaluator.py (alert.fired.{tenant_id}, alert.resolved.{tenant_id})
- restore_service.py (config.push.{tenant_id}.{device_id})
- upgrade_service.py (firmware.progress.{tenant_id}.{device_id})
All publishes are fire-and-forget: errors are logged but never propagate
to the caller. A NATS outage must never block alert evaluation, config
push, or firmware upgrade operations.
"""
import json
import logging
from typing import Any
import nats
import nats.aio.client
from app.config import settings
logger = logging.getLogger(__name__)
# Module-level NATS connection (lazy initialized, reused across publishes)
_nc: nats.aio.client.Client | None = None
async def _get_nats() -> nats.aio.client.Client:
"""Get or create a NATS connection for event publishing."""
global _nc
if _nc is None or _nc.is_closed:
_nc = await nats.connect(settings.NATS_URL)
logger.info("Event publisher NATS connection established")
return _nc
async def publish_event(subject: str, payload: dict[str, Any]) -> None:
"""Publish a JSON event to a NATS JetStream subject (fire-and-forget).
Args:
subject: NATS subject, e.g. "alert.fired.{tenant_id}".
payload: Dict that will be JSON-serialized as the message body.
Never raises -- all exceptions are caught and logged as warnings.
"""
try:
nc = await _get_nats()
js = nc.jetstream()
await js.publish(subject, json.dumps(payload).encode())
logger.debug("Published event to %s", subject)
except Exception as exc:
logger.warning("Failed to publish event to %s: %s", subject, exc)

View File

@@ -0,0 +1,303 @@
"""Firmware version cache service and NPK downloader.
Responsibilities:
- check_latest_versions(): fetch latest RouterOS versions from download.mikrotik.com
- download_firmware(): download NPK packages to local PVC cache
- get_firmware_overview(): return fleet firmware status for a tenant
- schedule_firmware_checks(): register daily firmware check job with APScheduler
Version discovery comes from two sources:
1. Go poller runs /system/package/update per device (rate-limited to once/day)
and publishes via NATS -> firmware_subscriber processes these events
2. check_latest_versions() fetches LATEST.7 / LATEST.6 from download.mikrotik.com
"""
import logging
import os
from pathlib import Path
import httpx
from sqlalchemy import text
from app.config import settings
from app.database import AdminAsyncSessionLocal
logger = logging.getLogger(__name__)
# Architectures supported by RouterOS v7 and v6
_V7_ARCHITECTURES = ["arm", "arm64", "mipsbe", "mmips", "smips", "tile", "ppc", "x86"]
_V6_ARCHITECTURES = ["mipsbe", "mmips", "smips", "tile", "ppc", "x86"]
# Version source files on download.mikrotik.com
_VERSION_SOURCES = [
("LATEST.7", "stable", 7),
("LATEST.7long", "long-term", 7),
("LATEST.6", "stable", 6),
("LATEST.6long", "long-term", 6),
]
async def check_latest_versions() -> list[dict]:
"""Fetch latest RouterOS versions from download.mikrotik.com.
Checks LATEST.7, LATEST.7long, LATEST.6, and LATEST.6long files for
version strings, then upserts into firmware_versions table for each
architecture/channel combination.
Returns list of discovered version dicts.
"""
results: list[dict] = []
async with httpx.AsyncClient(timeout=30.0) as client:
for channel_file, channel, major in _VERSION_SOURCES:
try:
resp = await client.get(
f"https://download.mikrotik.com/routeros/{channel_file}"
)
if resp.status_code != 200:
logger.warning(
"MikroTik version check returned %d for %s",
resp.status_code, channel_file,
)
continue
version = resp.text.strip()
if not version or not version[0].isdigit():
logger.warning("Invalid version string from %s: %r", channel_file, version)
continue
architectures = _V7_ARCHITECTURES if major == 7 else _V6_ARCHITECTURES
for arch in architectures:
npk_url = (
f"https://download.mikrotik.com/routeros/"
f"{version}/routeros-{version}-{arch}.npk"
)
results.append({
"architecture": arch,
"channel": channel,
"version": version,
"npk_url": npk_url,
})
except Exception as e:
logger.warning("Failed to check %s: %s", channel_file, e)
# Upsert into firmware_versions table
if results:
async with AdminAsyncSessionLocal() as session:
for r in results:
await session.execute(
text("""
INSERT INTO firmware_versions (id, architecture, channel, version, npk_url, checked_at)
VALUES (gen_random_uuid(), :arch, :channel, :version, :npk_url, NOW())
ON CONFLICT (architecture, channel, version) DO UPDATE SET checked_at = NOW()
"""),
{
"arch": r["architecture"],
"channel": r["channel"],
"version": r["version"],
"npk_url": r["npk_url"],
},
)
await session.commit()
logger.info("Firmware version check complete — %d versions discovered", len(results))
return results
async def download_firmware(architecture: str, channel: str, version: str) -> str:
"""Download an NPK package to the local firmware cache.
Returns the local file path. Skips download if file already exists
and size matches.
"""
cache_dir = Path(settings.FIRMWARE_CACHE_DIR) / version
cache_dir.mkdir(parents=True, exist_ok=True)
filename = f"routeros-{version}-{architecture}.npk"
local_path = cache_dir / filename
npk_url = f"https://download.mikrotik.com/routeros/{version}/{filename}"
# Check if already cached
if local_path.exists() and local_path.stat().st_size > 0:
logger.info("Firmware already cached: %s", local_path)
return str(local_path)
logger.info("Downloading firmware: %s", npk_url)
async with httpx.AsyncClient(timeout=300.0) as client:
async with client.stream("GET", npk_url) as response:
response.raise_for_status()
with open(local_path, "wb") as f:
async for chunk in response.aiter_bytes(chunk_size=65536):
f.write(chunk)
file_size = local_path.stat().st_size
logger.info("Firmware downloaded: %s (%d bytes)", local_path, file_size)
# Update firmware_versions table with local path and size
async with AdminAsyncSessionLocal() as session:
await session.execute(
text("""
UPDATE firmware_versions
SET npk_local_path = :path, npk_size_bytes = :size
WHERE architecture = :arch AND channel = :channel AND version = :version
"""),
{
"path": str(local_path),
"size": file_size,
"arch": architecture,
"channel": channel,
"version": version,
},
)
await session.commit()
return str(local_path)
async def get_firmware_overview(tenant_id: str) -> dict:
"""Return fleet firmware status for a tenant.
Returns devices grouped by firmware version, annotated with up-to-date status
based on the latest known version for each device's architecture and preferred channel.
"""
async with AdminAsyncSessionLocal() as session:
# Get all devices for tenant
devices_result = await session.execute(
text("""
SELECT id, hostname, ip_address, routeros_version, architecture,
preferred_channel, routeros_major_version,
serial_number, firmware_version, model
FROM devices
WHERE tenant_id = CAST(:tenant_id AS uuid)
ORDER BY hostname
"""),
{"tenant_id": tenant_id},
)
devices = devices_result.fetchall()
# Get latest firmware versions per architecture/channel
versions_result = await session.execute(
text("""
SELECT DISTINCT ON (architecture, channel)
architecture, channel, version, npk_url
FROM firmware_versions
ORDER BY architecture, channel, checked_at DESC
""")
)
latest_versions = {
(row[0], row[1]): {"version": row[2], "npk_url": row[3]}
for row in versions_result.fetchall()
}
# Build per-device status
device_list = []
version_groups: dict[str, list] = {}
summary = {"total": 0, "up_to_date": 0, "outdated": 0, "unknown": 0}
for dev in devices:
dev_id = str(dev[0])
hostname = dev[1]
current_version = dev[3]
arch = dev[4]
channel = dev[5] or "stable"
latest = latest_versions.get((arch, channel)) if arch else None
latest_version = latest["version"] if latest else None
is_up_to_date = False
if not current_version or not arch:
summary["unknown"] += 1
elif latest_version and current_version == latest_version:
is_up_to_date = True
summary["up_to_date"] += 1
else:
summary["outdated"] += 1
summary["total"] += 1
dev_info = {
"id": dev_id,
"hostname": hostname,
"ip_address": dev[2],
"routeros_version": current_version,
"architecture": arch,
"latest_version": latest_version,
"channel": channel,
"is_up_to_date": is_up_to_date,
"serial_number": dev[7],
"firmware_version": dev[8],
"model": dev[9],
}
device_list.append(dev_info)
# Group by version
ver_key = current_version or "unknown"
if ver_key not in version_groups:
version_groups[ver_key] = []
version_groups[ver_key].append(dev_info)
# Build version groups with is_latest flag
groups = []
for ver, devs in sorted(version_groups.items()):
# A version is "latest" if it matches the latest for any arch/channel combo
is_latest = any(
v["version"] == ver for v in latest_versions.values()
)
groups.append({
"version": ver,
"count": len(devs),
"is_latest": is_latest,
"devices": devs,
})
return {
"devices": device_list,
"version_groups": groups,
"summary": summary,
}
async def get_cached_firmware() -> list[dict]:
"""List all locally cached NPK files with their sizes."""
cache_dir = Path(settings.FIRMWARE_CACHE_DIR)
cached = []
if not cache_dir.exists():
return cached
for version_dir in sorted(cache_dir.iterdir()):
if not version_dir.is_dir():
continue
for npk_file in sorted(version_dir.iterdir()):
if npk_file.suffix == ".npk":
cached.append({
"path": str(npk_file),
"version": version_dir.name,
"filename": npk_file.name,
"size_bytes": npk_file.stat().st_size,
})
return cached
def schedule_firmware_checks() -> None:
"""Register daily firmware version check with APScheduler.
Called from FastAPI lifespan startup to schedule check_latest_versions()
at 3am UTC daily.
"""
from apscheduler.triggers.cron import CronTrigger
from app.services.backup_scheduler import backup_scheduler
backup_scheduler.add_job(
check_latest_versions,
trigger=CronTrigger(hour=3, minute=0, timezone="UTC"),
id="firmware_version_check",
name="Check for new RouterOS firmware versions",
max_instances=1,
replace_existing=True,
)
logger.info("Firmware version check scheduled — daily at 3am UTC")

View File

@@ -0,0 +1,206 @@
"""NATS JetStream subscriber for device firmware events from the Go poller.
Subscribes to device.firmware.> and:
1. Updates devices.routeros_version and devices.architecture from poller data
2. Upserts firmware_versions table with latest version per architecture/channel
Uses AdminAsyncSessionLocal (superuser bypass RLS) so firmware data from any
tenant can be written without setting app.current_tenant.
"""
import asyncio
import json
import logging
from typing import Optional
import nats
from nats.js import JetStreamContext
from nats.aio.client import Client as NATSClient
from sqlalchemy import text
from app.config import settings
from app.database import AdminAsyncSessionLocal
logger = logging.getLogger(__name__)
_firmware_client: Optional[NATSClient] = None
async def on_device_firmware(msg) -> None:
"""Handle a device.firmware event published by the Go poller.
Payload (JSON):
device_id (str) -- UUID of the device
tenant_id (str) -- UUID of the owning tenant
installed_version (str) -- currently installed RouterOS version
latest_version (str) -- latest available version (may be empty)
channel (str) -- firmware channel ("stable", "long-term")
status (str) -- "New version is available", etc.
architecture (str) -- CPU architecture (arm, arm64, mipsbe, etc.)
"""
try:
data = json.loads(msg.data)
device_id = data.get("device_id")
tenant_id = data.get("tenant_id")
architecture = data.get("architecture")
installed_version = data.get("installed_version")
latest_version = data.get("latest_version")
channel = data.get("channel", "stable")
if not device_id:
logger.warning("device.firmware event missing device_id — skipping")
await msg.ack()
return
async with AdminAsyncSessionLocal() as session:
# Update device routeros_version and architecture from poller data
if architecture or installed_version:
await session.execute(
text("""
UPDATE devices
SET routeros_version = COALESCE(:installed_ver, routeros_version),
architecture = COALESCE(:architecture, architecture),
updated_at = NOW()
WHERE id = CAST(:device_id AS uuid)
"""),
{
"installed_ver": installed_version,
"architecture": architecture,
"device_id": device_id,
},
)
# Upsert firmware_versions if we got latest version info
if latest_version and architecture:
npk_url = (
f"https://download.mikrotik.com/routeros/"
f"{latest_version}/routeros-{latest_version}-{architecture}.npk"
)
await session.execute(
text("""
INSERT INTO firmware_versions (id, architecture, channel, version, npk_url, checked_at)
VALUES (gen_random_uuid(), :arch, :channel, :version, :url, NOW())
ON CONFLICT (architecture, channel, version) DO UPDATE SET checked_at = NOW()
"""),
{
"arch": architecture,
"channel": channel,
"version": latest_version,
"url": npk_url,
},
)
await session.commit()
logger.debug(
"device.firmware processed",
extra={
"device_id": device_id,
"architecture": architecture,
"installed": installed_version,
"latest": latest_version,
},
)
await msg.ack()
except Exception as exc:
logger.error(
"Failed to process device.firmware event: %s",
exc,
exc_info=True,
)
try:
await msg.nak()
except Exception:
pass
async def _subscribe_with_retry(js: JetStreamContext) -> None:
"""Subscribe to device.firmware.> with durable consumer, retrying if stream not ready."""
max_attempts = 6 # ~30 seconds at 5s intervals
for attempt in range(1, max_attempts + 1):
try:
await js.subscribe(
"device.firmware.>",
cb=on_device_firmware,
durable="api-firmware-consumer",
stream="DEVICE_EVENTS",
)
logger.info(
"NATS: subscribed to device.firmware.> (durable: api-firmware-consumer)"
)
return
except Exception as exc:
if attempt < max_attempts:
logger.warning(
"NATS: stream DEVICE_EVENTS not ready for firmware (attempt %d/%d): %s — retrying in 5s",
attempt,
max_attempts,
exc,
)
await asyncio.sleep(5)
else:
logger.warning(
"NATS: giving up on device.firmware.> after %d attempts: %s — API will run without firmware updates",
max_attempts,
exc,
)
return
async def start_firmware_subscriber() -> Optional[NATSClient]:
"""Connect to NATS and start the device.firmware.> subscription.
Uses a separate NATS connection from the status and metrics subscribers.
Returns the NATS connection (must be passed to stop_firmware_subscriber on shutdown).
Raises on fatal connection errors after retry exhaustion.
"""
global _firmware_client
logger.info("NATS firmware: connecting to %s", settings.NATS_URL)
nc = await nats.connect(
settings.NATS_URL,
max_reconnect_attempts=-1,
reconnect_time_wait=2,
error_cb=_on_error,
reconnected_cb=_on_reconnected,
disconnected_cb=_on_disconnected,
)
logger.info("NATS firmware: connected to %s", settings.NATS_URL)
js = nc.jetstream()
await _subscribe_with_retry(js)
_firmware_client = nc
return nc
async def stop_firmware_subscriber(nc: Optional[NATSClient]) -> None:
"""Drain and close the firmware NATS connection gracefully."""
if nc is None:
return
try:
logger.info("NATS firmware: draining connection...")
await nc.drain()
logger.info("NATS firmware: connection closed")
except Exception as exc:
logger.warning("NATS firmware: error during drain: %s", exc)
try:
await nc.close()
except Exception:
pass
async def _on_error(exc: Exception) -> None:
logger.error("NATS firmware error: %s", exc)
async def _on_reconnected() -> None:
logger.info("NATS firmware: reconnected")
async def _on_disconnected() -> None:
logger.warning("NATS firmware: disconnected")

View File

@@ -0,0 +1,296 @@
"""pygit2-based git store for versioned config backup storage.
All functions in this module are synchronous (pygit2 is C bindings over libgit2).
Callers running in an async context MUST wrap calls in:
loop.run_in_executor(None, func, *args)
or:
asyncio.get_event_loop().run_in_executor(None, func, *args)
See Pitfall 3 in 04-RESEARCH.md — blocking pygit2 in async context stalls
the event loop and causes timeouts for other concurrent requests.
Git layout:
{GIT_STORE_PATH}/{tenant_id}.git/ <- bare repo per tenant
objects/ refs/ HEAD <- standard bare git structure
{device_id}/ <- device subtree
export.rsc <- text export (/export compact)
backup.bin <- binary system backup
"""
import difflib
import threading
from pathlib import Path
from typing import Optional
import pygit2
from app.config import settings
# =========================================================================
# Per-tenant mutex to prevent TreeBuilder race condition (Pitfall 5 in RESEARCH.md).
# Two simultaneous backups for different devices in the same tenant repo would
# each read HEAD, build their own device subtrees, and write conflicting root
# trees. The second commit would lose the first's device subtree.
# Lock scope is the entire tenant repo — not just the device.
# =========================================================================
_tenant_locks: dict[str, threading.Lock] = {}
_tenant_locks_guard = threading.Lock()
def _get_tenant_lock(tenant_id: str) -> threading.Lock:
"""Return (creating if needed) the per-tenant commit lock."""
with _tenant_locks_guard:
if tenant_id not in _tenant_locks:
_tenant_locks[tenant_id] = threading.Lock()
return _tenant_locks[tenant_id]
# =========================================================================
# PUBLIC API
# =========================================================================
def get_or_create_repo(tenant_id: str) -> pygit2.Repository:
"""Open the tenant's bare git repo, creating it on first use.
The repo lives at {GIT_STORE_PATH}/{tenant_id}.git. The parent directory
is created if it does not exist.
Args:
tenant_id: Tenant UUID as string.
Returns:
An open pygit2.Repository instance (bare).
"""
git_store_root = Path(settings.GIT_STORE_PATH)
git_store_root.mkdir(parents=True, exist_ok=True)
repo_path = git_store_root / f"{tenant_id}.git"
if repo_path.exists():
return pygit2.Repository(str(repo_path))
return pygit2.init_repository(str(repo_path), bare=True)
def commit_backup(
tenant_id: str,
device_id: str,
export_text: str,
binary_backup: bytes,
message: str,
) -> str:
"""Write a backup pair (export.rsc + backup.bin) as a git commit.
Creates or updates the device subdirectory in the tenant's bare repo.
Preserves other devices' subdirectories by merging the device subtree
into the existing root tree.
Per-tenant locking (threading.Lock) prevents the TreeBuilder race
condition when two devices in the same tenant back up concurrently.
Args:
tenant_id: Tenant UUID as string.
device_id: Device UUID as string (becomes a subdirectory in the repo).
export_text: Text output of /export compact.
binary_backup: Raw bytes from /system backup save.
message: Commit message (format: "{trigger}: {hostname} ({ip}) at {ts}").
Returns:
The hex commit SHA string (40 characters).
"""
lock = _get_tenant_lock(tenant_id)
with lock:
repo = get_or_create_repo(tenant_id)
# Create blobs from content
export_oid = repo.create_blob(export_text.encode("utf-8"))
binary_oid = repo.create_blob(binary_backup)
# Build device subtree: {device_id}/export.rsc and {device_id}/backup.bin
device_builder = repo.TreeBuilder()
device_builder.insert("export.rsc", export_oid, pygit2.GIT_FILEMODE_BLOB)
device_builder.insert("backup.bin", binary_oid, pygit2.GIT_FILEMODE_BLOB)
device_tree_oid = device_builder.write()
# Merge device subtree into root tree, preserving all other device subtrees.
# If the repo has no commits yet, start with an empty root tree.
root_ref = repo.references.get("refs/heads/main")
parent_commit: Optional[pygit2.Commit] = None
if root_ref is not None:
try:
parent_commit = repo.get(root_ref.target)
root_builder = repo.TreeBuilder(parent_commit.tree)
except Exception:
root_builder = repo.TreeBuilder()
else:
root_builder = repo.TreeBuilder()
root_builder.insert(device_id, device_tree_oid, pygit2.GIT_FILEMODE_TREE)
root_tree_oid = root_builder.write()
# Author signature — no real identity, portal service account
author = pygit2.Signature("The Other Dude", "backup@tod.local")
parents = [root_ref.target] if root_ref is not None else []
commit_oid = repo.create_commit(
"refs/heads/main",
author,
author,
message,
root_tree_oid,
parents,
)
return str(commit_oid)
def read_file(
tenant_id: str,
commit_sha: str,
device_id: str,
filename: str,
) -> bytes:
"""Read a file blob from a specific backup commit.
Navigates the tree: root -> device_id subtree -> filename.
Args:
tenant_id: Tenant UUID as string.
commit_sha: Full or abbreviated git commit SHA.
device_id: Device UUID as string (subdirectory name in the repo).
filename: File to read: "export.rsc" or "backup.bin".
Returns:
Raw bytes of the file content.
Raises:
KeyError: If device_id subtree or filename does not exist in commit.
pygit2.GitError: If commit_sha is not found.
"""
repo = get_or_create_repo(tenant_id)
commit_obj = repo.get(commit_sha)
if commit_obj is None:
raise KeyError(f"Commit {commit_sha!r} not found in tenant {tenant_id!r} repo")
# Navigate: root tree -> device subtree -> file blob
device_entry = commit_obj.tree[device_id]
device_tree = repo.get(device_entry.id)
file_entry = device_tree[filename]
file_blob = repo.get(file_entry.id)
return file_blob.data
def list_device_commits(
tenant_id: str,
device_id: str,
) -> list[dict]:
"""Walk commit history and return commits that include the device subtree.
Walks commits newest-first. Returns only commits where the device_id
subtree is present in the root tree (the device had a backup in that commit).
Args:
tenant_id: Tenant UUID as string.
device_id: Device UUID as string.
Returns:
List of dicts (newest first):
[{"sha": str, "message": str, "timestamp": int}, ...]
Empty list if no commits or device has never been backed up.
"""
repo = get_or_create_repo(tenant_id)
# If there are no commits, return empty list immediately.
# Use refs/heads/main explicitly rather than repo.head (which defaults to
# refs/heads/master — wrong when the repo uses 'main' as the default branch).
main_ref = repo.references.get("refs/heads/main")
if main_ref is None:
return []
head_target = main_ref.target
results = []
walker = repo.walk(head_target, pygit2.GIT_SORT_TIME)
for commit in walker:
# Check if device_id subtree exists in this commit's root tree.
try:
device_entry = commit.tree[device_id]
except KeyError:
# Device not present in this commit at all — skip.
continue
# Only include this commit if it actually changed the device's subtree
# vs its parent. This prevents every subsequent backup (for any device
# in the same tenant) from appearing in all devices' histories.
if commit.parents:
parent = commit.parents[0]
try:
parent_device_entry = parent.tree[device_id]
if parent_device_entry.id == device_entry.id:
# Device subtree unchanged in this commit — skip.
continue
except KeyError:
# Device wasn't in parent but is in this commit — it's the first entry.
pass
results.append({
"sha": str(commit.id),
"message": commit.message.strip(),
"timestamp": commit.commit_time,
})
return results
def compute_line_delta(old_text: str, new_text: str) -> tuple[int, int]:
"""Compute (lines_added, lines_removed) between two text versions.
Uses difflib.SequenceMatcher to efficiently compute the line-count delta
without generating a full unified diff. This is faster than
difflib.unified_diff for large config files.
For the first backup (no prior version), pass old_text="" to get
(total_lines, 0) as the delta.
Args:
old_text: Previous export.rsc content (empty string for first backup).
new_text: New export.rsc content.
Returns:
Tuple of (lines_added, lines_removed).
"""
old_lines = old_text.splitlines() if old_text else []
new_lines = new_text.splitlines() if new_text else []
if not old_lines and not new_lines:
return (0, 0)
# For first backup (empty old), all lines are "added".
if not old_lines:
return (len(new_lines), 0)
# For deletion of all content, all lines are "removed".
if not new_lines:
return (0, len(old_lines))
matcher = difflib.SequenceMatcher(None, old_lines, new_lines, autojunk=False)
lines_added = 0
lines_removed = 0
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == "replace":
lines_removed += i2 - i1
lines_added += j2 - j1
elif tag == "delete":
lines_removed += i2 - i1
elif tag == "insert":
lines_added += j2 - j1
# "equal" — no change
return (lines_added, lines_removed)

View File

@@ -0,0 +1,324 @@
"""Key hierarchy management service for zero-knowledge architecture.
Provides CRUD operations for encrypted key bundles (UserKeySet),
append-only audit logging (KeyAccessLog), and OpenBao Transit
tenant key provisioning with credential migration.
"""
import logging
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.key_set import KeyAccessLog, UserKeySet
logger = logging.getLogger(__name__)
async def store_user_key_set(
db: AsyncSession,
user_id: UUID,
tenant_id: UUID | None,
encrypted_private_key: bytes,
private_key_nonce: bytes,
encrypted_vault_key: bytes,
vault_key_nonce: bytes,
public_key: bytes,
pbkdf2_salt: bytes,
hkdf_salt: bytes,
pbkdf2_iterations: int = 650000,
) -> UserKeySet:
"""Store encrypted key bundle during registration.
Creates a new UserKeySet for the user. Each user has exactly one
key set (UNIQUE constraint on user_id).
Args:
db: Async database session.
user_id: The user's UUID.
tenant_id: The user's tenant UUID (None for super_admin).
encrypted_private_key: RSA private key wrapped by AUK (AES-GCM).
private_key_nonce: 12-byte AES-GCM nonce for private key.
encrypted_vault_key: Tenant vault key wrapped by user's public key.
vault_key_nonce: 12-byte AES-GCM nonce for vault key.
public_key: RSA-2048 public key in SPKI format.
pbkdf2_salt: 32-byte salt for PBKDF2 key derivation.
hkdf_salt: 32-byte salt for HKDF Secret Key derivation.
pbkdf2_iterations: PBKDF2 iteration count (default 650000).
Returns:
The created UserKeySet instance.
"""
# Remove any existing key set (e.g. from a failed prior upgrade attempt)
from sqlalchemy import delete
await db.execute(delete(UserKeySet).where(UserKeySet.user_id == user_id))
key_set = UserKeySet(
user_id=user_id,
tenant_id=tenant_id,
encrypted_private_key=encrypted_private_key,
private_key_nonce=private_key_nonce,
encrypted_vault_key=encrypted_vault_key,
vault_key_nonce=vault_key_nonce,
public_key=public_key,
pbkdf2_salt=pbkdf2_salt,
hkdf_salt=hkdf_salt,
pbkdf2_iterations=pbkdf2_iterations,
)
db.add(key_set)
await db.flush()
return key_set
async def get_user_key_set(
db: AsyncSession, user_id: UUID
) -> UserKeySet | None:
"""Retrieve encrypted key bundle for login response.
Args:
db: Async database session.
user_id: The user's UUID.
Returns:
The UserKeySet if found, None otherwise.
"""
result = await db.execute(
select(UserKeySet).where(UserKeySet.user_id == user_id)
)
return result.scalar_one_or_none()
async def log_key_access(
db: AsyncSession,
tenant_id: UUID,
user_id: UUID | None,
action: str,
resource_type: str | None = None,
resource_id: str | None = None,
key_version: int | None = None,
ip_address: str | None = None,
device_id: UUID | None = None,
justification: str | None = None,
correlation_id: str | None = None,
) -> None:
"""Append to immutable key_access_log.
This table is append-only (INSERT+SELECT only via RLS policy).
No UPDATE or DELETE is permitted.
Args:
db: Async database session.
tenant_id: The tenant UUID for RLS isolation.
user_id: The user who performed the action (None for system ops).
action: Action description (e.g., 'create_key_set', 'decrypt_vault_key').
resource_type: Optional resource type being accessed.
resource_id: Optional resource identifier.
key_version: Optional key version involved.
ip_address: Optional client IP address.
device_id: Optional device UUID for credential access tracking.
justification: Optional justification for the access (e.g., 'api_backup').
correlation_id: Optional correlation ID for request tracing.
"""
log_entry = KeyAccessLog(
tenant_id=tenant_id,
user_id=user_id,
action=action,
resource_type=resource_type,
resource_id=resource_id,
key_version=key_version,
ip_address=ip_address,
device_id=device_id,
justification=justification,
correlation_id=correlation_id,
)
db.add(log_entry)
await db.flush()
# ---------------------------------------------------------------------------
# OpenBao Transit tenant key provisioning and credential migration
# ---------------------------------------------------------------------------
async def provision_tenant_key(db: AsyncSession, tenant_id: UUID) -> str:
"""Provision an OpenBao Transit key for a tenant and update the tenant record.
Idempotent: if the key already exists in OpenBao, it's a no-op on the
OpenBao side. The tenant record is always updated with the key name.
Args:
db: Async database session (admin engine, no RLS).
tenant_id: Tenant UUID.
Returns:
The key name (tenant_{uuid}).
"""
from app.models.tenant import Tenant
from app.services.openbao_service import get_openbao_service
openbao = get_openbao_service()
key_name = f"tenant_{tenant_id}"
await openbao.create_tenant_key(str(tenant_id))
# Update tenant record with key name
result = await db.execute(
select(Tenant).where(Tenant.id == tenant_id)
)
tenant = result.scalar_one_or_none()
if tenant:
tenant.openbao_key_name = key_name
await db.flush()
logger.info(
"Provisioned OpenBao Transit key for tenant %s (key=%s)",
tenant_id,
key_name,
)
return key_name
async def migrate_tenant_credentials(db: AsyncSession, tenant_id: UUID) -> dict:
"""Re-encrypt all legacy credentials for a tenant from AES-256-GCM to Transit.
Migrates device credentials, CA private keys, device cert private keys,
and notification channel secrets. Already-migrated items are skipped.
Args:
db: Async database session (admin engine, no RLS).
tenant_id: Tenant UUID.
Returns:
Dict with counts: {"devices": N, "cas": N, "certs": N, "channels": N, "errors": N}
"""
from app.config import settings
from app.models.alert import NotificationChannel
from app.models.certificate import CertificateAuthority, DeviceCertificate
from app.models.device import Device
from app.services.crypto import decrypt_credentials
from app.services.openbao_service import get_openbao_service
openbao = get_openbao_service()
legacy_key = settings.get_encryption_key_bytes()
tid = str(tenant_id)
counts = {"devices": 0, "cas": 0, "certs": 0, "channels": 0, "errors": 0}
# --- Migrate device credentials ---
result = await db.execute(
select(Device).where(
Device.tenant_id == tenant_id,
Device.encrypted_credentials.isnot(None),
(Device.encrypted_credentials_transit.is_(None) | (Device.encrypted_credentials_transit == "")),
)
)
for device in result.scalars().all():
try:
plaintext = decrypt_credentials(device.encrypted_credentials, legacy_key)
device.encrypted_credentials_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
counts["devices"] += 1
except Exception as e:
logger.error("Failed to migrate device %s credentials: %s", device.id, e)
counts["errors"] += 1
# --- Migrate CA private keys ---
result = await db.execute(
select(CertificateAuthority).where(
CertificateAuthority.tenant_id == tenant_id,
CertificateAuthority.encrypted_private_key.isnot(None),
(CertificateAuthority.encrypted_private_key_transit.is_(None) | (CertificateAuthority.encrypted_private_key_transit == "")),
)
)
for ca in result.scalars().all():
try:
plaintext = decrypt_credentials(ca.encrypted_private_key, legacy_key)
ca.encrypted_private_key_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
counts["cas"] += 1
except Exception as e:
logger.error("Failed to migrate CA %s private key: %s", ca.id, e)
counts["errors"] += 1
# --- Migrate device cert private keys ---
result = await db.execute(
select(DeviceCertificate).where(
DeviceCertificate.tenant_id == tenant_id,
DeviceCertificate.encrypted_private_key.isnot(None),
(DeviceCertificate.encrypted_private_key_transit.is_(None) | (DeviceCertificate.encrypted_private_key_transit == "")),
)
)
for cert in result.scalars().all():
try:
plaintext = decrypt_credentials(cert.encrypted_private_key, legacy_key)
cert.encrypted_private_key_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
counts["certs"] += 1
except Exception as e:
logger.error("Failed to migrate cert %s private key: %s", cert.id, e)
counts["errors"] += 1
# --- Migrate notification channel secrets ---
result = await db.execute(
select(NotificationChannel).where(
NotificationChannel.tenant_id == tenant_id,
)
)
for ch in result.scalars().all():
migrated_any = False
try:
# SMTP password
if ch.smtp_password and not ch.smtp_password_transit:
plaintext = decrypt_credentials(ch.smtp_password, legacy_key)
ch.smtp_password_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
migrated_any = True
if migrated_any:
counts["channels"] += 1
except Exception as e:
logger.error("Failed to migrate channel %s secrets: %s", ch.id, e)
counts["errors"] += 1
await db.flush()
logger.info(
"Tenant %s credential migration complete: %s",
tenant_id,
counts,
)
return counts
async def provision_existing_tenants(db: AsyncSession) -> dict:
"""Provision OpenBao Transit keys for all existing tenants and migrate credentials.
Called on app startup to ensure all tenants have Transit keys.
Idempotent -- running multiple times is safe (already-migrated items are skipped).
Args:
db: Async database session (admin engine, no RLS).
Returns:
Summary dict with total counts across all tenants.
"""
from app.models.tenant import Tenant
result = await db.execute(select(Tenant))
tenants = result.scalars().all()
total = {"tenants": len(tenants), "devices": 0, "cas": 0, "certs": 0, "channels": 0, "errors": 0}
for tenant in tenants:
try:
await provision_tenant_key(db, tenant.id)
counts = await migrate_tenant_credentials(db, tenant.id)
total["devices"] += counts["devices"]
total["cas"] += counts["cas"]
total["certs"] += counts["certs"]
total["channels"] += counts["channels"]
total["errors"] += counts["errors"]
except Exception as e:
logger.error("Failed to provision/migrate tenant %s: %s", tenant.id, e)
total["errors"] += 1
await db.commit()
logger.info("Existing tenant provisioning complete: %s", total)
return total

View File

@@ -0,0 +1,346 @@
"""NATS JetStream subscriber for device metrics events.
Subscribes to device.metrics.> and inserts into TimescaleDB hypertables:
- interface_metrics — per-interface rx/tx byte counters
- health_metrics — CPU, memory, disk, temperature per device
- wireless_metrics — per-wireless-interface aggregated client stats
Also maintains denormalized last_cpu_load and last_memory_used_pct columns
on the devices table for efficient fleet table display.
Uses AdminAsyncSessionLocal (superuser bypass RLS) so metrics from any tenant
can be written without setting app.current_tenant.
"""
import asyncio
import json
import logging
from datetime import datetime, timezone
from typing import Optional
import nats
from nats.js import JetStreamContext
from nats.aio.client import Client as NATSClient
from sqlalchemy import text
from app.config import settings
from app.database import AdminAsyncSessionLocal
logger = logging.getLogger(__name__)
_metrics_client: Optional[NATSClient] = None
# =============================================================================
# INSERT HANDLERS
# =============================================================================
def _parse_timestamp(val: str | None) -> datetime:
"""Parse an ISO 8601 / RFC 3339 timestamp string into a datetime object."""
if not val:
return datetime.now(timezone.utc)
try:
return datetime.fromisoformat(val.replace("Z", "+00:00"))
except (ValueError, AttributeError):
return datetime.now(timezone.utc)
async def _insert_health_metrics(session, data: dict) -> None:
"""Insert a health metrics event into health_metrics and update devices."""
health = data.get("health")
if not health:
logger.warning("health metrics event missing 'health' field — skipping")
return
device_id = data.get("device_id")
tenant_id = data.get("tenant_id")
collected_at = _parse_timestamp(data.get("collected_at"))
# Parse numeric values; treat empty strings as NULL.
def parse_int(val: str | None) -> int | None:
if not val:
return None
try:
return int(val)
except (ValueError, TypeError):
return None
cpu_load = parse_int(health.get("cpu_load"))
free_memory = parse_int(health.get("free_memory"))
total_memory = parse_int(health.get("total_memory"))
free_disk = parse_int(health.get("free_disk"))
total_disk = parse_int(health.get("total_disk"))
temperature = parse_int(health.get("temperature"))
await session.execute(
text("""
INSERT INTO health_metrics
(time, device_id, tenant_id, cpu_load, free_memory, total_memory,
free_disk, total_disk, temperature)
VALUES
(:time, :device_id, :tenant_id, :cpu_load, :free_memory, :total_memory,
:free_disk, :total_disk, :temperature)
"""),
{
"time": collected_at,
"device_id": device_id,
"tenant_id": tenant_id,
"cpu_load": cpu_load,
"free_memory": free_memory,
"total_memory": total_memory,
"free_disk": free_disk,
"total_disk": total_disk,
"temperature": temperature,
},
)
# Update denormalized columns on devices for fleet table display.
# Compute memory percentage in Python to avoid asyncpg type ambiguity.
mem_pct = None
if total_memory and total_memory > 0 and free_memory is not None:
mem_pct = round((1.0 - free_memory / total_memory) * 100)
await session.execute(
text("""
UPDATE devices SET
last_cpu_load = COALESCE(:cpu_load, last_cpu_load),
last_memory_used_pct = COALESCE(:mem_pct, last_memory_used_pct),
updated_at = NOW()
WHERE id = CAST(:device_id AS uuid)
"""),
{
"cpu_load": cpu_load,
"mem_pct": mem_pct,
"device_id": device_id,
},
)
async def _insert_interface_metrics(session, data: dict) -> None:
"""Insert per-interface traffic counters into interface_metrics."""
interfaces = data.get("interfaces")
if not interfaces:
return # Device may have no interfaces (unlikely but safe to skip)
device_id = data.get("device_id")
tenant_id = data.get("tenant_id")
collected_at = _parse_timestamp(data.get("collected_at"))
for iface in interfaces:
await session.execute(
text("""
INSERT INTO interface_metrics
(time, device_id, tenant_id, interface, rx_bytes, tx_bytes, rx_bps, tx_bps)
VALUES
(:time, :device_id, :tenant_id, :interface, :rx_bytes, :tx_bytes, NULL, NULL)
"""),
{
"time": collected_at,
"device_id": device_id,
"tenant_id": tenant_id,
"interface": iface.get("name"),
"rx_bytes": iface.get("rx_bytes"),
"tx_bytes": iface.get("tx_bytes"),
},
)
async def _insert_wireless_metrics(session, data: dict) -> None:
"""Insert per-wireless-interface aggregated client stats into wireless_metrics."""
wireless = data.get("wireless")
if not wireless:
return # Device may have no wireless interfaces
device_id = data.get("device_id")
tenant_id = data.get("tenant_id")
collected_at = _parse_timestamp(data.get("collected_at"))
for wif in wireless:
await session.execute(
text("""
INSERT INTO wireless_metrics
(time, device_id, tenant_id, interface, client_count, avg_signal, ccq, frequency)
VALUES
(:time, :device_id, :tenant_id, :interface,
:client_count, :avg_signal, :ccq, :frequency)
"""),
{
"time": collected_at,
"device_id": device_id,
"tenant_id": tenant_id,
"interface": wif.get("interface"),
"client_count": wif.get("client_count"),
"avg_signal": wif.get("avg_signal"),
"ccq": wif.get("ccq"),
"frequency": wif.get("frequency"),
},
)
# =============================================================================
# MAIN MESSAGE HANDLER
# =============================================================================
async def on_device_metrics(msg) -> None:
"""Handle a device.metrics event published by the Go poller.
Dispatches to the appropriate insert handler based on the 'type' field:
- "health" → _insert_health_metrics + update devices
- "interfaces" → _insert_interface_metrics
- "wireless" → _insert_wireless_metrics
On success, acknowledges the message. On error, NAKs so NATS can redeliver.
"""
try:
data = json.loads(msg.data)
metric_type = data.get("type")
device_id = data.get("device_id")
if not metric_type or not device_id:
logger.warning(
"device.metrics event missing 'type' or 'device_id' — skipping"
)
await msg.ack()
return
async with AdminAsyncSessionLocal() as session:
if metric_type == "health":
await _insert_health_metrics(session, data)
elif metric_type == "interfaces":
await _insert_interface_metrics(session, data)
elif metric_type == "wireless":
await _insert_wireless_metrics(session, data)
else:
logger.warning("Unknown metric type '%s' — skipping", metric_type)
await msg.ack()
return
await session.commit()
# Alert evaluation — non-fatal; metric write is the primary operation
try:
from app.services import alert_evaluator
await alert_evaluator.evaluate(
device_id=device_id,
tenant_id=data.get("tenant_id", ""),
metric_type=metric_type,
data=data,
)
except Exception as eval_err:
logger.warning("Alert evaluation failed for device %s: %s", device_id, eval_err)
logger.debug(
"device.metrics processed",
extra={"device_id": device_id, "type": metric_type},
)
await msg.ack()
except Exception as exc:
logger.error(
"Failed to process device.metrics event: %s",
exc,
exc_info=True,
)
try:
await msg.nak()
except Exception:
pass # If NAK also fails, NATS will redeliver after ack_wait
# =============================================================================
# SUBSCRIPTION SETUP
# =============================================================================
async def _subscribe_with_retry(js: JetStreamContext) -> None:
"""Subscribe to device.metrics.> with durable consumer, retrying if stream not ready."""
max_attempts = 6 # ~30 seconds at 5s intervals
for attempt in range(1, max_attempts + 1):
try:
await js.subscribe(
"device.metrics.>",
cb=on_device_metrics,
durable="api-metrics-consumer",
stream="DEVICE_EVENTS",
)
logger.info(
"NATS: subscribed to device.metrics.> (durable: api-metrics-consumer)"
)
return
except Exception as exc:
if attempt < max_attempts:
logger.warning(
"NATS: stream DEVICE_EVENTS not ready for metrics (attempt %d/%d): %s — retrying in 5s",
attempt,
max_attempts,
exc,
)
await asyncio.sleep(5)
else:
logger.warning(
"NATS: giving up on device.metrics.> after %d attempts: %s — API will run without metrics ingestion",
max_attempts,
exc,
)
return
async def start_metrics_subscriber() -> Optional[NATSClient]:
"""Connect to NATS and start the device.metrics.> subscription.
Uses a separate NATS connection from the status subscriber — simpler and
NATS handles multiple connections per client efficiently.
Returns the NATS connection (must be passed to stop_metrics_subscriber on shutdown).
Raises on fatal connection errors after retry exhaustion.
"""
global _metrics_client
logger.info("NATS metrics: connecting to %s", settings.NATS_URL)
nc = await nats.connect(
settings.NATS_URL,
max_reconnect_attempts=-1,
reconnect_time_wait=2,
error_cb=_on_error,
reconnected_cb=_on_reconnected,
disconnected_cb=_on_disconnected,
)
logger.info("NATS metrics: connected to %s", settings.NATS_URL)
js = nc.jetstream()
await _subscribe_with_retry(js)
_metrics_client = nc
return nc
async def stop_metrics_subscriber(nc: Optional[NATSClient]) -> None:
"""Drain and close the metrics NATS connection gracefully."""
if nc is None:
return
try:
logger.info("NATS metrics: draining connection...")
await nc.drain()
logger.info("NATS metrics: connection closed")
except Exception as exc:
logger.warning("NATS metrics: error during drain: %s", exc)
try:
await nc.close()
except Exception:
pass
async def _on_error(exc: Exception) -> None:
logger.error("NATS metrics error: %s", exc)
async def _on_reconnected() -> None:
logger.info("NATS metrics: reconnected")
async def _on_disconnected() -> None:
logger.warning("NATS metrics: disconnected")

View File

@@ -0,0 +1,231 @@
"""NATS JetStream subscriber for device status events from the Go poller.
Subscribes to device.status.> and updates device records in PostgreSQL.
This is a system-level process that needs to update devices across all tenants,
so it uses the admin engine (bypasses RLS).
"""
import asyncio
import json
import logging
import re
from datetime import datetime, timezone
from typing import Optional
import nats
from nats.js import JetStreamContext
from nats.aio.client import Client as NATSClient
from sqlalchemy import text
from app.config import settings
from app.database import AdminAsyncSessionLocal
logger = logging.getLogger(__name__)
_nats_client: Optional[NATSClient] = None
# Regex for RouterOS uptime strings like "42d14h23m15s", "14h23m15s", "23m15s", "3w2d"
_UPTIME_RE = re.compile(r"(?:(\d+)w)?(?:(\d+)d)?(?:(\d+)h)?(?:(\d+)m)?(?:(\d+)s)?")
def _parse_uptime(raw: str) -> int | None:
"""Parse a RouterOS uptime string into total seconds."""
if not raw:
return None
m = _UPTIME_RE.fullmatch(raw)
if not m:
return None
weeks = int(m.group(1) or 0)
days = int(m.group(2) or 0)
hours = int(m.group(3) or 0)
minutes = int(m.group(4) or 0)
seconds = int(m.group(5) or 0)
total = weeks * 604800 + days * 86400 + hours * 3600 + minutes * 60 + seconds
return total if total > 0 else None
async def on_device_status(msg) -> None:
"""Handle a device.status event published by the Go poller.
Payload (JSON):
device_id (str) — UUID of the device
tenant_id (str) — UUID of the owning tenant
status (str) — "online" or "offline"
routeros_version (str | None) — e.g. "7.16.2"
major_version (int | None) — e.g. 7
board_name (str | None) — e.g. "RB4011iGS+5HacQ2HnD"
last_seen (str | None) — ISO-8601 timestamp
"""
try:
data = json.loads(msg.data)
device_id = data.get("device_id")
status = data.get("status")
routeros_version = data.get("routeros_version")
major_version = data.get("major_version")
board_name = data.get("board_name")
last_seen_raw = data.get("last_seen")
serial_number = data.get("serial_number") or None
firmware_version = data.get("firmware_version") or None
uptime_seconds = _parse_uptime(data.get("uptime", ""))
if not device_id or not status:
logger.warning("Received device.status event with missing device_id or status — skipping")
await msg.ack()
return
# Parse timestamp in Python — asyncpg needs datetime objects, not strings
last_seen_dt = None
if last_seen_raw:
try:
last_seen_dt = datetime.fromisoformat(last_seen_raw.replace("Z", "+00:00"))
except (ValueError, AttributeError):
last_seen_dt = datetime.now(timezone.utc)
async with AdminAsyncSessionLocal() as session:
await session.execute(
text(
"""
UPDATE devices SET
status = :status,
routeros_version = COALESCE(:routeros_version, routeros_version),
routeros_major_version = COALESCE(:major_version, routeros_major_version),
model = COALESCE(:board_name, model),
serial_number = COALESCE(:serial_number, serial_number),
firmware_version = COALESCE(:firmware_version, firmware_version),
uptime_seconds = COALESCE(:uptime_seconds, uptime_seconds),
last_seen = COALESCE(:last_seen, last_seen),
updated_at = NOW()
WHERE id = CAST(:device_id AS uuid)
"""
),
{
"status": status,
"routeros_version": routeros_version,
"major_version": major_version,
"board_name": board_name,
"serial_number": serial_number,
"firmware_version": firmware_version,
"uptime_seconds": uptime_seconds,
"last_seen": last_seen_dt,
"device_id": device_id,
},
)
await session.commit()
# Alert evaluation for offline/online status changes — non-fatal
try:
from app.services import alert_evaluator
if status == "offline":
await alert_evaluator.evaluate_offline(device_id, data.get("tenant_id", ""))
elif status == "online":
await alert_evaluator.evaluate_online(device_id, data.get("tenant_id", ""))
except Exception as e:
logger.warning("Alert evaluation failed for device %s status=%s: %s", device_id, status, e)
logger.info(
"Device status updated",
extra={
"device_id": device_id,
"status": status,
"routeros_version": routeros_version,
},
)
await msg.ack()
except Exception as exc:
logger.error(
"Failed to process device.status event: %s",
exc,
exc_info=True,
)
try:
await msg.nak()
except Exception:
pass # If NAK also fails, NATS will redeliver after ack_wait
async def _subscribe_with_retry(js: JetStreamContext) -> None:
"""Subscribe to device.status.> with durable consumer, retrying if stream not ready."""
max_attempts = 6 # ~30 seconds at 5s intervals
for attempt in range(1, max_attempts + 1):
try:
await js.subscribe(
"device.status.>",
cb=on_device_status,
durable="api-status-consumer",
stream="DEVICE_EVENTS",
)
logger.info("NATS: subscribed to device.status.> (durable: api-status-consumer)")
return
except Exception as exc:
if attempt < max_attempts:
logger.warning(
"NATS: stream DEVICE_EVENTS not ready (attempt %d/%d): %s — retrying in 5s",
attempt,
max_attempts,
exc,
)
await asyncio.sleep(5)
else:
logger.warning(
"NATS: giving up on device.status.> after %d attempts: %s — API will run without real-time status updates",
max_attempts,
exc,
)
return
async def start_nats_subscriber() -> Optional[NATSClient]:
"""Connect to NATS and start the device.status.> subscription.
Returns the NATS connection (must be passed to stop_nats_subscriber on shutdown).
Raises on fatal connection errors after retry exhaustion.
"""
global _nats_client
logger.info("NATS: connecting to %s", settings.NATS_URL)
nc = await nats.connect(
settings.NATS_URL,
max_reconnect_attempts=-1, # reconnect forever (pod-to-pod transient failures)
reconnect_time_wait=2,
error_cb=_on_error,
reconnected_cb=_on_reconnected,
disconnected_cb=_on_disconnected,
)
logger.info("NATS: connected to %s", settings.NATS_URL)
js = nc.jetstream()
await _subscribe_with_retry(js)
_nats_client = nc
return nc
async def stop_nats_subscriber(nc: Optional[NATSClient]) -> None:
"""Drain and close the NATS connection gracefully."""
if nc is None:
return
try:
logger.info("NATS: draining connection...")
await nc.drain()
logger.info("NATS: connection closed")
except Exception as exc:
logger.warning("NATS: error during drain: %s", exc)
try:
await nc.close()
except Exception:
pass
async def _on_error(exc: Exception) -> None:
logger.error("NATS error: %s", exc)
async def _on_reconnected() -> None:
logger.info("NATS: reconnected")
async def _on_disconnected() -> None:
logger.warning("NATS: disconnected")

View File

@@ -0,0 +1,256 @@
"""Email and webhook notification delivery for alert events.
Best-effort delivery: failures are logged but never raised.
Each dispatch is wrapped in try/except so one failing channel
doesn't prevent delivery to other channels.
"""
import logging
from typing import Any
import httpx
logger = logging.getLogger(__name__)
async def dispatch_notifications(
alert_event: dict[str, Any],
channels: list[dict[str, Any]],
device_hostname: str,
) -> None:
"""Send notifications for an alert event to all provided channels.
Args:
alert_event: Dict with alert event fields (status, severity, metric, etc.)
channels: List of notification channel dicts
device_hostname: Human-readable device name for messages
"""
for channel in channels:
try:
if channel["channel_type"] == "email":
await _send_email(channel, alert_event, device_hostname)
elif channel["channel_type"] == "webhook":
await _send_webhook(channel, alert_event, device_hostname)
elif channel["channel_type"] == "slack":
await _send_slack(channel, alert_event, device_hostname)
else:
logger.warning("Unknown channel type: %s", channel["channel_type"])
except Exception as e:
logger.warning(
"Notification delivery failed for channel %s (%s): %s",
channel.get("name"), channel.get("channel_type"), e,
)
async def _send_email(channel: dict, alert_event: dict, device_hostname: str) -> None:
"""Send alert notification email using per-channel SMTP config."""
from app.services.email_service import SMTPConfig, send_email
severity = alert_event.get("severity", "warning")
status = alert_event.get("status", "firing")
rule_name = alert_event.get("rule_name") or alert_event.get("message", "Unknown Rule")
metric = alert_event.get("metric_name") or alert_event.get("metric", "")
value = alert_event.get("current_value") or alert_event.get("value", "")
threshold = alert_event.get("threshold", "")
severity_colors = {
"critical": "#ef4444",
"warning": "#f59e0b",
"info": "#38bdf8",
}
color = severity_colors.get(severity, "#38bdf8")
status_label = "RESOLVED" if status == "resolved" else "FIRING"
html = f"""
<div style="font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; max-width: 600px; margin: 0 auto;">
<div style="background: {color}; padding: 16px 24px; border-radius: 8px 8px 0 0;">
<h2 style="color: #fff; margin: 0;">[{status_label}] {rule_name}</h2>
</div>
<div style="background: #1e293b; padding: 24px; border-radius: 0 0 8px 8px; color: #e2e8f0;">
<table style="width: 100%; border-collapse: collapse;">
<tr><td style="padding: 8px 0; color: #94a3b8;">Device</td><td style="padding: 8px 0;">{device_hostname}</td></tr>
<tr><td style="padding: 8px 0; color: #94a3b8;">Severity</td><td style="padding: 8px 0;">{severity.upper()}</td></tr>
<tr><td style="padding: 8px 0; color: #94a3b8;">Metric</td><td style="padding: 8px 0;">{metric}</td></tr>
<tr><td style="padding: 8px 0; color: #94a3b8;">Value</td><td style="padding: 8px 0;">{value}</td></tr>
<tr><td style="padding: 8px 0; color: #94a3b8;">Threshold</td><td style="padding: 8px 0;">{threshold}</td></tr>
</table>
<p style="color: #64748b; font-size: 12px; margin-top: 24px;">
TOD — Fleet Management for MikroTik RouterOS
</p>
</div>
</div>
"""
plain = (
f"[{status_label}] {rule_name}\n\n"
f"Device: {device_hostname}\n"
f"Severity: {severity}\n"
f"Metric: {metric}\n"
f"Value: {value}\n"
f"Threshold: {threshold}\n"
)
# Decrypt SMTP password (Transit first, then legacy Fernet)
smtp_password = None
transit_cipher = channel.get("smtp_password_transit")
legacy_cipher = channel.get("smtp_password")
tenant_id = channel.get("tenant_id")
if transit_cipher and tenant_id:
try:
from app.services.kms_service import decrypt_transit
smtp_password = await decrypt_transit(transit_cipher, tenant_id)
except Exception:
logger.warning("Transit decryption failed for channel %s, trying legacy", channel.get("id"))
if not smtp_password and legacy_cipher:
try:
from app.config import settings as app_settings
from cryptography.fernet import Fernet
raw = bytes(legacy_cipher) if isinstance(legacy_cipher, memoryview) else legacy_cipher
f = Fernet(app_settings.CREDENTIAL_ENCRYPTION_KEY.encode())
smtp_password = f.decrypt(raw).decode()
except Exception:
logger.warning("Legacy decryption failed for channel %s", channel.get("id"))
config = SMTPConfig(
host=channel.get("smtp_host", "localhost"),
port=channel.get("smtp_port", 587),
user=channel.get("smtp_user"),
password=smtp_password,
use_tls=channel.get("smtp_use_tls", False),
from_address=channel.get("from_address") or "alerts@mikrotik-portal.local",
)
to = channel.get("to_address")
subject = f"[TOD {status_label}] {rule_name}{device_hostname}"
await send_email(to, subject, html, plain, config)
async def _send_webhook(
channel: dict[str, Any],
alert_event: dict[str, Any],
device_hostname: str,
) -> None:
"""Send alert notification to a webhook URL (Slack-compatible JSON)."""
severity = alert_event.get("severity", "info")
status = alert_event.get("status", "firing")
metric = alert_event.get("metric")
value = alert_event.get("value")
threshold = alert_event.get("threshold")
message_text = alert_event.get("message", "")
payload = {
"alert_name": message_text,
"severity": severity,
"status": status,
"device": device_hostname,
"device_id": alert_event.get("device_id"),
"metric": metric,
"value": value,
"threshold": threshold,
"timestamp": str(alert_event.get("fired_at", "")),
"text": f"[{severity.upper()}] {device_hostname}: {message_text}",
}
webhook_url = channel.get("webhook_url", "")
if not webhook_url:
logger.warning("Webhook channel %s has no URL configured", channel.get("name"))
return
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.post(webhook_url, json=payload)
logger.info(
"Webhook notification sent to %s — status %d",
webhook_url, response.status_code,
)
async def _send_slack(
channel: dict[str, Any],
alert_event: dict[str, Any],
device_hostname: str,
) -> None:
"""Send alert notification to Slack via incoming webhook with Block Kit formatting."""
severity = alert_event.get("severity", "info").upper()
status = alert_event.get("status", "firing")
metric = alert_event.get("metric", "unknown")
message_text = alert_event.get("message", "")
value = alert_event.get("value")
threshold = alert_event.get("threshold")
color = {"CRITICAL": "#dc2626", "WARNING": "#f59e0b", "INFO": "#3b82f6"}.get(severity, "#6b7280")
status_label = "RESOLVED" if status == "resolved" else status
blocks = [
{
"type": "header",
"text": {"type": "plain_text", "text": f"{'' if status == 'resolved' else '🚨'} [{severity}] {status_label.upper()}"},
},
{
"type": "section",
"fields": [
{"type": "mrkdwn", "text": f"*Device:*\n{device_hostname}"},
{"type": "mrkdwn", "text": f"*Metric:*\n{metric}"},
],
},
]
if value is not None or threshold is not None:
fields = []
if value is not None:
fields.append({"type": "mrkdwn", "text": f"*Value:*\n{value}"})
if threshold is not None:
fields.append({"type": "mrkdwn", "text": f"*Threshold:*\n{threshold}"})
blocks.append({"type": "section", "fields": fields})
if message_text:
blocks.append({"type": "section", "text": {"type": "mrkdwn", "text": f"*Message:*\n{message_text}"}})
blocks.append({"type": "context", "elements": [{"type": "mrkdwn", "text": "TOD Alert System"}]})
slack_url = channel.get("slack_webhook_url", "")
if not slack_url:
logger.warning("Slack channel %s has no webhook URL configured", channel.get("name"))
return
payload = {"attachments": [{"color": color, "blocks": blocks}]}
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.post(slack_url, json=payload)
logger.info("Slack notification sent — status %d", response.status_code)
async def send_test_notification(channel: dict[str, Any]) -> bool:
"""Send a test notification through a channel to verify configuration.
Args:
channel: Notification channel dict with all config fields
Returns:
True on success
Raises:
Exception on delivery failure (caller handles)
"""
test_event = {
"status": "test",
"severity": "info",
"metric": "test",
"value": None,
"threshold": None,
"message": "Test notification from TOD",
"device_id": "00000000-0000-0000-0000-000000000000",
"fired_at": "",
}
if channel["channel_type"] == "email":
await _send_email(channel, test_event, "Test Device")
elif channel["channel_type"] == "webhook":
await _send_webhook(channel, test_event, "Test Device")
elif channel["channel_type"] == "slack":
await _send_slack(channel, test_event, "Test Device")
else:
raise ValueError(f"Unknown channel type: {channel['channel_type']}")
return True

View File

@@ -0,0 +1,174 @@
"""
OpenBao Transit secrets engine client for per-tenant envelope encryption.
Provides encrypt/decrypt operations via OpenBao's HTTP API. Each tenant gets
a dedicated Transit key (tenant_{uuid}) for AES-256-GCM encryption. The key
material never leaves OpenBao -- the application only sees ciphertext.
Ciphertext format: "vault:v1:base64..." (compatible with Vault Transit format)
"""
import base64
import logging
from typing import Optional
import httpx
from app.config import settings
logger = logging.getLogger(__name__)
class OpenBaoTransitService:
"""Async client for OpenBao Transit secrets engine."""
def __init__(self, addr: str | None = None, token: str | None = None):
self.addr = addr or settings.OPENBAO_ADDR
self.token = token or settings.OPENBAO_TOKEN
self._client: httpx.AsyncClient | None = None
async def _get_client(self) -> httpx.AsyncClient:
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(
base_url=self.addr,
headers={"X-Vault-Token": self.token},
timeout=5.0,
)
return self._client
async def close(self) -> None:
if self._client and not self._client.is_closed:
await self._client.aclose()
self._client = None
async def create_tenant_key(self, tenant_id: str) -> None:
"""Create Transit encryption keys for a tenant (credential + data). Idempotent."""
client = await self._get_client()
# Credential key: tenant_{uuid}
key_name = f"tenant_{tenant_id}"
resp = await client.post(
f"/v1/transit/keys/{key_name}",
json={"type": "aes256-gcm96"},
)
if resp.status_code not in (200, 204):
resp.raise_for_status()
logger.info("OpenBao Transit key ensured", extra={"key_name": key_name})
# Data key: tenant_{uuid}_data (Phase 30)
await self.create_tenant_data_key(tenant_id)
async def encrypt(self, tenant_id: str, plaintext: bytes) -> str:
"""Encrypt plaintext via Transit engine. Returns ciphertext string."""
client = await self._get_client()
key_name = f"tenant_{tenant_id}"
resp = await client.post(
f"/v1/transit/encrypt/{key_name}",
json={"plaintext": base64.b64encode(plaintext).decode()},
)
resp.raise_for_status()
ciphertext = resp.json()["data"]["ciphertext"]
return ciphertext # "vault:v1:..."
async def decrypt(self, tenant_id: str, ciphertext: str) -> bytes:
"""Decrypt Transit ciphertext. Returns plaintext bytes."""
client = await self._get_client()
key_name = f"tenant_{tenant_id}"
resp = await client.post(
f"/v1/transit/decrypt/{key_name}",
json={"ciphertext": ciphertext},
)
resp.raise_for_status()
plaintext_b64 = resp.json()["data"]["plaintext"]
return base64.b64decode(plaintext_b64)
async def key_exists(self, tenant_id: str) -> bool:
"""Check if a Transit key exists for a tenant."""
client = await self._get_client()
key_name = f"tenant_{tenant_id}"
resp = await client.get(f"/v1/transit/keys/{key_name}")
return resp.status_code == 200
# ------------------------------------------------------------------
# Data encryption keys (tenant_{uuid}_data) — Phase 30
# ------------------------------------------------------------------
async def create_tenant_data_key(self, tenant_id: str) -> None:
"""Create a Transit data encryption key for a tenant. Idempotent.
Data keys use the suffix '_data' to separate them from credential keys.
Key naming: tenant_{uuid}_data (vs tenant_{uuid} for credentials).
"""
client = await self._get_client()
key_name = f"tenant_{tenant_id}_data"
resp = await client.post(
f"/v1/transit/keys/{key_name}",
json={"type": "aes256-gcm96"},
)
if resp.status_code not in (200, 204):
resp.raise_for_status()
logger.info("OpenBao Transit data key ensured", extra={"key_name": key_name})
async def ensure_tenant_data_key(self, tenant_id: str) -> None:
"""Ensure a data encryption key exists for a tenant. Idempotent.
Checks existence first and creates if missing. Safe to call on every
encrypt operation (fast path: single GET to check existence).
"""
client = await self._get_client()
key_name = f"tenant_{tenant_id}_data"
resp = await client.get(f"/v1/transit/keys/{key_name}")
if resp.status_code != 200:
await self.create_tenant_data_key(tenant_id)
async def encrypt_data(self, tenant_id: str, plaintext: bytes) -> str:
"""Encrypt data via Transit using per-tenant data key.
Uses the tenant_{uuid}_data key (separate from credential key).
Args:
tenant_id: Tenant UUID string.
plaintext: Raw bytes to encrypt.
Returns:
Transit ciphertext string (vault:v1:...).
"""
client = await self._get_client()
key_name = f"tenant_{tenant_id}_data"
resp = await client.post(
f"/v1/transit/encrypt/{key_name}",
json={"plaintext": base64.b64encode(plaintext).decode()},
)
resp.raise_for_status()
return resp.json()["data"]["ciphertext"]
async def decrypt_data(self, tenant_id: str, ciphertext: str) -> bytes:
"""Decrypt Transit data ciphertext using per-tenant data key.
Args:
tenant_id: Tenant UUID string.
ciphertext: Transit ciphertext (vault:v1:...).
Returns:
Decrypted plaintext bytes.
"""
client = await self._get_client()
key_name = f"tenant_{tenant_id}_data"
resp = await client.post(
f"/v1/transit/decrypt/{key_name}",
json={"ciphertext": ciphertext},
)
resp.raise_for_status()
plaintext_b64 = resp.json()["data"]["plaintext"]
return base64.b64decode(plaintext_b64)
# Module-level singleton
_openbao_service: Optional[OpenBaoTransitService] = None
def get_openbao_service() -> OpenBaoTransitService:
"""Return module-level OpenBao Transit service singleton."""
global _openbao_service
if _openbao_service is None:
_openbao_service = OpenBaoTransitService()
return _openbao_service

View File

@@ -0,0 +1,141 @@
"""NATS subscribers for push rollback (auto) and push alert (manual).
- config.push.rollback.> -> auto-restore for template pushes
- config.push.alert.> -> create alert for editor pushes
"""
import json
import logging
from typing import Any, Optional
from app.config import settings
from app.database import AdminAsyncSessionLocal
from app.models.alert import AlertEvent
from app.services import restore_service
logger = logging.getLogger(__name__)
_nc: Optional[Any] = None
async def _create_push_alert(device_id: str, tenant_id: str, push_type: str) -> None:
"""Create a high-priority alert for device offline after config push."""
async with AdminAsyncSessionLocal() as session:
alert = AlertEvent(
device_id=device_id,
tenant_id=tenant_id,
status="firing",
severity="critical",
message=f"Device went offline after config {push_type} — rollback available",
)
session.add(alert)
await session.commit()
logger.info("Created push alert for device %s (type=%s)", device_id, push_type)
async def handle_push_rollback(event: dict) -> None:
"""Auto-rollback: restore device to pre-push config."""
device_id = event.get("device_id")
tenant_id = event.get("tenant_id")
commit_sha = event.get("pre_push_commit_sha")
if not all([device_id, tenant_id, commit_sha]):
logger.warning("Push rollback event missing fields: %s", event)
return
logger.warning(
"AUTO-ROLLBACK: Device %s offline after template push, restoring to %s",
device_id,
commit_sha,
)
try:
async with AdminAsyncSessionLocal() as session:
result = await restore_service.restore_config(
device_id=device_id,
tenant_id=tenant_id,
commit_sha=commit_sha,
db_session=session,
)
await session.commit()
logger.info(
"Auto-rollback result for device %s: %s",
device_id,
result.get("status"),
)
except Exception as e:
logger.error("Auto-rollback failed for device %s: %s", device_id, e)
await _create_push_alert(device_id, tenant_id, "template (auto-rollback failed)")
async def handle_push_alert(event: dict) -> None:
"""Alert: create notification for device offline after editor push."""
device_id = event.get("device_id")
tenant_id = event.get("tenant_id")
push_type = event.get("push_type", "editor")
if not device_id or not tenant_id:
logger.warning("Push alert event missing fields: %s", event)
return
await _create_push_alert(device_id, tenant_id, push_type)
async def _on_rollback_message(msg) -> None:
"""NATS message handler for config.push.rollback.> subjects."""
try:
event = json.loads(msg.data.decode())
await handle_push_rollback(event)
await msg.ack()
except Exception as e:
logger.error("Error handling rollback message: %s", e)
await msg.nak()
async def _on_alert_message(msg) -> None:
"""NATS message handler for config.push.alert.> subjects."""
try:
event = json.loads(msg.data.decode())
await handle_push_alert(event)
await msg.ack()
except Exception as e:
logger.error("Error handling push alert message: %s", e)
await msg.nak()
async def start_push_rollback_subscriber() -> Optional[Any]:
"""Connect to NATS and subscribe to push rollback/alert events."""
import nats
global _nc
try:
logger.info("NATS push-rollback: connecting to %s", settings.NATS_URL)
_nc = await nats.connect(settings.NATS_URL)
js = _nc.jetstream()
await js.subscribe(
"config.push.rollback.>",
cb=_on_rollback_message,
durable="api-push-rollback-consumer",
stream="DEVICE_EVENTS",
manual_ack=True,
)
await js.subscribe(
"config.push.alert.>",
cb=_on_alert_message,
durable="api-push-alert-consumer",
stream="DEVICE_EVENTS",
manual_ack=True,
)
logger.info("Push rollback/alert subscriber started")
return _nc
except Exception as e:
logger.error("Failed to start push rollback subscriber: %s", e)
return None
async def stop_push_rollback_subscriber() -> None:
"""Gracefully close the NATS connection."""
global _nc
if _nc:
await _nc.drain()
_nc = None

View File

@@ -0,0 +1,70 @@
"""Track recent config pushes in Redis for poller-aware rollback.
When a device goes offline shortly after a push, the poller checks these
keys and triggers rollback (template/restore) or alert (editor).
Redis key format: push:recent:{device_id}
TTL: 300 seconds (5 minutes)
"""
import json
import logging
from typing import Optional
import redis.asyncio as redis
from app.config import settings
logger = logging.getLogger(__name__)
PUSH_TTL_SECONDS = 300 # 5 minutes
_redis: Optional[redis.Redis] = None
async def _get_redis() -> redis.Redis:
global _redis
if _redis is None:
_redis = redis.from_url(settings.REDIS_URL)
return _redis
async def record_push(
device_id: str,
tenant_id: str,
push_type: str,
push_operation_id: str = "",
pre_push_commit_sha: str = "",
) -> None:
"""Record a recent config push in Redis.
Args:
device_id: UUID of the device.
tenant_id: UUID of the tenant.
push_type: 'template' (auto-rollback) or 'editor' (alert only) or 'restore'.
push_operation_id: ID of the ConfigPushOperation row.
pre_push_commit_sha: Git SHA of the pre-push backup (for rollback).
"""
r = await _get_redis()
key = f"push:recent:{device_id}"
value = json.dumps({
"device_id": device_id,
"tenant_id": tenant_id,
"push_type": push_type,
"push_operation_id": push_operation_id,
"pre_push_commit_sha": pre_push_commit_sha,
})
await r.set(key, value, ex=PUSH_TTL_SECONDS)
logger.debug(
"Recorded push for device %s (type=%s, TTL=%ds)",
device_id,
push_type,
PUSH_TTL_SECONDS,
)
async def clear_push(device_id: str) -> None:
"""Clear the push tracking key (e.g., after successful commit)."""
r = await _get_redis()
await r.delete(f"push:recent:{device_id}")
logger.debug("Cleared push tracking for device %s", device_id)

View File

@@ -0,0 +1,572 @@
"""Report generation service.
Generates PDF (via Jinja2 + weasyprint) and CSV reports for:
- Device inventory
- Metrics summary
- Alert history
- Change log (audit_logs if available, else config_backups fallback)
Phase 30 NOTE: Reports are currently ephemeral (generated on-demand per request,
never stored at rest). DATAENC-03 requires "report content is encrypted before
storage." Since no report storage exists yet, encryption will be applied when
report caching/storage is added. The generation pipeline is Transit-ready --
wrap the file_bytes with encrypt_data_transit() before any future INSERT.
"""
import csv
import io
import os
import time
from datetime import datetime
from typing import Any, Optional
from uuid import UUID
import structlog
from jinja2 import Environment, FileSystemLoader
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
logger = structlog.get_logger(__name__)
# Jinja2 environment pointing at the templates directory
_TEMPLATE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "templates")
_jinja_env = Environment(
loader=FileSystemLoader(_TEMPLATE_DIR),
autoescape=True,
)
async def generate_report(
db: AsyncSession,
tenant_id: UUID,
report_type: str,
date_from: Optional[datetime],
date_to: Optional[datetime],
fmt: str = "pdf",
) -> tuple[bytes, str, str]:
"""Generate a report and return (file_bytes, content_type, filename).
Args:
db: RLS-enforced async session (tenant context already set).
tenant_id: Tenant UUID for scoping.
report_type: One of device_inventory, metrics_summary, alert_history, change_log.
date_from: Start date for time-ranged reports.
date_to: End date for time-ranged reports.
fmt: Output format -- "pdf" or "csv".
Returns:
Tuple of (file_bytes, content_type, filename).
"""
start = time.monotonic()
# Fetch tenant name for the header
tenant_name = await _get_tenant_name(db, tenant_id)
# Dispatch to the appropriate handler
handlers = {
"device_inventory": _device_inventory,
"metrics_summary": _metrics_summary,
"alert_history": _alert_history,
"change_log": _change_log,
}
handler = handlers[report_type]
template_data = await handler(db, tenant_id, date_from, date_to)
# Common template context
generated_at = datetime.utcnow().strftime("%Y-%m-%d %H:%M UTC")
base_context = {
"tenant_name": tenant_name,
"generated_at": generated_at,
}
timestamp_str = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
if fmt == "csv":
file_bytes = _render_csv(report_type, template_data)
content_type = "text/csv; charset=utf-8"
filename = f"{report_type}_{timestamp_str}.csv"
else:
file_bytes = _render_pdf(report_type, {**base_context, **template_data})
content_type = "application/pdf"
filename = f"{report_type}_{timestamp_str}.pdf"
elapsed = time.monotonic() - start
logger.info(
"report_generated",
report_type=report_type,
format=fmt,
tenant_id=str(tenant_id),
size_bytes=len(file_bytes),
elapsed_seconds=round(elapsed, 2),
)
return file_bytes, content_type, filename
# ---------------------------------------------------------------------------
# Tenant name helper
# ---------------------------------------------------------------------------
async def _get_tenant_name(db: AsyncSession, tenant_id: UUID) -> str:
"""Fetch the tenant name by ID."""
result = await db.execute(
text("SELECT name FROM tenants WHERE id = CAST(:tid AS uuid)"),
{"tid": str(tenant_id)},
)
row = result.fetchone()
return row[0] if row else "Unknown Tenant"
# ---------------------------------------------------------------------------
# Report type handlers
# ---------------------------------------------------------------------------
async def _device_inventory(
db: AsyncSession,
tenant_id: UUID,
date_from: Optional[datetime],
date_to: Optional[datetime],
) -> dict[str, Any]:
"""Gather device inventory data."""
result = await db.execute(
text("""
SELECT d.hostname, d.ip_address, d.model, d.routeros_version,
d.status, d.last_seen, d.uptime_seconds,
COALESCE(
(SELECT string_agg(dg.name, ', ')
FROM device_group_memberships dgm
JOIN device_groups dg ON dg.id = dgm.group_id
WHERE dgm.device_id = d.id),
''
) AS groups
FROM devices d
ORDER BY d.hostname ASC
""")
)
rows = result.fetchall()
devices = []
online_count = 0
offline_count = 0
unknown_count = 0
for row in rows:
status = row[4]
if status == "online":
online_count += 1
elif status == "offline":
offline_count += 1
else:
unknown_count += 1
uptime_str = _format_uptime(row[6]) if row[6] else None
last_seen_str = row[5].strftime("%Y-%m-%d %H:%M") if row[5] else None
devices.append({
"hostname": row[0],
"ip_address": row[1],
"model": row[2],
"routeros_version": row[3],
"status": status,
"last_seen": last_seen_str,
"uptime": uptime_str,
"groups": row[7] if row[7] else None,
})
return {
"report_title": "Device Inventory",
"devices": devices,
"total_devices": len(devices),
"online_count": online_count,
"offline_count": offline_count,
"unknown_count": unknown_count,
}
async def _metrics_summary(
db: AsyncSession,
tenant_id: UUID,
date_from: Optional[datetime],
date_to: Optional[datetime],
) -> dict[str, Any]:
"""Gather metrics summary data grouped by device."""
result = await db.execute(
text("""
SELECT d.hostname,
AVG(hm.cpu_load) AS avg_cpu,
MAX(hm.cpu_load) AS peak_cpu,
AVG(CASE WHEN hm.total_memory > 0
THEN 100.0 * (hm.total_memory - hm.free_memory) / hm.total_memory
END) AS avg_mem,
MAX(CASE WHEN hm.total_memory > 0
THEN 100.0 * (hm.total_memory - hm.free_memory) / hm.total_memory
END) AS peak_mem,
AVG(CASE WHEN hm.total_disk > 0
THEN 100.0 * (hm.total_disk - hm.free_disk) / hm.total_disk
END) AS avg_disk,
AVG(hm.temperature) AS avg_temp,
COUNT(*) AS data_points
FROM health_metrics hm
JOIN devices d ON d.id = hm.device_id
WHERE hm.time >= :date_from
AND hm.time <= :date_to
GROUP BY d.id, d.hostname
ORDER BY avg_cpu DESC NULLS LAST
"""),
{
"date_from": date_from,
"date_to": date_to,
},
)
rows = result.fetchall()
devices = []
for row in rows:
devices.append({
"hostname": row[0],
"avg_cpu": float(row[1]) if row[1] is not None else None,
"peak_cpu": float(row[2]) if row[2] is not None else None,
"avg_mem": float(row[3]) if row[3] is not None else None,
"peak_mem": float(row[4]) if row[4] is not None else None,
"avg_disk": float(row[5]) if row[5] is not None else None,
"avg_temp": float(row[6]) if row[6] is not None else None,
"data_points": row[7],
})
return {
"report_title": "Metrics Summary",
"devices": devices,
"date_from": date_from.strftime("%Y-%m-%d") if date_from else "",
"date_to": date_to.strftime("%Y-%m-%d") if date_to else "",
}
async def _alert_history(
db: AsyncSession,
tenant_id: UUID,
date_from: Optional[datetime],
date_to: Optional[datetime],
) -> dict[str, Any]:
"""Gather alert history data."""
result = await db.execute(
text("""
SELECT ae.fired_at, ae.resolved_at, ae.severity, ae.status,
ae.message, d.hostname,
EXTRACT(EPOCH FROM (ae.resolved_at - ae.fired_at)) AS duration_secs
FROM alert_events ae
LEFT JOIN devices d ON d.id = ae.device_id
WHERE ae.fired_at >= :date_from
AND ae.fired_at <= :date_to
ORDER BY ae.fired_at DESC
"""),
{
"date_from": date_from,
"date_to": date_to,
},
)
rows = result.fetchall()
alerts = []
critical_count = 0
warning_count = 0
info_count = 0
resolved_durations: list[float] = []
for row in rows:
severity = row[2]
if severity == "critical":
critical_count += 1
elif severity == "warning":
warning_count += 1
else:
info_count += 1
duration_secs = float(row[6]) if row[6] is not None else None
if duration_secs is not None:
resolved_durations.append(duration_secs)
alerts.append({
"fired_at": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"hostname": row[5],
"severity": severity,
"status": row[3],
"message": row[4],
"duration": _format_duration(duration_secs) if duration_secs is not None else None,
})
mttr_minutes = None
mttr_display = None
if resolved_durations:
avg_secs = sum(resolved_durations) / len(resolved_durations)
mttr_minutes = round(avg_secs / 60, 1)
mttr_display = _format_duration(avg_secs)
return {
"report_title": "Alert History",
"alerts": alerts,
"total_alerts": len(alerts),
"critical_count": critical_count,
"warning_count": warning_count,
"info_count": info_count,
"mttr_minutes": mttr_minutes,
"mttr_display": mttr_display,
"date_from": date_from.strftime("%Y-%m-%d") if date_from else "",
"date_to": date_to.strftime("%Y-%m-%d") if date_to else "",
}
async def _change_log(
db: AsyncSession,
tenant_id: UUID,
date_from: Optional[datetime],
date_to: Optional[datetime],
) -> dict[str, Any]:
"""Gather change log data -- try audit_logs table first, fall back to config_backups."""
# Check if audit_logs table exists (17-01 may not have run yet)
has_audit_logs = await _table_exists(db, "audit_logs")
if has_audit_logs:
return await _change_log_from_audit(db, date_from, date_to)
else:
return await _change_log_from_backups(db, date_from, date_to)
async def _table_exists(db: AsyncSession, table_name: str) -> bool:
"""Check if a table exists in the database."""
result = await db.execute(
text("""
SELECT EXISTS (
SELECT 1 FROM information_schema.tables
WHERE table_schema = 'public' AND table_name = :table_name
)
"""),
{"table_name": table_name},
)
return bool(result.scalar())
async def _change_log_from_audit(
db: AsyncSession,
date_from: Optional[datetime],
date_to: Optional[datetime],
) -> dict[str, Any]:
"""Build change log from audit_logs table."""
result = await db.execute(
text("""
SELECT al.created_at, u.name AS user_name, al.action,
d.hostname, al.resource_type,
al.details
FROM audit_logs al
LEFT JOIN users u ON u.id = al.user_id
LEFT JOIN devices d ON d.id = al.device_id
WHERE al.created_at >= :date_from
AND al.created_at <= :date_to
ORDER BY al.created_at DESC
"""),
{
"date_from": date_from,
"date_to": date_to,
},
)
rows = result.fetchall()
entries = []
for row in rows:
entries.append({
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"user": row[1],
"action": row[2],
"device": row[3],
"details": row[4] or row[5] or "",
})
return {
"report_title": "Change Log",
"entries": entries,
"total_entries": len(entries),
"data_source": "Audit Logs",
"date_from": date_from.strftime("%Y-%m-%d") if date_from else "",
"date_to": date_to.strftime("%Y-%m-%d") if date_to else "",
}
async def _change_log_from_backups(
db: AsyncSession,
date_from: Optional[datetime],
date_to: Optional[datetime],
) -> dict[str, Any]:
"""Build change log from config_backups + alert_events as fallback."""
# Config backups as change events
backup_result = await db.execute(
text("""
SELECT cb.created_at, 'system' AS user_name, 'config_backup' AS action,
d.hostname, cb.trigger_type AS details
FROM config_backups cb
JOIN devices d ON d.id = cb.device_id
WHERE cb.created_at >= :date_from
AND cb.created_at <= :date_to
"""),
{
"date_from": date_from,
"date_to": date_to,
},
)
backup_rows = backup_result.fetchall()
# Alert events as change events
alert_result = await db.execute(
text("""
SELECT ae.fired_at, 'system' AS user_name,
ae.severity || '_alert' AS action,
d.hostname, ae.message AS details
FROM alert_events ae
LEFT JOIN devices d ON d.id = ae.device_id
WHERE ae.fired_at >= :date_from
AND ae.fired_at <= :date_to
"""),
{
"date_from": date_from,
"date_to": date_to,
},
)
alert_rows = alert_result.fetchall()
# Merge and sort by timestamp descending
entries = []
for row in backup_rows:
entries.append({
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"user": row[1],
"action": row[2],
"device": row[3],
"details": row[4] or "",
})
for row in alert_rows:
entries.append({
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"user": row[1],
"action": row[2],
"device": row[3],
"details": row[4] or "",
})
# Sort by timestamp string descending
entries.sort(key=lambda e: e["timestamp"], reverse=True)
return {
"report_title": "Change Log",
"entries": entries,
"total_entries": len(entries),
"data_source": "Backups + Alerts",
"date_from": date_from.strftime("%Y-%m-%d") if date_from else "",
"date_to": date_to.strftime("%Y-%m-%d") if date_to else "",
}
# ---------------------------------------------------------------------------
# Rendering helpers
# ---------------------------------------------------------------------------
def _render_pdf(report_type: str, context: dict[str, Any]) -> bytes:
"""Render HTML template and convert to PDF via weasyprint."""
import weasyprint
template = _jinja_env.get_template(f"reports/{report_type}.html")
html_str = template.render(**context)
pdf_bytes = weasyprint.HTML(string=html_str).write_pdf()
return pdf_bytes
def _render_csv(report_type: str, data: dict[str, Any]) -> bytes:
"""Render report data as CSV bytes."""
output = io.StringIO()
writer = csv.writer(output)
if report_type == "device_inventory":
writer.writerow([
"Hostname", "IP Address", "Model", "RouterOS Version",
"Status", "Last Seen", "Uptime", "Groups",
])
for d in data.get("devices", []):
writer.writerow([
d["hostname"], d["ip_address"], d["model"] or "",
d["routeros_version"] or "", d["status"],
d["last_seen"] or "", d["uptime"] or "",
d["groups"] or "",
])
elif report_type == "metrics_summary":
writer.writerow([
"Hostname", "Avg CPU %", "Peak CPU %", "Avg Memory %",
"Peak Memory %", "Avg Disk %", "Avg Temp", "Data Points",
])
for d in data.get("devices", []):
writer.writerow([
d["hostname"],
f"{d['avg_cpu']:.1f}" if d["avg_cpu"] is not None else "",
f"{d['peak_cpu']:.1f}" if d["peak_cpu"] is not None else "",
f"{d['avg_mem']:.1f}" if d["avg_mem"] is not None else "",
f"{d['peak_mem']:.1f}" if d["peak_mem"] is not None else "",
f"{d['avg_disk']:.1f}" if d["avg_disk"] is not None else "",
f"{d['avg_temp']:.1f}" if d["avg_temp"] is not None else "",
d["data_points"],
])
elif report_type == "alert_history":
writer.writerow([
"Timestamp", "Device", "Severity", "Message", "Status", "Duration",
])
for a in data.get("alerts", []):
writer.writerow([
a["fired_at"], a["hostname"] or "", a["severity"],
a["message"] or "", a["status"], a["duration"] or "",
])
elif report_type == "change_log":
writer.writerow([
"Timestamp", "User", "Action", "Device", "Details",
])
for e in data.get("entries", []):
writer.writerow([
e["timestamp"], e["user"] or "", e["action"],
e["device"] or "", e["details"] or "",
])
return output.getvalue().encode("utf-8")
# ---------------------------------------------------------------------------
# Formatting utilities
# ---------------------------------------------------------------------------
def _format_uptime(seconds: int) -> str:
"""Format uptime seconds as human-readable string."""
days = seconds // 86400
hours = (seconds % 86400) // 3600
minutes = (seconds % 3600) // 60
if days > 0:
return f"{days}d {hours}h {minutes}m"
elif hours > 0:
return f"{hours}h {minutes}m"
else:
return f"{minutes}m"
def _format_duration(seconds: float) -> str:
"""Format a duration in seconds as a human-readable string."""
if seconds < 60:
return f"{int(seconds)}s"
elif seconds < 3600:
return f"{int(seconds // 60)}m {int(seconds % 60)}s"
elif seconds < 86400:
hours = int(seconds // 3600)
mins = int((seconds % 3600) // 60)
return f"{hours}h {mins}m"
else:
days = int(seconds // 86400)
hours = int((seconds % 86400) // 3600)
return f"{days}d {hours}h"

View File

@@ -0,0 +1,599 @@
"""Two-phase config push with panic-revert safety for RouterOS devices.
This module implements the critical safety mechanism for config restoration:
Phase 1 — Push:
1. Pre-backup (mandatory) — snapshot current config before any changes
2. Install panic-revert RouterOS scheduler — auto-reverts if device becomes
unreachable (the scheduler fires after 90s and loads the pre-push backup)
3. Push the target config via SSH /import
Phase 2 — Verification (60s settle window):
4. Wait 60s for config to settle (scheduled processes restart, etc.)
5. Reachability check via asyncssh
6a. Reachable — remove panic-revert scheduler; mark operation committed
6b. Unreachable — RouterOS is auto-reverting; mark operation reverted
Pitfall 6 handling:
If the API pod restarts during the 60s window, the config_push_operations
row with status='pending_verification' serves as the recovery signal.
On startup, recover_stale_push_operations() resolves any stale rows.
Security policy:
known_hosts=None — RouterOS self-signed host keys; mirrors InsecureSkipVerify
used in the poller's TLS connection. See Pitfall 2 in 04-RESEARCH.md.
"""
import asyncio
import json
import logging
from datetime import datetime, timedelta, timezone
import asyncssh
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.database import set_tenant_context, AdminAsyncSessionLocal
from app.models.config_backup import ConfigPushOperation
from app.models.device import Device
from app.services import backup_service, git_store
from app.services.event_publisher import publish_event
from app.services.push_tracker import record_push, clear_push
logger = logging.getLogger(__name__)
# Name of the panic-revert scheduler installed on the RouterOS device
_PANIC_REVERT_SCHEDULER = "mikrotik-portal-panic-revert"
# Name of the pre-push binary backup saved on device flash
_PRE_PUSH_BACKUP = "portal-pre-push"
# Name of the RSC file used for /import on device
_RESTORE_RSC = "portal-restore.rsc"
async def _publish_push_progress(
tenant_id: str,
device_id: str,
stage: str,
message: str,
push_op_id: str | None = None,
error: str | None = None,
) -> None:
"""Publish config push progress event to NATS (fire-and-forget)."""
payload = {
"event_type": "config_push",
"tenant_id": tenant_id,
"device_id": device_id,
"stage": stage,
"message": message,
"timestamp": datetime.now(timezone.utc).isoformat(),
"push_operation_id": push_op_id,
}
if error:
payload["error"] = error
await publish_event(f"config.push.{tenant_id}.{device_id}", payload)
async def restore_config(
device_id: str,
tenant_id: str,
commit_sha: str,
db_session: AsyncSession,
) -> dict:
"""Restore a device config to a specific backup version via two-phase push.
Args:
device_id: Device UUID as string.
tenant_id: Tenant UUID as string.
commit_sha: Git commit SHA of the backup version to restore.
db_session: AsyncSession with RLS context already set (from API endpoint).
Returns:
{
"status": "committed" | "reverted" | "failed",
"message": str,
"pre_backup_sha": str,
}
Raises:
ValueError: If device not found or missing credentials.
Exception: On SSH failure during push phase (reverted status logged).
"""
loop = asyncio.get_event_loop()
# ------------------------------------------------------------------
# Step 1: Load device from DB and decrypt credentials
# ------------------------------------------------------------------
from sqlalchemy import select
result = await db_session.execute(
select(Device).where(Device.id == device_id) # type: ignore[arg-type]
)
device = result.scalar_one_or_none()
if device is None:
raise ValueError(f"Device {device_id!r} not found")
if not device.encrypted_credentials_transit and not device.encrypted_credentials:
raise ValueError(
f"Device {device_id!r} has no stored credentials — cannot perform restore"
)
key = settings.get_encryption_key_bytes()
from app.services.crypto import decrypt_credentials_hybrid
creds_json = await decrypt_credentials_hybrid(
device.encrypted_credentials_transit,
device.encrypted_credentials,
str(device.tenant_id),
key,
)
creds = json.loads(creds_json)
ssh_username = creds.get("username", "")
ssh_password = creds.get("password", "")
ip = device.ip_address
hostname = device.hostname or ip
# Publish "started" progress event
await _publish_push_progress(tenant_id, device_id, "started", f"Config restore started for {hostname}")
# ------------------------------------------------------------------
# Step 2: Read the target export.rsc from the backup commit
# ------------------------------------------------------------------
try:
export_bytes = await loop.run_in_executor(
None,
git_store.read_file,
tenant_id,
commit_sha,
device_id,
"export.rsc",
)
except (KeyError, Exception) as exc:
raise ValueError(
f"Backup version {commit_sha!r} not found for device {device_id!r}: {exc}"
) from exc
export_text = export_bytes.decode("utf-8", errors="replace")
# ------------------------------------------------------------------
# Step 3: Mandatory pre-backup before push
# ------------------------------------------------------------------
await _publish_push_progress(tenant_id, device_id, "backing_up", f"Creating pre-restore backup for {hostname}")
logger.info(
"Starting pre-restore backup for device %s (%s) before pushing commit %s",
hostname,
ip,
commit_sha[:8],
)
pre_backup_result = await backup_service.run_backup(
device_id=device_id,
tenant_id=tenant_id,
trigger_type="pre-restore",
db_session=db_session,
)
pre_backup_sha = pre_backup_result["commit_sha"]
logger.info("Pre-restore backup complete: %s", pre_backup_sha[:8])
# ------------------------------------------------------------------
# Step 4: Record push operation (pending_verification for recovery)
# ------------------------------------------------------------------
push_op = ConfigPushOperation(
device_id=device.id,
tenant_id=device.tenant_id,
pre_push_commit_sha=pre_backup_sha,
scheduler_name=_PANIC_REVERT_SCHEDULER,
status="pending_verification",
)
db_session.add(push_op)
await db_session.flush()
push_op_id = push_op.id
logger.info(
"Push op %s in pending_verification — if API restarts, "
"recover_stale_push_operations() will resolve on next startup",
push_op.id,
)
# ------------------------------------------------------------------
# Step 5: SSH to device — install panic-revert, push config
# ------------------------------------------------------------------
push_op_id_str = str(push_op_id)
await _publish_push_progress(tenant_id, device_id, "pushing", f"Pushing config to {hostname}", push_op_id=push_op_id_str)
logger.info(
"Pushing config to device %s (%s): installing panic-revert scheduler and uploading config",
hostname,
ip,
)
try:
async with asyncssh.connect(
ip,
port=22,
username=ssh_username,
password=ssh_password,
known_hosts=None, # RouterOS self-signed host keys — see module docstring
connect_timeout=30,
) as conn:
# 5a: Create binary backup on device as revert point
await conn.run(
f"/system backup save name={_PRE_PUSH_BACKUP} dont-encrypt=yes",
check=True,
)
logger.debug("Pre-push binary backup saved on device as %s.backup", _PRE_PUSH_BACKUP)
# 5b: Install panic-revert RouterOS scheduler
# The scheduler fires after 90s on startup and loads the pre-push backup.
# This is the safety net: if the device becomes unreachable after push,
# RouterOS will auto-revert to the known-good config on the next reboot
# or after 90s of uptime.
await conn.run(
f"/system scheduler add "
f'name="{_PANIC_REVERT_SCHEDULER}" '
f"interval=90s "
f'on-event=":delay 0; /system backup load name={_PRE_PUSH_BACKUP}" '
f"start-time=startup",
check=True,
)
logger.debug("Panic-revert scheduler installed on device")
# 5c: Upload export.rsc and /import it
# Write the RSC content to the device filesystem via SSH exec,
# then use /import to apply it. The file is cleaned up after import.
# We use a here-doc approach: write content line-by-line via /file set.
# RouterOS supports writing files via /tool fetch or direct file commands.
# Simplest approach for large configs: use asyncssh's write_into to
# write file content, then /import.
#
# RouterOS doesn't support direct SFTP uploads via SSH open_sftp() easily
# for config files. Use the script approach instead:
# /system script add + run + remove (avoids flash write concerns).
#
# Actually the simplest method: write the export.rsc line by line via
# /file print / set commands is RouterOS 6 only and unreliable.
# Best approach for RouterOS 7: use SFTP to upload the file.
async with conn.start_sftp_client() as sftp:
async with sftp.open(_RESTORE_RSC, "wb") as f:
await f.write(export_text.encode("utf-8"))
logger.debug("Uploaded %s to device flash", _RESTORE_RSC)
# /import the config file
import_result = await conn.run(
f"/import file={_RESTORE_RSC}",
check=False, # Don't raise on non-zero exit — import may succeed with warnings
)
logger.info(
"Config import result for device %s: exit_status=%s stdout=%r",
hostname,
import_result.exit_status,
(import_result.stdout or "")[:200],
)
# Clean up the uploaded RSC file (best-effort)
try:
await conn.run(f"/file remove {_RESTORE_RSC}", check=True)
except Exception as cleanup_err:
logger.warning(
"Failed to clean up %s from device %s: %s",
_RESTORE_RSC,
ip,
cleanup_err,
)
except Exception as push_err:
logger.error(
"SSH push phase failed for device %s (%s): %s",
hostname,
ip,
push_err,
)
# Update push operation to failed
await _update_push_op_status(push_op_id, "failed", db_session)
await _publish_push_progress(
tenant_id, device_id, "failed",
f"Config push failed for {hostname}: {push_err}",
push_op_id=push_op_id_str, error=str(push_err),
)
return {
"status": "failed",
"message": f"Config push failed during SSH phase: {push_err}",
"pre_backup_sha": pre_backup_sha,
}
# Record push in Redis so the poller can detect post-push offline events
await record_push(
device_id=device_id,
tenant_id=tenant_id,
push_type="restore",
push_operation_id=push_op_id_str,
pre_push_commit_sha=pre_backup_sha,
)
# ------------------------------------------------------------------
# Step 6: Wait 60s for config to settle
# ------------------------------------------------------------------
await _publish_push_progress(tenant_id, device_id, "settling", f"Config pushed to {hostname} — waiting 60s for settle", push_op_id=push_op_id_str)
logger.info(
"Config pushed to device %s — waiting 60s for config to settle",
hostname,
)
await asyncio.sleep(60)
# ------------------------------------------------------------------
# Step 7: Reachability check
# ------------------------------------------------------------------
await _publish_push_progress(tenant_id, device_id, "verifying", f"Verifying device {hostname} reachability", push_op_id=push_op_id_str)
reachable = await _check_reachability(ip, ssh_username, ssh_password)
if reachable:
# ------------------------------------------------------------------
# Step 8a: Device is reachable — remove panic-revert scheduler + cleanup
# ------------------------------------------------------------------
logger.info("Device %s (%s) is reachable after push — committing", hostname, ip)
try:
async with asyncssh.connect(
ip,
port=22,
username=ssh_username,
password=ssh_password,
known_hosts=None,
connect_timeout=30,
) as conn:
# Remove the panic-revert scheduler
await conn.run(
f'/system scheduler remove "{_PANIC_REVERT_SCHEDULER}"',
check=False, # Non-fatal if already removed
)
# Clean up the pre-push binary backup from device flash
await conn.run(
f"/file remove {_PRE_PUSH_BACKUP}.backup",
check=False, # Non-fatal if already removed
)
except Exception as cleanup_err:
# Cleanup failure is non-fatal — scheduler will eventually fire but
# the backup is now the correct config, so it's acceptable.
logger.warning(
"Failed to clean up panic-revert scheduler/backup on device %s: %s",
hostname,
cleanup_err,
)
await _update_push_op_status(push_op_id, "committed", db_session)
await clear_push(device_id)
await _publish_push_progress(tenant_id, device_id, "committed", f"Config restored successfully on {hostname}", push_op_id=push_op_id_str)
return {
"status": "committed",
"message": "Config restored successfully",
"pre_backup_sha": pre_backup_sha,
}
else:
# ------------------------------------------------------------------
# Step 8b: Device unreachable — RouterOS is auto-reverting via scheduler
# ------------------------------------------------------------------
logger.warning(
"Device %s (%s) is unreachable after push — RouterOS panic-revert scheduler "
"will auto-revert to %s.backup",
hostname,
ip,
_PRE_PUSH_BACKUP,
)
await _update_push_op_status(push_op_id, "reverted", db_session)
await _publish_push_progress(
tenant_id, device_id, "reverted",
f"Device {hostname} unreachable — auto-reverting via panic-revert scheduler",
push_op_id=push_op_id_str,
)
return {
"status": "reverted",
"message": (
"Device unreachable after push; RouterOS is auto-reverting "
"via panic-revert scheduler"
),
"pre_backup_sha": pre_backup_sha,
}
async def _check_reachability(ip: str, username: str, password: str) -> bool:
"""Check if a RouterOS device is reachable via SSH.
Attempts to connect and run a simple command (/system identity print).
Returns True if successful, False if the connection fails or times out.
Uses asyncssh (not the poller's binary API) to avoid a circular import.
A 30-second timeout is used — if the device doesn't respond within that
window, it's considered unreachable (panic-revert will handle it).
Args:
ip: Device IP address.
username: SSH username.
password: SSH password.
Returns:
True if reachable, False if unreachable.
"""
try:
async with asyncssh.connect(
ip,
port=22,
username=username,
password=password,
known_hosts=None,
connect_timeout=30,
) as conn:
result = await conn.run("/system identity print", check=True)
logger.debug("Reachability check OK for %s: %r", ip, result.stdout[:50])
return True
except Exception as exc:
logger.info("Device %s unreachable after push: %s", ip, exc)
return False
async def _update_push_op_status(
push_op_id,
new_status: str,
db_session: AsyncSession,
) -> None:
"""Update the status and completed_at of a ConfigPushOperation row.
Args:
push_op_id: UUID of the ConfigPushOperation row.
new_status: New status value ('committed' | 'reverted' | 'failed').
db_session: Database session (must already have tenant context set).
"""
from sqlalchemy import select, update
await db_session.execute(
update(ConfigPushOperation)
.where(ConfigPushOperation.id == push_op_id) # type: ignore[arg-type]
.values(
status=new_status,
completed_at=datetime.now(timezone.utc),
)
)
# Don't commit here — the caller (endpoint) owns the transaction
async def _remove_panic_scheduler(
ip: str, username: str, password: str, scheduler_name: str
) -> bool:
"""SSH to device and remove the panic-revert scheduler. Returns True if removed."""
try:
async with asyncssh.connect(
ip,
username=username,
password=password,
known_hosts=None,
connect_timeout=30,
) as conn:
# Check if scheduler exists
result = await conn.run(
f'/system scheduler print where name="{scheduler_name}"',
check=False,
)
if scheduler_name in result.stdout:
await conn.run(
f'/system scheduler remove [find name="{scheduler_name}"]',
check=False,
)
# Also clean up pre-push backup file
await conn.run(
f'/file remove [find name="{_PRE_PUSH_BACKUP}.backup"]',
check=False,
)
return True
return False # Scheduler already gone (device reverted itself)
except Exception as e:
logger.error("Failed to remove panic scheduler from %s: %s", ip, e)
return False
async def recover_stale_push_operations(db_session: AsyncSession) -> None:
"""Recover stale pending_verification push operations on API startup.
Scans for operations older than 5 minutes that are still pending.
For each, checks device reachability and resolves the operation.
"""
from sqlalchemy import select
from app.models.config_backup import ConfigPushOperation
from app.models.device import Device
from app.services.crypto import decrypt_credentials_hybrid
cutoff = datetime.now(timezone.utc) - timedelta(minutes=5)
result = await db_session.execute(
select(ConfigPushOperation).where(
ConfigPushOperation.status == "pending_verification",
ConfigPushOperation.started_at < cutoff,
)
)
stale_ops = result.scalars().all()
if not stale_ops:
logger.info("No stale push operations to recover")
return
logger.warning("Found %d stale push operations to recover", len(stale_ops))
key = settings.get_encryption_key_bytes()
for op in stale_ops:
try:
# Load device
dev_result = await db_session.execute(
select(Device).where(Device.id == op.device_id)
)
device = dev_result.scalar_one_or_none()
if not device:
logger.error("Device %s not found for stale op %s", op.device_id, op.id)
await _update_push_op_status(op.id, "failed", db_session)
continue
# Decrypt credentials
creds_json = await decrypt_credentials_hybrid(
device.encrypted_credentials_transit,
device.encrypted_credentials,
str(op.tenant_id),
key,
)
creds = json.loads(creds_json)
ssh_username = creds.get("username", "admin")
ssh_password = creds.get("password", "")
# Check reachability
reachable = await _check_reachability(
device.ip_address, ssh_username, ssh_password
)
if reachable:
# Try to remove scheduler (if still there, push was good)
removed = await _remove_panic_scheduler(
device.ip_address,
ssh_username,
ssh_password,
op.scheduler_name,
)
if removed:
logger.info("Recovery: committed op %s (scheduler removed)", op.id)
else:
# Scheduler already gone — device may have reverted
logger.warning(
"Recovery: op %s — scheduler gone, device may have reverted. "
"Marking committed (device is reachable).",
op.id,
)
await _update_push_op_status(op.id, "committed", db_session)
await _publish_push_progress(
str(op.tenant_id),
str(op.device_id),
"committed",
"Recovered after API restart",
push_op_id=str(op.id),
)
else:
logger.warning(
"Recovery: device %s unreachable, marking op %s failed",
op.device_id,
op.id,
)
await _update_push_op_status(op.id, "failed", db_session)
await _publish_push_progress(
str(op.tenant_id),
str(op.device_id),
"failed",
"Device unreachable during recovery after API restart",
push_op_id=str(op.id),
)
except Exception as e:
logger.error("Recovery failed for op %s: %s", op.id, e)
await _update_push_op_status(op.id, "failed", db_session)
await db_session.commit()

View File

@@ -0,0 +1,165 @@
"""RouterOS command proxy via NATS request-reply.
Sends command requests to the Go poller's CmdResponder subscription
(device.cmd.{device_id}) and returns structured RouterOS API response data.
Used by:
- Config editor API (browse menu paths, add/edit/delete entries)
- Template push service (execute rendered template commands)
"""
import json
import logging
from typing import Any
import nats
import nats.aio.client
from app.config import settings
logger = logging.getLogger(__name__)
# Module-level NATS connection (lazy initialized)
_nc: nats.aio.client.Client | None = None
async def _get_nats() -> nats.aio.client.Client:
"""Get or create a NATS connection for command proxy requests."""
global _nc
if _nc is None or _nc.is_closed:
_nc = await nats.connect(settings.NATS_URL)
logger.info("RouterOS proxy NATS connection established")
return _nc
async def execute_command(
device_id: str,
command: str,
args: list[str] | None = None,
timeout: float = 15.0,
) -> dict[str, Any]:
"""Execute a RouterOS API command on a device via the Go poller.
Args:
device_id: UUID string of the target device.
command: Full RouterOS API path, e.g. "/ip/address/print".
args: Optional list of RouterOS API args, e.g. ["=.proplist=.id,address"].
timeout: NATS request timeout in seconds (default 15s).
Returns:
{"success": bool, "data": list[dict], "error": str|None}
"""
nc = await _get_nats()
request = {
"device_id": device_id,
"command": command,
"args": args or [],
}
try:
reply = await nc.request(
f"device.cmd.{device_id}",
json.dumps(request).encode(),
timeout=timeout,
)
return json.loads(reply.data)
except nats.errors.TimeoutError:
return {
"success": False,
"data": [],
"error": "Device command timed out — device may be offline or unreachable",
}
except Exception as exc:
logger.error("NATS request failed for device %s: %s", device_id, exc)
return {"success": False, "data": [], "error": str(exc)}
async def browse_menu(device_id: str, path: str) -> dict[str, Any]:
"""Browse a RouterOS menu path and return all entries.
Args:
device_id: Device UUID string.
path: RouterOS menu path, e.g. "/ip/address" or "/interface".
Returns:
{"success": bool, "data": list[dict], "error": str|None}
"""
command = f"{path}/print"
return await execute_command(device_id, command)
async def add_entry(
device_id: str, path: str, properties: dict[str, str]
) -> dict[str, Any]:
"""Add a new entry to a RouterOS menu path.
Args:
device_id: Device UUID.
path: Menu path, e.g. "/ip/address".
properties: Key-value pairs for the new entry.
Returns:
Command response dict.
"""
args = [f"={k}={v}" for k, v in properties.items()]
return await execute_command(device_id, f"{path}/add", args)
async def update_entry(
device_id: str, path: str, entry_id: str | None, properties: dict[str, str]
) -> dict[str, Any]:
"""Update an existing entry in a RouterOS menu path.
Args:
device_id: Device UUID.
path: Menu path.
entry_id: RouterOS .id value (e.g. "*1"). None for singleton paths.
properties: Key-value pairs to update.
Returns:
Command response dict.
"""
id_args = [f"=.id={entry_id}"] if entry_id else []
args = id_args + [f"={k}={v}" for k, v in properties.items()]
return await execute_command(device_id, f"{path}/set", args)
async def remove_entry(
device_id: str, path: str, entry_id: str
) -> dict[str, Any]:
"""Remove an entry from a RouterOS menu path.
Args:
device_id: Device UUID.
path: Menu path.
entry_id: RouterOS .id value.
Returns:
Command response dict.
"""
return await execute_command(device_id, f"{path}/remove", [f"=.id={entry_id}"])
async def execute_cli(device_id: str, cli_command: str) -> dict[str, Any]:
"""Execute an arbitrary RouterOS CLI command.
For commands that don't follow the standard /path/action pattern.
The command is sent as-is to the RouterOS API.
Args:
device_id: Device UUID.
cli_command: Full CLI command string.
Returns:
Command response dict.
"""
return await execute_command(device_id, cli_command)
async def close() -> None:
"""Close the NATS connection. Called on application shutdown."""
global _nc
if _nc and not _nc.is_closed:
await _nc.drain()
_nc = None
logger.info("RouterOS proxy NATS connection closed")

View File

@@ -0,0 +1,220 @@
"""RouterOS RSC export parser — extracts categories, validates syntax, computes impact."""
import re
import logging
from typing import Any
logger = logging.getLogger(__name__)
HIGH_RISK_PATHS = {
"/ip address", "/ip route", "/ip firewall filter", "/ip firewall nat",
"/interface", "/interface bridge", "/interface vlan",
"/system identity", "/ip service", "/ip ssh", "/user",
}
MANAGEMENT_PATTERNS = [
(re.compile(r"chain=input.*dst-port=(22|8291|8728|8729|443|80)", re.I),
"Modifies firewall rules for management ports (SSH/WinBox/API/Web)"),
(re.compile(r"chain=input.*action=drop", re.I),
"Adds drop rule on input chain — may block management access"),
(re.compile(r"/ip service", re.I),
"Modifies IP services — may disable API/SSH/WinBox access"),
(re.compile(r"/user.*set.*password", re.I),
"Changes user password — may affect automated access"),
]
def _join_continuation_lines(text: str) -> list[str]:
"""Join lines ending with \\ into single logical lines."""
lines = text.split("\n")
joined: list[str] = []
buf = ""
for line in lines:
stripped = line.rstrip()
if stripped.endswith("\\"):
buf += stripped[:-1].rstrip() + " "
else:
if buf:
buf += stripped
joined.append(buf)
buf = ""
else:
joined.append(stripped)
if buf:
joined.append(buf + " <<TRUNCATED>>")
return joined
def parse_rsc(text: str) -> dict[str, Any]:
"""Parse a RouterOS /export compact output.
Returns a dict with a "categories" list, each containing:
- path: the RouterOS command path (e.g. "/ip address")
- adds: count of "add" commands
- sets: count of "set" commands
- removes: count of "remove" commands
- commands: list of command strings under this path
"""
lines = _join_continuation_lines(text)
categories: dict[str, dict] = {}
current_path: str | None = None
for line in lines:
line = line.strip()
if not line or line.startswith("#"):
continue
if line.startswith("/"):
# Could be just a path header, or a path followed by a command
parts = line.split(None, 1)
if len(parts) == 1:
# Pure path header like "/interface bridge"
current_path = parts[0]
else:
# Check if second part starts with a known command verb
cmd_check = parts[1].strip().split(None, 1)
if cmd_check and cmd_check[0] in ("add", "set", "remove", "print", "enable", "disable"):
current_path = parts[0]
line = parts[1].strip()
else:
# The whole line is a path (e.g. "/ip firewall filter")
current_path = line
continue
if current_path and current_path not in categories:
categories[current_path] = {
"path": current_path,
"adds": 0,
"sets": 0,
"removes": 0,
"commands": [],
}
if len(parts) == 1:
continue
if current_path is None:
continue
if current_path not in categories:
categories[current_path] = {
"path": current_path,
"adds": 0,
"sets": 0,
"removes": 0,
"commands": [],
}
cat = categories[current_path]
cat["commands"].append(line)
if line.startswith("add ") or line.startswith("add\t"):
cat["adds"] += 1
elif line.startswith("set "):
cat["sets"] += 1
elif line.startswith("remove "):
cat["removes"] += 1
return {"categories": list(categories.values())}
def validate_rsc(text: str) -> dict[str, Any]:
"""Validate RSC export syntax.
Checks for:
- Unbalanced quotes (indicates truncation or corruption)
- Trailing continuation lines (indicates truncated export)
Returns dict with "valid" (bool) and "errors" (list of strings).
"""
errors: list[str] = []
# Check for unbalanced quotes across the entire file
in_quote = False
for line in text.split("\n"):
stripped = line.rstrip()
if stripped.endswith("\\"):
stripped = stripped[:-1]
# Count unescaped quotes
count = stripped.count('"') - stripped.count('\\"')
if count % 2 != 0:
in_quote = not in_quote
if in_quote:
errors.append("Unbalanced quote detected — file may be truncated")
# Check if file ends with a continuation backslash
lines = text.rstrip().split("\n")
if lines and lines[-1].rstrip().endswith("\\"):
errors.append("File ends with continuation line (\\) — truncated export")
return {"valid": len(errors) == 0, "errors": errors}
def compute_impact(
current_parsed: dict[str, Any],
target_parsed: dict[str, Any],
) -> dict[str, Any]:
"""Compare current vs target parsed RSC and compute impact analysis.
Returns dict with:
- categories: list of per-path diffs with risk levels
- warnings: list of human-readable warning strings
- diff: summary counts (added, removed, modified)
"""
current_map = {c["path"]: c for c in current_parsed["categories"]}
target_map = {c["path"]: c for c in target_parsed["categories"]}
all_paths = sorted(set(list(current_map.keys()) + list(target_map.keys())))
result_categories = []
warnings: list[str] = []
total_added = total_removed = total_modified = 0
for path in all_paths:
curr = current_map.get(path, {"adds": 0, "sets": 0, "removes": 0, "commands": []})
tgt = target_map.get(path, {"adds": 0, "sets": 0, "removes": 0, "commands": []})
curr_cmds = set(curr.get("commands", []))
tgt_cmds = set(tgt.get("commands", []))
added = len(tgt_cmds - curr_cmds)
removed = len(curr_cmds - tgt_cmds)
total_added += added
total_removed += removed
has_changes = added > 0 or removed > 0
risk = "none"
if has_changes:
risk = "high" if path in HIGH_RISK_PATHS else "low"
result_categories.append({
"path": path,
"adds": added,
"removes": removed,
"risk": risk,
})
# Check target commands against management patterns
target_text = "\n".join(
cmd for cat in target_parsed["categories"] for cmd in cat.get("commands", [])
)
for pattern, message in MANAGEMENT_PATTERNS:
if pattern.search(target_text):
warnings.append(message)
# Warn about removed IP addresses
if "/ip address" in current_map and "/ip address" in target_map:
curr_addrs = current_map["/ip address"].get("commands", [])
tgt_addrs = target_map["/ip address"].get("commands", [])
removed_addrs = set(curr_addrs) - set(tgt_addrs)
if removed_addrs:
warnings.append(
f"Removes {len(removed_addrs)} IP address(es) — verify none are management interfaces"
)
return {
"categories": result_categories,
"warnings": warnings,
"diff": {
"added": total_added,
"removed": total_removed,
"modified": total_modified,
},
}

View File

@@ -0,0 +1,124 @@
"""
Subnet scanner for MikroTik device discovery.
Scans a CIDR range by attempting TCP connections to RouterOS API ports
(8728 and 8729) with configurable concurrency limits and timeouts.
Security constraints:
- CIDR range limited to /20 or smaller (4096 IPs maximum)
- Maximum 50 concurrent connections to prevent network flooding
- 2-second timeout per connection attempt
"""
import asyncio
import ipaddress
import socket
from typing import Optional
from app.schemas.device import SubnetScanResult
# Maximum concurrency for TCP probes
_MAX_CONCURRENT = 50
# Timeout (seconds) per TCP connection attempt
_TCP_TIMEOUT = 2.0
# RouterOS API port
_API_PORT = 8728
# RouterOS SSL API port
_SSL_PORT = 8729
async def _probe_host(
semaphore: asyncio.Semaphore,
ip_str: str,
) -> Optional[SubnetScanResult]:
"""
Probe a single IP for RouterOS API ports.
Returns a SubnetScanResult if either port is open, None otherwise.
"""
async with semaphore:
api_open, ssl_open = await asyncio.gather(
_tcp_connect(ip_str, _API_PORT),
_tcp_connect(ip_str, _SSL_PORT),
return_exceptions=False,
)
if not api_open and not ssl_open:
return None
# Attempt reverse DNS (best-effort; won't fail the scan)
hostname = await _reverse_dns(ip_str)
return SubnetScanResult(
ip_address=ip_str,
hostname=hostname,
api_port_open=api_open,
api_ssl_port_open=ssl_open,
)
async def _tcp_connect(ip: str, port: int) -> bool:
"""Return True if a TCP connection to ip:port succeeds within _TCP_TIMEOUT."""
try:
_, writer = await asyncio.wait_for(
asyncio.open_connection(ip, port),
timeout=_TCP_TIMEOUT,
)
writer.close()
try:
await writer.wait_closed()
except Exception:
pass
return True
except Exception:
return False
async def _reverse_dns(ip: str) -> Optional[str]:
"""Attempt a reverse DNS lookup. Returns None on failure."""
try:
loop = asyncio.get_running_loop()
hostname, _, _ = await asyncio.wait_for(
loop.run_in_executor(None, socket.gethostbyaddr, ip),
timeout=1.5,
)
return hostname
except Exception:
return None
async def scan_subnet(cidr: str) -> list[SubnetScanResult]:
"""
Scan a CIDR range for hosts with open RouterOS API ports.
Args:
cidr: CIDR notation string, e.g. "192.168.1.0/24".
Must be /20 or smaller (validated by SubnetScanRequest).
Returns:
List of SubnetScanResult for each host with at least one open API port.
Raises:
ValueError: If CIDR is malformed or too large.
"""
try:
network = ipaddress.ip_network(cidr, strict=False)
except ValueError as e:
raise ValueError(f"Invalid CIDR: {e}") from e
if network.num_addresses > 4096:
raise ValueError(
f"CIDR range too large ({network.num_addresses} addresses). "
"Maximum allowed is /20 (4096 addresses)."
)
# Skip network address and broadcast address for IPv4
hosts = list(network.hosts()) if network.num_addresses > 2 else list(network)
semaphore = asyncio.Semaphore(_MAX_CONCURRENT)
tasks = [_probe_host(semaphore, str(ip)) for ip in hosts]
results = await asyncio.gather(*tasks, return_exceptions=False)
# Filter out None (hosts with no open ports)
return [r for r in results if r is not None]

View File

@@ -0,0 +1,113 @@
"""SRP-6a server-side authentication service.
Wraps the srptools library for the two-step SRP handshake.
All functions are async, using asyncio.to_thread() because
srptools operations are CPU-bound and synchronous.
"""
import asyncio
import hashlib
from srptools import SRPContext, SRPServerSession
from srptools.constants import PRIME_2048, PRIME_2048_GEN
# Client uses Web Crypto SHA-256 — server must match.
# srptools defaults to SHA-1 which would cause proof mismatch.
_SRP_HASH = hashlib.sha256
async def create_srp_verifier(
salt_hex: str, verifier_hex: str
) -> tuple[bytes, bytes]:
"""Convert client-provided hex salt and verifier to bytes for storage.
The client computes v = g^x mod N using 2SKD-derived SRP-x.
The server stores the verifier directly and never computes x
from the password.
Returns:
Tuple of (salt_bytes, verifier_bytes) ready for database storage.
"""
return bytes.fromhex(salt_hex), bytes.fromhex(verifier_hex)
async def srp_init(
email: str, srp_verifier_hex: str
) -> tuple[str, str]:
"""SRP Step 1: Generate server ephemeral (B) and private key (b).
Args:
email: User email (SRP identity I).
srp_verifier_hex: Hex-encoded SRP verifier from database.
Returns:
Tuple of (server_public_hex, server_private_hex).
Caller stores server_private in Redis with 60s TTL.
Raises:
ValueError: If SRP initialization fails for any reason.
"""
def _init() -> tuple[str, str]:
context = SRPContext(
email, prime=PRIME_2048, generator=PRIME_2048_GEN,
hash_func=_SRP_HASH,
)
server_session = SRPServerSession(
context, srp_verifier_hex
)
return server_session.public, server_session.private
try:
return await asyncio.to_thread(_init)
except Exception as e:
raise ValueError(f"SRP initialization failed: {e}") from e
async def srp_verify(
email: str,
srp_verifier_hex: str,
server_private: str,
client_public: str,
client_proof: str,
srp_salt_hex: str,
) -> tuple[bool, str | None]:
"""SRP Step 2: Verify client proof M1, return server proof M2.
Args:
email: User email (SRP identity I).
srp_verifier_hex: Hex-encoded SRP verifier from database.
server_private: Server private ephemeral from Redis session.
client_public: Hex-encoded client public ephemeral A.
client_proof: Hex-encoded client proof M1.
srp_salt_hex: Hex-encoded SRP salt.
Returns:
Tuple of (is_valid, server_proof_hex_or_none).
If valid, server_proof is M2 for the client to verify.
"""
def _verify() -> tuple[bool, str | None]:
import logging
log = logging.getLogger("srp_debug")
context = SRPContext(
email, prime=PRIME_2048, generator=PRIME_2048_GEN,
hash_func=_SRP_HASH,
)
server_session = SRPServerSession(
context, srp_verifier_hex, private=server_private
)
_key, _key_proof, _key_proof_hash = server_session.process(client_public, srp_salt_hex)
# srptools verify_proof has a Python 3 bug: hexlify() returns bytes
# but client_proof is str, so bytes == str is always False.
# Compare manually with consistent types.
server_m1 = _key_proof if isinstance(_key_proof, str) else _key_proof.decode('ascii')
is_valid = client_proof.lower() == server_m1.lower()
if not is_valid:
return False, None
# Return M2 (key_proof_hash), also fixing the bytes/str issue
m2 = _key_proof_hash if isinstance(_key_proof_hash, str) else _key_proof_hash.decode('ascii')
return True, m2
try:
return await asyncio.to_thread(_verify)
except Exception as e:
raise ValueError(f"SRP verification failed: {e}") from e

View File

@@ -0,0 +1,311 @@
"""SSE Connection Manager -- bridges NATS JetStream to per-client asyncio queues.
Each SSE client gets its own NATS connection with ephemeral consumers.
Events are tenant-filtered and placed onto an asyncio.Queue that the
SSE router drains via EventSourceResponse.
"""
import asyncio
import json
from typing import Optional
import nats
import structlog
from nats.js.api import ConsumerConfig, DeliverPolicy, StreamConfig
from app.config import settings
logger = structlog.get_logger(__name__)
# Subjects per stream for SSE subscriptions
# Note: config.push.* subjects live in DEVICE_EVENTS (created by Go poller)
_DEVICE_EVENT_SUBJECTS = [
"device.status.>",
"device.metrics.>",
"config.push.rollback.>",
"config.push.alert.>",
]
_ALERT_EVENT_SUBJECTS = ["alert.fired.>", "alert.resolved.>"]
_OPERATION_EVENT_SUBJECTS = ["firmware.progress.>"]
def _map_subject_to_event_type(subject: str) -> str:
"""Map a NATS subject prefix to an SSE event type string."""
if subject.startswith("device.status."):
return "device_status"
if subject.startswith("device.metrics."):
return "metric_update"
if subject.startswith("alert.fired."):
return "alert_fired"
if subject.startswith("alert.resolved."):
return "alert_resolved"
if subject.startswith("config.push."):
return "config_push"
if subject.startswith("firmware.progress."):
return "firmware_progress"
return "unknown"
async def ensure_sse_streams() -> None:
"""Create ALERT_EVENTS and OPERATION_EVENTS NATS streams if they don't exist.
Called once during app startup so the streams are ready before any
SSE connection or event publisher needs them. Idempotent -- uses
add_stream which acts as create-or-update.
"""
nc = None
try:
nc = await nats.connect(settings.NATS_URL)
js = nc.jetstream()
await js.add_stream(
StreamConfig(
name="ALERT_EVENTS",
subjects=["alert.fired.>", "alert.resolved.>"],
max_age=3600, # 1 hour retention
)
)
logger.info("nats.stream.ensured", stream="ALERT_EVENTS")
await js.add_stream(
StreamConfig(
name="OPERATION_EVENTS",
subjects=["firmware.progress.>"],
max_age=3600, # 1 hour retention
)
)
logger.info("nats.stream.ensured", stream="OPERATION_EVENTS")
except Exception as exc:
logger.warning("sse.streams.ensure_failed", error=str(exc))
raise
finally:
if nc:
try:
await nc.close()
except Exception:
pass
class SSEConnectionManager:
"""Manages a single SSE client's lifecycle: NATS connection, subscriptions, and event queue."""
def __init__(self) -> None:
self._nc: Optional[nats.aio.client.Client] = None
self._subscriptions: list = []
self._queue: Optional[asyncio.Queue] = None
self._tenant_id: Optional[str] = None
self._connection_id: Optional[str] = None
async def connect(
self,
connection_id: str,
tenant_id: Optional[str],
last_event_id: Optional[str] = None,
) -> asyncio.Queue:
"""Set up NATS subscriptions and return an asyncio.Queue for SSE events.
Args:
connection_id: Unique identifier for this SSE connection.
tenant_id: Tenant UUID string to filter events. None for super_admin
(receives events from all tenants).
last_event_id: NATS stream sequence number from the Last-Event-ID header.
If provided, replay starts from sequence + 1.
Returns:
asyncio.Queue that the SSE generator should drain.
"""
self._connection_id = connection_id
self._tenant_id = tenant_id
self._queue = asyncio.Queue(maxsize=256)
self._nc = await nats.connect(
settings.NATS_URL,
max_reconnect_attempts=5,
reconnect_time_wait=2,
)
js = self._nc.jetstream()
logger.info(
"sse.connecting",
connection_id=connection_id,
tenant_id=tenant_id,
last_event_id=last_event_id,
)
# Build consumer config for replay support
if last_event_id is not None:
try:
start_seq = int(last_event_id) + 1
consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.BY_START_SEQUENCE, opt_start_seq=start_seq)
except (ValueError, TypeError):
consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.NEW)
else:
consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.NEW)
# Subscribe to device events (DEVICE_EVENTS stream -- created by Go poller)
for subject in _DEVICE_EVENT_SUBJECTS:
try:
sub = await js.subscribe(
subject,
stream="DEVICE_EVENTS",
config=consumer_cfg,
)
self._subscriptions.append(sub)
except Exception as exc:
logger.warning(
"sse.subscribe_failed",
subject=subject,
stream="DEVICE_EVENTS",
error=str(exc),
)
# Subscribe to alert events (ALERT_EVENTS stream)
# Lazily create the stream if it doesn't exist yet (startup race)
for subject in _ALERT_EVENT_SUBJECTS:
try:
sub = await js.subscribe(
subject,
stream="ALERT_EVENTS",
config=consumer_cfg,
)
self._subscriptions.append(sub)
except Exception as exc:
if "stream not found" in str(exc):
try:
await js.add_stream(StreamConfig(
name="ALERT_EVENTS",
subjects=_ALERT_EVENT_SUBJECTS,
max_age=3600,
))
sub = await js.subscribe(subject, stream="ALERT_EVENTS", config=consumer_cfg)
self._subscriptions.append(sub)
logger.info("sse.stream_created_lazily", stream="ALERT_EVENTS")
except Exception as retry_exc:
logger.warning("sse.subscribe_failed", subject=subject, stream="ALERT_EVENTS", error=str(retry_exc))
else:
logger.warning("sse.subscribe_failed", subject=subject, stream="ALERT_EVENTS", error=str(exc))
# Subscribe to operation events (OPERATION_EVENTS stream)
for subject in _OPERATION_EVENT_SUBJECTS:
try:
sub = await js.subscribe(
subject,
stream="OPERATION_EVENTS",
config=consumer_cfg,
)
self._subscriptions.append(sub)
except Exception as exc:
if "stream not found" in str(exc):
try:
await js.add_stream(StreamConfig(
name="OPERATION_EVENTS",
subjects=_OPERATION_EVENT_SUBJECTS,
max_age=3600,
))
sub = await js.subscribe(subject, stream="OPERATION_EVENTS", config=consumer_cfg)
self._subscriptions.append(sub)
logger.info("sse.stream_created_lazily", stream="OPERATION_EVENTS")
except Exception as retry_exc:
logger.warning("sse.subscribe_failed", subject=subject, stream="OPERATION_EVENTS", error=str(retry_exc))
else:
logger.warning("sse.subscribe_failed", subject=subject, stream="OPERATION_EVENTS", error=str(exc))
# Start background task to pull messages from subscriptions into the queue
asyncio.create_task(self._pump_messages())
logger.info(
"sse.connected",
connection_id=connection_id,
subscription_count=len(self._subscriptions),
)
return self._queue
async def _pump_messages(self) -> None:
"""Read messages from all NATS push subscriptions and push them onto the asyncio queue.
Uses next_msg with a short timeout so we can interleave across
subscriptions without blocking. Runs until the NATS connection is closed
or drained.
"""
while self._nc and self._nc.is_connected:
for sub in self._subscriptions:
try:
msg = await sub.next_msg(timeout=0.5)
await self._handle_message(msg)
except nats.errors.TimeoutError:
# No messages available on this subscription -- move on
continue
except Exception as exc:
if self._nc and self._nc.is_connected:
logger.warning(
"sse.pump_error",
connection_id=self._connection_id,
error=str(exc),
)
break
# Brief yield to avoid tight-looping
await asyncio.sleep(0.1)
async def _handle_message(self, msg) -> None:
"""Parse a NATS message, apply tenant filter, and enqueue as SSE event."""
try:
data = json.loads(msg.data)
except (json.JSONDecodeError, UnicodeDecodeError):
await msg.ack()
return
# Tenant filtering: skip messages not matching this connection's tenant
if self._tenant_id is not None:
msg_tenant = data.get("tenant_id", "")
if str(msg_tenant) != self._tenant_id:
await msg.ack()
return
event_type = _map_subject_to_event_type(msg.subject)
# Extract NATS stream sequence for Last-Event-ID support
seq_id = "0"
if msg.metadata and msg.metadata.sequence:
seq_id = str(msg.metadata.sequence.stream)
sse_event = {
"event": event_type,
"data": json.dumps(data),
"id": seq_id,
}
try:
self._queue.put_nowait(sse_event)
except asyncio.QueueFull:
logger.warning(
"sse.queue_full",
connection_id=self._connection_id,
dropped_event=event_type,
)
await msg.ack()
async def disconnect(self) -> None:
"""Unsubscribe from all NATS subscriptions and close the connection."""
logger.info("sse.disconnecting", connection_id=self._connection_id)
for sub in self._subscriptions:
try:
await sub.unsubscribe()
except Exception:
pass
self._subscriptions.clear()
if self._nc:
try:
await self._nc.drain()
except Exception:
try:
await self._nc.close()
except Exception:
pass
self._nc = None
logger.info("sse.disconnected", connection_id=self._connection_id)

View File

@@ -0,0 +1,480 @@
"""Config template service: Jinja2 rendering, variable extraction, and multi-device push.
Provides:
- extract_variables: Parse template content to find all undeclared Jinja2 variables
- render_template: Render a template with device context and custom variables
- validate_variable: Type-check a variable value against its declared type
- push_to_devices: Sequential multi-device push with pause-on-failure
- push_single_device: Two-phase panic-revert push for a single device
The push logic follows the same two-phase pattern as restore_service but uses
separate scheduler and file names to avoid conflicts with restore operations.
"""
import asyncio
import io
import ipaddress
import json
import logging
import uuid
from datetime import datetime, timezone
import asyncssh
from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment
from sqlalchemy import select, text
from app.config import settings
from app.database import AdminAsyncSessionLocal
from app.models.config_template import TemplatePushJob
from app.models.device import Device
logger = logging.getLogger(__name__)
# Sandboxed Jinja2 environment prevents template injection
_env = SandboxedEnvironment()
# Names used on the RouterOS device during template push
_PANIC_REVERT_SCHEDULER = "mikrotik-portal-template-revert"
_PRE_PUSH_BACKUP = "portal-template-pre-push"
_TEMPLATE_RSC = "portal-template.rsc"
# ---------------------------------------------------------------------------
# Variable extraction & rendering
# ---------------------------------------------------------------------------
def extract_variables(template_content: str) -> list[str]:
"""Extract all undeclared variables from a Jinja2 template.
Returns a sorted list of variable names, excluding the built-in 'device'
variable which is auto-populated at render time.
"""
ast = _env.parse(template_content)
all_vars = meta.find_undeclared_variables(ast)
# 'device' is a built-in variable, not user-provided
return sorted(v for v in all_vars if v != "device")
def render_template(
template_content: str,
device: dict,
custom_variables: dict[str, str],
) -> str:
"""Render a Jinja2 template with device context and custom variables.
The 'device' variable is auto-populated from the device dict.
Custom variables are user-provided at push time.
Uses SandboxedEnvironment to prevent template injection.
Args:
template_content: Jinja2 template string.
device: Device info dict with keys: hostname, ip_address, model.
custom_variables: User-supplied variable values.
Returns:
Rendered template string.
Raises:
jinja2.TemplateSyntaxError: If template has syntax errors.
jinja2.UndefinedError: If required variables are missing.
"""
context = {
"device": {
"hostname": device.get("hostname", ""),
"ip": device.get("ip_address", ""),
"model": device.get("model", ""),
},
**custom_variables,
}
tpl = _env.from_string(template_content)
return tpl.render(context)
def validate_variable(name: str, value: str, var_type: str) -> str | None:
"""Validate a variable value against its declared type.
Returns None on success, or an error message string on failure.
"""
if var_type == "string":
return None # any string is valid
elif var_type == "ip":
try:
ipaddress.ip_address(value)
return None
except ValueError:
return f"'{name}' must be a valid IP address"
elif var_type == "subnet":
try:
ipaddress.ip_network(value, strict=False)
return None
except ValueError:
return f"'{name}' must be a valid subnet (e.g., 192.168.1.0/24)"
elif var_type == "integer":
try:
int(value)
return None
except ValueError:
return f"'{name}' must be an integer"
elif var_type == "boolean":
if value.lower() in ("true", "false", "yes", "no", "1", "0"):
return None
return f"'{name}' must be a boolean (true/false)"
return None # unknown type, allow
# ---------------------------------------------------------------------------
# Multi-device push orchestration
# ---------------------------------------------------------------------------
async def push_to_devices(rollout_id: str) -> dict:
"""Execute sequential template push for all jobs in a rollout.
Processes devices one at a time. If any device fails or reverts,
remaining jobs stay pending (paused). Follows the same pattern as
firmware upgrade_service.start_mass_upgrade.
This runs as a background task (asyncio.create_task) after the
API creates the push jobs and returns the rollout_id.
"""
try:
return await _run_push_rollout(rollout_id)
except Exception as exc:
logger.error(
"Uncaught exception in template push rollout %s: %s",
rollout_id, exc, exc_info=True,
)
return {"completed": 0, "failed": 1, "pending": 0}
async def _run_push_rollout(rollout_id: str) -> dict:
"""Internal rollout implementation."""
# Load all jobs for this rollout
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT j.id::text, j.status, d.hostname
FROM template_push_jobs j
JOIN devices d ON d.id = j.device_id
WHERE j.rollout_id = CAST(:rollout_id AS uuid)
ORDER BY j.created_at ASC
"""),
{"rollout_id": rollout_id},
)
jobs = result.fetchall()
if not jobs:
logger.warning("No jobs found for template push rollout %s", rollout_id)
return {"completed": 0, "failed": 0, "pending": 0}
completed = 0
failed = False
for job_id, current_status, hostname in jobs:
if current_status != "pending":
if current_status == "committed":
completed += 1
continue
logger.info(
"Template push rollout %s: pushing to device %s (job %s)",
rollout_id, hostname, job_id,
)
await push_single_device(job_id)
# Check resulting status
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("SELECT status FROM template_push_jobs WHERE id = CAST(:id AS uuid)"),
{"id": job_id},
)
row = result.fetchone()
if row and row[0] == "committed":
completed += 1
elif row and row[0] in ("failed", "reverted"):
failed = True
logger.error(
"Template push rollout %s paused: device %s %s",
rollout_id, hostname, row[0],
)
break
# Count remaining pending jobs
remaining = sum(1 for _, s, _ in jobs if s == "pending") - completed - (1 if failed else 0)
return {
"completed": completed,
"failed": 1 if failed else 0,
"pending": max(0, remaining),
}
async def push_single_device(job_id: str) -> None:
"""Push rendered template content to a single device.
Implements the two-phase panic-revert pattern:
1. Pre-backup (mandatory)
2. Install panic-revert scheduler on device
3. Write template content as RSC file via SFTP
4. /import the RSC file
5. Wait 60s for config to settle
6. Reachability check -> committed or reverted
All errors are caught and recorded in the job row.
"""
try:
await _run_single_push(job_id)
except Exception as exc:
logger.error(
"Uncaught exception in template push job %s: %s",
job_id, exc, exc_info=True,
)
await _update_job(job_id, status="failed", error_message=f"Unexpected error: {exc}")
async def _run_single_push(job_id: str) -> None:
"""Internal single-device push implementation."""
# Step 1: Load job and device info
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT j.id, j.device_id, j.tenant_id, j.rendered_content,
d.ip_address, d.hostname, d.encrypted_credentials,
d.encrypted_credentials_transit
FROM template_push_jobs j
JOIN devices d ON d.id = j.device_id
WHERE j.id = CAST(:job_id AS uuid)
"""),
{"job_id": job_id},
)
row = result.fetchone()
if not row:
logger.error("Template push job %s not found", job_id)
return
(
_, device_id, tenant_id, rendered_content,
ip_address, hostname, encrypted_credentials,
encrypted_credentials_transit,
) = row
device_id = str(device_id)
tenant_id = str(tenant_id)
hostname = hostname or ip_address
# Step 2: Update status to pushing
await _update_job(job_id, status="pushing", started_at=datetime.now(timezone.utc))
# Step 3: Decrypt credentials (dual-read: Transit preferred, legacy fallback)
if not encrypted_credentials_transit and not encrypted_credentials:
await _update_job(job_id, status="failed", error_message="Device has no stored credentials")
return
try:
from app.services.crypto import decrypt_credentials_hybrid
key = settings.get_encryption_key_bytes()
creds_json = await decrypt_credentials_hybrid(
encrypted_credentials_transit, encrypted_credentials, tenant_id, key,
)
creds = json.loads(creds_json)
ssh_username = creds.get("username", "")
ssh_password = creds.get("password", "")
except Exception as cred_err:
await _update_job(
job_id, status="failed",
error_message=f"Failed to decrypt credentials: {cred_err}",
)
return
# Step 4: Mandatory pre-push backup
logger.info("Running mandatory pre-push backup for device %s (%s)", hostname, ip_address)
try:
from app.services import backup_service
backup_result = await backup_service.run_backup(
device_id=device_id,
tenant_id=tenant_id,
trigger_type="pre-template-push",
)
backup_sha = backup_result["commit_sha"]
await _update_job(job_id, pre_push_backup_sha=backup_sha)
logger.info("Pre-push backup complete: %s", backup_sha[:8])
except Exception as backup_err:
logger.error("Pre-push backup failed for %s: %s", hostname, backup_err)
await _update_job(
job_id, status="failed",
error_message=f"Pre-push backup failed: {backup_err}",
)
return
# Step 5: SSH to device - install panic-revert, push config
logger.info(
"Pushing template to device %s (%s): installing panic-revert and uploading config",
hostname, ip_address,
)
try:
async with asyncssh.connect(
ip_address,
port=22,
username=ssh_username,
password=ssh_password,
known_hosts=None,
connect_timeout=30,
) as conn:
# 5a: Create binary backup on device as revert point
await conn.run(
f"/system backup save name={_PRE_PUSH_BACKUP} dont-encrypt=yes",
check=True,
)
logger.debug("Pre-push binary backup saved on device as %s.backup", _PRE_PUSH_BACKUP)
# 5b: Install panic-revert RouterOS scheduler
await conn.run(
f"/system scheduler add "
f'name="{_PANIC_REVERT_SCHEDULER}" '
f"interval=90s "
f'on-event=":delay 0; /system backup load name={_PRE_PUSH_BACKUP}" '
f"start-time=startup",
check=True,
)
logger.debug("Panic-revert scheduler installed on device")
# 5c: Upload rendered template as RSC file via SFTP
async with conn.start_sftp_client() as sftp:
async with sftp.open(_TEMPLATE_RSC, "wb") as f:
await f.write(rendered_content.encode("utf-8"))
logger.debug("Uploaded %s to device flash", _TEMPLATE_RSC)
# 5d: /import the config file
import_result = await conn.run(
f"/import file={_TEMPLATE_RSC}",
check=False,
)
logger.info(
"Template import result for device %s: exit_status=%s stdout=%r",
hostname, import_result.exit_status,
(import_result.stdout or "")[:200],
)
# 5e: Clean up the uploaded RSC file (best-effort)
try:
await conn.run(f"/file remove {_TEMPLATE_RSC}", check=True)
except Exception as cleanup_err:
logger.warning(
"Failed to clean up %s from device %s: %s",
_TEMPLATE_RSC, ip_address, cleanup_err,
)
except Exception as push_err:
logger.error(
"SSH push phase failed for device %s (%s): %s",
hostname, ip_address, push_err,
)
await _update_job(
job_id, status="failed",
error_message=f"Config push failed during SSH phase: {push_err}",
)
return
# Step 6: Wait 60s for config to settle
logger.info("Template pushed to device %s - waiting 60s for config to settle", hostname)
await asyncio.sleep(60)
# Step 7: Reachability check
reachable = await _check_reachability(ip_address, ssh_username, ssh_password)
if reachable:
# Step 8a: Device is reachable - remove panic-revert scheduler + cleanup
logger.info("Device %s (%s) is reachable after push - committing", hostname, ip_address)
try:
async with asyncssh.connect(
ip_address, port=22,
username=ssh_username, password=ssh_password,
known_hosts=None, connect_timeout=30,
) as conn:
await conn.run(
f'/system scheduler remove "{_PANIC_REVERT_SCHEDULER}"',
check=False,
)
await conn.run(
f"/file remove {_PRE_PUSH_BACKUP}.backup",
check=False,
)
except Exception as cleanup_err:
logger.warning(
"Failed to clean up panic-revert scheduler/backup on device %s: %s",
hostname, cleanup_err,
)
await _update_job(
job_id, status="committed",
completed_at=datetime.now(timezone.utc),
)
else:
# Step 8b: Device unreachable - RouterOS is auto-reverting
logger.warning(
"Device %s (%s) is unreachable after push - panic-revert scheduler "
"will auto-revert to %s.backup",
hostname, ip_address, _PRE_PUSH_BACKUP,
)
await _update_job(
job_id, status="reverted",
error_message="Device unreachable after push; auto-reverted via panic-revert scheduler",
completed_at=datetime.now(timezone.utc),
)
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
async def _check_reachability(ip: str, username: str, password: str) -> bool:
"""Check if a RouterOS device is reachable via SSH."""
try:
async with asyncssh.connect(
ip, port=22,
username=username, password=password,
known_hosts=None, connect_timeout=30,
) as conn:
result = await conn.run("/system identity print", check=True)
logger.debug("Reachability check OK for %s: %r", ip, result.stdout[:50])
return True
except Exception as exc:
logger.info("Device %s unreachable after push: %s", ip, exc)
return False
async def _update_job(job_id: str, **kwargs) -> None:
"""Update TemplatePushJob fields via raw SQL (background task, no RLS)."""
sets = []
params: dict = {"job_id": job_id}
for key, value in kwargs.items():
param_name = f"v_{key}"
if value is None and key in ("error_message", "started_at", "completed_at", "pre_push_backup_sha"):
sets.append(f"{key} = NULL")
else:
sets.append(f"{key} = :{param_name}")
params[param_name] = value
if not sets:
return
async with AdminAsyncSessionLocal() as session:
await session.execute(
text(f"""
UPDATE template_push_jobs
SET {', '.join(sets)}
WHERE id = CAST(:job_id AS uuid)
"""),
params,
)
await session.commit()

View File

@@ -0,0 +1,564 @@
"""Firmware upgrade orchestration service.
Handles single-device and mass firmware upgrades with:
- Mandatory pre-upgrade config backup
- NPK download and SFTP upload to device
- Reboot trigger and reconnect polling
- Post-upgrade version verification
- Sequential mass rollout with pause-on-failure
- Scheduled upgrades via APScheduler DateTrigger
All DB operations use AdminAsyncSessionLocal to bypass RLS since upgrade
jobs may span multiple tenants and run in background asyncio tasks.
"""
import asyncio
import io
import json
import logging
from datetime import datetime, timezone
from pathlib import Path
import asyncssh
from sqlalchemy import text
from app.config import settings
from app.database import AdminAsyncSessionLocal
from app.services.event_publisher import publish_event
logger = logging.getLogger(__name__)
# Maximum time to wait for a device to reconnect after reboot (seconds)
_RECONNECT_TIMEOUT = 300 # 5 minutes
_RECONNECT_POLL_INTERVAL = 15 # seconds
_INITIAL_WAIT = 60 # Wait before first reconnect attempt (boot cycle)
async def start_upgrade(job_id: str) -> None:
"""Execute a single device firmware upgrade.
Lifecycle: pending -> downloading -> uploading -> rebooting -> verifying -> completed/failed
This function is designed to run as a background asyncio.create_task or
APScheduler job. It never raises — all errors are caught and recorded
in the FirmwareUpgradeJob row.
"""
try:
await _run_upgrade(job_id)
except Exception as exc:
logger.error("Uncaught exception in firmware upgrade %s: %s", job_id, exc, exc_info=True)
await _update_job(job_id, status="failed", error_message=f"Unexpected error: {exc}")
async def _publish_upgrade_progress(
tenant_id: str,
device_id: str,
job_id: str,
stage: str,
target_version: str,
message: str,
error: str | None = None,
) -> None:
"""Publish firmware upgrade progress event to NATS (fire-and-forget)."""
payload = {
"event_type": "firmware_progress",
"tenant_id": tenant_id,
"device_id": device_id,
"job_id": job_id,
"stage": stage,
"target_version": target_version,
"message": message,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
if error:
payload["error"] = error
await publish_event(f"firmware.progress.{tenant_id}.{device_id}", payload)
async def _run_upgrade(job_id: str) -> None:
"""Internal upgrade implementation."""
# Step 1: Load job
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT j.id, j.device_id, j.tenant_id, j.target_version,
j.architecture, j.channel, j.status, j.confirmed_major_upgrade,
d.ip_address, d.hostname, d.encrypted_credentials,
d.routeros_version, d.encrypted_credentials_transit
FROM firmware_upgrade_jobs j
JOIN devices d ON d.id = j.device_id
WHERE j.id = CAST(:job_id AS uuid)
"""),
{"job_id": job_id},
)
row = result.fetchone()
if not row:
logger.error("Upgrade job %s not found", job_id)
return
(
_, device_id, tenant_id, target_version,
architecture, channel, status, confirmed_major,
ip_address, hostname, encrypted_credentials,
current_version, encrypted_credentials_transit,
) = row
device_id = str(device_id)
tenant_id = str(tenant_id)
hostname = hostname or ip_address
# Skip if already running or completed
if status not in ("pending", "scheduled"):
logger.info("Upgrade job %s already in status %s — skipping", job_id, status)
return
logger.info(
"Starting firmware upgrade for %s (%s): %s -> %s",
hostname, ip_address, current_version, target_version,
)
# Step 2: Update status to downloading
await _update_job(job_id, status="downloading", started_at=datetime.now(timezone.utc))
await _publish_upgrade_progress(tenant_id, device_id, job_id, "downloading", target_version, f"Downloading firmware {target_version} for {hostname}")
# Step 3: Check major version upgrade confirmation
if current_version and target_version:
current_major = current_version.split(".")[0] if current_version else ""
target_major = target_version.split(".")[0]
if current_major != target_major and not confirmed_major:
await _update_job(
job_id,
status="failed",
error_message="Major version upgrade requires explicit confirmation",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Major version upgrade requires explicit confirmation for {hostname}", error="Major version upgrade requires explicit confirmation")
return
# Step 4: Mandatory config backup
logger.info("Running mandatory pre-upgrade backup for %s", hostname)
try:
from app.services import backup_service
backup_result = await backup_service.run_backup(
device_id=device_id,
tenant_id=tenant_id,
trigger_type="pre-upgrade",
)
backup_sha = backup_result["commit_sha"]
await _update_job(job_id, pre_upgrade_backup_sha=backup_sha)
logger.info("Pre-upgrade backup complete: %s", backup_sha[:8])
except Exception as backup_err:
logger.error("Pre-upgrade backup failed for %s: %s", hostname, backup_err)
await _update_job(
job_id,
status="failed",
error_message=f"Pre-upgrade backup failed: {backup_err}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Pre-upgrade backup failed for {hostname}", error=str(backup_err))
return
# Step 5: Download NPK
logger.info("Downloading firmware %s for %s/%s", target_version, architecture, channel)
try:
from app.services.firmware_service import download_firmware
npk_path = await download_firmware(architecture, channel, target_version)
logger.info("Firmware cached at %s", npk_path)
except Exception as dl_err:
logger.error("Firmware download failed: %s", dl_err)
await _update_job(
job_id,
status="failed",
error_message=f"Firmware download failed: {dl_err}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Firmware download failed for {hostname}", error=str(dl_err))
return
# Step 6: Upload NPK to device via SFTP
await _update_job(job_id, status="uploading")
await _publish_upgrade_progress(tenant_id, device_id, job_id, "uploading", target_version, f"Uploading firmware to {hostname}")
# Decrypt device credentials (dual-read: Transit preferred, legacy fallback)
if not encrypted_credentials_transit and not encrypted_credentials:
await _update_job(job_id, status="failed", error_message="Device has no stored credentials")
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"No stored credentials for {hostname}", error="Device has no stored credentials")
return
try:
from app.services.crypto import decrypt_credentials_hybrid
key = settings.get_encryption_key_bytes()
creds_json = await decrypt_credentials_hybrid(
encrypted_credentials_transit, encrypted_credentials, tenant_id, key,
)
creds = json.loads(creds_json)
ssh_username = creds.get("username", "")
ssh_password = creds.get("password", "")
except Exception as cred_err:
await _update_job(
job_id,
status="failed",
error_message=f"Failed to decrypt credentials: {cred_err}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Failed to decrypt credentials for {hostname}", error=str(cred_err))
return
try:
npk_data = Path(npk_path).read_bytes()
npk_filename = Path(npk_path).name
async with asyncssh.connect(
ip_address,
port=22,
username=ssh_username,
password=ssh_password,
known_hosts=None,
connect_timeout=30,
) as conn:
async with conn.start_sftp_client() as sftp:
async with sftp.open(f"/{npk_filename}", "wb") as f:
await f.write(npk_data)
logger.info("Uploaded %s to %s", npk_filename, hostname)
except Exception as upload_err:
logger.error("NPK upload failed for %s: %s", hostname, upload_err)
await _update_job(
job_id,
status="failed",
error_message=f"NPK upload failed: {upload_err}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"NPK upload failed for {hostname}", error=str(upload_err))
return
# Step 7: Trigger reboot
await _update_job(job_id, status="rebooting")
await _publish_upgrade_progress(tenant_id, device_id, job_id, "rebooting", target_version, f"Rebooting {hostname} for firmware install")
try:
async with asyncssh.connect(
ip_address,
port=22,
username=ssh_username,
password=ssh_password,
known_hosts=None,
connect_timeout=30,
) as conn:
# RouterOS will install NPK on boot
await conn.run("/system reboot", check=False)
logger.info("Reboot command sent to %s", hostname)
except Exception as reboot_err:
# Device may drop connection during reboot — this is expected
logger.info("Device %s dropped connection after reboot command (expected): %s", hostname, reboot_err)
# Step 8: Wait for reconnect
logger.info("Waiting %ds before polling %s for reconnect", _INITIAL_WAIT, hostname)
await asyncio.sleep(_INITIAL_WAIT)
reconnected = False
elapsed = 0
while elapsed < _RECONNECT_TIMEOUT:
if await _check_ssh_reachable(ip_address, ssh_username, ssh_password):
reconnected = True
break
await asyncio.sleep(_RECONNECT_POLL_INTERVAL)
elapsed += _RECONNECT_POLL_INTERVAL
if not reconnected:
logger.error("Device %s did not reconnect within %ds", hostname, _RECONNECT_TIMEOUT)
await _update_job(
job_id,
status="failed",
error_message=f"Device did not reconnect within {_RECONNECT_TIMEOUT // 60} minutes after reboot",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Device {hostname} did not reconnect within {_RECONNECT_TIMEOUT // 60} minutes", error="Reconnect timeout")
return
# Step 9: Verify upgrade
await _update_job(job_id, status="verifying")
await _publish_upgrade_progress(tenant_id, device_id, job_id, "verifying", target_version, f"Verifying firmware version on {hostname}")
try:
actual_version = await _get_device_version(ip_address, ssh_username, ssh_password)
if actual_version and target_version in actual_version:
logger.info(
"Firmware upgrade verified for %s: %s",
hostname, actual_version,
)
await _update_job(
job_id,
status="completed",
completed_at=datetime.now(timezone.utc),
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "completed", target_version, f"Firmware upgrade to {target_version} completed on {hostname}")
else:
logger.error(
"Version mismatch for %s: expected %s, got %s",
hostname, target_version, actual_version,
)
await _update_job(
job_id,
status="failed",
error_message=f"Expected {target_version} but got {actual_version}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Version mismatch on {hostname}: expected {target_version}, got {actual_version}", error=f"Expected {target_version} but got {actual_version}")
except Exception as verify_err:
logger.error("Post-upgrade verification failed for %s: %s", hostname, verify_err)
await _update_job(
job_id,
status="failed",
error_message=f"Post-upgrade verification failed: {verify_err}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Post-upgrade verification failed for {hostname}", error=str(verify_err))
async def start_mass_upgrade(rollout_group_id: str) -> dict:
"""Execute a sequential mass firmware upgrade.
Processes upgrade jobs one at a time. If any device fails,
all remaining jobs in the group are paused.
Returns summary dict with completed/failed/paused counts.
"""
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT j.id, j.status, d.hostname
FROM firmware_upgrade_jobs j
JOIN devices d ON d.id = j.device_id
WHERE j.rollout_group_id = CAST(:group_id AS uuid)
ORDER BY j.created_at ASC
"""),
{"group_id": rollout_group_id},
)
jobs = result.fetchall()
if not jobs:
logger.warning("No jobs found for rollout group %s", rollout_group_id)
return {"completed": 0, "failed": 0, "paused": 0}
completed = 0
failed_device = None
for job_id, current_status, hostname in jobs:
job_id_str = str(job_id)
# Only process pending/scheduled jobs
if current_status not in ("pending", "scheduled"):
if current_status == "completed":
completed += 1
continue
logger.info("Mass rollout: upgrading device %s (job %s)", hostname, job_id_str)
await start_upgrade(job_id_str)
# Check if it completed or failed
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("SELECT status FROM firmware_upgrade_jobs WHERE id = CAST(:id AS uuid)"),
{"id": job_id_str},
)
row = result.fetchone()
if row and row[0] == "completed":
completed += 1
elif row and row[0] == "failed":
failed_device = hostname
logger.error("Mass rollout paused: %s failed", hostname)
break
# Pause remaining jobs if one failed
paused = 0
if failed_device:
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
UPDATE firmware_upgrade_jobs
SET status = 'paused'
WHERE rollout_group_id = CAST(:group_id AS uuid)
AND status IN ('pending', 'scheduled')
RETURNING id
"""),
{"group_id": rollout_group_id},
)
paused = len(result.fetchall())
await session.commit()
return {
"completed": completed,
"failed": 1 if failed_device else 0,
"failed_device": failed_device,
"paused": paused,
}
def schedule_upgrade(job_id: str, scheduled_at: datetime) -> None:
"""Schedule a firmware upgrade for future execution via APScheduler."""
from app.services.backup_scheduler import backup_scheduler
backup_scheduler.add_job(
start_upgrade,
trigger="date",
run_date=scheduled_at,
args=[job_id],
id=f"fw_upgrade_{job_id}",
name=f"Firmware upgrade {job_id}",
max_instances=1,
replace_existing=True,
)
logger.info("Scheduled firmware upgrade %s for %s", job_id, scheduled_at)
def schedule_mass_upgrade(rollout_group_id: str, scheduled_at: datetime) -> None:
"""Schedule a mass firmware upgrade for future execution."""
from app.services.backup_scheduler import backup_scheduler
backup_scheduler.add_job(
start_mass_upgrade,
trigger="date",
run_date=scheduled_at,
args=[rollout_group_id],
id=f"fw_mass_upgrade_{rollout_group_id}",
name=f"Mass firmware upgrade {rollout_group_id}",
max_instances=1,
replace_existing=True,
)
logger.info("Scheduled mass firmware upgrade %s for %s", rollout_group_id, scheduled_at)
async def cancel_upgrade(job_id: str) -> None:
"""Cancel a scheduled or pending upgrade."""
from app.services.backup_scheduler import backup_scheduler
# Remove APScheduler job if it exists
try:
backup_scheduler.remove_job(f"fw_upgrade_{job_id}")
except Exception:
pass # Job might not be scheduled
await _update_job(
job_id,
status="failed",
error_message="Cancelled by operator",
completed_at=datetime.now(timezone.utc),
)
logger.info("Upgrade job %s cancelled", job_id)
async def retry_failed_upgrade(job_id: str) -> None:
"""Reset a failed upgrade job to pending and re-execute."""
await _update_job(
job_id,
status="pending",
error_message=None,
started_at=None,
completed_at=None,
)
asyncio.create_task(start_upgrade(job_id))
logger.info("Retrying upgrade job %s", job_id)
async def resume_mass_upgrade(rollout_group_id: str) -> None:
"""Resume a paused mass rollout from where it left off."""
# Reset first paused job to pending, then restart sequential processing
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
UPDATE firmware_upgrade_jobs
SET status = 'pending'
WHERE rollout_group_id = CAST(:group_id AS uuid)
AND status = 'paused'
"""),
{"group_id": rollout_group_id},
)
await session.commit()
asyncio.create_task(start_mass_upgrade(rollout_group_id))
logger.info("Resuming mass rollout %s", rollout_group_id)
async def abort_mass_upgrade(rollout_group_id: str) -> int:
"""Abort all remaining jobs in a paused mass rollout."""
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
UPDATE firmware_upgrade_jobs
SET status = 'failed',
error_message = 'Aborted by operator',
completed_at = NOW()
WHERE rollout_group_id = CAST(:group_id AS uuid)
AND status IN ('pending', 'scheduled', 'paused')
RETURNING id
"""),
{"group_id": rollout_group_id},
)
aborted = len(result.fetchall())
await session.commit()
logger.info("Aborted %d remaining jobs in rollout %s", aborted, rollout_group_id)
return aborted
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
async def _update_job(job_id: str, **kwargs) -> None:
"""Update FirmwareUpgradeJob fields."""
sets = []
params: dict = {"job_id": job_id}
for key, value in kwargs.items():
param_name = f"v_{key}"
if value is None and key in ("error_message", "started_at", "completed_at"):
sets.append(f"{key} = NULL")
else:
sets.append(f"{key} = :{param_name}")
params[param_name] = value
if not sets:
return
async with AdminAsyncSessionLocal() as session:
await session.execute(
text(f"""
UPDATE firmware_upgrade_jobs
SET {', '.join(sets)}
WHERE id = CAST(:job_id AS uuid)
"""),
params,
)
await session.commit()
async def _check_ssh_reachable(ip: str, username: str, password: str) -> bool:
"""Check if a device is reachable via SSH."""
try:
async with asyncssh.connect(
ip,
port=22,
username=username,
password=password,
known_hosts=None,
connect_timeout=15,
) as conn:
await conn.run("/system identity print", check=True)
return True
except Exception:
return False
async def _get_device_version(ip: str, username: str, password: str) -> str:
"""Get the current RouterOS version from a device via SSH."""
async with asyncssh.connect(
ip,
port=22,
username=username,
password=password,
known_hosts=None,
connect_timeout=30,
) as conn:
result = await conn.run("/system resource print", check=True)
# Parse version from output: "version: 7.17 (stable)"
for line in result.stdout.splitlines():
if "version" in line.lower():
parts = line.split(":", 1)
if len(parts) == 2:
return parts[1].strip()
return ""

View File

@@ -0,0 +1,392 @@
"""WireGuard VPN management service.
Handles key generation, peer management, config file sync, and RouterOS command generation.
"""
import base64
import ipaddress
import json
import os
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
import structlog
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
from cryptography.hazmat.primitives.serialization import (
Encoding,
NoEncryption,
PrivateFormat,
PublicFormat,
)
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.models.device import Device
from app.models.vpn import VpnConfig, VpnPeer
from app.services.crypto import decrypt_credentials, encrypt_credentials, encrypt_credentials_transit
logger = structlog.get_logger(__name__)
# ── Key Generation ──
def generate_wireguard_keypair() -> tuple[str, str]:
"""Generate a WireGuard X25519 keypair. Returns (private_key_b64, public_key_b64)."""
private_key = X25519PrivateKey.generate()
priv_bytes = private_key.private_bytes(Encoding.Raw, PrivateFormat.Raw, NoEncryption())
pub_bytes = private_key.public_key().public_bytes(Encoding.Raw, PublicFormat.Raw)
return base64.b64encode(priv_bytes).decode(), base64.b64encode(pub_bytes).decode()
def generate_preshared_key() -> str:
"""Generate a WireGuard preshared key (32 random bytes, base64)."""
return base64.b64encode(os.urandom(32)).decode()
# ── Config File Management ──
def _get_wg_config_path() -> Path:
"""Return the path to the shared WireGuard config directory."""
return Path(os.getenv("WIREGUARD_CONFIG_PATH", "/data/wireguard"))
async def sync_wireguard_config(db: AsyncSession, tenant_id: uuid.UUID) -> None:
"""Regenerate wg0.conf from database state and write to shared volume."""
config = await get_vpn_config(db, tenant_id)
if not config or not config.is_enabled:
return
key_bytes = settings.get_encryption_key_bytes()
server_private_key = decrypt_credentials(config.server_private_key, key_bytes)
result = await db.execute(
select(VpnPeer).where(VpnPeer.tenant_id == tenant_id, VpnPeer.is_enabled.is_(True))
)
peers = result.scalars().all()
# Build wg0.conf
lines = [
"[Interface]",
f"Address = {config.server_address}",
f"ListenPort = {config.server_port}",
f"PrivateKey = {server_private_key}",
"",
]
for peer in peers:
peer_ip = peer.assigned_ip.split("/")[0] # strip CIDR for AllowedIPs
allowed_ips = [f"{peer_ip}/32"]
if peer.additional_allowed_ips:
# Comma-separated additional subnets (e.g. site-to-site routing)
extra = [s.strip() for s in peer.additional_allowed_ips.split(",") if s.strip()]
allowed_ips.extend(extra)
lines.append("[Peer]")
lines.append(f"PublicKey = {peer.peer_public_key}")
if peer.preshared_key:
psk = decrypt_credentials(peer.preshared_key, key_bytes)
lines.append(f"PresharedKey = {psk}")
lines.append(f"AllowedIPs = {', '.join(allowed_ips)}")
lines.append("")
config_dir = _get_wg_config_path()
wg_confs_dir = config_dir / "wg_confs"
wg_confs_dir.mkdir(parents=True, exist_ok=True)
conf_path = wg_confs_dir / "wg0.conf"
conf_path.write_text("\n".join(lines))
# Signal WireGuard container to reload config
reload_flag = wg_confs_dir / ".reload"
reload_flag.write_text("1")
logger.info("wireguard config synced", tenant_id=str(tenant_id), peers=len(peers))
# ── Live Status ──
def read_wg_status() -> dict[str, dict]:
"""Read live WireGuard peer status from the shared volume.
The WireGuard container writes wg_status.json every 15 seconds
with output from `wg show wg0 dump`. Returns a dict keyed by
peer public key with handshake timestamp and transfer stats.
"""
status_path = _get_wg_config_path() / "wg_status.json"
if not status_path.exists():
return {}
try:
data = json.loads(status_path.read_text())
return {entry["public_key"]: entry for entry in data}
except (json.JSONDecodeError, KeyError, OSError):
return {}
def get_peer_handshake(wg_status: dict[str, dict], public_key: str) -> Optional[datetime]:
"""Get last_handshake datetime for a peer from live WireGuard status."""
entry = wg_status.get(public_key)
if not entry:
return None
ts = entry.get("last_handshake", 0)
if ts and ts > 0:
return datetime.fromtimestamp(ts, tz=timezone.utc)
return None
# ── CRUD Operations ──
async def get_vpn_config(db: AsyncSession, tenant_id: uuid.UUID) -> Optional[VpnConfig]:
"""Get the VPN config for a tenant."""
result = await db.execute(select(VpnConfig).where(VpnConfig.tenant_id == tenant_id))
return result.scalar_one_or_none()
async def setup_vpn(
db: AsyncSession, tenant_id: uuid.UUID, endpoint: Optional[str] = None
) -> VpnConfig:
"""Initialize VPN for a tenant — generates server keys and creates config."""
existing = await get_vpn_config(db, tenant_id)
if existing:
raise ValueError("VPN already configured for this tenant")
private_key_b64, public_key_b64 = generate_wireguard_keypair()
key_bytes = settings.get_encryption_key_bytes()
encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
config = VpnConfig(
tenant_id=tenant_id,
server_private_key=encrypted_private,
server_public_key=public_key_b64,
endpoint=endpoint,
is_enabled=True,
)
db.add(config)
await db.flush()
await sync_wireguard_config(db, tenant_id)
return config
async def update_vpn_config(
db: AsyncSession, tenant_id: uuid.UUID, endpoint: Optional[str] = None, is_enabled: Optional[bool] = None
) -> VpnConfig:
"""Update VPN config settings."""
config = await get_vpn_config(db, tenant_id)
if not config:
raise ValueError("VPN not configured for this tenant")
if endpoint is not None:
config.endpoint = endpoint
if is_enabled is not None:
config.is_enabled = is_enabled
await db.flush()
await sync_wireguard_config(db, tenant_id)
return config
async def get_peers(db: AsyncSession, tenant_id: uuid.UUID) -> list[VpnPeer]:
"""List all VPN peers for a tenant."""
result = await db.execute(
select(VpnPeer).where(VpnPeer.tenant_id == tenant_id).order_by(VpnPeer.created_at)
)
return list(result.scalars().all())
async def _next_available_ip(db: AsyncSession, tenant_id: uuid.UUID, config: VpnConfig) -> str:
"""Allocate the next available IP in the VPN subnet."""
# Parse subnet: e.g. "10.10.0.0/24" → start from .2 (server is .1)
network = ipaddress.ip_network(config.subnet, strict=False)
hosts = list(network.hosts())
# Get already assigned IPs
result = await db.execute(select(VpnPeer.assigned_ip).where(VpnPeer.tenant_id == tenant_id))
used_ips = {row[0].split("/")[0] for row in result.all()}
used_ips.add(config.server_address.split("/")[0]) # exclude server IP
for host in hosts[1:]: # skip .1 (server)
if str(host) not in used_ips:
return f"{host}/24"
raise ValueError("No available IPs in VPN subnet")
async def add_peer(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UUID, additional_allowed_ips: Optional[str] = None) -> VpnPeer:
"""Add a device as a VPN peer."""
config = await get_vpn_config(db, tenant_id)
if not config:
raise ValueError("VPN not configured — enable VPN first")
# Check device exists
device = await db.execute(select(Device).where(Device.id == device_id, Device.tenant_id == tenant_id))
if not device.scalar_one_or_none():
raise ValueError("Device not found")
# Check if already a peer
existing = await db.execute(select(VpnPeer).where(VpnPeer.device_id == device_id))
if existing.scalar_one_or_none():
raise ValueError("Device is already a VPN peer")
private_key_b64, public_key_b64 = generate_wireguard_keypair()
psk = generate_preshared_key()
key_bytes = settings.get_encryption_key_bytes()
encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
encrypted_psk = encrypt_credentials(psk, key_bytes)
assigned_ip = await _next_available_ip(db, tenant_id, config)
peer = VpnPeer(
tenant_id=tenant_id,
device_id=device_id,
peer_private_key=encrypted_private,
peer_public_key=public_key_b64,
preshared_key=encrypted_psk,
assigned_ip=assigned_ip,
additional_allowed_ips=additional_allowed_ips,
)
db.add(peer)
await db.flush()
await sync_wireguard_config(db, tenant_id)
return peer
async def remove_peer(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.UUID) -> None:
"""Remove a VPN peer."""
result = await db.execute(
select(VpnPeer).where(VpnPeer.id == peer_id, VpnPeer.tenant_id == tenant_id)
)
peer = result.scalar_one_or_none()
if not peer:
raise ValueError("Peer not found")
await db.delete(peer)
await db.flush()
await sync_wireguard_config(db, tenant_id)
async def get_peer_config(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.UUID) -> dict:
"""Get the full config for a peer — includes private key for device setup."""
config = await get_vpn_config(db, tenant_id)
if not config:
raise ValueError("VPN not configured")
result = await db.execute(
select(VpnPeer).where(VpnPeer.id == peer_id, VpnPeer.tenant_id == tenant_id)
)
peer = result.scalar_one_or_none()
if not peer:
raise ValueError("Peer not found")
key_bytes = settings.get_encryption_key_bytes()
private_key = decrypt_credentials(peer.peer_private_key, key_bytes)
psk = decrypt_credentials(peer.preshared_key, key_bytes) if peer.preshared_key else None
endpoint = config.endpoint or "YOUR_SERVER_IP:51820"
peer_ip_no_cidr = peer.assigned_ip.split("/")[0]
routeros_commands = [
f'/interface wireguard add name=wg-portal listen-port=13231 private-key="{private_key}"',
f'/interface wireguard peers add interface=wg-portal public-key="{config.server_public_key}" '
f'endpoint-address={endpoint.split(":")[0]} endpoint-port={endpoint.split(":")[-1]} '
f'allowed-address={config.subnet} persistent-keepalive=25'
+ (f' preshared-key="{psk}"' if psk else ""),
f"/ip address add address={peer.assigned_ip} interface=wg-portal",
]
return {
"peer_private_key": private_key,
"peer_public_key": peer.peer_public_key,
"assigned_ip": peer.assigned_ip,
"server_public_key": config.server_public_key,
"server_endpoint": endpoint,
"allowed_ips": config.subnet,
"routeros_commands": routeros_commands,
}
async def onboard_device(
db: AsyncSession,
tenant_id: uuid.UUID,
hostname: str,
username: str,
password: str,
) -> dict:
"""Create device + VPN peer in one transaction. Returns device, peer, and RouterOS commands.
Unlike regular device creation, this skips TCP connectivity checks because
the VPN tunnel isn't up yet. The device IP is set to the VPN-assigned address.
"""
config = await get_vpn_config(db, tenant_id)
if not config:
raise ValueError("VPN not configured — enable VPN first")
# Allocate VPN IP before creating device
assigned_ip = await _next_available_ip(db, tenant_id, config)
vpn_ip_no_cidr = assigned_ip.split("/")[0]
# Create device with VPN IP (skip TCP check — tunnel not up yet)
credentials_json = json.dumps({"username": username, "password": password})
transit_ciphertext = await encrypt_credentials_transit(credentials_json, str(tenant_id))
device = Device(
tenant_id=tenant_id,
hostname=hostname,
ip_address=vpn_ip_no_cidr,
api_port=8728,
api_ssl_port=8729,
encrypted_credentials_transit=transit_ciphertext,
status="unknown",
)
db.add(device)
await db.flush()
# Create VPN peer linked to this device
private_key_b64, public_key_b64 = generate_wireguard_keypair()
psk = generate_preshared_key()
key_bytes = settings.get_encryption_key_bytes()
encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
encrypted_psk = encrypt_credentials(psk, key_bytes)
peer = VpnPeer(
tenant_id=tenant_id,
device_id=device.id,
peer_private_key=encrypted_private,
peer_public_key=public_key_b64,
preshared_key=encrypted_psk,
assigned_ip=assigned_ip,
)
db.add(peer)
await db.flush()
await sync_wireguard_config(db, tenant_id)
# Generate RouterOS commands
endpoint = config.endpoint or "YOUR_SERVER_IP:51820"
psk_decrypted = decrypt_credentials(encrypted_psk, key_bytes)
routeros_commands = [
f'/interface wireguard add name=wg-portal listen-port=13231 private-key="{private_key_b64}"',
f'/interface wireguard peers add interface=wg-portal public-key="{config.server_public_key}" '
f'endpoint-address={endpoint.split(":")[0]} endpoint-port={endpoint.split(":")[-1]} '
f'allowed-address={config.subnet} persistent-keepalive=25'
f' preshared-key="{psk_decrypted}"',
f"/ip address add address={assigned_ip} interface=wg-portal",
]
return {
"device_id": device.id,
"peer_id": peer.id,
"hostname": hostname,
"assigned_ip": assigned_ip,
"routeros_commands": routeros_commands,
}

View File

@@ -0,0 +1,66 @@
{% extends "reports/base.html" %}
{% block content %}
<div class="date-range">
Report period: {{ date_from }} to {{ date_to }}
</div>
<div class="summary-box">
<div class="summary-stat">
<div class="value">{{ total_alerts }}</div>
<div class="label">Total Alerts</div>
</div>
<div class="summary-stat">
<div class="value" style="color: #991B1B;">{{ critical_count }}</div>
<div class="label">Critical</div>
</div>
<div class="summary-stat">
<div class="value" style="color: #92400E;">{{ warning_count }}</div>
<div class="label">Warning</div>
</div>
<div class="summary-stat">
<div class="value" style="color: #1E40AF;">{{ info_count }}</div>
<div class="label">Info</div>
</div>
{% if mttr_minutes is not none %}
<div class="summary-stat">
<div class="value">{{ mttr_display }}</div>
<div class="label">Avg MTTR</div>
</div>
{% endif %}
</div>
{% if alerts %}
<h2 class="section-title">Alert Events</h2>
<table>
<thead>
<tr>
<th>Timestamp</th>
<th>Device</th>
<th>Severity</th>
<th>Message</th>
<th>Status</th>
<th>Duration</th>
</tr>
</thead>
<tbody>
{% for alert in alerts %}
<tr>
<td>{{ alert.fired_at }}</td>
<td style="font-weight: 600;">{{ alert.hostname or '-' }}</td>
<td>
<span class="badge badge-{{ alert.severity }}">{{ alert.severity | upper }}</span>
</td>
<td>{{ alert.message or '-' }}</td>
<td>
<span class="badge badge-{{ alert.status }}">{{ alert.status | upper }}</span>
</td>
<td>{{ alert.duration or '-' }}</td>
</tr>
{% endfor %}
</tbody>
</table>
{% else %}
<div class="no-data">No alerts found for the selected date range.</div>
{% endif %}
{% endblock %}

View File

@@ -0,0 +1,208 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>{{ report_title }} - TOD</title>
<style>
@page {
size: {% block page_size %}A4 portrait{% endblock %};
margin: 20mm 15mm 25mm 15mm;
@bottom-center {
content: "Page " counter(page) " of " counter(pages);
font-size: 9px;
color: #64748B;
}
@bottom-right {
content: "Generated by TOD";
font-size: 9px;
color: #64748B;
}
}
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
font-size: 11px;
color: #1E293B;
line-height: 1.5;
}
.report-header {
background: #1E293B;
color: #FFFFFF;
padding: 16px 20px;
margin: -20mm -15mm 20px -15mm;
padding-left: 15mm;
padding-right: 15mm;
display: flex;
justify-content: space-between;
align-items: center;
}
.report-header .brand {
display: flex;
align-items: center;
gap: 10px;
}
.report-header .brand .logo {
width: 28px;
height: 28px;
background: #38BDF8;
border-radius: 6px;
display: flex;
align-items: center;
justify-content: center;
font-weight: 700;
font-size: 16px;
color: #0F172A;
}
.report-header .brand .name {
font-size: 14px;
font-weight: 600;
}
.report-header .meta {
text-align: right;
font-size: 10px;
color: #94A3B8;
}
.report-header .meta .title {
font-size: 16px;
font-weight: 600;
color: #FFFFFF;
margin-bottom: 2px;
}
.report-body {
padding: 0;
}
.section-title {
font-size: 13px;
font-weight: 600;
color: #0F172A;
border-bottom: 2px solid #38BDF8;
padding-bottom: 6px;
margin-bottom: 12px;
margin-top: 20px;
}
.section-title:first-child {
margin-top: 0;
}
table {
width: 100%;
border-collapse: collapse;
margin-bottom: 16px;
font-size: 10px;
}
thead th {
background: #F1F5F9;
color: #475569;
font-weight: 600;
text-align: left;
padding: 8px 10px;
border-bottom: 2px solid #E2E8F0;
font-size: 9px;
text-transform: uppercase;
letter-spacing: 0.5px;
}
tbody td {
padding: 7px 10px;
border-bottom: 1px solid #F1F5F9;
vertical-align: top;
}
tbody tr:nth-child(even) {
background: #FAFBFC;
}
.summary-box {
background: #F8FAFC;
border: 1px solid #E2E8F0;
border-radius: 6px;
padding: 12px 16px;
margin-bottom: 16px;
display: flex;
gap: 24px;
}
.summary-stat {
text-align: center;
}
.summary-stat .value {
font-size: 20px;
font-weight: 700;
color: #0F172A;
}
.summary-stat .label {
font-size: 9px;
color: #64748B;
text-transform: uppercase;
letter-spacing: 0.5px;
}
.badge {
display: inline-block;
padding: 2px 8px;
border-radius: 10px;
font-size: 9px;
font-weight: 600;
}
.badge-online { background: #DCFCE7; color: #166534; }
.badge-offline { background: #FEE2E2; color: #991B1B; }
.badge-unknown { background: #F1F5F9; color: #475569; }
.badge-critical { background: #FEE2E2; color: #991B1B; }
.badge-warning { background: #FEF3C7; color: #92400E; }
.badge-info { background: #DBEAFE; color: #1E40AF; }
.badge-firing { background: #FEE2E2; color: #991B1B; }
.badge-resolved { background: #DCFCE7; color: #166534; }
.badge-acknowledged { background: #DBEAFE; color: #1E40AF; }
.date-range {
font-size: 11px;
color: #64748B;
margin-bottom: 12px;
}
.no-data {
text-align: center;
padding: 40px 20px;
color: #94A3B8;
font-size: 12px;
}
{% block extra_styles %}{% endblock %}
</style>
</head>
<body>
<div class="report-header">
<div class="brand">
<div class="logo">M</div>
<div class="name">TOD - The Other Dude</div>
</div>
<div class="meta">
<div class="title">{{ report_title }}</div>
<div>{{ tenant_name }} &bull; Generated {{ generated_at }}</div>
</div>
</div>
<div class="report-body">
{% block content %}{% endblock %}
</div>
</body>
</html>

View File

@@ -0,0 +1,46 @@
{% extends "reports/base.html" %}
{% block content %}
<div class="date-range">
Report period: {{ date_from }} to {{ date_to }}
</div>
<div class="summary-box">
<div class="summary-stat">
<div class="value">{{ total_entries }}</div>
<div class="label">Total Changes</div>
</div>
<div class="summary-stat">
<div class="value">{{ data_source }}</div>
<div class="label">Data Source</div>
</div>
</div>
{% if entries %}
<h2 class="section-title">Change Log</h2>
<table>
<thead>
<tr>
<th>Timestamp</th>
<th>User</th>
<th>Action</th>
<th>Device</th>
<th>Details</th>
</tr>
</thead>
<tbody>
{% for entry in entries %}
<tr>
<td>{{ entry.timestamp }}</td>
<td>{{ entry.user or '-' }}</td>
<td>{{ entry.action }}</td>
<td style="font-weight: 600;">{{ entry.device or '-' }}</td>
<td>{{ entry.details or '-' }}</td>
</tr>
{% endfor %}
</tbody>
</table>
{% else %}
<div class="no-data">No changes found for the selected date range.</div>
{% endif %}
{% endblock %}

View File

@@ -0,0 +1,59 @@
{% extends "reports/base.html" %}
{% block page_size %}A4 landscape{% endblock %}
{% block content %}
<div class="summary-box">
<div class="summary-stat">
<div class="value">{{ total_devices }}</div>
<div class="label">Total Devices</div>
</div>
<div class="summary-stat">
<div class="value" style="color: #166534;">{{ online_count }}</div>
<div class="label">Online</div>
</div>
<div class="summary-stat">
<div class="value" style="color: #991B1B;">{{ offline_count }}</div>
<div class="label">Offline</div>
</div>
<div class="summary-stat">
<div class="value" style="color: #475569;">{{ unknown_count }}</div>
<div class="label">Unknown</div>
</div>
</div>
{% if devices %}
<table>
<thead>
<tr>
<th>Hostname</th>
<th>IP Address</th>
<th>Model</th>
<th>RouterOS</th>
<th>Status</th>
<th>Last Seen</th>
<th>Uptime</th>
<th>Groups</th>
</tr>
</thead>
<tbody>
{% for device in devices %}
<tr>
<td style="font-weight: 600;">{{ device.hostname }}</td>
<td>{{ device.ip_address }}</td>
<td>{{ device.model or '-' }}</td>
<td>{{ device.routeros_version or '-' }}</td>
<td>
<span class="badge badge-{{ device.status }}">{{ device.status | upper }}</span>
</td>
<td>{{ device.last_seen or '-' }}</td>
<td>{{ device.uptime or '-' }}</td>
<td>{{ device.groups or '-' }}</td>
</tr>
{% endfor %}
</tbody>
</table>
{% else %}
<div class="no-data">No devices found for this tenant.</div>
{% endif %}
{% endblock %}

Some files were not shown because too many files have changed in this diff Show More