from __future__ import annotations """Tests for musehub.db.schema_check. Covers: - assert_schema_matches_orm: passes for a correct schema - assert_schema_matches_orm: raises RuntimeError for a missing table - assert_schema_matches_orm: raises RuntimeError for a missing column - assert_schema_matches_orm: raises RuntimeError for a nullable mismatch - assert_schema_matches_orm: raises RuntimeError for JSONB→JSON type drift - assert_schema_matches_orm: raises RuntimeError for VARCHAR length mismatch - assert_schema_matches_orm: raises RuntimeError for a missing index - assert_schema_matches_orm: raises RuntimeError for a missing FK - assert_schema_matches_orm: skips silently for SQLite engines - assert_no_orm_column_aliases: passes when all keys match column names - assert_no_orm_column_aliases: raises AssertionError when an alias exists - assert_no_orm_column_aliases: musehub Base has zero aliases (regression) """ from unittest.mock import AsyncMock, MagicMock import pytest from sqlalchemy import ForeignKey, Index, Integer, String, Table from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from musehub.types.json_types import JSONObject, JSONValue # noqa: F401 — needed for ForwardRef resolution in Mapped[JSONObject] type _ColumnsByTable = dict[str, list[JSONObject]] from musehub.db.schema_check import ( _build_alias_map, _inspect_schema, assert_no_orm_column_aliases, assert_schema_matches_orm, ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- class _TestBase(DeclarativeBase): pass class _Gadget(_TestBase): __tablename__ = "gadgets" __table_args__ = ( Index("ix_gadgets_serial", "serial"), ) id: Mapped[int] = mapped_column(Integer, primary_key=True) serial: Mapped[str] = mapped_column(String(64), nullable=False) notes: Mapped[str | None] = mapped_column(String(256), nullable=True) class _Widget(_TestBase): __tablename__ = "widgets" id: Mapped[int] = mapped_column(Integer, primary_key=True) gadget_id: Mapped[int] = mapped_column( Integer, ForeignKey("gadgets.id", ondelete="CASCADE"), nullable=False, ) data: Mapped[JSONObject | None] = mapped_column(JSONB, nullable=True) def _make_engine(url: str = "postgresql+asyncpg://user:pass@localhost/test") -> MagicMock: engine = MagicMock() engine.url = MagicMock() engine.url.__str__ = lambda _: url return engine def _make_inspector( tables: list[str], columns_by_table: _ColumnsByTable, indexes_by_table: dict[str, list[dict]] | None = None, fks_by_table: dict[str, list[dict]] | None = None, ) -> MagicMock: inspector = MagicMock() inspector.get_table_names.return_value = tables inspector.get_columns.side_effect = lambda table: columns_by_table.get(table, []) inspector.get_indexes.side_effect = lambda table: (indexes_by_table or {}).get(table, []) inspector.get_foreign_keys.side_effect = lambda table: (fks_by_table or {}).get(table, []) return inspector def _good_inspector() -> MagicMock: """Inspector that satisfies _TestBase (gadgets + widgets) fully.""" return _make_inspector( tables=["gadgets", "widgets"], columns_by_table={ "gadgets": [ {"name": "id", "nullable": False, "type": Integer()}, {"name": "serial", "nullable": False, "type": String(64)}, {"name": "notes", "nullable": True, "type": String(256)}, ], "widgets": [ {"name": "id", "nullable": False, "type": Integer()}, {"name": "gadget_id", "nullable": False, "type": Integer()}, {"name": "data", "nullable": True, "type": JSONB()}, ], }, indexes_by_table={ "gadgets": [{"name": "ix_gadgets_serial", "column_names": ["serial"], "unique": False}], "widgets": [], }, fks_by_table={ "gadgets": [], "widgets": [ { "constrained_columns": ["gadget_id"], "referred_table": "gadgets", "referred_columns": ["id"], "name": "fk_widgets_gadget_id", } ], }, ) def _alias_map() -> JSONObject: return _build_alias_map(_TestBase) # --------------------------------------------------------------------------- # SQLite skip # --------------------------------------------------------------------------- class TestSqliteSkip: async def test_sqlite_engine_is_skipped(self) -> None: engine = _make_engine("sqlite+aiosqlite:///./test.db") with MagicMock() as connect_mock: engine.connect = connect_mock connect_mock.side_effect = AssertionError("should not connect for SQLite") await assert_schema_matches_orm(engine, _TestBase) # must not raise # --------------------------------------------------------------------------- # Passing cases # --------------------------------------------------------------------------- def _wired_engine(mismatches: list[str]) -> MagicMock: """Return a mock async engine whose run_sync returns *mismatches* directly.""" engine = _make_engine() conn = AsyncMock() conn.run_sync = AsyncMock(return_value=mismatches) engine.connect = MagicMock() engine.connect.return_value.__aenter__ = AsyncMock(return_value=conn) engine.connect.return_value.__aexit__ = AsyncMock(return_value=False) return engine class TestPassingCases: async def test_exact_match_passes(self) -> None: await assert_schema_matches_orm(_wired_engine([]), _TestBase) async def test_extra_db_columns_are_allowed(self) -> None: await assert_schema_matches_orm(_wired_engine([]), _TestBase) def test_fully_correct_schema_produces_no_mismatches(self) -> None: mismatches = _inspect_schema(_good_inspector(), _TestBase, _alias_map()) assert mismatches == [] # --------------------------------------------------------------------------- # Table / column / nullability failure detection # --------------------------------------------------------------------------- class TestTableAndColumnChecks: def test_missing_table_detected(self) -> None: inspector = _make_inspector(tables=[], columns_by_table={}) mismatches = _inspect_schema(inspector, _TestBase, _alias_map()) assert any("missing table" in m and "gadgets" in m for m in mismatches) def test_missing_column_detected(self) -> None: inspector = _make_inspector( tables=["gadgets", "widgets"], columns_by_table={ "gadgets": [ {"name": "id", "nullable": False, "type": Integer()}, # "serial" is missing {"name": "notes", "nullable": True, "type": String(256)}, ], "widgets": [], }, ) mismatches = _inspect_schema(inspector, _TestBase, _alias_map()) assert any("serial" in m and "missing from DB" in m for m in mismatches) def test_nullable_mismatch_detected(self) -> None: inspector = _make_inspector( tables=["gadgets", "widgets"], columns_by_table={ "gadgets": [ {"name": "id", "nullable": False, "type": Integer()}, {"name": "serial", "nullable": True, "type": String(64)}, # ORM says NOT NULL {"name": "notes", "nullable": True, "type": String(256)}, ], "widgets": [], }, ) mismatches = _inspect_schema(inspector, _TestBase, _alias_map()) assert any("serial" in m and "nullable mismatch" in m for m in mismatches) def test_alias_column_error_includes_orm_key(self) -> None: class _ABase(DeclarativeBase): pass class _A(_ABase): __tablename__ = "a_table" id: Mapped[int] = mapped_column(Integer, primary_key=True) python_name: Mapped[str | None] = mapped_column("db_name", String(64), nullable=True) inspector = _make_inspector( tables=["a_table"], columns_by_table={"a_table": [{"name": "id", "nullable": False, "type": Integer()}]}, ) alias_map = _build_alias_map(_ABase) mismatches = _inspect_schema(inspector, _ABase, alias_map) assert any("db_name" in m and "python_name" in m for m in mismatches) async def test_raises_runtime_error_on_mismatch(self) -> None: engine = _wired_engine(["missing table: 'gadgets'"]) with pytest.raises(RuntimeError, match="Schema drift detected"): await assert_schema_matches_orm(engine, _TestBase) # --------------------------------------------------------------------------- # JSONB type drift # --------------------------------------------------------------------------- class TestJsonbTypeCheck: def test_jsonb_in_db_passes(self) -> None: inspector = _good_inspector() mismatches = _inspect_schema(inspector, _TestBase, _alias_map()) assert not any("JSONB" in m for m in mismatches) def test_json_instead_of_jsonb_detected(self) -> None: from sqlalchemy import JSON inspector = _make_inspector( tables=["gadgets", "widgets"], columns_by_table={ "gadgets": [ {"name": "id", "nullable": False, "type": Integer()}, {"name": "serial", "nullable": False, "type": String(64)}, {"name": "notes", "nullable": True, "type": String(256)}, ], "widgets": [ {"name": "id", "nullable": False, "type": Integer()}, {"name": "gadget_id", "nullable": False, "type": Integer()}, {"name": "data", "nullable": True, "type": JSON()}, # ← wrong: should be JSONB ], }, fks_by_table={ "gadgets": [], "widgets": [ { "constrained_columns": ["gadget_id"], "referred_table": "gadgets", "referred_columns": ["id"], "name": "fk_widgets_gadget_id", } ], }, ) mismatches = _inspect_schema(inspector, _TestBase, _alias_map()) assert any("data" in m and "JSONB" in m and "JSON" in m for m in mismatches) # --------------------------------------------------------------------------- # VARCHAR length mismatch # --------------------------------------------------------------------------- class TestVarcharLengthCheck: def test_matching_length_passes(self) -> None: mismatches = _inspect_schema(_good_inspector(), _TestBase, _alias_map()) assert not any("varchar length" in m for m in mismatches) def test_length_mismatch_detected(self) -> None: inspector = _make_inspector( tables=["gadgets", "widgets"], columns_by_table={ "gadgets": [ {"name": "id", "nullable": False, "type": Integer()}, {"name": "serial", "nullable": False, "type": String(128)}, # ORM says 64 {"name": "notes", "nullable": True, "type": String(256)}, ], "widgets": [ {"name": "id", "nullable": False, "type": Integer()}, {"name": "gadget_id", "nullable": False, "type": Integer()}, {"name": "data", "nullable": True, "type": JSONB()}, ], }, indexes_by_table={ "gadgets": [{"name": "ix_gadgets_serial", "column_names": ["serial"], "unique": False}], "widgets": [], }, fks_by_table={ "gadgets": [], "widgets": [ { "constrained_columns": ["gadget_id"], "referred_table": "gadgets", "referred_columns": ["id"], "name": "fk_widgets_gadget_id", } ], }, ) mismatches = _inspect_schema(inspector, _TestBase, _alias_map()) assert any("serial" in m and "varchar length" in m and "64" in m and "128" in m for m in mismatches) # --------------------------------------------------------------------------- # Index existence # --------------------------------------------------------------------------- class TestIndexCheck: def test_present_index_passes(self) -> None: mismatches = _inspect_schema(_good_inspector(), _TestBase, _alias_map()) assert not any("index" in m and "missing" in m for m in mismatches) def test_missing_index_detected(self) -> None: inspector = _make_inspector( tables=["gadgets", "widgets"], columns_by_table={ "gadgets": [ {"name": "id", "nullable": False, "type": Integer()}, {"name": "serial", "nullable": False, "type": String(64)}, {"name": "notes", "nullable": True, "type": String(256)}, ], "widgets": [ {"name": "id", "nullable": False, "type": Integer()}, {"name": "gadget_id", "nullable": False, "type": Integer()}, {"name": "data", "nullable": True, "type": JSONB()}, ], }, indexes_by_table={ "gadgets": [], # ← ix_gadgets_serial is missing "widgets": [], }, fks_by_table={ "gadgets": [], "widgets": [ { "constrained_columns": ["gadget_id"], "referred_table": "gadgets", "referred_columns": ["id"], "name": "fk_widgets_gadget_id", } ], }, ) mismatches = _inspect_schema(inspector, _TestBase, _alias_map()) assert any("ix_gadgets_serial" in m and "missing from DB" in m for m in mismatches) # --------------------------------------------------------------------------- # FK existence # --------------------------------------------------------------------------- class TestFkCheck: def test_present_fk_passes(self) -> None: mismatches = _inspect_schema(_good_inspector(), _TestBase, _alias_map()) assert not any("FK" in m and "missing" in m for m in mismatches) def test_missing_fk_detected(self) -> None: inspector = _make_inspector( tables=["gadgets", "widgets"], columns_by_table={ "gadgets": [ {"name": "id", "nullable": False, "type": Integer()}, {"name": "serial", "nullable": False, "type": String(64)}, {"name": "notes", "nullable": True, "type": String(256)}, ], "widgets": [ {"name": "id", "nullable": False, "type": Integer()}, {"name": "gadget_id", "nullable": False, "type": Integer()}, {"name": "data", "nullable": True, "type": JSONB()}, ], }, indexes_by_table={ "gadgets": [{"name": "ix_gadgets_serial", "column_names": ["serial"], "unique": False}], "widgets": [], }, fks_by_table={ "gadgets": [], "widgets": [], # ← FK gadget_id → gadgets is missing }, ) mismatches = _inspect_schema(inspector, _TestBase, _alias_map()) assert any("gadget_id" in m and "gadgets" in m and "missing from DB" in m for m in mismatches) # --------------------------------------------------------------------------- # assert_no_orm_column_aliases # --------------------------------------------------------------------------- class TestAssertNoOrmColumnAliases: def test_clean_base_passes(self) -> None: class _CleanBase(DeclarativeBase): pass class _Clean(_CleanBase): __tablename__ = "clean" id: Mapped[int] = mapped_column(Integer, primary_key=True) value: Mapped[str] = mapped_column(String(64), nullable=False) assert_no_orm_column_aliases(_CleanBase) # must not raise def test_aliased_column_raises(self) -> None: class _ABase(DeclarativeBase): pass class _A(_ABase): __tablename__ = "atbl" id: Mapped[int] = mapped_column(Integer, primary_key=True) py_name: Mapped[str | None] = mapped_column("db_name", String(64), nullable=True) with pytest.raises(AssertionError, match="py_name"): assert_no_orm_column_aliases(_ABase) def test_error_message_shows_both_names(self) -> None: class _BBase(DeclarativeBase): pass class _B(_BBase): __tablename__ = "btbl" id: Mapped[int] = mapped_column(Integer, primary_key=True) new_name: Mapped[str | None] = mapped_column("old_name", String(64), nullable=True) with pytest.raises(AssertionError) as exc_info: assert_no_orm_column_aliases(_BBase) msg = str(exc_info.value) assert "new_name" in msg assert "old_name" in msg def test_musehub_base_has_zero_aliases(self) -> None: """Regression: musehub's production Base must have no column aliases.""" from musehub.db import muse_cli_models # noqa: F401 from musehub.db.database import Base assert_no_orm_column_aliases(Base) # raises AssertionError if any alias exists