feat: The Other Dude v9.0.1 — full-featured email system

ci: add GitHub Pages deployment workflow for docs site

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Jason Staack
2026-03-08 17:46:37 -05:00
commit b840047e19
511 changed files with 106948 additions and 0 deletions

View File

@@ -0,0 +1 @@
"""Backend services — auth, crypto, and business logic."""

View File

@@ -0,0 +1,240 @@
"""Account self-service operations: deletion and data export.
Provides GDPR/CCPA-compliant account deletion with full PII erasure
and data portability export (Article 20).
All queries use raw SQL via text() with admin sessions (bypass RLS)
since these are cross-table operations on the authenticated user's data.
"""
import hashlib
import uuid
from datetime import UTC, datetime
from typing import Any
import structlog
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import AdminAsyncSessionLocal
from app.services.audit_service import log_action
logger = structlog.get_logger("account_service")
async def delete_user_account(
db: AsyncSession,
user_id: uuid.UUID,
tenant_id: uuid.UUID | None,
user_email: str,
) -> dict[str, Any]:
"""Hard-delete a user account with full PII erasure.
Steps:
1. Create a deletion receipt audit log (persisted via separate session)
2. Anonymize PII in existing audit_logs for this user
3. Hard-delete the user row (CASCADE handles related tables)
4. Best-effort session invalidation via Redis
Args:
db: Admin async session (bypasses RLS).
user_id: UUID of the user to delete.
tenant_id: Tenant UUID (None for super_admin).
user_email: User's email (needed for audit hash before deletion).
Returns:
Dict with deleted=True and user_id on success.
"""
effective_tenant_id = tenant_id or uuid.UUID(int=0)
email_hash = hashlib.sha256(user_email.encode()).hexdigest()
# ── 1. Pre-deletion audit receipt (separate session so it persists) ────
try:
async with AdminAsyncSessionLocal() as audit_db:
await log_action(
audit_db,
tenant_id=effective_tenant_id,
user_id=user_id,
action="account_deleted",
resource_type="user",
resource_id=str(user_id),
details={
"deleted_user_id": str(user_id),
"email_hash": email_hash,
"deletion_type": "self_service",
"deleted_at": datetime.now(UTC).isoformat(),
},
)
await audit_db.commit()
except Exception:
logger.warning(
"deletion_receipt_failed",
user_id=str(user_id),
exc_info=True,
)
# ── 2. Anonymize PII in audit_logs for this user ─────────────────────
# Strip PII keys from details JSONB (email, name, user_email, user_name)
await db.execute(
text(
"UPDATE audit_logs "
"SET details = details - 'email' - 'name' - 'user_email' - 'user_name' "
"WHERE user_id = :user_id"
),
{"user_id": user_id},
)
# Null out encrypted_details (may contain encrypted PII)
await db.execute(
text(
"UPDATE audit_logs "
"SET encrypted_details = NULL "
"WHERE user_id = :user_id"
),
{"user_id": user_id},
)
# ── 3. Hard delete user row ──────────────────────────────────────────
# CASCADE handles: user_key_sets, api_keys, password_reset_tokens
# SET NULL handles: audit_logs.user_id, key_access_log.user_id,
# maintenance_windows.created_by, alert_events.acknowledged_by
await db.execute(
text("DELETE FROM users WHERE id = :user_id"),
{"user_id": user_id},
)
await db.commit()
# ── 4. Best-effort Redis session invalidation ────────────────────────
try:
import redis.asyncio as aioredis
from app.config import settings
from app.services.auth import revoke_user_tokens
r = aioredis.from_url(settings.REDIS_URL, decode_responses=True)
await revoke_user_tokens(r, str(user_id))
await r.aclose()
except Exception:
# JWT expires in 15 min anyway; not critical
logger.debug("redis_session_invalidation_skipped", user_id=str(user_id))
logger.info("account_deleted", user_id=str(user_id), email_hash=email_hash)
return {"deleted": True, "user_id": str(user_id)}
async def export_user_data(
db: AsyncSession,
user_id: uuid.UUID,
tenant_id: uuid.UUID | None,
) -> dict[str, Any]:
"""Assemble all user data for GDPR Art. 20 data portability export.
Returns a structured dict with user profile, API keys, audit logs,
and key access log entries.
Args:
db: Admin async session (bypasses RLS).
user_id: UUID of the user whose data to export.
tenant_id: Tenant UUID (None for super_admin).
Returns:
Envelope dict with export_date, format_version, and all user data.
"""
# ── User profile ─────────────────────────────────────────────────────
result = await db.execute(
text(
"SELECT id, email, name, role, tenant_id, "
"created_at, last_login, auth_version "
"FROM users WHERE id = :user_id"
),
{"user_id": user_id},
)
user_row = result.mappings().first()
user_data: dict[str, Any] = {}
if user_row:
user_data = {
"id": str(user_row["id"]),
"email": user_row["email"],
"name": user_row["name"],
"role": user_row["role"],
"tenant_id": str(user_row["tenant_id"]) if user_row["tenant_id"] else None,
"created_at": user_row["created_at"].isoformat() if user_row["created_at"] else None,
"last_login": user_row["last_login"].isoformat() if user_row["last_login"] else None,
"auth_version": user_row["auth_version"],
}
# ── API keys (exclude key_hash for security) ─────────────────────────
result = await db.execute(
text(
"SELECT id, name, key_prefix, scopes, created_at, "
"expires_at, revoked_at, last_used_at "
"FROM api_keys WHERE user_id = :user_id "
"ORDER BY created_at DESC"
),
{"user_id": user_id},
)
api_keys = []
for row in result.mappings().all():
api_keys.append({
"id": str(row["id"]),
"name": row["name"],
"key_prefix": row["key_prefix"],
"scopes": row["scopes"],
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
"expires_at": row["expires_at"].isoformat() if row["expires_at"] else None,
"revoked_at": row["revoked_at"].isoformat() if row["revoked_at"] else None,
"last_used_at": row["last_used_at"].isoformat() if row["last_used_at"] else None,
})
# ── Audit logs (limit 1000, most recent first) ───────────────────────
result = await db.execute(
text(
"SELECT id, action, resource_type, resource_id, "
"details, ip_address, created_at "
"FROM audit_logs WHERE user_id = :user_id "
"ORDER BY created_at DESC LIMIT 1000"
),
{"user_id": user_id},
)
audit_logs = []
for row in result.mappings().all():
details = row["details"] if row["details"] else {}
audit_logs.append({
"id": str(row["id"]),
"action": row["action"],
"resource_type": row["resource_type"],
"resource_id": row["resource_id"],
"details": details,
"ip_address": row["ip_address"],
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
})
# ── Key access log (limit 1000, most recent first) ───────────────────
result = await db.execute(
text(
"SELECT id, action, resource_type, ip_address, created_at "
"FROM key_access_log WHERE user_id = :user_id "
"ORDER BY created_at DESC LIMIT 1000"
),
{"user_id": user_id},
)
key_access_entries = []
for row in result.mappings().all():
key_access_entries.append({
"id": str(row["id"]),
"action": row["action"],
"resource_type": row["resource_type"],
"ip_address": row["ip_address"],
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
})
return {
"export_date": datetime.now(UTC).isoformat(),
"format_version": "1.0",
"user": user_data,
"api_keys": api_keys,
"audit_logs": audit_logs,
"key_access_log": key_access_entries,
}

View File

@@ -0,0 +1,723 @@
"""Alert rule evaluation engine with Redis breach counters and flap detection.
Entry points:
- evaluate(device_id, tenant_id, metric_type, data): called from metrics_subscriber
- evaluate_offline(device_id, tenant_id): called from nats_subscriber on device offline
- evaluate_online(device_id, tenant_id): called from nats_subscriber on device online
Uses Redis for:
- Consecutive breach counting (alert:breach:{device_id}:{rule_id})
- Flap detection (alert:flap:{device_id}:{rule_id} sorted set)
Uses AdminAsyncSessionLocal for all DB operations (runs cross-tenant in NATS handlers).
"""
import asyncio
import logging
import time
from datetime import datetime, timezone
from typing import Any
import redis.asyncio as aioredis
from sqlalchemy import text
from app.config import settings
from app.database import AdminAsyncSessionLocal
from app.services.event_publisher import publish_event
logger = logging.getLogger(__name__)
# Module-level Redis client, lazily initialized
_redis_client: aioredis.Redis | None = None
# Module-level rule cache: {tenant_id: (rules_list, fetched_at_timestamp)}
_rule_cache: dict[str, tuple[list[dict], float]] = {}
_CACHE_TTL_SECONDS = 60
# Module-level maintenance window cache: {tenant_id: (active_windows_list, fetched_at_timestamp)}
# Each window: {"device_ids": [...], "suppress_alerts": True}
_maintenance_cache: dict[str, tuple[list[dict], float]] = {}
_MAINTENANCE_CACHE_TTL = 30 # 30 seconds
async def _get_redis() -> aioredis.Redis:
"""Get or create the Redis client."""
global _redis_client
if _redis_client is None:
_redis_client = aioredis.from_url(settings.REDIS_URL, decode_responses=True)
return _redis_client
async def _get_active_maintenance_windows(tenant_id: str) -> list[dict]:
"""Fetch active maintenance windows for a tenant, with 30s cache."""
now = time.time()
cached = _maintenance_cache.get(tenant_id)
if cached and (now - cached[1]) < _MAINTENANCE_CACHE_TTL:
return cached[0]
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT device_ids, suppress_alerts
FROM maintenance_windows
WHERE tenant_id = CAST(:tenant_id AS uuid)
AND suppress_alerts = true
AND start_at <= NOW()
AND end_at >= NOW()
"""),
{"tenant_id": tenant_id},
)
rows = result.fetchall()
windows = [
{
"device_ids": row[0] if isinstance(row[0], list) else [],
"suppress_alerts": row[1],
}
for row in rows
]
_maintenance_cache[tenant_id] = (windows, now)
return windows
async def _is_device_in_maintenance(tenant_id: str, device_id: str) -> bool:
"""Check if a device is currently under active maintenance with alert suppression.
Returns True if there is at least one active maintenance window covering
this device (or all devices via empty device_ids array).
"""
windows = await _get_active_maintenance_windows(tenant_id)
for window in windows:
device_ids = window["device_ids"]
# Empty device_ids means "all devices in tenant"
if not device_ids or device_id in device_ids:
return True
return False
async def _get_rules_for_tenant(tenant_id: str) -> list[dict]:
"""Fetch active alert rules for a tenant, with 60s cache."""
now = time.time()
cached = _rule_cache.get(tenant_id)
if cached and (now - cached[1]) < _CACHE_TTL_SECONDS:
return cached[0]
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT id, tenant_id, device_id, group_id, name, metric,
operator, threshold, duration_polls, severity
FROM alert_rules
WHERE tenant_id = CAST(:tenant_id AS uuid) AND enabled = TRUE
"""),
{"tenant_id": tenant_id},
)
rows = result.fetchall()
rules = [
{
"id": str(row[0]),
"tenant_id": str(row[1]),
"device_id": str(row[2]) if row[2] else None,
"group_id": str(row[3]) if row[3] else None,
"name": row[4],
"metric": row[5],
"operator": row[6],
"threshold": float(row[7]),
"duration_polls": row[8],
"severity": row[9],
}
for row in rows
]
_rule_cache[tenant_id] = (rules, now)
return rules
def _check_threshold(value: float, operator: str, threshold: float) -> bool:
"""Check if a metric value breaches a threshold."""
if operator == "gt":
return value > threshold
elif operator == "lt":
return value < threshold
elif operator == "gte":
return value >= threshold
elif operator == "lte":
return value <= threshold
return False
def _extract_metrics(metric_type: str, data: dict) -> dict[str, float]:
"""Extract metric name->value pairs from a NATS metrics event."""
metrics: dict[str, float] = {}
if metric_type == "health":
health = data.get("health", {})
for key in ("cpu_load", "temperature"):
val = health.get(key)
if val is not None and val != "":
try:
metrics[key] = float(val)
except (ValueError, TypeError):
pass
# Compute memory_used_pct and disk_used_pct
free_mem = health.get("free_memory")
total_mem = health.get("total_memory")
if free_mem is not None and total_mem is not None:
try:
total = float(total_mem)
free = float(free_mem)
if total > 0:
metrics["memory_used_pct"] = round((1.0 - free / total) * 100, 1)
except (ValueError, TypeError):
pass
free_disk = health.get("free_disk")
total_disk = health.get("total_disk")
if free_disk is not None and total_disk is not None:
try:
total = float(total_disk)
free = float(free_disk)
if total > 0:
metrics["disk_used_pct"] = round((1.0 - free / total) * 100, 1)
except (ValueError, TypeError):
pass
elif metric_type == "wireless":
wireless = data.get("wireless", [])
# Aggregate: use worst signal, lowest CCQ, sum client_count
for wif in wireless:
for key in ("signal_strength", "ccq", "client_count"):
val = wif.get(key) if key != "avg_signal" else wif.get("avg_signal")
if key == "signal_strength":
val = wif.get("avg_signal")
if val is not None and val != "":
try:
fval = float(val)
if key not in metrics:
metrics[key] = fval
elif key == "signal_strength":
metrics[key] = min(metrics[key], fval) # worst signal
elif key == "ccq":
metrics[key] = min(metrics[key], fval) # worst CCQ
elif key == "client_count":
metrics[key] = metrics.get(key, 0) + fval # sum
except (ValueError, TypeError):
pass
# TODO: Interface bandwidth alerting (rx_bps/tx_bps) requires stateful delta
# computation between consecutive poll values. Deferred for now — the alert_rules
# table supports these metric types, but evaluation is skipped.
return metrics
async def _increment_breach(
r: aioredis.Redis, device_id: str, rule_id: str, required_polls: int
) -> bool:
"""Increment breach counter in Redis. Returns True when threshold duration reached."""
key = f"alert:breach:{device_id}:{rule_id}"
count = await r.incr(key)
# Set TTL to (required_polls + 2) * 60 seconds so it expires if breaches stop
await r.expire(key, (required_polls + 2) * 60)
return count >= required_polls
async def _reset_breach(r: aioredis.Redis, device_id: str, rule_id: str) -> None:
"""Reset breach counter when metric returns to normal."""
key = f"alert:breach:{device_id}:{rule_id}"
await r.delete(key)
async def _check_flapping(r: aioredis.Redis, device_id: str, rule_id: str) -> bool:
"""Check if alert is flapping (>= 5 state transitions in 10 minutes).
Uses a Redis sorted set with timestamps as scores.
"""
key = f"alert:flap:{device_id}:{rule_id}"
now = time.time()
window_start = now - 600 # 10 minute window
# Add this transition
await r.zadd(key, {str(now): now})
# Remove entries outside the window
await r.zremrangebyscore(key, "-inf", window_start)
# Set TTL on the key
await r.expire(key, 1200)
# Count transitions in window
count = await r.zcard(key)
return count >= 5
async def _get_device_groups(device_id: str) -> list[str]:
"""Get group IDs for a device."""
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("SELECT group_id FROM device_group_memberships WHERE device_id = CAST(:device_id AS uuid)"),
{"device_id": device_id},
)
return [str(row[0]) for row in result.fetchall()]
async def _has_open_alert(device_id: str, rule_id: str | None, metric: str | None = None) -> bool:
"""Check if there's an open (firing, unresolved) alert for this device+rule."""
async with AdminAsyncSessionLocal() as session:
if rule_id:
result = await session.execute(
text("""
SELECT 1 FROM alert_events
WHERE device_id = CAST(:device_id AS uuid) AND rule_id = CAST(:rule_id AS uuid)
AND status = 'firing' AND resolved_at IS NULL
LIMIT 1
"""),
{"device_id": device_id, "rule_id": rule_id},
)
else:
result = await session.execute(
text("""
SELECT 1 FROM alert_events
WHERE device_id = CAST(:device_id AS uuid) AND rule_id IS NULL
AND metric = :metric AND status = 'firing' AND resolved_at IS NULL
LIMIT 1
"""),
{"device_id": device_id, "metric": metric or "offline"},
)
return result.fetchone() is not None
async def _create_alert_event(
device_id: str,
tenant_id: str,
rule_id: str | None,
status: str,
severity: str,
metric: str | None,
value: float | None,
threshold: float | None,
message: str | None,
is_flapping: bool = False,
) -> dict:
"""Create an alert event row and return its data."""
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
INSERT INTO alert_events
(id, device_id, tenant_id, rule_id, status, severity, metric,
value, threshold, message, is_flapping, fired_at,
resolved_at)
VALUES
(gen_random_uuid(), CAST(:device_id AS uuid), CAST(:tenant_id AS uuid),
:rule_id, :status, :severity, :metric,
:value, :threshold, :message, :is_flapping, NOW(),
CASE WHEN :status = 'resolved' THEN NOW() ELSE NULL END)
RETURNING id, fired_at
"""),
{
"device_id": device_id,
"tenant_id": tenant_id,
"rule_id": rule_id,
"status": status,
"severity": severity,
"metric": metric,
"value": value,
"threshold": threshold,
"message": message,
"is_flapping": is_flapping,
},
)
row = result.fetchone()
await session.commit()
alert_data = {
"id": str(row[0]) if row else None,
"device_id": device_id,
"tenant_id": tenant_id,
"rule_id": rule_id,
"status": status,
"severity": severity,
"metric": metric,
"value": value,
"threshold": threshold,
"message": message,
"is_flapping": is_flapping,
}
# Publish real-time event to NATS for SSE pipeline (fire-and-forget)
if status in ("firing", "flapping"):
await publish_event(f"alert.fired.{tenant_id}", {
"event_type": "alert_fired",
"tenant_id": tenant_id,
"device_id": device_id,
"alert_event_id": alert_data["id"],
"severity": severity,
"metric": metric,
"current_value": value,
"threshold": threshold,
"message": message,
"is_flapping": is_flapping,
"fired_at": datetime.now(timezone.utc).isoformat(),
})
elif status == "resolved":
await publish_event(f"alert.resolved.{tenant_id}", {
"event_type": "alert_resolved",
"tenant_id": tenant_id,
"device_id": device_id,
"alert_event_id": alert_data["id"],
"severity": severity,
"metric": metric,
"message": message,
"resolved_at": datetime.now(timezone.utc).isoformat(),
})
return alert_data
async def _resolve_alert(device_id: str, rule_id: str | None, metric: str | None = None) -> None:
"""Resolve an open alert by setting resolved_at."""
async with AdminAsyncSessionLocal() as session:
if rule_id:
await session.execute(
text("""
UPDATE alert_events SET resolved_at = NOW(), status = 'resolved'
WHERE device_id = CAST(:device_id AS uuid) AND rule_id = CAST(:rule_id AS uuid)
AND status = 'firing' AND resolved_at IS NULL
"""),
{"device_id": device_id, "rule_id": rule_id},
)
else:
await session.execute(
text("""
UPDATE alert_events SET resolved_at = NOW(), status = 'resolved'
WHERE device_id = CAST(:device_id AS uuid) AND rule_id IS NULL
AND metric = :metric AND status = 'firing' AND resolved_at IS NULL
"""),
{"device_id": device_id, "metric": metric or "offline"},
)
await session.commit()
async def _get_channels_for_tenant(tenant_id: str) -> list[dict]:
"""Get all notification channels for a tenant."""
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT id, name, channel_type, smtp_host, smtp_port, smtp_user,
smtp_password, smtp_use_tls, from_address, to_address,
webhook_url, smtp_password_transit, slack_webhook_url, tenant_id
FROM notification_channels
WHERE tenant_id = CAST(:tenant_id AS uuid)
"""),
{"tenant_id": tenant_id},
)
return [
{
"id": str(row[0]),
"name": row[1],
"channel_type": row[2],
"smtp_host": row[3],
"smtp_port": row[4],
"smtp_user": row[5],
"smtp_password": row[6],
"smtp_use_tls": row[7],
"from_address": row[8],
"to_address": row[9],
"webhook_url": row[10],
"smtp_password_transit": row[11],
"slack_webhook_url": row[12],
"tenant_id": str(row[13]) if row[13] else None,
}
for row in result.fetchall()
]
async def _get_channels_for_rule(rule_id: str) -> list[dict]:
"""Get notification channels linked to a specific alert rule."""
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT nc.id, nc.name, nc.channel_type, nc.smtp_host, nc.smtp_port,
nc.smtp_user, nc.smtp_password, nc.smtp_use_tls,
nc.from_address, nc.to_address, nc.webhook_url,
nc.smtp_password_transit, nc.slack_webhook_url, nc.tenant_id
FROM notification_channels nc
JOIN alert_rule_channels arc ON arc.channel_id = nc.id
WHERE arc.rule_id = CAST(:rule_id AS uuid)
"""),
{"rule_id": rule_id},
)
return [
{
"id": str(row[0]),
"name": row[1],
"channel_type": row[2],
"smtp_host": row[3],
"smtp_port": row[4],
"smtp_user": row[5],
"smtp_password": row[6],
"smtp_use_tls": row[7],
"from_address": row[8],
"to_address": row[9],
"webhook_url": row[10],
"smtp_password_transit": row[11],
"slack_webhook_url": row[12],
"tenant_id": str(row[13]) if row[13] else None,
}
for row in result.fetchall()
]
async def _dispatch_async(alert_event: dict, channels: list[dict], device_hostname: str) -> None:
"""Fire-and-forget notification dispatch."""
try:
from app.services.notification_service import dispatch_notifications
await dispatch_notifications(alert_event, channels, device_hostname)
except Exception as e:
logger.warning("Notification dispatch failed: %s", e)
async def _get_device_hostname(device_id: str) -> str:
"""Get device hostname for notification messages."""
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("SELECT hostname FROM devices WHERE id = CAST(:device_id AS uuid)"),
{"device_id": device_id},
)
row = result.fetchone()
return row[0] if row else device_id
async def evaluate(
device_id: str,
tenant_id: str,
metric_type: str,
data: dict[str, Any],
) -> None:
"""Evaluate alert rules for incoming device metrics.
Called from metrics_subscriber after metric DB write.
"""
# Check maintenance window suppression before evaluating rules
if await _is_device_in_maintenance(tenant_id, device_id):
logger.debug(
"Alert suppressed by maintenance window for device %s tenant %s",
device_id, tenant_id,
)
return
rules = await _get_rules_for_tenant(tenant_id)
if not rules:
return
metrics = _extract_metrics(metric_type, data)
if not metrics:
return
r = await _get_redis()
device_groups = await _get_device_groups(device_id)
# Build a set of metrics that have device-specific rules
device_specific_metrics: set[str] = set()
for rule in rules:
if rule["device_id"] == device_id:
device_specific_metrics.add(rule["metric"])
for rule in rules:
rule_metric = rule["metric"]
if rule_metric not in metrics:
continue
# Check if rule applies to this device
applies = False
if rule["device_id"] == device_id:
applies = True
elif rule["device_id"] is None and rule["group_id"] is None:
# Tenant-wide rule — skip if device-specific rule exists for same metric
if rule_metric in device_specific_metrics:
continue
applies = True
elif rule["group_id"] and rule["group_id"] in device_groups:
applies = True
if not applies:
continue
value = metrics[rule_metric]
breaching = _check_threshold(value, rule["operator"], rule["threshold"])
if breaching:
reached = await _increment_breach(r, device_id, rule["id"], rule["duration_polls"])
if reached:
# Check if already firing
if await _has_open_alert(device_id, rule["id"]):
continue
# Check flapping
is_flapping = await _check_flapping(r, device_id, rule["id"])
hostname = await _get_device_hostname(device_id)
message = f"{rule['name']}: {rule_metric} = {value} (threshold: {rule['operator']} {rule['threshold']})"
alert_event = await _create_alert_event(
device_id=device_id,
tenant_id=tenant_id,
rule_id=rule["id"],
status="flapping" if is_flapping else "firing",
severity=rule["severity"],
metric=rule_metric,
value=value,
threshold=rule["threshold"],
message=message,
is_flapping=is_flapping,
)
if is_flapping:
logger.info(
"Alert %s for device %s is flapping — notifications suppressed",
rule["name"], device_id,
)
else:
channels = await _get_channels_for_rule(rule["id"])
if channels:
asyncio.create_task(_dispatch_async(alert_event, channels, hostname))
else:
# Not breaching — reset counter and check for open alert to resolve
await _reset_breach(r, device_id, rule["id"])
if await _has_open_alert(device_id, rule["id"]):
# Check flapping before resolving
is_flapping = await _check_flapping(r, device_id, rule["id"])
await _resolve_alert(device_id, rule["id"])
hostname = await _get_device_hostname(device_id)
message = f"Resolved: {rule['name']}: {rule_metric} = {value}"
resolved_event = await _create_alert_event(
device_id=device_id,
tenant_id=tenant_id,
rule_id=rule["id"],
status="resolved",
severity=rule["severity"],
metric=rule_metric,
value=value,
threshold=rule["threshold"],
message=message,
is_flapping=is_flapping,
)
if not is_flapping:
channels = await _get_channels_for_rule(rule["id"])
if channels:
asyncio.create_task(_dispatch_async(resolved_event, channels, hostname))
async def _get_offline_rule(tenant_id: str) -> dict | None:
"""Look up the device_offline default rule for a tenant."""
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT id, enabled FROM alert_rules
WHERE tenant_id = CAST(:tenant_id AS uuid)
AND metric = 'device_offline' AND is_default = TRUE
LIMIT 1
"""),
{"tenant_id": tenant_id},
)
row = result.fetchone()
if row:
return {"id": str(row[0]), "enabled": row[1]}
return None
async def evaluate_offline(device_id: str, tenant_id: str) -> None:
"""Create a critical alert when a device goes offline.
Uses the tenant's device_offline default rule if it exists and is enabled.
Falls back to system-level alert (rule_id=NULL) for backward compatibility.
"""
if await _is_device_in_maintenance(tenant_id, device_id):
logger.debug(
"Offline alert suppressed by maintenance window for device %s",
device_id,
)
return
rule = await _get_offline_rule(tenant_id)
rule_id = rule["id"] if rule else None
# If rule exists but is disabled, skip alert creation (user opted out)
if rule and not rule["enabled"]:
return
if rule_id:
if await _has_open_alert(device_id, rule_id):
return
else:
if await _has_open_alert(device_id, None, "offline"):
return
hostname = await _get_device_hostname(device_id)
message = f"Device {hostname} is offline"
alert_event = await _create_alert_event(
device_id=device_id,
tenant_id=tenant_id,
rule_id=rule_id,
status="firing",
severity="critical",
metric="offline",
value=None,
threshold=None,
message=message,
)
# Use rule-linked channels if available, otherwise tenant-wide channels
if rule_id:
channels = await _get_channels_for_rule(rule_id)
if not channels:
channels = await _get_channels_for_tenant(tenant_id)
else:
channels = await _get_channels_for_tenant(tenant_id)
if channels:
asyncio.create_task(_dispatch_async(alert_event, channels, hostname))
async def evaluate_online(device_id: str, tenant_id: str) -> None:
"""Resolve offline alert when device comes back online."""
rule = await _get_offline_rule(tenant_id)
rule_id = rule["id"] if rule else None
if rule_id:
if not await _has_open_alert(device_id, rule_id):
return
await _resolve_alert(device_id, rule_id)
else:
if not await _has_open_alert(device_id, None, "offline"):
return
await _resolve_alert(device_id, None, "offline")
hostname = await _get_device_hostname(device_id)
message = f"Device {hostname} is back online"
resolved_event = await _create_alert_event(
device_id=device_id,
tenant_id=tenant_id,
rule_id=rule_id,
status="resolved",
severity="critical",
metric="offline",
value=None,
threshold=None,
message=message,
)
if rule_id:
channels = await _get_channels_for_rule(rule_id)
if not channels:
channels = await _get_channels_for_tenant(tenant_id)
else:
channels = await _get_channels_for_tenant(tenant_id)
if channels:
asyncio.create_task(_dispatch_async(resolved_event, channels, hostname))

View File

@@ -0,0 +1,190 @@
"""API key generation, validation, and management service.
Keys use the mktp_ prefix for easy identification in logs.
Storage uses SHA-256 hash -- the plaintext key is never persisted.
Validation uses AdminAsyncSessionLocal since it runs before tenant context is set.
"""
import hashlib
import json
import secrets
import uuid
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy import text
from app.database import AdminAsyncSessionLocal
# Allowed scopes for API keys
ALLOWED_SCOPES: set[str] = {
"devices:read",
"devices:write",
"config:read",
"config:write",
"alerts:read",
"firmware:write",
}
def generate_raw_key() -> str:
"""Generate a raw API key with mktp_ prefix + 32 URL-safe random chars."""
random_part = secrets.token_urlsafe(32)
return f"mktp_{random_part}"
def hash_key(raw_key: str) -> str:
"""SHA-256 hex digest of a raw API key."""
return hashlib.sha256(raw_key.encode()).hexdigest()
async def create_api_key(
db,
tenant_id: uuid.UUID,
user_id: uuid.UUID,
name: str,
scopes: list[str],
expires_at: Optional[datetime] = None,
) -> dict:
"""Create a new API key.
Returns dict with:
- key: the plaintext key (shown once, never again)
- id: the key UUID
- key_prefix: first 9 chars of the key (e.g. "mktp_abc1")
"""
raw_key = generate_raw_key()
key_hash_value = hash_key(raw_key)
key_prefix = raw_key[:9] # "mktp_" + first 4 random chars
result = await db.execute(
text("""
INSERT INTO api_keys (tenant_id, user_id, name, key_prefix, key_hash, scopes, expires_at)
VALUES (:tenant_id, :user_id, :name, :key_prefix, :key_hash, CAST(:scopes AS jsonb), :expires_at)
RETURNING id, created_at
"""),
{
"tenant_id": str(tenant_id),
"user_id": str(user_id),
"name": name,
"key_prefix": key_prefix,
"key_hash": key_hash_value,
"scopes": json.dumps(scopes),
"expires_at": expires_at,
},
)
row = result.fetchone()
await db.commit()
return {
"key": raw_key,
"id": row.id,
"key_prefix": key_prefix,
"name": name,
"scopes": scopes,
"expires_at": expires_at,
"created_at": row.created_at,
}
async def validate_api_key(raw_key: str) -> Optional[dict]:
"""Validate an API key and return context if valid.
Uses AdminAsyncSessionLocal since this runs before tenant context is set.
Returns dict with tenant_id, user_id, scopes, key_id on success.
Returns None for invalid, expired, or revoked keys.
Updates last_used_at on successful validation.
"""
key_hash_value = hash_key(raw_key)
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT id, tenant_id, user_id, scopes, expires_at, revoked_at
FROM api_keys
WHERE key_hash = :key_hash
"""),
{"key_hash": key_hash_value},
)
row = result.fetchone()
if not row:
return None
# Check revoked
if row.revoked_at is not None:
return None
# Check expired
if row.expires_at is not None and row.expires_at <= datetime.now(timezone.utc):
return None
# Update last_used_at
await session.execute(
text("""
UPDATE api_keys SET last_used_at = now()
WHERE id = :key_id
"""),
{"key_id": str(row.id)},
)
await session.commit()
return {
"tenant_id": row.tenant_id,
"user_id": row.user_id,
"scopes": row.scopes if row.scopes else [],
"key_id": row.id,
}
async def list_api_keys(db, tenant_id: uuid.UUID) -> list[dict]:
"""List all API keys for a tenant (active and revoked).
Returns keys with masked display (key_prefix + "...").
"""
result = await db.execute(
text("""
SELECT id, name, key_prefix, scopes, expires_at, last_used_at,
created_at, revoked_at, user_id
FROM api_keys
WHERE tenant_id = :tenant_id
ORDER BY created_at DESC
"""),
{"tenant_id": str(tenant_id)},
)
rows = result.fetchall()
return [
{
"id": row.id,
"name": row.name,
"key_prefix": row.key_prefix,
"scopes": row.scopes if row.scopes else [],
"expires_at": row.expires_at.isoformat() if row.expires_at else None,
"last_used_at": row.last_used_at.isoformat() if row.last_used_at else None,
"created_at": row.created_at.isoformat() if row.created_at else None,
"revoked_at": row.revoked_at.isoformat() if row.revoked_at else None,
"user_id": str(row.user_id),
}
for row in rows
]
async def revoke_api_key(db, tenant_id: uuid.UUID, key_id: uuid.UUID) -> bool:
"""Revoke an API key by setting revoked_at = now().
Returns True if a key was actually revoked, False if not found or already revoked.
"""
result = await db.execute(
text("""
UPDATE api_keys
SET revoked_at = now()
WHERE id = :key_id AND tenant_id = :tenant_id AND revoked_at IS NULL
RETURNING id
"""),
{"key_id": str(key_id), "tenant_id": str(tenant_id)},
)
row = result.fetchone()
await db.commit()
return row is not None

View File

@@ -0,0 +1,92 @@
"""Centralized audit logging service.
Provides a fire-and-forget ``log_action`` coroutine that inserts a row into
the ``audit_logs`` table. Uses raw SQL INSERT (not ORM) for minimal overhead.
The function is wrapped in a try/except so that a logging failure **never**
breaks the parent operation.
Phase 30: When details are non-empty, they are encrypted via OpenBao Transit
(per-tenant data key) and stored in encrypted_details. The plaintext details
column is set to '{}' for column compatibility. If Transit encryption fails
(e.g., OpenBao unavailable), details are stored in plaintext as a fallback.
"""
import uuid
from typing import Any, Optional
import structlog
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
logger = structlog.get_logger("audit")
async def log_action(
db: AsyncSession,
tenant_id: uuid.UUID,
user_id: uuid.UUID,
action: str,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
device_id: Optional[uuid.UUID] = None,
details: Optional[dict[str, Any]] = None,
ip_address: Optional[str] = None,
) -> None:
"""Insert a row into audit_logs. Swallows all exceptions on failure."""
try:
import json as _json
details_dict = details or {}
details_json = _json.dumps(details_dict)
encrypted_details: Optional[str] = None
# Attempt Transit encryption for non-empty details
if details_dict:
try:
from app.services.crypto import encrypt_data_transit
encrypted_details = await encrypt_data_transit(
details_json, str(tenant_id)
)
# Encryption succeeded — clear plaintext details
details_json = _json.dumps({})
except Exception:
# Transit unavailable — fall back to plaintext details
logger.warning(
"audit_transit_encryption_failed",
action=action,
tenant_id=str(tenant_id),
exc_info=True,
)
# Keep details_json as-is (plaintext fallback)
encrypted_details = None
await db.execute(
text(
"INSERT INTO audit_logs "
"(tenant_id, user_id, action, resource_type, resource_id, "
"device_id, details, encrypted_details, ip_address) "
"VALUES (:tenant_id, :user_id, :action, :resource_type, "
":resource_id, :device_id, CAST(:details AS jsonb), "
":encrypted_details, :ip_address)"
),
{
"tenant_id": str(tenant_id),
"user_id": str(user_id),
"action": action,
"resource_type": resource_type,
"resource_id": resource_id,
"device_id": str(device_id) if device_id else None,
"details": details_json,
"encrypted_details": encrypted_details,
"ip_address": ip_address,
},
)
except Exception:
logger.warning(
"audit_log_insert_failed",
action=action,
tenant_id=str(tenant_id),
exc_info=True,
)

View File

@@ -0,0 +1,154 @@
"""
JWT authentication service.
Handles password hashing, JWT token creation, token verification,
and token revocation via Redis.
"""
import time
import uuid
from datetime import UTC, datetime, timedelta
from typing import Optional
import bcrypt
from fastapi import HTTPException, status
from jose import JWTError, jwt
from redis.asyncio import Redis
from app.config import settings
TOKEN_REVOCATION_PREFIX = "token_revoked:"
def hash_password(password: str) -> str:
"""Hash a plaintext password using bcrypt.
DEPRECATED: Used only by password reset (temporary bcrypt hash for
upgrade flow) and bootstrap_first_admin. Remove post-v6.0.
"""
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a plaintext password against a bcrypt hash.
DEPRECATED: Used only by the one-time SRP upgrade flow (login with
must_upgrade_auth=True) and anti-enumeration dummy calls. Remove post-v6.0.
"""
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
def create_access_token(
user_id: uuid.UUID,
tenant_id: Optional[uuid.UUID],
role: str,
) -> str:
"""
Create a short-lived JWT access token.
Claims:
sub: user UUID (subject)
tenant_id: tenant UUID or None for super_admin
role: user's role string
type: "access" (to distinguish from refresh tokens)
exp: expiry timestamp
"""
now = datetime.now(UTC)
expire = now + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
payload = {
"sub": str(user_id),
"tenant_id": str(tenant_id) if tenant_id else None,
"role": role,
"type": "access",
"iat": now,
"exp": expire,
}
return jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
def create_refresh_token(user_id: uuid.UUID) -> str:
"""
Create a long-lived JWT refresh token.
Claims:
sub: user UUID (subject)
type: "refresh" (to distinguish from access tokens)
exp: expiry timestamp (7 days)
"""
now = datetime.now(UTC)
expire = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
payload = {
"sub": str(user_id),
"type": "refresh",
"iat": now,
"exp": expire,
}
return jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
def verify_token(token: str, expected_type: str = "access") -> dict:
"""
Decode and validate a JWT token.
Args:
token: JWT string to validate
expected_type: "access" or "refresh"
Returns:
dict: Decoded payload (sub, tenant_id, role, type, exp, iat)
Raises:
HTTPException 401: If token is invalid, expired, or wrong type
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM],
)
except JWTError:
raise credentials_exception
# Validate token type
token_type = payload.get("type")
if token_type != expected_type:
raise credentials_exception
# Validate subject exists
sub = payload.get("sub")
if not sub:
raise credentials_exception
return payload
async def revoke_user_tokens(redis: Redis, user_id: str) -> None:
"""Mark all tokens for a user as revoked by storing current timestamp.
Any refresh token issued before this timestamp will be rejected.
TTL matches maximum refresh token lifetime (7 days).
"""
key = f"{TOKEN_REVOCATION_PREFIX}{user_id}"
await redis.set(key, str(time.time()), ex=7 * 24 * 3600)
async def is_token_revoked(redis: Redis, user_id: str, issued_at: float) -> bool:
"""Check if a token was issued before the user's revocation timestamp.
Returns True if the token should be rejected.
"""
key = f"{TOKEN_REVOCATION_PREFIX}{user_id}"
revoked_at = await redis.get(key)
if revoked_at is None:
return False
return issued_at < float(revoked_at)

View File

@@ -0,0 +1,197 @@
"""Dynamic backup scheduler — reads cron schedules from DB, manages APScheduler jobs."""
import logging
from typing import Optional
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from app.database import AdminAsyncSessionLocal
from app.models.config_backup import ConfigBackupSchedule
from app.models.device import Device
from app.services import backup_service
from sqlalchemy import select
logger = logging.getLogger(__name__)
_scheduler: Optional[AsyncIOScheduler] = None
# System default: 2am UTC daily
DEFAULT_CRON = "0 2 * * *"
def _cron_to_trigger(cron_expr: str) -> Optional[CronTrigger]:
"""Parse a 5-field cron expression into an APScheduler CronTrigger.
Returns None if the expression is invalid.
"""
try:
parts = cron_expr.strip().split()
if len(parts) != 5:
return None
minute, hour, day, month, day_of_week = parts
return CronTrigger(
minute=minute, hour=hour, day=day, month=month,
day_of_week=day_of_week, timezone="UTC",
)
except Exception as e:
logger.warning("Invalid cron expression '%s': %s", cron_expr, e)
return None
def build_schedule_map(schedules: list) -> dict[str, list[dict]]:
"""Group device schedules by cron expression.
Returns: {cron_expression: [{device_id, tenant_id}, ...]}
"""
schedule_map: dict[str, list[dict]] = {}
for s in schedules:
if not s.enabled:
continue
cron = s.cron_expression or DEFAULT_CRON
if cron not in schedule_map:
schedule_map[cron] = []
schedule_map[cron].append({
"device_id": str(s.device_id),
"tenant_id": str(s.tenant_id),
})
return schedule_map
async def _run_scheduled_backups(devices: list[dict]) -> None:
"""Run backups for a list of devices. Each failure is isolated."""
success_count = 0
failure_count = 0
for dev_info in devices:
try:
async with AdminAsyncSessionLocal() as session:
await backup_service.run_backup(
device_id=dev_info["device_id"],
tenant_id=dev_info["tenant_id"],
trigger_type="scheduled",
db_session=session,
)
await session.commit()
logger.info("Scheduled backup OK: device %s", dev_info["device_id"])
success_count += 1
except Exception as e:
logger.error(
"Scheduled backup FAILED: device %s: %s",
dev_info["device_id"], e,
)
failure_count += 1
logger.info(
"Backup batch complete — %d succeeded, %d failed",
success_count, failure_count,
)
async def _load_effective_schedules() -> list:
"""Load all effective schedules from DB.
For each device: use device-specific schedule if exists, else tenant default.
Returns flat list of (device_id, tenant_id, cron_expression, enabled) objects.
"""
from types import SimpleNamespace
async with AdminAsyncSessionLocal() as session:
# Get all devices
dev_result = await session.execute(select(Device))
devices = dev_result.scalars().all()
# Get all schedules
sched_result = await session.execute(select(ConfigBackupSchedule))
schedules = sched_result.scalars().all()
# Index: device-specific and tenant defaults
device_schedules = {} # device_id -> schedule
tenant_defaults = {} # tenant_id -> schedule
for s in schedules:
if s.device_id:
device_schedules[str(s.device_id)] = s
else:
tenant_defaults[str(s.tenant_id)] = s
effective = []
for dev in devices:
dev_id = str(dev.id)
tenant_id = str(dev.tenant_id)
if dev_id in device_schedules:
sched = device_schedules[dev_id]
elif tenant_id in tenant_defaults:
sched = tenant_defaults[tenant_id]
else:
# No schedule configured — use system default
sched = None
effective.append(SimpleNamespace(
device_id=dev_id,
tenant_id=tenant_id,
cron_expression=sched.cron_expression if sched else DEFAULT_CRON,
enabled=sched.enabled if sched else True,
))
return effective
async def sync_schedules() -> None:
"""Reload all schedules from DB and reconfigure APScheduler jobs."""
global _scheduler
if not _scheduler:
return
# Remove all existing backup jobs (keep other jobs like firmware check)
for job in _scheduler.get_jobs():
if job.id.startswith("backup_cron_"):
job.remove()
schedules = await _load_effective_schedules()
schedule_map = build_schedule_map(schedules)
for cron_expr, devices in schedule_map.items():
trigger = _cron_to_trigger(cron_expr)
if not trigger:
logger.warning("Skipping invalid cron '%s', using default", cron_expr)
trigger = _cron_to_trigger(DEFAULT_CRON)
job_id = f"backup_cron_{cron_expr.replace(' ', '_')}"
_scheduler.add_job(
_run_scheduled_backups,
trigger=trigger,
args=[devices],
id=job_id,
name=f"Backup: {cron_expr} ({len(devices)} devices)",
max_instances=1,
replace_existing=True,
)
logger.info("Scheduled %d devices with cron '%s'", len(devices), cron_expr)
async def on_schedule_change(tenant_id: str, device_id: str) -> None:
"""Called when a schedule is created/updated via API. Hot-reloads all schedules."""
logger.info("Schedule changed for tenant=%s device=%s, resyncing", tenant_id, device_id)
await sync_schedules()
async def start_backup_scheduler() -> None:
"""Start the APScheduler and load initial schedules from DB."""
global _scheduler
_scheduler = AsyncIOScheduler(timezone="UTC")
_scheduler.start()
await sync_schedules()
logger.info("Backup scheduler started with dynamic schedules")
async def stop_backup_scheduler() -> None:
"""Gracefully shutdown the scheduler."""
global _scheduler
if _scheduler:
_scheduler.shutdown(wait=False)
_scheduler = None
logger.info("Backup scheduler stopped")

View File

@@ -0,0 +1,378 @@
"""SSH-based config capture service for RouterOS devices.
This service handles:
1. capture_export() — SSH to device, run /export compact, return stdout text
2. capture_binary_backup() — SSH to device, trigger /system backup save, SFTP-download result
3. run_backup() — Orchestrate a full backup: capture + git commit + DB record
All functions are async (asyncssh is asyncio-native).
Security policy:
known_hosts=None is intentional — RouterOS devices use self-signed SSH host keys
that change on reset or key regeneration. This mirrors InsecureSkipVerify=true
used in the poller's TLS connection. The threat model accepts device impersonation
risk in exchange for operational simplicity (no pre-enrollment of host keys needed).
See Pitfall 2 in 04-RESEARCH.md.
pygit2 calls are synchronous C bindings and MUST be wrapped in run_in_executor.
See Pitfall 3 in 04-RESEARCH.md.
Phase 30: ALL backups (manual, scheduled, pre-restore) are encrypted via OpenBao
Transit (Tier 2) before git commit. The server retains decrypt capability for
on-demand viewing. Raw files in git are ciphertext; the API decrypts on GET.
"""
import asyncio
import base64
import io
import json
import logging
from datetime import datetime, timezone
import asyncssh
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.database import AdminAsyncSessionLocal, set_tenant_context
from app.models.config_backup import ConfigBackupRun
from app.models.device import Device
from app.services import git_store
from app.services.crypto import decrypt_credentials_hybrid
logger = logging.getLogger(__name__)
# Fixed backup file name on device flash — overwrites on each run so files
# don't accumulate. See Pitfall 4 in 04-RESEARCH.md.
_BACKUP_NAME = "portal-backup"
async def capture_export(
ip: str,
port: int = 22,
username: str = "",
password: str = "",
) -> str:
"""SSH to a RouterOS device and capture /export compact output.
Args:
ip: Device IP address.
port: SSH port (default 22; RouterOS default is 22).
username: SSH login username.
password: SSH login password.
Returns:
The raw RSC text from /export compact (may include RouterOS header line).
Raises:
asyncssh.Error: On SSH connection or command execution failure.
"""
async with asyncssh.connect(
ip,
port=port,
username=username,
password=password,
known_hosts=None, # RouterOS self-signed host keys — see module docstring
connect_timeout=30,
) as conn:
result = await conn.run("/export compact", check=True)
return result.stdout
async def capture_binary_backup(
ip: str,
port: int = 22,
username: str = "",
password: str = "",
) -> bytes:
"""SSH to a RouterOS device, create a binary backup, SFTP-download it, then clean up.
Uses a fixed backup name ({_BACKUP_NAME}.backup) so the file overwrites
on subsequent runs, preventing flash storage accumulation.
The cleanup (removing the file from device flash) runs in a try/finally
block so cleanup failures don't mask the actual backup error but are
logged for observability. See Pitfall 4 in 04-RESEARCH.md.
Args:
ip: Device IP address.
port: SSH port (default 22).
username: SSH login username.
password: SSH login password.
Returns:
Raw bytes of the binary backup file.
Raises:
asyncssh.Error: On SSH connection, command, or SFTP failure.
"""
async with asyncssh.connect(
ip,
port=port,
username=username,
password=password,
known_hosts=None,
connect_timeout=30,
) as conn:
# Step 1: Trigger backup creation on device flash.
await conn.run(
f"/system backup save name={_BACKUP_NAME} dont-encrypt=yes",
check=True,
)
buf = io.BytesIO()
try:
# Step 2: SFTP-download the backup file.
async with conn.start_sftp_client() as sftp:
async with sftp.open(f"{_BACKUP_NAME}.backup", "rb") as f:
buf.write(await f.read())
finally:
# Step 3: Remove backup file from device flash (best-effort cleanup).
try:
await conn.run(f"/file remove {_BACKUP_NAME}.backup", check=True)
except Exception as cleanup_err:
logger.warning(
"Failed to remove backup file from device %s: %s",
ip,
cleanup_err,
)
return buf.getvalue()
async def run_backup(
device_id: str,
tenant_id: str,
trigger_type: str,
db_session: AsyncSession | None = None,
) -> dict:
"""Orchestrate a full config backup for a device.
Steps:
1. Load device from DB (ip_address, encrypted_credentials).
2. Decrypt credentials using crypto.decrypt_credentials().
3. Capture /export compact and binary backup concurrently via asyncio.gather().
4. Compute line delta vs the most recent export.rsc in git (None for first backup).
5. Commit both files to the tenant's bare git repo (run_in_executor for pygit2).
6. Insert ConfigBackupRun record with commit SHA, trigger type, line deltas.
7. Return summary dict.
Args:
device_id: Device UUID as string.
tenant_id: Tenant UUID as string.
trigger_type: 'scheduled' | 'manual' | 'pre-restore'
db_session: Optional AsyncSession with RLS context already set.
If None, uses AdminAsyncSessionLocal (for scheduler context).
Returns:
Dict: {"commit_sha": str, "trigger_type": str, "lines_added": int|None, "lines_removed": int|None}
Raises:
ValueError: If device not found or missing credentials.
asyncssh.Error: On SSH/SFTP failure.
"""
loop = asyncio.get_event_loop()
ts = datetime.now(timezone.utc).isoformat()
# -----------------------------------------------------------------------
# Step 1: Load device from DB
# -----------------------------------------------------------------------
if db_session is not None:
session = db_session
should_close = False
else:
# Scheduler context: use admin session (cross-tenant; RLS bypassed)
session = AdminAsyncSessionLocal()
should_close = True
try:
from sqlalchemy import select
if should_close:
# Admin session doesn't have RLS context — query directly.
result = await session.execute(
select(Device).where(
Device.id == device_id, # type: ignore[arg-type]
Device.tenant_id == tenant_id, # type: ignore[arg-type]
)
)
else:
result = await session.execute(
select(Device).where(Device.id == device_id) # type: ignore[arg-type]
)
device = result.scalar_one_or_none()
if device is None:
raise ValueError(f"Device {device_id!r} not found for tenant {tenant_id!r}")
if not device.encrypted_credentials_transit and not device.encrypted_credentials:
raise ValueError(
f"Device {device_id!r} has no stored credentials — cannot perform backup"
)
# -----------------------------------------------------------------------
# Step 2: Decrypt credentials (dual-read: Transit preferred, legacy fallback)
# -----------------------------------------------------------------------
key = settings.get_encryption_key_bytes()
creds_json = await decrypt_credentials_hybrid(
device.encrypted_credentials_transit,
device.encrypted_credentials,
str(device.tenant_id),
key,
)
creds = json.loads(creds_json)
ssh_username = creds.get("username", "")
ssh_password = creds.get("password", "")
ip = device.ip_address
hostname = device.hostname or ip
# -----------------------------------------------------------------------
# Step 3: Capture export and binary backup concurrently
# -----------------------------------------------------------------------
logger.info(
"Starting %s backup for device %s (%s) tenant %s",
trigger_type,
hostname,
ip,
tenant_id,
)
export_text, binary_backup = await asyncio.gather(
capture_export(ip, username=ssh_username, password=ssh_password),
capture_binary_backup(ip, username=ssh_username, password=ssh_password),
)
# -----------------------------------------------------------------------
# Step 4: Compute line delta vs prior version
# -----------------------------------------------------------------------
lines_added: int | None = None
lines_removed: int | None = None
prior_commits = await loop.run_in_executor(
None, git_store.list_device_commits, tenant_id, device_id
)
if prior_commits:
try:
prior_export_bytes = await loop.run_in_executor(
None, git_store.read_file, tenant_id, prior_commits[0]["sha"], device_id, "export.rsc"
)
prior_text = prior_export_bytes.decode("utf-8", errors="replace")
lines_added, lines_removed = await loop.run_in_executor(
None, git_store.compute_line_delta, prior_text, export_text
)
except Exception as delta_err:
logger.warning(
"Failed to compute line delta for device %s: %s",
device_id,
delta_err,
)
# Keep lines_added/lines_removed as None on error — non-fatal
else:
# First backup: all lines are "added", none removed
all_lines = len(export_text.splitlines())
lines_added = all_lines
lines_removed = 0
# -----------------------------------------------------------------------
# Step 5: Encrypt ALL backups via Transit (Tier 2: OpenBao Transit)
# -----------------------------------------------------------------------
encryption_tier: int | None = None
git_export_content = export_text
git_binary_content = binary_backup
try:
from app.services.crypto import encrypt_data_transit
encrypted_export = await encrypt_data_transit(
export_text, tenant_id
)
encrypted_binary = await encrypt_data_transit(
base64.b64encode(binary_backup).decode(), tenant_id
)
# Transit ciphertext is text — store directly in git
git_export_content = encrypted_export
git_binary_content = encrypted_binary.encode("utf-8")
encryption_tier = 2
logger.info(
"Tier 2 Transit encryption applied for %s backup of device %s",
trigger_type,
device_id,
)
except Exception as enc_err:
# Transit unavailable — fall back to plaintext (non-fatal)
logger.warning(
"Transit encryption failed for %s backup of device %s, "
"storing plaintext: %s",
trigger_type,
device_id,
enc_err,
)
# Keep encryption_tier = None (plaintext fallback)
# -----------------------------------------------------------------------
# Step 6: Commit to git (wrapped in run_in_executor — pygit2 is sync C bindings)
# -----------------------------------------------------------------------
commit_message = (
f"{trigger_type}: {hostname} ({ip}) at {ts}"
)
commit_sha = await loop.run_in_executor(
None,
git_store.commit_backup,
tenant_id,
device_id,
git_export_content,
git_binary_content,
commit_message,
)
logger.info(
"Committed backup for device %s to git SHA %s (tier=%s)",
device_id,
commit_sha[:8],
encryption_tier,
)
# -----------------------------------------------------------------------
# Step 7: Insert ConfigBackupRun record
# -----------------------------------------------------------------------
if not should_close:
# RLS-scoped session from API context — record directly
backup_run = ConfigBackupRun(
device_id=device.id,
tenant_id=device.tenant_id,
commit_sha=commit_sha,
trigger_type=trigger_type,
lines_added=lines_added,
lines_removed=lines_removed,
encryption_tier=encryption_tier,
)
session.add(backup_run)
await session.flush()
else:
# Admin session — set tenant context before insert so RLS policy is satisfied
async with AdminAsyncSessionLocal() as admin_session:
await set_tenant_context(admin_session, str(device.tenant_id))
backup_run = ConfigBackupRun(
device_id=device.id,
tenant_id=device.tenant_id,
commit_sha=commit_sha,
trigger_type=trigger_type,
lines_added=lines_added,
lines_removed=lines_removed,
encryption_tier=encryption_tier,
)
admin_session.add(backup_run)
await admin_session.commit()
return {
"commit_sha": commit_sha,
"trigger_type": trigger_type,
"lines_added": lines_added,
"lines_removed": lines_removed,
}
finally:
if should_close:
await session.close()

View File

@@ -0,0 +1,462 @@
"""Certificate Authority service — CA generation, device cert signing, lifecycle.
This module provides the core PKI functionality for the Internal Certificate
Authority feature. All functions receive an ``AsyncSession`` and an
``encryption_key`` as parameters (no direct Settings access) for testability.
Security notes:
- CA private keys are encrypted with AES-256-GCM before database storage.
- PEM key material is NEVER logged.
- Device keys are decrypted only when needed for NATS transmission.
"""
from __future__ import annotations
import datetime
import ipaddress
import logging
from uuid import UUID
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.certificate import CertificateAuthority, DeviceCertificate
from app.services.crypto import (
decrypt_credentials_hybrid,
encrypt_credentials_transit,
)
logger = logging.getLogger(__name__)
# Valid status transitions for the device certificate lifecycle.
_VALID_TRANSITIONS: dict[str, set[str]] = {
"issued": {"deploying"},
"deploying": {"deployed", "issued"}, # issued = rollback on deploy failure
"deployed": {"expiring", "revoked", "superseded"},
"expiring": {"expired", "revoked", "superseded"},
"expired": {"superseded"},
"revoked": set(),
"superseded": set(),
}
# ---------------------------------------------------------------------------
# CA Generation
# ---------------------------------------------------------------------------
async def generate_ca(
db: AsyncSession,
tenant_id: UUID,
common_name: str,
validity_years: int,
encryption_key: bytes,
) -> CertificateAuthority:
"""Generate a self-signed root CA for a tenant.
Args:
db: Async database session.
tenant_id: Tenant UUID — only one CA per tenant.
common_name: CN for the CA certificate (e.g., "Portal Root CA").
validity_years: How many years the CA cert is valid.
encryption_key: 32-byte AES-256-GCM key for encrypting the CA private key.
Returns:
The newly created ``CertificateAuthority`` model instance.
Raises:
ValueError: If the tenant already has a CA.
"""
# Ensure one CA per tenant
existing = await get_ca_for_tenant(db, tenant_id)
if existing is not None:
raise ValueError(
f"Tenant {tenant_id} already has a CA (id={existing.id}). "
"Delete the existing CA before creating a new one."
)
# Generate RSA 2048 key pair
ca_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
now = datetime.datetime.now(datetime.timezone.utc)
expiry = now + datetime.timedelta(days=365 * validity_years)
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
])
ca_cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(ca_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(expiry)
.add_extension(
x509.BasicConstraints(ca=True, path_length=0), critical=True
)
.add_extension(
x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=True,
crl_sign=True,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.add_extension(
x509.SubjectKeyIdentifier.from_public_key(ca_key.public_key()),
critical=False,
)
.sign(ca_key, hashes.SHA256())
)
# Serialize public cert to PEM
cert_pem = ca_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
# Serialize private key to PEM, then encrypt with OpenBao Transit
key_pem = ca_key.private_bytes(
serialization.Encoding.PEM,
serialization.PrivateFormat.PKCS8,
serialization.NoEncryption(),
).decode("utf-8")
encrypted_key_transit = await encrypt_credentials_transit(key_pem, str(tenant_id))
# Compute SHA-256 fingerprint (colon-separated hex)
fingerprint_bytes = ca_cert.fingerprint(hashes.SHA256())
fingerprint = ":".join(f"{b:02X}" for b in fingerprint_bytes)
# Serial number as hex string
serial_hex = format(ca_cert.serial_number, "X")
model = CertificateAuthority(
tenant_id=tenant_id,
common_name=common_name,
cert_pem=cert_pem,
encrypted_private_key=b"", # Legacy column kept for schema compat
encrypted_private_key_transit=encrypted_key_transit,
serial_number=serial_hex,
fingerprint_sha256=fingerprint,
not_valid_before=now,
not_valid_after=expiry,
)
db.add(model)
await db.flush()
logger.info(
"Generated CA for tenant %s: cn=%s fingerprint=%s",
tenant_id,
common_name,
fingerprint,
)
return model
# ---------------------------------------------------------------------------
# Device Certificate Signing
# ---------------------------------------------------------------------------
async def sign_device_cert(
db: AsyncSession,
ca: CertificateAuthority,
device_id: UUID,
hostname: str,
ip_address: str,
validity_days: int,
encryption_key: bytes,
) -> DeviceCertificate:
"""Sign a per-device TLS certificate using the tenant's CA.
Args:
db: Async database session.
ca: The tenant's CertificateAuthority model instance.
device_id: UUID of the device receiving the cert.
hostname: Device hostname — used as CN and SAN DNSName.
ip_address: Device IP — used as SAN IPAddress.
validity_days: Certificate validity in days.
encryption_key: 32-byte AES-256-GCM key for encrypting the device private key.
Returns:
The newly created ``DeviceCertificate`` model instance (status='issued').
"""
# Decrypt CA private key (dual-read: Transit preferred, legacy fallback)
ca_key_pem = await decrypt_credentials_hybrid(
ca.encrypted_private_key_transit,
ca.encrypted_private_key,
str(ca.tenant_id),
encryption_key,
)
ca_key = serialization.load_pem_private_key(
ca_key_pem.encode("utf-8"), password=None
)
# Load CA certificate for issuer info and AuthorityKeyIdentifier
ca_cert = x509.load_pem_x509_certificate(ca.cert_pem.encode("utf-8"))
# Generate device RSA 2048 key
device_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
now = datetime.datetime.now(datetime.timezone.utc)
expiry = now + datetime.timedelta(days=validity_days)
device_cert = (
x509.CertificateBuilder()
.subject_name(
x509.Name([
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
x509.NameAttribute(NameOID.COMMON_NAME, hostname),
])
)
.issuer_name(ca_cert.subject)
.public_key(device_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(expiry)
.add_extension(
x509.BasicConstraints(ca=False, path_length=None), critical=True
)
.add_extension(
x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=True,
data_encipherment=False,
key_agreement=False,
key_cert_sign=False,
crl_sign=False,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.add_extension(
x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]),
critical=False,
)
.add_extension(
x509.SubjectAlternativeName([
x509.IPAddress(ipaddress.ip_address(ip_address)),
x509.DNSName(hostname),
]),
critical=False,
)
.add_extension(
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
ca_cert.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
).value
),
critical=False,
)
.sign(ca_key, hashes.SHA256())
)
# Serialize device cert and key to PEM
cert_pem = device_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
key_pem = device_key.private_bytes(
serialization.Encoding.PEM,
serialization.PrivateFormat.PKCS8,
serialization.NoEncryption(),
).decode("utf-8")
# Encrypt device private key via OpenBao Transit
encrypted_key_transit = await encrypt_credentials_transit(key_pem, str(ca.tenant_id))
# Compute fingerprint
fingerprint_bytes = device_cert.fingerprint(hashes.SHA256())
fingerprint = ":".join(f"{b:02X}" for b in fingerprint_bytes)
serial_hex = format(device_cert.serial_number, "X")
model = DeviceCertificate(
tenant_id=ca.tenant_id,
device_id=device_id,
ca_id=ca.id,
common_name=hostname,
serial_number=serial_hex,
fingerprint_sha256=fingerprint,
cert_pem=cert_pem,
encrypted_private_key=b"", # Legacy column kept for schema compat
encrypted_private_key_transit=encrypted_key_transit,
not_valid_before=now,
not_valid_after=expiry,
status="issued",
)
db.add(model)
await db.flush()
logger.info(
"Signed device cert for device %s: cn=%s fingerprint=%s",
device_id,
hostname,
fingerprint,
)
return model
# ---------------------------------------------------------------------------
# Queries
# ---------------------------------------------------------------------------
async def get_ca_for_tenant(
db: AsyncSession,
tenant_id: UUID,
) -> CertificateAuthority | None:
"""Return the tenant's CA, or None if not yet initialized."""
result = await db.execute(
select(CertificateAuthority).where(
CertificateAuthority.tenant_id == tenant_id
)
)
return result.scalar_one_or_none()
async def get_device_certs(
db: AsyncSession,
tenant_id: UUID,
device_id: UUID | None = None,
) -> list[DeviceCertificate]:
"""List device certificates for a tenant.
Args:
db: Async database session.
tenant_id: Tenant UUID.
device_id: If provided, filter to certs for this device only.
Returns:
List of DeviceCertificate models (excludes superseded by default).
"""
stmt = (
select(DeviceCertificate)
.where(DeviceCertificate.tenant_id == tenant_id)
.where(DeviceCertificate.status != "superseded")
)
if device_id is not None:
stmt = stmt.where(DeviceCertificate.device_id == device_id)
stmt = stmt.order_by(DeviceCertificate.created_at.desc())
result = await db.execute(stmt)
return list(result.scalars().all())
# ---------------------------------------------------------------------------
# Status Management
# ---------------------------------------------------------------------------
async def update_cert_status(
db: AsyncSession,
cert_id: UUID,
status: str,
deployed_at: datetime.datetime | None = None,
) -> DeviceCertificate:
"""Update a device certificate's lifecycle status.
Validates that the transition is allowed by the state machine:
issued -> deploying -> deployed -> expiring -> expired
\\-> revoked
\\-> superseded
Args:
db: Async database session.
cert_id: Certificate UUID.
status: New status value.
deployed_at: Timestamp to set when transitioning to 'deployed'.
Returns:
The updated DeviceCertificate model.
Raises:
ValueError: If the certificate is not found or the transition is invalid.
"""
result = await db.execute(
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
)
cert = result.scalar_one_or_none()
if cert is None:
raise ValueError(f"Device certificate {cert_id} not found")
allowed = _VALID_TRANSITIONS.get(cert.status, set())
if status not in allowed:
raise ValueError(
f"Invalid status transition: {cert.status} -> {status}. "
f"Allowed transitions from '{cert.status}': {allowed or 'none'}"
)
cert.status = status
cert.updated_at = datetime.datetime.now(datetime.timezone.utc)
if status == "deployed" and deployed_at is not None:
cert.deployed_at = deployed_at
elif status == "deployed":
cert.deployed_at = cert.updated_at
await db.flush()
logger.info(
"Updated cert %s status to %s",
cert_id,
status,
)
return cert
# ---------------------------------------------------------------------------
# Cert Data for Deployment
# ---------------------------------------------------------------------------
async def get_cert_for_deploy(
db: AsyncSession,
cert_id: UUID,
encryption_key: bytes,
) -> tuple[str, str, str]:
"""Retrieve and decrypt certificate data for NATS deployment.
Returns the device cert PEM, decrypted device key PEM, and the CA cert
PEM — everything needed to push to a device via the Go poller.
Args:
db: Async database session.
cert_id: Device certificate UUID.
encryption_key: 32-byte AES-256-GCM key to decrypt the device private key.
Returns:
Tuple of (cert_pem, key_pem_decrypted, ca_cert_pem).
Raises:
ValueError: If the certificate or its CA is not found.
"""
result = await db.execute(
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
)
cert = result.scalar_one_or_none()
if cert is None:
raise ValueError(f"Device certificate {cert_id} not found")
# Fetch the CA for the ca_cert_pem
ca_result = await db.execute(
select(CertificateAuthority).where(
CertificateAuthority.id == cert.ca_id
)
)
ca = ca_result.scalar_one_or_none()
if ca is None:
raise ValueError(f"CA {cert.ca_id} not found for certificate {cert_id}")
# Decrypt device private key (dual-read: Transit preferred, legacy fallback)
key_pem = await decrypt_credentials_hybrid(
cert.encrypted_private_key_transit,
cert.encrypted_private_key,
str(cert.tenant_id),
encryption_key,
)
return cert.cert_pem, key_pem, ca.cert_pem

View File

@@ -0,0 +1,118 @@
"""NATS subscriber for config change events from the Go poller.
Triggers automatic backups when out-of-band config changes are detected,
with 5-minute deduplication to prevent rapid-fire backups.
"""
import json
import logging
from datetime import datetime, timedelta, timezone
from typing import Any, Optional
from sqlalchemy import select
from app.config import settings
from app.database import AdminAsyncSessionLocal
from app.models.config_backup import ConfigBackupRun
from app.services import backup_service
logger = logging.getLogger(__name__)
DEDUP_WINDOW_MINUTES = 5
_nc: Optional[Any] = None
async def _last_backup_within_dedup_window(device_id: str) -> bool:
"""Check if a backup was created for this device in the last N minutes."""
cutoff = datetime.now(timezone.utc) - timedelta(minutes=DEDUP_WINDOW_MINUTES)
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
select(ConfigBackupRun)
.where(
ConfigBackupRun.device_id == device_id,
ConfigBackupRun.created_at > cutoff,
)
.limit(1)
)
return result.scalar_one_or_none() is not None
async def handle_config_changed(event: dict) -> None:
"""Handle a config change event. Trigger backup with dedup."""
device_id = event.get("device_id")
tenant_id = event.get("tenant_id")
if not device_id or not tenant_id:
logger.warning("Config change event missing device_id or tenant_id: %s", event)
return
# Dedup check
if await _last_backup_within_dedup_window(device_id):
logger.info(
"Config change on device %s — skipping backup (within %dm dedup window)",
device_id, DEDUP_WINDOW_MINUTES,
)
return
logger.info(
"Config change detected on device %s (tenant %s): %s -> %s",
device_id, tenant_id,
event.get("old_timestamp", "?"),
event.get("new_timestamp", "?"),
)
try:
async with AdminAsyncSessionLocal() as session:
await backup_service.run_backup(
device_id=device_id,
tenant_id=tenant_id,
trigger_type="config-change",
db_session=session,
)
await session.commit()
logger.info("Config-change backup completed for device %s", device_id)
except Exception as e:
logger.error("Config-change backup failed for device %s: %s", device_id, e)
async def _on_message(msg) -> None:
"""NATS message handler for config.changed.> subjects."""
try:
event = json.loads(msg.data.decode())
await handle_config_changed(event)
await msg.ack()
except Exception as e:
logger.error("Error handling config change message: %s", e)
await msg.nak()
async def start_config_change_subscriber() -> Optional[Any]:
"""Connect to NATS and subscribe to config.changed.> events."""
import nats
global _nc
try:
logger.info("NATS config-change: connecting to %s", settings.NATS_URL)
_nc = await nats.connect(settings.NATS_URL)
js = _nc.jetstream()
await js.subscribe(
"config.changed.>",
cb=_on_message,
durable="api-config-change-consumer",
stream="DEVICE_EVENTS",
manual_ack=True,
)
logger.info("Config change subscriber started")
return _nc
except Exception as e:
logger.error("Failed to start config change subscriber: %s", e)
return None
async def stop_config_change_subscriber() -> None:
"""Gracefully close the NATS connection."""
global _nc
if _nc:
await _nc.drain()
_nc = None

View File

@@ -0,0 +1,183 @@
"""
Credential encryption/decryption with dual-read (OpenBao Transit + legacy AES-256-GCM).
This module provides two encryption paths:
1. Legacy (sync): AES-256-GCM with static CREDENTIAL_ENCRYPTION_KEY — used for fallback reads.
2. Transit (async): OpenBao Transit per-tenant keys — used for all new writes.
The dual-read pattern:
- New writes always use OpenBao Transit (encrypt_credentials_transit).
- Reads prefer Transit ciphertext, falling back to legacy (decrypt_credentials_hybrid).
- Legacy functions are preserved for backward compatibility during migration.
Security properties:
- AES-256-GCM provides authenticated encryption (confidentiality + integrity)
- A unique 12-byte random nonce is generated per legacy encryption operation
- OpenBao Transit keys are AES-256-GCM96, managed entirely by OpenBao
- Ciphertext format: "vault:v1:..." for Transit, raw bytes for legacy
"""
import os
def encrypt_credentials(plaintext: str, key: bytes) -> bytes:
"""
Encrypt a plaintext string using AES-256-GCM.
Args:
plaintext: The credential string to encrypt (e.g., JSON with username/password)
key: 32-byte encryption key
Returns:
bytes: nonce (12 bytes) + ciphertext + GCM tag (16 bytes)
Raises:
ValueError: If key is not exactly 32 bytes
"""
if len(key) != 32:
raise ValueError(f"Key must be exactly 32 bytes, got {len(key)}")
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
aesgcm = AESGCM(key)
nonce = os.urandom(12) # 96-bit nonce, unique per encryption
ciphertext = aesgcm.encrypt(nonce, plaintext.encode("utf-8"), None)
# Store as: nonce (12 bytes) + ciphertext + GCM tag (included in ciphertext by library)
return nonce + ciphertext
def decrypt_credentials(ciphertext: bytes, key: bytes) -> str:
"""
Decrypt AES-256-GCM encrypted credentials.
Args:
ciphertext: bytes from encrypt_credentials (nonce + encrypted data + GCM tag)
key: 32-byte encryption key (must match the key used for encryption)
Returns:
str: The original plaintext string
Raises:
ValueError: If key is not exactly 32 bytes
cryptography.exceptions.InvalidTag: If authentication fails (tampered data or wrong key)
"""
if len(key) != 32:
raise ValueError(f"Key must be exactly 32 bytes, got {len(key)}")
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
nonce = ciphertext[:12]
encrypted_data = ciphertext[12:]
aesgcm = AESGCM(key)
plaintext_bytes = aesgcm.decrypt(nonce, encrypted_data, None)
return plaintext_bytes.decode("utf-8")
# ---------------------------------------------------------------------------
# OpenBao Transit functions (async, per-tenant keys)
# ---------------------------------------------------------------------------
async def encrypt_credentials_transit(plaintext: str, tenant_id: str) -> str:
"""Encrypt via OpenBao Transit. Returns ciphertext string (vault:v1:...).
Args:
plaintext: The credential string to encrypt.
tenant_id: Tenant UUID string for key lookup.
Returns:
Transit ciphertext string (vault:v1:base64...).
"""
from app.services.openbao_service import get_openbao_service
service = get_openbao_service()
return await service.encrypt(tenant_id, plaintext.encode("utf-8"))
async def decrypt_credentials_transit(ciphertext: str, tenant_id: str) -> str:
"""Decrypt OpenBao Transit ciphertext. Returns plaintext string.
Args:
ciphertext: Transit ciphertext (vault:v1:...).
tenant_id: Tenant UUID string for key lookup.
Returns:
Decrypted plaintext string.
"""
from app.services.openbao_service import get_openbao_service
service = get_openbao_service()
plaintext_bytes = await service.decrypt(tenant_id, ciphertext)
return plaintext_bytes.decode("utf-8")
# ---------------------------------------------------------------------------
# OpenBao Transit data encryption (async, per-tenant _data keys — Phase 30)
# ---------------------------------------------------------------------------
async def encrypt_data_transit(plaintext: str, tenant_id: str) -> str:
"""Encrypt non-credential data via OpenBao Transit using per-tenant data key.
Used for audit log details, config backups, and reports. Data keys are
separate from credential keys (tenant_{uuid}_data vs tenant_{uuid}).
Args:
plaintext: The data string to encrypt.
tenant_id: Tenant UUID string for data key lookup.
Returns:
Transit ciphertext string (vault:v1:base64...).
"""
from app.services.openbao_service import get_openbao_service
service = get_openbao_service()
return await service.encrypt_data(tenant_id, plaintext.encode("utf-8"))
async def decrypt_data_transit(ciphertext: str, tenant_id: str) -> str:
"""Decrypt OpenBao Transit data ciphertext. Returns plaintext string.
Args:
ciphertext: Transit ciphertext (vault:v1:...).
tenant_id: Tenant UUID string for data key lookup.
Returns:
Decrypted plaintext string.
"""
from app.services.openbao_service import get_openbao_service
service = get_openbao_service()
plaintext_bytes = await service.decrypt_data(tenant_id, ciphertext)
return plaintext_bytes.decode("utf-8")
async def decrypt_credentials_hybrid(
transit_ciphertext: str | None,
legacy_ciphertext: bytes | None,
tenant_id: str,
legacy_key: bytes,
) -> str:
"""Dual-read: prefer Transit ciphertext, fall back to legacy.
Args:
transit_ciphertext: OpenBao Transit ciphertext (vault:v1:...) or None.
legacy_ciphertext: Legacy AES-256-GCM bytes (nonce+ciphertext+tag) or None.
tenant_id: Tenant UUID string for Transit key lookup.
legacy_key: 32-byte legacy encryption key for fallback.
Returns:
Decrypted plaintext string.
Raises:
ValueError: If neither ciphertext is available.
"""
if transit_ciphertext and transit_ciphertext.startswith("vault:v"):
return await decrypt_credentials_transit(transit_ciphertext, tenant_id)
elif legacy_ciphertext:
return decrypt_credentials(legacy_ciphertext, legacy_key)
else:
raise ValueError("No credentials available (both transit and legacy are empty)")

View File

@@ -0,0 +1,670 @@
"""
Device service — business logic for device CRUD, credential encryption, groups, and tags.
All functions operate via the app_user engine (RLS enforced).
Tenant isolation is handled automatically by PostgreSQL RLS policies
(SET LOCAL app.current_tenant is set by the get_current_user dependency before
this layer is called).
Credential policy:
- Credentials are always stored as AES-256-GCM encrypted JSON blobs.
- Credentials are NEVER returned in any public-facing response.
- Re-encryption happens only when a new password is explicitly provided in an update.
"""
import asyncio
import json
import uuid
from typing import Optional
from sqlalchemy import func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models.device import (
Device,
DeviceGroup,
DeviceGroupMembership,
DeviceTag,
DeviceTagAssignment,
)
from app.schemas.device import (
BulkAddRequest,
BulkAddResult,
DeviceCreate,
DeviceGroupCreate,
DeviceGroupResponse,
DeviceGroupUpdate,
DeviceResponse,
DeviceTagCreate,
DeviceTagResponse,
DeviceTagUpdate,
DeviceUpdate,
)
from app.config import settings
from app.services.crypto import (
decrypt_credentials,
decrypt_credentials_hybrid,
encrypt_credentials,
encrypt_credentials_transit,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _tcp_reachable(ip: str, port: int, timeout: float = 3.0) -> bool:
"""Return True if a TCP connection to ip:port succeeds within timeout."""
try:
_, writer = await asyncio.wait_for(
asyncio.open_connection(ip, port), timeout=timeout
)
writer.close()
try:
await writer.wait_closed()
except Exception:
pass
return True
except Exception:
return False
def _build_device_response(device: Device) -> DeviceResponse:
"""
Build a DeviceResponse from an ORM Device instance.
Tags and groups are extracted from pre-loaded relationships.
Credentials are explicitly EXCLUDED.
"""
from app.schemas.device import DeviceGroupRef, DeviceTagRef
tags = [
DeviceTagRef(
id=a.tag.id,
name=a.tag.name,
color=a.tag.color,
)
for a in device.tag_assignments
]
groups = [
DeviceGroupRef(
id=m.group.id,
name=m.group.name,
)
for m in device.group_memberships
]
return DeviceResponse(
id=device.id,
hostname=device.hostname,
ip_address=device.ip_address,
api_port=device.api_port,
api_ssl_port=device.api_ssl_port,
model=device.model,
serial_number=device.serial_number,
firmware_version=device.firmware_version,
routeros_version=device.routeros_version,
uptime_seconds=device.uptime_seconds,
last_seen=device.last_seen,
latitude=device.latitude,
longitude=device.longitude,
status=device.status,
tls_mode=device.tls_mode,
tags=tags,
groups=groups,
created_at=device.created_at,
)
def _device_with_relations():
"""Return a select() for Device with tags and groups eagerly loaded."""
return select(Device).options(
selectinload(Device.tag_assignments).selectinload(DeviceTagAssignment.tag),
selectinload(Device.group_memberships).selectinload(DeviceGroupMembership.group),
)
# ---------------------------------------------------------------------------
# Device CRUD
# ---------------------------------------------------------------------------
async def create_device(
db: AsyncSession,
tenant_id: uuid.UUID,
data: DeviceCreate,
encryption_key: bytes,
) -> DeviceResponse:
"""
Create a new device.
- Validates TCP connectivity (api_port or api_ssl_port must be reachable).
- Encrypts credentials before storage.
- Status set to "unknown" until the Go poller runs a full auth check (Phase 2).
"""
# Test connectivity before accepting the device
api_reachable = await _tcp_reachable(data.ip_address, data.api_port)
ssl_reachable = await _tcp_reachable(data.ip_address, data.api_ssl_port)
if not api_reachable and not ssl_reachable:
from fastapi import HTTPException, status
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=(
f"Cannot reach {data.ip_address} on port {data.api_port} "
f"(RouterOS API) or {data.api_ssl_port} (RouterOS SSL API). "
"Verify the IP address and that the RouterOS API is enabled."
),
)
# Encrypt credentials via OpenBao Transit (new writes go through Transit)
credentials_json = json.dumps({"username": data.username, "password": data.password})
transit_ciphertext = await encrypt_credentials_transit(
credentials_json, str(tenant_id)
)
device = Device(
tenant_id=tenant_id,
hostname=data.hostname,
ip_address=data.ip_address,
api_port=data.api_port,
api_ssl_port=data.api_ssl_port,
encrypted_credentials_transit=transit_ciphertext,
status="unknown",
)
db.add(device)
await db.flush() # Get the ID without committing
await db.refresh(device)
# Re-query with relationships loaded
result = await db.execute(
_device_with_relations().where(Device.id == device.id)
)
device = result.scalar_one()
return _build_device_response(device)
async def get_devices(
db: AsyncSession,
tenant_id: uuid.UUID,
page: int = 1,
page_size: int = 25,
status: Optional[str] = None,
search: Optional[str] = None,
tag_id: Optional[uuid.UUID] = None,
group_id: Optional[uuid.UUID] = None,
sort_by: str = "created_at",
sort_order: str = "desc",
) -> tuple[list[DeviceResponse], int]:
"""
Return a paginated list of devices with optional filtering and sorting.
Returns (items, total_count).
RLS automatically scopes this to the caller's tenant.
"""
base_q = _device_with_relations()
# Filtering
if status:
base_q = base_q.where(Device.status == status)
if search:
pattern = f"%{search}%"
base_q = base_q.where(
or_(
Device.hostname.ilike(pattern),
Device.ip_address.ilike(pattern),
)
)
if tag_id:
base_q = base_q.where(
Device.id.in_(
select(DeviceTagAssignment.device_id).where(
DeviceTagAssignment.tag_id == tag_id
)
)
)
if group_id:
base_q = base_q.where(
Device.id.in_(
select(DeviceGroupMembership.device_id).where(
DeviceGroupMembership.group_id == group_id
)
)
)
# Count total before pagination
count_q = select(func.count()).select_from(base_q.subquery())
total_result = await db.execute(count_q)
total = total_result.scalar_one()
# Sorting
allowed_sort_cols = {
"created_at": Device.created_at,
"hostname": Device.hostname,
"ip_address": Device.ip_address,
"status": Device.status,
"last_seen": Device.last_seen,
}
sort_col = allowed_sort_cols.get(sort_by, Device.created_at)
if sort_order.lower() == "asc":
base_q = base_q.order_by(sort_col.asc())
else:
base_q = base_q.order_by(sort_col.desc())
# Pagination
offset = (page - 1) * page_size
base_q = base_q.offset(offset).limit(page_size)
result = await db.execute(base_q)
devices = result.scalars().all()
return [_build_device_response(d) for d in devices], total
async def get_device(
db: AsyncSession,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
) -> DeviceResponse:
"""Get a single device by ID."""
from fastapi import HTTPException, status
result = await db.execute(
_device_with_relations().where(Device.id == device_id)
)
device = result.scalar_one_or_none()
if not device:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
return _build_device_response(device)
async def update_device(
db: AsyncSession,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
data: DeviceUpdate,
encryption_key: bytes,
) -> DeviceResponse:
"""
Update device fields. Re-encrypts credentials only if password is provided.
"""
from fastapi import HTTPException, status
result = await db.execute(
_device_with_relations().where(Device.id == device_id)
)
device = result.scalar_one_or_none()
if not device:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
# Update scalar fields
if data.hostname is not None:
device.hostname = data.hostname
if data.ip_address is not None:
device.ip_address = data.ip_address
if data.api_port is not None:
device.api_port = data.api_port
if data.api_ssl_port is not None:
device.api_ssl_port = data.api_ssl_port
if data.latitude is not None:
device.latitude = data.latitude
if data.longitude is not None:
device.longitude = data.longitude
if data.tls_mode is not None:
device.tls_mode = data.tls_mode
# Re-encrypt credentials if new ones are provided
credentials_changed = False
if data.password is not None:
# Decrypt existing to get current username if no new username given
current_username: str = data.username or ""
if not current_username and (device.encrypted_credentials_transit or device.encrypted_credentials):
try:
existing_json = await decrypt_credentials_hybrid(
device.encrypted_credentials_transit,
device.encrypted_credentials,
str(device.tenant_id),
settings.get_encryption_key_bytes(),
)
existing = json.loads(existing_json)
current_username = existing.get("username", "")
except Exception:
current_username = ""
credentials_json = json.dumps({
"username": data.username if data.username is not None else current_username,
"password": data.password,
})
# New writes go through Transit
device.encrypted_credentials_transit = await encrypt_credentials_transit(
credentials_json, str(device.tenant_id)
)
device.encrypted_credentials = None # Clear legacy (Transit is canonical)
credentials_changed = True
elif data.username is not None and (device.encrypted_credentials_transit or device.encrypted_credentials):
# Only username changed — update it without changing the password
try:
existing_json = await decrypt_credentials_hybrid(
device.encrypted_credentials_transit,
device.encrypted_credentials,
str(device.tenant_id),
settings.get_encryption_key_bytes(),
)
existing = json.loads(existing_json)
existing["username"] = data.username
# Re-encrypt via Transit
device.encrypted_credentials_transit = await encrypt_credentials_transit(
json.dumps(existing), str(device.tenant_id)
)
device.encrypted_credentials = None
credentials_changed = True
except Exception:
pass # Keep existing encrypted blob if decryption fails
await db.flush()
await db.refresh(device)
# Notify poller to invalidate cached credentials (fire-and-forget via NATS)
if credentials_changed:
try:
from app.services.event_publisher import publish_event
await publish_event(
f"device.credential_changed.{device_id}",
{"device_id": str(device_id), "tenant_id": str(tenant_id)},
)
except Exception:
pass # Never fail the update due to NATS issues
result2 = await db.execute(
_device_with_relations().where(Device.id == device_id)
)
device = result2.scalar_one()
return _build_device_response(device)
async def delete_device(
db: AsyncSession,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
) -> None:
"""Hard-delete a device (v1 — no soft delete for devices)."""
from fastapi import HTTPException, status
result = await db.execute(select(Device).where(Device.id == device_id))
device = result.scalar_one_or_none()
if not device:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
await db.delete(device)
await db.flush()
# ---------------------------------------------------------------------------
# Group / Tag assignment
# ---------------------------------------------------------------------------
async def assign_device_to_group(
db: AsyncSession,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
group_id: uuid.UUID,
) -> None:
"""Assign a device to a group (idempotent)."""
from fastapi import HTTPException, status
# Verify device and group exist (RLS scopes both)
dev = await db.get(Device, device_id)
if not dev:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
grp = await db.get(DeviceGroup, group_id)
if not grp:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group not found")
existing = await db.get(DeviceGroupMembership, (device_id, group_id))
if not existing:
db.add(DeviceGroupMembership(device_id=device_id, group_id=group_id))
await db.flush()
async def remove_device_from_group(
db: AsyncSession,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
group_id: uuid.UUID,
) -> None:
"""Remove a device from a group."""
from fastapi import HTTPException, status
membership = await db.get(DeviceGroupMembership, (device_id, group_id))
if not membership:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Device is not in this group",
)
await db.delete(membership)
await db.flush()
async def assign_tag_to_device(
db: AsyncSession,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
tag_id: uuid.UUID,
) -> None:
"""Assign a tag to a device (idempotent)."""
from fastapi import HTTPException, status
dev = await db.get(Device, device_id)
if not dev:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
tag = await db.get(DeviceTag, tag_id)
if not tag:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tag not found")
existing = await db.get(DeviceTagAssignment, (device_id, tag_id))
if not existing:
db.add(DeviceTagAssignment(device_id=device_id, tag_id=tag_id))
await db.flush()
async def remove_tag_from_device(
db: AsyncSession,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
tag_id: uuid.UUID,
) -> None:
"""Remove a tag from a device."""
from fastapi import HTTPException, status
assignment = await db.get(DeviceTagAssignment, (device_id, tag_id))
if not assignment:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Tag is not assigned to this device",
)
await db.delete(assignment)
await db.flush()
# ---------------------------------------------------------------------------
# DeviceGroup CRUD
# ---------------------------------------------------------------------------
async def create_group(
db: AsyncSession,
tenant_id: uuid.UUID,
data: DeviceGroupCreate,
) -> DeviceGroupResponse:
"""Create a new device group."""
group = DeviceGroup(
tenant_id=tenant_id,
name=data.name,
description=data.description,
)
db.add(group)
await db.flush()
await db.refresh(group)
# Count devices in the group (0 for new group)
return DeviceGroupResponse(
id=group.id,
name=group.name,
description=group.description,
device_count=0,
created_at=group.created_at,
)
async def get_groups(
db: AsyncSession,
tenant_id: uuid.UUID,
) -> list[DeviceGroupResponse]:
"""Return all device groups for the current tenant with device counts."""
result = await db.execute(
select(DeviceGroup).options(
selectinload(DeviceGroup.memberships)
)
)
groups = result.scalars().all()
return [
DeviceGroupResponse(
id=g.id,
name=g.name,
description=g.description,
device_count=len(g.memberships),
created_at=g.created_at,
)
for g in groups
]
async def update_group(
db: AsyncSession,
tenant_id: uuid.UUID,
group_id: uuid.UUID,
data: DeviceGroupUpdate,
) -> DeviceGroupResponse:
"""Update a device group."""
from fastapi import HTTPException, status
result = await db.execute(
select(DeviceGroup).options(
selectinload(DeviceGroup.memberships)
).where(DeviceGroup.id == group_id)
)
group = result.scalar_one_or_none()
if not group:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group not found")
if data.name is not None:
group.name = data.name
if data.description is not None:
group.description = data.description
await db.flush()
await db.refresh(group)
result2 = await db.execute(
select(DeviceGroup).options(
selectinload(DeviceGroup.memberships)
).where(DeviceGroup.id == group_id)
)
group = result2.scalar_one()
return DeviceGroupResponse(
id=group.id,
name=group.name,
description=group.description,
device_count=len(group.memberships),
created_at=group.created_at,
)
async def delete_group(
db: AsyncSession,
tenant_id: uuid.UUID,
group_id: uuid.UUID,
) -> None:
"""Delete a device group."""
from fastapi import HTTPException, status
group = await db.get(DeviceGroup, group_id)
if not group:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group not found")
await db.delete(group)
await db.flush()
# ---------------------------------------------------------------------------
# DeviceTag CRUD
# ---------------------------------------------------------------------------
async def create_tag(
db: AsyncSession,
tenant_id: uuid.UUID,
data: DeviceTagCreate,
) -> DeviceTagResponse:
"""Create a new device tag."""
tag = DeviceTag(
tenant_id=tenant_id,
name=data.name,
color=data.color,
)
db.add(tag)
await db.flush()
await db.refresh(tag)
return DeviceTagResponse(id=tag.id, name=tag.name, color=tag.color)
async def get_tags(
db: AsyncSession,
tenant_id: uuid.UUID,
) -> list[DeviceTagResponse]:
"""Return all device tags for the current tenant."""
result = await db.execute(select(DeviceTag))
tags = result.scalars().all()
return [DeviceTagResponse(id=t.id, name=t.name, color=t.color) for t in tags]
async def update_tag(
db: AsyncSession,
tenant_id: uuid.UUID,
tag_id: uuid.UUID,
data: DeviceTagUpdate,
) -> DeviceTagResponse:
"""Update a device tag."""
from fastapi import HTTPException, status
tag = await db.get(DeviceTag, tag_id)
if not tag:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tag not found")
if data.name is not None:
tag.name = data.name
if data.color is not None:
tag.color = data.color
await db.flush()
await db.refresh(tag)
return DeviceTagResponse(id=tag.id, name=tag.name, color=tag.color)
async def delete_tag(
db: AsyncSession,
tenant_id: uuid.UUID,
tag_id: uuid.UUID,
) -> None:
"""Delete a device tag."""
from fastapi import HTTPException, status
tag = await db.get(DeviceTag, tag_id)
if not tag:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Tag not found")
await db.delete(tag)
await db.flush()

View File

@@ -0,0 +1,124 @@
"""Unified email sending service.
All email sending (system emails, alert notifications) goes through this module.
Supports TLS, STARTTLS, and plain SMTP. Handles Transit + legacy Fernet password decryption.
"""
import logging
from email.message import EmailMessage
from typing import Optional
import aiosmtplib
logger = logging.getLogger(__name__)
class SMTPConfig:
"""SMTP connection configuration."""
def __init__(
self,
host: str,
port: int = 587,
user: Optional[str] = None,
password: Optional[str] = None,
use_tls: bool = False,
from_address: str = "noreply@example.com",
):
self.host = host
self.port = port
self.user = user
self.password = password
self.use_tls = use_tls
self.from_address = from_address
async def send_email(
to: str,
subject: str,
html: str,
plain_text: str,
smtp_config: SMTPConfig,
) -> None:
"""Send an email via SMTP.
Args:
to: Recipient email address.
subject: Email subject line.
html: HTML body.
plain_text: Plain text fallback body.
smtp_config: SMTP connection settings.
Raises:
aiosmtplib.SMTPException: On SMTP connection or send failure.
"""
msg = EmailMessage()
msg["Subject"] = subject
msg["From"] = smtp_config.from_address
msg["To"] = to
msg.set_content(plain_text)
msg.add_alternative(html, subtype="html")
use_tls = smtp_config.use_tls
start_tls = not use_tls if smtp_config.port != 25 else False
await aiosmtplib.send(
msg,
hostname=smtp_config.host,
port=smtp_config.port,
username=smtp_config.user or None,
password=smtp_config.password or None,
use_tls=use_tls,
start_tls=start_tls,
)
async def test_smtp_connection(smtp_config: SMTPConfig) -> dict:
"""Test SMTP connectivity without sending an email.
Returns:
dict with "success" bool and "message" string.
"""
try:
smtp = aiosmtplib.SMTP(
hostname=smtp_config.host,
port=smtp_config.port,
use_tls=smtp_config.use_tls,
start_tls=not smtp_config.use_tls if smtp_config.port != 25 else False,
)
await smtp.connect()
if smtp_config.user and smtp_config.password:
await smtp.login(smtp_config.user, smtp_config.password)
await smtp.quit()
return {"success": True, "message": "SMTP connection successful"}
except Exception as e:
return {"success": False, "message": str(e)}
async def send_test_email(to: str, smtp_config: SMTPConfig) -> dict:
"""Send a test email to verify the full SMTP flow.
Returns:
dict with "success" bool and "message" string.
"""
html = """
<div style="font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; max-width: 600px; margin: 0 auto;">
<div style="background: #0f172a; padding: 24px; border-radius: 8px 8px 0 0;">
<h2 style="color: #38bdf8; margin: 0;">TOD — Email Test</h2>
</div>
<div style="background: #1e293b; padding: 24px; border-radius: 0 0 8px 8px; color: #e2e8f0;">
<p>This is a test email from The Other Dude.</p>
<p>If you're reading this, your SMTP configuration is working correctly.</p>
<p style="color: #94a3b8; font-size: 13px; margin-top: 24px;">
Sent from TOD Fleet Management
</p>
</div>
</div>
"""
plain = "TOD — Email Test\n\nThis is a test email from The Other Dude.\nIf you're reading this, your SMTP configuration is working correctly."
try:
await send_email(to, "TOD — Test Email", html, plain, smtp_config)
return {"success": True, "message": f"Test email sent to {to}"}
except Exception as e:
return {"success": False, "message": str(e)}

View File

@@ -0,0 +1,54 @@
"""Emergency Kit PDF template generation.
Generates an Emergency Kit PDF containing the user's email and sign-in URL
but NOT the Secret Key. The Secret Key placeholder is filled client-side
so that the server never sees it.
Uses Jinja2 + WeasyPrint following the same pattern as the reports service.
"""
import asyncio
from datetime import UTC, datetime
from pathlib import Path
from jinja2 import Environment, FileSystemLoader
from app.config import settings
TEMPLATE_DIR = Path(__file__).parent.parent.parent / "templates"
async def generate_emergency_kit_template(
email: str,
) -> bytes:
"""Generate Emergency Kit PDF template WITHOUT the Secret Key.
The Secret Key placeholder will be filled client-side.
The server never sees the Secret Key.
Args:
email: The user's email address to display in the PDF.
Returns:
PDF bytes ready for streaming response.
"""
env = Environment(
loader=FileSystemLoader(str(TEMPLATE_DIR)),
autoescape=True,
)
template = env.get_template("emergency_kit.html")
html_content = template.render(
email=email,
signin_url=settings.APP_BASE_URL,
date=datetime.now(UTC).strftime("%Y-%m-%d"),
secret_key_placeholder="[Download complete -- your Secret Key will be inserted by your browser]",
)
# Run weasyprint in thread to avoid blocking the event loop
from weasyprint import HTML
pdf_bytes = await asyncio.to_thread(
lambda: HTML(string=html_content).write_pdf()
)
return pdf_bytes

View File

@@ -0,0 +1,52 @@
"""Fire-and-forget NATS JetStream event publisher for real-time SSE pipeline.
Provides a shared lazy NATS connection and publish helper used by:
- alert_evaluator.py (alert.fired.{tenant_id}, alert.resolved.{tenant_id})
- restore_service.py (config.push.{tenant_id}.{device_id})
- upgrade_service.py (firmware.progress.{tenant_id}.{device_id})
All publishes are fire-and-forget: errors are logged but never propagate
to the caller. A NATS outage must never block alert evaluation, config
push, or firmware upgrade operations.
"""
import json
import logging
from typing import Any
import nats
import nats.aio.client
from app.config import settings
logger = logging.getLogger(__name__)
# Module-level NATS connection (lazy initialized, reused across publishes)
_nc: nats.aio.client.Client | None = None
async def _get_nats() -> nats.aio.client.Client:
"""Get or create a NATS connection for event publishing."""
global _nc
if _nc is None or _nc.is_closed:
_nc = await nats.connect(settings.NATS_URL)
logger.info("Event publisher NATS connection established")
return _nc
async def publish_event(subject: str, payload: dict[str, Any]) -> None:
"""Publish a JSON event to a NATS JetStream subject (fire-and-forget).
Args:
subject: NATS subject, e.g. "alert.fired.{tenant_id}".
payload: Dict that will be JSON-serialized as the message body.
Never raises -- all exceptions are caught and logged as warnings.
"""
try:
nc = await _get_nats()
js = nc.jetstream()
await js.publish(subject, json.dumps(payload).encode())
logger.debug("Published event to %s", subject)
except Exception as exc:
logger.warning("Failed to publish event to %s: %s", subject, exc)

View File

@@ -0,0 +1,303 @@
"""Firmware version cache service and NPK downloader.
Responsibilities:
- check_latest_versions(): fetch latest RouterOS versions from download.mikrotik.com
- download_firmware(): download NPK packages to local PVC cache
- get_firmware_overview(): return fleet firmware status for a tenant
- schedule_firmware_checks(): register daily firmware check job with APScheduler
Version discovery comes from two sources:
1. Go poller runs /system/package/update per device (rate-limited to once/day)
and publishes via NATS -> firmware_subscriber processes these events
2. check_latest_versions() fetches LATEST.7 / LATEST.6 from download.mikrotik.com
"""
import logging
import os
from pathlib import Path
import httpx
from sqlalchemy import text
from app.config import settings
from app.database import AdminAsyncSessionLocal
logger = logging.getLogger(__name__)
# Architectures supported by RouterOS v7 and v6
_V7_ARCHITECTURES = ["arm", "arm64", "mipsbe", "mmips", "smips", "tile", "ppc", "x86"]
_V6_ARCHITECTURES = ["mipsbe", "mmips", "smips", "tile", "ppc", "x86"]
# Version source files on download.mikrotik.com
_VERSION_SOURCES = [
("LATEST.7", "stable", 7),
("LATEST.7long", "long-term", 7),
("LATEST.6", "stable", 6),
("LATEST.6long", "long-term", 6),
]
async def check_latest_versions() -> list[dict]:
"""Fetch latest RouterOS versions from download.mikrotik.com.
Checks LATEST.7, LATEST.7long, LATEST.6, and LATEST.6long files for
version strings, then upserts into firmware_versions table for each
architecture/channel combination.
Returns list of discovered version dicts.
"""
results: list[dict] = []
async with httpx.AsyncClient(timeout=30.0) as client:
for channel_file, channel, major in _VERSION_SOURCES:
try:
resp = await client.get(
f"https://download.mikrotik.com/routeros/{channel_file}"
)
if resp.status_code != 200:
logger.warning(
"MikroTik version check returned %d for %s",
resp.status_code, channel_file,
)
continue
version = resp.text.strip()
if not version or not version[0].isdigit():
logger.warning("Invalid version string from %s: %r", channel_file, version)
continue
architectures = _V7_ARCHITECTURES if major == 7 else _V6_ARCHITECTURES
for arch in architectures:
npk_url = (
f"https://download.mikrotik.com/routeros/"
f"{version}/routeros-{version}-{arch}.npk"
)
results.append({
"architecture": arch,
"channel": channel,
"version": version,
"npk_url": npk_url,
})
except Exception as e:
logger.warning("Failed to check %s: %s", channel_file, e)
# Upsert into firmware_versions table
if results:
async with AdminAsyncSessionLocal() as session:
for r in results:
await session.execute(
text("""
INSERT INTO firmware_versions (id, architecture, channel, version, npk_url, checked_at)
VALUES (gen_random_uuid(), :arch, :channel, :version, :npk_url, NOW())
ON CONFLICT (architecture, channel, version) DO UPDATE SET checked_at = NOW()
"""),
{
"arch": r["architecture"],
"channel": r["channel"],
"version": r["version"],
"npk_url": r["npk_url"],
},
)
await session.commit()
logger.info("Firmware version check complete — %d versions discovered", len(results))
return results
async def download_firmware(architecture: str, channel: str, version: str) -> str:
"""Download an NPK package to the local firmware cache.
Returns the local file path. Skips download if file already exists
and size matches.
"""
cache_dir = Path(settings.FIRMWARE_CACHE_DIR) / version
cache_dir.mkdir(parents=True, exist_ok=True)
filename = f"routeros-{version}-{architecture}.npk"
local_path = cache_dir / filename
npk_url = f"https://download.mikrotik.com/routeros/{version}/{filename}"
# Check if already cached
if local_path.exists() and local_path.stat().st_size > 0:
logger.info("Firmware already cached: %s", local_path)
return str(local_path)
logger.info("Downloading firmware: %s", npk_url)
async with httpx.AsyncClient(timeout=300.0) as client:
async with client.stream("GET", npk_url) as response:
response.raise_for_status()
with open(local_path, "wb") as f:
async for chunk in response.aiter_bytes(chunk_size=65536):
f.write(chunk)
file_size = local_path.stat().st_size
logger.info("Firmware downloaded: %s (%d bytes)", local_path, file_size)
# Update firmware_versions table with local path and size
async with AdminAsyncSessionLocal() as session:
await session.execute(
text("""
UPDATE firmware_versions
SET npk_local_path = :path, npk_size_bytes = :size
WHERE architecture = :arch AND channel = :channel AND version = :version
"""),
{
"path": str(local_path),
"size": file_size,
"arch": architecture,
"channel": channel,
"version": version,
},
)
await session.commit()
return str(local_path)
async def get_firmware_overview(tenant_id: str) -> dict:
"""Return fleet firmware status for a tenant.
Returns devices grouped by firmware version, annotated with up-to-date status
based on the latest known version for each device's architecture and preferred channel.
"""
async with AdminAsyncSessionLocal() as session:
# Get all devices for tenant
devices_result = await session.execute(
text("""
SELECT id, hostname, ip_address, routeros_version, architecture,
preferred_channel, routeros_major_version,
serial_number, firmware_version, model
FROM devices
WHERE tenant_id = CAST(:tenant_id AS uuid)
ORDER BY hostname
"""),
{"tenant_id": tenant_id},
)
devices = devices_result.fetchall()
# Get latest firmware versions per architecture/channel
versions_result = await session.execute(
text("""
SELECT DISTINCT ON (architecture, channel)
architecture, channel, version, npk_url
FROM firmware_versions
ORDER BY architecture, channel, checked_at DESC
""")
)
latest_versions = {
(row[0], row[1]): {"version": row[2], "npk_url": row[3]}
for row in versions_result.fetchall()
}
# Build per-device status
device_list = []
version_groups: dict[str, list] = {}
summary = {"total": 0, "up_to_date": 0, "outdated": 0, "unknown": 0}
for dev in devices:
dev_id = str(dev[0])
hostname = dev[1]
current_version = dev[3]
arch = dev[4]
channel = dev[5] or "stable"
latest = latest_versions.get((arch, channel)) if arch else None
latest_version = latest["version"] if latest else None
is_up_to_date = False
if not current_version or not arch:
summary["unknown"] += 1
elif latest_version and current_version == latest_version:
is_up_to_date = True
summary["up_to_date"] += 1
else:
summary["outdated"] += 1
summary["total"] += 1
dev_info = {
"id": dev_id,
"hostname": hostname,
"ip_address": dev[2],
"routeros_version": current_version,
"architecture": arch,
"latest_version": latest_version,
"channel": channel,
"is_up_to_date": is_up_to_date,
"serial_number": dev[7],
"firmware_version": dev[8],
"model": dev[9],
}
device_list.append(dev_info)
# Group by version
ver_key = current_version or "unknown"
if ver_key not in version_groups:
version_groups[ver_key] = []
version_groups[ver_key].append(dev_info)
# Build version groups with is_latest flag
groups = []
for ver, devs in sorted(version_groups.items()):
# A version is "latest" if it matches the latest for any arch/channel combo
is_latest = any(
v["version"] == ver for v in latest_versions.values()
)
groups.append({
"version": ver,
"count": len(devs),
"is_latest": is_latest,
"devices": devs,
})
return {
"devices": device_list,
"version_groups": groups,
"summary": summary,
}
async def get_cached_firmware() -> list[dict]:
"""List all locally cached NPK files with their sizes."""
cache_dir = Path(settings.FIRMWARE_CACHE_DIR)
cached = []
if not cache_dir.exists():
return cached
for version_dir in sorted(cache_dir.iterdir()):
if not version_dir.is_dir():
continue
for npk_file in sorted(version_dir.iterdir()):
if npk_file.suffix == ".npk":
cached.append({
"path": str(npk_file),
"version": version_dir.name,
"filename": npk_file.name,
"size_bytes": npk_file.stat().st_size,
})
return cached
def schedule_firmware_checks() -> None:
"""Register daily firmware version check with APScheduler.
Called from FastAPI lifespan startup to schedule check_latest_versions()
at 3am UTC daily.
"""
from apscheduler.triggers.cron import CronTrigger
from app.services.backup_scheduler import backup_scheduler
backup_scheduler.add_job(
check_latest_versions,
trigger=CronTrigger(hour=3, minute=0, timezone="UTC"),
id="firmware_version_check",
name="Check for new RouterOS firmware versions",
max_instances=1,
replace_existing=True,
)
logger.info("Firmware version check scheduled — daily at 3am UTC")

View File

@@ -0,0 +1,206 @@
"""NATS JetStream subscriber for device firmware events from the Go poller.
Subscribes to device.firmware.> and:
1. Updates devices.routeros_version and devices.architecture from poller data
2. Upserts firmware_versions table with latest version per architecture/channel
Uses AdminAsyncSessionLocal (superuser bypass RLS) so firmware data from any
tenant can be written without setting app.current_tenant.
"""
import asyncio
import json
import logging
from typing import Optional
import nats
from nats.js import JetStreamContext
from nats.aio.client import Client as NATSClient
from sqlalchemy import text
from app.config import settings
from app.database import AdminAsyncSessionLocal
logger = logging.getLogger(__name__)
_firmware_client: Optional[NATSClient] = None
async def on_device_firmware(msg) -> None:
"""Handle a device.firmware event published by the Go poller.
Payload (JSON):
device_id (str) -- UUID of the device
tenant_id (str) -- UUID of the owning tenant
installed_version (str) -- currently installed RouterOS version
latest_version (str) -- latest available version (may be empty)
channel (str) -- firmware channel ("stable", "long-term")
status (str) -- "New version is available", etc.
architecture (str) -- CPU architecture (arm, arm64, mipsbe, etc.)
"""
try:
data = json.loads(msg.data)
device_id = data.get("device_id")
tenant_id = data.get("tenant_id")
architecture = data.get("architecture")
installed_version = data.get("installed_version")
latest_version = data.get("latest_version")
channel = data.get("channel", "stable")
if not device_id:
logger.warning("device.firmware event missing device_id — skipping")
await msg.ack()
return
async with AdminAsyncSessionLocal() as session:
# Update device routeros_version and architecture from poller data
if architecture or installed_version:
await session.execute(
text("""
UPDATE devices
SET routeros_version = COALESCE(:installed_ver, routeros_version),
architecture = COALESCE(:architecture, architecture),
updated_at = NOW()
WHERE id = CAST(:device_id AS uuid)
"""),
{
"installed_ver": installed_version,
"architecture": architecture,
"device_id": device_id,
},
)
# Upsert firmware_versions if we got latest version info
if latest_version and architecture:
npk_url = (
f"https://download.mikrotik.com/routeros/"
f"{latest_version}/routeros-{latest_version}-{architecture}.npk"
)
await session.execute(
text("""
INSERT INTO firmware_versions (id, architecture, channel, version, npk_url, checked_at)
VALUES (gen_random_uuid(), :arch, :channel, :version, :url, NOW())
ON CONFLICT (architecture, channel, version) DO UPDATE SET checked_at = NOW()
"""),
{
"arch": architecture,
"channel": channel,
"version": latest_version,
"url": npk_url,
},
)
await session.commit()
logger.debug(
"device.firmware processed",
extra={
"device_id": device_id,
"architecture": architecture,
"installed": installed_version,
"latest": latest_version,
},
)
await msg.ack()
except Exception as exc:
logger.error(
"Failed to process device.firmware event: %s",
exc,
exc_info=True,
)
try:
await msg.nak()
except Exception:
pass
async def _subscribe_with_retry(js: JetStreamContext) -> None:
"""Subscribe to device.firmware.> with durable consumer, retrying if stream not ready."""
max_attempts = 6 # ~30 seconds at 5s intervals
for attempt in range(1, max_attempts + 1):
try:
await js.subscribe(
"device.firmware.>",
cb=on_device_firmware,
durable="api-firmware-consumer",
stream="DEVICE_EVENTS",
)
logger.info(
"NATS: subscribed to device.firmware.> (durable: api-firmware-consumer)"
)
return
except Exception as exc:
if attempt < max_attempts:
logger.warning(
"NATS: stream DEVICE_EVENTS not ready for firmware (attempt %d/%d): %s — retrying in 5s",
attempt,
max_attempts,
exc,
)
await asyncio.sleep(5)
else:
logger.warning(
"NATS: giving up on device.firmware.> after %d attempts: %s — API will run without firmware updates",
max_attempts,
exc,
)
return
async def start_firmware_subscriber() -> Optional[NATSClient]:
"""Connect to NATS and start the device.firmware.> subscription.
Uses a separate NATS connection from the status and metrics subscribers.
Returns the NATS connection (must be passed to stop_firmware_subscriber on shutdown).
Raises on fatal connection errors after retry exhaustion.
"""
global _firmware_client
logger.info("NATS firmware: connecting to %s", settings.NATS_URL)
nc = await nats.connect(
settings.NATS_URL,
max_reconnect_attempts=-1,
reconnect_time_wait=2,
error_cb=_on_error,
reconnected_cb=_on_reconnected,
disconnected_cb=_on_disconnected,
)
logger.info("NATS firmware: connected to %s", settings.NATS_URL)
js = nc.jetstream()
await _subscribe_with_retry(js)
_firmware_client = nc
return nc
async def stop_firmware_subscriber(nc: Optional[NATSClient]) -> None:
"""Drain and close the firmware NATS connection gracefully."""
if nc is None:
return
try:
logger.info("NATS firmware: draining connection...")
await nc.drain()
logger.info("NATS firmware: connection closed")
except Exception as exc:
logger.warning("NATS firmware: error during drain: %s", exc)
try:
await nc.close()
except Exception:
pass
async def _on_error(exc: Exception) -> None:
logger.error("NATS firmware error: %s", exc)
async def _on_reconnected() -> None:
logger.info("NATS firmware: reconnected")
async def _on_disconnected() -> None:
logger.warning("NATS firmware: disconnected")

View File

@@ -0,0 +1,296 @@
"""pygit2-based git store for versioned config backup storage.
All functions in this module are synchronous (pygit2 is C bindings over libgit2).
Callers running in an async context MUST wrap calls in:
loop.run_in_executor(None, func, *args)
or:
asyncio.get_event_loop().run_in_executor(None, func, *args)
See Pitfall 3 in 04-RESEARCH.md — blocking pygit2 in async context stalls
the event loop and causes timeouts for other concurrent requests.
Git layout:
{GIT_STORE_PATH}/{tenant_id}.git/ <- bare repo per tenant
objects/ refs/ HEAD <- standard bare git structure
{device_id}/ <- device subtree
export.rsc <- text export (/export compact)
backup.bin <- binary system backup
"""
import difflib
import threading
from pathlib import Path
from typing import Optional
import pygit2
from app.config import settings
# =========================================================================
# Per-tenant mutex to prevent TreeBuilder race condition (Pitfall 5 in RESEARCH.md).
# Two simultaneous backups for different devices in the same tenant repo would
# each read HEAD, build their own device subtrees, and write conflicting root
# trees. The second commit would lose the first's device subtree.
# Lock scope is the entire tenant repo — not just the device.
# =========================================================================
_tenant_locks: dict[str, threading.Lock] = {}
_tenant_locks_guard = threading.Lock()
def _get_tenant_lock(tenant_id: str) -> threading.Lock:
"""Return (creating if needed) the per-tenant commit lock."""
with _tenant_locks_guard:
if tenant_id not in _tenant_locks:
_tenant_locks[tenant_id] = threading.Lock()
return _tenant_locks[tenant_id]
# =========================================================================
# PUBLIC API
# =========================================================================
def get_or_create_repo(tenant_id: str) -> pygit2.Repository:
"""Open the tenant's bare git repo, creating it on first use.
The repo lives at {GIT_STORE_PATH}/{tenant_id}.git. The parent directory
is created if it does not exist.
Args:
tenant_id: Tenant UUID as string.
Returns:
An open pygit2.Repository instance (bare).
"""
git_store_root = Path(settings.GIT_STORE_PATH)
git_store_root.mkdir(parents=True, exist_ok=True)
repo_path = git_store_root / f"{tenant_id}.git"
if repo_path.exists():
return pygit2.Repository(str(repo_path))
return pygit2.init_repository(str(repo_path), bare=True)
def commit_backup(
tenant_id: str,
device_id: str,
export_text: str,
binary_backup: bytes,
message: str,
) -> str:
"""Write a backup pair (export.rsc + backup.bin) as a git commit.
Creates or updates the device subdirectory in the tenant's bare repo.
Preserves other devices' subdirectories by merging the device subtree
into the existing root tree.
Per-tenant locking (threading.Lock) prevents the TreeBuilder race
condition when two devices in the same tenant back up concurrently.
Args:
tenant_id: Tenant UUID as string.
device_id: Device UUID as string (becomes a subdirectory in the repo).
export_text: Text output of /export compact.
binary_backup: Raw bytes from /system backup save.
message: Commit message (format: "{trigger}: {hostname} ({ip}) at {ts}").
Returns:
The hex commit SHA string (40 characters).
"""
lock = _get_tenant_lock(tenant_id)
with lock:
repo = get_or_create_repo(tenant_id)
# Create blobs from content
export_oid = repo.create_blob(export_text.encode("utf-8"))
binary_oid = repo.create_blob(binary_backup)
# Build device subtree: {device_id}/export.rsc and {device_id}/backup.bin
device_builder = repo.TreeBuilder()
device_builder.insert("export.rsc", export_oid, pygit2.GIT_FILEMODE_BLOB)
device_builder.insert("backup.bin", binary_oid, pygit2.GIT_FILEMODE_BLOB)
device_tree_oid = device_builder.write()
# Merge device subtree into root tree, preserving all other device subtrees.
# If the repo has no commits yet, start with an empty root tree.
root_ref = repo.references.get("refs/heads/main")
parent_commit: Optional[pygit2.Commit] = None
if root_ref is not None:
try:
parent_commit = repo.get(root_ref.target)
root_builder = repo.TreeBuilder(parent_commit.tree)
except Exception:
root_builder = repo.TreeBuilder()
else:
root_builder = repo.TreeBuilder()
root_builder.insert(device_id, device_tree_oid, pygit2.GIT_FILEMODE_TREE)
root_tree_oid = root_builder.write()
# Author signature — no real identity, portal service account
author = pygit2.Signature("The Other Dude", "backup@tod.local")
parents = [root_ref.target] if root_ref is not None else []
commit_oid = repo.create_commit(
"refs/heads/main",
author,
author,
message,
root_tree_oid,
parents,
)
return str(commit_oid)
def read_file(
tenant_id: str,
commit_sha: str,
device_id: str,
filename: str,
) -> bytes:
"""Read a file blob from a specific backup commit.
Navigates the tree: root -> device_id subtree -> filename.
Args:
tenant_id: Tenant UUID as string.
commit_sha: Full or abbreviated git commit SHA.
device_id: Device UUID as string (subdirectory name in the repo).
filename: File to read: "export.rsc" or "backup.bin".
Returns:
Raw bytes of the file content.
Raises:
KeyError: If device_id subtree or filename does not exist in commit.
pygit2.GitError: If commit_sha is not found.
"""
repo = get_or_create_repo(tenant_id)
commit_obj = repo.get(commit_sha)
if commit_obj is None:
raise KeyError(f"Commit {commit_sha!r} not found in tenant {tenant_id!r} repo")
# Navigate: root tree -> device subtree -> file blob
device_entry = commit_obj.tree[device_id]
device_tree = repo.get(device_entry.id)
file_entry = device_tree[filename]
file_blob = repo.get(file_entry.id)
return file_blob.data
def list_device_commits(
tenant_id: str,
device_id: str,
) -> list[dict]:
"""Walk commit history and return commits that include the device subtree.
Walks commits newest-first. Returns only commits where the device_id
subtree is present in the root tree (the device had a backup in that commit).
Args:
tenant_id: Tenant UUID as string.
device_id: Device UUID as string.
Returns:
List of dicts (newest first):
[{"sha": str, "message": str, "timestamp": int}, ...]
Empty list if no commits or device has never been backed up.
"""
repo = get_or_create_repo(tenant_id)
# If there are no commits, return empty list immediately.
# Use refs/heads/main explicitly rather than repo.head (which defaults to
# refs/heads/master — wrong when the repo uses 'main' as the default branch).
main_ref = repo.references.get("refs/heads/main")
if main_ref is None:
return []
head_target = main_ref.target
results = []
walker = repo.walk(head_target, pygit2.GIT_SORT_TIME)
for commit in walker:
# Check if device_id subtree exists in this commit's root tree.
try:
device_entry = commit.tree[device_id]
except KeyError:
# Device not present in this commit at all — skip.
continue
# Only include this commit if it actually changed the device's subtree
# vs its parent. This prevents every subsequent backup (for any device
# in the same tenant) from appearing in all devices' histories.
if commit.parents:
parent = commit.parents[0]
try:
parent_device_entry = parent.tree[device_id]
if parent_device_entry.id == device_entry.id:
# Device subtree unchanged in this commit — skip.
continue
except KeyError:
# Device wasn't in parent but is in this commit — it's the first entry.
pass
results.append({
"sha": str(commit.id),
"message": commit.message.strip(),
"timestamp": commit.commit_time,
})
return results
def compute_line_delta(old_text: str, new_text: str) -> tuple[int, int]:
"""Compute (lines_added, lines_removed) between two text versions.
Uses difflib.SequenceMatcher to efficiently compute the line-count delta
without generating a full unified diff. This is faster than
difflib.unified_diff for large config files.
For the first backup (no prior version), pass old_text="" to get
(total_lines, 0) as the delta.
Args:
old_text: Previous export.rsc content (empty string for first backup).
new_text: New export.rsc content.
Returns:
Tuple of (lines_added, lines_removed).
"""
old_lines = old_text.splitlines() if old_text else []
new_lines = new_text.splitlines() if new_text else []
if not old_lines and not new_lines:
return (0, 0)
# For first backup (empty old), all lines are "added".
if not old_lines:
return (len(new_lines), 0)
# For deletion of all content, all lines are "removed".
if not new_lines:
return (0, len(old_lines))
matcher = difflib.SequenceMatcher(None, old_lines, new_lines, autojunk=False)
lines_added = 0
lines_removed = 0
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == "replace":
lines_removed += i2 - i1
lines_added += j2 - j1
elif tag == "delete":
lines_removed += i2 - i1
elif tag == "insert":
lines_added += j2 - j1
# "equal" — no change
return (lines_added, lines_removed)

View File

@@ -0,0 +1,324 @@
"""Key hierarchy management service for zero-knowledge architecture.
Provides CRUD operations for encrypted key bundles (UserKeySet),
append-only audit logging (KeyAccessLog), and OpenBao Transit
tenant key provisioning with credential migration.
"""
import logging
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.key_set import KeyAccessLog, UserKeySet
logger = logging.getLogger(__name__)
async def store_user_key_set(
db: AsyncSession,
user_id: UUID,
tenant_id: UUID | None,
encrypted_private_key: bytes,
private_key_nonce: bytes,
encrypted_vault_key: bytes,
vault_key_nonce: bytes,
public_key: bytes,
pbkdf2_salt: bytes,
hkdf_salt: bytes,
pbkdf2_iterations: int = 650000,
) -> UserKeySet:
"""Store encrypted key bundle during registration.
Creates a new UserKeySet for the user. Each user has exactly one
key set (UNIQUE constraint on user_id).
Args:
db: Async database session.
user_id: The user's UUID.
tenant_id: The user's tenant UUID (None for super_admin).
encrypted_private_key: RSA private key wrapped by AUK (AES-GCM).
private_key_nonce: 12-byte AES-GCM nonce for private key.
encrypted_vault_key: Tenant vault key wrapped by user's public key.
vault_key_nonce: 12-byte AES-GCM nonce for vault key.
public_key: RSA-2048 public key in SPKI format.
pbkdf2_salt: 32-byte salt for PBKDF2 key derivation.
hkdf_salt: 32-byte salt for HKDF Secret Key derivation.
pbkdf2_iterations: PBKDF2 iteration count (default 650000).
Returns:
The created UserKeySet instance.
"""
# Remove any existing key set (e.g. from a failed prior upgrade attempt)
from sqlalchemy import delete
await db.execute(delete(UserKeySet).where(UserKeySet.user_id == user_id))
key_set = UserKeySet(
user_id=user_id,
tenant_id=tenant_id,
encrypted_private_key=encrypted_private_key,
private_key_nonce=private_key_nonce,
encrypted_vault_key=encrypted_vault_key,
vault_key_nonce=vault_key_nonce,
public_key=public_key,
pbkdf2_salt=pbkdf2_salt,
hkdf_salt=hkdf_salt,
pbkdf2_iterations=pbkdf2_iterations,
)
db.add(key_set)
await db.flush()
return key_set
async def get_user_key_set(
db: AsyncSession, user_id: UUID
) -> UserKeySet | None:
"""Retrieve encrypted key bundle for login response.
Args:
db: Async database session.
user_id: The user's UUID.
Returns:
The UserKeySet if found, None otherwise.
"""
result = await db.execute(
select(UserKeySet).where(UserKeySet.user_id == user_id)
)
return result.scalar_one_or_none()
async def log_key_access(
db: AsyncSession,
tenant_id: UUID,
user_id: UUID | None,
action: str,
resource_type: str | None = None,
resource_id: str | None = None,
key_version: int | None = None,
ip_address: str | None = None,
device_id: UUID | None = None,
justification: str | None = None,
correlation_id: str | None = None,
) -> None:
"""Append to immutable key_access_log.
This table is append-only (INSERT+SELECT only via RLS policy).
No UPDATE or DELETE is permitted.
Args:
db: Async database session.
tenant_id: The tenant UUID for RLS isolation.
user_id: The user who performed the action (None for system ops).
action: Action description (e.g., 'create_key_set', 'decrypt_vault_key').
resource_type: Optional resource type being accessed.
resource_id: Optional resource identifier.
key_version: Optional key version involved.
ip_address: Optional client IP address.
device_id: Optional device UUID for credential access tracking.
justification: Optional justification for the access (e.g., 'api_backup').
correlation_id: Optional correlation ID for request tracing.
"""
log_entry = KeyAccessLog(
tenant_id=tenant_id,
user_id=user_id,
action=action,
resource_type=resource_type,
resource_id=resource_id,
key_version=key_version,
ip_address=ip_address,
device_id=device_id,
justification=justification,
correlation_id=correlation_id,
)
db.add(log_entry)
await db.flush()
# ---------------------------------------------------------------------------
# OpenBao Transit tenant key provisioning and credential migration
# ---------------------------------------------------------------------------
async def provision_tenant_key(db: AsyncSession, tenant_id: UUID) -> str:
"""Provision an OpenBao Transit key for a tenant and update the tenant record.
Idempotent: if the key already exists in OpenBao, it's a no-op on the
OpenBao side. The tenant record is always updated with the key name.
Args:
db: Async database session (admin engine, no RLS).
tenant_id: Tenant UUID.
Returns:
The key name (tenant_{uuid}).
"""
from app.models.tenant import Tenant
from app.services.openbao_service import get_openbao_service
openbao = get_openbao_service()
key_name = f"tenant_{tenant_id}"
await openbao.create_tenant_key(str(tenant_id))
# Update tenant record with key name
result = await db.execute(
select(Tenant).where(Tenant.id == tenant_id)
)
tenant = result.scalar_one_or_none()
if tenant:
tenant.openbao_key_name = key_name
await db.flush()
logger.info(
"Provisioned OpenBao Transit key for tenant %s (key=%s)",
tenant_id,
key_name,
)
return key_name
async def migrate_tenant_credentials(db: AsyncSession, tenant_id: UUID) -> dict:
"""Re-encrypt all legacy credentials for a tenant from AES-256-GCM to Transit.
Migrates device credentials, CA private keys, device cert private keys,
and notification channel secrets. Already-migrated items are skipped.
Args:
db: Async database session (admin engine, no RLS).
tenant_id: Tenant UUID.
Returns:
Dict with counts: {"devices": N, "cas": N, "certs": N, "channels": N, "errors": N}
"""
from app.config import settings
from app.models.alert import NotificationChannel
from app.models.certificate import CertificateAuthority, DeviceCertificate
from app.models.device import Device
from app.services.crypto import decrypt_credentials
from app.services.openbao_service import get_openbao_service
openbao = get_openbao_service()
legacy_key = settings.get_encryption_key_bytes()
tid = str(tenant_id)
counts = {"devices": 0, "cas": 0, "certs": 0, "channels": 0, "errors": 0}
# --- Migrate device credentials ---
result = await db.execute(
select(Device).where(
Device.tenant_id == tenant_id,
Device.encrypted_credentials.isnot(None),
(Device.encrypted_credentials_transit.is_(None) | (Device.encrypted_credentials_transit == "")),
)
)
for device in result.scalars().all():
try:
plaintext = decrypt_credentials(device.encrypted_credentials, legacy_key)
device.encrypted_credentials_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
counts["devices"] += 1
except Exception as e:
logger.error("Failed to migrate device %s credentials: %s", device.id, e)
counts["errors"] += 1
# --- Migrate CA private keys ---
result = await db.execute(
select(CertificateAuthority).where(
CertificateAuthority.tenant_id == tenant_id,
CertificateAuthority.encrypted_private_key.isnot(None),
(CertificateAuthority.encrypted_private_key_transit.is_(None) | (CertificateAuthority.encrypted_private_key_transit == "")),
)
)
for ca in result.scalars().all():
try:
plaintext = decrypt_credentials(ca.encrypted_private_key, legacy_key)
ca.encrypted_private_key_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
counts["cas"] += 1
except Exception as e:
logger.error("Failed to migrate CA %s private key: %s", ca.id, e)
counts["errors"] += 1
# --- Migrate device cert private keys ---
result = await db.execute(
select(DeviceCertificate).where(
DeviceCertificate.tenant_id == tenant_id,
DeviceCertificate.encrypted_private_key.isnot(None),
(DeviceCertificate.encrypted_private_key_transit.is_(None) | (DeviceCertificate.encrypted_private_key_transit == "")),
)
)
for cert in result.scalars().all():
try:
plaintext = decrypt_credentials(cert.encrypted_private_key, legacy_key)
cert.encrypted_private_key_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
counts["certs"] += 1
except Exception as e:
logger.error("Failed to migrate cert %s private key: %s", cert.id, e)
counts["errors"] += 1
# --- Migrate notification channel secrets ---
result = await db.execute(
select(NotificationChannel).where(
NotificationChannel.tenant_id == tenant_id,
)
)
for ch in result.scalars().all():
migrated_any = False
try:
# SMTP password
if ch.smtp_password and not ch.smtp_password_transit:
plaintext = decrypt_credentials(ch.smtp_password, legacy_key)
ch.smtp_password_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
migrated_any = True
if migrated_any:
counts["channels"] += 1
except Exception as e:
logger.error("Failed to migrate channel %s secrets: %s", ch.id, e)
counts["errors"] += 1
await db.flush()
logger.info(
"Tenant %s credential migration complete: %s",
tenant_id,
counts,
)
return counts
async def provision_existing_tenants(db: AsyncSession) -> dict:
"""Provision OpenBao Transit keys for all existing tenants and migrate credentials.
Called on app startup to ensure all tenants have Transit keys.
Idempotent -- running multiple times is safe (already-migrated items are skipped).
Args:
db: Async database session (admin engine, no RLS).
Returns:
Summary dict with total counts across all tenants.
"""
from app.models.tenant import Tenant
result = await db.execute(select(Tenant))
tenants = result.scalars().all()
total = {"tenants": len(tenants), "devices": 0, "cas": 0, "certs": 0, "channels": 0, "errors": 0}
for tenant in tenants:
try:
await provision_tenant_key(db, tenant.id)
counts = await migrate_tenant_credentials(db, tenant.id)
total["devices"] += counts["devices"]
total["cas"] += counts["cas"]
total["certs"] += counts["certs"]
total["channels"] += counts["channels"]
total["errors"] += counts["errors"]
except Exception as e:
logger.error("Failed to provision/migrate tenant %s: %s", tenant.id, e)
total["errors"] += 1
await db.commit()
logger.info("Existing tenant provisioning complete: %s", total)
return total

View File

@@ -0,0 +1,346 @@
"""NATS JetStream subscriber for device metrics events.
Subscribes to device.metrics.> and inserts into TimescaleDB hypertables:
- interface_metrics — per-interface rx/tx byte counters
- health_metrics — CPU, memory, disk, temperature per device
- wireless_metrics — per-wireless-interface aggregated client stats
Also maintains denormalized last_cpu_load and last_memory_used_pct columns
on the devices table for efficient fleet table display.
Uses AdminAsyncSessionLocal (superuser bypass RLS) so metrics from any tenant
can be written without setting app.current_tenant.
"""
import asyncio
import json
import logging
from datetime import datetime, timezone
from typing import Optional
import nats
from nats.js import JetStreamContext
from nats.aio.client import Client as NATSClient
from sqlalchemy import text
from app.config import settings
from app.database import AdminAsyncSessionLocal
logger = logging.getLogger(__name__)
_metrics_client: Optional[NATSClient] = None
# =============================================================================
# INSERT HANDLERS
# =============================================================================
def _parse_timestamp(val: str | None) -> datetime:
"""Parse an ISO 8601 / RFC 3339 timestamp string into a datetime object."""
if not val:
return datetime.now(timezone.utc)
try:
return datetime.fromisoformat(val.replace("Z", "+00:00"))
except (ValueError, AttributeError):
return datetime.now(timezone.utc)
async def _insert_health_metrics(session, data: dict) -> None:
"""Insert a health metrics event into health_metrics and update devices."""
health = data.get("health")
if not health:
logger.warning("health metrics event missing 'health' field — skipping")
return
device_id = data.get("device_id")
tenant_id = data.get("tenant_id")
collected_at = _parse_timestamp(data.get("collected_at"))
# Parse numeric values; treat empty strings as NULL.
def parse_int(val: str | None) -> int | None:
if not val:
return None
try:
return int(val)
except (ValueError, TypeError):
return None
cpu_load = parse_int(health.get("cpu_load"))
free_memory = parse_int(health.get("free_memory"))
total_memory = parse_int(health.get("total_memory"))
free_disk = parse_int(health.get("free_disk"))
total_disk = parse_int(health.get("total_disk"))
temperature = parse_int(health.get("temperature"))
await session.execute(
text("""
INSERT INTO health_metrics
(time, device_id, tenant_id, cpu_load, free_memory, total_memory,
free_disk, total_disk, temperature)
VALUES
(:time, :device_id, :tenant_id, :cpu_load, :free_memory, :total_memory,
:free_disk, :total_disk, :temperature)
"""),
{
"time": collected_at,
"device_id": device_id,
"tenant_id": tenant_id,
"cpu_load": cpu_load,
"free_memory": free_memory,
"total_memory": total_memory,
"free_disk": free_disk,
"total_disk": total_disk,
"temperature": temperature,
},
)
# Update denormalized columns on devices for fleet table display.
# Compute memory percentage in Python to avoid asyncpg type ambiguity.
mem_pct = None
if total_memory and total_memory > 0 and free_memory is not None:
mem_pct = round((1.0 - free_memory / total_memory) * 100)
await session.execute(
text("""
UPDATE devices SET
last_cpu_load = COALESCE(:cpu_load, last_cpu_load),
last_memory_used_pct = COALESCE(:mem_pct, last_memory_used_pct),
updated_at = NOW()
WHERE id = CAST(:device_id AS uuid)
"""),
{
"cpu_load": cpu_load,
"mem_pct": mem_pct,
"device_id": device_id,
},
)
async def _insert_interface_metrics(session, data: dict) -> None:
"""Insert per-interface traffic counters into interface_metrics."""
interfaces = data.get("interfaces")
if not interfaces:
return # Device may have no interfaces (unlikely but safe to skip)
device_id = data.get("device_id")
tenant_id = data.get("tenant_id")
collected_at = _parse_timestamp(data.get("collected_at"))
for iface in interfaces:
await session.execute(
text("""
INSERT INTO interface_metrics
(time, device_id, tenant_id, interface, rx_bytes, tx_bytes, rx_bps, tx_bps)
VALUES
(:time, :device_id, :tenant_id, :interface, :rx_bytes, :tx_bytes, NULL, NULL)
"""),
{
"time": collected_at,
"device_id": device_id,
"tenant_id": tenant_id,
"interface": iface.get("name"),
"rx_bytes": iface.get("rx_bytes"),
"tx_bytes": iface.get("tx_bytes"),
},
)
async def _insert_wireless_metrics(session, data: dict) -> None:
"""Insert per-wireless-interface aggregated client stats into wireless_metrics."""
wireless = data.get("wireless")
if not wireless:
return # Device may have no wireless interfaces
device_id = data.get("device_id")
tenant_id = data.get("tenant_id")
collected_at = _parse_timestamp(data.get("collected_at"))
for wif in wireless:
await session.execute(
text("""
INSERT INTO wireless_metrics
(time, device_id, tenant_id, interface, client_count, avg_signal, ccq, frequency)
VALUES
(:time, :device_id, :tenant_id, :interface,
:client_count, :avg_signal, :ccq, :frequency)
"""),
{
"time": collected_at,
"device_id": device_id,
"tenant_id": tenant_id,
"interface": wif.get("interface"),
"client_count": wif.get("client_count"),
"avg_signal": wif.get("avg_signal"),
"ccq": wif.get("ccq"),
"frequency": wif.get("frequency"),
},
)
# =============================================================================
# MAIN MESSAGE HANDLER
# =============================================================================
async def on_device_metrics(msg) -> None:
"""Handle a device.metrics event published by the Go poller.
Dispatches to the appropriate insert handler based on the 'type' field:
- "health" → _insert_health_metrics + update devices
- "interfaces" → _insert_interface_metrics
- "wireless" → _insert_wireless_metrics
On success, acknowledges the message. On error, NAKs so NATS can redeliver.
"""
try:
data = json.loads(msg.data)
metric_type = data.get("type")
device_id = data.get("device_id")
if not metric_type or not device_id:
logger.warning(
"device.metrics event missing 'type' or 'device_id' — skipping"
)
await msg.ack()
return
async with AdminAsyncSessionLocal() as session:
if metric_type == "health":
await _insert_health_metrics(session, data)
elif metric_type == "interfaces":
await _insert_interface_metrics(session, data)
elif metric_type == "wireless":
await _insert_wireless_metrics(session, data)
else:
logger.warning("Unknown metric type '%s' — skipping", metric_type)
await msg.ack()
return
await session.commit()
# Alert evaluation — non-fatal; metric write is the primary operation
try:
from app.services import alert_evaluator
await alert_evaluator.evaluate(
device_id=device_id,
tenant_id=data.get("tenant_id", ""),
metric_type=metric_type,
data=data,
)
except Exception as eval_err:
logger.warning("Alert evaluation failed for device %s: %s", device_id, eval_err)
logger.debug(
"device.metrics processed",
extra={"device_id": device_id, "type": metric_type},
)
await msg.ack()
except Exception as exc:
logger.error(
"Failed to process device.metrics event: %s",
exc,
exc_info=True,
)
try:
await msg.nak()
except Exception:
pass # If NAK also fails, NATS will redeliver after ack_wait
# =============================================================================
# SUBSCRIPTION SETUP
# =============================================================================
async def _subscribe_with_retry(js: JetStreamContext) -> None:
"""Subscribe to device.metrics.> with durable consumer, retrying if stream not ready."""
max_attempts = 6 # ~30 seconds at 5s intervals
for attempt in range(1, max_attempts + 1):
try:
await js.subscribe(
"device.metrics.>",
cb=on_device_metrics,
durable="api-metrics-consumer",
stream="DEVICE_EVENTS",
)
logger.info(
"NATS: subscribed to device.metrics.> (durable: api-metrics-consumer)"
)
return
except Exception as exc:
if attempt < max_attempts:
logger.warning(
"NATS: stream DEVICE_EVENTS not ready for metrics (attempt %d/%d): %s — retrying in 5s",
attempt,
max_attempts,
exc,
)
await asyncio.sleep(5)
else:
logger.warning(
"NATS: giving up on device.metrics.> after %d attempts: %s — API will run without metrics ingestion",
max_attempts,
exc,
)
return
async def start_metrics_subscriber() -> Optional[NATSClient]:
"""Connect to NATS and start the device.metrics.> subscription.
Uses a separate NATS connection from the status subscriber — simpler and
NATS handles multiple connections per client efficiently.
Returns the NATS connection (must be passed to stop_metrics_subscriber on shutdown).
Raises on fatal connection errors after retry exhaustion.
"""
global _metrics_client
logger.info("NATS metrics: connecting to %s", settings.NATS_URL)
nc = await nats.connect(
settings.NATS_URL,
max_reconnect_attempts=-1,
reconnect_time_wait=2,
error_cb=_on_error,
reconnected_cb=_on_reconnected,
disconnected_cb=_on_disconnected,
)
logger.info("NATS metrics: connected to %s", settings.NATS_URL)
js = nc.jetstream()
await _subscribe_with_retry(js)
_metrics_client = nc
return nc
async def stop_metrics_subscriber(nc: Optional[NATSClient]) -> None:
"""Drain and close the metrics NATS connection gracefully."""
if nc is None:
return
try:
logger.info("NATS metrics: draining connection...")
await nc.drain()
logger.info("NATS metrics: connection closed")
except Exception as exc:
logger.warning("NATS metrics: error during drain: %s", exc)
try:
await nc.close()
except Exception:
pass
async def _on_error(exc: Exception) -> None:
logger.error("NATS metrics error: %s", exc)
async def _on_reconnected() -> None:
logger.info("NATS metrics: reconnected")
async def _on_disconnected() -> None:
logger.warning("NATS metrics: disconnected")

View File

@@ -0,0 +1,231 @@
"""NATS JetStream subscriber for device status events from the Go poller.
Subscribes to device.status.> and updates device records in PostgreSQL.
This is a system-level process that needs to update devices across all tenants,
so it uses the admin engine (bypasses RLS).
"""
import asyncio
import json
import logging
import re
from datetime import datetime, timezone
from typing import Optional
import nats
from nats.js import JetStreamContext
from nats.aio.client import Client as NATSClient
from sqlalchemy import text
from app.config import settings
from app.database import AdminAsyncSessionLocal
logger = logging.getLogger(__name__)
_nats_client: Optional[NATSClient] = None
# Regex for RouterOS uptime strings like "42d14h23m15s", "14h23m15s", "23m15s", "3w2d"
_UPTIME_RE = re.compile(r"(?:(\d+)w)?(?:(\d+)d)?(?:(\d+)h)?(?:(\d+)m)?(?:(\d+)s)?")
def _parse_uptime(raw: str) -> int | None:
"""Parse a RouterOS uptime string into total seconds."""
if not raw:
return None
m = _UPTIME_RE.fullmatch(raw)
if not m:
return None
weeks = int(m.group(1) or 0)
days = int(m.group(2) or 0)
hours = int(m.group(3) or 0)
minutes = int(m.group(4) or 0)
seconds = int(m.group(5) or 0)
total = weeks * 604800 + days * 86400 + hours * 3600 + minutes * 60 + seconds
return total if total > 0 else None
async def on_device_status(msg) -> None:
"""Handle a device.status event published by the Go poller.
Payload (JSON):
device_id (str) — UUID of the device
tenant_id (str) — UUID of the owning tenant
status (str) — "online" or "offline"
routeros_version (str | None) — e.g. "7.16.2"
major_version (int | None) — e.g. 7
board_name (str | None) — e.g. "RB4011iGS+5HacQ2HnD"
last_seen (str | None) — ISO-8601 timestamp
"""
try:
data = json.loads(msg.data)
device_id = data.get("device_id")
status = data.get("status")
routeros_version = data.get("routeros_version")
major_version = data.get("major_version")
board_name = data.get("board_name")
last_seen_raw = data.get("last_seen")
serial_number = data.get("serial_number") or None
firmware_version = data.get("firmware_version") or None
uptime_seconds = _parse_uptime(data.get("uptime", ""))
if not device_id or not status:
logger.warning("Received device.status event with missing device_id or status — skipping")
await msg.ack()
return
# Parse timestamp in Python — asyncpg needs datetime objects, not strings
last_seen_dt = None
if last_seen_raw:
try:
last_seen_dt = datetime.fromisoformat(last_seen_raw.replace("Z", "+00:00"))
except (ValueError, AttributeError):
last_seen_dt = datetime.now(timezone.utc)
async with AdminAsyncSessionLocal() as session:
await session.execute(
text(
"""
UPDATE devices SET
status = :status,
routeros_version = COALESCE(:routeros_version, routeros_version),
routeros_major_version = COALESCE(:major_version, routeros_major_version),
model = COALESCE(:board_name, model),
serial_number = COALESCE(:serial_number, serial_number),
firmware_version = COALESCE(:firmware_version, firmware_version),
uptime_seconds = COALESCE(:uptime_seconds, uptime_seconds),
last_seen = COALESCE(:last_seen, last_seen),
updated_at = NOW()
WHERE id = CAST(:device_id AS uuid)
"""
),
{
"status": status,
"routeros_version": routeros_version,
"major_version": major_version,
"board_name": board_name,
"serial_number": serial_number,
"firmware_version": firmware_version,
"uptime_seconds": uptime_seconds,
"last_seen": last_seen_dt,
"device_id": device_id,
},
)
await session.commit()
# Alert evaluation for offline/online status changes — non-fatal
try:
from app.services import alert_evaluator
if status == "offline":
await alert_evaluator.evaluate_offline(device_id, data.get("tenant_id", ""))
elif status == "online":
await alert_evaluator.evaluate_online(device_id, data.get("tenant_id", ""))
except Exception as e:
logger.warning("Alert evaluation failed for device %s status=%s: %s", device_id, status, e)
logger.info(
"Device status updated",
extra={
"device_id": device_id,
"status": status,
"routeros_version": routeros_version,
},
)
await msg.ack()
except Exception as exc:
logger.error(
"Failed to process device.status event: %s",
exc,
exc_info=True,
)
try:
await msg.nak()
except Exception:
pass # If NAK also fails, NATS will redeliver after ack_wait
async def _subscribe_with_retry(js: JetStreamContext) -> None:
"""Subscribe to device.status.> with durable consumer, retrying if stream not ready."""
max_attempts = 6 # ~30 seconds at 5s intervals
for attempt in range(1, max_attempts + 1):
try:
await js.subscribe(
"device.status.>",
cb=on_device_status,
durable="api-status-consumer",
stream="DEVICE_EVENTS",
)
logger.info("NATS: subscribed to device.status.> (durable: api-status-consumer)")
return
except Exception as exc:
if attempt < max_attempts:
logger.warning(
"NATS: stream DEVICE_EVENTS not ready (attempt %d/%d): %s — retrying in 5s",
attempt,
max_attempts,
exc,
)
await asyncio.sleep(5)
else:
logger.warning(
"NATS: giving up on device.status.> after %d attempts: %s — API will run without real-time status updates",
max_attempts,
exc,
)
return
async def start_nats_subscriber() -> Optional[NATSClient]:
"""Connect to NATS and start the device.status.> subscription.
Returns the NATS connection (must be passed to stop_nats_subscriber on shutdown).
Raises on fatal connection errors after retry exhaustion.
"""
global _nats_client
logger.info("NATS: connecting to %s", settings.NATS_URL)
nc = await nats.connect(
settings.NATS_URL,
max_reconnect_attempts=-1, # reconnect forever (pod-to-pod transient failures)
reconnect_time_wait=2,
error_cb=_on_error,
reconnected_cb=_on_reconnected,
disconnected_cb=_on_disconnected,
)
logger.info("NATS: connected to %s", settings.NATS_URL)
js = nc.jetstream()
await _subscribe_with_retry(js)
_nats_client = nc
return nc
async def stop_nats_subscriber(nc: Optional[NATSClient]) -> None:
"""Drain and close the NATS connection gracefully."""
if nc is None:
return
try:
logger.info("NATS: draining connection...")
await nc.drain()
logger.info("NATS: connection closed")
except Exception as exc:
logger.warning("NATS: error during drain: %s", exc)
try:
await nc.close()
except Exception:
pass
async def _on_error(exc: Exception) -> None:
logger.error("NATS error: %s", exc)
async def _on_reconnected() -> None:
logger.info("NATS: reconnected")
async def _on_disconnected() -> None:
logger.warning("NATS: disconnected")

View File

@@ -0,0 +1,256 @@
"""Email and webhook notification delivery for alert events.
Best-effort delivery: failures are logged but never raised.
Each dispatch is wrapped in try/except so one failing channel
doesn't prevent delivery to other channels.
"""
import logging
from typing import Any
import httpx
logger = logging.getLogger(__name__)
async def dispatch_notifications(
alert_event: dict[str, Any],
channels: list[dict[str, Any]],
device_hostname: str,
) -> None:
"""Send notifications for an alert event to all provided channels.
Args:
alert_event: Dict with alert event fields (status, severity, metric, etc.)
channels: List of notification channel dicts
device_hostname: Human-readable device name for messages
"""
for channel in channels:
try:
if channel["channel_type"] == "email":
await _send_email(channel, alert_event, device_hostname)
elif channel["channel_type"] == "webhook":
await _send_webhook(channel, alert_event, device_hostname)
elif channel["channel_type"] == "slack":
await _send_slack(channel, alert_event, device_hostname)
else:
logger.warning("Unknown channel type: %s", channel["channel_type"])
except Exception as e:
logger.warning(
"Notification delivery failed for channel %s (%s): %s",
channel.get("name"), channel.get("channel_type"), e,
)
async def _send_email(channel: dict, alert_event: dict, device_hostname: str) -> None:
"""Send alert notification email using per-channel SMTP config."""
from app.services.email_service import SMTPConfig, send_email
severity = alert_event.get("severity", "warning")
status = alert_event.get("status", "firing")
rule_name = alert_event.get("rule_name") or alert_event.get("message", "Unknown Rule")
metric = alert_event.get("metric_name") or alert_event.get("metric", "")
value = alert_event.get("current_value") or alert_event.get("value", "")
threshold = alert_event.get("threshold", "")
severity_colors = {
"critical": "#ef4444",
"warning": "#f59e0b",
"info": "#38bdf8",
}
color = severity_colors.get(severity, "#38bdf8")
status_label = "RESOLVED" if status == "resolved" else "FIRING"
html = f"""
<div style="font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; max-width: 600px; margin: 0 auto;">
<div style="background: {color}; padding: 16px 24px; border-radius: 8px 8px 0 0;">
<h2 style="color: #fff; margin: 0;">[{status_label}] {rule_name}</h2>
</div>
<div style="background: #1e293b; padding: 24px; border-radius: 0 0 8px 8px; color: #e2e8f0;">
<table style="width: 100%; border-collapse: collapse;">
<tr><td style="padding: 8px 0; color: #94a3b8;">Device</td><td style="padding: 8px 0;">{device_hostname}</td></tr>
<tr><td style="padding: 8px 0; color: #94a3b8;">Severity</td><td style="padding: 8px 0;">{severity.upper()}</td></tr>
<tr><td style="padding: 8px 0; color: #94a3b8;">Metric</td><td style="padding: 8px 0;">{metric}</td></tr>
<tr><td style="padding: 8px 0; color: #94a3b8;">Value</td><td style="padding: 8px 0;">{value}</td></tr>
<tr><td style="padding: 8px 0; color: #94a3b8;">Threshold</td><td style="padding: 8px 0;">{threshold}</td></tr>
</table>
<p style="color: #64748b; font-size: 12px; margin-top: 24px;">
TOD — Fleet Management for MikroTik RouterOS
</p>
</div>
</div>
"""
plain = (
f"[{status_label}] {rule_name}\n\n"
f"Device: {device_hostname}\n"
f"Severity: {severity}\n"
f"Metric: {metric}\n"
f"Value: {value}\n"
f"Threshold: {threshold}\n"
)
# Decrypt SMTP password (Transit first, then legacy Fernet)
smtp_password = None
transit_cipher = channel.get("smtp_password_transit")
legacy_cipher = channel.get("smtp_password")
tenant_id = channel.get("tenant_id")
if transit_cipher and tenant_id:
try:
from app.services.kms_service import decrypt_transit
smtp_password = await decrypt_transit(transit_cipher, tenant_id)
except Exception:
logger.warning("Transit decryption failed for channel %s, trying legacy", channel.get("id"))
if not smtp_password and legacy_cipher:
try:
from app.config import settings as app_settings
from cryptography.fernet import Fernet
raw = bytes(legacy_cipher) if isinstance(legacy_cipher, memoryview) else legacy_cipher
f = Fernet(app_settings.CREDENTIAL_ENCRYPTION_KEY.encode())
smtp_password = f.decrypt(raw).decode()
except Exception:
logger.warning("Legacy decryption failed for channel %s", channel.get("id"))
config = SMTPConfig(
host=channel.get("smtp_host", "localhost"),
port=channel.get("smtp_port", 587),
user=channel.get("smtp_user"),
password=smtp_password,
use_tls=channel.get("smtp_use_tls", False),
from_address=channel.get("from_address") or "alerts@mikrotik-portal.local",
)
to = channel.get("to_address")
subject = f"[TOD {status_label}] {rule_name}{device_hostname}"
await send_email(to, subject, html, plain, config)
async def _send_webhook(
channel: dict[str, Any],
alert_event: dict[str, Any],
device_hostname: str,
) -> None:
"""Send alert notification to a webhook URL (Slack-compatible JSON)."""
severity = alert_event.get("severity", "info")
status = alert_event.get("status", "firing")
metric = alert_event.get("metric")
value = alert_event.get("value")
threshold = alert_event.get("threshold")
message_text = alert_event.get("message", "")
payload = {
"alert_name": message_text,
"severity": severity,
"status": status,
"device": device_hostname,
"device_id": alert_event.get("device_id"),
"metric": metric,
"value": value,
"threshold": threshold,
"timestamp": str(alert_event.get("fired_at", "")),
"text": f"[{severity.upper()}] {device_hostname}: {message_text}",
}
webhook_url = channel.get("webhook_url", "")
if not webhook_url:
logger.warning("Webhook channel %s has no URL configured", channel.get("name"))
return
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.post(webhook_url, json=payload)
logger.info(
"Webhook notification sent to %s — status %d",
webhook_url, response.status_code,
)
async def _send_slack(
channel: dict[str, Any],
alert_event: dict[str, Any],
device_hostname: str,
) -> None:
"""Send alert notification to Slack via incoming webhook with Block Kit formatting."""
severity = alert_event.get("severity", "info").upper()
status = alert_event.get("status", "firing")
metric = alert_event.get("metric", "unknown")
message_text = alert_event.get("message", "")
value = alert_event.get("value")
threshold = alert_event.get("threshold")
color = {"CRITICAL": "#dc2626", "WARNING": "#f59e0b", "INFO": "#3b82f6"}.get(severity, "#6b7280")
status_label = "RESOLVED" if status == "resolved" else status
blocks = [
{
"type": "header",
"text": {"type": "plain_text", "text": f"{'' if status == 'resolved' else '🚨'} [{severity}] {status_label.upper()}"},
},
{
"type": "section",
"fields": [
{"type": "mrkdwn", "text": f"*Device:*\n{device_hostname}"},
{"type": "mrkdwn", "text": f"*Metric:*\n{metric}"},
],
},
]
if value is not None or threshold is not None:
fields = []
if value is not None:
fields.append({"type": "mrkdwn", "text": f"*Value:*\n{value}"})
if threshold is not None:
fields.append({"type": "mrkdwn", "text": f"*Threshold:*\n{threshold}"})
blocks.append({"type": "section", "fields": fields})
if message_text:
blocks.append({"type": "section", "text": {"type": "mrkdwn", "text": f"*Message:*\n{message_text}"}})
blocks.append({"type": "context", "elements": [{"type": "mrkdwn", "text": "TOD Alert System"}]})
slack_url = channel.get("slack_webhook_url", "")
if not slack_url:
logger.warning("Slack channel %s has no webhook URL configured", channel.get("name"))
return
payload = {"attachments": [{"color": color, "blocks": blocks}]}
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.post(slack_url, json=payload)
logger.info("Slack notification sent — status %d", response.status_code)
async def send_test_notification(channel: dict[str, Any]) -> bool:
"""Send a test notification through a channel to verify configuration.
Args:
channel: Notification channel dict with all config fields
Returns:
True on success
Raises:
Exception on delivery failure (caller handles)
"""
test_event = {
"status": "test",
"severity": "info",
"metric": "test",
"value": None,
"threshold": None,
"message": "Test notification from TOD",
"device_id": "00000000-0000-0000-0000-000000000000",
"fired_at": "",
}
if channel["channel_type"] == "email":
await _send_email(channel, test_event, "Test Device")
elif channel["channel_type"] == "webhook":
await _send_webhook(channel, test_event, "Test Device")
elif channel["channel_type"] == "slack":
await _send_slack(channel, test_event, "Test Device")
else:
raise ValueError(f"Unknown channel type: {channel['channel_type']}")
return True

View File

@@ -0,0 +1,174 @@
"""
OpenBao Transit secrets engine client for per-tenant envelope encryption.
Provides encrypt/decrypt operations via OpenBao's HTTP API. Each tenant gets
a dedicated Transit key (tenant_{uuid}) for AES-256-GCM encryption. The key
material never leaves OpenBao -- the application only sees ciphertext.
Ciphertext format: "vault:v1:base64..." (compatible with Vault Transit format)
"""
import base64
import logging
from typing import Optional
import httpx
from app.config import settings
logger = logging.getLogger(__name__)
class OpenBaoTransitService:
"""Async client for OpenBao Transit secrets engine."""
def __init__(self, addr: str | None = None, token: str | None = None):
self.addr = addr or settings.OPENBAO_ADDR
self.token = token or settings.OPENBAO_TOKEN
self._client: httpx.AsyncClient | None = None
async def _get_client(self) -> httpx.AsyncClient:
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(
base_url=self.addr,
headers={"X-Vault-Token": self.token},
timeout=5.0,
)
return self._client
async def close(self) -> None:
if self._client and not self._client.is_closed:
await self._client.aclose()
self._client = None
async def create_tenant_key(self, tenant_id: str) -> None:
"""Create Transit encryption keys for a tenant (credential + data). Idempotent."""
client = await self._get_client()
# Credential key: tenant_{uuid}
key_name = f"tenant_{tenant_id}"
resp = await client.post(
f"/v1/transit/keys/{key_name}",
json={"type": "aes256-gcm96"},
)
if resp.status_code not in (200, 204):
resp.raise_for_status()
logger.info("OpenBao Transit key ensured", extra={"key_name": key_name})
# Data key: tenant_{uuid}_data (Phase 30)
await self.create_tenant_data_key(tenant_id)
async def encrypt(self, tenant_id: str, plaintext: bytes) -> str:
"""Encrypt plaintext via Transit engine. Returns ciphertext string."""
client = await self._get_client()
key_name = f"tenant_{tenant_id}"
resp = await client.post(
f"/v1/transit/encrypt/{key_name}",
json={"plaintext": base64.b64encode(plaintext).decode()},
)
resp.raise_for_status()
ciphertext = resp.json()["data"]["ciphertext"]
return ciphertext # "vault:v1:..."
async def decrypt(self, tenant_id: str, ciphertext: str) -> bytes:
"""Decrypt Transit ciphertext. Returns plaintext bytes."""
client = await self._get_client()
key_name = f"tenant_{tenant_id}"
resp = await client.post(
f"/v1/transit/decrypt/{key_name}",
json={"ciphertext": ciphertext},
)
resp.raise_for_status()
plaintext_b64 = resp.json()["data"]["plaintext"]
return base64.b64decode(plaintext_b64)
async def key_exists(self, tenant_id: str) -> bool:
"""Check if a Transit key exists for a tenant."""
client = await self._get_client()
key_name = f"tenant_{tenant_id}"
resp = await client.get(f"/v1/transit/keys/{key_name}")
return resp.status_code == 200
# ------------------------------------------------------------------
# Data encryption keys (tenant_{uuid}_data) — Phase 30
# ------------------------------------------------------------------
async def create_tenant_data_key(self, tenant_id: str) -> None:
"""Create a Transit data encryption key for a tenant. Idempotent.
Data keys use the suffix '_data' to separate them from credential keys.
Key naming: tenant_{uuid}_data (vs tenant_{uuid} for credentials).
"""
client = await self._get_client()
key_name = f"tenant_{tenant_id}_data"
resp = await client.post(
f"/v1/transit/keys/{key_name}",
json={"type": "aes256-gcm96"},
)
if resp.status_code not in (200, 204):
resp.raise_for_status()
logger.info("OpenBao Transit data key ensured", extra={"key_name": key_name})
async def ensure_tenant_data_key(self, tenant_id: str) -> None:
"""Ensure a data encryption key exists for a tenant. Idempotent.
Checks existence first and creates if missing. Safe to call on every
encrypt operation (fast path: single GET to check existence).
"""
client = await self._get_client()
key_name = f"tenant_{tenant_id}_data"
resp = await client.get(f"/v1/transit/keys/{key_name}")
if resp.status_code != 200:
await self.create_tenant_data_key(tenant_id)
async def encrypt_data(self, tenant_id: str, plaintext: bytes) -> str:
"""Encrypt data via Transit using per-tenant data key.
Uses the tenant_{uuid}_data key (separate from credential key).
Args:
tenant_id: Tenant UUID string.
plaintext: Raw bytes to encrypt.
Returns:
Transit ciphertext string (vault:v1:...).
"""
client = await self._get_client()
key_name = f"tenant_{tenant_id}_data"
resp = await client.post(
f"/v1/transit/encrypt/{key_name}",
json={"plaintext": base64.b64encode(plaintext).decode()},
)
resp.raise_for_status()
return resp.json()["data"]["ciphertext"]
async def decrypt_data(self, tenant_id: str, ciphertext: str) -> bytes:
"""Decrypt Transit data ciphertext using per-tenant data key.
Args:
tenant_id: Tenant UUID string.
ciphertext: Transit ciphertext (vault:v1:...).
Returns:
Decrypted plaintext bytes.
"""
client = await self._get_client()
key_name = f"tenant_{tenant_id}_data"
resp = await client.post(
f"/v1/transit/decrypt/{key_name}",
json={"ciphertext": ciphertext},
)
resp.raise_for_status()
plaintext_b64 = resp.json()["data"]["plaintext"]
return base64.b64decode(plaintext_b64)
# Module-level singleton
_openbao_service: Optional[OpenBaoTransitService] = None
def get_openbao_service() -> OpenBaoTransitService:
"""Return module-level OpenBao Transit service singleton."""
global _openbao_service
if _openbao_service is None:
_openbao_service = OpenBaoTransitService()
return _openbao_service

View File

@@ -0,0 +1,141 @@
"""NATS subscribers for push rollback (auto) and push alert (manual).
- config.push.rollback.> -> auto-restore for template pushes
- config.push.alert.> -> create alert for editor pushes
"""
import json
import logging
from typing import Any, Optional
from app.config import settings
from app.database import AdminAsyncSessionLocal
from app.models.alert import AlertEvent
from app.services import restore_service
logger = logging.getLogger(__name__)
_nc: Optional[Any] = None
async def _create_push_alert(device_id: str, tenant_id: str, push_type: str) -> None:
"""Create a high-priority alert for device offline after config push."""
async with AdminAsyncSessionLocal() as session:
alert = AlertEvent(
device_id=device_id,
tenant_id=tenant_id,
status="firing",
severity="critical",
message=f"Device went offline after config {push_type} — rollback available",
)
session.add(alert)
await session.commit()
logger.info("Created push alert for device %s (type=%s)", device_id, push_type)
async def handle_push_rollback(event: dict) -> None:
"""Auto-rollback: restore device to pre-push config."""
device_id = event.get("device_id")
tenant_id = event.get("tenant_id")
commit_sha = event.get("pre_push_commit_sha")
if not all([device_id, tenant_id, commit_sha]):
logger.warning("Push rollback event missing fields: %s", event)
return
logger.warning(
"AUTO-ROLLBACK: Device %s offline after template push, restoring to %s",
device_id,
commit_sha,
)
try:
async with AdminAsyncSessionLocal() as session:
result = await restore_service.restore_config(
device_id=device_id,
tenant_id=tenant_id,
commit_sha=commit_sha,
db_session=session,
)
await session.commit()
logger.info(
"Auto-rollback result for device %s: %s",
device_id,
result.get("status"),
)
except Exception as e:
logger.error("Auto-rollback failed for device %s: %s", device_id, e)
await _create_push_alert(device_id, tenant_id, "template (auto-rollback failed)")
async def handle_push_alert(event: dict) -> None:
"""Alert: create notification for device offline after editor push."""
device_id = event.get("device_id")
tenant_id = event.get("tenant_id")
push_type = event.get("push_type", "editor")
if not device_id or not tenant_id:
logger.warning("Push alert event missing fields: %s", event)
return
await _create_push_alert(device_id, tenant_id, push_type)
async def _on_rollback_message(msg) -> None:
"""NATS message handler for config.push.rollback.> subjects."""
try:
event = json.loads(msg.data.decode())
await handle_push_rollback(event)
await msg.ack()
except Exception as e:
logger.error("Error handling rollback message: %s", e)
await msg.nak()
async def _on_alert_message(msg) -> None:
"""NATS message handler for config.push.alert.> subjects."""
try:
event = json.loads(msg.data.decode())
await handle_push_alert(event)
await msg.ack()
except Exception as e:
logger.error("Error handling push alert message: %s", e)
await msg.nak()
async def start_push_rollback_subscriber() -> Optional[Any]:
"""Connect to NATS and subscribe to push rollback/alert events."""
import nats
global _nc
try:
logger.info("NATS push-rollback: connecting to %s", settings.NATS_URL)
_nc = await nats.connect(settings.NATS_URL)
js = _nc.jetstream()
await js.subscribe(
"config.push.rollback.>",
cb=_on_rollback_message,
durable="api-push-rollback-consumer",
stream="DEVICE_EVENTS",
manual_ack=True,
)
await js.subscribe(
"config.push.alert.>",
cb=_on_alert_message,
durable="api-push-alert-consumer",
stream="DEVICE_EVENTS",
manual_ack=True,
)
logger.info("Push rollback/alert subscriber started")
return _nc
except Exception as e:
logger.error("Failed to start push rollback subscriber: %s", e)
return None
async def stop_push_rollback_subscriber() -> None:
"""Gracefully close the NATS connection."""
global _nc
if _nc:
await _nc.drain()
_nc = None

View File

@@ -0,0 +1,70 @@
"""Track recent config pushes in Redis for poller-aware rollback.
When a device goes offline shortly after a push, the poller checks these
keys and triggers rollback (template/restore) or alert (editor).
Redis key format: push:recent:{device_id}
TTL: 300 seconds (5 minutes)
"""
import json
import logging
from typing import Optional
import redis.asyncio as redis
from app.config import settings
logger = logging.getLogger(__name__)
PUSH_TTL_SECONDS = 300 # 5 minutes
_redis: Optional[redis.Redis] = None
async def _get_redis() -> redis.Redis:
global _redis
if _redis is None:
_redis = redis.from_url(settings.REDIS_URL)
return _redis
async def record_push(
device_id: str,
tenant_id: str,
push_type: str,
push_operation_id: str = "",
pre_push_commit_sha: str = "",
) -> None:
"""Record a recent config push in Redis.
Args:
device_id: UUID of the device.
tenant_id: UUID of the tenant.
push_type: 'template' (auto-rollback) or 'editor' (alert only) or 'restore'.
push_operation_id: ID of the ConfigPushOperation row.
pre_push_commit_sha: Git SHA of the pre-push backup (for rollback).
"""
r = await _get_redis()
key = f"push:recent:{device_id}"
value = json.dumps({
"device_id": device_id,
"tenant_id": tenant_id,
"push_type": push_type,
"push_operation_id": push_operation_id,
"pre_push_commit_sha": pre_push_commit_sha,
})
await r.set(key, value, ex=PUSH_TTL_SECONDS)
logger.debug(
"Recorded push for device %s (type=%s, TTL=%ds)",
device_id,
push_type,
PUSH_TTL_SECONDS,
)
async def clear_push(device_id: str) -> None:
"""Clear the push tracking key (e.g., after successful commit)."""
r = await _get_redis()
await r.delete(f"push:recent:{device_id}")
logger.debug("Cleared push tracking for device %s", device_id)

View File

@@ -0,0 +1,572 @@
"""Report generation service.
Generates PDF (via Jinja2 + weasyprint) and CSV reports for:
- Device inventory
- Metrics summary
- Alert history
- Change log (audit_logs if available, else config_backups fallback)
Phase 30 NOTE: Reports are currently ephemeral (generated on-demand per request,
never stored at rest). DATAENC-03 requires "report content is encrypted before
storage." Since no report storage exists yet, encryption will be applied when
report caching/storage is added. The generation pipeline is Transit-ready --
wrap the file_bytes with encrypt_data_transit() before any future INSERT.
"""
import csv
import io
import os
import time
from datetime import datetime
from typing import Any, Optional
from uuid import UUID
import structlog
from jinja2 import Environment, FileSystemLoader
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
logger = structlog.get_logger(__name__)
# Jinja2 environment pointing at the templates directory
_TEMPLATE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "templates")
_jinja_env = Environment(
loader=FileSystemLoader(_TEMPLATE_DIR),
autoescape=True,
)
async def generate_report(
db: AsyncSession,
tenant_id: UUID,
report_type: str,
date_from: Optional[datetime],
date_to: Optional[datetime],
fmt: str = "pdf",
) -> tuple[bytes, str, str]:
"""Generate a report and return (file_bytes, content_type, filename).
Args:
db: RLS-enforced async session (tenant context already set).
tenant_id: Tenant UUID for scoping.
report_type: One of device_inventory, metrics_summary, alert_history, change_log.
date_from: Start date for time-ranged reports.
date_to: End date for time-ranged reports.
fmt: Output format -- "pdf" or "csv".
Returns:
Tuple of (file_bytes, content_type, filename).
"""
start = time.monotonic()
# Fetch tenant name for the header
tenant_name = await _get_tenant_name(db, tenant_id)
# Dispatch to the appropriate handler
handlers = {
"device_inventory": _device_inventory,
"metrics_summary": _metrics_summary,
"alert_history": _alert_history,
"change_log": _change_log,
}
handler = handlers[report_type]
template_data = await handler(db, tenant_id, date_from, date_to)
# Common template context
generated_at = datetime.utcnow().strftime("%Y-%m-%d %H:%M UTC")
base_context = {
"tenant_name": tenant_name,
"generated_at": generated_at,
}
timestamp_str = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
if fmt == "csv":
file_bytes = _render_csv(report_type, template_data)
content_type = "text/csv; charset=utf-8"
filename = f"{report_type}_{timestamp_str}.csv"
else:
file_bytes = _render_pdf(report_type, {**base_context, **template_data})
content_type = "application/pdf"
filename = f"{report_type}_{timestamp_str}.pdf"
elapsed = time.monotonic() - start
logger.info(
"report_generated",
report_type=report_type,
format=fmt,
tenant_id=str(tenant_id),
size_bytes=len(file_bytes),
elapsed_seconds=round(elapsed, 2),
)
return file_bytes, content_type, filename
# ---------------------------------------------------------------------------
# Tenant name helper
# ---------------------------------------------------------------------------
async def _get_tenant_name(db: AsyncSession, tenant_id: UUID) -> str:
"""Fetch the tenant name by ID."""
result = await db.execute(
text("SELECT name FROM tenants WHERE id = CAST(:tid AS uuid)"),
{"tid": str(tenant_id)},
)
row = result.fetchone()
return row[0] if row else "Unknown Tenant"
# ---------------------------------------------------------------------------
# Report type handlers
# ---------------------------------------------------------------------------
async def _device_inventory(
db: AsyncSession,
tenant_id: UUID,
date_from: Optional[datetime],
date_to: Optional[datetime],
) -> dict[str, Any]:
"""Gather device inventory data."""
result = await db.execute(
text("""
SELECT d.hostname, d.ip_address, d.model, d.routeros_version,
d.status, d.last_seen, d.uptime_seconds,
COALESCE(
(SELECT string_agg(dg.name, ', ')
FROM device_group_memberships dgm
JOIN device_groups dg ON dg.id = dgm.group_id
WHERE dgm.device_id = d.id),
''
) AS groups
FROM devices d
ORDER BY d.hostname ASC
""")
)
rows = result.fetchall()
devices = []
online_count = 0
offline_count = 0
unknown_count = 0
for row in rows:
status = row[4]
if status == "online":
online_count += 1
elif status == "offline":
offline_count += 1
else:
unknown_count += 1
uptime_str = _format_uptime(row[6]) if row[6] else None
last_seen_str = row[5].strftime("%Y-%m-%d %H:%M") if row[5] else None
devices.append({
"hostname": row[0],
"ip_address": row[1],
"model": row[2],
"routeros_version": row[3],
"status": status,
"last_seen": last_seen_str,
"uptime": uptime_str,
"groups": row[7] if row[7] else None,
})
return {
"report_title": "Device Inventory",
"devices": devices,
"total_devices": len(devices),
"online_count": online_count,
"offline_count": offline_count,
"unknown_count": unknown_count,
}
async def _metrics_summary(
db: AsyncSession,
tenant_id: UUID,
date_from: Optional[datetime],
date_to: Optional[datetime],
) -> dict[str, Any]:
"""Gather metrics summary data grouped by device."""
result = await db.execute(
text("""
SELECT d.hostname,
AVG(hm.cpu_load) AS avg_cpu,
MAX(hm.cpu_load) AS peak_cpu,
AVG(CASE WHEN hm.total_memory > 0
THEN 100.0 * (hm.total_memory - hm.free_memory) / hm.total_memory
END) AS avg_mem,
MAX(CASE WHEN hm.total_memory > 0
THEN 100.0 * (hm.total_memory - hm.free_memory) / hm.total_memory
END) AS peak_mem,
AVG(CASE WHEN hm.total_disk > 0
THEN 100.0 * (hm.total_disk - hm.free_disk) / hm.total_disk
END) AS avg_disk,
AVG(hm.temperature) AS avg_temp,
COUNT(*) AS data_points
FROM health_metrics hm
JOIN devices d ON d.id = hm.device_id
WHERE hm.time >= :date_from
AND hm.time <= :date_to
GROUP BY d.id, d.hostname
ORDER BY avg_cpu DESC NULLS LAST
"""),
{
"date_from": date_from,
"date_to": date_to,
},
)
rows = result.fetchall()
devices = []
for row in rows:
devices.append({
"hostname": row[0],
"avg_cpu": float(row[1]) if row[1] is not None else None,
"peak_cpu": float(row[2]) if row[2] is not None else None,
"avg_mem": float(row[3]) if row[3] is not None else None,
"peak_mem": float(row[4]) if row[4] is not None else None,
"avg_disk": float(row[5]) if row[5] is not None else None,
"avg_temp": float(row[6]) if row[6] is not None else None,
"data_points": row[7],
})
return {
"report_title": "Metrics Summary",
"devices": devices,
"date_from": date_from.strftime("%Y-%m-%d") if date_from else "",
"date_to": date_to.strftime("%Y-%m-%d") if date_to else "",
}
async def _alert_history(
db: AsyncSession,
tenant_id: UUID,
date_from: Optional[datetime],
date_to: Optional[datetime],
) -> dict[str, Any]:
"""Gather alert history data."""
result = await db.execute(
text("""
SELECT ae.fired_at, ae.resolved_at, ae.severity, ae.status,
ae.message, d.hostname,
EXTRACT(EPOCH FROM (ae.resolved_at - ae.fired_at)) AS duration_secs
FROM alert_events ae
LEFT JOIN devices d ON d.id = ae.device_id
WHERE ae.fired_at >= :date_from
AND ae.fired_at <= :date_to
ORDER BY ae.fired_at DESC
"""),
{
"date_from": date_from,
"date_to": date_to,
},
)
rows = result.fetchall()
alerts = []
critical_count = 0
warning_count = 0
info_count = 0
resolved_durations: list[float] = []
for row in rows:
severity = row[2]
if severity == "critical":
critical_count += 1
elif severity == "warning":
warning_count += 1
else:
info_count += 1
duration_secs = float(row[6]) if row[6] is not None else None
if duration_secs is not None:
resolved_durations.append(duration_secs)
alerts.append({
"fired_at": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"hostname": row[5],
"severity": severity,
"status": row[3],
"message": row[4],
"duration": _format_duration(duration_secs) if duration_secs is not None else None,
})
mttr_minutes = None
mttr_display = None
if resolved_durations:
avg_secs = sum(resolved_durations) / len(resolved_durations)
mttr_minutes = round(avg_secs / 60, 1)
mttr_display = _format_duration(avg_secs)
return {
"report_title": "Alert History",
"alerts": alerts,
"total_alerts": len(alerts),
"critical_count": critical_count,
"warning_count": warning_count,
"info_count": info_count,
"mttr_minutes": mttr_minutes,
"mttr_display": mttr_display,
"date_from": date_from.strftime("%Y-%m-%d") if date_from else "",
"date_to": date_to.strftime("%Y-%m-%d") if date_to else "",
}
async def _change_log(
db: AsyncSession,
tenant_id: UUID,
date_from: Optional[datetime],
date_to: Optional[datetime],
) -> dict[str, Any]:
"""Gather change log data -- try audit_logs table first, fall back to config_backups."""
# Check if audit_logs table exists (17-01 may not have run yet)
has_audit_logs = await _table_exists(db, "audit_logs")
if has_audit_logs:
return await _change_log_from_audit(db, date_from, date_to)
else:
return await _change_log_from_backups(db, date_from, date_to)
async def _table_exists(db: AsyncSession, table_name: str) -> bool:
"""Check if a table exists in the database."""
result = await db.execute(
text("""
SELECT EXISTS (
SELECT 1 FROM information_schema.tables
WHERE table_schema = 'public' AND table_name = :table_name
)
"""),
{"table_name": table_name},
)
return bool(result.scalar())
async def _change_log_from_audit(
db: AsyncSession,
date_from: Optional[datetime],
date_to: Optional[datetime],
) -> dict[str, Any]:
"""Build change log from audit_logs table."""
result = await db.execute(
text("""
SELECT al.created_at, u.name AS user_name, al.action,
d.hostname, al.resource_type,
al.details
FROM audit_logs al
LEFT JOIN users u ON u.id = al.user_id
LEFT JOIN devices d ON d.id = al.device_id
WHERE al.created_at >= :date_from
AND al.created_at <= :date_to
ORDER BY al.created_at DESC
"""),
{
"date_from": date_from,
"date_to": date_to,
},
)
rows = result.fetchall()
entries = []
for row in rows:
entries.append({
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"user": row[1],
"action": row[2],
"device": row[3],
"details": row[4] or row[5] or "",
})
return {
"report_title": "Change Log",
"entries": entries,
"total_entries": len(entries),
"data_source": "Audit Logs",
"date_from": date_from.strftime("%Y-%m-%d") if date_from else "",
"date_to": date_to.strftime("%Y-%m-%d") if date_to else "",
}
async def _change_log_from_backups(
db: AsyncSession,
date_from: Optional[datetime],
date_to: Optional[datetime],
) -> dict[str, Any]:
"""Build change log from config_backups + alert_events as fallback."""
# Config backups as change events
backup_result = await db.execute(
text("""
SELECT cb.created_at, 'system' AS user_name, 'config_backup' AS action,
d.hostname, cb.trigger_type AS details
FROM config_backups cb
JOIN devices d ON d.id = cb.device_id
WHERE cb.created_at >= :date_from
AND cb.created_at <= :date_to
"""),
{
"date_from": date_from,
"date_to": date_to,
},
)
backup_rows = backup_result.fetchall()
# Alert events as change events
alert_result = await db.execute(
text("""
SELECT ae.fired_at, 'system' AS user_name,
ae.severity || '_alert' AS action,
d.hostname, ae.message AS details
FROM alert_events ae
LEFT JOIN devices d ON d.id = ae.device_id
WHERE ae.fired_at >= :date_from
AND ae.fired_at <= :date_to
"""),
{
"date_from": date_from,
"date_to": date_to,
},
)
alert_rows = alert_result.fetchall()
# Merge and sort by timestamp descending
entries = []
for row in backup_rows:
entries.append({
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"user": row[1],
"action": row[2],
"device": row[3],
"details": row[4] or "",
})
for row in alert_rows:
entries.append({
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"user": row[1],
"action": row[2],
"device": row[3],
"details": row[4] or "",
})
# Sort by timestamp string descending
entries.sort(key=lambda e: e["timestamp"], reverse=True)
return {
"report_title": "Change Log",
"entries": entries,
"total_entries": len(entries),
"data_source": "Backups + Alerts",
"date_from": date_from.strftime("%Y-%m-%d") if date_from else "",
"date_to": date_to.strftime("%Y-%m-%d") if date_to else "",
}
# ---------------------------------------------------------------------------
# Rendering helpers
# ---------------------------------------------------------------------------
def _render_pdf(report_type: str, context: dict[str, Any]) -> bytes:
"""Render HTML template and convert to PDF via weasyprint."""
import weasyprint
template = _jinja_env.get_template(f"reports/{report_type}.html")
html_str = template.render(**context)
pdf_bytes = weasyprint.HTML(string=html_str).write_pdf()
return pdf_bytes
def _render_csv(report_type: str, data: dict[str, Any]) -> bytes:
"""Render report data as CSV bytes."""
output = io.StringIO()
writer = csv.writer(output)
if report_type == "device_inventory":
writer.writerow([
"Hostname", "IP Address", "Model", "RouterOS Version",
"Status", "Last Seen", "Uptime", "Groups",
])
for d in data.get("devices", []):
writer.writerow([
d["hostname"], d["ip_address"], d["model"] or "",
d["routeros_version"] or "", d["status"],
d["last_seen"] or "", d["uptime"] or "",
d["groups"] or "",
])
elif report_type == "metrics_summary":
writer.writerow([
"Hostname", "Avg CPU %", "Peak CPU %", "Avg Memory %",
"Peak Memory %", "Avg Disk %", "Avg Temp", "Data Points",
])
for d in data.get("devices", []):
writer.writerow([
d["hostname"],
f"{d['avg_cpu']:.1f}" if d["avg_cpu"] is not None else "",
f"{d['peak_cpu']:.1f}" if d["peak_cpu"] is not None else "",
f"{d['avg_mem']:.1f}" if d["avg_mem"] is not None else "",
f"{d['peak_mem']:.1f}" if d["peak_mem"] is not None else "",
f"{d['avg_disk']:.1f}" if d["avg_disk"] is not None else "",
f"{d['avg_temp']:.1f}" if d["avg_temp"] is not None else "",
d["data_points"],
])
elif report_type == "alert_history":
writer.writerow([
"Timestamp", "Device", "Severity", "Message", "Status", "Duration",
])
for a in data.get("alerts", []):
writer.writerow([
a["fired_at"], a["hostname"] or "", a["severity"],
a["message"] or "", a["status"], a["duration"] or "",
])
elif report_type == "change_log":
writer.writerow([
"Timestamp", "User", "Action", "Device", "Details",
])
for e in data.get("entries", []):
writer.writerow([
e["timestamp"], e["user"] or "", e["action"],
e["device"] or "", e["details"] or "",
])
return output.getvalue().encode("utf-8")
# ---------------------------------------------------------------------------
# Formatting utilities
# ---------------------------------------------------------------------------
def _format_uptime(seconds: int) -> str:
"""Format uptime seconds as human-readable string."""
days = seconds // 86400
hours = (seconds % 86400) // 3600
minutes = (seconds % 3600) // 60
if days > 0:
return f"{days}d {hours}h {minutes}m"
elif hours > 0:
return f"{hours}h {minutes}m"
else:
return f"{minutes}m"
def _format_duration(seconds: float) -> str:
"""Format a duration in seconds as a human-readable string."""
if seconds < 60:
return f"{int(seconds)}s"
elif seconds < 3600:
return f"{int(seconds // 60)}m {int(seconds % 60)}s"
elif seconds < 86400:
hours = int(seconds // 3600)
mins = int((seconds % 3600) // 60)
return f"{hours}h {mins}m"
else:
days = int(seconds // 86400)
hours = int((seconds % 86400) // 3600)
return f"{days}d {hours}h"

View File

@@ -0,0 +1,599 @@
"""Two-phase config push with panic-revert safety for RouterOS devices.
This module implements the critical safety mechanism for config restoration:
Phase 1 — Push:
1. Pre-backup (mandatory) — snapshot current config before any changes
2. Install panic-revert RouterOS scheduler — auto-reverts if device becomes
unreachable (the scheduler fires after 90s and loads the pre-push backup)
3. Push the target config via SSH /import
Phase 2 — Verification (60s settle window):
4. Wait 60s for config to settle (scheduled processes restart, etc.)
5. Reachability check via asyncssh
6a. Reachable — remove panic-revert scheduler; mark operation committed
6b. Unreachable — RouterOS is auto-reverting; mark operation reverted
Pitfall 6 handling:
If the API pod restarts during the 60s window, the config_push_operations
row with status='pending_verification' serves as the recovery signal.
On startup, recover_stale_push_operations() resolves any stale rows.
Security policy:
known_hosts=None — RouterOS self-signed host keys; mirrors InsecureSkipVerify
used in the poller's TLS connection. See Pitfall 2 in 04-RESEARCH.md.
"""
import asyncio
import json
import logging
from datetime import datetime, timedelta, timezone
import asyncssh
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.database import set_tenant_context, AdminAsyncSessionLocal
from app.models.config_backup import ConfigPushOperation
from app.models.device import Device
from app.services import backup_service, git_store
from app.services.event_publisher import publish_event
from app.services.push_tracker import record_push, clear_push
logger = logging.getLogger(__name__)
# Name of the panic-revert scheduler installed on the RouterOS device
_PANIC_REVERT_SCHEDULER = "mikrotik-portal-panic-revert"
# Name of the pre-push binary backup saved on device flash
_PRE_PUSH_BACKUP = "portal-pre-push"
# Name of the RSC file used for /import on device
_RESTORE_RSC = "portal-restore.rsc"
async def _publish_push_progress(
tenant_id: str,
device_id: str,
stage: str,
message: str,
push_op_id: str | None = None,
error: str | None = None,
) -> None:
"""Publish config push progress event to NATS (fire-and-forget)."""
payload = {
"event_type": "config_push",
"tenant_id": tenant_id,
"device_id": device_id,
"stage": stage,
"message": message,
"timestamp": datetime.now(timezone.utc).isoformat(),
"push_operation_id": push_op_id,
}
if error:
payload["error"] = error
await publish_event(f"config.push.{tenant_id}.{device_id}", payload)
async def restore_config(
device_id: str,
tenant_id: str,
commit_sha: str,
db_session: AsyncSession,
) -> dict:
"""Restore a device config to a specific backup version via two-phase push.
Args:
device_id: Device UUID as string.
tenant_id: Tenant UUID as string.
commit_sha: Git commit SHA of the backup version to restore.
db_session: AsyncSession with RLS context already set (from API endpoint).
Returns:
{
"status": "committed" | "reverted" | "failed",
"message": str,
"pre_backup_sha": str,
}
Raises:
ValueError: If device not found or missing credentials.
Exception: On SSH failure during push phase (reverted status logged).
"""
loop = asyncio.get_event_loop()
# ------------------------------------------------------------------
# Step 1: Load device from DB and decrypt credentials
# ------------------------------------------------------------------
from sqlalchemy import select
result = await db_session.execute(
select(Device).where(Device.id == device_id) # type: ignore[arg-type]
)
device = result.scalar_one_or_none()
if device is None:
raise ValueError(f"Device {device_id!r} not found")
if not device.encrypted_credentials_transit and not device.encrypted_credentials:
raise ValueError(
f"Device {device_id!r} has no stored credentials — cannot perform restore"
)
key = settings.get_encryption_key_bytes()
from app.services.crypto import decrypt_credentials_hybrid
creds_json = await decrypt_credentials_hybrid(
device.encrypted_credentials_transit,
device.encrypted_credentials,
str(device.tenant_id),
key,
)
creds = json.loads(creds_json)
ssh_username = creds.get("username", "")
ssh_password = creds.get("password", "")
ip = device.ip_address
hostname = device.hostname or ip
# Publish "started" progress event
await _publish_push_progress(tenant_id, device_id, "started", f"Config restore started for {hostname}")
# ------------------------------------------------------------------
# Step 2: Read the target export.rsc from the backup commit
# ------------------------------------------------------------------
try:
export_bytes = await loop.run_in_executor(
None,
git_store.read_file,
tenant_id,
commit_sha,
device_id,
"export.rsc",
)
except (KeyError, Exception) as exc:
raise ValueError(
f"Backup version {commit_sha!r} not found for device {device_id!r}: {exc}"
) from exc
export_text = export_bytes.decode("utf-8", errors="replace")
# ------------------------------------------------------------------
# Step 3: Mandatory pre-backup before push
# ------------------------------------------------------------------
await _publish_push_progress(tenant_id, device_id, "backing_up", f"Creating pre-restore backup for {hostname}")
logger.info(
"Starting pre-restore backup for device %s (%s) before pushing commit %s",
hostname,
ip,
commit_sha[:8],
)
pre_backup_result = await backup_service.run_backup(
device_id=device_id,
tenant_id=tenant_id,
trigger_type="pre-restore",
db_session=db_session,
)
pre_backup_sha = pre_backup_result["commit_sha"]
logger.info("Pre-restore backup complete: %s", pre_backup_sha[:8])
# ------------------------------------------------------------------
# Step 4: Record push operation (pending_verification for recovery)
# ------------------------------------------------------------------
push_op = ConfigPushOperation(
device_id=device.id,
tenant_id=device.tenant_id,
pre_push_commit_sha=pre_backup_sha,
scheduler_name=_PANIC_REVERT_SCHEDULER,
status="pending_verification",
)
db_session.add(push_op)
await db_session.flush()
push_op_id = push_op.id
logger.info(
"Push op %s in pending_verification — if API restarts, "
"recover_stale_push_operations() will resolve on next startup",
push_op.id,
)
# ------------------------------------------------------------------
# Step 5: SSH to device — install panic-revert, push config
# ------------------------------------------------------------------
push_op_id_str = str(push_op_id)
await _publish_push_progress(tenant_id, device_id, "pushing", f"Pushing config to {hostname}", push_op_id=push_op_id_str)
logger.info(
"Pushing config to device %s (%s): installing panic-revert scheduler and uploading config",
hostname,
ip,
)
try:
async with asyncssh.connect(
ip,
port=22,
username=ssh_username,
password=ssh_password,
known_hosts=None, # RouterOS self-signed host keys — see module docstring
connect_timeout=30,
) as conn:
# 5a: Create binary backup on device as revert point
await conn.run(
f"/system backup save name={_PRE_PUSH_BACKUP} dont-encrypt=yes",
check=True,
)
logger.debug("Pre-push binary backup saved on device as %s.backup", _PRE_PUSH_BACKUP)
# 5b: Install panic-revert RouterOS scheduler
# The scheduler fires after 90s on startup and loads the pre-push backup.
# This is the safety net: if the device becomes unreachable after push,
# RouterOS will auto-revert to the known-good config on the next reboot
# or after 90s of uptime.
await conn.run(
f"/system scheduler add "
f'name="{_PANIC_REVERT_SCHEDULER}" '
f"interval=90s "
f'on-event=":delay 0; /system backup load name={_PRE_PUSH_BACKUP}" '
f"start-time=startup",
check=True,
)
logger.debug("Panic-revert scheduler installed on device")
# 5c: Upload export.rsc and /import it
# Write the RSC content to the device filesystem via SSH exec,
# then use /import to apply it. The file is cleaned up after import.
# We use a here-doc approach: write content line-by-line via /file set.
# RouterOS supports writing files via /tool fetch or direct file commands.
# Simplest approach for large configs: use asyncssh's write_into to
# write file content, then /import.
#
# RouterOS doesn't support direct SFTP uploads via SSH open_sftp() easily
# for config files. Use the script approach instead:
# /system script add + run + remove (avoids flash write concerns).
#
# Actually the simplest method: write the export.rsc line by line via
# /file print / set commands is RouterOS 6 only and unreliable.
# Best approach for RouterOS 7: use SFTP to upload the file.
async with conn.start_sftp_client() as sftp:
async with sftp.open(_RESTORE_RSC, "wb") as f:
await f.write(export_text.encode("utf-8"))
logger.debug("Uploaded %s to device flash", _RESTORE_RSC)
# /import the config file
import_result = await conn.run(
f"/import file={_RESTORE_RSC}",
check=False, # Don't raise on non-zero exit — import may succeed with warnings
)
logger.info(
"Config import result for device %s: exit_status=%s stdout=%r",
hostname,
import_result.exit_status,
(import_result.stdout or "")[:200],
)
# Clean up the uploaded RSC file (best-effort)
try:
await conn.run(f"/file remove {_RESTORE_RSC}", check=True)
except Exception as cleanup_err:
logger.warning(
"Failed to clean up %s from device %s: %s",
_RESTORE_RSC,
ip,
cleanup_err,
)
except Exception as push_err:
logger.error(
"SSH push phase failed for device %s (%s): %s",
hostname,
ip,
push_err,
)
# Update push operation to failed
await _update_push_op_status(push_op_id, "failed", db_session)
await _publish_push_progress(
tenant_id, device_id, "failed",
f"Config push failed for {hostname}: {push_err}",
push_op_id=push_op_id_str, error=str(push_err),
)
return {
"status": "failed",
"message": f"Config push failed during SSH phase: {push_err}",
"pre_backup_sha": pre_backup_sha,
}
# Record push in Redis so the poller can detect post-push offline events
await record_push(
device_id=device_id,
tenant_id=tenant_id,
push_type="restore",
push_operation_id=push_op_id_str,
pre_push_commit_sha=pre_backup_sha,
)
# ------------------------------------------------------------------
# Step 6: Wait 60s for config to settle
# ------------------------------------------------------------------
await _publish_push_progress(tenant_id, device_id, "settling", f"Config pushed to {hostname} — waiting 60s for settle", push_op_id=push_op_id_str)
logger.info(
"Config pushed to device %s — waiting 60s for config to settle",
hostname,
)
await asyncio.sleep(60)
# ------------------------------------------------------------------
# Step 7: Reachability check
# ------------------------------------------------------------------
await _publish_push_progress(tenant_id, device_id, "verifying", f"Verifying device {hostname} reachability", push_op_id=push_op_id_str)
reachable = await _check_reachability(ip, ssh_username, ssh_password)
if reachable:
# ------------------------------------------------------------------
# Step 8a: Device is reachable — remove panic-revert scheduler + cleanup
# ------------------------------------------------------------------
logger.info("Device %s (%s) is reachable after push — committing", hostname, ip)
try:
async with asyncssh.connect(
ip,
port=22,
username=ssh_username,
password=ssh_password,
known_hosts=None,
connect_timeout=30,
) as conn:
# Remove the panic-revert scheduler
await conn.run(
f'/system scheduler remove "{_PANIC_REVERT_SCHEDULER}"',
check=False, # Non-fatal if already removed
)
# Clean up the pre-push binary backup from device flash
await conn.run(
f"/file remove {_PRE_PUSH_BACKUP}.backup",
check=False, # Non-fatal if already removed
)
except Exception as cleanup_err:
# Cleanup failure is non-fatal — scheduler will eventually fire but
# the backup is now the correct config, so it's acceptable.
logger.warning(
"Failed to clean up panic-revert scheduler/backup on device %s: %s",
hostname,
cleanup_err,
)
await _update_push_op_status(push_op_id, "committed", db_session)
await clear_push(device_id)
await _publish_push_progress(tenant_id, device_id, "committed", f"Config restored successfully on {hostname}", push_op_id=push_op_id_str)
return {
"status": "committed",
"message": "Config restored successfully",
"pre_backup_sha": pre_backup_sha,
}
else:
# ------------------------------------------------------------------
# Step 8b: Device unreachable — RouterOS is auto-reverting via scheduler
# ------------------------------------------------------------------
logger.warning(
"Device %s (%s) is unreachable after push — RouterOS panic-revert scheduler "
"will auto-revert to %s.backup",
hostname,
ip,
_PRE_PUSH_BACKUP,
)
await _update_push_op_status(push_op_id, "reverted", db_session)
await _publish_push_progress(
tenant_id, device_id, "reverted",
f"Device {hostname} unreachable — auto-reverting via panic-revert scheduler",
push_op_id=push_op_id_str,
)
return {
"status": "reverted",
"message": (
"Device unreachable after push; RouterOS is auto-reverting "
"via panic-revert scheduler"
),
"pre_backup_sha": pre_backup_sha,
}
async def _check_reachability(ip: str, username: str, password: str) -> bool:
"""Check if a RouterOS device is reachable via SSH.
Attempts to connect and run a simple command (/system identity print).
Returns True if successful, False if the connection fails or times out.
Uses asyncssh (not the poller's binary API) to avoid a circular import.
A 30-second timeout is used — if the device doesn't respond within that
window, it's considered unreachable (panic-revert will handle it).
Args:
ip: Device IP address.
username: SSH username.
password: SSH password.
Returns:
True if reachable, False if unreachable.
"""
try:
async with asyncssh.connect(
ip,
port=22,
username=username,
password=password,
known_hosts=None,
connect_timeout=30,
) as conn:
result = await conn.run("/system identity print", check=True)
logger.debug("Reachability check OK for %s: %r", ip, result.stdout[:50])
return True
except Exception as exc:
logger.info("Device %s unreachable after push: %s", ip, exc)
return False
async def _update_push_op_status(
push_op_id,
new_status: str,
db_session: AsyncSession,
) -> None:
"""Update the status and completed_at of a ConfigPushOperation row.
Args:
push_op_id: UUID of the ConfigPushOperation row.
new_status: New status value ('committed' | 'reverted' | 'failed').
db_session: Database session (must already have tenant context set).
"""
from sqlalchemy import select, update
await db_session.execute(
update(ConfigPushOperation)
.where(ConfigPushOperation.id == push_op_id) # type: ignore[arg-type]
.values(
status=new_status,
completed_at=datetime.now(timezone.utc),
)
)
# Don't commit here — the caller (endpoint) owns the transaction
async def _remove_panic_scheduler(
ip: str, username: str, password: str, scheduler_name: str
) -> bool:
"""SSH to device and remove the panic-revert scheduler. Returns True if removed."""
try:
async with asyncssh.connect(
ip,
username=username,
password=password,
known_hosts=None,
connect_timeout=30,
) as conn:
# Check if scheduler exists
result = await conn.run(
f'/system scheduler print where name="{scheduler_name}"',
check=False,
)
if scheduler_name in result.stdout:
await conn.run(
f'/system scheduler remove [find name="{scheduler_name}"]',
check=False,
)
# Also clean up pre-push backup file
await conn.run(
f'/file remove [find name="{_PRE_PUSH_BACKUP}.backup"]',
check=False,
)
return True
return False # Scheduler already gone (device reverted itself)
except Exception as e:
logger.error("Failed to remove panic scheduler from %s: %s", ip, e)
return False
async def recover_stale_push_operations(db_session: AsyncSession) -> None:
"""Recover stale pending_verification push operations on API startup.
Scans for operations older than 5 minutes that are still pending.
For each, checks device reachability and resolves the operation.
"""
from sqlalchemy import select
from app.models.config_backup import ConfigPushOperation
from app.models.device import Device
from app.services.crypto import decrypt_credentials_hybrid
cutoff = datetime.now(timezone.utc) - timedelta(minutes=5)
result = await db_session.execute(
select(ConfigPushOperation).where(
ConfigPushOperation.status == "pending_verification",
ConfigPushOperation.started_at < cutoff,
)
)
stale_ops = result.scalars().all()
if not stale_ops:
logger.info("No stale push operations to recover")
return
logger.warning("Found %d stale push operations to recover", len(stale_ops))
key = settings.get_encryption_key_bytes()
for op in stale_ops:
try:
# Load device
dev_result = await db_session.execute(
select(Device).where(Device.id == op.device_id)
)
device = dev_result.scalar_one_or_none()
if not device:
logger.error("Device %s not found for stale op %s", op.device_id, op.id)
await _update_push_op_status(op.id, "failed", db_session)
continue
# Decrypt credentials
creds_json = await decrypt_credentials_hybrid(
device.encrypted_credentials_transit,
device.encrypted_credentials,
str(op.tenant_id),
key,
)
creds = json.loads(creds_json)
ssh_username = creds.get("username", "admin")
ssh_password = creds.get("password", "")
# Check reachability
reachable = await _check_reachability(
device.ip_address, ssh_username, ssh_password
)
if reachable:
# Try to remove scheduler (if still there, push was good)
removed = await _remove_panic_scheduler(
device.ip_address,
ssh_username,
ssh_password,
op.scheduler_name,
)
if removed:
logger.info("Recovery: committed op %s (scheduler removed)", op.id)
else:
# Scheduler already gone — device may have reverted
logger.warning(
"Recovery: op %s — scheduler gone, device may have reverted. "
"Marking committed (device is reachable).",
op.id,
)
await _update_push_op_status(op.id, "committed", db_session)
await _publish_push_progress(
str(op.tenant_id),
str(op.device_id),
"committed",
"Recovered after API restart",
push_op_id=str(op.id),
)
else:
logger.warning(
"Recovery: device %s unreachable, marking op %s failed",
op.device_id,
op.id,
)
await _update_push_op_status(op.id, "failed", db_session)
await _publish_push_progress(
str(op.tenant_id),
str(op.device_id),
"failed",
"Device unreachable during recovery after API restart",
push_op_id=str(op.id),
)
except Exception as e:
logger.error("Recovery failed for op %s: %s", op.id, e)
await _update_push_op_status(op.id, "failed", db_session)
await db_session.commit()

View File

@@ -0,0 +1,165 @@
"""RouterOS command proxy via NATS request-reply.
Sends command requests to the Go poller's CmdResponder subscription
(device.cmd.{device_id}) and returns structured RouterOS API response data.
Used by:
- Config editor API (browse menu paths, add/edit/delete entries)
- Template push service (execute rendered template commands)
"""
import json
import logging
from typing import Any
import nats
import nats.aio.client
from app.config import settings
logger = logging.getLogger(__name__)
# Module-level NATS connection (lazy initialized)
_nc: nats.aio.client.Client | None = None
async def _get_nats() -> nats.aio.client.Client:
"""Get or create a NATS connection for command proxy requests."""
global _nc
if _nc is None or _nc.is_closed:
_nc = await nats.connect(settings.NATS_URL)
logger.info("RouterOS proxy NATS connection established")
return _nc
async def execute_command(
device_id: str,
command: str,
args: list[str] | None = None,
timeout: float = 15.0,
) -> dict[str, Any]:
"""Execute a RouterOS API command on a device via the Go poller.
Args:
device_id: UUID string of the target device.
command: Full RouterOS API path, e.g. "/ip/address/print".
args: Optional list of RouterOS API args, e.g. ["=.proplist=.id,address"].
timeout: NATS request timeout in seconds (default 15s).
Returns:
{"success": bool, "data": list[dict], "error": str|None}
"""
nc = await _get_nats()
request = {
"device_id": device_id,
"command": command,
"args": args or [],
}
try:
reply = await nc.request(
f"device.cmd.{device_id}",
json.dumps(request).encode(),
timeout=timeout,
)
return json.loads(reply.data)
except nats.errors.TimeoutError:
return {
"success": False,
"data": [],
"error": "Device command timed out — device may be offline or unreachable",
}
except Exception as exc:
logger.error("NATS request failed for device %s: %s", device_id, exc)
return {"success": False, "data": [], "error": str(exc)}
async def browse_menu(device_id: str, path: str) -> dict[str, Any]:
"""Browse a RouterOS menu path and return all entries.
Args:
device_id: Device UUID string.
path: RouterOS menu path, e.g. "/ip/address" or "/interface".
Returns:
{"success": bool, "data": list[dict], "error": str|None}
"""
command = f"{path}/print"
return await execute_command(device_id, command)
async def add_entry(
device_id: str, path: str, properties: dict[str, str]
) -> dict[str, Any]:
"""Add a new entry to a RouterOS menu path.
Args:
device_id: Device UUID.
path: Menu path, e.g. "/ip/address".
properties: Key-value pairs for the new entry.
Returns:
Command response dict.
"""
args = [f"={k}={v}" for k, v in properties.items()]
return await execute_command(device_id, f"{path}/add", args)
async def update_entry(
device_id: str, path: str, entry_id: str | None, properties: dict[str, str]
) -> dict[str, Any]:
"""Update an existing entry in a RouterOS menu path.
Args:
device_id: Device UUID.
path: Menu path.
entry_id: RouterOS .id value (e.g. "*1"). None for singleton paths.
properties: Key-value pairs to update.
Returns:
Command response dict.
"""
id_args = [f"=.id={entry_id}"] if entry_id else []
args = id_args + [f"={k}={v}" for k, v in properties.items()]
return await execute_command(device_id, f"{path}/set", args)
async def remove_entry(
device_id: str, path: str, entry_id: str
) -> dict[str, Any]:
"""Remove an entry from a RouterOS menu path.
Args:
device_id: Device UUID.
path: Menu path.
entry_id: RouterOS .id value.
Returns:
Command response dict.
"""
return await execute_command(device_id, f"{path}/remove", [f"=.id={entry_id}"])
async def execute_cli(device_id: str, cli_command: str) -> dict[str, Any]:
"""Execute an arbitrary RouterOS CLI command.
For commands that don't follow the standard /path/action pattern.
The command is sent as-is to the RouterOS API.
Args:
device_id: Device UUID.
cli_command: Full CLI command string.
Returns:
Command response dict.
"""
return await execute_command(device_id, cli_command)
async def close() -> None:
"""Close the NATS connection. Called on application shutdown."""
global _nc
if _nc and not _nc.is_closed:
await _nc.drain()
_nc = None
logger.info("RouterOS proxy NATS connection closed")

View File

@@ -0,0 +1,220 @@
"""RouterOS RSC export parser — extracts categories, validates syntax, computes impact."""
import re
import logging
from typing import Any
logger = logging.getLogger(__name__)
HIGH_RISK_PATHS = {
"/ip address", "/ip route", "/ip firewall filter", "/ip firewall nat",
"/interface", "/interface bridge", "/interface vlan",
"/system identity", "/ip service", "/ip ssh", "/user",
}
MANAGEMENT_PATTERNS = [
(re.compile(r"chain=input.*dst-port=(22|8291|8728|8729|443|80)", re.I),
"Modifies firewall rules for management ports (SSH/WinBox/API/Web)"),
(re.compile(r"chain=input.*action=drop", re.I),
"Adds drop rule on input chain — may block management access"),
(re.compile(r"/ip service", re.I),
"Modifies IP services — may disable API/SSH/WinBox access"),
(re.compile(r"/user.*set.*password", re.I),
"Changes user password — may affect automated access"),
]
def _join_continuation_lines(text: str) -> list[str]:
"""Join lines ending with \\ into single logical lines."""
lines = text.split("\n")
joined: list[str] = []
buf = ""
for line in lines:
stripped = line.rstrip()
if stripped.endswith("\\"):
buf += stripped[:-1].rstrip() + " "
else:
if buf:
buf += stripped
joined.append(buf)
buf = ""
else:
joined.append(stripped)
if buf:
joined.append(buf + " <<TRUNCATED>>")
return joined
def parse_rsc(text: str) -> dict[str, Any]:
"""Parse a RouterOS /export compact output.
Returns a dict with a "categories" list, each containing:
- path: the RouterOS command path (e.g. "/ip address")
- adds: count of "add" commands
- sets: count of "set" commands
- removes: count of "remove" commands
- commands: list of command strings under this path
"""
lines = _join_continuation_lines(text)
categories: dict[str, dict] = {}
current_path: str | None = None
for line in lines:
line = line.strip()
if not line or line.startswith("#"):
continue
if line.startswith("/"):
# Could be just a path header, or a path followed by a command
parts = line.split(None, 1)
if len(parts) == 1:
# Pure path header like "/interface bridge"
current_path = parts[0]
else:
# Check if second part starts with a known command verb
cmd_check = parts[1].strip().split(None, 1)
if cmd_check and cmd_check[0] in ("add", "set", "remove", "print", "enable", "disable"):
current_path = parts[0]
line = parts[1].strip()
else:
# The whole line is a path (e.g. "/ip firewall filter")
current_path = line
continue
if current_path and current_path not in categories:
categories[current_path] = {
"path": current_path,
"adds": 0,
"sets": 0,
"removes": 0,
"commands": [],
}
if len(parts) == 1:
continue
if current_path is None:
continue
if current_path not in categories:
categories[current_path] = {
"path": current_path,
"adds": 0,
"sets": 0,
"removes": 0,
"commands": [],
}
cat = categories[current_path]
cat["commands"].append(line)
if line.startswith("add ") or line.startswith("add\t"):
cat["adds"] += 1
elif line.startswith("set "):
cat["sets"] += 1
elif line.startswith("remove "):
cat["removes"] += 1
return {"categories": list(categories.values())}
def validate_rsc(text: str) -> dict[str, Any]:
"""Validate RSC export syntax.
Checks for:
- Unbalanced quotes (indicates truncation or corruption)
- Trailing continuation lines (indicates truncated export)
Returns dict with "valid" (bool) and "errors" (list of strings).
"""
errors: list[str] = []
# Check for unbalanced quotes across the entire file
in_quote = False
for line in text.split("\n"):
stripped = line.rstrip()
if stripped.endswith("\\"):
stripped = stripped[:-1]
# Count unescaped quotes
count = stripped.count('"') - stripped.count('\\"')
if count % 2 != 0:
in_quote = not in_quote
if in_quote:
errors.append("Unbalanced quote detected — file may be truncated")
# Check if file ends with a continuation backslash
lines = text.rstrip().split("\n")
if lines and lines[-1].rstrip().endswith("\\"):
errors.append("File ends with continuation line (\\) — truncated export")
return {"valid": len(errors) == 0, "errors": errors}
def compute_impact(
current_parsed: dict[str, Any],
target_parsed: dict[str, Any],
) -> dict[str, Any]:
"""Compare current vs target parsed RSC and compute impact analysis.
Returns dict with:
- categories: list of per-path diffs with risk levels
- warnings: list of human-readable warning strings
- diff: summary counts (added, removed, modified)
"""
current_map = {c["path"]: c for c in current_parsed["categories"]}
target_map = {c["path"]: c for c in target_parsed["categories"]}
all_paths = sorted(set(list(current_map.keys()) + list(target_map.keys())))
result_categories = []
warnings: list[str] = []
total_added = total_removed = total_modified = 0
for path in all_paths:
curr = current_map.get(path, {"adds": 0, "sets": 0, "removes": 0, "commands": []})
tgt = target_map.get(path, {"adds": 0, "sets": 0, "removes": 0, "commands": []})
curr_cmds = set(curr.get("commands", []))
tgt_cmds = set(tgt.get("commands", []))
added = len(tgt_cmds - curr_cmds)
removed = len(curr_cmds - tgt_cmds)
total_added += added
total_removed += removed
has_changes = added > 0 or removed > 0
risk = "none"
if has_changes:
risk = "high" if path in HIGH_RISK_PATHS else "low"
result_categories.append({
"path": path,
"adds": added,
"removes": removed,
"risk": risk,
})
# Check target commands against management patterns
target_text = "\n".join(
cmd for cat in target_parsed["categories"] for cmd in cat.get("commands", [])
)
for pattern, message in MANAGEMENT_PATTERNS:
if pattern.search(target_text):
warnings.append(message)
# Warn about removed IP addresses
if "/ip address" in current_map and "/ip address" in target_map:
curr_addrs = current_map["/ip address"].get("commands", [])
tgt_addrs = target_map["/ip address"].get("commands", [])
removed_addrs = set(curr_addrs) - set(tgt_addrs)
if removed_addrs:
warnings.append(
f"Removes {len(removed_addrs)} IP address(es) — verify none are management interfaces"
)
return {
"categories": result_categories,
"warnings": warnings,
"diff": {
"added": total_added,
"removed": total_removed,
"modified": total_modified,
},
}

View File

@@ -0,0 +1,124 @@
"""
Subnet scanner for MikroTik device discovery.
Scans a CIDR range by attempting TCP connections to RouterOS API ports
(8728 and 8729) with configurable concurrency limits and timeouts.
Security constraints:
- CIDR range limited to /20 or smaller (4096 IPs maximum)
- Maximum 50 concurrent connections to prevent network flooding
- 2-second timeout per connection attempt
"""
import asyncio
import ipaddress
import socket
from typing import Optional
from app.schemas.device import SubnetScanResult
# Maximum concurrency for TCP probes
_MAX_CONCURRENT = 50
# Timeout (seconds) per TCP connection attempt
_TCP_TIMEOUT = 2.0
# RouterOS API port
_API_PORT = 8728
# RouterOS SSL API port
_SSL_PORT = 8729
async def _probe_host(
semaphore: asyncio.Semaphore,
ip_str: str,
) -> Optional[SubnetScanResult]:
"""
Probe a single IP for RouterOS API ports.
Returns a SubnetScanResult if either port is open, None otherwise.
"""
async with semaphore:
api_open, ssl_open = await asyncio.gather(
_tcp_connect(ip_str, _API_PORT),
_tcp_connect(ip_str, _SSL_PORT),
return_exceptions=False,
)
if not api_open and not ssl_open:
return None
# Attempt reverse DNS (best-effort; won't fail the scan)
hostname = await _reverse_dns(ip_str)
return SubnetScanResult(
ip_address=ip_str,
hostname=hostname,
api_port_open=api_open,
api_ssl_port_open=ssl_open,
)
async def _tcp_connect(ip: str, port: int) -> bool:
"""Return True if a TCP connection to ip:port succeeds within _TCP_TIMEOUT."""
try:
_, writer = await asyncio.wait_for(
asyncio.open_connection(ip, port),
timeout=_TCP_TIMEOUT,
)
writer.close()
try:
await writer.wait_closed()
except Exception:
pass
return True
except Exception:
return False
async def _reverse_dns(ip: str) -> Optional[str]:
"""Attempt a reverse DNS lookup. Returns None on failure."""
try:
loop = asyncio.get_running_loop()
hostname, _, _ = await asyncio.wait_for(
loop.run_in_executor(None, socket.gethostbyaddr, ip),
timeout=1.5,
)
return hostname
except Exception:
return None
async def scan_subnet(cidr: str) -> list[SubnetScanResult]:
"""
Scan a CIDR range for hosts with open RouterOS API ports.
Args:
cidr: CIDR notation string, e.g. "192.168.1.0/24".
Must be /20 or smaller (validated by SubnetScanRequest).
Returns:
List of SubnetScanResult for each host with at least one open API port.
Raises:
ValueError: If CIDR is malformed or too large.
"""
try:
network = ipaddress.ip_network(cidr, strict=False)
except ValueError as e:
raise ValueError(f"Invalid CIDR: {e}") from e
if network.num_addresses > 4096:
raise ValueError(
f"CIDR range too large ({network.num_addresses} addresses). "
"Maximum allowed is /20 (4096 addresses)."
)
# Skip network address and broadcast address for IPv4
hosts = list(network.hosts()) if network.num_addresses > 2 else list(network)
semaphore = asyncio.Semaphore(_MAX_CONCURRENT)
tasks = [_probe_host(semaphore, str(ip)) for ip in hosts]
results = await asyncio.gather(*tasks, return_exceptions=False)
# Filter out None (hosts with no open ports)
return [r for r in results if r is not None]

View File

@@ -0,0 +1,113 @@
"""SRP-6a server-side authentication service.
Wraps the srptools library for the two-step SRP handshake.
All functions are async, using asyncio.to_thread() because
srptools operations are CPU-bound and synchronous.
"""
import asyncio
import hashlib
from srptools import SRPContext, SRPServerSession
from srptools.constants import PRIME_2048, PRIME_2048_GEN
# Client uses Web Crypto SHA-256 — server must match.
# srptools defaults to SHA-1 which would cause proof mismatch.
_SRP_HASH = hashlib.sha256
async def create_srp_verifier(
salt_hex: str, verifier_hex: str
) -> tuple[bytes, bytes]:
"""Convert client-provided hex salt and verifier to bytes for storage.
The client computes v = g^x mod N using 2SKD-derived SRP-x.
The server stores the verifier directly and never computes x
from the password.
Returns:
Tuple of (salt_bytes, verifier_bytes) ready for database storage.
"""
return bytes.fromhex(salt_hex), bytes.fromhex(verifier_hex)
async def srp_init(
email: str, srp_verifier_hex: str
) -> tuple[str, str]:
"""SRP Step 1: Generate server ephemeral (B) and private key (b).
Args:
email: User email (SRP identity I).
srp_verifier_hex: Hex-encoded SRP verifier from database.
Returns:
Tuple of (server_public_hex, server_private_hex).
Caller stores server_private in Redis with 60s TTL.
Raises:
ValueError: If SRP initialization fails for any reason.
"""
def _init() -> tuple[str, str]:
context = SRPContext(
email, prime=PRIME_2048, generator=PRIME_2048_GEN,
hash_func=_SRP_HASH,
)
server_session = SRPServerSession(
context, srp_verifier_hex
)
return server_session.public, server_session.private
try:
return await asyncio.to_thread(_init)
except Exception as e:
raise ValueError(f"SRP initialization failed: {e}") from e
async def srp_verify(
email: str,
srp_verifier_hex: str,
server_private: str,
client_public: str,
client_proof: str,
srp_salt_hex: str,
) -> tuple[bool, str | None]:
"""SRP Step 2: Verify client proof M1, return server proof M2.
Args:
email: User email (SRP identity I).
srp_verifier_hex: Hex-encoded SRP verifier from database.
server_private: Server private ephemeral from Redis session.
client_public: Hex-encoded client public ephemeral A.
client_proof: Hex-encoded client proof M1.
srp_salt_hex: Hex-encoded SRP salt.
Returns:
Tuple of (is_valid, server_proof_hex_or_none).
If valid, server_proof is M2 for the client to verify.
"""
def _verify() -> tuple[bool, str | None]:
import logging
log = logging.getLogger("srp_debug")
context = SRPContext(
email, prime=PRIME_2048, generator=PRIME_2048_GEN,
hash_func=_SRP_HASH,
)
server_session = SRPServerSession(
context, srp_verifier_hex, private=server_private
)
_key, _key_proof, _key_proof_hash = server_session.process(client_public, srp_salt_hex)
# srptools verify_proof has a Python 3 bug: hexlify() returns bytes
# but client_proof is str, so bytes == str is always False.
# Compare manually with consistent types.
server_m1 = _key_proof if isinstance(_key_proof, str) else _key_proof.decode('ascii')
is_valid = client_proof.lower() == server_m1.lower()
if not is_valid:
return False, None
# Return M2 (key_proof_hash), also fixing the bytes/str issue
m2 = _key_proof_hash if isinstance(_key_proof_hash, str) else _key_proof_hash.decode('ascii')
return True, m2
try:
return await asyncio.to_thread(_verify)
except Exception as e:
raise ValueError(f"SRP verification failed: {e}") from e

View File

@@ -0,0 +1,311 @@
"""SSE Connection Manager -- bridges NATS JetStream to per-client asyncio queues.
Each SSE client gets its own NATS connection with ephemeral consumers.
Events are tenant-filtered and placed onto an asyncio.Queue that the
SSE router drains via EventSourceResponse.
"""
import asyncio
import json
from typing import Optional
import nats
import structlog
from nats.js.api import ConsumerConfig, DeliverPolicy, StreamConfig
from app.config import settings
logger = structlog.get_logger(__name__)
# Subjects per stream for SSE subscriptions
# Note: config.push.* subjects live in DEVICE_EVENTS (created by Go poller)
_DEVICE_EVENT_SUBJECTS = [
"device.status.>",
"device.metrics.>",
"config.push.rollback.>",
"config.push.alert.>",
]
_ALERT_EVENT_SUBJECTS = ["alert.fired.>", "alert.resolved.>"]
_OPERATION_EVENT_SUBJECTS = ["firmware.progress.>"]
def _map_subject_to_event_type(subject: str) -> str:
"""Map a NATS subject prefix to an SSE event type string."""
if subject.startswith("device.status."):
return "device_status"
if subject.startswith("device.metrics."):
return "metric_update"
if subject.startswith("alert.fired."):
return "alert_fired"
if subject.startswith("alert.resolved."):
return "alert_resolved"
if subject.startswith("config.push."):
return "config_push"
if subject.startswith("firmware.progress."):
return "firmware_progress"
return "unknown"
async def ensure_sse_streams() -> None:
"""Create ALERT_EVENTS and OPERATION_EVENTS NATS streams if they don't exist.
Called once during app startup so the streams are ready before any
SSE connection or event publisher needs them. Idempotent -- uses
add_stream which acts as create-or-update.
"""
nc = None
try:
nc = await nats.connect(settings.NATS_URL)
js = nc.jetstream()
await js.add_stream(
StreamConfig(
name="ALERT_EVENTS",
subjects=["alert.fired.>", "alert.resolved.>"],
max_age=3600, # 1 hour retention
)
)
logger.info("nats.stream.ensured", stream="ALERT_EVENTS")
await js.add_stream(
StreamConfig(
name="OPERATION_EVENTS",
subjects=["firmware.progress.>"],
max_age=3600, # 1 hour retention
)
)
logger.info("nats.stream.ensured", stream="OPERATION_EVENTS")
except Exception as exc:
logger.warning("sse.streams.ensure_failed", error=str(exc))
raise
finally:
if nc:
try:
await nc.close()
except Exception:
pass
class SSEConnectionManager:
"""Manages a single SSE client's lifecycle: NATS connection, subscriptions, and event queue."""
def __init__(self) -> None:
self._nc: Optional[nats.aio.client.Client] = None
self._subscriptions: list = []
self._queue: Optional[asyncio.Queue] = None
self._tenant_id: Optional[str] = None
self._connection_id: Optional[str] = None
async def connect(
self,
connection_id: str,
tenant_id: Optional[str],
last_event_id: Optional[str] = None,
) -> asyncio.Queue:
"""Set up NATS subscriptions and return an asyncio.Queue for SSE events.
Args:
connection_id: Unique identifier for this SSE connection.
tenant_id: Tenant UUID string to filter events. None for super_admin
(receives events from all tenants).
last_event_id: NATS stream sequence number from the Last-Event-ID header.
If provided, replay starts from sequence + 1.
Returns:
asyncio.Queue that the SSE generator should drain.
"""
self._connection_id = connection_id
self._tenant_id = tenant_id
self._queue = asyncio.Queue(maxsize=256)
self._nc = await nats.connect(
settings.NATS_URL,
max_reconnect_attempts=5,
reconnect_time_wait=2,
)
js = self._nc.jetstream()
logger.info(
"sse.connecting",
connection_id=connection_id,
tenant_id=tenant_id,
last_event_id=last_event_id,
)
# Build consumer config for replay support
if last_event_id is not None:
try:
start_seq = int(last_event_id) + 1
consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.BY_START_SEQUENCE, opt_start_seq=start_seq)
except (ValueError, TypeError):
consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.NEW)
else:
consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.NEW)
# Subscribe to device events (DEVICE_EVENTS stream -- created by Go poller)
for subject in _DEVICE_EVENT_SUBJECTS:
try:
sub = await js.subscribe(
subject,
stream="DEVICE_EVENTS",
config=consumer_cfg,
)
self._subscriptions.append(sub)
except Exception as exc:
logger.warning(
"sse.subscribe_failed",
subject=subject,
stream="DEVICE_EVENTS",
error=str(exc),
)
# Subscribe to alert events (ALERT_EVENTS stream)
# Lazily create the stream if it doesn't exist yet (startup race)
for subject in _ALERT_EVENT_SUBJECTS:
try:
sub = await js.subscribe(
subject,
stream="ALERT_EVENTS",
config=consumer_cfg,
)
self._subscriptions.append(sub)
except Exception as exc:
if "stream not found" in str(exc):
try:
await js.add_stream(StreamConfig(
name="ALERT_EVENTS",
subjects=_ALERT_EVENT_SUBJECTS,
max_age=3600,
))
sub = await js.subscribe(subject, stream="ALERT_EVENTS", config=consumer_cfg)
self._subscriptions.append(sub)
logger.info("sse.stream_created_lazily", stream="ALERT_EVENTS")
except Exception as retry_exc:
logger.warning("sse.subscribe_failed", subject=subject, stream="ALERT_EVENTS", error=str(retry_exc))
else:
logger.warning("sse.subscribe_failed", subject=subject, stream="ALERT_EVENTS", error=str(exc))
# Subscribe to operation events (OPERATION_EVENTS stream)
for subject in _OPERATION_EVENT_SUBJECTS:
try:
sub = await js.subscribe(
subject,
stream="OPERATION_EVENTS",
config=consumer_cfg,
)
self._subscriptions.append(sub)
except Exception as exc:
if "stream not found" in str(exc):
try:
await js.add_stream(StreamConfig(
name="OPERATION_EVENTS",
subjects=_OPERATION_EVENT_SUBJECTS,
max_age=3600,
))
sub = await js.subscribe(subject, stream="OPERATION_EVENTS", config=consumer_cfg)
self._subscriptions.append(sub)
logger.info("sse.stream_created_lazily", stream="OPERATION_EVENTS")
except Exception as retry_exc:
logger.warning("sse.subscribe_failed", subject=subject, stream="OPERATION_EVENTS", error=str(retry_exc))
else:
logger.warning("sse.subscribe_failed", subject=subject, stream="OPERATION_EVENTS", error=str(exc))
# Start background task to pull messages from subscriptions into the queue
asyncio.create_task(self._pump_messages())
logger.info(
"sse.connected",
connection_id=connection_id,
subscription_count=len(self._subscriptions),
)
return self._queue
async def _pump_messages(self) -> None:
"""Read messages from all NATS push subscriptions and push them onto the asyncio queue.
Uses next_msg with a short timeout so we can interleave across
subscriptions without blocking. Runs until the NATS connection is closed
or drained.
"""
while self._nc and self._nc.is_connected:
for sub in self._subscriptions:
try:
msg = await sub.next_msg(timeout=0.5)
await self._handle_message(msg)
except nats.errors.TimeoutError:
# No messages available on this subscription -- move on
continue
except Exception as exc:
if self._nc and self._nc.is_connected:
logger.warning(
"sse.pump_error",
connection_id=self._connection_id,
error=str(exc),
)
break
# Brief yield to avoid tight-looping
await asyncio.sleep(0.1)
async def _handle_message(self, msg) -> None:
"""Parse a NATS message, apply tenant filter, and enqueue as SSE event."""
try:
data = json.loads(msg.data)
except (json.JSONDecodeError, UnicodeDecodeError):
await msg.ack()
return
# Tenant filtering: skip messages not matching this connection's tenant
if self._tenant_id is not None:
msg_tenant = data.get("tenant_id", "")
if str(msg_tenant) != self._tenant_id:
await msg.ack()
return
event_type = _map_subject_to_event_type(msg.subject)
# Extract NATS stream sequence for Last-Event-ID support
seq_id = "0"
if msg.metadata and msg.metadata.sequence:
seq_id = str(msg.metadata.sequence.stream)
sse_event = {
"event": event_type,
"data": json.dumps(data),
"id": seq_id,
}
try:
self._queue.put_nowait(sse_event)
except asyncio.QueueFull:
logger.warning(
"sse.queue_full",
connection_id=self._connection_id,
dropped_event=event_type,
)
await msg.ack()
async def disconnect(self) -> None:
"""Unsubscribe from all NATS subscriptions and close the connection."""
logger.info("sse.disconnecting", connection_id=self._connection_id)
for sub in self._subscriptions:
try:
await sub.unsubscribe()
except Exception:
pass
self._subscriptions.clear()
if self._nc:
try:
await self._nc.drain()
except Exception:
try:
await self._nc.close()
except Exception:
pass
self._nc = None
logger.info("sse.disconnected", connection_id=self._connection_id)

View File

@@ -0,0 +1,480 @@
"""Config template service: Jinja2 rendering, variable extraction, and multi-device push.
Provides:
- extract_variables: Parse template content to find all undeclared Jinja2 variables
- render_template: Render a template with device context and custom variables
- validate_variable: Type-check a variable value against its declared type
- push_to_devices: Sequential multi-device push with pause-on-failure
- push_single_device: Two-phase panic-revert push for a single device
The push logic follows the same two-phase pattern as restore_service but uses
separate scheduler and file names to avoid conflicts with restore operations.
"""
import asyncio
import io
import ipaddress
import json
import logging
import uuid
from datetime import datetime, timezone
import asyncssh
from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment
from sqlalchemy import select, text
from app.config import settings
from app.database import AdminAsyncSessionLocal
from app.models.config_template import TemplatePushJob
from app.models.device import Device
logger = logging.getLogger(__name__)
# Sandboxed Jinja2 environment prevents template injection
_env = SandboxedEnvironment()
# Names used on the RouterOS device during template push
_PANIC_REVERT_SCHEDULER = "mikrotik-portal-template-revert"
_PRE_PUSH_BACKUP = "portal-template-pre-push"
_TEMPLATE_RSC = "portal-template.rsc"
# ---------------------------------------------------------------------------
# Variable extraction & rendering
# ---------------------------------------------------------------------------
def extract_variables(template_content: str) -> list[str]:
"""Extract all undeclared variables from a Jinja2 template.
Returns a sorted list of variable names, excluding the built-in 'device'
variable which is auto-populated at render time.
"""
ast = _env.parse(template_content)
all_vars = meta.find_undeclared_variables(ast)
# 'device' is a built-in variable, not user-provided
return sorted(v for v in all_vars if v != "device")
def render_template(
template_content: str,
device: dict,
custom_variables: dict[str, str],
) -> str:
"""Render a Jinja2 template with device context and custom variables.
The 'device' variable is auto-populated from the device dict.
Custom variables are user-provided at push time.
Uses SandboxedEnvironment to prevent template injection.
Args:
template_content: Jinja2 template string.
device: Device info dict with keys: hostname, ip_address, model.
custom_variables: User-supplied variable values.
Returns:
Rendered template string.
Raises:
jinja2.TemplateSyntaxError: If template has syntax errors.
jinja2.UndefinedError: If required variables are missing.
"""
context = {
"device": {
"hostname": device.get("hostname", ""),
"ip": device.get("ip_address", ""),
"model": device.get("model", ""),
},
**custom_variables,
}
tpl = _env.from_string(template_content)
return tpl.render(context)
def validate_variable(name: str, value: str, var_type: str) -> str | None:
"""Validate a variable value against its declared type.
Returns None on success, or an error message string on failure.
"""
if var_type == "string":
return None # any string is valid
elif var_type == "ip":
try:
ipaddress.ip_address(value)
return None
except ValueError:
return f"'{name}' must be a valid IP address"
elif var_type == "subnet":
try:
ipaddress.ip_network(value, strict=False)
return None
except ValueError:
return f"'{name}' must be a valid subnet (e.g., 192.168.1.0/24)"
elif var_type == "integer":
try:
int(value)
return None
except ValueError:
return f"'{name}' must be an integer"
elif var_type == "boolean":
if value.lower() in ("true", "false", "yes", "no", "1", "0"):
return None
return f"'{name}' must be a boolean (true/false)"
return None # unknown type, allow
# ---------------------------------------------------------------------------
# Multi-device push orchestration
# ---------------------------------------------------------------------------
async def push_to_devices(rollout_id: str) -> dict:
"""Execute sequential template push for all jobs in a rollout.
Processes devices one at a time. If any device fails or reverts,
remaining jobs stay pending (paused). Follows the same pattern as
firmware upgrade_service.start_mass_upgrade.
This runs as a background task (asyncio.create_task) after the
API creates the push jobs and returns the rollout_id.
"""
try:
return await _run_push_rollout(rollout_id)
except Exception as exc:
logger.error(
"Uncaught exception in template push rollout %s: %s",
rollout_id, exc, exc_info=True,
)
return {"completed": 0, "failed": 1, "pending": 0}
async def _run_push_rollout(rollout_id: str) -> dict:
"""Internal rollout implementation."""
# Load all jobs for this rollout
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT j.id::text, j.status, d.hostname
FROM template_push_jobs j
JOIN devices d ON d.id = j.device_id
WHERE j.rollout_id = CAST(:rollout_id AS uuid)
ORDER BY j.created_at ASC
"""),
{"rollout_id": rollout_id},
)
jobs = result.fetchall()
if not jobs:
logger.warning("No jobs found for template push rollout %s", rollout_id)
return {"completed": 0, "failed": 0, "pending": 0}
completed = 0
failed = False
for job_id, current_status, hostname in jobs:
if current_status != "pending":
if current_status == "committed":
completed += 1
continue
logger.info(
"Template push rollout %s: pushing to device %s (job %s)",
rollout_id, hostname, job_id,
)
await push_single_device(job_id)
# Check resulting status
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("SELECT status FROM template_push_jobs WHERE id = CAST(:id AS uuid)"),
{"id": job_id},
)
row = result.fetchone()
if row and row[0] == "committed":
completed += 1
elif row and row[0] in ("failed", "reverted"):
failed = True
logger.error(
"Template push rollout %s paused: device %s %s",
rollout_id, hostname, row[0],
)
break
# Count remaining pending jobs
remaining = sum(1 for _, s, _ in jobs if s == "pending") - completed - (1 if failed else 0)
return {
"completed": completed,
"failed": 1 if failed else 0,
"pending": max(0, remaining),
}
async def push_single_device(job_id: str) -> None:
"""Push rendered template content to a single device.
Implements the two-phase panic-revert pattern:
1. Pre-backup (mandatory)
2. Install panic-revert scheduler on device
3. Write template content as RSC file via SFTP
4. /import the RSC file
5. Wait 60s for config to settle
6. Reachability check -> committed or reverted
All errors are caught and recorded in the job row.
"""
try:
await _run_single_push(job_id)
except Exception as exc:
logger.error(
"Uncaught exception in template push job %s: %s",
job_id, exc, exc_info=True,
)
await _update_job(job_id, status="failed", error_message=f"Unexpected error: {exc}")
async def _run_single_push(job_id: str) -> None:
"""Internal single-device push implementation."""
# Step 1: Load job and device info
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT j.id, j.device_id, j.tenant_id, j.rendered_content,
d.ip_address, d.hostname, d.encrypted_credentials,
d.encrypted_credentials_transit
FROM template_push_jobs j
JOIN devices d ON d.id = j.device_id
WHERE j.id = CAST(:job_id AS uuid)
"""),
{"job_id": job_id},
)
row = result.fetchone()
if not row:
logger.error("Template push job %s not found", job_id)
return
(
_, device_id, tenant_id, rendered_content,
ip_address, hostname, encrypted_credentials,
encrypted_credentials_transit,
) = row
device_id = str(device_id)
tenant_id = str(tenant_id)
hostname = hostname or ip_address
# Step 2: Update status to pushing
await _update_job(job_id, status="pushing", started_at=datetime.now(timezone.utc))
# Step 3: Decrypt credentials (dual-read: Transit preferred, legacy fallback)
if not encrypted_credentials_transit and not encrypted_credentials:
await _update_job(job_id, status="failed", error_message="Device has no stored credentials")
return
try:
from app.services.crypto import decrypt_credentials_hybrid
key = settings.get_encryption_key_bytes()
creds_json = await decrypt_credentials_hybrid(
encrypted_credentials_transit, encrypted_credentials, tenant_id, key,
)
creds = json.loads(creds_json)
ssh_username = creds.get("username", "")
ssh_password = creds.get("password", "")
except Exception as cred_err:
await _update_job(
job_id, status="failed",
error_message=f"Failed to decrypt credentials: {cred_err}",
)
return
# Step 4: Mandatory pre-push backup
logger.info("Running mandatory pre-push backup for device %s (%s)", hostname, ip_address)
try:
from app.services import backup_service
backup_result = await backup_service.run_backup(
device_id=device_id,
tenant_id=tenant_id,
trigger_type="pre-template-push",
)
backup_sha = backup_result["commit_sha"]
await _update_job(job_id, pre_push_backup_sha=backup_sha)
logger.info("Pre-push backup complete: %s", backup_sha[:8])
except Exception as backup_err:
logger.error("Pre-push backup failed for %s: %s", hostname, backup_err)
await _update_job(
job_id, status="failed",
error_message=f"Pre-push backup failed: {backup_err}",
)
return
# Step 5: SSH to device - install panic-revert, push config
logger.info(
"Pushing template to device %s (%s): installing panic-revert and uploading config",
hostname, ip_address,
)
try:
async with asyncssh.connect(
ip_address,
port=22,
username=ssh_username,
password=ssh_password,
known_hosts=None,
connect_timeout=30,
) as conn:
# 5a: Create binary backup on device as revert point
await conn.run(
f"/system backup save name={_PRE_PUSH_BACKUP} dont-encrypt=yes",
check=True,
)
logger.debug("Pre-push binary backup saved on device as %s.backup", _PRE_PUSH_BACKUP)
# 5b: Install panic-revert RouterOS scheduler
await conn.run(
f"/system scheduler add "
f'name="{_PANIC_REVERT_SCHEDULER}" '
f"interval=90s "
f'on-event=":delay 0; /system backup load name={_PRE_PUSH_BACKUP}" '
f"start-time=startup",
check=True,
)
logger.debug("Panic-revert scheduler installed on device")
# 5c: Upload rendered template as RSC file via SFTP
async with conn.start_sftp_client() as sftp:
async with sftp.open(_TEMPLATE_RSC, "wb") as f:
await f.write(rendered_content.encode("utf-8"))
logger.debug("Uploaded %s to device flash", _TEMPLATE_RSC)
# 5d: /import the config file
import_result = await conn.run(
f"/import file={_TEMPLATE_RSC}",
check=False,
)
logger.info(
"Template import result for device %s: exit_status=%s stdout=%r",
hostname, import_result.exit_status,
(import_result.stdout or "")[:200],
)
# 5e: Clean up the uploaded RSC file (best-effort)
try:
await conn.run(f"/file remove {_TEMPLATE_RSC}", check=True)
except Exception as cleanup_err:
logger.warning(
"Failed to clean up %s from device %s: %s",
_TEMPLATE_RSC, ip_address, cleanup_err,
)
except Exception as push_err:
logger.error(
"SSH push phase failed for device %s (%s): %s",
hostname, ip_address, push_err,
)
await _update_job(
job_id, status="failed",
error_message=f"Config push failed during SSH phase: {push_err}",
)
return
# Step 6: Wait 60s for config to settle
logger.info("Template pushed to device %s - waiting 60s for config to settle", hostname)
await asyncio.sleep(60)
# Step 7: Reachability check
reachable = await _check_reachability(ip_address, ssh_username, ssh_password)
if reachable:
# Step 8a: Device is reachable - remove panic-revert scheduler + cleanup
logger.info("Device %s (%s) is reachable after push - committing", hostname, ip_address)
try:
async with asyncssh.connect(
ip_address, port=22,
username=ssh_username, password=ssh_password,
known_hosts=None, connect_timeout=30,
) as conn:
await conn.run(
f'/system scheduler remove "{_PANIC_REVERT_SCHEDULER}"',
check=False,
)
await conn.run(
f"/file remove {_PRE_PUSH_BACKUP}.backup",
check=False,
)
except Exception as cleanup_err:
logger.warning(
"Failed to clean up panic-revert scheduler/backup on device %s: %s",
hostname, cleanup_err,
)
await _update_job(
job_id, status="committed",
completed_at=datetime.now(timezone.utc),
)
else:
# Step 8b: Device unreachable - RouterOS is auto-reverting
logger.warning(
"Device %s (%s) is unreachable after push - panic-revert scheduler "
"will auto-revert to %s.backup",
hostname, ip_address, _PRE_PUSH_BACKUP,
)
await _update_job(
job_id, status="reverted",
error_message="Device unreachable after push; auto-reverted via panic-revert scheduler",
completed_at=datetime.now(timezone.utc),
)
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
async def _check_reachability(ip: str, username: str, password: str) -> bool:
"""Check if a RouterOS device is reachable via SSH."""
try:
async with asyncssh.connect(
ip, port=22,
username=username, password=password,
known_hosts=None, connect_timeout=30,
) as conn:
result = await conn.run("/system identity print", check=True)
logger.debug("Reachability check OK for %s: %r", ip, result.stdout[:50])
return True
except Exception as exc:
logger.info("Device %s unreachable after push: %s", ip, exc)
return False
async def _update_job(job_id: str, **kwargs) -> None:
"""Update TemplatePushJob fields via raw SQL (background task, no RLS)."""
sets = []
params: dict = {"job_id": job_id}
for key, value in kwargs.items():
param_name = f"v_{key}"
if value is None and key in ("error_message", "started_at", "completed_at", "pre_push_backup_sha"):
sets.append(f"{key} = NULL")
else:
sets.append(f"{key} = :{param_name}")
params[param_name] = value
if not sets:
return
async with AdminAsyncSessionLocal() as session:
await session.execute(
text(f"""
UPDATE template_push_jobs
SET {', '.join(sets)}
WHERE id = CAST(:job_id AS uuid)
"""),
params,
)
await session.commit()

View File

@@ -0,0 +1,564 @@
"""Firmware upgrade orchestration service.
Handles single-device and mass firmware upgrades with:
- Mandatory pre-upgrade config backup
- NPK download and SFTP upload to device
- Reboot trigger and reconnect polling
- Post-upgrade version verification
- Sequential mass rollout with pause-on-failure
- Scheduled upgrades via APScheduler DateTrigger
All DB operations use AdminAsyncSessionLocal to bypass RLS since upgrade
jobs may span multiple tenants and run in background asyncio tasks.
"""
import asyncio
import io
import json
import logging
from datetime import datetime, timezone
from pathlib import Path
import asyncssh
from sqlalchemy import text
from app.config import settings
from app.database import AdminAsyncSessionLocal
from app.services.event_publisher import publish_event
logger = logging.getLogger(__name__)
# Maximum time to wait for a device to reconnect after reboot (seconds)
_RECONNECT_TIMEOUT = 300 # 5 minutes
_RECONNECT_POLL_INTERVAL = 15 # seconds
_INITIAL_WAIT = 60 # Wait before first reconnect attempt (boot cycle)
async def start_upgrade(job_id: str) -> None:
"""Execute a single device firmware upgrade.
Lifecycle: pending -> downloading -> uploading -> rebooting -> verifying -> completed/failed
This function is designed to run as a background asyncio.create_task or
APScheduler job. It never raises — all errors are caught and recorded
in the FirmwareUpgradeJob row.
"""
try:
await _run_upgrade(job_id)
except Exception as exc:
logger.error("Uncaught exception in firmware upgrade %s: %s", job_id, exc, exc_info=True)
await _update_job(job_id, status="failed", error_message=f"Unexpected error: {exc}")
async def _publish_upgrade_progress(
tenant_id: str,
device_id: str,
job_id: str,
stage: str,
target_version: str,
message: str,
error: str | None = None,
) -> None:
"""Publish firmware upgrade progress event to NATS (fire-and-forget)."""
payload = {
"event_type": "firmware_progress",
"tenant_id": tenant_id,
"device_id": device_id,
"job_id": job_id,
"stage": stage,
"target_version": target_version,
"message": message,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
if error:
payload["error"] = error
await publish_event(f"firmware.progress.{tenant_id}.{device_id}", payload)
async def _run_upgrade(job_id: str) -> None:
"""Internal upgrade implementation."""
# Step 1: Load job
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT j.id, j.device_id, j.tenant_id, j.target_version,
j.architecture, j.channel, j.status, j.confirmed_major_upgrade,
d.ip_address, d.hostname, d.encrypted_credentials,
d.routeros_version, d.encrypted_credentials_transit
FROM firmware_upgrade_jobs j
JOIN devices d ON d.id = j.device_id
WHERE j.id = CAST(:job_id AS uuid)
"""),
{"job_id": job_id},
)
row = result.fetchone()
if not row:
logger.error("Upgrade job %s not found", job_id)
return
(
_, device_id, tenant_id, target_version,
architecture, channel, status, confirmed_major,
ip_address, hostname, encrypted_credentials,
current_version, encrypted_credentials_transit,
) = row
device_id = str(device_id)
tenant_id = str(tenant_id)
hostname = hostname or ip_address
# Skip if already running or completed
if status not in ("pending", "scheduled"):
logger.info("Upgrade job %s already in status %s — skipping", job_id, status)
return
logger.info(
"Starting firmware upgrade for %s (%s): %s -> %s",
hostname, ip_address, current_version, target_version,
)
# Step 2: Update status to downloading
await _update_job(job_id, status="downloading", started_at=datetime.now(timezone.utc))
await _publish_upgrade_progress(tenant_id, device_id, job_id, "downloading", target_version, f"Downloading firmware {target_version} for {hostname}")
# Step 3: Check major version upgrade confirmation
if current_version and target_version:
current_major = current_version.split(".")[0] if current_version else ""
target_major = target_version.split(".")[0]
if current_major != target_major and not confirmed_major:
await _update_job(
job_id,
status="failed",
error_message="Major version upgrade requires explicit confirmation",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Major version upgrade requires explicit confirmation for {hostname}", error="Major version upgrade requires explicit confirmation")
return
# Step 4: Mandatory config backup
logger.info("Running mandatory pre-upgrade backup for %s", hostname)
try:
from app.services import backup_service
backup_result = await backup_service.run_backup(
device_id=device_id,
tenant_id=tenant_id,
trigger_type="pre-upgrade",
)
backup_sha = backup_result["commit_sha"]
await _update_job(job_id, pre_upgrade_backup_sha=backup_sha)
logger.info("Pre-upgrade backup complete: %s", backup_sha[:8])
except Exception as backup_err:
logger.error("Pre-upgrade backup failed for %s: %s", hostname, backup_err)
await _update_job(
job_id,
status="failed",
error_message=f"Pre-upgrade backup failed: {backup_err}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Pre-upgrade backup failed for {hostname}", error=str(backup_err))
return
# Step 5: Download NPK
logger.info("Downloading firmware %s for %s/%s", target_version, architecture, channel)
try:
from app.services.firmware_service import download_firmware
npk_path = await download_firmware(architecture, channel, target_version)
logger.info("Firmware cached at %s", npk_path)
except Exception as dl_err:
logger.error("Firmware download failed: %s", dl_err)
await _update_job(
job_id,
status="failed",
error_message=f"Firmware download failed: {dl_err}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Firmware download failed for {hostname}", error=str(dl_err))
return
# Step 6: Upload NPK to device via SFTP
await _update_job(job_id, status="uploading")
await _publish_upgrade_progress(tenant_id, device_id, job_id, "uploading", target_version, f"Uploading firmware to {hostname}")
# Decrypt device credentials (dual-read: Transit preferred, legacy fallback)
if not encrypted_credentials_transit and not encrypted_credentials:
await _update_job(job_id, status="failed", error_message="Device has no stored credentials")
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"No stored credentials for {hostname}", error="Device has no stored credentials")
return
try:
from app.services.crypto import decrypt_credentials_hybrid
key = settings.get_encryption_key_bytes()
creds_json = await decrypt_credentials_hybrid(
encrypted_credentials_transit, encrypted_credentials, tenant_id, key,
)
creds = json.loads(creds_json)
ssh_username = creds.get("username", "")
ssh_password = creds.get("password", "")
except Exception as cred_err:
await _update_job(
job_id,
status="failed",
error_message=f"Failed to decrypt credentials: {cred_err}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Failed to decrypt credentials for {hostname}", error=str(cred_err))
return
try:
npk_data = Path(npk_path).read_bytes()
npk_filename = Path(npk_path).name
async with asyncssh.connect(
ip_address,
port=22,
username=ssh_username,
password=ssh_password,
known_hosts=None,
connect_timeout=30,
) as conn:
async with conn.start_sftp_client() as sftp:
async with sftp.open(f"/{npk_filename}", "wb") as f:
await f.write(npk_data)
logger.info("Uploaded %s to %s", npk_filename, hostname)
except Exception as upload_err:
logger.error("NPK upload failed for %s: %s", hostname, upload_err)
await _update_job(
job_id,
status="failed",
error_message=f"NPK upload failed: {upload_err}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"NPK upload failed for {hostname}", error=str(upload_err))
return
# Step 7: Trigger reboot
await _update_job(job_id, status="rebooting")
await _publish_upgrade_progress(tenant_id, device_id, job_id, "rebooting", target_version, f"Rebooting {hostname} for firmware install")
try:
async with asyncssh.connect(
ip_address,
port=22,
username=ssh_username,
password=ssh_password,
known_hosts=None,
connect_timeout=30,
) as conn:
# RouterOS will install NPK on boot
await conn.run("/system reboot", check=False)
logger.info("Reboot command sent to %s", hostname)
except Exception as reboot_err:
# Device may drop connection during reboot — this is expected
logger.info("Device %s dropped connection after reboot command (expected): %s", hostname, reboot_err)
# Step 8: Wait for reconnect
logger.info("Waiting %ds before polling %s for reconnect", _INITIAL_WAIT, hostname)
await asyncio.sleep(_INITIAL_WAIT)
reconnected = False
elapsed = 0
while elapsed < _RECONNECT_TIMEOUT:
if await _check_ssh_reachable(ip_address, ssh_username, ssh_password):
reconnected = True
break
await asyncio.sleep(_RECONNECT_POLL_INTERVAL)
elapsed += _RECONNECT_POLL_INTERVAL
if not reconnected:
logger.error("Device %s did not reconnect within %ds", hostname, _RECONNECT_TIMEOUT)
await _update_job(
job_id,
status="failed",
error_message=f"Device did not reconnect within {_RECONNECT_TIMEOUT // 60} minutes after reboot",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Device {hostname} did not reconnect within {_RECONNECT_TIMEOUT // 60} minutes", error="Reconnect timeout")
return
# Step 9: Verify upgrade
await _update_job(job_id, status="verifying")
await _publish_upgrade_progress(tenant_id, device_id, job_id, "verifying", target_version, f"Verifying firmware version on {hostname}")
try:
actual_version = await _get_device_version(ip_address, ssh_username, ssh_password)
if actual_version and target_version in actual_version:
logger.info(
"Firmware upgrade verified for %s: %s",
hostname, actual_version,
)
await _update_job(
job_id,
status="completed",
completed_at=datetime.now(timezone.utc),
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "completed", target_version, f"Firmware upgrade to {target_version} completed on {hostname}")
else:
logger.error(
"Version mismatch for %s: expected %s, got %s",
hostname, target_version, actual_version,
)
await _update_job(
job_id,
status="failed",
error_message=f"Expected {target_version} but got {actual_version}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Version mismatch on {hostname}: expected {target_version}, got {actual_version}", error=f"Expected {target_version} but got {actual_version}")
except Exception as verify_err:
logger.error("Post-upgrade verification failed for %s: %s", hostname, verify_err)
await _update_job(
job_id,
status="failed",
error_message=f"Post-upgrade verification failed: {verify_err}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Post-upgrade verification failed for {hostname}", error=str(verify_err))
async def start_mass_upgrade(rollout_group_id: str) -> dict:
"""Execute a sequential mass firmware upgrade.
Processes upgrade jobs one at a time. If any device fails,
all remaining jobs in the group are paused.
Returns summary dict with completed/failed/paused counts.
"""
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT j.id, j.status, d.hostname
FROM firmware_upgrade_jobs j
JOIN devices d ON d.id = j.device_id
WHERE j.rollout_group_id = CAST(:group_id AS uuid)
ORDER BY j.created_at ASC
"""),
{"group_id": rollout_group_id},
)
jobs = result.fetchall()
if not jobs:
logger.warning("No jobs found for rollout group %s", rollout_group_id)
return {"completed": 0, "failed": 0, "paused": 0}
completed = 0
failed_device = None
for job_id, current_status, hostname in jobs:
job_id_str = str(job_id)
# Only process pending/scheduled jobs
if current_status not in ("pending", "scheduled"):
if current_status == "completed":
completed += 1
continue
logger.info("Mass rollout: upgrading device %s (job %s)", hostname, job_id_str)
await start_upgrade(job_id_str)
# Check if it completed or failed
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("SELECT status FROM firmware_upgrade_jobs WHERE id = CAST(:id AS uuid)"),
{"id": job_id_str},
)
row = result.fetchone()
if row and row[0] == "completed":
completed += 1
elif row and row[0] == "failed":
failed_device = hostname
logger.error("Mass rollout paused: %s failed", hostname)
break
# Pause remaining jobs if one failed
paused = 0
if failed_device:
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
UPDATE firmware_upgrade_jobs
SET status = 'paused'
WHERE rollout_group_id = CAST(:group_id AS uuid)
AND status IN ('pending', 'scheduled')
RETURNING id
"""),
{"group_id": rollout_group_id},
)
paused = len(result.fetchall())
await session.commit()
return {
"completed": completed,
"failed": 1 if failed_device else 0,
"failed_device": failed_device,
"paused": paused,
}
def schedule_upgrade(job_id: str, scheduled_at: datetime) -> None:
"""Schedule a firmware upgrade for future execution via APScheduler."""
from app.services.backup_scheduler import backup_scheduler
backup_scheduler.add_job(
start_upgrade,
trigger="date",
run_date=scheduled_at,
args=[job_id],
id=f"fw_upgrade_{job_id}",
name=f"Firmware upgrade {job_id}",
max_instances=1,
replace_existing=True,
)
logger.info("Scheduled firmware upgrade %s for %s", job_id, scheduled_at)
def schedule_mass_upgrade(rollout_group_id: str, scheduled_at: datetime) -> None:
"""Schedule a mass firmware upgrade for future execution."""
from app.services.backup_scheduler import backup_scheduler
backup_scheduler.add_job(
start_mass_upgrade,
trigger="date",
run_date=scheduled_at,
args=[rollout_group_id],
id=f"fw_mass_upgrade_{rollout_group_id}",
name=f"Mass firmware upgrade {rollout_group_id}",
max_instances=1,
replace_existing=True,
)
logger.info("Scheduled mass firmware upgrade %s for %s", rollout_group_id, scheduled_at)
async def cancel_upgrade(job_id: str) -> None:
"""Cancel a scheduled or pending upgrade."""
from app.services.backup_scheduler import backup_scheduler
# Remove APScheduler job if it exists
try:
backup_scheduler.remove_job(f"fw_upgrade_{job_id}")
except Exception:
pass # Job might not be scheduled
await _update_job(
job_id,
status="failed",
error_message="Cancelled by operator",
completed_at=datetime.now(timezone.utc),
)
logger.info("Upgrade job %s cancelled", job_id)
async def retry_failed_upgrade(job_id: str) -> None:
"""Reset a failed upgrade job to pending and re-execute."""
await _update_job(
job_id,
status="pending",
error_message=None,
started_at=None,
completed_at=None,
)
asyncio.create_task(start_upgrade(job_id))
logger.info("Retrying upgrade job %s", job_id)
async def resume_mass_upgrade(rollout_group_id: str) -> None:
"""Resume a paused mass rollout from where it left off."""
# Reset first paused job to pending, then restart sequential processing
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
UPDATE firmware_upgrade_jobs
SET status = 'pending'
WHERE rollout_group_id = CAST(:group_id AS uuid)
AND status = 'paused'
"""),
{"group_id": rollout_group_id},
)
await session.commit()
asyncio.create_task(start_mass_upgrade(rollout_group_id))
logger.info("Resuming mass rollout %s", rollout_group_id)
async def abort_mass_upgrade(rollout_group_id: str) -> int:
"""Abort all remaining jobs in a paused mass rollout."""
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("""
UPDATE firmware_upgrade_jobs
SET status = 'failed',
error_message = 'Aborted by operator',
completed_at = NOW()
WHERE rollout_group_id = CAST(:group_id AS uuid)
AND status IN ('pending', 'scheduled', 'paused')
RETURNING id
"""),
{"group_id": rollout_group_id},
)
aborted = len(result.fetchall())
await session.commit()
logger.info("Aborted %d remaining jobs in rollout %s", aborted, rollout_group_id)
return aborted
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
async def _update_job(job_id: str, **kwargs) -> None:
"""Update FirmwareUpgradeJob fields."""
sets = []
params: dict = {"job_id": job_id}
for key, value in kwargs.items():
param_name = f"v_{key}"
if value is None and key in ("error_message", "started_at", "completed_at"):
sets.append(f"{key} = NULL")
else:
sets.append(f"{key} = :{param_name}")
params[param_name] = value
if not sets:
return
async with AdminAsyncSessionLocal() as session:
await session.execute(
text(f"""
UPDATE firmware_upgrade_jobs
SET {', '.join(sets)}
WHERE id = CAST(:job_id AS uuid)
"""),
params,
)
await session.commit()
async def _check_ssh_reachable(ip: str, username: str, password: str) -> bool:
"""Check if a device is reachable via SSH."""
try:
async with asyncssh.connect(
ip,
port=22,
username=username,
password=password,
known_hosts=None,
connect_timeout=15,
) as conn:
await conn.run("/system identity print", check=True)
return True
except Exception:
return False
async def _get_device_version(ip: str, username: str, password: str) -> str:
"""Get the current RouterOS version from a device via SSH."""
async with asyncssh.connect(
ip,
port=22,
username=username,
password=password,
known_hosts=None,
connect_timeout=30,
) as conn:
result = await conn.run("/system resource print", check=True)
# Parse version from output: "version: 7.17 (stable)"
for line in result.stdout.splitlines():
if "version" in line.lower():
parts = line.split(":", 1)
if len(parts) == 2:
return parts[1].strip()
return ""

View File

@@ -0,0 +1,392 @@
"""WireGuard VPN management service.
Handles key generation, peer management, config file sync, and RouterOS command generation.
"""
import base64
import ipaddress
import json
import os
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
import structlog
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
from cryptography.hazmat.primitives.serialization import (
Encoding,
NoEncryption,
PrivateFormat,
PublicFormat,
)
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.models.device import Device
from app.models.vpn import VpnConfig, VpnPeer
from app.services.crypto import decrypt_credentials, encrypt_credentials, encrypt_credentials_transit
logger = structlog.get_logger(__name__)
# ── Key Generation ──
def generate_wireguard_keypair() -> tuple[str, str]:
"""Generate a WireGuard X25519 keypair. Returns (private_key_b64, public_key_b64)."""
private_key = X25519PrivateKey.generate()
priv_bytes = private_key.private_bytes(Encoding.Raw, PrivateFormat.Raw, NoEncryption())
pub_bytes = private_key.public_key().public_bytes(Encoding.Raw, PublicFormat.Raw)
return base64.b64encode(priv_bytes).decode(), base64.b64encode(pub_bytes).decode()
def generate_preshared_key() -> str:
"""Generate a WireGuard preshared key (32 random bytes, base64)."""
return base64.b64encode(os.urandom(32)).decode()
# ── Config File Management ──
def _get_wg_config_path() -> Path:
"""Return the path to the shared WireGuard config directory."""
return Path(os.getenv("WIREGUARD_CONFIG_PATH", "/data/wireguard"))
async def sync_wireguard_config(db: AsyncSession, tenant_id: uuid.UUID) -> None:
"""Regenerate wg0.conf from database state and write to shared volume."""
config = await get_vpn_config(db, tenant_id)
if not config or not config.is_enabled:
return
key_bytes = settings.get_encryption_key_bytes()
server_private_key = decrypt_credentials(config.server_private_key, key_bytes)
result = await db.execute(
select(VpnPeer).where(VpnPeer.tenant_id == tenant_id, VpnPeer.is_enabled.is_(True))
)
peers = result.scalars().all()
# Build wg0.conf
lines = [
"[Interface]",
f"Address = {config.server_address}",
f"ListenPort = {config.server_port}",
f"PrivateKey = {server_private_key}",
"",
]
for peer in peers:
peer_ip = peer.assigned_ip.split("/")[0] # strip CIDR for AllowedIPs
allowed_ips = [f"{peer_ip}/32"]
if peer.additional_allowed_ips:
# Comma-separated additional subnets (e.g. site-to-site routing)
extra = [s.strip() for s in peer.additional_allowed_ips.split(",") if s.strip()]
allowed_ips.extend(extra)
lines.append("[Peer]")
lines.append(f"PublicKey = {peer.peer_public_key}")
if peer.preshared_key:
psk = decrypt_credentials(peer.preshared_key, key_bytes)
lines.append(f"PresharedKey = {psk}")
lines.append(f"AllowedIPs = {', '.join(allowed_ips)}")
lines.append("")
config_dir = _get_wg_config_path()
wg_confs_dir = config_dir / "wg_confs"
wg_confs_dir.mkdir(parents=True, exist_ok=True)
conf_path = wg_confs_dir / "wg0.conf"
conf_path.write_text("\n".join(lines))
# Signal WireGuard container to reload config
reload_flag = wg_confs_dir / ".reload"
reload_flag.write_text("1")
logger.info("wireguard config synced", tenant_id=str(tenant_id), peers=len(peers))
# ── Live Status ──
def read_wg_status() -> dict[str, dict]:
"""Read live WireGuard peer status from the shared volume.
The WireGuard container writes wg_status.json every 15 seconds
with output from `wg show wg0 dump`. Returns a dict keyed by
peer public key with handshake timestamp and transfer stats.
"""
status_path = _get_wg_config_path() / "wg_status.json"
if not status_path.exists():
return {}
try:
data = json.loads(status_path.read_text())
return {entry["public_key"]: entry for entry in data}
except (json.JSONDecodeError, KeyError, OSError):
return {}
def get_peer_handshake(wg_status: dict[str, dict], public_key: str) -> Optional[datetime]:
"""Get last_handshake datetime for a peer from live WireGuard status."""
entry = wg_status.get(public_key)
if not entry:
return None
ts = entry.get("last_handshake", 0)
if ts and ts > 0:
return datetime.fromtimestamp(ts, tz=timezone.utc)
return None
# ── CRUD Operations ──
async def get_vpn_config(db: AsyncSession, tenant_id: uuid.UUID) -> Optional[VpnConfig]:
"""Get the VPN config for a tenant."""
result = await db.execute(select(VpnConfig).where(VpnConfig.tenant_id == tenant_id))
return result.scalar_one_or_none()
async def setup_vpn(
db: AsyncSession, tenant_id: uuid.UUID, endpoint: Optional[str] = None
) -> VpnConfig:
"""Initialize VPN for a tenant — generates server keys and creates config."""
existing = await get_vpn_config(db, tenant_id)
if existing:
raise ValueError("VPN already configured for this tenant")
private_key_b64, public_key_b64 = generate_wireguard_keypair()
key_bytes = settings.get_encryption_key_bytes()
encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
config = VpnConfig(
tenant_id=tenant_id,
server_private_key=encrypted_private,
server_public_key=public_key_b64,
endpoint=endpoint,
is_enabled=True,
)
db.add(config)
await db.flush()
await sync_wireguard_config(db, tenant_id)
return config
async def update_vpn_config(
db: AsyncSession, tenant_id: uuid.UUID, endpoint: Optional[str] = None, is_enabled: Optional[bool] = None
) -> VpnConfig:
"""Update VPN config settings."""
config = await get_vpn_config(db, tenant_id)
if not config:
raise ValueError("VPN not configured for this tenant")
if endpoint is not None:
config.endpoint = endpoint
if is_enabled is not None:
config.is_enabled = is_enabled
await db.flush()
await sync_wireguard_config(db, tenant_id)
return config
async def get_peers(db: AsyncSession, tenant_id: uuid.UUID) -> list[VpnPeer]:
"""List all VPN peers for a tenant."""
result = await db.execute(
select(VpnPeer).where(VpnPeer.tenant_id == tenant_id).order_by(VpnPeer.created_at)
)
return list(result.scalars().all())
async def _next_available_ip(db: AsyncSession, tenant_id: uuid.UUID, config: VpnConfig) -> str:
"""Allocate the next available IP in the VPN subnet."""
# Parse subnet: e.g. "10.10.0.0/24" → start from .2 (server is .1)
network = ipaddress.ip_network(config.subnet, strict=False)
hosts = list(network.hosts())
# Get already assigned IPs
result = await db.execute(select(VpnPeer.assigned_ip).where(VpnPeer.tenant_id == tenant_id))
used_ips = {row[0].split("/")[0] for row in result.all()}
used_ips.add(config.server_address.split("/")[0]) # exclude server IP
for host in hosts[1:]: # skip .1 (server)
if str(host) not in used_ips:
return f"{host}/24"
raise ValueError("No available IPs in VPN subnet")
async def add_peer(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UUID, additional_allowed_ips: Optional[str] = None) -> VpnPeer:
"""Add a device as a VPN peer."""
config = await get_vpn_config(db, tenant_id)
if not config:
raise ValueError("VPN not configured — enable VPN first")
# Check device exists
device = await db.execute(select(Device).where(Device.id == device_id, Device.tenant_id == tenant_id))
if not device.scalar_one_or_none():
raise ValueError("Device not found")
# Check if already a peer
existing = await db.execute(select(VpnPeer).where(VpnPeer.device_id == device_id))
if existing.scalar_one_or_none():
raise ValueError("Device is already a VPN peer")
private_key_b64, public_key_b64 = generate_wireguard_keypair()
psk = generate_preshared_key()
key_bytes = settings.get_encryption_key_bytes()
encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
encrypted_psk = encrypt_credentials(psk, key_bytes)
assigned_ip = await _next_available_ip(db, tenant_id, config)
peer = VpnPeer(
tenant_id=tenant_id,
device_id=device_id,
peer_private_key=encrypted_private,
peer_public_key=public_key_b64,
preshared_key=encrypted_psk,
assigned_ip=assigned_ip,
additional_allowed_ips=additional_allowed_ips,
)
db.add(peer)
await db.flush()
await sync_wireguard_config(db, tenant_id)
return peer
async def remove_peer(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.UUID) -> None:
"""Remove a VPN peer."""
result = await db.execute(
select(VpnPeer).where(VpnPeer.id == peer_id, VpnPeer.tenant_id == tenant_id)
)
peer = result.scalar_one_or_none()
if not peer:
raise ValueError("Peer not found")
await db.delete(peer)
await db.flush()
await sync_wireguard_config(db, tenant_id)
async def get_peer_config(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.UUID) -> dict:
"""Get the full config for a peer — includes private key for device setup."""
config = await get_vpn_config(db, tenant_id)
if not config:
raise ValueError("VPN not configured")
result = await db.execute(
select(VpnPeer).where(VpnPeer.id == peer_id, VpnPeer.tenant_id == tenant_id)
)
peer = result.scalar_one_or_none()
if not peer:
raise ValueError("Peer not found")
key_bytes = settings.get_encryption_key_bytes()
private_key = decrypt_credentials(peer.peer_private_key, key_bytes)
psk = decrypt_credentials(peer.preshared_key, key_bytes) if peer.preshared_key else None
endpoint = config.endpoint or "YOUR_SERVER_IP:51820"
peer_ip_no_cidr = peer.assigned_ip.split("/")[0]
routeros_commands = [
f'/interface wireguard add name=wg-portal listen-port=13231 private-key="{private_key}"',
f'/interface wireguard peers add interface=wg-portal public-key="{config.server_public_key}" '
f'endpoint-address={endpoint.split(":")[0]} endpoint-port={endpoint.split(":")[-1]} '
f'allowed-address={config.subnet} persistent-keepalive=25'
+ (f' preshared-key="{psk}"' if psk else ""),
f"/ip address add address={peer.assigned_ip} interface=wg-portal",
]
return {
"peer_private_key": private_key,
"peer_public_key": peer.peer_public_key,
"assigned_ip": peer.assigned_ip,
"server_public_key": config.server_public_key,
"server_endpoint": endpoint,
"allowed_ips": config.subnet,
"routeros_commands": routeros_commands,
}
async def onboard_device(
db: AsyncSession,
tenant_id: uuid.UUID,
hostname: str,
username: str,
password: str,
) -> dict:
"""Create device + VPN peer in one transaction. Returns device, peer, and RouterOS commands.
Unlike regular device creation, this skips TCP connectivity checks because
the VPN tunnel isn't up yet. The device IP is set to the VPN-assigned address.
"""
config = await get_vpn_config(db, tenant_id)
if not config:
raise ValueError("VPN not configured — enable VPN first")
# Allocate VPN IP before creating device
assigned_ip = await _next_available_ip(db, tenant_id, config)
vpn_ip_no_cidr = assigned_ip.split("/")[0]
# Create device with VPN IP (skip TCP check — tunnel not up yet)
credentials_json = json.dumps({"username": username, "password": password})
transit_ciphertext = await encrypt_credentials_transit(credentials_json, str(tenant_id))
device = Device(
tenant_id=tenant_id,
hostname=hostname,
ip_address=vpn_ip_no_cidr,
api_port=8728,
api_ssl_port=8729,
encrypted_credentials_transit=transit_ciphertext,
status="unknown",
)
db.add(device)
await db.flush()
# Create VPN peer linked to this device
private_key_b64, public_key_b64 = generate_wireguard_keypair()
psk = generate_preshared_key()
key_bytes = settings.get_encryption_key_bytes()
encrypted_private = encrypt_credentials(private_key_b64, key_bytes)
encrypted_psk = encrypt_credentials(psk, key_bytes)
peer = VpnPeer(
tenant_id=tenant_id,
device_id=device.id,
peer_private_key=encrypted_private,
peer_public_key=public_key_b64,
preshared_key=encrypted_psk,
assigned_ip=assigned_ip,
)
db.add(peer)
await db.flush()
await sync_wireguard_config(db, tenant_id)
# Generate RouterOS commands
endpoint = config.endpoint or "YOUR_SERVER_IP:51820"
psk_decrypted = decrypt_credentials(encrypted_psk, key_bytes)
routeros_commands = [
f'/interface wireguard add name=wg-portal listen-port=13231 private-key="{private_key_b64}"',
f'/interface wireguard peers add interface=wg-portal public-key="{config.server_public_key}" '
f'endpoint-address={endpoint.split(":")[0]} endpoint-port={endpoint.split(":")[-1]} '
f'allowed-address={config.subnet} persistent-keepalive=25'
f' preshared-key="{psk_decrypted}"',
f"/ip address add address={assigned_ip} interface=wg-portal",
]
return {
"device_id": device.id,
"peer_id": peer.id,
"hostname": hostname,
"assigned_ip": assigned_ip,
"routeros_commands": routeros_commands,
}