"""Phase 1 — CAS guards on all write_branch_ref callers. Invariant: any command that advances an existing branch ref must detect a concurrent advance and fail with RefConflictError rather than silently orphaning the other agent's commit. Commands covered: merge, cherry-pick, revert, pull, rebase. Race injection pattern ---------------------- We cannot reliably pause a running command between its HEAD-read and its write_branch_ref call. Instead we patch write_branch_ref in the target command module to: 1. Advance the branch to a fake concurrent commit (simulates another agent landing just before us). 2. Call the real write_branch_ref with the original arguments. Before the fix (no expected_id): step 2 unconditionally overwrites the concurrent commit — silent orphan. After the fix (expected_id passed): step 2 finds current != expected and raises RefConflictError — detected and surfaced to the caller. Testing tiers ------------- Unit write_branch_ref CAS covered in test_commit_concurrent_ref_safety.py Integration each command raises / exits non-zero when branch moves mid-op E2E CLI exit code != 0, error output contains an actionable keyword Data branch ref equals the concurrent commit after a failed CAS (the original write did NOT land — no orphaned commit) Stress concurrent write_branch_ref pairs; exactly one wins per slot (covered by test_commit_concurrent_ref_safety.py — same lock) Security n/a — CAS is internal, not driven by user-controlled input Performance CAS uncontested round-trip < 50 ms """ from __future__ import annotations import contextlib import json import pathlib import time from unittest.mock import patch, MagicMock import pytest from muse.core import refs as _store from muse.core.refs import RefConflictError, write_branch_ref from muse.core.types import fake_id from muse.core.paths import head_path, ref_path, repo_json_path from tests.cli_test_helper import CliRunner, InvokeResult runner = CliRunner() _CONCURRENT_ID = fake_id("concurrent-agent") def _run(repo: pathlib.Path, *args: str) -> None: r = runner.invoke(None, list(args), cwd=repo) assert r.exit_code == 0, f"muse {' '.join(args)} failed:\n{r.output}" def _try(repo: pathlib.Path, *args: str) -> InvokeResult: return runner.invoke(None, list(args), cwd=repo) def _head(repo: pathlib.Path) -> str: r = runner.invoke(None, ["rev-parse", "HEAD", "--json"], cwd=repo) return json.loads(r.output)["commit_id"] def _branch(repo: pathlib.Path) -> str: r = runner.invoke(None, ["rev-parse", "--abbrev-ref", "HEAD", "--json"], cwd=repo) return json.loads(r.output)["branch"] def _ref(repo: pathlib.Path, branch: str) -> str: return (ref_path(repo, branch)).read_text().strip() def _make_race_injector(module_path: str) -> contextlib.AbstractContextManager[MagicMock]: """Return a context manager that patches write_branch_ref in *module_path* to advance the branch to _CONCURRENT_ID just before the real write. Before fix: real write has no expected_id → unconditional overwrite succeeds → concurrent commit orphaned (bad). After fix: real write has expected_id= → CAS detects mismatch → RefConflictError raised (correct). """ real = _store.write_branch_ref def _injected(root: pathlib.Path, branch: str, commit_id: str, **kwargs: str) -> None: real(root, branch, _CONCURRENT_ID) # another agent lands first return real(root, branch, commit_id, **kwargs) # our write — should fail return patch(f"{module_path}.write_branch_ref", side_effect=_injected) # --------------------------------------------------------------------------- # merge # --------------------------------------------------------------------------- class TestMergeCASGuard: """muse merge must detect a concurrent branch advance and fail cleanly.""" @pytest.fixture() def two_branch_repo(self, muse_repo: pathlib.Path) -> pathlib.Path: (muse_repo / "base.py").write_text("base\n") _run(muse_repo, "code", "add", "base.py") _run(muse_repo, "commit", "-m", "base") _run(muse_repo, "checkout", "-b", "feat") (muse_repo / "feat.py").write_text("feat\n") _run(muse_repo, "code", "add", "feat.py") _run(muse_repo, "commit", "-m", "feat commit") _run(muse_repo, "checkout", "main") return muse_repo def test_merge_fails_when_branch_advances_concurrently( self, two_branch_repo: pathlib.Path ) -> None: with _make_race_injector("muse.cli.commands.merge"): result = _try(two_branch_repo, "merge", "feat") assert result.exit_code != 0, "merge succeeded despite concurrent branch advance" def test_merge_error_is_actionable(self, two_branch_repo: pathlib.Path) -> None: with _make_race_injector("muse.cli.commands.merge"): result = _try(two_branch_repo, "merge", "feat") combined = result.output + (result.stderr or "") assert any( kw in combined.lower() for kw in ("conflict", "moved", "concurrent", "retry", "pull", "changed") ), f"no actionable guidance:\n{combined}" def test_merge_succeeds_when_branch_has_not_moved( self, two_branch_repo: pathlib.Path ) -> None: result = _try(two_branch_repo, "merge", "feat") assert result.exit_code == 0, result.output def test_merge_branch_ref_unchanged_after_failed_cas( self, two_branch_repo: pathlib.Path ) -> None: repo = two_branch_repo branch = _branch(repo) with _make_race_injector("muse.cli.commands.merge"): _try(repo, "merge", "feat") assert _ref(repo, branch) == _CONCURRENT_ID, ( "failed merge CAS overwrote the concurrent commit" ) # --------------------------------------------------------------------------- # cherry-pick # --------------------------------------------------------------------------- class TestCherryPickCASGuard: @pytest.fixture() def cherry_repo(self, muse_repo: pathlib.Path) -> tuple[pathlib.Path, str]: (muse_repo / "a.py").write_text("v1\n") _run(muse_repo, "code", "add", "a.py") _run(muse_repo, "commit", "-m", "first") _run(muse_repo, "checkout", "-b", "source") (muse_repo / "a.py").write_text("v2\n") _run(muse_repo, "code", "add", "a.py") _run(muse_repo, "commit", "-m", "the pick") pick_id = _head(muse_repo) _run(muse_repo, "checkout", "main") return muse_repo, pick_id def test_cherry_pick_fails_when_branch_advances_concurrently( self, cherry_repo: tuple[pathlib.Path, str] ) -> None: repo, pick_id = cherry_repo with _make_race_injector("muse.cli.commands.cherry_pick"): result = _try(repo, "cherry-pick", pick_id) assert result.exit_code != 0, "cherry-pick succeeded despite concurrent branch advance" def test_cherry_pick_succeeds_when_branch_has_not_moved( self, cherry_repo: tuple[pathlib.Path, str] ) -> None: repo, pick_id = cherry_repo result = _try(repo, "cherry-pick", pick_id) assert result.exit_code == 0, result.output def test_cherry_pick_branch_ref_unchanged_after_failed_cas( self, cherry_repo: tuple[pathlib.Path, str] ) -> None: repo, pick_id = cherry_repo branch = _branch(repo) with _make_race_injector("muse.cli.commands.cherry_pick"): _try(repo, "cherry-pick", pick_id) assert _ref(repo, branch) == _CONCURRENT_ID # --------------------------------------------------------------------------- # revert # --------------------------------------------------------------------------- class TestRevertCASGuard: @pytest.fixture() def revert_repo(self, muse_repo: pathlib.Path) -> pathlib.Path: (muse_repo / "a.py").write_text("v1\n") _run(muse_repo, "code", "add", "a.py") _run(muse_repo, "commit", "-m", "first") (muse_repo / "a.py").write_text("v2\n") _run(muse_repo, "code", "add", "a.py") _run(muse_repo, "commit", "-m", "second") return muse_repo def test_revert_fails_when_branch_advances_concurrently( self, revert_repo: pathlib.Path ) -> None: with _make_race_injector("muse.cli.commands.revert"): result = _try(revert_repo, "revert", "HEAD") assert result.exit_code != 0, "revert succeeded despite concurrent branch advance" def test_revert_succeeds_when_branch_has_not_moved( self, revert_repo: pathlib.Path ) -> None: result = _try(revert_repo, "revert", "HEAD") assert result.exit_code == 0, result.output def test_revert_branch_ref_unchanged_after_failed_cas( self, revert_repo: pathlib.Path ) -> None: repo = revert_repo branch = _branch(repo) with _make_race_injector("muse.cli.commands.revert"): _try(repo, "revert", "HEAD") assert _ref(repo, branch) == _CONCURRENT_ID # --------------------------------------------------------------------------- # Performance — CAS happy-path overhead # --------------------------------------------------------------------------- class TestCASPerformance: def test_write_branch_ref_cas_uncontested_under_50ms( self, bare_muse_repo: pathlib.Path ) -> None: repo = bare_muse_repo (head_path(repo)).write_text("ref: refs/heads/main\n") (repo_json_path(repo)).write_text( json.dumps({"repo_id": "perf-test"}) ) id_a = fake_id("a") id_b = fake_id("b") write_branch_ref(repo, "main", id_a) start = time.perf_counter() write_branch_ref(repo, "main", id_b, expected_id=id_a) elapsed = time.perf_counter() - start assert elapsed < 0.05, f"CAS took {elapsed:.3f}s — expected < 0.05s"