feat: The Other Dude v9.0.1 — full-featured email system
ci: add GitHub Pages deployment workflow for docs site Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
1
backend/app/__init__.py
Normal file
1
backend/app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# TOD Backend
|
||||
177
backend/app/config.py
Normal file
177
backend/app/config.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""Application configuration using Pydantic Settings."""
|
||||
|
||||
import base64
|
||||
import sys
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
# Known insecure default values that MUST NOT be used in non-dev environments.
|
||||
# If any of these are detected in production/staging, the app refuses to start.
|
||||
KNOWN_INSECURE_DEFAULTS: dict[str, list[str]] = {
|
||||
"JWT_SECRET_KEY": [
|
||||
"change-this-in-production-use-a-long-random-string",
|
||||
"dev-jwt-secret-change-in-production",
|
||||
"CHANGE_ME_IN_PRODUCTION",
|
||||
],
|
||||
"CREDENTIAL_ENCRYPTION_KEY": [
|
||||
"LLLjnfBZTSycvL2U07HDSxUeTtLxb9cZzryQl0R9E4w=",
|
||||
"CHANGE_ME_IN_PRODUCTION",
|
||||
],
|
||||
"OPENBAO_TOKEN": [
|
||||
"dev-openbao-token",
|
||||
"CHANGE_ME_IN_PRODUCTION",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def validate_production_settings(settings: "Settings") -> None:
|
||||
"""Reject known-insecure defaults in non-dev environments.
|
||||
|
||||
Called during app startup. Exits with code 1 and clear error message
|
||||
if production is running with dev secrets.
|
||||
"""
|
||||
if settings.ENVIRONMENT == "dev":
|
||||
return
|
||||
|
||||
for field, insecure_values in KNOWN_INSECURE_DEFAULTS.items():
|
||||
actual = getattr(settings, field, None)
|
||||
if actual in insecure_values:
|
||||
print(
|
||||
f"FATAL: {field} uses a known insecure default in '{settings.ENVIRONMENT}' environment.\n"
|
||||
f"Generate a secure value and set it in your .env.prod file.\n"
|
||||
f"For JWT_SECRET_KEY: python -c \"import secrets; print(secrets.token_urlsafe(64))\"\n"
|
||||
f"For CREDENTIAL_ENCRYPTION_KEY: python -c \"import secrets, base64; print(base64.b64encode(secrets.token_bytes(32)).decode())\"",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# Environment (dev | staging | production)
|
||||
ENVIRONMENT: str = "dev"
|
||||
|
||||
# Database
|
||||
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/mikrotik"
|
||||
# Sync URL used by Alembic only
|
||||
SYNC_DATABASE_URL: str = "postgresql+psycopg2://postgres:postgres@localhost:5432/mikrotik"
|
||||
|
||||
# App user for RLS enforcement (cannot bypass RLS)
|
||||
APP_USER_DATABASE_URL: str = "postgresql+asyncpg://app_user:app_password@localhost:5432/mikrotik"
|
||||
|
||||
# Database connection pool
|
||||
DB_POOL_SIZE: int = 20
|
||||
DB_MAX_OVERFLOW: int = 40
|
||||
DB_ADMIN_POOL_SIZE: int = 10
|
||||
DB_ADMIN_MAX_OVERFLOW: int = 20
|
||||
|
||||
# Redis
|
||||
REDIS_URL: str = "redis://localhost:6379/0"
|
||||
|
||||
# NATS JetStream
|
||||
NATS_URL: str = "nats://localhost:4222"
|
||||
|
||||
# JWT configuration
|
||||
JWT_SECRET_KEY: str = "change-this-in-production-use-a-long-random-string"
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 15
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 7
|
||||
|
||||
# Credential encryption key — must be 32 bytes, base64-encoded in env
|
||||
# Generate with: python -c "import secrets, base64; print(base64.b64encode(secrets.token_bytes(32)).decode())"
|
||||
CREDENTIAL_ENCRYPTION_KEY: str = "LLLjnfBZTSycvL2U07HDSxUeTtLxb9cZzryQl0R9E4w="
|
||||
|
||||
# OpenBao Transit (KMS for per-tenant credential encryption)
|
||||
OPENBAO_ADDR: str = "http://localhost:8200"
|
||||
OPENBAO_TOKEN: str = "dev-openbao-token"
|
||||
|
||||
# First admin bootstrap
|
||||
FIRST_ADMIN_EMAIL: Optional[str] = None
|
||||
FIRST_ADMIN_PASSWORD: Optional[str] = None
|
||||
|
||||
# CORS origins (comma-separated)
|
||||
CORS_ORIGINS: str = "http://localhost:3000,http://localhost:5173,http://localhost:8080"
|
||||
|
||||
# Git store — PVC mount for bare git repos (one per tenant).
|
||||
# In production: /data/git-store (Kubernetes PVC ReadWriteMany).
|
||||
# In local dev: ./git-store (relative to cwd, created on first use).
|
||||
GIT_STORE_PATH: str = "./git-store"
|
||||
|
||||
# WireGuard config path — shared volume with the WireGuard container
|
||||
WIREGUARD_CONFIG_PATH: str = "/data/wireguard"
|
||||
|
||||
# Firmware cache
|
||||
FIRMWARE_CACHE_DIR: str = "/data/firmware-cache" # PVC mount path
|
||||
FIRMWARE_CHECK_INTERVAL_HOURS: int = 24 # How often to check for new versions
|
||||
|
||||
# SMTP settings for transactional email (password reset, etc.)
|
||||
SMTP_HOST: str = "localhost"
|
||||
SMTP_PORT: int = 587
|
||||
SMTP_USER: Optional[str] = None
|
||||
SMTP_PASSWORD: Optional[str] = None
|
||||
SMTP_USE_TLS: bool = False
|
||||
SMTP_FROM_ADDRESS: str = "noreply@mikrotik-portal.local"
|
||||
|
||||
# Password reset
|
||||
PASSWORD_RESET_TOKEN_EXPIRE_MINUTES: int = 30
|
||||
APP_BASE_URL: str = "http://localhost:5173"
|
||||
|
||||
# App settings
|
||||
APP_NAME: str = "TOD - The Other Dude"
|
||||
APP_VERSION: str = "0.1.0"
|
||||
DEBUG: bool = False
|
||||
|
||||
@field_validator("CREDENTIAL_ENCRYPTION_KEY")
|
||||
@classmethod
|
||||
def validate_encryption_key(cls, v: str) -> str:
|
||||
"""Ensure the key decodes to exactly 32 bytes.
|
||||
|
||||
Note: CHANGE_ME_IN_PRODUCTION is allowed through this validator
|
||||
because it fails the base64 length check. The production safety
|
||||
check in validate_production_settings() catches it separately.
|
||||
"""
|
||||
if v == "CHANGE_ME_IN_PRODUCTION":
|
||||
# Allow the placeholder through field validation -- the production
|
||||
# safety check will reject it in non-dev environments.
|
||||
return v
|
||||
try:
|
||||
key_bytes = base64.b64decode(v)
|
||||
if len(key_bytes) != 32:
|
||||
raise ValueError(
|
||||
f"CREDENTIAL_ENCRYPTION_KEY must decode to exactly 32 bytes, got {len(key_bytes)}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid CREDENTIAL_ENCRYPTION_KEY: {e}") from e
|
||||
return v
|
||||
|
||||
def get_encryption_key_bytes(self) -> bytes:
|
||||
"""Return the encryption key as raw bytes."""
|
||||
return base64.b64decode(self.CREDENTIAL_ENCRYPTION_KEY)
|
||||
|
||||
def get_cors_origins(self) -> list[str]:
|
||||
"""Return CORS origins as a list."""
|
||||
return [origin.strip() for origin in self.CORS_ORIGINS.split(",") if origin.strip()]
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""Return cached settings instance.
|
||||
|
||||
Validates that production environments do not use insecure defaults.
|
||||
This runs once (cached) at startup before the app accepts requests.
|
||||
"""
|
||||
s = Settings()
|
||||
validate_production_settings(s)
|
||||
return s
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
114
backend/app/database.py
Normal file
114
backend/app/database.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Database engine, session factory, and dependency injection."""
|
||||
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all SQLAlchemy ORM models."""
|
||||
pass
|
||||
|
||||
|
||||
# Primary engine using postgres superuser (for migrations/admin)
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=settings.DEBUG,
|
||||
pool_pre_ping=True,
|
||||
pool_size=settings.DB_ADMIN_POOL_SIZE,
|
||||
max_overflow=settings.DB_ADMIN_MAX_OVERFLOW,
|
||||
)
|
||||
|
||||
# App user engine (enforces RLS — no superuser bypass)
|
||||
app_engine = create_async_engine(
|
||||
settings.APP_USER_DATABASE_URL,
|
||||
echo=settings.DEBUG,
|
||||
pool_pre_ping=True,
|
||||
pool_size=settings.DB_POOL_SIZE,
|
||||
max_overflow=settings.DB_MAX_OVERFLOW,
|
||||
)
|
||||
|
||||
# Session factory for the app_user connection (RLS enforced)
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
app_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
# Admin session factory (for bootstrap/migrations only)
|
||||
AdminAsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Dependency that yields an async database session using app_user (RLS enforced).
|
||||
|
||||
The tenant context (SET LOCAL app.current_tenant) must be set by
|
||||
tenant_context middleware before any tenant-scoped queries.
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def get_admin_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Dependency that yields an admin database session (bypasses RLS).
|
||||
USE ONLY for bootstrap operations and internal system tasks.
|
||||
"""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def set_tenant_context(session: AsyncSession, tenant_id: Optional[str]) -> None:
|
||||
"""
|
||||
Set the PostgreSQL session variable for RLS enforcement.
|
||||
|
||||
This MUST be called before any tenant-scoped query to activate RLS policies.
|
||||
Uses SET LOCAL so the context resets at transaction end.
|
||||
"""
|
||||
if tenant_id:
|
||||
# Allow 'super_admin' as a special RLS context value for cross-tenant access.
|
||||
# Otherwise validate tenant_id is a valid UUID to prevent SQL injection.
|
||||
# SET LOCAL cannot use parameterized queries in PostgreSQL.
|
||||
if tenant_id != "super_admin":
|
||||
try:
|
||||
uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid tenant_id format: {tenant_id!r}")
|
||||
await session.execute(text(f"SET LOCAL app.current_tenant = '{tenant_id}'"))
|
||||
else:
|
||||
# For super_admin users: set empty string which will not match any tenant
|
||||
# The super_admin uses the admin engine which bypasses RLS
|
||||
await session.execute(text("SET LOCAL app.current_tenant = ''"))
|
||||
81
backend/app/logging_config.py
Normal file
81
backend/app/logging_config.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Structured logging configuration for the FastAPI backend.
|
||||
|
||||
Uses structlog with two rendering modes:
|
||||
- Dev mode (ENVIRONMENT=dev or DEBUG=true): colored console output
|
||||
- Prod mode: machine-parseable JSON output
|
||||
|
||||
Must be called once during app startup (in lifespan), NOT at module import time,
|
||||
so tests can override the configuration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import structlog
|
||||
|
||||
|
||||
def configure_logging() -> None:
|
||||
"""Configure structlog for the FastAPI application.
|
||||
|
||||
Dev mode: colored console output with human-readable formatting.
|
||||
Prod mode: JSON output with machine-parseable fields.
|
||||
|
||||
Must be called once during app startup (in lifespan), NOT at module import time,
|
||||
so tests can override the configuration.
|
||||
"""
|
||||
is_dev = os.getenv("ENVIRONMENT", "dev") == "dev"
|
||||
log_level_name = os.getenv("LOG_LEVEL", "debug" if is_dev else "info").upper()
|
||||
log_level = getattr(logging, log_level_name, logging.INFO)
|
||||
|
||||
shared_processors: list[structlog.types.Processor] = [
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.stdlib.add_logger_name,
|
||||
structlog.stdlib.add_log_level,
|
||||
structlog.stdlib.PositionalArgumentsFormatter(),
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
]
|
||||
|
||||
if is_dev:
|
||||
renderer = structlog.dev.ConsoleRenderer()
|
||||
else:
|
||||
renderer = structlog.processors.JSONRenderer()
|
||||
|
||||
structlog.configure(
|
||||
processors=[
|
||||
*shared_processors,
|
||||
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
||||
],
|
||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||
wrapper_class=structlog.stdlib.BoundLogger,
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
|
||||
# Capture stdlib loggers (uvicorn, SQLAlchemy, alembic) into structlog pipeline
|
||||
formatter = structlog.stdlib.ProcessorFormatter(
|
||||
processors=[
|
||||
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
|
||||
renderer,
|
||||
],
|
||||
)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.handlers.clear()
|
||||
root_logger.addHandler(handler)
|
||||
root_logger.setLevel(log_level)
|
||||
|
||||
# Quiet down noisy libraries in dev
|
||||
if is_dev:
|
||||
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def get_logger(name: str | None = None) -> structlog.stdlib.BoundLogger:
|
||||
"""Get a structlog bound logger.
|
||||
|
||||
Use this instead of logging.getLogger() throughout the application.
|
||||
"""
|
||||
return structlog.get_logger(name)
|
||||
330
backend/app/main.py
Normal file
330
backend/app/main.py
Normal file
@@ -0,0 +1,330 @@
|
||||
"""FastAPI application entry point."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import structlog
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from app.config import settings
|
||||
from app.logging_config import configure_logging
|
||||
from app.middleware.rate_limit import setup_rate_limiting
|
||||
from app.middleware.request_id import RequestIDMiddleware
|
||||
from app.middleware.security_headers import SecurityHeadersMiddleware
|
||||
from app.observability import check_health_ready, setup_instrumentator
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
async def run_migrations() -> None:
|
||||
"""Run Alembic migrations on startup."""
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "alembic", "upgrade", "head"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.error("migration failed", stderr=result.stderr)
|
||||
raise RuntimeError(f"Database migration failed: {result.stderr}")
|
||||
logger.info("migrations applied successfully")
|
||||
|
||||
|
||||
async def bootstrap_first_admin() -> None:
|
||||
"""Create the first super_admin user if no users exist."""
|
||||
if not settings.FIRST_ADMIN_EMAIL or not settings.FIRST_ADMIN_PASSWORD:
|
||||
logger.info("FIRST_ADMIN_EMAIL/PASSWORD not set, skipping bootstrap")
|
||||
return
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
from app.models.user import User, UserRole
|
||||
from app.services.auth import hash_password
|
||||
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
# Check if any users exist (bypass RLS with admin session)
|
||||
result = await session.execute(select(User).limit(1))
|
||||
existing_user = result.scalar_one_or_none()
|
||||
|
||||
if existing_user:
|
||||
logger.info("users already exist, skipping first admin bootstrap")
|
||||
return
|
||||
|
||||
# Create the first super_admin with bcrypt password.
|
||||
# must_upgrade_auth=True triggers the SRP registration flow on first login.
|
||||
admin = User(
|
||||
email=settings.FIRST_ADMIN_EMAIL,
|
||||
hashed_password=hash_password(settings.FIRST_ADMIN_PASSWORD),
|
||||
name="Super Admin",
|
||||
role=UserRole.SUPER_ADMIN.value,
|
||||
tenant_id=None, # super_admin has no tenant
|
||||
is_active=True,
|
||||
must_upgrade_auth=True,
|
||||
)
|
||||
session.add(admin)
|
||||
await session.commit()
|
||||
logger.info("created first super_admin", email=settings.FIRST_ADMIN_EMAIL)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Application lifespan: run migrations and bootstrap on startup."""
|
||||
from app.services.backup_scheduler import start_backup_scheduler, stop_backup_scheduler
|
||||
from app.services.firmware_subscriber import start_firmware_subscriber, stop_firmware_subscriber
|
||||
from app.services.metrics_subscriber import start_metrics_subscriber, stop_metrics_subscriber
|
||||
from app.services.nats_subscriber import start_nats_subscriber, stop_nats_subscriber
|
||||
from app.services.sse_manager import ensure_sse_streams
|
||||
|
||||
# Configure structured logging FIRST -- before any other startup work
|
||||
configure_logging()
|
||||
|
||||
logger.info("starting TOD API")
|
||||
|
||||
# Run database migrations
|
||||
await run_migrations()
|
||||
|
||||
# Bootstrap first admin user
|
||||
await bootstrap_first_admin()
|
||||
|
||||
# Start NATS subscriber for device status events.
|
||||
# Wrapped in try/except so NATS failure doesn't prevent API startup --
|
||||
# allows running the API locally without NATS during frontend development.
|
||||
nats_connection = None
|
||||
try:
|
||||
nats_connection = await start_nats_subscriber()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"NATS status subscriber could not start (API will run without it)",
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
# Start NATS subscriber for device metrics events (separate NATS connection).
|
||||
# Same pattern -- failure is non-fatal so the API starts without full NATS stack.
|
||||
metrics_nc = None
|
||||
try:
|
||||
metrics_nc = await start_metrics_subscriber()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"NATS metrics subscriber could not start (API will run without it)",
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
# Start NATS subscriber for device firmware events (separate NATS connection).
|
||||
firmware_nc = None
|
||||
try:
|
||||
firmware_nc = await start_firmware_subscriber()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"NATS firmware subscriber could not start (API will run without it)",
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
# Ensure NATS streams for SSE event delivery exist (ALERT_EVENTS, OPERATION_EVENTS).
|
||||
# Non-fatal -- API starts without SSE streams; they'll be created on first SSE connection.
|
||||
try:
|
||||
await ensure_sse_streams()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"SSE NATS streams could not be created (SSE will retry on connection)",
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
# Start APScheduler for automated nightly config backups.
|
||||
# Non-fatal -- API starts and serves requests even without the scheduler.
|
||||
try:
|
||||
await start_backup_scheduler()
|
||||
except Exception as exc:
|
||||
logger.warning("backup scheduler could not start", error=str(exc))
|
||||
|
||||
# Register daily firmware version check (3am UTC) on the same scheduler.
|
||||
try:
|
||||
from app.services.firmware_service import schedule_firmware_checks
|
||||
|
||||
schedule_firmware_checks()
|
||||
except Exception as exc:
|
||||
logger.warning("firmware check scheduler could not start", error=str(exc))
|
||||
|
||||
# Provision OpenBao Transit keys for existing tenants and migrate legacy credentials.
|
||||
# Non-blocking: if OpenBao is unavailable, the dual-read path handles fallback.
|
||||
if settings.OPENBAO_ADDR:
|
||||
try:
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
from app.services.key_service import provision_existing_tenants
|
||||
|
||||
async with AdminAsyncSessionLocal() as openbao_session:
|
||||
counts = await provision_existing_tenants(openbao_session)
|
||||
logger.info(
|
||||
"openbao tenant provisioning complete",
|
||||
**{k: v for k, v in counts.items()},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"openbao tenant provisioning failed (will retry on next restart)",
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
# Recover stale push operations from previous API instance
|
||||
try:
|
||||
from app.services.restore_service import recover_stale_push_operations
|
||||
from app.database import AdminAsyncSessionLocal as _AdminSession
|
||||
|
||||
async with _AdminSession() as session:
|
||||
await recover_stale_push_operations(session)
|
||||
logger.info("push operation recovery check complete")
|
||||
except Exception as e:
|
||||
logger.error("push operation recovery failed (non-fatal): %s", e)
|
||||
|
||||
# Config change subscriber (event-driven backups)
|
||||
config_change_nc = None
|
||||
try:
|
||||
from app.services.config_change_subscriber import (
|
||||
start_config_change_subscriber,
|
||||
stop_config_change_subscriber,
|
||||
)
|
||||
config_change_nc = await start_config_change_subscriber()
|
||||
except Exception as e:
|
||||
logger.error("Config change subscriber failed to start (non-fatal): %s", e)
|
||||
|
||||
# Push rollback/alert subscriber
|
||||
push_rollback_nc = None
|
||||
try:
|
||||
from app.services.push_rollback_subscriber import (
|
||||
start_push_rollback_subscriber,
|
||||
stop_push_rollback_subscriber,
|
||||
)
|
||||
push_rollback_nc = await start_push_rollback_subscriber()
|
||||
except Exception as e:
|
||||
logger.error("Push rollback subscriber failed to start (non-fatal): %s", e)
|
||||
|
||||
logger.info("startup complete, ready to serve requests")
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("shutting down TOD API")
|
||||
await stop_backup_scheduler()
|
||||
await stop_nats_subscriber(nats_connection)
|
||||
await stop_metrics_subscriber(metrics_nc)
|
||||
await stop_firmware_subscriber(firmware_nc)
|
||||
if config_change_nc:
|
||||
await stop_config_change_subscriber()
|
||||
if push_rollback_nc:
|
||||
await stop_push_rollback_subscriber()
|
||||
|
||||
# Dispose database engine connections to release all pooled connections cleanly.
|
||||
from app.database import app_engine, engine
|
||||
|
||||
await app_engine.dispose()
|
||||
await engine.dispose()
|
||||
logger.info("database connections closed")
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create and configure the FastAPI application."""
|
||||
app = FastAPI(
|
||||
title=settings.APP_NAME,
|
||||
version=settings.APP_VERSION,
|
||||
description="The Other Dude — Fleet Management API",
|
||||
docs_url="/docs" if settings.ENVIRONMENT == "dev" else None,
|
||||
redoc_url="/redoc" if settings.ENVIRONMENT == "dev" else None,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Starlette processes middleware in LIFO order (last added = first to run).
|
||||
# We want: Request -> RequestID -> CORS -> Route handler
|
||||
# So add CORS first, then RequestID (it will wrap CORS).
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.get_cors_origins(),
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["Authorization", "Content-Type", "X-Request-ID"],
|
||||
)
|
||||
app.add_middleware(SecurityHeadersMiddleware, environment=settings.ENVIRONMENT)
|
||||
setup_rate_limiting(app) # Register 429 exception handler (no middleware added)
|
||||
app.add_middleware(RequestIDMiddleware)
|
||||
|
||||
# Include routers
|
||||
from app.routers.alerts import router as alerts_router
|
||||
from app.routers.auth import router as auth_router
|
||||
from app.routers.sse import router as sse_router
|
||||
from app.routers.config_backups import router as config_router
|
||||
from app.routers.config_editor import router as config_editor_router
|
||||
from app.routers.device_groups import router as device_groups_router
|
||||
from app.routers.device_tags import router as device_tags_router
|
||||
from app.routers.devices import router as devices_router
|
||||
from app.routers.firmware import router as firmware_router
|
||||
from app.routers.metrics import router as metrics_router
|
||||
from app.routers.events import router as events_router
|
||||
from app.routers.clients import router as clients_router
|
||||
from app.routers.device_logs import router as device_logs_router
|
||||
from app.routers.templates import router as templates_router
|
||||
from app.routers.tenants import router as tenants_router
|
||||
from app.routers.reports import router as reports_router
|
||||
from app.routers.topology import router as topology_router
|
||||
from app.routers.users import router as users_router
|
||||
from app.routers.audit_logs import router as audit_logs_router
|
||||
from app.routers.api_keys import router as api_keys_router
|
||||
from app.routers.maintenance_windows import router as maintenance_windows_router
|
||||
from app.routers.vpn import router as vpn_router
|
||||
from app.routers.certificates import router as certificates_router
|
||||
from app.routers.transparency import router as transparency_router
|
||||
from app.routers.settings import router as settings_router
|
||||
|
||||
app.include_router(auth_router, prefix="/api")
|
||||
app.include_router(tenants_router, prefix="/api")
|
||||
app.include_router(users_router, prefix="/api")
|
||||
app.include_router(devices_router, prefix="/api")
|
||||
app.include_router(device_groups_router, prefix="/api")
|
||||
app.include_router(device_tags_router, prefix="/api")
|
||||
app.include_router(metrics_router, prefix="/api")
|
||||
app.include_router(config_router, prefix="/api")
|
||||
app.include_router(firmware_router, prefix="/api")
|
||||
app.include_router(alerts_router, prefix="/api")
|
||||
app.include_router(config_editor_router, prefix="/api")
|
||||
app.include_router(events_router, prefix="/api")
|
||||
app.include_router(device_logs_router, prefix="/api")
|
||||
app.include_router(templates_router, prefix="/api")
|
||||
app.include_router(clients_router, prefix="/api")
|
||||
app.include_router(topology_router, prefix="/api")
|
||||
app.include_router(sse_router, prefix="/api")
|
||||
app.include_router(audit_logs_router, prefix="/api")
|
||||
app.include_router(reports_router, prefix="/api")
|
||||
app.include_router(api_keys_router, prefix="/api")
|
||||
app.include_router(maintenance_windows_router, prefix="/api")
|
||||
app.include_router(vpn_router, prefix="/api")
|
||||
app.include_router(certificates_router, prefix="/api/certificates", tags=["certificates"])
|
||||
app.include_router(transparency_router, prefix="/api")
|
||||
app.include_router(settings_router, prefix="/api")
|
||||
|
||||
# Health check endpoints
|
||||
@app.get("/health", tags=["health"])
|
||||
async def health_check() -> dict:
|
||||
"""Liveness probe -- returns 200 if the process is alive."""
|
||||
return {"status": "ok", "version": settings.APP_VERSION}
|
||||
|
||||
@app.get("/health/ready", tags=["health"])
|
||||
async def health_ready() -> JSONResponse:
|
||||
"""Readiness probe -- returns 200 only when PostgreSQL, Redis, and NATS are healthy."""
|
||||
result = await check_health_ready()
|
||||
status_code = 200 if result["status"] == "healthy" else 503
|
||||
return JSONResponse(content=result, status_code=status_code)
|
||||
|
||||
@app.get("/api/health", tags=["health"])
|
||||
async def api_health_check() -> dict:
|
||||
"""Backward-compatible health endpoint under /api prefix."""
|
||||
return {"status": "ok", "version": settings.APP_VERSION}
|
||||
|
||||
# Prometheus metrics instrumentation -- MUST be after routers so all routes are captured
|
||||
setup_instrumentator(app)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
1
backend/app/middleware/__init__.py
Normal file
1
backend/app/middleware/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""FastAPI middleware and dependencies for auth, tenant context, and RBAC."""
|
||||
48
backend/app/middleware/rate_limit.py
Normal file
48
backend/app/middleware/rate_limit.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Rate limiting middleware using slowapi with Redis backend.
|
||||
|
||||
Per-route rate limits only -- no global limits to avoid blocking the
|
||||
Go poller, NATS subscribers, and health check endpoints.
|
||||
|
||||
Rate limit data uses Redis DB 1 (separate from app data in DB 0).
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
def _get_redis_url() -> str:
|
||||
"""Return Redis URL pointing to DB 1 for rate limit storage.
|
||||
|
||||
Keeps rate limit counters separate from application data in DB 0.
|
||||
"""
|
||||
url = settings.REDIS_URL
|
||||
if url.endswith("/0"):
|
||||
return url[:-2] + "/1"
|
||||
# If no DB specified or different DB, append /1
|
||||
if url.rstrip("/").split("/")[-1].isdigit():
|
||||
# Replace existing DB number
|
||||
parts = url.rsplit("/", 1)
|
||||
return parts[0] + "/1"
|
||||
return url.rstrip("/") + "/1"
|
||||
|
||||
|
||||
limiter = Limiter(
|
||||
key_func=get_remote_address,
|
||||
storage_uri=_get_redis_url(),
|
||||
default_limits=[], # No global limits -- per-route only
|
||||
)
|
||||
|
||||
|
||||
def setup_rate_limiting(app: FastAPI) -> None:
|
||||
"""Register the rate limiter on the FastAPI app.
|
||||
|
||||
This sets app.state.limiter (required by slowapi) and registers
|
||||
the 429 exception handler. It does NOT add middleware -- the
|
||||
@limiter.limit() decorators handle actual limiting per-route.
|
||||
"""
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
186
backend/app/middleware/rbac.py
Normal file
186
backend/app/middleware/rbac.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
Role-Based Access Control (RBAC) middleware.
|
||||
|
||||
Provides dependency factories for enforcing role-based access control
|
||||
on FastAPI routes. Roles are hierarchical:
|
||||
|
||||
super_admin > tenant_admin > operator > viewer
|
||||
|
||||
Role permissions per plan TENANT-04/05/06:
|
||||
- viewer: GET endpoints only (read-only)
|
||||
- operator: GET + device/config management endpoints
|
||||
- tenant_admin: full access within their tenant
|
||||
- super_admin: full access across all tenants
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi.params import Depends as DependsClass
|
||||
|
||||
from app.middleware.tenant_context import CurrentUser, get_current_user
|
||||
|
||||
# Role hierarchy (higher index = more privilege)
|
||||
# api_key is at operator level for RBAC checks; fine-grained access controlled by scopes.
|
||||
ROLE_HIERARCHY = {
|
||||
"viewer": 0,
|
||||
"api_key": 1,
|
||||
"operator": 1,
|
||||
"tenant_admin": 2,
|
||||
"super_admin": 3,
|
||||
}
|
||||
|
||||
|
||||
def _get_role_level(role: str) -> int:
|
||||
"""Return numeric privilege level for a role string."""
|
||||
return ROLE_HIERARCHY.get(role, -1)
|
||||
|
||||
|
||||
def require_role(*allowed_roles: str) -> Callable:
|
||||
"""
|
||||
FastAPI dependency factory that checks the current user's role.
|
||||
|
||||
Usage:
|
||||
@router.post("/items", dependencies=[Depends(require_role("tenant_admin", "super_admin"))])
|
||||
|
||||
Args:
|
||||
*allowed_roles: Role strings that are permitted to access the endpoint
|
||||
|
||||
Returns:
|
||||
FastAPI dependency that raises 403 if the role is insufficient
|
||||
"""
|
||||
async def dependency(
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> CurrentUser:
|
||||
if current_user.role not in allowed_roles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Access denied. Required roles: {', '.join(allowed_roles)}. "
|
||||
f"Your role: {current_user.role}",
|
||||
)
|
||||
return current_user
|
||||
|
||||
return dependency
|
||||
|
||||
|
||||
def require_min_role(min_role: str) -> Callable:
|
||||
"""
|
||||
Dependency factory that allows any role at or above the minimum level.
|
||||
|
||||
Usage:
|
||||
@router.get("/items", dependencies=[Depends(require_min_role("operator"))])
|
||||
# Allows: operator, tenant_admin, super_admin
|
||||
# Denies: viewer
|
||||
"""
|
||||
min_level = _get_role_level(min_role)
|
||||
|
||||
async def dependency(
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> CurrentUser:
|
||||
user_level = _get_role_level(current_user.role)
|
||||
if user_level < min_level:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Access denied. Minimum required role: {min_role}. "
|
||||
f"Your role: {current_user.role}",
|
||||
)
|
||||
return current_user
|
||||
|
||||
return dependency
|
||||
|
||||
|
||||
def require_write_access() -> Callable:
|
||||
"""
|
||||
Dependency that enforces viewer read-only restriction.
|
||||
|
||||
Viewers are NOT allowed on POST/PUT/PATCH/DELETE endpoints.
|
||||
Call this on any mutating endpoint to deny viewers.
|
||||
"""
|
||||
async def dependency(
|
||||
request: Request,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> CurrentUser:
|
||||
if request.method in ("POST", "PUT", "PATCH", "DELETE"):
|
||||
if current_user.role == "viewer":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Viewers have read-only access. "
|
||||
"Contact your administrator to request elevated permissions.",
|
||||
)
|
||||
return current_user
|
||||
|
||||
return dependency
|
||||
|
||||
|
||||
def require_scope(scope: str) -> DependsClass:
|
||||
"""FastAPI dependency that checks API key scopes.
|
||||
|
||||
No-op for regular users (JWT auth) -- scopes only apply to API keys.
|
||||
For API key users: checks that the required scope is in the key's scope list.
|
||||
|
||||
Returns a Depends() instance so it can be used in dependency lists:
|
||||
@router.get("/items", dependencies=[require_scope("devices:read")])
|
||||
|
||||
Args:
|
||||
scope: Required scope string (e.g. "devices:read", "config:write").
|
||||
|
||||
Raises:
|
||||
HTTPException 403 if the API key is missing the required scope.
|
||||
"""
|
||||
async def _check_scope(
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> CurrentUser:
|
||||
if current_user.role == "api_key":
|
||||
if not current_user.scopes or scope not in current_user.scopes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"API key missing required scope: {scope}",
|
||||
)
|
||||
return current_user
|
||||
|
||||
return Depends(_check_scope)
|
||||
|
||||
|
||||
# Pre-built convenience dependencies
|
||||
|
||||
async def require_super_admin(
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> CurrentUser:
|
||||
"""Require super_admin role (portal-wide admin)."""
|
||||
if current_user.role != "super_admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. Super admin role required.",
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
async def require_tenant_admin_or_above(
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> CurrentUser:
|
||||
"""Require tenant_admin or super_admin role."""
|
||||
if current_user.role not in ("tenant_admin", "super_admin"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. Tenant admin or higher role required.",
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
async def require_operator_or_above(
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> CurrentUser:
|
||||
"""Require operator, tenant_admin, or super_admin role."""
|
||||
if current_user.role not in ("operator", "tenant_admin", "super_admin"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. Operator or higher role required.",
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
async def require_authenticated(
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> CurrentUser:
|
||||
"""Require any authenticated user (viewer and above)."""
|
||||
return current_user
|
||||
67
backend/app/middleware/request_id.py
Normal file
67
backend/app/middleware/request_id.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Request ID middleware for structured logging context.
|
||||
|
||||
Generates or extracts a request ID for every incoming request and binds it
|
||||
(along with tenant_id from JWT) to structlog's contextvars so that all log
|
||||
lines emitted during the request include these correlation fields.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import structlog
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
|
||||
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware that binds request_id and tenant_id to structlog context."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# CRITICAL: Clear stale context from previous request to prevent leaks
|
||||
structlog.contextvars.clear_contextvars()
|
||||
|
||||
# Generate or extract request ID
|
||||
request_id = request.headers.get("X-Request-ID", str(uuid.uuid4()))
|
||||
|
||||
# Best-effort tenant_id extraction from JWT (does not fail if no token)
|
||||
tenant_id = self._extract_tenant_id(request)
|
||||
|
||||
# Bind to structlog context -- all subsequent log calls include these fields
|
||||
structlog.contextvars.bind_contextvars(
|
||||
request_id=request_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
response: Response = await call_next(request)
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
return response
|
||||
|
||||
def _extract_tenant_id(self, request: Request) -> str | None:
|
||||
"""Best-effort extraction of tenant_id from JWT.
|
||||
|
||||
Looks in cookies first (access_token), then Authorization header.
|
||||
Returns None if no valid token is found -- this is fine for
|
||||
unauthenticated endpoints like /login.
|
||||
"""
|
||||
token = request.cookies.get("access_token")
|
||||
if not token:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:]
|
||||
|
||||
if not token:
|
||||
return None
|
||||
|
||||
try:
|
||||
from jose import jwt as jose_jwt
|
||||
|
||||
from app.config import settings
|
||||
|
||||
payload = jose_jwt.decode(
|
||||
token,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM],
|
||||
)
|
||||
return payload.get("tenant_id")
|
||||
except Exception:
|
||||
return None
|
||||
79
backend/app/middleware/security_headers.py
Normal file
79
backend/app/middleware/security_headers.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Security response headers middleware.
|
||||
|
||||
Adds standard security headers to all API responses:
|
||||
- X-Content-Type-Options: nosniff (prevent MIME sniffing)
|
||||
- X-Frame-Options: DENY (prevent clickjacking)
|
||||
- Referrer-Policy: strict-origin-when-cross-origin
|
||||
- Cache-Control: no-store (prevent browser caching of API responses)
|
||||
- Strict-Transport-Security (HSTS, production only -- breaks plain HTTP dev)
|
||||
- Content-Security-Policy (strict in production, relaxed for dev HMR)
|
||||
|
||||
CSP directives:
|
||||
- script-src 'self' (production) blocks inline scripts -- XSS mitigation
|
||||
- style-src 'unsafe-inline' required for Tailwind, Framer Motion, Radix, Sonner
|
||||
- connect-src includes wss:/ws: for SSE and WebSocket connections
|
||||
- Dev mode adds 'unsafe-inline' and 'unsafe-eval' for Vite HMR
|
||||
"""
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
# Production CSP: strict -- no inline scripts allowed
|
||||
_CSP_PRODUCTION = "; ".join([
|
||||
"default-src 'self'",
|
||||
"script-src 'self'",
|
||||
"style-src 'self' 'unsafe-inline'",
|
||||
"img-src 'self' data: blob:",
|
||||
"font-src 'self'",
|
||||
"connect-src 'self' wss: ws:",
|
||||
"worker-src 'self'",
|
||||
"frame-ancestors 'none'",
|
||||
"base-uri 'self'",
|
||||
"form-action 'self'",
|
||||
])
|
||||
|
||||
# Development CSP: relaxed for Vite HMR (hot module replacement)
|
||||
_CSP_DEV = "; ".join([
|
||||
"default-src 'self'",
|
||||
"script-src 'self' 'unsafe-inline' 'unsafe-eval'",
|
||||
"style-src 'self' 'unsafe-inline'",
|
||||
"img-src 'self' data: blob:",
|
||||
"font-src 'self'",
|
||||
"connect-src 'self' http://localhost:* ws://localhost:* wss:",
|
||||
"worker-src 'self' blob:",
|
||||
"frame-ancestors 'none'",
|
||||
"base-uri 'self'",
|
||||
"form-action 'self'",
|
||||
])
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Add security headers to every API response."""
|
||||
|
||||
def __init__(self, app, environment: str = "dev"):
|
||||
super().__init__(app)
|
||||
self.is_production = environment != "dev"
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
response = await call_next(request)
|
||||
|
||||
# Always-on security headers
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
response.headers["Cache-Control"] = "no-store"
|
||||
|
||||
# Content-Security-Policy (environment-aware)
|
||||
if self.is_production:
|
||||
response.headers["Content-Security-Policy"] = _CSP_PRODUCTION
|
||||
else:
|
||||
response.headers["Content-Security-Policy"] = _CSP_DEV
|
||||
|
||||
# HSTS only in production (plain HTTP in dev would be blocked)
|
||||
if self.is_production:
|
||||
response.headers["Strict-Transport-Security"] = (
|
||||
"max-age=31536000; includeSubDomains"
|
||||
)
|
||||
|
||||
return response
|
||||
177
backend/app/middleware/tenant_context.py
Normal file
177
backend/app/middleware/tenant_context.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""
|
||||
Tenant context middleware and current user dependency.
|
||||
|
||||
Extracts JWT from Authorization header (Bearer token) or httpOnly cookie,
|
||||
validates it, and provides current user context for request handlers.
|
||||
|
||||
For tenant-scoped users: sets SET LOCAL app.current_tenant on the DB session.
|
||||
For super_admin: uses special 'super_admin' context that grants cross-tenant access.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import Cookie, Depends, HTTPException, Request, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db, set_tenant_context
|
||||
from app.services.auth import verify_token
|
||||
|
||||
# Optional HTTP Bearer scheme (won't raise 403 automatically — we handle auth ourselves)
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
class CurrentUser:
|
||||
"""Represents the currently authenticated user extracted from JWT or API key."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
tenant_id: Optional[uuid.UUID],
|
||||
role: str,
|
||||
scopes: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
self.user_id = user_id
|
||||
self.tenant_id = tenant_id
|
||||
self.role = role
|
||||
self.scopes = scopes
|
||||
|
||||
@property
|
||||
def is_super_admin(self) -> bool:
|
||||
return self.role == "super_admin"
|
||||
|
||||
@property
|
||||
def is_api_key(self) -> bool:
|
||||
return self.role == "api_key"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<CurrentUser user_id={self.user_id} role={self.role} tenant_id={self.tenant_id}>"
|
||||
|
||||
|
||||
def _extract_token(
|
||||
request: Request,
|
||||
credentials: Optional[HTTPAuthorizationCredentials],
|
||||
access_token: Optional[str],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Extract JWT token from Authorization header or httpOnly cookie.
|
||||
|
||||
Priority: Authorization header > cookie.
|
||||
"""
|
||||
if credentials and credentials.scheme.lower() == "bearer":
|
||||
return credentials.credentials
|
||||
|
||||
if access_token:
|
||||
return access_token
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
request: Request,
|
||||
credentials: Annotated[Optional[HTTPAuthorizationCredentials], Depends(bearer_scheme)] = None,
|
||||
access_token: Annotated[Optional[str], Cookie()] = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> CurrentUser:
|
||||
"""
|
||||
FastAPI dependency that extracts and validates the current user from JWT.
|
||||
|
||||
Supports both Bearer token (Authorization header) and httpOnly cookie.
|
||||
Sets the tenant context on the database session for RLS enforcement.
|
||||
|
||||
Raises:
|
||||
HTTPException 401: If no token provided or token is invalid
|
||||
"""
|
||||
token = _extract_token(request, credentials, access_token)
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# API key authentication: detect mktp_ prefix and validate via api_key_service
|
||||
if token.startswith("mktp_"):
|
||||
from app.services.api_key_service import validate_api_key
|
||||
|
||||
key_data = await validate_api_key(token)
|
||||
if not key_data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid, expired, or revoked API key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
tenant_id = key_data["tenant_id"]
|
||||
# Set tenant context on the request-scoped DB session for RLS
|
||||
await set_tenant_context(db, str(tenant_id))
|
||||
|
||||
return CurrentUser(
|
||||
user_id=key_data["user_id"],
|
||||
tenant_id=tenant_id,
|
||||
role="api_key",
|
||||
scopes=key_data["scopes"],
|
||||
)
|
||||
|
||||
# Decode and validate the JWT
|
||||
payload = verify_token(token, expected_type="access")
|
||||
|
||||
user_id_str = payload.get("sub")
|
||||
tenant_id_str = payload.get("tenant_id")
|
||||
role = payload.get("role")
|
||||
|
||||
if not user_id_str or not role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token payload",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
user_id = uuid.UUID(user_id_str)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token payload",
|
||||
)
|
||||
|
||||
tenant_id: Optional[uuid.UUID] = None
|
||||
if tenant_id_str:
|
||||
try:
|
||||
tenant_id = uuid.UUID(tenant_id_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Set the tenant context on the database session for RLS enforcement
|
||||
if role == "super_admin":
|
||||
# super_admin uses special context that grants cross-tenant access
|
||||
await set_tenant_context(db, "super_admin")
|
||||
elif tenant_id:
|
||||
await set_tenant_context(db, str(tenant_id))
|
||||
else:
|
||||
# Non-super_admin without tenant — deny access
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token: no tenant context",
|
||||
)
|
||||
|
||||
return CurrentUser(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
role=role,
|
||||
)
|
||||
|
||||
|
||||
async def get_optional_current_user(
|
||||
request: Request,
|
||||
credentials: Annotated[Optional[HTTPAuthorizationCredentials], Depends(bearer_scheme)] = None,
|
||||
access_token: Annotated[Optional[str], Cookie()] = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Optional[CurrentUser]:
|
||||
"""Same as get_current_user but returns None instead of raising 401."""
|
||||
try:
|
||||
return await get_current_user(request, credentials, access_token, db)
|
||||
except HTTPException:
|
||||
return None
|
||||
35
backend/app/models/__init__.py
Normal file
35
backend/app/models/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""SQLAlchemy ORM models."""
|
||||
|
||||
from app.models.tenant import Tenant
|
||||
from app.models.user import User, UserRole
|
||||
from app.models.device import Device, DeviceGroup, DeviceTag, DeviceGroupMembership, DeviceTagAssignment, DeviceStatus
|
||||
from app.models.alert import AlertRule, NotificationChannel, AlertRuleChannel, AlertEvent
|
||||
from app.models.firmware import FirmwareVersion, FirmwareUpgradeJob
|
||||
from app.models.config_template import ConfigTemplate, ConfigTemplateTag, TemplatePushJob
|
||||
from app.models.audit_log import AuditLog
|
||||
from app.models.maintenance_window import MaintenanceWindow
|
||||
from app.models.api_key import ApiKey
|
||||
|
||||
__all__ = [
|
||||
"Tenant",
|
||||
"User",
|
||||
"UserRole",
|
||||
"Device",
|
||||
"DeviceGroup",
|
||||
"DeviceTag",
|
||||
"DeviceGroupMembership",
|
||||
"DeviceTagAssignment",
|
||||
"DeviceStatus",
|
||||
"AlertRule",
|
||||
"NotificationChannel",
|
||||
"AlertRuleChannel",
|
||||
"AlertEvent",
|
||||
"FirmwareVersion",
|
||||
"FirmwareUpgradeJob",
|
||||
"ConfigTemplate",
|
||||
"ConfigTemplateTag",
|
||||
"TemplatePushJob",
|
||||
"AuditLog",
|
||||
"MaintenanceWindow",
|
||||
"ApiKey",
|
||||
]
|
||||
177
backend/app/models/alert.py
Normal file
177
backend/app/models/alert.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""Alert system ORM models: rules, notification channels, and alert events."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
LargeBinary,
|
||||
Numeric,
|
||||
Text,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class AlertRule(Base):
|
||||
"""Configurable alert threshold rule.
|
||||
|
||||
Rules can be tenant-wide (device_id=NULL), device-specific, or group-scoped.
|
||||
When a metric breaches the threshold for duration_polls consecutive polls,
|
||||
an alert fires.
|
||||
"""
|
||||
__tablename__ = "alert_rules"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
device_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("devices.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
)
|
||||
group_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("device_groups.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
metric: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
operator: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
threshold: Mapped[float] = mapped_column(Numeric, nullable=False)
|
||||
duration_polls: Mapped[int] = mapped_column(Integer, nullable=False, default=1, server_default="1")
|
||||
severity: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True, server_default="true")
|
||||
is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default="false")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AlertRule id={self.id} name={self.name!r} metric={self.metric}>"
|
||||
|
||||
|
||||
class NotificationChannel(Base):
|
||||
"""Email, webhook, or Slack notification destination."""
|
||||
__tablename__ = "notification_channels"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
channel_type: Mapped[str] = mapped_column(Text, nullable=False) # "email", "webhook", or "slack"
|
||||
# SMTP fields (email channels)
|
||||
smtp_host: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
smtp_port: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
smtp_user: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
smtp_password: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True) # AES-256-GCM encrypted
|
||||
smtp_use_tls: Mapped[bool] = mapped_column(Boolean, default=False, server_default="false")
|
||||
from_address: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
to_address: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
# Webhook fields
|
||||
webhook_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
# Slack fields
|
||||
slack_webhook_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
# OpenBao Transit ciphertext (dual-write migration)
|
||||
smtp_password_transit: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<NotificationChannel id={self.id} name={self.name!r} type={self.channel_type}>"
|
||||
|
||||
|
||||
class AlertRuleChannel(Base):
|
||||
"""Many-to-many association between alert rules and notification channels."""
|
||||
__tablename__ = "alert_rule_channels"
|
||||
|
||||
rule_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("alert_rules.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
channel_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("notification_channels.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
|
||||
class AlertEvent(Base):
|
||||
"""Record of an alert firing, resolving, or flapping.
|
||||
|
||||
rule_id is NULL for system-level alerts (e.g., device offline).
|
||||
"""
|
||||
__tablename__ = "alert_events"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
rule_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("alert_rules.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
device_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("devices.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
status: Mapped[str] = mapped_column(Text, nullable=False) # "firing", "resolved", "flapping"
|
||||
severity: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
metric: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
value: Mapped[float | None] = mapped_column(Numeric, nullable=True)
|
||||
threshold: Mapped[float | None] = mapped_column(Numeric, nullable=True)
|
||||
message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
is_flapping: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default="false")
|
||||
acknowledged_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
acknowledged_by: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
silenced_until: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
fired_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
resolved_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AlertEvent id={self.id} status={self.status} severity={self.severity}>"
|
||||
60
backend/app/models/api_key.py
Normal file
60
backend/app/models/api_key.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""API key ORM model for tenant-scoped programmatic access."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Text, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class ApiKey(Base):
|
||||
"""Tracks API keys for programmatic access to the portal.
|
||||
|
||||
Keys are stored as SHA-256 hashes (never plaintext).
|
||||
Scoped permissions limit what each key can do.
|
||||
Revocation is soft-delete (sets revoked_at, row preserved for audit).
|
||||
"""
|
||||
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
key_prefix: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
key_hash: Mapped[str] = mapped_column(Text, nullable=False, unique=True)
|
||||
scopes: Mapped[list] = mapped_column(JSONB, nullable=False, server_default="'[]'::jsonb")
|
||||
expires_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
last_used_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
revoked_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<ApiKey id={self.id} name={self.name} prefix={self.key_prefix}>"
|
||||
59
backend/app/models/audit_log.py
Normal file
59
backend/app/models/audit_log.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Audit log model for centralized audit trail."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, String, Text, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
"""Records all auditable actions in the system (config changes, CRUD, auth events)."""
|
||||
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
user_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
action: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
resource_type: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
resource_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
device_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("devices.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
details: Mapped[dict[str, Any]] = mapped_column(
|
||||
JSONB,
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
)
|
||||
# Transit-encrypted details JSON (vault:v1:...) — set when details are encrypted
|
||||
encrypted_details: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
ip_address: Mapped[str | None] = mapped_column(String(45), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AuditLog id={self.id} action={self.action!r} tenant_id={self.tenant_id}>"
|
||||
140
backend/app/models/certificate.py
Normal file
140
backend/app/models/certificate.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Certificate Authority and Device Certificate ORM models.
|
||||
|
||||
Supports the Internal Certificate Authority feature:
|
||||
- CertificateAuthority: one per tenant, stores encrypted CA private key + public cert
|
||||
- DeviceCertificate: per-device signed certificate with lifecycle status tracking
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, LargeBinary, String, Text, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class CertificateAuthority(Base):
|
||||
"""Per-tenant root Certificate Authority.
|
||||
|
||||
Each tenant has at most one CA. The CA private key is encrypted with
|
||||
AES-256-GCM before storage (using the same pattern as device credentials).
|
||||
The public cert_pem is not sensitive and can be distributed freely.
|
||||
"""
|
||||
|
||||
__tablename__ = "certificate_authorities"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
)
|
||||
common_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
cert_pem: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
encrypted_private_key: Mapped[bytes] = mapped_column(
|
||||
LargeBinary, nullable=False
|
||||
)
|
||||
serial_number: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
fingerprint_sha256: Mapped[str] = mapped_column(String(95), nullable=False)
|
||||
not_valid_before: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False
|
||||
)
|
||||
not_valid_after: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False
|
||||
)
|
||||
# OpenBao Transit ciphertext (dual-write migration)
|
||||
encrypted_private_key_transit: Mapped[str | None] = mapped_column(
|
||||
Text, nullable=True
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<CertificateAuthority id={self.id} "
|
||||
f"cn={self.common_name!r} tenant={self.tenant_id}>"
|
||||
)
|
||||
|
||||
|
||||
class DeviceCertificate(Base):
|
||||
"""Per-device TLS certificate signed by the tenant's CA.
|
||||
|
||||
Status lifecycle:
|
||||
issued -> deploying -> deployed -> expiring -> expired
|
||||
\\-> revoked
|
||||
\\-> superseded (when rotated)
|
||||
"""
|
||||
|
||||
__tablename__ = "device_certificates"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
device_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("devices.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
ca_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("certificate_authorities.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
common_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
serial_number: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
fingerprint_sha256: Mapped[str] = mapped_column(String(95), nullable=False)
|
||||
cert_pem: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
encrypted_private_key: Mapped[bytes] = mapped_column(
|
||||
LargeBinary, nullable=False
|
||||
)
|
||||
not_valid_before: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False
|
||||
)
|
||||
not_valid_after: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False
|
||||
)
|
||||
# OpenBao Transit ciphertext (dual-write migration)
|
||||
encrypted_private_key_transit: Mapped[str | None] = mapped_column(
|
||||
Text, nullable=True
|
||||
)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(20), nullable=False, server_default="issued"
|
||||
)
|
||||
deployed_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<DeviceCertificate id={self.id} "
|
||||
f"cn={self.common_name!r} status={self.status}>"
|
||||
)
|
||||
178
backend/app/models/config_backup.py
Normal file
178
backend/app/models/config_backup.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""SQLAlchemy models for config backup tables."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, LargeBinary, SmallInteger, String, Text, UniqueConstraint, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class ConfigBackupRun(Base):
|
||||
"""Metadata for a single config backup run.
|
||||
|
||||
The actual config content (export.rsc and backup.bin) lives in the tenant's
|
||||
bare git repository at GIT_STORE_PATH/{tenant_id}.git. This table provides
|
||||
the timeline view and per-run metadata without duplicating file content in
|
||||
PostgreSQL.
|
||||
"""
|
||||
|
||||
__tablename__ = "config_backup_runs"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
device_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("devices.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
# Git commit hash in the tenant's bare repo where this backup is stored.
|
||||
commit_sha: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
# Trigger type: 'scheduled' | 'manual' | 'pre-restore' | 'checkpoint' | 'config-change'
|
||||
trigger_type: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
# Lines added/removed vs the prior export.rsc for this device.
|
||||
# NULL for the first backup (no prior version to diff against).
|
||||
lines_added: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
lines_removed: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
# Encryption metadata: NULL=plaintext, 1=client-side AES-GCM, 2=OpenBao Transit
|
||||
encryption_tier: Mapped[int | None] = mapped_column(SmallInteger, nullable=True)
|
||||
# 12-byte AES-GCM nonce for Tier 1 (client-side) backups; NULL for plaintext/Transit
|
||||
encryption_nonce: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<ConfigBackupRun id={self.id} device_id={self.device_id} "
|
||||
f"trigger={self.trigger_type!r} sha={self.commit_sha[:8]!r}>"
|
||||
)
|
||||
|
||||
|
||||
class ConfigBackupSchedule(Base):
|
||||
"""Per-tenant default and per-device override backup schedule config.
|
||||
|
||||
A row with device_id=NULL is the tenant-level default (daily at 2am).
|
||||
A row with a specific device_id overrides the tenant default for that device.
|
||||
"""
|
||||
|
||||
__tablename__ = "config_backup_schedules"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("tenant_id", "device_id", name="uq_backup_schedule_tenant_device"),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
# NULL = tenant-level default schedule; non-NULL = device-specific override.
|
||||
device_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("devices.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
)
|
||||
# Standard cron expression (5 fields). Default: daily at 2am UTC.
|
||||
cron_expression: Mapped[str] = mapped_column(
|
||||
String(100),
|
||||
nullable=False,
|
||||
default="0 2 * * *",
|
||||
server_default="0 2 * * *",
|
||||
)
|
||||
enabled: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=True,
|
||||
server_default="TRUE",
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
scope = f"device={self.device_id}" if self.device_id else f"tenant={self.tenant_id}"
|
||||
return f"<ConfigBackupSchedule {scope} cron={self.cron_expression!r} enabled={self.enabled}>"
|
||||
|
||||
|
||||
class ConfigPushOperation(Base):
|
||||
"""Tracks pending two-phase config push operations for panic-revert recovery.
|
||||
|
||||
Before pushing a config, a row is inserted with status='pending_verification'.
|
||||
If the API pod restarts during the 60-second verification window, the startup
|
||||
handler checks this table and either commits (deletes the RouterOS scheduler
|
||||
job) or marks the operation as 'failed'. This prevents the panic-revert
|
||||
scheduler from firing and reverting a successful push after an API restart.
|
||||
|
||||
See Pitfall 6 in 04-RESEARCH.md for the full failure scenario.
|
||||
"""
|
||||
|
||||
__tablename__ = "config_push_operations"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
device_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("devices.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
# Git commit SHA we'd revert to if the push fails.
|
||||
pre_push_commit_sha: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
# RouterOS scheduler job name created on the device for panic-revert.
|
||||
scheduler_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# 'pending_verification' | 'committed' | 'reverted' | 'failed'
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(30),
|
||||
nullable=False,
|
||||
default="pending_verification",
|
||||
server_default="pending_verification",
|
||||
)
|
||||
started_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<ConfigPushOperation id={self.id} device_id={self.device_id} "
|
||||
f"status={self.status!r}>"
|
||||
)
|
||||
153
backend/app/models/config_template.py
Normal file
153
backend/app/models/config_template.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Config template, template tag, and template push job models."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
DateTime,
|
||||
Float,
|
||||
ForeignKey,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSON, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class ConfigTemplate(Base):
|
||||
__tablename__ = "config_templates"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("tenant_id", "name", name="uq_config_templates_tenant_name"),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
variables: Mapped[list] = mapped_column(JSON, nullable=False, default=list, server_default="[]")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
tenant: Mapped["Tenant"] = relationship("Tenant") # type: ignore[name-defined]
|
||||
tags: Mapped[list["ConfigTemplateTag"]] = relationship(
|
||||
"ConfigTemplateTag", back_populates="template", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<ConfigTemplate id={self.id} name={self.name!r} tenant_id={self.tenant_id}>"
|
||||
|
||||
|
||||
class ConfigTemplateTag(Base):
|
||||
__tablename__ = "config_template_tags"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("template_id", "name", name="uq_config_template_tags_template_name"),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
template_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("config_templates.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
template: Mapped["ConfigTemplate"] = relationship(
|
||||
"ConfigTemplate", back_populates="tags"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<ConfigTemplateTag id={self.id} name={self.name!r} template_id={self.template_id}>"
|
||||
|
||||
|
||||
class TemplatePushJob(Base):
|
||||
__tablename__ = "template_push_jobs"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
template_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("config_templates.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
device_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("devices.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
rollout_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
nullable=True,
|
||||
)
|
||||
rendered_content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
status: Mapped[str] = mapped_column(
|
||||
Text,
|
||||
nullable=False,
|
||||
default="pending",
|
||||
server_default="pending",
|
||||
)
|
||||
pre_push_backup_sha: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
started_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
template: Mapped["ConfigTemplate | None"] = relationship("ConfigTemplate")
|
||||
device: Mapped["Device"] = relationship("Device") # type: ignore[name-defined]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<TemplatePushJob id={self.id} status={self.status!r} device_id={self.device_id}>"
|
||||
214
backend/app/models/device.py
Normal file
214
backend/app/models/device.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Device, DeviceGroup, DeviceTag, and membership models."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
DateTime,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
LargeBinary,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class DeviceStatus(str, Enum):
|
||||
"""Device connection status."""
|
||||
UNKNOWN = "unknown"
|
||||
ONLINE = "online"
|
||||
OFFLINE = "offline"
|
||||
|
||||
|
||||
class Device(Base):
|
||||
__tablename__ = "devices"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("tenant_id", "hostname", name="uq_devices_tenant_hostname"),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
hostname: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
ip_address: Mapped[str] = mapped_column(String(45), nullable=False) # IPv4 or IPv6
|
||||
api_port: Mapped[int] = mapped_column(Integer, default=8728, nullable=False)
|
||||
api_ssl_port: Mapped[int] = mapped_column(Integer, default=8729, nullable=False)
|
||||
model: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
serial_number: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
firmware_version: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
routeros_version: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
routeros_major_version: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
uptime_seconds: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
last_cpu_load: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
last_memory_used_pct: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
architecture: Mapped[str | None] = mapped_column(Text, nullable=True) # CPU arch (arm, arm64, mipsbe, etc.)
|
||||
preferred_channel: Mapped[str] = mapped_column(
|
||||
Text, default="stable", server_default="stable", nullable=False
|
||||
) # Firmware release channel
|
||||
last_seen: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
# AES-256-GCM encrypted credentials (username + password JSON)
|
||||
encrypted_credentials: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
|
||||
# OpenBao Transit ciphertext (dual-write migration)
|
||||
encrypted_credentials_transit: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
latitude: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
longitude: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
default=DeviceStatus.UNKNOWN.value,
|
||||
nullable=False,
|
||||
)
|
||||
tls_mode: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
default="auto",
|
||||
server_default="auto",
|
||||
nullable=False,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
tenant: Mapped["Tenant"] = relationship("Tenant", back_populates="devices") # type: ignore[name-defined]
|
||||
group_memberships: Mapped[list["DeviceGroupMembership"]] = relationship(
|
||||
"DeviceGroupMembership", back_populates="device", cascade="all, delete-orphan"
|
||||
)
|
||||
tag_assignments: Mapped[list["DeviceTagAssignment"]] = relationship(
|
||||
"DeviceTagAssignment", back_populates="device", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Device id={self.id} hostname={self.hostname!r} tenant_id={self.tenant_id}>"
|
||||
|
||||
|
||||
class DeviceGroup(Base):
|
||||
__tablename__ = "device_groups"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("tenant_id", "name", name="uq_device_groups_tenant_name"),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
preferred_channel: Mapped[str] = mapped_column(
|
||||
Text, default="stable", server_default="stable", nullable=False
|
||||
) # Firmware release channel for the group
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
tenant: Mapped["Tenant"] = relationship("Tenant", back_populates="device_groups") # type: ignore[name-defined]
|
||||
memberships: Mapped[list["DeviceGroupMembership"]] = relationship(
|
||||
"DeviceGroupMembership", back_populates="group", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<DeviceGroup id={self.id} name={self.name!r} tenant_id={self.tenant_id}>"
|
||||
|
||||
|
||||
class DeviceTag(Base):
|
||||
__tablename__ = "device_tags"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("tenant_id", "name", name="uq_device_tags_tenant_name"),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
color: Mapped[str | None] = mapped_column(String(7), nullable=True) # hex color e.g. #FF5733
|
||||
|
||||
# Relationships
|
||||
tenant: Mapped["Tenant"] = relationship("Tenant", back_populates="device_tags") # type: ignore[name-defined]
|
||||
assignments: Mapped[list["DeviceTagAssignment"]] = relationship(
|
||||
"DeviceTagAssignment", back_populates="tag", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<DeviceTag id={self.id} name={self.name!r} tenant_id={self.tenant_id}>"
|
||||
|
||||
|
||||
class DeviceGroupMembership(Base):
|
||||
__tablename__ = "device_group_memberships"
|
||||
|
||||
device_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("devices.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
group_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("device_groups.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
device: Mapped["Device"] = relationship("Device", back_populates="group_memberships")
|
||||
group: Mapped["DeviceGroup"] = relationship("DeviceGroup", back_populates="memberships")
|
||||
|
||||
|
||||
class DeviceTagAssignment(Base):
|
||||
__tablename__ = "device_tag_assignments"
|
||||
|
||||
device_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("devices.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
tag_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("device_tags.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
device: Mapped["Device"] = relationship("Device", back_populates="tag_assignments")
|
||||
tag: Mapped["DeviceTag"] = relationship("DeviceTag", back_populates="assignments")
|
||||
102
backend/app/models/firmware.py
Normal file
102
backend/app/models/firmware.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Firmware version tracking and upgrade job ORM models."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Boolean,
|
||||
DateTime,
|
||||
Integer,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy import ForeignKey
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class FirmwareVersion(Base):
|
||||
"""Cached firmware version from MikroTik download server or poller discovery.
|
||||
|
||||
Not tenant-scoped — firmware versions are global data shared across all tenants.
|
||||
"""
|
||||
__tablename__ = "firmware_versions"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("architecture", "channel", "version"),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
architecture: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
channel: Mapped[str] = mapped_column(Text, nullable=False) # "stable", "long-term", "testing"
|
||||
version: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
npk_url: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
npk_local_path: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
npk_size_bytes: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
checked_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<FirmwareVersion arch={self.architecture} ch={self.channel} ver={self.version}>"
|
||||
|
||||
|
||||
class FirmwareUpgradeJob(Base):
|
||||
"""Tracks a firmware upgrade operation for a single device.
|
||||
|
||||
Multiple jobs can share a rollout_group_id for mass upgrades.
|
||||
"""
|
||||
__tablename__ = "firmware_upgrade_jobs"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
device_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("devices.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
rollout_group_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
nullable=True,
|
||||
)
|
||||
target_version: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
architecture: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
channel: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
status: Mapped[str] = mapped_column(
|
||||
Text, nullable=False, default="pending", server_default="pending"
|
||||
)
|
||||
pre_upgrade_backup_sha: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
scheduled_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
started_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
confirmed_major_upgrade: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False, server_default="false"
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<FirmwareUpgradeJob id={self.id} status={self.status} target={self.target_version}>"
|
||||
134
backend/app/models/key_set.py
Normal file
134
backend/app/models/key_set.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Key set and key access log models for zero-knowledge architecture."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Integer, LargeBinary, Text, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class UserKeySet(Base):
|
||||
"""Encrypted key bundle for a user.
|
||||
|
||||
Stores the RSA private key (wrapped by AUK), tenant vault key
|
||||
(wrapped by AUK), RSA public key, and key derivation salts.
|
||||
One key set per user (UNIQUE on user_id).
|
||||
"""
|
||||
|
||||
__tablename__ = "user_key_sets"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=True, # NULL for super_admin
|
||||
)
|
||||
encrypted_private_key: Mapped[bytes] = mapped_column(
|
||||
LargeBinary, nullable=False
|
||||
)
|
||||
private_key_nonce: Mapped[bytes] = mapped_column(
|
||||
LargeBinary, nullable=False
|
||||
)
|
||||
encrypted_vault_key: Mapped[bytes] = mapped_column(
|
||||
LargeBinary, nullable=False
|
||||
)
|
||||
vault_key_nonce: Mapped[bytes] = mapped_column(
|
||||
LargeBinary, nullable=False
|
||||
)
|
||||
public_key: Mapped[bytes] = mapped_column(
|
||||
LargeBinary, nullable=False
|
||||
)
|
||||
pbkdf2_iterations: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
server_default=func.literal_column("650000"),
|
||||
nullable=False,
|
||||
)
|
||||
pbkdf2_salt: Mapped[bytes] = mapped_column(
|
||||
LargeBinary, nullable=False
|
||||
)
|
||||
hkdf_salt: Mapped[bytes] = mapped_column(
|
||||
LargeBinary, nullable=False
|
||||
)
|
||||
key_version: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
server_default=func.literal_column("1"),
|
||||
nullable=False,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
user: Mapped["User"] = relationship("User") # type: ignore[name-defined]
|
||||
tenant: Mapped["Tenant | None"] = relationship("Tenant") # type: ignore[name-defined]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<UserKeySet id={self.id} user_id={self.user_id} version={self.key_version}>"
|
||||
|
||||
|
||||
class KeyAccessLog(Base):
|
||||
"""Immutable audit trail for key operations.
|
||||
|
||||
Append-only: INSERT+SELECT only, no UPDATE/DELETE via RLS.
|
||||
"""
|
||||
|
||||
__tablename__ = "key_access_log"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
user_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
action: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
resource_type: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
resource_id: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
key_version: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
ip_address: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
# Phase 29 extensions for device credential access tracking
|
||||
device_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("devices.id"),
|
||||
nullable=True,
|
||||
)
|
||||
justification: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
correlation_id: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<KeyAccessLog id={self.id} action={self.action!r}>"
|
||||
74
backend/app/models/maintenance_window.py
Normal file
74
backend/app/models/maintenance_window.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Maintenance window ORM model for scheduled maintenance periods.
|
||||
|
||||
Maintenance windows allow operators to define time periods during which
|
||||
alerts are suppressed for specific devices (or all devices in a tenant).
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Text, VARCHAR, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class MaintenanceWindow(Base):
|
||||
"""Scheduled maintenance window with optional alert suppression.
|
||||
|
||||
device_ids is a JSONB array of device UUID strings.
|
||||
An empty array means "all devices in tenant".
|
||||
"""
|
||||
__tablename__ = "maintenance_windows"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(VARCHAR(200), nullable=False)
|
||||
device_ids: Mapped[list] = mapped_column(
|
||||
JSONB,
|
||||
nullable=False,
|
||||
server_default="'[]'::jsonb",
|
||||
)
|
||||
start_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
)
|
||||
end_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
)
|
||||
suppress_alerts: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=True,
|
||||
server_default="true",
|
||||
)
|
||||
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
created_by: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<MaintenanceWindow id={self.id} name={self.name!r}>"
|
||||
49
backend/app/models/tenant.py
Normal file
49
backend/app/models/tenant.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Tenant model — represents an MSP client organization."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, LargeBinary, Integer, String, Text, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class Tenant(Base):
|
||||
__tablename__ = "tenants"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
contact_email: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Zero-knowledge key management (Phase 28+29)
|
||||
encrypted_vault_key: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
|
||||
vault_key_version: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
openbao_key_name: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Relationships — passive_deletes=True lets the DB ON DELETE CASCADE handle cleanup
|
||||
users: Mapped[list["User"]] = relationship("User", back_populates="tenant", passive_deletes=True) # type: ignore[name-defined]
|
||||
devices: Mapped[list["Device"]] = relationship("Device", back_populates="tenant", passive_deletes=True) # type: ignore[name-defined]
|
||||
device_groups: Mapped[list["DeviceGroup"]] = relationship("DeviceGroup", back_populates="tenant", passive_deletes=True) # type: ignore[name-defined]
|
||||
device_tags: Mapped[list["DeviceTag"]] = relationship("DeviceTag", back_populates="tenant", passive_deletes=True) # type: ignore[name-defined]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tenant id={self.id} name={self.name!r}>"
|
||||
74
backend/app/models/user.py
Normal file
74
backend/app/models/user.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""User model with role-based access control."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, LargeBinary, SmallInteger, String, func, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class UserRole(str, Enum):
|
||||
"""User roles with increasing privilege levels."""
|
||||
SUPER_ADMIN = "super_admin"
|
||||
TENANT_ADMIN = "tenant_admin"
|
||||
OPERATOR = "operator"
|
||||
VIEWER = "viewer"
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||
hashed_password: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
role: Mapped[str] = mapped_column(
|
||||
String(50),
|
||||
nullable=False,
|
||||
default=UserRole.VIEWER.value,
|
||||
)
|
||||
# tenant_id is nullable for super_admin users (portal-wide role)
|
||||
tenant_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
# SRP zero-knowledge authentication columns (nullable during migration period)
|
||||
srp_salt: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
|
||||
srp_verifier: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
|
||||
auth_version: Mapped[int] = mapped_column(
|
||||
SmallInteger, server_default=text("1"), nullable=False
|
||||
) # 1=bcrypt legacy, 2=SRP
|
||||
must_upgrade_auth: Mapped[bool] = mapped_column(
|
||||
Boolean, server_default=text("false"), nullable=False
|
||||
) # True for bcrypt users who need SRP upgrade
|
||||
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
last_login: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
tenant: Mapped["Tenant | None"] = relationship("Tenant", back_populates="users") # type: ignore[name-defined]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<User id={self.id} email={self.email!r} role={self.role!r}>"
|
||||
85
backend/app/models/vpn.py
Normal file
85
backend/app/models/vpn.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""VPN configuration and peer models for WireGuard management."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, LargeBinary, String, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class VpnConfig(Base):
|
||||
"""Per-tenant WireGuard server configuration."""
|
||||
|
||||
__tablename__ = "vpn_config"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
)
|
||||
server_private_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
|
||||
server_public_key: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
subnet: Mapped[str] = mapped_column(String(32), nullable=False, server_default="10.10.0.0/24")
|
||||
server_port: Mapped[int] = mapped_column(Integer, nullable=False, server_default="51820")
|
||||
server_address: Mapped[str] = mapped_column(String(32), nullable=False, server_default="10.10.0.1/24")
|
||||
endpoint: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
is_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default="false")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False, onupdate=func.now()
|
||||
)
|
||||
|
||||
# Peers are queried separately via tenant_id — no ORM relationship needed
|
||||
|
||||
|
||||
class VpnPeer(Base):
|
||||
"""WireGuard peer representing a device's VPN connection."""
|
||||
|
||||
__tablename__ = "vpn_peers"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
server_default=func.gen_random_uuid(),
|
||||
)
|
||||
tenant_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("tenants.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
device_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("devices.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
)
|
||||
peer_private_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
|
||||
peer_public_key: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
preshared_key: Mapped[Optional[bytes]] = mapped_column(LargeBinary, nullable=True)
|
||||
assigned_ip: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
additional_allowed_ips: Mapped[Optional[str]] = mapped_column(String(512), nullable=True)
|
||||
is_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default="true")
|
||||
last_handshake: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False, onupdate=func.now()
|
||||
)
|
||||
|
||||
# Config is queried separately via tenant_id — no ORM relationship needed
|
||||
140
backend/app/observability.py
Normal file
140
backend/app/observability.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Prometheus metrics and health check infrastructure.
|
||||
|
||||
Provides:
|
||||
- setup_instrumentator(): Configures Prometheus auto-instrumentation for FastAPI
|
||||
- check_health_ready(): Verifies PostgreSQL, Redis, and NATS connectivity for readiness probes
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import structlog
|
||||
from fastapi import FastAPI
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
def setup_instrumentator(app: FastAPI) -> Instrumentator:
|
||||
"""Configure and mount Prometheus metrics instrumentation.
|
||||
|
||||
Auto-instruments all HTTP endpoints with:
|
||||
- http_requests_total (counter) by method, handler, status_code
|
||||
- http_request_duration_seconds (histogram) by method, handler
|
||||
- http_requests_in_progress (gauge)
|
||||
|
||||
The /metrics endpoint is mounted at root level (not under /api prefix).
|
||||
Labels use handler templates (e.g., /api/tenants/{tenant_id}/...) not
|
||||
resolved paths, ensuring bounded cardinality.
|
||||
|
||||
Must be called AFTER all routers are included so all routes are captured.
|
||||
"""
|
||||
instrumentator = Instrumentator(
|
||||
should_group_status_codes=False,
|
||||
should_ignore_untemplated=True,
|
||||
excluded_handlers=["/health", "/health/ready", "/metrics", "/api/health"],
|
||||
should_respect_env_var=False,
|
||||
)
|
||||
instrumentator.instrument(app)
|
||||
instrumentator.expose(app, include_in_schema=False, should_gzip=True)
|
||||
logger.info("prometheus instrumentation enabled", endpoint="/metrics")
|
||||
return instrumentator
|
||||
|
||||
|
||||
async def check_health_ready() -> dict:
|
||||
"""Check readiness by verifying all critical dependencies.
|
||||
|
||||
Checks PostgreSQL, Redis, and NATS connectivity with 5-second timeouts.
|
||||
Returns a structured result with per-dependency status and latency.
|
||||
|
||||
Returns:
|
||||
dict with "status" ("healthy"|"unhealthy"), "version", and "checks"
|
||||
containing per-dependency results.
|
||||
"""
|
||||
from app.config import settings
|
||||
|
||||
checks: dict[str, dict] = {}
|
||||
all_healthy = True
|
||||
|
||||
# PostgreSQL check
|
||||
checks["postgres"] = await _check_postgres()
|
||||
if checks["postgres"]["status"] != "up":
|
||||
all_healthy = False
|
||||
|
||||
# Redis check
|
||||
checks["redis"] = await _check_redis(settings.REDIS_URL)
|
||||
if checks["redis"]["status"] != "up":
|
||||
all_healthy = False
|
||||
|
||||
# NATS check
|
||||
checks["nats"] = await _check_nats(settings.NATS_URL)
|
||||
if checks["nats"]["status"] != "up":
|
||||
all_healthy = False
|
||||
|
||||
return {
|
||||
"status": "healthy" if all_healthy else "unhealthy",
|
||||
"version": settings.APP_VERSION,
|
||||
"checks": checks,
|
||||
}
|
||||
|
||||
|
||||
async def _check_postgres() -> dict:
|
||||
"""Verify PostgreSQL connectivity via the admin engine."""
|
||||
start = time.monotonic()
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.database import engine
|
||||
|
||||
async with engine.connect() as conn:
|
||||
await asyncio.wait_for(
|
||||
conn.execute(text("SELECT 1")),
|
||||
timeout=5.0,
|
||||
)
|
||||
latency_ms = round((time.monotonic() - start) * 1000)
|
||||
return {"status": "up", "latency_ms": latency_ms, "error": None}
|
||||
except Exception as exc:
|
||||
latency_ms = round((time.monotonic() - start) * 1000)
|
||||
logger.warning("health check: postgres failed", error=str(exc))
|
||||
return {"status": "down", "latency_ms": latency_ms, "error": str(exc)}
|
||||
|
||||
|
||||
async def _check_redis(redis_url: str) -> dict:
|
||||
"""Verify Redis connectivity."""
|
||||
start = time.monotonic()
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
client = aioredis.from_url(redis_url, socket_connect_timeout=5)
|
||||
try:
|
||||
await asyncio.wait_for(client.ping(), timeout=5.0)
|
||||
finally:
|
||||
await client.aclose()
|
||||
latency_ms = round((time.monotonic() - start) * 1000)
|
||||
return {"status": "up", "latency_ms": latency_ms, "error": None}
|
||||
except Exception as exc:
|
||||
latency_ms = round((time.monotonic() - start) * 1000)
|
||||
logger.warning("health check: redis failed", error=str(exc))
|
||||
return {"status": "down", "latency_ms": latency_ms, "error": str(exc)}
|
||||
|
||||
|
||||
async def _check_nats(nats_url: str) -> dict:
|
||||
"""Verify NATS connectivity."""
|
||||
start = time.monotonic()
|
||||
try:
|
||||
import nats
|
||||
|
||||
nc = await asyncio.wait_for(
|
||||
nats.connect(nats_url),
|
||||
timeout=5.0,
|
||||
)
|
||||
try:
|
||||
await nc.drain()
|
||||
except Exception:
|
||||
pass
|
||||
latency_ms = round((time.monotonic() - start) * 1000)
|
||||
return {"status": "up", "latency_ms": latency_ms, "error": None}
|
||||
except Exception as exc:
|
||||
latency_ms = round((time.monotonic() - start) * 1000)
|
||||
logger.warning("health check: nats failed", error=str(exc))
|
||||
return {"status": "down", "latency_ms": latency_ms, "error": str(exc)}
|
||||
1
backend/app/routers/__init__.py
Normal file
1
backend/app/routers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""FastAPI routers for all API endpoints."""
|
||||
1088
backend/app/routers/alerts.py
Normal file
1088
backend/app/routers/alerts.py
Normal file
File diff suppressed because it is too large
Load Diff
172
backend/app/routers/api_keys.py
Normal file
172
backend/app/routers/api_keys.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""API key management endpoints.
|
||||
|
||||
Tenant-scoped routes under /api/tenants/{tenant_id}/api-keys:
|
||||
- List all keys (active + revoked)
|
||||
- Create new key (returns plaintext once)
|
||||
- Revoke key (soft delete)
|
||||
|
||||
RBAC: tenant_admin or above for all operations.
|
||||
RLS enforced via get_db() (app_user engine with tenant context).
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db, set_tenant_context
|
||||
from app.middleware.rbac import require_min_role
|
||||
from app.middleware.tenant_context import CurrentUser, get_current_user
|
||||
from app.services.api_key_service import (
|
||||
ALLOWED_SCOPES,
|
||||
create_api_key,
|
||||
list_api_keys,
|
||||
revoke_api_key,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["api-keys"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _check_tenant_access(
|
||||
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
|
||||
) -> None:
|
||||
"""Verify the current user is allowed to access the given tenant."""
|
||||
if current_user.is_super_admin:
|
||||
await set_tenant_context(db, str(tenant_id))
|
||||
elif current_user.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this tenant",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request/response schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ApiKeyCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
name: str
|
||||
scopes: list[str]
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class ApiKeyResponse(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
id: str
|
||||
name: str
|
||||
key_prefix: str
|
||||
scopes: list[str]
|
||||
expires_at: Optional[str] = None
|
||||
last_used_at: Optional[str] = None
|
||||
created_at: str
|
||||
revoked_at: Optional[str] = None
|
||||
|
||||
|
||||
class ApiKeyCreateResponse(ApiKeyResponse):
|
||||
"""Extended response that includes the plaintext key (shown once)."""
|
||||
|
||||
key: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/api-keys", response_model=list[ApiKeyResponse])
|
||||
async def list_keys(
|
||||
tenant_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("tenant_admin")),
|
||||
) -> list[dict]:
|
||||
"""List all API keys for a tenant."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
keys = await list_api_keys(db, tenant_id)
|
||||
# Convert UUID ids to strings for response
|
||||
for k in keys:
|
||||
k["id"] = str(k["id"])
|
||||
return keys
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/api-keys",
|
||||
response_model=ApiKeyCreateResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_key(
|
||||
tenant_id: uuid.UUID,
|
||||
body: ApiKeyCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("tenant_admin")),
|
||||
) -> dict:
|
||||
"""Create a new API key. The plaintext key is returned only once."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
# Validate scopes against allowed list
|
||||
invalid_scopes = set(body.scopes) - ALLOWED_SCOPES
|
||||
if invalid_scopes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid scopes: {', '.join(sorted(invalid_scopes))}. "
|
||||
f"Allowed: {', '.join(sorted(ALLOWED_SCOPES))}",
|
||||
)
|
||||
|
||||
if not body.scopes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="At least one scope is required.",
|
||||
)
|
||||
|
||||
result = await create_api_key(
|
||||
db=db,
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.user_id,
|
||||
name=body.name,
|
||||
scopes=body.scopes,
|
||||
expires_at=body.expires_at,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": str(result["id"]),
|
||||
"name": result["name"],
|
||||
"key_prefix": result["key_prefix"],
|
||||
"key": result["key"],
|
||||
"scopes": result["scopes"],
|
||||
"expires_at": result["expires_at"].isoformat() if result["expires_at"] else None,
|
||||
"last_used_at": None,
|
||||
"created_at": result["created_at"].isoformat() if result["created_at"] else None,
|
||||
"revoked_at": None,
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/tenants/{tenant_id}/api-keys/{key_id}", status_code=status.HTTP_200_OK)
|
||||
async def revoke_key(
|
||||
tenant_id: uuid.UUID,
|
||||
key_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("tenant_admin")),
|
||||
) -> dict:
|
||||
"""Revoke an API key (soft delete -- sets revoked_at timestamp)."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
success = await revoke_api_key(db, tenant_id, key_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="API key not found or already revoked.",
|
||||
)
|
||||
|
||||
return {"status": "revoked", "key_id": str(key_id)}
|
||||
294
backend/app/routers/audit_logs.py
Normal file
294
backend/app/routers/audit_logs.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""Audit log API endpoints.
|
||||
|
||||
Tenant-scoped routes under /api/tenants/{tenant_id}/ for:
|
||||
- Paginated, filterable audit log listing
|
||||
- CSV export of audit logs
|
||||
|
||||
RLS enforced via get_db() (app_user engine with tenant context).
|
||||
RBAC: operator and above can view audit logs.
|
||||
|
||||
Phase 30: Audit log details are encrypted at rest via Transit (Tier 2).
|
||||
When encrypted_details is set, the router decrypts via Transit on-demand
|
||||
and returns the plaintext details in the response. Structural fields
|
||||
(action, resource_type, timestamp, ip_address) are always plaintext.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, func, select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db, set_tenant_context
|
||||
from app.middleware.tenant_context import CurrentUser, get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["audit-logs"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _check_tenant_access(
|
||||
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
|
||||
) -> None:
|
||||
"""Verify the current user is allowed to access the given tenant."""
|
||||
if current_user.is_super_admin:
|
||||
await set_tenant_context(db, str(tenant_id))
|
||||
elif current_user.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this tenant",
|
||||
)
|
||||
|
||||
|
||||
def _require_operator(current_user: CurrentUser) -> None:
|
||||
"""Raise 403 if user does not have at least operator role."""
|
||||
allowed = {"super_admin", "admin", "operator"}
|
||||
if current_user.role not in allowed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="At least operator role required to view audit logs.",
|
||||
)
|
||||
|
||||
|
||||
async def _decrypt_audit_details(
|
||||
encrypted_details: str | None,
|
||||
plaintext_details: dict[str, Any] | None,
|
||||
tenant_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Decrypt encrypted audit log details via Transit, falling back to plaintext.
|
||||
|
||||
Priority:
|
||||
1. If encrypted_details is set, decrypt via Transit and parse as JSON.
|
||||
2. If decryption fails, return plaintext details as fallback.
|
||||
3. If neither available, return empty dict.
|
||||
"""
|
||||
if encrypted_details:
|
||||
try:
|
||||
from app.services.crypto import decrypt_data_transit
|
||||
|
||||
decrypted_json = await decrypt_data_transit(encrypted_details, tenant_id)
|
||||
return json.loads(decrypted_json)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to decrypt audit details for tenant %s, using plaintext fallback",
|
||||
tenant_id,
|
||||
exc_info=True,
|
||||
)
|
||||
# Fall through to plaintext
|
||||
return plaintext_details if plaintext_details else {}
|
||||
|
||||
|
||||
async def _decrypt_details_batch(
|
||||
rows: list[Any],
|
||||
tenant_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Decrypt encrypted_details for a batch of audit log rows concurrently.
|
||||
|
||||
Uses asyncio.gather with limited concurrency to avoid overwhelming OpenBao.
|
||||
Rows without encrypted_details return their plaintext details directly.
|
||||
"""
|
||||
semaphore = asyncio.Semaphore(10) # Limit concurrent Transit calls
|
||||
|
||||
async def _decrypt_one(row: Any) -> dict[str, Any]:
|
||||
async with semaphore:
|
||||
return await _decrypt_audit_details(
|
||||
row.get("encrypted_details"),
|
||||
row.get("details"),
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
return list(await asyncio.gather(*[_decrypt_one(row) for row in rows]))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AuditLogItem(BaseModel):
|
||||
id: str
|
||||
user_email: Optional[str] = None
|
||||
action: str
|
||||
resource_type: Optional[str] = None
|
||||
resource_id: Optional[str] = None
|
||||
device_name: Optional[str] = None
|
||||
details: dict[str, Any] = {}
|
||||
ip_address: Optional[str] = None
|
||||
created_at: str
|
||||
|
||||
|
||||
class AuditLogResponse(BaseModel):
|
||||
items: list[AuditLogItem]
|
||||
total: int
|
||||
page: int
|
||||
per_page: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/audit-logs",
|
||||
response_model=AuditLogResponse,
|
||||
summary="List audit logs with pagination and filters",
|
||||
)
|
||||
async def list_audit_logs(
|
||||
tenant_id: uuid.UUID,
|
||||
page: int = Query(default=1, ge=1),
|
||||
per_page: int = Query(default=50, ge=1, le=100),
|
||||
action: Optional[str] = Query(default=None),
|
||||
user_id: Optional[uuid.UUID] = Query(default=None),
|
||||
device_id: Optional[uuid.UUID] = Query(default=None),
|
||||
date_from: Optional[datetime] = Query(default=None),
|
||||
date_to: Optional[datetime] = Query(default=None),
|
||||
format: Optional[str] = Query(default=None, description="Set to 'csv' for CSV export"),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
_require_operator(current_user)
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
# Build filter conditions using parameterized text fragments
|
||||
conditions = [text("a.tenant_id = :tenant_id")]
|
||||
params: dict[str, Any] = {"tenant_id": str(tenant_id)}
|
||||
|
||||
if action:
|
||||
conditions.append(text("a.action = :action"))
|
||||
params["action"] = action
|
||||
|
||||
if user_id:
|
||||
conditions.append(text("a.user_id = :user_id"))
|
||||
params["user_id"] = str(user_id)
|
||||
|
||||
if device_id:
|
||||
conditions.append(text("a.device_id = :device_id"))
|
||||
params["device_id"] = str(device_id)
|
||||
|
||||
if date_from:
|
||||
conditions.append(text("a.created_at >= :date_from"))
|
||||
params["date_from"] = date_from.isoformat()
|
||||
|
||||
if date_to:
|
||||
conditions.append(text("a.created_at <= :date_to"))
|
||||
params["date_to"] = date_to.isoformat()
|
||||
|
||||
where_clause = and_(*conditions)
|
||||
|
||||
# Shared SELECT columns for data queries
|
||||
_data_columns = text(
|
||||
"a.id, u.email AS user_email, a.action, a.resource_type, "
|
||||
"a.resource_id, d.hostname AS device_name, a.details, "
|
||||
"a.encrypted_details, a.ip_address, a.created_at"
|
||||
)
|
||||
_data_from = text(
|
||||
"audit_logs a "
|
||||
"LEFT JOIN users u ON a.user_id = u.id "
|
||||
"LEFT JOIN devices d ON a.device_id = d.id"
|
||||
)
|
||||
|
||||
# Count total
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(text("audit_logs a")).where(where_clause),
|
||||
params,
|
||||
)
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
# CSV export -- no pagination limit
|
||||
if format == "csv":
|
||||
result = await db.execute(
|
||||
select(_data_columns)
|
||||
.select_from(_data_from)
|
||||
.where(where_clause)
|
||||
.order_by(text("a.created_at DESC")),
|
||||
params,
|
||||
)
|
||||
all_rows = result.mappings().all()
|
||||
|
||||
# Decrypt encrypted details concurrently
|
||||
decrypted_details = await _decrypt_details_batch(
|
||||
all_rows, str(tenant_id)
|
||||
)
|
||||
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
writer.writerow([
|
||||
"ID", "User Email", "Action", "Resource Type",
|
||||
"Resource ID", "Device", "Details", "IP Address", "Timestamp",
|
||||
])
|
||||
for row, details in zip(all_rows, decrypted_details):
|
||||
details_str = json.dumps(details) if details else "{}"
|
||||
writer.writerow([
|
||||
str(row["id"]),
|
||||
row["user_email"] or "",
|
||||
row["action"],
|
||||
row["resource_type"] or "",
|
||||
row["resource_id"] or "",
|
||||
row["device_name"] or "",
|
||||
details_str,
|
||||
row["ip_address"] or "",
|
||||
str(row["created_at"]),
|
||||
])
|
||||
|
||||
output.seek(0)
|
||||
return StreamingResponse(
|
||||
iter([output.getvalue()]),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": "attachment; filename=audit-logs.csv"},
|
||||
)
|
||||
|
||||
# Paginated query
|
||||
offset = (page - 1) * per_page
|
||||
params["limit"] = per_page
|
||||
params["offset"] = offset
|
||||
|
||||
result = await db.execute(
|
||||
select(_data_columns)
|
||||
.select_from(_data_from)
|
||||
.where(where_clause)
|
||||
.order_by(text("a.created_at DESC"))
|
||||
.limit(per_page)
|
||||
.offset(offset),
|
||||
params,
|
||||
)
|
||||
rows = result.mappings().all()
|
||||
|
||||
# Decrypt encrypted details concurrently (skips rows without encrypted_details)
|
||||
decrypted_details = await _decrypt_details_batch(rows, str(tenant_id))
|
||||
|
||||
items = [
|
||||
AuditLogItem(
|
||||
id=str(row["id"]),
|
||||
user_email=row["user_email"],
|
||||
action=row["action"],
|
||||
resource_type=row["resource_type"],
|
||||
resource_id=row["resource_id"],
|
||||
device_name=row["device_name"],
|
||||
details=details,
|
||||
ip_address=row["ip_address"],
|
||||
created_at=row["created_at"].isoformat() if row["created_at"] else "",
|
||||
)
|
||||
for row, details in zip(rows, decrypted_details)
|
||||
]
|
||||
|
||||
return AuditLogResponse(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
)
|
||||
1052
backend/app/routers/auth.py
Normal file
1052
backend/app/routers/auth.py
Normal file
File diff suppressed because it is too large
Load Diff
763
backend/app/routers/certificates.py
Normal file
763
backend/app/routers/certificates.py
Normal file
@@ -0,0 +1,763 @@
|
||||
"""Certificate Authority management API endpoints.
|
||||
|
||||
Provides the full certificate lifecycle for tenant CAs:
|
||||
- CA initialization and info retrieval
|
||||
- Per-device certificate signing
|
||||
- Certificate deployment via NATS to Go poller (SFTP + RouterOS import)
|
||||
- Bulk deployment across multiple devices
|
||||
- Certificate rotation and revocation
|
||||
|
||||
RLS enforced via get_db() (app_user engine with tenant context).
|
||||
RBAC: viewer = read-only (GET); tenant_admin and above = mutating actions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import nats
|
||||
import nats.aio.client
|
||||
import nats.errors
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.database import get_db, set_tenant_context
|
||||
from app.middleware.rate_limit import limiter
|
||||
from app.middleware.rbac import require_min_role
|
||||
from app.middleware.tenant_context import CurrentUser, get_current_user
|
||||
from app.models.certificate import CertificateAuthority, DeviceCertificate
|
||||
from app.models.device import Device
|
||||
from app.schemas.certificate import (
|
||||
BulkCertDeployRequest,
|
||||
CACreateRequest,
|
||||
CAResponse,
|
||||
CertDeployResponse,
|
||||
CertSignRequest,
|
||||
DeviceCertResponse,
|
||||
)
|
||||
from app.services.audit_service import log_action
|
||||
from app.services.ca_service import (
|
||||
generate_ca,
|
||||
get_ca_for_tenant,
|
||||
get_cert_for_deploy,
|
||||
get_device_certs,
|
||||
sign_device_cert,
|
||||
update_cert_status,
|
||||
)
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
router = APIRouter(tags=["certificates"])
|
||||
|
||||
# Module-level NATS connection for cert deployment (lazy initialized)
|
||||
_nc: nats.aio.client.Client | None = None
|
||||
|
||||
|
||||
async def _get_nats() -> nats.aio.client.Client:
|
||||
"""Get or create a NATS connection for certificate deployment requests."""
|
||||
global _nc
|
||||
if _nc is None or _nc.is_closed:
|
||||
_nc = await nats.connect(settings.NATS_URL)
|
||||
logger.info("Certificate NATS connection established")
|
||||
return _nc
|
||||
|
||||
|
||||
async def _deploy_cert_via_nats(
|
||||
device_id: str,
|
||||
cert_pem: str,
|
||||
key_pem: str,
|
||||
cert_name: str,
|
||||
ssh_port: int = 22,
|
||||
) -> dict:
|
||||
"""Send a certificate deployment request to the Go poller via NATS.
|
||||
|
||||
Args:
|
||||
device_id: Target device UUID string.
|
||||
cert_pem: PEM-encoded device certificate.
|
||||
key_pem: PEM-encoded device private key (decrypted).
|
||||
cert_name: Name for the cert on the device (e.g., "portal-device-cert").
|
||||
ssh_port: SSH port for SFTP upload (default 22).
|
||||
|
||||
Returns:
|
||||
Dict with success, cert_name_on_device, and error fields.
|
||||
"""
|
||||
nc = await _get_nats()
|
||||
payload = json.dumps({
|
||||
"device_id": device_id,
|
||||
"cert_pem": cert_pem,
|
||||
"key_pem": key_pem,
|
||||
"cert_name": cert_name,
|
||||
"ssh_port": ssh_port,
|
||||
}).encode()
|
||||
|
||||
try:
|
||||
reply = await nc.request(
|
||||
f"cert.deploy.{device_id}",
|
||||
payload,
|
||||
timeout=60.0,
|
||||
)
|
||||
return json.loads(reply.data)
|
||||
except nats.errors.TimeoutError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Certificate deployment timed out -- device may be offline or unreachable",
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.error("NATS cert deploy request failed", device_id=device_id, error=str(exc))
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _get_device_for_tenant(
|
||||
db: AsyncSession, device_id: uuid.UUID, current_user: CurrentUser
|
||||
) -> Device:
|
||||
"""Fetch a device and verify tenant ownership."""
|
||||
result = await db.execute(
|
||||
select(Device).where(Device.id == device_id)
|
||||
)
|
||||
device = result.scalar_one_or_none()
|
||||
if device is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Device {device_id} not found",
|
||||
)
|
||||
return device
|
||||
|
||||
|
||||
async def _get_tenant_id(
|
||||
current_user: CurrentUser,
|
||||
db: AsyncSession,
|
||||
tenant_id_override: uuid.UUID | None = None,
|
||||
) -> uuid.UUID:
|
||||
"""Extract tenant_id from the current user, handling super_admin.
|
||||
|
||||
Super admins must provide tenant_id_override (from query param).
|
||||
Regular users use their own tenant_id.
|
||||
"""
|
||||
if current_user.is_super_admin:
|
||||
if tenant_id_override is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Super admin must provide tenant_id query parameter.",
|
||||
)
|
||||
# Set RLS context for the selected tenant
|
||||
await set_tenant_context(db, str(tenant_id_override))
|
||||
return tenant_id_override
|
||||
if current_user.tenant_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="No tenant context available.",
|
||||
)
|
||||
return current_user.tenant_id
|
||||
|
||||
|
||||
async def _get_cert_with_tenant_check(
|
||||
db: AsyncSession, cert_id: uuid.UUID, tenant_id: uuid.UUID
|
||||
) -> DeviceCertificate:
|
||||
"""Fetch a device certificate and verify tenant ownership."""
|
||||
result = await db.execute(
|
||||
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
|
||||
)
|
||||
cert = result.scalar_one_or_none()
|
||||
if cert is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Certificate {cert_id} not found",
|
||||
)
|
||||
# RLS should enforce this, but double-check
|
||||
if cert.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Certificate {cert_id} not found",
|
||||
)
|
||||
return cert
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post(
|
||||
"/ca",
|
||||
response_model=CAResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Initialize a Certificate Authority for the tenant",
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def create_ca(
|
||||
request: Request,
|
||||
body: CACreateRequest,
|
||||
tenant_id: uuid.UUID | None = Query(None, description="Tenant ID (required for super_admin)"),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("tenant_admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> CAResponse:
|
||||
"""Generate a self-signed root CA for the tenant.
|
||||
|
||||
Each tenant may have at most one CA. Returns 409 if a CA already exists.
|
||||
"""
|
||||
tenant_id = await _get_tenant_id(current_user, db, tenant_id)
|
||||
|
||||
# Check if CA already exists
|
||||
existing = await get_ca_for_tenant(db, tenant_id)
|
||||
if existing is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Tenant already has a Certificate Authority. Delete it before creating a new one.",
|
||||
)
|
||||
|
||||
ca = await generate_ca(
|
||||
db,
|
||||
tenant_id,
|
||||
body.common_name,
|
||||
body.validity_years,
|
||||
settings.get_encryption_key_bytes(),
|
||||
)
|
||||
|
||||
try:
|
||||
await log_action(
|
||||
db, tenant_id, current_user.user_id, "ca_create",
|
||||
resource_type="certificate_authority", resource_id=str(ca.id),
|
||||
details={"common_name": body.common_name, "validity_years": body.validity_years},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("CA created", tenant_id=str(tenant_id), ca_id=str(ca.id))
|
||||
return CAResponse.model_validate(ca)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/ca",
|
||||
response_model=CAResponse,
|
||||
summary="Get tenant CA information",
|
||||
)
|
||||
async def get_ca(
|
||||
tenant_id: uuid.UUID | None = Query(None, description="Tenant ID (required for super_admin)"),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> CAResponse:
|
||||
"""Return the tenant's CA public information (no private key)."""
|
||||
tenant_id = await _get_tenant_id(current_user, db, tenant_id)
|
||||
ca = await get_ca_for_tenant(db, tenant_id)
|
||||
if ca is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No Certificate Authority configured for this tenant.",
|
||||
)
|
||||
return CAResponse.model_validate(ca)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/ca/pem",
|
||||
response_class=PlainTextResponse,
|
||||
summary="Download the CA public certificate (PEM)",
|
||||
)
|
||||
async def get_ca_pem(
|
||||
tenant_id: uuid.UUID | None = Query(None, description="Tenant ID (required for super_admin)"),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> PlainTextResponse:
|
||||
"""Return the CA's public certificate in PEM format.
|
||||
|
||||
Users can import this into their trust store to validate device connections.
|
||||
"""
|
||||
tenant_id = await _get_tenant_id(current_user, db, tenant_id)
|
||||
ca = await get_ca_for_tenant(db, tenant_id)
|
||||
if ca is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No Certificate Authority configured for this tenant.",
|
||||
)
|
||||
return PlainTextResponse(
|
||||
content=ca.cert_pem,
|
||||
media_type="application/x-pem-file",
|
||||
headers={"Content-Disposition": "attachment; filename=portal-ca.pem"},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sign",
|
||||
response_model=DeviceCertResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Sign a certificate for a device",
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def sign_cert(
|
||||
request: Request,
|
||||
body: CertSignRequest,
|
||||
tenant_id: uuid.UUID | None = Query(None, description="Tenant ID (required for super_admin)"),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("tenant_admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> DeviceCertResponse:
|
||||
"""Sign a per-device TLS certificate using the tenant's CA.
|
||||
|
||||
The device must belong to the tenant. The cert uses CN=hostname, SAN=IP+DNS.
|
||||
"""
|
||||
tenant_id = await _get_tenant_id(current_user, db, tenant_id)
|
||||
|
||||
# Verify device belongs to tenant (RLS enforces, but also get device data)
|
||||
device = await _get_device_for_tenant(db, body.device_id, current_user)
|
||||
|
||||
# Get tenant CA
|
||||
ca = await get_ca_for_tenant(db, tenant_id)
|
||||
if ca is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No Certificate Authority configured. Initialize a CA first.",
|
||||
)
|
||||
|
||||
cert = await sign_device_cert(
|
||||
db,
|
||||
ca,
|
||||
body.device_id,
|
||||
device.hostname,
|
||||
device.ip_address,
|
||||
body.validity_days,
|
||||
settings.get_encryption_key_bytes(),
|
||||
)
|
||||
|
||||
try:
|
||||
await log_action(
|
||||
db, tenant_id, current_user.user_id, "cert_sign",
|
||||
resource_type="device_certificate", resource_id=str(cert.id),
|
||||
device_id=body.device_id,
|
||||
details={"hostname": device.hostname, "validity_days": body.validity_days},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("Device cert signed", device_id=str(body.device_id), cert_id=str(cert.id))
|
||||
return DeviceCertResponse.model_validate(cert)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{cert_id}/deploy",
|
||||
response_model=CertDeployResponse,
|
||||
summary="Deploy a signed certificate to a device",
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def deploy_cert(
|
||||
request: Request,
|
||||
cert_id: uuid.UUID,
|
||||
tenant_id: uuid.UUID | None = Query(None, description="Tenant ID (required for super_admin)"),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("tenant_admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> CertDeployResponse:
|
||||
"""Deploy a signed certificate to a device via NATS/SFTP.
|
||||
|
||||
The Go poller receives the cert, uploads it via SFTP, imports it,
|
||||
and assigns it to the api-ssl service on the RouterOS device.
|
||||
"""
|
||||
tenant_id = await _get_tenant_id(current_user, db, tenant_id)
|
||||
cert = await _get_cert_with_tenant_check(db, cert_id, tenant_id)
|
||||
|
||||
# Update status to deploying
|
||||
try:
|
||||
await update_cert_status(db, cert_id, "deploying")
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Get decrypted cert data for deployment
|
||||
try:
|
||||
cert_pem, key_pem, _ca_cert_pem = await get_cert_for_deploy(
|
||||
db, cert_id, settings.get_encryption_key_bytes()
|
||||
)
|
||||
except ValueError as e:
|
||||
# Rollback status
|
||||
await update_cert_status(db, cert_id, "issued")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to prepare cert for deployment: {e}",
|
||||
)
|
||||
|
||||
# Flush DB changes before NATS call so deploying status is persisted
|
||||
await db.flush()
|
||||
|
||||
# Send deployment command via NATS
|
||||
result = await _deploy_cert_via_nats(
|
||||
device_id=str(cert.device_id),
|
||||
cert_pem=cert_pem,
|
||||
key_pem=key_pem,
|
||||
cert_name="portal-device-cert",
|
||||
)
|
||||
|
||||
if result.get("success"):
|
||||
# Update cert status to deployed
|
||||
await update_cert_status(db, cert_id, "deployed")
|
||||
|
||||
# Update device tls_mode to portal_ca
|
||||
device_result = await db.execute(
|
||||
select(Device).where(Device.id == cert.device_id)
|
||||
)
|
||||
device = device_result.scalar_one_or_none()
|
||||
if device is not None:
|
||||
device.tls_mode = "portal_ca"
|
||||
|
||||
try:
|
||||
await log_action(
|
||||
db, tenant_id, current_user.user_id, "cert_deploy",
|
||||
resource_type="device_certificate", resource_id=str(cert_id),
|
||||
device_id=cert.device_id,
|
||||
details={"cert_name_on_device": result.get("cert_name_on_device")},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(
|
||||
"Certificate deployed successfully",
|
||||
cert_id=str(cert_id),
|
||||
device_id=str(cert.device_id),
|
||||
cert_name_on_device=result.get("cert_name_on_device"),
|
||||
)
|
||||
|
||||
return CertDeployResponse(
|
||||
success=True,
|
||||
device_id=cert.device_id,
|
||||
cert_name_on_device=result.get("cert_name_on_device"),
|
||||
)
|
||||
else:
|
||||
# Rollback status to issued
|
||||
await update_cert_status(db, cert_id, "issued")
|
||||
|
||||
logger.warning(
|
||||
"Certificate deployment failed",
|
||||
cert_id=str(cert_id),
|
||||
device_id=str(cert.device_id),
|
||||
error=result.get("error"),
|
||||
)
|
||||
|
||||
return CertDeployResponse(
|
||||
success=False,
|
||||
device_id=cert.device_id,
|
||||
error=result.get("error"),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/deploy/bulk",
|
||||
response_model=list[CertDeployResponse],
|
||||
summary="Bulk deploy certificates to multiple devices",
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def bulk_deploy(
|
||||
request: Request,
|
||||
body: BulkCertDeployRequest,
|
||||
tenant_id: uuid.UUID | None = Query(None, description="Tenant ID (required for super_admin)"),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("tenant_admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[CertDeployResponse]:
|
||||
"""Deploy certificates to multiple devices sequentially.
|
||||
|
||||
For each device: signs a cert if none exists (status=issued), then deploys.
|
||||
Sequential deployment per project patterns (no concurrent NATS calls).
|
||||
"""
|
||||
tenant_id = await _get_tenant_id(current_user, db, tenant_id)
|
||||
|
||||
# Get tenant CA
|
||||
ca = await get_ca_for_tenant(db, tenant_id)
|
||||
if ca is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No Certificate Authority configured. Initialize a CA first.",
|
||||
)
|
||||
|
||||
results: list[CertDeployResponse] = []
|
||||
|
||||
for device_id in body.device_ids:
|
||||
try:
|
||||
# Get device info
|
||||
device = await _get_device_for_tenant(db, device_id, current_user)
|
||||
|
||||
# Check if device already has an issued cert
|
||||
existing_certs = await get_device_certs(db, tenant_id, device_id)
|
||||
issued_cert = None
|
||||
for c in existing_certs:
|
||||
if c.status == "issued":
|
||||
issued_cert = c
|
||||
break
|
||||
|
||||
# Sign a new cert if none exists in issued state
|
||||
if issued_cert is None:
|
||||
issued_cert = await sign_device_cert(
|
||||
db,
|
||||
ca,
|
||||
device_id,
|
||||
device.hostname,
|
||||
device.ip_address,
|
||||
730, # Default 2 years
|
||||
settings.get_encryption_key_bytes(),
|
||||
)
|
||||
await db.flush()
|
||||
|
||||
# Deploy the cert
|
||||
await update_cert_status(db, issued_cert.id, "deploying")
|
||||
|
||||
cert_pem, key_pem, _ca_cert_pem = await get_cert_for_deploy(
|
||||
db, issued_cert.id, settings.get_encryption_key_bytes()
|
||||
)
|
||||
|
||||
await db.flush()
|
||||
|
||||
result = await _deploy_cert_via_nats(
|
||||
device_id=str(device_id),
|
||||
cert_pem=cert_pem,
|
||||
key_pem=key_pem,
|
||||
cert_name="portal-device-cert",
|
||||
)
|
||||
|
||||
if result.get("success"):
|
||||
await update_cert_status(db, issued_cert.id, "deployed")
|
||||
device.tls_mode = "portal_ca"
|
||||
|
||||
results.append(CertDeployResponse(
|
||||
success=True,
|
||||
device_id=device_id,
|
||||
cert_name_on_device=result.get("cert_name_on_device"),
|
||||
))
|
||||
else:
|
||||
await update_cert_status(db, issued_cert.id, "issued")
|
||||
results.append(CertDeployResponse(
|
||||
success=False,
|
||||
device_id=device_id,
|
||||
error=result.get("error"),
|
||||
))
|
||||
|
||||
except HTTPException as e:
|
||||
results.append(CertDeployResponse(
|
||||
success=False,
|
||||
device_id=device_id,
|
||||
error=e.detail,
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error("Bulk deploy error", device_id=str(device_id), error=str(e))
|
||||
results.append(CertDeployResponse(
|
||||
success=False,
|
||||
device_id=device_id,
|
||||
error=str(e),
|
||||
))
|
||||
|
||||
try:
|
||||
await log_action(
|
||||
db, tenant_id, current_user.user_id, "cert_bulk_deploy",
|
||||
resource_type="device_certificate",
|
||||
details={
|
||||
"device_count": len(body.device_ids),
|
||||
"successful": sum(1 for r in results if r.success),
|
||||
"failed": sum(1 for r in results if not r.success),
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@router.get(
|
||||
"/devices",
|
||||
response_model=list[DeviceCertResponse],
|
||||
summary="List device certificates",
|
||||
)
|
||||
async def list_device_certs(
|
||||
device_id: uuid.UUID | None = Query(None, description="Filter by device ID"),
|
||||
tenant_id: uuid.UUID | None = Query(None, description="Tenant ID (required for super_admin)"),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[DeviceCertResponse]:
|
||||
"""List device certificates for the tenant.
|
||||
|
||||
Optionally filter by device_id. Excludes superseded certs.
|
||||
"""
|
||||
tenant_id = await _get_tenant_id(current_user, db, tenant_id)
|
||||
certs = await get_device_certs(db, tenant_id, device_id)
|
||||
return [DeviceCertResponse.model_validate(c) for c in certs]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{cert_id}/revoke",
|
||||
response_model=DeviceCertResponse,
|
||||
summary="Revoke a device certificate",
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def revoke_cert(
|
||||
request: Request,
|
||||
cert_id: uuid.UUID,
|
||||
tenant_id: uuid.UUID | None = Query(None, description="Tenant ID (required for super_admin)"),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("tenant_admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> DeviceCertResponse:
|
||||
"""Revoke a device certificate and reset the device TLS mode to insecure."""
|
||||
tenant_id = await _get_tenant_id(current_user, db, tenant_id)
|
||||
cert = await _get_cert_with_tenant_check(db, cert_id, tenant_id)
|
||||
|
||||
try:
|
||||
updated_cert = await update_cert_status(db, cert_id, "revoked")
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Reset device tls_mode to insecure
|
||||
device_result = await db.execute(
|
||||
select(Device).where(Device.id == cert.device_id)
|
||||
)
|
||||
device = device_result.scalar_one_or_none()
|
||||
if device is not None:
|
||||
device.tls_mode = "insecure"
|
||||
|
||||
try:
|
||||
await log_action(
|
||||
db, tenant_id, current_user.user_id, "cert_revoke",
|
||||
resource_type="device_certificate", resource_id=str(cert_id),
|
||||
device_id=cert.device_id,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("Certificate revoked", cert_id=str(cert_id), device_id=str(cert.device_id))
|
||||
return DeviceCertResponse.model_validate(updated_cert)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{cert_id}/rotate",
|
||||
response_model=CertDeployResponse,
|
||||
summary="Rotate a device certificate",
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def rotate_cert(
|
||||
request: Request,
|
||||
cert_id: uuid.UUID,
|
||||
tenant_id: uuid.UUID | None = Query(None, description="Tenant ID (required for super_admin)"),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("tenant_admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> CertDeployResponse:
|
||||
"""Rotate a device certificate: supersede the old cert, sign a new one, and deploy it.
|
||||
|
||||
This is equivalent to: mark old cert as superseded, sign new cert, deploy new cert.
|
||||
"""
|
||||
tenant_id = await _get_tenant_id(current_user, db, tenant_id)
|
||||
old_cert = await _get_cert_with_tenant_check(db, cert_id, tenant_id)
|
||||
|
||||
# Get the device for hostname/IP
|
||||
device_result = await db.execute(
|
||||
select(Device).where(Device.id == old_cert.device_id)
|
||||
)
|
||||
device = device_result.scalar_one_or_none()
|
||||
if device is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Device {old_cert.device_id} not found",
|
||||
)
|
||||
|
||||
# Get tenant CA
|
||||
ca = await get_ca_for_tenant(db, tenant_id)
|
||||
if ca is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No Certificate Authority configured.",
|
||||
)
|
||||
|
||||
# Mark old cert as superseded
|
||||
try:
|
||||
await update_cert_status(db, cert_id, "superseded")
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Sign new cert
|
||||
new_cert = await sign_device_cert(
|
||||
db,
|
||||
ca,
|
||||
old_cert.device_id,
|
||||
device.hostname,
|
||||
device.ip_address,
|
||||
730, # Default 2 years
|
||||
settings.get_encryption_key_bytes(),
|
||||
)
|
||||
await db.flush()
|
||||
|
||||
# Deploy new cert
|
||||
await update_cert_status(db, new_cert.id, "deploying")
|
||||
|
||||
cert_pem, key_pem, _ca_cert_pem = await get_cert_for_deploy(
|
||||
db, new_cert.id, settings.get_encryption_key_bytes()
|
||||
)
|
||||
|
||||
await db.flush()
|
||||
|
||||
result = await _deploy_cert_via_nats(
|
||||
device_id=str(old_cert.device_id),
|
||||
cert_pem=cert_pem,
|
||||
key_pem=key_pem,
|
||||
cert_name="portal-device-cert",
|
||||
)
|
||||
|
||||
if result.get("success"):
|
||||
await update_cert_status(db, new_cert.id, "deployed")
|
||||
device.tls_mode = "portal_ca"
|
||||
|
||||
try:
|
||||
await log_action(
|
||||
db, tenant_id, current_user.user_id, "cert_rotate",
|
||||
resource_type="device_certificate", resource_id=str(new_cert.id),
|
||||
device_id=old_cert.device_id,
|
||||
details={
|
||||
"old_cert_id": str(cert_id),
|
||||
"cert_name_on_device": result.get("cert_name_on_device"),
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(
|
||||
"Certificate rotated successfully",
|
||||
old_cert_id=str(cert_id),
|
||||
new_cert_id=str(new_cert.id),
|
||||
device_id=str(old_cert.device_id),
|
||||
)
|
||||
|
||||
return CertDeployResponse(
|
||||
success=True,
|
||||
device_id=old_cert.device_id,
|
||||
cert_name_on_device=result.get("cert_name_on_device"),
|
||||
)
|
||||
else:
|
||||
# Rollback: mark new cert as issued (deploy failed)
|
||||
await update_cert_status(db, new_cert.id, "issued")
|
||||
|
||||
logger.warning(
|
||||
"Certificate rotation deploy failed",
|
||||
new_cert_id=str(new_cert.id),
|
||||
device_id=str(old_cert.device_id),
|
||||
error=result.get("error"),
|
||||
)
|
||||
|
||||
return CertDeployResponse(
|
||||
success=False,
|
||||
device_id=old_cert.device_id,
|
||||
error=result.get("error"),
|
||||
)
|
||||
297
backend/app/routers/clients.py
Normal file
297
backend/app/routers/clients.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
Client device discovery API endpoint.
|
||||
|
||||
Fetches ARP, DHCP lease, and wireless registration data from a RouterOS device
|
||||
via the NATS command proxy, merges by MAC address, and returns a unified client list.
|
||||
|
||||
All routes are tenant-scoped under:
|
||||
/api/tenants/{tenant_id}/devices/{device_id}/clients
|
||||
|
||||
RLS is enforced via get_db() (app_user engine with tenant context).
|
||||
RBAC: viewer and above (read-only operation).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.middleware.rbac import require_min_role
|
||||
from app.middleware.tenant_context import CurrentUser, get_current_user
|
||||
from app.models.device import Device
|
||||
from app.services import routeros_proxy
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
router = APIRouter(tags=["clients"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _check_tenant_access(
|
||||
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
|
||||
) -> None:
|
||||
"""Verify the current user is allowed to access the given tenant."""
|
||||
if current_user.is_super_admin:
|
||||
from app.database import set_tenant_context
|
||||
await set_tenant_context(db, str(tenant_id))
|
||||
return
|
||||
if current_user.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied: you do not belong to this tenant.",
|
||||
)
|
||||
|
||||
|
||||
async def _check_device_online(
|
||||
db: AsyncSession, device_id: uuid.UUID
|
||||
) -> Device:
|
||||
"""Verify the device exists and is online. Returns the Device object."""
|
||||
result = await db.execute(
|
||||
select(Device).where(Device.id == device_id) # type: ignore[arg-type]
|
||||
)
|
||||
device = result.scalar_one_or_none()
|
||||
if device is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Device {device_id} not found",
|
||||
)
|
||||
if device.status != "online":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Device is offline -- client discovery requires a live connection.",
|
||||
)
|
||||
return device
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MAC-address merge logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _normalize_mac(mac: str) -> str:
|
||||
"""Normalize a MAC address to uppercase colon-separated format."""
|
||||
return mac.strip().upper().replace("-", ":")
|
||||
|
||||
|
||||
def _merge_client_data(
|
||||
arp_data: list[dict[str, Any]],
|
||||
dhcp_data: list[dict[str, Any]],
|
||||
wireless_data: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Merge ARP, DHCP lease, and wireless registration data by MAC address.
|
||||
|
||||
ARP entries are the base. DHCP enriches with hostname. Wireless enriches
|
||||
with signal/tx/rx/uptime and marks the client as wireless.
|
||||
"""
|
||||
# Index DHCP leases by MAC
|
||||
dhcp_by_mac: dict[str, dict[str, Any]] = {}
|
||||
for lease in dhcp_data:
|
||||
mac_raw = lease.get("mac-address") or lease.get("active-mac-address", "")
|
||||
if mac_raw:
|
||||
dhcp_by_mac[_normalize_mac(mac_raw)] = lease
|
||||
|
||||
# Index wireless registrations by MAC
|
||||
wireless_by_mac: dict[str, dict[str, Any]] = {}
|
||||
for reg in wireless_data:
|
||||
mac_raw = reg.get("mac-address", "")
|
||||
if mac_raw:
|
||||
wireless_by_mac[_normalize_mac(mac_raw)] = reg
|
||||
|
||||
# Track which MACs we've already processed (from ARP)
|
||||
seen_macs: set[str] = set()
|
||||
clients: list[dict[str, Any]] = []
|
||||
|
||||
# Start with ARP entries as base
|
||||
for entry in arp_data:
|
||||
mac_raw = entry.get("mac-address", "")
|
||||
if not mac_raw:
|
||||
continue
|
||||
mac = _normalize_mac(mac_raw)
|
||||
if mac in seen_macs:
|
||||
continue
|
||||
seen_macs.add(mac)
|
||||
|
||||
# Determine status: ARP complete flag or dynamic flag
|
||||
is_complete = entry.get("complete", "true").lower() == "true"
|
||||
arp_status = "reachable" if is_complete else "stale"
|
||||
|
||||
client: dict[str, Any] = {
|
||||
"mac": mac,
|
||||
"ip": entry.get("address", ""),
|
||||
"interface": entry.get("interface", ""),
|
||||
"hostname": None,
|
||||
"status": arp_status,
|
||||
"signal_strength": None,
|
||||
"tx_rate": None,
|
||||
"rx_rate": None,
|
||||
"uptime": None,
|
||||
"is_wireless": False,
|
||||
}
|
||||
|
||||
# Enrich with DHCP data
|
||||
dhcp = dhcp_by_mac.get(mac)
|
||||
if dhcp:
|
||||
client["hostname"] = dhcp.get("host-name") or None
|
||||
dhcp_status = dhcp.get("status", "")
|
||||
if dhcp_status:
|
||||
client["dhcp_status"] = dhcp_status
|
||||
|
||||
# Enrich with wireless data
|
||||
wireless = wireless_by_mac.get(mac)
|
||||
if wireless:
|
||||
client["is_wireless"] = True
|
||||
client["signal_strength"] = wireless.get("signal-strength") or None
|
||||
client["tx_rate"] = wireless.get("tx-rate") or None
|
||||
client["rx_rate"] = wireless.get("rx-rate") or None
|
||||
client["uptime"] = wireless.get("uptime") or None
|
||||
|
||||
clients.append(client)
|
||||
|
||||
# Also include DHCP-only entries (no ARP match -- e.g. expired leases)
|
||||
for mac, lease in dhcp_by_mac.items():
|
||||
if mac in seen_macs:
|
||||
continue
|
||||
seen_macs.add(mac)
|
||||
|
||||
client = {
|
||||
"mac": mac,
|
||||
"ip": lease.get("active-address") or lease.get("address", ""),
|
||||
"interface": lease.get("active-server") or "",
|
||||
"hostname": lease.get("host-name") or None,
|
||||
"status": "stale", # No ARP entry = not actively reachable
|
||||
"signal_strength": None,
|
||||
"tx_rate": None,
|
||||
"rx_rate": None,
|
||||
"uptime": None,
|
||||
"is_wireless": mac in wireless_by_mac,
|
||||
}
|
||||
|
||||
wireless = wireless_by_mac.get(mac)
|
||||
if wireless:
|
||||
client["signal_strength"] = wireless.get("signal-strength") or None
|
||||
client["tx_rate"] = wireless.get("tx-rate") or None
|
||||
client["rx_rate"] = wireless.get("rx-rate") or None
|
||||
client["uptime"] = wireless.get("uptime") or None
|
||||
|
||||
clients.append(client)
|
||||
|
||||
return clients
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/clients",
|
||||
summary="List connected client devices (ARP + DHCP + wireless)",
|
||||
)
|
||||
async def list_clients(
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, Any]:
|
||||
"""Discover all client devices connected to a MikroTik device.
|
||||
|
||||
Fetches ARP table, DHCP server leases, and wireless registration table
|
||||
in parallel, then merges by MAC address into a unified client list.
|
||||
|
||||
Wireless fetch failure is non-fatal (device may not have wireless interfaces).
|
||||
DHCP fetch failure is non-fatal (device may not run a DHCP server).
|
||||
ARP fetch failure is fatal (core data source).
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
await _check_device_online(db, device_id)
|
||||
|
||||
device_id_str = str(device_id)
|
||||
|
||||
# Fetch all three sources in parallel
|
||||
arp_result, dhcp_result, wireless_result = await asyncio.gather(
|
||||
routeros_proxy.execute_command(device_id_str, "/ip/arp/print"),
|
||||
routeros_proxy.execute_command(device_id_str, "/ip/dhcp-server/lease/print"),
|
||||
routeros_proxy.execute_command(
|
||||
device_id_str, "/interface/wireless/registration-table/print"
|
||||
),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# ARP is required -- if it failed, return 502
|
||||
if isinstance(arp_result, Exception):
|
||||
logger.error("ARP fetch exception", device_id=device_id_str, error=str(arp_result))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Failed to fetch ARP table: {arp_result}",
|
||||
)
|
||||
if not arp_result.get("success"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=arp_result.get("error", "Failed to fetch ARP table"),
|
||||
)
|
||||
|
||||
arp_data: list[dict[str, Any]] = arp_result.get("data", [])
|
||||
|
||||
# DHCP is optional -- log warning and continue with empty data
|
||||
dhcp_data: list[dict[str, Any]] = []
|
||||
if isinstance(dhcp_result, Exception):
|
||||
logger.warning(
|
||||
"DHCP fetch exception (continuing without DHCP data)",
|
||||
device_id=device_id_str,
|
||||
error=str(dhcp_result),
|
||||
)
|
||||
elif not dhcp_result.get("success"):
|
||||
logger.warning(
|
||||
"DHCP fetch failed (continuing without DHCP data)",
|
||||
device_id=device_id_str,
|
||||
error=dhcp_result.get("error"),
|
||||
)
|
||||
else:
|
||||
dhcp_data = dhcp_result.get("data", [])
|
||||
|
||||
# Wireless is optional -- many devices have no wireless interfaces
|
||||
wireless_data: list[dict[str, Any]] = []
|
||||
if isinstance(wireless_result, Exception):
|
||||
logger.warning(
|
||||
"Wireless fetch exception (device may not have wireless interfaces)",
|
||||
device_id=device_id_str,
|
||||
error=str(wireless_result),
|
||||
)
|
||||
elif not wireless_result.get("success"):
|
||||
logger.warning(
|
||||
"Wireless fetch failed (device may not have wireless interfaces)",
|
||||
device_id=device_id_str,
|
||||
error=wireless_result.get("error"),
|
||||
)
|
||||
else:
|
||||
wireless_data = wireless_result.get("data", [])
|
||||
|
||||
# Merge by MAC address
|
||||
clients = _merge_client_data(arp_data, dhcp_data, wireless_data)
|
||||
|
||||
logger.info(
|
||||
"client_discovery_complete",
|
||||
device_id=device_id_str,
|
||||
tenant_id=str(tenant_id),
|
||||
arp_count=len(arp_data),
|
||||
dhcp_count=len(dhcp_data),
|
||||
wireless_count=len(wireless_data),
|
||||
merged_count=len(clients),
|
||||
)
|
||||
|
||||
return {
|
||||
"clients": clients,
|
||||
"device_id": device_id_str,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
745
backend/app/routers/config_backups.py
Normal file
745
backend/app/routers/config_backups.py
Normal file
@@ -0,0 +1,745 @@
|
||||
"""
|
||||
Config backup API endpoints.
|
||||
|
||||
All routes are tenant-scoped under:
|
||||
/api/tenants/{tenant_id}/devices/{device_id}/config/
|
||||
|
||||
Provides:
|
||||
- GET /backups — list backup timeline
|
||||
- POST /backups — trigger manual backup
|
||||
- POST /checkpoint — create a checkpoint (restore point)
|
||||
- GET /backups/{sha}/export — retrieve export.rsc text
|
||||
- GET /backups/{sha}/binary — download backup.bin
|
||||
- POST /preview-restore — preview impact analysis before restore
|
||||
- POST /restore — restore a config version (two-phase panic-revert)
|
||||
- POST /emergency-rollback — rollback to most recent pre-push backup
|
||||
- GET /schedules — view effective backup schedule
|
||||
- PUT /schedules — create/update device-specific schedule override
|
||||
|
||||
RLS is enforced via get_db() (app_user engine with tenant context).
|
||||
RBAC: viewer = read-only (GET); operator and above = write (POST/PUT).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import timezone, datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.middleware.rate_limit import limiter
|
||||
from app.middleware.rbac import require_min_role, require_scope
|
||||
from app.middleware.tenant_context import CurrentUser, get_current_user
|
||||
from app.models.config_backup import ConfigBackupRun, ConfigBackupSchedule
|
||||
from app.config import settings
|
||||
from app.models.device import Device
|
||||
from app.services import backup_service, git_store
|
||||
from app.services import restore_service
|
||||
from app.services.crypto import decrypt_credentials_hybrid
|
||||
from app.services.rsc_parser import parse_rsc, validate_rsc, compute_impact
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["config-backups"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _check_tenant_access(
|
||||
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
|
||||
) -> None:
|
||||
"""
|
||||
Verify the current user is allowed to access the given tenant.
|
||||
|
||||
- super_admin can access any tenant — re-sets DB tenant context to target tenant.
|
||||
- All other roles must match their own tenant_id.
|
||||
"""
|
||||
if current_user.is_super_admin:
|
||||
from app.database import set_tenant_context
|
||||
await set_tenant_context(db, str(tenant_id))
|
||||
return
|
||||
if current_user.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied: you do not belong to this tenant.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request/Response schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RestoreRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
commit_sha: str
|
||||
|
||||
|
||||
class ScheduleUpdate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
cron_expression: str
|
||||
enabled: bool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/config/backups",
|
||||
summary="List backup timeline for a device",
|
||||
dependencies=[require_scope("config:read")],
|
||||
)
|
||||
async def list_backups(
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return backup timeline for a device, newest first.
|
||||
|
||||
Each entry includes: id, commit_sha, trigger_type, lines_added,
|
||||
lines_removed, and created_at.
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(ConfigBackupRun)
|
||||
.where(
|
||||
ConfigBackupRun.device_id == device_id, # type: ignore[arg-type]
|
||||
ConfigBackupRun.tenant_id == tenant_id, # type: ignore[arg-type]
|
||||
)
|
||||
.order_by(ConfigBackupRun.created_at.desc())
|
||||
)
|
||||
runs = result.scalars().all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(run.id),
|
||||
"commit_sha": run.commit_sha,
|
||||
"trigger_type": run.trigger_type,
|
||||
"lines_added": run.lines_added,
|
||||
"lines_removed": run.lines_removed,
|
||||
"encryption_tier": run.encryption_tier,
|
||||
"created_at": run.created_at.isoformat(),
|
||||
}
|
||||
for run in runs
|
||||
]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/config/backups",
|
||||
summary="Trigger a manual config backup",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
dependencies=[require_scope("config:write")],
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def trigger_backup(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("operator")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, Any]:
|
||||
"""Trigger an immediate manual backup for a device.
|
||||
|
||||
Captures export.rsc and backup.bin via SSH, commits to the tenant's
|
||||
git store, and records a ConfigBackupRun with trigger_type='manual'.
|
||||
Returns the backup metadata dict.
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
try:
|
||||
result = await backup_service.run_backup(
|
||||
device_id=str(device_id),
|
||||
tenant_id=str(tenant_id),
|
||||
trigger_type="manual",
|
||||
db_session=db,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Manual backup failed for device %s tenant %s: %s",
|
||||
device_id,
|
||||
tenant_id,
|
||||
exc,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Backup failed: {exc}",
|
||||
) from exc
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/config/checkpoint",
|
||||
summary="Create a checkpoint (restore point) of the current config",
|
||||
dependencies=[require_scope("config:write")],
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def create_checkpoint(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("operator")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, Any]:
|
||||
"""Create a checkpoint (restore point) of the current device config.
|
||||
|
||||
Identical to a manual backup but tagged with trigger_type='checkpoint'.
|
||||
Checkpoints serve as named restore points that operators create before
|
||||
making risky changes, so they can easily roll back.
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
try:
|
||||
result = await backup_service.run_backup(
|
||||
device_id=str(device_id),
|
||||
tenant_id=str(tenant_id),
|
||||
trigger_type="checkpoint",
|
||||
db_session=db,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Checkpoint backup failed for device %s tenant %s: %s",
|
||||
device_id,
|
||||
tenant_id,
|
||||
exc,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Checkpoint failed: {exc}",
|
||||
) from exc
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/config/backups/{commit_sha}/export",
|
||||
summary="Get export.rsc text for a specific backup",
|
||||
response_class=Response,
|
||||
dependencies=[require_scope("config:read")],
|
||||
)
|
||||
async def get_export(
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
commit_sha: str,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Response:
|
||||
"""Return the raw /export compact text for a specific backup version.
|
||||
|
||||
For encrypted backups (encryption_tier != NULL), the Transit ciphertext
|
||||
stored in git is decrypted on-demand before returning plaintext.
|
||||
Legacy plaintext backups (encryption_tier = NULL) are returned as-is.
|
||||
|
||||
Content-Type: text/plain
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
content_bytes = await loop.run_in_executor(
|
||||
None,
|
||||
git_store.read_file,
|
||||
str(tenant_id),
|
||||
commit_sha,
|
||||
str(device_id),
|
||||
"export.rsc",
|
||||
)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Backup version not found: {exc}",
|
||||
) from exc
|
||||
|
||||
# Check if this backup is encrypted — decrypt via Transit if so
|
||||
result = await db.execute(
|
||||
select(ConfigBackupRun).where(
|
||||
ConfigBackupRun.commit_sha == commit_sha,
|
||||
ConfigBackupRun.device_id == device_id,
|
||||
)
|
||||
)
|
||||
backup_run = result.scalar_one_or_none()
|
||||
if backup_run and backup_run.encryption_tier:
|
||||
try:
|
||||
from app.services.crypto import decrypt_data_transit
|
||||
|
||||
plaintext = await decrypt_data_transit(
|
||||
content_bytes.decode("utf-8"), str(tenant_id)
|
||||
)
|
||||
content_bytes = plaintext.encode("utf-8")
|
||||
except Exception as dec_err:
|
||||
logger.error(
|
||||
"Failed to decrypt export for device %s sha %s: %s",
|
||||
device_id, commit_sha, dec_err,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to decrypt backup content",
|
||||
) from dec_err
|
||||
|
||||
return Response(content=content_bytes, media_type="text/plain")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/config/backups/{commit_sha}/binary",
|
||||
summary="Download backup.bin for a specific backup",
|
||||
response_class=Response,
|
||||
dependencies=[require_scope("config:read")],
|
||||
)
|
||||
async def get_binary(
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
commit_sha: str,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Response:
|
||||
"""Download the RouterOS binary backup file for a specific backup version.
|
||||
|
||||
For encrypted backups, the Transit ciphertext is decrypted and the
|
||||
base64-encoded binary is decoded back to raw bytes before returning.
|
||||
Legacy plaintext backups are returned as-is.
|
||||
|
||||
Content-Type: application/octet-stream (attachment download).
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
content_bytes = await loop.run_in_executor(
|
||||
None,
|
||||
git_store.read_file,
|
||||
str(tenant_id),
|
||||
commit_sha,
|
||||
str(device_id),
|
||||
"backup.bin",
|
||||
)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Backup version not found: {exc}",
|
||||
) from exc
|
||||
|
||||
# Check if this backup is encrypted — decrypt via Transit if so
|
||||
result = await db.execute(
|
||||
select(ConfigBackupRun).where(
|
||||
ConfigBackupRun.commit_sha == commit_sha,
|
||||
ConfigBackupRun.device_id == device_id,
|
||||
)
|
||||
)
|
||||
backup_run = result.scalar_one_or_none()
|
||||
if backup_run and backup_run.encryption_tier:
|
||||
try:
|
||||
import base64 as b64
|
||||
|
||||
from app.services.crypto import decrypt_data_transit
|
||||
|
||||
# Transit ciphertext -> base64-encoded binary -> raw bytes
|
||||
b64_plaintext = await decrypt_data_transit(
|
||||
content_bytes.decode("utf-8"), str(tenant_id)
|
||||
)
|
||||
content_bytes = b64.b64decode(b64_plaintext)
|
||||
except Exception as dec_err:
|
||||
logger.error(
|
||||
"Failed to decrypt binary backup for device %s sha %s: %s",
|
||||
device_id, commit_sha, dec_err,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to decrypt backup content",
|
||||
) from dec_err
|
||||
|
||||
return Response(
|
||||
content=content_bytes,
|
||||
media_type="application/octet-stream",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="backup-{commit_sha[:8]}.bin"'
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/config/preview-restore",
|
||||
summary="Preview the impact of restoring a config backup",
|
||||
dependencies=[require_scope("config:read")],
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def preview_restore(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
body: RestoreRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("operator")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, Any]:
|
||||
"""Preview the impact of restoring a config backup before executing.
|
||||
|
||||
Reads the target config from the git backup, fetches the current config
|
||||
from the live device (falling back to the latest backup if unreachable),
|
||||
and returns a diff with categories, risk levels, warnings, and validation.
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# 1. Read target export from git
|
||||
try:
|
||||
target_bytes = await loop.run_in_executor(
|
||||
None,
|
||||
git_store.read_file,
|
||||
str(tenant_id),
|
||||
body.commit_sha,
|
||||
str(device_id),
|
||||
"export.rsc",
|
||||
)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Backup export not found: {exc}",
|
||||
) from exc
|
||||
|
||||
target_text = target_bytes.decode("utf-8", errors="replace")
|
||||
|
||||
# 2. Get current export from device (live) or fallback to latest backup
|
||||
current_text = ""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(Device).where(Device.id == device_id) # type: ignore[arg-type]
|
||||
)
|
||||
device = result.scalar_one_or_none()
|
||||
if device and (device.encrypted_credentials_transit or device.encrypted_credentials):
|
||||
key = settings.get_encryption_key_bytes()
|
||||
creds_json = await decrypt_credentials_hybrid(
|
||||
device.encrypted_credentials_transit,
|
||||
device.encrypted_credentials,
|
||||
str(tenant_id),
|
||||
key,
|
||||
)
|
||||
import json
|
||||
creds = json.loads(creds_json)
|
||||
current_text = await backup_service.capture_export(
|
||||
device.ip_address,
|
||||
username=creds.get("username", "admin"),
|
||||
password=creds.get("password", ""),
|
||||
)
|
||||
except Exception:
|
||||
# Fallback to latest backup in git
|
||||
logger.debug(
|
||||
"Live export failed for device %s, falling back to latest backup",
|
||||
device_id,
|
||||
)
|
||||
latest = await db.execute(
|
||||
select(ConfigBackupRun)
|
||||
.where(
|
||||
ConfigBackupRun.device_id == device_id, # type: ignore[arg-type]
|
||||
)
|
||||
.order_by(ConfigBackupRun.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
latest_run = latest.scalar_one_or_none()
|
||||
if latest_run:
|
||||
try:
|
||||
current_bytes = await loop.run_in_executor(
|
||||
None,
|
||||
git_store.read_file,
|
||||
str(tenant_id),
|
||||
latest_run.commit_sha,
|
||||
str(device_id),
|
||||
"export.rsc",
|
||||
)
|
||||
current_text = current_bytes.decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
current_text = ""
|
||||
|
||||
# 3. Parse and analyze
|
||||
current_parsed = parse_rsc(current_text)
|
||||
target_parsed = parse_rsc(target_text)
|
||||
validation = validate_rsc(target_text)
|
||||
impact = compute_impact(current_parsed, target_parsed)
|
||||
|
||||
return {
|
||||
"diff": impact["diff"],
|
||||
"categories": impact["categories"],
|
||||
"warnings": impact["warnings"],
|
||||
"validation": validation,
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/config/restore",
|
||||
summary="Restore a config version (two-phase push with panic-revert)",
|
||||
dependencies=[require_scope("config:write")],
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def restore_config_endpoint(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
body: RestoreRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("operator")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, Any]:
|
||||
"""Restore a device config to a specific backup version.
|
||||
|
||||
Implements two-phase push with panic-revert:
|
||||
1. Pre-backup is taken on device (mandatory before any push)
|
||||
2. RouterOS scheduler is installed as safety net (auto-reverts if unreachable)
|
||||
3. Config is pushed via /import
|
||||
4. Wait 60s for config to settle
|
||||
5. Reachability check — remove scheduler if device is reachable
|
||||
6. Return committed/reverted/failed status
|
||||
|
||||
Returns: {"status": str, "message": str, "pre_backup_sha": str}
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
try:
|
||||
result = await restore_service.restore_config(
|
||||
device_id=str(device_id),
|
||||
tenant_id=str(tenant_id),
|
||||
commit_sha=body.commit_sha,
|
||||
db_session=db,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Restore failed for device %s tenant %s commit %s: %s",
|
||||
device_id,
|
||||
tenant_id,
|
||||
body.commit_sha,
|
||||
exc,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Restore failed: {exc}",
|
||||
) from exc
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/config/emergency-rollback",
|
||||
summary="Emergency rollback to most recent pre-push backup",
|
||||
dependencies=[require_scope("config:write")],
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def emergency_rollback(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("operator")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, Any]:
|
||||
"""Emergency rollback: restore the most recent pre-push backup.
|
||||
|
||||
Used when a device goes offline after a config push.
|
||||
Finds the latest 'pre-restore', 'checkpoint', or 'pre-template-push'
|
||||
backup and restores it via the two-phase panic-revert process.
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(ConfigBackupRun)
|
||||
.where(
|
||||
ConfigBackupRun.device_id == device_id, # type: ignore[arg-type]
|
||||
ConfigBackupRun.tenant_id == tenant_id, # type: ignore[arg-type]
|
||||
ConfigBackupRun.trigger_type.in_(
|
||||
["pre-restore", "checkpoint", "pre-template-push"]
|
||||
),
|
||||
)
|
||||
.order_by(ConfigBackupRun.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
backup = result.scalar_one_or_none()
|
||||
if not backup:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No pre-push backup found for rollback",
|
||||
)
|
||||
|
||||
try:
|
||||
restore_result = await restore_service.restore_config(
|
||||
device_id=str(device_id),
|
||||
tenant_id=str(tenant_id),
|
||||
commit_sha=backup.commit_sha,
|
||||
db_session=db,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Emergency rollback failed for device %s tenant %s: %s",
|
||||
device_id,
|
||||
tenant_id,
|
||||
exc,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Emergency rollback failed: {exc}",
|
||||
) from exc
|
||||
|
||||
return {
|
||||
**restore_result,
|
||||
"rolled_back_to": backup.commit_sha,
|
||||
"rolled_back_to_date": backup.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/config/schedules",
|
||||
summary="Get effective backup schedule for a device",
|
||||
dependencies=[require_scope("config:read")],
|
||||
)
|
||||
async def get_schedule(
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, Any]:
|
||||
"""Return the effective backup schedule for a device.
|
||||
|
||||
Returns the device-specific override if it exists; falls back to the
|
||||
tenant-level default. If no schedule is configured, returns a synthetic
|
||||
default (2am UTC daily, enabled=True).
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
# Check for device-specific override first
|
||||
result = await db.execute(
|
||||
select(ConfigBackupSchedule).where(
|
||||
ConfigBackupSchedule.tenant_id == tenant_id, # type: ignore[arg-type]
|
||||
ConfigBackupSchedule.device_id == device_id, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
schedule = result.scalar_one_or_none()
|
||||
|
||||
if schedule is None:
|
||||
# Fall back to tenant-level default
|
||||
result = await db.execute(
|
||||
select(ConfigBackupSchedule).where(
|
||||
ConfigBackupSchedule.tenant_id == tenant_id, # type: ignore[arg-type]
|
||||
ConfigBackupSchedule.device_id.is_(None), # type: ignore[union-attr]
|
||||
)
|
||||
)
|
||||
schedule = result.scalar_one_or_none()
|
||||
|
||||
if schedule is None:
|
||||
# No schedule configured — return synthetic default
|
||||
return {
|
||||
"id": None,
|
||||
"tenant_id": str(tenant_id),
|
||||
"device_id": str(device_id),
|
||||
"cron_expression": "0 2 * * *",
|
||||
"enabled": True,
|
||||
"is_default": True,
|
||||
}
|
||||
|
||||
is_device_specific = schedule.device_id is not None
|
||||
return {
|
||||
"id": str(schedule.id),
|
||||
"tenant_id": str(schedule.tenant_id),
|
||||
"device_id": str(schedule.device_id) if schedule.device_id else None,
|
||||
"cron_expression": schedule.cron_expression,
|
||||
"enabled": schedule.enabled,
|
||||
"is_default": not is_device_specific,
|
||||
}
|
||||
|
||||
|
||||
@router.put(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/config/schedules",
|
||||
summary="Create or update the device-specific backup schedule",
|
||||
dependencies=[require_scope("config:write")],
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def update_schedule(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
body: ScheduleUpdate,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("operator")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, Any]:
|
||||
"""Create or update the device-specific backup schedule override.
|
||||
|
||||
If no device-specific schedule exists, creates one. If one exists, updates
|
||||
its cron_expression and enabled fields.
|
||||
|
||||
Returns the updated schedule.
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
# Look for existing device-specific schedule
|
||||
result = await db.execute(
|
||||
select(ConfigBackupSchedule).where(
|
||||
ConfigBackupSchedule.tenant_id == tenant_id, # type: ignore[arg-type]
|
||||
ConfigBackupSchedule.device_id == device_id, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
schedule = result.scalar_one_or_none()
|
||||
|
||||
if schedule is None:
|
||||
# Create new device-specific schedule
|
||||
schedule = ConfigBackupSchedule(
|
||||
tenant_id=tenant_id,
|
||||
device_id=device_id,
|
||||
cron_expression=body.cron_expression,
|
||||
enabled=body.enabled,
|
||||
)
|
||||
db.add(schedule)
|
||||
else:
|
||||
# Update existing schedule
|
||||
schedule.cron_expression = body.cron_expression
|
||||
schedule.enabled = body.enabled
|
||||
|
||||
await db.flush()
|
||||
|
||||
# Hot-reload the scheduler so changes take effect immediately
|
||||
from app.services.backup_scheduler import on_schedule_change
|
||||
await on_schedule_change(tenant_id, device_id)
|
||||
|
||||
return {
|
||||
"id": str(schedule.id),
|
||||
"tenant_id": str(schedule.tenant_id),
|
||||
"device_id": str(schedule.device_id),
|
||||
"cron_expression": schedule.cron_expression,
|
||||
"enabled": schedule.enabled,
|
||||
"is_default": False,
|
||||
}
|
||||
371
backend/app/routers/config_editor.py
Normal file
371
backend/app/routers/config_editor.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""
|
||||
Dynamic RouterOS config editor API endpoints.
|
||||
|
||||
All routes are tenant-scoped under:
|
||||
/api/tenants/{tenant_id}/devices/{device_id}/config-editor/
|
||||
|
||||
Proxies commands to the Go poller's CmdResponder via the RouterOS proxy service.
|
||||
|
||||
Provides:
|
||||
- GET /browse -- browse a RouterOS menu path
|
||||
- POST /add -- add a new entry
|
||||
- POST /set -- edit an existing entry
|
||||
- POST /remove -- delete an entry
|
||||
- POST /execute -- execute an arbitrary CLI command
|
||||
|
||||
RLS is enforced via get_db() (app_user engine with tenant context).
|
||||
RBAC: viewer = read-only (GET browse); operator and above = write (POST).
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import structlog
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.middleware.rate_limit import limiter
|
||||
from app.middleware.rbac import require_min_role, require_scope
|
||||
from app.middleware.tenant_context import CurrentUser, get_current_user
|
||||
from app.models.device import Device
|
||||
from app.security.command_blocklist import check_command_safety, check_path_safety
|
||||
from app.services import routeros_proxy
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
audit_logger = structlog.get_logger("audit")
|
||||
|
||||
router = APIRouter(tags=["config-editor"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _check_tenant_access(
|
||||
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
|
||||
) -> None:
|
||||
"""Verify the current user is allowed to access the given tenant."""
|
||||
from app.database import set_tenant_context
|
||||
|
||||
if current_user.is_super_admin:
|
||||
await set_tenant_context(db, str(tenant_id))
|
||||
return
|
||||
if current_user.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied: you do not belong to this tenant.",
|
||||
)
|
||||
# Set RLS context for regular users too
|
||||
await set_tenant_context(db, str(tenant_id))
|
||||
|
||||
|
||||
async def _check_device_online(
|
||||
db: AsyncSession, device_id: uuid.UUID
|
||||
) -> Device:
|
||||
"""Verify the device exists and is online. Returns the Device object."""
|
||||
result = await db.execute(
|
||||
select(Device).where(Device.id == device_id) # type: ignore[arg-type]
|
||||
)
|
||||
device = result.scalar_one_or_none()
|
||||
if device is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Device {device_id} not found",
|
||||
)
|
||||
if device.status != "online":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Device is offline \u2014 config editor requires a live connection.",
|
||||
)
|
||||
return device
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AddEntryRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
path: str
|
||||
properties: dict[str, str]
|
||||
|
||||
|
||||
class SetEntryRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
path: str
|
||||
entry_id: str | None = None # Optional for singleton paths (e.g. /ip/dns)
|
||||
properties: dict[str, str]
|
||||
|
||||
|
||||
class RemoveEntryRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
path: str
|
||||
entry_id: str
|
||||
|
||||
|
||||
class ExecuteRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
command: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/config-editor/browse",
|
||||
summary="Browse a RouterOS menu path",
|
||||
dependencies=[require_scope("config:read")],
|
||||
)
|
||||
async def browse_menu(
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
path: str = Query("/interface", description="RouterOS menu path to browse"),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Browse a RouterOS menu path and return all entries at that path."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
await _check_device_online(db, device_id)
|
||||
check_path_safety(path)
|
||||
|
||||
result = await routeros_proxy.browse_menu(str(device_id), path)
|
||||
|
||||
if not result.get("success"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=result.get("error", "Failed to browse menu path"),
|
||||
)
|
||||
|
||||
audit_logger.info(
|
||||
"routeros_config_browsed",
|
||||
device_id=str(device_id),
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=str(current_user.user_id),
|
||||
path=path,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"entries": result.get("data", []),
|
||||
"error": None,
|
||||
"path": path,
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/config-editor/add",
|
||||
summary="Add a new entry to a RouterOS menu path",
|
||||
dependencies=[require_scope("config:write")],
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def add_entry(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
body: AddEntryRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("operator")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Add a new entry to a RouterOS menu path with the given properties."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
await _check_device_online(db, device_id)
|
||||
check_path_safety(body.path, write=True)
|
||||
|
||||
result = await routeros_proxy.add_entry(str(device_id), body.path, body.properties)
|
||||
|
||||
if not result.get("success"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=result.get("error", "Failed to add entry"),
|
||||
)
|
||||
|
||||
audit_logger.info(
|
||||
"routeros_config_added",
|
||||
device_id=str(device_id),
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=str(current_user.user_id),
|
||||
user_role=current_user.role,
|
||||
path=body.path,
|
||||
success=result.get("success", False),
|
||||
)
|
||||
|
||||
try:
|
||||
await log_action(
|
||||
db, tenant_id, current_user.user_id, "config_add",
|
||||
resource_type="config", resource_id=str(device_id),
|
||||
device_id=device_id,
|
||||
details={"path": body.path, "properties": body.properties},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/config-editor/set",
|
||||
summary="Edit an existing entry in a RouterOS menu path",
|
||||
dependencies=[require_scope("config:write")],
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def set_entry(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
body: SetEntryRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("operator")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Update an existing entry's properties on the device."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
await _check_device_online(db, device_id)
|
||||
check_path_safety(body.path, write=True)
|
||||
|
||||
result = await routeros_proxy.update_entry(
|
||||
str(device_id), body.path, body.entry_id, body.properties
|
||||
)
|
||||
|
||||
if not result.get("success"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=result.get("error", "Failed to update entry"),
|
||||
)
|
||||
|
||||
audit_logger.info(
|
||||
"routeros_config_modified",
|
||||
device_id=str(device_id),
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=str(current_user.user_id),
|
||||
user_role=current_user.role,
|
||||
path=body.path,
|
||||
entry_id=body.entry_id,
|
||||
success=result.get("success", False),
|
||||
)
|
||||
|
||||
try:
|
||||
await log_action(
|
||||
db, tenant_id, current_user.user_id, "config_set",
|
||||
resource_type="config", resource_id=str(device_id),
|
||||
device_id=device_id,
|
||||
details={"path": body.path, "entry_id": body.entry_id, "properties": body.properties},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/config-editor/remove",
|
||||
summary="Delete an entry from a RouterOS menu path",
|
||||
dependencies=[require_scope("config:write")],
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def remove_entry(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
body: RemoveEntryRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("operator")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Remove an entry from a RouterOS menu path."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
await _check_device_online(db, device_id)
|
||||
check_path_safety(body.path, write=True)
|
||||
|
||||
result = await routeros_proxy.remove_entry(
|
||||
str(device_id), body.path, body.entry_id
|
||||
)
|
||||
|
||||
if not result.get("success"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=result.get("error", "Failed to remove entry"),
|
||||
)
|
||||
|
||||
audit_logger.info(
|
||||
"routeros_config_removed",
|
||||
device_id=str(device_id),
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=str(current_user.user_id),
|
||||
user_role=current_user.role,
|
||||
path=body.path,
|
||||
entry_id=body.entry_id,
|
||||
success=result.get("success", False),
|
||||
)
|
||||
|
||||
try:
|
||||
await log_action(
|
||||
db, tenant_id, current_user.user_id, "config_remove",
|
||||
resource_type="config", resource_id=str(device_id),
|
||||
device_id=device_id,
|
||||
details={"path": body.path, "entry_id": body.entry_id},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/config-editor/execute",
|
||||
summary="Execute an arbitrary RouterOS CLI command",
|
||||
dependencies=[require_scope("config:write")],
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def execute_command(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
body: ExecuteRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("operator")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Execute an arbitrary RouterOS CLI command on the device."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
await _check_device_online(db, device_id)
|
||||
check_command_safety(body.command)
|
||||
|
||||
result = await routeros_proxy.execute_cli(str(device_id), body.command)
|
||||
|
||||
if not result.get("success"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=result.get("error", "Failed to execute command"),
|
||||
)
|
||||
|
||||
audit_logger.info(
|
||||
"routeros_command_executed",
|
||||
device_id=str(device_id),
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=str(current_user.user_id),
|
||||
user_role=current_user.role,
|
||||
command=body.command,
|
||||
success=result.get("success", False),
|
||||
)
|
||||
|
||||
try:
|
||||
await log_action(
|
||||
db, tenant_id, current_user.user_id, "config_execute",
|
||||
resource_type="config", resource_id=str(device_id),
|
||||
device_id=device_id,
|
||||
details={"command": body.command},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
94
backend/app/routers/device_groups.py
Normal file
94
backend/app/routers/device_groups.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
Device group management API endpoints.
|
||||
|
||||
Routes: /api/tenants/{tenant_id}/device-groups
|
||||
|
||||
RBAC:
|
||||
- viewer: GET (read-only)
|
||||
- operator: POST, PUT (write)
|
||||
- tenant_admin/admin: DELETE
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.middleware.rbac import require_operator_or_above, require_tenant_admin_or_above
|
||||
from app.middleware.tenant_context import CurrentUser, get_current_user
|
||||
from app.routers.devices import _check_tenant_access
|
||||
from app.schemas.device import DeviceGroupCreate, DeviceGroupResponse, DeviceGroupUpdate
|
||||
from app.services import device as device_service
|
||||
|
||||
router = APIRouter(tags=["device-groups"])
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/device-groups",
|
||||
response_model=list[DeviceGroupResponse],
|
||||
summary="List device groups",
|
||||
)
|
||||
async def list_groups(
|
||||
tenant_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[DeviceGroupResponse]:
|
||||
"""List all device groups for a tenant. Viewer role and above."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
return await device_service.get_groups(db=db, tenant_id=tenant_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/device-groups",
|
||||
response_model=DeviceGroupResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a device group",
|
||||
dependencies=[Depends(require_operator_or_above)],
|
||||
)
|
||||
async def create_group(
|
||||
tenant_id: uuid.UUID,
|
||||
data: DeviceGroupCreate,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> DeviceGroupResponse:
|
||||
"""Create a new device group. Requires operator role or above."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
return await device_service.create_group(db=db, tenant_id=tenant_id, data=data)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/tenants/{tenant_id}/device-groups/{group_id}",
|
||||
response_model=DeviceGroupResponse,
|
||||
summary="Update a device group",
|
||||
dependencies=[Depends(require_operator_or_above)],
|
||||
)
|
||||
async def update_group(
|
||||
tenant_id: uuid.UUID,
|
||||
group_id: uuid.UUID,
|
||||
data: DeviceGroupUpdate,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> DeviceGroupResponse:
|
||||
"""Update a device group. Requires operator role or above."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
return await device_service.update_group(
|
||||
db=db, tenant_id=tenant_id, group_id=group_id, data=data
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/tenants/{tenant_id}/device-groups/{group_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a device group",
|
||||
dependencies=[Depends(require_tenant_admin_or_above)],
|
||||
)
|
||||
async def delete_group(
|
||||
tenant_id: uuid.UUID,
|
||||
group_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Delete a device group. Requires tenant_admin or above."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
await device_service.delete_group(db=db, tenant_id=tenant_id, group_id=group_id)
|
||||
150
backend/app/routers/device_logs.py
Normal file
150
backend/app/routers/device_logs.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Device syslog fetch endpoint via NATS RouterOS proxy.
|
||||
|
||||
Provides:
|
||||
- GET /tenants/{tenant_id}/devices/{device_id}/logs -- fetch device log entries
|
||||
|
||||
RLS enforced via get_db() (app_user engine with tenant context).
|
||||
RBAC: viewer and above can read logs.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.middleware.rbac import require_min_role
|
||||
from app.middleware.tenant_context import CurrentUser, get_current_user
|
||||
from app.services import routeros_proxy
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
router = APIRouter(tags=["device-logs"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers (same pattern as config_editor.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _check_tenant_access(
|
||||
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
|
||||
) -> None:
|
||||
"""Verify the current user is allowed to access the given tenant."""
|
||||
if current_user.is_super_admin:
|
||||
from app.database import set_tenant_context
|
||||
await set_tenant_context(db, str(tenant_id))
|
||||
return
|
||||
if current_user.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied: you do not belong to this tenant.",
|
||||
)
|
||||
|
||||
|
||||
async def _check_device_exists(
|
||||
db: AsyncSession, device_id: uuid.UUID
|
||||
) -> None:
|
||||
"""Verify the device exists (does not require online status for logs)."""
|
||||
from sqlalchemy import select
|
||||
from app.models.device import Device
|
||||
|
||||
result = await db.execute(
|
||||
select(Device).where(Device.id == device_id)
|
||||
)
|
||||
device = result.scalar_one_or_none()
|
||||
if device is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Device {device_id} not found",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class LogEntry(BaseModel):
|
||||
time: str
|
||||
topics: str
|
||||
message: str
|
||||
|
||||
|
||||
class LogsResponse(BaseModel):
|
||||
logs: list[LogEntry]
|
||||
device_id: str
|
||||
count: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/logs",
|
||||
response_model=LogsResponse,
|
||||
summary="Fetch device syslog entries via RouterOS API",
|
||||
dependencies=[Depends(require_min_role("viewer"))],
|
||||
)
|
||||
async def get_device_logs(
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
limit: int = Query(default=100, ge=1, le=500),
|
||||
topic: str | None = Query(default=None, description="Filter by log topic"),
|
||||
search: str | None = Query(default=None, description="Search in message/topics"),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> LogsResponse:
|
||||
"""Fetch device log entries via the RouterOS /log/print command."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
await _check_device_exists(db, device_id)
|
||||
|
||||
# Build RouterOS command args
|
||||
args = [f"=count={limit}"]
|
||||
if topic:
|
||||
args.append(f"?topics={topic}")
|
||||
|
||||
result = await routeros_proxy.execute_command(
|
||||
str(device_id), "/log/print", args=args, timeout=15.0
|
||||
)
|
||||
|
||||
if not result.get("success"):
|
||||
error_msg = result.get("error", "Unknown error fetching logs")
|
||||
logger.warning(
|
||||
"failed to fetch device logs",
|
||||
device_id=str(device_id),
|
||||
error=error_msg,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Failed to fetch device logs: {error_msg}",
|
||||
)
|
||||
|
||||
# Parse log entries from RouterOS response
|
||||
raw_entries = result.get("data", [])
|
||||
logs: list[LogEntry] = []
|
||||
for entry in raw_entries:
|
||||
log_entry = LogEntry(
|
||||
time=entry.get("time", ""),
|
||||
topics=entry.get("topics", ""),
|
||||
message=entry.get("message", ""),
|
||||
)
|
||||
|
||||
# Apply search filter (case-insensitive) if provided
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
if (
|
||||
search_lower not in log_entry.message.lower()
|
||||
and search_lower not in log_entry.topics.lower()
|
||||
):
|
||||
continue
|
||||
|
||||
logs.append(log_entry)
|
||||
|
||||
return LogsResponse(
|
||||
logs=logs,
|
||||
device_id=str(device_id),
|
||||
count=len(logs),
|
||||
)
|
||||
94
backend/app/routers/device_tags.py
Normal file
94
backend/app/routers/device_tags.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
Device tag management API endpoints.
|
||||
|
||||
Routes: /api/tenants/{tenant_id}/device-tags
|
||||
|
||||
RBAC:
|
||||
- viewer: GET (read-only)
|
||||
- operator: POST, PUT (write)
|
||||
- tenant_admin/admin: DELETE
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.middleware.rbac import require_operator_or_above, require_tenant_admin_or_above
|
||||
from app.middleware.tenant_context import CurrentUser, get_current_user
|
||||
from app.routers.devices import _check_tenant_access
|
||||
from app.schemas.device import DeviceTagCreate, DeviceTagResponse, DeviceTagUpdate
|
||||
from app.services import device as device_service
|
||||
|
||||
router = APIRouter(tags=["device-tags"])
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/device-tags",
|
||||
response_model=list[DeviceTagResponse],
|
||||
summary="List device tags",
|
||||
)
|
||||
async def list_tags(
|
||||
tenant_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[DeviceTagResponse]:
|
||||
"""List all device tags for a tenant. Viewer role and above."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
return await device_service.get_tags(db=db, tenant_id=tenant_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/device-tags",
|
||||
response_model=DeviceTagResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a device tag",
|
||||
dependencies=[Depends(require_operator_or_above)],
|
||||
)
|
||||
async def create_tag(
|
||||
tenant_id: uuid.UUID,
|
||||
data: DeviceTagCreate,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> DeviceTagResponse:
|
||||
"""Create a new device tag. Requires operator role or above."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
return await device_service.create_tag(db=db, tenant_id=tenant_id, data=data)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/tenants/{tenant_id}/device-tags/{tag_id}",
|
||||
response_model=DeviceTagResponse,
|
||||
summary="Update a device tag",
|
||||
dependencies=[Depends(require_operator_or_above)],
|
||||
)
|
||||
async def update_tag(
|
||||
tenant_id: uuid.UUID,
|
||||
tag_id: uuid.UUID,
|
||||
data: DeviceTagUpdate,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> DeviceTagResponse:
|
||||
"""Update a device tag. Requires operator role or above."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
return await device_service.update_tag(
|
||||
db=db, tenant_id=tenant_id, tag_id=tag_id, data=data
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/tenants/{tenant_id}/device-tags/{tag_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a device tag",
|
||||
dependencies=[Depends(require_tenant_admin_or_above)],
|
||||
)
|
||||
async def delete_tag(
|
||||
tenant_id: uuid.UUID,
|
||||
tag_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Delete a device tag. Requires tenant_admin or above."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
await device_service.delete_tag(db=db, tenant_id=tenant_id, tag_id=tag_id)
|
||||
452
backend/app/routers/devices.py
Normal file
452
backend/app/routers/devices.py
Normal file
@@ -0,0 +1,452 @@
|
||||
"""
|
||||
Device management API endpoints.
|
||||
|
||||
All routes are tenant-scoped under /api/tenants/{tenant_id}/devices.
|
||||
RLS is enforced via PostgreSQL — the app_user engine automatically filters
|
||||
cross-tenant data based on the SET LOCAL app.current_tenant context set by
|
||||
get_current_user dependency.
|
||||
|
||||
RBAC:
|
||||
- viewer: GET (read-only)
|
||||
- operator: POST, PUT (write)
|
||||
- admin/tenant_admin: DELETE
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
from app.middleware.rate_limit import limiter
|
||||
from app.services.audit_service import log_action
|
||||
from app.middleware.rbac import (
|
||||
require_min_role,
|
||||
require_operator_or_above,
|
||||
require_scope,
|
||||
require_tenant_admin_or_above,
|
||||
)
|
||||
from app.middleware.tenant_context import CurrentUser, get_current_user
|
||||
from app.schemas.device import (
|
||||
BulkAddRequest,
|
||||
BulkAddResult,
|
||||
DeviceCreate,
|
||||
DeviceListResponse,
|
||||
DeviceResponse,
|
||||
DeviceUpdate,
|
||||
SubnetScanRequest,
|
||||
SubnetScanResponse,
|
||||
)
|
||||
from app.services import device as device_service
|
||||
from app.services.scanner import scan_subnet
|
||||
|
||||
router = APIRouter(tags=["devices"])
|
||||
|
||||
|
||||
async def _check_tenant_access(
|
||||
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
|
||||
) -> None:
|
||||
"""
|
||||
Verify the current user is allowed to access the given tenant.
|
||||
|
||||
- super_admin can access any tenant — re-sets DB tenant context to target tenant.
|
||||
- All other roles must match their own tenant_id.
|
||||
"""
|
||||
if current_user.is_super_admin:
|
||||
# Re-set tenant context to the target tenant so RLS allows the operation
|
||||
from app.database import set_tenant_context
|
||||
await set_tenant_context(db, str(tenant_id))
|
||||
return
|
||||
if current_user.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied: you do not belong to this tenant.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Device CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/devices",
|
||||
response_model=DeviceListResponse,
|
||||
summary="List devices with pagination and filtering",
|
||||
dependencies=[require_scope("devices:read")],
|
||||
)
|
||||
async def list_devices(
|
||||
tenant_id: uuid.UUID,
|
||||
page: int = Query(1, ge=1, description="Page number (1-based)"),
|
||||
page_size: int = Query(25, ge=1, le=100, description="Items per page (1-100)"),
|
||||
status_filter: Optional[str] = Query(None, alias="status"),
|
||||
search: Optional[str] = Query(None, description="Text search on hostname or IP"),
|
||||
tag_id: Optional[uuid.UUID] = Query(None),
|
||||
group_id: Optional[uuid.UUID] = Query(None),
|
||||
sort_by: str = Query("created_at", description="Field to sort by"),
|
||||
sort_order: str = Query("desc", description="asc or desc"),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> DeviceListResponse:
|
||||
"""List devices for a tenant with optional pagination, filtering, and sorting."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
items, total = await device_service.get_devices(
|
||||
db=db,
|
||||
tenant_id=tenant_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
status=status_filter,
|
||||
search=search,
|
||||
tag_id=tag_id,
|
||||
group_id=group_id,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
return DeviceListResponse(items=items, total=total, page=page, page_size=page_size)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/devices",
|
||||
response_model=DeviceResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Add a device (validates TCP connectivity first)",
|
||||
dependencies=[Depends(require_operator_or_above), require_scope("devices:write")],
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def create_device(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
data: DeviceCreate,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> DeviceResponse:
|
||||
"""
|
||||
Create a new device. Requires operator role or above.
|
||||
|
||||
The device IP/port is TCP-probed before the record is saved.
|
||||
Credentials are encrypted with AES-256-GCM before storage and never returned.
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
result = await device_service.create_device(
|
||||
db=db,
|
||||
tenant_id=tenant_id,
|
||||
data=data,
|
||||
encryption_key=settings.get_encryption_key_bytes(),
|
||||
)
|
||||
try:
|
||||
await log_action(
|
||||
db, tenant_id, current_user.user_id, "device_create",
|
||||
resource_type="device", resource_id=str(result.id),
|
||||
details={"hostname": data.hostname, "ip_address": data.ip_address},
|
||||
ip_address=request.client.host if request.client else None,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/devices/{device_id}",
|
||||
response_model=DeviceResponse,
|
||||
summary="Get a single device",
|
||||
dependencies=[require_scope("devices:read")],
|
||||
)
|
||||
async def get_device(
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> DeviceResponse:
|
||||
"""Get device details. Viewer role and above."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
return await device_service.get_device(db=db, tenant_id=tenant_id, device_id=device_id)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/tenants/{tenant_id}/devices/{device_id}",
|
||||
response_model=DeviceResponse,
|
||||
summary="Update a device",
|
||||
dependencies=[Depends(require_operator_or_above), require_scope("devices:write")],
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def update_device(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
data: DeviceUpdate,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> DeviceResponse:
|
||||
"""Update device fields. Requires operator role or above."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
result = await device_service.update_device(
|
||||
db=db,
|
||||
tenant_id=tenant_id,
|
||||
device_id=device_id,
|
||||
data=data,
|
||||
encryption_key=settings.get_encryption_key_bytes(),
|
||||
)
|
||||
try:
|
||||
await log_action(
|
||||
db, tenant_id, current_user.user_id, "device_update",
|
||||
resource_type="device", resource_id=str(device_id),
|
||||
device_id=device_id,
|
||||
details={"changes": data.model_dump(exclude_unset=True)},
|
||||
ip_address=request.client.host if request.client else None,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/tenants/{tenant_id}/devices/{device_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a device",
|
||||
dependencies=[Depends(require_tenant_admin_or_above), require_scope("devices:write")],
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def delete_device(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Hard-delete a device. Requires tenant_admin or above."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
try:
|
||||
await log_action(
|
||||
db, tenant_id, current_user.user_id, "device_delete",
|
||||
resource_type="device", resource_id=str(device_id),
|
||||
device_id=device_id,
|
||||
ip_address=request.client.host if request.client else None,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
await device_service.delete_device(db=db, tenant_id=tenant_id, device_id=device_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subnet scan and bulk add
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/devices/scan",
|
||||
response_model=SubnetScanResponse,
|
||||
summary="Scan a subnet for MikroTik devices",
|
||||
dependencies=[Depends(require_operator_or_above), require_scope("devices:write")],
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def scan_devices(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
data: SubnetScanRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> SubnetScanResponse:
|
||||
"""
|
||||
Scan a CIDR subnet for hosts with open RouterOS API ports (8728/8729).
|
||||
|
||||
Returns a list of discovered IPs for the user to review and selectively
|
||||
import — does NOT automatically add devices.
|
||||
|
||||
Requires operator role or above.
|
||||
"""
|
||||
if not current_user.is_super_admin and current_user.tenant_id != tenant_id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
|
||||
discovered = await scan_subnet(data.cidr)
|
||||
import ipaddress
|
||||
network = ipaddress.ip_network(data.cidr, strict=False)
|
||||
total_scanned = network.num_addresses - 2 if network.num_addresses > 2 else network.num_addresses
|
||||
|
||||
# Audit log the scan (fire-and-forget — never breaks the response)
|
||||
try:
|
||||
await log_action(
|
||||
db, tenant_id, current_user.user_id, "subnet_scan",
|
||||
resource_type="network", resource_id=data.cidr,
|
||||
details={
|
||||
"cidr": data.cidr,
|
||||
"devices_found": len(discovered),
|
||||
"ip": request.client.host if request.client else None,
|
||||
},
|
||||
ip_address=request.client.host if request.client else None,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return SubnetScanResponse(
|
||||
cidr=data.cidr,
|
||||
discovered=discovered,
|
||||
total_scanned=total_scanned,
|
||||
total_discovered=len(discovered),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/devices/bulk-add",
|
||||
response_model=BulkAddResult,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Bulk-add devices from scan results",
|
||||
dependencies=[Depends(require_operator_or_above), require_scope("devices:write")],
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def bulk_add_devices(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
data: BulkAddRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> BulkAddResult:
|
||||
"""
|
||||
Add multiple devices at once from scan results.
|
||||
|
||||
Per-device credentials take precedence over shared credentials.
|
||||
Devices that fail connectivity checks or validation are reported in `failed`.
|
||||
Requires operator role or above.
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
added = []
|
||||
failed = []
|
||||
encryption_key = settings.get_encryption_key_bytes()
|
||||
|
||||
for dev_data in data.devices:
|
||||
# Resolve credentials: per-device first, then shared
|
||||
username = dev_data.username or data.shared_username
|
||||
password = dev_data.password or data.shared_password
|
||||
|
||||
if not username or not password:
|
||||
failed.append({
|
||||
"ip_address": dev_data.ip_address,
|
||||
"error": "No credentials provided (set per-device or shared credentials)",
|
||||
})
|
||||
continue
|
||||
|
||||
create_data = DeviceCreate(
|
||||
hostname=dev_data.hostname or dev_data.ip_address,
|
||||
ip_address=dev_data.ip_address,
|
||||
api_port=dev_data.api_port,
|
||||
api_ssl_port=dev_data.api_ssl_port,
|
||||
username=username,
|
||||
password=password,
|
||||
)
|
||||
|
||||
try:
|
||||
device = await device_service.create_device(
|
||||
db=db,
|
||||
tenant_id=tenant_id,
|
||||
data=create_data,
|
||||
encryption_key=encryption_key,
|
||||
)
|
||||
added.append(device)
|
||||
try:
|
||||
await log_action(
|
||||
db, tenant_id, current_user.user_id, "device_adopt",
|
||||
resource_type="device", resource_id=str(device.id),
|
||||
details={"hostname": create_data.hostname, "ip_address": create_data.ip_address},
|
||||
ip_address=request.client.host if request.client else None,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except HTTPException as exc:
|
||||
failed.append({"ip_address": dev_data.ip_address, "error": exc.detail})
|
||||
except Exception as exc:
|
||||
failed.append({"ip_address": dev_data.ip_address, "error": str(exc)})
|
||||
|
||||
return BulkAddResult(added=added, failed=failed)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group assignment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/groups/{group_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Add device to a group",
|
||||
dependencies=[Depends(require_operator_or_above), require_scope("devices:write")],
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def add_device_to_group(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
group_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Assign a device to a group. Requires operator or above."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
await device_service.assign_device_to_group(db, tenant_id, device_id, group_id)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/groups/{group_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Remove device from a group",
|
||||
dependencies=[Depends(require_operator_or_above), require_scope("devices:write")],
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def remove_device_from_group(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
group_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Remove a device from a group. Requires operator or above."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
await device_service.remove_device_from_group(db, tenant_id, device_id, group_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tag assignment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/tags/{tag_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Add tag to a device",
|
||||
dependencies=[Depends(require_operator_or_above), require_scope("devices:write")],
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def add_tag_to_device(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
tag_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Assign a tag to a device. Requires operator or above."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
await device_service.assign_tag_to_device(db, tenant_id, device_id, tag_id)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/tags/{tag_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Remove tag from a device",
|
||||
dependencies=[Depends(require_operator_or_above), require_scope("devices:write")],
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def remove_tag_from_device(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
tag_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Remove a tag from a device. Requires operator or above."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
await device_service.remove_tag_from_device(db, tenant_id, device_id, tag_id)
|
||||
164
backend/app/routers/events.py
Normal file
164
backend/app/routers/events.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Unified events timeline API endpoint.
|
||||
|
||||
Provides a single GET endpoint that unions alert events, device status changes,
|
||||
and config backup runs into a unified timeline for the dashboard.
|
||||
|
||||
RLS enforced via get_db() (app_user engine with tenant context).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db, set_tenant_context
|
||||
from app.middleware.tenant_context import CurrentUser, get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["events"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _check_tenant_access(
|
||||
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
|
||||
) -> None:
|
||||
"""Verify the current user is allowed to access the given tenant."""
|
||||
if current_user.is_super_admin:
|
||||
await set_tenant_context(db, str(tenant_id))
|
||||
elif current_user.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this tenant",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unified events endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/events",
|
||||
summary="List unified events (alerts, status changes, config backups)",
|
||||
)
|
||||
async def list_events(
|
||||
tenant_id: uuid.UUID,
|
||||
limit: int = Query(50, ge=1, le=200, description="Max events to return"),
|
||||
event_type: Optional[str] = Query(
|
||||
None,
|
||||
description="Filter by event type: alert, status_change, config_backup",
|
||||
),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return a unified list of recent events across alerts, device status, and config backups.
|
||||
|
||||
Events are ordered by timestamp descending, limited to `limit` (default 50).
|
||||
RLS automatically filters to the tenant's data via the app_user session.
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
if event_type and event_type not in ("alert", "status_change", "config_backup"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="event_type must be one of: alert, status_change, config_backup",
|
||||
)
|
||||
|
||||
events: list[dict[str, Any]] = []
|
||||
|
||||
# 1. Alert events
|
||||
if not event_type or event_type == "alert":
|
||||
alert_result = await db.execute(
|
||||
text("""
|
||||
SELECT ae.id, ae.status, ae.severity, ae.metric, ae.message,
|
||||
ae.fired_at, ae.device_id, d.hostname
|
||||
FROM alert_events ae
|
||||
LEFT JOIN devices d ON d.id = ae.device_id
|
||||
ORDER BY ae.fired_at DESC
|
||||
LIMIT :limit
|
||||
"""),
|
||||
{"limit": limit},
|
||||
)
|
||||
for row in alert_result.fetchall():
|
||||
alert_status = row[1] or "firing"
|
||||
metric = row[3] or "unknown"
|
||||
events.append({
|
||||
"id": str(row[0]),
|
||||
"event_type": "alert",
|
||||
"severity": row[2],
|
||||
"title": f"{alert_status}: {metric}",
|
||||
"description": row[4] or f"Alert {alert_status} for {metric}",
|
||||
"device_hostname": row[7],
|
||||
"device_id": str(row[6]) if row[6] else None,
|
||||
"timestamp": row[5].isoformat() if row[5] else None,
|
||||
})
|
||||
|
||||
# 2. Device status changes (inferred from current status + last_seen)
|
||||
if not event_type or event_type == "status_change":
|
||||
status_result = await db.execute(
|
||||
text("""
|
||||
SELECT d.id, d.hostname, d.status, d.last_seen
|
||||
FROM devices d
|
||||
WHERE d.last_seen IS NOT NULL
|
||||
ORDER BY d.last_seen DESC
|
||||
LIMIT :limit
|
||||
"""),
|
||||
{"limit": limit},
|
||||
)
|
||||
for row in status_result.fetchall():
|
||||
device_status = row[2] or "unknown"
|
||||
hostname = row[1] or "Unknown device"
|
||||
severity = "info" if device_status == "online" else "warning"
|
||||
events.append({
|
||||
"id": f"status-{row[0]}",
|
||||
"event_type": "status_change",
|
||||
"severity": severity,
|
||||
"title": f"Device {device_status}",
|
||||
"description": f"{hostname} is now {device_status}",
|
||||
"device_hostname": hostname,
|
||||
"device_id": str(row[0]),
|
||||
"timestamp": row[3].isoformat() if row[3] else None,
|
||||
})
|
||||
|
||||
# 3. Config backup runs
|
||||
if not event_type or event_type == "config_backup":
|
||||
backup_result = await db.execute(
|
||||
text("""
|
||||
SELECT cbr.id, cbr.trigger_type, cbr.created_at,
|
||||
cbr.device_id, d.hostname
|
||||
FROM config_backup_runs cbr
|
||||
LEFT JOIN devices d ON d.id = cbr.device_id
|
||||
ORDER BY cbr.created_at DESC
|
||||
LIMIT :limit
|
||||
"""),
|
||||
{"limit": limit},
|
||||
)
|
||||
for row in backup_result.fetchall():
|
||||
trigger_type = row[1] or "manual"
|
||||
hostname = row[4] or "Unknown device"
|
||||
events.append({
|
||||
"id": str(row[0]),
|
||||
"event_type": "config_backup",
|
||||
"severity": "info",
|
||||
"title": "Config backup",
|
||||
"description": f"{trigger_type} backup completed for {hostname}",
|
||||
"device_hostname": hostname,
|
||||
"device_id": str(row[3]) if row[3] else None,
|
||||
"timestamp": row[2].isoformat() if row[2] else None,
|
||||
})
|
||||
|
||||
# Sort all events by timestamp descending, then apply final limit
|
||||
events.sort(
|
||||
key=lambda e: e["timestamp"] or "",
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
return events[:limit]
|
||||
712
backend/app/routers/firmware.py
Normal file
712
backend/app/routers/firmware.py
Normal file
@@ -0,0 +1,712 @@
|
||||
"""Firmware API endpoints for version overview, cache management, preferred channel,
|
||||
and firmware upgrade orchestration.
|
||||
|
||||
Tenant-scoped routes under /api/tenants/{tenant_id}/firmware/*.
|
||||
Global routes under /api/firmware/* for version listing and admin actions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db, set_tenant_context
|
||||
from app.middleware.rate_limit import limiter
|
||||
from app.middleware.rbac import require_scope
|
||||
from app.middleware.tenant_context import CurrentUser, get_current_user
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
router = APIRouter(tags=["firmware"])
|
||||
|
||||
|
||||
async def _check_tenant_access(
|
||||
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
|
||||
) -> None:
|
||||
"""Verify the current user is allowed to access the given tenant."""
|
||||
if current_user.is_super_admin:
|
||||
await set_tenant_context(db, str(tenant_id))
|
||||
elif current_user.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this tenant",
|
||||
)
|
||||
|
||||
|
||||
class PreferredChannelRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
preferred_channel: str # "stable", "long-term", "testing"
|
||||
|
||||
|
||||
class FirmwareDownloadRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
architecture: str
|
||||
channel: str
|
||||
version: str
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# TENANT-SCOPED ENDPOINTS
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/firmware/overview",
|
||||
summary="Get firmware status for all devices in tenant",
|
||||
dependencies=[require_scope("firmware:write")],
|
||||
)
|
||||
async def get_firmware_overview(
|
||||
tenant_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, Any]:
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
from app.services.firmware_service import get_firmware_overview as _get_overview
|
||||
return await _get_overview(str(tenant_id))
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/preferred-channel",
|
||||
summary="Set preferred firmware channel for a device",
|
||||
dependencies=[require_scope("firmware:write")],
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def set_device_preferred_channel(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
body: PreferredChannelRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, str]:
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
if body.preferred_channel not in ("stable", "long-term", "testing"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="preferred_channel must be one of: stable, long-term, testing",
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
text("""
|
||||
UPDATE devices SET preferred_channel = :channel, updated_at = NOW()
|
||||
WHERE id = :device_id
|
||||
RETURNING id
|
||||
"""),
|
||||
{"channel": body.preferred_channel, "device_id": str(device_id)},
|
||||
)
|
||||
if not result.fetchone():
|
||||
raise HTTPException(status_code=404, detail="Device not found")
|
||||
await db.commit()
|
||||
return {"status": "ok", "preferred_channel": body.preferred_channel}
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/tenants/{tenant_id}/device-groups/{group_id}/preferred-channel",
|
||||
summary="Set preferred firmware channel for a device group",
|
||||
dependencies=[require_scope("firmware:write")],
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def set_group_preferred_channel(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
group_id: uuid.UUID,
|
||||
body: PreferredChannelRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, str]:
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
if body.preferred_channel not in ("stable", "long-term", "testing"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="preferred_channel must be one of: stable, long-term, testing",
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
text("""
|
||||
UPDATE device_groups SET preferred_channel = :channel
|
||||
WHERE id = :group_id
|
||||
RETURNING id
|
||||
"""),
|
||||
{"channel": body.preferred_channel, "group_id": str(group_id)},
|
||||
)
|
||||
if not result.fetchone():
|
||||
raise HTTPException(status_code=404, detail="Device group not found")
|
||||
await db.commit()
|
||||
return {"status": "ok", "preferred_channel": body.preferred_channel}
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# GLOBAL ENDPOINTS (firmware versions are not tenant-scoped)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/firmware/versions",
|
||||
summary="List all known firmware versions from cache",
|
||||
)
|
||||
async def list_firmware_versions(
|
||||
architecture: Optional[str] = Query(None),
|
||||
channel: Optional[str] = Query(None),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[dict[str, Any]]:
|
||||
filters = []
|
||||
params: dict[str, Any] = {}
|
||||
|
||||
if architecture:
|
||||
filters.append("architecture = :arch")
|
||||
params["arch"] = architecture
|
||||
if channel:
|
||||
filters.append("channel = :channel")
|
||||
params["channel"] = channel
|
||||
|
||||
where = f"WHERE {' AND '.join(filters)}" if filters else ""
|
||||
|
||||
result = await db.execute(
|
||||
text(f"""
|
||||
SELECT id, architecture, channel, version, npk_url,
|
||||
npk_local_path, npk_size_bytes, checked_at
|
||||
FROM firmware_versions
|
||||
{where}
|
||||
ORDER BY architecture, channel, checked_at DESC
|
||||
"""),
|
||||
params,
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(row[0]),
|
||||
"architecture": row[1],
|
||||
"channel": row[2],
|
||||
"version": row[3],
|
||||
"npk_url": row[4],
|
||||
"npk_local_path": row[5],
|
||||
"npk_size_bytes": row[6],
|
||||
"checked_at": row[7].isoformat() if row[7] else None,
|
||||
}
|
||||
for row in result.fetchall()
|
||||
]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/firmware/check",
|
||||
summary="Trigger immediate firmware version check (super admin only)",
|
||||
)
|
||||
async def trigger_firmware_check(
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> dict[str, Any]:
|
||||
if not current_user.is_super_admin:
|
||||
raise HTTPException(status_code=403, detail="Super admin only")
|
||||
|
||||
from app.services.firmware_service import check_latest_versions
|
||||
results = await check_latest_versions()
|
||||
return {"status": "ok", "versions_discovered": len(results), "versions": results}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/firmware/cache",
|
||||
summary="List locally cached NPK files (super admin only)",
|
||||
)
|
||||
async def list_firmware_cache(
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> list[dict[str, Any]]:
|
||||
if not current_user.is_super_admin:
|
||||
raise HTTPException(status_code=403, detail="Super admin only")
|
||||
|
||||
from app.services.firmware_service import get_cached_firmware
|
||||
return await get_cached_firmware()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/firmware/download",
|
||||
summary="Download a specific NPK to local cache (super admin only)",
|
||||
)
|
||||
async def download_firmware(
|
||||
body: FirmwareDownloadRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> dict[str, str]:
|
||||
if not current_user.is_super_admin:
|
||||
raise HTTPException(status_code=403, detail="Super admin only")
|
||||
|
||||
from app.services.firmware_service import download_firmware as _download
|
||||
path = await _download(body.architecture, body.channel, body.version)
|
||||
return {"status": "ok", "path": path}
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# UPGRADE ENDPOINTS
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class UpgradeRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
device_id: str
|
||||
target_version: str
|
||||
architecture: str
|
||||
channel: str = "stable"
|
||||
confirmed_major_upgrade: bool = False
|
||||
scheduled_at: Optional[str] = None # ISO datetime or None for immediate
|
||||
|
||||
|
||||
class MassUpgradeRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
device_ids: list[str]
|
||||
target_version: str
|
||||
channel: str = "stable"
|
||||
confirmed_major_upgrade: bool = False
|
||||
scheduled_at: Optional[str] = None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/firmware/upgrade",
|
||||
summary="Start or schedule a single device firmware upgrade",
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
dependencies=[require_scope("firmware:write")],
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def start_firmware_upgrade(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
body: UpgradeRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, Any]:
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
if current_user.role == "viewer":
|
||||
raise HTTPException(403, "Viewers cannot initiate upgrades")
|
||||
|
||||
# Look up device architecture if not provided
|
||||
architecture = body.architecture
|
||||
if not architecture:
|
||||
dev_result = await db.execute(
|
||||
text("SELECT architecture FROM devices WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": body.device_id},
|
||||
)
|
||||
dev_row = dev_result.fetchone()
|
||||
if not dev_row or not dev_row[0]:
|
||||
raise HTTPException(422, "Device architecture unknown — cannot upgrade")
|
||||
architecture = dev_row[0]
|
||||
|
||||
# Create upgrade job
|
||||
job_id = str(uuid.uuid4())
|
||||
await db.execute(
|
||||
text("""
|
||||
INSERT INTO firmware_upgrade_jobs
|
||||
(id, tenant_id, device_id, target_version, architecture, channel,
|
||||
status, confirmed_major_upgrade, scheduled_at)
|
||||
VALUES
|
||||
(CAST(:id AS uuid), CAST(:tenant_id AS uuid), CAST(:device_id AS uuid),
|
||||
:target_version, :architecture, :channel,
|
||||
:status, :confirmed, :scheduled_at)
|
||||
"""),
|
||||
{
|
||||
"id": job_id,
|
||||
"tenant_id": str(tenant_id),
|
||||
"device_id": body.device_id,
|
||||
"target_version": body.target_version,
|
||||
"architecture": architecture,
|
||||
"channel": body.channel,
|
||||
"status": "scheduled" if body.scheduled_at else "pending",
|
||||
"confirmed": body.confirmed_major_upgrade,
|
||||
"scheduled_at": body.scheduled_at,
|
||||
},
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
# Schedule or start immediately
|
||||
if body.scheduled_at:
|
||||
from app.services.upgrade_service import schedule_upgrade
|
||||
schedule_upgrade(job_id, datetime.fromisoformat(body.scheduled_at))
|
||||
else:
|
||||
from app.services.upgrade_service import start_upgrade
|
||||
asyncio.create_task(start_upgrade(job_id))
|
||||
|
||||
try:
|
||||
await log_action(
|
||||
db, tenant_id, current_user.user_id, "firmware_upgrade",
|
||||
resource_type="firmware", resource_id=job_id,
|
||||
device_id=uuid.UUID(body.device_id),
|
||||
details={"target_version": body.target_version, "channel": body.channel},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"status": "accepted", "job_id": job_id}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/firmware/mass-upgrade",
|
||||
summary="Start or schedule a mass firmware upgrade for multiple devices",
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
dependencies=[require_scope("firmware:write")],
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def start_mass_firmware_upgrade(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
body: MassUpgradeRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, Any]:
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
if current_user.role == "viewer":
|
||||
raise HTTPException(403, "Viewers cannot initiate upgrades")
|
||||
|
||||
rollout_group_id = str(uuid.uuid4())
|
||||
jobs = []
|
||||
|
||||
for device_id in body.device_ids:
|
||||
# Look up architecture per device
|
||||
dev_result = await db.execute(
|
||||
text("SELECT architecture FROM devices WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": device_id},
|
||||
)
|
||||
dev_row = dev_result.fetchone()
|
||||
architecture = dev_row[0] if dev_row and dev_row[0] else "unknown"
|
||||
|
||||
job_id = str(uuid.uuid4())
|
||||
await db.execute(
|
||||
text("""
|
||||
INSERT INTO firmware_upgrade_jobs
|
||||
(id, tenant_id, device_id, rollout_group_id,
|
||||
target_version, architecture, channel,
|
||||
status, confirmed_major_upgrade, scheduled_at)
|
||||
VALUES
|
||||
(CAST(:id AS uuid), CAST(:tenant_id AS uuid),
|
||||
CAST(:device_id AS uuid), CAST(:group_id AS uuid),
|
||||
:target_version, :architecture, :channel,
|
||||
:status, :confirmed, :scheduled_at)
|
||||
"""),
|
||||
{
|
||||
"id": job_id,
|
||||
"tenant_id": str(tenant_id),
|
||||
"device_id": device_id,
|
||||
"group_id": rollout_group_id,
|
||||
"target_version": body.target_version,
|
||||
"architecture": architecture,
|
||||
"channel": body.channel,
|
||||
"status": "scheduled" if body.scheduled_at else "pending",
|
||||
"confirmed": body.confirmed_major_upgrade,
|
||||
"scheduled_at": body.scheduled_at,
|
||||
},
|
||||
)
|
||||
jobs.append({"job_id": job_id, "device_id": device_id, "architecture": architecture})
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Schedule or start immediately
|
||||
if body.scheduled_at:
|
||||
from app.services.upgrade_service import schedule_mass_upgrade
|
||||
schedule_mass_upgrade(rollout_group_id, datetime.fromisoformat(body.scheduled_at))
|
||||
else:
|
||||
from app.services.upgrade_service import start_mass_upgrade
|
||||
asyncio.create_task(start_mass_upgrade(rollout_group_id))
|
||||
|
||||
return {
|
||||
"status": "accepted",
|
||||
"rollout_group_id": rollout_group_id,
|
||||
"jobs": jobs,
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/firmware/upgrades",
|
||||
summary="List firmware upgrade jobs for tenant",
|
||||
dependencies=[require_scope("firmware:write")],
|
||||
)
|
||||
async def list_upgrade_jobs(
|
||||
tenant_id: uuid.UUID,
|
||||
upgrade_status: Optional[str] = Query(None, alias="status"),
|
||||
device_id: Optional[str] = Query(None),
|
||||
rollout_group_id: Optional[str] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(50, ge=1, le=200),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, Any]:
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
filters = ["1=1"]
|
||||
params: dict[str, Any] = {}
|
||||
|
||||
if upgrade_status:
|
||||
filters.append("j.status = :status")
|
||||
params["status"] = upgrade_status
|
||||
if device_id:
|
||||
filters.append("j.device_id = CAST(:device_id AS uuid)")
|
||||
params["device_id"] = device_id
|
||||
if rollout_group_id:
|
||||
filters.append("j.rollout_group_id = CAST(:group_id AS uuid)")
|
||||
params["group_id"] = rollout_group_id
|
||||
|
||||
where = " AND ".join(filters)
|
||||
offset = (page - 1) * per_page
|
||||
|
||||
count_result = await db.execute(
|
||||
text(f"SELECT COUNT(*) FROM firmware_upgrade_jobs j WHERE {where}"),
|
||||
params,
|
||||
)
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
result = await db.execute(
|
||||
text(f"""
|
||||
SELECT j.id, j.device_id, j.rollout_group_id,
|
||||
j.target_version, j.architecture, j.channel,
|
||||
j.status, j.pre_upgrade_backup_sha, j.scheduled_at,
|
||||
j.started_at, j.completed_at, j.error_message,
|
||||
j.confirmed_major_upgrade, j.created_at,
|
||||
d.hostname AS device_hostname
|
||||
FROM firmware_upgrade_jobs j
|
||||
LEFT JOIN devices d ON d.id = j.device_id
|
||||
WHERE {where}
|
||||
ORDER BY j.created_at DESC
|
||||
LIMIT :limit OFFSET :offset
|
||||
"""),
|
||||
{**params, "limit": per_page, "offset": offset},
|
||||
)
|
||||
|
||||
items = [
|
||||
{
|
||||
"id": str(row[0]),
|
||||
"device_id": str(row[1]),
|
||||
"rollout_group_id": str(row[2]) if row[2] else None,
|
||||
"target_version": row[3],
|
||||
"architecture": row[4],
|
||||
"channel": row[5],
|
||||
"status": row[6],
|
||||
"pre_upgrade_backup_sha": row[7],
|
||||
"scheduled_at": row[8].isoformat() if row[8] else None,
|
||||
"started_at": row[9].isoformat() if row[9] else None,
|
||||
"completed_at": row[10].isoformat() if row[10] else None,
|
||||
"error_message": row[11],
|
||||
"confirmed_major_upgrade": row[12],
|
||||
"created_at": row[13].isoformat() if row[13] else None,
|
||||
"device_hostname": row[14],
|
||||
}
|
||||
for row in result.fetchall()
|
||||
]
|
||||
|
||||
return {"items": items, "total": total, "page": page, "per_page": per_page}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/firmware/upgrades/{job_id}",
|
||||
summary="Get single upgrade job detail",
|
||||
dependencies=[require_scope("firmware:write")],
|
||||
)
|
||||
async def get_upgrade_job(
|
||||
tenant_id: uuid.UUID,
|
||||
job_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, Any]:
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT j.id, j.device_id, j.rollout_group_id,
|
||||
j.target_version, j.architecture, j.channel,
|
||||
j.status, j.pre_upgrade_backup_sha, j.scheduled_at,
|
||||
j.started_at, j.completed_at, j.error_message,
|
||||
j.confirmed_major_upgrade, j.created_at,
|
||||
d.hostname AS device_hostname
|
||||
FROM firmware_upgrade_jobs j
|
||||
LEFT JOIN devices d ON d.id = j.device_id
|
||||
WHERE j.id = CAST(:job_id AS uuid)
|
||||
"""),
|
||||
{"job_id": str(job_id)},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if not row:
|
||||
raise HTTPException(404, "Upgrade job not found")
|
||||
|
||||
return {
|
||||
"id": str(row[0]),
|
||||
"device_id": str(row[1]),
|
||||
"rollout_group_id": str(row[2]) if row[2] else None,
|
||||
"target_version": row[3],
|
||||
"architecture": row[4],
|
||||
"channel": row[5],
|
||||
"status": row[6],
|
||||
"pre_upgrade_backup_sha": row[7],
|
||||
"scheduled_at": row[8].isoformat() if row[8] else None,
|
||||
"started_at": row[9].isoformat() if row[9] else None,
|
||||
"completed_at": row[10].isoformat() if row[10] else None,
|
||||
"error_message": row[11],
|
||||
"confirmed_major_upgrade": row[12],
|
||||
"created_at": row[13].isoformat() if row[13] else None,
|
||||
"device_hostname": row[14],
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/firmware/rollouts/{rollout_group_id}",
|
||||
summary="Get mass rollout status with all jobs",
|
||||
dependencies=[require_scope("firmware:write")],
|
||||
)
|
||||
async def get_rollout_status(
|
||||
tenant_id: uuid.UUID,
|
||||
rollout_group_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, Any]:
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT j.id, j.device_id, j.status, j.target_version,
|
||||
j.architecture, j.error_message, j.started_at,
|
||||
j.completed_at, d.hostname
|
||||
FROM firmware_upgrade_jobs j
|
||||
LEFT JOIN devices d ON d.id = j.device_id
|
||||
WHERE j.rollout_group_id = CAST(:group_id AS uuid)
|
||||
ORDER BY j.created_at ASC
|
||||
"""),
|
||||
{"group_id": str(rollout_group_id)},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
if not rows:
|
||||
raise HTTPException(404, "Rollout group not found")
|
||||
|
||||
# Compute summary
|
||||
total = len(rows)
|
||||
completed = sum(1 for r in rows if r[2] == "completed")
|
||||
failed = sum(1 for r in rows if r[2] == "failed")
|
||||
paused = sum(1 for r in rows if r[2] == "paused")
|
||||
pending = sum(1 for r in rows if r[2] in ("pending", "scheduled"))
|
||||
|
||||
# Find currently running device
|
||||
active_statuses = {"downloading", "uploading", "rebooting", "verifying"}
|
||||
current_device = None
|
||||
for r in rows:
|
||||
if r[2] in active_statuses:
|
||||
current_device = r[8] or str(r[1])
|
||||
break
|
||||
|
||||
jobs = [
|
||||
{
|
||||
"id": str(r[0]),
|
||||
"device_id": str(r[1]),
|
||||
"status": r[2],
|
||||
"target_version": r[3],
|
||||
"architecture": r[4],
|
||||
"error_message": r[5],
|
||||
"started_at": r[6].isoformat() if r[6] else None,
|
||||
"completed_at": r[7].isoformat() if r[7] else None,
|
||||
"device_hostname": r[8],
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
return {
|
||||
"rollout_group_id": str(rollout_group_id),
|
||||
"total": total,
|
||||
"completed": completed,
|
||||
"failed": failed,
|
||||
"paused": paused,
|
||||
"pending": pending,
|
||||
"current_device": current_device,
|
||||
"jobs": jobs,
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/firmware/upgrades/{job_id}/cancel",
|
||||
summary="Cancel a scheduled or pending upgrade",
|
||||
dependencies=[require_scope("firmware:write")],
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def cancel_upgrade_endpoint(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
job_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, str]:
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
if current_user.role == "viewer":
|
||||
raise HTTPException(403, "Viewers cannot cancel upgrades")
|
||||
|
||||
from app.services.upgrade_service import cancel_upgrade
|
||||
await cancel_upgrade(str(job_id))
|
||||
return {"status": "ok", "message": "Upgrade cancelled"}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/firmware/upgrades/{job_id}/retry",
|
||||
summary="Retry a failed upgrade",
|
||||
dependencies=[require_scope("firmware:write")],
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def retry_upgrade_endpoint(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
job_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, str]:
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
if current_user.role == "viewer":
|
||||
raise HTTPException(403, "Viewers cannot retry upgrades")
|
||||
|
||||
from app.services.upgrade_service import retry_failed_upgrade
|
||||
await retry_failed_upgrade(str(job_id))
|
||||
return {"status": "ok", "message": "Upgrade retry started"}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/firmware/rollouts/{rollout_group_id}/resume",
|
||||
summary="Resume a paused mass rollout",
|
||||
dependencies=[require_scope("firmware:write")],
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def resume_rollout_endpoint(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
rollout_group_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, str]:
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
if current_user.role == "viewer":
|
||||
raise HTTPException(403, "Viewers cannot resume rollouts")
|
||||
|
||||
from app.services.upgrade_service import resume_mass_upgrade
|
||||
await resume_mass_upgrade(str(rollout_group_id))
|
||||
return {"status": "ok", "message": "Rollout resumed"}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/firmware/rollouts/{rollout_group_id}/abort",
|
||||
summary="Abort remaining devices in a paused rollout",
|
||||
dependencies=[require_scope("firmware:write")],
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def abort_rollout_endpoint(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
rollout_group_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, Any]:
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
if current_user.role == "viewer":
|
||||
raise HTTPException(403, "Viewers cannot abort rollouts")
|
||||
|
||||
from app.services.upgrade_service import abort_mass_upgrade
|
||||
aborted = await abort_mass_upgrade(str(rollout_group_id))
|
||||
return {"status": "ok", "aborted_count": aborted}
|
||||
309
backend/app/routers/maintenance_windows.py
Normal file
309
backend/app/routers/maintenance_windows.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""Maintenance windows API endpoints.
|
||||
|
||||
Tenant-scoped routes under /api/tenants/{tenant_id}/ for:
|
||||
- Maintenance window CRUD (list, create, update, delete)
|
||||
- Filterable by status: upcoming, active, past
|
||||
|
||||
RLS enforced via get_db() (app_user engine with tenant context).
|
||||
RBAC: operator and above for all operations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db, set_tenant_context
|
||||
from app.middleware.rate_limit import limiter
|
||||
from app.middleware.tenant_context import CurrentUser, get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["maintenance-windows"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _check_tenant_access(
|
||||
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
|
||||
) -> None:
|
||||
"""Verify the current user is allowed to access the given tenant."""
|
||||
if current_user.is_super_admin:
|
||||
await set_tenant_context(db, str(tenant_id))
|
||||
elif current_user.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this tenant",
|
||||
)
|
||||
|
||||
|
||||
def _require_operator(current_user: CurrentUser) -> None:
|
||||
"""Raise 403 if user does not have at least operator role."""
|
||||
if current_user.role == "viewer":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Requires at least operator role.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request/response schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MaintenanceWindowCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
name: str
|
||||
device_ids: list[str] = []
|
||||
start_at: datetime
|
||||
end_at: datetime
|
||||
suppress_alerts: bool = True
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
class MaintenanceWindowUpdate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
name: Optional[str] = None
|
||||
device_ids: Optional[list[str]] = None
|
||||
start_at: Optional[datetime] = None
|
||||
end_at: Optional[datetime] = None
|
||||
suppress_alerts: Optional[bool] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
class MaintenanceWindowResponse(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
id: str
|
||||
tenant_id: str
|
||||
name: str
|
||||
device_ids: list[str]
|
||||
start_at: str
|
||||
end_at: str
|
||||
suppress_alerts: bool
|
||||
notes: Optional[str] = None
|
||||
created_by: Optional[str] = None
|
||||
created_at: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CRUD endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/maintenance-windows",
|
||||
summary="List maintenance windows for tenant",
|
||||
)
|
||||
async def list_maintenance_windows(
|
||||
tenant_id: uuid.UUID,
|
||||
window_status: Optional[str] = Query(None, alias="status"),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[dict[str, Any]]:
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
_require_operator(current_user)
|
||||
|
||||
filters = ["1=1"]
|
||||
params: dict[str, Any] = {}
|
||||
|
||||
if window_status == "active":
|
||||
filters.append("mw.start_at <= NOW() AND mw.end_at >= NOW()")
|
||||
elif window_status == "upcoming":
|
||||
filters.append("mw.start_at > NOW()")
|
||||
elif window_status == "past":
|
||||
filters.append("mw.end_at < NOW()")
|
||||
|
||||
where = " AND ".join(filters)
|
||||
|
||||
result = await db.execute(
|
||||
text(f"""
|
||||
SELECT mw.id, mw.tenant_id, mw.name, mw.device_ids,
|
||||
mw.start_at, mw.end_at, mw.suppress_alerts,
|
||||
mw.notes, mw.created_by, mw.created_at
|
||||
FROM maintenance_windows mw
|
||||
WHERE {where}
|
||||
ORDER BY mw.start_at DESC
|
||||
"""),
|
||||
params,
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(row[0]),
|
||||
"tenant_id": str(row[1]),
|
||||
"name": row[2],
|
||||
"device_ids": row[3] if isinstance(row[3], list) else [],
|
||||
"start_at": row[4].isoformat() if row[4] else None,
|
||||
"end_at": row[5].isoformat() if row[5] else None,
|
||||
"suppress_alerts": row[6],
|
||||
"notes": row[7],
|
||||
"created_by": str(row[8]) if row[8] else None,
|
||||
"created_at": row[9].isoformat() if row[9] else None,
|
||||
}
|
||||
for row in result.fetchall()
|
||||
]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/maintenance-windows",
|
||||
summary="Create maintenance window",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def create_maintenance_window(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
body: MaintenanceWindowCreate,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, Any]:
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
_require_operator(current_user)
|
||||
|
||||
if body.end_at <= body.start_at:
|
||||
raise HTTPException(422, "end_at must be after start_at")
|
||||
|
||||
window_id = str(uuid.uuid4())
|
||||
|
||||
await db.execute(
|
||||
text("""
|
||||
INSERT INTO maintenance_windows
|
||||
(id, tenant_id, name, device_ids, start_at, end_at,
|
||||
suppress_alerts, notes, created_by)
|
||||
VALUES
|
||||
(CAST(:id AS uuid), CAST(:tenant_id AS uuid),
|
||||
:name, CAST(:device_ids AS jsonb), :start_at, :end_at,
|
||||
:suppress_alerts, :notes, CAST(:created_by AS uuid))
|
||||
"""),
|
||||
{
|
||||
"id": window_id,
|
||||
"tenant_id": str(tenant_id),
|
||||
"name": body.name,
|
||||
"device_ids": json.dumps(body.device_ids),
|
||||
"start_at": body.start_at,
|
||||
"end_at": body.end_at,
|
||||
"suppress_alerts": body.suppress_alerts,
|
||||
"notes": body.notes,
|
||||
"created_by": str(current_user.user_id),
|
||||
},
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"id": window_id,
|
||||
"tenant_id": str(tenant_id),
|
||||
"name": body.name,
|
||||
"device_ids": body.device_ids,
|
||||
"start_at": body.start_at.isoformat(),
|
||||
"end_at": body.end_at.isoformat(),
|
||||
"suppress_alerts": body.suppress_alerts,
|
||||
"notes": body.notes,
|
||||
"created_by": str(current_user.user_id),
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@router.put(
|
||||
"/tenants/{tenant_id}/maintenance-windows/{window_id}",
|
||||
summary="Update maintenance window",
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def update_maintenance_window(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
window_id: uuid.UUID,
|
||||
body: MaintenanceWindowUpdate,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, Any]:
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
_require_operator(current_user)
|
||||
|
||||
# Build dynamic SET clause for partial updates
|
||||
set_parts: list[str] = ["updated_at = NOW()"]
|
||||
params: dict[str, Any] = {"window_id": str(window_id)}
|
||||
|
||||
if body.name is not None:
|
||||
set_parts.append("name = :name")
|
||||
params["name"] = body.name
|
||||
if body.device_ids is not None:
|
||||
set_parts.append("device_ids = CAST(:device_ids AS jsonb)")
|
||||
params["device_ids"] = json.dumps(body.device_ids)
|
||||
if body.start_at is not None:
|
||||
set_parts.append("start_at = :start_at")
|
||||
params["start_at"] = body.start_at
|
||||
if body.end_at is not None:
|
||||
set_parts.append("end_at = :end_at")
|
||||
params["end_at"] = body.end_at
|
||||
if body.suppress_alerts is not None:
|
||||
set_parts.append("suppress_alerts = :suppress_alerts")
|
||||
params["suppress_alerts"] = body.suppress_alerts
|
||||
if body.notes is not None:
|
||||
set_parts.append("notes = :notes")
|
||||
params["notes"] = body.notes
|
||||
|
||||
set_clause = ", ".join(set_parts)
|
||||
|
||||
result = await db.execute(
|
||||
text(f"""
|
||||
UPDATE maintenance_windows
|
||||
SET {set_clause}
|
||||
WHERE id = CAST(:window_id AS uuid)
|
||||
RETURNING id, tenant_id, name, device_ids, start_at, end_at,
|
||||
suppress_alerts, notes, created_by, created_at
|
||||
"""),
|
||||
params,
|
||||
)
|
||||
row = result.fetchone()
|
||||
if not row:
|
||||
raise HTTPException(404, "Maintenance window not found")
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"id": str(row[0]),
|
||||
"tenant_id": str(row[1]),
|
||||
"name": row[2],
|
||||
"device_ids": row[3] if isinstance(row[3], list) else [],
|
||||
"start_at": row[4].isoformat() if row[4] else None,
|
||||
"end_at": row[5].isoformat() if row[5] else None,
|
||||
"suppress_alerts": row[6],
|
||||
"notes": row[7],
|
||||
"created_by": str(row[8]) if row[8] else None,
|
||||
"created_at": row[9].isoformat() if row[9] else None,
|
||||
}
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/tenants/{tenant_id}/maintenance-windows/{window_id}",
|
||||
summary="Delete maintenance window",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def delete_maintenance_window(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
window_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
_require_operator(current_user)
|
||||
|
||||
result = await db.execute(
|
||||
text(
|
||||
"DELETE FROM maintenance_windows WHERE id = CAST(:id AS uuid) RETURNING id"
|
||||
),
|
||||
{"id": str(window_id)},
|
||||
)
|
||||
if not result.fetchone():
|
||||
raise HTTPException(404, "Maintenance window not found")
|
||||
await db.commit()
|
||||
414
backend/app/routers/metrics.py
Normal file
414
backend/app/routers/metrics.py
Normal file
@@ -0,0 +1,414 @@
|
||||
"""
|
||||
Metrics API endpoints for querying TimescaleDB hypertables.
|
||||
|
||||
All device-scoped routes are tenant-scoped under
|
||||
/api/tenants/{tenant_id}/devices/{device_id}/metrics/*.
|
||||
Fleet summary endpoints are under /api/tenants/{tenant_id}/fleet/summary
|
||||
and /api/fleet/summary (super_admin cross-tenant).
|
||||
|
||||
RLS is enforced via get_db() — the app_user engine applies tenant filtering
|
||||
automatically based on the SET LOCAL app.current_tenant context.
|
||||
|
||||
All endpoints require authentication (get_current_user) and enforce
|
||||
tenant access via _check_tenant_access.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.middleware.tenant_context import CurrentUser, get_current_user
|
||||
|
||||
router = APIRouter(tags=["metrics"])
|
||||
|
||||
|
||||
def _bucket_for_range(start: datetime, end: datetime) -> timedelta:
|
||||
"""
|
||||
Select an appropriate time_bucket size based on the requested time range.
|
||||
|
||||
Shorter ranges get finer granularity; longer ranges get coarser buckets
|
||||
to keep result sets manageable.
|
||||
|
||||
Returns a timedelta because asyncpg requires a Python timedelta (not a
|
||||
string interval literal) when binding the first argument of time_bucket().
|
||||
"""
|
||||
delta = end - start
|
||||
hours = delta.total_seconds() / 3600
|
||||
if hours <= 1:
|
||||
return timedelta(minutes=1)
|
||||
elif hours <= 6:
|
||||
return timedelta(minutes=5)
|
||||
elif hours <= 24:
|
||||
return timedelta(minutes=15)
|
||||
elif hours <= 168: # 7 days
|
||||
return timedelta(hours=1)
|
||||
elif hours <= 720: # 30 days
|
||||
return timedelta(hours=6)
|
||||
else:
|
||||
return timedelta(days=1)
|
||||
|
||||
|
||||
async def _check_tenant_access(
|
||||
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
|
||||
) -> None:
|
||||
"""
|
||||
Verify the current user is allowed to access the given tenant.
|
||||
|
||||
- super_admin can access any tenant — re-sets DB tenant context to target tenant.
|
||||
- All other roles must match their own tenant_id.
|
||||
"""
|
||||
if current_user.is_super_admin:
|
||||
# Re-set tenant context to the target tenant so RLS allows the operation
|
||||
from app.database import set_tenant_context
|
||||
await set_tenant_context(db, str(tenant_id))
|
||||
return
|
||||
if current_user.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied: you do not belong to this tenant.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health metrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/metrics/health",
|
||||
summary="Time-bucketed health metrics (CPU, memory, disk, temperature)",
|
||||
)
|
||||
async def device_health_metrics(
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
start: datetime = Query(..., description="Start of time range (ISO format)"),
|
||||
end: datetime = Query(..., description="End of time range (ISO format)"),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Return time-bucketed CPU, memory, disk, and temperature metrics for a device.
|
||||
|
||||
Bucket size adapts automatically to the requested time range.
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
bucket = _bucket_for_range(start, end)
|
||||
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT
|
||||
time_bucket(:bucket, time) AS bucket,
|
||||
avg(cpu_load)::smallint AS avg_cpu,
|
||||
max(cpu_load)::smallint AS max_cpu,
|
||||
avg(CASE WHEN total_memory > 0
|
||||
THEN round((1 - free_memory::float / total_memory) * 100)
|
||||
ELSE NULL END)::smallint AS avg_mem_pct,
|
||||
avg(CASE WHEN total_disk > 0
|
||||
THEN round((1 - free_disk::float / total_disk) * 100)
|
||||
ELSE NULL END)::smallint AS avg_disk_pct,
|
||||
avg(temperature)::smallint AS avg_temp
|
||||
FROM health_metrics
|
||||
WHERE device_id = :device_id
|
||||
AND time >= :start AND time < :end
|
||||
GROUP BY bucket
|
||||
ORDER BY bucket ASC
|
||||
"""),
|
||||
{"bucket": bucket, "device_id": str(device_id), "start": start, "end": end},
|
||||
)
|
||||
rows = result.mappings().all()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Interface traffic metrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/metrics/interfaces",
|
||||
summary="Time-bucketed interface bandwidth metrics (bps from cumulative byte deltas)",
|
||||
)
|
||||
async def device_interface_metrics(
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
start: datetime = Query(..., description="Start of time range (ISO format)"),
|
||||
end: datetime = Query(..., description="End of time range (ISO format)"),
|
||||
interface: Optional[str] = Query(None, description="Filter to a specific interface name"),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Return time-bucketed interface traffic metrics for a device.
|
||||
|
||||
Bandwidth (bps) is computed from raw cumulative byte counters using
|
||||
SQL LAG() window functions — no poller-side state is required.
|
||||
Counter wraps (rx_bytes < prev_rx) are treated as NULL to avoid
|
||||
incorrect spikes.
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
bucket = _bucket_for_range(start, end)
|
||||
|
||||
# Build interface filter clause conditionally.
|
||||
# The interface name is passed as a bind parameter — never interpolated
|
||||
# into the SQL string — so this is safe from SQL injection.
|
||||
interface_filter = "AND interface = :interface" if interface else ""
|
||||
|
||||
sql = f"""
|
||||
WITH ordered AS (
|
||||
SELECT
|
||||
time,
|
||||
interface,
|
||||
rx_bytes,
|
||||
tx_bytes,
|
||||
LAG(rx_bytes) OVER (PARTITION BY interface ORDER BY time) AS prev_rx,
|
||||
LAG(tx_bytes) OVER (PARTITION BY interface ORDER BY time) AS prev_tx,
|
||||
EXTRACT(EPOCH FROM time - LAG(time) OVER (PARTITION BY interface ORDER BY time)) AS dt
|
||||
FROM interface_metrics
|
||||
WHERE device_id = :device_id
|
||||
AND time >= :start AND time < :end
|
||||
{interface_filter}
|
||||
),
|
||||
with_bps AS (
|
||||
SELECT
|
||||
time,
|
||||
interface,
|
||||
rx_bytes,
|
||||
tx_bytes,
|
||||
CASE WHEN rx_bytes >= prev_rx AND dt > 0
|
||||
THEN ((rx_bytes - prev_rx) * 8 / dt)::bigint
|
||||
ELSE NULL END AS rx_bps,
|
||||
CASE WHEN tx_bytes >= prev_tx AND dt > 0
|
||||
THEN ((tx_bytes - prev_tx) * 8 / dt)::bigint
|
||||
ELSE NULL END AS tx_bps
|
||||
FROM ordered
|
||||
WHERE prev_rx IS NOT NULL
|
||||
)
|
||||
SELECT
|
||||
time_bucket(:bucket, time) AS bucket,
|
||||
interface,
|
||||
avg(rx_bps)::bigint AS avg_rx_bps,
|
||||
avg(tx_bps)::bigint AS avg_tx_bps,
|
||||
max(rx_bps)::bigint AS max_rx_bps,
|
||||
max(tx_bps)::bigint AS max_tx_bps
|
||||
FROM with_bps
|
||||
WHERE rx_bps IS NOT NULL
|
||||
GROUP BY bucket, interface
|
||||
ORDER BY interface, bucket ASC
|
||||
"""
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"bucket": bucket,
|
||||
"device_id": str(device_id),
|
||||
"start": start,
|
||||
"end": end,
|
||||
}
|
||||
if interface:
|
||||
params["interface"] = interface
|
||||
|
||||
result = await db.execute(text(sql), params)
|
||||
rows = result.mappings().all()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/metrics/interfaces/list",
|
||||
summary="List distinct interface names for a device",
|
||||
)
|
||||
async def device_interface_list(
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[str]:
|
||||
"""Return distinct interface names seen in interface_metrics for a device."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT DISTINCT interface
|
||||
FROM interface_metrics
|
||||
WHERE device_id = :device_id
|
||||
ORDER BY interface
|
||||
"""),
|
||||
{"device_id": str(device_id)},
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
return list(rows)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wireless metrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/metrics/wireless",
|
||||
summary="Time-bucketed wireless metrics (clients, signal, CCQ)",
|
||||
)
|
||||
async def device_wireless_metrics(
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
start: datetime = Query(..., description="Start of time range (ISO format)"),
|
||||
end: datetime = Query(..., description="End of time range (ISO format)"),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return time-bucketed wireless metrics per interface for a device."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
bucket = _bucket_for_range(start, end)
|
||||
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT
|
||||
time_bucket(:bucket, time) AS bucket,
|
||||
interface,
|
||||
avg(client_count)::smallint AS avg_clients,
|
||||
max(client_count)::smallint AS max_clients,
|
||||
avg(avg_signal)::smallint AS avg_signal,
|
||||
avg(ccq)::smallint AS avg_ccq,
|
||||
max(frequency) AS frequency
|
||||
FROM wireless_metrics
|
||||
WHERE device_id = :device_id
|
||||
AND time >= :start AND time < :end
|
||||
GROUP BY bucket, interface
|
||||
ORDER BY interface, bucket ASC
|
||||
"""),
|
||||
{"bucket": bucket, "device_id": str(device_id), "start": start, "end": end},
|
||||
)
|
||||
rows = result.mappings().all()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/metrics/wireless/latest",
|
||||
summary="Latest wireless stats per interface (not time-bucketed)",
|
||||
)
|
||||
async def device_wireless_latest(
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return the most recent wireless reading per interface for a device."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT DISTINCT ON (interface)
|
||||
interface, client_count, avg_signal, ccq, frequency, time
|
||||
FROM wireless_metrics
|
||||
WHERE device_id = :device_id
|
||||
ORDER BY interface, time DESC
|
||||
"""),
|
||||
{"device_id": str(device_id)},
|
||||
)
|
||||
rows = result.mappings().all()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sparkline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/devices/{device_id}/metrics/sparkline",
|
||||
summary="Last 12 health readings for sparkline display",
|
||||
)
|
||||
async def device_sparkline(
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Return the last 12 CPU readings (in chronological order) for sparkline
|
||||
display in the fleet table.
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT cpu_load, time
|
||||
FROM (
|
||||
SELECT cpu_load, time
|
||||
FROM health_metrics
|
||||
WHERE device_id = :device_id
|
||||
ORDER BY time DESC
|
||||
LIMIT 12
|
||||
) sub
|
||||
ORDER BY time ASC
|
||||
"""),
|
||||
{"device_id": str(device_id)},
|
||||
)
|
||||
rows = result.mappings().all()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fleet summary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FLEET_SUMMARY_SQL = """
|
||||
SELECT
|
||||
d.id, d.hostname, d.ip_address, d.status, d.model, d.last_seen,
|
||||
d.uptime_seconds, d.last_cpu_load, d.last_memory_used_pct,
|
||||
d.latitude, d.longitude,
|
||||
d.tenant_id, t.name AS tenant_name
|
||||
FROM devices d
|
||||
JOIN tenants t ON d.tenant_id = t.id
|
||||
ORDER BY t.name, d.hostname
|
||||
"""
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/fleet/summary",
|
||||
summary="Fleet summary for a tenant (latest metrics per device)",
|
||||
)
|
||||
async def fleet_summary(
|
||||
tenant_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Return fleet summary for a single tenant.
|
||||
|
||||
Queries the devices table (not hypertables) for speed.
|
||||
RLS filters to only devices belonging to the tenant automatically.
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
result = await db.execute(text(_FLEET_SUMMARY_SQL))
|
||||
rows = result.mappings().all()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/fleet/summary",
|
||||
summary="Cross-tenant fleet summary (super_admin only)",
|
||||
)
|
||||
async def fleet_summary_all(
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Return fleet summary across ALL tenants.
|
||||
|
||||
Requires super_admin role. The RLS policy for super_admin returns all
|
||||
rows across all tenants, so the same SQL query works without modification.
|
||||
This avoids the N+1 problem of fetching per-tenant summaries in a loop.
|
||||
"""
|
||||
if current_user.role != "super_admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Super admin required",
|
||||
)
|
||||
|
||||
result = await db.execute(text(_FLEET_SUMMARY_SQL))
|
||||
rows = result.mappings().all()
|
||||
return [dict(row) for row in rows]
|
||||
146
backend/app/routers/reports.py
Normal file
146
backend/app/routers/reports.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Report generation API endpoint.
|
||||
|
||||
POST /api/tenants/{tenant_id}/reports/generate
|
||||
Generates PDF or CSV reports for device inventory, metrics summary,
|
||||
alert history, and change log.
|
||||
|
||||
RLS enforced via get_db() (app_user engine with tenant context).
|
||||
RBAC: require at least operator role.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db, set_tenant_context
|
||||
from app.middleware.tenant_context import CurrentUser, get_current_user
|
||||
from app.services.report_service import generate_report
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
router = APIRouter(tags=["reports"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _check_tenant_access(
|
||||
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
|
||||
) -> None:
|
||||
"""Verify the current user is allowed to access the given tenant."""
|
||||
if current_user.is_super_admin:
|
||||
await set_tenant_context(db, str(tenant_id))
|
||||
elif current_user.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this tenant",
|
||||
)
|
||||
|
||||
|
||||
def _require_operator(current_user: CurrentUser) -> None:
|
||||
"""Raise 403 if user is a viewer (reports require operator+)."""
|
||||
if current_user.role == "viewer":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Reports require at least operator role.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ReportType(str, Enum):
|
||||
device_inventory = "device_inventory"
|
||||
metrics_summary = "metrics_summary"
|
||||
alert_history = "alert_history"
|
||||
change_log = "change_log"
|
||||
|
||||
|
||||
class ReportFormat(str, Enum):
|
||||
pdf = "pdf"
|
||||
csv = "csv"
|
||||
|
||||
|
||||
class ReportRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: ReportType
|
||||
date_from: Optional[datetime] = None
|
||||
date_to: Optional[datetime] = None
|
||||
format: ReportFormat = ReportFormat.pdf
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/reports/generate",
|
||||
summary="Generate a report (PDF or CSV)",
|
||||
response_class=StreamingResponse,
|
||||
)
|
||||
async def generate_report_endpoint(
|
||||
tenant_id: uuid.UUID,
|
||||
body: ReportRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> StreamingResponse:
|
||||
"""Generate and download a report as PDF or CSV.
|
||||
|
||||
- device_inventory: no date range required
|
||||
- metrics_summary, alert_history, change_log: date_from and date_to required
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
_require_operator(current_user)
|
||||
|
||||
# Validate date range for time-based reports
|
||||
if body.type != ReportType.device_inventory:
|
||||
if not body.date_from or not body.date_to:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=f"date_from and date_to are required for {body.type.value} reports.",
|
||||
)
|
||||
if body.date_from > body.date_to:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="date_from must be before date_to.",
|
||||
)
|
||||
|
||||
try:
|
||||
file_bytes, content_type, filename = await generate_report(
|
||||
db=db,
|
||||
tenant_id=tenant_id,
|
||||
report_type=body.type.value,
|
||||
date_from=body.date_from,
|
||||
date_to=body.date_to,
|
||||
fmt=body.format.value,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("report_generation_failed", error=str(exc), report_type=body.type.value)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Report generation failed: {str(exc)}",
|
||||
)
|
||||
|
||||
import io
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(file_bytes),
|
||||
media_type=content_type,
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
"Content-Length": str(len(file_bytes)),
|
||||
},
|
||||
)
|
||||
155
backend/app/routers/settings.py
Normal file
155
backend/app/routers/settings.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""System settings router — global SMTP configuration.
|
||||
|
||||
Super-admin only. Stores SMTP settings in system_settings table with
|
||||
Transit encryption for passwords. Falls back to .env values.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
from app.middleware.rbac import require_role
|
||||
from app.services.email_service import SMTPConfig, send_test_email, test_smtp_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/settings", tags=["settings"])
|
||||
|
||||
SMTP_KEYS = [
|
||||
"smtp_host",
|
||||
"smtp_port",
|
||||
"smtp_user",
|
||||
"smtp_password",
|
||||
"smtp_use_tls",
|
||||
"smtp_from_address",
|
||||
"smtp_provider",
|
||||
]
|
||||
|
||||
|
||||
class SMTPSettingsUpdate(BaseModel):
|
||||
smtp_host: str
|
||||
smtp_port: int = 587
|
||||
smtp_user: Optional[str] = None
|
||||
smtp_password: Optional[str] = None
|
||||
smtp_use_tls: bool = False
|
||||
smtp_from_address: str = "noreply@example.com"
|
||||
smtp_provider: str = "custom"
|
||||
|
||||
|
||||
class SMTPTestRequest(BaseModel):
|
||||
to: str
|
||||
smtp_host: Optional[str] = None
|
||||
smtp_port: Optional[int] = None
|
||||
smtp_user: Optional[str] = None
|
||||
smtp_password: Optional[str] = None
|
||||
smtp_use_tls: Optional[bool] = None
|
||||
smtp_from_address: Optional[str] = None
|
||||
|
||||
|
||||
async def _get_system_settings(keys: list[str]) -> dict:
|
||||
"""Read settings from system_settings table."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("SELECT key, value FROM system_settings WHERE key = ANY(:keys)"),
|
||||
{"keys": keys},
|
||||
)
|
||||
return {row[0]: row[1] for row in result.fetchall()}
|
||||
|
||||
|
||||
async def _set_system_settings(updates: dict, user_id: str) -> None:
|
||||
"""Upsert settings into system_settings table."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
for key, value in updates.items():
|
||||
await session.execute(
|
||||
text("""
|
||||
INSERT INTO system_settings (key, value, updated_by, updated_at)
|
||||
VALUES (:key, :value, CAST(:user_id AS uuid), now())
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET value = :value, updated_by = CAST(:user_id AS uuid), updated_at = now()
|
||||
"""),
|
||||
{"key": key, "value": str(value) if value is not None else None, "user_id": user_id},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def get_smtp_config() -> SMTPConfig:
|
||||
"""Get SMTP config from system_settings, falling back to .env."""
|
||||
db_settings = await _get_system_settings(SMTP_KEYS)
|
||||
|
||||
return SMTPConfig(
|
||||
host=db_settings.get("smtp_host") or settings.SMTP_HOST,
|
||||
port=int(db_settings.get("smtp_port") or settings.SMTP_PORT),
|
||||
user=db_settings.get("smtp_user") or settings.SMTP_USER,
|
||||
password=db_settings.get("smtp_password") or settings.SMTP_PASSWORD,
|
||||
use_tls=(db_settings.get("smtp_use_tls") or str(settings.SMTP_USE_TLS)).lower() == "true",
|
||||
from_address=db_settings.get("smtp_from_address") or settings.SMTP_FROM_ADDRESS,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/smtp")
|
||||
async def get_smtp_settings(user=Depends(require_role("super_admin"))):
|
||||
"""Get current global SMTP configuration. Password is redacted."""
|
||||
db_settings = await _get_system_settings(SMTP_KEYS)
|
||||
|
||||
return {
|
||||
"smtp_host": db_settings.get("smtp_host") or settings.SMTP_HOST,
|
||||
"smtp_port": int(db_settings.get("smtp_port") or settings.SMTP_PORT),
|
||||
"smtp_user": db_settings.get("smtp_user") or settings.SMTP_USER or "",
|
||||
"smtp_use_tls": (db_settings.get("smtp_use_tls") or str(settings.SMTP_USE_TLS)).lower() == "true",
|
||||
"smtp_from_address": db_settings.get("smtp_from_address") or settings.SMTP_FROM_ADDRESS,
|
||||
"smtp_provider": db_settings.get("smtp_provider") or "custom",
|
||||
"smtp_password_set": bool(db_settings.get("smtp_password") or settings.SMTP_PASSWORD),
|
||||
"source": "database" if db_settings.get("smtp_host") else "environment",
|
||||
}
|
||||
|
||||
|
||||
@router.put("/smtp")
|
||||
async def update_smtp_settings(
|
||||
data: SMTPSettingsUpdate,
|
||||
user=Depends(require_role("super_admin")),
|
||||
):
|
||||
"""Update global SMTP configuration."""
|
||||
updates = {
|
||||
"smtp_host": data.smtp_host,
|
||||
"smtp_port": str(data.smtp_port),
|
||||
"smtp_user": data.smtp_user,
|
||||
"smtp_use_tls": str(data.smtp_use_tls).lower(),
|
||||
"smtp_from_address": data.smtp_from_address,
|
||||
"smtp_provider": data.smtp_provider,
|
||||
}
|
||||
if data.smtp_password is not None:
|
||||
updates["smtp_password"] = data.smtp_password
|
||||
|
||||
await _set_system_settings(updates, str(user.id))
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post("/smtp/test")
|
||||
async def test_smtp(
|
||||
data: SMTPTestRequest,
|
||||
user=Depends(require_role("super_admin")),
|
||||
):
|
||||
"""Test SMTP connection and optionally send a test email."""
|
||||
# Use provided values or fall back to saved config
|
||||
saved = await get_smtp_config()
|
||||
config = SMTPConfig(
|
||||
host=data.smtp_host or saved.host,
|
||||
port=data.smtp_port if data.smtp_port is not None else saved.port,
|
||||
user=data.smtp_user if data.smtp_user is not None else saved.user,
|
||||
password=data.smtp_password if data.smtp_password is not None else saved.password,
|
||||
use_tls=data.smtp_use_tls if data.smtp_use_tls is not None else saved.use_tls,
|
||||
from_address=data.smtp_from_address or saved.from_address,
|
||||
)
|
||||
|
||||
conn_result = await test_smtp_connection(config)
|
||||
if not conn_result["success"]:
|
||||
return conn_result
|
||||
|
||||
if data.to:
|
||||
return await send_test_email(data.to, config)
|
||||
|
||||
return conn_result
|
||||
141
backend/app/routers/sse.py
Normal file
141
backend/app/routers/sse.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""SSE streaming endpoint for real-time event delivery.
|
||||
|
||||
Provides a Server-Sent Events endpoint per tenant that streams device status,
|
||||
alert, config push, and firmware progress events in real time. Authentication
|
||||
is via a short-lived, single-use exchange token (obtained from POST /auth/sse-token)
|
||||
to avoid exposing the full JWT in query parameters.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
import structlog
|
||||
from fastapi import APIRouter, HTTPException, Query, Request, status
|
||||
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
|
||||
|
||||
from app.services.sse_manager import SSEConnectionManager
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
router = APIRouter(tags=["sse"])
|
||||
|
||||
# ─── Redis for SSE token validation ───────────────────────────────────────────
|
||||
|
||||
_redis: aioredis.Redis | None = None
|
||||
|
||||
|
||||
async def _get_sse_redis() -> aioredis.Redis:
|
||||
"""Lazily initialise and return the SSE Redis client."""
|
||||
global _redis
|
||||
if _redis is None:
|
||||
from app.config import settings
|
||||
_redis = aioredis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
return _redis
|
||||
|
||||
|
||||
async def _validate_sse_token(token: str) -> dict:
|
||||
"""Validate a short-lived SSE exchange token via Redis.
|
||||
|
||||
The token is single-use: retrieved and deleted atomically with GETDEL.
|
||||
If the token is not found (expired or already used), raises 401.
|
||||
|
||||
Args:
|
||||
token: SSE exchange token string (from query param).
|
||||
|
||||
Returns:
|
||||
Dict with user_id, tenant_id, and role.
|
||||
|
||||
Raises:
|
||||
HTTPException 401: If the token is invalid, expired, or already used.
|
||||
"""
|
||||
redis = await _get_sse_redis()
|
||||
key = f"sse_token:{token}"
|
||||
data = await redis.getdel(key) # Single-use: delete on retrieval
|
||||
if not data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired SSE token",
|
||||
)
|
||||
return json.loads(data)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/events/stream",
|
||||
summary="SSE event stream for real-time tenant events",
|
||||
response_class=EventSourceResponse,
|
||||
)
|
||||
async def event_stream(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
token: str = Query(..., description="Short-lived SSE exchange token (from POST /auth/sse-token)"),
|
||||
) -> EventSourceResponse:
|
||||
"""Stream real-time events for a tenant via Server-Sent Events.
|
||||
|
||||
Event types: device_status, alert_fired, alert_resolved, config_push,
|
||||
firmware_progress, metric_update.
|
||||
|
||||
Supports Last-Event-ID header for reconnection replay.
|
||||
Sends heartbeat comments every 15 seconds on idle connections.
|
||||
"""
|
||||
# Validate exchange token from query parameter (single-use, 30s TTL)
|
||||
user_context = await _validate_sse_token(token)
|
||||
user_role = user_context.get("role", "")
|
||||
user_tenant_id = user_context.get("tenant_id")
|
||||
user_id = user_context.get("user_id", "")
|
||||
|
||||
# Authorization: user must belong to the requested tenant or be super_admin
|
||||
if user_role != "super_admin" and (user_tenant_id is None or str(user_tenant_id) != str(tenant_id)):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not authorized for this tenant",
|
||||
)
|
||||
|
||||
# super_admin receives events from ALL tenants (tenant_id filter = None)
|
||||
filter_tenant_id: Optional[str] = None if user_role == "super_admin" else str(tenant_id)
|
||||
|
||||
# Generate unique connection ID
|
||||
connection_id = f"sse-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
# Check for Last-Event-ID header (reconnection replay)
|
||||
last_event_id = request.headers.get("Last-Event-ID")
|
||||
|
||||
logger.info(
|
||||
"sse.stream_requested",
|
||||
connection_id=connection_id,
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=user_id,
|
||||
role=user_role,
|
||||
last_event_id=last_event_id,
|
||||
)
|
||||
|
||||
manager = SSEConnectionManager()
|
||||
queue = await manager.connect(
|
||||
connection_id=connection_id,
|
||||
tenant_id=filter_tenant_id,
|
||||
last_event_id=last_event_id,
|
||||
)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[ServerSentEvent, None]:
|
||||
"""Yield SSE events from the queue with 15s heartbeat on idle."""
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
event = await asyncio.wait_for(queue.get(), timeout=15.0)
|
||||
yield ServerSentEvent(
|
||||
data=event["data"],
|
||||
event=event["event"],
|
||||
id=event["id"],
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat comment to keep connection alive
|
||||
yield ServerSentEvent(comment="heartbeat")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
finally:
|
||||
await manager.disconnect()
|
||||
logger.info("sse.stream_closed", connection_id=connection_id)
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
613
backend/app/routers/templates.py
Normal file
613
backend/app/routers/templates.py
Normal file
@@ -0,0 +1,613 @@
|
||||
"""
|
||||
Config template CRUD, preview, and push API endpoints.
|
||||
|
||||
All routes are tenant-scoped under:
|
||||
/api/tenants/{tenant_id}/templates/
|
||||
|
||||
Provides:
|
||||
- GET /templates -- list templates (optional tag filter)
|
||||
- POST /templates -- create a template
|
||||
- GET /templates/{id} -- get single template
|
||||
- PUT /templates/{id} -- update a template
|
||||
- DELETE /templates/{id} -- delete a template
|
||||
- POST /templates/{id}/preview -- preview rendered template for a device
|
||||
- POST /templates/{id}/push -- push template to devices (sequential rollout)
|
||||
- GET /templates/push-status/{rollout_id} -- poll push progress
|
||||
|
||||
RLS is enforced via get_db() (app_user engine with tenant context).
|
||||
RBAC: viewer = read (GET/preview); operator and above = write (POST/PUT/DELETE/push).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.database import get_db
|
||||
from app.middleware.rate_limit import limiter
|
||||
from app.middleware.rbac import require_min_role, require_scope
|
||||
from app.middleware.tenant_context import CurrentUser, get_current_user
|
||||
from app.models.config_template import ConfigTemplate, ConfigTemplateTag, TemplatePushJob
|
||||
from app.models.device import Device
|
||||
from app.services import template_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["templates"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _check_tenant_access(
|
||||
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
|
||||
) -> None:
|
||||
"""Verify the current user is allowed to access the given tenant."""
|
||||
if current_user.is_super_admin:
|
||||
from app.database import set_tenant_context
|
||||
await set_tenant_context(db, str(tenant_id))
|
||||
return
|
||||
if current_user.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied: you do not belong to this tenant.",
|
||||
)
|
||||
|
||||
|
||||
def _serialize_template(template: ConfigTemplate, include_content: bool = False) -> dict:
|
||||
"""Serialize a ConfigTemplate to a response dict."""
|
||||
result: dict[str, Any] = {
|
||||
"id": str(template.id),
|
||||
"name": template.name,
|
||||
"description": template.description,
|
||||
"tags": [tag.name for tag in template.tags],
|
||||
"variable_count": len(template.variables) if template.variables else 0,
|
||||
"created_at": template.created_at.isoformat(),
|
||||
"updated_at": template.updated_at.isoformat(),
|
||||
}
|
||||
if include_content:
|
||||
result["content"] = template.content
|
||||
result["variables"] = template.variables or []
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request/Response schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class VariableDef(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
name: str
|
||||
type: str = "string" # string | ip | integer | boolean | subnet
|
||||
default: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class TemplateCreateRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
content: str
|
||||
variables: list[VariableDef] = []
|
||||
tags: list[str] = []
|
||||
|
||||
|
||||
class TemplateUpdateRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
content: str
|
||||
variables: list[VariableDef] = []
|
||||
tags: list[str] = []
|
||||
|
||||
|
||||
class PreviewRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
device_id: str
|
||||
variables: dict[str, str] = {}
|
||||
|
||||
|
||||
class PushRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
device_ids: list[str]
|
||||
variables: dict[str, str] = {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CRUD endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/templates",
|
||||
summary="List config templates",
|
||||
dependencies=[require_scope("config:read")],
|
||||
)
|
||||
async def list_templates(
|
||||
tenant_id: uuid.UUID,
|
||||
tag: Optional[str] = Query(None, description="Filter by tag name"),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[dict]:
|
||||
"""List all config templates for a tenant with optional tag filtering."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
query = (
|
||||
select(ConfigTemplate)
|
||||
.options(selectinload(ConfigTemplate.tags))
|
||||
.where(ConfigTemplate.tenant_id == tenant_id) # type: ignore[arg-type]
|
||||
.order_by(ConfigTemplate.updated_at.desc())
|
||||
)
|
||||
|
||||
if tag:
|
||||
query = query.where(
|
||||
ConfigTemplate.id.in_( # type: ignore[attr-defined]
|
||||
select(ConfigTemplateTag.template_id).where(
|
||||
ConfigTemplateTag.name == tag,
|
||||
ConfigTemplateTag.tenant_id == tenant_id, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
templates = result.scalars().all()
|
||||
|
||||
return [_serialize_template(t) for t in templates]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/templates",
|
||||
summary="Create a config template",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
dependencies=[require_scope("config:write")],
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def create_template(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
body: TemplateCreateRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("operator")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Create a new config template with Jinja2 content and variable definitions."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
# Auto-extract variables from content for comparison
|
||||
detected = template_service.extract_variables(body.content)
|
||||
provided_names = {v.name for v in body.variables}
|
||||
unmatched = set(detected) - provided_names
|
||||
if unmatched:
|
||||
logger.warning(
|
||||
"Template '%s' has undeclared variables: %s (auto-adding as string type)",
|
||||
body.name, unmatched,
|
||||
)
|
||||
|
||||
# Create template
|
||||
template = ConfigTemplate(
|
||||
tenant_id=tenant_id,
|
||||
name=body.name,
|
||||
description=body.description,
|
||||
content=body.content,
|
||||
variables=[v.model_dump() for v in body.variables],
|
||||
)
|
||||
db.add(template)
|
||||
await db.flush() # Get the generated ID
|
||||
|
||||
# Create tags
|
||||
for tag_name in body.tags:
|
||||
tag = ConfigTemplateTag(
|
||||
tenant_id=tenant_id,
|
||||
name=tag_name,
|
||||
template_id=template.id,
|
||||
)
|
||||
db.add(tag)
|
||||
|
||||
await db.flush()
|
||||
|
||||
# Re-query with tags loaded
|
||||
result = await db.execute(
|
||||
select(ConfigTemplate)
|
||||
.options(selectinload(ConfigTemplate.tags))
|
||||
.where(ConfigTemplate.id == template.id) # type: ignore[arg-type]
|
||||
)
|
||||
template = result.scalar_one()
|
||||
|
||||
return _serialize_template(template, include_content=True)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/templates/{template_id}",
|
||||
summary="Get a single config template",
|
||||
dependencies=[require_scope("config:read")],
|
||||
)
|
||||
async def get_template(
|
||||
tenant_id: uuid.UUID,
|
||||
template_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Get a config template with full content, variables, and tags."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(ConfigTemplate)
|
||||
.options(selectinload(ConfigTemplate.tags))
|
||||
.where(
|
||||
ConfigTemplate.id == template_id, # type: ignore[arg-type]
|
||||
ConfigTemplate.tenant_id == tenant_id, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
if template is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Template {template_id} not found",
|
||||
)
|
||||
|
||||
return _serialize_template(template, include_content=True)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/tenants/{tenant_id}/templates/{template_id}",
|
||||
summary="Update a config template",
|
||||
dependencies=[require_scope("config:write")],
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def update_template(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
template_id: uuid.UUID,
|
||||
body: TemplateUpdateRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("operator")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Update an existing config template."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(ConfigTemplate)
|
||||
.options(selectinload(ConfigTemplate.tags))
|
||||
.where(
|
||||
ConfigTemplate.id == template_id, # type: ignore[arg-type]
|
||||
ConfigTemplate.tenant_id == tenant_id, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
if template is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Template {template_id} not found",
|
||||
)
|
||||
|
||||
# Update fields
|
||||
template.name = body.name
|
||||
template.description = body.description
|
||||
template.content = body.content
|
||||
template.variables = [v.model_dump() for v in body.variables]
|
||||
|
||||
# Replace tags: delete old, create new
|
||||
await db.execute(
|
||||
delete(ConfigTemplateTag).where(
|
||||
ConfigTemplateTag.template_id == template_id # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
for tag_name in body.tags:
|
||||
tag = ConfigTemplateTag(
|
||||
tenant_id=tenant_id,
|
||||
name=tag_name,
|
||||
template_id=template.id,
|
||||
)
|
||||
db.add(tag)
|
||||
|
||||
await db.flush()
|
||||
|
||||
# Re-query with fresh tags
|
||||
result = await db.execute(
|
||||
select(ConfigTemplate)
|
||||
.options(selectinload(ConfigTemplate.tags))
|
||||
.where(ConfigTemplate.id == template.id) # type: ignore[arg-type]
|
||||
)
|
||||
template = result.scalar_one()
|
||||
|
||||
return _serialize_template(template, include_content=True)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/tenants/{tenant_id}/templates/{template_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a config template",
|
||||
dependencies=[require_scope("config:write")],
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def delete_template(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
template_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("operator")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Delete a config template. Tags are cascade-deleted. Push jobs are SET NULL."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(ConfigTemplate).where(
|
||||
ConfigTemplate.id == template_id, # type: ignore[arg-type]
|
||||
ConfigTemplate.tenant_id == tenant_id, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
if template is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Template {template_id} not found",
|
||||
)
|
||||
|
||||
await db.delete(template)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Preview & Push endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/templates/{template_id}/preview",
|
||||
summary="Preview template rendered for a specific device",
|
||||
dependencies=[require_scope("config:read")],
|
||||
)
|
||||
async def preview_template(
|
||||
tenant_id: uuid.UUID,
|
||||
template_id: uuid.UUID,
|
||||
body: PreviewRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Render a template with device context and custom variables for preview."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
# Load template
|
||||
result = await db.execute(
|
||||
select(ConfigTemplate).where(
|
||||
ConfigTemplate.id == template_id, # type: ignore[arg-type]
|
||||
ConfigTemplate.tenant_id == tenant_id, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
if template is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Template {template_id} not found",
|
||||
)
|
||||
|
||||
# Load device
|
||||
result = await db.execute(
|
||||
select(Device).where(Device.id == body.device_id) # type: ignore[arg-type]
|
||||
)
|
||||
device = result.scalar_one_or_none()
|
||||
if device is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Device {body.device_id} not found",
|
||||
)
|
||||
|
||||
# Validate variables against type definitions
|
||||
if template.variables:
|
||||
for var_def in template.variables:
|
||||
var_name = var_def.get("name", "")
|
||||
var_type = var_def.get("type", "string")
|
||||
value = body.variables.get(var_name)
|
||||
if value is None:
|
||||
# Use default if available
|
||||
default = var_def.get("default")
|
||||
if default is not None:
|
||||
body.variables[var_name] = default
|
||||
continue
|
||||
error = template_service.validate_variable(var_name, value, var_type)
|
||||
if error:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=error,
|
||||
)
|
||||
|
||||
# Render
|
||||
try:
|
||||
rendered = template_service.render_template(
|
||||
template.content,
|
||||
{
|
||||
"hostname": device.hostname,
|
||||
"ip_address": device.ip_address,
|
||||
"model": device.model,
|
||||
},
|
||||
body.variables,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=f"Template rendering failed: {exc}",
|
||||
)
|
||||
|
||||
return {
|
||||
"rendered": rendered,
|
||||
"device_hostname": device.hostname,
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tenants/{tenant_id}/templates/{template_id}/push",
|
||||
summary="Push template to devices (sequential rollout with panic-revert)",
|
||||
dependencies=[require_scope("config:write")],
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def push_template(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
template_id: uuid.UUID,
|
||||
body: PushRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("operator")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Start a template push to one or more devices.
|
||||
|
||||
Creates push jobs for each device and starts a background sequential rollout.
|
||||
Returns the rollout_id for status polling.
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
# Load template
|
||||
result = await db.execute(
|
||||
select(ConfigTemplate).where(
|
||||
ConfigTemplate.id == template_id, # type: ignore[arg-type]
|
||||
ConfigTemplate.tenant_id == tenant_id, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
if template is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Template {template_id} not found",
|
||||
)
|
||||
|
||||
if not body.device_ids:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="At least one device_id is required",
|
||||
)
|
||||
|
||||
# Validate variables
|
||||
if template.variables:
|
||||
for var_def in template.variables:
|
||||
var_name = var_def.get("name", "")
|
||||
var_type = var_def.get("type", "string")
|
||||
value = body.variables.get(var_name)
|
||||
if value is None:
|
||||
default = var_def.get("default")
|
||||
if default is not None:
|
||||
body.variables[var_name] = default
|
||||
continue
|
||||
error = template_service.validate_variable(var_name, value, var_type)
|
||||
if error:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=error,
|
||||
)
|
||||
|
||||
rollout_id = uuid.uuid4()
|
||||
jobs_created = []
|
||||
|
||||
for device_id_str in body.device_ids:
|
||||
# Load device to render template per-device
|
||||
result = await db.execute(
|
||||
select(Device).where(Device.id == device_id_str) # type: ignore[arg-type]
|
||||
)
|
||||
device = result.scalar_one_or_none()
|
||||
if device is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Device {device_id_str} not found",
|
||||
)
|
||||
|
||||
# Render template with this device's context
|
||||
try:
|
||||
rendered = template_service.render_template(
|
||||
template.content,
|
||||
{
|
||||
"hostname": device.hostname,
|
||||
"ip_address": device.ip_address,
|
||||
"model": device.model,
|
||||
},
|
||||
body.variables,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=f"Template rendering failed for device {device.hostname}: {exc}",
|
||||
)
|
||||
|
||||
# Create push job
|
||||
job = TemplatePushJob(
|
||||
tenant_id=tenant_id,
|
||||
template_id=template_id,
|
||||
device_id=device.id,
|
||||
rollout_id=rollout_id,
|
||||
rendered_content=rendered,
|
||||
status="pending",
|
||||
)
|
||||
db.add(job)
|
||||
jobs_created.append({
|
||||
"job_id": str(job.id),
|
||||
"device_id": str(device.id),
|
||||
"device_hostname": device.hostname,
|
||||
})
|
||||
|
||||
await db.flush()
|
||||
|
||||
# Start background push task
|
||||
asyncio.create_task(template_service.push_to_devices(str(rollout_id)))
|
||||
|
||||
return {
|
||||
"rollout_id": str(rollout_id),
|
||||
"jobs": jobs_created,
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/templates/push-status/{rollout_id}",
|
||||
summary="Poll push progress for a rollout",
|
||||
dependencies=[require_scope("config:read")],
|
||||
)
|
||||
async def push_status(
|
||||
tenant_id: uuid.UUID,
|
||||
rollout_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Return all push job statuses for a rollout with device hostnames."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(TemplatePushJob, Device.hostname)
|
||||
.join(Device, TemplatePushJob.device_id == Device.id) # type: ignore[arg-type]
|
||||
.where(
|
||||
TemplatePushJob.rollout_id == rollout_id, # type: ignore[arg-type]
|
||||
TemplatePushJob.tenant_id == tenant_id, # type: ignore[arg-type]
|
||||
)
|
||||
.order_by(TemplatePushJob.created_at.asc())
|
||||
)
|
||||
rows = result.all()
|
||||
|
||||
jobs = []
|
||||
for job, hostname in rows:
|
||||
jobs.append({
|
||||
"device_id": str(job.device_id),
|
||||
"hostname": hostname,
|
||||
"status": job.status,
|
||||
"error_message": job.error_message,
|
||||
"started_at": job.started_at.isoformat() if job.started_at else None,
|
||||
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
|
||||
})
|
||||
|
||||
return {
|
||||
"rollout_id": str(rollout_id),
|
||||
"jobs": jobs,
|
||||
}
|
||||
367
backend/app/routers/tenants.py
Normal file
367
backend/app/routers/tenants.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
Tenant management endpoints.
|
||||
|
||||
GET /api/tenants — list tenants (super_admin: all; tenant_admin: own only)
|
||||
POST /api/tenants — create tenant (super_admin only)
|
||||
GET /api/tenants/{id} — get tenant detail
|
||||
PUT /api/tenants/{id} — update tenant (super_admin only)
|
||||
DELETE /api/tenants/{id} — delete tenant (super_admin only)
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.middleware.rate_limit import limiter
|
||||
|
||||
from app.database import get_admin_db, get_db
|
||||
from app.middleware.rbac import require_super_admin, require_tenant_admin_or_above
|
||||
from app.middleware.tenant_context import CurrentUser
|
||||
from app.models.device import Device
|
||||
from app.models.tenant import Tenant
|
||||
from app.models.user import User
|
||||
from app.schemas.tenant import TenantCreate, TenantResponse, TenantUpdate
|
||||
|
||||
router = APIRouter(prefix="/tenants", tags=["tenants"])
|
||||
|
||||
|
||||
async def _get_tenant_response(
|
||||
tenant: Tenant,
|
||||
db: AsyncSession,
|
||||
) -> TenantResponse:
|
||||
"""Build a TenantResponse with user and device counts."""
|
||||
user_count_result = await db.execute(
|
||||
select(func.count(User.id)).where(User.tenant_id == tenant.id)
|
||||
)
|
||||
user_count = user_count_result.scalar_one() or 0
|
||||
|
||||
device_count_result = await db.execute(
|
||||
select(func.count(Device.id)).where(Device.tenant_id == tenant.id)
|
||||
)
|
||||
device_count = device_count_result.scalar_one() or 0
|
||||
|
||||
return TenantResponse(
|
||||
id=tenant.id,
|
||||
name=tenant.name,
|
||||
description=tenant.description,
|
||||
contact_email=tenant.contact_email,
|
||||
user_count=user_count,
|
||||
device_count=device_count,
|
||||
created_at=tenant.created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=list[TenantResponse], summary="List tenants")
|
||||
async def list_tenants(
|
||||
current_user: CurrentUser = Depends(require_tenant_admin_or_above),
|
||||
db: AsyncSession = Depends(get_admin_db),
|
||||
) -> list[TenantResponse]:
|
||||
"""
|
||||
List tenants.
|
||||
- super_admin: sees all tenants
|
||||
- tenant_admin: sees only their own tenant
|
||||
"""
|
||||
if current_user.is_super_admin:
|
||||
result = await db.execute(select(Tenant).order_by(Tenant.name))
|
||||
tenants = result.scalars().all()
|
||||
else:
|
||||
if not current_user.tenant_id:
|
||||
return []
|
||||
result = await db.execute(
|
||||
select(Tenant).where(Tenant.id == current_user.tenant_id)
|
||||
)
|
||||
tenants = result.scalars().all()
|
||||
|
||||
return [await _get_tenant_response(tenant, db) for tenant in tenants]
|
||||
|
||||
|
||||
@router.post("", response_model=TenantResponse, status_code=status.HTTP_201_CREATED, summary="Create a tenant")
|
||||
@limiter.limit("20/minute")
|
||||
async def create_tenant(
|
||||
request: Request,
|
||||
data: TenantCreate,
|
||||
current_user: CurrentUser = Depends(require_super_admin),
|
||||
db: AsyncSession = Depends(get_admin_db),
|
||||
) -> TenantResponse:
|
||||
"""Create a new tenant (super_admin only)."""
|
||||
# Check for name uniqueness
|
||||
existing = await db.execute(select(Tenant).where(Tenant.name == data.name))
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Tenant with name '{data.name}' already exists",
|
||||
)
|
||||
|
||||
tenant = Tenant(name=data.name, description=data.description, contact_email=data.contact_email)
|
||||
db.add(tenant)
|
||||
await db.commit()
|
||||
await db.refresh(tenant)
|
||||
|
||||
# Seed default alert rules for new tenant
|
||||
default_rules = [
|
||||
("High CPU Usage", "cpu_load", "gt", 90, 5, "warning"),
|
||||
("High Memory Usage", "memory_used_pct", "gt", 90, 5, "warning"),
|
||||
("High Disk Usage", "disk_used_pct", "gt", 85, 3, "warning"),
|
||||
("Device Offline", "device_offline", "eq", 1, 1, "critical"),
|
||||
]
|
||||
for name, metric, operator, threshold, duration, sev in default_rules:
|
||||
await db.execute(text("""
|
||||
INSERT INTO alert_rules (id, tenant_id, name, metric, operator, threshold, duration_polls, severity, enabled, is_default)
|
||||
VALUES (gen_random_uuid(), CAST(:tenant_id AS uuid), :name, :metric, :operator, :threshold, :duration, :severity, TRUE, TRUE)
|
||||
"""), {
|
||||
"tenant_id": str(tenant.id), "name": name, "metric": metric,
|
||||
"operator": operator, "threshold": threshold, "duration": duration, "severity": sev,
|
||||
})
|
||||
await db.commit()
|
||||
|
||||
# Seed starter config templates for new tenant
|
||||
await _seed_starter_templates(db, tenant.id)
|
||||
await db.commit()
|
||||
|
||||
# Provision OpenBao Transit key for the new tenant (non-blocking)
|
||||
try:
|
||||
from app.config import settings
|
||||
from app.services.key_service import provision_tenant_key
|
||||
|
||||
if settings.OPENBAO_ADDR:
|
||||
await provision_tenant_key(db, tenant.id)
|
||||
await db.commit()
|
||||
except Exception as exc:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
"OpenBao key provisioning failed for tenant %s (will be provisioned on next startup): %s",
|
||||
tenant.id,
|
||||
exc,
|
||||
)
|
||||
|
||||
return await _get_tenant_response(tenant, db)
|
||||
|
||||
|
||||
@router.get("/{tenant_id}", response_model=TenantResponse, summary="Get tenant detail")
|
||||
async def get_tenant(
|
||||
tenant_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_tenant_admin_or_above),
|
||||
db: AsyncSession = Depends(get_admin_db),
|
||||
) -> TenantResponse:
|
||||
"""Get tenant detail. Tenant admins can only view their own tenant."""
|
||||
# Enforce tenant_admin can only see their own tenant
|
||||
if not current_user.is_super_admin and current_user.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this tenant",
|
||||
)
|
||||
|
||||
result = await db.execute(select(Tenant).where(Tenant.id == tenant_id))
|
||||
tenant = result.scalar_one_or_none()
|
||||
|
||||
if not tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tenant not found",
|
||||
)
|
||||
|
||||
return await _get_tenant_response(tenant, db)
|
||||
|
||||
|
||||
@router.put("/{tenant_id}", response_model=TenantResponse, summary="Update a tenant")
|
||||
@limiter.limit("20/minute")
|
||||
async def update_tenant(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
data: TenantUpdate,
|
||||
current_user: CurrentUser = Depends(require_super_admin),
|
||||
db: AsyncSession = Depends(get_admin_db),
|
||||
) -> TenantResponse:
|
||||
"""Update tenant (super_admin only)."""
|
||||
result = await db.execute(select(Tenant).where(Tenant.id == tenant_id))
|
||||
tenant = result.scalar_one_or_none()
|
||||
|
||||
if not tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tenant not found",
|
||||
)
|
||||
|
||||
if data.name is not None:
|
||||
# Check name uniqueness
|
||||
name_check = await db.execute(
|
||||
select(Tenant).where(Tenant.name == data.name, Tenant.id != tenant_id)
|
||||
)
|
||||
if name_check.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Tenant with name '{data.name}' already exists",
|
||||
)
|
||||
tenant.name = data.name
|
||||
|
||||
if data.description is not None:
|
||||
tenant.description = data.description
|
||||
|
||||
if data.contact_email is not None:
|
||||
tenant.contact_email = data.contact_email
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(tenant)
|
||||
|
||||
return await _get_tenant_response(tenant, db)
|
||||
|
||||
|
||||
@router.delete("/{tenant_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete a tenant")
|
||||
@limiter.limit("5/minute")
|
||||
async def delete_tenant(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_super_admin),
|
||||
db: AsyncSession = Depends(get_admin_db),
|
||||
) -> None:
|
||||
"""Delete tenant (super_admin only). Cascades to all users and devices."""
|
||||
result = await db.execute(select(Tenant).where(Tenant.id == tenant_id))
|
||||
tenant = result.scalar_one_or_none()
|
||||
|
||||
if not tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tenant not found",
|
||||
)
|
||||
|
||||
await db.delete(tenant)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Starter template seeding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_STARTER_TEMPLATES = [
|
||||
{
|
||||
"name": "Basic Router",
|
||||
"description": "Complete SOHO/branch router setup: WAN on ether1, LAN bridge, DHCP, DNS, NAT, basic firewall",
|
||||
"content": """/interface bridge add name=bridge-lan comment="LAN bridge"
|
||||
/interface bridge port add bridge=bridge-lan interface=ether2
|
||||
/interface bridge port add bridge=bridge-lan interface=ether3
|
||||
/interface bridge port add bridge=bridge-lan interface=ether4
|
||||
/interface bridge port add bridge=bridge-lan interface=ether5
|
||||
|
||||
# WAN — DHCP client on ether1
|
||||
/ip dhcp-client add interface={{ wan_interface }} disabled=no comment="WAN uplink"
|
||||
|
||||
# LAN address
|
||||
/ip address add address={{ lan_gateway }}/{{ lan_cidr }} interface=bridge-lan
|
||||
|
||||
# DNS
|
||||
/ip dns set servers={{ dns_servers }} allow-remote-requests=yes
|
||||
|
||||
# DHCP server for LAN
|
||||
/ip pool add name=lan-pool ranges={{ dhcp_start }}-{{ dhcp_end }}
|
||||
/ip dhcp-server network add address={{ lan_network }}/{{ lan_cidr }} gateway={{ lan_gateway }} dns-server={{ lan_gateway }}
|
||||
/ip dhcp-server add name=lan-dhcp interface=bridge-lan address-pool=lan-pool disabled=no
|
||||
|
||||
# NAT masquerade
|
||||
/ip firewall nat add chain=srcnat out-interface={{ wan_interface }} action=masquerade
|
||||
|
||||
# Firewall — input chain
|
||||
/ip firewall filter
|
||||
add chain=input connection-state=established,related action=accept
|
||||
add chain=input connection-state=invalid action=drop
|
||||
add chain=input in-interface={{ wan_interface }} action=drop comment="Drop all other WAN input"
|
||||
|
||||
# Firewall — forward chain
|
||||
add chain=forward connection-state=established,related action=accept
|
||||
add chain=forward connection-state=invalid action=drop
|
||||
add chain=forward in-interface=bridge-lan out-interface={{ wan_interface }} action=accept comment="Allow LAN to WAN"
|
||||
add chain=forward action=drop comment="Drop everything else"
|
||||
|
||||
# NTP
|
||||
/system ntp client set enabled=yes servers={{ ntp_server }}
|
||||
|
||||
# Identity
|
||||
/system identity set name={{ device.hostname }}""",
|
||||
"variables": [
|
||||
{"name": "wan_interface", "type": "string", "default": "ether1", "description": "WAN-facing interface"},
|
||||
{"name": "lan_gateway", "type": "ip", "default": "192.168.88.1", "description": "LAN gateway IP"},
|
||||
{"name": "lan_cidr", "type": "integer", "default": "24", "description": "LAN subnet mask bits"},
|
||||
{"name": "lan_network", "type": "ip", "default": "192.168.88.0", "description": "LAN network address"},
|
||||
{"name": "dhcp_start", "type": "ip", "default": "192.168.88.100", "description": "DHCP pool start"},
|
||||
{"name": "dhcp_end", "type": "ip", "default": "192.168.88.254", "description": "DHCP pool end"},
|
||||
{"name": "dns_servers", "type": "string", "default": "8.8.8.8,8.8.4.4", "description": "Upstream DNS servers"},
|
||||
{"name": "ntp_server", "type": "string", "default": "pool.ntp.org", "description": "NTP server"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Basic Firewall",
|
||||
"description": "Standard firewall ruleset with WAN protection and LAN forwarding",
|
||||
"content": """/ip firewall filter
|
||||
add chain=input connection-state=established,related action=accept
|
||||
add chain=input connection-state=invalid action=drop
|
||||
add chain=input in-interface={{ wan_interface }} protocol=tcp dst-port=8291 action=drop comment="Block Winbox from WAN"
|
||||
add chain=input in-interface={{ wan_interface }} protocol=tcp dst-port=22 action=drop comment="Block SSH from WAN"
|
||||
add chain=forward connection-state=established,related action=accept
|
||||
add chain=forward connection-state=invalid action=drop
|
||||
add chain=forward src-address={{ allowed_network }} action=accept
|
||||
add chain=forward action=drop""",
|
||||
"variables": [
|
||||
{"name": "wan_interface", "type": "string", "default": "ether1", "description": "WAN-facing interface"},
|
||||
{"name": "allowed_network", "type": "subnet", "default": "192.168.88.0/24", "description": "Allowed source network"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "DHCP Server Setup",
|
||||
"description": "Configure DHCP server with address pool, DNS, and gateway",
|
||||
"content": """/ip pool add name=dhcp-pool ranges={{ pool_start }}-{{ pool_end }}
|
||||
/ip dhcp-server network add address={{ gateway }}/24 gateway={{ gateway }} dns-server={{ dns_server }}
|
||||
/ip dhcp-server add name=dhcp1 interface={{ interface }} address-pool=dhcp-pool disabled=no""",
|
||||
"variables": [
|
||||
{"name": "pool_start", "type": "ip", "default": "192.168.88.100", "description": "DHCP pool start address"},
|
||||
{"name": "pool_end", "type": "ip", "default": "192.168.88.254", "description": "DHCP pool end address"},
|
||||
{"name": "gateway", "type": "ip", "default": "192.168.88.1", "description": "Default gateway"},
|
||||
{"name": "dns_server", "type": "ip", "default": "8.8.8.8", "description": "DNS server address"},
|
||||
{"name": "interface", "type": "string", "default": "bridge-lan", "description": "Interface to serve DHCP on"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Wireless AP Config",
|
||||
"description": "Configure wireless access point with WPA2 security",
|
||||
"content": """/interface wireless security-profiles add name=portal-wpa2 mode=dynamic-keys authentication-types=wpa2-psk wpa2-pre-shared-key={{ password }}
|
||||
/interface wireless set wlan1 mode=ap-bridge ssid={{ ssid }} security-profile=portal-wpa2 frequency={{ frequency }} channel-width={{ channel_width }} disabled=no""",
|
||||
"variables": [
|
||||
{"name": "ssid", "type": "string", "default": "MikroTik-AP", "description": "Wireless network name"},
|
||||
{"name": "password", "type": "string", "default": "", "description": "WPA2 pre-shared key (min 8 characters)"},
|
||||
{"name": "frequency", "type": "integer", "default": "2412", "description": "Wireless frequency in MHz"},
|
||||
{"name": "channel_width", "type": "string", "default": "20/40mhz-XX", "description": "Channel width setting"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Initial Device Setup",
|
||||
"description": "Set device identity, NTP, DNS, and disable unused services",
|
||||
"content": """/system identity set name={{ device.hostname }}
|
||||
/system ntp client set enabled=yes servers={{ ntp_server }}
|
||||
/ip dns set servers={{ dns_servers }} allow-remote-requests=no
|
||||
/ip service disable telnet,ftp,www,api-ssl
|
||||
/ip service set ssh port=22
|
||||
/ip service set winbox port=8291""",
|
||||
"variables": [
|
||||
{"name": "ntp_server", "type": "ip", "default": "pool.ntp.org", "description": "NTP server address"},
|
||||
{"name": "dns_servers", "type": "string", "default": "8.8.8.8,8.8.4.4", "description": "Comma-separated DNS servers"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def _seed_starter_templates(db, tenant_id) -> None:
|
||||
"""Insert starter config templates for a newly created tenant."""
|
||||
import json as _json
|
||||
|
||||
for tmpl in _STARTER_TEMPLATES:
|
||||
await db.execute(text("""
|
||||
INSERT INTO config_templates (id, tenant_id, name, description, content, variables)
|
||||
VALUES (gen_random_uuid(), CAST(:tid AS uuid), :name, :desc, :content, CAST(:vars AS jsonb))
|
||||
"""), {
|
||||
"tid": str(tenant_id),
|
||||
"name": tmpl["name"],
|
||||
"desc": tmpl["description"],
|
||||
"content": tmpl["content"],
|
||||
"vars": _json.dumps(tmpl["variables"]),
|
||||
})
|
||||
374
backend/app/routers/topology.py
Normal file
374
backend/app/routers/topology.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""
|
||||
Network topology inference endpoint.
|
||||
|
||||
Endpoint: GET /api/tenants/{tenant_id}/topology
|
||||
|
||||
Builds a topology graph of managed devices by:
|
||||
1. Querying all devices for the tenant (via RLS)
|
||||
2. Fetching /ip/neighbor tables from online devices via NATS
|
||||
3. Matching neighbor addresses to known devices
|
||||
4. Falling back to shared /24 subnet inference when neighbor data is unavailable
|
||||
5. Caching results in Redis with 5-minute TTL
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.database import get_db, set_tenant_context
|
||||
from app.middleware.rbac import require_min_role
|
||||
from app.middleware.tenant_context import CurrentUser, get_current_user
|
||||
from app.models.device import Device
|
||||
from app.models.vpn import VpnPeer
|
||||
from app.services import routeros_proxy
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
router = APIRouter(tags=["topology"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Redis connection (lazy initialized, same pattern as routeros_proxy NATS)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_redis: aioredis.Redis | None = None
|
||||
TOPOLOGY_CACHE_TTL = 300 # 5 minutes
|
||||
|
||||
|
||||
async def _get_redis() -> aioredis.Redis:
|
||||
"""Get or create a Redis connection for topology caching."""
|
||||
global _redis
|
||||
if _redis is None:
|
||||
_redis = aioredis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
logger.info("Topology Redis connection established")
|
||||
return _redis
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TopologyNode(BaseModel):
|
||||
id: str
|
||||
hostname: str
|
||||
ip: str
|
||||
status: str
|
||||
model: str | None
|
||||
uptime: str | None
|
||||
|
||||
|
||||
class TopologyEdge(BaseModel):
|
||||
source: str
|
||||
target: str
|
||||
label: str
|
||||
|
||||
|
||||
class TopologyResponse(BaseModel):
|
||||
nodes: list[TopologyNode]
|
||||
edges: list[TopologyEdge]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _check_tenant_access(
|
||||
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
|
||||
) -> None:
|
||||
"""Verify the current user is allowed to access the given tenant."""
|
||||
if current_user.is_super_admin:
|
||||
await set_tenant_context(db, str(tenant_id))
|
||||
return
|
||||
if current_user.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied: you do not belong to this tenant.",
|
||||
)
|
||||
|
||||
|
||||
def _format_uptime(seconds: int | None) -> str | None:
|
||||
"""Convert uptime seconds to a human-readable string."""
|
||||
if seconds is None:
|
||||
return None
|
||||
days = seconds // 86400
|
||||
hours = (seconds % 86400) // 3600
|
||||
minutes = (seconds % 3600) // 60
|
||||
if days > 0:
|
||||
return f"{days}d {hours}h {minutes}m"
|
||||
if hours > 0:
|
||||
return f"{hours}h {minutes}m"
|
||||
return f"{minutes}m"
|
||||
|
||||
|
||||
def _get_subnet_key(ip_str: str) -> str | None:
|
||||
"""Return the /24 network key for an IPv4 address, or None if invalid."""
|
||||
try:
|
||||
addr = ipaddress.ip_address(ip_str)
|
||||
if isinstance(addr, ipaddress.IPv4Address):
|
||||
network = ipaddress.ip_network(f"{ip_str}/24", strict=False)
|
||||
return str(network)
|
||||
except ValueError:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _build_edges_from_neighbors(
|
||||
neighbor_data: dict[str, list[dict[str, Any]]],
|
||||
ip_to_device: dict[str, str],
|
||||
) -> list[TopologyEdge]:
|
||||
"""Build topology edges from neighbor discovery results.
|
||||
|
||||
Args:
|
||||
neighbor_data: Mapping of device_id -> list of neighbor entries.
|
||||
ip_to_device: Mapping of IP address -> device_id for known devices.
|
||||
|
||||
Returns:
|
||||
De-duplicated list of topology edges.
|
||||
"""
|
||||
seen_edges: set[tuple[str, str]] = set()
|
||||
edges: list[TopologyEdge] = []
|
||||
|
||||
for device_id, neighbors in neighbor_data.items():
|
||||
for neighbor in neighbors:
|
||||
# RouterOS neighbor entry has 'address' (or 'address4') field
|
||||
neighbor_ip = neighbor.get("address") or neighbor.get("address4", "")
|
||||
if not neighbor_ip:
|
||||
continue
|
||||
|
||||
target_device_id = ip_to_device.get(neighbor_ip)
|
||||
if target_device_id is None or target_device_id == device_id:
|
||||
continue
|
||||
|
||||
# De-duplicate bidirectional edges (A->B and B->A become one edge)
|
||||
edge_key = tuple(sorted([device_id, target_device_id]))
|
||||
if edge_key in seen_edges:
|
||||
continue
|
||||
seen_edges.add(edge_key)
|
||||
|
||||
interface_name = neighbor.get("interface", "neighbor")
|
||||
edges.append(
|
||||
TopologyEdge(
|
||||
source=device_id,
|
||||
target=target_device_id,
|
||||
label=interface_name,
|
||||
)
|
||||
)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
def _build_edges_from_subnets(
|
||||
devices: list[Device],
|
||||
existing_connected: set[tuple[str, str]],
|
||||
) -> list[TopologyEdge]:
|
||||
"""Infer edges from shared /24 subnets for devices without neighbor data.
|
||||
|
||||
Only adds subnet-based edges for device pairs that are NOT already connected
|
||||
via neighbor discovery.
|
||||
"""
|
||||
# Group devices by /24 subnet
|
||||
subnet_groups: dict[str, list[str]] = {}
|
||||
for device in devices:
|
||||
subnet_key = _get_subnet_key(device.ip_address)
|
||||
if subnet_key:
|
||||
subnet_groups.setdefault(subnet_key, []).append(str(device.id))
|
||||
|
||||
edges: list[TopologyEdge] = []
|
||||
for subnet, device_ids in subnet_groups.items():
|
||||
if len(device_ids) < 2:
|
||||
continue
|
||||
# Connect all pairs in the subnet
|
||||
for i, src in enumerate(device_ids):
|
||||
for tgt in device_ids[i + 1 :]:
|
||||
edge_key = tuple(sorted([src, tgt]))
|
||||
if edge_key in existing_connected:
|
||||
continue
|
||||
edges.append(
|
||||
TopologyEdge(
|
||||
source=src,
|
||||
target=tgt,
|
||||
label="shared subnet",
|
||||
)
|
||||
)
|
||||
existing_connected.add(edge_key)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/topology",
|
||||
response_model=TopologyResponse,
|
||||
summary="Get network topology for a tenant",
|
||||
)
|
||||
async def get_topology(
|
||||
tenant_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
_role: CurrentUser = Depends(require_min_role("viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> TopologyResponse:
|
||||
"""Build and return a network topology graph for the given tenant.
|
||||
|
||||
The topology is inferred from:
|
||||
1. LLDP/CDP/MNDP neighbor discovery on online devices
|
||||
2. Shared /24 subnet fallback for devices without neighbor data
|
||||
|
||||
Results are cached in Redis with a 5-minute TTL.
|
||||
"""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
cache_key = f"topology:{tenant_id}"
|
||||
|
||||
# Check Redis cache
|
||||
try:
|
||||
rd = await _get_redis()
|
||||
cached = await rd.get(cache_key)
|
||||
if cached:
|
||||
data = json.loads(cached)
|
||||
return TopologyResponse(**data)
|
||||
except Exception as exc:
|
||||
logger.warning("Redis cache read failed, computing topology fresh", error=str(exc))
|
||||
|
||||
# Fetch all devices for tenant (RLS enforced via get_db)
|
||||
result = await db.execute(
|
||||
select(
|
||||
Device.id,
|
||||
Device.hostname,
|
||||
Device.ip_address,
|
||||
Device.status,
|
||||
Device.model,
|
||||
Device.uptime_seconds,
|
||||
)
|
||||
)
|
||||
rows = result.all()
|
||||
|
||||
if not rows:
|
||||
return TopologyResponse(nodes=[], edges=[])
|
||||
|
||||
# Build nodes
|
||||
nodes: list[TopologyNode] = []
|
||||
ip_to_device: dict[str, str] = {}
|
||||
online_device_ids: list[str] = []
|
||||
devices_by_id: dict[str, Any] = {}
|
||||
|
||||
for row in rows:
|
||||
device_id = str(row.id)
|
||||
nodes.append(
|
||||
TopologyNode(
|
||||
id=device_id,
|
||||
hostname=row.hostname,
|
||||
ip=row.ip_address,
|
||||
status=row.status,
|
||||
model=row.model,
|
||||
uptime=_format_uptime(row.uptime_seconds),
|
||||
)
|
||||
)
|
||||
ip_to_device[row.ip_address] = device_id
|
||||
if row.status == "online":
|
||||
online_device_ids.append(device_id)
|
||||
|
||||
# Fetch neighbor tables from online devices in parallel
|
||||
neighbor_data: dict[str, list[dict[str, Any]]] = {}
|
||||
|
||||
if online_device_ids:
|
||||
tasks = [
|
||||
routeros_proxy.execute_command(
|
||||
device_id, "/ip/neighbor/print", timeout=10.0
|
||||
)
|
||||
for device_id in online_device_ids
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for device_id, res in zip(online_device_ids, results):
|
||||
if isinstance(res, Exception):
|
||||
logger.warning(
|
||||
"Neighbor fetch failed",
|
||||
device_id=device_id,
|
||||
error=str(res),
|
||||
)
|
||||
continue
|
||||
if isinstance(res, dict) and res.get("success") and res.get("data"):
|
||||
neighbor_data[device_id] = res["data"]
|
||||
|
||||
# Build edges from neighbor discovery
|
||||
neighbor_edges = _build_edges_from_neighbors(neighbor_data, ip_to_device)
|
||||
|
||||
# Track connected pairs for subnet fallback
|
||||
connected_pairs: set[tuple[str, str]] = set()
|
||||
for edge in neighbor_edges:
|
||||
connected_pairs.add(tuple(sorted([edge.source, edge.target])))
|
||||
|
||||
# VPN-based edges: query WireGuard peers to infer hub-spoke topology.
|
||||
# VPN peers all connect to the same WireGuard server. The gateway device
|
||||
# is the managed device NOT in the VPN peers list (it's the server, not a
|
||||
# client). If found, create star edges from gateway to each VPN peer device.
|
||||
vpn_edges: list[TopologyEdge] = []
|
||||
vpn_peer_device_ids: set[str] = set()
|
||||
try:
|
||||
peer_result = await db.execute(
|
||||
select(VpnPeer.device_id).where(VpnPeer.is_enabled.is_(True))
|
||||
)
|
||||
vpn_peer_device_ids = {str(row[0]) for row in peer_result.all()}
|
||||
|
||||
if vpn_peer_device_ids:
|
||||
# Gateway = managed devices NOT in VPN peers (typically the Core router)
|
||||
all_device_ids = {str(row.id) for row in rows}
|
||||
gateway_ids = all_device_ids - vpn_peer_device_ids
|
||||
# Pick the gateway that's online (prefer online devices)
|
||||
gateway_id = None
|
||||
for gid in gateway_ids:
|
||||
if gid in online_device_ids:
|
||||
gateway_id = gid
|
||||
break
|
||||
if not gateway_id and gateway_ids:
|
||||
gateway_id = next(iter(gateway_ids))
|
||||
|
||||
if gateway_id:
|
||||
for peer_device_id in vpn_peer_device_ids:
|
||||
edge_key = tuple(sorted([gateway_id, peer_device_id]))
|
||||
if edge_key not in connected_pairs:
|
||||
vpn_edges.append(
|
||||
TopologyEdge(
|
||||
source=gateway_id,
|
||||
target=peer_device_id,
|
||||
label="vpn tunnel",
|
||||
)
|
||||
)
|
||||
connected_pairs.add(edge_key)
|
||||
except Exception as exc:
|
||||
logger.warning("VPN edge detection failed", error=str(exc))
|
||||
|
||||
# Fallback: infer connections from shared /24 subnets
|
||||
# Query full Device objects for subnet analysis
|
||||
device_result = await db.execute(select(Device))
|
||||
all_devices = list(device_result.scalars().all())
|
||||
subnet_edges = _build_edges_from_subnets(all_devices, connected_pairs)
|
||||
|
||||
all_edges = neighbor_edges + vpn_edges + subnet_edges
|
||||
|
||||
topology = TopologyResponse(nodes=nodes, edges=all_edges)
|
||||
|
||||
# Cache result in Redis
|
||||
try:
|
||||
rd = await _get_redis()
|
||||
await rd.set(cache_key, topology.model_dump_json(), ex=TOPOLOGY_CACHE_TTL)
|
||||
except Exception as exc:
|
||||
logger.warning("Redis cache write failed", error=str(exc))
|
||||
|
||||
return topology
|
||||
391
backend/app/routers/transparency.py
Normal file
391
backend/app/routers/transparency.py
Normal file
@@ -0,0 +1,391 @@
|
||||
"""Transparency log API endpoints.
|
||||
|
||||
Tenant-scoped routes under /api/tenants/{tenant_id}/ for:
|
||||
- Paginated, filterable key access transparency log listing
|
||||
- Transparency log statistics (total events, last 24h, unique devices, justification breakdown)
|
||||
- CSV export of transparency logs
|
||||
|
||||
RLS enforced via get_db() (app_user engine with tenant context).
|
||||
RBAC: admin and above can view transparency logs (tenant_admin or super_admin).
|
||||
|
||||
Phase 31: Data Access Transparency Dashboard - TRUST-01, TRUST-02
|
||||
Shows tenant admins every KMS credential access event for their tenant.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, func, select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db, set_tenant_context
|
||||
from app.middleware.tenant_context import CurrentUser, get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["transparency"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _check_tenant_access(
|
||||
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
|
||||
) -> None:
|
||||
"""Verify the current user is allowed to access the given tenant."""
|
||||
if current_user.is_super_admin:
|
||||
await set_tenant_context(db, str(tenant_id))
|
||||
elif current_user.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this tenant",
|
||||
)
|
||||
|
||||
|
||||
def _require_admin(current_user: CurrentUser) -> None:
|
||||
"""Raise 403 if user does not have at least admin role.
|
||||
|
||||
Transparency data is sensitive operational intelligence --
|
||||
only tenant_admin and super_admin can view it.
|
||||
"""
|
||||
allowed = {"super_admin", "admin", "tenant_admin"}
|
||||
if current_user.role not in allowed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="At least admin role required to view transparency logs.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TransparencyLogItem(BaseModel):
|
||||
id: str
|
||||
action: str
|
||||
device_name: Optional[str] = None
|
||||
device_id: Optional[str] = None
|
||||
justification: Optional[str] = None
|
||||
operator_email: Optional[str] = None
|
||||
correlation_id: Optional[str] = None
|
||||
resource_type: Optional[str] = None
|
||||
resource_id: Optional[str] = None
|
||||
ip_address: Optional[str] = None
|
||||
created_at: str
|
||||
|
||||
|
||||
class TransparencyLogResponse(BaseModel):
|
||||
items: list[TransparencyLogItem]
|
||||
total: int
|
||||
page: int
|
||||
per_page: int
|
||||
|
||||
|
||||
class TransparencyStats(BaseModel):
|
||||
total_events: int
|
||||
events_last_24h: int
|
||||
unique_devices: int
|
||||
justification_breakdown: dict[str, int]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/transparency-logs",
|
||||
response_model=TransparencyLogResponse,
|
||||
summary="List KMS credential access events for tenant",
|
||||
)
|
||||
async def list_transparency_logs(
|
||||
tenant_id: uuid.UUID,
|
||||
page: int = Query(default=1, ge=1),
|
||||
per_page: int = Query(default=50, ge=1, le=100),
|
||||
device_id: Optional[uuid.UUID] = Query(default=None),
|
||||
justification: Optional[str] = Query(default=None),
|
||||
action: Optional[str] = Query(default=None),
|
||||
date_from: Optional[datetime] = Query(default=None),
|
||||
date_to: Optional[datetime] = Query(default=None),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
_require_admin(current_user)
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
# Build filter conditions using parameterized text fragments
|
||||
conditions = [text("k.tenant_id = :tenant_id")]
|
||||
params: dict[str, Any] = {"tenant_id": str(tenant_id)}
|
||||
|
||||
if device_id:
|
||||
conditions.append(text("k.device_id = :device_id"))
|
||||
params["device_id"] = str(device_id)
|
||||
|
||||
if justification:
|
||||
conditions.append(text("k.justification = :justification"))
|
||||
params["justification"] = justification
|
||||
|
||||
if action:
|
||||
conditions.append(text("k.action = :action"))
|
||||
params["action"] = action
|
||||
|
||||
if date_from:
|
||||
conditions.append(text("k.created_at >= :date_from"))
|
||||
params["date_from"] = date_from.isoformat()
|
||||
|
||||
if date_to:
|
||||
conditions.append(text("k.created_at <= :date_to"))
|
||||
params["date_to"] = date_to.isoformat()
|
||||
|
||||
where_clause = and_(*conditions)
|
||||
|
||||
# Shared SELECT columns for data queries
|
||||
_data_columns = text(
|
||||
"k.id, k.action, d.hostname AS device_name, "
|
||||
"k.device_id, k.justification, u.email AS operator_email, "
|
||||
"k.correlation_id, k.resource_type, k.resource_id, "
|
||||
"k.ip_address, k.created_at"
|
||||
)
|
||||
_data_from = text(
|
||||
"key_access_log k "
|
||||
"LEFT JOIN users u ON k.user_id = u.id "
|
||||
"LEFT JOIN devices d ON k.device_id = d.id"
|
||||
)
|
||||
|
||||
# Count total
|
||||
count_result = await db.execute(
|
||||
select(func.count())
|
||||
.select_from(text("key_access_log k"))
|
||||
.where(where_clause),
|
||||
params,
|
||||
)
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
# Paginated query
|
||||
offset = (page - 1) * per_page
|
||||
params["limit"] = per_page
|
||||
params["offset"] = offset
|
||||
|
||||
result = await db.execute(
|
||||
select(_data_columns)
|
||||
.select_from(_data_from)
|
||||
.where(where_clause)
|
||||
.order_by(text("k.created_at DESC"))
|
||||
.limit(per_page)
|
||||
.offset(offset),
|
||||
params,
|
||||
)
|
||||
rows = result.mappings().all()
|
||||
|
||||
items = [
|
||||
TransparencyLogItem(
|
||||
id=str(row["id"]),
|
||||
action=row["action"],
|
||||
device_name=row["device_name"],
|
||||
device_id=str(row["device_id"]) if row["device_id"] else None,
|
||||
justification=row["justification"],
|
||||
operator_email=row["operator_email"],
|
||||
correlation_id=row["correlation_id"],
|
||||
resource_type=row["resource_type"],
|
||||
resource_id=row["resource_id"],
|
||||
ip_address=row["ip_address"],
|
||||
created_at=row["created_at"].isoformat() if row["created_at"] else "",
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
return TransparencyLogResponse(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/transparency-logs/stats",
|
||||
response_model=TransparencyStats,
|
||||
summary="Get transparency log statistics",
|
||||
)
|
||||
async def get_transparency_stats(
|
||||
tenant_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> TransparencyStats:
|
||||
_require_admin(current_user)
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
params: dict[str, Any] = {"tenant_id": str(tenant_id)}
|
||||
|
||||
# Total events
|
||||
total_result = await db.execute(
|
||||
select(func.count())
|
||||
.select_from(text("key_access_log"))
|
||||
.where(text("tenant_id = :tenant_id")),
|
||||
params,
|
||||
)
|
||||
total_events = total_result.scalar() or 0
|
||||
|
||||
# Events in last 24 hours
|
||||
last_24h_result = await db.execute(
|
||||
select(func.count())
|
||||
.select_from(text("key_access_log"))
|
||||
.where(
|
||||
and_(
|
||||
text("tenant_id = :tenant_id"),
|
||||
text("created_at >= NOW() - INTERVAL '24 hours'"),
|
||||
)
|
||||
),
|
||||
params,
|
||||
)
|
||||
events_last_24h = last_24h_result.scalar() or 0
|
||||
|
||||
# Unique devices
|
||||
unique_devices_result = await db.execute(
|
||||
select(func.count(text("DISTINCT device_id")))
|
||||
.select_from(text("key_access_log"))
|
||||
.where(
|
||||
and_(
|
||||
text("tenant_id = :tenant_id"),
|
||||
text("device_id IS NOT NULL"),
|
||||
)
|
||||
),
|
||||
params,
|
||||
)
|
||||
unique_devices = unique_devices_result.scalar() or 0
|
||||
|
||||
# Justification breakdown
|
||||
breakdown_result = await db.execute(
|
||||
select(
|
||||
text("COALESCE(justification, 'system') AS justification_label"),
|
||||
func.count().label("count"),
|
||||
)
|
||||
.select_from(text("key_access_log"))
|
||||
.where(text("tenant_id = :tenant_id"))
|
||||
.group_by(text("justification_label")),
|
||||
params,
|
||||
)
|
||||
justification_breakdown: dict[str, int] = {}
|
||||
for row in breakdown_result.mappings().all():
|
||||
justification_breakdown[row["justification_label"]] = row["count"]
|
||||
|
||||
return TransparencyStats(
|
||||
total_events=total_events,
|
||||
events_last_24h=events_last_24h,
|
||||
unique_devices=unique_devices,
|
||||
justification_breakdown=justification_breakdown,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tenants/{tenant_id}/transparency-logs/export",
|
||||
summary="Export transparency logs as CSV",
|
||||
)
|
||||
async def export_transparency_logs(
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: Optional[uuid.UUID] = Query(default=None),
|
||||
justification: Optional[str] = Query(default=None),
|
||||
action: Optional[str] = Query(default=None),
|
||||
date_from: Optional[datetime] = Query(default=None),
|
||||
date_to: Optional[datetime] = Query(default=None),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> StreamingResponse:
|
||||
_require_admin(current_user)
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
|
||||
# Build filter conditions
|
||||
conditions = [text("k.tenant_id = :tenant_id")]
|
||||
params: dict[str, Any] = {"tenant_id": str(tenant_id)}
|
||||
|
||||
if device_id:
|
||||
conditions.append(text("k.device_id = :device_id"))
|
||||
params["device_id"] = str(device_id)
|
||||
|
||||
if justification:
|
||||
conditions.append(text("k.justification = :justification"))
|
||||
params["justification"] = justification
|
||||
|
||||
if action:
|
||||
conditions.append(text("k.action = :action"))
|
||||
params["action"] = action
|
||||
|
||||
if date_from:
|
||||
conditions.append(text("k.created_at >= :date_from"))
|
||||
params["date_from"] = date_from.isoformat()
|
||||
|
||||
if date_to:
|
||||
conditions.append(text("k.created_at <= :date_to"))
|
||||
params["date_to"] = date_to.isoformat()
|
||||
|
||||
where_clause = and_(*conditions)
|
||||
|
||||
_data_columns = text(
|
||||
"k.id, k.action, d.hostname AS device_name, "
|
||||
"k.device_id, k.justification, u.email AS operator_email, "
|
||||
"k.correlation_id, k.resource_type, k.resource_id, "
|
||||
"k.ip_address, k.created_at"
|
||||
)
|
||||
_data_from = text(
|
||||
"key_access_log k "
|
||||
"LEFT JOIN users u ON k.user_id = u.id "
|
||||
"LEFT JOIN devices d ON k.device_id = d.id"
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
select(_data_columns)
|
||||
.select_from(_data_from)
|
||||
.where(where_clause)
|
||||
.order_by(text("k.created_at DESC")),
|
||||
params,
|
||||
)
|
||||
all_rows = result.mappings().all()
|
||||
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
writer.writerow([
|
||||
"ID",
|
||||
"Action",
|
||||
"Device Name",
|
||||
"Device ID",
|
||||
"Justification",
|
||||
"Operator Email",
|
||||
"Correlation ID",
|
||||
"Resource Type",
|
||||
"Resource ID",
|
||||
"IP Address",
|
||||
"Timestamp",
|
||||
])
|
||||
for row in all_rows:
|
||||
writer.writerow([
|
||||
str(row["id"]),
|
||||
row["action"],
|
||||
row["device_name"] or "",
|
||||
str(row["device_id"]) if row["device_id"] else "",
|
||||
row["justification"] or "",
|
||||
row["operator_email"] or "",
|
||||
row["correlation_id"] or "",
|
||||
row["resource_type"] or "",
|
||||
row["resource_id"] or "",
|
||||
row["ip_address"] or "",
|
||||
str(row["created_at"]),
|
||||
])
|
||||
|
||||
output.seek(0)
|
||||
return StreamingResponse(
|
||||
iter([output.getvalue()]),
|
||||
media_type="text/csv",
|
||||
headers={
|
||||
"Content-Disposition": "attachment; filename=transparency-logs.csv"
|
||||
},
|
||||
)
|
||||
231
backend/app/routers/users.py
Normal file
231
backend/app/routers/users.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
User management endpoints (scoped to tenant).
|
||||
|
||||
GET /api/tenants/{tenant_id}/users — list users in tenant
|
||||
POST /api/tenants/{tenant_id}/users — create user in tenant
|
||||
GET /api/tenants/{tenant_id}/users/{id} — get user detail
|
||||
PUT /api/tenants/{tenant_id}/users/{id} — update user
|
||||
DELETE /api/tenants/{tenant_id}/users/{id} — deactivate user
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.middleware.rate_limit import limiter
|
||||
|
||||
from app.database import get_admin_db
|
||||
from app.middleware.rbac import require_tenant_admin_or_above
|
||||
from app.middleware.tenant_context import CurrentUser
|
||||
from app.models.tenant import Tenant
|
||||
from app.models.user import User, UserRole
|
||||
from app.schemas.user import UserCreate, UserResponse, UserUpdate
|
||||
from app.services.auth import hash_password
|
||||
|
||||
router = APIRouter(prefix="/tenants", tags=["users"])
|
||||
|
||||
|
||||
async def _check_tenant_access(
|
||||
tenant_id: uuid.UUID,
|
||||
current_user: CurrentUser,
|
||||
db: AsyncSession,
|
||||
) -> Tenant:
|
||||
"""
|
||||
Verify the tenant exists and the current user has access to it.
|
||||
|
||||
super_admin can access any tenant.
|
||||
tenant_admin can only access their own tenant.
|
||||
"""
|
||||
if not current_user.is_super_admin and current_user.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this tenant",
|
||||
)
|
||||
|
||||
result = await db.execute(select(Tenant).where(Tenant.id == tenant_id))
|
||||
tenant = result.scalar_one_or_none()
|
||||
|
||||
if not tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tenant not found",
|
||||
)
|
||||
|
||||
return tenant
|
||||
|
||||
|
||||
@router.get("/{tenant_id}/users", response_model=list[UserResponse], summary="List users in tenant")
|
||||
async def list_users(
|
||||
tenant_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_tenant_admin_or_above),
|
||||
db: AsyncSession = Depends(get_admin_db),
|
||||
) -> list[UserResponse]:
|
||||
"""
|
||||
List users in a tenant.
|
||||
- super_admin: can list users in any tenant
|
||||
- tenant_admin: can only list users in their own tenant
|
||||
"""
|
||||
await _check_tenant_access(tenant_id, current_user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(User)
|
||||
.where(User.tenant_id == tenant_id)
|
||||
.order_by(User.name)
|
||||
)
|
||||
users = result.scalars().all()
|
||||
|
||||
return [UserResponse.model_validate(user) for user in users]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{tenant_id}/users",
|
||||
response_model=UserResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a user in tenant",
|
||||
)
|
||||
@limiter.limit("20/minute")
|
||||
async def create_user(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
data: UserCreate,
|
||||
current_user: CurrentUser = Depends(require_tenant_admin_or_above),
|
||||
db: AsyncSession = Depends(get_admin_db),
|
||||
) -> UserResponse:
|
||||
"""
|
||||
Create a user within a tenant.
|
||||
|
||||
- super_admin: can create users in any tenant
|
||||
- tenant_admin: can only create users in their own tenant
|
||||
- No email invitation flow — admin creates accounts with temporary passwords
|
||||
"""
|
||||
await _check_tenant_access(tenant_id, current_user, db)
|
||||
|
||||
# Check email uniqueness (global, not per-tenant)
|
||||
existing = await db.execute(
|
||||
select(User).where(User.email == data.email.lower())
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="A user with this email already exists",
|
||||
)
|
||||
|
||||
user = User(
|
||||
email=data.email.lower(),
|
||||
hashed_password=hash_password(data.password),
|
||||
name=data.name,
|
||||
role=data.role.value,
|
||||
tenant_id=tenant_id,
|
||||
is_active=True,
|
||||
must_upgrade_auth=True,
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
return UserResponse.model_validate(user)
|
||||
|
||||
|
||||
@router.get("/{tenant_id}/users/{user_id}", response_model=UserResponse, summary="Get user detail")
|
||||
async def get_user(
|
||||
tenant_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_tenant_admin_or_above),
|
||||
db: AsyncSession = Depends(get_admin_db),
|
||||
) -> UserResponse:
|
||||
"""Get user detail."""
|
||||
await _check_tenant_access(tenant_id, current_user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == user_id, User.tenant_id == tenant_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
return UserResponse.model_validate(user)
|
||||
|
||||
|
||||
@router.put("/{tenant_id}/users/{user_id}", response_model=UserResponse, summary="Update a user")
|
||||
@limiter.limit("20/minute")
|
||||
async def update_user(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
data: UserUpdate,
|
||||
current_user: CurrentUser = Depends(require_tenant_admin_or_above),
|
||||
db: AsyncSession = Depends(get_admin_db),
|
||||
) -> UserResponse:
|
||||
"""
|
||||
Update user attributes (name, role, is_active).
|
||||
Role assignment is editable by admins.
|
||||
"""
|
||||
await _check_tenant_access(tenant_id, current_user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == user_id, User.tenant_id == tenant_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
if data.name is not None:
|
||||
user.name = data.name
|
||||
|
||||
if data.role is not None:
|
||||
user.role = data.role.value
|
||||
|
||||
if data.is_active is not None:
|
||||
user.is_active = data.is_active
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
return UserResponse.model_validate(user)
|
||||
|
||||
|
||||
@router.delete("/{tenant_id}/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Deactivate a user")
|
||||
@limiter.limit("5/minute")
|
||||
async def deactivate_user(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_tenant_admin_or_above),
|
||||
db: AsyncSession = Depends(get_admin_db),
|
||||
) -> None:
|
||||
"""
|
||||
Deactivate a user (soft delete — sets is_active=False).
|
||||
This preserves audit trail while preventing login.
|
||||
"""
|
||||
await _check_tenant_access(tenant_id, current_user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == user_id, User.tenant_id == tenant_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
# Prevent self-deactivation
|
||||
if user.id == current_user.user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot deactivate your own account",
|
||||
)
|
||||
|
||||
user.is_active = False
|
||||
await db.commit()
|
||||
236
backend/app/routers/vpn.py
Normal file
236
backend/app/routers/vpn.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""WireGuard VPN API endpoints.
|
||||
|
||||
Tenant-scoped routes under /api/tenants/{tenant_id}/vpn/ for:
|
||||
- VPN setup (enable WireGuard for tenant)
|
||||
- VPN config management (update endpoint, enable/disable)
|
||||
- Peer management (add device, remove, get config)
|
||||
|
||||
RLS enforced via get_db() (app_user engine with tenant context).
|
||||
RBAC: operator and above for all operations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db, set_tenant_context
|
||||
from app.middleware.rate_limit import limiter
|
||||
from app.middleware.tenant_context import CurrentUser, get_current_user
|
||||
from app.models.device import Device
|
||||
from app.schemas.vpn import (
|
||||
VpnConfigResponse,
|
||||
VpnConfigUpdate,
|
||||
VpnOnboardRequest,
|
||||
VpnOnboardResponse,
|
||||
VpnPeerConfig,
|
||||
VpnPeerCreate,
|
||||
VpnPeerResponse,
|
||||
VpnSetupRequest,
|
||||
)
|
||||
from app.services import vpn_service
|
||||
|
||||
router = APIRouter(tags=["vpn"])
|
||||
|
||||
|
||||
async def _check_tenant_access(
|
||||
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
|
||||
) -> None:
|
||||
if current_user.is_super_admin:
|
||||
await set_tenant_context(db, str(tenant_id))
|
||||
elif current_user.tenant_id != tenant_id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
|
||||
|
||||
def _require_operator(current_user: CurrentUser) -> None:
|
||||
if current_user.role == "viewer":
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Operator role required")
|
||||
|
||||
|
||||
# ── VPN Config ──
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/vpn", response_model=VpnConfigResponse | None)
|
||||
async def get_vpn_config(
|
||||
tenant_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get VPN configuration for this tenant."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
config = await vpn_service.get_vpn_config(db, tenant_id)
|
||||
if not config:
|
||||
return None
|
||||
peers = await vpn_service.get_peers(db, tenant_id)
|
||||
resp = VpnConfigResponse.model_validate(config)
|
||||
resp.peer_count = len(peers)
|
||||
return resp
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/vpn", response_model=VpnConfigResponse, status_code=status.HTTP_201_CREATED)
|
||||
@limiter.limit("20/minute")
|
||||
async def setup_vpn(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
body: VpnSetupRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Enable VPN for this tenant — generates server keys."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
_require_operator(current_user)
|
||||
try:
|
||||
config = await vpn_service.setup_vpn(db, tenant_id, endpoint=body.endpoint)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e))
|
||||
return VpnConfigResponse.model_validate(config)
|
||||
|
||||
|
||||
@router.patch("/tenants/{tenant_id}/vpn", response_model=VpnConfigResponse)
|
||||
@limiter.limit("20/minute")
|
||||
async def update_vpn_config(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
body: VpnConfigUpdate,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Update VPN settings (endpoint, enable/disable)."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
_require_operator(current_user)
|
||||
try:
|
||||
config = await vpn_service.update_vpn_config(
|
||||
db, tenant_id, endpoint=body.endpoint, is_enabled=body.is_enabled
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
peers = await vpn_service.get_peers(db, tenant_id)
|
||||
resp = VpnConfigResponse.model_validate(config)
|
||||
resp.peer_count = len(peers)
|
||||
return resp
|
||||
|
||||
|
||||
# ── VPN Peers ──
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/vpn/peers", response_model=list[VpnPeerResponse])
|
||||
async def list_peers(
|
||||
tenant_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List all VPN peers for this tenant."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
peers = await vpn_service.get_peers(db, tenant_id)
|
||||
|
||||
# Enrich with device info
|
||||
device_ids = [p.device_id for p in peers]
|
||||
devices = {}
|
||||
if device_ids:
|
||||
result = await db.execute(select(Device).where(Device.id.in_(device_ids)))
|
||||
devices = {d.id: d for d in result.scalars().all()}
|
||||
|
||||
# Read live WireGuard status for handshake enrichment
|
||||
wg_status = vpn_service.read_wg_status()
|
||||
|
||||
responses = []
|
||||
for peer in peers:
|
||||
resp = VpnPeerResponse.model_validate(peer)
|
||||
device = devices.get(peer.device_id)
|
||||
if device:
|
||||
resp.device_hostname = device.hostname
|
||||
resp.device_ip = device.ip_address
|
||||
# Enrich with live handshake from WireGuard container
|
||||
live_handshake = vpn_service.get_peer_handshake(wg_status, peer.peer_public_key)
|
||||
if live_handshake:
|
||||
resp.last_handshake = live_handshake
|
||||
responses.append(resp)
|
||||
return responses
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/vpn/peers", response_model=VpnPeerResponse, status_code=status.HTTP_201_CREATED)
|
||||
@limiter.limit("20/minute")
|
||||
async def add_peer(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
body: VpnPeerCreate,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Add a device as a VPN peer."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
_require_operator(current_user)
|
||||
try:
|
||||
peer = await vpn_service.add_peer(db, tenant_id, body.device_id, additional_allowed_ips=body.additional_allowed_ips)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e))
|
||||
|
||||
# Enrich with device info
|
||||
result = await db.execute(select(Device).where(Device.id == peer.device_id))
|
||||
device = result.scalar_one_or_none()
|
||||
|
||||
resp = VpnPeerResponse.model_validate(peer)
|
||||
if device:
|
||||
resp.device_hostname = device.hostname
|
||||
resp.device_ip = device.ip_address
|
||||
return resp
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/vpn/peers/onboard", response_model=VpnOnboardResponse, status_code=status.HTTP_201_CREATED)
|
||||
@limiter.limit("10/minute")
|
||||
async def onboard_device(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
body: VpnOnboardRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create device + VPN peer in one step. Returns RouterOS commands for tunnel setup."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
_require_operator(current_user)
|
||||
try:
|
||||
result = await vpn_service.onboard_device(
|
||||
db, tenant_id,
|
||||
hostname=body.hostname,
|
||||
username=body.username,
|
||||
password=body.password,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e))
|
||||
return VpnOnboardResponse(**result)
|
||||
|
||||
|
||||
@router.delete("/tenants/{tenant_id}/vpn/peers/{peer_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@limiter.limit("5/minute")
|
||||
async def remove_peer(
|
||||
request: Request,
|
||||
tenant_id: uuid.UUID,
|
||||
peer_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Remove a VPN peer."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
_require_operator(current_user)
|
||||
try:
|
||||
await vpn_service.remove_peer(db, tenant_id, peer_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/vpn/peers/{peer_id}/config", response_model=VpnPeerConfig)
|
||||
async def get_peer_device_config(
|
||||
tenant_id: uuid.UUID,
|
||||
peer_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get the full config for a peer — includes private key and RouterOS commands."""
|
||||
await _check_tenant_access(current_user, tenant_id, db)
|
||||
_require_operator(current_user)
|
||||
try:
|
||||
config = await vpn_service.get_peer_config(db, tenant_id, peer_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
return VpnPeerConfig(**config)
|
||||
18
backend/app/schemas/__init__.py
Normal file
18
backend/app/schemas/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Pydantic schemas for request/response validation."""
|
||||
|
||||
from app.schemas.auth import LoginRequest, TokenResponse, RefreshRequest, UserMeResponse
|
||||
from app.schemas.tenant import TenantCreate, TenantResponse, TenantUpdate
|
||||
from app.schemas.user import UserCreate, UserResponse, UserUpdate
|
||||
|
||||
__all__ = [
|
||||
"LoginRequest",
|
||||
"TokenResponse",
|
||||
"RefreshRequest",
|
||||
"UserMeResponse",
|
||||
"TenantCreate",
|
||||
"TenantResponse",
|
||||
"TenantUpdate",
|
||||
"UserCreate",
|
||||
"UserResponse",
|
||||
"UserUpdate",
|
||||
]
|
||||
123
backend/app/schemas/auth.py
Normal file
123
backend/app/schemas/auth.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Authentication request/response schemas."""
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
auth_upgrade_required: bool = False # True when bcrypt user needs SRP registration
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class UserMeResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
email: str
|
||||
name: str
|
||||
role: str
|
||||
tenant_id: Optional[uuid.UUID] = None
|
||||
auth_version: int = 1
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
current_password: str
|
||||
new_password: str
|
||||
# SRP users must provide re-derived credentials
|
||||
new_srp_salt: Optional[str] = None
|
||||
new_srp_verifier: Optional[str] = None
|
||||
# Re-wrapped key bundle (SRP users re-encrypt with new AUK)
|
||||
encrypted_private_key: Optional[str] = None
|
||||
private_key_nonce: Optional[str] = None
|
||||
encrypted_vault_key: Optional[str] = None
|
||||
vault_key_nonce: Optional[str] = None
|
||||
public_key: Optional[str] = None
|
||||
pbkdf2_salt: Optional[str] = None
|
||||
hkdf_salt: Optional[str] = None
|
||||
|
||||
|
||||
class ForgotPasswordRequest(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class ResetPasswordRequest(BaseModel):
|
||||
token: str
|
||||
new_password: str
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
# --- SRP Zero-Knowledge Authentication Schemas ---
|
||||
|
||||
|
||||
class SRPInitRequest(BaseModel):
|
||||
"""Step 1 request: client sends email to begin SRP handshake."""
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class SRPInitResponse(BaseModel):
|
||||
"""Step 1 response: server returns ephemeral B and key derivation salts."""
|
||||
salt: str # hex-encoded SRP salt
|
||||
server_public: str # hex-encoded server ephemeral B
|
||||
session_id: str # Redis session key nonce
|
||||
pbkdf2_salt: str # base64-encoded, from user_key_sets (needed for 2SKD before SRP verify)
|
||||
hkdf_salt: str # base64-encoded, from user_key_sets (needed for 2SKD before SRP verify)
|
||||
|
||||
|
||||
class SRPVerifyRequest(BaseModel):
|
||||
"""Step 2 request: client sends proof M1 to complete handshake."""
|
||||
email: EmailStr
|
||||
session_id: str
|
||||
client_public: str # hex-encoded client ephemeral A
|
||||
client_proof: str # hex-encoded client proof M1
|
||||
|
||||
|
||||
class SRPVerifyResponse(BaseModel):
|
||||
"""Step 2 response: server returns tokens and proof M2."""
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
server_proof: str # hex-encoded server proof M2
|
||||
encrypted_key_set: Optional[dict] = None # Key bundle for client-side decryption
|
||||
|
||||
|
||||
class SRPRegisterRequest(BaseModel):
|
||||
"""Used during registration to store SRP verifier and key set."""
|
||||
srp_salt: str # hex-encoded
|
||||
srp_verifier: str # hex-encoded
|
||||
encrypted_private_key: str # base64-encoded
|
||||
private_key_nonce: str # base64-encoded
|
||||
encrypted_vault_key: str # base64-encoded
|
||||
vault_key_nonce: str # base64-encoded
|
||||
public_key: str # base64-encoded
|
||||
pbkdf2_salt: str # base64-encoded
|
||||
hkdf_salt: str # base64-encoded
|
||||
|
||||
|
||||
# --- Account Self-Service Schemas ---
|
||||
|
||||
|
||||
class DeleteAccountRequest(BaseModel):
|
||||
"""Request body for account self-deletion. User must type 'DELETE' to confirm."""
|
||||
confirmation: str # Must be "DELETE" to confirm
|
||||
|
||||
|
||||
class DeleteAccountResponse(BaseModel):
|
||||
"""Response after successful account deletion."""
|
||||
message: str
|
||||
deleted: bool
|
||||
78
backend/app/schemas/certificate.py
Normal file
78
backend/app/schemas/certificate.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Pydantic request/response schemas for the Internal Certificate Authority."""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class CACreateRequest(BaseModel):
|
||||
"""Request to generate a new root CA for the tenant."""
|
||||
|
||||
common_name: str = "Portal Root CA"
|
||||
validity_years: int = 10 # Default 10 years for CA
|
||||
|
||||
|
||||
class CertSignRequest(BaseModel):
|
||||
"""Request to sign a per-device certificate using the tenant CA."""
|
||||
|
||||
device_id: UUID
|
||||
validity_days: int = 730 # Default 2 years for device certs
|
||||
|
||||
|
||||
class BulkCertDeployRequest(BaseModel):
|
||||
"""Request to deploy certificates to multiple devices."""
|
||||
|
||||
device_ids: list[UUID]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class CAResponse(BaseModel):
|
||||
"""Public details of a tenant's Certificate Authority (no private key)."""
|
||||
|
||||
id: UUID
|
||||
tenant_id: UUID
|
||||
common_name: str
|
||||
fingerprint_sha256: str
|
||||
serial_number: str
|
||||
not_valid_before: datetime
|
||||
not_valid_after: datetime
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class DeviceCertResponse(BaseModel):
|
||||
"""Public details of a device certificate (no private key)."""
|
||||
|
||||
id: UUID
|
||||
tenant_id: UUID
|
||||
device_id: UUID
|
||||
ca_id: UUID
|
||||
common_name: str
|
||||
fingerprint_sha256: str
|
||||
serial_number: str
|
||||
not_valid_before: datetime
|
||||
not_valid_after: datetime
|
||||
status: str
|
||||
deployed_at: datetime | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class CertDeployResponse(BaseModel):
|
||||
"""Result of a single device certificate deployment attempt."""
|
||||
|
||||
success: bool
|
||||
device_id: UUID
|
||||
cert_name_on_device: str | None = None
|
||||
error: str | None = None
|
||||
271
backend/app/schemas/device.py
Normal file
271
backend/app/schemas/device.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""Pydantic schemas for Device, DeviceGroup, and DeviceTag endpoints."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Device schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DeviceCreate(BaseModel):
|
||||
"""Schema for creating a new device."""
|
||||
|
||||
hostname: str
|
||||
ip_address: str
|
||||
api_port: int = 8728
|
||||
api_ssl_port: int = 8729
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class DeviceUpdate(BaseModel):
|
||||
"""Schema for updating an existing device. All fields optional."""
|
||||
|
||||
hostname: Optional[str] = None
|
||||
ip_address: Optional[str] = None
|
||||
api_port: Optional[int] = None
|
||||
api_ssl_port: Optional[int] = None
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
latitude: Optional[float] = None
|
||||
longitude: Optional[float] = None
|
||||
tls_mode: Optional[str] = None
|
||||
|
||||
@field_validator("tls_mode")
|
||||
@classmethod
|
||||
def validate_tls_mode(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate tls_mode is one of the allowed values."""
|
||||
if v is None:
|
||||
return v
|
||||
allowed = {"auto", "insecure", "plain", "portal_ca"}
|
||||
if v not in allowed:
|
||||
raise ValueError(f"tls_mode must be one of: {', '.join(sorted(allowed))}")
|
||||
return v
|
||||
|
||||
|
||||
class DeviceTagRef(BaseModel):
|
||||
"""Minimal tag info embedded in device responses."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
color: Optional[str] = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class DeviceGroupRef(BaseModel):
|
||||
"""Minimal group info embedded in device responses."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class DeviceResponse(BaseModel):
|
||||
"""Device response schema. NEVER includes credential fields."""
|
||||
|
||||
id: uuid.UUID
|
||||
hostname: str
|
||||
ip_address: str
|
||||
api_port: int
|
||||
api_ssl_port: int
|
||||
model: Optional[str] = None
|
||||
serial_number: Optional[str] = None
|
||||
firmware_version: Optional[str] = None
|
||||
routeros_version: Optional[str] = None
|
||||
routeros_major_version: Optional[int] = None
|
||||
uptime_seconds: Optional[int] = None
|
||||
last_seen: Optional[datetime] = None
|
||||
latitude: Optional[float] = None
|
||||
longitude: Optional[float] = None
|
||||
status: str
|
||||
tls_mode: str = "auto"
|
||||
tags: list[DeviceTagRef] = []
|
||||
groups: list[DeviceGroupRef] = []
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class DeviceListResponse(BaseModel):
|
||||
"""Paginated device list response."""
|
||||
|
||||
items: list[DeviceResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subnet scan schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SubnetScanRequest(BaseModel):
|
||||
"""Request body for a subnet scan."""
|
||||
|
||||
cidr: str
|
||||
|
||||
@field_validator("cidr")
|
||||
@classmethod
|
||||
def validate_cidr(cls, v: str) -> str:
|
||||
"""Validate that the value is a valid CIDR notation and RFC 1918 private range."""
|
||||
import ipaddress
|
||||
try:
|
||||
network = ipaddress.ip_network(v, strict=False)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid CIDR notation: {e}") from e
|
||||
# Only allow private IP ranges (RFC 1918: 10/8, 172.16/12, 192.168/16)
|
||||
if not network.is_private:
|
||||
raise ValueError(
|
||||
"Only private IP ranges can be scanned (RFC 1918: "
|
||||
"10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16)"
|
||||
)
|
||||
# Reject ranges larger than /20 (4096 IPs) to prevent abuse
|
||||
if network.num_addresses > 4096:
|
||||
raise ValueError(
|
||||
f"CIDR range too large ({network.num_addresses} addresses). "
|
||||
"Maximum allowed: /20 (4096 addresses)."
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class SubnetScanResult(BaseModel):
|
||||
"""A single discovered host from a subnet scan."""
|
||||
|
||||
ip_address: str
|
||||
hostname: Optional[str] = None
|
||||
api_port_open: bool = False
|
||||
api_ssl_port_open: bool = False
|
||||
|
||||
|
||||
class SubnetScanResponse(BaseModel):
|
||||
"""Response for a subnet scan operation."""
|
||||
|
||||
cidr: str
|
||||
discovered: list[SubnetScanResult]
|
||||
total_scanned: int
|
||||
total_discovered: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bulk add from scan
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BulkDeviceAdd(BaseModel):
|
||||
"""One device entry within a bulk-add request."""
|
||||
|
||||
ip_address: str
|
||||
hostname: Optional[str] = None
|
||||
api_port: int = 8728
|
||||
api_ssl_port: int = 8729
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
|
||||
|
||||
class BulkAddRequest(BaseModel):
|
||||
"""
|
||||
Bulk-add devices selected from a scan result.
|
||||
|
||||
shared_username / shared_password are used for all devices that do not
|
||||
provide their own credentials.
|
||||
"""
|
||||
|
||||
devices: list[BulkDeviceAdd]
|
||||
shared_username: Optional[str] = None
|
||||
shared_password: Optional[str] = None
|
||||
|
||||
|
||||
class BulkAddResult(BaseModel):
|
||||
"""Summary result of a bulk-add operation."""
|
||||
|
||||
added: list[DeviceResponse]
|
||||
failed: list[dict] # {ip_address, error}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DeviceGroup schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DeviceGroupCreate(BaseModel):
|
||||
"""Schema for creating a device group."""
|
||||
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class DeviceGroupUpdate(BaseModel):
|
||||
"""Schema for updating a device group."""
|
||||
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class DeviceGroupResponse(BaseModel):
|
||||
"""Device group response schema."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
device_count: int = 0
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DeviceTag schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DeviceTagCreate(BaseModel):
|
||||
"""Schema for creating a device tag."""
|
||||
|
||||
name: str
|
||||
color: Optional[str] = None
|
||||
|
||||
@field_validator("color")
|
||||
@classmethod
|
||||
def validate_color(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate hex color format if provided."""
|
||||
if v is None:
|
||||
return v
|
||||
import re
|
||||
if not re.match(r"^#[0-9A-Fa-f]{6}$", v):
|
||||
raise ValueError("Color must be a valid 6-digit hex color (e.g. #FF5733)")
|
||||
return v
|
||||
|
||||
|
||||
class DeviceTagUpdate(BaseModel):
|
||||
"""Schema for updating a device tag."""
|
||||
|
||||
name: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
|
||||
@field_validator("color")
|
||||
@classmethod
|
||||
def validate_color(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
return v
|
||||
import re
|
||||
if not re.match(r"^#[0-9A-Fa-f]{6}$", v):
|
||||
raise ValueError("Color must be a valid 6-digit hex color (e.g. #FF5733)")
|
||||
return v
|
||||
|
||||
|
||||
class DeviceTagResponse(BaseModel):
|
||||
"""Device tag response schema."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
color: Optional[str] = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
31
backend/app/schemas/tenant.py
Normal file
31
backend/app/schemas/tenant.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Tenant request/response schemas."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TenantCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
contact_email: Optional[str] = None
|
||||
|
||||
|
||||
class TenantUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
contact_email: Optional[str] = None
|
||||
|
||||
|
||||
class TenantResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
contact_email: Optional[str] = None
|
||||
user_count: int = 0
|
||||
device_count: int = 0
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
53
backend/app/schemas/user.py
Normal file
53
backend/app/schemas/user.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""User request/response schemas."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, EmailStr, field_validator
|
||||
|
||||
from app.models.user import UserRole
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
name: str
|
||||
email: EmailStr
|
||||
password: str
|
||||
role: UserRole = UserRole.VIEWER
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def validate_password(cls, v: str) -> str:
|
||||
if len(v) < 8:
|
||||
raise ValueError("Password must be at least 8 characters")
|
||||
return v
|
||||
|
||||
@field_validator("role")
|
||||
@classmethod
|
||||
def validate_role(cls, v: UserRole) -> UserRole:
|
||||
"""Tenant admins can only create operator/viewer roles; super_admin via separate flow."""
|
||||
allowed_tenant_roles = {UserRole.TENANT_ADMIN, UserRole.OPERATOR, UserRole.VIEWER}
|
||||
if v not in allowed_tenant_roles:
|
||||
raise ValueError(
|
||||
f"Role must be one of: {', '.join(r.value for r in allowed_tenant_roles)}"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
email: str
|
||||
role: str
|
||||
tenant_id: Optional[uuid.UUID] = None
|
||||
is_active: bool
|
||||
last_login: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
role: Optional[UserRole] = None
|
||||
is_active: Optional[bool] = None
|
||||
91
backend/app/schemas/vpn.py
Normal file
91
backend/app/schemas/vpn.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Pydantic schemas for WireGuard VPN management."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# ── VPN Config (server-side) ──
|
||||
|
||||
|
||||
class VpnSetupRequest(BaseModel):
|
||||
"""Request to enable VPN for a tenant."""
|
||||
endpoint: Optional[str] = None # public hostname:port — if blank, devices must be configured manually
|
||||
|
||||
|
||||
class VpnConfigResponse(BaseModel):
|
||||
"""VPN server configuration (never exposes private key)."""
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
id: uuid.UUID
|
||||
tenant_id: uuid.UUID
|
||||
server_public_key: str
|
||||
subnet: str
|
||||
server_port: int
|
||||
server_address: str
|
||||
endpoint: Optional[str]
|
||||
is_enabled: bool
|
||||
peer_count: int = 0
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class VpnConfigUpdate(BaseModel):
|
||||
"""Update VPN configuration."""
|
||||
endpoint: Optional[str] = None
|
||||
is_enabled: Optional[bool] = None
|
||||
|
||||
|
||||
# ── VPN Peers ──
|
||||
|
||||
|
||||
class VpnPeerCreate(BaseModel):
|
||||
"""Add a device as a VPN peer."""
|
||||
device_id: uuid.UUID
|
||||
additional_allowed_ips: Optional[str] = None # comma-separated subnets for site-to-site routing
|
||||
|
||||
|
||||
class VpnPeerResponse(BaseModel):
|
||||
"""VPN peer info (never exposes private key)."""
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
id: uuid.UUID
|
||||
device_id: uuid.UUID
|
||||
device_hostname: str = ""
|
||||
device_ip: str = ""
|
||||
peer_public_key: str
|
||||
assigned_ip: str
|
||||
is_enabled: bool
|
||||
last_handshake: Optional[datetime]
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# ── VPN Onboarding (combined device + peer creation) ──
|
||||
|
||||
|
||||
class VpnOnboardRequest(BaseModel):
|
||||
"""Combined device creation + VPN peer onboarding."""
|
||||
hostname: str
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class VpnOnboardResponse(BaseModel):
|
||||
"""Response from onboarding — device, peer, and RouterOS commands."""
|
||||
device_id: uuid.UUID
|
||||
peer_id: uuid.UUID
|
||||
hostname: str
|
||||
assigned_ip: str
|
||||
routeros_commands: list[str]
|
||||
|
||||
|
||||
class VpnPeerConfig(BaseModel):
|
||||
"""Full peer config for display/export — includes private key for device setup."""
|
||||
peer_private_key: str
|
||||
peer_public_key: str
|
||||
assigned_ip: str
|
||||
server_public_key: str
|
||||
server_endpoint: str
|
||||
allowed_ips: str
|
||||
routeros_commands: list[str]
|
||||
0
backend/app/security/__init__.py
Normal file
0
backend/app/security/__init__.py
Normal file
95
backend/app/security/command_blocklist.py
Normal file
95
backend/app/security/command_blocklist.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Dangerous RouterOS command and path blocklist.
|
||||
|
||||
Prevents destructive or sensitive operations from being executed through
|
||||
the config editor. Commands and paths are checked via case-insensitive
|
||||
prefix matching against known-dangerous entries.
|
||||
|
||||
To extend: add strings to DANGEROUS_COMMANDS, BROWSE_BLOCKED_PATHS,
|
||||
or WRITE_BLOCKED_PATHS.
|
||||
"""
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
# CLI commands blocked from the execute endpoint.
|
||||
# Matched as case-insensitive prefixes (e.g., "/user" blocks "/user/print" too).
|
||||
DANGEROUS_COMMANDS: list[str] = [
|
||||
"/system/reset-configuration",
|
||||
"/system/shutdown",
|
||||
"/system/reboot",
|
||||
"/system/backup",
|
||||
"/system/license",
|
||||
"/user",
|
||||
"/password",
|
||||
"/certificate",
|
||||
"/radius",
|
||||
"/export",
|
||||
"/import",
|
||||
]
|
||||
|
||||
# Paths blocked from ALL operations including browse (truly dangerous to read).
|
||||
BROWSE_BLOCKED_PATHS: list[str] = [
|
||||
"system/reset-configuration",
|
||||
"system/shutdown",
|
||||
"system/reboot",
|
||||
"system/backup",
|
||||
"system/license",
|
||||
"password",
|
||||
]
|
||||
|
||||
# Paths blocked from write operations (add/set/remove) but readable via browse.
|
||||
WRITE_BLOCKED_PATHS: list[str] = [
|
||||
"user",
|
||||
"certificate",
|
||||
"radius",
|
||||
]
|
||||
|
||||
|
||||
def check_command_safety(command: str) -> None:
|
||||
"""Reject dangerous CLI commands with HTTP 403.
|
||||
|
||||
Normalizes the command (strip + lowercase) and checks against
|
||||
DANGEROUS_COMMANDS using prefix matching.
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if the command matches a dangerous prefix.
|
||||
"""
|
||||
normalized = command.strip().lower()
|
||||
for blocked in DANGEROUS_COMMANDS:
|
||||
if normalized.startswith(blocked):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=(
|
||||
f"Command blocked: '{command}' matches dangerous prefix '{blocked}'. "
|
||||
f"This operation is not allowed through the config editor."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def check_path_safety(path: str, *, write: bool = False) -> None:
|
||||
"""Reject dangerous menu paths with HTTP 403.
|
||||
|
||||
Normalizes the path (strip + lowercase + lstrip '/') and checks
|
||||
against blocked path lists using prefix matching.
|
||||
|
||||
Args:
|
||||
path: The RouterOS menu path to check.
|
||||
write: If True, also check WRITE_BLOCKED_PATHS (for add/set/remove).
|
||||
If False, only check BROWSE_BLOCKED_PATHS (for read-only browse).
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if the path matches a blocked prefix.
|
||||
"""
|
||||
normalized = path.strip().lower().lstrip("/")
|
||||
blocked_lists = [BROWSE_BLOCKED_PATHS]
|
||||
if write:
|
||||
blocked_lists.append(WRITE_BLOCKED_PATHS)
|
||||
for blocklist in blocked_lists:
|
||||
for blocked in blocklist:
|
||||
if normalized.startswith(blocked):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=(
|
||||
f"Path blocked: '{path}' matches dangerous prefix '{blocked}'. "
|
||||
f"This operation is not allowed through the config editor."
|
||||
),
|
||||
)
|
||||
1
backend/app/services/__init__.py
Normal file
1
backend/app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Backend services — auth, crypto, and business logic."""
|
||||
240
backend/app/services/account_service.py
Normal file
240
backend/app/services/account_service.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Account self-service operations: deletion and data export.
|
||||
|
||||
Provides GDPR/CCPA-compliant account deletion with full PII erasure
|
||||
and data portability export (Article 20).
|
||||
|
||||
All queries use raw SQL via text() with admin sessions (bypass RLS)
|
||||
since these are cross-table operations on the authenticated user's data.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
logger = structlog.get_logger("account_service")
|
||||
|
||||
|
||||
async def delete_user_account(
|
||||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
tenant_id: uuid.UUID | None,
|
||||
user_email: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Hard-delete a user account with full PII erasure.
|
||||
|
||||
Steps:
|
||||
1. Create a deletion receipt audit log (persisted via separate session)
|
||||
2. Anonymize PII in existing audit_logs for this user
|
||||
3. Hard-delete the user row (CASCADE handles related tables)
|
||||
4. Best-effort session invalidation via Redis
|
||||
|
||||
Args:
|
||||
db: Admin async session (bypasses RLS).
|
||||
user_id: UUID of the user to delete.
|
||||
tenant_id: Tenant UUID (None for super_admin).
|
||||
user_email: User's email (needed for audit hash before deletion).
|
||||
|
||||
Returns:
|
||||
Dict with deleted=True and user_id on success.
|
||||
"""
|
||||
effective_tenant_id = tenant_id or uuid.UUID(int=0)
|
||||
email_hash = hashlib.sha256(user_email.encode()).hexdigest()
|
||||
|
||||
# ── 1. Pre-deletion audit receipt (separate session so it persists) ────
|
||||
try:
|
||||
async with AdminAsyncSessionLocal() as audit_db:
|
||||
await log_action(
|
||||
audit_db,
|
||||
tenant_id=effective_tenant_id,
|
||||
user_id=user_id,
|
||||
action="account_deleted",
|
||||
resource_type="user",
|
||||
resource_id=str(user_id),
|
||||
details={
|
||||
"deleted_user_id": str(user_id),
|
||||
"email_hash": email_hash,
|
||||
"deletion_type": "self_service",
|
||||
"deleted_at": datetime.now(UTC).isoformat(),
|
||||
},
|
||||
)
|
||||
await audit_db.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"deletion_receipt_failed",
|
||||
user_id=str(user_id),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# ── 2. Anonymize PII in audit_logs for this user ─────────────────────
|
||||
# Strip PII keys from details JSONB (email, name, user_email, user_name)
|
||||
await db.execute(
|
||||
text(
|
||||
"UPDATE audit_logs "
|
||||
"SET details = details - 'email' - 'name' - 'user_email' - 'user_name' "
|
||||
"WHERE user_id = :user_id"
|
||||
),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
|
||||
# Null out encrypted_details (may contain encrypted PII)
|
||||
await db.execute(
|
||||
text(
|
||||
"UPDATE audit_logs "
|
||||
"SET encrypted_details = NULL "
|
||||
"WHERE user_id = :user_id"
|
||||
),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
|
||||
# ── 3. Hard delete user row ──────────────────────────────────────────
|
||||
# CASCADE handles: user_key_sets, api_keys, password_reset_tokens
|
||||
# SET NULL handles: audit_logs.user_id, key_access_log.user_id,
|
||||
# maintenance_windows.created_by, alert_events.acknowledged_by
|
||||
await db.execute(
|
||||
text("DELETE FROM users WHERE id = :user_id"),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# ── 4. Best-effort Redis session invalidation ────────────────────────
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
from app.config import settings
|
||||
from app.services.auth import revoke_user_tokens
|
||||
|
||||
r = aioredis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
await revoke_user_tokens(r, str(user_id))
|
||||
await r.aclose()
|
||||
except Exception:
|
||||
# JWT expires in 15 min anyway; not critical
|
||||
logger.debug("redis_session_invalidation_skipped", user_id=str(user_id))
|
||||
|
||||
logger.info("account_deleted", user_id=str(user_id), email_hash=email_hash)
|
||||
|
||||
return {"deleted": True, "user_id": str(user_id)}
|
||||
|
||||
|
||||
async def export_user_data(
|
||||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
tenant_id: uuid.UUID | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Assemble all user data for GDPR Art. 20 data portability export.
|
||||
|
||||
Returns a structured dict with user profile, API keys, audit logs,
|
||||
and key access log entries.
|
||||
|
||||
Args:
|
||||
db: Admin async session (bypasses RLS).
|
||||
user_id: UUID of the user whose data to export.
|
||||
tenant_id: Tenant UUID (None for super_admin).
|
||||
|
||||
Returns:
|
||||
Envelope dict with export_date, format_version, and all user data.
|
||||
"""
|
||||
|
||||
# ── User profile ─────────────────────────────────────────────────────
|
||||
result = await db.execute(
|
||||
text(
|
||||
"SELECT id, email, name, role, tenant_id, "
|
||||
"created_at, last_login, auth_version "
|
||||
"FROM users WHERE id = :user_id"
|
||||
),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
user_row = result.mappings().first()
|
||||
user_data: dict[str, Any] = {}
|
||||
if user_row:
|
||||
user_data = {
|
||||
"id": str(user_row["id"]),
|
||||
"email": user_row["email"],
|
||||
"name": user_row["name"],
|
||||
"role": user_row["role"],
|
||||
"tenant_id": str(user_row["tenant_id"]) if user_row["tenant_id"] else None,
|
||||
"created_at": user_row["created_at"].isoformat() if user_row["created_at"] else None,
|
||||
"last_login": user_row["last_login"].isoformat() if user_row["last_login"] else None,
|
||||
"auth_version": user_row["auth_version"],
|
||||
}
|
||||
|
||||
# ── API keys (exclude key_hash for security) ─────────────────────────
|
||||
result = await db.execute(
|
||||
text(
|
||||
"SELECT id, name, key_prefix, scopes, created_at, "
|
||||
"expires_at, revoked_at, last_used_at "
|
||||
"FROM api_keys WHERE user_id = :user_id "
|
||||
"ORDER BY created_at DESC"
|
||||
),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
api_keys = []
|
||||
for row in result.mappings().all():
|
||||
api_keys.append({
|
||||
"id": str(row["id"]),
|
||||
"name": row["name"],
|
||||
"key_prefix": row["key_prefix"],
|
||||
"scopes": row["scopes"],
|
||||
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
||||
"expires_at": row["expires_at"].isoformat() if row["expires_at"] else None,
|
||||
"revoked_at": row["revoked_at"].isoformat() if row["revoked_at"] else None,
|
||||
"last_used_at": row["last_used_at"].isoformat() if row["last_used_at"] else None,
|
||||
})
|
||||
|
||||
# ── Audit logs (limit 1000, most recent first) ───────────────────────
|
||||
result = await db.execute(
|
||||
text(
|
||||
"SELECT id, action, resource_type, resource_id, "
|
||||
"details, ip_address, created_at "
|
||||
"FROM audit_logs WHERE user_id = :user_id "
|
||||
"ORDER BY created_at DESC LIMIT 1000"
|
||||
),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
audit_logs = []
|
||||
for row in result.mappings().all():
|
||||
details = row["details"] if row["details"] else {}
|
||||
audit_logs.append({
|
||||
"id": str(row["id"]),
|
||||
"action": row["action"],
|
||||
"resource_type": row["resource_type"],
|
||||
"resource_id": row["resource_id"],
|
||||
"details": details,
|
||||
"ip_address": row["ip_address"],
|
||||
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
||||
})
|
||||
|
||||
# ── Key access log (limit 1000, most recent first) ───────────────────
|
||||
result = await db.execute(
|
||||
text(
|
||||
"SELECT id, action, resource_type, ip_address, created_at "
|
||||
"FROM key_access_log WHERE user_id = :user_id "
|
||||
"ORDER BY created_at DESC LIMIT 1000"
|
||||
),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
key_access_entries = []
|
||||
for row in result.mappings().all():
|
||||
key_access_entries.append({
|
||||
"id": str(row["id"]),
|
||||
"action": row["action"],
|
||||
"resource_type": row["resource_type"],
|
||||
"ip_address": row["ip_address"],
|
||||
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
||||
})
|
||||
|
||||
return {
|
||||
"export_date": datetime.now(UTC).isoformat(),
|
||||
"format_version": "1.0",
|
||||
"user": user_data,
|
||||
"api_keys": api_keys,
|
||||
"audit_logs": audit_logs,
|
||||
"key_access_log": key_access_entries,
|
||||
}
|
||||
723
backend/app/services/alert_evaluator.py
Normal file
723
backend/app/services/alert_evaluator.py
Normal file
@@ -0,0 +1,723 @@
|
||||
"""Alert rule evaluation engine with Redis breach counters and flap detection.
|
||||
|
||||
Entry points:
|
||||
- evaluate(device_id, tenant_id, metric_type, data): called from metrics_subscriber
|
||||
- evaluate_offline(device_id, tenant_id): called from nats_subscriber on device offline
|
||||
- evaluate_online(device_id, tenant_id): called from nats_subscriber on device online
|
||||
|
||||
Uses Redis for:
|
||||
- Consecutive breach counting (alert:breach:{device_id}:{rule_id})
|
||||
- Flap detection (alert:flap:{device_id}:{rule_id} sorted set)
|
||||
|
||||
Uses AdminAsyncSessionLocal for all DB operations (runs cross-tenant in NATS handlers).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
from app.services.event_publisher import publish_event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Module-level Redis client, lazily initialized
|
||||
_redis_client: aioredis.Redis | None = None
|
||||
|
||||
# Module-level rule cache: {tenant_id: (rules_list, fetched_at_timestamp)}
|
||||
_rule_cache: dict[str, tuple[list[dict], float]] = {}
|
||||
_CACHE_TTL_SECONDS = 60
|
||||
|
||||
# Module-level maintenance window cache: {tenant_id: (active_windows_list, fetched_at_timestamp)}
|
||||
# Each window: {"device_ids": [...], "suppress_alerts": True}
|
||||
_maintenance_cache: dict[str, tuple[list[dict], float]] = {}
|
||||
_MAINTENANCE_CACHE_TTL = 30 # 30 seconds
|
||||
|
||||
|
||||
async def _get_redis() -> aioredis.Redis:
|
||||
"""Get or create the Redis client."""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = aioredis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
return _redis_client
|
||||
|
||||
|
||||
async def _get_active_maintenance_windows(tenant_id: str) -> list[dict]:
|
||||
"""Fetch active maintenance windows for a tenant, with 30s cache."""
|
||||
now = time.time()
|
||||
cached = _maintenance_cache.get(tenant_id)
|
||||
if cached and (now - cached[1]) < _MAINTENANCE_CACHE_TTL:
|
||||
return cached[0]
|
||||
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT device_ids, suppress_alerts
|
||||
FROM maintenance_windows
|
||||
WHERE tenant_id = CAST(:tenant_id AS uuid)
|
||||
AND suppress_alerts = true
|
||||
AND start_at <= NOW()
|
||||
AND end_at >= NOW()
|
||||
"""),
|
||||
{"tenant_id": tenant_id},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
windows = [
|
||||
{
|
||||
"device_ids": row[0] if isinstance(row[0], list) else [],
|
||||
"suppress_alerts": row[1],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
_maintenance_cache[tenant_id] = (windows, now)
|
||||
return windows
|
||||
|
||||
|
||||
async def _is_device_in_maintenance(tenant_id: str, device_id: str) -> bool:
|
||||
"""Check if a device is currently under active maintenance with alert suppression.
|
||||
|
||||
Returns True if there is at least one active maintenance window covering
|
||||
this device (or all devices via empty device_ids array).
|
||||
"""
|
||||
windows = await _get_active_maintenance_windows(tenant_id)
|
||||
for window in windows:
|
||||
device_ids = window["device_ids"]
|
||||
# Empty device_ids means "all devices in tenant"
|
||||
if not device_ids or device_id in device_ids:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _get_rules_for_tenant(tenant_id: str) -> list[dict]:
|
||||
"""Fetch active alert rules for a tenant, with 60s cache."""
|
||||
now = time.time()
|
||||
cached = _rule_cache.get(tenant_id)
|
||||
if cached and (now - cached[1]) < _CACHE_TTL_SECONDS:
|
||||
return cached[0]
|
||||
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT id, tenant_id, device_id, group_id, name, metric,
|
||||
operator, threshold, duration_polls, severity
|
||||
FROM alert_rules
|
||||
WHERE tenant_id = CAST(:tenant_id AS uuid) AND enabled = TRUE
|
||||
"""),
|
||||
{"tenant_id": tenant_id},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
rules = [
|
||||
{
|
||||
"id": str(row[0]),
|
||||
"tenant_id": str(row[1]),
|
||||
"device_id": str(row[2]) if row[2] else None,
|
||||
"group_id": str(row[3]) if row[3] else None,
|
||||
"name": row[4],
|
||||
"metric": row[5],
|
||||
"operator": row[6],
|
||||
"threshold": float(row[7]),
|
||||
"duration_polls": row[8],
|
||||
"severity": row[9],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
_rule_cache[tenant_id] = (rules, now)
|
||||
return rules
|
||||
|
||||
|
||||
def _check_threshold(value: float, operator: str, threshold: float) -> bool:
|
||||
"""Check if a metric value breaches a threshold."""
|
||||
if operator == "gt":
|
||||
return value > threshold
|
||||
elif operator == "lt":
|
||||
return value < threshold
|
||||
elif operator == "gte":
|
||||
return value >= threshold
|
||||
elif operator == "lte":
|
||||
return value <= threshold
|
||||
return False
|
||||
|
||||
|
||||
def _extract_metrics(metric_type: str, data: dict) -> dict[str, float]:
|
||||
"""Extract metric name->value pairs from a NATS metrics event."""
|
||||
metrics: dict[str, float] = {}
|
||||
|
||||
if metric_type == "health":
|
||||
health = data.get("health", {})
|
||||
for key in ("cpu_load", "temperature"):
|
||||
val = health.get(key)
|
||||
if val is not None and val != "":
|
||||
try:
|
||||
metrics[key] = float(val)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
# Compute memory_used_pct and disk_used_pct
|
||||
free_mem = health.get("free_memory")
|
||||
total_mem = health.get("total_memory")
|
||||
if free_mem is not None and total_mem is not None:
|
||||
try:
|
||||
total = float(total_mem)
|
||||
free = float(free_mem)
|
||||
if total > 0:
|
||||
metrics["memory_used_pct"] = round((1.0 - free / total) * 100, 1)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
free_disk = health.get("free_disk")
|
||||
total_disk = health.get("total_disk")
|
||||
if free_disk is not None and total_disk is not None:
|
||||
try:
|
||||
total = float(total_disk)
|
||||
free = float(free_disk)
|
||||
if total > 0:
|
||||
metrics["disk_used_pct"] = round((1.0 - free / total) * 100, 1)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
elif metric_type == "wireless":
|
||||
wireless = data.get("wireless", [])
|
||||
# Aggregate: use worst signal, lowest CCQ, sum client_count
|
||||
for wif in wireless:
|
||||
for key in ("signal_strength", "ccq", "client_count"):
|
||||
val = wif.get(key) if key != "avg_signal" else wif.get("avg_signal")
|
||||
if key == "signal_strength":
|
||||
val = wif.get("avg_signal")
|
||||
if val is not None and val != "":
|
||||
try:
|
||||
fval = float(val)
|
||||
if key not in metrics:
|
||||
metrics[key] = fval
|
||||
elif key == "signal_strength":
|
||||
metrics[key] = min(metrics[key], fval) # worst signal
|
||||
elif key == "ccq":
|
||||
metrics[key] = min(metrics[key], fval) # worst CCQ
|
||||
elif key == "client_count":
|
||||
metrics[key] = metrics.get(key, 0) + fval # sum
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# TODO: Interface bandwidth alerting (rx_bps/tx_bps) requires stateful delta
|
||||
# computation between consecutive poll values. Deferred for now — the alert_rules
|
||||
# table supports these metric types, but evaluation is skipped.
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
async def _increment_breach(
|
||||
r: aioredis.Redis, device_id: str, rule_id: str, required_polls: int
|
||||
) -> bool:
|
||||
"""Increment breach counter in Redis. Returns True when threshold duration reached."""
|
||||
key = f"alert:breach:{device_id}:{rule_id}"
|
||||
count = await r.incr(key)
|
||||
# Set TTL to (required_polls + 2) * 60 seconds so it expires if breaches stop
|
||||
await r.expire(key, (required_polls + 2) * 60)
|
||||
return count >= required_polls
|
||||
|
||||
|
||||
async def _reset_breach(r: aioredis.Redis, device_id: str, rule_id: str) -> None:
|
||||
"""Reset breach counter when metric returns to normal."""
|
||||
key = f"alert:breach:{device_id}:{rule_id}"
|
||||
await r.delete(key)
|
||||
|
||||
|
||||
async def _check_flapping(r: aioredis.Redis, device_id: str, rule_id: str) -> bool:
|
||||
"""Check if alert is flapping (>= 5 state transitions in 10 minutes).
|
||||
|
||||
Uses a Redis sorted set with timestamps as scores.
|
||||
"""
|
||||
key = f"alert:flap:{device_id}:{rule_id}"
|
||||
now = time.time()
|
||||
window_start = now - 600 # 10 minute window
|
||||
|
||||
# Add this transition
|
||||
await r.zadd(key, {str(now): now})
|
||||
# Remove entries outside the window
|
||||
await r.zremrangebyscore(key, "-inf", window_start)
|
||||
# Set TTL on the key
|
||||
await r.expire(key, 1200)
|
||||
# Count transitions in window
|
||||
count = await r.zcard(key)
|
||||
return count >= 5
|
||||
|
||||
|
||||
async def _get_device_groups(device_id: str) -> list[str]:
|
||||
"""Get group IDs for a device."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("SELECT group_id FROM device_group_memberships WHERE device_id = CAST(:device_id AS uuid)"),
|
||||
{"device_id": device_id},
|
||||
)
|
||||
return [str(row[0]) for row in result.fetchall()]
|
||||
|
||||
|
||||
async def _has_open_alert(device_id: str, rule_id: str | None, metric: str | None = None) -> bool:
|
||||
"""Check if there's an open (firing, unresolved) alert for this device+rule."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
if rule_id:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT 1 FROM alert_events
|
||||
WHERE device_id = CAST(:device_id AS uuid) AND rule_id = CAST(:rule_id AS uuid)
|
||||
AND status = 'firing' AND resolved_at IS NULL
|
||||
LIMIT 1
|
||||
"""),
|
||||
{"device_id": device_id, "rule_id": rule_id},
|
||||
)
|
||||
else:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT 1 FROM alert_events
|
||||
WHERE device_id = CAST(:device_id AS uuid) AND rule_id IS NULL
|
||||
AND metric = :metric AND status = 'firing' AND resolved_at IS NULL
|
||||
LIMIT 1
|
||||
"""),
|
||||
{"device_id": device_id, "metric": metric or "offline"},
|
||||
)
|
||||
return result.fetchone() is not None
|
||||
|
||||
|
||||
async def _create_alert_event(
|
||||
device_id: str,
|
||||
tenant_id: str,
|
||||
rule_id: str | None,
|
||||
status: str,
|
||||
severity: str,
|
||||
metric: str | None,
|
||||
value: float | None,
|
||||
threshold: float | None,
|
||||
message: str | None,
|
||||
is_flapping: bool = False,
|
||||
) -> dict:
|
||||
"""Create an alert event row and return its data."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
INSERT INTO alert_events
|
||||
(id, device_id, tenant_id, rule_id, status, severity, metric,
|
||||
value, threshold, message, is_flapping, fired_at,
|
||||
resolved_at)
|
||||
VALUES
|
||||
(gen_random_uuid(), CAST(:device_id AS uuid), CAST(:tenant_id AS uuid),
|
||||
:rule_id, :status, :severity, :metric,
|
||||
:value, :threshold, :message, :is_flapping, NOW(),
|
||||
CASE WHEN :status = 'resolved' THEN NOW() ELSE NULL END)
|
||||
RETURNING id, fired_at
|
||||
"""),
|
||||
{
|
||||
"device_id": device_id,
|
||||
"tenant_id": tenant_id,
|
||||
"rule_id": rule_id,
|
||||
"status": status,
|
||||
"severity": severity,
|
||||
"metric": metric,
|
||||
"value": value,
|
||||
"threshold": threshold,
|
||||
"message": message,
|
||||
"is_flapping": is_flapping,
|
||||
},
|
||||
)
|
||||
row = result.fetchone()
|
||||
await session.commit()
|
||||
|
||||
alert_data = {
|
||||
"id": str(row[0]) if row else None,
|
||||
"device_id": device_id,
|
||||
"tenant_id": tenant_id,
|
||||
"rule_id": rule_id,
|
||||
"status": status,
|
||||
"severity": severity,
|
||||
"metric": metric,
|
||||
"value": value,
|
||||
"threshold": threshold,
|
||||
"message": message,
|
||||
"is_flapping": is_flapping,
|
||||
}
|
||||
|
||||
# Publish real-time event to NATS for SSE pipeline (fire-and-forget)
|
||||
if status in ("firing", "flapping"):
|
||||
await publish_event(f"alert.fired.{tenant_id}", {
|
||||
"event_type": "alert_fired",
|
||||
"tenant_id": tenant_id,
|
||||
"device_id": device_id,
|
||||
"alert_event_id": alert_data["id"],
|
||||
"severity": severity,
|
||||
"metric": metric,
|
||||
"current_value": value,
|
||||
"threshold": threshold,
|
||||
"message": message,
|
||||
"is_flapping": is_flapping,
|
||||
"fired_at": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
elif status == "resolved":
|
||||
await publish_event(f"alert.resolved.{tenant_id}", {
|
||||
"event_type": "alert_resolved",
|
||||
"tenant_id": tenant_id,
|
||||
"device_id": device_id,
|
||||
"alert_event_id": alert_data["id"],
|
||||
"severity": severity,
|
||||
"metric": metric,
|
||||
"message": message,
|
||||
"resolved_at": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
|
||||
return alert_data
|
||||
|
||||
|
||||
async def _resolve_alert(device_id: str, rule_id: str | None, metric: str | None = None) -> None:
|
||||
"""Resolve an open alert by setting resolved_at."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
if rule_id:
|
||||
await session.execute(
|
||||
text("""
|
||||
UPDATE alert_events SET resolved_at = NOW(), status = 'resolved'
|
||||
WHERE device_id = CAST(:device_id AS uuid) AND rule_id = CAST(:rule_id AS uuid)
|
||||
AND status = 'firing' AND resolved_at IS NULL
|
||||
"""),
|
||||
{"device_id": device_id, "rule_id": rule_id},
|
||||
)
|
||||
else:
|
||||
await session.execute(
|
||||
text("""
|
||||
UPDATE alert_events SET resolved_at = NOW(), status = 'resolved'
|
||||
WHERE device_id = CAST(:device_id AS uuid) AND rule_id IS NULL
|
||||
AND metric = :metric AND status = 'firing' AND resolved_at IS NULL
|
||||
"""),
|
||||
{"device_id": device_id, "metric": metric or "offline"},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def _get_channels_for_tenant(tenant_id: str) -> list[dict]:
|
||||
"""Get all notification channels for a tenant."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT id, name, channel_type, smtp_host, smtp_port, smtp_user,
|
||||
smtp_password, smtp_use_tls, from_address, to_address,
|
||||
webhook_url, smtp_password_transit, slack_webhook_url, tenant_id
|
||||
FROM notification_channels
|
||||
WHERE tenant_id = CAST(:tenant_id AS uuid)
|
||||
"""),
|
||||
{"tenant_id": tenant_id},
|
||||
)
|
||||
return [
|
||||
{
|
||||
"id": str(row[0]),
|
||||
"name": row[1],
|
||||
"channel_type": row[2],
|
||||
"smtp_host": row[3],
|
||||
"smtp_port": row[4],
|
||||
"smtp_user": row[5],
|
||||
"smtp_password": row[6],
|
||||
"smtp_use_tls": row[7],
|
||||
"from_address": row[8],
|
||||
"to_address": row[9],
|
||||
"webhook_url": row[10],
|
||||
"smtp_password_transit": row[11],
|
||||
"slack_webhook_url": row[12],
|
||||
"tenant_id": str(row[13]) if row[13] else None,
|
||||
}
|
||||
for row in result.fetchall()
|
||||
]
|
||||
|
||||
|
||||
async def _get_channels_for_rule(rule_id: str) -> list[dict]:
|
||||
"""Get notification channels linked to a specific alert rule."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT nc.id, nc.name, nc.channel_type, nc.smtp_host, nc.smtp_port,
|
||||
nc.smtp_user, nc.smtp_password, nc.smtp_use_tls,
|
||||
nc.from_address, nc.to_address, nc.webhook_url,
|
||||
nc.smtp_password_transit, nc.slack_webhook_url, nc.tenant_id
|
||||
FROM notification_channels nc
|
||||
JOIN alert_rule_channels arc ON arc.channel_id = nc.id
|
||||
WHERE arc.rule_id = CAST(:rule_id AS uuid)
|
||||
"""),
|
||||
{"rule_id": rule_id},
|
||||
)
|
||||
return [
|
||||
{
|
||||
"id": str(row[0]),
|
||||
"name": row[1],
|
||||
"channel_type": row[2],
|
||||
"smtp_host": row[3],
|
||||
"smtp_port": row[4],
|
||||
"smtp_user": row[5],
|
||||
"smtp_password": row[6],
|
||||
"smtp_use_tls": row[7],
|
||||
"from_address": row[8],
|
||||
"to_address": row[9],
|
||||
"webhook_url": row[10],
|
||||
"smtp_password_transit": row[11],
|
||||
"slack_webhook_url": row[12],
|
||||
"tenant_id": str(row[13]) if row[13] else None,
|
||||
}
|
||||
for row in result.fetchall()
|
||||
]
|
||||
|
||||
|
||||
async def _dispatch_async(alert_event: dict, channels: list[dict], device_hostname: str) -> None:
|
||||
"""Fire-and-forget notification dispatch."""
|
||||
try:
|
||||
from app.services.notification_service import dispatch_notifications
|
||||
await dispatch_notifications(alert_event, channels, device_hostname)
|
||||
except Exception as e:
|
||||
logger.warning("Notification dispatch failed: %s", e)
|
||||
|
||||
|
||||
async def _get_device_hostname(device_id: str) -> str:
|
||||
"""Get device hostname for notification messages."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("SELECT hostname FROM devices WHERE id = CAST(:device_id AS uuid)"),
|
||||
{"device_id": device_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row[0] if row else device_id
|
||||
|
||||
|
||||
async def evaluate(
|
||||
device_id: str,
|
||||
tenant_id: str,
|
||||
metric_type: str,
|
||||
data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Evaluate alert rules for incoming device metrics.
|
||||
|
||||
Called from metrics_subscriber after metric DB write.
|
||||
"""
|
||||
# Check maintenance window suppression before evaluating rules
|
||||
if await _is_device_in_maintenance(tenant_id, device_id):
|
||||
logger.debug(
|
||||
"Alert suppressed by maintenance window for device %s tenant %s",
|
||||
device_id, tenant_id,
|
||||
)
|
||||
return
|
||||
|
||||
rules = await _get_rules_for_tenant(tenant_id)
|
||||
if not rules:
|
||||
return
|
||||
|
||||
metrics = _extract_metrics(metric_type, data)
|
||||
if not metrics:
|
||||
return
|
||||
|
||||
r = await _get_redis()
|
||||
device_groups = await _get_device_groups(device_id)
|
||||
|
||||
# Build a set of metrics that have device-specific rules
|
||||
device_specific_metrics: set[str] = set()
|
||||
for rule in rules:
|
||||
if rule["device_id"] == device_id:
|
||||
device_specific_metrics.add(rule["metric"])
|
||||
|
||||
for rule in rules:
|
||||
rule_metric = rule["metric"]
|
||||
if rule_metric not in metrics:
|
||||
continue
|
||||
|
||||
# Check if rule applies to this device
|
||||
applies = False
|
||||
if rule["device_id"] == device_id:
|
||||
applies = True
|
||||
elif rule["device_id"] is None and rule["group_id"] is None:
|
||||
# Tenant-wide rule — skip if device-specific rule exists for same metric
|
||||
if rule_metric in device_specific_metrics:
|
||||
continue
|
||||
applies = True
|
||||
elif rule["group_id"] and rule["group_id"] in device_groups:
|
||||
applies = True
|
||||
|
||||
if not applies:
|
||||
continue
|
||||
|
||||
value = metrics[rule_metric]
|
||||
breaching = _check_threshold(value, rule["operator"], rule["threshold"])
|
||||
|
||||
if breaching:
|
||||
reached = await _increment_breach(r, device_id, rule["id"], rule["duration_polls"])
|
||||
if reached:
|
||||
# Check if already firing
|
||||
if await _has_open_alert(device_id, rule["id"]):
|
||||
continue
|
||||
|
||||
# Check flapping
|
||||
is_flapping = await _check_flapping(r, device_id, rule["id"])
|
||||
|
||||
hostname = await _get_device_hostname(device_id)
|
||||
message = f"{rule['name']}: {rule_metric} = {value} (threshold: {rule['operator']} {rule['threshold']})"
|
||||
|
||||
alert_event = await _create_alert_event(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
rule_id=rule["id"],
|
||||
status="flapping" if is_flapping else "firing",
|
||||
severity=rule["severity"],
|
||||
metric=rule_metric,
|
||||
value=value,
|
||||
threshold=rule["threshold"],
|
||||
message=message,
|
||||
is_flapping=is_flapping,
|
||||
)
|
||||
|
||||
if is_flapping:
|
||||
logger.info(
|
||||
"Alert %s for device %s is flapping — notifications suppressed",
|
||||
rule["name"], device_id,
|
||||
)
|
||||
else:
|
||||
channels = await _get_channels_for_rule(rule["id"])
|
||||
if channels:
|
||||
asyncio.create_task(_dispatch_async(alert_event, channels, hostname))
|
||||
else:
|
||||
# Not breaching — reset counter and check for open alert to resolve
|
||||
await _reset_breach(r, device_id, rule["id"])
|
||||
|
||||
if await _has_open_alert(device_id, rule["id"]):
|
||||
# Check flapping before resolving
|
||||
is_flapping = await _check_flapping(r, device_id, rule["id"])
|
||||
|
||||
await _resolve_alert(device_id, rule["id"])
|
||||
|
||||
hostname = await _get_device_hostname(device_id)
|
||||
message = f"Resolved: {rule['name']}: {rule_metric} = {value}"
|
||||
|
||||
resolved_event = await _create_alert_event(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
rule_id=rule["id"],
|
||||
status="resolved",
|
||||
severity=rule["severity"],
|
||||
metric=rule_metric,
|
||||
value=value,
|
||||
threshold=rule["threshold"],
|
||||
message=message,
|
||||
is_flapping=is_flapping,
|
||||
)
|
||||
|
||||
if not is_flapping:
|
||||
channels = await _get_channels_for_rule(rule["id"])
|
||||
if channels:
|
||||
asyncio.create_task(_dispatch_async(resolved_event, channels, hostname))
|
||||
|
||||
|
||||
async def _get_offline_rule(tenant_id: str) -> dict | None:
|
||||
"""Look up the device_offline default rule for a tenant."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT id, enabled FROM alert_rules
|
||||
WHERE tenant_id = CAST(:tenant_id AS uuid)
|
||||
AND metric = 'device_offline' AND is_default = TRUE
|
||||
LIMIT 1
|
||||
"""),
|
||||
{"tenant_id": tenant_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if row:
|
||||
return {"id": str(row[0]), "enabled": row[1]}
|
||||
return None
|
||||
|
||||
|
||||
async def evaluate_offline(device_id: str, tenant_id: str) -> None:
|
||||
"""Create a critical alert when a device goes offline.
|
||||
|
||||
Uses the tenant's device_offline default rule if it exists and is enabled.
|
||||
Falls back to system-level alert (rule_id=NULL) for backward compatibility.
|
||||
"""
|
||||
if await _is_device_in_maintenance(tenant_id, device_id):
|
||||
logger.debug(
|
||||
"Offline alert suppressed by maintenance window for device %s",
|
||||
device_id,
|
||||
)
|
||||
return
|
||||
|
||||
rule = await _get_offline_rule(tenant_id)
|
||||
rule_id = rule["id"] if rule else None
|
||||
|
||||
# If rule exists but is disabled, skip alert creation (user opted out)
|
||||
if rule and not rule["enabled"]:
|
||||
return
|
||||
|
||||
if rule_id:
|
||||
if await _has_open_alert(device_id, rule_id):
|
||||
return
|
||||
else:
|
||||
if await _has_open_alert(device_id, None, "offline"):
|
||||
return
|
||||
|
||||
hostname = await _get_device_hostname(device_id)
|
||||
message = f"Device {hostname} is offline"
|
||||
|
||||
alert_event = await _create_alert_event(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
rule_id=rule_id,
|
||||
status="firing",
|
||||
severity="critical",
|
||||
metric="offline",
|
||||
value=None,
|
||||
threshold=None,
|
||||
message=message,
|
||||
)
|
||||
|
||||
# Use rule-linked channels if available, otherwise tenant-wide channels
|
||||
if rule_id:
|
||||
channels = await _get_channels_for_rule(rule_id)
|
||||
if not channels:
|
||||
channels = await _get_channels_for_tenant(tenant_id)
|
||||
else:
|
||||
channels = await _get_channels_for_tenant(tenant_id)
|
||||
|
||||
if channels:
|
||||
asyncio.create_task(_dispatch_async(alert_event, channels, hostname))
|
||||
|
||||
|
||||
async def evaluate_online(device_id: str, tenant_id: str) -> None:
|
||||
"""Resolve offline alert when device comes back online."""
|
||||
rule = await _get_offline_rule(tenant_id)
|
||||
rule_id = rule["id"] if rule else None
|
||||
|
||||
if rule_id:
|
||||
if not await _has_open_alert(device_id, rule_id):
|
||||
return
|
||||
await _resolve_alert(device_id, rule_id)
|
||||
else:
|
||||
if not await _has_open_alert(device_id, None, "offline"):
|
||||
return
|
||||
await _resolve_alert(device_id, None, "offline")
|
||||
|
||||
hostname = await _get_device_hostname(device_id)
|
||||
message = f"Device {hostname} is back online"
|
||||
|
||||
resolved_event = await _create_alert_event(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
rule_id=rule_id,
|
||||
status="resolved",
|
||||
severity="critical",
|
||||
metric="offline",
|
||||
value=None,
|
||||
threshold=None,
|
||||
message=message,
|
||||
)
|
||||
|
||||
if rule_id:
|
||||
channels = await _get_channels_for_rule(rule_id)
|
||||
if not channels:
|
||||
channels = await _get_channels_for_tenant(tenant_id)
|
||||
else:
|
||||
channels = await _get_channels_for_tenant(tenant_id)
|
||||
|
||||
if channels:
|
||||
asyncio.create_task(_dispatch_async(resolved_event, channels, hostname))
|
||||
190
backend/app/services/api_key_service.py
Normal file
190
backend/app/services/api_key_service.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""API key generation, validation, and management service.
|
||||
|
||||
Keys use the mktp_ prefix for easy identification in logs.
|
||||
Storage uses SHA-256 hash -- the plaintext key is never persisted.
|
||||
Validation uses AdminAsyncSessionLocal since it runs before tenant context is set.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
|
||||
# Allowed scopes for API keys
|
||||
ALLOWED_SCOPES: set[str] = {
|
||||
"devices:read",
|
||||
"devices:write",
|
||||
"config:read",
|
||||
"config:write",
|
||||
"alerts:read",
|
||||
"firmware:write",
|
||||
}
|
||||
|
||||
|
||||
def generate_raw_key() -> str:
|
||||
"""Generate a raw API key with mktp_ prefix + 32 URL-safe random chars."""
|
||||
random_part = secrets.token_urlsafe(32)
|
||||
return f"mktp_{random_part}"
|
||||
|
||||
|
||||
def hash_key(raw_key: str) -> str:
|
||||
"""SHA-256 hex digest of a raw API key."""
|
||||
return hashlib.sha256(raw_key.encode()).hexdigest()
|
||||
|
||||
|
||||
async def create_api_key(
|
||||
db,
|
||||
tenant_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
name: str,
|
||||
scopes: list[str],
|
||||
expires_at: Optional[datetime] = None,
|
||||
) -> dict:
|
||||
"""Create a new API key.
|
||||
|
||||
Returns dict with:
|
||||
- key: the plaintext key (shown once, never again)
|
||||
- id: the key UUID
|
||||
- key_prefix: first 9 chars of the key (e.g. "mktp_abc1")
|
||||
"""
|
||||
raw_key = generate_raw_key()
|
||||
key_hash_value = hash_key(raw_key)
|
||||
key_prefix = raw_key[:9] # "mktp_" + first 4 random chars
|
||||
|
||||
result = await db.execute(
|
||||
text("""
|
||||
INSERT INTO api_keys (tenant_id, user_id, name, key_prefix, key_hash, scopes, expires_at)
|
||||
VALUES (:tenant_id, :user_id, :name, :key_prefix, :key_hash, CAST(:scopes AS jsonb), :expires_at)
|
||||
RETURNING id, created_at
|
||||
"""),
|
||||
{
|
||||
"tenant_id": str(tenant_id),
|
||||
"user_id": str(user_id),
|
||||
"name": name,
|
||||
"key_prefix": key_prefix,
|
||||
"key_hash": key_hash_value,
|
||||
"scopes": json.dumps(scopes),
|
||||
"expires_at": expires_at,
|
||||
},
|
||||
)
|
||||
row = result.fetchone()
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"key": raw_key,
|
||||
"id": row.id,
|
||||
"key_prefix": key_prefix,
|
||||
"name": name,
|
||||
"scopes": scopes,
|
||||
"expires_at": expires_at,
|
||||
"created_at": row.created_at,
|
||||
}
|
||||
|
||||
|
||||
async def validate_api_key(raw_key: str) -> Optional[dict]:
|
||||
"""Validate an API key and return context if valid.
|
||||
|
||||
Uses AdminAsyncSessionLocal since this runs before tenant context is set.
|
||||
|
||||
Returns dict with tenant_id, user_id, scopes, key_id on success.
|
||||
Returns None for invalid, expired, or revoked keys.
|
||||
Updates last_used_at on successful validation.
|
||||
"""
|
||||
key_hash_value = hash_key(raw_key)
|
||||
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT id, tenant_id, user_id, scopes, expires_at, revoked_at
|
||||
FROM api_keys
|
||||
WHERE key_hash = :key_hash
|
||||
"""),
|
||||
{"key_hash": key_hash_value},
|
||||
)
|
||||
row = result.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
# Check revoked
|
||||
if row.revoked_at is not None:
|
||||
return None
|
||||
|
||||
# Check expired
|
||||
if row.expires_at is not None and row.expires_at <= datetime.now(timezone.utc):
|
||||
return None
|
||||
|
||||
# Update last_used_at
|
||||
await session.execute(
|
||||
text("""
|
||||
UPDATE api_keys SET last_used_at = now()
|
||||
WHERE id = :key_id
|
||||
"""),
|
||||
{"key_id": str(row.id)},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
return {
|
||||
"tenant_id": row.tenant_id,
|
||||
"user_id": row.user_id,
|
||||
"scopes": row.scopes if row.scopes else [],
|
||||
"key_id": row.id,
|
||||
}
|
||||
|
||||
|
||||
async def list_api_keys(db, tenant_id: uuid.UUID) -> list[dict]:
|
||||
"""List all API keys for a tenant (active and revoked).
|
||||
|
||||
Returns keys with masked display (key_prefix + "...").
|
||||
"""
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT id, name, key_prefix, scopes, expires_at, last_used_at,
|
||||
created_at, revoked_at, user_id
|
||||
FROM api_keys
|
||||
WHERE tenant_id = :tenant_id
|
||||
ORDER BY created_at DESC
|
||||
"""),
|
||||
{"tenant_id": str(tenant_id)},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": row.id,
|
||||
"name": row.name,
|
||||
"key_prefix": row.key_prefix,
|
||||
"scopes": row.scopes if row.scopes else [],
|
||||
"expires_at": row.expires_at.isoformat() if row.expires_at else None,
|
||||
"last_used_at": row.last_used_at.isoformat() if row.last_used_at else None,
|
||||
"created_at": row.created_at.isoformat() if row.created_at else None,
|
||||
"revoked_at": row.revoked_at.isoformat() if row.revoked_at else None,
|
||||
"user_id": str(row.user_id),
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
|
||||
async def revoke_api_key(db, tenant_id: uuid.UUID, key_id: uuid.UUID) -> bool:
|
||||
"""Revoke an API key by setting revoked_at = now().
|
||||
|
||||
Returns True if a key was actually revoked, False if not found or already revoked.
|
||||
"""
|
||||
result = await db.execute(
|
||||
text("""
|
||||
UPDATE api_keys
|
||||
SET revoked_at = now()
|
||||
WHERE id = :key_id AND tenant_id = :tenant_id AND revoked_at IS NULL
|
||||
RETURNING id
|
||||
"""),
|
||||
{"key_id": str(key_id), "tenant_id": str(tenant_id)},
|
||||
)
|
||||
row = result.fetchone()
|
||||
await db.commit()
|
||||
return row is not None
|
||||
92
backend/app/services/audit_service.py
Normal file
92
backend/app/services/audit_service.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Centralized audit logging service.
|
||||
|
||||
Provides a fire-and-forget ``log_action`` coroutine that inserts a row into
|
||||
the ``audit_logs`` table. Uses raw SQL INSERT (not ORM) for minimal overhead.
|
||||
|
||||
The function is wrapped in a try/except so that a logging failure **never**
|
||||
breaks the parent operation.
|
||||
|
||||
Phase 30: When details are non-empty, they are encrypted via OpenBao Transit
|
||||
(per-tenant data key) and stored in encrypted_details. The plaintext details
|
||||
column is set to '{}' for column compatibility. If Transit encryption fails
|
||||
(e.g., OpenBao unavailable), details are stored in plaintext as a fallback.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = structlog.get_logger("audit")
|
||||
|
||||
|
||||
async def log_action(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
action: str,
|
||||
resource_type: Optional[str] = None,
|
||||
resource_id: Optional[str] = None,
|
||||
device_id: Optional[uuid.UUID] = None,
|
||||
details: Optional[dict[str, Any]] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Insert a row into audit_logs. Swallows all exceptions on failure."""
|
||||
try:
|
||||
import json as _json
|
||||
|
||||
details_dict = details or {}
|
||||
details_json = _json.dumps(details_dict)
|
||||
encrypted_details: Optional[str] = None
|
||||
|
||||
# Attempt Transit encryption for non-empty details
|
||||
if details_dict:
|
||||
try:
|
||||
from app.services.crypto import encrypt_data_transit
|
||||
|
||||
encrypted_details = await encrypt_data_transit(
|
||||
details_json, str(tenant_id)
|
||||
)
|
||||
# Encryption succeeded — clear plaintext details
|
||||
details_json = _json.dumps({})
|
||||
except Exception:
|
||||
# Transit unavailable — fall back to plaintext details
|
||||
logger.warning(
|
||||
"audit_transit_encryption_failed",
|
||||
action=action,
|
||||
tenant_id=str(tenant_id),
|
||||
exc_info=True,
|
||||
)
|
||||
# Keep details_json as-is (plaintext fallback)
|
||||
encrypted_details = None
|
||||
|
||||
await db.execute(
|
||||
text(
|
||||
"INSERT INTO audit_logs "
|
||||
"(tenant_id, user_id, action, resource_type, resource_id, "
|
||||
"device_id, details, encrypted_details, ip_address) "
|
||||
"VALUES (:tenant_id, :user_id, :action, :resource_type, "
|
||||
":resource_id, :device_id, CAST(:details AS jsonb), "
|
||||
":encrypted_details, :ip_address)"
|
||||
),
|
||||
{
|
||||
"tenant_id": str(tenant_id),
|
||||
"user_id": str(user_id),
|
||||
"action": action,
|
||||
"resource_type": resource_type,
|
||||
"resource_id": resource_id,
|
||||
"device_id": str(device_id) if device_id else None,
|
||||
"details": details_json,
|
||||
"encrypted_details": encrypted_details,
|
||||
"ip_address": ip_address,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"audit_log_insert_failed",
|
||||
action=action,
|
||||
tenant_id=str(tenant_id),
|
||||
exc_info=True,
|
||||
)
|
||||
154
backend/app/services/auth.py
Normal file
154
backend/app/services/auth.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
JWT authentication service.
|
||||
|
||||
Handles password hashing, JWT token creation, token verification,
|
||||
and token revocation via Redis.
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
import bcrypt
|
||||
from fastapi import HTTPException, status
|
||||
from jose import JWTError, jwt
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.config import settings
|
||||
|
||||
TOKEN_REVOCATION_PREFIX = "token_revoked:"
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a plaintext password using bcrypt.
|
||||
|
||||
DEPRECATED: Used only by password reset (temporary bcrypt hash for
|
||||
upgrade flow) and bootstrap_first_admin. Remove post-v6.0.
|
||||
"""
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a plaintext password against a bcrypt hash.
|
||||
|
||||
DEPRECATED: Used only by the one-time SRP upgrade flow (login with
|
||||
must_upgrade_auth=True) and anti-enumeration dummy calls. Remove post-v6.0.
|
||||
"""
|
||||
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
|
||||
|
||||
|
||||
def create_access_token(
|
||||
user_id: uuid.UUID,
|
||||
tenant_id: Optional[uuid.UUID],
|
||||
role: str,
|
||||
) -> str:
|
||||
"""
|
||||
Create a short-lived JWT access token.
|
||||
|
||||
Claims:
|
||||
sub: user UUID (subject)
|
||||
tenant_id: tenant UUID or None for super_admin
|
||||
role: user's role string
|
||||
type: "access" (to distinguish from refresh tokens)
|
||||
exp: expiry timestamp
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
expire = now + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"tenant_id": str(tenant_id) if tenant_id else None,
|
||||
"role": role,
|
||||
"type": "access",
|
||||
"iat": now,
|
||||
"exp": expire,
|
||||
}
|
||||
|
||||
return jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(user_id: uuid.UUID) -> str:
|
||||
"""
|
||||
Create a long-lived JWT refresh token.
|
||||
|
||||
Claims:
|
||||
sub: user UUID (subject)
|
||||
type: "refresh" (to distinguish from access tokens)
|
||||
exp: expiry timestamp (7 days)
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
expire = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"type": "refresh",
|
||||
"iat": now,
|
||||
"exp": expire,
|
||||
}
|
||||
|
||||
return jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def verify_token(token: str, expected_type: str = "access") -> dict:
|
||||
"""
|
||||
Decode and validate a JWT token.
|
||||
|
||||
Args:
|
||||
token: JWT string to validate
|
||||
expected_type: "access" or "refresh"
|
||||
|
||||
Returns:
|
||||
dict: Decoded payload (sub, tenant_id, role, type, exp, iat)
|
||||
|
||||
Raises:
|
||||
HTTPException 401: If token is invalid, expired, or wrong type
|
||||
"""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM],
|
||||
)
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
|
||||
# Validate token type
|
||||
token_type = payload.get("type")
|
||||
if token_type != expected_type:
|
||||
raise credentials_exception
|
||||
|
||||
# Validate subject exists
|
||||
sub = payload.get("sub")
|
||||
if not sub:
|
||||
raise credentials_exception
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
async def revoke_user_tokens(redis: Redis, user_id: str) -> None:
|
||||
"""Mark all tokens for a user as revoked by storing current timestamp.
|
||||
|
||||
Any refresh token issued before this timestamp will be rejected.
|
||||
TTL matches maximum refresh token lifetime (7 days).
|
||||
"""
|
||||
key = f"{TOKEN_REVOCATION_PREFIX}{user_id}"
|
||||
await redis.set(key, str(time.time()), ex=7 * 24 * 3600)
|
||||
|
||||
|
||||
async def is_token_revoked(redis: Redis, user_id: str, issued_at: float) -> bool:
|
||||
"""Check if a token was issued before the user's revocation timestamp.
|
||||
|
||||
Returns True if the token should be rejected.
|
||||
"""
|
||||
key = f"{TOKEN_REVOCATION_PREFIX}{user_id}"
|
||||
revoked_at = await redis.get(key)
|
||||
if revoked_at is None:
|
||||
return False
|
||||
return issued_at < float(revoked_at)
|
||||
197
backend/app/services/backup_scheduler.py
Normal file
197
backend/app/services/backup_scheduler.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""Dynamic backup scheduler — reads cron schedules from DB, manages APScheduler jobs."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
from app.models.config_backup import ConfigBackupSchedule
|
||||
from app.models.device import Device
|
||||
from app.services import backup_service
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_scheduler: Optional[AsyncIOScheduler] = None
|
||||
|
||||
# System default: 2am UTC daily
|
||||
DEFAULT_CRON = "0 2 * * *"
|
||||
|
||||
|
||||
def _cron_to_trigger(cron_expr: str) -> Optional[CronTrigger]:
|
||||
"""Parse a 5-field cron expression into an APScheduler CronTrigger.
|
||||
|
||||
Returns None if the expression is invalid.
|
||||
"""
|
||||
try:
|
||||
parts = cron_expr.strip().split()
|
||||
if len(parts) != 5:
|
||||
return None
|
||||
minute, hour, day, month, day_of_week = parts
|
||||
return CronTrigger(
|
||||
minute=minute, hour=hour, day=day, month=month,
|
||||
day_of_week=day_of_week, timezone="UTC",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Invalid cron expression '%s': %s", cron_expr, e)
|
||||
return None
|
||||
|
||||
|
||||
def build_schedule_map(schedules: list) -> dict[str, list[dict]]:
|
||||
"""Group device schedules by cron expression.
|
||||
|
||||
Returns: {cron_expression: [{device_id, tenant_id}, ...]}
|
||||
"""
|
||||
schedule_map: dict[str, list[dict]] = {}
|
||||
for s in schedules:
|
||||
if not s.enabled:
|
||||
continue
|
||||
cron = s.cron_expression or DEFAULT_CRON
|
||||
if cron not in schedule_map:
|
||||
schedule_map[cron] = []
|
||||
schedule_map[cron].append({
|
||||
"device_id": str(s.device_id),
|
||||
"tenant_id": str(s.tenant_id),
|
||||
})
|
||||
return schedule_map
|
||||
|
||||
|
||||
async def _run_scheduled_backups(devices: list[dict]) -> None:
|
||||
"""Run backups for a list of devices. Each failure is isolated."""
|
||||
success_count = 0
|
||||
failure_count = 0
|
||||
|
||||
for dev_info in devices:
|
||||
try:
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
await backup_service.run_backup(
|
||||
device_id=dev_info["device_id"],
|
||||
tenant_id=dev_info["tenant_id"],
|
||||
trigger_type="scheduled",
|
||||
db_session=session,
|
||||
)
|
||||
await session.commit()
|
||||
logger.info("Scheduled backup OK: device %s", dev_info["device_id"])
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Scheduled backup FAILED: device %s: %s",
|
||||
dev_info["device_id"], e,
|
||||
)
|
||||
failure_count += 1
|
||||
|
||||
logger.info(
|
||||
"Backup batch complete — %d succeeded, %d failed",
|
||||
success_count, failure_count,
|
||||
)
|
||||
|
||||
|
||||
async def _load_effective_schedules() -> list:
|
||||
"""Load all effective schedules from DB.
|
||||
|
||||
For each device: use device-specific schedule if exists, else tenant default.
|
||||
Returns flat list of (device_id, tenant_id, cron_expression, enabled) objects.
|
||||
"""
|
||||
from types import SimpleNamespace
|
||||
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
# Get all devices
|
||||
dev_result = await session.execute(select(Device))
|
||||
devices = dev_result.scalars().all()
|
||||
|
||||
# Get all schedules
|
||||
sched_result = await session.execute(select(ConfigBackupSchedule))
|
||||
schedules = sched_result.scalars().all()
|
||||
|
||||
# Index: device-specific and tenant defaults
|
||||
device_schedules = {} # device_id -> schedule
|
||||
tenant_defaults = {} # tenant_id -> schedule
|
||||
|
||||
for s in schedules:
|
||||
if s.device_id:
|
||||
device_schedules[str(s.device_id)] = s
|
||||
else:
|
||||
tenant_defaults[str(s.tenant_id)] = s
|
||||
|
||||
effective = []
|
||||
for dev in devices:
|
||||
dev_id = str(dev.id)
|
||||
tenant_id = str(dev.tenant_id)
|
||||
|
||||
if dev_id in device_schedules:
|
||||
sched = device_schedules[dev_id]
|
||||
elif tenant_id in tenant_defaults:
|
||||
sched = tenant_defaults[tenant_id]
|
||||
else:
|
||||
# No schedule configured — use system default
|
||||
sched = None
|
||||
|
||||
effective.append(SimpleNamespace(
|
||||
device_id=dev_id,
|
||||
tenant_id=tenant_id,
|
||||
cron_expression=sched.cron_expression if sched else DEFAULT_CRON,
|
||||
enabled=sched.enabled if sched else True,
|
||||
))
|
||||
|
||||
return effective
|
||||
|
||||
|
||||
async def sync_schedules() -> None:
|
||||
"""Reload all schedules from DB and reconfigure APScheduler jobs."""
|
||||
global _scheduler
|
||||
if not _scheduler:
|
||||
return
|
||||
|
||||
# Remove all existing backup jobs (keep other jobs like firmware check)
|
||||
for job in _scheduler.get_jobs():
|
||||
if job.id.startswith("backup_cron_"):
|
||||
job.remove()
|
||||
|
||||
schedules = await _load_effective_schedules()
|
||||
schedule_map = build_schedule_map(schedules)
|
||||
|
||||
for cron_expr, devices in schedule_map.items():
|
||||
trigger = _cron_to_trigger(cron_expr)
|
||||
if not trigger:
|
||||
logger.warning("Skipping invalid cron '%s', using default", cron_expr)
|
||||
trigger = _cron_to_trigger(DEFAULT_CRON)
|
||||
|
||||
job_id = f"backup_cron_{cron_expr.replace(' ', '_')}"
|
||||
_scheduler.add_job(
|
||||
_run_scheduled_backups,
|
||||
trigger=trigger,
|
||||
args=[devices],
|
||||
id=job_id,
|
||||
name=f"Backup: {cron_expr} ({len(devices)} devices)",
|
||||
max_instances=1,
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info("Scheduled %d devices with cron '%s'", len(devices), cron_expr)
|
||||
|
||||
|
||||
async def on_schedule_change(tenant_id: str, device_id: str) -> None:
|
||||
"""Called when a schedule is created/updated via API. Hot-reloads all schedules."""
|
||||
logger.info("Schedule changed for tenant=%s device=%s, resyncing", tenant_id, device_id)
|
||||
await sync_schedules()
|
||||
|
||||
|
||||
async def start_backup_scheduler() -> None:
|
||||
"""Start the APScheduler and load initial schedules from DB."""
|
||||
global _scheduler
|
||||
_scheduler = AsyncIOScheduler(timezone="UTC")
|
||||
_scheduler.start()
|
||||
|
||||
await sync_schedules()
|
||||
logger.info("Backup scheduler started with dynamic schedules")
|
||||
|
||||
|
||||
async def stop_backup_scheduler() -> None:
|
||||
"""Gracefully shutdown the scheduler."""
|
||||
global _scheduler
|
||||
if _scheduler:
|
||||
_scheduler.shutdown(wait=False)
|
||||
_scheduler = None
|
||||
logger.info("Backup scheduler stopped")
|
||||
378
backend/app/services/backup_service.py
Normal file
378
backend/app/services/backup_service.py
Normal file
@@ -0,0 +1,378 @@
|
||||
"""SSH-based config capture service for RouterOS devices.
|
||||
|
||||
This service handles:
|
||||
1. capture_export() — SSH to device, run /export compact, return stdout text
|
||||
2. capture_binary_backup() — SSH to device, trigger /system backup save, SFTP-download result
|
||||
3. run_backup() — Orchestrate a full backup: capture + git commit + DB record
|
||||
|
||||
All functions are async (asyncssh is asyncio-native).
|
||||
|
||||
Security policy:
|
||||
known_hosts=None is intentional — RouterOS devices use self-signed SSH host keys
|
||||
that change on reset or key regeneration. This mirrors InsecureSkipVerify=true
|
||||
used in the poller's TLS connection. The threat model accepts device impersonation
|
||||
risk in exchange for operational simplicity (no pre-enrollment of host keys needed).
|
||||
See Pitfall 2 in 04-RESEARCH.md.
|
||||
|
||||
pygit2 calls are synchronous C bindings and MUST be wrapped in run_in_executor.
|
||||
See Pitfall 3 in 04-RESEARCH.md.
|
||||
|
||||
Phase 30: ALL backups (manual, scheduled, pre-restore) are encrypted via OpenBao
|
||||
Transit (Tier 2) before git commit. The server retains decrypt capability for
|
||||
on-demand viewing. Raw files in git are ciphertext; the API decrypts on GET.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import asyncssh
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal, set_tenant_context
|
||||
from app.models.config_backup import ConfigBackupRun
|
||||
from app.models.device import Device
|
||||
from app.services import git_store
|
||||
from app.services.crypto import decrypt_credentials_hybrid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Fixed backup file name on device flash — overwrites on each run so files
|
||||
# don't accumulate. See Pitfall 4 in 04-RESEARCH.md.
|
||||
_BACKUP_NAME = "portal-backup"
|
||||
|
||||
|
||||
async def capture_export(
|
||||
ip: str,
|
||||
port: int = 22,
|
||||
username: str = "",
|
||||
password: str = "",
|
||||
) -> str:
|
||||
"""SSH to a RouterOS device and capture /export compact output.
|
||||
|
||||
Args:
|
||||
ip: Device IP address.
|
||||
port: SSH port (default 22; RouterOS default is 22).
|
||||
username: SSH login username.
|
||||
password: SSH login password.
|
||||
|
||||
Returns:
|
||||
The raw RSC text from /export compact (may include RouterOS header line).
|
||||
|
||||
Raises:
|
||||
asyncssh.Error: On SSH connection or command execution failure.
|
||||
"""
|
||||
async with asyncssh.connect(
|
||||
ip,
|
||||
port=port,
|
||||
username=username,
|
||||
password=password,
|
||||
known_hosts=None, # RouterOS self-signed host keys — see module docstring
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
result = await conn.run("/export compact", check=True)
|
||||
return result.stdout
|
||||
|
||||
|
||||
async def capture_binary_backup(
|
||||
ip: str,
|
||||
port: int = 22,
|
||||
username: str = "",
|
||||
password: str = "",
|
||||
) -> bytes:
|
||||
"""SSH to a RouterOS device, create a binary backup, SFTP-download it, then clean up.
|
||||
|
||||
Uses a fixed backup name ({_BACKUP_NAME}.backup) so the file overwrites
|
||||
on subsequent runs, preventing flash storage accumulation.
|
||||
|
||||
The cleanup (removing the file from device flash) runs in a try/finally
|
||||
block so cleanup failures don't mask the actual backup error but are
|
||||
logged for observability. See Pitfall 4 in 04-RESEARCH.md.
|
||||
|
||||
Args:
|
||||
ip: Device IP address.
|
||||
port: SSH port (default 22).
|
||||
username: SSH login username.
|
||||
password: SSH login password.
|
||||
|
||||
Returns:
|
||||
Raw bytes of the binary backup file.
|
||||
|
||||
Raises:
|
||||
asyncssh.Error: On SSH connection, command, or SFTP failure.
|
||||
"""
|
||||
async with asyncssh.connect(
|
||||
ip,
|
||||
port=port,
|
||||
username=username,
|
||||
password=password,
|
||||
known_hosts=None,
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
# Step 1: Trigger backup creation on device flash.
|
||||
await conn.run(
|
||||
f"/system backup save name={_BACKUP_NAME} dont-encrypt=yes",
|
||||
check=True,
|
||||
)
|
||||
|
||||
buf = io.BytesIO()
|
||||
try:
|
||||
# Step 2: SFTP-download the backup file.
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
async with sftp.open(f"{_BACKUP_NAME}.backup", "rb") as f:
|
||||
buf.write(await f.read())
|
||||
finally:
|
||||
# Step 3: Remove backup file from device flash (best-effort cleanup).
|
||||
try:
|
||||
await conn.run(f"/file remove {_BACKUP_NAME}.backup", check=True)
|
||||
except Exception as cleanup_err:
|
||||
logger.warning(
|
||||
"Failed to remove backup file from device %s: %s",
|
||||
ip,
|
||||
cleanup_err,
|
||||
)
|
||||
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
async def run_backup(
|
||||
device_id: str,
|
||||
tenant_id: str,
|
||||
trigger_type: str,
|
||||
db_session: AsyncSession | None = None,
|
||||
) -> dict:
|
||||
"""Orchestrate a full config backup for a device.
|
||||
|
||||
Steps:
|
||||
1. Load device from DB (ip_address, encrypted_credentials).
|
||||
2. Decrypt credentials using crypto.decrypt_credentials().
|
||||
3. Capture /export compact and binary backup concurrently via asyncio.gather().
|
||||
4. Compute line delta vs the most recent export.rsc in git (None for first backup).
|
||||
5. Commit both files to the tenant's bare git repo (run_in_executor for pygit2).
|
||||
6. Insert ConfigBackupRun record with commit SHA, trigger type, line deltas.
|
||||
7. Return summary dict.
|
||||
|
||||
Args:
|
||||
device_id: Device UUID as string.
|
||||
tenant_id: Tenant UUID as string.
|
||||
trigger_type: 'scheduled' | 'manual' | 'pre-restore'
|
||||
db_session: Optional AsyncSession with RLS context already set.
|
||||
If None, uses AdminAsyncSessionLocal (for scheduler context).
|
||||
|
||||
Returns:
|
||||
Dict: {"commit_sha": str, "trigger_type": str, "lines_added": int|None, "lines_removed": int|None}
|
||||
|
||||
Raises:
|
||||
ValueError: If device not found or missing credentials.
|
||||
asyncssh.Error: On SSH/SFTP failure.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
ts = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Step 1: Load device from DB
|
||||
# -----------------------------------------------------------------------
|
||||
if db_session is not None:
|
||||
session = db_session
|
||||
should_close = False
|
||||
else:
|
||||
# Scheduler context: use admin session (cross-tenant; RLS bypassed)
|
||||
session = AdminAsyncSessionLocal()
|
||||
should_close = True
|
||||
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
|
||||
if should_close:
|
||||
# Admin session doesn't have RLS context — query directly.
|
||||
result = await session.execute(
|
||||
select(Device).where(
|
||||
Device.id == device_id, # type: ignore[arg-type]
|
||||
Device.tenant_id == tenant_id, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
else:
|
||||
result = await session.execute(
|
||||
select(Device).where(Device.id == device_id) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
device = result.scalar_one_or_none()
|
||||
if device is None:
|
||||
raise ValueError(f"Device {device_id!r} not found for tenant {tenant_id!r}")
|
||||
|
||||
if not device.encrypted_credentials_transit and not device.encrypted_credentials:
|
||||
raise ValueError(
|
||||
f"Device {device_id!r} has no stored credentials — cannot perform backup"
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Step 2: Decrypt credentials (dual-read: Transit preferred, legacy fallback)
|
||||
# -----------------------------------------------------------------------
|
||||
key = settings.get_encryption_key_bytes()
|
||||
creds_json = await decrypt_credentials_hybrid(
|
||||
device.encrypted_credentials_transit,
|
||||
device.encrypted_credentials,
|
||||
str(device.tenant_id),
|
||||
key,
|
||||
)
|
||||
creds = json.loads(creds_json)
|
||||
ssh_username = creds.get("username", "")
|
||||
ssh_password = creds.get("password", "")
|
||||
ip = device.ip_address
|
||||
|
||||
hostname = device.hostname or ip
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Step 3: Capture export and binary backup concurrently
|
||||
# -----------------------------------------------------------------------
|
||||
logger.info(
|
||||
"Starting %s backup for device %s (%s) tenant %s",
|
||||
trigger_type,
|
||||
hostname,
|
||||
ip,
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
export_text, binary_backup = await asyncio.gather(
|
||||
capture_export(ip, username=ssh_username, password=ssh_password),
|
||||
capture_binary_backup(ip, username=ssh_username, password=ssh_password),
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Step 4: Compute line delta vs prior version
|
||||
# -----------------------------------------------------------------------
|
||||
lines_added: int | None = None
|
||||
lines_removed: int | None = None
|
||||
|
||||
prior_commits = await loop.run_in_executor(
|
||||
None, git_store.list_device_commits, tenant_id, device_id
|
||||
)
|
||||
|
||||
if prior_commits:
|
||||
try:
|
||||
prior_export_bytes = await loop.run_in_executor(
|
||||
None, git_store.read_file, tenant_id, prior_commits[0]["sha"], device_id, "export.rsc"
|
||||
)
|
||||
prior_text = prior_export_bytes.decode("utf-8", errors="replace")
|
||||
lines_added, lines_removed = await loop.run_in_executor(
|
||||
None, git_store.compute_line_delta, prior_text, export_text
|
||||
)
|
||||
except Exception as delta_err:
|
||||
logger.warning(
|
||||
"Failed to compute line delta for device %s: %s",
|
||||
device_id,
|
||||
delta_err,
|
||||
)
|
||||
# Keep lines_added/lines_removed as None on error — non-fatal
|
||||
else:
|
||||
# First backup: all lines are "added", none removed
|
||||
all_lines = len(export_text.splitlines())
|
||||
lines_added = all_lines
|
||||
lines_removed = 0
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Step 5: Encrypt ALL backups via Transit (Tier 2: OpenBao Transit)
|
||||
# -----------------------------------------------------------------------
|
||||
encryption_tier: int | None = None
|
||||
git_export_content = export_text
|
||||
git_binary_content = binary_backup
|
||||
|
||||
try:
|
||||
from app.services.crypto import encrypt_data_transit
|
||||
|
||||
encrypted_export = await encrypt_data_transit(
|
||||
export_text, tenant_id
|
||||
)
|
||||
encrypted_binary = await encrypt_data_transit(
|
||||
base64.b64encode(binary_backup).decode(), tenant_id
|
||||
)
|
||||
# Transit ciphertext is text — store directly in git
|
||||
git_export_content = encrypted_export
|
||||
git_binary_content = encrypted_binary.encode("utf-8")
|
||||
encryption_tier = 2
|
||||
logger.info(
|
||||
"Tier 2 Transit encryption applied for %s backup of device %s",
|
||||
trigger_type,
|
||||
device_id,
|
||||
)
|
||||
except Exception as enc_err:
|
||||
# Transit unavailable — fall back to plaintext (non-fatal)
|
||||
logger.warning(
|
||||
"Transit encryption failed for %s backup of device %s, "
|
||||
"storing plaintext: %s",
|
||||
trigger_type,
|
||||
device_id,
|
||||
enc_err,
|
||||
)
|
||||
# Keep encryption_tier = None (plaintext fallback)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Step 6: Commit to git (wrapped in run_in_executor — pygit2 is sync C bindings)
|
||||
# -----------------------------------------------------------------------
|
||||
commit_message = (
|
||||
f"{trigger_type}: {hostname} ({ip}) at {ts}"
|
||||
)
|
||||
|
||||
commit_sha = await loop.run_in_executor(
|
||||
None,
|
||||
git_store.commit_backup,
|
||||
tenant_id,
|
||||
device_id,
|
||||
git_export_content,
|
||||
git_binary_content,
|
||||
commit_message,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Committed backup for device %s to git SHA %s (tier=%s)",
|
||||
device_id,
|
||||
commit_sha[:8],
|
||||
encryption_tier,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Step 7: Insert ConfigBackupRun record
|
||||
# -----------------------------------------------------------------------
|
||||
if not should_close:
|
||||
# RLS-scoped session from API context — record directly
|
||||
backup_run = ConfigBackupRun(
|
||||
device_id=device.id,
|
||||
tenant_id=device.tenant_id,
|
||||
commit_sha=commit_sha,
|
||||
trigger_type=trigger_type,
|
||||
lines_added=lines_added,
|
||||
lines_removed=lines_removed,
|
||||
encryption_tier=encryption_tier,
|
||||
)
|
||||
session.add(backup_run)
|
||||
await session.flush()
|
||||
else:
|
||||
# Admin session — set tenant context before insert so RLS policy is satisfied
|
||||
async with AdminAsyncSessionLocal() as admin_session:
|
||||
await set_tenant_context(admin_session, str(device.tenant_id))
|
||||
backup_run = ConfigBackupRun(
|
||||
device_id=device.id,
|
||||
tenant_id=device.tenant_id,
|
||||
commit_sha=commit_sha,
|
||||
trigger_type=trigger_type,
|
||||
lines_added=lines_added,
|
||||
lines_removed=lines_removed,
|
||||
encryption_tier=encryption_tier,
|
||||
)
|
||||
admin_session.add(backup_run)
|
||||
await admin_session.commit()
|
||||
|
||||
return {
|
||||
"commit_sha": commit_sha,
|
||||
"trigger_type": trigger_type,
|
||||
"lines_added": lines_added,
|
||||
"lines_removed": lines_removed,
|
||||
}
|
||||
|
||||
finally:
|
||||
if should_close:
|
||||
await session.close()
|
||||
462
backend/app/services/ca_service.py
Normal file
462
backend/app/services/ca_service.py
Normal file
@@ -0,0 +1,462 @@
|
||||
"""Certificate Authority service — CA generation, device cert signing, lifecycle.
|
||||
|
||||
This module provides the core PKI functionality for the Internal Certificate
|
||||
Authority feature. All functions receive an ``AsyncSession`` and an
|
||||
``encryption_key`` as parameters (no direct Settings access) for testability.
|
||||
|
||||
Security notes:
|
||||
- CA private keys are encrypted with AES-256-GCM before database storage.
|
||||
- PEM key material is NEVER logged.
|
||||
- Device keys are decrypted only when needed for NATS transmission.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import ipaddress
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.certificate import CertificateAuthority, DeviceCertificate
|
||||
from app.services.crypto import (
|
||||
decrypt_credentials_hybrid,
|
||||
encrypt_credentials_transit,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Valid status transitions for the device certificate lifecycle.
|
||||
_VALID_TRANSITIONS: dict[str, set[str]] = {
|
||||
"issued": {"deploying"},
|
||||
"deploying": {"deployed", "issued"}, # issued = rollback on deploy failure
|
||||
"deployed": {"expiring", "revoked", "superseded"},
|
||||
"expiring": {"expired", "revoked", "superseded"},
|
||||
"expired": {"superseded"},
|
||||
"revoked": set(),
|
||||
"superseded": set(),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CA Generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def generate_ca(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
common_name: str,
|
||||
validity_years: int,
|
||||
encryption_key: bytes,
|
||||
) -> CertificateAuthority:
|
||||
"""Generate a self-signed root CA for a tenant.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
tenant_id: Tenant UUID — only one CA per tenant.
|
||||
common_name: CN for the CA certificate (e.g., "Portal Root CA").
|
||||
validity_years: How many years the CA cert is valid.
|
||||
encryption_key: 32-byte AES-256-GCM key for encrypting the CA private key.
|
||||
|
||||
Returns:
|
||||
The newly created ``CertificateAuthority`` model instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If the tenant already has a CA.
|
||||
"""
|
||||
# Ensure one CA per tenant
|
||||
existing = await get_ca_for_tenant(db, tenant_id)
|
||||
if existing is not None:
|
||||
raise ValueError(
|
||||
f"Tenant {tenant_id} already has a CA (id={existing.id}). "
|
||||
"Delete the existing CA before creating a new one."
|
||||
)
|
||||
|
||||
# Generate RSA 2048 key pair
|
||||
ca_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
expiry = now + datetime.timedelta(days=365 * validity_years)
|
||||
|
||||
subject = issuer = x509.Name([
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
||||
])
|
||||
|
||||
ca_cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(subject)
|
||||
.issuer_name(issuer)
|
||||
.public_key(ca_key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now)
|
||||
.not_valid_after(expiry)
|
||||
.add_extension(
|
||||
x509.BasicConstraints(ca=True, path_length=0), critical=True
|
||||
)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
content_commitment=False,
|
||||
key_encipherment=False,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
key_cert_sign=True,
|
||||
crl_sign=True,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.SubjectKeyIdentifier.from_public_key(ca_key.public_key()),
|
||||
critical=False,
|
||||
)
|
||||
.sign(ca_key, hashes.SHA256())
|
||||
)
|
||||
|
||||
# Serialize public cert to PEM
|
||||
cert_pem = ca_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
|
||||
|
||||
# Serialize private key to PEM, then encrypt with OpenBao Transit
|
||||
key_pem = ca_key.private_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PrivateFormat.PKCS8,
|
||||
serialization.NoEncryption(),
|
||||
).decode("utf-8")
|
||||
encrypted_key_transit = await encrypt_credentials_transit(key_pem, str(tenant_id))
|
||||
|
||||
# Compute SHA-256 fingerprint (colon-separated hex)
|
||||
fingerprint_bytes = ca_cert.fingerprint(hashes.SHA256())
|
||||
fingerprint = ":".join(f"{b:02X}" for b in fingerprint_bytes)
|
||||
|
||||
# Serial number as hex string
|
||||
serial_hex = format(ca_cert.serial_number, "X")
|
||||
|
||||
model = CertificateAuthority(
|
||||
tenant_id=tenant_id,
|
||||
common_name=common_name,
|
||||
cert_pem=cert_pem,
|
||||
encrypted_private_key=b"", # Legacy column kept for schema compat
|
||||
encrypted_private_key_transit=encrypted_key_transit,
|
||||
serial_number=serial_hex,
|
||||
fingerprint_sha256=fingerprint,
|
||||
not_valid_before=now,
|
||||
not_valid_after=expiry,
|
||||
)
|
||||
db.add(model)
|
||||
await db.flush()
|
||||
|
||||
logger.info(
|
||||
"Generated CA for tenant %s: cn=%s fingerprint=%s",
|
||||
tenant_id,
|
||||
common_name,
|
||||
fingerprint,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Device Certificate Signing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def sign_device_cert(
|
||||
db: AsyncSession,
|
||||
ca: CertificateAuthority,
|
||||
device_id: UUID,
|
||||
hostname: str,
|
||||
ip_address: str,
|
||||
validity_days: int,
|
||||
encryption_key: bytes,
|
||||
) -> DeviceCertificate:
|
||||
"""Sign a per-device TLS certificate using the tenant's CA.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
ca: The tenant's CertificateAuthority model instance.
|
||||
device_id: UUID of the device receiving the cert.
|
||||
hostname: Device hostname — used as CN and SAN DNSName.
|
||||
ip_address: Device IP — used as SAN IPAddress.
|
||||
validity_days: Certificate validity in days.
|
||||
encryption_key: 32-byte AES-256-GCM key for encrypting the device private key.
|
||||
|
||||
Returns:
|
||||
The newly created ``DeviceCertificate`` model instance (status='issued').
|
||||
"""
|
||||
# Decrypt CA private key (dual-read: Transit preferred, legacy fallback)
|
||||
ca_key_pem = await decrypt_credentials_hybrid(
|
||||
ca.encrypted_private_key_transit,
|
||||
ca.encrypted_private_key,
|
||||
str(ca.tenant_id),
|
||||
encryption_key,
|
||||
)
|
||||
ca_key = serialization.load_pem_private_key(
|
||||
ca_key_pem.encode("utf-8"), password=None
|
||||
)
|
||||
|
||||
# Load CA certificate for issuer info and AuthorityKeyIdentifier
|
||||
ca_cert = x509.load_pem_x509_certificate(ca.cert_pem.encode("utf-8"))
|
||||
|
||||
# Generate device RSA 2048 key
|
||||
device_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
expiry = now + datetime.timedelta(days=validity_days)
|
||||
|
||||
device_cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(
|
||||
x509.Name([
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, hostname),
|
||||
])
|
||||
)
|
||||
.issuer_name(ca_cert.subject)
|
||||
.public_key(device_key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now)
|
||||
.not_valid_after(expiry)
|
||||
.add_extension(
|
||||
x509.BasicConstraints(ca=False, path_length=None), critical=True
|
||||
)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
content_commitment=False,
|
||||
key_encipherment=True,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
key_cert_sign=False,
|
||||
crl_sign=False,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.SubjectAlternativeName([
|
||||
x509.IPAddress(ipaddress.ip_address(ip_address)),
|
||||
x509.DNSName(hostname),
|
||||
]),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
|
||||
ca_cert.extensions.get_extension_for_class(
|
||||
x509.SubjectKeyIdentifier
|
||||
).value
|
||||
),
|
||||
critical=False,
|
||||
)
|
||||
.sign(ca_key, hashes.SHA256())
|
||||
)
|
||||
|
||||
# Serialize device cert and key to PEM
|
||||
cert_pem = device_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
|
||||
key_pem = device_key.private_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PrivateFormat.PKCS8,
|
||||
serialization.NoEncryption(),
|
||||
).decode("utf-8")
|
||||
|
||||
# Encrypt device private key via OpenBao Transit
|
||||
encrypted_key_transit = await encrypt_credentials_transit(key_pem, str(ca.tenant_id))
|
||||
|
||||
# Compute fingerprint
|
||||
fingerprint_bytes = device_cert.fingerprint(hashes.SHA256())
|
||||
fingerprint = ":".join(f"{b:02X}" for b in fingerprint_bytes)
|
||||
|
||||
serial_hex = format(device_cert.serial_number, "X")
|
||||
|
||||
model = DeviceCertificate(
|
||||
tenant_id=ca.tenant_id,
|
||||
device_id=device_id,
|
||||
ca_id=ca.id,
|
||||
common_name=hostname,
|
||||
serial_number=serial_hex,
|
||||
fingerprint_sha256=fingerprint,
|
||||
cert_pem=cert_pem,
|
||||
encrypted_private_key=b"", # Legacy column kept for schema compat
|
||||
encrypted_private_key_transit=encrypted_key_transit,
|
||||
not_valid_before=now,
|
||||
not_valid_after=expiry,
|
||||
status="issued",
|
||||
)
|
||||
db.add(model)
|
||||
await db.flush()
|
||||
|
||||
logger.info(
|
||||
"Signed device cert for device %s: cn=%s fingerprint=%s",
|
||||
device_id,
|
||||
hostname,
|
||||
fingerprint,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Queries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def get_ca_for_tenant(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
) -> CertificateAuthority | None:
|
||||
"""Return the tenant's CA, or None if not yet initialized."""
|
||||
result = await db.execute(
|
||||
select(CertificateAuthority).where(
|
||||
CertificateAuthority.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_device_certs(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
device_id: UUID | None = None,
|
||||
) -> list[DeviceCertificate]:
|
||||
"""List device certificates for a tenant.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
tenant_id: Tenant UUID.
|
||||
device_id: If provided, filter to certs for this device only.
|
||||
|
||||
Returns:
|
||||
List of DeviceCertificate models (excludes superseded by default).
|
||||
"""
|
||||
stmt = (
|
||||
select(DeviceCertificate)
|
||||
.where(DeviceCertificate.tenant_id == tenant_id)
|
||||
.where(DeviceCertificate.status != "superseded")
|
||||
)
|
||||
if device_id is not None:
|
||||
stmt = stmt.where(DeviceCertificate.device_id == device_id)
|
||||
stmt = stmt.order_by(DeviceCertificate.created_at.desc())
|
||||
result = await db.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Status Management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def update_cert_status(
|
||||
db: AsyncSession,
|
||||
cert_id: UUID,
|
||||
status: str,
|
||||
deployed_at: datetime.datetime | None = None,
|
||||
) -> DeviceCertificate:
|
||||
"""Update a device certificate's lifecycle status.
|
||||
|
||||
Validates that the transition is allowed by the state machine:
|
||||
issued -> deploying -> deployed -> expiring -> expired
|
||||
\\-> revoked
|
||||
\\-> superseded
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
cert_id: Certificate UUID.
|
||||
status: New status value.
|
||||
deployed_at: Timestamp to set when transitioning to 'deployed'.
|
||||
|
||||
Returns:
|
||||
The updated DeviceCertificate model.
|
||||
|
||||
Raises:
|
||||
ValueError: If the certificate is not found or the transition is invalid.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
|
||||
)
|
||||
cert = result.scalar_one_or_none()
|
||||
if cert is None:
|
||||
raise ValueError(f"Device certificate {cert_id} not found")
|
||||
|
||||
allowed = _VALID_TRANSITIONS.get(cert.status, set())
|
||||
if status not in allowed:
|
||||
raise ValueError(
|
||||
f"Invalid status transition: {cert.status} -> {status}. "
|
||||
f"Allowed transitions from '{cert.status}': {allowed or 'none'}"
|
||||
)
|
||||
|
||||
cert.status = status
|
||||
cert.updated_at = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
if status == "deployed" and deployed_at is not None:
|
||||
cert.deployed_at = deployed_at
|
||||
elif status == "deployed":
|
||||
cert.deployed_at = cert.updated_at
|
||||
|
||||
await db.flush()
|
||||
|
||||
logger.info(
|
||||
"Updated cert %s status to %s",
|
||||
cert_id,
|
||||
status,
|
||||
)
|
||||
return cert
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cert Data for Deployment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def get_cert_for_deploy(
|
||||
db: AsyncSession,
|
||||
cert_id: UUID,
|
||||
encryption_key: bytes,
|
||||
) -> tuple[str, str, str]:
|
||||
"""Retrieve and decrypt certificate data for NATS deployment.
|
||||
|
||||
Returns the device cert PEM, decrypted device key PEM, and the CA cert
|
||||
PEM — everything needed to push to a device via the Go poller.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
cert_id: Device certificate UUID.
|
||||
encryption_key: 32-byte AES-256-GCM key to decrypt the device private key.
|
||||
|
||||
Returns:
|
||||
Tuple of (cert_pem, key_pem_decrypted, ca_cert_pem).
|
||||
|
||||
Raises:
|
||||
ValueError: If the certificate or its CA is not found.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
|
||||
)
|
||||
cert = result.scalar_one_or_none()
|
||||
if cert is None:
|
||||
raise ValueError(f"Device certificate {cert_id} not found")
|
||||
|
||||
# Fetch the CA for the ca_cert_pem
|
||||
ca_result = await db.execute(
|
||||
select(CertificateAuthority).where(
|
||||
CertificateAuthority.id == cert.ca_id
|
||||
)
|
||||
)
|
||||
ca = ca_result.scalar_one_or_none()
|
||||
if ca is None:
|
||||
raise ValueError(f"CA {cert.ca_id} not found for certificate {cert_id}")
|
||||
|
||||
# Decrypt device private key (dual-read: Transit preferred, legacy fallback)
|
||||
key_pem = await decrypt_credentials_hybrid(
|
||||
cert.encrypted_private_key_transit,
|
||||
cert.encrypted_private_key,
|
||||
str(cert.tenant_id),
|
||||
encryption_key,
|
||||
)
|
||||
|
||||
return cert.cert_pem, key_pem, ca.cert_pem
|
||||
118
backend/app/services/config_change_subscriber.py
Normal file
118
backend/app/services/config_change_subscriber.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""NATS subscriber for config change events from the Go poller.
|
||||
|
||||
Triggers automatic backups when out-of-band config changes are detected,
|
||||
with 5-minute deduplication to prevent rapid-fire backups.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
from app.models.config_backup import ConfigBackupRun
|
||||
from app.services import backup_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEDUP_WINDOW_MINUTES = 5
|
||||
|
||||
_nc: Optional[Any] = None
|
||||
|
||||
|
||||
async def _last_backup_within_dedup_window(device_id: str) -> bool:
|
||||
"""Check if a backup was created for this device in the last N minutes."""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(minutes=DEDUP_WINDOW_MINUTES)
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
select(ConfigBackupRun)
|
||||
.where(
|
||||
ConfigBackupRun.device_id == device_id,
|
||||
ConfigBackupRun.created_at > cutoff,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
return result.scalar_one_or_none() is not None
|
||||
|
||||
|
||||
async def handle_config_changed(event: dict) -> None:
|
||||
"""Handle a config change event. Trigger backup with dedup."""
|
||||
device_id = event.get("device_id")
|
||||
tenant_id = event.get("tenant_id")
|
||||
|
||||
if not device_id or not tenant_id:
|
||||
logger.warning("Config change event missing device_id or tenant_id: %s", event)
|
||||
return
|
||||
|
||||
# Dedup check
|
||||
if await _last_backup_within_dedup_window(device_id):
|
||||
logger.info(
|
||||
"Config change on device %s — skipping backup (within %dm dedup window)",
|
||||
device_id, DEDUP_WINDOW_MINUTES,
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Config change detected on device %s (tenant %s): %s -> %s",
|
||||
device_id, tenant_id,
|
||||
event.get("old_timestamp", "?"),
|
||||
event.get("new_timestamp", "?"),
|
||||
)
|
||||
|
||||
try:
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
await backup_service.run_backup(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
trigger_type="config-change",
|
||||
db_session=session,
|
||||
)
|
||||
await session.commit()
|
||||
logger.info("Config-change backup completed for device %s", device_id)
|
||||
except Exception as e:
|
||||
logger.error("Config-change backup failed for device %s: %s", device_id, e)
|
||||
|
||||
|
||||
async def _on_message(msg) -> None:
|
||||
"""NATS message handler for config.changed.> subjects."""
|
||||
try:
|
||||
event = json.loads(msg.data.decode())
|
||||
await handle_config_changed(event)
|
||||
await msg.ack()
|
||||
except Exception as e:
|
||||
logger.error("Error handling config change message: %s", e)
|
||||
await msg.nak()
|
||||
|
||||
|
||||
async def start_config_change_subscriber() -> Optional[Any]:
|
||||
"""Connect to NATS and subscribe to config.changed.> events."""
|
||||
import nats
|
||||
|
||||
global _nc
|
||||
try:
|
||||
logger.info("NATS config-change: connecting to %s", settings.NATS_URL)
|
||||
_nc = await nats.connect(settings.NATS_URL)
|
||||
js = _nc.jetstream()
|
||||
await js.subscribe(
|
||||
"config.changed.>",
|
||||
cb=_on_message,
|
||||
durable="api-config-change-consumer",
|
||||
stream="DEVICE_EVENTS",
|
||||
manual_ack=True,
|
||||
)
|
||||
logger.info("Config change subscriber started")
|
||||
return _nc
|
||||
except Exception as e:
|
||||
logger.error("Failed to start config change subscriber: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
async def stop_config_change_subscriber() -> None:
|
||||
"""Gracefully close the NATS connection."""
|
||||
global _nc
|
||||
if _nc:
|
||||
await _nc.drain()
|
||||
_nc = None
|
||||
183
backend/app/services/crypto.py
Normal file
183
backend/app/services/crypto.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
Credential encryption/decryption with dual-read (OpenBao Transit + legacy AES-256-GCM).
|
||||
|
||||
This module provides two encryption paths:
|
||||
1. Legacy (sync): AES-256-GCM with static CREDENTIAL_ENCRYPTION_KEY — used for fallback reads.
|
||||
2. Transit (async): OpenBao Transit per-tenant keys — used for all new writes.
|
||||
|
||||
The dual-read pattern:
|
||||
- New writes always use OpenBao Transit (encrypt_credentials_transit).
|
||||
- Reads prefer Transit ciphertext, falling back to legacy (decrypt_credentials_hybrid).
|
||||
- Legacy functions are preserved for backward compatibility during migration.
|
||||
|
||||
Security properties:
|
||||
- AES-256-GCM provides authenticated encryption (confidentiality + integrity)
|
||||
- A unique 12-byte random nonce is generated per legacy encryption operation
|
||||
- OpenBao Transit keys are AES-256-GCM96, managed entirely by OpenBao
|
||||
- Ciphertext format: "vault:v1:..." for Transit, raw bytes for legacy
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def encrypt_credentials(plaintext: str, key: bytes) -> bytes:
|
||||
"""
|
||||
Encrypt a plaintext string using AES-256-GCM.
|
||||
|
||||
Args:
|
||||
plaintext: The credential string to encrypt (e.g., JSON with username/password)
|
||||
key: 32-byte encryption key
|
||||
|
||||
Returns:
|
||||
bytes: nonce (12 bytes) + ciphertext + GCM tag (16 bytes)
|
||||
|
||||
Raises:
|
||||
ValueError: If key is not exactly 32 bytes
|
||||
"""
|
||||
if len(key) != 32:
|
||||
raise ValueError(f"Key must be exactly 32 bytes, got {len(key)}")
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
aesgcm = AESGCM(key)
|
||||
nonce = os.urandom(12) # 96-bit nonce, unique per encryption
|
||||
ciphertext = aesgcm.encrypt(nonce, plaintext.encode("utf-8"), None)
|
||||
|
||||
# Store as: nonce (12 bytes) + ciphertext + GCM tag (included in ciphertext by library)
|
||||
return nonce + ciphertext
|
||||
|
||||
|
||||
def decrypt_credentials(ciphertext: bytes, key: bytes) -> str:
|
||||
"""
|
||||
Decrypt AES-256-GCM encrypted credentials.
|
||||
|
||||
Args:
|
||||
ciphertext: bytes from encrypt_credentials (nonce + encrypted data + GCM tag)
|
||||
key: 32-byte encryption key (must match the key used for encryption)
|
||||
|
||||
Returns:
|
||||
str: The original plaintext string
|
||||
|
||||
Raises:
|
||||
ValueError: If key is not exactly 32 bytes
|
||||
cryptography.exceptions.InvalidTag: If authentication fails (tampered data or wrong key)
|
||||
"""
|
||||
if len(key) != 32:
|
||||
raise ValueError(f"Key must be exactly 32 bytes, got {len(key)}")
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
nonce = ciphertext[:12]
|
||||
encrypted_data = ciphertext[12:]
|
||||
|
||||
aesgcm = AESGCM(key)
|
||||
plaintext_bytes = aesgcm.decrypt(nonce, encrypted_data, None)
|
||||
|
||||
return plaintext_bytes.decode("utf-8")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenBao Transit functions (async, per-tenant keys)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def encrypt_credentials_transit(plaintext: str, tenant_id: str) -> str:
|
||||
"""Encrypt via OpenBao Transit. Returns ciphertext string (vault:v1:...).
|
||||
|
||||
Args:
|
||||
plaintext: The credential string to encrypt.
|
||||
tenant_id: Tenant UUID string for key lookup.
|
||||
|
||||
Returns:
|
||||
Transit ciphertext string (vault:v1:base64...).
|
||||
"""
|
||||
from app.services.openbao_service import get_openbao_service
|
||||
|
||||
service = get_openbao_service()
|
||||
return await service.encrypt(tenant_id, plaintext.encode("utf-8"))
|
||||
|
||||
|
||||
async def decrypt_credentials_transit(ciphertext: str, tenant_id: str) -> str:
|
||||
"""Decrypt OpenBao Transit ciphertext. Returns plaintext string.
|
||||
|
||||
Args:
|
||||
ciphertext: Transit ciphertext (vault:v1:...).
|
||||
tenant_id: Tenant UUID string for key lookup.
|
||||
|
||||
Returns:
|
||||
Decrypted plaintext string.
|
||||
"""
|
||||
from app.services.openbao_service import get_openbao_service
|
||||
|
||||
service = get_openbao_service()
|
||||
plaintext_bytes = await service.decrypt(tenant_id, ciphertext)
|
||||
return plaintext_bytes.decode("utf-8")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenBao Transit data encryption (async, per-tenant _data keys — Phase 30)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def encrypt_data_transit(plaintext: str, tenant_id: str) -> str:
|
||||
"""Encrypt non-credential data via OpenBao Transit using per-tenant data key.
|
||||
|
||||
Used for audit log details, config backups, and reports. Data keys are
|
||||
separate from credential keys (tenant_{uuid}_data vs tenant_{uuid}).
|
||||
|
||||
Args:
|
||||
plaintext: The data string to encrypt.
|
||||
tenant_id: Tenant UUID string for data key lookup.
|
||||
|
||||
Returns:
|
||||
Transit ciphertext string (vault:v1:base64...).
|
||||
"""
|
||||
from app.services.openbao_service import get_openbao_service
|
||||
|
||||
service = get_openbao_service()
|
||||
return await service.encrypt_data(tenant_id, plaintext.encode("utf-8"))
|
||||
|
||||
|
||||
async def decrypt_data_transit(ciphertext: str, tenant_id: str) -> str:
|
||||
"""Decrypt OpenBao Transit data ciphertext. Returns plaintext string.
|
||||
|
||||
Args:
|
||||
ciphertext: Transit ciphertext (vault:v1:...).
|
||||
tenant_id: Tenant UUID string for data key lookup.
|
||||
|
||||
Returns:
|
||||
Decrypted plaintext string.
|
||||
"""
|
||||
from app.services.openbao_service import get_openbao_service
|
||||
|
||||
service = get_openbao_service()
|
||||
plaintext_bytes = await service.decrypt_data(tenant_id, ciphertext)
|
||||
return plaintext_bytes.decode("utf-8")
|
||||
|
||||
|
||||
async def decrypt_credentials_hybrid(
|
||||
transit_ciphertext: str | None,
|
||||
legacy_ciphertext: bytes | None,
|
||||
tenant_id: str,
|
||||
legacy_key: bytes,
|
||||
) -> str:
|
||||
"""Dual-read: prefer Transit ciphertext, fall back to legacy.
|
||||
|
||||
Args:
|
||||
transit_ciphertext: OpenBao Transit ciphertext (vault:v1:...) or None.
|
||||
legacy_ciphertext: Legacy AES-256-GCM bytes (nonce+ciphertext+tag) or None.
|
||||
tenant_id: Tenant UUID string for Transit key lookup.
|
||||
legacy_key: 32-byte legacy encryption key for fallback.
|
||||
|
||||
Returns:
|
||||
Decrypted plaintext string.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither ciphertext is available.
|
||||
"""
|
||||
if transit_ciphertext and transit_ciphertext.startswith("vault:v"):
|
||||
return await decrypt_credentials_transit(transit_ciphertext, tenant_id)
|
||||
elif legacy_ciphertext:
|
||||
return decrypt_credentials(legacy_ciphertext, legacy_key)
|
||||
else:
|
||||
raise ValueError("No credentials available (both transit and legacy are empty)")
|
||||
670
backend/app/services/device.py
Normal file
670
backend/app/services/device.py
Normal file
@@ -0,0 +1,670 @@
|
||||
"""
|
||||
Device service — business logic for device CRUD, credential encryption, groups, and tags.
|
||||
|
||||
All functions operate via the app_user engine (RLS enforced).
|
||||
Tenant isolation is handled automatically by PostgreSQL RLS policies
|
||||
(SET LOCAL app.current_tenant is set by the get_current_user dependency before
|
||||
this layer is called).
|
||||
|
||||
Credential policy:
|
||||
- Credentials are always stored as AES-256-GCM encrypted JSON blobs.
|
||||
- Credentials are NEVER returned in any public-facing response.
|
||||
- Re-encryption happens only when a new password is explicitly provided in an update.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models.device import (
|
||||
Device,
|
||||
DeviceGroup,
|
||||
DeviceGroupMembership,
|
||||
DeviceTag,
|
||||
DeviceTagAssignment,
|
||||
)
|
||||
from app.schemas.device import (
|
||||
BulkAddRequest,
|
||||
BulkAddResult,
|
||||
DeviceCreate,
|
||||
DeviceGroupCreate,
|
||||
DeviceGroupResponse,
|
||||
DeviceGroupUpdate,
|
||||
DeviceResponse,
|
||||
DeviceTagCreate,
|
||||
DeviceTagResponse,
|
||||
DeviceTagUpdate,
|
||||
DeviceUpdate,
|
||||
)
|
||||
from app.config import settings
|
||||
from app.services.crypto import (
|
||||
decrypt_credentials,
|
||||
decrypt_credentials_hybrid,
|
||||
encrypt_credentials,
|
||||
encrypt_credentials_transit,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _tcp_reachable(ip: str, port: int, timeout: float = 3.0) -> bool:
|
||||
"""Return True if a TCP connection to ip:port succeeds within timeout."""
|
||||
try:
|
||||
_, writer = await asyncio.wait_for(
|
||||
asyncio.open_connection(ip, port), timeout=timeout
|
||||
)
|
||||
writer.close()
|
||||
try:
|
||||
await writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _build_device_response(device: Device) -> DeviceResponse:
|
||||
"""
|
||||
Build a DeviceResponse from an ORM Device instance.
|
||||
|
||||
Tags and groups are extracted from pre-loaded relationships.
|
||||
Credentials are explicitly EXCLUDED.
|
||||
"""
|
||||
from app.schemas.device import DeviceGroupRef, DeviceTagRef
|
||||
|
||||
tags = [
|
||||
DeviceTagRef(
|
||||
id=a.tag.id,
|
||||
name=a.tag.name,
|
||||
color=a.tag.color,
|
||||
)
|
||||
for a in device.tag_assignments
|
||||
]
|
||||
|
||||
groups = [
|
||||
DeviceGroupRef(
|
||||
id=m.group.id,
|
||||
name=m.group.name,
|
||||
)
|
||||
for m in device.group_memberships
|
||||
]
|
||||
|
||||
return DeviceResponse(
|
||||
id=device.id,
|
||||
hostname=device.hostname,
|
||||
ip_address=device.ip_address,
|
||||
api_port=device.api_port,
|
||||
api_ssl_port=device.api_ssl_port,
|
||||
model=device.model,
|
||||
serial_number=device.serial_number,
|
||||
firmware_version=device.firmware_version,
|
||||
routeros_version=device.routeros_version,
|
||||
uptime_seconds=device.uptime_seconds,
|
||||
last_seen=device.last_seen,
|
||||
latitude=device.latitude,
|
||||
longitude=device.longitude,
|
||||
status=device.status,
|
||||
tls_mode=device.tls_mode,
|
||||
tags=tags,
|
||||
groups=groups,
|
||||
created_at=device.created_at,
|
||||
)
|
||||
|
||||
|
||||
def _device_with_relations():
|
||||
"""Return a select() for Device with tags and groups eagerly loaded."""
|
||||
return select(Device).options(
|
||||
selectinload(Device.tag_assignments).selectinload(DeviceTagAssignment.tag),
|
||||
selectinload(Device.group_memberships).selectinload(DeviceGroupMembership.group),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Device CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def create_device(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
data: DeviceCreate,
|
||||
encryption_key: bytes,
|
||||
) -> DeviceResponse:
|
||||
"""
|
||||
Create a new device.
|
||||
|
||||
- Validates TCP connectivity (api_port or api_ssl_port must be reachable).
|
||||
- Encrypts credentials before storage.
|
||||
- Status set to "unknown" until the Go poller runs a full auth check (Phase 2).
|
||||
"""
|
||||
# Test connectivity before accepting the device
|
||||
api_reachable = await _tcp_reachable(data.ip_address, data.api_port)
|
||||
ssl_reachable = await _tcp_reachable(data.ip_address, data.api_ssl_port)
|
||||
|
||||
if not api_reachable and not ssl_reachable:
|
||||
from fastapi import HTTPException, status
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=(
|
||||
f"Cannot reach {data.ip_address} on port {data.api_port} "
|
||||
f"(RouterOS API) or {data.api_ssl_port} (RouterOS SSL API). "
|
||||
"Verify the IP address and that the RouterOS API is enabled."
|
||||
),
|
||||
)
|
||||
|
||||
# Encrypt credentials via OpenBao Transit (new writes go through Transit)
|
||||
credentials_json = json.dumps({"username": data.username, "password": data.password})
|
||||
transit_ciphertext = await encrypt_credentials_transit(
|
||||
credentials_json, str(tenant_id)
|
||||
)
|
||||
|
||||
device = Device(
|
||||
tenant_id=tenant_id,
|
||||
hostname=data.hostname,
|
||||
ip_address=data.ip_address,
|
||||
api_port=data.api_port,
|
||||
api_ssl_port=data.api_ssl_port,
|
||||
encrypted_credentials_transit=transit_ciphertext,
|
||||
status="unknown",
|
||||
)
|
||||
db.add(device)
|
||||
await db.flush() # Get the ID without committing
|
||||
await db.refresh(device)
|
||||
|
||||
# Re-query with relationships loaded
|
||||
result = await db.execute(
|
||||
_device_with_relations().where(Device.id == device.id)
|
||||
)
|
||||
device = result.scalar_one()
|
||||
return _build_device_response(device)
|
||||
|
||||
|
||||
async def get_devices(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
page: int = 1,
|
||||
page_size: int = 25,
|
||||
status: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
tag_id: Optional[uuid.UUID] = None,
|
||||
group_id: Optional[uuid.UUID] = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
) -> tuple[list[DeviceResponse], int]:
|
||||
"""
|
||||
Return a paginated list of devices with optional filtering and sorting.
|
||||
|
||||
Returns (items, total_count).
|
||||
RLS automatically scopes this to the caller's tenant.
|
||||
"""
|
||||
base_q = _device_with_relations()
|
||||
|
||||
# Filtering
|
||||
if status:
|
||||
base_q = base_q.where(Device.status == status)
|
||||
|
||||
if search:
|
||||
pattern = f"%{search}%"
|
||||
base_q = base_q.where(
|
||||
or_(
|
||||
Device.hostname.ilike(pattern),
|
||||
Device.ip_address.ilike(pattern),
|
||||
)
|
||||
)
|
||||
|
||||
if tag_id:
|
||||
base_q = base_q.where(
|
||||
Device.id.in_(
|
||||
select(DeviceTagAssignment.device_id).where(
|
||||
DeviceTagAssignment.tag_id == tag_id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if group_id:
|
||||
base_q = base_q.where(
|
||||
Device.id.in_(
|
||||
select(DeviceGroupMembership.device_id).where(
|
||||
DeviceGroupMembership.group_id == group_id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Count total before pagination
|
||||
count_q = select(func.count()).select_from(base_q.subquery())
|
||||
total_result = await db.execute(count_q)
|
||||
total = total_result.scalar_one()
|
||||
|
||||
# Sorting
|
||||
allowed_sort_cols = {
|
||||
"created_at": Device.created_at,
|
||||
"hostname": Device.hostname,
|
||||
"ip_address": Device.ip_address,
|
||||
"status": Device.status,
|
||||
"last_seen": Device.last_seen,
|
||||
}
|
||||
sort_col = allowed_sort_cols.get(sort_by, Device.created_at)
|
||||
if sort_order.lower() == "asc":
|
||||
base_q = base_q.order_by(sort_col.asc())
|
||||
else:
|
||||
base_q = base_q.order_by(sort_col.desc())
|
||||
|
||||
# Pagination
|
||||
offset = (page - 1) * page_size
|
||||
base_q = base_q.offset(offset).limit(page_size)
|
||||
|
||||
result = await db.execute(base_q)
|
||||
devices = result.scalars().all()
|
||||
return [_build_device_response(d) for d in devices], total
|
||||
|
||||
|
||||
async def get_device(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
) -> DeviceResponse:
|
||||
"""Get a single device by ID."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
result = await db.execute(
|
||||
_device_with_relations().where(Device.id == device_id)
|
||||
)
|
||||
device = result.scalar_one_or_none()
|
||||
if not device:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
|
||||
return _build_device_response(device)
|
||||
|
||||
|
||||
async def update_device(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
data: DeviceUpdate,
|
||||
encryption_key: bytes,
|
||||
) -> DeviceResponse:
|
||||
"""
|
||||
Update device fields. Re-encrypts credentials only if password is provided.
|
||||
"""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
result = await db.execute(
|
||||
_device_with_relations().where(Device.id == device_id)
|
||||
)
|
||||
device = result.scalar_one_or_none()
|
||||
if not device:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
|
||||
|
||||
# Update scalar fields
|
||||
if data.hostname is not None:
|
||||
device.hostname = data.hostname
|
||||
if data.ip_address is not None:
|
||||
device.ip_address = data.ip_address
|
||||
if data.api_port is not None:
|
||||
device.api_port = data.api_port
|
||||
if data.api_ssl_port is not None:
|
||||
device.api_ssl_port = data.api_ssl_port
|
||||
if data.latitude is not None:
|
||||
device.latitude = data.latitude
|
||||
if data.longitude is not None:
|
||||
device.longitude = data.longitude
|
||||
if data.tls_mode is not None:
|
||||
device.tls_mode = data.tls_mode
|
||||
|
||||
# Re-encrypt credentials if new ones are provided
|
||||
credentials_changed = False
|
||||
if data.password is not None:
|
||||
# Decrypt existing to get current username if no new username given
|
||||
current_username: str = data.username or ""
|
||||
if not current_username and (device.encrypted_credentials_transit or device.encrypted_credentials):
|
||||
try:
|
||||
existing_json = await decrypt_credentials_hybrid(
|
||||
device.encrypted_credentials_transit,
|
||||
device.encrypted_credentials,
|
||||
str(device.tenant_id),
|
||||
settings.get_encryption_key_bytes(),
|
||||
)
|
||||
existing = json.loads(existing_json)
|
||||
current_username = existing.get("username", "")
|
||||
except Exception:
|
||||
current_username = ""
|
||||
|
||||
credentials_json = json.dumps({
|
||||
"username": data.username if data.username is not None else current_username,
|
||||
"password": data.password,
|
||||
})
|
||||
# New writes go through Transit
|
||||
device.encrypted_credentials_transit = await encrypt_credentials_transit(
|
||||
credentials_json, str(device.tenant_id)
|
||||
)
|
||||
device.encrypted_credentials = None # Clear legacy (Transit is canonical)
|
||||
credentials_changed = True
|
||||
elif data.username is not None and (device.encrypted_credentials_transit or device.encrypted_credentials):
|
||||
# Only username changed — update it without changing the password
|
||||
try:
|
||||
existing_json = await decrypt_credentials_hybrid(
|
||||
device.encrypted_credentials_transit,
|
||||
device.encrypted_credentials,
|
||||
str(device.tenant_id),
|
||||
settings.get_encryption_key_bytes(),
|
||||
)
|
||||
existing = json.loads(existing_json)
|
||||
existing["username"] = data.username
|
||||
# Re-encrypt via Transit
|
||||
device.encrypted_credentials_transit = await encrypt_credentials_transit(
|
||||
json.dumps(existing), str(device.tenant_id)
|
||||
)
|
||||
device.encrypted_credentials = None
|
||||
credentials_changed = True
|
||||
except Exception:
|
||||
pass # Keep existing encrypted blob if decryption fails
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(device)
|
||||
|
||||
# Notify poller to invalidate cached credentials (fire-and-forget via NATS)
|
||||
if credentials_changed:
|
||||
try:
|
||||
from app.services.event_publisher import publish_event
|
||||
await publish_event(
|
||||
f"device.credential_changed.{device_id}",
|
||||
{"device_id": str(device_id), "tenant_id": str(tenant_id)},
|
||||
)
|
||||
except Exception:
|
||||
pass # Never fail the update due to NATS issues
|
||||
|
||||
result2 = await db.execute(
|
||||
_device_with_relations().where(Device.id == device_id)
|
||||
)
|
||||
device = result2.scalar_one()
|
||||
return _build_device_response(device)
|
||||
|
||||
|
||||
async def delete_device(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
) -> None:
|
||||
"""Hard-delete a device (v1 — no soft delete for devices)."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
result = await db.execute(select(Device).where(Device.id == device_id))
|
||||
device = result.scalar_one_or_none()
|
||||
if not device:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
|
||||
await db.delete(device)
|
||||
await db.flush()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group / Tag assignment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def assign_device_to_group(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
group_id: uuid.UUID,
|
||||
) -> None:
|
||||
"""Assign a device to a group (idempotent)."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
# Verify device and group exist (RLS scopes both)
|
||||
dev = await db.get(Device, device_id)
|
||||
if not dev:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
|
||||
grp = await db.get(DeviceGroup, group_id)
|
||||
if not grp:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group not found")
|
||||
|
||||
existing = await db.get(DeviceGroupMembership, (device_id, group_id))
|
||||
if not existing:
|
||||
db.add(DeviceGroupMembership(device_id=device_id, group_id=group_id))
|
||||
await db.flush()
|
||||
|
||||
|
||||
async def remove_device_from_group(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
group_id: uuid.UUID,
|
||||
) -> None:
|
||||
"""Remove a device from a group."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
membership = await db.get(DeviceGroupMembership, (device_id, group_id))
|
||||
if not membership:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Device is not in this group",
|
||||
)
|
||||
await db.delete(membership)
|
||||
await db.flush()
|
||||
|
||||
|
||||
async def assign_tag_to_device(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
tag_id: uuid.UUID,
|
||||
) -> None:
|
||||
"""Assign a tag to a device (idempotent)."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
dev = await db.get(Device, device_id)
|
||||
if not dev:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
|
||||
tag = await db.get(DeviceTag, tag_id)
|
||||
if not tag:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tag not found")
|
||||
|
||||
existing = await db.get(DeviceTagAssignment, (device_id, tag_id))
|
||||
if not existing:
|
||||
db.add(DeviceTagAssignment(device_id=device_id, tag_id=tag_id))
|
||||
await db.flush()
|
||||
|
||||
|
||||
async def remove_tag_from_device(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
tag_id: uuid.UUID,
|
||||
) -> None:
|
||||
"""Remove a tag from a device."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
assignment = await db.get(DeviceTagAssignment, (device_id, tag_id))
|
||||
if not assignment:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tag is not assigned to this device",
|
||||
)
|
||||
await db.delete(assignment)
|
||||
await db.flush()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DeviceGroup CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def create_group(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
data: DeviceGroupCreate,
|
||||
) -> DeviceGroupResponse:
|
||||
"""Create a new device group."""
|
||||
group = DeviceGroup(
|
||||
tenant_id=tenant_id,
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
)
|
||||
db.add(group)
|
||||
await db.flush()
|
||||
await db.refresh(group)
|
||||
|
||||
# Count devices in the group (0 for new group)
|
||||
return DeviceGroupResponse(
|
||||
id=group.id,
|
||||
name=group.name,
|
||||
description=group.description,
|
||||
device_count=0,
|
||||
created_at=group.created_at,
|
||||
)
|
||||
|
||||
|
||||
async def get_groups(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
) -> list[DeviceGroupResponse]:
|
||||
"""Return all device groups for the current tenant with device counts."""
|
||||
result = await db.execute(
|
||||
select(DeviceGroup).options(
|
||||
selectinload(DeviceGroup.memberships)
|
||||
)
|
||||
)
|
||||
groups = result.scalars().all()
|
||||
return [
|
||||
DeviceGroupResponse(
|
||||
id=g.id,
|
||||
name=g.name,
|
||||
description=g.description,
|
||||
device_count=len(g.memberships),
|
||||
created_at=g.created_at,
|
||||
)
|
||||
for g in groups
|
||||
]
|
||||
|
||||
|
||||
async def update_group(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
group_id: uuid.UUID,
|
||||
data: DeviceGroupUpdate,
|
||||
) -> DeviceGroupResponse:
|
||||
"""Update a device group."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
result = await db.execute(
|
||||
select(DeviceGroup).options(
|
||||
selectinload(DeviceGroup.memberships)
|
||||
).where(DeviceGroup.id == group_id)
|
||||
)
|
||||
group = result.scalar_one_or_none()
|
||||
if not group:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group not found")
|
||||
|
||||
if data.name is not None:
|
||||
group.name = data.name
|
||||
if data.description is not None:
|
||||
group.description = data.description
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(group)
|
||||
|
||||
result2 = await db.execute(
|
||||
select(DeviceGroup).options(
|
||||
selectinload(DeviceGroup.memberships)
|
||||
).where(DeviceGroup.id == group_id)
|
||||
)
|
||||
group = result2.scalar_one()
|
||||
return DeviceGroupResponse(
|
||||
id=group.id,
|
||||
name=group.name,
|
||||
description=group.description,
|
||||
device_count=len(group.memberships),
|
||||
created_at=group.created_at,
|
||||
)
|
||||
|
||||
|
||||
async def delete_group(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
group_id: uuid.UUID,
|
||||
) -> None:
|
||||
"""Delete a device group."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
group = await db.get(DeviceGroup, group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group not found")
|
||||
await db.delete(group)
|
||||
await db.flush()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DeviceTag CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def create_tag(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
data: DeviceTagCreate,
|
||||
) -> DeviceTagResponse:
|
||||
"""Create a new device tag."""
|
||||
tag = DeviceTag(
|
||||
tenant_id=tenant_id,
|
||||
name=data.name,
|
||||
color=data.color,
|
||||
)
|
||||
db.add(tag)
|
||||
await db.flush()
|
||||
await db.refresh(tag)
|
||||
return DeviceTagResponse(id=tag.id, name=tag.name, color=tag.color)
|
||||
|
||||
|
||||
async def get_tags(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
) -> list[DeviceTagResponse]:
|
||||
"""Return all device tags for the current tenant."""
|
||||
result = await db.execute(select(DeviceTag))
|
||||
tags = result.scalars().all()
|
||||
return [DeviceTagResponse(id=t.id, name=t.name, color=t.color) for t in tags]
|
||||
|
||||
|
||||
async def update_tag(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
tag_id: uuid.UUID,
|
||||
data: DeviceTagUpdate,
|
||||
) -> DeviceTagResponse:
|
||||
"""Update a device tag."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
tag = await db.get(DeviceTag, tag_id)
|
||||
if not tag:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tag not found")
|
||||
|
||||
if data.name is not None:
|
||||
tag.name = data.name
|
||||
if data.color is not None:
|
||||
tag.color = data.color
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(tag)
|
||||
return DeviceTagResponse(id=tag.id, name=tag.name, color=tag.color)
|
||||
|
||||
|
||||
async def delete_tag(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
tag_id: uuid.UUID,
|
||||
) -> None:
|
||||
"""Delete a device tag."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
tag = await db.get(DeviceTag, tag_id)
|
||||
if not tag:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tag not found")
|
||||
await db.delete(tag)
|
||||
await db.flush()
|
||||
124
backend/app/services/email_service.py
Normal file
124
backend/app/services/email_service.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Unified email sending service.
|
||||
|
||||
All email sending (system emails, alert notifications) goes through this module.
|
||||
Supports TLS, STARTTLS, and plain SMTP. Handles Transit + legacy Fernet password decryption.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from email.message import EmailMessage
|
||||
from typing import Optional
|
||||
|
||||
import aiosmtplib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SMTPConfig:
|
||||
"""SMTP connection configuration."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int = 587,
|
||||
user: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
use_tls: bool = False,
|
||||
from_address: str = "noreply@example.com",
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.user = user
|
||||
self.password = password
|
||||
self.use_tls = use_tls
|
||||
self.from_address = from_address
|
||||
|
||||
|
||||
async def send_email(
|
||||
to: str,
|
||||
subject: str,
|
||||
html: str,
|
||||
plain_text: str,
|
||||
smtp_config: SMTPConfig,
|
||||
) -> None:
|
||||
"""Send an email via SMTP.
|
||||
|
||||
Args:
|
||||
to: Recipient email address.
|
||||
subject: Email subject line.
|
||||
html: HTML body.
|
||||
plain_text: Plain text fallback body.
|
||||
smtp_config: SMTP connection settings.
|
||||
|
||||
Raises:
|
||||
aiosmtplib.SMTPException: On SMTP connection or send failure.
|
||||
"""
|
||||
msg = EmailMessage()
|
||||
msg["Subject"] = subject
|
||||
msg["From"] = smtp_config.from_address
|
||||
msg["To"] = to
|
||||
msg.set_content(plain_text)
|
||||
msg.add_alternative(html, subtype="html")
|
||||
|
||||
use_tls = smtp_config.use_tls
|
||||
start_tls = not use_tls if smtp_config.port != 25 else False
|
||||
|
||||
await aiosmtplib.send(
|
||||
msg,
|
||||
hostname=smtp_config.host,
|
||||
port=smtp_config.port,
|
||||
username=smtp_config.user or None,
|
||||
password=smtp_config.password or None,
|
||||
use_tls=use_tls,
|
||||
start_tls=start_tls,
|
||||
)
|
||||
|
||||
|
||||
async def test_smtp_connection(smtp_config: SMTPConfig) -> dict:
|
||||
"""Test SMTP connectivity without sending an email.
|
||||
|
||||
Returns:
|
||||
dict with "success" bool and "message" string.
|
||||
"""
|
||||
try:
|
||||
smtp = aiosmtplib.SMTP(
|
||||
hostname=smtp_config.host,
|
||||
port=smtp_config.port,
|
||||
use_tls=smtp_config.use_tls,
|
||||
start_tls=not smtp_config.use_tls if smtp_config.port != 25 else False,
|
||||
)
|
||||
await smtp.connect()
|
||||
if smtp_config.user and smtp_config.password:
|
||||
await smtp.login(smtp_config.user, smtp_config.password)
|
||||
await smtp.quit()
|
||||
return {"success": True, "message": "SMTP connection successful"}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": str(e)}
|
||||
|
||||
|
||||
async def send_test_email(to: str, smtp_config: SMTPConfig) -> dict:
|
||||
"""Send a test email to verify the full SMTP flow.
|
||||
|
||||
Returns:
|
||||
dict with "success" bool and "message" string.
|
||||
"""
|
||||
html = """
|
||||
<div style="font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; max-width: 600px; margin: 0 auto;">
|
||||
<div style="background: #0f172a; padding: 24px; border-radius: 8px 8px 0 0;">
|
||||
<h2 style="color: #38bdf8; margin: 0;">TOD — Email Test</h2>
|
||||
</div>
|
||||
<div style="background: #1e293b; padding: 24px; border-radius: 0 0 8px 8px; color: #e2e8f0;">
|
||||
<p>This is a test email from The Other Dude.</p>
|
||||
<p>If you're reading this, your SMTP configuration is working correctly.</p>
|
||||
<p style="color: #94a3b8; font-size: 13px; margin-top: 24px;">
|
||||
Sent from TOD Fleet Management
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
plain = "TOD — Email Test\n\nThis is a test email from The Other Dude.\nIf you're reading this, your SMTP configuration is working correctly."
|
||||
|
||||
try:
|
||||
await send_email(to, "TOD — Test Email", html, plain, smtp_config)
|
||||
return {"success": True, "message": f"Test email sent to {to}"}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": str(e)}
|
||||
54
backend/app/services/emergency_kit_service.py
Normal file
54
backend/app/services/emergency_kit_service.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Emergency Kit PDF template generation.
|
||||
|
||||
Generates an Emergency Kit PDF containing the user's email and sign-in URL
|
||||
but NOT the Secret Key. The Secret Key placeholder is filled client-side
|
||||
so that the server never sees it.
|
||||
|
||||
Uses Jinja2 + WeasyPrint following the same pattern as the reports service.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
from app.config import settings
|
||||
|
||||
TEMPLATE_DIR = Path(__file__).parent.parent.parent / "templates"
|
||||
|
||||
|
||||
async def generate_emergency_kit_template(
|
||||
email: str,
|
||||
) -> bytes:
|
||||
"""Generate Emergency Kit PDF template WITHOUT the Secret Key.
|
||||
|
||||
The Secret Key placeholder will be filled client-side.
|
||||
The server never sees the Secret Key.
|
||||
|
||||
Args:
|
||||
email: The user's email address to display in the PDF.
|
||||
|
||||
Returns:
|
||||
PDF bytes ready for streaming response.
|
||||
"""
|
||||
env = Environment(
|
||||
loader=FileSystemLoader(str(TEMPLATE_DIR)),
|
||||
autoescape=True,
|
||||
)
|
||||
template = env.get_template("emergency_kit.html")
|
||||
|
||||
html_content = template.render(
|
||||
email=email,
|
||||
signin_url=settings.APP_BASE_URL,
|
||||
date=datetime.now(UTC).strftime("%Y-%m-%d"),
|
||||
secret_key_placeholder="[Download complete -- your Secret Key will be inserted by your browser]",
|
||||
)
|
||||
|
||||
# Run weasyprint in thread to avoid blocking the event loop
|
||||
from weasyprint import HTML
|
||||
|
||||
pdf_bytes = await asyncio.to_thread(
|
||||
lambda: HTML(string=html_content).write_pdf()
|
||||
)
|
||||
return pdf_bytes
|
||||
52
backend/app/services/event_publisher.py
Normal file
52
backend/app/services/event_publisher.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Fire-and-forget NATS JetStream event publisher for real-time SSE pipeline.
|
||||
|
||||
Provides a shared lazy NATS connection and publish helper used by:
|
||||
- alert_evaluator.py (alert.fired.{tenant_id}, alert.resolved.{tenant_id})
|
||||
- restore_service.py (config.push.{tenant_id}.{device_id})
|
||||
- upgrade_service.py (firmware.progress.{tenant_id}.{device_id})
|
||||
|
||||
All publishes are fire-and-forget: errors are logged but never propagate
|
||||
to the caller. A NATS outage must never block alert evaluation, config
|
||||
push, or firmware upgrade operations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import nats
|
||||
import nats.aio.client
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Module-level NATS connection (lazy initialized, reused across publishes)
|
||||
_nc: nats.aio.client.Client | None = None
|
||||
|
||||
|
||||
async def _get_nats() -> nats.aio.client.Client:
|
||||
"""Get or create a NATS connection for event publishing."""
|
||||
global _nc
|
||||
if _nc is None or _nc.is_closed:
|
||||
_nc = await nats.connect(settings.NATS_URL)
|
||||
logger.info("Event publisher NATS connection established")
|
||||
return _nc
|
||||
|
||||
|
||||
async def publish_event(subject: str, payload: dict[str, Any]) -> None:
|
||||
"""Publish a JSON event to a NATS JetStream subject (fire-and-forget).
|
||||
|
||||
Args:
|
||||
subject: NATS subject, e.g. "alert.fired.{tenant_id}".
|
||||
payload: Dict that will be JSON-serialized as the message body.
|
||||
|
||||
Never raises -- all exceptions are caught and logged as warnings.
|
||||
"""
|
||||
try:
|
||||
nc = await _get_nats()
|
||||
js = nc.jetstream()
|
||||
await js.publish(subject, json.dumps(payload).encode())
|
||||
logger.debug("Published event to %s", subject)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to publish event to %s: %s", subject, exc)
|
||||
303
backend/app/services/firmware_service.py
Normal file
303
backend/app/services/firmware_service.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""Firmware version cache service and NPK downloader.
|
||||
|
||||
Responsibilities:
|
||||
- check_latest_versions(): fetch latest RouterOS versions from download.mikrotik.com
|
||||
- download_firmware(): download NPK packages to local PVC cache
|
||||
- get_firmware_overview(): return fleet firmware status for a tenant
|
||||
- schedule_firmware_checks(): register daily firmware check job with APScheduler
|
||||
|
||||
Version discovery comes from two sources:
|
||||
1. Go poller runs /system/package/update per device (rate-limited to once/day)
|
||||
and publishes via NATS -> firmware_subscriber processes these events
|
||||
2. check_latest_versions() fetches LATEST.7 / LATEST.6 from download.mikrotik.com
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Architectures supported by RouterOS v7 and v6
|
||||
_V7_ARCHITECTURES = ["arm", "arm64", "mipsbe", "mmips", "smips", "tile", "ppc", "x86"]
|
||||
_V6_ARCHITECTURES = ["mipsbe", "mmips", "smips", "tile", "ppc", "x86"]
|
||||
|
||||
# Version source files on download.mikrotik.com
|
||||
_VERSION_SOURCES = [
|
||||
("LATEST.7", "stable", 7),
|
||||
("LATEST.7long", "long-term", 7),
|
||||
("LATEST.6", "stable", 6),
|
||||
("LATEST.6long", "long-term", 6),
|
||||
]
|
||||
|
||||
|
||||
async def check_latest_versions() -> list[dict]:
|
||||
"""Fetch latest RouterOS versions from download.mikrotik.com.
|
||||
|
||||
Checks LATEST.7, LATEST.7long, LATEST.6, and LATEST.6long files for
|
||||
version strings, then upserts into firmware_versions table for each
|
||||
architecture/channel combination.
|
||||
|
||||
Returns list of discovered version dicts.
|
||||
"""
|
||||
results: list[dict] = []
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
for channel_file, channel, major in _VERSION_SOURCES:
|
||||
try:
|
||||
resp = await client.get(
|
||||
f"https://download.mikrotik.com/routeros/{channel_file}"
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
"MikroTik version check returned %d for %s",
|
||||
resp.status_code, channel_file,
|
||||
)
|
||||
continue
|
||||
|
||||
version = resp.text.strip()
|
||||
if not version or not version[0].isdigit():
|
||||
logger.warning("Invalid version string from %s: %r", channel_file, version)
|
||||
continue
|
||||
|
||||
architectures = _V7_ARCHITECTURES if major == 7 else _V6_ARCHITECTURES
|
||||
for arch in architectures:
|
||||
npk_url = (
|
||||
f"https://download.mikrotik.com/routeros/"
|
||||
f"{version}/routeros-{version}-{arch}.npk"
|
||||
)
|
||||
results.append({
|
||||
"architecture": arch,
|
||||
"channel": channel,
|
||||
"version": version,
|
||||
"npk_url": npk_url,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to check %s: %s", channel_file, e)
|
||||
|
||||
# Upsert into firmware_versions table
|
||||
if results:
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
for r in results:
|
||||
await session.execute(
|
||||
text("""
|
||||
INSERT INTO firmware_versions (id, architecture, channel, version, npk_url, checked_at)
|
||||
VALUES (gen_random_uuid(), :arch, :channel, :version, :npk_url, NOW())
|
||||
ON CONFLICT (architecture, channel, version) DO UPDATE SET checked_at = NOW()
|
||||
"""),
|
||||
{
|
||||
"arch": r["architecture"],
|
||||
"channel": r["channel"],
|
||||
"version": r["version"],
|
||||
"npk_url": r["npk_url"],
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
logger.info("Firmware version check complete — %d versions discovered", len(results))
|
||||
return results
|
||||
|
||||
|
||||
async def download_firmware(architecture: str, channel: str, version: str) -> str:
|
||||
"""Download an NPK package to the local firmware cache.
|
||||
|
||||
Returns the local file path. Skips download if file already exists
|
||||
and size matches.
|
||||
"""
|
||||
cache_dir = Path(settings.FIRMWARE_CACHE_DIR) / version
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
filename = f"routeros-{version}-{architecture}.npk"
|
||||
local_path = cache_dir / filename
|
||||
npk_url = f"https://download.mikrotik.com/routeros/{version}/{filename}"
|
||||
|
||||
# Check if already cached
|
||||
if local_path.exists() and local_path.stat().st_size > 0:
|
||||
logger.info("Firmware already cached: %s", local_path)
|
||||
return str(local_path)
|
||||
|
||||
logger.info("Downloading firmware: %s", npk_url)
|
||||
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
async with client.stream("GET", npk_url) as response:
|
||||
response.raise_for_status()
|
||||
with open(local_path, "wb") as f:
|
||||
async for chunk in response.aiter_bytes(chunk_size=65536):
|
||||
f.write(chunk)
|
||||
|
||||
file_size = local_path.stat().st_size
|
||||
logger.info("Firmware downloaded: %s (%d bytes)", local_path, file_size)
|
||||
|
||||
# Update firmware_versions table with local path and size
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
await session.execute(
|
||||
text("""
|
||||
UPDATE firmware_versions
|
||||
SET npk_local_path = :path, npk_size_bytes = :size
|
||||
WHERE architecture = :arch AND channel = :channel AND version = :version
|
||||
"""),
|
||||
{
|
||||
"path": str(local_path),
|
||||
"size": file_size,
|
||||
"arch": architecture,
|
||||
"channel": channel,
|
||||
"version": version,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
return str(local_path)
|
||||
|
||||
|
||||
async def get_firmware_overview(tenant_id: str) -> dict:
|
||||
"""Return fleet firmware status for a tenant.
|
||||
|
||||
Returns devices grouped by firmware version, annotated with up-to-date status
|
||||
based on the latest known version for each device's architecture and preferred channel.
|
||||
"""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
# Get all devices for tenant
|
||||
devices_result = await session.execute(
|
||||
text("""
|
||||
SELECT id, hostname, ip_address, routeros_version, architecture,
|
||||
preferred_channel, routeros_major_version,
|
||||
serial_number, firmware_version, model
|
||||
FROM devices
|
||||
WHERE tenant_id = CAST(:tenant_id AS uuid)
|
||||
ORDER BY hostname
|
||||
"""),
|
||||
{"tenant_id": tenant_id},
|
||||
)
|
||||
devices = devices_result.fetchall()
|
||||
|
||||
# Get latest firmware versions per architecture/channel
|
||||
versions_result = await session.execute(
|
||||
text("""
|
||||
SELECT DISTINCT ON (architecture, channel)
|
||||
architecture, channel, version, npk_url
|
||||
FROM firmware_versions
|
||||
ORDER BY architecture, channel, checked_at DESC
|
||||
""")
|
||||
)
|
||||
latest_versions = {
|
||||
(row[0], row[1]): {"version": row[2], "npk_url": row[3]}
|
||||
for row in versions_result.fetchall()
|
||||
}
|
||||
|
||||
# Build per-device status
|
||||
device_list = []
|
||||
version_groups: dict[str, list] = {}
|
||||
summary = {"total": 0, "up_to_date": 0, "outdated": 0, "unknown": 0}
|
||||
|
||||
for dev in devices:
|
||||
dev_id = str(dev[0])
|
||||
hostname = dev[1]
|
||||
current_version = dev[3]
|
||||
arch = dev[4]
|
||||
channel = dev[5] or "stable"
|
||||
|
||||
latest = latest_versions.get((arch, channel)) if arch else None
|
||||
latest_version = latest["version"] if latest else None
|
||||
|
||||
is_up_to_date = False
|
||||
if not current_version or not arch:
|
||||
summary["unknown"] += 1
|
||||
elif latest_version and current_version == latest_version:
|
||||
is_up_to_date = True
|
||||
summary["up_to_date"] += 1
|
||||
else:
|
||||
summary["outdated"] += 1
|
||||
|
||||
summary["total"] += 1
|
||||
|
||||
dev_info = {
|
||||
"id": dev_id,
|
||||
"hostname": hostname,
|
||||
"ip_address": dev[2],
|
||||
"routeros_version": current_version,
|
||||
"architecture": arch,
|
||||
"latest_version": latest_version,
|
||||
"channel": channel,
|
||||
"is_up_to_date": is_up_to_date,
|
||||
"serial_number": dev[7],
|
||||
"firmware_version": dev[8],
|
||||
"model": dev[9],
|
||||
}
|
||||
device_list.append(dev_info)
|
||||
|
||||
# Group by version
|
||||
ver_key = current_version or "unknown"
|
||||
if ver_key not in version_groups:
|
||||
version_groups[ver_key] = []
|
||||
version_groups[ver_key].append(dev_info)
|
||||
|
||||
# Build version groups with is_latest flag
|
||||
groups = []
|
||||
for ver, devs in sorted(version_groups.items()):
|
||||
# A version is "latest" if it matches the latest for any arch/channel combo
|
||||
is_latest = any(
|
||||
v["version"] == ver for v in latest_versions.values()
|
||||
)
|
||||
groups.append({
|
||||
"version": ver,
|
||||
"count": len(devs),
|
||||
"is_latest": is_latest,
|
||||
"devices": devs,
|
||||
})
|
||||
|
||||
return {
|
||||
"devices": device_list,
|
||||
"version_groups": groups,
|
||||
"summary": summary,
|
||||
}
|
||||
|
||||
|
||||
async def get_cached_firmware() -> list[dict]:
|
||||
"""List all locally cached NPK files with their sizes."""
|
||||
cache_dir = Path(settings.FIRMWARE_CACHE_DIR)
|
||||
cached = []
|
||||
|
||||
if not cache_dir.exists():
|
||||
return cached
|
||||
|
||||
for version_dir in sorted(cache_dir.iterdir()):
|
||||
if not version_dir.is_dir():
|
||||
continue
|
||||
for npk_file in sorted(version_dir.iterdir()):
|
||||
if npk_file.suffix == ".npk":
|
||||
cached.append({
|
||||
"path": str(npk_file),
|
||||
"version": version_dir.name,
|
||||
"filename": npk_file.name,
|
||||
"size_bytes": npk_file.stat().st_size,
|
||||
})
|
||||
|
||||
return cached
|
||||
|
||||
|
||||
def schedule_firmware_checks() -> None:
|
||||
"""Register daily firmware version check with APScheduler.
|
||||
|
||||
Called from FastAPI lifespan startup to schedule check_latest_versions()
|
||||
at 3am UTC daily.
|
||||
"""
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from app.services.backup_scheduler import backup_scheduler
|
||||
|
||||
backup_scheduler.add_job(
|
||||
check_latest_versions,
|
||||
trigger=CronTrigger(hour=3, minute=0, timezone="UTC"),
|
||||
id="firmware_version_check",
|
||||
name="Check for new RouterOS firmware versions",
|
||||
max_instances=1,
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
logger.info("Firmware version check scheduled — daily at 3am UTC")
|
||||
206
backend/app/services/firmware_subscriber.py
Normal file
206
backend/app/services/firmware_subscriber.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""NATS JetStream subscriber for device firmware events from the Go poller.
|
||||
|
||||
Subscribes to device.firmware.> and:
|
||||
1. Updates devices.routeros_version and devices.architecture from poller data
|
||||
2. Upserts firmware_versions table with latest version per architecture/channel
|
||||
|
||||
Uses AdminAsyncSessionLocal (superuser bypass RLS) so firmware data from any
|
||||
tenant can be written without setting app.current_tenant.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import nats
|
||||
from nats.js import JetStreamContext
|
||||
from nats.aio.client import Client as NATSClient
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_firmware_client: Optional[NATSClient] = None
|
||||
|
||||
|
||||
async def on_device_firmware(msg) -> None:
|
||||
"""Handle a device.firmware event published by the Go poller.
|
||||
|
||||
Payload (JSON):
|
||||
device_id (str) -- UUID of the device
|
||||
tenant_id (str) -- UUID of the owning tenant
|
||||
installed_version (str) -- currently installed RouterOS version
|
||||
latest_version (str) -- latest available version (may be empty)
|
||||
channel (str) -- firmware channel ("stable", "long-term")
|
||||
status (str) -- "New version is available", etc.
|
||||
architecture (str) -- CPU architecture (arm, arm64, mipsbe, etc.)
|
||||
"""
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
device_id = data.get("device_id")
|
||||
tenant_id = data.get("tenant_id")
|
||||
architecture = data.get("architecture")
|
||||
installed_version = data.get("installed_version")
|
||||
latest_version = data.get("latest_version")
|
||||
channel = data.get("channel", "stable")
|
||||
|
||||
if not device_id:
|
||||
logger.warning("device.firmware event missing device_id — skipping")
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
# Update device routeros_version and architecture from poller data
|
||||
if architecture or installed_version:
|
||||
await session.execute(
|
||||
text("""
|
||||
UPDATE devices
|
||||
SET routeros_version = COALESCE(:installed_ver, routeros_version),
|
||||
architecture = COALESCE(:architecture, architecture),
|
||||
updated_at = NOW()
|
||||
WHERE id = CAST(:device_id AS uuid)
|
||||
"""),
|
||||
{
|
||||
"installed_ver": installed_version,
|
||||
"architecture": architecture,
|
||||
"device_id": device_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Upsert firmware_versions if we got latest version info
|
||||
if latest_version and architecture:
|
||||
npk_url = (
|
||||
f"https://download.mikrotik.com/routeros/"
|
||||
f"{latest_version}/routeros-{latest_version}-{architecture}.npk"
|
||||
)
|
||||
await session.execute(
|
||||
text("""
|
||||
INSERT INTO firmware_versions (id, architecture, channel, version, npk_url, checked_at)
|
||||
VALUES (gen_random_uuid(), :arch, :channel, :version, :url, NOW())
|
||||
ON CONFLICT (architecture, channel, version) DO UPDATE SET checked_at = NOW()
|
||||
"""),
|
||||
{
|
||||
"arch": architecture,
|
||||
"channel": channel,
|
||||
"version": latest_version,
|
||||
"url": npk_url,
|
||||
},
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
logger.debug(
|
||||
"device.firmware processed",
|
||||
extra={
|
||||
"device_id": device_id,
|
||||
"architecture": architecture,
|
||||
"installed": installed_version,
|
||||
"latest": latest_version,
|
||||
},
|
||||
)
|
||||
await msg.ack()
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to process device.firmware event: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
await msg.nak()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _subscribe_with_retry(js: JetStreamContext) -> None:
|
||||
"""Subscribe to device.firmware.> with durable consumer, retrying if stream not ready."""
|
||||
max_attempts = 6 # ~30 seconds at 5s intervals
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
await js.subscribe(
|
||||
"device.firmware.>",
|
||||
cb=on_device_firmware,
|
||||
durable="api-firmware-consumer",
|
||||
stream="DEVICE_EVENTS",
|
||||
)
|
||||
logger.info(
|
||||
"NATS: subscribed to device.firmware.> (durable: api-firmware-consumer)"
|
||||
)
|
||||
return
|
||||
except Exception as exc:
|
||||
if attempt < max_attempts:
|
||||
logger.warning(
|
||||
"NATS: stream DEVICE_EVENTS not ready for firmware (attempt %d/%d): %s — retrying in 5s",
|
||||
attempt,
|
||||
max_attempts,
|
||||
exc,
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
logger.warning(
|
||||
"NATS: giving up on device.firmware.> after %d attempts: %s — API will run without firmware updates",
|
||||
max_attempts,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
async def start_firmware_subscriber() -> Optional[NATSClient]:
|
||||
"""Connect to NATS and start the device.firmware.> subscription.
|
||||
|
||||
Uses a separate NATS connection from the status and metrics subscribers.
|
||||
|
||||
Returns the NATS connection (must be passed to stop_firmware_subscriber on shutdown).
|
||||
Raises on fatal connection errors after retry exhaustion.
|
||||
"""
|
||||
global _firmware_client
|
||||
|
||||
logger.info("NATS firmware: connecting to %s", settings.NATS_URL)
|
||||
|
||||
nc = await nats.connect(
|
||||
settings.NATS_URL,
|
||||
max_reconnect_attempts=-1,
|
||||
reconnect_time_wait=2,
|
||||
error_cb=_on_error,
|
||||
reconnected_cb=_on_reconnected,
|
||||
disconnected_cb=_on_disconnected,
|
||||
)
|
||||
|
||||
logger.info("NATS firmware: connected to %s", settings.NATS_URL)
|
||||
|
||||
js = nc.jetstream()
|
||||
await _subscribe_with_retry(js)
|
||||
|
||||
_firmware_client = nc
|
||||
return nc
|
||||
|
||||
|
||||
async def stop_firmware_subscriber(nc: Optional[NATSClient]) -> None:
|
||||
"""Drain and close the firmware NATS connection gracefully."""
|
||||
if nc is None:
|
||||
return
|
||||
try:
|
||||
logger.info("NATS firmware: draining connection...")
|
||||
await nc.drain()
|
||||
logger.info("NATS firmware: connection closed")
|
||||
except Exception as exc:
|
||||
logger.warning("NATS firmware: error during drain: %s", exc)
|
||||
try:
|
||||
await nc.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _on_error(exc: Exception) -> None:
|
||||
logger.error("NATS firmware error: %s", exc)
|
||||
|
||||
|
||||
async def _on_reconnected() -> None:
|
||||
logger.info("NATS firmware: reconnected")
|
||||
|
||||
|
||||
async def _on_disconnected() -> None:
|
||||
logger.warning("NATS firmware: disconnected")
|
||||
296
backend/app/services/git_store.py
Normal file
296
backend/app/services/git_store.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""pygit2-based git store for versioned config backup storage.
|
||||
|
||||
All functions in this module are synchronous (pygit2 is C bindings over libgit2).
|
||||
Callers running in an async context MUST wrap calls in:
|
||||
loop.run_in_executor(None, func, *args)
|
||||
or:
|
||||
asyncio.get_event_loop().run_in_executor(None, func, *args)
|
||||
|
||||
See Pitfall 3 in 04-RESEARCH.md — blocking pygit2 in async context stalls
|
||||
the event loop and causes timeouts for other concurrent requests.
|
||||
|
||||
Git layout:
|
||||
{GIT_STORE_PATH}/{tenant_id}.git/ <- bare repo per tenant
|
||||
objects/ refs/ HEAD <- standard bare git structure
|
||||
{device_id}/ <- device subtree
|
||||
export.rsc <- text export (/export compact)
|
||||
backup.bin <- binary system backup
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pygit2
|
||||
|
||||
from app.config import settings
|
||||
|
||||
# =========================================================================
|
||||
# Per-tenant mutex to prevent TreeBuilder race condition (Pitfall 5 in RESEARCH.md).
|
||||
# Two simultaneous backups for different devices in the same tenant repo would
|
||||
# each read HEAD, build their own device subtrees, and write conflicting root
|
||||
# trees. The second commit would lose the first's device subtree.
|
||||
# Lock scope is the entire tenant repo — not just the device.
|
||||
# =========================================================================
|
||||
_tenant_locks: dict[str, threading.Lock] = {}
|
||||
_tenant_locks_guard = threading.Lock()
|
||||
|
||||
|
||||
def _get_tenant_lock(tenant_id: str) -> threading.Lock:
|
||||
"""Return (creating if needed) the per-tenant commit lock."""
|
||||
with _tenant_locks_guard:
|
||||
if tenant_id not in _tenant_locks:
|
||||
_tenant_locks[tenant_id] = threading.Lock()
|
||||
return _tenant_locks[tenant_id]
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# PUBLIC API
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def get_or_create_repo(tenant_id: str) -> pygit2.Repository:
|
||||
"""Open the tenant's bare git repo, creating it on first use.
|
||||
|
||||
The repo lives at {GIT_STORE_PATH}/{tenant_id}.git. The parent directory
|
||||
is created if it does not exist.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID as string.
|
||||
|
||||
Returns:
|
||||
An open pygit2.Repository instance (bare).
|
||||
"""
|
||||
git_store_root = Path(settings.GIT_STORE_PATH)
|
||||
git_store_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
repo_path = git_store_root / f"{tenant_id}.git"
|
||||
if repo_path.exists():
|
||||
return pygit2.Repository(str(repo_path))
|
||||
|
||||
return pygit2.init_repository(str(repo_path), bare=True)
|
||||
|
||||
|
||||
def commit_backup(
|
||||
tenant_id: str,
|
||||
device_id: str,
|
||||
export_text: str,
|
||||
binary_backup: bytes,
|
||||
message: str,
|
||||
) -> str:
|
||||
"""Write a backup pair (export.rsc + backup.bin) as a git commit.
|
||||
|
||||
Creates or updates the device subdirectory in the tenant's bare repo.
|
||||
Preserves other devices' subdirectories by merging the device subtree
|
||||
into the existing root tree.
|
||||
|
||||
Per-tenant locking (threading.Lock) prevents the TreeBuilder race
|
||||
condition when two devices in the same tenant back up concurrently.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID as string.
|
||||
device_id: Device UUID as string (becomes a subdirectory in the repo).
|
||||
export_text: Text output of /export compact.
|
||||
binary_backup: Raw bytes from /system backup save.
|
||||
message: Commit message (format: "{trigger}: {hostname} ({ip}) at {ts}").
|
||||
|
||||
Returns:
|
||||
The hex commit SHA string (40 characters).
|
||||
"""
|
||||
lock = _get_tenant_lock(tenant_id)
|
||||
|
||||
with lock:
|
||||
repo = get_or_create_repo(tenant_id)
|
||||
|
||||
# Create blobs from content
|
||||
export_oid = repo.create_blob(export_text.encode("utf-8"))
|
||||
binary_oid = repo.create_blob(binary_backup)
|
||||
|
||||
# Build device subtree: {device_id}/export.rsc and {device_id}/backup.bin
|
||||
device_builder = repo.TreeBuilder()
|
||||
device_builder.insert("export.rsc", export_oid, pygit2.GIT_FILEMODE_BLOB)
|
||||
device_builder.insert("backup.bin", binary_oid, pygit2.GIT_FILEMODE_BLOB)
|
||||
device_tree_oid = device_builder.write()
|
||||
|
||||
# Merge device subtree into root tree, preserving all other device subtrees.
|
||||
# If the repo has no commits yet, start with an empty root tree.
|
||||
root_ref = repo.references.get("refs/heads/main")
|
||||
parent_commit: Optional[pygit2.Commit] = None
|
||||
|
||||
if root_ref is not None:
|
||||
try:
|
||||
parent_commit = repo.get(root_ref.target)
|
||||
root_builder = repo.TreeBuilder(parent_commit.tree)
|
||||
except Exception:
|
||||
root_builder = repo.TreeBuilder()
|
||||
else:
|
||||
root_builder = repo.TreeBuilder()
|
||||
|
||||
root_builder.insert(device_id, device_tree_oid, pygit2.GIT_FILEMODE_TREE)
|
||||
root_tree_oid = root_builder.write()
|
||||
|
||||
# Author signature — no real identity, portal service account
|
||||
author = pygit2.Signature("The Other Dude", "backup@tod.local")
|
||||
|
||||
parents = [root_ref.target] if root_ref is not None else []
|
||||
|
||||
commit_oid = repo.create_commit(
|
||||
"refs/heads/main",
|
||||
author,
|
||||
author,
|
||||
message,
|
||||
root_tree_oid,
|
||||
parents,
|
||||
)
|
||||
|
||||
return str(commit_oid)
|
||||
|
||||
|
||||
def read_file(
|
||||
tenant_id: str,
|
||||
commit_sha: str,
|
||||
device_id: str,
|
||||
filename: str,
|
||||
) -> bytes:
|
||||
"""Read a file blob from a specific backup commit.
|
||||
|
||||
Navigates the tree: root -> device_id subtree -> filename.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID as string.
|
||||
commit_sha: Full or abbreviated git commit SHA.
|
||||
device_id: Device UUID as string (subdirectory name in the repo).
|
||||
filename: File to read: "export.rsc" or "backup.bin".
|
||||
|
||||
Returns:
|
||||
Raw bytes of the file content.
|
||||
|
||||
Raises:
|
||||
KeyError: If device_id subtree or filename does not exist in commit.
|
||||
pygit2.GitError: If commit_sha is not found.
|
||||
"""
|
||||
repo = get_or_create_repo(tenant_id)
|
||||
|
||||
commit_obj = repo.get(commit_sha)
|
||||
if commit_obj is None:
|
||||
raise KeyError(f"Commit {commit_sha!r} not found in tenant {tenant_id!r} repo")
|
||||
|
||||
# Navigate: root tree -> device subtree -> file blob
|
||||
device_entry = commit_obj.tree[device_id]
|
||||
device_tree = repo.get(device_entry.id)
|
||||
file_entry = device_tree[filename]
|
||||
file_blob = repo.get(file_entry.id)
|
||||
|
||||
return file_blob.data
|
||||
|
||||
|
||||
def list_device_commits(
|
||||
tenant_id: str,
|
||||
device_id: str,
|
||||
) -> list[dict]:
|
||||
"""Walk commit history and return commits that include the device subtree.
|
||||
|
||||
Walks commits newest-first. Returns only commits where the device_id
|
||||
subtree is present in the root tree (the device had a backup in that commit).
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID as string.
|
||||
device_id: Device UUID as string.
|
||||
|
||||
Returns:
|
||||
List of dicts (newest first):
|
||||
[{"sha": str, "message": str, "timestamp": int}, ...]
|
||||
Empty list if no commits or device has never been backed up.
|
||||
"""
|
||||
repo = get_or_create_repo(tenant_id)
|
||||
|
||||
# If there are no commits, return empty list immediately.
|
||||
# Use refs/heads/main explicitly rather than repo.head (which defaults to
|
||||
# refs/heads/master — wrong when the repo uses 'main' as the default branch).
|
||||
main_ref = repo.references.get("refs/heads/main")
|
||||
if main_ref is None:
|
||||
return []
|
||||
head_target = main_ref.target
|
||||
|
||||
results = []
|
||||
walker = repo.walk(head_target, pygit2.GIT_SORT_TIME)
|
||||
|
||||
for commit in walker:
|
||||
# Check if device_id subtree exists in this commit's root tree.
|
||||
try:
|
||||
device_entry = commit.tree[device_id]
|
||||
except KeyError:
|
||||
# Device not present in this commit at all — skip.
|
||||
continue
|
||||
|
||||
# Only include this commit if it actually changed the device's subtree
|
||||
# vs its parent. This prevents every subsequent backup (for any device
|
||||
# in the same tenant) from appearing in all devices' histories.
|
||||
if commit.parents:
|
||||
parent = commit.parents[0]
|
||||
try:
|
||||
parent_device_entry = parent.tree[device_id]
|
||||
if parent_device_entry.id == device_entry.id:
|
||||
# Device subtree unchanged in this commit — skip.
|
||||
continue
|
||||
except KeyError:
|
||||
# Device wasn't in parent but is in this commit — it's the first entry.
|
||||
pass
|
||||
|
||||
results.append({
|
||||
"sha": str(commit.id),
|
||||
"message": commit.message.strip(),
|
||||
"timestamp": commit.commit_time,
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def compute_line_delta(old_text: str, new_text: str) -> tuple[int, int]:
|
||||
"""Compute (lines_added, lines_removed) between two text versions.
|
||||
|
||||
Uses difflib.SequenceMatcher to efficiently compute the line-count delta
|
||||
without generating a full unified diff. This is faster than
|
||||
difflib.unified_diff for large config files.
|
||||
|
||||
For the first backup (no prior version), pass old_text="" to get
|
||||
(total_lines, 0) as the delta.
|
||||
|
||||
Args:
|
||||
old_text: Previous export.rsc content (empty string for first backup).
|
||||
new_text: New export.rsc content.
|
||||
|
||||
Returns:
|
||||
Tuple of (lines_added, lines_removed).
|
||||
"""
|
||||
old_lines = old_text.splitlines() if old_text else []
|
||||
new_lines = new_text.splitlines() if new_text else []
|
||||
|
||||
if not old_lines and not new_lines:
|
||||
return (0, 0)
|
||||
|
||||
# For first backup (empty old), all lines are "added".
|
||||
if not old_lines:
|
||||
return (len(new_lines), 0)
|
||||
|
||||
# For deletion of all content, all lines are "removed".
|
||||
if not new_lines:
|
||||
return (0, len(old_lines))
|
||||
|
||||
matcher = difflib.SequenceMatcher(None, old_lines, new_lines, autojunk=False)
|
||||
|
||||
lines_added = 0
|
||||
lines_removed = 0
|
||||
|
||||
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
||||
if tag == "replace":
|
||||
lines_removed += i2 - i1
|
||||
lines_added += j2 - j1
|
||||
elif tag == "delete":
|
||||
lines_removed += i2 - i1
|
||||
elif tag == "insert":
|
||||
lines_added += j2 - j1
|
||||
# "equal" — no change
|
||||
|
||||
return (lines_added, lines_removed)
|
||||
324
backend/app/services/key_service.py
Normal file
324
backend/app/services/key_service.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""Key hierarchy management service for zero-knowledge architecture.
|
||||
|
||||
Provides CRUD operations for encrypted key bundles (UserKeySet),
|
||||
append-only audit logging (KeyAccessLog), and OpenBao Transit
|
||||
tenant key provisioning with credential migration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.key_set import KeyAccessLog, UserKeySet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def store_user_key_set(
|
||||
db: AsyncSession,
|
||||
user_id: UUID,
|
||||
tenant_id: UUID | None,
|
||||
encrypted_private_key: bytes,
|
||||
private_key_nonce: bytes,
|
||||
encrypted_vault_key: bytes,
|
||||
vault_key_nonce: bytes,
|
||||
public_key: bytes,
|
||||
pbkdf2_salt: bytes,
|
||||
hkdf_salt: bytes,
|
||||
pbkdf2_iterations: int = 650000,
|
||||
) -> UserKeySet:
|
||||
"""Store encrypted key bundle during registration.
|
||||
|
||||
Creates a new UserKeySet for the user. Each user has exactly one
|
||||
key set (UNIQUE constraint on user_id).
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
user_id: The user's UUID.
|
||||
tenant_id: The user's tenant UUID (None for super_admin).
|
||||
encrypted_private_key: RSA private key wrapped by AUK (AES-GCM).
|
||||
private_key_nonce: 12-byte AES-GCM nonce for private key.
|
||||
encrypted_vault_key: Tenant vault key wrapped by user's public key.
|
||||
vault_key_nonce: 12-byte AES-GCM nonce for vault key.
|
||||
public_key: RSA-2048 public key in SPKI format.
|
||||
pbkdf2_salt: 32-byte salt for PBKDF2 key derivation.
|
||||
hkdf_salt: 32-byte salt for HKDF Secret Key derivation.
|
||||
pbkdf2_iterations: PBKDF2 iteration count (default 650000).
|
||||
|
||||
Returns:
|
||||
The created UserKeySet instance.
|
||||
"""
|
||||
# Remove any existing key set (e.g. from a failed prior upgrade attempt)
|
||||
from sqlalchemy import delete
|
||||
await db.execute(delete(UserKeySet).where(UserKeySet.user_id == user_id))
|
||||
|
||||
key_set = UserKeySet(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
encrypted_private_key=encrypted_private_key,
|
||||
private_key_nonce=private_key_nonce,
|
||||
encrypted_vault_key=encrypted_vault_key,
|
||||
vault_key_nonce=vault_key_nonce,
|
||||
public_key=public_key,
|
||||
pbkdf2_salt=pbkdf2_salt,
|
||||
hkdf_salt=hkdf_salt,
|
||||
pbkdf2_iterations=pbkdf2_iterations,
|
||||
)
|
||||
db.add(key_set)
|
||||
await db.flush()
|
||||
return key_set
|
||||
|
||||
|
||||
async def get_user_key_set(
|
||||
db: AsyncSession, user_id: UUID
|
||||
) -> UserKeySet | None:
|
||||
"""Retrieve encrypted key bundle for login response.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
The UserKeySet if found, None otherwise.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(UserKeySet).where(UserKeySet.user_id == user_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def log_key_access(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
user_id: UUID | None,
|
||||
action: str,
|
||||
resource_type: str | None = None,
|
||||
resource_id: str | None = None,
|
||||
key_version: int | None = None,
|
||||
ip_address: str | None = None,
|
||||
device_id: UUID | None = None,
|
||||
justification: str | None = None,
|
||||
correlation_id: str | None = None,
|
||||
) -> None:
|
||||
"""Append to immutable key_access_log.
|
||||
|
||||
This table is append-only (INSERT+SELECT only via RLS policy).
|
||||
No UPDATE or DELETE is permitted.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
tenant_id: The tenant UUID for RLS isolation.
|
||||
user_id: The user who performed the action (None for system ops).
|
||||
action: Action description (e.g., 'create_key_set', 'decrypt_vault_key').
|
||||
resource_type: Optional resource type being accessed.
|
||||
resource_id: Optional resource identifier.
|
||||
key_version: Optional key version involved.
|
||||
ip_address: Optional client IP address.
|
||||
device_id: Optional device UUID for credential access tracking.
|
||||
justification: Optional justification for the access (e.g., 'api_backup').
|
||||
correlation_id: Optional correlation ID for request tracing.
|
||||
"""
|
||||
log_entry = KeyAccessLog(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
key_version=key_version,
|
||||
ip_address=ip_address,
|
||||
device_id=device_id,
|
||||
justification=justification,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
db.add(log_entry)
|
||||
await db.flush()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenBao Transit tenant key provisioning and credential migration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def provision_tenant_key(db: AsyncSession, tenant_id: UUID) -> str:
|
||||
"""Provision an OpenBao Transit key for a tenant and update the tenant record.
|
||||
|
||||
Idempotent: if the key already exists in OpenBao, it's a no-op on the
|
||||
OpenBao side. The tenant record is always updated with the key name.
|
||||
|
||||
Args:
|
||||
db: Async database session (admin engine, no RLS).
|
||||
tenant_id: Tenant UUID.
|
||||
|
||||
Returns:
|
||||
The key name (tenant_{uuid}).
|
||||
"""
|
||||
from app.models.tenant import Tenant
|
||||
from app.services.openbao_service import get_openbao_service
|
||||
|
||||
openbao = get_openbao_service()
|
||||
key_name = f"tenant_{tenant_id}"
|
||||
|
||||
await openbao.create_tenant_key(str(tenant_id))
|
||||
|
||||
# Update tenant record with key name
|
||||
result = await db.execute(
|
||||
select(Tenant).where(Tenant.id == tenant_id)
|
||||
)
|
||||
tenant = result.scalar_one_or_none()
|
||||
if tenant:
|
||||
tenant.openbao_key_name = key_name
|
||||
await db.flush()
|
||||
|
||||
logger.info(
|
||||
"Provisioned OpenBao Transit key for tenant %s (key=%s)",
|
||||
tenant_id,
|
||||
key_name,
|
||||
)
|
||||
return key_name
|
||||
|
||||
|
||||
async def migrate_tenant_credentials(db: AsyncSession, tenant_id: UUID) -> dict:
|
||||
"""Re-encrypt all legacy credentials for a tenant from AES-256-GCM to Transit.
|
||||
|
||||
Migrates device credentials, CA private keys, device cert private keys,
|
||||
and notification channel secrets. Already-migrated items are skipped.
|
||||
|
||||
Args:
|
||||
db: Async database session (admin engine, no RLS).
|
||||
tenant_id: Tenant UUID.
|
||||
|
||||
Returns:
|
||||
Dict with counts: {"devices": N, "cas": N, "certs": N, "channels": N, "errors": N}
|
||||
"""
|
||||
from app.config import settings
|
||||
from app.models.alert import NotificationChannel
|
||||
from app.models.certificate import CertificateAuthority, DeviceCertificate
|
||||
from app.models.device import Device
|
||||
from app.services.crypto import decrypt_credentials
|
||||
from app.services.openbao_service import get_openbao_service
|
||||
|
||||
openbao = get_openbao_service()
|
||||
legacy_key = settings.get_encryption_key_bytes()
|
||||
tid = str(tenant_id)
|
||||
|
||||
counts = {"devices": 0, "cas": 0, "certs": 0, "channels": 0, "errors": 0}
|
||||
|
||||
# --- Migrate device credentials ---
|
||||
result = await db.execute(
|
||||
select(Device).where(
|
||||
Device.tenant_id == tenant_id,
|
||||
Device.encrypted_credentials.isnot(None),
|
||||
(Device.encrypted_credentials_transit.is_(None) | (Device.encrypted_credentials_transit == "")),
|
||||
)
|
||||
)
|
||||
for device in result.scalars().all():
|
||||
try:
|
||||
plaintext = decrypt_credentials(device.encrypted_credentials, legacy_key)
|
||||
device.encrypted_credentials_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
|
||||
counts["devices"] += 1
|
||||
except Exception as e:
|
||||
logger.error("Failed to migrate device %s credentials: %s", device.id, e)
|
||||
counts["errors"] += 1
|
||||
|
||||
# --- Migrate CA private keys ---
|
||||
result = await db.execute(
|
||||
select(CertificateAuthority).where(
|
||||
CertificateAuthority.tenant_id == tenant_id,
|
||||
CertificateAuthority.encrypted_private_key.isnot(None),
|
||||
(CertificateAuthority.encrypted_private_key_transit.is_(None) | (CertificateAuthority.encrypted_private_key_transit == "")),
|
||||
)
|
||||
)
|
||||
for ca in result.scalars().all():
|
||||
try:
|
||||
plaintext = decrypt_credentials(ca.encrypted_private_key, legacy_key)
|
||||
ca.encrypted_private_key_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
|
||||
counts["cas"] += 1
|
||||
except Exception as e:
|
||||
logger.error("Failed to migrate CA %s private key: %s", ca.id, e)
|
||||
counts["errors"] += 1
|
||||
|
||||
# --- Migrate device cert private keys ---
|
||||
result = await db.execute(
|
||||
select(DeviceCertificate).where(
|
||||
DeviceCertificate.tenant_id == tenant_id,
|
||||
DeviceCertificate.encrypted_private_key.isnot(None),
|
||||
(DeviceCertificate.encrypted_private_key_transit.is_(None) | (DeviceCertificate.encrypted_private_key_transit == "")),
|
||||
)
|
||||
)
|
||||
for cert in result.scalars().all():
|
||||
try:
|
||||
plaintext = decrypt_credentials(cert.encrypted_private_key, legacy_key)
|
||||
cert.encrypted_private_key_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
|
||||
counts["certs"] += 1
|
||||
except Exception as e:
|
||||
logger.error("Failed to migrate cert %s private key: %s", cert.id, e)
|
||||
counts["errors"] += 1
|
||||
|
||||
# --- Migrate notification channel secrets ---
|
||||
result = await db.execute(
|
||||
select(NotificationChannel).where(
|
||||
NotificationChannel.tenant_id == tenant_id,
|
||||
)
|
||||
)
|
||||
for ch in result.scalars().all():
|
||||
migrated_any = False
|
||||
try:
|
||||
# SMTP password
|
||||
if ch.smtp_password and not ch.smtp_password_transit:
|
||||
plaintext = decrypt_credentials(ch.smtp_password, legacy_key)
|
||||
ch.smtp_password_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
|
||||
migrated_any = True
|
||||
if migrated_any:
|
||||
counts["channels"] += 1
|
||||
except Exception as e:
|
||||
logger.error("Failed to migrate channel %s secrets: %s", ch.id, e)
|
||||
counts["errors"] += 1
|
||||
|
||||
await db.flush()
|
||||
|
||||
logger.info(
|
||||
"Tenant %s credential migration complete: %s",
|
||||
tenant_id,
|
||||
counts,
|
||||
)
|
||||
return counts
|
||||
|
||||
|
||||
async def provision_existing_tenants(db: AsyncSession) -> dict:
|
||||
"""Provision OpenBao Transit keys for all existing tenants and migrate credentials.
|
||||
|
||||
Called on app startup to ensure all tenants have Transit keys.
|
||||
Idempotent -- running multiple times is safe (already-migrated items are skipped).
|
||||
|
||||
Args:
|
||||
db: Async database session (admin engine, no RLS).
|
||||
|
||||
Returns:
|
||||
Summary dict with total counts across all tenants.
|
||||
"""
|
||||
from app.models.tenant import Tenant
|
||||
|
||||
result = await db.execute(select(Tenant))
|
||||
tenants = result.scalars().all()
|
||||
|
||||
total = {"tenants": len(tenants), "devices": 0, "cas": 0, "certs": 0, "channels": 0, "errors": 0}
|
||||
|
||||
for tenant in tenants:
|
||||
try:
|
||||
await provision_tenant_key(db, tenant.id)
|
||||
counts = await migrate_tenant_credentials(db, tenant.id)
|
||||
total["devices"] += counts["devices"]
|
||||
total["cas"] += counts["cas"]
|
||||
total["certs"] += counts["certs"]
|
||||
total["channels"] += counts["channels"]
|
||||
total["errors"] += counts["errors"]
|
||||
except Exception as e:
|
||||
logger.error("Failed to provision/migrate tenant %s: %s", tenant.id, e)
|
||||
total["errors"] += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info("Existing tenant provisioning complete: %s", total)
|
||||
return total
|
||||
346
backend/app/services/metrics_subscriber.py
Normal file
346
backend/app/services/metrics_subscriber.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""NATS JetStream subscriber for device metrics events.
|
||||
|
||||
Subscribes to device.metrics.> and inserts into TimescaleDB hypertables:
|
||||
- interface_metrics — per-interface rx/tx byte counters
|
||||
- health_metrics — CPU, memory, disk, temperature per device
|
||||
- wireless_metrics — per-wireless-interface aggregated client stats
|
||||
|
||||
Also maintains denormalized last_cpu_load and last_memory_used_pct columns
|
||||
on the devices table for efficient fleet table display.
|
||||
|
||||
Uses AdminAsyncSessionLocal (superuser bypass RLS) so metrics from any tenant
|
||||
can be written without setting app.current_tenant.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import nats
|
||||
from nats.js import JetStreamContext
|
||||
from nats.aio.client import Client as NATSClient
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_metrics_client: Optional[NATSClient] = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# INSERT HANDLERS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _parse_timestamp(val: str | None) -> datetime:
|
||||
"""Parse an ISO 8601 / RFC 3339 timestamp string into a datetime object."""
|
||||
if not val:
|
||||
return datetime.now(timezone.utc)
|
||||
try:
|
||||
return datetime.fromisoformat(val.replace("Z", "+00:00"))
|
||||
except (ValueError, AttributeError):
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
async def _insert_health_metrics(session, data: dict) -> None:
|
||||
"""Insert a health metrics event into health_metrics and update devices."""
|
||||
health = data.get("health")
|
||||
if not health:
|
||||
logger.warning("health metrics event missing 'health' field — skipping")
|
||||
return
|
||||
|
||||
device_id = data.get("device_id")
|
||||
tenant_id = data.get("tenant_id")
|
||||
collected_at = _parse_timestamp(data.get("collected_at"))
|
||||
|
||||
# Parse numeric values; treat empty strings as NULL.
|
||||
def parse_int(val: str | None) -> int | None:
|
||||
if not val:
|
||||
return None
|
||||
try:
|
||||
return int(val)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
cpu_load = parse_int(health.get("cpu_load"))
|
||||
free_memory = parse_int(health.get("free_memory"))
|
||||
total_memory = parse_int(health.get("total_memory"))
|
||||
free_disk = parse_int(health.get("free_disk"))
|
||||
total_disk = parse_int(health.get("total_disk"))
|
||||
temperature = parse_int(health.get("temperature"))
|
||||
|
||||
await session.execute(
|
||||
text("""
|
||||
INSERT INTO health_metrics
|
||||
(time, device_id, tenant_id, cpu_load, free_memory, total_memory,
|
||||
free_disk, total_disk, temperature)
|
||||
VALUES
|
||||
(:time, :device_id, :tenant_id, :cpu_load, :free_memory, :total_memory,
|
||||
:free_disk, :total_disk, :temperature)
|
||||
"""),
|
||||
{
|
||||
"time": collected_at,
|
||||
"device_id": device_id,
|
||||
"tenant_id": tenant_id,
|
||||
"cpu_load": cpu_load,
|
||||
"free_memory": free_memory,
|
||||
"total_memory": total_memory,
|
||||
"free_disk": free_disk,
|
||||
"total_disk": total_disk,
|
||||
"temperature": temperature,
|
||||
},
|
||||
)
|
||||
|
||||
# Update denormalized columns on devices for fleet table display.
|
||||
# Compute memory percentage in Python to avoid asyncpg type ambiguity.
|
||||
mem_pct = None
|
||||
if total_memory and total_memory > 0 and free_memory is not None:
|
||||
mem_pct = round((1.0 - free_memory / total_memory) * 100)
|
||||
|
||||
await session.execute(
|
||||
text("""
|
||||
UPDATE devices SET
|
||||
last_cpu_load = COALESCE(:cpu_load, last_cpu_load),
|
||||
last_memory_used_pct = COALESCE(:mem_pct, last_memory_used_pct),
|
||||
updated_at = NOW()
|
||||
WHERE id = CAST(:device_id AS uuid)
|
||||
"""),
|
||||
{
|
||||
"cpu_load": cpu_load,
|
||||
"mem_pct": mem_pct,
|
||||
"device_id": device_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _insert_interface_metrics(session, data: dict) -> None:
|
||||
"""Insert per-interface traffic counters into interface_metrics."""
|
||||
interfaces = data.get("interfaces")
|
||||
if not interfaces:
|
||||
return # Device may have no interfaces (unlikely but safe to skip)
|
||||
|
||||
device_id = data.get("device_id")
|
||||
tenant_id = data.get("tenant_id")
|
||||
collected_at = _parse_timestamp(data.get("collected_at"))
|
||||
|
||||
for iface in interfaces:
|
||||
await session.execute(
|
||||
text("""
|
||||
INSERT INTO interface_metrics
|
||||
(time, device_id, tenant_id, interface, rx_bytes, tx_bytes, rx_bps, tx_bps)
|
||||
VALUES
|
||||
(:time, :device_id, :tenant_id, :interface, :rx_bytes, :tx_bytes, NULL, NULL)
|
||||
"""),
|
||||
{
|
||||
"time": collected_at,
|
||||
"device_id": device_id,
|
||||
"tenant_id": tenant_id,
|
||||
"interface": iface.get("name"),
|
||||
"rx_bytes": iface.get("rx_bytes"),
|
||||
"tx_bytes": iface.get("tx_bytes"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _insert_wireless_metrics(session, data: dict) -> None:
|
||||
"""Insert per-wireless-interface aggregated client stats into wireless_metrics."""
|
||||
wireless = data.get("wireless")
|
||||
if not wireless:
|
||||
return # Device may have no wireless interfaces
|
||||
|
||||
device_id = data.get("device_id")
|
||||
tenant_id = data.get("tenant_id")
|
||||
collected_at = _parse_timestamp(data.get("collected_at"))
|
||||
|
||||
for wif in wireless:
|
||||
await session.execute(
|
||||
text("""
|
||||
INSERT INTO wireless_metrics
|
||||
(time, device_id, tenant_id, interface, client_count, avg_signal, ccq, frequency)
|
||||
VALUES
|
||||
(:time, :device_id, :tenant_id, :interface,
|
||||
:client_count, :avg_signal, :ccq, :frequency)
|
||||
"""),
|
||||
{
|
||||
"time": collected_at,
|
||||
"device_id": device_id,
|
||||
"tenant_id": tenant_id,
|
||||
"interface": wif.get("interface"),
|
||||
"client_count": wif.get("client_count"),
|
||||
"avg_signal": wif.get("avg_signal"),
|
||||
"ccq": wif.get("ccq"),
|
||||
"frequency": wif.get("frequency"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MAIN MESSAGE HANDLER
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def on_device_metrics(msg) -> None:
|
||||
"""Handle a device.metrics event published by the Go poller.
|
||||
|
||||
Dispatches to the appropriate insert handler based on the 'type' field:
|
||||
- "health" → _insert_health_metrics + update devices
|
||||
- "interfaces" → _insert_interface_metrics
|
||||
- "wireless" → _insert_wireless_metrics
|
||||
|
||||
On success, acknowledges the message. On error, NAKs so NATS can redeliver.
|
||||
"""
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
metric_type = data.get("type")
|
||||
device_id = data.get("device_id")
|
||||
|
||||
if not metric_type or not device_id:
|
||||
logger.warning(
|
||||
"device.metrics event missing 'type' or 'device_id' — skipping"
|
||||
)
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
if metric_type == "health":
|
||||
await _insert_health_metrics(session, data)
|
||||
elif metric_type == "interfaces":
|
||||
await _insert_interface_metrics(session, data)
|
||||
elif metric_type == "wireless":
|
||||
await _insert_wireless_metrics(session, data)
|
||||
else:
|
||||
logger.warning("Unknown metric type '%s' — skipping", metric_type)
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Alert evaluation — non-fatal; metric write is the primary operation
|
||||
try:
|
||||
from app.services import alert_evaluator
|
||||
await alert_evaluator.evaluate(
|
||||
device_id=device_id,
|
||||
tenant_id=data.get("tenant_id", ""),
|
||||
metric_type=metric_type,
|
||||
data=data,
|
||||
)
|
||||
except Exception as eval_err:
|
||||
logger.warning("Alert evaluation failed for device %s: %s", device_id, eval_err)
|
||||
|
||||
logger.debug(
|
||||
"device.metrics processed",
|
||||
extra={"device_id": device_id, "type": metric_type},
|
||||
)
|
||||
await msg.ack()
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to process device.metrics event: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
await msg.nak()
|
||||
except Exception:
|
||||
pass # If NAK also fails, NATS will redeliver after ack_wait
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SUBSCRIPTION SETUP
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def _subscribe_with_retry(js: JetStreamContext) -> None:
|
||||
"""Subscribe to device.metrics.> with durable consumer, retrying if stream not ready."""
|
||||
max_attempts = 6 # ~30 seconds at 5s intervals
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
await js.subscribe(
|
||||
"device.metrics.>",
|
||||
cb=on_device_metrics,
|
||||
durable="api-metrics-consumer",
|
||||
stream="DEVICE_EVENTS",
|
||||
)
|
||||
logger.info(
|
||||
"NATS: subscribed to device.metrics.> (durable: api-metrics-consumer)"
|
||||
)
|
||||
return
|
||||
except Exception as exc:
|
||||
if attempt < max_attempts:
|
||||
logger.warning(
|
||||
"NATS: stream DEVICE_EVENTS not ready for metrics (attempt %d/%d): %s — retrying in 5s",
|
||||
attempt,
|
||||
max_attempts,
|
||||
exc,
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
logger.warning(
|
||||
"NATS: giving up on device.metrics.> after %d attempts: %s — API will run without metrics ingestion",
|
||||
max_attempts,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
async def start_metrics_subscriber() -> Optional[NATSClient]:
|
||||
"""Connect to NATS and start the device.metrics.> subscription.
|
||||
|
||||
Uses a separate NATS connection from the status subscriber — simpler and
|
||||
NATS handles multiple connections per client efficiently.
|
||||
|
||||
Returns the NATS connection (must be passed to stop_metrics_subscriber on shutdown).
|
||||
Raises on fatal connection errors after retry exhaustion.
|
||||
"""
|
||||
global _metrics_client
|
||||
|
||||
logger.info("NATS metrics: connecting to %s", settings.NATS_URL)
|
||||
|
||||
nc = await nats.connect(
|
||||
settings.NATS_URL,
|
||||
max_reconnect_attempts=-1,
|
||||
reconnect_time_wait=2,
|
||||
error_cb=_on_error,
|
||||
reconnected_cb=_on_reconnected,
|
||||
disconnected_cb=_on_disconnected,
|
||||
)
|
||||
|
||||
logger.info("NATS metrics: connected to %s", settings.NATS_URL)
|
||||
|
||||
js = nc.jetstream()
|
||||
await _subscribe_with_retry(js)
|
||||
|
||||
_metrics_client = nc
|
||||
return nc
|
||||
|
||||
|
||||
async def stop_metrics_subscriber(nc: Optional[NATSClient]) -> None:
|
||||
"""Drain and close the metrics NATS connection gracefully."""
|
||||
if nc is None:
|
||||
return
|
||||
try:
|
||||
logger.info("NATS metrics: draining connection...")
|
||||
await nc.drain()
|
||||
logger.info("NATS metrics: connection closed")
|
||||
except Exception as exc:
|
||||
logger.warning("NATS metrics: error during drain: %s", exc)
|
||||
try:
|
||||
await nc.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _on_error(exc: Exception) -> None:
|
||||
logger.error("NATS metrics error: %s", exc)
|
||||
|
||||
|
||||
async def _on_reconnected() -> None:
|
||||
logger.info("NATS metrics: reconnected")
|
||||
|
||||
|
||||
async def _on_disconnected() -> None:
|
||||
logger.warning("NATS metrics: disconnected")
|
||||
231
backend/app/services/nats_subscriber.py
Normal file
231
backend/app/services/nats_subscriber.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""NATS JetStream subscriber for device status events from the Go poller.
|
||||
|
||||
Subscribes to device.status.> and updates device records in PostgreSQL.
|
||||
This is a system-level process that needs to update devices across all tenants,
|
||||
so it uses the admin engine (bypasses RLS).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import nats
|
||||
from nats.js import JetStreamContext
|
||||
from nats.aio.client import Client as NATSClient
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_nats_client: Optional[NATSClient] = None
|
||||
|
||||
# Regex for RouterOS uptime strings like "42d14h23m15s", "14h23m15s", "23m15s", "3w2d"
|
||||
_UPTIME_RE = re.compile(r"(?:(\d+)w)?(?:(\d+)d)?(?:(\d+)h)?(?:(\d+)m)?(?:(\d+)s)?")
|
||||
|
||||
|
||||
def _parse_uptime(raw: str) -> int | None:
|
||||
"""Parse a RouterOS uptime string into total seconds."""
|
||||
if not raw:
|
||||
return None
|
||||
m = _UPTIME_RE.fullmatch(raw)
|
||||
if not m:
|
||||
return None
|
||||
weeks = int(m.group(1) or 0)
|
||||
days = int(m.group(2) or 0)
|
||||
hours = int(m.group(3) or 0)
|
||||
minutes = int(m.group(4) or 0)
|
||||
seconds = int(m.group(5) or 0)
|
||||
total = weeks * 604800 + days * 86400 + hours * 3600 + minutes * 60 + seconds
|
||||
return total if total > 0 else None
|
||||
|
||||
|
||||
async def on_device_status(msg) -> None:
|
||||
"""Handle a device.status event published by the Go poller.
|
||||
|
||||
Payload (JSON):
|
||||
device_id (str) — UUID of the device
|
||||
tenant_id (str) — UUID of the owning tenant
|
||||
status (str) — "online" or "offline"
|
||||
routeros_version (str | None) — e.g. "7.16.2"
|
||||
major_version (int | None) — e.g. 7
|
||||
board_name (str | None) — e.g. "RB4011iGS+5HacQ2HnD"
|
||||
last_seen (str | None) — ISO-8601 timestamp
|
||||
"""
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
device_id = data.get("device_id")
|
||||
status = data.get("status")
|
||||
routeros_version = data.get("routeros_version")
|
||||
major_version = data.get("major_version")
|
||||
board_name = data.get("board_name")
|
||||
last_seen_raw = data.get("last_seen")
|
||||
serial_number = data.get("serial_number") or None
|
||||
firmware_version = data.get("firmware_version") or None
|
||||
uptime_seconds = _parse_uptime(data.get("uptime", ""))
|
||||
|
||||
if not device_id or not status:
|
||||
logger.warning("Received device.status event with missing device_id or status — skipping")
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
# Parse timestamp in Python — asyncpg needs datetime objects, not strings
|
||||
last_seen_dt = None
|
||||
if last_seen_raw:
|
||||
try:
|
||||
last_seen_dt = datetime.fromisoformat(last_seen_raw.replace("Z", "+00:00"))
|
||||
except (ValueError, AttributeError):
|
||||
last_seen_dt = datetime.now(timezone.utc)
|
||||
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE devices SET
|
||||
status = :status,
|
||||
routeros_version = COALESCE(:routeros_version, routeros_version),
|
||||
routeros_major_version = COALESCE(:major_version, routeros_major_version),
|
||||
model = COALESCE(:board_name, model),
|
||||
serial_number = COALESCE(:serial_number, serial_number),
|
||||
firmware_version = COALESCE(:firmware_version, firmware_version),
|
||||
uptime_seconds = COALESCE(:uptime_seconds, uptime_seconds),
|
||||
last_seen = COALESCE(:last_seen, last_seen),
|
||||
updated_at = NOW()
|
||||
WHERE id = CAST(:device_id AS uuid)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"status": status,
|
||||
"routeros_version": routeros_version,
|
||||
"major_version": major_version,
|
||||
"board_name": board_name,
|
||||
"serial_number": serial_number,
|
||||
"firmware_version": firmware_version,
|
||||
"uptime_seconds": uptime_seconds,
|
||||
"last_seen": last_seen_dt,
|
||||
"device_id": device_id,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Alert evaluation for offline/online status changes — non-fatal
|
||||
try:
|
||||
from app.services import alert_evaluator
|
||||
if status == "offline":
|
||||
await alert_evaluator.evaluate_offline(device_id, data.get("tenant_id", ""))
|
||||
elif status == "online":
|
||||
await alert_evaluator.evaluate_online(device_id, data.get("tenant_id", ""))
|
||||
except Exception as e:
|
||||
logger.warning("Alert evaluation failed for device %s status=%s: %s", device_id, status, e)
|
||||
|
||||
logger.info(
|
||||
"Device status updated",
|
||||
extra={
|
||||
"device_id": device_id,
|
||||
"status": status,
|
||||
"routeros_version": routeros_version,
|
||||
},
|
||||
)
|
||||
await msg.ack()
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to process device.status event: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
await msg.nak()
|
||||
except Exception:
|
||||
pass # If NAK also fails, NATS will redeliver after ack_wait
|
||||
|
||||
|
||||
async def _subscribe_with_retry(js: JetStreamContext) -> None:
|
||||
"""Subscribe to device.status.> with durable consumer, retrying if stream not ready."""
|
||||
max_attempts = 6 # ~30 seconds at 5s intervals
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
await js.subscribe(
|
||||
"device.status.>",
|
||||
cb=on_device_status,
|
||||
durable="api-status-consumer",
|
||||
stream="DEVICE_EVENTS",
|
||||
)
|
||||
logger.info("NATS: subscribed to device.status.> (durable: api-status-consumer)")
|
||||
return
|
||||
except Exception as exc:
|
||||
if attempt < max_attempts:
|
||||
logger.warning(
|
||||
"NATS: stream DEVICE_EVENTS not ready (attempt %d/%d): %s — retrying in 5s",
|
||||
attempt,
|
||||
max_attempts,
|
||||
exc,
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
logger.warning(
|
||||
"NATS: giving up on device.status.> after %d attempts: %s — API will run without real-time status updates",
|
||||
max_attempts,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
async def start_nats_subscriber() -> Optional[NATSClient]:
|
||||
"""Connect to NATS and start the device.status.> subscription.
|
||||
|
||||
Returns the NATS connection (must be passed to stop_nats_subscriber on shutdown).
|
||||
Raises on fatal connection errors after retry exhaustion.
|
||||
"""
|
||||
global _nats_client
|
||||
|
||||
logger.info("NATS: connecting to %s", settings.NATS_URL)
|
||||
|
||||
nc = await nats.connect(
|
||||
settings.NATS_URL,
|
||||
max_reconnect_attempts=-1, # reconnect forever (pod-to-pod transient failures)
|
||||
reconnect_time_wait=2,
|
||||
error_cb=_on_error,
|
||||
reconnected_cb=_on_reconnected,
|
||||
disconnected_cb=_on_disconnected,
|
||||
)
|
||||
|
||||
logger.info("NATS: connected to %s", settings.NATS_URL)
|
||||
|
||||
js = nc.jetstream()
|
||||
await _subscribe_with_retry(js)
|
||||
|
||||
_nats_client = nc
|
||||
return nc
|
||||
|
||||
|
||||
async def stop_nats_subscriber(nc: Optional[NATSClient]) -> None:
|
||||
"""Drain and close the NATS connection gracefully."""
|
||||
if nc is None:
|
||||
return
|
||||
try:
|
||||
logger.info("NATS: draining connection...")
|
||||
await nc.drain()
|
||||
logger.info("NATS: connection closed")
|
||||
except Exception as exc:
|
||||
logger.warning("NATS: error during drain: %s", exc)
|
||||
try:
|
||||
await nc.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _on_error(exc: Exception) -> None:
|
||||
logger.error("NATS error: %s", exc)
|
||||
|
||||
|
||||
async def _on_reconnected() -> None:
|
||||
logger.info("NATS: reconnected")
|
||||
|
||||
|
||||
async def _on_disconnected() -> None:
|
||||
logger.warning("NATS: disconnected")
|
||||
256
backend/app/services/notification_service.py
Normal file
256
backend/app/services/notification_service.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""Email and webhook notification delivery for alert events.
|
||||
|
||||
Best-effort delivery: failures are logged but never raised.
|
||||
Each dispatch is wrapped in try/except so one failing channel
|
||||
doesn't prevent delivery to other channels.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def dispatch_notifications(
|
||||
alert_event: dict[str, Any],
|
||||
channels: list[dict[str, Any]],
|
||||
device_hostname: str,
|
||||
) -> None:
|
||||
"""Send notifications for an alert event to all provided channels.
|
||||
|
||||
Args:
|
||||
alert_event: Dict with alert event fields (status, severity, metric, etc.)
|
||||
channels: List of notification channel dicts
|
||||
device_hostname: Human-readable device name for messages
|
||||
"""
|
||||
for channel in channels:
|
||||
try:
|
||||
if channel["channel_type"] == "email":
|
||||
await _send_email(channel, alert_event, device_hostname)
|
||||
elif channel["channel_type"] == "webhook":
|
||||
await _send_webhook(channel, alert_event, device_hostname)
|
||||
elif channel["channel_type"] == "slack":
|
||||
await _send_slack(channel, alert_event, device_hostname)
|
||||
else:
|
||||
logger.warning("Unknown channel type: %s", channel["channel_type"])
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Notification delivery failed for channel %s (%s): %s",
|
||||
channel.get("name"), channel.get("channel_type"), e,
|
||||
)
|
||||
|
||||
|
||||
async def _send_email(channel: dict, alert_event: dict, device_hostname: str) -> None:
|
||||
"""Send alert notification email using per-channel SMTP config."""
|
||||
from app.services.email_service import SMTPConfig, send_email
|
||||
|
||||
severity = alert_event.get("severity", "warning")
|
||||
status = alert_event.get("status", "firing")
|
||||
rule_name = alert_event.get("rule_name") or alert_event.get("message", "Unknown Rule")
|
||||
metric = alert_event.get("metric_name") or alert_event.get("metric", "")
|
||||
value = alert_event.get("current_value") or alert_event.get("value", "")
|
||||
threshold = alert_event.get("threshold", "")
|
||||
|
||||
severity_colors = {
|
||||
"critical": "#ef4444",
|
||||
"warning": "#f59e0b",
|
||||
"info": "#38bdf8",
|
||||
}
|
||||
color = severity_colors.get(severity, "#38bdf8")
|
||||
status_label = "RESOLVED" if status == "resolved" else "FIRING"
|
||||
|
||||
html = f"""
|
||||
<div style="font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; max-width: 600px; margin: 0 auto;">
|
||||
<div style="background: {color}; padding: 16px 24px; border-radius: 8px 8px 0 0;">
|
||||
<h2 style="color: #fff; margin: 0;">[{status_label}] {rule_name}</h2>
|
||||
</div>
|
||||
<div style="background: #1e293b; padding: 24px; border-radius: 0 0 8px 8px; color: #e2e8f0;">
|
||||
<table style="width: 100%; border-collapse: collapse;">
|
||||
<tr><td style="padding: 8px 0; color: #94a3b8;">Device</td><td style="padding: 8px 0;">{device_hostname}</td></tr>
|
||||
<tr><td style="padding: 8px 0; color: #94a3b8;">Severity</td><td style="padding: 8px 0;">{severity.upper()}</td></tr>
|
||||
<tr><td style="padding: 8px 0; color: #94a3b8;">Metric</td><td style="padding: 8px 0;">{metric}</td></tr>
|
||||
<tr><td style="padding: 8px 0; color: #94a3b8;">Value</td><td style="padding: 8px 0;">{value}</td></tr>
|
||||
<tr><td style="padding: 8px 0; color: #94a3b8;">Threshold</td><td style="padding: 8px 0;">{threshold}</td></tr>
|
||||
</table>
|
||||
<p style="color: #64748b; font-size: 12px; margin-top: 24px;">
|
||||
TOD — Fleet Management for MikroTik RouterOS
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
plain = (
|
||||
f"[{status_label}] {rule_name}\n\n"
|
||||
f"Device: {device_hostname}\n"
|
||||
f"Severity: {severity}\n"
|
||||
f"Metric: {metric}\n"
|
||||
f"Value: {value}\n"
|
||||
f"Threshold: {threshold}\n"
|
||||
)
|
||||
|
||||
# Decrypt SMTP password (Transit first, then legacy Fernet)
|
||||
smtp_password = None
|
||||
transit_cipher = channel.get("smtp_password_transit")
|
||||
legacy_cipher = channel.get("smtp_password")
|
||||
tenant_id = channel.get("tenant_id")
|
||||
|
||||
if transit_cipher and tenant_id:
|
||||
try:
|
||||
from app.services.kms_service import decrypt_transit
|
||||
smtp_password = await decrypt_transit(transit_cipher, tenant_id)
|
||||
except Exception:
|
||||
logger.warning("Transit decryption failed for channel %s, trying legacy", channel.get("id"))
|
||||
|
||||
if not smtp_password and legacy_cipher:
|
||||
try:
|
||||
from app.config import settings as app_settings
|
||||
from cryptography.fernet import Fernet
|
||||
raw = bytes(legacy_cipher) if isinstance(legacy_cipher, memoryview) else legacy_cipher
|
||||
f = Fernet(app_settings.CREDENTIAL_ENCRYPTION_KEY.encode())
|
||||
smtp_password = f.decrypt(raw).decode()
|
||||
except Exception:
|
||||
logger.warning("Legacy decryption failed for channel %s", channel.get("id"))
|
||||
|
||||
config = SMTPConfig(
|
||||
host=channel.get("smtp_host", "localhost"),
|
||||
port=channel.get("smtp_port", 587),
|
||||
user=channel.get("smtp_user"),
|
||||
password=smtp_password,
|
||||
use_tls=channel.get("smtp_use_tls", False),
|
||||
from_address=channel.get("from_address") or "alerts@mikrotik-portal.local",
|
||||
)
|
||||
|
||||
to = channel.get("to_address")
|
||||
subject = f"[TOD {status_label}] {rule_name} — {device_hostname}"
|
||||
await send_email(to, subject, html, plain, config)
|
||||
|
||||
|
||||
async def _send_webhook(
|
||||
channel: dict[str, Any],
|
||||
alert_event: dict[str, Any],
|
||||
device_hostname: str,
|
||||
) -> None:
|
||||
"""Send alert notification to a webhook URL (Slack-compatible JSON)."""
|
||||
severity = alert_event.get("severity", "info")
|
||||
status = alert_event.get("status", "firing")
|
||||
metric = alert_event.get("metric")
|
||||
value = alert_event.get("value")
|
||||
threshold = alert_event.get("threshold")
|
||||
message_text = alert_event.get("message", "")
|
||||
|
||||
payload = {
|
||||
"alert_name": message_text,
|
||||
"severity": severity,
|
||||
"status": status,
|
||||
"device": device_hostname,
|
||||
"device_id": alert_event.get("device_id"),
|
||||
"metric": metric,
|
||||
"value": value,
|
||||
"threshold": threshold,
|
||||
"timestamp": str(alert_event.get("fired_at", "")),
|
||||
"text": f"[{severity.upper()}] {device_hostname}: {message_text}",
|
||||
}
|
||||
|
||||
webhook_url = channel.get("webhook_url", "")
|
||||
if not webhook_url:
|
||||
logger.warning("Webhook channel %s has no URL configured", channel.get("name"))
|
||||
return
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.post(webhook_url, json=payload)
|
||||
logger.info(
|
||||
"Webhook notification sent to %s — status %d",
|
||||
webhook_url, response.status_code,
|
||||
)
|
||||
|
||||
|
||||
async def _send_slack(
|
||||
channel: dict[str, Any],
|
||||
alert_event: dict[str, Any],
|
||||
device_hostname: str,
|
||||
) -> None:
|
||||
"""Send alert notification to Slack via incoming webhook with Block Kit formatting."""
|
||||
severity = alert_event.get("severity", "info").upper()
|
||||
status = alert_event.get("status", "firing")
|
||||
metric = alert_event.get("metric", "unknown")
|
||||
message_text = alert_event.get("message", "")
|
||||
value = alert_event.get("value")
|
||||
threshold = alert_event.get("threshold")
|
||||
|
||||
color = {"CRITICAL": "#dc2626", "WARNING": "#f59e0b", "INFO": "#3b82f6"}.get(severity, "#6b7280")
|
||||
status_label = "RESOLVED" if status == "resolved" else status
|
||||
|
||||
blocks = [
|
||||
{
|
||||
"type": "header",
|
||||
"text": {"type": "plain_text", "text": f"{'✅' if status == 'resolved' else '🚨'} [{severity}] {status_label.upper()}"},
|
||||
},
|
||||
{
|
||||
"type": "section",
|
||||
"fields": [
|
||||
{"type": "mrkdwn", "text": f"*Device:*\n{device_hostname}"},
|
||||
{"type": "mrkdwn", "text": f"*Metric:*\n{metric}"},
|
||||
],
|
||||
},
|
||||
]
|
||||
if value is not None or threshold is not None:
|
||||
fields = []
|
||||
if value is not None:
|
||||
fields.append({"type": "mrkdwn", "text": f"*Value:*\n{value}"})
|
||||
if threshold is not None:
|
||||
fields.append({"type": "mrkdwn", "text": f"*Threshold:*\n{threshold}"})
|
||||
blocks.append({"type": "section", "fields": fields})
|
||||
|
||||
if message_text:
|
||||
blocks.append({"type": "section", "text": {"type": "mrkdwn", "text": f"*Message:*\n{message_text}"}})
|
||||
|
||||
blocks.append({"type": "context", "elements": [{"type": "mrkdwn", "text": "TOD Alert System"}]})
|
||||
|
||||
slack_url = channel.get("slack_webhook_url", "")
|
||||
if not slack_url:
|
||||
logger.warning("Slack channel %s has no webhook URL configured", channel.get("name"))
|
||||
return
|
||||
|
||||
payload = {"attachments": [{"color": color, "blocks": blocks}]}
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.post(slack_url, json=payload)
|
||||
logger.info("Slack notification sent — status %d", response.status_code)
|
||||
|
||||
|
||||
async def send_test_notification(channel: dict[str, Any]) -> bool:
|
||||
"""Send a test notification through a channel to verify configuration.
|
||||
|
||||
Args:
|
||||
channel: Notification channel dict with all config fields
|
||||
|
||||
Returns:
|
||||
True on success
|
||||
|
||||
Raises:
|
||||
Exception on delivery failure (caller handles)
|
||||
"""
|
||||
test_event = {
|
||||
"status": "test",
|
||||
"severity": "info",
|
||||
"metric": "test",
|
||||
"value": None,
|
||||
"threshold": None,
|
||||
"message": "Test notification from TOD",
|
||||
"device_id": "00000000-0000-0000-0000-000000000000",
|
||||
"fired_at": "",
|
||||
}
|
||||
|
||||
if channel["channel_type"] == "email":
|
||||
await _send_email(channel, test_event, "Test Device")
|
||||
elif channel["channel_type"] == "webhook":
|
||||
await _send_webhook(channel, test_event, "Test Device")
|
||||
elif channel["channel_type"] == "slack":
|
||||
await _send_slack(channel, test_event, "Test Device")
|
||||
else:
|
||||
raise ValueError(f"Unknown channel type: {channel['channel_type']}")
|
||||
|
||||
return True
|
||||
174
backend/app/services/openbao_service.py
Normal file
174
backend/app/services/openbao_service.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
OpenBao Transit secrets engine client for per-tenant envelope encryption.
|
||||
|
||||
Provides encrypt/decrypt operations via OpenBao's HTTP API. Each tenant gets
|
||||
a dedicated Transit key (tenant_{uuid}) for AES-256-GCM encryption. The key
|
||||
material never leaves OpenBao -- the application only sees ciphertext.
|
||||
|
||||
Ciphertext format: "vault:v1:base64..." (compatible with Vault Transit format)
|
||||
"""
|
||||
import base64
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenBaoTransitService:
|
||||
"""Async client for OpenBao Transit secrets engine."""
|
||||
|
||||
def __init__(self, addr: str | None = None, token: str | None = None):
|
||||
self.addr = addr or settings.OPENBAO_ADDR
|
||||
self.token = token or settings.OPENBAO_TOKEN
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None or self._client.is_closed:
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.addr,
|
||||
headers={"X-Vault-Token": self.token},
|
||||
timeout=5.0,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._client and not self._client.is_closed:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def create_tenant_key(self, tenant_id: str) -> None:
|
||||
"""Create Transit encryption keys for a tenant (credential + data). Idempotent."""
|
||||
client = await self._get_client()
|
||||
|
||||
# Credential key: tenant_{uuid}
|
||||
key_name = f"tenant_{tenant_id}"
|
||||
resp = await client.post(
|
||||
f"/v1/transit/keys/{key_name}",
|
||||
json={"type": "aes256-gcm96"},
|
||||
)
|
||||
if resp.status_code not in (200, 204):
|
||||
resp.raise_for_status()
|
||||
logger.info("OpenBao Transit key ensured", extra={"key_name": key_name})
|
||||
|
||||
# Data key: tenant_{uuid}_data (Phase 30)
|
||||
await self.create_tenant_data_key(tenant_id)
|
||||
|
||||
async def encrypt(self, tenant_id: str, plaintext: bytes) -> str:
|
||||
"""Encrypt plaintext via Transit engine. Returns ciphertext string."""
|
||||
client = await self._get_client()
|
||||
key_name = f"tenant_{tenant_id}"
|
||||
resp = await client.post(
|
||||
f"/v1/transit/encrypt/{key_name}",
|
||||
json={"plaintext": base64.b64encode(plaintext).decode()},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
ciphertext = resp.json()["data"]["ciphertext"]
|
||||
return ciphertext # "vault:v1:..."
|
||||
|
||||
async def decrypt(self, tenant_id: str, ciphertext: str) -> bytes:
|
||||
"""Decrypt Transit ciphertext. Returns plaintext bytes."""
|
||||
client = await self._get_client()
|
||||
key_name = f"tenant_{tenant_id}"
|
||||
resp = await client.post(
|
||||
f"/v1/transit/decrypt/{key_name}",
|
||||
json={"ciphertext": ciphertext},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
plaintext_b64 = resp.json()["data"]["plaintext"]
|
||||
return base64.b64decode(plaintext_b64)
|
||||
|
||||
async def key_exists(self, tenant_id: str) -> bool:
|
||||
"""Check if a Transit key exists for a tenant."""
|
||||
client = await self._get_client()
|
||||
key_name = f"tenant_{tenant_id}"
|
||||
resp = await client.get(f"/v1/transit/keys/{key_name}")
|
||||
return resp.status_code == 200
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Data encryption keys (tenant_{uuid}_data) — Phase 30
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def create_tenant_data_key(self, tenant_id: str) -> None:
|
||||
"""Create a Transit data encryption key for a tenant. Idempotent.
|
||||
|
||||
Data keys use the suffix '_data' to separate them from credential keys.
|
||||
Key naming: tenant_{uuid}_data (vs tenant_{uuid} for credentials).
|
||||
"""
|
||||
client = await self._get_client()
|
||||
key_name = f"tenant_{tenant_id}_data"
|
||||
resp = await client.post(
|
||||
f"/v1/transit/keys/{key_name}",
|
||||
json={"type": "aes256-gcm96"},
|
||||
)
|
||||
if resp.status_code not in (200, 204):
|
||||
resp.raise_for_status()
|
||||
logger.info("OpenBao Transit data key ensured", extra={"key_name": key_name})
|
||||
|
||||
async def ensure_tenant_data_key(self, tenant_id: str) -> None:
|
||||
"""Ensure a data encryption key exists for a tenant. Idempotent.
|
||||
|
||||
Checks existence first and creates if missing. Safe to call on every
|
||||
encrypt operation (fast path: single GET to check existence).
|
||||
"""
|
||||
client = await self._get_client()
|
||||
key_name = f"tenant_{tenant_id}_data"
|
||||
resp = await client.get(f"/v1/transit/keys/{key_name}")
|
||||
if resp.status_code != 200:
|
||||
await self.create_tenant_data_key(tenant_id)
|
||||
|
||||
async def encrypt_data(self, tenant_id: str, plaintext: bytes) -> str:
|
||||
"""Encrypt data via Transit using per-tenant data key.
|
||||
|
||||
Uses the tenant_{uuid}_data key (separate from credential key).
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID string.
|
||||
plaintext: Raw bytes to encrypt.
|
||||
|
||||
Returns:
|
||||
Transit ciphertext string (vault:v1:...).
|
||||
"""
|
||||
client = await self._get_client()
|
||||
key_name = f"tenant_{tenant_id}_data"
|
||||
resp = await client.post(
|
||||
f"/v1/transit/encrypt/{key_name}",
|
||||
json={"plaintext": base64.b64encode(plaintext).decode()},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["data"]["ciphertext"]
|
||||
|
||||
async def decrypt_data(self, tenant_id: str, ciphertext: str) -> bytes:
|
||||
"""Decrypt Transit data ciphertext using per-tenant data key.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID string.
|
||||
ciphertext: Transit ciphertext (vault:v1:...).
|
||||
|
||||
Returns:
|
||||
Decrypted plaintext bytes.
|
||||
"""
|
||||
client = await self._get_client()
|
||||
key_name = f"tenant_{tenant_id}_data"
|
||||
resp = await client.post(
|
||||
f"/v1/transit/decrypt/{key_name}",
|
||||
json={"ciphertext": ciphertext},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
plaintext_b64 = resp.json()["data"]["plaintext"]
|
||||
return base64.b64decode(plaintext_b64)
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
_openbao_service: Optional[OpenBaoTransitService] = None
|
||||
|
||||
|
||||
def get_openbao_service() -> OpenBaoTransitService:
|
||||
"""Return module-level OpenBao Transit service singleton."""
|
||||
global _openbao_service
|
||||
if _openbao_service is None:
|
||||
_openbao_service = OpenBaoTransitService()
|
||||
return _openbao_service
|
||||
141
backend/app/services/push_rollback_subscriber.py
Normal file
141
backend/app/services/push_rollback_subscriber.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""NATS subscribers for push rollback (auto) and push alert (manual).
|
||||
|
||||
- config.push.rollback.> -> auto-restore for template pushes
|
||||
- config.push.alert.> -> create alert for editor pushes
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
from app.models.alert import AlertEvent
|
||||
from app.services import restore_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_nc: Optional[Any] = None
|
||||
|
||||
|
||||
async def _create_push_alert(device_id: str, tenant_id: str, push_type: str) -> None:
|
||||
"""Create a high-priority alert for device offline after config push."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
alert = AlertEvent(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
status="firing",
|
||||
severity="critical",
|
||||
message=f"Device went offline after config {push_type} — rollback available",
|
||||
)
|
||||
session.add(alert)
|
||||
await session.commit()
|
||||
logger.info("Created push alert for device %s (type=%s)", device_id, push_type)
|
||||
|
||||
|
||||
async def handle_push_rollback(event: dict) -> None:
|
||||
"""Auto-rollback: restore device to pre-push config."""
|
||||
device_id = event.get("device_id")
|
||||
tenant_id = event.get("tenant_id")
|
||||
commit_sha = event.get("pre_push_commit_sha")
|
||||
|
||||
if not all([device_id, tenant_id, commit_sha]):
|
||||
logger.warning("Push rollback event missing fields: %s", event)
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"AUTO-ROLLBACK: Device %s offline after template push, restoring to %s",
|
||||
device_id,
|
||||
commit_sha,
|
||||
)
|
||||
|
||||
try:
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await restore_service.restore_config(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
commit_sha=commit_sha,
|
||||
db_session=session,
|
||||
)
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"Auto-rollback result for device %s: %s",
|
||||
device_id,
|
||||
result.get("status"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Auto-rollback failed for device %s: %s", device_id, e)
|
||||
await _create_push_alert(device_id, tenant_id, "template (auto-rollback failed)")
|
||||
|
||||
|
||||
async def handle_push_alert(event: dict) -> None:
|
||||
"""Alert: create notification for device offline after editor push."""
|
||||
device_id = event.get("device_id")
|
||||
tenant_id = event.get("tenant_id")
|
||||
push_type = event.get("push_type", "editor")
|
||||
|
||||
if not device_id or not tenant_id:
|
||||
logger.warning("Push alert event missing fields: %s", event)
|
||||
return
|
||||
|
||||
await _create_push_alert(device_id, tenant_id, push_type)
|
||||
|
||||
|
||||
async def _on_rollback_message(msg) -> None:
|
||||
"""NATS message handler for config.push.rollback.> subjects."""
|
||||
try:
|
||||
event = json.loads(msg.data.decode())
|
||||
await handle_push_rollback(event)
|
||||
await msg.ack()
|
||||
except Exception as e:
|
||||
logger.error("Error handling rollback message: %s", e)
|
||||
await msg.nak()
|
||||
|
||||
|
||||
async def _on_alert_message(msg) -> None:
|
||||
"""NATS message handler for config.push.alert.> subjects."""
|
||||
try:
|
||||
event = json.loads(msg.data.decode())
|
||||
await handle_push_alert(event)
|
||||
await msg.ack()
|
||||
except Exception as e:
|
||||
logger.error("Error handling push alert message: %s", e)
|
||||
await msg.nak()
|
||||
|
||||
|
||||
async def start_push_rollback_subscriber() -> Optional[Any]:
|
||||
"""Connect to NATS and subscribe to push rollback/alert events."""
|
||||
import nats
|
||||
|
||||
global _nc
|
||||
try:
|
||||
logger.info("NATS push-rollback: connecting to %s", settings.NATS_URL)
|
||||
_nc = await nats.connect(settings.NATS_URL)
|
||||
js = _nc.jetstream()
|
||||
await js.subscribe(
|
||||
"config.push.rollback.>",
|
||||
cb=_on_rollback_message,
|
||||
durable="api-push-rollback-consumer",
|
||||
stream="DEVICE_EVENTS",
|
||||
manual_ack=True,
|
||||
)
|
||||
await js.subscribe(
|
||||
"config.push.alert.>",
|
||||
cb=_on_alert_message,
|
||||
durable="api-push-alert-consumer",
|
||||
stream="DEVICE_EVENTS",
|
||||
manual_ack=True,
|
||||
)
|
||||
logger.info("Push rollback/alert subscriber started")
|
||||
return _nc
|
||||
except Exception as e:
|
||||
logger.error("Failed to start push rollback subscriber: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
async def stop_push_rollback_subscriber() -> None:
|
||||
"""Gracefully close the NATS connection."""
|
||||
global _nc
|
||||
if _nc:
|
||||
await _nc.drain()
|
||||
_nc = None
|
||||
70
backend/app/services/push_tracker.py
Normal file
70
backend/app/services/push_tracker.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Track recent config pushes in Redis for poller-aware rollback.
|
||||
|
||||
When a device goes offline shortly after a push, the poller checks these
|
||||
keys and triggers rollback (template/restore) or alert (editor).
|
||||
|
||||
Redis key format: push:recent:{device_id}
|
||||
TTL: 300 seconds (5 minutes)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import redis.asyncio as redis
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PUSH_TTL_SECONDS = 300 # 5 minutes
|
||||
|
||||
_redis: Optional[redis.Redis] = None
|
||||
|
||||
|
||||
async def _get_redis() -> redis.Redis:
|
||||
global _redis
|
||||
if _redis is None:
|
||||
_redis = redis.from_url(settings.REDIS_URL)
|
||||
return _redis
|
||||
|
||||
|
||||
async def record_push(
|
||||
device_id: str,
|
||||
tenant_id: str,
|
||||
push_type: str,
|
||||
push_operation_id: str = "",
|
||||
pre_push_commit_sha: str = "",
|
||||
) -> None:
|
||||
"""Record a recent config push in Redis.
|
||||
|
||||
Args:
|
||||
device_id: UUID of the device.
|
||||
tenant_id: UUID of the tenant.
|
||||
push_type: 'template' (auto-rollback) or 'editor' (alert only) or 'restore'.
|
||||
push_operation_id: ID of the ConfigPushOperation row.
|
||||
pre_push_commit_sha: Git SHA of the pre-push backup (for rollback).
|
||||
"""
|
||||
r = await _get_redis()
|
||||
key = f"push:recent:{device_id}"
|
||||
value = json.dumps({
|
||||
"device_id": device_id,
|
||||
"tenant_id": tenant_id,
|
||||
"push_type": push_type,
|
||||
"push_operation_id": push_operation_id,
|
||||
"pre_push_commit_sha": pre_push_commit_sha,
|
||||
})
|
||||
await r.set(key, value, ex=PUSH_TTL_SECONDS)
|
||||
logger.debug(
|
||||
"Recorded push for device %s (type=%s, TTL=%ds)",
|
||||
device_id,
|
||||
push_type,
|
||||
PUSH_TTL_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
async def clear_push(device_id: str) -> None:
|
||||
"""Clear the push tracking key (e.g., after successful commit)."""
|
||||
r = await _get_redis()
|
||||
await r.delete(f"push:recent:{device_id}")
|
||||
logger.debug("Cleared push tracking for device %s", device_id)
|
||||
572
backend/app/services/report_service.py
Normal file
572
backend/app/services/report_service.py
Normal file
@@ -0,0 +1,572 @@
|
||||
"""Report generation service.
|
||||
|
||||
Generates PDF (via Jinja2 + weasyprint) and CSV reports for:
|
||||
- Device inventory
|
||||
- Metrics summary
|
||||
- Alert history
|
||||
- Change log (audit_logs if available, else config_backups fallback)
|
||||
|
||||
Phase 30 NOTE: Reports are currently ephemeral (generated on-demand per request,
|
||||
never stored at rest). DATAENC-03 requires "report content is encrypted before
|
||||
storage." Since no report storage exists yet, encryption will be applied when
|
||||
report caching/storage is added. The generation pipeline is Transit-ready --
|
||||
wrap the file_bytes with encrypt_data_transit() before any future INSERT.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import structlog
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
# Jinja2 environment pointing at the templates directory
|
||||
_TEMPLATE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "templates")
|
||||
_jinja_env = Environment(
|
||||
loader=FileSystemLoader(_TEMPLATE_DIR),
|
||||
autoescape=True,
|
||||
)
|
||||
|
||||
|
||||
async def generate_report(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
report_type: str,
|
||||
date_from: Optional[datetime],
|
||||
date_to: Optional[datetime],
|
||||
fmt: str = "pdf",
|
||||
) -> tuple[bytes, str, str]:
|
||||
"""Generate a report and return (file_bytes, content_type, filename).
|
||||
|
||||
Args:
|
||||
db: RLS-enforced async session (tenant context already set).
|
||||
tenant_id: Tenant UUID for scoping.
|
||||
report_type: One of device_inventory, metrics_summary, alert_history, change_log.
|
||||
date_from: Start date for time-ranged reports.
|
||||
date_to: End date for time-ranged reports.
|
||||
fmt: Output format -- "pdf" or "csv".
|
||||
|
||||
Returns:
|
||||
Tuple of (file_bytes, content_type, filename).
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
# Fetch tenant name for the header
|
||||
tenant_name = await _get_tenant_name(db, tenant_id)
|
||||
|
||||
# Dispatch to the appropriate handler
|
||||
handlers = {
|
||||
"device_inventory": _device_inventory,
|
||||
"metrics_summary": _metrics_summary,
|
||||
"alert_history": _alert_history,
|
||||
"change_log": _change_log,
|
||||
}
|
||||
handler = handlers[report_type]
|
||||
template_data = await handler(db, tenant_id, date_from, date_to)
|
||||
|
||||
# Common template context
|
||||
generated_at = datetime.utcnow().strftime("%Y-%m-%d %H:%M UTC")
|
||||
base_context = {
|
||||
"tenant_name": tenant_name,
|
||||
"generated_at": generated_at,
|
||||
}
|
||||
|
||||
timestamp_str = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
if fmt == "csv":
|
||||
file_bytes = _render_csv(report_type, template_data)
|
||||
content_type = "text/csv; charset=utf-8"
|
||||
filename = f"{report_type}_{timestamp_str}.csv"
|
||||
else:
|
||||
file_bytes = _render_pdf(report_type, {**base_context, **template_data})
|
||||
content_type = "application/pdf"
|
||||
filename = f"{report_type}_{timestamp_str}.pdf"
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info(
|
||||
"report_generated",
|
||||
report_type=report_type,
|
||||
format=fmt,
|
||||
tenant_id=str(tenant_id),
|
||||
size_bytes=len(file_bytes),
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
)
|
||||
|
||||
return file_bytes, content_type, filename
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tenant name helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _get_tenant_name(db: AsyncSession, tenant_id: UUID) -> str:
|
||||
"""Fetch the tenant name by ID."""
|
||||
result = await db.execute(
|
||||
text("SELECT name FROM tenants WHERE id = CAST(:tid AS uuid)"),
|
||||
{"tid": str(tenant_id)},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row[0] if row else "Unknown Tenant"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Report type handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _device_inventory(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
date_from: Optional[datetime],
|
||||
date_to: Optional[datetime],
|
||||
) -> dict[str, Any]:
|
||||
"""Gather device inventory data."""
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT d.hostname, d.ip_address, d.model, d.routeros_version,
|
||||
d.status, d.last_seen, d.uptime_seconds,
|
||||
COALESCE(
|
||||
(SELECT string_agg(dg.name, ', ')
|
||||
FROM device_group_memberships dgm
|
||||
JOIN device_groups dg ON dg.id = dgm.group_id
|
||||
WHERE dgm.device_id = d.id),
|
||||
''
|
||||
) AS groups
|
||||
FROM devices d
|
||||
ORDER BY d.hostname ASC
|
||||
""")
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
devices = []
|
||||
online_count = 0
|
||||
offline_count = 0
|
||||
unknown_count = 0
|
||||
|
||||
for row in rows:
|
||||
status = row[4]
|
||||
if status == "online":
|
||||
online_count += 1
|
||||
elif status == "offline":
|
||||
offline_count += 1
|
||||
else:
|
||||
unknown_count += 1
|
||||
|
||||
uptime_str = _format_uptime(row[6]) if row[6] else None
|
||||
last_seen_str = row[5].strftime("%Y-%m-%d %H:%M") if row[5] else None
|
||||
|
||||
devices.append({
|
||||
"hostname": row[0],
|
||||
"ip_address": row[1],
|
||||
"model": row[2],
|
||||
"routeros_version": row[3],
|
||||
"status": status,
|
||||
"last_seen": last_seen_str,
|
||||
"uptime": uptime_str,
|
||||
"groups": row[7] if row[7] else None,
|
||||
})
|
||||
|
||||
return {
|
||||
"report_title": "Device Inventory",
|
||||
"devices": devices,
|
||||
"total_devices": len(devices),
|
||||
"online_count": online_count,
|
||||
"offline_count": offline_count,
|
||||
"unknown_count": unknown_count,
|
||||
}
|
||||
|
||||
|
||||
async def _metrics_summary(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
date_from: Optional[datetime],
|
||||
date_to: Optional[datetime],
|
||||
) -> dict[str, Any]:
|
||||
"""Gather metrics summary data grouped by device."""
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT d.hostname,
|
||||
AVG(hm.cpu_load) AS avg_cpu,
|
||||
MAX(hm.cpu_load) AS peak_cpu,
|
||||
AVG(CASE WHEN hm.total_memory > 0
|
||||
THEN 100.0 * (hm.total_memory - hm.free_memory) / hm.total_memory
|
||||
END) AS avg_mem,
|
||||
MAX(CASE WHEN hm.total_memory > 0
|
||||
THEN 100.0 * (hm.total_memory - hm.free_memory) / hm.total_memory
|
||||
END) AS peak_mem,
|
||||
AVG(CASE WHEN hm.total_disk > 0
|
||||
THEN 100.0 * (hm.total_disk - hm.free_disk) / hm.total_disk
|
||||
END) AS avg_disk,
|
||||
AVG(hm.temperature) AS avg_temp,
|
||||
COUNT(*) AS data_points
|
||||
FROM health_metrics hm
|
||||
JOIN devices d ON d.id = hm.device_id
|
||||
WHERE hm.time >= :date_from
|
||||
AND hm.time <= :date_to
|
||||
GROUP BY d.id, d.hostname
|
||||
ORDER BY avg_cpu DESC NULLS LAST
|
||||
"""),
|
||||
{
|
||||
"date_from": date_from,
|
||||
"date_to": date_to,
|
||||
},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
devices = []
|
||||
for row in rows:
|
||||
devices.append({
|
||||
"hostname": row[0],
|
||||
"avg_cpu": float(row[1]) if row[1] is not None else None,
|
||||
"peak_cpu": float(row[2]) if row[2] is not None else None,
|
||||
"avg_mem": float(row[3]) if row[3] is not None else None,
|
||||
"peak_mem": float(row[4]) if row[4] is not None else None,
|
||||
"avg_disk": float(row[5]) if row[5] is not None else None,
|
||||
"avg_temp": float(row[6]) if row[6] is not None else None,
|
||||
"data_points": row[7],
|
||||
})
|
||||
|
||||
return {
|
||||
"report_title": "Metrics Summary",
|
||||
"devices": devices,
|
||||
"date_from": date_from.strftime("%Y-%m-%d") if date_from else "",
|
||||
"date_to": date_to.strftime("%Y-%m-%d") if date_to else "",
|
||||
}
|
||||
|
||||
|
||||
async def _alert_history(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
date_from: Optional[datetime],
|
||||
date_to: Optional[datetime],
|
||||
) -> dict[str, Any]:
|
||||
"""Gather alert history data."""
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT ae.fired_at, ae.resolved_at, ae.severity, ae.status,
|
||||
ae.message, d.hostname,
|
||||
EXTRACT(EPOCH FROM (ae.resolved_at - ae.fired_at)) AS duration_secs
|
||||
FROM alert_events ae
|
||||
LEFT JOIN devices d ON d.id = ae.device_id
|
||||
WHERE ae.fired_at >= :date_from
|
||||
AND ae.fired_at <= :date_to
|
||||
ORDER BY ae.fired_at DESC
|
||||
"""),
|
||||
{
|
||||
"date_from": date_from,
|
||||
"date_to": date_to,
|
||||
},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
alerts = []
|
||||
critical_count = 0
|
||||
warning_count = 0
|
||||
info_count = 0
|
||||
resolved_durations: list[float] = []
|
||||
|
||||
for row in rows:
|
||||
severity = row[2]
|
||||
if severity == "critical":
|
||||
critical_count += 1
|
||||
elif severity == "warning":
|
||||
warning_count += 1
|
||||
else:
|
||||
info_count += 1
|
||||
|
||||
duration_secs = float(row[6]) if row[6] is not None else None
|
||||
if duration_secs is not None:
|
||||
resolved_durations.append(duration_secs)
|
||||
|
||||
alerts.append({
|
||||
"fired_at": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
|
||||
"hostname": row[5],
|
||||
"severity": severity,
|
||||
"status": row[3],
|
||||
"message": row[4],
|
||||
"duration": _format_duration(duration_secs) if duration_secs is not None else None,
|
||||
})
|
||||
|
||||
mttr_minutes = None
|
||||
mttr_display = None
|
||||
if resolved_durations:
|
||||
avg_secs = sum(resolved_durations) / len(resolved_durations)
|
||||
mttr_minutes = round(avg_secs / 60, 1)
|
||||
mttr_display = _format_duration(avg_secs)
|
||||
|
||||
return {
|
||||
"report_title": "Alert History",
|
||||
"alerts": alerts,
|
||||
"total_alerts": len(alerts),
|
||||
"critical_count": critical_count,
|
||||
"warning_count": warning_count,
|
||||
"info_count": info_count,
|
||||
"mttr_minutes": mttr_minutes,
|
||||
"mttr_display": mttr_display,
|
||||
"date_from": date_from.strftime("%Y-%m-%d") if date_from else "",
|
||||
"date_to": date_to.strftime("%Y-%m-%d") if date_to else "",
|
||||
}
|
||||
|
||||
|
||||
async def _change_log(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
date_from: Optional[datetime],
|
||||
date_to: Optional[datetime],
|
||||
) -> dict[str, Any]:
|
||||
"""Gather change log data -- try audit_logs table first, fall back to config_backups."""
|
||||
# Check if audit_logs table exists (17-01 may not have run yet)
|
||||
has_audit_logs = await _table_exists(db, "audit_logs")
|
||||
|
||||
if has_audit_logs:
|
||||
return await _change_log_from_audit(db, date_from, date_to)
|
||||
else:
|
||||
return await _change_log_from_backups(db, date_from, date_to)
|
||||
|
||||
|
||||
async def _table_exists(db: AsyncSession, table_name: str) -> bool:
|
||||
"""Check if a table exists in the database."""
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_schema = 'public' AND table_name = :table_name
|
||||
)
|
||||
"""),
|
||||
{"table_name": table_name},
|
||||
)
|
||||
return bool(result.scalar())
|
||||
|
||||
|
||||
async def _change_log_from_audit(
|
||||
db: AsyncSession,
|
||||
date_from: Optional[datetime],
|
||||
date_to: Optional[datetime],
|
||||
) -> dict[str, Any]:
|
||||
"""Build change log from audit_logs table."""
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT al.created_at, u.name AS user_name, al.action,
|
||||
d.hostname, al.resource_type,
|
||||
al.details
|
||||
FROM audit_logs al
|
||||
LEFT JOIN users u ON u.id = al.user_id
|
||||
LEFT JOIN devices d ON d.id = al.device_id
|
||||
WHERE al.created_at >= :date_from
|
||||
AND al.created_at <= :date_to
|
||||
ORDER BY al.created_at DESC
|
||||
"""),
|
||||
{
|
||||
"date_from": date_from,
|
||||
"date_to": date_to,
|
||||
},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
entries = []
|
||||
for row in rows:
|
||||
entries.append({
|
||||
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
|
||||
"user": row[1],
|
||||
"action": row[2],
|
||||
"device": row[3],
|
||||
"details": row[4] or row[5] or "",
|
||||
})
|
||||
|
||||
return {
|
||||
"report_title": "Change Log",
|
||||
"entries": entries,
|
||||
"total_entries": len(entries),
|
||||
"data_source": "Audit Logs",
|
||||
"date_from": date_from.strftime("%Y-%m-%d") if date_from else "",
|
||||
"date_to": date_to.strftime("%Y-%m-%d") if date_to else "",
|
||||
}
|
||||
|
||||
|
||||
async def _change_log_from_backups(
|
||||
db: AsyncSession,
|
||||
date_from: Optional[datetime],
|
||||
date_to: Optional[datetime],
|
||||
) -> dict[str, Any]:
|
||||
"""Build change log from config_backups + alert_events as fallback."""
|
||||
# Config backups as change events
|
||||
backup_result = await db.execute(
|
||||
text("""
|
||||
SELECT cb.created_at, 'system' AS user_name, 'config_backup' AS action,
|
||||
d.hostname, cb.trigger_type AS details
|
||||
FROM config_backups cb
|
||||
JOIN devices d ON d.id = cb.device_id
|
||||
WHERE cb.created_at >= :date_from
|
||||
AND cb.created_at <= :date_to
|
||||
"""),
|
||||
{
|
||||
"date_from": date_from,
|
||||
"date_to": date_to,
|
||||
},
|
||||
)
|
||||
backup_rows = backup_result.fetchall()
|
||||
|
||||
# Alert events as change events
|
||||
alert_result = await db.execute(
|
||||
text("""
|
||||
SELECT ae.fired_at, 'system' AS user_name,
|
||||
ae.severity || '_alert' AS action,
|
||||
d.hostname, ae.message AS details
|
||||
FROM alert_events ae
|
||||
LEFT JOIN devices d ON d.id = ae.device_id
|
||||
WHERE ae.fired_at >= :date_from
|
||||
AND ae.fired_at <= :date_to
|
||||
"""),
|
||||
{
|
||||
"date_from": date_from,
|
||||
"date_to": date_to,
|
||||
},
|
||||
)
|
||||
alert_rows = alert_result.fetchall()
|
||||
|
||||
# Merge and sort by timestamp descending
|
||||
entries = []
|
||||
for row in backup_rows:
|
||||
entries.append({
|
||||
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
|
||||
"user": row[1],
|
||||
"action": row[2],
|
||||
"device": row[3],
|
||||
"details": row[4] or "",
|
||||
})
|
||||
for row in alert_rows:
|
||||
entries.append({
|
||||
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
|
||||
"user": row[1],
|
||||
"action": row[2],
|
||||
"device": row[3],
|
||||
"details": row[4] or "",
|
||||
})
|
||||
|
||||
# Sort by timestamp string descending
|
||||
entries.sort(key=lambda e: e["timestamp"], reverse=True)
|
||||
|
||||
return {
|
||||
"report_title": "Change Log",
|
||||
"entries": entries,
|
||||
"total_entries": len(entries),
|
||||
"data_source": "Backups + Alerts",
|
||||
"date_from": date_from.strftime("%Y-%m-%d") if date_from else "",
|
||||
"date_to": date_to.strftime("%Y-%m-%d") if date_to else "",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rendering helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _render_pdf(report_type: str, context: dict[str, Any]) -> bytes:
|
||||
"""Render HTML template and convert to PDF via weasyprint."""
|
||||
import weasyprint
|
||||
|
||||
template = _jinja_env.get_template(f"reports/{report_type}.html")
|
||||
html_str = template.render(**context)
|
||||
pdf_bytes = weasyprint.HTML(string=html_str).write_pdf()
|
||||
return pdf_bytes
|
||||
|
||||
|
||||
def _render_csv(report_type: str, data: dict[str, Any]) -> bytes:
|
||||
"""Render report data as CSV bytes."""
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
|
||||
if report_type == "device_inventory":
|
||||
writer.writerow([
|
||||
"Hostname", "IP Address", "Model", "RouterOS Version",
|
||||
"Status", "Last Seen", "Uptime", "Groups",
|
||||
])
|
||||
for d in data.get("devices", []):
|
||||
writer.writerow([
|
||||
d["hostname"], d["ip_address"], d["model"] or "",
|
||||
d["routeros_version"] or "", d["status"],
|
||||
d["last_seen"] or "", d["uptime"] or "",
|
||||
d["groups"] or "",
|
||||
])
|
||||
|
||||
elif report_type == "metrics_summary":
|
||||
writer.writerow([
|
||||
"Hostname", "Avg CPU %", "Peak CPU %", "Avg Memory %",
|
||||
"Peak Memory %", "Avg Disk %", "Avg Temp", "Data Points",
|
||||
])
|
||||
for d in data.get("devices", []):
|
||||
writer.writerow([
|
||||
d["hostname"],
|
||||
f"{d['avg_cpu']:.1f}" if d["avg_cpu"] is not None else "",
|
||||
f"{d['peak_cpu']:.1f}" if d["peak_cpu"] is not None else "",
|
||||
f"{d['avg_mem']:.1f}" if d["avg_mem"] is not None else "",
|
||||
f"{d['peak_mem']:.1f}" if d["peak_mem"] is not None else "",
|
||||
f"{d['avg_disk']:.1f}" if d["avg_disk"] is not None else "",
|
||||
f"{d['avg_temp']:.1f}" if d["avg_temp"] is not None else "",
|
||||
d["data_points"],
|
||||
])
|
||||
|
||||
elif report_type == "alert_history":
|
||||
writer.writerow([
|
||||
"Timestamp", "Device", "Severity", "Message", "Status", "Duration",
|
||||
])
|
||||
for a in data.get("alerts", []):
|
||||
writer.writerow([
|
||||
a["fired_at"], a["hostname"] or "", a["severity"],
|
||||
a["message"] or "", a["status"], a["duration"] or "",
|
||||
])
|
||||
|
||||
elif report_type == "change_log":
|
||||
writer.writerow([
|
||||
"Timestamp", "User", "Action", "Device", "Details",
|
||||
])
|
||||
for e in data.get("entries", []):
|
||||
writer.writerow([
|
||||
e["timestamp"], e["user"] or "", e["action"],
|
||||
e["device"] or "", e["details"] or "",
|
||||
])
|
||||
|
||||
return output.getvalue().encode("utf-8")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Formatting utilities
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _format_uptime(seconds: int) -> str:
|
||||
"""Format uptime seconds as human-readable string."""
|
||||
days = seconds // 86400
|
||||
hours = (seconds % 86400) // 3600
|
||||
minutes = (seconds % 3600) // 60
|
||||
if days > 0:
|
||||
return f"{days}d {hours}h {minutes}m"
|
||||
elif hours > 0:
|
||||
return f"{hours}h {minutes}m"
|
||||
else:
|
||||
return f"{minutes}m"
|
||||
|
||||
|
||||
def _format_duration(seconds: float) -> str:
|
||||
"""Format a duration in seconds as a human-readable string."""
|
||||
if seconds < 60:
|
||||
return f"{int(seconds)}s"
|
||||
elif seconds < 3600:
|
||||
return f"{int(seconds // 60)}m {int(seconds % 60)}s"
|
||||
elif seconds < 86400:
|
||||
hours = int(seconds // 3600)
|
||||
mins = int((seconds % 3600) // 60)
|
||||
return f"{hours}h {mins}m"
|
||||
else:
|
||||
days = int(seconds // 86400)
|
||||
hours = int((seconds % 86400) // 3600)
|
||||
return f"{days}d {hours}h"
|
||||
599
backend/app/services/restore_service.py
Normal file
599
backend/app/services/restore_service.py
Normal file
@@ -0,0 +1,599 @@
|
||||
"""Two-phase config push with panic-revert safety for RouterOS devices.
|
||||
|
||||
This module implements the critical safety mechanism for config restoration:
|
||||
|
||||
Phase 1 — Push:
|
||||
1. Pre-backup (mandatory) — snapshot current config before any changes
|
||||
2. Install panic-revert RouterOS scheduler — auto-reverts if device becomes
|
||||
unreachable (the scheduler fires after 90s and loads the pre-push backup)
|
||||
3. Push the target config via SSH /import
|
||||
|
||||
Phase 2 — Verification (60s settle window):
|
||||
4. Wait 60s for config to settle (scheduled processes restart, etc.)
|
||||
5. Reachability check via asyncssh
|
||||
6a. Reachable — remove panic-revert scheduler; mark operation committed
|
||||
6b. Unreachable — RouterOS is auto-reverting; mark operation reverted
|
||||
|
||||
Pitfall 6 handling:
|
||||
If the API pod restarts during the 60s window, the config_push_operations
|
||||
row with status='pending_verification' serves as the recovery signal.
|
||||
On startup, recover_stale_push_operations() resolves any stale rows.
|
||||
|
||||
Security policy:
|
||||
known_hosts=None — RouterOS self-signed host keys; mirrors InsecureSkipVerify
|
||||
used in the poller's TLS connection. See Pitfall 2 in 04-RESEARCH.md.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import asyncssh
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.database import set_tenant_context, AdminAsyncSessionLocal
|
||||
from app.models.config_backup import ConfigPushOperation
|
||||
from app.models.device import Device
|
||||
from app.services import backup_service, git_store
|
||||
from app.services.event_publisher import publish_event
|
||||
from app.services.push_tracker import record_push, clear_push
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Name of the panic-revert scheduler installed on the RouterOS device
|
||||
_PANIC_REVERT_SCHEDULER = "mikrotik-portal-panic-revert"
|
||||
# Name of the pre-push binary backup saved on device flash
|
||||
_PRE_PUSH_BACKUP = "portal-pre-push"
|
||||
# Name of the RSC file used for /import on device
|
||||
_RESTORE_RSC = "portal-restore.rsc"
|
||||
|
||||
|
||||
async def _publish_push_progress(
|
||||
tenant_id: str,
|
||||
device_id: str,
|
||||
stage: str,
|
||||
message: str,
|
||||
push_op_id: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Publish config push progress event to NATS (fire-and-forget)."""
|
||||
payload = {
|
||||
"event_type": "config_push",
|
||||
"tenant_id": tenant_id,
|
||||
"device_id": device_id,
|
||||
"stage": stage,
|
||||
"message": message,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"push_operation_id": push_op_id,
|
||||
}
|
||||
if error:
|
||||
payload["error"] = error
|
||||
await publish_event(f"config.push.{tenant_id}.{device_id}", payload)
|
||||
|
||||
|
||||
async def restore_config(
|
||||
device_id: str,
|
||||
tenant_id: str,
|
||||
commit_sha: str,
|
||||
db_session: AsyncSession,
|
||||
) -> dict:
|
||||
"""Restore a device config to a specific backup version via two-phase push.
|
||||
|
||||
Args:
|
||||
device_id: Device UUID as string.
|
||||
tenant_id: Tenant UUID as string.
|
||||
commit_sha: Git commit SHA of the backup version to restore.
|
||||
db_session: AsyncSession with RLS context already set (from API endpoint).
|
||||
|
||||
Returns:
|
||||
{
|
||||
"status": "committed" | "reverted" | "failed",
|
||||
"message": str,
|
||||
"pre_backup_sha": str,
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: If device not found or missing credentials.
|
||||
Exception: On SSH failure during push phase (reverted status logged).
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 1: Load device from DB and decrypt credentials
|
||||
# ------------------------------------------------------------------
|
||||
from sqlalchemy import select
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Device).where(Device.id == device_id) # type: ignore[arg-type]
|
||||
)
|
||||
device = result.scalar_one_or_none()
|
||||
if device is None:
|
||||
raise ValueError(f"Device {device_id!r} not found")
|
||||
|
||||
if not device.encrypted_credentials_transit and not device.encrypted_credentials:
|
||||
raise ValueError(
|
||||
f"Device {device_id!r} has no stored credentials — cannot perform restore"
|
||||
)
|
||||
|
||||
key = settings.get_encryption_key_bytes()
|
||||
from app.services.crypto import decrypt_credentials_hybrid
|
||||
creds_json = await decrypt_credentials_hybrid(
|
||||
device.encrypted_credentials_transit,
|
||||
device.encrypted_credentials,
|
||||
str(device.tenant_id),
|
||||
key,
|
||||
)
|
||||
creds = json.loads(creds_json)
|
||||
ssh_username = creds.get("username", "")
|
||||
ssh_password = creds.get("password", "")
|
||||
ip = device.ip_address
|
||||
|
||||
hostname = device.hostname or ip
|
||||
|
||||
# Publish "started" progress event
|
||||
await _publish_push_progress(tenant_id, device_id, "started", f"Config restore started for {hostname}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2: Read the target export.rsc from the backup commit
|
||||
# ------------------------------------------------------------------
|
||||
try:
|
||||
export_bytes = await loop.run_in_executor(
|
||||
None,
|
||||
git_store.read_file,
|
||||
tenant_id,
|
||||
commit_sha,
|
||||
device_id,
|
||||
"export.rsc",
|
||||
)
|
||||
except (KeyError, Exception) as exc:
|
||||
raise ValueError(
|
||||
f"Backup version {commit_sha!r} not found for device {device_id!r}: {exc}"
|
||||
) from exc
|
||||
|
||||
export_text = export_bytes.decode("utf-8", errors="replace")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 3: Mandatory pre-backup before push
|
||||
# ------------------------------------------------------------------
|
||||
await _publish_push_progress(tenant_id, device_id, "backing_up", f"Creating pre-restore backup for {hostname}")
|
||||
|
||||
logger.info(
|
||||
"Starting pre-restore backup for device %s (%s) before pushing commit %s",
|
||||
hostname,
|
||||
ip,
|
||||
commit_sha[:8],
|
||||
)
|
||||
pre_backup_result = await backup_service.run_backup(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
trigger_type="pre-restore",
|
||||
db_session=db_session,
|
||||
)
|
||||
pre_backup_sha = pre_backup_result["commit_sha"]
|
||||
logger.info("Pre-restore backup complete: %s", pre_backup_sha[:8])
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 4: Record push operation (pending_verification for recovery)
|
||||
# ------------------------------------------------------------------
|
||||
push_op = ConfigPushOperation(
|
||||
device_id=device.id,
|
||||
tenant_id=device.tenant_id,
|
||||
pre_push_commit_sha=pre_backup_sha,
|
||||
scheduler_name=_PANIC_REVERT_SCHEDULER,
|
||||
status="pending_verification",
|
||||
)
|
||||
db_session.add(push_op)
|
||||
await db_session.flush()
|
||||
push_op_id = push_op.id
|
||||
|
||||
logger.info(
|
||||
"Push op %s in pending_verification — if API restarts, "
|
||||
"recover_stale_push_operations() will resolve on next startup",
|
||||
push_op.id,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 5: SSH to device — install panic-revert, push config
|
||||
# ------------------------------------------------------------------
|
||||
push_op_id_str = str(push_op_id)
|
||||
await _publish_push_progress(tenant_id, device_id, "pushing", f"Pushing config to {hostname}", push_op_id=push_op_id_str)
|
||||
|
||||
logger.info(
|
||||
"Pushing config to device %s (%s): installing panic-revert scheduler and uploading config",
|
||||
hostname,
|
||||
ip,
|
||||
)
|
||||
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip,
|
||||
port=22,
|
||||
username=ssh_username,
|
||||
password=ssh_password,
|
||||
known_hosts=None, # RouterOS self-signed host keys — see module docstring
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
# 5a: Create binary backup on device as revert point
|
||||
await conn.run(
|
||||
f"/system backup save name={_PRE_PUSH_BACKUP} dont-encrypt=yes",
|
||||
check=True,
|
||||
)
|
||||
logger.debug("Pre-push binary backup saved on device as %s.backup", _PRE_PUSH_BACKUP)
|
||||
|
||||
# 5b: Install panic-revert RouterOS scheduler
|
||||
# The scheduler fires after 90s on startup and loads the pre-push backup.
|
||||
# This is the safety net: if the device becomes unreachable after push,
|
||||
# RouterOS will auto-revert to the known-good config on the next reboot
|
||||
# or after 90s of uptime.
|
||||
await conn.run(
|
||||
f"/system scheduler add "
|
||||
f'name="{_PANIC_REVERT_SCHEDULER}" '
|
||||
f"interval=90s "
|
||||
f'on-event=":delay 0; /system backup load name={_PRE_PUSH_BACKUP}" '
|
||||
f"start-time=startup",
|
||||
check=True,
|
||||
)
|
||||
logger.debug("Panic-revert scheduler installed on device")
|
||||
|
||||
# 5c: Upload export.rsc and /import it
|
||||
# Write the RSC content to the device filesystem via SSH exec,
|
||||
# then use /import to apply it. The file is cleaned up after import.
|
||||
# We use a here-doc approach: write content line-by-line via /file set.
|
||||
# RouterOS supports writing files via /tool fetch or direct file commands.
|
||||
# Simplest approach for large configs: use asyncssh's write_into to
|
||||
# write file content, then /import.
|
||||
#
|
||||
# RouterOS doesn't support direct SFTP uploads via SSH open_sftp() easily
|
||||
# for config files. Use the script approach instead:
|
||||
# /system script add + run + remove (avoids flash write concerns).
|
||||
#
|
||||
# Actually the simplest method: write the export.rsc line by line via
|
||||
# /file print / set commands is RouterOS 6 only and unreliable.
|
||||
# Best approach for RouterOS 7: use SFTP to upload the file.
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
async with sftp.open(_RESTORE_RSC, "wb") as f:
|
||||
await f.write(export_text.encode("utf-8"))
|
||||
logger.debug("Uploaded %s to device flash", _RESTORE_RSC)
|
||||
|
||||
# /import the config file
|
||||
import_result = await conn.run(
|
||||
f"/import file={_RESTORE_RSC}",
|
||||
check=False, # Don't raise on non-zero exit — import may succeed with warnings
|
||||
)
|
||||
logger.info(
|
||||
"Config import result for device %s: exit_status=%s stdout=%r",
|
||||
hostname,
|
||||
import_result.exit_status,
|
||||
(import_result.stdout or "")[:200],
|
||||
)
|
||||
|
||||
# Clean up the uploaded RSC file (best-effort)
|
||||
try:
|
||||
await conn.run(f"/file remove {_RESTORE_RSC}", check=True)
|
||||
except Exception as cleanup_err:
|
||||
logger.warning(
|
||||
"Failed to clean up %s from device %s: %s",
|
||||
_RESTORE_RSC,
|
||||
ip,
|
||||
cleanup_err,
|
||||
)
|
||||
|
||||
except Exception as push_err:
|
||||
logger.error(
|
||||
"SSH push phase failed for device %s (%s): %s",
|
||||
hostname,
|
||||
ip,
|
||||
push_err,
|
||||
)
|
||||
# Update push operation to failed
|
||||
await _update_push_op_status(push_op_id, "failed", db_session)
|
||||
await _publish_push_progress(
|
||||
tenant_id, device_id, "failed",
|
||||
f"Config push failed for {hostname}: {push_err}",
|
||||
push_op_id=push_op_id_str, error=str(push_err),
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"message": f"Config push failed during SSH phase: {push_err}",
|
||||
"pre_backup_sha": pre_backup_sha,
|
||||
}
|
||||
|
||||
# Record push in Redis so the poller can detect post-push offline events
|
||||
await record_push(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
push_type="restore",
|
||||
push_operation_id=push_op_id_str,
|
||||
pre_push_commit_sha=pre_backup_sha,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 6: Wait 60s for config to settle
|
||||
# ------------------------------------------------------------------
|
||||
await _publish_push_progress(tenant_id, device_id, "settling", f"Config pushed to {hostname} — waiting 60s for settle", push_op_id=push_op_id_str)
|
||||
|
||||
logger.info(
|
||||
"Config pushed to device %s — waiting 60s for config to settle",
|
||||
hostname,
|
||||
)
|
||||
await asyncio.sleep(60)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 7: Reachability check
|
||||
# ------------------------------------------------------------------
|
||||
await _publish_push_progress(tenant_id, device_id, "verifying", f"Verifying device {hostname} reachability", push_op_id=push_op_id_str)
|
||||
|
||||
reachable = await _check_reachability(ip, ssh_username, ssh_password)
|
||||
|
||||
if reachable:
|
||||
# ------------------------------------------------------------------
|
||||
# Step 8a: Device is reachable — remove panic-revert scheduler + cleanup
|
||||
# ------------------------------------------------------------------
|
||||
logger.info("Device %s (%s) is reachable after push — committing", hostname, ip)
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip,
|
||||
port=22,
|
||||
username=ssh_username,
|
||||
password=ssh_password,
|
||||
known_hosts=None,
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
# Remove the panic-revert scheduler
|
||||
await conn.run(
|
||||
f'/system scheduler remove "{_PANIC_REVERT_SCHEDULER}"',
|
||||
check=False, # Non-fatal if already removed
|
||||
)
|
||||
# Clean up the pre-push binary backup from device flash
|
||||
await conn.run(
|
||||
f"/file remove {_PRE_PUSH_BACKUP}.backup",
|
||||
check=False, # Non-fatal if already removed
|
||||
)
|
||||
except Exception as cleanup_err:
|
||||
# Cleanup failure is non-fatal — scheduler will eventually fire but
|
||||
# the backup is now the correct config, so it's acceptable.
|
||||
logger.warning(
|
||||
"Failed to clean up panic-revert scheduler/backup on device %s: %s",
|
||||
hostname,
|
||||
cleanup_err,
|
||||
)
|
||||
|
||||
await _update_push_op_status(push_op_id, "committed", db_session)
|
||||
await clear_push(device_id)
|
||||
await _publish_push_progress(tenant_id, device_id, "committed", f"Config restored successfully on {hostname}", push_op_id=push_op_id_str)
|
||||
|
||||
return {
|
||||
"status": "committed",
|
||||
"message": "Config restored successfully",
|
||||
"pre_backup_sha": pre_backup_sha,
|
||||
}
|
||||
|
||||
else:
|
||||
# ------------------------------------------------------------------
|
||||
# Step 8b: Device unreachable — RouterOS is auto-reverting via scheduler
|
||||
# ------------------------------------------------------------------
|
||||
logger.warning(
|
||||
"Device %s (%s) is unreachable after push — RouterOS panic-revert scheduler "
|
||||
"will auto-revert to %s.backup",
|
||||
hostname,
|
||||
ip,
|
||||
_PRE_PUSH_BACKUP,
|
||||
)
|
||||
|
||||
await _update_push_op_status(push_op_id, "reverted", db_session)
|
||||
await _publish_push_progress(
|
||||
tenant_id, device_id, "reverted",
|
||||
f"Device {hostname} unreachable — auto-reverting via panic-revert scheduler",
|
||||
push_op_id=push_op_id_str,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "reverted",
|
||||
"message": (
|
||||
"Device unreachable after push; RouterOS is auto-reverting "
|
||||
"via panic-revert scheduler"
|
||||
),
|
||||
"pre_backup_sha": pre_backup_sha,
|
||||
}
|
||||
|
||||
|
||||
async def _check_reachability(ip: str, username: str, password: str) -> bool:
|
||||
"""Check if a RouterOS device is reachable via SSH.
|
||||
|
||||
Attempts to connect and run a simple command (/system identity print).
|
||||
Returns True if successful, False if the connection fails or times out.
|
||||
|
||||
Uses asyncssh (not the poller's binary API) to avoid a circular import.
|
||||
A 30-second timeout is used — if the device doesn't respond within that
|
||||
window, it's considered unreachable (panic-revert will handle it).
|
||||
|
||||
Args:
|
||||
ip: Device IP address.
|
||||
username: SSH username.
|
||||
password: SSH password.
|
||||
|
||||
Returns:
|
||||
True if reachable, False if unreachable.
|
||||
"""
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip,
|
||||
port=22,
|
||||
username=username,
|
||||
password=password,
|
||||
known_hosts=None,
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
result = await conn.run("/system identity print", check=True)
|
||||
logger.debug("Reachability check OK for %s: %r", ip, result.stdout[:50])
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.info("Device %s unreachable after push: %s", ip, exc)
|
||||
return False
|
||||
|
||||
|
||||
async def _update_push_op_status(
|
||||
push_op_id,
|
||||
new_status: str,
|
||||
db_session: AsyncSession,
|
||||
) -> None:
|
||||
"""Update the status and completed_at of a ConfigPushOperation row.
|
||||
|
||||
Args:
|
||||
push_op_id: UUID of the ConfigPushOperation row.
|
||||
new_status: New status value ('committed' | 'reverted' | 'failed').
|
||||
db_session: Database session (must already have tenant context set).
|
||||
"""
|
||||
from sqlalchemy import select, update
|
||||
|
||||
await db_session.execute(
|
||||
update(ConfigPushOperation)
|
||||
.where(ConfigPushOperation.id == push_op_id) # type: ignore[arg-type]
|
||||
.values(
|
||||
status=new_status,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
)
|
||||
# Don't commit here — the caller (endpoint) owns the transaction
|
||||
|
||||
|
||||
async def _remove_panic_scheduler(
|
||||
ip: str, username: str, password: str, scheduler_name: str
|
||||
) -> bool:
|
||||
"""SSH to device and remove the panic-revert scheduler. Returns True if removed."""
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip,
|
||||
username=username,
|
||||
password=password,
|
||||
known_hosts=None,
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
# Check if scheduler exists
|
||||
result = await conn.run(
|
||||
f'/system scheduler print where name="{scheduler_name}"',
|
||||
check=False,
|
||||
)
|
||||
if scheduler_name in result.stdout:
|
||||
await conn.run(
|
||||
f'/system scheduler remove [find name="{scheduler_name}"]',
|
||||
check=False,
|
||||
)
|
||||
# Also clean up pre-push backup file
|
||||
await conn.run(
|
||||
f'/file remove [find name="{_PRE_PUSH_BACKUP}.backup"]',
|
||||
check=False,
|
||||
)
|
||||
return True
|
||||
return False # Scheduler already gone (device reverted itself)
|
||||
except Exception as e:
|
||||
logger.error("Failed to remove panic scheduler from %s: %s", ip, e)
|
||||
return False
|
||||
|
||||
|
||||
async def recover_stale_push_operations(db_session: AsyncSession) -> None:
|
||||
"""Recover stale pending_verification push operations on API startup.
|
||||
|
||||
Scans for operations older than 5 minutes that are still pending.
|
||||
For each, checks device reachability and resolves the operation.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.config_backup import ConfigPushOperation
|
||||
from app.models.device import Device
|
||||
from app.services.crypto import decrypt_credentials_hybrid
|
||||
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(ConfigPushOperation).where(
|
||||
ConfigPushOperation.status == "pending_verification",
|
||||
ConfigPushOperation.started_at < cutoff,
|
||||
)
|
||||
)
|
||||
stale_ops = result.scalars().all()
|
||||
|
||||
if not stale_ops:
|
||||
logger.info("No stale push operations to recover")
|
||||
return
|
||||
|
||||
logger.warning("Found %d stale push operations to recover", len(stale_ops))
|
||||
|
||||
key = settings.get_encryption_key_bytes()
|
||||
|
||||
for op in stale_ops:
|
||||
try:
|
||||
# Load device
|
||||
dev_result = await db_session.execute(
|
||||
select(Device).where(Device.id == op.device_id)
|
||||
)
|
||||
device = dev_result.scalar_one_or_none()
|
||||
if not device:
|
||||
logger.error("Device %s not found for stale op %s", op.device_id, op.id)
|
||||
await _update_push_op_status(op.id, "failed", db_session)
|
||||
continue
|
||||
|
||||
# Decrypt credentials
|
||||
creds_json = await decrypt_credentials_hybrid(
|
||||
device.encrypted_credentials_transit,
|
||||
device.encrypted_credentials,
|
||||
str(op.tenant_id),
|
||||
key,
|
||||
)
|
||||
creds = json.loads(creds_json)
|
||||
ssh_username = creds.get("username", "admin")
|
||||
ssh_password = creds.get("password", "")
|
||||
|
||||
# Check reachability
|
||||
reachable = await _check_reachability(
|
||||
device.ip_address, ssh_username, ssh_password
|
||||
)
|
||||
|
||||
if reachable:
|
||||
# Try to remove scheduler (if still there, push was good)
|
||||
removed = await _remove_panic_scheduler(
|
||||
device.ip_address,
|
||||
ssh_username,
|
||||
ssh_password,
|
||||
op.scheduler_name,
|
||||
)
|
||||
if removed:
|
||||
logger.info("Recovery: committed op %s (scheduler removed)", op.id)
|
||||
else:
|
||||
# Scheduler already gone — device may have reverted
|
||||
logger.warning(
|
||||
"Recovery: op %s — scheduler gone, device may have reverted. "
|
||||
"Marking committed (device is reachable).",
|
||||
op.id,
|
||||
)
|
||||
await _update_push_op_status(op.id, "committed", db_session)
|
||||
|
||||
await _publish_push_progress(
|
||||
str(op.tenant_id),
|
||||
str(op.device_id),
|
||||
"committed",
|
||||
"Recovered after API restart",
|
||||
push_op_id=str(op.id),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Recovery: device %s unreachable, marking op %s failed",
|
||||
op.device_id,
|
||||
op.id,
|
||||
)
|
||||
await _update_push_op_status(op.id, "failed", db_session)
|
||||
await _publish_push_progress(
|
||||
str(op.tenant_id),
|
||||
str(op.device_id),
|
||||
"failed",
|
||||
"Device unreachable during recovery after API restart",
|
||||
push_op_id=str(op.id),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Recovery failed for op %s: %s", op.id, e)
|
||||
await _update_push_op_status(op.id, "failed", db_session)
|
||||
|
||||
await db_session.commit()
|
||||
165
backend/app/services/routeros_proxy.py
Normal file
165
backend/app/services/routeros_proxy.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""RouterOS command proxy via NATS request-reply.
|
||||
|
||||
Sends command requests to the Go poller's CmdResponder subscription
|
||||
(device.cmd.{device_id}) and returns structured RouterOS API response data.
|
||||
|
||||
Used by:
|
||||
- Config editor API (browse menu paths, add/edit/delete entries)
|
||||
- Template push service (execute rendered template commands)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import nats
|
||||
import nats.aio.client
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Module-level NATS connection (lazy initialized)
|
||||
_nc: nats.aio.client.Client | None = None
|
||||
|
||||
|
||||
async def _get_nats() -> nats.aio.client.Client:
|
||||
"""Get or create a NATS connection for command proxy requests."""
|
||||
global _nc
|
||||
if _nc is None or _nc.is_closed:
|
||||
_nc = await nats.connect(settings.NATS_URL)
|
||||
logger.info("RouterOS proxy NATS connection established")
|
||||
return _nc
|
||||
|
||||
|
||||
async def execute_command(
|
||||
device_id: str,
|
||||
command: str,
|
||||
args: list[str] | None = None,
|
||||
timeout: float = 15.0,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute a RouterOS API command on a device via the Go poller.
|
||||
|
||||
Args:
|
||||
device_id: UUID string of the target device.
|
||||
command: Full RouterOS API path, e.g. "/ip/address/print".
|
||||
args: Optional list of RouterOS API args, e.g. ["=.proplist=.id,address"].
|
||||
timeout: NATS request timeout in seconds (default 15s).
|
||||
|
||||
Returns:
|
||||
{"success": bool, "data": list[dict], "error": str|None}
|
||||
"""
|
||||
nc = await _get_nats()
|
||||
request = {
|
||||
"device_id": device_id,
|
||||
"command": command,
|
||||
"args": args or [],
|
||||
}
|
||||
|
||||
try:
|
||||
reply = await nc.request(
|
||||
f"device.cmd.{device_id}",
|
||||
json.dumps(request).encode(),
|
||||
timeout=timeout,
|
||||
)
|
||||
return json.loads(reply.data)
|
||||
except nats.errors.TimeoutError:
|
||||
return {
|
||||
"success": False,
|
||||
"data": [],
|
||||
"error": "Device command timed out — device may be offline or unreachable",
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.error("NATS request failed for device %s: %s", device_id, exc)
|
||||
return {"success": False, "data": [], "error": str(exc)}
|
||||
|
||||
|
||||
async def browse_menu(device_id: str, path: str) -> dict[str, Any]:
|
||||
"""Browse a RouterOS menu path and return all entries.
|
||||
|
||||
Args:
|
||||
device_id: Device UUID string.
|
||||
path: RouterOS menu path, e.g. "/ip/address" or "/interface".
|
||||
|
||||
Returns:
|
||||
{"success": bool, "data": list[dict], "error": str|None}
|
||||
"""
|
||||
command = f"{path}/print"
|
||||
return await execute_command(device_id, command)
|
||||
|
||||
|
||||
async def add_entry(
|
||||
device_id: str, path: str, properties: dict[str, str]
|
||||
) -> dict[str, Any]:
|
||||
"""Add a new entry to a RouterOS menu path.
|
||||
|
||||
Args:
|
||||
device_id: Device UUID.
|
||||
path: Menu path, e.g. "/ip/address".
|
||||
properties: Key-value pairs for the new entry.
|
||||
|
||||
Returns:
|
||||
Command response dict.
|
||||
"""
|
||||
args = [f"={k}={v}" for k, v in properties.items()]
|
||||
return await execute_command(device_id, f"{path}/add", args)
|
||||
|
||||
|
||||
async def update_entry(
|
||||
device_id: str, path: str, entry_id: str | None, properties: dict[str, str]
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing entry in a RouterOS menu path.
|
||||
|
||||
Args:
|
||||
device_id: Device UUID.
|
||||
path: Menu path.
|
||||
entry_id: RouterOS .id value (e.g. "*1"). None for singleton paths.
|
||||
properties: Key-value pairs to update.
|
||||
|
||||
Returns:
|
||||
Command response dict.
|
||||
"""
|
||||
id_args = [f"=.id={entry_id}"] if entry_id else []
|
||||
args = id_args + [f"={k}={v}" for k, v in properties.items()]
|
||||
return await execute_command(device_id, f"{path}/set", args)
|
||||
|
||||
|
||||
async def remove_entry(
|
||||
device_id: str, path: str, entry_id: str
|
||||
) -> dict[str, Any]:
|
||||
"""Remove an entry from a RouterOS menu path.
|
||||
|
||||
Args:
|
||||
device_id: Device UUID.
|
||||
path: Menu path.
|
||||
entry_id: RouterOS .id value.
|
||||
|
||||
Returns:
|
||||
Command response dict.
|
||||
"""
|
||||
return await execute_command(device_id, f"{path}/remove", [f"=.id={entry_id}"])
|
||||
|
||||
|
||||
async def execute_cli(device_id: str, cli_command: str) -> dict[str, Any]:
|
||||
"""Execute an arbitrary RouterOS CLI command.
|
||||
|
||||
For commands that don't follow the standard /path/action pattern.
|
||||
The command is sent as-is to the RouterOS API.
|
||||
|
||||
Args:
|
||||
device_id: Device UUID.
|
||||
cli_command: Full CLI command string.
|
||||
|
||||
Returns:
|
||||
Command response dict.
|
||||
"""
|
||||
return await execute_command(device_id, cli_command)
|
||||
|
||||
|
||||
async def close() -> None:
|
||||
"""Close the NATS connection. Called on application shutdown."""
|
||||
global _nc
|
||||
if _nc and not _nc.is_closed:
|
||||
await _nc.drain()
|
||||
_nc = None
|
||||
logger.info("RouterOS proxy NATS connection closed")
|
||||
220
backend/app/services/rsc_parser.py
Normal file
220
backend/app/services/rsc_parser.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""RouterOS RSC export parser — extracts categories, validates syntax, computes impact."""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HIGH_RISK_PATHS = {
|
||||
"/ip address", "/ip route", "/ip firewall filter", "/ip firewall nat",
|
||||
"/interface", "/interface bridge", "/interface vlan",
|
||||
"/system identity", "/ip service", "/ip ssh", "/user",
|
||||
}
|
||||
|
||||
MANAGEMENT_PATTERNS = [
|
||||
(re.compile(r"chain=input.*dst-port=(22|8291|8728|8729|443|80)", re.I),
|
||||
"Modifies firewall rules for management ports (SSH/WinBox/API/Web)"),
|
||||
(re.compile(r"chain=input.*action=drop", re.I),
|
||||
"Adds drop rule on input chain — may block management access"),
|
||||
(re.compile(r"/ip service", re.I),
|
||||
"Modifies IP services — may disable API/SSH/WinBox access"),
|
||||
(re.compile(r"/user.*set.*password", re.I),
|
||||
"Changes user password — may affect automated access"),
|
||||
]
|
||||
|
||||
|
||||
def _join_continuation_lines(text: str) -> list[str]:
|
||||
"""Join lines ending with \\ into single logical lines."""
|
||||
lines = text.split("\n")
|
||||
joined: list[str] = []
|
||||
buf = ""
|
||||
for line in lines:
|
||||
stripped = line.rstrip()
|
||||
if stripped.endswith("\\"):
|
||||
buf += stripped[:-1].rstrip() + " "
|
||||
else:
|
||||
if buf:
|
||||
buf += stripped
|
||||
joined.append(buf)
|
||||
buf = ""
|
||||
else:
|
||||
joined.append(stripped)
|
||||
if buf:
|
||||
joined.append(buf + " <<TRUNCATED>>")
|
||||
return joined
|
||||
|
||||
|
||||
def parse_rsc(text: str) -> dict[str, Any]:
|
||||
"""Parse a RouterOS /export compact output.
|
||||
|
||||
Returns a dict with a "categories" list, each containing:
|
||||
- path: the RouterOS command path (e.g. "/ip address")
|
||||
- adds: count of "add" commands
|
||||
- sets: count of "set" commands
|
||||
- removes: count of "remove" commands
|
||||
- commands: list of command strings under this path
|
||||
"""
|
||||
lines = _join_continuation_lines(text)
|
||||
categories: dict[str, dict] = {}
|
||||
current_path: str | None = None
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
if line.startswith("/"):
|
||||
# Could be just a path header, or a path followed by a command
|
||||
parts = line.split(None, 1)
|
||||
if len(parts) == 1:
|
||||
# Pure path header like "/interface bridge"
|
||||
current_path = parts[0]
|
||||
else:
|
||||
# Check if second part starts with a known command verb
|
||||
cmd_check = parts[1].strip().split(None, 1)
|
||||
if cmd_check and cmd_check[0] in ("add", "set", "remove", "print", "enable", "disable"):
|
||||
current_path = parts[0]
|
||||
line = parts[1].strip()
|
||||
else:
|
||||
# The whole line is a path (e.g. "/ip firewall filter")
|
||||
current_path = line
|
||||
continue
|
||||
|
||||
if current_path and current_path not in categories:
|
||||
categories[current_path] = {
|
||||
"path": current_path,
|
||||
"adds": 0,
|
||||
"sets": 0,
|
||||
"removes": 0,
|
||||
"commands": [],
|
||||
}
|
||||
|
||||
if len(parts) == 1:
|
||||
continue
|
||||
|
||||
if current_path is None:
|
||||
continue
|
||||
|
||||
if current_path not in categories:
|
||||
categories[current_path] = {
|
||||
"path": current_path,
|
||||
"adds": 0,
|
||||
"sets": 0,
|
||||
"removes": 0,
|
||||
"commands": [],
|
||||
}
|
||||
|
||||
cat = categories[current_path]
|
||||
cat["commands"].append(line)
|
||||
|
||||
if line.startswith("add ") or line.startswith("add\t"):
|
||||
cat["adds"] += 1
|
||||
elif line.startswith("set "):
|
||||
cat["sets"] += 1
|
||||
elif line.startswith("remove "):
|
||||
cat["removes"] += 1
|
||||
|
||||
return {"categories": list(categories.values())}
|
||||
|
||||
|
||||
def validate_rsc(text: str) -> dict[str, Any]:
|
||||
"""Validate RSC export syntax.
|
||||
|
||||
Checks for:
|
||||
- Unbalanced quotes (indicates truncation or corruption)
|
||||
- Trailing continuation lines (indicates truncated export)
|
||||
|
||||
Returns dict with "valid" (bool) and "errors" (list of strings).
|
||||
"""
|
||||
errors: list[str] = []
|
||||
|
||||
# Check for unbalanced quotes across the entire file
|
||||
in_quote = False
|
||||
for line in text.split("\n"):
|
||||
stripped = line.rstrip()
|
||||
if stripped.endswith("\\"):
|
||||
stripped = stripped[:-1]
|
||||
# Count unescaped quotes
|
||||
count = stripped.count('"') - stripped.count('\\"')
|
||||
if count % 2 != 0:
|
||||
in_quote = not in_quote
|
||||
|
||||
if in_quote:
|
||||
errors.append("Unbalanced quote detected — file may be truncated")
|
||||
|
||||
# Check if file ends with a continuation backslash
|
||||
lines = text.rstrip().split("\n")
|
||||
if lines and lines[-1].rstrip().endswith("\\"):
|
||||
errors.append("File ends with continuation line (\\) — truncated export")
|
||||
|
||||
return {"valid": len(errors) == 0, "errors": errors}
|
||||
|
||||
|
||||
def compute_impact(
|
||||
current_parsed: dict[str, Any],
|
||||
target_parsed: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Compare current vs target parsed RSC and compute impact analysis.
|
||||
|
||||
Returns dict with:
|
||||
- categories: list of per-path diffs with risk levels
|
||||
- warnings: list of human-readable warning strings
|
||||
- diff: summary counts (added, removed, modified)
|
||||
"""
|
||||
current_map = {c["path"]: c for c in current_parsed["categories"]}
|
||||
target_map = {c["path"]: c for c in target_parsed["categories"]}
|
||||
all_paths = sorted(set(list(current_map.keys()) + list(target_map.keys())))
|
||||
|
||||
result_categories = []
|
||||
warnings: list[str] = []
|
||||
total_added = total_removed = total_modified = 0
|
||||
|
||||
for path in all_paths:
|
||||
curr = current_map.get(path, {"adds": 0, "sets": 0, "removes": 0, "commands": []})
|
||||
tgt = target_map.get(path, {"adds": 0, "sets": 0, "removes": 0, "commands": []})
|
||||
curr_cmds = set(curr.get("commands", []))
|
||||
tgt_cmds = set(tgt.get("commands", []))
|
||||
added = len(tgt_cmds - curr_cmds)
|
||||
removed = len(curr_cmds - tgt_cmds)
|
||||
total_added += added
|
||||
total_removed += removed
|
||||
|
||||
has_changes = added > 0 or removed > 0
|
||||
risk = "none"
|
||||
if has_changes:
|
||||
risk = "high" if path in HIGH_RISK_PATHS else "low"
|
||||
result_categories.append({
|
||||
"path": path,
|
||||
"adds": added,
|
||||
"removes": removed,
|
||||
"risk": risk,
|
||||
})
|
||||
|
||||
# Check target commands against management patterns
|
||||
target_text = "\n".join(
|
||||
cmd for cat in target_parsed["categories"] for cmd in cat.get("commands", [])
|
||||
)
|
||||
for pattern, message in MANAGEMENT_PATTERNS:
|
||||
if pattern.search(target_text):
|
||||
warnings.append(message)
|
||||
|
||||
# Warn about removed IP addresses
|
||||
if "/ip address" in current_map and "/ip address" in target_map:
|
||||
curr_addrs = current_map["/ip address"].get("commands", [])
|
||||
tgt_addrs = target_map["/ip address"].get("commands", [])
|
||||
removed_addrs = set(curr_addrs) - set(tgt_addrs)
|
||||
if removed_addrs:
|
||||
warnings.append(
|
||||
f"Removes {len(removed_addrs)} IP address(es) — verify none are management interfaces"
|
||||
)
|
||||
|
||||
return {
|
||||
"categories": result_categories,
|
||||
"warnings": warnings,
|
||||
"diff": {
|
||||
"added": total_added,
|
||||
"removed": total_removed,
|
||||
"modified": total_modified,
|
||||
},
|
||||
}
|
||||
124
backend/app/services/scanner.py
Normal file
124
backend/app/services/scanner.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Subnet scanner for MikroTik device discovery.
|
||||
|
||||
Scans a CIDR range by attempting TCP connections to RouterOS API ports
|
||||
(8728 and 8729) with configurable concurrency limits and timeouts.
|
||||
|
||||
Security constraints:
|
||||
- CIDR range limited to /20 or smaller (4096 IPs maximum)
|
||||
- Maximum 50 concurrent connections to prevent network flooding
|
||||
- 2-second timeout per connection attempt
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import socket
|
||||
from typing import Optional
|
||||
|
||||
from app.schemas.device import SubnetScanResult
|
||||
|
||||
# Maximum concurrency for TCP probes
|
||||
_MAX_CONCURRENT = 50
|
||||
# Timeout (seconds) per TCP connection attempt
|
||||
_TCP_TIMEOUT = 2.0
|
||||
# RouterOS API port
|
||||
_API_PORT = 8728
|
||||
# RouterOS SSL API port
|
||||
_SSL_PORT = 8729
|
||||
|
||||
|
||||
async def _probe_host(
|
||||
semaphore: asyncio.Semaphore,
|
||||
ip_str: str,
|
||||
) -> Optional[SubnetScanResult]:
|
||||
"""
|
||||
Probe a single IP for RouterOS API ports.
|
||||
|
||||
Returns a SubnetScanResult if either port is open, None otherwise.
|
||||
"""
|
||||
async with semaphore:
|
||||
api_open, ssl_open = await asyncio.gather(
|
||||
_tcp_connect(ip_str, _API_PORT),
|
||||
_tcp_connect(ip_str, _SSL_PORT),
|
||||
return_exceptions=False,
|
||||
)
|
||||
|
||||
if not api_open and not ssl_open:
|
||||
return None
|
||||
|
||||
# Attempt reverse DNS (best-effort; won't fail the scan)
|
||||
hostname = await _reverse_dns(ip_str)
|
||||
|
||||
return SubnetScanResult(
|
||||
ip_address=ip_str,
|
||||
hostname=hostname,
|
||||
api_port_open=api_open,
|
||||
api_ssl_port_open=ssl_open,
|
||||
)
|
||||
|
||||
|
||||
async def _tcp_connect(ip: str, port: int) -> bool:
|
||||
"""Return True if a TCP connection to ip:port succeeds within _TCP_TIMEOUT."""
|
||||
try:
|
||||
_, writer = await asyncio.wait_for(
|
||||
asyncio.open_connection(ip, port),
|
||||
timeout=_TCP_TIMEOUT,
|
||||
)
|
||||
writer.close()
|
||||
try:
|
||||
await writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def _reverse_dns(ip: str) -> Optional[str]:
|
||||
"""Attempt a reverse DNS lookup. Returns None on failure."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
hostname, _, _ = await asyncio.wait_for(
|
||||
loop.run_in_executor(None, socket.gethostbyaddr, ip),
|
||||
timeout=1.5,
|
||||
)
|
||||
return hostname
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def scan_subnet(cidr: str) -> list[SubnetScanResult]:
|
||||
"""
|
||||
Scan a CIDR range for hosts with open RouterOS API ports.
|
||||
|
||||
Args:
|
||||
cidr: CIDR notation string, e.g. "192.168.1.0/24".
|
||||
Must be /20 or smaller (validated by SubnetScanRequest).
|
||||
|
||||
Returns:
|
||||
List of SubnetScanResult for each host with at least one open API port.
|
||||
|
||||
Raises:
|
||||
ValueError: If CIDR is malformed or too large.
|
||||
"""
|
||||
try:
|
||||
network = ipaddress.ip_network(cidr, strict=False)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid CIDR: {e}") from e
|
||||
|
||||
if network.num_addresses > 4096:
|
||||
raise ValueError(
|
||||
f"CIDR range too large ({network.num_addresses} addresses). "
|
||||
"Maximum allowed is /20 (4096 addresses)."
|
||||
)
|
||||
|
||||
# Skip network address and broadcast address for IPv4
|
||||
hosts = list(network.hosts()) if network.num_addresses > 2 else list(network)
|
||||
|
||||
semaphore = asyncio.Semaphore(_MAX_CONCURRENT)
|
||||
tasks = [_probe_host(semaphore, str(ip)) for ip in hosts]
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
# Filter out None (hosts with no open ports)
|
||||
return [r for r in results if r is not None]
|
||||
113
backend/app/services/srp_service.py
Normal file
113
backend/app/services/srp_service.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""SRP-6a server-side authentication service.
|
||||
|
||||
Wraps the srptools library for the two-step SRP handshake.
|
||||
All functions are async, using asyncio.to_thread() because
|
||||
srptools operations are CPU-bound and synchronous.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
|
||||
from srptools import SRPContext, SRPServerSession
|
||||
from srptools.constants import PRIME_2048, PRIME_2048_GEN
|
||||
|
||||
# Client uses Web Crypto SHA-256 — server must match.
|
||||
# srptools defaults to SHA-1 which would cause proof mismatch.
|
||||
_SRP_HASH = hashlib.sha256
|
||||
|
||||
|
||||
async def create_srp_verifier(
|
||||
salt_hex: str, verifier_hex: str
|
||||
) -> tuple[bytes, bytes]:
|
||||
"""Convert client-provided hex salt and verifier to bytes for storage.
|
||||
|
||||
The client computes v = g^x mod N using 2SKD-derived SRP-x.
|
||||
The server stores the verifier directly and never computes x
|
||||
from the password.
|
||||
|
||||
Returns:
|
||||
Tuple of (salt_bytes, verifier_bytes) ready for database storage.
|
||||
"""
|
||||
return bytes.fromhex(salt_hex), bytes.fromhex(verifier_hex)
|
||||
|
||||
|
||||
async def srp_init(
|
||||
email: str, srp_verifier_hex: str
|
||||
) -> tuple[str, str]:
|
||||
"""SRP Step 1: Generate server ephemeral (B) and private key (b).
|
||||
|
||||
Args:
|
||||
email: User email (SRP identity I).
|
||||
srp_verifier_hex: Hex-encoded SRP verifier from database.
|
||||
|
||||
Returns:
|
||||
Tuple of (server_public_hex, server_private_hex).
|
||||
Caller stores server_private in Redis with 60s TTL.
|
||||
|
||||
Raises:
|
||||
ValueError: If SRP initialization fails for any reason.
|
||||
"""
|
||||
def _init() -> tuple[str, str]:
|
||||
context = SRPContext(
|
||||
email, prime=PRIME_2048, generator=PRIME_2048_GEN,
|
||||
hash_func=_SRP_HASH,
|
||||
)
|
||||
server_session = SRPServerSession(
|
||||
context, srp_verifier_hex
|
||||
)
|
||||
return server_session.public, server_session.private
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_init)
|
||||
except Exception as e:
|
||||
raise ValueError(f"SRP initialization failed: {e}") from e
|
||||
|
||||
|
||||
async def srp_verify(
|
||||
email: str,
|
||||
srp_verifier_hex: str,
|
||||
server_private: str,
|
||||
client_public: str,
|
||||
client_proof: str,
|
||||
srp_salt_hex: str,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""SRP Step 2: Verify client proof M1, return server proof M2.
|
||||
|
||||
Args:
|
||||
email: User email (SRP identity I).
|
||||
srp_verifier_hex: Hex-encoded SRP verifier from database.
|
||||
server_private: Server private ephemeral from Redis session.
|
||||
client_public: Hex-encoded client public ephemeral A.
|
||||
client_proof: Hex-encoded client proof M1.
|
||||
srp_salt_hex: Hex-encoded SRP salt.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, server_proof_hex_or_none).
|
||||
If valid, server_proof is M2 for the client to verify.
|
||||
"""
|
||||
def _verify() -> tuple[bool, str | None]:
|
||||
import logging
|
||||
log = logging.getLogger("srp_debug")
|
||||
context = SRPContext(
|
||||
email, prime=PRIME_2048, generator=PRIME_2048_GEN,
|
||||
hash_func=_SRP_HASH,
|
||||
)
|
||||
server_session = SRPServerSession(
|
||||
context, srp_verifier_hex, private=server_private
|
||||
)
|
||||
_key, _key_proof, _key_proof_hash = server_session.process(client_public, srp_salt_hex)
|
||||
# srptools verify_proof has a Python 3 bug: hexlify() returns bytes
|
||||
# but client_proof is str, so bytes == str is always False.
|
||||
# Compare manually with consistent types.
|
||||
server_m1 = _key_proof if isinstance(_key_proof, str) else _key_proof.decode('ascii')
|
||||
is_valid = client_proof.lower() == server_m1.lower()
|
||||
if not is_valid:
|
||||
return False, None
|
||||
# Return M2 (key_proof_hash), also fixing the bytes/str issue
|
||||
m2 = _key_proof_hash if isinstance(_key_proof_hash, str) else _key_proof_hash.decode('ascii')
|
||||
return True, m2
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_verify)
|
||||
except Exception as e:
|
||||
raise ValueError(f"SRP verification failed: {e}") from e
|
||||
311
backend/app/services/sse_manager.py
Normal file
311
backend/app/services/sse_manager.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""SSE Connection Manager -- bridges NATS JetStream to per-client asyncio queues.
|
||||
|
||||
Each SSE client gets its own NATS connection with ephemeral consumers.
|
||||
Events are tenant-filtered and placed onto an asyncio.Queue that the
|
||||
SSE router drains via EventSourceResponse.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
import nats
|
||||
import structlog
|
||||
from nats.js.api import ConsumerConfig, DeliverPolicy, StreamConfig
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
# Subjects per stream for SSE subscriptions
|
||||
# Note: config.push.* subjects live in DEVICE_EVENTS (created by Go poller)
|
||||
_DEVICE_EVENT_SUBJECTS = [
|
||||
"device.status.>",
|
||||
"device.metrics.>",
|
||||
"config.push.rollback.>",
|
||||
"config.push.alert.>",
|
||||
]
|
||||
_ALERT_EVENT_SUBJECTS = ["alert.fired.>", "alert.resolved.>"]
|
||||
_OPERATION_EVENT_SUBJECTS = ["firmware.progress.>"]
|
||||
|
||||
|
||||
def _map_subject_to_event_type(subject: str) -> str:
|
||||
"""Map a NATS subject prefix to an SSE event type string."""
|
||||
if subject.startswith("device.status."):
|
||||
return "device_status"
|
||||
if subject.startswith("device.metrics."):
|
||||
return "metric_update"
|
||||
if subject.startswith("alert.fired."):
|
||||
return "alert_fired"
|
||||
if subject.startswith("alert.resolved."):
|
||||
return "alert_resolved"
|
||||
if subject.startswith("config.push."):
|
||||
return "config_push"
|
||||
if subject.startswith("firmware.progress."):
|
||||
return "firmware_progress"
|
||||
return "unknown"
|
||||
|
||||
|
||||
async def ensure_sse_streams() -> None:
|
||||
"""Create ALERT_EVENTS and OPERATION_EVENTS NATS streams if they don't exist.
|
||||
|
||||
Called once during app startup so the streams are ready before any
|
||||
SSE connection or event publisher needs them. Idempotent -- uses
|
||||
add_stream which acts as create-or-update.
|
||||
"""
|
||||
nc = None
|
||||
try:
|
||||
nc = await nats.connect(settings.NATS_URL)
|
||||
js = nc.jetstream()
|
||||
|
||||
await js.add_stream(
|
||||
StreamConfig(
|
||||
name="ALERT_EVENTS",
|
||||
subjects=["alert.fired.>", "alert.resolved.>"],
|
||||
max_age=3600, # 1 hour retention
|
||||
)
|
||||
)
|
||||
logger.info("nats.stream.ensured", stream="ALERT_EVENTS")
|
||||
|
||||
await js.add_stream(
|
||||
StreamConfig(
|
||||
name="OPERATION_EVENTS",
|
||||
subjects=["firmware.progress.>"],
|
||||
max_age=3600, # 1 hour retention
|
||||
)
|
||||
)
|
||||
logger.info("nats.stream.ensured", stream="OPERATION_EVENTS")
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("sse.streams.ensure_failed", error=str(exc))
|
||||
raise
|
||||
finally:
|
||||
if nc:
|
||||
try:
|
||||
await nc.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class SSEConnectionManager:
|
||||
"""Manages a single SSE client's lifecycle: NATS connection, subscriptions, and event queue."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._nc: Optional[nats.aio.client.Client] = None
|
||||
self._subscriptions: list = []
|
||||
self._queue: Optional[asyncio.Queue] = None
|
||||
self._tenant_id: Optional[str] = None
|
||||
self._connection_id: Optional[str] = None
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
connection_id: str,
|
||||
tenant_id: Optional[str],
|
||||
last_event_id: Optional[str] = None,
|
||||
) -> asyncio.Queue:
|
||||
"""Set up NATS subscriptions and return an asyncio.Queue for SSE events.
|
||||
|
||||
Args:
|
||||
connection_id: Unique identifier for this SSE connection.
|
||||
tenant_id: Tenant UUID string to filter events. None for super_admin
|
||||
(receives events from all tenants).
|
||||
last_event_id: NATS stream sequence number from the Last-Event-ID header.
|
||||
If provided, replay starts from sequence + 1.
|
||||
|
||||
Returns:
|
||||
asyncio.Queue that the SSE generator should drain.
|
||||
"""
|
||||
self._connection_id = connection_id
|
||||
self._tenant_id = tenant_id
|
||||
self._queue = asyncio.Queue(maxsize=256)
|
||||
|
||||
self._nc = await nats.connect(
|
||||
settings.NATS_URL,
|
||||
max_reconnect_attempts=5,
|
||||
reconnect_time_wait=2,
|
||||
)
|
||||
js = self._nc.jetstream()
|
||||
|
||||
logger.info(
|
||||
"sse.connecting",
|
||||
connection_id=connection_id,
|
||||
tenant_id=tenant_id,
|
||||
last_event_id=last_event_id,
|
||||
)
|
||||
|
||||
# Build consumer config for replay support
|
||||
if last_event_id is not None:
|
||||
try:
|
||||
start_seq = int(last_event_id) + 1
|
||||
consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.BY_START_SEQUENCE, opt_start_seq=start_seq)
|
||||
except (ValueError, TypeError):
|
||||
consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.NEW)
|
||||
else:
|
||||
consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.NEW)
|
||||
|
||||
# Subscribe to device events (DEVICE_EVENTS stream -- created by Go poller)
|
||||
for subject in _DEVICE_EVENT_SUBJECTS:
|
||||
try:
|
||||
sub = await js.subscribe(
|
||||
subject,
|
||||
stream="DEVICE_EVENTS",
|
||||
config=consumer_cfg,
|
||||
)
|
||||
self._subscriptions.append(sub)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"sse.subscribe_failed",
|
||||
subject=subject,
|
||||
stream="DEVICE_EVENTS",
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
# Subscribe to alert events (ALERT_EVENTS stream)
|
||||
# Lazily create the stream if it doesn't exist yet (startup race)
|
||||
for subject in _ALERT_EVENT_SUBJECTS:
|
||||
try:
|
||||
sub = await js.subscribe(
|
||||
subject,
|
||||
stream="ALERT_EVENTS",
|
||||
config=consumer_cfg,
|
||||
)
|
||||
self._subscriptions.append(sub)
|
||||
except Exception as exc:
|
||||
if "stream not found" in str(exc):
|
||||
try:
|
||||
await js.add_stream(StreamConfig(
|
||||
name="ALERT_EVENTS",
|
||||
subjects=_ALERT_EVENT_SUBJECTS,
|
||||
max_age=3600,
|
||||
))
|
||||
sub = await js.subscribe(subject, stream="ALERT_EVENTS", config=consumer_cfg)
|
||||
self._subscriptions.append(sub)
|
||||
logger.info("sse.stream_created_lazily", stream="ALERT_EVENTS")
|
||||
except Exception as retry_exc:
|
||||
logger.warning("sse.subscribe_failed", subject=subject, stream="ALERT_EVENTS", error=str(retry_exc))
|
||||
else:
|
||||
logger.warning("sse.subscribe_failed", subject=subject, stream="ALERT_EVENTS", error=str(exc))
|
||||
|
||||
# Subscribe to operation events (OPERATION_EVENTS stream)
|
||||
for subject in _OPERATION_EVENT_SUBJECTS:
|
||||
try:
|
||||
sub = await js.subscribe(
|
||||
subject,
|
||||
stream="OPERATION_EVENTS",
|
||||
config=consumer_cfg,
|
||||
)
|
||||
self._subscriptions.append(sub)
|
||||
except Exception as exc:
|
||||
if "stream not found" in str(exc):
|
||||
try:
|
||||
await js.add_stream(StreamConfig(
|
||||
name="OPERATION_EVENTS",
|
||||
subjects=_OPERATION_EVENT_SUBJECTS,
|
||||
max_age=3600,
|
||||
))
|
||||
sub = await js.subscribe(subject, stream="OPERATION_EVENTS", config=consumer_cfg)
|
||||
self._subscriptions.append(sub)
|
||||
logger.info("sse.stream_created_lazily", stream="OPERATION_EVENTS")
|
||||
except Exception as retry_exc:
|
||||
logger.warning("sse.subscribe_failed", subject=subject, stream="OPERATION_EVENTS", error=str(retry_exc))
|
||||
else:
|
||||
logger.warning("sse.subscribe_failed", subject=subject, stream="OPERATION_EVENTS", error=str(exc))
|
||||
|
||||
# Start background task to pull messages from subscriptions into the queue
|
||||
asyncio.create_task(self._pump_messages())
|
||||
|
||||
logger.info(
|
||||
"sse.connected",
|
||||
connection_id=connection_id,
|
||||
subscription_count=len(self._subscriptions),
|
||||
)
|
||||
|
||||
return self._queue
|
||||
|
||||
async def _pump_messages(self) -> None:
|
||||
"""Read messages from all NATS push subscriptions and push them onto the asyncio queue.
|
||||
|
||||
Uses next_msg with a short timeout so we can interleave across
|
||||
subscriptions without blocking. Runs until the NATS connection is closed
|
||||
or drained.
|
||||
"""
|
||||
while self._nc and self._nc.is_connected:
|
||||
for sub in self._subscriptions:
|
||||
try:
|
||||
msg = await sub.next_msg(timeout=0.5)
|
||||
await self._handle_message(msg)
|
||||
except nats.errors.TimeoutError:
|
||||
# No messages available on this subscription -- move on
|
||||
continue
|
||||
except Exception as exc:
|
||||
if self._nc and self._nc.is_connected:
|
||||
logger.warning(
|
||||
"sse.pump_error",
|
||||
connection_id=self._connection_id,
|
||||
error=str(exc),
|
||||
)
|
||||
break
|
||||
# Brief yield to avoid tight-looping
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _handle_message(self, msg) -> None:
|
||||
"""Parse a NATS message, apply tenant filter, and enqueue as SSE event."""
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
# Tenant filtering: skip messages not matching this connection's tenant
|
||||
if self._tenant_id is not None:
|
||||
msg_tenant = data.get("tenant_id", "")
|
||||
if str(msg_tenant) != self._tenant_id:
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
event_type = _map_subject_to_event_type(msg.subject)
|
||||
|
||||
# Extract NATS stream sequence for Last-Event-ID support
|
||||
seq_id = "0"
|
||||
if msg.metadata and msg.metadata.sequence:
|
||||
seq_id = str(msg.metadata.sequence.stream)
|
||||
|
||||
sse_event = {
|
||||
"event": event_type,
|
||||
"data": json.dumps(data),
|
||||
"id": seq_id,
|
||||
}
|
||||
|
||||
try:
|
||||
self._queue.put_nowait(sse_event)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
"sse.queue_full",
|
||||
connection_id=self._connection_id,
|
||||
dropped_event=event_type,
|
||||
)
|
||||
|
||||
await msg.ack()
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Unsubscribe from all NATS subscriptions and close the connection."""
|
||||
logger.info("sse.disconnecting", connection_id=self._connection_id)
|
||||
|
||||
for sub in self._subscriptions:
|
||||
try:
|
||||
await sub.unsubscribe()
|
||||
except Exception:
|
||||
pass
|
||||
self._subscriptions.clear()
|
||||
|
||||
if self._nc:
|
||||
try:
|
||||
await self._nc.drain()
|
||||
except Exception:
|
||||
try:
|
||||
await self._nc.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._nc = None
|
||||
|
||||
logger.info("sse.disconnected", connection_id=self._connection_id)
|
||||
480
backend/app/services/template_service.py
Normal file
480
backend/app/services/template_service.py
Normal file
@@ -0,0 +1,480 @@
|
||||
"""Config template service: Jinja2 rendering, variable extraction, and multi-device push.
|
||||
|
||||
Provides:
|
||||
- extract_variables: Parse template content to find all undeclared Jinja2 variables
|
||||
- render_template: Render a template with device context and custom variables
|
||||
- validate_variable: Type-check a variable value against its declared type
|
||||
- push_to_devices: Sequential multi-device push with pause-on-failure
|
||||
- push_single_device: Two-phase panic-revert push for a single device
|
||||
|
||||
The push logic follows the same two-phase pattern as restore_service but uses
|
||||
separate scheduler and file names to avoid conflicts with restore operations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import ipaddress
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import asyncssh
|
||||
from jinja2 import meta
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
from sqlalchemy import select, text
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
from app.models.config_template import TemplatePushJob
|
||||
from app.models.device import Device
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Sandboxed Jinja2 environment prevents template injection
|
||||
_env = SandboxedEnvironment()
|
||||
|
||||
# Names used on the RouterOS device during template push
|
||||
_PANIC_REVERT_SCHEDULER = "mikrotik-portal-template-revert"
|
||||
_PRE_PUSH_BACKUP = "portal-template-pre-push"
|
||||
_TEMPLATE_RSC = "portal-template.rsc"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Variable extraction & rendering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def extract_variables(template_content: str) -> list[str]:
|
||||
"""Extract all undeclared variables from a Jinja2 template.
|
||||
|
||||
Returns a sorted list of variable names, excluding the built-in 'device'
|
||||
variable which is auto-populated at render time.
|
||||
"""
|
||||
ast = _env.parse(template_content)
|
||||
all_vars = meta.find_undeclared_variables(ast)
|
||||
# 'device' is a built-in variable, not user-provided
|
||||
return sorted(v for v in all_vars if v != "device")
|
||||
|
||||
|
||||
def render_template(
|
||||
template_content: str,
|
||||
device: dict,
|
||||
custom_variables: dict[str, str],
|
||||
) -> str:
|
||||
"""Render a Jinja2 template with device context and custom variables.
|
||||
|
||||
The 'device' variable is auto-populated from the device dict.
|
||||
Custom variables are user-provided at push time.
|
||||
|
||||
Uses SandboxedEnvironment to prevent template injection.
|
||||
|
||||
Args:
|
||||
template_content: Jinja2 template string.
|
||||
device: Device info dict with keys: hostname, ip_address, model.
|
||||
custom_variables: User-supplied variable values.
|
||||
|
||||
Returns:
|
||||
Rendered template string.
|
||||
|
||||
Raises:
|
||||
jinja2.TemplateSyntaxError: If template has syntax errors.
|
||||
jinja2.UndefinedError: If required variables are missing.
|
||||
"""
|
||||
context = {
|
||||
"device": {
|
||||
"hostname": device.get("hostname", ""),
|
||||
"ip": device.get("ip_address", ""),
|
||||
"model": device.get("model", ""),
|
||||
},
|
||||
**custom_variables,
|
||||
}
|
||||
tpl = _env.from_string(template_content)
|
||||
return tpl.render(context)
|
||||
|
||||
|
||||
def validate_variable(name: str, value: str, var_type: str) -> str | None:
|
||||
"""Validate a variable value against its declared type.
|
||||
|
||||
Returns None on success, or an error message string on failure.
|
||||
"""
|
||||
if var_type == "string":
|
||||
return None # any string is valid
|
||||
elif var_type == "ip":
|
||||
try:
|
||||
ipaddress.ip_address(value)
|
||||
return None
|
||||
except ValueError:
|
||||
return f"'{name}' must be a valid IP address"
|
||||
elif var_type == "subnet":
|
||||
try:
|
||||
ipaddress.ip_network(value, strict=False)
|
||||
return None
|
||||
except ValueError:
|
||||
return f"'{name}' must be a valid subnet (e.g., 192.168.1.0/24)"
|
||||
elif var_type == "integer":
|
||||
try:
|
||||
int(value)
|
||||
return None
|
||||
except ValueError:
|
||||
return f"'{name}' must be an integer"
|
||||
elif var_type == "boolean":
|
||||
if value.lower() in ("true", "false", "yes", "no", "1", "0"):
|
||||
return None
|
||||
return f"'{name}' must be a boolean (true/false)"
|
||||
return None # unknown type, allow
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-device push orchestration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def push_to_devices(rollout_id: str) -> dict:
|
||||
"""Execute sequential template push for all jobs in a rollout.
|
||||
|
||||
Processes devices one at a time. If any device fails or reverts,
|
||||
remaining jobs stay pending (paused). Follows the same pattern as
|
||||
firmware upgrade_service.start_mass_upgrade.
|
||||
|
||||
This runs as a background task (asyncio.create_task) after the
|
||||
API creates the push jobs and returns the rollout_id.
|
||||
"""
|
||||
try:
|
||||
return await _run_push_rollout(rollout_id)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Uncaught exception in template push rollout %s: %s",
|
||||
rollout_id, exc, exc_info=True,
|
||||
)
|
||||
return {"completed": 0, "failed": 1, "pending": 0}
|
||||
|
||||
|
||||
async def _run_push_rollout(rollout_id: str) -> dict:
|
||||
"""Internal rollout implementation."""
|
||||
# Load all jobs for this rollout
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT j.id::text, j.status, d.hostname
|
||||
FROM template_push_jobs j
|
||||
JOIN devices d ON d.id = j.device_id
|
||||
WHERE j.rollout_id = CAST(:rollout_id AS uuid)
|
||||
ORDER BY j.created_at ASC
|
||||
"""),
|
||||
{"rollout_id": rollout_id},
|
||||
)
|
||||
jobs = result.fetchall()
|
||||
|
||||
if not jobs:
|
||||
logger.warning("No jobs found for template push rollout %s", rollout_id)
|
||||
return {"completed": 0, "failed": 0, "pending": 0}
|
||||
|
||||
completed = 0
|
||||
failed = False
|
||||
|
||||
for job_id, current_status, hostname in jobs:
|
||||
if current_status != "pending":
|
||||
if current_status == "committed":
|
||||
completed += 1
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"Template push rollout %s: pushing to device %s (job %s)",
|
||||
rollout_id, hostname, job_id,
|
||||
)
|
||||
|
||||
await push_single_device(job_id)
|
||||
|
||||
# Check resulting status
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("SELECT status FROM template_push_jobs WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": job_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
|
||||
if row and row[0] == "committed":
|
||||
completed += 1
|
||||
elif row and row[0] in ("failed", "reverted"):
|
||||
failed = True
|
||||
logger.error(
|
||||
"Template push rollout %s paused: device %s %s",
|
||||
rollout_id, hostname, row[0],
|
||||
)
|
||||
break
|
||||
|
||||
# Count remaining pending jobs
|
||||
remaining = sum(1 for _, s, _ in jobs if s == "pending") - completed - (1 if failed else 0)
|
||||
|
||||
return {
|
||||
"completed": completed,
|
||||
"failed": 1 if failed else 0,
|
||||
"pending": max(0, remaining),
|
||||
}
|
||||
|
||||
|
||||
async def push_single_device(job_id: str) -> None:
|
||||
"""Push rendered template content to a single device.
|
||||
|
||||
Implements the two-phase panic-revert pattern:
|
||||
1. Pre-backup (mandatory)
|
||||
2. Install panic-revert scheduler on device
|
||||
3. Write template content as RSC file via SFTP
|
||||
4. /import the RSC file
|
||||
5. Wait 60s for config to settle
|
||||
6. Reachability check -> committed or reverted
|
||||
|
||||
All errors are caught and recorded in the job row.
|
||||
"""
|
||||
try:
|
||||
await _run_single_push(job_id)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Uncaught exception in template push job %s: %s",
|
||||
job_id, exc, exc_info=True,
|
||||
)
|
||||
await _update_job(job_id, status="failed", error_message=f"Unexpected error: {exc}")
|
||||
|
||||
|
||||
async def _run_single_push(job_id: str) -> None:
|
||||
"""Internal single-device push implementation."""
|
||||
|
||||
# Step 1: Load job and device info
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT j.id, j.device_id, j.tenant_id, j.rendered_content,
|
||||
d.ip_address, d.hostname, d.encrypted_credentials,
|
||||
d.encrypted_credentials_transit
|
||||
FROM template_push_jobs j
|
||||
JOIN devices d ON d.id = j.device_id
|
||||
WHERE j.id = CAST(:job_id AS uuid)
|
||||
"""),
|
||||
{"job_id": job_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
|
||||
if not row:
|
||||
logger.error("Template push job %s not found", job_id)
|
||||
return
|
||||
|
||||
(
|
||||
_, device_id, tenant_id, rendered_content,
|
||||
ip_address, hostname, encrypted_credentials,
|
||||
encrypted_credentials_transit,
|
||||
) = row
|
||||
|
||||
device_id = str(device_id)
|
||||
tenant_id = str(tenant_id)
|
||||
hostname = hostname or ip_address
|
||||
|
||||
# Step 2: Update status to pushing
|
||||
await _update_job(job_id, status="pushing", started_at=datetime.now(timezone.utc))
|
||||
|
||||
# Step 3: Decrypt credentials (dual-read: Transit preferred, legacy fallback)
|
||||
if not encrypted_credentials_transit and not encrypted_credentials:
|
||||
await _update_job(job_id, status="failed", error_message="Device has no stored credentials")
|
||||
return
|
||||
|
||||
try:
|
||||
from app.services.crypto import decrypt_credentials_hybrid
|
||||
key = settings.get_encryption_key_bytes()
|
||||
creds_json = await decrypt_credentials_hybrid(
|
||||
encrypted_credentials_transit, encrypted_credentials, tenant_id, key,
|
||||
)
|
||||
creds = json.loads(creds_json)
|
||||
ssh_username = creds.get("username", "")
|
||||
ssh_password = creds.get("password", "")
|
||||
except Exception as cred_err:
|
||||
await _update_job(
|
||||
job_id, status="failed",
|
||||
error_message=f"Failed to decrypt credentials: {cred_err}",
|
||||
)
|
||||
return
|
||||
|
||||
# Step 4: Mandatory pre-push backup
|
||||
logger.info("Running mandatory pre-push backup for device %s (%s)", hostname, ip_address)
|
||||
try:
|
||||
from app.services import backup_service
|
||||
backup_result = await backup_service.run_backup(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
trigger_type="pre-template-push",
|
||||
)
|
||||
backup_sha = backup_result["commit_sha"]
|
||||
await _update_job(job_id, pre_push_backup_sha=backup_sha)
|
||||
logger.info("Pre-push backup complete: %s", backup_sha[:8])
|
||||
except Exception as backup_err:
|
||||
logger.error("Pre-push backup failed for %s: %s", hostname, backup_err)
|
||||
await _update_job(
|
||||
job_id, status="failed",
|
||||
error_message=f"Pre-push backup failed: {backup_err}",
|
||||
)
|
||||
return
|
||||
|
||||
# Step 5: SSH to device - install panic-revert, push config
|
||||
logger.info(
|
||||
"Pushing template to device %s (%s): installing panic-revert and uploading config",
|
||||
hostname, ip_address,
|
||||
)
|
||||
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip_address,
|
||||
port=22,
|
||||
username=ssh_username,
|
||||
password=ssh_password,
|
||||
known_hosts=None,
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
# 5a: Create binary backup on device as revert point
|
||||
await conn.run(
|
||||
f"/system backup save name={_PRE_PUSH_BACKUP} dont-encrypt=yes",
|
||||
check=True,
|
||||
)
|
||||
logger.debug("Pre-push binary backup saved on device as %s.backup", _PRE_PUSH_BACKUP)
|
||||
|
||||
# 5b: Install panic-revert RouterOS scheduler
|
||||
await conn.run(
|
||||
f"/system scheduler add "
|
||||
f'name="{_PANIC_REVERT_SCHEDULER}" '
|
||||
f"interval=90s "
|
||||
f'on-event=":delay 0; /system backup load name={_PRE_PUSH_BACKUP}" '
|
||||
f"start-time=startup",
|
||||
check=True,
|
||||
)
|
||||
logger.debug("Panic-revert scheduler installed on device")
|
||||
|
||||
# 5c: Upload rendered template as RSC file via SFTP
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
async with sftp.open(_TEMPLATE_RSC, "wb") as f:
|
||||
await f.write(rendered_content.encode("utf-8"))
|
||||
logger.debug("Uploaded %s to device flash", _TEMPLATE_RSC)
|
||||
|
||||
# 5d: /import the config file
|
||||
import_result = await conn.run(
|
||||
f"/import file={_TEMPLATE_RSC}",
|
||||
check=False,
|
||||
)
|
||||
logger.info(
|
||||
"Template import result for device %s: exit_status=%s stdout=%r",
|
||||
hostname, import_result.exit_status,
|
||||
(import_result.stdout or "")[:200],
|
||||
)
|
||||
|
||||
# 5e: Clean up the uploaded RSC file (best-effort)
|
||||
try:
|
||||
await conn.run(f"/file remove {_TEMPLATE_RSC}", check=True)
|
||||
except Exception as cleanup_err:
|
||||
logger.warning(
|
||||
"Failed to clean up %s from device %s: %s",
|
||||
_TEMPLATE_RSC, ip_address, cleanup_err,
|
||||
)
|
||||
|
||||
except Exception as push_err:
|
||||
logger.error(
|
||||
"SSH push phase failed for device %s (%s): %s",
|
||||
hostname, ip_address, push_err,
|
||||
)
|
||||
await _update_job(
|
||||
job_id, status="failed",
|
||||
error_message=f"Config push failed during SSH phase: {push_err}",
|
||||
)
|
||||
return
|
||||
|
||||
# Step 6: Wait 60s for config to settle
|
||||
logger.info("Template pushed to device %s - waiting 60s for config to settle", hostname)
|
||||
await asyncio.sleep(60)
|
||||
|
||||
# Step 7: Reachability check
|
||||
reachable = await _check_reachability(ip_address, ssh_username, ssh_password)
|
||||
|
||||
if reachable:
|
||||
# Step 8a: Device is reachable - remove panic-revert scheduler + cleanup
|
||||
logger.info("Device %s (%s) is reachable after push - committing", hostname, ip_address)
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip_address, port=22,
|
||||
username=ssh_username, password=ssh_password,
|
||||
known_hosts=None, connect_timeout=30,
|
||||
) as conn:
|
||||
await conn.run(
|
||||
f'/system scheduler remove "{_PANIC_REVERT_SCHEDULER}"',
|
||||
check=False,
|
||||
)
|
||||
await conn.run(
|
||||
f"/file remove {_PRE_PUSH_BACKUP}.backup",
|
||||
check=False,
|
||||
)
|
||||
except Exception as cleanup_err:
|
||||
logger.warning(
|
||||
"Failed to clean up panic-revert scheduler/backup on device %s: %s",
|
||||
hostname, cleanup_err,
|
||||
)
|
||||
|
||||
await _update_job(
|
||||
job_id, status="committed",
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
else:
|
||||
# Step 8b: Device unreachable - RouterOS is auto-reverting
|
||||
logger.warning(
|
||||
"Device %s (%s) is unreachable after push - panic-revert scheduler "
|
||||
"will auto-revert to %s.backup",
|
||||
hostname, ip_address, _PRE_PUSH_BACKUP,
|
||||
)
|
||||
await _update_job(
|
||||
job_id, status="reverted",
|
||||
error_message="Device unreachable after push; auto-reverted via panic-revert scheduler",
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _check_reachability(ip: str, username: str, password: str) -> bool:
|
||||
"""Check if a RouterOS device is reachable via SSH."""
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip, port=22,
|
||||
username=username, password=password,
|
||||
known_hosts=None, connect_timeout=30,
|
||||
) as conn:
|
||||
result = await conn.run("/system identity print", check=True)
|
||||
logger.debug("Reachability check OK for %s: %r", ip, result.stdout[:50])
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.info("Device %s unreachable after push: %s", ip, exc)
|
||||
return False
|
||||
|
||||
|
||||
async def _update_job(job_id: str, **kwargs) -> None:
|
||||
"""Update TemplatePushJob fields via raw SQL (background task, no RLS)."""
|
||||
sets = []
|
||||
params: dict = {"job_id": job_id}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
param_name = f"v_{key}"
|
||||
if value is None and key in ("error_message", "started_at", "completed_at", "pre_push_backup_sha"):
|
||||
sets.append(f"{key} = NULL")
|
||||
else:
|
||||
sets.append(f"{key} = :{param_name}")
|
||||
params[param_name] = value
|
||||
|
||||
if not sets:
|
||||
return
|
||||
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
await session.execute(
|
||||
text(f"""
|
||||
UPDATE template_push_jobs
|
||||
SET {', '.join(sets)}
|
||||
WHERE id = CAST(:job_id AS uuid)
|
||||
"""),
|
||||
params,
|
||||
)
|
||||
await session.commit()
|
||||
564
backend/app/services/upgrade_service.py
Normal file
564
backend/app/services/upgrade_service.py
Normal file
@@ -0,0 +1,564 @@
|
||||
"""Firmware upgrade orchestration service.
|
||||
|
||||
Handles single-device and mass firmware upgrades with:
|
||||
- Mandatory pre-upgrade config backup
|
||||
- NPK download and SFTP upload to device
|
||||
- Reboot trigger and reconnect polling
|
||||
- Post-upgrade version verification
|
||||
- Sequential mass rollout with pause-on-failure
|
||||
- Scheduled upgrades via APScheduler DateTrigger
|
||||
|
||||
All DB operations use AdminAsyncSessionLocal to bypass RLS since upgrade
|
||||
jobs may span multiple tenants and run in background asyncio tasks.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import asyncssh
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
from app.services.event_publisher import publish_event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum time to wait for a device to reconnect after reboot (seconds)
|
||||
_RECONNECT_TIMEOUT = 300 # 5 minutes
|
||||
_RECONNECT_POLL_INTERVAL = 15 # seconds
|
||||
_INITIAL_WAIT = 60 # Wait before first reconnect attempt (boot cycle)
|
||||
|
||||
|
||||
async def start_upgrade(job_id: str) -> None:
|
||||
"""Execute a single device firmware upgrade.
|
||||
|
||||
Lifecycle: pending -> downloading -> uploading -> rebooting -> verifying -> completed/failed
|
||||
|
||||
This function is designed to run as a background asyncio.create_task or
|
||||
APScheduler job. It never raises — all errors are caught and recorded
|
||||
in the FirmwareUpgradeJob row.
|
||||
"""
|
||||
try:
|
||||
await _run_upgrade(job_id)
|
||||
except Exception as exc:
|
||||
logger.error("Uncaught exception in firmware upgrade %s: %s", job_id, exc, exc_info=True)
|
||||
await _update_job(job_id, status="failed", error_message=f"Unexpected error: {exc}")
|
||||
|
||||
|
||||
async def _publish_upgrade_progress(
|
||||
tenant_id: str,
|
||||
device_id: str,
|
||||
job_id: str,
|
||||
stage: str,
|
||||
target_version: str,
|
||||
message: str,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Publish firmware upgrade progress event to NATS (fire-and-forget)."""
|
||||
payload = {
|
||||
"event_type": "firmware_progress",
|
||||
"tenant_id": tenant_id,
|
||||
"device_id": device_id,
|
||||
"job_id": job_id,
|
||||
"stage": stage,
|
||||
"target_version": target_version,
|
||||
"message": message,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
if error:
|
||||
payload["error"] = error
|
||||
await publish_event(f"firmware.progress.{tenant_id}.{device_id}", payload)
|
||||
|
||||
|
||||
async def _run_upgrade(job_id: str) -> None:
|
||||
"""Internal upgrade implementation."""
|
||||
|
||||
# Step 1: Load job
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT j.id, j.device_id, j.tenant_id, j.target_version,
|
||||
j.architecture, j.channel, j.status, j.confirmed_major_upgrade,
|
||||
d.ip_address, d.hostname, d.encrypted_credentials,
|
||||
d.routeros_version, d.encrypted_credentials_transit
|
||||
FROM firmware_upgrade_jobs j
|
||||
JOIN devices d ON d.id = j.device_id
|
||||
WHERE j.id = CAST(:job_id AS uuid)
|
||||
"""),
|
||||
{"job_id": job_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
|
||||
if not row:
|
||||
logger.error("Upgrade job %s not found", job_id)
|
||||
return
|
||||
|
||||
(
|
||||
_, device_id, tenant_id, target_version,
|
||||
architecture, channel, status, confirmed_major,
|
||||
ip_address, hostname, encrypted_credentials,
|
||||
current_version, encrypted_credentials_transit,
|
||||
) = row
|
||||
|
||||
device_id = str(device_id)
|
||||
tenant_id = str(tenant_id)
|
||||
hostname = hostname or ip_address
|
||||
|
||||
# Skip if already running or completed
|
||||
if status not in ("pending", "scheduled"):
|
||||
logger.info("Upgrade job %s already in status %s — skipping", job_id, status)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Starting firmware upgrade for %s (%s): %s -> %s",
|
||||
hostname, ip_address, current_version, target_version,
|
||||
)
|
||||
|
||||
# Step 2: Update status to downloading
|
||||
await _update_job(job_id, status="downloading", started_at=datetime.now(timezone.utc))
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "downloading", target_version, f"Downloading firmware {target_version} for {hostname}")
|
||||
|
||||
# Step 3: Check major version upgrade confirmation
|
||||
if current_version and target_version:
|
||||
current_major = current_version.split(".")[0] if current_version else ""
|
||||
target_major = target_version.split(".")[0]
|
||||
if current_major != target_major and not confirmed_major:
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message="Major version upgrade requires explicit confirmation",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Major version upgrade requires explicit confirmation for {hostname}", error="Major version upgrade requires explicit confirmation")
|
||||
return
|
||||
|
||||
# Step 4: Mandatory config backup
|
||||
logger.info("Running mandatory pre-upgrade backup for %s", hostname)
|
||||
try:
|
||||
from app.services import backup_service
|
||||
backup_result = await backup_service.run_backup(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
trigger_type="pre-upgrade",
|
||||
)
|
||||
backup_sha = backup_result["commit_sha"]
|
||||
await _update_job(job_id, pre_upgrade_backup_sha=backup_sha)
|
||||
logger.info("Pre-upgrade backup complete: %s", backup_sha[:8])
|
||||
except Exception as backup_err:
|
||||
logger.error("Pre-upgrade backup failed for %s: %s", hostname, backup_err)
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message=f"Pre-upgrade backup failed: {backup_err}",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Pre-upgrade backup failed for {hostname}", error=str(backup_err))
|
||||
return
|
||||
|
||||
# Step 5: Download NPK
|
||||
logger.info("Downloading firmware %s for %s/%s", target_version, architecture, channel)
|
||||
try:
|
||||
from app.services.firmware_service import download_firmware
|
||||
npk_path = await download_firmware(architecture, channel, target_version)
|
||||
logger.info("Firmware cached at %s", npk_path)
|
||||
except Exception as dl_err:
|
||||
logger.error("Firmware download failed: %s", dl_err)
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message=f"Firmware download failed: {dl_err}",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Firmware download failed for {hostname}", error=str(dl_err))
|
||||
return
|
||||
|
||||
# Step 6: Upload NPK to device via SFTP
|
||||
await _update_job(job_id, status="uploading")
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "uploading", target_version, f"Uploading firmware to {hostname}")
|
||||
|
||||
# Decrypt device credentials (dual-read: Transit preferred, legacy fallback)
|
||||
if not encrypted_credentials_transit and not encrypted_credentials:
|
||||
await _update_job(job_id, status="failed", error_message="Device has no stored credentials")
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"No stored credentials for {hostname}", error="Device has no stored credentials")
|
||||
return
|
||||
|
||||
try:
|
||||
from app.services.crypto import decrypt_credentials_hybrid
|
||||
key = settings.get_encryption_key_bytes()
|
||||
creds_json = await decrypt_credentials_hybrid(
|
||||
encrypted_credentials_transit, encrypted_credentials, tenant_id, key,
|
||||
)
|
||||
creds = json.loads(creds_json)
|
||||
ssh_username = creds.get("username", "")
|
||||
ssh_password = creds.get("password", "")
|
||||
except Exception as cred_err:
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message=f"Failed to decrypt credentials: {cred_err}",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Failed to decrypt credentials for {hostname}", error=str(cred_err))
|
||||
return
|
||||
|
||||
try:
|
||||
npk_data = Path(npk_path).read_bytes()
|
||||
npk_filename = Path(npk_path).name
|
||||
|
||||
async with asyncssh.connect(
|
||||
ip_address,
|
||||
port=22,
|
||||
username=ssh_username,
|
||||
password=ssh_password,
|
||||
known_hosts=None,
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
async with sftp.open(f"/{npk_filename}", "wb") as f:
|
||||
await f.write(npk_data)
|
||||
logger.info("Uploaded %s to %s", npk_filename, hostname)
|
||||
except Exception as upload_err:
|
||||
logger.error("NPK upload failed for %s: %s", hostname, upload_err)
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message=f"NPK upload failed: {upload_err}",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"NPK upload failed for {hostname}", error=str(upload_err))
|
||||
return
|
||||
|
||||
# Step 7: Trigger reboot
|
||||
await _update_job(job_id, status="rebooting")
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "rebooting", target_version, f"Rebooting {hostname} for firmware install")
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip_address,
|
||||
port=22,
|
||||
username=ssh_username,
|
||||
password=ssh_password,
|
||||
known_hosts=None,
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
# RouterOS will install NPK on boot
|
||||
await conn.run("/system reboot", check=False)
|
||||
logger.info("Reboot command sent to %s", hostname)
|
||||
except Exception as reboot_err:
|
||||
# Device may drop connection during reboot — this is expected
|
||||
logger.info("Device %s dropped connection after reboot command (expected): %s", hostname, reboot_err)
|
||||
|
||||
# Step 8: Wait for reconnect
|
||||
logger.info("Waiting %ds before polling %s for reconnect", _INITIAL_WAIT, hostname)
|
||||
await asyncio.sleep(_INITIAL_WAIT)
|
||||
|
||||
reconnected = False
|
||||
elapsed = 0
|
||||
while elapsed < _RECONNECT_TIMEOUT:
|
||||
if await _check_ssh_reachable(ip_address, ssh_username, ssh_password):
|
||||
reconnected = True
|
||||
break
|
||||
await asyncio.sleep(_RECONNECT_POLL_INTERVAL)
|
||||
elapsed += _RECONNECT_POLL_INTERVAL
|
||||
|
||||
if not reconnected:
|
||||
logger.error("Device %s did not reconnect within %ds", hostname, _RECONNECT_TIMEOUT)
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message=f"Device did not reconnect within {_RECONNECT_TIMEOUT // 60} minutes after reboot",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Device {hostname} did not reconnect within {_RECONNECT_TIMEOUT // 60} minutes", error="Reconnect timeout")
|
||||
return
|
||||
|
||||
# Step 9: Verify upgrade
|
||||
await _update_job(job_id, status="verifying")
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "verifying", target_version, f"Verifying firmware version on {hostname}")
|
||||
try:
|
||||
actual_version = await _get_device_version(ip_address, ssh_username, ssh_password)
|
||||
if actual_version and target_version in actual_version:
|
||||
logger.info(
|
||||
"Firmware upgrade verified for %s: %s",
|
||||
hostname, actual_version,
|
||||
)
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="completed",
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "completed", target_version, f"Firmware upgrade to {target_version} completed on {hostname}")
|
||||
else:
|
||||
logger.error(
|
||||
"Version mismatch for %s: expected %s, got %s",
|
||||
hostname, target_version, actual_version,
|
||||
)
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message=f"Expected {target_version} but got {actual_version}",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Version mismatch on {hostname}: expected {target_version}, got {actual_version}", error=f"Expected {target_version} but got {actual_version}")
|
||||
except Exception as verify_err:
|
||||
logger.error("Post-upgrade verification failed for %s: %s", hostname, verify_err)
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message=f"Post-upgrade verification failed: {verify_err}",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Post-upgrade verification failed for {hostname}", error=str(verify_err))
|
||||
|
||||
|
||||
async def start_mass_upgrade(rollout_group_id: str) -> dict:
|
||||
"""Execute a sequential mass firmware upgrade.
|
||||
|
||||
Processes upgrade jobs one at a time. If any device fails,
|
||||
all remaining jobs in the group are paused.
|
||||
|
||||
Returns summary dict with completed/failed/paused counts.
|
||||
"""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT j.id, j.status, d.hostname
|
||||
FROM firmware_upgrade_jobs j
|
||||
JOIN devices d ON d.id = j.device_id
|
||||
WHERE j.rollout_group_id = CAST(:group_id AS uuid)
|
||||
ORDER BY j.created_at ASC
|
||||
"""),
|
||||
{"group_id": rollout_group_id},
|
||||
)
|
||||
jobs = result.fetchall()
|
||||
|
||||
if not jobs:
|
||||
logger.warning("No jobs found for rollout group %s", rollout_group_id)
|
||||
return {"completed": 0, "failed": 0, "paused": 0}
|
||||
|
||||
completed = 0
|
||||
failed_device = None
|
||||
|
||||
for job_id, current_status, hostname in jobs:
|
||||
job_id_str = str(job_id)
|
||||
|
||||
# Only process pending/scheduled jobs
|
||||
if current_status not in ("pending", "scheduled"):
|
||||
if current_status == "completed":
|
||||
completed += 1
|
||||
continue
|
||||
|
||||
logger.info("Mass rollout: upgrading device %s (job %s)", hostname, job_id_str)
|
||||
await start_upgrade(job_id_str)
|
||||
|
||||
# Check if it completed or failed
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("SELECT status FROM firmware_upgrade_jobs WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": job_id_str},
|
||||
)
|
||||
row = result.fetchone()
|
||||
|
||||
if row and row[0] == "completed":
|
||||
completed += 1
|
||||
elif row and row[0] == "failed":
|
||||
failed_device = hostname
|
||||
logger.error("Mass rollout paused: %s failed", hostname)
|
||||
break
|
||||
|
||||
# Pause remaining jobs if one failed
|
||||
paused = 0
|
||||
if failed_device:
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
UPDATE firmware_upgrade_jobs
|
||||
SET status = 'paused'
|
||||
WHERE rollout_group_id = CAST(:group_id AS uuid)
|
||||
AND status IN ('pending', 'scheduled')
|
||||
RETURNING id
|
||||
"""),
|
||||
{"group_id": rollout_group_id},
|
||||
)
|
||||
paused = len(result.fetchall())
|
||||
await session.commit()
|
||||
|
||||
return {
|
||||
"completed": completed,
|
||||
"failed": 1 if failed_device else 0,
|
||||
"failed_device": failed_device,
|
||||
"paused": paused,
|
||||
}
|
||||
|
||||
|
||||
def schedule_upgrade(job_id: str, scheduled_at: datetime) -> None:
|
||||
"""Schedule a firmware upgrade for future execution via APScheduler."""
|
||||
from app.services.backup_scheduler import backup_scheduler
|
||||
|
||||
backup_scheduler.add_job(
|
||||
start_upgrade,
|
||||
trigger="date",
|
||||
run_date=scheduled_at,
|
||||
args=[job_id],
|
||||
id=f"fw_upgrade_{job_id}",
|
||||
name=f"Firmware upgrade {job_id}",
|
||||
max_instances=1,
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info("Scheduled firmware upgrade %s for %s", job_id, scheduled_at)
|
||||
|
||||
|
||||
def schedule_mass_upgrade(rollout_group_id: str, scheduled_at: datetime) -> None:
|
||||
"""Schedule a mass firmware upgrade for future execution."""
|
||||
from app.services.backup_scheduler import backup_scheduler
|
||||
|
||||
backup_scheduler.add_job(
|
||||
start_mass_upgrade,
|
||||
trigger="date",
|
||||
run_date=scheduled_at,
|
||||
args=[rollout_group_id],
|
||||
id=f"fw_mass_upgrade_{rollout_group_id}",
|
||||
name=f"Mass firmware upgrade {rollout_group_id}",
|
||||
max_instances=1,
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info("Scheduled mass firmware upgrade %s for %s", rollout_group_id, scheduled_at)
|
||||
|
||||
|
||||
async def cancel_upgrade(job_id: str) -> None:
|
||||
"""Cancel a scheduled or pending upgrade."""
|
||||
from app.services.backup_scheduler import backup_scheduler
|
||||
|
||||
# Remove APScheduler job if it exists
|
||||
try:
|
||||
backup_scheduler.remove_job(f"fw_upgrade_{job_id}")
|
||||
except Exception:
|
||||
pass # Job might not be scheduled
|
||||
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message="Cancelled by operator",
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
logger.info("Upgrade job %s cancelled", job_id)
|
||||
|
||||
|
||||
async def retry_failed_upgrade(job_id: str) -> None:
|
||||
"""Reset a failed upgrade job to pending and re-execute."""
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="pending",
|
||||
error_message=None,
|
||||
started_at=None,
|
||||
completed_at=None,
|
||||
)
|
||||
asyncio.create_task(start_upgrade(job_id))
|
||||
logger.info("Retrying upgrade job %s", job_id)
|
||||
|
||||
|
||||
async def resume_mass_upgrade(rollout_group_id: str) -> None:
|
||||
"""Resume a paused mass rollout from where it left off."""
|
||||
# Reset first paused job to pending, then restart sequential processing
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
UPDATE firmware_upgrade_jobs
|
||||
SET status = 'pending'
|
||||
WHERE rollout_group_id = CAST(:group_id AS uuid)
|
||||
AND status = 'paused'
|
||||
"""),
|
||||
{"group_id": rollout_group_id},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
asyncio.create_task(start_mass_upgrade(rollout_group_id))
|
||||
logger.info("Resuming mass rollout %s", rollout_group_id)
|
||||
|
||||
|
||||
async def abort_mass_upgrade(rollout_group_id: str) -> int:
|
||||
"""Abort all remaining jobs in a paused mass rollout."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
UPDATE firmware_upgrade_jobs
|
||||
SET status = 'failed',
|
||||
error_message = 'Aborted by operator',
|
||||
completed_at = NOW()
|
||||
WHERE rollout_group_id = CAST(:group_id AS uuid)
|
||||
AND status IN ('pending', 'scheduled', 'paused')
|
||||
RETURNING id
|
||||
"""),
|
||||
{"group_id": rollout_group_id},
|
||||
)
|
||||
aborted = len(result.fetchall())
|
||||
await session.commit()
|
||||
|
||||
logger.info("Aborted %d remaining jobs in rollout %s", aborted, rollout_group_id)
|
||||
return aborted
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _update_job(job_id: str, **kwargs) -> None:
|
||||
"""Update FirmwareUpgradeJob fields."""
|
||||
sets = []
|
||||
params: dict = {"job_id": job_id}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
param_name = f"v_{key}"
|
||||
if value is None and key in ("error_message", "started_at", "completed_at"):
|
||||
sets.append(f"{key} = NULL")
|
||||
else:
|
||||
sets.append(f"{key} = :{param_name}")
|
||||
params[param_name] = value
|
||||
|
||||
if not sets:
|
||||
return
|
||||
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
await session.execute(
|
||||
text(f"""
|
||||
UPDATE firmware_upgrade_jobs
|
||||
SET {', '.join(sets)}
|
||||
WHERE id = CAST(:job_id AS uuid)
|
||||
"""),
|
||||
params,
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def _check_ssh_reachable(ip: str, username: str, password: str) -> bool:
|
||||
"""Check if a device is reachable via SSH."""
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip,
|
||||
port=22,
|
||||
username=username,
|
||||
password=password,
|
||||
known_hosts=None,
|
||||
connect_timeout=15,
|
||||
) as conn:
|
||||
await conn.run("/system identity print", check=True)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def _get_device_version(ip: str, username: str, password: str) -> str:
|
||||
"""Get the current RouterOS version from a device via SSH."""
|
||||
async with asyncssh.connect(
|
||||
ip,
|
||||
port=22,
|
||||
username=username,
|
||||
password=password,
|
||||
known_hosts=None,
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
result = await conn.run("/system resource print", check=True)
|
||||
# Parse version from output: "version: 7.17 (stable)"
|
||||
for line in result.stdout.splitlines():
|
||||
if "version" in line.lower():
|
||||
parts = line.split(":", 1)
|
||||
if len(parts) == 2:
|
||||
return parts[1].strip()
|
||||
return ""
|
||||
392
backend/app/services/vpn_service.py
Normal file
392
backend/app/services/vpn_service.py
Normal file
@@ -0,0 +1,392 @@
|
||||
"""WireGuard VPN management service.
|
||||
|
||||
Handles key generation, peer management, config file sync, and RouterOS command generation.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import ipaddress
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import structlog
|
||||
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
|
||||
from cryptography.hazmat.primitives.serialization import (
|
||||
Encoding,
|
||||
NoEncryption,
|
||||
PrivateFormat,
|
||||
PublicFormat,
|
||||
)
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.models.device import Device
|
||||
from app.models.vpn import VpnConfig, VpnPeer
|
||||
from app.services.crypto import decrypt_credentials, encrypt_credentials, encrypt_credentials_transit
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
# ── Key Generation ──
|
||||
|
||||
|
||||
def generate_wireguard_keypair() -> tuple[str, str]:
|
||||
"""Generate a WireGuard X25519 keypair. Returns (private_key_b64, public_key_b64)."""
|
||||
private_key = X25519PrivateKey.generate()
|
||||
priv_bytes = private_key.private_bytes(Encoding.Raw, PrivateFormat.Raw, NoEncryption())
|
||||
pub_bytes = private_key.public_key().public_bytes(Encoding.Raw, PublicFormat.Raw)
|
||||
return base64.b64encode(priv_bytes).decode(), base64.b64encode(pub_bytes).decode()
|
||||
|
||||
|
||||
def generate_preshared_key() -> str:
|
||||
"""Generate a WireGuard preshared key (32 random bytes, base64)."""
|
||||
return base64.b64encode(os.urandom(32)).decode()
|
||||
|
||||
|
||||
# ── Config File Management ──
|
||||
|
||||
|
||||
def _get_wg_config_path() -> Path:
|
||||
"""Return the path to the shared WireGuard config directory."""
|
||||
return Path(os.getenv("WIREGUARD_CONFIG_PATH", "/data/wireguard"))
|
||||
|
||||
|
||||
async def sync_wireguard_config(db: AsyncSession, tenant_id: uuid.UUID) -> None:
|
||||
"""Regenerate wg0.conf from database state and write to shared volume."""
|
||||
config = await get_vpn_config(db, tenant_id)
|
||||
if not config or not config.is_enabled:
|
||||
return
|
||||
|
||||
key_bytes = settings.get_encryption_key_bytes()
|
||||
server_private_key = decrypt_credentials(config.server_private_key, key_bytes)
|
||||
|
||||
result = await db.execute(
|
||||
select(VpnPeer).where(VpnPeer.tenant_id == tenant_id, VpnPeer.is_enabled.is_(True))
|
||||
)
|
||||
peers = result.scalars().all()
|
||||
|
||||
# Build wg0.conf
|
||||
lines = [
|
||||
"[Interface]",
|
||||
f"Address = {config.server_address}",
|
||||
f"ListenPort = {config.server_port}",
|
||||
f"PrivateKey = {server_private_key}",
|
||||
"",
|
||||
]
|
||||
|
||||
for peer in peers:
|
||||
peer_ip = peer.assigned_ip.split("/")[0] # strip CIDR for AllowedIPs
|
||||
allowed_ips = [f"{peer_ip}/32"]
|
||||
if peer.additional_allowed_ips:
|
||||
# Comma-separated additional subnets (e.g. site-to-site routing)
|
||||
extra = [s.strip() for s in peer.additional_allowed_ips.split(",") if s.strip()]
|
||||
allowed_ips.extend(extra)
|
||||
lines.append("[Peer]")
|
||||
lines.append(f"PublicKey = {peer.peer_public_key}")
|
||||
if peer.preshared_key:
|
||||
psk = decrypt_credentials(peer.preshared_key, key_bytes)
|
||||
lines.append(f"PresharedKey = {psk}")
|
||||
lines.append(f"AllowedIPs = {', '.join(allowed_ips)}")
|
||||
lines.append("")
|
||||
|
||||
config_dir = _get_wg_config_path()
|
||||
wg_confs_dir = config_dir / "wg_confs"
|
||||
wg_confs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
conf_path = wg_confs_dir / "wg0.conf"
|
||||
conf_path.write_text("\n".join(lines))
|
||||
|
||||
# Signal WireGuard container to reload config
|
||||
reload_flag = wg_confs_dir / ".reload"
|
||||
reload_flag.write_text("1")
|
||||
|
||||
logger.info("wireguard config synced", tenant_id=str(tenant_id), peers=len(peers))
|
||||
|
||||
|
||||
# ── Live Status ──
|
||||
|
||||
|
||||
def read_wg_status() -> dict[str, dict]:
|
||||
"""Read live WireGuard peer status from the shared volume.
|
||||
|
||||
The WireGuard container writes wg_status.json every 15 seconds
|
||||
with output from `wg show wg0 dump`. Returns a dict keyed by
|
||||
peer public key with handshake timestamp and transfer stats.
|
||||
"""
|
||||
status_path = _get_wg_config_path() / "wg_status.json"
|
||||
if not status_path.exists():
|
||||
return {}
|
||||
try:
|
||||
data = json.loads(status_path.read_text())
|
||||
return {entry["public_key"]: entry for entry in data}
|
||||
except (json.JSONDecodeError, KeyError, OSError):
|
||||
return {}
|
||||
|
||||
|
||||
def get_peer_handshake(wg_status: dict[str, dict], public_key: str) -> Optional[datetime]:
|
||||
"""Get last_handshake datetime for a peer from live WireGuard status."""
|
||||
entry = wg_status.get(public_key)
|
||||
if not entry:
|
||||
return None
|
||||
ts = entry.get("last_handshake", 0)
|
||||
if ts and ts > 0:
|
||||
return datetime.fromtimestamp(ts, tz=timezone.utc)
|
||||
return None
|
||||
|
||||
|
||||
# ── CRUD Operations ──
|
||||
|
||||
|
||||
async def get_vpn_config(db: AsyncSession, tenant_id: uuid.UUID) -> Optional[VpnConfig]:
|
||||
"""Get the VPN config for a tenant."""
|
||||
result = await db.execute(select(VpnConfig).where(VpnConfig.tenant_id == tenant_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def setup_vpn(
|
||||
db: AsyncSession, tenant_id: uuid.UUID, endpoint: Optional[str] = None
|
||||
) -> VpnConfig:
|
||||
"""Initialize VPN for a tenant — generates server keys and creates config."""
|
||||
existing = await get_vpn_config(db, tenant_id)
|
||||
if existing:
|
||||
raise ValueError("VPN already configured for this tenant")
|
||||
|
||||
private_key_b64, public_key_b64 = generate_wireguard_keypair()
|
||||
|
||||
key_bytes = settings.get_encryption_key_bytes()
|
||||
encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
|
||||
|
||||
config = VpnConfig(
|
||||
tenant_id=tenant_id,
|
||||
server_private_key=encrypted_private,
|
||||
server_public_key=public_key_b64,
|
||||
endpoint=endpoint,
|
||||
is_enabled=True,
|
||||
)
|
||||
db.add(config)
|
||||
await db.flush()
|
||||
|
||||
await sync_wireguard_config(db, tenant_id)
|
||||
return config
|
||||
|
||||
|
||||
async def update_vpn_config(
|
||||
db: AsyncSession, tenant_id: uuid.UUID, endpoint: Optional[str] = None, is_enabled: Optional[bool] = None
|
||||
) -> VpnConfig:
|
||||
"""Update VPN config settings."""
|
||||
config = await get_vpn_config(db, tenant_id)
|
||||
if not config:
|
||||
raise ValueError("VPN not configured for this tenant")
|
||||
|
||||
if endpoint is not None:
|
||||
config.endpoint = endpoint
|
||||
if is_enabled is not None:
|
||||
config.is_enabled = is_enabled
|
||||
|
||||
await db.flush()
|
||||
await sync_wireguard_config(db, tenant_id)
|
||||
return config
|
||||
|
||||
|
||||
async def get_peers(db: AsyncSession, tenant_id: uuid.UUID) -> list[VpnPeer]:
|
||||
"""List all VPN peers for a tenant."""
|
||||
result = await db.execute(
|
||||
select(VpnPeer).where(VpnPeer.tenant_id == tenant_id).order_by(VpnPeer.created_at)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def _next_available_ip(db: AsyncSession, tenant_id: uuid.UUID, config: VpnConfig) -> str:
|
||||
"""Allocate the next available IP in the VPN subnet."""
|
||||
# Parse subnet: e.g. "10.10.0.0/24" → start from .2 (server is .1)
|
||||
network = ipaddress.ip_network(config.subnet, strict=False)
|
||||
hosts = list(network.hosts())
|
||||
|
||||
# Get already assigned IPs
|
||||
result = await db.execute(select(VpnPeer.assigned_ip).where(VpnPeer.tenant_id == tenant_id))
|
||||
used_ips = {row[0].split("/")[0] for row in result.all()}
|
||||
used_ips.add(config.server_address.split("/")[0]) # exclude server IP
|
||||
|
||||
for host in hosts[1:]: # skip .1 (server)
|
||||
if str(host) not in used_ips:
|
||||
return f"{host}/24"
|
||||
|
||||
raise ValueError("No available IPs in VPN subnet")
|
||||
|
||||
|
||||
async def add_peer(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UUID, additional_allowed_ips: Optional[str] = None) -> VpnPeer:
|
||||
"""Add a device as a VPN peer."""
|
||||
config = await get_vpn_config(db, tenant_id)
|
||||
if not config:
|
||||
raise ValueError("VPN not configured — enable VPN first")
|
||||
|
||||
# Check device exists
|
||||
device = await db.execute(select(Device).where(Device.id == device_id, Device.tenant_id == tenant_id))
|
||||
if not device.scalar_one_or_none():
|
||||
raise ValueError("Device not found")
|
||||
|
||||
# Check if already a peer
|
||||
existing = await db.execute(select(VpnPeer).where(VpnPeer.device_id == device_id))
|
||||
if existing.scalar_one_or_none():
|
||||
raise ValueError("Device is already a VPN peer")
|
||||
|
||||
private_key_b64, public_key_b64 = generate_wireguard_keypair()
|
||||
psk = generate_preshared_key()
|
||||
|
||||
key_bytes = settings.get_encryption_key_bytes()
|
||||
encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
|
||||
encrypted_psk = encrypt_credentials(psk, key_bytes)
|
||||
|
||||
assigned_ip = await _next_available_ip(db, tenant_id, config)
|
||||
|
||||
peer = VpnPeer(
|
||||
tenant_id=tenant_id,
|
||||
device_id=device_id,
|
||||
peer_private_key=encrypted_private,
|
||||
peer_public_key=public_key_b64,
|
||||
preshared_key=encrypted_psk,
|
||||
assigned_ip=assigned_ip,
|
||||
additional_allowed_ips=additional_allowed_ips,
|
||||
)
|
||||
db.add(peer)
|
||||
await db.flush()
|
||||
|
||||
await sync_wireguard_config(db, tenant_id)
|
||||
return peer
|
||||
|
||||
|
||||
async def remove_peer(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.UUID) -> None:
|
||||
"""Remove a VPN peer."""
|
||||
result = await db.execute(
|
||||
select(VpnPeer).where(VpnPeer.id == peer_id, VpnPeer.tenant_id == tenant_id)
|
||||
)
|
||||
peer = result.scalar_one_or_none()
|
||||
if not peer:
|
||||
raise ValueError("Peer not found")
|
||||
|
||||
await db.delete(peer)
|
||||
await db.flush()
|
||||
await sync_wireguard_config(db, tenant_id)
|
||||
|
||||
|
||||
async def get_peer_config(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.UUID) -> dict:
|
||||
"""Get the full config for a peer — includes private key for device setup."""
|
||||
config = await get_vpn_config(db, tenant_id)
|
||||
if not config:
|
||||
raise ValueError("VPN not configured")
|
||||
|
||||
result = await db.execute(
|
||||
select(VpnPeer).where(VpnPeer.id == peer_id, VpnPeer.tenant_id == tenant_id)
|
||||
)
|
||||
peer = result.scalar_one_or_none()
|
||||
if not peer:
|
||||
raise ValueError("Peer not found")
|
||||
|
||||
key_bytes = settings.get_encryption_key_bytes()
|
||||
private_key = decrypt_credentials(peer.peer_private_key, key_bytes)
|
||||
psk = decrypt_credentials(peer.preshared_key, key_bytes) if peer.preshared_key else None
|
||||
|
||||
endpoint = config.endpoint or "YOUR_SERVER_IP:51820"
|
||||
peer_ip_no_cidr = peer.assigned_ip.split("/")[0]
|
||||
|
||||
routeros_commands = [
|
||||
f'/interface wireguard add name=wg-portal listen-port=13231 private-key="{private_key}"',
|
||||
f'/interface wireguard peers add interface=wg-portal public-key="{config.server_public_key}" '
|
||||
f'endpoint-address={endpoint.split(":")[0]} endpoint-port={endpoint.split(":")[-1]} '
|
||||
f'allowed-address={config.subnet} persistent-keepalive=25'
|
||||
+ (f' preshared-key="{psk}"' if psk else ""),
|
||||
f"/ip address add address={peer.assigned_ip} interface=wg-portal",
|
||||
]
|
||||
|
||||
return {
|
||||
"peer_private_key": private_key,
|
||||
"peer_public_key": peer.peer_public_key,
|
||||
"assigned_ip": peer.assigned_ip,
|
||||
"server_public_key": config.server_public_key,
|
||||
"server_endpoint": endpoint,
|
||||
"allowed_ips": config.subnet,
|
||||
"routeros_commands": routeros_commands,
|
||||
}
|
||||
|
||||
|
||||
async def onboard_device(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
hostname: str,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> dict:
|
||||
"""Create device + VPN peer in one transaction. Returns device, peer, and RouterOS commands.
|
||||
|
||||
Unlike regular device creation, this skips TCP connectivity checks because
|
||||
the VPN tunnel isn't up yet. The device IP is set to the VPN-assigned address.
|
||||
"""
|
||||
config = await get_vpn_config(db, tenant_id)
|
||||
if not config:
|
||||
raise ValueError("VPN not configured — enable VPN first")
|
||||
|
||||
# Allocate VPN IP before creating device
|
||||
assigned_ip = await _next_available_ip(db, tenant_id, config)
|
||||
vpn_ip_no_cidr = assigned_ip.split("/")[0]
|
||||
|
||||
# Create device with VPN IP (skip TCP check — tunnel not up yet)
|
||||
credentials_json = json.dumps({"username": username, "password": password})
|
||||
transit_ciphertext = await encrypt_credentials_transit(credentials_json, str(tenant_id))
|
||||
|
||||
device = Device(
|
||||
tenant_id=tenant_id,
|
||||
hostname=hostname,
|
||||
ip_address=vpn_ip_no_cidr,
|
||||
api_port=8728,
|
||||
api_ssl_port=8729,
|
||||
encrypted_credentials_transit=transit_ciphertext,
|
||||
status="unknown",
|
||||
)
|
||||
db.add(device)
|
||||
await db.flush()
|
||||
|
||||
# Create VPN peer linked to this device
|
||||
private_key_b64, public_key_b64 = generate_wireguard_keypair()
|
||||
psk = generate_preshared_key()
|
||||
|
||||
key_bytes = settings.get_encryption_key_bytes()
|
||||
encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
|
||||
encrypted_psk = encrypt_credentials(psk, key_bytes)
|
||||
|
||||
peer = VpnPeer(
|
||||
tenant_id=tenant_id,
|
||||
device_id=device.id,
|
||||
peer_private_key=encrypted_private,
|
||||
peer_public_key=public_key_b64,
|
||||
preshared_key=encrypted_psk,
|
||||
assigned_ip=assigned_ip,
|
||||
)
|
||||
db.add(peer)
|
||||
await db.flush()
|
||||
|
||||
await sync_wireguard_config(db, tenant_id)
|
||||
|
||||
# Generate RouterOS commands
|
||||
endpoint = config.endpoint or "YOUR_SERVER_IP:51820"
|
||||
psk_decrypted = decrypt_credentials(encrypted_psk, key_bytes)
|
||||
|
||||
routeros_commands = [
|
||||
f'/interface wireguard add name=wg-portal listen-port=13231 private-key="{private_key_b64}"',
|
||||
f'/interface wireguard peers add interface=wg-portal public-key="{config.server_public_key}" '
|
||||
f'endpoint-address={endpoint.split(":")[0]} endpoint-port={endpoint.split(":")[-1]} '
|
||||
f'allowed-address={config.subnet} persistent-keepalive=25'
|
||||
f' preshared-key="{psk_decrypted}"',
|
||||
f"/ip address add address={assigned_ip} interface=wg-portal",
|
||||
]
|
||||
|
||||
return {
|
||||
"device_id": device.id,
|
||||
"peer_id": peer.id,
|
||||
"hostname": hostname,
|
||||
"assigned_ip": assigned_ip,
|
||||
"routeros_commands": routeros_commands,
|
||||
}
|
||||
66
backend/app/templates/reports/alert_history.html
Normal file
66
backend/app/templates/reports/alert_history.html
Normal file
@@ -0,0 +1,66 @@
|
||||
{% extends "reports/base.html" %}
|
||||
|
||||
{% block content %}
|
||||
<div class="date-range">
|
||||
Report period: {{ date_from }} to {{ date_to }}
|
||||
</div>
|
||||
|
||||
<div class="summary-box">
|
||||
<div class="summary-stat">
|
||||
<div class="value">{{ total_alerts }}</div>
|
||||
<div class="label">Total Alerts</div>
|
||||
</div>
|
||||
<div class="summary-stat">
|
||||
<div class="value" style="color: #991B1B;">{{ critical_count }}</div>
|
||||
<div class="label">Critical</div>
|
||||
</div>
|
||||
<div class="summary-stat">
|
||||
<div class="value" style="color: #92400E;">{{ warning_count }}</div>
|
||||
<div class="label">Warning</div>
|
||||
</div>
|
||||
<div class="summary-stat">
|
||||
<div class="value" style="color: #1E40AF;">{{ info_count }}</div>
|
||||
<div class="label">Info</div>
|
||||
</div>
|
||||
{% if mttr_minutes is not none %}
|
||||
<div class="summary-stat">
|
||||
<div class="value">{{ mttr_display }}</div>
|
||||
<div class="label">Avg MTTR</div>
|
||||
</div>
|
||||
{% endif %}
|
||||
</div>
|
||||
|
||||
{% if alerts %}
|
||||
<h2 class="section-title">Alert Events</h2>
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Timestamp</th>
|
||||
<th>Device</th>
|
||||
<th>Severity</th>
|
||||
<th>Message</th>
|
||||
<th>Status</th>
|
||||
<th>Duration</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for alert in alerts %}
|
||||
<tr>
|
||||
<td>{{ alert.fired_at }}</td>
|
||||
<td style="font-weight: 600;">{{ alert.hostname or '-' }}</td>
|
||||
<td>
|
||||
<span class="badge badge-{{ alert.severity }}">{{ alert.severity | upper }}</span>
|
||||
</td>
|
||||
<td>{{ alert.message or '-' }}</td>
|
||||
<td>
|
||||
<span class="badge badge-{{ alert.status }}">{{ alert.status | upper }}</span>
|
||||
</td>
|
||||
<td>{{ alert.duration or '-' }}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
{% else %}
|
||||
<div class="no-data">No alerts found for the selected date range.</div>
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
208
backend/app/templates/reports/base.html
Normal file
208
backend/app/templates/reports/base.html
Normal file
@@ -0,0 +1,208 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>{{ report_title }} - TOD</title>
|
||||
<style>
|
||||
@page {
|
||||
size: {% block page_size %}A4 portrait{% endblock %};
|
||||
margin: 20mm 15mm 25mm 15mm;
|
||||
@bottom-center {
|
||||
content: "Page " counter(page) " of " counter(pages);
|
||||
font-size: 9px;
|
||||
color: #64748B;
|
||||
}
|
||||
@bottom-right {
|
||||
content: "Generated by TOD";
|
||||
font-size: 9px;
|
||||
color: #64748B;
|
||||
}
|
||||
}
|
||||
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
|
||||
font-size: 11px;
|
||||
color: #1E293B;
|
||||
line-height: 1.5;
|
||||
}
|
||||
|
||||
.report-header {
|
||||
background: #1E293B;
|
||||
color: #FFFFFF;
|
||||
padding: 16px 20px;
|
||||
margin: -20mm -15mm 20px -15mm;
|
||||
padding-left: 15mm;
|
||||
padding-right: 15mm;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.report-header .brand {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.report-header .brand .logo {
|
||||
width: 28px;
|
||||
height: 28px;
|
||||
background: #38BDF8;
|
||||
border-radius: 6px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-weight: 700;
|
||||
font-size: 16px;
|
||||
color: #0F172A;
|
||||
}
|
||||
|
||||
.report-header .brand .name {
|
||||
font-size: 14px;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.report-header .meta {
|
||||
text-align: right;
|
||||
font-size: 10px;
|
||||
color: #94A3B8;
|
||||
}
|
||||
|
||||
.report-header .meta .title {
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
color: #FFFFFF;
|
||||
margin-bottom: 2px;
|
||||
}
|
||||
|
||||
.report-body {
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.section-title {
|
||||
font-size: 13px;
|
||||
font-weight: 600;
|
||||
color: #0F172A;
|
||||
border-bottom: 2px solid #38BDF8;
|
||||
padding-bottom: 6px;
|
||||
margin-bottom: 12px;
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.section-title:first-child {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
margin-bottom: 16px;
|
||||
font-size: 10px;
|
||||
}
|
||||
|
||||
thead th {
|
||||
background: #F1F5F9;
|
||||
color: #475569;
|
||||
font-weight: 600;
|
||||
text-align: left;
|
||||
padding: 8px 10px;
|
||||
border-bottom: 2px solid #E2E8F0;
|
||||
font-size: 9px;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
|
||||
tbody td {
|
||||
padding: 7px 10px;
|
||||
border-bottom: 1px solid #F1F5F9;
|
||||
vertical-align: top;
|
||||
}
|
||||
|
||||
tbody tr:nth-child(even) {
|
||||
background: #FAFBFC;
|
||||
}
|
||||
|
||||
.summary-box {
|
||||
background: #F8FAFC;
|
||||
border: 1px solid #E2E8F0;
|
||||
border-radius: 6px;
|
||||
padding: 12px 16px;
|
||||
margin-bottom: 16px;
|
||||
display: flex;
|
||||
gap: 24px;
|
||||
}
|
||||
|
||||
.summary-stat {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.summary-stat .value {
|
||||
font-size: 20px;
|
||||
font-weight: 700;
|
||||
color: #0F172A;
|
||||
}
|
||||
|
||||
.summary-stat .label {
|
||||
font-size: 9px;
|
||||
color: #64748B;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
|
||||
.badge {
|
||||
display: inline-block;
|
||||
padding: 2px 8px;
|
||||
border-radius: 10px;
|
||||
font-size: 9px;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.badge-online { background: #DCFCE7; color: #166534; }
|
||||
.badge-offline { background: #FEE2E2; color: #991B1B; }
|
||||
.badge-unknown { background: #F1F5F9; color: #475569; }
|
||||
.badge-critical { background: #FEE2E2; color: #991B1B; }
|
||||
.badge-warning { background: #FEF3C7; color: #92400E; }
|
||||
.badge-info { background: #DBEAFE; color: #1E40AF; }
|
||||
.badge-firing { background: #FEE2E2; color: #991B1B; }
|
||||
.badge-resolved { background: #DCFCE7; color: #166534; }
|
||||
.badge-acknowledged { background: #DBEAFE; color: #1E40AF; }
|
||||
|
||||
.date-range {
|
||||
font-size: 11px;
|
||||
color: #64748B;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
.no-data {
|
||||
text-align: center;
|
||||
padding: 40px 20px;
|
||||
color: #94A3B8;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
{% block extra_styles %}{% endblock %}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="report-header">
|
||||
<div class="brand">
|
||||
<div class="logo">M</div>
|
||||
<div class="name">TOD - The Other Dude</div>
|
||||
</div>
|
||||
<div class="meta">
|
||||
<div class="title">{{ report_title }}</div>
|
||||
<div>{{ tenant_name }} • Generated {{ generated_at }}</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="report-body">
|
||||
{% block content %}{% endblock %}
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
46
backend/app/templates/reports/change_log.html
Normal file
46
backend/app/templates/reports/change_log.html
Normal file
@@ -0,0 +1,46 @@
|
||||
{% extends "reports/base.html" %}
|
||||
|
||||
{% block content %}
|
||||
<div class="date-range">
|
||||
Report period: {{ date_from }} to {{ date_to }}
|
||||
</div>
|
||||
|
||||
<div class="summary-box">
|
||||
<div class="summary-stat">
|
||||
<div class="value">{{ total_entries }}</div>
|
||||
<div class="label">Total Changes</div>
|
||||
</div>
|
||||
<div class="summary-stat">
|
||||
<div class="value">{{ data_source }}</div>
|
||||
<div class="label">Data Source</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{% if entries %}
|
||||
<h2 class="section-title">Change Log</h2>
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Timestamp</th>
|
||||
<th>User</th>
|
||||
<th>Action</th>
|
||||
<th>Device</th>
|
||||
<th>Details</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for entry in entries %}
|
||||
<tr>
|
||||
<td>{{ entry.timestamp }}</td>
|
||||
<td>{{ entry.user or '-' }}</td>
|
||||
<td>{{ entry.action }}</td>
|
||||
<td style="font-weight: 600;">{{ entry.device or '-' }}</td>
|
||||
<td>{{ entry.details or '-' }}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
{% else %}
|
||||
<div class="no-data">No changes found for the selected date range.</div>
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
59
backend/app/templates/reports/device_inventory.html
Normal file
59
backend/app/templates/reports/device_inventory.html
Normal file
@@ -0,0 +1,59 @@
|
||||
{% extends "reports/base.html" %}
|
||||
|
||||
{% block page_size %}A4 landscape{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<div class="summary-box">
|
||||
<div class="summary-stat">
|
||||
<div class="value">{{ total_devices }}</div>
|
||||
<div class="label">Total Devices</div>
|
||||
</div>
|
||||
<div class="summary-stat">
|
||||
<div class="value" style="color: #166534;">{{ online_count }}</div>
|
||||
<div class="label">Online</div>
|
||||
</div>
|
||||
<div class="summary-stat">
|
||||
<div class="value" style="color: #991B1B;">{{ offline_count }}</div>
|
||||
<div class="label">Offline</div>
|
||||
</div>
|
||||
<div class="summary-stat">
|
||||
<div class="value" style="color: #475569;">{{ unknown_count }}</div>
|
||||
<div class="label">Unknown</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{% if devices %}
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Hostname</th>
|
||||
<th>IP Address</th>
|
||||
<th>Model</th>
|
||||
<th>RouterOS</th>
|
||||
<th>Status</th>
|
||||
<th>Last Seen</th>
|
||||
<th>Uptime</th>
|
||||
<th>Groups</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for device in devices %}
|
||||
<tr>
|
||||
<td style="font-weight: 600;">{{ device.hostname }}</td>
|
||||
<td>{{ device.ip_address }}</td>
|
||||
<td>{{ device.model or '-' }}</td>
|
||||
<td>{{ device.routeros_version or '-' }}</td>
|
||||
<td>
|
||||
<span class="badge badge-{{ device.status }}">{{ device.status | upper }}</span>
|
||||
</td>
|
||||
<td>{{ device.last_seen or '-' }}</td>
|
||||
<td>{{ device.uptime or '-' }}</td>
|
||||
<td>{{ device.groups or '-' }}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
{% else %}
|
||||
<div class="no-data">No devices found for this tenant.</div>
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user