fix(lint): resolve all ruff lint errors
Add ruff config to exclude alembic E402, SQLAlchemy F821, and pre-existing E501 line-length issues. Auto-fix 69 unused imports and 2 f-strings without placeholders. Manually fix 8 unused variables. Apply ruff format to 127 files. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -86,11 +86,7 @@ async def delete_user_account(
|
||||
|
||||
# Null out encrypted_details (may contain encrypted PII)
|
||||
await db.execute(
|
||||
text(
|
||||
"UPDATE audit_logs "
|
||||
"SET encrypted_details = NULL "
|
||||
"WHERE user_id = :user_id"
|
||||
),
|
||||
text("UPDATE audit_logs SET encrypted_details = NULL WHERE user_id = :user_id"),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
|
||||
@@ -177,16 +173,18 @@ async def export_user_data(
|
||||
)
|
||||
api_keys = []
|
||||
for row in result.mappings().all():
|
||||
api_keys.append({
|
||||
"id": str(row["id"]),
|
||||
"name": row["name"],
|
||||
"key_prefix": row["key_prefix"],
|
||||
"scopes": row["scopes"],
|
||||
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
||||
"expires_at": row["expires_at"].isoformat() if row["expires_at"] else None,
|
||||
"revoked_at": row["revoked_at"].isoformat() if row["revoked_at"] else None,
|
||||
"last_used_at": row["last_used_at"].isoformat() if row["last_used_at"] else None,
|
||||
})
|
||||
api_keys.append(
|
||||
{
|
||||
"id": str(row["id"]),
|
||||
"name": row["name"],
|
||||
"key_prefix": row["key_prefix"],
|
||||
"scopes": row["scopes"],
|
||||
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
||||
"expires_at": row["expires_at"].isoformat() if row["expires_at"] else None,
|
||||
"revoked_at": row["revoked_at"].isoformat() if row["revoked_at"] else None,
|
||||
"last_used_at": row["last_used_at"].isoformat() if row["last_used_at"] else None,
|
||||
}
|
||||
)
|
||||
|
||||
# ── Audit logs (limit 1000, most recent first) ───────────────────────
|
||||
result = await db.execute(
|
||||
@@ -201,15 +199,17 @@ async def export_user_data(
|
||||
audit_logs = []
|
||||
for row in result.mappings().all():
|
||||
details = row["details"] if row["details"] else {}
|
||||
audit_logs.append({
|
||||
"id": str(row["id"]),
|
||||
"action": row["action"],
|
||||
"resource_type": row["resource_type"],
|
||||
"resource_id": row["resource_id"],
|
||||
"details": details,
|
||||
"ip_address": row["ip_address"],
|
||||
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
||||
})
|
||||
audit_logs.append(
|
||||
{
|
||||
"id": str(row["id"]),
|
||||
"action": row["action"],
|
||||
"resource_type": row["resource_type"],
|
||||
"resource_id": row["resource_id"],
|
||||
"details": details,
|
||||
"ip_address": row["ip_address"],
|
||||
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
||||
}
|
||||
)
|
||||
|
||||
# ── Key access log (limit 1000, most recent first) ───────────────────
|
||||
result = await db.execute(
|
||||
@@ -222,13 +222,15 @@ async def export_user_data(
|
||||
)
|
||||
key_access_entries = []
|
||||
for row in result.mappings().all():
|
||||
key_access_entries.append({
|
||||
"id": str(row["id"]),
|
||||
"action": row["action"],
|
||||
"resource_type": row["resource_type"],
|
||||
"ip_address": row["ip_address"],
|
||||
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
||||
})
|
||||
key_access_entries.append(
|
||||
{
|
||||
"id": str(row["id"]),
|
||||
"action": row["action"],
|
||||
"resource_type": row["resource_type"],
|
||||
"ip_address": row["ip_address"],
|
||||
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"export_date": datetime.now(UTC).isoformat(),
|
||||
|
||||
@@ -253,7 +253,9 @@ async def _get_device_groups(device_id: str) -> list[str]:
|
||||
"""Get group IDs for a device."""
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
text("SELECT group_id FROM device_group_memberships WHERE device_id = CAST(:device_id AS uuid)"),
|
||||
text(
|
||||
"SELECT group_id FROM device_group_memberships WHERE device_id = CAST(:device_id AS uuid)"
|
||||
),
|
||||
{"device_id": device_id},
|
||||
)
|
||||
return [str(row[0]) for row in result.fetchall()]
|
||||
@@ -344,30 +346,36 @@ async def _create_alert_event(
|
||||
|
||||
# Publish real-time event to NATS for SSE pipeline (fire-and-forget)
|
||||
if status in ("firing", "flapping"):
|
||||
await publish_event(f"alert.fired.{tenant_id}", {
|
||||
"event_type": "alert_fired",
|
||||
"tenant_id": tenant_id,
|
||||
"device_id": device_id,
|
||||
"alert_event_id": alert_data["id"],
|
||||
"severity": severity,
|
||||
"metric": metric,
|
||||
"current_value": value,
|
||||
"threshold": threshold,
|
||||
"message": message,
|
||||
"is_flapping": is_flapping,
|
||||
"fired_at": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await publish_event(
|
||||
f"alert.fired.{tenant_id}",
|
||||
{
|
||||
"event_type": "alert_fired",
|
||||
"tenant_id": tenant_id,
|
||||
"device_id": device_id,
|
||||
"alert_event_id": alert_data["id"],
|
||||
"severity": severity,
|
||||
"metric": metric,
|
||||
"current_value": value,
|
||||
"threshold": threshold,
|
||||
"message": message,
|
||||
"is_flapping": is_flapping,
|
||||
"fired_at": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
elif status == "resolved":
|
||||
await publish_event(f"alert.resolved.{tenant_id}", {
|
||||
"event_type": "alert_resolved",
|
||||
"tenant_id": tenant_id,
|
||||
"device_id": device_id,
|
||||
"alert_event_id": alert_data["id"],
|
||||
"severity": severity,
|
||||
"metric": metric,
|
||||
"message": message,
|
||||
"resolved_at": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await publish_event(
|
||||
f"alert.resolved.{tenant_id}",
|
||||
{
|
||||
"event_type": "alert_resolved",
|
||||
"tenant_id": tenant_id,
|
||||
"device_id": device_id,
|
||||
"alert_event_id": alert_data["id"],
|
||||
"severity": severity,
|
||||
"metric": metric,
|
||||
"message": message,
|
||||
"resolved_at": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
return alert_data
|
||||
|
||||
@@ -470,6 +478,7 @@ async def _dispatch_async(alert_event: dict, channels: list[dict], device_hostna
|
||||
"""Fire-and-forget notification dispatch."""
|
||||
try:
|
||||
from app.services.notification_service import dispatch_notifications
|
||||
|
||||
await dispatch_notifications(alert_event, channels, device_hostname)
|
||||
except Exception as e:
|
||||
logger.warning("Notification dispatch failed: %s", e)
|
||||
@@ -500,7 +509,8 @@ async def evaluate(
|
||||
if await _is_device_in_maintenance(tenant_id, device_id):
|
||||
logger.debug(
|
||||
"Alert suppressed by maintenance window for device %s tenant %s",
|
||||
device_id, tenant_id,
|
||||
device_id,
|
||||
tenant_id,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -573,7 +583,8 @@ async def evaluate(
|
||||
if is_flapping:
|
||||
logger.info(
|
||||
"Alert %s for device %s is flapping — notifications suppressed",
|
||||
rule["name"], device_id,
|
||||
rule["name"],
|
||||
device_id,
|
||||
)
|
||||
else:
|
||||
channels = await _get_channels_for_rule(rule["id"])
|
||||
|
||||
@@ -56,9 +56,7 @@ async def log_action(
|
||||
try:
|
||||
from app.services.crypto import encrypt_data_transit
|
||||
|
||||
encrypted_details = await encrypt_data_transit(
|
||||
details_json, str(tenant_id)
|
||||
)
|
||||
encrypted_details = await encrypt_data_transit(details_json, str(tenant_id))
|
||||
# Encryption succeeded — clear plaintext details
|
||||
details_json = _json.dumps({})
|
||||
except Exception:
|
||||
|
||||
@@ -32,8 +32,12 @@ def _cron_to_trigger(cron_expr: str) -> Optional[CronTrigger]:
|
||||
return None
|
||||
minute, hour, day, month, day_of_week = parts
|
||||
return CronTrigger(
|
||||
minute=minute, hour=hour, day=day, month=month,
|
||||
day_of_week=day_of_week, timezone="UTC",
|
||||
minute=minute,
|
||||
hour=hour,
|
||||
day=day,
|
||||
month=month,
|
||||
day_of_week=day_of_week,
|
||||
timezone="UTC",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Invalid cron expression '%s': %s", cron_expr, e)
|
||||
@@ -52,10 +56,12 @@ def build_schedule_map(schedules: list) -> dict[str, list[dict]]:
|
||||
cron = s.cron_expression or DEFAULT_CRON
|
||||
if cron not in schedule_map:
|
||||
schedule_map[cron] = []
|
||||
schedule_map[cron].append({
|
||||
"device_id": str(s.device_id),
|
||||
"tenant_id": str(s.tenant_id),
|
||||
})
|
||||
schedule_map[cron].append(
|
||||
{
|
||||
"device_id": str(s.device_id),
|
||||
"tenant_id": str(s.tenant_id),
|
||||
}
|
||||
)
|
||||
return schedule_map
|
||||
|
||||
|
||||
@@ -79,13 +85,15 @@ async def _run_scheduled_backups(devices: list[dict]) -> None:
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Scheduled backup FAILED: device %s: %s",
|
||||
dev_info["device_id"], e,
|
||||
dev_info["device_id"],
|
||||
e,
|
||||
)
|
||||
failure_count += 1
|
||||
|
||||
logger.info(
|
||||
"Backup batch complete — %d succeeded, %d failed",
|
||||
success_count, failure_count,
|
||||
success_count,
|
||||
failure_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -108,7 +116,7 @@ async def _load_effective_schedules() -> list:
|
||||
|
||||
# Index: device-specific and tenant defaults
|
||||
device_schedules = {} # device_id -> schedule
|
||||
tenant_defaults = {} # tenant_id -> schedule
|
||||
tenant_defaults = {} # tenant_id -> schedule
|
||||
|
||||
for s in schedules:
|
||||
if s.device_id:
|
||||
@@ -129,12 +137,14 @@ async def _load_effective_schedules() -> list:
|
||||
# No schedule configured — use system default
|
||||
sched = None
|
||||
|
||||
effective.append(SimpleNamespace(
|
||||
device_id=dev_id,
|
||||
tenant_id=tenant_id,
|
||||
cron_expression=sched.cron_expression if sched else DEFAULT_CRON,
|
||||
enabled=sched.enabled if sched else True,
|
||||
))
|
||||
effective.append(
|
||||
SimpleNamespace(
|
||||
device_id=dev_id,
|
||||
tenant_id=tenant_id,
|
||||
cron_expression=sched.cron_expression if sched else DEFAULT_CRON,
|
||||
enabled=sched.enabled if sched else True,
|
||||
)
|
||||
)
|
||||
|
||||
return effective
|
||||
|
||||
@@ -203,6 +213,7 @@ class _SchedulerProxy:
|
||||
Usage: `from app.services.backup_scheduler import backup_scheduler`
|
||||
then `backup_scheduler.add_job(...)`.
|
||||
"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
if _scheduler is None:
|
||||
raise RuntimeError("Backup scheduler not started yet")
|
||||
|
||||
@@ -255,7 +255,12 @@ async def run_backup(
|
||||
if prior_commits:
|
||||
try:
|
||||
prior_export_bytes = await loop.run_in_executor(
|
||||
None, git_store.read_file, tenant_id, prior_commits[0]["sha"], device_id, "export.rsc"
|
||||
None,
|
||||
git_store.read_file,
|
||||
tenant_id,
|
||||
prior_commits[0]["sha"],
|
||||
device_id,
|
||||
"export.rsc",
|
||||
)
|
||||
prior_text = prior_export_bytes.decode("utf-8", errors="replace")
|
||||
lines_added, lines_removed = await loop.run_in_executor(
|
||||
@@ -284,9 +289,7 @@ async def run_backup(
|
||||
try:
|
||||
from app.services.crypto import encrypt_data_transit
|
||||
|
||||
encrypted_export = await encrypt_data_transit(
|
||||
export_text, tenant_id
|
||||
)
|
||||
encrypted_export = await encrypt_data_transit(export_text, tenant_id)
|
||||
encrypted_binary = await encrypt_data_transit(
|
||||
base64.b64encode(binary_backup).decode(), tenant_id
|
||||
)
|
||||
@@ -302,8 +305,7 @@ async def run_backup(
|
||||
except Exception as enc_err:
|
||||
# Transit unavailable — fall back to plaintext (non-fatal)
|
||||
logger.warning(
|
||||
"Transit encryption failed for %s backup of device %s, "
|
||||
"storing plaintext: %s",
|
||||
"Transit encryption failed for %s backup of device %s, storing plaintext: %s",
|
||||
trigger_type,
|
||||
device_id,
|
||||
enc_err,
|
||||
@@ -313,9 +315,7 @@ async def run_backup(
|
||||
# -----------------------------------------------------------------------
|
||||
# Step 6: Commit to git (wrapped in run_in_executor — pygit2 is sync C bindings)
|
||||
# -----------------------------------------------------------------------
|
||||
commit_message = (
|
||||
f"{trigger_type}: {hostname} ({ip}) at {ts}"
|
||||
)
|
||||
commit_message = f"{trigger_type}: {hostname} ({ip}) at {ts}"
|
||||
|
||||
commit_sha = await loop.run_in_executor(
|
||||
None,
|
||||
|
||||
@@ -48,6 +48,7 @@ _VALID_TRANSITIONS: dict[str, set[str]] = {
|
||||
# CA Generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def generate_ca(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
@@ -84,10 +85,12 @@ async def generate_ca(
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
expiry = now + datetime.timedelta(days=365 * validity_years)
|
||||
|
||||
subject = issuer = x509.Name([
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
||||
])
|
||||
subject = issuer = x509.Name(
|
||||
[
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
|
||||
]
|
||||
)
|
||||
|
||||
ca_cert = (
|
||||
x509.CertificateBuilder()
|
||||
@@ -97,9 +100,7 @@ async def generate_ca(
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now)
|
||||
.not_valid_after(expiry)
|
||||
.add_extension(
|
||||
x509.BasicConstraints(ca=True, path_length=0), critical=True
|
||||
)
|
||||
.add_extension(x509.BasicConstraints(ca=True, path_length=0), critical=True)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
@@ -166,6 +167,7 @@ async def generate_ca(
|
||||
# Device Certificate Signing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def sign_device_cert(
|
||||
db: AsyncSession,
|
||||
ca: CertificateAuthority,
|
||||
@@ -196,9 +198,7 @@ async def sign_device_cert(
|
||||
str(ca.tenant_id),
|
||||
encryption_key,
|
||||
)
|
||||
ca_key = serialization.load_pem_private_key(
|
||||
ca_key_pem.encode("utf-8"), password=None
|
||||
)
|
||||
ca_key = serialization.load_pem_private_key(ca_key_pem.encode("utf-8"), password=None)
|
||||
|
||||
# Load CA certificate for issuer info and AuthorityKeyIdentifier
|
||||
ca_cert = x509.load_pem_x509_certificate(ca.cert_pem.encode("utf-8"))
|
||||
@@ -212,19 +212,19 @@ async def sign_device_cert(
|
||||
device_cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(
|
||||
x509.Name([
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, hostname),
|
||||
])
|
||||
x509.Name(
|
||||
[
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, hostname),
|
||||
]
|
||||
)
|
||||
)
|
||||
.issuer_name(ca_cert.subject)
|
||||
.public_key(device_key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now)
|
||||
.not_valid_after(expiry)
|
||||
.add_extension(
|
||||
x509.BasicConstraints(ca=False, path_length=None), critical=True
|
||||
)
|
||||
.add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
@@ -244,17 +244,17 @@ async def sign_device_cert(
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.SubjectAlternativeName([
|
||||
x509.IPAddress(ipaddress.ip_address(ip_address)),
|
||||
x509.DNSName(hostname),
|
||||
]),
|
||||
x509.SubjectAlternativeName(
|
||||
[
|
||||
x509.IPAddress(ipaddress.ip_address(ip_address)),
|
||||
x509.DNSName(hostname),
|
||||
]
|
||||
),
|
||||
critical=False,
|
||||
)
|
||||
.add_extension(
|
||||
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
|
||||
ca_cert.extensions.get_extension_for_class(
|
||||
x509.SubjectKeyIdentifier
|
||||
).value
|
||||
ca_cert.extensions.get_extension_for_class(x509.SubjectKeyIdentifier).value
|
||||
),
|
||||
critical=False,
|
||||
)
|
||||
@@ -308,15 +308,14 @@ async def sign_device_cert(
|
||||
# Queries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def get_ca_for_tenant(
|
||||
db: AsyncSession,
|
||||
tenant_id: UUID,
|
||||
) -> CertificateAuthority | None:
|
||||
"""Return the tenant's CA, or None if not yet initialized."""
|
||||
result = await db.execute(
|
||||
select(CertificateAuthority).where(
|
||||
CertificateAuthority.tenant_id == tenant_id
|
||||
)
|
||||
select(CertificateAuthority).where(CertificateAuthority.tenant_id == tenant_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@@ -352,6 +351,7 @@ async def get_device_certs(
|
||||
# Status Management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def update_cert_status(
|
||||
db: AsyncSession,
|
||||
cert_id: UUID,
|
||||
@@ -377,9 +377,7 @@ async def update_cert_status(
|
||||
Raises:
|
||||
ValueError: If the certificate is not found or the transition is invalid.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
|
||||
)
|
||||
result = await db.execute(select(DeviceCertificate).where(DeviceCertificate.id == cert_id))
|
||||
cert = result.scalar_one_or_none()
|
||||
if cert is None:
|
||||
raise ValueError(f"Device certificate {cert_id} not found")
|
||||
@@ -413,6 +411,7 @@ async def update_cert_status(
|
||||
# Cert Data for Deployment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def get_cert_for_deploy(
|
||||
db: AsyncSession,
|
||||
cert_id: UUID,
|
||||
@@ -434,18 +433,14 @@ async def get_cert_for_deploy(
|
||||
Raises:
|
||||
ValueError: If the certificate or its CA is not found.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
|
||||
)
|
||||
result = await db.execute(select(DeviceCertificate).where(DeviceCertificate.id == cert_id))
|
||||
cert = result.scalar_one_or_none()
|
||||
if cert is None:
|
||||
raise ValueError(f"Device certificate {cert_id} not found")
|
||||
|
||||
# Fetch the CA for the ca_cert_pem
|
||||
ca_result = await db.execute(
|
||||
select(CertificateAuthority).where(
|
||||
CertificateAuthority.id == cert.ca_id
|
||||
)
|
||||
select(CertificateAuthority).where(CertificateAuthority.id == cert.ca_id)
|
||||
)
|
||||
ca = ca_result.scalar_one_or_none()
|
||||
if ca is None:
|
||||
|
||||
@@ -64,9 +64,7 @@ def parse_diff_changes(diff_text: str) -> list[dict]:
|
||||
current_section: str | None = None
|
||||
|
||||
# Track per-component: adds, removes, raw_lines
|
||||
components: dict[str, dict] = defaultdict(
|
||||
lambda: {"adds": 0, "removes": 0, "raw_lines": []}
|
||||
)
|
||||
components: dict[str, dict] = defaultdict(lambda: {"adds": 0, "removes": 0, "raw_lines": []})
|
||||
|
||||
for line in lines:
|
||||
# Skip unified diff headers
|
||||
|
||||
@@ -51,13 +51,15 @@ async def handle_config_changed(event: dict) -> None:
|
||||
if await _last_backup_within_dedup_window(device_id):
|
||||
logger.info(
|
||||
"Config change on device %s — skipping backup (within %dm dedup window)",
|
||||
device_id, DEDUP_WINDOW_MINUTES,
|
||||
device_id,
|
||||
DEDUP_WINDOW_MINUTES,
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Config change detected on device %s (tenant %s): %s -> %s",
|
||||
device_id, tenant_id,
|
||||
device_id,
|
||||
tenant_id,
|
||||
event.get("old_timestamp", "?"),
|
||||
event.get("new_timestamp", "?"),
|
||||
)
|
||||
|
||||
@@ -65,9 +65,7 @@ async def generate_and_store_diff(
|
||||
|
||||
# 2. No previous snapshot = first snapshot for device
|
||||
if prev_row is None:
|
||||
logger.debug(
|
||||
"First snapshot for device %s, no diff to generate", device_id
|
||||
)
|
||||
logger.debug("First snapshot for device %s, no diff to generate", device_id)
|
||||
return
|
||||
|
||||
old_snapshot_id = prev_row._mapping["id"]
|
||||
@@ -103,9 +101,7 @@ async def generate_and_store_diff(
|
||||
# 6. Generate unified diff
|
||||
old_lines = old_plaintext.decode("utf-8").splitlines()
|
||||
new_lines = new_plaintext.decode("utf-8").splitlines()
|
||||
diff_lines = list(
|
||||
difflib.unified_diff(old_lines, new_lines, lineterm="", n=3)
|
||||
)
|
||||
diff_lines = list(difflib.unified_diff(old_lines, new_lines, lineterm="", n=3))
|
||||
|
||||
# 7. If empty diff, skip INSERT
|
||||
diff_text = "\n".join(diff_lines)
|
||||
@@ -118,12 +114,10 @@ async def generate_and_store_diff(
|
||||
|
||||
# 8. Count lines added/removed (exclude +++ and --- headers)
|
||||
lines_added = sum(
|
||||
1 for line in diff_lines
|
||||
if line.startswith("+") and not line.startswith("++")
|
||||
1 for line in diff_lines if line.startswith("+") and not line.startswith("++")
|
||||
)
|
||||
lines_removed = sum(
|
||||
1 for line in diff_lines
|
||||
if line.startswith("-") and not line.startswith("--")
|
||||
1 for line in diff_lines if line.startswith("-") and not line.startswith("--")
|
||||
)
|
||||
|
||||
# 9. INSERT into router_config_diffs (RETURNING id for change parser)
|
||||
@@ -165,6 +159,7 @@ async def generate_and_store_diff(
|
||||
try:
|
||||
from app.services.audit_service import log_action
|
||||
import uuid as _uuid
|
||||
|
||||
await log_action(
|
||||
db=None,
|
||||
tenant_id=_uuid.UUID(tenant_id),
|
||||
@@ -207,12 +202,16 @@ async def generate_and_store_diff(
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"Stored %d config changes for device %s diff %s",
|
||||
len(changes), device_id, diff_id,
|
||||
len(changes),
|
||||
device_id,
|
||||
diff_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Change parser error for device %s diff %s (non-fatal): %s",
|
||||
device_id, diff_id, exc,
|
||||
device_id,
|
||||
diff_id,
|
||||
exc,
|
||||
)
|
||||
config_diff_errors_total.labels(error_type="change_parser").inc()
|
||||
|
||||
|
||||
@@ -89,9 +89,11 @@ async def handle_config_snapshot(msg) -> None:
|
||||
|
||||
collected_at_raw = data.get("collected_at")
|
||||
try:
|
||||
collected_at = datetime.fromisoformat(
|
||||
collected_at_raw.replace("Z", "+00:00")
|
||||
) if collected_at_raw else datetime.now(timezone.utc)
|
||||
collected_at = (
|
||||
datetime.fromisoformat(collected_at_raw.replace("Z", "+00:00"))
|
||||
if collected_at_raw
|
||||
else datetime.now(timezone.utc)
|
||||
)
|
||||
except (ValueError, AttributeError):
|
||||
collected_at = datetime.now(timezone.utc)
|
||||
|
||||
@@ -131,9 +133,7 @@ async def handle_config_snapshot(msg) -> None:
|
||||
# --- Encrypt via OpenBao Transit ---
|
||||
openbao = OpenBaoTransitService()
|
||||
try:
|
||||
encrypted_text = await openbao.encrypt(
|
||||
tenant_id, config_text.encode("utf-8")
|
||||
)
|
||||
encrypted_text = await openbao.encrypt(tenant_id, config_text.encode("utf-8"))
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Transit encrypt failed for device %s tenant %s: %s",
|
||||
@@ -207,7 +207,8 @@ async def handle_config_snapshot(msg) -> None:
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Diff generation failed for device %s (non-fatal): %s",
|
||||
device_id, exc,
|
||||
device_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -233,9 +234,7 @@ async def _subscribe_with_retry(js) -> None:
|
||||
stream="DEVICE_EVENTS",
|
||||
manual_ack=True,
|
||||
)
|
||||
logger.info(
|
||||
"NATS: subscribed to config.snapshot.> (durable: config_snapshot_ingest)"
|
||||
)
|
||||
logger.info("NATS: subscribed to config.snapshot.> (durable: config_snapshot_ingest)")
|
||||
return
|
||||
except Exception as exc:
|
||||
if attempt < max_attempts:
|
||||
|
||||
@@ -29,8 +29,6 @@ from app.models.device import (
|
||||
DeviceTagAssignment,
|
||||
)
|
||||
from app.schemas.device import (
|
||||
BulkAddRequest,
|
||||
BulkAddResult,
|
||||
DeviceCreate,
|
||||
DeviceGroupCreate,
|
||||
DeviceGroupResponse,
|
||||
@@ -43,9 +41,7 @@ from app.schemas.device import (
|
||||
)
|
||||
from app.config import settings
|
||||
from app.services.crypto import (
|
||||
decrypt_credentials,
|
||||
decrypt_credentials_hybrid,
|
||||
encrypt_credentials,
|
||||
encrypt_credentials_transit,
|
||||
)
|
||||
|
||||
@@ -58,9 +54,7 @@ from app.services.crypto import (
|
||||
async def _tcp_reachable(ip: str, port: int, timeout: float = 3.0) -> bool:
|
||||
"""Return True if a TCP connection to ip:port succeeds within timeout."""
|
||||
try:
|
||||
_, writer = await asyncio.wait_for(
|
||||
asyncio.open_connection(ip, port), timeout=timeout
|
||||
)
|
||||
_, writer = await asyncio.wait_for(asyncio.open_connection(ip, port), timeout=timeout)
|
||||
writer.close()
|
||||
try:
|
||||
await writer.wait_closed()
|
||||
@@ -151,6 +145,7 @@ async def create_device(
|
||||
|
||||
if not api_reachable and not ssl_reachable:
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=(
|
||||
@@ -162,9 +157,7 @@ async def create_device(
|
||||
|
||||
# Encrypt credentials via OpenBao Transit (new writes go through Transit)
|
||||
credentials_json = json.dumps({"username": data.username, "password": data.password})
|
||||
transit_ciphertext = await encrypt_credentials_transit(
|
||||
credentials_json, str(tenant_id)
|
||||
)
|
||||
transit_ciphertext = await encrypt_credentials_transit(credentials_json, str(tenant_id))
|
||||
|
||||
device = Device(
|
||||
tenant_id=tenant_id,
|
||||
@@ -180,9 +173,7 @@ async def create_device(
|
||||
await db.refresh(device)
|
||||
|
||||
# Re-query with relationships loaded
|
||||
result = await db.execute(
|
||||
_device_with_relations().where(Device.id == device.id)
|
||||
)
|
||||
result = await db.execute(_device_with_relations().where(Device.id == device.id))
|
||||
device = result.scalar_one()
|
||||
return _build_device_response(device)
|
||||
|
||||
@@ -223,9 +214,7 @@ async def get_devices(
|
||||
if tag_id:
|
||||
base_q = base_q.where(
|
||||
Device.id.in_(
|
||||
select(DeviceTagAssignment.device_id).where(
|
||||
DeviceTagAssignment.tag_id == tag_id
|
||||
)
|
||||
select(DeviceTagAssignment.device_id).where(DeviceTagAssignment.tag_id == tag_id)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -274,9 +263,7 @@ async def get_device(
|
||||
"""Get a single device by ID."""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
result = await db.execute(
|
||||
_device_with_relations().where(Device.id == device_id)
|
||||
)
|
||||
result = await db.execute(_device_with_relations().where(Device.id == device_id))
|
||||
device = result.scalar_one_or_none()
|
||||
if not device:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
|
||||
@@ -295,9 +282,7 @@ async def update_device(
|
||||
"""
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
result = await db.execute(
|
||||
_device_with_relations().where(Device.id == device_id)
|
||||
)
|
||||
result = await db.execute(_device_with_relations().where(Device.id == device_id))
|
||||
device = result.scalar_one_or_none()
|
||||
if not device:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
|
||||
@@ -323,7 +308,9 @@ async def update_device(
|
||||
if data.password is not None:
|
||||
# Decrypt existing to get current username if no new username given
|
||||
current_username: str = data.username or ""
|
||||
if not current_username and (device.encrypted_credentials_transit or device.encrypted_credentials):
|
||||
if not current_username and (
|
||||
device.encrypted_credentials_transit or device.encrypted_credentials
|
||||
):
|
||||
try:
|
||||
existing_json = await decrypt_credentials_hybrid(
|
||||
device.encrypted_credentials_transit,
|
||||
@@ -336,17 +323,21 @@ async def update_device(
|
||||
except Exception:
|
||||
current_username = ""
|
||||
|
||||
credentials_json = json.dumps({
|
||||
"username": data.username if data.username is not None else current_username,
|
||||
"password": data.password,
|
||||
})
|
||||
credentials_json = json.dumps(
|
||||
{
|
||||
"username": data.username if data.username is not None else current_username,
|
||||
"password": data.password,
|
||||
}
|
||||
)
|
||||
# New writes go through Transit
|
||||
device.encrypted_credentials_transit = await encrypt_credentials_transit(
|
||||
credentials_json, str(device.tenant_id)
|
||||
)
|
||||
device.encrypted_credentials = None # Clear legacy (Transit is canonical)
|
||||
credentials_changed = True
|
||||
elif data.username is not None and (device.encrypted_credentials_transit or device.encrypted_credentials):
|
||||
elif data.username is not None and (
|
||||
device.encrypted_credentials_transit or device.encrypted_credentials
|
||||
):
|
||||
# Only username changed — update it without changing the password
|
||||
try:
|
||||
existing_json = await decrypt_credentials_hybrid(
|
||||
@@ -373,6 +364,7 @@ async def update_device(
|
||||
if credentials_changed:
|
||||
try:
|
||||
from app.services.event_publisher import publish_event
|
||||
|
||||
await publish_event(
|
||||
f"device.credential_changed.{device_id}",
|
||||
{"device_id": str(device_id), "tenant_id": str(tenant_id)},
|
||||
@@ -380,9 +372,7 @@ async def update_device(
|
||||
except Exception:
|
||||
pass # Never fail the update due to NATS issues
|
||||
|
||||
result2 = await db.execute(
|
||||
_device_with_relations().where(Device.id == device_id)
|
||||
)
|
||||
result2 = await db.execute(_device_with_relations().where(Device.id == device_id))
|
||||
device = result2.scalar_one()
|
||||
return _build_device_response(device)
|
||||
|
||||
@@ -526,11 +516,7 @@ async def get_groups(
|
||||
tenant_id: uuid.UUID,
|
||||
) -> list[DeviceGroupResponse]:
|
||||
"""Return all device groups for the current tenant with device counts."""
|
||||
result = await db.execute(
|
||||
select(DeviceGroup).options(
|
||||
selectinload(DeviceGroup.memberships)
|
||||
)
|
||||
)
|
||||
result = await db.execute(select(DeviceGroup).options(selectinload(DeviceGroup.memberships)))
|
||||
groups = result.scalars().all()
|
||||
return [
|
||||
DeviceGroupResponse(
|
||||
@@ -554,9 +540,9 @@ async def update_group(
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
result = await db.execute(
|
||||
select(DeviceGroup).options(
|
||||
selectinload(DeviceGroup.memberships)
|
||||
).where(DeviceGroup.id == group_id)
|
||||
select(DeviceGroup)
|
||||
.options(selectinload(DeviceGroup.memberships))
|
||||
.where(DeviceGroup.id == group_id)
|
||||
)
|
||||
group = result.scalar_one_or_none()
|
||||
if not group:
|
||||
@@ -571,9 +557,9 @@ async def update_group(
|
||||
await db.refresh(group)
|
||||
|
||||
result2 = await db.execute(
|
||||
select(DeviceGroup).options(
|
||||
selectinload(DeviceGroup.memberships)
|
||||
).where(DeviceGroup.id == group_id)
|
||||
select(DeviceGroup)
|
||||
.options(selectinload(DeviceGroup.memberships))
|
||||
.where(DeviceGroup.id == group_id)
|
||||
)
|
||||
group = result2.scalar_one()
|
||||
return DeviceGroupResponse(
|
||||
|
||||
@@ -48,7 +48,5 @@ async def generate_emergency_kit_template(
|
||||
# Run weasyprint in thread to avoid blocking the event loop
|
||||
from weasyprint import HTML
|
||||
|
||||
pdf_bytes = await asyncio.to_thread(
|
||||
lambda: HTML(string=html_content).write_pdf()
|
||||
)
|
||||
pdf_bytes = await asyncio.to_thread(lambda: HTML(string=html_content).write_pdf())
|
||||
return pdf_bytes
|
||||
|
||||
@@ -13,7 +13,6 @@ Version discovery comes from two sources:
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
@@ -51,13 +50,12 @@ async def check_latest_versions() -> list[dict]:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
for channel_file, channel, major in _VERSION_SOURCES:
|
||||
try:
|
||||
resp = await client.get(
|
||||
f"https://download.mikrotik.com/routeros/{channel_file}"
|
||||
)
|
||||
resp = await client.get(f"https://download.mikrotik.com/routeros/{channel_file}")
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
"MikroTik version check returned %d for %s",
|
||||
resp.status_code, channel_file,
|
||||
resp.status_code,
|
||||
channel_file,
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -72,12 +70,14 @@ async def check_latest_versions() -> list[dict]:
|
||||
f"https://download.mikrotik.com/routeros/"
|
||||
f"{version}/routeros-{version}-{arch}.npk"
|
||||
)
|
||||
results.append({
|
||||
"architecture": arch,
|
||||
"channel": channel,
|
||||
"version": version,
|
||||
"npk_url": npk_url,
|
||||
})
|
||||
results.append(
|
||||
{
|
||||
"architecture": arch,
|
||||
"channel": channel,
|
||||
"version": version,
|
||||
"npk_url": npk_url,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to check %s: %s", channel_file, e)
|
||||
@@ -242,15 +242,15 @@ async def get_firmware_overview(tenant_id: str) -> dict:
|
||||
groups = []
|
||||
for ver, devs in sorted(version_groups.items()):
|
||||
# A version is "latest" if it matches the latest for any arch/channel combo
|
||||
is_latest = any(
|
||||
v["version"] == ver for v in latest_versions.values()
|
||||
is_latest = any(v["version"] == ver for v in latest_versions.values())
|
||||
groups.append(
|
||||
{
|
||||
"version": ver,
|
||||
"count": len(devs),
|
||||
"is_latest": is_latest,
|
||||
"devices": devs,
|
||||
}
|
||||
)
|
||||
groups.append({
|
||||
"version": ver,
|
||||
"count": len(devs),
|
||||
"is_latest": is_latest,
|
||||
"devices": devs,
|
||||
})
|
||||
|
||||
return {
|
||||
"devices": device_list,
|
||||
@@ -272,12 +272,14 @@ async def get_cached_firmware() -> list[dict]:
|
||||
continue
|
||||
for npk_file in sorted(version_dir.iterdir()):
|
||||
if npk_file.suffix == ".npk":
|
||||
cached.append({
|
||||
"path": str(npk_file),
|
||||
"version": version_dir.name,
|
||||
"filename": npk_file.name,
|
||||
"size_bytes": npk_file.stat().st_size,
|
||||
})
|
||||
cached.append(
|
||||
{
|
||||
"path": str(npk_file),
|
||||
"version": version_dir.name,
|
||||
"filename": npk_file.name,
|
||||
"size_bytes": npk_file.stat().st_size,
|
||||
}
|
||||
)
|
||||
|
||||
return cached
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ async def on_device_firmware(msg) -> None:
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
device_id = data.get("device_id")
|
||||
tenant_id = data.get("tenant_id")
|
||||
_tenant_id = data.get("tenant_id")
|
||||
architecture = data.get("architecture")
|
||||
installed_version = data.get("installed_version")
|
||||
latest_version = data.get("latest_version")
|
||||
@@ -126,9 +126,7 @@ async def _subscribe_with_retry(js: JetStreamContext) -> None:
|
||||
durable="api-firmware-consumer",
|
||||
stream="DEVICE_EVENTS",
|
||||
)
|
||||
logger.info(
|
||||
"NATS: subscribed to device.firmware.> (durable: api-firmware-consumer)"
|
||||
)
|
||||
logger.info("NATS: subscribed to device.firmware.> (durable: api-firmware-consumer)")
|
||||
return
|
||||
except Exception as exc:
|
||||
if attempt < max_attempts:
|
||||
|
||||
@@ -238,11 +238,13 @@ def list_device_commits(
|
||||
# Device wasn't in parent but is in this commit — it's the first entry.
|
||||
pass
|
||||
|
||||
results.append({
|
||||
"sha": str(commit.id),
|
||||
"message": commit.message.strip(),
|
||||
"timestamp": commit.commit_time,
|
||||
})
|
||||
results.append(
|
||||
{
|
||||
"sha": str(commit.id),
|
||||
"message": commit.message.strip(),
|
||||
"timestamp": commit.commit_time,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ async def store_user_key_set(
|
||||
"""
|
||||
# Remove any existing key set (e.g. from a failed prior upgrade attempt)
|
||||
from sqlalchemy import delete
|
||||
|
||||
await db.execute(delete(UserKeySet).where(UserKeySet.user_id == user_id))
|
||||
|
||||
key_set = UserKeySet(
|
||||
@@ -71,9 +72,7 @@ async def store_user_key_set(
|
||||
return key_set
|
||||
|
||||
|
||||
async def get_user_key_set(
|
||||
db: AsyncSession, user_id: UUID
|
||||
) -> UserKeySet | None:
|
||||
async def get_user_key_set(db: AsyncSession, user_id: UUID) -> UserKeySet | None:
|
||||
"""Retrieve encrypted key bundle for login response.
|
||||
|
||||
Args:
|
||||
@@ -83,9 +82,7 @@ async def get_user_key_set(
|
||||
Returns:
|
||||
The UserKeySet if found, None otherwise.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(UserKeySet).where(UserKeySet.user_id == user_id)
|
||||
)
|
||||
result = await db.execute(select(UserKeySet).where(UserKeySet.user_id == user_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
@@ -163,9 +160,7 @@ async def provision_tenant_key(db: AsyncSession, tenant_id: UUID) -> str:
|
||||
await openbao.create_tenant_key(str(tenant_id))
|
||||
|
||||
# Update tenant record with key name
|
||||
result = await db.execute(
|
||||
select(Tenant).where(Tenant.id == tenant_id)
|
||||
)
|
||||
result = await db.execute(select(Tenant).where(Tenant.id == tenant_id))
|
||||
tenant = result.scalar_one_or_none()
|
||||
if tenant:
|
||||
tenant.openbao_key_name = key_name
|
||||
@@ -210,13 +205,18 @@ async def migrate_tenant_credentials(db: AsyncSession, tenant_id: UUID) -> dict:
|
||||
select(Device).where(
|
||||
Device.tenant_id == tenant_id,
|
||||
Device.encrypted_credentials.isnot(None),
|
||||
(Device.encrypted_credentials_transit.is_(None) | (Device.encrypted_credentials_transit == "")),
|
||||
(
|
||||
Device.encrypted_credentials_transit.is_(None)
|
||||
| (Device.encrypted_credentials_transit == "")
|
||||
),
|
||||
)
|
||||
)
|
||||
for device in result.scalars().all():
|
||||
try:
|
||||
plaintext = decrypt_credentials(device.encrypted_credentials, legacy_key)
|
||||
device.encrypted_credentials_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
|
||||
device.encrypted_credentials_transit = await openbao.encrypt(
|
||||
tid, plaintext.encode("utf-8")
|
||||
)
|
||||
counts["devices"] += 1
|
||||
except Exception as e:
|
||||
logger.error("Failed to migrate device %s credentials: %s", device.id, e)
|
||||
@@ -227,7 +227,10 @@ async def migrate_tenant_credentials(db: AsyncSession, tenant_id: UUID) -> dict:
|
||||
select(CertificateAuthority).where(
|
||||
CertificateAuthority.tenant_id == tenant_id,
|
||||
CertificateAuthority.encrypted_private_key.isnot(None),
|
||||
(CertificateAuthority.encrypted_private_key_transit.is_(None) | (CertificateAuthority.encrypted_private_key_transit == "")),
|
||||
(
|
||||
CertificateAuthority.encrypted_private_key_transit.is_(None)
|
||||
| (CertificateAuthority.encrypted_private_key_transit == "")
|
||||
),
|
||||
)
|
||||
)
|
||||
for ca in result.scalars().all():
|
||||
@@ -244,13 +247,18 @@ async def migrate_tenant_credentials(db: AsyncSession, tenant_id: UUID) -> dict:
|
||||
select(DeviceCertificate).where(
|
||||
DeviceCertificate.tenant_id == tenant_id,
|
||||
DeviceCertificate.encrypted_private_key.isnot(None),
|
||||
(DeviceCertificate.encrypted_private_key_transit.is_(None) | (DeviceCertificate.encrypted_private_key_transit == "")),
|
||||
(
|
||||
DeviceCertificate.encrypted_private_key_transit.is_(None)
|
||||
| (DeviceCertificate.encrypted_private_key_transit == "")
|
||||
),
|
||||
)
|
||||
)
|
||||
for cert in result.scalars().all():
|
||||
try:
|
||||
plaintext = decrypt_credentials(cert.encrypted_private_key, legacy_key)
|
||||
cert.encrypted_private_key_transit = await openbao.encrypt(tid, plaintext.encode("utf-8"))
|
||||
cert.encrypted_private_key_transit = await openbao.encrypt(
|
||||
tid, plaintext.encode("utf-8")
|
||||
)
|
||||
counts["certs"] += 1
|
||||
except Exception as e:
|
||||
logger.error("Failed to migrate cert %s private key: %s", cert.id, e)
|
||||
@@ -303,7 +311,14 @@ async def provision_existing_tenants(db: AsyncSession) -> dict:
|
||||
result = await db.execute(select(Tenant))
|
||||
tenants = result.scalars().all()
|
||||
|
||||
total = {"tenants": len(tenants), "devices": 0, "cas": 0, "certs": 0, "channels": 0, "errors": 0}
|
||||
total = {
|
||||
"tenants": len(tenants),
|
||||
"devices": 0,
|
||||
"cas": 0,
|
||||
"certs": 0,
|
||||
"channels": 0,
|
||||
"errors": 0,
|
||||
}
|
||||
|
||||
for tenant in tenants:
|
||||
try:
|
||||
|
||||
@@ -199,9 +199,7 @@ async def on_device_metrics(msg) -> None:
|
||||
device_id = data.get("device_id")
|
||||
|
||||
if not metric_type or not device_id:
|
||||
logger.warning(
|
||||
"device.metrics event missing 'type' or 'device_id' — skipping"
|
||||
)
|
||||
logger.warning("device.metrics event missing 'type' or 'device_id' — skipping")
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
@@ -222,6 +220,7 @@ async def on_device_metrics(msg) -> None:
|
||||
# Alert evaluation — non-fatal; metric write is the primary operation
|
||||
try:
|
||||
from app.services import alert_evaluator
|
||||
|
||||
await alert_evaluator.evaluate(
|
||||
device_id=device_id,
|
||||
tenant_id=data.get("tenant_id", ""),
|
||||
@@ -265,9 +264,7 @@ async def _subscribe_with_retry(js: JetStreamContext) -> None:
|
||||
durable="api-metrics-consumer",
|
||||
stream="DEVICE_EVENTS",
|
||||
)
|
||||
logger.info(
|
||||
"NATS: subscribed to device.metrics.> (durable: api-metrics-consumer)"
|
||||
)
|
||||
logger.info("NATS: subscribed to device.metrics.> (durable: api-metrics-consumer)")
|
||||
return
|
||||
except Exception as exc:
|
||||
if attempt < max_attempts:
|
||||
|
||||
@@ -69,7 +69,9 @@ async def on_device_status(msg) -> None:
|
||||
uptime_seconds = _parse_uptime(data.get("uptime", ""))
|
||||
|
||||
if not device_id or not status:
|
||||
logger.warning("Received device.status event with missing device_id or status — skipping")
|
||||
logger.warning(
|
||||
"Received device.status event with missing device_id or status — skipping"
|
||||
)
|
||||
await msg.ack()
|
||||
return
|
||||
|
||||
@@ -115,12 +117,15 @@ async def on_device_status(msg) -> None:
|
||||
# Alert evaluation for offline/online status changes — non-fatal
|
||||
try:
|
||||
from app.services import alert_evaluator
|
||||
|
||||
if status == "offline":
|
||||
await alert_evaluator.evaluate_offline(device_id, data.get("tenant_id", ""))
|
||||
elif status == "online":
|
||||
await alert_evaluator.evaluate_online(device_id, data.get("tenant_id", ""))
|
||||
except Exception as e:
|
||||
logger.warning("Alert evaluation failed for device %s status=%s: %s", device_id, status, e)
|
||||
logger.warning(
|
||||
"Alert evaluation failed for device %s status=%s: %s", device_id, status, e
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Device status updated",
|
||||
|
||||
@@ -39,7 +39,9 @@ async def dispatch_notifications(
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Notification delivery failed for channel %s (%s): %s",
|
||||
channel.get("name"), channel.get("channel_type"), e,
|
||||
channel.get("name"),
|
||||
channel.get("channel_type"),
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
@@ -100,14 +102,18 @@ async def _send_email(channel: dict, alert_event: dict, device_hostname: str) ->
|
||||
if transit_cipher and tenant_id:
|
||||
try:
|
||||
from app.services.kms_service import decrypt_transit
|
||||
|
||||
smtp_password = await decrypt_transit(transit_cipher, tenant_id)
|
||||
except Exception:
|
||||
logger.warning("Transit decryption failed for channel %s, trying legacy", channel.get("id"))
|
||||
logger.warning(
|
||||
"Transit decryption failed for channel %s, trying legacy", channel.get("id")
|
||||
)
|
||||
|
||||
if not smtp_password and legacy_cipher:
|
||||
try:
|
||||
from app.config import settings as app_settings
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
raw = bytes(legacy_cipher) if isinstance(legacy_cipher, memoryview) else legacy_cipher
|
||||
f = Fernet(app_settings.CREDENTIAL_ENCRYPTION_KEY.encode())
|
||||
smtp_password = f.decrypt(raw).decode()
|
||||
@@ -163,7 +169,8 @@ async def _send_webhook(
|
||||
response = await client.post(webhook_url, json=payload)
|
||||
logger.info(
|
||||
"Webhook notification sent to %s — status %d",
|
||||
webhook_url, response.status_code,
|
||||
webhook_url,
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
|
||||
@@ -180,13 +187,18 @@ async def _send_slack(
|
||||
value = alert_event.get("value")
|
||||
threshold = alert_event.get("threshold")
|
||||
|
||||
color = {"CRITICAL": "#dc2626", "WARNING": "#f59e0b", "INFO": "#3b82f6"}.get(severity, "#6b7280")
|
||||
color = {"CRITICAL": "#dc2626", "WARNING": "#f59e0b", "INFO": "#3b82f6"}.get(
|
||||
severity, "#6b7280"
|
||||
)
|
||||
status_label = "RESOLVED" if status == "resolved" else status
|
||||
|
||||
blocks = [
|
||||
{
|
||||
"type": "header",
|
||||
"text": {"type": "plain_text", "text": f"{'✅' if status == 'resolved' else '🚨'} [{severity}] {status_label.upper()}"},
|
||||
"text": {
|
||||
"type": "plain_text",
|
||||
"text": f"{'✅' if status == 'resolved' else '🚨'} [{severity}] {status_label.upper()}",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "section",
|
||||
@@ -205,7 +217,9 @@ async def _send_slack(
|
||||
blocks.append({"type": "section", "fields": fields})
|
||||
|
||||
if message_text:
|
||||
blocks.append({"type": "section", "text": {"type": "mrkdwn", "text": f"*Message:*\n{message_text}"}})
|
||||
blocks.append(
|
||||
{"type": "section", "text": {"type": "mrkdwn", "text": f"*Message:*\n{message_text}"}}
|
||||
)
|
||||
|
||||
blocks.append({"type": "context", "elements": [{"type": "mrkdwn", "text": "TOD Alert System"}]})
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ material never leaves OpenBao -- the application only sees ciphertext.
|
||||
|
||||
Ciphertext format: "vault:v1:base64..." (compatible with Vault Transit format)
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
@@ -47,13 +47,15 @@ async def record_push(
|
||||
"""
|
||||
r = await _get_redis()
|
||||
key = f"push:recent:{device_id}"
|
||||
value = json.dumps({
|
||||
"device_id": device_id,
|
||||
"tenant_id": tenant_id,
|
||||
"push_type": push_type,
|
||||
"push_operation_id": push_operation_id,
|
||||
"pre_push_commit_sha": pre_push_commit_sha,
|
||||
})
|
||||
value = json.dumps(
|
||||
{
|
||||
"device_id": device_id,
|
||||
"tenant_id": tenant_id,
|
||||
"push_type": push_type,
|
||||
"push_operation_id": push_operation_id,
|
||||
"pre_push_commit_sha": pre_push_commit_sha,
|
||||
}
|
||||
)
|
||||
await r.set(key, value, ex=PUSH_TTL_SECONDS)
|
||||
logger.debug(
|
||||
"Recorded push for device %s (type=%s, TTL=%ds)",
|
||||
|
||||
@@ -164,16 +164,18 @@ async def _device_inventory(
|
||||
uptime_str = _format_uptime(row[6]) if row[6] else None
|
||||
last_seen_str = row[5].strftime("%Y-%m-%d %H:%M") if row[5] else None
|
||||
|
||||
devices.append({
|
||||
"hostname": row[0],
|
||||
"ip_address": row[1],
|
||||
"model": row[2],
|
||||
"routeros_version": row[3],
|
||||
"status": status,
|
||||
"last_seen": last_seen_str,
|
||||
"uptime": uptime_str,
|
||||
"groups": row[7] if row[7] else None,
|
||||
})
|
||||
devices.append(
|
||||
{
|
||||
"hostname": row[0],
|
||||
"ip_address": row[1],
|
||||
"model": row[2],
|
||||
"routeros_version": row[3],
|
||||
"status": status,
|
||||
"last_seen": last_seen_str,
|
||||
"uptime": uptime_str,
|
||||
"groups": row[7] if row[7] else None,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"report_title": "Device Inventory",
|
||||
@@ -224,16 +226,18 @@ async def _metrics_summary(
|
||||
|
||||
devices = []
|
||||
for row in rows:
|
||||
devices.append({
|
||||
"hostname": row[0],
|
||||
"avg_cpu": float(row[1]) if row[1] is not None else None,
|
||||
"peak_cpu": float(row[2]) if row[2] is not None else None,
|
||||
"avg_mem": float(row[3]) if row[3] is not None else None,
|
||||
"peak_mem": float(row[4]) if row[4] is not None else None,
|
||||
"avg_disk": float(row[5]) if row[5] is not None else None,
|
||||
"avg_temp": float(row[6]) if row[6] is not None else None,
|
||||
"data_points": row[7],
|
||||
})
|
||||
devices.append(
|
||||
{
|
||||
"hostname": row[0],
|
||||
"avg_cpu": float(row[1]) if row[1] is not None else None,
|
||||
"peak_cpu": float(row[2]) if row[2] is not None else None,
|
||||
"avg_mem": float(row[3]) if row[3] is not None else None,
|
||||
"peak_mem": float(row[4]) if row[4] is not None else None,
|
||||
"avg_disk": float(row[5]) if row[5] is not None else None,
|
||||
"avg_temp": float(row[6]) if row[6] is not None else None,
|
||||
"data_points": row[7],
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"report_title": "Metrics Summary",
|
||||
@@ -287,14 +291,16 @@ async def _alert_history(
|
||||
if duration_secs is not None:
|
||||
resolved_durations.append(duration_secs)
|
||||
|
||||
alerts.append({
|
||||
"fired_at": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
|
||||
"hostname": row[5],
|
||||
"severity": severity,
|
||||
"status": row[3],
|
||||
"message": row[4],
|
||||
"duration": _format_duration(duration_secs) if duration_secs is not None else None,
|
||||
})
|
||||
alerts.append(
|
||||
{
|
||||
"fired_at": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
|
||||
"hostname": row[5],
|
||||
"severity": severity,
|
||||
"status": row[3],
|
||||
"message": row[4],
|
||||
"duration": _format_duration(duration_secs) if duration_secs is not None else None,
|
||||
}
|
||||
)
|
||||
|
||||
mttr_minutes = None
|
||||
mttr_display = None
|
||||
@@ -374,13 +380,15 @@ async def _change_log_from_audit(
|
||||
|
||||
entries = []
|
||||
for row in rows:
|
||||
entries.append({
|
||||
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
|
||||
"user": row[1],
|
||||
"action": row[2],
|
||||
"device": row[3],
|
||||
"details": row[4] or row[5] or "",
|
||||
})
|
||||
entries.append(
|
||||
{
|
||||
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
|
||||
"user": row[1],
|
||||
"action": row[2],
|
||||
"device": row[3],
|
||||
"details": row[4] or row[5] or "",
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"report_title": "Change Log",
|
||||
@@ -436,21 +444,25 @@ async def _change_log_from_backups(
|
||||
# Merge and sort by timestamp descending
|
||||
entries = []
|
||||
for row in backup_rows:
|
||||
entries.append({
|
||||
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
|
||||
"user": row[1],
|
||||
"action": row[2],
|
||||
"device": row[3],
|
||||
"details": row[4] or "",
|
||||
})
|
||||
entries.append(
|
||||
{
|
||||
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
|
||||
"user": row[1],
|
||||
"action": row[2],
|
||||
"device": row[3],
|
||||
"details": row[4] or "",
|
||||
}
|
||||
)
|
||||
for row in alert_rows:
|
||||
entries.append({
|
||||
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
|
||||
"user": row[1],
|
||||
"action": row[2],
|
||||
"device": row[3],
|
||||
"details": row[4] or "",
|
||||
})
|
||||
entries.append(
|
||||
{
|
||||
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
|
||||
"user": row[1],
|
||||
"action": row[2],
|
||||
"device": row[3],
|
||||
"details": row[4] or "",
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by timestamp string descending
|
||||
entries.sort(key=lambda e: e["timestamp"], reverse=True)
|
||||
@@ -486,54 +498,102 @@ def _render_csv(report_type: str, data: dict[str, Any]) -> bytes:
|
||||
writer = csv.writer(output)
|
||||
|
||||
if report_type == "device_inventory":
|
||||
writer.writerow([
|
||||
"Hostname", "IP Address", "Model", "RouterOS Version",
|
||||
"Status", "Last Seen", "Uptime", "Groups",
|
||||
])
|
||||
writer.writerow(
|
||||
[
|
||||
"Hostname",
|
||||
"IP Address",
|
||||
"Model",
|
||||
"RouterOS Version",
|
||||
"Status",
|
||||
"Last Seen",
|
||||
"Uptime",
|
||||
"Groups",
|
||||
]
|
||||
)
|
||||
for d in data.get("devices", []):
|
||||
writer.writerow([
|
||||
d["hostname"], d["ip_address"], d["model"] or "",
|
||||
d["routeros_version"] or "", d["status"],
|
||||
d["last_seen"] or "", d["uptime"] or "",
|
||||
d["groups"] or "",
|
||||
])
|
||||
writer.writerow(
|
||||
[
|
||||
d["hostname"],
|
||||
d["ip_address"],
|
||||
d["model"] or "",
|
||||
d["routeros_version"] or "",
|
||||
d["status"],
|
||||
d["last_seen"] or "",
|
||||
d["uptime"] or "",
|
||||
d["groups"] or "",
|
||||
]
|
||||
)
|
||||
|
||||
elif report_type == "metrics_summary":
|
||||
writer.writerow([
|
||||
"Hostname", "Avg CPU %", "Peak CPU %", "Avg Memory %",
|
||||
"Peak Memory %", "Avg Disk %", "Avg Temp", "Data Points",
|
||||
])
|
||||
writer.writerow(
|
||||
[
|
||||
"Hostname",
|
||||
"Avg CPU %",
|
||||
"Peak CPU %",
|
||||
"Avg Memory %",
|
||||
"Peak Memory %",
|
||||
"Avg Disk %",
|
||||
"Avg Temp",
|
||||
"Data Points",
|
||||
]
|
||||
)
|
||||
for d in data.get("devices", []):
|
||||
writer.writerow([
|
||||
d["hostname"],
|
||||
f"{d['avg_cpu']:.1f}" if d["avg_cpu"] is not None else "",
|
||||
f"{d['peak_cpu']:.1f}" if d["peak_cpu"] is not None else "",
|
||||
f"{d['avg_mem']:.1f}" if d["avg_mem"] is not None else "",
|
||||
f"{d['peak_mem']:.1f}" if d["peak_mem"] is not None else "",
|
||||
f"{d['avg_disk']:.1f}" if d["avg_disk"] is not None else "",
|
||||
f"{d['avg_temp']:.1f}" if d["avg_temp"] is not None else "",
|
||||
d["data_points"],
|
||||
])
|
||||
writer.writerow(
|
||||
[
|
||||
d["hostname"],
|
||||
f"{d['avg_cpu']:.1f}" if d["avg_cpu"] is not None else "",
|
||||
f"{d['peak_cpu']:.1f}" if d["peak_cpu"] is not None else "",
|
||||
f"{d['avg_mem']:.1f}" if d["avg_mem"] is not None else "",
|
||||
f"{d['peak_mem']:.1f}" if d["peak_mem"] is not None else "",
|
||||
f"{d['avg_disk']:.1f}" if d["avg_disk"] is not None else "",
|
||||
f"{d['avg_temp']:.1f}" if d["avg_temp"] is not None else "",
|
||||
d["data_points"],
|
||||
]
|
||||
)
|
||||
|
||||
elif report_type == "alert_history":
|
||||
writer.writerow([
|
||||
"Timestamp", "Device", "Severity", "Message", "Status", "Duration",
|
||||
])
|
||||
writer.writerow(
|
||||
[
|
||||
"Timestamp",
|
||||
"Device",
|
||||
"Severity",
|
||||
"Message",
|
||||
"Status",
|
||||
"Duration",
|
||||
]
|
||||
)
|
||||
for a in data.get("alerts", []):
|
||||
writer.writerow([
|
||||
a["fired_at"], a["hostname"] or "", a["severity"],
|
||||
a["message"] or "", a["status"], a["duration"] or "",
|
||||
])
|
||||
writer.writerow(
|
||||
[
|
||||
a["fired_at"],
|
||||
a["hostname"] or "",
|
||||
a["severity"],
|
||||
a["message"] or "",
|
||||
a["status"],
|
||||
a["duration"] or "",
|
||||
]
|
||||
)
|
||||
|
||||
elif report_type == "change_log":
|
||||
writer.writerow([
|
||||
"Timestamp", "User", "Action", "Device", "Details",
|
||||
])
|
||||
writer.writerow(
|
||||
[
|
||||
"Timestamp",
|
||||
"User",
|
||||
"Action",
|
||||
"Device",
|
||||
"Details",
|
||||
]
|
||||
)
|
||||
for e in data.get("entries", []):
|
||||
writer.writerow([
|
||||
e["timestamp"], e["user"] or "", e["action"],
|
||||
e["device"] or "", e["details"] or "",
|
||||
])
|
||||
writer.writerow(
|
||||
[
|
||||
e["timestamp"],
|
||||
e["user"] or "",
|
||||
e["action"],
|
||||
e["device"] or "",
|
||||
e["details"] or "",
|
||||
]
|
||||
)
|
||||
|
||||
return output.getvalue().encode("utf-8")
|
||||
|
||||
|
||||
@@ -33,7 +33,6 @@ import asyncssh
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.database import set_tenant_context, AdminAsyncSessionLocal
|
||||
from app.models.config_backup import ConfigPushOperation
|
||||
from app.models.device import Device
|
||||
from app.services import backup_service, git_store
|
||||
@@ -113,12 +112,11 @@ async def restore_config(
|
||||
raise ValueError(f"Device {device_id!r} not found")
|
||||
|
||||
if not device.encrypted_credentials_transit and not device.encrypted_credentials:
|
||||
raise ValueError(
|
||||
f"Device {device_id!r} has no stored credentials — cannot perform restore"
|
||||
)
|
||||
raise ValueError(f"Device {device_id!r} has no stored credentials — cannot perform restore")
|
||||
|
||||
key = settings.get_encryption_key_bytes()
|
||||
from app.services.crypto import decrypt_credentials_hybrid
|
||||
|
||||
creds_json = await decrypt_credentials_hybrid(
|
||||
device.encrypted_credentials_transit,
|
||||
device.encrypted_credentials,
|
||||
@@ -133,7 +131,9 @@ async def restore_config(
|
||||
hostname = device.hostname or ip
|
||||
|
||||
# Publish "started" progress event
|
||||
await _publish_push_progress(tenant_id, device_id, "started", f"Config restore started for {hostname}")
|
||||
await _publish_push_progress(
|
||||
tenant_id, device_id, "started", f"Config restore started for {hostname}"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2: Read the target export.rsc from the backup commit
|
||||
@@ -157,7 +157,9 @@ async def restore_config(
|
||||
# ------------------------------------------------------------------
|
||||
# Step 3: Mandatory pre-backup before push
|
||||
# ------------------------------------------------------------------
|
||||
await _publish_push_progress(tenant_id, device_id, "backing_up", f"Creating pre-restore backup for {hostname}")
|
||||
await _publish_push_progress(
|
||||
tenant_id, device_id, "backing_up", f"Creating pre-restore backup for {hostname}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Starting pre-restore backup for device %s (%s) before pushing commit %s",
|
||||
@@ -198,7 +200,9 @@ async def restore_config(
|
||||
# Step 5: SSH to device — install panic-revert, push config
|
||||
# ------------------------------------------------------------------
|
||||
push_op_id_str = str(push_op_id)
|
||||
await _publish_push_progress(tenant_id, device_id, "pushing", f"Pushing config to {hostname}", push_op_id=push_op_id_str)
|
||||
await _publish_push_progress(
|
||||
tenant_id, device_id, "pushing", f"Pushing config to {hostname}", push_op_id=push_op_id_str
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Pushing config to device %s (%s): installing panic-revert scheduler and uploading config",
|
||||
@@ -290,9 +294,12 @@ async def restore_config(
|
||||
# Update push operation to failed
|
||||
await _update_push_op_status(push_op_id, "failed", db_session)
|
||||
await _publish_push_progress(
|
||||
tenant_id, device_id, "failed",
|
||||
tenant_id,
|
||||
device_id,
|
||||
"failed",
|
||||
f"Config push failed for {hostname}: {push_err}",
|
||||
push_op_id=push_op_id_str, error=str(push_err),
|
||||
push_op_id=push_op_id_str,
|
||||
error=str(push_err),
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
@@ -312,7 +319,13 @@ async def restore_config(
|
||||
# ------------------------------------------------------------------
|
||||
# Step 6: Wait 60s for config to settle
|
||||
# ------------------------------------------------------------------
|
||||
await _publish_push_progress(tenant_id, device_id, "settling", f"Config pushed to {hostname} — waiting 60s for settle", push_op_id=push_op_id_str)
|
||||
await _publish_push_progress(
|
||||
tenant_id,
|
||||
device_id,
|
||||
"settling",
|
||||
f"Config pushed to {hostname} — waiting 60s for settle",
|
||||
push_op_id=push_op_id_str,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Config pushed to device %s — waiting 60s for config to settle",
|
||||
@@ -323,7 +336,13 @@ async def restore_config(
|
||||
# ------------------------------------------------------------------
|
||||
# Step 7: Reachability check
|
||||
# ------------------------------------------------------------------
|
||||
await _publish_push_progress(tenant_id, device_id, "verifying", f"Verifying device {hostname} reachability", push_op_id=push_op_id_str)
|
||||
await _publish_push_progress(
|
||||
tenant_id,
|
||||
device_id,
|
||||
"verifying",
|
||||
f"Verifying device {hostname} reachability",
|
||||
push_op_id=push_op_id_str,
|
||||
)
|
||||
|
||||
reachable = await _check_reachability(ip, ssh_username, ssh_password)
|
||||
|
||||
@@ -362,7 +381,13 @@ async def restore_config(
|
||||
|
||||
await _update_push_op_status(push_op_id, "committed", db_session)
|
||||
await clear_push(device_id)
|
||||
await _publish_push_progress(tenant_id, device_id, "committed", f"Config restored successfully on {hostname}", push_op_id=push_op_id_str)
|
||||
await _publish_push_progress(
|
||||
tenant_id,
|
||||
device_id,
|
||||
"committed",
|
||||
f"Config restored successfully on {hostname}",
|
||||
push_op_id=push_op_id_str,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "committed",
|
||||
@@ -384,7 +409,9 @@ async def restore_config(
|
||||
|
||||
await _update_push_op_status(push_op_id, "reverted", db_session)
|
||||
await _publish_push_progress(
|
||||
tenant_id, device_id, "reverted",
|
||||
tenant_id,
|
||||
device_id,
|
||||
"reverted",
|
||||
f"Device {hostname} unreachable — auto-reverting via panic-revert scheduler",
|
||||
push_op_id=push_op_id_str,
|
||||
)
|
||||
@@ -446,7 +473,7 @@ async def _update_push_op_status(
|
||||
new_status: New status value ('committed' | 'reverted' | 'failed').
|
||||
db_session: Database session (must already have tenant context set).
|
||||
"""
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy import update
|
||||
|
||||
await db_session.execute(
|
||||
update(ConfigPushOperation)
|
||||
@@ -526,9 +553,7 @@ async def recover_stale_push_operations(db_session: AsyncSession) -> None:
|
||||
for op in stale_ops:
|
||||
try:
|
||||
# Load device
|
||||
dev_result = await db_session.execute(
|
||||
select(Device).where(Device.id == op.device_id)
|
||||
)
|
||||
dev_result = await db_session.execute(select(Device).where(Device.id == op.device_id))
|
||||
device = dev_result.scalar_one_or_none()
|
||||
if not device:
|
||||
logger.error("Device %s not found for stale op %s", op.device_id, op.id)
|
||||
@@ -547,9 +572,7 @@ async def recover_stale_push_operations(db_session: AsyncSession) -> None:
|
||||
ssh_password = creds.get("password", "")
|
||||
|
||||
# Check reachability
|
||||
reachable = await _check_reachability(
|
||||
device.ip_address, ssh_username, ssh_password
|
||||
)
|
||||
reachable = await _check_reachability(device.ip_address, ssh_username, ssh_password)
|
||||
|
||||
if reachable:
|
||||
# Try to remove scheduler (if still there, push was good)
|
||||
|
||||
@@ -53,7 +53,9 @@ async def cleanup_expired_snapshots() -> int:
|
||||
deleted = result.rowcount
|
||||
|
||||
config_snapshots_cleaned_total.inc(deleted)
|
||||
logger.info("retention cleanup complete", extra={"deleted_snapshots": deleted, "retention_days": days})
|
||||
logger.info(
|
||||
"retention cleanup complete", extra={"deleted_snapshots": deleted, "retention_days": days}
|
||||
)
|
||||
return deleted
|
||||
|
||||
|
||||
|
||||
@@ -88,9 +88,7 @@ async def browse_menu(device_id: str, path: str) -> dict[str, Any]:
|
||||
return await execute_command(device_id, command)
|
||||
|
||||
|
||||
async def add_entry(
|
||||
device_id: str, path: str, properties: dict[str, str]
|
||||
) -> dict[str, Any]:
|
||||
async def add_entry(device_id: str, path: str, properties: dict[str, str]) -> dict[str, Any]:
|
||||
"""Add a new entry to a RouterOS menu path.
|
||||
|
||||
Args:
|
||||
@@ -124,9 +122,7 @@ async def update_entry(
|
||||
return await execute_command(device_id, f"{path}/set", args)
|
||||
|
||||
|
||||
async def remove_entry(
|
||||
device_id: str, path: str, entry_id: str
|
||||
) -> dict[str, Any]:
|
||||
async def remove_entry(device_id: str, path: str, entry_id: str) -> dict[str, Any]:
|
||||
"""Remove an entry from a RouterOS menu path.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -7,20 +7,33 @@ from typing import Any
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HIGH_RISK_PATHS = {
|
||||
"/ip address", "/ip route", "/ip firewall filter", "/ip firewall nat",
|
||||
"/interface", "/interface bridge", "/interface vlan",
|
||||
"/system identity", "/ip service", "/ip ssh", "/user",
|
||||
"/ip address",
|
||||
"/ip route",
|
||||
"/ip firewall filter",
|
||||
"/ip firewall nat",
|
||||
"/interface",
|
||||
"/interface bridge",
|
||||
"/interface vlan",
|
||||
"/system identity",
|
||||
"/ip service",
|
||||
"/ip ssh",
|
||||
"/user",
|
||||
}
|
||||
|
||||
MANAGEMENT_PATTERNS = [
|
||||
(re.compile(r"chain=input.*dst-port=(22|8291|8728|8729|443|80)", re.I),
|
||||
"Modifies firewall rules for management ports (SSH/WinBox/API/Web)"),
|
||||
(re.compile(r"chain=input.*action=drop", re.I),
|
||||
"Adds drop rule on input chain — may block management access"),
|
||||
(re.compile(r"/ip service", re.I),
|
||||
"Modifies IP services — may disable API/SSH/WinBox access"),
|
||||
(re.compile(r"/user.*set.*password", re.I),
|
||||
"Changes user password — may affect automated access"),
|
||||
(
|
||||
re.compile(r"chain=input.*dst-port=(22|8291|8728|8729|443|80)", re.I),
|
||||
"Modifies firewall rules for management ports (SSH/WinBox/API/Web)",
|
||||
),
|
||||
(
|
||||
re.compile(r"chain=input.*action=drop", re.I),
|
||||
"Adds drop rule on input chain — may block management access",
|
||||
),
|
||||
(re.compile(r"/ip service", re.I), "Modifies IP services — may disable API/SSH/WinBox access"),
|
||||
(
|
||||
re.compile(r"/user.*set.*password", re.I),
|
||||
"Changes user password — may affect automated access",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -73,7 +86,14 @@ def parse_rsc(text: str) -> dict[str, Any]:
|
||||
else:
|
||||
# Check if second part starts with a known command verb
|
||||
cmd_check = parts[1].strip().split(None, 1)
|
||||
if cmd_check and cmd_check[0] in ("add", "set", "remove", "print", "enable", "disable"):
|
||||
if cmd_check and cmd_check[0] in (
|
||||
"add",
|
||||
"set",
|
||||
"remove",
|
||||
"print",
|
||||
"enable",
|
||||
"disable",
|
||||
):
|
||||
current_path = parts[0]
|
||||
line = parts[1].strip()
|
||||
else:
|
||||
@@ -184,12 +204,14 @@ def compute_impact(
|
||||
risk = "none"
|
||||
if has_changes:
|
||||
risk = "high" if path in HIGH_RISK_PATHS else "low"
|
||||
result_categories.append({
|
||||
"path": path,
|
||||
"adds": added,
|
||||
"removes": removed,
|
||||
"risk": risk,
|
||||
})
|
||||
result_categories.append(
|
||||
{
|
||||
"path": path,
|
||||
"adds": added,
|
||||
"removes": removed,
|
||||
"risk": risk,
|
||||
}
|
||||
)
|
||||
|
||||
# Check target commands against management patterns
|
||||
target_text = "\n".join(
|
||||
|
||||
@@ -16,9 +16,7 @@ from srptools.constants import PRIME_2048, PRIME_2048_GEN
|
||||
_SRP_HASH = hashlib.sha256
|
||||
|
||||
|
||||
async def create_srp_verifier(
|
||||
salt_hex: str, verifier_hex: str
|
||||
) -> tuple[bytes, bytes]:
|
||||
async def create_srp_verifier(salt_hex: str, verifier_hex: str) -> tuple[bytes, bytes]:
|
||||
"""Convert client-provided hex salt and verifier to bytes for storage.
|
||||
|
||||
The client computes v = g^x mod N using 2SKD-derived SRP-x.
|
||||
@@ -31,9 +29,7 @@ async def create_srp_verifier(
|
||||
return bytes.fromhex(salt_hex), bytes.fromhex(verifier_hex)
|
||||
|
||||
|
||||
async def srp_init(
|
||||
email: str, srp_verifier_hex: str
|
||||
) -> tuple[str, str]:
|
||||
async def srp_init(email: str, srp_verifier_hex: str) -> tuple[str, str]:
|
||||
"""SRP Step 1: Generate server ephemeral (B) and private key (b).
|
||||
|
||||
Args:
|
||||
@@ -47,14 +43,15 @@ async def srp_init(
|
||||
Raises:
|
||||
ValueError: If SRP initialization fails for any reason.
|
||||
"""
|
||||
|
||||
def _init() -> tuple[str, str]:
|
||||
context = SRPContext(
|
||||
email, prime=PRIME_2048, generator=PRIME_2048_GEN,
|
||||
email,
|
||||
prime=PRIME_2048,
|
||||
generator=PRIME_2048_GEN,
|
||||
hash_func=_SRP_HASH,
|
||||
)
|
||||
server_session = SRPServerSession(
|
||||
context, srp_verifier_hex
|
||||
)
|
||||
server_session = SRPServerSession(context, srp_verifier_hex)
|
||||
return server_session.public, server_session.private
|
||||
|
||||
try:
|
||||
@@ -85,26 +82,27 @@ async def srp_verify(
|
||||
Tuple of (is_valid, server_proof_hex_or_none).
|
||||
If valid, server_proof is M2 for the client to verify.
|
||||
"""
|
||||
|
||||
def _verify() -> tuple[bool, str | None]:
|
||||
import logging
|
||||
log = logging.getLogger("srp_debug")
|
||||
context = SRPContext(
|
||||
email, prime=PRIME_2048, generator=PRIME_2048_GEN,
|
||||
email,
|
||||
prime=PRIME_2048,
|
||||
generator=PRIME_2048_GEN,
|
||||
hash_func=_SRP_HASH,
|
||||
)
|
||||
server_session = SRPServerSession(
|
||||
context, srp_verifier_hex, private=server_private
|
||||
)
|
||||
server_session = SRPServerSession(context, srp_verifier_hex, private=server_private)
|
||||
_key, _key_proof, _key_proof_hash = server_session.process(client_public, srp_salt_hex)
|
||||
# srptools verify_proof has a Python 3 bug: hexlify() returns bytes
|
||||
# but client_proof is str, so bytes == str is always False.
|
||||
# Compare manually with consistent types.
|
||||
server_m1 = _key_proof if isinstance(_key_proof, str) else _key_proof.decode('ascii')
|
||||
server_m1 = _key_proof if isinstance(_key_proof, str) else _key_proof.decode("ascii")
|
||||
is_valid = client_proof.lower() == server_m1.lower()
|
||||
if not is_valid:
|
||||
return False, None
|
||||
# Return M2 (key_proof_hash), also fixing the bytes/str issue
|
||||
m2 = _key_proof_hash if isinstance(_key_proof_hash, str) else _key_proof_hash.decode('ascii')
|
||||
m2 = (
|
||||
_key_proof_hash if isinstance(_key_proof_hash, str) else _key_proof_hash.decode("ascii")
|
||||
)
|
||||
return True, m2
|
||||
|
||||
try:
|
||||
|
||||
@@ -137,7 +137,9 @@ class SSEConnectionManager:
|
||||
if last_event_id is not None:
|
||||
try:
|
||||
start_seq = int(last_event_id) + 1
|
||||
consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.BY_START_SEQUENCE, opt_start_seq=start_seq)
|
||||
consumer_cfg = ConsumerConfig(
|
||||
deliver_policy=DeliverPolicy.BY_START_SEQUENCE, opt_start_seq=start_seq
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.NEW)
|
||||
else:
|
||||
@@ -173,18 +175,32 @@ class SSEConnectionManager:
|
||||
except Exception as exc:
|
||||
if "stream not found" in str(exc):
|
||||
try:
|
||||
await js.add_stream(StreamConfig(
|
||||
name="ALERT_EVENTS",
|
||||
subjects=_ALERT_EVENT_SUBJECTS,
|
||||
max_age=3600,
|
||||
))
|
||||
sub = await js.subscribe(subject, stream="ALERT_EVENTS", config=consumer_cfg)
|
||||
await js.add_stream(
|
||||
StreamConfig(
|
||||
name="ALERT_EVENTS",
|
||||
subjects=_ALERT_EVENT_SUBJECTS,
|
||||
max_age=3600,
|
||||
)
|
||||
)
|
||||
sub = await js.subscribe(
|
||||
subject, stream="ALERT_EVENTS", config=consumer_cfg
|
||||
)
|
||||
self._subscriptions.append(sub)
|
||||
logger.info("sse.stream_created_lazily", stream="ALERT_EVENTS")
|
||||
except Exception as retry_exc:
|
||||
logger.warning("sse.subscribe_failed", subject=subject, stream="ALERT_EVENTS", error=str(retry_exc))
|
||||
logger.warning(
|
||||
"sse.subscribe_failed",
|
||||
subject=subject,
|
||||
stream="ALERT_EVENTS",
|
||||
error=str(retry_exc),
|
||||
)
|
||||
else:
|
||||
logger.warning("sse.subscribe_failed", subject=subject, stream="ALERT_EVENTS", error=str(exc))
|
||||
logger.warning(
|
||||
"sse.subscribe_failed",
|
||||
subject=subject,
|
||||
stream="ALERT_EVENTS",
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
# Subscribe to operation events (OPERATION_EVENTS stream)
|
||||
for subject in _OPERATION_EVENT_SUBJECTS:
|
||||
@@ -198,18 +214,32 @@ class SSEConnectionManager:
|
||||
except Exception as exc:
|
||||
if "stream not found" in str(exc):
|
||||
try:
|
||||
await js.add_stream(StreamConfig(
|
||||
name="OPERATION_EVENTS",
|
||||
subjects=_OPERATION_EVENT_SUBJECTS,
|
||||
max_age=3600,
|
||||
))
|
||||
sub = await js.subscribe(subject, stream="OPERATION_EVENTS", config=consumer_cfg)
|
||||
await js.add_stream(
|
||||
StreamConfig(
|
||||
name="OPERATION_EVENTS",
|
||||
subjects=_OPERATION_EVENT_SUBJECTS,
|
||||
max_age=3600,
|
||||
)
|
||||
)
|
||||
sub = await js.subscribe(
|
||||
subject, stream="OPERATION_EVENTS", config=consumer_cfg
|
||||
)
|
||||
self._subscriptions.append(sub)
|
||||
logger.info("sse.stream_created_lazily", stream="OPERATION_EVENTS")
|
||||
except Exception as retry_exc:
|
||||
logger.warning("sse.subscribe_failed", subject=subject, stream="OPERATION_EVENTS", error=str(retry_exc))
|
||||
logger.warning(
|
||||
"sse.subscribe_failed",
|
||||
subject=subject,
|
||||
stream="OPERATION_EVENTS",
|
||||
error=str(retry_exc),
|
||||
)
|
||||
else:
|
||||
logger.warning("sse.subscribe_failed", subject=subject, stream="OPERATION_EVENTS", error=str(exc))
|
||||
logger.warning(
|
||||
"sse.subscribe_failed",
|
||||
subject=subject,
|
||||
stream="OPERATION_EVENTS",
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
# Start background task to pull messages from subscriptions into the queue
|
||||
asyncio.create_task(self._pump_messages())
|
||||
|
||||
@@ -12,22 +12,18 @@ separate scheduler and file names to avoid conflicts with restore operations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import ipaddress
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import asyncssh
|
||||
from jinja2 import meta
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.config import settings
|
||||
from app.database import AdminAsyncSessionLocal
|
||||
from app.models.config_template import TemplatePushJob
|
||||
from app.models.device import Device
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -145,7 +141,9 @@ async def push_to_devices(rollout_id: str) -> dict:
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Uncaught exception in template push rollout %s: %s",
|
||||
rollout_id, exc, exc_info=True,
|
||||
rollout_id,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
return {"completed": 0, "failed": 1, "pending": 0}
|
||||
|
||||
@@ -181,7 +179,9 @@ async def _run_push_rollout(rollout_id: str) -> dict:
|
||||
|
||||
logger.info(
|
||||
"Template push rollout %s: pushing to device %s (job %s)",
|
||||
rollout_id, hostname, job_id,
|
||||
rollout_id,
|
||||
hostname,
|
||||
job_id,
|
||||
)
|
||||
|
||||
await push_single_device(job_id)
|
||||
@@ -200,7 +200,9 @@ async def _run_push_rollout(rollout_id: str) -> dict:
|
||||
failed = True
|
||||
logger.error(
|
||||
"Template push rollout %s paused: device %s %s",
|
||||
rollout_id, hostname, row[0],
|
||||
rollout_id,
|
||||
hostname,
|
||||
row[0],
|
||||
)
|
||||
break
|
||||
|
||||
@@ -232,7 +234,9 @@ async def push_single_device(job_id: str) -> None:
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Uncaught exception in template push job %s: %s",
|
||||
job_id, exc, exc_info=True,
|
||||
job_id,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
await _update_job(job_id, status="failed", error_message=f"Unexpected error: {exc}")
|
||||
|
||||
@@ -260,8 +264,13 @@ async def _run_single_push(job_id: str) -> None:
|
||||
return
|
||||
|
||||
(
|
||||
_, device_id, tenant_id, rendered_content,
|
||||
ip_address, hostname, encrypted_credentials,
|
||||
_,
|
||||
device_id,
|
||||
tenant_id,
|
||||
rendered_content,
|
||||
ip_address,
|
||||
hostname,
|
||||
encrypted_credentials,
|
||||
encrypted_credentials_transit,
|
||||
) = row
|
||||
|
||||
@@ -279,16 +288,21 @@ async def _run_single_push(job_id: str) -> None:
|
||||
|
||||
try:
|
||||
from app.services.crypto import decrypt_credentials_hybrid
|
||||
|
||||
key = settings.get_encryption_key_bytes()
|
||||
creds_json = await decrypt_credentials_hybrid(
|
||||
encrypted_credentials_transit, encrypted_credentials, tenant_id, key,
|
||||
encrypted_credentials_transit,
|
||||
encrypted_credentials,
|
||||
tenant_id,
|
||||
key,
|
||||
)
|
||||
creds = json.loads(creds_json)
|
||||
ssh_username = creds.get("username", "")
|
||||
ssh_password = creds.get("password", "")
|
||||
except Exception as cred_err:
|
||||
await _update_job(
|
||||
job_id, status="failed",
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message=f"Failed to decrypt credentials: {cred_err}",
|
||||
)
|
||||
return
|
||||
@@ -297,6 +311,7 @@ async def _run_single_push(job_id: str) -> None:
|
||||
logger.info("Running mandatory pre-push backup for device %s (%s)", hostname, ip_address)
|
||||
try:
|
||||
from app.services import backup_service
|
||||
|
||||
backup_result = await backup_service.run_backup(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -308,7 +323,8 @@ async def _run_single_push(job_id: str) -> None:
|
||||
except Exception as backup_err:
|
||||
logger.error("Pre-push backup failed for %s: %s", hostname, backup_err)
|
||||
await _update_job(
|
||||
job_id, status="failed",
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message=f"Pre-push backup failed: {backup_err}",
|
||||
)
|
||||
return
|
||||
@@ -316,7 +332,8 @@ async def _run_single_push(job_id: str) -> None:
|
||||
# Step 5: SSH to device - install panic-revert, push config
|
||||
logger.info(
|
||||
"Pushing template to device %s (%s): installing panic-revert and uploading config",
|
||||
hostname, ip_address,
|
||||
hostname,
|
||||
ip_address,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -359,7 +376,8 @@ async def _run_single_push(job_id: str) -> None:
|
||||
)
|
||||
logger.info(
|
||||
"Template import result for device %s: exit_status=%s stdout=%r",
|
||||
hostname, import_result.exit_status,
|
||||
hostname,
|
||||
import_result.exit_status,
|
||||
(import_result.stdout or "")[:200],
|
||||
)
|
||||
|
||||
@@ -369,16 +387,21 @@ async def _run_single_push(job_id: str) -> None:
|
||||
except Exception as cleanup_err:
|
||||
logger.warning(
|
||||
"Failed to clean up %s from device %s: %s",
|
||||
_TEMPLATE_RSC, ip_address, cleanup_err,
|
||||
_TEMPLATE_RSC,
|
||||
ip_address,
|
||||
cleanup_err,
|
||||
)
|
||||
|
||||
except Exception as push_err:
|
||||
logger.error(
|
||||
"SSH push phase failed for device %s (%s): %s",
|
||||
hostname, ip_address, push_err,
|
||||
hostname,
|
||||
ip_address,
|
||||
push_err,
|
||||
)
|
||||
await _update_job(
|
||||
job_id, status="failed",
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message=f"Config push failed during SSH phase: {push_err}",
|
||||
)
|
||||
return
|
||||
@@ -395,9 +418,12 @@ async def _run_single_push(job_id: str) -> None:
|
||||
logger.info("Device %s (%s) is reachable after push - committing", hostname, ip_address)
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip_address, port=22,
|
||||
username=ssh_username, password=ssh_password,
|
||||
known_hosts=None, connect_timeout=30,
|
||||
ip_address,
|
||||
port=22,
|
||||
username=ssh_username,
|
||||
password=ssh_password,
|
||||
known_hosts=None,
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
await conn.run(
|
||||
f'/system scheduler remove "{_PANIC_REVERT_SCHEDULER}"',
|
||||
@@ -410,11 +436,13 @@ async def _run_single_push(job_id: str) -> None:
|
||||
except Exception as cleanup_err:
|
||||
logger.warning(
|
||||
"Failed to clean up panic-revert scheduler/backup on device %s: %s",
|
||||
hostname, cleanup_err,
|
||||
hostname,
|
||||
cleanup_err,
|
||||
)
|
||||
|
||||
await _update_job(
|
||||
job_id, status="committed",
|
||||
job_id,
|
||||
status="committed",
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
else:
|
||||
@@ -422,10 +450,13 @@ async def _run_single_push(job_id: str) -> None:
|
||||
logger.warning(
|
||||
"Device %s (%s) is unreachable after push - panic-revert scheduler "
|
||||
"will auto-revert to %s.backup",
|
||||
hostname, ip_address, _PRE_PUSH_BACKUP,
|
||||
hostname,
|
||||
ip_address,
|
||||
_PRE_PUSH_BACKUP,
|
||||
)
|
||||
await _update_job(
|
||||
job_id, status="reverted",
|
||||
job_id,
|
||||
status="reverted",
|
||||
error_message="Device unreachable after push; auto-reverted via panic-revert scheduler",
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
@@ -440,9 +471,12 @@ async def _check_reachability(ip: str, username: str, password: str) -> bool:
|
||||
"""Check if a RouterOS device is reachable via SSH."""
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip, port=22,
|
||||
username=username, password=password,
|
||||
known_hosts=None, connect_timeout=30,
|
||||
ip,
|
||||
port=22,
|
||||
username=username,
|
||||
password=password,
|
||||
known_hosts=None,
|
||||
connect_timeout=30,
|
||||
) as conn:
|
||||
result = await conn.run("/system identity print", check=True)
|
||||
logger.debug("Reachability check OK for %s: %r", ip, result.stdout[:50])
|
||||
@@ -459,7 +493,12 @@ async def _update_job(job_id: str, **kwargs) -> None:
|
||||
|
||||
for key, value in kwargs.items():
|
||||
param_name = f"v_{key}"
|
||||
if value is None and key in ("error_message", "started_at", "completed_at", "pre_push_backup_sha"):
|
||||
if value is None and key in (
|
||||
"error_message",
|
||||
"started_at",
|
||||
"completed_at",
|
||||
"pre_push_backup_sha",
|
||||
):
|
||||
sets.append(f"{key} = NULL")
|
||||
else:
|
||||
sets.append(f"{key} = :{param_name}")
|
||||
@@ -472,7 +511,7 @@ async def _update_job(job_id: str, **kwargs) -> None:
|
||||
await session.execute(
|
||||
text(f"""
|
||||
UPDATE template_push_jobs
|
||||
SET {', '.join(sets)}
|
||||
SET {", ".join(sets)}
|
||||
WHERE id = CAST(:job_id AS uuid)
|
||||
"""),
|
||||
params,
|
||||
|
||||
@@ -13,7 +13,6 @@ jobs may span multiple tenants and run in background asyncio tasks.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
@@ -99,10 +98,19 @@ async def _run_upgrade(job_id: str) -> None:
|
||||
return
|
||||
|
||||
(
|
||||
_, device_id, tenant_id, target_version,
|
||||
architecture, channel, status, confirmed_major,
|
||||
ip_address, hostname, encrypted_credentials,
|
||||
current_version, encrypted_credentials_transit,
|
||||
_,
|
||||
device_id,
|
||||
tenant_id,
|
||||
target_version,
|
||||
architecture,
|
||||
channel,
|
||||
status,
|
||||
confirmed_major,
|
||||
ip_address,
|
||||
hostname,
|
||||
encrypted_credentials,
|
||||
current_version,
|
||||
encrypted_credentials_transit,
|
||||
) = row
|
||||
|
||||
device_id = str(device_id)
|
||||
@@ -116,12 +124,22 @@ async def _run_upgrade(job_id: str) -> None:
|
||||
|
||||
logger.info(
|
||||
"Starting firmware upgrade for %s (%s): %s -> %s",
|
||||
hostname, ip_address, current_version, target_version,
|
||||
hostname,
|
||||
ip_address,
|
||||
current_version,
|
||||
target_version,
|
||||
)
|
||||
|
||||
# Step 2: Update status to downloading
|
||||
await _update_job(job_id, status="downloading", started_at=datetime.now(timezone.utc))
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "downloading", target_version, f"Downloading firmware {target_version} for {hostname}")
|
||||
await _publish_upgrade_progress(
|
||||
tenant_id,
|
||||
device_id,
|
||||
job_id,
|
||||
"downloading",
|
||||
target_version,
|
||||
f"Downloading firmware {target_version} for {hostname}",
|
||||
)
|
||||
|
||||
# Step 3: Check major version upgrade confirmation
|
||||
if current_version and target_version:
|
||||
@@ -133,13 +151,22 @@ async def _run_upgrade(job_id: str) -> None:
|
||||
status="failed",
|
||||
error_message="Major version upgrade requires explicit confirmation",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Major version upgrade requires explicit confirmation for {hostname}", error="Major version upgrade requires explicit confirmation")
|
||||
await _publish_upgrade_progress(
|
||||
tenant_id,
|
||||
device_id,
|
||||
job_id,
|
||||
"failed",
|
||||
target_version,
|
||||
f"Major version upgrade requires explicit confirmation for {hostname}",
|
||||
error="Major version upgrade requires explicit confirmation",
|
||||
)
|
||||
return
|
||||
|
||||
# Step 4: Mandatory config backup
|
||||
logger.info("Running mandatory pre-upgrade backup for %s", hostname)
|
||||
try:
|
||||
from app.services import backup_service
|
||||
|
||||
backup_result = await backup_service.run_backup(
|
||||
device_id=device_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -155,13 +182,22 @@ async def _run_upgrade(job_id: str) -> None:
|
||||
status="failed",
|
||||
error_message=f"Pre-upgrade backup failed: {backup_err}",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Pre-upgrade backup failed for {hostname}", error=str(backup_err))
|
||||
await _publish_upgrade_progress(
|
||||
tenant_id,
|
||||
device_id,
|
||||
job_id,
|
||||
"failed",
|
||||
target_version,
|
||||
f"Pre-upgrade backup failed for {hostname}",
|
||||
error=str(backup_err),
|
||||
)
|
||||
return
|
||||
|
||||
# Step 5: Download NPK
|
||||
logger.info("Downloading firmware %s for %s/%s", target_version, architecture, channel)
|
||||
try:
|
||||
from app.services.firmware_service import download_firmware
|
||||
|
||||
npk_path = await download_firmware(architecture, channel, target_version)
|
||||
logger.info("Firmware cached at %s", npk_path)
|
||||
except Exception as dl_err:
|
||||
@@ -171,24 +207,51 @@ async def _run_upgrade(job_id: str) -> None:
|
||||
status="failed",
|
||||
error_message=f"Firmware download failed: {dl_err}",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Firmware download failed for {hostname}", error=str(dl_err))
|
||||
await _publish_upgrade_progress(
|
||||
tenant_id,
|
||||
device_id,
|
||||
job_id,
|
||||
"failed",
|
||||
target_version,
|
||||
f"Firmware download failed for {hostname}",
|
||||
error=str(dl_err),
|
||||
)
|
||||
return
|
||||
|
||||
# Step 6: Upload NPK to device via SFTP
|
||||
await _update_job(job_id, status="uploading")
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "uploading", target_version, f"Uploading firmware to {hostname}")
|
||||
await _publish_upgrade_progress(
|
||||
tenant_id,
|
||||
device_id,
|
||||
job_id,
|
||||
"uploading",
|
||||
target_version,
|
||||
f"Uploading firmware to {hostname}",
|
||||
)
|
||||
|
||||
# Decrypt device credentials (dual-read: Transit preferred, legacy fallback)
|
||||
if not encrypted_credentials_transit and not encrypted_credentials:
|
||||
await _update_job(job_id, status="failed", error_message="Device has no stored credentials")
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"No stored credentials for {hostname}", error="Device has no stored credentials")
|
||||
await _publish_upgrade_progress(
|
||||
tenant_id,
|
||||
device_id,
|
||||
job_id,
|
||||
"failed",
|
||||
target_version,
|
||||
f"No stored credentials for {hostname}",
|
||||
error="Device has no stored credentials",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
from app.services.crypto import decrypt_credentials_hybrid
|
||||
|
||||
key = settings.get_encryption_key_bytes()
|
||||
creds_json = await decrypt_credentials_hybrid(
|
||||
encrypted_credentials_transit, encrypted_credentials, tenant_id, key,
|
||||
encrypted_credentials_transit,
|
||||
encrypted_credentials,
|
||||
tenant_id,
|
||||
key,
|
||||
)
|
||||
creds = json.loads(creds_json)
|
||||
ssh_username = creds.get("username", "")
|
||||
@@ -199,7 +262,15 @@ async def _run_upgrade(job_id: str) -> None:
|
||||
status="failed",
|
||||
error_message=f"Failed to decrypt credentials: {cred_err}",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Failed to decrypt credentials for {hostname}", error=str(cred_err))
|
||||
await _publish_upgrade_progress(
|
||||
tenant_id,
|
||||
device_id,
|
||||
job_id,
|
||||
"failed",
|
||||
target_version,
|
||||
f"Failed to decrypt credentials for {hostname}",
|
||||
error=str(cred_err),
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -225,12 +296,27 @@ async def _run_upgrade(job_id: str) -> None:
|
||||
status="failed",
|
||||
error_message=f"NPK upload failed: {upload_err}",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"NPK upload failed for {hostname}", error=str(upload_err))
|
||||
await _publish_upgrade_progress(
|
||||
tenant_id,
|
||||
device_id,
|
||||
job_id,
|
||||
"failed",
|
||||
target_version,
|
||||
f"NPK upload failed for {hostname}",
|
||||
error=str(upload_err),
|
||||
)
|
||||
return
|
||||
|
||||
# Step 7: Trigger reboot
|
||||
await _update_job(job_id, status="rebooting")
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "rebooting", target_version, f"Rebooting {hostname} for firmware install")
|
||||
await _publish_upgrade_progress(
|
||||
tenant_id,
|
||||
device_id,
|
||||
job_id,
|
||||
"rebooting",
|
||||
target_version,
|
||||
f"Rebooting {hostname} for firmware install",
|
||||
)
|
||||
try:
|
||||
async with asyncssh.connect(
|
||||
ip_address,
|
||||
@@ -245,7 +331,9 @@ async def _run_upgrade(job_id: str) -> None:
|
||||
logger.info("Reboot command sent to %s", hostname)
|
||||
except Exception as reboot_err:
|
||||
# Device may drop connection during reboot — this is expected
|
||||
logger.info("Device %s dropped connection after reboot command (expected): %s", hostname, reboot_err)
|
||||
logger.info(
|
||||
"Device %s dropped connection after reboot command (expected): %s", hostname, reboot_err
|
||||
)
|
||||
|
||||
# Step 8: Wait for reconnect
|
||||
logger.info("Waiting %ds before polling %s for reconnect", _INITIAL_WAIT, hostname)
|
||||
@@ -267,36 +355,69 @@ async def _run_upgrade(job_id: str) -> None:
|
||||
status="failed",
|
||||
error_message=f"Device did not reconnect within {_RECONNECT_TIMEOUT // 60} minutes after reboot",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Device {hostname} did not reconnect within {_RECONNECT_TIMEOUT // 60} minutes", error="Reconnect timeout")
|
||||
await _publish_upgrade_progress(
|
||||
tenant_id,
|
||||
device_id,
|
||||
job_id,
|
||||
"failed",
|
||||
target_version,
|
||||
f"Device {hostname} did not reconnect within {_RECONNECT_TIMEOUT // 60} minutes",
|
||||
error="Reconnect timeout",
|
||||
)
|
||||
return
|
||||
|
||||
# Step 9: Verify upgrade
|
||||
await _update_job(job_id, status="verifying")
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "verifying", target_version, f"Verifying firmware version on {hostname}")
|
||||
await _publish_upgrade_progress(
|
||||
tenant_id,
|
||||
device_id,
|
||||
job_id,
|
||||
"verifying",
|
||||
target_version,
|
||||
f"Verifying firmware version on {hostname}",
|
||||
)
|
||||
try:
|
||||
actual_version = await _get_device_version(ip_address, ssh_username, ssh_password)
|
||||
if actual_version and target_version in actual_version:
|
||||
logger.info(
|
||||
"Firmware upgrade verified for %s: %s",
|
||||
hostname, actual_version,
|
||||
hostname,
|
||||
actual_version,
|
||||
)
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="completed",
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "completed", target_version, f"Firmware upgrade to {target_version} completed on {hostname}")
|
||||
await _publish_upgrade_progress(
|
||||
tenant_id,
|
||||
device_id,
|
||||
job_id,
|
||||
"completed",
|
||||
target_version,
|
||||
f"Firmware upgrade to {target_version} completed on {hostname}",
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"Version mismatch for %s: expected %s, got %s",
|
||||
hostname, target_version, actual_version,
|
||||
hostname,
|
||||
target_version,
|
||||
actual_version,
|
||||
)
|
||||
await _update_job(
|
||||
job_id,
|
||||
status="failed",
|
||||
error_message=f"Expected {target_version} but got {actual_version}",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Version mismatch on {hostname}: expected {target_version}, got {actual_version}", error=f"Expected {target_version} but got {actual_version}")
|
||||
await _publish_upgrade_progress(
|
||||
tenant_id,
|
||||
device_id,
|
||||
job_id,
|
||||
"failed",
|
||||
target_version,
|
||||
f"Version mismatch on {hostname}: expected {target_version}, got {actual_version}",
|
||||
error=f"Expected {target_version} but got {actual_version}",
|
||||
)
|
||||
except Exception as verify_err:
|
||||
logger.error("Post-upgrade verification failed for %s: %s", hostname, verify_err)
|
||||
await _update_job(
|
||||
@@ -304,7 +425,15 @@ async def _run_upgrade(job_id: str) -> None:
|
||||
status="failed",
|
||||
error_message=f"Post-upgrade verification failed: {verify_err}",
|
||||
)
|
||||
await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Post-upgrade verification failed for {hostname}", error=str(verify_err))
|
||||
await _publish_upgrade_progress(
|
||||
tenant_id,
|
||||
device_id,
|
||||
job_id,
|
||||
"failed",
|
||||
target_version,
|
||||
f"Post-upgrade verification failed for {hostname}",
|
||||
error=str(verify_err),
|
||||
)
|
||||
|
||||
|
||||
async def start_mass_upgrade(rollout_group_id: str) -> dict:
|
||||
@@ -457,7 +586,7 @@ async def resume_mass_upgrade(rollout_group_id: str) -> None:
|
||||
"""Resume a paused mass rollout from where it left off."""
|
||||
# Reset first paused job to pending, then restart sequential processing
|
||||
async with AdminAsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
await session.execute(
|
||||
text("""
|
||||
UPDATE firmware_upgrade_jobs
|
||||
SET status = 'pending'
|
||||
@@ -519,7 +648,7 @@ async def _update_job(job_id: str, **kwargs) -> None:
|
||||
await session.execute(
|
||||
text(f"""
|
||||
UPDATE firmware_upgrade_jobs
|
||||
SET {', '.join(sets)}
|
||||
SET {", ".join(sets)}
|
||||
WHERE id = CAST(:job_id AS uuid)
|
||||
"""),
|
||||
params,
|
||||
|
||||
@@ -26,7 +26,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.config import settings
|
||||
from app.models.device import Device
|
||||
from app.models.vpn import VpnConfig, VpnPeer
|
||||
from app.services.crypto import decrypt_credentials, encrypt_credentials, encrypt_credentials_transit
|
||||
from app.services.crypto import (
|
||||
decrypt_credentials,
|
||||
encrypt_credentials,
|
||||
encrypt_credentials_transit,
|
||||
)
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
@@ -62,7 +66,9 @@ async def _get_or_create_global_server_key(db: AsyncSession) -> tuple[str, str]:
|
||||
await db.execute(sa_text("SELECT pg_advisory_xact_lock(hashtext('vpn_server_keygen'))"))
|
||||
|
||||
result = await db.execute(
|
||||
sa_text("SELECT key, value, encrypted_value FROM system_settings WHERE key IN ('vpn_server_public_key', 'vpn_server_private_key')")
|
||||
sa_text(
|
||||
"SELECT key, value, encrypted_value FROM system_settings WHERE key IN ('vpn_server_public_key', 'vpn_server_private_key')"
|
||||
)
|
||||
)
|
||||
rows = {row[0]: row for row in result.fetchall()}
|
||||
|
||||
@@ -188,7 +194,9 @@ async def sync_wireguard_config() -> None:
|
||||
|
||||
# Query ALL enabled VPN configs (admin session bypasses RLS)
|
||||
configs_result = await admin_db.execute(
|
||||
select(VpnConfig).where(VpnConfig.is_enabled.is_(True)).order_by(VpnConfig.subnet_index)
|
||||
select(VpnConfig)
|
||||
.where(VpnConfig.is_enabled.is_(True))
|
||||
.order_by(VpnConfig.subnet_index)
|
||||
)
|
||||
configs = configs_result.scalars().all()
|
||||
|
||||
@@ -228,7 +236,9 @@ async def sync_wireguard_config() -> None:
|
||||
peer_ip = peer.assigned_ip.split("/")[0]
|
||||
allowed_ips = [f"{peer_ip}/32"]
|
||||
if peer.additional_allowed_ips:
|
||||
extra = [s.strip() for s in peer.additional_allowed_ips.split(",") if s.strip()]
|
||||
extra = [
|
||||
s.strip() for s in peer.additional_allowed_ips.split(",") if s.strip()
|
||||
]
|
||||
allowed_ips.extend(extra)
|
||||
lines.append("[Peer]")
|
||||
lines.append(f"PublicKey = {peer.peer_public_key}")
|
||||
@@ -253,12 +263,13 @@ async def sync_wireguard_config() -> None:
|
||||
# Docker traffic (172.16.0.0/12) going to each tenant's subnet
|
||||
# gets SNATted to that tenant's gateway IP (.1) so the router
|
||||
# can route replies back through the tunnel.
|
||||
nat_lines = ["#!/bin/sh",
|
||||
"# Auto-generated per-tenant SNAT rules",
|
||||
"# Remove old rules",
|
||||
"iptables -t nat -F POSTROUTING 2>/dev/null",
|
||||
"# Re-add Docker DNS rules",
|
||||
]
|
||||
nat_lines = [
|
||||
"#!/bin/sh",
|
||||
"# Auto-generated per-tenant SNAT rules",
|
||||
"# Remove old rules",
|
||||
"iptables -t nat -F POSTROUTING 2>/dev/null",
|
||||
"# Re-add Docker DNS rules",
|
||||
]
|
||||
for config in configs:
|
||||
gateway_ip = config.server_address.split("/")[0] # e.g. 10.10.3.1
|
||||
subnet = config.subnet # e.g. 10.10.3.0/24
|
||||
@@ -275,12 +286,15 @@ async def sync_wireguard_config() -> None:
|
||||
reload_flag = wg_confs_dir / ".reload"
|
||||
reload_flag.write_text("1")
|
||||
|
||||
logger.info("wireguard_config_synced", audit=True,
|
||||
tenants=len(configs), peers=total_peers)
|
||||
logger.info(
|
||||
"wireguard_config_synced", audit=True, tenants=len(configs), peers=total_peers
|
||||
)
|
||||
|
||||
finally:
|
||||
# Release advisory lock explicitly (session-level lock, not xact-level)
|
||||
await admin_db.execute(sa_text("SELECT pg_advisory_unlock(hashtext('wireguard_config'))"))
|
||||
await admin_db.execute(
|
||||
sa_text("SELECT pg_advisory_unlock(hashtext('wireguard_config'))")
|
||||
)
|
||||
|
||||
|
||||
# ── Live Status ──
|
||||
@@ -371,15 +385,23 @@ async def setup_vpn(
|
||||
db.add(config)
|
||||
await db.flush()
|
||||
|
||||
logger.info("vpn_subnet_allocated", audit=True,
|
||||
tenant_id=str(tenant_id), subnet_index=subnet_index, subnet=subnet)
|
||||
logger.info(
|
||||
"vpn_subnet_allocated",
|
||||
audit=True,
|
||||
tenant_id=str(tenant_id),
|
||||
subnet_index=subnet_index,
|
||||
subnet=subnet,
|
||||
)
|
||||
|
||||
await _commit_and_sync(db)
|
||||
return config
|
||||
|
||||
|
||||
async def update_vpn_config(
|
||||
db: AsyncSession, tenant_id: uuid.UUID, endpoint: Optional[str] = None, is_enabled: Optional[bool] = None
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
endpoint: Optional[str] = None,
|
||||
is_enabled: Optional[bool] = None,
|
||||
) -> VpnConfig:
|
||||
"""Update VPN config settings."""
|
||||
config = await get_vpn_config(db, tenant_id)
|
||||
@@ -422,14 +444,21 @@ async def _next_available_ip(db: AsyncSession, tenant_id: uuid.UUID, config: Vpn
|
||||
raise ValueError("No available IPs in VPN subnet")
|
||||
|
||||
|
||||
async def add_peer(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UUID, additional_allowed_ips: Optional[str] = None) -> VpnPeer:
|
||||
async def add_peer(
|
||||
db: AsyncSession,
|
||||
tenant_id: uuid.UUID,
|
||||
device_id: uuid.UUID,
|
||||
additional_allowed_ips: Optional[str] = None,
|
||||
) -> VpnPeer:
|
||||
"""Add a device as a VPN peer."""
|
||||
config = await get_vpn_config(db, tenant_id)
|
||||
if not config:
|
||||
raise ValueError("VPN not configured — enable VPN first")
|
||||
|
||||
# Check device exists
|
||||
device = await db.execute(select(Device).where(Device.id == device_id, Device.tenant_id == tenant_id))
|
||||
device = await db.execute(
|
||||
select(Device).where(Device.id == device_id, Device.tenant_id == tenant_id)
|
||||
)
|
||||
if not device.scalar_one_or_none():
|
||||
raise ValueError("Device not found")
|
||||
|
||||
@@ -497,13 +526,12 @@ async def get_peer_config(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid.
|
||||
psk = decrypt_credentials(peer.preshared_key, key_bytes) if peer.preshared_key else None
|
||||
|
||||
endpoint = config.endpoint or "YOUR_SERVER_IP:51820"
|
||||
peer_ip_no_cidr = peer.assigned_ip.split("/")[0]
|
||||
|
||||
routeros_commands = [
|
||||
f'/interface wireguard add name=wg-portal listen-port=13231 private-key="{private_key}"',
|
||||
f'/interface wireguard peers add interface=wg-portal public-key="{config.server_public_key}" '
|
||||
f'endpoint-address={endpoint.split(":")[0]} endpoint-port={endpoint.split(":")[-1]} '
|
||||
f'allowed-address=10.10.0.0/16 persistent-keepalive=25'
|
||||
f"endpoint-address={endpoint.split(':')[0]} endpoint-port={endpoint.split(':')[-1]} "
|
||||
f"allowed-address=10.10.0.0/16 persistent-keepalive=25"
|
||||
+ (f' preshared-key="{psk}"' if psk else ""),
|
||||
f"/ip address add address={peer.assigned_ip} interface=wg-portal",
|
||||
]
|
||||
@@ -583,8 +611,8 @@ async def onboard_device(
|
||||
routeros_commands = [
|
||||
f'/interface wireguard add name=wg-portal listen-port=13231 private-key="{private_key_b64}"',
|
||||
f'/interface wireguard peers add interface=wg-portal public-key="{config.server_public_key}" '
|
||||
f'endpoint-address={endpoint.split(":")[0]} endpoint-port={endpoint.split(":")[-1]} '
|
||||
f'allowed-address=10.10.0.0/16 persistent-keepalive=25'
|
||||
f"endpoint-address={endpoint.split(':')[0]} endpoint-port={endpoint.split(':')[-1]} "
|
||||
f"allowed-address=10.10.0.0/16 persistent-keepalive=25"
|
||||
f' preshared-key="{psk_decrypted}"',
|
||||
f"/ip address add address={assigned_ip} interface=wg-portal",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user