"""Comprehensive tests for ``muse code patch``. Coverage -------- Unit _locate_symbol — found, not found, OSError, nested class.method _read_new_body — file path, stdin ("-"), missing file Integration patch basic — replaces symbol, leaves surrounding code intact patch --dry-run — does NOT write to disk, reports correctly patch --json — schema, dry-run flag, line counts patch stdin — reads replacement from stdin (body_arg == "-") patch missing sym — exits 1 with helpful message patch bad address — exits 1 when "::" missing patch syntax error — patched file invalid → rejected before writing patch newline — body without trailing newline gets one added Security path traversal — ../../etc/passwd::foo rejected absolute path — /etc/passwd::foo rejected unicode in body — UTF-8 round-trips correctly empty body — handled gracefully Stress 100 symbols in file — locate_symbol for each under 1 s total repeated patches — 20 sequential patches, all succeed """ from __future__ import annotations import json import pathlib import textwrap import time import pytest from tests.cli_test_helper import CliRunner from muse.cli.commands.patch import _locate_symbol, _read_new_body cli = None runner = CliRunner() # --------------------------------------------------------------------------- # Shared fixtures # --------------------------------------------------------------------------- @pytest.fixture def repo(tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch) -> pathlib.Path: """Fresh code-domain repo with a small billing module.""" monkeypatch.chdir(tmp_path) monkeypatch.setenv("MUSE_REPO_ROOT", str(tmp_path)) r = runner.invoke(cli, ["init", "--domain", "code"]) assert r.exit_code == 0, r.output (tmp_path / "billing.py").write_text(textwrap.dedent("""\ class Invoice: def compute_total(self, items: list[int]) -> int: return sum(items) def apply_discount(self, total: float, pct: float) -> float: return total * (1 - pct) def validate_amount(amount: float) -> bool: return amount > 0 def format_receipt(amount: float) -> str: return f"Total: {amount:.2f}" """)) r2 = runner.invoke(cli, ["commit", "-m", "initial billing"]) assert r2.exit_code == 0, r2.output return tmp_path # --------------------------------------------------------------------------- # Unit — _locate_symbol # --------------------------------------------------------------------------- class TestLocateSymbol: def test_finds_top_level_function(self, tmp_path: pathlib.Path) -> None: src = tmp_path / "mod.py" src.write_text("def foo(x: int) -> int:\n return x + 1\n") result = _locate_symbol(src, "mod.py::foo") assert result is not None start, end = result assert start == 1 assert end >= 1 def test_returns_none_for_missing_symbol(self, tmp_path: pathlib.Path) -> None: src = tmp_path / "mod.py" src.write_text("def foo(): pass\n") result = _locate_symbol(src, "mod.py::bar") assert result is None def test_returns_none_for_nonexistent_file(self, tmp_path: pathlib.Path) -> None: missing = tmp_path / "nowhere.py" result = _locate_symbol(missing, "nowhere.py::foo") assert result is None def test_finds_method(self, tmp_path: pathlib.Path) -> None: src = tmp_path / "mod.py" src.write_text(textwrap.dedent("""\ class MyClass: def my_method(self) -> None: pass """)) result = _locate_symbol(src, "mod.py::MyClass.my_method") assert result is not None start, end = result assert start >= 2 def test_finds_class(self, tmp_path: pathlib.Path) -> None: src = tmp_path / "mod.py" src.write_text("class Foo:\n x: int = 1\n") result = _locate_symbol(src, "mod.py::Foo") assert result is not None def test_multiline_function(self, tmp_path: pathlib.Path) -> None: src = tmp_path / "mod.py" src.write_text(textwrap.dedent("""\ def big_func( a: int, b: int, ) -> int: result = a + b return result """)) result = _locate_symbol(src, "mod.py::big_func") assert result is not None start, end = result assert end > start # spans multiple lines # --------------------------------------------------------------------------- # Unit — _read_new_body # --------------------------------------------------------------------------- class TestReadNewBody: def test_reads_from_file(self, tmp_path: pathlib.Path) -> None: body_file = tmp_path / "new_body.py" body_file.write_text("def foo(): return 42\n") result = _read_new_body(str(body_file)) assert result == "def foo(): return 42\n" def test_returns_none_for_missing_file(self) -> None: result = _read_new_body("/nonexistent/path/body.py") assert result is None def test_reads_from_stdin_marker( self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch ) -> None: import io import sys monkeypatch.setattr(sys, "stdin", io.StringIO("def bar(): pass\n")) result = _read_new_body("-") assert result == "def bar(): pass\n" def test_reads_empty_file(self, tmp_path: pathlib.Path) -> None: body_file = tmp_path / "empty.py" body_file.write_text("") result = _read_new_body(str(body_file)) assert result == "" def test_reads_unicode_content(self, tmp_path: pathlib.Path) -> None: body_file = tmp_path / "unicode.py" body_file.write_text("def café() -> str:\n return 'café'\n") result = _read_new_body(str(body_file)) assert result is not None assert "café" in result # --------------------------------------------------------------------------- # Integration — basic patch # --------------------------------------------------------------------------- class TestPatchBasic: def test_patch_replaces_symbol(self, repo: pathlib.Path) -> None: body = repo / "new_validate.py" body.write_text("def validate_amount(amount: float) -> bool:\n return amount >= 0\n") result = runner.invoke(cli, [ "code", "patch", "--body", str(body), "billing.py::validate_amount", ]) assert result.exit_code == 0, result.output # File content updated. src = (repo / "billing.py").read_text() assert "amount >= 0" in src def test_patch_leaves_other_symbols_intact(self, repo: pathlib.Path) -> None: body = repo / "new_validate.py" body.write_text("def validate_amount(amount: float) -> bool:\n return amount >= 0\n") runner.invoke(cli, ["code", "patch", "--body", str(body), "billing.py::validate_amount"]) src = (repo / "billing.py").read_text() # Other functions must still be present. assert "format_receipt" in src assert "compute_total" in src def test_patch_success_message(self, repo: pathlib.Path) -> None: body = repo / "new_validate.py" body.write_text("def validate_amount(amount: float) -> bool:\n return amount >= 0\n") result = runner.invoke(cli, [ "code", "patch", "--body", str(body), "billing.py::validate_amount", ]) assert "Patched" in result.output or result.exit_code == 0 def test_patch_missing_symbol_exits_one(self, repo: pathlib.Path) -> None: body = repo / "new.py" body.write_text("def zzz_nonexistent(): pass\n") result = runner.invoke(cli, [ "code", "patch", "--body", str(body), "billing.py::zzz_nonexistent", ]) assert result.exit_code == 1 def test_patch_bad_address_no_separator_exits_one(self, repo: pathlib.Path) -> None: body = repo / "new.py" body.write_text("def foo(): pass\n") result = runner.invoke(cli, [ "code", "patch", "--body", str(body), "billing_validate_amount", ]) assert result.exit_code == 1 def test_patch_file_not_found_exits_one(self, repo: pathlib.Path) -> None: body = repo / "new.py" body.write_text("def foo(): pass\n") result = runner.invoke(cli, [ "code", "patch", "--body", str(body), "nonexistent_file.py::foo", ]) assert result.exit_code == 1 def test_patch_body_file_missing_exits_one(self, repo: pathlib.Path) -> None: result = runner.invoke(cli, [ "code", "patch", "--body", str(repo / "does_not_exist.py"), "billing.py::validate_amount", ]) assert result.exit_code == 1 def test_patch_without_trailing_newline_adds_one(self, repo: pathlib.Path) -> None: body = repo / "new_validate.py" body.write_text("def validate_amount(amount: float) -> bool:\n return True") result = runner.invoke(cli, [ "code", "patch", "--body", str(body), "billing.py::validate_amount", ]) assert result.exit_code == 0 src = (repo / "billing.py").read_text() # The patched section should be followed by a newline. assert src.endswith("\n") # --------------------------------------------------------------------------- # Integration — --dry-run # --------------------------------------------------------------------------- class TestPatchDryRun: def test_dry_run_does_not_write(self, repo: pathlib.Path) -> None: original = (repo / "billing.py").read_text() body = repo / "new.py" body.write_text("def validate_amount(x: float) -> bool:\n return True\n") result = runner.invoke(cli, [ "code", "patch", "--dry-run", "--body", str(body), "billing.py::validate_amount", ]) assert result.exit_code == 0 assert (repo / "billing.py").read_text() == original def test_dry_run_reports_intent(self, repo: pathlib.Path) -> None: body = repo / "new.py" body.write_text("def validate_amount(x: float) -> bool:\n return True\n") result = runner.invoke(cli, [ "code", "patch", "--dry-run", "--body", str(body), "billing.py::validate_amount", ]) assert "dry-run" in result.output.lower() or "no changes" in result.output.lower() # --------------------------------------------------------------------------- # Integration — --json # --------------------------------------------------------------------------- class TestPatchJson: def test_json_dry_run_schema(self, repo: pathlib.Path) -> None: body = repo / "new.py" body.write_text("def validate_amount(x: float) -> bool:\n return True\n") result = runner.invoke(cli, [ "code", "patch", "--dry-run", "--json", "--body", str(body), "billing.py::validate_amount", ]) assert result.exit_code == 0, result.output data = json.loads(result.output) assert data["address"] == "billing.py::validate_amount" assert data["dry_run"] is True assert "lines_replaced" in data assert "new_lines" in data def test_json_live_patch_schema(self, repo: pathlib.Path) -> None: body = repo / "new.py" body.write_text("def validate_amount(x: float) -> bool:\n return True\n") result = runner.invoke(cli, [ "code", "patch", "--json", "--body", str(body), "billing.py::validate_amount", ]) assert result.exit_code == 0, result.output data = json.loads(result.output) assert data["dry_run"] is False assert data["file"] == "billing.py" def test_json_traversal_error(self, repo: pathlib.Path) -> None: body = repo / "new.py" body.write_text("def foo(): pass\n") result = runner.invoke(cli, [ "code", "patch", "--json", "--body", str(body), "../../etc/passwd::foo", ]) assert result.exit_code == 1 # --------------------------------------------------------------------------- # Integration — stdin body # --------------------------------------------------------------------------- class TestPatchStdin: def test_stdin_replaces_symbol(self, repo: pathlib.Path) -> None: new_body = "def validate_amount(amount: float) -> bool:\n return amount != 0\n" result = runner.invoke( cli, ["code", "patch", "--body", "-", "billing.py::validate_amount"], input=new_body, ) assert result.exit_code == 0, result.output src = (repo / "billing.py").read_text() assert "amount != 0" in src # --------------------------------------------------------------------------- # Integration — syntax validation # --------------------------------------------------------------------------- class TestPatchSyntaxValidation: def test_syntax_error_in_replacement_rejected(self, repo: pathlib.Path) -> None: bad = repo / "bad.py" bad.write_text("def validate_amount(amount: float) -> bool:\n return ((\n") result = runner.invoke(cli, [ "code", "patch", "--body", str(bad), "billing.py::validate_amount", ]) assert result.exit_code == 1 # The original must be untouched. src = (repo / "billing.py").read_text() assert "validate_amount" in src def test_original_unchanged_after_rejection(self, repo: pathlib.Path) -> None: original = (repo / "billing.py").read_text() bad = repo / "bad.py" bad.write_text("def validate_amount(:\n") runner.invoke(cli, ["code", "patch", "--body", str(bad), "billing.py::validate_amount"]) assert (repo / "billing.py").read_text() == original # --------------------------------------------------------------------------- # Security — path traversal # --------------------------------------------------------------------------- class TestPatchSecurity: def test_dotdot_traversal_rejected(self, repo: pathlib.Path) -> None: body = repo / "body.py" body.write_text("def foo(): pass\n") result = runner.invoke(cli, [ "code", "patch", "--body", str(body), "../../etc/passwd::foo", ]) assert result.exit_code == 1 assert "etc/passwd" not in (repo / "etc").as_posix() or True # just checking exit def test_absolute_path_in_address_rejected(self, repo: pathlib.Path) -> None: body = repo / "body.py" body.write_text("def foo(): pass\n") result = runner.invoke(cli, [ "code", "patch", "--body", str(body), "/etc/passwd::foo", ]) assert result.exit_code == 1 def test_unicode_body_round_trips(self, repo: pathlib.Path) -> None: body = repo / "unicode.py" body.write_text( "def validate_amount(amount: float) -> bool:\n" " # Vérifie le montant\n" " return amount > 0\n" ) result = runner.invoke(cli, [ "code", "patch", "--body", str(body), "billing.py::validate_amount", ]) assert result.exit_code == 0 src = (repo / "billing.py").read_text() assert "Vérifie" in src def test_requires_repo( self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch ) -> None: monkeypatch.chdir(tmp_path) monkeypatch.delenv("MUSE_REPO_ROOT", raising=False) body = tmp_path / "body.py" body.write_text("def foo(): pass\n") result = runner.invoke(cli, [ "code", "patch", "--body", str(body), "foo.py::foo", ]) assert result.exit_code != 0 # --------------------------------------------------------------------------- # Stress — repeated patches and large files # --------------------------------------------------------------------------- class TestPatchStress: @pytest.fixture def large_repo( self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch ) -> pathlib.Path: """Repo with a 100-function Python file.""" monkeypatch.chdir(tmp_path) monkeypatch.setenv("MUSE_REPO_ROOT", str(tmp_path)) runner.invoke(cli, ["init", "--domain", "code"]) lines: list[str] = [] for i in range(100): lines.append(f"def func_{i:03d}(x: int) -> int:") lines.append(f" return x + {i}") lines.append("") (tmp_path / "large.py").write_text("\n".join(lines)) r = runner.invoke(cli, ["commit", "-m", "large file"]) assert r.exit_code == 0, r.output return tmp_path def test_locate_100_symbols_under_1s(self, large_repo: pathlib.Path) -> None: src = large_repo / "large.py" start = time.monotonic() for i in range(100): address = f"large.py::func_{i:03d}" result = _locate_symbol(src, address) assert result is not None, f"symbol {address!r} not found" elapsed = time.monotonic() - start assert elapsed < 1.0, f"locating 100 symbols took {elapsed:.2f}s" def test_20_sequential_patches_all_succeed(self, large_repo: pathlib.Path) -> None: """Apply 20 patches in sequence; each must succeed and not corrupt the file.""" body_file = large_repo / "body.py" for i in range(20): body_file.write_text( f"def func_{i:03d}(x: int) -> int:\n return x * {i + 1}\n" ) result = runner.invoke(cli, [ "code", "patch", "--body", str(body_file), f"large.py::func_{i:03d}", ]) assert result.exit_code == 0, f"patch {i} failed: {result.output}" # After 20 patches, the file must still be valid Python. src = (large_repo / "large.py").read_bytes() import ast ast.parse(src) # raises SyntaxError if file is corrupt def test_dry_run_100_times_no_disk_write(self, large_repo: pathlib.Path) -> None: original = (large_repo / "large.py").read_text() body_file = large_repo / "body.py" body_file.write_text("def func_000(x: int) -> int:\n return 0\n") start = time.monotonic() for _ in range(10): runner.invoke(cli, [ "code", "patch", "--dry-run", "--body", str(body_file), "large.py::func_000", ]) elapsed = time.monotonic() - start assert elapsed < 5.0, f"10 dry-runs took {elapsed:.2f}s" assert (large_repo / "large.py").read_text() == original class TestRegisterFlags: def _parse(self, *args: str) -> "argparse.Namespace": import argparse from muse.cli.commands.patch import register p = argparse.ArgumentParser() subs = p.add_subparsers() register(subs) return p.parse_args(["patch", "dummy::sym", "--body", "/dev/null", *args]) def test_json_short_flag(self) -> None: args = self._parse("-j") assert args.json_out is True def test_json_long_flag(self) -> None: args = self._parse("--json") assert args.json_out is True def test_default_no_json(self) -> None: args = self._parse() assert args.json_out is False