"""Concurrent ref-update safety for muse commit. Before this fix, write_branch_ref was an unconditional overwrite. Two agents committing to the same branch simultaneously would silently orphan one commit — the second write_branch_ref call would overwrite the first with no detection, dropping a commit from the branch history. Fix: write_branch_ref gains an optional expected_id parameter. When provided, the write only proceeds if the current ref value matches. If it doesn't match (another agent advanced the ref first), it raises RefConflictError with a clear message. commit passes parent_id as expected_id so any concurrent advance of the branch ref between the parent_id read and the ref write is caught and reported as a retryable error rather than silently overwriting. Test structure: Tier 1 — write_branch_ref CAS unit tests Tier 2 — commit detects branch-moved condition Tier 3 — concurrent commit threads: only one succeeds per slot """ from __future__ import annotations import pathlib import threading from typing import TYPE_CHECKING import pytest from muse.core.refs import write_branch_ref from muse.core.types import fake_id from muse.core.paths import heads_dir, muse_dir from tests.cli_test_helper import CliRunner, InvokeResult if TYPE_CHECKING: pass runner = CliRunner() def _run(repo: pathlib.Path, *args: str) -> None: r = runner.invoke(None, list(args), cwd=repo) assert r.exit_code == 0, f"muse {' '.join(args)} failed:\n{r.output}" def _try(repo: pathlib.Path, *args: str) -> InvokeResult: return runner.invoke(None, list(args), cwd=repo) def _make_bare_refs(tmp_path: pathlib.Path) -> pathlib.Path: """Minimal .muse/refs/heads/ tree — enough for write_branch_ref.""" import json muse = muse_dir(tmp_path) for d in ("objects", "commits", "snapshots", "refs/heads"): (muse / d).mkdir(parents=True, exist_ok=True) (muse / "repo.json").write_text(json.dumps({"repo_id": "test-repo"})) (muse / "HEAD").write_text("ref: refs/heads/main\n") return tmp_path # --------------------------------------------------------------------------- # Tier 1 — write_branch_ref CAS unit tests # --------------------------------------------------------------------------- class TestWriteBranchRefCAS: def test_unconditional_write_succeeds(self, tmp_path: pathlib.Path) -> None: repo = _make_bare_refs(tmp_path) id_a = fake_id("a") write_branch_ref(repo, "main", id_a) ref = (heads_dir(repo) / "main").read_text().strip() assert ref == id_a def test_cas_succeeds_when_expected_matches(self, tmp_path: pathlib.Path) -> None: repo = _make_bare_refs(tmp_path) id_a = fake_id("a") id_b = fake_id("b") write_branch_ref(repo, "main", id_a) write_branch_ref(repo, "main", id_b, expected_id=id_a) ref = (heads_dir(repo) / "main").read_text().strip() assert ref == id_b def test_cas_fails_when_expected_does_not_match(self, tmp_path: pathlib.Path) -> None: from muse.core.refs import RefConflictError repo = _make_bare_refs(tmp_path) id_a = fake_id("a") id_b = fake_id("b") id_c = fake_id("c") write_branch_ref(repo, "main", id_a) with pytest.raises(RefConflictError): write_branch_ref(repo, "main", id_c, expected_id=id_b) # ref must be unchanged ref = (heads_dir(repo) / "main").read_text().strip() assert ref == id_a def test_cas_fails_when_ref_is_missing_and_expected_is_set(self, tmp_path: pathlib.Path) -> None: from muse.core.refs import RefConflictError repo = _make_bare_refs(tmp_path) id_a = fake_id("a") id_b = fake_id("b") # No ref written yet — expected_id=id_a should fail (ref is None, not id_a) with pytest.raises(RefConflictError): write_branch_ref(repo, "main", id_b, expected_id=id_a) def test_cas_succeeds_on_first_write_when_expected_is_none(self, tmp_path: pathlib.Path) -> None: repo = _make_bare_refs(tmp_path) id_a = fake_id("a") # expected_id=None means "no prior value" — valid for first commit on a branch write_branch_ref(repo, "main", id_a, expected_id=None) ref = (heads_dir(repo) / "main").read_text().strip() assert ref == id_a def test_error_contains_branch_name_and_ids(self, tmp_path: pathlib.Path) -> None: from muse.core.refs import RefConflictError repo = _make_bare_refs(tmp_path) id_a = fake_id("a") id_b = fake_id("b") id_c = fake_id("c") write_branch_ref(repo, "dev", id_a) with pytest.raises(RefConflictError, match="dev"): write_branch_ref(repo, "dev", id_c, expected_id=id_b) # --------------------------------------------------------------------------- # Tier 2 — commit detects branch-moved condition # # The race window is between get_head_commit_id (parent read) and # write_branch_ref (ref advance). We simulate it by patching # get_head_commit_id to return a stale ID while the branch ref has already # been advanced to a different commit by a "concurrent" agent. # --------------------------------------------------------------------------- class TestCommitDetectsBranchMoved: def test_commit_fails_when_branch_advanced_concurrently( self, muse_repo: pathlib.Path ) -> None: """If the branch ref advances between parent-read and ref-write, commit must fail with a clear error rather than silently orphaning a commit.""" from unittest.mock import patch import json as _json f = muse_repo / "a.py" f.write_text("v1\n") _run(muse_repo, "code", "add", "a.py") _run(muse_repo, "commit", "-m", "first") # Capture the real HEAD after the first commit head_r = runner.invoke(None, ["rev-parse", "HEAD", "--json"], cwd=muse_repo) real_parent = _json.loads(head_r.output)["commit_id"] branch_r = runner.invoke(None, ["rev-parse", "--abbrev-ref", "HEAD", "--json"], cwd=muse_repo) branch = _json.loads(branch_r.output)["branch"] # Stage a change for the second commit f.write_text("v2\n") _run(muse_repo, "code", "add", "a.py") # Advance the branch ref to simulate another agent committing first other_id = fake_id("other") write_branch_ref(muse_repo, branch, other_id) # Patch get_head_commit_id so commit sees the STALE parent (real_parent), # while the branch ref on disk is already at other_id. # This exactly reproduces the race: parent_id read → branch advances → write fails. with patch( "muse.cli.commands.commit.get_head_commit_id", return_value=real_parent, ): result = _try(muse_repo, "commit", "-m", "ours — should fail") assert result.exit_code != 0, ( "commit succeeded despite branch moving concurrently — " "this would have orphaned the concurrent commit" ) def test_commit_succeeds_when_branch_has_not_moved( self, muse_repo: pathlib.Path ) -> None: """Normal single-agent commit still works after the CAS fix.""" f = muse_repo / "b.py" f.write_text("hello\n") _run(muse_repo, "code", "add", "b.py") r = _try(muse_repo, "commit", "-m", "normal commit") assert r.exit_code == 0, r.output def test_commit_error_message_is_actionable( self, muse_repo: pathlib.Path ) -> None: """The error surfaced to the user must tell them to pull/retry.""" from unittest.mock import patch import json as _json f = muse_repo / "c.py" f.write_text("v1\n") _run(muse_repo, "code", "add", "c.py") _run(muse_repo, "commit", "-m", "first") head_r = runner.invoke(None, ["rev-parse", "HEAD", "--json"], cwd=muse_repo) real_parent = _json.loads(head_r.output)["commit_id"] branch_r = runner.invoke(None, ["rev-parse", "--abbrev-ref", "HEAD", "--json"], cwd=muse_repo) branch = _json.loads(branch_r.output)["branch"] f.write_text("v2\n") _run(muse_repo, "code", "add", "c.py") write_branch_ref(muse_repo, branch, fake_id("concurrent")) with patch( "muse.cli.commands.commit.get_head_commit_id", return_value=real_parent, ): result = _try(muse_repo, "commit", "-m", "should fail") combined = result.output + (result.stderr or "") assert any( kw in combined.lower() for kw in ("conflict", "moved", "concurrent", "retry", "pull", "changed") ), f"error output gave no actionable guidance:\n{combined}" # --------------------------------------------------------------------------- # Tier 3 — concurrent commit threads: only one succeeds per slot # # Two threads both read parent_id = A, both build a commit, both try to # write the branch ref with expected_id=A. The first rename wins; the # second finds current != A and raises RefConflictError. # --------------------------------------------------------------------------- class TestConcurrentCommitThreads: def test_only_one_of_two_concurrent_commits_wins( self, muse_repo: pathlib.Path ) -> None: """Two threads racing on write_branch_ref with the same expected_id: exactly one must succeed and one must raise RefConflictError.""" from muse.core.refs import RefConflictError (muse_repo / "base.py").write_text("base\n") _run(muse_repo, "code", "add", "base.py") _run(muse_repo, "commit", "-m", "base") import json as _json head_r = runner.invoke(None, ["rev-parse", "HEAD", "--json"], cwd=muse_repo) current = _json.loads(head_r.output)["commit_id"] branch_r = runner.invoke(None, ["rev-parse", "--abbrev-ref", "HEAD", "--json"], cwd=muse_repo) branch = _json.loads(branch_r.output)["branch"] id_a = fake_id("commit-a") id_b = fake_id("commit-b") results: list[bool] = [] # True = success, False = RefConflictError lock = threading.Lock() barrier = threading.Barrier(2) def do_cas(new_id: str) -> None: barrier.wait() # both threads start the CAS at the same moment try: write_branch_ref(muse_repo, branch, new_id, expected_id=current) with lock: results.append(True) except RefConflictError: with lock: results.append(False) t1 = threading.Thread(target=do_cas, args=(id_a,)) t2 = threading.Thread(target=do_cas, args=(id_b,)) t1.start() t2.start() t1.join() t2.join() successes = results.count(True) assert successes == 1, ( f"expected exactly 1 CAS success, got {successes}. results: {results}" ) assert results.count(False) == 1