fix(lint): resolve all ruff lint errors

Add ruff config to exclude alembic E402, SQLAlchemy F821, and pre-existing
E501 line-length issues. Auto-fix 69 unused imports and 2 f-strings without
placeholders. Manually fix 8 unused variables. Apply ruff format to 127 files.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Jason Staack
2026-03-14 22:17:50 -05:00
parent 2ad0367c91
commit 06a41ca9bf
133 changed files with 2927 additions and 1890 deletions

View File

@@ -220,7 +220,8 @@ def upgrade() -> None:
# Super admin sees all; tenant users see only their tenant # Super admin sees all; tenant users see only their tenant
conn.execute(sa.text("ALTER TABLE tenants ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE tenants ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text("ALTER TABLE tenants FORCE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE tenants FORCE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON tenants CREATE POLICY tenant_isolation ON tenants
USING ( USING (
id::text = current_setting('app.current_tenant', true) id::text = current_setting('app.current_tenant', true)
@@ -230,13 +231,15 @@ def upgrade() -> None:
id::text = current_setting('app.current_tenant', true) id::text = current_setting('app.current_tenant', true)
OR current_setting('app.current_tenant', true) = 'super_admin' OR current_setting('app.current_tenant', true) = 'super_admin'
) )
""")) """)
)
# --- USERS RLS --- # --- USERS RLS ---
# Users see only other users in their tenant; super_admin sees all # Users see only other users in their tenant; super_admin sees all
conn.execute(sa.text("ALTER TABLE users ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE users ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text("ALTER TABLE users FORCE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE users FORCE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON users CREATE POLICY tenant_isolation ON users
USING ( USING (
tenant_id::text = current_setting('app.current_tenant', true) tenant_id::text = current_setting('app.current_tenant', true)
@@ -246,41 +249,49 @@ def upgrade() -> None:
tenant_id::text = current_setting('app.current_tenant', true) tenant_id::text = current_setting('app.current_tenant', true)
OR current_setting('app.current_tenant', true) = 'super_admin' OR current_setting('app.current_tenant', true) = 'super_admin'
) )
""")) """)
)
# --- DEVICES RLS --- # --- DEVICES RLS ---
conn.execute(sa.text("ALTER TABLE devices ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE devices ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text("ALTER TABLE devices FORCE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE devices FORCE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON devices CREATE POLICY tenant_isolation ON devices
USING (tenant_id::text = current_setting('app.current_tenant', true)) USING (tenant_id::text = current_setting('app.current_tenant', true))
WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true)) WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true))
""")) """)
)
# --- DEVICE GROUPS RLS --- # --- DEVICE GROUPS RLS ---
conn.execute(sa.text("ALTER TABLE device_groups ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE device_groups ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text("ALTER TABLE device_groups FORCE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE device_groups FORCE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON device_groups CREATE POLICY tenant_isolation ON device_groups
USING (tenant_id::text = current_setting('app.current_tenant', true)) USING (tenant_id::text = current_setting('app.current_tenant', true))
WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true)) WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true))
""")) """)
)
# --- DEVICE TAGS RLS --- # --- DEVICE TAGS RLS ---
conn.execute(sa.text("ALTER TABLE device_tags ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE device_tags ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text("ALTER TABLE device_tags FORCE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE device_tags FORCE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON device_tags CREATE POLICY tenant_isolation ON device_tags
USING (tenant_id::text = current_setting('app.current_tenant', true)) USING (tenant_id::text = current_setting('app.current_tenant', true))
WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true)) WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true))
""")) """)
)
# --- DEVICE GROUP MEMBERSHIPS RLS --- # --- DEVICE GROUP MEMBERSHIPS RLS ---
# These are filtered by joining through devices/groups (which already have RLS) # These are filtered by joining through devices/groups (which already have RLS)
# But we also add direct RLS via a join to the devices table # But we also add direct RLS via a join to the devices table
conn.execute(sa.text("ALTER TABLE device_group_memberships ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE device_group_memberships ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text("ALTER TABLE device_group_memberships FORCE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE device_group_memberships FORCE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON device_group_memberships CREATE POLICY tenant_isolation ON device_group_memberships
USING ( USING (
EXISTS ( EXISTS (
@@ -296,12 +307,14 @@ def upgrade() -> None:
AND d.tenant_id::text = current_setting('app.current_tenant', true) AND d.tenant_id::text = current_setting('app.current_tenant', true)
) )
) )
""")) """)
)
# --- DEVICE TAG ASSIGNMENTS RLS --- # --- DEVICE TAG ASSIGNMENTS RLS ---
conn.execute(sa.text("ALTER TABLE device_tag_assignments ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE device_tag_assignments ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text("ALTER TABLE device_tag_assignments FORCE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE device_tag_assignments FORCE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON device_tag_assignments CREATE POLICY tenant_isolation ON device_tag_assignments
USING ( USING (
EXISTS ( EXISTS (
@@ -317,7 +330,8 @@ def upgrade() -> None:
AND d.tenant_id::text = current_setting('app.current_tenant', true) AND d.tenant_id::text = current_setting('app.current_tenant', true)
) )
) )
""")) """)
)
# ========================================================================= # =========================================================================
# GRANT PERMISSIONS TO app_user (RLS-enforcing application role) # GRANT PERMISSIONS TO app_user (RLS-enforcing application role)
@@ -336,9 +350,7 @@ def upgrade() -> None:
] ]
for table in tables: for table in tables:
conn.execute(sa.text( conn.execute(sa.text(f"GRANT SELECT, INSERT, UPDATE, DELETE ON {table} TO app_user"))
f"GRANT SELECT, INSERT, UPDATE, DELETE ON {table} TO app_user"
))
# Grant sequence usage for UUID generation (gen_random_uuid is built-in, but just in case) # Grant sequence usage for UUID generation (gen_random_uuid is built-in, but just in case)
conn.execute(sa.text("GRANT USAGE ON SCHEMA public TO app_user")) conn.execute(sa.text("GRANT USAGE ON SCHEMA public TO app_user"))

View File

@@ -46,7 +46,8 @@ def upgrade() -> None:
# to read all devices across all tenants, which is required for polling. # to read all devices across all tenants, which is required for polling.
conn = op.get_bind() conn = op.get_bind()
conn.execute(sa.text(""" conn.execute(
sa.text("""
DO $$ DO $$
BEGIN BEGIN
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = 'poller_user') THEN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = 'poller_user') THEN
@@ -54,7 +55,8 @@ def upgrade() -> None:
END IF; END IF;
END END
$$ $$
""")) """)
)
conn.execute(sa.text("GRANT CONNECT ON DATABASE tod TO poller_user")) conn.execute(sa.text("GRANT CONNECT ON DATABASE tod TO poller_user"))
conn.execute(sa.text("GRANT USAGE ON SCHEMA public TO poller_user")) conn.execute(sa.text("GRANT USAGE ON SCHEMA public TO poller_user"))

View File

@@ -34,7 +34,8 @@ def upgrade() -> None:
# Stores per-interface byte counters from /interface/print on every poll cycle. # Stores per-interface byte counters from /interface/print on every poll cycle.
# rx_bps/tx_bps are stored as NULL — computed at query time via LAG() window # rx_bps/tx_bps are stored as NULL — computed at query time via LAG() window
# function to avoid delta state in the poller. # function to avoid delta state in the poller.
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE TABLE IF NOT EXISTS interface_metrics ( CREATE TABLE IF NOT EXISTS interface_metrics (
time TIMESTAMPTZ NOT NULL, time TIMESTAMPTZ NOT NULL,
device_id UUID NOT NULL, device_id UUID NOT NULL,
@@ -45,23 +46,28 @@ def upgrade() -> None:
rx_bps BIGINT, rx_bps BIGINT,
tx_bps BIGINT tx_bps BIGINT
) )
""")) """)
)
conn.execute(sa.text( conn.execute(
"SELECT create_hypertable('interface_metrics', 'time', if_not_exists => TRUE)" sa.text("SELECT create_hypertable('interface_metrics', 'time', if_not_exists => TRUE)")
)) )
conn.execute(sa.text( conn.execute(
"CREATE INDEX IF NOT EXISTS idx_interface_metrics_device_time " sa.text(
"ON interface_metrics (device_id, time DESC)" "CREATE INDEX IF NOT EXISTS idx_interface_metrics_device_time "
)) "ON interface_metrics (device_id, time DESC)"
)
)
conn.execute(sa.text("ALTER TABLE interface_metrics ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE interface_metrics ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON interface_metrics CREATE POLICY tenant_isolation ON interface_metrics
USING (tenant_id::text = current_setting('app.current_tenant')) USING (tenant_id::text = current_setting('app.current_tenant'))
""")) """)
)
conn.execute(sa.text("GRANT SELECT, INSERT ON interface_metrics TO app_user")) conn.execute(sa.text("GRANT SELECT, INSERT ON interface_metrics TO app_user"))
conn.execute(sa.text("GRANT SELECT, INSERT ON interface_metrics TO poller_user")) conn.execute(sa.text("GRANT SELECT, INSERT ON interface_metrics TO poller_user"))
@@ -72,7 +78,8 @@ def upgrade() -> None:
# Stores per-device system health metrics from /system/resource/print and # Stores per-device system health metrics from /system/resource/print and
# /system/health/print on every poll cycle. # /system/health/print on every poll cycle.
# temperature is nullable — not all RouterOS devices have temperature sensors. # temperature is nullable — not all RouterOS devices have temperature sensors.
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE TABLE IF NOT EXISTS health_metrics ( CREATE TABLE IF NOT EXISTS health_metrics (
time TIMESTAMPTZ NOT NULL, time TIMESTAMPTZ NOT NULL,
device_id UUID NOT NULL, device_id UUID NOT NULL,
@@ -84,23 +91,28 @@ def upgrade() -> None:
total_disk BIGINT, total_disk BIGINT,
temperature SMALLINT temperature SMALLINT
) )
""")) """)
)
conn.execute(sa.text( conn.execute(
"SELECT create_hypertable('health_metrics', 'time', if_not_exists => TRUE)" sa.text("SELECT create_hypertable('health_metrics', 'time', if_not_exists => TRUE)")
)) )
conn.execute(sa.text( conn.execute(
"CREATE INDEX IF NOT EXISTS idx_health_metrics_device_time " sa.text(
"ON health_metrics (device_id, time DESC)" "CREATE INDEX IF NOT EXISTS idx_health_metrics_device_time "
)) "ON health_metrics (device_id, time DESC)"
)
)
conn.execute(sa.text("ALTER TABLE health_metrics ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE health_metrics ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON health_metrics CREATE POLICY tenant_isolation ON health_metrics
USING (tenant_id::text = current_setting('app.current_tenant')) USING (tenant_id::text = current_setting('app.current_tenant'))
""")) """)
)
conn.execute(sa.text("GRANT SELECT, INSERT ON health_metrics TO app_user")) conn.execute(sa.text("GRANT SELECT, INSERT ON health_metrics TO app_user"))
conn.execute(sa.text("GRANT SELECT, INSERT ON health_metrics TO poller_user")) conn.execute(sa.text("GRANT SELECT, INSERT ON health_metrics TO poller_user"))
@@ -113,7 +125,8 @@ def upgrade() -> None:
# /interface/wifi/registration-table/print (v7). # /interface/wifi/registration-table/print (v7).
# ccq may be 0 on RouterOS v7 (not available in the WiFi API path). # ccq may be 0 on RouterOS v7 (not available in the WiFi API path).
# avg_signal is dBm (negative integer, e.g. -67). # avg_signal is dBm (negative integer, e.g. -67).
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE TABLE IF NOT EXISTS wireless_metrics ( CREATE TABLE IF NOT EXISTS wireless_metrics (
time TIMESTAMPTZ NOT NULL, time TIMESTAMPTZ NOT NULL,
device_id UUID NOT NULL, device_id UUID NOT NULL,
@@ -124,23 +137,28 @@ def upgrade() -> None:
ccq SMALLINT, ccq SMALLINT,
frequency INTEGER frequency INTEGER
) )
""")) """)
)
conn.execute(sa.text( conn.execute(
"SELECT create_hypertable('wireless_metrics', 'time', if_not_exists => TRUE)" sa.text("SELECT create_hypertable('wireless_metrics', 'time', if_not_exists => TRUE)")
)) )
conn.execute(sa.text( conn.execute(
"CREATE INDEX IF NOT EXISTS idx_wireless_metrics_device_time " sa.text(
"ON wireless_metrics (device_id, time DESC)" "CREATE INDEX IF NOT EXISTS idx_wireless_metrics_device_time "
)) "ON wireless_metrics (device_id, time DESC)"
)
)
conn.execute(sa.text("ALTER TABLE wireless_metrics ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE wireless_metrics ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON wireless_metrics CREATE POLICY tenant_isolation ON wireless_metrics
USING (tenant_id::text = current_setting('app.current_tenant')) USING (tenant_id::text = current_setting('app.current_tenant'))
""")) """)
)
conn.execute(sa.text("GRANT SELECT, INSERT ON wireless_metrics TO app_user")) conn.execute(sa.text("GRANT SELECT, INSERT ON wireless_metrics TO app_user"))
conn.execute(sa.text("GRANT SELECT, INSERT ON wireless_metrics TO poller_user")) conn.execute(sa.text("GRANT SELECT, INSERT ON wireless_metrics TO poller_user"))

View File

@@ -32,7 +32,8 @@ def upgrade() -> None:
# Stores metadata for each backup run. The actual config content lives in # Stores metadata for each backup run. The actual config content lives in
# the tenant's bare git repository (GIT_STORE_PATH). This table provides # the tenant's bare git repository (GIT_STORE_PATH). This table provides
# the timeline view and change tracking without duplicating file content. # the timeline view and change tracking without duplicating file content.
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE TABLE IF NOT EXISTS config_backup_runs ( CREATE TABLE IF NOT EXISTS config_backup_runs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(), id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
device_id UUID NOT NULL REFERENCES devices(id) ON DELETE CASCADE, device_id UUID NOT NULL REFERENCES devices(id) ON DELETE CASCADE,
@@ -43,19 +44,24 @@ def upgrade() -> None:
lines_removed INT, lines_removed INT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
) )
""")) """)
)
conn.execute(sa.text( conn.execute(
"CREATE INDEX IF NOT EXISTS idx_config_backup_runs_device_created " sa.text(
"ON config_backup_runs (device_id, created_at DESC)" "CREATE INDEX IF NOT EXISTS idx_config_backup_runs_device_created "
)) "ON config_backup_runs (device_id, created_at DESC)"
)
)
conn.execute(sa.text("ALTER TABLE config_backup_runs ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE config_backup_runs ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON config_backup_runs CREATE POLICY tenant_isolation ON config_backup_runs
USING (tenant_id::text = current_setting('app.current_tenant')) USING (tenant_id::text = current_setting('app.current_tenant'))
""")) """)
)
conn.execute(sa.text("GRANT SELECT, INSERT ON config_backup_runs TO app_user")) conn.execute(sa.text("GRANT SELECT, INSERT ON config_backup_runs TO app_user"))
conn.execute(sa.text("GRANT SELECT ON config_backup_runs TO poller_user")) conn.execute(sa.text("GRANT SELECT ON config_backup_runs TO poller_user"))
@@ -68,7 +74,8 @@ def upgrade() -> None:
# A per-device row with a specific device_id overrides the tenant default. # A per-device row with a specific device_id overrides the tenant default.
# UNIQUE(tenant_id, device_id) allows one entry per (tenant, device) pair # UNIQUE(tenant_id, device_id) allows one entry per (tenant, device) pair
# where device_id NULL is the tenant-level default. # where device_id NULL is the tenant-level default.
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE TABLE IF NOT EXISTS config_backup_schedules ( CREATE TABLE IF NOT EXISTS config_backup_schedules (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(), id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE,
@@ -78,14 +85,17 @@ def upgrade() -> None:
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE(tenant_id, device_id) UNIQUE(tenant_id, device_id)
) )
""")) """)
)
conn.execute(sa.text("ALTER TABLE config_backup_schedules ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE config_backup_schedules ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON config_backup_schedules CREATE POLICY tenant_isolation ON config_backup_schedules
USING (tenant_id::text = current_setting('app.current_tenant')) USING (tenant_id::text = current_setting('app.current_tenant'))
""")) """)
)
conn.execute(sa.text("GRANT SELECT, INSERT, UPDATE ON config_backup_schedules TO app_user")) conn.execute(sa.text("GRANT SELECT, INSERT, UPDATE ON config_backup_schedules TO app_user"))
@@ -97,7 +107,8 @@ def upgrade() -> None:
# startup handler checks for 'pending_verification' rows and either verifies # startup handler checks for 'pending_verification' rows and either verifies
# connectivity (clean up the RouterOS scheduler job) or marks as failed. # connectivity (clean up the RouterOS scheduler job) or marks as failed.
# See Pitfall 6 in 04-RESEARCH.md. # See Pitfall 6 in 04-RESEARCH.md.
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE TABLE IF NOT EXISTS config_push_operations ( CREATE TABLE IF NOT EXISTS config_push_operations (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(), id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
device_id UUID NOT NULL REFERENCES devices(id) ON DELETE CASCADE, device_id UUID NOT NULL REFERENCES devices(id) ON DELETE CASCADE,
@@ -108,14 +119,17 @@ def upgrade() -> None:
started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
completed_at TIMESTAMPTZ completed_at TIMESTAMPTZ
) )
""")) """)
)
conn.execute(sa.text("ALTER TABLE config_push_operations ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE config_push_operations ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON config_push_operations CREATE POLICY tenant_isolation ON config_push_operations
USING (tenant_id::text = current_setting('app.current_tenant')) USING (tenant_id::text = current_setting('app.current_tenant'))
""")) """)
)
conn.execute(sa.text("GRANT SELECT, INSERT, UPDATE ON config_push_operations TO app_user")) conn.execute(sa.text("GRANT SELECT, INSERT, UPDATE ON config_push_operations TO app_user"))

View File

@@ -31,24 +31,27 @@ def upgrade() -> None:
# ========================================================================= # =========================================================================
# ALTER devices TABLE — add architecture and preferred_channel columns # ALTER devices TABLE — add architecture and preferred_channel columns
# ========================================================================= # =========================================================================
conn.execute(sa.text( conn.execute(sa.text("ALTER TABLE devices ADD COLUMN IF NOT EXISTS architecture TEXT"))
"ALTER TABLE devices ADD COLUMN IF NOT EXISTS architecture TEXT" conn.execute(
)) sa.text(
conn.execute(sa.text( "ALTER TABLE devices ADD COLUMN IF NOT EXISTS preferred_channel TEXT DEFAULT 'stable' NOT NULL"
"ALTER TABLE devices ADD COLUMN IF NOT EXISTS preferred_channel TEXT DEFAULT 'stable' NOT NULL" )
)) )
# ========================================================================= # =========================================================================
# ALTER device_groups TABLE — add preferred_channel column # ALTER device_groups TABLE — add preferred_channel column
# ========================================================================= # =========================================================================
conn.execute(sa.text( conn.execute(
"ALTER TABLE device_groups ADD COLUMN IF NOT EXISTS preferred_channel TEXT DEFAULT 'stable' NOT NULL" sa.text(
)) "ALTER TABLE device_groups ADD COLUMN IF NOT EXISTS preferred_channel TEXT DEFAULT 'stable' NOT NULL"
)
)
# ========================================================================= # =========================================================================
# CREATE alert_rules TABLE # CREATE alert_rules TABLE
# ========================================================================= # =========================================================================
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE TABLE IF NOT EXISTS alert_rules ( CREATE TABLE IF NOT EXISTS alert_rules (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(), id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE,
@@ -64,27 +67,31 @@ def upgrade() -> None:
is_default BOOLEAN NOT NULL DEFAULT FALSE, is_default BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
) )
""")) """)
)
conn.execute(sa.text( conn.execute(
"CREATE INDEX IF NOT EXISTS idx_alert_rules_tenant_enabled " sa.text(
"ON alert_rules (tenant_id, enabled)" "CREATE INDEX IF NOT EXISTS idx_alert_rules_tenant_enabled "
)) "ON alert_rules (tenant_id, enabled)"
)
)
conn.execute(sa.text("ALTER TABLE alert_rules ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE alert_rules ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON alert_rules CREATE POLICY tenant_isolation ON alert_rules
USING (tenant_id::text = current_setting('app.current_tenant')) USING (tenant_id::text = current_setting('app.current_tenant'))
""")) """)
conn.execute(sa.text( )
"GRANT SELECT, INSERT, UPDATE, DELETE ON alert_rules TO app_user" conn.execute(sa.text("GRANT SELECT, INSERT, UPDATE, DELETE ON alert_rules TO app_user"))
))
conn.execute(sa.text("GRANT ALL ON alert_rules TO poller_user")) conn.execute(sa.text("GRANT ALL ON alert_rules TO poller_user"))
# ========================================================================= # =========================================================================
# CREATE notification_channels TABLE # CREATE notification_channels TABLE
# ========================================================================= # =========================================================================
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE TABLE IF NOT EXISTS notification_channels ( CREATE TABLE IF NOT EXISTS notification_channels (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(), id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE,
@@ -100,52 +107,60 @@ def upgrade() -> None:
webhook_url TEXT, webhook_url TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
) )
""")) """)
)
conn.execute(sa.text( conn.execute(
"CREATE INDEX IF NOT EXISTS idx_notification_channels_tenant " sa.text(
"ON notification_channels (tenant_id)" "CREATE INDEX IF NOT EXISTS idx_notification_channels_tenant "
)) "ON notification_channels (tenant_id)"
)
)
conn.execute(sa.text("ALTER TABLE notification_channels ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE notification_channels ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON notification_channels CREATE POLICY tenant_isolation ON notification_channels
USING (tenant_id::text = current_setting('app.current_tenant')) USING (tenant_id::text = current_setting('app.current_tenant'))
""")) """)
conn.execute(sa.text( )
"GRANT SELECT, INSERT, UPDATE, DELETE ON notification_channels TO app_user" conn.execute(
)) sa.text("GRANT SELECT, INSERT, UPDATE, DELETE ON notification_channels TO app_user")
)
conn.execute(sa.text("GRANT ALL ON notification_channels TO poller_user")) conn.execute(sa.text("GRANT ALL ON notification_channels TO poller_user"))
# ========================================================================= # =========================================================================
# CREATE alert_rule_channels TABLE (M2M association) # CREATE alert_rule_channels TABLE (M2M association)
# ========================================================================= # =========================================================================
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE TABLE IF NOT EXISTS alert_rule_channels ( CREATE TABLE IF NOT EXISTS alert_rule_channels (
rule_id UUID NOT NULL REFERENCES alert_rules(id) ON DELETE CASCADE, rule_id UUID NOT NULL REFERENCES alert_rules(id) ON DELETE CASCADE,
channel_id UUID NOT NULL REFERENCES notification_channels(id) ON DELETE CASCADE, channel_id UUID NOT NULL REFERENCES notification_channels(id) ON DELETE CASCADE,
PRIMARY KEY (rule_id, channel_id) PRIMARY KEY (rule_id, channel_id)
) )
""")) """)
)
conn.execute(sa.text("ALTER TABLE alert_rule_channels ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE alert_rule_channels ENABLE ROW LEVEL SECURITY"))
# RLS for M2M: join through parent table's tenant_id via rule_id # RLS for M2M: join through parent table's tenant_id via rule_id
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON alert_rule_channels CREATE POLICY tenant_isolation ON alert_rule_channels
USING (rule_id IN ( USING (rule_id IN (
SELECT id FROM alert_rules SELECT id FROM alert_rules
WHERE tenant_id::text = current_setting('app.current_tenant') WHERE tenant_id::text = current_setting('app.current_tenant')
)) ))
""")) """)
conn.execute(sa.text( )
"GRANT SELECT, INSERT, UPDATE, DELETE ON alert_rule_channels TO app_user" conn.execute(sa.text("GRANT SELECT, INSERT, UPDATE, DELETE ON alert_rule_channels TO app_user"))
))
conn.execute(sa.text("GRANT ALL ON alert_rule_channels TO poller_user")) conn.execute(sa.text("GRANT ALL ON alert_rule_channels TO poller_user"))
# ========================================================================= # =========================================================================
# CREATE alert_events TABLE # CREATE alert_events TABLE
# ========================================================================= # =========================================================================
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE TABLE IF NOT EXISTS alert_events ( CREATE TABLE IF NOT EXISTS alert_events (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(), id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
rule_id UUID REFERENCES alert_rules(id) ON DELETE SET NULL, rule_id UUID REFERENCES alert_rules(id) ON DELETE SET NULL,
@@ -164,31 +179,37 @@ def upgrade() -> None:
fired_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), fired_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
resolved_at TIMESTAMPTZ resolved_at TIMESTAMPTZ
) )
""")) """)
)
conn.execute(sa.text( conn.execute(
"CREATE INDEX IF NOT EXISTS idx_alert_events_device_rule_status " sa.text(
"ON alert_events (device_id, rule_id, status)" "CREATE INDEX IF NOT EXISTS idx_alert_events_device_rule_status "
)) "ON alert_events (device_id, rule_id, status)"
conn.execute(sa.text( )
"CREATE INDEX IF NOT EXISTS idx_alert_events_tenant_fired " )
"ON alert_events (tenant_id, fired_at)" conn.execute(
)) sa.text(
"CREATE INDEX IF NOT EXISTS idx_alert_events_tenant_fired "
"ON alert_events (tenant_id, fired_at)"
)
)
conn.execute(sa.text("ALTER TABLE alert_events ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE alert_events ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON alert_events CREATE POLICY tenant_isolation ON alert_events
USING (tenant_id::text = current_setting('app.current_tenant')) USING (tenant_id::text = current_setting('app.current_tenant'))
""")) """)
conn.execute(sa.text( )
"GRANT SELECT, INSERT, UPDATE, DELETE ON alert_events TO app_user" conn.execute(sa.text("GRANT SELECT, INSERT, UPDATE, DELETE ON alert_events TO app_user"))
))
conn.execute(sa.text("GRANT ALL ON alert_events TO poller_user")) conn.execute(sa.text("GRANT ALL ON alert_events TO poller_user"))
# ========================================================================= # =========================================================================
# CREATE firmware_versions TABLE (global — NOT tenant-scoped) # CREATE firmware_versions TABLE (global — NOT tenant-scoped)
# ========================================================================= # =========================================================================
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE TABLE IF NOT EXISTS firmware_versions ( CREATE TABLE IF NOT EXISTS firmware_versions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(), id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
architecture TEXT NOT NULL, architecture TEXT NOT NULL,
@@ -200,23 +221,25 @@ def upgrade() -> None:
checked_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), checked_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE(architecture, channel, version) UNIQUE(architecture, channel, version)
) )
""")) """)
)
conn.execute(sa.text( conn.execute(
"CREATE INDEX IF NOT EXISTS idx_firmware_versions_arch_channel " sa.text(
"ON firmware_versions (architecture, channel)" "CREATE INDEX IF NOT EXISTS idx_firmware_versions_arch_channel "
)) "ON firmware_versions (architecture, channel)"
)
)
# No RLS on firmware_versions — global cache table # No RLS on firmware_versions — global cache table
conn.execute(sa.text( conn.execute(sa.text("GRANT SELECT, INSERT, UPDATE ON firmware_versions TO app_user"))
"GRANT SELECT, INSERT, UPDATE ON firmware_versions TO app_user"
))
conn.execute(sa.text("GRANT ALL ON firmware_versions TO poller_user")) conn.execute(sa.text("GRANT ALL ON firmware_versions TO poller_user"))
# ========================================================================= # =========================================================================
# CREATE firmware_upgrade_jobs TABLE # CREATE firmware_upgrade_jobs TABLE
# ========================================================================= # =========================================================================
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE TABLE IF NOT EXISTS firmware_upgrade_jobs ( CREATE TABLE IF NOT EXISTS firmware_upgrade_jobs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(), id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE,
@@ -234,16 +257,19 @@ def upgrade() -> None:
confirmed_major_upgrade BOOLEAN NOT NULL DEFAULT FALSE, confirmed_major_upgrade BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
) )
""")) """)
)
conn.execute(sa.text("ALTER TABLE firmware_upgrade_jobs ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE firmware_upgrade_jobs ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON firmware_upgrade_jobs CREATE POLICY tenant_isolation ON firmware_upgrade_jobs
USING (tenant_id::text = current_setting('app.current_tenant')) USING (tenant_id::text = current_setting('app.current_tenant'))
""")) """)
conn.execute(sa.text( )
"GRANT SELECT, INSERT, UPDATE, DELETE ON firmware_upgrade_jobs TO app_user" conn.execute(
)) sa.text("GRANT SELECT, INSERT, UPDATE, DELETE ON firmware_upgrade_jobs TO app_user")
)
conn.execute(sa.text("GRANT ALL ON firmware_upgrade_jobs TO poller_user")) conn.execute(sa.text("GRANT ALL ON firmware_upgrade_jobs TO poller_user"))
# ========================================================================= # =========================================================================
@@ -252,21 +278,27 @@ def upgrade() -> None:
# Note: New tenant creation (in the tenants API router) should also seed # Note: New tenant creation (in the tenants API router) should also seed
# these three default rules. A _seed_default_alert_rules(tenant_id) helper # these three default rules. A _seed_default_alert_rules(tenant_id) helper
# should be created in the alerts router or a shared service for this. # should be created in the alerts router or a shared service for this.
conn.execute(sa.text(""" conn.execute(
sa.text("""
INSERT INTO alert_rules (id, tenant_id, name, metric, operator, threshold, duration_polls, severity, enabled, is_default) INSERT INTO alert_rules (id, tenant_id, name, metric, operator, threshold, duration_polls, severity, enabled, is_default)
SELECT gen_random_uuid(), t.id, 'High CPU Usage', 'cpu_load', 'gt', 90, 5, 'warning', TRUE, TRUE SELECT gen_random_uuid(), t.id, 'High CPU Usage', 'cpu_load', 'gt', 90, 5, 'warning', TRUE, TRUE
FROM tenants t FROM tenants t
""")) """)
conn.execute(sa.text(""" )
conn.execute(
sa.text("""
INSERT INTO alert_rules (id, tenant_id, name, metric, operator, threshold, duration_polls, severity, enabled, is_default) INSERT INTO alert_rules (id, tenant_id, name, metric, operator, threshold, duration_polls, severity, enabled, is_default)
SELECT gen_random_uuid(), t.id, 'High Memory Usage', 'memory_used_pct', 'gt', 90, 5, 'warning', TRUE, TRUE SELECT gen_random_uuid(), t.id, 'High Memory Usage', 'memory_used_pct', 'gt', 90, 5, 'warning', TRUE, TRUE
FROM tenants t FROM tenants t
""")) """)
conn.execute(sa.text(""" )
conn.execute(
sa.text("""
INSERT INTO alert_rules (id, tenant_id, name, metric, operator, threshold, duration_polls, severity, enabled, is_default) INSERT INTO alert_rules (id, tenant_id, name, metric, operator, threshold, duration_polls, severity, enabled, is_default)
SELECT gen_random_uuid(), t.id, 'High Disk Usage', 'disk_used_pct', 'gt', 85, 3, 'warning', TRUE, TRUE SELECT gen_random_uuid(), t.id, 'High Disk Usage', 'disk_used_pct', 'gt', 85, 3, 'warning', TRUE, TRUE
FROM tenants t FROM tenants t
""")) """)
)
def downgrade() -> None: def downgrade() -> None:

View File

@@ -31,17 +31,14 @@ def upgrade() -> None:
# ========================================================================= # =========================================================================
# ALTER devices TABLE — add latitude and longitude columns # ALTER devices TABLE — add latitude and longitude columns
# ========================================================================= # =========================================================================
conn.execute(sa.text( conn.execute(sa.text("ALTER TABLE devices ADD COLUMN IF NOT EXISTS latitude DOUBLE PRECISION"))
"ALTER TABLE devices ADD COLUMN IF NOT EXISTS latitude DOUBLE PRECISION" conn.execute(sa.text("ALTER TABLE devices ADD COLUMN IF NOT EXISTS longitude DOUBLE PRECISION"))
))
conn.execute(sa.text(
"ALTER TABLE devices ADD COLUMN IF NOT EXISTS longitude DOUBLE PRECISION"
))
# ========================================================================= # =========================================================================
# CREATE config_templates TABLE # CREATE config_templates TABLE
# ========================================================================= # =========================================================================
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE TABLE IF NOT EXISTS config_templates ( CREATE TABLE IF NOT EXISTS config_templates (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(), id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE,
@@ -53,12 +50,14 @@ def upgrade() -> None:
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
UNIQUE(tenant_id, name) UNIQUE(tenant_id, name)
) )
""")) """)
)
# ========================================================================= # =========================================================================
# CREATE config_template_tags TABLE # CREATE config_template_tags TABLE
# ========================================================================= # =========================================================================
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE TABLE IF NOT EXISTS config_template_tags ( CREATE TABLE IF NOT EXISTS config_template_tags (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(), id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE,
@@ -66,12 +65,14 @@ def upgrade() -> None:
template_id UUID NOT NULL REFERENCES config_templates(id) ON DELETE CASCADE, template_id UUID NOT NULL REFERENCES config_templates(id) ON DELETE CASCADE,
UNIQUE(template_id, name) UNIQUE(template_id, name)
) )
""")) """)
)
# ========================================================================= # =========================================================================
# CREATE template_push_jobs TABLE # CREATE template_push_jobs TABLE
# ========================================================================= # =========================================================================
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE TABLE IF NOT EXISTS template_push_jobs ( CREATE TABLE IF NOT EXISTS template_push_jobs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(), id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE,
@@ -86,48 +87,57 @@ def upgrade() -> None:
completed_at TIMESTAMPTZ, completed_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT now() created_at TIMESTAMPTZ NOT NULL DEFAULT now()
) )
""")) """)
)
# ========================================================================= # =========================================================================
# RLS POLICIES # RLS POLICIES
# ========================================================================= # =========================================================================
for table in ("config_templates", "config_template_tags", "template_push_jobs"): for table in ("config_templates", "config_template_tags", "template_push_jobs"):
conn.execute(sa.text(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text(f""" conn.execute(
sa.text(f"""
CREATE POLICY {table}_tenant_isolation ON {table} CREATE POLICY {table}_tenant_isolation ON {table}
USING (tenant_id = current_setting('app.current_tenant')::uuid) USING (tenant_id = current_setting('app.current_tenant')::uuid)
""")) """)
conn.execute(sa.text( )
f"GRANT SELECT, INSERT, UPDATE, DELETE ON {table} TO app_user" conn.execute(sa.text(f"GRANT SELECT, INSERT, UPDATE, DELETE ON {table} TO app_user"))
))
conn.execute(sa.text(f"GRANT ALL ON {table} TO poller_user")) conn.execute(sa.text(f"GRANT ALL ON {table} TO poller_user"))
# ========================================================================= # =========================================================================
# INDEXES # INDEXES
# ========================================================================= # =========================================================================
conn.execute(sa.text( conn.execute(
"CREATE INDEX IF NOT EXISTS idx_config_templates_tenant " sa.text(
"ON config_templates (tenant_id)" "CREATE INDEX IF NOT EXISTS idx_config_templates_tenant ON config_templates (tenant_id)"
)) )
conn.execute(sa.text( )
"CREATE INDEX IF NOT EXISTS idx_config_template_tags_template " conn.execute(
"ON config_template_tags (template_id)" sa.text(
)) "CREATE INDEX IF NOT EXISTS idx_config_template_tags_template "
conn.execute(sa.text( "ON config_template_tags (template_id)"
"CREATE INDEX IF NOT EXISTS idx_template_push_jobs_tenant_rollout " )
"ON template_push_jobs (tenant_id, rollout_id)" )
)) conn.execute(
conn.execute(sa.text( sa.text(
"CREATE INDEX IF NOT EXISTS idx_template_push_jobs_device_status " "CREATE INDEX IF NOT EXISTS idx_template_push_jobs_tenant_rollout "
"ON template_push_jobs (device_id, status)" "ON template_push_jobs (tenant_id, rollout_id)"
)) )
)
conn.execute(
sa.text(
"CREATE INDEX IF NOT EXISTS idx_template_push_jobs_device_status "
"ON template_push_jobs (device_id, status)"
)
)
# ========================================================================= # =========================================================================
# SEED STARTER TEMPLATES for all existing tenants # SEED STARTER TEMPLATES for all existing tenants
# ========================================================================= # =========================================================================
# 1. Basic Firewall # 1. Basic Firewall
conn.execute(sa.text(""" conn.execute(
sa.text("""
INSERT INTO config_templates (id, tenant_id, name, description, content, variables) INSERT INTO config_templates (id, tenant_id, name, description, content, variables)
SELECT SELECT
gen_random_uuid(), gen_random_uuid(),
@@ -146,10 +156,12 @@ add chain=forward action=drop',
'[{"name":"wan_interface","type":"string","default":"ether1","description":"WAN-facing interface"},{"name":"allowed_network","type":"subnet","default":"192.168.1.0/24","description":"Allowed source network"}]'::jsonb '[{"name":"wan_interface","type":"string","default":"ether1","description":"WAN-facing interface"},{"name":"allowed_network","type":"subnet","default":"192.168.1.0/24","description":"Allowed source network"}]'::jsonb
FROM tenants t FROM tenants t
ON CONFLICT DO NOTHING ON CONFLICT DO NOTHING
""")) """)
)
# 2. DHCP Server Setup # 2. DHCP Server Setup
conn.execute(sa.text(""" conn.execute(
sa.text("""
INSERT INTO config_templates (id, tenant_id, name, description, content, variables) INSERT INTO config_templates (id, tenant_id, name, description, content, variables)
SELECT SELECT
gen_random_uuid(), gen_random_uuid(),
@@ -162,10 +174,12 @@ add chain=forward action=drop',
'[{"name":"pool_start","type":"ip","default":"192.168.1.100","description":"DHCP pool start address"},{"name":"pool_end","type":"ip","default":"192.168.1.254","description":"DHCP pool end address"},{"name":"gateway","type":"ip","default":"192.168.1.1","description":"Default gateway"},{"name":"dns_server","type":"ip","default":"8.8.8.8","description":"DNS server address"},{"name":"interface","type":"string","default":"bridge1","description":"Interface to serve DHCP on"}]'::jsonb '[{"name":"pool_start","type":"ip","default":"192.168.1.100","description":"DHCP pool start address"},{"name":"pool_end","type":"ip","default":"192.168.1.254","description":"DHCP pool end address"},{"name":"gateway","type":"ip","default":"192.168.1.1","description":"Default gateway"},{"name":"dns_server","type":"ip","default":"8.8.8.8","description":"DNS server address"},{"name":"interface","type":"string","default":"bridge1","description":"Interface to serve DHCP on"}]'::jsonb
FROM tenants t FROM tenants t
ON CONFLICT DO NOTHING ON CONFLICT DO NOTHING
""")) """)
)
# 3. Wireless AP Config # 3. Wireless AP Config
conn.execute(sa.text(""" conn.execute(
sa.text("""
INSERT INTO config_templates (id, tenant_id, name, description, content, variables) INSERT INTO config_templates (id, tenant_id, name, description, content, variables)
SELECT SELECT
gen_random_uuid(), gen_random_uuid(),
@@ -177,10 +191,12 @@ add chain=forward action=drop',
'[{"name":"ssid","type":"string","default":"MikroTik-AP","description":"Wireless network name"},{"name":"password","type":"string","default":"","description":"WPA2 pre-shared key (min 8 characters)"},{"name":"frequency","type":"integer","default":"2412","description":"Wireless frequency in MHz"},{"name":"channel_width","type":"string","default":"20/40mhz-XX","description":"Channel width setting"}]'::jsonb '[{"name":"ssid","type":"string","default":"MikroTik-AP","description":"Wireless network name"},{"name":"password","type":"string","default":"","description":"WPA2 pre-shared key (min 8 characters)"},{"name":"frequency","type":"integer","default":"2412","description":"Wireless frequency in MHz"},{"name":"channel_width","type":"string","default":"20/40mhz-XX","description":"Channel width setting"}]'::jsonb
FROM tenants t FROM tenants t
ON CONFLICT DO NOTHING ON CONFLICT DO NOTHING
""")) """)
)
# 4. Initial Device Setup # 4. Initial Device Setup
conn.execute(sa.text(""" conn.execute(
sa.text("""
INSERT INTO config_templates (id, tenant_id, name, description, content, variables) INSERT INTO config_templates (id, tenant_id, name, description, content, variables)
SELECT SELECT
gen_random_uuid(), gen_random_uuid(),
@@ -196,7 +212,8 @@ add chain=forward action=drop',
'[{"name":"ntp_server","type":"ip","default":"pool.ntp.org","description":"NTP server address"},{"name":"dns_servers","type":"string","default":"8.8.8.8,8.8.4.4","description":"Comma-separated DNS servers"}]'::jsonb '[{"name":"ntp_server","type":"ip","default":"pool.ntp.org","description":"NTP server address"},{"name":"dns_servers","type":"string","default":"8.8.8.8,8.8.4.4","description":"Comma-separated DNS servers"}]'::jsonb
FROM tenants t FROM tenants t
ON CONFLICT DO NOTHING ON CONFLICT DO NOTHING
""")) """)
)
def downgrade() -> None: def downgrade() -> None:

View File

@@ -29,7 +29,8 @@ def upgrade() -> None:
# ========================================================================= # =========================================================================
# CREATE audit_logs TABLE # CREATE audit_logs TABLE
# ========================================================================= # =========================================================================
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE TABLE IF NOT EXISTS audit_logs ( CREATE TABLE IF NOT EXISTS audit_logs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(), id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE,
@@ -42,39 +43,40 @@ def upgrade() -> None:
ip_address VARCHAR(45), ip_address VARCHAR(45),
created_at TIMESTAMPTZ NOT NULL DEFAULT now() created_at TIMESTAMPTZ NOT NULL DEFAULT now()
) )
""")) """)
)
# ========================================================================= # =========================================================================
# RLS POLICY # RLS POLICY
# ========================================================================= # =========================================================================
conn.execute(sa.text( conn.execute(sa.text("ALTER TABLE audit_logs ENABLE ROW LEVEL SECURITY"))
"ALTER TABLE audit_logs ENABLE ROW LEVEL SECURITY" conn.execute(
)) sa.text("""
conn.execute(sa.text("""
CREATE POLICY audit_logs_tenant_isolation ON audit_logs CREATE POLICY audit_logs_tenant_isolation ON audit_logs
USING (tenant_id = current_setting('app.current_tenant')::uuid) USING (tenant_id = current_setting('app.current_tenant')::uuid)
""")) """)
)
# Grant SELECT + INSERT to app_user (no UPDATE/DELETE -- audit logs are immutable) # Grant SELECT + INSERT to app_user (no UPDATE/DELETE -- audit logs are immutable)
conn.execute(sa.text( conn.execute(sa.text("GRANT SELECT, INSERT ON audit_logs TO app_user"))
"GRANT SELECT, INSERT ON audit_logs TO app_user"
))
# Poller user gets full access for cross-tenant audit logging # Poller user gets full access for cross-tenant audit logging
conn.execute(sa.text( conn.execute(sa.text("GRANT ALL ON audit_logs TO poller_user"))
"GRANT ALL ON audit_logs TO poller_user"
))
# ========================================================================= # =========================================================================
# INDEXES # INDEXES
# ========================================================================= # =========================================================================
conn.execute(sa.text( conn.execute(
"CREATE INDEX IF NOT EXISTS idx_audit_logs_tenant_created " sa.text(
"ON audit_logs (tenant_id, created_at DESC)" "CREATE INDEX IF NOT EXISTS idx_audit_logs_tenant_created "
)) "ON audit_logs (tenant_id, created_at DESC)"
conn.execute(sa.text( )
"CREATE INDEX IF NOT EXISTS idx_audit_logs_tenant_action " )
"ON audit_logs (tenant_id, action)" conn.execute(
)) sa.text(
"CREATE INDEX IF NOT EXISTS idx_audit_logs_tenant_action "
"ON audit_logs (tenant_id, action)"
)
)
def downgrade() -> None: def downgrade() -> None:

View File

@@ -28,7 +28,8 @@ def upgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
# ── 1. Create maintenance_windows table ──────────────────────────────── # ── 1. Create maintenance_windows table ────────────────────────────────
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE TABLE IF NOT EXISTS maintenance_windows ( CREATE TABLE IF NOT EXISTS maintenance_windows (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(), id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE,
@@ -44,18 +45,22 @@ def upgrade() -> None:
CONSTRAINT chk_maintenance_window_dates CHECK (end_at > start_at) CONSTRAINT chk_maintenance_window_dates CHECK (end_at > start_at)
) )
""")) """)
)
# ── 2. Composite index for active window queries ─────────────────────── # ── 2. Composite index for active window queries ───────────────────────
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE INDEX IF NOT EXISTS idx_maintenance_windows_tenant_time CREATE INDEX IF NOT EXISTS idx_maintenance_windows_tenant_time
ON maintenance_windows (tenant_id, start_at, end_at) ON maintenance_windows (tenant_id, start_at, end_at)
""")) """)
)
# ── 3. RLS policy ───────────────────────────────────────────────────── # ── 3. RLS policy ─────────────────────────────────────────────────────
conn.execute(sa.text("ALTER TABLE maintenance_windows ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE maintenance_windows ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
DO $$ DO $$
BEGIN BEGIN
IF NOT EXISTS ( IF NOT EXISTS (
@@ -67,10 +72,12 @@ def upgrade() -> None:
END IF; END IF;
END END
$$ $$
""")) """)
)
# ── 4. Grant permissions to app_user ─────────────────────────────────── # ── 4. Grant permissions to app_user ───────────────────────────────────
conn.execute(sa.text(""" conn.execute(
sa.text("""
DO $$ DO $$
BEGIN BEGIN
IF EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'app_user') THEN IF EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'app_user') THEN
@@ -78,7 +85,8 @@ def upgrade() -> None:
END IF; END IF;
END END
$$ $$
""")) """)
)
def downgrade() -> None: def downgrade() -> None:

View File

@@ -28,34 +28,81 @@ def upgrade() -> None:
# ── vpn_config: one row per tenant ── # ── vpn_config: one row per tenant ──
op.create_table( op.create_table(
"vpn_config", "vpn_config",
sa.Column("id", UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), primary_key=True), sa.Column(
sa.Column("tenant_id", UUID(as_uuid=True), sa.ForeignKey("tenants.id", ondelete="CASCADE"), nullable=False, unique=True), "id", UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), primary_key=True
),
sa.Column(
"tenant_id",
UUID(as_uuid=True),
sa.ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
unique=True,
),
sa.Column("server_private_key", sa.LargeBinary(), nullable=False), # AES-256-GCM encrypted sa.Column("server_private_key", sa.LargeBinary(), nullable=False), # AES-256-GCM encrypted
sa.Column("server_public_key", sa.String(64), nullable=False), sa.Column("server_public_key", sa.String(64), nullable=False),
sa.Column("subnet", sa.String(32), nullable=False, server_default="10.10.0.0/24"), sa.Column("subnet", sa.String(32), nullable=False, server_default="10.10.0.0/24"),
sa.Column("server_port", sa.Integer(), nullable=False, server_default="51820"), sa.Column("server_port", sa.Integer(), nullable=False, server_default="51820"),
sa.Column("server_address", sa.String(32), nullable=False, server_default="10.10.0.1/24"), sa.Column("server_address", sa.String(32), nullable=False, server_default="10.10.0.1/24"),
sa.Column("endpoint", sa.String(255), nullable=True), # public hostname:port for devices to connect to sa.Column(
"endpoint", sa.String(255), nullable=True
), # public hostname:port for devices to connect to
sa.Column("is_enabled", sa.Boolean(), nullable=False, server_default="false"), sa.Column("is_enabled", sa.Boolean(), nullable=False, server_default="false"),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), sa.Column(
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), "created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
) )
# ── vpn_peers: one per device VPN connection ── # ── vpn_peers: one per device VPN connection ──
op.create_table( op.create_table(
"vpn_peers", "vpn_peers",
sa.Column("id", UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), primary_key=True), sa.Column(
sa.Column("tenant_id", UUID(as_uuid=True), sa.ForeignKey("tenants.id", ondelete="CASCADE"), nullable=False), "id", UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), primary_key=True
sa.Column("device_id", UUID(as_uuid=True), sa.ForeignKey("devices.id", ondelete="CASCADE"), nullable=False, unique=True), ),
sa.Column(
"tenant_id",
UUID(as_uuid=True),
sa.ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"device_id",
UUID(as_uuid=True),
sa.ForeignKey("devices.id", ondelete="CASCADE"),
nullable=False,
unique=True,
),
sa.Column("peer_private_key", sa.LargeBinary(), nullable=False), # AES-256-GCM encrypted sa.Column("peer_private_key", sa.LargeBinary(), nullable=False), # AES-256-GCM encrypted
sa.Column("peer_public_key", sa.String(64), nullable=False), sa.Column("peer_public_key", sa.String(64), nullable=False),
sa.Column("preshared_key", sa.LargeBinary(), nullable=True), # AES-256-GCM encrypted, optional sa.Column(
"preshared_key", sa.LargeBinary(), nullable=True
), # AES-256-GCM encrypted, optional
sa.Column("assigned_ip", sa.String(32), nullable=False), # e.g. 10.10.0.2/24 sa.Column("assigned_ip", sa.String(32), nullable=False), # e.g. 10.10.0.2/24
sa.Column("additional_allowed_ips", sa.String(512), nullable=True), # comma-separated subnets for site-to-site sa.Column(
"additional_allowed_ips", sa.String(512), nullable=True
), # comma-separated subnets for site-to-site
sa.Column("is_enabled", sa.Boolean(), nullable=False, server_default="true"), sa.Column("is_enabled", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("last_handshake", sa.DateTime(timezone=True), nullable=True), sa.Column("last_handshake", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), sa.Column(
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), "created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
) )
# Indexes # Indexes

View File

@@ -22,7 +22,8 @@ def upgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
# 1. Basic Router — comprehensive starter for a typical SOHO/branch router # 1. Basic Router — comprehensive starter for a typical SOHO/branch router
conn.execute(sa.text(""" conn.execute(
sa.text("""
INSERT INTO config_templates (id, tenant_id, name, description, content, variables) INSERT INTO config_templates (id, tenant_id, name, description, content, variables)
SELECT SELECT
gen_random_uuid(), gen_random_uuid(),
@@ -75,10 +76,12 @@ add chain=forward action=drop comment="Drop everything else"
SELECT 1 FROM config_templates ct SELECT 1 FROM config_templates ct
WHERE ct.tenant_id = t.id AND ct.name = 'Basic Router' WHERE ct.tenant_id = t.id AND ct.name = 'Basic Router'
) )
""")) """)
)
# 2. Re-seed Basic Firewall (for tenants missing it) # 2. Re-seed Basic Firewall (for tenants missing it)
conn.execute(sa.text(""" conn.execute(
sa.text("""
INSERT INTO config_templates (id, tenant_id, name, description, content, variables) INSERT INTO config_templates (id, tenant_id, name, description, content, variables)
SELECT SELECT
gen_random_uuid(), gen_random_uuid(),
@@ -100,10 +103,12 @@ add chain=forward action=drop',
SELECT 1 FROM config_templates ct SELECT 1 FROM config_templates ct
WHERE ct.tenant_id = t.id AND ct.name = 'Basic Firewall' WHERE ct.tenant_id = t.id AND ct.name = 'Basic Firewall'
) )
""")) """)
)
# 3. Re-seed DHCP Server Setup # 3. Re-seed DHCP Server Setup
conn.execute(sa.text(""" conn.execute(
sa.text("""
INSERT INTO config_templates (id, tenant_id, name, description, content, variables) INSERT INTO config_templates (id, tenant_id, name, description, content, variables)
SELECT SELECT
gen_random_uuid(), gen_random_uuid(),
@@ -119,10 +124,12 @@ add chain=forward action=drop',
SELECT 1 FROM config_templates ct SELECT 1 FROM config_templates ct
WHERE ct.tenant_id = t.id AND ct.name = 'DHCP Server Setup' WHERE ct.tenant_id = t.id AND ct.name = 'DHCP Server Setup'
) )
""")) """)
)
# 4. Re-seed Wireless AP Config # 4. Re-seed Wireless AP Config
conn.execute(sa.text(""" conn.execute(
sa.text("""
INSERT INTO config_templates (id, tenant_id, name, description, content, variables) INSERT INTO config_templates (id, tenant_id, name, description, content, variables)
SELECT SELECT
gen_random_uuid(), gen_random_uuid(),
@@ -137,10 +144,12 @@ add chain=forward action=drop',
SELECT 1 FROM config_templates ct SELECT 1 FROM config_templates ct
WHERE ct.tenant_id = t.id AND ct.name = 'Wireless AP Config' WHERE ct.tenant_id = t.id AND ct.name = 'Wireless AP Config'
) )
""")) """)
)
# 5. Re-seed Initial Device Setup # 5. Re-seed Initial Device Setup
conn.execute(sa.text(""" conn.execute(
sa.text("""
INSERT INTO config_templates (id, tenant_id, name, description, content, variables) INSERT INTO config_templates (id, tenant_id, name, description, content, variables)
SELECT SELECT
gen_random_uuid(), gen_random_uuid(),
@@ -159,11 +168,10 @@ add chain=forward action=drop',
SELECT 1 FROM config_templates ct SELECT 1 FROM config_templates ct
WHERE ct.tenant_id = t.id AND ct.name = 'Initial Device Setup' WHERE ct.tenant_id = t.id AND ct.name = 'Initial Device Setup'
) )
""")) """)
)
def downgrade() -> None: def downgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
conn.execute(sa.text( conn.execute(sa.text("DELETE FROM config_templates WHERE name = 'Basic Router'"))
"DELETE FROM config_templates WHERE name = 'Basic Router'"
))

View File

@@ -138,62 +138,44 @@ def upgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
# certificate_authorities RLS # certificate_authorities RLS
conn.execute(sa.text( conn.execute(sa.text("ALTER TABLE certificate_authorities ENABLE ROW LEVEL SECURITY"))
"ALTER TABLE certificate_authorities ENABLE ROW LEVEL SECURITY" conn.execute(
)) sa.text("GRANT SELECT, INSERT, UPDATE, DELETE ON certificate_authorities TO app_user")
conn.execute(sa.text( )
"GRANT SELECT, INSERT, UPDATE, DELETE ON certificate_authorities TO app_user" conn.execute(
)) sa.text(
conn.execute(sa.text( "CREATE POLICY tenant_isolation ON certificate_authorities FOR ALL "
"CREATE POLICY tenant_isolation ON certificate_authorities FOR ALL " "USING (tenant_id = NULLIF(current_setting('app.current_tenant', true), '')::uuid) "
"USING (tenant_id = NULLIF(current_setting('app.current_tenant', true), '')::uuid) " "WITH CHECK (tenant_id = NULLIF(current_setting('app.current_tenant', true), '')::uuid)"
"WITH CHECK (tenant_id = NULLIF(current_setting('app.current_tenant', true), '')::uuid)" )
)) )
conn.execute(sa.text( conn.execute(sa.text("GRANT SELECT ON certificate_authorities TO poller_user"))
"GRANT SELECT ON certificate_authorities TO poller_user"
))
# device_certificates RLS # device_certificates RLS
conn.execute(sa.text( conn.execute(sa.text("ALTER TABLE device_certificates ENABLE ROW LEVEL SECURITY"))
"ALTER TABLE device_certificates ENABLE ROW LEVEL SECURITY" conn.execute(sa.text("GRANT SELECT, INSERT, UPDATE, DELETE ON device_certificates TO app_user"))
)) conn.execute(
conn.execute(sa.text( sa.text(
"GRANT SELECT, INSERT, UPDATE, DELETE ON device_certificates TO app_user" "CREATE POLICY tenant_isolation ON device_certificates FOR ALL "
)) "USING (tenant_id = NULLIF(current_setting('app.current_tenant', true), '')::uuid) "
conn.execute(sa.text( "WITH CHECK (tenant_id = NULLIF(current_setting('app.current_tenant', true), '')::uuid)"
"CREATE POLICY tenant_isolation ON device_certificates FOR ALL " )
"USING (tenant_id = NULLIF(current_setting('app.current_tenant', true), '')::uuid) " )
"WITH CHECK (tenant_id = NULLIF(current_setting('app.current_tenant', true), '')::uuid)" conn.execute(sa.text("GRANT SELECT ON device_certificates TO poller_user"))
))
conn.execute(sa.text(
"GRANT SELECT ON device_certificates TO poller_user"
))
def downgrade() -> None: def downgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
# Drop RLS policies # Drop RLS policies
conn.execute(sa.text( conn.execute(sa.text("DROP POLICY IF EXISTS tenant_isolation ON device_certificates"))
"DROP POLICY IF EXISTS tenant_isolation ON device_certificates" conn.execute(sa.text("DROP POLICY IF EXISTS tenant_isolation ON certificate_authorities"))
))
conn.execute(sa.text(
"DROP POLICY IF EXISTS tenant_isolation ON certificate_authorities"
))
# Revoke grants # Revoke grants
conn.execute(sa.text( conn.execute(sa.text("REVOKE ALL ON device_certificates FROM app_user"))
"REVOKE ALL ON device_certificates FROM app_user" conn.execute(sa.text("REVOKE ALL ON device_certificates FROM poller_user"))
)) conn.execute(sa.text("REVOKE ALL ON certificate_authorities FROM app_user"))
conn.execute(sa.text( conn.execute(sa.text("REVOKE ALL ON certificate_authorities FROM poller_user"))
"REVOKE ALL ON device_certificates FROM poller_user"
))
conn.execute(sa.text(
"REVOKE ALL ON certificate_authorities FROM app_user"
))
conn.execute(sa.text(
"REVOKE ALL ON certificate_authorities FROM poller_user"
))
# Drop tls_mode column from devices # Drop tls_mode column from devices
op.drop_column("devices", "tls_mode") op.drop_column("devices", "tls_mode")

View File

@@ -35,9 +35,7 @@ def upgrade() -> None:
for table in HYPERTABLES: for table in HYPERTABLES:
# Drop chunks older than 90 days # Drop chunks older than 90 days
conn.execute(sa.text( conn.execute(sa.text(f"SELECT add_retention_policy('{table}', INTERVAL '90 days')"))
f"SELECT add_retention_policy('{table}', INTERVAL '90 days')"
))
def downgrade() -> None: def downgrade() -> None:
@@ -45,6 +43,4 @@ def downgrade() -> None:
for table in HYPERTABLES: for table in HYPERTABLES:
# Remove retention policy # Remove retention policy
conn.execute(sa.text( conn.execute(sa.text(f"SELECT remove_retention_policy('{table}', if_exists => true)"))
f"SELECT remove_retention_policy('{table}', if_exists => true)"
))

View File

@@ -147,46 +147,36 @@ def upgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
# user_key_sets RLS # user_key_sets RLS
conn.execute(sa.text( conn.execute(sa.text("ALTER TABLE user_key_sets ENABLE ROW LEVEL SECURITY"))
"ALTER TABLE user_key_sets ENABLE ROW LEVEL SECURITY" conn.execute(
)) sa.text(
conn.execute(sa.text( "CREATE POLICY user_key_sets_tenant_isolation ON user_key_sets "
"CREATE POLICY user_key_sets_tenant_isolation ON user_key_sets " "USING (tenant_id::text = current_setting('app.current_tenant', true) "
"USING (tenant_id::text = current_setting('app.current_tenant', true) " "OR current_setting('app.current_tenant', true) = 'super_admin')"
"OR current_setting('app.current_tenant', true) = 'super_admin')" )
)) )
conn.execute(sa.text( conn.execute(sa.text("GRANT SELECT, INSERT, UPDATE ON user_key_sets TO app_user"))
"GRANT SELECT, INSERT, UPDATE ON user_key_sets TO app_user"
))
# key_access_log RLS (append-only: INSERT+SELECT only, no UPDATE/DELETE) # key_access_log RLS (append-only: INSERT+SELECT only, no UPDATE/DELETE)
conn.execute(sa.text( conn.execute(sa.text("ALTER TABLE key_access_log ENABLE ROW LEVEL SECURITY"))
"ALTER TABLE key_access_log ENABLE ROW LEVEL SECURITY" conn.execute(
)) sa.text(
conn.execute(sa.text( "CREATE POLICY key_access_log_tenant_isolation ON key_access_log "
"CREATE POLICY key_access_log_tenant_isolation ON key_access_log " "USING (tenant_id::text = current_setting('app.current_tenant', true) "
"USING (tenant_id::text = current_setting('app.current_tenant', true) " "OR current_setting('app.current_tenant', true) = 'super_admin')"
"OR current_setting('app.current_tenant', true) = 'super_admin')" )
)) )
conn.execute(sa.text( conn.execute(sa.text("GRANT INSERT, SELECT ON key_access_log TO app_user"))
"GRANT INSERT, SELECT ON key_access_log TO app_user"
))
# poller_user needs INSERT to log key access events when decrypting credentials # poller_user needs INSERT to log key access events when decrypting credentials
conn.execute(sa.text( conn.execute(sa.text("GRANT INSERT, SELECT ON key_access_log TO poller_user"))
"GRANT INSERT, SELECT ON key_access_log TO poller_user"
))
def downgrade() -> None: def downgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
# Drop RLS policies # Drop RLS policies
conn.execute(sa.text( conn.execute(sa.text("DROP POLICY IF EXISTS key_access_log_tenant_isolation ON key_access_log"))
"DROP POLICY IF EXISTS key_access_log_tenant_isolation ON key_access_log" conn.execute(sa.text("DROP POLICY IF EXISTS user_key_sets_tenant_isolation ON user_key_sets"))
))
conn.execute(sa.text(
"DROP POLICY IF EXISTS user_key_sets_tenant_isolation ON user_key_sets"
))
# Revoke grants # Revoke grants
conn.execute(sa.text("REVOKE ALL ON key_access_log FROM app_user")) conn.execute(sa.text("REVOKE ALL ON key_access_log FROM app_user"))

View File

@@ -77,9 +77,7 @@ def upgrade() -> None:
def downgrade() -> None: def downgrade() -> None:
op.drop_constraint( op.drop_constraint("fk_key_access_log_device_id", "key_access_log", type_="foreignkey")
"fk_key_access_log_device_id", "key_access_log", type_="foreignkey"
)
op.drop_column("key_access_log", "correlation_id") op.drop_column("key_access_log", "correlation_id")
op.drop_column("key_access_log", "justification") op.drop_column("key_access_log", "justification")
op.drop_column("key_access_log", "device_id") op.drop_column("key_access_log", "device_id")

View File

@@ -33,8 +33,7 @@ def upgrade() -> None:
# Flag all bcrypt-only users for upgrade (auth_version=1 and no SRP verifier) # Flag all bcrypt-only users for upgrade (auth_version=1 and no SRP verifier)
op.execute( op.execute(
"UPDATE users SET must_upgrade_auth = true " "UPDATE users SET must_upgrade_auth = true WHERE auth_version = 1 AND srp_verifier IS NULL"
"WHERE auth_version = 1 AND srp_verifier IS NULL"
) )
# Make hashed_password nullable (SRP users don't need it) # Make hashed_password nullable (SRP users don't need it)
@@ -44,8 +43,7 @@ def upgrade() -> None:
def downgrade() -> None: def downgrade() -> None:
# Restore NOT NULL (set a dummy value for any NULLs first) # Restore NOT NULL (set a dummy value for any NULLs first)
op.execute( op.execute(
"UPDATE users SET hashed_password = '$2b$12$placeholder' " "UPDATE users SET hashed_password = '$2b$12$placeholder' WHERE hashed_password IS NULL"
"WHERE hashed_password IS NULL"
) )
op.alter_column("users", "hashed_password", nullable=False) op.alter_column("users", "hashed_password", nullable=False)

View File

@@ -14,7 +14,6 @@ Existing 'insecure' devices become 'auto' since the old behavior was
an implicit auto-fallback. portal_ca devices keep their mode. an implicit auto-fallback. portal_ca devices keep their mode.
""" """
import sqlalchemy as sa
from alembic import op from alembic import op
revision = "020" revision = "020"

View File

@@ -25,7 +25,8 @@ def upgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
for table in _TABLES: for table in _TABLES:
conn.execute(sa.text(f"DROP POLICY IF EXISTS tenant_isolation ON {table}")) conn.execute(sa.text(f"DROP POLICY IF EXISTS tenant_isolation ON {table}"))
conn.execute(sa.text(f""" conn.execute(
sa.text(f"""
CREATE POLICY tenant_isolation ON {table} CREATE POLICY tenant_isolation ON {table}
USING ( USING (
tenant_id::text = current_setting('app.current_tenant', true) tenant_id::text = current_setting('app.current_tenant', true)
@@ -35,15 +36,18 @@ def upgrade() -> None:
tenant_id::text = current_setting('app.current_tenant', true) tenant_id::text = current_setting('app.current_tenant', true)
OR current_setting('app.current_tenant', true) = 'super_admin' OR current_setting('app.current_tenant', true) = 'super_admin'
) )
""")) """)
)
def downgrade() -> None: def downgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
for table in _TABLES: for table in _TABLES:
conn.execute(sa.text(f"DROP POLICY IF EXISTS tenant_isolation ON {table}")) conn.execute(sa.text(f"DROP POLICY IF EXISTS tenant_isolation ON {table}"))
conn.execute(sa.text(f""" conn.execute(
sa.text(f"""
CREATE POLICY tenant_isolation ON {table} CREATE POLICY tenant_isolation ON {table}
USING (tenant_id::text = current_setting('app.current_tenant', true)) USING (tenant_id::text = current_setting('app.current_tenant', true))
WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true)) WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true))
""")) """)
)

View File

@@ -19,7 +19,8 @@ def upgrade() -> None:
op.add_column("tenants", sa.Column("contact_email", sa.String(255), nullable=True)) op.add_column("tenants", sa.Column("contact_email", sa.String(255), nullable=True))
# 2. Seed device_offline default alert rule for all existing tenants # 2. Seed device_offline default alert rule for all existing tenants
conn.execute(sa.text(""" conn.execute(
sa.text("""
INSERT INTO alert_rules (id, tenant_id, name, metric, operator, threshold, duration_polls, severity, enabled, is_default) INSERT INTO alert_rules (id, tenant_id, name, metric, operator, threshold, duration_polls, severity, enabled, is_default)
SELECT gen_random_uuid(), t.id, 'Device Offline', 'device_offline', 'eq', 1, 1, 'critical', TRUE, TRUE SELECT gen_random_uuid(), t.id, 'Device Offline', 'device_offline', 'eq', 1, 1, 'critical', TRUE, TRUE
FROM tenants t FROM tenants t
@@ -28,14 +29,17 @@ def upgrade() -> None:
SELECT 1 FROM alert_rules ar SELECT 1 FROM alert_rules ar
WHERE ar.tenant_id = t.id AND ar.metric = 'device_offline' AND ar.is_default = TRUE WHERE ar.tenant_id = t.id AND ar.metric = 'device_offline' AND ar.is_default = TRUE
) )
""")) """)
)
def downgrade() -> None: def downgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
conn.execute(sa.text(""" conn.execute(
sa.text("""
DELETE FROM alert_rules WHERE metric = 'device_offline' AND is_default = TRUE DELETE FROM alert_rules WHERE metric = 'device_offline' AND is_default = TRUE
""")) """)
)
op.drop_column("tenants", "contact_email") op.drop_column("tenants", "contact_email")

View File

@@ -11,9 +11,7 @@ down_revision = "024"
def upgrade() -> None: def upgrade() -> None:
op.drop_constraint( op.drop_constraint("fk_key_access_log_device_id", "key_access_log", type_="foreignkey")
"fk_key_access_log_device_id", "key_access_log", type_="foreignkey"
)
op.create_foreign_key( op.create_foreign_key(
"fk_key_access_log_device_id", "fk_key_access_log_device_id",
"key_access_log", "key_access_log",
@@ -25,9 +23,7 @@ def upgrade() -> None:
def downgrade() -> None: def downgrade() -> None:
op.drop_constraint( op.drop_constraint("fk_key_access_log_device_id", "key_access_log", type_="foreignkey")
"fk_key_access_log_device_id", "key_access_log", type_="foreignkey"
)
op.create_foreign_key( op.create_foreign_key(
"fk_key_access_log_device_id", "fk_key_access_log_device_id",
"key_access_log", "key_access_log",

View File

@@ -25,7 +25,8 @@ def upgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
# ── router_config_snapshots ────────────────────────────────────────── # ── router_config_snapshots ──────────────────────────────────────────
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE TABLE router_config_snapshots ( CREATE TABLE router_config_snapshots (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(), id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
device_id UUID NOT NULL REFERENCES devices(id) ON DELETE CASCADE, device_id UUID NOT NULL REFERENCES devices(id) ON DELETE CASCADE,
@@ -35,30 +36,38 @@ def upgrade() -> None:
collected_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), collected_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
) )
""")) """)
)
# RLS # RLS
conn.execute(sa.text("ALTER TABLE router_config_snapshots ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE router_config_snapshots ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text("ALTER TABLE router_config_snapshots FORCE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE router_config_snapshots FORCE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON router_config_snapshots CREATE POLICY tenant_isolation ON router_config_snapshots
USING (tenant_id::text = current_setting('app.current_tenant', true)) USING (tenant_id::text = current_setting('app.current_tenant', true))
WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true)) WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true))
""")) """)
)
# Grants # Grants
conn.execute(sa.text("GRANT SELECT, INSERT, DELETE ON router_config_snapshots TO app_user")) conn.execute(sa.text("GRANT SELECT, INSERT, DELETE ON router_config_snapshots TO app_user"))
# Indexes # Indexes
conn.execute(sa.text( conn.execute(
"CREATE INDEX idx_rcs_device_collected ON router_config_snapshots (device_id, collected_at DESC)" sa.text(
)) "CREATE INDEX idx_rcs_device_collected ON router_config_snapshots (device_id, collected_at DESC)"
conn.execute(sa.text( )
"CREATE INDEX idx_rcs_device_hash ON router_config_snapshots (device_id, sha256_hash)" )
)) conn.execute(
sa.text(
"CREATE INDEX idx_rcs_device_hash ON router_config_snapshots (device_id, sha256_hash)"
)
)
# ── router_config_diffs ────────────────────────────────────────────── # ── router_config_diffs ──────────────────────────────────────────────
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE TABLE router_config_diffs ( CREATE TABLE router_config_diffs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(), id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
device_id UUID NOT NULL REFERENCES devices(id) ON DELETE CASCADE, device_id UUID NOT NULL REFERENCES devices(id) ON DELETE CASCADE,
@@ -70,27 +79,33 @@ def upgrade() -> None:
lines_removed INT NOT NULL DEFAULT 0, lines_removed INT NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
) )
""")) """)
)
# RLS # RLS
conn.execute(sa.text("ALTER TABLE router_config_diffs ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE router_config_diffs ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text("ALTER TABLE router_config_diffs FORCE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE router_config_diffs FORCE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON router_config_diffs CREATE POLICY tenant_isolation ON router_config_diffs
USING (tenant_id::text = current_setting('app.current_tenant', true)) USING (tenant_id::text = current_setting('app.current_tenant', true))
WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true)) WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true))
""")) """)
)
# Grants # Grants
conn.execute(sa.text("GRANT SELECT, INSERT, DELETE ON router_config_diffs TO app_user")) conn.execute(sa.text("GRANT SELECT, INSERT, DELETE ON router_config_diffs TO app_user"))
# Indexes # Indexes
conn.execute(sa.text( conn.execute(
"CREATE UNIQUE INDEX idx_rcd_snapshot_pair ON router_config_diffs (old_snapshot_id, new_snapshot_id)" sa.text(
)) "CREATE UNIQUE INDEX idx_rcd_snapshot_pair ON router_config_diffs (old_snapshot_id, new_snapshot_id)"
)
)
# ── router_config_changes ──────────────────────────────────────────── # ── router_config_changes ────────────────────────────────────────────
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE TABLE router_config_changes ( CREATE TABLE router_config_changes (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(), id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
diff_id UUID NOT NULL REFERENCES router_config_diffs(id) ON DELETE CASCADE, diff_id UUID NOT NULL REFERENCES router_config_diffs(id) ON DELETE CASCADE,
@@ -101,24 +116,25 @@ def upgrade() -> None:
raw_line TEXT, raw_line TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
) )
""")) """)
)
# RLS # RLS
conn.execute(sa.text("ALTER TABLE router_config_changes ENABLE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE router_config_changes ENABLE ROW LEVEL SECURITY"))
conn.execute(sa.text("ALTER TABLE router_config_changes FORCE ROW LEVEL SECURITY")) conn.execute(sa.text("ALTER TABLE router_config_changes FORCE ROW LEVEL SECURITY"))
conn.execute(sa.text(""" conn.execute(
sa.text("""
CREATE POLICY tenant_isolation ON router_config_changes CREATE POLICY tenant_isolation ON router_config_changes
USING (tenant_id::text = current_setting('app.current_tenant', true)) USING (tenant_id::text = current_setting('app.current_tenant', true))
WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true)) WITH CHECK (tenant_id::text = current_setting('app.current_tenant', true))
""")) """)
)
# Grants # Grants
conn.execute(sa.text("GRANT SELECT, INSERT, DELETE ON router_config_changes TO app_user")) conn.execute(sa.text("GRANT SELECT, INSERT, DELETE ON router_config_changes TO app_user"))
# Indexes # Indexes
conn.execute(sa.text( conn.execute(sa.text("CREATE INDEX idx_rcc_diff_id ON router_config_changes (diff_id)"))
"CREATE INDEX idx_rcc_diff_id ON router_config_changes (diff_id)"
))
def downgrade() -> None: def downgrade() -> None:

View File

@@ -25,40 +25,28 @@ import sqlalchemy as sa
def upgrade() -> None: def upgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
conn.execute(sa.text( conn.execute(sa.text("ALTER TABLE devices ADD COLUMN ssh_port INTEGER DEFAULT 22"))
"ALTER TABLE devices ADD COLUMN ssh_port INTEGER DEFAULT 22" conn.execute(sa.text("ALTER TABLE devices ADD COLUMN ssh_host_key_fingerprint TEXT"))
)) conn.execute(sa.text("ALTER TABLE devices ADD COLUMN ssh_host_key_first_seen TIMESTAMPTZ"))
conn.execute(sa.text( conn.execute(sa.text("ALTER TABLE devices ADD COLUMN ssh_host_key_last_verified TIMESTAMPTZ"))
"ALTER TABLE devices ADD COLUMN ssh_host_key_fingerprint TEXT"
))
conn.execute(sa.text(
"ALTER TABLE devices ADD COLUMN ssh_host_key_first_seen TIMESTAMPTZ"
))
conn.execute(sa.text(
"ALTER TABLE devices ADD COLUMN ssh_host_key_last_verified TIMESTAMPTZ"
))
# Grant poller_user UPDATE on SSH columns for TOFU host key persistence # Grant poller_user UPDATE on SSH columns for TOFU host key persistence
conn.execute(sa.text( conn.execute(
"GRANT UPDATE (ssh_host_key_fingerprint, ssh_host_key_first_seen, ssh_host_key_last_verified) ON devices TO poller_user" sa.text(
)) "GRANT UPDATE (ssh_host_key_fingerprint, ssh_host_key_first_seen, ssh_host_key_last_verified) ON devices TO poller_user"
)
)
def downgrade() -> None: def downgrade() -> None:
conn = op.get_bind() conn = op.get_bind()
conn.execute(sa.text( conn.execute(
"REVOKE UPDATE (ssh_host_key_fingerprint, ssh_host_key_first_seen, ssh_host_key_last_verified) ON devices FROM poller_user" sa.text(
)) "REVOKE UPDATE (ssh_host_key_fingerprint, ssh_host_key_first_seen, ssh_host_key_last_verified) ON devices FROM poller_user"
conn.execute(sa.text( )
"ALTER TABLE devices DROP COLUMN ssh_host_key_last_verified" )
)) conn.execute(sa.text("ALTER TABLE devices DROP COLUMN ssh_host_key_last_verified"))
conn.execute(sa.text( conn.execute(sa.text("ALTER TABLE devices DROP COLUMN ssh_host_key_first_seen"))
"ALTER TABLE devices DROP COLUMN ssh_host_key_first_seen" conn.execute(sa.text("ALTER TABLE devices DROP COLUMN ssh_host_key_fingerprint"))
)) conn.execute(sa.text("ALTER TABLE devices DROP COLUMN ssh_port"))
conn.execute(sa.text(
"ALTER TABLE devices DROP COLUMN ssh_host_key_fingerprint"
))
conn.execute(sa.text(
"ALTER TABLE devices DROP COLUMN ssh_port"
))

View File

@@ -16,7 +16,12 @@ import base64
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
from cryptography.hazmat.primitives.serialization import Encoding, NoEncryption, PrivateFormat, PublicFormat from cryptography.hazmat.primitives.serialization import (
Encoding,
NoEncryption,
PrivateFormat,
PublicFormat,
)
from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.ciphers.aead import AESGCM
@@ -120,5 +125,9 @@ def downgrade() -> None:
op.alter_column("vpn_config", "subnet", server_default="10.10.0.0/24") op.alter_column("vpn_config", "subnet", server_default="10.10.0.0/24")
op.alter_column("vpn_config", "server_address", server_default="10.10.0.1/24") op.alter_column("vpn_config", "server_address", server_default="10.10.0.1/24")
conn = op.get_bind() conn = op.get_bind()
conn.execute(sa.text("DELETE FROM system_settings WHERE key IN ('vpn_server_public_key', 'vpn_server_private_key')")) conn.execute(
sa.text(
"DELETE FROM system_settings WHERE key IN ('vpn_server_public_key', 'vpn_server_private_key')"
)
)
# NOTE: downgrade does not remap peer IPs back. Manual cleanup may be needed. # NOTE: downgrade does not remap peer IPs back. Manual cleanup may be needed.

View File

@@ -42,8 +42,8 @@ def validate_production_settings(settings: "Settings") -> None:
print( print(
f"FATAL: {field} uses a known insecure default in '{settings.ENVIRONMENT}' environment.\n" f"FATAL: {field} uses a known insecure default in '{settings.ENVIRONMENT}' environment.\n"
f"Generate a secure value and set it in your .env.prod file.\n" f"Generate a secure value and set it in your .env.prod file.\n"
f"For JWT_SECRET_KEY: python -c \"import secrets; print(secrets.token_urlsafe(64))\"\n" f'For JWT_SECRET_KEY: python -c "import secrets; print(secrets.token_urlsafe(64))"\n'
f"For CREDENTIAL_ENCRYPTION_KEY: python -c \"import secrets, base64; print(base64.b64encode(secrets.token_bytes(32)).decode())\"\n" f'For CREDENTIAL_ENCRYPTION_KEY: python -c "import secrets, base64; print(base64.b64encode(secrets.token_bytes(32)).decode())"\n'
f"For OPENBAO_TOKEN: use the token from your OpenBao server (not the dev token)", f"For OPENBAO_TOKEN: use the token from your OpenBao server (not the dev token)",
file=sys.stderr, file=sys.stderr,
) )

View File

@@ -17,6 +17,7 @@ from app.config import settings
class Base(DeclarativeBase): class Base(DeclarativeBase):
"""Base class for all SQLAlchemy ORM models.""" """Base class for all SQLAlchemy ORM models."""
pass pass

View File

@@ -82,7 +82,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
from app.services.retention_service import start_retention_scheduler, stop_retention_scheduler from app.services.retention_service import start_retention_scheduler, stop_retention_scheduler
from app.services.metrics_subscriber import start_metrics_subscriber, stop_metrics_subscriber from app.services.metrics_subscriber import start_metrics_subscriber, stop_metrics_subscriber
from app.services.nats_subscriber import start_nats_subscriber, stop_nats_subscriber from app.services.nats_subscriber import start_nats_subscriber, stop_nats_subscriber
from app.services.session_audit_subscriber import start_session_audit_subscriber, stop_session_audit_subscriber from app.services.session_audit_subscriber import (
start_session_audit_subscriber,
stop_session_audit_subscriber,
)
from app.services.sse_manager import ensure_sse_streams from app.services.sse_manager import ensure_sse_streams
# Configure structured logging FIRST -- before any other startup work # Configure structured logging FIRST -- before any other startup work
@@ -201,6 +204,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
start_config_change_subscriber, start_config_change_subscriber,
stop_config_change_subscriber, stop_config_change_subscriber,
) )
config_change_nc = await start_config_change_subscriber() config_change_nc = await start_config_change_subscriber()
except Exception as e: except Exception as e:
logger.error("Config change subscriber failed to start (non-fatal): %s", e) logger.error("Config change subscriber failed to start (non-fatal): %s", e)
@@ -212,6 +216,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
start_push_rollback_subscriber, start_push_rollback_subscriber,
stop_push_rollback_subscriber, stop_push_rollback_subscriber,
) )
push_rollback_nc = await start_push_rollback_subscriber() push_rollback_nc = await start_push_rollback_subscriber()
except Exception as e: except Exception as e:
logger.error("Push rollback subscriber failed to start (non-fatal): %s", e) logger.error("Push rollback subscriber failed to start (non-fatal): %s", e)
@@ -223,6 +228,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
start_config_snapshot_subscriber, start_config_snapshot_subscriber,
stop_config_snapshot_subscriber, stop_config_snapshot_subscriber,
) )
config_snapshot_nc = await start_config_snapshot_subscriber() config_snapshot_nc = await start_config_snapshot_subscriber()
except Exception as e: except Exception as e:
logger.error("Config snapshot subscriber failed to start (non-fatal): %s", e) logger.error("Config snapshot subscriber failed to start (non-fatal): %s", e)
@@ -231,14 +237,16 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
try: try:
await start_retention_scheduler() await start_retention_scheduler()
except Exception as exc: except Exception as exc:
logger.warning("retention scheduler could not start (API will run without it)", error=str(exc)) logger.warning(
"retention scheduler could not start (API will run without it)", error=str(exc)
)
# Start Remote WinBox session reconciliation loop (60s interval). # Start Remote WinBox session reconciliation loop (60s interval).
# Detects orphaned sessions (worker lost them) and cleans up Redis + tunnels. # Detects orphaned sessions (worker lost them) and cleans up Redis + tunnels.
winbox_reconcile_task: Optional[asyncio.Task] = None # type: ignore[type-arg] winbox_reconcile_task: Optional[asyncio.Task] = None # type: ignore[type-arg]
try: try:
from app.routers.winbox_remote import _get_redis as _wb_get_redis, _close_tunnel from app.routers.winbox_remote import _get_redis as _wb_get_redis, _close_tunnel
from app.services.winbox_remote import get_session as _wb_worker_get, health_check as _wb_health from app.services.winbox_remote import get_session as _wb_worker_get
async def _winbox_reconcile_loop() -> None: async def _winbox_reconcile_loop() -> None:
"""Scan Redis for winbox-remote:* keys and reconcile with worker.""" """Scan Redis for winbox-remote:* keys and reconcile with worker."""

View File

@@ -49,6 +49,7 @@ def require_role(*allowed_roles: str) -> Callable:
Returns: Returns:
FastAPI dependency that raises 403 if the role is insufficient FastAPI dependency that raises 403 if the role is insufficient
""" """
async def dependency( async def dependency(
current_user: CurrentUser = Depends(get_current_user), current_user: CurrentUser = Depends(get_current_user),
) -> CurrentUser: ) -> CurrentUser:
@@ -56,7 +57,7 @@ def require_role(*allowed_roles: str) -> Callable:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied. Required roles: {', '.join(allowed_roles)}. " detail=f"Access denied. Required roles: {', '.join(allowed_roles)}. "
f"Your role: {current_user.role}", f"Your role: {current_user.role}",
) )
return current_user return current_user
@@ -82,7 +83,7 @@ def require_min_role(min_role: str) -> Callable:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied. Minimum required role: {min_role}. " detail=f"Access denied. Minimum required role: {min_role}. "
f"Your role: {current_user.role}", f"Your role: {current_user.role}",
) )
return current_user return current_user
@@ -96,6 +97,7 @@ def require_write_access() -> Callable:
Viewers are NOT allowed on POST/PUT/PATCH/DELETE endpoints. Viewers are NOT allowed on POST/PUT/PATCH/DELETE endpoints.
Call this on any mutating endpoint to deny viewers. Call this on any mutating endpoint to deny viewers.
""" """
async def dependency( async def dependency(
request: Request, request: Request,
current_user: CurrentUser = Depends(get_current_user), current_user: CurrentUser = Depends(get_current_user),
@@ -105,7 +107,7 @@ def require_write_access() -> Callable:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Viewers have read-only access. " detail="Viewers have read-only access. "
"Contact your administrator to request elevated permissions.", "Contact your administrator to request elevated permissions.",
) )
return current_user return current_user
@@ -127,6 +129,7 @@ def require_scope(scope: str) -> DependsClass:
Raises: Raises:
HTTPException 403 if the API key is missing the required scope. HTTPException 403 if the API key is missing the required scope.
""" """
async def _check_scope( async def _check_scope(
current_user: CurrentUser = Depends(get_current_user), current_user: CurrentUser = Depends(get_current_user),
) -> CurrentUser: ) -> CurrentUser:
@@ -143,6 +146,7 @@ def require_scope(scope: str) -> DependsClass:
# Pre-built convenience dependencies # Pre-built convenience dependencies
async def require_super_admin( async def require_super_admin(
current_user: CurrentUser = Depends(get_current_user), current_user: CurrentUser = Depends(get_current_user),
) -> CurrentUser: ) -> CurrentUser:

View File

@@ -20,32 +20,36 @@ from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
# Production CSP: strict -- no inline scripts allowed # Production CSP: strict -- no inline scripts allowed
_CSP_PRODUCTION = "; ".join([ _CSP_PRODUCTION = "; ".join(
"default-src 'self'", [
"script-src 'self'", "default-src 'self'",
"style-src 'self' 'unsafe-inline'", "script-src 'self'",
"img-src 'self' data: blob:", "style-src 'self' 'unsafe-inline'",
"font-src 'self'", "img-src 'self' data: blob:",
"connect-src 'self' wss: ws:", "font-src 'self'",
"worker-src 'self'", "connect-src 'self' wss: ws:",
"frame-ancestors 'none'", "worker-src 'self'",
"base-uri 'self'", "frame-ancestors 'none'",
"form-action 'self'", "base-uri 'self'",
]) "form-action 'self'",
]
)
# Development CSP: relaxed for Vite HMR (hot module replacement) # Development CSP: relaxed for Vite HMR (hot module replacement)
_CSP_DEV = "; ".join([ _CSP_DEV = "; ".join(
"default-src 'self'", [
"script-src 'self' 'unsafe-inline' 'unsafe-eval'", "default-src 'self'",
"style-src 'self' 'unsafe-inline'", "script-src 'self' 'unsafe-inline' 'unsafe-eval'",
"img-src 'self' data: blob:", "style-src 'self' 'unsafe-inline'",
"font-src 'self'", "img-src 'self' data: blob:",
"connect-src 'self' http://localhost:* ws://localhost:* wss:", "font-src 'self'",
"worker-src 'self' blob:", "connect-src 'self' http://localhost:* ws://localhost:* wss:",
"frame-ancestors 'none'", "worker-src 'self' blob:",
"base-uri 'self'", "frame-ancestors 'none'",
"form-action 'self'", "base-uri 'self'",
]) "form-action 'self'",
]
)
class SecurityHeadersMiddleware(BaseHTTPMiddleware): class SecurityHeadersMiddleware(BaseHTTPMiddleware):
@@ -72,8 +76,6 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
# HSTS only in production (plain HTTP in dev would be blocked) # HSTS only in production (plain HTTP in dev would be blocked)
if self.is_production: if self.is_production:
response.headers["Strict-Transport-Security"] = ( response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
"max-age=31536000; includeSubDomains"
)
return response return response

View File

@@ -178,7 +178,6 @@ async def get_current_user_ws(
Raises: Raises:
WebSocketException 1008: If no token is provided or the token is invalid. WebSocketException 1008: If no token is provided or the token is invalid.
""" """
from starlette.websockets import WebSocket, WebSocketState
from fastapi import WebSocketException from fastapi import WebSocketException
# 1. Try cookie # 1. Try cookie

View File

@@ -2,7 +2,14 @@
from app.models.tenant import Tenant from app.models.tenant import Tenant
from app.models.user import User, UserRole from app.models.user import User, UserRole
from app.models.device import Device, DeviceGroup, DeviceTag, DeviceGroupMembership, DeviceTagAssignment, DeviceStatus from app.models.device import (
Device,
DeviceGroup,
DeviceTag,
DeviceGroupMembership,
DeviceTagAssignment,
DeviceStatus,
)
from app.models.alert import AlertRule, NotificationChannel, AlertRuleChannel, AlertEvent from app.models.alert import AlertRule, NotificationChannel, AlertRuleChannel, AlertEvent
from app.models.firmware import FirmwareVersion, FirmwareUpgradeJob from app.models.firmware import FirmwareVersion, FirmwareUpgradeJob
from app.models.config_template import ConfigTemplate, ConfigTemplateTag, TemplatePushJob from app.models.config_template import ConfigTemplate, ConfigTemplateTag, TemplatePushJob

View File

@@ -26,6 +26,7 @@ class AlertRule(Base):
When a metric breaches the threshold for duration_polls consecutive polls, When a metric breaches the threshold for duration_polls consecutive polls,
an alert fires. an alert fires.
""" """
__tablename__ = "alert_rules" __tablename__ = "alert_rules"
id: Mapped[uuid.UUID] = mapped_column( id: Mapped[uuid.UUID] = mapped_column(
@@ -53,10 +54,16 @@ class AlertRule(Base):
metric: Mapped[str] = mapped_column(Text, nullable=False) metric: Mapped[str] = mapped_column(Text, nullable=False)
operator: Mapped[str] = mapped_column(Text, nullable=False) operator: Mapped[str] = mapped_column(Text, nullable=False)
threshold: Mapped[float] = mapped_column(Numeric, nullable=False) threshold: Mapped[float] = mapped_column(Numeric, nullable=False)
duration_polls: Mapped[int] = mapped_column(Integer, nullable=False, default=1, server_default="1") duration_polls: Mapped[int] = mapped_column(
Integer, nullable=False, default=1, server_default="1"
)
severity: Mapped[str] = mapped_column(Text, nullable=False) severity: Mapped[str] = mapped_column(Text, nullable=False)
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True, server_default="true") enabled: Mapped[bool] = mapped_column(
is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default="false") Boolean, nullable=False, default=True, server_default="true"
)
is_default: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False, server_default="false"
)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), DateTime(timezone=True),
server_default=func.now(), server_default=func.now(),
@@ -69,6 +76,7 @@ class AlertRule(Base):
class NotificationChannel(Base): class NotificationChannel(Base):
"""Email, webhook, or Slack notification destination.""" """Email, webhook, or Slack notification destination."""
__tablename__ = "notification_channels" __tablename__ = "notification_channels"
id: Mapped[uuid.UUID] = mapped_column( id: Mapped[uuid.UUID] = mapped_column(
@@ -83,12 +91,16 @@ class NotificationChannel(Base):
nullable=False, nullable=False,
) )
name: Mapped[str] = mapped_column(Text, nullable=False) name: Mapped[str] = mapped_column(Text, nullable=False)
channel_type: Mapped[str] = mapped_column(Text, nullable=False) # "email", "webhook", or "slack" channel_type: Mapped[str] = mapped_column(
Text, nullable=False
) # "email", "webhook", or "slack"
# SMTP fields (email channels) # SMTP fields (email channels)
smtp_host: Mapped[str | None] = mapped_column(Text, nullable=True) smtp_host: Mapped[str | None] = mapped_column(Text, nullable=True)
smtp_port: Mapped[int | None] = mapped_column(Integer, nullable=True) smtp_port: Mapped[int | None] = mapped_column(Integer, nullable=True)
smtp_user: Mapped[str | None] = mapped_column(Text, nullable=True) smtp_user: Mapped[str | None] = mapped_column(Text, nullable=True)
smtp_password: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True) # AES-256-GCM encrypted smtp_password: Mapped[bytes | None] = mapped_column(
LargeBinary, nullable=True
) # AES-256-GCM encrypted
smtp_use_tls: Mapped[bool] = mapped_column(Boolean, default=False, server_default="false") smtp_use_tls: Mapped[bool] = mapped_column(Boolean, default=False, server_default="false")
from_address: Mapped[str | None] = mapped_column(Text, nullable=True) from_address: Mapped[str | None] = mapped_column(Text, nullable=True)
to_address: Mapped[str | None] = mapped_column(Text, nullable=True) to_address: Mapped[str | None] = mapped_column(Text, nullable=True)
@@ -110,6 +122,7 @@ class NotificationChannel(Base):
class AlertRuleChannel(Base): class AlertRuleChannel(Base):
"""Many-to-many association between alert rules and notification channels.""" """Many-to-many association between alert rules and notification channels."""
__tablename__ = "alert_rule_channels" __tablename__ = "alert_rule_channels"
rule_id: Mapped[uuid.UUID] = mapped_column( rule_id: Mapped[uuid.UUID] = mapped_column(
@@ -129,6 +142,7 @@ class AlertEvent(Base):
rule_id is NULL for system-level alerts (e.g., device offline). rule_id is NULL for system-level alerts (e.g., device offline).
""" """
__tablename__ = "alert_events" __tablename__ = "alert_events"
id: Mapped[uuid.UUID] = mapped_column( id: Mapped[uuid.UUID] = mapped_column(
@@ -158,7 +172,9 @@ class AlertEvent(Base):
value: Mapped[float | None] = mapped_column(Numeric, nullable=True) value: Mapped[float | None] = mapped_column(Numeric, nullable=True)
threshold: Mapped[float | None] = mapped_column(Numeric, nullable=True) threshold: Mapped[float | None] = mapped_column(Numeric, nullable=True)
message: Mapped[str | None] = mapped_column(Text, nullable=True) message: Mapped[str | None] = mapped_column(Text, nullable=True)
is_flapping: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default="false") is_flapping: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False, server_default="false"
)
acknowledged_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) acknowledged_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
acknowledged_by: Mapped[uuid.UUID | None] = mapped_column( acknowledged_by: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), UUID(as_uuid=True),

View File

@@ -41,20 +41,14 @@ class ApiKey(Base):
key_prefix: Mapped[str] = mapped_column(Text, nullable=False) key_prefix: Mapped[str] = mapped_column(Text, nullable=False)
key_hash: Mapped[str] = mapped_column(Text, nullable=False, unique=True) key_hash: Mapped[str] = mapped_column(Text, nullable=False, unique=True)
scopes: Mapped[list] = mapped_column(JSONB, nullable=False, server_default="'[]'::jsonb") scopes: Mapped[list] = mapped_column(JSONB, nullable=False, server_default="'[]'::jsonb")
expires_at: Mapped[Optional[datetime]] = mapped_column( expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
DateTime(timezone=True), nullable=True last_used_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
)
last_used_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True), nullable=True
)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), DateTime(timezone=True),
server_default=func.now(), server_default=func.now(),
nullable=False, nullable=False,
) )
revoked_at: Mapped[Optional[datetime]] = mapped_column( revoked_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
DateTime(timezone=True), nullable=True
)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<ApiKey id={self.id} name={self.name} prefix={self.key_prefix}>" return f"<ApiKey id={self.id} name={self.name} prefix={self.key_prefix}>"

View File

@@ -39,21 +39,13 @@ class CertificateAuthority(Base):
) )
common_name: Mapped[str] = mapped_column(String(255), nullable=False) common_name: Mapped[str] = mapped_column(String(255), nullable=False)
cert_pem: Mapped[str] = mapped_column(Text, nullable=False) cert_pem: Mapped[str] = mapped_column(Text, nullable=False)
encrypted_private_key: Mapped[bytes] = mapped_column( encrypted_private_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
LargeBinary, nullable=False
)
serial_number: Mapped[str] = mapped_column(String(64), nullable=False) serial_number: Mapped[str] = mapped_column(String(64), nullable=False)
fingerprint_sha256: Mapped[str] = mapped_column(String(95), nullable=False) fingerprint_sha256: Mapped[str] = mapped_column(String(95), nullable=False)
not_valid_before: Mapped[datetime] = mapped_column( not_valid_before: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
DateTime(timezone=True), nullable=False not_valid_after: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
)
not_valid_after: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False
)
# OpenBao Transit ciphertext (dual-write migration) # OpenBao Transit ciphertext (dual-write migration)
encrypted_private_key_transit: Mapped[str | None] = mapped_column( encrypted_private_key_transit: Mapped[str | None] = mapped_column(Text, nullable=True)
Text, nullable=True
)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), DateTime(timezone=True),
server_default=func.now(), server_default=func.now(),
@@ -62,8 +54,7 @@ class CertificateAuthority(Base):
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"<CertificateAuthority id={self.id} " f"<CertificateAuthority id={self.id} cn={self.common_name!r} tenant={self.tenant_id}>"
f"cn={self.common_name!r} tenant={self.tenant_id}>"
) )
@@ -103,25 +94,13 @@ class DeviceCertificate(Base):
serial_number: Mapped[str] = mapped_column(String(64), nullable=False) serial_number: Mapped[str] = mapped_column(String(64), nullable=False)
fingerprint_sha256: Mapped[str] = mapped_column(String(95), nullable=False) fingerprint_sha256: Mapped[str] = mapped_column(String(95), nullable=False)
cert_pem: Mapped[str] = mapped_column(Text, nullable=False) cert_pem: Mapped[str] = mapped_column(Text, nullable=False)
encrypted_private_key: Mapped[bytes] = mapped_column( encrypted_private_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
LargeBinary, nullable=False not_valid_before: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
) not_valid_after: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
not_valid_before: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False
)
not_valid_after: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False
)
# OpenBao Transit ciphertext (dual-write migration) # OpenBao Transit ciphertext (dual-write migration)
encrypted_private_key_transit: Mapped[str | None] = mapped_column( encrypted_private_key_transit: Mapped[str | None] = mapped_column(Text, nullable=True)
Text, nullable=True status: Mapped[str] = mapped_column(String(20), nullable=False, server_default="issued")
) deployed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
status: Mapped[str] = mapped_column(
String(20), nullable=False, server_default="issued"
)
deployed_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), DateTime(timezone=True),
server_default=func.now(), server_default=func.now(),
@@ -134,7 +113,4 @@ class DeviceCertificate(Base):
) )
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return f"<DeviceCertificate id={self.id} cn={self.common_name!r} status={self.status}>"
f"<DeviceCertificate id={self.id} "
f"cn={self.common_name!r} status={self.status}>"
)

View File

@@ -3,9 +3,20 @@
import uuid import uuid
from datetime import datetime from datetime import datetime
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, LargeBinary, SmallInteger, String, Text, UniqueConstraint, func from sqlalchemy import (
Boolean,
DateTime,
ForeignKey,
Integer,
LargeBinary,
SmallInteger,
String,
Text,
UniqueConstraint,
func,
)
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base from app.database import Base
@@ -115,7 +126,9 @@ class ConfigBackupSchedule(Base):
def __repr__(self) -> str: def __repr__(self) -> str:
scope = f"device={self.device_id}" if self.device_id else f"tenant={self.tenant_id}" scope = f"device={self.device_id}" if self.device_id else f"tenant={self.tenant_id}"
return f"<ConfigBackupSchedule {scope} cron={self.cron_expression!r} enabled={self.enabled}>" return (
f"<ConfigBackupSchedule {scope} cron={self.cron_expression!r} enabled={self.enabled}>"
)
class ConfigPushOperation(Base): class ConfigPushOperation(Base):
@@ -173,8 +186,7 @@ class ConfigPushOperation(Base):
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"<ConfigPushOperation id={self.id} device_id={self.device_id} " f"<ConfigPushOperation id={self.id} device_id={self.device_id} status={self.status!r}>"
f"status={self.status!r}>"
) )
@@ -272,7 +284,9 @@ class RouterConfigDiff(Base):
) )
diff_text: Mapped[str] = mapped_column(Text, nullable=False) diff_text: Mapped[str] = mapped_column(Text, nullable=False)
lines_added: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0") lines_added: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
lines_removed: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0") lines_removed: Mapped[int] = mapped_column(
Integer, nullable=False, default=0, server_default="0"
)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), DateTime(timezone=True),
server_default=func.now(), server_default=func.now(),
@@ -334,6 +348,5 @@ class RouterConfigChange(Base):
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"<RouterConfigChange id={self.id} diff_id={self.diff_id} " f"<RouterConfigChange id={self.id} diff_id={self.diff_id} component={self.component!r}>"
f"component={self.component!r}>"
) )

View File

@@ -5,7 +5,6 @@ from datetime import datetime
from sqlalchemy import ( from sqlalchemy import (
DateTime, DateTime,
Float,
ForeignKey, ForeignKey,
String, String,
Text, Text,
@@ -88,9 +87,7 @@ class ConfigTemplateTag(Base):
) )
# Relationships # Relationships
template: Mapped["ConfigTemplate"] = relationship( template: Mapped["ConfigTemplate"] = relationship("ConfigTemplate", back_populates="tags")
"ConfigTemplate", back_populates="tags"
)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<ConfigTemplateTag id={self.id} name={self.name!r} template_id={self.template_id}>" return f"<ConfigTemplateTag id={self.id} name={self.name!r} template_id={self.template_id}>"
@@ -133,12 +130,8 @@ class TemplatePushJob(Base):
) )
pre_push_backup_sha: Mapped[str | None] = mapped_column(Text, nullable=True) pre_push_backup_sha: Mapped[str | None] = mapped_column(Text, nullable=True)
error_message: Mapped[str | None] = mapped_column(Text, nullable=True) error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
started_at: Mapped[datetime | None] = mapped_column( started_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
DateTime(timezone=True), nullable=True completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
)
completed_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), DateTime(timezone=True),
server_default=func.now(), server_default=func.now(),

View File

@@ -5,7 +5,6 @@ from datetime import datetime
from enum import Enum from enum import Enum
from sqlalchemy import ( from sqlalchemy import (
Boolean,
DateTime, DateTime,
Float, Float,
ForeignKey, ForeignKey,
@@ -24,6 +23,7 @@ from app.database import Base
class DeviceStatus(str, Enum): class DeviceStatus(str, Enum):
"""Device connection status.""" """Device connection status."""
UNKNOWN = "unknown" UNKNOWN = "unknown"
ONLINE = "online" ONLINE = "online"
OFFLINE = "offline" OFFLINE = "offline"
@@ -31,9 +31,7 @@ class DeviceStatus(str, Enum):
class Device(Base): class Device(Base):
__tablename__ = "devices" __tablename__ = "devices"
__table_args__ = ( __table_args__ = (UniqueConstraint("tenant_id", "hostname", name="uq_devices_tenant_hostname"),)
UniqueConstraint("tenant_id", "hostname", name="uq_devices_tenant_hostname"),
)
id: Mapped[uuid.UUID] = mapped_column( id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), UUID(as_uuid=True),
@@ -59,7 +57,9 @@ class Device(Base):
uptime_seconds: Mapped[int | None] = mapped_column(Integer, nullable=True) uptime_seconds: Mapped[int | None] = mapped_column(Integer, nullable=True)
last_cpu_load: Mapped[int | None] = mapped_column(Integer, nullable=True) last_cpu_load: Mapped[int | None] = mapped_column(Integer, nullable=True)
last_memory_used_pct: Mapped[int | None] = mapped_column(Integer, nullable=True) last_memory_used_pct: Mapped[int | None] = mapped_column(Integer, nullable=True)
architecture: Mapped[str | None] = mapped_column(Text, nullable=True) # CPU arch (arm, arm64, mipsbe, etc.) architecture: Mapped[str | None] = mapped_column(
Text, nullable=True
) # CPU arch (arm, arm64, mipsbe, etc.)
preferred_channel: Mapped[str] = mapped_column( preferred_channel: Mapped[str] = mapped_column(
Text, default="stable", server_default="stable", nullable=False Text, default="stable", server_default="stable", nullable=False
) # Firmware release channel ) # Firmware release channel
@@ -108,9 +108,7 @@ class Device(Base):
class DeviceGroup(Base): class DeviceGroup(Base):
__tablename__ = "device_groups" __tablename__ = "device_groups"
__table_args__ = ( __table_args__ = (UniqueConstraint("tenant_id", "name", name="uq_device_groups_tenant_name"),)
UniqueConstraint("tenant_id", "name", name="uq_device_groups_tenant_name"),
)
id: Mapped[uuid.UUID] = mapped_column( id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), UUID(as_uuid=True),
@@ -147,9 +145,7 @@ class DeviceGroup(Base):
class DeviceTag(Base): class DeviceTag(Base):
__tablename__ = "device_tags" __tablename__ = "device_tags"
__table_args__ = ( __table_args__ = (UniqueConstraint("tenant_id", "name", name="uq_device_tags_tenant_name"),)
UniqueConstraint("tenant_id", "name", name="uq_device_tags_tenant_name"),
)
id: Mapped[uuid.UUID] = mapped_column( id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), UUID(as_uuid=True),

View File

@@ -7,7 +7,6 @@ from sqlalchemy import (
BigInteger, BigInteger,
Boolean, Boolean,
DateTime, DateTime,
Integer,
Text, Text,
UniqueConstraint, UniqueConstraint,
func, func,
@@ -24,10 +23,9 @@ class FirmwareVersion(Base):
Not tenant-scoped — firmware versions are global data shared across all tenants. Not tenant-scoped — firmware versions are global data shared across all tenants.
""" """
__tablename__ = "firmware_versions" __tablename__ = "firmware_versions"
__table_args__ = ( __table_args__ = (UniqueConstraint("architecture", "channel", "version"),)
UniqueConstraint("architecture", "channel", "version"),
)
id: Mapped[uuid.UUID] = mapped_column( id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), UUID(as_uuid=True),
@@ -56,6 +54,7 @@ class FirmwareUpgradeJob(Base):
Multiple jobs can share a rollout_group_id for mass upgrades. Multiple jobs can share a rollout_group_id for mass upgrades.
""" """
__tablename__ = "firmware_upgrade_jobs" __tablename__ = "firmware_upgrade_jobs"
id: Mapped[uuid.UUID] = mapped_column( id: Mapped[uuid.UUID] = mapped_column(
@@ -99,4 +98,6 @@ class FirmwareUpgradeJob(Base):
) )
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<FirmwareUpgradeJob id={self.id} status={self.status} target={self.target_version}>" return (
f"<FirmwareUpgradeJob id={self.id} status={self.status} target={self.target_version}>"
)

View File

@@ -37,32 +37,18 @@ class UserKeySet(Base):
ForeignKey("tenants.id", ondelete="CASCADE"), ForeignKey("tenants.id", ondelete="CASCADE"),
nullable=True, # NULL for super_admin nullable=True, # NULL for super_admin
) )
encrypted_private_key: Mapped[bytes] = mapped_column( encrypted_private_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
LargeBinary, nullable=False private_key_nonce: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
) encrypted_vault_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
private_key_nonce: Mapped[bytes] = mapped_column( vault_key_nonce: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
LargeBinary, nullable=False public_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
)
encrypted_vault_key: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
vault_key_nonce: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
public_key: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
pbkdf2_iterations: Mapped[int] = mapped_column( pbkdf2_iterations: Mapped[int] = mapped_column(
Integer, Integer,
server_default=func.literal_column("650000"), server_default=func.literal_column("650000"),
nullable=False, nullable=False,
) )
pbkdf2_salt: Mapped[bytes] = mapped_column( pbkdf2_salt: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
LargeBinary, nullable=False hkdf_salt: Mapped[bytes] = mapped_column(LargeBinary, nullable=False)
)
hkdf_salt: Mapped[bytes] = mapped_column(
LargeBinary, nullable=False
)
key_version: Mapped[int] = mapped_column( key_version: Mapped[int] = mapped_column(
Integer, Integer,
server_default=func.literal_column("1"), server_default=func.literal_column("1"),

View File

@@ -20,6 +20,7 @@ class MaintenanceWindow(Base):
device_ids is a JSONB array of device UUID strings. device_ids is a JSONB array of device UUID strings.
An empty array means "all devices in tenant". An empty array means "all devices in tenant".
""" """
__tablename__ = "maintenance_windows" __tablename__ = "maintenance_windows"
id: Mapped[uuid.UUID] = mapped_column( id: Mapped[uuid.UUID] = mapped_column(

View File

@@ -40,10 +40,18 @@ class Tenant(Base):
openbao_key_name: Mapped[str | None] = mapped_column(Text, nullable=True) openbao_key_name: Mapped[str | None] = mapped_column(Text, nullable=True)
# Relationships — passive_deletes=True lets the DB ON DELETE CASCADE handle cleanup # Relationships — passive_deletes=True lets the DB ON DELETE CASCADE handle cleanup
users: Mapped[list["User"]] = relationship("User", back_populates="tenant", passive_deletes=True) # type: ignore[name-defined] users: Mapped[list["User"]] = relationship(
devices: Mapped[list["Device"]] = relationship("Device", back_populates="tenant", passive_deletes=True) # type: ignore[name-defined] "User", back_populates="tenant", passive_deletes=True
device_groups: Mapped[list["DeviceGroup"]] = relationship("DeviceGroup", back_populates="tenant", passive_deletes=True) # type: ignore[name-defined] ) # type: ignore[name-defined]
device_tags: Mapped[list["DeviceTag"]] = relationship("DeviceTag", back_populates="tenant", passive_deletes=True) # type: ignore[name-defined] devices: Mapped[list["Device"]] = relationship(
"Device", back_populates="tenant", passive_deletes=True
) # type: ignore[name-defined]
device_groups: Mapped[list["DeviceGroup"]] = relationship(
"DeviceGroup", back_populates="tenant", passive_deletes=True
) # type: ignore[name-defined]
device_tags: Mapped[list["DeviceTag"]] = relationship(
"DeviceTag", back_populates="tenant", passive_deletes=True
) # type: ignore[name-defined]
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Tenant id={self.id} name={self.name!r}>" return f"<Tenant id={self.id} name={self.name!r}>"

View File

@@ -13,6 +13,7 @@ from app.database import Base
class UserRole(str, Enum): class UserRole(str, Enum):
"""User roles with increasing privilege levels.""" """User roles with increasing privilege levels."""
SUPER_ADMIN = "super_admin" SUPER_ADMIN = "super_admin"
TENANT_ADMIN = "tenant_admin" TENANT_ADMIN = "tenant_admin"
OPERATOR = "operator" OPERATOR = "operator"

View File

@@ -75,7 +75,9 @@ class VpnPeer(Base):
assigned_ip: Mapped[str] = mapped_column(String(32), nullable=False) assigned_ip: Mapped[str] = mapped_column(String(32), nullable=False)
additional_allowed_ips: Mapped[Optional[str]] = mapped_column(String(512), nullable=True) additional_allowed_ips: Mapped[Optional[str]] = mapped_column(String(512), nullable=True)
is_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default="true") is_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default="true")
last_handshake: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) last_handshake: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True), nullable=True
)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False DateTime(timezone=True), server_default=func.now(), nullable=False
) )

View File

@@ -12,10 +12,8 @@ RLS enforced via get_db() (app_user engine with tenant context).
RBAC: viewer = read-only (GET); operator and above = write (POST/PUT/PATCH/DELETE). RBAC: viewer = read-only (GET); operator and above = write (POST/PUT/PATCH/DELETE).
""" """
import base64
import logging import logging
import uuid import uuid
from datetime import datetime, timedelta, timezone
from typing import Any, Optional from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
@@ -66,8 +64,13 @@ def _require_write(current_user: CurrentUser) -> None:
EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
ALLOWED_METRICS = { ALLOWED_METRICS = {
"cpu_load", "memory_used_pct", "disk_used_pct", "temperature", "cpu_load",
"signal_strength", "ccq", "client_count", "memory_used_pct",
"disk_used_pct",
"temperature",
"signal_strength",
"ccq",
"client_count",
} }
ALLOWED_OPERATORS = {"gt", "lt", "gte", "lte"} ALLOWED_OPERATORS = {"gt", "lt", "gte", "lte"}
ALLOWED_SEVERITIES = {"critical", "warning", "info"} ALLOWED_SEVERITIES = {"critical", "warning", "info"}
@@ -252,7 +255,9 @@ async def create_alert_rule(
if body.operator not in ALLOWED_OPERATORS: if body.operator not in ALLOWED_OPERATORS:
raise HTTPException(422, f"operator must be one of: {', '.join(sorted(ALLOWED_OPERATORS))}") raise HTTPException(422, f"operator must be one of: {', '.join(sorted(ALLOWED_OPERATORS))}")
if body.severity not in ALLOWED_SEVERITIES: if body.severity not in ALLOWED_SEVERITIES:
raise HTTPException(422, f"severity must be one of: {', '.join(sorted(ALLOWED_SEVERITIES))}") raise HTTPException(
422, f"severity must be one of: {', '.join(sorted(ALLOWED_SEVERITIES))}"
)
rule_id = str(uuid.uuid4()) rule_id = str(uuid.uuid4())
@@ -296,8 +301,12 @@ async def create_alert_rule(
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "alert_rule_create", db,
resource_type="alert_rule", resource_id=rule_id, tenant_id,
current_user.user_id,
"alert_rule_create",
resource_type="alert_rule",
resource_id=rule_id,
details={"name": body.name, "metric": body.metric, "severity": body.severity}, details={"name": body.name, "metric": body.metric, "severity": body.severity},
) )
except Exception: except Exception:
@@ -338,7 +347,9 @@ async def update_alert_rule(
if body.operator not in ALLOWED_OPERATORS: if body.operator not in ALLOWED_OPERATORS:
raise HTTPException(422, f"operator must be one of: {', '.join(sorted(ALLOWED_OPERATORS))}") raise HTTPException(422, f"operator must be one of: {', '.join(sorted(ALLOWED_OPERATORS))}")
if body.severity not in ALLOWED_SEVERITIES: if body.severity not in ALLOWED_SEVERITIES:
raise HTTPException(422, f"severity must be one of: {', '.join(sorted(ALLOWED_SEVERITIES))}") raise HTTPException(
422, f"severity must be one of: {', '.join(sorted(ALLOWED_SEVERITIES))}"
)
result = await db.execute( result = await db.execute(
text(""" text("""
@@ -384,8 +395,12 @@ async def update_alert_rule(
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "alert_rule_update", db,
resource_type="alert_rule", resource_id=str(rule_id), tenant_id,
current_user.user_id,
"alert_rule_update",
resource_type="alert_rule",
resource_id=str(rule_id),
details={"name": body.name, "metric": body.metric, "severity": body.severity}, details={"name": body.name, "metric": body.metric, "severity": body.severity},
) )
except Exception: except Exception:
@@ -439,8 +454,12 @@ async def delete_alert_rule(
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "alert_rule_delete", db,
resource_type="alert_rule", resource_id=str(rule_id), tenant_id,
current_user.user_id,
"alert_rule_delete",
resource_type="alert_rule",
resource_id=str(rule_id),
) )
except Exception: except Exception:
pass pass
@@ -592,7 +611,8 @@ async def create_notification_channel(
encrypted_password_transit = None encrypted_password_transit = None
if body.smtp_password: if body.smtp_password:
encrypted_password_transit = await encrypt_credentials_transit( encrypted_password_transit = await encrypt_credentials_transit(
body.smtp_password, str(tenant_id), body.smtp_password,
str(tenant_id),
) )
await db.execute( await db.execute(
@@ -665,10 +685,14 @@ async def update_notification_channel(
# Build SET clauses dynamically based on which secrets are provided # Build SET clauses dynamically based on which secrets are provided
set_parts = [ set_parts = [
"name = :name", "channel_type = :channel_type", "name = :name",
"smtp_host = :smtp_host", "smtp_port = :smtp_port", "channel_type = :channel_type",
"smtp_user = :smtp_user", "smtp_use_tls = :smtp_use_tls", "smtp_host = :smtp_host",
"from_address = :from_address", "to_address = :to_address", "smtp_port = :smtp_port",
"smtp_user = :smtp_user",
"smtp_use_tls = :smtp_use_tls",
"from_address = :from_address",
"to_address = :to_address",
"webhook_url = :webhook_url", "webhook_url = :webhook_url",
"slack_webhook_url = :slack_webhook_url", "slack_webhook_url = :slack_webhook_url",
] ]
@@ -689,7 +713,8 @@ async def update_notification_channel(
if body.smtp_password: if body.smtp_password:
set_parts.append("smtp_password_transit = :smtp_password_transit") set_parts.append("smtp_password_transit = :smtp_password_transit")
params["smtp_password_transit"] = await encrypt_credentials_transit( params["smtp_password_transit"] = await encrypt_credentials_transit(
body.smtp_password, str(tenant_id), body.smtp_password,
str(tenant_id),
) )
# Clear legacy column # Clear legacy column
set_parts.append("smtp_password = NULL") set_parts.append("smtp_password = NULL")
@@ -799,6 +824,7 @@ async def test_notification_channel(
} }
from app.services.notification_service import send_test_notification from app.services.notification_service import send_test_notification
try: try:
success = await send_test_notification(channel) success = await send_test_notification(channel)
if success: if success:

View File

@@ -221,29 +221,38 @@ async def list_audit_logs(
all_rows = result.mappings().all() all_rows = result.mappings().all()
# Decrypt encrypted details concurrently # Decrypt encrypted details concurrently
decrypted_details = await _decrypt_details_batch( decrypted_details = await _decrypt_details_batch(all_rows, str(tenant_id))
all_rows, str(tenant_id)
)
output = io.StringIO() output = io.StringIO()
writer = csv.writer(output) writer = csv.writer(output)
writer.writerow([ writer.writerow(
"ID", "User Email", "Action", "Resource Type", [
"Resource ID", "Device", "Details", "IP Address", "Timestamp", "ID",
]) "User Email",
"Action",
"Resource Type",
"Resource ID",
"Device",
"Details",
"IP Address",
"Timestamp",
]
)
for row, details in zip(all_rows, decrypted_details): for row, details in zip(all_rows, decrypted_details):
details_str = json.dumps(details) if details else "{}" details_str = json.dumps(details) if details else "{}"
writer.writerow([ writer.writerow(
str(row["id"]), [
row["user_email"] or "", str(row["id"]),
row["action"], row["user_email"] or "",
row["resource_type"] or "", row["action"],
row["resource_id"] or "", row["resource_type"] or "",
row["device_name"] or "", row["resource_id"] or "",
details_str, row["device_name"] or "",
row["ip_address"] or "", details_str,
str(row["created_at"]), row["ip_address"] or "",
]) str(row["created_at"]),
]
)
output.seek(0) output.seek(0)
return StreamingResponse( return StreamingResponse(

View File

@@ -103,7 +103,11 @@ async def get_redis() -> aioredis.Redis:
# ─── SRP Zero-Knowledge Authentication ─────────────────────────────────────── # ─── SRP Zero-Knowledge Authentication ───────────────────────────────────────
@router.post("/srp/init", response_model=SRPInitResponse, summary="SRP Step 1: return salt and server ephemeral B") @router.post(
"/srp/init",
response_model=SRPInitResponse,
summary="SRP Step 1: return salt and server ephemeral B",
)
@limiter.limit("5/minute") @limiter.limit("5/minute")
async def srp_init_endpoint( async def srp_init_endpoint(
request: StarletteRequest, request: StarletteRequest,
@@ -137,9 +141,7 @@ async def srp_init_endpoint(
# Generate server ephemeral # Generate server ephemeral
try: try:
server_public, server_private = await srp_init( server_public, server_private = await srp_init(user.email, user.srp_verifier.hex())
user.email, user.srp_verifier.hex()
)
except ValueError as e: except ValueError as e:
logger.error("SRP init failed for %s: %s", user.email, e) logger.error("SRP init failed for %s: %s", user.email, e)
raise HTTPException( raise HTTPException(
@@ -150,13 +152,15 @@ async def srp_init_endpoint(
# Store session in Redis with 60s TTL # Store session in Redis with 60s TTL
session_id = secrets.token_urlsafe(16) session_id = secrets.token_urlsafe(16)
redis = await get_redis() redis = await get_redis()
session_data = json.dumps({ session_data = json.dumps(
"email": user.email, {
"server_private": server_private, "email": user.email,
"srp_verifier_hex": user.srp_verifier.hex(), "server_private": server_private,
"srp_salt_hex": user.srp_salt.hex(), "srp_verifier_hex": user.srp_verifier.hex(),
"user_id": str(user.id), "srp_salt_hex": user.srp_salt.hex(),
}) "user_id": str(user.id),
}
)
await redis.set(f"srp:session:{session_id}", session_data, ex=60) await redis.set(f"srp:session:{session_id}", session_data, ex=60)
return SRPInitResponse( return SRPInitResponse(
@@ -168,7 +172,11 @@ async def srp_init_endpoint(
) )
@router.post("/srp/verify", response_model=SRPVerifyResponse, summary="SRP Step 2: verify client proof and return tokens") @router.post(
"/srp/verify",
response_model=SRPVerifyResponse,
summary="SRP Step 2: verify client proof and return tokens",
)
@limiter.limit("5/minute") @limiter.limit("5/minute")
async def srp_verify_endpoint( async def srp_verify_endpoint(
request: StarletteRequest, request: StarletteRequest,
@@ -236,7 +244,9 @@ async def srp_verify_endpoint(
# Update last_login and clear upgrade flag on successful SRP login # Update last_login and clear upgrade flag on successful SRP login
await db.execute( await db.execute(
update(User).where(User.id == user.id).values( update(User)
.where(User.id == user.id)
.values(
last_login=datetime.now(UTC), last_login=datetime.now(UTC),
must_upgrade_auth=False, must_upgrade_auth=False,
) )
@@ -323,9 +333,7 @@ async def login(
Rate limited to 5 requests per minute per IP. Rate limited to 5 requests per minute per IP.
""" """
# Look up user by email (case-insensitive) # Look up user by email (case-insensitive)
result = await db.execute( result = await db.execute(select(User).where(User.email == body.email.lower()))
select(User).where(User.email == body.email.lower())
)
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
# Generic error — do not reveal whether email exists (no user enumeration) # Generic error — do not reveal whether email exists (no user enumeration)
@@ -389,7 +397,9 @@ async def login(
# Update last_login # Update last_login
await db.execute( await db.execute(
update(User).where(User.id == user.id).values( update(User)
.where(User.id == user.id)
.values(
last_login=datetime.now(UTC), last_login=datetime.now(UTC),
) )
) )
@@ -404,7 +414,10 @@ async def login(
user_id=user.id, user_id=user.id,
action="login_upgrade" if user.must_upgrade_auth else "login", action="login_upgrade" if user.must_upgrade_auth else "login",
resource_type="auth", resource_type="auth",
details={"email": user.email, **({"upgrade": "bcrypt_to_srp"} if user.must_upgrade_auth else {})}, details={
"email": user.email,
**({"upgrade": "bcrypt_to_srp"} if user.must_upgrade_auth else {}),
},
ip_address=request.client.host if request.client else None, ip_address=request.client.host if request.client else None,
) )
await audit_db.commit() await audit_db.commit()
@@ -440,7 +453,9 @@ async def refresh_token(
Rate limited to 10 requests per minute per IP. Rate limited to 10 requests per minute per IP.
""" """
# Resolve token: body takes precedence over cookie # Resolve token: body takes precedence over cookie
raw_token = (body.refresh_token if body and body.refresh_token else None) or refresh_token_cookie raw_token = (
body.refresh_token if body and body.refresh_token else None
) or refresh_token_cookie
if not raw_token: if not raw_token:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@@ -518,7 +533,9 @@ async def refresh_token(
) )
@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT, summary="Log out and clear session cookie") @router.post(
"/logout", status_code=status.HTTP_204_NO_CONTENT, summary="Log out and clear session cookie"
)
@limiter.limit("10/minute") @limiter.limit("10/minute")
async def logout( async def logout(
request: StarletteRequest, request: StarletteRequest,
@@ -535,7 +552,10 @@ async def logout(
tenant_id = current_user.tenant_id or uuid.UUID(int=0) tenant_id = current_user.tenant_id or uuid.UUID(int=0)
async with AdminAsyncSessionLocal() as audit_db: async with AdminAsyncSessionLocal() as audit_db:
await log_action( await log_action(
audit_db, tenant_id, current_user.user_id, "logout", audit_db,
tenant_id,
current_user.user_id,
"logout",
resource_type="auth", resource_type="auth",
ip_address=request.client.host if request.client else None, ip_address=request.client.host if request.client else None,
) )
@@ -558,7 +578,11 @@ async def logout(
) )
@router.post("/change-password", response_model=MessageResponse, summary="Change password for authenticated user") @router.post(
"/change-password",
response_model=MessageResponse,
summary="Change password for authenticated user",
)
@limiter.limit("3/minute") @limiter.limit("3/minute")
async def change_password( async def change_password(
request: StarletteRequest, request: StarletteRequest,
@@ -602,7 +626,9 @@ async def change_password(
existing_ks.hkdf_salt = base64.b64decode(body.hkdf_salt or "") existing_ks.hkdf_salt = base64.b64decode(body.hkdf_salt or "")
else: else:
# Legacy bcrypt user — verify current password # Legacy bcrypt user — verify current password
if not user.hashed_password or not verify_password(body.current_password, user.hashed_password): if not user.hashed_password or not verify_password(
body.current_password, user.hashed_password
):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Current password is incorrect", detail="Current password is incorrect",
@@ -822,7 +848,9 @@ async def get_emergency_kit_template(
) )
@router.post("/register-srp", response_model=MessageResponse, summary="Register SRP credentials for a user") @router.post(
"/register-srp", response_model=MessageResponse, summary="Register SRP credentials for a user"
)
@limiter.limit("3/minute") @limiter.limit("3/minute")
async def register_srp( async def register_srp(
request: StarletteRequest, request: StarletteRequest,
@@ -845,7 +873,9 @@ async def register_srp(
# Update user with SRP credentials and clear upgrade flag # Update user with SRP credentials and clear upgrade flag
await db.execute( await db.execute(
update(User).where(User.id == user.id).values( update(User)
.where(User.id == user.id)
.values(
srp_salt=bytes.fromhex(body.srp_salt), srp_salt=bytes.fromhex(body.srp_salt),
srp_verifier=bytes.fromhex(body.srp_verifier), srp_verifier=bytes.fromhex(body.srp_verifier),
auth_version=2, auth_version=2,
@@ -873,8 +903,11 @@ async def register_srp(
try: try:
async with AdminAsyncSessionLocal() as audit_db: async with AdminAsyncSessionLocal() as audit_db:
await log_key_access( await log_key_access(
audit_db, user.tenant_id or uuid.UUID(int=0), user.id, audit_db,
"create_key_set", resource_type="user_key_set", user.tenant_id or uuid.UUID(int=0),
user.id,
"create_key_set",
resource_type="user_key_set",
ip_address=request.client.host if request.client else None, ip_address=request.client.host if request.client else None,
) )
await audit_db.commit() await audit_db.commit()
@@ -901,11 +934,17 @@ async def create_sse_token(
token = secrets.token_urlsafe(32) token = secrets.token_urlsafe(32)
key = f"sse_token:{token}" key = f"sse_token:{token}"
# Store user context for the SSE endpoint to retrieve # Store user context for the SSE endpoint to retrieve
await redis.set(key, json.dumps({ await redis.set(
"user_id": str(current_user.user_id), key,
"tenant_id": str(current_user.tenant_id) if current_user.tenant_id else None, json.dumps(
"role": current_user.role, {
}), ex=30) # 30 second TTL "user_id": str(current_user.user_id),
"tenant_id": str(current_user.tenant_id) if current_user.tenant_id else None,
"role": current_user.role,
}
),
ex=30,
) # 30 second TTL
return {"token": token} return {"token": token}
@@ -977,9 +1016,7 @@ async def forgot_password(
""" """
generic_msg = "If an account with that email exists, a reset link has been sent." generic_msg = "If an account with that email exists, a reset link has been sent."
result = await db.execute( result = await db.execute(select(User).where(User.email == body.email.lower()))
select(User).where(User.email == body.email.lower())
)
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
if not user or not user.is_active: if not user or not user.is_active:
@@ -988,9 +1025,7 @@ async def forgot_password(
# Generate a secure token # Generate a secure token
raw_token = secrets.token_urlsafe(32) raw_token = secrets.token_urlsafe(32)
token_hash = _hash_token(raw_token) token_hash = _hash_token(raw_token)
expires_at = datetime.now(UTC) + timedelta( expires_at = datetime.now(UTC) + timedelta(minutes=settings.PASSWORD_RESET_TOKEN_EXPIRE_MINUTES)
minutes=settings.PASSWORD_RESET_TOKEN_EXPIRE_MINUTES
)
# Insert token record (using raw SQL to avoid importing the model globally) # Insert token record (using raw SQL to avoid importing the model globally)
from sqlalchemy import text from sqlalchemy import text

View File

@@ -12,9 +12,7 @@ RBAC: viewer = read-only (GET); tenant_admin and above = mutating actions.
""" """
import json import json
import logging
import uuid import uuid
from datetime import datetime, timezone
import nats import nats
import nats.aio.client import nats.aio.client
@@ -30,7 +28,7 @@ from app.database import get_db, set_tenant_context
from app.middleware.rate_limit import limiter from app.middleware.rate_limit import limiter
from app.middleware.rbac import require_min_role from app.middleware.rbac import require_min_role
from app.middleware.tenant_context import CurrentUser, get_current_user from app.middleware.tenant_context import CurrentUser, get_current_user
from app.models.certificate import CertificateAuthority, DeviceCertificate from app.models.certificate import DeviceCertificate
from app.models.device import Device from app.models.device import Device
from app.schemas.certificate import ( from app.schemas.certificate import (
BulkCertDeployRequest, BulkCertDeployRequest,
@@ -87,13 +85,15 @@ async def _deploy_cert_via_nats(
Dict with success, cert_name_on_device, and error fields. Dict with success, cert_name_on_device, and error fields.
""" """
nc = await _get_nats() nc = await _get_nats()
payload = json.dumps({ payload = json.dumps(
"device_id": device_id, {
"cert_pem": cert_pem, "device_id": device_id,
"key_pem": key_pem, "cert_pem": cert_pem,
"cert_name": cert_name, "key_pem": key_pem,
"ssh_port": ssh_port, "cert_name": cert_name,
}).encode() "ssh_port": ssh_port,
}
).encode()
try: try:
reply = await nc.request( reply = await nc.request(
@@ -121,9 +121,7 @@ async def _get_device_for_tenant(
db: AsyncSession, device_id: uuid.UUID, current_user: CurrentUser db: AsyncSession, device_id: uuid.UUID, current_user: CurrentUser
) -> Device: ) -> Device:
"""Fetch a device and verify tenant ownership.""" """Fetch a device and verify tenant ownership."""
result = await db.execute( result = await db.execute(select(Device).where(Device.id == device_id))
select(Device).where(Device.id == device_id)
)
device = result.scalar_one_or_none() device = result.scalar_one_or_none()
if device is None: if device is None:
raise HTTPException( raise HTTPException(
@@ -164,9 +162,7 @@ async def _get_cert_with_tenant_check(
db: AsyncSession, cert_id: uuid.UUID, tenant_id: uuid.UUID db: AsyncSession, cert_id: uuid.UUID, tenant_id: uuid.UUID
) -> DeviceCertificate: ) -> DeviceCertificate:
"""Fetch a device certificate and verify tenant ownership.""" """Fetch a device certificate and verify tenant ownership."""
result = await db.execute( result = await db.execute(select(DeviceCertificate).where(DeviceCertificate.id == cert_id))
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
)
cert = result.scalar_one_or_none() cert = result.scalar_one_or_none()
if cert is None: if cert is None:
raise HTTPException( raise HTTPException(
@@ -226,8 +222,12 @@ async def create_ca(
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "ca_create", db,
resource_type="certificate_authority", resource_id=str(ca.id), tenant_id,
current_user.user_id,
"ca_create",
resource_type="certificate_authority",
resource_id=str(ca.id),
details={"common_name": body.common_name, "validity_years": body.validity_years}, details={"common_name": body.common_name, "validity_years": body.validity_years},
) )
except Exception: except Exception:
@@ -332,8 +332,12 @@ async def sign_cert(
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "cert_sign", db,
resource_type="device_certificate", resource_id=str(cert.id), tenant_id,
current_user.user_id,
"cert_sign",
resource_type="device_certificate",
resource_id=str(cert.id),
device_id=body.device_id, device_id=body.device_id,
details={"hostname": device.hostname, "validity_days": body.validity_days}, details={"hostname": device.hostname, "validity_days": body.validity_days},
) )
@@ -404,17 +408,19 @@ async def deploy_cert(
await update_cert_status(db, cert_id, "deployed") await update_cert_status(db, cert_id, "deployed")
# Update device tls_mode to portal_ca # Update device tls_mode to portal_ca
device_result = await db.execute( device_result = await db.execute(select(Device).where(Device.id == cert.device_id))
select(Device).where(Device.id == cert.device_id)
)
device = device_result.scalar_one_or_none() device = device_result.scalar_one_or_none()
if device is not None: if device is not None:
device.tls_mode = "portal_ca" device.tls_mode = "portal_ca"
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "cert_deploy", db,
resource_type="device_certificate", resource_id=str(cert_id), tenant_id,
current_user.user_id,
"cert_deploy",
resource_type="device_certificate",
resource_id=str(cert_id),
device_id=cert.device_id, device_id=cert.device_id,
details={"cert_name_on_device": result.get("cert_name_on_device")}, details={"cert_name_on_device": result.get("cert_name_on_device")},
) )
@@ -528,36 +534,47 @@ async def bulk_deploy(
await update_cert_status(db, issued_cert.id, "deployed") await update_cert_status(db, issued_cert.id, "deployed")
device.tls_mode = "portal_ca" device.tls_mode = "portal_ca"
results.append(CertDeployResponse( results.append(
success=True, CertDeployResponse(
device_id=device_id, success=True,
cert_name_on_device=result.get("cert_name_on_device"), device_id=device_id,
)) cert_name_on_device=result.get("cert_name_on_device"),
)
)
else: else:
await update_cert_status(db, issued_cert.id, "issued") await update_cert_status(db, issued_cert.id, "issued")
results.append(CertDeployResponse( results.append(
success=False, CertDeployResponse(
device_id=device_id, success=False,
error=result.get("error"), device_id=device_id,
)) error=result.get("error"),
)
)
except HTTPException as e: except HTTPException as e:
results.append(CertDeployResponse( results.append(
success=False, CertDeployResponse(
device_id=device_id, success=False,
error=e.detail, device_id=device_id,
)) error=e.detail,
)
)
except Exception as e: except Exception as e:
logger.error("Bulk deploy error", device_id=str(device_id), error=str(e)) logger.error("Bulk deploy error", device_id=str(device_id), error=str(e))
results.append(CertDeployResponse( results.append(
success=False, CertDeployResponse(
device_id=device_id, success=False,
error=str(e), device_id=device_id,
)) error=str(e),
)
)
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "cert_bulk_deploy", db,
tenant_id,
current_user.user_id,
"cert_bulk_deploy",
resource_type="device_certificate", resource_type="device_certificate",
details={ details={
"device_count": len(body.device_ids), "device_count": len(body.device_ids),
@@ -619,17 +636,19 @@ async def revoke_cert(
) )
# Reset device tls_mode to insecure # Reset device tls_mode to insecure
device_result = await db.execute( device_result = await db.execute(select(Device).where(Device.id == cert.device_id))
select(Device).where(Device.id == cert.device_id)
)
device = device_result.scalar_one_or_none() device = device_result.scalar_one_or_none()
if device is not None: if device is not None:
device.tls_mode = "insecure" device.tls_mode = "insecure"
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "cert_revoke", db,
resource_type="device_certificate", resource_id=str(cert_id), tenant_id,
current_user.user_id,
"cert_revoke",
resource_type="device_certificate",
resource_id=str(cert_id),
device_id=cert.device_id, device_id=cert.device_id,
) )
except Exception: except Exception:
@@ -661,9 +680,7 @@ async def rotate_cert(
old_cert = await _get_cert_with_tenant_check(db, cert_id, tenant_id) old_cert = await _get_cert_with_tenant_check(db, cert_id, tenant_id)
# Get the device for hostname/IP # Get the device for hostname/IP
device_result = await db.execute( device_result = await db.execute(select(Device).where(Device.id == old_cert.device_id))
select(Device).where(Device.id == old_cert.device_id)
)
device = device_result.scalar_one_or_none() device = device_result.scalar_one_or_none()
if device is None: if device is None:
raise HTTPException( raise HTTPException(
@@ -722,8 +739,12 @@ async def rotate_cert(
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "cert_rotate", db,
resource_type="device_certificate", resource_id=str(new_cert.id), tenant_id,
current_user.user_id,
"cert_rotate",
resource_type="device_certificate",
resource_id=str(new_cert.id),
device_id=old_cert.device_id, device_id=old_cert.device_id,
details={ details={
"old_cert_id": str(cert_id), "old_cert_id": str(cert_id),

View File

@@ -43,6 +43,7 @@ async def _check_tenant_access(
"""Verify the current user is allowed to access the given tenant.""" """Verify the current user is allowed to access the given tenant."""
if current_user.is_super_admin: if current_user.is_super_admin:
from app.database import set_tenant_context from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id)) await set_tenant_context(db, str(tenant_id))
return return
if current_user.tenant_id != tenant_id: if current_user.tenant_id != tenant_id:
@@ -52,9 +53,7 @@ async def _check_tenant_access(
) )
async def _check_device_online( async def _check_device_online(db: AsyncSession, device_id: uuid.UUID) -> Device:
db: AsyncSession, device_id: uuid.UUID
) -> Device:
"""Verify the device exists and is online. Returns the Device object.""" """Verify the device exists and is online. Returns the Device object."""
result = await db.execute( result = await db.execute(
select(Device).where(Device.id == device_id) # type: ignore[arg-type] select(Device).where(Device.id == device_id) # type: ignore[arg-type]

View File

@@ -25,7 +25,6 @@ import asyncio
import json import json
import logging import logging
import uuid import uuid
from datetime import timezone, datetime
from typing import Any from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
@@ -67,6 +66,7 @@ async def _check_tenant_access(
""" """
if current_user.is_super_admin: if current_user.is_super_admin:
from app.database import set_tenant_context from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id)) await set_tenant_context(db, str(tenant_id))
return return
if current_user.tenant_id != tenant_id: if current_user.tenant_id != tenant_id:
@@ -291,14 +291,14 @@ async def get_export(
try: try:
from app.services.crypto import decrypt_data_transit from app.services.crypto import decrypt_data_transit
plaintext = await decrypt_data_transit( plaintext = await decrypt_data_transit(content_bytes.decode("utf-8"), str(tenant_id))
content_bytes.decode("utf-8"), str(tenant_id)
)
content_bytes = plaintext.encode("utf-8") content_bytes = plaintext.encode("utf-8")
except Exception as dec_err: except Exception as dec_err:
logger.error( logger.error(
"Failed to decrypt export for device %s sha %s: %s", "Failed to decrypt export for device %s sha %s: %s",
device_id, commit_sha, dec_err, device_id,
commit_sha,
dec_err,
) )
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -370,7 +370,9 @@ async def get_binary(
except Exception as dec_err: except Exception as dec_err:
logger.error( logger.error(
"Failed to decrypt binary backup for device %s sha %s: %s", "Failed to decrypt binary backup for device %s sha %s: %s",
device_id, commit_sha, dec_err, device_id,
commit_sha,
dec_err,
) )
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -380,9 +382,7 @@ async def get_binary(
return Response( return Response(
content=content_bytes, content=content_bytes,
media_type="application/octet-stream", media_type="application/octet-stream",
headers={ headers={"Content-Disposition": f'attachment; filename="backup-{commit_sha[:8]}.bin"'},
"Content-Disposition": f'attachment; filename="backup-{commit_sha[:8]}.bin"'
},
) )
@@ -445,6 +445,7 @@ async def preview_restore(
key, key,
) )
import json import json
creds = json.loads(creds_json) creds = json.loads(creds_json)
current_text = await backup_service.capture_export( current_text = await backup_service.capture_export(
device.ip_address, device.ip_address,
@@ -578,9 +579,7 @@ async def emergency_rollback(
.where( .where(
ConfigBackupRun.device_id == device_id, # type: ignore[arg-type] ConfigBackupRun.device_id == device_id, # type: ignore[arg-type]
ConfigBackupRun.tenant_id == tenant_id, # type: ignore[arg-type] ConfigBackupRun.tenant_id == tenant_id, # type: ignore[arg-type]
ConfigBackupRun.trigger_type.in_( ConfigBackupRun.trigger_type.in_(["pre-restore", "checkpoint", "pre-template-push"]),
["pre-restore", "checkpoint", "pre-template-push"]
),
) )
.order_by(ConfigBackupRun.created_at.desc()) .order_by(ConfigBackupRun.created_at.desc())
.limit(1) .limit(1)
@@ -735,6 +734,7 @@ async def update_schedule(
# Hot-reload the scheduler so changes take effect immediately # Hot-reload the scheduler so changes take effect immediately
from app.services.backup_scheduler import on_schedule_change from app.services.backup_scheduler import on_schedule_change
await on_schedule_change(tenant_id, device_id) await on_schedule_change(tenant_id, device_id)
return { return {
@@ -758,6 +758,7 @@ async def _get_nats():
Reuses the same lazy-init pattern as routeros_proxy._get_nats(). Reuses the same lazy-init pattern as routeros_proxy._get_nats().
""" """
from app.services.routeros_proxy import _get_nats as _proxy_get_nats from app.services.routeros_proxy import _get_nats as _proxy_get_nats
return await _proxy_get_nats() return await _proxy_get_nats()
@@ -839,6 +840,7 @@ async def trigger_config_snapshot(
if reply_data.get("status") == "success": if reply_data.get("status") == "success":
try: try:
from app.services.audit_service import log_action from app.services.audit_service import log_action
await log_action( await log_action(
db, db,
tenant_id, tenant_id,

View File

@@ -64,9 +64,7 @@ async def _check_tenant_access(
await set_tenant_context(db, str(tenant_id)) await set_tenant_context(db, str(tenant_id))
async def _check_device_online( async def _check_device_online(db: AsyncSession, device_id: uuid.UUID) -> Device:
db: AsyncSession, device_id: uuid.UUID
) -> Device:
"""Verify the device exists and is online. Returns the Device object.""" """Verify the device exists and is online. Returns the Device object."""
result = await db.execute( result = await db.execute(
select(Device).where(Device.id == device_id) # type: ignore[arg-type] select(Device).where(Device.id == device_id) # type: ignore[arg-type]
@@ -201,8 +199,12 @@ async def add_entry(
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "config_add", db,
resource_type="config", resource_id=str(device_id), tenant_id,
current_user.user_id,
"config_add",
resource_type="config",
resource_id=str(device_id),
device_id=device_id, device_id=device_id,
details={"path": body.path, "properties": body.properties}, details={"path": body.path, "properties": body.properties},
) )
@@ -255,8 +257,12 @@ async def set_entry(
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "config_set", db,
resource_type="config", resource_id=str(device_id), tenant_id,
current_user.user_id,
"config_set",
resource_type="config",
resource_id=str(device_id),
device_id=device_id, device_id=device_id,
details={"path": body.path, "entry_id": body.entry_id, "properties": body.properties}, details={"path": body.path, "entry_id": body.entry_id, "properties": body.properties},
) )
@@ -286,9 +292,7 @@ async def remove_entry(
await _check_device_online(db, device_id) await _check_device_online(db, device_id)
check_path_safety(body.path, write=True) check_path_safety(body.path, write=True)
result = await routeros_proxy.remove_entry( result = await routeros_proxy.remove_entry(str(device_id), body.path, body.entry_id)
str(device_id), body.path, body.entry_id
)
if not result.get("success"): if not result.get("success"):
raise HTTPException( raise HTTPException(
@@ -309,8 +313,12 @@ async def remove_entry(
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "config_remove", db,
resource_type="config", resource_id=str(device_id), tenant_id,
current_user.user_id,
"config_remove",
resource_type="config",
resource_id=str(device_id),
device_id=device_id, device_id=device_id,
details={"path": body.path, "entry_id": body.entry_id}, details={"path": body.path, "entry_id": body.entry_id},
) )
@@ -360,8 +368,12 @@ async def execute_command(
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "config_execute", db,
resource_type="config", resource_id=str(device_id), tenant_id,
current_user.user_id,
"config_execute",
resource_type="config",
resource_id=str(device_id),
device_id=device_id, device_id=device_id,
details={"command": body.command}, details={"command": body.command},
) )

View File

@@ -43,6 +43,7 @@ async def _check_tenant_access(
""" """
if current_user.is_super_admin: if current_user.is_super_admin:
from app.database import set_tenant_context from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id)) await set_tenant_context(db, str(tenant_id))
return return
if current_user.tenant_id != tenant_id: if current_user.tenant_id != tenant_id:
@@ -115,9 +116,7 @@ async def view_snapshot(
session=db, session=db,
) )
except Exception: except Exception:
logger.exception( logger.exception("Failed to decrypt snapshot %s for device %s", snapshot_id, device_id)
"Failed to decrypt snapshot %s for device %s", snapshot_id, device_id
)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to decrypt snapshot content", detail="Failed to decrypt snapshot content",

View File

@@ -29,12 +29,14 @@ router = APIRouter(tags=["device-logs"])
# Helpers (same pattern as config_editor.py) # Helpers (same pattern as config_editor.py)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def _check_tenant_access( async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None: ) -> None:
"""Verify the current user is allowed to access the given tenant.""" """Verify the current user is allowed to access the given tenant."""
if current_user.is_super_admin: if current_user.is_super_admin:
from app.database import set_tenant_context from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id)) await set_tenant_context(db, str(tenant_id))
return return
if current_user.tenant_id != tenant_id: if current_user.tenant_id != tenant_id:
@@ -44,16 +46,12 @@ async def _check_tenant_access(
) )
async def _check_device_exists( async def _check_device_exists(db: AsyncSession, device_id: uuid.UUID) -> None:
db: AsyncSession, device_id: uuid.UUID
) -> None:
"""Verify the device exists (does not require online status for logs).""" """Verify the device exists (does not require online status for logs)."""
from sqlalchemy import select from sqlalchemy import select
from app.models.device import Device from app.models.device import Device
result = await db.execute( result = await db.execute(select(Device).where(Device.id == device_id))
select(Device).where(Device.id == device_id)
)
device = result.scalar_one_or_none() device = result.scalar_one_or_none()
if device is None: if device is None:
raise HTTPException( raise HTTPException(
@@ -66,6 +64,7 @@ async def _check_device_exists(
# Response model # Response model
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class LogEntry(BaseModel): class LogEntry(BaseModel):
time: str time: str
topics: str topics: str
@@ -82,6 +81,7 @@ class LogsResponse(BaseModel):
# Endpoint # Endpoint
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get( @router.get(
"/tenants/{tenant_id}/devices/{device_id}/logs", "/tenants/{tenant_id}/devices/{device_id}/logs",
response_model=LogsResponse, response_model=LogsResponse,

View File

@@ -72,9 +72,7 @@ async def update_tag(
) -> DeviceTagResponse: ) -> DeviceTagResponse:
"""Update a device tag. Requires operator role or above.""" """Update a device tag. Requires operator role or above."""
await _check_tenant_access(current_user, tenant_id, db) await _check_tenant_access(current_user, tenant_id, db)
return await device_service.update_tag( return await device_service.update_tag(db=db, tenant_id=tenant_id, tag_id=tag_id, data=data)
db=db, tenant_id=tenant_id, tag_id=tag_id, data=data
)
@router.delete( @router.delete(

View File

@@ -23,7 +23,6 @@ from app.database import get_db
from app.middleware.rate_limit import limiter from app.middleware.rate_limit import limiter
from app.services.audit_service import log_action from app.services.audit_service import log_action
from app.middleware.rbac import ( from app.middleware.rbac import (
require_min_role,
require_operator_or_above, require_operator_or_above,
require_scope, require_scope,
require_tenant_admin_or_above, require_tenant_admin_or_above,
@@ -57,6 +56,7 @@ async def _check_tenant_access(
if current_user.is_super_admin: if current_user.is_super_admin:
# Re-set tenant context to the target tenant so RLS allows the operation # Re-set tenant context to the target tenant so RLS allows the operation
from app.database import set_tenant_context from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id)) await set_tenant_context(db, str(tenant_id))
return return
if current_user.tenant_id != tenant_id: if current_user.tenant_id != tenant_id:
@@ -138,8 +138,12 @@ async def create_device(
) )
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "device_create", db,
resource_type="device", resource_id=str(result.id), tenant_id,
current_user.user_id,
"device_create",
resource_type="device",
resource_id=str(result.id),
details={"hostname": data.hostname, "ip_address": data.ip_address}, details={"hostname": data.hostname, "ip_address": data.ip_address},
ip_address=request.client.host if request.client else None, ip_address=request.client.host if request.client else None,
) )
@@ -191,8 +195,12 @@ async def update_device(
) )
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "device_update", db,
resource_type="device", resource_id=str(device_id), tenant_id,
current_user.user_id,
"device_update",
resource_type="device",
resource_id=str(device_id),
device_id=device_id, device_id=device_id,
details={"changes": data.model_dump(exclude_unset=True)}, details={"changes": data.model_dump(exclude_unset=True)},
ip_address=request.client.host if request.client else None, ip_address=request.client.host if request.client else None,
@@ -220,8 +228,12 @@ async def delete_device(
await _check_tenant_access(current_user, tenant_id, db) await _check_tenant_access(current_user, tenant_id, db)
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "device_delete", db,
resource_type="device", resource_id=str(device_id), tenant_id,
current_user.user_id,
"device_delete",
resource_type="device",
resource_id=str(device_id),
device_id=device_id, device_id=device_id,
ip_address=request.client.host if request.client else None, ip_address=request.client.host if request.client else None,
) )
@@ -262,14 +274,21 @@ async def scan_devices(
discovered = await scan_subnet(data.cidr) discovered = await scan_subnet(data.cidr)
import ipaddress import ipaddress
network = ipaddress.ip_network(data.cidr, strict=False) network = ipaddress.ip_network(data.cidr, strict=False)
total_scanned = network.num_addresses - 2 if network.num_addresses > 2 else network.num_addresses total_scanned = (
network.num_addresses - 2 if network.num_addresses > 2 else network.num_addresses
)
# Audit log the scan (fire-and-forget — never breaks the response) # Audit log the scan (fire-and-forget — never breaks the response)
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "subnet_scan", db,
resource_type="network", resource_id=data.cidr, tenant_id,
current_user.user_id,
"subnet_scan",
resource_type="network",
resource_id=data.cidr,
details={ details={
"cidr": data.cidr, "cidr": data.cidr,
"devices_found": len(discovered), "devices_found": len(discovered),
@@ -322,10 +341,12 @@ async def bulk_add_devices(
password = dev_data.password or data.shared_password password = dev_data.password or data.shared_password
if not username or not password: if not username or not password:
failed.append({ failed.append(
"ip_address": dev_data.ip_address, {
"error": "No credentials provided (set per-device or shared credentials)", "ip_address": dev_data.ip_address,
}) "error": "No credentials provided (set per-device or shared credentials)",
}
)
continue continue
create_data = DeviceCreate( create_data = DeviceCreate(
@@ -347,9 +368,16 @@ async def bulk_add_devices(
added.append(device) added.append(device)
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "device_adopt", db,
resource_type="device", resource_id=str(device.id), tenant_id,
details={"hostname": create_data.hostname, "ip_address": create_data.ip_address}, current_user.user_id,
"device_adopt",
resource_type="device",
resource_id=str(device.id),
details={
"hostname": create_data.hostname,
"ip_address": create_data.ip_address,
},
ip_address=request.client.host if request.client else None, ip_address=request.client.host if request.client else None,
) )
except Exception: except Exception:

View File

@@ -90,16 +90,18 @@ async def list_events(
for row in alert_result.fetchall(): for row in alert_result.fetchall():
alert_status = row[1] or "firing" alert_status = row[1] or "firing"
metric = row[3] or "unknown" metric = row[3] or "unknown"
events.append({ events.append(
"id": str(row[0]), {
"event_type": "alert", "id": str(row[0]),
"severity": row[2], "event_type": "alert",
"title": f"{alert_status}: {metric}", "severity": row[2],
"description": row[4] or f"Alert {alert_status} for {metric}", "title": f"{alert_status}: {metric}",
"device_hostname": row[7], "description": row[4] or f"Alert {alert_status} for {metric}",
"device_id": str(row[6]) if row[6] else None, "device_hostname": row[7],
"timestamp": row[5].isoformat() if row[5] else None, "device_id": str(row[6]) if row[6] else None,
}) "timestamp": row[5].isoformat() if row[5] else None,
}
)
# 2. Device status changes (inferred from current status + last_seen) # 2. Device status changes (inferred from current status + last_seen)
if not event_type or event_type == "status_change": if not event_type or event_type == "status_change":
@@ -117,16 +119,18 @@ async def list_events(
device_status = row[2] or "unknown" device_status = row[2] or "unknown"
hostname = row[1] or "Unknown device" hostname = row[1] or "Unknown device"
severity = "info" if device_status == "online" else "warning" severity = "info" if device_status == "online" else "warning"
events.append({ events.append(
"id": f"status-{row[0]}", {
"event_type": "status_change", "id": f"status-{row[0]}",
"severity": severity, "event_type": "status_change",
"title": f"Device {device_status}", "severity": severity,
"description": f"{hostname} is now {device_status}", "title": f"Device {device_status}",
"device_hostname": hostname, "description": f"{hostname} is now {device_status}",
"device_id": str(row[0]), "device_hostname": hostname,
"timestamp": row[3].isoformat() if row[3] else None, "device_id": str(row[0]),
}) "timestamp": row[3].isoformat() if row[3] else None,
}
)
# 3. Config backup runs # 3. Config backup runs
if not event_type or event_type == "config_backup": if not event_type or event_type == "config_backup":
@@ -144,16 +148,18 @@ async def list_events(
for row in backup_result.fetchall(): for row in backup_result.fetchall():
trigger_type = row[1] or "manual" trigger_type = row[1] or "manual"
hostname = row[4] or "Unknown device" hostname = row[4] or "Unknown device"
events.append({ events.append(
"id": str(row[0]), {
"event_type": "config_backup", "id": str(row[0]),
"severity": "info", "event_type": "config_backup",
"title": "Config backup", "severity": "info",
"description": f"{trigger_type} backup completed for {hostname}", "title": "Config backup",
"device_hostname": hostname, "description": f"{trigger_type} backup completed for {hostname}",
"device_id": str(row[3]) if row[3] else None, "device_hostname": hostname,
"timestamp": row[2].isoformat() if row[2] else None, "device_id": str(row[3]) if row[3] else None,
}) "timestamp": row[2].isoformat() if row[2] else None,
}
)
# Sort all events by timestamp descending, then apply final limit # Sort all events by timestamp descending, then apply final limit
events.sort( events.sort(

View File

@@ -67,6 +67,7 @@ async def get_firmware_overview(
await _check_tenant_access(current_user, tenant_id, db) await _check_tenant_access(current_user, tenant_id, db)
from app.services.firmware_service import get_firmware_overview as _get_overview from app.services.firmware_service import get_firmware_overview as _get_overview
return await _get_overview(str(tenant_id)) return await _get_overview(str(tenant_id))
@@ -206,6 +207,7 @@ async def trigger_firmware_check(
raise HTTPException(status_code=403, detail="Super admin only") raise HTTPException(status_code=403, detail="Super admin only")
from app.services.firmware_service import check_latest_versions from app.services.firmware_service import check_latest_versions
results = await check_latest_versions() results = await check_latest_versions()
return {"status": "ok", "versions_discovered": len(results), "versions": results} return {"status": "ok", "versions_discovered": len(results), "versions": results}
@@ -221,6 +223,7 @@ async def list_firmware_cache(
raise HTTPException(status_code=403, detail="Super admin only") raise HTTPException(status_code=403, detail="Super admin only")
from app.services.firmware_service import get_cached_firmware from app.services.firmware_service import get_cached_firmware
return await get_cached_firmware() return await get_cached_firmware()
@@ -236,6 +239,7 @@ async def download_firmware(
raise HTTPException(status_code=403, detail="Super admin only") raise HTTPException(status_code=403, detail="Super admin only")
from app.services.firmware_service import download_firmware as _download from app.services.firmware_service import download_firmware as _download
path = await _download(body.architecture, body.channel, body.version) path = await _download(body.architecture, body.channel, body.version)
return {"status": "ok", "path": path} return {"status": "ok", "path": path}
@@ -324,15 +328,21 @@ async def start_firmware_upgrade(
# Schedule or start immediately # Schedule or start immediately
if body.scheduled_at: if body.scheduled_at:
from app.services.upgrade_service import schedule_upgrade from app.services.upgrade_service import schedule_upgrade
schedule_upgrade(job_id, datetime.fromisoformat(body.scheduled_at)) schedule_upgrade(job_id, datetime.fromisoformat(body.scheduled_at))
else: else:
from app.services.upgrade_service import start_upgrade from app.services.upgrade_service import start_upgrade
asyncio.create_task(start_upgrade(job_id)) asyncio.create_task(start_upgrade(job_id))
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "firmware_upgrade", db,
resource_type="firmware", resource_id=job_id, tenant_id,
current_user.user_id,
"firmware_upgrade",
resource_type="firmware",
resource_id=job_id,
device_id=uuid.UUID(body.device_id), device_id=uuid.UUID(body.device_id),
details={"target_version": body.target_version, "channel": body.channel}, details={"target_version": body.target_version, "channel": body.channel},
) )
@@ -406,9 +416,11 @@ async def start_mass_firmware_upgrade(
# Schedule or start immediately # Schedule or start immediately
if body.scheduled_at: if body.scheduled_at:
from app.services.upgrade_service import schedule_mass_upgrade from app.services.upgrade_service import schedule_mass_upgrade
schedule_mass_upgrade(rollout_group_id, datetime.fromisoformat(body.scheduled_at)) schedule_mass_upgrade(rollout_group_id, datetime.fromisoformat(body.scheduled_at))
else: else:
from app.services.upgrade_service import start_mass_upgrade from app.services.upgrade_service import start_mass_upgrade
asyncio.create_task(start_mass_upgrade(rollout_group_id)) asyncio.create_task(start_mass_upgrade(rollout_group_id))
return { return {
@@ -639,6 +651,7 @@ async def cancel_upgrade_endpoint(
raise HTTPException(403, "Viewers cannot cancel upgrades") raise HTTPException(403, "Viewers cannot cancel upgrades")
from app.services.upgrade_service import cancel_upgrade from app.services.upgrade_service import cancel_upgrade
await cancel_upgrade(str(job_id)) await cancel_upgrade(str(job_id))
return {"status": "ok", "message": "Upgrade cancelled"} return {"status": "ok", "message": "Upgrade cancelled"}
@@ -662,6 +675,7 @@ async def retry_upgrade_endpoint(
raise HTTPException(403, "Viewers cannot retry upgrades") raise HTTPException(403, "Viewers cannot retry upgrades")
from app.services.upgrade_service import retry_failed_upgrade from app.services.upgrade_service import retry_failed_upgrade
await retry_failed_upgrade(str(job_id)) await retry_failed_upgrade(str(job_id))
return {"status": "ok", "message": "Upgrade retry started"} return {"status": "ok", "message": "Upgrade retry started"}
@@ -685,6 +699,7 @@ async def resume_rollout_endpoint(
raise HTTPException(403, "Viewers cannot resume rollouts") raise HTTPException(403, "Viewers cannot resume rollouts")
from app.services.upgrade_service import resume_mass_upgrade from app.services.upgrade_service import resume_mass_upgrade
await resume_mass_upgrade(str(rollout_group_id)) await resume_mass_upgrade(str(rollout_group_id))
return {"status": "ok", "message": "Rollout resumed"} return {"status": "ok", "message": "Rollout resumed"}
@@ -708,5 +723,6 @@ async def abort_rollout_endpoint(
raise HTTPException(403, "Viewers cannot abort rollouts") raise HTTPException(403, "Viewers cannot abort rollouts")
from app.services.upgrade_service import abort_mass_upgrade from app.services.upgrade_service import abort_mass_upgrade
aborted = await abort_mass_upgrade(str(rollout_group_id)) aborted = await abort_mass_upgrade(str(rollout_group_id))
return {"status": "ok", "aborted_count": aborted} return {"status": "ok", "aborted_count": aborted}

View File

@@ -299,9 +299,7 @@ async def delete_maintenance_window(
_require_operator(current_user) _require_operator(current_user)
result = await db.execute( result = await db.execute(
text( text("DELETE FROM maintenance_windows WHERE id = CAST(:id AS uuid) RETURNING id"),
"DELETE FROM maintenance_windows WHERE id = CAST(:id AS uuid) RETURNING id"
),
{"id": str(window_id)}, {"id": str(window_id)},
) )
if not result.fetchone(): if not result.fetchone():

View File

@@ -65,6 +65,7 @@ async def _check_tenant_access(
if current_user.is_super_admin: if current_user.is_super_admin:
# Re-set tenant context to the target tenant so RLS allows the operation # Re-set tenant context to the target tenant so RLS allows the operation
from app.database import set_tenant_context from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id)) await set_tenant_context(db, str(tenant_id))
return return
if current_user.tenant_id != tenant_id: if current_user.tenant_id != tenant_id:

View File

@@ -81,9 +81,12 @@ async def _get_device(db: AsyncSession, tenant_id: uuid.UUID, device_id: uuid.UU
return device return device
async def _check_tenant_access(current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession) -> None: async def _check_tenant_access(
current_user: CurrentUser, tenant_id: uuid.UUID, db: AsyncSession
) -> None:
if current_user.is_super_admin: if current_user.is_super_admin:
from app.database import set_tenant_context from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id)) await set_tenant_context(db, str(tenant_id))
return return
if current_user.tenant_id != tenant_id: if current_user.tenant_id != tenant_id:
@@ -124,8 +127,12 @@ async def open_winbox_session(
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "winbox_tunnel_open", db,
resource_type="device", resource_id=str(device_id), tenant_id,
current_user.user_id,
"winbox_tunnel_open",
resource_type="device",
resource_id=str(device_id),
device_id=device_id, device_id=device_id,
details={"source_ip": source_ip}, details={"source_ip": source_ip},
ip_address=source_ip, ip_address=source_ip,
@@ -133,24 +140,31 @@ async def open_winbox_session(
except Exception: except Exception:
pass pass
payload = json.dumps({ payload = json.dumps(
"device_id": str(device_id), {
"tenant_id": str(tenant_id), "device_id": str(device_id),
"user_id": str(current_user.user_id), "tenant_id": str(tenant_id),
"target_port": 8291, "user_id": str(current_user.user_id),
}).encode() "target_port": 8291,
}
).encode()
try: try:
nc = await _get_nats() nc = await _get_nats()
msg = await nc.request("tunnel.open", payload, timeout=10) msg = await nc.request("tunnel.open", payload, timeout=10)
except Exception as exc: except Exception as exc:
logger.error("NATS tunnel.open failed: %s", exc) logger.error("NATS tunnel.open failed: %s", exc)
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Tunnel service unavailable") raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Tunnel service unavailable"
)
try: try:
data = json.loads(msg.data) data = json.loads(msg.data)
except Exception: except Exception:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Invalid response from tunnel service") raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Invalid response from tunnel service",
)
if "error" in data: if "error" in data:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=data["error"]) raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=data["error"])
@@ -158,11 +172,16 @@ async def open_winbox_session(
port = data.get("local_port") port = data.get("local_port")
tunnel_id = data.get("tunnel_id", "") tunnel_id = data.get("tunnel_id", "")
if not isinstance(port, int) or not (49000 <= port <= 49100): if not isinstance(port, int) or not (49000 <= port <= 49100):
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Invalid port allocation from tunnel service") raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Invalid port allocation from tunnel service",
)
# Derive the tunnel host from the request so remote clients get the server's # Derive the tunnel host from the request so remote clients get the server's
# address rather than 127.0.0.1 (which would point to the user's own machine). # address rather than 127.0.0.1 (which would point to the user's own machine).
tunnel_host = (request.headers.get("x-forwarded-host") or request.headers.get("host") or "127.0.0.1") tunnel_host = (
request.headers.get("x-forwarded-host") or request.headers.get("host") or "127.0.0.1"
)
# Strip port from host header (e.g. "10.101.0.175:8001" → "10.101.0.175") # Strip port from host header (e.g. "10.101.0.175:8001" → "10.101.0.175")
tunnel_host = tunnel_host.split(":")[0] tunnel_host = tunnel_host.split(":")[0]
@@ -213,8 +232,12 @@ async def open_ssh_session(
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "ssh_session_open", db,
resource_type="device", resource_id=str(device_id), tenant_id,
current_user.user_id,
"ssh_session_open",
resource_type="device",
resource_id=str(device_id),
device_id=device_id, device_id=device_id,
details={"source_ip": source_ip, "cols": body.cols, "rows": body.rows}, details={"source_ip": source_ip, "cols": body.cols, "rows": body.rows},
ip_address=source_ip, ip_address=source_ip,
@@ -223,22 +246,26 @@ async def open_ssh_session(
pass pass
token = secrets.token_urlsafe(32) token = secrets.token_urlsafe(32)
token_payload = json.dumps({ token_payload = json.dumps(
"device_id": str(device_id), {
"tenant_id": str(tenant_id), "device_id": str(device_id),
"user_id": str(current_user.user_id), "tenant_id": str(tenant_id),
"source_ip": source_ip, "user_id": str(current_user.user_id),
"cols": body.cols, "source_ip": source_ip,
"rows": body.rows, "cols": body.cols,
"created_at": int(time.time()), "rows": body.rows,
}) "created_at": int(time.time()),
}
)
try: try:
rd = await _get_redis() rd = await _get_redis()
await rd.setex(f"ssh:token:{token}", 120, token_payload) await rd.setex(f"ssh:token:{token}", 120, token_payload)
except Exception as exc: except Exception as exc:
logger.error("Redis setex failed for SSH token: %s", exc) logger.error("Redis setex failed for SSH token: %s", exc)
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Session store unavailable") raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Session store unavailable"
)
return SSHSessionResponse( return SSHSessionResponse(
token=token, token=token,
@@ -274,8 +301,12 @@ async def close_winbox_session(
try: try:
await log_action( await log_action(
db, tenant_id, current_user.user_id, "winbox_tunnel_close", db,
resource_type="device", resource_id=str(device_id), tenant_id,
current_user.user_id,
"winbox_tunnel_close",
resource_type="device",
resource_id=str(device_id),
device_id=device_id, device_id=device_id,
details={"tunnel_id": tunnel_id, "source_ip": source_ip}, details={"tunnel_id": tunnel_id, "source_ip": source_ip},
ip_address=source_ip, ip_address=source_ip,

View File

@@ -72,7 +72,11 @@ async def _set_system_settings(updates: dict, user_id: str) -> None:
ON CONFLICT (key) DO UPDATE ON CONFLICT (key) DO UPDATE
SET value = :value, updated_by = CAST(:user_id AS uuid), updated_at = now() SET value = :value, updated_by = CAST(:user_id AS uuid), updated_at = now()
"""), """),
{"key": key, "value": str(value) if value is not None else None, "user_id": user_id}, {
"key": key,
"value": str(value) if value is not None else None,
"user_id": user_id,
},
) )
await session.commit() await session.commit()
@@ -100,7 +104,8 @@ async def get_smtp_settings(user=Depends(require_role("super_admin"))):
"smtp_host": db_settings.get("smtp_host") or settings.SMTP_HOST, "smtp_host": db_settings.get("smtp_host") or settings.SMTP_HOST,
"smtp_port": int(db_settings.get("smtp_port") or settings.SMTP_PORT), "smtp_port": int(db_settings.get("smtp_port") or settings.SMTP_PORT),
"smtp_user": db_settings.get("smtp_user") or settings.SMTP_USER or "", "smtp_user": db_settings.get("smtp_user") or settings.SMTP_USER or "",
"smtp_use_tls": (db_settings.get("smtp_use_tls") or str(settings.SMTP_USE_TLS)).lower() == "true", "smtp_use_tls": (db_settings.get("smtp_use_tls") or str(settings.SMTP_USE_TLS)).lower()
== "true",
"smtp_from_address": db_settings.get("smtp_from_address") or settings.SMTP_FROM_ADDRESS, "smtp_from_address": db_settings.get("smtp_from_address") or settings.SMTP_FROM_ADDRESS,
"smtp_provider": db_settings.get("smtp_provider") or "custom", "smtp_provider": db_settings.get("smtp_provider") or "custom",
"smtp_password_set": bool(db_settings.get("smtp_password") or settings.SMTP_PASSWORD), "smtp_password_set": bool(db_settings.get("smtp_password") or settings.SMTP_PASSWORD),

View File

@@ -32,6 +32,7 @@ async def _get_sse_redis() -> aioredis.Redis:
global _redis global _redis
if _redis is None: if _redis is None:
from app.config import settings from app.config import settings
_redis = aioredis.from_url(settings.REDIS_URL, decode_responses=True) _redis = aioredis.from_url(settings.REDIS_URL, decode_responses=True)
return _redis return _redis
@@ -70,7 +71,9 @@ async def _validate_sse_token(token: str) -> dict:
async def event_stream( async def event_stream(
request: Request, request: Request,
tenant_id: uuid.UUID, tenant_id: uuid.UUID,
token: str = Query(..., description="Short-lived SSE exchange token (from POST /auth/sse-token)"), token: str = Query(
..., description="Short-lived SSE exchange token (from POST /auth/sse-token)"
),
) -> EventSourceResponse: ) -> EventSourceResponse:
"""Stream real-time events for a tenant via Server-Sent Events. """Stream real-time events for a tenant via Server-Sent Events.
@@ -87,7 +90,9 @@ async def event_stream(
user_id = user_context.get("user_id", "") user_id = user_context.get("user_id", "")
# Authorization: user must belong to the requested tenant or be super_admin # Authorization: user must belong to the requested tenant or be super_admin
if user_role != "super_admin" and (user_tenant_id is None or str(user_tenant_id) != str(tenant_id)): if user_role != "super_admin" and (
user_tenant_id is None or str(user_tenant_id) != str(tenant_id)
):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Not authorized for this tenant", detail="Not authorized for this tenant",

View File

@@ -21,7 +21,6 @@ RBAC: viewer = read (GET/preview); operator and above = write (POST/PUT/DELETE/p
import asyncio import asyncio
import logging import logging
import uuid import uuid
from datetime import datetime, timezone
from typing import Any, Optional from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
@@ -54,6 +53,7 @@ async def _check_tenant_access(
"""Verify the current user is allowed to access the given tenant.""" """Verify the current user is allowed to access the given tenant."""
if current_user.is_super_admin: if current_user.is_super_admin:
from app.database import set_tenant_context from app.database import set_tenant_context
await set_tenant_context(db, str(tenant_id)) await set_tenant_context(db, str(tenant_id))
return return
if current_user.tenant_id != tenant_id: if current_user.tenant_id != tenant_id:
@@ -191,7 +191,8 @@ async def create_template(
if unmatched: if unmatched:
logger.warning( logger.warning(
"Template '%s' has undeclared variables: %s (auto-adding as string type)", "Template '%s' has undeclared variables: %s (auto-adding as string type)",
body.name, unmatched, body.name,
unmatched,
) )
# Create template # Create template
@@ -553,11 +554,13 @@ async def push_template(
status="pending", status="pending",
) )
db.add(job) db.add(job)
jobs_created.append({ jobs_created.append(
"job_id": str(job.id), {
"device_id": str(device.id), "job_id": str(job.id),
"device_hostname": device.hostname, "device_id": str(device.id),
}) "device_hostname": device.hostname,
}
)
await db.flush() await db.flush()
@@ -598,14 +601,16 @@ async def push_status(
jobs = [] jobs = []
for job, hostname in rows: for job, hostname in rows:
jobs.append({ jobs.append(
"device_id": str(job.device_id), {
"hostname": hostname, "device_id": str(job.device_id),
"status": job.status, "hostname": hostname,
"error_message": job.error_message, "status": job.status,
"started_at": job.started_at.isoformat() if job.started_at else None, "error_message": job.error_message,
"completed_at": job.completed_at.isoformat() if job.completed_at else None, "started_at": job.started_at.isoformat() if job.started_at else None,
}) "completed_at": job.completed_at.isoformat() if job.completed_at else None,
}
)
return { return {
"rollout_id": str(rollout_id), "rollout_id": str(rollout_id),

View File

@@ -9,7 +9,6 @@ DELETE /api/tenants/{id} — delete tenant (super_admin only)
""" """
import uuid import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy import func, select, text from sqlalchemy import func, select, text
@@ -17,7 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.middleware.rate_limit import limiter from app.middleware.rate_limit import limiter
from app.database import get_admin_db, get_db from app.database import get_admin_db
from app.middleware.rbac import require_super_admin, require_tenant_admin_or_above from app.middleware.rbac import require_super_admin, require_tenant_admin_or_above
from app.middleware.tenant_context import CurrentUser from app.middleware.tenant_context import CurrentUser
from app.models.device import Device from app.models.device import Device
@@ -70,15 +69,18 @@ async def list_tenants(
else: else:
if not current_user.tenant_id: if not current_user.tenant_id:
return [] return []
result = await db.execute( result = await db.execute(select(Tenant).where(Tenant.id == current_user.tenant_id))
select(Tenant).where(Tenant.id == current_user.tenant_id)
)
tenants = result.scalars().all() tenants = result.scalars().all()
return [await _get_tenant_response(tenant, db) for tenant in tenants] return [await _get_tenant_response(tenant, db) for tenant in tenants]
@router.post("", response_model=TenantResponse, status_code=status.HTTP_201_CREATED, summary="Create a tenant") @router.post(
"",
response_model=TenantResponse,
status_code=status.HTTP_201_CREATED,
summary="Create a tenant",
)
@limiter.limit("20/minute") @limiter.limit("20/minute")
async def create_tenant( async def create_tenant(
request: Request, request: Request,
@@ -108,13 +110,21 @@ async def create_tenant(
("Device Offline", "device_offline", "eq", 1, 1, "critical"), ("Device Offline", "device_offline", "eq", 1, 1, "critical"),
] ]
for name, metric, operator, threshold, duration, sev in default_rules: for name, metric, operator, threshold, duration, sev in default_rules:
await db.execute(text(""" await db.execute(
text("""
INSERT INTO alert_rules (id, tenant_id, name, metric, operator, threshold, duration_polls, severity, enabled, is_default) INSERT INTO alert_rules (id, tenant_id, name, metric, operator, threshold, duration_polls, severity, enabled, is_default)
VALUES (gen_random_uuid(), CAST(:tenant_id AS uuid), :name, :metric, :operator, :threshold, :duration, :severity, TRUE, TRUE) VALUES (gen_random_uuid(), CAST(:tenant_id AS uuid), :name, :metric, :operator, :threshold, :duration, :severity, TRUE, TRUE)
"""), { """),
"tenant_id": str(tenant.id), "name": name, "metric": metric, {
"operator": operator, "threshold": threshold, "duration": duration, "severity": sev, "tenant_id": str(tenant.id),
}) "name": name,
"metric": metric,
"operator": operator,
"threshold": threshold,
"duration": duration,
"severity": sev,
},
)
await db.commit() await db.commit()
# Seed starter config templates for new tenant # Seed starter config templates for new tenant
@@ -131,6 +141,7 @@ async def create_tenant(
await db.commit() await db.commit()
except Exception as exc: except Exception as exc:
import logging import logging
logging.getLogger(__name__).warning( logging.getLogger(__name__).warning(
"OpenBao key provisioning failed for tenant %s (will be provisioned on next startup): %s", "OpenBao key provisioning failed for tenant %s (will be provisioned on next startup): %s",
tenant.id, tenant.id,
@@ -229,6 +240,7 @@ async def delete_tenant(
# Check if tenant had VPN configured (before cascade deletes it) # Check if tenant had VPN configured (before cascade deletes it)
from app.services.vpn_service import get_vpn_config, sync_wireguard_config from app.services.vpn_service import get_vpn_config, sync_wireguard_config
had_vpn = await get_vpn_config(db, tenant_id) had_vpn = await get_vpn_config(db, tenant_id)
await db.delete(tenant) await db.delete(tenant)
@@ -288,14 +300,54 @@ add chain=forward action=drop comment="Drop everything else"
# Identity # Identity
/system identity set name={{ device.hostname }}""", /system identity set name={{ device.hostname }}""",
"variables": [ "variables": [
{"name": "wan_interface", "type": "string", "default": "ether1", "description": "WAN-facing interface"}, {
{"name": "lan_gateway", "type": "ip", "default": "192.168.88.1", "description": "LAN gateway IP"}, "name": "wan_interface",
{"name": "lan_cidr", "type": "integer", "default": "24", "description": "LAN subnet mask bits"}, "type": "string",
{"name": "lan_network", "type": "ip", "default": "192.168.88.0", "description": "LAN network address"}, "default": "ether1",
{"name": "dhcp_start", "type": "ip", "default": "192.168.88.100", "description": "DHCP pool start"}, "description": "WAN-facing interface",
{"name": "dhcp_end", "type": "ip", "default": "192.168.88.254", "description": "DHCP pool end"}, },
{"name": "dns_servers", "type": "string", "default": "8.8.8.8,8.8.4.4", "description": "Upstream DNS servers"}, {
{"name": "ntp_server", "type": "string", "default": "pool.ntp.org", "description": "NTP server"}, "name": "lan_gateway",
"type": "ip",
"default": "192.168.88.1",
"description": "LAN gateway IP",
},
{
"name": "lan_cidr",
"type": "integer",
"default": "24",
"description": "LAN subnet mask bits",
},
{
"name": "lan_network",
"type": "ip",
"default": "192.168.88.0",
"description": "LAN network address",
},
{
"name": "dhcp_start",
"type": "ip",
"default": "192.168.88.100",
"description": "DHCP pool start",
},
{
"name": "dhcp_end",
"type": "ip",
"default": "192.168.88.254",
"description": "DHCP pool end",
},
{
"name": "dns_servers",
"type": "string",
"default": "8.8.8.8,8.8.4.4",
"description": "Upstream DNS servers",
},
{
"name": "ntp_server",
"type": "string",
"default": "pool.ntp.org",
"description": "NTP server",
},
], ],
}, },
{ {
@@ -311,8 +363,18 @@ add chain=forward connection-state=invalid action=drop
add chain=forward src-address={{ allowed_network }} action=accept add chain=forward src-address={{ allowed_network }} action=accept
add chain=forward action=drop""", add chain=forward action=drop""",
"variables": [ "variables": [
{"name": "wan_interface", "type": "string", "default": "ether1", "description": "WAN-facing interface"}, {
{"name": "allowed_network", "type": "subnet", "default": "192.168.88.0/24", "description": "Allowed source network"}, "name": "wan_interface",
"type": "string",
"default": "ether1",
"description": "WAN-facing interface",
},
{
"name": "allowed_network",
"type": "subnet",
"default": "192.168.88.0/24",
"description": "Allowed source network",
},
], ],
}, },
{ {
@@ -322,11 +384,36 @@ add chain=forward action=drop""",
/ip dhcp-server network add address={{ gateway }}/24 gateway={{ gateway }} dns-server={{ dns_server }} /ip dhcp-server network add address={{ gateway }}/24 gateway={{ gateway }} dns-server={{ dns_server }}
/ip dhcp-server add name=dhcp1 interface={{ interface }} address-pool=dhcp-pool disabled=no""", /ip dhcp-server add name=dhcp1 interface={{ interface }} address-pool=dhcp-pool disabled=no""",
"variables": [ "variables": [
{"name": "pool_start", "type": "ip", "default": "192.168.88.100", "description": "DHCP pool start address"}, {
{"name": "pool_end", "type": "ip", "default": "192.168.88.254", "description": "DHCP pool end address"}, "name": "pool_start",
{"name": "gateway", "type": "ip", "default": "192.168.88.1", "description": "Default gateway"}, "type": "ip",
{"name": "dns_server", "type": "ip", "default": "8.8.8.8", "description": "DNS server address"}, "default": "192.168.88.100",
{"name": "interface", "type": "string", "default": "bridge-lan", "description": "Interface to serve DHCP on"}, "description": "DHCP pool start address",
},
{
"name": "pool_end",
"type": "ip",
"default": "192.168.88.254",
"description": "DHCP pool end address",
},
{
"name": "gateway",
"type": "ip",
"default": "192.168.88.1",
"description": "Default gateway",
},
{
"name": "dns_server",
"type": "ip",
"default": "8.8.8.8",
"description": "DNS server address",
},
{
"name": "interface",
"type": "string",
"default": "bridge-lan",
"description": "Interface to serve DHCP on",
},
], ],
}, },
{ {
@@ -335,10 +422,30 @@ add chain=forward action=drop""",
"content": """/interface wireless security-profiles add name=portal-wpa2 mode=dynamic-keys authentication-types=wpa2-psk wpa2-pre-shared-key={{ password }} "content": """/interface wireless security-profiles add name=portal-wpa2 mode=dynamic-keys authentication-types=wpa2-psk wpa2-pre-shared-key={{ password }}
/interface wireless set wlan1 mode=ap-bridge ssid={{ ssid }} security-profile=portal-wpa2 frequency={{ frequency }} channel-width={{ channel_width }} disabled=no""", /interface wireless set wlan1 mode=ap-bridge ssid={{ ssid }} security-profile=portal-wpa2 frequency={{ frequency }} channel-width={{ channel_width }} disabled=no""",
"variables": [ "variables": [
{"name": "ssid", "type": "string", "default": "MikroTik-AP", "description": "Wireless network name"}, {
{"name": "password", "type": "string", "default": "", "description": "WPA2 pre-shared key (min 8 characters)"}, "name": "ssid",
{"name": "frequency", "type": "integer", "default": "2412", "description": "Wireless frequency in MHz"}, "type": "string",
{"name": "channel_width", "type": "string", "default": "20/40mhz-XX", "description": "Channel width setting"}, "default": "MikroTik-AP",
"description": "Wireless network name",
},
{
"name": "password",
"type": "string",
"default": "",
"description": "WPA2 pre-shared key (min 8 characters)",
},
{
"name": "frequency",
"type": "integer",
"default": "2412",
"description": "Wireless frequency in MHz",
},
{
"name": "channel_width",
"type": "string",
"default": "20/40mhz-XX",
"description": "Channel width setting",
},
], ],
}, },
{ {
@@ -351,8 +458,18 @@ add chain=forward action=drop""",
/ip service set ssh port=22 /ip service set ssh port=22
/ip service set winbox port=8291""", /ip service set winbox port=8291""",
"variables": [ "variables": [
{"name": "ntp_server", "type": "ip", "default": "pool.ntp.org", "description": "NTP server address"}, {
{"name": "dns_servers", "type": "string", "default": "8.8.8.8,8.8.4.4", "description": "Comma-separated DNS servers"}, "name": "ntp_server",
"type": "ip",
"default": "pool.ntp.org",
"description": "NTP server address",
},
{
"name": "dns_servers",
"type": "string",
"default": "8.8.8.8,8.8.4.4",
"description": "Comma-separated DNS servers",
},
], ],
}, },
] ]
@@ -363,13 +480,16 @@ async def _seed_starter_templates(db, tenant_id) -> None:
import json as _json import json as _json
for tmpl in _STARTER_TEMPLATES: for tmpl in _STARTER_TEMPLATES:
await db.execute(text(""" await db.execute(
text("""
INSERT INTO config_templates (id, tenant_id, name, description, content, variables) INSERT INTO config_templates (id, tenant_id, name, description, content, variables)
VALUES (gen_random_uuid(), CAST(:tid AS uuid), :name, :desc, :content, CAST(:vars AS jsonb)) VALUES (gen_random_uuid(), CAST(:tid AS uuid), :name, :desc, :content, CAST(:vars AS jsonb))
"""), { """),
"tid": str(tenant_id), {
"name": tmpl["name"], "tid": str(tenant_id),
"desc": tmpl["description"], "name": tmpl["name"],
"content": tmpl["content"], "desc": tmpl["description"],
"vars": _json.dumps(tmpl["variables"]), "content": tmpl["content"],
}) "vars": _json.dumps(tmpl["variables"]),
},
)

View File

@@ -14,7 +14,6 @@ Builds a topology graph of managed devices by:
import asyncio import asyncio
import ipaddress import ipaddress
import json import json
import logging
import uuid import uuid
from typing import Any from typing import Any
@@ -265,7 +264,7 @@ async def get_topology(
nodes: list[TopologyNode] = [] nodes: list[TopologyNode] = []
ip_to_device: dict[str, str] = {} ip_to_device: dict[str, str] = {}
online_device_ids: list[str] = [] online_device_ids: list[str] = []
devices_by_id: dict[str, Any] = {} _devices_by_id: dict[str, Any] = {}
for row in rows: for row in rows:
device_id = str(row.id) device_id = str(row.id)
@@ -288,9 +287,7 @@ async def get_topology(
if online_device_ids: if online_device_ids:
tasks = [ tasks = [
routeros_proxy.execute_command( routeros_proxy.execute_command(device_id, "/ip/neighbor/print", timeout=10.0)
device_id, "/ip/neighbor/print", timeout=10.0
)
for device_id in online_device_ids for device_id in online_device_ids
] ]
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks, return_exceptions=True)

View File

@@ -164,9 +164,7 @@ async def list_transparency_logs(
# Count total # Count total
count_result = await db.execute( count_result = await db.execute(
select(func.count()) select(func.count()).select_from(text("key_access_log k")).where(where_clause),
.select_from(text("key_access_log k"))
.where(where_clause),
params, params,
) )
total = count_result.scalar() or 0 total = count_result.scalar() or 0
@@ -353,39 +351,41 @@ async def export_transparency_logs(
output = io.StringIO() output = io.StringIO()
writer = csv.writer(output) writer = csv.writer(output)
writer.writerow([ writer.writerow(
"ID", [
"Action", "ID",
"Device Name", "Action",
"Device ID", "Device Name",
"Justification", "Device ID",
"Operator Email", "Justification",
"Correlation ID", "Operator Email",
"Resource Type", "Correlation ID",
"Resource ID", "Resource Type",
"IP Address", "Resource ID",
"Timestamp", "IP Address",
]) "Timestamp",
]
)
for row in all_rows: for row in all_rows:
writer.writerow([ writer.writerow(
str(row["id"]), [
row["action"], str(row["id"]),
row["device_name"] or "", row["action"],
str(row["device_id"]) if row["device_id"] else "", row["device_name"] or "",
row["justification"] or "", str(row["device_id"]) if row["device_id"] else "",
row["operator_email"] or "", row["justification"] or "",
row["correlation_id"] or "", row["operator_email"] or "",
row["resource_type"] or "", row["correlation_id"] or "",
row["resource_id"] or "", row["resource_type"] or "",
row["ip_address"] or "", row["resource_id"] or "",
str(row["created_at"]), row["ip_address"] or "",
]) str(row["created_at"]),
]
)
output.seek(0) output.seek(0)
return StreamingResponse( return StreamingResponse(
iter([output.getvalue()]), iter([output.getvalue()]),
media_type="text/csv", media_type="text/csv",
headers={ headers={"Content-Disposition": "attachment; filename=transparency-logs.csv"},
"Content-Disposition": "attachment; filename=transparency-logs.csv"
},
) )

View File

@@ -20,7 +20,7 @@ from app.database import get_admin_db
from app.middleware.rbac import require_tenant_admin_or_above from app.middleware.rbac import require_tenant_admin_or_above
from app.middleware.tenant_context import CurrentUser from app.middleware.tenant_context import CurrentUser
from app.models.tenant import Tenant from app.models.tenant import Tenant
from app.models.user import User, UserRole from app.models.user import User
from app.schemas.user import UserCreate, UserResponse, UserUpdate from app.schemas.user import UserCreate, UserResponse, UserUpdate
from app.services.auth import hash_password from app.services.auth import hash_password
@@ -69,11 +69,7 @@ async def list_users(
""" """
await _check_tenant_access(tenant_id, current_user, db) await _check_tenant_access(tenant_id, current_user, db)
result = await db.execute( result = await db.execute(select(User).where(User.tenant_id == tenant_id).order_by(User.name))
select(User)
.where(User.tenant_id == tenant_id)
.order_by(User.name)
)
users = result.scalars().all() users = result.scalars().all()
return [UserResponse.model_validate(user) for user in users] return [UserResponse.model_validate(user) for user in users]
@@ -103,9 +99,7 @@ async def create_user(
await _check_tenant_access(tenant_id, current_user, db) await _check_tenant_access(tenant_id, current_user, db)
# Check email uniqueness (global, not per-tenant) # Check email uniqueness (global, not per-tenant)
existing = await db.execute( existing = await db.execute(select(User).where(User.email == data.email.lower()))
select(User).where(User.email == data.email.lower())
)
if existing.scalar_one_or_none(): if existing.scalar_one_or_none():
raise HTTPException( raise HTTPException(
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_409_CONFLICT,
@@ -138,9 +132,7 @@ async def get_user(
"""Get user detail.""" """Get user detail."""
await _check_tenant_access(tenant_id, current_user, db) await _check_tenant_access(tenant_id, current_user, db)
result = await db.execute( result = await db.execute(select(User).where(User.id == user_id, User.tenant_id == tenant_id))
select(User).where(User.id == user_id, User.tenant_id == tenant_id)
)
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
if not user: if not user:
@@ -168,9 +160,7 @@ async def update_user(
""" """
await _check_tenant_access(tenant_id, current_user, db) await _check_tenant_access(tenant_id, current_user, db)
result = await db.execute( result = await db.execute(select(User).where(User.id == user_id, User.tenant_id == tenant_id))
select(User).where(User.id == user_id, User.tenant_id == tenant_id)
)
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
if not user: if not user:
@@ -194,7 +184,11 @@ async def update_user(
return UserResponse.model_validate(user) return UserResponse.model_validate(user)
@router.delete("/{tenant_id}/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Deactivate a user") @router.delete(
"/{tenant_id}/users/{user_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Deactivate a user",
)
@limiter.limit("5/minute") @limiter.limit("5/minute")
async def deactivate_user( async def deactivate_user(
request: Request, request: Request,
@@ -209,9 +203,7 @@ async def deactivate_user(
""" """
await _check_tenant_access(tenant_id, current_user, db) await _check_tenant_access(tenant_id, current_user, db)
result = await db.execute( result = await db.execute(select(User).where(User.id == user_id, User.tenant_id == tenant_id))
select(User).where(User.id == user_id, User.tenant_id == tenant_id)
)
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
if not user: if not user:

View File

@@ -68,7 +68,11 @@ async def get_vpn_config(
return resp return resp
@router.post("/tenants/{tenant_id}/vpn", response_model=VpnConfigResponse, status_code=status.HTTP_201_CREATED) @router.post(
"/tenants/{tenant_id}/vpn",
response_model=VpnConfigResponse,
status_code=status.HTTP_201_CREATED,
)
@limiter.limit("20/minute") @limiter.limit("20/minute")
async def setup_vpn( async def setup_vpn(
request: Request, request: Request,
@@ -177,7 +181,11 @@ async def list_peers(
return responses return responses
@router.post("/tenants/{tenant_id}/vpn/peers", response_model=VpnPeerResponse, status_code=status.HTTP_201_CREATED) @router.post(
"/tenants/{tenant_id}/vpn/peers",
response_model=VpnPeerResponse,
status_code=status.HTTP_201_CREATED,
)
@limiter.limit("20/minute") @limiter.limit("20/minute")
async def add_peer( async def add_peer(
request: Request, request: Request,
@@ -190,7 +198,9 @@ async def add_peer(
await _check_tenant_access(current_user, tenant_id, db) await _check_tenant_access(current_user, tenant_id, db)
_require_operator(current_user) _require_operator(current_user)
try: try:
peer = await vpn_service.add_peer(db, tenant_id, body.device_id, additional_allowed_ips=body.additional_allowed_ips) peer = await vpn_service.add_peer(
db, tenant_id, body.device_id, additional_allowed_ips=body.additional_allowed_ips
)
except ValueError as e: except ValueError as e:
msg = str(e) msg = str(e)
if "must not overlap" in msg: if "must not overlap" in msg:
@@ -208,7 +218,11 @@ async def add_peer(
return resp return resp
@router.post("/tenants/{tenant_id}/vpn/peers/onboard", response_model=VpnOnboardResponse, status_code=status.HTTP_201_CREATED) @router.post(
"/tenants/{tenant_id}/vpn/peers/onboard",
response_model=VpnOnboardResponse,
status_code=status.HTTP_201_CREATED,
)
@limiter.limit("10/minute") @limiter.limit("10/minute")
async def onboard_device( async def onboard_device(
request: Request, request: Request,
@@ -222,7 +236,8 @@ async def onboard_device(
_require_operator(current_user) _require_operator(current_user)
try: try:
result = await vpn_service.onboard_device( result = await vpn_service.onboard_device(
db, tenant_id, db,
tenant_id,
hostname=body.hostname, hostname=body.hostname,
username=body.username, username=body.username,
password=body.password, password=body.password,

View File

@@ -146,16 +146,16 @@ async def _delete_session_from_redis(session_id: str) -> None:
await rd.delete(f"{REDIS_PREFIX}{session_id}") await rd.delete(f"{REDIS_PREFIX}{session_id}")
async def _open_tunnel( async def _open_tunnel(device_id: uuid.UUID, tenant_id: uuid.UUID, user_id: uuid.UUID) -> dict:
device_id: uuid.UUID, tenant_id: uuid.UUID, user_id: uuid.UUID
) -> dict:
"""Open a TCP tunnel to device port 8291 via NATS request-reply.""" """Open a TCP tunnel to device port 8291 via NATS request-reply."""
payload = json.dumps({ payload = json.dumps(
"device_id": str(device_id), {
"tenant_id": str(tenant_id), "device_id": str(device_id),
"user_id": str(user_id), "tenant_id": str(tenant_id),
"target_port": 8291, "user_id": str(user_id),
}).encode() "target_port": 8291,
}
).encode()
try: try:
nc = await _get_nats() nc = await _get_nats()
@@ -176,9 +176,7 @@ async def _open_tunnel(
) )
if "error" in data: if "error" in data:
raise HTTPException( raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=data["error"])
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=data["error"]
)
return data return data
@@ -250,9 +248,7 @@ async def create_winbox_remote_session(
except Exception: except Exception:
worker_info = None worker_info = None
if worker_info is None: if worker_info is None:
logger.warning( logger.warning("Cleaning stale Redis session %s (worker 404)", stale_sid)
"Cleaning stale Redis session %s (worker 404)", stale_sid
)
tunnel_id = sess.get("tunnel_id") tunnel_id = sess.get("tunnel_id")
if tunnel_id: if tunnel_id:
await _close_tunnel(tunnel_id) await _close_tunnel(tunnel_id)
@@ -333,12 +329,8 @@ async def create_winbox_remote_session(
username = "" # noqa: F841 username = "" # noqa: F841
password = "" # noqa: F841 password = "" # noqa: F841
expires_at = datetime.fromisoformat( expires_at = datetime.fromisoformat(worker_resp.get("expires_at", now.isoformat()))
worker_resp.get("expires_at", now.isoformat()) max_expires_at = datetime.fromisoformat(worker_resp.get("max_expires_at", now.isoformat()))
)
max_expires_at = datetime.fromisoformat(
worker_resp.get("max_expires_at", now.isoformat())
)
# Save session to Redis # Save session to Redis
session_data = { session_data = {
@@ -375,8 +367,7 @@ async def create_winbox_remote_session(
pass pass
ws_path = ( ws_path = (
f"/api/tenants/{tenant_id}/devices/{device_id}" f"/api/tenants/{tenant_id}/devices/{device_id}/winbox-remote-sessions/{session_id}/ws"
f"/winbox-remote-sessions/{session_id}/ws"
) )
return RemoteWinboxSessionResponse( return RemoteWinboxSessionResponse(
@@ -425,14 +416,10 @@ async def get_winbox_remote_session(
sess = await _get_session_from_redis(str(session_id)) sess = await _get_session_from_redis(str(session_id))
if sess is None: if sess is None:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
status_code=status.HTTP_404_NOT_FOUND, detail="Session not found"
)
if sess.get("tenant_id") != str(tenant_id) or sess.get("device_id") != str(device_id): if sess.get("tenant_id") != str(tenant_id) or sess.get("device_id") != str(device_id):
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
status_code=status.HTTP_404_NOT_FOUND, detail="Session not found"
)
return RemoteWinboxStatusResponse( return RemoteWinboxStatusResponse(
session_id=uuid.UUID(sess["session_id"]), session_id=uuid.UUID(sess["session_id"]),
@@ -478,10 +465,7 @@ async def list_winbox_remote_sessions(
sess = json.loads(raw) sess = json.loads(raw)
except Exception: except Exception:
continue continue
if ( if sess.get("tenant_id") == str(tenant_id) and sess.get("device_id") == str(device_id):
sess.get("tenant_id") == str(tenant_id)
and sess.get("device_id") == str(device_id)
):
sessions.append( sessions.append(
RemoteWinboxStatusResponse( RemoteWinboxStatusResponse(
session_id=uuid.UUID(sess["session_id"]), session_id=uuid.UUID(sess["session_id"]),
@@ -533,9 +517,7 @@ async def terminate_winbox_remote_session(
) )
if sess.get("tenant_id") != str(tenant_id): if sess.get("tenant_id") != str(tenant_id):
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
status_code=status.HTTP_404_NOT_FOUND, detail="Session not found"
)
# Rollback order: worker -> tunnel -> redis -> audit # Rollback order: worker -> tunnel -> redis -> audit
await worker_terminate_session(str(session_id)) await worker_terminate_session(str(session_id))
@@ -574,14 +556,12 @@ async def terminate_winbox_remote_session(
@router.get( @router.get(
"/tenants/{tenant_id}/devices/{device_id}" "/tenants/{tenant_id}/devices/{device_id}/winbox-remote-sessions/{session_id}/xpra/{path:path}",
"/winbox-remote-sessions/{session_id}/xpra/{path:path}",
summary="Proxy Xpra HTML5 client files", summary="Proxy Xpra HTML5 client files",
dependencies=[Depends(require_operator_or_above)], dependencies=[Depends(require_operator_or_above)],
) )
@router.get( @router.get(
"/tenants/{tenant_id}/devices/{device_id}" "/tenants/{tenant_id}/devices/{device_id}/winbox-remote-sessions/{session_id}/xpra",
"/winbox-remote-sessions/{session_id}/xpra",
summary="Proxy Xpra HTML5 client (root)", summary="Proxy Xpra HTML5 client (root)",
dependencies=[Depends(require_operator_or_above)], dependencies=[Depends(require_operator_or_above)],
) )
@@ -626,7 +606,8 @@ async def proxy_xpra_html(
content=proxy_resp.content, content=proxy_resp.content,
status_code=proxy_resp.status_code, status_code=proxy_resp.status_code,
headers={ headers={
k: v for k, v in proxy_resp.headers.items() k: v
for k, v in proxy_resp.headers.items()
if k.lower() in ("content-type", "cache-control", "content-encoding") if k.lower() in ("content-type", "cache-control", "content-encoding")
}, },
) )
@@ -637,9 +618,7 @@ async def proxy_xpra_html(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.websocket( @router.websocket("/tenants/{tenant_id}/devices/{device_id}/winbox-remote-sessions/{session_id}/ws")
"/tenants/{tenant_id}/devices/{device_id}/winbox-remote-sessions/{session_id}/ws"
)
async def winbox_remote_ws_proxy( async def winbox_remote_ws_proxy(
websocket: WebSocket, websocket: WebSocket,
tenant_id: uuid.UUID, tenant_id: uuid.UUID,

View File

@@ -67,11 +67,13 @@ class MessageResponse(BaseModel):
class SRPInitRequest(BaseModel): class SRPInitRequest(BaseModel):
"""Step 1 request: client sends email to begin SRP handshake.""" """Step 1 request: client sends email to begin SRP handshake."""
email: EmailStr email: EmailStr
class SRPInitResponse(BaseModel): class SRPInitResponse(BaseModel):
"""Step 1 response: server returns ephemeral B and key derivation salts.""" """Step 1 response: server returns ephemeral B and key derivation salts."""
salt: str # hex-encoded SRP salt salt: str # hex-encoded SRP salt
server_public: str # hex-encoded server ephemeral B server_public: str # hex-encoded server ephemeral B
session_id: str # Redis session key nonce session_id: str # Redis session key nonce
@@ -81,6 +83,7 @@ class SRPInitResponse(BaseModel):
class SRPVerifyRequest(BaseModel): class SRPVerifyRequest(BaseModel):
"""Step 2 request: client sends proof M1 to complete handshake.""" """Step 2 request: client sends proof M1 to complete handshake."""
email: EmailStr email: EmailStr
session_id: str session_id: str
client_public: str # hex-encoded client ephemeral A client_public: str # hex-encoded client ephemeral A
@@ -89,6 +92,7 @@ class SRPVerifyRequest(BaseModel):
class SRPVerifyResponse(BaseModel): class SRPVerifyResponse(BaseModel):
"""Step 2 response: server returns tokens and proof M2.""" """Step 2 response: server returns tokens and proof M2."""
access_token: str access_token: str
refresh_token: str refresh_token: str
token_type: str = "bearer" token_type: str = "bearer"
@@ -98,6 +102,7 @@ class SRPVerifyResponse(BaseModel):
class SRPRegisterRequest(BaseModel): class SRPRegisterRequest(BaseModel):
"""Used during registration to store SRP verifier and key set.""" """Used during registration to store SRP verifier and key set."""
srp_salt: str # hex-encoded srp_salt: str # hex-encoded
srp_verifier: str # hex-encoded srp_verifier: str # hex-encoded
encrypted_private_key: str # base64-encoded encrypted_private_key: str # base64-encoded
@@ -114,10 +119,12 @@ class SRPRegisterRequest(BaseModel):
class DeleteAccountRequest(BaseModel): class DeleteAccountRequest(BaseModel):
"""Request body for account self-deletion. User must type 'DELETE' to confirm.""" """Request body for account self-deletion. User must type 'DELETE' to confirm."""
confirmation: str # Must be "DELETE" to confirm confirmation: str # Must be "DELETE" to confirm
class DeleteAccountResponse(BaseModel): class DeleteAccountResponse(BaseModel):
"""Response after successful account deletion.""" """Response after successful account deletion."""
message: str message: str
deleted: bool deleted: bool

View File

@@ -10,6 +10,7 @@ from pydantic import BaseModel, ConfigDict
# Request schemas # Request schemas
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class CACreateRequest(BaseModel): class CACreateRequest(BaseModel):
"""Request to generate a new root CA for the tenant.""" """Request to generate a new root CA for the tenant."""
@@ -34,6 +35,7 @@ class BulkCertDeployRequest(BaseModel):
# Response schemas # Response schemas
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class CAResponse(BaseModel): class CAResponse(BaseModel):
"""Public details of a tenant's Certificate Authority (no private key).""" """Public details of a tenant's Certificate Authority (no private key)."""

View File

@@ -117,6 +117,7 @@ class SubnetScanRequest(BaseModel):
def validate_cidr(cls, v: str) -> str: def validate_cidr(cls, v: str) -> str:
"""Validate that the value is a valid CIDR notation and RFC 1918 private range.""" """Validate that the value is a valid CIDR notation and RFC 1918 private range."""
import ipaddress import ipaddress
try: try:
network = ipaddress.ip_network(v, strict=False) network = ipaddress.ip_network(v, strict=False)
except ValueError as e: except ValueError as e:
@@ -239,6 +240,7 @@ class DeviceTagCreate(BaseModel):
if v is None: if v is None:
return v return v
import re import re
if not re.match(r"^#[0-9A-Fa-f]{6}$", v): if not re.match(r"^#[0-9A-Fa-f]{6}$", v):
raise ValueError("Color must be a valid 6-digit hex color (e.g. #FF5733)") raise ValueError("Color must be a valid 6-digit hex color (e.g. #FF5733)")
return v return v
@@ -256,6 +258,7 @@ class DeviceTagUpdate(BaseModel):
if v is None: if v is None:
return v return v
import re import re
if not re.match(r"^#[0-9A-Fa-f]{6}$", v): if not re.match(r"^#[0-9A-Fa-f]{6}$", v):
raise ValueError("Color must be a valid 6-digit hex color (e.g. #FF5733)") raise ValueError("Color must be a valid 6-digit hex color (e.g. #FF5733)")
return v return v

View File

@@ -12,11 +12,15 @@ from pydantic import BaseModel
class VpnSetupRequest(BaseModel): class VpnSetupRequest(BaseModel):
"""Request to enable VPN for a tenant.""" """Request to enable VPN for a tenant."""
endpoint: Optional[str] = None # public hostname:port — if blank, devices must be configured manually
endpoint: Optional[str] = (
None # public hostname:port — if blank, devices must be configured manually
)
class VpnConfigResponse(BaseModel): class VpnConfigResponse(BaseModel):
"""VPN server configuration (never exposes private key).""" """VPN server configuration (never exposes private key)."""
model_config = {"from_attributes": True} model_config = {"from_attributes": True}
id: uuid.UUID id: uuid.UUID
@@ -33,6 +37,7 @@ class VpnConfigResponse(BaseModel):
class VpnConfigUpdate(BaseModel): class VpnConfigUpdate(BaseModel):
"""Update VPN configuration.""" """Update VPN configuration."""
endpoint: Optional[str] = None endpoint: Optional[str] = None
is_enabled: Optional[bool] = None is_enabled: Optional[bool] = None
@@ -42,12 +47,14 @@ class VpnConfigUpdate(BaseModel):
class VpnPeerCreate(BaseModel): class VpnPeerCreate(BaseModel):
"""Add a device as a VPN peer.""" """Add a device as a VPN peer."""
device_id: uuid.UUID device_id: uuid.UUID
additional_allowed_ips: Optional[str] = None # comma-separated subnets for site-to-site routing additional_allowed_ips: Optional[str] = None # comma-separated subnets for site-to-site routing
class VpnPeerResponse(BaseModel): class VpnPeerResponse(BaseModel):
"""VPN peer info (never exposes private key).""" """VPN peer info (never exposes private key)."""
model_config = {"from_attributes": True} model_config = {"from_attributes": True}
id: uuid.UUID id: uuid.UUID
@@ -66,6 +73,7 @@ class VpnPeerResponse(BaseModel):
class VpnOnboardRequest(BaseModel): class VpnOnboardRequest(BaseModel):
"""Combined device creation + VPN peer onboarding.""" """Combined device creation + VPN peer onboarding."""
hostname: str hostname: str
username: str username: str
password: str password: str
@@ -73,6 +81,7 @@ class VpnOnboardRequest(BaseModel):
class VpnOnboardResponse(BaseModel): class VpnOnboardResponse(BaseModel):
"""Response from onboarding — device, peer, and RouterOS commands.""" """Response from onboarding — device, peer, and RouterOS commands."""
device_id: uuid.UUID device_id: uuid.UUID
peer_id: uuid.UUID peer_id: uuid.UUID
hostname: str hostname: str
@@ -82,6 +91,7 @@ class VpnOnboardResponse(BaseModel):
class VpnPeerConfig(BaseModel): class VpnPeerConfig(BaseModel):
"""Full peer config for display/export — includes private key for device setup.""" """Full peer config for display/export — includes private key for device setup."""
peer_private_key: str peer_private_key: str
peer_public_key: str peer_public_key: str
assigned_ip: str assigned_ip: str

View File

@@ -57,6 +57,7 @@ class RemoteWinboxDuplicateDetail(BaseModel):
class RemoteWinboxSessionItem(BaseModel): class RemoteWinboxSessionItem(BaseModel):
"""Used in the combined active sessions list.""" """Used in the combined active sessions list."""
session_id: uuid.UUID session_id: uuid.UUID
status: RemoteWinboxState status: RemoteWinboxState
created_at: datetime created_at: datetime

View File

@@ -86,11 +86,7 @@ async def delete_user_account(
# Null out encrypted_details (may contain encrypted PII) # Null out encrypted_details (may contain encrypted PII)
await db.execute( await db.execute(
text( text("UPDATE audit_logs SET encrypted_details = NULL WHERE user_id = :user_id"),
"UPDATE audit_logs "
"SET encrypted_details = NULL "
"WHERE user_id = :user_id"
),
{"user_id": user_id}, {"user_id": user_id},
) )
@@ -177,16 +173,18 @@ async def export_user_data(
) )
api_keys = [] api_keys = []
for row in result.mappings().all(): for row in result.mappings().all():
api_keys.append({ api_keys.append(
"id": str(row["id"]), {
"name": row["name"], "id": str(row["id"]),
"key_prefix": row["key_prefix"], "name": row["name"],
"scopes": row["scopes"], "key_prefix": row["key_prefix"],
"created_at": row["created_at"].isoformat() if row["created_at"] else None, "scopes": row["scopes"],
"expires_at": row["expires_at"].isoformat() if row["expires_at"] else None, "created_at": row["created_at"].isoformat() if row["created_at"] else None,
"revoked_at": row["revoked_at"].isoformat() if row["revoked_at"] else None, "expires_at": row["expires_at"].isoformat() if row["expires_at"] else None,
"last_used_at": row["last_used_at"].isoformat() if row["last_used_at"] else None, "revoked_at": row["revoked_at"].isoformat() if row["revoked_at"] else None,
}) "last_used_at": row["last_used_at"].isoformat() if row["last_used_at"] else None,
}
)
# ── Audit logs (limit 1000, most recent first) ─────────────────────── # ── Audit logs (limit 1000, most recent first) ───────────────────────
result = await db.execute( result = await db.execute(
@@ -201,15 +199,17 @@ async def export_user_data(
audit_logs = [] audit_logs = []
for row in result.mappings().all(): for row in result.mappings().all():
details = row["details"] if row["details"] else {} details = row["details"] if row["details"] else {}
audit_logs.append({ audit_logs.append(
"id": str(row["id"]), {
"action": row["action"], "id": str(row["id"]),
"resource_type": row["resource_type"], "action": row["action"],
"resource_id": row["resource_id"], "resource_type": row["resource_type"],
"details": details, "resource_id": row["resource_id"],
"ip_address": row["ip_address"], "details": details,
"created_at": row["created_at"].isoformat() if row["created_at"] else None, "ip_address": row["ip_address"],
}) "created_at": row["created_at"].isoformat() if row["created_at"] else None,
}
)
# ── Key access log (limit 1000, most recent first) ─────────────────── # ── Key access log (limit 1000, most recent first) ───────────────────
result = await db.execute( result = await db.execute(
@@ -222,13 +222,15 @@ async def export_user_data(
) )
key_access_entries = [] key_access_entries = []
for row in result.mappings().all(): for row in result.mappings().all():
key_access_entries.append({ key_access_entries.append(
"id": str(row["id"]), {
"action": row["action"], "id": str(row["id"]),
"resource_type": row["resource_type"], "action": row["action"],
"ip_address": row["ip_address"], "resource_type": row["resource_type"],
"created_at": row["created_at"].isoformat() if row["created_at"] else None, "ip_address": row["ip_address"],
}) "created_at": row["created_at"].isoformat() if row["created_at"] else None,
}
)
return { return {
"export_date": datetime.now(UTC).isoformat(), "export_date": datetime.now(UTC).isoformat(),

View File

@@ -253,7 +253,9 @@ async def _get_device_groups(device_id: str) -> list[str]:
"""Get group IDs for a device.""" """Get group IDs for a device."""
async with AdminAsyncSessionLocal() as session: async with AdminAsyncSessionLocal() as session:
result = await session.execute( result = await session.execute(
text("SELECT group_id FROM device_group_memberships WHERE device_id = CAST(:device_id AS uuid)"), text(
"SELECT group_id FROM device_group_memberships WHERE device_id = CAST(:device_id AS uuid)"
),
{"device_id": device_id}, {"device_id": device_id},
) )
return [str(row[0]) for row in result.fetchall()] return [str(row[0]) for row in result.fetchall()]
@@ -344,30 +346,36 @@ async def _create_alert_event(
# Publish real-time event to NATS for SSE pipeline (fire-and-forget) # Publish real-time event to NATS for SSE pipeline (fire-and-forget)
if status in ("firing", "flapping"): if status in ("firing", "flapping"):
await publish_event(f"alert.fired.{tenant_id}", { await publish_event(
"event_type": "alert_fired", f"alert.fired.{tenant_id}",
"tenant_id": tenant_id, {
"device_id": device_id, "event_type": "alert_fired",
"alert_event_id": alert_data["id"], "tenant_id": tenant_id,
"severity": severity, "device_id": device_id,
"metric": metric, "alert_event_id": alert_data["id"],
"current_value": value, "severity": severity,
"threshold": threshold, "metric": metric,
"message": message, "current_value": value,
"is_flapping": is_flapping, "threshold": threshold,
"fired_at": datetime.now(timezone.utc).isoformat(), "message": message,
}) "is_flapping": is_flapping,
"fired_at": datetime.now(timezone.utc).isoformat(),
},
)
elif status == "resolved": elif status == "resolved":
await publish_event(f"alert.resolved.{tenant_id}", { await publish_event(
"event_type": "alert_resolved", f"alert.resolved.{tenant_id}",
"tenant_id": tenant_id, {
"device_id": device_id, "event_type": "alert_resolved",
"alert_event_id": alert_data["id"], "tenant_id": tenant_id,
"severity": severity, "device_id": device_id,
"metric": metric, "alert_event_id": alert_data["id"],
"message": message, "severity": severity,
"resolved_at": datetime.now(timezone.utc).isoformat(), "metric": metric,
}) "message": message,
"resolved_at": datetime.now(timezone.utc).isoformat(),
},
)
return alert_data return alert_data
@@ -470,6 +478,7 @@ async def _dispatch_async(alert_event: dict, channels: list[dict], device_hostna
"""Fire-and-forget notification dispatch.""" """Fire-and-forget notification dispatch."""
try: try:
from app.services.notification_service import dispatch_notifications from app.services.notification_service import dispatch_notifications
await dispatch_notifications(alert_event, channels, device_hostname) await dispatch_notifications(alert_event, channels, device_hostname)
except Exception as e: except Exception as e:
logger.warning("Notification dispatch failed: %s", e) logger.warning("Notification dispatch failed: %s", e)
@@ -500,7 +509,8 @@ async def evaluate(
if await _is_device_in_maintenance(tenant_id, device_id): if await _is_device_in_maintenance(tenant_id, device_id):
logger.debug( logger.debug(
"Alert suppressed by maintenance window for device %s tenant %s", "Alert suppressed by maintenance window for device %s tenant %s",
device_id, tenant_id, device_id,
tenant_id,
) )
return return
@@ -573,7 +583,8 @@ async def evaluate(
if is_flapping: if is_flapping:
logger.info( logger.info(
"Alert %s for device %s is flapping — notifications suppressed", "Alert %s for device %s is flapping — notifications suppressed",
rule["name"], device_id, rule["name"],
device_id,
) )
else: else:
channels = await _get_channels_for_rule(rule["id"]) channels = await _get_channels_for_rule(rule["id"])

View File

@@ -56,9 +56,7 @@ async def log_action(
try: try:
from app.services.crypto import encrypt_data_transit from app.services.crypto import encrypt_data_transit
encrypted_details = await encrypt_data_transit( encrypted_details = await encrypt_data_transit(details_json, str(tenant_id))
details_json, str(tenant_id)
)
# Encryption succeeded — clear plaintext details # Encryption succeeded — clear plaintext details
details_json = _json.dumps({}) details_json = _json.dumps({})
except Exception: except Exception:

View File

@@ -32,8 +32,12 @@ def _cron_to_trigger(cron_expr: str) -> Optional[CronTrigger]:
return None return None
minute, hour, day, month, day_of_week = parts minute, hour, day, month, day_of_week = parts
return CronTrigger( return CronTrigger(
minute=minute, hour=hour, day=day, month=month, minute=minute,
day_of_week=day_of_week, timezone="UTC", hour=hour,
day=day,
month=month,
day_of_week=day_of_week,
timezone="UTC",
) )
except Exception as e: except Exception as e:
logger.warning("Invalid cron expression '%s': %s", cron_expr, e) logger.warning("Invalid cron expression '%s': %s", cron_expr, e)
@@ -52,10 +56,12 @@ def build_schedule_map(schedules: list) -> dict[str, list[dict]]:
cron = s.cron_expression or DEFAULT_CRON cron = s.cron_expression or DEFAULT_CRON
if cron not in schedule_map: if cron not in schedule_map:
schedule_map[cron] = [] schedule_map[cron] = []
schedule_map[cron].append({ schedule_map[cron].append(
"device_id": str(s.device_id), {
"tenant_id": str(s.tenant_id), "device_id": str(s.device_id),
}) "tenant_id": str(s.tenant_id),
}
)
return schedule_map return schedule_map
@@ -79,13 +85,15 @@ async def _run_scheduled_backups(devices: list[dict]) -> None:
except Exception as e: except Exception as e:
logger.error( logger.error(
"Scheduled backup FAILED: device %s: %s", "Scheduled backup FAILED: device %s: %s",
dev_info["device_id"], e, dev_info["device_id"],
e,
) )
failure_count += 1 failure_count += 1
logger.info( logger.info(
"Backup batch complete — %d succeeded, %d failed", "Backup batch complete — %d succeeded, %d failed",
success_count, failure_count, success_count,
failure_count,
) )
@@ -108,7 +116,7 @@ async def _load_effective_schedules() -> list:
# Index: device-specific and tenant defaults # Index: device-specific and tenant defaults
device_schedules = {} # device_id -> schedule device_schedules = {} # device_id -> schedule
tenant_defaults = {} # tenant_id -> schedule tenant_defaults = {} # tenant_id -> schedule
for s in schedules: for s in schedules:
if s.device_id: if s.device_id:
@@ -129,12 +137,14 @@ async def _load_effective_schedules() -> list:
# No schedule configured — use system default # No schedule configured — use system default
sched = None sched = None
effective.append(SimpleNamespace( effective.append(
device_id=dev_id, SimpleNamespace(
tenant_id=tenant_id, device_id=dev_id,
cron_expression=sched.cron_expression if sched else DEFAULT_CRON, tenant_id=tenant_id,
enabled=sched.enabled if sched else True, cron_expression=sched.cron_expression if sched else DEFAULT_CRON,
)) enabled=sched.enabled if sched else True,
)
)
return effective return effective
@@ -203,6 +213,7 @@ class _SchedulerProxy:
Usage: `from app.services.backup_scheduler import backup_scheduler` Usage: `from app.services.backup_scheduler import backup_scheduler`
then `backup_scheduler.add_job(...)`. then `backup_scheduler.add_job(...)`.
""" """
def __getattr__(self, name): def __getattr__(self, name):
if _scheduler is None: if _scheduler is None:
raise RuntimeError("Backup scheduler not started yet") raise RuntimeError("Backup scheduler not started yet")

View File

@@ -255,7 +255,12 @@ async def run_backup(
if prior_commits: if prior_commits:
try: try:
prior_export_bytes = await loop.run_in_executor( prior_export_bytes = await loop.run_in_executor(
None, git_store.read_file, tenant_id, prior_commits[0]["sha"], device_id, "export.rsc" None,
git_store.read_file,
tenant_id,
prior_commits[0]["sha"],
device_id,
"export.rsc",
) )
prior_text = prior_export_bytes.decode("utf-8", errors="replace") prior_text = prior_export_bytes.decode("utf-8", errors="replace")
lines_added, lines_removed = await loop.run_in_executor( lines_added, lines_removed = await loop.run_in_executor(
@@ -284,9 +289,7 @@ async def run_backup(
try: try:
from app.services.crypto import encrypt_data_transit from app.services.crypto import encrypt_data_transit
encrypted_export = await encrypt_data_transit( encrypted_export = await encrypt_data_transit(export_text, tenant_id)
export_text, tenant_id
)
encrypted_binary = await encrypt_data_transit( encrypted_binary = await encrypt_data_transit(
base64.b64encode(binary_backup).decode(), tenant_id base64.b64encode(binary_backup).decode(), tenant_id
) )
@@ -302,8 +305,7 @@ async def run_backup(
except Exception as enc_err: except Exception as enc_err:
# Transit unavailable — fall back to plaintext (non-fatal) # Transit unavailable — fall back to plaintext (non-fatal)
logger.warning( logger.warning(
"Transit encryption failed for %s backup of device %s, " "Transit encryption failed for %s backup of device %s, storing plaintext: %s",
"storing plaintext: %s",
trigger_type, trigger_type,
device_id, device_id,
enc_err, enc_err,
@@ -313,9 +315,7 @@ async def run_backup(
# ----------------------------------------------------------------------- # -----------------------------------------------------------------------
# Step 6: Commit to git (wrapped in run_in_executor — pygit2 is sync C bindings) # Step 6: Commit to git (wrapped in run_in_executor — pygit2 is sync C bindings)
# ----------------------------------------------------------------------- # -----------------------------------------------------------------------
commit_message = ( commit_message = f"{trigger_type}: {hostname} ({ip}) at {ts}"
f"{trigger_type}: {hostname} ({ip}) at {ts}"
)
commit_sha = await loop.run_in_executor( commit_sha = await loop.run_in_executor(
None, None,

View File

@@ -48,6 +48,7 @@ _VALID_TRANSITIONS: dict[str, set[str]] = {
# CA Generation # CA Generation
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def generate_ca( async def generate_ca(
db: AsyncSession, db: AsyncSession,
tenant_id: UUID, tenant_id: UUID,
@@ -84,10 +85,12 @@ async def generate_ca(
now = datetime.datetime.now(datetime.timezone.utc) now = datetime.datetime.now(datetime.timezone.utc)
expiry = now + datetime.timedelta(days=365 * validity_years) expiry = now + datetime.timedelta(days=365 * validity_years)
subject = issuer = x509.Name([ subject = issuer = x509.Name(
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"), [
x509.NameAttribute(NameOID.COMMON_NAME, common_name), x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
]) x509.NameAttribute(NameOID.COMMON_NAME, common_name),
]
)
ca_cert = ( ca_cert = (
x509.CertificateBuilder() x509.CertificateBuilder()
@@ -97,9 +100,7 @@ async def generate_ca(
.serial_number(x509.random_serial_number()) .serial_number(x509.random_serial_number())
.not_valid_before(now) .not_valid_before(now)
.not_valid_after(expiry) .not_valid_after(expiry)
.add_extension( .add_extension(x509.BasicConstraints(ca=True, path_length=0), critical=True)
x509.BasicConstraints(ca=True, path_length=0), critical=True
)
.add_extension( .add_extension(
x509.KeyUsage( x509.KeyUsage(
digital_signature=True, digital_signature=True,
@@ -166,6 +167,7 @@ async def generate_ca(
# Device Certificate Signing # Device Certificate Signing
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def sign_device_cert( async def sign_device_cert(
db: AsyncSession, db: AsyncSession,
ca: CertificateAuthority, ca: CertificateAuthority,
@@ -196,9 +198,7 @@ async def sign_device_cert(
str(ca.tenant_id), str(ca.tenant_id),
encryption_key, encryption_key,
) )
ca_key = serialization.load_pem_private_key( ca_key = serialization.load_pem_private_key(ca_key_pem.encode("utf-8"), password=None)
ca_key_pem.encode("utf-8"), password=None
)
# Load CA certificate for issuer info and AuthorityKeyIdentifier # Load CA certificate for issuer info and AuthorityKeyIdentifier
ca_cert = x509.load_pem_x509_certificate(ca.cert_pem.encode("utf-8")) ca_cert = x509.load_pem_x509_certificate(ca.cert_pem.encode("utf-8"))
@@ -212,19 +212,19 @@ async def sign_device_cert(
device_cert = ( device_cert = (
x509.CertificateBuilder() x509.CertificateBuilder()
.subject_name( .subject_name(
x509.Name([ x509.Name(
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"), [
x509.NameAttribute(NameOID.COMMON_NAME, hostname), x509.NameAttribute(NameOID.ORGANIZATION_NAME, "The Other Dude"),
]) x509.NameAttribute(NameOID.COMMON_NAME, hostname),
]
)
) )
.issuer_name(ca_cert.subject) .issuer_name(ca_cert.subject)
.public_key(device_key.public_key()) .public_key(device_key.public_key())
.serial_number(x509.random_serial_number()) .serial_number(x509.random_serial_number())
.not_valid_before(now) .not_valid_before(now)
.not_valid_after(expiry) .not_valid_after(expiry)
.add_extension( .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True)
x509.BasicConstraints(ca=False, path_length=None), critical=True
)
.add_extension( .add_extension(
x509.KeyUsage( x509.KeyUsage(
digital_signature=True, digital_signature=True,
@@ -244,17 +244,17 @@ async def sign_device_cert(
critical=False, critical=False,
) )
.add_extension( .add_extension(
x509.SubjectAlternativeName([ x509.SubjectAlternativeName(
x509.IPAddress(ipaddress.ip_address(ip_address)), [
x509.DNSName(hostname), x509.IPAddress(ipaddress.ip_address(ip_address)),
]), x509.DNSName(hostname),
]
),
critical=False, critical=False,
) )
.add_extension( .add_extension(
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier( x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
ca_cert.extensions.get_extension_for_class( ca_cert.extensions.get_extension_for_class(x509.SubjectKeyIdentifier).value
x509.SubjectKeyIdentifier
).value
), ),
critical=False, critical=False,
) )
@@ -308,15 +308,14 @@ async def sign_device_cert(
# Queries # Queries
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def get_ca_for_tenant( async def get_ca_for_tenant(
db: AsyncSession, db: AsyncSession,
tenant_id: UUID, tenant_id: UUID,
) -> CertificateAuthority | None: ) -> CertificateAuthority | None:
"""Return the tenant's CA, or None if not yet initialized.""" """Return the tenant's CA, or None if not yet initialized."""
result = await db.execute( result = await db.execute(
select(CertificateAuthority).where( select(CertificateAuthority).where(CertificateAuthority.tenant_id == tenant_id)
CertificateAuthority.tenant_id == tenant_id
)
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
@@ -352,6 +351,7 @@ async def get_device_certs(
# Status Management # Status Management
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def update_cert_status( async def update_cert_status(
db: AsyncSession, db: AsyncSession,
cert_id: UUID, cert_id: UUID,
@@ -377,9 +377,7 @@ async def update_cert_status(
Raises: Raises:
ValueError: If the certificate is not found or the transition is invalid. ValueError: If the certificate is not found or the transition is invalid.
""" """
result = await db.execute( result = await db.execute(select(DeviceCertificate).where(DeviceCertificate.id == cert_id))
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
)
cert = result.scalar_one_or_none() cert = result.scalar_one_or_none()
if cert is None: if cert is None:
raise ValueError(f"Device certificate {cert_id} not found") raise ValueError(f"Device certificate {cert_id} not found")
@@ -413,6 +411,7 @@ async def update_cert_status(
# Cert Data for Deployment # Cert Data for Deployment
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def get_cert_for_deploy( async def get_cert_for_deploy(
db: AsyncSession, db: AsyncSession,
cert_id: UUID, cert_id: UUID,
@@ -434,18 +433,14 @@ async def get_cert_for_deploy(
Raises: Raises:
ValueError: If the certificate or its CA is not found. ValueError: If the certificate or its CA is not found.
""" """
result = await db.execute( result = await db.execute(select(DeviceCertificate).where(DeviceCertificate.id == cert_id))
select(DeviceCertificate).where(DeviceCertificate.id == cert_id)
)
cert = result.scalar_one_or_none() cert = result.scalar_one_or_none()
if cert is None: if cert is None:
raise ValueError(f"Device certificate {cert_id} not found") raise ValueError(f"Device certificate {cert_id} not found")
# Fetch the CA for the ca_cert_pem # Fetch the CA for the ca_cert_pem
ca_result = await db.execute( ca_result = await db.execute(
select(CertificateAuthority).where( select(CertificateAuthority).where(CertificateAuthority.id == cert.ca_id)
CertificateAuthority.id == cert.ca_id
)
) )
ca = ca_result.scalar_one_or_none() ca = ca_result.scalar_one_or_none()
if ca is None: if ca is None:

View File

@@ -64,9 +64,7 @@ def parse_diff_changes(diff_text: str) -> list[dict]:
current_section: str | None = None current_section: str | None = None
# Track per-component: adds, removes, raw_lines # Track per-component: adds, removes, raw_lines
components: dict[str, dict] = defaultdict( components: dict[str, dict] = defaultdict(lambda: {"adds": 0, "removes": 0, "raw_lines": []})
lambda: {"adds": 0, "removes": 0, "raw_lines": []}
)
for line in lines: for line in lines:
# Skip unified diff headers # Skip unified diff headers

View File

@@ -51,13 +51,15 @@ async def handle_config_changed(event: dict) -> None:
if await _last_backup_within_dedup_window(device_id): if await _last_backup_within_dedup_window(device_id):
logger.info( logger.info(
"Config change on device %s — skipping backup (within %dm dedup window)", "Config change on device %s — skipping backup (within %dm dedup window)",
device_id, DEDUP_WINDOW_MINUTES, device_id,
DEDUP_WINDOW_MINUTES,
) )
return return
logger.info( logger.info(
"Config change detected on device %s (tenant %s): %s -> %s", "Config change detected on device %s (tenant %s): %s -> %s",
device_id, tenant_id, device_id,
tenant_id,
event.get("old_timestamp", "?"), event.get("old_timestamp", "?"),
event.get("new_timestamp", "?"), event.get("new_timestamp", "?"),
) )

View File

@@ -65,9 +65,7 @@ async def generate_and_store_diff(
# 2. No previous snapshot = first snapshot for device # 2. No previous snapshot = first snapshot for device
if prev_row is None: if prev_row is None:
logger.debug( logger.debug("First snapshot for device %s, no diff to generate", device_id)
"First snapshot for device %s, no diff to generate", device_id
)
return return
old_snapshot_id = prev_row._mapping["id"] old_snapshot_id = prev_row._mapping["id"]
@@ -103,9 +101,7 @@ async def generate_and_store_diff(
# 6. Generate unified diff # 6. Generate unified diff
old_lines = old_plaintext.decode("utf-8").splitlines() old_lines = old_plaintext.decode("utf-8").splitlines()
new_lines = new_plaintext.decode("utf-8").splitlines() new_lines = new_plaintext.decode("utf-8").splitlines()
diff_lines = list( diff_lines = list(difflib.unified_diff(old_lines, new_lines, lineterm="", n=3))
difflib.unified_diff(old_lines, new_lines, lineterm="", n=3)
)
# 7. If empty diff, skip INSERT # 7. If empty diff, skip INSERT
diff_text = "\n".join(diff_lines) diff_text = "\n".join(diff_lines)
@@ -118,12 +114,10 @@ async def generate_and_store_diff(
# 8. Count lines added/removed (exclude +++ and --- headers) # 8. Count lines added/removed (exclude +++ and --- headers)
lines_added = sum( lines_added = sum(
1 for line in diff_lines 1 for line in diff_lines if line.startswith("+") and not line.startswith("++")
if line.startswith("+") and not line.startswith("++")
) )
lines_removed = sum( lines_removed = sum(
1 for line in diff_lines 1 for line in diff_lines if line.startswith("-") and not line.startswith("--")
if line.startswith("-") and not line.startswith("--")
) )
# 9. INSERT into router_config_diffs (RETURNING id for change parser) # 9. INSERT into router_config_diffs (RETURNING id for change parser)
@@ -165,6 +159,7 @@ async def generate_and_store_diff(
try: try:
from app.services.audit_service import log_action from app.services.audit_service import log_action
import uuid as _uuid import uuid as _uuid
await log_action( await log_action(
db=None, db=None,
tenant_id=_uuid.UUID(tenant_id), tenant_id=_uuid.UUID(tenant_id),
@@ -207,12 +202,16 @@ async def generate_and_store_diff(
await session.commit() await session.commit()
logger.info( logger.info(
"Stored %d config changes for device %s diff %s", "Stored %d config changes for device %s diff %s",
len(changes), device_id, diff_id, len(changes),
device_id,
diff_id,
) )
except Exception as exc: except Exception as exc:
logger.warning( logger.warning(
"Change parser error for device %s diff %s (non-fatal): %s", "Change parser error for device %s diff %s (non-fatal): %s",
device_id, diff_id, exc, device_id,
diff_id,
exc,
) )
config_diff_errors_total.labels(error_type="change_parser").inc() config_diff_errors_total.labels(error_type="change_parser").inc()

View File

@@ -89,9 +89,11 @@ async def handle_config_snapshot(msg) -> None:
collected_at_raw = data.get("collected_at") collected_at_raw = data.get("collected_at")
try: try:
collected_at = datetime.fromisoformat( collected_at = (
collected_at_raw.replace("Z", "+00:00") datetime.fromisoformat(collected_at_raw.replace("Z", "+00:00"))
) if collected_at_raw else datetime.now(timezone.utc) if collected_at_raw
else datetime.now(timezone.utc)
)
except (ValueError, AttributeError): except (ValueError, AttributeError):
collected_at = datetime.now(timezone.utc) collected_at = datetime.now(timezone.utc)
@@ -131,9 +133,7 @@ async def handle_config_snapshot(msg) -> None:
# --- Encrypt via OpenBao Transit --- # --- Encrypt via OpenBao Transit ---
openbao = OpenBaoTransitService() openbao = OpenBaoTransitService()
try: try:
encrypted_text = await openbao.encrypt( encrypted_text = await openbao.encrypt(tenant_id, config_text.encode("utf-8"))
tenant_id, config_text.encode("utf-8")
)
except Exception as exc: except Exception as exc:
logger.warning( logger.warning(
"Transit encrypt failed for device %s tenant %s: %s", "Transit encrypt failed for device %s tenant %s: %s",
@@ -207,7 +207,8 @@ async def handle_config_snapshot(msg) -> None:
except Exception as exc: except Exception as exc:
logger.warning( logger.warning(
"Diff generation failed for device %s (non-fatal): %s", "Diff generation failed for device %s (non-fatal): %s",
device_id, exc, device_id,
exc,
) )
logger.info( logger.info(
@@ -233,9 +234,7 @@ async def _subscribe_with_retry(js) -> None:
stream="DEVICE_EVENTS", stream="DEVICE_EVENTS",
manual_ack=True, manual_ack=True,
) )
logger.info( logger.info("NATS: subscribed to config.snapshot.> (durable: config_snapshot_ingest)")
"NATS: subscribed to config.snapshot.> (durable: config_snapshot_ingest)"
)
return return
except Exception as exc: except Exception as exc:
if attempt < max_attempts: if attempt < max_attempts:

View File

@@ -29,8 +29,6 @@ from app.models.device import (
DeviceTagAssignment, DeviceTagAssignment,
) )
from app.schemas.device import ( from app.schemas.device import (
BulkAddRequest,
BulkAddResult,
DeviceCreate, DeviceCreate,
DeviceGroupCreate, DeviceGroupCreate,
DeviceGroupResponse, DeviceGroupResponse,
@@ -43,9 +41,7 @@ from app.schemas.device import (
) )
from app.config import settings from app.config import settings
from app.services.crypto import ( from app.services.crypto import (
decrypt_credentials,
decrypt_credentials_hybrid, decrypt_credentials_hybrid,
encrypt_credentials,
encrypt_credentials_transit, encrypt_credentials_transit,
) )
@@ -58,9 +54,7 @@ from app.services.crypto import (
async def _tcp_reachable(ip: str, port: int, timeout: float = 3.0) -> bool: async def _tcp_reachable(ip: str, port: int, timeout: float = 3.0) -> bool:
"""Return True if a TCP connection to ip:port succeeds within timeout.""" """Return True if a TCP connection to ip:port succeeds within timeout."""
try: try:
_, writer = await asyncio.wait_for( _, writer = await asyncio.wait_for(asyncio.open_connection(ip, port), timeout=timeout)
asyncio.open_connection(ip, port), timeout=timeout
)
writer.close() writer.close()
try: try:
await writer.wait_closed() await writer.wait_closed()
@@ -151,6 +145,7 @@ async def create_device(
if not api_reachable and not ssl_reachable: if not api_reachable and not ssl_reachable:
from fastapi import HTTPException, status from fastapi import HTTPException, status
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=( detail=(
@@ -162,9 +157,7 @@ async def create_device(
# Encrypt credentials via OpenBao Transit (new writes go through Transit) # Encrypt credentials via OpenBao Transit (new writes go through Transit)
credentials_json = json.dumps({"username": data.username, "password": data.password}) credentials_json = json.dumps({"username": data.username, "password": data.password})
transit_ciphertext = await encrypt_credentials_transit( transit_ciphertext = await encrypt_credentials_transit(credentials_json, str(tenant_id))
credentials_json, str(tenant_id)
)
device = Device( device = Device(
tenant_id=tenant_id, tenant_id=tenant_id,
@@ -180,9 +173,7 @@ async def create_device(
await db.refresh(device) await db.refresh(device)
# Re-query with relationships loaded # Re-query with relationships loaded
result = await db.execute( result = await db.execute(_device_with_relations().where(Device.id == device.id))
_device_with_relations().where(Device.id == device.id)
)
device = result.scalar_one() device = result.scalar_one()
return _build_device_response(device) return _build_device_response(device)
@@ -223,9 +214,7 @@ async def get_devices(
if tag_id: if tag_id:
base_q = base_q.where( base_q = base_q.where(
Device.id.in_( Device.id.in_(
select(DeviceTagAssignment.device_id).where( select(DeviceTagAssignment.device_id).where(DeviceTagAssignment.tag_id == tag_id)
DeviceTagAssignment.tag_id == tag_id
)
) )
) )
@@ -274,9 +263,7 @@ async def get_device(
"""Get a single device by ID.""" """Get a single device by ID."""
from fastapi import HTTPException, status from fastapi import HTTPException, status
result = await db.execute( result = await db.execute(_device_with_relations().where(Device.id == device_id))
_device_with_relations().where(Device.id == device_id)
)
device = result.scalar_one_or_none() device = result.scalar_one_or_none()
if not device: if not device:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
@@ -295,9 +282,7 @@ async def update_device(
""" """
from fastapi import HTTPException, status from fastapi import HTTPException, status
result = await db.execute( result = await db.execute(_device_with_relations().where(Device.id == device_id))
_device_with_relations().where(Device.id == device_id)
)
device = result.scalar_one_or_none() device = result.scalar_one_or_none()
if not device: if not device:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Device not found")
@@ -323,7 +308,9 @@ async def update_device(
if data.password is not None: if data.password is not None:
# Decrypt existing to get current username if no new username given # Decrypt existing to get current username if no new username given
current_username: str = data.username or "" current_username: str = data.username or ""
if not current_username and (device.encrypted_credentials_transit or device.encrypted_credentials): if not current_username and (
device.encrypted_credentials_transit or device.encrypted_credentials
):
try: try:
existing_json = await decrypt_credentials_hybrid( existing_json = await decrypt_credentials_hybrid(
device.encrypted_credentials_transit, device.encrypted_credentials_transit,
@@ -336,17 +323,21 @@ async def update_device(
except Exception: except Exception:
current_username = "" current_username = ""
credentials_json = json.dumps({ credentials_json = json.dumps(
"username": data.username if data.username is not None else current_username, {
"password": data.password, "username": data.username if data.username is not None else current_username,
}) "password": data.password,
}
)
# New writes go through Transit # New writes go through Transit
device.encrypted_credentials_transit = await encrypt_credentials_transit( device.encrypted_credentials_transit = await encrypt_credentials_transit(
credentials_json, str(device.tenant_id) credentials_json, str(device.tenant_id)
) )
device.encrypted_credentials = None # Clear legacy (Transit is canonical) device.encrypted_credentials = None # Clear legacy (Transit is canonical)
credentials_changed = True credentials_changed = True
elif data.username is not None and (device.encrypted_credentials_transit or device.encrypted_credentials): elif data.username is not None and (
device.encrypted_credentials_transit or device.encrypted_credentials
):
# Only username changed — update it without changing the password # Only username changed — update it without changing the password
try: try:
existing_json = await decrypt_credentials_hybrid( existing_json = await decrypt_credentials_hybrid(
@@ -373,6 +364,7 @@ async def update_device(
if credentials_changed: if credentials_changed:
try: try:
from app.services.event_publisher import publish_event from app.services.event_publisher import publish_event
await publish_event( await publish_event(
f"device.credential_changed.{device_id}", f"device.credential_changed.{device_id}",
{"device_id": str(device_id), "tenant_id": str(tenant_id)}, {"device_id": str(device_id), "tenant_id": str(tenant_id)},
@@ -380,9 +372,7 @@ async def update_device(
except Exception: except Exception:
pass # Never fail the update due to NATS issues pass # Never fail the update due to NATS issues
result2 = await db.execute( result2 = await db.execute(_device_with_relations().where(Device.id == device_id))
_device_with_relations().where(Device.id == device_id)
)
device = result2.scalar_one() device = result2.scalar_one()
return _build_device_response(device) return _build_device_response(device)
@@ -526,11 +516,7 @@ async def get_groups(
tenant_id: uuid.UUID, tenant_id: uuid.UUID,
) -> list[DeviceGroupResponse]: ) -> list[DeviceGroupResponse]:
"""Return all device groups for the current tenant with device counts.""" """Return all device groups for the current tenant with device counts."""
result = await db.execute( result = await db.execute(select(DeviceGroup).options(selectinload(DeviceGroup.memberships)))
select(DeviceGroup).options(
selectinload(DeviceGroup.memberships)
)
)
groups = result.scalars().all() groups = result.scalars().all()
return [ return [
DeviceGroupResponse( DeviceGroupResponse(
@@ -554,9 +540,9 @@ async def update_group(
from fastapi import HTTPException, status from fastapi import HTTPException, status
result = await db.execute( result = await db.execute(
select(DeviceGroup).options( select(DeviceGroup)
selectinload(DeviceGroup.memberships) .options(selectinload(DeviceGroup.memberships))
).where(DeviceGroup.id == group_id) .where(DeviceGroup.id == group_id)
) )
group = result.scalar_one_or_none() group = result.scalar_one_or_none()
if not group: if not group:
@@ -571,9 +557,9 @@ async def update_group(
await db.refresh(group) await db.refresh(group)
result2 = await db.execute( result2 = await db.execute(
select(DeviceGroup).options( select(DeviceGroup)
selectinload(DeviceGroup.memberships) .options(selectinload(DeviceGroup.memberships))
).where(DeviceGroup.id == group_id) .where(DeviceGroup.id == group_id)
) )
group = result2.scalar_one() group = result2.scalar_one()
return DeviceGroupResponse( return DeviceGroupResponse(

View File

@@ -48,7 +48,5 @@ async def generate_emergency_kit_template(
# Run weasyprint in thread to avoid blocking the event loop # Run weasyprint in thread to avoid blocking the event loop
from weasyprint import HTML from weasyprint import HTML
pdf_bytes = await asyncio.to_thread( pdf_bytes = await asyncio.to_thread(lambda: HTML(string=html_content).write_pdf())
lambda: HTML(string=html_content).write_pdf()
)
return pdf_bytes return pdf_bytes

View File

@@ -13,7 +13,6 @@ Version discovery comes from two sources:
""" """
import logging import logging
import os
from pathlib import Path from pathlib import Path
import httpx import httpx
@@ -51,13 +50,12 @@ async def check_latest_versions() -> list[dict]:
async with httpx.AsyncClient(timeout=30.0) as client: async with httpx.AsyncClient(timeout=30.0) as client:
for channel_file, channel, major in _VERSION_SOURCES: for channel_file, channel, major in _VERSION_SOURCES:
try: try:
resp = await client.get( resp = await client.get(f"https://download.mikrotik.com/routeros/{channel_file}")
f"https://download.mikrotik.com/routeros/{channel_file}"
)
if resp.status_code != 200: if resp.status_code != 200:
logger.warning( logger.warning(
"MikroTik version check returned %d for %s", "MikroTik version check returned %d for %s",
resp.status_code, channel_file, resp.status_code,
channel_file,
) )
continue continue
@@ -72,12 +70,14 @@ async def check_latest_versions() -> list[dict]:
f"https://download.mikrotik.com/routeros/" f"https://download.mikrotik.com/routeros/"
f"{version}/routeros-{version}-{arch}.npk" f"{version}/routeros-{version}-{arch}.npk"
) )
results.append({ results.append(
"architecture": arch, {
"channel": channel, "architecture": arch,
"version": version, "channel": channel,
"npk_url": npk_url, "version": version,
}) "npk_url": npk_url,
}
)
except Exception as e: except Exception as e:
logger.warning("Failed to check %s: %s", channel_file, e) logger.warning("Failed to check %s: %s", channel_file, e)
@@ -242,15 +242,15 @@ async def get_firmware_overview(tenant_id: str) -> dict:
groups = [] groups = []
for ver, devs in sorted(version_groups.items()): for ver, devs in sorted(version_groups.items()):
# A version is "latest" if it matches the latest for any arch/channel combo # A version is "latest" if it matches the latest for any arch/channel combo
is_latest = any( is_latest = any(v["version"] == ver for v in latest_versions.values())
v["version"] == ver for v in latest_versions.values() groups.append(
{
"version": ver,
"count": len(devs),
"is_latest": is_latest,
"devices": devs,
}
) )
groups.append({
"version": ver,
"count": len(devs),
"is_latest": is_latest,
"devices": devs,
})
return { return {
"devices": device_list, "devices": device_list,
@@ -272,12 +272,14 @@ async def get_cached_firmware() -> list[dict]:
continue continue
for npk_file in sorted(version_dir.iterdir()): for npk_file in sorted(version_dir.iterdir()):
if npk_file.suffix == ".npk": if npk_file.suffix == ".npk":
cached.append({ cached.append(
"path": str(npk_file), {
"version": version_dir.name, "path": str(npk_file),
"filename": npk_file.name, "version": version_dir.name,
"size_bytes": npk_file.stat().st_size, "filename": npk_file.name,
}) "size_bytes": npk_file.stat().st_size,
}
)
return cached return cached

View File

@@ -41,7 +41,7 @@ async def on_device_firmware(msg) -> None:
try: try:
data = json.loads(msg.data) data = json.loads(msg.data)
device_id = data.get("device_id") device_id = data.get("device_id")
tenant_id = data.get("tenant_id") _tenant_id = data.get("tenant_id")
architecture = data.get("architecture") architecture = data.get("architecture")
installed_version = data.get("installed_version") installed_version = data.get("installed_version")
latest_version = data.get("latest_version") latest_version = data.get("latest_version")
@@ -126,9 +126,7 @@ async def _subscribe_with_retry(js: JetStreamContext) -> None:
durable="api-firmware-consumer", durable="api-firmware-consumer",
stream="DEVICE_EVENTS", stream="DEVICE_EVENTS",
) )
logger.info( logger.info("NATS: subscribed to device.firmware.> (durable: api-firmware-consumer)")
"NATS: subscribed to device.firmware.> (durable: api-firmware-consumer)"
)
return return
except Exception as exc: except Exception as exc:
if attempt < max_attempts: if attempt < max_attempts:

View File

@@ -238,11 +238,13 @@ def list_device_commits(
# Device wasn't in parent but is in this commit — it's the first entry. # Device wasn't in parent but is in this commit — it's the first entry.
pass pass
results.append({ results.append(
"sha": str(commit.id), {
"message": commit.message.strip(), "sha": str(commit.id),
"timestamp": commit.commit_time, "message": commit.message.strip(),
}) "timestamp": commit.commit_time,
}
)
return results return results

View File

@@ -52,6 +52,7 @@ async def store_user_key_set(
""" """
# Remove any existing key set (e.g. from a failed prior upgrade attempt) # Remove any existing key set (e.g. from a failed prior upgrade attempt)
from sqlalchemy import delete from sqlalchemy import delete
await db.execute(delete(UserKeySet).where(UserKeySet.user_id == user_id)) await db.execute(delete(UserKeySet).where(UserKeySet.user_id == user_id))
key_set = UserKeySet( key_set = UserKeySet(
@@ -71,9 +72,7 @@ async def store_user_key_set(
return key_set return key_set
async def get_user_key_set( async def get_user_key_set(db: AsyncSession, user_id: UUID) -> UserKeySet | None:
db: AsyncSession, user_id: UUID
) -> UserKeySet | None:
"""Retrieve encrypted key bundle for login response. """Retrieve encrypted key bundle for login response.
Args: Args:
@@ -83,9 +82,7 @@ async def get_user_key_set(
Returns: Returns:
The UserKeySet if found, None otherwise. The UserKeySet if found, None otherwise.
""" """
result = await db.execute( result = await db.execute(select(UserKeySet).where(UserKeySet.user_id == user_id))
select(UserKeySet).where(UserKeySet.user_id == user_id)
)
return result.scalar_one_or_none() return result.scalar_one_or_none()
@@ -163,9 +160,7 @@ async def provision_tenant_key(db: AsyncSession, tenant_id: UUID) -> str:
await openbao.create_tenant_key(str(tenant_id)) await openbao.create_tenant_key(str(tenant_id))
# Update tenant record with key name # Update tenant record with key name
result = await db.execute( result = await db.execute(select(Tenant).where(Tenant.id == tenant_id))
select(Tenant).where(Tenant.id == tenant_id)
)
tenant = result.scalar_one_or_none() tenant = result.scalar_one_or_none()
if tenant: if tenant:
tenant.openbao_key_name = key_name tenant.openbao_key_name = key_name
@@ -210,13 +205,18 @@ async def migrate_tenant_credentials(db: AsyncSession, tenant_id: UUID) -> dict:
select(Device).where( select(Device).where(
Device.tenant_id == tenant_id, Device.tenant_id == tenant_id,
Device.encrypted_credentials.isnot(None), Device.encrypted_credentials.isnot(None),
(Device.encrypted_credentials_transit.is_(None) | (Device.encrypted_credentials_transit == "")), (
Device.encrypted_credentials_transit.is_(None)
| (Device.encrypted_credentials_transit == "")
),
) )
) )
for device in result.scalars().all(): for device in result.scalars().all():
try: try:
plaintext = decrypt_credentials(device.encrypted_credentials, legacy_key) plaintext = decrypt_credentials(device.encrypted_credentials, legacy_key)
device.encrypted_credentials_transit = await openbao.encrypt(tid, plaintext.encode("utf-8")) device.encrypted_credentials_transit = await openbao.encrypt(
tid, plaintext.encode("utf-8")
)
counts["devices"] += 1 counts["devices"] += 1
except Exception as e: except Exception as e:
logger.error("Failed to migrate device %s credentials: %s", device.id, e) logger.error("Failed to migrate device %s credentials: %s", device.id, e)
@@ -227,7 +227,10 @@ async def migrate_tenant_credentials(db: AsyncSession, tenant_id: UUID) -> dict:
select(CertificateAuthority).where( select(CertificateAuthority).where(
CertificateAuthority.tenant_id == tenant_id, CertificateAuthority.tenant_id == tenant_id,
CertificateAuthority.encrypted_private_key.isnot(None), CertificateAuthority.encrypted_private_key.isnot(None),
(CertificateAuthority.encrypted_private_key_transit.is_(None) | (CertificateAuthority.encrypted_private_key_transit == "")), (
CertificateAuthority.encrypted_private_key_transit.is_(None)
| (CertificateAuthority.encrypted_private_key_transit == "")
),
) )
) )
for ca in result.scalars().all(): for ca in result.scalars().all():
@@ -244,13 +247,18 @@ async def migrate_tenant_credentials(db: AsyncSession, tenant_id: UUID) -> dict:
select(DeviceCertificate).where( select(DeviceCertificate).where(
DeviceCertificate.tenant_id == tenant_id, DeviceCertificate.tenant_id == tenant_id,
DeviceCertificate.encrypted_private_key.isnot(None), DeviceCertificate.encrypted_private_key.isnot(None),
(DeviceCertificate.encrypted_private_key_transit.is_(None) | (DeviceCertificate.encrypted_private_key_transit == "")), (
DeviceCertificate.encrypted_private_key_transit.is_(None)
| (DeviceCertificate.encrypted_private_key_transit == "")
),
) )
) )
for cert in result.scalars().all(): for cert in result.scalars().all():
try: try:
plaintext = decrypt_credentials(cert.encrypted_private_key, legacy_key) plaintext = decrypt_credentials(cert.encrypted_private_key, legacy_key)
cert.encrypted_private_key_transit = await openbao.encrypt(tid, plaintext.encode("utf-8")) cert.encrypted_private_key_transit = await openbao.encrypt(
tid, plaintext.encode("utf-8")
)
counts["certs"] += 1 counts["certs"] += 1
except Exception as e: except Exception as e:
logger.error("Failed to migrate cert %s private key: %s", cert.id, e) logger.error("Failed to migrate cert %s private key: %s", cert.id, e)
@@ -303,7 +311,14 @@ async def provision_existing_tenants(db: AsyncSession) -> dict:
result = await db.execute(select(Tenant)) result = await db.execute(select(Tenant))
tenants = result.scalars().all() tenants = result.scalars().all()
total = {"tenants": len(tenants), "devices": 0, "cas": 0, "certs": 0, "channels": 0, "errors": 0} total = {
"tenants": len(tenants),
"devices": 0,
"cas": 0,
"certs": 0,
"channels": 0,
"errors": 0,
}
for tenant in tenants: for tenant in tenants:
try: try:

View File

@@ -199,9 +199,7 @@ async def on_device_metrics(msg) -> None:
device_id = data.get("device_id") device_id = data.get("device_id")
if not metric_type or not device_id: if not metric_type or not device_id:
logger.warning( logger.warning("device.metrics event missing 'type' or 'device_id' — skipping")
"device.metrics event missing 'type' or 'device_id' — skipping"
)
await msg.ack() await msg.ack()
return return
@@ -222,6 +220,7 @@ async def on_device_metrics(msg) -> None:
# Alert evaluation — non-fatal; metric write is the primary operation # Alert evaluation — non-fatal; metric write is the primary operation
try: try:
from app.services import alert_evaluator from app.services import alert_evaluator
await alert_evaluator.evaluate( await alert_evaluator.evaluate(
device_id=device_id, device_id=device_id,
tenant_id=data.get("tenant_id", ""), tenant_id=data.get("tenant_id", ""),
@@ -265,9 +264,7 @@ async def _subscribe_with_retry(js: JetStreamContext) -> None:
durable="api-metrics-consumer", durable="api-metrics-consumer",
stream="DEVICE_EVENTS", stream="DEVICE_EVENTS",
) )
logger.info( logger.info("NATS: subscribed to device.metrics.> (durable: api-metrics-consumer)")
"NATS: subscribed to device.metrics.> (durable: api-metrics-consumer)"
)
return return
except Exception as exc: except Exception as exc:
if attempt < max_attempts: if attempt < max_attempts:

View File

@@ -69,7 +69,9 @@ async def on_device_status(msg) -> None:
uptime_seconds = _parse_uptime(data.get("uptime", "")) uptime_seconds = _parse_uptime(data.get("uptime", ""))
if not device_id or not status: if not device_id or not status:
logger.warning("Received device.status event with missing device_id or status — skipping") logger.warning(
"Received device.status event with missing device_id or status — skipping"
)
await msg.ack() await msg.ack()
return return
@@ -115,12 +117,15 @@ async def on_device_status(msg) -> None:
# Alert evaluation for offline/online status changes — non-fatal # Alert evaluation for offline/online status changes — non-fatal
try: try:
from app.services import alert_evaluator from app.services import alert_evaluator
if status == "offline": if status == "offline":
await alert_evaluator.evaluate_offline(device_id, data.get("tenant_id", "")) await alert_evaluator.evaluate_offline(device_id, data.get("tenant_id", ""))
elif status == "online": elif status == "online":
await alert_evaluator.evaluate_online(device_id, data.get("tenant_id", "")) await alert_evaluator.evaluate_online(device_id, data.get("tenant_id", ""))
except Exception as e: except Exception as e:
logger.warning("Alert evaluation failed for device %s status=%s: %s", device_id, status, e) logger.warning(
"Alert evaluation failed for device %s status=%s: %s", device_id, status, e
)
logger.info( logger.info(
"Device status updated", "Device status updated",

View File

@@ -39,7 +39,9 @@ async def dispatch_notifications(
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"Notification delivery failed for channel %s (%s): %s", "Notification delivery failed for channel %s (%s): %s",
channel.get("name"), channel.get("channel_type"), e, channel.get("name"),
channel.get("channel_type"),
e,
) )
@@ -100,14 +102,18 @@ async def _send_email(channel: dict, alert_event: dict, device_hostname: str) ->
if transit_cipher and tenant_id: if transit_cipher and tenant_id:
try: try:
from app.services.kms_service import decrypt_transit from app.services.kms_service import decrypt_transit
smtp_password = await decrypt_transit(transit_cipher, tenant_id) smtp_password = await decrypt_transit(transit_cipher, tenant_id)
except Exception: except Exception:
logger.warning("Transit decryption failed for channel %s, trying legacy", channel.get("id")) logger.warning(
"Transit decryption failed for channel %s, trying legacy", channel.get("id")
)
if not smtp_password and legacy_cipher: if not smtp_password and legacy_cipher:
try: try:
from app.config import settings as app_settings from app.config import settings as app_settings
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
raw = bytes(legacy_cipher) if isinstance(legacy_cipher, memoryview) else legacy_cipher raw = bytes(legacy_cipher) if isinstance(legacy_cipher, memoryview) else legacy_cipher
f = Fernet(app_settings.CREDENTIAL_ENCRYPTION_KEY.encode()) f = Fernet(app_settings.CREDENTIAL_ENCRYPTION_KEY.encode())
smtp_password = f.decrypt(raw).decode() smtp_password = f.decrypt(raw).decode()
@@ -163,7 +169,8 @@ async def _send_webhook(
response = await client.post(webhook_url, json=payload) response = await client.post(webhook_url, json=payload)
logger.info( logger.info(
"Webhook notification sent to %s — status %d", "Webhook notification sent to %s — status %d",
webhook_url, response.status_code, webhook_url,
response.status_code,
) )
@@ -180,13 +187,18 @@ async def _send_slack(
value = alert_event.get("value") value = alert_event.get("value")
threshold = alert_event.get("threshold") threshold = alert_event.get("threshold")
color = {"CRITICAL": "#dc2626", "WARNING": "#f59e0b", "INFO": "#3b82f6"}.get(severity, "#6b7280") color = {"CRITICAL": "#dc2626", "WARNING": "#f59e0b", "INFO": "#3b82f6"}.get(
severity, "#6b7280"
)
status_label = "RESOLVED" if status == "resolved" else status status_label = "RESOLVED" if status == "resolved" else status
blocks = [ blocks = [
{ {
"type": "header", "type": "header",
"text": {"type": "plain_text", "text": f"{'' if status == 'resolved' else '🚨'} [{severity}] {status_label.upper()}"}, "text": {
"type": "plain_text",
"text": f"{'' if status == 'resolved' else '🚨'} [{severity}] {status_label.upper()}",
},
}, },
{ {
"type": "section", "type": "section",
@@ -205,7 +217,9 @@ async def _send_slack(
blocks.append({"type": "section", "fields": fields}) blocks.append({"type": "section", "fields": fields})
if message_text: if message_text:
blocks.append({"type": "section", "text": {"type": "mrkdwn", "text": f"*Message:*\n{message_text}"}}) blocks.append(
{"type": "section", "text": {"type": "mrkdwn", "text": f"*Message:*\n{message_text}"}}
)
blocks.append({"type": "context", "elements": [{"type": "mrkdwn", "text": "TOD Alert System"}]}) blocks.append({"type": "context", "elements": [{"type": "mrkdwn", "text": "TOD Alert System"}]})

View File

@@ -7,6 +7,7 @@ material never leaves OpenBao -- the application only sees ciphertext.
Ciphertext format: "vault:v1:base64..." (compatible with Vault Transit format) Ciphertext format: "vault:v1:base64..." (compatible with Vault Transit format)
""" """
import base64 import base64
import logging import logging
from typing import Optional from typing import Optional

View File

@@ -47,13 +47,15 @@ async def record_push(
""" """
r = await _get_redis() r = await _get_redis()
key = f"push:recent:{device_id}" key = f"push:recent:{device_id}"
value = json.dumps({ value = json.dumps(
"device_id": device_id, {
"tenant_id": tenant_id, "device_id": device_id,
"push_type": push_type, "tenant_id": tenant_id,
"push_operation_id": push_operation_id, "push_type": push_type,
"pre_push_commit_sha": pre_push_commit_sha, "push_operation_id": push_operation_id,
}) "pre_push_commit_sha": pre_push_commit_sha,
}
)
await r.set(key, value, ex=PUSH_TTL_SECONDS) await r.set(key, value, ex=PUSH_TTL_SECONDS)
logger.debug( logger.debug(
"Recorded push for device %s (type=%s, TTL=%ds)", "Recorded push for device %s (type=%s, TTL=%ds)",

View File

@@ -164,16 +164,18 @@ async def _device_inventory(
uptime_str = _format_uptime(row[6]) if row[6] else None uptime_str = _format_uptime(row[6]) if row[6] else None
last_seen_str = row[5].strftime("%Y-%m-%d %H:%M") if row[5] else None last_seen_str = row[5].strftime("%Y-%m-%d %H:%M") if row[5] else None
devices.append({ devices.append(
"hostname": row[0], {
"ip_address": row[1], "hostname": row[0],
"model": row[2], "ip_address": row[1],
"routeros_version": row[3], "model": row[2],
"status": status, "routeros_version": row[3],
"last_seen": last_seen_str, "status": status,
"uptime": uptime_str, "last_seen": last_seen_str,
"groups": row[7] if row[7] else None, "uptime": uptime_str,
}) "groups": row[7] if row[7] else None,
}
)
return { return {
"report_title": "Device Inventory", "report_title": "Device Inventory",
@@ -224,16 +226,18 @@ async def _metrics_summary(
devices = [] devices = []
for row in rows: for row in rows:
devices.append({ devices.append(
"hostname": row[0], {
"avg_cpu": float(row[1]) if row[1] is not None else None, "hostname": row[0],
"peak_cpu": float(row[2]) if row[2] is not None else None, "avg_cpu": float(row[1]) if row[1] is not None else None,
"avg_mem": float(row[3]) if row[3] is not None else None, "peak_cpu": float(row[2]) if row[2] is not None else None,
"peak_mem": float(row[4]) if row[4] is not None else None, "avg_mem": float(row[3]) if row[3] is not None else None,
"avg_disk": float(row[5]) if row[5] is not None else None, "peak_mem": float(row[4]) if row[4] is not None else None,
"avg_temp": float(row[6]) if row[6] is not None else None, "avg_disk": float(row[5]) if row[5] is not None else None,
"data_points": row[7], "avg_temp": float(row[6]) if row[6] is not None else None,
}) "data_points": row[7],
}
)
return { return {
"report_title": "Metrics Summary", "report_title": "Metrics Summary",
@@ -287,14 +291,16 @@ async def _alert_history(
if duration_secs is not None: if duration_secs is not None:
resolved_durations.append(duration_secs) resolved_durations.append(duration_secs)
alerts.append({ alerts.append(
"fired_at": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-", {
"hostname": row[5], "fired_at": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"severity": severity, "hostname": row[5],
"status": row[3], "severity": severity,
"message": row[4], "status": row[3],
"duration": _format_duration(duration_secs) if duration_secs is not None else None, "message": row[4],
}) "duration": _format_duration(duration_secs) if duration_secs is not None else None,
}
)
mttr_minutes = None mttr_minutes = None
mttr_display = None mttr_display = None
@@ -374,13 +380,15 @@ async def _change_log_from_audit(
entries = [] entries = []
for row in rows: for row in rows:
entries.append({ entries.append(
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-", {
"user": row[1], "timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"action": row[2], "user": row[1],
"device": row[3], "action": row[2],
"details": row[4] or row[5] or "", "device": row[3],
}) "details": row[4] or row[5] or "",
}
)
return { return {
"report_title": "Change Log", "report_title": "Change Log",
@@ -436,21 +444,25 @@ async def _change_log_from_backups(
# Merge and sort by timestamp descending # Merge and sort by timestamp descending
entries = [] entries = []
for row in backup_rows: for row in backup_rows:
entries.append({ entries.append(
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-", {
"user": row[1], "timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"action": row[2], "user": row[1],
"device": row[3], "action": row[2],
"details": row[4] or "", "device": row[3],
}) "details": row[4] or "",
}
)
for row in alert_rows: for row in alert_rows:
entries.append({ entries.append(
"timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-", {
"user": row[1], "timestamp": row[0].strftime("%Y-%m-%d %H:%M") if row[0] else "-",
"action": row[2], "user": row[1],
"device": row[3], "action": row[2],
"details": row[4] or "", "device": row[3],
}) "details": row[4] or "",
}
)
# Sort by timestamp string descending # Sort by timestamp string descending
entries.sort(key=lambda e: e["timestamp"], reverse=True) entries.sort(key=lambda e: e["timestamp"], reverse=True)
@@ -486,54 +498,102 @@ def _render_csv(report_type: str, data: dict[str, Any]) -> bytes:
writer = csv.writer(output) writer = csv.writer(output)
if report_type == "device_inventory": if report_type == "device_inventory":
writer.writerow([ writer.writerow(
"Hostname", "IP Address", "Model", "RouterOS Version", [
"Status", "Last Seen", "Uptime", "Groups", "Hostname",
]) "IP Address",
"Model",
"RouterOS Version",
"Status",
"Last Seen",
"Uptime",
"Groups",
]
)
for d in data.get("devices", []): for d in data.get("devices", []):
writer.writerow([ writer.writerow(
d["hostname"], d["ip_address"], d["model"] or "", [
d["routeros_version"] or "", d["status"], d["hostname"],
d["last_seen"] or "", d["uptime"] or "", d["ip_address"],
d["groups"] or "", d["model"] or "",
]) d["routeros_version"] or "",
d["status"],
d["last_seen"] or "",
d["uptime"] or "",
d["groups"] or "",
]
)
elif report_type == "metrics_summary": elif report_type == "metrics_summary":
writer.writerow([ writer.writerow(
"Hostname", "Avg CPU %", "Peak CPU %", "Avg Memory %", [
"Peak Memory %", "Avg Disk %", "Avg Temp", "Data Points", "Hostname",
]) "Avg CPU %",
"Peak CPU %",
"Avg Memory %",
"Peak Memory %",
"Avg Disk %",
"Avg Temp",
"Data Points",
]
)
for d in data.get("devices", []): for d in data.get("devices", []):
writer.writerow([ writer.writerow(
d["hostname"], [
f"{d['avg_cpu']:.1f}" if d["avg_cpu"] is not None else "", d["hostname"],
f"{d['peak_cpu']:.1f}" if d["peak_cpu"] is not None else "", f"{d['avg_cpu']:.1f}" if d["avg_cpu"] is not None else "",
f"{d['avg_mem']:.1f}" if d["avg_mem"] is not None else "", f"{d['peak_cpu']:.1f}" if d["peak_cpu"] is not None else "",
f"{d['peak_mem']:.1f}" if d["peak_mem"] is not None else "", f"{d['avg_mem']:.1f}" if d["avg_mem"] is not None else "",
f"{d['avg_disk']:.1f}" if d["avg_disk"] is not None else "", f"{d['peak_mem']:.1f}" if d["peak_mem"] is not None else "",
f"{d['avg_temp']:.1f}" if d["avg_temp"] is not None else "", f"{d['avg_disk']:.1f}" if d["avg_disk"] is not None else "",
d["data_points"], f"{d['avg_temp']:.1f}" if d["avg_temp"] is not None else "",
]) d["data_points"],
]
)
elif report_type == "alert_history": elif report_type == "alert_history":
writer.writerow([ writer.writerow(
"Timestamp", "Device", "Severity", "Message", "Status", "Duration", [
]) "Timestamp",
"Device",
"Severity",
"Message",
"Status",
"Duration",
]
)
for a in data.get("alerts", []): for a in data.get("alerts", []):
writer.writerow([ writer.writerow(
a["fired_at"], a["hostname"] or "", a["severity"], [
a["message"] or "", a["status"], a["duration"] or "", a["fired_at"],
]) a["hostname"] or "",
a["severity"],
a["message"] or "",
a["status"],
a["duration"] or "",
]
)
elif report_type == "change_log": elif report_type == "change_log":
writer.writerow([ writer.writerow(
"Timestamp", "User", "Action", "Device", "Details", [
]) "Timestamp",
"User",
"Action",
"Device",
"Details",
]
)
for e in data.get("entries", []): for e in data.get("entries", []):
writer.writerow([ writer.writerow(
e["timestamp"], e["user"] or "", e["action"], [
e["device"] or "", e["details"] or "", e["timestamp"],
]) e["user"] or "",
e["action"],
e["device"] or "",
e["details"] or "",
]
)
return output.getvalue().encode("utf-8") return output.getvalue().encode("utf-8")

View File

@@ -33,7 +33,6 @@ import asyncssh
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings from app.config import settings
from app.database import set_tenant_context, AdminAsyncSessionLocal
from app.models.config_backup import ConfigPushOperation from app.models.config_backup import ConfigPushOperation
from app.models.device import Device from app.models.device import Device
from app.services import backup_service, git_store from app.services import backup_service, git_store
@@ -113,12 +112,11 @@ async def restore_config(
raise ValueError(f"Device {device_id!r} not found") raise ValueError(f"Device {device_id!r} not found")
if not device.encrypted_credentials_transit and not device.encrypted_credentials: if not device.encrypted_credentials_transit and not device.encrypted_credentials:
raise ValueError( raise ValueError(f"Device {device_id!r} has no stored credentials — cannot perform restore")
f"Device {device_id!r} has no stored credentials — cannot perform restore"
)
key = settings.get_encryption_key_bytes() key = settings.get_encryption_key_bytes()
from app.services.crypto import decrypt_credentials_hybrid from app.services.crypto import decrypt_credentials_hybrid
creds_json = await decrypt_credentials_hybrid( creds_json = await decrypt_credentials_hybrid(
device.encrypted_credentials_transit, device.encrypted_credentials_transit,
device.encrypted_credentials, device.encrypted_credentials,
@@ -133,7 +131,9 @@ async def restore_config(
hostname = device.hostname or ip hostname = device.hostname or ip
# Publish "started" progress event # Publish "started" progress event
await _publish_push_progress(tenant_id, device_id, "started", f"Config restore started for {hostname}") await _publish_push_progress(
tenant_id, device_id, "started", f"Config restore started for {hostname}"
)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Step 2: Read the target export.rsc from the backup commit # Step 2: Read the target export.rsc from the backup commit
@@ -157,7 +157,9 @@ async def restore_config(
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Step 3: Mandatory pre-backup before push # Step 3: Mandatory pre-backup before push
# ------------------------------------------------------------------ # ------------------------------------------------------------------
await _publish_push_progress(tenant_id, device_id, "backing_up", f"Creating pre-restore backup for {hostname}") await _publish_push_progress(
tenant_id, device_id, "backing_up", f"Creating pre-restore backup for {hostname}"
)
logger.info( logger.info(
"Starting pre-restore backup for device %s (%s) before pushing commit %s", "Starting pre-restore backup for device %s (%s) before pushing commit %s",
@@ -198,7 +200,9 @@ async def restore_config(
# Step 5: SSH to device — install panic-revert, push config # Step 5: SSH to device — install panic-revert, push config
# ------------------------------------------------------------------ # ------------------------------------------------------------------
push_op_id_str = str(push_op_id) push_op_id_str = str(push_op_id)
await _publish_push_progress(tenant_id, device_id, "pushing", f"Pushing config to {hostname}", push_op_id=push_op_id_str) await _publish_push_progress(
tenant_id, device_id, "pushing", f"Pushing config to {hostname}", push_op_id=push_op_id_str
)
logger.info( logger.info(
"Pushing config to device %s (%s): installing panic-revert scheduler and uploading config", "Pushing config to device %s (%s): installing panic-revert scheduler and uploading config",
@@ -290,9 +294,12 @@ async def restore_config(
# Update push operation to failed # Update push operation to failed
await _update_push_op_status(push_op_id, "failed", db_session) await _update_push_op_status(push_op_id, "failed", db_session)
await _publish_push_progress( await _publish_push_progress(
tenant_id, device_id, "failed", tenant_id,
device_id,
"failed",
f"Config push failed for {hostname}: {push_err}", f"Config push failed for {hostname}: {push_err}",
push_op_id=push_op_id_str, error=str(push_err), push_op_id=push_op_id_str,
error=str(push_err),
) )
return { return {
"status": "failed", "status": "failed",
@@ -312,7 +319,13 @@ async def restore_config(
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Step 6: Wait 60s for config to settle # Step 6: Wait 60s for config to settle
# ------------------------------------------------------------------ # ------------------------------------------------------------------
await _publish_push_progress(tenant_id, device_id, "settling", f"Config pushed to {hostname} — waiting 60s for settle", push_op_id=push_op_id_str) await _publish_push_progress(
tenant_id,
device_id,
"settling",
f"Config pushed to {hostname} — waiting 60s for settle",
push_op_id=push_op_id_str,
)
logger.info( logger.info(
"Config pushed to device %s — waiting 60s for config to settle", "Config pushed to device %s — waiting 60s for config to settle",
@@ -323,7 +336,13 @@ async def restore_config(
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Step 7: Reachability check # Step 7: Reachability check
# ------------------------------------------------------------------ # ------------------------------------------------------------------
await _publish_push_progress(tenant_id, device_id, "verifying", f"Verifying device {hostname} reachability", push_op_id=push_op_id_str) await _publish_push_progress(
tenant_id,
device_id,
"verifying",
f"Verifying device {hostname} reachability",
push_op_id=push_op_id_str,
)
reachable = await _check_reachability(ip, ssh_username, ssh_password) reachable = await _check_reachability(ip, ssh_username, ssh_password)
@@ -362,7 +381,13 @@ async def restore_config(
await _update_push_op_status(push_op_id, "committed", db_session) await _update_push_op_status(push_op_id, "committed", db_session)
await clear_push(device_id) await clear_push(device_id)
await _publish_push_progress(tenant_id, device_id, "committed", f"Config restored successfully on {hostname}", push_op_id=push_op_id_str) await _publish_push_progress(
tenant_id,
device_id,
"committed",
f"Config restored successfully on {hostname}",
push_op_id=push_op_id_str,
)
return { return {
"status": "committed", "status": "committed",
@@ -384,7 +409,9 @@ async def restore_config(
await _update_push_op_status(push_op_id, "reverted", db_session) await _update_push_op_status(push_op_id, "reverted", db_session)
await _publish_push_progress( await _publish_push_progress(
tenant_id, device_id, "reverted", tenant_id,
device_id,
"reverted",
f"Device {hostname} unreachable — auto-reverting via panic-revert scheduler", f"Device {hostname} unreachable — auto-reverting via panic-revert scheduler",
push_op_id=push_op_id_str, push_op_id=push_op_id_str,
) )
@@ -446,7 +473,7 @@ async def _update_push_op_status(
new_status: New status value ('committed' | 'reverted' | 'failed'). new_status: New status value ('committed' | 'reverted' | 'failed').
db_session: Database session (must already have tenant context set). db_session: Database session (must already have tenant context set).
""" """
from sqlalchemy import select, update from sqlalchemy import update
await db_session.execute( await db_session.execute(
update(ConfigPushOperation) update(ConfigPushOperation)
@@ -526,9 +553,7 @@ async def recover_stale_push_operations(db_session: AsyncSession) -> None:
for op in stale_ops: for op in stale_ops:
try: try:
# Load device # Load device
dev_result = await db_session.execute( dev_result = await db_session.execute(select(Device).where(Device.id == op.device_id))
select(Device).where(Device.id == op.device_id)
)
device = dev_result.scalar_one_or_none() device = dev_result.scalar_one_or_none()
if not device: if not device:
logger.error("Device %s not found for stale op %s", op.device_id, op.id) logger.error("Device %s not found for stale op %s", op.device_id, op.id)
@@ -547,9 +572,7 @@ async def recover_stale_push_operations(db_session: AsyncSession) -> None:
ssh_password = creds.get("password", "") ssh_password = creds.get("password", "")
# Check reachability # Check reachability
reachable = await _check_reachability( reachable = await _check_reachability(device.ip_address, ssh_username, ssh_password)
device.ip_address, ssh_username, ssh_password
)
if reachable: if reachable:
# Try to remove scheduler (if still there, push was good) # Try to remove scheduler (if still there, push was good)

View File

@@ -53,7 +53,9 @@ async def cleanup_expired_snapshots() -> int:
deleted = result.rowcount deleted = result.rowcount
config_snapshots_cleaned_total.inc(deleted) config_snapshots_cleaned_total.inc(deleted)
logger.info("retention cleanup complete", extra={"deleted_snapshots": deleted, "retention_days": days}) logger.info(
"retention cleanup complete", extra={"deleted_snapshots": deleted, "retention_days": days}
)
return deleted return deleted

View File

@@ -88,9 +88,7 @@ async def browse_menu(device_id: str, path: str) -> dict[str, Any]:
return await execute_command(device_id, command) return await execute_command(device_id, command)
async def add_entry( async def add_entry(device_id: str, path: str, properties: dict[str, str]) -> dict[str, Any]:
device_id: str, path: str, properties: dict[str, str]
) -> dict[str, Any]:
"""Add a new entry to a RouterOS menu path. """Add a new entry to a RouterOS menu path.
Args: Args:
@@ -124,9 +122,7 @@ async def update_entry(
return await execute_command(device_id, f"{path}/set", args) return await execute_command(device_id, f"{path}/set", args)
async def remove_entry( async def remove_entry(device_id: str, path: str, entry_id: str) -> dict[str, Any]:
device_id: str, path: str, entry_id: str
) -> dict[str, Any]:
"""Remove an entry from a RouterOS menu path. """Remove an entry from a RouterOS menu path.
Args: Args:

View File

@@ -7,20 +7,33 @@ from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
HIGH_RISK_PATHS = { HIGH_RISK_PATHS = {
"/ip address", "/ip route", "/ip firewall filter", "/ip firewall nat", "/ip address",
"/interface", "/interface bridge", "/interface vlan", "/ip route",
"/system identity", "/ip service", "/ip ssh", "/user", "/ip firewall filter",
"/ip firewall nat",
"/interface",
"/interface bridge",
"/interface vlan",
"/system identity",
"/ip service",
"/ip ssh",
"/user",
} }
MANAGEMENT_PATTERNS = [ MANAGEMENT_PATTERNS = [
(re.compile(r"chain=input.*dst-port=(22|8291|8728|8729|443|80)", re.I), (
"Modifies firewall rules for management ports (SSH/WinBox/API/Web)"), re.compile(r"chain=input.*dst-port=(22|8291|8728|8729|443|80)", re.I),
(re.compile(r"chain=input.*action=drop", re.I), "Modifies firewall rules for management ports (SSH/WinBox/API/Web)",
"Adds drop rule on input chain — may block management access"), ),
(re.compile(r"/ip service", re.I), (
"Modifies IP services — may disable API/SSH/WinBox access"), re.compile(r"chain=input.*action=drop", re.I),
(re.compile(r"/user.*set.*password", re.I), "Adds drop rule on input chain — may block management access",
"Changes user password — may affect automated access"), ),
(re.compile(r"/ip service", re.I), "Modifies IP services — may disable API/SSH/WinBox access"),
(
re.compile(r"/user.*set.*password", re.I),
"Changes user password — may affect automated access",
),
] ]
@@ -73,7 +86,14 @@ def parse_rsc(text: str) -> dict[str, Any]:
else: else:
# Check if second part starts with a known command verb # Check if second part starts with a known command verb
cmd_check = parts[1].strip().split(None, 1) cmd_check = parts[1].strip().split(None, 1)
if cmd_check and cmd_check[0] in ("add", "set", "remove", "print", "enable", "disable"): if cmd_check and cmd_check[0] in (
"add",
"set",
"remove",
"print",
"enable",
"disable",
):
current_path = parts[0] current_path = parts[0]
line = parts[1].strip() line = parts[1].strip()
else: else:
@@ -184,12 +204,14 @@ def compute_impact(
risk = "none" risk = "none"
if has_changes: if has_changes:
risk = "high" if path in HIGH_RISK_PATHS else "low" risk = "high" if path in HIGH_RISK_PATHS else "low"
result_categories.append({ result_categories.append(
"path": path, {
"adds": added, "path": path,
"removes": removed, "adds": added,
"risk": risk, "removes": removed,
}) "risk": risk,
}
)
# Check target commands against management patterns # Check target commands against management patterns
target_text = "\n".join( target_text = "\n".join(

View File

@@ -16,9 +16,7 @@ from srptools.constants import PRIME_2048, PRIME_2048_GEN
_SRP_HASH = hashlib.sha256 _SRP_HASH = hashlib.sha256
async def create_srp_verifier( async def create_srp_verifier(salt_hex: str, verifier_hex: str) -> tuple[bytes, bytes]:
salt_hex: str, verifier_hex: str
) -> tuple[bytes, bytes]:
"""Convert client-provided hex salt and verifier to bytes for storage. """Convert client-provided hex salt and verifier to bytes for storage.
The client computes v = g^x mod N using 2SKD-derived SRP-x. The client computes v = g^x mod N using 2SKD-derived SRP-x.
@@ -31,9 +29,7 @@ async def create_srp_verifier(
return bytes.fromhex(salt_hex), bytes.fromhex(verifier_hex) return bytes.fromhex(salt_hex), bytes.fromhex(verifier_hex)
async def srp_init( async def srp_init(email: str, srp_verifier_hex: str) -> tuple[str, str]:
email: str, srp_verifier_hex: str
) -> tuple[str, str]:
"""SRP Step 1: Generate server ephemeral (B) and private key (b). """SRP Step 1: Generate server ephemeral (B) and private key (b).
Args: Args:
@@ -47,14 +43,15 @@ async def srp_init(
Raises: Raises:
ValueError: If SRP initialization fails for any reason. ValueError: If SRP initialization fails for any reason.
""" """
def _init() -> tuple[str, str]: def _init() -> tuple[str, str]:
context = SRPContext( context = SRPContext(
email, prime=PRIME_2048, generator=PRIME_2048_GEN, email,
prime=PRIME_2048,
generator=PRIME_2048_GEN,
hash_func=_SRP_HASH, hash_func=_SRP_HASH,
) )
server_session = SRPServerSession( server_session = SRPServerSession(context, srp_verifier_hex)
context, srp_verifier_hex
)
return server_session.public, server_session.private return server_session.public, server_session.private
try: try:
@@ -85,26 +82,27 @@ async def srp_verify(
Tuple of (is_valid, server_proof_hex_or_none). Tuple of (is_valid, server_proof_hex_or_none).
If valid, server_proof is M2 for the client to verify. If valid, server_proof is M2 for the client to verify.
""" """
def _verify() -> tuple[bool, str | None]: def _verify() -> tuple[bool, str | None]:
import logging
log = logging.getLogger("srp_debug")
context = SRPContext( context = SRPContext(
email, prime=PRIME_2048, generator=PRIME_2048_GEN, email,
prime=PRIME_2048,
generator=PRIME_2048_GEN,
hash_func=_SRP_HASH, hash_func=_SRP_HASH,
) )
server_session = SRPServerSession( server_session = SRPServerSession(context, srp_verifier_hex, private=server_private)
context, srp_verifier_hex, private=server_private
)
_key, _key_proof, _key_proof_hash = server_session.process(client_public, srp_salt_hex) _key, _key_proof, _key_proof_hash = server_session.process(client_public, srp_salt_hex)
# srptools verify_proof has a Python 3 bug: hexlify() returns bytes # srptools verify_proof has a Python 3 bug: hexlify() returns bytes
# but client_proof is str, so bytes == str is always False. # but client_proof is str, so bytes == str is always False.
# Compare manually with consistent types. # Compare manually with consistent types.
server_m1 = _key_proof if isinstance(_key_proof, str) else _key_proof.decode('ascii') server_m1 = _key_proof if isinstance(_key_proof, str) else _key_proof.decode("ascii")
is_valid = client_proof.lower() == server_m1.lower() is_valid = client_proof.lower() == server_m1.lower()
if not is_valid: if not is_valid:
return False, None return False, None
# Return M2 (key_proof_hash), also fixing the bytes/str issue # Return M2 (key_proof_hash), also fixing the bytes/str issue
m2 = _key_proof_hash if isinstance(_key_proof_hash, str) else _key_proof_hash.decode('ascii') m2 = (
_key_proof_hash if isinstance(_key_proof_hash, str) else _key_proof_hash.decode("ascii")
)
return True, m2 return True, m2
try: try:

View File

@@ -137,7 +137,9 @@ class SSEConnectionManager:
if last_event_id is not None: if last_event_id is not None:
try: try:
start_seq = int(last_event_id) + 1 start_seq = int(last_event_id) + 1
consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.BY_START_SEQUENCE, opt_start_seq=start_seq) consumer_cfg = ConsumerConfig(
deliver_policy=DeliverPolicy.BY_START_SEQUENCE, opt_start_seq=start_seq
)
except (ValueError, TypeError): except (ValueError, TypeError):
consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.NEW) consumer_cfg = ConsumerConfig(deliver_policy=DeliverPolicy.NEW)
else: else:
@@ -173,18 +175,32 @@ class SSEConnectionManager:
except Exception as exc: except Exception as exc:
if "stream not found" in str(exc): if "stream not found" in str(exc):
try: try:
await js.add_stream(StreamConfig( await js.add_stream(
name="ALERT_EVENTS", StreamConfig(
subjects=_ALERT_EVENT_SUBJECTS, name="ALERT_EVENTS",
max_age=3600, subjects=_ALERT_EVENT_SUBJECTS,
)) max_age=3600,
sub = await js.subscribe(subject, stream="ALERT_EVENTS", config=consumer_cfg) )
)
sub = await js.subscribe(
subject, stream="ALERT_EVENTS", config=consumer_cfg
)
self._subscriptions.append(sub) self._subscriptions.append(sub)
logger.info("sse.stream_created_lazily", stream="ALERT_EVENTS") logger.info("sse.stream_created_lazily", stream="ALERT_EVENTS")
except Exception as retry_exc: except Exception as retry_exc:
logger.warning("sse.subscribe_failed", subject=subject, stream="ALERT_EVENTS", error=str(retry_exc)) logger.warning(
"sse.subscribe_failed",
subject=subject,
stream="ALERT_EVENTS",
error=str(retry_exc),
)
else: else:
logger.warning("sse.subscribe_failed", subject=subject, stream="ALERT_EVENTS", error=str(exc)) logger.warning(
"sse.subscribe_failed",
subject=subject,
stream="ALERT_EVENTS",
error=str(exc),
)
# Subscribe to operation events (OPERATION_EVENTS stream) # Subscribe to operation events (OPERATION_EVENTS stream)
for subject in _OPERATION_EVENT_SUBJECTS: for subject in _OPERATION_EVENT_SUBJECTS:
@@ -198,18 +214,32 @@ class SSEConnectionManager:
except Exception as exc: except Exception as exc:
if "stream not found" in str(exc): if "stream not found" in str(exc):
try: try:
await js.add_stream(StreamConfig( await js.add_stream(
name="OPERATION_EVENTS", StreamConfig(
subjects=_OPERATION_EVENT_SUBJECTS, name="OPERATION_EVENTS",
max_age=3600, subjects=_OPERATION_EVENT_SUBJECTS,
)) max_age=3600,
sub = await js.subscribe(subject, stream="OPERATION_EVENTS", config=consumer_cfg) )
)
sub = await js.subscribe(
subject, stream="OPERATION_EVENTS", config=consumer_cfg
)
self._subscriptions.append(sub) self._subscriptions.append(sub)
logger.info("sse.stream_created_lazily", stream="OPERATION_EVENTS") logger.info("sse.stream_created_lazily", stream="OPERATION_EVENTS")
except Exception as retry_exc: except Exception as retry_exc:
logger.warning("sse.subscribe_failed", subject=subject, stream="OPERATION_EVENTS", error=str(retry_exc)) logger.warning(
"sse.subscribe_failed",
subject=subject,
stream="OPERATION_EVENTS",
error=str(retry_exc),
)
else: else:
logger.warning("sse.subscribe_failed", subject=subject, stream="OPERATION_EVENTS", error=str(exc)) logger.warning(
"sse.subscribe_failed",
subject=subject,
stream="OPERATION_EVENTS",
error=str(exc),
)
# Start background task to pull messages from subscriptions into the queue # Start background task to pull messages from subscriptions into the queue
asyncio.create_task(self._pump_messages()) asyncio.create_task(self._pump_messages())

View File

@@ -12,22 +12,18 @@ separate scheduler and file names to avoid conflicts with restore operations.
""" """
import asyncio import asyncio
import io
import ipaddress import ipaddress
import json import json
import logging import logging
import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
import asyncssh import asyncssh
from jinja2 import meta from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment from jinja2.sandbox import SandboxedEnvironment
from sqlalchemy import select, text from sqlalchemy import text
from app.config import settings from app.config import settings
from app.database import AdminAsyncSessionLocal from app.database import AdminAsyncSessionLocal
from app.models.config_template import TemplatePushJob
from app.models.device import Device
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -145,7 +141,9 @@ async def push_to_devices(rollout_id: str) -> dict:
except Exception as exc: except Exception as exc:
logger.error( logger.error(
"Uncaught exception in template push rollout %s: %s", "Uncaught exception in template push rollout %s: %s",
rollout_id, exc, exc_info=True, rollout_id,
exc,
exc_info=True,
) )
return {"completed": 0, "failed": 1, "pending": 0} return {"completed": 0, "failed": 1, "pending": 0}
@@ -181,7 +179,9 @@ async def _run_push_rollout(rollout_id: str) -> dict:
logger.info( logger.info(
"Template push rollout %s: pushing to device %s (job %s)", "Template push rollout %s: pushing to device %s (job %s)",
rollout_id, hostname, job_id, rollout_id,
hostname,
job_id,
) )
await push_single_device(job_id) await push_single_device(job_id)
@@ -200,7 +200,9 @@ async def _run_push_rollout(rollout_id: str) -> dict:
failed = True failed = True
logger.error( logger.error(
"Template push rollout %s paused: device %s %s", "Template push rollout %s paused: device %s %s",
rollout_id, hostname, row[0], rollout_id,
hostname,
row[0],
) )
break break
@@ -232,7 +234,9 @@ async def push_single_device(job_id: str) -> None:
except Exception as exc: except Exception as exc:
logger.error( logger.error(
"Uncaught exception in template push job %s: %s", "Uncaught exception in template push job %s: %s",
job_id, exc, exc_info=True, job_id,
exc,
exc_info=True,
) )
await _update_job(job_id, status="failed", error_message=f"Unexpected error: {exc}") await _update_job(job_id, status="failed", error_message=f"Unexpected error: {exc}")
@@ -260,8 +264,13 @@ async def _run_single_push(job_id: str) -> None:
return return
( (
_, device_id, tenant_id, rendered_content, _,
ip_address, hostname, encrypted_credentials, device_id,
tenant_id,
rendered_content,
ip_address,
hostname,
encrypted_credentials,
encrypted_credentials_transit, encrypted_credentials_transit,
) = row ) = row
@@ -279,16 +288,21 @@ async def _run_single_push(job_id: str) -> None:
try: try:
from app.services.crypto import decrypt_credentials_hybrid from app.services.crypto import decrypt_credentials_hybrid
key = settings.get_encryption_key_bytes() key = settings.get_encryption_key_bytes()
creds_json = await decrypt_credentials_hybrid( creds_json = await decrypt_credentials_hybrid(
encrypted_credentials_transit, encrypted_credentials, tenant_id, key, encrypted_credentials_transit,
encrypted_credentials,
tenant_id,
key,
) )
creds = json.loads(creds_json) creds = json.loads(creds_json)
ssh_username = creds.get("username", "") ssh_username = creds.get("username", "")
ssh_password = creds.get("password", "") ssh_password = creds.get("password", "")
except Exception as cred_err: except Exception as cred_err:
await _update_job( await _update_job(
job_id, status="failed", job_id,
status="failed",
error_message=f"Failed to decrypt credentials: {cred_err}", error_message=f"Failed to decrypt credentials: {cred_err}",
) )
return return
@@ -297,6 +311,7 @@ async def _run_single_push(job_id: str) -> None:
logger.info("Running mandatory pre-push backup for device %s (%s)", hostname, ip_address) logger.info("Running mandatory pre-push backup for device %s (%s)", hostname, ip_address)
try: try:
from app.services import backup_service from app.services import backup_service
backup_result = await backup_service.run_backup( backup_result = await backup_service.run_backup(
device_id=device_id, device_id=device_id,
tenant_id=tenant_id, tenant_id=tenant_id,
@@ -308,7 +323,8 @@ async def _run_single_push(job_id: str) -> None:
except Exception as backup_err: except Exception as backup_err:
logger.error("Pre-push backup failed for %s: %s", hostname, backup_err) logger.error("Pre-push backup failed for %s: %s", hostname, backup_err)
await _update_job( await _update_job(
job_id, status="failed", job_id,
status="failed",
error_message=f"Pre-push backup failed: {backup_err}", error_message=f"Pre-push backup failed: {backup_err}",
) )
return return
@@ -316,7 +332,8 @@ async def _run_single_push(job_id: str) -> None:
# Step 5: SSH to device - install panic-revert, push config # Step 5: SSH to device - install panic-revert, push config
logger.info( logger.info(
"Pushing template to device %s (%s): installing panic-revert and uploading config", "Pushing template to device %s (%s): installing panic-revert and uploading config",
hostname, ip_address, hostname,
ip_address,
) )
try: try:
@@ -359,7 +376,8 @@ async def _run_single_push(job_id: str) -> None:
) )
logger.info( logger.info(
"Template import result for device %s: exit_status=%s stdout=%r", "Template import result for device %s: exit_status=%s stdout=%r",
hostname, import_result.exit_status, hostname,
import_result.exit_status,
(import_result.stdout or "")[:200], (import_result.stdout or "")[:200],
) )
@@ -369,16 +387,21 @@ async def _run_single_push(job_id: str) -> None:
except Exception as cleanup_err: except Exception as cleanup_err:
logger.warning( logger.warning(
"Failed to clean up %s from device %s: %s", "Failed to clean up %s from device %s: %s",
_TEMPLATE_RSC, ip_address, cleanup_err, _TEMPLATE_RSC,
ip_address,
cleanup_err,
) )
except Exception as push_err: except Exception as push_err:
logger.error( logger.error(
"SSH push phase failed for device %s (%s): %s", "SSH push phase failed for device %s (%s): %s",
hostname, ip_address, push_err, hostname,
ip_address,
push_err,
) )
await _update_job( await _update_job(
job_id, status="failed", job_id,
status="failed",
error_message=f"Config push failed during SSH phase: {push_err}", error_message=f"Config push failed during SSH phase: {push_err}",
) )
return return
@@ -395,9 +418,12 @@ async def _run_single_push(job_id: str) -> None:
logger.info("Device %s (%s) is reachable after push - committing", hostname, ip_address) logger.info("Device %s (%s) is reachable after push - committing", hostname, ip_address)
try: try:
async with asyncssh.connect( async with asyncssh.connect(
ip_address, port=22, ip_address,
username=ssh_username, password=ssh_password, port=22,
known_hosts=None, connect_timeout=30, username=ssh_username,
password=ssh_password,
known_hosts=None,
connect_timeout=30,
) as conn: ) as conn:
await conn.run( await conn.run(
f'/system scheduler remove "{_PANIC_REVERT_SCHEDULER}"', f'/system scheduler remove "{_PANIC_REVERT_SCHEDULER}"',
@@ -410,11 +436,13 @@ async def _run_single_push(job_id: str) -> None:
except Exception as cleanup_err: except Exception as cleanup_err:
logger.warning( logger.warning(
"Failed to clean up panic-revert scheduler/backup on device %s: %s", "Failed to clean up panic-revert scheduler/backup on device %s: %s",
hostname, cleanup_err, hostname,
cleanup_err,
) )
await _update_job( await _update_job(
job_id, status="committed", job_id,
status="committed",
completed_at=datetime.now(timezone.utc), completed_at=datetime.now(timezone.utc),
) )
else: else:
@@ -422,10 +450,13 @@ async def _run_single_push(job_id: str) -> None:
logger.warning( logger.warning(
"Device %s (%s) is unreachable after push - panic-revert scheduler " "Device %s (%s) is unreachable after push - panic-revert scheduler "
"will auto-revert to %s.backup", "will auto-revert to %s.backup",
hostname, ip_address, _PRE_PUSH_BACKUP, hostname,
ip_address,
_PRE_PUSH_BACKUP,
) )
await _update_job( await _update_job(
job_id, status="reverted", job_id,
status="reverted",
error_message="Device unreachable after push; auto-reverted via panic-revert scheduler", error_message="Device unreachable after push; auto-reverted via panic-revert scheduler",
completed_at=datetime.now(timezone.utc), completed_at=datetime.now(timezone.utc),
) )
@@ -440,9 +471,12 @@ async def _check_reachability(ip: str, username: str, password: str) -> bool:
"""Check if a RouterOS device is reachable via SSH.""" """Check if a RouterOS device is reachable via SSH."""
try: try:
async with asyncssh.connect( async with asyncssh.connect(
ip, port=22, ip,
username=username, password=password, port=22,
known_hosts=None, connect_timeout=30, username=username,
password=password,
known_hosts=None,
connect_timeout=30,
) as conn: ) as conn:
result = await conn.run("/system identity print", check=True) result = await conn.run("/system identity print", check=True)
logger.debug("Reachability check OK for %s: %r", ip, result.stdout[:50]) logger.debug("Reachability check OK for %s: %r", ip, result.stdout[:50])
@@ -459,7 +493,12 @@ async def _update_job(job_id: str, **kwargs) -> None:
for key, value in kwargs.items(): for key, value in kwargs.items():
param_name = f"v_{key}" param_name = f"v_{key}"
if value is None and key in ("error_message", "started_at", "completed_at", "pre_push_backup_sha"): if value is None and key in (
"error_message",
"started_at",
"completed_at",
"pre_push_backup_sha",
):
sets.append(f"{key} = NULL") sets.append(f"{key} = NULL")
else: else:
sets.append(f"{key} = :{param_name}") sets.append(f"{key} = :{param_name}")
@@ -472,7 +511,7 @@ async def _update_job(job_id: str, **kwargs) -> None:
await session.execute( await session.execute(
text(f""" text(f"""
UPDATE template_push_jobs UPDATE template_push_jobs
SET {', '.join(sets)} SET {", ".join(sets)}
WHERE id = CAST(:job_id AS uuid) WHERE id = CAST(:job_id AS uuid)
"""), """),
params, params,

Some files were not shown because too many files have changed in this diff Show More