gabriel / musehub public
gen_type_contracts.py python
565 lines 18.6 KB
Raw
sha256:3ff9c9863a9891bdcde71b4a43228f66d0493e38b7cc1d09fe9eb7de774046b2 feat: add repair-commit wire endpoint (API parity with repa… Opus 4.8 minor ⚠ breaking 1 day ago
1 """Auto-generate docs/reference/type-contracts.md for a repository.
2
3 Walks Python source files, extracts every named type entity — Enum,
4 TypedDict, Pydantic BaseModel, SQLAlchemy ORM model — and generates a
5 structured Markdown reference document. The generated document replaces the
6 manually maintained ``type-contracts.md`` and is the single source of truth
7 for the repo's type surface.
8
9 Usage::
10
11 # Generate for the current repo (infers repo from cwd):
12 python scripts/gen_type_contracts.py
13
14 # Generate for a specific repo:
15 python scripts/gen_type_contracts.py --repo ~/ecosystem/musehub
16
17 # Check that committed doc matches current source (CI gate):
18 python scripts/gen_type_contracts.py --check
19
20 # Write to a custom path:
21 python scripts/gen_type_contracts.py --output path/to/out.md
22
23 # Target specific subdirectories:
24 python scripts/gen_type_contracts.py --dirs musehub/ tests/
25
26 Exit codes:
27 0 — OK (or --check passed with no drift).
28 1 — --check detected drift; see stderr for diff summary.
29 """
30
31 from __future__ import annotations
32
33 import argparse
34 import ast
35 import datetime
36 import difflib
37 import sys
38 from pathlib import Path
39 from typing import TypedDict
40
41 # ---------------------------------------------------------------------------
42 # Type shapes
43 # ---------------------------------------------------------------------------
44
45
46 class EnumMember(TypedDict):
47 """One member of an Enum class."""
48
49 name: str
50 value: str
51
52
53 class TypeField(TypedDict):
54 """One field of a TypedDict, Pydantic model, or ORM model."""
55
56 name: str
57 annotation: str
58 default: str
59 required: bool
60
61
62 class TypeRecord(TypedDict):
63 """A single extracted type entity."""
64
65 kind: str
66 """One of: ``"enum"`` | ``"typeddict"`` | ``"pydantic"`` | ``"orm"`` | ``"protocol"``."""
67 name: str
68 file: str
69 line: int
70 docstring: str
71 bases: list[str]
72 fields: list[TypeField]
73 members: list[EnumMember]
74 table: str
75 """ORM table name; empty string for non-ORM types."""
76
77
78 # ---------------------------------------------------------------------------
79 # Skip dirs (mirrors typing_audit.py)
80 # ---------------------------------------------------------------------------
81
82 _SKIP_DIRS: frozenset[str] = frozenset({
83 "venv", ".venv", "env", ".env",
84 "__pycache__",
85 ".git", ".muse", ".mypy_cache", ".ruff_cache", ".pytest_cache", ".tox",
86 "dist", "build", "site-packages", "__pypackages__",
87 "node_modules",
88 })
89
90 # Base class names that indicate each category.
91 _ENUM_BASES: frozenset[str] = frozenset({"Enum", "IntEnum", "StrEnum", "Flag", "IntFlag"})
92 _PYDANTIC_BASES: frozenset[str] = frozenset({
93 "BaseModel", "CamelModel", "BaseSettings", "SQLModel",
94 })
95 _TYPEDDICT_BASES: frozenset[str] = frozenset({"TypedDict"})
96 _PROTOCOL_BASES: frozenset[str] = frozenset({"Protocol"})
97 # ORM detection: class with __tablename__ attribute in the body.
98
99
100 # ---------------------------------------------------------------------------
101 # AST helpers
102 # ---------------------------------------------------------------------------
103
104
105 def _node_to_str(node: ast.expr) -> str:
106 """Convert an AST annotation node to a compact string representation."""
107 return ast.unparse(node)
108
109
110 def _get_docstring(node: ast.ClassDef) -> str:
111 """Return the first statement's string value if it is a docstring."""
112 if (
113 node.body
114 and isinstance(node.body[0], ast.Expr)
115 and isinstance(node.body[0].value, ast.Constant)
116 and isinstance(node.body[0].value.value, str)
117 ):
118 # Normalise whitespace: collapse newlines + indentation to a single space.
119 raw = node.body[0].value.value.strip()
120 # Keep only the first paragraph (up to the first blank line).
121 first_para = raw.split("\n\n")[0].replace("\n", " ").replace(" ", " ")
122 return first_para
123 return ""
124
125
126 def _base_names(node: ast.ClassDef) -> list[str]:
127 """Return the simple base-class names for *node*."""
128 names: list[str] = []
129 for base in node.bases:
130 if isinstance(base, ast.Name):
131 names.append(base.id)
132 elif isinstance(base, ast.Attribute):
133 names.append(base.attr)
134 return names
135
136
137 def _has_tablename(node: ast.ClassDef) -> str:
138 """Return the ``__tablename__`` value if present, else empty string."""
139 for stmt in node.body:
140 if isinstance(stmt, ast.Assign):
141 for target in stmt.targets:
142 if isinstance(target, ast.Name) and target.id == "__tablename__":
143 if isinstance(stmt.value, ast.Constant):
144 return str(stmt.value.value)
145 if isinstance(stmt, ast.AnnAssign):
146 if isinstance(stmt.target, ast.Name) and stmt.target.id == "__tablename__":
147 if stmt.value and isinstance(stmt.value, ast.Constant):
148 return str(stmt.value.value)
149 return ""
150
151
152 def _extract_default(node: ast.AnnAssign) -> str:
153 """Return a string representation of the default value, or empty string."""
154 if node.value is None:
155 return ""
156 val = node.value
157 # Field(...) — required Pydantic field; treat as no default.
158 if isinstance(val, ast.Call):
159 func = val.func
160 if isinstance(func, ast.Name) and func.id == "Field":
161 # Field(...) with first arg as Ellipsis = required.
162 if val.args and isinstance(val.args[0], ast.Constant) and val.args[0].value is ...:
163 return ""
164 # Field(default=...) keyword.
165 for kw in val.keywords:
166 if kw.arg == "default" and isinstance(kw.value, ast.Constant):
167 if kw.value.value is ...:
168 return ""
169 return repr(kw.value.value)
170 # Field with default_factory — treat as optional.
171 for kw in val.keywords:
172 if kw.arg == "default_factory":
173 return "(factory)"
174 return ast.unparse(val)
175 if isinstance(val, ast.Constant):
176 return repr(val.value)
177 return ast.unparse(val)
178
179
180 def _classify_and_extract(
181 node: ast.ClassDef,
182 filepath: str,
183 ) -> TypeRecord | None:
184 """Classify a class and extract its type information.
185
186 Returns ``None`` for classes that don't match any tracked category.
187 """
188 bases = _base_names(node)
189 bases_set = set(bases)
190
191 kind = ""
192 if bases_set & _ENUM_BASES:
193 kind = "enum"
194 elif bases_set & _TYPEDDICT_BASES:
195 kind = "typeddict"
196 elif bases_set & _PYDANTIC_BASES:
197 kind = "pydantic"
198 elif bases_set & _PROTOCOL_BASES:
199 kind = "protocol"
200 else:
201 # ORM: check for __tablename__
202 table = _has_tablename(node)
203 if table:
204 kind = "orm"
205
206 if not kind:
207 return None
208
209 table = _has_tablename(node) if kind == "orm" else ""
210 docstring = _get_docstring(node)
211
212 members: list[EnumMember] = []
213 fields: list[TypeField] = []
214
215 for stmt in node.body:
216 # --- Enum members ---
217 if kind == "enum" and isinstance(stmt, ast.Assign):
218 for target in stmt.targets:
219 if isinstance(target, ast.Name) and not target.id.startswith("_"):
220 val_str = ast.unparse(stmt.value) if stmt.value else ""
221 members.append(EnumMember(name=target.id, value=val_str))
222
223 # --- Annotated fields (TypedDict, Pydantic, ORM, Protocol) ---
224 if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name):
225 field_name = stmt.target.id
226 # Skip dunder and private fields
227 if field_name.startswith("__"):
228 continue
229 annotation = _node_to_str(stmt.annotation) if stmt.annotation else ""
230 default_str = _extract_default(stmt)
231 required = (stmt.value is None) or (default_str == "")
232 # For Pydantic, Field(...) means required; Field(None) means optional.
233 if kind == "pydantic" and stmt.value is not None:
234 val = stmt.value
235 if isinstance(val, ast.Call):
236 func = val.func
237 if isinstance(func, ast.Name) and func.id == "Field":
238 if val.args and isinstance(val.args[0], ast.Constant):
239 required = val.args[0].value is ...
240 else:
241 required = not any(
242 kw.arg in {"default", "default_factory"}
243 for kw in val.keywords
244 )
245 elif isinstance(val, ast.Constant):
246 required = val.value is ...
247 else:
248 required = False
249
250 fields.append(TypeField(
251 name=field_name,
252 annotation=annotation,
253 default=default_str,
254 required=required,
255 ))
256
257 return TypeRecord(
258 kind=kind,
259 name=node.name,
260 file=filepath,
261 line=node.lineno,
262 docstring=docstring,
263 bases=bases,
264 fields=fields,
265 members=members,
266 table=table,
267 )
268
269
270 # ---------------------------------------------------------------------------
271 # File and directory scanner
272 # ---------------------------------------------------------------------------
273
274
275 def scan_file(filepath: Path, root: Path) -> list[TypeRecord]:
276 """Extract all type records from a single Python source file."""
277 try:
278 source = filepath.read_text(encoding="utf-8")
279 except (OSError, UnicodeDecodeError):
280 return []
281
282 try:
283 tree = ast.parse(source)
284 except SyntaxError:
285 return []
286
287 relative = str(filepath.relative_to(root))
288 records: list[TypeRecord] = []
289
290 for node in ast.walk(tree):
291 if not isinstance(node, ast.ClassDef):
292 continue
293 record = _classify_and_extract(node, relative)
294 if record is not None:
295 records.append(record)
296
297 return records
298
299
300 def scan_directories(dirs: list[Path], root: Path) -> list[TypeRecord]:
301 """Scan all Python files in the given directories and return type records."""
302 all_records: list[TypeRecord] = []
303 seen_files: set[Path] = set()
304
305 for directory in dirs:
306 if not directory.exists():
307 continue
308 if directory.is_file():
309 if directory.suffix == ".py" and directory not in seen_files:
310 seen_files.add(directory)
311 all_records.extend(scan_file(directory, root))
312 continue
313 for py_file in sorted(directory.rglob("*.py")):
314 if py_file in seen_files:
315 continue
316 if any(part in _SKIP_DIRS for part in py_file.parts):
317 continue
318 seen_files.add(py_file)
319 all_records.extend(scan_file(py_file, root))
320
321 # Sort: by kind (enum < typeddict < pydantic < orm < protocol), then file, then name.
322 _kind_order = {"enum": 0, "typeddict": 1, "pydantic": 2, "orm": 3, "protocol": 4}
323 all_records.sort(key=lambda r: (_kind_order.get(r["kind"], 9), r["file"], r["name"]))
324 return all_records
325
326
327 # ---------------------------------------------------------------------------
328 # Markdown generation
329 # ---------------------------------------------------------------------------
330
331 _KIND_HEADING: dict[str, str] = {
332 "enum": "Enums",
333 "typeddict": "TypedDicts",
334 "pydantic": "Pydantic Models",
335 "orm": "SQLAlchemy ORM Models",
336 "protocol": "Protocols",
337 }
338
339 _KIND_ANCHOR: dict[str, str] = {
340 "enum": "enums",
341 "typeddict": "typeddicts",
342 "pydantic": "pydantic-models",
343 "orm": "sqlalchemy-orm-models",
344 "protocol": "protocols",
345 }
346
347
348 def _field_table(fields: list[TypeField]) -> str:
349 """Return a Markdown table for type fields."""
350 if not fields:
351 return "_No annotated fields._\n"
352 lines = ["| Field | Type | Required | Default |",
353 "|-------|------|:--------:|---------|"]
354 for f in fields:
355 req = "✓" if f["required"] else ""
356 default = f["default"] if f["default"] else "—"
357 lines.append(f"| `{f['name']}` | `{f['annotation']}` | {req} | `{default}` |")
358 return "\n".join(lines) + "\n"
359
360
361 def _enum_table(members: list[EnumMember]) -> str:
362 """Return a Markdown table for Enum members."""
363 if not members:
364 return "_No members._\n"
365 lines = ["| Member | Value |",
366 "|--------|-------|"]
367 for m in members:
368 lines.append(f"| `{m['name']}` | `{m['value']}` |")
369 return "\n".join(lines) + "\n"
370
371
372 def generate_markdown(
373 records: list[TypeRecord],
374 repo_name: str,
375 generated_by: str,
376 ) -> str:
377 """Render *records* into a complete Markdown type-contracts document."""
378 today = datetime.date.today().isoformat()
379
380 # Group by kind.
381 by_kind: dict[str, list[TypeRecord]] = {}
382 for r in records:
383 by_kind.setdefault(r["kind"], []).append(r)
384
385 # Build table of contents.
386 toc_lines: list[str] = []
387 for kind in ("enum", "typeddict", "pydantic", "orm", "protocol"):
388 if kind not in by_kind:
389 continue
390 heading = _KIND_HEADING[kind]
391 anchor = _KIND_ANCHOR[kind]
392 toc_lines.append(f"- [{heading}](#{anchor}) ({len(by_kind[kind])})")
393
394 # Build body sections.
395 body_parts: list[str] = []
396 for kind in ("enum", "typeddict", "pydantic", "orm", "protocol"):
397 if kind not in by_kind:
398 continue
399 heading = _KIND_HEADING[kind]
400 body_parts.append(f"## {heading}\n")
401
402 for r in by_kind[kind]:
403 link = f"[`{r['file']}`]({r['file']})"
404 body_parts.append(f"### `{r['name']}` — {link}\n")
405 if r["docstring"]:
406 body_parts.append(f"{r['docstring']}\n")
407 if r["bases"]:
408 body_parts.append(f"**Bases:** `{'`, `'.join(r['bases'])}`\n")
409 if r["table"]:
410 body_parts.append(f"**Table:** `{r['table']}`\n")
411 if kind == "enum":
412 body_parts.append(_enum_table(r["members"]))
413 else:
414 body_parts.append(_field_table(r["fields"]))
415 body_parts.append("")
416
417 toc = "\n".join(toc_lines)
418 body = "\n".join(body_parts)
419
420 return f"""\
421 # {repo_name} — Type Contracts Reference
422
423 > Auto-generated: {today}
424 > Source: `{generated_by}`
425 > Check for drift: `python tools/gen_type_contracts.py --check`
426 >
427 > **Do not edit this file manually.** It is regenerated from source on every
428 > commit. To update, change the source types and re-run the generator.
429
430 ---
431
432 ## Table of Contents
433
434 {toc}
435
436 ---
437
438 {body}
439 """
440
441
442 # ---------------------------------------------------------------------------
443 # CLI
444 # ---------------------------------------------------------------------------
445
446
447 def _infer_repo_name(repo_root: Path) -> str:
448 """Return a display name for the repo (title-cased directory name)."""
449 name = repo_root.name.replace("-", " ").replace("_", " ").title()
450 return name
451
452
453 def main() -> None:
454 """Entry point — parse CLI flags, scan, generate, and optionally check."""
455 parser = argparse.ArgumentParser(
456 description=(
457 "Auto-generate docs/reference/type-contracts.md from Python source."
458 ),
459 )
460 parser.add_argument(
461 "--repo",
462 type=Path,
463 default=Path.cwd(),
464 metavar="DIR",
465 help="Repo root directory. Default: current working directory.",
466 )
467 parser.add_argument(
468 "--dirs",
469 nargs="+",
470 metavar="DIR",
471 default=None,
472 help=(
473 "Source directories to scan (relative to --repo). "
474 "Default: auto-detected from repo name."
475 ),
476 )
477 parser.add_argument(
478 "--output",
479 type=Path,
480 default=None,
481 metavar="PATH",
482 help="Output path (relative to --repo). Default: docs/reference/type-contracts.md",
483 )
484 parser.add_argument(
485 "--check",
486 action="store_true",
487 help="Check that the committed document matches generated output. Exit 1 on drift.",
488 )
489 args = parser.parse_args()
490
491 repo_root: Path = args.repo.resolve()
492 if not repo_root.is_dir():
493 print(f"ERROR: --repo {repo_root} is not a directory", file=sys.stderr)
494 sys.exit(1)
495
496 # Determine source directories to scan.
497 if args.dirs is not None:
498 scan_dirs = [repo_root / d for d in args.dirs]
499 else:
500 # Auto-detect: scan the package directory matching the repo name, plus tests/.
501 repo_name_lc = repo_root.name.lower().replace("-", "").replace("_", "")
502 candidates = [repo_root / repo_root.name, repo_root / "src", repo_root / repo_name_lc]
503 detected = [d for d in candidates if d.is_dir()]
504 tests_dir = repo_root / "tests"
505 scan_dirs = (detected or [repo_root]) + ([tests_dir] if tests_dir.is_dir() else [])
506
507 # Determine output path.
508 out_rel = args.output or Path("docs/reference/type-contracts.md")
509 out_path = (repo_root / out_rel).resolve()
510
511 repo_name = _infer_repo_name(repo_root)
512 generated_by = f"tools/gen_type_contracts.py (repo: {repo_root.name})"
513
514 records = scan_directories(scan_dirs, repo_root)
515 content = generate_markdown(records, repo_name, generated_by)
516
517 if args.check:
518 if not out_path.exists():
519 print(
520 f"ERROR: --check failed — {out_path} does not exist.\n"
521 "Run without --check to generate it.",
522 file=sys.stderr,
523 )
524 sys.exit(1)
525 committed = out_path.read_text(encoding="utf-8")
526 if committed == content:
527 print(f"✅ type-contracts.md is up to date ({out_path})")
528 return
529 # Show a short diff summary.
530 diff = list(difflib.unified_diff(
531 committed.splitlines(keepends=True),
532 content.splitlines(keepends=True),
533 fromfile="committed",
534 tofile="generated",
535 n=3,
536 ))
537 print(
538 f"❌ type-contracts.md is out of date ({out_path}).\n"
539 f"Regenerate with: python tools/gen_type_contracts.py\n"
540 f"\nDiff ({len(diff)} lines):",
541 file=sys.stderr,
542 )
543 sys.stderr.writelines(diff[:80])
544 if len(diff) > 80:
545 print(f"... ({len(diff) - 80} more diff lines)", file=sys.stderr)
546 sys.exit(1)
547
548 # Write output.
549 out_path.parent.mkdir(parents=True, exist_ok=True)
550 out_path.write_text(content, encoding="utf-8")
551 total = len(records)
552 counts = {k: len(v) for k, v in {}.items()}
553 by_kind: dict[str, list[TypeRecord]] = {}
554 for r in records:
555 by_kind.setdefault(r["kind"], []).append(r)
556 summary = ", ".join(
557 f"{len(v)} {_KIND_HEADING.get(k, k)}"
558 for k, v in sorted(by_kind.items())
559 )
560 print(f"✅ Generated {out_path} ({total} types: {summary})")
561 _ = counts # unused but kept to avoid bare-dict-at-boundary violation
562
563
564 if __name__ == "__main__":
565 main()
File History 1 commit
sha256:3ff9c9863a9891bdcde71b4a43228f66d0493e38b7cc1d09fe9eb7de774046b2 feat: add repair-commit wire endpoint (API parity with repa… Opus 4.8 minor 1 day ago