fix(lint): resolve all ruff lint errors

Add ruff config to exclude alembic E402, SQLAlchemy F821, and pre-existing
E501 line-length issues. Auto-fix 69 unused imports and 2 f-strings without
placeholders. Manually fix 8 unused variables. Apply ruff format to 127 files.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Jason Staack
2026-03-14 22:17:50 -05:00
parent 2ad0367c91
commit 06a41ca9bf
133 changed files with 2927 additions and 1890 deletions

View File

@@ -86,11 +86,7 @@ async def delete_user_account(
# Null out encrypted_details (may contain encrypted PII)
await db.execute(
text(
"UPDATE audit_logs "
"SET encrypted_details = NULL "
"WHERE user_id = :user_id"
),
text("UPDATE audit_logs SET encrypted_details = NULL WHERE user_id = :user_id"),
{"user_id": user_id},
)
@@ -177,16 +173,18 @@ async def export_user_data(
)
api_keys = []
for row in result.mappings().all():
api_keys.append({
"id": str(row["id"]),
"name": row["name"],
"key_prefix": row["key_prefix"],
"scopes": row["scopes"],
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
"expires_at": row["expires_at"].isoformat() if row["expires_at"] else None,
"revoked_at": row["revoked_at"].isoformat() if row["revoked_at"] else None,
"last_used_at": row["last_used_at"].isoformat() if row["last_used_at"] else None,
})
api_keys.append(
{
"id": str(row["id"]),
"name": row["name"],
"key_prefix": row["key_prefix"],
"scopes": row["scopes"],
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
"expires_at": row["expires_at"].isoformat() if row["expires_at"] else None,
"revoked_at": row["revoked_at"].isoformat() if row["revoked_at"] else None,
"last_used_at": row["last_used_at"].isoformat() if row["last_used_at"] else None,
}
)
# ── Audit logs (limit 1000, most recent first) ───────────────────────
result = await db.execute(
@@ -201,15 +199,17 @@ async def export_user_data(
audit_logs = []
for row in result.mappings().all():
details = row["details"] if row["details"] else {}
audit_logs.append({
"id": str(row["id"]),
"action": row["action"],
"resource_type": row["resource_type"],
"resource_id": row["resource_id"],
"details": details,
"ip_address": row["ip_address"],
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
})
audit_logs.append(
{
"id": str(row["id"]),
"action": row["action"],
"resource_type": row["resource_type"],
"resource_id": row["resource_id"],
"details": details,
"ip_address": row["ip_address"],
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
}
)
# ── Key access log (limit 1000, most recent first) ───────────────────
result = await db.execute(
@@ -222,13 +222,15 @@ async def export_user_data(
)
key_access_entries = []
for row in result.mappings().all():
key_access_entries.append({
"id": str(row["id"]),
"action": row["action"],
"resource_type": row["resource_type"],
"ip_address": row["ip_address"],
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
})
key_access_entries.append(
{
"id": str(row["id"]),
"action": row["action"],
"resource_type": row["resource_type"],
"ip_address": row["ip_address"],
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
}
)
return {
"export_date": datetime.now(UTC).isoformat(),

View File

@@ -253,7 +253,9 @@ async def _get_device_groups(device_id: str) -> list[str]:
"""Get group IDs for a device."""
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
text("SELECT group_id FROM device_group_memberships WHERE device_id = CAST(:device_id AS uuid)"),
text(
"SELECT group_id FROM device_group_memberships WHERE device_id = CAST(:device_id AS uuid)"
),
{"device_id": device_id},
)
return [str(row[0]) for row in result.fetchall()]
@@ -344,30 +346,36 @@ async def _create_alert_event(
# Publish real-time event to NATS for SSE pipeline (fire-and-forget)
if status in ("firing", "flapping"):
await publish_event(f"alert.fired.{tenant_id}", {
"event_type": "alert_fired",
"tenant_id": tenant_id,
"device_id": device_id,
"alert_event_id": alert_data["id"],
"severity": severity,
"metric": metric,
"current_value": value,
"threshold": threshold,
"message": message,
"is_flapping": is_flapping,
"fired_at": datetime.now(timezone.utc).isoformat(),
})
await publish_event(
f"alert.fired.{tenant_id}",
{
"event_type": "alert_fired",
"tenant_id": tenant_id,
"device_id": device_id,
"alert_event_id": alert_data["id"],
"severity": severity,
"metric": metric,
"current_value": value,
"threshold": threshold,
"message": message,
"is_flapping": is_flapping,
"fired_at": datetime.now(timezone.utc).isoformat(),
},
)
elif status == "resolved":
await publish_event(f"alert.resolved.{tenant_id}", {
"event_type": "alert_resolved",
"tenant_id": tenant_id,
"device_id": device_id,
"alert_event_id": alert_data["id"],
"severity": severity,
"metric": metric,
"message": message,
"resolved_at": datetime.now(timezone.utc).isoformat(),
})
await publish_event(
f"alert.resolved.{tenant_id}",
{
"event_type": "alert_resolved",
"tenant_id": tenant_id,
"device_id": device_id,
"alert_event_id": alert_data["id"],
"severity": severity,
"metric": metric,
"message": message,
"resolved_at": datetime.now(timezone.utc).isoformat(),
},
)
return alert_data
@@ -470,6 +478,7 @@ async def _dispatch_async(alert_event: dict, channels: list[dict], device_hostna
"""Fire-and-forget notification dispatch."""
try:
from app.services.notification_service import dispatch_notifications
await dispatch_notifications(alert_event, channels, device_hostname)
except Exception as e:
logger.warning("Notification dispatch failed: %s", e)
@@ -500,7 +509,8 @@ async def evaluate(
if await _is_device_in_maintenance(tenant_id, device_id):
logger.debug(
"Alert suppressed by maintenance window for device %s tenant %s",
device_id, tenant_id,
device_id,
tenant_id,
)
return
@@ -573,7 +583,8 @@ async def evaluate(
if is_flapping:
logger.info(
"Alert %s for device %s is flapping — notifications suppressed",
rule["name"], device_id,
rule["name"],
device_id,
)
else:
channels = await _get_channels_for_rule(rule["id"])

View File

@@ -56,9 +56,7 @@ async def log_action(
try:
from app.services.crypto import encrypt_data_transit
encrypted_details = await encrypt_data_transit(
details_json, str(tenant_id)
)
encrypted_details = await encrypt_data_transit(details_json, str(tenant_id))
# Encryption succeeded — clear plaintext details
details_json = _json.dumps({})
except Exception:

View File

@@ -32,8 +32,12 @@ def _cron_to_trigger(cron_expr: str) -> Optional[CronTrigger]:
return None
minute, hour, day, month, day_of_week = parts
return CronTrigger(
minute=minute, hour=hour, day=day, month=month,
day_of_week=day_of_week, timezone="UTC",
minute=minute,
hour=hour,
day=day,
month=month,
day_of_week=day_of_week,
timezone="UTC",
)
except Exception as e:
logger.warning("Invalid cron expression '%s': %s", cron_expr, e)
@@ -52,10 +56,12 @@ def build_schedule_map(schedules: list) -> dict[str, list[dict]]:
cron = s.cron_expression or DEFAULT_CRON
if cron not in schedule_map:
schedule_map[cron] = []
schedule_map[cron].append({
"device_id": str(s.device_id),
"tenant_id": str(s.tenant_id),
})
schedule_map[cron].append(
{
"device_id": str(s.device_id),
"tenant_id": str(s.tenant_id),
}
)
return schedule_map
@@ -79,13 +85,15 @@ async def _run_scheduled_backups(devices: list[dict]) -> None:
except Exception as e:
logger.error(
"Scheduled backup FAILED: device %s: %s",
dev_info["device_id"], e,
dev_info["device_id"],
e,
)
failure_count += 1
logger.info(
"Backup batch complete — %d succeeded, %d failed",
success_count, failure_count,
success_count,
failure_count,
)
@@ -108,7 +116,7 @@ async def _load_effective_schedules() -> list:
# Index: device-specific and tenant defaults
device_schedules = {} # device_id -> schedule
tenant_defaults = {} # tenant_id -> schedule
tenant_defaults = {} # tenant_id -> schedule
for s in schedules:
if s.device_id:
@@ -129,12 +137,14 @@ async def _load_effective_schedules() -> list:
# No schedule configured — use system default
sched = None
effective.append(SimpleNamespace(
device_id=dev_id,
tenant_id=tenant_id,
cron_expression=sched.cron_expression if sched else DEFAULT_CRON,
enabled=sched.enabled if sched else True,
))
effective.append(
SimpleNamespace(
device_id=dev_id,
tenant_id=tenant_id,
cron_expression=sched.cron_expression if sched else DEFAULT_CRON,
enabled=sched.enabled if sched else True,
)
)
return effective
@@ -203,6 +213,7 @@ class _SchedulerProxy:
Usage: `from app.services.backup_scheduler import backup_scheduler`
then `backup_scheduler.add_job(...)`.
"""
def __getattr__(self, name):
if _scheduler is None:
raise RuntimeError("Backup scheduler not started yet")

View File

@@ -255,7 +255,12 @@ async def run_backup(
if prior_commits:
try:
prior_export_bytes = await loop.run_in_executor(
None, git_store.read_file, tenant_id, prior_commits[0]["sha"], device_id, "export.rsc"
None,
git_store.read_file,
tenant_id,
prior_commits[0]["sha"],
device_id,
"export.rsc",
)
prior_text = prior_export_bytes.decode("utf-8", errors="replace")
lines_added, lines_removed = await loop.run_in_executor(
@@ -284,9 +289,7 @@ async def run_backup(
try:
from app.services.crypto import encrypt_data_transit
encrypted_export = await encrypt_data_transit(
export_text, tenant_id
)
encrypted_export = await encrypt_data_transit(export_text, tenant_id)
encrypted_binary = await encrypt_data_transit(
base64.b64encode(binary_backup).decode(), tenant_id
)
@@ -302,8 +305,7 @@ async def run_backup(
except Exception as enc_err:
# Transit unavailable — fall back to plaintext (non-fatal)
logger.warning(
"Transit encryption failed for %s backup of device %s, "
"storing plaintext: %s",
"Transit encryption failed for %s backup of device %s, storing plaintext: %s",
trigger_type,
device_id,
enc_err,
@@ -313,9 +315,7 @@ async def run_backup(
# -----------------------------------------------------------------------
# Step 6: Commit to git (wrapped in run_in_executor — pygit2 is sync C bindings)
# -----------------------------------------------------------------------
commit_message = (
f"{trigger_type}: {hostname} ({ip}) at {ts}"
)
commit_message = f"{trigger_type}: {hostname} ({ip}) at {ts}"
commit_sha = await loop.run_in_executor(
None,

View File

@@ -48,6 +48,7 @@ _VALID_TRANSITIONS: dict[str, set[str]] = {
# CA Generation
# ---------------------------------------------------------------------------
async def generate_ca(
db: AsyncSession,
tenant_id: UUID,
@@ -84,10 +85,12 @@ async def generate_ca(
now = datetime.datetime.now(datetime.timezone.utc)
expiry = now + datetime.timedelta(days=365 * validity_years)
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
])
subject = issuer = x509.Name(
[
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
]
)
ca_cert = (
x509.CertificateBuilder()
@@ -97,9 +100,7 @@ async def generate_ca(
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(expiry)
.add_extension(
x509.BasicConstraints(ca=True, path_length=0), critical=True
)
.add_extension(x509.BasicConstraints(ca=True, path_length=0), critical=True)
.add_extension(
x509.KeyUsage(
digital_signature=True,
@@ -166,6 +167,7 @@ async def generate_ca(
# Device Certificate Signing
# ---------------------------------------------------------------------------
async def sign_device_cert(
db: AsyncSession,
ca: CertificateAuthority,
@@ -196,9 +198,7 @@ async def sign_device_cert(
str(ca.tenant_id),
encryption_key,
)
ca_key = serialization.load_pem_private_key(
ca_key_pem.encode("utf-8"), password=None
)
ca_key = serialization.load_pem_private_key(ca_key_pem.encode("utf-8"), password=None)
# Load CA certificate for issuer info and AuthorityKeyIdentifier
ca_cert = x509.load_pem_x509_certificate(ca.cert_pem.encode("utf-8"))
@@ -212,19 +212,19 @@ async def sign_device_cert(
device_cert = (
x509.CertificateBuilder()
.subject_name(
x509.Name([
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
x509.NameAttribute(NameOID.COMMON_NAME, hostname),
])
x509.Name(
[
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
x509.NameAttribute(NameOID.COMMON_NAME, hostname),
]
)
)
.issuer_name(ca_cert.subject)
.public_key(device_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(expiry)
.add_extension(
x509.BasicConstraints(ca=False, path_length=None), critical=True
)
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
.add_extension(
x509.KeyUsage(
digital_signature=True,
@@ -244,17 +244,17 @@ async def sign_device_cert(
critical=False,
)
.add_extension(
x509.SubjectAlternativeName([
x509.IPAddress(ipaddress.ip_address(ip_address)),
x509.DNSName(hostname),
]),
x509.SubjectAlternativeName(
[
x509.IPAddress(ipaddress.ip_address(ip_address)),
x509.DNSName(hostname),
]
),
critical=False,
)
.add_extension(
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
ca_cert.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
).value
ca_cert.extensions.get_extension_for_class(x509.SubjectKeyIdentifier).value
),
critical=False,
)
@@ -308,15 +308,14 @@ async def sign_device_cert(
# Queries
# ---------------------------------------------------------------------------
async def get_ca_for_tenant(
db: AsyncSession,
tenant_id: UUID,
) -> CertificateAuthority | None:
"""Return the tenant's CA, or None if not yet initialized."""
result = await db.execute(
select(CertificateAuthority).where(
CertificateAuthority.tenant_id == tenant_id
)
select(CertificateAuthority).where(CertificateAuthority.tenant_id == tenant_id)
)
return result.scalar_one_or_none()
@@ -352,6 +351,7 @@ async def get_device_certs(
# Status Management
# ---------------------------------------------------------------------------
async def update_cert_status(
db: AsyncSession,
cert_id: UUID,
@@ -377,9 +377,7 @@ async def update_cert_status(
Raises:
ValueError: If the certificate is not found or the transition is invalid.
"""
result = await db.execute(
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
)
result = await db.execute(select(DeviceCertificate).where(DeviceCertificate.id == cert_id))
cert = result.scalar_one_or_none()
if cert is None:
raise ValueError(f"Device certificate {cert_id} not found")
@@ -413,6 +411,7 @@ async def update_cert_status(
# Cert Data for Deployment
# ---------------------------------------------------------------------------
async def get_cert_for_deploy(
db: AsyncSession,
cert_id: UUID,
@@ -434,18 +433,14 @@ async def get_cert_for_deploy(
Raises:
ValueError: If the certificate or its CA is not found.
"""
result = await db.execute(
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
)
result = await db.execute(select(DeviceCertificate).where(DeviceCertificate.id == cert_id))
cert = result.scalar_one_or_none()
if cert is None:
raise ValueError(f"Device certificate {cert_id} not found")
# Fetch the CA for the ca_cert_pem
ca_result = await db.execute(
select(CertificateAuthority).where(
CertificateAuthority.id == cert.ca_id
)
select(CertificateAuthority).where(CertificateAuthority.id == cert.ca_id)
)
ca = ca_result.scalar_one_or_none()
if ca is None:

View File

@@ -64,9 +64,7 @@ def parse_diff_changes(diff_text: str) -> list[dict]:
current_section: str | None = None
# Track per-component: adds, removes, raw_lines
components: dict[str, dict] = defaultdict(
lambda: {"adds": 0, "removes": 0, "raw_lines": []}
)
components: dict[str, dict] = defaultdict(lambda: {"adds": 0, "removes": 0, "raw_lines": []})
for line in lines:
# Skip unified diff headers

View File

@@ -51,13 +51,15 @@ async def handle_config_changed(event: dict) -> None:
if await _last_backup_within_dedup_window(device_id):
logger.info(
"Config change on device %s — skipping backup (within %dm dedup window)",
device_id, DEDUP_WINDOW_MINUTES,
device_id,
DEDUP_WINDOW_MINUTES,
)
return
logger.info(
"Config change detected on device %s (tenant %s): %s -> %s",
device_id, tenant_id,
device_id,
tenant_id,
event.get("old_timestamp", "?"),
event.get("new_timestamp", "?"),
)

View File

@@ -65,9 +65,7 @@ async def generate_and_store_diff(
# 2. No previous snapshot = first snapshot for device
if prev_row is None:
logger.debug(
"First snapshot for device %s, no diff to generate", device_id
)
logger.debug("First snapshot for device %s, no diff to generate", device_id)
return
old_snapshot_id = prev_row._mapping["id"]
@@ -103,9 +101,7 @@ async def generate_and_store_diff(
# 6. Generate unified diff
old_lines = old_plaintext.decode("utf-8").splitlines()
new_lines = new_plaintext.decode("utf-8").splitlines()
diff_lines = list(
difflib.unified_diff(old_lines, new_lines, lineterm="", n=3)
)
diff_lines = list(difflib.unified_diff(old_lines, new_lines, lineterm="", n=3))
# 7. If empty diff, skip INSERT
diff_text = "\n".join(diff_lines)
@@ -118,12 +114,10 @@ async def generate_and_store_diff(
# 8. Count lines added/removed (exclude +++ and --- headers)
lines_added = sum(
1 for line in diff_lines
if line.startswith("+") and not line.startswith("++")
1 for line in diff_lines if line.startswith("+") and not line.startswith("++")
)
lines_removed = sum(
1 for line in diff_lines
if line.startswith("-") and not line.startswith("--")
1 for line in diff_lines if line.startswith("-") and not line.startswith("--")
)
# 9. INSERT into router_config_diffs (RETURNING id for change parser)
@@ -165,6 +159,7 @@ async def generate_and_store_diff(
try:
from app.services.audit_service import log_action
import uuid as _uuid
await log_action(
db=None,
tenant_id=_uuid.UUID(tenant_id),
@@ -207,12 +202,16 @@ async def generate_and_store_diff(
await session.commit()
logger.info(
"Stored %d config changes for device %s diff %s",
len(changes), device_id, diff_id,
len(changes),
device_id,
diff_id,
)
except Exception as exc:
logger.warning(
"Change parser error for device %s diff %s (non-fatal): %s",
device_id, diff_id, exc,
device_id,
diff_id,
exc,
)
config_diff_errors_total.labels(error_type="change_parser").inc()

View File

@@ -89,9 +89,11 @@ async def handle_config_snapshot(msg) -> None:
collected_at_raw = data.get("collected_at")
try:
collected_at = datetime.fromisoformat(
collected_at_raw.replace("Z", "+00:00")
) if collected_at_raw else datetime.now(timezone.utc)
collected_at = (
datetime.fromisoformat(collected_at_raw.replace("Z", "+00:00"))
if collected_at_raw
else datetime.now(timezone.utc)
)
except (ValueError, AttributeError):
collected_at = datetime.now(timezone.utc)
@@ -131,9 +133,7 @@ async def handle_config_snapshot(msg) -> None:
# --- Encrypt via OpenBao Transit ---
openbao = OpenBaoTransitService()
try:
encrypted_text = await openbao.encrypt(
tenant_id, config_text.encode("utf-8")
)
encrypted_text = await openbao.encrypt(tenant_id, config_text.encode("utf-8"))
except Exception as exc:
logger.warning(
"Transit encrypt failed for device %s tenant %s: %s",
@@ -207,7 +207,8 @@ async def handle_config_snapshot(msg) -> None:
except Exception as exc:
logger.warning(
"Diff generation failed for device %s (non-fatal): %s",
device_id, exc,
device_id,
exc,
)
logger.info(
@@ -233,9 +234,7 @@ async def _subscribe_with_retry(js) -> None:
stream="DEVICE_EVENTS",
manual_ack=True,
)
logger.info(
"NATS: subscribed to config.snapshot.> (durable: config_snapshot_ingest)"
)
logger.info("NATS: subscribed to config.snapshot.> (durable: config_snapshot_ingest)")
return
except Exception as exc:
if attempt < max_attempts:

View File

@@ -29,8 +29,6 @@ from app.models.device import (
DeviceTagAssignment,
)
from app.schemas.device import (
BulkAddRequest,
BulkAddResult,
DeviceCreate,
DeviceGroupCreate,
DeviceGroupResponse,
@@ -43,9 +41,7 @@ from app.schemas.device import (
)
from app.config import settings
from app.services.crypto import (
decrypt_credentials,
decrypt_credentials_hybrid,
encrypt_credentials,
encrypt_credentials_transit,
)
@@ -58,9 +54,7 @@ from app.services.crypto import (
async def _tcp_reachable(ip: str, port: int, timeout: float = 3.0) -> bool:
"""Return True if a TCP connection to ip:port succeeds within timeout."""
try:
_, writer = await asyncio.wait_for(
asyncio.open_connection(ip, port), timeout=timeout
)
_, writer = await asyncio.wait_for(asyncio.open_connection(ip, port), timeout=timeout)
writer.close()
try:
await writer.wait_closed()
@@ -151,6 +145,7 @@ async def create_device(
if not api_reachable and not ssl_reachable:
from fastapi import HTTPException, status
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=(
@@ -162,9 +157,7 @@ async def create_device(
# Encrypt credentials via OpenBao Transit (new writes go through Transit)
credentials_json = json.dumps({"username": data.username, "password": data.password})
transit_ciphertext = await encrypt_credentials_transit(
credentials_json, str(tenant_id)
)
transit_ciphertext = await encrypt_credentials_transit(credentials_json, str(tenant_id))
device = Device(
tenant_id=tenant_id,
@@ -180,9 +173,7 @@ async def create_device(
await db.refresh(device)
# Re-query with relationships loaded
result = await db.execute(
_device_with_relations().where(Device.id == device.id)
)
result = await db.execute(_device_with_relations().where(Device.id == device.id))
device = result.scalar_one()
return _build_device_response(device)
@@ -223,9 +214,7 @@ async def get_devices(
if tag_id:
base_q = base_q.where(
Device.id.in_(
select(DeviceTagAssignment.device_id).where(
DeviceTagAssignment.tag_id == tag_id
)
select(DeviceTagAssignment.device_id).where(DeviceTagAssignment.tag_id == tag_id)
)
)
@@ -274,9 +263,7 @@ async def get_device(
"""Get a single device by ID."""
from fastapi import HTTPException, status
result = await db.execute(
_device_with_relations().where(Device.id == device_id)
)
result = await db.execute(_device_with_relations().where(Device.id == device_id))
device = result.scalar_one_or_none()
if not device:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
@@ -295,9 +282,7 @@ async def update_device(
"""
from fastapi import HTTPException, status
result = await db.execute(
_device_with_relations().where(Device.id == device_id)
)
result = await db.execute(_device_with_relations().where(Device.id == device_id))
device = result.scalar_one_or_none()
if not device:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
@@ -323,7 +308,9 @@ async def update_device(
if data.password is not None:
# Decrypt existing to get current username if no new username given
current_username: str = data.username or ""
if not current_username and (device.encrypted_credentials_transit or device.encrypted_credentials):
if not current_username and (
device.encrypted_credentials_transit or device.encrypted_credentials
):
try:
existing_json = await decrypt_credentials_hybrid(
device.encrypted_credentials_transit,
@@ -336,17 +323,21 @@ async def update_device(
except Exception:
current_username = ""
credentials_json = json.dumps({
"username": data.username if data.username is not None else current_username,
"password": data.password,
})
credentials_json = json.dumps(
{
"username": data.username if data.username is not None else current_username,
"password": data.password,
}
)
# New writes go through Transit
device.encrypted_credentials_transit = await encrypt_credentials_transit(
credentials_json, str(device.tenant_id)
)
device.encrypted_credentials = None # Clear legacy (Transit is canonical)
credentials_changed = True
elif data.username is not None and (device.encrypted_credentials_transit or device.encrypted_credentials):
elif data.username is not None and (
device.encrypted_credentials_transit or device.encrypted_credentials
):
# Only username changed — update it without changing the password
try:
existing_json = await decrypt_credentials_hybrid(
@@ -373,6 +364,7 @@ async def update_device(
if credentials_changed:
try:
from app.services.event_publisher import publish_event
await publish_event(
f"device.credential_changed.{device_id}",
{"device_id": str(device_id), "tenant_id": str(tenant_id)},
@@ -380,9 +372,7 @@ async def update_device(
except Exception:
pass # Never fail the update due to NATS issues
result2 = await db.execute(
_device_with_relations().where(Device.id == device_id)
)
result2 = await db.execute(_device_with_relations().where(Device.id == device_id))
device = result2.scalar_one()
return _build_device_response(device)
@@ -526,11 +516,7 @@ async def get_groups(
tenant_id: uuid.UUID,
) -> list[DeviceGroupResponse]:
"""Return all device groups for the current tenant with device counts."""
result = await db.execute(
select(DeviceGroup).options(
selectinload(DeviceGroup.memberships)
)
)
result = await db.execute(select(DeviceGroup).options(selectinload(DeviceGroup.memberships)))
groups = result.scalars().all()
return [
DeviceGroupResponse(
@@ -554,9 +540,9 @@ async def update_group(
from fastapi import HTTPException, status
result = await db.execute(
select(DeviceGroup).options(
selectinload(DeviceGroup.memberships)
).where(DeviceGroup.id == group_id)
select(DeviceGroup)
.options(selectinload(DeviceGroup.memberships))
.where(DeviceGroup.id == group_id)
)
group = result.scalar_one_or_none()
if not group:
@@ -571,9 +557,9 @@ async def update_group(
await db.refresh(group)
result2 = await db.execute(
select(DeviceGroup).options(
selectinload(DeviceGroup.memberships)
).where(DeviceGroup.id == group_id)
select(DeviceGroup)
.options(selectinload(DeviceGroup.memberships))
.where(DeviceGroup.id == group_id)
)
group = result2.scalar_one()
return DeviceGroupResponse(

View File

@@ -48,7 +48,5 @@ async def generate_emergency_kit_template(
# Run weasyprint in thread to avoid blocking the event loop
from weasyprint import HTML
pdf_bytes = await asyncio.to_thread(
lambda: HTML(string=html_content).write_pdf()
)
pdf_bytes = await asyncio.to_thread(lambda: HTML(string=html_content).write_pdf())
return pdf_bytes

View File

@@ -13,7 +13,6 @@ Version discovery comes from two sources:
"""
import logging
import os
from pathlib import Path
import httpx
@@ -51,13 +50,12 @@ async def check_latest_versions() -> list[dict]:
async with httpx.AsyncClient(timeout=30.0) as client:
for channel_file, channel, major in _VERSION_SOURCES:
try:
resp = await client.get(
f"https://download.mikrotik.com/routeros/{channel_file}"
)
resp = await client.get(f"https://download.mikrotik.com/routeros/{channel_file}")
if resp.status_code != 200:
logger.warning(
"MikroTik version check returned %d for %s",
resp.status_code, channel_file,
resp.status_code,
channel_file,
)
continue
@@ -72,12 +70,14 @@ async def check_latest_versions() -> list[dict]:
f"https://download.mikrotik.com/routeros/"
f"{version}/routeros-{version}-{arch}.npk"
)
results.append({
"architecture": arch,
"channel": channel,
"version": version,
"npk_url": npk_url,
})
results.append(
{
"architecture": arch,
"channel": channel,
"version": version,
"npk_url": npk_url,
}
)
except Exception as e:
logger.warning("Failed to check %s: %s", channel_file, e)
@@ -242,15 +242,15 @@ async def get_firmware_overview(tenant_id: str) -> dict:
groups = []
for ver, devs in sorted(version_groups.items()):
# A version is "latest" if it matches the latest for any arch/channel combo
is_latest = any(
v["version"] == ver for v in latest_versions.values()
is_latest = any(v["version"] == ver for v in latest_versions.values())
groups.append(
{
"version": ver,
"count": len(devs),
"is_latest": is_latest,
"devices": devs,
}
)
groups.append({
"version": ver,
"count": len(devs),
"is_latest": is_latest,
"devices": devs,
})
return {
"devices": device_list,
@@ -272,12 +272,14 @@ async def get_cached_firmware() -> list[dict]:
continue
for npk_file in sorted(version_dir.iterdir()):
if npk_file.suffix == ".npk":
cached.append({
"path": str(npk_file),
"version": version_dir.name,
"filename": npk_file.name,
"size_bytes": npk_file.stat().st_size,
})
cached.append(
{
"path": str(npk_file),
"version": version_dir.name,
"filename": npk_file.name,
"size_bytes": npk_file.stat().st_size,
}
)
return cached

View File

@@ -41,7 +41,7 @@ async def on_device_firmware(msg) -> None:
try:
data = json.loads(msg.data)
device_id = data.get("device_id")
tenant_id = data.get("tenant_id")
_tenant_id = data.get("tenant_id")
architecture = data.get("architecture")
installed_version = data.get("installed_version")
latest_version = data.get("latest_version")
@@ -126,9 +126,7 @@ async def _subscribe_with_retry(js: JetStreamContext) -> None:
durable="api-firmware-consumer",
stream="DEVICE_EVENTS",
)
logger.info(
"NATS: subscribed to device.firmware.> (durable: api-firmware-consumer)"
)
logger.info("NATS: subscribed to device.firmware.> (durable: api-firmware-consumer)")
return
except Exception as exc:
if attempt < max_attempts:

View File

@@ -238,11 +238,13 @@ def list_device_commits(
# Device wasn't in parent but is in this commit — it's the first entry.
pass
results.append({
"sha": str(commit.id),
"message": commit.message.strip(),
"timestamp": commit.commit_time,
})
results.append(
{
"sha": str(commit.id),
"message": commit.message.strip(),
"timestamp": commit.commit_time,
}
)
return results

View File

@@ -52,6 +52,7 @@ async def store_user_key_set(
"""
# Remove any existing key set (e.g. from a failed prior upgrade attempt)
from sqlalchemy import delete
await db.execute(delete(UserKeySet).where(UserKeySet.user_id == user_id))
key_set = UserKeySet(
@@ -71,9 +72,7 @@ async def store_user_key_set(
return key_set
async def get_user_key_set(
db: AsyncSession, user_id: UUID
) -> UserKeySet | None:
async def get_user_key_set(db: AsyncSession, user_id: UUID) -> UserKeySet | None:
"""Retrieve encrypted key bundle for login response.
Args:
@@ -83,9 +82,7 @@ async def get_user_key_set(
Returns:
The UserKeySet if found, None otherwise.
"""
result = await db.execute(
select(UserKeySet).where(UserKeySet.user_id == user_id)
)
result = await db.execute(select(UserKeySet).where(UserKeySet.user_id == user_id))
return result.scalar_one_or_none()
@@ -163,9 +160,7 @@ async def provision_tenant_key(db: AsyncSession, tenant_id: UUID) -> str:
await openbao.create_tenant_key(str(tenant_id))
# Update tenant record with key name
result = await db.execute(
select(Tenant).where(Tenant.id == tenant_id)
)
result = await db.execute(select(Tenant).where(Tenant.id == tenant_id))
tenant = result.scalar_one_or_none()
if tenant:
tenant.openbao_key_name = key_name
@@ -210,13 +205,18 @@ async def migrate_tenant_credentials(db: AsyncSession, tenant_id: UUID) -> dict:
select(Device).where(
Device.tenant_id == tenant_id,
Device.encrypted_credentials.isnot(None),
(Device.encrypted_credentials_transit.is_(None) | (Device.encrypted_credentials_transit == "")),
(
Device.encrypted_credentials_transit.is_(None)
| (Device.encrypted_credentials_transit == "")
),
)
)
for device in result.scalars().all():
try:
plaintext = decrypt_credentials(device.encrypted_credentials, legacy_key)
device.encrypted_credentials_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
device.encrypted_credentials_transit = await openbao.encrypt(
tid, plaintext.encode("utf-8")
)
counts["devices"] += 1
except Exception as e:
logger.error("Failed to migrate device %s credentials: %s", device.id, e)
@@ -227,7 +227,10 @@ async def migrate_tenant_credentials(db: AsyncSession, tenant_id: UUID) -> dict:
select(CertificateAuthority).where(
CertificateAuthority.tenant_id == tenant_id,
CertificateAuthority.encrypted_private_key.isnot(None),
(CertificateAuthority.encrypted_private_key_transit.is_(None) | (CertificateAuthority.encrypted_private_key_transit == "")),
(
CertificateAuthority.encrypted_private_key_transit.is_(None)
| (CertificateAuthority.encrypted_private_key_transit == "")
),
)
)
for ca in result.scalars().all():
@@ -244,13 +247,18 @@ async def migrate_tenant_credentials(db: AsyncSession, tenant_id: UUID) -> dict:
select(DeviceCertificate).where(
DeviceCertificate.tenant_id == tenant_id,
DeviceCertificate.encrypted_private_key.isnot(None),
(DeviceCertificate.encrypted_private_key_transit.is_(None) | (DeviceCertificate.encrypted_private_key_transit == "")),
(
DeviceCertificate.encrypted_private_key_transit.is_(None)
| (DeviceCertificate.encrypted_private_key_transit == "")
),
)
)
for cert in result.scalars().all():
try:
plaintext = decrypt_credentials(cert.encrypted_private_key, legacy_key)
cert.encrypted_private_key_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
cert.encrypted_private_key_transit = await openbao.encrypt(
tid, plaintext.encode("utf-8")
)
counts["certs"] += 1
except Exception as e:
logger.error("Failed to migrate cert %s private key: %s", cert.id, e)
@@ -303,7 +311,14 @@ async def provision_existing_tenants(db: AsyncSession) -> dict:
result = await db.execute(select(Tenant))
tenants = result.scalars().all()
total = {"tenants": len(tenants), "devices": 0, "cas": 0, "certs": 0, "channels": 0, "errors": 0}
total = {
"tenants": len(tenants),
"devices": 0,
"cas": 0,
"certs": 0,
"channels": 0,
"errors": 0,
}
for tenant in tenants:
try:

View File

@@ -199,9 +199,7 @@ async def on_device_metrics(msg) -> None:
device_id = data.get("device_id")
if not metric_type or not device_id:
logger.warning(
"device.metrics event missing 'type' or 'device_id' — skipping"
)
logger.warning("device.metrics event missing 'type' or 'device_id' — skipping")
await msg.ack()
return
@@ -222,6 +220,7 @@ async def on_device_metrics(msg) -> None:
# Alert evaluation — non-fatal; metric write is the primary operation
try:
from app.services import alert_evaluator
await alert_evaluator.evaluate(
device_id=device_id,
tenant_id=data.get("tenant_id", ""),
@@ -265,9 +264,7 @@ async def _subscribe_with_retry(js: JetStreamContext) -> None:
durable="api-metrics-consumer",
stream="DEVICE_EVENTS",
)
logger.info(
"NATS: subscribed to device.metrics.> (durable: api-metrics-consumer)"
)
logger.info("NATS: subscribed to device.metrics.> (durable: api-metrics-consumer)")
return
except Exception as exc:
if attempt < max_attempts:

View File

@@ -69,7 +69,9 @@ async def on_device_status(msg) -> None:
uptime_seconds = _parse_uptime(data.get("uptime", ""))
if not device_id or not status:
logger.warning("Received device.status event with missing device_id or status — skipping")
logger.warning(
"Received device.status event with missing device_id or status — skipping"
)
await msg.ack()
return
@@ -115,12 +117,15 @@ async def on_device_status(msg) -> None:
# Alert evaluation for offline/online status changes — non-fatal
try:
from app.services import alert_evaluator
if status == "offline":
await alert_evaluator.evaluate_offline(device_id, data.get("tenant_id", ""))
elif status == "online":
await alert_evaluator.evaluate_online(device_id, data.get("tenant_id", ""))
except Exception as e:
logger.warning("Alert evaluation failed for device %s status=%s: %s", device_id, status, e)
logger.warning(
"Alert evaluation failed for device %s status=%s: %s", device_id, status, e
)
logger.info(
"Device status updated",

View File

@@ -39,7 +39,9 @@ async def dispatch_notifications(
except Exception as e:
logger.warning(
"Notification delivery failed for channel %s (%s): %s",
channel.get("name"), channel.get("channel_type"), e,
channel.get("name"),
channel.get("channel_type"),
e,
)
@@ -100,14 +102,18 @@ async def _send_email(channel: dict, alert_event: dict, device_hostname: str) ->
if transit_cipher and tenant_id:
try:
from app.services.kms_service import decrypt_transit
smtp_password = await decrypt_transit(transit_cipher, tenant_id)
except Exception:
logger.warning("Transit decryption failed for channel %s, trying legacy", channel.get("id"))
logger.warning(
"Transit decryption failed for channel %s, trying legacy", channel.get("id")
)
if not smtp_password and legacy_cipher:
try:
from app.config import settings as app_settings
from cryptography.fernet import Fernet
raw = bytes(legacy_cipher) if isinstance(legacy_cipher, memoryview) else legacy_cipher
f = Fernet(app_settings.CREDENTIAL_ENCRYPTION_KEY.encode())
smtp_password = f.decrypt(raw).decode()
@@ -163,7 +169,8 @@ async def _send_webhook(
response = await client.post(webhook_url, json=payload)
logger.info(
"Webhook notification sent to %s — status %d",
webhook_url, response.status_code,
webhook_url,
response.status_code,
)
@@ -180,13 +187,18 @@ async def _send_slack(
value = alert_event.get("value")
threshold = alert_event.get("threshold")
color = {"CRITICAL": "#dc2626", "WARNING": "#f59e0b", "INFO": "#3b82f6"}.get(severity, "#6b7280")
color = {"CRITICAL": "#dc2626", "WARNING": "#f59e0b", "INFO": "#3b82f6"}.get(
severity, "#6b7280"
)
status_label = "RESOLVED" if status == "resolved" else status
blocks = [
{
"type": "header",
"text": {"type": "plain_text", "text": f"{'' if status == 'resolved' else '🚨'} [{severity}] {status_label.upper()}"},
"text": {
"type": "plain_text",
"text": f"{'' if status == 'resolved' else '🚨'} [{severity}] {status_label.upper()}",
},
},
{
"type": "section",
@@ -205,7 +217,9 @@ async def _send_slack(
blocks.append({"type": "section", "fields": fields})
if message_text:
blocks.append({"type": "section", "text": {"type": "mrkdwn", "text": f"*Message:*\n{message_text}"}})
blocks.append(
{"type": "section", "text": {"type": "mrkdwn", "text": f"*Message:*\n{message_text}"}}
)
blocks.append({"type": "context", "elements": [{"type": "mrkdwn", "text": "TOD Alert System"}]})

View File

@@ -7,6 +7,7 @@ material never leaves OpenBao -- the application only sees ciphertext.
Ciphertext format: "vault:v1:base64..." (compatible with Vault Transit format)
"""
import base64
import logging
from typing import Optional

View File

@@ -47,13 +47,15 @@ async def record_push(
"""
r = await _get_redis()
key = f"push:recent:{device_id}"
value = json.dumps({
"device_id": device_id,
"tenant_id": tenant_id,
"push_type": push_type,
"push_operation_id": push_operation_id,
"pre_push_commit_sha": pre_push_commit_sha,
})
value = json.dumps(
{
"device_id": device_id,
"tenant_id": tenant_id,
"push_type": push_type,
"push_operation_id": push_operation_id,
"pre_push_commit_sha": pre_push_commit_sha,
}
)
await r.set(key, value, ex=PUSH_TTL_SECONDS)
logger.debug(
"Recorded push for device %s (type=%s, TTL=%ds)",

View File

@@ -164,16 +164,18 @@ async def _device_inventory(
uptime_str = _format_uptime(row[6]) if row[6] else None
last_seen_str = row[5].strftime("%Y-%m-%d %H:%M") if row[5] else None
devices.append({
"hostname": row[0],
"ip_address": row[1],
"model": row[2],
"routeros_version": row[3],
"status": status,
"last_seen": last_seen_str,
"uptime": uptime_str,
"groups": row[7] if row[7] else None,
})
devices.append(
{
"hostname": row[0],
"ip_address": row[1],
"model": row[2],
"routeros_version": row[3],
"status": status,
"last_seen": last_seen_str,
"uptime": uptime_str,
"groups": row[7] if row[7] else None,
}
)
return {
"report_title": "Device Inventory",
@@ -224,16 +226,18 @@ async def _metrics_summary(
devices = []
for row in rows:
devices.append({
"hostname": row[0],
"avg_cpu": float(row[1]) if row[1] is not None else None,
"peak_cpu": float(row[2]) if row[2] is not None else None,
"avg_mem": float(row[3]) if row[3] is not None else None,
"peak_mem": float(row[4]) if row[4] is not None else None,
"avg_disk": float(row[5]) if row[5] is not None else None,
"avg_temp": float(row[6]) if row[6] is not None else None,
"data_points": row[7],
})
devices.append(
{
"hostname": row[0],
"avg_cpu": float(row[1]) if row[1] is not None else None,
"peak_cpu": float(row[2]) if row[2] is not None else None,
"avg_mem": float(row[3]) if row[3] is not None else None,
"peak_mem": float(row[4]) if row[4] is not None else None,
"avg_disk": float(row[5]) if row[5] is not None else None,
"avg_temp": float(row[6]) if row[6] is not None else None,
"data_points": row[7],
}
)
return {
"report_title": "Metrics Summary",
@@ -287,14 +291,16 @@ async def _alert_history(
if duration_secs is not None:
resolved_durations.append(duration_secs)
alerts.append({
"fired_at": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"hostname": row[5],
"severity": severity,
"status": row[3],
"message": row[4],
"duration": _format_duration(duration_secs) if duration_secs is not None else None,
})
alerts.append(
{
"fired_at": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"hostname": row[5],
"severity": severity,
"status": row[3],
"message": row[4],
"duration": _format_duration(duration_secs) if duration_secs is not None else None,
}
)
mttr_minutes = None
mttr_display = None
@@ -374,13 +380,15 @@ async def _change_log_from_audit(
entries = []
for row in rows:
entries.append({
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"user": row[1],
"action": row[2],
"device": row[3],
"details": row[4] or row[5] or "",
})
entries.append(
{
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"user": row[1],
"action": row[2],
"device": row[3],
"details": row[4] or row[5] or "",
}
)
return {
"report_title": "Change Log",
@@ -436,21 +444,25 @@ async def _change_log_from_backups(
# Merge and sort by timestamp descending
entries = []
for row in backup_rows:
entries.append({
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"user": row[1],
"action": row[2],
"device": row[3],
"details": row[4] or "",
})
entries.append(
{
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"user": row[1],
"action": row[2],
"device": row[3],
"details": row[4] or "",
}
)
for row in alert_rows:
entries.append({
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"user": row[1],
"action": row[2],
"device": row[3],
"details": row[4] or "",
})
entries.append(
{
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"user": row[1],
"action": row[2],
"device": row[3],
"details": row[4] or "",
}
)
# Sort by timestamp string descending
entries.sort(key=lambda e: e["timestamp"], reverse=True)
@@ -486,54 +498,102 @@ def _render_csv(report_type: str, data: dict[str, Any]) -> bytes:
writer = csv.writer(output)
if report_type == "device_inventory":
writer.writerow([
"Hostname", "IP Address", "Model", "RouterOS Version",
"Status", "Last Seen", "Uptime", "Groups",
])
writer.writerow(
[
"Hostname",
"IP Address",
"Model",
"RouterOS Version",
"Status",
"Last Seen",
"Uptime",
"Groups",
]
)
for d in data.get("devices", []):
writer.writerow([
d["hostname"], d["ip_address"], d["model"] or "",
d["routeros_version"] or "", d["status"],
d["last_seen"] or "", d["uptime"] or "",
d["groups"] or "",
])
writer.writerow(
[
d["hostname"],
d["ip_address"],
d["model"] or "",
d["routeros_version"] or "",
d["status"],
d["last_seen"] or "",
d["uptime"] or "",
d["groups"] or "",
]
)
elif report_type == "metrics_summary":
writer.writerow([
"Hostname", "Avg CPU %", "Peak CPU %", "Avg Memory %",
"Peak Memory %", "Avg Disk %", "Avg Temp", "Data Points",
])
writer.writerow(
[
"Hostname",
"Avg CPU %",
"Peak CPU %",
"Avg Memory %",
"Peak Memory %",
"Avg Disk %",
"Avg Temp",
"Data Points",
]
)
for d in data.get("devices", []):
writer.writerow([
d["hostname"],
f"{d['avg_cpu']:.1f}" if d["avg_cpu"] is not None else "",
f"{d['peak_cpu']:.1f}" if d["peak_cpu"] is not None else "",
f"{d['avg_mem']:.1f}" if d["avg_mem"] is not None else "",
f"{d['peak_mem']:.1f}" if d["peak_mem"] is not None else "",
f"{d['avg_disk']:.1f}" if d["avg_disk"] is not None else "",
f"{d['avg_temp']:.1f}" if d["avg_temp"] is not None else "",
d["data_points"],
])
writer.writerow(
[
d["hostname"],
f"{d['avg_cpu']:.1f}" if d["avg_cpu"] is not None else "",
f"{d['peak_cpu']:.1f}" if d["peak_cpu"] is not None else "",
f"{d['avg_mem']:.1f}" if d["avg_mem"] is not None else "",
f"{d['peak_mem']:.1f}" if d["peak_mem"] is not None else "",
f"{d['avg_disk']:.1f}" if d["avg_disk"] is not None else "",
f"{d['avg_temp']:.1f}" if d["avg_temp"] is not None else "",
d["data_points"],
]
)
elif report_type == "alert_history":
writer.writerow([
"Timestamp", "Device", "Severity", "Message", "Status", "Duration",
])
writer.writerow(
[
"Timestamp",
"Device",
"Severity",
"Message",
"Status",
"Duration",
]
)
for a in data.get("alerts", []):
writer.writerow([
a["fired_at"], a["hostname"] or "", a["severity"],
a["message"] or "", a["status"], a["duration"] or "",
])
writer.writerow(
[
a["fired_at"],
a["hostname"] or "",
a["severity"],
a["message"] or "",
a["status"],
a["duration"] or "",
]
)
elif report_type == "change_log":
writer.writerow([
"Timestamp", "User", "Action", "Device", "Details",
])
writer.writerow(
[
"Timestamp",
"User",
"Action",
"Device",
"Details",
]
)
for e in data.get("entries", []):
writer.writerow([
e["timestamp"], e["user"] or "", e["action"],
e["device"] or "", e["details"] or "",
])
writer.writerow(
[
e["timestamp"],
e["user"] or "",
e["action"],
e["device"] or "",
e["details"] or "",
]
)
return output.getvalue().encode("utf-8")

View File

@@ -33,7 +33,6 @@ import asyncssh
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.database import set_tenant_context, AdminAsyncSessionLocal
from app.models.config_backup import ConfigPushOperation
from app.models.device import Device
from app.services import backup_service, git_store
@@ -113,12 +112,11 @@ async def restore_config(
raise ValueError(f"Device {device_id!r} not found")
if not device.encrypted_credentials_transit and not device.encrypted_credentials:
raise ValueError(
f"Device {device_id!r} has no stored credentials — cannot perform restore"
)
raise ValueError(f"Device {device_id!r} has no stored credentials — cannot perform restore")
key = settings.get_encryption_key_bytes()
from app.services.crypto import decrypt_credentials_hybrid
creds_json = await decrypt_credentials_hybrid(
device.encrypted_credentials_transit,
device.encrypted_credentials,
@@ -133,7 +131,9 @@ async def restore_config(
hostname = device.hostname or ip
# Publish "started" progress event
await _publish_push_progress(tenant_id, device_id, "started", f"Config restore started for {hostname}")
await _publish_push_progress(
tenant_id, device_id, "started", f"Config restore started for {hostname}"
)
# ------------------------------------------------------------------
# Step 2: Read the target export.rsc from the backup commit
@@ -157,7 +157,9 @@ async def restore_config(
# ------------------------------------------------------------------
# Step 3: Mandatory pre-backup before push
# ------------------------------------------------------------------
await _publish_push_progress(tenant_id, device_id, "backing_up", f"Creating pre-restore backup for {hostname}")
await _publish_push_progress(
tenant_id, device_id, "backing_up", f"Creating pre-restore backup for {hostname}"
)
logger.info(
"Starting pre-restore backup for device %s (%s) before pushing commit %s",
@@ -198,7 +200,9 @@ async def restore_config(
# Step 5: SSH to device — install panic-revert, push config
# ------------------------------------------------------------------
push_op_id_str = str(push_op_id)
await _publish_push_progress(tenant_id, device_id, "pushing", f"Pushing config to {hostname}", push_op_id=push_op_id_str)
await _publish_push_progress(
tenant_id, device_id, "pushing", f"Pushing config to {hostname}", push_op_id=push_op_id_str
)
logger.info(
"Pushing config to device %s (%s): installing panic-revert scheduler and uploading config",
@@ -290,9 +294,12 @@ async def restore_config(
# Update push operation to failed
await _update_push_op_status(push_op_id, "failed", db_session)
await _publish_push_progress(
tenant_id, device_id, "failed",
tenant_id,
device_id,
"failed",
f"Config push failed for {hostname}: {push_err}",
push_op_id=push_op_id_str, error=str(push_err),
push_op_id=push_op_id_str,
error=str(push_err),
)
return {
"status": "failed",
@@ -312,7 +319,13 @@ async def restore_config(
# ------------------------------------------------------------------
# Step 6: Wait 60s for config to settle
# ------------------------------------------------------------------
await _publish_push_progress(tenant_id, device_id, "settling", f"Config pushed to {hostname} — waiting 60s for settle", push_op_id=push_op_id_str)
await _publish_push_progress(
tenant_id,
device_id,
"settling",
f"Config pushed to {hostname} — waiting 60s for settle",
push_op_id=push_op_id_str,
)
logger.info(
"Config pushed to device %s — waiting 60s for config to settle",
@@ -323,7 +336,13 @@ async def restore_config(
# ------------------------------------------------------------------
# Step 7: Reachability check
# ------------------------------------------------------------------
await _publish_push_progress(tenant_id, device_id, "verifying", f"Verifying device {hostname} reachability", push_op_id=push_op_id_str)
await _publish_push_progress(
tenant_id,
device_id,
"verifying",
f"Verifying device {hostname} reachability",
push_op_id=push_op_id_str,
)
reachable = await _check_reachability(ip, ssh_username, ssh_password)
@@ -362,7 +381,13 @@ async def restore_config(
await _update_push_op_status(push_op_id, "committed", db_session)
await clear_push(device_id)
await _publish_push_progress(tenant_id, device_id, "committed", f"Config restored successfully on {hostname}", push_op_id=push_op_id_str)
await _publish_push_progress(
tenant_id,
device_id,
"committed",
f"Config restored successfully on {hostname}",
push_op_id=push_op_id_str,
)
return {
"status": "committed",
@@ -384,7 +409,9 @@ async def restore_config(
await _update_push_op_status(push_op_id, "reverted", db_session)
await _publish_push_progress(
tenant_id, device_id, "reverted",
tenant_id,
device_id,
"reverted",
f"Device {hostname} unreachable — auto-reverting via panic-revert scheduler",
push_op_id=push_op_id_str,
)
@@ -446,7 +473,7 @@ async def _update_push_op_status(
new_status: New status value ('committed' | 'reverted' | 'failed').
db_session: Database session (must already have tenant context set).
"""
from sqlalchemy import select, update
from sqlalchemy import update
await db_session.execute(
update(ConfigPushOperation)
@@ -526,9 +553,7 @@ async def recover_stale_push_operations(db_session: AsyncSession) -> None:
for op in stale_ops:
try:
# Load device
dev_result = await db_session.execute(
select(Device).where(Device.id == op.device_id)
)
dev_result = await db_session.execute(select(Device).where(Device.id == op.device_id))
device = dev_result.scalar_one_or_none()
if not device:
logger.error("Device %s not found for stale op %s", op.device_id, op.id)
@@ -547,9 +572,7 @@ async def recover_stale_push_operations(db_session: AsyncSession) -> None:
ssh_password = creds.get("password", "")
# Check reachability
reachable = await _check_reachability(
device.ip_address, ssh_username, ssh_password
)
reachable = await _check_reachability(device.ip_address, ssh_username, ssh_password)
if reachable:
# Try to remove scheduler (if still there, push was good)

View File

@@ -53,7 +53,9 @@ async def cleanup_expired_snapshots() -> int:
deleted = result.rowcount
config_snapshots_cleaned_total.inc(deleted)
logger.info("retention cleanup complete", extra={"deleted_snapshots": deleted, "retention_days": days})
logger.info(
"retention cleanup complete", extra={"deleted_snapshots": deleted, "retention_days": days}
)
return deleted

View File

@@ -88,9 +88,7 @@ async def browse_menu(device_id: str, path: str) -> dict[str, Any]:
return await execute_command(device_id, command)
async def add_entry(
device_id: str, path: str, properties: dict[str, str]
) -> dict[str, Any]:
async def add_entry(device_id: str, path: str, properties: dict[str, str]) -> dict[str, Any]:
"""Add a new entry to a RouterOS menu path.
Args:
@@ -124,9 +122,7 @@ async def update_entry(
return await execute_command(device_id, f"{path}/set", args)
async def remove_entry(
device_id: str, path: str, entry_id: str
) -> dict[str, Any]:
async def remove_entry(device_id: str, path: str, entry_id: str) -> dict[str, Any]:
"""Remove an entry from a RouterOS menu path.
Args:

View File

@@ -7,20 +7,33 @@ from typing import Any
logger = logging.getLogger(__name__)
HIGH_RISK_PATHS = {
"/ip address", "/ip route", "/ip firewall filter", "/ip firewall nat",
"/interface", "/interface bridge", "/interface vlan",
"/system identity", "/ip service", "/ip ssh", "/user",
"/ip address",
"/ip route",
"/ip firewall filter",
"/ip firewall nat",
"/interface",
"/interface bridge",
"/interface vlan",
"/system identity",
"/ip service",
"/ip ssh",
"/user",
}
MANAGEMENT_PATTERNS = [
(re.compile(r"chain=input.*dst-port=(22|8291|8728|8729|443|80)", re.I),
"Modifies firewall rules for management ports (SSH/WinBox/API/Web)"),
(re.compile(r"chain=input.*action=drop", re.I),
"Adds drop rule on input chain — may block management access"),
(re.compile(r"/ip service", re.I),
"Modifies IP services — may disable API/SSH/WinBox access"),
(re.compile(r"/user.*set.*password", re.I),
"Changes user password — may affect automated access"),
(
re.compile(r"chain=input.*dst-port=(22|8291|8728|8729|443|80)", re.I),
"Modifies firewall rules for management ports (SSH/WinBox/API/Web)",
),
(
re.compile(r"chain=input.*action=drop", re.I),
"Adds drop rule on input chain — may block management access",
),
(re.compile(r"/ip service", re.I), "Modifies IP services — may disable API/SSH/WinBox access"),
(
re.compile(r"/user.*set.*password", re.I),
"Changes user password — may affect automated access",
),
]
@@ -73,7 +86,14 @@ def parse_rsc(text: str) -> dict[str, Any]:
else:
# Check if second part starts with a known command verb
cmd_check = parts[1].strip().split(None, 1)
if cmd_check and cmd_check[0] in ("add", "set", "remove", "print", "enable", "disable"):
if cmd_check and cmd_check[0] in (
"add",
"set",
"remove",
"print",
"enable",
"disable",
):
current_path = parts[0]
line = parts[1].strip()
else:
@@ -184,12 +204,14 @@ def compute_impact(
risk = "none"
if has_changes:
risk = "high" if path in HIGH_RISK_PATHS else "low"
result_categories.append({
"path": path,
"adds": added,
"removes": removed,
"risk": risk,
})
result_categories.append(
{
"path": path,
"adds": added,
"removes": removed,
"risk": risk,
}
)
# Check target commands against management patterns
target_text = "\n".join(

View File

@@ -16,9 +16,7 @@ from srptools.constants import PRIME_2048, PRIME_2048_GEN
_SRP_HASH = hashlib.sha256
async def create_srp_verifier(
salt_hex: str, verifier_hex: str
) -> tuple[bytes, bytes]:
async def create_srp_verifier(salt_hex: str, verifier_hex: str) -> tuple[bytes, bytes]:
"""Convert client-provided hex salt and verifier to bytes for storage.
The client computes v = g^x mod N using 2SKD-derived SRP-x.
@@ -31,9 +29,7 @@ async def create_srp_verifier(
return bytes.fromhex(salt_hex), bytes.fromhex(verifier_hex)
async def srp_init(
email: str, srp_verifier_hex: str
) -> tuple[str, str]:
async def srp_init(email: str, srp_verifier_hex: str) -> tuple[str, str]:
"""SRP Step 1: Generate server ephemeral (B) and private key (b).
Args:
@@ -47,14 +43,15 @@ async def srp_init(
Raises:
ValueError: If SRP initialization fails for any reason.
"""
def _init() -> tuple[str, str]:
context = SRPContext(
email, prime=PRIME_2048, generator=PRIME_2048_GEN,
email,
prime=PRIME_2048,
generator=PRIME_2048_GEN,
hash_func=_SRP_HASH,
)
server_session = SRPServerSession(
context, srp_verifier_hex
)
server_session = SRPServerSession(context, srp_verifier_hex)
return server_session.public, server_session.private
try:
@@ -85,26 +82,27 @@ async def srp_verify(
Tuple of (is_valid, server_proof_hex_or_none).
If valid, server_proof is M2 for the client to verify.
"""
def _verify() -> tuple[bool, str | None]:
import logging
log = logging.getLogger("srp_debug")
context = SRPContext(
email, prime=PRIME_2048, generator=PRIME_2048_GEN,
email,
prime=PRIME_2048,
generator=PRIME_2048_GEN,
hash_func=_SRP_HASH,
)
server_session = SRPServerSession(
context, srp_verifier_hex, private=server_private
)
server_session = SRPServerSession(context, srp_verifier_hex, private=server_private)
_key, _key_proof, _key_proof_hash = server_session.process(client_public, srp_salt_hex)
# srptools verify_proof has a Python 3 bug: hexlify() returns bytes
# but client_proof is str, so bytes == str is always False.
# Compare manually with consistent types.
server_m1 = _key_proof if isinstance(_key_proof, str) else _key_proof.decode('ascii')
server_m1 = _key_proof if isinstance(_key_proof, str) else _key_proof.decode("ascii")
is_valid = client_proof.lower() == server_m1.lower()
if not is_valid:
return False, None
# Return M2 (key_proof_hash), also fixing the bytes/str issue
m2 = _key_proof_hash if isinstance(_key_proof_hash, str) else _key_proof_hash.decode('ascii')
m2 = (
_key_proof_hash if isinstance(_key_proof_hash, str) else _key_proof_hash.decode("ascii")
)
return True, m2
try:

View File

@@ -137,7 +137,9 @@ class SSEConnectionManager:
if last_event_id is not None:
try:
start_seq = int(last_event_id) + 1
consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.BY_START_SEQUENCE, opt_start_seq=start_seq)
consumer_cfg = ConsumerConfig(
deliver_policy=DeliverPolicy.BY_START_SEQUENCE, opt_start_seq=start_seq
)
except (ValueError, TypeError):
consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.NEW)
else:
@@ -173,18 +175,32 @@ class SSEConnectionManager:
except Exception as exc:
if "stream not found" in str(exc):
try:
await js.add_stream(StreamConfig(
name="ALERT_EVENTS",
subjects=_ALERT_EVENT_SUBJECTS,
max_age=3600,
))
sub = await js.subscribe(subject, stream="ALERT_EVENTS", config=consumer_cfg)
await js.add_stream(
StreamConfig(
name="ALERT_EVENTS",
subjects=_ALERT_EVENT_SUBJECTS,
max_age=3600,
)
)
sub = await js.subscribe(
subject, stream="ALERT_EVENTS", config=consumer_cfg
)
self._subscriptions.append(sub)
logger.info("sse.stream_created_lazily", stream="ALERT_EVENTS")
except Exception as retry_exc:
logger.warning("sse.subscribe_failed", subject=subject, stream="ALERT_EVENTS", error=str(retry_exc))
logger.warning(
"sse.subscribe_failed",
subject=subject,
stream="ALERT_EVENTS",
error=str(retry_exc),
)
else:
logger.warning("sse.subscribe_failed", subject=subject, stream="ALERT_EVENTS", error=str(exc))
logger.warning(
"sse.subscribe_failed",
subject=subject,
stream="ALERT_EVENTS",
error=str(exc),
)
# Subscribe to operation events (OPERATION_EVENTS stream)
for subject in _OPERATION_EVENT_SUBJECTS:
@@ -198,18 +214,32 @@ class SSEConnectionManager:
except Exception as exc:
if "stream not found" in str(exc):
try:
await js.add_stream(StreamConfig(
name="OPERATION_EVENTS",
subjects=_OPERATION_EVENT_SUBJECTS,
max_age=3600,
))
sub = await js.subscribe(subject, stream="OPERATION_EVENTS", config=consumer_cfg)
await js.add_stream(
StreamConfig(
name="OPERATION_EVENTS",
subjects=_OPERATION_EVENT_SUBJECTS,
max_age=3600,
)
)
sub = await js.subscribe(
subject, stream="OPERATION_EVENTS", config=consumer_cfg
)
self._subscriptions.append(sub)
logger.info("sse.stream_created_lazily", stream="OPERATION_EVENTS")
except Exception as retry_exc:
logger.warning("sse.subscribe_failed", subject=subject, stream="OPERATION_EVENTS", error=str(retry_exc))
logger.warning(
"sse.subscribe_failed",
subject=subject,
stream="OPERATION_EVENTS",
error=str(retry_exc),
)
else:
logger.warning("sse.subscribe_failed", subject=subject, stream="OPERATION_EVENTS", error=str(exc))
logger.warning(
"sse.subscribe_failed",
subject=subject,
stream="OPERATION_EVENTS",
error=str(exc),
)
# Start background task to pull messages from subscriptions into the queue
asyncio.create_task(self._pump_messages())

View File

@@ -12,22 +12,18 @@ separate scheduler and file names to avoid conflicts with restore operations.
"""
import asyncio
import io
import ipaddress
import json
import logging
import uuid
from datetime import datetime, timezone
import asyncssh
from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment
from sqlalchemy import select, text
from sqlalchemy import text
from app.config import settings
from app.database import AdminAsyncSessionLocal
from app.models.config_template import TemplatePushJob
from app.models.device import Device
logger = logging.getLogger(__name__)
@@ -145,7 +141,9 @@ async def push_to_devices(rollout_id: str) -> dict:
except Exception as exc:
logger.error(
"Uncaught exception in template push rollout %s: %s",
rollout_id, exc, exc_info=True,
rollout_id,
exc,
exc_info=True,
)
return {"completed": 0, "failed": 1, "pending": 0}
@@ -181,7 +179,9 @@ async def _run_push_rollout(rollout_id: str) -> dict:
logger.info(
"Template push rollout %s: pushing to device %s (job %s)",
rollout_id, hostname, job_id,
rollout_id,
hostname,
job_id,
)
await push_single_device(job_id)
@@ -200,7 +200,9 @@ async def _run_push_rollout(rollout_id: str) -> dict:
failed = True
logger.error(
"Template push rollout %s paused: device %s %s",
rollout_id, hostname, row[0],
rollout_id,
hostname,
row[0],
)
break
@@ -232,7 +234,9 @@ async def push_single_device(job_id: str) -> None:
except Exception as exc:
logger.error(
"Uncaught exception in template push job %s: %s",
job_id, exc, exc_info=True,
job_id,
exc,
exc_info=True,
)
await _update_job(job_id, status="failed", error_message=f"Unexpected error: {exc}")
@@ -260,8 +264,13 @@ async def _run_single_push(job_id: str) -> None:
return
(
_, device_id, tenant_id, rendered_content,
ip_address, hostname, encrypted_credentials,
_,
device_id,
tenant_id,
rendered_content,
ip_address,
hostname,
encrypted_credentials,
encrypted_credentials_transit,
) = row
@@ -279,16 +288,21 @@ async def _run_single_push(job_id: str) -> None:
try:
from app.services.crypto import decrypt_credentials_hybrid
key = settings.get_encryption_key_bytes()
creds_json = await decrypt_credentials_hybrid(
encrypted_credentials_transit, encrypted_credentials, tenant_id, key,
encrypted_credentials_transit,
encrypted_credentials,
tenant_id,
key,
)
creds = json.loads(creds_json)
ssh_username = creds.get("username", "")
ssh_password = creds.get("password", "")
except Exception as cred_err:
await _update_job(
job_id, status="failed",
job_id,
status="failed",
error_message=f"Failed to decrypt credentials: {cred_err}",
)
return
@@ -297,6 +311,7 @@ async def _run_single_push(job_id: str) -> None:
logger.info("Running mandatory pre-push backup for device %s (%s)", hostname, ip_address)
try:
from app.services import backup_service
backup_result = await backup_service.run_backup(
device_id=device_id,
tenant_id=tenant_id,
@@ -308,7 +323,8 @@ async def _run_single_push(job_id: str) -> None:
except Exception as backup_err:
logger.error("Pre-push backup failed for %s: %s", hostname, backup_err)
await _update_job(
job_id, status="failed",
job_id,
status="failed",
error_message=f"Pre-push backup failed: {backup_err}",
)
return
@@ -316,7 +332,8 @@ async def _run_single_push(job_id: str) -> None:
# Step 5: SSH to device - install panic-revert, push config
logger.info(
"Pushing template to device %s (%s): installing panic-revert and uploading config",
hostname, ip_address,
hostname,
ip_address,
)
try:
@@ -359,7 +376,8 @@ async def _run_single_push(job_id: str) -> None:
)
logger.info(
"Template import result for device %s: exit_status=%s stdout=%r",
hostname, import_result.exit_status,
hostname,
import_result.exit_status,
(import_result.stdout or "")[:200],
)
@@ -369,16 +387,21 @@ async def _run_single_push(job_id: str) -> None:
except Exception as cleanup_err:
logger.warning(
"Failed to clean up %s from device %s: %s",
_TEMPLATE_RSC, ip_address, cleanup_err,
_TEMPLATE_RSC,
ip_address,
cleanup_err,
)
except Exception as push_err:
logger.error(
"SSH push phase failed for device %s (%s): %s",
hostname, ip_address, push_err,
hostname,
ip_address,
push_err,
)
await _update_job(
job_id, status="failed",
job_id,
status="failed",
error_message=f"Config push failed during SSH phase: {push_err}",
)
return
@@ -395,9 +418,12 @@ async def _run_single_push(job_id: str) -> None:
logger.info("Device %s (%s) is reachable after push - committing", hostname, ip_address)
try:
async with asyncssh.connect(
ip_address, port=22,
username=ssh_username, password=ssh_password,
known_hosts=None, connect_timeout=30,
ip_address,
port=22,
username=ssh_username,
password=ssh_password,
known_hosts=None,
connect_timeout=30,
) as conn:
await conn.run(
f'/system scheduler remove "{_PANIC_REVERT_SCHEDULER}"',
@@ -410,11 +436,13 @@ async def _run_single_push(job_id: str) -> None:
except Exception as cleanup_err:
logger.warning(
"Failed to clean up panic-revert scheduler/backup on device %s: %s",
hostname, cleanup_err,
hostname,
cleanup_err,
)
await _update_job(
job_id, status="committed",
job_id,
status="committed",
completed_at=datetime.now(timezone.utc),
)
else:
@@ -422,10 +450,13 @@ async def _run_single_push(job_id: str) -> None:
logger.warning(
"Device %s (%s) is unreachable after push - panic-revert scheduler "
"will auto-revert to %s.backup",
hostname, ip_address, _PRE_PUSH_BACKUP,
hostname,
ip_address,
_PRE_PUSH_BACKUP,
)
await _update_job(
job_id, status="reverted",
job_id,
status="reverted",
error_message="Device unreachable after push; auto-reverted via panic-revert scheduler",
completed_at=datetime.now(timezone.utc),
)
@@ -440,9 +471,12 @@ async def _check_reachability(ip: str, username: str, password: str) -> bool:
"""Check if a RouterOS device is reachable via SSH."""
try:
async with asyncssh.connect(
ip, port=22,
username=username, password=password,
known_hosts=None, connect_timeout=30,
ip,
port=22,
username=username,
password=password,
known_hosts=None,
connect_timeout=30,
) as conn:
result = await conn.run("/system identity print", check=True)
logger.debug("Reachability check OK for %s: %r", ip, result.stdout[:50])
@@ -459,7 +493,12 @@ async def _update_job(job_id: str, **kwargs) -> None:
for key, value in kwargs.items():
param_name = f"v_{key}"
if value is None and key in ("error_message", "started_at", "completed_at", "pre_push_backup_sha"):
if value is None and key in (
"error_message",
"started_at",
"completed_at",
"pre_push_backup_sha",
):
sets.append(f"{key} = NULL")
else:
sets.append(f"{key} = :{param_name}")
@@ -472,7 +511,7 @@ async def _update_job(job_id: str, **kwargs) -> None:
await session.execute(
text(f"""
UPDATE template_push_jobs
SET {', '.join(sets)}
SET {", ".join(sets)}
WHERE id = CAST(:job_id AS uuid)
"""),
params,

View File

@@ -13,7 +13,6 @@ jobs may span multiple tenants and run in background asyncio tasks.
"""
import asyncio
import io
import json
import logging
from datetime import datetime, timezone
@@ -99,10 +98,19 @@ async def _run_upgrade(job_id: str) -> None:
return
(
_, device_id, tenant_id, target_version,
architecture, channel, status, confirmed_major,
ip_address, hostname, encrypted_credentials,
current_version, encrypted_credentials_transit,
_,
device_id,
tenant_id,
target_version,
architecture,
channel,
status,
confirmed_major,
ip_address,
hostname,
encrypted_credentials,
current_version,
encrypted_credentials_transit,
) = row
device_id = str(device_id)
@@ -116,12 +124,22 @@ async def _run_upgrade(job_id: str) -> None:
logger.info(
"Starting firmware upgrade for %s (%s): %s -> %s",
hostname, ip_address, current_version, target_version,
hostname,
ip_address,
current_version,
target_version,
)
# Step 2: Update status to downloading
await _update_job(job_id, status="downloading", started_at=datetime.now(timezone.utc))
await _publish_upgrade_progress(tenant_id, device_id, job_id, "downloading", target_version, f"Downloading firmware {target_version} for {hostname}")
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"downloading",
target_version,
f"Downloading firmware {target_version} for {hostname}",
)
# Step 3: Check major version upgrade confirmation
if current_version and target_version:
@@ -133,13 +151,22 @@ async def _run_upgrade(job_id: str) -> None:
status="failed",
error_message="Major version upgrade requires explicit confirmation",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Major version upgrade requires explicit confirmation for {hostname}", error="Major version upgrade requires explicit confirmation")
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"failed",
target_version,
f"Major version upgrade requires explicit confirmation for {hostname}",
error="Major version upgrade requires explicit confirmation",
)
return
# Step 4: Mandatory config backup
logger.info("Running mandatory pre-upgrade backup for %s", hostname)
try:
from app.services import backup_service
backup_result = await backup_service.run_backup(
device_id=device_id,
tenant_id=tenant_id,
@@ -155,13 +182,22 @@ async def _run_upgrade(job_id: str) -> None:
status="failed",
error_message=f"Pre-upgrade backup failed: {backup_err}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Pre-upgrade backup failed for {hostname}", error=str(backup_err))
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"failed",
target_version,
f"Pre-upgrade backup failed for {hostname}",
error=str(backup_err),
)
return
# Step 5: Download NPK
logger.info("Downloading firmware %s for %s/%s", target_version, architecture, channel)
try:
from app.services.firmware_service import download_firmware
npk_path = await download_firmware(architecture, channel, target_version)
logger.info("Firmware cached at %s", npk_path)
except Exception as dl_err:
@@ -171,24 +207,51 @@ async def _run_upgrade(job_id: str) -> None:
status="failed",
error_message=f"Firmware download failed: {dl_err}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Firmware download failed for {hostname}", error=str(dl_err))
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"failed",
target_version,
f"Firmware download failed for {hostname}",
error=str(dl_err),
)
return
# Step 6: Upload NPK to device via SFTP
await _update_job(job_id, status="uploading")
await _publish_upgrade_progress(tenant_id, device_id, job_id, "uploading", target_version, f"Uploading firmware to {hostname}")
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"uploading",
target_version,
f"Uploading firmware to {hostname}",
)
# Decrypt device credentials (dual-read: Transit preferred, legacy fallback)
if not encrypted_credentials_transit and not encrypted_credentials:
await _update_job(job_id, status="failed", error_message="Device has no stored credentials")
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"No stored credentials for {hostname}", error="Device has no stored credentials")
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"failed",
target_version,
f"No stored credentials for {hostname}",
error="Device has no stored credentials",
)
return
try:
from app.services.crypto import decrypt_credentials_hybrid
key = settings.get_encryption_key_bytes()
creds_json = await decrypt_credentials_hybrid(
encrypted_credentials_transit, encrypted_credentials, tenant_id, key,
encrypted_credentials_transit,
encrypted_credentials,
tenant_id,
key,
)
creds = json.loads(creds_json)
ssh_username = creds.get("username", "")
@@ -199,7 +262,15 @@ async def _run_upgrade(job_id: str) -> None:
status="failed",
error_message=f"Failed to decrypt credentials: {cred_err}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Failed to decrypt credentials for {hostname}", error=str(cred_err))
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"failed",
target_version,
f"Failed to decrypt credentials for {hostname}",
error=str(cred_err),
)
return
try:
@@ -225,12 +296,27 @@ async def _run_upgrade(job_id: str) -> None:
status="failed",
error_message=f"NPK upload failed: {upload_err}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"NPK upload failed for {hostname}", error=str(upload_err))
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"failed",
target_version,
f"NPK upload failed for {hostname}",
error=str(upload_err),
)
return
# Step 7: Trigger reboot
await _update_job(job_id, status="rebooting")
await _publish_upgrade_progress(tenant_id, device_id, job_id, "rebooting", target_version, f"Rebooting {hostname} for firmware install")
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"rebooting",
target_version,
f"Rebooting {hostname} for firmware install",
)
try:
async with asyncssh.connect(
ip_address,
@@ -245,7 +331,9 @@ async def _run_upgrade(job_id: str) -> None:
logger.info("Reboot command sent to %s", hostname)
except Exception as reboot_err:
# Device may drop connection during reboot — this is expected
logger.info("Device %s dropped connection after reboot command (expected): %s", hostname, reboot_err)
logger.info(
"Device %s dropped connection after reboot command (expected): %s", hostname, reboot_err
)
# Step 8: Wait for reconnect
logger.info("Waiting %ds before polling %s for reconnect", _INITIAL_WAIT, hostname)
@@ -267,36 +355,69 @@ async def _run_upgrade(job_id: str) -> None:
status="failed",
error_message=f"Device did not reconnect within {_RECONNECT_TIMEOUT // 60} minutes after reboot",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Device {hostname} did not reconnect within {_RECONNECT_TIMEOUT // 60} minutes", error="Reconnect timeout")
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"failed",
target_version,
f"Device {hostname} did not reconnect within {_RECONNECT_TIMEOUT // 60} minutes",
error="Reconnect timeout",
)
return
# Step 9: Verify upgrade
await _update_job(job_id, status="verifying")
await _publish_upgrade_progress(tenant_id, device_id, job_id, "verifying", target_version, f"Verifying firmware version on {hostname}")
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"verifying",
target_version,
f"Verifying firmware version on {hostname}",
)
try:
actual_version = await _get_device_version(ip_address, ssh_username, ssh_password)
if actual_version and target_version in actual_version:
logger.info(
"Firmware upgrade verified for %s: %s",
hostname, actual_version,
hostname,
actual_version,
)
await _update_job(
job_id,
status="completed",
completed_at=datetime.now(timezone.utc),
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "completed", target_version, f"Firmware upgrade to {target_version} completed on {hostname}")
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"completed",
target_version,
f"Firmware upgrade to {target_version} completed on {hostname}",
)
else:
logger.error(
"Version mismatch for %s: expected %s, got %s",
hostname, target_version, actual_version,
hostname,
target_version,
actual_version,
)
await _update_job(
job_id,
status="failed",
error_message=f"Expected {target_version} but got {actual_version}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Version mismatch on {hostname}: expected {target_version}, got {actual_version}", error=f"Expected {target_version} but got {actual_version}")
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"failed",
target_version,
f"Version mismatch on {hostname}: expected {target_version}, got {actual_version}",
error=f"Expected {target_version} but got {actual_version}",
)
except Exception as verify_err:
logger.error("Post-upgrade verification failed for %s: %s", hostname, verify_err)
await _update_job(
@@ -304,7 +425,15 @@ async def _run_upgrade(job_id: str) -> None:
status="failed",
error_message=f"Post-upgrade verification failed: {verify_err}",
)
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Post-upgrade verification failed for {hostname}", error=str(verify_err))
await _publish_upgrade_progress(
tenant_id,
device_id,
job_id,
"failed",
target_version,
f"Post-upgrade verification failed for {hostname}",
error=str(verify_err),
)
async def start_mass_upgrade(rollout_group_id: str) -> dict:
@@ -457,7 +586,7 @@ async def resume_mass_upgrade(rollout_group_id: str) -> None:
"""Resume a paused mass rollout from where it left off."""
# Reset first paused job to pending, then restart sequential processing
async with AdminAsyncSessionLocal() as session:
result = await session.execute(
await session.execute(
text("""
UPDATE firmware_upgrade_jobs
SET status = 'pending'
@@ -519,7 +648,7 @@ async def _update_job(job_id: str, **kwargs) -> None:
await session.execute(
text(f"""
UPDATE firmware_upgrade_jobs
SET {', '.join(sets)}
SET {", ".join(sets)}
WHERE id = CAST(:job_id AS uuid)
"""),
params,

View File

@@ -26,7 +26,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.models.device import Device
from app.models.vpn import VpnConfig, VpnPeer
from app.services.crypto import decrypt_credentials, encrypt_credentials, encrypt_credentials_transit
from app.services.crypto import (
decrypt_credentials,
encrypt_credentials,
encrypt_credentials_transit,
)
logger = structlog.get_logger(__name__)
@@ -62,7 +66,9 @@ async def _get_or_create_global_server_key(db: AsyncSession) -> tuple[str, str]:
await db.execute(sa_text("SELECT pg_advisory_xact_lock(hashtext('vpn_server_keygen'))"))
result = await db.execute(
sa_text("SELECT key, value, encrypted_value FROM system_settings WHERE key IN ('vpn_server_public_key', 'vpn_server_private_key')")
sa_text(
"SELECT key, value, encrypted_value FROM system_settings WHERE key IN ('vpn_server_public_key', 'vpn_server_private_key')"
)
)
rows = {row[0]: row for row in result.fetchall()}
@@ -188,7 +194,9 @@ async def sync_wireguard_config() -> None:
# Query ALL enabled VPN configs (admin session bypasses RLS)
configs_result = await admin_db.execute(
select(VpnConfig).where(VpnConfig.is_enabled.is_(True)).order_by(VpnConfig.subnet_index)
select(VpnConfig)
.where(VpnConfig.is_enabled.is_(True))
.order_by(VpnConfig.subnet_index)
)
configs = configs_result.scalars().all()
@@ -228,7 +236,9 @@ async def sync_wireguard_config() -> None:
peer_ip = peer.assigned_ip.split("/")[0]
allowed_ips = [f"{peer_ip}/32"]
if peer.additional_allowed_ips:
extra = [s.strip() for s in peer.additional_allowed_ips.split(",") if s.strip()]
extra = [
s.strip() for s in peer.additional_allowed_ips.split(",") if s.strip()
]
allowed_ips.extend(extra)
lines.append("[Peer]")
lines.append(f"PublicKey = {peer.peer_public_key}")
@@ -253,12 +263,13 @@ async def sync_wireguard_config() -> None:
# Docker traffic (172.16.0.0/12) going to each tenant's subnet
# gets SNATted to that tenant's gateway IP (.1) so the router
# can route replies back through the tunnel.
nat_lines = ["#!/bin/sh",
"# Auto-generated per-tenant SNAT rules",
"# Remove old rules",
"iptables -t nat -F POSTROUTING 2>/dev/null",
"# Re-add Docker DNS rules",
]
nat_lines = [
"#!/bin/sh",
"# Auto-generated per-tenant SNAT rules",
"# Remove old rules",
"iptables -t nat -F POSTROUTING 2>/dev/null",
"# Re-add Docker DNS rules",
]
for config in configs:
gateway_ip = config.server_address.split("/")[0] # e.g. 10.10.3.1
subnet = config.subnet # e.g. 10.10.3.0/24
@@ -275,12 +286,15 @@ async def sync_wireguard_config() -> None:
reload_flag = wg_confs_dir / ".reload"
reload_flag.write_text("1")
logger.info("wireguard_config_synced", audit=True,
tenants=len(configs), peers=total_peers)
logger.info(
"wireguard_config_synced", audit=True, tenants=len(configs), peers=total_peers
)
finally:
# Release advisory lock explicitly (session-level lock, not xact-level)
await admin_db.execute(sa_text("SELECT pg_advisory_unlock(hashtext('wireguard_config'))"))
await admin_db.execute(
sa_text("SELECT pg_advisory_unlock(hashtext('wireguard_config'))")
)
# ── Live Status ──
@@ -371,15 +385,23 @@ async def setup_vpn(
db.add(config)
await db.flush()
logger.info("vpn_subnet_allocated", audit=True,
tenant_id=str(tenant_id), subnet_index=subnet_index, subnet=subnet)
logger.info(
"vpn_subnet_allocated",
audit=True,
tenant_id=str(tenant_id),
subnet_index=subnet_index,
subnet=subnet,
)
await _commit_and_sync(db)
return config
async def update_vpn_config(
db: AsyncSession, tenant_id: uuid.UUID, endpoint: Optional[str] = None, is_enabled: Optional[bool] = None
db: AsyncSession,
tenant_id: uuid.UUID,
endpoint: Optional[str] = None,
is_enabled: Optional[bool] = None,
) -> VpnConfig:
"""Update VPN config settings."""
config = await get_vpn_config(db, tenant_id)
@@ -422,14 +444,21 @@ async def _next_available_ip(db: AsyncSession, tenant_id: uuid.UUID, config: Vpn
raise ValueError("No available IPs in VPN subnet")
async def add_peer(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UUID, additional_allowed_ips: Optional[str] = None) -> VpnPeer:
async def add_peer(
db: AsyncSession,
tenant_id: uuid.UUID,
device_id: uuid.UUID,
additional_allowed_ips: Optional[str] = None,
) -> VpnPeer:
"""Add a device as a VPN peer."""
config = await get_vpn_config(db, tenant_id)
if not config:
raise ValueError("VPN not configured — enable VPN first")
# Check device exists
device = await db.execute(select(Device).where(Device.id == device_id, Device.tenant_id == tenant_id))
device = await db.execute(
select(Device).where(Device.id == device_id, Device.tenant_id == tenant_id)
)
if not device.scalar_one_or_none():
raise ValueError("Device not found")
@@ -497,13 +526,12 @@ async def get_peer_config(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.
psk = decrypt_credentials(peer.preshared_key, key_bytes) if peer.preshared_key else None
endpoint = config.endpoint or "YOUR_SERVER_IP:51820"
peer_ip_no_cidr = peer.assigned_ip.split("/")[0]
routeros_commands = [
f'/interface wireguard add name=wg-portal listen-port=13231 private-key="{private_key}"',
f'/interface wireguard peers add interface=wg-portal public-key="{config.server_public_key}" '
f'endpoint-address={endpoint.split(":")[0]} endpoint-port={endpoint.split(":")[-1]} '
f'allowed-address=10.10.0.0/16 persistent-keepalive=25'
f"endpoint-address={endpoint.split(':')[0]} endpoint-port={endpoint.split(':')[-1]} "
f"allowed-address=10.10.0.0/16 persistent-keepalive=25"
+ (f' preshared-key="{psk}"' if psk else ""),
f"/ip address add address={peer.assigned_ip} interface=wg-portal",
]
@@ -583,8 +611,8 @@ async def onboard_device(
routeros_commands = [
f'/interface wireguard add name=wg-portal listen-port=13231 private-key="{private_key_b64}"',
f'/interface wireguard peers add interface=wg-portal public-key="{config.server_public_key}" '
f'endpoint-address={endpoint.split(":")[0]} endpoint-port={endpoint.split(":")[-1]} '
f'allowed-address=10.10.0.0/16 persistent-keepalive=25'
f"endpoint-address={endpoint.split(':')[0]} endpoint-port={endpoint.split(':')[-1]} "
f"allowed-address=10.10.0.0/16 persistent-keepalive=25"
f' preshared-key="{psk_decrypted}"',
f"/ip address add address={assigned_ip} interface=wg-portal",
]