gabriel / muse public
deps.py python
593 lines 21.8 KB
Raw
1 """muse code deps — import graph and call-graph analysis.
2
3 Answers questions that Git cannot answer structurally:
4
5 **File mode** (``muse code deps src/billing.py``):
6 What does this file import, and what files in the repo import it?
7
8 **Symbol mode** (``muse code deps "src/billing.py::compute_invoice_total"``):
9 What does this function call? (Python only; uses stdlib ``ast``.)
10 With ``--reverse``: what symbols in the repo call this function?
11 With ``--depth N``: walk the call chain transitively N hops deep.
12 With ``--transitive``: unlimited-depth BFS through the full call graph.
13
14 These relationships are *structural impossibilities* in Git: Git stores files
15 as blobs of text with no concept of imports or call-sites. Muse reads the
16 typed symbol graph and the AST to answer these questions in milliseconds.
17
18 Usage::
19
20 muse code deps src/billing.py # what does billing.py import?
21 muse code deps src/billing.py --reverse # who imports billing.py?
22 muse code deps src/billing.py --reverse --filter tests/ # only importers in tests/
23 muse code deps "src/billing.py::compute_total" # direct callees (Python)
24 muse code deps "src/billing.py::compute_total" --reverse # direct callers
25 muse code deps "src/billing.py::compute_total" --reverse --depth 3 # 3-hop callers
26 muse code deps "src/billing.py::compute_total" --transitive # full blast radius
27 muse code deps "src/billing.py::compute_total" --count # caller count only
28
29 Flags:
30
31 ``--commit, -c REF``
32 Inspect a historical snapshot instead of HEAD.
33
34 ``--reverse``
35 Invert the query: show callers (symbol mode) or importers (file mode).
36
37 ``--depth N``
38 With ``--reverse``: walk transitive callers up to N hops.
39 Without ``--reverse``: walk transitive callees up to N hops.
40 ``0`` means unlimited (same as ``--transitive``).
41
42 ``--transitive``
43 Walk the full call graph without a depth cap (equivalent to ``--depth 0``).
44
45 ``--filter PATTERN``
46 Restrict results to addresses or file paths containing PATTERN.
47
48 ``--count``
49 Emit only the total count of results.
50
51 ``--json``
52 Emit results as JSON.
53 """
54
55 import argparse
56 import fnmatch
57 import json
58 import logging
59 import pathlib
60 import sys
61 from typing import TypedDict
62
63 from muse.core.envelope import EnvelopeJson, make_envelope
64 from muse.core.errors import ExitCode
65 from muse.core.repo import require_repo
66 from muse.core.refs import read_current_branch
67 from muse.core.commits import resolve_commit_ref
68 from muse.core.snapshots import get_commit_snapshot_manifest
69 from muse.core.symbol_cache import load_symbol_cache
70 from muse.core.timing import start_timer
71 from muse.plugins.code._callgraph import (
72 build_forward_graph,
73 build_reverse_graph,
74 callees_for_symbol,
75 transitive_callees,
76 transitive_callers,
77 )
78 from muse.plugins.code._query import language_of, symbols_for_snapshot
79 from muse.plugins.code.ast_parser import SymbolTree
80 from muse.core.validation import clamp_int, sanitize_display
81 from muse.core.types import Manifest
82
83 logger = logging.getLogger(__name__)
84
85 # ---------------------------------------------------------------------------
86 # Typed output shapes
87 # ---------------------------------------------------------------------------
88
89 class _DepsFileJson(EnvelopeJson, total=False):
90 """JSON output for file-mode ``muse code deps``.
91
92 Inherits the 6 standard envelope fields from :class:`~muse.core.envelope.EnvelopeJson`.
93
94 Fields
95 ------
96 path The target file path.
97 imports Forward: modules imported by the file.
98 imported_by Reverse: files that import the target.
99 """
100
101 path: str
102 imports: list[str]
103 imported_by: list[str]
104
105 class _DepsSymbolJson(EnvelopeJson, total=False):
106 """JSON output for symbol-mode ``muse code deps``.
107
108 Inherits the 6 standard envelope fields from :class:`~muse.core.envelope.EnvelopeJson`.
109
110 Fields
111 ------
112 address The full symbol address (file::symbol).
113 target_name Bare symbol name used for reverse-graph lookup.
114 depth BFS depth limit used (0 = unlimited / transitive).
115 transitive True when --transitive was passed.
116 calls Forward depth=1: direct callees.
117 called_by Reverse depth=1: direct callers.
118 by_depth Multi-hop: depth (str) → sorted list of addresses.
119 """
120
121 address: str
122 target_name: str
123 depth: int
124 transitive: bool
125 calls: list[str]
126 called_by: list[str]
127 by_depth: dict[str, list[str]]
128
129 def _validate_file_rel(file_rel: str) -> None:
130 """Exit with USER_ERROR if *file_rel* looks like a path traversal attempt."""
131 if not file_rel:
132 print("❌ Target file path cannot be empty.", file=sys.stderr)
133 raise SystemExit(ExitCode.USER_ERROR)
134 p = pathlib.PurePosixPath(file_rel)
135 if p.is_absolute() or ".." in p.parts:
136 print(
137 f"❌ Target path '{file_rel}' must be a relative path"
138 " with no '..' components.",
139 file=sys.stderr,
140 )
141 raise SystemExit(ExitCode.USER_ERROR)
142
143 # ---------------------------------------------------------------------------
144 # Import graph helpers
145 # ---------------------------------------------------------------------------
146
147 def _imports_in_tree(tree: SymbolTree) -> list[str]:
148 """Return the list of module/symbol names imported by symbols in *tree*."""
149 return sorted(
150 rec["qualified_name"]
151 for rec in tree.values()
152 if rec["kind"] == "import"
153 )
154
155 def _file_imports(
156 root: pathlib.Path,
157 manifest: Manifest,
158 target_file: str,
159 *,
160 workdir: pathlib.Path | None = None,
161 ) -> list[str]:
162 """Return import names declared in *target_file* within *manifest*."""
163 cache = load_symbol_cache(root)
164 sym_map = symbols_for_snapshot(
165 root, manifest, file_filter=target_file, workdir=workdir, cache=cache
166 )
167 cache.save()
168 tree = sym_map.get(target_file, {})
169 return _imports_in_tree(tree)
170
171 def _reverse_imports(
172 root: pathlib.Path,
173 manifest: Manifest,
174 target_file: str,
175 file_filter: str | None = None,
176 ) -> list[str]:
177 """Return files in *manifest* that import a name matching *target_file*.
178
179 The heuristic: the target file's stem (e.g. ``billing`` for
180 ``src/billing.py``) is matched against each other file's import names.
181 This catches ``import billing``, ``from billing import X``, and fully-
182 qualified paths like ``src.billing``.
183
184 Args:
185 root: Repository root.
186 manifest: Snapshot manifest.
187 target_file: File path to find importers of.
188 file_filter: Optional glob pattern — only files matching this pattern
189 are included in the importer list.
190 """
191 target_stem = pathlib.PurePosixPath(target_file).stem
192 target_module = (
193 pathlib.PurePosixPath(target_file)
194 .with_suffix("")
195 .as_posix()
196 .replace("/", ".")
197 )
198 # Use the shared cache — one load for the entire manifest scan.
199 cache = load_symbol_cache(root)
200 sym_map = symbols_for_snapshot(root, manifest, cache=cache)
201 cache.save()
202
203 importers: list[str] = []
204 for file_path, tree in sym_map.items():
205 if file_path == target_file:
206 continue
207 if file_filter and not fnmatch.fnmatch(file_path, f"*{file_filter}*"):
208 continue
209 for imp_name in _imports_in_tree(tree):
210 if (
211 imp_name == target_stem
212 or imp_name == target_module
213 or imp_name.endswith(f".{target_stem}")
214 or imp_name.endswith(f".{target_module}")
215 or target_stem in imp_name.split(".")
216 ):
217 importers.append(file_path)
218 break
219 return sorted(importers)
220
221 # ---------------------------------------------------------------------------
222 # Registration
223 # ---------------------------------------------------------------------------
224
225 def register(
226 subparsers: "argparse._SubParsersAction[argparse.ArgumentParser]",
227 ) -> None:
228 """Register the deps subcommand."""
229 parser = subparsers.add_parser(
230 "deps",
231 help="Show the import graph or call graph for a file or symbol.",
232 description=__doc__,
233 formatter_class=argparse.RawDescriptionHelpFormatter,
234 )
235 parser.add_argument(
236 "target",
237 metavar="TARGET",
238 help=(
239 'File path (e.g. "src/billing.py") for import graph, or '
240 'symbol address (e.g. "src/billing.py::compute_invoice_total")'
241 " for call graph."
242 ),
243 )
244 parser.add_argument(
245 "--reverse", "-r",
246 action="store_true",
247 help="Show importers (file mode) or callers (symbol mode) instead.",
248 )
249 parser.add_argument(
250 "--depth",
251 type=int,
252 default=1,
253 metavar="N",
254 help=(
255 "Walk transitive callers/callees up to N hops"
256 " (symbol mode only; 0 = unlimited)."
257 ),
258 )
259 parser.add_argument(
260 "--transitive",
261 action="store_true",
262 help="Walk the full call graph without a depth cap.",
263 )
264 parser.add_argument(
265 "--filter",
266 dest="file_filter",
267 default=None,
268 metavar="PATTERN",
269 help="Restrict results to addresses/paths containing PATTERN.",
270 )
271 parser.add_argument(
272 "--count",
273 action="store_true",
274 help="Emit only the total count of results.",
275 )
276 parser.add_argument(
277 "--commit", "-c",
278 default=None,
279 metavar="REF",
280 dest="ref",
281 help="Inspect a historical commit instead of HEAD.",
282 )
283 parser.add_argument(
284 "--json", "-j",
285 action="store_true",
286 dest="json_out",
287 help="Emit results as JSON.",
288 )
289 parser.set_defaults(func=run)
290
291 # ---------------------------------------------------------------------------
292 # Run
293 # ---------------------------------------------------------------------------
294
295 def run(args: argparse.Namespace) -> None:
296 """Show the import graph or call graph for a file or symbol.
297
298 File mode (pass a file path) shows outbound imports or inbound importers.
299 Symbol mode (pass a ``file.py::Symbol`` address) walks the Python call
300 graph forward (callees) or reverse (callers) up to ``--depth`` hops, or
301 transitively with ``--transitive``.
302
303 Agent quickstart
304 ----------------
305 ::
306
307 muse code deps src/billing.py --json
308 muse code deps src/billing.py --reverse --json
309 muse code deps "src/billing.py::compute_total" --json
310 muse code deps "src/billing.py::compute_total" --transitive --json
311
312 JSON fields (file mode, forward)
313 ---------------------------------
314 path File path analysed.
315 imports List of module names imported by the file.
316
317 JSON fields (file mode, reverse)
318 ---------------------------------
319 path File path analysed.
320 imported_by List of file paths that import this file.
321
322 JSON fields (symbol mode, forward)
323 ------------------------------------
324 address Symbol address analysed.
325 depth Hop depth requested.
326 calls List of callee symbol addresses (depth-1) or by-depth map.
327
328 Exit codes
329 ----------
330 0 Analysis complete.
331 1 Symbol or file not found, invalid arguments.
332 2 Not inside a Muse repository.
333 """
334 elapsed = start_timer()
335 target: str = args.target
336 reverse: bool = args.reverse
337 ref: str | None = args.ref
338 json_out: bool = args.json_out
339 depth: int = clamp_int(args.depth, 1, 50, "depth")
340 transitive: bool = args.transitive
341 file_filter: str | None = args.file_filter
342 count_only: bool = args.count
343
344 # --transitive overrides --depth to unlimited (0 signals no BFS cap internally).
345 effective_depth = 0 if transitive else depth
346
347 root = require_repo()
348 branch = read_current_branch(root)
349
350 is_symbol_mode = "::" in target
351
352 # ── Symbol mode: call-graph (Python only) ─────────────────────────────────
353 if is_symbol_mode:
354 file_rel, sym_qualified = target.split("::", 1)
355 _validate_file_rel(file_rel)
356
357 lang = language_of(file_rel)
358 if lang != "Python":
359 print(
360 f"❌ Call-graph analysis is currently Python-only."
361 f" '{file_rel}' is {lang}.",
362 file=sys.stderr,
363 )
364 raise SystemExit(ExitCode.USER_ERROR)
365
366 commit = resolve_commit_ref(root, branch, ref)
367 if commit is None:
368 print(
369 f"❌ Commit '{ref or 'HEAD'}' not found.", file=sys.stderr
370 )
371 raise SystemExit(ExitCode.USER_ERROR)
372
373 manifest = get_commit_snapshot_manifest(root, commit.commit_id) or {}
374 shared_cache = load_symbol_cache(root)
375
376 if not reverse:
377 # ── Forward: what does this symbol call? ──────────────────────
378 if effective_depth == 1:
379 # Single hop — fast path using source bytes.
380 obj_id = manifest.get(file_rel)
381 if obj_id is None:
382 print(
383 f"❌ '{file_rel}' not found in snapshot"
384 f" '{commit.commit_id}'.",
385 file=sys.stderr,
386 )
387 raise SystemExit(ExitCode.USER_ERROR)
388 from muse.core.object_store import read_object # noqa: PLC0415
389 raw = read_object(root, obj_id)
390 if raw is None:
391 print(
392 f"❌ Object for '{file_rel}' missing from store.",
393 file=sys.stderr,
394 )
395 raise SystemExit(ExitCode.USER_ERROR)
396 callees = callees_for_symbol(raw, target)
397 if file_filter:
398 callees = [c for c in callees if file_filter in c]
399 shared_cache.save()
400
401 if count_only:
402 print(len(callees))
403 return
404 if json_out:
405 print(json.dumps(_DepsSymbolJson(
406 **make_envelope(elapsed),
407 address=target,
408 depth=1,
409 calls=callees,
410 )))
411 return
412 print(f"\nDirect callees of {sanitize_display(target)}")
413 print("─" * 62)
414 if not callees:
415 print(" (no function calls detected)")
416 else:
417 for name in callees:
418 print(f" {name}")
419 print(f"\n{len(callees)} callee(s)")
420 else:
421 # Multi-hop: build full forward graph and BFS.
422 forward = build_forward_graph(root, manifest, cache=shared_cache)
423 shared_cache.save()
424 by_depth = transitive_callees(target, forward, effective_depth)
425 flat = [
426 name
427 for d in sorted(by_depth)
428 for name in sorted(by_depth[d])
429 ]
430 if file_filter:
431 flat = [c for c in flat if file_filter in c]
432 total = sum(len(v) for v in by_depth.values())
433
434 if count_only:
435 print(total)
436 return
437 if json_out:
438 print(json.dumps(_DepsSymbolJson(
439 **make_envelope(elapsed),
440 address=target,
441 depth=effective_depth,
442 transitive=True,
443 by_depth={str(d): sorted(by_depth[d]) for d in sorted(by_depth)},
444 )))
445 return
446 depth_label = (
447 "∞" if effective_depth == 0 else str(effective_depth)
448 )
449 print(
450 f"\nTransitive callees of {target}"
451 f" (depth ≤ {depth_label})"
452 )
453 print("─" * 62)
454 if not by_depth:
455 print(" (no callees found)")
456 else:
457 for d in sorted(by_depth):
458 names = sorted(by_depth[d])
459 print(f"\n depth {d}:")
460 for name in names:
461 print(f" {sanitize_display(name)}")
462 print(f"\n{total} callee(s) across {len(by_depth)} depth(s)")
463
464 else:
465 # ── Reverse: who calls this symbol? ───────────────────────────
466 target_name = sym_qualified.split(".")[-1]
467 reverse_graph = build_reverse_graph(root, manifest, cache=shared_cache)
468 shared_cache.save()
469
470 if effective_depth == 1:
471 # Direct callers only.
472 callers = reverse_graph.get(target_name, [])
473 if file_filter:
474 callers = [c for c in callers if file_filter in c]
475 if count_only:
476 print(len(callers))
477 return
478 if json_out:
479 print(json.dumps(_DepsSymbolJson(
480 **make_envelope(elapsed),
481 address=target,
482 target_name=target_name,
483 depth=1,
484 called_by=callers,
485 )))
486 return
487 print(f"\nDirect callers of {sanitize_display(target)}")
488 print("─" * 62)
489 if not callers:
490 print(" (no callers found in snapshot)")
491 else:
492 for addr in callers:
493 print(f" {sanitize_display(addr)}")
494 print(f"\n{len(callers)} caller(s)")
495 else:
496 # Transitive callers via BFS.
497 by_depth = transitive_callers(target_name, reverse_graph, effective_depth)
498 if file_filter:
499 by_depth = {
500 d: [c for c in callers if file_filter in c]
501 for d, callers in by_depth.items()
502 }
503 by_depth = {d: v for d, v in by_depth.items() if v}
504 total = sum(len(v) for v in by_depth.values())
505
506 if count_only:
507 print(total)
508 return
509 if json_out:
510 print(json.dumps(_DepsSymbolJson(
511 **make_envelope(elapsed),
512 address=target,
513 target_name=target_name,
514 depth=effective_depth,
515 transitive=True,
516 by_depth={str(d): sorted(by_depth[d]) for d in sorted(by_depth)},
517 )))
518 return
519 depth_label = (
520 "∞" if effective_depth == 0 else str(effective_depth)
521 )
522 print(
523 f"\nTransitive callers of {target}"
524 f" (depth ≤ {depth_label})"
525 )
526 print(f" (matching bare name: {target_name!r})")
527 print("─" * 62)
528 if not by_depth:
529 print(" (no callers found)")
530 else:
531 for d in sorted(by_depth):
532 addrs = sorted(by_depth[d])
533 print(f"\n depth {d}:")
534 for addr in addrs:
535 print(f" {addr}")
536 print(f"\n{total} caller(s) across {len(by_depth)} depth(s)")
537 return
538
539 # ── File mode: import graph ────────────────────────────────────────────────
540 _validate_file_rel(target)
541
542 commit = resolve_commit_ref(root, branch, ref)
543 if commit is None:
544 print(f"❌ Commit '{ref or 'HEAD'}' not found.", file=sys.stderr)
545 raise SystemExit(ExitCode.USER_ERROR)
546
547 manifest = get_commit_snapshot_manifest(root, commit.commit_id) or {}
548
549 # When working-tree mode and the target is not yet committed, inject a
550 # synthetic manifest entry so symbols_for_snapshot can read it from disk.
551 working_tree = ref is None
552 if working_tree and target not in manifest:
553 candidate = root / target
554 if candidate.is_file():
555 manifest = dict(manifest)
556 manifest[target] = ""
557
558 if not reverse:
559 imports = _file_imports(root, manifest, target, workdir=root if working_tree else None)
560 if file_filter:
561 imports = [i for i in imports if file_filter in i]
562 if count_only:
563 print(len(imports))
564 return
565 if json_out:
566 print(json.dumps(_DepsFileJson(**make_envelope(elapsed), path=target, imports=imports)))
567 return
568 print(f"\nImports declared in {sanitize_display(target)}")
569 print("─" * 62)
570 if not imports:
571 print(" (no imports found)")
572 else:
573 for name in imports:
574 print(f" {sanitize_display(name)}")
575 print(f"\n{len(imports)} import(s)")
576 else:
577 importers = _reverse_imports(root, manifest, target, file_filter)
578 if count_only:
579 print(len(importers))
580 return
581 if json_out:
582 print(json.dumps(_DepsFileJson(**make_envelope(elapsed), path=target, imported_by=importers)))
583 return
584 print(f"\nFiles that import {sanitize_display(target)}")
585 if file_filter:
586 print(f" (filtered to: *{file_filter}*)")
587 print("─" * 62)
588 if not importers:
589 print(" (no files import this module in the committed snapshot)")
590 else:
591 for fp in importers:
592 print(f" {sanitize_display(fp)}")
593 print(f"\n{len(importers)} importer(s) found")
File History 1 commit