"""Tests for ``muse restore`` — working-tree and stage file restoration. Coverage tiers: - Unit: _resolve_source_manifest, _resolve_file_path helpers - Integration: restore worktree from HEAD (default), restore --staged (unstage), restore --staged --worktree (full reset), --source , multiple paths, glob patterns, --dry-run, --json - End-to-end: full CLI via CliRunner - Security: path traversal rejected, outside-repo paths rejected - Edge cases: file not in HEAD/source, staged-only file restore - Stress: restore 100 modified files """ from __future__ import annotations from collections.abc import Mapping import datetime import json import pathlib import pytest from tests.cli_test_helper import CliRunner from muse.core.object_store import write_object from muse.core.ids import hash_commit, hash_snapshot from muse.core.commits import ( CommitRecord, write_commit, ) from muse.core.snapshots import ( SnapshotRecord, write_snapshot, ) from muse.core.types import Manifest, blob_id from muse.plugins.code.stage import StagedFileMap, make_entry, read_stage, write_stage from muse.core.paths import muse_dir, ref_path runner = CliRunner() _REPO_ID = "restore-test" # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- _counter = 0 def _init_repo(path: pathlib.Path) -> pathlib.Path: dot_muse = muse_dir(path) for d in ("commits", "snapshots", "objects", "refs/heads", "code"): (dot_muse / d).mkdir(parents=True, exist_ok=True) (dot_muse / "HEAD").write_text("ref: refs/heads/main", encoding="utf-8") (dot_muse / "repo.json").write_text( json.dumps({"repo_id": _REPO_ID, "domain": "code"}), encoding="utf-8" ) return path def _env(repo: pathlib.Path) -> Mapping[str, str]: return {"MUSE_REPO_ROOT": str(repo)} def _commit_files(root: pathlib.Path, files: Mapping[str, bytes], branch: str = "main") -> str: global _counter _counter += 1 manifest: Manifest = {} for rel_path, content in files.items(): obj_id = blob_id(content) write_object(root, obj_id, content) manifest[rel_path] = obj_id abs_path = root / rel_path abs_path.parent.mkdir(parents=True, exist_ok=True) abs_path.write_bytes(content) snap_id = hash_snapshot(manifest) write_snapshot(root, SnapshotRecord(snapshot_id=snap_id, manifest=manifest)) committed_at = datetime.datetime.now(datetime.timezone.utc) commit_id = hash_commit( parent_ids=[], snapshot_id=snap_id, message=f"commit {_counter}", committed_at_iso=committed_at.isoformat(), ) write_commit( root, CommitRecord( commit_id=commit_id, branch=branch, snapshot_id=snap_id, message=f"commit {_counter}", committed_at=committed_at, ), ) (ref_path(root, branch)).write_text(commit_id, encoding="utf-8") return commit_id def _invoke(repo: pathlib.Path, *args: str) -> "InvokeResult": from muse.cli.app import main as cli return runner.invoke(cli, ["restore", *args], env=_env(repo)) # --------------------------------------------------------------------------- # Unit — helpers # --------------------------------------------------------------------------- def test_resolve_source_manifest_from_head(tmp_path: pathlib.Path) -> None: from muse.cli.commands.restore import _resolve_source_manifest root = _init_repo(tmp_path) content = b"hello\n" _commit_files(root, {"f.py": content}) manifest = _resolve_source_manifest(root, source_ref=None) assert "f.py" in manifest assert manifest["f.py"] == blob_id(content) def test_resolve_source_manifest_empty_repo(tmp_path: pathlib.Path) -> None: from muse.cli.commands.restore import _resolve_source_manifest root = _init_repo(tmp_path) # No commits yet — should return empty dict, not raise manifest = _resolve_source_manifest(root, source_ref=None) assert manifest == {} def test_resolve_file_path_inside_repo(tmp_path: pathlib.Path) -> None: from muse.cli.commands.restore import _resolve_file_path root = _init_repo(tmp_path) rel = _resolve_file_path(root, "src/main.py") assert rel == "src/main.py" def test_resolve_file_path_traversal_raises(tmp_path: pathlib.Path) -> None: from muse.cli.commands.restore import _resolve_file_path root = _init_repo(tmp_path) with pytest.raises(SystemExit) as exc: _resolve_file_path(root, "../../../etc/passwd") assert exc.value.code != 0 # --------------------------------------------------------------------------- # Integration — restore worktree (default) # --------------------------------------------------------------------------- def test_restore_worktree_overwrites_modified_file(tmp_path: pathlib.Path) -> None: root = _init_repo(tmp_path) original = b"# original\n" _commit_files(root, {"a.py": original}) # Modify on disk (root / "a.py").write_bytes(b"# dirty\n") result = _invoke(root, "a.py") assert result.exit_code == 0 assert (root / "a.py").read_bytes() == original def test_restore_worktree_does_not_touch_stage(tmp_path: pathlib.Path) -> None: root = _init_repo(tmp_path) _commit_files(root, {"a.py": b"# orig\n"}) # Stage a modification new_content = b"# staged\n" obj_id = blob_id(new_content) write_object(root, obj_id, new_content) stage = read_stage(root) stage["a.py"] = make_entry(obj_id, "M") write_stage(root, stage) # Dirty the disk (root / "a.py").write_bytes(b"# dirty\n") _invoke(root, "a.py") # Stage must be untouched stage_after = read_stage(root) assert "a.py" in stage_after assert stage_after["a.py"]["mode"] == "M" def test_restore_worktree_from_staged_content(tmp_path: pathlib.Path) -> None: """When a file is staged, default restore pulls from the staged object_id.""" root = _init_repo(tmp_path) _commit_files(root, {"b.py": b"# head\n"}) staged_content = b"# staged version\n" obj_id = blob_id(staged_content) write_object(root, obj_id, staged_content) stage = read_stage(root) stage["b.py"] = make_entry(obj_id, "M") write_stage(root, stage) (root / "b.py").write_bytes(b"# dirty\n") _invoke(root, "b.py") assert (root / "b.py").read_bytes() == staged_content def test_restore_exit_zero_on_success(tmp_path: pathlib.Path) -> None: root = _init_repo(tmp_path) _commit_files(root, {"a.py": b"# a\n"}) (root / "a.py").write_bytes(b"# dirty\n") result = _invoke(root, "a.py") assert result.exit_code == 0 def test_restore_file_not_in_head_exits_nonzero(tmp_path: pathlib.Path) -> None: root = _init_repo(tmp_path) _commit_files(root, {"other.py": b"# o\n"}) result = _invoke(root, "ghost.py") assert result.exit_code != 0 # --------------------------------------------------------------------------- # Integration — restore --staged # --------------------------------------------------------------------------- def test_restore_staged_removes_modification(tmp_path: pathlib.Path) -> None: root = _init_repo(tmp_path) _commit_files(root, {"a.py": b"# orig\n"}) obj_id = blob_id(b"# modified\n") write_object(root, obj_id, b"# modified\n") stage: StagedFileMap = {"a.py": make_entry(obj_id, "M")} write_stage(root, stage) result = _invoke(root, "--staged", "a.py") assert result.exit_code == 0 stage_after = read_stage(root) assert "a.py" not in stage_after def test_restore_staged_does_not_touch_disk(tmp_path: pathlib.Path) -> None: root = _init_repo(tmp_path) _commit_files(root, {"a.py": b"# orig\n"}) modified_content = b"# modified\n" obj_id = blob_id(modified_content) write_object(root, obj_id, modified_content) stage: StagedFileMap = {"a.py": make_entry(obj_id, "M")} write_stage(root, stage) (root / "a.py").write_bytes(modified_content) _invoke(root, "--staged", "a.py") # Disk still has the modified content assert (root / "a.py").read_bytes() == modified_content def test_restore_staged_removes_added_file(tmp_path: pathlib.Path) -> None: """Unstaging a brand-new file (mode 'A', not in HEAD) removes it from stage.""" root = _init_repo(tmp_path) _commit_files(root, {"anchor.py": b"# anchor\n"}) content = b"# new\n" obj_id = blob_id(content) write_object(root, obj_id, content) (root / "new.py").write_bytes(content) stage: StagedFileMap = {"new.py": make_entry(obj_id, "A")} write_stage(root, stage) result = _invoke(root, "--staged", "new.py") assert result.exit_code == 0 stage_after = read_stage(root) assert "new.py" not in stage_after # Disk file untouched assert (root / "new.py").exists() def test_restore_staged_undeletes_from_stage(tmp_path: pathlib.Path) -> None: """Restoring --staged a deleted file removes the 'D' tombstone.""" root = _init_repo(tmp_path) _commit_files(root, {"gone.py": b"# original\n"}) stage: StagedFileMap = {"gone.py": make_entry("", "D")} write_stage(root, stage) result = _invoke(root, "--staged", "gone.py") assert result.exit_code == 0 stage_after = read_stage(root) assert "gone.py" not in stage_after def test_restore_staged_not_staged_is_noop(tmp_path: pathlib.Path) -> None: """Restoring --staged a file that isn't staged is a clean no-op.""" root = _init_repo(tmp_path) _commit_files(root, {"a.py": b"# a\n"}) result = _invoke(root, "--staged", "a.py") assert result.exit_code == 0 # --------------------------------------------------------------------------- # Integration — restore --staged --worktree (full reset) # --------------------------------------------------------------------------- def test_restore_staged_worktree_resets_both(tmp_path: pathlib.Path) -> None: """--staged --worktree restores disk and clears stage entry.""" root = _init_repo(tmp_path) original = b"# original\n" _commit_files(root, {"f.py": original}) modified = b"# modified\n" obj_id = blob_id(modified) write_object(root, obj_id, modified) stage: StagedFileMap = {"f.py": make_entry(obj_id, "M")} write_stage(root, stage) (root / "f.py").write_bytes(modified) result = _invoke(root, "--staged", "--worktree", "f.py") assert result.exit_code == 0 assert (root / "f.py").read_bytes() == original stage_after = read_stage(root) assert "f.py" not in stage_after # --------------------------------------------------------------------------- # Integration — --source # --------------------------------------------------------------------------- def test_restore_source_ref_restores_from_commit(tmp_path: pathlib.Path) -> None: """--source restores file from that commit's manifest.""" root = _init_repo(tmp_path) v1 = b"# version 1\n" commit_v1 = _commit_files(root, {"versioned.py": v1}) # Now update the file in HEAD v2 = b"# version 2\n" _commit_files(root, {"versioned.py": v2}) # Disk now has v2; restore to v1 using the first commit id (root / "versioned.py").write_bytes(b"# dirty\n") result = _invoke(root, "--source", commit_v1, "versioned.py") assert result.exit_code == 0 assert (root / "versioned.py").read_bytes() == v1 def test_restore_source_ref_not_found_exits_nonzero(tmp_path: pathlib.Path) -> None: root = _init_repo(tmp_path) _commit_files(root, {"a.py": b"# a\n"}) result = _invoke(root, "--source", "nonexistent-ref", "a.py") assert result.exit_code != 0 def test_restore_source_file_not_in_that_commit_exits_nonzero(tmp_path: pathlib.Path) -> None: root = _init_repo(tmp_path) v1_commit = _commit_files(root, {"only_in_v1.py": b"# v1\n"}) _commit_files(root, {"v2_only.py": b"# v2\n"}) result = _invoke(root, "--source", v1_commit, "v2_only.py") assert result.exit_code != 0 # --------------------------------------------------------------------------- # Integration -- multiple paths # --------------------------------------------------------------------------- def test_restore_multiple_paths(tmp_path: pathlib.Path) -> None: root = _init_repo(tmp_path) orig_a = b"# a orig\n" orig_b = b"# b orig\n" _commit_files(root, {"a.py": orig_a, "b.py": orig_b}) (root / "a.py").write_bytes(b"# a dirty\n") (root / "b.py").write_bytes(b"# b dirty\n") result = _invoke(root, "a.py", "b.py") assert result.exit_code == 0 assert (root / "a.py").read_bytes() == orig_a assert (root / "b.py").read_bytes() == orig_b # --------------------------------------------------------------------------- # Integration — --dry-run # --------------------------------------------------------------------------- def test_restore_dry_run_no_disk_change(tmp_path: pathlib.Path) -> None: root = _init_repo(tmp_path) _commit_files(root, {"a.py": b"# orig\n"}) dirty = b"# dirty\n" (root / "a.py").write_bytes(dirty) result = _invoke(root, "--dry-run", "a.py") assert result.exit_code == 0 assert (root / "a.py").read_bytes() == dirty def test_restore_dry_run_no_stage_change(tmp_path: pathlib.Path) -> None: root = _init_repo(tmp_path) _commit_files(root, {"a.py": b"# orig\n"}) obj_id = blob_id(b"# modified\n") write_object(root, obj_id, b"# modified\n") stage: StagedFileMap = {"a.py": make_entry(obj_id, "M")} write_stage(root, stage) _invoke(root, "--dry-run", "--staged", "a.py") stage_after = read_stage(root) assert "a.py" in stage_after # stage unchanged def test_restore_dry_run_json(tmp_path: pathlib.Path) -> None: root = _init_repo(tmp_path) _commit_files(root, {"a.py": b"# orig\n"}) (root / "a.py").write_bytes(b"# dirty\n") result = _invoke(root, "--dry-run", "--json", "a.py") assert result.exit_code == 0 data = json.loads(result.stdout) assert data["dry_run"] is True assert "a.py" in data.get("restored", []) or len(data.get("paths", [])) >= 1 # --------------------------------------------------------------------------- # Integration — --json # --------------------------------------------------------------------------- def test_restore_json_output_structure(tmp_path: pathlib.Path) -> None: root = _init_repo(tmp_path) _commit_files(root, {"j.py": b"# j\n"}) (root / "j.py").write_bytes(b"# dirty\n") result = _invoke(root, "--json", "j.py") assert result.exit_code == 0 data = json.loads(result.stdout) assert "restored" in data assert "j.py" in data["restored"] assert data["dry_run"] is False # --------------------------------------------------------------------------- # Security # --------------------------------------------------------------------------- def test_restore_path_traversal_rejected(tmp_path: pathlib.Path) -> None: root = _init_repo(tmp_path) _commit_files(root, {"anchor.py": b"# a\n"}) result = _invoke(root, "../../../etc/passwd") assert result.exit_code != 0 def test_restore_staged_path_traversal_rejected(tmp_path: pathlib.Path) -> None: root = _init_repo(tmp_path) _commit_files(root, {"anchor.py": b"# a\n"}) result = _invoke(root, "--staged", "../../malicious.py") assert result.exit_code != 0 # --------------------------------------------------------------------------- # Stress # --------------------------------------------------------------------------- def test_restore_100_modified_files(tmp_path: pathlib.Path) -> None: """Restore 100 modified files in one invocation.""" root = _init_repo(tmp_path) originals = {f"f{i}.py": f"# orig {i}\n".encode() for i in range(100)} _commit_files(root, originals) # Dirty all 100 for name in originals: (root / name).write_bytes(b"# dirty\n") result = _invoke(root, *originals.keys()) assert result.exit_code == 0 for name, orig_content in originals.items(): assert (root / name).read_bytes() == orig_content, f"{name} not restored"