gabriel / musehub public
test_schema_check.py python
455 lines 17.7 KB
Raw
sha256:3ff9c9863a9891bdcde71b4a43228f66d0493e38b7cc1d09fe9eb7de774046b2 feat: add repair-commit wire endpoint (API parity with repa… Opus 4.8 minor ⚠ breaking 1 day ago
1 from __future__ import annotations
2
3 """Tests for musehub.db.schema_check.
4
5 Covers:
6 - assert_schema_matches_orm: passes for a correct schema
7 - assert_schema_matches_orm: raises RuntimeError for a missing table
8 - assert_schema_matches_orm: raises RuntimeError for a missing column
9 - assert_schema_matches_orm: raises RuntimeError for a nullable mismatch
10 - assert_schema_matches_orm: raises RuntimeError for JSONB→JSON type drift
11 - assert_schema_matches_orm: raises RuntimeError for VARCHAR length mismatch
12 - assert_schema_matches_orm: raises RuntimeError for a missing index
13 - assert_schema_matches_orm: raises RuntimeError for a missing FK
14 - assert_schema_matches_orm: skips silently for SQLite engines
15 - assert_no_orm_column_aliases: passes when all keys match column names
16 - assert_no_orm_column_aliases: raises AssertionError when an alias exists
17 - assert_no_orm_column_aliases: musehub Base has zero aliases (regression)
18 """
19
20 from unittest.mock import AsyncMock, MagicMock
21
22 import pytest
23 from sqlalchemy import ForeignKey, Index, Integer, String, Table
24 from sqlalchemy.dialects.postgresql import JSONB
25 from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
26
27 from musehub.types.json_types import JSONObject, JSONValue # noqa: F401 — needed for ForwardRef resolution in Mapped[JSONObject]
28
29 type _ColumnsByTable = dict[str, list[JSONObject]]
30 from musehub.db.schema_check import (
31 _build_alias_map,
32 _inspect_schema,
33 assert_no_orm_column_aliases,
34 assert_schema_matches_orm,
35 )
36
37
38 # ---------------------------------------------------------------------------
39 # Helpers
40 # ---------------------------------------------------------------------------
41
42
43 class _TestBase(DeclarativeBase):
44 pass
45
46
47 class _Gadget(_TestBase):
48 __tablename__ = "gadgets"
49 __table_args__ = (
50 Index("ix_gadgets_serial", "serial"),
51 )
52 id: Mapped[int] = mapped_column(Integer, primary_key=True)
53 serial: Mapped[str] = mapped_column(String(64), nullable=False)
54 notes: Mapped[str | None] = mapped_column(String(256), nullable=True)
55
56
57 class _Widget(_TestBase):
58 __tablename__ = "widgets"
59 id: Mapped[int] = mapped_column(Integer, primary_key=True)
60 gadget_id: Mapped[int] = mapped_column(
61 Integer,
62 ForeignKey("gadgets.id", ondelete="CASCADE"),
63 nullable=False,
64 )
65 data: Mapped[JSONObject | None] = mapped_column(JSONB, nullable=True)
66
67
68 def _make_engine(url: str = "postgresql+asyncpg://user:pass@localhost/test") -> MagicMock:
69 engine = MagicMock()
70 engine.url = MagicMock()
71 engine.url.__str__ = lambda _: url
72 return engine
73
74
75 def _make_inspector(
76 tables: list[str],
77 columns_by_table: _ColumnsByTable,
78 indexes_by_table: dict[str, list[dict]] | None = None,
79 fks_by_table: dict[str, list[dict]] | None = None,
80 ) -> MagicMock:
81 inspector = MagicMock()
82 inspector.get_table_names.return_value = tables
83 inspector.get_columns.side_effect = lambda table: columns_by_table.get(table, [])
84 inspector.get_indexes.side_effect = lambda table: (indexes_by_table or {}).get(table, [])
85 inspector.get_foreign_keys.side_effect = lambda table: (fks_by_table or {}).get(table, [])
86 return inspector
87
88
89 def _good_inspector() -> MagicMock:
90 """Inspector that satisfies _TestBase (gadgets + widgets) fully."""
91 return _make_inspector(
92 tables=["gadgets", "widgets"],
93 columns_by_table={
94 "gadgets": [
95 {"name": "id", "nullable": False, "type": Integer()},
96 {"name": "serial", "nullable": False, "type": String(64)},
97 {"name": "notes", "nullable": True, "type": String(256)},
98 ],
99 "widgets": [
100 {"name": "id", "nullable": False, "type": Integer()},
101 {"name": "gadget_id", "nullable": False, "type": Integer()},
102 {"name": "data", "nullable": True, "type": JSONB()},
103 ],
104 },
105 indexes_by_table={
106 "gadgets": [{"name": "ix_gadgets_serial", "column_names": ["serial"], "unique": False}],
107 "widgets": [],
108 },
109 fks_by_table={
110 "gadgets": [],
111 "widgets": [
112 {
113 "constrained_columns": ["gadget_id"],
114 "referred_table": "gadgets",
115 "referred_columns": ["id"],
116 "name": "fk_widgets_gadget_id",
117 }
118 ],
119 },
120 )
121
122
123 def _alias_map() -> JSONObject:
124 return _build_alias_map(_TestBase)
125
126
127 # ---------------------------------------------------------------------------
128 # SQLite skip
129 # ---------------------------------------------------------------------------
130
131
132 class TestSqliteSkip:
133 async def test_sqlite_engine_is_skipped(self) -> None:
134 engine = _make_engine("sqlite+aiosqlite:///./test.db")
135 with MagicMock() as connect_mock:
136 engine.connect = connect_mock
137 connect_mock.side_effect = AssertionError("should not connect for SQLite")
138 await assert_schema_matches_orm(engine, _TestBase) # must not raise
139
140
141 # ---------------------------------------------------------------------------
142 # Passing cases
143 # ---------------------------------------------------------------------------
144
145
146 def _wired_engine(mismatches: list[str]) -> MagicMock:
147 """Return a mock async engine whose run_sync returns *mismatches* directly."""
148 engine = _make_engine()
149 conn = AsyncMock()
150 conn.run_sync = AsyncMock(return_value=mismatches)
151 engine.connect = MagicMock()
152 engine.connect.return_value.__aenter__ = AsyncMock(return_value=conn)
153 engine.connect.return_value.__aexit__ = AsyncMock(return_value=False)
154 return engine
155
156
157 class TestPassingCases:
158 async def test_exact_match_passes(self) -> None:
159 await assert_schema_matches_orm(_wired_engine([]), _TestBase)
160
161 async def test_extra_db_columns_are_allowed(self) -> None:
162 await assert_schema_matches_orm(_wired_engine([]), _TestBase)
163
164 def test_fully_correct_schema_produces_no_mismatches(self) -> None:
165 mismatches = _inspect_schema(_good_inspector(), _TestBase, _alias_map())
166 assert mismatches == []
167
168
169 # ---------------------------------------------------------------------------
170 # Table / column / nullability failure detection
171 # ---------------------------------------------------------------------------
172
173
174 class TestTableAndColumnChecks:
175 def test_missing_table_detected(self) -> None:
176 inspector = _make_inspector(tables=[], columns_by_table={})
177 mismatches = _inspect_schema(inspector, _TestBase, _alias_map())
178 assert any("missing table" in m and "gadgets" in m for m in mismatches)
179
180 def test_missing_column_detected(self) -> None:
181 inspector = _make_inspector(
182 tables=["gadgets", "widgets"],
183 columns_by_table={
184 "gadgets": [
185 {"name": "id", "nullable": False, "type": Integer()},
186 # "serial" is missing
187 {"name": "notes", "nullable": True, "type": String(256)},
188 ],
189 "widgets": [],
190 },
191 )
192 mismatches = _inspect_schema(inspector, _TestBase, _alias_map())
193 assert any("serial" in m and "missing from DB" in m for m in mismatches)
194
195 def test_nullable_mismatch_detected(self) -> None:
196 inspector = _make_inspector(
197 tables=["gadgets", "widgets"],
198 columns_by_table={
199 "gadgets": [
200 {"name": "id", "nullable": False, "type": Integer()},
201 {"name": "serial", "nullable": True, "type": String(64)}, # ORM says NOT NULL
202 {"name": "notes", "nullable": True, "type": String(256)},
203 ],
204 "widgets": [],
205 },
206 )
207 mismatches = _inspect_schema(inspector, _TestBase, _alias_map())
208 assert any("serial" in m and "nullable mismatch" in m for m in mismatches)
209
210 def test_alias_column_error_includes_orm_key(self) -> None:
211 class _ABase(DeclarativeBase):
212 pass
213
214 class _A(_ABase):
215 __tablename__ = "a_table"
216 id: Mapped[int] = mapped_column(Integer, primary_key=True)
217 python_name: Mapped[str | None] = mapped_column("db_name", String(64), nullable=True)
218
219 inspector = _make_inspector(
220 tables=["a_table"],
221 columns_by_table={"a_table": [{"name": "id", "nullable": False, "type": Integer()}]},
222 )
223 alias_map = _build_alias_map(_ABase)
224 mismatches = _inspect_schema(inspector, _ABase, alias_map)
225 assert any("db_name" in m and "python_name" in m for m in mismatches)
226
227 async def test_raises_runtime_error_on_mismatch(self) -> None:
228 engine = _wired_engine(["missing table: 'gadgets'"])
229 with pytest.raises(RuntimeError, match="Schema drift detected"):
230 await assert_schema_matches_orm(engine, _TestBase)
231
232
233 # ---------------------------------------------------------------------------
234 # JSONB type drift
235 # ---------------------------------------------------------------------------
236
237
238 class TestJsonbTypeCheck:
239 def test_jsonb_in_db_passes(self) -> None:
240 inspector = _good_inspector()
241 mismatches = _inspect_schema(inspector, _TestBase, _alias_map())
242 assert not any("JSONB" in m for m in mismatches)
243
244 def test_json_instead_of_jsonb_detected(self) -> None:
245 from sqlalchemy import JSON
246
247 inspector = _make_inspector(
248 tables=["gadgets", "widgets"],
249 columns_by_table={
250 "gadgets": [
251 {"name": "id", "nullable": False, "type": Integer()},
252 {"name": "serial", "nullable": False, "type": String(64)},
253 {"name": "notes", "nullable": True, "type": String(256)},
254 ],
255 "widgets": [
256 {"name": "id", "nullable": False, "type": Integer()},
257 {"name": "gadget_id", "nullable": False, "type": Integer()},
258 {"name": "data", "nullable": True, "type": JSON()}, # ← wrong: should be JSONB
259 ],
260 },
261 fks_by_table={
262 "gadgets": [],
263 "widgets": [
264 {
265 "constrained_columns": ["gadget_id"],
266 "referred_table": "gadgets",
267 "referred_columns": ["id"],
268 "name": "fk_widgets_gadget_id",
269 }
270 ],
271 },
272 )
273 mismatches = _inspect_schema(inspector, _TestBase, _alias_map())
274 assert any("data" in m and "JSONB" in m and "JSON" in m for m in mismatches)
275
276
277 # ---------------------------------------------------------------------------
278 # VARCHAR length mismatch
279 # ---------------------------------------------------------------------------
280
281
282 class TestVarcharLengthCheck:
283 def test_matching_length_passes(self) -> None:
284 mismatches = _inspect_schema(_good_inspector(), _TestBase, _alias_map())
285 assert not any("varchar length" in m for m in mismatches)
286
287 def test_length_mismatch_detected(self) -> None:
288 inspector = _make_inspector(
289 tables=["gadgets", "widgets"],
290 columns_by_table={
291 "gadgets": [
292 {"name": "id", "nullable": False, "type": Integer()},
293 {"name": "serial", "nullable": False, "type": String(128)}, # ORM says 64
294 {"name": "notes", "nullable": True, "type": String(256)},
295 ],
296 "widgets": [
297 {"name": "id", "nullable": False, "type": Integer()},
298 {"name": "gadget_id", "nullable": False, "type": Integer()},
299 {"name": "data", "nullable": True, "type": JSONB()},
300 ],
301 },
302 indexes_by_table={
303 "gadgets": [{"name": "ix_gadgets_serial", "column_names": ["serial"], "unique": False}],
304 "widgets": [],
305 },
306 fks_by_table={
307 "gadgets": [],
308 "widgets": [
309 {
310 "constrained_columns": ["gadget_id"],
311 "referred_table": "gadgets",
312 "referred_columns": ["id"],
313 "name": "fk_widgets_gadget_id",
314 }
315 ],
316 },
317 )
318 mismatches = _inspect_schema(inspector, _TestBase, _alias_map())
319 assert any("serial" in m and "varchar length" in m and "64" in m and "128" in m for m in mismatches)
320
321
322 # ---------------------------------------------------------------------------
323 # Index existence
324 # ---------------------------------------------------------------------------
325
326
327 class TestIndexCheck:
328 def test_present_index_passes(self) -> None:
329 mismatches = _inspect_schema(_good_inspector(), _TestBase, _alias_map())
330 assert not any("index" in m and "missing" in m for m in mismatches)
331
332 def test_missing_index_detected(self) -> None:
333 inspector = _make_inspector(
334 tables=["gadgets", "widgets"],
335 columns_by_table={
336 "gadgets": [
337 {"name": "id", "nullable": False, "type": Integer()},
338 {"name": "serial", "nullable": False, "type": String(64)},
339 {"name": "notes", "nullable": True, "type": String(256)},
340 ],
341 "widgets": [
342 {"name": "id", "nullable": False, "type": Integer()},
343 {"name": "gadget_id", "nullable": False, "type": Integer()},
344 {"name": "data", "nullable": True, "type": JSONB()},
345 ],
346 },
347 indexes_by_table={
348 "gadgets": [], # ← ix_gadgets_serial is missing
349 "widgets": [],
350 },
351 fks_by_table={
352 "gadgets": [],
353 "widgets": [
354 {
355 "constrained_columns": ["gadget_id"],
356 "referred_table": "gadgets",
357 "referred_columns": ["id"],
358 "name": "fk_widgets_gadget_id",
359 }
360 ],
361 },
362 )
363 mismatches = _inspect_schema(inspector, _TestBase, _alias_map())
364 assert any("ix_gadgets_serial" in m and "missing from DB" in m for m in mismatches)
365
366
367 # ---------------------------------------------------------------------------
368 # FK existence
369 # ---------------------------------------------------------------------------
370
371
372 class TestFkCheck:
373 def test_present_fk_passes(self) -> None:
374 mismatches = _inspect_schema(_good_inspector(), _TestBase, _alias_map())
375 assert not any("FK" in m and "missing" in m for m in mismatches)
376
377 def test_missing_fk_detected(self) -> None:
378 inspector = _make_inspector(
379 tables=["gadgets", "widgets"],
380 columns_by_table={
381 "gadgets": [
382 {"name": "id", "nullable": False, "type": Integer()},
383 {"name": "serial", "nullable": False, "type": String(64)},
384 {"name": "notes", "nullable": True, "type": String(256)},
385 ],
386 "widgets": [
387 {"name": "id", "nullable": False, "type": Integer()},
388 {"name": "gadget_id", "nullable": False, "type": Integer()},
389 {"name": "data", "nullable": True, "type": JSONB()},
390 ],
391 },
392 indexes_by_table={
393 "gadgets": [{"name": "ix_gadgets_serial", "column_names": ["serial"], "unique": False}],
394 "widgets": [],
395 },
396 fks_by_table={
397 "gadgets": [],
398 "widgets": [], # ← FK gadget_id → gadgets is missing
399 },
400 )
401 mismatches = _inspect_schema(inspector, _TestBase, _alias_map())
402 assert any("gadget_id" in m and "gadgets" in m and "missing from DB" in m for m in mismatches)
403
404
405 # ---------------------------------------------------------------------------
406 # assert_no_orm_column_aliases
407 # ---------------------------------------------------------------------------
408
409
410 class TestAssertNoOrmColumnAliases:
411 def test_clean_base_passes(self) -> None:
412 class _CleanBase(DeclarativeBase):
413 pass
414
415 class _Clean(_CleanBase):
416 __tablename__ = "clean"
417 id: Mapped[int] = mapped_column(Integer, primary_key=True)
418 value: Mapped[str] = mapped_column(String(64), nullable=False)
419
420 assert_no_orm_column_aliases(_CleanBase) # must not raise
421
422 def test_aliased_column_raises(self) -> None:
423 class _ABase(DeclarativeBase):
424 pass
425
426 class _A(_ABase):
427 __tablename__ = "atbl"
428 id: Mapped[int] = mapped_column(Integer, primary_key=True)
429 py_name: Mapped[str | None] = mapped_column("db_name", String(64), nullable=True)
430
431 with pytest.raises(AssertionError, match="py_name"):
432 assert_no_orm_column_aliases(_ABase)
433
434 def test_error_message_shows_both_names(self) -> None:
435 class _BBase(DeclarativeBase):
436 pass
437
438 class _B(_BBase):
439 __tablename__ = "btbl"
440 id: Mapped[int] = mapped_column(Integer, primary_key=True)
441 new_name: Mapped[str | None] = mapped_column("old_name", String(64), nullable=True)
442
443 with pytest.raises(AssertionError) as exc_info:
444 assert_no_orm_column_aliases(_BBase)
445
446 msg = str(exc_info.value)
447 assert "new_name" in msg
448 assert "old_name" in msg
449
450 def test_musehub_base_has_zero_aliases(self) -> None:
451 """Regression: musehub's production Base must have no column aliases."""
452 from musehub.db import muse_cli_models # noqa: F401
453 from musehub.db.database import Base
454
455 assert_no_orm_column_aliases(Base) # raises AssertionError if any alias exists
File History 1 commit
sha256:3ff9c9863a9891bdcde71b4a43228f66d0493e38b7cc1d09fe9eb7de774046b2 feat: add repair-commit wire endpoint (API parity with repa… Opus 4.8 minor 1 day ago