fix(lint): resolve all ruff lint errors

Add ruff config to exclude alembic E402, SQLAlchemy F821, and pre-existing
E501 line-length issues. Auto-fix 69 unused imports and 2 f-strings without
placeholders. Manually fix 8 unused variables. Apply ruff format to 127 files.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Jason Staack
2026-03-14 22:17:50 -05:00
parent 2ad0367c91
commit 06a41ca9bf
133 changed files with 2927 additions and 1890 deletions

View File

@@ -12,10 +12,8 @@ RLS enforced via get_db() (app_user engine with tenant context).
RBAC: viewer = read-only (GET); operator and above = write (POST/PUT/PATCH/DELETE).
"""
import base64
import logging
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
@@ -66,8 +64,13 @@ def _require_write(current_user: CurrentUser) -> None:
EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
ALLOWED_METRICS = {
"cpu_load", "memory_used_pct", "disk_used_pct", "temperature",
"signal_strength", "ccq", "client_count",
"cpu_load",
"memory_used_pct",
"disk_used_pct",
"temperature",
"signal_strength",
"ccq",
"client_count",
}
ALLOWED_OPERATORS = {"gt", "lt", "gte", "lte"}
ALLOWED_SEVERITIES = {"critical", "warning", "info"}
@@ -252,7 +255,9 @@ async def create_alert_rule(
if body.operator not in ALLOWED_OPERATORS:
raise HTTPException(422, f"operator must be one of: {', '.join(sorted(ALLOWED_OPERATORS))}")
if body.severity not in ALLOWED_SEVERITIES:
raise HTTPException(422, f"severity must be one of: {', '.join(sorted(ALLOWED_SEVERITIES))}")
raise HTTPException(
422, f"severity must be one of: {', '.join(sorted(ALLOWED_SEVERITIES))}"
)
rule_id = str(uuid.uuid4())
@@ -296,8 +301,12 @@ async def create_alert_rule(
try:
await log_action(
db, tenant_id, current_user.user_id, "alert_rule_create",
resource_type="alert_rule", resource_id=rule_id,
db,
tenant_id,
current_user.user_id,
"alert_rule_create",
resource_type="alert_rule",
resource_id=rule_id,
details={"name": body.name, "metric": body.metric, "severity": body.severity},
)
except Exception:
@@ -338,7 +347,9 @@ async def update_alert_rule(
if body.operator not in ALLOWED_OPERATORS:
raise HTTPException(422, f"operator must be one of: {', '.join(sorted(ALLOWED_OPERATORS))}")
if body.severity not in ALLOWED_SEVERITIES:
raise HTTPException(422, f"severity must be one of: {', '.join(sorted(ALLOWED_SEVERITIES))}")
raise HTTPException(
422, f"severity must be one of: {', '.join(sorted(ALLOWED_SEVERITIES))}"
)
result = await db.execute(
text("""
@@ -384,8 +395,12 @@ async def update_alert_rule(
try:
await log_action(
db, tenant_id, current_user.user_id, "alert_rule_update",
resource_type="alert_rule", resource_id=str(rule_id),
db,
tenant_id,
current_user.user_id,
"alert_rule_update",
resource_type="alert_rule",
resource_id=str(rule_id),
details={"name": body.name, "metric": body.metric, "severity": body.severity},
)
except Exception:
@@ -439,8 +454,12 @@ async def delete_alert_rule(
try:
await log_action(
db, tenant_id, current_user.user_id, "alert_rule_delete",
resource_type="alert_rule", resource_id=str(rule_id),
db,
tenant_id,
current_user.user_id,
"alert_rule_delete",
resource_type="alert_rule",
resource_id=str(rule_id),
)
except Exception:
pass
@@ -592,7 +611,8 @@ async def create_notification_channel(
encrypted_password_transit = None
if body.smtp_password:
encrypted_password_transit = await encrypt_credentials_transit(
body.smtp_password, str(tenant_id),
body.smtp_password,
str(tenant_id),
)
await db.execute(
@@ -665,10 +685,14 @@ async def update_notification_channel(
# Build SET clauses dynamically based on which secrets are provided
set_parts = [
"name = :name", "channel_type = :channel_type",
"smtp_host = :smtp_host", "smtp_port = :smtp_port",
"smtp_user = :smtp_user", "smtp_use_tls = :smtp_use_tls",
"from_address = :from_address", "to_address = :to_address",
"name = :name",
"channel_type = :channel_type",
"smtp_host = :smtp_host",
"smtp_port = :smtp_port",
"smtp_user = :smtp_user",
"smtp_use_tls = :smtp_use_tls",
"from_address = :from_address",
"to_address = :to_address",
"webhook_url = :webhook_url",
"slack_webhook_url = :slack_webhook_url",
]
@@ -689,7 +713,8 @@ async def update_notification_channel(
if body.smtp_password:
set_parts.append("smtp_password_transit = :smtp_password_transit")
params["smtp_password_transit"] = await encrypt_credentials_transit(
body.smtp_password, str(tenant_id),
body.smtp_password,
str(tenant_id),
)
# Clear legacy column
set_parts.append("smtp_password = NULL")
@@ -799,6 +824,7 @@ async def test_notification_channel(
}
from app.services.notification_service import send_test_notification
try:
success = await send_test_notification(channel)
if success:

View File

@@ -221,29 +221,38 @@ async def list_audit_logs(
all_rows = result.mappings().all()
# Decrypt encrypted details concurrently
decrypted_details = await _decrypt_details_batch(
all_rows, str(tenant_id)
)
decrypted_details = await _decrypt_details_batch(all_rows, str(tenant_id))
output = io.StringIO()
writer = csv.writer(output)
writer.writerow([
"ID", "User Email", "Action", "Resource Type",
"Resource ID", "Device", "Details", "IP Address", "Timestamp",
])
writer.writerow(
[
"ID",
"User Email",
"Action",
"Resource Type",
"Resource ID",
"Device",
"Details",
"IP Address",
"Timestamp",
]
)
for row, details in zip(all_rows, decrypted_details):
details_str = json.dumps(details) if details else "{}"
writer.writerow([
str(row["id"]),
row["user_email"] or "",
row["action"],
row["resource_type"] or "",
row["resource_id"] or "",
row["device_name"] or "",
details_str,
row["ip_address"] or "",
str(row["created_at"]),
])
writer.writerow(
[
str(row["id"]),
row["user_email"] or "",
row["action"],
row["resource_type"] or "",
row["resource_id"] or "",
row["device_name"] or "",
details_str,
row["ip_address"] or "",
str(row["created_at"]),
]
)
output.seek(0)
return StreamingResponse(

View File

@@ -103,7 +103,11 @@ async def get_redis() -> aioredis.Redis:
# ─── SRP Zero-Knowledge Authentication ───────────────────────────────────────
@router.post("/srp/init", response_model=SRPInitResponse, summary="SRP Step 1: return salt and server ephemeral B")
@router.post(
"/srp/init",
response_model=SRPInitResponse,
summary="SRP Step 1: return salt and server ephemeral B",
)
@limiter.limit("5/minute")
async def srp_init_endpoint(
request: StarletteRequest,
@@ -137,9 +141,7 @@ async def srp_init_endpoint(
# Generate server ephemeral
try:
server_public, server_private = await srp_init(
user.email, user.srp_verifier.hex()
)
server_public, server_private = await srp_init(user.email, user.srp_verifier.hex())
except ValueError as e:
logger.error("SRP init failed for %s: %s", user.email, e)
raise HTTPException(
@@ -150,13 +152,15 @@ async def srp_init_endpoint(
# Store session in Redis with 60s TTL
session_id = secrets.token_urlsafe(16)
redis = await get_redis()
session_data = json.dumps({
"email": user.email,
"server_private": server_private,
"srp_verifier_hex": user.srp_verifier.hex(),
"srp_salt_hex": user.srp_salt.hex(),
"user_id": str(user.id),
})
session_data = json.dumps(
{
"email": user.email,
"server_private": server_private,
"srp_verifier_hex": user.srp_verifier.hex(),
"srp_salt_hex": user.srp_salt.hex(),
"user_id": str(user.id),
}
)
await redis.set(f"srp:session:{session_id}", session_data, ex=60)
return SRPInitResponse(
@@ -168,7 +172,11 @@ async def srp_init_endpoint(
)
@router.post("/srp/verify", response_model=SRPVerifyResponse, summary="SRP Step 2: verify client proof and return tokens")
@router.post(
"/srp/verify",
response_model=SRPVerifyResponse,
summary="SRP Step 2: verify client proof and return tokens",
)
@limiter.limit("5/minute")
async def srp_verify_endpoint(
request: StarletteRequest,
@@ -236,7 +244,9 @@ async def srp_verify_endpoint(
# Update last_login and clear upgrade flag on successful SRP login
await db.execute(
update(User).where(User.id == user.id).values(
update(User)
.where(User.id == user.id)
.values(
last_login=datetime.now(UTC),
must_upgrade_auth=False,
)
@@ -323,9 +333,7 @@ async def login(
Rate limited to 5 requests per minute per IP.
"""
# Look up user by email (case-insensitive)
result = await db.execute(
select(User).where(User.email == body.email.lower())
)
result = await db.execute(select(User).where(User.email == body.email.lower()))
user = result.scalar_one_or_none()
# Generic error — do not reveal whether email exists (no user enumeration)
@@ -389,7 +397,9 @@ async def login(
# Update last_login
await db.execute(
update(User).where(User.id == user.id).values(
update(User)
.where(User.id == user.id)
.values(
last_login=datetime.now(UTC),
)
)
@@ -404,7 +414,10 @@ async def login(
user_id=user.id,
action="login_upgrade" if user.must_upgrade_auth else "login",
resource_type="auth",
details={"email": user.email, **({"upgrade": "bcrypt_to_srp"} if user.must_upgrade_auth else {})},
details={
"email": user.email,
**({"upgrade": "bcrypt_to_srp"} if user.must_upgrade_auth else {}),
},
ip_address=request.client.host if request.client else None,
)
await audit_db.commit()
@@ -440,7 +453,9 @@ async def refresh_token(
Rate limited to 10 requests per minute per IP.
"""
# Resolve token: body takes precedence over cookie
raw_token = (body.refresh_token if body and body.refresh_token else None) or refresh_token_cookie
raw_token = (
body.refresh_token if body and body.refresh_token else None
) or refresh_token_cookie
if not raw_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -518,7 +533,9 @@ async def refresh_token(
)
@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT, summary="Log out and clear session cookie")
@router.post(
"/logout", status_code=status.HTTP_204_NO_CONTENT, summary="Log out and clear session cookie"
)
@limiter.limit("10/minute")
async def logout(
request: StarletteRequest,
@@ -535,7 +552,10 @@ async def logout(
tenant_id = current_user.tenant_id or uuid.UUID(int=0)
async with AdminAsyncSessionLocal() as audit_db:
await log_action(
audit_db, tenant_id, current_user.user_id, "logout",
audit_db,
tenant_id,
current_user.user_id,
"logout",
resource_type="auth",
ip_address=request.client.host if request.client else None,
)
@@ -558,7 +578,11 @@ async def logout(
)
@router.post("/change-password", response_model=MessageResponse, summary="Change password for authenticated user")
@router.post(
"/change-password",
response_model=MessageResponse,
summary="Change password for authenticated user",
)
@limiter.limit("3/minute")
async def change_password(
request: StarletteRequest,
@@ -602,7 +626,9 @@ async def change_password(
existing_ks.hkdf_salt = base64.b64decode(body.hkdf_salt or "")
else:
# Legacy bcrypt user — verify current password
if not user.hashed_password or not verify_password(body.current_password, user.hashed_password):
if not user.hashed_password or not verify_password(
body.current_password, user.hashed_password
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Current password is incorrect",
@@ -822,7 +848,9 @@ async def get_emergency_kit_template(
)
@router.post("/register-srp", response_model=MessageResponse, summary="Register SRP credentials for a user")
@router.post(
"/register-srp", response_model=MessageResponse, summary="Register SRP credentials for a user"
)
@limiter.limit("3/minute")
async def register_srp(
request: StarletteRequest,
@@ -845,7 +873,9 @@ async def register_srp(
# Update user with SRP credentials and clear upgrade flag
await db.execute(
update(User).where(User.id == user.id).values(
update(User)
.where(User.id == user.id)
.values(
srp_salt=bytes.fromhex(body.srp_salt),
srp_verifier=bytes.fromhex(body.srp_verifier),
auth_version=2,
@@ -873,8 +903,11 @@ async def register_srp(
try:
async with AdminAsyncSessionLocal() as audit_db:
await log_key_access(
audit_db, user.tenant_id or uuid.UUID(int=0), user.id,
"create_key_set", resource_type="user_key_set",
audit_db,
user.tenant_id or uuid.UUID(int=0),
user.id,
"create_key_set",
resource_type="user_key_set",
ip_address=request.client.host if request.client else None,
)
await audit_db.commit()
@@ -901,11 +934,17 @@ async def create_sse_token(
token = secrets.token_urlsafe(32)
key = f"sse_token:{token}"
# Store user context for the SSE endpoint to retrieve
await redis.set(key, json.dumps({
"user_id": str(current_user.user_id),
"tenant_id": str(current_user.tenant_id) if current_user.tenant_id else None,
"role": current_user.role,
}), ex=30) # 30 second TTL
await redis.set(
key,
json.dumps(
{
"user_id": str(current_user.user_id),
"tenant_id": str(current_user.tenant_id) if current_user.tenant_id else None,
"role": current_user.role,
}
),
ex=30,
) # 30 second TTL
return {"token": token}
@@ -977,9 +1016,7 @@ async def forgot_password(
"""
generic_msg = "If an account with that email exists, a reset link has been sent."
result = await db.execute(
select(User).where(User.email == body.email.lower())
)
result = await db.execute(select(User).where(User.email == body.email.lower()))
user = result.scalar_one_or_none()
if not user or not user.is_active:
@@ -988,9 +1025,7 @@ async def forgot_password(
# Generate a secure token
raw_token = secrets.token_urlsafe(32)
token_hash = _hash_token(raw_token)
expires_at = datetime.now(UTC) + timedelta(
minutes=settings.PASSWORD_RESET_TOKEN_EXPIRE_MINUTES
)
expires_at = datetime.now(UTC) + timedelta(minutes=settings.PASSWORD_RESET_TOKEN_EXPIRE_MINUTES)
# Insert token record (using raw SQL to avoid importing the model globally)
from sqlalchemy import text

View File

@@ -12,9 +12,7 @@ RBAC: viewer = read-only (GET); tenant_admin and above = mutating actions.
"""
import json
import logging
import uuid
from datetime import datetime, timezone
import nats
import nats.aio.client
@@ -30,7 +28,7 @@ from app.database import get_db, set_tenant_context
from app.middleware.rate_limit import limiter
from app.middleware.rbac import require_min_role
from app.middleware.tenant_context import CurrentUser, get_current_user
from app.models.certificate import CertificateAuthority, DeviceCertificate
from app.models.certificate import DeviceCertificate
from app.models.device import Device
from app.schemas.certificate import (
BulkCertDeployRequest,
@@ -87,13 +85,15 @@ async def _deploy_cert_via_nats(
Dict with success, cert_name_on_device, and error fields.
"""
nc = await _get_nats()
payload = json.dumps({
"device_id": device_id,
"cert_pem": cert_pem,
"key_pem": key_pem,
"cert_name": cert_name,
"ssh_port": ssh_port,
}).encode()
payload = json.dumps(
{
"device_id": device_id,
"cert_pem": cert_pem,
"key_pem": key_pem,
"cert_name": cert_name,
"ssh_port": ssh_port,
}
).encode()
try:
reply = await nc.request(
@@ -121,9 +121,7 @@ async def _get_device_for_tenant(
db: AsyncSession, device_id: uuid.UUID, current_user: CurrentUser
) -> Device:
"""Fetch a device and verify tenant ownership."""
result = await db.execute(
select(Device).where(Device.id == device_id)
)
result = await db.execute(select(Device).where(Device.id == device_id))
device = result.scalar_one_or_none()
if device is None:
raise HTTPException(
@@ -164,9 +162,7 @@ async def _get_cert_with_tenant_check(
db: AsyncSession, cert_id: uuid.UUID, tenant_id: uuid.UUID
) -> DeviceCertificate:
"""Fetch a device certificate and verify tenant ownership."""
result = await db.execute(
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
)
result = await db.execute(select(DeviceCertificate).where(DeviceCertificate.id == cert_id))
cert = result.scalar_one_or_none()
if cert is None:
raise HTTPException(
@@ -226,8 +222,12 @@ async def create_ca(
try:
await log_action(
db, tenant_id, current_user.user_id, "ca_create",
resource_type="certificate_authority", resource_id=str(ca.id),
db,
tenant_id,
current_user.user_id,
"ca_create",
resource_type="certificate_authority",
resource_id=str(ca.id),
details={"common_name": body.common_name, "validity_years": body.validity_years},
)
except Exception:
@@ -332,8 +332,12 @@ async def sign_cert(
try:
await log_action(
db, tenant_id, current_user.user_id, "cert_sign",
resource_type="device_certificate", resource_id=str(cert.id),
db,
tenant_id,
current_user.user_id,
"cert_sign",
resource_type="device_certificate",
resource_id=str(cert.id),
device_id=body.device_id,
details={"hostname": device.hostname, "validity_days": body.validity_days},
)
@@ -404,17 +408,19 @@ async def deploy_cert(
await update_cert_status(db, cert_id, "deployed")
# Update device tls_mode to portal_ca
device_result = await db.execute(
select(Device).where(Device.id == cert.device_id)
)
device_result = await db.execute(select(Device).where(Device.id == cert.device_id))
device = device_result.scalar_one_or_none()
if device is not None:
device.tls_mode = "portal_ca"
try:
await log_action(
db, tenant_id, current_user.user_id, "cert_deploy",
resource_type="device_certificate", resource_id=str(cert_id),
db,
tenant_id,
current_user.user_id,
"cert_deploy",
resource_type="device_certificate",
resource_id=str(cert_id),
device_id=cert.device_id,
details={"cert_name_on_device": result.get("cert_name_on_device")},
)
@@ -528,36 +534,47 @@ async def bulk_deploy(
await update_cert_status(db, issued_cert.id, "deployed")
device.tls_mode = "portal_ca"
results.append(CertDeployResponse(
success=True,
device_id=device_id,
cert_name_on_device=result.get("cert_name_on_device"),
))
results.append(
CertDeployResponse(
success=True,
device_id=device_id,
cert_name_on_device=result.get("cert_name_on_device"),
)
)
else:
await update_cert_status(db, issued_cert.id, "issued")
results.append(CertDeployResponse(
success=False,
device_id=device_id,
error=result.get("error"),
))
results.append(
CertDeployResponse(
success=False,
device_id=device_id,
error=result.get("error"),
)
)
except HTTPException as e:
results.append(CertDeployResponse(
success=False,
device_id=device_id,
error=e.detail,
))
results.append(
CertDeployResponse(
success=False,
device_id=device_id,
error=e.detail,
)
)
except Exception as e:
logger.error("Bulk deploy error", device_id=str(device_id), error=str(e))
results.append(CertDeployResponse(
success=False,
device_id=device_id,
error=str(e),
))
results.append(
CertDeployResponse(
success=False,
device_id=device_id,
error=str(e),
)
)
try:
await log_action(
db, tenant_id, current_user.user_id, "cert_bulk_deploy",
db,
tenant_id,
current_user.user_id,
"cert_bulk_deploy",
resource_type="device_certificate",
details={
"device_count": len(body.device_ids),
@@ -619,17 +636,19 @@ async def revoke_cert(
)
# Reset device tls_mode to insecure
device_result = await db.execute(
select(Device).where(Device.id == cert.device_id)
)
device_result = await db.execute(select(Device).where(Device.id == cert.device_id))
device = device_result.scalar_one_or_none()
if device is not None:
device.tls_mode = "insecure"
try:
await log_action(
db, tenant_id, current_user.user_id, "cert_revoke",
resource_type="device_certificate", resource_id=str(cert_id),
db,
tenant_id,
current_user.user_id,
"cert_revoke",
resource_type="device_certificate",
resource_id=str(cert_id),
device_id=cert.device_id,
)
except Exception:
@@ -661,9 +680,7 @@ async def rotate_cert(
old_cert = await _get_cert_with_tenant_check(db, cert_id, tenant_id)
# Get the device for hostname/IP
device_result = await db.execute(
select(Device).where(Device.id == old_cert.device_id)
)
device_result = await db.execute(select(Device).where(Device.id == old_cert.device_id))
device = device_result.scalar_one_or_none()
if device is None:
raise HTTPException(
@@ -722,8 +739,12 @@ async def rotate_cert(
try:
await log_action(
db, tenant_id, current_user.user_id, "cert_rotate",
resource_type="device_certificate", resource_id=str(new_cert.id),
db,
tenant_id,
current_user.user_id,
"cert_rotate",
resource_type="device_certificate",
resource_id=str(new_cert.id),
device_id=old_cert.device_id,
details={
"old_cert_id": str(cert_id),

View File

@@ -43,6 +43,7 @@ async def _check_tenant_access(
"""Verify the current user is allowed to access the given tenant."""
if current_user.is_super_admin:
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
@@ -52,9 +53,7 @@ async def _check_tenant_access(
)
async def _check_device_online(
db: AsyncSession, device_id: uuid.UUID
) -> Device:
async def _check_device_online(db: AsyncSession, device_id: uuid.UUID) -> Device:
"""Verify the device exists and is online. Returns the Device object."""
result = await db.execute(
select(Device).where(Device.id == device_id) # type: ignore[arg-type]

View File

@@ -25,7 +25,6 @@ import asyncio
import json
import logging
import uuid
from datetime import timezone, datetime
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Request, status
@@ -67,6 +66,7 @@ async def _check_tenant_access(
"""
if current_user.is_super_admin:
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
@@ -291,14 +291,14 @@ async def get_export(
try:
from app.services.crypto import decrypt_data_transit
plaintext = await decrypt_data_transit(
content_bytes.decode("utf-8"), str(tenant_id)
)
plaintext = await decrypt_data_transit(content_bytes.decode("utf-8"), str(tenant_id))
content_bytes = plaintext.encode("utf-8")
except Exception as dec_err:
logger.error(
"Failed to decrypt export for device %s sha %s: %s",
device_id, commit_sha, dec_err,
device_id,
commit_sha,
dec_err,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -370,7 +370,9 @@ async def get_binary(
except Exception as dec_err:
logger.error(
"Failed to decrypt binary backup for device %s sha %s: %s",
device_id, commit_sha, dec_err,
device_id,
commit_sha,
dec_err,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -380,9 +382,7 @@ async def get_binary(
return Response(
content=content_bytes,
media_type="application/octet-stream",
headers={
"Content-Disposition": f'attachment; filename="backup-{commit_sha[:8]}.bin"'
},
headers={"Content-Disposition": f'attachment; filename="backup-{commit_sha[:8]}.bin"'},
)
@@ -445,6 +445,7 @@ async def preview_restore(
key,
)
import json
creds = json.loads(creds_json)
current_text = await backup_service.capture_export(
device.ip_address,
@@ -578,9 +579,7 @@ async def emergency_rollback(
.where(
ConfigBackupRun.device_id == device_id, # type: ignore[arg-type]
ConfigBackupRun.tenant_id == tenant_id, # type: ignore[arg-type]
ConfigBackupRun.trigger_type.in_(
["pre-restore", "checkpoint", "pre-template-push"]
),
ConfigBackupRun.trigger_type.in_(["pre-restore", "checkpoint", "pre-template-push"]),
)
.order_by(ConfigBackupRun.created_at.desc())
.limit(1)
@@ -735,6 +734,7 @@ async def update_schedule(
# Hot-reload the scheduler so changes take effect immediately
from app.services.backup_scheduler import on_schedule_change
await on_schedule_change(tenant_id, device_id)
return {
@@ -758,6 +758,7 @@ async def _get_nats():
Reuses the same lazy-init pattern as routeros_proxy._get_nats().
"""
from app.services.routeros_proxy import _get_nats as _proxy_get_nats
return await _proxy_get_nats()
@@ -839,6 +840,7 @@ async def trigger_config_snapshot(
if reply_data.get("status") == "success":
try:
from app.services.audit_service import log_action
await log_action(
db,
tenant_id,

View File

@@ -64,9 +64,7 @@ async def _check_tenant_access(
await set_tenant_context(db, str(tenant_id))
async def _check_device_online(
db: AsyncSession, device_id: uuid.UUID
) -> Device:
async def _check_device_online(db: AsyncSession, device_id: uuid.UUID) -> Device:
"""Verify the device exists and is online. Returns the Device object."""
result = await db.execute(
select(Device).where(Device.id == device_id) # type: ignore[arg-type]
@@ -201,8 +199,12 @@ async def add_entry(
try:
await log_action(
db, tenant_id, current_user.user_id, "config_add",
resource_type="config", resource_id=str(device_id),
db,
tenant_id,
current_user.user_id,
"config_add",
resource_type="config",
resource_id=str(device_id),
device_id=device_id,
details={"path": body.path, "properties": body.properties},
)
@@ -255,8 +257,12 @@ async def set_entry(
try:
await log_action(
db, tenant_id, current_user.user_id, "config_set",
resource_type="config", resource_id=str(device_id),
db,
tenant_id,
current_user.user_id,
"config_set",
resource_type="config",
resource_id=str(device_id),
device_id=device_id,
details={"path": body.path, "entry_id": body.entry_id, "properties": body.properties},
)
@@ -286,9 +292,7 @@ async def remove_entry(
await _check_device_online(db, device_id)
check_path_safety(body.path, write=True)
result = await routeros_proxy.remove_entry(
str(device_id), body.path, body.entry_id
)
result = await routeros_proxy.remove_entry(str(device_id), body.path, body.entry_id)
if not result.get("success"):
raise HTTPException(
@@ -309,8 +313,12 @@ async def remove_entry(
try:
await log_action(
db, tenant_id, current_user.user_id, "config_remove",
resource_type="config", resource_id=str(device_id),
db,
tenant_id,
current_user.user_id,
"config_remove",
resource_type="config",
resource_id=str(device_id),
device_id=device_id,
details={"path": body.path, "entry_id": body.entry_id},
)
@@ -360,8 +368,12 @@ async def execute_command(
try:
await log_action(
db, tenant_id, current_user.user_id, "config_execute",
resource_type="config", resource_id=str(device_id),
db,
tenant_id,
current_user.user_id,
"config_execute",
resource_type="config",
resource_id=str(device_id),
device_id=device_id,
details={"command": body.command},
)

View File

@@ -43,6 +43,7 @@ async def _check_tenant_access(
"""
if current_user.is_super_admin:
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
@@ -115,9 +116,7 @@ async def view_snapshot(
session=db,
)
except Exception:
logger.exception(
"Failed to decrypt snapshot %s for device %s", snapshot_id, device_id
)
logger.exception("Failed to decrypt snapshot %s for device %s", snapshot_id, device_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to decrypt snapshot content",

View File

@@ -29,12 +29,14 @@ router = APIRouter(tags=["device-logs"])
# Helpers (same pattern as config_editor.py)
# ---------------------------------------------------------------------------
async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
"""Verify the current user is allowed to access the given tenant."""
if current_user.is_super_admin:
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
@@ -44,16 +46,12 @@ async def _check_tenant_access(
)
async def _check_device_exists(
db: AsyncSession, device_id: uuid.UUID
) -> None:
async def _check_device_exists(db: AsyncSession, device_id: uuid.UUID) -> None:
"""Verify the device exists (does not require online status for logs)."""
from sqlalchemy import select
from app.models.device import Device
result = await db.execute(
select(Device).where(Device.id == device_id)
)
result = await db.execute(select(Device).where(Device.id == device_id))
device = result.scalar_one_or_none()
if device is None:
raise HTTPException(
@@ -66,6 +64,7 @@ async def _check_device_exists(
# Response model
# ---------------------------------------------------------------------------
class LogEntry(BaseModel):
time: str
topics: str
@@ -82,6 +81,7 @@ class LogsResponse(BaseModel):
# Endpoint
# ---------------------------------------------------------------------------
@router.get(
"/tenants/{tenant_id}/devices/{device_id}/logs",
response_model=LogsResponse,

View File

@@ -72,9 +72,7 @@ async def update_tag(
) -> DeviceTagResponse:
"""Update a device tag. Requires operator role or above."""
await _check_tenant_access(current_user, tenant_id, db)
return await device_service.update_tag(
db=db, tenant_id=tenant_id, tag_id=tag_id, data=data
)
return await device_service.update_tag(db=db, tenant_id=tenant_id, tag_id=tag_id, data=data)
@router.delete(

View File

@@ -23,7 +23,6 @@ from app.database import get_db
from app.middleware.rate_limit import limiter
from app.services.audit_service import log_action
from app.middleware.rbac import (
require_min_role,
require_operator_or_above,
require_scope,
require_tenant_admin_or_above,
@@ -57,6 +56,7 @@ async def _check_tenant_access(
if current_user.is_super_admin:
# Re-set tenant context to the target tenant so RLS allows the operation
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
@@ -138,8 +138,12 @@ async def create_device(
)
try:
await log_action(
db, tenant_id, current_user.user_id, "device_create",
resource_type="device", resource_id=str(result.id),
db,
tenant_id,
current_user.user_id,
"device_create",
resource_type="device",
resource_id=str(result.id),
details={"hostname": data.hostname, "ip_address": data.ip_address},
ip_address=request.client.host if request.client else None,
)
@@ -191,8 +195,12 @@ async def update_device(
)
try:
await log_action(
db, tenant_id, current_user.user_id, "device_update",
resource_type="device", resource_id=str(device_id),
db,
tenant_id,
current_user.user_id,
"device_update",
resource_type="device",
resource_id=str(device_id),
device_id=device_id,
details={"changes": data.model_dump(exclude_unset=True)},
ip_address=request.client.host if request.client else None,
@@ -220,8 +228,12 @@ async def delete_device(
await _check_tenant_access(current_user, tenant_id, db)
try:
await log_action(
db, tenant_id, current_user.user_id, "device_delete",
resource_type="device", resource_id=str(device_id),
db,
tenant_id,
current_user.user_id,
"device_delete",
resource_type="device",
resource_id=str(device_id),
device_id=device_id,
ip_address=request.client.host if request.client else None,
)
@@ -262,14 +274,21 @@ async def scan_devices(
discovered = await scan_subnet(data.cidr)
import ipaddress
network = ipaddress.ip_network(data.cidr, strict=False)
total_scanned = network.num_addresses - 2 if network.num_addresses > 2 else network.num_addresses
total_scanned = (
network.num_addresses - 2 if network.num_addresses > 2 else network.num_addresses
)
# Audit log the scan (fire-and-forget — never breaks the response)
try:
await log_action(
db, tenant_id, current_user.user_id, "subnet_scan",
resource_type="network", resource_id=data.cidr,
db,
tenant_id,
current_user.user_id,
"subnet_scan",
resource_type="network",
resource_id=data.cidr,
details={
"cidr": data.cidr,
"devices_found": len(discovered),
@@ -322,10 +341,12 @@ async def bulk_add_devices(
password = dev_data.password or data.shared_password
if not username or not password:
failed.append({
"ip_address": dev_data.ip_address,
"error": "No credentials provided (set per-device or shared credentials)",
})
failed.append(
{
"ip_address": dev_data.ip_address,
"error": "No credentials provided (set per-device or shared credentials)",
}
)
continue
create_data = DeviceCreate(
@@ -347,9 +368,16 @@ async def bulk_add_devices(
added.append(device)
try:
await log_action(
db, tenant_id, current_user.user_id, "device_adopt",
resource_type="device", resource_id=str(device.id),
details={"hostname": create_data.hostname, "ip_address": create_data.ip_address},
db,
tenant_id,
current_user.user_id,
"device_adopt",
resource_type="device",
resource_id=str(device.id),
details={
"hostname": create_data.hostname,
"ip_address": create_data.ip_address,
},
ip_address=request.client.host if request.client else None,
)
except Exception:

View File

@@ -90,16 +90,18 @@ async def list_events(
for row in alert_result.fetchall():
alert_status = row[1] or "firing"
metric = row[3] or "unknown"
events.append({
"id": str(row[0]),
"event_type": "alert",
"severity": row[2],
"title": f"{alert_status}: {metric}",
"description": row[4] or f"Alert {alert_status} for {metric}",
"device_hostname": row[7],
"device_id": str(row[6]) if row[6] else None,
"timestamp": row[5].isoformat() if row[5] else None,
})
events.append(
{
"id": str(row[0]),
"event_type": "alert",
"severity": row[2],
"title": f"{alert_status}: {metric}",
"description": row[4] or f"Alert {alert_status} for {metric}",
"device_hostname": row[7],
"device_id": str(row[6]) if row[6] else None,
"timestamp": row[5].isoformat() if row[5] else None,
}
)
# 2. Device status changes (inferred from current status + last_seen)
if not event_type or event_type == "status_change":
@@ -117,16 +119,18 @@ async def list_events(
device_status = row[2] or "unknown"
hostname = row[1] or "Unknown device"
severity = "info" if device_status == "online" else "warning"
events.append({
"id": f"status-{row[0]}",
"event_type": "status_change",
"severity": severity,
"title": f"Device {device_status}",
"description": f"{hostname} is now {device_status}",
"device_hostname": hostname,
"device_id": str(row[0]),
"timestamp": row[3].isoformat() if row[3] else None,
})
events.append(
{
"id": f"status-{row[0]}",
"event_type": "status_change",
"severity": severity,
"title": f"Device {device_status}",
"description": f"{hostname} is now {device_status}",
"device_hostname": hostname,
"device_id": str(row[0]),
"timestamp": row[3].isoformat() if row[3] else None,
}
)
# 3. Config backup runs
if not event_type or event_type == "config_backup":
@@ -144,16 +148,18 @@ async def list_events(
for row in backup_result.fetchall():
trigger_type = row[1] or "manual"
hostname = row[4] or "Unknown device"
events.append({
"id": str(row[0]),
"event_type": "config_backup",
"severity": "info",
"title": "Config backup",
"description": f"{trigger_type} backup completed for {hostname}",
"device_hostname": hostname,
"device_id": str(row[3]) if row[3] else None,
"timestamp": row[2].isoformat() if row[2] else None,
})
events.append(
{
"id": str(row[0]),
"event_type": "config_backup",
"severity": "info",
"title": "Config backup",
"description": f"{trigger_type} backup completed for {hostname}",
"device_hostname": hostname,
"device_id": str(row[3]) if row[3] else None,
"timestamp": row[2].isoformat() if row[2] else None,
}
)
# Sort all events by timestamp descending, then apply final limit
events.sort(

View File

@@ -67,6 +67,7 @@ async def get_firmware_overview(
await _check_tenant_access(current_user, tenant_id, db)
from app.services.firmware_service import get_firmware_overview as _get_overview
return await _get_overview(str(tenant_id))
@@ -206,6 +207,7 @@ async def trigger_firmware_check(
raise HTTPException(status_code=403, detail="Super admin only")
from app.services.firmware_service import check_latest_versions
results = await check_latest_versions()
return {"status": "ok", "versions_discovered": len(results), "versions": results}
@@ -221,6 +223,7 @@ async def list_firmware_cache(
raise HTTPException(status_code=403, detail="Super admin only")
from app.services.firmware_service import get_cached_firmware
return await get_cached_firmware()
@@ -236,6 +239,7 @@ async def download_firmware(
raise HTTPException(status_code=403, detail="Super admin only")
from app.services.firmware_service import download_firmware as _download
path = await _download(body.architecture, body.channel, body.version)
return {"status": "ok", "path": path}
@@ -324,15 +328,21 @@ async def start_firmware_upgrade(
# Schedule or start immediately
if body.scheduled_at:
from app.services.upgrade_service import schedule_upgrade
schedule_upgrade(job_id, datetime.fromisoformat(body.scheduled_at))
else:
from app.services.upgrade_service import start_upgrade
asyncio.create_task(start_upgrade(job_id))
try:
await log_action(
db, tenant_id, current_user.user_id, "firmware_upgrade",
resource_type="firmware", resource_id=job_id,
db,
tenant_id,
current_user.user_id,
"firmware_upgrade",
resource_type="firmware",
resource_id=job_id,
device_id=uuid.UUID(body.device_id),
details={"target_version": body.target_version, "channel": body.channel},
)
@@ -406,9 +416,11 @@ async def start_mass_firmware_upgrade(
# Schedule or start immediately
if body.scheduled_at:
from app.services.upgrade_service import schedule_mass_upgrade
schedule_mass_upgrade(rollout_group_id, datetime.fromisoformat(body.scheduled_at))
else:
from app.services.upgrade_service import start_mass_upgrade
asyncio.create_task(start_mass_upgrade(rollout_group_id))
return {
@@ -639,6 +651,7 @@ async def cancel_upgrade_endpoint(
raise HTTPException(403, "Viewers cannot cancel upgrades")
from app.services.upgrade_service import cancel_upgrade
await cancel_upgrade(str(job_id))
return {"status": "ok", "message": "Upgrade cancelled"}
@@ -662,6 +675,7 @@ async def retry_upgrade_endpoint(
raise HTTPException(403, "Viewers cannot retry upgrades")
from app.services.upgrade_service import retry_failed_upgrade
await retry_failed_upgrade(str(job_id))
return {"status": "ok", "message": "Upgrade retry started"}
@@ -685,6 +699,7 @@ async def resume_rollout_endpoint(
raise HTTPException(403, "Viewers cannot resume rollouts")
from app.services.upgrade_service import resume_mass_upgrade
await resume_mass_upgrade(str(rollout_group_id))
return {"status": "ok", "message": "Rollout resumed"}
@@ -708,5 +723,6 @@ async def abort_rollout_endpoint(
raise HTTPException(403, "Viewers cannot abort rollouts")
from app.services.upgrade_service import abort_mass_upgrade
aborted = await abort_mass_upgrade(str(rollout_group_id))
return {"status": "ok", "aborted_count": aborted}

View File

@@ -299,9 +299,7 @@ async def delete_maintenance_window(
_require_operator(current_user)
result = await db.execute(
text(
"DELETE FROM maintenance_windows WHERE id = CAST(:id AS uuid) RETURNING id"
),
text("DELETE FROM maintenance_windows WHERE id = CAST(:id AS uuid) RETURNING id"),
{"id": str(window_id)},
)
if not result.fetchone():

View File

@@ -65,6 +65,7 @@ async def _check_tenant_access(
if current_user.is_super_admin:
# Re-set tenant context to the target tenant so RLS allows the operation
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:

View File

@@ -81,9 +81,12 @@ async def _get_device(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UU
return device
async def _check_tenant_access(current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession) -> None:
async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
if current_user.is_super_admin:
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
@@ -124,8 +127,12 @@ async def open_winbox_session(
try:
await log_action(
db, tenant_id, current_user.user_id, "winbox_tunnel_open",
resource_type="device", resource_id=str(device_id),
db,
tenant_id,
current_user.user_id,
"winbox_tunnel_open",
resource_type="device",
resource_id=str(device_id),
device_id=device_id,
details={"source_ip": source_ip},
ip_address=source_ip,
@@ -133,24 +140,31 @@ async def open_winbox_session(
except Exception:
pass
payload = json.dumps({
"device_id": str(device_id),
"tenant_id": str(tenant_id),
"user_id": str(current_user.user_id),
"target_port": 8291,
}).encode()
payload = json.dumps(
{
"device_id": str(device_id),
"tenant_id": str(tenant_id),
"user_id": str(current_user.user_id),
"target_port": 8291,
}
).encode()
try:
nc = await _get_nats()
msg = await nc.request("tunnel.open", payload, timeout=10)
except Exception as exc:
logger.error("NATS tunnel.open failed: %s", exc)
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Tunnel service unavailable")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Tunnel service unavailable"
)
try:
data = json.loads(msg.data)
except Exception:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Invalid response from tunnel service")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Invalid response from tunnel service",
)
if "error" in data:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=data["error"])
@@ -158,11 +172,16 @@ async def open_winbox_session(
port = data.get("local_port")
tunnel_id = data.get("tunnel_id", "")
if not isinstance(port, int) or not (49000 <= port <= 49100):
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Invalid port allocation from tunnel service")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Invalid port allocation from tunnel service",
)
# Derive the tunnel host from the request so remote clients get the server's
# address rather than 127.0.0.1 (which would point to the user's own machine).
tunnel_host = (request.headers.get("x-forwarded-host") or request.headers.get("host") or "127.0.0.1")
tunnel_host = (
request.headers.get("x-forwarded-host") or request.headers.get("host") or "127.0.0.1"
)
# Strip port from host header (e.g. "10.101.0.175:8001" → "10.101.0.175")
tunnel_host = tunnel_host.split(":")[0]
@@ -213,8 +232,12 @@ async def open_ssh_session(
try:
await log_action(
db, tenant_id, current_user.user_id, "ssh_session_open",
resource_type="device", resource_id=str(device_id),
db,
tenant_id,
current_user.user_id,
"ssh_session_open",
resource_type="device",
resource_id=str(device_id),
device_id=device_id,
details={"source_ip": source_ip, "cols": body.cols, "rows": body.rows},
ip_address=source_ip,
@@ -223,22 +246,26 @@ async def open_ssh_session(
pass
token = secrets.token_urlsafe(32)
token_payload = json.dumps({
"device_id": str(device_id),
"tenant_id": str(tenant_id),
"user_id": str(current_user.user_id),
"source_ip": source_ip,
"cols": body.cols,
"rows": body.rows,
"created_at": int(time.time()),
})
token_payload = json.dumps(
{
"device_id": str(device_id),
"tenant_id": str(tenant_id),
"user_id": str(current_user.user_id),
"source_ip": source_ip,
"cols": body.cols,
"rows": body.rows,
"created_at": int(time.time()),
}
)
try:
rd = await _get_redis()
await rd.setex(f"ssh:token:{token}", 120, token_payload)
except Exception as exc:
logger.error("Redis setex failed for SSH token: %s", exc)
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Session store unavailable")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Session store unavailable"
)
return SSHSessionResponse(
token=token,
@@ -274,8 +301,12 @@ async def close_winbox_session(
try:
await log_action(
db, tenant_id, current_user.user_id, "winbox_tunnel_close",
resource_type="device", resource_id=str(device_id),
db,
tenant_id,
current_user.user_id,
"winbox_tunnel_close",
resource_type="device",
resource_id=str(device_id),
device_id=device_id,
details={"tunnel_id": tunnel_id, "source_ip": source_ip},
ip_address=source_ip,

View File

@@ -72,7 +72,11 @@ async def _set_system_settings(updates: dict, user_id: str) -> None:
ON CONFLICT (key) DO UPDATE
SET value = :value, updated_by = CAST(:user_id AS uuid), updated_at = now()
"""),
{"key": key, "value": str(value) if value is not None else None, "user_id": user_id},
{
"key": key,
"value": str(value) if value is not None else None,
"user_id": user_id,
},
)
await session.commit()
@@ -100,7 +104,8 @@ async def get_smtp_settings(user=Depends(require_role("super_admin"))):
"smtp_host": db_settings.get("smtp_host") or settings.SMTP_HOST,
"smtp_port": int(db_settings.get("smtp_port") or settings.SMTP_PORT),
"smtp_user": db_settings.get("smtp_user") or settings.SMTP_USER or "",
"smtp_use_tls": (db_settings.get("smtp_use_tls") or str(settings.SMTP_USE_TLS)).lower() == "true",
"smtp_use_tls": (db_settings.get("smtp_use_tls") or str(settings.SMTP_USE_TLS)).lower()
== "true",
"smtp_from_address": db_settings.get("smtp_from_address") or settings.SMTP_FROM_ADDRESS,
"smtp_provider": db_settings.get("smtp_provider") or "custom",
"smtp_password_set": bool(db_settings.get("smtp_password") or settings.SMTP_PASSWORD),

View File

@@ -32,6 +32,7 @@ async def _get_sse_redis() -> aioredis.Redis:
global _redis
if _redis is None:
from app.config import settings
_redis = aioredis.from_url(settings.REDIS_URL, decode_responses=True)
return _redis
@@ -70,7 +71,9 @@ async def _validate_sse_token(token: str) -> dict:
async def event_stream(
request: Request,
tenant_id: uuid.UUID,
token: str = Query(..., description="Short-lived SSE exchange token (from POST /auth/sse-token)"),
token: str = Query(
..., description="Short-lived SSE exchange token (from POST /auth/sse-token)"
),
) -> EventSourceResponse:
"""Stream real-time events for a tenant via Server-Sent Events.
@@ -87,7 +90,9 @@ async def event_stream(
user_id = user_context.get("user_id", "")
# Authorization: user must belong to the requested tenant or be super_admin
if user_role != "super_admin" and (user_tenant_id is None or str(user_tenant_id) != str(tenant_id)):
if user_role != "super_admin" and (
user_tenant_id is None or str(user_tenant_id) != str(tenant_id)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authorized for this tenant",

View File

@@ -21,7 +21,6 @@ RBAC: viewer = read (GET/preview); operator and above = write (POST/PUT/DELETE/p
import asyncio
import logging
import uuid
from datetime import datetime, timezone
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
@@ -54,6 +53,7 @@ async def _check_tenant_access(
"""Verify the current user is allowed to access the given tenant."""
if current_user.is_super_admin:
from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id))
return
if current_user.tenant_id != tenant_id:
@@ -191,7 +191,8 @@ async def create_template(
if unmatched:
logger.warning(
"Template '%s' has undeclared variables: %s (auto-adding as string type)",
body.name, unmatched,
body.name,
unmatched,
)
# Create template
@@ -553,11 +554,13 @@ async def push_template(
status="pending",
)
db.add(job)
jobs_created.append({
"job_id": str(job.id),
"device_id": str(device.id),
"device_hostname": device.hostname,
})
jobs_created.append(
{
"job_id": str(job.id),
"device_id": str(device.id),
"device_hostname": device.hostname,
}
)
await db.flush()
@@ -598,14 +601,16 @@ async def push_status(
jobs = []
for job, hostname in rows:
jobs.append({
"device_id": str(job.device_id),
"hostname": hostname,
"status": job.status,
"error_message": job.error_message,
"started_at": job.started_at.isoformat() if job.started_at else None,
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
})
jobs.append(
{
"device_id": str(job.device_id),
"hostname": hostname,
"status": job.status,
"error_message": job.error_message,
"started_at": job.started_at.isoformat() if job.started_at else None,
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
}
)
return {
"rollout_id": str(rollout_id),

View File

@@ -9,7 +9,6 @@ DELETE /api/tenants/{id} — delete tenant (super_admin only)
"""
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy import func, select, text
@@ -17,7 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.middleware.rate_limit import limiter
from app.database import get_admin_db, get_db
from app.database import get_admin_db
from app.middleware.rbac import require_super_admin, require_tenant_admin_or_above
from app.middleware.tenant_context import CurrentUser
from app.models.device import Device
@@ -70,15 +69,18 @@ async def list_tenants(
else:
if not current_user.tenant_id:
return []
result = await db.execute(
select(Tenant).where(Tenant.id == current_user.tenant_id)
)
result = await db.execute(select(Tenant).where(Tenant.id == current_user.tenant_id))
tenants = result.scalars().all()
return [await _get_tenant_response(tenant, db) for tenant in tenants]
@router.post("", response_model=TenantResponse, status_code=status.HTTP_201_CREATED, summary="Create a tenant")
@router.post(
"",
response_model=TenantResponse,
status_code=status.HTTP_201_CREATED,
summary="Create a tenant",
)
@limiter.limit("20/minute")
async def create_tenant(
request: Request,
@@ -108,13 +110,21 @@ async def create_tenant(
("Device Offline", "device_offline", "eq", 1, 1, "critical"),
]
for name, metric, operator, threshold, duration, sev in default_rules:
await db.execute(text("""
await db.execute(
text("""
INSERT INTO alert_rules (id, tenant_id, name, metric, operator, threshold, duration_polls, severity, enabled, is_default)
VALUES (gen_random_uuid(), CAST(:tenant_id AS uuid), :name, :metric, :operator, :threshold, :duration, :severity, TRUE, TRUE)
"""), {
"tenant_id": str(tenant.id), "name": name, "metric": metric,
"operator": operator, "threshold": threshold, "duration": duration, "severity": sev,
})
"""),
{
"tenant_id": str(tenant.id),
"name": name,
"metric": metric,
"operator": operator,
"threshold": threshold,
"duration": duration,
"severity": sev,
},
)
await db.commit()
# Seed starter config templates for new tenant
@@ -131,6 +141,7 @@ async def create_tenant(
await db.commit()
except Exception as exc:
import logging
logging.getLogger(__name__).warning(
"OpenBao key provisioning failed for tenant %s (will be provisioned on next startup): %s",
tenant.id,
@@ -229,6 +240,7 @@ async def delete_tenant(
# Check if tenant had VPN configured (before cascade deletes it)
from app.services.vpn_service import get_vpn_config, sync_wireguard_config
had_vpn = await get_vpn_config(db, tenant_id)
await db.delete(tenant)
@@ -288,14 +300,54 @@ add chain=forward action=drop comment="Drop everything else"
# Identity
/system identity set name={{ device.hostname }}""",
"variables": [
{"name": "wan_interface", "type": "string", "default": "ether1", "description": "WAN-facing interface"},
{"name": "lan_gateway", "type": "ip", "default": "192.168.88.1", "description": "LAN gateway IP"},
{"name": "lan_cidr", "type": "integer", "default": "24", "description": "LAN subnet mask bits"},
{"name": "lan_network", "type": "ip", "default": "192.168.88.0", "description": "LAN network address"},
{"name": "dhcp_start", "type": "ip", "default": "192.168.88.100", "description": "DHCP pool start"},
{"name": "dhcp_end", "type": "ip", "default": "192.168.88.254", "description": "DHCP pool end"},
{"name": "dns_servers", "type": "string", "default": "8.8.8.8,8.8.4.4", "description": "Upstream DNS servers"},
{"name": "ntp_server", "type": "string", "default": "pool.ntp.org", "description": "NTP server"},
{
"name": "wan_interface",
"type": "string",
"default": "ether1",
"description": "WAN-facing interface",
},
{
"name": "lan_gateway",
"type": "ip",
"default": "192.168.88.1",
"description": "LAN gateway IP",
},
{
"name": "lan_cidr",
"type": "integer",
"default": "24",
"description": "LAN subnet mask bits",
},
{
"name": "lan_network",
"type": "ip",
"default": "192.168.88.0",
"description": "LAN network address",
},
{
"name": "dhcp_start",
"type": "ip",
"default": "192.168.88.100",
"description": "DHCP pool start",
},
{
"name": "dhcp_end",
"type": "ip",
"default": "192.168.88.254",
"description": "DHCP pool end",
},
{
"name": "dns_servers",
"type": "string",
"default": "8.8.8.8,8.8.4.4",
"description": "Upstream DNS servers",
},
{
"name": "ntp_server",
"type": "string",
"default": "pool.ntp.org",
"description": "NTP server",
},
],
},
{
@@ -311,8 +363,18 @@ add chain=forward connection-state=invalid action=drop
add chain=forward src-address={{ allowed_network }} action=accept
add chain=forward action=drop""",
"variables": [
{"name": "wan_interface", "type": "string", "default": "ether1", "description": "WAN-facing interface"},
{"name": "allowed_network", "type": "subnet", "default": "192.168.88.0/24", "description": "Allowed source network"},
{
"name": "wan_interface",
"type": "string",
"default": "ether1",
"description": "WAN-facing interface",
},
{
"name": "allowed_network",
"type": "subnet",
"default": "192.168.88.0/24",
"description": "Allowed source network",
},
],
},
{
@@ -322,11 +384,36 @@ add chain=forward action=drop""",
/ip dhcp-server network add address={{ gateway }}/24 gateway={{ gateway }} dns-server={{ dns_server }}
/ip dhcp-server add name=dhcp1 interface={{ interface }} address-pool=dhcp-pool disabled=no""",
"variables": [
{"name": "pool_start", "type": "ip", "default": "192.168.88.100", "description": "DHCP pool start address"},
{"name": "pool_end", "type": "ip", "default": "192.168.88.254", "description": "DHCP pool end address"},
{"name": "gateway", "type": "ip", "default": "192.168.88.1", "description": "Default gateway"},
{"name": "dns_server", "type": "ip", "default": "8.8.8.8", "description": "DNS server address"},
{"name": "interface", "type": "string", "default": "bridge-lan", "description": "Interface to serve DHCP on"},
{
"name": "pool_start",
"type": "ip",
"default": "192.168.88.100",
"description": "DHCP pool start address",
},
{
"name": "pool_end",
"type": "ip",
"default": "192.168.88.254",
"description": "DHCP pool end address",
},
{
"name": "gateway",
"type": "ip",
"default": "192.168.88.1",
"description": "Default gateway",
},
{
"name": "dns_server",
"type": "ip",
"default": "8.8.8.8",
"description": "DNS server address",
},
{
"name": "interface",
"type": "string",
"default": "bridge-lan",
"description": "Interface to serve DHCP on",
},
],
},
{
@@ -335,10 +422,30 @@ add chain=forward action=drop""",
"content": """/interface wireless security-profiles add name=portal-wpa2 mode=dynamic-keys authentication-types=wpa2-psk wpa2-pre-shared-key={{ password }}
/interface wireless set wlan1 mode=ap-bridge ssid={{ ssid }} security-profile=portal-wpa2 frequency={{ frequency }} channel-width={{ channel_width }} disabled=no""",
"variables": [
{"name": "ssid", "type": "string", "default": "MikroTik-AP", "description": "Wireless network name"},
{"name": "password", "type": "string", "default": "", "description": "WPA2 pre-shared key (min 8 characters)"},
{"name": "frequency", "type": "integer", "default": "2412", "description": "Wireless frequency in MHz"},
{"name": "channel_width", "type": "string", "default": "20/40mhz-XX", "description": "Channel width setting"},
{
"name": "ssid",
"type": "string",
"default": "MikroTik-AP",
"description": "Wireless network name",
},
{
"name": "password",
"type": "string",
"default": "",
"description": "WPA2 pre-shared key (min 8 characters)",
},
{
"name": "frequency",
"type": "integer",
"default": "2412",
"description": "Wireless frequency in MHz",
},
{
"name": "channel_width",
"type": "string",
"default": "20/40mhz-XX",
"description": "Channel width setting",
},
],
},
{
@@ -351,8 +458,18 @@ add chain=forward action=drop""",
/ip service set ssh port=22
/ip service set winbox port=8291""",
"variables": [
{"name": "ntp_server", "type": "ip", "default": "pool.ntp.org", "description": "NTP server address"},
{"name": "dns_servers", "type": "string", "default": "8.8.8.8,8.8.4.4", "description": "Comma-separated DNS servers"},
{
"name": "ntp_server",
"type": "ip",
"default": "pool.ntp.org",
"description": "NTP server address",
},
{
"name": "dns_servers",
"type": "string",
"default": "8.8.8.8,8.8.4.4",
"description": "Comma-separated DNS servers",
},
],
},
]
@@ -363,13 +480,16 @@ async def _seed_starter_templates(db, tenant_id) -> None:
import json as _json
for tmpl in _STARTER_TEMPLATES:
await db.execute(text("""
await db.execute(
text("""
INSERT INTO config_templates (id, tenant_id, name, description, content, variables)
VALUES (gen_random_uuid(), CAST(:tid AS uuid), :name, :desc, :content, CAST(:vars AS jsonb))
"""), {
"tid": str(tenant_id),
"name": tmpl["name"],
"desc": tmpl["description"],
"content": tmpl["content"],
"vars": _json.dumps(tmpl["variables"]),
})
"""),
{
"tid": str(tenant_id),
"name": tmpl["name"],
"desc": tmpl["description"],
"content": tmpl["content"],
"vars": _json.dumps(tmpl["variables"]),
},
)

View File

@@ -14,7 +14,6 @@ Builds a topology graph of managed devices by:
import asyncio
import ipaddress
import json
import logging
import uuid
from typing import Any
@@ -265,7 +264,7 @@ async def get_topology(
nodes: list[TopologyNode] = []
ip_to_device: dict[str, str] = {}
online_device_ids: list[str] = []
devices_by_id: dict[str, Any] = {}
_devices_by_id: dict[str, Any] = {}
for row in rows:
device_id = str(row.id)
@@ -288,9 +287,7 @@ async def get_topology(
if online_device_ids:
tasks = [
routeros_proxy.execute_command(
device_id, "/ip/neighbor/print", timeout=10.0
)
routeros_proxy.execute_command(device_id, "/ip/neighbor/print", timeout=10.0)
for device_id in online_device_ids
]
results = await asyncio.gather(*tasks, return_exceptions=True)

View File

@@ -164,9 +164,7 @@ async def list_transparency_logs(
# Count total
count_result = await db.execute(
select(func.count())
.select_from(text("key_access_log k"))
.where(where_clause),
select(func.count()).select_from(text("key_access_log k")).where(where_clause),
params,
)
total = count_result.scalar() or 0
@@ -353,39 +351,41 @@ async def export_transparency_logs(
output = io.StringIO()
writer = csv.writer(output)
writer.writerow([
"ID",
"Action",
"Device Name",
"Device ID",
"Justification",
"Operator Email",
"Correlation ID",
"Resource Type",
"Resource ID",
"IP Address",
"Timestamp",
])
writer.writerow(
[
"ID",
"Action",
"Device Name",
"Device ID",
"Justification",
"Operator Email",
"Correlation ID",
"Resource Type",
"Resource ID",
"IP Address",
"Timestamp",
]
)
for row in all_rows:
writer.writerow([
str(row["id"]),
row["action"],
row["device_name"] or "",
str(row["device_id"]) if row["device_id"] else "",
row["justification"] or "",
row["operator_email"] or "",
row["correlation_id"] or "",
row["resource_type"] or "",
row["resource_id"] or "",
row["ip_address"] or "",
str(row["created_at"]),
])
writer.writerow(
[
str(row["id"]),
row["action"],
row["device_name"] or "",
str(row["device_id"]) if row["device_id"] else "",
row["justification"] or "",
row["operator_email"] or "",
row["correlation_id"] or "",
row["resource_type"] or "",
row["resource_id"] or "",
row["ip_address"] or "",
str(row["created_at"]),
]
)
output.seek(0)
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={
"Content-Disposition": "attachment; filename=transparency-logs.csv"
},
headers={"Content-Disposition": "attachment; filename=transparency-logs.csv"},
)

View File

@@ -20,7 +20,7 @@ from app.database import get_admin_db
from app.middleware.rbac import require_tenant_admin_or_above
from app.middleware.tenant_context import CurrentUser
from app.models.tenant import Tenant
from app.models.user import User, UserRole
from app.models.user import User
from app.schemas.user import UserCreate, UserResponse, UserUpdate
from app.services.auth import hash_password
@@ -69,11 +69,7 @@ async def list_users(
"""
await _check_tenant_access(tenant_id, current_user, db)
result = await db.execute(
select(User)
.where(User.tenant_id == tenant_id)
.order_by(User.name)
)
result = await db.execute(select(User).where(User.tenant_id == tenant_id).order_by(User.name))
users = result.scalars().all()
return [UserResponse.model_validate(user) for user in users]
@@ -103,9 +99,7 @@ async def create_user(
await _check_tenant_access(tenant_id, current_user, db)
# Check email uniqueness (global, not per-tenant)
existing = await db.execute(
select(User).where(User.email == data.email.lower())
)
existing = await db.execute(select(User).where(User.email == data.email.lower()))
if existing.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
@@ -138,9 +132,7 @@ async def get_user(
"""Get user detail."""
await _check_tenant_access(tenant_id, current_user, db)
result = await db.execute(
select(User).where(User.id == user_id, User.tenant_id == tenant_id)
)
result = await db.execute(select(User).where(User.id == user_id, User.tenant_id == tenant_id))
user = result.scalar_one_or_none()
if not user:
@@ -168,9 +160,7 @@ async def update_user(
"""
await _check_tenant_access(tenant_id, current_user, db)
result = await db.execute(
select(User).where(User.id == user_id, User.tenant_id == tenant_id)
)
result = await db.execute(select(User).where(User.id == user_id, User.tenant_id == tenant_id))
user = result.scalar_one_or_none()
if not user:
@@ -194,7 +184,11 @@ async def update_user(
return UserResponse.model_validate(user)
@router.delete("/{tenant_id}/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Deactivate a user")
@router.delete(
"/{tenant_id}/users/{user_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Deactivate a user",
)
@limiter.limit("5/minute")
async def deactivate_user(
request: Request,
@@ -209,9 +203,7 @@ async def deactivate_user(
"""
await _check_tenant_access(tenant_id, current_user, db)
result = await db.execute(
select(User).where(User.id == user_id, User.tenant_id == tenant_id)
)
result = await db.execute(select(User).where(User.id == user_id, User.tenant_id == tenant_id))
user = result.scalar_one_or_none()
if not user:

View File

@@ -68,7 +68,11 @@ async def get_vpn_config(
return resp
@router.post("/tenants/{tenant_id}/vpn", response_model=VpnConfigResponse, status_code=status.HTTP_201_CREATED)
@router.post(
"/tenants/{tenant_id}/vpn",
response_model=VpnConfigResponse,
status_code=status.HTTP_201_CREATED,
)
@limiter.limit("20/minute")
async def setup_vpn(
request: Request,
@@ -177,7 +181,11 @@ async def list_peers(
return responses
@router.post("/tenants/{tenant_id}/vpn/peers", response_model=VpnPeerResponse, status_code=status.HTTP_201_CREATED)
@router.post(
"/tenants/{tenant_id}/vpn/peers",
response_model=VpnPeerResponse,
status_code=status.HTTP_201_CREATED,
)
@limiter.limit("20/minute")
async def add_peer(
request: Request,
@@ -190,7 +198,9 @@ async def add_peer(
await _check_tenant_access(current_user, tenant_id, db)
_require_operator(current_user)
try:
peer = await vpn_service.add_peer(db, tenant_id, body.device_id, additional_allowed_ips=body.additional_allowed_ips)
peer = await vpn_service.add_peer(
db, tenant_id, body.device_id, additional_allowed_ips=body.additional_allowed_ips
)
except ValueError as e:
msg = str(e)
if "must not overlap" in msg:
@@ -208,7 +218,11 @@ async def add_peer(
return resp
@router.post("/tenants/{tenant_id}/vpn/peers/onboard", response_model=VpnOnboardResponse, status_code=status.HTTP_201_CREATED)
@router.post(
"/tenants/{tenant_id}/vpn/peers/onboard",
response_model=VpnOnboardResponse,
status_code=status.HTTP_201_CREATED,
)
@limiter.limit("10/minute")
async def onboard_device(
request: Request,
@@ -222,7 +236,8 @@ async def onboard_device(
_require_operator(current_user)
try:
result = await vpn_service.onboard_device(
db, tenant_id,
db,
tenant_id,
hostname=body.hostname,
username=body.username,
password=body.password,

View File

@@ -146,16 +146,16 @@ async def _delete_session_from_redis(session_id: str) -> None:
await rd.delete(f"{REDIS_PREFIX}{session_id}")
async def _open_tunnel(
device_id: uuid.UUID, tenant_id: uuid.UUID, user_id: uuid.UUID
) -> dict:
async def _open_tunnel(device_id: uuid.UUID, tenant_id: uuid.UUID, user_id: uuid.UUID) -> dict:
"""Open a TCP tunnel to device port 8291 via NATS request-reply."""
payload = json.dumps({
"device_id": str(device_id),
"tenant_id": str(tenant_id),
"user_id": str(user_id),
"target_port": 8291,
}).encode()
payload = json.dumps(
{
"device_id": str(device_id),
"tenant_id": str(tenant_id),
"user_id": str(user_id),
"target_port": 8291,
}
).encode()
try:
nc = await _get_nats()
@@ -176,9 +176,7 @@ async def _open_tunnel(
)
if "error" in data:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=data["error"]
)
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=data["error"])
return data
@@ -250,9 +248,7 @@ async def create_winbox_remote_session(
except Exception:
worker_info = None
if worker_info is None:
logger.warning(
"Cleaning stale Redis session %s (worker 404)", stale_sid
)
logger.warning("Cleaning stale Redis session %s (worker 404)", stale_sid)
tunnel_id = sess.get("tunnel_id")
if tunnel_id:
await _close_tunnel(tunnel_id)
@@ -333,12 +329,8 @@ async def create_winbox_remote_session(
username = "" # noqa: F841
password = "" # noqa: F841
expires_at = datetime.fromisoformat(
worker_resp.get("expires_at", now.isoformat())
)
max_expires_at = datetime.fromisoformat(
worker_resp.get("max_expires_at", now.isoformat())
)
expires_at = datetime.fromisoformat(worker_resp.get("expires_at", now.isoformat()))
max_expires_at = datetime.fromisoformat(worker_resp.get("max_expires_at", now.isoformat()))
# Save session to Redis
session_data = {
@@ -375,8 +367,7 @@ async def create_winbox_remote_session(
pass
ws_path = (
f"/api/tenants/{tenant_id}/devices/{device_id}"
f"/winbox-remote-sessions/{session_id}/ws"
f"/api/tenants/{tenant_id}/devices/{device_id}/winbox-remote-sessions/{session_id}/ws"
)
return RemoteWinboxSessionResponse(
@@ -425,14 +416,10 @@ async def get_winbox_remote_session(
sess = await _get_session_from_redis(str(session_id))
if sess is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Session not found"
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
if sess.get("tenant_id") != str(tenant_id) or sess.get("device_id") != str(device_id):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Session not found"
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
return RemoteWinboxStatusResponse(
session_id=uuid.UUID(sess["session_id"]),
@@ -478,10 +465,7 @@ async def list_winbox_remote_sessions(
sess = json.loads(raw)
except Exception:
continue
if (
sess.get("tenant_id") == str(tenant_id)
and sess.get("device_id") == str(device_id)
):
if sess.get("tenant_id") == str(tenant_id) and sess.get("device_id") == str(device_id):
sessions.append(
RemoteWinboxStatusResponse(
session_id=uuid.UUID(sess["session_id"]),
@@ -533,9 +517,7 @@ async def terminate_winbox_remote_session(
)
if sess.get("tenant_id") != str(tenant_id):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Session not found"
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
# Rollback order: worker -> tunnel -> redis -> audit
await worker_terminate_session(str(session_id))
@@ -574,14 +556,12 @@ async def terminate_winbox_remote_session(
@router.get(
"/tenants/{tenant_id}/devices/{device_id}"
"/winbox-remote-sessions/{session_id}/xpra/{path:path}",
"/tenants/{tenant_id}/devices/{device_id}/winbox-remote-sessions/{session_id}/xpra/{path:path}",
summary="Proxy Xpra HTML5 client files",
dependencies=[Depends(require_operator_or_above)],
)
@router.get(
"/tenants/{tenant_id}/devices/{device_id}"
"/winbox-remote-sessions/{session_id}/xpra",
"/tenants/{tenant_id}/devices/{device_id}/winbox-remote-sessions/{session_id}/xpra",
summary="Proxy Xpra HTML5 client (root)",
dependencies=[Depends(require_operator_or_above)],
)
@@ -626,7 +606,8 @@ async def proxy_xpra_html(
content=proxy_resp.content,
status_code=proxy_resp.status_code,
headers={
k: v for k, v in proxy_resp.headers.items()
k: v
for k, v in proxy_resp.headers.items()
if k.lower() in ("content-type", "cache-control", "content-encoding")
},
)
@@ -637,9 +618,7 @@ async def proxy_xpra_html(
# ---------------------------------------------------------------------------
@router.websocket(
"/tenants/{tenant_id}/devices/{device_id}/winbox-remote-sessions/{session_id}/ws"
)
@router.websocket("/tenants/{tenant_id}/devices/{device_id}/winbox-remote-sessions/{session_id}/ws")
async def winbox_remote_ws_proxy(
websocket: WebSocket,
tenant_id: uuid.UUID,