gabriel / muse public
semantic_cherry_pick.py python
427 lines 16.1 KB
Raw
sha256:06dba78c2a78e251b580422dd1fd547f3c8357ff18f7709a860873b2d24dbbbf chore: bump version to 0.2.0rc14 Sonnet 4.6 patch 1 day ago
1 """muse code semantic-cherry-pick — cherry-pick specific symbols, not files.
2
3 Extracts named symbols from a source commit and applies them to the current
4 working tree, replacing only those symbols. All other code is left untouched.
5
6 This is the semantic counterpart to ``git cherry-pick``, which operates at the
7 file-hunk level. ``muse code semantic-cherry-pick`` operates at the symbol level:
8 you name the exact functions, classes, or methods you want to bring forward.
9
10 Multiple symbols can be cherry-picked in a single invocation. They are
11 applied left-to-right. Failures are recorded and processing continues with
12 the remaining symbols — the full result set is always returned.
13
14 Security note: every file path extracted from ADDRESS arguments is validated
15 via ``contain_path()`` before any disk access or directory creation. Paths
16 that escape the repo root (e.g. ``../../etc/shadow::foo``) are rejected and
17 the symbol is recorded as ``not_found`` with an appropriate detail message.
18
19 Usage::
20
21 muse code semantic-cherry-pick "src/billing.py::compute_total" --from abc12345
22 muse code semantic-cherry-pick \\
23 "src/auth.py::validate_token" \\
24 "src/auth.py::refresh_token" \\
25 --from feature-branch
26 muse code semantic-cherry-pick "src/core.py::hash_content" --from HEAD~5 --dry-run
27 muse code semantic-cherry-pick "src/billing.py::Invoice.pay" --from v1.0 --json
28
29 Output::
30
31 Semantic cherry-pick from commit abc12345
32 ──────────────────────────────────────────────────────────────
33
34 ✅ src/auth.py::validate_token applied (lines 12–34 → 29 lines)
35 ✅ src/auth.py::refresh_token applied (lines 36–58 → 17 lines)
36 ❌ src/billing.py::compute_total not found in source commit
37
38 2 applied, 1 failed
39
40 Flags:
41
42 ``--from REF``
43 Required. Commit or branch to cherry-pick from.
44
45 ``--dry-run``
46 Print what would change without writing anything. Each result still
47 includes ``diff_lines`` and ``verified`` so agents can gate on output
48 quality before committing to a write.
49
50 ``--json``
51 Emit per-symbol results as JSON (includes ``diff_lines`` and ``verified``).
52 """
53
54 import argparse
55 import difflib
56 import json
57 import logging
58 import pathlib
59 from typing import Literal, TypedDict
60
61 from muse.core.envelope import EnvelopeJson, make_envelope
62 from muse.core.errors import ExitCode
63 from muse.core.object_store import read_object
64 from muse.core.repo import require_repo
65 from muse.core.refs import read_current_branch
66 from muse.core.commits import resolve_commit_ref
67 from muse.core.snapshots import get_commit_snapshot_manifest
68 from muse.core.timing import start_timer
69 from muse.core.validation import contain_path, sanitize_display
70 from muse.plugins.code.ast_parser import parse_symbols
71 from muse.core.types import Manifest
72
73 type _BlobCache = dict[str, bytes]
74 logger = logging.getLogger(__name__)
75
76 ApplyStatus = Literal["applied", "not_found", "file_missing", "parse_error", "already_current"]
77
78 class _PickResultDict(TypedDict):
79 """JSON schema for one cherry-pick result."""
80
81 address: str
82 status: str
83 detail: str
84 old_lines: int
85 new_lines: int
86 diff_lines: list[str]
87 verified: bool
88
89 class _CherryPickResultJson(EnvelopeJson):
90 """Top-level JSON envelope emitted by ``muse code semantic-cherry-pick --json``."""
91
92 branch: str
93 from_commit: str
94 dry_run: bool
95 results: list[_PickResultDict]
96 applied: int
97 already_current: int
98 failed: int
99 unverified: list[str]
100
101 class _PickResult:
102 """Result for one cherry-picked symbol."""
103
104 def __init__(
105 self,
106 address: str,
107 status: ApplyStatus,
108 detail: str = "",
109 old_lines: int = 0,
110 new_lines: int = 0,
111 diff_lines: list[str] | None = None,
112 verified: bool = True,
113 ) -> None:
114 self.address = address
115 self.status = status
116 self.detail = detail
117 self.old_lines = old_lines
118 self.new_lines = new_lines
119 self.diff_lines: list[str] = diff_lines if diff_lines is not None else []
120 self.verified = verified
121
122 def to_dict(self) -> _PickResultDict:
123 return {
124 "address": self.address,
125 "status": self.status,
126 "detail": self.detail,
127 "old_lines": self.old_lines,
128 "new_lines": self.new_lines,
129 "diff_lines": self.diff_lines,
130 "verified": self.verified,
131 }
132
133 def _verify_symbol(working_file: pathlib.Path, file_rel: str, address: str) -> bool:
134 """Re-parse *working_file* and confirm *address* is locatable after a write.
135
136 Returns ``False`` if the write produced syntactically invalid output or
137 if the symbol is no longer addressable — a useful signal to agents that
138 the splice needs human review.
139 """
140 try:
141 tree = parse_symbols(working_file.read_bytes(), file_rel)
142 return tree.get(address) is not None
143 except Exception:
144 return False
145
146 def _apply_symbol(
147 root: pathlib.Path,
148 address: str,
149 src_manifest: Manifest,
150 dry_run: bool,
151 src_cache: _BlobCache,
152 ) -> _PickResult:
153 """Apply one symbol from *src_manifest* to the working tree.
154
155 *src_cache* is a content-addressed blob cache keyed by object ID.
156 Passing the same dict across calls for a single invocation ensures that
157 multiple addresses targeting the same source file only fetch the blob once.
158 """
159 if "::" not in address:
160 return _PickResult(address, "not_found", "address has no '::' separator")
161
162 file_rel = address.split("::")[0]
163
164 # Validate the file path stays inside the repo root before any I/O.
165 try:
166 working_file = contain_path(root, file_rel)
167 except ValueError as exc:
168 return _PickResult(address, "not_found", str(exc))
169
170 # Read historical blob — use src_cache so repeated addresses targeting
171 # the same source file pay the fetch + decode cost only once.
172 obj_id = src_manifest.get(file_rel)
173 if obj_id is None:
174 return _PickResult(address, "file_missing", f"'{file_rel}' not in source snapshot")
175
176 if obj_id not in src_cache:
177 raw = read_object(root, obj_id)
178 if raw is None:
179 return _PickResult(address, "file_missing", f"blob {obj_id} missing from object store")
180 src_cache[obj_id] = raw
181 src_raw = src_cache[obj_id]
182
183 try:
184 src_tree = parse_symbols(src_raw, file_rel)
185 except Exception as exc:
186 return _PickResult(address, "parse_error", str(exc))
187
188 src_rec = src_tree.get(address)
189 if src_rec is None:
190 return _PickResult(address, "not_found", "symbol not found in source commit")
191
192 src_text = src_raw.decode("utf-8", errors="replace")
193 src_lines_all = src_text.splitlines(keepends=True)
194 src_symbol_lines = src_lines_all[src_rec["lineno"] - 1 : src_rec["end_lineno"]]
195
196 # Read current working tree.
197 if not working_file.exists():
198 diff_lines = list(difflib.unified_diff(
199 [], src_symbol_lines, fromfile="current", tofile="historical", lineterm="",
200 ))
201 if not dry_run:
202 working_file.parent.mkdir(parents=True, exist_ok=True)
203 working_file.write_text("".join(src_symbol_lines), encoding="utf-8")
204 verified = _verify_symbol(working_file, file_rel, address)
205 else:
206 # Dry-run: simulate verification in memory.
207 try:
208 sim_tree = parse_symbols("".join(src_symbol_lines).encode(), file_rel)
209 verified = sim_tree.get(address) is not None
210 except Exception:
211 verified = False
212 return _PickResult(address, "applied", "created file", 0, len(src_symbol_lines), diff_lines, verified)
213
214 current_text = working_file.read_text(encoding="utf-8", errors="replace")
215 current_lines = current_text.splitlines(keepends=True)
216 current_raw = current_text.encode("utf-8")
217
218 try:
219 current_tree = parse_symbols(current_raw, file_rel)
220 except Exception as exc:
221 return _PickResult(address, "parse_error", f"current file: {exc}")
222
223 current_rec = current_tree.get(address)
224
225 if current_rec is not None:
226 if current_rec["content_id"] == src_rec["content_id"]:
227 return _PickResult(address, "already_current", "content identical", 0, 0)
228 old_start = current_rec["lineno"] - 1
229 old_end = current_rec["end_lineno"]
230 current_symbol_lines = current_lines[old_start:old_end]
231 new_lines = current_lines[:old_start] + src_symbol_lines + current_lines[old_end:]
232 detail = f"lines {current_rec['lineno']}–{current_rec['end_lineno']} → {len(src_symbol_lines)} lines"
233 else:
234 # Symbol not in current tree — append at end.
235 current_symbol_lines = []
236 new_lines = current_lines + ["\n"] + src_symbol_lines
237 detail = "appended at end (symbol not found in current tree)"
238
239 diff_lines = list(difflib.unified_diff(
240 current_symbol_lines, src_symbol_lines,
241 fromfile="current", tofile="historical", lineterm="",
242 ))
243
244 if not dry_run:
245 working_file.write_text("".join(new_lines), encoding="utf-8")
246 verified = _verify_symbol(working_file, file_rel, address)
247 else:
248 # Dry-run: simulate verification in memory.
249 try:
250 sim_tree = parse_symbols("".join(new_lines).encode(), file_rel)
251 verified = sim_tree.get(address) is not None
252 except Exception:
253 verified = False
254
255 return _PickResult(address, "applied", detail, len(current_symbol_lines), len(src_symbol_lines), diff_lines, verified)
256
257 def register(subparsers: "argparse._SubParsersAction[argparse.ArgumentParser]") -> None:
258 """Register the ``semantic-cherry-pick`` subcommand and all its arguments.
259
260 Arguments registered
261 --------------------
262 addresses One or more ``file.py::Symbol`` addresses to cherry-pick (positional).
263 --from REF Required. Commit or branch to cherry-pick symbols from.
264 --dry-run / -n Preview diffs without writing anything to disk.
265 --json / -j Emit per-symbol results as a JSON envelope (includes
266 ``diff_lines``, ``verified``, ``exit_code``, ``duration_ms``).
267 """
268 parser = subparsers.add_parser(
269 "semantic-cherry-pick",
270 help="Cherry-pick specific named symbols from a historical commit.",
271 description=__doc__,
272 formatter_class=argparse.RawDescriptionHelpFormatter,
273 )
274 parser.add_argument(
275 "addresses",
276 nargs="+",
277 metavar="ADDRESS",
278 help='Symbol addresses to cherry-pick, e.g. "src/auth.py::validate_token".',
279 )
280 parser.add_argument(
281 "--from",
282 dest="from_ref",
283 required=True,
284 metavar="REF",
285 help="Commit or branch to cherry-pick symbols from (required).",
286 )
287 parser.add_argument(
288 "--dry-run", "-n",
289 action="store_true",
290 help="Print what would change without writing anything.",
291 )
292 parser.add_argument(
293 "--json", "-j",
294 dest="json_out",
295 action="store_true",
296 help="Emit per-symbol results as JSON (includes diff_lines and verified).",
297 )
298 parser.set_defaults(func=run)
299
300 def run(args: argparse.Namespace) -> None:
301 """Cherry-pick specific named symbols from a historical commit.
302
303 Extracts each listed symbol from the source commit and splices it into
304 the current working-tree file at the symbol's current location. Only
305 the target symbol's lines change; all surrounding code is preserved.
306
307 If the symbol does not exist in the current working tree, the historical
308 version is appended to the end of the file.
309
310 Failures are recorded and processing continues with remaining symbols;
311 the full result set is always returned.
312
313 ``--dry-run`` (``-n``) shows what would change without writing anything.
314 ``--json`` (``-j``) emits per-symbol results for machine consumption.
315
316 Agent quickstart::
317
318 muse code semantic-cherry-pick "src/billing.py::compute_total" --from abc12345 --json
319 muse code semantic-cherry-pick "src/auth.py::validate_token" --from feature-branch --json
320 muse code semantic-cherry-pick "src/core.py::hash_content" --from HEAD~5 --dry-run --json
321
322 JSON fields::
323
324 branch str Current branch name
325 from_commit str Short ID of the source commit
326 dry_run bool True when no files were written
327 results list Per-symbol: address, status, detail, old_lines, new_lines, diff_lines, verified
328 applied int Count of successfully applied symbols
329 already_current int Count of symbols already matching the source
330 failed int Count of symbols that could not be applied
331 unverified list[str] Addresses where post-write re-parse failed
332
333 Exit codes::
334
335 0 Success (even if some symbols failed — check ``failed`` field).
336 1 User error (bad address, ref not found).
337 3 Internal error (repo ID unreadable, snapshot missing).
338 """
339 elapsed = start_timer()
340 addresses: list[str] = args.addresses
341 from_ref: str = args.from_ref
342 dry_run: bool = args.dry_run
343 json_out: bool = args.json_out
344
345 root = require_repo()
346
347 try:
348 branch = read_current_branch(root)
349 except Exception as exc:
350 logger.error("❌ Could not read current branch: %s", exc)
351 raise SystemExit(ExitCode.INTERNAL_ERROR) from exc
352
353 if not addresses:
354 logger.error("❌ At least one ADDRESS is required.")
355 raise SystemExit(ExitCode.USER_ERROR)
356
357 from_commit = resolve_commit_ref(root, branch, from_ref)
358 if from_commit is None:
359 logger.error("❌ --from ref '%s' not found.", from_ref)
360 raise SystemExit(ExitCode.USER_ERROR)
361
362 src_manifest = get_commit_snapshot_manifest(root, from_commit.commit_id)
363 if src_manifest is None:
364 logger.error(
365 "❌ Snapshot for commit %s is missing from the object store.",
366 from_commit.commit_id,
367 )
368 raise SystemExit(ExitCode.INTERNAL_ERROR)
369
370 # src_cache avoids re-fetching the same blob when multiple addresses
371 # target the same source file within a single invocation.
372 src_cache: _BlobCache = {}
373 results: list[_PickResult] = []
374 for address in addresses:
375 result = _apply_symbol(root, address, src_manifest, dry_run, src_cache)
376 results.append(result)
377
378 n_applied = sum(1 for r in results if r.status == "applied")
379 n_already = sum(1 for r in results if r.status == "already_current")
380 n_failed = sum(1 for r in results if r.status not in ("applied", "already_current"))
381 unverified = [r.address for r in results if r.status == "applied" and not r.verified]
382
383 if json_out:
384 print(json.dumps(_CherryPickResultJson(
385 **make_envelope(elapsed),
386 branch=branch,
387 from_commit=from_commit.commit_id,
388 dry_run=dry_run,
389 results=[r.to_dict() for r in results],
390 applied=n_applied,
391 already_current=n_already,
392 failed=n_failed,
393 unverified=unverified,
394 )))
395 return
396
397 action = "Dry-run" if dry_run else "Semantic cherry-pick"
398 print(f"\n{action} from commit {from_commit.commit_id}")
399 print("─" * 62)
400
401 max_addr = max(len(r.address) for r in results)
402
403 for r in results:
404 if r.status == "applied":
405 icon = "✅"
406 label = f"applied ({r.detail})"
407 if not r.verified:
408 label += " ⚠️ unverified — re-parse failed after write"
409 elif r.status == "already_current":
410 icon = "ℹ️ "
411 label = "already current — no change needed"
412 else:
413 icon = "❌"
414 label = f"{r.status} ({r.detail})"
415 print(f"\n {icon} {sanitize_display(r.address):<{max_addr}} {label}")
416
417 print(f"\n {n_applied} applied, {n_failed} failed")
418 if n_already:
419 print(f" {n_already} already current")
420 if dry_run:
421 print(" (dry run — no files were written)")
422 if unverified:
423 logger.warning(
424 "⚠️ %d symbol(s) could not be verified after write: %s",
425 len(unverified),
426 ", ".join(unverified),
427 )
File History 1 commit
sha256:06dba78c2a78e251b580422dd1fd547f3c8357ff18f7709a860873b2d24dbbbf chore: bump version to 0.2.0rc14 Sonnet 4.6 patch 1 day ago