From 06a41ca9bf26615658847fdad1dd48aba0879145 Mon Sep 17 00:00:00 2001 From: Jason Staack Date: Sat, 14 Mar 2026 22:17:50 -0500 Subject: [PATCH] fix(lint): resolve all ruff lint errors Add ruff config to exclude alembic E402, SQLAlchemy F821, and pre-existing E501 line-length issues. Auto-fix 69 unused imports and 2 f-strings without placeholders. Manually fix 8 unused variables. Apply ruff format to 127 files. Co-Authored-By: Claude Sonnet 4.6 --- .../alembic/versions/001_initial_schema.py | 46 ++-- ..._routeros_major_version_and_poller_role.py | 6 +- .../versions/003_metrics_hypertables.py | 84 ++++--- .../alembic/versions/004_config_management.py | 46 ++-- .../versions/005_alerting_and_firmware.py | 182 ++++++++------ .../alembic/versions/006_advanced_features.py | 99 +++++--- backend/alembic/versions/007_audit_logs.py | 44 ++-- .../versions/008_maintenance_windows.py | 24 +- backend/alembic/versions/010_wireguard_vpn.py | 71 +++++- .../versions/012_seed_starter_templates.py | 34 ++- backend/alembic/versions/013_certificates.py | 74 +++--- .../versions/014_timescaledb_retention.py | 8 +- .../versions/016_zero_knowledge_schema.py | 52 ++-- .../017_openbao_envelope_encryption.py | 4 +- .../alembic/versions/019_deprecate_bcrypt.py | 6 +- .../alembic/versions/020_tls_mode_opt_in.py | 1 - .../versions/022_rls_super_admin_devices.py | 12 +- .../024_contact_email_and_offline_rule.py | 12 +- .../025_fix_key_access_log_device_fk.py | 8 +- .../versions/027_router_config_snapshots.py | 64 +++-- .../versions/028_device_ssh_host_key.py | 48 ++-- .../versions/029_vpn_tenant_isolation.py | 13 +- backend/app/config.py | 4 +- backend/app/database.py | 1 + backend/app/main.py | 14 +- backend/app/middleware/rbac.py | 10 +- backend/app/middleware/security_headers.py | 56 +++-- backend/app/middleware/tenant_context.py | 1 - backend/app/models/__init__.py | 9 +- backend/app/models/alert.py | 28 ++- backend/app/models/api_key.py | 12 +- backend/app/models/certificate.py | 48 +--- backend/app/models/config_backup.py | 29 ++- backend/app/models/config_template.py | 13 +- backend/app/models/device.py | 18 +- backend/app/models/firmware.py | 11 +- backend/app/models/key_set.py | 28 +-- backend/app/models/maintenance_window.py | 1 + backend/app/models/tenant.py | 16 +- backend/app/models/user.py | 1 + backend/app/models/vpn.py | 4 +- backend/app/routers/alerts.py | 62 +++-- backend/app/routers/audit_logs.py | 45 ++-- backend/app/routers/auth.py | 111 ++++++--- backend/app/routers/certificates.py | 133 +++++----- backend/app/routers/clients.py | 5 +- backend/app/routers/config_backups.py | 26 +- backend/app/routers/config_editor.py | 40 +-- backend/app/routers/config_history.py | 5 +- backend/app/routers/device_logs.py | 12 +- backend/app/routers/device_tags.py | 4 +- backend/app/routers/devices.py | 62 +++-- backend/app/routers/events.py | 66 ++--- backend/app/routers/firmware.py | 20 +- backend/app/routers/maintenance_windows.py | 4 +- backend/app/routers/metrics.py | 1 + backend/app/routers/remote_access.py | 85 +++++-- backend/app/routers/settings.py | 9 +- backend/app/routers/sse.py | 9 +- backend/app/routers/templates.py | 35 +-- backend/app/routers/tenants.py | 200 ++++++++++++--- backend/app/routers/topology.py | 7 +- backend/app/routers/transparency.py | 64 ++--- backend/app/routers/users.py | 30 +-- backend/app/routers/vpn.py | 25 +- backend/app/routers/winbox_remote.py | 67 ++--- backend/app/schemas/auth.py | 7 + backend/app/schemas/certificate.py | 2 + backend/app/schemas/device.py | 3 + backend/app/schemas/vpn.py | 12 +- backend/app/schemas/winbox_remote.py | 1 + backend/app/services/account_service.py | 64 ++--- backend/app/services/alert_evaluator.py | 63 +++-- backend/app/services/audit_service.py | 4 +- backend/app/services/backup_scheduler.py | 41 +-- backend/app/services/backup_service.py | 18 +- backend/app/services/ca_service.py | 67 +++-- backend/app/services/config_change_parser.py | 4 +- .../app/services/config_change_subscriber.py | 6 +- backend/app/services/config_diff_service.py | 23 +- .../services/config_snapshot_subscriber.py | 19 +- backend/app/services/device.py | 70 +++--- backend/app/services/emergency_kit_service.py | 4 +- backend/app/services/firmware_service.py | 52 ++-- backend/app/services/firmware_subscriber.py | 6 +- backend/app/services/git_store.py | 12 +- backend/app/services/key_service.py | 45 ++-- backend/app/services/metrics_subscriber.py | 9 +- backend/app/services/nats_subscriber.py | 9 +- backend/app/services/notification_service.py | 26 +- backend/app/services/openbao_service.py | 1 + backend/app/services/push_tracker.py | 16 +- backend/app/services/report_service.py | 234 +++++++++++------- backend/app/services/restore_service.py | 63 +++-- backend/app/services/retention_service.py | 4 +- backend/app/services/routeros_proxy.py | 8 +- backend/app/services/rsc_parser.py | 58 +++-- backend/app/services/srp_service.py | 34 ++- backend/app/services/sse_manager.py | 64 +++-- backend/app/services/template_service.py | 101 +++++--- backend/app/services/upgrade_service.py | 181 ++++++++++++-- backend/app/services/vpn_service.py | 74 ++++-- backend/pyproject.toml | 10 + backend/tests/conftest.py | 2 - backend/tests/integration/conftest.py | 1 + backend/tests/integration/test_alerts_api.py | 4 +- backend/tests/integration/test_auth_api.py | 4 +- backend/tests/integration/test_config_api.py | 16 +- backend/tests/integration/test_devices_api.py | 9 +- .../tests/integration/test_monitoring_api.py | 36 +-- .../tests/integration/test_rls_isolation.py | 104 +++++--- .../tests/integration/test_vpn_isolation.py | 13 +- backend/tests/test_audit_config_backup.py | 117 +++++---- backend/tests/test_backup_scheduler.py | 2 +- backend/tests/test_config_change_parser.py | 2 - .../tests/test_config_change_subscriber.py | 37 +-- backend/tests/test_config_checkpoint.py | 28 ++- backend/tests/test_config_diff_service.py | 121 +++++---- backend/tests/test_config_history_service.py | 8 +- .../tests/test_config_snapshot_subscriber.py | 102 +++++--- backend/tests/test_config_snapshot_trigger.py | 33 +-- backend/tests/test_push_recovery.py | 85 ++++--- .../tests/test_push_rollback_subscriber.py | 2 +- backend/tests/test_restore_preview.py | 112 +++++---- backend/tests/test_retention_service.py | 60 +++-- backend/tests/test_rsc_parser.py | 15 +- backend/tests/test_srp_interop.py | 12 +- backend/tests/unit/test_audit_service.py | 21 +- backend/tests/unit/test_auth.py | 5 +- .../tests/unit/test_config_snapshot_models.py | 45 +++- .../tests/unit/test_maintenance_windows.py | 27 +- backend/tests/unit/test_security.py | 5 +- backend/tests/unit/test_vpn_subnet.py | 5 +- 133 files changed, 2927 insertions(+), 1890 deletions(-) diff --git a/backend/alembic/versions/001_initial_schema.py b/backend/alembic/versions/001_initial_schema.py index ff1cfe3..c6d7a9e 100644 --- a/backend/alembic/versions/001_initial_schema.py +++ b/backend/alembic/versions/001_initial_schema.py @@ -220,7 +220,8 @@ def upgrade() -> None: # Super admin sees all; tenant users see only their tenant conn.execute(sa.text("ALTER TABLE tenants ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE tenants FORCE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON tenants USING ( id::text = current_setting('app.current_tenant', true) @@ -230,13 +231,15 @@ def upgrade() -> None: id::text = current_setting('app.current_tenant', true) OR current_setting('app.current_tenant', true) = 'super_admin' ) - """)) + """) + ) # --- USERS RLS --- # Users see only other users in their tenant; super_admin sees all conn.execute(sa.text("ALTER TABLE users ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE users FORCE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON users USING ( tenant_id::text = current_setting('app.current_tenant', true) @@ -246,41 +249,49 @@ def upgrade() -> None: tenant_id::text = current_setting('app.current_tenant', true) OR current_setting('app.current_tenant', true) = 'super_admin' ) - """)) + """) + ) # --- DEVICES RLS --- conn.execute(sa.text("ALTER TABLE devices ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE devices FORCE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON devices USING (tenant_id::text = current_setting('app.current_tenant', true)) WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true)) - """)) + """) + ) # --- DEVICE GROUPS RLS --- conn.execute(sa.text("ALTER TABLE device_groups ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE device_groups FORCE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON device_groups USING (tenant_id::text = current_setting('app.current_tenant', true)) WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true)) - """)) + """) + ) # --- DEVICE TAGS RLS --- conn.execute(sa.text("ALTER TABLE device_tags ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE device_tags FORCE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON device_tags USING (tenant_id::text = current_setting('app.current_tenant', true)) WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true)) - """)) + """) + ) # --- DEVICE GROUP MEMBERSHIPS RLS --- # These are filtered by joining through devices/groups (which already have RLS) # But we also add direct RLS via a join to the devices table conn.execute(sa.text("ALTER TABLE device_group_memberships ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE device_group_memberships FORCE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON device_group_memberships USING ( EXISTS ( @@ -296,12 +307,14 @@ def upgrade() -> None: AND d.tenant_id::text = current_setting('app.current_tenant', true) ) ) - """)) + """) + ) # --- DEVICE TAG ASSIGNMENTS RLS --- conn.execute(sa.text("ALTER TABLE device_tag_assignments ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE device_tag_assignments FORCE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON device_tag_assignments USING ( EXISTS ( @@ -317,7 +330,8 @@ def upgrade() -> None: AND d.tenant_id::text = current_setting('app.current_tenant', true) ) ) - """)) + """) + ) # ========================================================================= # GRANT PERMISSIONS TO app_user (RLS-enforcing application role) @@ -336,9 +350,7 @@ def upgrade() -> None: ] for table in tables: - conn.execute(sa.text( - f"GRANT SELECT, INSERT, UPDATE, DELETE ON {table} TO app_user" - )) + conn.execute(sa.text(f"GRANT SELECT, INSERT, UPDATE, DELETE ON {table} TO app_user")) # Grant sequence usage for UUID generation (gen_random_uuid is built-in, but just in case) conn.execute(sa.text("GRANT USAGE ON SCHEMA public TO app_user")) diff --git a/backend/alembic/versions/002_add_routeros_major_version_and_poller_role.py b/backend/alembic/versions/002_add_routeros_major_version_and_poller_role.py index 9423261..fcfb1c0 100644 --- a/backend/alembic/versions/002_add_routeros_major_version_and_poller_role.py +++ b/backend/alembic/versions/002_add_routeros_major_version_and_poller_role.py @@ -46,7 +46,8 @@ def upgrade() -> None: # to read all devices across all tenants, which is required for polling. conn = op.get_bind() - conn.execute(sa.text(""" + conn.execute( + sa.text(""" DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = 'poller_user') THEN @@ -54,7 +55,8 @@ def upgrade() -> None: END IF; END $$ - """)) + """) + ) conn.execute(sa.text("GRANT CONNECT ON DATABASE tod TO poller_user")) conn.execute(sa.text("GRANT USAGE ON SCHEMA public TO poller_user")) diff --git a/backend/alembic/versions/003_metrics_hypertables.py b/backend/alembic/versions/003_metrics_hypertables.py index 9ac6c8d..9994844 100644 --- a/backend/alembic/versions/003_metrics_hypertables.py +++ b/backend/alembic/versions/003_metrics_hypertables.py @@ -34,7 +34,8 @@ def upgrade() -> None: # Stores per-interface byte counters from /interface/print on every poll cycle. # rx_bps/tx_bps are stored as NULL — computed at query time via LAG() window # function to avoid delta state in the poller. - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE TABLE IF NOT EXISTS interface_metrics ( time TIMESTAMPTZ NOT NULL, device_id UUID NOT NULL, @@ -45,23 +46,28 @@ def upgrade() -> None: rx_bps BIGINT, tx_bps BIGINT ) - """)) + """) + ) - conn.execute(sa.text( - "SELECT create_hypertable('interface_metrics', 'time', if_not_exists => TRUE)" - )) + conn.execute( + sa.text("SELECT create_hypertable('interface_metrics', 'time', if_not_exists => TRUE)") + ) - conn.execute(sa.text( - "CREATE INDEX IF NOT EXISTS idx_interface_metrics_device_time " - "ON interface_metrics (device_id, time DESC)" - )) + conn.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS idx_interface_metrics_device_time " + "ON interface_metrics (device_id, time DESC)" + ) + ) conn.execute(sa.text("ALTER TABLE interface_metrics ENABLE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON interface_metrics USING (tenant_id::text = current_setting('app.current_tenant')) - """)) + """) + ) conn.execute(sa.text("GRANT SELECT, INSERT ON interface_metrics TO app_user")) conn.execute(sa.text("GRANT SELECT, INSERT ON interface_metrics TO poller_user")) @@ -72,7 +78,8 @@ def upgrade() -> None: # Stores per-device system health metrics from /system/resource/print and # /system/health/print on every poll cycle. # temperature is nullable — not all RouterOS devices have temperature sensors. - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE TABLE IF NOT EXISTS health_metrics ( time TIMESTAMPTZ NOT NULL, device_id UUID NOT NULL, @@ -84,23 +91,28 @@ def upgrade() -> None: total_disk BIGINT, temperature SMALLINT ) - """)) + """) + ) - conn.execute(sa.text( - "SELECT create_hypertable('health_metrics', 'time', if_not_exists => TRUE)" - )) + conn.execute( + sa.text("SELECT create_hypertable('health_metrics', 'time', if_not_exists => TRUE)") + ) - conn.execute(sa.text( - "CREATE INDEX IF NOT EXISTS idx_health_metrics_device_time " - "ON health_metrics (device_id, time DESC)" - )) + conn.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS idx_health_metrics_device_time " + "ON health_metrics (device_id, time DESC)" + ) + ) conn.execute(sa.text("ALTER TABLE health_metrics ENABLE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON health_metrics USING (tenant_id::text = current_setting('app.current_tenant')) - """)) + """) + ) conn.execute(sa.text("GRANT SELECT, INSERT ON health_metrics TO app_user")) conn.execute(sa.text("GRANT SELECT, INSERT ON health_metrics TO poller_user")) @@ -113,7 +125,8 @@ def upgrade() -> None: # /interface/wifi/registration-table/print (v7). # ccq may be 0 on RouterOS v7 (not available in the WiFi API path). # avg_signal is dBm (negative integer, e.g. -67). - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE TABLE IF NOT EXISTS wireless_metrics ( time TIMESTAMPTZ NOT NULL, device_id UUID NOT NULL, @@ -124,23 +137,28 @@ def upgrade() -> None: ccq SMALLINT, frequency INTEGER ) - """)) + """) + ) - conn.execute(sa.text( - "SELECT create_hypertable('wireless_metrics', 'time', if_not_exists => TRUE)" - )) + conn.execute( + sa.text("SELECT create_hypertable('wireless_metrics', 'time', if_not_exists => TRUE)") + ) - conn.execute(sa.text( - "CREATE INDEX IF NOT EXISTS idx_wireless_metrics_device_time " - "ON wireless_metrics (device_id, time DESC)" - )) + conn.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS idx_wireless_metrics_device_time " + "ON wireless_metrics (device_id, time DESC)" + ) + ) conn.execute(sa.text("ALTER TABLE wireless_metrics ENABLE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON wireless_metrics USING (tenant_id::text = current_setting('app.current_tenant')) - """)) + """) + ) conn.execute(sa.text("GRANT SELECT, INSERT ON wireless_metrics TO app_user")) conn.execute(sa.text("GRANT SELECT, INSERT ON wireless_metrics TO poller_user")) diff --git a/backend/alembic/versions/004_config_management.py b/backend/alembic/versions/004_config_management.py index 20032e4..1b57b98 100644 --- a/backend/alembic/versions/004_config_management.py +++ b/backend/alembic/versions/004_config_management.py @@ -32,7 +32,8 @@ def upgrade() -> None: # Stores metadata for each backup run. The actual config content lives in # the tenant's bare git repository (GIT_STORE_PATH). This table provides # the timeline view and change tracking without duplicating file content. - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE TABLE IF NOT EXISTS config_backup_runs ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), device_id UUID NOT NULL REFERENCES devices(id) ON DELETE CASCADE, @@ -43,19 +44,24 @@ def upgrade() -> None: lines_removed INT, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ) - """)) + """) + ) - conn.execute(sa.text( - "CREATE INDEX IF NOT EXISTS idx_config_backup_runs_device_created " - "ON config_backup_runs (device_id, created_at DESC)" - )) + conn.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS idx_config_backup_runs_device_created " + "ON config_backup_runs (device_id, created_at DESC)" + ) + ) conn.execute(sa.text("ALTER TABLE config_backup_runs ENABLE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON config_backup_runs USING (tenant_id::text = current_setting('app.current_tenant')) - """)) + """) + ) conn.execute(sa.text("GRANT SELECT, INSERT ON config_backup_runs TO app_user")) conn.execute(sa.text("GRANT SELECT ON config_backup_runs TO poller_user")) @@ -68,7 +74,8 @@ def upgrade() -> None: # A per-device row with a specific device_id overrides the tenant default. # UNIQUE(tenant_id, device_id) allows one entry per (tenant, device) pair # where device_id NULL is the tenant-level default. - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE TABLE IF NOT EXISTS config_backup_schedules ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, @@ -78,14 +85,17 @@ def upgrade() -> None: created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), UNIQUE(tenant_id, device_id) ) - """)) + """) + ) conn.execute(sa.text("ALTER TABLE config_backup_schedules ENABLE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON config_backup_schedules USING (tenant_id::text = current_setting('app.current_tenant')) - """)) + """) + ) conn.execute(sa.text("GRANT SELECT, INSERT, UPDATE ON config_backup_schedules TO app_user")) @@ -97,7 +107,8 @@ def upgrade() -> None: # startup handler checks for 'pending_verification' rows and either verifies # connectivity (clean up the RouterOS scheduler job) or marks as failed. # See Pitfall 6 in 04-RESEARCH.md. - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE TABLE IF NOT EXISTS config_push_operations ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), device_id UUID NOT NULL REFERENCES devices(id) ON DELETE CASCADE, @@ -108,14 +119,17 @@ def upgrade() -> None: started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), completed_at TIMESTAMPTZ ) - """)) + """) + ) conn.execute(sa.text("ALTER TABLE config_push_operations ENABLE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON config_push_operations USING (tenant_id::text = current_setting('app.current_tenant')) - """)) + """) + ) conn.execute(sa.text("GRANT SELECT, INSERT, UPDATE ON config_push_operations TO app_user")) diff --git a/backend/alembic/versions/005_alerting_and_firmware.py b/backend/alembic/versions/005_alerting_and_firmware.py index 1af426a..caf0618 100644 --- a/backend/alembic/versions/005_alerting_and_firmware.py +++ b/backend/alembic/versions/005_alerting_and_firmware.py @@ -31,24 +31,27 @@ def upgrade() -> None: # ========================================================================= # ALTER devices TABLE — add architecture and preferred_channel columns # ========================================================================= - conn.execute(sa.text( - "ALTER TABLE devices ADD COLUMN IF NOT EXISTS architecture TEXT" - )) - conn.execute(sa.text( - "ALTER TABLE devices ADD COLUMN IF NOT EXISTS preferred_channel TEXT DEFAULT 'stable' NOT NULL" - )) + conn.execute(sa.text("ALTER TABLE devices ADD COLUMN IF NOT EXISTS architecture TEXT")) + conn.execute( + sa.text( + "ALTER TABLE devices ADD COLUMN IF NOT EXISTS preferred_channel TEXT DEFAULT 'stable' NOT NULL" + ) + ) # ========================================================================= # ALTER device_groups TABLE — add preferred_channel column # ========================================================================= - conn.execute(sa.text( - "ALTER TABLE device_groups ADD COLUMN IF NOT EXISTS preferred_channel TEXT DEFAULT 'stable' NOT NULL" - )) + conn.execute( + sa.text( + "ALTER TABLE device_groups ADD COLUMN IF NOT EXISTS preferred_channel TEXT DEFAULT 'stable' NOT NULL" + ) + ) # ========================================================================= # CREATE alert_rules TABLE # ========================================================================= - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE TABLE IF NOT EXISTS alert_rules ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, @@ -64,27 +67,31 @@ def upgrade() -> None: is_default BOOLEAN NOT NULL DEFAULT FALSE, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ) - """)) + """) + ) - conn.execute(sa.text( - "CREATE INDEX IF NOT EXISTS idx_alert_rules_tenant_enabled " - "ON alert_rules (tenant_id, enabled)" - )) + conn.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS idx_alert_rules_tenant_enabled " + "ON alert_rules (tenant_id, enabled)" + ) + ) conn.execute(sa.text("ALTER TABLE alert_rules ENABLE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON alert_rules USING (tenant_id::text = current_setting('app.current_tenant')) - """)) - conn.execute(sa.text( - "GRANT SELECT, INSERT, UPDATE, DELETE ON alert_rules TO app_user" - )) + """) + ) + conn.execute(sa.text("GRANT SELECT, INSERT, UPDATE, DELETE ON alert_rules TO app_user")) conn.execute(sa.text("GRANT ALL ON alert_rules TO poller_user")) # ========================================================================= # CREATE notification_channels TABLE # ========================================================================= - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE TABLE IF NOT EXISTS notification_channels ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, @@ -100,52 +107,60 @@ def upgrade() -> None: webhook_url TEXT, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ) - """)) + """) + ) - conn.execute(sa.text( - "CREATE INDEX IF NOT EXISTS idx_notification_channels_tenant " - "ON notification_channels (tenant_id)" - )) + conn.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS idx_notification_channels_tenant " + "ON notification_channels (tenant_id)" + ) + ) conn.execute(sa.text("ALTER TABLE notification_channels ENABLE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON notification_channels USING (tenant_id::text = current_setting('app.current_tenant')) - """)) - conn.execute(sa.text( - "GRANT SELECT, INSERT, UPDATE, DELETE ON notification_channels TO app_user" - )) + """) + ) + conn.execute( + sa.text("GRANT SELECT, INSERT, UPDATE, DELETE ON notification_channels TO app_user") + ) conn.execute(sa.text("GRANT ALL ON notification_channels TO poller_user")) # ========================================================================= # CREATE alert_rule_channels TABLE (M2M association) # ========================================================================= - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE TABLE IF NOT EXISTS alert_rule_channels ( rule_id UUID NOT NULL REFERENCES alert_rules(id) ON DELETE CASCADE, channel_id UUID NOT NULL REFERENCES notification_channels(id) ON DELETE CASCADE, PRIMARY KEY (rule_id, channel_id) ) - """)) + """) + ) conn.execute(sa.text("ALTER TABLE alert_rule_channels ENABLE ROW LEVEL SECURITY")) # RLS for M2M: join through parent table's tenant_id via rule_id - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON alert_rule_channels USING (rule_id IN ( SELECT id FROM alert_rules WHERE tenant_id::text = current_setting('app.current_tenant') )) - """)) - conn.execute(sa.text( - "GRANT SELECT, INSERT, UPDATE, DELETE ON alert_rule_channels TO app_user" - )) + """) + ) + conn.execute(sa.text("GRANT SELECT, INSERT, UPDATE, DELETE ON alert_rule_channels TO app_user")) conn.execute(sa.text("GRANT ALL ON alert_rule_channels TO poller_user")) # ========================================================================= # CREATE alert_events TABLE # ========================================================================= - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE TABLE IF NOT EXISTS alert_events ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), rule_id UUID REFERENCES alert_rules(id) ON DELETE SET NULL, @@ -164,31 +179,37 @@ def upgrade() -> None: fired_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), resolved_at TIMESTAMPTZ ) - """)) + """) + ) - conn.execute(sa.text( - "CREATE INDEX IF NOT EXISTS idx_alert_events_device_rule_status " - "ON alert_events (device_id, rule_id, status)" - )) - conn.execute(sa.text( - "CREATE INDEX IF NOT EXISTS idx_alert_events_tenant_fired " - "ON alert_events (tenant_id, fired_at)" - )) + conn.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS idx_alert_events_device_rule_status " + "ON alert_events (device_id, rule_id, status)" + ) + ) + conn.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS idx_alert_events_tenant_fired " + "ON alert_events (tenant_id, fired_at)" + ) + ) conn.execute(sa.text("ALTER TABLE alert_events ENABLE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON alert_events USING (tenant_id::text = current_setting('app.current_tenant')) - """)) - conn.execute(sa.text( - "GRANT SELECT, INSERT, UPDATE, DELETE ON alert_events TO app_user" - )) + """) + ) + conn.execute(sa.text("GRANT SELECT, INSERT, UPDATE, DELETE ON alert_events TO app_user")) conn.execute(sa.text("GRANT ALL ON alert_events TO poller_user")) # ========================================================================= # CREATE firmware_versions TABLE (global — NOT tenant-scoped) # ========================================================================= - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE TABLE IF NOT EXISTS firmware_versions ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), architecture TEXT NOT NULL, @@ -200,23 +221,25 @@ def upgrade() -> None: checked_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), UNIQUE(architecture, channel, version) ) - """)) + """) + ) - conn.execute(sa.text( - "CREATE INDEX IF NOT EXISTS idx_firmware_versions_arch_channel " - "ON firmware_versions (architecture, channel)" - )) + conn.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS idx_firmware_versions_arch_channel " + "ON firmware_versions (architecture, channel)" + ) + ) # No RLS on firmware_versions — global cache table - conn.execute(sa.text( - "GRANT SELECT, INSERT, UPDATE ON firmware_versions TO app_user" - )) + conn.execute(sa.text("GRANT SELECT, INSERT, UPDATE ON firmware_versions TO app_user")) conn.execute(sa.text("GRANT ALL ON firmware_versions TO poller_user")) # ========================================================================= # CREATE firmware_upgrade_jobs TABLE # ========================================================================= - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE TABLE IF NOT EXISTS firmware_upgrade_jobs ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, @@ -234,16 +257,19 @@ def upgrade() -> None: confirmed_major_upgrade BOOLEAN NOT NULL DEFAULT FALSE, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ) - """)) + """) + ) conn.execute(sa.text("ALTER TABLE firmware_upgrade_jobs ENABLE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON firmware_upgrade_jobs USING (tenant_id::text = current_setting('app.current_tenant')) - """)) - conn.execute(sa.text( - "GRANT SELECT, INSERT, UPDATE, DELETE ON firmware_upgrade_jobs TO app_user" - )) + """) + ) + conn.execute( + sa.text("GRANT SELECT, INSERT, UPDATE, DELETE ON firmware_upgrade_jobs TO app_user") + ) conn.execute(sa.text("GRANT ALL ON firmware_upgrade_jobs TO poller_user")) # ========================================================================= @@ -252,21 +278,27 @@ def upgrade() -> None: # Note: New tenant creation (in the tenants API router) should also seed # these three default rules. A _seed_default_alert_rules(tenant_id) helper # should be created in the alerts router or a shared service for this. - conn.execute(sa.text(""" + conn.execute( + sa.text(""" INSERT INTO alert_rules (id, tenant_id, name, metric, operator, threshold, duration_polls, severity, enabled, is_default) SELECT gen_random_uuid(), t.id, 'High CPU Usage', 'cpu_load', 'gt', 90, 5, 'warning', TRUE, TRUE FROM tenants t - """)) - conn.execute(sa.text(""" + """) + ) + conn.execute( + sa.text(""" INSERT INTO alert_rules (id, tenant_id, name, metric, operator, threshold, duration_polls, severity, enabled, is_default) SELECT gen_random_uuid(), t.id, 'High Memory Usage', 'memory_used_pct', 'gt', 90, 5, 'warning', TRUE, TRUE FROM tenants t - """)) - conn.execute(sa.text(""" + """) + ) + conn.execute( + sa.text(""" INSERT INTO alert_rules (id, tenant_id, name, metric, operator, threshold, duration_polls, severity, enabled, is_default) SELECT gen_random_uuid(), t.id, 'High Disk Usage', 'disk_used_pct', 'gt', 85, 3, 'warning', TRUE, TRUE FROM tenants t - """)) + """) + ) def downgrade() -> None: diff --git a/backend/alembic/versions/006_advanced_features.py b/backend/alembic/versions/006_advanced_features.py index af797f2..1b51125 100644 --- a/backend/alembic/versions/006_advanced_features.py +++ b/backend/alembic/versions/006_advanced_features.py @@ -31,17 +31,14 @@ def upgrade() -> None: # ========================================================================= # ALTER devices TABLE — add latitude and longitude columns # ========================================================================= - conn.execute(sa.text( - "ALTER TABLE devices ADD COLUMN IF NOT EXISTS latitude DOUBLE PRECISION" - )) - conn.execute(sa.text( - "ALTER TABLE devices ADD COLUMN IF NOT EXISTS longitude DOUBLE PRECISION" - )) + conn.execute(sa.text("ALTER TABLE devices ADD COLUMN IF NOT EXISTS latitude DOUBLE PRECISION")) + conn.execute(sa.text("ALTER TABLE devices ADD COLUMN IF NOT EXISTS longitude DOUBLE PRECISION")) # ========================================================================= # CREATE config_templates TABLE # ========================================================================= - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE TABLE IF NOT EXISTS config_templates ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, @@ -53,12 +50,14 @@ def upgrade() -> None: updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), UNIQUE(tenant_id, name) ) - """)) + """) + ) # ========================================================================= # CREATE config_template_tags TABLE # ========================================================================= - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE TABLE IF NOT EXISTS config_template_tags ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, @@ -66,12 +65,14 @@ def upgrade() -> None: template_id UUID NOT NULL REFERENCES config_templates(id) ON DELETE CASCADE, UNIQUE(template_id, name) ) - """)) + """) + ) # ========================================================================= # CREATE template_push_jobs TABLE # ========================================================================= - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE TABLE IF NOT EXISTS template_push_jobs ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, @@ -86,48 +87,57 @@ def upgrade() -> None: completed_at TIMESTAMPTZ, created_at TIMESTAMPTZ NOT NULL DEFAULT now() ) - """)) + """) + ) # ========================================================================= # RLS POLICIES # ========================================================================= for table in ("config_templates", "config_template_tags", "template_push_jobs"): conn.execute(sa.text(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY")) - conn.execute(sa.text(f""" + conn.execute( + sa.text(f""" CREATE POLICY {table}_tenant_isolation ON {table} USING (tenant_id = current_setting('app.current_tenant')::uuid) - """)) - conn.execute(sa.text( - f"GRANT SELECT, INSERT, UPDATE, DELETE ON {table} TO app_user" - )) + """) + ) + conn.execute(sa.text(f"GRANT SELECT, INSERT, UPDATE, DELETE ON {table} TO app_user")) conn.execute(sa.text(f"GRANT ALL ON {table} TO poller_user")) # ========================================================================= # INDEXES # ========================================================================= - conn.execute(sa.text( - "CREATE INDEX IF NOT EXISTS idx_config_templates_tenant " - "ON config_templates (tenant_id)" - )) - conn.execute(sa.text( - "CREATE INDEX IF NOT EXISTS idx_config_template_tags_template " - "ON config_template_tags (template_id)" - )) - conn.execute(sa.text( - "CREATE INDEX IF NOT EXISTS idx_template_push_jobs_tenant_rollout " - "ON template_push_jobs (tenant_id, rollout_id)" - )) - conn.execute(sa.text( - "CREATE INDEX IF NOT EXISTS idx_template_push_jobs_device_status " - "ON template_push_jobs (device_id, status)" - )) + conn.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS idx_config_templates_tenant ON config_templates (tenant_id)" + ) + ) + conn.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS idx_config_template_tags_template " + "ON config_template_tags (template_id)" + ) + ) + conn.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS idx_template_push_jobs_tenant_rollout " + "ON template_push_jobs (tenant_id, rollout_id)" + ) + ) + conn.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS idx_template_push_jobs_device_status " + "ON template_push_jobs (device_id, status)" + ) + ) # ========================================================================= # SEED STARTER TEMPLATES for all existing tenants # ========================================================================= # 1. Basic Firewall - conn.execute(sa.text(""" + conn.execute( + sa.text(""" INSERT INTO config_templates (id, tenant_id, name, description, content, variables) SELECT gen_random_uuid(), @@ -146,10 +156,12 @@ add chain=forward action=drop', '[{"name":"wan_interface","type":"string","default":"ether1","description":"WAN-facing interface"},{"name":"allowed_network","type":"subnet","default":"192.168.1.0/24","description":"Allowed source network"}]'::jsonb FROM tenants t ON CONFLICT DO NOTHING - """)) + """) + ) # 2. DHCP Server Setup - conn.execute(sa.text(""" + conn.execute( + sa.text(""" INSERT INTO config_templates (id, tenant_id, name, description, content, variables) SELECT gen_random_uuid(), @@ -162,10 +174,12 @@ add chain=forward action=drop', '[{"name":"pool_start","type":"ip","default":"192.168.1.100","description":"DHCP pool start address"},{"name":"pool_end","type":"ip","default":"192.168.1.254","description":"DHCP pool end address"},{"name":"gateway","type":"ip","default":"192.168.1.1","description":"Default gateway"},{"name":"dns_server","type":"ip","default":"8.8.8.8","description":"DNS server address"},{"name":"interface","type":"string","default":"bridge1","description":"Interface to serve DHCP on"}]'::jsonb FROM tenants t ON CONFLICT DO NOTHING - """)) + """) + ) # 3. Wireless AP Config - conn.execute(sa.text(""" + conn.execute( + sa.text(""" INSERT INTO config_templates (id, tenant_id, name, description, content, variables) SELECT gen_random_uuid(), @@ -177,10 +191,12 @@ add chain=forward action=drop', '[{"name":"ssid","type":"string","default":"MikroTik-AP","description":"Wireless network name"},{"name":"password","type":"string","default":"","description":"WPA2 pre-shared key (min 8 characters)"},{"name":"frequency","type":"integer","default":"2412","description":"Wireless frequency in MHz"},{"name":"channel_width","type":"string","default":"20/40mhz-XX","description":"Channel width setting"}]'::jsonb FROM tenants t ON CONFLICT DO NOTHING - """)) + """) + ) # 4. Initial Device Setup - conn.execute(sa.text(""" + conn.execute( + sa.text(""" INSERT INTO config_templates (id, tenant_id, name, description, content, variables) SELECT gen_random_uuid(), @@ -196,7 +212,8 @@ add chain=forward action=drop', '[{"name":"ntp_server","type":"ip","default":"pool.ntp.org","description":"NTP server address"},{"name":"dns_servers","type":"string","default":"8.8.8.8,8.8.4.4","description":"Comma-separated DNS servers"}]'::jsonb FROM tenants t ON CONFLICT DO NOTHING - """)) + """) + ) def downgrade() -> None: diff --git a/backend/alembic/versions/007_audit_logs.py b/backend/alembic/versions/007_audit_logs.py index 6ef33de..4edc051 100644 --- a/backend/alembic/versions/007_audit_logs.py +++ b/backend/alembic/versions/007_audit_logs.py @@ -29,7 +29,8 @@ def upgrade() -> None: # ========================================================================= # CREATE audit_logs TABLE # ========================================================================= - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE TABLE IF NOT EXISTS audit_logs ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, @@ -42,39 +43,40 @@ def upgrade() -> None: ip_address VARCHAR(45), created_at TIMESTAMPTZ NOT NULL DEFAULT now() ) - """)) + """) + ) # ========================================================================= # RLS POLICY # ========================================================================= - conn.execute(sa.text( - "ALTER TABLE audit_logs ENABLE ROW LEVEL SECURITY" - )) - conn.execute(sa.text(""" + conn.execute(sa.text("ALTER TABLE audit_logs ENABLE ROW LEVEL SECURITY")) + conn.execute( + sa.text(""" CREATE POLICY audit_logs_tenant_isolation ON audit_logs USING (tenant_id = current_setting('app.current_tenant')::uuid) - """)) + """) + ) # Grant SELECT + INSERT to app_user (no UPDATE/DELETE -- audit logs are immutable) - conn.execute(sa.text( - "GRANT SELECT, INSERT ON audit_logs TO app_user" - )) + conn.execute(sa.text("GRANT SELECT, INSERT ON audit_logs TO app_user")) # Poller user gets full access for cross-tenant audit logging - conn.execute(sa.text( - "GRANT ALL ON audit_logs TO poller_user" - )) + conn.execute(sa.text("GRANT ALL ON audit_logs TO poller_user")) # ========================================================================= # INDEXES # ========================================================================= - conn.execute(sa.text( - "CREATE INDEX IF NOT EXISTS idx_audit_logs_tenant_created " - "ON audit_logs (tenant_id, created_at DESC)" - )) - conn.execute(sa.text( - "CREATE INDEX IF NOT EXISTS idx_audit_logs_tenant_action " - "ON audit_logs (tenant_id, action)" - )) + conn.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS idx_audit_logs_tenant_created " + "ON audit_logs (tenant_id, created_at DESC)" + ) + ) + conn.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS idx_audit_logs_tenant_action " + "ON audit_logs (tenant_id, action)" + ) + ) def downgrade() -> None: diff --git a/backend/alembic/versions/008_maintenance_windows.py b/backend/alembic/versions/008_maintenance_windows.py index 814cb0f..5da77b9 100644 --- a/backend/alembic/versions/008_maintenance_windows.py +++ b/backend/alembic/versions/008_maintenance_windows.py @@ -28,7 +28,8 @@ def upgrade() -> None: conn = op.get_bind() # ── 1. Create maintenance_windows table ──────────────────────────────── - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE TABLE IF NOT EXISTS maintenance_windows ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, @@ -44,18 +45,22 @@ def upgrade() -> None: CONSTRAINT chk_maintenance_window_dates CHECK (end_at > start_at) ) - """)) + """) + ) # ── 2. Composite index for active window queries ─────────────────────── - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE INDEX IF NOT EXISTS idx_maintenance_windows_tenant_time ON maintenance_windows (tenant_id, start_at, end_at) - """)) + """) + ) # ── 3. RLS policy ───────────────────────────────────────────────────── conn.execute(sa.text("ALTER TABLE maintenance_windows ENABLE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" DO $$ BEGIN IF NOT EXISTS ( @@ -67,10 +72,12 @@ def upgrade() -> None: END IF; END $$ - """)) + """) + ) # ── 4. Grant permissions to app_user ─────────────────────────────────── - conn.execute(sa.text(""" + conn.execute( + sa.text(""" DO $$ BEGIN IF EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'app_user') THEN @@ -78,7 +85,8 @@ def upgrade() -> None: END IF; END $$ - """)) + """) + ) def downgrade() -> None: diff --git a/backend/alembic/versions/010_wireguard_vpn.py b/backend/alembic/versions/010_wireguard_vpn.py index e034d4a..4079f28 100644 --- a/backend/alembic/versions/010_wireguard_vpn.py +++ b/backend/alembic/versions/010_wireguard_vpn.py @@ -28,34 +28,81 @@ def upgrade() -> None: # ── vpn_config: one row per tenant ── op.create_table( "vpn_config", - sa.Column("id", UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), primary_key=True), - sa.Column("tenant_id", UUID(as_uuid=True), sa.ForeignKey("tenants.id", ondelete="CASCADE"), nullable=False, unique=True), + sa.Column( + "id", UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), primary_key=True + ), + sa.Column( + "tenant_id", + UUID(as_uuid=True), + sa.ForeignKey("tenants.id", ondelete="CASCADE"), + nullable=False, + unique=True, + ), sa.Column("server_private_key", sa.LargeBinary(), nullable=False), # AES-256-GCM encrypted sa.Column("server_public_key", sa.String(64), nullable=False), sa.Column("subnet", sa.String(32), nullable=False, server_default="10.10.0.0/24"), sa.Column("server_port", sa.Integer(), nullable=False, server_default="51820"), sa.Column("server_address", sa.String(32), nullable=False, server_default="10.10.0.1/24"), - sa.Column("endpoint", sa.String(255), nullable=True), # public hostname:port for devices to connect to + sa.Column( + "endpoint", sa.String(255), nullable=True + ), # public hostname:port for devices to connect to sa.Column("is_enabled", sa.Boolean(), nullable=False, server_default="false"), - sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), ) # ── vpn_peers: one per device VPN connection ── op.create_table( "vpn_peers", - sa.Column("id", UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), primary_key=True), - sa.Column("tenant_id", UUID(as_uuid=True), sa.ForeignKey("tenants.id", ondelete="CASCADE"), nullable=False), - sa.Column("device_id", UUID(as_uuid=True), sa.ForeignKey("devices.id", ondelete="CASCADE"), nullable=False, unique=True), + sa.Column( + "id", UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), primary_key=True + ), + sa.Column( + "tenant_id", + UUID(as_uuid=True), + sa.ForeignKey("tenants.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "device_id", + UUID(as_uuid=True), + sa.ForeignKey("devices.id", ondelete="CASCADE"), + nullable=False, + unique=True, + ), sa.Column("peer_private_key", sa.LargeBinary(), nullable=False), # AES-256-GCM encrypted sa.Column("peer_public_key", sa.String(64), nullable=False), - sa.Column("preshared_key", sa.LargeBinary(), nullable=True), # AES-256-GCM encrypted, optional + sa.Column( + "preshared_key", sa.LargeBinary(), nullable=True + ), # AES-256-GCM encrypted, optional sa.Column("assigned_ip", sa.String(32), nullable=False), # e.g. 10.10.0.2/24 - sa.Column("additional_allowed_ips", sa.String(512), nullable=True), # comma-separated subnets for site-to-site + sa.Column( + "additional_allowed_ips", sa.String(512), nullable=True + ), # comma-separated subnets for site-to-site sa.Column("is_enabled", sa.Boolean(), nullable=False, server_default="true"), sa.Column("last_handshake", sa.DateTime(timezone=True), nullable=True), - sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), - sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), ) # Indexes diff --git a/backend/alembic/versions/012_seed_starter_templates.py b/backend/alembic/versions/012_seed_starter_templates.py index 375fffe..8e68f38 100644 --- a/backend/alembic/versions/012_seed_starter_templates.py +++ b/backend/alembic/versions/012_seed_starter_templates.py @@ -22,7 +22,8 @@ def upgrade() -> None: conn = op.get_bind() # 1. Basic Router — comprehensive starter for a typical SOHO/branch router - conn.execute(sa.text(""" + conn.execute( + sa.text(""" INSERT INTO config_templates (id, tenant_id, name, description, content, variables) SELECT gen_random_uuid(), @@ -75,10 +76,12 @@ add chain=forward action=drop comment="Drop everything else" SELECT 1 FROM config_templates ct WHERE ct.tenant_id = t.id AND ct.name = 'Basic Router' ) - """)) + """) + ) # 2. Re-seed Basic Firewall (for tenants missing it) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" INSERT INTO config_templates (id, tenant_id, name, description, content, variables) SELECT gen_random_uuid(), @@ -100,10 +103,12 @@ add chain=forward action=drop', SELECT 1 FROM config_templates ct WHERE ct.tenant_id = t.id AND ct.name = 'Basic Firewall' ) - """)) + """) + ) # 3. Re-seed DHCP Server Setup - conn.execute(sa.text(""" + conn.execute( + sa.text(""" INSERT INTO config_templates (id, tenant_id, name, description, content, variables) SELECT gen_random_uuid(), @@ -119,10 +124,12 @@ add chain=forward action=drop', SELECT 1 FROM config_templates ct WHERE ct.tenant_id = t.id AND ct.name = 'DHCP Server Setup' ) - """)) + """) + ) # 4. Re-seed Wireless AP Config - conn.execute(sa.text(""" + conn.execute( + sa.text(""" INSERT INTO config_templates (id, tenant_id, name, description, content, variables) SELECT gen_random_uuid(), @@ -137,10 +144,12 @@ add chain=forward action=drop', SELECT 1 FROM config_templates ct WHERE ct.tenant_id = t.id AND ct.name = 'Wireless AP Config' ) - """)) + """) + ) # 5. Re-seed Initial Device Setup - conn.execute(sa.text(""" + conn.execute( + sa.text(""" INSERT INTO config_templates (id, tenant_id, name, description, content, variables) SELECT gen_random_uuid(), @@ -159,11 +168,10 @@ add chain=forward action=drop', SELECT 1 FROM config_templates ct WHERE ct.tenant_id = t.id AND ct.name = 'Initial Device Setup' ) - """)) + """) + ) def downgrade() -> None: conn = op.get_bind() - conn.execute(sa.text( - "DELETE FROM config_templates WHERE name = 'Basic Router'" - )) + conn.execute(sa.text("DELETE FROM config_templates WHERE name = 'Basic Router'")) diff --git a/backend/alembic/versions/013_certificates.py b/backend/alembic/versions/013_certificates.py index b29f1d0..c2c670a 100644 --- a/backend/alembic/versions/013_certificates.py +++ b/backend/alembic/versions/013_certificates.py @@ -138,62 +138,44 @@ def upgrade() -> None: conn = op.get_bind() # certificate_authorities RLS - conn.execute(sa.text( - "ALTER TABLE certificate_authorities ENABLE ROW LEVEL SECURITY" - )) - conn.execute(sa.text( - "GRANT SELECT, INSERT, UPDATE, DELETE ON certificate_authorities TO app_user" - )) - conn.execute(sa.text( - "CREATE POLICY tenant_isolation ON certificate_authorities FOR ALL " - "USING (tenant_id = NULLIF(current_setting('app.current_tenant', true), '')::uuid) " - "WITH CHECK (tenant_id = NULLIF(current_setting('app.current_tenant', true), '')::uuid)" - )) - conn.execute(sa.text( - "GRANT SELECT ON certificate_authorities TO poller_user" - )) + conn.execute(sa.text("ALTER TABLE certificate_authorities ENABLE ROW LEVEL SECURITY")) + conn.execute( + sa.text("GRANT SELECT, INSERT, UPDATE, DELETE ON certificate_authorities TO app_user") + ) + conn.execute( + sa.text( + "CREATE POLICY tenant_isolation ON certificate_authorities FOR ALL " + "USING (tenant_id = NULLIF(current_setting('app.current_tenant', true), '')::uuid) " + "WITH CHECK (tenant_id = NULLIF(current_setting('app.current_tenant', true), '')::uuid)" + ) + ) + conn.execute(sa.text("GRANT SELECT ON certificate_authorities TO poller_user")) # device_certificates RLS - conn.execute(sa.text( - "ALTER TABLE device_certificates ENABLE ROW LEVEL SECURITY" - )) - conn.execute(sa.text( - "GRANT SELECT, INSERT, UPDATE, DELETE ON device_certificates TO app_user" - )) - conn.execute(sa.text( - "CREATE POLICY tenant_isolation ON device_certificates FOR ALL " - "USING (tenant_id = NULLIF(current_setting('app.current_tenant', true), '')::uuid) " - "WITH CHECK (tenant_id = NULLIF(current_setting('app.current_tenant', true), '')::uuid)" - )) - conn.execute(sa.text( - "GRANT SELECT ON device_certificates TO poller_user" - )) + conn.execute(sa.text("ALTER TABLE device_certificates ENABLE ROW LEVEL SECURITY")) + conn.execute(sa.text("GRANT SELECT, INSERT, UPDATE, DELETE ON device_certificates TO app_user")) + conn.execute( + sa.text( + "CREATE POLICY tenant_isolation ON device_certificates FOR ALL " + "USING (tenant_id = NULLIF(current_setting('app.current_tenant', true), '')::uuid) " + "WITH CHECK (tenant_id = NULLIF(current_setting('app.current_tenant', true), '')::uuid)" + ) + ) + conn.execute(sa.text("GRANT SELECT ON device_certificates TO poller_user")) def downgrade() -> None: conn = op.get_bind() # Drop RLS policies - conn.execute(sa.text( - "DROP POLICY IF EXISTS tenant_isolation ON device_certificates" - )) - conn.execute(sa.text( - "DROP POLICY IF EXISTS tenant_isolation ON certificate_authorities" - )) + conn.execute(sa.text("DROP POLICY IF EXISTS tenant_isolation ON device_certificates")) + conn.execute(sa.text("DROP POLICY IF EXISTS tenant_isolation ON certificate_authorities")) # Revoke grants - conn.execute(sa.text( - "REVOKE ALL ON device_certificates FROM app_user" - )) - conn.execute(sa.text( - "REVOKE ALL ON device_certificates FROM poller_user" - )) - conn.execute(sa.text( - "REVOKE ALL ON certificate_authorities FROM app_user" - )) - conn.execute(sa.text( - "REVOKE ALL ON certificate_authorities FROM poller_user" - )) + conn.execute(sa.text("REVOKE ALL ON device_certificates FROM app_user")) + conn.execute(sa.text("REVOKE ALL ON device_certificates FROM poller_user")) + conn.execute(sa.text("REVOKE ALL ON certificate_authorities FROM app_user")) + conn.execute(sa.text("REVOKE ALL ON certificate_authorities FROM poller_user")) # Drop tls_mode column from devices op.drop_column("devices", "tls_mode") diff --git a/backend/alembic/versions/014_timescaledb_retention.py b/backend/alembic/versions/014_timescaledb_retention.py index cb48a97..d38c64e 100644 --- a/backend/alembic/versions/014_timescaledb_retention.py +++ b/backend/alembic/versions/014_timescaledb_retention.py @@ -35,9 +35,7 @@ def upgrade() -> None: for table in HYPERTABLES: # Drop chunks older than 90 days - conn.execute(sa.text( - f"SELECT add_retention_policy('{table}', INTERVAL '90 days')" - )) + conn.execute(sa.text(f"SELECT add_retention_policy('{table}', INTERVAL '90 days')")) def downgrade() -> None: @@ -45,6 +43,4 @@ def downgrade() -> None: for table in HYPERTABLES: # Remove retention policy - conn.execute(sa.text( - f"SELECT remove_retention_policy('{table}', if_exists => true)" - )) + conn.execute(sa.text(f"SELECT remove_retention_policy('{table}', if_exists => true)")) diff --git a/backend/alembic/versions/016_zero_knowledge_schema.py b/backend/alembic/versions/016_zero_knowledge_schema.py index 38dd56c..456a8a9 100644 --- a/backend/alembic/versions/016_zero_knowledge_schema.py +++ b/backend/alembic/versions/016_zero_knowledge_schema.py @@ -147,46 +147,36 @@ def upgrade() -> None: conn = op.get_bind() # user_key_sets RLS - conn.execute(sa.text( - "ALTER TABLE user_key_sets ENABLE ROW LEVEL SECURITY" - )) - conn.execute(sa.text( - "CREATE POLICY user_key_sets_tenant_isolation ON user_key_sets " - "USING (tenant_id::text = current_setting('app.current_tenant', true) " - "OR current_setting('app.current_tenant', true) = 'super_admin')" - )) - conn.execute(sa.text( - "GRANT SELECT, INSERT, UPDATE ON user_key_sets TO app_user" - )) + conn.execute(sa.text("ALTER TABLE user_key_sets ENABLE ROW LEVEL SECURITY")) + conn.execute( + sa.text( + "CREATE POLICY user_key_sets_tenant_isolation ON user_key_sets " + "USING (tenant_id::text = current_setting('app.current_tenant', true) " + "OR current_setting('app.current_tenant', true) = 'super_admin')" + ) + ) + conn.execute(sa.text("GRANT SELECT, INSERT, UPDATE ON user_key_sets TO app_user")) # key_access_log RLS (append-only: INSERT+SELECT only, no UPDATE/DELETE) - conn.execute(sa.text( - "ALTER TABLE key_access_log ENABLE ROW LEVEL SECURITY" - )) - conn.execute(sa.text( - "CREATE POLICY key_access_log_tenant_isolation ON key_access_log " - "USING (tenant_id::text = current_setting('app.current_tenant', true) " - "OR current_setting('app.current_tenant', true) = 'super_admin')" - )) - conn.execute(sa.text( - "GRANT INSERT, SELECT ON key_access_log TO app_user" - )) + conn.execute(sa.text("ALTER TABLE key_access_log ENABLE ROW LEVEL SECURITY")) + conn.execute( + sa.text( + "CREATE POLICY key_access_log_tenant_isolation ON key_access_log " + "USING (tenant_id::text = current_setting('app.current_tenant', true) " + "OR current_setting('app.current_tenant', true) = 'super_admin')" + ) + ) + conn.execute(sa.text("GRANT INSERT, SELECT ON key_access_log TO app_user")) # poller_user needs INSERT to log key access events when decrypting credentials - conn.execute(sa.text( - "GRANT INSERT, SELECT ON key_access_log TO poller_user" - )) + conn.execute(sa.text("GRANT INSERT, SELECT ON key_access_log TO poller_user")) def downgrade() -> None: conn = op.get_bind() # Drop RLS policies - conn.execute(sa.text( - "DROP POLICY IF EXISTS key_access_log_tenant_isolation ON key_access_log" - )) - conn.execute(sa.text( - "DROP POLICY IF EXISTS user_key_sets_tenant_isolation ON user_key_sets" - )) + conn.execute(sa.text("DROP POLICY IF EXISTS key_access_log_tenant_isolation ON key_access_log")) + conn.execute(sa.text("DROP POLICY IF EXISTS user_key_sets_tenant_isolation ON user_key_sets")) # Revoke grants conn.execute(sa.text("REVOKE ALL ON key_access_log FROM app_user")) diff --git a/backend/alembic/versions/017_openbao_envelope_encryption.py b/backend/alembic/versions/017_openbao_envelope_encryption.py index b032ceb..b68f953 100644 --- a/backend/alembic/versions/017_openbao_envelope_encryption.py +++ b/backend/alembic/versions/017_openbao_envelope_encryption.py @@ -77,9 +77,7 @@ def upgrade() -> None: def downgrade() -> None: - op.drop_constraint( - "fk_key_access_log_device_id", "key_access_log", type_="foreignkey" - ) + op.drop_constraint("fk_key_access_log_device_id", "key_access_log", type_="foreignkey") op.drop_column("key_access_log", "correlation_id") op.drop_column("key_access_log", "justification") op.drop_column("key_access_log", "device_id") diff --git a/backend/alembic/versions/019_deprecate_bcrypt.py b/backend/alembic/versions/019_deprecate_bcrypt.py index 627a7cf..fdc4f86 100644 --- a/backend/alembic/versions/019_deprecate_bcrypt.py +++ b/backend/alembic/versions/019_deprecate_bcrypt.py @@ -33,8 +33,7 @@ def upgrade() -> None: # Flag all bcrypt-only users for upgrade (auth_version=1 and no SRP verifier) op.execute( - "UPDATE users SET must_upgrade_auth = true " - "WHERE auth_version = 1 AND srp_verifier IS NULL" + "UPDATE users SET must_upgrade_auth = true WHERE auth_version = 1 AND srp_verifier IS NULL" ) # Make hashed_password nullable (SRP users don't need it) @@ -44,8 +43,7 @@ def upgrade() -> None: def downgrade() -> None: # Restore NOT NULL (set a dummy value for any NULLs first) op.execute( - "UPDATE users SET hashed_password = '$2b$12$placeholder' " - "WHERE hashed_password IS NULL" + "UPDATE users SET hashed_password = '$2b$12$placeholder' WHERE hashed_password IS NULL" ) op.alter_column("users", "hashed_password", nullable=False) diff --git a/backend/alembic/versions/020_tls_mode_opt_in.py b/backend/alembic/versions/020_tls_mode_opt_in.py index 0d2b82b..a1bb99b 100644 --- a/backend/alembic/versions/020_tls_mode_opt_in.py +++ b/backend/alembic/versions/020_tls_mode_opt_in.py @@ -14,7 +14,6 @@ Existing 'insecure' devices become 'auto' since the old behavior was an implicit auto-fallback. portal_ca devices keep their mode. """ -import sqlalchemy as sa from alembic import op revision = "020" diff --git a/backend/alembic/versions/022_rls_super_admin_devices.py b/backend/alembic/versions/022_rls_super_admin_devices.py index 6a4bfbb..a1d99d4 100644 --- a/backend/alembic/versions/022_rls_super_admin_devices.py +++ b/backend/alembic/versions/022_rls_super_admin_devices.py @@ -25,7 +25,8 @@ def upgrade() -> None: conn = op.get_bind() for table in _TABLES: conn.execute(sa.text(f"DROP POLICY IF EXISTS tenant_isolation ON {table}")) - conn.execute(sa.text(f""" + conn.execute( + sa.text(f""" CREATE POLICY tenant_isolation ON {table} USING ( tenant_id::text = current_setting('app.current_tenant', true) @@ -35,15 +36,18 @@ def upgrade() -> None: tenant_id::text = current_setting('app.current_tenant', true) OR current_setting('app.current_tenant', true) = 'super_admin' ) - """)) + """) + ) def downgrade() -> None: conn = op.get_bind() for table in _TABLES: conn.execute(sa.text(f"DROP POLICY IF EXISTS tenant_isolation ON {table}")) - conn.execute(sa.text(f""" + conn.execute( + sa.text(f""" CREATE POLICY tenant_isolation ON {table} USING (tenant_id::text = current_setting('app.current_tenant', true)) WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true)) - """)) + """) + ) diff --git a/backend/alembic/versions/024_contact_email_and_offline_rule.py b/backend/alembic/versions/024_contact_email_and_offline_rule.py index 6c5e035..9dd6c0d 100644 --- a/backend/alembic/versions/024_contact_email_and_offline_rule.py +++ b/backend/alembic/versions/024_contact_email_and_offline_rule.py @@ -19,7 +19,8 @@ def upgrade() -> None: op.add_column("tenants", sa.Column("contact_email", sa.String(255), nullable=True)) # 2. Seed device_offline default alert rule for all existing tenants - conn.execute(sa.text(""" + conn.execute( + sa.text(""" INSERT INTO alert_rules (id, tenant_id, name, metric, operator, threshold, duration_polls, severity, enabled, is_default) SELECT gen_random_uuid(), t.id, 'Device Offline', 'device_offline', 'eq', 1, 1, 'critical', TRUE, TRUE FROM tenants t @@ -28,14 +29,17 @@ def upgrade() -> None: SELECT 1 FROM alert_rules ar WHERE ar.tenant_id = t.id AND ar.metric = 'device_offline' AND ar.is_default = TRUE ) - """)) + """) + ) def downgrade() -> None: conn = op.get_bind() - conn.execute(sa.text(""" + conn.execute( + sa.text(""" DELETE FROM alert_rules WHERE metric = 'device_offline' AND is_default = TRUE - """)) + """) + ) op.drop_column("tenants", "contact_email") diff --git a/backend/alembic/versions/025_fix_key_access_log_device_fk.py b/backend/alembic/versions/025_fix_key_access_log_device_fk.py index 818c3b6..1ccdaa8 100644 --- a/backend/alembic/versions/025_fix_key_access_log_device_fk.py +++ b/backend/alembic/versions/025_fix_key_access_log_device_fk.py @@ -11,9 +11,7 @@ down_revision = "024" def upgrade() -> None: - op.drop_constraint( - "fk_key_access_log_device_id", "key_access_log", type_="foreignkey" - ) + op.drop_constraint("fk_key_access_log_device_id", "key_access_log", type_="foreignkey") op.create_foreign_key( "fk_key_access_log_device_id", "key_access_log", @@ -25,9 +23,7 @@ def upgrade() -> None: def downgrade() -> None: - op.drop_constraint( - "fk_key_access_log_device_id", "key_access_log", type_="foreignkey" - ) + op.drop_constraint("fk_key_access_log_device_id", "key_access_log", type_="foreignkey") op.create_foreign_key( "fk_key_access_log_device_id", "key_access_log", diff --git a/backend/alembic/versions/027_router_config_snapshots.py b/backend/alembic/versions/027_router_config_snapshots.py index 6ed5fcb..b7ed34a 100644 --- a/backend/alembic/versions/027_router_config_snapshots.py +++ b/backend/alembic/versions/027_router_config_snapshots.py @@ -25,7 +25,8 @@ def upgrade() -> None: conn = op.get_bind() # ── router_config_snapshots ────────────────────────────────────────── - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE TABLE router_config_snapshots ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), device_id UUID NOT NULL REFERENCES devices(id) ON DELETE CASCADE, @@ -35,30 +36,38 @@ def upgrade() -> None: collected_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ) - """)) + """) + ) # RLS conn.execute(sa.text("ALTER TABLE router_config_snapshots ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE router_config_snapshots FORCE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON router_config_snapshots USING (tenant_id::text = current_setting('app.current_tenant', true)) WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true)) - """)) + """) + ) # Grants conn.execute(sa.text("GRANT SELECT, INSERT, DELETE ON router_config_snapshots TO app_user")) # Indexes - conn.execute(sa.text( - "CREATE INDEX idx_rcs_device_collected ON router_config_snapshots (device_id, collected_at DESC)" - )) - conn.execute(sa.text( - "CREATE INDEX idx_rcs_device_hash ON router_config_snapshots (device_id, sha256_hash)" - )) + conn.execute( + sa.text( + "CREATE INDEX idx_rcs_device_collected ON router_config_snapshots (device_id, collected_at DESC)" + ) + ) + conn.execute( + sa.text( + "CREATE INDEX idx_rcs_device_hash ON router_config_snapshots (device_id, sha256_hash)" + ) + ) # ── router_config_diffs ────────────────────────────────────────────── - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE TABLE router_config_diffs ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), device_id UUID NOT NULL REFERENCES devices(id) ON DELETE CASCADE, @@ -70,27 +79,33 @@ def upgrade() -> None: lines_removed INT NOT NULL DEFAULT 0, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ) - """)) + """) + ) # RLS conn.execute(sa.text("ALTER TABLE router_config_diffs ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE router_config_diffs FORCE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON router_config_diffs USING (tenant_id::text = current_setting('app.current_tenant', true)) WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true)) - """)) + """) + ) # Grants conn.execute(sa.text("GRANT SELECT, INSERT, DELETE ON router_config_diffs TO app_user")) # Indexes - conn.execute(sa.text( - "CREATE UNIQUE INDEX idx_rcd_snapshot_pair ON router_config_diffs (old_snapshot_id, new_snapshot_id)" - )) + conn.execute( + sa.text( + "CREATE UNIQUE INDEX idx_rcd_snapshot_pair ON router_config_diffs (old_snapshot_id, new_snapshot_id)" + ) + ) # ── router_config_changes ──────────────────────────────────────────── - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE TABLE router_config_changes ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), diff_id UUID NOT NULL REFERENCES router_config_diffs(id) ON DELETE CASCADE, @@ -101,24 +116,25 @@ def upgrade() -> None: raw_line TEXT, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ) - """)) + """) + ) # RLS conn.execute(sa.text("ALTER TABLE router_config_changes ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE router_config_changes FORCE ROW LEVEL SECURITY")) - conn.execute(sa.text(""" + conn.execute( + sa.text(""" CREATE POLICY tenant_isolation ON router_config_changes USING (tenant_id::text = current_setting('app.current_tenant', true)) WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true)) - """)) + """) + ) # Grants conn.execute(sa.text("GRANT SELECT, INSERT, DELETE ON router_config_changes TO app_user")) # Indexes - conn.execute(sa.text( - "CREATE INDEX idx_rcc_diff_id ON router_config_changes (diff_id)" - )) + conn.execute(sa.text("CREATE INDEX idx_rcc_diff_id ON router_config_changes (diff_id)")) def downgrade() -> None: diff --git a/backend/alembic/versions/028_device_ssh_host_key.py b/backend/alembic/versions/028_device_ssh_host_key.py index 653d94f..46ad98d 100644 --- a/backend/alembic/versions/028_device_ssh_host_key.py +++ b/backend/alembic/versions/028_device_ssh_host_key.py @@ -25,40 +25,28 @@ import sqlalchemy as sa def upgrade() -> None: conn = op.get_bind() - conn.execute(sa.text( - "ALTER TABLE devices ADD COLUMN ssh_port INTEGER DEFAULT 22" - )) - conn.execute(sa.text( - "ALTER TABLE devices ADD COLUMN ssh_host_key_fingerprint TEXT" - )) - conn.execute(sa.text( - "ALTER TABLE devices ADD COLUMN ssh_host_key_first_seen TIMESTAMPTZ" - )) - conn.execute(sa.text( - "ALTER TABLE devices ADD COLUMN ssh_host_key_last_verified TIMESTAMPTZ" - )) + conn.execute(sa.text("ALTER TABLE devices ADD COLUMN ssh_port INTEGER DEFAULT 22")) + conn.execute(sa.text("ALTER TABLE devices ADD COLUMN ssh_host_key_fingerprint TEXT")) + conn.execute(sa.text("ALTER TABLE devices ADD COLUMN ssh_host_key_first_seen TIMESTAMPTZ")) + conn.execute(sa.text("ALTER TABLE devices ADD COLUMN ssh_host_key_last_verified TIMESTAMPTZ")) # Grant poller_user UPDATE on SSH columns for TOFU host key persistence - conn.execute(sa.text( - "GRANT UPDATE (ssh_host_key_fingerprint, ssh_host_key_first_seen, ssh_host_key_last_verified) ON devices TO poller_user" - )) + conn.execute( + sa.text( + "GRANT UPDATE (ssh_host_key_fingerprint, ssh_host_key_first_seen, ssh_host_key_last_verified) ON devices TO poller_user" + ) + ) def downgrade() -> None: conn = op.get_bind() - conn.execute(sa.text( - "REVOKE UPDATE (ssh_host_key_fingerprint, ssh_host_key_first_seen, ssh_host_key_last_verified) ON devices FROM poller_user" - )) - conn.execute(sa.text( - "ALTER TABLE devices DROP COLUMN ssh_host_key_last_verified" - )) - conn.execute(sa.text( - "ALTER TABLE devices DROP COLUMN ssh_host_key_first_seen" - )) - conn.execute(sa.text( - "ALTER TABLE devices DROP COLUMN ssh_host_key_fingerprint" - )) - conn.execute(sa.text( - "ALTER TABLE devices DROP COLUMN ssh_port" - )) + conn.execute( + sa.text( + "REVOKE UPDATE (ssh_host_key_fingerprint, ssh_host_key_first_seen, ssh_host_key_last_verified) ON devices FROM poller_user" + ) + ) + conn.execute(sa.text("ALTER TABLE devices DROP COLUMN ssh_host_key_last_verified")) + conn.execute(sa.text("ALTER TABLE devices DROP COLUMN ssh_host_key_first_seen")) + conn.execute(sa.text("ALTER TABLE devices DROP COLUMN ssh_host_key_fingerprint")) + conn.execute(sa.text("ALTER TABLE devices DROP COLUMN ssh_port")) diff --git a/backend/alembic/versions/029_vpn_tenant_isolation.py b/backend/alembic/versions/029_vpn_tenant_isolation.py index 71b58e9..b5bfcf9 100644 --- a/backend/alembic/versions/029_vpn_tenant_isolation.py +++ b/backend/alembic/versions/029_vpn_tenant_isolation.py @@ -16,7 +16,12 @@ import base64 from alembic import op import sqlalchemy as sa from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey -from cryptography.hazmat.primitives.serialization import Encoding, NoEncryption, PrivateFormat, PublicFormat +from cryptography.hazmat.primitives.serialization import ( + Encoding, + NoEncryption, + PrivateFormat, + PublicFormat, +) from cryptography.hazmat.primitives.ciphers.aead import AESGCM @@ -120,5 +125,9 @@ def downgrade() -> None: op.alter_column("vpn_config", "subnet", server_default="10.10.0.0/24") op.alter_column("vpn_config", "server_address", server_default="10.10.0.1/24") conn = op.get_bind() - conn.execute(sa.text("DELETE FROM system_settings WHERE key IN ('vpn_server_public_key', 'vpn_server_private_key')")) + conn.execute( + sa.text( + "DELETE FROM system_settings WHERE key IN ('vpn_server_public_key', 'vpn_server_private_key')" + ) + ) # NOTE: downgrade does not remap peer IPs back. Manual cleanup may be needed. diff --git a/backend/app/config.py b/backend/app/config.py index f9e51d2..74e4e7a 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -42,8 +42,8 @@ def validate_production_settings(settings: "Settings") -> None: print( f"FATAL: {field} uses a known insecure default in '{settings.ENVIRONMENT}' environment.\n" f"Generate a secure value and set it in your .env.prod file.\n" - f"For JWT_SECRET_KEY: python -c \"import secrets; print(secrets.token_urlsafe(64))\"\n" - f"For CREDENTIAL_ENCRYPTION_KEY: python -c \"import secrets, base64; print(base64.b64encode(secrets.token_bytes(32)).decode())\"\n" + f'For JWT_SECRET_KEY: python -c "import secrets; print(secrets.token_urlsafe(64))"\n' + f'For CREDENTIAL_ENCRYPTION_KEY: python -c "import secrets, base64; print(base64.b64encode(secrets.token_bytes(32)).decode())"\n' f"For OPENBAO_TOKEN: use the token from your OpenBao server (not the dev token)", file=sys.stderr, ) diff --git a/backend/app/database.py b/backend/app/database.py index 321aca4..2dc1bcd 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -17,6 +17,7 @@ from app.config import settings class Base(DeclarativeBase): """Base class for all SQLAlchemy ORM models.""" + pass diff --git a/backend/app/main.py b/backend/app/main.py index 2ab5ebc..e094696 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -82,7 +82,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: from app.services.retention_service import start_retention_scheduler, stop_retention_scheduler from app.services.metrics_subscriber import start_metrics_subscriber, stop_metrics_subscriber from app.services.nats_subscriber import start_nats_subscriber, stop_nats_subscriber - from app.services.session_audit_subscriber import start_session_audit_subscriber, stop_session_audit_subscriber + from app.services.session_audit_subscriber import ( + start_session_audit_subscriber, + stop_session_audit_subscriber, + ) from app.services.sse_manager import ensure_sse_streams # Configure structured logging FIRST -- before any other startup work @@ -201,6 +204,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: start_config_change_subscriber, stop_config_change_subscriber, ) + config_change_nc = await start_config_change_subscriber() except Exception as e: logger.error("Config change subscriber failed to start (non-fatal): %s", e) @@ -212,6 +216,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: start_push_rollback_subscriber, stop_push_rollback_subscriber, ) + push_rollback_nc = await start_push_rollback_subscriber() except Exception as e: logger.error("Push rollback subscriber failed to start (non-fatal): %s", e) @@ -223,6 +228,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: start_config_snapshot_subscriber, stop_config_snapshot_subscriber, ) + config_snapshot_nc = await start_config_snapshot_subscriber() except Exception as e: logger.error("Config snapshot subscriber failed to start (non-fatal): %s", e) @@ -231,14 +237,16 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: try: await start_retention_scheduler() except Exception as exc: - logger.warning("retention scheduler could not start (API will run without it)", error=str(exc)) + logger.warning( + "retention scheduler could not start (API will run without it)", error=str(exc) + ) # Start Remote WinBox session reconciliation loop (60s interval). # Detects orphaned sessions (worker lost them) and cleans up Redis + tunnels. winbox_reconcile_task: Optional[asyncio.Task] = None # type: ignore[type-arg] try: from app.routers.winbox_remote import _get_redis as _wb_get_redis, _close_tunnel - from app.services.winbox_remote import get_session as _wb_worker_get, health_check as _wb_health + from app.services.winbox_remote import get_session as _wb_worker_get async def _winbox_reconcile_loop() -> None: """Scan Redis for winbox-remote:* keys and reconcile with worker.""" diff --git a/backend/app/middleware/rbac.py b/backend/app/middleware/rbac.py index ca6129a..9f4eeca 100644 --- a/backend/app/middleware/rbac.py +++ b/backend/app/middleware/rbac.py @@ -49,6 +49,7 @@ def require_role(*allowed_roles: str) -> Callable: Returns: FastAPI dependency that raises 403 if the role is insufficient """ + async def dependency( current_user: CurrentUser = Depends(get_current_user), ) -> CurrentUser: @@ -56,7 +57,7 @@ def require_role(*allowed_roles: str) -> Callable: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Access denied. Required roles: {', '.join(allowed_roles)}. " - f"Your role: {current_user.role}", + f"Your role: {current_user.role}", ) return current_user @@ -82,7 +83,7 @@ def require_min_role(min_role: str) -> Callable: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Access denied. Minimum required role: {min_role}. " - f"Your role: {current_user.role}", + f"Your role: {current_user.role}", ) return current_user @@ -96,6 +97,7 @@ def require_write_access() -> Callable: Viewers are NOT allowed on POST/PUT/PATCH/DELETE endpoints. Call this on any mutating endpoint to deny viewers. """ + async def dependency( request: Request, current_user: CurrentUser = Depends(get_current_user), @@ -105,7 +107,7 @@ def require_write_access() -> Callable: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Viewers have read-only access. " - "Contact your administrator to request elevated permissions.", + "Contact your administrator to request elevated permissions.", ) return current_user @@ -127,6 +129,7 @@ def require_scope(scope: str) -> DependsClass: Raises: HTTPException 403 if the API key is missing the required scope. """ + async def _check_scope( current_user: CurrentUser = Depends(get_current_user), ) -> CurrentUser: @@ -143,6 +146,7 @@ def require_scope(scope: str) -> DependsClass: # Pre-built convenience dependencies + async def require_super_admin( current_user: CurrentUser = Depends(get_current_user), ) -> CurrentUser: diff --git a/backend/app/middleware/security_headers.py b/backend/app/middleware/security_headers.py index c3a0ec3..fe3796f 100644 --- a/backend/app/middleware/security_headers.py +++ b/backend/app/middleware/security_headers.py @@ -20,32 +20,36 @@ from starlette.requests import Request from starlette.responses import Response # Production CSP: strict -- no inline scripts allowed -_CSP_PRODUCTION = "; ".join([ - "default-src 'self'", - "script-src 'self'", - "style-src 'self' 'unsafe-inline'", - "img-src 'self' data: blob:", - "font-src 'self'", - "connect-src 'self' wss: ws:", - "worker-src 'self'", - "frame-ancestors 'none'", - "base-uri 'self'", - "form-action 'self'", -]) +_CSP_PRODUCTION = "; ".join( + [ + "default-src 'self'", + "script-src 'self'", + "style-src 'self' 'unsafe-inline'", + "img-src 'self' data: blob:", + "font-src 'self'", + "connect-src 'self' wss: ws:", + "worker-src 'self'", + "frame-ancestors 'none'", + "base-uri 'self'", + "form-action 'self'", + ] +) # Development CSP: relaxed for Vite HMR (hot module replacement) -_CSP_DEV = "; ".join([ - "default-src 'self'", - "script-src 'self' 'unsafe-inline' 'unsafe-eval'", - "style-src 'self' 'unsafe-inline'", - "img-src 'self' data: blob:", - "font-src 'self'", - "connect-src 'self' http://localhost:* ws://localhost:* wss:", - "worker-src 'self' blob:", - "frame-ancestors 'none'", - "base-uri 'self'", - "form-action 'self'", -]) +_CSP_DEV = "; ".join( + [ + "default-src 'self'", + "script-src 'self' 'unsafe-inline' 'unsafe-eval'", + "style-src 'self' 'unsafe-inline'", + "img-src 'self' data: blob:", + "font-src 'self'", + "connect-src 'self' http://localhost:* ws://localhost:* wss:", + "worker-src 'self' blob:", + "frame-ancestors 'none'", + "base-uri 'self'", + "form-action 'self'", + ] +) class SecurityHeadersMiddleware(BaseHTTPMiddleware): @@ -72,8 +76,6 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware): # HSTS only in production (plain HTTP in dev would be blocked) if self.is_production: - response.headers["Strict-Transport-Security"] = ( - "max-age=31536000; includeSubDomains" - ) + response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" return response diff --git a/backend/app/middleware/tenant_context.py b/backend/app/middleware/tenant_context.py index 73cd15c..ab68a8b 100644 --- a/backend/app/middleware/tenant_context.py +++ b/backend/app/middleware/tenant_context.py @@ -178,7 +178,6 @@ async def get_current_user_ws( Raises: WebSocketException 1008: If no token is provided or the token is invalid. """ - from starlette.websockets import WebSocket, WebSocketState from fastapi import WebSocketException # 1. Try cookie diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 3da20ab..2a62654 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -2,7 +2,14 @@ from app.models.tenant import Tenant from app.models.user import User, UserRole -from app.models.device import Device, DeviceGroup, DeviceTag, DeviceGroupMembership, DeviceTagAssignment, DeviceStatus +from app.models.device import ( + Device, + DeviceGroup, + DeviceTag, + DeviceGroupMembership, + DeviceTagAssignment, + DeviceStatus, +) from app.models.alert import AlertRule, NotificationChannel, AlertRuleChannel, AlertEvent from app.models.firmware import FirmwareVersion, FirmwareUpgradeJob from app.models.config_template import ConfigTemplate, ConfigTemplateTag, TemplatePushJob diff --git a/backend/app/models/alert.py b/backend/app/models/alert.py index cd798f8..b239e99 100644 --- a/backend/app/models/alert.py +++ b/backend/app/models/alert.py @@ -26,6 +26,7 @@ class AlertRule(Base): When a metric breaches the threshold for duration_polls consecutive polls, an alert fires. """ + __tablename__ = "alert_rules" id: Mapped[uuid.UUID] = mapped_column( @@ -53,10 +54,16 @@ class AlertRule(Base): metric: Mapped[str] = mapped_column(Text, nullable=False) operator: Mapped[str] = mapped_column(Text, nullable=False) threshold: Mapped[float] = mapped_column(Numeric, nullable=False) - duration_polls: Mapped[int] = mapped_column(Integer, nullable=False, default=1, server_default="1") + duration_polls: Mapped[int] = mapped_column( + Integer, nullable=False, default=1, server_default="1" + ) severity: Mapped[str] = mapped_column(Text, nullable=False) - enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True, server_default="true") - is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default="false") + enabled: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=True, server_default="true" + ) + is_default: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=False, server_default="false" + ) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), @@ -69,6 +76,7 @@ class AlertRule(Base): class NotificationChannel(Base): """Email, webhook, or Slack notification destination.""" + __tablename__ = "notification_channels" id: Mapped[uuid.UUID] = mapped_column( @@ -83,12 +91,16 @@ class NotificationChannel(Base): nullable=False, ) name: Mapped[str] = mapped_column(Text, nullable=False) - channel_type: Mapped[str] = mapped_column(Text, nullable=False) # "email", "webhook", or "slack" + channel_type: Mapped[str] = mapped_column( + Text, nullable=False + ) # "email", "webhook", or "slack" # SMTP fields (email channels) smtp_host: Mapped[str | None] = mapped_column(Text, nullable=True) smtp_port: Mapped[int | None] = mapped_column(Integer, nullable=True) smtp_user: Mapped[str | None] = mapped_column(Text, nullable=True) - smtp_password: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True) # AES-256-GCM encrypted + smtp_password: Mapped[bytes | None] = mapped_column( + LargeBinary, nullable=True + ) # AES-256-GCM encrypted smtp_use_tls: Mapped[bool] = mapped_column(Boolean, default=False, server_default="false") from_address: Mapped[str | None] = mapped_column(Text, nullable=True) to_address: Mapped[str | None] = mapped_column(Text, nullable=True) @@ -110,6 +122,7 @@ class NotificationChannel(Base): class AlertRuleChannel(Base): """Many-to-many association between alert rules and notification channels.""" + __tablename__ = "alert_rule_channels" rule_id: Mapped[uuid.UUID] = mapped_column( @@ -129,6 +142,7 @@ class AlertEvent(Base): rule_id is NULL for system-level alerts (e.g., device offline). """ + __tablename__ = "alert_events" id: Mapped[uuid.UUID] = mapped_column( @@ -158,7 +172,9 @@ class AlertEvent(Base): value: Mapped[float | None] = mapped_column(Numeric, nullable=True) threshold: Mapped[float | None] = mapped_column(Numeric, nullable=True) message: Mapped[str | None] = mapped_column(Text, nullable=True) - is_flapping: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default="false") + is_flapping: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=False, server_default="false" + ) acknowledged_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) acknowledged_by: Mapped[uuid.UUID | None] = mapped_column( UUID(as_uuid=True), diff --git a/backend/app/models/api_key.py b/backend/app/models/api_key.py index bef874a..1b873e8 100644 --- a/backend/app/models/api_key.py +++ b/backend/app/models/api_key.py @@ -41,20 +41,14 @@ class ApiKey(Base): key_prefix: Mapped[str] = mapped_column(Text, nullable=False) key_hash: Mapped[str] = mapped_column(Text, nullable=False, unique=True) scopes: Mapped[list] = mapped_column(JSONB, nullable=False, server_default="'[]'::jsonb") - expires_at: Mapped[Optional[datetime]] = mapped_column( - DateTime(timezone=True), nullable=True - ) - last_used_at: Mapped[Optional[datetime]] = mapped_column( - DateTime(timezone=True), nullable=True - ) + expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + last_used_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False, ) - revoked_at: Mapped[Optional[datetime]] = mapped_column( - DateTime(timezone=True), nullable=True - ) + revoked_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) def __repr__(self) -> str: return f"" diff --git a/backend/app/models/certificate.py b/backend/app/models/certificate.py index 98149f8..23fc77b 100644 --- a/backend/app/models/certificate.py +++ b/backend/app/models/certificate.py @@ -39,21 +39,13 @@ class CertificateAuthority(Base): ) common_name: Mapped[str] = mapped_column(String(255), nullable=False) cert_pem: Mapped[str] = mapped_column(Text, nullable=False) - encrypted_private_key: Mapped[bytes] = mapped_column( - LargeBinary, nullable=False - ) + encrypted_private_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False) serial_number: Mapped[str] = mapped_column(String(64), nullable=False) fingerprint_sha256: Mapped[str] = mapped_column(String(95), nullable=False) - not_valid_before: Mapped[datetime] = mapped_column( - DateTime(timezone=True), nullable=False - ) - not_valid_after: Mapped[datetime] = mapped_column( - DateTime(timezone=True), nullable=False - ) + not_valid_before: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + not_valid_after: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) # OpenBao Transit ciphertext (dual-write migration) - encrypted_private_key_transit: Mapped[str | None] = mapped_column( - Text, nullable=True - ) + encrypted_private_key_transit: Mapped[str | None] = mapped_column(Text, nullable=True) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), @@ -62,8 +54,7 @@ class CertificateAuthority(Base): def __repr__(self) -> str: return ( - f"" + f"" ) @@ -103,25 +94,13 @@ class DeviceCertificate(Base): serial_number: Mapped[str] = mapped_column(String(64), nullable=False) fingerprint_sha256: Mapped[str] = mapped_column(String(95), nullable=False) cert_pem: Mapped[str] = mapped_column(Text, nullable=False) - encrypted_private_key: Mapped[bytes] = mapped_column( - LargeBinary, nullable=False - ) - not_valid_before: Mapped[datetime] = mapped_column( - DateTime(timezone=True), nullable=False - ) - not_valid_after: Mapped[datetime] = mapped_column( - DateTime(timezone=True), nullable=False - ) + encrypted_private_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False) + not_valid_before: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + not_valid_after: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) # OpenBao Transit ciphertext (dual-write migration) - encrypted_private_key_transit: Mapped[str | None] = mapped_column( - Text, nullable=True - ) - status: Mapped[str] = mapped_column( - String(20), nullable=False, server_default="issued" - ) - deployed_at: Mapped[datetime | None] = mapped_column( - DateTime(timezone=True), nullable=True - ) + encrypted_private_key_transit: Mapped[str | None] = mapped_column(Text, nullable=True) + status: Mapped[str] = mapped_column(String(20), nullable=False, server_default="issued") + deployed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), @@ -134,7 +113,4 @@ class DeviceCertificate(Base): ) def __repr__(self) -> str: - return ( - f"" - ) + return f"" diff --git a/backend/app/models/config_backup.py b/backend/app/models/config_backup.py index edc3dd2..e461462 100644 --- a/backend/app/models/config_backup.py +++ b/backend/app/models/config_backup.py @@ -3,9 +3,20 @@ import uuid from datetime import datetime -from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, LargeBinary, SmallInteger, String, Text, UniqueConstraint, func +from sqlalchemy import ( + Boolean, + DateTime, + ForeignKey, + Integer, + LargeBinary, + SmallInteger, + String, + Text, + UniqueConstraint, + func, +) from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.orm import Mapped, mapped_column from app.database import Base @@ -115,7 +126,9 @@ class ConfigBackupSchedule(Base): def __repr__(self) -> str: scope = f"device={self.device_id}" if self.device_id else f"tenant={self.tenant_id}" - return f"" + return ( + f"" + ) class ConfigPushOperation(Base): @@ -173,8 +186,7 @@ class ConfigPushOperation(Base): def __repr__(self) -> str: return ( - f"" + f"" ) @@ -272,7 +284,9 @@ class RouterConfigDiff(Base): ) diff_text: Mapped[str] = mapped_column(Text, nullable=False) lines_added: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0") - lines_removed: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0") + lines_removed: Mapped[int] = mapped_column( + Integer, nullable=False, default=0, server_default="0" + ) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), @@ -334,6 +348,5 @@ class RouterConfigChange(Base): def __repr__(self) -> str: return ( - f"" + f"" ) diff --git a/backend/app/models/config_template.py b/backend/app/models/config_template.py index b375181..454a668 100644 --- a/backend/app/models/config_template.py +++ b/backend/app/models/config_template.py @@ -5,7 +5,6 @@ from datetime import datetime from sqlalchemy import ( DateTime, - Float, ForeignKey, String, Text, @@ -88,9 +87,7 @@ class ConfigTemplateTag(Base): ) # Relationships - template: Mapped["ConfigTemplate"] = relationship( - "ConfigTemplate", back_populates="tags" - ) + template: Mapped["ConfigTemplate"] = relationship("ConfigTemplate", back_populates="tags") def __repr__(self) -> str: return f"" @@ -133,12 +130,8 @@ class TemplatePushJob(Base): ) pre_push_backup_sha: Mapped[str | None] = mapped_column(Text, nullable=True) error_message: Mapped[str | None] = mapped_column(Text, nullable=True) - started_at: Mapped[datetime | None] = mapped_column( - DateTime(timezone=True), nullable=True - ) - completed_at: Mapped[datetime | None] = mapped_column( - DateTime(timezone=True), nullable=True - ) + started_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), diff --git a/backend/app/models/device.py b/backend/app/models/device.py index f4bfb4d..de3080d 100644 --- a/backend/app/models/device.py +++ b/backend/app/models/device.py @@ -5,7 +5,6 @@ from datetime import datetime from enum import Enum from sqlalchemy import ( - Boolean, DateTime, Float, ForeignKey, @@ -24,6 +23,7 @@ from app.database import Base class DeviceStatus(str, Enum): """Device connection status.""" + UNKNOWN = "unknown" ONLINE = "online" OFFLINE = "offline" @@ -31,9 +31,7 @@ class DeviceStatus(str, Enum): class Device(Base): __tablename__ = "devices" - __table_args__ = ( - UniqueConstraint("tenant_id", "hostname", name="uq_devices_tenant_hostname"), - ) + __table_args__ = (UniqueConstraint("tenant_id", "hostname", name="uq_devices_tenant_hostname"),) id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), @@ -59,7 +57,9 @@ class Device(Base): uptime_seconds: Mapped[int | None] = mapped_column(Integer, nullable=True) last_cpu_load: Mapped[int | None] = mapped_column(Integer, nullable=True) last_memory_used_pct: Mapped[int | None] = mapped_column(Integer, nullable=True) - architecture: Mapped[str | None] = mapped_column(Text, nullable=True) # CPU arch (arm, arm64, mipsbe, etc.) + architecture: Mapped[str | None] = mapped_column( + Text, nullable=True + ) # CPU arch (arm, arm64, mipsbe, etc.) preferred_channel: Mapped[str] = mapped_column( Text, default="stable", server_default="stable", nullable=False ) # Firmware release channel @@ -108,9 +108,7 @@ class Device(Base): class DeviceGroup(Base): __tablename__ = "device_groups" - __table_args__ = ( - UniqueConstraint("tenant_id", "name", name="uq_device_groups_tenant_name"), - ) + __table_args__ = (UniqueConstraint("tenant_id", "name", name="uq_device_groups_tenant_name"),) id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), @@ -147,9 +145,7 @@ class DeviceGroup(Base): class DeviceTag(Base): __tablename__ = "device_tags" - __table_args__ = ( - UniqueConstraint("tenant_id", "name", name="uq_device_tags_tenant_name"), - ) + __table_args__ = (UniqueConstraint("tenant_id", "name", name="uq_device_tags_tenant_name"),) id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), diff --git a/backend/app/models/firmware.py b/backend/app/models/firmware.py index 67385c5..01dbd74 100644 --- a/backend/app/models/firmware.py +++ b/backend/app/models/firmware.py @@ -7,7 +7,6 @@ from sqlalchemy import ( BigInteger, Boolean, DateTime, - Integer, Text, UniqueConstraint, func, @@ -24,10 +23,9 @@ class FirmwareVersion(Base): Not tenant-scoped — firmware versions are global data shared across all tenants. """ + __tablename__ = "firmware_versions" - __table_args__ = ( - UniqueConstraint("architecture", "channel", "version"), - ) + __table_args__ = (UniqueConstraint("architecture", "channel", "version"),) id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), @@ -56,6 +54,7 @@ class FirmwareUpgradeJob(Base): Multiple jobs can share a rollout_group_id for mass upgrades. """ + __tablename__ = "firmware_upgrade_jobs" id: Mapped[uuid.UUID] = mapped_column( @@ -99,4 +98,6 @@ class FirmwareUpgradeJob(Base): ) def __repr__(self) -> str: - return f"" + return ( + f"" + ) diff --git a/backend/app/models/key_set.py b/backend/app/models/key_set.py index 4124515..37a2a06 100644 --- a/backend/app/models/key_set.py +++ b/backend/app/models/key_set.py @@ -37,32 +37,18 @@ class UserKeySet(Base): ForeignKey("tenants.id", ondelete="CASCADE"), nullable=True, # NULL for super_admin ) - encrypted_private_key: Mapped[bytes] = mapped_column( - LargeBinary, nullable=False - ) - private_key_nonce: Mapped[bytes] = mapped_column( - LargeBinary, nullable=False - ) - encrypted_vault_key: Mapped[bytes] = mapped_column( - LargeBinary, nullable=False - ) - vault_key_nonce: Mapped[bytes] = mapped_column( - LargeBinary, nullable=False - ) - public_key: Mapped[bytes] = mapped_column( - LargeBinary, nullable=False - ) + encrypted_private_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False) + private_key_nonce: Mapped[bytes] = mapped_column(LargeBinary, nullable=False) + encrypted_vault_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False) + vault_key_nonce: Mapped[bytes] = mapped_column(LargeBinary, nullable=False) + public_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False) pbkdf2_iterations: Mapped[int] = mapped_column( Integer, server_default=func.literal_column("650000"), nullable=False, ) - pbkdf2_salt: Mapped[bytes] = mapped_column( - LargeBinary, nullable=False - ) - hkdf_salt: Mapped[bytes] = mapped_column( - LargeBinary, nullable=False - ) + pbkdf2_salt: Mapped[bytes] = mapped_column(LargeBinary, nullable=False) + hkdf_salt: Mapped[bytes] = mapped_column(LargeBinary, nullable=False) key_version: Mapped[int] = mapped_column( Integer, server_default=func.literal_column("1"), diff --git a/backend/app/models/maintenance_window.py b/backend/app/models/maintenance_window.py index ec3a9f8..886a3cb 100644 --- a/backend/app/models/maintenance_window.py +++ b/backend/app/models/maintenance_window.py @@ -20,6 +20,7 @@ class MaintenanceWindow(Base): device_ids is a JSONB array of device UUID strings. An empty array means "all devices in tenant". """ + __tablename__ = "maintenance_windows" id: Mapped[uuid.UUID] = mapped_column( diff --git a/backend/app/models/tenant.py b/backend/app/models/tenant.py index 4271be4..565c15f 100644 --- a/backend/app/models/tenant.py +++ b/backend/app/models/tenant.py @@ -40,10 +40,18 @@ class Tenant(Base): openbao_key_name: Mapped[str | None] = mapped_column(Text, nullable=True) # Relationships — passive_deletes=True lets the DB ON DELETE CASCADE handle cleanup - users: Mapped[list["User"]] = relationship("User", back_populates="tenant", passive_deletes=True) # type: ignore[name-defined] - devices: Mapped[list["Device"]] = relationship("Device", back_populates="tenant", passive_deletes=True) # type: ignore[name-defined] - device_groups: Mapped[list["DeviceGroup"]] = relationship("DeviceGroup", back_populates="tenant", passive_deletes=True) # type: ignore[name-defined] - device_tags: Mapped[list["DeviceTag"]] = relationship("DeviceTag", back_populates="tenant", passive_deletes=True) # type: ignore[name-defined] + users: Mapped[list["User"]] = relationship( + "User", back_populates="tenant", passive_deletes=True + ) # type: ignore[name-defined] + devices: Mapped[list["Device"]] = relationship( + "Device", back_populates="tenant", passive_deletes=True + ) # type: ignore[name-defined] + device_groups: Mapped[list["DeviceGroup"]] = relationship( + "DeviceGroup", back_populates="tenant", passive_deletes=True + ) # type: ignore[name-defined] + device_tags: Mapped[list["DeviceTag"]] = relationship( + "DeviceTag", back_populates="tenant", passive_deletes=True + ) # type: ignore[name-defined] def __repr__(self) -> str: return f"" diff --git a/backend/app/models/user.py b/backend/app/models/user.py index 8b43f4b..ba7043d 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -13,6 +13,7 @@ from app.database import Base class UserRole(str, Enum): """User roles with increasing privilege levels.""" + SUPER_ADMIN = "super_admin" TENANT_ADMIN = "tenant_admin" OPERATOR = "operator" diff --git a/backend/app/models/vpn.py b/backend/app/models/vpn.py index 7db504d..c5dbe6b 100644 --- a/backend/app/models/vpn.py +++ b/backend/app/models/vpn.py @@ -75,7 +75,9 @@ class VpnPeer(Base): assigned_ip: Mapped[str] = mapped_column(String(32), nullable=False) additional_allowed_ips: Mapped[Optional[str]] = mapped_column(String(512), nullable=True) is_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default="true") - last_handshake: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + last_handshake: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), nullable=True + ) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) diff --git a/backend/app/routers/alerts.py b/backend/app/routers/alerts.py index 02dad2d..ee11932 100644 --- a/backend/app/routers/alerts.py +++ b/backend/app/routers/alerts.py @@ -12,10 +12,8 @@ RLS enforced via get_db() (app_user engine with tenant context). RBAC: viewer = read-only (GET); operator and above = write (POST/PUT/PATCH/DELETE). """ -import base64 import logging import uuid -from datetime import datetime, timedelta, timezone from typing import Any, Optional from fastapi import APIRouter, Depends, HTTPException, Query, Request, status @@ -66,8 +64,13 @@ def _require_write(current_user: CurrentUser) -> None: EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") ALLOWED_METRICS = { - "cpu_load", "memory_used_pct", "disk_used_pct", "temperature", - "signal_strength", "ccq", "client_count", + "cpu_load", + "memory_used_pct", + "disk_used_pct", + "temperature", + "signal_strength", + "ccq", + "client_count", } ALLOWED_OPERATORS = {"gt", "lt", "gte", "lte"} ALLOWED_SEVERITIES = {"critical", "warning", "info"} @@ -252,7 +255,9 @@ async def create_alert_rule( if body.operator not in ALLOWED_OPERATORS: raise HTTPException(422, f"operator must be one of: {', '.join(sorted(ALLOWED_OPERATORS))}") if body.severity not in ALLOWED_SEVERITIES: - raise HTTPException(422, f"severity must be one of: {', '.join(sorted(ALLOWED_SEVERITIES))}") + raise HTTPException( + 422, f"severity must be one of: {', '.join(sorted(ALLOWED_SEVERITIES))}" + ) rule_id = str(uuid.uuid4()) @@ -296,8 +301,12 @@ async def create_alert_rule( try: await log_action( - db, tenant_id, current_user.user_id, "alert_rule_create", - resource_type="alert_rule", resource_id=rule_id, + db, + tenant_id, + current_user.user_id, + "alert_rule_create", + resource_type="alert_rule", + resource_id=rule_id, details={"name": body.name, "metric": body.metric, "severity": body.severity}, ) except Exception: @@ -338,7 +347,9 @@ async def update_alert_rule( if body.operator not in ALLOWED_OPERATORS: raise HTTPException(422, f"operator must be one of: {', '.join(sorted(ALLOWED_OPERATORS))}") if body.severity not in ALLOWED_SEVERITIES: - raise HTTPException(422, f"severity must be one of: {', '.join(sorted(ALLOWED_SEVERITIES))}") + raise HTTPException( + 422, f"severity must be one of: {', '.join(sorted(ALLOWED_SEVERITIES))}" + ) result = await db.execute( text(""" @@ -384,8 +395,12 @@ async def update_alert_rule( try: await log_action( - db, tenant_id, current_user.user_id, "alert_rule_update", - resource_type="alert_rule", resource_id=str(rule_id), + db, + tenant_id, + current_user.user_id, + "alert_rule_update", + resource_type="alert_rule", + resource_id=str(rule_id), details={"name": body.name, "metric": body.metric, "severity": body.severity}, ) except Exception: @@ -439,8 +454,12 @@ async def delete_alert_rule( try: await log_action( - db, tenant_id, current_user.user_id, "alert_rule_delete", - resource_type="alert_rule", resource_id=str(rule_id), + db, + tenant_id, + current_user.user_id, + "alert_rule_delete", + resource_type="alert_rule", + resource_id=str(rule_id), ) except Exception: pass @@ -592,7 +611,8 @@ async def create_notification_channel( encrypted_password_transit = None if body.smtp_password: encrypted_password_transit = await encrypt_credentials_transit( - body.smtp_password, str(tenant_id), + body.smtp_password, + str(tenant_id), ) await db.execute( @@ -665,10 +685,14 @@ async def update_notification_channel( # Build SET clauses dynamically based on which secrets are provided set_parts = [ - "name = :name", "channel_type = :channel_type", - "smtp_host = :smtp_host", "smtp_port = :smtp_port", - "smtp_user = :smtp_user", "smtp_use_tls = :smtp_use_tls", - "from_address = :from_address", "to_address = :to_address", + "name = :name", + "channel_type = :channel_type", + "smtp_host = :smtp_host", + "smtp_port = :smtp_port", + "smtp_user = :smtp_user", + "smtp_use_tls = :smtp_use_tls", + "from_address = :from_address", + "to_address = :to_address", "webhook_url = :webhook_url", "slack_webhook_url = :slack_webhook_url", ] @@ -689,7 +713,8 @@ async def update_notification_channel( if body.smtp_password: set_parts.append("smtp_password_transit = :smtp_password_transit") params["smtp_password_transit"] = await encrypt_credentials_transit( - body.smtp_password, str(tenant_id), + body.smtp_password, + str(tenant_id), ) # Clear legacy column set_parts.append("smtp_password = NULL") @@ -799,6 +824,7 @@ async def test_notification_channel( } from app.services.notification_service import send_test_notification + try: success = await send_test_notification(channel) if success: diff --git a/backend/app/routers/audit_logs.py b/backend/app/routers/audit_logs.py index a769b3a..0370123 100644 --- a/backend/app/routers/audit_logs.py +++ b/backend/app/routers/audit_logs.py @@ -221,29 +221,38 @@ async def list_audit_logs( all_rows = result.mappings().all() # Decrypt encrypted details concurrently - decrypted_details = await _decrypt_details_batch( - all_rows, str(tenant_id) - ) + decrypted_details = await _decrypt_details_batch(all_rows, str(tenant_id)) output = io.StringIO() writer = csv.writer(output) - writer.writerow([ - "ID", "User Email", "Action", "Resource Type", - "Resource ID", "Device", "Details", "IP Address", "Timestamp", - ]) + writer.writerow( + [ + "ID", + "User Email", + "Action", + "Resource Type", + "Resource ID", + "Device", + "Details", + "IP Address", + "Timestamp", + ] + ) for row, details in zip(all_rows, decrypted_details): details_str = json.dumps(details) if details else "{}" - writer.writerow([ - str(row["id"]), - row["user_email"] or "", - row["action"], - row["resource_type"] or "", - row["resource_id"] or "", - row["device_name"] or "", - details_str, - row["ip_address"] or "", - str(row["created_at"]), - ]) + writer.writerow( + [ + str(row["id"]), + row["user_email"] or "", + row["action"], + row["resource_type"] or "", + row["resource_id"] or "", + row["device_name"] or "", + details_str, + row["ip_address"] or "", + str(row["created_at"]), + ] + ) output.seek(0) return StreamingResponse( diff --git a/backend/app/routers/auth.py b/backend/app/routers/auth.py index edb25c8..f92ae35 100644 --- a/backend/app/routers/auth.py +++ b/backend/app/routers/auth.py @@ -103,7 +103,11 @@ async def get_redis() -> aioredis.Redis: # ─── SRP Zero-Knowledge Authentication ─────────────────────────────────────── -@router.post("/srp/init", response_model=SRPInitResponse, summary="SRP Step 1: return salt and server ephemeral B") +@router.post( + "/srp/init", + response_model=SRPInitResponse, + summary="SRP Step 1: return salt and server ephemeral B", +) @limiter.limit("5/minute") async def srp_init_endpoint( request: StarletteRequest, @@ -137,9 +141,7 @@ async def srp_init_endpoint( # Generate server ephemeral try: - server_public, server_private = await srp_init( - user.email, user.srp_verifier.hex() - ) + server_public, server_private = await srp_init(user.email, user.srp_verifier.hex()) except ValueError as e: logger.error("SRP init failed for %s: %s", user.email, e) raise HTTPException( @@ -150,13 +152,15 @@ async def srp_init_endpoint( # Store session in Redis with 60s TTL session_id = secrets.token_urlsafe(16) redis = await get_redis() - session_data = json.dumps({ - "email": user.email, - "server_private": server_private, - "srp_verifier_hex": user.srp_verifier.hex(), - "srp_salt_hex": user.srp_salt.hex(), - "user_id": str(user.id), - }) + session_data = json.dumps( + { + "email": user.email, + "server_private": server_private, + "srp_verifier_hex": user.srp_verifier.hex(), + "srp_salt_hex": user.srp_salt.hex(), + "user_id": str(user.id), + } + ) await redis.set(f"srp:session:{session_id}", session_data, ex=60) return SRPInitResponse( @@ -168,7 +172,11 @@ async def srp_init_endpoint( ) -@router.post("/srp/verify", response_model=SRPVerifyResponse, summary="SRP Step 2: verify client proof and return tokens") +@router.post( + "/srp/verify", + response_model=SRPVerifyResponse, + summary="SRP Step 2: verify client proof and return tokens", +) @limiter.limit("5/minute") async def srp_verify_endpoint( request: StarletteRequest, @@ -236,7 +244,9 @@ async def srp_verify_endpoint( # Update last_login and clear upgrade flag on successful SRP login await db.execute( - update(User).where(User.id == user.id).values( + update(User) + .where(User.id == user.id) + .values( last_login=datetime.now(UTC), must_upgrade_auth=False, ) @@ -323,9 +333,7 @@ async def login( Rate limited to 5 requests per minute per IP. """ # Look up user by email (case-insensitive) - result = await db.execute( - select(User).where(User.email == body.email.lower()) - ) + result = await db.execute(select(User).where(User.email == body.email.lower())) user = result.scalar_one_or_none() # Generic error — do not reveal whether email exists (no user enumeration) @@ -389,7 +397,9 @@ async def login( # Update last_login await db.execute( - update(User).where(User.id == user.id).values( + update(User) + .where(User.id == user.id) + .values( last_login=datetime.now(UTC), ) ) @@ -404,7 +414,10 @@ async def login( user_id=user.id, action="login_upgrade" if user.must_upgrade_auth else "login", resource_type="auth", - details={"email": user.email, **({"upgrade": "bcrypt_to_srp"} if user.must_upgrade_auth else {})}, + details={ + "email": user.email, + **({"upgrade": "bcrypt_to_srp"} if user.must_upgrade_auth else {}), + }, ip_address=request.client.host if request.client else None, ) await audit_db.commit() @@ -440,7 +453,9 @@ async def refresh_token( Rate limited to 10 requests per minute per IP. """ # Resolve token: body takes precedence over cookie - raw_token = (body.refresh_token if body and body.refresh_token else None) or refresh_token_cookie + raw_token = ( + body.refresh_token if body and body.refresh_token else None + ) or refresh_token_cookie if not raw_token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -518,7 +533,9 @@ async def refresh_token( ) -@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT, summary="Log out and clear session cookie") +@router.post( + "/logout", status_code=status.HTTP_204_NO_CONTENT, summary="Log out and clear session cookie" +) @limiter.limit("10/minute") async def logout( request: StarletteRequest, @@ -535,7 +552,10 @@ async def logout( tenant_id = current_user.tenant_id or uuid.UUID(int=0) async with AdminAsyncSessionLocal() as audit_db: await log_action( - audit_db, tenant_id, current_user.user_id, "logout", + audit_db, + tenant_id, + current_user.user_id, + "logout", resource_type="auth", ip_address=request.client.host if request.client else None, ) @@ -558,7 +578,11 @@ async def logout( ) -@router.post("/change-password", response_model=MessageResponse, summary="Change password for authenticated user") +@router.post( + "/change-password", + response_model=MessageResponse, + summary="Change password for authenticated user", +) @limiter.limit("3/minute") async def change_password( request: StarletteRequest, @@ -602,7 +626,9 @@ async def change_password( existing_ks.hkdf_salt = base64.b64decode(body.hkdf_salt or "") else: # Legacy bcrypt user — verify current password - if not user.hashed_password or not verify_password(body.current_password, user.hashed_password): + if not user.hashed_password or not verify_password( + body.current_password, user.hashed_password + ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Current password is incorrect", @@ -822,7 +848,9 @@ async def get_emergency_kit_template( ) -@router.post("/register-srp", response_model=MessageResponse, summary="Register SRP credentials for a user") +@router.post( + "/register-srp", response_model=MessageResponse, summary="Register SRP credentials for a user" +) @limiter.limit("3/minute") async def register_srp( request: StarletteRequest, @@ -845,7 +873,9 @@ async def register_srp( # Update user with SRP credentials and clear upgrade flag await db.execute( - update(User).where(User.id == user.id).values( + update(User) + .where(User.id == user.id) + .values( srp_salt=bytes.fromhex(body.srp_salt), srp_verifier=bytes.fromhex(body.srp_verifier), auth_version=2, @@ -873,8 +903,11 @@ async def register_srp( try: async with AdminAsyncSessionLocal() as audit_db: await log_key_access( - audit_db, user.tenant_id or uuid.UUID(int=0), user.id, - "create_key_set", resource_type="user_key_set", + audit_db, + user.tenant_id or uuid.UUID(int=0), + user.id, + "create_key_set", + resource_type="user_key_set", ip_address=request.client.host if request.client else None, ) await audit_db.commit() @@ -901,11 +934,17 @@ async def create_sse_token( token = secrets.token_urlsafe(32) key = f"sse_token:{token}" # Store user context for the SSE endpoint to retrieve - await redis.set(key, json.dumps({ - "user_id": str(current_user.user_id), - "tenant_id": str(current_user.tenant_id) if current_user.tenant_id else None, - "role": current_user.role, - }), ex=30) # 30 second TTL + await redis.set( + key, + json.dumps( + { + "user_id": str(current_user.user_id), + "tenant_id": str(current_user.tenant_id) if current_user.tenant_id else None, + "role": current_user.role, + } + ), + ex=30, + ) # 30 second TTL return {"token": token} @@ -977,9 +1016,7 @@ async def forgot_password( """ generic_msg = "If an account with that email exists, a reset link has been sent." - result = await db.execute( - select(User).where(User.email == body.email.lower()) - ) + result = await db.execute(select(User).where(User.email == body.email.lower())) user = result.scalar_one_or_none() if not user or not user.is_active: @@ -988,9 +1025,7 @@ async def forgot_password( # Generate a secure token raw_token = secrets.token_urlsafe(32) token_hash = _hash_token(raw_token) - expires_at = datetime.now(UTC) + timedelta( - minutes=settings.PASSWORD_RESET_TOKEN_EXPIRE_MINUTES - ) + expires_at = datetime.now(UTC) + timedelta(minutes=settings.PASSWORD_RESET_TOKEN_EXPIRE_MINUTES) # Insert token record (using raw SQL to avoid importing the model globally) from sqlalchemy import text diff --git a/backend/app/routers/certificates.py b/backend/app/routers/certificates.py index effe93f..05a84dd 100644 --- a/backend/app/routers/certificates.py +++ b/backend/app/routers/certificates.py @@ -12,9 +12,7 @@ RBAC: viewer = read-only (GET); tenant_admin and above = mutating actions. """ import json -import logging import uuid -from datetime import datetime, timezone import nats import nats.aio.client @@ -30,7 +28,7 @@ from app.database import get_db, set_tenant_context from app.middleware.rate_limit import limiter from app.middleware.rbac import require_min_role from app.middleware.tenant_context import CurrentUser, get_current_user -from app.models.certificate import CertificateAuthority, DeviceCertificate +from app.models.certificate import DeviceCertificate from app.models.device import Device from app.schemas.certificate import ( BulkCertDeployRequest, @@ -87,13 +85,15 @@ async def _deploy_cert_via_nats( Dict with success, cert_name_on_device, and error fields. """ nc = await _get_nats() - payload = json.dumps({ - "device_id": device_id, - "cert_pem": cert_pem, - "key_pem": key_pem, - "cert_name": cert_name, - "ssh_port": ssh_port, - }).encode() + payload = json.dumps( + { + "device_id": device_id, + "cert_pem": cert_pem, + "key_pem": key_pem, + "cert_name": cert_name, + "ssh_port": ssh_port, + } + ).encode() try: reply = await nc.request( @@ -121,9 +121,7 @@ async def _get_device_for_tenant( db: AsyncSession, device_id: uuid.UUID, current_user: CurrentUser ) -> Device: """Fetch a device and verify tenant ownership.""" - result = await db.execute( - select(Device).where(Device.id == device_id) - ) + result = await db.execute(select(Device).where(Device.id == device_id)) device = result.scalar_one_or_none() if device is None: raise HTTPException( @@ -164,9 +162,7 @@ async def _get_cert_with_tenant_check( db: AsyncSession, cert_id: uuid.UUID, tenant_id: uuid.UUID ) -> DeviceCertificate: """Fetch a device certificate and verify tenant ownership.""" - result = await db.execute( - select(DeviceCertificate).where(DeviceCertificate.id == cert_id) - ) + result = await db.execute(select(DeviceCertificate).where(DeviceCertificate.id == cert_id)) cert = result.scalar_one_or_none() if cert is None: raise HTTPException( @@ -226,8 +222,12 @@ async def create_ca( try: await log_action( - db, tenant_id, current_user.user_id, "ca_create", - resource_type="certificate_authority", resource_id=str(ca.id), + db, + tenant_id, + current_user.user_id, + "ca_create", + resource_type="certificate_authority", + resource_id=str(ca.id), details={"common_name": body.common_name, "validity_years": body.validity_years}, ) except Exception: @@ -332,8 +332,12 @@ async def sign_cert( try: await log_action( - db, tenant_id, current_user.user_id, "cert_sign", - resource_type="device_certificate", resource_id=str(cert.id), + db, + tenant_id, + current_user.user_id, + "cert_sign", + resource_type="device_certificate", + resource_id=str(cert.id), device_id=body.device_id, details={"hostname": device.hostname, "validity_days": body.validity_days}, ) @@ -404,17 +408,19 @@ async def deploy_cert( await update_cert_status(db, cert_id, "deployed") # Update device tls_mode to portal_ca - device_result = await db.execute( - select(Device).where(Device.id == cert.device_id) - ) + device_result = await db.execute(select(Device).where(Device.id == cert.device_id)) device = device_result.scalar_one_or_none() if device is not None: device.tls_mode = "portal_ca" try: await log_action( - db, tenant_id, current_user.user_id, "cert_deploy", - resource_type="device_certificate", resource_id=str(cert_id), + db, + tenant_id, + current_user.user_id, + "cert_deploy", + resource_type="device_certificate", + resource_id=str(cert_id), device_id=cert.device_id, details={"cert_name_on_device": result.get("cert_name_on_device")}, ) @@ -528,36 +534,47 @@ async def bulk_deploy( await update_cert_status(db, issued_cert.id, "deployed") device.tls_mode = "portal_ca" - results.append(CertDeployResponse( - success=True, - device_id=device_id, - cert_name_on_device=result.get("cert_name_on_device"), - )) + results.append( + CertDeployResponse( + success=True, + device_id=device_id, + cert_name_on_device=result.get("cert_name_on_device"), + ) + ) else: await update_cert_status(db, issued_cert.id, "issued") - results.append(CertDeployResponse( - success=False, - device_id=device_id, - error=result.get("error"), - )) + results.append( + CertDeployResponse( + success=False, + device_id=device_id, + error=result.get("error"), + ) + ) except HTTPException as e: - results.append(CertDeployResponse( - success=False, - device_id=device_id, - error=e.detail, - )) + results.append( + CertDeployResponse( + success=False, + device_id=device_id, + error=e.detail, + ) + ) except Exception as e: logger.error("Bulk deploy error", device_id=str(device_id), error=str(e)) - results.append(CertDeployResponse( - success=False, - device_id=device_id, - error=str(e), - )) + results.append( + CertDeployResponse( + success=False, + device_id=device_id, + error=str(e), + ) + ) try: await log_action( - db, tenant_id, current_user.user_id, "cert_bulk_deploy", + db, + tenant_id, + current_user.user_id, + "cert_bulk_deploy", resource_type="device_certificate", details={ "device_count": len(body.device_ids), @@ -619,17 +636,19 @@ async def revoke_cert( ) # Reset device tls_mode to insecure - device_result = await db.execute( - select(Device).where(Device.id == cert.device_id) - ) + device_result = await db.execute(select(Device).where(Device.id == cert.device_id)) device = device_result.scalar_one_or_none() if device is not None: device.tls_mode = "insecure" try: await log_action( - db, tenant_id, current_user.user_id, "cert_revoke", - resource_type="device_certificate", resource_id=str(cert_id), + db, + tenant_id, + current_user.user_id, + "cert_revoke", + resource_type="device_certificate", + resource_id=str(cert_id), device_id=cert.device_id, ) except Exception: @@ -661,9 +680,7 @@ async def rotate_cert( old_cert = await _get_cert_with_tenant_check(db, cert_id, tenant_id) # Get the device for hostname/IP - device_result = await db.execute( - select(Device).where(Device.id == old_cert.device_id) - ) + device_result = await db.execute(select(Device).where(Device.id == old_cert.device_id)) device = device_result.scalar_one_or_none() if device is None: raise HTTPException( @@ -722,8 +739,12 @@ async def rotate_cert( try: await log_action( - db, tenant_id, current_user.user_id, "cert_rotate", - resource_type="device_certificate", resource_id=str(new_cert.id), + db, + tenant_id, + current_user.user_id, + "cert_rotate", + resource_type="device_certificate", + resource_id=str(new_cert.id), device_id=old_cert.device_id, details={ "old_cert_id": str(cert_id), diff --git a/backend/app/routers/clients.py b/backend/app/routers/clients.py index c66f096..ff3c678 100644 --- a/backend/app/routers/clients.py +++ b/backend/app/routers/clients.py @@ -43,6 +43,7 @@ async def _check_tenant_access( """Verify the current user is allowed to access the given tenant.""" if current_user.is_super_admin: from app.database import set_tenant_context + await set_tenant_context(db, str(tenant_id)) return if current_user.tenant_id != tenant_id: @@ -52,9 +53,7 @@ async def _check_tenant_access( ) -async def _check_device_online( - db: AsyncSession, device_id: uuid.UUID -) -> Device: +async def _check_device_online(db: AsyncSession, device_id: uuid.UUID) -> Device: """Verify the device exists and is online. Returns the Device object.""" result = await db.execute( select(Device).where(Device.id == device_id) # type: ignore[arg-type] diff --git a/backend/app/routers/config_backups.py b/backend/app/routers/config_backups.py index 682417e..d6d0b21 100644 --- a/backend/app/routers/config_backups.py +++ b/backend/app/routers/config_backups.py @@ -25,7 +25,6 @@ import asyncio import json import logging import uuid -from datetime import timezone, datetime from typing import Any from fastapi import APIRouter, Depends, HTTPException, Request, status @@ -67,6 +66,7 @@ async def _check_tenant_access( """ if current_user.is_super_admin: from app.database import set_tenant_context + await set_tenant_context(db, str(tenant_id)) return if current_user.tenant_id != tenant_id: @@ -291,14 +291,14 @@ async def get_export( try: from app.services.crypto import decrypt_data_transit - plaintext = await decrypt_data_transit( - content_bytes.decode("utf-8"), str(tenant_id) - ) + plaintext = await decrypt_data_transit(content_bytes.decode("utf-8"), str(tenant_id)) content_bytes = plaintext.encode("utf-8") except Exception as dec_err: logger.error( "Failed to decrypt export for device %s sha %s: %s", - device_id, commit_sha, dec_err, + device_id, + commit_sha, + dec_err, ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -370,7 +370,9 @@ async def get_binary( except Exception as dec_err: logger.error( "Failed to decrypt binary backup for device %s sha %s: %s", - device_id, commit_sha, dec_err, + device_id, + commit_sha, + dec_err, ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -380,9 +382,7 @@ async def get_binary( return Response( content=content_bytes, media_type="application/octet-stream", - headers={ - "Content-Disposition": f'attachment; filename="backup-{commit_sha[:8]}.bin"' - }, + headers={"Content-Disposition": f'attachment; filename="backup-{commit_sha[:8]}.bin"'}, ) @@ -445,6 +445,7 @@ async def preview_restore( key, ) import json + creds = json.loads(creds_json) current_text = await backup_service.capture_export( device.ip_address, @@ -578,9 +579,7 @@ async def emergency_rollback( .where( ConfigBackupRun.device_id == device_id, # type: ignore[arg-type] ConfigBackupRun.tenant_id == tenant_id, # type: ignore[arg-type] - ConfigBackupRun.trigger_type.in_( - ["pre-restore", "checkpoint", "pre-template-push"] - ), + ConfigBackupRun.trigger_type.in_(["pre-restore", "checkpoint", "pre-template-push"]), ) .order_by(ConfigBackupRun.created_at.desc()) .limit(1) @@ -735,6 +734,7 @@ async def update_schedule( # Hot-reload the scheduler so changes take effect immediately from app.services.backup_scheduler import on_schedule_change + await on_schedule_change(tenant_id, device_id) return { @@ -758,6 +758,7 @@ async def _get_nats(): Reuses the same lazy-init pattern as routeros_proxy._get_nats(). """ from app.services.routeros_proxy import _get_nats as _proxy_get_nats + return await _proxy_get_nats() @@ -839,6 +840,7 @@ async def trigger_config_snapshot( if reply_data.get("status") == "success": try: from app.services.audit_service import log_action + await log_action( db, tenant_id, diff --git a/backend/app/routers/config_editor.py b/backend/app/routers/config_editor.py index 2e2833a..735bc42 100644 --- a/backend/app/routers/config_editor.py +++ b/backend/app/routers/config_editor.py @@ -64,9 +64,7 @@ async def _check_tenant_access( await set_tenant_context(db, str(tenant_id)) -async def _check_device_online( - db: AsyncSession, device_id: uuid.UUID -) -> Device: +async def _check_device_online(db: AsyncSession, device_id: uuid.UUID) -> Device: """Verify the device exists and is online. Returns the Device object.""" result = await db.execute( select(Device).where(Device.id == device_id) # type: ignore[arg-type] @@ -201,8 +199,12 @@ async def add_entry( try: await log_action( - db, tenant_id, current_user.user_id, "config_add", - resource_type="config", resource_id=str(device_id), + db, + tenant_id, + current_user.user_id, + "config_add", + resource_type="config", + resource_id=str(device_id), device_id=device_id, details={"path": body.path, "properties": body.properties}, ) @@ -255,8 +257,12 @@ async def set_entry( try: await log_action( - db, tenant_id, current_user.user_id, "config_set", - resource_type="config", resource_id=str(device_id), + db, + tenant_id, + current_user.user_id, + "config_set", + resource_type="config", + resource_id=str(device_id), device_id=device_id, details={"path": body.path, "entry_id": body.entry_id, "properties": body.properties}, ) @@ -286,9 +292,7 @@ async def remove_entry( await _check_device_online(db, device_id) check_path_safety(body.path, write=True) - result = await routeros_proxy.remove_entry( - str(device_id), body.path, body.entry_id - ) + result = await routeros_proxy.remove_entry(str(device_id), body.path, body.entry_id) if not result.get("success"): raise HTTPException( @@ -309,8 +313,12 @@ async def remove_entry( try: await log_action( - db, tenant_id, current_user.user_id, "config_remove", - resource_type="config", resource_id=str(device_id), + db, + tenant_id, + current_user.user_id, + "config_remove", + resource_type="config", + resource_id=str(device_id), device_id=device_id, details={"path": body.path, "entry_id": body.entry_id}, ) @@ -360,8 +368,12 @@ async def execute_command( try: await log_action( - db, tenant_id, current_user.user_id, "config_execute", - resource_type="config", resource_id=str(device_id), + db, + tenant_id, + current_user.user_id, + "config_execute", + resource_type="config", + resource_id=str(device_id), device_id=device_id, details={"command": body.command}, ) diff --git a/backend/app/routers/config_history.py b/backend/app/routers/config_history.py index 7f2be99..e3335cb 100644 --- a/backend/app/routers/config_history.py +++ b/backend/app/routers/config_history.py @@ -43,6 +43,7 @@ async def _check_tenant_access( """ if current_user.is_super_admin: from app.database import set_tenant_context + await set_tenant_context(db, str(tenant_id)) return if current_user.tenant_id != tenant_id: @@ -115,9 +116,7 @@ async def view_snapshot( session=db, ) except Exception: - logger.exception( - "Failed to decrypt snapshot %s for device %s", snapshot_id, device_id - ) + logger.exception("Failed to decrypt snapshot %s for device %s", snapshot_id, device_id) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to decrypt snapshot content", diff --git a/backend/app/routers/device_logs.py b/backend/app/routers/device_logs.py index bdfe07c..27286ae 100644 --- a/backend/app/routers/device_logs.py +++ b/backend/app/routers/device_logs.py @@ -29,12 +29,14 @@ router = APIRouter(tags=["device-logs"]) # Helpers (same pattern as config_editor.py) # --------------------------------------------------------------------------- + async def _check_tenant_access( current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession ) -> None: """Verify the current user is allowed to access the given tenant.""" if current_user.is_super_admin: from app.database import set_tenant_context + await set_tenant_context(db, str(tenant_id)) return if current_user.tenant_id != tenant_id: @@ -44,16 +46,12 @@ async def _check_tenant_access( ) -async def _check_device_exists( - db: AsyncSession, device_id: uuid.UUID -) -> None: +async def _check_device_exists(db: AsyncSession, device_id: uuid.UUID) -> None: """Verify the device exists (does not require online status for logs).""" from sqlalchemy import select from app.models.device import Device - result = await db.execute( - select(Device).where(Device.id == device_id) - ) + result = await db.execute(select(Device).where(Device.id == device_id)) device = result.scalar_one_or_none() if device is None: raise HTTPException( @@ -66,6 +64,7 @@ async def _check_device_exists( # Response model # --------------------------------------------------------------------------- + class LogEntry(BaseModel): time: str topics: str @@ -82,6 +81,7 @@ class LogsResponse(BaseModel): # Endpoint # --------------------------------------------------------------------------- + @router.get( "/tenants/{tenant_id}/devices/{device_id}/logs", response_model=LogsResponse, diff --git a/backend/app/routers/device_tags.py b/backend/app/routers/device_tags.py index 523cca1..368b149 100644 --- a/backend/app/routers/device_tags.py +++ b/backend/app/routers/device_tags.py @@ -72,9 +72,7 @@ async def update_tag( ) -> DeviceTagResponse: """Update a device tag. Requires operator role or above.""" await _check_tenant_access(current_user, tenant_id, db) - return await device_service.update_tag( - db=db, tenant_id=tenant_id, tag_id=tag_id, data=data - ) + return await device_service.update_tag(db=db, tenant_id=tenant_id, tag_id=tag_id, data=data) @router.delete( diff --git a/backend/app/routers/devices.py b/backend/app/routers/devices.py index c3ac89b..ad8508d 100644 --- a/backend/app/routers/devices.py +++ b/backend/app/routers/devices.py @@ -23,7 +23,6 @@ from app.database import get_db from app.middleware.rate_limit import limiter from app.services.audit_service import log_action from app.middleware.rbac import ( - require_min_role, require_operator_or_above, require_scope, require_tenant_admin_or_above, @@ -57,6 +56,7 @@ async def _check_tenant_access( if current_user.is_super_admin: # Re-set tenant context to the target tenant so RLS allows the operation from app.database import set_tenant_context + await set_tenant_context(db, str(tenant_id)) return if current_user.tenant_id != tenant_id: @@ -138,8 +138,12 @@ async def create_device( ) try: await log_action( - db, tenant_id, current_user.user_id, "device_create", - resource_type="device", resource_id=str(result.id), + db, + tenant_id, + current_user.user_id, + "device_create", + resource_type="device", + resource_id=str(result.id), details={"hostname": data.hostname, "ip_address": data.ip_address}, ip_address=request.client.host if request.client else None, ) @@ -191,8 +195,12 @@ async def update_device( ) try: await log_action( - db, tenant_id, current_user.user_id, "device_update", - resource_type="device", resource_id=str(device_id), + db, + tenant_id, + current_user.user_id, + "device_update", + resource_type="device", + resource_id=str(device_id), device_id=device_id, details={"changes": data.model_dump(exclude_unset=True)}, ip_address=request.client.host if request.client else None, @@ -220,8 +228,12 @@ async def delete_device( await _check_tenant_access(current_user, tenant_id, db) try: await log_action( - db, tenant_id, current_user.user_id, "device_delete", - resource_type="device", resource_id=str(device_id), + db, + tenant_id, + current_user.user_id, + "device_delete", + resource_type="device", + resource_id=str(device_id), device_id=device_id, ip_address=request.client.host if request.client else None, ) @@ -262,14 +274,21 @@ async def scan_devices( discovered = await scan_subnet(data.cidr) import ipaddress + network = ipaddress.ip_network(data.cidr, strict=False) - total_scanned = network.num_addresses - 2 if network.num_addresses > 2 else network.num_addresses + total_scanned = ( + network.num_addresses - 2 if network.num_addresses > 2 else network.num_addresses + ) # Audit log the scan (fire-and-forget — never breaks the response) try: await log_action( - db, tenant_id, current_user.user_id, "subnet_scan", - resource_type="network", resource_id=data.cidr, + db, + tenant_id, + current_user.user_id, + "subnet_scan", + resource_type="network", + resource_id=data.cidr, details={ "cidr": data.cidr, "devices_found": len(discovered), @@ -322,10 +341,12 @@ async def bulk_add_devices( password = dev_data.password or data.shared_password if not username or not password: - failed.append({ - "ip_address": dev_data.ip_address, - "error": "No credentials provided (set per-device or shared credentials)", - }) + failed.append( + { + "ip_address": dev_data.ip_address, + "error": "No credentials provided (set per-device or shared credentials)", + } + ) continue create_data = DeviceCreate( @@ -347,9 +368,16 @@ async def bulk_add_devices( added.append(device) try: await log_action( - db, tenant_id, current_user.user_id, "device_adopt", - resource_type="device", resource_id=str(device.id), - details={"hostname": create_data.hostname, "ip_address": create_data.ip_address}, + db, + tenant_id, + current_user.user_id, + "device_adopt", + resource_type="device", + resource_id=str(device.id), + details={ + "hostname": create_data.hostname, + "ip_address": create_data.ip_address, + }, ip_address=request.client.host if request.client else None, ) except Exception: diff --git a/backend/app/routers/events.py b/backend/app/routers/events.py index 3ac9f19..82f2d77 100644 --- a/backend/app/routers/events.py +++ b/backend/app/routers/events.py @@ -90,16 +90,18 @@ async def list_events( for row in alert_result.fetchall(): alert_status = row[1] or "firing" metric = row[3] or "unknown" - events.append({ - "id": str(row[0]), - "event_type": "alert", - "severity": row[2], - "title": f"{alert_status}: {metric}", - "description": row[4] or f"Alert {alert_status} for {metric}", - "device_hostname": row[7], - "device_id": str(row[6]) if row[6] else None, - "timestamp": row[5].isoformat() if row[5] else None, - }) + events.append( + { + "id": str(row[0]), + "event_type": "alert", + "severity": row[2], + "title": f"{alert_status}: {metric}", + "description": row[4] or f"Alert {alert_status} for {metric}", + "device_hostname": row[7], + "device_id": str(row[6]) if row[6] else None, + "timestamp": row[5].isoformat() if row[5] else None, + } + ) # 2. Device status changes (inferred from current status + last_seen) if not event_type or event_type == "status_change": @@ -117,16 +119,18 @@ async def list_events( device_status = row[2] or "unknown" hostname = row[1] or "Unknown device" severity = "info" if device_status == "online" else "warning" - events.append({ - "id": f"status-{row[0]}", - "event_type": "status_change", - "severity": severity, - "title": f"Device {device_status}", - "description": f"{hostname} is now {device_status}", - "device_hostname": hostname, - "device_id": str(row[0]), - "timestamp": row[3].isoformat() if row[3] else None, - }) + events.append( + { + "id": f"status-{row[0]}", + "event_type": "status_change", + "severity": severity, + "title": f"Device {device_status}", + "description": f"{hostname} is now {device_status}", + "device_hostname": hostname, + "device_id": str(row[0]), + "timestamp": row[3].isoformat() if row[3] else None, + } + ) # 3. Config backup runs if not event_type or event_type == "config_backup": @@ -144,16 +148,18 @@ async def list_events( for row in backup_result.fetchall(): trigger_type = row[1] or "manual" hostname = row[4] or "Unknown device" - events.append({ - "id": str(row[0]), - "event_type": "config_backup", - "severity": "info", - "title": "Config backup", - "description": f"{trigger_type} backup completed for {hostname}", - "device_hostname": hostname, - "device_id": str(row[3]) if row[3] else None, - "timestamp": row[2].isoformat() if row[2] else None, - }) + events.append( + { + "id": str(row[0]), + "event_type": "config_backup", + "severity": "info", + "title": "Config backup", + "description": f"{trigger_type} backup completed for {hostname}", + "device_hostname": hostname, + "device_id": str(row[3]) if row[3] else None, + "timestamp": row[2].isoformat() if row[2] else None, + } + ) # Sort all events by timestamp descending, then apply final limit events.sort( diff --git a/backend/app/routers/firmware.py b/backend/app/routers/firmware.py index 278be84..c31e2eb 100644 --- a/backend/app/routers/firmware.py +++ b/backend/app/routers/firmware.py @@ -67,6 +67,7 @@ async def get_firmware_overview( await _check_tenant_access(current_user, tenant_id, db) from app.services.firmware_service import get_firmware_overview as _get_overview + return await _get_overview(str(tenant_id)) @@ -206,6 +207,7 @@ async def trigger_firmware_check( raise HTTPException(status_code=403, detail="Super admin only") from app.services.firmware_service import check_latest_versions + results = await check_latest_versions() return {"status": "ok", "versions_discovered": len(results), "versions": results} @@ -221,6 +223,7 @@ async def list_firmware_cache( raise HTTPException(status_code=403, detail="Super admin only") from app.services.firmware_service import get_cached_firmware + return await get_cached_firmware() @@ -236,6 +239,7 @@ async def download_firmware( raise HTTPException(status_code=403, detail="Super admin only") from app.services.firmware_service import download_firmware as _download + path = await _download(body.architecture, body.channel, body.version) return {"status": "ok", "path": path} @@ -324,15 +328,21 @@ async def start_firmware_upgrade( # Schedule or start immediately if body.scheduled_at: from app.services.upgrade_service import schedule_upgrade + schedule_upgrade(job_id, datetime.fromisoformat(body.scheduled_at)) else: from app.services.upgrade_service import start_upgrade + asyncio.create_task(start_upgrade(job_id)) try: await log_action( - db, tenant_id, current_user.user_id, "firmware_upgrade", - resource_type="firmware", resource_id=job_id, + db, + tenant_id, + current_user.user_id, + "firmware_upgrade", + resource_type="firmware", + resource_id=job_id, device_id=uuid.UUID(body.device_id), details={"target_version": body.target_version, "channel": body.channel}, ) @@ -406,9 +416,11 @@ async def start_mass_firmware_upgrade( # Schedule or start immediately if body.scheduled_at: from app.services.upgrade_service import schedule_mass_upgrade + schedule_mass_upgrade(rollout_group_id, datetime.fromisoformat(body.scheduled_at)) else: from app.services.upgrade_service import start_mass_upgrade + asyncio.create_task(start_mass_upgrade(rollout_group_id)) return { @@ -639,6 +651,7 @@ async def cancel_upgrade_endpoint( raise HTTPException(403, "Viewers cannot cancel upgrades") from app.services.upgrade_service import cancel_upgrade + await cancel_upgrade(str(job_id)) return {"status": "ok", "message": "Upgrade cancelled"} @@ -662,6 +675,7 @@ async def retry_upgrade_endpoint( raise HTTPException(403, "Viewers cannot retry upgrades") from app.services.upgrade_service import retry_failed_upgrade + await retry_failed_upgrade(str(job_id)) return {"status": "ok", "message": "Upgrade retry started"} @@ -685,6 +699,7 @@ async def resume_rollout_endpoint( raise HTTPException(403, "Viewers cannot resume rollouts") from app.services.upgrade_service import resume_mass_upgrade + await resume_mass_upgrade(str(rollout_group_id)) return {"status": "ok", "message": "Rollout resumed"} @@ -708,5 +723,6 @@ async def abort_rollout_endpoint( raise HTTPException(403, "Viewers cannot abort rollouts") from app.services.upgrade_service import abort_mass_upgrade + aborted = await abort_mass_upgrade(str(rollout_group_id)) return {"status": "ok", "aborted_count": aborted} diff --git a/backend/app/routers/maintenance_windows.py b/backend/app/routers/maintenance_windows.py index 61e5abf..2516f9b 100644 --- a/backend/app/routers/maintenance_windows.py +++ b/backend/app/routers/maintenance_windows.py @@ -299,9 +299,7 @@ async def delete_maintenance_window( _require_operator(current_user) result = await db.execute( - text( - "DELETE FROM maintenance_windows WHERE id = CAST(:id AS uuid) RETURNING id" - ), + text("DELETE FROM maintenance_windows WHERE id = CAST(:id AS uuid) RETURNING id"), {"id": str(window_id)}, ) if not result.fetchone(): diff --git a/backend/app/routers/metrics.py b/backend/app/routers/metrics.py index 92ae3ea..774c57d 100644 --- a/backend/app/routers/metrics.py +++ b/backend/app/routers/metrics.py @@ -65,6 +65,7 @@ async def _check_tenant_access( if current_user.is_super_admin: # Re-set tenant context to the target tenant so RLS allows the operation from app.database import set_tenant_context + await set_tenant_context(db, str(tenant_id)) return if current_user.tenant_id != tenant_id: diff --git a/backend/app/routers/remote_access.py b/backend/app/routers/remote_access.py index 1ad0d27..46bdaef 100644 --- a/backend/app/routers/remote_access.py +++ b/backend/app/routers/remote_access.py @@ -81,9 +81,12 @@ async def _get_device(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UU return device -async def _check_tenant_access(current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession) -> None: +async def _check_tenant_access( + current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession +) -> None: if current_user.is_super_admin: from app.database import set_tenant_context + await set_tenant_context(db, str(tenant_id)) return if current_user.tenant_id != tenant_id: @@ -124,8 +127,12 @@ async def open_winbox_session( try: await log_action( - db, tenant_id, current_user.user_id, "winbox_tunnel_open", - resource_type="device", resource_id=str(device_id), + db, + tenant_id, + current_user.user_id, + "winbox_tunnel_open", + resource_type="device", + resource_id=str(device_id), device_id=device_id, details={"source_ip": source_ip}, ip_address=source_ip, @@ -133,24 +140,31 @@ async def open_winbox_session( except Exception: pass - payload = json.dumps({ - "device_id": str(device_id), - "tenant_id": str(tenant_id), - "user_id": str(current_user.user_id), - "target_port": 8291, - }).encode() + payload = json.dumps( + { + "device_id": str(device_id), + "tenant_id": str(tenant_id), + "user_id": str(current_user.user_id), + "target_port": 8291, + } + ).encode() try: nc = await _get_nats() msg = await nc.request("tunnel.open", payload, timeout=10) except Exception as exc: logger.error("NATS tunnel.open failed: %s", exc) - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Tunnel service unavailable") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Tunnel service unavailable" + ) try: data = json.loads(msg.data) except Exception: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Invalid response from tunnel service") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Invalid response from tunnel service", + ) if "error" in data: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=data["error"]) @@ -158,11 +172,16 @@ async def open_winbox_session( port = data.get("local_port") tunnel_id = data.get("tunnel_id", "") if not isinstance(port, int) or not (49000 <= port <= 49100): - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Invalid port allocation from tunnel service") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Invalid port allocation from tunnel service", + ) # Derive the tunnel host from the request so remote clients get the server's # address rather than 127.0.0.1 (which would point to the user's own machine). - tunnel_host = (request.headers.get("x-forwarded-host") or request.headers.get("host") or "127.0.0.1") + tunnel_host = ( + request.headers.get("x-forwarded-host") or request.headers.get("host") or "127.0.0.1" + ) # Strip port from host header (e.g. "10.101.0.175:8001" → "10.101.0.175") tunnel_host = tunnel_host.split(":")[0] @@ -213,8 +232,12 @@ async def open_ssh_session( try: await log_action( - db, tenant_id, current_user.user_id, "ssh_session_open", - resource_type="device", resource_id=str(device_id), + db, + tenant_id, + current_user.user_id, + "ssh_session_open", + resource_type="device", + resource_id=str(device_id), device_id=device_id, details={"source_ip": source_ip, "cols": body.cols, "rows": body.rows}, ip_address=source_ip, @@ -223,22 +246,26 @@ async def open_ssh_session( pass token = secrets.token_urlsafe(32) - token_payload = json.dumps({ - "device_id": str(device_id), - "tenant_id": str(tenant_id), - "user_id": str(current_user.user_id), - "source_ip": source_ip, - "cols": body.cols, - "rows": body.rows, - "created_at": int(time.time()), - }) + token_payload = json.dumps( + { + "device_id": str(device_id), + "tenant_id": str(tenant_id), + "user_id": str(current_user.user_id), + "source_ip": source_ip, + "cols": body.cols, + "rows": body.rows, + "created_at": int(time.time()), + } + ) try: rd = await _get_redis() await rd.setex(f"ssh:token:{token}", 120, token_payload) except Exception as exc: logger.error("Redis setex failed for SSH token: %s", exc) - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Session store unavailable") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Session store unavailable" + ) return SSHSessionResponse( token=token, @@ -274,8 +301,12 @@ async def close_winbox_session( try: await log_action( - db, tenant_id, current_user.user_id, "winbox_tunnel_close", - resource_type="device", resource_id=str(device_id), + db, + tenant_id, + current_user.user_id, + "winbox_tunnel_close", + resource_type="device", + resource_id=str(device_id), device_id=device_id, details={"tunnel_id": tunnel_id, "source_ip": source_ip}, ip_address=source_ip, diff --git a/backend/app/routers/settings.py b/backend/app/routers/settings.py index 99716e2..842c36c 100644 --- a/backend/app/routers/settings.py +++ b/backend/app/routers/settings.py @@ -72,7 +72,11 @@ async def _set_system_settings(updates: dict, user_id: str) -> None: ON CONFLICT (key) DO UPDATE SET value = :value, updated_by = CAST(:user_id AS uuid), updated_at = now() """), - {"key": key, "value": str(value) if value is not None else None, "user_id": user_id}, + { + "key": key, + "value": str(value) if value is not None else None, + "user_id": user_id, + }, ) await session.commit() @@ -100,7 +104,8 @@ async def get_smtp_settings(user=Depends(require_role("super_admin"))): "smtp_host": db_settings.get("smtp_host") or settings.SMTP_HOST, "smtp_port": int(db_settings.get("smtp_port") or settings.SMTP_PORT), "smtp_user": db_settings.get("smtp_user") or settings.SMTP_USER or "", - "smtp_use_tls": (db_settings.get("smtp_use_tls") or str(settings.SMTP_USE_TLS)).lower() == "true", + "smtp_use_tls": (db_settings.get("smtp_use_tls") or str(settings.SMTP_USE_TLS)).lower() + == "true", "smtp_from_address": db_settings.get("smtp_from_address") or settings.SMTP_FROM_ADDRESS, "smtp_provider": db_settings.get("smtp_provider") or "custom", "smtp_password_set": bool(db_settings.get("smtp_password") or settings.SMTP_PASSWORD), diff --git a/backend/app/routers/sse.py b/backend/app/routers/sse.py index 8ea9ad6..5919d04 100644 --- a/backend/app/routers/sse.py +++ b/backend/app/routers/sse.py @@ -32,6 +32,7 @@ async def _get_sse_redis() -> aioredis.Redis: global _redis if _redis is None: from app.config import settings + _redis = aioredis.from_url(settings.REDIS_URL, decode_responses=True) return _redis @@ -70,7 +71,9 @@ async def _validate_sse_token(token: str) -> dict: async def event_stream( request: Request, tenant_id: uuid.UUID, - token: str = Query(..., description="Short-lived SSE exchange token (from POST /auth/sse-token)"), + token: str = Query( + ..., description="Short-lived SSE exchange token (from POST /auth/sse-token)" + ), ) -> EventSourceResponse: """Stream real-time events for a tenant via Server-Sent Events. @@ -87,7 +90,9 @@ async def event_stream( user_id = user_context.get("user_id", "") # Authorization: user must belong to the requested tenant or be super_admin - if user_role != "super_admin" and (user_tenant_id is None or str(user_tenant_id) != str(tenant_id)): + if user_role != "super_admin" and ( + user_tenant_id is None or str(user_tenant_id) != str(tenant_id) + ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized for this tenant", diff --git a/backend/app/routers/templates.py b/backend/app/routers/templates.py index eb56267..0541b59 100644 --- a/backend/app/routers/templates.py +++ b/backend/app/routers/templates.py @@ -21,7 +21,6 @@ RBAC: viewer = read (GET/preview); operator and above = write (POST/PUT/DELETE/p import asyncio import logging import uuid -from datetime import datetime, timezone from typing import Any, Optional from fastapi import APIRouter, Depends, HTTPException, Query, Request, status @@ -54,6 +53,7 @@ async def _check_tenant_access( """Verify the current user is allowed to access the given tenant.""" if current_user.is_super_admin: from app.database import set_tenant_context + await set_tenant_context(db, str(tenant_id)) return if current_user.tenant_id != tenant_id: @@ -191,7 +191,8 @@ async def create_template( if unmatched: logger.warning( "Template '%s' has undeclared variables: %s (auto-adding as string type)", - body.name, unmatched, + body.name, + unmatched, ) # Create template @@ -553,11 +554,13 @@ async def push_template( status="pending", ) db.add(job) - jobs_created.append({ - "job_id": str(job.id), - "device_id": str(device.id), - "device_hostname": device.hostname, - }) + jobs_created.append( + { + "job_id": str(job.id), + "device_id": str(device.id), + "device_hostname": device.hostname, + } + ) await db.flush() @@ -598,14 +601,16 @@ async def push_status( jobs = [] for job, hostname in rows: - jobs.append({ - "device_id": str(job.device_id), - "hostname": hostname, - "status": job.status, - "error_message": job.error_message, - "started_at": job.started_at.isoformat() if job.started_at else None, - "completed_at": job.completed_at.isoformat() if job.completed_at else None, - }) + jobs.append( + { + "device_id": str(job.device_id), + "hostname": hostname, + "status": job.status, + "error_message": job.error_message, + "started_at": job.started_at.isoformat() if job.started_at else None, + "completed_at": job.completed_at.isoformat() if job.completed_at else None, + } + ) return { "rollout_id": str(rollout_id), diff --git a/backend/app/routers/tenants.py b/backend/app/routers/tenants.py index 4532513..5ec6a2d 100644 --- a/backend/app/routers/tenants.py +++ b/backend/app/routers/tenants.py @@ -9,7 +9,6 @@ DELETE /api/tenants/{id} — delete tenant (super_admin only) """ import uuid -from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Request, status from sqlalchemy import func, select, text @@ -17,7 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.middleware.rate_limit import limiter -from app.database import get_admin_db, get_db +from app.database import get_admin_db from app.middleware.rbac import require_super_admin, require_tenant_admin_or_above from app.middleware.tenant_context import CurrentUser from app.models.device import Device @@ -70,15 +69,18 @@ async def list_tenants( else: if not current_user.tenant_id: return [] - result = await db.execute( - select(Tenant).where(Tenant.id == current_user.tenant_id) - ) + result = await db.execute(select(Tenant).where(Tenant.id == current_user.tenant_id)) tenants = result.scalars().all() return [await _get_tenant_response(tenant, db) for tenant in tenants] -@router.post("", response_model=TenantResponse, status_code=status.HTTP_201_CREATED, summary="Create a tenant") +@router.post( + "", + response_model=TenantResponse, + status_code=status.HTTP_201_CREATED, + summary="Create a tenant", +) @limiter.limit("20/minute") async def create_tenant( request: Request, @@ -108,13 +110,21 @@ async def create_tenant( ("Device Offline", "device_offline", "eq", 1, 1, "critical"), ] for name, metric, operator, threshold, duration, sev in default_rules: - await db.execute(text(""" + await db.execute( + text(""" INSERT INTO alert_rules (id, tenant_id, name, metric, operator, threshold, duration_polls, severity, enabled, is_default) VALUES (gen_random_uuid(), CAST(:tenant_id AS uuid), :name, :metric, :operator, :threshold, :duration, :severity, TRUE, TRUE) - """), { - "tenant_id": str(tenant.id), "name": name, "metric": metric, - "operator": operator, "threshold": threshold, "duration": duration, "severity": sev, - }) + """), + { + "tenant_id": str(tenant.id), + "name": name, + "metric": metric, + "operator": operator, + "threshold": threshold, + "duration": duration, + "severity": sev, + }, + ) await db.commit() # Seed starter config templates for new tenant @@ -131,6 +141,7 @@ async def create_tenant( await db.commit() except Exception as exc: import logging + logging.getLogger(__name__).warning( "OpenBao key provisioning failed for tenant %s (will be provisioned on next startup): %s", tenant.id, @@ -229,6 +240,7 @@ async def delete_tenant( # Check if tenant had VPN configured (before cascade deletes it) from app.services.vpn_service import get_vpn_config, sync_wireguard_config + had_vpn = await get_vpn_config(db, tenant_id) await db.delete(tenant) @@ -288,14 +300,54 @@ add chain=forward action=drop comment="Drop everything else" # Identity /system identity set name={{ device.hostname }}""", "variables": [ - {"name": "wan_interface", "type": "string", "default": "ether1", "description": "WAN-facing interface"}, - {"name": "lan_gateway", "type": "ip", "default": "192.168.88.1", "description": "LAN gateway IP"}, - {"name": "lan_cidr", "type": "integer", "default": "24", "description": "LAN subnet mask bits"}, - {"name": "lan_network", "type": "ip", "default": "192.168.88.0", "description": "LAN network address"}, - {"name": "dhcp_start", "type": "ip", "default": "192.168.88.100", "description": "DHCP pool start"}, - {"name": "dhcp_end", "type": "ip", "default": "192.168.88.254", "description": "DHCP pool end"}, - {"name": "dns_servers", "type": "string", "default": "8.8.8.8,8.8.4.4", "description": "Upstream DNS servers"}, - {"name": "ntp_server", "type": "string", "default": "pool.ntp.org", "description": "NTP server"}, + { + "name": "wan_interface", + "type": "string", + "default": "ether1", + "description": "WAN-facing interface", + }, + { + "name": "lan_gateway", + "type": "ip", + "default": "192.168.88.1", + "description": "LAN gateway IP", + }, + { + "name": "lan_cidr", + "type": "integer", + "default": "24", + "description": "LAN subnet mask bits", + }, + { + "name": "lan_network", + "type": "ip", + "default": "192.168.88.0", + "description": "LAN network address", + }, + { + "name": "dhcp_start", + "type": "ip", + "default": "192.168.88.100", + "description": "DHCP pool start", + }, + { + "name": "dhcp_end", + "type": "ip", + "default": "192.168.88.254", + "description": "DHCP pool end", + }, + { + "name": "dns_servers", + "type": "string", + "default": "8.8.8.8,8.8.4.4", + "description": "Upstream DNS servers", + }, + { + "name": "ntp_server", + "type": "string", + "default": "pool.ntp.org", + "description": "NTP server", + }, ], }, { @@ -311,8 +363,18 @@ add chain=forward connection-state=invalid action=drop add chain=forward src-address={{ allowed_network }} action=accept add chain=forward action=drop""", "variables": [ - {"name": "wan_interface", "type": "string", "default": "ether1", "description": "WAN-facing interface"}, - {"name": "allowed_network", "type": "subnet", "default": "192.168.88.0/24", "description": "Allowed source network"}, + { + "name": "wan_interface", + "type": "string", + "default": "ether1", + "description": "WAN-facing interface", + }, + { + "name": "allowed_network", + "type": "subnet", + "default": "192.168.88.0/24", + "description": "Allowed source network", + }, ], }, { @@ -322,11 +384,36 @@ add chain=forward action=drop""", /ip dhcp-server network add address={{ gateway }}/24 gateway={{ gateway }} dns-server={{ dns_server }} /ip dhcp-server add name=dhcp1 interface={{ interface }} address-pool=dhcp-pool disabled=no""", "variables": [ - {"name": "pool_start", "type": "ip", "default": "192.168.88.100", "description": "DHCP pool start address"}, - {"name": "pool_end", "type": "ip", "default": "192.168.88.254", "description": "DHCP pool end address"}, - {"name": "gateway", "type": "ip", "default": "192.168.88.1", "description": "Default gateway"}, - {"name": "dns_server", "type": "ip", "default": "8.8.8.8", "description": "DNS server address"}, - {"name": "interface", "type": "string", "default": "bridge-lan", "description": "Interface to serve DHCP on"}, + { + "name": "pool_start", + "type": "ip", + "default": "192.168.88.100", + "description": "DHCP pool start address", + }, + { + "name": "pool_end", + "type": "ip", + "default": "192.168.88.254", + "description": "DHCP pool end address", + }, + { + "name": "gateway", + "type": "ip", + "default": "192.168.88.1", + "description": "Default gateway", + }, + { + "name": "dns_server", + "type": "ip", + "default": "8.8.8.8", + "description": "DNS server address", + }, + { + "name": "interface", + "type": "string", + "default": "bridge-lan", + "description": "Interface to serve DHCP on", + }, ], }, { @@ -335,10 +422,30 @@ add chain=forward action=drop""", "content": """/interface wireless security-profiles add name=portal-wpa2 mode=dynamic-keys authentication-types=wpa2-psk wpa2-pre-shared-key={{ password }} /interface wireless set wlan1 mode=ap-bridge ssid={{ ssid }} security-profile=portal-wpa2 frequency={{ frequency }} channel-width={{ channel_width }} disabled=no""", "variables": [ - {"name": "ssid", "type": "string", "default": "MikroTik-AP", "description": "Wireless network name"}, - {"name": "password", "type": "string", "default": "", "description": "WPA2 pre-shared key (min 8 characters)"}, - {"name": "frequency", "type": "integer", "default": "2412", "description": "Wireless frequency in MHz"}, - {"name": "channel_width", "type": "string", "default": "20/40mhz-XX", "description": "Channel width setting"}, + { + "name": "ssid", + "type": "string", + "default": "MikroTik-AP", + "description": "Wireless network name", + }, + { + "name": "password", + "type": "string", + "default": "", + "description": "WPA2 pre-shared key (min 8 characters)", + }, + { + "name": "frequency", + "type": "integer", + "default": "2412", + "description": "Wireless frequency in MHz", + }, + { + "name": "channel_width", + "type": "string", + "default": "20/40mhz-XX", + "description": "Channel width setting", + }, ], }, { @@ -351,8 +458,18 @@ add chain=forward action=drop""", /ip service set ssh port=22 /ip service set winbox port=8291""", "variables": [ - {"name": "ntp_server", "type": "ip", "default": "pool.ntp.org", "description": "NTP server address"}, - {"name": "dns_servers", "type": "string", "default": "8.8.8.8,8.8.4.4", "description": "Comma-separated DNS servers"}, + { + "name": "ntp_server", + "type": "ip", + "default": "pool.ntp.org", + "description": "NTP server address", + }, + { + "name": "dns_servers", + "type": "string", + "default": "8.8.8.8,8.8.4.4", + "description": "Comma-separated DNS servers", + }, ], }, ] @@ -363,13 +480,16 @@ async def _seed_starter_templates(db, tenant_id) -> None: import json as _json for tmpl in _STARTER_TEMPLATES: - await db.execute(text(""" + await db.execute( + text(""" INSERT INTO config_templates (id, tenant_id, name, description, content, variables) VALUES (gen_random_uuid(), CAST(:tid AS uuid), :name, :desc, :content, CAST(:vars AS jsonb)) - """), { - "tid": str(tenant_id), - "name": tmpl["name"], - "desc": tmpl["description"], - "content": tmpl["content"], - "vars": _json.dumps(tmpl["variables"]), - }) + """), + { + "tid": str(tenant_id), + "name": tmpl["name"], + "desc": tmpl["description"], + "content": tmpl["content"], + "vars": _json.dumps(tmpl["variables"]), + }, + ) diff --git a/backend/app/routers/topology.py b/backend/app/routers/topology.py index ab928d1..a1d1740 100644 --- a/backend/app/routers/topology.py +++ b/backend/app/routers/topology.py @@ -14,7 +14,6 @@ Builds a topology graph of managed devices by: import asyncio import ipaddress import json -import logging import uuid from typing import Any @@ -265,7 +264,7 @@ async def get_topology( nodes: list[TopologyNode] = [] ip_to_device: dict[str, str] = {} online_device_ids: list[str] = [] - devices_by_id: dict[str, Any] = {} + _devices_by_id: dict[str, Any] = {} for row in rows: device_id = str(row.id) @@ -288,9 +287,7 @@ async def get_topology( if online_device_ids: tasks = [ - routeros_proxy.execute_command( - device_id, "/ip/neighbor/print", timeout=10.0 - ) + routeros_proxy.execute_command(device_id, "/ip/neighbor/print", timeout=10.0) for device_id in online_device_ids ] results = await asyncio.gather(*tasks, return_exceptions=True) diff --git a/backend/app/routers/transparency.py b/backend/app/routers/transparency.py index 06ad16c..8e8fdf8 100644 --- a/backend/app/routers/transparency.py +++ b/backend/app/routers/transparency.py @@ -164,9 +164,7 @@ async def list_transparency_logs( # Count total count_result = await db.execute( - select(func.count()) - .select_from(text("key_access_log k")) - .where(where_clause), + select(func.count()).select_from(text("key_access_log k")).where(where_clause), params, ) total = count_result.scalar() or 0 @@ -353,39 +351,41 @@ async def export_transparency_logs( output = io.StringIO() writer = csv.writer(output) - writer.writerow([ - "ID", - "Action", - "Device Name", - "Device ID", - "Justification", - "Operator Email", - "Correlation ID", - "Resource Type", - "Resource ID", - "IP Address", - "Timestamp", - ]) + writer.writerow( + [ + "ID", + "Action", + "Device Name", + "Device ID", + "Justification", + "Operator Email", + "Correlation ID", + "Resource Type", + "Resource ID", + "IP Address", + "Timestamp", + ] + ) for row in all_rows: - writer.writerow([ - str(row["id"]), - row["action"], - row["device_name"] or "", - str(row["device_id"]) if row["device_id"] else "", - row["justification"] or "", - row["operator_email"] or "", - row["correlation_id"] or "", - row["resource_type"] or "", - row["resource_id"] or "", - row["ip_address"] or "", - str(row["created_at"]), - ]) + writer.writerow( + [ + str(row["id"]), + row["action"], + row["device_name"] or "", + str(row["device_id"]) if row["device_id"] else "", + row["justification"] or "", + row["operator_email"] or "", + row["correlation_id"] or "", + row["resource_type"] or "", + row["resource_id"] or "", + row["ip_address"] or "", + str(row["created_at"]), + ] + ) output.seek(0) return StreamingResponse( iter([output.getvalue()]), media_type="text/csv", - headers={ - "Content-Disposition": "attachment; filename=transparency-logs.csv" - }, + headers={"Content-Disposition": "attachment; filename=transparency-logs.csv"}, ) diff --git a/backend/app/routers/users.py b/backend/app/routers/users.py index 0d85fe2..9237ac0 100644 --- a/backend/app/routers/users.py +++ b/backend/app/routers/users.py @@ -20,7 +20,7 @@ from app.database import get_admin_db from app.middleware.rbac import require_tenant_admin_or_above from app.middleware.tenant_context import CurrentUser from app.models.tenant import Tenant -from app.models.user import User, UserRole +from app.models.user import User from app.schemas.user import UserCreate, UserResponse, UserUpdate from app.services.auth import hash_password @@ -69,11 +69,7 @@ async def list_users( """ await _check_tenant_access(tenant_id, current_user, db) - result = await db.execute( - select(User) - .where(User.tenant_id == tenant_id) - .order_by(User.name) - ) + result = await db.execute(select(User).where(User.tenant_id == tenant_id).order_by(User.name)) users = result.scalars().all() return [UserResponse.model_validate(user) for user in users] @@ -103,9 +99,7 @@ async def create_user( await _check_tenant_access(tenant_id, current_user, db) # Check email uniqueness (global, not per-tenant) - existing = await db.execute( - select(User).where(User.email == data.email.lower()) - ) + existing = await db.execute(select(User).where(User.email == data.email.lower())) if existing.scalar_one_or_none(): raise HTTPException( status_code=status.HTTP_409_CONFLICT, @@ -138,9 +132,7 @@ async def get_user( """Get user detail.""" await _check_tenant_access(tenant_id, current_user, db) - result = await db.execute( - select(User).where(User.id == user_id, User.tenant_id == tenant_id) - ) + result = await db.execute(select(User).where(User.id == user_id, User.tenant_id == tenant_id)) user = result.scalar_one_or_none() if not user: @@ -168,9 +160,7 @@ async def update_user( """ await _check_tenant_access(tenant_id, current_user, db) - result = await db.execute( - select(User).where(User.id == user_id, User.tenant_id == tenant_id) - ) + result = await db.execute(select(User).where(User.id == user_id, User.tenant_id == tenant_id)) user = result.scalar_one_or_none() if not user: @@ -194,7 +184,11 @@ async def update_user( return UserResponse.model_validate(user) -@router.delete("/{tenant_id}/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Deactivate a user") +@router.delete( + "/{tenant_id}/users/{user_id}", + status_code=status.HTTP_204_NO_CONTENT, + summary="Deactivate a user", +) @limiter.limit("5/minute") async def deactivate_user( request: Request, @@ -209,9 +203,7 @@ async def deactivate_user( """ await _check_tenant_access(tenant_id, current_user, db) - result = await db.execute( - select(User).where(User.id == user_id, User.tenant_id == tenant_id) - ) + result = await db.execute(select(User).where(User.id == user_id, User.tenant_id == tenant_id)) user = result.scalar_one_or_none() if not user: diff --git a/backend/app/routers/vpn.py b/backend/app/routers/vpn.py index 4c499f4..7511c01 100644 --- a/backend/app/routers/vpn.py +++ b/backend/app/routers/vpn.py @@ -68,7 +68,11 @@ async def get_vpn_config( return resp -@router.post("/tenants/{tenant_id}/vpn", response_model=VpnConfigResponse, status_code=status.HTTP_201_CREATED) +@router.post( + "/tenants/{tenant_id}/vpn", + response_model=VpnConfigResponse, + status_code=status.HTTP_201_CREATED, +) @limiter.limit("20/minute") async def setup_vpn( request: Request, @@ -177,7 +181,11 @@ async def list_peers( return responses -@router.post("/tenants/{tenant_id}/vpn/peers", response_model=VpnPeerResponse, status_code=status.HTTP_201_CREATED) +@router.post( + "/tenants/{tenant_id}/vpn/peers", + response_model=VpnPeerResponse, + status_code=status.HTTP_201_CREATED, +) @limiter.limit("20/minute") async def add_peer( request: Request, @@ -190,7 +198,9 @@ async def add_peer( await _check_tenant_access(current_user, tenant_id, db) _require_operator(current_user) try: - peer = await vpn_service.add_peer(db, tenant_id, body.device_id, additional_allowed_ips=body.additional_allowed_ips) + peer = await vpn_service.add_peer( + db, tenant_id, body.device_id, additional_allowed_ips=body.additional_allowed_ips + ) except ValueError as e: msg = str(e) if "must not overlap" in msg: @@ -208,7 +218,11 @@ async def add_peer( return resp -@router.post("/tenants/{tenant_id}/vpn/peers/onboard", response_model=VpnOnboardResponse, status_code=status.HTTP_201_CREATED) +@router.post( + "/tenants/{tenant_id}/vpn/peers/onboard", + response_model=VpnOnboardResponse, + status_code=status.HTTP_201_CREATED, +) @limiter.limit("10/minute") async def onboard_device( request: Request, @@ -222,7 +236,8 @@ async def onboard_device( _require_operator(current_user) try: result = await vpn_service.onboard_device( - db, tenant_id, + db, + tenant_id, hostname=body.hostname, username=body.username, password=body.password, diff --git a/backend/app/routers/winbox_remote.py b/backend/app/routers/winbox_remote.py index 3518514..5c804c2 100644 --- a/backend/app/routers/winbox_remote.py +++ b/backend/app/routers/winbox_remote.py @@ -146,16 +146,16 @@ async def _delete_session_from_redis(session_id: str) -> None: await rd.delete(f"{REDIS_PREFIX}{session_id}") -async def _open_tunnel( - device_id: uuid.UUID, tenant_id: uuid.UUID, user_id: uuid.UUID -) -> dict: +async def _open_tunnel(device_id: uuid.UUID, tenant_id: uuid.UUID, user_id: uuid.UUID) -> dict: """Open a TCP tunnel to device port 8291 via NATS request-reply.""" - payload = json.dumps({ - "device_id": str(device_id), - "tenant_id": str(tenant_id), - "user_id": str(user_id), - "target_port": 8291, - }).encode() + payload = json.dumps( + { + "device_id": str(device_id), + "tenant_id": str(tenant_id), + "user_id": str(user_id), + "target_port": 8291, + } + ).encode() try: nc = await _get_nats() @@ -176,9 +176,7 @@ async def _open_tunnel( ) if "error" in data: - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=data["error"] - ) + raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=data["error"]) return data @@ -250,9 +248,7 @@ async def create_winbox_remote_session( except Exception: worker_info = None if worker_info is None: - logger.warning( - "Cleaning stale Redis session %s (worker 404)", stale_sid - ) + logger.warning("Cleaning stale Redis session %s (worker 404)", stale_sid) tunnel_id = sess.get("tunnel_id") if tunnel_id: await _close_tunnel(tunnel_id) @@ -333,12 +329,8 @@ async def create_winbox_remote_session( username = "" # noqa: F841 password = "" # noqa: F841 - expires_at = datetime.fromisoformat( - worker_resp.get("expires_at", now.isoformat()) - ) - max_expires_at = datetime.fromisoformat( - worker_resp.get("max_expires_at", now.isoformat()) - ) + expires_at = datetime.fromisoformat(worker_resp.get("expires_at", now.isoformat())) + max_expires_at = datetime.fromisoformat(worker_resp.get("max_expires_at", now.isoformat())) # Save session to Redis session_data = { @@ -375,8 +367,7 @@ async def create_winbox_remote_session( pass ws_path = ( - f"/api/tenants/{tenant_id}/devices/{device_id}" - f"/winbox-remote-sessions/{session_id}/ws" + f"/api/tenants/{tenant_id}/devices/{device_id}/winbox-remote-sessions/{session_id}/ws" ) return RemoteWinboxSessionResponse( @@ -425,14 +416,10 @@ async def get_winbox_remote_session( sess = await _get_session_from_redis(str(session_id)) if sess is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Session not found" - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found") if sess.get("tenant_id") != str(tenant_id) or sess.get("device_id") != str(device_id): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Session not found" - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found") return RemoteWinboxStatusResponse( session_id=uuid.UUID(sess["session_id"]), @@ -478,10 +465,7 @@ async def list_winbox_remote_sessions( sess = json.loads(raw) except Exception: continue - if ( - sess.get("tenant_id") == str(tenant_id) - and sess.get("device_id") == str(device_id) - ): + if sess.get("tenant_id") == str(tenant_id) and sess.get("device_id") == str(device_id): sessions.append( RemoteWinboxStatusResponse( session_id=uuid.UUID(sess["session_id"]), @@ -533,9 +517,7 @@ async def terminate_winbox_remote_session( ) if sess.get("tenant_id") != str(tenant_id): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Session not found" - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found") # Rollback order: worker -> tunnel -> redis -> audit await worker_terminate_session(str(session_id)) @@ -574,14 +556,12 @@ async def terminate_winbox_remote_session( @router.get( - "/tenants/{tenant_id}/devices/{device_id}" - "/winbox-remote-sessions/{session_id}/xpra/{path:path}", + "/tenants/{tenant_id}/devices/{device_id}/winbox-remote-sessions/{session_id}/xpra/{path:path}", summary="Proxy Xpra HTML5 client files", dependencies=[Depends(require_operator_or_above)], ) @router.get( - "/tenants/{tenant_id}/devices/{device_id}" - "/winbox-remote-sessions/{session_id}/xpra", + "/tenants/{tenant_id}/devices/{device_id}/winbox-remote-sessions/{session_id}/xpra", summary="Proxy Xpra HTML5 client (root)", dependencies=[Depends(require_operator_or_above)], ) @@ -626,7 +606,8 @@ async def proxy_xpra_html( content=proxy_resp.content, status_code=proxy_resp.status_code, headers={ - k: v for k, v in proxy_resp.headers.items() + k: v + for k, v in proxy_resp.headers.items() if k.lower() in ("content-type", "cache-control", "content-encoding") }, ) @@ -637,9 +618,7 @@ async def proxy_xpra_html( # --------------------------------------------------------------------------- -@router.websocket( - "/tenants/{tenant_id}/devices/{device_id}/winbox-remote-sessions/{session_id}/ws" -) +@router.websocket("/tenants/{tenant_id}/devices/{device_id}/winbox-remote-sessions/{session_id}/ws") async def winbox_remote_ws_proxy( websocket: WebSocket, tenant_id: uuid.UUID, diff --git a/backend/app/schemas/auth.py b/backend/app/schemas/auth.py index eff69e5..bb182f7 100644 --- a/backend/app/schemas/auth.py +++ b/backend/app/schemas/auth.py @@ -67,11 +67,13 @@ class MessageResponse(BaseModel): class SRPInitRequest(BaseModel): """Step 1 request: client sends email to begin SRP handshake.""" + email: EmailStr class SRPInitResponse(BaseModel): """Step 1 response: server returns ephemeral B and key derivation salts.""" + salt: str # hex-encoded SRP salt server_public: str # hex-encoded server ephemeral B session_id: str # Redis session key nonce @@ -81,6 +83,7 @@ class SRPInitResponse(BaseModel): class SRPVerifyRequest(BaseModel): """Step 2 request: client sends proof M1 to complete handshake.""" + email: EmailStr session_id: str client_public: str # hex-encoded client ephemeral A @@ -89,6 +92,7 @@ class SRPVerifyRequest(BaseModel): class SRPVerifyResponse(BaseModel): """Step 2 response: server returns tokens and proof M2.""" + access_token: str refresh_token: str token_type: str = "bearer" @@ -98,6 +102,7 @@ class SRPVerifyResponse(BaseModel): class SRPRegisterRequest(BaseModel): """Used during registration to store SRP verifier and key set.""" + srp_salt: str # hex-encoded srp_verifier: str # hex-encoded encrypted_private_key: str # base64-encoded @@ -114,10 +119,12 @@ class SRPRegisterRequest(BaseModel): class DeleteAccountRequest(BaseModel): """Request body for account self-deletion. User must type 'DELETE' to confirm.""" + confirmation: str # Must be "DELETE" to confirm class DeleteAccountResponse(BaseModel): """Response after successful account deletion.""" + message: str deleted: bool diff --git a/backend/app/schemas/certificate.py b/backend/app/schemas/certificate.py index 08aa9c1..e7b3e96 100644 --- a/backend/app/schemas/certificate.py +++ b/backend/app/schemas/certificate.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, ConfigDict # Request schemas # --------------------------------------------------------------------------- + class CACreateRequest(BaseModel): """Request to generate a new root CA for the tenant.""" @@ -34,6 +35,7 @@ class BulkCertDeployRequest(BaseModel): # Response schemas # --------------------------------------------------------------------------- + class CAResponse(BaseModel): """Public details of a tenant's Certificate Authority (no private key).""" diff --git a/backend/app/schemas/device.py b/backend/app/schemas/device.py index 1cf46f7..adc8878 100644 --- a/backend/app/schemas/device.py +++ b/backend/app/schemas/device.py @@ -117,6 +117,7 @@ class SubnetScanRequest(BaseModel): def validate_cidr(cls, v: str) -> str: """Validate that the value is a valid CIDR notation and RFC 1918 private range.""" import ipaddress + try: network = ipaddress.ip_network(v, strict=False) except ValueError as e: @@ -239,6 +240,7 @@ class DeviceTagCreate(BaseModel): if v is None: return v import re + if not re.match(r"^#[0-9A-Fa-f]{6}$", v): raise ValueError("Color must be a valid 6-digit hex color (e.g. #FF5733)") return v @@ -256,6 +258,7 @@ class DeviceTagUpdate(BaseModel): if v is None: return v import re + if not re.match(r"^#[0-9A-Fa-f]{6}$", v): raise ValueError("Color must be a valid 6-digit hex color (e.g. #FF5733)") return v diff --git a/backend/app/schemas/vpn.py b/backend/app/schemas/vpn.py index d36d872..0f383a8 100644 --- a/backend/app/schemas/vpn.py +++ b/backend/app/schemas/vpn.py @@ -12,11 +12,15 @@ from pydantic import BaseModel class VpnSetupRequest(BaseModel): """Request to enable VPN for a tenant.""" - endpoint: Optional[str] = None # public hostname:port — if blank, devices must be configured manually + + endpoint: Optional[str] = ( + None # public hostname:port — if blank, devices must be configured manually + ) class VpnConfigResponse(BaseModel): """VPN server configuration (never exposes private key).""" + model_config = {"from_attributes": True} id: uuid.UUID @@ -33,6 +37,7 @@ class VpnConfigResponse(BaseModel): class VpnConfigUpdate(BaseModel): """Update VPN configuration.""" + endpoint: Optional[str] = None is_enabled: Optional[bool] = None @@ -42,12 +47,14 @@ class VpnConfigUpdate(BaseModel): class VpnPeerCreate(BaseModel): """Add a device as a VPN peer.""" + device_id: uuid.UUID additional_allowed_ips: Optional[str] = None # comma-separated subnets for site-to-site routing class VpnPeerResponse(BaseModel): """VPN peer info (never exposes private key).""" + model_config = {"from_attributes": True} id: uuid.UUID @@ -66,6 +73,7 @@ class VpnPeerResponse(BaseModel): class VpnOnboardRequest(BaseModel): """Combined device creation + VPN peer onboarding.""" + hostname: str username: str password: str @@ -73,6 +81,7 @@ class VpnOnboardRequest(BaseModel): class VpnOnboardResponse(BaseModel): """Response from onboarding — device, peer, and RouterOS commands.""" + device_id: uuid.UUID peer_id: uuid.UUID hostname: str @@ -82,6 +91,7 @@ class VpnOnboardResponse(BaseModel): class VpnPeerConfig(BaseModel): """Full peer config for display/export — includes private key for device setup.""" + peer_private_key: str peer_public_key: str assigned_ip: str diff --git a/backend/app/schemas/winbox_remote.py b/backend/app/schemas/winbox_remote.py index 2e10331..b3d2c25 100644 --- a/backend/app/schemas/winbox_remote.py +++ b/backend/app/schemas/winbox_remote.py @@ -57,6 +57,7 @@ class RemoteWinboxDuplicateDetail(BaseModel): class RemoteWinboxSessionItem(BaseModel): """Used in the combined active sessions list.""" + session_id: uuid.UUID status: RemoteWinboxState created_at: datetime diff --git a/backend/app/services/account_service.py b/backend/app/services/account_service.py index 5339974..af46ddb 100644 --- a/backend/app/services/account_service.py +++ b/backend/app/services/account_service.py @@ -86,11 +86,7 @@ async def delete_user_account( # Null out encrypted_details (may contain encrypted PII) await db.execute( - text( - "UPDATE audit_logs " - "SET encrypted_details = NULL " - "WHERE user_id = :user_id" - ), + text("UPDATE audit_logs SET encrypted_details = NULL WHERE user_id = :user_id"), {"user_id": user_id}, ) @@ -177,16 +173,18 @@ async def export_user_data( ) api_keys = [] for row in result.mappings().all(): - api_keys.append({ - "id": str(row["id"]), - "name": row["name"], - "key_prefix": row["key_prefix"], - "scopes": row["scopes"], - "created_at": row["created_at"].isoformat() if row["created_at"] else None, - "expires_at": row["expires_at"].isoformat() if row["expires_at"] else None, - "revoked_at": row["revoked_at"].isoformat() if row["revoked_at"] else None, - "last_used_at": row["last_used_at"].isoformat() if row["last_used_at"] else None, - }) + api_keys.append( + { + "id": str(row["id"]), + "name": row["name"], + "key_prefix": row["key_prefix"], + "scopes": row["scopes"], + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + "expires_at": row["expires_at"].isoformat() if row["expires_at"] else None, + "revoked_at": row["revoked_at"].isoformat() if row["revoked_at"] else None, + "last_used_at": row["last_used_at"].isoformat() if row["last_used_at"] else None, + } + ) # ── Audit logs (limit 1000, most recent first) ─────────────────────── result = await db.execute( @@ -201,15 +199,17 @@ async def export_user_data( audit_logs = [] for row in result.mappings().all(): details = row["details"] if row["details"] else {} - audit_logs.append({ - "id": str(row["id"]), - "action": row["action"], - "resource_type": row["resource_type"], - "resource_id": row["resource_id"], - "details": details, - "ip_address": row["ip_address"], - "created_at": row["created_at"].isoformat() if row["created_at"] else None, - }) + audit_logs.append( + { + "id": str(row["id"]), + "action": row["action"], + "resource_type": row["resource_type"], + "resource_id": row["resource_id"], + "details": details, + "ip_address": row["ip_address"], + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + } + ) # ── Key access log (limit 1000, most recent first) ─────────────────── result = await db.execute( @@ -222,13 +222,15 @@ async def export_user_data( ) key_access_entries = [] for row in result.mappings().all(): - key_access_entries.append({ - "id": str(row["id"]), - "action": row["action"], - "resource_type": row["resource_type"], - "ip_address": row["ip_address"], - "created_at": row["created_at"].isoformat() if row["created_at"] else None, - }) + key_access_entries.append( + { + "id": str(row["id"]), + "action": row["action"], + "resource_type": row["resource_type"], + "ip_address": row["ip_address"], + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + } + ) return { "export_date": datetime.now(UTC).isoformat(), diff --git a/backend/app/services/alert_evaluator.py b/backend/app/services/alert_evaluator.py index e6d474b..13ff2b7 100644 --- a/backend/app/services/alert_evaluator.py +++ b/backend/app/services/alert_evaluator.py @@ -253,7 +253,9 @@ async def _get_device_groups(device_id: str) -> list[str]: """Get group IDs for a device.""" async with AdminAsyncSessionLocal() as session: result = await session.execute( - text("SELECT group_id FROM device_group_memberships WHERE device_id = CAST(:device_id AS uuid)"), + text( + "SELECT group_id FROM device_group_memberships WHERE device_id = CAST(:device_id AS uuid)" + ), {"device_id": device_id}, ) return [str(row[0]) for row in result.fetchall()] @@ -344,30 +346,36 @@ async def _create_alert_event( # Publish real-time event to NATS for SSE pipeline (fire-and-forget) if status in ("firing", "flapping"): - await publish_event(f"alert.fired.{tenant_id}", { - "event_type": "alert_fired", - "tenant_id": tenant_id, - "device_id": device_id, - "alert_event_id": alert_data["id"], - "severity": severity, - "metric": metric, - "current_value": value, - "threshold": threshold, - "message": message, - "is_flapping": is_flapping, - "fired_at": datetime.now(timezone.utc).isoformat(), - }) + await publish_event( + f"alert.fired.{tenant_id}", + { + "event_type": "alert_fired", + "tenant_id": tenant_id, + "device_id": device_id, + "alert_event_id": alert_data["id"], + "severity": severity, + "metric": metric, + "current_value": value, + "threshold": threshold, + "message": message, + "is_flapping": is_flapping, + "fired_at": datetime.now(timezone.utc).isoformat(), + }, + ) elif status == "resolved": - await publish_event(f"alert.resolved.{tenant_id}", { - "event_type": "alert_resolved", - "tenant_id": tenant_id, - "device_id": device_id, - "alert_event_id": alert_data["id"], - "severity": severity, - "metric": metric, - "message": message, - "resolved_at": datetime.now(timezone.utc).isoformat(), - }) + await publish_event( + f"alert.resolved.{tenant_id}", + { + "event_type": "alert_resolved", + "tenant_id": tenant_id, + "device_id": device_id, + "alert_event_id": alert_data["id"], + "severity": severity, + "metric": metric, + "message": message, + "resolved_at": datetime.now(timezone.utc).isoformat(), + }, + ) return alert_data @@ -470,6 +478,7 @@ async def _dispatch_async(alert_event: dict, channels: list[dict], device_hostna """Fire-and-forget notification dispatch.""" try: from app.services.notification_service import dispatch_notifications + await dispatch_notifications(alert_event, channels, device_hostname) except Exception as e: logger.warning("Notification dispatch failed: %s", e) @@ -500,7 +509,8 @@ async def evaluate( if await _is_device_in_maintenance(tenant_id, device_id): logger.debug( "Alert suppressed by maintenance window for device %s tenant %s", - device_id, tenant_id, + device_id, + tenant_id, ) return @@ -573,7 +583,8 @@ async def evaluate( if is_flapping: logger.info( "Alert %s for device %s is flapping — notifications suppressed", - rule["name"], device_id, + rule["name"], + device_id, ) else: channels = await _get_channels_for_rule(rule["id"]) diff --git a/backend/app/services/audit_service.py b/backend/app/services/audit_service.py index 05c9f26..809d493 100644 --- a/backend/app/services/audit_service.py +++ b/backend/app/services/audit_service.py @@ -56,9 +56,7 @@ async def log_action( try: from app.services.crypto import encrypt_data_transit - encrypted_details = await encrypt_data_transit( - details_json, str(tenant_id) - ) + encrypted_details = await encrypt_data_transit(details_json, str(tenant_id)) # Encryption succeeded — clear plaintext details details_json = _json.dumps({}) except Exception: diff --git a/backend/app/services/backup_scheduler.py b/backend/app/services/backup_scheduler.py index 6a20fe0..eff7d1d 100644 --- a/backend/app/services/backup_scheduler.py +++ b/backend/app/services/backup_scheduler.py @@ -32,8 +32,12 @@ def _cron_to_trigger(cron_expr: str) -> Optional[CronTrigger]: return None minute, hour, day, month, day_of_week = parts return CronTrigger( - minute=minute, hour=hour, day=day, month=month, - day_of_week=day_of_week, timezone="UTC", + minute=minute, + hour=hour, + day=day, + month=month, + day_of_week=day_of_week, + timezone="UTC", ) except Exception as e: logger.warning("Invalid cron expression '%s': %s", cron_expr, e) @@ -52,10 +56,12 @@ def build_schedule_map(schedules: list) -> dict[str, list[dict]]: cron = s.cron_expression or DEFAULT_CRON if cron not in schedule_map: schedule_map[cron] = [] - schedule_map[cron].append({ - "device_id": str(s.device_id), - "tenant_id": str(s.tenant_id), - }) + schedule_map[cron].append( + { + "device_id": str(s.device_id), + "tenant_id": str(s.tenant_id), + } + ) return schedule_map @@ -79,13 +85,15 @@ async def _run_scheduled_backups(devices: list[dict]) -> None: except Exception as e: logger.error( "Scheduled backup FAILED: device %s: %s", - dev_info["device_id"], e, + dev_info["device_id"], + e, ) failure_count += 1 logger.info( "Backup batch complete — %d succeeded, %d failed", - success_count, failure_count, + success_count, + failure_count, ) @@ -108,7 +116,7 @@ async def _load_effective_schedules() -> list: # Index: device-specific and tenant defaults device_schedules = {} # device_id -> schedule - tenant_defaults = {} # tenant_id -> schedule + tenant_defaults = {} # tenant_id -> schedule for s in schedules: if s.device_id: @@ -129,12 +137,14 @@ async def _load_effective_schedules() -> list: # No schedule configured — use system default sched = None - effective.append(SimpleNamespace( - device_id=dev_id, - tenant_id=tenant_id, - cron_expression=sched.cron_expression if sched else DEFAULT_CRON, - enabled=sched.enabled if sched else True, - )) + effective.append( + SimpleNamespace( + device_id=dev_id, + tenant_id=tenant_id, + cron_expression=sched.cron_expression if sched else DEFAULT_CRON, + enabled=sched.enabled if sched else True, + ) + ) return effective @@ -203,6 +213,7 @@ class _SchedulerProxy: Usage: `from app.services.backup_scheduler import backup_scheduler` then `backup_scheduler.add_job(...)`. """ + def __getattr__(self, name): if _scheduler is None: raise RuntimeError("Backup scheduler not started yet") diff --git a/backend/app/services/backup_service.py b/backend/app/services/backup_service.py index e9a50fd..5ad0d8e 100644 --- a/backend/app/services/backup_service.py +++ b/backend/app/services/backup_service.py @@ -255,7 +255,12 @@ async def run_backup( if prior_commits: try: prior_export_bytes = await loop.run_in_executor( - None, git_store.read_file, tenant_id, prior_commits[0]["sha"], device_id, "export.rsc" + None, + git_store.read_file, + tenant_id, + prior_commits[0]["sha"], + device_id, + "export.rsc", ) prior_text = prior_export_bytes.decode("utf-8", errors="replace") lines_added, lines_removed = await loop.run_in_executor( @@ -284,9 +289,7 @@ async def run_backup( try: from app.services.crypto import encrypt_data_transit - encrypted_export = await encrypt_data_transit( - export_text, tenant_id - ) + encrypted_export = await encrypt_data_transit(export_text, tenant_id) encrypted_binary = await encrypt_data_transit( base64.b64encode(binary_backup).decode(), tenant_id ) @@ -302,8 +305,7 @@ async def run_backup( except Exception as enc_err: # Transit unavailable — fall back to plaintext (non-fatal) logger.warning( - "Transit encryption failed for %s backup of device %s, " - "storing plaintext: %s", + "Transit encryption failed for %s backup of device %s, storing plaintext: %s", trigger_type, device_id, enc_err, @@ -313,9 +315,7 @@ async def run_backup( # ----------------------------------------------------------------------- # Step 6: Commit to git (wrapped in run_in_executor — pygit2 is sync C bindings) # ----------------------------------------------------------------------- - commit_message = ( - f"{trigger_type}: {hostname} ({ip}) at {ts}" - ) + commit_message = f"{trigger_type}: {hostname} ({ip}) at {ts}" commit_sha = await loop.run_in_executor( None, diff --git a/backend/app/services/ca_service.py b/backend/app/services/ca_service.py index ba5c3cf..6b14ddb 100644 --- a/backend/app/services/ca_service.py +++ b/backend/app/services/ca_service.py @@ -48,6 +48,7 @@ _VALID_TRANSITIONS: dict[str, set[str]] = { # CA Generation # --------------------------------------------------------------------------- + async def generate_ca( db: AsyncSession, tenant_id: UUID, @@ -84,10 +85,12 @@ async def generate_ca( now = datetime.datetime.now(datetime.timezone.utc) expiry = now + datetime.timedelta(days=365 * validity_years) - subject = issuer = x509.Name([ - x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"), - x509.NameAttribute(NameOID.COMMON_NAME, common_name), - ]) + subject = issuer = x509.Name( + [ + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"), + x509.NameAttribute(NameOID.COMMON_NAME, common_name), + ] + ) ca_cert = ( x509.CertificateBuilder() @@ -97,9 +100,7 @@ async def generate_ca( .serial_number(x509.random_serial_number()) .not_valid_before(now) .not_valid_after(expiry) - .add_extension( - x509.BasicConstraints(ca=True, path_length=0), critical=True - ) + .add_extension(x509.BasicConstraints(ca=True, path_length=0), critical=True) .add_extension( x509.KeyUsage( digital_signature=True, @@ -166,6 +167,7 @@ async def generate_ca( # Device Certificate Signing # --------------------------------------------------------------------------- + async def sign_device_cert( db: AsyncSession, ca: CertificateAuthority, @@ -196,9 +198,7 @@ async def sign_device_cert( str(ca.tenant_id), encryption_key, ) - ca_key = serialization.load_pem_private_key( - ca_key_pem.encode("utf-8"), password=None - ) + ca_key = serialization.load_pem_private_key(ca_key_pem.encode("utf-8"), password=None) # Load CA certificate for issuer info and AuthorityKeyIdentifier ca_cert = x509.load_pem_x509_certificate(ca.cert_pem.encode("utf-8")) @@ -212,19 +212,19 @@ async def sign_device_cert( device_cert = ( x509.CertificateBuilder() .subject_name( - x509.Name([ - x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"), - x509.NameAttribute(NameOID.COMMON_NAME, hostname), - ]) + x509.Name( + [ + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"), + x509.NameAttribute(NameOID.COMMON_NAME, hostname), + ] + ) ) .issuer_name(ca_cert.subject) .public_key(device_key.public_key()) .serial_number(x509.random_serial_number()) .not_valid_before(now) .not_valid_after(expiry) - .add_extension( - x509.BasicConstraints(ca=False, path_length=None), critical=True - ) + .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True) .add_extension( x509.KeyUsage( digital_signature=True, @@ -244,17 +244,17 @@ async def sign_device_cert( critical=False, ) .add_extension( - x509.SubjectAlternativeName([ - x509.IPAddress(ipaddress.ip_address(ip_address)), - x509.DNSName(hostname), - ]), + x509.SubjectAlternativeName( + [ + x509.IPAddress(ipaddress.ip_address(ip_address)), + x509.DNSName(hostname), + ] + ), critical=False, ) .add_extension( x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier( - ca_cert.extensions.get_extension_for_class( - x509.SubjectKeyIdentifier - ).value + ca_cert.extensions.get_extension_for_class(x509.SubjectKeyIdentifier).value ), critical=False, ) @@ -308,15 +308,14 @@ async def sign_device_cert( # Queries # --------------------------------------------------------------------------- + async def get_ca_for_tenant( db: AsyncSession, tenant_id: UUID, ) -> CertificateAuthority | None: """Return the tenant's CA, or None if not yet initialized.""" result = await db.execute( - select(CertificateAuthority).where( - CertificateAuthority.tenant_id == tenant_id - ) + select(CertificateAuthority).where(CertificateAuthority.tenant_id == tenant_id) ) return result.scalar_one_or_none() @@ -352,6 +351,7 @@ async def get_device_certs( # Status Management # --------------------------------------------------------------------------- + async def update_cert_status( db: AsyncSession, cert_id: UUID, @@ -377,9 +377,7 @@ async def update_cert_status( Raises: ValueError: If the certificate is not found or the transition is invalid. """ - result = await db.execute( - select(DeviceCertificate).where(DeviceCertificate.id == cert_id) - ) + result = await db.execute(select(DeviceCertificate).where(DeviceCertificate.id == cert_id)) cert = result.scalar_one_or_none() if cert is None: raise ValueError(f"Device certificate {cert_id} not found") @@ -413,6 +411,7 @@ async def update_cert_status( # Cert Data for Deployment # --------------------------------------------------------------------------- + async def get_cert_for_deploy( db: AsyncSession, cert_id: UUID, @@ -434,18 +433,14 @@ async def get_cert_for_deploy( Raises: ValueError: If the certificate or its CA is not found. """ - result = await db.execute( - select(DeviceCertificate).where(DeviceCertificate.id == cert_id) - ) + result = await db.execute(select(DeviceCertificate).where(DeviceCertificate.id == cert_id)) cert = result.scalar_one_or_none() if cert is None: raise ValueError(f"Device certificate {cert_id} not found") # Fetch the CA for the ca_cert_pem ca_result = await db.execute( - select(CertificateAuthority).where( - CertificateAuthority.id == cert.ca_id - ) + select(CertificateAuthority).where(CertificateAuthority.id == cert.ca_id) ) ca = ca_result.scalar_one_or_none() if ca is None: diff --git a/backend/app/services/config_change_parser.py b/backend/app/services/config_change_parser.py index 7c18ea1..6ea3434 100644 --- a/backend/app/services/config_change_parser.py +++ b/backend/app/services/config_change_parser.py @@ -64,9 +64,7 @@ def parse_diff_changes(diff_text: str) -> list[dict]: current_section: str | None = None # Track per-component: adds, removes, raw_lines - components: dict[str, dict] = defaultdict( - lambda: {"adds": 0, "removes": 0, "raw_lines": []} - ) + components: dict[str, dict] = defaultdict(lambda: {"adds": 0, "removes": 0, "raw_lines": []}) for line in lines: # Skip unified diff headers diff --git a/backend/app/services/config_change_subscriber.py b/backend/app/services/config_change_subscriber.py index fec1969..d1bbd91 100644 --- a/backend/app/services/config_change_subscriber.py +++ b/backend/app/services/config_change_subscriber.py @@ -51,13 +51,15 @@ async def handle_config_changed(event: dict) -> None: if await _last_backup_within_dedup_window(device_id): logger.info( "Config change on device %s — skipping backup (within %dm dedup window)", - device_id, DEDUP_WINDOW_MINUTES, + device_id, + DEDUP_WINDOW_MINUTES, ) return logger.info( "Config change detected on device %s (tenant %s): %s -> %s", - device_id, tenant_id, + device_id, + tenant_id, event.get("old_timestamp", "?"), event.get("new_timestamp", "?"), ) diff --git a/backend/app/services/config_diff_service.py b/backend/app/services/config_diff_service.py index 336a4a5..0f1e44a 100644 --- a/backend/app/services/config_diff_service.py +++ b/backend/app/services/config_diff_service.py @@ -65,9 +65,7 @@ async def generate_and_store_diff( # 2. No previous snapshot = first snapshot for device if prev_row is None: - logger.debug( - "First snapshot for device %s, no diff to generate", device_id - ) + logger.debug("First snapshot for device %s, no diff to generate", device_id) return old_snapshot_id = prev_row._mapping["id"] @@ -103,9 +101,7 @@ async def generate_and_store_diff( # 6. Generate unified diff old_lines = old_plaintext.decode("utf-8").splitlines() new_lines = new_plaintext.decode("utf-8").splitlines() - diff_lines = list( - difflib.unified_diff(old_lines, new_lines, lineterm="", n=3) - ) + diff_lines = list(difflib.unified_diff(old_lines, new_lines, lineterm="", n=3)) # 7. If empty diff, skip INSERT diff_text = "\n".join(diff_lines) @@ -118,12 +114,10 @@ async def generate_and_store_diff( # 8. Count lines added/removed (exclude +++ and --- headers) lines_added = sum( - 1 for line in diff_lines - if line.startswith("+") and not line.startswith("++") + 1 for line in diff_lines if line.startswith("+") and not line.startswith("++") ) lines_removed = sum( - 1 for line in diff_lines - if line.startswith("-") and not line.startswith("--") + 1 for line in diff_lines if line.startswith("-") and not line.startswith("--") ) # 9. INSERT into router_config_diffs (RETURNING id for change parser) @@ -165,6 +159,7 @@ async def generate_and_store_diff( try: from app.services.audit_service import log_action import uuid as _uuid + await log_action( db=None, tenant_id=_uuid.UUID(tenant_id), @@ -207,12 +202,16 @@ async def generate_and_store_diff( await session.commit() logger.info( "Stored %d config changes for device %s diff %s", - len(changes), device_id, diff_id, + len(changes), + device_id, + diff_id, ) except Exception as exc: logger.warning( "Change parser error for device %s diff %s (non-fatal): %s", - device_id, diff_id, exc, + device_id, + diff_id, + exc, ) config_diff_errors_total.labels(error_type="change_parser").inc() diff --git a/backend/app/services/config_snapshot_subscriber.py b/backend/app/services/config_snapshot_subscriber.py index 63de766..17a8cba 100644 --- a/backend/app/services/config_snapshot_subscriber.py +++ b/backend/app/services/config_snapshot_subscriber.py @@ -89,9 +89,11 @@ async def handle_config_snapshot(msg) -> None: collected_at_raw = data.get("collected_at") try: - collected_at = datetime.fromisoformat( - collected_at_raw.replace("Z", "+00:00") - ) if collected_at_raw else datetime.now(timezone.utc) + collected_at = ( + datetime.fromisoformat(collected_at_raw.replace("Z", "+00:00")) + if collected_at_raw + else datetime.now(timezone.utc) + ) except (ValueError, AttributeError): collected_at = datetime.now(timezone.utc) @@ -131,9 +133,7 @@ async def handle_config_snapshot(msg) -> None: # --- Encrypt via OpenBao Transit --- openbao = OpenBaoTransitService() try: - encrypted_text = await openbao.encrypt( - tenant_id, config_text.encode("utf-8") - ) + encrypted_text = await openbao.encrypt(tenant_id, config_text.encode("utf-8")) except Exception as exc: logger.warning( "Transit encrypt failed for device %s tenant %s: %s", @@ -207,7 +207,8 @@ async def handle_config_snapshot(msg) -> None: except Exception as exc: logger.warning( "Diff generation failed for device %s (non-fatal): %s", - device_id, exc, + device_id, + exc, ) logger.info( @@ -233,9 +234,7 @@ async def _subscribe_with_retry(js) -> None: stream="DEVICE_EVENTS", manual_ack=True, ) - logger.info( - "NATS: subscribed to config.snapshot.> (durable: config_snapshot_ingest)" - ) + logger.info("NATS: subscribed to config.snapshot.> (durable: config_snapshot_ingest)") return except Exception as exc: if attempt < max_attempts: diff --git a/backend/app/services/device.py b/backend/app/services/device.py index 627ff49..0aef536 100644 --- a/backend/app/services/device.py +++ b/backend/app/services/device.py @@ -29,8 +29,6 @@ from app.models.device import ( DeviceTagAssignment, ) from app.schemas.device import ( - BulkAddRequest, - BulkAddResult, DeviceCreate, DeviceGroupCreate, DeviceGroupResponse, @@ -43,9 +41,7 @@ from app.schemas.device import ( ) from app.config import settings from app.services.crypto import ( - decrypt_credentials, decrypt_credentials_hybrid, - encrypt_credentials, encrypt_credentials_transit, ) @@ -58,9 +54,7 @@ from app.services.crypto import ( async def _tcp_reachable(ip: str, port: int, timeout: float = 3.0) -> bool: """Return True if a TCP connection to ip:port succeeds within timeout.""" try: - _, writer = await asyncio.wait_for( - asyncio.open_connection(ip, port), timeout=timeout - ) + _, writer = await asyncio.wait_for(asyncio.open_connection(ip, port), timeout=timeout) writer.close() try: await writer.wait_closed() @@ -151,6 +145,7 @@ async def create_device( if not api_reachable and not ssl_reachable: from fastapi import HTTPException, status + raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=( @@ -162,9 +157,7 @@ async def create_device( # Encrypt credentials via OpenBao Transit (new writes go through Transit) credentials_json = json.dumps({"username": data.username, "password": data.password}) - transit_ciphertext = await encrypt_credentials_transit( - credentials_json, str(tenant_id) - ) + transit_ciphertext = await encrypt_credentials_transit(credentials_json, str(tenant_id)) device = Device( tenant_id=tenant_id, @@ -180,9 +173,7 @@ async def create_device( await db.refresh(device) # Re-query with relationships loaded - result = await db.execute( - _device_with_relations().where(Device.id == device.id) - ) + result = await db.execute(_device_with_relations().where(Device.id == device.id)) device = result.scalar_one() return _build_device_response(device) @@ -223,9 +214,7 @@ async def get_devices( if tag_id: base_q = base_q.where( Device.id.in_( - select(DeviceTagAssignment.device_id).where( - DeviceTagAssignment.tag_id == tag_id - ) + select(DeviceTagAssignment.device_id).where(DeviceTagAssignment.tag_id == tag_id) ) ) @@ -274,9 +263,7 @@ async def get_device( """Get a single device by ID.""" from fastapi import HTTPException, status - result = await db.execute( - _device_with_relations().where(Device.id == device_id) - ) + result = await db.execute(_device_with_relations().where(Device.id == device_id)) device = result.scalar_one_or_none() if not device: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found") @@ -295,9 +282,7 @@ async def update_device( """ from fastapi import HTTPException, status - result = await db.execute( - _device_with_relations().where(Device.id == device_id) - ) + result = await db.execute(_device_with_relations().where(Device.id == device_id)) device = result.scalar_one_or_none() if not device: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found") @@ -323,7 +308,9 @@ async def update_device( if data.password is not None: # Decrypt existing to get current username if no new username given current_username: str = data.username or "" - if not current_username and (device.encrypted_credentials_transit or device.encrypted_credentials): + if not current_username and ( + device.encrypted_credentials_transit or device.encrypted_credentials + ): try: existing_json = await decrypt_credentials_hybrid( device.encrypted_credentials_transit, @@ -336,17 +323,21 @@ async def update_device( except Exception: current_username = "" - credentials_json = json.dumps({ - "username": data.username if data.username is not None else current_username, - "password": data.password, - }) + credentials_json = json.dumps( + { + "username": data.username if data.username is not None else current_username, + "password": data.password, + } + ) # New writes go through Transit device.encrypted_credentials_transit = await encrypt_credentials_transit( credentials_json, str(device.tenant_id) ) device.encrypted_credentials = None # Clear legacy (Transit is canonical) credentials_changed = True - elif data.username is not None and (device.encrypted_credentials_transit or device.encrypted_credentials): + elif data.username is not None and ( + device.encrypted_credentials_transit or device.encrypted_credentials + ): # Only username changed — update it without changing the password try: existing_json = await decrypt_credentials_hybrid( @@ -373,6 +364,7 @@ async def update_device( if credentials_changed: try: from app.services.event_publisher import publish_event + await publish_event( f"device.credential_changed.{device_id}", {"device_id": str(device_id), "tenant_id": str(tenant_id)}, @@ -380,9 +372,7 @@ async def update_device( except Exception: pass # Never fail the update due to NATS issues - result2 = await db.execute( - _device_with_relations().where(Device.id == device_id) - ) + result2 = await db.execute(_device_with_relations().where(Device.id == device_id)) device = result2.scalar_one() return _build_device_response(device) @@ -526,11 +516,7 @@ async def get_groups( tenant_id: uuid.UUID, ) -> list[DeviceGroupResponse]: """Return all device groups for the current tenant with device counts.""" - result = await db.execute( - select(DeviceGroup).options( - selectinload(DeviceGroup.memberships) - ) - ) + result = await db.execute(select(DeviceGroup).options(selectinload(DeviceGroup.memberships))) groups = result.scalars().all() return [ DeviceGroupResponse( @@ -554,9 +540,9 @@ async def update_group( from fastapi import HTTPException, status result = await db.execute( - select(DeviceGroup).options( - selectinload(DeviceGroup.memberships) - ).where(DeviceGroup.id == group_id) + select(DeviceGroup) + .options(selectinload(DeviceGroup.memberships)) + .where(DeviceGroup.id == group_id) ) group = result.scalar_one_or_none() if not group: @@ -571,9 +557,9 @@ async def update_group( await db.refresh(group) result2 = await db.execute( - select(DeviceGroup).options( - selectinload(DeviceGroup.memberships) - ).where(DeviceGroup.id == group_id) + select(DeviceGroup) + .options(selectinload(DeviceGroup.memberships)) + .where(DeviceGroup.id == group_id) ) group = result2.scalar_one() return DeviceGroupResponse( diff --git a/backend/app/services/emergency_kit_service.py b/backend/app/services/emergency_kit_service.py index 41171bf..abed9d5 100644 --- a/backend/app/services/emergency_kit_service.py +++ b/backend/app/services/emergency_kit_service.py @@ -48,7 +48,5 @@ async def generate_emergency_kit_template( # Run weasyprint in thread to avoid blocking the event loop from weasyprint import HTML - pdf_bytes = await asyncio.to_thread( - lambda: HTML(string=html_content).write_pdf() - ) + pdf_bytes = await asyncio.to_thread(lambda: HTML(string=html_content).write_pdf()) return pdf_bytes diff --git a/backend/app/services/firmware_service.py b/backend/app/services/firmware_service.py index 58cd7c2..815c71b 100644 --- a/backend/app/services/firmware_service.py +++ b/backend/app/services/firmware_service.py @@ -13,7 +13,6 @@ Version discovery comes from two sources: """ import logging -import os from pathlib import Path import httpx @@ -51,13 +50,12 @@ async def check_latest_versions() -> list[dict]: async with httpx.AsyncClient(timeout=30.0) as client: for channel_file, channel, major in _VERSION_SOURCES: try: - resp = await client.get( - f"https://download.mikrotik.com/routeros/{channel_file}" - ) + resp = await client.get(f"https://download.mikrotik.com/routeros/{channel_file}") if resp.status_code != 200: logger.warning( "MikroTik version check returned %d for %s", - resp.status_code, channel_file, + resp.status_code, + channel_file, ) continue @@ -72,12 +70,14 @@ async def check_latest_versions() -> list[dict]: f"https://download.mikrotik.com/routeros/" f"{version}/routeros-{version}-{arch}.npk" ) - results.append({ - "architecture": arch, - "channel": channel, - "version": version, - "npk_url": npk_url, - }) + results.append( + { + "architecture": arch, + "channel": channel, + "version": version, + "npk_url": npk_url, + } + ) except Exception as e: logger.warning("Failed to check %s: %s", channel_file, e) @@ -242,15 +242,15 @@ async def get_firmware_overview(tenant_id: str) -> dict: groups = [] for ver, devs in sorted(version_groups.items()): # A version is "latest" if it matches the latest for any arch/channel combo - is_latest = any( - v["version"] == ver for v in latest_versions.values() + is_latest = any(v["version"] == ver for v in latest_versions.values()) + groups.append( + { + "version": ver, + "count": len(devs), + "is_latest": is_latest, + "devices": devs, + } ) - groups.append({ - "version": ver, - "count": len(devs), - "is_latest": is_latest, - "devices": devs, - }) return { "devices": device_list, @@ -272,12 +272,14 @@ async def get_cached_firmware() -> list[dict]: continue for npk_file in sorted(version_dir.iterdir()): if npk_file.suffix == ".npk": - cached.append({ - "path": str(npk_file), - "version": version_dir.name, - "filename": npk_file.name, - "size_bytes": npk_file.stat().st_size, - }) + cached.append( + { + "path": str(npk_file), + "version": version_dir.name, + "filename": npk_file.name, + "size_bytes": npk_file.stat().st_size, + } + ) return cached diff --git a/backend/app/services/firmware_subscriber.py b/backend/app/services/firmware_subscriber.py index 36ed39c..112d065 100644 --- a/backend/app/services/firmware_subscriber.py +++ b/backend/app/services/firmware_subscriber.py @@ -41,7 +41,7 @@ async def on_device_firmware(msg) -> None: try: data = json.loads(msg.data) device_id = data.get("device_id") - tenant_id = data.get("tenant_id") + _tenant_id = data.get("tenant_id") architecture = data.get("architecture") installed_version = data.get("installed_version") latest_version = data.get("latest_version") @@ -126,9 +126,7 @@ async def _subscribe_with_retry(js: JetStreamContext) -> None: durable="api-firmware-consumer", stream="DEVICE_EVENTS", ) - logger.info( - "NATS: subscribed to device.firmware.> (durable: api-firmware-consumer)" - ) + logger.info("NATS: subscribed to device.firmware.> (durable: api-firmware-consumer)") return except Exception as exc: if attempt < max_attempts: diff --git a/backend/app/services/git_store.py b/backend/app/services/git_store.py index cc52e48..2a593a8 100644 --- a/backend/app/services/git_store.py +++ b/backend/app/services/git_store.py @@ -238,11 +238,13 @@ def list_device_commits( # Device wasn't in parent but is in this commit — it's the first entry. pass - results.append({ - "sha": str(commit.id), - "message": commit.message.strip(), - "timestamp": commit.commit_time, - }) + results.append( + { + "sha": str(commit.id), + "message": commit.message.strip(), + "timestamp": commit.commit_time, + } + ) return results diff --git a/backend/app/services/key_service.py b/backend/app/services/key_service.py index 8a7b278..127ea0e 100644 --- a/backend/app/services/key_service.py +++ b/backend/app/services/key_service.py @@ -52,6 +52,7 @@ async def store_user_key_set( """ # Remove any existing key set (e.g. from a failed prior upgrade attempt) from sqlalchemy import delete + await db.execute(delete(UserKeySet).where(UserKeySet.user_id == user_id)) key_set = UserKeySet( @@ -71,9 +72,7 @@ async def store_user_key_set( return key_set -async def get_user_key_set( - db: AsyncSession, user_id: UUID -) -> UserKeySet | None: +async def get_user_key_set(db: AsyncSession, user_id: UUID) -> UserKeySet | None: """Retrieve encrypted key bundle for login response. Args: @@ -83,9 +82,7 @@ async def get_user_key_set( Returns: The UserKeySet if found, None otherwise. """ - result = await db.execute( - select(UserKeySet).where(UserKeySet.user_id == user_id) - ) + result = await db.execute(select(UserKeySet).where(UserKeySet.user_id == user_id)) return result.scalar_one_or_none() @@ -163,9 +160,7 @@ async def provision_tenant_key(db: AsyncSession, tenant_id: UUID) -> str: await openbao.create_tenant_key(str(tenant_id)) # Update tenant record with key name - result = await db.execute( - select(Tenant).where(Tenant.id == tenant_id) - ) + result = await db.execute(select(Tenant).where(Tenant.id == tenant_id)) tenant = result.scalar_one_or_none() if tenant: tenant.openbao_key_name = key_name @@ -210,13 +205,18 @@ async def migrate_tenant_credentials(db: AsyncSession, tenant_id: UUID) -> dict: select(Device).where( Device.tenant_id == tenant_id, Device.encrypted_credentials.isnot(None), - (Device.encrypted_credentials_transit.is_(None) | (Device.encrypted_credentials_transit == "")), + ( + Device.encrypted_credentials_transit.is_(None) + | (Device.encrypted_credentials_transit == "") + ), ) ) for device in result.scalars().all(): try: plaintext = decrypt_credentials(device.encrypted_credentials, legacy_key) - device.encrypted_credentials_transit = await openbao.encrypt(tid, plaintext.encode("utf-8")) + device.encrypted_credentials_transit = await openbao.encrypt( + tid, plaintext.encode("utf-8") + ) counts["devices"] += 1 except Exception as e: logger.error("Failed to migrate device %s credentials: %s", device.id, e) @@ -227,7 +227,10 @@ async def migrate_tenant_credentials(db: AsyncSession, tenant_id: UUID) -> dict: select(CertificateAuthority).where( CertificateAuthority.tenant_id == tenant_id, CertificateAuthority.encrypted_private_key.isnot(None), - (CertificateAuthority.encrypted_private_key_transit.is_(None) | (CertificateAuthority.encrypted_private_key_transit == "")), + ( + CertificateAuthority.encrypted_private_key_transit.is_(None) + | (CertificateAuthority.encrypted_private_key_transit == "") + ), ) ) for ca in result.scalars().all(): @@ -244,13 +247,18 @@ async def migrate_tenant_credentials(db: AsyncSession, tenant_id: UUID) -> dict: select(DeviceCertificate).where( DeviceCertificate.tenant_id == tenant_id, DeviceCertificate.encrypted_private_key.isnot(None), - (DeviceCertificate.encrypted_private_key_transit.is_(None) | (DeviceCertificate.encrypted_private_key_transit == "")), + ( + DeviceCertificate.encrypted_private_key_transit.is_(None) + | (DeviceCertificate.encrypted_private_key_transit == "") + ), ) ) for cert in result.scalars().all(): try: plaintext = decrypt_credentials(cert.encrypted_private_key, legacy_key) - cert.encrypted_private_key_transit = await openbao.encrypt(tid, plaintext.encode("utf-8")) + cert.encrypted_private_key_transit = await openbao.encrypt( + tid, plaintext.encode("utf-8") + ) counts["certs"] += 1 except Exception as e: logger.error("Failed to migrate cert %s private key: %s", cert.id, e) @@ -303,7 +311,14 @@ async def provision_existing_tenants(db: AsyncSession) -> dict: result = await db.execute(select(Tenant)) tenants = result.scalars().all() - total = {"tenants": len(tenants), "devices": 0, "cas": 0, "certs": 0, "channels": 0, "errors": 0} + total = { + "tenants": len(tenants), + "devices": 0, + "cas": 0, + "certs": 0, + "channels": 0, + "errors": 0, + } for tenant in tenants: try: diff --git a/backend/app/services/metrics_subscriber.py b/backend/app/services/metrics_subscriber.py index f637c31..791a130 100644 --- a/backend/app/services/metrics_subscriber.py +++ b/backend/app/services/metrics_subscriber.py @@ -199,9 +199,7 @@ async def on_device_metrics(msg) -> None: device_id = data.get("device_id") if not metric_type or not device_id: - logger.warning( - "device.metrics event missing 'type' or 'device_id' — skipping" - ) + logger.warning("device.metrics event missing 'type' or 'device_id' — skipping") await msg.ack() return @@ -222,6 +220,7 @@ async def on_device_metrics(msg) -> None: # Alert evaluation — non-fatal; metric write is the primary operation try: from app.services import alert_evaluator + await alert_evaluator.evaluate( device_id=device_id, tenant_id=data.get("tenant_id", ""), @@ -265,9 +264,7 @@ async def _subscribe_with_retry(js: JetStreamContext) -> None: durable="api-metrics-consumer", stream="DEVICE_EVENTS", ) - logger.info( - "NATS: subscribed to device.metrics.> (durable: api-metrics-consumer)" - ) + logger.info("NATS: subscribed to device.metrics.> (durable: api-metrics-consumer)") return except Exception as exc: if attempt < max_attempts: diff --git a/backend/app/services/nats_subscriber.py b/backend/app/services/nats_subscriber.py index 123127e..511bee7 100644 --- a/backend/app/services/nats_subscriber.py +++ b/backend/app/services/nats_subscriber.py @@ -69,7 +69,9 @@ async def on_device_status(msg) -> None: uptime_seconds = _parse_uptime(data.get("uptime", "")) if not device_id or not status: - logger.warning("Received device.status event with missing device_id or status — skipping") + logger.warning( + "Received device.status event with missing device_id or status — skipping" + ) await msg.ack() return @@ -115,12 +117,15 @@ async def on_device_status(msg) -> None: # Alert evaluation for offline/online status changes — non-fatal try: from app.services import alert_evaluator + if status == "offline": await alert_evaluator.evaluate_offline(device_id, data.get("tenant_id", "")) elif status == "online": await alert_evaluator.evaluate_online(device_id, data.get("tenant_id", "")) except Exception as e: - logger.warning("Alert evaluation failed for device %s status=%s: %s", device_id, status, e) + logger.warning( + "Alert evaluation failed for device %s status=%s: %s", device_id, status, e + ) logger.info( "Device status updated", diff --git a/backend/app/services/notification_service.py b/backend/app/services/notification_service.py index e4e0b76..7e19ce2 100644 --- a/backend/app/services/notification_service.py +++ b/backend/app/services/notification_service.py @@ -39,7 +39,9 @@ async def dispatch_notifications( except Exception as e: logger.warning( "Notification delivery failed for channel %s (%s): %s", - channel.get("name"), channel.get("channel_type"), e, + channel.get("name"), + channel.get("channel_type"), + e, ) @@ -100,14 +102,18 @@ async def _send_email(channel: dict, alert_event: dict, device_hostname: str) -> if transit_cipher and tenant_id: try: from app.services.kms_service import decrypt_transit + smtp_password = await decrypt_transit(transit_cipher, tenant_id) except Exception: - logger.warning("Transit decryption failed for channel %s, trying legacy", channel.get("id")) + logger.warning( + "Transit decryption failed for channel %s, trying legacy", channel.get("id") + ) if not smtp_password and legacy_cipher: try: from app.config import settings as app_settings from cryptography.fernet import Fernet + raw = bytes(legacy_cipher) if isinstance(legacy_cipher, memoryview) else legacy_cipher f = Fernet(app_settings.CREDENTIAL_ENCRYPTION_KEY.encode()) smtp_password = f.decrypt(raw).decode() @@ -163,7 +169,8 @@ async def _send_webhook( response = await client.post(webhook_url, json=payload) logger.info( "Webhook notification sent to %s — status %d", - webhook_url, response.status_code, + webhook_url, + response.status_code, ) @@ -180,13 +187,18 @@ async def _send_slack( value = alert_event.get("value") threshold = alert_event.get("threshold") - color = {"CRITICAL": "#dc2626", "WARNING": "#f59e0b", "INFO": "#3b82f6"}.get(severity, "#6b7280") + color = {"CRITICAL": "#dc2626", "WARNING": "#f59e0b", "INFO": "#3b82f6"}.get( + severity, "#6b7280" + ) status_label = "RESOLVED" if status == "resolved" else status blocks = [ { "type": "header", - "text": {"type": "plain_text", "text": f"{'✅' if status == 'resolved' else '🚨'} [{severity}] {status_label.upper()}"}, + "text": { + "type": "plain_text", + "text": f"{'✅' if status == 'resolved' else '🚨'} [{severity}] {status_label.upper()}", + }, }, { "type": "section", @@ -205,7 +217,9 @@ async def _send_slack( blocks.append({"type": "section", "fields": fields}) if message_text: - blocks.append({"type": "section", "text": {"type": "mrkdwn", "text": f"*Message:*\n{message_text}"}}) + blocks.append( + {"type": "section", "text": {"type": "mrkdwn", "text": f"*Message:*\n{message_text}"}} + ) blocks.append({"type": "context", "elements": [{"type": "mrkdwn", "text": "TOD Alert System"}]}) diff --git a/backend/app/services/openbao_service.py b/backend/app/services/openbao_service.py index a7d6f83..b4ccef3 100644 --- a/backend/app/services/openbao_service.py +++ b/backend/app/services/openbao_service.py @@ -7,6 +7,7 @@ material never leaves OpenBao -- the application only sees ciphertext. Ciphertext format: "vault:v1:base64..." (compatible with Vault Transit format) """ + import base64 import logging from typing import Optional diff --git a/backend/app/services/push_tracker.py b/backend/app/services/push_tracker.py index 41d209d..d622087 100644 --- a/backend/app/services/push_tracker.py +++ b/backend/app/services/push_tracker.py @@ -47,13 +47,15 @@ async def record_push( """ r = await _get_redis() key = f"push:recent:{device_id}" - value = json.dumps({ - "device_id": device_id, - "tenant_id": tenant_id, - "push_type": push_type, - "push_operation_id": push_operation_id, - "pre_push_commit_sha": pre_push_commit_sha, - }) + value = json.dumps( + { + "device_id": device_id, + "tenant_id": tenant_id, + "push_type": push_type, + "push_operation_id": push_operation_id, + "pre_push_commit_sha": pre_push_commit_sha, + } + ) await r.set(key, value, ex=PUSH_TTL_SECONDS) logger.debug( "Recorded push for device %s (type=%s, TTL=%ds)", diff --git a/backend/app/services/report_service.py b/backend/app/services/report_service.py index db9177a..27de757 100644 --- a/backend/app/services/report_service.py +++ b/backend/app/services/report_service.py @@ -164,16 +164,18 @@ async def _device_inventory( uptime_str = _format_uptime(row[6]) if row[6] else None last_seen_str = row[5].strftime("%Y-%m-%d %H:%M") if row[5] else None - devices.append({ - "hostname": row[0], - "ip_address": row[1], - "model": row[2], - "routeros_version": row[3], - "status": status, - "last_seen": last_seen_str, - "uptime": uptime_str, - "groups": row[7] if row[7] else None, - }) + devices.append( + { + "hostname": row[0], + "ip_address": row[1], + "model": row[2], + "routeros_version": row[3], + "status": status, + "last_seen": last_seen_str, + "uptime": uptime_str, + "groups": row[7] if row[7] else None, + } + ) return { "report_title": "Device Inventory", @@ -224,16 +226,18 @@ async def _metrics_summary( devices = [] for row in rows: - devices.append({ - "hostname": row[0], - "avg_cpu": float(row[1]) if row[1] is not None else None, - "peak_cpu": float(row[2]) if row[2] is not None else None, - "avg_mem": float(row[3]) if row[3] is not None else None, - "peak_mem": float(row[4]) if row[4] is not None else None, - "avg_disk": float(row[5]) if row[5] is not None else None, - "avg_temp": float(row[6]) if row[6] is not None else None, - "data_points": row[7], - }) + devices.append( + { + "hostname": row[0], + "avg_cpu": float(row[1]) if row[1] is not None else None, + "peak_cpu": float(row[2]) if row[2] is not None else None, + "avg_mem": float(row[3]) if row[3] is not None else None, + "peak_mem": float(row[4]) if row[4] is not None else None, + "avg_disk": float(row[5]) if row[5] is not None else None, + "avg_temp": float(row[6]) if row[6] is not None else None, + "data_points": row[7], + } + ) return { "report_title": "Metrics Summary", @@ -287,14 +291,16 @@ async def _alert_history( if duration_secs is not None: resolved_durations.append(duration_secs) - alerts.append({ - "fired_at": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-", - "hostname": row[5], - "severity": severity, - "status": row[3], - "message": row[4], - "duration": _format_duration(duration_secs) if duration_secs is not None else None, - }) + alerts.append( + { + "fired_at": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-", + "hostname": row[5], + "severity": severity, + "status": row[3], + "message": row[4], + "duration": _format_duration(duration_secs) if duration_secs is not None else None, + } + ) mttr_minutes = None mttr_display = None @@ -374,13 +380,15 @@ async def _change_log_from_audit( entries = [] for row in rows: - entries.append({ - "timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-", - "user": row[1], - "action": row[2], - "device": row[3], - "details": row[4] or row[5] or "", - }) + entries.append( + { + "timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-", + "user": row[1], + "action": row[2], + "device": row[3], + "details": row[4] or row[5] or "", + } + ) return { "report_title": "Change Log", @@ -436,21 +444,25 @@ async def _change_log_from_backups( # Merge and sort by timestamp descending entries = [] for row in backup_rows: - entries.append({ - "timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-", - "user": row[1], - "action": row[2], - "device": row[3], - "details": row[4] or "", - }) + entries.append( + { + "timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-", + "user": row[1], + "action": row[2], + "device": row[3], + "details": row[4] or "", + } + ) for row in alert_rows: - entries.append({ - "timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-", - "user": row[1], - "action": row[2], - "device": row[3], - "details": row[4] or "", - }) + entries.append( + { + "timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-", + "user": row[1], + "action": row[2], + "device": row[3], + "details": row[4] or "", + } + ) # Sort by timestamp string descending entries.sort(key=lambda e: e["timestamp"], reverse=True) @@ -486,54 +498,102 @@ def _render_csv(report_type: str, data: dict[str, Any]) -> bytes: writer = csv.writer(output) if report_type == "device_inventory": - writer.writerow([ - "Hostname", "IP Address", "Model", "RouterOS Version", - "Status", "Last Seen", "Uptime", "Groups", - ]) + writer.writerow( + [ + "Hostname", + "IP Address", + "Model", + "RouterOS Version", + "Status", + "Last Seen", + "Uptime", + "Groups", + ] + ) for d in data.get("devices", []): - writer.writerow([ - d["hostname"], d["ip_address"], d["model"] or "", - d["routeros_version"] or "", d["status"], - d["last_seen"] or "", d["uptime"] or "", - d["groups"] or "", - ]) + writer.writerow( + [ + d["hostname"], + d["ip_address"], + d["model"] or "", + d["routeros_version"] or "", + d["status"], + d["last_seen"] or "", + d["uptime"] or "", + d["groups"] or "", + ] + ) elif report_type == "metrics_summary": - writer.writerow([ - "Hostname", "Avg CPU %", "Peak CPU %", "Avg Memory %", - "Peak Memory %", "Avg Disk %", "Avg Temp", "Data Points", - ]) + writer.writerow( + [ + "Hostname", + "Avg CPU %", + "Peak CPU %", + "Avg Memory %", + "Peak Memory %", + "Avg Disk %", + "Avg Temp", + "Data Points", + ] + ) for d in data.get("devices", []): - writer.writerow([ - d["hostname"], - f"{d['avg_cpu']:.1f}" if d["avg_cpu"] is not None else "", - f"{d['peak_cpu']:.1f}" if d["peak_cpu"] is not None else "", - f"{d['avg_mem']:.1f}" if d["avg_mem"] is not None else "", - f"{d['peak_mem']:.1f}" if d["peak_mem"] is not None else "", - f"{d['avg_disk']:.1f}" if d["avg_disk"] is not None else "", - f"{d['avg_temp']:.1f}" if d["avg_temp"] is not None else "", - d["data_points"], - ]) + writer.writerow( + [ + d["hostname"], + f"{d['avg_cpu']:.1f}" if d["avg_cpu"] is not None else "", + f"{d['peak_cpu']:.1f}" if d["peak_cpu"] is not None else "", + f"{d['avg_mem']:.1f}" if d["avg_mem"] is not None else "", + f"{d['peak_mem']:.1f}" if d["peak_mem"] is not None else "", + f"{d['avg_disk']:.1f}" if d["avg_disk"] is not None else "", + f"{d['avg_temp']:.1f}" if d["avg_temp"] is not None else "", + d["data_points"], + ] + ) elif report_type == "alert_history": - writer.writerow([ - "Timestamp", "Device", "Severity", "Message", "Status", "Duration", - ]) + writer.writerow( + [ + "Timestamp", + "Device", + "Severity", + "Message", + "Status", + "Duration", + ] + ) for a in data.get("alerts", []): - writer.writerow([ - a["fired_at"], a["hostname"] or "", a["severity"], - a["message"] or "", a["status"], a["duration"] or "", - ]) + writer.writerow( + [ + a["fired_at"], + a["hostname"] or "", + a["severity"], + a["message"] or "", + a["status"], + a["duration"] or "", + ] + ) elif report_type == "change_log": - writer.writerow([ - "Timestamp", "User", "Action", "Device", "Details", - ]) + writer.writerow( + [ + "Timestamp", + "User", + "Action", + "Device", + "Details", + ] + ) for e in data.get("entries", []): - writer.writerow([ - e["timestamp"], e["user"] or "", e["action"], - e["device"] or "", e["details"] or "", - ]) + writer.writerow( + [ + e["timestamp"], + e["user"] or "", + e["action"], + e["device"] or "", + e["details"] or "", + ] + ) return output.getvalue().encode("utf-8") diff --git a/backend/app/services/restore_service.py b/backend/app/services/restore_service.py index 51287b8..8909f85 100644 --- a/backend/app/services/restore_service.py +++ b/backend/app/services/restore_service.py @@ -33,7 +33,6 @@ import asyncssh from sqlalchemy.ext.asyncio import AsyncSession from app.config import settings -from app.database import set_tenant_context, AdminAsyncSessionLocal from app.models.config_backup import ConfigPushOperation from app.models.device import Device from app.services import backup_service, git_store @@ -113,12 +112,11 @@ async def restore_config( raise ValueError(f"Device {device_id!r} not found") if not device.encrypted_credentials_transit and not device.encrypted_credentials: - raise ValueError( - f"Device {device_id!r} has no stored credentials — cannot perform restore" - ) + raise ValueError(f"Device {device_id!r} has no stored credentials — cannot perform restore") key = settings.get_encryption_key_bytes() from app.services.crypto import decrypt_credentials_hybrid + creds_json = await decrypt_credentials_hybrid( device.encrypted_credentials_transit, device.encrypted_credentials, @@ -133,7 +131,9 @@ async def restore_config( hostname = device.hostname or ip # Publish "started" progress event - await _publish_push_progress(tenant_id, device_id, "started", f"Config restore started for {hostname}") + await _publish_push_progress( + tenant_id, device_id, "started", f"Config restore started for {hostname}" + ) # ------------------------------------------------------------------ # Step 2: Read the target export.rsc from the backup commit @@ -157,7 +157,9 @@ async def restore_config( # ------------------------------------------------------------------ # Step 3: Mandatory pre-backup before push # ------------------------------------------------------------------ - await _publish_push_progress(tenant_id, device_id, "backing_up", f"Creating pre-restore backup for {hostname}") + await _publish_push_progress( + tenant_id, device_id, "backing_up", f"Creating pre-restore backup for {hostname}" + ) logger.info( "Starting pre-restore backup for device %s (%s) before pushing commit %s", @@ -198,7 +200,9 @@ async def restore_config( # Step 5: SSH to device — install panic-revert, push config # ------------------------------------------------------------------ push_op_id_str = str(push_op_id) - await _publish_push_progress(tenant_id, device_id, "pushing", f"Pushing config to {hostname}", push_op_id=push_op_id_str) + await _publish_push_progress( + tenant_id, device_id, "pushing", f"Pushing config to {hostname}", push_op_id=push_op_id_str + ) logger.info( "Pushing config to device %s (%s): installing panic-revert scheduler and uploading config", @@ -290,9 +294,12 @@ async def restore_config( # Update push operation to failed await _update_push_op_status(push_op_id, "failed", db_session) await _publish_push_progress( - tenant_id, device_id, "failed", + tenant_id, + device_id, + "failed", f"Config push failed for {hostname}: {push_err}", - push_op_id=push_op_id_str, error=str(push_err), + push_op_id=push_op_id_str, + error=str(push_err), ) return { "status": "failed", @@ -312,7 +319,13 @@ async def restore_config( # ------------------------------------------------------------------ # Step 6: Wait 60s for config to settle # ------------------------------------------------------------------ - await _publish_push_progress(tenant_id, device_id, "settling", f"Config pushed to {hostname} — waiting 60s for settle", push_op_id=push_op_id_str) + await _publish_push_progress( + tenant_id, + device_id, + "settling", + f"Config pushed to {hostname} — waiting 60s for settle", + push_op_id=push_op_id_str, + ) logger.info( "Config pushed to device %s — waiting 60s for config to settle", @@ -323,7 +336,13 @@ async def restore_config( # ------------------------------------------------------------------ # Step 7: Reachability check # ------------------------------------------------------------------ - await _publish_push_progress(tenant_id, device_id, "verifying", f"Verifying device {hostname} reachability", push_op_id=push_op_id_str) + await _publish_push_progress( + tenant_id, + device_id, + "verifying", + f"Verifying device {hostname} reachability", + push_op_id=push_op_id_str, + ) reachable = await _check_reachability(ip, ssh_username, ssh_password) @@ -362,7 +381,13 @@ async def restore_config( await _update_push_op_status(push_op_id, "committed", db_session) await clear_push(device_id) - await _publish_push_progress(tenant_id, device_id, "committed", f"Config restored successfully on {hostname}", push_op_id=push_op_id_str) + await _publish_push_progress( + tenant_id, + device_id, + "committed", + f"Config restored successfully on {hostname}", + push_op_id=push_op_id_str, + ) return { "status": "committed", @@ -384,7 +409,9 @@ async def restore_config( await _update_push_op_status(push_op_id, "reverted", db_session) await _publish_push_progress( - tenant_id, device_id, "reverted", + tenant_id, + device_id, + "reverted", f"Device {hostname} unreachable — auto-reverting via panic-revert scheduler", push_op_id=push_op_id_str, ) @@ -446,7 +473,7 @@ async def _update_push_op_status( new_status: New status value ('committed' | 'reverted' | 'failed'). db_session: Database session (must already have tenant context set). """ - from sqlalchemy import select, update + from sqlalchemy import update await db_session.execute( update(ConfigPushOperation) @@ -526,9 +553,7 @@ async def recover_stale_push_operations(db_session: AsyncSession) -> None: for op in stale_ops: try: # Load device - dev_result = await db_session.execute( - select(Device).where(Device.id == op.device_id) - ) + dev_result = await db_session.execute(select(Device).where(Device.id == op.device_id)) device = dev_result.scalar_one_or_none() if not device: logger.error("Device %s not found for stale op %s", op.device_id, op.id) @@ -547,9 +572,7 @@ async def recover_stale_push_operations(db_session: AsyncSession) -> None: ssh_password = creds.get("password", "") # Check reachability - reachable = await _check_reachability( - device.ip_address, ssh_username, ssh_password - ) + reachable = await _check_reachability(device.ip_address, ssh_username, ssh_password) if reachable: # Try to remove scheduler (if still there, push was good) diff --git a/backend/app/services/retention_service.py b/backend/app/services/retention_service.py index 68e7649..f52dc07 100644 --- a/backend/app/services/retention_service.py +++ b/backend/app/services/retention_service.py @@ -53,7 +53,9 @@ async def cleanup_expired_snapshots() -> int: deleted = result.rowcount config_snapshots_cleaned_total.inc(deleted) - logger.info("retention cleanup complete", extra={"deleted_snapshots": deleted, "retention_days": days}) + logger.info( + "retention cleanup complete", extra={"deleted_snapshots": deleted, "retention_days": days} + ) return deleted diff --git a/backend/app/services/routeros_proxy.py b/backend/app/services/routeros_proxy.py index eb00d90..e4ec147 100644 --- a/backend/app/services/routeros_proxy.py +++ b/backend/app/services/routeros_proxy.py @@ -88,9 +88,7 @@ async def browse_menu(device_id: str, path: str) -> dict[str, Any]: return await execute_command(device_id, command) -async def add_entry( - device_id: str, path: str, properties: dict[str, str] -) -> dict[str, Any]: +async def add_entry(device_id: str, path: str, properties: dict[str, str]) -> dict[str, Any]: """Add a new entry to a RouterOS menu path. Args: @@ -124,9 +122,7 @@ async def update_entry( return await execute_command(device_id, f"{path}/set", args) -async def remove_entry( - device_id: str, path: str, entry_id: str -) -> dict[str, Any]: +async def remove_entry(device_id: str, path: str, entry_id: str) -> dict[str, Any]: """Remove an entry from a RouterOS menu path. Args: diff --git a/backend/app/services/rsc_parser.py b/backend/app/services/rsc_parser.py index 1448b65..2c302b4 100644 --- a/backend/app/services/rsc_parser.py +++ b/backend/app/services/rsc_parser.py @@ -7,20 +7,33 @@ from typing import Any logger = logging.getLogger(__name__) HIGH_RISK_PATHS = { - "/ip address", "/ip route", "/ip firewall filter", "/ip firewall nat", - "/interface", "/interface bridge", "/interface vlan", - "/system identity", "/ip service", "/ip ssh", "/user", + "/ip address", + "/ip route", + "/ip firewall filter", + "/ip firewall nat", + "/interface", + "/interface bridge", + "/interface vlan", + "/system identity", + "/ip service", + "/ip ssh", + "/user", } MANAGEMENT_PATTERNS = [ - (re.compile(r"chain=input.*dst-port=(22|8291|8728|8729|443|80)", re.I), - "Modifies firewall rules for management ports (SSH/WinBox/API/Web)"), - (re.compile(r"chain=input.*action=drop", re.I), - "Adds drop rule on input chain — may block management access"), - (re.compile(r"/ip service", re.I), - "Modifies IP services — may disable API/SSH/WinBox access"), - (re.compile(r"/user.*set.*password", re.I), - "Changes user password — may affect automated access"), + ( + re.compile(r"chain=input.*dst-port=(22|8291|8728|8729|443|80)", re.I), + "Modifies firewall rules for management ports (SSH/WinBox/API/Web)", + ), + ( + re.compile(r"chain=input.*action=drop", re.I), + "Adds drop rule on input chain — may block management access", + ), + (re.compile(r"/ip service", re.I), "Modifies IP services — may disable API/SSH/WinBox access"), + ( + re.compile(r"/user.*set.*password", re.I), + "Changes user password — may affect automated access", + ), ] @@ -73,7 +86,14 @@ def parse_rsc(text: str) -> dict[str, Any]: else: # Check if second part starts with a known command verb cmd_check = parts[1].strip().split(None, 1) - if cmd_check and cmd_check[0] in ("add", "set", "remove", "print", "enable", "disable"): + if cmd_check and cmd_check[0] in ( + "add", + "set", + "remove", + "print", + "enable", + "disable", + ): current_path = parts[0] line = parts[1].strip() else: @@ -184,12 +204,14 @@ def compute_impact( risk = "none" if has_changes: risk = "high" if path in HIGH_RISK_PATHS else "low" - result_categories.append({ - "path": path, - "adds": added, - "removes": removed, - "risk": risk, - }) + result_categories.append( + { + "path": path, + "adds": added, + "removes": removed, + "risk": risk, + } + ) # Check target commands against management patterns target_text = "\n".join( diff --git a/backend/app/services/srp_service.py b/backend/app/services/srp_service.py index b2efa53..af10a84 100644 --- a/backend/app/services/srp_service.py +++ b/backend/app/services/srp_service.py @@ -16,9 +16,7 @@ from srptools.constants import PRIME_2048, PRIME_2048_GEN _SRP_HASH = hashlib.sha256 -async def create_srp_verifier( - salt_hex: str, verifier_hex: str -) -> tuple[bytes, bytes]: +async def create_srp_verifier(salt_hex: str, verifier_hex: str) -> tuple[bytes, bytes]: """Convert client-provided hex salt and verifier to bytes for storage. The client computes v = g^x mod N using 2SKD-derived SRP-x. @@ -31,9 +29,7 @@ async def create_srp_verifier( return bytes.fromhex(salt_hex), bytes.fromhex(verifier_hex) -async def srp_init( - email: str, srp_verifier_hex: str -) -> tuple[str, str]: +async def srp_init(email: str, srp_verifier_hex: str) -> tuple[str, str]: """SRP Step 1: Generate server ephemeral (B) and private key (b). Args: @@ -47,14 +43,15 @@ async def srp_init( Raises: ValueError: If SRP initialization fails for any reason. """ + def _init() -> tuple[str, str]: context = SRPContext( - email, prime=PRIME_2048, generator=PRIME_2048_GEN, + email, + prime=PRIME_2048, + generator=PRIME_2048_GEN, hash_func=_SRP_HASH, ) - server_session = SRPServerSession( - context, srp_verifier_hex - ) + server_session = SRPServerSession(context, srp_verifier_hex) return server_session.public, server_session.private try: @@ -85,26 +82,27 @@ async def srp_verify( Tuple of (is_valid, server_proof_hex_or_none). If valid, server_proof is M2 for the client to verify. """ + def _verify() -> tuple[bool, str | None]: - import logging - log = logging.getLogger("srp_debug") context = SRPContext( - email, prime=PRIME_2048, generator=PRIME_2048_GEN, + email, + prime=PRIME_2048, + generator=PRIME_2048_GEN, hash_func=_SRP_HASH, ) - server_session = SRPServerSession( - context, srp_verifier_hex, private=server_private - ) + server_session = SRPServerSession(context, srp_verifier_hex, private=server_private) _key, _key_proof, _key_proof_hash = server_session.process(client_public, srp_salt_hex) # srptools verify_proof has a Python 3 bug: hexlify() returns bytes # but client_proof is str, so bytes == str is always False. # Compare manually with consistent types. - server_m1 = _key_proof if isinstance(_key_proof, str) else _key_proof.decode('ascii') + server_m1 = _key_proof if isinstance(_key_proof, str) else _key_proof.decode("ascii") is_valid = client_proof.lower() == server_m1.lower() if not is_valid: return False, None # Return M2 (key_proof_hash), also fixing the bytes/str issue - m2 = _key_proof_hash if isinstance(_key_proof_hash, str) else _key_proof_hash.decode('ascii') + m2 = ( + _key_proof_hash if isinstance(_key_proof_hash, str) else _key_proof_hash.decode("ascii") + ) return True, m2 try: diff --git a/backend/app/services/sse_manager.py b/backend/app/services/sse_manager.py index db241b5..ab15064 100644 --- a/backend/app/services/sse_manager.py +++ b/backend/app/services/sse_manager.py @@ -137,7 +137,9 @@ class SSEConnectionManager: if last_event_id is not None: try: start_seq = int(last_event_id) + 1 - consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.BY_START_SEQUENCE, opt_start_seq=start_seq) + consumer_cfg = ConsumerConfig( + deliver_policy=DeliverPolicy.BY_START_SEQUENCE, opt_start_seq=start_seq + ) except (ValueError, TypeError): consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.NEW) else: @@ -173,18 +175,32 @@ class SSEConnectionManager: except Exception as exc: if "stream not found" in str(exc): try: - await js.add_stream(StreamConfig( - name="ALERT_EVENTS", - subjects=_ALERT_EVENT_SUBJECTS, - max_age=3600, - )) - sub = await js.subscribe(subject, stream="ALERT_EVENTS", config=consumer_cfg) + await js.add_stream( + StreamConfig( + name="ALERT_EVENTS", + subjects=_ALERT_EVENT_SUBJECTS, + max_age=3600, + ) + ) + sub = await js.subscribe( + subject, stream="ALERT_EVENTS", config=consumer_cfg + ) self._subscriptions.append(sub) logger.info("sse.stream_created_lazily", stream="ALERT_EVENTS") except Exception as retry_exc: - logger.warning("sse.subscribe_failed", subject=subject, stream="ALERT_EVENTS", error=str(retry_exc)) + logger.warning( + "sse.subscribe_failed", + subject=subject, + stream="ALERT_EVENTS", + error=str(retry_exc), + ) else: - logger.warning("sse.subscribe_failed", subject=subject, stream="ALERT_EVENTS", error=str(exc)) + logger.warning( + "sse.subscribe_failed", + subject=subject, + stream="ALERT_EVENTS", + error=str(exc), + ) # Subscribe to operation events (OPERATION_EVENTS stream) for subject in _OPERATION_EVENT_SUBJECTS: @@ -198,18 +214,32 @@ class SSEConnectionManager: except Exception as exc: if "stream not found" in str(exc): try: - await js.add_stream(StreamConfig( - name="OPERATION_EVENTS", - subjects=_OPERATION_EVENT_SUBJECTS, - max_age=3600, - )) - sub = await js.subscribe(subject, stream="OPERATION_EVENTS", config=consumer_cfg) + await js.add_stream( + StreamConfig( + name="OPERATION_EVENTS", + subjects=_OPERATION_EVENT_SUBJECTS, + max_age=3600, + ) + ) + sub = await js.subscribe( + subject, stream="OPERATION_EVENTS", config=consumer_cfg + ) self._subscriptions.append(sub) logger.info("sse.stream_created_lazily", stream="OPERATION_EVENTS") except Exception as retry_exc: - logger.warning("sse.subscribe_failed", subject=subject, stream="OPERATION_EVENTS", error=str(retry_exc)) + logger.warning( + "sse.subscribe_failed", + subject=subject, + stream="OPERATION_EVENTS", + error=str(retry_exc), + ) else: - logger.warning("sse.subscribe_failed", subject=subject, stream="OPERATION_EVENTS", error=str(exc)) + logger.warning( + "sse.subscribe_failed", + subject=subject, + stream="OPERATION_EVENTS", + error=str(exc), + ) # Start background task to pull messages from subscriptions into the queue asyncio.create_task(self._pump_messages()) diff --git a/backend/app/services/template_service.py b/backend/app/services/template_service.py index bbb02f4..8808f4c 100644 --- a/backend/app/services/template_service.py +++ b/backend/app/services/template_service.py @@ -12,22 +12,18 @@ separate scheduler and file names to avoid conflicts with restore operations. """ import asyncio -import io import ipaddress import json import logging -import uuid from datetime import datetime, timezone import asyncssh from jinja2 import meta from jinja2.sandbox import SandboxedEnvironment -from sqlalchemy import select, text +from sqlalchemy import text from app.config import settings from app.database import AdminAsyncSessionLocal -from app.models.config_template import TemplatePushJob -from app.models.device import Device logger = logging.getLogger(__name__) @@ -145,7 +141,9 @@ async def push_to_devices(rollout_id: str) -> dict: except Exception as exc: logger.error( "Uncaught exception in template push rollout %s: %s", - rollout_id, exc, exc_info=True, + rollout_id, + exc, + exc_info=True, ) return {"completed": 0, "failed": 1, "pending": 0} @@ -181,7 +179,9 @@ async def _run_push_rollout(rollout_id: str) -> dict: logger.info( "Template push rollout %s: pushing to device %s (job %s)", - rollout_id, hostname, job_id, + rollout_id, + hostname, + job_id, ) await push_single_device(job_id) @@ -200,7 +200,9 @@ async def _run_push_rollout(rollout_id: str) -> dict: failed = True logger.error( "Template push rollout %s paused: device %s %s", - rollout_id, hostname, row[0], + rollout_id, + hostname, + row[0], ) break @@ -232,7 +234,9 @@ async def push_single_device(job_id: str) -> None: except Exception as exc: logger.error( "Uncaught exception in template push job %s: %s", - job_id, exc, exc_info=True, + job_id, + exc, + exc_info=True, ) await _update_job(job_id, status="failed", error_message=f"Unexpected error: {exc}") @@ -260,8 +264,13 @@ async def _run_single_push(job_id: str) -> None: return ( - _, device_id, tenant_id, rendered_content, - ip_address, hostname, encrypted_credentials, + _, + device_id, + tenant_id, + rendered_content, + ip_address, + hostname, + encrypted_credentials, encrypted_credentials_transit, ) = row @@ -279,16 +288,21 @@ async def _run_single_push(job_id: str) -> None: try: from app.services.crypto import decrypt_credentials_hybrid + key = settings.get_encryption_key_bytes() creds_json = await decrypt_credentials_hybrid( - encrypted_credentials_transit, encrypted_credentials, tenant_id, key, + encrypted_credentials_transit, + encrypted_credentials, + tenant_id, + key, ) creds = json.loads(creds_json) ssh_username = creds.get("username", "") ssh_password = creds.get("password", "") except Exception as cred_err: await _update_job( - job_id, status="failed", + job_id, + status="failed", error_message=f"Failed to decrypt credentials: {cred_err}", ) return @@ -297,6 +311,7 @@ async def _run_single_push(job_id: str) -> None: logger.info("Running mandatory pre-push backup for device %s (%s)", hostname, ip_address) try: from app.services import backup_service + backup_result = await backup_service.run_backup( device_id=device_id, tenant_id=tenant_id, @@ -308,7 +323,8 @@ async def _run_single_push(job_id: str) -> None: except Exception as backup_err: logger.error("Pre-push backup failed for %s: %s", hostname, backup_err) await _update_job( - job_id, status="failed", + job_id, + status="failed", error_message=f"Pre-push backup failed: {backup_err}", ) return @@ -316,7 +332,8 @@ async def _run_single_push(job_id: str) -> None: # Step 5: SSH to device - install panic-revert, push config logger.info( "Pushing template to device %s (%s): installing panic-revert and uploading config", - hostname, ip_address, + hostname, + ip_address, ) try: @@ -359,7 +376,8 @@ async def _run_single_push(job_id: str) -> None: ) logger.info( "Template import result for device %s: exit_status=%s stdout=%r", - hostname, import_result.exit_status, + hostname, + import_result.exit_status, (import_result.stdout or "")[:200], ) @@ -369,16 +387,21 @@ async def _run_single_push(job_id: str) -> None: except Exception as cleanup_err: logger.warning( "Failed to clean up %s from device %s: %s", - _TEMPLATE_RSC, ip_address, cleanup_err, + _TEMPLATE_RSC, + ip_address, + cleanup_err, ) except Exception as push_err: logger.error( "SSH push phase failed for device %s (%s): %s", - hostname, ip_address, push_err, + hostname, + ip_address, + push_err, ) await _update_job( - job_id, status="failed", + job_id, + status="failed", error_message=f"Config push failed during SSH phase: {push_err}", ) return @@ -395,9 +418,12 @@ async def _run_single_push(job_id: str) -> None: logger.info("Device %s (%s) is reachable after push - committing", hostname, ip_address) try: async with asyncssh.connect( - ip_address, port=22, - username=ssh_username, password=ssh_password, - known_hosts=None, connect_timeout=30, + ip_address, + port=22, + username=ssh_username, + password=ssh_password, + known_hosts=None, + connect_timeout=30, ) as conn: await conn.run( f'/system scheduler remove "{_PANIC_REVERT_SCHEDULER}"', @@ -410,11 +436,13 @@ async def _run_single_push(job_id: str) -> None: except Exception as cleanup_err: logger.warning( "Failed to clean up panic-revert scheduler/backup on device %s: %s", - hostname, cleanup_err, + hostname, + cleanup_err, ) await _update_job( - job_id, status="committed", + job_id, + status="committed", completed_at=datetime.now(timezone.utc), ) else: @@ -422,10 +450,13 @@ async def _run_single_push(job_id: str) -> None: logger.warning( "Device %s (%s) is unreachable after push - panic-revert scheduler " "will auto-revert to %s.backup", - hostname, ip_address, _PRE_PUSH_BACKUP, + hostname, + ip_address, + _PRE_PUSH_BACKUP, ) await _update_job( - job_id, status="reverted", + job_id, + status="reverted", error_message="Device unreachable after push; auto-reverted via panic-revert scheduler", completed_at=datetime.now(timezone.utc), ) @@ -440,9 +471,12 @@ async def _check_reachability(ip: str, username: str, password: str) -> bool: """Check if a RouterOS device is reachable via SSH.""" try: async with asyncssh.connect( - ip, port=22, - username=username, password=password, - known_hosts=None, connect_timeout=30, + ip, + port=22, + username=username, + password=password, + known_hosts=None, + connect_timeout=30, ) as conn: result = await conn.run("/system identity print", check=True) logger.debug("Reachability check OK for %s: %r", ip, result.stdout[:50]) @@ -459,7 +493,12 @@ async def _update_job(job_id: str, **kwargs) -> None: for key, value in kwargs.items(): param_name = f"v_{key}" - if value is None and key in ("error_message", "started_at", "completed_at", "pre_push_backup_sha"): + if value is None and key in ( + "error_message", + "started_at", + "completed_at", + "pre_push_backup_sha", + ): sets.append(f"{key} = NULL") else: sets.append(f"{key} = :{param_name}") @@ -472,7 +511,7 @@ async def _update_job(job_id: str, **kwargs) -> None: await session.execute( text(f""" UPDATE template_push_jobs - SET {', '.join(sets)} + SET {", ".join(sets)} WHERE id = CAST(:job_id AS uuid) """), params, diff --git a/backend/app/services/upgrade_service.py b/backend/app/services/upgrade_service.py index ead083b..7de22d1 100644 --- a/backend/app/services/upgrade_service.py +++ b/backend/app/services/upgrade_service.py @@ -13,7 +13,6 @@ jobs may span multiple tenants and run in background asyncio tasks. """ import asyncio -import io import json import logging from datetime import datetime, timezone @@ -99,10 +98,19 @@ async def _run_upgrade(job_id: str) -> None: return ( - _, device_id, tenant_id, target_version, - architecture, channel, status, confirmed_major, - ip_address, hostname, encrypted_credentials, - current_version, encrypted_credentials_transit, + _, + device_id, + tenant_id, + target_version, + architecture, + channel, + status, + confirmed_major, + ip_address, + hostname, + encrypted_credentials, + current_version, + encrypted_credentials_transit, ) = row device_id = str(device_id) @@ -116,12 +124,22 @@ async def _run_upgrade(job_id: str) -> None: logger.info( "Starting firmware upgrade for %s (%s): %s -> %s", - hostname, ip_address, current_version, target_version, + hostname, + ip_address, + current_version, + target_version, ) # Step 2: Update status to downloading await _update_job(job_id, status="downloading", started_at=datetime.now(timezone.utc)) - await _publish_upgrade_progress(tenant_id, device_id, job_id, "downloading", target_version, f"Downloading firmware {target_version} for {hostname}") + await _publish_upgrade_progress( + tenant_id, + device_id, + job_id, + "downloading", + target_version, + f"Downloading firmware {target_version} for {hostname}", + ) # Step 3: Check major version upgrade confirmation if current_version and target_version: @@ -133,13 +151,22 @@ async def _run_upgrade(job_id: str) -> None: status="failed", error_message="Major version upgrade requires explicit confirmation", ) - await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Major version upgrade requires explicit confirmation for {hostname}", error="Major version upgrade requires explicit confirmation") + await _publish_upgrade_progress( + tenant_id, + device_id, + job_id, + "failed", + target_version, + f"Major version upgrade requires explicit confirmation for {hostname}", + error="Major version upgrade requires explicit confirmation", + ) return # Step 4: Mandatory config backup logger.info("Running mandatory pre-upgrade backup for %s", hostname) try: from app.services import backup_service + backup_result = await backup_service.run_backup( device_id=device_id, tenant_id=tenant_id, @@ -155,13 +182,22 @@ async def _run_upgrade(job_id: str) -> None: status="failed", error_message=f"Pre-upgrade backup failed: {backup_err}", ) - await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Pre-upgrade backup failed for {hostname}", error=str(backup_err)) + await _publish_upgrade_progress( + tenant_id, + device_id, + job_id, + "failed", + target_version, + f"Pre-upgrade backup failed for {hostname}", + error=str(backup_err), + ) return # Step 5: Download NPK logger.info("Downloading firmware %s for %s/%s", target_version, architecture, channel) try: from app.services.firmware_service import download_firmware + npk_path = await download_firmware(architecture, channel, target_version) logger.info("Firmware cached at %s", npk_path) except Exception as dl_err: @@ -171,24 +207,51 @@ async def _run_upgrade(job_id: str) -> None: status="failed", error_message=f"Firmware download failed: {dl_err}", ) - await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Firmware download failed for {hostname}", error=str(dl_err)) + await _publish_upgrade_progress( + tenant_id, + device_id, + job_id, + "failed", + target_version, + f"Firmware download failed for {hostname}", + error=str(dl_err), + ) return # Step 6: Upload NPK to device via SFTP await _update_job(job_id, status="uploading") - await _publish_upgrade_progress(tenant_id, device_id, job_id, "uploading", target_version, f"Uploading firmware to {hostname}") + await _publish_upgrade_progress( + tenant_id, + device_id, + job_id, + "uploading", + target_version, + f"Uploading firmware to {hostname}", + ) # Decrypt device credentials (dual-read: Transit preferred, legacy fallback) if not encrypted_credentials_transit and not encrypted_credentials: await _update_job(job_id, status="failed", error_message="Device has no stored credentials") - await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"No stored credentials for {hostname}", error="Device has no stored credentials") + await _publish_upgrade_progress( + tenant_id, + device_id, + job_id, + "failed", + target_version, + f"No stored credentials for {hostname}", + error="Device has no stored credentials", + ) return try: from app.services.crypto import decrypt_credentials_hybrid + key = settings.get_encryption_key_bytes() creds_json = await decrypt_credentials_hybrid( - encrypted_credentials_transit, encrypted_credentials, tenant_id, key, + encrypted_credentials_transit, + encrypted_credentials, + tenant_id, + key, ) creds = json.loads(creds_json) ssh_username = creds.get("username", "") @@ -199,7 +262,15 @@ async def _run_upgrade(job_id: str) -> None: status="failed", error_message=f"Failed to decrypt credentials: {cred_err}", ) - await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Failed to decrypt credentials for {hostname}", error=str(cred_err)) + await _publish_upgrade_progress( + tenant_id, + device_id, + job_id, + "failed", + target_version, + f"Failed to decrypt credentials for {hostname}", + error=str(cred_err), + ) return try: @@ -225,12 +296,27 @@ async def _run_upgrade(job_id: str) -> None: status="failed", error_message=f"NPK upload failed: {upload_err}", ) - await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"NPK upload failed for {hostname}", error=str(upload_err)) + await _publish_upgrade_progress( + tenant_id, + device_id, + job_id, + "failed", + target_version, + f"NPK upload failed for {hostname}", + error=str(upload_err), + ) return # Step 7: Trigger reboot await _update_job(job_id, status="rebooting") - await _publish_upgrade_progress(tenant_id, device_id, job_id, "rebooting", target_version, f"Rebooting {hostname} for firmware install") + await _publish_upgrade_progress( + tenant_id, + device_id, + job_id, + "rebooting", + target_version, + f"Rebooting {hostname} for firmware install", + ) try: async with asyncssh.connect( ip_address, @@ -245,7 +331,9 @@ async def _run_upgrade(job_id: str) -> None: logger.info("Reboot command sent to %s", hostname) except Exception as reboot_err: # Device may drop connection during reboot — this is expected - logger.info("Device %s dropped connection after reboot command (expected): %s", hostname, reboot_err) + logger.info( + "Device %s dropped connection after reboot command (expected): %s", hostname, reboot_err + ) # Step 8: Wait for reconnect logger.info("Waiting %ds before polling %s for reconnect", _INITIAL_WAIT, hostname) @@ -267,36 +355,69 @@ async def _run_upgrade(job_id: str) -> None: status="failed", error_message=f"Device did not reconnect within {_RECONNECT_TIMEOUT // 60} minutes after reboot", ) - await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Device {hostname} did not reconnect within {_RECONNECT_TIMEOUT // 60} minutes", error="Reconnect timeout") + await _publish_upgrade_progress( + tenant_id, + device_id, + job_id, + "failed", + target_version, + f"Device {hostname} did not reconnect within {_RECONNECT_TIMEOUT // 60} minutes", + error="Reconnect timeout", + ) return # Step 9: Verify upgrade await _update_job(job_id, status="verifying") - await _publish_upgrade_progress(tenant_id, device_id, job_id, "verifying", target_version, f"Verifying firmware version on {hostname}") + await _publish_upgrade_progress( + tenant_id, + device_id, + job_id, + "verifying", + target_version, + f"Verifying firmware version on {hostname}", + ) try: actual_version = await _get_device_version(ip_address, ssh_username, ssh_password) if actual_version and target_version in actual_version: logger.info( "Firmware upgrade verified for %s: %s", - hostname, actual_version, + hostname, + actual_version, ) await _update_job( job_id, status="completed", completed_at=datetime.now(timezone.utc), ) - await _publish_upgrade_progress(tenant_id, device_id, job_id, "completed", target_version, f"Firmware upgrade to {target_version} completed on {hostname}") + await _publish_upgrade_progress( + tenant_id, + device_id, + job_id, + "completed", + target_version, + f"Firmware upgrade to {target_version} completed on {hostname}", + ) else: logger.error( "Version mismatch for %s: expected %s, got %s", - hostname, target_version, actual_version, + hostname, + target_version, + actual_version, ) await _update_job( job_id, status="failed", error_message=f"Expected {target_version} but got {actual_version}", ) - await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Version mismatch on {hostname}: expected {target_version}, got {actual_version}", error=f"Expected {target_version} but got {actual_version}") + await _publish_upgrade_progress( + tenant_id, + device_id, + job_id, + "failed", + target_version, + f"Version mismatch on {hostname}: expected {target_version}, got {actual_version}", + error=f"Expected {target_version} but got {actual_version}", + ) except Exception as verify_err: logger.error("Post-upgrade verification failed for %s: %s", hostname, verify_err) await _update_job( @@ -304,7 +425,15 @@ async def _run_upgrade(job_id: str) -> None: status="failed", error_message=f"Post-upgrade verification failed: {verify_err}", ) - await _publish_upgrade_progress(tenant_id, device_id, job_id, "failed", target_version, f"Post-upgrade verification failed for {hostname}", error=str(verify_err)) + await _publish_upgrade_progress( + tenant_id, + device_id, + job_id, + "failed", + target_version, + f"Post-upgrade verification failed for {hostname}", + error=str(verify_err), + ) async def start_mass_upgrade(rollout_group_id: str) -> dict: @@ -457,7 +586,7 @@ async def resume_mass_upgrade(rollout_group_id: str) -> None: """Resume a paused mass rollout from where it left off.""" # Reset first paused job to pending, then restart sequential processing async with AdminAsyncSessionLocal() as session: - result = await session.execute( + await session.execute( text(""" UPDATE firmware_upgrade_jobs SET status = 'pending' @@ -519,7 +648,7 @@ async def _update_job(job_id: str, **kwargs) -> None: await session.execute( text(f""" UPDATE firmware_upgrade_jobs - SET {', '.join(sets)} + SET {", ".join(sets)} WHERE id = CAST(:job_id AS uuid) """), params, diff --git a/backend/app/services/vpn_service.py b/backend/app/services/vpn_service.py index 5966b5b..7be4dc1 100644 --- a/backend/app/services/vpn_service.py +++ b/backend/app/services/vpn_service.py @@ -26,7 +26,11 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.config import settings from app.models.device import Device from app.models.vpn import VpnConfig, VpnPeer -from app.services.crypto import decrypt_credentials, encrypt_credentials, encrypt_credentials_transit +from app.services.crypto import ( + decrypt_credentials, + encrypt_credentials, + encrypt_credentials_transit, +) logger = structlog.get_logger(__name__) @@ -62,7 +66,9 @@ async def _get_or_create_global_server_key(db: AsyncSession) -> tuple[str, str]: await db.execute(sa_text("SELECT pg_advisory_xact_lock(hashtext('vpn_server_keygen'))")) result = await db.execute( - sa_text("SELECT key, value, encrypted_value FROM system_settings WHERE key IN ('vpn_server_public_key', 'vpn_server_private_key')") + sa_text( + "SELECT key, value, encrypted_value FROM system_settings WHERE key IN ('vpn_server_public_key', 'vpn_server_private_key')" + ) ) rows = {row[0]: row for row in result.fetchall()} @@ -188,7 +194,9 @@ async def sync_wireguard_config() -> None: # Query ALL enabled VPN configs (admin session bypasses RLS) configs_result = await admin_db.execute( - select(VpnConfig).where(VpnConfig.is_enabled.is_(True)).order_by(VpnConfig.subnet_index) + select(VpnConfig) + .where(VpnConfig.is_enabled.is_(True)) + .order_by(VpnConfig.subnet_index) ) configs = configs_result.scalars().all() @@ -228,7 +236,9 @@ async def sync_wireguard_config() -> None: peer_ip = peer.assigned_ip.split("/")[0] allowed_ips = [f"{peer_ip}/32"] if peer.additional_allowed_ips: - extra = [s.strip() for s in peer.additional_allowed_ips.split(",") if s.strip()] + extra = [ + s.strip() for s in peer.additional_allowed_ips.split(",") if s.strip() + ] allowed_ips.extend(extra) lines.append("[Peer]") lines.append(f"PublicKey = {peer.peer_public_key}") @@ -253,12 +263,13 @@ async def sync_wireguard_config() -> None: # Docker traffic (172.16.0.0/12) going to each tenant's subnet # gets SNATted to that tenant's gateway IP (.1) so the router # can route replies back through the tunnel. - nat_lines = ["#!/bin/sh", - "# Auto-generated per-tenant SNAT rules", - "# Remove old rules", - "iptables -t nat -F POSTROUTING 2>/dev/null", - "# Re-add Docker DNS rules", - ] + nat_lines = [ + "#!/bin/sh", + "# Auto-generated per-tenant SNAT rules", + "# Remove old rules", + "iptables -t nat -F POSTROUTING 2>/dev/null", + "# Re-add Docker DNS rules", + ] for config in configs: gateway_ip = config.server_address.split("/")[0] # e.g. 10.10.3.1 subnet = config.subnet # e.g. 10.10.3.0/24 @@ -275,12 +286,15 @@ async def sync_wireguard_config() -> None: reload_flag = wg_confs_dir / ".reload" reload_flag.write_text("1") - logger.info("wireguard_config_synced", audit=True, - tenants=len(configs), peers=total_peers) + logger.info( + "wireguard_config_synced", audit=True, tenants=len(configs), peers=total_peers + ) finally: # Release advisory lock explicitly (session-level lock, not xact-level) - await admin_db.execute(sa_text("SELECT pg_advisory_unlock(hashtext('wireguard_config'))")) + await admin_db.execute( + sa_text("SELECT pg_advisory_unlock(hashtext('wireguard_config'))") + ) # ── Live Status ── @@ -371,15 +385,23 @@ async def setup_vpn( db.add(config) await db.flush() - logger.info("vpn_subnet_allocated", audit=True, - tenant_id=str(tenant_id), subnet_index=subnet_index, subnet=subnet) + logger.info( + "vpn_subnet_allocated", + audit=True, + tenant_id=str(tenant_id), + subnet_index=subnet_index, + subnet=subnet, + ) await _commit_and_sync(db) return config async def update_vpn_config( - db: AsyncSession, tenant_id: uuid.UUID, endpoint: Optional[str] = None, is_enabled: Optional[bool] = None + db: AsyncSession, + tenant_id: uuid.UUID, + endpoint: Optional[str] = None, + is_enabled: Optional[bool] = None, ) -> VpnConfig: """Update VPN config settings.""" config = await get_vpn_config(db, tenant_id) @@ -422,14 +444,21 @@ async def _next_available_ip(db: AsyncSession, tenant_id: uuid.UUID, config: Vpn raise ValueError("No available IPs in VPN subnet") -async def add_peer(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UUID, additional_allowed_ips: Optional[str] = None) -> VpnPeer: +async def add_peer( + db: AsyncSession, + tenant_id: uuid.UUID, + device_id: uuid.UUID, + additional_allowed_ips: Optional[str] = None, +) -> VpnPeer: """Add a device as a VPN peer.""" config = await get_vpn_config(db, tenant_id) if not config: raise ValueError("VPN not configured — enable VPN first") # Check device exists - device = await db.execute(select(Device).where(Device.id == device_id, Device.tenant_id == tenant_id)) + device = await db.execute( + select(Device).where(Device.id == device_id, Device.tenant_id == tenant_id) + ) if not device.scalar_one_or_none(): raise ValueError("Device not found") @@ -497,13 +526,12 @@ async def get_peer_config(db: AsyncSession, tenant_id: uuid.UUID, peer_id: uuid. psk = decrypt_credentials(peer.preshared_key, key_bytes) if peer.preshared_key else None endpoint = config.endpoint or "YOUR_SERVER_IP:51820" - peer_ip_no_cidr = peer.assigned_ip.split("/")[0] routeros_commands = [ f'/interface wireguard add name=wg-portal listen-port=13231 private-key="{private_key}"', f'/interface wireguard peers add interface=wg-portal public-key="{config.server_public_key}" ' - f'endpoint-address={endpoint.split(":")[0]} endpoint-port={endpoint.split(":")[-1]} ' - f'allowed-address=10.10.0.0/16 persistent-keepalive=25' + f"endpoint-address={endpoint.split(':')[0]} endpoint-port={endpoint.split(':')[-1]} " + f"allowed-address=10.10.0.0/16 persistent-keepalive=25" + (f' preshared-key="{psk}"' if psk else ""), f"/ip address add address={peer.assigned_ip} interface=wg-portal", ] @@ -583,8 +611,8 @@ async def onboard_device( routeros_commands = [ f'/interface wireguard add name=wg-portal listen-port=13231 private-key="{private_key_b64}"', f'/interface wireguard peers add interface=wg-portal public-key="{config.server_public_key}" ' - f'endpoint-address={endpoint.split(":")[0]} endpoint-port={endpoint.split(":")[-1]} ' - f'allowed-address=10.10.0.0/16 persistent-keepalive=25' + f"endpoint-address={endpoint.split(':')[0]} endpoint-port={endpoint.split(':')[-1]} " + f"allowed-address=10.10.0.0/16 persistent-keepalive=25" f' preshared-key="{psk_decrypted}"', f"/ip address add address={assigned_ip} interface=wg-portal", ] diff --git a/backend/pyproject.toml b/backend/pyproject.toml index f052c89..b80689e 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -57,3 +57,13 @@ markers = [ [tool.ruff] line-length = 100 target-version = "py312" + +[tool.ruff.lint] +select = ["E", "F", "W"] +ignore = [ + "F821", # SQLAlchemy uses forward-reference strings in relationships + "E501", # Line too long — pre-existing in migrations and SQL strings +] + +[tool.ruff.lint.per-file-ignores] +"alembic/versions/*.py" = ["E402"] # Alembic puts revision vars before imports diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 0f94577..331a030 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -6,8 +6,6 @@ Phase 10: Integration test fixtures added in tests/integration/conftest.py. Pytest marker registration and shared configuration lives here. """ -import pytest - def pytest_configure(config): """Register custom markers.""" diff --git a/backend/tests/integration/conftest.py b/backend/tests/integration/conftest.py index 13d3d03..be2446f 100644 --- a/backend/tests/integration/conftest.py +++ b/backend/tests/integration/conftest.py @@ -241,6 +241,7 @@ async def test_app(admin_engine, app_engine): # Register rate limiter (auth endpoints use @limiter.limit) from app.middleware.rate_limit import setup_rate_limiting + setup_rate_limiting(app) # Create test session factories diff --git a/backend/tests/integration/test_alerts_api.py b/backend/tests/integration/test_alerts_api.py index 561de0d..9baf4df 100644 --- a/backend/tests/integration/test_alerts_api.py +++ b/backend/tests/integration/test_alerts_api.py @@ -258,9 +258,7 @@ class TestAlertEvents: ): """GET /api/tenants/{tenant_id}/devices/{device_id}/alerts returns paginated response.""" tenant = await create_test_tenant(admin_session) - auth = await auth_headers_factory( - admin_session, existing_tenant_id=tenant.id - ) + auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id) tenant_id = auth["tenant_id"] device = await create_test_device(admin_session, tenant.id) await admin_session.commit() diff --git a/backend/tests/integration/test_auth_api.py b/backend/tests/integration/test_auth_api.py index 591765b..d57b78a 100644 --- a/backend/tests/integration/test_auth_api.py +++ b/backend/tests/integration/test_auth_api.py @@ -19,7 +19,7 @@ from app.services.auth import hash_password pytestmark = pytest.mark.integration -from tests.integration.conftest import TEST_DATABASE_URL +from tests.integration.conftest import TEST_DATABASE_URL # noqa: E402 # --------------------------------------------------------------------------- @@ -93,7 +93,6 @@ async def test_login_success(client, admin_engine): assert len(body["refresh_token"]) > 0 # Verify httpOnly cookie is set - cookies = resp.cookies # Cookie may or may not appear in httpx depending on secure flag # Just verify the response contains Set-Cookie header set_cookie = resp.headers.get("set-cookie", "") @@ -193,7 +192,6 @@ async def test_token_refresh(client, admin_engine): assert login_resp.status_code == 200 tokens = login_resp.json() refresh_token = tokens["refresh_token"] - original_access = tokens["access_token"] # Use refresh token to get new access token refresh_resp = await client.post( diff --git a/backend/tests/integration/test_config_api.py b/backend/tests/integration/test_config_api.py index 4fcaeb6..073ebf1 100644 --- a/backend/tests/integration/test_config_api.py +++ b/backend/tests/integration/test_config_api.py @@ -32,9 +32,7 @@ class TestConfigBackups: ): """GET config backups for a device with no backups returns 200 + empty list.""" tenant = await create_test_tenant(admin_session) - auth = await auth_headers_factory( - admin_session, existing_tenant_id=tenant.id - ) + auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id) tenant_id = auth["tenant_id"] device = await create_test_device(admin_session, tenant.id) await admin_session.commit() @@ -58,9 +56,7 @@ class TestConfigBackups: ): """GET schedule returns synthetic default when no schedule configured.""" tenant = await create_test_tenant(admin_session) - auth = await auth_headers_factory( - admin_session, existing_tenant_id=tenant.id - ) + auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id) tenant_id = auth["tenant_id"] device = await create_test_device(admin_session, tenant.id) await admin_session.commit() @@ -118,9 +114,7 @@ class TestConfigBackups: ): """Config backup router responds (not 404) for expected paths.""" tenant = await create_test_tenant(admin_session) - auth = await auth_headers_factory( - admin_session, existing_tenant_id=tenant.id - ) + auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id) tenant_id = auth["tenant_id"] device = await create_test_device(admin_session, tenant.id) await admin_session.commit() @@ -143,7 +137,5 @@ class TestConfigBackups: """GET config backups without auth returns 401.""" tenant_id = str(uuid.uuid4()) device_id = str(uuid.uuid4()) - resp = await client.get( - f"/api/tenants/{tenant_id}/devices/{device_id}/config/backups" - ) + resp = await client.get(f"/api/tenants/{tenant_id}/devices/{device_id}/config/backups") assert resp.status_code == 401 diff --git a/backend/tests/integration/test_devices_api.py b/backend/tests/integration/test_devices_api.py index 555df7b..396062f 100644 --- a/backend/tests/integration/test_devices_api.py +++ b/backend/tests/integration/test_devices_api.py @@ -10,7 +10,6 @@ All tests are independent and create their own test data. import uuid import pytest -import pytest_asyncio pytestmark = pytest.mark.integration @@ -90,9 +89,7 @@ class TestDevicesCRUD: ): """GET /api/tenants/{tenant_id}/devices/{device_id} returns correct device.""" tenant = await create_test_tenant(admin_session) - auth = await auth_headers_factory( - admin_session, existing_tenant_id=tenant.id - ) + auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id) tenant_id = auth["tenant_id"] device = await create_test_device(admin_session, tenant.id) @@ -178,9 +175,7 @@ class TestDevicesCRUD: ): """GET /api/tenants/{tenant_id}/devices?status=online returns filtered results.""" tenant = await create_test_tenant(admin_session) - auth = await auth_headers_factory( - admin_session, existing_tenant_id=tenant.id - ) + auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id) tenant_id = auth["tenant_id"] # Create devices with different statuses diff --git a/backend/tests/integration/test_monitoring_api.py b/backend/tests/integration/test_monitoring_api.py index 738fb18..263818e 100644 --- a/backend/tests/integration/test_monitoring_api.py +++ b/backend/tests/integration/test_monitoring_api.py @@ -36,9 +36,7 @@ class TestHealthMetrics: ): """GET health metrics for a device with no data returns 200 + empty list.""" tenant = await create_test_tenant(admin_session) - auth = await auth_headers_factory( - admin_session, existing_tenant_id=tenant.id - ) + auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id) tenant_id = auth["tenant_id"] device = await create_test_device(admin_session, tenant.id) @@ -68,9 +66,7 @@ class TestHealthMetrics: ): """GET health metrics returns bucketed data when rows exist.""" tenant = await create_test_tenant(admin_session) - auth = await auth_headers_factory( - admin_session, existing_tenant_id=tenant.id - ) + auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id) tenant_id = auth["tenant_id"] device = await create_test_device(admin_session, tenant.id) await admin_session.flush() @@ -131,9 +127,7 @@ class TestInterfaceMetrics: ): """GET interface metrics for device with no data returns 200 + empty list.""" tenant = await create_test_tenant(admin_session) - auth = await auth_headers_factory( - admin_session, existing_tenant_id=tenant.id - ) + auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id) tenant_id = auth["tenant_id"] device = await create_test_device(admin_session, tenant.id) await admin_session.commit() @@ -160,9 +154,7 @@ class TestInterfaceMetrics: ): """GET interface list for device with no data returns 200 + empty list.""" tenant = await create_test_tenant(admin_session) - auth = await auth_headers_factory( - admin_session, existing_tenant_id=tenant.id - ) + auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id) tenant_id = auth["tenant_id"] device = await create_test_device(admin_session, tenant.id) await admin_session.commit() @@ -188,9 +180,7 @@ class TestSparkline: ): """GET sparkline for device with no data returns 200 + empty list.""" tenant = await create_test_tenant(admin_session) - auth = await auth_headers_factory( - admin_session, existing_tenant_id=tenant.id - ) + auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id) tenant_id = auth["tenant_id"] device = await create_test_device(admin_session, tenant.id) await admin_session.commit() @@ -215,9 +205,7 @@ class TestFleetSummary: ): """GET /api/tenants/{tenant_id}/fleet/summary returns 200 with empty fleet.""" tenant = await create_test_tenant(admin_session) - auth = await auth_headers_factory( - admin_session, existing_tenant_id=tenant.id - ) + auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id) tenant_id = auth["tenant_id"] resp = await client.get( @@ -238,9 +226,7 @@ class TestFleetSummary: ): """GET fleet summary returns device data when devices exist.""" tenant = await create_test_tenant(admin_session) - auth = await auth_headers_factory( - admin_session, existing_tenant_id=tenant.id - ) + auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id) tenant_id = auth["tenant_id"] await create_test_device(admin_session, tenant.id, hostname="fleet-dev-1") @@ -279,9 +265,7 @@ class TestWirelessMetrics: ): """GET wireless metrics for device with no data returns 200 + empty list.""" tenant = await create_test_tenant(admin_session) - auth = await auth_headers_factory( - admin_session, existing_tenant_id=tenant.id - ) + auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id) tenant_id = auth["tenant_id"] device = await create_test_device(admin_session, tenant.id) await admin_session.commit() @@ -308,9 +292,7 @@ class TestWirelessMetrics: ): """GET wireless latest for device with no data returns 200 + empty list.""" tenant = await create_test_tenant(admin_session) - auth = await auth_headers_factory( - admin_session, existing_tenant_id=tenant.id - ) + auth = await auth_headers_factory(admin_session, existing_tenant_id=tenant.id) tenant_id = auth["tenant_id"] device = await create_test_device(admin_session, tenant.id) await admin_session.commit() diff --git a/backend/tests/integration/test_rls_isolation.py b/backend/tests/integration/test_rls_isolation.py index bbd1366..0a45d24 100644 --- a/backend/tests/integration/test_rls_isolation.py +++ b/backend/tests/integration/test_rls_isolation.py @@ -28,7 +28,7 @@ from app.services.auth import hash_password pytestmark = pytest.mark.integration # Use the same test DB URLs as conftest -from tests.integration.conftest import TEST_APP_USER_DATABASE_URL, TEST_DATABASE_URL +from tests.integration.conftest import TEST_APP_USER_DATABASE_URL, TEST_DATABASE_URL # noqa: E402 # --------------------------------------------------------------------------- @@ -86,12 +86,20 @@ async def test_tenant_a_cannot_see_tenant_b_devices(): await session.flush() da = Device( - tenant_id=ta.id, hostname=f"rls-ra-{uid}", ip_address="10.1.1.1", - api_port=8728, api_ssl_port=8729, status="online", + tenant_id=ta.id, + hostname=f"rls-ra-{uid}", + ip_address="10.1.1.1", + api_port=8728, + api_ssl_port=8729, + status="online", ) db = Device( - tenant_id=tb.id, hostname=f"rls-rb-{uid}", ip_address="10.1.1.2", - api_port=8728, api_ssl_port=8729, status="online", + tenant_id=tb.id, + hostname=f"rls-rb-{uid}", + ip_address="10.1.1.2", + api_port=8728, + api_ssl_port=8729, + status="online", ) session.add_all([da, db]) await session.flush() @@ -129,12 +137,20 @@ async def test_tenant_a_cannot_see_tenant_b_alerts(): await session.flush() ra = AlertRule( - tenant_id=ta.id, name=f"CPU Alert A {uid}", - metric="cpu_load", operator=">", threshold=90.0, severity="warning", + tenant_id=ta.id, + name=f"CPU Alert A {uid}", + metric="cpu_load", + operator=">", + threshold=90.0, + severity="warning", ) rb = AlertRule( - tenant_id=tb.id, name=f"CPU Alert B {uid}", - metric="cpu_load", operator=">", threshold=85.0, severity="critical", + tenant_id=tb.id, + name=f"CPU Alert B {uid}", + metric="cpu_load", + operator=">", + threshold=85.0, + severity="critical", ) session.add_all([ra, rb]) await session.flush() @@ -256,12 +272,20 @@ async def test_super_admin_sees_all_tenants(): await session.flush() da = Device( - tenant_id=ta.id, hostname=f"sa-ra-{uid}", ip_address="10.2.1.1", - api_port=8728, api_ssl_port=8729, status="online", + tenant_id=ta.id, + hostname=f"sa-ra-{uid}", + ip_address="10.2.1.1", + api_port=8728, + api_ssl_port=8729, + status="online", ) db = Device( - tenant_id=tb.id, hostname=f"sa-rb-{uid}", ip_address="10.2.1.2", - api_port=8728, api_ssl_port=8729, status="online", + tenant_id=tb.id, + hostname=f"sa-rb-{uid}", + ip_address="10.2.1.2", + api_port=8728, + api_ssl_port=8729, + status="online", ) session.add_all([da, db]) await session.flush() @@ -311,31 +335,45 @@ async def test_api_rls_isolation_devices_endpoint(client, admin_engine): ua = User( email=f"api-ua-{uid}@example.com", hashed_password=hash_password("TestPass123!"), - name="User A", role="tenant_admin", - tenant_id=ta.id, is_active=True, + name="User A", + role="tenant_admin", + tenant_id=ta.id, + is_active=True, ) ub = User( email=f"api-ub-{uid}@example.com", hashed_password=hash_password("TestPass123!"), - name="User B", role="tenant_admin", - tenant_id=tb.id, is_active=True, + name="User B", + role="tenant_admin", + tenant_id=tb.id, + is_active=True, ) session.add_all([ua, ub]) await session.flush() da = Device( - tenant_id=ta.id, hostname=f"api-ra-{uid}", ip_address="10.3.1.1", - api_port=8728, api_ssl_port=8729, status="online", + tenant_id=ta.id, + hostname=f"api-ra-{uid}", + ip_address="10.3.1.1", + api_port=8728, + api_ssl_port=8729, + status="online", ) db = Device( - tenant_id=tb.id, hostname=f"api-rb-{uid}", ip_address="10.3.1.2", - api_port=8728, api_ssl_port=8729, status="online", + tenant_id=tb.id, + hostname=f"api-rb-{uid}", + ip_address="10.3.1.2", + api_port=8728, + api_ssl_port=8729, + status="online", ) session.add_all([da, db]) await session.flush() return { - "ta_id": str(ta.id), "tb_id": str(tb.id), - "ua_email": ua.email, "ub_email": ub.email, + "ta_id": str(ta.id), + "tb_id": str(tb.id), + "ua_email": ua.email, + "ub_email": ub.email, } ids = await _admin_commit(TEST_DATABASE_URL, setup) @@ -398,21 +436,29 @@ async def test_api_rls_isolation_cross_tenant_device_access(client, admin_engine ua = User( email=f"api-xt-ua-{uid}@example.com", hashed_password=hash_password("TestPass123!"), - name="User A", role="tenant_admin", - tenant_id=ta.id, is_active=True, + name="User A", + role="tenant_admin", + tenant_id=ta.id, + is_active=True, ) session.add(ua) await session.flush() db = Device( - tenant_id=tb.id, hostname=f"api-xt-rb-{uid}", ip_address="10.4.1.1", - api_port=8728, api_ssl_port=8729, status="online", + tenant_id=tb.id, + hostname=f"api-xt-rb-{uid}", + ip_address="10.4.1.1", + api_port=8728, + api_ssl_port=8729, + status="online", ) session.add(db) await session.flush() return { - "ta_id": str(ta.id), "tb_id": str(tb.id), - "ua_email": ua.email, "db_id": str(db.id), + "ta_id": str(ta.id), + "tb_id": str(tb.id), + "ua_email": ua.email, + "db_id": str(db.id), } ids = await _admin_commit(TEST_DATABASE_URL, setup) diff --git a/backend/tests/integration/test_vpn_isolation.py b/backend/tests/integration/test_vpn_isolation.py index 14e5025..476ceb5 100644 --- a/backend/tests/integration/test_vpn_isolation.py +++ b/backend/tests/integration/test_vpn_isolation.py @@ -5,22 +5,16 @@ tenant deletion cleanup, and allowed-IPs validation. """ import os -import uuid from unittest.mock import AsyncMock, patch import pytest -import pytest_asyncio -from sqlalchemy import select, text +from sqlalchemy import select -from app.models.vpn import VpnConfig, VpnPeer from app.services.vpn_service import ( add_peer, get_peer_config, get_vpn_config, - remove_peer, setup_vpn, - sync_wireguard_config, - _get_wg_config_path, ) pytestmark = pytest.mark.integration @@ -213,9 +207,8 @@ class TestTenantDeletion: # Delete tenant 2 from app.models.tenant import Tenant - result = await admin_session.execute( - select(Tenant).where(Tenant.id == t2.id) - ) + + result = await admin_session.execute(select(Tenant).where(Tenant.id == t2.id)) tenant_obj = result.scalar_one() await admin_session.delete(tenant_obj) await admin_session.flush() diff --git a/backend/tests/test_audit_config_backup.py b/backend/tests/test_audit_config_backup.py index 7c71407..6d0f79c 100644 --- a/backend/tests/test_audit_config_backup.py +++ b/backend/tests/test_audit_config_backup.py @@ -62,24 +62,31 @@ async def test_snapshot_created_audit_event(): mock_log_action = AsyncMock() - with patch( - "app.services.config_snapshot_subscriber.AdminAsyncSessionLocal", - return_value=mock_ctx, - ), patch( - "app.services.config_snapshot_subscriber.OpenBaoTransitService", - return_value=mock_openbao, - ), patch( - "app.services.config_snapshot_subscriber.generate_and_store_diff", - new_callable=AsyncMock, - ), patch( - "app.services.config_snapshot_subscriber.log_action", - mock_log_action, + with ( + patch( + "app.services.config_snapshot_subscriber.AdminAsyncSessionLocal", + return_value=mock_ctx, + ), + patch( + "app.services.config_snapshot_subscriber.OpenBaoTransitService", + return_value=mock_openbao, + ), + patch( + "app.services.config_snapshot_subscriber.generate_and_store_diff", + new_callable=AsyncMock, + ), + patch( + "app.services.config_snapshot_subscriber.log_action", + mock_log_action, + ), ): await handle_config_snapshot(msg) # log_action should have been called with config_snapshot_created - actions = [call.kwargs.get("action", call.args[4] if len(call.args) > 4 else None) - for call in mock_log_action.call_args_list] + actions = [ + call.kwargs.get("action", call.args[4] if len(call.args) > 4 else None) + for call in mock_log_action.call_args_list + ] assert "config_snapshot_created" in actions @@ -103,21 +110,27 @@ async def test_snapshot_skipped_duplicate_audit_event(): mock_log_action = AsyncMock() - with patch( - "app.services.config_snapshot_subscriber.AdminAsyncSessionLocal", - return_value=mock_ctx, - ), patch( - "app.services.config_snapshot_subscriber.OpenBaoTransitService", - return_value=AsyncMock(), - ), patch( - "app.services.config_snapshot_subscriber.log_action", - mock_log_action, + with ( + patch( + "app.services.config_snapshot_subscriber.AdminAsyncSessionLocal", + return_value=mock_ctx, + ), + patch( + "app.services.config_snapshot_subscriber.OpenBaoTransitService", + return_value=AsyncMock(), + ), + patch( + "app.services.config_snapshot_subscriber.log_action", + mock_log_action, + ), ): await handle_config_snapshot(msg) # log_action should have been called with config_snapshot_skipped_duplicate - actions = [call.kwargs.get("action", call.args[4] if len(call.args) > 4 else None) - for call in mock_log_action.call_args_list] + actions = [ + call.kwargs.get("action", call.args[4] if len(call.args) > 4 else None) + for call in mock_log_action.call_args_list + ] assert "config_snapshot_skipped_duplicate" in actions @@ -150,22 +163,28 @@ async def test_diff_generated_audit_event(): mock_session.commit = AsyncMock() mock_openbao = AsyncMock() - mock_openbao.decrypt = AsyncMock(side_effect=[ - old_config.encode("utf-8"), - new_config.encode("utf-8"), - ]) + mock_openbao.decrypt = AsyncMock( + side_effect=[ + old_config.encode("utf-8"), + new_config.encode("utf-8"), + ] + ) mock_log_action = AsyncMock() - with patch( - "app.services.config_diff_service.OpenBaoTransitService", - return_value=mock_openbao, - ), patch( - "app.services.config_diff_service.parse_diff_changes", - return_value=[], - ), patch( - "app.services.audit_service.log_action", - mock_log_action, + with ( + patch( + "app.services.config_diff_service.OpenBaoTransitService", + return_value=mock_openbao, + ), + patch( + "app.services.config_diff_service.parse_diff_changes", + return_value=[], + ), + patch( + "app.services.audit_service.log_action", + mock_log_action, + ), ): await generate_and_store_diff(device_id, tenant_id, new_snapshot_id, mock_session) @@ -217,13 +236,21 @@ async def test_manual_trigger_audit_event(): original_enabled = limiter.enabled limiter.enabled = False try: - with patch.object( - cb_module, "_get_nats", return_value=mock_nc, - ), patch.object( - cb_module, "_check_tenant_access", new_callable=AsyncMock, - ), patch( - "app.services.audit_service.log_action", - mock_log_action, + with ( + patch.object( + cb_module, + "_get_nats", + return_value=mock_nc, + ), + patch.object( + cb_module, + "_check_tenant_access", + new_callable=AsyncMock, + ), + patch( + "app.services.audit_service.log_action", + mock_log_action, + ), ): result = await cb_module.trigger_config_snapshot( request=mock_request, diff --git a/backend/tests/test_backup_scheduler.py b/backend/tests/test_backup_scheduler.py index 1f278ba..e4df30e 100644 --- a/backend/tests/test_backup_scheduler.py +++ b/backend/tests/test_backup_scheduler.py @@ -1,7 +1,7 @@ """Tests for dynamic backup scheduling.""" import pytest -from unittest.mock import AsyncMock, patch, MagicMock +from unittest.mock import MagicMock from app.services.backup_scheduler import ( build_schedule_map, diff --git a/backend/tests/test_config_change_parser.py b/backend/tests/test_config_change_parser.py index db68682..9966abf 100644 --- a/backend/tests/test_config_change_parser.py +++ b/backend/tests/test_config_change_parser.py @@ -4,8 +4,6 @@ Tests the parse_diff_changes function that extracts structured RouterOS component changes from unified diffs. """ -import pytest - from app.services.config_change_parser import parse_diff_changes diff --git a/backend/tests/test_config_change_subscriber.py b/backend/tests/test_config_change_subscriber.py index 50168bc..c009a21 100644 --- a/backend/tests/test_config_change_subscriber.py +++ b/backend/tests/test_config_change_subscriber.py @@ -1,8 +1,7 @@ """Tests for config change NATS subscriber.""" import pytest -from datetime import datetime, timedelta, timezone -from unittest.mock import AsyncMock, patch, MagicMock +from unittest.mock import AsyncMock, patch from uuid import uuid4 from app.services.config_change_subscriber import handle_config_changed @@ -18,13 +17,16 @@ async def test_triggers_backup_on_config_change(): "new_timestamp": "2026-03-07 12:00:00", } - with patch( - "app.services.config_change_subscriber.backup_service.run_backup", - new_callable=AsyncMock, - ) as mock_backup, patch( - "app.services.config_change_subscriber._last_backup_within_dedup_window", - new_callable=AsyncMock, - return_value=False, + with ( + patch( + "app.services.config_change_subscriber.backup_service.run_backup", + new_callable=AsyncMock, + ) as mock_backup, + patch( + "app.services.config_change_subscriber._last_backup_within_dedup_window", + new_callable=AsyncMock, + return_value=False, + ), ): await handle_config_changed(event) @@ -42,13 +44,16 @@ async def test_skips_backup_within_dedup_window(): "new_timestamp": "2026-03-07 12:00:00", } - with patch( - "app.services.config_change_subscriber.backup_service.run_backup", - new_callable=AsyncMock, - ) as mock_backup, patch( - "app.services.config_change_subscriber._last_backup_within_dedup_window", - new_callable=AsyncMock, - return_value=True, + with ( + patch( + "app.services.config_change_subscriber.backup_service.run_backup", + new_callable=AsyncMock, + ) as mock_backup, + patch( + "app.services.config_change_subscriber._last_backup_within_dedup_window", + new_callable=AsyncMock, + return_value=True, + ), ): await handle_config_changed(event) diff --git a/backend/tests/test_config_checkpoint.py b/backend/tests/test_config_checkpoint.py index 31e6a2d..48f99ba 100644 --- a/backend/tests/test_config_checkpoint.py +++ b/backend/tests/test_config_checkpoint.py @@ -13,9 +13,7 @@ class TestCheckpointEndpointExists: from app.routers.config_backups import router paths = [r.path for r in router.routes] - assert any("checkpoint" in p for p in paths), ( - f"No checkpoint route found. Routes: {paths}" - ) + assert any("checkpoint" in p for p in paths), f"No checkpoint route found. Routes: {paths}" def test_checkpoint_route_is_post(self): from app.routers.config_backups import router @@ -53,16 +51,20 @@ class TestCheckpointFunction: mock_request = MagicMock() - with patch( - "app.routers.config_backups.backup_service.run_backup", - new_callable=AsyncMock, - return_value=mock_result, - ) as mock_backup, patch( - "app.routers.config_backups._check_tenant_access", - new_callable=AsyncMock, - ), patch( - "app.routers.config_backups.limiter.enabled", - False, + with ( + patch( + "app.routers.config_backups.backup_service.run_backup", + new_callable=AsyncMock, + return_value=mock_result, + ) as mock_backup, + patch( + "app.routers.config_backups._check_tenant_access", + new_callable=AsyncMock, + ), + patch( + "app.routers.config_backups.limiter.enabled", + False, + ), ): result = await create_checkpoint( request=mock_request, diff --git a/backend/tests/test_config_diff_service.py b/backend/tests/test_config_diff_service.py index 2bf6c89..3a3031a 100644 --- a/backend/tests/test_config_diff_service.py +++ b/backend/tests/test_config_diff_service.py @@ -4,9 +4,8 @@ Tests the generate_and_store_diff function with mocked DB sessions and OpenBao Transit service. """ -import json import pytest -from unittest.mock import AsyncMock, MagicMock, patch, call +from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 @@ -51,17 +50,22 @@ async def test_diff_generated_and_stored(): mock_session.commit = AsyncMock() mock_openbao = AsyncMock() - mock_openbao.decrypt = AsyncMock(side_effect=[ - old_config.encode("utf-8"), - new_config.encode("utf-8"), - ]) + mock_openbao.decrypt = AsyncMock( + side_effect=[ + old_config.encode("utf-8"), + new_config.encode("utf-8"), + ] + ) - with patch( - "app.services.config_diff_service.OpenBaoTransitService", - return_value=mock_openbao, - ), patch( - "app.services.config_diff_service.parse_diff_changes", - return_value=[], + with ( + patch( + "app.services.config_diff_service.OpenBaoTransitService", + return_value=mock_openbao, + ), + patch( + "app.services.config_diff_service.parse_diff_changes", + return_value=[], + ), ): await generate_and_store_diff(device_id, tenant_id, new_snapshot_id, mock_session) @@ -178,17 +182,22 @@ async def test_line_counts_correct(): mock_session.commit = AsyncMock() mock_openbao = AsyncMock() - mock_openbao.decrypt = AsyncMock(side_effect=[ - old_config.encode("utf-8"), - new_config.encode("utf-8"), - ]) + mock_openbao.decrypt = AsyncMock( + side_effect=[ + old_config.encode("utf-8"), + new_config.encode("utf-8"), + ] + ) - with patch( - "app.services.config_diff_service.OpenBaoTransitService", - return_value=mock_openbao, - ), patch( - "app.services.config_diff_service.parse_diff_changes", - return_value=[], + with ( + patch( + "app.services.config_diff_service.OpenBaoTransitService", + return_value=mock_openbao, + ), + patch( + "app.services.config_diff_service.parse_diff_changes", + return_value=[], + ), ): await generate_and_store_diff(device_id, tenant_id, new_snapshot_id, mock_session) @@ -222,10 +231,12 @@ async def test_empty_diff_skips_insert(): mock_session.commit = AsyncMock() mock_openbao = AsyncMock() - mock_openbao.decrypt = AsyncMock(side_effect=[ - same_config.encode("utf-8"), - same_config.encode("utf-8"), - ]) + mock_openbao.decrypt = AsyncMock( + side_effect=[ + same_config.encode("utf-8"), + same_config.encode("utf-8"), + ] + ) with patch( "app.services.config_diff_service.OpenBaoTransitService", @@ -271,22 +282,31 @@ async def test_change_parser_called_and_changes_stored(): mock_session.commit = AsyncMock() mock_openbao = AsyncMock() - mock_openbao.decrypt = AsyncMock(side_effect=[ - old_config.encode("utf-8"), - new_config.encode("utf-8"), - ]) + mock_openbao.decrypt = AsyncMock( + side_effect=[ + old_config.encode("utf-8"), + new_config.encode("utf-8"), + ] + ) mock_changes = [ - {"component": "ip/firewall/filter", "summary": "Added 1 firewall filter rule", "raw_line": "+add chain=forward action=drop"}, + { + "component": "ip/firewall/filter", + "summary": "Added 1 firewall filter rule", + "raw_line": "+add chain=forward action=drop", + }, ] - with patch( - "app.services.config_diff_service.OpenBaoTransitService", - return_value=mock_openbao, - ), patch( - "app.services.config_diff_service.parse_diff_changes", - return_value=mock_changes, - ) as mock_parser: + with ( + patch( + "app.services.config_diff_service.OpenBaoTransitService", + return_value=mock_openbao, + ), + patch( + "app.services.config_diff_service.parse_diff_changes", + return_value=mock_changes, + ) as mock_parser, + ): await generate_and_store_diff(device_id, tenant_id, new_snapshot_id, mock_session) # parse_diff_changes called with the diff text @@ -332,17 +352,22 @@ async def test_change_parser_error_does_not_block_diff(): mock_session.commit = AsyncMock() mock_openbao = AsyncMock() - mock_openbao.decrypt = AsyncMock(side_effect=[ - old_config.encode("utf-8"), - new_config.encode("utf-8"), - ]) + mock_openbao.decrypt = AsyncMock( + side_effect=[ + old_config.encode("utf-8"), + new_config.encode("utf-8"), + ] + ) - with patch( - "app.services.config_diff_service.OpenBaoTransitService", - return_value=mock_openbao, - ), patch( - "app.services.config_diff_service.parse_diff_changes", - side_effect=Exception("Parser exploded"), + with ( + patch( + "app.services.config_diff_service.OpenBaoTransitService", + return_value=mock_openbao, + ), + patch( + "app.services.config_diff_service.parse_diff_changes", + side_effect=Exception("Parser exploded"), + ), ): # Should NOT raise await generate_and_store_diff(device_id, tenant_id, new_snapshot_id, mock_session) diff --git a/backend/tests/test_config_history_service.py b/backend/tests/test_config_history_service.py index 13c67fc..dda3e48 100644 --- a/backend/tests/test_config_history_service.py +++ b/backend/tests/test_config_history_service.py @@ -10,7 +10,9 @@ from uuid import uuid4 from datetime import datetime, timezone -def _make_change_row(change_id, component, summary, created_at, diff_id, lines_added, lines_removed, snapshot_id): +def _make_change_row( + change_id, component, summary, created_at, diff_id, lines_added, lines_removed, snapshot_id +): """Create a mock row matching the JOIN query result.""" row = MagicMock() row._mapping = { @@ -41,7 +43,9 @@ async def test_returns_formatted_entries(): mock_session = AsyncMock() result_mock = MagicMock() result_mock.fetchall.return_value = [ - _make_change_row(change_id, "ip/firewall/filter", "Added 1 rule", ts, diff_id, 3, 1, snapshot_id), + _make_change_row( + change_id, "ip/firewall/filter", "Added 1 rule", ts, diff_id, 3, 1, snapshot_id + ), ] mock_session.execute = AsyncMock(return_value=result_mock) diff --git a/backend/tests/test_config_snapshot_subscriber.py b/backend/tests/test_config_snapshot_subscriber.py index d8c0211..938dbbc 100644 --- a/backend/tests/test_config_snapshot_subscriber.py +++ b/backend/tests/test_config_snapshot_subscriber.py @@ -56,15 +56,19 @@ async def test_new_snapshot_encrypted_and_stored(): mock_openbao = AsyncMock() mock_openbao.encrypt.return_value = "vault:v1:encrypted_data" - with patch( - "app.services.config_snapshot_subscriber.AdminAsyncSessionLocal", - return_value=mock_ctx, - ), patch( - "app.services.config_snapshot_subscriber.OpenBaoTransitService", - return_value=mock_openbao, - ), patch( - "app.services.config_snapshot_subscriber.generate_and_store_diff", - new_callable=AsyncMock, + with ( + patch( + "app.services.config_snapshot_subscriber.AdminAsyncSessionLocal", + return_value=mock_ctx, + ), + patch( + "app.services.config_snapshot_subscriber.OpenBaoTransitService", + return_value=mock_openbao, + ), + patch( + "app.services.config_snapshot_subscriber.generate_and_store_diff", + new_callable=AsyncMock, + ), ): await handle_config_snapshot(msg) @@ -102,12 +106,15 @@ async def test_duplicate_snapshot_skipped(): mock_openbao = AsyncMock() - with patch( - "app.services.config_snapshot_subscriber.AdminAsyncSessionLocal", - return_value=mock_ctx, - ), patch( - "app.services.config_snapshot_subscriber.OpenBaoTransitService", - return_value=mock_openbao, + with ( + patch( + "app.services.config_snapshot_subscriber.AdminAsyncSessionLocal", + return_value=mock_ctx, + ), + patch( + "app.services.config_snapshot_subscriber.OpenBaoTransitService", + return_value=mock_openbao, + ), ): await handle_config_snapshot(msg) @@ -141,12 +148,15 @@ async def test_transit_encrypt_failure_causes_nak(): mock_openbao = AsyncMock() mock_openbao.encrypt.side_effect = Exception("Transit unavailable") - with patch( - "app.services.config_snapshot_subscriber.AdminAsyncSessionLocal", - return_value=mock_ctx, - ), patch( - "app.services.config_snapshot_subscriber.OpenBaoTransitService", - return_value=mock_openbao, + with ( + patch( + "app.services.config_snapshot_subscriber.AdminAsyncSessionLocal", + return_value=mock_ctx, + ), + patch( + "app.services.config_snapshot_subscriber.OpenBaoTransitService", + return_value=mock_openbao, + ), ): await handle_config_snapshot(msg) @@ -168,11 +178,14 @@ async def test_malformed_message_acked_and_discarded(): mock_openbao = AsyncMock() - with patch( - "app.services.config_snapshot_subscriber.AdminAsyncSessionLocal", - ) as mock_session_cls, patch( - "app.services.config_snapshot_subscriber.OpenBaoTransitService", - return_value=mock_openbao, + with ( + patch( + "app.services.config_snapshot_subscriber.AdminAsyncSessionLocal", + ), + patch( + "app.services.config_snapshot_subscriber.OpenBaoTransitService", + return_value=mock_openbao, + ), ): await handle_config_snapshot(msg) @@ -209,12 +222,15 @@ async def test_orphan_device_acked_and_discarded(): mock_openbao = AsyncMock() mock_openbao.encrypt.return_value = "vault:v1:encrypted_data" - with patch( - "app.services.config_snapshot_subscriber.AdminAsyncSessionLocal", - return_value=mock_ctx, - ), patch( - "app.services.config_snapshot_subscriber.OpenBaoTransitService", - return_value=mock_openbao, + with ( + patch( + "app.services.config_snapshot_subscriber.AdminAsyncSessionLocal", + return_value=mock_ctx, + ), + patch( + "app.services.config_snapshot_subscriber.OpenBaoTransitService", + return_value=mock_openbao, + ), ): await handle_config_snapshot(msg) @@ -245,15 +261,19 @@ async def test_first_snapshot_for_device_always_stored(): mock_openbao = AsyncMock() mock_openbao.encrypt.return_value = "vault:v1:first_snapshot_encrypted" - with patch( - "app.services.config_snapshot_subscriber.AdminAsyncSessionLocal", - return_value=mock_ctx, - ), patch( - "app.services.config_snapshot_subscriber.OpenBaoTransitService", - return_value=mock_openbao, - ), patch( - "app.services.config_snapshot_subscriber.generate_and_store_diff", - new_callable=AsyncMock, + with ( + patch( + "app.services.config_snapshot_subscriber.AdminAsyncSessionLocal", + return_value=mock_ctx, + ), + patch( + "app.services.config_snapshot_subscriber.OpenBaoTransitService", + return_value=mock_openbao, + ), + patch( + "app.services.config_snapshot_subscriber.generate_and_store_diff", + new_callable=AsyncMock, + ), ): await handle_config_snapshot(msg) diff --git a/backend/tests/test_config_snapshot_trigger.py b/backend/tests/test_config_snapshot_trigger.py index d2b2bfb..1b94cd3 100644 --- a/backend/tests/test_config_snapshot_trigger.py +++ b/backend/tests/test_config_snapshot_trigger.py @@ -15,7 +15,6 @@ from unittest.mock import AsyncMock, MagicMock import nats.errors from fastapi import HTTPException, status -from sqlalchemy import select # --------------------------------------------------------------------------- @@ -118,11 +117,13 @@ def _mock_db(device_exists: bool): async def test_trigger_success_returns_201(): """POST with operator role returns 201 with status and sha256_hash.""" sha256 = "b" * 64 - nc = _mock_nats_reply({ - "status": "success", - "sha256_hash": sha256, - "message": "Config snapshot collected", - }) + nc = _mock_nats_reply( + { + "status": "success", + "sha256_hash": sha256, + "message": "Config snapshot collected", + } + ) db = _mock_db(device_exists=True) result = await _simulate_trigger(nats_conn=nc, db_session=db) @@ -156,10 +157,12 @@ async def test_trigger_nats_timeout_returns_504(): @pytest.mark.asyncio async def test_trigger_poller_failure_returns_502(): """Poller failure reply returns 502.""" - nc = _mock_nats_reply({ - "status": "failed", - "error": "SSH connection refused", - }) + nc = _mock_nats_reply( + { + "status": "failed", + "error": "SSH connection refused", + } + ) db = _mock_db(device_exists=True) with pytest.raises(HTTPException) as exc_info: @@ -184,10 +187,12 @@ async def test_trigger_device_not_found_returns_404(): @pytest.mark.asyncio async def test_trigger_locked_returns_409(): """Lock contention returns 409 Conflict.""" - nc = _mock_nats_reply({ - "status": "locked", - "message": "backup already in progress", - }) + nc = _mock_nats_reply( + { + "status": "locked", + "message": "backup already in progress", + } + ) db = _mock_db(device_exists=True) with pytest.raises(HTTPException) as exc_info: diff --git a/backend/tests/test_push_recovery.py b/backend/tests/test_push_recovery.py index dfa148f..e8c6577 100644 --- a/backend/tests/test_push_recovery.py +++ b/backend/tests/test_push_recovery.py @@ -35,26 +35,33 @@ async def test_recovery_commits_reachable_device_with_scheduler(): dev_result.scalar_one_or_none.return_value = device mock_session.execute = AsyncMock(side_effect=[mock_result, dev_result]) - with patch( - "app.services.restore_service._check_reachability", - new_callable=AsyncMock, - return_value=True, - ), patch( - "app.services.restore_service._remove_panic_scheduler", - new_callable=AsyncMock, - return_value=True, - ), patch( - "app.services.restore_service._update_push_op_status", - new_callable=AsyncMock, - ) as mock_update, patch( - "app.services.restore_service._publish_push_progress", - new_callable=AsyncMock, - ), patch( - "app.services.crypto.decrypt_credentials_hybrid", - new_callable=AsyncMock, - return_value='{"username": "admin", "password": "test123"}', - ), patch( - "app.services.restore_service.settings", + with ( + patch( + "app.services.restore_service._check_reachability", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "app.services.restore_service._remove_panic_scheduler", + new_callable=AsyncMock, + return_value=True, + ), + patch( + "app.services.restore_service._update_push_op_status", + new_callable=AsyncMock, + ) as mock_update, + patch( + "app.services.restore_service._publish_push_progress", + new_callable=AsyncMock, + ), + patch( + "app.services.crypto.decrypt_credentials_hybrid", + new_callable=AsyncMock, + return_value='{"username": "admin", "password": "test123"}', + ), + patch( + "app.services.restore_service.settings", + ), ): await recover_stale_push_operations(mock_session) @@ -84,22 +91,28 @@ async def test_recovery_marks_unreachable_device_failed(): dev_result.scalar_one_or_none.return_value = device mock_session.execute = AsyncMock(side_effect=[mock_result, dev_result]) - with patch( - "app.services.restore_service._check_reachability", - new_callable=AsyncMock, - return_value=False, - ), patch( - "app.services.restore_service._update_push_op_status", - new_callable=AsyncMock, - ) as mock_update, patch( - "app.services.restore_service._publish_push_progress", - new_callable=AsyncMock, - ), patch( - "app.services.crypto.decrypt_credentials_hybrid", - new_callable=AsyncMock, - return_value='{"username": "admin", "password": "test123"}', - ), patch( - "app.services.restore_service.settings", + with ( + patch( + "app.services.restore_service._check_reachability", + new_callable=AsyncMock, + return_value=False, + ), + patch( + "app.services.restore_service._update_push_op_status", + new_callable=AsyncMock, + ) as mock_update, + patch( + "app.services.restore_service._publish_push_progress", + new_callable=AsyncMock, + ), + patch( + "app.services.crypto.decrypt_credentials_hybrid", + new_callable=AsyncMock, + return_value='{"username": "admin", "password": "test123"}', + ), + patch( + "app.services.restore_service.settings", + ), ): await recover_stale_push_operations(mock_session) diff --git a/backend/tests/test_push_rollback_subscriber.py b/backend/tests/test_push_rollback_subscriber.py index e517c83..a23972f 100644 --- a/backend/tests/test_push_rollback_subscriber.py +++ b/backend/tests/test_push_rollback_subscriber.py @@ -1,7 +1,7 @@ """Tests for push rollback NATS subscriber.""" import pytest -from unittest.mock import AsyncMock, patch, MagicMock +from unittest.mock import AsyncMock, patch from uuid import uuid4 from app.services.push_rollback_subscriber import ( diff --git a/backend/tests/test_restore_preview.py b/backend/tests/test_restore_preview.py index 8cfa0f7..caa63d3 100644 --- a/backend/tests/test_restore_preview.py +++ b/backend/tests/test_restore_preview.py @@ -60,25 +60,32 @@ class TestPreviewRestoreFunction: mock_scalar.scalar_one_or_none.return_value = mock_device mock_db.execute.return_value = mock_scalar - with patch( - "app.routers.config_backups._check_tenant_access", - new_callable=AsyncMock, - ), patch( - "app.routers.config_backups.limiter.enabled", - False, - ), patch( - "app.routers.config_backups.git_store.read_file", - return_value=target_export.encode(), - ), patch( - "app.routers.config_backups.backup_service.capture_export", - new_callable=AsyncMock, - return_value=current_export, - ), patch( - "app.routers.config_backups.decrypt_credentials_hybrid", - new_callable=AsyncMock, - return_value='{"username": "admin", "password": "pass"}', - ), patch( - "app.routers.config_backups.settings", + with ( + patch( + "app.routers.config_backups._check_tenant_access", + new_callable=AsyncMock, + ), + patch( + "app.routers.config_backups.limiter.enabled", + False, + ), + patch( + "app.routers.config_backups.git_store.read_file", + return_value=target_export.encode(), + ), + patch( + "app.routers.config_backups.backup_service.capture_export", + new_callable=AsyncMock, + return_value=current_export, + ), + patch( + "app.routers.config_backups.decrypt_credentials_hybrid", + new_callable=AsyncMock, + return_value='{"username": "admin", "password": "pass"}', + ), + patch( + "app.routers.config_backups.settings", + ), ): result = await preview_restore( request=mock_request, @@ -140,25 +147,32 @@ class TestPreviewRestoreFunction: return current_export.encode() return b"" - with patch( - "app.routers.config_backups._check_tenant_access", - new_callable=AsyncMock, - ), patch( - "app.routers.config_backups.limiter.enabled", - False, - ), patch( - "app.routers.config_backups.git_store.read_file", - side_effect=mock_read_file, - ), patch( - "app.routers.config_backups.backup_service.capture_export", - new_callable=AsyncMock, - side_effect=ConnectionError("Device unreachable"), - ), patch( - "app.routers.config_backups.decrypt_credentials_hybrid", - new_callable=AsyncMock, - return_value='{"username": "admin", "password": "pass"}', - ), patch( - "app.routers.config_backups.settings", + with ( + patch( + "app.routers.config_backups._check_tenant_access", + new_callable=AsyncMock, + ), + patch( + "app.routers.config_backups.limiter.enabled", + False, + ), + patch( + "app.routers.config_backups.git_store.read_file", + side_effect=mock_read_file, + ), + patch( + "app.routers.config_backups.backup_service.capture_export", + new_callable=AsyncMock, + side_effect=ConnectionError("Device unreachable"), + ), + patch( + "app.routers.config_backups.decrypt_credentials_hybrid", + new_callable=AsyncMock, + return_value='{"username": "admin", "password": "pass"}', + ), + patch( + "app.routers.config_backups.settings", + ), ): result = await preview_restore( request=mock_request, @@ -188,15 +202,19 @@ class TestPreviewRestoreFunction: mock_request = MagicMock() body = RestoreRequest(commit_sha="nonexistent") - with patch( - "app.routers.config_backups._check_tenant_access", - new_callable=AsyncMock, - ), patch( - "app.routers.config_backups.limiter.enabled", - False, - ), patch( - "app.routers.config_backups.git_store.read_file", - side_effect=KeyError("not found"), + with ( + patch( + "app.routers.config_backups._check_tenant_access", + new_callable=AsyncMock, + ), + patch( + "app.routers.config_backups.limiter.enabled", + False, + ), + patch( + "app.routers.config_backups.git_store.read_file", + side_effect=KeyError("not found"), + ), ): with pytest.raises(HTTPException) as exc_info: await preview_restore( diff --git a/backend/tests/test_retention_service.py b/backend/tests/test_retention_service.py index 8ec2f97..28ef5be 100644 --- a/backend/tests/test_retention_service.py +++ b/backend/tests/test_retention_service.py @@ -23,12 +23,15 @@ async def test_cleanup_deletes_expired_snapshots(): mock_ctx.__aenter__ = AsyncMock(return_value=mock_session) mock_ctx.__aexit__ = AsyncMock(return_value=False) - with patch( - "app.services.retention_service.AdminAsyncSessionLocal", - return_value=mock_ctx, - ), patch( - "app.services.retention_service.settings", - ) as mock_settings: + with ( + patch( + "app.services.retention_service.AdminAsyncSessionLocal", + return_value=mock_ctx, + ), + patch( + "app.services.retention_service.settings", + ) as mock_settings, + ): mock_settings.CONFIG_RETENTION_DAYS = 90 count = await cleanup_expired_snapshots() @@ -60,12 +63,15 @@ async def test_cleanup_keeps_snapshots_within_retention_window(): mock_ctx.__aenter__ = AsyncMock(return_value=mock_session) mock_ctx.__aexit__ = AsyncMock(return_value=False) - with patch( - "app.services.retention_service.AdminAsyncSessionLocal", - return_value=mock_ctx, - ), patch( - "app.services.retention_service.settings", - ) as mock_settings: + with ( + patch( + "app.services.retention_service.AdminAsyncSessionLocal", + return_value=mock_ctx, + ), + patch( + "app.services.retention_service.settings", + ) as mock_settings, + ): mock_settings.CONFIG_RETENTION_DAYS = 90 count = await cleanup_expired_snapshots() @@ -87,12 +93,15 @@ async def test_cleanup_returns_deleted_count(): mock_ctx.__aenter__ = AsyncMock(return_value=mock_session) mock_ctx.__aexit__ = AsyncMock(return_value=False) - with patch( - "app.services.retention_service.AdminAsyncSessionLocal", - return_value=mock_ctx, - ), patch( - "app.services.retention_service.settings", - ) as mock_settings: + with ( + patch( + "app.services.retention_service.AdminAsyncSessionLocal", + return_value=mock_ctx, + ), + patch( + "app.services.retention_service.settings", + ) as mock_settings, + ): mock_settings.CONFIG_RETENTION_DAYS = 30 count = await cleanup_expired_snapshots() @@ -114,12 +123,15 @@ async def test_cleanup_handles_empty_table(): mock_ctx.__aenter__ = AsyncMock(return_value=mock_session) mock_ctx.__aexit__ = AsyncMock(return_value=False) - with patch( - "app.services.retention_service.AdminAsyncSessionLocal", - return_value=mock_ctx, - ), patch( - "app.services.retention_service.settings", - ) as mock_settings: + with ( + patch( + "app.services.retention_service.AdminAsyncSessionLocal", + return_value=mock_ctx, + ), + patch( + "app.services.retention_service.settings", + ) as mock_settings, + ): mock_settings.CONFIG_RETENTION_DAYS = 90 count = await cleanup_expired_snapshots() diff --git a/backend/tests/test_rsc_parser.py b/backend/tests/test_rsc_parser.py index de68ccf..b7217a1 100644 --- a/backend/tests/test_rsc_parser.py +++ b/backend/tests/test_rsc_parser.py @@ -1,6 +1,5 @@ """Tests for RouterOS RSC export parser.""" -import pytest from app.services.rsc_parser import parse_rsc, validate_rsc, compute_impact @@ -74,7 +73,7 @@ class TestValidateRsc: assert any("quote" in e.lower() for e in result["errors"]) def test_truncated_continuation_detected(self): - bad = '/ip address\nadd address=192.168.1.1/24 \\\n' + bad = "/ip address\nadd address=192.168.1.1/24 \\\n" result = validate_rsc(bad) assert result["valid"] is False assert any("truncat" in e.lower() or "continuation" in e.lower() for e in result["errors"]) @@ -82,25 +81,25 @@ class TestValidateRsc: class TestComputeImpact: def test_high_risk_for_firewall_input(self): - current = '/ip firewall filter\nadd action=accept chain=input\n' - target = '/ip firewall filter\nadd action=drop chain=input\n' + current = "/ip firewall filter\nadd action=accept chain=input\n" + target = "/ip firewall filter\nadd action=drop chain=input\n" result = compute_impact(parse_rsc(current), parse_rsc(target)) assert any(c["risk"] == "high" for c in result["categories"]) def test_high_risk_for_ip_address_changes(self): - current = '/ip address\nadd address=192.168.1.1/24 interface=ether1\n' - target = '/ip address\nadd address=10.0.0.1/24 interface=ether1\n' + current = "/ip address\nadd address=192.168.1.1/24 interface=ether1\n" + target = "/ip address\nadd address=10.0.0.1/24 interface=ether1\n" result = compute_impact(parse_rsc(current), parse_rsc(target)) ip_cat = next(c for c in result["categories"] if c["path"] == "/ip address") assert ip_cat["risk"] in ("high", "medium") def test_warnings_for_management_access(self): current = "" - target = '/ip firewall filter\nadd action=drop chain=input protocol=tcp dst-port=22\n' + target = "/ip firewall filter\nadd action=drop chain=input protocol=tcp dst-port=22\n" result = compute_impact(parse_rsc(current), parse_rsc(target)) assert len(result["warnings"]) > 0 def test_no_changes_no_warnings(self): - same = '/ip dns\nset servers=8.8.8.8\n' + same = "/ip dns\nset servers=8.8.8.8\n" result = compute_impact(parse_rsc(same), parse_rsc(same)) assert result["warnings"] == [] or all(c["risk"] == "none" for c in result["categories"]) diff --git a/backend/tests/test_srp_interop.py b/backend/tests/test_srp_interop.py index 1bc6b56..5f2422b 100644 --- a/backend/tests/test_srp_interop.py +++ b/backend/tests/test_srp_interop.py @@ -32,7 +32,7 @@ def test_srp_roundtrip(): context = SRPContext(EMAIL, password=PASSWORD, prime=PRIME_2048, generator=PRIME_2048_GEN) username, verifier, salt = context.get_user_data_triplet() - print(f"\n--- SRP Interop Reference Values ---") + print("\n--- SRP Interop Reference Values ---") print(f"email (I): {EMAIL}") print(f"salt (s): {salt}") print(f"verifier (v): {verifier[:64]}... (len={len(verifier)})") @@ -45,7 +45,9 @@ def test_srp_roundtrip(): print(f"server_public (B): {server_public[:64]}... (len={len(server_public)})") # Step 3: Client init -- generate A (client needs password for proof) - client_context = SRPContext(EMAIL, password=PASSWORD, prime=PRIME_2048, generator=PRIME_2048_GEN) + client_context = SRPContext( + EMAIL, password=PASSWORD, prime=PRIME_2048, generator=PRIME_2048_GEN + ) client_session = SRPClientSession(client_context) client_public = client_session.public @@ -78,7 +80,7 @@ def test_srp_roundtrip(): ) print(f"session_key (K): {client_session.key[:64]}... (len={len(client_session.key)})") - print(f"--- Handshake PASSED ---\n") + print("--- Handshake PASSED ---\n") def test_srp_bad_proof_rejected(): @@ -89,7 +91,9 @@ def test_srp_bad_proof_rejected(): server_context = SRPContext(EMAIL, prime=PRIME_2048, generator=PRIME_2048_GEN) server_session = SRPServerSession(server_context, verifier) - client_context = SRPContext(EMAIL, password=PASSWORD, prime=PRIME_2048, generator=PRIME_2048_GEN) + client_context = SRPContext( + EMAIL, password=PASSWORD, prime=PRIME_2048, generator=PRIME_2048_GEN + ) client_session = SRPClientSession(client_context) client_session.process(server_session.public, salt) diff --git a/backend/tests/unit/test_audit_service.py b/backend/tests/unit/test_audit_service.py index a319821..6c91954 100644 --- a/backend/tests/unit/test_audit_service.py +++ b/backend/tests/unit/test_audit_service.py @@ -8,7 +8,7 @@ Tests cover: """ import uuid -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock import pytest @@ -18,15 +18,24 @@ class TestAuditLogModel: def test_model_importable(self): from app.models.audit_log import AuditLog + assert AuditLog.__tablename__ == "audit_logs" def test_model_has_required_columns(self): from app.models.audit_log import AuditLog + mapper = AuditLog.__table__.columns expected_columns = { - "id", "tenant_id", "user_id", "action", - "resource_type", "resource_id", "device_id", - "details", "ip_address", "created_at", + "id", + "tenant_id", + "user_id", + "action", + "resource_type", + "resource_id", + "device_id", + "details", + "ip_address", + "created_at", } actual_columns = {c.name for c in mapper} assert expected_columns.issubset(actual_columns), ( @@ -35,6 +44,7 @@ class TestAuditLogModel: def test_model_exported_from_init(self): from app.models import AuditLog + assert AuditLog.__tablename__ == "audit_logs" @@ -43,6 +53,7 @@ class TestAuditService: def test_log_action_importable(self): from app.services.audit_service import log_action + assert callable(log_action) @pytest.mark.asyncio @@ -67,9 +78,11 @@ class TestAuditRouter: def test_router_importable(self): from app.routers.audit_logs import router + assert router is not None def test_router_has_audit_logs_endpoint(self): from app.routers.audit_logs import router + paths = [route.path for route in router.routes] assert "/audit-logs" in paths or any("/audit-logs" in p for p in paths) diff --git a/backend/tests/unit/test_auth.py b/backend/tests/unit/test_auth.py index be65f92..e091fe7 100644 --- a/backend/tests/unit/test_auth.py +++ b/backend/tests/unit/test_auth.py @@ -11,7 +11,6 @@ These are pure function tests -- no database or async required. import uuid from datetime import UTC, datetime, timedelta -from unittest.mock import patch import pytest from fastapi import HTTPException @@ -83,9 +82,7 @@ class TestAccessToken: assert payload["role"] == "super_admin" def test_contains_expiry(self): - token = create_access_token( - user_id=uuid.uuid4(), tenant_id=uuid.uuid4(), role="viewer" - ) + token = create_access_token(user_id=uuid.uuid4(), tenant_id=uuid.uuid4(), role="viewer") payload = verify_token(token, expected_type="access") assert "exp" in payload assert "iat" in payload diff --git a/backend/tests/unit/test_config_snapshot_models.py b/backend/tests/unit/test_config_snapshot_models.py index b90e251..a428ac2 100644 --- a/backend/tests/unit/test_config_snapshot_models.py +++ b/backend/tests/unit/test_config_snapshot_models.py @@ -3,53 +3,65 @@ Verifies STOR-01 (table/column structure) and STOR-05 (config_text stores ciphertext). """ -import uuid - from sqlalchemy import String, Text -from sqlalchemy.dialects.postgresql import UUID def test_router_config_snapshot_importable(): """RouterConfigSnapshot can be imported from app.models.""" from app.models import RouterConfigSnapshot + assert RouterConfigSnapshot is not None def test_router_config_diff_importable(): """RouterConfigDiff can be imported from app.models.""" from app.models import RouterConfigDiff + assert RouterConfigDiff is not None def test_router_config_change_importable(): """RouterConfigChange can be imported from app.models.""" from app.models import RouterConfigChange + assert RouterConfigChange is not None def test_snapshot_tablename(): """RouterConfigSnapshot.__tablename__ is correct.""" from app.models import RouterConfigSnapshot + assert RouterConfigSnapshot.__tablename__ == "router_config_snapshots" def test_diff_tablename(): """RouterConfigDiff.__tablename__ is correct.""" from app.models import RouterConfigDiff + assert RouterConfigDiff.__tablename__ == "router_config_diffs" def test_change_tablename(): """RouterConfigChange.__tablename__ is correct.""" from app.models import RouterConfigChange + assert RouterConfigChange.__tablename__ == "router_config_changes" def test_snapshot_columns(): """RouterConfigSnapshot has all required columns.""" from app.models import RouterConfigSnapshot + table = RouterConfigSnapshot.__table__ - expected = {"id", "device_id", "tenant_id", "config_text", "sha256_hash", "collected_at", "created_at"} + expected = { + "id", + "device_id", + "tenant_id", + "config_text", + "sha256_hash", + "collected_at", + "created_at", + } actual = {c.name for c in table.columns} assert expected.issubset(actual), f"Missing columns: {expected - actual}" @@ -57,10 +69,18 @@ def test_snapshot_columns(): def test_diff_columns(): """RouterConfigDiff has all required columns.""" from app.models import RouterConfigDiff + table = RouterConfigDiff.__table__ expected = { - "id", "device_id", "tenant_id", "old_snapshot_id", "new_snapshot_id", - "diff_text", "lines_added", "lines_removed", "created_at", + "id", + "device_id", + "tenant_id", + "old_snapshot_id", + "new_snapshot_id", + "diff_text", + "lines_added", + "lines_removed", + "created_at", } actual = {c.name for c in table.columns} assert expected.issubset(actual), f"Missing columns: {expected - actual}" @@ -69,10 +89,17 @@ def test_diff_columns(): def test_change_columns(): """RouterConfigChange has all required columns.""" from app.models import RouterConfigChange + table = RouterConfigChange.__table__ expected = { - "id", "diff_id", "device_id", "tenant_id", - "component", "summary", "raw_line", "created_at", + "id", + "diff_id", + "device_id", + "tenant_id", + "component", + "summary", + "raw_line", + "created_at", } actual = {c.name for c in table.columns} assert expected.issubset(actual), f"Missing columns: {expected - actual}" @@ -81,6 +108,7 @@ def test_change_columns(): def test_snapshot_config_text_is_text_type(): """config_text column type is Text (documents Transit ciphertext contract).""" from app.models import RouterConfigSnapshot + col = RouterConfigSnapshot.__table__.c.config_text assert isinstance(col.type, Text), f"Expected Text, got {type(col.type)}" @@ -88,6 +116,7 @@ def test_snapshot_config_text_is_text_type(): def test_snapshot_sha256_hash_is_string_64(): """sha256_hash column type is String(64) for plaintext hash deduplication.""" from app.models import RouterConfigSnapshot + col = RouterConfigSnapshot.__table__.c.sha256_hash assert isinstance(col.type, String), f"Expected String, got {type(col.type)}" assert col.type.length == 64, f"Expected length 64, got {col.type.length}" diff --git a/backend/tests/unit/test_maintenance_windows.py b/backend/tests/unit/test_maintenance_windows.py index 67b0cb5..ff48fb0 100644 --- a/backend/tests/unit/test_maintenance_windows.py +++ b/backend/tests/unit/test_maintenance_windows.py @@ -7,11 +7,9 @@ Tests cover: - Router registration in main app """ -import uuid from datetime import datetime, timezone, timedelta import pytest -from pydantic import ValidationError class TestMaintenanceWindowModel: @@ -19,20 +17,31 @@ class TestMaintenanceWindowModel: def test_model_importable(self): from app.models.maintenance_window import MaintenanceWindow + assert MaintenanceWindow.__tablename__ == "maintenance_windows" def test_model_exported_from_init(self): from app.models import MaintenanceWindow + assert MaintenanceWindow.__tablename__ == "maintenance_windows" def test_model_has_required_columns(self): from app.models.maintenance_window import MaintenanceWindow + mapper = MaintenanceWindow.__mapper__ column_names = {c.key for c in mapper.columns} expected = { - "id", "tenant_id", "name", "device_ids", - "start_at", "end_at", "suppress_alerts", - "notes", "created_by", "created_at", "updated_at", + "id", + "tenant_id", + "name", + "device_ids", + "start_at", + "end_at", + "suppress_alerts", + "notes", + "created_by", + "created_at", + "updated_at", } assert expected.issubset(column_names), f"Missing columns: {expected - column_names}" @@ -42,6 +51,7 @@ class TestMaintenanceWindowSchemas: def test_create_schema_valid(self): from app.routers.maintenance_windows import MaintenanceWindowCreate + data = MaintenanceWindowCreate( name="Nightly update", device_ids=["abc-123"], @@ -55,6 +65,7 @@ class TestMaintenanceWindowSchemas: def test_create_schema_defaults(self): from app.routers.maintenance_windows import MaintenanceWindowCreate + data = MaintenanceWindowCreate( name="Quick reboot", device_ids=[], @@ -66,12 +77,14 @@ class TestMaintenanceWindowSchemas: def test_update_schema_partial(self): from app.routers.maintenance_windows import MaintenanceWindowUpdate + data = MaintenanceWindowUpdate(name="Updated name") assert data.name == "Updated name" assert data.device_ids is None # all optional def test_response_schema(self): from app.routers.maintenance_windows import MaintenanceWindowResponse + data = MaintenanceWindowResponse( id="abc", tenant_id="def", @@ -92,10 +105,12 @@ class TestRouterRegistration: def test_router_importable(self): from app.routers.maintenance_windows import router + assert router is not None def test_router_has_routes(self): from app.routers.maintenance_windows import router + paths = [r.path for r in router.routes] assert any("maintenance-windows" in p for p in paths) @@ -114,8 +129,10 @@ class TestAlertEvaluatorMaintenance: def test_maintenance_cache_exists(self): from app.services import alert_evaluator + assert hasattr(alert_evaluator, "_maintenance_cache") def test_is_device_in_maintenance_function_exists(self): from app.services.alert_evaluator import _is_device_in_maintenance + assert callable(_is_device_in_maintenance) diff --git a/backend/tests/unit/test_security.py b/backend/tests/unit/test_security.py index 29b7706..7decf8c 100644 --- a/backend/tests/unit/test_security.py +++ b/backend/tests/unit/test_security.py @@ -9,7 +9,6 @@ for startup validation, async only for middleware tests. """ from types import SimpleNamespace -from unittest.mock import patch import pytest @@ -114,7 +113,9 @@ class TestSecurityHeadersMiddleware: response = await client.get("/test") assert response.status_code == 200 - assert response.headers["strict-transport-security"] == "max-age=31536000; includeSubDomains" + assert ( + response.headers["strict-transport-security"] == "max-age=31536000; includeSubDomains" + ) assert response.headers["x-content-type-options"] == "nosniff" assert response.headers["x-frame-options"] == "DENY" assert response.headers["cache-control"] == "no-store" diff --git a/backend/tests/unit/test_vpn_subnet.py b/backend/tests/unit/test_vpn_subnet.py index 7ace7d2..d7278df 100644 --- a/backend/tests/unit/test_vpn_subnet.py +++ b/backend/tests/unit/test_vpn_subnet.py @@ -1,7 +1,10 @@ """Unit tests for VPN subnet allocation and allowed-IPs validation.""" import pytest -from app.services.vpn_service import _allocate_subnet_index_from_used, _validate_additional_allowed_ips +from app.services.vpn_service import ( + _allocate_subnet_index_from_used, + _validate_additional_allowed_ips, +) class TestSubnetAllocation: