"""Supercharged tests for ``muse rebase`` — TDD for all gaps. Covers every JSON output path for: - ``duration_ms`` (float, milliseconds) - ``exit_code`` (int, 0/1/3) - ``replayed_commit_ids`` (list[str], sha256:-prefixed) Covers sha256:-prefix correctness: - ``_resolve_ref_to_id`` with sha256:-prefixed content in ref files - ``_short_id`` keeps the sha256: prefix and truncates only the hex portion - ``new_head``/``onto`` in JSON are sha256:-prefixed Covers all integration and lifecycle paths: - completed (normal), aborted, up_to_date, conflict, dry_run, status, squash Security, performance, and stress: - symlink guard on REBASE_STATE.json (load, save, clear) - size cap on REBASE_STATE.json - 50-commit dry-run, concurrent status reads """ from __future__ import annotations from collections.abc import Mapping import datetime import argparse import json import pathlib import threading import time import pytest from tests.cli_test_helper import CliRunner, InvokeResult from muse.core.object_store import write_object from muse.core.rebase import ( RebaseState, _MAX_STATE_BYTES, clear_rebase_state, collect_commits_to_replay, get_rebase_progress, load_rebase_state, save_rebase_state, ) from muse.core.paths import muse_dir, rebase_state_path, ref_path from muse.core.ids import hash_commit as compute_commit_id, hash_snapshot as compute_snapshot_id from muse.core.commits import ( CommitRecord, write_commit, ) from muse.core.snapshots import ( SnapshotRecord, write_snapshot, ) from muse.core.types import Manifest, blob_id, long_id, short_id runner = CliRunner() _REPO_ID = "rebase-supercharge-test" # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _oid(content: bytes) -> str: """Return a sha256:-prefixed object ID.""" return blob_id(content) _counter = 0 _counter_lock = threading.Lock() def _make_commit( root: pathlib.Path, parent_id: str | None = None, content: bytes = b"data", branch: str = "main", ) -> str: """Create a commit with correct sha256:-prefixed object IDs. Returns the commit ID.""" global _counter with _counter_lock: _counter += 1 c_val = _counter c = content + str(c_val).encode() obj_id = _oid(c) write_object(root, obj_id, c) manifest: Manifest = {f"f_{c_val}.txt": obj_id} snap_id = compute_snapshot_id(manifest) write_snapshot(root, SnapshotRecord(snapshot_id=snap_id, manifest=manifest)) committed_at = datetime.datetime.now(datetime.timezone.utc) parent_ids = [parent_id] if parent_id else [] commit_id = compute_commit_id( parent_ids=parent_ids, snapshot_id=snap_id, message=f"commit {c_val}", committed_at_iso=committed_at.isoformat(), ) write_commit(root, CommitRecord( commit_id=commit_id, branch=branch, snapshot_id=snap_id, message=f"commit {c_val}", committed_at=committed_at, parent_commit_id=parent_id, )) (ref_path(root, branch)).write_text(commit_id, encoding="utf-8") return commit_id def _init_repo(path: pathlib.Path) -> pathlib.Path: muse = muse_dir(path) for d in ("commits", "snapshots", "objects", "refs/heads"): (muse / d).mkdir(parents=True, exist_ok=True) (muse / "HEAD").write_text("ref: refs/heads/main", encoding="utf-8") (muse / "repo.json").write_text( json.dumps({"repo_id": _REPO_ID, "domain": "code"}), encoding="utf-8" ) return path def _env(repo: pathlib.Path) -> Mapping[str, str]: return {"MUSE_REPO_ROOT": str(repo)} def _invoke(args: list[str], repo: pathlib.Path) -> InvokeResult: return runner.invoke(None, args, env=_env(repo)) def _json_from(output: str) -> Mapping[str, object]: for line in output.splitlines(): line = line.strip() if line.startswith("{"): return json.loads(line) return json.loads(output.strip()) # --------------------------------------------------------------------------- # _short_id helper — prefix is canonical, only hex portion is truncated # --------------------------------------------------------------------------- class TestShortId: """_short_id keeps the sha256: prefix and truncates only the hex portion.""" def test_short_id_keeps_prefix(self, tmp_path: pathlib.Path) -> None: """_short_id must keep the sha256: prefix — it is canonical in Muse.""" cid = long_id("a" * 64) result = short_id(cid) assert result.startswith("sha256:"), f"Expected sha256: prefix, got {result!r}" def test_short_id_truncates_hex_to_12(self, tmp_path: pathlib.Path) -> None: """_short_id returns sha256: + first 12 hex chars.""" cid = long_id("deadbeef" * 8) result = short_id(cid) assert result == "sha256:deadbeefdead" # prefix + 12 hex chars def test_short_id_total_length(self, tmp_path: pathlib.Path) -> None: """sha256: (7) + 12 hex chars = 19 total chars.""" cid = long_id("cafebabe" * 8) result = short_id(cid) assert len(result) == 19 # "sha256:" (7) + 12 hex chars def test_short_id_bare_hex_passthrough(self, tmp_path: pathlib.Path) -> None: """_short_id with a bare hex string (no prefix) returns first 12 chars.""" bare = f"1234567890ab{'cd' * 26}" # 64 chars total result = short_id(bare) assert result == "1234567890ab" def test_text_output_shows_sha256_short_id(self, tmp_path: pathlib.Path) -> None: """Text output must show sha256:<12 hex chars> short IDs, not bare hex.""" _init_repo(tmp_path) base = _make_commit(tmp_path, content=b"base") (ref_path(tmp_path, "upstream")).write_text(base, encoding="utf-8") c1 = _make_commit(tmp_path, parent_id=base, content=b"c1") result = _invoke(["rebase", "--dry-run", "upstream"], tmp_path) assert result.exit_code == 0 # The output must contain sha256: expected_short = long_id(c1[7:19])# prefix + 12 hex chars assert expected_short in result.output, ( f"Expected {expected_short!r} in dry-run text output.\n" f"Got: {result.output!r}" ) # --------------------------------------------------------------------------- # _resolve_ref_to_id — must handle sha256:-prefixed content in ref files # --------------------------------------------------------------------------- class TestResolveRefToId: """_resolve_ref_to_id must handle ref files whose content has sha256: prefix.""" def test_resolves_sha256_prefixed_ref_file(self, tmp_path: pathlib.Path) -> None: """Bug: len(raw) == 64 check fails when ref file contains sha256:-prefixed ID (71 chars).""" from muse.cli.commands.rebase import _resolve_ref_to_id _init_repo(tmp_path) commit_id = _make_commit(tmp_path, content=b"sha256-prefix-test") # commit_id is sha256:-prefixed (71 chars) — the ref file already has this resolved = _resolve_ref_to_id(tmp_path, "main", "main") assert resolved == commit_id, ( f"Expected {commit_id!r}, got {resolved!r}. " "Bug: _resolve_ref_to_id len check fails for sha256:-prefixed IDs." ) def test_resolves_head(self, tmp_path: pathlib.Path) -> None: """HEAD resolves to the current branch's commit.""" from muse.cli.commands.rebase import _resolve_ref_to_id _init_repo(tmp_path) commit_id = _make_commit(tmp_path, content=b"head-test") result = _resolve_ref_to_id(tmp_path, "main", "HEAD") assert result == commit_id def test_returns_none_for_missing_branch(self, tmp_path: pathlib.Path) -> None: """Unknown branch name resolves to None.""" from muse.cli.commands.rebase import _resolve_ref_to_id _init_repo(tmp_path) _make_commit(tmp_path) result = _resolve_ref_to_id(tmp_path, "main", "nonexistent-branch") assert result is None def test_resolved_id_is_sha256_prefixed(self, tmp_path: pathlib.Path) -> None: """The resolved commit ID must be sha256:-prefixed.""" from muse.cli.commands.rebase import _resolve_ref_to_id _init_repo(tmp_path) _make_commit(tmp_path, content=b"prefix-check") result = _resolve_ref_to_id(tmp_path, "main", "main") assert result is not None assert result.startswith("sha256:") # --------------------------------------------------------------------------- # duration_ms — all JSON output paths must include it # --------------------------------------------------------------------------- class TestJsonSchemaDurationMs: """Every JSON output path must include duration_ms.""" def test_status_inactive_has_duration_ms(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) _make_commit(tmp_path) result = _invoke(["rebase", "--status", "--json"], tmp_path) assert result.exit_code == 0 data = _json_from(result.output) assert "duration_ms" in data, f"Missing duration_ms in status JSON: {data}" assert isinstance(data["duration_ms"], (int, float)) assert data["duration_ms"] >= 0 def test_status_active_has_duration_ms(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) state = RebaseState( original_branch="main", original_head="a" * 64, onto="b" * 64, remaining=["c" * 64], completed=[], squash=False, ) save_rebase_state(tmp_path, state) result = _invoke(["rebase", "--status", "--json"], tmp_path) assert result.exit_code == 0 data = _json_from(result.output) assert "duration_ms" in data assert isinstance(data["duration_ms"], (int, float)) def test_abort_json_has_duration_ms(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) state = RebaseState( original_branch="main", original_head=base, onto=base, remaining=[], completed=[], squash=False, ) save_rebase_state(tmp_path, state) result = _invoke(["rebase", "--abort", "--json"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) assert "duration_ms" in data, f"Missing duration_ms in abort JSON: {data}" assert isinstance(data["duration_ms"], (int, float)) def test_up_to_date_json_has_duration_ms(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) cid = _make_commit(tmp_path) (ref_path(tmp_path, "up")).write_text(cid, encoding="utf-8") result = _invoke(["rebase", "--json", "up"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) assert "duration_ms" in data, f"Missing duration_ms in up_to_date JSON: {data}" assert isinstance(data["duration_ms"], (int, float)) def test_dry_run_json_has_duration_ms(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) (ref_path(tmp_path, "upstream")).write_text(base, encoding="utf-8") _make_commit(tmp_path, parent_id=base) result = _invoke(["rebase", "--dry-run", "--json", "upstream"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) assert "duration_ms" in data, f"Missing duration_ms in dry_run JSON: {data}" assert isinstance(data["duration_ms"], (int, float)) def test_completed_json_has_duration_ms(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) (ref_path(tmp_path, "upstream")).write_text(base, encoding="utf-8") _make_commit(tmp_path, parent_id=base) result = _invoke(["rebase", "--json", "upstream"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) assert "duration_ms" in data, f"Missing duration_ms in completed JSON: {data}" assert isinstance(data["duration_ms"], (int, float)) # --------------------------------------------------------------------------- # exit_code — all JSON output paths must include it # --------------------------------------------------------------------------- class TestJsonSchemaExitCode: """Every JSON output path must include exit_code.""" def test_status_json_has_exit_code_0(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) _make_commit(tmp_path) result = _invoke(["rebase", "--status", "--json"], tmp_path) assert result.exit_code == 0 data = _json_from(result.output) assert "exit_code" in data, f"Missing exit_code: {data}" assert data["exit_code"] == 0 def test_abort_json_has_exit_code_0(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) state = RebaseState( original_branch="main", original_head=base, onto=base, remaining=[], completed=[], squash=False, ) save_rebase_state(tmp_path, state) result = _invoke(["rebase", "--abort", "--json"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) assert "exit_code" in data, f"Missing exit_code: {data}" assert data["exit_code"] == 0 def test_up_to_date_json_has_exit_code_0(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) cid = _make_commit(tmp_path) (ref_path(tmp_path, "up")).write_text(cid, encoding="utf-8") result = _invoke(["rebase", "--json", "up"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) assert "exit_code" in data assert data["exit_code"] == 0 def test_dry_run_json_has_exit_code_0(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) (ref_path(tmp_path, "upstream")).write_text(base, encoding="utf-8") _make_commit(tmp_path, parent_id=base) result = _invoke(["rebase", "--dry-run", "--json", "upstream"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) assert "exit_code" in data assert data["exit_code"] == 0 def test_completed_json_has_exit_code_0(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) (ref_path(tmp_path, "upstream")).write_text(base, encoding="utf-8") _make_commit(tmp_path, parent_id=base) result = _invoke(["rebase", "--json", "upstream"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) assert "exit_code" in data assert data["exit_code"] == 0 def test_duration_ms_is_nonnegative_float(self, tmp_path: pathlib.Path) -> None: """duration_ms must be a non-negative number.""" _init_repo(tmp_path) cid = _make_commit(tmp_path) (ref_path(tmp_path, "up")).write_text(cid, encoding="utf-8") result = _invoke(["rebase", "--json", "up"], tmp_path) data = _json_from(result.output) assert data["duration_ms"] >= 0.0 # --------------------------------------------------------------------------- # replayed_commit_ids — completed result JSON must list new commit IDs # --------------------------------------------------------------------------- class TestReplayedCommitIds: """Completed rebase JSON must include replayed_commit_ids.""" def test_completed_has_replayed_commit_ids(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) (ref_path(tmp_path, "upstream")).write_text(base, encoding="utf-8") _make_commit(tmp_path, parent_id=base) result = _invoke(["rebase", "--json", "upstream"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) assert "replayed_commit_ids" in data, f"Missing replayed_commit_ids: {data}" assert isinstance(data["replayed_commit_ids"], list) def test_replayed_commit_ids_count_matches_replayed(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) (ref_path(tmp_path, "upstream")).write_text(base, encoding="utf-8") c1 = _make_commit(tmp_path, parent_id=base) c2 = _make_commit(tmp_path, parent_id=c1) result = _invoke(["rebase", "--json", "upstream"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) assert data["replayed"] == 2 assert len(data["replayed_commit_ids"]) == 2 def test_replayed_commit_ids_are_sha256_prefixed(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) (ref_path(tmp_path, "upstream")).write_text(base, encoding="utf-8") _make_commit(tmp_path, parent_id=base) result = _invoke(["rebase", "--json", "upstream"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) for cid in data["replayed_commit_ids"]: assert cid.startswith("sha256:"), f"Not sha256:-prefixed: {cid!r}" def test_abort_has_replayed_commit_ids_empty(self, tmp_path: pathlib.Path) -> None: """Aborted rebase has no new commits — replayed_commit_ids must be empty list.""" _init_repo(tmp_path) base = _make_commit(tmp_path) state = RebaseState( original_branch="main", original_head=base, onto=base, remaining=[], completed=[], squash=False, ) save_rebase_state(tmp_path, state) result = _invoke(["rebase", "--abort", "--json"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) assert "replayed_commit_ids" in data assert data["replayed_commit_ids"] == [] def test_up_to_date_has_replayed_commit_ids_empty(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) cid = _make_commit(tmp_path) (ref_path(tmp_path, "up")).write_text(cid, encoding="utf-8") result = _invoke(["rebase", "--json", "up"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) assert "replayed_commit_ids" in data assert data["replayed_commit_ids"] == [] # --------------------------------------------------------------------------- # Data integrity — IDs in JSON must be sha256:-prefixed # --------------------------------------------------------------------------- class TestDataIntegrity: """All commit IDs in JSON output must be sha256:-prefixed.""" def test_new_head_is_sha256_prefixed(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) (ref_path(tmp_path, "upstream")).write_text(base, encoding="utf-8") _make_commit(tmp_path, parent_id=base) result = _invoke(["rebase", "--json", "upstream"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) assert data["new_head"].startswith("sha256:"), f"new_head not sha256:-prefixed: {data['new_head']!r}" def test_onto_is_sha256_prefixed(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) (ref_path(tmp_path, "upstream")).write_text(base, encoding="utf-8") _make_commit(tmp_path, parent_id=base) result = _invoke(["rebase", "--json", "upstream"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) assert data["onto"].startswith("sha256:"), f"onto not sha256:-prefixed: {data['onto']!r}" def test_up_to_date_new_head_is_sha256_prefixed(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) cid = _make_commit(tmp_path) (ref_path(tmp_path, "upstream")).write_text(cid, encoding="utf-8") result = _invoke(["rebase", "--json", "upstream"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) assert data["new_head"].startswith("sha256:") def test_abort_new_head_is_sha256_prefixed(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) state = RebaseState( original_branch="main", original_head=base, onto=base, remaining=[], completed=[], squash=False, ) save_rebase_state(tmp_path, state) result = _invoke(["rebase", "--abort", "--json"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) assert data["new_head"].startswith("sha256:"), ( f"abort new_head not sha256:-prefixed: {data['new_head']!r}" ) def test_dry_run_commit_ids_are_sha256_prefixed(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) (ref_path(tmp_path, "upstream")).write_text(base, encoding="utf-8") _make_commit(tmp_path, parent_id=base) result = _invoke(["rebase", "--dry-run", "--json", "upstream"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) for entry in data["commits"]: assert entry["commit_id"].startswith("sha256:"), ( f"dry_run commit_id not sha256:-prefixed: {entry['commit_id']!r}" ) def test_status_original_head_is_sha256_prefixed(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) state = RebaseState( original_branch="main", original_head=base, onto=base, remaining=[], completed=[], squash=False, ) save_rebase_state(tmp_path, state) result = _invoke(["rebase", "--status", "--json"], tmp_path) assert result.exit_code == 0 data = _json_from(result.output) assert data["original_head"].startswith("sha256:"), ( f"status original_head not sha256:-prefixed: {data['original_head']!r}" ) # --------------------------------------------------------------------------- # Full JSON schema — all fields present on each path # --------------------------------------------------------------------------- class TestJsonSchemaComplete: """Verify all required fields exist in each output path.""" _RESULT_FIELDS = { "status", "branch", "new_head", "onto", "squash", "replayed", "replayed_commit_ids", "conflicts", "duration_ms", "exit_code", } _STATUS_FIELDS = { "active", "original_branch", "original_head", "onto", "total", "done", "remaining", "squash", "duration_ms", "exit_code", } _DRY_RUN_FIELDS = { "branch", "onto", "commits", "count", "squash", "duration_ms", "exit_code", } def test_completed_schema(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) (ref_path(tmp_path, "upstream")).write_text(base, encoding="utf-8") _make_commit(tmp_path, parent_id=base) result = _invoke(["rebase", "--json", "upstream"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) missing = self._RESULT_FIELDS - set(data) assert not missing, f"completed JSON missing fields: {missing}" assert data["status"] == "completed" assert data["exit_code"] == 0 assert data["replayed"] == 1 assert len(data["replayed_commit_ids"]) == 1 def test_aborted_schema(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) tip = _make_commit(tmp_path, parent_id=base) state = RebaseState( original_branch="main", original_head=base, onto=base, remaining=[tip], completed=[], squash=False, ) save_rebase_state(tmp_path, state) result = _invoke(["rebase", "--abort", "--json"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) missing = self._RESULT_FIELDS - set(data) assert not missing, f"aborted JSON missing fields: {missing}" assert data["status"] == "aborted" assert data["exit_code"] == 0 assert data["new_head"] == base assert data["replayed_commit_ids"] == [] def test_up_to_date_schema(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) cid = _make_commit(tmp_path) (ref_path(tmp_path, "upstream")).write_text(cid, encoding="utf-8") result = _invoke(["rebase", "--json", "upstream"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) missing = self._RESULT_FIELDS - set(data) assert not missing, f"up_to_date JSON missing fields: {missing}" assert data["status"] == "up_to_date" assert data["exit_code"] == 0 assert data["replayed"] == 0 assert data["replayed_commit_ids"] == [] def test_dry_run_schema(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) (ref_path(tmp_path, "upstream")).write_text(base, encoding="utf-8") c1 = _make_commit(tmp_path, parent_id=base) result = _invoke(["rebase", "--dry-run", "--json", "upstream"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) missing = self._DRY_RUN_FIELDS - set(data) assert not missing, f"dry_run JSON missing fields: {missing}" assert data["count"] == 1 assert data["commits"][0]["commit_id"] == c1 assert data["exit_code"] == 0 def test_status_schema_inactive(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) _make_commit(tmp_path) result = _invoke(["rebase", "--status", "--json"], tmp_path) assert result.exit_code == 0 data = _json_from(result.output) missing = self._STATUS_FIELDS - set(data) assert not missing, f"status JSON missing fields: {missing}" assert data["active"] is False assert data["exit_code"] == 0 def test_status_schema_active(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) state = RebaseState( original_branch="feat/x", original_head=base, onto=base, remaining=[base], completed=[], squash=True, ) save_rebase_state(tmp_path, state) result = _invoke(["rebase", "--status", "--json"], tmp_path) assert result.exit_code == 0 data = _json_from(result.output) missing = self._STATUS_FIELDS - set(data) assert not missing, f"status (active) JSON missing fields: {missing}" assert data["active"] is True assert data["original_branch"] == "feat/x" assert data["exit_code"] == 0 # --------------------------------------------------------------------------- # Lifecycle integration tests # --------------------------------------------------------------------------- class TestRebaseLifecycle: """Full lifecycle: init → rebase → result; abort restores HEAD.""" def test_simple_rebase_completed(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) (ref_path(tmp_path, "upstream")).write_text(base, encoding="utf-8") _make_commit(tmp_path, parent_id=base) result = _invoke(["rebase", "upstream"], tmp_path) assert result.exit_code == 0, result.output assert "complete" in result.output.lower() or "up to date" in result.output.lower() def test_abort_restores_head(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) tip = _make_commit(tmp_path, parent_id=base) state = RebaseState( original_branch="main", original_head=base, onto=base, remaining=[tip], completed=[], squash=False, ) save_rebase_state(tmp_path, state) result = _invoke(["rebase", "--abort"], tmp_path) assert result.exit_code == 0 assert "aborted" in result.output.lower() assert load_rebase_state(tmp_path) is None restored = (ref_path(tmp_path, "main")).read_text(encoding="utf-8").strip() assert restored == base def test_abort_text_shows_sha256_short_id(self, tmp_path: pathlib.Path) -> None: """Abort text output must show sha256:<12 hex chars>, not bare hex.""" _init_repo(tmp_path) base = _make_commit(tmp_path) state = RebaseState( original_branch="main", original_head=base, onto=base, remaining=[], completed=[], squash=False, ) save_rebase_state(tmp_path, state) result = _invoke(["rebase", "--abort"], tmp_path) assert result.exit_code == 0 expected_short = long_id(base[7:19])# prefix + 12 hex chars assert expected_short in result.output, ( f"Expected {expected_short!r} in abort text output: {result.output!r}" ) def test_already_up_to_date_text(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) cid = _make_commit(tmp_path) (ref_path(tmp_path, "up")).write_text(cid, encoding="utf-8") result = _invoke(["rebase", "up"], tmp_path) assert result.exit_code == 0 assert "up to date" in result.output.lower() def test_dry_run_no_side_effects(self, tmp_path: pathlib.Path) -> None: """--dry-run must not write REBASE_STATE.json or modify branch refs.""" _init_repo(tmp_path) base = _make_commit(tmp_path) (ref_path(tmp_path, "upstream")).write_text(base, encoding="utf-8") c1 = _make_commit(tmp_path, parent_id=base) original_head = (ref_path(tmp_path, "main")).read_text(encoding="utf-8").strip() result = _invoke(["rebase", "--dry-run", "upstream"], tmp_path) assert result.exit_code == 0 assert not (rebase_state_path(tmp_path)).exists() new_head = (ref_path(tmp_path, "main")).read_text(encoding="utf-8").strip() assert new_head == original_head expected_short = long_id(c1[7:19]) assert expected_short in result.output def test_dry_run_squash_flag(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) (ref_path(tmp_path, "upstream")).write_text(base, encoding="utf-8") _make_commit(tmp_path, parent_id=base) result = _invoke(["rebase", "--dry-run", "--squash", "--json", "upstream"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) assert data["squash"] is True def test_status_text_inactive(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) _make_commit(tmp_path) result = _invoke(["rebase", "--status"], tmp_path) assert result.exit_code == 0 assert "No rebase" in result.output def test_status_text_active(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) state = RebaseState( original_branch="feat/y", original_head=base, onto=base, remaining=[base], completed=[], squash=True, ) save_rebase_state(tmp_path, state) result = _invoke(["rebase", "--status"], tmp_path) assert result.exit_code == 0 assert "feat/y" in result.output def test_completed_clears_state_file(self, tmp_path: pathlib.Path) -> None: """After a clean rebase, REBASE_STATE.json must be removed.""" _init_repo(tmp_path) base = _make_commit(tmp_path) (ref_path(tmp_path, "upstream")).write_text(base, encoding="utf-8") _make_commit(tmp_path, parent_id=base) result = _invoke(["rebase", "upstream"], tmp_path) assert result.exit_code == 0, result.output assert load_rebase_state(tmp_path) is None def test_max_commits_cap(self, tmp_path: pathlib.Path) -> None: """--max-commits 2 on a 5-commit chain reports at most 2.""" _init_repo(tmp_path) base = _make_commit(tmp_path) (ref_path(tmp_path, "upstream")).write_text(base, encoding="utf-8") prev = base for _ in range(5): prev = _make_commit(tmp_path, parent_id=prev) result = _invoke( ["rebase", "--dry-run", "--json", "--max-commits", "2", "upstream"], tmp_path ) assert result.exit_code == 0, result.output data = _json_from(result.output) assert data["count"] <= 2 # --------------------------------------------------------------------------- # Error paths # --------------------------------------------------------------------------- class TestErrors: """Error conditions must exit non-zero.""" def test_no_upstream_exits_nonzero(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) _make_commit(tmp_path) result = _invoke(["rebase"], tmp_path) assert result.exit_code != 0 def test_unknown_upstream_exits_nonzero(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) _make_commit(tmp_path) result = _invoke(["rebase", "nonexistent-branch-xyz"], tmp_path) assert result.exit_code != 0 assert "not found" in result.stderr.lower() def test_abort_no_state_exits_nonzero(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) result = _invoke(["rebase", "--abort"], tmp_path) assert result.exit_code != 0 def test_continue_no_state_exits_nonzero(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) result = _invoke(["rebase", "--continue"], tmp_path) assert result.exit_code != 0 def test_rebase_in_progress_exits_nonzero(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path) state = RebaseState( original_branch="main", original_head=base, onto=base, remaining=[], completed=[], squash=False, ) save_rebase_state(tmp_path, state) result = _invoke(["rebase", "main"], tmp_path) assert result.exit_code != 0 assert "--continue" in result.stderr or "--abort" in result.stderr # --------------------------------------------------------------------------- # Security — symlink and size guards (from hardening tests) # --------------------------------------------------------------------------- class TestSecurity: """Symlink and size-cap guards on REBASE_STATE.json.""" def test_load_rebase_state_symlink_rejected(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) state_path = rebase_state_path(tmp_path) target = tmp_path / "sensitive.json" target.write_text( json.dumps({ "original_branch": "main", "original_head": "a" * 64, "onto": "b" * 64, "remaining": [], "completed": [], "squash": False, }), encoding="utf-8", ) state_path.symlink_to(target) result = load_rebase_state(tmp_path) assert result is None, "Symlinked state file must be rejected" def test_save_rebase_state_symlink_rejected(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) state_path = rebase_state_path(tmp_path) target = tmp_path / "victim.json" target.write_text("{}", encoding="utf-8") state_path.symlink_to(target) state = RebaseState( original_branch="main", original_head="a" * 64, onto="b" * 64, remaining=[], completed=[], squash=False, ) with pytest.raises(OSError, match="symlink"): save_rebase_state(tmp_path, state) assert target.read_text(encoding="utf-8") == "{}" def test_clear_rebase_state_symlink_not_deleted(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) state_path = rebase_state_path(tmp_path) target = tmp_path / "do_not_delete.json" target.write_text("important", encoding="utf-8") state_path.symlink_to(target) clear_rebase_state(tmp_path) assert target.exists() def test_load_rebase_state_size_cap_rejected(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) state_path = rebase_state_path(tmp_path) state_path.write_bytes(b"x" * (_MAX_STATE_BYTES + 1)) result = load_rebase_state(tmp_path) assert result is None def test_load_rebase_state_exactly_at_cap_rejected(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) state_path = rebase_state_path(tmp_path) state_path.write_bytes(b"y" * _MAX_STATE_BYTES) result = load_rebase_state(tmp_path) assert result is None # invalid JSON, size check fires first # --------------------------------------------------------------------------- # Performance # --------------------------------------------------------------------------- class TestPerformance: """Timing guards — key operations must complete quickly.""" def test_status_completes_within_200ms(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) _make_commit(tmp_path) t0 = time.monotonic() result = _invoke(["rebase", "--status", "--json"], tmp_path) elapsed = time.monotonic() - t0 assert result.exit_code == 0 assert elapsed < 0.2, f"--status took {elapsed*1000:.1f}ms (expected <200ms)" def test_dry_run_50_commits_completes_within_5s(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path, content=b"perf-base") (ref_path(tmp_path, "upstream")).write_text(base, encoding="utf-8") prev = base for i in range(50): prev = _make_commit(tmp_path, parent_id=prev, content=f"p{i}".encode()) t0 = time.monotonic() result = _invoke(["rebase", "--dry-run", "--json", "upstream"], tmp_path) elapsed = time.monotonic() - t0 assert result.exit_code == 0, result.output data = _json_from(result.output) assert data["count"] == 50 assert elapsed < 5.0, f"dry-run 50 commits took {elapsed:.2f}s (expected <5s)" def test_duration_ms_is_positive(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) cid = _make_commit(tmp_path) (ref_path(tmp_path, "up")).write_text(cid, encoding="utf-8") result = _invoke(["rebase", "--json", "up"], tmp_path) data = _json_from(result.output) # duration_ms must be a number (could be 0.0 on very fast systems, but always a float) assert isinstance(data["duration_ms"], (int, float)) # --------------------------------------------------------------------------- # Stress # --------------------------------------------------------------------------- class TestStress: """Large rebase chains and concurrent operations.""" def test_collect_20_commits(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path, content=b"stress-base") prev = base ids = [] for i in range(20): prev = _make_commit(tmp_path, parent_id=prev, content=f"s{i}".encode()) ids.append(prev) result = collect_commits_to_replay(tmp_path, stop_at=base, tip=prev) assert len(result) == 20 assert result[0].commit_id == ids[0] assert result[-1].commit_id == ids[-1] def test_50_commit_dry_run_json(self, tmp_path: pathlib.Path) -> None: _init_repo(tmp_path) base = _make_commit(tmp_path, content=b"fifty-base") (ref_path(tmp_path, "upstream")).write_text(base, encoding="utf-8") prev = base ids = [] for i in range(50): prev = _make_commit(tmp_path, parent_id=prev, content=f"t{i}".encode()) ids.append(prev) result = _invoke(["rebase", "--dry-run", "--json", "upstream"], tmp_path) assert result.exit_code == 0, result.output data = _json_from(result.output) assert data["count"] == 50 assert len(data["commits"]) == 50 assert data["commits"][0]["commit_id"] == ids[0] assert data["commits"][-1]["commit_id"] == ids[-1] # All IDs must be sha256:-prefixed for entry in data["commits"]: assert entry["commit_id"].startswith("sha256:") def test_concurrent_status_reads(self, tmp_path: pathlib.Path) -> None: """Multiple threads calling get_rebase_progress must not crash.""" _init_repo(tmp_path) state = RebaseState( original_branch="main", original_head="a" * 64, onto="b" * 64, remaining=["c" * 64] * 10, completed=["d" * 64] * 5, squash=False, ) save_rebase_state(tmp_path, state) errors: list[str] = [] def _read() -> None: try: p = get_rebase_progress(tmp_path) assert p["active"] is True except Exception as exc: errors.append(str(exc)) threads = [threading.Thread(target=_read) for _ in range(20)] for t in threads: t.start() for t in threads: t.join() assert not errors, f"Concurrent status failures: {errors}" def test_status_1000_element_state(self, tmp_path: pathlib.Path) -> None: """get_rebase_progress is fast even with a 1000-element state.""" _init_repo(tmp_path) state = RebaseState( original_branch="main", original_head="a" * 64, onto="b" * 64, remaining=["c" * 64] * 500, completed=["d" * 64] * 500, squash=False, ) save_rebase_state(tmp_path, state) p = get_rebase_progress(tmp_path) assert p["total"] == 1000 assert p["done"] == 500 assert p["remaining"] == 500 # --------------------------------------------------------------------------- # TestRegisterFlags — argparse-level verification # --------------------------------------------------------------------------- class TestRegisterFlags: """Verify that register() wires --json / -j correctly.""" def _make_parser(self) -> "argparse.ArgumentParser": import argparse from muse.cli.commands.rebase import register ap = argparse.ArgumentParser() subs = ap.add_subparsers() register(subs) return ap def test_json_flag_long(self) -> None: ns = self._make_parser().parse_args(["rebase", "--json"]) assert ns.json_out is True def test_j_alias(self) -> None: ns = self._make_parser().parse_args(["rebase", "-j"]) assert ns.json_out is True def test_default_is_text(self) -> None: ns = self._make_parser().parse_args(["rebase"]) assert ns.json_out is False def test_dest_is_json_out(self) -> None: ns = self._make_parser().parse_args(["rebase", "-j"]) assert hasattr(ns, "json_out") assert not hasattr(ns, "fmt")