import pytest from app.database import get_db from unittest.mock import patch, MagicMock @pytest.mark.asyncio async def test_get_db_generator(): mock_session = MagicMock() mock_session.commit = AsyncMock() mock_session.rollback = AsyncMock() mock_session.close = AsyncMock() # We need to mock AsyncSessionLocal as context manager mock_session_factory = MagicMock() mock_session_factory.return_value.__aenter__.return_value = mock_session with patch("app.database.AsyncSessionLocal", mock_session_factory): generator = get_db() # Initial yield session = await generator.__anext__() assert session == mock_session # After yield, it tries to commit try: await generator.__anext__() except StopAsyncIteration: pass mock_session.commit.assert_called_once() mock_session.close.assert_called_once() @pytest.mark.asyncio async def test_get_db_generator_exception(): mock_session = MagicMock() mock_session.commit = AsyncMock() mock_session.rollback = AsyncMock() mock_session.close = AsyncMock() mock_session_factory = MagicMock() mock_session_factory.return_value.__aenter__.return_value = mock_session with patch("app.database.AsyncSessionLocal", mock_session_factory): generator = get_db() session = await generator.__anext__() # Simulate exception during use with pytest.raises(ValueError): await generator.athrow(ValueError("test error")) mock_session.rollback.assert_called_once() mock_session.close.assert_called_once() class AsyncMock(MagicMock): async def __call__(self, *args, **kwargs): return super(AsyncMock, self).__call__(*args, **kwargs)