feat: The Other Dude v9.0.1 — full-featured email system
ci: add GitHub Pages deployment workflow for docs site Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
1
backend/app/services/__init__.py
Normal file
1
backend/app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Backend services — auth, crypto, and business logic."""
|
||||
240
backend/app/services/account_service.py
Normal file
240
backend/app/services/account_service.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Account self-service operations: deletion and data export.
|
||||
|
||||
Provides GDPR/CCPA-compliant account deletion with full PII erasure
|
||||
and data portability export (Article 20).
|
||||
|
||||
All queries use raw SQL via text() with admin sessions (bypass RLS)
|
||||
since these are cross-table operations on the authenticated user's data.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
logger = structlog.get_logger("account_service")
|
||||
|
||||
|
||||
async def delete_user_account(
|
||||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
tenant_id: uuid.UUID | None,
|
||||
user_email: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Hard-delete a user account with full PII erasure.
|
||||
|
||||
Steps:
|
||||
1. Create a deletion receipt audit log (persisted via separate session)
|
||||
2. Anonymize PII in existing audit_logs for this user
|
||||
3. Hard-delete the user row (CASCADE handles related tables)
|
||||
4. Best-effort session invalidation via Redis
|
||||
|
||||
Args:
|
||||
db: Admin async session (bypasses RLS).
|
||||
user_id: UUID of the user to delete.
|
||||
tenant_id: Tenant UUID (None for super_admin).
|
||||
user_email: User's email (needed for audit hash before deletion).
|
||||
|
||||
Returns:
|
||||
Dict with deleted=True and user_id on success.
|
||||
"""
|
||||
effective_tenant_id = tenant_id or uuid.UUID(int=0)
|
||||
email_hash = hashlib.sha256(user_email.encode()).hexdigest()
|
||||
|
||||
# ── 1. Pre-deletion audit receipt (separate session so it persists) ────
|
||||
try:
|
||||
async with AdminAsyncSessionLocal() as audit_db:
|
||||
await log_action(
|
||||
audit_db,
|
||||
tenant_id=effective_tenant_id,
|
||||
user_id=user_id,
|
||||
action="account_deleted",
|
||||
resource_type="user",
|
||||
resource_id=str(user_id),
|
||||
details={
|
||||
"deleted_user_id": str(user_id),
|
||||
"email_hash": email_hash,
|
||||
"deletion_type": "self_service",
|
||||
"deleted_at": datetime.now(UTC).isoformat(),
|
||||
},
|
||||
)
|
||||
await audit_db.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"deletion_receipt_failed",
|
||||
user_id=str(user_id),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# ── 2. Anonymize PII in audit_logs for this user ─────────────────────
|
||||
# Strip PII keys from details JSONB (email, name, user_email, user_name)
|
||||
await db.execute(
|
||||
text(
|
||||
"UPDATE audit_logs "
|
||||
"SET details = details - 'email' - 'name' - 'user_email' - 'user_name' "
|
||||
"WHERE user_id = :user_id"
|
||||
),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
|
||||
# Null out encrypted_details (may contain encrypted PII)
|
||||
await db.execute(
|
||||
text(
|
||||
"UPDATE audit_logs "
|
||||
"SET encrypted_details = NULL "
|
||||
"WHERE user_id = :user_id"
|
||||
),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
|
||||
# ── 3. Hard delete user row ──────────────────────────────────────────
|
||||
# CASCADE handles: user_key_sets, api_keys, password_reset_tokens
|
||||
# SET NULL handles: audit_logs.user_id, key_access_log.user_id,
|
||||
# maintenance_windows.created_by, alert_events.acknowledged_by
|
||||
await db.execute(
|
||||
text("DELETE FROM users WHERE id = :user_id"),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# ── 4. Best-effort Redis session invalidation ────────────────────────
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
from app.config import settings
|
||||
from app.services.auth import revoke_user_tokens
|
||||
|
||||
r = aioredis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
await revoke_user_tokens(r, str(user_id))
|
||||
await r.aclose()
|
||||
except Exception:
|
||||
# JWT expires in 15 min anyway; not critical
|
||||
logger.debug("redis_session_invalidation_skipped", user_id=str(user_id))
|
||||
|
||||
logger.info("account_deleted", user_id=str(user_id), email_hash=email_hash)
|
||||
|
||||
return {"deleted": True, "user_id": str(user_id)}
|
||||
|
||||
|
||||
async def export_user_data(
|
||||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
tenant_id: uuid.UUID | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Assemble all user data for GDPR Art. 20 data portability export.
|
||||
|
||||
Returns a structured dict with user profile, API keys, audit logs,
|
||||
and key access log entries.
|
||||
|
||||
Args:
|
||||
db: Admin async session (bypasses RLS).
|
||||
user_id: UUID of the user whose data to export.
|
||||
tenant_id: Tenant UUID (None for super_admin).
|
||||
|
||||
Returns:
|
||||
Envelope dict with export_date, format_version, and all user data.
|
||||
"""
|
||||
|
||||
# ── User profile ─────────────────────────────────────────────────────
|
||||
result = await db.execute(
|
||||
text(
|
||||
"SELECT id, email, name, role, tenant_id, "
|
||||
"created_at, last_login, auth_version "
|
||||
"FROM users WHERE id = :user_id"
|
||||
),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
user_row = result.mappings().first()
|
||||
user_data: dict[str, Any] = {}
|
||||
if user_row:
|
||||
user_data = {
|
||||
"id": str(user_row["id"]),
|
||||
"email": user_row["email"],
|
||||
"name": user_row["name"],
|
||||
"role": user_row["role"],
|
||||
"tenant_id": str(user_row["tenant_id"]) if user_row["tenant_id"] else None,
|
||||
"created_at": user_row["created_at"].isoformat() if user_row["created_at"] else None,
|
||||
"last_login": user_row["last_login"].isoformat() if user_row["last_login"] else None,
|
||||
"auth_version": user_row["auth_version"],
|
||||
}
|
||||
|
||||
# ── API keys (exclude key_hash for security) ─────────────────────────
|
||||
result = await db.execute(
|
||||
text(
|
||||
"SELECT id, name, key_prefix, scopes, created_at, "
|
||||
"expires_at, revoked_at, last_used_at "
|
||||
"FROM api_keys WHERE user_id = :user_id "
|
||||
"ORDER BY created_at DESC"
|
||||
),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
api_keys = []
|
||||
for row in result.mappings().all():
|
||||
api_keys.append({
|
||||
"id": str(row["id"]),
|
||||
"name": row["name"],
|
||||
"key_prefix": row["key_prefix"],
|
||||
"scopes": row["scopes"],
|
||||
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
||||
"expires_at": row["expires_at"].isoformat() if row["expires_at"] else None,
|
||||
"revoked_at": row["revoked_at"].isoformat() if row["revoked_at"] else None,
|
||||
"last_used_at": row["last_used_at"].isoformat() if row["last_used_at"] else None,
|
||||
})
|
||||
|
||||
# ── Audit logs (limit 1000, most recent first) ───────────────────────
|
||||
result = await db.execute(
|
||||
text(
|
||||
"SELECT id, action, resource_type, resource_id, "
|
||||
"details, ip_address, created_at "
|
||||
"FROM audit_logs WHERE user_id = :user_id "
|
||||
"ORDER BY created_at DESC LIMIT 1000"
|
||||
),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
audit_logs = []
|
||||
for row in result.mappings().all():
|
||||
details = row["details"] if row["details"] else {}
|
||||
audit_logs.append({
|
||||
"id": str(row["id"]),
|
||||
"action": row["action"],
|
||||
"resource_type": row["resource_type"],
|
||||
"resource_id": row["resource_id"],
|
||||
"details": details,
|
||||
"ip_address": row["ip_address"],
|
||||
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
||||
})
|
||||
|
||||
# ── Key access log (limit 1000, most recent first) ───────────────────
|
||||
result = await db.execute(
|
||||
text(
|
||||
"SELECT id, action, resource_type, ip_address, created_at "
|
||||
"FROM key_access_log WHERE user_id = :user_id "
|
||||
"ORDER BY created_at DESC LIMIT 1000"
|
||||
),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
key_access_entries = []
|
||||
for row in result.mappings().all():
|
||||
key_access_entries.append({
|
||||
"id": str(row["id"]),
|
||||
"action": row["action"],
|
||||
"resource_type": row["resource_type"],
|
||||
"ip_address": row["ip_address"],
|
||||
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
||||
})
|
||||
|
||||
return {
|
||||
"export_date": datetime.now(UTC).isoformat(),
|
||||
"format_version": "1.0",
|
||||
"user": user_data,
|
||||
"api_keys": api_keys,
|
||||
"audit_logs": audit_logs,
|
||||
"key_access_log": key_access_entries,
|
||||
}
|
||||
723
backend/app/services/alert_evaluator.py
Normal file
723
backend/app/services/alert_evaluator.py
Normal file
@@ -0,0 +1,723 @@
|
||||
"""Alert rule evaluation engine with Redis breach counters and flap detection.
|
||||
|
||||
Entry points:
|
||||
- evaluate(device_id, tenant_id, metric_type, data): called from metrics_subscriber
|
||||
- evaluate_offline(device_id, tenant_id): called from nats_subscriber on device offline
|
||||
- evaluate_online(device_id, tenant_id): called from nats_subscriber on device online
|
||||
|
||||
Uses Redis for:
|
||||
- Consecutive breach counting (alert:breach:{device_id}:{rule_id})
|
||||
- Flap detection (alert:flap:{device_id}:{rule_id} sorted set)
|
||||
|
||||
Uses AdminAsyncSessionLocal for all DB operations (runs cross-tenant in NATS handlers).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
from app.services.event_publisher import publish_event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Module-level Redis client, lazily initialized
|
||||
_redis_client: aioredis.Redis | None = None
|
||||
|
||||
# Module-level rule cache: {tenant_id: (rules_list, fetched_at_timestamp)}
|
||||
_rule_cache: dict[str, tuple[list[dict], float]] = {}
|
||||
_CACHE_TTL_SECONDS = 60
|
||||
|
||||
# Module-level maintenance window cache: {tenant_id: (active_windows_list, fetched_at_timestamp)}
|
||||
# Each window: {"device_ids": [...], "suppress_alerts": True}
|
||||
_maintenance_cache: dict[str, tuple[list[dict], float]] = {}
|
||||
_MAINTENANCE_CACHE_TTL = 30 # 30 seconds
|
||||
|
||||
|
||||
async def _get_redis() -> aioredis.Redis:
|
||||
"""Get or create the Redis client."""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = aioredis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
return _redis_client
|
||||
|
||||
|
||||
async def _get_active_maintenance_windows(tenant_id: str) -> list[dict]:
|
||||
"""Fetch active maintenance windows for a tenant, with 30s cache."""
|
||||
now = time.time()
|
||||
cached = _maintenance_cache.get(tenant_id)
|
||||
if cached and (now - cached[1]) < _MAINTENANCE_CACHE_TTL:
|
||||
return cached[0]
|
||||
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT device_ids, suppress_alerts
|
||||
FROM maintenance_windows
|
||||
WHERE tenant_id = CAST(:tenant_id AS uuid)
|
||||
AND suppress_alerts = true
|
||||
AND start_at <= NOW()
|
||||
AND end_at >= NOW()
|
||||
"""),
|
||||
{"tenant_id": tenant_id},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
windows = [
|
||||
{
|
||||
"device_ids": row[0] if isinstance(row[0], list) else [],
|
||||
"suppress_alerts": row[1],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
_maintenance_cache[tenant_id] = (windows, now)
|
||||
return windows
|
||||
|
||||
|
||||
async def _is_device_in_maintenance(tenant_id: str, device_id: str) -> bool:
|
||||
"""Check if a device is currently under active maintenance with alert suppression.
|
||||
|
||||
Returns True if there is at least one active maintenance window covering
|
||||
this device (or all devices via empty device_ids array).
|
||||
"""
|
||||
windows = await _get_active_maintenance_windows(tenant_id)
|
||||
for window in windows:
|
||||
device_ids = window["device_ids"]
|
||||
# Empty device_ids means "all devices in tenant"
|
||||
if not device_ids or device_id in device_ids:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _get_rules_for_tenant(tenant_id: str) -> list[dict]:
|
||||
"""Fetch active alert rules for a tenant, with 60s cache."""
|
||||
now = time.time()
|
||||
cached = _rule_cache.get(tenant_id)
|
||||
if cached and (now - cached[1]) < _CACHE_TTL_SECONDS:
|
||||
return cached[0]
|
||||
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT id, tenant_id, device_id, group_id, name, metric,
|
||||
operator, threshold, duration_polls, severity
|
||||
FROM alert_rules
|
||||
WHERE tenant_id = CAST(:tenant_id AS uuid) AND enabled = TRUE
|
||||
"""),
|
||||
{"tenant_id": tenant_id},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
rules = [
|
||||
{
|
||||
"id": str(row[0]),
|
||||
"tenant_id": str(row[1]),
|
||||
"device_id": str(row[2]) if row[2] else None,
|
||||
"group_id": str(row[3]) if row[3] else None,
|
||||
"name": row[4],
|
||||
"metric": row[5],
|
||||
"operator": row[6],
|
||||
"threshold": float(row[7]),
|
||||
"duration_polls": row[8],
|
||||
"severity": row[9],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
_rule_cache[tenant_id] = (rules, now)
|
||||
return rules
|
||||
|
||||
|
||||
def _check_threshold(value: float, operator: str, threshold: float) -> bool:
|
||||
"""Check if a metric value breaches a threshold."""
|
||||
if operator == "gt":
|
||||
return value > threshold
|
||||
elif operator == "lt":
|
||||
return value < threshold
|
||||
elif operator == "gte":
|
||||
return value >= threshold
|
||||
elif operator == "lte":
|
||||
return value <= threshold
|
||||
return False
|
||||
|
||||
|
||||
def _extract_metrics(metric_type: str, data: dict) -> dict[str, float]:
|
||||
"""Extract metric name->value pairs from a NATS metrics event."""
|
||||
metrics: dict[str, float] = {}
|
||||
|
||||
if metric_type == "health":
|
||||
health = data.get("health", {})
|
||||
for key in ("cpu_load", "temperature"):
|
||||
val = health.get(key)
|
||||
if val is not None and val != "":
|
||||
try:
|
||||
metrics[key] = float(val)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
# Compute memory_used_pct and disk_used_pct
|
||||
free_mem = health.get("free_memory")
|
||||
total_mem = health.get("total_memory")
|
||||
if free_mem is not None and total_mem is not None:
|
||||
try:
|
||||
total = float(total_mem)
|
||||
free = float(free_mem)
|
||||
if total > 0:
|
||||
metrics["memory_used_pct"] = round((1.0 - free / total) * 100, 1)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
free_disk = health.get("free_disk")
|
||||
total_disk = health.get("total_disk")
|
||||
if free_disk is not None and total_disk is not None:
|
||||
try:
|
||||
total = float(total_disk)
|
||||
free = float(free_disk)
|
||||
if total > 0:
|
||||
metrics["disk_used_pct"] = round((1.0 - free / total) * 100, 1)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
elif metric_type == "wireless":
|
||||
wireless = data.get("wireless", [])
|
||||
# Aggregate: use worst signal, lowest CCQ, sum client_count
|
||||
for wif in wireless:
|
||||
for key in ("signal_strength", "ccq", "client_count"):
|
||||
val = wif.get(key) if key != "avg_signal" else wif.get("avg_signal")
|
||||
if key == "signal_strength":
|
||||
val = wif.get("avg_signal")
|
||||
if val is not None and val != "":
|
||||
try:
|
||||
fval = float(val)
|
||||
if key not in metrics:
|
||||
metrics[key] = fval
|
||||
elif key == "signal_strength":
|
||||
metrics[key] = min(metrics[key], fval) # worst signal
|
||||
elif key == "ccq":
|
||||
metrics[key] = min(metrics[key], fval) # worst CCQ
|
||||
elif key == "client_count":
|
||||
metrics[key] = metrics.get(key, 0) + fval # sum
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# TODO: Interface bandwidth alerting (rx_bps/tx_bps) requires stateful delta
|
||||
# computation between consecutive poll values. Deferred for now — the alert_rules
|
||||
# table supports these metric types, but evaluation is skipped.
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
async def _increment_breach(
|
||||
r: aioredis.Redis, device_id: str, rule_id: str, required_polls: int
|
||||
) -> bool:
|
||||
"""Increment breach counter in Redis. Returns True when threshold duration reached."""
|
||||
key = f"alert:breach:{device_id}:{rule_id}"
|
||||
count = await r.incr(key)
|
||||
# Set TTL to (required_polls + 2) * 60 seconds so it expires if breaches stop
|
||||
await r.expire(key, (required_polls + 2) * 60)
|
||||
return count >= required_polls
|
||||
|
||||
|
||||
async def _reset_breach(r: aioredis.Redis, device_id: str, rule_id: str) -> None:
|
||||
"""Reset breach counter when metric returns to normal."""
|
||||
key = f"alert:breach:{device_id}:{rule_id}"
|
||||
await r.delete(key)
|
||||
|
||||
|
||||
async def _check_flapping(r: aioredis.Redis, device_id: str, rule_id: str) -> bool:
|
||||
"""Check if alert is flapping (>= 5 state transitions in 10 minutes).
|
||||
|
||||
Uses a Redis sorted set with timestamps as scores.
|
||||
"""
|
||||
key = f"alert:flap:{device_id}:{rule_id}"
|
||||
now = time.time()
|
||||
window_start = now - 600 # 10 minute window
|
||||
|
||||
# Add this transition
|
||||
await r.zadd(key, {str(now): now})
|
||||
# Remove entries outside the window
|
||||
await r.zremrangebyscore(key, "-inf", window_start)
|
||||
# Set TTL on the key
|
||||
await r.expire(key, 1200)
|
||||
# Count transitions in window
|
||||
count = await r.zcard(key)
|
||||
return count >= 5
|
||||
|
||||
|
||||
async def _get_device_groups(device_id: str) -> list[str]:
|
||||
"""Get group IDs for a device."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("SELECT group_id FROM device_group_memberships WHERE device_id = CAST(:device_id AS uuid)"),
|
||||
{"device_id": device_id},
|
||||
)
|
||||
return [str(row[0]) for row in result.fetchall()]
|
||||
|
||||
|
||||
async def _has_open_alert(device_id: str, rule_id: str | None, metric: str | None = None) -> bool:
|
||||
"""Check if there's an open (firing, unresolved) alert for this device+rule."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
if rule_id:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT 1 FROM alert_events
|
||||
WHERE device_id = CAST(:device_id AS uuid) AND rule_id = CAST(:rule_id AS uuid)
|
||||
AND status = 'firing' AND resolved_at IS NULL
|
||||
LIMIT 1
|
||||
"""),
|
||||
{"device_id": device_id, "rule_id": rule_id},
|
||||
)
|
||||
else:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT 1 FROM alert_events
|
||||
WHERE device_id = CAST(:device_id AS uuid) AND rule_id IS NULL
|
||||
AND metric = :metric AND status = 'firing' AND resolved_at IS NULL
|
||||
LIMIT 1
|
||||
"""),
|
||||
{"device_id": device_id, "metric": metric or "offline"},
|
||||
)
|
||||
return result.fetchone() is not None
|
||||
|
||||
|
||||
async def _create_alert_event(
|
||||
device_id: str,
|
||||
tenant_id: str,
|
||||
rule_id: str | None,
|
||||
status: str,
|
||||
severity: str,
|
||||
metric: str | None,
|
||||
value: float | None,
|
||||
threshold: float | None,
|
||||
message: str | None,
|
||||
is_flapping: bool = False,
|
||||
) -> dict:
|
||||
"""Create an alert event row and return its data."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
INSERT INTO alert_events
|
||||
(id, device_id, tenant_id, rule_id, status, severity, metric,
|
||||
value, threshold, message, is_flapping, fired_at,
|
||||
resolved_at)
|
||||
VALUES
|
||||
(gen_random_uuid(), CAST(:device_id AS uuid), CAST(:tenant_id AS uuid),
|
||||
:rule_id, :status, :severity, :metric,
|
||||
:value, :threshold, :message, :is_flapping, NOW(),
|
||||
CASE WHEN :status = 'resolved' THEN NOW() ELSE NULL END)
|
||||
RETURNING id, fired_at
|
||||
"""),
|
||||
{
|
||||
"device_id": device_id,
|
||||
"tenant_id": tenant_id,
|
||||
"rule_id": rule_id,
|
||||
"status": status,
|
||||
"severity": severity,
|
||||
"metric": metric,
|
||||
"value": value,
|
||||
"threshold": threshold,
|
||||
"message": message,
|
||||
"is_flapping": is_flapping,
|
||||
},
|
||||
)
|
||||
row = result.fetchone()
|
||||
await session.commit()
|
||||
|
||||
alert_data = {
|
||||
"id": str(row[0]) if row else None,
|
||||
"device_id": device_id,
|
||||
"tenant_id": tenant_id,
|
||||
"rule_id": rule_id,
|
||||
"status": status,
|
||||
"severity": severity,
|
||||
"metric": metric,
|
||||
"value": value,
|
||||
"threshold": threshold,
|
||||
"message": message,
|
||||
"is_flapping": is_flapping,
|
||||
}
|
||||
|
||||
# Publish real-time event to NATS for SSE pipeline (fire-and-forget)
|
||||
if status in ("firing", "flapping"):
|
||||
await publish_event(f"alert.fired.{tenant_id}", {
|
||||
"event_type": "alert_fired",
|
||||
"tenant_id": tenant_id,
|
||||
"device_id": device_id,
|
||||
"alert_event_id": alert_data["id"],
|
||||
"severity": severity,
|
||||
"metric": metric,
|
||||
"current_value": value,
|
||||
"threshold": threshold,
|
||||
"message": message,
|
||||
"is_flapping": is_flapping,
|
||||
"fired_at": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
elif status == "resolved":
|
||||
await publish_event(f"alert.resolved.{tenant_id}", {
|
||||
"event_type": "alert_resolved",
|
||||
"tenant_id": tenant_id,
|
||||
"device_id": device_id,
|
||||
"alert_event_id": alert_data["id"],
|
||||
"severity": severity,
|
||||
"metric": metric,
|
||||
"message": message,
|
||||
"resolved_at": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
|
||||
return alert_data
|
||||
|
||||
|
||||
async def _resolve_alert(device_id: str, rule_id: str | None, metric: str | None = None) -> None:
|
||||
"""Resolve an open alert by setting resolved_at."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
if rule_id:
|
||||
await session.execute(
|
||||
text("""
|
||||
UPDATE alert_events SET resolved_at = NOW(), status = 'resolved'
|
||||
WHERE device_id = CAST(:device_id AS uuid) AND rule_id = CAST(:rule_id AS uuid)
|
||||
AND status = 'firing' AND resolved_at IS NULL
|
||||
"""),
|
||||
{"device_id": device_id, "rule_id": rule_id},
|
||||
)
|
||||
else:
|
||||
await session.execute(
|
||||
text("""
|
||||
UPDATE alert_events SET resolved_at = NOW(), status = 'resolved'
|
||||
WHERE device_id = CAST(:device_id AS uuid) AND rule_id IS NULL
|
||||
AND metric = :metric AND status = 'firing' AND resolved_at IS NULL
|
||||
"""),
|
||||
{"device_id": device_id, "metric": metric or "offline"},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def _get_channels_for_tenant(tenant_id: str) -> list[dict]:
|
||||
"""Get all notification channels for a tenant."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT id, name, channel_type, smtp_host, smtp_port, smtp_user,
|
||||
smtp_password, smtp_use_tls, from_address, to_address,
|
||||
webhook_url, smtp_password_transit, slack_webhook_url, tenant_id
|
||||
FROM notification_channels
|
||||
WHERE tenant_id = CAST(:tenant_id AS uuid)
|
||||
"""),
|
||||
{"tenant_id": tenant_id},
|
||||
)
|
||||
return [
|
||||
{
|
||||
"id": str(row[0]),
|
||||
"name": row[1],
|
||||
"channel_type": row[2],
|
||||
"smtp_host": row[3],
|
||||
"smtp_port": row[4],
|
||||
"smtp_user": row[5],
|
||||
"smtp_password": row[6],
|
||||
"smtp_use_tls": row[7],
|
||||
"from_address": row[8],
|
||||
"to_address": row[9],
|
||||
"webhook_url": row[10],
|
||||
"smtp_password_transit": row[11],
|
||||
"slack_webhook_url": row[12],
|
||||
"tenant_id": str(row[13]) if row[13] else None,
|
||||
}
|
||||
for row in result.fetchall()
|
||||
]
|
||||
|
||||
|
||||
async def _get_channels_for_rule(rule_id: str) -> list[dict]:
|
||||
"""Get notification channels linked to a specific alert rule."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT nc.id, nc.name, nc.channel_type, nc.smtp_host, nc.smtp_port,
|
||||
nc.smtp_user, nc.smtp_password, nc.smtp_use_tls,
|
||||
nc.from_address, nc.to_address, nc.webhook_url,
|
||||
nc.smtp_password_transit, nc.slack_webhook_url, nc.tenant_id
|
||||
FROM notification_channels nc
|
||||
JOIN alert_rule_channels arc ON arc.channel_id = nc.id
|
||||
WHERE arc.rule_id = CAST(:rule_id AS uuid)
|
||||
"""),
|
||||
{"rule_id": rule_id},
|
||||
)
|
||||
return [
|
||||
{
|
||||
"id": str(row[0]),
|
||||
"name": row[1],
|
||||
"channel_type": row[2],
|
||||
"smtp_host": row[3],
|
||||
"smtp_port": row[4],
|
||||
"smtp_user": row[5],
|
||||
"smtp_password": row[6],
|
||||
"smtp_use_tls": row[7],
|
||||
"from_address": row[8],
|
||||
"to_address": row[9],
|
||||
"webhook_url": row[10],
|
||||
"smtp_password_transit": row[11],
|
||||
"slack_webhook_url": row[12],
|
||||
"tenant_id": str(row[13]) if row[13] else None,
|
||||
}
|
||||
for row in result.fetchall()
|
||||
]
|
||||
|
||||
|
||||
async def _dispatch_async(alert_event: dict, channels: list[dict], device_hostname: str) -> None:
|
||||
"""Fire-and-forget notification dispatch."""
|
||||
try:
|
||||
from app.services.notification_service import dispatch_notifications
|
||||
await dispatch_notifications(alert_event, channels, device_hostname)
|
||||
except Exception as e:
|
||||
logger.warning("Notification dispatch failed: %s", e)
|
||||
|
||||
|
||||
async def _get_device_hostname(device_id: str) -> str:
|
||||
"""Get device hostname for notification messages."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("SELECT hostname FROM devices WHERE id = CAST(:device_id AS uuid)"),
|
||||
{"device_id": device_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row[0] if row else device_id
|
||||
|
||||
|
||||
async def evaluate(
|
||||
device_id: str,
|
||||
tenant_id: str,
|
||||
metric_type: str,
|
||||
data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Evaluate alert rules for incoming device metrics.
|
||||
|
||||
Called from metrics_subscriber after metric DB write.
|
||||
"""
|
||||
# Check maintenance window suppression before evaluating rules
|
||||
if await _is_device_in_maintenance(tenant_id, device_id):
|
||||
logger.debug(
|
||||
"Alert suppressed by maintenance window for device %s tenant %s",
|
||||
device_id, tenant_id,
|
||||
)
|
||||
return
|
||||
|
||||
rules = await _get_rules_for_tenant(tenant_id)
|
||||
if not rules:
|
||||
return
|
||||
|
||||
metrics = _extract_metrics(metric_type, data)
|
||||
if not metrics:
|
||||
return
|
||||
|
||||
r = await _get_redis()
|
||||
device_groups = await _get_device_groups(device_id)
|
||||
|
||||
# Build a set of metrics that have device-specific rules
|
||||
device_specific_metrics: set[str] = set()
|
||||
for rule in rules:
|
||||
if rule["device_id"] == device_id:
|
||||
device_specific_metrics.add(rule["metric"])
|
||||
|
||||
for rule in rules:
|
||||
rule_metric = rule["metric"]
|
||||
if rule_metric not in metrics:
|
||||
continue
|
||||
|
||||
# Check if rule applies to this device
|
||||
applies = False
|
||||
if rule["device_id"] == device_id:
|
||||
applies = True
|
||||
elif rule["device_id"] is None and rule["group_id"] is None:
|
||||
# Tenant-wide rule — skip if device-specific rule exists for same metric
|
||||
if rule_metric in device_specific_metrics:
|
||||
continue
|
||||
applies = True
|
||||
elif rule["group_id"] and rule["group_id"] in device_groups:
|
||||
applies = True
|
||||
|
||||
if not applies:
|
||||
continue
|
||||
|
||||
value = metrics[rule_metric]
|
||||
breaching = _check_threshold(value, rule["operator"], rule["threshold"])
|
||||
|
||||
if breaching:
|
||||
reached = await _increment_breach(r, device_id, rule["id"], rule["duration_polls"])
|
||||
if reached:
|
||||
# Check if already firing
|
||||
if await _has_open_alert(device_id, rule["id"]):
|
||||
continue
|
||||
|
||||
# Check flapping
|
||||
is_flapping = await _check_flapping(r, device_id, rule["id"])
|
||||
|
||||
hostname = await _get_device_hostname(device_id)
|
||||
message = f"{rule['name']}: {rule_metric} = {value} (threshold: {rule['operator']} {rule['threshold']})"
|
||||
|
||||
alert_event = await _create_alert_event(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
rule_id=rule["id"],
|
||||
status="flapping" if is_flapping else "firing",
|
||||
severity=rule["severity"],
|
||||
metric=rule_metric,
|
||||
value=value,
|
||||
threshold=rule["threshold"],
|
||||
message=message,
|
||||
is_flapping=is_flapping,
|
||||
)
|
||||
|
||||
if is_flapping:
|
||||
logger.info(
|
||||
"Alert %s for device %s is flapping — notifications suppressed",
|
||||
rule["name"], device_id,
|
||||
)
|
||||
else:
|
||||
channels = await _get_channels_for_rule(rule["id"])
|
||||
if channels:
|
||||
asyncio.create_task(_dispatch_async(alert_event, channels, hostname))
|
||||
else:
|
||||
# Not breaching — reset counter and check for open alert to resolve
|
||||
await _reset_breach(r, device_id, rule["id"])
|
||||
|
||||
if await _has_open_alert(device_id, rule["id"]):
|
||||
# Check flapping before resolving
|
||||
is_flapping = await _check_flapping(r, device_id, rule["id"])
|
||||
|
||||
await _resolve_alert(device_id, rule["id"])
|
||||
|
||||
hostname = await _get_device_hostname(device_id)
|
||||
message = f"Resolved: {rule['name']}: {rule_metric} = {value}"
|
||||
|
||||
resolved_event = await _create_alert_event(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
rule_id=rule["id"],
|
||||
status="resolved",
|
||||
severity=rule["severity"],
|
||||
metric=rule_metric,
|
||||
value=value,
|
||||
threshold=rule["threshold"],
|
||||
message=message,
|
||||
is_flapping=is_flapping,
|
||||
)
|
||||
|
||||
if not is_flapping:
|
||||
channels = await _get_channels_for_rule(rule["id"])
|
||||
if channels:
|
||||
asyncio.create_task(_dispatch_async(resolved_event, channels, hostname))
|
||||
|
||||
|
||||
async def _get_offline_rule(tenant_id: str) -> dict | None:
|
||||
"""Look up the device_offline default rule for a tenant."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT id, enabled FROM alert_rules
|
||||
WHERE tenant_id = CAST(:tenant_id AS uuid)
|
||||
AND metric = 'device_offline' AND is_default = TRUE
|
||||
LIMIT 1
|
||||
"""),
|
||||
{"tenant_id": tenant_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if row:
|
||||
return {"id": str(row[0]), "enabled": row[1]}
|
||||
return None
|
||||
|
||||
|
||||
async def evaluate_offline(device_id: str, tenant_id: str) -> None:
|
||||
"""Create a critical alert when a device goes offline.
|
||||
|
||||
Uses the tenant's device_offline default rule if it exists and is enabled.
|
||||
Falls back to system-level alert (rule_id=NULL) for backward compatibility.
|
||||
"""
|
||||
if await _is_device_in_maintenance(tenant_id, device_id):
|
||||
logger.debug(
|
||||
"Offline alert suppressed by maintenance window for device %s",
|
||||
device_id,
|
||||
)
|
||||
return
|
||||
|
||||
rule = await _get_offline_rule(tenant_id)
|
||||
rule_id = rule["id"] if rule else None
|
||||
|
||||
# If rule exists but is disabled, skip alert creation (user opted out)
|
||||
if rule and not rule["enabled"]:
|
||||
return
|
||||
|
||||
if rule_id:
|
||||
if await _has_open_alert(device_id, rule_id):
|
||||
return
|
||||
else:
|
||||
if await _has_open_alert(device_id, None, "offline"):
|
||||
return
|
||||
|
||||
hostname = await _get_device_hostname(device_id)
|
||||
message = f"Device {hostname} is offline"
|
||||
|
||||
alert_event = await _create_alert_event(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
rule_id=rule_id,
|
||||
status="firing",
|
||||
severity="critical",
|
||||
metric="offline",
|
||||
value=None,
|
||||
threshold=None,
|
||||
message=message,
|
||||
)
|
||||
|
||||
# Use rule-linked channels if available, otherwise tenant-wide channels
|
||||
if rule_id:
|
||||
channels = await _get_channels_for_rule(rule_id)
|
||||
if not channels:
|
||||
channels = await _get_channels_for_tenant(tenant_id)
|
||||
else:
|
||||
channels = await _get_channels_for_tenant(tenant_id)
|
||||
|
||||
if channels:
|
||||
asyncio.create_task(_dispatch_async(alert_event, channels, hostname))
|
||||
|
||||
|
||||
async def evaluate_online(device_id: str, tenant_id: str) -> None:
|
||||
"""Resolve offline alert when device comes back online."""
|
||||
rule = await _get_offline_rule(tenant_id)
|
||||
rule_id = rule["id"] if rule else None
|
||||
|
||||
if rule_id:
|
||||
if not await _has_open_alert(device_id, rule_id):
|
||||
return
|
||||
await _resolve_alert(device_id, rule_id)
|
||||
else:
|
||||
if not await _has_open_alert(device_id, None, "offline"):
|
||||
return
|
||||
await _resolve_alert(device_id, None, "offline")
|
||||
|
||||
hostname = await _get_device_hostname(device_id)
|
||||
message = f"Device {hostname} is back online"
|
||||
|
||||
resolved_event = await _create_alert_event(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
rule_id=rule_id,
|
||||
status="resolved",
|
||||
severity="critical",
|
||||
metric="offline",
|
||||
value=None,
|
||||
threshold=None,
|
||||
message=message,
|
||||
)
|
||||
|
||||
if rule_id:
|
||||
channels = await _get_channels_for_rule(rule_id)
|
||||
if not channels:
|
||||
channels = await _get_channels_for_tenant(tenant_id)
|
||||
else:
|
||||
channels = await _get_channels_for_tenant(tenant_id)
|
||||
|
||||
if channels:
|
||||
asyncio.create_task(_dispatch_async(resolved_event, channels, hostname))
|
||||
190
backend/app/services/api_key_service.py
Normal file
190
backend/app/services/api_key_service.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""API key generation, validation, and management service.
|
||||
|
||||
Keys use the mktp_ prefix for easy identification in logs.
|
||||
Storage uses SHA-256 hash -- the plaintext key is never persisted.
|
||||
Validation uses AdminAsyncSessionLocal since it runs before tenant context is set.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
|
||||
# Allowed scopes for API keys
|
||||
ALLOWED_SCOPES: set[str] = {
|
||||
"devices:read",
|
||||
"devices:write",
|
||||
"config:read",
|
||||
"config:write",
|
||||
"alerts:read",
|
||||
"firmware:write",
|
||||
}
|
||||
|
||||
|
||||
def generate_raw_key() -> str:
|
||||
"""Generate a raw API key with mktp_ prefix + 32 URL-safe random chars."""
|
||||
random_part = secrets.token_urlsafe(32)
|
||||
return f"mktp_{random_part}"
|
||||
|
||||
|
||||
def hash_key(raw_key: str) -> str:
|
||||
"""SHA-256 hex digest of a raw API key."""
|
||||
return hashlib.sha256(raw_key.encode()).hexdigest()
|
||||
|
||||
|
||||
async def create_api_key(
|
||||
db,
|
||||
tenant_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
name: str,
|
||||
scopes: list[str],
|
||||
expires_at: Optional[datetime] = None,
|
||||
) -> dict:
|
||||
"""Create a new API key.
|
||||
|
||||
Returns dict with:
|
||||
- key: the plaintext key (shown once, never again)
|
||||
- id: the key UUID
|
||||
- key_prefix: first 9 chars of the key (e.g. "mktp_abc1")
|
||||
"""
|
||||
raw_key = generate_raw_key()
|
||||
key_hash_value = hash_key(raw_key)
|
||||
key_prefix = raw_key[:9] # "mktp_" + first 4 random chars
|
||||
|
||||
result = await db.execute(
|
||||
text("""
|
||||
INSERT INTO api_keys (tenant_id, user_id, name, key_prefix, key_hash, scopes, expires_at)
|
||||
VALUES (:tenant_id, :user_id, :name, :key_prefix, :key_hash, CAST(:scopes AS jsonb), :expires_at)
|
||||
RETURNING id, created_at
|
||||
"""),
|
||||
{
|
||||
"tenant_id": str(tenant_id),
|
||||
"user_id": str(user_id),
|
||||
"name": name,
|
||||
"key_prefix": key_prefix,
|
||||
"key_hash": key_hash_value,
|
||||
"scopes": json.dumps(scopes),
|
||||
"expires_at": expires_at,
|
||||
},
|
||||
)
|
||||
row = result.fetchone()
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"key": raw_key,
|
||||
"id": row.id,
|
||||
"key_prefix": key_prefix,
|
||||
"name": name,
|
||||
"scopes": scopes,
|
||||
"expires_at": expires_at,
|
||||
"created_at": row.created_at,
|
||||
}
|
||||
|
||||
|
||||
async def validate_api_key(raw_key: str) -> Optional[dict]:
|
||||
"""Validate an API key and return context if valid.
|
||||
|
||||
Uses AdminAsyncSessionLocal since this runs before tenant context is set.
|
||||
|
||||
Returns dict with tenant_id, user_id, scopes, key_id on success.
|
||||
Returns None for invalid, expired, or revoked keys.
|
||||
Updates last_used_at on successful validation.
|
||||
"""
|
||||
key_hash_value = hash_key(raw_key)
|
||||
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT id, tenant_id, user_id, scopes, expires_at, revoked_at
|
||||
FROM api_keys
|
||||
WHERE key_hash = :key_hash
|
||||
"""),
|
||||
{"key_hash": key_hash_value},
|
||||
)
|
||||
row = result.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
# Check revoked
|
||||
if row.revoked_at is not None:
|
||||
return None
|
||||
|
||||
# Check expired
|
||||
if row.expires_at is not None and row.expires_at <= datetime.now(timezone.utc):
|
||||
return None
|
||||
|
||||
# Update last_used_at
|
||||
await session.execute(
|
||||
text("""
|
||||
UPDATE api_keys SET last_used_at = now()
|
||||
WHERE id = :key_id
|
||||
"""),
|
||||
{"key_id": str(row.id)},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
return {
|
||||
"tenant_id": row.tenant_id,
|
||||
"user_id": row.user_id,
|
||||
"scopes": row.scopes if row.scopes else [],
|
||||
"key_id": row.id,
|
||||
}
|
||||
|
||||
|
||||
async def list_api_keys(db, tenant_id: uuid.UUID) -> list[dict]:
|
||||
"""List all API keys for a tenant (active and revoked).
|
||||
|
||||
Returns keys with masked display (key_prefix + "...").
|
||||
"""
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT id, name, key_prefix, scopes, expires_at, last_used_at,
|
||||
created_at, revoked_at, user_id
|
||||
FROM api_keys
|
||||
WHERE tenant_id = :tenant_id
|
||||
ORDER BY created_at DESC
|
||||
"""),
|
||||
{"tenant_id": str(tenant_id)},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": row.id,
|
||||
"name": row.name,
|
||||
"key_prefix": row.key_prefix,
|
||||
"scopes": row.scopes if row.scopes else [],
|
||||
"expires_at": row.expires_at.isoformat() if row.expires_at else None,
|
||||
"last_used_at": row.last_used_at.isoformat() if row.last_used_at else None,
|
||||
"created_at": row.created_at.isoformat() if row.created_at else None,
|
||||
"revoked_at": row.revoked_at.isoformat() if row.revoked_at else None,
|
||||
"user_id": str(row.user_id),
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
|
||||
async def revoke_api_key(db, tenant_id: uuid.UUID, key_id: uuid.UUID) -> bool:
|
||||
"""Revoke an API key by setting revoked_at = now().
|
||||
|
||||
Returns True if a key was actually revoked, False if not found or already revoked.
|
||||
"""
|
||||
result = await db.execute(
|
||||
text("""
|
||||
UPDATE api_keys
|
||||
SET revoked_at = now()
|
||||
WHERE id = :key_id AND tenant_id = :tenant_id AND revoked_at IS NULL
|
||||
RETURNING id
|
||||
"""),
|
||||
{"key_id": str(key_id), "tenant_id": str(tenant_id)},
|
||||
)
|
||||
row = result.fetchone()
|
||||
await db.commit()
|
||||
return row is not None
|
||||
92
backend/app/services/audit_service.py
Normal file
92
backend/app/services/audit_service.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Centralized audit logging service.
|
||||
|
||||
Provides a fire-and-forget ``log_action`` coroutine that inserts a row into
|
||||
the ``audit_logs`` table. Uses raw SQL INSERT (not ORM) for minimal overhead.
|
||||
|
||||
The function is wrapped in a try/except so that a logging failure **never**
|
||||
breaks the parent operation.
|
||||
|
||||
Phase 30: When details are non-empty, they are encrypted via OpenBao Transit
|
||||
(per-tenant data key) and stored in encrypted_details. The plaintext details
|
||||
column is set to '{}' for column compatibility. If Transit encryption fails
|
||||
(e.g., OpenBao unavailable), details are stored in plaintext as a fallback.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = structlog.get_logger("audit")
|
||||
|
||||
|
||||
async def log_action(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
action: str,
|
||||
resource_type: Optional[str] = None,
|
||||
resource_id: Optional[str] = None,
|
||||
device_id: Optional[uuid.UUID] = None,
|
||||
details: Optional[dict[str, Any]] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Insert a row into audit_logs. Swallows all exceptions on failure."""
|
||||
try:
|
||||
import json as _json
|
||||
|
||||
details_dict = details or {}
|
||||
details_json = _json.dumps(details_dict)
|
||||
encrypted_details: Optional[str] = None
|
||||
|
||||
# Attempt Transit encryption for non-empty details
|
||||
if details_dict:
|
||||
try:
|
||||
from app.services.crypto import encrypt_data_transit
|
||||
|
||||
encrypted_details = await encrypt_data_transit(
|
||||
details_json, str(tenant_id)
|
||||
)
|
||||
# Encryption succeeded — clear plaintext details
|
||||
details_json = _json.dumps({})
|
||||
except Exception:
|
||||
# Transit unavailable — fall back to plaintext details
|
||||
logger.warning(
|
||||
"audit_transit_encryption_failed",
|
||||
action=action,
|
||||
tenant_id=str(tenant_id),
|
||||
exc_info=True,
|
||||
)
|
||||
# Keep details_json as-is (plaintext fallback)
|
||||
encrypted_details = None
|
||||
|
||||
await db.execute(
|
||||
text(
|
||||
"INSERT INTO audit_logs "
|
||||
"(tenant_id, user_id, action, resource_type, resource_id, "
|
||||
"device_id, details, encrypted_details, ip_address) "
|
||||
"VALUES (:tenant_id, :user_id, :action, :resource_type, "
|
||||
":resource_id, :device_id, CAST(:details AS jsonb), "
|
||||
":encrypted_details, :ip_address)"
|
||||
),
|
||||
{
|
||||
"tenant_id": str(tenant_id),
|
||||
"user_id": str(user_id),
|
||||
"action": action,
|
||||
"resource_type": resource_type,
|
||||
"resource_id": resource_id,
|
||||
"device_id": str(device_id) if device_id else None,
|
||||
"details": details_json,
|
||||
"encrypted_details": encrypted_details,
|
||||
"ip_address": ip_address,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"audit_log_insert_failed",
|
||||
action=action,
|
||||
tenant_id=str(tenant_id),
|
||||
exc_info=True,
|
||||
)
|
||||
154
backend/app/services/auth.py
Normal file
154
backend/app/services/auth.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
JWT authentication service.
|
||||
|
||||
Handles password hashing, JWT token creation, token verification,
|
||||
and token revocation via Redis.
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
import bcrypt
|
||||
from fastapi import HTTPException, status
|
||||
from jose import JWTError, jwt
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.config import settings
|
||||
|
||||
TOKEN_REVOCATION_PREFIX = "token_revoked:"
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a plaintext password using bcrypt.
|
||||
|
||||
DEPRECATED: Used only by password reset (temporary bcrypt hash for
|
||||
upgrade flow) and bootstrap_first_admin. Remove post-v6.0.
|
||||
"""
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a plaintext password against a bcrypt hash.
|
||||
|
||||
DEPRECATED: Used only by the one-time SRP upgrade flow (login with
|
||||
must_upgrade_auth=True) and anti-enumeration dummy calls. Remove post-v6.0.
|
||||
"""
|
||||
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
|
||||
|
||||
|
||||
def create_access_token(
|
||||
user_id: uuid.UUID,
|
||||
tenant_id: Optional[uuid.UUID],
|
||||
role: str,
|
||||
) -> str:
|
||||
"""
|
||||
Create a short-lived JWT access token.
|
||||
|
||||
Claims:
|
||||
sub: user UUID (subject)
|
||||
tenant_id: tenant UUID or None for super_admin
|
||||
role: user's role string
|
||||
type: "access" (to distinguish from refresh tokens)
|
||||
exp: expiry timestamp
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
expire = now + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"tenant_id": str(tenant_id) if tenant_id else None,
|
||||
"role": role,
|
||||
"type": "access",
|
||||
"iat": now,
|
||||
"exp": expire,
|
||||
}
|
||||
|
||||
return jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(user_id: uuid.UUID) -> str:
|
||||
"""
|
||||
Create a long-lived JWT refresh token.
|
||||
|
||||
Claims:
|
||||
sub: user UUID (subject)
|
||||
type: "refresh" (to distinguish from access tokens)
|
||||
exp: expiry timestamp (7 days)
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
expire = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"type": "refresh",
|
||||
"iat": now,
|
||||
"exp": expire,
|
||||
}
|
||||
|
||||
return jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def verify_token(token: str, expected_type: str = "access") -> dict:
|
||||
"""
|
||||
Decode and validate a JWT token.
|
||||
|
||||
Args:
|
||||
token: JWT string to validate
|
||||
expected_type: "access" or "refresh"
|
||||
|
||||
Returns:
|
||||
dict: Decoded payload (sub, tenant_id, role, type, exp, iat)
|
||||
|
||||
Raises:
|
||||
HTTPException 401: If token is invalid, expired, or wrong type
|
||||
"""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM],
|
||||
)
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
|
||||
# Validate token type
|
||||
token_type = payload.get("type")
|
||||
if token_type != expected_type:
|
||||
raise credentials_exception
|
||||
|
||||
# Validate subject exists
|
||||
sub = payload.get("sub")
|
||||
if not sub:
|
||||
raise credentials_exception
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
async def revoke_user_tokens(redis: Redis, user_id: str) -> None:
|
||||
"""Mark all tokens for a user as revoked by storing current timestamp.
|
||||
|
||||
Any refresh token issued before this timestamp will be rejected.
|
||||
TTL matches maximum refresh token lifetime (7 days).
|
||||
"""
|
||||
key = f"{TOKEN_REVOCATION_PREFIX}{user_id}"
|
||||
await redis.set(key, str(time.time()), ex=7 * 24 * 3600)
|
||||
|
||||
|
||||
async def is_token_revoked(redis: Redis, user_id: str, issued_at: float) -> bool:
|
||||
"""Check if a token was issued before the user's revocation timestamp.
|
||||
|
||||
Returns True if the token should be rejected.
|
||||
"""
|
||||
key = f"{TOKEN_REVOCATION_PREFIX}{user_id}"
|
||||
revoked_at = await redis.get(key)
|
||||
if revoked_at is None:
|
||||
return False
|
||||
return issued_at < float(revoked_at)
|
||||
197
backend/app/services/backup_scheduler.py
Normal file
197
backend/app/services/backup_scheduler.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""Dynamic backup scheduler — reads cron schedules from DB, manages APScheduler jobs."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
from app.models.config_backup import ConfigBackupSchedule
|
||||
from app.models.device import Device
|
||||
from app.services import backup_service
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_scheduler: Optional[AsyncIOScheduler] = None
|
||||
|
||||
# System default: 2am UTC daily
|
||||
DEFAULT_CRON = "0 2 * * *"
|
||||
|
||||
|
||||
def _cron_to_trigger(cron_expr: str) -> Optional[CronTrigger]:
|
||||
"""Parse a 5-field cron expression into an APScheduler CronTrigger.
|
||||
|
||||
Returns None if the expression is invalid.
|
||||
"""
|
||||
try:
|
||||
parts = cron_expr.strip().split()
|
||||
if len(parts) != 5:
|
||||
return None
|
||||
minute, hour, day, month, day_of_week = parts
|
||||
return CronTrigger(
|
||||
minute=minute, hour=hour, day=day, month=month,
|
||||
day_of_week=day_of_week, timezone="UTC",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Invalid cron expression '%s': %s", cron_expr, e)
|
||||
return None
|
||||
|
||||
|
||||
def build_schedule_map(schedules: list) -> dict[str, list[dict]]:
|
||||
"""Group device schedules by cron expression.
|
||||
|
||||
Returns: {cron_expression: [{device_id, tenant_id}, ...]}
|
||||
"""
|
||||
schedule_map: dict[str, list[dict]] = {}
|
||||
for s in schedules:
|
||||
if not s.enabled:
|
||||
continue
|
||||
cron = s.cron_expression or DEFAULT_CRON
|
||||
if cron not in schedule_map:
|
||||
schedule_map[cron] = []
|
||||
schedule_map[cron].append({
|
||||
"device_id": str(s.device_id),
|
||||
"tenant_id": str(s.tenant_id),
|
||||
})
|
||||
return schedule_map
|
||||
|
||||
|
||||
async def _run_scheduled_backups(devices: list[dict]) -> None:
|
||||
"""Run backups for a list of devices. Each failure is isolated."""
|
||||
success_count = 0
|
||||
failure_count = 0
|
||||
|
||||
for dev_info in devices:
|
||||
try:
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
await backup_service.run_backup(
|
||||
device_id=dev_info["device_id"],
|
||||
tenant_id=dev_info["tenant_id"],
|
||||
trigger_type="scheduled",
|
||||
db_session=session,
|
||||
)
|
||||
await session.commit()
|
||||
logger.info("Scheduled backup OK: device %s", dev_info["device_id"])
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Scheduled backup FAILED: device %s: %s",
|
||||
dev_info["device_id"], e,
|
||||
)
|
||||
failure_count += 1
|
||||
|
||||
logger.info(
|
||||
"Backup batch complete — %d succeeded, %d failed",
|
||||
success_count, failure_count,
|
||||
)
|
||||
|
||||
|
||||
async def _load_effective_schedules() -> list:
|
||||
"""Load all effective schedules from DB.
|
||||
|
||||
For each device: use device-specific schedule if exists, else tenant default.
|
||||
Returns flat list of (device_id, tenant_id, cron_expression, enabled) objects.
|
||||
"""
|
||||
from types import SimpleNamespace
|
||||
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
# Get all devices
|
||||
dev_result = await session.execute(select(Device))
|
||||
devices = dev_result.scalars().all()
|
||||
|
||||
# Get all schedules
|
||||
sched_result = await session.execute(select(ConfigBackupSchedule))
|
||||
schedules = sched_result.scalars().all()
|
||||
|
||||
# Index: device-specific and tenant defaults
|
||||
device_schedules = {} # device_id -> schedule
|
||||
tenant_defaults = {} # tenant_id -> schedule
|
||||
|
||||
for s in schedules:
|
||||
if s.device_id:
|
||||
device_schedules[str(s.device_id)] = s
|
||||
else:
|
||||
tenant_defaults[str(s.tenant_id)] = s
|
||||
|
||||
effective = []
|
||||
for dev in devices:
|
||||
dev_id = str(dev.id)
|
||||
tenant_id = str(dev.tenant_id)
|
||||
|
||||
if dev_id in device_schedules:
|
||||
sched = device_schedules[dev_id]
|
||||
elif tenant_id in tenant_defaults:
|
||||
sched = tenant_defaults[tenant_id]
|
||||
else:
|
||||
# No schedule configured — use system default
|
||||
sched = None
|
||||
|
||||
effective.append(SimpleNamespace(
|
||||
device_id=dev_id,
|
||||
tenant_id=tenant_id,
|
||||
cron_expression=sched.cron_expression if sched else DEFAULT_CRON,
|
||||
enabled=sched.enabled if sched else True,
|
||||
))
|
||||
|
||||
return effective
|
||||
|
||||
|
||||
async def sync_schedules() -> None:
|
||||
"""Reload all schedules from DB and reconfigure APScheduler jobs."""
|
||||
global _scheduler
|
||||
if not _scheduler:
|
||||
return
|
||||
|
||||
# Remove all existing backup jobs (keep other jobs like firmware check)
|
||||
for job in _scheduler.get_jobs():
|
||||
if job.id.startswith("backup_cron_"):
|
||||
job.remove()
|
||||
|
||||
schedules = await _load_effective_schedules()
|
||||
schedule_map = build_schedule_map(schedules)
|
||||
|
||||
for cron_expr, devices in schedule_map.items():
|
||||
trigger = _cron_to_trigger(cron_expr)
|
||||
if not trigger:
|
||||
logger.warning("Skipping invalid cron '%s', using default", cron_expr)
|
||||
trigger = _cron_to_trigger(DEFAULT_CRON)
|
||||
|
||||
job_id = f"backup_cron_{cron_expr.replace(' ', '_')}"
|
||||
_scheduler.add_job(
|
||||
_run_scheduled_backups,
|
||||
trigger=trigger,
|
||||
args=[devices],
|
||||
id=job_id,
|
||||
name=f"Backup: {cron_expr} ({len(devices)} devices)",
|
||||
max_instances=1,
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info("Scheduled %d devices with cron '%s'", len(devices), cron_expr)
|
||||
|
||||
|
||||
async def on_schedule_change(tenant_id: str, device_id: str) -> None:
|
||||
"""Called when a schedule is created/updated via API. Hot-reloads all schedules."""
|
||||
logger.info("Schedule changed for tenant=%s device=%s, resyncing", tenant_id, device_id)
|
||||
await sync_schedules()
|
||||
|
||||
|
||||
async def start_backup_scheduler() -> None:
|
||||
"""Start the APScheduler and load initial schedules from DB."""
|
||||
global _scheduler
|
||||
_scheduler = AsyncIOScheduler(timezone="UTC")
|
||||
_scheduler.start()
|
||||
|
||||
await sync_schedules()
|
||||
logger.info("Backup scheduler started with dynamic schedules")
|
||||
|
||||
|
||||
async def stop_backup_scheduler() -> None:
|
||||
"""Gracefully shutdown the scheduler."""
|
||||
global _scheduler
|
||||
if _scheduler:
|
||||
_scheduler.shutdown(wait=False)
|
||||
_scheduler = None
|
||||
logger.info("Backup scheduler stopped")
|
||||
378
backend/app/services/backup_service.py
Normal file
378
backend/app/services/backup_service.py
Normal file
@@ -0,0 +1,378 @@
|
||||
"""SSH-based config capture service for RouterOS devices.
|
||||
|
||||
This service handles:
|
||||
1. capture_export() — SSH to device, run /export compact, return stdout text
|
||||
2. capture_binary_backup() — SSH to device, trigger /system backup save, SFTP-download result
|
||||
3. run_backup() — Orchestrate a full backup: capture + git commit + DB record
|
||||
|
||||
All functions are async (asyncssh is asyncio-native).
|
||||
|
||||
Security policy:
|
||||
known_hosts=None is intentional — RouterOS devices use self-signed SSH host keys
|
||||
that change on reset or key regeneration. This mirrors InsecureSkipVerify=true
|
||||
used in the poller's TLS connection. The threat model accepts device impersonation
|
||||
risk in exchange for operational simplicity (no pre-enrollment of host keys needed).
|
||||
See Pitfall 2 in 04-RESEARCH.md.
|
||||
|
||||
pygit2 calls are synchronous C bindings and MUST be wrapped in run_in_executor.
|
||||
See Pitfall 3 in 04-RESEARCH.md.
|
||||
|
||||
Phase 30: ALL backups (manual, scheduled, pre-restore) are encrypted via OpenBao
|
||||
Transit (Tier 2) before git commit. The server retains decrypt capability for
|
||||
on-demand viewing. Raw files in git are ciphertext; the API decrypts on GET.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import asyncssh
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal, set_tenant_context
|
||||
from app.models.config_backup import ConfigBackupRun
|
||||
from app.models.device import Device
|
||||
from app.services import git_store
|
||||
from app.services.crypto import decrypt_credentials_hybrid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Fixed backup file name on device flash — overwrites on each run so files
|
||||
# don't accumulate. See Pitfall 4 in 04-RESEARCH.md.
|
||||
_BACKUP_NAME = "portal-backup"
|
||||
|
||||
|
||||
async def capture_export(
|
||||
ip: str,
|
||||
port: int = 22,
|
||||
username: str = "",
|
||||
password: str = "",
|
||||
) -> str:
|
||||
"""SSH to a RouterOS device and capture /export compact output.
|
||||
|
||||
Args:
|
||||
ip: Device IP address.
|
||||
port: SSH port (default 22; RouterOS default is 22).
|
||||
username: SSH login username.
|
||||
password: SSH login password.
|
||||
|
||||
Returns:
|
||||
The raw RSC text from /export compact (may include RouterOS header line).
|
||||
|
||||
Raises:
|
||||
asyncssh.Error: On SSH connection or command execution failure.
|
||||
"""
|
||||
async with asyncssh.connect(
|
||||
ip,
|
||||
port=port,
|
||||
username=username,
|
||||
password=password,
|
||||
known_hosts=None, # RouterOS self-signed host keys — see module docstring
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
result = await conn.run("/export compact", check=True)
|
||||
return result.stdout
|
||||
|
||||
|
||||
async def capture_binary_backup(
|
||||
ip: str,
|
||||
port: int = 22,
|
||||
username: str = "",
|
||||
password: str = "",
|
||||
) -> bytes:
|
||||
"""SSH to a RouterOS device, create a binary backup, SFTP-download it, then clean up.
|
||||
|
||||
Uses a fixed backup name ({_BACKUP_NAME}.backup) so the file overwrites
|
||||
on subsequent runs, preventing flash storage accumulation.
|
||||
|
||||
The cleanup (removing the file from device flash) runs in a try/finally
|
||||
block so cleanup failures don't mask the actual backup error but are
|
||||
logged for observability. See Pitfall 4 in 04-RESEARCH.md.
|
||||
|
||||
Args:
|
||||
ip: Device IP address.
|
||||
port: SSH port (default 22).
|
||||
username: SSH login username.
|
||||
password: SSH login password.
|
||||
|
||||
Returns:
|
||||
Raw bytes of the binary backup file.
|
||||
|
||||
Raises:
|
||||
asyncssh.Error: On SSH connection, command, or SFTP failure.
|
||||
"""
|
||||
async with asyncssh.connect(
|
||||
ip,
|
||||
port=port,
|
||||
username=username,
|
||||
password=password,
|
||||
known_hosts=None,
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
# Step 1: Trigger backup creation on device flash.
|
||||
await conn.run(
|
||||
f"/system backup save name={_BACKUP_NAME} dont-encrypt=yes",
|
||||
check=True,
|
||||
)
|
||||
|
||||
buf = io.BytesIO()
|
||||
try:
|
||||
# Step 2: SFTP-download the backup file.
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
async with sftp.open(f"{_BACKUP_NAME}.backup", "rb") as f:
|
||||
buf.write(await f.read())
|
||||
finally:
|
||||
# Step 3: Remove backup file from device flash (best-effort cleanup).
|
||||
try:
|
||||
await conn.run(f"/file remove {_BACKUP_NAME}.backup", check=True)
|
||||
except Exception as cleanup_err:
|
||||
logger.warning(
|
||||
"Failed to remove backup file from device %s: %s",
|
||||
ip,
|
||||
cleanup_err,
|
||||
)
|
||||
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
async def run_backup(
|
||||
device_id: str,
|
||||
tenant_id: str,
|
||||
trigger_type: str,
|
||||
db_session: AsyncSession | None = None,
|
||||
) -> dict:
|
||||
"""Orchestrate a full config backup for a device.
|
||||
|
||||
Steps:
|
||||
1. Load device from DB (ip_address, encrypted_credentials).
|
||||
2. Decrypt credentials using crypto.decrypt_credentials().
|
||||
3. Capture /export compact and binary backup concurrently via asyncio.gather().
|
||||
4. Compute line delta vs the most recent export.rsc in git (None for first backup).
|
||||
5. Commit both files to the tenant's bare git repo (run_in_executor for pygit2).
|
||||
6. Insert ConfigBackupRun record with commit SHA, trigger type, line deltas.
|
||||
7. Return summary dict.
|
||||
|
||||
Args:
|
||||
device_id: Device UUID as string.
|
||||
tenant_id: Tenant UUID as string.
|
||||
trigger_type: 'scheduled' | 'manual' | 'pre-restore'
|
||||
db_session: Optional AsyncSession with RLS context already set.
|
||||
If None, uses AdminAsyncSessionLocal (for scheduler context).
|
||||
|
||||
Returns:
|
||||
Dict: {"commit_sha": str, "trigger_type": str, "lines_added": int|None, "lines_removed": int|None}
|
||||
|
||||
Raises:
|
||||
ValueError: If device not found or missing credentials.
|
||||
asyncssh.Error: On SSH/SFTP failure.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
ts = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Step 1: Load device from DB
|
||||
# -----------------------------------------------------------------------
|
||||
if db_session is not None:
|
||||
session = db_session
|
||||
should_close = False
|
||||
else:
|
||||
# Scheduler context: use admin session (cross-tenant; RLS bypassed)
|
||||
session = AdminAsyncSessionLocal()
|
||||
should_close = True
|
||||
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
|
||||
if should_close:
|
||||
# Admin session doesn't have RLS context — query directly.
|
||||
result = await session.execute(
|
||||
select(Device).where(
|
||||
Device.id == device_id, # type: ignore[arg-type]
|
||||
Device.tenant_id == tenant_id, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
else:
|
||||
result = await session.execute(
|
||||
select(Device).where(Device.id == device_id) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
device = result.scalar_one_or_none()
|
||||
if device is None:
|
||||
raise ValueError(f"Device {device_id!r} not found for tenant {tenant_id!r}")
|
||||
|
||||
if not device.encrypted_credentials_transit and not device.encrypted_credentials:
|
||||
raise ValueError(
|
||||
f"Device {device_id!r} has no stored credentials — cannot perform backup"
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Step 2: Decrypt credentials (dual-read: Transit preferred, legacy fallback)
|
||||
# -----------------------------------------------------------------------
|
||||
key = settings.get_encryption_key_bytes()
|
||||
creds_json = await decrypt_credentials_hybrid(
|
||||
device.encrypted_credentials_transit,
|
||||
device.encrypted_credentials,
|
||||
str(device.tenant_id),
|
||||
key,
|
||||
)
|
||||
creds = json.loads(creds_json)
|
||||
ssh_username = creds.get("username", "")
|
||||
ssh_password = creds.get("password", "")
|
||||
ip = device.ip_address
|
||||
|
||||
hostname = device.hostname or ip
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Step 3: Capture export and binary backup concurrently
|
||||
# -----------------------------------------------------------------------
|
||||
logger.info(
|
||||
"Starting %s backup for device %s (%s) tenant %s",
|
||||
trigger_type,
|
||||
hostname,
|
||||
ip,
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
export_text, binary_backup = await asyncio.gather(
|
||||
capture_export(ip, username=ssh_username, password=ssh_password),
|
||||
capture_binary_backup(ip, username=ssh_username, password=ssh_password),
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Step 4: Compute line delta vs prior version
|
||||
# -----------------------------------------------------------------------
|
||||
lines_added: int | None = None
|
||||
lines_removed: int | None = None
|
||||
|
||||
prior_commits = await loop.run_in_executor(
|
||||
None, git_store.list_device_commits, tenant_id, device_id
|
||||
)
|
||||
|
||||
if prior_commits:
|
||||
try:
|
||||
prior_export_bytes = await loop.run_in_executor(
|
||||
None, git_store.read_file, tenant_id, prior_commits[0]["sha"], device_id, "export.rsc"
|
||||
)
|
||||
prior_text = prior_export_bytes.decode("utf-8", errors="replace")
|
||||
lines_added, lines_removed = await loop.run_in_executor(
|
||||
None, git_store.compute_line_delta, prior_text, export_text
|
||||
)
|
||||
except Exception as delta_err:
|
||||
logger.warning(
|
||||
"Failed to compute line delta for device %s: %s",
|
||||
device_id,
|
||||
delta_err,
|
||||
)
|
||||
# Keep lines_added/lines_removed as None on error — non-fatal
|
||||
else:
|
||||
# First backup: all lines are "added", none removed
|
||||
all_lines = len(export_text.splitlines())
|
||||
lines_added = all_lines
|
||||
lines_removed = 0
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Step 5: Encrypt ALL backups via Transit (Tier 2: OpenBao Transit)
|
||||
# -----------------------------------------------------------------------
|
||||
encryption_tier: int | None = None
|
||||
git_export_content = export_text
|
||||
git_binary_content = binary_backup
|
||||
|
||||
try:
|
||||
from app.services.crypto import encrypt_data_transit
|
||||
|
||||
encrypted_export = await encrypt_data_transit(
|
||||
export_text, tenant_id
|
||||
)
|
||||
encrypted_binary = await encrypt_data_transit(
|
||||
base64.b64encode(binary_backup).decode(), tenant_id
|
||||
)
|
||||
# Transit ciphertext is text — store directly in git
|
||||
git_export_content = encrypted_export
|
||||
git_binary_content = encrypted_binary.encode("utf-8")
|
||||
encryption_tier = 2
|
||||
logger.info(
|
||||
"Tier 2 Transit encryption applied for %s backup of device %s",
|
||||
trigger_type,
|
||||
device_id,
|
||||
)
|
||||
except Exception as enc_err:
|
||||
# Transit unavailable — fall back to plaintext (non-fatal)
|
||||
logger.warning(
|
||||
"Transit encryption failed for %s backup of device %s, "
|
||||
"storing plaintext: %s",
|
||||
trigger_type,
|
||||
device_id,
|
||||
enc_err,
|
||||
)
|
||||
# Keep encryption_tier = None (plaintext fallback)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Step 6: Commit to git (wrapped in run_in_executor — pygit2 is sync C bindings)
|
||||
# -----------------------------------------------------------------------
|
||||
commit_message = (
|
||||
f"{trigger_type}: {hostname} ({ip}) at {ts}"
|
||||
)
|
||||
|
||||
commit_sha = await loop.run_in_executor(
|
||||
None,
|
||||
git_store.commit_backup,
|
||||
tenant_id,
|
||||
device_id,
|
||||
git_export_content,
|
||||
git_binary_content,
|
||||
commit_message,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Committed backup for device %s to git SHA %s (tier=%s)",
|
||||
device_id,
|
||||
commit_sha[:8],
|
||||
encryption_tier,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Step 7: Insert ConfigBackupRun record
|
||||
# -----------------------------------------------------------------------
|
||||
if not should_close:
|
||||
# RLS-scoped session from API context — record directly
|
||||
backup_run = ConfigBackupRun(
|
||||
device_id=device.id,
|
||||
tenant_id=device.tenant_id,
|
||||
commit_sha=commit_sha,
|
||||
trigger_type=trigger_type,
|
||||
lines_added=lines_added,
|
||||
lines_removed=lines_removed,
|
||||
encryption_tier=encryption_tier,
|
||||
)
|
||||
session.add(backup_run)
|
||||
await session.flush()
|
||||
else:
|
||||
# Admin session — set tenant context before insert so RLS policy is satisfied
|
||||
async with AdminAsyncSessionLocal() as admin_session:
|
||||
await set_tenant_context(admin_session, str(device.tenant_id))
|
||||
backup_run = ConfigBackupRun(
|
||||
device_id=device.id,
|
||||
tenant_id=device.tenant_id,
|
||||
commit_sha=commit_sha,
|
||||
trigger_type=trigger_type,
|
||||
lines_added=lines_added,
|
||||
lines_removed=lines_removed,
|
||||
encryption_tier=encryption_tier,
|
||||
)
|
||||
admin_session.add(backup_run)
|
||||
await admin_session.commit()
|
||||
|
||||
return {
|
||||
"commit_sha": commit_sha,
|
||||
"trigger_type": trigger_type,
|
||||
"lines_added": lines_added,
|
||||
"lines_removed": lines_removed,
|
||||
}
|
||||
|
||||
finally:
|
||||
if should_close:
|
||||
await session.close()
|
||||
462
backend/app/services/ca_service.py
Normal file
462
backend/app/services/ca_service.py
Normal file
@@ -0,0 +1,462 @@
|
||||
"""Certificate Authority service — CA generation, device cert signing, lifecycle.
|
||||
|
||||
This module provides the core PKI functionality for the Internal Certificate
|
||||
Authority feature. All functions receive an ``AsyncSession`` and an
|
||||
``encryption_key`` as parameters (no direct Settings access) for testability.
|
||||
|
||||
Security notes:
|
||||
- CA private keys are encrypted with AES-256-GCM before database storage.
|
||||
- PEM key material is NEVER logged.
|
||||
- Device keys are decrypted only when needed for NATS transmission.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import ipaddress
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.certificate import CertificateAuthority, DeviceCertificate
|
||||
from app.services.crypto import (
|
||||
decrypt_credentials_hybrid,
|
||||
encrypt_credentials_transit,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Valid status transitions for the device certificate lifecycle.
|
||||
_VALID_TRANSITIONS: dict[str, set[str]] = {
|
||||
"issued": {"deploying"},
|
||||
"deploying": {"deployed", "issued"}, # issued = rollback on deploy failure
|
||||
"deployed": {"expiring", "revoked", "superseded"},
|
||||
"expiring": {"expired", "revoked", "superseded"},
|
||||
"expired": {"superseded"},
|
||||
"revoked": set(),
|
||||
"superseded": set(),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CA Generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def generate_ca(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
common_name: str,
|
||||
validity_years: int,
|
||||
encryption_key: bytes,
|
||||
) -> CertificateAuthority:
|
||||
"""Generate a self-signed root CA for a tenant.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
tenant_id: Tenant UUID — only one CA per tenant.
|
||||
common_name: CN for the CA certificate (e.g., "Portal Root CA").
|
||||
validity_years: How many years the CA cert is valid.
|
||||
encryption_key: 32-byte AES-256-GCM key for encrypting the CA private key.
|
||||
|
||||
Returns:
|
||||
The newly created ``CertificateAuthority`` model instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If the tenant already has a CA.
|
||||
"""
|
||||
# Ensure one CA per tenant
|
||||
existing = await get_ca_for_tenant(db, tenant_id)
|
||||
if existing is not None:
|
||||
raise ValueError(
|
||||
f"Tenant {tenant_id} already has a CA (id={existing.id}). "
|
||||
"Delete the existing CA before creating a new one."
|
||||
)
|
||||
|
||||
# Generate RSA 2048 key pair
|
||||
ca_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
expiry = now + datetime.timedelta(days=365 * validity_years)
|
||||
|
||||
subject = issuer = x509.Name([
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
||||
])
|
||||
|
||||
ca_cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(subject)
|
||||
.issuer_name(issuer)
|
||||
.public_key(ca_key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now)
|
||||
.not_valid_after(expiry)
|
||||
.add_extension(
|
||||
x509.BasicConstraints(ca=True, path_length=0), critical=True
|
||||
)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
content_commitment=False,
|
||||
key_encipherment=False,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
key_cert_sign=True,
|
||||
crl_sign=True,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.SubjectKeyIdentifier.from_public_key(ca_key.public_key()),
|
||||
critical=False,
|
||||
)
|
||||
.sign(ca_key, hashes.SHA256())
|
||||
)
|
||||
|
||||
# Serialize public cert to PEM
|
||||
cert_pem = ca_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
|
||||
|
||||
# Serialize private key to PEM, then encrypt with OpenBao Transit
|
||||
key_pem = ca_key.private_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PrivateFormat.PKCS8,
|
||||
serialization.NoEncryption(),
|
||||
).decode("utf-8")
|
||||
encrypted_key_transit = await encrypt_credentials_transit(key_pem, str(tenant_id))
|
||||
|
||||
# Compute SHA-256 fingerprint (colon-separated hex)
|
||||
fingerprint_bytes = ca_cert.fingerprint(hashes.SHA256())
|
||||
fingerprint = ":".join(f"{b:02X}" for b in fingerprint_bytes)
|
||||
|
||||
# Serial number as hex string
|
||||
serial_hex = format(ca_cert.serial_number, "X")
|
||||
|
||||
model = CertificateAuthority(
|
||||
tenant_id=tenant_id,
|
||||
common_name=common_name,
|
||||
cert_pem=cert_pem,
|
||||
encrypted_private_key=b"", # Legacy column kept for schema compat
|
||||
encrypted_private_key_transit=encrypted_key_transit,
|
||||
serial_number=serial_hex,
|
||||
fingerprint_sha256=fingerprint,
|
||||
not_valid_before=now,
|
||||
not_valid_after=expiry,
|
||||
)
|
||||
db.add(model)
|
||||
await db.flush()
|
||||
|
||||
logger.info(
|
||||
"Generated CA for tenant %s: cn=%s fingerprint=%s",
|
||||
tenant_id,
|
||||
common_name,
|
||||
fingerprint,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Device Certificate Signing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def sign_device_cert(
|
||||
db: AsyncSession,
|
||||
ca: CertificateAuthority,
|
||||
device_id: UUID,
|
||||
hostname: str,
|
||||
ip_address: str,
|
||||
validity_days: int,
|
||||
encryption_key: bytes,
|
||||
) -> DeviceCertificate:
|
||||
"""Sign a per-device TLS certificate using the tenant's CA.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
ca: The tenant's CertificateAuthority model instance.
|
||||
device_id: UUID of the device receiving the cert.
|
||||
hostname: Device hostname — used as CN and SAN DNSName.
|
||||
ip_address: Device IP — used as SAN IPAddress.
|
||||
validity_days: Certificate validity in days.
|
||||
encryption_key: 32-byte AES-256-GCM key for encrypting the device private key.
|
||||
|
||||
Returns:
|
||||
The newly created ``DeviceCertificate`` model instance (status='issued').
|
||||
"""
|
||||
# Decrypt CA private key (dual-read: Transit preferred, legacy fallback)
|
||||
ca_key_pem = await decrypt_credentials_hybrid(
|
||||
ca.encrypted_private_key_transit,
|
||||
ca.encrypted_private_key,
|
||||
str(ca.tenant_id),
|
||||
encryption_key,
|
||||
)
|
||||
ca_key = serialization.load_pem_private_key(
|
||||
ca_key_pem.encode("utf-8"), password=None
|
||||
)
|
||||
|
||||
# Load CA certificate for issuer info and AuthorityKeyIdentifier
|
||||
ca_cert = x509.load_pem_x509_certificate(ca.cert_pem.encode("utf-8"))
|
||||
|
||||
# Generate device RSA 2048 key
|
||||
device_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
expiry = now + datetime.timedelta(days=validity_days)
|
||||
|
||||
device_cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(
|
||||
x509.Name([
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, hostname),
|
||||
])
|
||||
)
|
||||
.issuer_name(ca_cert.subject)
|
||||
.public_key(device_key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now)
|
||||
.not_valid_after(expiry)
|
||||
.add_extension(
|
||||
x509.BasicConstraints(ca=False, path_length=None), critical=True
|
||||
)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
content_commitment=False,
|
||||
key_encipherment=True,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
key_cert_sign=False,
|
||||
crl_sign=False,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.add_extension(
|
||||
x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.SubjectAlternativeName([
|
||||
x509.IPAddress(ipaddress.ip_address(ip_address)),
|
||||
x509.DNSName(hostname),
|
||||
]),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
|
||||
ca_cert.extensions.get_extension_for_class(
|
||||
x509.SubjectKeyIdentifier
|
||||
).value
|
||||
),
|
||||
critical=False,
|
||||
)
|
||||
.sign(ca_key, hashes.SHA256())
|
||||
)
|
||||
|
||||
# Serialize device cert and key to PEM
|
||||
cert_pem = device_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
|
||||
key_pem = device_key.private_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PrivateFormat.PKCS8,
|
||||
serialization.NoEncryption(),
|
||||
).decode("utf-8")
|
||||
|
||||
# Encrypt device private key via OpenBao Transit
|
||||
encrypted_key_transit = await encrypt_credentials_transit(key_pem, str(ca.tenant_id))
|
||||
|
||||
# Compute fingerprint
|
||||
fingerprint_bytes = device_cert.fingerprint(hashes.SHA256())
|
||||
fingerprint = ":".join(f"{b:02X}" for b in fingerprint_bytes)
|
||||
|
||||
serial_hex = format(device_cert.serial_number, "X")
|
||||
|
||||
model = DeviceCertificate(
|
||||
tenant_id=ca.tenant_id,
|
||||
device_id=device_id,
|
||||
ca_id=ca.id,
|
||||
common_name=hostname,
|
||||
serial_number=serial_hex,
|
||||
fingerprint_sha256=fingerprint,
|
||||
cert_pem=cert_pem,
|
||||
encrypted_private_key=b"", # Legacy column kept for schema compat
|
||||
encrypted_private_key_transit=encrypted_key_transit,
|
||||
not_valid_before=now,
|
||||
not_valid_after=expiry,
|
||||
status="issued",
|
||||
)
|
||||
db.add(model)
|
||||
await db.flush()
|
||||
|
||||
logger.info(
|
||||
"Signed device cert for device %s: cn=%s fingerprint=%s",
|
||||
device_id,
|
||||
hostname,
|
||||
fingerprint,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Queries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def get_ca_for_tenant(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
) -> CertificateAuthority | None:
|
||||
"""Return the tenant's CA, or None if not yet initialized."""
|
||||
result = await db.execute(
|
||||
select(CertificateAuthority).where(
|
||||
CertificateAuthority.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_device_certs(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
device_id: UUID | None = None,
|
||||
) -> list[DeviceCertificate]:
|
||||
"""List device certificates for a tenant.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
tenant_id: Tenant UUID.
|
||||
device_id: If provided, filter to certs for this device only.
|
||||
|
||||
Returns:
|
||||
List of DeviceCertificate models (excludes superseded by default).
|
||||
"""
|
||||
stmt = (
|
||||
select(DeviceCertificate)
|
||||
.where(DeviceCertificate.tenant_id == tenant_id)
|
||||
.where(DeviceCertificate.status != "superseded")
|
||||
)
|
||||
if device_id is not None:
|
||||
stmt = stmt.where(DeviceCertificate.device_id == device_id)
|
||||
stmt = stmt.order_by(DeviceCertificate.created_at.desc())
|
||||
result = await db.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Status Management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def update_cert_status(
|
||||
db: AsyncSession,
|
||||
cert_id: UUID,
|
||||
status: str,
|
||||
deployed_at: datetime.datetime | None = None,
|
||||
) -> DeviceCertificate:
|
||||
"""Update a device certificate's lifecycle status.
|
||||
|
||||
Validates that the transition is allowed by the state machine:
|
||||
issued -> deploying -> deployed -> expiring -> expired
|
||||
\\-> revoked
|
||||
\\-> superseded
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
cert_id: Certificate UUID.
|
||||
status: New status value.
|
||||
deployed_at: Timestamp to set when transitioning to 'deployed'.
|
||||
|
||||
Returns:
|
||||
The updated DeviceCertificate model.
|
||||
|
||||
Raises:
|
||||
ValueError: If the certificate is not found or the transition is invalid.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
|
||||
)
|
||||
cert = result.scalar_one_or_none()
|
||||
if cert is None:
|
||||
raise ValueError(f"Device certificate {cert_id} not found")
|
||||
|
||||
allowed = _VALID_TRANSITIONS.get(cert.status, set())
|
||||
if status not in allowed:
|
||||
raise ValueError(
|
||||
f"Invalid status transition: {cert.status} -> {status}. "
|
||||
f"Allowed transitions from '{cert.status}': {allowed or 'none'}"
|
||||
)
|
||||
|
||||
cert.status = status
|
||||
cert.updated_at = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
if status == "deployed" and deployed_at is not None:
|
||||
cert.deployed_at = deployed_at
|
||||
elif status == "deployed":
|
||||
cert.deployed_at = cert.updated_at
|
||||
|
||||
await db.flush()
|
||||
|
||||
logger.info(
|
||||
"Updated cert %s status to %s",
|
||||
cert_id,
|
||||
status,
|
||||
)
|
||||
return cert
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cert Data for Deployment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def get_cert_for_deploy(
|
||||
db: AsyncSession,
|
||||
cert_id: UUID,
|
||||
encryption_key: bytes,
|
||||
) -> tuple[str, str, str]:
|
||||
"""Retrieve and decrypt certificate data for NATS deployment.
|
||||
|
||||
Returns the device cert PEM, decrypted device key PEM, and the CA cert
|
||||
PEM — everything needed to push to a device via the Go poller.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
cert_id: Device certificate UUID.
|
||||
encryption_key: 32-byte AES-256-GCM key to decrypt the device private key.
|
||||
|
||||
Returns:
|
||||
Tuple of (cert_pem, key_pem_decrypted, ca_cert_pem).
|
||||
|
||||
Raises:
|
||||
ValueError: If the certificate or its CA is not found.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
|
||||
)
|
||||
cert = result.scalar_one_or_none()
|
||||
if cert is None:
|
||||
raise ValueError(f"Device certificate {cert_id} not found")
|
||||
|
||||
# Fetch the CA for the ca_cert_pem
|
||||
ca_result = await db.execute(
|
||||
select(CertificateAuthority).where(
|
||||
CertificateAuthority.id == cert.ca_id
|
||||
)
|
||||
)
|
||||
ca = ca_result.scalar_one_or_none()
|
||||
if ca is None:
|
||||
raise ValueError(f"CA {cert.ca_id} not found for certificate {cert_id}")
|
||||
|
||||
# Decrypt device private key (dual-read: Transit preferred, legacy fallback)
|
||||
key_pem = await decrypt_credentials_hybrid(
|
||||
cert.encrypted_private_key_transit,
|
||||
cert.encrypted_private_key,
|
||||
str(cert.tenant_id),
|
||||
encryption_key,
|
||||
)
|
||||
|
||||
return cert.cert_pem, key_pem, ca.cert_pem
|
||||
118
backend/app/services/config_change_subscriber.py
Normal file
118
backend/app/services/config_change_subscriber.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""NATS subscriber for config change events from the Go poller.
|
||||
|
||||
Triggers automatic backups when out-of-band config changes are detected,
|
||||
with 5-minute deduplication to prevent rapid-fire backups.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
from app.models.config_backup import ConfigBackupRun
|
||||
from app.services import backup_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEDUP_WINDOW_MINUTES = 5
|
||||
|
||||
_nc: Optional[Any] = None
|
||||
|
||||
|
||||
async def _last_backup_within_dedup_window(device_id: str) -> bool:
|
||||
"""Check if a backup was created for this device in the last N minutes."""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(minutes=DEDUP_WINDOW_MINUTES)
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
select(ConfigBackupRun)
|
||||
.where(
|
||||
ConfigBackupRun.device_id == device_id,
|
||||
ConfigBackupRun.created_at > cutoff,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
return result.scalar_one_or_none() is not None
|
||||
|
||||
|
||||
async def handle_config_changed(event: dict) -> None:
|
||||
"""Handle a config change event. Trigger backup with dedup."""
|
||||
device_id = event.get("device_id")
|
||||
tenant_id = event.get("tenant_id")
|
||||
|
||||
if not device_id or not tenant_id:
|
||||
logger.warning("Config change event missing device_id or tenant_id: %s", event)
|
||||
return
|
||||
|
||||
# Dedup check
|
||||
if await _last_backup_within_dedup_window(device_id):
|
||||
logger.info(
|
||||
"Config change on device %s — skipping backup (within %dm dedup window)",
|
||||
device_id, DEDUP_WINDOW_MINUTES,
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Config change detected on device %s (tenant %s): %s -> %s",
|
||||
device_id, tenant_id,
|
||||
event.get("old_timestamp", "?"),
|
||||
event.get("new_timestamp", "?"),
|
||||
)
|
||||
|
||||
try:
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
await backup_service.run_backup(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
trigger_type="config-change",
|
||||
db_session=session,
|
||||
)
|
||||
await session.commit()
|
||||
logger.info("Config-change backup completed for device %s", device_id)
|
||||
except Exception as e:
|
||||
logger.error("Config-change backup failed for device %s: %s", device_id, e)
|
||||
|
||||
|
||||
async def _on_message(msg) -> None:
|
||||
"""NATS message handler for config.changed.> subjects."""
|
||||
try:
|
||||
event = json.loads(msg.data.decode())
|
||||
await handle_config_changed(event)
|
||||
await msg.ack()
|
||||
except Exception as e:
|
||||
logger.error("Error handling config change message: %s", e)
|
||||
await msg.nak()
|
||||
|
||||
|
||||
async def start_config_change_subscriber() -> Optional[Any]:
|
||||
"""Connect to NATS and subscribe to config.changed.> events."""
|
||||
import nats
|
||||
|
||||
global _nc
|
||||
try:
|
||||
logger.info("NATS config-change: connecting to %s", settings.NATS_URL)
|
||||
_nc = await nats.connect(settings.NATS_URL)
|
||||
js = _nc.jetstream()
|
||||
await js.subscribe(
|
||||
"config.changed.>",
|
||||
cb=_on_message,
|
||||
durable="api-config-change-consumer",
|
||||
stream="DEVICE_EVENTS",
|
||||
manual_ack=True,
|
||||
)
|
||||
logger.info("Config change subscriber started")
|
||||
return _nc
|
||||
except Exception as e:
|
||||
logger.error("Failed to start config change subscriber: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
async def stop_config_change_subscriber() -> None:
|
||||
"""Gracefully close the NATS connection."""
|
||||
global _nc
|
||||
if _nc:
|
||||
await _nc.drain()
|
||||
_nc = None
|
||||
183
backend/app/services/crypto.py
Normal file
183
backend/app/services/crypto.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
Credential encryption/decryption with dual-read (OpenBao Transit + legacy AES-256-GCM).
|
||||
|
||||
This module provides two encryption paths:
|
||||
1. Legacy (sync): AES-256-GCM with static CREDENTIAL_ENCRYPTION_KEY — used for fallback reads.
|
||||
2. Transit (async): OpenBao Transit per-tenant keys — used for all new writes.
|
||||
|
||||
The dual-read pattern:
|
||||
- New writes always use OpenBao Transit (encrypt_credentials_transit).
|
||||
- Reads prefer Transit ciphertext, falling back to legacy (decrypt_credentials_hybrid).
|
||||
- Legacy functions are preserved for backward compatibility during migration.
|
||||
|
||||
Security properties:
|
||||
- AES-256-GCM provides authenticated encryption (confidentiality + integrity)
|
||||
- A unique 12-byte random nonce is generated per legacy encryption operation
|
||||
- OpenBao Transit keys are AES-256-GCM96, managed entirely by OpenBao
|
||||
- Ciphertext format: "vault:v1:..." for Transit, raw bytes for legacy
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def encrypt_credentials(plaintext: str, key: bytes) -> bytes:
|
||||
"""
|
||||
Encrypt a plaintext string using AES-256-GCM.
|
||||
|
||||
Args:
|
||||
plaintext: The credential string to encrypt (e.g., JSON with username/password)
|
||||
key: 32-byte encryption key
|
||||
|
||||
Returns:
|
||||
bytes: nonce (12 bytes) + ciphertext + GCM tag (16 bytes)
|
||||
|
||||
Raises:
|
||||
ValueError: If key is not exactly 32 bytes
|
||||
"""
|
||||
if len(key) != 32:
|
||||
raise ValueError(f"Key must be exactly 32 bytes, got {len(key)}")
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
aesgcm = AESGCM(key)
|
||||
nonce = os.urandom(12) # 96-bit nonce, unique per encryption
|
||||
ciphertext = aesgcm.encrypt(nonce, plaintext.encode("utf-8"), None)
|
||||
|
||||
# Store as: nonce (12 bytes) + ciphertext + GCM tag (included in ciphertext by library)
|
||||
return nonce + ciphertext
|
||||
|
||||
|
||||
def decrypt_credentials(ciphertext: bytes, key: bytes) -> str:
|
||||
"""
|
||||
Decrypt AES-256-GCM encrypted credentials.
|
||||
|
||||
Args:
|
||||
ciphertext: bytes from encrypt_credentials (nonce + encrypted data + GCM tag)
|
||||
key: 32-byte encryption key (must match the key used for encryption)
|
||||
|
||||
Returns:
|
||||
str: The original plaintext string
|
||||
|
||||
Raises:
|
||||
ValueError: If key is not exactly 32 bytes
|
||||
cryptography.exceptions.InvalidTag: If authentication fails (tampered data or wrong key)
|
||||
"""
|
||||
if len(key) != 32:
|
||||
raise ValueError(f"Key must be exactly 32 bytes, got {len(key)}")
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
nonce = ciphertext[:12]
|
||||
encrypted_data = ciphertext[12:]
|
||||
|
||||
aesgcm = AESGCM(key)
|
||||
plaintext_bytes = aesgcm.decrypt(nonce, encrypted_data, None)
|
||||
|
||||
return plaintext_bytes.decode("utf-8")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenBao Transit functions (async, per-tenant keys)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def encrypt_credentials_transit(plaintext: str, tenant_id: str) -> str:
|
||||
"""Encrypt via OpenBao Transit. Returns ciphertext string (vault:v1:...).
|
||||
|
||||
Args:
|
||||
plaintext: The credential string to encrypt.
|
||||
tenant_id: Tenant UUID string for key lookup.
|
||||
|
||||
Returns:
|
||||
Transit ciphertext string (vault:v1:base64...).
|
||||
"""
|
||||
from app.services.openbao_service import get_openbao_service
|
||||
|
||||
service = get_openbao_service()
|
||||
return await service.encrypt(tenant_id, plaintext.encode("utf-8"))
|
||||
|
||||
|
||||
async def decrypt_credentials_transit(ciphertext: str, tenant_id: str) -> str:
|
||||
"""Decrypt OpenBao Transit ciphertext. Returns plaintext string.
|
||||
|
||||
Args:
|
||||
ciphertext: Transit ciphertext (vault:v1:...).
|
||||
tenant_id: Tenant UUID string for key lookup.
|
||||
|
||||
Returns:
|
||||
Decrypted plaintext string.
|
||||
"""
|
||||
from app.services.openbao_service import get_openbao_service
|
||||
|
||||
service = get_openbao_service()
|
||||
plaintext_bytes = await service.decrypt(tenant_id, ciphertext)
|
||||
return plaintext_bytes.decode("utf-8")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenBao Transit data encryption (async, per-tenant _data keys — Phase 30)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def encrypt_data_transit(plaintext: str, tenant_id: str) -> str:
|
||||
"""Encrypt non-credential data via OpenBao Transit using per-tenant data key.
|
||||
|
||||
Used for audit log details, config backups, and reports. Data keys are
|
||||
separate from credential keys (tenant_{uuid}_data vs tenant_{uuid}).
|
||||
|
||||
Args:
|
||||
plaintext: The data string to encrypt.
|
||||
tenant_id: Tenant UUID string for data key lookup.
|
||||
|
||||
Returns:
|
||||
Transit ciphertext string (vault:v1:base64...).
|
||||
"""
|
||||
from app.services.openbao_service import get_openbao_service
|
||||
|
||||
service = get_openbao_service()
|
||||
return await service.encrypt_data(tenant_id, plaintext.encode("utf-8"))
|
||||
|
||||
|
||||
async def decrypt_data_transit(ciphertext: str, tenant_id: str) -> str:
|
||||
"""Decrypt OpenBao Transit data ciphertext. Returns plaintext string.
|
||||
|
||||
Args:
|
||||
ciphertext: Transit ciphertext (vault:v1:...).
|
||||
tenant_id: Tenant UUID string for data key lookup.
|
||||
|
||||
Returns:
|
||||
Decrypted plaintext string.
|
||||
"""
|
||||
from app.services.openbao_service import get_openbao_service
|
||||
|
||||
service = get_openbao_service()
|
||||
plaintext_bytes = await service.decrypt_data(tenant_id, ciphertext)
|
||||
return plaintext_bytes.decode("utf-8")
|
||||
|
||||
|
||||
async def decrypt_credentials_hybrid(
|
||||
transit_ciphertext: str | None,
|
||||
legacy_ciphertext: bytes | None,
|
||||
tenant_id: str,
|
||||
legacy_key: bytes,
|
||||
) -> str:
|
||||
"""Dual-read: prefer Transit ciphertext, fall back to legacy.
|
||||
|
||||
Args:
|
||||
transit_ciphertext: OpenBao Transit ciphertext (vault:v1:...) or None.
|
||||
legacy_ciphertext: Legacy AES-256-GCM bytes (nonce+ciphertext+tag) or None.
|
||||
tenant_id: Tenant UUID string for Transit key lookup.
|
||||
legacy_key: 32-byte legacy encryption key for fallback.
|
||||
|
||||
Returns:
|
||||
Decrypted plaintext string.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither ciphertext is available.
|
||||
"""
|
||||
if transit_ciphertext and transit_ciphertext.startswith("vault:v"):
|
||||
return await decrypt_credentials_transit(transit_ciphertext, tenant_id)
|
||||
elif legacy_ciphertext:
|
||||
return decrypt_credentials(legacy_ciphertext, legacy_key)
|
||||
else:
|
||||
raise ValueError("No credentials available (both transit and legacy are empty)")
|
||||
670
backend/app/services/device.py
Normal file
670
backend/app/services/device.py
Normal file
@@ -0,0 +1,670 @@
|
||||
"""
|
||||
Device service — business logic for device CRUD, credential encryption, groups, and tags.
|
||||
|
||||
All functions operate via the app_user engine (RLS enforced).
|
||||
Tenant isolation is handled automatically by PostgreSQL RLS policies
|
||||
(SET LOCAL app.current_tenant is set by the get_current_user dependency before
|
||||
this layer is called).
|
||||
|
||||
Credential policy:
|
||||
- Credentials are always stored as AES-256-GCM encrypted JSON blobs.
|
||||
- Credentials are NEVER returned in any public-facing response.
|
||||
- Re-encryption happens only when a new password is explicitly provided in an update.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models.device import (
|
||||
Device,
|
||||
DeviceGroup,
|
||||
DeviceGroupMembership,
|
||||
DeviceTag,
|
||||
DeviceTagAssignment,
|
||||
)
|
||||
from app.schemas.device import (
|
||||
BulkAddRequest,
|
||||
BulkAddResult,
|
||||
DeviceCreate,
|
||||
DeviceGroupCreate,
|
||||
DeviceGroupResponse,
|
||||
DeviceGroupUpdate,
|
||||
DeviceResponse,
|
||||
DeviceTagCreate,
|
||||
DeviceTagResponse,
|
||||
DeviceTagUpdate,
|
||||
DeviceUpdate,
|
||||
)
|
||||
from app.config import settings
|
||||
from app.services.crypto import (
|
||||
decrypt_credentials,
|
||||
decrypt_credentials_hybrid,
|
||||
encrypt_credentials,
|
||||
encrypt_credentials_transit,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _tcp_reachable(ip: str, port: int, timeout: float = 3.0) -> bool:
|
||||
"""Return True if a TCP connection to ip:port succeeds within timeout."""
|
||||
try:
|
||||
_, writer = await asyncio.wait_for(
|
||||
asyncio.open_connection(ip, port), timeout=timeout
|
||||
)
|
||||
writer.close()
|
||||
try:
|
||||
await writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _build_device_response(device: Device) -> DeviceResponse:
|
||||
"""
|
||||
Build a DeviceResponse from an ORM Device instance.
|
||||
|
||||
Tags and groups are extracted from pre-loaded relationships.
|
||||
Credentials are explicitly EXCLUDED.
|
||||
"""
|
||||
from app.schemas.device import DeviceGroupRef, DeviceTagRef
|
||||
|
||||
tags = [
|
||||
DeviceTagRef(
|
||||
id=a.tag.id,
|
||||
name=a.tag.name,
|
||||
color=a.tag.color,
|
||||
)
|
||||
for a in device.tag_assignments
|
||||
]
|
||||
|
||||
groups = [
|
||||
DeviceGroupRef(
|
||||
id=m.group.id,
|
||||
name=m.group.name,
|
||||
)
|
||||
for m in device.group_memberships
|
||||
]
|
||||
|
||||
return DeviceResponse(
|
||||
id=device.id,
|
||||
hostname=device.hostname,
|
||||
ip_address=device.ip_address,
|
||||
api_port=device.api_port,
|
||||
api_ssl_port=device.api_ssl_port,
|
||||
model=device.model,
|
||||
serial_number=device.serial_number,
|
||||
firmware_version=device.firmware_version,
|
||||
routeros_version=device.routeros_version,
|
||||
uptime_seconds=device.uptime_seconds,
|
||||
last_seen=device.last_seen,
|
||||
latitude=device.latitude,
|
||||
longitude=device.longitude,
|
||||
status=device.status,
|
||||
tls_mode=device.tls_mode,
|
||||
tags=tags,
|
||||
groups=groups,
|
||||
created_at=device.created_at,
|
||||
)
|
||||
|
||||
|
||||
def _device_with_relations():
|
||||
"""Return a select() for Device with tags and groups eagerly loaded."""
|
||||
return select(Device).options(
|
||||
selectinload(Device.tag_assignments).selectinload(DeviceTagAssignment.tag),
|
||||
selectinload(Device.group_memberships).selectinload(DeviceGroupMembership.group),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Device CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def create_device(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
data: DeviceCreate,
|
||||
encryption_key: bytes,
|
||||
) -> DeviceResponse:
|
||||
"""
|
||||
Create a new device.
|
||||
|
||||
- Validates TCP connectivity (api_port or api_ssl_port must be reachable).
|
||||
- Encrypts credentials before storage.
|
||||
- Status set to "unknown" until the Go poller runs a full auth check (Phase 2).
|
||||
"""
|
||||
# Test connectivity before accepting the device
|
||||
api_reachable = await _tcp_reachable(data.ip_address, data.api_port)
|
||||
ssl_reachable = await _tcp_reachable(data.ip_address, data.api_ssl_port)
|
||||
|
||||
if not api_reachable and not ssl_reachable:
|
||||
from fastapi import HTTPException, status
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=(
|
||||
f"Cannot reach {data.ip_address} on port {data.api_port} "
|
||||
f"(RouterOS API) or {data.api_ssl_port} (RouterOS SSL API). "
|
||||
"Verify the IP address and that the RouterOS API is enabled."
|
||||
),
|
||||
)
|
||||
|
||||
# Encrypt credentials via OpenBao Transit (new writes go through Transit)
|
||||
credentials_json = json.dumps({"username": data.username, "password": data.password})
|
||||
transit_ciphertext = await encrypt_credentials_transit(
|
||||
credentials_json, str(tenant_id)
|
||||
)
|
||||
|
||||
device = Device(
|
||||
tenant_id=tenant_id,
|
||||
hostname=data.hostname,
|
||||
ip_address=data.ip_address,
|
||||
api_port=data.api_port,
|
||||
api_ssl_port=data.api_ssl_port,
|
||||
encrypted_credentials_transit=transit_ciphertext,
|
||||
status="unknown",
|
||||
)
|
||||
db.add(device)
|
||||
await db.flush() # Get the ID without committing
|
||||
await db.refresh(device)
|
||||
|
||||
# Re-query with relationships loaded
|
||||
result = await db.execute(
|
||||
_device_with_relations().where(Device.id == device.id)
|
||||
)
|
||||
device = result.scalar_one()
|
||||
return _build_device_response(device)
|
||||
|
||||
|
||||
async def get_devices(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
page: int = 1,
|
||||
page_size: int = 25,
|
||||
status: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
tag_id: Optional[uuid.UUID] = None,
|
||||
group_id: Optional[uuid.UUID] = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
) -> tuple[list[DeviceResponse], int]:
|
||||
"""
|
||||
Return a paginated list of devices with optional filtering and sorting.
|
||||
|
||||
Returns (items, total_count).
|
||||
RLS automatically scopes this to the caller's tenant.
|
||||
"""
|
||||
base_q = _device_with_relations()
|
||||
|
||||
# Filtering
|
||||
if status:
|
||||
base_q = base_q.where(Device.status == status)
|
||||
|
||||
if search:
|
||||
pattern = f"%{search}%"
|
||||
base_q = base_q.where(
|
||||
or_(
|
||||
Device.hostname.ilike(pattern),
|
||||
Device.ip_address.ilike(pattern),
|
||||
)
|
||||
)
|
||||
|
||||
if tag_id:
|
||||
base_q = base_q.where(
|
||||
Device.id.in_(
|
||||
select(DeviceTagAssignment.device_id).where(
|
||||
DeviceTagAssignment.tag_id == tag_id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if group_id:
|
||||
base_q = base_q.where(
|
||||
Device.id.in_(
|
||||
select(DeviceGroupMembership.device_id).where(
|
||||
DeviceGroupMembership.group_id == group_id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Count total before pagination
|
||||
count_q = select(func.count()).select_from(base_q.subquery())
|
||||
total_result = await db.execute(count_q)
|
||||
total = total_result.scalar_one()
|
||||
|
||||
# Sorting
|
||||
allowed_sort_cols = {
|
||||
"created_at": Device.created_at,
|
||||
"hostname": Device.hostname,
|
||||
"ip_address": Device.ip_address,
|
||||
"status": Device.status,
|
||||
"last_seen": Device.last_seen,
|
||||
}
|
||||
sort_col = allowed_sort_cols.get(sort_by, Device.created_at)
|
||||
if sort_order.lower() == "asc":
|
||||
base_q = base_q.order_by(sort_col.asc())
|
||||
else:
|
||||
base_q = base_q.order_by(sort_col.desc())
|
||||
|
||||
# Pagination
|
||||
offset = (page - 1) * page_size
|
||||
base_q = base_q.offset(offset).limit(page_size)
|
||||
|
||||
result = await db.execute(base_q)
|
||||
devices = result.scalars().all()
|
||||
return [_build_device_response(d) for d in devices], total
|
||||
|
||||
|
||||
async def get_device(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
) -> DeviceResponse:
|
||||
"""Get a single device by ID."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
result = await db.execute(
|
||||
_device_with_relations().where(Device.id == device_id)
|
||||
)
|
||||
device = result.scalar_one_or_none()
|
||||
if not device:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
|
||||
return _build_device_response(device)
|
||||
|
||||
|
||||
async def update_device(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
data: DeviceUpdate,
|
||||
encryption_key: bytes,
|
||||
) -> DeviceResponse:
|
||||
"""
|
||||
Update device fields. Re-encrypts credentials only if password is provided.
|
||||
"""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
result = await db.execute(
|
||||
_device_with_relations().where(Device.id == device_id)
|
||||
)
|
||||
device = result.scalar_one_or_none()
|
||||
if not device:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
|
||||
|
||||
# Update scalar fields
|
||||
if data.hostname is not None:
|
||||
device.hostname = data.hostname
|
||||
if data.ip_address is not None:
|
||||
device.ip_address = data.ip_address
|
||||
if data.api_port is not None:
|
||||
device.api_port = data.api_port
|
||||
if data.api_ssl_port is not None:
|
||||
device.api_ssl_port = data.api_ssl_port
|
||||
if data.latitude is not None:
|
||||
device.latitude = data.latitude
|
||||
if data.longitude is not None:
|
||||
device.longitude = data.longitude
|
||||
if data.tls_mode is not None:
|
||||
device.tls_mode = data.tls_mode
|
||||
|
||||
# Re-encrypt credentials if new ones are provided
|
||||
credentials_changed = False
|
||||
if data.password is not None:
|
||||
# Decrypt existing to get current username if no new username given
|
||||
current_username: str = data.username or ""
|
||||
if not current_username and (device.encrypted_credentials_transit or device.encrypted_credentials):
|
||||
try:
|
||||
existing_json = await decrypt_credentials_hybrid(
|
||||
device.encrypted_credentials_transit,
|
||||
device.encrypted_credentials,
|
||||
str(device.tenant_id),
|
||||
settings.get_encryption_key_bytes(),
|
||||
)
|
||||
existing = json.loads(existing_json)
|
||||
current_username = existing.get("username", "")
|
||||
except Exception:
|
||||
current_username = ""
|
||||
|
||||
credentials_json = json.dumps({
|
||||
"username": data.username if data.username is not None else current_username,
|
||||
"password": data.password,
|
||||
})
|
||||
# New writes go through Transit
|
||||
device.encrypted_credentials_transit = await encrypt_credentials_transit(
|
||||
credentials_json, str(device.tenant_id)
|
||||
)
|
||||
device.encrypted_credentials = None # Clear legacy (Transit is canonical)
|
||||
credentials_changed = True
|
||||
elif data.username is not None and (device.encrypted_credentials_transit or device.encrypted_credentials):
|
||||
# Only username changed — update it without changing the password
|
||||
try:
|
||||
existing_json = await decrypt_credentials_hybrid(
|
||||
device.encrypted_credentials_transit,
|
||||
device.encrypted_credentials,
|
||||
str(device.tenant_id),
|
||||
settings.get_encryption_key_bytes(),
|
||||
)
|
||||
existing = json.loads(existing_json)
|
||||
existing["username"] = data.username
|
||||
# Re-encrypt via Transit
|
||||
device.encrypted_credentials_transit = await encrypt_credentials_transit(
|
||||
json.dumps(existing), str(device.tenant_id)
|
||||
)
|
||||
device.encrypted_credentials = None
|
||||
credentials_changed = True
|
||||
except Exception:
|
||||
pass # Keep existing encrypted blob if decryption fails
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(device)
|
||||
|
||||
# Notify poller to invalidate cached credentials (fire-and-forget via NATS)
|
||||
if credentials_changed:
|
||||
try:
|
||||
from app.services.event_publisher import publish_event
|
||||
await publish_event(
|
||||
f"device.credential_changed.{device_id}",
|
||||
{"device_id": str(device_id), "tenant_id": str(tenant_id)},
|
||||
)
|
||||
except Exception:
|
||||
pass # Never fail the update due to NATS issues
|
||||
|
||||
result2 = await db.execute(
|
||||
_device_with_relations().where(Device.id == device_id)
|
||||
)
|
||||
device = result2.scalar_one()
|
||||
return _build_device_response(device)
|
||||
|
||||
|
||||
async def delete_device(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
) -> None:
|
||||
"""Hard-delete a device (v1 — no soft delete for devices)."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
result = await db.execute(select(Device).where(Device.id == device_id))
|
||||
device = result.scalar_one_or_none()
|
||||
if not device:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
|
||||
await db.delete(device)
|
||||
await db.flush()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group / Tag assignment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def assign_device_to_group(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
group_id: uuid.UUID,
|
||||
) -> None:
|
||||
"""Assign a device to a group (idempotent)."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
# Verify device and group exist (RLS scopes both)
|
||||
dev = await db.get(Device, device_id)
|
||||
if not dev:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
|
||||
grp = await db.get(DeviceGroup, group_id)
|
||||
if not grp:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group not found")
|
||||
|
||||
existing = await db.get(DeviceGroupMembership, (device_id, group_id))
|
||||
if not existing:
|
||||
db.add(DeviceGroupMembership(device_id=device_id, group_id=group_id))
|
||||
await db.flush()
|
||||
|
||||
|
||||
async def remove_device_from_group(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
group_id: uuid.UUID,
|
||||
) -> None:
|
||||
"""Remove a device from a group."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
membership = await db.get(DeviceGroupMembership, (device_id, group_id))
|
||||
if not membership:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Device is not in this group",
|
||||
)
|
||||
await db.delete(membership)
|
||||
await db.flush()
|
||||
|
||||
|
||||
async def assign_tag_to_device(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
tag_id: uuid.UUID,
|
||||
) -> None:
|
||||
"""Assign a tag to a device (idempotent)."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
dev = await db.get(Device, device_id)
|
||||
if not dev:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
|
||||
tag = await db.get(DeviceTag, tag_id)
|
||||
if not tag:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tag not found")
|
||||
|
||||
existing = await db.get(DeviceTagAssignment, (device_id, tag_id))
|
||||
if not existing:
|
||||
db.add(DeviceTagAssignment(device_id=device_id, tag_id=tag_id))
|
||||
await db.flush()
|
||||
|
||||
|
||||
async def remove_tag_from_device(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
tag_id: uuid.UUID,
|
||||
) -> None:
|
||||
"""Remove a tag from a device."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
assignment = await db.get(DeviceTagAssignment, (device_id, tag_id))
|
||||
if not assignment:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tag is not assigned to this device",
|
||||
)
|
||||
await db.delete(assignment)
|
||||
await db.flush()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DeviceGroup CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def create_group(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
data: DeviceGroupCreate,
|
||||
) -> DeviceGroupResponse:
|
||||
"""Create a new device group."""
|
||||
group = DeviceGroup(
|
||||
tenant_id=tenant_id,
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
)
|
||||
db.add(group)
|
||||
await db.flush()
|
||||
await db.refresh(group)
|
||||
|
||||
# Count devices in the group (0 for new group)
|
||||
return DeviceGroupResponse(
|
||||
id=group.id,
|
||||
name=group.name,
|
||||
description=group.description,
|
||||
device_count=0,
|
||||
created_at=group.created_at,
|
||||
)
|
||||
|
||||
|
||||
async def get_groups(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
) -> list[DeviceGroupResponse]:
|
||||
"""Return all device groups for the current tenant with device counts."""
|
||||
result = await db.execute(
|
||||
select(DeviceGroup).options(
|
||||
selectinload(DeviceGroup.memberships)
|
||||
)
|
||||
)
|
||||
groups = result.scalars().all()
|
||||
return [
|
||||
DeviceGroupResponse(
|
||||
id=g.id,
|
||||
name=g.name,
|
||||
description=g.description,
|
||||
device_count=len(g.memberships),
|
||||
created_at=g.created_at,
|
||||
)
|
||||
for g in groups
|
||||
]
|
||||
|
||||
|
||||
async def update_group(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
group_id: uuid.UUID,
|
||||
data: DeviceGroupUpdate,
|
||||
) -> DeviceGroupResponse:
|
||||
"""Update a device group."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
result = await db.execute(
|
||||
select(DeviceGroup).options(
|
||||
selectinload(DeviceGroup.memberships)
|
||||
).where(DeviceGroup.id == group_id)
|
||||
)
|
||||
group = result.scalar_one_or_none()
|
||||
if not group:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group not found")
|
||||
|
||||
if data.name is not None:
|
||||
group.name = data.name
|
||||
if data.description is not None:
|
||||
group.description = data.description
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(group)
|
||||
|
||||
result2 = await db.execute(
|
||||
select(DeviceGroup).options(
|
||||
selectinload(DeviceGroup.memberships)
|
||||
).where(DeviceGroup.id == group_id)
|
||||
)
|
||||
group = result2.scalar_one()
|
||||
return DeviceGroupResponse(
|
||||
id=group.id,
|
||||
name=group.name,
|
||||
description=group.description,
|
||||
device_count=len(group.memberships),
|
||||
created_at=group.created_at,
|
||||
)
|
||||
|
||||
|
||||
async def delete_group(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
group_id: uuid.UUID,
|
||||
) -> None:
|
||||
"""Delete a device group."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
group = await db.get(DeviceGroup, group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group not found")
|
||||
await db.delete(group)
|
||||
await db.flush()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DeviceTag CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def create_tag(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
data: DeviceTagCreate,
|
||||
) -> DeviceTagResponse:
|
||||
"""Create a new device tag."""
|
||||
tag = DeviceTag(
|
||||
tenant_id=tenant_id,
|
||||
name=data.name,
|
||||
color=data.color,
|
||||
)
|
||||
db.add(tag)
|
||||
await db.flush()
|
||||
await db.refresh(tag)
|
||||
return DeviceTagResponse(id=tag.id, name=tag.name, color=tag.color)
|
||||
|
||||
|
||||
async def get_tags(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
) -> list[DeviceTagResponse]:
|
||||
"""Return all device tags for the current tenant."""
|
||||
result = await db.execute(select(DeviceTag))
|
||||
tags = result.scalars().all()
|
||||
return [DeviceTagResponse(id=t.id, name=t.name, color=t.color) for t in tags]
|
||||
|
||||
|
||||
async def update_tag(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
tag_id: uuid.UUID,
|
||||
data: DeviceTagUpdate,
|
||||
) -> DeviceTagResponse:
|
||||
"""Update a device tag."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
tag = await db.get(DeviceTag, tag_id)
|
||||
if not tag:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tag not found")
|
||||
|
||||
if data.name is not None:
|
||||
tag.name = data.name
|
||||
if data.color is not None:
|
||||
tag.color = data.color
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(tag)
|
||||
return DeviceTagResponse(id=tag.id, name=tag.name, color=tag.color)
|
||||
|
||||
|
||||
async def delete_tag(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
tag_id: uuid.UUID,
|
||||
) -> None:
|
||||
"""Delete a device tag."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
tag = await db.get(DeviceTag, tag_id)
|
||||
if not tag:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tag not found")
|
||||
await db.delete(tag)
|
||||
await db.flush()
|
||||
124
backend/app/services/email_service.py
Normal file
124
backend/app/services/email_service.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Unified email sending service.
|
||||
|
||||
All email sending (system emails, alert notifications) goes through this module.
|
||||
Supports TLS, STARTTLS, and plain SMTP. Handles Transit + legacy Fernet password decryption.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from email.message import EmailMessage
|
||||
from typing import Optional
|
||||
|
||||
import aiosmtplib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SMTPConfig:
|
||||
"""SMTP connection configuration."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int = 587,
|
||||
user: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
use_tls: bool = False,
|
||||
from_address: str = "noreply@example.com",
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.user = user
|
||||
self.password = password
|
||||
self.use_tls = use_tls
|
||||
self.from_address = from_address
|
||||
|
||||
|
||||
async def send_email(
|
||||
to: str,
|
||||
subject: str,
|
||||
html: str,
|
||||
plain_text: str,
|
||||
smtp_config: SMTPConfig,
|
||||
) -> None:
|
||||
"""Send an email via SMTP.
|
||||
|
||||
Args:
|
||||
to: Recipient email address.
|
||||
subject: Email subject line.
|
||||
html: HTML body.
|
||||
plain_text: Plain text fallback body.
|
||||
smtp_config: SMTP connection settings.
|
||||
|
||||
Raises:
|
||||
aiosmtplib.SMTPException: On SMTP connection or send failure.
|
||||
"""
|
||||
msg = EmailMessage()
|
||||
msg["Subject"] = subject
|
||||
msg["From"] = smtp_config.from_address
|
||||
msg["To"] = to
|
||||
msg.set_content(plain_text)
|
||||
msg.add_alternative(html, subtype="html")
|
||||
|
||||
use_tls = smtp_config.use_tls
|
||||
start_tls = not use_tls if smtp_config.port != 25 else False
|
||||
|
||||
await aiosmtplib.send(
|
||||
msg,
|
||||
hostname=smtp_config.host,
|
||||
port=smtp_config.port,
|
||||
username=smtp_config.user or None,
|
||||
password=smtp_config.password or None,
|
||||
use_tls=use_tls,
|
||||
start_tls=start_tls,
|
||||
)
|
||||
|
||||
|
||||
async def test_smtp_connection(smtp_config: SMTPConfig) -> dict:
|
||||
"""Test SMTP connectivity without sending an email.
|
||||
|
||||
Returns:
|
||||
dict with "success" bool and "message" string.
|
||||
"""
|
||||
try:
|
||||
smtp = aiosmtplib.SMTP(
|
||||
hostname=smtp_config.host,
|
||||
port=smtp_config.port,
|
||||
use_tls=smtp_config.use_tls,
|
||||
start_tls=not smtp_config.use_tls if smtp_config.port != 25 else False,
|
||||
)
|
||||
await smtp.connect()
|
||||
if smtp_config.user and smtp_config.password:
|
||||
await smtp.login(smtp_config.user, smtp_config.password)
|
||||
await smtp.quit()
|
||||
return {"success": True, "message": "SMTP connection successful"}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": str(e)}
|
||||
|
||||
|
||||
async def send_test_email(to: str, smtp_config: SMTPConfig) -> dict:
|
||||
"""Send a test email to verify the full SMTP flow.
|
||||
|
||||
Returns:
|
||||
dict with "success" bool and "message" string.
|
||||
"""
|
||||
html = """
|
||||
<div style="font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; max-width: 600px; margin: 0 auto;">
|
||||
<div style="background: #0f172a; padding: 24px; border-radius: 8px 8px 0 0;">
|
||||
<h2 style="color: #38bdf8; margin: 0;">TOD — Email Test</h2>
|
||||
</div>
|
||||
<div style="background: #1e293b; padding: 24px; border-radius: 0 0 8px 8px; color: #e2e8f0;">
|
||||
<p>This is a test email from The Other Dude.</p>
|
||||
<p>If you're reading this, your SMTP configuration is working correctly.</p>
|
||||
<p style="color: #94a3b8; font-size: 13px; margin-top: 24px;">
|
||||
Sent from TOD Fleet Management
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
plain = "TOD — Email Test\n\nThis is a test email from The Other Dude.\nIf you're reading this, your SMTP configuration is working correctly."
|
||||
|
||||
try:
|
||||
await send_email(to, "TOD — Test Email", html, plain, smtp_config)
|
||||
return {"success": True, "message": f"Test email sent to {to}"}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": str(e)}
|
||||
54
backend/app/services/emergency_kit_service.py
Normal file
54
backend/app/services/emergency_kit_service.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Emergency Kit PDF template generation.
|
||||
|
||||
Generates an Emergency Kit PDF containing the user's email and sign-in URL
|
||||
but NOT the Secret Key. The Secret Key placeholder is filled client-side
|
||||
so that the server never sees it.
|
||||
|
||||
Uses Jinja2 + WeasyPrint following the same pattern as the reports service.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
from app.config import settings
|
||||
|
||||
TEMPLATE_DIR = Path(__file__).parent.parent.parent / "templates"
|
||||
|
||||
|
||||
async def generate_emergency_kit_template(
|
||||
email: str,
|
||||
) -> bytes:
|
||||
"""Generate Emergency Kit PDF template WITHOUT the Secret Key.
|
||||
|
||||
The Secret Key placeholder will be filled client-side.
|
||||
The server never sees the Secret Key.
|
||||
|
||||
Args:
|
||||
email: The user's email address to display in the PDF.
|
||||
|
||||
Returns:
|
||||
PDF bytes ready for streaming response.
|
||||
"""
|
||||
env = Environment(
|
||||
loader=FileSystemLoader(str(TEMPLATE_DIR)),
|
||||
autoescape=True,
|
||||
)
|
||||
template = env.get_template("emergency_kit.html")
|
||||
|
||||
html_content = template.render(
|
||||
email=email,
|
||||
signin_url=settings.APP_BASE_URL,
|
||||
date=datetime.now(UTC).strftime("%Y-%m-%d"),
|
||||
secret_key_placeholder="[Download complete -- your Secret Key will be inserted by your browser]",
|
||||
)
|
||||
|
||||
# Run weasyprint in thread to avoid blocking the event loop
|
||||
from weasyprint import HTML
|
||||
|
||||
pdf_bytes = await asyncio.to_thread(
|
||||
lambda: HTML(string=html_content).write_pdf()
|
||||
)
|
||||
return pdf_bytes
|
||||
52
backend/app/services/event_publisher.py
Normal file
52
backend/app/services/event_publisher.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Fire-and-forget NATS JetStream event publisher for real-time SSE pipeline.
|
||||
|
||||
Provides a shared lazy NATS connection and publish helper used by:
|
||||
- alert_evaluator.py (alert.fired.{tenant_id}, alert.resolved.{tenant_id})
|
||||
- restore_service.py (config.push.{tenant_id}.{device_id})
|
||||
- upgrade_service.py (firmware.progress.{tenant_id}.{device_id})
|
||||
|
||||
All publishes are fire-and-forget: errors are logged but never propagate
|
||||
to the caller. A NATS outage must never block alert evaluation, config
|
||||
push, or firmware upgrade operations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import nats
|
||||
import nats.aio.client
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Module-level NATS connection (lazy initialized, reused across publishes)
|
||||
_nc: nats.aio.client.Client | None = None
|
||||
|
||||
|
||||
async def _get_nats() -> nats.aio.client.Client:
|
||||
"""Get or create a NATS connection for event publishing."""
|
||||
global _nc
|
||||
if _nc is None or _nc.is_closed:
|
||||
_nc = await nats.connect(settings.NATS_URL)
|
||||
logger.info("Event publisher NATS connection established")
|
||||
return _nc
|
||||
|
||||
|
||||
async def publish_event(subject: str, payload: dict[str, Any]) -> None:
|
||||
"""Publish a JSON event to a NATS JetStream subject (fire-and-forget).
|
||||
|
||||
Args:
|
||||
subject: NATS subject, e.g. "alert.fired.{tenant_id}".
|
||||
payload: Dict that will be JSON-serialized as the message body.
|
||||
|
||||
Never raises -- all exceptions are caught and logged as warnings.
|
||||
"""
|
||||
try:
|
||||
nc = await _get_nats()
|
||||
js = nc.jetstream()
|
||||
await js.publish(subject, json.dumps(payload).encode())
|
||||
logger.debug("Published event to %s", subject)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to publish event to %s: %s", subject, exc)
|
||||
303
backend/app/services/firmware_service.py
Normal file
303
backend/app/services/firmware_service.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""Firmware version cache service and NPK downloader.
|
||||
|
||||
Responsibilities:
|
||||
- check_latest_versions(): fetch latest RouterOS versions from download.mikrotik.com
|
||||
- download_firmware(): download NPK packages to local PVC cache
|
||||
- get_firmware_overview(): return fleet firmware status for a tenant
|
||||
- schedule_firmware_checks(): register daily firmware check job with APScheduler
|
||||
|
||||
Version discovery comes from two sources:
|
||||
1. Go poller runs /system/package/update per device (rate-limited to once/day)
|
||||
and publishes via NATS -> firmware_subscriber processes these events
|
||||
2. check_latest_versions() fetches LATEST.7 / LATEST.6 from download.mikrotik.com
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Architectures supported by RouterOS v7 and v6
|
||||
_V7_ARCHITECTURES = ["arm", "arm64", "mipsbe", "mmips", "smips", "tile", "ppc", "x86"]
|
||||
_V6_ARCHITECTURES = ["mipsbe", "mmips", "smips", "tile", "ppc", "x86"]
|
||||
|
||||
# Version source files on download.mikrotik.com
|
||||
_VERSION_SOURCES = [
|
||||
("LATEST.7", "stable", 7),
|
||||
("LATEST.7long", "long-term", 7),
|
||||
("LATEST.6", "stable", 6),
|
||||
("LATEST.6long", "long-term", 6),
|
||||
]
|
||||
|
||||
|
||||
async def check_latest_versions() -> list[dict]:
|
||||
"""Fetch latest RouterOS versions from download.mikrotik.com.
|
||||
|
||||
Checks LATEST.7, LATEST.7long, LATEST.6, and LATEST.6long files for
|
||||
version strings, then upserts into firmware_versions table for each
|
||||
architecture/channel combination.
|
||||
|
||||
Returns list of discovered version dicts.
|
||||
"""
|
||||
results: list[dict] = []
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
for channel_file, channel, major in _VERSION_SOURCES:
|
||||
try:
|
||||
resp = await client.get(
|
||||
f"https://download.mikrotik.com/routeros/{channel_file}"
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
"MikroTik version check returned %d for %s",
|
||||
resp.status_code, channel_file,
|
||||
)
|
||||
continue
|
||||
|
||||
version = resp.text.strip()
|
||||
if not version or not version[0].isdigit():
|
||||
logger.warning("Invalid version string from %s: %r", channel_file, version)
|
||||
continue
|
||||
|
||||
architectures = _V7_ARCHITECTURES if major == 7 else _V6_ARCHITECTURES
|
||||
for arch in architectures:
|
||||
npk_url = (
|
||||
f"https://download.mikrotik.com/routeros/"
|
||||
f"{version}/routeros-{version}-{arch}.npk"
|
||||
)
|
||||
results.append({
|
||||
"architecture": arch,
|
||||
"channel": channel,
|
||||
"version": version,
|
||||
"npk_url": npk_url,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to check %s: %s", channel_file, e)
|
||||
|
||||
# Upsert into firmware_versions table
|
||||
if results:
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
for r in results:
|
||||
await session.execute(
|
||||
text("""
|
||||
INSERT INTO firmware_versions (id, architecture, channel, version, npk_url, checked_at)
|
||||
VALUES (gen_random_uuid(), :arch, :channel, :version, :npk_url, NOW())
|
||||
ON CONFLICT (architecture, channel, version) DO UPDATE SET checked_at = NOW()
|
||||
"""),
|
||||
{
|
||||
"arch": r["architecture"],
|
||||
"channel": r["channel"],
|
||||
"version": r["version"],
|
||||
"npk_url": r["npk_url"],
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
logger.info("Firmware version check complete — %d versions discovered", len(results))
|
||||
return results
|
||||
|
||||
|
||||
async def download_firmware(architecture: str, channel: str, version: str) -> str:
|
||||
"""Download an NPK package to the local firmware cache.
|
||||
|
||||
Returns the local file path. Skips download if file already exists
|
||||
and size matches.
|
||||
"""
|
||||
cache_dir = Path(settings.FIRMWARE_CACHE_DIR) / version
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
filename = f"routeros-{version}-{architecture}.npk"
|
||||
local_path = cache_dir / filename
|
||||
npk_url = f"https://download.mikrotik.com/routeros/{version}/{filename}"
|
||||
|
||||
# Check if already cached
|
||||
if local_path.exists() and local_path.stat().st_size > 0:
|
||||
logger.info("Firmware already cached: %s", local_path)
|
||||
return str(local_path)
|
||||
|
||||
logger.info("Downloading firmware: %s", npk_url)
|
||||
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
async with client.stream("GET", npk_url) as response:
|
||||
response.raise_for_status()
|
||||
with open(local_path, "wb") as f:
|
||||
async for chunk in response.aiter_bytes(chunk_size=65536):
|
||||
f.write(chunk)
|
||||
|
||||
file_size = local_path.stat().st_size
|
||||
logger.info("Firmware downloaded: %s (%d bytes)", local_path, file_size)
|
||||
|
||||
# Update firmware_versions table with local path and size
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
await session.execute(
|
||||
text("""
|
||||
UPDATE firmware_versions
|
||||
SET npk_local_path = :path, npk_size_bytes = :size
|
||||
WHERE architecture = :arch AND channel = :channel AND version = :version
|
||||
"""),
|
||||
{
|
||||
"path": str(local_path),
|
||||
"size": file_size,
|
||||
"arch": architecture,
|
||||
"channel": channel,
|
||||
"version": version,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
return str(local_path)
|
||||
|
||||
|
||||
async def get_firmware_overview(tenant_id: str) -> dict:
|
||||
"""Return fleet firmware status for a tenant.
|
||||
|
||||
Returns devices grouped by firmware version, annotated with up-to-date status
|
||||
based on the latest known version for each device's architecture and preferred channel.
|
||||
"""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
# Get all devices for tenant
|
||||
devices_result = await session.execute(
|
||||
text("""
|
||||
SELECT id, hostname, ip_address, routeros_version, architecture,
|
||||
preferred_channel, routeros_major_version,
|
||||
serial_number, firmware_version, model
|
||||
FROM devices
|
||||
WHERE tenant_id = CAST(:tenant_id AS uuid)
|
||||
ORDER BY hostname
|
||||
"""),
|
||||
{"tenant_id": tenant_id},
|
||||
)
|
||||
devices = devices_result.fetchall()
|
||||
|
||||
# Get latest firmware versions per architecture/channel
|
||||
versions_result = await session.execute(
|
||||
text("""
|
||||
SELECT DISTINCT ON (architecture, channel)
|
||||
architecture, channel, version, npk_url
|
||||
FROM firmware_versions
|
||||
ORDER BY architecture, channel, checked_at DESC
|
||||
""")
|
||||
)
|
||||
latest_versions = {
|
||||
(row[0], row[1]): {"version": row[2], "npk_url": row[3]}
|
||||
for row in versions_result.fetchall()
|
||||
}
|
||||
|
||||
# Build per-device status
|
||||
device_list = []
|
||||
version_groups: dict[str, list] = {}
|
||||
summary = {"total": 0, "up_to_date": 0, "outdated": 0, "unknown": 0}
|
||||
|
||||
for dev in devices:
|
||||
dev_id = str(dev[0])
|
||||
hostname = dev[1]
|
||||
current_version = dev[3]
|
||||
arch = dev[4]
|
||||
channel = dev[5] or "stable"
|
||||
|
||||
latest = latest_versions.get((arch, channel)) if arch else None
|
||||
latest_version = latest["version"] if latest else None
|
||||
|
||||
is_up_to_date = False
|
||||
if not current_version or not arch:
|
||||
summary["unknown"] += 1
|
||||
elif latest_version and current_version == latest_version:
|
||||
is_up_to_date = True
|
||||
summary["up_to_date"] += 1
|
||||
else:
|
||||
summary["outdated"] += 1
|
||||
|
||||
summary["total"] += 1
|
||||
|
||||
dev_info = {
|
||||
"id": dev_id,
|
||||
"hostname": hostname,
|
||||
"ip_address": dev[2],
|
||||
"routeros_version": current_version,
|
||||
"architecture": arch,
|
||||
"latest_version": latest_version,
|
||||
"channel": channel,
|
||||
"is_up_to_date": is_up_to_date,
|
||||
"serial_number": dev[7],
|
||||
"firmware_version": dev[8],
|
||||
"model": dev[9],
|
||||
}
|
||||
device_list.append(dev_info)
|
||||
|
||||
# Group by version
|
||||
ver_key = current_version or "unknown"
|
||||
if ver_key not in version_groups:
|
||||
version_groups[ver_key] = []
|
||||
version_groups[ver_key].append(dev_info)
|
||||
|
||||
# Build version groups with is_latest flag
|
||||
groups = []
|
||||
for ver, devs in sorted(version_groups.items()):
|
||||
# A version is "latest" if it matches the latest for any arch/channel combo
|
||||
is_latest = any(
|
||||
v["version"] == ver for v in latest_versions.values()
|
||||
)
|
||||
groups.append({
|
||||
"version": ver,
|
||||
"count": len(devs),
|
||||
"is_latest": is_latest,
|
||||
"devices": devs,
|
||||
})
|
||||
|
||||
return {
|
||||
"devices": device_list,
|
||||
"version_groups": groups,
|
||||
"summary": summary,
|
||||
}
|
||||
|
||||
|
||||
async def get_cached_firmware() -> list[dict]:
|
||||
"""List all locally cached NPK files with their sizes."""
|
||||
cache_dir = Path(settings.FIRMWARE_CACHE_DIR)
|
||||
cached = []
|
||||
|
||||
if not cache_dir.exists():
|
||||
return cached
|
||||
|
||||
for version_dir in sorted(cache_dir.iterdir()):
|
||||
if not version_dir.is_dir():
|
||||
continue
|
||||
for npk_file in sorted(version_dir.iterdir()):
|
||||
if npk_file.suffix == ".npk":
|
||||
cached.append({
|
||||
"path": str(npk_file),
|
||||
"version": version_dir.name,
|
||||
"filename": npk_file.name,
|
||||
"size_bytes": npk_file.stat().st_size,
|
||||
})
|
||||
|
||||
return cached
|
||||
|
||||
|
||||
def schedule_firmware_checks() -> None:
|
||||
"""Register daily firmware version check with APScheduler.
|
||||
|
||||
Called from FastAPI lifespan startup to schedule check_latest_versions()
|
||||
at 3am UTC daily.
|
||||
"""
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from app.services.backup_scheduler import backup_scheduler
|
||||
|
||||
backup_scheduler.add_job(
|
||||
check_latest_versions,
|
||||
trigger=CronTrigger(hour=3, minute=0, timezone="UTC"),
|
||||
id="firmware_version_check",
|
||||
name="Check for new RouterOS firmware versions",
|
||||
max_instances=1,
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
logger.info("Firmware version check scheduled — daily at 3am UTC")
|
||||
206
backend/app/services/firmware_subscriber.py
Normal file
206
backend/app/services/firmware_subscriber.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""NATS JetStream subscriber for device firmware events from the Go poller.
|
||||
|
||||
Subscribes to device.firmware.> and:
|
||||
1. Updates devices.routeros_version and devices.architecture from poller data
|
||||
2. Upserts firmware_versions table with latest version per architecture/channel
|
||||
|
||||
Uses AdminAsyncSessionLocal (superuser bypass RLS) so firmware data from any
|
||||
tenant can be written without setting app.current_tenant.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import nats
|
||||
from nats.js import JetStreamContext
|
||||
from nats.aio.client import Client as NATSClient
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_firmware_client: Optional[NATSClient] = None
|
||||
|
||||
|
||||
async def on_device_firmware(msg) -> None:
|
||||
"""Handle a device.firmware event published by the Go poller.
|
||||
|
||||
Payload (JSON):
|
||||
device_id (str) -- UUID of the device
|
||||
tenant_id (str) -- UUID of the owning tenant
|
||||
installed_version (str) -- currently installed RouterOS version
|
||||
latest_version (str) -- latest available version (may be empty)
|
||||
channel (str) -- firmware channel ("stable", "long-term")
|
||||
status (str) -- "New version is available", etc.
|
||||
architecture (str) -- CPU architecture (arm, arm64, mipsbe, etc.)
|
||||
"""
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
device_id = data.get("device_id")
|
||||
tenant_id = data.get("tenant_id")
|
||||
architecture = data.get("architecture")
|
||||
installed_version = data.get("installed_version")
|
||||
latest_version = data.get("latest_version")
|
||||
channel = data.get("channel", "stable")
|
||||
|
||||
if not device_id:
|
||||
logger.warning("device.firmware event missing device_id — skipping")
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
# Update device routeros_version and architecture from poller data
|
||||
if architecture or installed_version:
|
||||
await session.execute(
|
||||
text("""
|
||||
UPDATE devices
|
||||
SET routeros_version = COALESCE(:installed_ver, routeros_version),
|
||||
architecture = COALESCE(:architecture, architecture),
|
||||
updated_at = NOW()
|
||||
WHERE id = CAST(:device_id AS uuid)
|
||||
"""),
|
||||
{
|
||||
"installed_ver": installed_version,
|
||||
"architecture": architecture,
|
||||
"device_id": device_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Upsert firmware_versions if we got latest version info
|
||||
if latest_version and architecture:
|
||||
npk_url = (
|
||||
f"https://download.mikrotik.com/routeros/"
|
||||
f"{latest_version}/routeros-{latest_version}-{architecture}.npk"
|
||||
)
|
||||
await session.execute(
|
||||
text("""
|
||||
INSERT INTO firmware_versions (id, architecture, channel, version, npk_url, checked_at)
|
||||
VALUES (gen_random_uuid(), :arch, :channel, :version, :url, NOW())
|
||||
ON CONFLICT (architecture, channel, version) DO UPDATE SET checked_at = NOW()
|
||||
"""),
|
||||
{
|
||||
"arch": architecture,
|
||||
"channel": channel,
|
||||
"version": latest_version,
|
||||
"url": npk_url,
|
||||
},
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
logger.debug(
|
||||
"device.firmware processed",
|
||||
extra={
|
||||
"device_id": device_id,
|
||||
"architecture": architecture,
|
||||
"installed": installed_version,
|
||||
"latest": latest_version,
|
||||
},
|
||||
)
|
||||
await msg.ack()
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to process device.firmware event: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
await msg.nak()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _subscribe_with_retry(js: JetStreamContext) -> None:
|
||||
"""Subscribe to device.firmware.> with durable consumer, retrying if stream not ready."""
|
||||
max_attempts = 6 # ~30 seconds at 5s intervals
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
await js.subscribe(
|
||||
"device.firmware.>",
|
||||
cb=on_device_firmware,
|
||||
durable="api-firmware-consumer",
|
||||
stream="DEVICE_EVENTS",
|
||||
)
|
||||
logger.info(
|
||||
"NATS: subscribed to device.firmware.> (durable: api-firmware-consumer)"
|
||||
)
|
||||
return
|
||||
except Exception as exc:
|
||||
if attempt < max_attempts:
|
||||
logger.warning(
|
||||
"NATS: stream DEVICE_EVENTS not ready for firmware (attempt %d/%d): %s — retrying in 5s",
|
||||
attempt,
|
||||
max_attempts,
|
||||
exc,
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
logger.warning(
|
||||
"NATS: giving up on device.firmware.> after %d attempts: %s — API will run without firmware updates",
|
||||
max_attempts,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
async def start_firmware_subscriber() -> Optional[NATSClient]:
|
||||
"""Connect to NATS and start the device.firmware.> subscription.
|
||||
|
||||
Uses a separate NATS connection from the status and metrics subscribers.
|
||||
|
||||
Returns the NATS connection (must be passed to stop_firmware_subscriber on shutdown).
|
||||
Raises on fatal connection errors after retry exhaustion.
|
||||
"""
|
||||
global _firmware_client
|
||||
|
||||
logger.info("NATS firmware: connecting to %s", settings.NATS_URL)
|
||||
|
||||
nc = await nats.connect(
|
||||
settings.NATS_URL,
|
||||
max_reconnect_attempts=-1,
|
||||
reconnect_time_wait=2,
|
||||
error_cb=_on_error,
|
||||
reconnected_cb=_on_reconnected,
|
||||
disconnected_cb=_on_disconnected,
|
||||
)
|
||||
|
||||
logger.info("NATS firmware: connected to %s", settings.NATS_URL)
|
||||
|
||||
js = nc.jetstream()
|
||||
await _subscribe_with_retry(js)
|
||||
|
||||
_firmware_client = nc
|
||||
return nc
|
||||
|
||||
|
||||
async def stop_firmware_subscriber(nc: Optional[NATSClient]) -> None:
|
||||
"""Drain and close the firmware NATS connection gracefully."""
|
||||
if nc is None:
|
||||
return
|
||||
try:
|
||||
logger.info("NATS firmware: draining connection...")
|
||||
await nc.drain()
|
||||
logger.info("NATS firmware: connection closed")
|
||||
except Exception as exc:
|
||||
logger.warning("NATS firmware: error during drain: %s", exc)
|
||||
try:
|
||||
await nc.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _on_error(exc: Exception) -> None:
|
||||
logger.error("NATS firmware error: %s", exc)
|
||||
|
||||
|
||||
async def _on_reconnected() -> None:
|
||||
logger.info("NATS firmware: reconnected")
|
||||
|
||||
|
||||
async def _on_disconnected() -> None:
|
||||
logger.warning("NATS firmware: disconnected")
|
||||
296
backend/app/services/git_store.py
Normal file
296
backend/app/services/git_store.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""pygit2-based git store for versioned config backup storage.
|
||||
|
||||
All functions in this module are synchronous (pygit2 is C bindings over libgit2).
|
||||
Callers running in an async context MUST wrap calls in:
|
||||
loop.run_in_executor(None, func, *args)
|
||||
or:
|
||||
asyncio.get_event_loop().run_in_executor(None, func, *args)
|
||||
|
||||
See Pitfall 3 in 04-RESEARCH.md — blocking pygit2 in async context stalls
|
||||
the event loop and causes timeouts for other concurrent requests.
|
||||
|
||||
Git layout:
|
||||
{GIT_STORE_PATH}/{tenant_id}.git/ <- bare repo per tenant
|
||||
objects/ refs/ HEAD <- standard bare git structure
|
||||
{device_id}/ <- device subtree
|
||||
export.rsc <- text export (/export compact)
|
||||
backup.bin <- binary system backup
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pygit2
|
||||
|
||||
from app.config import settings
|
||||
|
||||
# =========================================================================
|
||||
# Per-tenant mutex to prevent TreeBuilder race condition (Pitfall 5 in RESEARCH.md).
|
||||
# Two simultaneous backups for different devices in the same tenant repo would
|
||||
# each read HEAD, build their own device subtrees, and write conflicting root
|
||||
# trees. The second commit would lose the first's device subtree.
|
||||
# Lock scope is the entire tenant repo — not just the device.
|
||||
# =========================================================================
|
||||
_tenant_locks: dict[str, threading.Lock] = {}
|
||||
_tenant_locks_guard = threading.Lock()
|
||||
|
||||
|
||||
def _get_tenant_lock(tenant_id: str) -> threading.Lock:
|
||||
"""Return (creating if needed) the per-tenant commit lock."""
|
||||
with _tenant_locks_guard:
|
||||
if tenant_id not in _tenant_locks:
|
||||
_tenant_locks[tenant_id] = threading.Lock()
|
||||
return _tenant_locks[tenant_id]
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# PUBLIC API
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def get_or_create_repo(tenant_id: str) -> pygit2.Repository:
|
||||
"""Open the tenant's bare git repo, creating it on first use.
|
||||
|
||||
The repo lives at {GIT_STORE_PATH}/{tenant_id}.git. The parent directory
|
||||
is created if it does not exist.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID as string.
|
||||
|
||||
Returns:
|
||||
An open pygit2.Repository instance (bare).
|
||||
"""
|
||||
git_store_root = Path(settings.GIT_STORE_PATH)
|
||||
git_store_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
repo_path = git_store_root / f"{tenant_id}.git"
|
||||
if repo_path.exists():
|
||||
return pygit2.Repository(str(repo_path))
|
||||
|
||||
return pygit2.init_repository(str(repo_path), bare=True)
|
||||
|
||||
|
||||
def commit_backup(
|
||||
tenant_id: str,
|
||||
device_id: str,
|
||||
export_text: str,
|
||||
binary_backup: bytes,
|
||||
message: str,
|
||||
) -> str:
|
||||
"""Write a backup pair (export.rsc + backup.bin) as a git commit.
|
||||
|
||||
Creates or updates the device subdirectory in the tenant's bare repo.
|
||||
Preserves other devices' subdirectories by merging the device subtree
|
||||
into the existing root tree.
|
||||
|
||||
Per-tenant locking (threading.Lock) prevents the TreeBuilder race
|
||||
condition when two devices in the same tenant back up concurrently.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID as string.
|
||||
device_id: Device UUID as string (becomes a subdirectory in the repo).
|
||||
export_text: Text output of /export compact.
|
||||
binary_backup: Raw bytes from /system backup save.
|
||||
message: Commit message (format: "{trigger}: {hostname} ({ip}) at {ts}").
|
||||
|
||||
Returns:
|
||||
The hex commit SHA string (40 characters).
|
||||
"""
|
||||
lock = _get_tenant_lock(tenant_id)
|
||||
|
||||
with lock:
|
||||
repo = get_or_create_repo(tenant_id)
|
||||
|
||||
# Create blobs from content
|
||||
export_oid = repo.create_blob(export_text.encode("utf-8"))
|
||||
binary_oid = repo.create_blob(binary_backup)
|
||||
|
||||
# Build device subtree: {device_id}/export.rsc and {device_id}/backup.bin
|
||||
device_builder = repo.TreeBuilder()
|
||||
device_builder.insert("export.rsc", export_oid, pygit2.GIT_FILEMODE_BLOB)
|
||||
device_builder.insert("backup.bin", binary_oid, pygit2.GIT_FILEMODE_BLOB)
|
||||
device_tree_oid = device_builder.write()
|
||||
|
||||
# Merge device subtree into root tree, preserving all other device subtrees.
|
||||
# If the repo has no commits yet, start with an empty root tree.
|
||||
root_ref = repo.references.get("refs/heads/main")
|
||||
parent_commit: Optional[pygit2.Commit] = None
|
||||
|
||||
if root_ref is not None:
|
||||
try:
|
||||
parent_commit = repo.get(root_ref.target)
|
||||
root_builder = repo.TreeBuilder(parent_commit.tree)
|
||||
except Exception:
|
||||
root_builder = repo.TreeBuilder()
|
||||
else:
|
||||
root_builder = repo.TreeBuilder()
|
||||
|
||||
root_builder.insert(device_id, device_tree_oid, pygit2.GIT_FILEMODE_TREE)
|
||||
root_tree_oid = root_builder.write()
|
||||
|
||||
# Author signature — no real identity, portal service account
|
||||
author = pygit2.Signature("The Other Dude", "backup@tod.local")
|
||||
|
||||
parents = [root_ref.target] if root_ref is not None else []
|
||||
|
||||
commit_oid = repo.create_commit(
|
||||
"refs/heads/main",
|
||||
author,
|
||||
author,
|
||||
message,
|
||||
root_tree_oid,
|
||||
parents,
|
||||
)
|
||||
|
||||
return str(commit_oid)
|
||||
|
||||
|
||||
def read_file(
|
||||
tenant_id: str,
|
||||
commit_sha: str,
|
||||
device_id: str,
|
||||
filename: str,
|
||||
) -> bytes:
|
||||
"""Read a file blob from a specific backup commit.
|
||||
|
||||
Navigates the tree: root -> device_id subtree -> filename.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID as string.
|
||||
commit_sha: Full or abbreviated git commit SHA.
|
||||
device_id: Device UUID as string (subdirectory name in the repo).
|
||||
filename: File to read: "export.rsc" or "backup.bin".
|
||||
|
||||
Returns:
|
||||
Raw bytes of the file content.
|
||||
|
||||
Raises:
|
||||
KeyError: If device_id subtree or filename does not exist in commit.
|
||||
pygit2.GitError: If commit_sha is not found.
|
||||
"""
|
||||
repo = get_or_create_repo(tenant_id)
|
||||
|
||||
commit_obj = repo.get(commit_sha)
|
||||
if commit_obj is None:
|
||||
raise KeyError(f"Commit {commit_sha!r} not found in tenant {tenant_id!r} repo")
|
||||
|
||||
# Navigate: root tree -> device subtree -> file blob
|
||||
device_entry = commit_obj.tree[device_id]
|
||||
device_tree = repo.get(device_entry.id)
|
||||
file_entry = device_tree[filename]
|
||||
file_blob = repo.get(file_entry.id)
|
||||
|
||||
return file_blob.data
|
||||
|
||||
|
||||
def list_device_commits(
|
||||
tenant_id: str,
|
||||
device_id: str,
|
||||
) -> list[dict]:
|
||||
"""Walk commit history and return commits that include the device subtree.
|
||||
|
||||
Walks commits newest-first. Returns only commits where the device_id
|
||||
subtree is present in the root tree (the device had a backup in that commit).
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID as string.
|
||||
device_id: Device UUID as string.
|
||||
|
||||
Returns:
|
||||
List of dicts (newest first):
|
||||
[{"sha": str, "message": str, "timestamp": int}, ...]
|
||||
Empty list if no commits or device has never been backed up.
|
||||
"""
|
||||
repo = get_or_create_repo(tenant_id)
|
||||
|
||||
# If there are no commits, return empty list immediately.
|
||||
# Use refs/heads/main explicitly rather than repo.head (which defaults to
|
||||
# refs/heads/master — wrong when the repo uses 'main' as the default branch).
|
||||
main_ref = repo.references.get("refs/heads/main")
|
||||
if main_ref is None:
|
||||
return []
|
||||
head_target = main_ref.target
|
||||
|
||||
results = []
|
||||
walker = repo.walk(head_target, pygit2.GIT_SORT_TIME)
|
||||
|
||||
for commit in walker:
|
||||
# Check if device_id subtree exists in this commit's root tree.
|
||||
try:
|
||||
device_entry = commit.tree[device_id]
|
||||
except KeyError:
|
||||
# Device not present in this commit at all — skip.
|
||||
continue
|
||||
|
||||
# Only include this commit if it actually changed the device's subtree
|
||||
# vs its parent. This prevents every subsequent backup (for any device
|
||||
# in the same tenant) from appearing in all devices' histories.
|
||||
if commit.parents:
|
||||
parent = commit.parents[0]
|
||||
try:
|
||||
parent_device_entry = parent.tree[device_id]
|
||||
if parent_device_entry.id == device_entry.id:
|
||||
# Device subtree unchanged in this commit — skip.
|
||||
continue
|
||||
except KeyError:
|
||||
# Device wasn't in parent but is in this commit — it's the first entry.
|
||||
pass
|
||||
|
||||
results.append({
|
||||
"sha": str(commit.id),
|
||||
"message": commit.message.strip(),
|
||||
"timestamp": commit.commit_time,
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def compute_line_delta(old_text: str, new_text: str) -> tuple[int, int]:
|
||||
"""Compute (lines_added, lines_removed) between two text versions.
|
||||
|
||||
Uses difflib.SequenceMatcher to efficiently compute the line-count delta
|
||||
without generating a full unified diff. This is faster than
|
||||
difflib.unified_diff for large config files.
|
||||
|
||||
For the first backup (no prior version), pass old_text="" to get
|
||||
(total_lines, 0) as the delta.
|
||||
|
||||
Args:
|
||||
old_text: Previous export.rsc content (empty string for first backup).
|
||||
new_text: New export.rsc content.
|
||||
|
||||
Returns:
|
||||
Tuple of (lines_added, lines_removed).
|
||||
"""
|
||||
old_lines = old_text.splitlines() if old_text else []
|
||||
new_lines = new_text.splitlines() if new_text else []
|
||||
|
||||
if not old_lines and not new_lines:
|
||||
return (0, 0)
|
||||
|
||||
# For first backup (empty old), all lines are "added".
|
||||
if not old_lines:
|
||||
return (len(new_lines), 0)
|
||||
|
||||
# For deletion of all content, all lines are "removed".
|
||||
if not new_lines:
|
||||
return (0, len(old_lines))
|
||||
|
||||
matcher = difflib.SequenceMatcher(None, old_lines, new_lines, autojunk=False)
|
||||
|
||||
lines_added = 0
|
||||
lines_removed = 0
|
||||
|
||||
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
||||
if tag == "replace":
|
||||
lines_removed += i2 - i1
|
||||
lines_added += j2 - j1
|
||||
elif tag == "delete":
|
||||
lines_removed += i2 - i1
|
||||
elif tag == "insert":
|
||||
lines_added += j2 - j1
|
||||
# "equal" — no change
|
||||
|
||||
return (lines_added, lines_removed)
|
||||
324
backend/app/services/key_service.py
Normal file
324
backend/app/services/key_service.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""Key hierarchy management service for zero-knowledge architecture.
|
||||
|
||||
Provides CRUD operations for encrypted key bundles (UserKeySet),
|
||||
append-only audit logging (KeyAccessLog), and OpenBao Transit
|
||||
tenant key provisioning with credential migration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.key_set import KeyAccessLog, UserKeySet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def store_user_key_set(
|
||||
db: AsyncSession,
|
||||
user_id: UUID,
|
||||
tenant_id: UUID | None,
|
||||
encrypted_private_key: bytes,
|
||||
private_key_nonce: bytes,
|
||||
encrypted_vault_key: bytes,
|
||||
vault_key_nonce: bytes,
|
||||
public_key: bytes,
|
||||
pbkdf2_salt: bytes,
|
||||
hkdf_salt: bytes,
|
||||
pbkdf2_iterations: int = 650000,
|
||||
) -> UserKeySet:
|
||||
"""Store encrypted key bundle during registration.
|
||||
|
||||
Creates a new UserKeySet for the user. Each user has exactly one
|
||||
key set (UNIQUE constraint on user_id).
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
user_id: The user's UUID.
|
||||
tenant_id: The user's tenant UUID (None for super_admin).
|
||||
encrypted_private_key: RSA private key wrapped by AUK (AES-GCM).
|
||||
private_key_nonce: 12-byte AES-GCM nonce for private key.
|
||||
encrypted_vault_key: Tenant vault key wrapped by user's public key.
|
||||
vault_key_nonce: 12-byte AES-GCM nonce for vault key.
|
||||
public_key: RSA-2048 public key in SPKI format.
|
||||
pbkdf2_salt: 32-byte salt for PBKDF2 key derivation.
|
||||
hkdf_salt: 32-byte salt for HKDF Secret Key derivation.
|
||||
pbkdf2_iterations: PBKDF2 iteration count (default 650000).
|
||||
|
||||
Returns:
|
||||
The created UserKeySet instance.
|
||||
"""
|
||||
# Remove any existing key set (e.g. from a failed prior upgrade attempt)
|
||||
from sqlalchemy import delete
|
||||
await db.execute(delete(UserKeySet).where(UserKeySet.user_id == user_id))
|
||||
|
||||
key_set = UserKeySet(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
encrypted_private_key=encrypted_private_key,
|
||||
private_key_nonce=private_key_nonce,
|
||||
encrypted_vault_key=encrypted_vault_key,
|
||||
vault_key_nonce=vault_key_nonce,
|
||||
public_key=public_key,
|
||||
pbkdf2_salt=pbkdf2_salt,
|
||||
hkdf_salt=hkdf_salt,
|
||||
pbkdf2_iterations=pbkdf2_iterations,
|
||||
)
|
||||
db.add(key_set)
|
||||
await db.flush()
|
||||
return key_set
|
||||
|
||||
|
||||
async def get_user_key_set(
|
||||
db: AsyncSession, user_id: UUID
|
||||
) -> UserKeySet | None:
|
||||
"""Retrieve encrypted key bundle for login response.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
The UserKeySet if found, None otherwise.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(UserKeySet).where(UserKeySet.user_id == user_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def log_key_access(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
user_id: UUID | None,
|
||||
action: str,
|
||||
resource_type: str | None = None,
|
||||
resource_id: str | None = None,
|
||||
key_version: int | None = None,
|
||||
ip_address: str | None = None,
|
||||
device_id: UUID | None = None,
|
||||
justification: str | None = None,
|
||||
correlation_id: str | None = None,
|
||||
) -> None:
|
||||
"""Append to immutable key_access_log.
|
||||
|
||||
This table is append-only (INSERT+SELECT only via RLS policy).
|
||||
No UPDATE or DELETE is permitted.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
tenant_id: The tenant UUID for RLS isolation.
|
||||
user_id: The user who performed the action (None for system ops).
|
||||
action: Action description (e.g., 'create_key_set', 'decrypt_vault_key').
|
||||
resource_type: Optional resource type being accessed.
|
||||
resource_id: Optional resource identifier.
|
||||
key_version: Optional key version involved.
|
||||
ip_address: Optional client IP address.
|
||||
device_id: Optional device UUID for credential access tracking.
|
||||
justification: Optional justification for the access (e.g., 'api_backup').
|
||||
correlation_id: Optional correlation ID for request tracing.
|
||||
"""
|
||||
log_entry = KeyAccessLog(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
key_version=key_version,
|
||||
ip_address=ip_address,
|
||||
device_id=device_id,
|
||||
justification=justification,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
db.add(log_entry)
|
||||
await db.flush()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenBao Transit tenant key provisioning and credential migration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def provision_tenant_key(db: AsyncSession, tenant_id: UUID) -> str:
|
||||
"""Provision an OpenBao Transit key for a tenant and update the tenant record.
|
||||
|
||||
Idempotent: if the key already exists in OpenBao, it's a no-op on the
|
||||
OpenBao side. The tenant record is always updated with the key name.
|
||||
|
||||
Args:
|
||||
db: Async database session (admin engine, no RLS).
|
||||
tenant_id: Tenant UUID.
|
||||
|
||||
Returns:
|
||||
The key name (tenant_{uuid}).
|
||||
"""
|
||||
from app.models.tenant import Tenant
|
||||
from app.services.openbao_service import get_openbao_service
|
||||
|
||||
openbao = get_openbao_service()
|
||||
key_name = f"tenant_{tenant_id}"
|
||||
|
||||
await openbao.create_tenant_key(str(tenant_id))
|
||||
|
||||
# Update tenant record with key name
|
||||
result = await db.execute(
|
||||
select(Tenant).where(Tenant.id == tenant_id)
|
||||
)
|
||||
tenant = result.scalar_one_or_none()
|
||||
if tenant:
|
||||
tenant.openbao_key_name = key_name
|
||||
await db.flush()
|
||||
|
||||
logger.info(
|
||||
"Provisioned OpenBao Transit key for tenant %s (key=%s)",
|
||||
tenant_id,
|
||||
key_name,
|
||||
)
|
||||
return key_name
|
||||
|
||||
|
||||
async def migrate_tenant_credentials(db: AsyncSession, tenant_id: UUID) -> dict:
|
||||
"""Re-encrypt all legacy credentials for a tenant from AES-256-GCM to Transit.
|
||||
|
||||
Migrates device credentials, CA private keys, device cert private keys,
|
||||
and notification channel secrets. Already-migrated items are skipped.
|
||||
|
||||
Args:
|
||||
db: Async database session (admin engine, no RLS).
|
||||
tenant_id: Tenant UUID.
|
||||
|
||||
Returns:
|
||||
Dict with counts: {"devices": N, "cas": N, "certs": N, "channels": N, "errors": N}
|
||||
"""
|
||||
from app.config import settings
|
||||
from app.models.alert import NotificationChannel
|
||||
from app.models.certificate import CertificateAuthority, DeviceCertificate
|
||||
from app.models.device import Device
|
||||
from app.services.crypto import decrypt_credentials
|
||||
from app.services.openbao_service import get_openbao_service
|
||||
|
||||
openbao = get_openbao_service()
|
||||
legacy_key = settings.get_encryption_key_bytes()
|
||||
tid = str(tenant_id)
|
||||
|
||||
counts = {"devices": 0, "cas": 0, "certs": 0, "channels": 0, "errors": 0}
|
||||
|
||||
# --- Migrate device credentials ---
|
||||
result = await db.execute(
|
||||
select(Device).where(
|
||||
Device.tenant_id == tenant_id,
|
||||
Device.encrypted_credentials.isnot(None),
|
||||
(Device.encrypted_credentials_transit.is_(None) | (Device.encrypted_credentials_transit == "")),
|
||||
)
|
||||
)
|
||||
for device in result.scalars().all():
|
||||
try:
|
||||
plaintext = decrypt_credentials(device.encrypted_credentials, legacy_key)
|
||||
device.encrypted_credentials_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
|
||||
counts["devices"] += 1
|
||||
except Exception as e:
|
||||
logger.error("Failed to migrate device %s credentials: %s", device.id, e)
|
||||
counts["errors"] += 1
|
||||
|
||||
# --- Migrate CA private keys ---
|
||||
result = await db.execute(
|
||||
select(CertificateAuthority).where(
|
||||
CertificateAuthority.tenant_id == tenant_id,
|
||||
CertificateAuthority.encrypted_private_key.isnot(None),
|
||||
(CertificateAuthority.encrypted_private_key_transit.is_(None) | (CertificateAuthority.encrypted_private_key_transit == "")),
|
||||
)
|
||||
)
|
||||
for ca in result.scalars().all():
|
||||
try:
|
||||
plaintext = decrypt_credentials(ca.encrypted_private_key, legacy_key)
|
||||
ca.encrypted_private_key_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
|
||||
counts["cas"] += 1
|
||||
except Exception as e:
|
||||
logger.error("Failed to migrate CA %s private key: %s", ca.id, e)
|
||||
counts["errors"] += 1
|
||||
|
||||
# --- Migrate device cert private keys ---
|
||||
result = await db.execute(
|
||||
select(DeviceCertificate).where(
|
||||
DeviceCertificate.tenant_id == tenant_id,
|
||||
DeviceCertificate.encrypted_private_key.isnot(None),
|
||||
(DeviceCertificate.encrypted_private_key_transit.is_(None) | (DeviceCertificate.encrypted_private_key_transit == "")),
|
||||
)
|
||||
)
|
||||
for cert in result.scalars().all():
|
||||
try:
|
||||
plaintext = decrypt_credentials(cert.encrypted_private_key, legacy_key)
|
||||
cert.encrypted_private_key_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
|
||||
counts["certs"] += 1
|
||||
except Exception as e:
|
||||
logger.error("Failed to migrate cert %s private key: %s", cert.id, e)
|
||||
counts["errors"] += 1
|
||||
|
||||
# --- Migrate notification channel secrets ---
|
||||
result = await db.execute(
|
||||
select(NotificationChannel).where(
|
||||
NotificationChannel.tenant_id == tenant_id,
|
||||
)
|
||||
)
|
||||
for ch in result.scalars().all():
|
||||
migrated_any = False
|
||||
try:
|
||||
# SMTP password
|
||||
if ch.smtp_password and not ch.smtp_password_transit:
|
||||
plaintext = decrypt_credentials(ch.smtp_password, legacy_key)
|
||||
ch.smtp_password_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
|
||||
migrated_any = True
|
||||
if migrated_any:
|
||||
counts["channels"] += 1
|
||||
except Exception as e:
|
||||
logger.error("Failed to migrate channel %s secrets: %s", ch.id, e)
|
||||
counts["errors"] += 1
|
||||
|
||||
await db.flush()
|
||||
|
||||
logger.info(
|
||||
"Tenant %s credential migration complete: %s",
|
||||
tenant_id,
|
||||
counts,
|
||||
)
|
||||
return counts
|
||||
|
||||
|
||||
async def provision_existing_tenants(db: AsyncSession) -> dict:
|
||||
"""Provision OpenBao Transit keys for all existing tenants and migrate credentials.
|
||||
|
||||
Called on app startup to ensure all tenants have Transit keys.
|
||||
Idempotent -- running multiple times is safe (already-migrated items are skipped).
|
||||
|
||||
Args:
|
||||
db: Async database session (admin engine, no RLS).
|
||||
|
||||
Returns:
|
||||
Summary dict with total counts across all tenants.
|
||||
"""
|
||||
from app.models.tenant import Tenant
|
||||
|
||||
result = await db.execute(select(Tenant))
|
||||
tenants = result.scalars().all()
|
||||
|
||||
total = {"tenants": len(tenants), "devices": 0, "cas": 0, "certs": 0, "channels": 0, "errors": 0}
|
||||
|
||||
for tenant in tenants:
|
||||
try:
|
||||
await provision_tenant_key(db, tenant.id)
|
||||
counts = await migrate_tenant_credentials(db, tenant.id)
|
||||
total["devices"] += counts["devices"]
|
||||
total["cas"] += counts["cas"]
|
||||
total["certs"] += counts["certs"]
|
||||
total["channels"] += counts["channels"]
|
||||
total["errors"] += counts["errors"]
|
||||
except Exception as e:
|
||||
logger.error("Failed to provision/migrate tenant %s: %s", tenant.id, e)
|
||||
total["errors"] += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info("Existing tenant provisioning complete: %s", total)
|
||||
return total
|
||||
346
backend/app/services/metrics_subscriber.py
Normal file
346
backend/app/services/metrics_subscriber.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""NATS JetStream subscriber for device metrics events.
|
||||
|
||||
Subscribes to device.metrics.> and inserts into TimescaleDB hypertables:
|
||||
- interface_metrics — per-interface rx/tx byte counters
|
||||
- health_metrics — CPU, memory, disk, temperature per device
|
||||
- wireless_metrics — per-wireless-interface aggregated client stats
|
||||
|
||||
Also maintains denormalized last_cpu_load and last_memory_used_pct columns
|
||||
on the devices table for efficient fleet table display.
|
||||
|
||||
Uses AdminAsyncSessionLocal (superuser bypass RLS) so metrics from any tenant
|
||||
can be written without setting app.current_tenant.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import nats
|
||||
from nats.js import JetStreamContext
|
||||
from nats.aio.client import Client as NATSClient
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_metrics_client: Optional[NATSClient] = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# INSERT HANDLERS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _parse_timestamp(val: str | None) -> datetime:
|
||||
"""Parse an ISO 8601 / RFC 3339 timestamp string into a datetime object."""
|
||||
if not val:
|
||||
return datetime.now(timezone.utc)
|
||||
try:
|
||||
return datetime.fromisoformat(val.replace("Z", "+00:00"))
|
||||
except (ValueError, AttributeError):
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
async def _insert_health_metrics(session, data: dict) -> None:
|
||||
"""Insert a health metrics event into health_metrics and update devices."""
|
||||
health = data.get("health")
|
||||
if not health:
|
||||
logger.warning("health metrics event missing 'health' field — skipping")
|
||||
return
|
||||
|
||||
device_id = data.get("device_id")
|
||||
tenant_id = data.get("tenant_id")
|
||||
collected_at = _parse_timestamp(data.get("collected_at"))
|
||||
|
||||
# Parse numeric values; treat empty strings as NULL.
|
||||
def parse_int(val: str | None) -> int | None:
|
||||
if not val:
|
||||
return None
|
||||
try:
|
||||
return int(val)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
cpu_load = parse_int(health.get("cpu_load"))
|
||||
free_memory = parse_int(health.get("free_memory"))
|
||||
total_memory = parse_int(health.get("total_memory"))
|
||||
free_disk = parse_int(health.get("free_disk"))
|
||||
total_disk = parse_int(health.get("total_disk"))
|
||||
temperature = parse_int(health.get("temperature"))
|
||||
|
||||
await session.execute(
|
||||
text("""
|
||||
INSERT INTO health_metrics
|
||||
(time, device_id, tenant_id, cpu_load, free_memory, total_memory,
|
||||
free_disk, total_disk, temperature)
|
||||
VALUES
|
||||
(:time, :device_id, :tenant_id, :cpu_load, :free_memory, :total_memory,
|
||||
:free_disk, :total_disk, :temperature)
|
||||
"""),
|
||||
{
|
||||
"time": collected_at,
|
||||
"device_id": device_id,
|
||||
"tenant_id": tenant_id,
|
||||
"cpu_load": cpu_load,
|
||||
"free_memory": free_memory,
|
||||
"total_memory": total_memory,
|
||||
"free_disk": free_disk,
|
||||
"total_disk": total_disk,
|
||||
"temperature": temperature,
|
||||
},
|
||||
)
|
||||
|
||||
# Update denormalized columns on devices for fleet table display.
|
||||
# Compute memory percentage in Python to avoid asyncpg type ambiguity.
|
||||
mem_pct = None
|
||||
if total_memory and total_memory > 0 and free_memory is not None:
|
||||
mem_pct = round((1.0 - free_memory / total_memory) * 100)
|
||||
|
||||
await session.execute(
|
||||
text("""
|
||||
UPDATE devices SET
|
||||
last_cpu_load = COALESCE(:cpu_load, last_cpu_load),
|
||||
last_memory_used_pct = COALESCE(:mem_pct, last_memory_used_pct),
|
||||
updated_at = NOW()
|
||||
WHERE id = CAST(:device_id AS uuid)
|
||||
"""),
|
||||
{
|
||||
"cpu_load": cpu_load,
|
||||
"mem_pct": mem_pct,
|
||||
"device_id": device_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _insert_interface_metrics(session, data: dict) -> None:
|
||||
"""Insert per-interface traffic counters into interface_metrics."""
|
||||
interfaces = data.get("interfaces")
|
||||
if not interfaces:
|
||||
return # Device may have no interfaces (unlikely but safe to skip)
|
||||
|
||||
device_id = data.get("device_id")
|
||||
tenant_id = data.get("tenant_id")
|
||||
collected_at = _parse_timestamp(data.get("collected_at"))
|
||||
|
||||
for iface in interfaces:
|
||||
await session.execute(
|
||||
text("""
|
||||
INSERT INTO interface_metrics
|
||||
(time, device_id, tenant_id, interface, rx_bytes, tx_bytes, rx_bps, tx_bps)
|
||||
VALUES
|
||||
(:time, :device_id, :tenant_id, :interface, :rx_bytes, :tx_bytes, NULL, NULL)
|
||||
"""),
|
||||
{
|
||||
"time": collected_at,
|
||||
"device_id": device_id,
|
||||
"tenant_id": tenant_id,
|
||||
"interface": iface.get("name"),
|
||||
"rx_bytes": iface.get("rx_bytes"),
|
||||
"tx_bytes": iface.get("tx_bytes"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _insert_wireless_metrics(session, data: dict) -> None:
|
||||
"""Insert per-wireless-interface aggregated client stats into wireless_metrics."""
|
||||
wireless = data.get("wireless")
|
||||
if not wireless:
|
||||
return # Device may have no wireless interfaces
|
||||
|
||||
device_id = data.get("device_id")
|
||||
tenant_id = data.get("tenant_id")
|
||||
collected_at = _parse_timestamp(data.get("collected_at"))
|
||||
|
||||
for wif in wireless:
|
||||
await session.execute(
|
||||
text("""
|
||||
INSERT INTO wireless_metrics
|
||||
(time, device_id, tenant_id, interface, client_count, avg_signal, ccq, frequency)
|
||||
VALUES
|
||||
(:time, :device_id, :tenant_id, :interface,
|
||||
:client_count, :avg_signal, :ccq, :frequency)
|
||||
"""),
|
||||
{
|
||||
"time": collected_at,
|
||||
"device_id": device_id,
|
||||
"tenant_id": tenant_id,
|
||||
"interface": wif.get("interface"),
|
||||
"client_count": wif.get("client_count"),
|
||||
"avg_signal": wif.get("avg_signal"),
|
||||
"ccq": wif.get("ccq"),
|
||||
"frequency": wif.get("frequency"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MAIN MESSAGE HANDLER
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def on_device_metrics(msg) -> None:
|
||||
"""Handle a device.metrics event published by the Go poller.
|
||||
|
||||
Dispatches to the appropriate insert handler based on the 'type' field:
|
||||
- "health" → _insert_health_metrics + update devices
|
||||
- "interfaces" → _insert_interface_metrics
|
||||
- "wireless" → _insert_wireless_metrics
|
||||
|
||||
On success, acknowledges the message. On error, NAKs so NATS can redeliver.
|
||||
"""
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
metric_type = data.get("type")
|
||||
device_id = data.get("device_id")
|
||||
|
||||
if not metric_type or not device_id:
|
||||
logger.warning(
|
||||
"device.metrics event missing 'type' or 'device_id' — skipping"
|
||||
)
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
if metric_type == "health":
|
||||
await _insert_health_metrics(session, data)
|
||||
elif metric_type == "interfaces":
|
||||
await _insert_interface_metrics(session, data)
|
||||
elif metric_type == "wireless":
|
||||
await _insert_wireless_metrics(session, data)
|
||||
else:
|
||||
logger.warning("Unknown metric type '%s' — skipping", metric_type)
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Alert evaluation — non-fatal; metric write is the primary operation
|
||||
try:
|
||||
from app.services import alert_evaluator
|
||||
await alert_evaluator.evaluate(
|
||||
device_id=device_id,
|
||||
tenant_id=data.get("tenant_id", ""),
|
||||
metric_type=metric_type,
|
||||
data=data,
|
||||
)
|
||||
except Exception as eval_err:
|
||||
logger.warning("Alert evaluation failed for device %s: %s", device_id, eval_err)
|
||||
|
||||
logger.debug(
|
||||
"device.metrics processed",
|
||||
extra={"device_id": device_id, "type": metric_type},
|
||||
)
|
||||
await msg.ack()
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to process device.metrics event: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
await msg.nak()
|
||||
except Exception:
|
||||
pass # If NAK also fails, NATS will redeliver after ack_wait
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SUBSCRIPTION SETUP
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def _subscribe_with_retry(js: JetStreamContext) -> None:
|
||||
"""Subscribe to device.metrics.> with durable consumer, retrying if stream not ready."""
|
||||
max_attempts = 6 # ~30 seconds at 5s intervals
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
await js.subscribe(
|
||||
"device.metrics.>",
|
||||
cb=on_device_metrics,
|
||||
durable="api-metrics-consumer",
|
||||
stream="DEVICE_EVENTS",
|
||||
)
|
||||
logger.info(
|
||||
"NATS: subscribed to device.metrics.> (durable: api-metrics-consumer)"
|
||||
)
|
||||
return
|
||||
except Exception as exc:
|
||||
if attempt < max_attempts:
|
||||
logger.warning(
|
||||
"NATS: stream DEVICE_EVENTS not ready for metrics (attempt %d/%d): %s — retrying in 5s",
|
||||
attempt,
|
||||
max_attempts,
|
||||
exc,
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
logger.warning(
|
||||
"NATS: giving up on device.metrics.> after %d attempts: %s — API will run without metrics ingestion",
|
||||
max_attempts,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
async def start_metrics_subscriber() -> Optional[NATSClient]:
|
||||
"""Connect to NATS and start the device.metrics.> subscription.
|
||||
|
||||
Uses a separate NATS connection from the status subscriber — simpler and
|
||||
NATS handles multiple connections per client efficiently.
|
||||
|
||||
Returns the NATS connection (must be passed to stop_metrics_subscriber on shutdown).
|
||||
Raises on fatal connection errors after retry exhaustion.
|
||||
"""
|
||||
global _metrics_client
|
||||
|
||||
logger.info("NATS metrics: connecting to %s", settings.NATS_URL)
|
||||
|
||||
nc = await nats.connect(
|
||||
settings.NATS_URL,
|
||||
max_reconnect_attempts=-1,
|
||||
reconnect_time_wait=2,
|
||||
error_cb=_on_error,
|
||||
reconnected_cb=_on_reconnected,
|
||||
disconnected_cb=_on_disconnected,
|
||||
)
|
||||
|
||||
logger.info("NATS metrics: connected to %s", settings.NATS_URL)
|
||||
|
||||
js = nc.jetstream()
|
||||
await _subscribe_with_retry(js)
|
||||
|
||||
_metrics_client = nc
|
||||
return nc
|
||||
|
||||
|
||||
async def stop_metrics_subscriber(nc: Optional[NATSClient]) -> None:
|
||||
"""Drain and close the metrics NATS connection gracefully."""
|
||||
if nc is None:
|
||||
return
|
||||
try:
|
||||
logger.info("NATS metrics: draining connection...")
|
||||
await nc.drain()
|
||||
logger.info("NATS metrics: connection closed")
|
||||
except Exception as exc:
|
||||
logger.warning("NATS metrics: error during drain: %s", exc)
|
||||
try:
|
||||
await nc.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _on_error(exc: Exception) -> None:
|
||||
logger.error("NATS metrics error: %s", exc)
|
||||
|
||||
|
||||
async def _on_reconnected() -> None:
|
||||
logger.info("NATS metrics: reconnected")
|
||||
|
||||
|
||||
async def _on_disconnected() -> None:
|
||||
logger.warning("NATS metrics: disconnected")
|
||||
231
backend/app/services/nats_subscriber.py
Normal file
231
backend/app/services/nats_subscriber.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""NATS JetStream subscriber for device status events from the Go poller.
|
||||
|
||||
Subscribes to device.status.> and updates device records in PostgreSQL.
|
||||
This is a system-level process that needs to update devices across all tenants,
|
||||
so it uses the admin engine (bypasses RLS).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import nats
|
||||
from nats.js import JetStreamContext
|
||||
from nats.aio.client import Client as NATSClient
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_nats_client: Optional[NATSClient] = None
|
||||
|
||||
# Regex for RouterOS uptime strings like "42d14h23m15s", "14h23m15s", "23m15s", "3w2d"
|
||||
_UPTIME_RE = re.compile(r"(?:(\d+)w)?(?:(\d+)d)?(?:(\d+)h)?(?:(\d+)m)?(?:(\d+)s)?")
|
||||
|
||||
|
||||
def _parse_uptime(raw: str) -> int | None:
|
||||
"""Parse a RouterOS uptime string into total seconds."""
|
||||
if not raw:
|
||||
return None
|
||||
m = _UPTIME_RE.fullmatch(raw)
|
||||
if not m:
|
||||
return None
|
||||
weeks = int(m.group(1) or 0)
|
||||
days = int(m.group(2) or 0)
|
||||
hours = int(m.group(3) or 0)
|
||||
minutes = int(m.group(4) or 0)
|
||||
seconds = int(m.group(5) or 0)
|
||||
total = weeks * 604800 + days * 86400 + hours * 3600 + minutes * 60 + seconds
|
||||
return total if total > 0 else None
|
||||
|
||||
|
||||
async def on_device_status(msg) -> None:
|
||||
"""Handle a device.status event published by the Go poller.
|
||||
|
||||
Payload (JSON):
|
||||
device_id (str) — UUID of the device
|
||||
tenant_id (str) — UUID of the owning tenant
|
||||
status (str) — "online" or "offline"
|
||||
routeros_version (str | None) — e.g. "7.16.2"
|
||||
major_version (int | None) — e.g. 7
|
||||
board_name (str | None) — e.g. "RB4011iGS+5HacQ2HnD"
|
||||
last_seen (str | None) — ISO-8601 timestamp
|
||||
"""
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
device_id = data.get("device_id")
|
||||
status = data.get("status")
|
||||
routeros_version = data.get("routeros_version")
|
||||
major_version = data.get("major_version")
|
||||
board_name = data.get("board_name")
|
||||
last_seen_raw = data.get("last_seen")
|
||||
serial_number = data.get("serial_number") or None
|
||||
firmware_version = data.get("firmware_version") or None
|
||||
uptime_seconds = _parse_uptime(data.get("uptime", ""))
|
||||
|
||||
if not device_id or not status:
|
||||
logger.warning("Received device.status event with missing device_id or status — skipping")
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
# Parse timestamp in Python — asyncpg needs datetime objects, not strings
|
||||
last_seen_dt = None
|
||||
if last_seen_raw:
|
||||
try:
|
||||
last_seen_dt = datetime.fromisoformat(last_seen_raw.replace("Z", "+00:00"))
|
||||
except (ValueError, AttributeError):
|
||||
last_seen_dt = datetime.now(timezone.utc)
|
||||
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE devices SET
|
||||
status = :status,
|
||||
routeros_version = COALESCE(:routeros_version, routeros_version),
|
||||
routeros_major_version = COALESCE(:major_version, routeros_major_version),
|
||||
model = COALESCE(:board_name, model),
|
||||
serial_number = COALESCE(:serial_number, serial_number),
|
||||
firmware_version = COALESCE(:firmware_version, firmware_version),
|
||||
uptime_seconds = COALESCE(:uptime_seconds, uptime_seconds),
|
||||
last_seen = COALESCE(:last_seen, last_seen),
|
||||
updated_at = NOW()
|
||||
WHERE id = CAST(:device_id AS uuid)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"status": status,
|
||||
"routeros_version": routeros_version,
|
||||
"major_version": major_version,
|
||||
"board_name": board_name,
|
||||
"serial_number": serial_number,
|
||||
"firmware_version": firmware_version,
|
||||
"uptime_seconds": uptime_seconds,
|
||||
"last_seen": last_seen_dt,
|
||||
"device_id": device_id,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Alert evaluation for offline/online status changes — non-fatal
|
||||
try:
|
||||
from app.services import alert_evaluator
|
||||
if status == "offline":
|
||||
await alert_evaluator.evaluate_offline(device_id, data.get("tenant_id", ""))
|
||||
elif status == "online":
|
||||
await alert_evaluator.evaluate_online(device_id, data.get("tenant_id", ""))
|
||||
except Exception as e:
|
||||
logger.warning("Alert evaluation failed for device %s status=%s: %s", device_id, status, e)
|
||||
|
||||
logger.info(
|
||||
"Device status updated",
|
||||
extra={
|
||||
"device_id": device_id,
|
||||
"status": status,
|
||||
"routeros_version": routeros_version,
|
||||
},
|
||||
)
|
||||
await msg.ack()
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to process device.status event: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
await msg.nak()
|
||||
except Exception:
|
||||
pass # If NAK also fails, NATS will redeliver after ack_wait
|
||||
|
||||
|
||||
async def _subscribe_with_retry(js: JetStreamContext) -> None:
|
||||
"""Subscribe to device.status.> with durable consumer, retrying if stream not ready."""
|
||||
max_attempts = 6 # ~30 seconds at 5s intervals
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
await js.subscribe(
|
||||
"device.status.>",
|
||||
cb=on_device_status,
|
||||
durable="api-status-consumer",
|
||||
stream="DEVICE_EVENTS",
|
||||
)
|
||||
logger.info("NATS: subscribed to device.status.> (durable: api-status-consumer)")
|
||||
return
|
||||
except Exception as exc:
|
||||
if attempt < max_attempts:
|
||||
logger.warning(
|
||||
"NATS: stream DEVICE_EVENTS not ready (attempt %d/%d): %s — retrying in 5s",
|
||||
attempt,
|
||||
max_attempts,
|
||||
exc,
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
logger.warning(
|
||||
"NATS: giving up on device.status.> after %d attempts: %s — API will run without real-time status updates",
|
||||
max_attempts,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
async def start_nats_subscriber() -> Optional[NATSClient]:
|
||||
"""Connect to NATS and start the device.status.> subscription.
|
||||
|
||||
Returns the NATS connection (must be passed to stop_nats_subscriber on shutdown).
|
||||
Raises on fatal connection errors after retry exhaustion.
|
||||
"""
|
||||
global _nats_client
|
||||
|
||||
logger.info("NATS: connecting to %s", settings.NATS_URL)
|
||||
|
||||
nc = await nats.connect(
|
||||
settings.NATS_URL,
|
||||
max_reconnect_attempts=-1, # reconnect forever (pod-to-pod transient failures)
|
||||
reconnect_time_wait=2,
|
||||
error_cb=_on_error,
|
||||
reconnected_cb=_on_reconnected,
|
||||
disconnected_cb=_on_disconnected,
|
||||
)
|
||||
|
||||
logger.info("NATS: connected to %s", settings.NATS_URL)
|
||||
|
||||
js = nc.jetstream()
|
||||
await _subscribe_with_retry(js)
|
||||
|
||||
_nats_client = nc
|
||||
return nc
|
||||
|
||||
|
||||
async def stop_nats_subscriber(nc: Optional[NATSClient]) -> None:
|
||||
"""Drain and close the NATS connection gracefully."""
|
||||
if nc is None:
|
||||
return
|
||||
try:
|
||||
logger.info("NATS: draining connection...")
|
||||
await nc.drain()
|
||||
logger.info("NATS: connection closed")
|
||||
except Exception as exc:
|
||||
logger.warning("NATS: error during drain: %s", exc)
|
||||
try:
|
||||
await nc.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _on_error(exc: Exception) -> None:
|
||||
logger.error("NATS error: %s", exc)
|
||||
|
||||
|
||||
async def _on_reconnected() -> None:
|
||||
logger.info("NATS: reconnected")
|
||||
|
||||
|
||||
async def _on_disconnected() -> None:
|
||||
logger.warning("NATS: disconnected")
|
||||
256
backend/app/services/notification_service.py
Normal file
256
backend/app/services/notification_service.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""Email and webhook notification delivery for alert events.
|
||||
|
||||
Best-effort delivery: failures are logged but never raised.
|
||||
Each dispatch is wrapped in try/except so one failing channel
|
||||
doesn't prevent delivery to other channels.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def dispatch_notifications(
|
||||
alert_event: dict[str, Any],
|
||||
channels: list[dict[str, Any]],
|
||||
device_hostname: str,
|
||||
) -> None:
|
||||
"""Send notifications for an alert event to all provided channels.
|
||||
|
||||
Args:
|
||||
alert_event: Dict with alert event fields (status, severity, metric, etc.)
|
||||
channels: List of notification channel dicts
|
||||
device_hostname: Human-readable device name for messages
|
||||
"""
|
||||
for channel in channels:
|
||||
try:
|
||||
if channel["channel_type"] == "email":
|
||||
await _send_email(channel, alert_event, device_hostname)
|
||||
elif channel["channel_type"] == "webhook":
|
||||
await _send_webhook(channel, alert_event, device_hostname)
|
||||
elif channel["channel_type"] == "slack":
|
||||
await _send_slack(channel, alert_event, device_hostname)
|
||||
else:
|
||||
logger.warning("Unknown channel type: %s", channel["channel_type"])
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Notification delivery failed for channel %s (%s): %s",
|
||||
channel.get("name"), channel.get("channel_type"), e,
|
||||
)
|
||||
|
||||
|
||||
async def _send_email(channel: dict, alert_event: dict, device_hostname: str) -> None:
|
||||
"""Send alert notification email using per-channel SMTP config."""
|
||||
from app.services.email_service import SMTPConfig, send_email
|
||||
|
||||
severity = alert_event.get("severity", "warning")
|
||||
status = alert_event.get("status", "firing")
|
||||
rule_name = alert_event.get("rule_name") or alert_event.get("message", "Unknown Rule")
|
||||
metric = alert_event.get("metric_name") or alert_event.get("metric", "")
|
||||
value = alert_event.get("current_value") or alert_event.get("value", "")
|
||||
threshold = alert_event.get("threshold", "")
|
||||
|
||||
severity_colors = {
|
||||
"critical": "#ef4444",
|
||||
"warning": "#f59e0b",
|
||||
"info": "#38bdf8",
|
||||
}
|
||||
color = severity_colors.get(severity, "#38bdf8")
|
||||
status_label = "RESOLVED" if status == "resolved" else "FIRING"
|
||||
|
||||
html = f"""
|
||||
<div style="font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; max-width: 600px; margin: 0 auto;">
|
||||
<div style="background: {color}; padding: 16px 24px; border-radius: 8px 8px 0 0;">
|
||||
<h2 style="color: #fff; margin: 0;">[{status_label}] {rule_name}</h2>
|
||||
</div>
|
||||
<div style="background: #1e293b; padding: 24px; border-radius: 0 0 8px 8px; color: #e2e8f0;">
|
||||
<table style="width: 100%; border-collapse: collapse;">
|
||||
<tr><td style="padding: 8px 0; color: #94a3b8;">Device</td><td style="padding: 8px 0;">{device_hostname}</td></tr>
|
||||
<tr><td style="padding: 8px 0; color: #94a3b8;">Severity</td><td style="padding: 8px 0;">{severity.upper()}</td></tr>
|
||||
<tr><td style="padding: 8px 0; color: #94a3b8;">Metric</td><td style="padding: 8px 0;">{metric}</td></tr>
|
||||
<tr><td style="padding: 8px 0; color: #94a3b8;">Value</td><td style="padding: 8px 0;">{value}</td></tr>
|
||||
<tr><td style="padding: 8px 0; color: #94a3b8;">Threshold</td><td style="padding: 8px 0;">{threshold}</td></tr>
|
||||
</table>
|
||||
<p style="color: #64748b; font-size: 12px; margin-top: 24px;">
|
||||
TOD — Fleet Management for MikroTik RouterOS
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
plain = (
|
||||
f"[{status_label}] {rule_name}\n\n"
|
||||
f"Device: {device_hostname}\n"
|
||||
f"Severity: {severity}\n"
|
||||
f"Metric: {metric}\n"
|
||||
f"Value: {value}\n"
|
||||
f"Threshold: {threshold}\n"
|
||||
)
|
||||
|
||||
# Decrypt SMTP password (Transit first, then legacy Fernet)
|
||||
smtp_password = None
|
||||
transit_cipher = channel.get("smtp_password_transit")
|
||||
legacy_cipher = channel.get("smtp_password")
|
||||
tenant_id = channel.get("tenant_id")
|
||||
|
||||
if transit_cipher and tenant_id:
|
||||
try:
|
||||
from app.services.kms_service import decrypt_transit
|
||||
smtp_password = await decrypt_transit(transit_cipher, tenant_id)
|
||||
except Exception:
|
||||
logger.warning("Transit decryption failed for channel %s, trying legacy", channel.get("id"))
|
||||
|
||||
if not smtp_password and legacy_cipher:
|
||||
try:
|
||||
from app.config import settings as app_settings
|
||||
from cryptography.fernet import Fernet
|
||||
raw = bytes(legacy_cipher) if isinstance(legacy_cipher, memoryview) else legacy_cipher
|
||||
f = Fernet(app_settings.CREDENTIAL_ENCRYPTION_KEY.encode())
|
||||
smtp_password = f.decrypt(raw).decode()
|
||||
except Exception:
|
||||
logger.warning("Legacy decryption failed for channel %s", channel.get("id"))
|
||||
|
||||
config = SMTPConfig(
|
||||
host=channel.get("smtp_host", "localhost"),
|
||||
port=channel.get("smtp_port", 587),
|
||||
user=channel.get("smtp_user"),
|
||||
password=smtp_password,
|
||||
use_tls=channel.get("smtp_use_tls", False),
|
||||
from_address=channel.get("from_address") or "alerts@mikrotik-portal.local",
|
||||
)
|
||||
|
||||
to = channel.get("to_address")
|
||||
subject = f"[TOD {status_label}] {rule_name} — {device_hostname}"
|
||||
await send_email(to, subject, html, plain, config)
|
||||
|
||||
|
||||
async def _send_webhook(
|
||||
channel: dict[str, Any],
|
||||
alert_event: dict[str, Any],
|
||||
device_hostname: str,
|
||||
) -> None:
|
||||
"""Send alert notification to a webhook URL (Slack-compatible JSON)."""
|
||||
severity = alert_event.get("severity", "info")
|
||||
status = alert_event.get("status", "firing")
|
||||
metric = alert_event.get("metric")
|
||||
value = alert_event.get("value")
|
||||
threshold = alert_event.get("threshold")
|
||||
message_text = alert_event.get("message", "")
|
||||
|
||||
payload = {
|
||||
"alert_name": message_text,
|
||||
"severity": severity,
|
||||
"status": status,
|
||||
"device": device_hostname,
|
||||
"device_id": alert_event.get("device_id"),
|
||||
"metric": metric,
|
||||
"value": value,
|
||||
"threshold": threshold,
|
||||
"timestamp": str(alert_event.get("fired_at", "")),
|
||||
"text": f"[{severity.upper()}] {device_hostname}: {message_text}",
|
||||
}
|
||||
|
||||
webhook_url = channel.get("webhook_url", "")
|
||||
if not webhook_url:
|
||||
logger.warning("Webhook channel %s has no URL configured", channel.get("name"))
|
||||
return
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.post(webhook_url, json=payload)
|
||||
logger.info(
|
||||
"Webhook notification sent to %s — status %d",
|
||||
webhook_url, response.status_code,
|
||||
)
|
||||
|
||||
|
||||
async def _send_slack(
|
||||
channel: dict[str, Any],
|
||||
alert_event: dict[str, Any],
|
||||
device_hostname: str,
|
||||
) -> None:
|
||||
"""Send alert notification to Slack via incoming webhook with Block Kit formatting."""
|
||||
severity = alert_event.get("severity", "info").upper()
|
||||
status = alert_event.get("status", "firing")
|
||||
metric = alert_event.get("metric", "unknown")
|
||||
message_text = alert_event.get("message", "")
|
||||
value = alert_event.get("value")
|
||||
threshold = alert_event.get("threshold")
|
||||
|
||||
color = {"CRITICAL": "#dc2626", "WARNING": "#f59e0b", "INFO": "#3b82f6"}.get(severity, "#6b7280")
|
||||
status_label = "RESOLVED" if status == "resolved" else status
|
||||
|
||||
blocks = [
|
||||
{
|
||||
"type": "header",
|
||||
"text": {"type": "plain_text", "text": f"{'✅' if status == 'resolved' else '🚨'} [{severity}] {status_label.upper()}"},
|
||||
},
|
||||
{
|
||||
"type": "section",
|
||||
"fields": [
|
||||
{"type": "mrkdwn", "text": f"*Device:*\n{device_hostname}"},
|
||||
{"type": "mrkdwn", "text": f"*Metric:*\n{metric}"},
|
||||
],
|
||||
},
|
||||
]
|
||||
if value is not None or threshold is not None:
|
||||
fields = []
|
||||
if value is not None:
|
||||
fields.append({"type": "mrkdwn", "text": f"*Value:*\n{value}"})
|
||||
if threshold is not None:
|
||||
fields.append({"type": "mrkdwn", "text": f"*Threshold:*\n{threshold}"})
|
||||
blocks.append({"type": "section", "fields": fields})
|
||||
|
||||
if message_text:
|
||||
blocks.append({"type": "section", "text": {"type": "mrkdwn", "text": f"*Message:*\n{message_text}"}})
|
||||
|
||||
blocks.append({"type": "context", "elements": [{"type": "mrkdwn", "text": "TOD Alert System"}]})
|
||||
|
||||
slack_url = channel.get("slack_webhook_url", "")
|
||||
if not slack_url:
|
||||
logger.warning("Slack channel %s has no webhook URL configured", channel.get("name"))
|
||||
return
|
||||
|
||||
payload = {"attachments": [{"color": color, "blocks": blocks}]}
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.post(slack_url, json=payload)
|
||||
logger.info("Slack notification sent — status %d", response.status_code)
|
||||
|
||||
|
||||
async def send_test_notification(channel: dict[str, Any]) -> bool:
|
||||
"""Send a test notification through a channel to verify configuration.
|
||||
|
||||
Args:
|
||||
channel: Notification channel dict with all config fields
|
||||
|
||||
Returns:
|
||||
True on success
|
||||
|
||||
Raises:
|
||||
Exception on delivery failure (caller handles)
|
||||
"""
|
||||
test_event = {
|
||||
"status": "test",
|
||||
"severity": "info",
|
||||
"metric": "test",
|
||||
"value": None,
|
||||
"threshold": None,
|
||||
"message": "Test notification from TOD",
|
||||
"device_id": "00000000-0000-0000-0000-000000000000",
|
||||
"fired_at": "",
|
||||
}
|
||||
|
||||
if channel["channel_type"] == "email":
|
||||
await _send_email(channel, test_event, "Test Device")
|
||||
elif channel["channel_type"] == "webhook":
|
||||
await _send_webhook(channel, test_event, "Test Device")
|
||||
elif channel["channel_type"] == "slack":
|
||||
await _send_slack(channel, test_event, "Test Device")
|
||||
else:
|
||||
raise ValueError(f"Unknown channel type: {channel['channel_type']}")
|
||||
|
||||
return True
|
||||
174
backend/app/services/openbao_service.py
Normal file
174
backend/app/services/openbao_service.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
OpenBao Transit secrets engine client for per-tenant envelope encryption.
|
||||
|
||||
Provides encrypt/decrypt operations via OpenBao's HTTP API. Each tenant gets
|
||||
a dedicated Transit key (tenant_{uuid}) for AES-256-GCM encryption. The key
|
||||
material never leaves OpenBao -- the application only sees ciphertext.
|
||||
|
||||
Ciphertext format: "vault:v1:base64..." (compatible with Vault Transit format)
|
||||
"""
|
||||
import base64
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenBaoTransitService:
|
||||
"""Async client for OpenBao Transit secrets engine."""
|
||||
|
||||
def __init__(self, addr: str | None = None, token: str | None = None):
|
||||
self.addr = addr or settings.OPENBAO_ADDR
|
||||
self.token = token or settings.OPENBAO_TOKEN
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None or self._client.is_closed:
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.addr,
|
||||
headers={"X-Vault-Token": self.token},
|
||||
timeout=5.0,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._client and not self._client.is_closed:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def create_tenant_key(self, tenant_id: str) -> None:
|
||||
"""Create Transit encryption keys for a tenant (credential + data). Idempotent."""
|
||||
client = await self._get_client()
|
||||
|
||||
# Credential key: tenant_{uuid}
|
||||
key_name = f"tenant_{tenant_id}"
|
||||
resp = await client.post(
|
||||
f"/v1/transit/keys/{key_name}",
|
||||
json={"type": "aes256-gcm96"},
|
||||
)
|
||||
if resp.status_code not in (200, 204):
|
||||
resp.raise_for_status()
|
||||
logger.info("OpenBao Transit key ensured", extra={"key_name": key_name})
|
||||
|
||||
# Data key: tenant_{uuid}_data (Phase 30)
|
||||
await self.create_tenant_data_key(tenant_id)
|
||||
|
||||
async def encrypt(self, tenant_id: str, plaintext: bytes) -> str:
|
||||
"""Encrypt plaintext via Transit engine. Returns ciphertext string."""
|
||||
client = await self._get_client()
|
||||
key_name = f"tenant_{tenant_id}"
|
||||
resp = await client.post(
|
||||
f"/v1/transit/encrypt/{key_name}",
|
||||
json={"plaintext": base64.b64encode(plaintext).decode()},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
ciphertext = resp.json()["data"]["ciphertext"]
|
||||
return ciphertext # "vault:v1:..."
|
||||
|
||||
async def decrypt(self, tenant_id: str, ciphertext: str) -> bytes:
|
||||
"""Decrypt Transit ciphertext. Returns plaintext bytes."""
|
||||
client = await self._get_client()
|
||||
key_name = f"tenant_{tenant_id}"
|
||||
resp = await client.post(
|
||||
f"/v1/transit/decrypt/{key_name}",
|
||||
json={"ciphertext": ciphertext},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
plaintext_b64 = resp.json()["data"]["plaintext"]
|
||||
return base64.b64decode(plaintext_b64)
|
||||
|
||||
async def key_exists(self, tenant_id: str) -> bool:
|
||||
"""Check if a Transit key exists for a tenant."""
|
||||
client = await self._get_client()
|
||||
key_name = f"tenant_{tenant_id}"
|
||||
resp = await client.get(f"/v1/transit/keys/{key_name}")
|
||||
return resp.status_code == 200
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Data encryption keys (tenant_{uuid}_data) — Phase 30
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def create_tenant_data_key(self, tenant_id: str) -> None:
|
||||
"""Create a Transit data encryption key for a tenant. Idempotent.
|
||||
|
||||
Data keys use the suffix '_data' to separate them from credential keys.
|
||||
Key naming: tenant_{uuid}_data (vs tenant_{uuid} for credentials).
|
||||
"""
|
||||
client = await self._get_client()
|
||||
key_name = f"tenant_{tenant_id}_data"
|
||||
resp = await client.post(
|
||||
f"/v1/transit/keys/{key_name}",
|
||||
json={"type": "aes256-gcm96"},
|
||||
)
|
||||
if resp.status_code not in (200, 204):
|
||||
resp.raise_for_status()
|
||||
logger.info("OpenBao Transit data key ensured", extra={"key_name": key_name})
|
||||
|
||||
async def ensure_tenant_data_key(self, tenant_id: str) -> None:
|
||||
"""Ensure a data encryption key exists for a tenant. Idempotent.
|
||||
|
||||
Checks existence first and creates if missing. Safe to call on every
|
||||
encrypt operation (fast path: single GET to check existence).
|
||||
"""
|
||||
client = await self._get_client()
|
||||
key_name = f"tenant_{tenant_id}_data"
|
||||
resp = await client.get(f"/v1/transit/keys/{key_name}")
|
||||
if resp.status_code != 200:
|
||||
await self.create_tenant_data_key(tenant_id)
|
||||
|
||||
async def encrypt_data(self, tenant_id: str, plaintext: bytes) -> str:
|
||||
"""Encrypt data via Transit using per-tenant data key.
|
||||
|
||||
Uses the tenant_{uuid}_data key (separate from credential key).
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID string.
|
||||
plaintext: Raw bytes to encrypt.
|
||||
|
||||
Returns:
|
||||
Transit ciphertext string (vault:v1:...).
|
||||
"""
|
||||
client = await self._get_client()
|
||||
key_name = f"tenant_{tenant_id}_data"
|
||||
resp = await client.post(
|
||||
f"/v1/transit/encrypt/{key_name}",
|
||||
json={"plaintext": base64.b64encode(plaintext).decode()},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["data"]["ciphertext"]
|
||||
|
||||
async def decrypt_data(self, tenant_id: str, ciphertext: str) -> bytes:
|
||||
"""Decrypt Transit data ciphertext using per-tenant data key.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID string.
|
||||
ciphertext: Transit ciphertext (vault:v1:...).
|
||||
|
||||
Returns:
|
||||
Decrypted plaintext bytes.
|
||||
"""
|
||||
client = await self._get_client()
|
||||
key_name = f"tenant_{tenant_id}_data"
|
||||
resp = await client.post(
|
||||
f"/v1/transit/decrypt/{key_name}",
|
||||
json={"ciphertext": ciphertext},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
plaintext_b64 = resp.json()["data"]["plaintext"]
|
||||
return base64.b64decode(plaintext_b64)
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
_openbao_service: Optional[OpenBaoTransitService] = None
|
||||
|
||||
|
||||
def get_openbao_service() -> OpenBaoTransitService:
|
||||
"""Return module-level OpenBao Transit service singleton."""
|
||||
global _openbao_service
|
||||
if _openbao_service is None:
|
||||
_openbao_service = OpenBaoTransitService()
|
||||
return _openbao_service
|
||||
141
backend/app/services/push_rollback_subscriber.py
Normal file
141
backend/app/services/push_rollback_subscriber.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""NATS subscribers for push rollback (auto) and push alert (manual).
|
||||
|
||||
- config.push.rollback.> -> auto-restore for template pushes
|
||||
- config.push.alert.> -> create alert for editor pushes
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
from app.models.alert import AlertEvent
|
||||
from app.services import restore_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_nc: Optional[Any] = None
|
||||
|
||||
|
||||
async def _create_push_alert(device_id: str, tenant_id: str, push_type: str) -> None:
|
||||
"""Create a high-priority alert for device offline after config push."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
alert = AlertEvent(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
status="firing",
|
||||
severity="critical",
|
||||
message=f"Device went offline after config {push_type} — rollback available",
|
||||
)
|
||||
session.add(alert)
|
||||
await session.commit()
|
||||
logger.info("Created push alert for device %s (type=%s)", device_id, push_type)
|
||||
|
||||
|
||||
async def handle_push_rollback(event: dict) -> None:
|
||||
"""Auto-rollback: restore device to pre-push config."""
|
||||
device_id = event.get("device_id")
|
||||
tenant_id = event.get("tenant_id")
|
||||
commit_sha = event.get("pre_push_commit_sha")
|
||||
|
||||
if not all([device_id, tenant_id, commit_sha]):
|
||||
logger.warning("Push rollback event missing fields: %s", event)
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"AUTO-ROLLBACK: Device %s offline after template push, restoring to %s",
|
||||
device_id,
|
||||
commit_sha,
|
||||
)
|
||||
|
||||
try:
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await restore_service.restore_config(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
commit_sha=commit_sha,
|
||||
db_session=session,
|
||||
)
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"Auto-rollback result for device %s: %s",
|
||||
device_id,
|
||||
result.get("status"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Auto-rollback failed for device %s: %s", device_id, e)
|
||||
await _create_push_alert(device_id, tenant_id, "template (auto-rollback failed)")
|
||||
|
||||
|
||||
async def handle_push_alert(event: dict) -> None:
|
||||
"""Alert: create notification for device offline after editor push."""
|
||||
device_id = event.get("device_id")
|
||||
tenant_id = event.get("tenant_id")
|
||||
push_type = event.get("push_type", "editor")
|
||||
|
||||
if not device_id or not tenant_id:
|
||||
logger.warning("Push alert event missing fields: %s", event)
|
||||
return
|
||||
|
||||
await _create_push_alert(device_id, tenant_id, push_type)
|
||||
|
||||
|
||||
async def _on_rollback_message(msg) -> None:
|
||||
"""NATS message handler for config.push.rollback.> subjects."""
|
||||
try:
|
||||
event = json.loads(msg.data.decode())
|
||||
await handle_push_rollback(event)
|
||||
await msg.ack()
|
||||
except Exception as e:
|
||||
logger.error("Error handling rollback message: %s", e)
|
||||
await msg.nak()
|
||||
|
||||
|
||||
async def _on_alert_message(msg) -> None:
|
||||
"""NATS message handler for config.push.alert.> subjects."""
|
||||
try:
|
||||
event = json.loads(msg.data.decode())
|
||||
await handle_push_alert(event)
|
||||
await msg.ack()
|
||||
except Exception as e:
|
||||
logger.error("Error handling push alert message: %s", e)
|
||||
await msg.nak()
|
||||
|
||||
|
||||
async def start_push_rollback_subscriber() -> Optional[Any]:
|
||||
"""Connect to NATS and subscribe to push rollback/alert events."""
|
||||
import nats
|
||||
|
||||
global _nc
|
||||
try:
|
||||
logger.info("NATS push-rollback: connecting to %s", settings.NATS_URL)
|
||||
_nc = await nats.connect(settings.NATS_URL)
|
||||
js = _nc.jetstream()
|
||||
await js.subscribe(
|
||||
"config.push.rollback.>",
|
||||
cb=_on_rollback_message,
|
||||
durable="api-push-rollback-consumer",
|
||||
stream="DEVICE_EVENTS",
|
||||
manual_ack=True,
|
||||
)
|
||||
await js.subscribe(
|
||||
"config.push.alert.>",
|
||||
cb=_on_alert_message,
|
||||
durable="api-push-alert-consumer",
|
||||
stream="DEVICE_EVENTS",
|
||||
manual_ack=True,
|
||||
)
|
||||
logger.info("Push rollback/alert subscriber started")
|
||||
return _nc
|
||||
except Exception as e:
|
||||
logger.error("Failed to start push rollback subscriber: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
async def stop_push_rollback_subscriber() -> None:
|
||||
"""Gracefully close the NATS connection."""
|
||||
global _nc
|
||||
if _nc:
|
||||
await _nc.drain()
|
||||
_nc = None
|
||||
70
backend/app/services/push_tracker.py
Normal file
70
backend/app/services/push_tracker.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Track recent config pushes in Redis for poller-aware rollback.
|
||||
|
||||
When a device goes offline shortly after a push, the poller checks these
|
||||
keys and triggers rollback (template/restore) or alert (editor).
|
||||
|
||||
Redis key format: push:recent:{device_id}
|
||||
TTL: 300 seconds (5 minutes)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import redis.asyncio as redis
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PUSH_TTL_SECONDS = 300 # 5 minutes
|
||||
|
||||
_redis: Optional[redis.Redis] = None
|
||||
|
||||
|
||||
async def _get_redis() -> redis.Redis:
|
||||
global _redis
|
||||
if _redis is None:
|
||||
_redis = redis.from_url(settings.REDIS_URL)
|
||||
return _redis
|
||||
|
||||
|
||||
async def record_push(
|
||||
device_id: str,
|
||||
tenant_id: str,
|
||||
push_type: str,
|
||||
push_operation_id: str = "",
|
||||
pre_push_commit_sha: str = "",
|
||||
) -> None:
|
||||
"""Record a recent config push in Redis.
|
||||
|
||||
Args:
|
||||
device_id: UUID of the device.
|
||||
tenant_id: UUID of the tenant.
|
||||
push_type: 'template' (auto-rollback) or 'editor' (alert only) or 'restore'.
|
||||
push_operation_id: ID of the ConfigPushOperation row.
|
||||
pre_push_commit_sha: Git SHA of the pre-push backup (for rollback).
|
||||
"""
|
||||
r = await _get_redis()
|
||||
key = f"push:recent:{device_id}"
|
||||
value = json.dumps({
|
||||
"device_id": device_id,
|
||||
"tenant_id": tenant_id,
|
||||
"push_type": push_type,
|
||||
"push_operation_id": push_operation_id,
|
||||
"pre_push_commit_sha": pre_push_commit_sha,
|
||||
})
|
||||
await r.set(key, value, ex=PUSH_TTL_SECONDS)
|
||||
logger.debug(
|
||||
"Recorded push for device %s (type=%s, TTL=%ds)",
|
||||
device_id,
|
||||
push_type,
|
||||
PUSH_TTL_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
async def clear_push(device_id: str) -> None:
|
||||
"""Clear the push tracking key (e.g., after successful commit)."""
|
||||
r = await _get_redis()
|
||||
await r.delete(f"push:recent:{device_id}")
|
||||
logger.debug("Cleared push tracking for device %s", device_id)
|
||||
572
backend/app/services/report_service.py
Normal file
572
backend/app/services/report_service.py
Normal file
@@ -0,0 +1,572 @@
|
||||
"""Report generation service.
|
||||
|
||||
Generates PDF (via Jinja2 + weasyprint) and CSV reports for:
|
||||
- Device inventory
|
||||
- Metrics summary
|
||||
- Alert history
|
||||
- Change log (audit_logs if available, else config_backups fallback)
|
||||
|
||||
Phase 30 NOTE: Reports are currently ephemeral (generated on-demand per request,
|
||||
never stored at rest). DATAENC-03 requires "report content is encrypted before
|
||||
storage." Since no report storage exists yet, encryption will be applied when
|
||||
report caching/storage is added. The generation pipeline is Transit-ready --
|
||||
wrap the file_bytes with encrypt_data_transit() before any future INSERT.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import structlog
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
# Jinja2 environment pointing at the templates directory
|
||||
_TEMPLATE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "templates")
|
||||
_jinja_env = Environment(
|
||||
loader=FileSystemLoader(_TEMPLATE_DIR),
|
||||
autoescape=True,
|
||||
)
|
||||
|
||||
|
||||
async def generate_report(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
report_type: str,
|
||||
date_from: Optional[datetime],
|
||||
date_to: Optional[datetime],
|
||||
fmt: str = "pdf",
|
||||
) -> tuple[bytes, str, str]:
|
||||
"""Generate a report and return (file_bytes, content_type, filename).
|
||||
|
||||
Args:
|
||||
db: RLS-enforced async session (tenant context already set).
|
||||
tenant_id: Tenant UUID for scoping.
|
||||
report_type: One of device_inventory, metrics_summary, alert_history, change_log.
|
||||
date_from: Start date for time-ranged reports.
|
||||
date_to: End date for time-ranged reports.
|
||||
fmt: Output format -- "pdf" or "csv".
|
||||
|
||||
Returns:
|
||||
Tuple of (file_bytes, content_type, filename).
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
# Fetch tenant name for the header
|
||||
tenant_name = await _get_tenant_name(db, tenant_id)
|
||||
|
||||
# Dispatch to the appropriate handler
|
||||
handlers = {
|
||||
"device_inventory": _device_inventory,
|
||||
"metrics_summary": _metrics_summary,
|
||||
"alert_history": _alert_history,
|
||||
"change_log": _change_log,
|
||||
}
|
||||
handler = handlers[report_type]
|
||||
template_data = await handler(db, tenant_id, date_from, date_to)
|
||||
|
||||
# Common template context
|
||||
generated_at = datetime.utcnow().strftime("%Y-%m-%d %H:%M UTC")
|
||||
base_context = {
|
||||
"tenant_name": tenant_name,
|
||||
"generated_at": generated_at,
|
||||
}
|
||||
|
||||
timestamp_str = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
if fmt == "csv":
|
||||
file_bytes = _render_csv(report_type, template_data)
|
||||
content_type = "text/csv; charset=utf-8"
|
||||
filename = f"{report_type}_{timestamp_str}.csv"
|
||||
else:
|
||||
file_bytes = _render_pdf(report_type, {**base_context, **template_data})
|
||||
content_type = "application/pdf"
|
||||
filename = f"{report_type}_{timestamp_str}.pdf"
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info(
|
||||
"report_generated",
|
||||
report_type=report_type,
|
||||
format=fmt,
|
||||
tenant_id=str(tenant_id),
|
||||
size_bytes=len(file_bytes),
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
)
|
||||
|
||||
return file_bytes, content_type, filename
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tenant name helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _get_tenant_name(db: AsyncSession, tenant_id: UUID) -> str:
|
||||
"""Fetch the tenant name by ID."""
|
||||
result = await db.execute(
|
||||
text("SELECT name FROM tenants WHERE id = CAST(:tid AS uuid)"),
|
||||
{"tid": str(tenant_id)},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row[0] if row else "Unknown Tenant"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Report type handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _device_inventory(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
date_from: Optional[datetime],
|
||||
date_to: Optional[datetime],
|
||||
) -> dict[str, Any]:
|
||||
"""Gather device inventory data."""
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT d.hostname, d.ip_address, d.model, d.routeros_version,
|
||||
d.status, d.last_seen, d.uptime_seconds,
|
||||
COALESCE(
|
||||
(SELECT string_agg(dg.name, ', ')
|
||||
FROM device_group_memberships dgm
|
||||
JOIN device_groups dg ON dg.id = dgm.group_id
|
||||
WHERE dgm.device_id = d.id),
|
||||
''
|
||||
) AS groups
|
||||
FROM devices d
|
||||
ORDER BY d.hostname ASC
|
||||
""")
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
devices = []
|
||||
online_count = 0
|
||||
offline_count = 0
|
||||
unknown_count = 0
|
||||
|
||||
for row in rows:
|
||||
status = row[4]
|
||||
if status == "online":
|
||||
online_count += 1
|
||||
elif status == "offline":
|
||||
offline_count += 1
|
||||
else:
|
||||
unknown_count += 1
|
||||
|
||||
uptime_str = _format_uptime(row[6]) if row[6] else None
|
||||
last_seen_str = row[5].strftime("%Y-%m-%d %H:%M") if row[5] else None
|
||||
|
||||
devices.append({
|
||||
"hostname": row[0],
|
||||
"ip_address": row[1],
|
||||
"model": row[2],
|
||||
"routeros_version": row[3],
|
||||
"status": status,
|
||||
"last_seen": last_seen_str,
|
||||
"uptime": uptime_str,
|
||||
"groups": row[7] if row[7] else None,
|
||||
})
|
||||
|
||||
return {
|
||||
"report_title": "Device Inventory",
|
||||
"devices": devices,
|
||||
"total_devices": len(devices),
|
||||
"online_count": online_count,
|
||||
"offline_count": offline_count,
|
||||
"unknown_count": unknown_count,
|
||||
}
|
||||
|
||||
|
||||
async def _metrics_summary(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
date_from: Optional[datetime],
|
||||
date_to: Optional[datetime],
|
||||
) -> dict[str, Any]:
|
||||
"""Gather metrics summary data grouped by device."""
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT d.hostname,
|
||||
AVG(hm.cpu_load) AS avg_cpu,
|
||||
MAX(hm.cpu_load) AS peak_cpu,
|
||||
AVG(CASE WHEN hm.total_memory > 0
|
||||
THEN 100.0 * (hm.total_memory - hm.free_memory) / hm.total_memory
|
||||
END) AS avg_mem,
|
||||
MAX(CASE WHEN hm.total_memory > 0
|
||||
THEN 100.0 * (hm.total_memory - hm.free_memory) / hm.total_memory
|
||||
END) AS peak_mem,
|
||||
AVG(CASE WHEN hm.total_disk > 0
|
||||
THEN 100.0 * (hm.total_disk - hm.free_disk) / hm.total_disk
|
||||
END) AS avg_disk,
|
||||
AVG(hm.temperature) AS avg_temp,
|
||||
COUNT(*) AS data_points
|
||||
FROM health_metrics hm
|
||||
JOIN devices d ON d.id = hm.device_id
|
||||
WHERE hm.time >= :date_from
|
||||
AND hm.time <= :date_to
|
||||
GROUP BY d.id, d.hostname
|
||||
ORDER BY avg_cpu DESC NULLS LAST
|
||||
"""),
|
||||
{
|
||||
"date_from": date_from,
|
||||
"date_to": date_to,
|
||||
},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
devices = []
|
||||
for row in rows:
|
||||
devices.append({
|
||||
"hostname": row[0],
|
||||
"avg_cpu": float(row[1]) if row[1] is not None else None,
|
||||
"peak_cpu": float(row[2]) if row[2] is not None else None,
|
||||
"avg_mem": float(row[3]) if row[3] is not None else None,
|
||||
"peak_mem": float(row[4]) if row[4] is not None else None,
|
||||
"avg_disk": float(row[5]) if row[5] is not None else None,
|
||||
"avg_temp": float(row[6]) if row[6] is not None else None,
|
||||
"data_points": row[7],
|
||||
})
|
||||
|
||||
return {
|
||||
"report_title": "Metrics Summary",
|
||||
"devices": devices,
|
||||
"date_from": date_from.strftime("%Y-%m-%d") if date_from else "",
|
||||
"date_to": date_to.strftime("%Y-%m-%d") if date_to else "",
|
||||
}
|
||||
|
||||
|
||||
async def _alert_history(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
date_from: Optional[datetime],
|
||||
date_to: Optional[datetime],
|
||||
) -> dict[str, Any]:
|
||||
"""Gather alert history data."""
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT ae.fired_at, ae.resolved_at, ae.severity, ae.status,
|
||||
ae.message, d.hostname,
|
||||
EXTRACT(EPOCH FROM (ae.resolved_at - ae.fired_at)) AS duration_secs
|
||||
FROM alert_events ae
|
||||
LEFT JOIN devices d ON d.id = ae.device_id
|
||||
WHERE ae.fired_at >= :date_from
|
||||
AND ae.fired_at <= :date_to
|
||||
ORDER BY ae.fired_at DESC
|
||||
"""),
|
||||
{
|
||||
"date_from": date_from,
|
||||
"date_to": date_to,
|
||||
},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
alerts = []
|
||||
critical_count = 0
|
||||
warning_count = 0
|
||||
info_count = 0
|
||||
resolved_durations: list[float] = []
|
||||
|
||||
for row in rows:
|
||||
severity = row[2]
|
||||
if severity == "critical":
|
||||
critical_count += 1
|
||||
elif severity == "warning":
|
||||
warning_count += 1
|
||||
else:
|
||||
info_count += 1
|
||||
|
||||
duration_secs = float(row[6]) if row[6] is not None else None
|
||||
if duration_secs is not None:
|
||||
resolved_durations.append(duration_secs)
|
||||
|
||||
alerts.append({
|
||||
"fired_at": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
|
||||
"hostname": row[5],
|
||||
"severity": severity,
|
||||
"status": row[3],
|
||||
"message": row[4],
|
||||
"duration": _format_duration(duration_secs) if duration_secs is not None else None,
|
||||
})
|
||||
|
||||
mttr_minutes = None
|
||||
mttr_display = None
|
||||
if resolved_durations:
|
||||
avg_secs = sum(resolved_durations) / len(resolved_durations)
|
||||
mttr_minutes = round(avg_secs / 60, 1)
|
||||
mttr_display = _format_duration(avg_secs)
|
||||
|
||||
return {
|
||||
"report_title": "Alert History",
|
||||
"alerts": alerts,
|
||||
"total_alerts": len(alerts),
|
||||
"critical_count": critical_count,
|
||||
"warning_count": warning_count,
|
||||
"info_count": info_count,
|
||||
"mttr_minutes": mttr_minutes,
|
||||
"mttr_display": mttr_display,
|
||||
"date_from": date_from.strftime("%Y-%m-%d") if date_from else "",
|
||||
"date_to": date_to.strftime("%Y-%m-%d") if date_to else "",
|
||||
}
|
||||
|
||||
|
||||
async def _change_log(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
date_from: Optional[datetime],
|
||||
date_to: Optional[datetime],
|
||||
) -> dict[str, Any]:
|
||||
"""Gather change log data -- try audit_logs table first, fall back to config_backups."""
|
||||
# Check if audit_logs table exists (17-01 may not have run yet)
|
||||
has_audit_logs = await _table_exists(db, "audit_logs")
|
||||
|
||||
if has_audit_logs:
|
||||
return await _change_log_from_audit(db, date_from, date_to)
|
||||
else:
|
||||
return await _change_log_from_backups(db, date_from, date_to)
|
||||
|
||||
|
||||
async def _table_exists(db: AsyncSession, table_name: str) -> bool:
|
||||
"""Check if a table exists in the database."""
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_schema = 'public' AND table_name = :table_name
|
||||
)
|
||||
"""),
|
||||
{"table_name": table_name},
|
||||
)
|
||||
return bool(result.scalar())
|
||||
|
||||
|
||||
async def _change_log_from_audit(
|
||||
db: AsyncSession,
|
||||
date_from: Optional[datetime],
|
||||
date_to: Optional[datetime],
|
||||
) -> dict[str, Any]:
|
||||
"""Build change log from audit_logs table."""
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT al.created_at, u.name AS user_name, al.action,
|
||||
d.hostname, al.resource_type,
|
||||
al.details
|
||||
FROM audit_logs al
|
||||
LEFT JOIN users u ON u.id = al.user_id
|
||||
LEFT JOIN devices d ON d.id = al.device_id
|
||||
WHERE al.created_at >= :date_from
|
||||
AND al.created_at <= :date_to
|
||||
ORDER BY al.created_at DESC
|
||||
"""),
|
||||
{
|
||||
"date_from": date_from,
|
||||
"date_to": date_to,
|
||||
},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
entries = []
|
||||
for row in rows:
|
||||
entries.append({
|
||||
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
|
||||
"user": row[1],
|
||||
"action": row[2],
|
||||
"device": row[3],
|
||||
"details": row[4] or row[5] or "",
|
||||
})
|
||||
|
||||
return {
|
||||
"report_title": "Change Log",
|
||||
"entries": entries,
|
||||
"total_entries": len(entries),
|
||||
"data_source": "Audit Logs",
|
||||
"date_from": date_from.strftime("%Y-%m-%d") if date_from else "",
|
||||
"date_to": date_to.strftime("%Y-%m-%d") if date_to else "",
|
||||
}
|
||||
|
||||
|
||||
async def _change_log_from_backups(
|
||||
db: AsyncSession,
|
||||
date_from: Optional[datetime],
|
||||
date_to: Optional[datetime],
|
||||
) -> dict[str, Any]:
|
||||
"""Build change log from config_backups + alert_events as fallback."""
|
||||
# Config backups as change events
|
||||
backup_result = await db.execute(
|
||||
text("""
|
||||
SELECT cb.created_at, 'system' AS user_name, 'config_backup' AS action,
|
||||
d.hostname, cb.trigger_type AS details
|
||||
FROM config_backups cb
|
||||
JOIN devices d ON d.id = cb.device_id
|
||||
WHERE cb.created_at >= :date_from
|
||||
AND cb.created_at <= :date_to
|
||||
"""),
|
||||
{
|
||||
"date_from": date_from,
|
||||
"date_to": date_to,
|
||||
},
|
||||
)
|
||||
backup_rows = backup_result.fetchall()
|
||||
|
||||
# Alert events as change events
|
||||
alert_result = await db.execute(
|
||||
text("""
|
||||
SELECT ae.fired_at, 'system' AS user_name,
|
||||
ae.severity || '_alert' AS action,
|
||||
d.hostname, ae.message AS details
|
||||
FROM alert_events ae
|
||||
LEFT JOIN devices d ON d.id = ae.device_id
|
||||
WHERE ae.fired_at >= :date_from
|
||||
AND ae.fired_at <= :date_to
|
||||
"""),
|
||||
{
|
||||
"date_from": date_from,
|
||||
"date_to": date_to,
|
||||
},
|
||||
)
|
||||
alert_rows = alert_result.fetchall()
|
||||
|
||||
# Merge and sort by timestamp descending
|
||||
entries = []
|
||||
for row in backup_rows:
|
||||
entries.append({
|
||||
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
|
||||
"user": row[1],
|
||||
"action": row[2],
|
||||
"device": row[3],
|
||||
"details": row[4] or "",
|
||||
})
|
||||
for row in alert_rows:
|
||||
entries.append({
|
||||
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
|
||||
"user": row[1],
|
||||
"action": row[2],
|
||||
"device": row[3],
|
||||
"details": row[4] or "",
|
||||
})
|
||||
|
||||
# Sort by timestamp string descending
|
||||
entries.sort(key=lambda e: e["timestamp"], reverse=True)
|
||||
|
||||
return {
|
||||
"report_title": "Change Log",
|
||||
"entries": entries,
|
||||
"total_entries": len(entries),
|
||||
"data_source": "Backups + Alerts",
|
||||
"date_from": date_from.strftime("%Y-%m-%d") if date_from else "",
|
||||
"date_to": date_to.strftime("%Y-%m-%d") if date_to else "",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rendering helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _render_pdf(report_type: str, context: dict[str, Any]) -> bytes:
|
||||
"""Render HTML template and convert to PDF via weasyprint."""
|
||||
import weasyprint
|
||||
|
||||
template = _jinja_env.get_template(f"reports/{report_type}.html")
|
||||
html_str = template.render(**context)
|
||||
pdf_bytes = weasyprint.HTML(string=html_str).write_pdf()
|
||||
return pdf_bytes
|
||||
|
||||
|
||||
def _render_csv(report_type: str, data: dict[str, Any]) -> bytes:
|
||||
"""Render report data as CSV bytes."""
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
|
||||
if report_type == "device_inventory":
|
||||
writer.writerow([
|
||||
"Hostname", "IP Address", "Model", "RouterOS Version",
|
||||
"Status", "Last Seen", "Uptime", "Groups",
|
||||
])
|
||||
for d in data.get("devices", []):
|
||||
writer.writerow([
|
||||
d["hostname"], d["ip_address"], d["model"] or "",
|
||||
d["routeros_version"] or "", d["status"],
|
||||
d["last_seen"] or "", d["uptime"] or "",
|
||||
d["groups"] or "",
|
||||
])
|
||||
|
||||
elif report_type == "metrics_summary":
|
||||
writer.writerow([
|
||||
"Hostname", "Avg CPU %", "Peak CPU %", "Avg Memory %",
|
||||
"Peak Memory %", "Avg Disk %", "Avg Temp", "Data Points",
|
||||
])
|
||||
for d in data.get("devices", []):
|
||||
writer.writerow([
|
||||
d["hostname"],
|
||||
f"{d['avg_cpu']:.1f}" if d["avg_cpu"] is not None else "",
|
||||
f"{d['peak_cpu']:.1f}" if d["peak_cpu"] is not None else "",
|
||||
f"{d['avg_mem']:.1f}" if d["avg_mem"] is not None else "",
|
||||
f"{d['peak_mem']:.1f}" if d["peak_mem"] is not None else "",
|
||||
f"{d['avg_disk']:.1f}" if d["avg_disk"] is not None else "",
|
||||
f"{d['avg_temp']:.1f}" if d["avg_temp"] is not None else "",
|
||||
d["data_points"],
|
||||
])
|
||||
|
||||
elif report_type == "alert_history":
|
||||
writer.writerow([
|
||||
"Timestamp", "Device", "Severity", "Message", "Status", "Duration",
|
||||
])
|
||||
for a in data.get("alerts", []):
|
||||
writer.writerow([
|
||||
a["fired_at"], a["hostname"] or "", a["severity"],
|
||||
a["message"] or "", a["status"], a["duration"] or "",
|
||||
])
|
||||
|
||||
elif report_type == "change_log":
|
||||
writer.writerow([
|
||||
"Timestamp", "User", "Action", "Device", "Details",
|
||||
])
|
||||
for e in data.get("entries", []):
|
||||
writer.writerow([
|
||||
e["timestamp"], e["user"] or "", e["action"],
|
||||
e["device"] or "", e["details"] or "",
|
||||
])
|
||||
|
||||
return output.getvalue().encode("utf-8")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Formatting utilities
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _format_uptime(seconds: int) -> str:
|
||||
"""Format uptime seconds as human-readable string."""
|
||||
days = seconds // 86400
|
||||
hours = (seconds % 86400) // 3600
|
||||
minutes = (seconds % 3600) // 60
|
||||
if days > 0:
|
||||
return f"{days}d {hours}h {minutes}m"
|
||||
elif hours > 0:
|
||||
return f"{hours}h {minutes}m"
|
||||
else:
|
||||
return f"{minutes}m"
|
||||
|
||||
|
||||
def _format_duration(seconds: float) -> str:
|
||||
"""Format a duration in seconds as a human-readable string."""
|
||||
if seconds < 60:
|
||||
return f"{int(seconds)}s"
|
||||
elif seconds < 3600:
|
||||
return f"{int(seconds // 60)}m {int(seconds % 60)}s"
|
||||
elif seconds < 86400:
|
||||
hours = int(seconds // 3600)
|
||||
mins = int((seconds % 3600) // 60)
|
||||
return f"{hours}h {mins}m"
|
||||
else:
|
||||
days = int(seconds // 86400)
|
||||
hours = int((seconds % 86400) // 3600)
|
||||
return f"{days}d {hours}h"
|
||||
599
backend/app/services/restore_service.py
Normal file
599
backend/app/services/restore_service.py
Normal file
@@ -0,0 +1,599 @@
|
||||
"""Two-phase config push with panic-revert safety for RouterOS devices.
|
||||
|
||||
This module implements the critical safety mechanism for config restoration:
|
||||
|
||||
Phase 1 — Push:
|
||||
1. Pre-backup (mandatory) — snapshot current config before any changes
|
||||
2. Install panic-revert RouterOS scheduler — auto-reverts if device becomes
|
||||
unreachable (the scheduler fires after 90s and loads the pre-push backup)
|
||||
3. Push the target config via SSH /import
|
||||
|
||||
Phase 2 — Verification (60s settle window):
|
||||
4. Wait 60s for config to settle (scheduled processes restart, etc.)
|
||||
5. Reachability check via asyncssh
|
||||
6a. Reachable — remove panic-revert scheduler; mark operation committed
|
||||
6b. Unreachable — RouterOS is auto-reverting; mark operation reverted
|
||||
|
||||
Pitfall 6 handling:
|
||||
If the API pod restarts during the 60s window, the config_push_operations
|
||||
row with status='pending_verification' serves as the recovery signal.
|
||||
On startup, recover_stale_push_operations() resolves any stale rows.
|
||||
|
||||
Security policy:
|
||||
known_hosts=None — RouterOS self-signed host keys; mirrors InsecureSkipVerify
|
||||
used in the poller's TLS connection. See Pitfall 2 in 04-RESEARCH.md.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import asyncssh
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.database import set_tenant_context, AdminAsyncSessionLocal
|
||||
from app.models.config_backup import ConfigPushOperation
|
||||
from app.models.device import Device
|
||||
from app.services import backup_service, git_store
|
||||
from app.services.event_publisher import publish_event
|
||||
from app.services.push_tracker import record_push, clear_push
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Name of the panic-revert scheduler installed on the RouterOS device
|
||||
_PANIC_REVERT_SCHEDULER = "mikrotik-portal-panic-revert"
|
||||
# Name of the pre-push binary backup saved on device flash
|
||||
_PRE_PUSH_BACKUP = "portal-pre-push"
|
||||
# Name of the RSC file used for /import on device
|
||||
_RESTORE_RSC = "portal-restore.rsc"
|
||||
|
||||
|
||||
async def _publish_push_progress(
|
||||
tenant_id: str,
|
||||
device_id: str,
|
||||
stage: str,
|
||||
message: str,
|
||||
push_op_id: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Publish config push progress event to NATS (fire-and-forget)."""
|
||||
payload = {
|
||||
"event_type": "config_push",
|
||||
"tenant_id": tenant_id,
|
||||
"device_id": device_id,
|
||||
"stage": stage,
|
||||
"message": message,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"push_operation_id": push_op_id,
|
||||
}
|
||||
if error:
|
||||
payload["error"] = error
|
||||
await publish_event(f"config.push.{tenant_id}.{device_id}", payload)
|
||||
|
||||
|
||||
async def restore_config(
|
||||
device_id: str,
|
||||
tenant_id: str,
|
||||
commit_sha: str,
|
||||
db_session: AsyncSession,
|
||||
) -> dict:
|
||||
"""Restore a device config to a specific backup version via two-phase push.
|
||||
|
||||
Args:
|
||||
device_id: Device UUID as string.
|
||||
tenant_id: Tenant UUID as string.
|
||||
commit_sha: Git commit SHA of the backup version to restore.
|
||||
db_session: AsyncSession with RLS context already set (from API endpoint).
|
||||
|
||||
Returns:
|
||||
{
|
||||
"status": "committed" | "reverted" | "failed",
|
||||
"message": str,
|
||||
"pre_backup_sha": str,
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: If device not found or missing credentials.
|
||||
Exception: On SSH failure during push phase (reverted status logged).
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 1: Load device from DB and decrypt credentials
|
||||
# ------------------------------------------------------------------
|
||||
from sqlalchemy import select
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Device).where(Device.id == device_id) # type: ignore[arg-type]
|
||||
)
|
||||
device = result.scalar_one_or_none()
|
||||
if device is None:
|
||||
raise ValueError(f"Device {device_id!r} not found")
|
||||
|
||||
if not device.encrypted_credentials_transit and not device.encrypted_credentials:
|
||||
raise ValueError(
|
||||
f"Device {device_id!r} has no stored credentials — cannot perform restore"
|
||||
)
|
||||
|
||||
key = settings.get_encryption_key_bytes()
|
||||
from app.services.crypto import decrypt_credentials_hybrid
|
||||
creds_json = await decrypt_credentials_hybrid(
|
||||
device.encrypted_credentials_transit,
|
||||
device.encrypted_credentials,
|
||||
str(device.tenant_id),
|
||||
key,
|
||||
)
|
||||
creds = json.loads(creds_json)
|
||||
ssh_username = creds.get("username", "")
|
||||
ssh_password = creds.get("password", "")
|
||||
ip = device.ip_address
|
||||
|
||||
hostname = device.hostname or ip
|
||||
|
||||
# Publish "started" progress event
|
||||
await _publish_push_progress(tenant_id, device_id, "started", f"Config restore started for {hostname}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2: Read the target export.rsc from the backup commit
|
||||
# ------------------------------------------------------------------
|
||||
try:
|
||||
export_bytes = await loop.run_in_executor(
|
||||
None,
|
||||
git_store.read_file,
|
||||
tenant_id,
|
||||
commit_sha,
|
||||
device_id,
|
||||
"export.rsc",
|
||||
)
|
||||
except (KeyError, Exception) as exc:
|
||||
raise ValueError(
|
||||
f"Backup version {commit_sha!r} not found for device {device_id!r}: {exc}"
|
||||
) from exc
|
||||
|
||||
export_text = export_bytes.decode("utf-8", errors="replace")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 3: Mandatory pre-backup before push
|
||||
# ------------------------------------------------------------------
|
||||
await _publish_push_progress(tenant_id, device_id, "backing_up", f"Creating pre-restore backup for {hostname}")
|
||||
|
||||
logger.info(
|
||||
"Starting pre-restore backup for device %s (%s) before pushing commit %s",
|
||||
hostname,
|
||||
ip,
|
||||
commit_sha[:8],
|
||||
)
|
||||
pre_backup_result = await backup_service.run_backup(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
trigger_type="pre-restore",
|
||||
db_session=db_session,
|
||||
)
|
||||
pre_backup_sha = pre_backup_result["commit_sha"]
|
||||
logger.info("Pre-restore backup complete: %s", pre_backup_sha[:8])
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 4: Record push operation (pending_verification for recovery)
|
||||
# ------------------------------------------------------------------
|
||||
push_op = ConfigPushOperation(
|
||||
device_id=device.id,
|
||||
tenant_id=device.tenant_id,
|
||||
pre_push_commit_sha=pre_backup_sha,
|
||||
scheduler_name=_PANIC_REVERT_SCHEDULER,
|
||||
status="pending_verification",
|
||||
)
|
||||
db_session.add(push_op)
|
||||
await db_session.flush()
|
||||
push_op_id = push_op.id
|
||||
|
||||
logger.info(
|
||||
"Push op %s in pending_verification — if API restarts, "
|
||||
"recover_stale_push_operations() will resolve on next startup",
|
||||
push_op.id,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 5: SSH to device — install panic-revert, push config
|
||||
# ------------------------------------------------------------------
|
||||
push_op_id_str = str(push_op_id)
|
||||
await _publish_push_progress(tenant_id, device_id, "pushing", f"Pushing config to {hostname}", push_op_id=push_op_id_str)
|
||||
|
||||
logger.info(
|
||||
"Pushing config to device %s (%s): installing panic-revert scheduler and uploading config",
|
||||
hostname,
|
||||
ip,
|
||||
)
|
||||
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip,
|
||||
port=22,
|
||||
username=ssh_username,
|
||||
password=ssh_password,
|
||||
known_hosts=None, # RouterOS self-signed host keys — see module docstring
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
# 5a: Create binary backup on device as revert point
|
||||
await conn.run(
|
||||
f"/system backup save name={_PRE_PUSH_BACKUP} dont-encrypt=yes",
|
||||
check=True,
|
||||
)
|
||||
logger.debug("Pre-push binary backup saved on device as %s.backup", _PRE_PUSH_BACKUP)
|
||||
|
||||
# 5b: Install panic-revert RouterOS scheduler
|
||||
# The scheduler fires after 90s on startup and loads the pre-push backup.
|
||||
# This is the safety net: if the device becomes unreachable after push,
|
||||
# RouterOS will auto-revert to the known-good config on the next reboot
|
||||
# or after 90s of uptime.
|
||||
await conn.run(
|
||||
f"/system scheduler add "
|
||||
f'name="{_PANIC_REVERT_SCHEDULER}" '
|
||||
f"interval=90s "
|
||||
f'on-event=":delay 0; /system backup load name={_PRE_PUSH_BACKUP}" '
|
||||
f"start-time=startup",
|
||||
check=True,
|
||||
)
|
||||
logger.debug("Panic-revert scheduler installed on device")
|
||||
|
||||
# 5c: Upload export.rsc and /import it
|
||||
# Write the RSC content to the device filesystem via SSH exec,
|
||||
# then use /import to apply it. The file is cleaned up after import.
|
||||
# We use a here-doc approach: write content line-by-line via /file set.
|
||||
# RouterOS supports writing files via /tool fetch or direct file commands.
|
||||
# Simplest approach for large configs: use asyncssh's write_into to
|
||||
# write file content, then /import.
|
||||
#
|
||||
# RouterOS doesn't support direct SFTP uploads via SSH open_sftp() easily
|
||||
# for config files. Use the script approach instead:
|
||||
# /system script add + run + remove (avoids flash write concerns).
|
||||
#
|
||||
# Actually the simplest method: write the export.rsc line by line via
|
||||
# /file print / set commands is RouterOS 6 only and unreliable.
|
||||
# Best approach for RouterOS 7: use SFTP to upload the file.
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
async with sftp.open(_RESTORE_RSC, "wb") as f:
|
||||
await f.write(export_text.encode("utf-8"))
|
||||
logger.debug("Uploaded %s to device flash", _RESTORE_RSC)
|
||||
|
||||
# /import the config file
|
||||
import_result = await conn.run(
|
||||
f"/import file={_RESTORE_RSC}",
|
||||
check=False, # Don't raise on non-zero exit — import may succeed with warnings
|
||||
)
|
||||
logger.info(
|
||||
"Config import result for device %s: exit_status=%s stdout=%r",
|
||||
hostname,
|
||||
import_result.exit_status,
|
||||
(import_result.stdout or "")[:200],
|
||||
)
|
||||
|
||||
# Clean up the uploaded RSC file (best-effort)
|
||||
try:
|
||||
await conn.run(f"/file remove {_RESTORE_RSC}", check=True)
|
||||
except Exception as cleanup_err:
|
||||
logger.warning(
|
||||
"Failed to clean up %s from device %s: %s",
|
||||
_RESTORE_RSC,
|
||||
ip,
|
||||
cleanup_err,
|
||||
)
|
||||
|
||||
except Exception as push_err:
|
||||
logger.error(
|
||||
"SSH push phase failed for device %s (%s): %s",
|
||||
hostname,
|
||||
ip,
|
||||
push_err,
|
||||
)
|
||||
# Update push operation to failed
|
||||
await _update_push_op_status(push_op_id, "failed", db_session)
|
||||
await _publish_push_progress(
|
||||
tenant_id, device_id, "failed",
|
||||
f"Config push failed for {hostname}: {push_err}",
|
||||
push_op_id=push_op_id_str, error=str(push_err),
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"message": f"Config push failed during SSH phase: {push_err}",
|
||||
"pre_backup_sha": pre_backup_sha,
|
||||
}
|
||||
|
||||
# Record push in Redis so the poller can detect post-push offline events
|
||||
await record_push(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
push_type="restore",
|
||||
push_operation_id=push_op_id_str,
|
||||
pre_push_commit_sha=pre_backup_sha,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 6: Wait 60s for config to settle
|
||||
# ------------------------------------------------------------------
|
||||
await _publish_push_progress(tenant_id, device_id, "settling", f"Config pushed to {hostname} — waiting 60s for settle", push_op_id=push_op_id_str)
|
||||
|
||||
logger.info(
|
||||
"Config pushed to device %s — waiting 60s for config to settle",
|
||||
hostname,
|
||||
)
|
||||
await asyncio.sleep(60)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 7: Reachability check
|
||||
# ------------------------------------------------------------------
|
||||
await _publish_push_progress(tenant_id, device_id, "verifying", f"Verifying device {hostname} reachability", push_op_id=push_op_id_str)
|
||||
|
||||
reachable = await _check_reachability(ip, ssh_username, ssh_password)
|
||||
|
||||
if reachable:
|
||||
# ------------------------------------------------------------------
|
||||
# Step 8a: Device is reachable — remove panic-revert scheduler + cleanup
|
||||
# ------------------------------------------------------------------
|
||||
logger.info("Device %s (%s) is reachable after push — committing", hostname, ip)
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip,
|
||||
port=22,
|
||||
username=ssh_username,
|
||||
password=ssh_password,
|
||||
known_hosts=None,
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
# Remove the panic-revert scheduler
|
||||
await conn.run(
|
||||
f'/system scheduler remove "{_PANIC_REVERT_SCHEDULER}"',
|
||||
check=False, # Non-fatal if already removed
|
||||
)
|
||||
# Clean up the pre-push binary backup from device flash
|
||||
await conn.run(
|
||||
f"/file remove {_PRE_PUSH_BACKUP}.backup",
|
||||
check=False, # Non-fatal if already removed
|
||||
)
|
||||
except Exception as cleanup_err:
|
||||
# Cleanup failure is non-fatal — scheduler will eventually fire but
|
||||
# the backup is now the correct config, so it's acceptable.
|
||||
logger.warning(
|
||||
"Failed to clean up panic-revert scheduler/backup on device %s: %s",
|
||||
hostname,
|
||||
cleanup_err,
|
||||
)
|
||||
|
||||
await _update_push_op_status(push_op_id, "committed", db_session)
|
||||
await clear_push(device_id)
|
||||
await _publish_push_progress(tenant_id, device_id, "committed", f"Config restored successfully on {hostname}", push_op_id=push_op_id_str)
|
||||
|
||||
return {
|
||||
"status": "committed",
|
||||
"message": "Config restored successfully",
|
||||
"pre_backup_sha": pre_backup_sha,
|
||||
}
|
||||
|
||||
else:
|
||||
# ------------------------------------------------------------------
|
||||
# Step 8b: Device unreachable — RouterOS is auto-reverting via scheduler
|
||||
# ------------------------------------------------------------------
|
||||
logger.warning(
|
||||
"Device %s (%s) is unreachable after push — RouterOS panic-revert scheduler "
|
||||
"will auto-revert to %s.backup",
|
||||
hostname,
|
||||
ip,
|
||||
_PRE_PUSH_BACKUP,
|
||||
)
|
||||
|
||||
await _update_push_op_status(push_op_id, "reverted", db_session)
|
||||
await _publish_push_progress(
|
||||
tenant_id, device_id, "reverted",
|
||||
f"Device {hostname} unreachable — auto-reverting via panic-revert scheduler",
|
||||
push_op_id=push_op_id_str,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "reverted",
|
||||
"message": (
|
||||
"Device unreachable after push; RouterOS is auto-reverting "
|
||||
"via panic-revert scheduler"
|
||||
),
|
||||
"pre_backup_sha": pre_backup_sha,
|
||||
}
|
||||
|
||||
|
||||
async def _check_reachability(ip: str, username: str, password: str) -> bool:
|
||||
"""Check if a RouterOS device is reachable via SSH.
|
||||
|
||||
Attempts to connect and run a simple command (/system identity print).
|
||||
Returns True if successful, False if the connection fails or times out.
|
||||
|
||||
Uses asyncssh (not the poller's binary API) to avoid a circular import.
|
||||
A 30-second timeout is used — if the device doesn't respond within that
|
||||
window, it's considered unreachable (panic-revert will handle it).
|
||||
|
||||
Args:
|
||||
ip: Device IP address.
|
||||
username: SSH username.
|
||||
password: SSH password.
|
||||
|
||||
Returns:
|
||||
True if reachable, False if unreachable.
|
||||
"""
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip,
|
||||
port=22,
|
||||
username=username,
|
||||
password=password,
|
||||
known_hosts=None,
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
result = await conn.run("/system identity print", check=True)
|
||||
logger.debug("Reachability check OK for %s: %r", ip, result.stdout[:50])
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.info("Device %s unreachable after push: %s", ip, exc)
|
||||
return False
|
||||
|
||||
|
||||
async def _update_push_op_status(
|
||||
push_op_id,
|
||||
new_status: str,
|
||||
db_session: AsyncSession,
|
||||
) -> None:
|
||||
"""Update the status and completed_at of a ConfigPushOperation row.
|
||||
|
||||
Args:
|
||||
push_op_id: UUID of the ConfigPushOperation row.
|
||||
new_status: New status value ('committed' | 'reverted' | 'failed').
|
||||
db_session: Database session (must already have tenant context set).
|
||||
"""
|
||||
from sqlalchemy import select, update
|
||||
|
||||
await db_session.execute(
|
||||
update(ConfigPushOperation)
|
||||
.where(ConfigPushOperation.id == push_op_id) # type: ignore[arg-type]
|
||||
.values(
|
||||
status=new_status,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
)
|
||||
# Don't commit here — the caller (endpoint) owns the transaction
|
||||
|
||||
|
||||
async def _remove_panic_scheduler(
|
||||
ip: str, username: str, password: str, scheduler_name: str
|
||||
) -> bool:
|
||||
"""SSH to device and remove the panic-revert scheduler. Returns True if removed."""
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip,
|
||||
username=username,
|
||||
password=password,
|
||||
known_hosts=None,
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
# Check if scheduler exists
|
||||
result = await conn.run(
|
||||
f'/system scheduler print where name="{scheduler_name}"',
|
||||
check=False,
|
||||
)
|
||||
if scheduler_name in result.stdout:
|
||||
await conn.run(
|
||||
f'/system scheduler remove [find name="{scheduler_name}"]',
|
||||
check=False,
|
||||
)
|
||||
# Also clean up pre-push backup file
|
||||
await conn.run(
|
||||
f'/file remove [find name="{_PRE_PUSH_BACKUP}.backup"]',
|
||||
check=False,
|
||||
)
|
||||
return True
|
||||
return False # Scheduler already gone (device reverted itself)
|
||||
except Exception as e:
|
||||
logger.error("Failed to remove panic scheduler from %s: %s", ip, e)
|
||||
return False
|
||||
|
||||
|
||||
async def recover_stale_push_operations(db_session: AsyncSession) -> None:
|
||||
"""Recover stale pending_verification push operations on API startup.
|
||||
|
||||
Scans for operations older than 5 minutes that are still pending.
|
||||
For each, checks device reachability and resolves the operation.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.config_backup import ConfigPushOperation
|
||||
from app.models.device import Device
|
||||
from app.services.crypto import decrypt_credentials_hybrid
|
||||
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(ConfigPushOperation).where(
|
||||
ConfigPushOperation.status == "pending_verification",
|
||||
ConfigPushOperation.started_at < cutoff,
|
||||
)
|
||||
)
|
||||
stale_ops = result.scalars().all()
|
||||
|
||||
if not stale_ops:
|
||||
logger.info("No stale push operations to recover")
|
||||
return
|
||||
|
||||
logger.warning("Found %d stale push operations to recover", len(stale_ops))
|
||||
|
||||
key = settings.get_encryption_key_bytes()
|
||||
|
||||
for op in stale_ops:
|
||||
try:
|
||||
# Load device
|
||||
dev_result = await db_session.execute(
|
||||
select(Device).where(Device.id == op.device_id)
|
||||
)
|
||||
device = dev_result.scalar_one_or_none()
|
||||
if not device:
|
||||
logger.error("Device %s not found for stale op %s", op.device_id, op.id)
|
||||
await _update_push_op_status(op.id, "failed", db_session)
|
||||
continue
|
||||
|
||||
# Decrypt credentials
|
||||
creds_json = await decrypt_credentials_hybrid(
|
||||
device.encrypted_credentials_transit,
|
||||
device.encrypted_credentials,
|
||||
str(op.tenant_id),
|
||||
key,
|
||||
)
|
||||
creds = json.loads(creds_json)
|
||||
ssh_username = creds.get("username", "admin")
|
||||
ssh_password = creds.get("password", "")
|
||||
|
||||
# Check reachability
|
||||
reachable = await _check_reachability(
|
||||
device.ip_address, ssh_username, ssh_password
|
||||
)
|
||||
|
||||
if reachable:
|
||||
# Try to remove scheduler (if still there, push was good)
|
||||
removed = await _remove_panic_scheduler(
|
||||
device.ip_address,
|
||||
ssh_username,
|
||||
ssh_password,
|
||||
op.scheduler_name,
|
||||
)
|
||||
if removed:
|
||||
logger.info("Recovery: committed op %s (scheduler removed)", op.id)
|
||||
else:
|
||||
# Scheduler already gone — device may have reverted
|
||||
logger.warning(
|
||||
"Recovery: op %s — scheduler gone, device may have reverted. "
|
||||
"Marking committed (device is reachable).",
|
||||
op.id,
|
||||
)
|
||||
await _update_push_op_status(op.id, "committed", db_session)
|
||||
|
||||
await _publish_push_progress(
|
||||
str(op.tenant_id),
|
||||
str(op.device_id),
|
||||
"committed",
|
||||
"Recovered after API restart",
|
||||
push_op_id=str(op.id),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Recovery: device %s unreachable, marking op %s failed",
|
||||
op.device_id,
|
||||
op.id,
|
||||
)
|
||||
await _update_push_op_status(op.id, "failed", db_session)
|
||||
await _publish_push_progress(
|
||||
str(op.tenant_id),
|
||||
str(op.device_id),
|
||||
"failed",
|
||||
"Device unreachable during recovery after API restart",
|
||||
push_op_id=str(op.id),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Recovery failed for op %s: %s", op.id, e)
|
||||
await _update_push_op_status(op.id, "failed", db_session)
|
||||
|
||||
await db_session.commit()
|
||||
165
backend/app/services/routeros_proxy.py
Normal file
165
backend/app/services/routeros_proxy.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""RouterOS command proxy via NATS request-reply.
|
||||
|
||||
Sends command requests to the Go poller's CmdResponder subscription
|
||||
(device.cmd.{device_id}) and returns structured RouterOS API response data.
|
||||
|
||||
Used by:
|
||||
- Config editor API (browse menu paths, add/edit/delete entries)
|
||||
- Template push service (execute rendered template commands)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import nats
|
||||
import nats.aio.client
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Module-level NATS connection (lazy initialized)
|
||||
_nc: nats.aio.client.Client | None = None
|
||||
|
||||
|
||||
async def _get_nats() -> nats.aio.client.Client:
|
||||
"""Get or create a NATS connection for command proxy requests."""
|
||||
global _nc
|
||||
if _nc is None or _nc.is_closed:
|
||||
_nc = await nats.connect(settings.NATS_URL)
|
||||
logger.info("RouterOS proxy NATS connection established")
|
||||
return _nc
|
||||
|
||||
|
||||
async def execute_command(
|
||||
device_id: str,
|
||||
command: str,
|
||||
args: list[str] | None = None,
|
||||
timeout: float = 15.0,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute a RouterOS API command on a device via the Go poller.
|
||||
|
||||
Args:
|
||||
device_id: UUID string of the target device.
|
||||
command: Full RouterOS API path, e.g. "/ip/address/print".
|
||||
args: Optional list of RouterOS API args, e.g. ["=.proplist=.id,address"].
|
||||
timeout: NATS request timeout in seconds (default 15s).
|
||||
|
||||
Returns:
|
||||
{"success": bool, "data": list[dict], "error": str|None}
|
||||
"""
|
||||
nc = await _get_nats()
|
||||
request = {
|
||||
"device_id": device_id,
|
||||
"command": command,
|
||||
"args": args or [],
|
||||
}
|
||||
|
||||
try:
|
||||
reply = await nc.request(
|
||||
f"device.cmd.{device_id}",
|
||||
json.dumps(request).encode(),
|
||||
timeout=timeout,
|
||||
)
|
||||
return json.loads(reply.data)
|
||||
except nats.errors.TimeoutError:
|
||||
return {
|
||||
"success": False,
|
||||
"data": [],
|
||||
"error": "Device command timed out — device may be offline or unreachable",
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.error("NATS request failed for device %s: %s", device_id, exc)
|
||||
return {"success": False, "data": [], "error": str(exc)}
|
||||
|
||||
|
||||
async def browse_menu(device_id: str, path: str) -> dict[str, Any]:
|
||||
"""Browse a RouterOS menu path and return all entries.
|
||||
|
||||
Args:
|
||||
device_id: Device UUID string.
|
||||
path: RouterOS menu path, e.g. "/ip/address" or "/interface".
|
||||
|
||||
Returns:
|
||||
{"success": bool, "data": list[dict], "error": str|None}
|
||||
"""
|
||||
command = f"{path}/print"
|
||||
return await execute_command(device_id, command)
|
||||
|
||||
|
||||
async def add_entry(
|
||||
device_id: str, path: str, properties: dict[str, str]
|
||||
) -> dict[str, Any]:
|
||||
"""Add a new entry to a RouterOS menu path.
|
||||
|
||||
Args:
|
||||
device_id: Device UUID.
|
||||
path: Menu path, e.g. "/ip/address".
|
||||
properties: Key-value pairs for the new entry.
|
||||
|
||||
Returns:
|
||||
Command response dict.
|
||||
"""
|
||||
args = [f"={k}={v}" for k, v in properties.items()]
|
||||
return await execute_command(device_id, f"{path}/add", args)
|
||||
|
||||
|
||||
async def update_entry(
|
||||
device_id: str, path: str, entry_id: str | None, properties: dict[str, str]
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing entry in a RouterOS menu path.
|
||||
|
||||
Args:
|
||||
device_id: Device UUID.
|
||||
path: Menu path.
|
||||
entry_id: RouterOS .id value (e.g. "*1"). None for singleton paths.
|
||||
properties: Key-value pairs to update.
|
||||
|
||||
Returns:
|
||||
Command response dict.
|
||||
"""
|
||||
id_args = [f"=.id={entry_id}"] if entry_id else []
|
||||
args = id_args + [f"={k}={v}" for k, v in properties.items()]
|
||||
return await execute_command(device_id, f"{path}/set", args)
|
||||
|
||||
|
||||
async def remove_entry(
|
||||
device_id: str, path: str, entry_id: str
|
||||
) -> dict[str, Any]:
|
||||
"""Remove an entry from a RouterOS menu path.
|
||||
|
||||
Args:
|
||||
device_id: Device UUID.
|
||||
path: Menu path.
|
||||
entry_id: RouterOS .id value.
|
||||
|
||||
Returns:
|
||||
Command response dict.
|
||||
"""
|
||||
return await execute_command(device_id, f"{path}/remove", [f"=.id={entry_id}"])
|
||||
|
||||
|
||||
async def execute_cli(device_id: str, cli_command: str) -> dict[str, Any]:
|
||||
"""Execute an arbitrary RouterOS CLI command.
|
||||
|
||||
For commands that don't follow the standard /path/action pattern.
|
||||
The command is sent as-is to the RouterOS API.
|
||||
|
||||
Args:
|
||||
device_id: Device UUID.
|
||||
cli_command: Full CLI command string.
|
||||
|
||||
Returns:
|
||||
Command response dict.
|
||||
"""
|
||||
return await execute_command(device_id, cli_command)
|
||||
|
||||
|
||||
async def close() -> None:
|
||||
"""Close the NATS connection. Called on application shutdown."""
|
||||
global _nc
|
||||
if _nc and not _nc.is_closed:
|
||||
await _nc.drain()
|
||||
_nc = None
|
||||
logger.info("RouterOS proxy NATS connection closed")
|
||||
220
backend/app/services/rsc_parser.py
Normal file
220
backend/app/services/rsc_parser.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""RouterOS RSC export parser — extracts categories, validates syntax, computes impact."""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HIGH_RISK_PATHS = {
|
||||
"/ip address", "/ip route", "/ip firewall filter", "/ip firewall nat",
|
||||
"/interface", "/interface bridge", "/interface vlan",
|
||||
"/system identity", "/ip service", "/ip ssh", "/user",
|
||||
}
|
||||
|
||||
MANAGEMENT_PATTERNS = [
|
||||
(re.compile(r"chain=input.*dst-port=(22|8291|8728|8729|443|80)", re.I),
|
||||
"Modifies firewall rules for management ports (SSH/WinBox/API/Web)"),
|
||||
(re.compile(r"chain=input.*action=drop", re.I),
|
||||
"Adds drop rule on input chain — may block management access"),
|
||||
(re.compile(r"/ip service", re.I),
|
||||
"Modifies IP services — may disable API/SSH/WinBox access"),
|
||||
(re.compile(r"/user.*set.*password", re.I),
|
||||
"Changes user password — may affect automated access"),
|
||||
]
|
||||
|
||||
|
||||
def _join_continuation_lines(text: str) -> list[str]:
|
||||
"""Join lines ending with \\ into single logical lines."""
|
||||
lines = text.split("\n")
|
||||
joined: list[str] = []
|
||||
buf = ""
|
||||
for line in lines:
|
||||
stripped = line.rstrip()
|
||||
if stripped.endswith("\\"):
|
||||
buf += stripped[:-1].rstrip() + " "
|
||||
else:
|
||||
if buf:
|
||||
buf += stripped
|
||||
joined.append(buf)
|
||||
buf = ""
|
||||
else:
|
||||
joined.append(stripped)
|
||||
if buf:
|
||||
joined.append(buf + " <<TRUNCATED>>")
|
||||
return joined
|
||||
|
||||
|
||||
def parse_rsc(text: str) -> dict[str, Any]:
|
||||
"""Parse a RouterOS /export compact output.
|
||||
|
||||
Returns a dict with a "categories" list, each containing:
|
||||
- path: the RouterOS command path (e.g. "/ip address")
|
||||
- adds: count of "add" commands
|
||||
- sets: count of "set" commands
|
||||
- removes: count of "remove" commands
|
||||
- commands: list of command strings under this path
|
||||
"""
|
||||
lines = _join_continuation_lines(text)
|
||||
categories: dict[str, dict] = {}
|
||||
current_path: str | None = None
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
if line.startswith("/"):
|
||||
# Could be just a path header, or a path followed by a command
|
||||
parts = line.split(None, 1)
|
||||
if len(parts) == 1:
|
||||
# Pure path header like "/interface bridge"
|
||||
current_path = parts[0]
|
||||
else:
|
||||
# Check if second part starts with a known command verb
|
||||
cmd_check = parts[1].strip().split(None, 1)
|
||||
if cmd_check and cmd_check[0] in ("add", "set", "remove", "print", "enable", "disable"):
|
||||
current_path = parts[0]
|
||||
line = parts[1].strip()
|
||||
else:
|
||||
# The whole line is a path (e.g. "/ip firewall filter")
|
||||
current_path = line
|
||||
continue
|
||||
|
||||
if current_path and current_path not in categories:
|
||||
categories[current_path] = {
|
||||
"path": current_path,
|
||||
"adds": 0,
|
||||
"sets": 0,
|
||||
"removes": 0,
|
||||
"commands": [],
|
||||
}
|
||||
|
||||
if len(parts) == 1:
|
||||
continue
|
||||
|
||||
if current_path is None:
|
||||
continue
|
||||
|
||||
if current_path not in categories:
|
||||
categories[current_path] = {
|
||||
"path": current_path,
|
||||
"adds": 0,
|
||||
"sets": 0,
|
||||
"removes": 0,
|
||||
"commands": [],
|
||||
}
|
||||
|
||||
cat = categories[current_path]
|
||||
cat["commands"].append(line)
|
||||
|
||||
if line.startswith("add ") or line.startswith("add\t"):
|
||||
cat["adds"] += 1
|
||||
elif line.startswith("set "):
|
||||
cat["sets"] += 1
|
||||
elif line.startswith("remove "):
|
||||
cat["removes"] += 1
|
||||
|
||||
return {"categories": list(categories.values())}
|
||||
|
||||
|
||||
def validate_rsc(text: str) -> dict[str, Any]:
|
||||
"""Validate RSC export syntax.
|
||||
|
||||
Checks for:
|
||||
- Unbalanced quotes (indicates truncation or corruption)
|
||||
- Trailing continuation lines (indicates truncated export)
|
||||
|
||||
Returns dict with "valid" (bool) and "errors" (list of strings).
|
||||
"""
|
||||
errors: list[str] = []
|
||||
|
||||
# Check for unbalanced quotes across the entire file
|
||||
in_quote = False
|
||||
for line in text.split("\n"):
|
||||
stripped = line.rstrip()
|
||||
if stripped.endswith("\\"):
|
||||
stripped = stripped[:-1]
|
||||
# Count unescaped quotes
|
||||
count = stripped.count('"') - stripped.count('\\"')
|
||||
if count % 2 != 0:
|
||||
in_quote = not in_quote
|
||||
|
||||
if in_quote:
|
||||
errors.append("Unbalanced quote detected — file may be truncated")
|
||||
|
||||
# Check if file ends with a continuation backslash
|
||||
lines = text.rstrip().split("\n")
|
||||
if lines and lines[-1].rstrip().endswith("\\"):
|
||||
errors.append("File ends with continuation line (\\) — truncated export")
|
||||
|
||||
return {"valid": len(errors) == 0, "errors": errors}
|
||||
|
||||
|
||||
def compute_impact(
|
||||
current_parsed: dict[str, Any],
|
||||
target_parsed: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Compare current vs target parsed RSC and compute impact analysis.
|
||||
|
||||
Returns dict with:
|
||||
- categories: list of per-path diffs with risk levels
|
||||
- warnings: list of human-readable warning strings
|
||||
- diff: summary counts (added, removed, modified)
|
||||
"""
|
||||
current_map = {c["path"]: c for c in current_parsed["categories"]}
|
||||
target_map = {c["path"]: c for c in target_parsed["categories"]}
|
||||
all_paths = sorted(set(list(current_map.keys()) + list(target_map.keys())))
|
||||
|
||||
result_categories = []
|
||||
warnings: list[str] = []
|
||||
total_added = total_removed = total_modified = 0
|
||||
|
||||
for path in all_paths:
|
||||
curr = current_map.get(path, {"adds": 0, "sets": 0, "removes": 0, "commands": []})
|
||||
tgt = target_map.get(path, {"adds": 0, "sets": 0, "removes": 0, "commands": []})
|
||||
curr_cmds = set(curr.get("commands", []))
|
||||
tgt_cmds = set(tgt.get("commands", []))
|
||||
added = len(tgt_cmds - curr_cmds)
|
||||
removed = len(curr_cmds - tgt_cmds)
|
||||
total_added += added
|
||||
total_removed += removed
|
||||
|
||||
has_changes = added > 0 or removed > 0
|
||||
risk = "none"
|
||||
if has_changes:
|
||||
risk = "high" if path in HIGH_RISK_PATHS else "low"
|
||||
result_categories.append({
|
||||
"path": path,
|
||||
"adds": added,
|
||||
"removes": removed,
|
||||
"risk": risk,
|
||||
})
|
||||
|
||||
# Check target commands against management patterns
|
||||
target_text = "\n".join(
|
||||
cmd for cat in target_parsed["categories"] for cmd in cat.get("commands", [])
|
||||
)
|
||||
for pattern, message in MANAGEMENT_PATTERNS:
|
||||
if pattern.search(target_text):
|
||||
warnings.append(message)
|
||||
|
||||
# Warn about removed IP addresses
|
||||
if "/ip address" in current_map and "/ip address" in target_map:
|
||||
curr_addrs = current_map["/ip address"].get("commands", [])
|
||||
tgt_addrs = target_map["/ip address"].get("commands", [])
|
||||
removed_addrs = set(curr_addrs) - set(tgt_addrs)
|
||||
if removed_addrs:
|
||||
warnings.append(
|
||||
f"Removes {len(removed_addrs)} IP address(es) — verify none are management interfaces"
|
||||
)
|
||||
|
||||
return {
|
||||
"categories": result_categories,
|
||||
"warnings": warnings,
|
||||
"diff": {
|
||||
"added": total_added,
|
||||
"removed": total_removed,
|
||||
"modified": total_modified,
|
||||
},
|
||||
}
|
||||
124
backend/app/services/scanner.py
Normal file
124
backend/app/services/scanner.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Subnet scanner for MikroTik device discovery.
|
||||
|
||||
Scans a CIDR range by attempting TCP connections to RouterOS API ports
|
||||
(8728 and 8729) with configurable concurrency limits and timeouts.
|
||||
|
||||
Security constraints:
|
||||
- CIDR range limited to /20 or smaller (4096 IPs maximum)
|
||||
- Maximum 50 concurrent connections to prevent network flooding
|
||||
- 2-second timeout per connection attempt
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import socket
|
||||
from typing import Optional
|
||||
|
||||
from app.schemas.device import SubnetScanResult
|
||||
|
||||
# Maximum concurrency for TCP probes
|
||||
_MAX_CONCURRENT = 50
|
||||
# Timeout (seconds) per TCP connection attempt
|
||||
_TCP_TIMEOUT = 2.0
|
||||
# RouterOS API port
|
||||
_API_PORT = 8728
|
||||
# RouterOS SSL API port
|
||||
_SSL_PORT = 8729
|
||||
|
||||
|
||||
async def _probe_host(
|
||||
semaphore: asyncio.Semaphore,
|
||||
ip_str: str,
|
||||
) -> Optional[SubnetScanResult]:
|
||||
"""
|
||||
Probe a single IP for RouterOS API ports.
|
||||
|
||||
Returns a SubnetScanResult if either port is open, None otherwise.
|
||||
"""
|
||||
async with semaphore:
|
||||
api_open, ssl_open = await asyncio.gather(
|
||||
_tcp_connect(ip_str, _API_PORT),
|
||||
_tcp_connect(ip_str, _SSL_PORT),
|
||||
return_exceptions=False,
|
||||
)
|
||||
|
||||
if not api_open and not ssl_open:
|
||||
return None
|
||||
|
||||
# Attempt reverse DNS (best-effort; won't fail the scan)
|
||||
hostname = await _reverse_dns(ip_str)
|
||||
|
||||
return SubnetScanResult(
|
||||
ip_address=ip_str,
|
||||
hostname=hostname,
|
||||
api_port_open=api_open,
|
||||
api_ssl_port_open=ssl_open,
|
||||
)
|
||||
|
||||
|
||||
async def _tcp_connect(ip: str, port: int) -> bool:
|
||||
"""Return True if a TCP connection to ip:port succeeds within _TCP_TIMEOUT."""
|
||||
try:
|
||||
_, writer = await asyncio.wait_for(
|
||||
asyncio.open_connection(ip, port),
|
||||
timeout=_TCP_TIMEOUT,
|
||||
)
|
||||
writer.close()
|
||||
try:
|
||||
await writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def _reverse_dns(ip: str) -> Optional[str]:
|
||||
"""Attempt a reverse DNS lookup. Returns None on failure."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
hostname, _, _ = await asyncio.wait_for(
|
||||
loop.run_in_executor(None, socket.gethostbyaddr, ip),
|
||||
timeout=1.5,
|
||||
)
|
||||
return hostname
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def scan_subnet(cidr: str) -> list[SubnetScanResult]:
|
||||
"""
|
||||
Scan a CIDR range for hosts with open RouterOS API ports.
|
||||
|
||||
Args:
|
||||
cidr: CIDR notation string, e.g. "192.168.1.0/24".
|
||||
Must be /20 or smaller (validated by SubnetScanRequest).
|
||||
|
||||
Returns:
|
||||
List of SubnetScanResult for each host with at least one open API port.
|
||||
|
||||
Raises:
|
||||
ValueError: If CIDR is malformed or too large.
|
||||
"""
|
||||
try:
|
||||
network = ipaddress.ip_network(cidr, strict=False)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid CIDR: {e}") from e
|
||||
|
||||
if network.num_addresses > 4096:
|
||||
raise ValueError(
|
||||
f"CIDR range too large ({network.num_addresses} addresses). "
|
||||
"Maximum allowed is /20 (4096 addresses)."
|
||||
)
|
||||
|
||||
# Skip network address and broadcast address for IPv4
|
||||
hosts = list(network.hosts()) if network.num_addresses > 2 else list(network)
|
||||
|
||||
semaphore = asyncio.Semaphore(_MAX_CONCURRENT)
|
||||
tasks = [_probe_host(semaphore, str(ip)) for ip in hosts]
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
# Filter out None (hosts with no open ports)
|
||||
return [r for r in results if r is not None]
|
||||
113
backend/app/services/srp_service.py
Normal file
113
backend/app/services/srp_service.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""SRP-6a server-side authentication service.
|
||||
|
||||
Wraps the srptools library for the two-step SRP handshake.
|
||||
All functions are async, using asyncio.to_thread() because
|
||||
srptools operations are CPU-bound and synchronous.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
|
||||
from srptools import SRPContext, SRPServerSession
|
||||
from srptools.constants import PRIME_2048, PRIME_2048_GEN
|
||||
|
||||
# Client uses Web Crypto SHA-256 — server must match.
|
||||
# srptools defaults to SHA-1 which would cause proof mismatch.
|
||||
_SRP_HASH = hashlib.sha256
|
||||
|
||||
|
||||
async def create_srp_verifier(
|
||||
salt_hex: str, verifier_hex: str
|
||||
) -> tuple[bytes, bytes]:
|
||||
"""Convert client-provided hex salt and verifier to bytes for storage.
|
||||
|
||||
The client computes v = g^x mod N using 2SKD-derived SRP-x.
|
||||
The server stores the verifier directly and never computes x
|
||||
from the password.
|
||||
|
||||
Returns:
|
||||
Tuple of (salt_bytes, verifier_bytes) ready for database storage.
|
||||
"""
|
||||
return bytes.fromhex(salt_hex), bytes.fromhex(verifier_hex)
|
||||
|
||||
|
||||
async def srp_init(
|
||||
email: str, srp_verifier_hex: str
|
||||
) -> tuple[str, str]:
|
||||
"""SRP Step 1: Generate server ephemeral (B) and private key (b).
|
||||
|
||||
Args:
|
||||
email: User email (SRP identity I).
|
||||
srp_verifier_hex: Hex-encoded SRP verifier from database.
|
||||
|
||||
Returns:
|
||||
Tuple of (server_public_hex, server_private_hex).
|
||||
Caller stores server_private in Redis with 60s TTL.
|
||||
|
||||
Raises:
|
||||
ValueError: If SRP initialization fails for any reason.
|
||||
"""
|
||||
def _init() -> tuple[str, str]:
|
||||
context = SRPContext(
|
||||
email, prime=PRIME_2048, generator=PRIME_2048_GEN,
|
||||
hash_func=_SRP_HASH,
|
||||
)
|
||||
server_session = SRPServerSession(
|
||||
context, srp_verifier_hex
|
||||
)
|
||||
return server_session.public, server_session.private
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_init)
|
||||
except Exception as e:
|
||||
raise ValueError(f"SRP initialization failed: {e}") from e
|
||||
|
||||
|
||||
async def srp_verify(
|
||||
email: str,
|
||||
srp_verifier_hex: str,
|
||||
server_private: str,
|
||||
client_public: str,
|
||||
client_proof: str,
|
||||
srp_salt_hex: str,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""SRP Step 2: Verify client proof M1, return server proof M2.
|
||||
|
||||
Args:
|
||||
email: User email (SRP identity I).
|
||||
srp_verifier_hex: Hex-encoded SRP verifier from database.
|
||||
server_private: Server private ephemeral from Redis session.
|
||||
client_public: Hex-encoded client public ephemeral A.
|
||||
client_proof: Hex-encoded client proof M1.
|
||||
srp_salt_hex: Hex-encoded SRP salt.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, server_proof_hex_or_none).
|
||||
If valid, server_proof is M2 for the client to verify.
|
||||
"""
|
||||
def _verify() -> tuple[bool, str | None]:
|
||||
import logging
|
||||
log = logging.getLogger("srp_debug")
|
||||
context = SRPContext(
|
||||
email, prime=PRIME_2048, generator=PRIME_2048_GEN,
|
||||
hash_func=_SRP_HASH,
|
||||
)
|
||||
server_session = SRPServerSession(
|
||||
context, srp_verifier_hex, private=server_private
|
||||
)
|
||||
_key, _key_proof, _key_proof_hash = server_session.process(client_public, srp_salt_hex)
|
||||
# srptools verify_proof has a Python 3 bug: hexlify() returns bytes
|
||||
# but client_proof is str, so bytes == str is always False.
|
||||
# Compare manually with consistent types.
|
||||
server_m1 = _key_proof if isinstance(_key_proof, str) else _key_proof.decode('ascii')
|
||||
is_valid = client_proof.lower() == server_m1.lower()
|
||||
if not is_valid:
|
||||
return False, None
|
||||
# Return M2 (key_proof_hash), also fixing the bytes/str issue
|
||||
m2 = _key_proof_hash if isinstance(_key_proof_hash, str) else _key_proof_hash.decode('ascii')
|
||||
return True, m2
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_verify)
|
||||
except Exception as e:
|
||||
raise ValueError(f"SRP verification failed: {e}") from e
|
||||
311
backend/app/services/sse_manager.py
Normal file
311
backend/app/services/sse_manager.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""SSE Connection Manager -- bridges NATS JetStream to per-client asyncio queues.
|
||||
|
||||
Each SSE client gets its own NATS connection with ephemeral consumers.
|
||||
Events are tenant-filtered and placed onto an asyncio.Queue that the
|
||||
SSE router drains via EventSourceResponse.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
import nats
|
||||
import structlog
|
||||
from nats.js.api import ConsumerConfig, DeliverPolicy, StreamConfig
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
# Subjects per stream for SSE subscriptions
|
||||
# Note: config.push.* subjects live in DEVICE_EVENTS (created by Go poller)
|
||||
_DEVICE_EVENT_SUBJECTS = [
|
||||
"device.status.>",
|
||||
"device.metrics.>",
|
||||
"config.push.rollback.>",
|
||||
"config.push.alert.>",
|
||||
]
|
||||
_ALERT_EVENT_SUBJECTS = ["alert.fired.>", "alert.resolved.>"]
|
||||
_OPERATION_EVENT_SUBJECTS = ["firmware.progress.>"]
|
||||
|
||||
|
||||
def _map_subject_to_event_type(subject: str) -> str:
|
||||
"""Map a NATS subject prefix to an SSE event type string."""
|
||||
if subject.startswith("device.status."):
|
||||
return "device_status"
|
||||
if subject.startswith("device.metrics."):
|
||||
return "metric_update"
|
||||
if subject.startswith("alert.fired."):
|
||||
return "alert_fired"
|
||||
if subject.startswith("alert.resolved."):
|
||||
return "alert_resolved"
|
||||
if subject.startswith("config.push."):
|
||||
return "config_push"
|
||||
if subject.startswith("firmware.progress."):
|
||||
return "firmware_progress"
|
||||
return "unknown"
|
||||
|
||||
|
||||
async def ensure_sse_streams() -> None:
|
||||
"""Create ALERT_EVENTS and OPERATION_EVENTS NATS streams if they don't exist.
|
||||
|
||||
Called once during app startup so the streams are ready before any
|
||||
SSE connection or event publisher needs them. Idempotent -- uses
|
||||
add_stream which acts as create-or-update.
|
||||
"""
|
||||
nc = None
|
||||
try:
|
||||
nc = await nats.connect(settings.NATS_URL)
|
||||
js = nc.jetstream()
|
||||
|
||||
await js.add_stream(
|
||||
StreamConfig(
|
||||
name="ALERT_EVENTS",
|
||||
subjects=["alert.fired.>", "alert.resolved.>"],
|
||||
max_age=3600, # 1 hour retention
|
||||
)
|
||||
)
|
||||
logger.info("nats.stream.ensured", stream="ALERT_EVENTS")
|
||||
|
||||
await js.add_stream(
|
||||
StreamConfig(
|
||||
name="OPERATION_EVENTS",
|
||||
subjects=["firmware.progress.>"],
|
||||
max_age=3600, # 1 hour retention
|
||||
)
|
||||
)
|
||||
logger.info("nats.stream.ensured", stream="OPERATION_EVENTS")
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("sse.streams.ensure_failed", error=str(exc))
|
||||
raise
|
||||
finally:
|
||||
if nc:
|
||||
try:
|
||||
await nc.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class SSEConnectionManager:
|
||||
"""Manages a single SSE client's lifecycle: NATS connection, subscriptions, and event queue."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._nc: Optional[nats.aio.client.Client] = None
|
||||
self._subscriptions: list = []
|
||||
self._queue: Optional[asyncio.Queue] = None
|
||||
self._tenant_id: Optional[str] = None
|
||||
self._connection_id: Optional[str] = None
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
connection_id: str,
|
||||
tenant_id: Optional[str],
|
||||
last_event_id: Optional[str] = None,
|
||||
) -> asyncio.Queue:
|
||||
"""Set up NATS subscriptions and return an asyncio.Queue for SSE events.
|
||||
|
||||
Args:
|
||||
connection_id: Unique identifier for this SSE connection.
|
||||
tenant_id: Tenant UUID string to filter events. None for super_admin
|
||||
(receives events from all tenants).
|
||||
last_event_id: NATS stream sequence number from the Last-Event-ID header.
|
||||
If provided, replay starts from sequence + 1.
|
||||
|
||||
Returns:
|
||||
asyncio.Queue that the SSE generator should drain.
|
||||
"""
|
||||
self._connection_id = connection_id
|
||||
self._tenant_id = tenant_id
|
||||
self._queue = asyncio.Queue(maxsize=256)
|
||||
|
||||
self._nc = await nats.connect(
|
||||
settings.NATS_URL,
|
||||
max_reconnect_attempts=5,
|
||||
reconnect_time_wait=2,
|
||||
)
|
||||
js = self._nc.jetstream()
|
||||
|
||||
logger.info(
|
||||
"sse.connecting",
|
||||
connection_id=connection_id,
|
||||
tenant_id=tenant_id,
|
||||
last_event_id=last_event_id,
|
||||
)
|
||||
|
||||
# Build consumer config for replay support
|
||||
if last_event_id is not None:
|
||||
try:
|
||||
start_seq = int(last_event_id) + 1
|
||||
consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.BY_START_SEQUENCE, opt_start_seq=start_seq)
|
||||
except (ValueError, TypeError):
|
||||
consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.NEW)
|
||||
else:
|
||||
consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.NEW)
|
||||
|
||||
# Subscribe to device events (DEVICE_EVENTS stream -- created by Go poller)
|
||||
for subject in _DEVICE_EVENT_SUBJECTS:
|
||||
try:
|
||||
sub = await js.subscribe(
|
||||
subject,
|
||||
stream="DEVICE_EVENTS",
|
||||
config=consumer_cfg,
|
||||
)
|
||||
self._subscriptions.append(sub)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"sse.subscribe_failed",
|
||||
subject=subject,
|
||||
stream="DEVICE_EVENTS",
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
# Subscribe to alert events (ALERT_EVENTS stream)
|
||||
# Lazily create the stream if it doesn't exist yet (startup race)
|
||||
for subject in _ALERT_EVENT_SUBJECTS:
|
||||
try:
|
||||
sub = await js.subscribe(
|
||||
subject,
|
||||
stream="ALERT_EVENTS",
|
||||
config=consumer_cfg,
|
||||
)
|
||||
self._subscriptions.append(sub)
|
||||
except Exception as exc:
|
||||
if "stream not found" in str(exc):
|
||||
try:
|
||||
await js.add_stream(StreamConfig(
|
||||
name="ALERT_EVENTS",
|
||||
subjects=_ALERT_EVENT_SUBJECTS,
|
||||
max_age=3600,
|
||||
))
|
||||
sub = await js.subscribe(subject, stream="ALERT_EVENTS", config=consumer_cfg)
|
||||
self._subscriptions.append(sub)
|
||||
logger.info("sse.stream_created_lazily", stream="ALERT_EVENTS")
|
||||
except Exception as retry_exc:
|
||||
logger.warning("sse.subscribe_failed", subject=subject, stream="ALERT_EVENTS", error=str(retry_exc))
|
||||
else:
|
||||
logger.warning("sse.subscribe_failed", subject=subject, stream="ALERT_EVENTS", error=str(exc))
|
||||
|
||||
# Subscribe to operation events (OPERATION_EVENTS stream)
|
||||
for subject in _OPERATION_EVENT_SUBJECTS:
|
||||
try:
|
||||
sub = await js.subscribe(
|
||||
subject,
|
||||
stream="OPERATION_EVENTS",
|
||||
config=consumer_cfg,
|
||||
)
|
||||
self._subscriptions.append(sub)
|
||||
except Exception as exc:
|
||||
if "stream not found" in str(exc):
|
||||
try:
|
||||
await js.add_stream(StreamConfig(
|
||||
name="OPERATION_EVENTS",
|
||||
subjects=_OPERATION_EVENT_SUBJECTS,
|
||||
max_age=3600,
|
||||
))
|
||||
sub = await js.subscribe(subject, stream="OPERATION_EVENTS", config=consumer_cfg)
|
||||
self._subscriptions.append(sub)
|
||||
logger.info("sse.stream_created_lazily", stream="OPERATION_EVENTS")
|
||||
except Exception as retry_exc:
|
||||
logger.warning("sse.subscribe_failed", subject=subject, stream="OPERATION_EVENTS", error=str(retry_exc))
|
||||
else:
|
||||
logger.warning("sse.subscribe_failed", subject=subject, stream="OPERATION_EVENTS", error=str(exc))
|
||||
|
||||
# Start background task to pull messages from subscriptions into the queue
|
||||
asyncio.create_task(self._pump_messages())
|
||||
|
||||
logger.info(
|
||||
"sse.connected",
|
||||
connection_id=connection_id,
|
||||
subscription_count=len(self._subscriptions),
|
||||
)
|
||||
|
||||
return self._queue
|
||||
|
||||
async def _pump_messages(self) -> None:
|
||||
"""Read messages from all NATS push subscriptions and push them onto the asyncio queue.
|
||||
|
||||
Uses next_msg with a short timeout so we can interleave across
|
||||
subscriptions without blocking. Runs until the NATS connection is closed
|
||||
or drained.
|
||||
"""
|
||||
while self._nc and self._nc.is_connected:
|
||||
for sub in self._subscriptions:
|
||||
try:
|
||||
msg = await sub.next_msg(timeout=0.5)
|
||||
await self._handle_message(msg)
|
||||
except nats.errors.TimeoutError:
|
||||
# No messages available on this subscription -- move on
|
||||
continue
|
||||
except Exception as exc:
|
||||
if self._nc and self._nc.is_connected:
|
||||
logger.warning(
|
||||
"sse.pump_error",
|
||||
connection_id=self._connection_id,
|
||||
error=str(exc),
|
||||
)
|
||||
break
|
||||
# Brief yield to avoid tight-looping
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _handle_message(self, msg) -> None:
|
||||
"""Parse a NATS message, apply tenant filter, and enqueue as SSE event."""
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
# Tenant filtering: skip messages not matching this connection's tenant
|
||||
if self._tenant_id is not None:
|
||||
msg_tenant = data.get("tenant_id", "")
|
||||
if str(msg_tenant) != self._tenant_id:
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
event_type = _map_subject_to_event_type(msg.subject)
|
||||
|
||||
# Extract NATS stream sequence for Last-Event-ID support
|
||||
seq_id = "0"
|
||||
if msg.metadata and msg.metadata.sequence:
|
||||
seq_id = str(msg.metadata.sequence.stream)
|
||||
|
||||
sse_event = {
|
||||
"event": event_type,
|
||||
"data": json.dumps(data),
|
||||
"id": seq_id,
|
||||
}
|
||||
|
||||
try:
|
||||
self._queue.put_nowait(sse_event)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
"sse.queue_full",
|
||||
connection_id=self._connection_id,
|
||||
dropped_event=event_type,
|
||||
)
|
||||
|
||||
await msg.ack()
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Unsubscribe from all NATS subscriptions and close the connection."""
|
||||
logger.info("sse.disconnecting", connection_id=self._connection_id)
|
||||
|
||||
for sub in self._subscriptions:
|
||||
try:
|
||||
await sub.unsubscribe()
|
||||
except Exception:
|
||||
pass
|
||||
self._subscriptions.clear()
|
||||
|
||||
if self._nc:
|
||||
try:
|
||||
await self._nc.drain()
|
||||
except Exception:
|
||||
try:
|
||||
await self._nc.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._nc = None
|
||||
|
||||
logger.info("sse.disconnected", connection_id=self._connection_id)
|
||||
480
backend/app/services/template_service.py
Normal file
480
backend/app/services/template_service.py
Normal file
@@ -0,0 +1,480 @@
|
||||
"""Config template service: Jinja2 rendering, variable extraction, and multi-device push.
|
||||
|
||||
Provides:
|
||||
- extract_variables: Parse template content to find all undeclared Jinja2 variables
|
||||
- render_template: Render a template with device context and custom variables
|
||||
- validate_variable: Type-check a variable value against its declared type
|
||||
- push_to_devices: Sequential multi-device push with pause-on-failure
|
||||
- push_single_device: Two-phase panic-revert push for a single device
|
||||
|
||||
The push logic follows the same two-phase pattern as restore_service but uses
|
||||
separate scheduler and file names to avoid conflicts with restore operations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import ipaddress
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import asyncssh
|
||||
from jinja2 import meta
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
from sqlalchemy import select, text
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
from app.models.config_template import TemplatePushJob
|
||||
from app.models.device import Device
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Sandboxed Jinja2 environment prevents template injection
|
||||
_env = SandboxedEnvironment()
|
||||
|
||||
# Names used on the RouterOS device during template push
|
||||
_PANIC_REVERT_SCHEDULER = "mikrotik-portal-template-revert"
|
||||
_PRE_PUSH_BACKUP = "portal-template-pre-push"
|
||||
_TEMPLATE_RSC = "portal-template.rsc"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Variable extraction & rendering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def extract_variables(template_content: str) -> list[str]:
|
||||
"""Extract all undeclared variables from a Jinja2 template.
|
||||
|
||||
Returns a sorted list of variable names, excluding the built-in 'device'
|
||||
variable which is auto-populated at render time.
|
||||
"""
|
||||
ast = _env.parse(template_content)
|
||||
all_vars = meta.find_undeclared_variables(ast)
|
||||
# 'device' is a built-in variable, not user-provided
|
||||
return sorted(v for v in all_vars if v != "device")
|
||||
|
||||
|
||||
def render_template(
|
||||
template_content: str,
|
||||
device: dict,
|
||||
custom_variables: dict[str, str],
|
||||
) -> str:
|
||||
"""Render a Jinja2 template with device context and custom variables.
|
||||
|
||||
The 'device' variable is auto-populated from the device dict.
|
||||
Custom variables are user-provided at push time.
|
||||
|
||||
Uses SandboxedEnvironment to prevent template injection.
|
||||
|
||||
Args:
|
||||
template_content: Jinja2 template string.
|
||||
device: Device info dict with keys: hostname, ip_address, model.
|
||||
custom_variables: User-supplied variable values.
|
||||
|
||||
Returns:
|
||||
Rendered template string.
|
||||
|
||||
Raises:
|
||||
jinja2.TemplateSyntaxError: If template has syntax errors.
|
||||
jinja2.UndefinedError: If required variables are missing.
|
||||
"""
|
||||
context = {
|
||||
"device": {
|
||||
"hostname": device.get("hostname", ""),
|
||||
"ip": device.get("ip_address", ""),
|
||||
"model": device.get("model", ""),
|
||||
},
|
||||
**custom_variables,
|
||||
}
|
||||
tpl = _env.from_string(template_content)
|
||||
return tpl.render(context)
|
||||
|
||||
|
||||
def validate_variable(name: str, value: str, var_type: str) -> str | None:
|
||||
"""Validate a variable value against its declared type.
|
||||
|
||||
Returns None on success, or an error message string on failure.
|
||||
"""
|
||||
if var_type == "string":
|
||||
return None # any string is valid
|
||||
elif var_type == "ip":
|
||||
try:
|
||||
ipaddress.ip_address(value)
|
||||
return None
|
||||
except ValueError:
|
||||
return f"'{name}' must be a valid IP address"
|
||||
elif var_type == "subnet":
|
||||
try:
|
||||
ipaddress.ip_network(value, strict=False)
|
||||
return None
|
||||
except ValueError:
|
||||
return f"'{name}' must be a valid subnet (e.g., 192.168.1.0/24)"
|
||||
elif var_type == "integer":
|
||||
try:
|
||||
int(value)
|
||||
return None
|
||||
except ValueError:
|
||||
return f"'{name}' must be an integer"
|
||||
elif var_type == "boolean":
|
||||
if value.lower() in ("true", "false", "yes", "no", "1", "0"):
|
||||
return None
|
||||
return f"'{name}' must be a boolean (true/false)"
|
||||
return None # unknown type, allow
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-device push orchestration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def push_to_devices(rollout_id: str) -> dict:
|
||||
"""Execute sequential template push for all jobs in a rollout.
|
||||
|
||||
Processes devices one at a time. If any device fails or reverts,
|
||||
remaining jobs stay pending (paused). Follows the same pattern as
|
||||
firmware upgrade_service.start_mass_upgrade.
|
||||
|
||||
This runs as a background task (asyncio.create_task) after the
|
||||
API creates the push jobs and returns the rollout_id.
|
||||
"""
|
||||
try:
|
||||
return await _run_push_rollout(rollout_id)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Uncaught exception in template push rollout %s: %s",
|
||||
rollout_id, exc, exc_info=True,
|
||||
)
|
||||
return {"completed": 0, "failed": 1, "pending": 0}
|
||||
|
||||
|
||||
async def _run_push_rollout(rollout_id: str) -> dict:
|
||||
"""Internal rollout implementation."""
|
||||
# Load all jobs for this rollout
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT j.id::text, j.status, d.hostname
|
||||
FROM template_push_jobs j
|
||||
JOIN devices d ON d.id = j.device_id
|
||||
WHERE j.rollout_id = CAST(:rollout_id AS uuid)
|
||||
ORDER BY j.created_at ASC
|
||||
"""),
|
||||
{"rollout_id": rollout_id},
|
||||
)
|
||||
jobs = result.fetchall()
|
||||
|
||||
if not jobs:
|
||||
logger.warning("No jobs found for template push rollout %s", rollout_id)
|
||||
return {"completed": 0, "failed": 0, "pending": 0}
|
||||
|
||||
completed = 0
|
||||
failed = False
|
||||
|
||||
for job_id, current_status, hostname in jobs:
|
||||
if current_status != "pending":
|
||||
if current_status == "committed":
|
||||
completed += 1
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"Template push rollout %s: pushing to device %s (job %s)",
|
||||
rollout_id, hostname, job_id,
|
||||
)
|
||||
|
||||
await push_single_device(job_id)
|
||||
|
||||
# Check resulting status
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("SELECT status FROM template_push_jobs WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": job_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
|
||||
if row and row[0] == "committed":
|
||||
completed += 1
|
||||
elif row and row[0] in ("failed", "reverted"):
|
||||
failed = True
|
||||
logger.error(
|
||||
"Template push rollout %s paused: device %s %s",
|
||||
rollout_id, hostname, row[0],
|
||||
)
|
||||
break
|
||||
|
||||
# Count remaining pending jobs
|
||||
remaining = sum(1 for _, s, _ in jobs if s == "pending") - completed - (1 if failed else 0)
|
||||
|
||||
return {
|
||||
"completed": completed,
|
||||
"failed": 1 if failed else 0,
|
||||
"pending": max(0, remaining),
|
||||
}
|
||||
|
||||
|
||||
async def push_single_device(job_id: str) -> None:
|
||||
"""Push rendered template content to a single device.
|
||||
|
||||
Implements the two-phase panic-revert pattern:
|
||||
1. Pre-backup (mandatory)
|
||||
2. Install panic-revert scheduler on device
|
||||
3. Write template content as RSC file via SFTP
|
||||
4. /import the RSC file
|
||||
5. Wait 60s for config to settle
|
||||
6. Reachability check -> committed or reverted
|
||||
|
||||
All errors are caught and recorded in the job row.
|
||||
"""
|
||||
try:
|
||||
await _run_single_push(job_id)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Uncaught exception in template push job %s: %s",
|
||||
job_id, exc, exc_info=True,
|
||||
)
|
||||
await _update_job(job_id, status="failed", error_message=f"Unexpected error: {exc}")
|
||||
|
||||
|
||||
async def _run_single_push(job_id: str) -> None:
|
||||
"""Internal single-device push implementation."""
|
||||
|
||||
# Step 1: Load job and device info
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT j.id, j.device_id, j.tenant_id, j.rendered_content,
|
||||
d.ip_address, d.hostname, d.encrypted_credentials,
|
||||
d.encrypted_credentials_transit
|
||||
FROM template_push_jobs j
|
||||
JOIN devices d ON d.id = j.device_id
|
||||
WHERE j.id = CAST(:job_id AS uuid)
|
||||
"""),
|
||||
{"job_id": job_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
|
||||
if not row:
|
||||
logger.error("Template push job %s not found", job_id)
|
||||
return
|
||||
|
||||
(
|
||||
_, device_id, tenant_id, rendered_content,
|
||||
ip_address, hostname, encrypted_credentials,
|
||||
encrypted_credentials_transit,
|
||||
) = row
|
||||
|
||||
device_id = str(device_id)
|
||||
tenant_id = str(tenant_id)
|
||||
hostname = hostname or ip_address
|
||||
|
||||
# Step 2: Update status to pushing
|
||||
await _update_job(job_id, status="pushing", started_at=datetime.now(timezone.utc))
|
||||
|
||||
# Step 3: Decrypt credentials (dual-read: Transit preferred, legacy fallback)
|
||||
if not encrypted_credentials_transit and not encrypted_credentials:
|
||||
await _update_job(job_id, status="failed", error_message="Device has no stored credentials")
|
||||
return
|
||||
|
||||
try:
|
||||
from app.services.crypto import decrypt_credentials_hybrid
|
||||
key = settings.get_encryption_key_bytes()
|
||||
creds_json = await decrypt_credentials_hybrid(
|
||||
encrypted_credentials_transit, encrypted_credentials, tenant_id, key,
|
||||
)
|
||||
creds = json.loads(creds_json)
|
||||
ssh_username = creds.get("username", "")
|
||||
ssh_password = creds.get("password", "")
|
||||
except Exception as cred_err:
|
||||
await _update_job(
|
||||
job_id, status="failed",
|
||||
error_message=f"Failed to decrypt credentials: {cred_err}",
|
||||
)
|
||||
return
|
||||
|
||||
# Step 4: Mandatory pre-push backup
|
||||
logger.info("Running mandatory pre-push backup for device %s (%s)", hostname, ip_address)
|
||||
try:
|
||||
from app.services import backup_service
|
||||
backup_result = await backup_service.run_backup(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
trigger_type="pre-template-push",
|
||||
)
|
||||
backup_sha = backup_result["commit_sha"]
|
||||
await _update_job(job_id, pre_push_backup_sha=backup_sha)
|
||||
logger.info("Pre-push backup complete: %s", backup_sha[:8])
|
||||
except Exception as backup_err:
|
||||
logger.error("Pre-push backup failed for %s: %s", hostname, backup_err)
|
||||
await _update_job(
|
||||
job_id, status="failed",
|
||||
error_message=f"Pre-push backup failed: {backup_err}",
|
||||
)
|
||||
return
|
||||
|
||||
# Step 5: SSH to device - install panic-revert, push config
|
||||
logger.info(
|
||||
"Pushing template to device %s (%s): installing panic-revert and uploading config",
|
||||
hostname, ip_address,
|
||||
)
|
||||
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip_address,
|
||||
port=22,
|
||||
username=ssh_username,
|
||||
password=ssh_password,
|
||||
known_hosts=None,
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
# 5a: Create binary backup on device as revert point
|
||||
await conn.run(
|
||||
f"/system backup save name={_PRE_PUSH_BACKUP} dont-encrypt=yes",
|
||||
check=True,
|
||||
)
|
||||
logger.debug("Pre-push binary backup saved on device as %s.backup", _PRE_PUSH_BACKUP)
|
||||
|
||||
# 5b: Install panic-revert RouterOS scheduler
|
||||
await conn.run(
|
||||
f"/system scheduler add "
|
||||
f'name="{_PANIC_REVERT_SCHEDULER}" '
|
||||
f"interval=90s "
|
||||
f'on-event=":delay 0; /system backup load name={_PRE_PUSH_BACKUP}" '
|
||||
f"start-time=startup",
|
||||
check=True,
|
||||
)
|
||||
logger.debug("Panic-revert scheduler installed on device")
|
||||
|
||||
# 5c: Upload rendered template as RSC file via SFTP
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
async with sftp.open(_TEMPLATE_RSC, "wb") as f:
|
||||
await f.write(rendered_content.encode("utf-8"))
|
||||
logger.debug("Uploaded %s to device flash", _TEMPLATE_RSC)
|
||||
|
||||
# 5d: /import the config file
|
||||
import_result = await conn.run(
|
||||
f"/import file={_TEMPLATE_RSC}",
|
||||
check=False,
|
||||
)
|
||||
logger.info(
|
||||
"Template import result for device %s: exit_status=%s stdout=%r",
|
||||
hostname, import_result.exit_status,
|
||||
(import_result.stdout or "")[:200],
|
||||
)
|
||||
|
||||
# 5e: Clean up the uploaded RSC file (best-effort)
|
||||
try:
|
||||
await conn.run(f"/file remove {_TEMPLATE_RSC}", check=True)
|
||||
except Exception as cleanup_err:
|
||||
logger.warning(
|
||||
"Failed to clean up %s from device %s: %s",
|
||||
_TEMPLATE_RSC, ip_address, cleanup_err,
|
||||
)
|
||||
|
||||
except Exception as push_err:
|
||||
logger.error(
|
||||
"SSH push phase failed for device %s (%s): %s",
|
||||
hostname, ip_address, push_err,
|
||||
)
|
||||
await _update_job(
|
||||
job_id, status="failed",
|
||||
error_message=f"Config push failed during SSH phase: {push_err}",
|
||||
)
|
||||
return
|
||||
|
||||
# Step 6: Wait 60s for config to settle
|
||||
logger.info("Template pushed to device %s - waiting 60s for config to settle", hostname)
|
||||
await asyncio.sleep(60)
|
||||
|
||||
# Step 7: Reachability check
|
||||
reachable = await _check_reachability(ip_address, ssh_username, ssh_password)
|
||||
|
||||
if reachable:
|
||||
# Step 8a: Device is reachable - remove panic-revert scheduler + cleanup
|
||||
logger.info("Device %s (%s) is reachable after push - committing", hostname, ip_address)
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip_address, port=22,
|
||||
username=ssh_username, password=ssh_password,
|
||||
known_hosts=None, connect_timeout=30,
|
||||
) as conn:
|
||||
await conn.run(
|
||||
f'/system scheduler remove "{_PANIC_REVERT_SCHEDULER}"',
|
||||
check=False,
|
||||
)
|
||||
await conn.run(
|
||||
f"/file remove {_PRE_PUSH_BACKUP}.backup",
|
||||
check=False,
|
||||
)
|
||||
except Exception as cleanup_err:
|
||||
logger.warning(
|
||||
"Failed to clean up panic-revert scheduler/backup on device %s: %s",
|
||||
hostname, cleanup_err,
|
||||
)
|
||||
|
||||
await _update_job(
|
||||
job_id, status="committed",
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
else:
|
||||
# Step 8b: Device unreachable - RouterOS is auto-reverting
|
||||
logger.warning(
|
||||
"Device %s (%s) is unreachable after push - panic-revert scheduler "
|
||||
"will auto-revert to %s.backup",
|
||||
hostname, ip_address, _PRE_PUSH_BACKUP,
|
||||
)
|
||||
await _update_job(
|
||||
job_id, status="reverted",
|
||||
error_message="Device unreachable after push; auto-reverted via panic-revert scheduler",
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _check_reachability(ip: str, username: str, password: str) -> bool:
|
||||
"""Check if a RouterOS device is reachable via SSH."""
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip, port=22,
|
||||
username=username, password=password,
|
||||
known_hosts=None, connect_timeout=30,
|
||||
) as conn:
|
||||
result = await conn.run("/system identity print", check=True)
|
||||
logger.debug("Reachability check OK for %s: %r", ip, result.stdout[:50])
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.info("Device %s unreachable after push: %s", ip, exc)
|
||||
return False
|
||||
|
||||
|
||||
async def _update_job(job_id: str, **kwargs) -> None:
|
||||
"""Update TemplatePushJob fields via raw SQL (background task, no RLS)."""
|
||||
sets = []
|
||||
params: dict = {"job_id": job_id}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
param_name = f"v_{key}"
|
||||
if value is None and key in ("error_message", "started_at", "completed_at", "pre_push_backup_sha"):
|
||||
sets.append(f"{key} = NULL")
|
||||
else:
|
||||
sets.append(f"{key} = :{param_name}")
|
||||
params[param_name] = value
|
||||
|
||||
if not sets:
|
||||
return
|
||||
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
await session.execute(
|
||||
text(f"""
|
||||
UPDATE template_push_jobs
|
||||
SET {', '.join(sets)}
|
||||
WHERE id = CAST(:job_id AS uuid)
|
||||
"""),
|
||||
params,
|
||||
)
|
||||
await session.commit()
|
||||
564
backend/app/services/upgrade_service.py
Normal file
564
backend/app/services/upgrade_service.py
Normal file
@@ -0,0 +1,564 @@
|
||||
"""Firmware upgrade orchestration service.
|
||||
|
||||
Handles single-device and mass firmware upgrades with:
|
||||
- Mandatory pre-upgrade config backup
|
||||
- NPK download and SFTP upload to device
|
||||
- Reboot trigger and reconnect polling
|
||||
- Post-upgrade version verification
|
||||
- Sequential mass rollout with pause-on-failure
|
||||
- Scheduled upgrades via APScheduler DateTrigger
|
||||
|
||||
All DB operations use AdminAsyncSessionLocal to bypass RLS since upgrade
|
||||
jobs may span multiple tenants and run in background asyncio tasks.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import asyncssh
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
from app.services.event_publisher import publish_event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum time to wait for a device to reconnect after reboot (seconds)
|
||||
_RECONNECT_TIMEOUT = 300 # 5 minutes
|
||||
_RECONNECT_POLL_INTERVAL = 15 # seconds
|
||||
_INITIAL_WAIT = 60 # Wait before first reconnect attempt (boot cycle)
|
||||
|
||||
|
||||
async def start_upgrade(job_id: str) -> None:
|
||||
"""Execute a single device firmware upgrade.
|
||||
|
||||
Lifecycle: pending -> downloading -> uploading -> rebooting -> verifying -> completed/failed
|
||||
|
||||
This function is designed to run as a background asyncio.create_task or
|
||||
APScheduler job. It never raises — all errors are caught and recorded
|
||||
in the FirmwareUpgradeJob row.
|
||||
"""
|
||||
try:
|
||||
await _run_upgrade(job_id)
|
||||
except Exception as exc:
|
||||
logger.error("Uncaught exception in firmware upgrade %s: %s", job_id, exc, exc_info=True)
|
||||
await _update_job(job_id, status="failed", error_message=f"Unexpected error: {exc}")
|
||||
|
||||
|
||||
async def _publish_upgrade_progress(
|
||||
tenant_id: str,
|
||||
device_id: str,
|
||||
job_id: str,
|
||||
stage: str,
|
||||
target_version: str,
|
||||
message: str,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Publish firmware upgrade progress event to NATS (fire-and-forget)."""
|
||||
payload = {
|
||||
"event_type": "firmware_progress",
|
||||
"tenant_id": tenant_id,
|
||||
"device_id": device_id,
|
||||
"job_id": job_id,
|
||||
"stage": stage,
|
||||
"target_version": target_version,
|
||||
"message": message,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
if error:
|
||||
payload["error"] = error
|
||||
await publish_event(f"firmware.progress.{tenant_id}.{device_id}", payload)
|
||||
|
||||
|
||||
async def _run_upgrade(job_id: str) -> None:
|
||||
"""Internal upgrade implementation."""
|
||||
|
||||
# Step 1: Load job
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT j.id, j.device_id, j.tenant_id, j.target_version,
|
||||
j.architecture, j.channel, j.status, j.confirmed_major_upgrade,
|
||||
d.ip_address, d.hostname, d.encrypted_credentials,
|
||||
d.routeros_version, d.encrypted_credentials_transit
|
||||
FROM firmware_upgrade_jobs j
|
||||
JOIN devices d ON d.id = j.device_id
|
||||
WHERE j.id = CAST(:job_id AS uuid)
|
||||
"""),
|
||||
{"job_id": job_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
|
||||
if not row:
|
||||
logger.error("Upgrade job %s not found", job_id)
|
||||
return
|
||||
|
||||
(
|
||||
_, device_id, tenant_id, target_version,
|
||||
architecture, channel, status, confirmed_major,
|
||||
ip_address, hostname, encrypted_credentials,
|
||||
current_version, encrypted_credentials_transit,
|
||||
) = row
|
||||
|
||||
device_id = str(device_id)
|
||||
tenant_id = str(tenant_id)
|
||||
hostname = hostname or ip_address
|
||||
|
||||
# Skip if already running or completed
|
||||
if status not in ("pending", "scheduled"):
|
||||
logger.info("Upgrade job %s already in status %s — skipping", job_id, status)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Starting firmware upgrade for %s (%s): %s -> %s",
|
||||
hostname, ip_address, current_version, target_version,
|
||||
)
|
||||
|
||||
# Step 2: Update status to downloading
|
||||
await _update_job(job_id, status="downloading", started_at=datetime.now(timezone.utc))
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "downloading", target_version, f"Downloading firmware {target_version} for {hostname}")
|
||||
|
||||
# Step 3: Check major version upgrade confirmation
|
||||
if current_version and target_version:
|
||||
current_major = current_version.split(".")[0] if current_version else ""
|
||||
target_major = target_version.split(".")[0]
|
||||
if current_major != target_major and not confirmed_major:
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message="Major version upgrade requires explicit confirmation",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Major version upgrade requires explicit confirmation for {hostname}", error="Major version upgrade requires explicit confirmation")
|
||||
return
|
||||
|
||||
# Step 4: Mandatory config backup
|
||||
logger.info("Running mandatory pre-upgrade backup for %s", hostname)
|
||||
try:
|
||||
from app.services import backup_service
|
||||
backup_result = await backup_service.run_backup(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
trigger_type="pre-upgrade",
|
||||
)
|
||||
backup_sha = backup_result["commit_sha"]
|
||||
await _update_job(job_id, pre_upgrade_backup_sha=backup_sha)
|
||||
logger.info("Pre-upgrade backup complete: %s", backup_sha[:8])
|
||||
except Exception as backup_err:
|
||||
logger.error("Pre-upgrade backup failed for %s: %s", hostname, backup_err)
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message=f"Pre-upgrade backup failed: {backup_err}",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Pre-upgrade backup failed for {hostname}", error=str(backup_err))
|
||||
return
|
||||
|
||||
# Step 5: Download NPK
|
||||
logger.info("Downloading firmware %s for %s/%s", target_version, architecture, channel)
|
||||
try:
|
||||
from app.services.firmware_service import download_firmware
|
||||
npk_path = await download_firmware(architecture, channel, target_version)
|
||||
logger.info("Firmware cached at %s", npk_path)
|
||||
except Exception as dl_err:
|
||||
logger.error("Firmware download failed: %s", dl_err)
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message=f"Firmware download failed: {dl_err}",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Firmware download failed for {hostname}", error=str(dl_err))
|
||||
return
|
||||
|
||||
# Step 6: Upload NPK to device via SFTP
|
||||
await _update_job(job_id, status="uploading")
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "uploading", target_version, f"Uploading firmware to {hostname}")
|
||||
|
||||
# Decrypt device credentials (dual-read: Transit preferred, legacy fallback)
|
||||
if not encrypted_credentials_transit and not encrypted_credentials:
|
||||
await _update_job(job_id, status="failed", error_message="Device has no stored credentials")
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"No stored credentials for {hostname}", error="Device has no stored credentials")
|
||||
return
|
||||
|
||||
try:
|
||||
from app.services.crypto import decrypt_credentials_hybrid
|
||||
key = settings.get_encryption_key_bytes()
|
||||
creds_json = await decrypt_credentials_hybrid(
|
||||
encrypted_credentials_transit, encrypted_credentials, tenant_id, key,
|
||||
)
|
||||
creds = json.loads(creds_json)
|
||||
ssh_username = creds.get("username", "")
|
||||
ssh_password = creds.get("password", "")
|
||||
except Exception as cred_err:
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message=f"Failed to decrypt credentials: {cred_err}",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Failed to decrypt credentials for {hostname}", error=str(cred_err))
|
||||
return
|
||||
|
||||
try:
|
||||
npk_data = Path(npk_path).read_bytes()
|
||||
npk_filename = Path(npk_path).name
|
||||
|
||||
async with asyncssh.connect(
|
||||
ip_address,
|
||||
port=22,
|
||||
username=ssh_username,
|
||||
password=ssh_password,
|
||||
known_hosts=None,
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
async with sftp.open(f"/{npk_filename}", "wb") as f:
|
||||
await f.write(npk_data)
|
||||
logger.info("Uploaded %s to %s", npk_filename, hostname)
|
||||
except Exception as upload_err:
|
||||
logger.error("NPK upload failed for %s: %s", hostname, upload_err)
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message=f"NPK upload failed: {upload_err}",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"NPK upload failed for {hostname}", error=str(upload_err))
|
||||
return
|
||||
|
||||
# Step 7: Trigger reboot
|
||||
await _update_job(job_id, status="rebooting")
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "rebooting", target_version, f"Rebooting {hostname} for firmware install")
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip_address,
|
||||
port=22,
|
||||
username=ssh_username,
|
||||
password=ssh_password,
|
||||
known_hosts=None,
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
# RouterOS will install NPK on boot
|
||||
await conn.run("/system reboot", check=False)
|
||||
logger.info("Reboot command sent to %s", hostname)
|
||||
except Exception as reboot_err:
|
||||
# Device may drop connection during reboot — this is expected
|
||||
logger.info("Device %s dropped connection after reboot command (expected): %s", hostname, reboot_err)
|
||||
|
||||
# Step 8: Wait for reconnect
|
||||
logger.info("Waiting %ds before polling %s for reconnect", _INITIAL_WAIT, hostname)
|
||||
await asyncio.sleep(_INITIAL_WAIT)
|
||||
|
||||
reconnected = False
|
||||
elapsed = 0
|
||||
while elapsed < _RECONNECT_TIMEOUT:
|
||||
if await _check_ssh_reachable(ip_address, ssh_username, ssh_password):
|
||||
reconnected = True
|
||||
break
|
||||
await asyncio.sleep(_RECONNECT_POLL_INTERVAL)
|
||||
elapsed += _RECONNECT_POLL_INTERVAL
|
||||
|
||||
if not reconnected:
|
||||
logger.error("Device %s did not reconnect within %ds", hostname, _RECONNECT_TIMEOUT)
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message=f"Device did not reconnect within {_RECONNECT_TIMEOUT // 60} minutes after reboot",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Device {hostname} did not reconnect within {_RECONNECT_TIMEOUT // 60} minutes", error="Reconnect timeout")
|
||||
return
|
||||
|
||||
# Step 9: Verify upgrade
|
||||
await _update_job(job_id, status="verifying")
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "verifying", target_version, f"Verifying firmware version on {hostname}")
|
||||
try:
|
||||
actual_version = await _get_device_version(ip_address, ssh_username, ssh_password)
|
||||
if actual_version and target_version in actual_version:
|
||||
logger.info(
|
||||
"Firmware upgrade verified for %s: %s",
|
||||
hostname, actual_version,
|
||||
)
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="completed",
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "completed", target_version, f"Firmware upgrade to {target_version} completed on {hostname}")
|
||||
else:
|
||||
logger.error(
|
||||
"Version mismatch for %s: expected %s, got %s",
|
||||
hostname, target_version, actual_version,
|
||||
)
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message=f"Expected {target_version} but got {actual_version}",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Version mismatch on {hostname}: expected {target_version}, got {actual_version}", error=f"Expected {target_version} but got {actual_version}")
|
||||
except Exception as verify_err:
|
||||
logger.error("Post-upgrade verification failed for %s: %s", hostname, verify_err)
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message=f"Post-upgrade verification failed: {verify_err}",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Post-upgrade verification failed for {hostname}", error=str(verify_err))
|
||||
|
||||
|
||||
async def start_mass_upgrade(rollout_group_id: str) -> dict:
|
||||
"""Execute a sequential mass firmware upgrade.
|
||||
|
||||
Processes upgrade jobs one at a time. If any device fails,
|
||||
all remaining jobs in the group are paused.
|
||||
|
||||
Returns summary dict with completed/failed/paused counts.
|
||||
"""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
SELECT j.id, j.status, d.hostname
|
||||
FROM firmware_upgrade_jobs j
|
||||
JOIN devices d ON d.id = j.device_id
|
||||
WHERE j.rollout_group_id = CAST(:group_id AS uuid)
|
||||
ORDER BY j.created_at ASC
|
||||
"""),
|
||||
{"group_id": rollout_group_id},
|
||||
)
|
||||
jobs = result.fetchall()
|
||||
|
||||
if not jobs:
|
||||
logger.warning("No jobs found for rollout group %s", rollout_group_id)
|
||||
return {"completed": 0, "failed": 0, "paused": 0}
|
||||
|
||||
completed = 0
|
||||
failed_device = None
|
||||
|
||||
for job_id, current_status, hostname in jobs:
|
||||
job_id_str = str(job_id)
|
||||
|
||||
# Only process pending/scheduled jobs
|
||||
if current_status not in ("pending", "scheduled"):
|
||||
if current_status == "completed":
|
||||
completed += 1
|
||||
continue
|
||||
|
||||
logger.info("Mass rollout: upgrading device %s (job %s)", hostname, job_id_str)
|
||||
await start_upgrade(job_id_str)
|
||||
|
||||
# Check if it completed or failed
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("SELECT status FROM firmware_upgrade_jobs WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": job_id_str},
|
||||
)
|
||||
row = result.fetchone()
|
||||
|
||||
if row and row[0] == "completed":
|
||||
completed += 1
|
||||
elif row and row[0] == "failed":
|
||||
failed_device = hostname
|
||||
logger.error("Mass rollout paused: %s failed", hostname)
|
||||
break
|
||||
|
||||
# Pause remaining jobs if one failed
|
||||
paused = 0
|
||||
if failed_device:
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
UPDATE firmware_upgrade_jobs
|
||||
SET status = 'paused'
|
||||
WHERE rollout_group_id = CAST(:group_id AS uuid)
|
||||
AND status IN ('pending', 'scheduled')
|
||||
RETURNING id
|
||||
"""),
|
||||
{"group_id": rollout_group_id},
|
||||
)
|
||||
paused = len(result.fetchall())
|
||||
await session.commit()
|
||||
|
||||
return {
|
||||
"completed": completed,
|
||||
"failed": 1 if failed_device else 0,
|
||||
"failed_device": failed_device,
|
||||
"paused": paused,
|
||||
}
|
||||
|
||||
|
||||
def schedule_upgrade(job_id: str, scheduled_at: datetime) -> None:
|
||||
"""Schedule a firmware upgrade for future execution via APScheduler."""
|
||||
from app.services.backup_scheduler import backup_scheduler
|
||||
|
||||
backup_scheduler.add_job(
|
||||
start_upgrade,
|
||||
trigger="date",
|
||||
run_date=scheduled_at,
|
||||
args=[job_id],
|
||||
id=f"fw_upgrade_{job_id}",
|
||||
name=f"Firmware upgrade {job_id}",
|
||||
max_instances=1,
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info("Scheduled firmware upgrade %s for %s", job_id, scheduled_at)
|
||||
|
||||
|
||||
def schedule_mass_upgrade(rollout_group_id: str, scheduled_at: datetime) -> None:
|
||||
"""Schedule a mass firmware upgrade for future execution."""
|
||||
from app.services.backup_scheduler import backup_scheduler
|
||||
|
||||
backup_scheduler.add_job(
|
||||
start_mass_upgrade,
|
||||
trigger="date",
|
||||
run_date=scheduled_at,
|
||||
args=[rollout_group_id],
|
||||
id=f"fw_mass_upgrade_{rollout_group_id}",
|
||||
name=f"Mass firmware upgrade {rollout_group_id}",
|
||||
max_instances=1,
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info("Scheduled mass firmware upgrade %s for %s", rollout_group_id, scheduled_at)
|
||||
|
||||
|
||||
async def cancel_upgrade(job_id: str) -> None:
|
||||
"""Cancel a scheduled or pending upgrade."""
|
||||
from app.services.backup_scheduler import backup_scheduler
|
||||
|
||||
# Remove APScheduler job if it exists
|
||||
try:
|
||||
backup_scheduler.remove_job(f"fw_upgrade_{job_id}")
|
||||
except Exception:
|
||||
pass # Job might not be scheduled
|
||||
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message="Cancelled by operator",
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
logger.info("Upgrade job %s cancelled", job_id)
|
||||
|
||||
|
||||
async def retry_failed_upgrade(job_id: str) -> None:
|
||||
"""Reset a failed upgrade job to pending and re-execute."""
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="pending",
|
||||
error_message=None,
|
||||
started_at=None,
|
||||
completed_at=None,
|
||||
)
|
||||
asyncio.create_task(start_upgrade(job_id))
|
||||
logger.info("Retrying upgrade job %s", job_id)
|
||||
|
||||
|
||||
async def resume_mass_upgrade(rollout_group_id: str) -> None:
|
||||
"""Resume a paused mass rollout from where it left off."""
|
||||
# Reset first paused job to pending, then restart sequential processing
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
UPDATE firmware_upgrade_jobs
|
||||
SET status = 'pending'
|
||||
WHERE rollout_group_id = CAST(:group_id AS uuid)
|
||||
AND status = 'paused'
|
||||
"""),
|
||||
{"group_id": rollout_group_id},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
asyncio.create_task(start_mass_upgrade(rollout_group_id))
|
||||
logger.info("Resuming mass rollout %s", rollout_group_id)
|
||||
|
||||
|
||||
async def abort_mass_upgrade(rollout_group_id: str) -> int:
|
||||
"""Abort all remaining jobs in a paused mass rollout."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("""
|
||||
UPDATE firmware_upgrade_jobs
|
||||
SET status = 'failed',
|
||||
error_message = 'Aborted by operator',
|
||||
completed_at = NOW()
|
||||
WHERE rollout_group_id = CAST(:group_id AS uuid)
|
||||
AND status IN ('pending', 'scheduled', 'paused')
|
||||
RETURNING id
|
||||
"""),
|
||||
{"group_id": rollout_group_id},
|
||||
)
|
||||
aborted = len(result.fetchall())
|
||||
await session.commit()
|
||||
|
||||
logger.info("Aborted %d remaining jobs in rollout %s", aborted, rollout_group_id)
|
||||
return aborted
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _update_job(job_id: str, **kwargs) -> None:
|
||||
"""Update FirmwareUpgradeJob fields."""
|
||||
sets = []
|
||||
params: dict = {"job_id": job_id}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
param_name = f"v_{key}"
|
||||
if value is None and key in ("error_message", "started_at", "completed_at"):
|
||||
sets.append(f"{key} = NULL")
|
||||
else:
|
||||
sets.append(f"{key} = :{param_name}")
|
||||
params[param_name] = value
|
||||
|
||||
if not sets:
|
||||
return
|
||||
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
await session.execute(
|
||||
text(f"""
|
||||
UPDATE firmware_upgrade_jobs
|
||||
SET {', '.join(sets)}
|
||||
WHERE id = CAST(:job_id AS uuid)
|
||||
"""),
|
||||
params,
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def _check_ssh_reachable(ip: str, username: str, password: str) -> bool:
|
||||
"""Check if a device is reachable via SSH."""
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip,
|
||||
port=22,
|
||||
username=username,
|
||||
password=password,
|
||||
known_hosts=None,
|
||||
connect_timeout=15,
|
||||
) as conn:
|
||||
await conn.run("/system identity print", check=True)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def _get_device_version(ip: str, username: str, password: str) -> str:
|
||||
"""Get the current RouterOS version from a device via SSH."""
|
||||
async with asyncssh.connect(
|
||||
ip,
|
||||
port=22,
|
||||
username=username,
|
||||
password=password,
|
||||
known_hosts=None,
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
result = await conn.run("/system resource print", check=True)
|
||||
# Parse version from output: "version: 7.17 (stable)"
|
||||
for line in result.stdout.splitlines():
|
||||
if "version" in line.lower():
|
||||
parts = line.split(":", 1)
|
||||
if len(parts) == 2:
|
||||
return parts[1].strip()
|
||||
return ""
|
||||
392
backend/app/services/vpn_service.py
Normal file
392
backend/app/services/vpn_service.py
Normal file
@@ -0,0 +1,392 @@
|
||||
"""WireGuard VPN management service.
|
||||
|
||||
Handles key generation, peer management, config file sync, and RouterOS command generation.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import ipaddress
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import structlog
|
||||
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
|
||||
from cryptography.hazmat.primitives.serialization import (
|
||||
Encoding,
|
||||
NoEncryption,
|
||||
PrivateFormat,
|
||||
PublicFormat,
|
||||
)
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.models.device import Device
|
||||
from app.models.vpn import VpnConfig, VpnPeer
|
||||
from app.services.crypto import decrypt_credentials, encrypt_credentials, encrypt_credentials_transit
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
# ── Key Generation ──
|
||||
|
||||
|
||||
def generate_wireguard_keypair() -> tuple[str, str]:
|
||||
"""Generate a WireGuard X25519 keypair. Returns (private_key_b64, public_key_b64)."""
|
||||
private_key = X25519PrivateKey.generate()
|
||||
priv_bytes = private_key.private_bytes(Encoding.Raw, PrivateFormat.Raw, NoEncryption())
|
||||
pub_bytes = private_key.public_key().public_bytes(Encoding.Raw, PublicFormat.Raw)
|
||||
return base64.b64encode(priv_bytes).decode(), base64.b64encode(pub_bytes).decode()
|
||||
|
||||
|
||||
def generate_preshared_key() -> str:
|
||||
"""Generate a WireGuard preshared key (32 random bytes, base64)."""
|
||||
return base64.b64encode(os.urandom(32)).decode()
|
||||
|
||||
|
||||
# ── Config File Management ──
|
||||
|
||||
|
||||
def _get_wg_config_path() -> Path:
|
||||
"""Return the path to the shared WireGuard config directory."""
|
||||
return Path(os.getenv("WIREGUARD_CONFIG_PATH", "/data/wireguard"))
|
||||
|
||||
|
||||
async def sync_wireguard_config(db: AsyncSession, tenant_id: uuid.UUID) -> None:
|
||||
"""Regenerate wg0.conf from database state and write to shared volume."""
|
||||
config = await get_vpn_config(db, tenant_id)
|
||||
if not config or not config.is_enabled:
|
||||
return
|
||||
|
||||
key_bytes = settings.get_encryption_key_bytes()
|
||||
server_private_key = decrypt_credentials(config.server_private_key, key_bytes)
|
||||
|
||||
result = await db.execute(
|
||||
select(VpnPeer).where(VpnPeer.tenant_id == tenant_id, VpnPeer.is_enabled.is_(True))
|
||||
)
|
||||
peers = result.scalars().all()
|
||||
|
||||
# Build wg0.conf
|
||||
lines = [
|
||||
"[Interface]",
|
||||
f"Address = {config.server_address}",
|
||||
f"ListenPort = {config.server_port}",
|
||||
f"PrivateKey = {server_private_key}",
|
||||
"",
|
||||
]
|
||||
|
||||
for peer in peers:
|
||||
peer_ip = peer.assigned_ip.split("/")[0] # strip CIDR for AllowedIPs
|
||||
allowed_ips = [f"{peer_ip}/32"]
|
||||
if peer.additional_allowed_ips:
|
||||
# Comma-separated additional subnets (e.g. site-to-site routing)
|
||||
extra = [s.strip() for s in peer.additional_allowed_ips.split(",") if s.strip()]
|
||||
allowed_ips.extend(extra)
|
||||
lines.append("[Peer]")
|
||||
lines.append(f"PublicKey = {peer.peer_public_key}")
|
||||
if peer.preshared_key:
|
||||
psk = decrypt_credentials(peer.preshared_key, key_bytes)
|
||||
lines.append(f"PresharedKey = {psk}")
|
||||
lines.append(f"AllowedIPs = {', '.join(allowed_ips)}")
|
||||
lines.append("")
|
||||
|
||||
config_dir = _get_wg_config_path()
|
||||
wg_confs_dir = config_dir / "wg_confs"
|
||||
wg_confs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
conf_path = wg_confs_dir / "wg0.conf"
|
||||
conf_path.write_text("\n".join(lines))
|
||||
|
||||
# Signal WireGuard container to reload config
|
||||
reload_flag = wg_confs_dir / ".reload"
|
||||
reload_flag.write_text("1")
|
||||
|
||||
logger.info("wireguard config synced", tenant_id=str(tenant_id), peers=len(peers))
|
||||
|
||||
|
||||
# ── Live Status ──
|
||||
|
||||
|
||||
def read_wg_status() -> dict[str, dict]:
|
||||
"""Read live WireGuard peer status from the shared volume.
|
||||
|
||||
The WireGuard container writes wg_status.json every 15 seconds
|
||||
with output from `wg show wg0 dump`. Returns a dict keyed by
|
||||
peer public key with handshake timestamp and transfer stats.
|
||||
"""
|
||||
status_path = _get_wg_config_path() / "wg_status.json"
|
||||
if not status_path.exists():
|
||||
return {}
|
||||
try:
|
||||
data = json.loads(status_path.read_text())
|
||||
return {entry["public_key"]: entry for entry in data}
|
||||
except (json.JSONDecodeError, KeyError, OSError):
|
||||
return {}
|
||||
|
||||
|
||||
def get_peer_handshake(wg_status: dict[str, dict], public_key: str) -> Optional[datetime]:
|
||||
"""Get last_handshake datetime for a peer from live WireGuard status."""
|
||||
entry = wg_status.get(public_key)
|
||||
if not entry:
|
||||
return None
|
||||
ts = entry.get("last_handshake", 0)
|
||||
if ts and ts > 0:
|
||||
return datetime.fromtimestamp(ts, tz=timezone.utc)
|
||||
return None
|
||||
|
||||
|
||||
# ── CRUD Operations ──
|
||||
|
||||
|
||||
async def get_vpn_config(db: AsyncSession, tenant_id: uuid.UUID) -> Optional[VpnConfig]:
|
||||
"""Get the VPN config for a tenant."""
|
||||
result = await db.execute(select(VpnConfig).where(VpnConfig.tenant_id == tenant_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def setup_vpn(
|
||||
db: AsyncSession, tenant_id: uuid.UUID, endpoint: Optional[str] = None
|
||||
) -> VpnConfig:
|
||||
"""Initialize VPN for a tenant — generates server keys and creates config."""
|
||||
existing = await get_vpn_config(db, tenant_id)
|
||||
if existing:
|
||||
raise ValueError("VPN already configured for this tenant")
|
||||
|
||||
private_key_b64, public_key_b64 = generate_wireguard_keypair()
|
||||
|
||||
key_bytes = settings.get_encryption_key_bytes()
|
||||
encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
|
||||
|
||||
config = VpnConfig(
|
||||
tenant_id=tenant_id,
|
||||
server_private_key=encrypted_private,
|
||||
server_public_key=public_key_b64,
|
||||
endpoint=endpoint,
|
||||
is_enabled=True,
|
||||
)
|
||||
db.add(config)
|
||||
await db.flush()
|
||||
|
||||
await sync_wireguard_config(db, tenant_id)
|
||||
return config
|
||||
|
||||
|
||||
async def update_vpn_config(
|
||||
db: AsyncSession, tenant_id: uuid.UUID, endpoint: Optional[str] = None, is_enabled: Optional[bool] = None
|
||||
) -> VpnConfig:
|
||||
"""Update VPN config settings."""
|
||||
config = await get_vpn_config(db, tenant_id)
|
||||
if not config:
|
||||
raise ValueError("VPN not configured for this tenant")
|
||||
|
||||
if endpoint is not None:
|
||||
config.endpoint = endpoint
|
||||
if is_enabled is not None:
|
||||
config.is_enabled = is_enabled
|
||||
|
||||
await db.flush()
|
||||
await sync_wireguard_config(db, tenant_id)
|
||||
return config
|
||||
|
||||
|
||||
async def get_peers(db: AsyncSession, tenant_id: uuid.UUID) -> list[VpnPeer]:
|
||||
"""List all VPN peers for a tenant."""
|
||||
result = await db.execute(
|
||||
select(VpnPeer).where(VpnPeer.tenant_id == tenant_id).order_by(VpnPeer.created_at)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def _next_available_ip(db: AsyncSession, tenant_id: uuid.UUID, config: VpnConfig) -> str:
|
||||
"""Allocate the next available IP in the VPN subnet."""
|
||||
# Parse subnet: e.g. "10.10.0.0/24" → start from .2 (server is .1)
|
||||
network = ipaddress.ip_network(config.subnet, strict=False)
|
||||
hosts = list(network.hosts())
|
||||
|
||||
# Get already assigned IPs
|
||||
result = await db.execute(select(VpnPeer.assigned_ip).where(VpnPeer.tenant_id == tenant_id))
|
||||
used_ips = {row[0].split("/")[0] for row in result.all()}
|
||||
used_ips.add(config.server_address.split("/")[0]) # exclude server IP
|
||||
|
||||
for host in hosts[1:]: # skip .1 (server)
|
||||
if str(host) not in used_ips:
|
||||
return f"{host}/24"
|
||||
|
||||
raise ValueError("No available IPs in VPN subnet")
|
||||
|
||||
|
||||
async def add_peer(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UUID, additional_allowed_ips: Optional[str] = None) -> VpnPeer:
|
||||
"""Add a device as a VPN peer."""
|
||||
config = await get_vpn_config(db, tenant_id)
|
||||
if not config:
|
||||
raise ValueError("VPN not configured — enable VPN first")
|
||||
|
||||
# Check device exists
|
||||
device = await db.execute(select(Device).where(Device.id == device_id, Device.tenant_id == tenant_id))
|
||||
if not device.scalar_one_or_none():
|
||||
raise ValueError("Device not found")
|
||||
|
||||
# Check if already a peer
|
||||
existing = await db.execute(select(VpnPeer).where(VpnPeer.device_id == device_id))
|
||||
if existing.scalar_one_or_none():
|
||||
raise ValueError("Device is already a VPN peer")
|
||||
|
||||
private_key_b64, public_key_b64 = generate_wireguard_keypair()
|
||||
psk = generate_preshared_key()
|
||||
|
||||
key_bytes = settings.get_encryption_key_bytes()
|
||||
encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
|
||||
encrypted_psk = encrypt_credentials(psk, key_bytes)
|
||||
|
||||
assigned_ip = await _next_available_ip(db, tenant_id, config)
|
||||
|
||||
peer = VpnPeer(
|
||||
tenant_id=tenant_id,
|
||||
device_id=device_id,
|
||||
peer_private_key=encrypted_private,
|
||||
peer_public_key=public_key_b64,
|
||||
preshared_key=encrypted_psk,
|
||||
assigned_ip=assigned_ip,
|
||||
additional_allowed_ips=additional_allowed_ips,
|
||||
)
|
||||
db.add(peer)
|
||||
await db.flush()
|
||||
|
||||
await sync_wireguard_config(db, tenant_id)
|
||||
return peer
|
||||
|
||||
|
||||
async def remove_peer(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.UUID) -> None:
|
||||
"""Remove a VPN peer."""
|
||||
result = await db.execute(
|
||||
select(VpnPeer).where(VpnPeer.id == peer_id, VpnPeer.tenant_id == tenant_id)
|
||||
)
|
||||
peer = result.scalar_one_or_none()
|
||||
if not peer:
|
||||
raise ValueError("Peer not found")
|
||||
|
||||
await db.delete(peer)
|
||||
await db.flush()
|
||||
await sync_wireguard_config(db, tenant_id)
|
||||
|
||||
|
||||
async def get_peer_config(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.UUID) -> dict:
|
||||
"""Get the full config for a peer — includes private key for device setup."""
|
||||
config = await get_vpn_config(db, tenant_id)
|
||||
if not config:
|
||||
raise ValueError("VPN not configured")
|
||||
|
||||
result = await db.execute(
|
||||
select(VpnPeer).where(VpnPeer.id == peer_id, VpnPeer.tenant_id == tenant_id)
|
||||
)
|
||||
peer = result.scalar_one_or_none()
|
||||
if not peer:
|
||||
raise ValueError("Peer not found")
|
||||
|
||||
key_bytes = settings.get_encryption_key_bytes()
|
||||
private_key = decrypt_credentials(peer.peer_private_key, key_bytes)
|
||||
psk = decrypt_credentials(peer.preshared_key, key_bytes) if peer.preshared_key else None
|
||||
|
||||
endpoint = config.endpoint or "YOUR_SERVER_IP:51820"
|
||||
peer_ip_no_cidr = peer.assigned_ip.split("/")[0]
|
||||
|
||||
routeros_commands = [
|
||||
f'/interface wireguard add name=wg-portal listen-port=13231 private-key="{private_key}"',
|
||||
f'/interface wireguard peers add interface=wg-portal public-key="{config.server_public_key}" '
|
||||
f'endpoint-address={endpoint.split(":")[0]} endpoint-port={endpoint.split(":")[-1]} '
|
||||
f'allowed-address={config.subnet} persistent-keepalive=25'
|
||||
+ (f' preshared-key="{psk}"' if psk else ""),
|
||||
f"/ip address add address={peer.assigned_ip} interface=wg-portal",
|
||||
]
|
||||
|
||||
return {
|
||||
"peer_private_key": private_key,
|
||||
"peer_public_key": peer.peer_public_key,
|
||||
"assigned_ip": peer.assigned_ip,
|
||||
"server_public_key": config.server_public_key,
|
||||
"server_endpoint": endpoint,
|
||||
"allowed_ips": config.subnet,
|
||||
"routeros_commands": routeros_commands,
|
||||
}
|
||||
|
||||
|
||||
async def onboard_device(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
hostname: str,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> dict:
|
||||
"""Create device + VPN peer in one transaction. Returns device, peer, and RouterOS commands.
|
||||
|
||||
Unlike regular device creation, this skips TCP connectivity checks because
|
||||
the VPN tunnel isn't up yet. The device IP is set to the VPN-assigned address.
|
||||
"""
|
||||
config = await get_vpn_config(db, tenant_id)
|
||||
if not config:
|
||||
raise ValueError("VPN not configured — enable VPN first")
|
||||
|
||||
# Allocate VPN IP before creating device
|
||||
assigned_ip = await _next_available_ip(db, tenant_id, config)
|
||||
vpn_ip_no_cidr = assigned_ip.split("/")[0]
|
||||
|
||||
# Create device with VPN IP (skip TCP check — tunnel not up yet)
|
||||
credentials_json = json.dumps({"username": username, "password": password})
|
||||
transit_ciphertext = await encrypt_credentials_transit(credentials_json, str(tenant_id))
|
||||
|
||||
device = Device(
|
||||
tenant_id=tenant_id,
|
||||
hostname=hostname,
|
||||
ip_address=vpn_ip_no_cidr,
|
||||
api_port=8728,
|
||||
api_ssl_port=8729,
|
||||
encrypted_credentials_transit=transit_ciphertext,
|
||||
status="unknown",
|
||||
)
|
||||
db.add(device)
|
||||
await db.flush()
|
||||
|
||||
# Create VPN peer linked to this device
|
||||
private_key_b64, public_key_b64 = generate_wireguard_keypair()
|
||||
psk = generate_preshared_key()
|
||||
|
||||
key_bytes = settings.get_encryption_key_bytes()
|
||||
encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
|
||||
encrypted_psk = encrypt_credentials(psk, key_bytes)
|
||||
|
||||
peer = VpnPeer(
|
||||
tenant_id=tenant_id,
|
||||
device_id=device.id,
|
||||
peer_private_key=encrypted_private,
|
||||
peer_public_key=public_key_b64,
|
||||
preshared_key=encrypted_psk,
|
||||
assigned_ip=assigned_ip,
|
||||
)
|
||||
db.add(peer)
|
||||
await db.flush()
|
||||
|
||||
await sync_wireguard_config(db, tenant_id)
|
||||
|
||||
# Generate RouterOS commands
|
||||
endpoint = config.endpoint or "YOUR_SERVER_IP:51820"
|
||||
psk_decrypted = decrypt_credentials(encrypted_psk, key_bytes)
|
||||
|
||||
routeros_commands = [
|
||||
f'/interface wireguard add name=wg-portal listen-port=13231 private-key="{private_key_b64}"',
|
||||
f'/interface wireguard peers add interface=wg-portal public-key="{config.server_public_key}" '
|
||||
f'endpoint-address={endpoint.split(":")[0]} endpoint-port={endpoint.split(":")[-1]} '
|
||||
f'allowed-address={config.subnet} persistent-keepalive=25'
|
||||
f' preshared-key="{psk_decrypted}"',
|
||||
f"/ip address add address={assigned_ip} interface=wg-portal",
|
||||
]
|
||||
|
||||
return {
|
||||
"device_id": device.id,
|
||||
"peer_id": peer.id,
|
||||
"hostname": hostname,
|
||||
"assigned_ip": assigned_ip,
|
||||
"routeros_commands": routeros_commands,
|
||||
}
|
||||
Reference in New Issue
Block a user