diff --git a/backend/app/services/config_history_service.py b/backend/app/services/config_history_service.py index f60a92f..4c0c5fb 100644 --- a/backend/app/services/config_history_service.py +++ b/backend/app/services/config_history_service.py @@ -3,12 +3,17 @@ Provides paginated query of config change entries for a device, joining router_config_changes with router_config_diffs to include diff metadata (lines_added, lines_removed, snapshot_id). + +Also provides single-snapshot retrieval (with Transit decrypt) and +diff retrieval by snapshot id. """ import logging from sqlalchemy import text +from app.services.openbao_service import OpenBaoTransitService + logger = logging.getLogger(__name__) @@ -62,3 +67,88 @@ async def get_config_history( } for row in rows ] + + +async def get_snapshot( + snapshot_id: str, + device_id: str, + tenant_id: str, + session, +) -> dict | None: + """Return decrypted config snapshot for a given snapshot, device, and tenant. + + Returns None if the snapshot does not exist or belongs to a different + device/tenant (RLS prevents cross-tenant access). + """ + result = await session.execute( + text( + "SELECT id, config_text, sha256_hash, collected_at " + "FROM router_config_snapshots " + "WHERE id = CAST(:snapshot_id AS uuid) " + "AND device_id = CAST(:device_id AS uuid) " + "AND tenant_id = CAST(:tenant_id AS uuid)" + ), + { + "snapshot_id": snapshot_id, + "device_id": device_id, + "tenant_id": tenant_id, + }, + ) + row = result.fetchone() + if row is None: + return None + + ciphertext = row._mapping["config_text"] + + openbao = OpenBaoTransitService() + try: + plaintext_bytes = await openbao.decrypt(tenant_id, ciphertext) + finally: + await openbao.close() + + return { + "id": str(row._mapping["id"]), + "config_text": plaintext_bytes.decode("utf-8"), + "sha256_hash": row._mapping["sha256_hash"], + "collected_at": row._mapping["collected_at"].isoformat(), + } + + +async def get_snapshot_diff( + snapshot_id: str, + device_id: str, + tenant_id: str, + session, +) -> dict | None: + """Return the diff associated with a snapshot (as the new_snapshot_id). + + Returns None if no diff exists for this snapshot (e.g., first snapshot). + """ + result = await session.execute( + text( + "SELECT id, diff_text, lines_added, lines_removed, " + "old_snapshot_id, new_snapshot_id, created_at " + "FROM router_config_diffs " + "WHERE new_snapshot_id = CAST(:snapshot_id AS uuid) " + "AND device_id = CAST(:device_id AS uuid) " + "AND tenant_id = CAST(:tenant_id AS uuid)" + ), + { + "snapshot_id": snapshot_id, + "device_id": device_id, + "tenant_id": tenant_id, + }, + ) + row = result.fetchone() + if row is None: + return None + + return { + "id": str(row._mapping["id"]), + "diff_text": row._mapping["diff_text"], + "lines_added": row._mapping["lines_added"], + "lines_removed": row._mapping["lines_removed"], + "old_snapshot_id": str(row._mapping["old_snapshot_id"]), + "new_snapshot_id": str(row._mapping["new_snapshot_id"]), + "created_at": row._mapping["created_at"].isoformat(), + } diff --git a/backend/tests/test_config_history_service.py b/backend/tests/test_config_history_service.py index 0bdfcf4..13c67fc 100644 --- a/backend/tests/test_config_history_service.py +++ b/backend/tests/test_config_history_service.py @@ -120,3 +120,125 @@ async def test_ordering_desc_by_created_at(): call_args = mock_session.execute.call_args query_text = str(call_args[0][0]) assert "DESC" in query_text.upper() + + +# --------------------------------------------------------------------------- +# Tests for get_snapshot +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_snapshot_returns_decrypted_content(): + """get_snapshot decrypts config_text via Transit and returns plaintext.""" + from unittest.mock import patch + from app.services.config_history_service import get_snapshot + + snapshot_id = str(uuid4()) + device_id = str(uuid4()) + tenant_id = str(uuid4()) + ts = datetime(2026, 3, 12, 10, 0, 0, tzinfo=timezone.utc) + sha = "abc123" * 10 + "abcd" + + mock_session = AsyncMock() + result_mock = MagicMock() + row = MagicMock() + row._mapping = { + "id": uuid4(), + "config_text": "vault:v1:encrypted_data", + "sha256_hash": sha, + "collected_at": ts, + } + result_mock.fetchone.return_value = row + mock_session.execute = AsyncMock(return_value=result_mock) + + plaintext_config = "/ip address\nadd address=10.0.0.1/24" + mock_openbao = AsyncMock() + mock_openbao.decrypt = AsyncMock(return_value=plaintext_config.encode("utf-8")) + + with patch( + "app.services.config_history_service.OpenBaoTransitService", + return_value=mock_openbao, + ): + result = await get_snapshot(snapshot_id, device_id, tenant_id, mock_session) + + assert result is not None + assert result["config_text"] == plaintext_config + assert result["sha256_hash"] == sha + assert result["collected_at"] == ts.isoformat() + mock_openbao.decrypt.assert_called_once_with(tenant_id, "vault:v1:encrypted_data") + mock_openbao.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_snapshot_not_found_returns_none(): + """get_snapshot returns None when snapshot not found (wrong id/device/tenant).""" + from app.services.config_history_service import get_snapshot + + mock_session = AsyncMock() + result_mock = MagicMock() + result_mock.fetchone.return_value = None + mock_session.execute = AsyncMock(return_value=result_mock) + + result = await get_snapshot(str(uuid4()), str(uuid4()), str(uuid4()), mock_session) + + assert result is None + + +# --------------------------------------------------------------------------- +# Tests for get_snapshot_diff +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_snapshot_diff_returns_diff_text(): + """get_snapshot_diff returns diff data for the given snapshot.""" + from app.services.config_history_service import get_snapshot_diff + + snapshot_id = str(uuid4()) + device_id = str(uuid4()) + tenant_id = str(uuid4()) + diff_id = uuid4() + old_snap = uuid4() + new_snap = uuid4() + ts = datetime(2026, 3, 12, 11, 0, 0, tzinfo=timezone.utc) + + mock_session = AsyncMock() + result_mock = MagicMock() + row = MagicMock() + row._mapping = { + "id": diff_id, + "diff_text": "--- old\n+++ new\n@@ -1 +1 @@\n-line1\n+line2", + "lines_added": 1, + "lines_removed": 1, + "old_snapshot_id": old_snap, + "new_snapshot_id": new_snap, + "created_at": ts, + } + result_mock.fetchone.return_value = row + mock_session.execute = AsyncMock(return_value=result_mock) + + result = await get_snapshot_diff(snapshot_id, device_id, tenant_id, mock_session) + + assert result is not None + assert result["id"] == str(diff_id) + assert "line2" in result["diff_text"] + assert result["lines_added"] == 1 + assert result["lines_removed"] == 1 + assert result["old_snapshot_id"] == str(old_snap) + assert result["new_snapshot_id"] == str(new_snap) + assert result["created_at"] == ts.isoformat() + + +@pytest.mark.asyncio +async def test_get_snapshot_diff_no_diff_returns_none(): + """get_snapshot_diff returns None when no diff exists (first snapshot).""" + from app.services.config_history_service import get_snapshot_diff + + mock_session = AsyncMock() + result_mock = MagicMock() + result_mock.fetchone.return_value = None + mock_session.execute = AsyncMock(return_value=result_mock) + + result = await get_snapshot_diff(str(uuid4()), str(uuid4()), str(uuid4()), mock_session) + + assert result is None