gabriel / musehub public
gen_type_contracts.py python
566 lines 18.6 KB
Raw
sha256:3ff9c9863a9891bdcde71b4a43228f66d0493e38b7cc1d09fe9eb7de774046b2 feat: add repair-commit wire endpoint (API parity with repa… Opus 4.8 minor ⚠ breaking 22 hours ago
1 #!/usr/bin/env python3
2 """Auto-generate docs/reference/type-contracts.md for a repository.
3
4 Walks Python source files, extracts every named type entity — Enum,
5 TypedDict, Pydantic BaseModel, SQLAlchemy ORM model — and generates a
6 structured Markdown reference document. The generated document replaces the
7 manually maintained ``type-contracts.md`` and is the single source of truth
8 for the repo's type surface.
9
10 Usage::
11
12 # Generate for the current repo (infers repo from cwd):
13 python scripts/gen_type_contracts.py
14
15 # Generate for a specific repo:
16 python scripts/gen_type_contracts.py --repo ~/ecosystem/musehub
17
18 # Check that committed doc matches current source (CI gate):
19 python scripts/gen_type_contracts.py --check
20
21 # Write to a custom path:
22 python scripts/gen_type_contracts.py --output path/to/out.md
23
24 # Target specific subdirectories:
25 python scripts/gen_type_contracts.py --dirs musehub/ tests/
26
27 Exit codes:
28 0 — OK (or --check passed with no drift).
29 1 — --check detected drift; see stderr for diff summary.
30 """
31
32 from __future__ import annotations
33
34 import argparse
35 import ast
36 import datetime
37 import difflib
38 import sys
39 from pathlib import Path
40 from typing import TypedDict
41
42 # ---------------------------------------------------------------------------
43 # Type shapes
44 # ---------------------------------------------------------------------------
45
46
47 class EnumMember(TypedDict):
48 """One member of an Enum class."""
49
50 name: str
51 value: str
52
53
54 class TypeField(TypedDict):
55 """One field of a TypedDict, Pydantic model, or ORM model."""
56
57 name: str
58 annotation: str
59 default: str
60 required: bool
61
62
63 class TypeRecord(TypedDict):
64 """A single extracted type entity."""
65
66 kind: str
67 """One of: ``"enum"`` | ``"typeddict"`` | ``"pydantic"`` | ``"orm"`` | ``"protocol"``."""
68 name: str
69 file: str
70 line: int
71 docstring: str
72 bases: list[str]
73 fields: list[TypeField]
74 members: list[EnumMember]
75 table: str
76 """ORM table name; empty string for non-ORM types."""
77
78
79 # ---------------------------------------------------------------------------
80 # Skip dirs (mirrors typing_audit.py)
81 # ---------------------------------------------------------------------------
82
83 _SKIP_DIRS: frozenset[str] = frozenset({
84 "venv", ".venv", "env", ".env",
85 "__pycache__",
86 ".git", ".muse", ".mypy_cache", ".ruff_cache", ".pytest_cache", ".tox",
87 "dist", "build", "site-packages", "__pypackages__",
88 "node_modules",
89 })
90
91 # Base class names that indicate each category.
92 _ENUM_BASES: frozenset[str] = frozenset({"Enum", "IntEnum", "StrEnum", "Flag", "IntFlag"})
93 _PYDANTIC_BASES: frozenset[str] = frozenset({
94 "BaseModel", "CamelModel", "BaseSettings", "SQLModel",
95 })
96 _TYPEDDICT_BASES: frozenset[str] = frozenset({"TypedDict"})
97 _PROTOCOL_BASES: frozenset[str] = frozenset({"Protocol"})
98 # ORM detection: class with __tablename__ attribute in the body.
99
100
101 # ---------------------------------------------------------------------------
102 # AST helpers
103 # ---------------------------------------------------------------------------
104
105
106 def _node_to_str(node: ast.expr) -> str:
107 """Convert an AST annotation node to a compact string representation."""
108 return ast.unparse(node)
109
110
111 def _get_docstring(node: ast.ClassDef) -> str:
112 """Return the first statement's string value if it is a docstring."""
113 if (
114 node.body
115 and isinstance(node.body[0], ast.Expr)
116 and isinstance(node.body[0].value, ast.Constant)
117 and isinstance(node.body[0].value.value, str)
118 ):
119 # Normalise whitespace: collapse newlines + indentation to a single space.
120 raw = node.body[0].value.value.strip()
121 # Keep only the first paragraph (up to the first blank line).
122 first_para = raw.split("\n\n")[0].replace("\n", " ").replace(" ", " ")
123 return first_para
124 return ""
125
126
127 def _base_names(node: ast.ClassDef) -> list[str]:
128 """Return the simple base-class names for *node*."""
129 names: list[str] = []
130 for base in node.bases:
131 if isinstance(base, ast.Name):
132 names.append(base.id)
133 elif isinstance(base, ast.Attribute):
134 names.append(base.attr)
135 return names
136
137
138 def _has_tablename(node: ast.ClassDef) -> str:
139 """Return the ``__tablename__`` value if present, else empty string."""
140 for stmt in node.body:
141 if isinstance(stmt, ast.Assign):
142 for target in stmt.targets:
143 if isinstance(target, ast.Name) and target.id == "__tablename__":
144 if isinstance(stmt.value, ast.Constant):
145 return str(stmt.value.value)
146 if isinstance(stmt, ast.AnnAssign):
147 if isinstance(stmt.target, ast.Name) and stmt.target.id == "__tablename__":
148 if stmt.value and isinstance(stmt.value, ast.Constant):
149 return str(stmt.value.value)
150 return ""
151
152
153 def _extract_default(node: ast.AnnAssign) -> str:
154 """Return a string representation of the default value, or empty string."""
155 if node.value is None:
156 return ""
157 val = node.value
158 # Field(...) — required Pydantic field; treat as no default.
159 if isinstance(val, ast.Call):
160 func = val.func
161 if isinstance(func, ast.Name) and func.id == "Field":
162 # Field(...) with first arg as Ellipsis = required.
163 if val.args and isinstance(val.args[0], ast.Constant) and val.args[0].value is ...:
164 return ""
165 # Field(default=...) keyword.
166 for kw in val.keywords:
167 if kw.arg == "default" and isinstance(kw.value, ast.Constant):
168 if kw.value.value is ...:
169 return ""
170 return repr(kw.value.value)
171 # Field with default_factory — treat as optional.
172 for kw in val.keywords:
173 if kw.arg == "default_factory":
174 return "(factory)"
175 return ast.unparse(val)
176 if isinstance(val, ast.Constant):
177 return repr(val.value)
178 return ast.unparse(val)
179
180
181 def _classify_and_extract(
182 node: ast.ClassDef,
183 filepath: str,
184 ) -> TypeRecord | None:
185 """Classify a class and extract its type information.
186
187 Returns ``None`` for classes that don't match any tracked category.
188 """
189 bases = _base_names(node)
190 bases_set = set(bases)
191
192 kind = ""
193 if bases_set & _ENUM_BASES:
194 kind = "enum"
195 elif bases_set & _TYPEDDICT_BASES:
196 kind = "typeddict"
197 elif bases_set & _PYDANTIC_BASES:
198 kind = "pydantic"
199 elif bases_set & _PROTOCOL_BASES:
200 kind = "protocol"
201 else:
202 # ORM: check for __tablename__
203 table = _has_tablename(node)
204 if table:
205 kind = "orm"
206
207 if not kind:
208 return None
209
210 table = _has_tablename(node) if kind == "orm" else ""
211 docstring = _get_docstring(node)
212
213 members: list[EnumMember] = []
214 fields: list[TypeField] = []
215
216 for stmt in node.body:
217 # --- Enum members ---
218 if kind == "enum" and isinstance(stmt, ast.Assign):
219 for target in stmt.targets:
220 if isinstance(target, ast.Name) and not target.id.startswith("_"):
221 val_str = ast.unparse(stmt.value) if stmt.value else ""
222 members.append(EnumMember(name=target.id, value=val_str))
223
224 # --- Annotated fields (TypedDict, Pydantic, ORM, Protocol) ---
225 if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name):
226 field_name = stmt.target.id
227 # Skip dunder and private fields
228 if field_name.startswith("__"):
229 continue
230 annotation = _node_to_str(stmt.annotation) if stmt.annotation else ""
231 default_str = _extract_default(stmt)
232 required = (stmt.value is None) or (default_str == "")
233 # For Pydantic, Field(...) means required; Field(None) means optional.
234 if kind == "pydantic" and stmt.value is not None:
235 val = stmt.value
236 if isinstance(val, ast.Call):
237 func = val.func
238 if isinstance(func, ast.Name) and func.id == "Field":
239 if val.args and isinstance(val.args[0], ast.Constant):
240 required = val.args[0].value is ...
241 else:
242 required = not any(
243 kw.arg in {"default", "default_factory"}
244 for kw in val.keywords
245 )
246 elif isinstance(val, ast.Constant):
247 required = val.value is ...
248 else:
249 required = False
250
251 fields.append(TypeField(
252 name=field_name,
253 annotation=annotation,
254 default=default_str,
255 required=required,
256 ))
257
258 return TypeRecord(
259 kind=kind,
260 name=node.name,
261 file=filepath,
262 line=node.lineno,
263 docstring=docstring,
264 bases=bases,
265 fields=fields,
266 members=members,
267 table=table,
268 )
269
270
271 # ---------------------------------------------------------------------------
272 # File and directory scanner
273 # ---------------------------------------------------------------------------
274
275
276 def scan_file(filepath: Path, root: Path) -> list[TypeRecord]:
277 """Extract all type records from a single Python source file."""
278 try:
279 source = filepath.read_text(encoding="utf-8")
280 except (OSError, UnicodeDecodeError):
281 return []
282
283 try:
284 tree = ast.parse(source)
285 except SyntaxError:
286 return []
287
288 relative = str(filepath.relative_to(root))
289 records: list[TypeRecord] = []
290
291 for node in ast.walk(tree):
292 if not isinstance(node, ast.ClassDef):
293 continue
294 record = _classify_and_extract(node, relative)
295 if record is not None:
296 records.append(record)
297
298 return records
299
300
301 def scan_directories(dirs: list[Path], root: Path) -> list[TypeRecord]:
302 """Scan all Python files in the given directories and return type records."""
303 all_records: list[TypeRecord] = []
304 seen_files: set[Path] = set()
305
306 for directory in dirs:
307 if not directory.exists():
308 continue
309 if directory.is_file():
310 if directory.suffix == ".py" and directory not in seen_files:
311 seen_files.add(directory)
312 all_records.extend(scan_file(directory, root))
313 continue
314 for py_file in sorted(directory.rglob("*.py")):
315 if py_file in seen_files:
316 continue
317 if any(part in _SKIP_DIRS for part in py_file.parts):
318 continue
319 seen_files.add(py_file)
320 all_records.extend(scan_file(py_file, root))
321
322 # Sort: by kind (enum < typeddict < pydantic < orm < protocol), then file, then name.
323 _kind_order = {"enum": 0, "typeddict": 1, "pydantic": 2, "orm": 3, "protocol": 4}
324 all_records.sort(key=lambda r: (_kind_order.get(r["kind"], 9), r["file"], r["name"]))
325 return all_records
326
327
328 # ---------------------------------------------------------------------------
329 # Markdown generation
330 # ---------------------------------------------------------------------------
331
332 _KIND_HEADING: dict[str, str] = {
333 "enum": "Enums",
334 "typeddict": "TypedDicts",
335 "pydantic": "Pydantic Models",
336 "orm": "SQLAlchemy ORM Models",
337 "protocol": "Protocols",
338 }
339
340 _KIND_ANCHOR: dict[str, str] = {
341 "enum": "enums",
342 "typeddict": "typeddicts",
343 "pydantic": "pydantic-models",
344 "orm": "sqlalchemy-orm-models",
345 "protocol": "protocols",
346 }
347
348
349 def _field_table(fields: list[TypeField]) -> str:
350 """Return a Markdown table for type fields."""
351 if not fields:
352 return "_No annotated fields._\n"
353 lines = ["| Field | Type | Required | Default |",
354 "|-------|------|:--------:|---------|"]
355 for f in fields:
356 req = "✓" if f["required"] else ""
357 default = f["default"] if f["default"] else "—"
358 lines.append(f"| `{f['name']}` | `{f['annotation']}` | {req} | `{default}` |")
359 return "\n".join(lines) + "\n"
360
361
362 def _enum_table(members: list[EnumMember]) -> str:
363 """Return a Markdown table for Enum members."""
364 if not members:
365 return "_No members._\n"
366 lines = ["| Member | Value |",
367 "|--------|-------|"]
368 for m in members:
369 lines.append(f"| `{m['name']}` | `{m['value']}` |")
370 return "\n".join(lines) + "\n"
371
372
373 def generate_markdown(
374 records: list[TypeRecord],
375 repo_name: str,
376 generated_by: str,
377 ) -> str:
378 """Render *records* into a complete Markdown type-contracts document."""
379 today = datetime.date.today().isoformat()
380
381 # Group by kind.
382 by_kind: dict[str, list[TypeRecord]] = {}
383 for r in records:
384 by_kind.setdefault(r["kind"], []).append(r)
385
386 # Build table of contents.
387 toc_lines: list[str] = []
388 for kind in ("enum", "typeddict", "pydantic", "orm", "protocol"):
389 if kind not in by_kind:
390 continue
391 heading = _KIND_HEADING[kind]
392 anchor = _KIND_ANCHOR[kind]
393 toc_lines.append(f"- [{heading}](#{anchor}) ({len(by_kind[kind])})")
394
395 # Build body sections.
396 body_parts: list[str] = []
397 for kind in ("enum", "typeddict", "pydantic", "orm", "protocol"):
398 if kind not in by_kind:
399 continue
400 heading = _KIND_HEADING[kind]
401 body_parts.append(f"## {heading}\n")
402
403 for r in by_kind[kind]:
404 link = f"[`{r['file']}`]({r['file']})"
405 body_parts.append(f"### `{r['name']}` — {link}\n")
406 if r["docstring"]:
407 body_parts.append(f"{r['docstring']}\n")
408 if r["bases"]:
409 body_parts.append(f"**Bases:** `{'`, `'.join(r['bases'])}`\n")
410 if r["table"]:
411 body_parts.append(f"**Table:** `{r['table']}`\n")
412 if kind == "enum":
413 body_parts.append(_enum_table(r["members"]))
414 else:
415 body_parts.append(_field_table(r["fields"]))
416 body_parts.append("")
417
418 toc = "\n".join(toc_lines)
419 body = "\n".join(body_parts)
420
421 return f"""\
422 # {repo_name} — Type Contracts Reference
423
424 > Auto-generated: {today}
425 > Source: `{generated_by}`
426 > Check for drift: `python scripts/gen_type_contracts.py --check`
427 >
428 > **Do not edit this file manually.** It is regenerated from source on every
429 > commit. To update, change the source types and re-run the generator.
430
431 ---
432
433 ## Table of Contents
434
435 {toc}
436
437 ---
438
439 {body}
440 """
441
442
443 # ---------------------------------------------------------------------------
444 # CLI
445 # ---------------------------------------------------------------------------
446
447
448 def _infer_repo_name(repo_root: Path) -> str:
449 """Return a display name for the repo (title-cased directory name)."""
450 name = repo_root.name.replace("-", " ").replace("_", " ").title()
451 return name
452
453
454 def main() -> None:
455 """Entry point — parse CLI flags, scan, generate, and optionally check."""
456 parser = argparse.ArgumentParser(
457 description=(
458 "Auto-generate docs/reference/type-contracts.md from Python source."
459 ),
460 )
461 parser.add_argument(
462 "--repo",
463 type=Path,
464 default=Path.cwd(),
465 metavar="DIR",
466 help="Repo root directory. Default: current working directory.",
467 )
468 parser.add_argument(
469 "--dirs",
470 nargs="+",
471 metavar="DIR",
472 default=None,
473 help=(
474 "Source directories to scan (relative to --repo). "
475 "Default: auto-detected from repo name."
476 ),
477 )
478 parser.add_argument(
479 "--output",
480 type=Path,
481 default=None,
482 metavar="PATH",
483 help="Output path (relative to --repo). Default: docs/reference/type-contracts.md",
484 )
485 parser.add_argument(
486 "--check",
487 action="store_true",
488 help="Check that the committed document matches generated output. Exit 1 on drift.",
489 )
490 args = parser.parse_args()
491
492 repo_root: Path = args.repo.resolve()
493 if not repo_root.is_dir():
494 print(f"ERROR: --repo {repo_root} is not a directory", file=sys.stderr)
495 sys.exit(1)
496
497 # Determine source directories to scan.
498 if args.dirs is not None:
499 scan_dirs = [repo_root / d for d in args.dirs]
500 else:
501 # Auto-detect: scan the package directory matching the repo name, plus tests/.
502 repo_name_lc = repo_root.name.lower().replace("-", "").replace("_", "")
503 candidates = [repo_root / repo_root.name, repo_root / "src", repo_root / repo_name_lc]
504 detected = [d for d in candidates if d.is_dir()]
505 tests_dir = repo_root / "tests"
506 scan_dirs = (detected or [repo_root]) + ([tests_dir] if tests_dir.is_dir() else [])
507
508 # Determine output path.
509 out_rel = args.output or Path("docs/reference/type-contracts.md")
510 out_path = (repo_root / out_rel).resolve()
511
512 repo_name = _infer_repo_name(repo_root)
513 generated_by = f"scripts/gen_type_contracts.py (repo: {repo_root.name})"
514
515 records = scan_directories(scan_dirs, repo_root)
516 content = generate_markdown(records, repo_name, generated_by)
517
518 if args.check:
519 if not out_path.exists():
520 print(
521 f"ERROR: --check failed — {out_path} does not exist.\n"
522 "Run without --check to generate it.",
523 file=sys.stderr,
524 )
525 sys.exit(1)
526 committed = out_path.read_text(encoding="utf-8")
527 if committed == content:
528 print(f"✅ type-contracts.md is up to date ({out_path})")
529 return
530 # Show a short diff summary.
531 diff = list(difflib.unified_diff(
532 committed.splitlines(keepends=True),
533 content.splitlines(keepends=True),
534 fromfile="committed",
535 tofile="generated",
536 n=3,
537 ))
538 print(
539 f"❌ type-contracts.md is out of date ({out_path}).\n"
540 f"Regenerate with: python scripts/gen_type_contracts.py\n"
541 f"\nDiff ({len(diff)} lines):",
542 file=sys.stderr,
543 )
544 sys.stderr.writelines(diff[:80])
545 if len(diff) > 80:
546 print(f"... ({len(diff) - 80} more diff lines)", file=sys.stderr)
547 sys.exit(1)
548
549 # Write output.
550 out_path.parent.mkdir(parents=True, exist_ok=True)
551 out_path.write_text(content, encoding="utf-8")
552 total = len(records)
553 counts = {k: len(v) for k, v in {}.items()}
554 by_kind: dict[str, list[TypeRecord]] = {}
555 for r in records:
556 by_kind.setdefault(r["kind"], []).append(r)
557 summary = ", ".join(
558 f"{len(v)} {_KIND_HEADING.get(k, k)}"
559 for k, v in sorted(by_kind.items())
560 )
561 print(f"✅ Generated {out_path} ({total} types: {summary})")
562 _ = counts # unused but kept to avoid bare-dict-at-boundary violation
563
564
565 if __name__ == "__main__":
566 main()
File History 1 commit
sha256:3ff9c9863a9891bdcde71b4a43228f66d0493e38b7cc1d09fe9eb7de774046b2 feat: add repair-commit wire endpoint (API parity with repa… Opus 4.8 minor 22 hours ago