fix(lint): resolve all ruff lint errors

Add ruff config to exclude alembic E402, SQLAlchemy F821, and pre-existing
E501 line-length issues. Auto-fix 69 unused imports and 2 f-strings without
placeholders. Manually fix 8 unused variables. Apply ruff format to 127 files.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Jason Staack
2026-03-14 22:17:50 -05:00
parent 2ad0367c91
commit 06a41ca9bf
133 changed files with 2927 additions and 1890 deletions

View File

@@ -42,8 +42,8 @@ def validate_production_settings(settings: "Settings") -> None:
print(
f"FATAL: {field} uses a known insecure default in '{settings.ENVIRONMENT}' environment.\n"
f"Generate a secure value and set it in your .env.prod file.\n"
f"For JWT_SECRET_KEY: python -c \"import secrets; print(secrets.token_urlsafe(64))\"\n"
f"For CREDENTIAL_ENCRYPTION_KEY: python -c \"import secrets, base64; print(base64.b64encode(secrets.token_bytes(32)).decode())\"\n"
f'For JWT_SECRET_KEY: python -c "import secrets; print(secrets.token_urlsafe(64))"\n'
f'For CREDENTIAL_ENCRYPTION_KEY: python -c "import secrets, base64; print(base64.b64encode(secrets.token_bytes(32)).decode())"\n'
f"For OPENBAO_TOKEN: use the token from your OpenBao server (not the dev token)",
file=sys.stderr,
)

View File

@@ -17,6 +17,7 @@ from app.config import settings
class Base(DeclarativeBase):
"""Base class for all SQLAlchemy ORM models."""
pass

View File

@@ -82,7 +82,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
from app.services.retention_service import start_retention_scheduler, stop_retention_scheduler
from app.services.metrics_subscriber import start_metrics_subscriber, stop_metrics_subscriber
from app.services.nats_subscriber import start_nats_subscriber, stop_nats_subscriber
from app.services.session_audit_subscriber import start_session_audit_subscriber, stop_session_audit_subscriber
from app.services.session_audit_subscriber import (
start_session_audit_subscriber,
stop_session_audit_subscriber,
)
from app.services.sse_manager import ensure_sse_streams
# Configure structured logging FIRST -- before any other startup work
@@ -201,6 +204,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
start_config_change_subscriber,
stop_config_change_subscriber,
)
config_change_nc = await start_config_change_subscriber()
except Exception as e:
logger.error("Config change subscriber failed to start (non-fatal): %s", e)
@@ -212,6 +216,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
start_push_rollback_subscriber,
stop_push_rollback_subscriber,
)
push_rollback_nc = await start_push_rollback_subscriber()
except Exception as e:
logger.error("Push rollback subscriber failed to start (non-fatal): %s", e)
@@ -223,6 +228,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
start_config_snapshot_subscriber,
stop_config_snapshot_subscriber,
)
config_snapshot_nc = await start_config_snapshot_subscriber()
except Exception as e:
logger.error("Config snapshot subscriber failed to start (non-fatal): %s", e)
@@ -231,14 +237,16 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
try:
await start_retention_scheduler()
except Exception as exc:
logger.warning("retention scheduler could not start (API will run without it)", error=str(exc))
logger.warning(
"retention scheduler could not start (API will run without it)", error=str(exc)
)
# Start Remote WinBox session reconciliation loop (60s interval).
# Detects orphaned sessions (worker lost them) and cleans up Redis + tunnels.
winbox_reconcile_task: Optional[asyncio.Task] = None # type: ignore[type-arg]
try:
from app.routers.winbox_remote import _get_redis as _wb_get_redis, _close_tunnel
from app.services.winbox_remote import get_session as _wb_worker_get, health_check as _wb_health
from app.services.winbox_remote import get_session as _wb_worker_get
async def _winbox_reconcile_loop() -> None:
"""Scan Redis for winbox-remote:* keys and reconcile with worker."""

View File

@@ -49,6 +49,7 @@ def require_role(*allowed_roles: str) -> Callable:
Returns:
FastAPI dependency that raises 403 if the role is insufficient
"""
async def dependency(
current_user: CurrentUser = Depends(get_current_user),
) -> CurrentUser:
@@ -56,7 +57,7 @@ def require_role(*allowed_roles: str) -> Callable:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied. Required roles: {', '.join(allowed_roles)}. "
f"Your role: {current_user.role}",
f"Your role: {current_user.role}",
)
return current_user
@@ -82,7 +83,7 @@ def require_min_role(min_role: str) -> Callable:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied. Minimum required role: {min_role}. "
f"Your role: {current_user.role}",
f"Your role: {current_user.role}",
)
return current_user
@@ -96,6 +97,7 @@ def require_write_access() -> Callable:
Viewers are NOT allowed on POST/PUT/PATCH/DELETE endpoints.
Call this on any mutating endpoint to deny viewers.
"""
async def dependency(
request: Request,
current_user: CurrentUser = Depends(get_current_user),
@@ -105,7 +107,7 @@ def require_write_access() -> Callable:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Viewers have read-only access. "
"Contact your administrator to request elevated permissions.",
"Contact your administrator to request elevated permissions.",
)
return current_user
@@ -127,6 +129,7 @@ def require_scope(scope: str) -> DependsClass:
Raises:
HTTPException 403 if the API key is missing the required scope.
"""
async def _check_scope(
current_user: CurrentUser = Depends(get_current_user),
) -> CurrentUser:
@@ -143,6 +146,7 @@ def require_scope(scope: str) -> DependsClass:
# Pre-built convenience dependencies
async def require_super_admin(
current_user: CurrentUser = Depends(get_current_user),
) -> CurrentUser:

View File

@@ -20,32 +20,36 @@ from starlette.requests import Request
from starlette.responses import Response
# Production CSP: strict -- no inline scripts allowed
_CSP_PRODUCTION = "; ".join([
"default-src 'self'",
"script-src 'self'",
"style-src 'self' 'unsafe-inline'",
"img-src 'self' data: blob:",
"font-src 'self'",
"connect-src 'self' wss: ws:",
"worker-src 'self'",
"frame-ancestors 'none'",
"base-uri 'self'",
"form-action 'self'",
])
_CSP_PRODUCTION = "; ".join(
[
"default-src 'self'",
"script-src 'self'",
"style-src 'self' 'unsafe-inline'",
"img-src 'self' data: blob:",
"font-src 'self'",
"connect-src 'self' wss: ws:",
"worker-src 'self'",
"frame-ancestors 'none'",
"base-uri 'self'",
"form-action 'self'",
]
)
# Development CSP: relaxed for Vite HMR (hot module replacement)
_CSP_DEV = "; ".join([
"default-src 'self'",
"script-src 'self' 'unsafe-inline' 'unsafe-eval'",
"style-src 'self' 'unsafe-inline'",
"img-src 'self' data: blob:",
"font-src 'self'",
"connect-src 'self' http://localhost:* ws://localhost:* wss:",
"worker-src 'self' blob:",
"frame-ancestors 'none'",
"base-uri 'self'",
"form-action 'self'",
])
_CSP_DEV = "; ".join(
[
"default-src 'self'",
"script-src 'self' 'unsafe-inline' 'unsafe-eval'",
"style-src 'self' 'unsafe-inline'",
"img-src 'self' data: blob:",
"font-src 'self'",
"connect-src 'self' http://localhost:* ws://localhost:* wss:",
"worker-src 'self' blob:",
"frame-ancestors 'none'",
"base-uri 'self'",
"form-action 'self'",
]
)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
@@ -72,8 +76,6 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
# HSTS only in production (plain HTTP in dev would be blocked)
if self.is_production:
response.headers["Strict-Transport-Security"] = (
"max-age=31536000; includeSubDomains"
)
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
return response

View File

@@ -178,7 +178,6 @@ async def get_current_user_ws(
Raises:
WebSocketException 1008: If no token is provided or the token is invalid.
"""
from starlette.websockets import WebSocket, WebSocketState
from fastapi import WebSocketException
# 1. Try cookie

View File

@@ -2,7 +2,14 @@
from app.models.tenant import Tenant
from app.models.user import User, UserRole
from app.models.device import Device, DeviceGroup, DeviceTag, DeviceGroupMembership, DeviceTagAssignment, DeviceStatus
from app.models.device import (
Device,
DeviceGroup,
DeviceTag,
DeviceGroupMembership,
DeviceTagAssignment,
DeviceStatus,
)
from app.models.alert import AlertRule, NotificationChannel, AlertRuleChannel, AlertEvent
from app.models.firmware import FirmwareVersion, FirmwareUpgradeJob
from app.models.config_template import ConfigTemplate, ConfigTemplateTag, TemplatePushJob

View File

@@ -26,6 +26,7 @@ class AlertRule(Base):
When a metric breaches the threshold for duration_polls consecutive polls,
an alert fires.
"""
__tablename__ = "alert_rules"
id: Mapped[uuid.UUID] = mapped_column(
@@ -53,10 +54,16 @@ class AlertRule(Base):
metric: Mapped[str] = mapped_column(Text, nullable=False)
operator: Mapped[str] = mapped_column(Text, nullable=False)
threshold: Mapped[float] = mapped_column(Numeric, nullable=False)
duration_polls: Mapped[int] = mapped_column(Integer, nullable=False, default=1, server_default="1")
duration_polls: Mapped[int] = mapped_column(
Integer, nullable=False, default=1, server_default="1"
)
severity: Mapped[str] = mapped_column(Text, nullable=False)
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True, server_default="true")
is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default="false")
enabled: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=True, server_default="true"
)
is_default: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False, server_default="false"
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
@@ -69,6 +76,7 @@ class AlertRule(Base):
class NotificationChannel(Base):
"""Email, webhook, or Slack notification destination."""
__tablename__ = "notification_channels"
id: Mapped[uuid.UUID] = mapped_column(
@@ -83,12 +91,16 @@ class NotificationChannel(Base):
nullable=False,
)
name: Mapped[str] = mapped_column(Text, nullable=False)
channel_type: Mapped[str] = mapped_column(Text, nullable=False) # "email", "webhook", or "slack"
channel_type: Mapped[str] = mapped_column(
Text, nullable=False
) # "email", "webhook", or "slack"
# SMTP fields (email channels)
smtp_host: Mapped[str | None] = mapped_column(Text, nullable=True)
smtp_port: Mapped[int | None] = mapped_column(Integer, nullable=True)
smtp_user: Mapped[str | None] = mapped_column(Text, nullable=True)
smtp_password: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True) # AES-256-GCM encrypted
smtp_password: Mapped[bytes | None] = mapped_column(
LargeBinary, nullable=True
) # AES-256-GCM encrypted
smtp_use_tls: Mapped[bool] = mapped_column(Boolean, default=False, server_default="false")
from_address: Mapped[str | None] = mapped_column(Text, nullable=True)
to_address: Mapped[str | None] = mapped_column(Text, nullable=True)
@@ -110,6 +122,7 @@ class NotificationChannel(Base):
class AlertRuleChannel(Base):
"""Many-to-many association between alert rules and notification channels."""
__tablename__ = "alert_rule_channels"
rule_id: Mapped[uuid.UUID] = mapped_column(
@@ -129,6 +142,7 @@ class AlertEvent(Base):
rule_id is NULL for system-level alerts (e.g., device offline).
"""
__tablename__ = "alert_events"
id: Mapped[uuid.UUID] = mapped_column(
@@ -158,7 +172,9 @@ class AlertEvent(Base):
value: Mapped[float | None] = mapped_column(Numeric, nullable=True)
threshold: Mapped[float | None] = mapped_column(Numeric, nullable=True)
message: Mapped[str | None] = mapped_column(Text, nullable=True)
is_flapping: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default="false")
is_flapping: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False, server_default="false"
)
acknowledged_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
acknowledged_by: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),

View File

@@ -41,20 +41,14 @@ class ApiKey(Base):
key_prefix: Mapped[str] = mapped_column(Text, nullable=False)
key_hash: Mapped[str] = mapped_column(Text, nullable=False, unique=True)
scopes: Mapped[list] = mapped_column(JSONB, nullable=False, server_default="'[]'::jsonb")
expires_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True), nullable=True
)
last_used_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True), nullable=True
)
expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
last_used_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
revoked_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True), nullable=True
)
revoked_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
def __repr__(self) -> str:
return f"<ApiKey id={self.id} name={self.name} prefix={self.key_prefix}>"

View File

@@ -39,21 +39,13 @@ class CertificateAuthority(Base):
)
common_name: Mapped[str] = mapped_column(String(255), nullable=False)
cert_pem: Mapped[str] = mapped_column(Text, nullable=False)
encrypted_private_key: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
encrypted_private_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
serial_number: Mapped[str] = mapped_column(String(64), nullable=False)
fingerprint_sha256: Mapped[str] = mapped_column(String(95), nullable=False)
not_valid_before: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False
)
not_valid_after: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False
)
not_valid_before: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
not_valid_after: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
# OpenBao Transit ciphertext (dual-write migration)
encrypted_private_key_transit: Mapped[str | None] = mapped_column(
Text, nullable=True
)
encrypted_private_key_transit: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
@@ -62,8 +54,7 @@ class CertificateAuthority(Base):
def __repr__(self) -> str:
return (
f"<CertificateAuthority id={self.id} "
f"cn={self.common_name!r} tenant={self.tenant_id}>"
f"<CertificateAuthority id={self.id} cn={self.common_name!r} tenant={self.tenant_id}>"
)
@@ -103,25 +94,13 @@ class DeviceCertificate(Base):
serial_number: Mapped[str] = mapped_column(String(64), nullable=False)
fingerprint_sha256: Mapped[str] = mapped_column(String(95), nullable=False)
cert_pem: Mapped[str] = mapped_column(Text, nullable=False)
encrypted_private_key: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
not_valid_before: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False
)
not_valid_after: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False
)
encrypted_private_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
not_valid_before: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
not_valid_after: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
# OpenBao Transit ciphertext (dual-write migration)
encrypted_private_key_transit: Mapped[str | None] = mapped_column(
Text, nullable=True
)
status: Mapped[str] = mapped_column(
String(20), nullable=False, server_default="issued"
)
deployed_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
encrypted_private_key_transit: Mapped[str | None] = mapped_column(Text, nullable=True)
status: Mapped[str] = mapped_column(String(20), nullable=False, server_default="issued")
deployed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
@@ -134,7 +113,4 @@ class DeviceCertificate(Base):
)
def __repr__(self) -> str:
return (
f"<DeviceCertificate id={self.id} "
f"cn={self.common_name!r} status={self.status}>"
)
return f"<DeviceCertificate id={self.id} cn={self.common_name!r} status={self.status}>"

View File

@@ -3,9 +3,20 @@
import uuid
from datetime import datetime
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, LargeBinary, SmallInteger, String, Text, UniqueConstraint, func
from sqlalchemy import (
Boolean,
DateTime,
ForeignKey,
Integer,
LargeBinary,
SmallInteger,
String,
Text,
UniqueConstraint,
func,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
@@ -115,7 +126,9 @@ class ConfigBackupSchedule(Base):
def __repr__(self) -> str:
scope = f"device={self.device_id}" if self.device_id else f"tenant={self.tenant_id}"
return f"<ConfigBackupSchedule {scope} cron={self.cron_expression!r} enabled={self.enabled}>"
return (
f"<ConfigBackupSchedule {scope} cron={self.cron_expression!r} enabled={self.enabled}>"
)
class ConfigPushOperation(Base):
@@ -173,8 +186,7 @@ class ConfigPushOperation(Base):
def __repr__(self) -> str:
return (
f"<ConfigPushOperation id={self.id} device_id={self.device_id} "
f"status={self.status!r}>"
f"<ConfigPushOperation id={self.id} device_id={self.device_id} status={self.status!r}>"
)
@@ -272,7 +284,9 @@ class RouterConfigDiff(Base):
)
diff_text: Mapped[str] = mapped_column(Text, nullable=False)
lines_added: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
lines_removed: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
lines_removed: Mapped[int] = mapped_column(
Integer, nullable=False, default=0, server_default="0"
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
@@ -334,6 +348,5 @@ class RouterConfigChange(Base):
def __repr__(self) -> str:
return (
f"<RouterConfigChange id={self.id} diff_id={self.diff_id} "
f"component={self.component!r}>"
f"<RouterConfigChange id={self.id} diff_id={self.diff_id} component={self.component!r}>"
)

View File

@@ -5,7 +5,6 @@ from datetime import datetime
from sqlalchemy import (
DateTime,
Float,
ForeignKey,
String,
Text,
@@ -88,9 +87,7 @@ class ConfigTemplateTag(Base):
)
# Relationships
template: Mapped["ConfigTemplate"] = relationship(
"ConfigTemplate", back_populates="tags"
)
template: Mapped["ConfigTemplate"] = relationship("ConfigTemplate", back_populates="tags")
def __repr__(self) -> str:
return f"<ConfigTemplateTag id={self.id} name={self.name!r} template_id={self.template_id}>"
@@ -133,12 +130,8 @@ class TemplatePushJob(Base):
)
pre_push_backup_sha: Mapped[str | None] = mapped_column(Text, nullable=True)
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
started_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
completed_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
started_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),

View File

@@ -5,7 +5,6 @@ from datetime import datetime
from enum import Enum
from sqlalchemy import (
Boolean,
DateTime,
Float,
ForeignKey,
@@ -24,6 +23,7 @@ from app.database import Base
class DeviceStatus(str, Enum):
"""Device connection status."""
UNKNOWN = "unknown"
ONLINE = "online"
OFFLINE = "offline"
@@ -31,9 +31,7 @@ class DeviceStatus(str, Enum):
class Device(Base):
__tablename__ = "devices"
__table_args__ = (
UniqueConstraint("tenant_id", "hostname", name="uq_devices_tenant_hostname"),
)
__table_args__ = (UniqueConstraint("tenant_id", "hostname", name="uq_devices_tenant_hostname"),)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
@@ -59,7 +57,9 @@ class Device(Base):
uptime_seconds: Mapped[int | None] = mapped_column(Integer, nullable=True)
last_cpu_load: Mapped[int | None] = mapped_column(Integer, nullable=True)
last_memory_used_pct: Mapped[int | None] = mapped_column(Integer, nullable=True)
architecture: Mapped[str | None] = mapped_column(Text, nullable=True) # CPU arch (arm, arm64, mipsbe, etc.)
architecture: Mapped[str | None] = mapped_column(
Text, nullable=True
) # CPU arch (arm, arm64, mipsbe, etc.)
preferred_channel: Mapped[str] = mapped_column(
Text, default="stable", server_default="stable", nullable=False
) # Firmware release channel
@@ -108,9 +108,7 @@ class Device(Base):
class DeviceGroup(Base):
__tablename__ = "device_groups"
__table_args__ = (
UniqueConstraint("tenant_id", "name", name="uq_device_groups_tenant_name"),
)
__table_args__ = (UniqueConstraint("tenant_id", "name", name="uq_device_groups_tenant_name"),)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
@@ -147,9 +145,7 @@ class DeviceGroup(Base):
class DeviceTag(Base):
__tablename__ = "device_tags"
__table_args__ = (
UniqueConstraint("tenant_id", "name", name="uq_device_tags_tenant_name"),
)
__table_args__ = (UniqueConstraint("tenant_id", "name", name="uq_device_tags_tenant_name"),)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),

View File

@@ -7,7 +7,6 @@ from sqlalchemy import (
BigInteger,
Boolean,
DateTime,
Integer,
Text,
UniqueConstraint,
func,
@@ -24,10 +23,9 @@ class FirmwareVersion(Base):
Not tenant-scoped — firmware versions are global data shared across all tenants.
"""
__tablename__ = "firmware_versions"
__table_args__ = (
UniqueConstraint("architecture", "channel", "version"),
)
__table_args__ = (UniqueConstraint("architecture", "channel", "version"),)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
@@ -56,6 +54,7 @@ class FirmwareUpgradeJob(Base):
Multiple jobs can share a rollout_group_id for mass upgrades.
"""
__tablename__ = "firmware_upgrade_jobs"
id: Mapped[uuid.UUID] = mapped_column(
@@ -99,4 +98,6 @@ class FirmwareUpgradeJob(Base):
)
def __repr__(self) -> str:
return f"<FirmwareUpgradeJob id={self.id} status={self.status} target={self.target_version}>"
return (
f"<FirmwareUpgradeJob id={self.id} status={self.status} target={self.target_version}>"
)

View File

@@ -37,32 +37,18 @@ class UserKeySet(Base):
ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=True, # NULL for super_admin
)
encrypted_private_key: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
private_key_nonce: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
encrypted_vault_key: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
vault_key_nonce: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
public_key: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
encrypted_private_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
private_key_nonce: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
encrypted_vault_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
vault_key_nonce: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
public_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
pbkdf2_iterations: Mapped[int] = mapped_column(
Integer,
server_default=func.literal_column("650000"),
nullable=False,
)
pbkdf2_salt: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
hkdf_salt: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
pbkdf2_salt: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
hkdf_salt: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
key_version: Mapped[int] = mapped_column(
Integer,
server_default=func.literal_column("1"),

View File

@@ -20,6 +20,7 @@ class MaintenanceWindow(Base):
device_ids is a JSONB array of device UUID strings.
An empty array means "all devices in tenant".
"""
__tablename__ = "maintenance_windows"
id: Mapped[uuid.UUID] = mapped_column(

View File

@@ -40,10 +40,18 @@ class Tenant(Base):
openbao_key_name: Mapped[str | None] = mapped_column(Text, nullable=True)
# Relationships — passive_deletes=True lets the DB ON DELETE CASCADE handle cleanup
users: Mapped[list["User"]] = relationship("User", back_populates="tenant", passive_deletes=True) # type: ignore[name-defined]
devices: Mapped[list["Device"]] = relationship("Device", back_populates="tenant", passive_deletes=True) # type: ignore[name-defined]
device_groups: Mapped[list["DeviceGroup"]] = relationship("DeviceGroup", back_populates="tenant", passive_deletes=True) # type: ignore[name-defined]
device_tags: Mapped[list["DeviceTag"]] = relationship("DeviceTag", back_populates="tenant", passive_deletes=True) # type: ignore[name-defined]
users: Mapped[list["User"]] = relationship(
"User", back_populates="tenant", passive_deletes=True
) # type: ignore[name-defined]
devices: Mapped[list["Device"]] = relationship(
"Device", back_populates="tenant", passive_deletes=True
) # type: ignore[name-defined]
device_groups: Mapped[list["DeviceGroup"]] = relationship(
"DeviceGroup", back_populates="tenant", passive_deletes=True
) # type: ignore[name-defined]
device_tags: Mapped[list["DeviceTag"]] = relationship(
"DeviceTag", back_populates="tenant", passive_deletes=True
) # type: ignore[name-defined]
def __repr__(self) -> str:
return f"<Tenant id={self.id} name={self.name!r}>"

View File

@@ -13,6 +13,7 @@ from app.database import Base
class UserRole(str, Enum):
"""User roles with increasing privilege levels."""
SUPER_ADMIN = "super_admin"
TENANT_ADMIN = "tenant_admin"
OPERATOR = "operator"

View File

@@ -75,7 +75,9 @@ class VpnPeer(Base):
assigned_ip: Mapped[str] = mapped_column(String(32), nullable=False)
additional_allowed_ips: Mapped[Optional[str]] = mapped_column(String(512), nullable=True)
is_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default="true")
last_handshake: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
last_handshake: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True), nullable=True
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)

View File

@@ -12,10 +12,8 @@ RLS enforced via get_db() (app_user engine with tenant context).
RBAC: viewer = read-only (GET); operator and above = write (POST/PUT/PATCH/DELETE).
"""
import base64
import logging
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
@@ -66,8 +64,13 @@ def _require_write(current_user: CurrentUser) -> None:
EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
ALLOWED_METRICS = {
"cpu_load", "memory_used_pct", "disk_used_pct", "temperature",
"signal_strength", "ccq", "client_count",
"cpu_load",
"memory_used_pct",
"disk_used_pct",
"temperature",
"signal_strength",
"ccq",
"client_count",
}
ALLOWED_OPERATORS = {"gt", "lt", "gte", "lte"}
ALLOWED_SEVERITIES = {"critical", "warning", "info"}
@@ -252,7 +255,9 @@ async def create_alert_rule(
if body.operator not in ALLOWED_OPERATORS:
raise HTTPException(422, f"operator must be one of: {', '.join(sorted(ALLOWED_OPERATORS))}")
if body.severity not in ALLOWED_SEVERITIES:
raise HTTPException(422, f"severity must be one of: {', '.join(sorted(ALLOWED_SEVERITIES))}")
raise HTTPException(
422, f"severity must be one of: {', '.join(sorted(ALLOWED_SEVERITIES))}"
)
rule_id = str(uuid.uuid4())
@@ -296,8 +301,12 @@ async def create_alert_rule(
try:
await log_action(
db, tenant_id, current_user.user_id, "alert_rule_create",
resource_type="alert_rule", resource_id=rule_id,
db,
tenant_id,
current_user.user_id,
"alert_rule_create",
resource_type="alert_rule",
resource_id=rule_id,
details={"name": body.name, "metric": body.metric, "severity": body.severity},
)
except Exception:
@@ -338,7 +347,9 @@ async def update_alert_rule(
if body.operator not in ALLOWED_OPERATORS:
raise HTTPException(422, f"operator must be one of: {', '.join(sorted(ALLOWED_OPERATORS))}")
if body.severity not in ALLOWED_SEVERITIES:
raise HTTPException(422, f"severity must be one of: {', '.join(sorted(ALLOWED_SEVERITIES))}")
raise HTTPException(
422, f"severity must be one of: {', '.join(sorted(ALLOWED_SEVERITIES))}"
)
result = await db.execute(
text("""
@@ -384,8 +395,12 @@ async def update_alert_rule(
try:
await log_action(
db, tenant_id, current_user.user_id, "alert_rule_update",
resource_type="alert_rule", resource_id=str(rule_id),
db,
tenant_id,
current_user.user_id,
"alert_rule_update",
resource_type="alert_rule",
resource_id=str(rule_id),
details={"name": body.name, "metric": body.metric, "severity": body.severity},
)
except Exception:
@@ -439,8 +454,12 @@ async def delete_alert_rule(
try:
await log_action(
db, tenant_id, current_user.user_id, "alert_rule_delete",
resource_type="alert_rule", resource_id=str(rule_id),
db,
tenant_id,
current_user.user_id,
"alert_rule_delete",
resource_type="alert_rule",
resource_id=str(rule_id),
)
except Exception:
pass
@@ -592,7 +611,8 @@ async def create_notification_channel(
encrypted_password_transit = None
if body.smtp_password:
encrypted_password_transit = await encrypt_credentials_transit(
body.smtp_password, str(tenant_id),
body.smtp_password,
str(tenant_id),
)
await db.execute(
@@ -665,10 +685,14 @@ async def update_notification_channel(
# Build SET clauses dynamically based on which secrets are provided
set_parts = [
"name = :name", "channel_type = :channel_type",
"smtp_host = :smtp_host", "smtp_port = :smtp_port",
"smtp_user = :smtp_user", "smtp_use_tls = :smtp_use_tls",
"from_address = :from_address", "to_address = :to_address",
"name = :name",
"channel_type = :channel_type",
"smtp_host = :smtp_host",
"smtp_port = :smtp_port",
"smtp_user = :smtp_user",
"smtp_use_tls = :smtp_use_tls",
"from_address = :from_address",
"to_address = :to_address",
"webhook_url = :webhook_url",
"slack_webhook_url = :slack_webhook_url",
]
@@ -689,7 +713,8 @@ async def update_notification_channel(
if body.smtp_password:
set_parts.append("smtp_password_transit = :smtp_password_transit")
params["smtp_password_transit"] = await encrypt_credentials_transit(
body.smtp_password, str(tenant_id),
body.smtp_password,
str(tenant_id),
)
# Clear legacy column
set_parts.append("smtp_password = NULL")
@@ -799,6 +824,7 @@ async def test_notification_channel(
}
from app.services.notification_service import send_test_notification
try:
success = await send_test_notification(channel)
if success:

View File

@@ -221,29 +221,38 @@ async def list_audit_logs(
all_rows = result.mappings().all()
# Decrypt encrypted details concurrently
decrypted_details = await _decrypt_details_batch(
all_rows, str(tenant_id)
)
decrypted_details = await _decrypt_details_batch(all_rows, str(tenant_id))
output = io.StringIO()
writer = csv.writer(output)
writer.writerow([
"ID", "User Email", "Action", "Resource Type",
"Resource ID", "Device", "Details", "IP Address", "Timestamp",
])
writer.writerow(
[
"ID",
"User Email",
"Action",
"Resource Type",
"Resource ID",
"Device",
"Details",
"IP Address",
"Timestamp",
]
)
for row, details in zip(all_rows, decrypted_details):
details_str = json.dumps(details) if details else "{}"
writer.writerow([
str(row["id"]),
row["user_email"] or "",
row["action"],
row["resource_type"] or "",
row["resource_id"] or "",
row["device_name"] or "",
details_str,
row["ip_address"] or "",
str(row["created_at"]),
])
writer.writerow(
[
str(row["id"]),
row["user_email"] or "",
row["action"],
row["resource_type"] or "",
row["resource_id"] or "",
row["device_name"] or "",
details_str,
row["ip_address"] or "",
str(row["created_at"]),
]
)
output.seek(0)
return StreamingResponse(

View File

@@ -103,7 +103,11 @@ async def get_redis() -> aioredis.Redis:
# ─── SRP Zero-Knowledge Authentication ───────────────────────────────────────
@router.post("/srp/init", response_model=SRPInitResponse, summary="SRP Step 1: return salt and server ephemeral B")
@router.post(
"/srp/init",
response_model=SRPInitResponse,
summary="SRP Step 1: return salt and server ephemeral B",
)
@limiter.limit("5/minute")
async def srp_init_endpoint(
request: StarletteRequest,
@@ -137,9 +141,7 @@ async def srp_init_endpoint(
# Generate server ephemeral
try:
server_public, server_private = await srp_init(
user.email, user.srp_verifier.hex()
)
server_public, server_private = await srp_init(user.email, user.srp_verifier.hex())
except ValueError as e:
logger.error("SRP init failed for %s: %s", user.email, e)
raise HTTPException(
@@ -150,13 +152,15 @@ async def srp_init_endpoint(
# Store session in Redis with 60s TTL
session_id = secrets.token_urlsafe(16)
redis = await get_redis()
session_data = json.dumps({
"email": user.email,
"server_private": server_private,
"srp_verifier_hex": user.srp_verifier.hex(),
"srp_salt_hex": user.srp_salt.hex(),
"user_id": str(user.id),
})
session_data = json.dumps(
{
"email": user.email,
"server_private": server_private,
"srp_verifier_hex": user.srp_verifier.hex(),
"srp_salt_hex": user.srp_salt.hex(),
"user_id": str(user.id),
}
)
await redis.set(f"srp:session:{session_id}", session_data, ex=60)
return SRPInitResponse(
@@ -168,7 +172,11 @@ async def srp_init_endpoint(
)
@router.post("/srp/verify", response_model=SRPVerifyResponse, summary="SRP Step 2: verify client proof and return tokens")
@router.post(
"/srp/verify",
response_model=SRPVerifyResponse,
summary="SRP Step 2: verify client proof and return tokens",
)
@limiter.limit("5/minute")
async def srp_verify_endpoint(
request: StarletteRequest,
@@ -236,7 +244,9 @@ async def srp_verify_endpoint(
# Update last_login and clear upgrade flag on successful SRP login
await db.execute(
update(User).where(User.id == user.id).values(
update(User)
.where(User.id == user.id)
.values(
last_login=datetime.now(UTC),
must_upgrade_auth=False,
)
@@ -323,9 +333,7 @@ async def login(
Rate limited to 5 requests per minute per IP.
"""
# Look up user by email (case-insensitive)
result = await db.execute(
select(User).where(User.email == body.email.lower())
)
result = await db.execute(select(User).where(User.email == body.email.lower()))
user = result.scalar_one_or_none()
# Generic error — do not reveal whether email exists (no user enumeration)
@@ -389,7 +397,9 @@ async def login(
# Update last_login
await db.execute(
update(User).where(User.id == user.id).values(
update(User)
.where(User.id == user.id)
.values(
last_login=datetime.now(UTC),
)
)
@@ -404,7 +414,10 @@ async def login(
user_id=user.id,
action="login_upgrade" if user.must_upgrade_auth else "login",
resource_type="auth",
details={"email": user.email, **({"upgrade": "bcrypt_to_srp"} if user.must_upgrade_auth else {})},
details={
"email": user.email,
**({"upgrade": "bcrypt_to_srp"} if user.must_upgrade_auth else {}),
},
ip_address=request.client.host if request.client else None,
)
await audit_db.commit()
@@ -440,7 +453,9 @@ async def refresh_token(
Rate limited to 10 requests per minute per IP.
"""
# Resolve token: body takes precedence over cookie
raw_token = (body.refresh_token if body and body.refresh_token else None) or refresh_token_cookie
raw_token = (
body.refresh_token if body and body.refresh_token else None
) or refresh_token_cookie
if not raw_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -518,7 +533,9 @@ async def refresh_token(
)
@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT, summary="Log out and clear session cookie")
@router.post(
"/logout", status_code=status.HTTP_204_NO_CONTENT, summary="Log out and clear session cookie"
)
@limiter.limit("10/minute")
async def logout(
request: StarletteRequest,
@@ -535,7 +552,10 @@ async def logout(
tenant_id = current_user.tenant_id or uuid.UUID(int=0)
async with AdminAsyncSessionLocal() as audit_db:
await log_action(
audit_db, tenant_id, current_user.user_id, "logout",
audit_db,
tenant_id,
current_user.user_id,
"logout",
resource_type="auth",
ip_address=request.client.host if request.client else None,
)
@@ -558,7 +578,11 @@ async def logout(
)
@router.post("/change-password", response_model=MessageResponse, summary="Change password for authenticated user")
@router.post(
"/change-password",
response_model=MessageResponse,
summary="Change password for authenticated user",
)
@limiter.limit("3/minute")
async def change_password(
request: StarletteRequest,
@@ -602,7 +626,9 @@ async def change_password(
existing_ks.hkdf_salt = base64.b64decode(body.hkdf_salt or "")
else:
# Legacy bcrypt user — verify current password
if not user.hashed_password or not verify_password(body.current_password, user.hashed_password):
if not user.hashed_password or not verify_password(
body.current_password, user.hashed_password
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Current password is incorrect",
@@ -822,7 +848,9 @@ async def get_emergency_kit_template(
)
@router.post("/register-srp", response_model=MessageResponse, summary="Register SRP credentials for a user")
@router.post(
"/register-srp", response_model=MessageResponse, summary="Register SRP credentials for a user"
)
@limiter.limit("3/minute")
async def register_srp(
request: StarletteRequest,
@@ -845,7 +873,9 @@ async def register_srp(
# Update user with SRP credentials and clear upgrade flag
await db.execute(
update(User).where(User.id == user.id).values(
update(User)
.where(User.id == user.id)
.values(
srp_salt=bytes.fromhex(body.srp_salt),
srp_verifier=bytes.fromhex(body.srp_verifier),
auth_version=2,
@@ -873,8 +903,11 @@ async def register_srp(
try:
async with AdminAsyncSessionLocal() as audit_db:
await log_key_access(
audit_db, user.tenant_id or uuid.UUID(int=0), user.id,
"create_key_set", resource_type="user_key_set",
audit_db,
user.tenant_id or uuid.UUID(int=0),
user.id,
"create_key_set",
resource_type="user_key_set",
ip_address=request.client.host if request.client else None,
)
await audit_db.commit()
@@ -901,11 +934,17 @@ async def create_sse_token(
token = secrets.token_urlsafe(32)
key = f"sse_token:{token}"
# Store user context for the SSE endpoint to retrieve
await redis.set(key, json.dumps({
"user_id": str(current_user.user_id),
"tenant_id": str(current_user.tenant_id) if current_user.tenant_id else None,
"role": current_user.role,
}), ex=30) # 30 second TTL
await redis.set(
key,
json.dumps(
{
"user_id": str(current_user.user_id),
"tenant_id": str(current_user.tenant_id) if current_user.tenant_id else None,
"role": current_user.role,
}
),
ex=30,
) # 30 second TTL
return {"token": token}
@@ -977,9 +1016,7 @@ async def forgot_password(
"""
generic_msg = "If an account with that email exists, a reset link has been sent."
result = await db.execute(
select(User).where(User.email == body.email.lower())
)
result = await db.execute(select(User).where(User.email == body.email.lower()))
user = result.scalar_one_or_none()
if not user or not user.is_active:
@@ -988,9 +1025,7 @@ async def forgot_password(
# Generate a secure token
raw_token = secrets.token_urlsafe(32)
token_hash = _hash_token(raw_token)
expires_at = datetime.now(UTC) + timedelta(
minutes=settings.PASSWORD_RESET_TOKEN_EXPIRE_MINUTES
)
expires_at = datetime.now(UTC) + timedelta(minutes=settings.PASSWORD_RESET_TOKEN_EXPIRE_MINUTES)
# Insert token record (using raw SQL to avoid importing the model globally)
from sqlalchemy import text

View File

@@ -12,9 +12,7 @@ RBAC: viewer = read-only (GET); tenant_admin and above = mutating actions.
"""
import json
import logging
import uuid
from datetime import datetime, timezone
import nats
import nats.aio.client
@@ -30,7 +28,7 @@ from app.database import get_db, set_tenant_context
from app.middleware.rate_limit import limiter
from app.middleware.rbac import require_min_role
from app.middleware.tenant_context import CurrentUser, get_current_user
from app.models.certificate import CertificateAuthority, DeviceCertificate
from app.models.certificate import DeviceCertificate
from app.models.device import Device
from app.schemas.certificate import (
BulkCertDeployRequest,
@@ -87,13 +85,15 @@ async def _deploy_cert_via_nats(
Dict with success, cert_name_on_device, and error fields.
"""
nc = await _get_nats()
payload = json.dumps({
"device_id": device_id,
"cert_pem": cert_pem,
"key_pem": key_pem,
"cert_name": cert_name,
"ssh_port": ssh_port,
}).encode()
payload = json.dumps(
{
"device_id": device_id,
"cert_pem": cert_pem,
"key_pem": key_pem,
"cert_name": cert_name,
"ssh_port": ssh_port,
}
).encode()
try:
reply = await nc.request(
@@ -121,9 +121,7 @@ async def _get_device_for_tenant(
db: AsyncSession, device_id: uuid.UUID, current_user: CurrentUser
) -> Device:
"""Fetch a device and verify tenant ownership."""
result = await db.execute(
select(Device).where(Device.id == device_id)
)
result = await db.execute(select(Device).where(Device.id == device_id))
device = result.scalar_one_or_none()
if device is None:
raise HTTPException(
@@ -164,9 +162,7 @@ async def _get_cert_with_tenant_check(
db: AsyncSession, cert_id: uuid.UUID, tenant_id: uuid.UUID
) -> DeviceCertificate:
"""Fetch a device certificate and verify tenant ownership."""
result = await db.execute(
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
)
result = await db.execute(select(DeviceCertificate).where(DeviceCertificate.id == cert_id))
cert = result.scalar_one_or_none()
if cert is None:
raise HTTPException(
@@ -226,8 +222,12 @@ async def create_ca(
try:
await log_action(
db, tenant_id, current_user.user_id, "ca_create",
resource_type="certificate_authority", resource_id=str(ca.id),
db,
tenant_id,
current_user.user_id,
"ca_create",
resource_type="certificate_authority",
resource_id=str(ca.id),
details={"common_name": body.common_name, "validity_years": body.validity_years},
)
except Exception:
@@ -332,8 +332,12 @@ async def sign_cert(
try:
await log_action(
db, tenant_id, current_user.user_id, "cert_sign",
resource_type="device_certificate", resource_id=str(cert.id),
db,
tenant_id,
current_user.user_id,
"cert_sign",
resource_type="device_certificate",
resource_id=str(cert.id),
device_id=body.device_id,
details={"hostname": device.hostname, "validity_days": body.validity_days},
)
@@ -404,17 +408,19 @@ async def deploy_cert(
await update_cert_status(db, cert_id, "deployed")
# Update device tls_mode to portal_ca
device_result = await db.execute(
select(Device).where(Device.id == cert.device_id)
)
device_result = await db.execute(select(Device).where(Device.id == cert.device_id))
device = device_result.scalar_one_or_none()
if device is not None:
device.tls_mode = "portal_ca"
try:
await log_action(
db, tenant_id, current_user.user_id, "cert_deploy",
resource_type="device_certificate", resource_id=str(cert_id),
db,
tenant_id,
current_user.user_id,
"cert_deploy",
resource_type="device_certificate",
resource_id=str(cert_id),
device_id=cert.device_id,
details={"cert_name_on_device": result.get("cert_name_on_device")},
)
@@ -528,36 +534,47 @@ async def bulk_deploy(
await update_cert_status(db, issued_cert.id, "deployed")
device.tls_mode = "portal_ca"
results.append(CertDeployResponse(
success=True,
device_id=device_id,
cert_name_on_device=result.get("cert_name_on_device"),
))
results.append(
CertDeployResponse(
success=True,
device_id=device_id,
cert_name_on_device=result.get("cert_name_on_device"),
)
)
else:
await update_cert_status(db, issued_cert.id, "issued")
results.append(CertDeployResponse(
success=False,
device_id=device_id,
error=result.get("error"),
))
results.append(
CertDeployResponse(
success=False,
device_id=device_id,
error=result.get("error"),
)
)
except HTTPException as e:
results.append(CertDeployResponse(
success=False,
device_id=device_id,
error=e.detail,
))
results.append(
CertDeployResponse(
success=False,
device_id=device_id,
error=e.detail,
)
)
except Exception as e:
logger.error("Bulk deploy error", device_id=str(device_id), error=str(e))
results.append(CertDeployResponse(
success=False,
device_id=device_id,
error=str(e),
))
results.append(
CertDeployResponse(
success=False,
device_id=device_id,
error=str(e),
)
)
try:
await log_action(
db, tenant_id, current_user.user_id, "cert_bulk_deploy",
db,
tenant_id,
current_user.user_id,
"cert_bulk_deploy",
resource_type="device_certificate",
details={
"device_count": len(body.device_ids),
@@ -619,17 +636,19 @@ async def revoke_cert(
)
# Reset device tls_mode to insecure
device_result = await db.execute(
select(Device).where(Device.id == cert.device_id)
)
device_result = await db.execute(select(Device).where(Device.id == cert.device_id))
device = device_result.scalar_one_or_none()
if device is not None:
device.tls_mode = "insecure"
try:
await log_action(
db, tenant_id, current_user.user_id, "cert_revoke",
resource_type="device_certificate", resource_id=str(cert_id),
db,
tenant_id,
current_user.user_id,
"cert_revoke",
resource_type="device_certificate",
resource_id=str(cert_id),
device_id=cert.device_id,
)
except Exception:
@@ -661,9 +680,7 @@ async def rotate_cert(
old_cert = await _get_cert_with_tenant_check(db, cert_id, tenant_id)
# Get the device for hostname/IP
device_result = await db.execute(
select(Device).where(Device.id == old_cert.device_id)
)
device_result = await db.execute(select(Device).where(Device.id == old_cert.device_id))
device = device_result.scalar_one_or_none()
if device is None:
raise HTTPException(
@@ -722,8 +739,12 @@ async def rotate_cert(
try:
await log_action(
db, tenant_id, current_user.user_id, "cert_rotate",
resource_type="device_certificate", resource_id=str(new_cert.id),
db,
tenant_id,
current_user.user_id,
"cert_rotate",
resource_type="device_certificate",
resource_id=str(new_cert.id),
device_id=old_cert.device_id,
details={
"old_cert_id": str(cert_id),

View File

@@ -43,6 +43,7 @@ async def _check_tenant_access(
"""Verify the current user is allowed to access the given tenant."""
if current_user.is_super_admin:
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
@@ -52,9 +53,7 @@ async def _check_tenant_access(
)
async def _check_device_online(
db: AsyncSession, device_id: uuid.UUID
) -> Device:
async def _check_device_online(db: AsyncSession, device_id: uuid.UUID) -> Device:
"""Verify the device exists and is online. Returns the Device object."""
result = await db.execute(
select(Device).where(Device.id == device_id) # type: ignore[arg-type]

View File

@@ -25,7 +25,6 @@ import asyncio
import json
import logging
import uuid
from datetime import timezone, datetime
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Request, status
@@ -67,6 +66,7 @@ async def _check_tenant_access(
"""
if current_user.is_super_admin:
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
@@ -291,14 +291,14 @@ async def get_export(
try:
from app.services.crypto import decrypt_data_transit
plaintext = await decrypt_data_transit(
content_bytes.decode("utf-8"), str(tenant_id)
)
plaintext = await decrypt_data_transit(content_bytes.decode("utf-8"), str(tenant_id))
content_bytes = plaintext.encode("utf-8")
except Exception as dec_err:
logger.error(
"Failed to decrypt export for device %s sha %s: %s",
device_id, commit_sha, dec_err,
device_id,
commit_sha,
dec_err,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -370,7 +370,9 @@ async def get_binary(
except Exception as dec_err:
logger.error(
"Failed to decrypt binary backup for device %s sha %s: %s",
device_id, commit_sha, dec_err,
device_id,
commit_sha,
dec_err,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -380,9 +382,7 @@ async def get_binary(
return Response(
content=content_bytes,
media_type="application/octet-stream",
headers={
"Content-Disposition": f'attachment; filename="backup-{commit_sha[:8]}.bin"'
},
headers={"Content-Disposition": f'attachment; filename="backup-{commit_sha[:8]}.bin"'},
)
@@ -445,6 +445,7 @@ async def preview_restore(
key,
)
import json
creds = json.loads(creds_json)
current_text = await backup_service.capture_export(
device.ip_address,
@@ -578,9 +579,7 @@ async def emergency_rollback(
.where(
ConfigBackupRun.device_id == device_id, # type: ignore[arg-type]
ConfigBackupRun.tenant_id == tenant_id, # type: ignore[arg-type]
ConfigBackupRun.trigger_type.in_(
["pre-restore", "checkpoint", "pre-template-push"]
),
ConfigBackupRun.trigger_type.in_(["pre-restore", "checkpoint", "pre-template-push"]),
)
.order_by(ConfigBackupRun.created_at.desc())
.limit(1)
@@ -735,6 +734,7 @@ async def update_schedule(
# Hot-reload the scheduler so changes take effect immediately
from app.services.backup_scheduler import on_schedule_change
await on_schedule_change(tenant_id, device_id)
return {
@@ -758,6 +758,7 @@ async def _get_nats():
Reuses the same lazy-init pattern as routeros_proxy._get_nats().
"""
from app.services.routeros_proxy import _get_nats as _proxy_get_nats
return await _proxy_get_nats()
@@ -839,6 +840,7 @@ async def trigger_config_snapshot(
if reply_data.get("status") == "success":
try:
from app.services.audit_service import log_action
await log_action(
db,
tenant_id,

View File

@@ -64,9 +64,7 @@ async def _check_tenant_access(
await set_tenant_context(db, str(tenant_id))
async def _check_device_online(
db: AsyncSession, device_id: uuid.UUID
) -> Device:
async def _check_device_online(db: AsyncSession, device_id: uuid.UUID) -> Device:
"""Verify the device exists and is online. Returns the Device object."""
result = await db.execute(
select(Device).where(Device.id == device_id) # type: ignore[arg-type]
@@ -201,8 +199,12 @@ async def add_entry(
try:
await log_action(
db, tenant_id, current_user.user_id, "config_add",
resource_type="config", resource_id=str(device_id),
db,
tenant_id,
current_user.user_id,
"config_add",
resource_type="config",
resource_id=str(device_id),
device_id=device_id,
details={"path": body.path, "properties": body.properties},
)
@@ -255,8 +257,12 @@ async def set_entry(
try:
await log_action(
db, tenant_id, current_user.user_id, "config_set",
resource_type="config", resource_id=str(device_id),
db,
tenant_id,
current_user.user_id,
"config_set",
resource_type="config",
resource_id=str(device_id),
device_id=device_id,
details={"path": body.path, "entry_id": body.entry_id, "properties": body.properties},
)
@@ -286,9 +292,7 @@ async def remove_entry(
await _check_device_online(db, device_id)
check_path_safety(body.path, write=True)
result = await routeros_proxy.remove_entry(
str(device_id), body.path, body.entry_id
)
result = await routeros_proxy.remove_entry(str(device_id), body.path, body.entry_id)
if not result.get("success"):
raise HTTPException(
@@ -309,8 +313,12 @@ async def remove_entry(
try:
await log_action(
db, tenant_id, current_user.user_id, "config_remove",
resource_type="config", resource_id=str(device_id),
db,
tenant_id,
current_user.user_id,
"config_remove",
resource_type="config",
resource_id=str(device_id),
device_id=device_id,
details={"path": body.path, "entry_id": body.entry_id},
)
@@ -360,8 +368,12 @@ async def execute_command(
try:
await log_action(
db, tenant_id, current_user.user_id, "config_execute",
resource_type="config", resource_id=str(device_id),
db,
tenant_id,
current_user.user_id,
"config_execute",
resource_type="config",
resource_id=str(device_id),
device_id=device_id,
details={"command": body.command},
)

View File

@@ -43,6 +43,7 @@ async def _check_tenant_access(
"""
if current_user.is_super_admin:
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
@@ -115,9 +116,7 @@ async def view_snapshot(
session=db,
)
except Exception:
logger.exception(
"Failed to decrypt snapshot %s for device %s", snapshot_id, device_id
)
logger.exception("Failed to decrypt snapshot %s for device %s", snapshot_id, device_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to decrypt snapshot content",

View File

@@ -29,12 +29,14 @@ router = APIRouter(tags=["device-logs"])
# Helpers (same pattern as config_editor.py)
# ---------------------------------------------------------------------------
async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
"""Verify the current user is allowed to access the given tenant."""
if current_user.is_super_admin:
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
@@ -44,16 +46,12 @@ async def _check_tenant_access(
)
async def _check_device_exists(
db: AsyncSession, device_id: uuid.UUID
) -> None:
async def _check_device_exists(db: AsyncSession, device_id: uuid.UUID) -> None:
"""Verify the device exists (does not require online status for logs)."""
from sqlalchemy import select
from app.models.device import Device
result = await db.execute(
select(Device).where(Device.id == device_id)
)
result = await db.execute(select(Device).where(Device.id == device_id))
device = result.scalar_one_or_none()
if device is None:
raise HTTPException(
@@ -66,6 +64,7 @@ async def _check_device_exists(
# Response model
# ---------------------------------------------------------------------------
class LogEntry(BaseModel):
time: str
topics: str
@@ -82,6 +81,7 @@ class LogsResponse(BaseModel):
# Endpoint
# ---------------------------------------------------------------------------
@router.get(
"/tenants/{tenant_id}/devices/{device_id}/logs",
response_model=LogsResponse,

View File

@@ -72,9 +72,7 @@ async def update_tag(
) -> DeviceTagResponse:
"""Update a device tag. Requires operator role or above."""
await _check_tenant_access(current_user, tenant_id, db)
return await device_service.update_tag(
db=db, tenant_id=tenant_id, tag_id=tag_id, data=data
)
return await device_service.update_tag(db=db, tenant_id=tenant_id, tag_id=tag_id, data=data)
@router.delete(

View File

@@ -23,7 +23,6 @@ from app.database import get_db
from app.middleware.rate_limit import limiter
from app.services.audit_service import log_action
from app.middleware.rbac import (
require_min_role,
require_operator_or_above,
require_scope,
require_tenant_admin_or_above,
@@ -57,6 +56,7 @@ async def _check_tenant_access(
if current_user.is_super_admin:
# Re-set tenant context to the target tenant so RLS allows the operation
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
@@ -138,8 +138,12 @@ async def create_device(
)
try:
await log_action(
db, tenant_id, current_user.user_id, "device_create",
resource_type="device", resource_id=str(result.id),
db,
tenant_id,
current_user.user_id,
"device_create",
resource_type="device",
resource_id=str(result.id),
details={"hostname": data.hostname, "ip_address": data.ip_address},
ip_address=request.client.host if request.client else None,
)
@@ -191,8 +195,12 @@ async def update_device(
)
try:
await log_action(
db, tenant_id, current_user.user_id, "device_update",
resource_type="device", resource_id=str(device_id),
db,
tenant_id,
current_user.user_id,
"device_update",
resource_type="device",
resource_id=str(device_id),
device_id=device_id,
details={"changes": data.model_dump(exclude_unset=True)},
ip_address=request.client.host if request.client else None,
@@ -220,8 +228,12 @@ async def delete_device(
await _check_tenant_access(current_user, tenant_id, db)
try:
await log_action(
db, tenant_id, current_user.user_id, "device_delete",
resource_type="device", resource_id=str(device_id),
db,
tenant_id,
current_user.user_id,
"device_delete",
resource_type="device",
resource_id=str(device_id),
device_id=device_id,
ip_address=request.client.host if request.client else None,
)
@@ -262,14 +274,21 @@ async def scan_devices(
discovered = await scan_subnet(data.cidr)
import ipaddress
network = ipaddress.ip_network(data.cidr, strict=False)
total_scanned = network.num_addresses - 2 if network.num_addresses > 2 else network.num_addresses
total_scanned = (
network.num_addresses - 2 if network.num_addresses > 2 else network.num_addresses
)
# Audit log the scan (fire-and-forget — never breaks the response)
try:
await log_action(
db, tenant_id, current_user.user_id, "subnet_scan",
resource_type="network", resource_id=data.cidr,
db,
tenant_id,
current_user.user_id,
"subnet_scan",
resource_type="network",
resource_id=data.cidr,
details={
"cidr": data.cidr,
"devices_found": len(discovered),
@@ -322,10 +341,12 @@ async def bulk_add_devices(
password = dev_data.password or data.shared_password
if not username or not password:
failed.append({
"ip_address": dev_data.ip_address,
"error": "No credentials provided (set per-device or shared credentials)",
})
failed.append(
{
"ip_address": dev_data.ip_address,
"error": "No credentials provided (set per-device or shared credentials)",
}
)
continue
create_data = DeviceCreate(
@@ -347,9 +368,16 @@ async def bulk_add_devices(
added.append(device)
try:
await log_action(
db, tenant_id, current_user.user_id, "device_adopt",
resource_type="device", resource_id=str(device.id),
details={"hostname": create_data.hostname, "ip_address": create_data.ip_address},
db,
tenant_id,
current_user.user_id,
"device_adopt",
resource_type="device",
resource_id=str(device.id),
details={
"hostname": create_data.hostname,
"ip_address": create_data.ip_address,
},
ip_address=request.client.host if request.client else None,
)
except Exception:

View File

@@ -90,16 +90,18 @@ async def list_events(
for row in alert_result.fetchall():
alert_status = row[1] or "firing"
metric = row[3] or "unknown"
events.append({
"id": str(row[0]),
"event_type": "alert",
"severity": row[2],
"title": f"{alert_status}: {metric}",
"description": row[4] or f"Alert {alert_status} for {metric}",
"device_hostname": row[7],
"device_id": str(row[6]) if row[6] else None,
"timestamp": row[5].isoformat() if row[5] else None,
})
events.append(
{
"id": str(row[0]),
"event_type": "alert",
"severity": row[2],
"title": f"{alert_status}: {metric}",
"description": row[4] or f"Alert {alert_status} for {metric}",
"device_hostname": row[7],
"device_id": str(row[6]) if row[6] else None,
"timestamp": row[5].isoformat() if row[5] else None,
}
)
# 2. Device status changes (inferred from current status + last_seen)
if not event_type or event_type == "status_change":
@@ -117,16 +119,18 @@ async def list_events(
device_status = row[2] or "unknown"
hostname = row[1] or "Unknown device"
severity = "info" if device_status == "online" else "warning"
events.append({
"id": f"status-{row[0]}",
"event_type": "status_change",
"severity": severity,
"title": f"Device {device_status}",
"description": f"{hostname} is now {device_status}",
"device_hostname": hostname,
"device_id": str(row[0]),
"timestamp": row[3].isoformat() if row[3] else None,
})
events.append(
{
"id": f"status-{row[0]}",
"event_type": "status_change",
"severity": severity,
"title": f"Device {device_status}",
"description": f"{hostname} is now {device_status}",
"device_hostname": hostname,
"device_id": str(row[0]),
"timestamp": row[3].isoformat() if row[3] else None,
}
)
# 3. Config backup runs
if not event_type or event_type == "config_backup":
@@ -144,16 +148,18 @@ async def list_events(
for row in backup_result.fetchall():
trigger_type = row[1] or "manual"
hostname = row[4] or "Unknown device"
events.append({
"id": str(row[0]),
"event_type": "config_backup",
"severity": "info",
"title": "Config backup",
"description": f"{trigger_type} backup completed for {hostname}",
"device_hostname": hostname,
"device_id": str(row[3]) if row[3] else None,
"timestamp": row[2].isoformat() if row[2] else None,
})
events.append(
{
"id": str(row[0]),
"event_type": "config_backup",
"severity": "info",
"title": "Config backup",
"description": f"{trigger_type} backup completed for {hostname}",
"device_hostname": hostname,
"device_id": str(row[3]) if row[3] else None,
"timestamp": row[2].isoformat() if row[2] else None,
}
)
# Sort all events by timestamp descending, then apply final limit
events.sort(

View File

@@ -67,6 +67,7 @@ async def get_firmware_overview(
await _check_tenant_access(current_user, tenant_id, db)
from app.services.firmware_service import get_firmware_overview as _get_overview
return await _get_overview(str(tenant_id))
@@ -206,6 +207,7 @@ async def trigger_firmware_check(
raise HTTPException(status_code=403, detail="Super admin only")
from app.services.firmware_service import check_latest_versions
results = await check_latest_versions()
return {"status": "ok", "versions_discovered": len(results), "versions": results}
@@ -221,6 +223,7 @@ async def list_firmware_cache(
raise HTTPException(status_code=403, detail="Super admin only")
from app.services.firmware_service import get_cached_firmware
return await get_cached_firmware()
@@ -236,6 +239,7 @@ async def download_firmware(
raise HTTPException(status_code=403, detail="Super admin only")
from app.services.firmware_service import download_firmware as _download
path = await _download(body.architecture, body.channel, body.version)
return {"status": "ok", "path": path}
@@ -324,15 +328,21 @@ async def start_firmware_upgrade(
# Schedule or start immediately
if body.scheduled_at:
from app.services.upgrade_service import schedule_upgrade
schedule_upgrade(job_id, datetime.fromisoformat(body.scheduled_at))
else:
from app.services.upgrade_service import start_upgrade
asyncio.create_task(start_upgrade(job_id))
try:
await log_action(
db, tenant_id, current_user.user_id, "firmware_upgrade",
resource_type="firmware", resource_id=job_id,
db,
tenant_id,
current_user.user_id,
"firmware_upgrade",
resource_type="firmware",
resource_id=job_id,
device_id=uuid.UUID(body.device_id),
details={"target_version": body.target_version, "channel": body.channel},
)
@@ -406,9 +416,11 @@ async def start_mass_firmware_upgrade(
# Schedule or start immediately
if body.scheduled_at:
from app.services.upgrade_service import schedule_mass_upgrade
schedule_mass_upgrade(rollout_group_id, datetime.fromisoformat(body.scheduled_at))
else:
from app.services.upgrade_service import start_mass_upgrade
asyncio.create_task(start_mass_upgrade(rollout_group_id))
return {
@@ -639,6 +651,7 @@ async def cancel_upgrade_endpoint(
raise HTTPException(403, "Viewers cannot cancel upgrades")
from app.services.upgrade_service import cancel_upgrade
await cancel_upgrade(str(job_id))
return {"status": "ok", "message": "Upgrade cancelled"}
@@ -662,6 +675,7 @@ async def retry_upgrade_endpoint(
raise HTTPException(403, "Viewers cannot retry upgrades")
from app.services.upgrade_service import retry_failed_upgrade
await retry_failed_upgrade(str(job_id))
return {"status": "ok", "message": "Upgrade retry started"}
@@ -685,6 +699,7 @@ async def resume_rollout_endpoint(
raise HTTPException(403, "Viewers cannot resume rollouts")
from app.services.upgrade_service import resume_mass_upgrade
await resume_mass_upgrade(str(rollout_group_id))
return {"status": "ok", "message": "Rollout resumed"}
@@ -708,5 +723,6 @@ async def abort_rollout_endpoint(
raise HTTPException(403, "Viewers cannot abort rollouts")
from app.services.upgrade_service import abort_mass_upgrade
aborted = await abort_mass_upgrade(str(rollout_group_id))
return {"status": "ok", "aborted_count": aborted}

View File

@@ -299,9 +299,7 @@ async def delete_maintenance_window(
_require_operator(current_user)
result = await db.execute(
text(
"DELETE FROM maintenance_windows WHERE id = CAST(:id AS uuid) RETURNING id"
),
text("DELETE FROM maintenance_windows WHERE id = CAST(:id AS uuid) RETURNING id"),
{"id": str(window_id)},
)
if not result.fetchone():

View File

@@ -65,6 +65,7 @@ async def _check_tenant_access(
if current_user.is_super_admin:
# Re-set tenant context to the target tenant so RLS allows the operation
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:

View File

@@ -81,9 +81,12 @@ async def _get_device(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UU
return device
async def _check_tenant_access(current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession) -> None:
async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
if current_user.is_super_admin:
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
@@ -124,8 +127,12 @@ async def open_winbox_session(
try:
await log_action(
db, tenant_id, current_user.user_id, "winbox_tunnel_open",
resource_type="device", resource_id=str(device_id),
db,
tenant_id,
current_user.user_id,
"winbox_tunnel_open",
resource_type="device",
resource_id=str(device_id),
device_id=device_id,
details={"source_ip": source_ip},
ip_address=source_ip,
@@ -133,24 +140,31 @@ async def open_winbox_session(
except Exception:
pass
payload = json.dumps({
"device_id": str(device_id),
"tenant_id": str(tenant_id),
"user_id": str(current_user.user_id),
"target_port": 8291,
}).encode()
payload = json.dumps(
{
"device_id": str(device_id),
"tenant_id": str(tenant_id),
"user_id": str(current_user.user_id),
"target_port": 8291,
}
).encode()
try:
nc = await _get_nats()
msg = await nc.request("tunnel.open", payload, timeout=10)
except Exception as exc:
logger.error("NATS tunnel.open failed: %s", exc)
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Tunnel service unavailable")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Tunnel service unavailable"
)
try:
data = json.loads(msg.data)
except Exception:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Invalid response from tunnel service")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Invalid response from tunnel service",
)
if "error" in data:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=data["error"])
@@ -158,11 +172,16 @@ async def open_winbox_session(
port = data.get("local_port")
tunnel_id = data.get("tunnel_id", "")
if not isinstance(port, int) or not (49000 <= port <= 49100):
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Invalid port allocation from tunnel service")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Invalid port allocation from tunnel service",
)
# Derive the tunnel host from the request so remote clients get the server's
# address rather than 127.0.0.1 (which would point to the user's own machine).
tunnel_host = (request.headers.get("x-forwarded-host") or request.headers.get("host") or "127.0.0.1")
tunnel_host = (
request.headers.get("x-forwarded-host") or request.headers.get("host") or "127.0.0.1"
)
# Strip port from host header (e.g. "10.101.0.175:8001" → "10.101.0.175")
tunnel_host = tunnel_host.split(":")[0]
@@ -213,8 +232,12 @@ async def open_ssh_session(
try:
await log_action(
db, tenant_id, current_user.user_id, "ssh_session_open",
resource_type="device", resource_id=str(device_id),
db,
tenant_id,
current_user.user_id,
"ssh_session_open",
resource_type="device",
resource_id=str(device_id),
device_id=device_id,
details={"source_ip": source_ip, "cols": body.cols, "rows": body.rows},
ip_address=source_ip,
@@ -223,22 +246,26 @@ async def open_ssh_session(
pass
token = secrets.token_urlsafe(32)
token_payload = json.dumps({
"device_id": str(device_id),
"tenant_id": str(tenant_id),
"user_id": str(current_user.user_id),
"source_ip": source_ip,
"cols": body.cols,
"rows": body.rows,
"created_at": int(time.time()),
})
token_payload = json.dumps(
{
"device_id": str(device_id),
"tenant_id": str(tenant_id),
"user_id": str(current_user.user_id),
"source_ip": source_ip,
"cols": body.cols,
"rows": body.rows,
"created_at": int(time.time()),
}
)
try:
rd = await _get_redis()
await rd.setex(f"ssh:token:{token}", 120, token_payload)
except Exception as exc:
logger.error("Redis setex failed for SSH token: %s", exc)
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Session store unavailable")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Session store unavailable"
)
return SSHSessionResponse(
token=token,
@@ -274,8 +301,12 @@ async def close_winbox_session(
try:
await log_action(
db, tenant_id, current_user.user_id, "winbox_tunnel_close",
resource_type="device", resource_id=str(device_id),
db,
tenant_id,
current_user.user_id,
"winbox_tunnel_close",
resource_type="device",
resource_id=str(device_id),
device_id=device_id,
details={"tunnel_id": tunnel_id, "source_ip": source_ip},
ip_address=source_ip,

View File

@@ -72,7 +72,11 @@ async def _set_system_settings(updates: dict, user_id: str) -> None:
ON CONFLICT (key) DO UPDATE
SET value = :value, updated_by = CAST(:user_id AS uuid), updated_at = now()
"""),
{"key": key, "value": str(value) if value is not None else None, "user_id": user_id},
{
"key": key,
"value": str(value) if value is not None else None,
"user_id": user_id,
},
)
await session.commit()
@@ -100,7 +104,8 @@ async def get_smtp_settings(user=Depends(require_role("super_admin"))):
"smtp_host": db_settings.get("smtp_host") or settings.SMTP_HOST,
"smtp_port": int(db_settings.get("smtp_port") or settings.SMTP_PORT),
"smtp_user": db_settings.get("smtp_user") or settings.SMTP_USER or "",
"smtp_use_tls": (db_settings.get("smtp_use_tls") or str(settings.SMTP_USE_TLS)).lower() == "true",
"smtp_use_tls": (db_settings.get("smtp_use_tls") or str(settings.SMTP_USE_TLS)).lower()
== "true",
"smtp_from_address": db_settings.get("smtp_from_address") or settings.SMTP_FROM_ADDRESS,
"smtp_provider": db_settings.get("smtp_provider") or "custom",
"smtp_password_set": bool(db_settings.get("smtp_password") or settings.SMTP_PASSWORD),

View File

@@ -32,6 +32,7 @@ async def _get_sse_redis() -> aioredis.Redis:
global _redis
if _redis is None:
from app.config import settings
_redis = aioredis.from_url(settings.REDIS_URL, decode_responses=True)
return _redis
@@ -70,7 +71,9 @@ async def _validate_sse_token(token: str) -> dict:
async def event_stream(
request: Request,
tenant_id: uuid.UUID,
token: str = Query(..., description="Short-lived SSE exchange token (from POST /auth/sse-token)"),
token: str = Query(
..., description="Short-lived SSE exchange token (from POST /auth/sse-token)"
),
) -> EventSourceResponse:
"""Stream real-time events for a tenant via Server-Sent Events.
@@ -87,7 +90,9 @@ async def event_stream(
user_id = user_context.get("user_id", "")
# Authorization: user must belong to the requested tenant or be super_admin
if user_role != "super_admin" and (user_tenant_id is None or str(user_tenant_id) != str(tenant_id)):
if user_role != "super_admin" and (
user_tenant_id is None or str(user_tenant_id) != str(tenant_id)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authorized for this tenant",

View File

@@ -21,7 +21,6 @@ RBAC: viewer = read (GET/preview); operator and above = write (POST/PUT/DELETE/p
import asyncio
import logging
import uuid
from datetime import datetime, timezone
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
@@ -54,6 +53,7 @@ async def _check_tenant_access(
"""Verify the current user is allowed to access the given tenant."""
if current_user.is_super_admin:
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
@@ -191,7 +191,8 @@ async def create_template(
if unmatched:
logger.warning(
"Template '%s' has undeclared variables: %s (auto-adding as string type)",
body.name, unmatched,
body.name,
unmatched,
)
# Create template
@@ -553,11 +554,13 @@ async def push_template(
status="pending",
)
db.add(job)
jobs_created.append({
"job_id": str(job.id),
"device_id": str(device.id),
"device_hostname": device.hostname,
})
jobs_created.append(
{
"job_id": str(job.id),
"device_id": str(device.id),
"device_hostname": device.hostname,
}
)
await db.flush()
@@ -598,14 +601,16 @@ async def push_status(
jobs = []
for job, hostname in rows:
jobs.append({
"device_id": str(job.device_id),
"hostname": hostname,
"status": job.status,
"error_message": job.error_message,
"started_at": job.started_at.isoformat() if job.started_at else None,
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
})
jobs.append(
{
"device_id": str(job.device_id),
"hostname": hostname,
"status": job.status,
"error_message": job.error_message,
"started_at": job.started_at.isoformat() if job.started_at else None,
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
}
)
return {
"rollout_id": str(rollout_id),

View File

@@ -9,7 +9,6 @@ DELETE /api/tenants/{id} — delete tenant (super_admin only)
"""
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy import func, select, text
@@ -17,7 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.middleware.rate_limit import limiter
from app.database import get_admin_db, get_db
from app.database import get_admin_db
from app.middleware.rbac import require_super_admin, require_tenant_admin_or_above
from app.middleware.tenant_context import CurrentUser
from app.models.device import Device
@@ -70,15 +69,18 @@ async def list_tenants(
else:
if not current_user.tenant_id:
return []
result = await db.execute(
select(Tenant).where(Tenant.id == current_user.tenant_id)
)
result = await db.execute(select(Tenant).where(Tenant.id == current_user.tenant_id))
tenants = result.scalars().all()
return [await _get_tenant_response(tenant, db) for tenant in tenants]
@router.post("", response_model=TenantResponse, status_code=status.HTTP_201_CREATED, summary="Create a tenant")
@router.post(
"",
response_model=TenantResponse,
status_code=status.HTTP_201_CREATED,
summary="Create a tenant",
)
@limiter.limit("20/minute")
async def create_tenant(
request: Request,
@@ -108,13 +110,21 @@ async def create_tenant(
("Device Offline", "device_offline", "eq", 1, 1, "critical"),
]
for name, metric, operator, threshold, duration, sev in default_rules:
await db.execute(text("""
await db.execute(
text("""
INSERT INTO alert_rules (id, tenant_id, name, metric, operator, threshold, duration_polls, severity, enabled, is_default)
VALUES (gen_random_uuid(), CAST(:tenant_id AS uuid), :name, :metric, :operator, :threshold, :duration, :severity, TRUE, TRUE)
"""), {
"tenant_id": str(tenant.id), "name": name, "metric": metric,
"operator": operator, "threshold": threshold, "duration": duration, "severity": sev,
})
"""),
{
"tenant_id": str(tenant.id),
"name": name,
"metric": metric,
"operator": operator,
"threshold": threshold,
"duration": duration,
"severity": sev,
},
)
await db.commit()
# Seed starter config templates for new tenant
@@ -131,6 +141,7 @@ async def create_tenant(
await db.commit()
except Exception as exc:
import logging
logging.getLogger(__name__).warning(
"OpenBao key provisioning failed for tenant %s (will be provisioned on next startup): %s",
tenant.id,
@@ -229,6 +240,7 @@ async def delete_tenant(
# Check if tenant had VPN configured (before cascade deletes it)
from app.services.vpn_service import get_vpn_config, sync_wireguard_config
had_vpn = await get_vpn_config(db, tenant_id)
await db.delete(tenant)
@@ -288,14 +300,54 @@ add chain=forward action=drop comment="Drop everything else"
# Identity
/system identity set name={{ device.hostname }}""",
"variables": [
{"name": "wan_interface", "type": "string", "default": "ether1", "description": "WAN-facing interface"},
{"name": "lan_gateway", "type": "ip", "default": "192.168.88.1", "description": "LAN gateway IP"},
{"name": "lan_cidr", "type": "integer", "default": "24", "description": "LAN subnet mask bits"},
{"name": "lan_network", "type": "ip", "default": "192.168.88.0", "description": "LAN network address"},
{"name": "dhcp_start", "type": "ip", "default": "192.168.88.100", "description": "DHCP pool start"},
{"name": "dhcp_end", "type": "ip", "default": "192.168.88.254", "description": "DHCP pool end"},
{"name": "dns_servers", "type": "string", "default": "8.8.8.8,8.8.4.4", "description": "Upstream DNS servers"},
{"name": "ntp_server", "type": "string", "default": "pool.ntp.org", "description": "NTP server"},
{
"name": "wan_interface",
"type": "string",
"default": "ether1",
"description": "WAN-facing interface",
},
{
"name": "lan_gateway",
"type": "ip",
"default": "192.168.88.1",
"description": "LAN gateway IP",
},
{
"name": "lan_cidr",
"type": "integer",
"default": "24",
"description": "LAN subnet mask bits",
},
{
"name": "lan_network",
"type": "ip",
"default": "192.168.88.0",
"description": "LAN network address",
},
{
"name": "dhcp_start",
"type": "ip",
"default": "192.168.88.100",
"description": "DHCP pool start",
},
{
"name": "dhcp_end",
"type": "ip",
"default": "192.168.88.254",
"description": "DHCP pool end",
},
{
"name": "dns_servers",
"type": "string",
"default": "8.8.8.8,8.8.4.4",
"description": "Upstream DNS servers",
},
{
"name": "ntp_server",
"type": "string",
"default": "pool.ntp.org",
"description": "NTP server",
},
],
},
{
@@ -311,8 +363,18 @@ add chain=forward connection-state=invalid action=drop
add chain=forward src-address={{ allowed_network }} action=accept
add chain=forward action=drop""",
"variables": [
{"name": "wan_interface", "type": "string", "default": "ether1", "description": "WAN-facing interface"},
{"name": "allowed_network", "type": "subnet", "default": "192.168.88.0/24", "description": "Allowed source network"},
{
"name": "wan_interface",
"type": "string",
"default": "ether1",
"description": "WAN-facing interface",
},
{
"name": "allowed_network",
"type": "subnet",
"default": "192.168.88.0/24",
"description": "Allowed source network",
},
],
},
{
@@ -322,11 +384,36 @@ add chain=forward action=drop""",
/ip dhcp-server network add address={{ gateway }}/24 gateway={{ gateway }} dns-server={{ dns_server }}
/ip dhcp-server add name=dhcp1 interface={{ interface }} address-pool=dhcp-pool disabled=no""",
"variables": [
{"name": "pool_start", "type": "ip", "default": "192.168.88.100", "description": "DHCP pool start address"},
{"name": "pool_end", "type": "ip", "default": "192.168.88.254", "description": "DHCP pool end address"},
{"name": "gateway", "type": "ip", "default": "192.168.88.1", "description": "Default gateway"},
{"name": "dns_server", "type": "ip", "default": "8.8.8.8", "description": "DNS server address"},
{"name": "interface", "type": "string", "default": "bridge-lan", "description": "Interface to serve DHCP on"},
{
"name": "pool_start",
"type": "ip",
"default": "192.168.88.100",
"description": "DHCP pool start address",
},
{
"name": "pool_end",
"type": "ip",
"default": "192.168.88.254",
"description": "DHCP pool end address",
},
{
"name": "gateway",
"type": "ip",
"default": "192.168.88.1",
"description": "Default gateway",
},
{
"name": "dns_server",
"type": "ip",
"default": "8.8.8.8",
"description": "DNS server address",
},
{
"name": "interface",
"type": "string",
"default": "bridge-lan",
"description": "Interface to serve DHCP on",
},
],
},
{
@@ -335,10 +422,30 @@ add chain=forward action=drop""",
"content": """/interface wireless security-profiles add name=portal-wpa2 mode=dynamic-keys authentication-types=wpa2-psk wpa2-pre-shared-key={{ password }}
/interface wireless set wlan1 mode=ap-bridge ssid={{ ssid }} security-profile=portal-wpa2 frequency={{ frequency }} channel-width={{ channel_width }} disabled=no""",
"variables": [
{"name": "ssid", "type": "string", "default": "MikroTik-AP", "description": "Wireless network name"},
{"name": "password", "type": "string", "default": "", "description": "WPA2 pre-shared key (min 8 characters)"},
{"name": "frequency", "type": "integer", "default": "2412", "description": "Wireless frequency in MHz"},
{"name": "channel_width", "type": "string", "default": "20/40mhz-XX", "description": "Channel width setting"},
{
"name": "ssid",
"type": "string",
"default": "MikroTik-AP",
"description": "Wireless network name",
},
{
"name": "password",
"type": "string",
"default": "",
"description": "WPA2 pre-shared key (min 8 characters)",
},
{
"name": "frequency",
"type": "integer",
"default": "2412",
"description": "Wireless frequency in MHz",
},
{
"name": "channel_width",
"type": "string",
"default": "20/40mhz-XX",
"description": "Channel width setting",
},
],
},
{
@@ -351,8 +458,18 @@ add chain=forward action=drop""",
/ip service set ssh port=22
/ip service set winbox port=8291""",
"variables": [
{"name": "ntp_server", "type": "ip", "default": "pool.ntp.org", "description": "NTP server address"},
{"name": "dns_servers", "type": "string", "default": "8.8.8.8,8.8.4.4", "description": "Comma-separated DNS servers"},
{
"name": "ntp_server",
"type": "ip",
"default": "pool.ntp.org",
"description": "NTP server address",
},
{
"name": "dns_servers",
"type": "string",
"default": "8.8.8.8,8.8.4.4",
"description": "Comma-separated DNS servers",
},
],
},
]
@@ -363,13 +480,16 @@ async def _seed_starter_templates(db, tenant_id) -> None:
import json as _json
for tmpl in _STARTER_TEMPLATES:
await db.execute(text("""
await db.execute(
text("""
INSERT INTO config_templates (id, tenant_id, name, description, content, variables)
VALUES (gen_random_uuid(), CAST(:tid AS uuid), :name, :desc, :content, CAST(:vars AS jsonb))
"""), {
"tid": str(tenant_id),
"name": tmpl["name"],
"desc": tmpl["description"],
"content": tmpl["content"],
"vars": _json.dumps(tmpl["variables"]),
})
"""),
{
"tid": str(tenant_id),
"name": tmpl["name"],
"desc": tmpl["description"],
"content": tmpl["content"],
"vars": _json.dumps(tmpl["variables"]),
},
)

View File

@@ -14,7 +14,6 @@ Builds a topology graph of managed devices by:
import asyncio
import ipaddress
import json
import logging
import uuid
from typing import Any
@@ -265,7 +264,7 @@ async def get_topology(
nodes: list[TopologyNode] = []
ip_to_device: dict[str, str] = {}
online_device_ids: list[str] = []
devices_by_id: dict[str, Any] = {}
_devices_by_id: dict[str, Any] = {}
for row in rows:
device_id = str(row.id)
@@ -288,9 +287,7 @@ async def get_topology(
if online_device_ids:
tasks = [
routeros_proxy.execute_command(
device_id, "/ip/neighbor/print", timeout=10.0
)
routeros_proxy.execute_command(device_id, "/ip/neighbor/print", timeout=10.0)
for device_id in online_device_ids
]
results = await asyncio.gather(*tasks, return_exceptions=True)

View File

@@ -164,9 +164,7 @@ async def list_transparency_logs(
# Count total
count_result = await db.execute(
select(func.count())
.select_from(text("key_access_log k"))
.where(where_clause),
select(func.count()).select_from(text("key_access_log k")).where(where_clause),
params,
)
total = count_result.scalar() or 0
@@ -353,39 +351,41 @@ async def export_transparency_logs(
output = io.StringIO()
writer = csv.writer(output)
writer.writerow([
"ID",
"Action",
"Device Name",
"Device ID",
"Justification",
"Operator Email",
"Correlation ID",
"Resource Type",
"Resource ID",
"IP Address",
"Timestamp",
])
writer.writerow(
[
"ID",
"Action",
"Device Name",
"Device ID",
"Justification",
"Operator Email",
"Correlation ID",
"Resource Type",
"Resource ID",
"IP Address",
"Timestamp",
]
)
for row in all_rows:
writer.writerow([
str(row["id"]),
row["action"],
row["device_name"] or "",
str(row["device_id"]) if row["device_id"] else "",
row["justification"] or "",
row["operator_email"] or "",
row["correlation_id"] or "",
row["resource_type"] or "",
row["resource_id"] or "",
row["ip_address"] or "",
str(row["created_at"]),
])
writer.writerow(
[
str(row["id"]),
row["action"],
row["device_name"] or "",
str(row["device_id"]) if row["device_id"] else "",
row["justification"] or "",
row["operator_email"] or "",
row["correlation_id"] or "",
row["resource_type"] or "",
row["resource_id"] or "",
row["ip_address"] or "",
str(row["created_at"]),
]
)
output.seek(0)
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={
"Content-Disposition": "attachment; filename=transparency-logs.csv"
},
headers={"Content-Disposition": "attachment; filename=transparency-logs.csv"},
)

View File

@@ -20,7 +20,7 @@ from app.database import get_admin_db
from app.middleware.rbac import require_tenant_admin_or_above
from app.middleware.tenant_context import CurrentUser
from app.models.tenant import Tenant
from app.models.user import User, UserRole
from app.models.user import User
from app.schemas.user import UserCreate, UserResponse, UserUpdate
from app.services.auth import hash_password
@@ -69,11 +69,7 @@ async def list_users(
"""
await _check_tenant_access(tenant_id, current_user, db)
result = await db.execute(
select(User)
.where(User.tenant_id == tenant_id)
.order_by(User.name)
)
result = await db.execute(select(User).where(User.tenant_id == tenant_id).order_by(User.name))
users = result.scalars().all()
return [UserResponse.model_validate(user) for user in users]
@@ -103,9 +99,7 @@ async def create_user(
await _check_tenant_access(tenant_id, current_user, db)
# Check email uniqueness (global, not per-tenant)
existing = await db.execute(
select(User).where(User.email == data.email.lower())
)
existing = await db.execute(select(User).where(User.email == data.email.lower()))
if existing.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
@@ -138,9 +132,7 @@ async def get_user(
"""Get user detail."""
await _check_tenant_access(tenant_id, current_user, db)
result = await db.execute(
select(User).where(User.id == user_id, User.tenant_id == tenant_id)
)
result = await db.execute(select(User).where(User.id == user_id, User.tenant_id == tenant_id))
user = result.scalar_one_or_none()
if not user:
@@ -168,9 +160,7 @@ async def update_user(
"""
await _check_tenant_access(tenant_id, current_user, db)
result = await db.execute(
select(User).where(User.id == user_id, User.tenant_id == tenant_id)
)
result = await db.execute(select(User).where(User.id == user_id, User.tenant_id == tenant_id))
user = result.scalar_one_or_none()
if not user:
@@ -194,7 +184,11 @@ async def update_user(
return UserResponse.model_validate(user)
@router.delete("/{tenant_id}/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Deactivate a user")
@router.delete(
"/{tenant_id}/users/{user_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Deactivate a user",
)
@limiter.limit("5/minute")
async def deactivate_user(
request: Request,
@@ -209,9 +203,7 @@ async def deactivate_user(
"""
await _check_tenant_access(tenant_id, current_user, db)
result = await db.execute(
select(User).where(User.id == user_id, User.tenant_id == tenant_id)
)
result = await db.execute(select(User).where(User.id == user_id, User.tenant_id == tenant_id))
user = result.scalar_one_or_none()
if not user:

View File

@@ -68,7 +68,11 @@ async def get_vpn_config(
return resp
@router.post("/tenants/{tenant_id}/vpn", response_model=VpnConfigResponse, status_code=status.HTTP_201_CREATED)
@router.post(
"/tenants/{tenant_id}/vpn",
response_model=VpnConfigResponse,
status_code=status.HTTP_201_CREATED,
)
@limiter.limit("20/minute")
async def setup_vpn(
request: Request,
@@ -177,7 +181,11 @@ async def list_peers(
return responses
@router.post("/tenants/{tenant_id}/vpn/peers", response_model=VpnPeerResponse, status_code=status.HTTP_201_CREATED)
@router.post(
"/tenants/{tenant_id}/vpn/peers",
response_model=VpnPeerResponse,
status_code=status.HTTP_201_CREATED,
)
@limiter.limit("20/minute")
async def add_peer(
request: Request,
@@ -190,7 +198,9 @@ async def add_peer(
await _check_tenant_access(current_user, tenant_id, db)
_require_operator(current_user)
try:
peer = await vpn_service.add_peer(db, tenant_id, body.device_id, additional_allowed_ips=body.additional_allowed_ips)
peer = await vpn_service.add_peer(
db, tenant_id, body.device_id, additional_allowed_ips=body.additional_allowed_ips
)
except ValueError as e:
msg = str(e)
if "must not overlap" in msg:
@@ -208,7 +218,11 @@ async def add_peer(
return resp
@router.post("/tenants/{tenant_id}/vpn/peers/onboard", response_model=VpnOnboardResponse, status_code=status.HTTP_201_CREATED)
@router.post(
"/tenants/{tenant_id}/vpn/peers/onboard",
response_model=VpnOnboardResponse,
status_code=status.HTTP_201_CREATED,
)
@limiter.limit("10/minute")
async def onboard_device(
request: Request,
@@ -222,7 +236,8 @@ async def onboard_device(
_require_operator(current_user)
try:
result = await vpn_service.onboard_device(
db, tenant_id,
db,
tenant_id,
hostname=body.hostname,
username=body.username,
password=body.password,

View File

@@ -146,16 +146,16 @@ async def _delete_session_from_redis(session_id: str) -> None:
await rd.delete(f"{REDIS_PREFIX}{session_id}")
async def _open_tunnel(
device_id: uuid.UUID, tenant_id: uuid.UUID, user_id: uuid.UUID
) -> dict:
async def _open_tunnel(device_id: uuid.UUID, tenant_id: uuid.UUID, user_id: uuid.UUID) -> dict:
"""Open a TCP tunnel to device port 8291 via NATS request-reply."""
payload = json.dumps({
"device_id": str(device_id),
"tenant_id": str(tenant_id),
"user_id": str(user_id),
"target_port": 8291,
}).encode()
payload = json.dumps(
{
"device_id": str(device_id),
"tenant_id": str(tenant_id),
"user_id": str(user_id),
"target_port": 8291,
}
).encode()
try:
nc = await _get_nats()
@@ -176,9 +176,7 @@ async def _open_tunnel(
)
if "error" in data:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=data["error"]
)
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=data["error"])
return data
@@ -250,9 +248,7 @@ async def create_winbox_remote_session(
except Exception:
worker_info = None
if worker_info is None:
logger.warning(
"Cleaning stale Redis session %s (worker 404)", stale_sid
)
logger.warning("Cleaning stale Redis session %s (worker 404)", stale_sid)
tunnel_id = sess.get("tunnel_id")
if tunnel_id:
await _close_tunnel(tunnel_id)
@@ -333,12 +329,8 @@ async def create_winbox_remote_session(
username = "" # noqa: F841
password = "" # noqa: F841
expires_at = datetime.fromisoformat(
worker_resp.get("expires_at", now.isoformat())
)
max_expires_at = datetime.fromisoformat(
worker_resp.get("max_expires_at", now.isoformat())
)
expires_at = datetime.fromisoformat(worker_resp.get("expires_at", now.isoformat()))
max_expires_at = datetime.fromisoformat(worker_resp.get("max_expires_at", now.isoformat()))
# Save session to Redis
session_data = {
@@ -375,8 +367,7 @@ async def create_winbox_remote_session(
pass
ws_path = (
f"/api/tenants/{tenant_id}/devices/{device_id}"
f"/winbox-remote-sessions/{session_id}/ws"
f"/api/tenants/{tenant_id}/devices/{device_id}/winbox-remote-sessions/{session_id}/ws"
)
return RemoteWinboxSessionResponse(
@@ -425,14 +416,10 @@ async def get_winbox_remote_session(
sess = await _get_session_from_redis(str(session_id))
if sess is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Session not found"
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
if sess.get("tenant_id") != str(tenant_id) or sess.get("device_id") != str(device_id):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Session not found"
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
return RemoteWinboxStatusResponse(
session_id=uuid.UUID(sess["session_id"]),
@@ -478,10 +465,7 @@ async def list_winbox_remote_sessions(
sess = json.loads(raw)
except Exception:
continue
if (
sess.get("tenant_id") == str(tenant_id)
and sess.get("device_id") == str(device_id)
):
if sess.get("tenant_id") == str(tenant_id) and sess.get("device_id") == str(device_id):
sessions.append(
RemoteWinboxStatusResponse(
session_id=uuid.UUID(sess["session_id"]),
@@ -533,9 +517,7 @@ async def terminate_winbox_remote_session(
)
if sess.get("tenant_id") != str(tenant_id):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Session not found"
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
# Rollback order: worker -> tunnel -> redis -> audit
await worker_terminate_session(str(session_id))
@@ -574,14 +556,12 @@ async def terminate_winbox_remote_session(
@router.get(
"/tenants/{tenant_id}/devices/{device_id}"
"/winbox-remote-sessions/{session_id}/xpra/{path:path}",
"/tenants/{tenant_id}/devices/{device_id}/winbox-remote-sessions/{session_id}/xpra/{path:path}",
summary="Proxy Xpra HTML5 client files",
dependencies=[Depends(require_operator_or_above)],
)
@router.get(
"/tenants/{tenant_id}/devices/{device_id}"
"/winbox-remote-sessions/{session_id}/xpra",
"/tenants/{tenant_id}/devices/{device_id}/winbox-remote-sessions/{session_id}/xpra",
summary="Proxy Xpra HTML5 client (root)",
dependencies=[Depends(require_operator_or_above)],
)
@@ -626,7 +606,8 @@ async def proxy_xpra_html(
content=proxy_resp.content,
status_code=proxy_resp.status_code,
headers={
k: v for k, v in proxy_resp.headers.items()
k: v
for k, v in proxy_resp.headers.items()
if k.lower() in ("content-type", "cache-control", "content-encoding")
},
)
@@ -637,9 +618,7 @@ async def proxy_xpra_html(
# ---------------------------------------------------------------------------
@router.websocket(
"/tenants/{tenant_id}/devices/{device_id}/winbox-remote-sessions/{session_id}/ws"
)
@router.websocket("/tenants/{tenant_id}/devices/{device_id}/winbox-remote-sessions/{session_id}/ws")
async def winbox_remote_ws_proxy(
websocket: WebSocket,
tenant_id: uuid.UUID,

View File

@@ -67,11 +67,13 @@ class MessageResponse(BaseModel):
class SRPInitRequest(BaseModel):
"""Step 1 request: client sends email to begin SRP handshake."""
email: EmailStr
class SRPInitResponse(BaseModel):
"""Step 1 response: server returns ephemeral B and key derivation salts."""
salt: str # hex-encoded SRP salt
server_public: str # hex-encoded server ephemeral B
session_id: str # Redis session key nonce
@@ -81,6 +83,7 @@ class SRPInitResponse(BaseModel):
class SRPVerifyRequest(BaseModel):
"""Step 2 request: client sends proof M1 to complete handshake."""
email: EmailStr
session_id: str
client_public: str # hex-encoded client ephemeral A
@@ -89,6 +92,7 @@ class SRPVerifyRequest(BaseModel):
class SRPVerifyResponse(BaseModel):
"""Step 2 response: server returns tokens and proof M2."""
access_token: str
refresh_token: str
token_type: str = "bearer"
@@ -98,6 +102,7 @@ class SRPVerifyResponse(BaseModel):
class SRPRegisterRequest(BaseModel):
"""Used during registration to store SRP verifier and key set."""
srp_salt: str # hex-encoded
srp_verifier: str # hex-encoded
encrypted_private_key: str # base64-encoded
@@ -114,10 +119,12 @@ class SRPRegisterRequest(BaseModel):
class DeleteAccountRequest(BaseModel):
"""Request body for account self-deletion. User must type 'DELETE' to confirm."""
confirmation: str # Must be "DELETE" to confirm
class DeleteAccountResponse(BaseModel):
"""Response after successful account deletion."""
message: str
deleted: bool

View File

@@ -10,6 +10,7 @@ from pydantic import BaseModel, ConfigDict
# Request schemas
# ---------------------------------------------------------------------------
class CACreateRequest(BaseModel):
"""Request to generate a new root CA for the tenant."""
@@ -34,6 +35,7 @@ class BulkCertDeployRequest(BaseModel):
# Response schemas
# ---------------------------------------------------------------------------
class CAResponse(BaseModel):
"""Public details of a tenant's Certificate Authority (no private key)."""

View File

@@ -117,6 +117,7 @@ class SubnetScanRequest(BaseModel):
def validate_cidr(cls, v: str) -> str:
"""Validate that the value is a valid CIDR notation and RFC 1918 private range."""
import ipaddress
try:
network = ipaddress.ip_network(v, strict=False)
except ValueError as e:
@@ -239,6 +240,7 @@ class DeviceTagCreate(BaseModel):
if v is None:
return v
import re
if not re.match(r"^#[0-9A-Fa-f]{6}$", v):
raise ValueError("Color must be a valid 6-digit hex color (e.g. #FF5733)")
return v
@@ -256,6 +258,7 @@ class DeviceTagUpdate(BaseModel):
if v is None:
return v
import re
if not re.match(r"^#[0-9A-Fa-f]{6}$", v):
raise ValueError("Color must be a valid 6-digit hex color (e.g. #FF5733)")
return v

View File

@@ -12,11 +12,15 @@ from pydantic import BaseModel
class VpnSetupRequest(BaseModel):
"""Request to enable VPN for a tenant."""
endpoint: Optional[str] = None # public hostname:port — if blank, devices must be configured manually
endpoint: Optional[str] = (
None # public hostname:port — if blank, devices must be configured manually
)
class VpnConfigResponse(BaseModel):
"""VPN server configuration (never exposes private key)."""
model_config = {"from_attributes": True}
id: uuid.UUID
@@ -33,6 +37,7 @@ class VpnConfigResponse(BaseModel):
class VpnConfigUpdate(BaseModel):
"""Update VPN configuration."""
endpoint: Optional[str] = None
is_enabled: Optional[bool] = None
@@ -42,12 +47,14 @@ class VpnConfigUpdate(BaseModel):
class VpnPeerCreate(BaseModel):
"""Add a device as a VPN peer."""
device_id: uuid.UUID
additional_allowed_ips: Optional[str] = None # comma-separated subnets for site-to-site routing
class VpnPeerResponse(BaseModel):
"""VPN peer info (never exposes private key)."""
model_config = {"from_attributes": True}
id: uuid.UUID
@@ -66,6 +73,7 @@ class VpnPeerResponse(BaseModel):
class VpnOnboardRequest(BaseModel):
"""Combined device creation + VPN peer onboarding."""
hostname: str
username: str
password: str
@@ -73,6 +81,7 @@ class VpnOnboardRequest(BaseModel):
class VpnOnboardResponse(BaseModel):
"""Response from onboarding — device, peer, and RouterOS commands."""
device_id: uuid.UUID
peer_id: uuid.UUID
hostname: str
@@ -82,6 +91,7 @@ class VpnOnboardResponse(BaseModel):
class VpnPeerConfig(BaseModel):
"""Full peer config for display/export — includes private key for device setup."""
peer_private_key: str
peer_public_key: str
assigned_ip: str

View File

@@ -57,6 +57,7 @@ class RemoteWinboxDuplicateDetail(BaseModel):
class RemoteWinboxSessionItem(BaseModel):
"""Used in the combined active sessions list."""
session_id: uuid.UUID
status: RemoteWinboxState
created_at: datetime

View File

@@ -86,11 +86,7 @@ async def delete_user_account(
# Null out encrypted_details (may contain encrypted PII)
await db.execute(
text(
"UPDATE audit_logs "
"SET encrypted_details = NULL "
"WHERE user_id = :user_id"
),
text("UPDATE audit_logs SET encrypted_details = NULL WHERE user_id = :user_id"),
{"user_id": user_id},
)
@@ -177,16 +173,18 @@ async def export_user_data(
)
api_keys = []
for row in result.mappings().all():
api_keys.append({
"id": str(row["id"]),
"name": row["name"],
"key_prefix": row["key_prefix"],
"scopes": row["scopes"],
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
"expires_at": row["expires_at"].isoformat() if row["expires_at"] else None,
"revoked_at": row["revoked_at"].isoformat() if row["revoked_at"] else None,
"last_used_at": row["last_used_at"].isoformat() if row["last_used_at"] else None,
})
api_keys.append(
{
"id": str(row["id"]),
"name": row["name"],
"key_prefix": row["key_prefix"],
"scopes": row["scopes"],
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
"expires_at": row["expires_at"].isoformat() if row["expires_at"] else None,
"revoked_at": row["revoked_at"].isoformat() if row["revoked_at"] else None,
"last_used_at": row["last_used_at"].isoformat() if row["last_used_at"] else None,
}
)
# ── Audit logs (limit 1000, most recent first) ───────────────────────
result = await db.execute(
@@ -201,15 +199,17 @@ async def export_user_data(
audit_logs = []
for row in result.mappings().all():
details = row["details"] if row["details"] else {}
audit_logs.append({
"id": str(row["id"]),
"action": row["action"],
"resource_type": row["resource_type"],
"resource_id": row["resource_id"],
"details": details,
"ip_address": row["ip_address"],
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
})
audit_logs.append(
{
"id": str(row["id"]),
"action": row["action"],
"resource_type": row["resource_type"],
"resource_id": row["resource_id"],
"details": details,
"ip_address": row["ip_address"],
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
}
)
# ── Key access log (limit 1000, most recent first) ───────────────────
result = await db.execute(
@@ -222,13 +222,15 @@ async def export_user_data(
)
key_access_entries = []
for row in result.mappings().all():
key_access_entries.append({
"id": str(row["id"]),
"action": row["action"],
"resource_type": row["resource_type"],
"ip_address": row["ip_address"],
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
})
key_access_entries.append(
{
"id": str(row["id"]),
"action": row["action"],
"resource_type": row["resource_type"],
"ip_address": row["ip_address"],
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
}
)
return {
"export_date": datetime.now(UTC).isoformat(),

View File

@@ -253,7 +253,9 @@ async def _get_device_groups(device_id: str) -> list[str]:
"""Get group IDs for a device."""
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("SELECT group_id FROM device_group_memberships WHERE device_id = CAST(:device_id AS uuid)"),
text(
"SELECT group_id FROM device_group_memberships WHERE device_id = CAST(:device_id AS uuid)"
),
{"device_id": device_id},
)
return [str(row[0]) for row in result.fetchall()]
@@ -344,30 +346,36 @@ async def _create_alert_event(
# Publish real-time event to NATS for SSE pipeline (fire-and-forget)
if status in ("firing", "flapping"):
await publish_event(f"alert.fired.{tenant_id}", {
"event_type": "alert_fired",
"tenant_id": tenant_id,
"device_id": device_id,
"alert_event_id": alert_data["id"],
"severity": severity,
"metric": metric,
"current_value": value,
"threshold": threshold,
"message": message,
"is_flapping": is_flapping,
"fired_at": datetime.now(timezone.utc).isoformat(),
})
await publish_event(
f"alert.fired.{tenant_id}",
{
"event_type": "alert_fired",
"tenant_id": tenant_id,
"device_id": device_id,
"alert_event_id": alert_data["id"],
"severity": severity,
"metric": metric,
"current_value": value,
"threshold": threshold,
"message": message,
"is_flapping": is_flapping,
"fired_at": datetime.now(timezone.utc).isoformat(),
},
)
elif status == "resolved":
await publish_event(f"alert.resolved.{tenant_id}", {
"event_type": "alert_resolved",
"tenant_id": tenant_id,
"device_id": device_id,
"alert_event_id": alert_data["id"],
"severity": severity,
"metric": metric,
"message": message,
"resolved_at": datetime.now(timezone.utc).isoformat(),
})
await publish_event(
f"alert.resolved.{tenant_id}",
{
"event_type": "alert_resolved",
"tenant_id": tenant_id,
"device_id": device_id,
"alert_event_id": alert_data["id"],
"severity": severity,
"metric": metric,
"message": message,
"resolved_at": datetime.now(timezone.utc).isoformat(),
},
)
return alert_data
@@ -470,6 +478,7 @@ async def _dispatch_async(alert_event: dict, channels: list[dict], device_hostna
"""Fire-and-forget notification dispatch."""
try:
from app.services.notification_service import dispatch_notifications
await dispatch_notifications(alert_event, channels, device_hostname)
except Exception as e:
logger.warning("Notification dispatch failed: %s", e)
@@ -500,7 +509,8 @@ async def evaluate(
if await _is_device_in_maintenance(tenant_id, device_id):
logger.debug(
"Alert suppressed by maintenance window for device %s tenant %s",
device_id, tenant_id,
device_id,
tenant_id,
)
return
@@ -573,7 +583,8 @@ async def evaluate(
if is_flapping:
logger.info(
"Alert %s for device %s is flapping — notifications suppressed",
rule["name"], device_id,
rule["name"],
device_id,
)
else:
channels = await _get_channels_for_rule(rule["id"])

View File

@@ -56,9 +56,7 @@ async def log_action(
try:
from app.services.crypto import encrypt_data_transit
encrypted_details = await encrypt_data_transit(
details_json, str(tenant_id)
)
encrypted_details = await encrypt_data_transit(details_json, str(tenant_id))
# Encryption succeeded — clear plaintext details
details_json = _json.dumps({})
except Exception:

View File

@@ -32,8 +32,12 @@ def _cron_to_trigger(cron_expr: str) -> Optional[CronTrigger]:
return None
minute, hour, day, month, day_of_week = parts
return CronTrigger(
minute=minute, hour=hour, day=day, month=month,
day_of_week=day_of_week, timezone="UTC",
minute=minute,
hour=hour,
day=day,
month=month,
day_of_week=day_of_week,
timezone="UTC",
)
except Exception as e:
logger.warning("Invalid cron expression '%s': %s", cron_expr, e)
@@ -52,10 +56,12 @@ def build_schedule_map(schedules: list) -> dict[str, list[dict]]:
cron = s.cron_expression or DEFAULT_CRON
if cron not in schedule_map:
schedule_map[cron] = []
schedule_map[cron].append({
"device_id": str(s.device_id),
"tenant_id": str(s.tenant_id),
})
schedule_map[cron].append(
{
"device_id": str(s.device_id),
"tenant_id": str(s.tenant_id),
}
)
return schedule_map
@@ -79,13 +85,15 @@ async def _run_scheduled_backups(devices: list[dict]) -> None:
except Exception as e:
logger.error(
"Scheduled backup FAILED: device %s: %s",
dev_info["device_id"], e,
dev_info["device_id"],
e,
)
failure_count += 1
logger.info(
"Backup batch complete — %d succeeded, %d failed",
success_count, failure_count,
success_count,
failure_count,
)
@@ -108,7 +116,7 @@ async def _load_effective_schedules() -> list:
# Index: device-specific and tenant defaults
device_schedules = {} # device_id -> schedule
tenant_defaults = {} # tenant_id -> schedule
tenant_defaults = {} # tenant_id -> schedule
for s in schedules:
if s.device_id:
@@ -129,12 +137,14 @@ async def _load_effective_schedules() -> list:
# No schedule configured — use system default
sched = None
effective.append(SimpleNamespace(
device_id=dev_id,
tenant_id=tenant_id,
cron_expression=sched.cron_expression if sched else DEFAULT_CRON,
enabled=sched.enabled if sched else True,
))
effective.append(
SimpleNamespace(
device_id=dev_id,
tenant_id=tenant_id,
cron_expression=sched.cron_expression if sched else DEFAULT_CRON,
enabled=sched.enabled if sched else True,
)
)
return effective
@@ -203,6 +213,7 @@ class _SchedulerProxy:
Usage: `from app.services.backup_scheduler import backup_scheduler`
then `backup_scheduler.add_job(...)`.
"""
def __getattr__(self, name):
if _scheduler is None:
raise RuntimeError("Backup scheduler not started yet")

View File

@@ -255,7 +255,12 @@ async def run_backup(
if prior_commits:
try:
prior_export_bytes = await loop.run_in_executor(
None, git_store.read_file, tenant_id, prior_commits[0]["sha"], device_id, "export.rsc"
None,
git_store.read_file,
tenant_id,
prior_commits[0]["sha"],
device_id,
"export.rsc",
)
prior_text = prior_export_bytes.decode("utf-8", errors="replace")
lines_added, lines_removed = await loop.run_in_executor(
@@ -284,9 +289,7 @@ async def run_backup(
try:
from app.services.crypto import encrypt_data_transit
encrypted_export = await encrypt_data_transit(
export_text, tenant_id
)
encrypted_export = await encrypt_data_transit(export_text, tenant_id)
encrypted_binary = await encrypt_data_transit(
base64.b64encode(binary_backup).decode(), tenant_id
)
@@ -302,8 +305,7 @@ async def run_backup(
except Exception as enc_err:
# Transit unavailable — fall back to plaintext (non-fatal)
logger.warning(
"Transit encryption failed for %s backup of device %s, "
"storing plaintext: %s",
"Transit encryption failed for %s backup of device %s, storing plaintext: %s",
trigger_type,
device_id,
enc_err,
@@ -313,9 +315,7 @@ async def run_backup(
# -----------------------------------------------------------------------
# Step 6: Commit to git (wrapped in run_in_executor — pygit2 is sync C bindings)
# -----------------------------------------------------------------------
commit_message = (
f"{trigger_type}: {hostname} ({ip}) at {ts}"
)
commit_message = f"{trigger_type}: {hostname} ({ip}) at {ts}"
commit_sha = await loop.run_in_executor(
None,

View File

@@ -48,6 +48,7 @@ _VALID_TRANSITIONS: dict[str, set[str]] = {
# CA Generation
# ---------------------------------------------------------------------------
async def generate_ca(
db: AsyncSession,
tenant_id: UUID,
@@ -84,10 +85,12 @@ async def generate_ca(
now = datetime.datetime.now(datetime.timezone.utc)
expiry = now + datetime.timedelta(days=365 * validity_years)
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
])
subject = issuer = x509.Name(
[
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
]
)
ca_cert = (
x509.CertificateBuilder()
@@ -97,9 +100,7 @@ async def generate_ca(
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(expiry)
.add_extension(
x509.BasicConstraints(ca=True, path_length=0), critical=True
)
.add_extension(x509.BasicConstraints(ca=True, path_length=0), critical=True)
.add_extension(
x509.KeyUsage(
digital_signature=True,
@@ -166,6 +167,7 @@ async def generate_ca(
# Device Certificate Signing
# ---------------------------------------------------------------------------
async def sign_device_cert(
db: AsyncSession,
ca: CertificateAuthority,
@@ -196,9 +198,7 @@ async def sign_device_cert(
str(ca.tenant_id),
encryption_key,
)
ca_key = serialization.load_pem_private_key(
ca_key_pem.encode("utf-8"), password=None
)
ca_key = serialization.load_pem_private_key(ca_key_pem.encode("utf-8"), password=None)
# Load CA certificate for issuer info and AuthorityKeyIdentifier
ca_cert = x509.load_pem_x509_certificate(ca.cert_pem.encode("utf-8"))
@@ -212,19 +212,19 @@ async def sign_device_cert(
device_cert = (
x509.CertificateBuilder()
.subject_name(
x509.Name([
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
x509.NameAttribute(NameOID.COMMON_NAME, hostname),
])
x509.Name(
[
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
x509.NameAttribute(NameOID.COMMON_NAME, hostname),
]
)
)
.issuer_name(ca_cert.subject)
.public_key(device_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(expiry)
.add_extension(
x509.BasicConstraints(ca=False, path_length=None), critical=True
)
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
.add_extension(
x509.KeyUsage(
digital_signature=True,
@@ -244,17 +244,17 @@ async def sign_device_cert(
critical=False,
)
.add_extension(
x509.SubjectAlternativeName([
x509.IPAddress(ipaddress.ip_address(ip_address)),
x509.DNSName(hostname),
]),
x509.SubjectAlternativeName(
[
x509.IPAddress(ipaddress.ip_address(ip_address)),
x509.DNSName(hostname),
]
),
critical=False,
)
.add_extension(
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
ca_cert.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
).value
ca_cert.extensions.get_extension_for_class(x509.SubjectKeyIdentifier).value
),
critical=False,
)
@@ -308,15 +308,14 @@ async def sign_device_cert(
# Queries
# ---------------------------------------------------------------------------
async def get_ca_for_tenant(
db: AsyncSession,
tenant_id: UUID,
) -> CertificateAuthority | None:
"""Return the tenant's CA, or None if not yet initialized."""
result = await db.execute(
select(CertificateAuthority).where(
CertificateAuthority.tenant_id == tenant_id
)
select(CertificateAuthority).where(CertificateAuthority.tenant_id == tenant_id)
)
return result.scalar_one_or_none()
@@ -352,6 +351,7 @@ async def get_device_certs(
# Status Management
# ---------------------------------------------------------------------------
async def update_cert_status(
db: AsyncSession,
cert_id: UUID,
@@ -377,9 +377,7 @@ async def update_cert_status(
Raises:
ValueError: If the certificate is not found or the transition is invalid.
"""
result = await db.execute(
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
)
result = await db.execute(select(DeviceCertificate).where(DeviceCertificate.id == cert_id))
cert = result.scalar_one_or_none()
if cert is None:
raise ValueError(f"Device certificate {cert_id} not found")
@@ -413,6 +411,7 @@ async def update_cert_status(
# Cert Data for Deployment
# ---------------------------------------------------------------------------
async def get_cert_for_deploy(
db: AsyncSession,
cert_id: UUID,
@@ -434,18 +433,14 @@ async def get_cert_for_deploy(
Raises:
ValueError: If the certificate or its CA is not found.
"""
result = await db.execute(
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
)
result = await db.execute(select(DeviceCertificate).where(DeviceCertificate.id == cert_id))
cert = result.scalar_one_or_none()
if cert is None:
raise ValueError(f"Device certificate {cert_id} not found")
# Fetch the CA for the ca_cert_pem
ca_result = await db.execute(
select(CertificateAuthority).where(
CertificateAuthority.id == cert.ca_id
)
select(CertificateAuthority).where(CertificateAuthority.id == cert.ca_id)
)
ca = ca_result.scalar_one_or_none()
if ca is None:

View File

@@ -64,9 +64,7 @@ def parse_diff_changes(diff_text: str) -> list[dict]:
current_section: str | None = None
# Track per-component: adds, removes, raw_lines
components: dict[str, dict] = defaultdict(
lambda: {"adds": 0, "removes": 0, "raw_lines": []}
)
components: dict[str, dict] = defaultdict(lambda: {"adds": 0, "removes": 0, "raw_lines": []})
for line in lines:
# Skip unified diff headers

View File

@@ -51,13 +51,15 @@ async def handle_config_changed(event: dict) -> None:
if await _last_backup_within_dedup_window(device_id):
logger.info(
"Config change on device %s — skipping backup (within %dm dedup window)",
device_id, DEDUP_WINDOW_MINUTES,
device_id,
DEDUP_WINDOW_MINUTES,
)
return
logger.info(
"Config change detected on device %s (tenant %s): %s -> %s",
device_id, tenant_id,
device_id,
tenant_id,
event.get("old_timestamp", "?"),
event.get("new_timestamp", "?"),
)

View File

@@ -65,9 +65,7 @@ async def generate_and_store_diff(
# 2. No previous snapshot = first snapshot for device
if prev_row is None:
logger.debug(
"First snapshot for device %s, no diff to generate", device_id
)
logger.debug("First snapshot for device %s, no diff to generate", device_id)
return
old_snapshot_id = prev_row._mapping["id"]
@@ -103,9 +101,7 @@ async def generate_and_store_diff(
# 6. Generate unified diff
old_lines = old_plaintext.decode("utf-8").splitlines()
new_lines = new_plaintext.decode("utf-8").splitlines()
diff_lines = list(
difflib.unified_diff(old_lines, new_lines, lineterm="", n=3)
)
diff_lines = list(difflib.unified_diff(old_lines, new_lines, lineterm="", n=3))
# 7. If empty diff, skip INSERT
diff_text = "\n".join(diff_lines)
@@ -118,12 +114,10 @@ async def generate_and_store_diff(
# 8. Count lines added/removed (exclude +++ and --- headers)
lines_added = sum(
1 for line in diff_lines
if line.startswith("+") and not line.startswith("++")
1 for line in diff_lines if line.startswith("+") and not line.startswith("++")
)
lines_removed = sum(
1 for line in diff_lines
if line.startswith("-") and not line.startswith("--")
1 for line in diff_lines if line.startswith("-") and not line.startswith("--")
)
# 9. INSERT into router_config_diffs (RETURNING id for change parser)
@@ -165,6 +159,7 @@ async def generate_and_store_diff(
try:
from app.services.audit_service import log_action
import uuid as _uuid
await log_action(
db=None,
tenant_id=_uuid.UUID(tenant_id),
@@ -207,12 +202,16 @@ async def generate_and_store_diff(
await session.commit()
logger.info(
"Stored %d config changes for device %s diff %s",
len(changes), device_id, diff_id,
len(changes),
device_id,
diff_id,
)
except Exception as exc:
logger.warning(
"Change parser error for device %s diff %s (non-fatal): %s",
device_id, diff_id, exc,
device_id,
diff_id,
exc,
)
config_diff_errors_total.labels(error_type="change_parser").inc()

View File

@@ -89,9 +89,11 @@ async def handle_config_snapshot(msg) -> None:
collected_at_raw = data.get("collected_at")
try:
collected_at = datetime.fromisoformat(
collected_at_raw.replace("Z", "+00:00")
) if collected_at_raw else datetime.now(timezone.utc)
collected_at = (
datetime.fromisoformat(collected_at_raw.replace("Z", "+00:00"))
if collected_at_raw
else datetime.now(timezone.utc)
)
except (ValueError, AttributeError):
collected_at = datetime.now(timezone.utc)
@@ -131,9 +133,7 @@ async def handle_config_snapshot(msg) -> None:
# --- Encrypt via OpenBao Transit ---
openbao = OpenBaoTransitService()
try:
encrypted_text = await openbao.encrypt(
tenant_id, config_text.encode("utf-8")
)
encrypted_text = await openbao.encrypt(tenant_id, config_text.encode("utf-8"))
except Exception as exc:
logger.warning(
"Transit encrypt failed for device %s tenant %s: %s",
@@ -207,7 +207,8 @@ async def handle_config_snapshot(msg) -> None:
except Exception as exc:
logger.warning(
"Diff generation failed for device %s (non-fatal): %s",
device_id, exc,
device_id,
exc,
)
logger.info(
@@ -233,9 +234,7 @@ async def _subscribe_with_retry(js) -> None:
stream="DEVICE_EVENTS",
manual_ack=True,
)
logger.info(
"NATS: subscribed to config.snapshot.> (durable: config_snapshot_ingest)"
)
logger.info("NATS: subscribed to config.snapshot.> (durable: config_snapshot_ingest)")
return
except Exception as exc:
if attempt < max_attempts:

View File

@@ -29,8 +29,6 @@ from app.models.device import (
DeviceTagAssignment,
)
from app.schemas.device import (
BulkAddRequest,
BulkAddResult,
DeviceCreate,
DeviceGroupCreate,
DeviceGroupResponse,
@@ -43,9 +41,7 @@ from app.schemas.device import (
)
from app.config import settings
from app.services.crypto import (
decrypt_credentials,
decrypt_credentials_hybrid,
encrypt_credentials,
encrypt_credentials_transit,
)
@@ -58,9 +54,7 @@ from app.services.crypto import (
async def _tcp_reachable(ip: str, port: int, timeout: float = 3.0) -> bool:
"""Return True if a TCP connection to ip:port succeeds within timeout."""
try:
_, writer = await asyncio.wait_for(
asyncio.open_connection(ip, port), timeout=timeout
)
_, writer = await asyncio.wait_for(asyncio.open_connection(ip, port), timeout=timeout)
writer.close()
try:
await writer.wait_closed()
@@ -151,6 +145,7 @@ async def create_device(
if not api_reachable and not ssl_reachable:
from fastapi import HTTPException, status
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=(
@@ -162,9 +157,7 @@ async def create_device(
# Encrypt credentials via OpenBao Transit (new writes go through Transit)
credentials_json = json.dumps({"username": data.username, "password": data.password})
transit_ciphertext = await encrypt_credentials_transit(
credentials_json, str(tenant_id)
)
transit_ciphertext = await encrypt_credentials_transit(credentials_json, str(tenant_id))
device = Device(
tenant_id=tenant_id,
@@ -180,9 +173,7 @@ async def create_device(
await db.refresh(device)
# Re-query with relationships loaded
result = await db.execute(
_device_with_relations().where(Device.id == device.id)
)
result = await db.execute(_device_with_relations().where(Device.id == device.id))
device = result.scalar_one()
return _build_device_response(device)
@@ -223,9 +214,7 @@ async def get_devices(
if tag_id:
base_q = base_q.where(
Device.id.in_(
select(DeviceTagAssignment.device_id).where(
DeviceTagAssignment.tag_id == tag_id
)
select(DeviceTagAssignment.device_id).where(DeviceTagAssignment.tag_id == tag_id)
)
)
@@ -274,9 +263,7 @@ async def get_device(
"""Get a single device by ID."""
from fastapi import HTTPException, status
result = await db.execute(
_device_with_relations().where(Device.id == device_id)
)
result = await db.execute(_device_with_relations().where(Device.id == device_id))
device = result.scalar_one_or_none()
if not device:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
@@ -295,9 +282,7 @@ async def update_device(
"""
from fastapi import HTTPException, status
result = await db.execute(
_device_with_relations().where(Device.id == device_id)
)
result = await db.execute(_device_with_relations().where(Device.id == device_id))
device = result.scalar_one_or_none()
if not device:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
@@ -323,7 +308,9 @@ async def update_device(
if data.password is not None:
# Decrypt existing to get current username if no new username given
current_username: str = data.username or ""
if not current_username and (device.encrypted_credentials_transit or device.encrypted_credentials):
if not current_username and (
device.encrypted_credentials_transit or device.encrypted_credentials
):
try:
existing_json = await decrypt_credentials_hybrid(
device.encrypted_credentials_transit,
@@ -336,17 +323,21 @@ async def update_device(
except Exception:
current_username = ""
credentials_json = json.dumps({
"username": data.username if data.username is not None else current_username,
"password": data.password,
})
credentials_json = json.dumps(
{
"username": data.username if data.username is not None else current_username,
"password": data.password,
}
)
# New writes go through Transit
device.encrypted_credentials_transit = await encrypt_credentials_transit(
credentials_json, str(device.tenant_id)
)
device.encrypted_credentials = None # Clear legacy (Transit is canonical)
credentials_changed = True
elif data.username is not None and (device.encrypted_credentials_transit or device.encrypted_credentials):
elif data.username is not None and (
device.encrypted_credentials_transit or device.encrypted_credentials
):
# Only username changed — update it without changing the password
try:
existing_json = await decrypt_credentials_hybrid(
@@ -373,6 +364,7 @@ async def update_device(
if credentials_changed:
try:
from app.services.event_publisher import publish_event
await publish_event(
f"device.credential_changed.{device_id}",
{"device_id": str(device_id), "tenant_id": str(tenant_id)},
@@ -380,9 +372,7 @@ async def update_device(
except Exception:
pass # Never fail the update due to NATS issues
result2 = await db.execute(
_device_with_relations().where(Device.id == device_id)
)
result2 = await db.execute(_device_with_relations().where(Device.id == device_id))
device = result2.scalar_one()
return _build_device_response(device)
@@ -526,11 +516,7 @@ async def get_groups(
tenant_id: uuid.UUID,
) -> list[DeviceGroupResponse]:
"""Return all device groups for the current tenant with device counts."""
result = await db.execute(
select(DeviceGroup).options(
selectinload(DeviceGroup.memberships)
)
)
result = await db.execute(select(DeviceGroup).options(selectinload(DeviceGroup.memberships)))
groups = result.scalars().all()
return [
DeviceGroupResponse(
@@ -554,9 +540,9 @@ async def update_group(
from fastapi import HTTPException, status
result = await db.execute(
select(DeviceGroup).options(
selectinload(DeviceGroup.memberships)
).where(DeviceGroup.id == group_id)
select(DeviceGroup)
.options(selectinload(DeviceGroup.memberships))
.where(DeviceGroup.id == group_id)
)
group = result.scalar_one_or_none()
if not group:
@@ -571,9 +557,9 @@ async def update_group(
await db.refresh(group)
result2 = await db.execute(
select(DeviceGroup).options(
selectinload(DeviceGroup.memberships)
).where(DeviceGroup.id == group_id)
select(DeviceGroup)
.options(selectinload(DeviceGroup.memberships))
.where(DeviceGroup.id == group_id)
)
group = result2.scalar_one()
return DeviceGroupResponse(

View File

@@ -48,7 +48,5 @@ async def generate_emergency_kit_template(
# Run weasyprint in thread to avoid blocking the event loop
from weasyprint import HTML
pdf_bytes = await asyncio.to_thread(
lambda: HTML(string=html_content).write_pdf()
)
pdf_bytes = await asyncio.to_thread(lambda: HTML(string=html_content).write_pdf())
return pdf_bytes

View File

@@ -13,7 +13,6 @@ Version discovery comes from two sources:
"""
import logging
import os
from pathlib import Path
import httpx
@@ -51,13 +50,12 @@ async def check_latest_versions() -> list[dict]:
async with httpx.AsyncClient(timeout=30.0) as client:
for channel_file, channel, major in _VERSION_SOURCES:
try:
resp = await client.get(
f"https://download.mikrotik.com/routeros/{channel_file}"
)
resp = await client.get(f"https://download.mikrotik.com/routeros/{channel_file}")
if resp.status_code != 200:
logger.warning(
"MikroTik version check returned %d for %s",
resp.status_code, channel_file,
resp.status_code,
channel_file,
)
continue
@@ -72,12 +70,14 @@ async def check_latest_versions() -> list[dict]:
f"https://download.mikrotik.com/routeros/"
f"{version}/routeros-{version}-{arch}.npk"
)
results.append({
"architecture": arch,
"channel": channel,
"version": version,
"npk_url": npk_url,
})
results.append(
{
"architecture": arch,
"channel": channel,
"version": version,
"npk_url": npk_url,
}
)
except Exception as e:
logger.warning("Failed to check %s: %s", channel_file, e)
@@ -242,15 +242,15 @@ async def get_firmware_overview(tenant_id: str) -> dict:
groups = []
for ver, devs in sorted(version_groups.items()):
# A version is "latest" if it matches the latest for any arch/channel combo
is_latest = any(
v["version"] == ver for v in latest_versions.values()
is_latest = any(v["version"] == ver for v in latest_versions.values())
groups.append(
{
"version": ver,
"count": len(devs),
"is_latest": is_latest,
"devices": devs,
}
)
groups.append({
"version": ver,
"count": len(devs),
"is_latest": is_latest,
"devices": devs,
})
return {
"devices": device_list,
@@ -272,12 +272,14 @@ async def get_cached_firmware() -> list[dict]:
continue
for npk_file in sorted(version_dir.iterdir()):
if npk_file.suffix == ".npk":
cached.append({
"path": str(npk_file),
"version": version_dir.name,
"filename": npk_file.name,
"size_bytes": npk_file.stat().st_size,
})
cached.append(
{
"path": str(npk_file),
"version": version_dir.name,
"filename": npk_file.name,
"size_bytes": npk_file.stat().st_size,
}
)
return cached

View File

@@ -41,7 +41,7 @@ async def on_device_firmware(msg) -> None:
try:
data = json.loads(msg.data)
device_id = data.get("device_id")
tenant_id = data.get("tenant_id")
_tenant_id = data.get("tenant_id")
architecture = data.get("architecture")
installed_version = data.get("installed_version")
latest_version = data.get("latest_version")
@@ -126,9 +126,7 @@ async def _subscribe_with_retry(js: JetStreamContext) -> None:
durable="api-firmware-consumer",
stream="DEVICE_EVENTS",
)
logger.info(
"NATS: subscribed to device.firmware.> (durable: api-firmware-consumer)"
)
logger.info("NATS: subscribed to device.firmware.> (durable: api-firmware-consumer)")
return
except Exception as exc:
if attempt < max_attempts:

View File

@@ -238,11 +238,13 @@ def list_device_commits(
# Device wasn't in parent but is in this commit — it's the first entry.
pass
results.append({
"sha": str(commit.id),
"message": commit.message.strip(),
"timestamp": commit.commit_time,
})
results.append(
{
"sha": str(commit.id),
"message": commit.message.strip(),
"timestamp": commit.commit_time,
}
)
return results

View File

@@ -52,6 +52,7 @@ async def store_user_key_set(
"""
# Remove any existing key set (e.g. from a failed prior upgrade attempt)
from sqlalchemy import delete
await db.execute(delete(UserKeySet).where(UserKeySet.user_id == user_id))
key_set = UserKeySet(
@@ -71,9 +72,7 @@ async def store_user_key_set(
return key_set
async def get_user_key_set(
db: AsyncSession, user_id: UUID
) -> UserKeySet | None:
async def get_user_key_set(db: AsyncSession, user_id: UUID) -> UserKeySet | None:
"""Retrieve encrypted key bundle for login response.
Args:
@@ -83,9 +82,7 @@ async def get_user_key_set(
Returns:
The UserKeySet if found, None otherwise.
"""
result = await db.execute(
select(UserKeySet).where(UserKeySet.user_id == user_id)
)
result = await db.execute(select(UserKeySet).where(UserKeySet.user_id == user_id))
return result.scalar_one_or_none()
@@ -163,9 +160,7 @@ async def provision_tenant_key(db: AsyncSession, tenant_id: UUID) -> str:
await openbao.create_tenant_key(str(tenant_id))
# Update tenant record with key name
result = await db.execute(
select(Tenant).where(Tenant.id == tenant_id)
)
result = await db.execute(select(Tenant).where(Tenant.id == tenant_id))
tenant = result.scalar_one_or_none()
if tenant:
tenant.openbao_key_name = key_name
@@ -210,13 +205,18 @@ async def migrate_tenant_credentials(db: AsyncSession, tenant_id: UUID) -> dict:
select(Device).where(
Device.tenant_id == tenant_id,
Device.encrypted_credentials.isnot(None),
(Device.encrypted_credentials_transit.is_(None) | (Device.encrypted_credentials_transit == "")),
(
Device.encrypted_credentials_transit.is_(None)
| (Device.encrypted_credentials_transit == "")
),
)
)
for device in result.scalars().all():
try:
plaintext = decrypt_credentials(device.encrypted_credentials, legacy_key)
device.encrypted_credentials_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
device.encrypted_credentials_transit = await openbao.encrypt(
tid, plaintext.encode("utf-8")
)
counts["devices"] += 1
except Exception as e:
logger.error("Failed to migrate device %s credentials: %s", device.id, e)
@@ -227,7 +227,10 @@ async def migrate_tenant_credentials(db: AsyncSession, tenant_id: UUID) -> dict:
select(CertificateAuthority).where(
CertificateAuthority.tenant_id == tenant_id,
CertificateAuthority.encrypted_private_key.isnot(None),
(CertificateAuthority.encrypted_private_key_transit.is_(None) | (CertificateAuthority.encrypted_private_key_transit == "")),
(
CertificateAuthority.encrypted_private_key_transit.is_(None)
| (CertificateAuthority.encrypted_private_key_transit == "")
),
)
)
for ca in result.scalars().all():
@@ -244,13 +247,18 @@ async def migrate_tenant_credentials(db: AsyncSession, tenant_id: UUID) -> dict:
select(DeviceCertificate).where(
DeviceCertificate.tenant_id == tenant_id,
DeviceCertificate.encrypted_private_key.isnot(None),
(DeviceCertificate.encrypted_private_key_transit.is_(None) | (DeviceCertificate.encrypted_private_key_transit == "")),
(
DeviceCertificate.encrypted_private_key_transit.is_(None)
| (DeviceCertificate.encrypted_private_key_transit == "")
),
)
)
for cert in result.scalars().all():
try:
plaintext = decrypt_credentials(cert.encrypted_private_key, legacy_key)
cert.encrypted_private_key_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
cert.encrypted_private_key_transit = await openbao.encrypt(
tid, plaintext.encode("utf-8")
)
counts["certs"] += 1
except Exception as e:
logger.error("Failed to migrate cert %s private key: %s", cert.id, e)
@@ -303,7 +311,14 @@ async def provision_existing_tenants(db: AsyncSession) -> dict:
result = await db.execute(select(Tenant))
tenants = result.scalars().all()
total = {"tenants": len(tenants), "devices": 0, "cas": 0, "certs": 0, "channels": 0, "errors": 0}
total = {
"tenants": len(tenants),
"devices": 0,
"cas": 0,
"certs": 0,
"channels": 0,
"errors": 0,
}
for tenant in tenants:
try:

View File

@@ -199,9 +199,7 @@ async def on_device_metrics(msg) -> None:
device_id = data.get("device_id")
if not metric_type or not device_id:
logger.warning(
"device.metrics event missing 'type' or 'device_id' — skipping"
)
logger.warning("device.metrics event missing 'type' or 'device_id' — skipping")
await msg.ack()
return
@@ -222,6 +220,7 @@ async def on_device_metrics(msg) -> None:
# Alert evaluation — non-fatal; metric write is the primary operation
try:
from app.services import alert_evaluator
await alert_evaluator.evaluate(
device_id=device_id,
tenant_id=data.get("tenant_id", ""),
@@ -265,9 +264,7 @@ async def _subscribe_with_retry(js: JetStreamContext) -> None:
durable="api-metrics-consumer",
stream="DEVICE_EVENTS",
)
logger.info(
"NATS: subscribed to device.metrics.> (durable: api-metrics-consumer)"
)
logger.info("NATS: subscribed to device.metrics.> (durable: api-metrics-consumer)")
return
except Exception as exc:
if attempt < max_attempts:

View File

@@ -69,7 +69,9 @@ async def on_device_status(msg) -> None:
uptime_seconds = _parse_uptime(data.get("uptime", ""))
if not device_id or not status:
logger.warning("Received device.status event with missing device_id or status — skipping")
logger.warning(
"Received device.status event with missing device_id or status — skipping"
)
await msg.ack()
return
@@ -115,12 +117,15 @@ async def on_device_status(msg) -> None:
# Alert evaluation for offline/online status changes — non-fatal
try:
from app.services import alert_evaluator
if status == "offline":
await alert_evaluator.evaluate_offline(device_id, data.get("tenant_id", ""))
elif status == "online":
await alert_evaluator.evaluate_online(device_id, data.get("tenant_id", ""))
except Exception as e:
logger.warning("Alert evaluation failed for device %s status=%s: %s", device_id, status, e)
logger.warning(
"Alert evaluation failed for device %s status=%s: %s", device_id, status, e
)
logger.info(
"Device status updated",

View File

@@ -39,7 +39,9 @@ async def dispatch_notifications(
except Exception as e:
logger.warning(
"Notification delivery failed for channel %s (%s): %s",
channel.get("name"), channel.get("channel_type"), e,
channel.get("name"),
channel.get("channel_type"),
e,
)
@@ -100,14 +102,18 @@ async def _send_email(channel: dict, alert_event: dict, device_hostname: str) ->
if transit_cipher and tenant_id:
try:
from app.services.kms_service import decrypt_transit
smtp_password = await decrypt_transit(transit_cipher, tenant_id)
except Exception:
logger.warning("Transit decryption failed for channel %s, trying legacy", channel.get("id"))
logger.warning(
"Transit decryption failed for channel %s, trying legacy", channel.get("id")
)
if not smtp_password and legacy_cipher:
try:
from app.config import settings as app_settings
from cryptography.fernet import Fernet
raw = bytes(legacy_cipher) if isinstance(legacy_cipher, memoryview) else legacy_cipher
f = Fernet(app_settings.CREDENTIAL_ENCRYPTION_KEY.encode())
smtp_password = f.decrypt(raw).decode()
@@ -163,7 +169,8 @@ async def _send_webhook(
response = await client.post(webhook_url, json=payload)
logger.info(
"Webhook notification sent to %s — status %d",
webhook_url, response.status_code,
webhook_url,
response.status_code,
)
@@ -180,13 +187,18 @@ async def _send_slack(
value = alert_event.get("value")
threshold = alert_event.get("threshold")
color = {"CRITICAL": "#dc2626", "WARNING": "#f59e0b", "INFO": "#3b82f6"}.get(severity, "#6b7280")
color = {"CRITICAL": "#dc2626", "WARNING": "#f59e0b", "INFO": "#3b82f6"}.get(
severity, "#6b7280"
)
status_label = "RESOLVED" if status == "resolved" else status
blocks = [
{
"type": "header",
"text": {"type": "plain_text", "text": f"{'' if status == 'resolved' else '🚨'} [{severity}] {status_label.upper()}"},
"text": {
"type": "plain_text",
"text": f"{'' if status == 'resolved' else '🚨'} [{severity}] {status_label.upper()}",
},
},
{
"type": "section",
@@ -205,7 +217,9 @@ async def _send_slack(
blocks.append({"type": "section", "fields": fields})
if message_text:
blocks.append({"type": "section", "text": {"type": "mrkdwn", "text": f"*Message:*\n{message_text}"}})
blocks.append(
{"type": "section", "text": {"type": "mrkdwn", "text": f"*Message:*\n{message_text}"}}
)
blocks.append({"type": "context", "elements": [{"type": "mrkdwn", "text": "TOD Alert System"}]})

View File

@@ -7,6 +7,7 @@ material never leaves OpenBao -- the application only sees ciphertext.
Ciphertext format: "vault:v1:base64..." (compatible with Vault Transit format)
"""
import base64
import logging
from typing import Optional

View File

@@ -47,13 +47,15 @@ async def record_push(
"""
r = await _get_redis()
key = f"push:recent:{device_id}"
value = json.dumps({
"device_id": device_id,
"tenant_id": tenant_id,
"push_type": push_type,
"push_operation_id": push_operation_id,
"pre_push_commit_sha": pre_push_commit_sha,
})
value = json.dumps(
{
"device_id": device_id,
"tenant_id": tenant_id,
"push_type": push_type,
"push_operation_id": push_operation_id,
"pre_push_commit_sha": pre_push_commit_sha,
}
)
await r.set(key, value, ex=PUSH_TTL_SECONDS)
logger.debug(
"Recorded push for device %s (type=%s, TTL=%ds)",

View File

@@ -164,16 +164,18 @@ async def _device_inventory(
uptime_str = _format_uptime(row[6]) if row[6] else None
last_seen_str = row[5].strftime("%Y-%m-%d %H:%M") if row[5] else None
devices.append({
"hostname": row[0],
"ip_address": row[1],
"model": row[2],
"routeros_version": row[3],
"status": status,
"last_seen": last_seen_str,
"uptime": uptime_str,
"groups": row[7] if row[7] else None,
})
devices.append(
{
"hostname": row[0],
"ip_address": row[1],
"model": row[2],
"routeros_version": row[3],
"status": status,
"last_seen": last_seen_str,
"uptime": uptime_str,
"groups": row[7] if row[7] else None,
}
)
return {
"report_title": "Device Inventory",
@@ -224,16 +226,18 @@ async def _metrics_summary(
devices = []
for row in rows:
devices.append({
"hostname": row[0],
"avg_cpu": float(row[1]) if row[1] is not None else None,
"peak_cpu": float(row[2]) if row[2] is not None else None,
"avg_mem": float(row[3]) if row[3] is not None else None,
"peak_mem": float(row[4]) if row[4] is not None else None,
"avg_disk": float(row[5]) if row[5] is not None else None,
"avg_temp": float(row[6]) if row[6] is not None else None,
"data_points": row[7],
})
devices.append(
{
"hostname": row[0],
"avg_cpu": float(row[1]) if row[1] is not None else None,
"peak_cpu": float(row[2]) if row[2] is not None else None,
"avg_mem": float(row[3]) if row[3] is not None else None,
"peak_mem": float(row[4]) if row[4] is not None else None,
"avg_disk": float(row[5]) if row[5] is not None else None,
"avg_temp": float(row[6]) if row[6] is not None else None,
"data_points": row[7],
}
)
return {
"report_title": "Metrics Summary",
@@ -287,14 +291,16 @@ async def _alert_history(
if duration_secs is not None:
resolved_durations.append(duration_secs)
alerts.append({
"fired_at": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"hostname": row[5],
"severity": severity,
"status": row[3],
"message": row[4],
"duration": _format_duration(duration_secs) if duration_secs is not None else None,
})
alerts.append(
{
"fired_at": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"hostname": row[5],
"severity": severity,
"status": row[3],
"message": row[4],
"duration": _format_duration(duration_secs) if duration_secs is not None else None,
}
)
mttr_minutes = None
mttr_display = None
@@ -374,13 +380,15 @@ async def _change_log_from_audit(
entries = []
for row in rows:
entries.append({
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"user": row[1],
"action": row[2],
"device": row[3],
"details": row[4] or row[5] or "",
})
entries.append(
{
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"user": row[1],
"action": row[2],
"device": row[3],
"details": row[4] or row[5] or "",
}
)
return {
"report_title": "Change Log",
@@ -436,21 +444,25 @@ async def _change_log_from_backups(
# Merge and sort by timestamp descending
entries = []
for row in backup_rows:
entries.append({
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"user": row[1],
"action": row[2],
"device": row[3],
"details": row[4] or "",
})
entries.append(
{
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"user": row[1],
"action": row[2],
"device": row[3],
"details": row[4] or "",
}
)
for row in alert_rows:
entries.append({
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"user": row[1],
"action": row[2],
"device": row[3],
"details": row[4] or "",
})
entries.append(
{
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"user": row[1],
"action": row[2],
"device": row[3],
"details": row[4] or "",
}
)
# Sort by timestamp string descending
entries.sort(key=lambda e: e["timestamp"], reverse=True)
@@ -486,54 +498,102 @@ def _render_csv(report_type: str, data: dict[str, Any]) -> bytes:
writer = csv.writer(output)
if report_type == "device_inventory":
writer.writerow([
"Hostname", "IP Address", "Model", "RouterOS Version",
"Status", "Last Seen", "Uptime", "Groups",
])
writer.writerow(
[
"Hostname",
"IP Address",
"Model",
"RouterOS Version",
"Status",
"Last Seen",
"Uptime",
"Groups",
]
)
for d in data.get("devices", []):
writer.writerow([
d["hostname"], d["ip_address"], d["model"] or "",
d["routeros_version"] or "", d["status"],
d["last_seen"] or "", d["uptime"] or "",
d["groups"] or "",
])
writer.writerow(
[
d["hostname"],
d["ip_address"],
d["model"] or "",
d["routeros_version"] or "",
d["status"],
d["last_seen"] or "",
d["uptime"] or "",
d["groups"] or "",
]
)
elif report_type == "metrics_summary":
writer.writerow([
"Hostname", "Avg CPU %", "Peak CPU %", "Avg Memory %",
"Peak Memory %", "Avg Disk %", "Avg Temp", "Data Points",
])
writer.writerow(
[
"Hostname",
"Avg CPU %",
"Peak CPU %",
"Avg Memory %",
"Peak Memory %",
"Avg Disk %",
"Avg Temp",
"Data Points",
]
)
for d in data.get("devices", []):
writer.writerow([
d["hostname"],
f"{d['avg_cpu']:.1f}" if d["avg_cpu"] is not None else "",
f"{d['peak_cpu']:.1f}" if d["peak_cpu"] is not None else "",
f"{d['avg_mem']:.1f}" if d["avg_mem"] is not None else "",
f"{d['peak_mem']:.1f}" if d["peak_mem"] is not None else "",
f"{d['avg_disk']:.1f}" if d["avg_disk"] is not None else "",
f"{d['avg_temp']:.1f}" if d["avg_temp"] is not None else "",
d["data_points"],
])
writer.writerow(
[
d["hostname"],
f"{d['avg_cpu']:.1f}" if d["avg_cpu"] is not None else "",
f"{d['peak_cpu']:.1f}" if d["peak_cpu"] is not None else "",
f"{d['avg_mem']:.1f}" if d["avg_mem"] is not None else "",
f"{d['peak_mem']:.1f}" if d["peak_mem"] is not None else "",
f"{d['avg_disk']:.1f}" if d["avg_disk"] is not None else "",
f"{d['avg_temp']:.1f}" if d["avg_temp"] is not None else "",
d["data_points"],
]
)
elif report_type == "alert_history":
writer.writerow([
"Timestamp", "Device", "Severity", "Message", "Status", "Duration",
])
writer.writerow(
[
"Timestamp",
"Device",
"Severity",
"Message",
"Status",
"Duration",
]
)
for a in data.get("alerts", []):
writer.writerow([
a["fired_at"], a["hostname"] or "", a["severity"],
a["message"] or "", a["status"], a["duration"] or "",
])
writer.writerow(
[
a["fired_at"],
a["hostname"] or "",
a["severity"],
a["message"] or "",
a["status"],
a["duration"] or "",
]
)
elif report_type == "change_log":
writer.writerow([
"Timestamp", "User", "Action", "Device", "Details",
])
writer.writerow(
[
"Timestamp",
"User",
"Action",
"Device",
"Details",
]
)
for e in data.get("entries", []):
writer.writerow([
e["timestamp"], e["user"] or "", e["action"],
e["device"] or "", e["details"] or "",
])
writer.writerow(
[
e["timestamp"],
e["user"] or "",
e["action"],
e["device"] or "",
e["details"] or "",
]
)
return output.getvalue().encode("utf-8")

View File

@@ -33,7 +33,6 @@ import asyncssh
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.database import set_tenant_context, AdminAsyncSessionLocal
from app.models.config_backup import ConfigPushOperation
from app.models.device import Device
from app.services import backup_service, git_store
@@ -113,12 +112,11 @@ async def restore_config(
raise ValueError(f"Device {device_id!r} not found")
if not device.encrypted_credentials_transit and not device.encrypted_credentials:
raise ValueError(
f"Device {device_id!r} has no stored credentials — cannot perform restore"
)
raise ValueError(f"Device {device_id!r} has no stored credentials — cannot perform restore")
key = settings.get_encryption_key_bytes()
from app.services.crypto import decrypt_credentials_hybrid
creds_json = await decrypt_credentials_hybrid(
device.encrypted_credentials_transit,
device.encrypted_credentials,
@@ -133,7 +131,9 @@ async def restore_config(
hostname = device.hostname or ip
# Publish "started" progress event
await _publish_push_progress(tenant_id, device_id, "started", f"Config restore started for {hostname}")
await _publish_push_progress(
tenant_id, device_id, "started", f"Config restore started for {hostname}"
)
# ------------------------------------------------------------------
# Step 2: Read the target export.rsc from the backup commit
@@ -157,7 +157,9 @@ async def restore_config(
# ------------------------------------------------------------------
# Step 3: Mandatory pre-backup before push
# ------------------------------------------------------------------
await _publish_push_progress(tenant_id, device_id, "backing_up", f"Creating pre-restore backup for {hostname}")
await _publish_push_progress(
tenant_id, device_id, "backing_up", f"Creating pre-restore backup for {hostname}"
)
logger.info(
"Starting pre-restore backup for device %s (%s) before pushing commit %s",
@@ -198,7 +200,9 @@ async def restore_config(
# Step 5: SSH to device — install panic-revert, push config
# ------------------------------------------------------------------
push_op_id_str = str(push_op_id)
await _publish_push_progress(tenant_id, device_id, "pushing", f"Pushing config to {hostname}", push_op_id=push_op_id_str)
await _publish_push_progress(
tenant_id, device_id, "pushing", f"Pushing config to {hostname}", push_op_id=push_op_id_str
)
logger.info(
"Pushing config to device %s (%s): installing panic-revert scheduler and uploading config",
@@ -290,9 +294,12 @@ async def restore_config(
# Update push operation to failed
await _update_push_op_status(push_op_id, "failed", db_session)
await _publish_push_progress(
tenant_id, device_id, "failed",
tenant_id,
device_id,
"failed",
f"Config push failed for {hostname}: {push_err}",
push_op_id=push_op_id_str, error=str(push_err),
push_op_id=push_op_id_str,
error=str(push_err),
)
return {
"status": "failed",
@@ -312,7 +319,13 @@ async def restore_config(
# ------------------------------------------------------------------
# Step 6: Wait 60s for config to settle
# ------------------------------------------------------------------
await _publish_push_progress(tenant_id, device_id, "settling", f"Config pushed to {hostname} — waiting 60s for settle", push_op_id=push_op_id_str)
await _publish_push_progress(
tenant_id,
device_id,
"settling",
f"Config pushed to {hostname} — waiting 60s for settle",
push_op_id=push_op_id_str,
)
logger.info(
"Config pushed to device %s — waiting 60s for config to settle",
@@ -323,7 +336,13 @@ async def restore_config(
# ------------------------------------------------------------------
# Step 7: Reachability check
# ------------------------------------------------------------------
await _publish_push_progress(tenant_id, device_id, "verifying", f"Verifying device {hostname} reachability", push_op_id=push_op_id_str)
await _publish_push_progress(
tenant_id,
device_id,
"verifying",
f"Verifying device {hostname} reachability",
push_op_id=push_op_id_str,
)
reachable = await _check_reachability(ip, ssh_username, ssh_password)
@@ -362,7 +381,13 @@ async def restore_config(
await _update_push_op_status(push_op_id, "committed", db_session)
await clear_push(device_id)
await _publish_push_progress(tenant_id, device_id, "committed", f"Config restored successfully on {hostname}", push_op_id=push_op_id_str)
await _publish_push_progress(
tenant_id,
device_id,
"committed",
f"Config restored successfully on {hostname}",
push_op_id=push_op_id_str,
)
return {
"status": "committed",
@@ -384,7 +409,9 @@ async def restore_config(
await _update_push_op_status(push_op_id, "reverted", db_session)
await _publish_push_progress(
tenant_id, device_id, "reverted",
tenant_id,
device_id,
"reverted",
f"Device {hostname} unreachable — auto-reverting via panic-revert scheduler",
push_op_id=push_op_id_str,
)
@@ -446,7 +473,7 @@ async def _update_push_op_status(
new_status: New status value ('committed' | 'reverted' | 'failed').
db_session: Database session (must already have tenant context set).
"""
from sqlalchemy import select, update
from sqlalchemy import update
await db_session.execute(
update(ConfigPushOperation)
@@ -526,9 +553,7 @@ async def recover_stale_push_operations(db_session: AsyncSession) -> None:
for op in stale_ops:
try:
# Load device
dev_result = await db_session.execute(
select(Device).where(Device.id == op.device_id)
)
dev_result = await db_session.execute(select(Device).where(Device.id == op.device_id))
device = dev_result.scalar_one_or_none()
if not device:
logger.error("Device %s not found for stale op %s", op.device_id, op.id)
@@ -547,9 +572,7 @@ async def recover_stale_push_operations(db_session: AsyncSession) -> None:
ssh_password = creds.get("password", "")
# Check reachability
reachable = await _check_reachability(
device.ip_address, ssh_username, ssh_password
)
reachable = await _check_reachability(device.ip_address, ssh_username, ssh_password)
if reachable:
# Try to remove scheduler (if still there, push was good)

View File

@@ -53,7 +53,9 @@ async def cleanup_expired_snapshots() -> int:
deleted = result.rowcount
config_snapshots_cleaned_total.inc(deleted)
logger.info("retention cleanup complete", extra={"deleted_snapshots": deleted, "retention_days": days})
logger.info(
"retention cleanup complete", extra={"deleted_snapshots": deleted, "retention_days": days}
)
return deleted

View File

@@ -88,9 +88,7 @@ async def browse_menu(device_id: str, path: str) -> dict[str, Any]:
return await execute_command(device_id, command)
async def add_entry(
device_id: str, path: str, properties: dict[str, str]
) -> dict[str, Any]:
async def add_entry(device_id: str, path: str, properties: dict[str, str]) -> dict[str, Any]:
"""Add a new entry to a RouterOS menu path.
Args:
@@ -124,9 +122,7 @@ async def update_entry(
return await execute_command(device_id, f"{path}/set", args)
async def remove_entry(
device_id: str, path: str, entry_id: str
) -> dict[str, Any]:
async def remove_entry(device_id: str, path: str, entry_id: str) -> dict[str, Any]:
"""Remove an entry from a RouterOS menu path.
Args:

View File

@@ -7,20 +7,33 @@ from typing import Any
logger = logging.getLogger(__name__)
HIGH_RISK_PATHS = {
"/ip address", "/ip route", "/ip firewall filter", "/ip firewall nat",
"/interface", "/interface bridge", "/interface vlan",
"/system identity", "/ip service", "/ip ssh", "/user",
"/ip address",
"/ip route",
"/ip firewall filter",
"/ip firewall nat",
"/interface",
"/interface bridge",
"/interface vlan",
"/system identity",
"/ip service",
"/ip ssh",
"/user",
}
MANAGEMENT_PATTERNS = [
(re.compile(r"chain=input.*dst-port=(22|8291|8728|8729|443|80)", re.I),
"Modifies firewall rules for management ports (SSH/WinBox/API/Web)"),
(re.compile(r"chain=input.*action=drop", re.I),
"Adds drop rule on input chain — may block management access"),
(re.compile(r"/ip service", re.I),
"Modifies IP services — may disable API/SSH/WinBox access"),
(re.compile(r"/user.*set.*password", re.I),
"Changes user password — may affect automated access"),
(
re.compile(r"chain=input.*dst-port=(22|8291|8728|8729|443|80)", re.I),
"Modifies firewall rules for management ports (SSH/WinBox/API/Web)",
),
(
re.compile(r"chain=input.*action=drop", re.I),
"Adds drop rule on input chain — may block management access",
),
(re.compile(r"/ip service", re.I), "Modifies IP services — may disable API/SSH/WinBox access"),
(
re.compile(r"/user.*set.*password", re.I),
"Changes user password — may affect automated access",
),
]
@@ -73,7 +86,14 @@ def parse_rsc(text: str) -> dict[str, Any]:
else:
# Check if second part starts with a known command verb
cmd_check = parts[1].strip().split(None, 1)
if cmd_check and cmd_check[0] in ("add", "set", "remove", "print", "enable", "disable"):
if cmd_check and cmd_check[0] in (
"add",
"set",
"remove",
"print",
"enable",
"disable",
):
current_path = parts[0]
line = parts[1].strip()
else:
@@ -184,12 +204,14 @@ def compute_impact(
risk = "none"
if has_changes:
risk = "high" if path in HIGH_RISK_PATHS else "low"
result_categories.append({
"path": path,
"adds": added,
"removes": removed,
"risk": risk,
})
result_categories.append(
{
"path": path,
"adds": added,
"removes": removed,
"risk": risk,
}
)
# Check target commands against management patterns
target_text = "\n".join(

View File

@@ -16,9 +16,7 @@ from srptools.constants import PRIME_2048, PRIME_2048_GEN
_SRP_HASH = hashlib.sha256
async def create_srp_verifier(
salt_hex: str, verifier_hex: str
) -> tuple[bytes, bytes]:
async def create_srp_verifier(salt_hex: str, verifier_hex: str) -> tuple[bytes, bytes]:
"""Convert client-provided hex salt and verifier to bytes for storage.
The client computes v = g^x mod N using 2SKD-derived SRP-x.
@@ -31,9 +29,7 @@ async def create_srp_verifier(
return bytes.fromhex(salt_hex), bytes.fromhex(verifier_hex)
async def srp_init(
email: str, srp_verifier_hex: str
) -> tuple[str, str]:
async def srp_init(email: str, srp_verifier_hex: str) -> tuple[str, str]:
"""SRP Step 1: Generate server ephemeral (B) and private key (b).
Args:
@@ -47,14 +43,15 @@ async def srp_init(
Raises:
ValueError: If SRP initialization fails for any reason.
"""
def _init() -> tuple[str, str]:
context = SRPContext(
email, prime=PRIME_2048, generator=PRIME_2048_GEN,
email,
prime=PRIME_2048,
generator=PRIME_2048_GEN,
hash_func=_SRP_HASH,
)
server_session = SRPServerSession(
context, srp_verifier_hex
)
server_session = SRPServerSession(context, srp_verifier_hex)
return server_session.public, server_session.private
try:
@@ -85,26 +82,27 @@ async def srp_verify(
Tuple of (is_valid, server_proof_hex_or_none).
If valid, server_proof is M2 for the client to verify.
"""
def _verify() -> tuple[bool, str | None]:
import logging
log = logging.getLogger("srp_debug")
context = SRPContext(
email, prime=PRIME_2048, generator=PRIME_2048_GEN,
email,
prime=PRIME_2048,
generator=PRIME_2048_GEN,
hash_func=_SRP_HASH,
)
server_session = SRPServerSession(
context, srp_verifier_hex, private=server_private
)
server_session = SRPServerSession(context, srp_verifier_hex, private=server_private)
_key, _key_proof, _key_proof_hash = server_session.process(client_public, srp_salt_hex)
# srptools verify_proof has a Python 3 bug: hexlify() returns bytes
# but client_proof is str, so bytes == str is always False.
# Compare manually with consistent types.
server_m1 = _key_proof if isinstance(_key_proof, str) else _key_proof.decode('ascii')
server_m1 = _key_proof if isinstance(_key_proof, str) else _key_proof.decode("ascii")
is_valid = client_proof.lower() == server_m1.lower()
if not is_valid:
return False, None
# Return M2 (key_proof_hash), also fixing the bytes/str issue
m2 = _key_proof_hash if isinstance(_key_proof_hash, str) else _key_proof_hash.decode('ascii')
m2 = (
_key_proof_hash if isinstance(_key_proof_hash, str) else _key_proof_hash.decode("ascii")
)
return True, m2
try:

View File

@@ -137,7 +137,9 @@ class SSEConnectionManager:
if last_event_id is not None:
try:
start_seq = int(last_event_id) + 1
consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.BY_START_SEQUENCE, opt_start_seq=start_seq)
consumer_cfg = ConsumerConfig(
deliver_policy=DeliverPolicy.BY_START_SEQUENCE, opt_start_seq=start_seq
)
except (ValueError, TypeError):
consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.NEW)
else:
@@ -173,18 +175,32 @@ class SSEConnectionManager:
except Exception as exc:
if "stream not found" in str(exc):
try:
await js.add_stream(StreamConfig(
name="ALERT_EVENTS",
subjects=_ALERT_EVENT_SUBJECTS,
max_age=3600,
))
sub = await js.subscribe(subject, stream="ALERT_EVENTS", config=consumer_cfg)
await js.add_stream(
StreamConfig(
name="ALERT_EVENTS",
subjects=_ALERT_EVENT_SUBJECTS,
max_age=3600,
)
)
sub = await js.subscribe(
subject, stream="ALERT_EVENTS", config=consumer_cfg
)
self._subscriptions.append(sub)
logger.info("sse.stream_created_lazily", stream="ALERT_EVENTS")
except Exception as retry_exc:
logger.warning("sse.subscribe_failed", subject=subject, stream="ALERT_EVENTS", error=str(retry_exc))
logger.warning(
"sse.subscribe_failed",
subject=subject,
stream="ALERT_EVENTS",
error=str(retry_exc),
)
else:
logger.warning("sse.subscribe_failed", subject=subject, stream="ALERT_EVENTS", error=str(exc))
logger.warning(
"sse.subscribe_failed",
subject=subject,
stream="ALERT_EVENTS",
error=str(exc),
)
# Subscribe to operation events (OPERATION_EVENTS stream)
for subject in _OPERATION_EVENT_SUBJECTS:
@@ -198,18 +214,32 @@ class SSEConnectionManager:
except Exception as exc:
if "stream not found" in str(exc):
try:
await js.add_stream(StreamConfig(
name="OPERATION_EVENTS",
subjects=_OPERATION_EVENT_SUBJECTS,
max_age=3600,
))
sub = await js.subscribe(subject, stream="OPERATION_EVENTS", config=consumer_cfg)
await js.add_stream(
StreamConfig(
name="OPERATION_EVENTS",
subjects=_OPERATION_EVENT_SUBJECTS,
max_age=3600,
)
)
sub = await js.subscribe(
subject, stream="OPERATION_EVENTS", config=consumer_cfg
)
self._subscriptions.append(sub)
logger.info("sse.stream_created_lazily", stream="OPERATION_EVENTS")
except Exception as retry_exc:
logger.warning("sse.subscribe_failed", subject=subject, stream="OPERATION_EVENTS", error=str(retry_exc))
logger.warning(
"sse.subscribe_failed",
subject=subject,
stream="OPERATION_EVENTS",
error=str(retry_exc),
)
else:
logger.warning("sse.subscribe_failed", subject=subject, stream="OPERATION_EVENTS", error=str(exc))
logger.warning(
"sse.subscribe_failed",
subject=subject,
stream="OPERATION_EVENTS",
error=str(exc),
)
# Start background task to pull messages from subscriptions into the queue
asyncio.create_task(self._pump_messages())

View File

@@ -12,22 +12,18 @@ separate scheduler and file names to avoid conflicts with restore operations.
"""
import asyncio
import io
import ipaddress
import json
import logging
import uuid
from datetime import datetime, timezone
import asyncssh
from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment
from sqlalchemy import select, text
from sqlalchemy import text
from app.config import settings
from app.database import AdminAsyncSessionLocal
from app.models.config_template import TemplatePushJob
from app.models.device import Device
logger = logging.getLogger(__name__)
@@ -145,7 +141,9 @@ async def push_to_devices(rollout_id: str) -> dict:
except Exception as exc:
logger.error(
"Uncaught exception in template push rollout %s: %s",
rollout_id, exc, exc_info=True,
rollout_id,
exc,
exc_info=True,
)
return {"completed": 0, "failed": 1, "pending": 0}
@@ -181,7 +179,9 @@ async def _run_push_rollout(rollout_id: str) -> dict:
logger.info(
"Template push rollout %s: pushing to device %s (job %s)",
rollout_id, hostname, job_id,
rollout_id,
hostname,
job_id,
)
await push_single_device(job_id)
@@ -200,7 +200,9 @@ async def _run_push_rollout(rollout_id: str) -> dict:
failed = True
logger.error(
"Template push rollout %s paused: device %s %s",
rollout_id, hostname, row[0],
rollout_id,
hostname,
row[0],
)
break
@@ -232,7 +234,9 @@ async def push_single_device(job_id: str) -> None:
except Exception as exc:
logger.error(
"Uncaught exception in template push job %s: %s",
job_id, exc, exc_info=True,
job_id,
exc,
exc_info=True,
)
await _update_job(job_id, status="failed", error_message=f"Unexpected error: {exc}")
@@ -260,8 +264,13 @@ async def _run_single_push(job_id: str) -> None:
return
(
_, device_id, tenant_id, rendered_content,
ip_address, hostname, encrypted_credentials,
_,
device_id,
tenant_id,
rendered_content,
ip_address,
hostname,
encrypted_credentials,
encrypted_credentials_transit,
) = row
@@ -279,16 +288,21 @@ async def _run_single_push(job_id: str) -> None:
try:
from app.services.crypto import decrypt_credentials_hybrid
key = settings.get_encryption_key_bytes()
creds_json = await decrypt_credentials_hybrid(
encrypted_credentials_transit, encrypted_credentials, tenant_id, key,
encrypted_credentials_transit,
encrypted_credentials,
tenant_id,
key,
)
creds = json.loads(creds_json)
ssh_username = creds.get("username", "")
ssh_password = creds.get("password", "")
except Exception as cred_err:
await _update_job(
job_id, status="failed",
job_id,
status="failed",
error_message=f"Failed to decrypt credentials: {cred_err}",
)
return
@@ -297,6 +311,7 @@ async def _run_single_push(job_id: str) -> None:
logger.info("Running mandatory pre-push backup for device %s (%s)", hostname, ip_address)
try:
from app.services import backup_service
backup_result = await backup_service.run_backup(
device_id=device_id,
tenant_id=tenant_id,
@@ -308,7 +323,8 @@ async def _run_single_push(job_id: str) -> None:
except Exception as backup_err:
logger.error("Pre-push backup failed for %s: %s", hostname, backup_err)
await _update_job(
job_id, status="failed",
job_id,
status="failed",
error_message=f"Pre-push backup failed: {backup_err}",
)
return
@@ -316,7 +332,8 @@ async def _run_single_push(job_id: str) -> None:
# Step 5: SSH to device - install panic-revert, push config
logger.info(
"Pushing template to device %s (%s): installing panic-revert and uploading config",
hostname, ip_address,
hostname,
ip_address,
)
try:
@@ -359,7 +376,8 @@ async def _run_single_push(job_id: str) -> None:
)
logger.info(
"Template import result for device %s: exit_status=%s stdout=%r",
hostname, import_result.exit_status,
hostname,
import_result.exit_status,
(import_result.stdout or "")[:200],
)
@@ -369,16 +387,21 @@ async def _run_single_push(job_id: str) -> None:
except Exception as cleanup_err:
logger.warning(
"Failed to clean up %s from device %s: %s",
_TEMPLATE_RSC, ip_address, cleanup_err,
_TEMPLATE_RSC,
ip_address,
cleanup_err,
)
except Exception as push_err:
logger.error(
"SSH push phase failed for device %s (%s): %s",
hostname, ip_address, push_err,
hostname,
ip_address,
push_err,
)
await _update_job(
job_id, status="failed",
job_id,
status="failed",
error_message=f"Config push failed during SSH phase: {push_err}",
)
return
@@ -395,9 +418,12 @@ async def _run_single_push(job_id: str) -> None:
logger.info("Device %s (%s) is reachable after push - committing", hostname, ip_address)
try:
async with asyncssh.connect(
ip_address, port=22,
username=ssh_username, password=ssh_password,
known_hosts=None, connect_timeout=30,
ip_address,
port=22,
username=ssh_username,
password=ssh_password,
known_hosts=None,
connect_timeout=30,
) as conn:
await conn.run(
f'/system scheduler remove "{_PANIC_REVERT_SCHEDULER}"',
@@ -410,11 +436,13 @@ async def _run_single_push(job_id: str) -> None:
except Exception as cleanup_err:
logger.warning(
"Failed to clean up panic-revert scheduler/backup on device %s: %s",
hostname, cleanup_err,
hostname,
cleanup_err,
)
await _update_job(
job_id, status="committed",
job_id,
status="committed",
completed_at=datetime.now(timezone.utc),
)
else:
@@ -422,10 +450,13 @@ async def _run_single_push(job_id: str) -> None:
logger.warning(
"Device %s (%s) is unreachable after push - panic-revert scheduler "
"will auto-revert to %s.backup",
hostname, ip_address, _PRE_PUSH_BACKUP,
hostname,
ip_address,
_PRE_PUSH_BACKUP,
)
await _update_job(
job_id, status="reverted",
job_id,
status="reverted",
error_message="Device unreachable after push; auto-reverted via panic-revert scheduler",
completed_at=datetime.now(timezone.utc),
)
@@ -440,9 +471,12 @@ async def _check_reachability(ip: str, username: str, password: str) -> bool:
"""Check if a RouterOS device is reachable via SSH."""
try:
async with asyncssh.connect(
ip, port=22,
username=username, password=password,
known_hosts=None, connect_timeout=30,
ip,
port=22,
username=username,
password=password,
known_hosts=None,
connect_timeout=30,
) as conn:
result = await conn.run("/system identity print", check=True)
logger.debug("Reachability check OK for %s: %r", ip, result.stdout[:50])
@@ -459,7 +493,12 @@ async def _update_job(job_id: str, **kwargs) -> None:
for key, value in kwargs.items():
param_name = f"v_{key}"
if value is None and key in ("error_message", "started_at", "completed_at", "pre_push_backup_sha"):
if value is None and key in (
"error_message",
"started_at",
"completed_at",
"pre_push_backup_sha",
):
sets.append(f"{key} = NULL")
else:
sets.append(f"{key} = :{param_name}")
@@ -472,7 +511,7 @@ async def _update_job(job_id: str, **kwargs) -> None:
await session.execute(
text(f"""
UPDATE template_push_jobs
SET {', '.join(sets)}
SET {", ".join(sets)}
WHERE id = CAST(:job_id AS uuid)
"""),
params,

View File

@@ -13,7 +13,6 @@ jobs may span multiple tenants and run in background asyncio tasks.
"""
import asyncio
import io
import json
import logging
from datetime import datetime, timezone
@@ -99,10 +98,19 @@ async def _run_upgrade(job_id: str) -> None:
return
(
_, device_id, tenant_id, target_version,
architecture, channel, status, confirmed_major,
ip_address, hostname, encrypted_credentials,
current_version, encrypted_credentials_transit,
_,
device_id,
tenant_id,
target_version,
architecture,
channel,
status,
confirmed_major,
ip_address,
hostname,
encrypted_credentials,
current_version,
encrypted_credentials_transit,
) = row
device_id = str(device_id)
@@ -116,12 +124,22 @@ async def _run_upgrade(job_id: str) -> None:
logger.info(
"Starting firmware upgrade for %s (%s): %s -> %s",
hostname, ip_address, current_version, target_version,
hostname,
ip_address,
current_version,
target_version,
)
# Step 2: Update status to downloading
await _update_job(job_id, status="downloading", started_at=datetime.now(timezone.utc))
await _publish_upgrade_progress(tenant_id, device_id, job_id, "downloading", target_version, f"Downloading firmware {target_version} for {hostname}")
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"downloading",
target_version,
f"Downloading firmware {target_version} for {hostname}",
)
# Step 3: Check major version upgrade confirmation
if current_version and target_version:
@@ -133,13 +151,22 @@ async def _run_upgrade(job_id: str) -> None:
status="failed",
error_message="Major version upgrade requires explicit confirmation",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Major version upgrade requires explicit confirmation for {hostname}", error="Major version upgrade requires explicit confirmation")
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"failed",
target_version,
f"Major version upgrade requires explicit confirmation for {hostname}",
error="Major version upgrade requires explicit confirmation",
)
return
# Step 4: Mandatory config backup
logger.info("Running mandatory pre-upgrade backup for %s", hostname)
try:
from app.services import backup_service
backup_result = await backup_service.run_backup(
device_id=device_id,
tenant_id=tenant_id,
@@ -155,13 +182,22 @@ async def _run_upgrade(job_id: str) -> None:
status="failed",
error_message=f"Pre-upgrade backup failed: {backup_err}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Pre-upgrade backup failed for {hostname}", error=str(backup_err))
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"failed",
target_version,
f"Pre-upgrade backup failed for {hostname}",
error=str(backup_err),
)
return
# Step 5: Download NPK
logger.info("Downloading firmware %s for %s/%s", target_version, architecture, channel)
try:
from app.services.firmware_service import download_firmware
npk_path = await download_firmware(architecture, channel, target_version)
logger.info("Firmware cached at %s", npk_path)
except Exception as dl_err:
@@ -171,24 +207,51 @@ async def _run_upgrade(job_id: str) -> None:
status="failed",
error_message=f"Firmware download failed: {dl_err}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Firmware download failed for {hostname}", error=str(dl_err))
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"failed",
target_version,
f"Firmware download failed for {hostname}",
error=str(dl_err),
)
return
# Step 6: Upload NPK to device via SFTP
await _update_job(job_id, status="uploading")
await _publish_upgrade_progress(tenant_id, device_id, job_id, "uploading", target_version, f"Uploading firmware to {hostname}")
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"uploading",
target_version,
f"Uploading firmware to {hostname}",
)
# Decrypt device credentials (dual-read: Transit preferred, legacy fallback)
if not encrypted_credentials_transit and not encrypted_credentials:
await _update_job(job_id, status="failed", error_message="Device has no stored credentials")
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"No stored credentials for {hostname}", error="Device has no stored credentials")
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"failed",
target_version,
f"No stored credentials for {hostname}",
error="Device has no stored credentials",
)
return
try:
from app.services.crypto import decrypt_credentials_hybrid
key = settings.get_encryption_key_bytes()
creds_json = await decrypt_credentials_hybrid(
encrypted_credentials_transit, encrypted_credentials, tenant_id, key,
encrypted_credentials_transit,
encrypted_credentials,
tenant_id,
key,
)
creds = json.loads(creds_json)
ssh_username = creds.get("username", "")
@@ -199,7 +262,15 @@ async def _run_upgrade(job_id: str) -> None:
status="failed",
error_message=f"Failed to decrypt credentials: {cred_err}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Failed to decrypt credentials for {hostname}", error=str(cred_err))
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"failed",
target_version,
f"Failed to decrypt credentials for {hostname}",
error=str(cred_err),
)
return
try:
@@ -225,12 +296,27 @@ async def _run_upgrade(job_id: str) -> None:
status="failed",
error_message=f"NPK upload failed: {upload_err}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"NPK upload failed for {hostname}", error=str(upload_err))
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"failed",
target_version,
f"NPK upload failed for {hostname}",
error=str(upload_err),
)
return
# Step 7: Trigger reboot
await _update_job(job_id, status="rebooting")
await _publish_upgrade_progress(tenant_id, device_id, job_id, "rebooting", target_version, f"Rebooting {hostname} for firmware install")
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"rebooting",
target_version,
f"Rebooting {hostname} for firmware install",
)
try:
async with asyncssh.connect(
ip_address,
@@ -245,7 +331,9 @@ async def _run_upgrade(job_id: str) -> None:
logger.info("Reboot command sent to %s", hostname)
except Exception as reboot_err:
# Device may drop connection during reboot — this is expected
logger.info("Device %s dropped connection after reboot command (expected): %s", hostname, reboot_err)
logger.info(
"Device %s dropped connection after reboot command (expected): %s", hostname, reboot_err
)
# Step 8: Wait for reconnect
logger.info("Waiting %ds before polling %s for reconnect", _INITIAL_WAIT, hostname)
@@ -267,36 +355,69 @@ async def _run_upgrade(job_id: str) -> None:
status="failed",
error_message=f"Device did not reconnect within {_RECONNECT_TIMEOUT // 60} minutes after reboot",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Device {hostname} did not reconnect within {_RECONNECT_TIMEOUT // 60} minutes", error="Reconnect timeout")
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"failed",
target_version,
f"Device {hostname} did not reconnect within {_RECONNECT_TIMEOUT // 60} minutes",
error="Reconnect timeout",
)
return
# Step 9: Verify upgrade
await _update_job(job_id, status="verifying")
await _publish_upgrade_progress(tenant_id, device_id, job_id, "verifying", target_version, f"Verifying firmware version on {hostname}")
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"verifying",
target_version,
f"Verifying firmware version on {hostname}",
)
try:
actual_version = await _get_device_version(ip_address, ssh_username, ssh_password)
if actual_version and target_version in actual_version:
logger.info(
"Firmware upgrade verified for %s: %s",
hostname, actual_version,
hostname,
actual_version,
)
await _update_job(
job_id,
status="completed",
completed_at=datetime.now(timezone.utc),
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "completed", target_version, f"Firmware upgrade to {target_version} completed on {hostname}")
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"completed",
target_version,
f"Firmware upgrade to {target_version} completed on {hostname}",
)
else:
logger.error(
"Version mismatch for %s: expected %s, got %s",
hostname, target_version, actual_version,
hostname,
target_version,
actual_version,
)
await _update_job(
job_id,
status="failed",
error_message=f"Expected {target_version} but got {actual_version}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Version mismatch on {hostname}: expected {target_version}, got {actual_version}", error=f"Expected {target_version} but got {actual_version}")
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"failed",
target_version,
f"Version mismatch on {hostname}: expected {target_version}, got {actual_version}",
error=f"Expected {target_version} but got {actual_version}",
)
except Exception as verify_err:
logger.error("Post-upgrade verification failed for %s: %s", hostname, verify_err)
await _update_job(
@@ -304,7 +425,15 @@ async def _run_upgrade(job_id: str) -> None:
status="failed",
error_message=f"Post-upgrade verification failed: {verify_err}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Post-upgrade verification failed for {hostname}", error=str(verify_err))
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"failed",
target_version,
f"Post-upgrade verification failed for {hostname}",
error=str(verify_err),
)
async def start_mass_upgrade(rollout_group_id: str) -> dict:
@@ -457,7 +586,7 @@ async def resume_mass_upgrade(rollout_group_id: str) -> None:
"""Resume a paused mass rollout from where it left off."""
# Reset first paused job to pending, then restart sequential processing
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
await session.execute(
text("""
UPDATE firmware_upgrade_jobs
SET status = 'pending'
@@ -519,7 +648,7 @@ async def _update_job(job_id: str, **kwargs) -> None:
await session.execute(
text(f"""
UPDATE firmware_upgrade_jobs
SET {', '.join(sets)}
SET {", ".join(sets)}
WHERE id = CAST(:job_id AS uuid)
"""),
params,

View File

@@ -26,7 +26,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.models.device import Device
from app.models.vpn import VpnConfig, VpnPeer
from app.services.crypto import decrypt_credentials, encrypt_credentials, encrypt_credentials_transit
from app.services.crypto import (
decrypt_credentials,
encrypt_credentials,
encrypt_credentials_transit,
)
logger = structlog.get_logger(__name__)
@@ -62,7 +66,9 @@ async def _get_or_create_global_server_key(db: AsyncSession) -> tuple[str, str]:
await db.execute(sa_text("SELECT pg_advisory_xact_lock(hashtext('vpn_server_keygen'))"))
result = await db.execute(
sa_text("SELECT key, value, encrypted_value FROM system_settings WHERE key IN ('vpn_server_public_key', 'vpn_server_private_key')")
sa_text(
"SELECT key, value, encrypted_value FROM system_settings WHERE key IN ('vpn_server_public_key', 'vpn_server_private_key')"
)
)
rows = {row[0]: row for row in result.fetchall()}
@@ -188,7 +194,9 @@ async def sync_wireguard_config() -> None:
# Query ALL enabled VPN configs (admin session bypasses RLS)
configs_result = await admin_db.execute(
select(VpnConfig).where(VpnConfig.is_enabled.is_(True)).order_by(VpnConfig.subnet_index)
select(VpnConfig)
.where(VpnConfig.is_enabled.is_(True))
.order_by(VpnConfig.subnet_index)
)
configs = configs_result.scalars().all()
@@ -228,7 +236,9 @@ async def sync_wireguard_config() -> None:
peer_ip = peer.assigned_ip.split("/")[0]
allowed_ips = [f"{peer_ip}/32"]
if peer.additional_allowed_ips:
extra = [s.strip() for s in peer.additional_allowed_ips.split(",") if s.strip()]
extra = [
s.strip() for s in peer.additional_allowed_ips.split(",") if s.strip()
]
allowed_ips.extend(extra)
lines.append("[Peer]")
lines.append(f"PublicKey = {peer.peer_public_key}")
@@ -253,12 +263,13 @@ async def sync_wireguard_config() -> None:
# Docker traffic (172.16.0.0/12) going to each tenant's subnet
# gets SNATted to that tenant's gateway IP (.1) so the router
# can route replies back through the tunnel.
nat_lines = ["#!/bin/sh",
"# Auto-generated per-tenant SNAT rules",
"# Remove old rules",
"iptables -t nat -F POSTROUTING 2>/dev/null",
"# Re-add Docker DNS rules",
]
nat_lines = [
"#!/bin/sh",
"# Auto-generated per-tenant SNAT rules",
"# Remove old rules",
"iptables -t nat -F POSTROUTING 2>/dev/null",
"# Re-add Docker DNS rules",
]
for config in configs:
gateway_ip = config.server_address.split("/")[0] # e.g. 10.10.3.1
subnet = config.subnet # e.g. 10.10.3.0/24
@@ -275,12 +286,15 @@ async def sync_wireguard_config() -> None:
reload_flag = wg_confs_dir / ".reload"
reload_flag.write_text("1")
logger.info("wireguard_config_synced", audit=True,
tenants=len(configs), peers=total_peers)
logger.info(
"wireguard_config_synced", audit=True, tenants=len(configs), peers=total_peers
)
finally:
# Release advisory lock explicitly (session-level lock, not xact-level)
await admin_db.execute(sa_text("SELECT pg_advisory_unlock(hashtext('wireguard_config'))"))
await admin_db.execute(
sa_text("SELECT pg_advisory_unlock(hashtext('wireguard_config'))")
)
# ── Live Status ──
@@ -371,15 +385,23 @@ async def setup_vpn(
db.add(config)
await db.flush()
logger.info("vpn_subnet_allocated", audit=True,
tenant_id=str(tenant_id), subnet_index=subnet_index, subnet=subnet)
logger.info(
"vpn_subnet_allocated",
audit=True,
tenant_id=str(tenant_id),
subnet_index=subnet_index,
subnet=subnet,
)
await _commit_and_sync(db)
return config
async def update_vpn_config(
db: AsyncSession, tenant_id: uuid.UUID, endpoint: Optional[str] = None, is_enabled: Optional[bool] = None
db: AsyncSession,
tenant_id: uuid.UUID,
endpoint: Optional[str] = None,
is_enabled: Optional[bool] = None,
) -> VpnConfig:
"""Update VPN config settings."""
config = await get_vpn_config(db, tenant_id)
@@ -422,14 +444,21 @@ async def _next_available_ip(db: AsyncSession, tenant_id: uuid.UUID, config: Vpn
raise ValueError("No available IPs in VPN subnet")
async def add_peer(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UUID, additional_allowed_ips: Optional[str] = None) -> VpnPeer:
async def add_peer(
db: AsyncSession,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
additional_allowed_ips: Optional[str] = None,
) -> VpnPeer:
"""Add a device as a VPN peer."""
config = await get_vpn_config(db, tenant_id)
if not config:
raise ValueError("VPN not configured — enable VPN first")
# Check device exists
device = await db.execute(select(Device).where(Device.id == device_id, Device.tenant_id == tenant_id))
device = await db.execute(
select(Device).where(Device.id == device_id, Device.tenant_id == tenant_id)
)
if not device.scalar_one_or_none():
raise ValueError("Device not found")
@@ -497,13 +526,12 @@ async def get_peer_config(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.
psk = decrypt_credentials(peer.preshared_key, key_bytes) if peer.preshared_key else None
endpoint = config.endpoint or "YOUR_SERVER_IP:51820"
peer_ip_no_cidr = peer.assigned_ip.split("/")[0]
routeros_commands = [
f'/interface wireguard add name=wg-portal listen-port=13231 private-key="{private_key}"',
f'/interface wireguard peers add interface=wg-portal public-key="{config.server_public_key}" '
f'endpoint-address={endpoint.split(":")[0]} endpoint-port={endpoint.split(":")[-1]} '
f'allowed-address=10.10.0.0/16 persistent-keepalive=25'
f"endpoint-address={endpoint.split(':')[0]} endpoint-port={endpoint.split(':')[-1]} "
f"allowed-address=10.10.0.0/16 persistent-keepalive=25"
+ (f' preshared-key="{psk}"' if psk else ""),
f"/ip address add address={peer.assigned_ip} interface=wg-portal",
]
@@ -583,8 +611,8 @@ async def onboard_device(
routeros_commands = [
f'/interface wireguard add name=wg-portal listen-port=13231 private-key="{private_key_b64}"',
f'/interface wireguard peers add interface=wg-portal public-key="{config.server_public_key}" '
f'endpoint-address={endpoint.split(":")[0]} endpoint-port={endpoint.split(":")[-1]} '
f'allowed-address=10.10.0.0/16 persistent-keepalive=25'
f"endpoint-address={endpoint.split(':')[0]} endpoint-port={endpoint.split(':')[-1]} "
f"allowed-address=10.10.0.0/16 persistent-keepalive=25"
f' preshared-key="{psk_decrypted}"',
f"/ip address add address={assigned_ip} interface=wg-portal",
]