"""Auto-generate docs/reference/type-contracts.md for a repository. Walks Python source files, extracts every named type entity — Enum, TypedDict, Pydantic BaseModel, SQLAlchemy ORM model — and generates a structured Markdown reference document. The generated document replaces the manually maintained ``type-contracts.md`` and is the single source of truth for the repo's type surface. Usage:: # Generate for the current repo (infers repo from cwd): python scripts/gen_type_contracts.py # Generate for a specific repo: python scripts/gen_type_contracts.py --repo ~/ecosystem/musehub # Check that committed doc matches current source (CI gate): python scripts/gen_type_contracts.py --check # Write to a custom path: python scripts/gen_type_contracts.py --output path/to/out.md # Target specific subdirectories: python scripts/gen_type_contracts.py --dirs musehub/ tests/ Exit codes: 0 — OK (or --check passed with no drift). 1 — --check detected drift; see stderr for diff summary. """ from __future__ import annotations import argparse import ast import datetime import difflib import sys from pathlib import Path from typing import TypedDict # --------------------------------------------------------------------------- # Type shapes # --------------------------------------------------------------------------- class EnumMember(TypedDict): """One member of an Enum class.""" name: str value: str class TypeField(TypedDict): """One field of a TypedDict, Pydantic model, or ORM model.""" name: str annotation: str default: str required: bool class TypeRecord(TypedDict): """A single extracted type entity.""" kind: str """One of: ``"enum"`` | ``"typeddict"`` | ``"pydantic"`` | ``"orm"`` | ``"protocol"``.""" name: str file: str line: int docstring: str bases: list[str] fields: list[TypeField] members: list[EnumMember] table: str """ORM table name; empty string for non-ORM types.""" # --------------------------------------------------------------------------- # Skip dirs (mirrors typing_audit.py) # --------------------------------------------------------------------------- _SKIP_DIRS: frozenset[str] = frozenset({ "venv", ".venv", "env", ".env", "__pycache__", ".git", ".muse", ".mypy_cache", ".ruff_cache", ".pytest_cache", ".tox", "dist", "build", "site-packages", "__pypackages__", "node_modules", }) # Base class names that indicate each category. _ENUM_BASES: frozenset[str] = frozenset({"Enum", "IntEnum", "StrEnum", "Flag", "IntFlag"}) _PYDANTIC_BASES: frozenset[str] = frozenset({ "BaseModel", "CamelModel", "BaseSettings", "SQLModel", }) _TYPEDDICT_BASES: frozenset[str] = frozenset({"TypedDict"}) _PROTOCOL_BASES: frozenset[str] = frozenset({"Protocol"}) # ORM detection: class with __tablename__ attribute in the body. # --------------------------------------------------------------------------- # AST helpers # --------------------------------------------------------------------------- def _node_to_str(node: ast.expr) -> str: """Convert an AST annotation node to a compact string representation.""" return ast.unparse(node) def _get_docstring(node: ast.ClassDef) -> str: """Return the first statement's string value if it is a docstring.""" if ( node.body and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value.value, str) ): # Normalise whitespace: collapse newlines + indentation to a single space. raw = node.body[0].value.value.strip() # Keep only the first paragraph (up to the first blank line). first_para = raw.split("\n\n")[0].replace("\n", " ").replace(" ", " ") return first_para return "" def _base_names(node: ast.ClassDef) -> list[str]: """Return the simple base-class names for *node*.""" names: list[str] = [] for base in node.bases: if isinstance(base, ast.Name): names.append(base.id) elif isinstance(base, ast.Attribute): names.append(base.attr) return names def _has_tablename(node: ast.ClassDef) -> str: """Return the ``__tablename__`` value if present, else empty string.""" for stmt in node.body: if isinstance(stmt, ast.Assign): for target in stmt.targets: if isinstance(target, ast.Name) and target.id == "__tablename__": if isinstance(stmt.value, ast.Constant): return str(stmt.value.value) if isinstance(stmt, ast.AnnAssign): if isinstance(stmt.target, ast.Name) and stmt.target.id == "__tablename__": if stmt.value and isinstance(stmt.value, ast.Constant): return str(stmt.value.value) return "" def _extract_default(node: ast.AnnAssign) -> str: """Return a string representation of the default value, or empty string.""" if node.value is None: return "" val = node.value # Field(...) — required Pydantic field; treat as no default. if isinstance(val, ast.Call): func = val.func if isinstance(func, ast.Name) and func.id == "Field": # Field(...) with first arg as Ellipsis = required. if val.args and isinstance(val.args[0], ast.Constant) and val.args[0].value is ...: return "" # Field(default=...) keyword. for kw in val.keywords: if kw.arg == "default" and isinstance(kw.value, ast.Constant): if kw.value.value is ...: return "" return repr(kw.value.value) # Field with default_factory — treat as optional. for kw in val.keywords: if kw.arg == "default_factory": return "(factory)" return ast.unparse(val) if isinstance(val, ast.Constant): return repr(val.value) return ast.unparse(val) def _classify_and_extract( node: ast.ClassDef, filepath: str, ) -> TypeRecord | None: """Classify a class and extract its type information. Returns ``None`` for classes that don't match any tracked category. """ bases = _base_names(node) bases_set = set(bases) kind = "" if bases_set & _ENUM_BASES: kind = "enum" elif bases_set & _TYPEDDICT_BASES: kind = "typeddict" elif bases_set & _PYDANTIC_BASES: kind = "pydantic" elif bases_set & _PROTOCOL_BASES: kind = "protocol" else: # ORM: check for __tablename__ table = _has_tablename(node) if table: kind = "orm" if not kind: return None table = _has_tablename(node) if kind == "orm" else "" docstring = _get_docstring(node) members: list[EnumMember] = [] fields: list[TypeField] = [] for stmt in node.body: # --- Enum members --- if kind == "enum" and isinstance(stmt, ast.Assign): for target in stmt.targets: if isinstance(target, ast.Name) and not target.id.startswith("_"): val_str = ast.unparse(stmt.value) if stmt.value else "" members.append(EnumMember(name=target.id, value=val_str)) # --- Annotated fields (TypedDict, Pydantic, ORM, Protocol) --- if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name): field_name = stmt.target.id # Skip dunder and private fields if field_name.startswith("__"): continue annotation = _node_to_str(stmt.annotation) if stmt.annotation else "" default_str = _extract_default(stmt) required = (stmt.value is None) or (default_str == "") # For Pydantic, Field(...) means required; Field(None) means optional. if kind == "pydantic" and stmt.value is not None: val = stmt.value if isinstance(val, ast.Call): func = val.func if isinstance(func, ast.Name) and func.id == "Field": if val.args and isinstance(val.args[0], ast.Constant): required = val.args[0].value is ... else: required = not any( kw.arg in {"default", "default_factory"} for kw in val.keywords ) elif isinstance(val, ast.Constant): required = val.value is ... else: required = False fields.append(TypeField( name=field_name, annotation=annotation, default=default_str, required=required, )) return TypeRecord( kind=kind, name=node.name, file=filepath, line=node.lineno, docstring=docstring, bases=bases, fields=fields, members=members, table=table, ) # --------------------------------------------------------------------------- # File and directory scanner # --------------------------------------------------------------------------- def scan_file(filepath: Path, root: Path) -> list[TypeRecord]: """Extract all type records from a single Python source file.""" try: source = filepath.read_text(encoding="utf-8") except (OSError, UnicodeDecodeError): return [] try: tree = ast.parse(source) except SyntaxError: return [] relative = str(filepath.relative_to(root)) records: list[TypeRecord] = [] for node in ast.walk(tree): if not isinstance(node, ast.ClassDef): continue record = _classify_and_extract(node, relative) if record is not None: records.append(record) return records def scan_directories(dirs: list[Path], root: Path) -> list[TypeRecord]: """Scan all Python files in the given directories and return type records.""" all_records: list[TypeRecord] = [] seen_files: set[Path] = set() for directory in dirs: if not directory.exists(): continue if directory.is_file(): if directory.suffix == ".py" and directory not in seen_files: seen_files.add(directory) all_records.extend(scan_file(directory, root)) continue for py_file in sorted(directory.rglob("*.py")): if py_file in seen_files: continue if any(part in _SKIP_DIRS for part in py_file.parts): continue seen_files.add(py_file) all_records.extend(scan_file(py_file, root)) # Sort: by kind (enum < typeddict < pydantic < orm < protocol), then file, then name. _kind_order = {"enum": 0, "typeddict": 1, "pydantic": 2, "orm": 3, "protocol": 4} all_records.sort(key=lambda r: (_kind_order.get(r["kind"], 9), r["file"], r["name"])) return all_records # --------------------------------------------------------------------------- # Markdown generation # --------------------------------------------------------------------------- _KIND_HEADING: dict[str, str] = { "enum": "Enums", "typeddict": "TypedDicts", "pydantic": "Pydantic Models", "orm": "SQLAlchemy ORM Models", "protocol": "Protocols", } _KIND_ANCHOR: dict[str, str] = { "enum": "enums", "typeddict": "typeddicts", "pydantic": "pydantic-models", "orm": "sqlalchemy-orm-models", "protocol": "protocols", } def _field_table(fields: list[TypeField]) -> str: """Return a Markdown table for type fields.""" if not fields: return "_No annotated fields._\n" lines = ["| Field | Type | Required | Default |", "|-------|------|:--------:|---------|"] for f in fields: req = "✓" if f["required"] else "" default = f["default"] if f["default"] else "—" lines.append(f"| `{f['name']}` | `{f['annotation']}` | {req} | `{default}` |") return "\n".join(lines) + "\n" def _enum_table(members: list[EnumMember]) -> str: """Return a Markdown table for Enum members.""" if not members: return "_No members._\n" lines = ["| Member | Value |", "|--------|-------|"] for m in members: lines.append(f"| `{m['name']}` | `{m['value']}` |") return "\n".join(lines) + "\n" def generate_markdown( records: list[TypeRecord], repo_name: str, generated_by: str, ) -> str: """Render *records* into a complete Markdown type-contracts document.""" today = datetime.date.today().isoformat() # Group by kind. by_kind: dict[str, list[TypeRecord]] = {} for r in records: by_kind.setdefault(r["kind"], []).append(r) # Build table of contents. toc_lines: list[str] = [] for kind in ("enum", "typeddict", "pydantic", "orm", "protocol"): if kind not in by_kind: continue heading = _KIND_HEADING[kind] anchor = _KIND_ANCHOR[kind] toc_lines.append(f"- [{heading}](#{anchor}) ({len(by_kind[kind])})") # Build body sections. body_parts: list[str] = [] for kind in ("enum", "typeddict", "pydantic", "orm", "protocol"): if kind not in by_kind: continue heading = _KIND_HEADING[kind] body_parts.append(f"## {heading}\n") for r in by_kind[kind]: link = f"[`{r['file']}`]({r['file']})" body_parts.append(f"### `{r['name']}` — {link}\n") if r["docstring"]: body_parts.append(f"{r['docstring']}\n") if r["bases"]: body_parts.append(f"**Bases:** `{'`, `'.join(r['bases'])}`\n") if r["table"]: body_parts.append(f"**Table:** `{r['table']}`\n") if kind == "enum": body_parts.append(_enum_table(r["members"])) else: body_parts.append(_field_table(r["fields"])) body_parts.append("") toc = "\n".join(toc_lines) body = "\n".join(body_parts) return f"""\ # {repo_name} — Type Contracts Reference > Auto-generated: {today} > Source: `{generated_by}` > Check for drift: `python tools/gen_type_contracts.py --check` > > **Do not edit this file manually.** It is regenerated from source on every > commit. To update, change the source types and re-run the generator. --- ## Table of Contents {toc} --- {body} """ # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def _infer_repo_name(repo_root: Path) -> str: """Return a display name for the repo (title-cased directory name).""" name = repo_root.name.replace("-", " ").replace("_", " ").title() return name def main() -> None: """Entry point — parse CLI flags, scan, generate, and optionally check.""" parser = argparse.ArgumentParser( description=( "Auto-generate docs/reference/type-contracts.md from Python source." ), ) parser.add_argument( "--repo", type=Path, default=Path.cwd(), metavar="DIR", help="Repo root directory. Default: current working directory.", ) parser.add_argument( "--dirs", nargs="+", metavar="DIR", default=None, help=( "Source directories to scan (relative to --repo). " "Default: auto-detected from repo name." ), ) parser.add_argument( "--output", type=Path, default=None, metavar="PATH", help="Output path (relative to --repo). Default: docs/reference/type-contracts.md", ) parser.add_argument( "--check", action="store_true", help="Check that the committed document matches generated output. Exit 1 on drift.", ) args = parser.parse_args() repo_root: Path = args.repo.resolve() if not repo_root.is_dir(): print(f"ERROR: --repo {repo_root} is not a directory", file=sys.stderr) sys.exit(1) # Determine source directories to scan. if args.dirs is not None: scan_dirs = [repo_root / d for d in args.dirs] else: # Auto-detect: scan the package directory matching the repo name, plus tests/. repo_name_lc = repo_root.name.lower().replace("-", "").replace("_", "") candidates = [repo_root / repo_root.name, repo_root / "src", repo_root / repo_name_lc] detected = [d for d in candidates if d.is_dir()] tests_dir = repo_root / "tests" scan_dirs = (detected or [repo_root]) + ([tests_dir] if tests_dir.is_dir() else []) # Determine output path. out_rel = args.output or Path("docs/reference/type-contracts.md") out_path = (repo_root / out_rel).resolve() repo_name = _infer_repo_name(repo_root) generated_by = f"tools/gen_type_contracts.py (repo: {repo_root.name})" records = scan_directories(scan_dirs, repo_root) content = generate_markdown(records, repo_name, generated_by) if args.check: if not out_path.exists(): print( f"ERROR: --check failed — {out_path} does not exist.\n" "Run without --check to generate it.", file=sys.stderr, ) sys.exit(1) committed = out_path.read_text(encoding="utf-8") if committed == content: print(f"✅ type-contracts.md is up to date ({out_path})") return # Show a short diff summary. diff = list(difflib.unified_diff( committed.splitlines(keepends=True), content.splitlines(keepends=True), fromfile="committed", tofile="generated", n=3, )) print( f"❌ type-contracts.md is out of date ({out_path}).\n" f"Regenerate with: python tools/gen_type_contracts.py\n" f"\nDiff ({len(diff)} lines):", file=sys.stderr, ) sys.stderr.writelines(diff[:80]) if len(diff) > 80: print(f"... ({len(diff) - 80} more diff lines)", file=sys.stderr) sys.exit(1) # Write output. out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text(content, encoding="utf-8") total = len(records) counts = {k: len(v) for k, v in {}.items()} by_kind: dict[str, list[TypeRecord]] = {} for r in records: by_kind.setdefault(r["kind"], []).append(r) summary = ", ".join( f"{len(v)} {_KIND_HEADING.get(k, k)}" for k, v in sorted(by_kind.items()) ) print(f"✅ Generated {out_path} ({total} types: {summary})") _ = counts # unused but kept to avoid bare-dict-at-boundary violation if __name__ == "__main__": main()