"""Tests for config diff generation service. Tests the generate_and_store_diff function with mocked DB sessions and OpenBao Transit service. """ import json import pytest from unittest.mock import AsyncMock, MagicMock, patch, call from uuid import uuid4 def _mock_snapshot_row(snapshot_id, config_text): """Create a mock row for snapshot query results.""" row = MagicMock() row._mapping = {"id": snapshot_id, "config_text": config_text} return row @pytest.mark.asyncio async def test_diff_generated_and_stored(): """Test 1: Two different configs produce a unified diff and INSERT into router_config_diffs.""" from app.services.config_diff_service import generate_and_store_diff device_id = str(uuid4()) tenant_id = str(uuid4()) new_snapshot_id = str(uuid4()) old_snapshot_id = str(uuid4()) old_config = "line1\nline2\nline3" new_config = "line1\nline2_modified\nline3" mock_session = AsyncMock() # Query 1: previous snapshot (returns old snapshot) prev_result = MagicMock() prev_result.fetchone.return_value = MagicMock( _mapping={"id": old_snapshot_id, "config_text": "vault:v1:old_encrypted"} ) # Query 2: new snapshot config_text new_result = MagicMock() new_result.scalar_one.return_value = "vault:v1:new_encrypted" # Query 3: INSERT insert_result = MagicMock() mock_session.execute = AsyncMock(side_effect=[prev_result, new_result, insert_result]) mock_session.commit = AsyncMock() mock_openbao = AsyncMock() mock_openbao.decrypt = AsyncMock(side_effect=[ old_config.encode("utf-8"), new_config.encode("utf-8"), ]) with patch( "app.services.config_diff_service.OpenBaoTransitService", return_value=mock_openbao, ): await generate_and_store_diff(device_id, tenant_id, new_snapshot_id, mock_session) # Should decrypt both configs assert mock_openbao.decrypt.call_count == 2 # Should INSERT (3 executes: prev query, new query, INSERT) assert mock_session.execute.call_count == 3 # Should commit mock_session.commit.assert_called_once() # Verify INSERT contains correct data insert_call = mock_session.execute.call_args_list[2] insert_params = insert_call[0][1] assert insert_params["old_snapshot_id"] == old_snapshot_id assert insert_params["new_snapshot_id"] == new_snapshot_id assert insert_params["lines_added"] == 1 assert insert_params["lines_removed"] == 1 assert "line2_modified" in insert_params["diff_text"] @pytest.mark.asyncio async def test_first_snapshot_no_diff(): """Test 2: First snapshot (no previous) skips diff generation gracefully.""" from app.services.config_diff_service import generate_and_store_diff device_id = str(uuid4()) tenant_id = str(uuid4()) new_snapshot_id = str(uuid4()) mock_session = AsyncMock() # Query 1: previous snapshot returns None prev_result = MagicMock() prev_result.fetchone.return_value = None mock_session.execute = AsyncMock(return_value=prev_result) mock_session.commit = AsyncMock() with patch( "app.services.config_diff_service.OpenBaoTransitService", return_value=AsyncMock(), ): await generate_and_store_diff(device_id, tenant_id, new_snapshot_id, mock_session) # Should only query for previous snapshot, then return assert mock_session.execute.call_count == 1 mock_session.commit.assert_not_called() @pytest.mark.asyncio async def test_decrypt_failure_logs_and_returns(): """Test 3: Transit decrypt failure logs warning, does NOT raise.""" from app.services.config_diff_service import generate_and_store_diff device_id = str(uuid4()) tenant_id = str(uuid4()) new_snapshot_id = str(uuid4()) old_snapshot_id = str(uuid4()) mock_session = AsyncMock() # Query 1: previous snapshot exists prev_result = MagicMock() prev_result.fetchone.return_value = MagicMock( _mapping={"id": old_snapshot_id, "config_text": "vault:v1:old_encrypted"} ) # Query 2: new snapshot config_text new_result = MagicMock() new_result.scalar_one.return_value = "vault:v1:new_encrypted" mock_session.execute = AsyncMock(side_effect=[prev_result, new_result]) mock_session.commit = AsyncMock() mock_openbao = AsyncMock() mock_openbao.decrypt = AsyncMock(side_effect=Exception("Transit unavailable")) with patch( "app.services.config_diff_service.OpenBaoTransitService", return_value=mock_openbao, ): # Should NOT raise await generate_and_store_diff(device_id, tenant_id, new_snapshot_id, mock_session) # Should not commit (no INSERT happened) mock_session.commit.assert_not_called() @pytest.mark.asyncio async def test_line_counts_correct(): """Test 4: lines_added/lines_removed counts are correct.""" from app.services.config_diff_service import generate_and_store_diff device_id = str(uuid4()) tenant_id = str(uuid4()) new_snapshot_id = str(uuid4()) old_snapshot_id = str(uuid4()) # 2 lines removed, 3 lines added old_config = "line1\nremoved1\nremoved2\nline4" new_config = "line1\nadded1\nadded2\nadded3\nline4" mock_session = AsyncMock() prev_result = MagicMock() prev_result.fetchone.return_value = MagicMock( _mapping={"id": old_snapshot_id, "config_text": "vault:v1:old"} ) new_result = MagicMock() new_result.scalar_one.return_value = "vault:v1:new" insert_result = MagicMock() mock_session.execute = AsyncMock(side_effect=[prev_result, new_result, insert_result]) mock_session.commit = AsyncMock() mock_openbao = AsyncMock() mock_openbao.decrypt = AsyncMock(side_effect=[ old_config.encode("utf-8"), new_config.encode("utf-8"), ]) with patch( "app.services.config_diff_service.OpenBaoTransitService", return_value=mock_openbao, ): await generate_and_store_diff(device_id, tenant_id, new_snapshot_id, mock_session) insert_params = mock_session.execute.call_args_list[2][0][1] assert insert_params["lines_added"] == 3 assert insert_params["lines_removed"] == 2 @pytest.mark.asyncio async def test_empty_diff_skips_insert(): """Test 5: Identical content (empty diff) skips INSERT.""" from app.services.config_diff_service import generate_and_store_diff device_id = str(uuid4()) tenant_id = str(uuid4()) new_snapshot_id = str(uuid4()) old_snapshot_id = str(uuid4()) same_config = "line1\nline2\nline3" mock_session = AsyncMock() prev_result = MagicMock() prev_result.fetchone.return_value = MagicMock( _mapping={"id": old_snapshot_id, "config_text": "vault:v1:old"} ) new_result = MagicMock() new_result.scalar_one.return_value = "vault:v1:new" mock_session.execute = AsyncMock(side_effect=[prev_result, new_result]) mock_session.commit = AsyncMock() mock_openbao = AsyncMock() mock_openbao.decrypt = AsyncMock(side_effect=[ same_config.encode("utf-8"), same_config.encode("utf-8"), ]) with patch( "app.services.config_diff_service.OpenBaoTransitService", return_value=mock_openbao, ): await generate_and_store_diff(device_id, tenant_id, new_snapshot_id, mock_session) # Only 2 queries (prev + new), no INSERT assert mock_session.execute.call_count == 2 mock_session.commit.assert_not_called()