"""Hardening tests for ``muse reset`` — security, schema, error routing, ordering. These tests cover the issues fixed in the security/correctness/agent-UX audit. They are intentionally distinct from the existing test_cmd_reset_revert.py and test_cli_reset_revert.py suites, which cover the core reset algorithm. Coverage tiers -------------- Unit — parser flags, dead-code removal. Integration — error routing to stderr, JSON schema, --dry-run, ordering safety. End-to-end — full CLI: security, branch-name sanitization. Security — ANSI injection, ref sanitization, exc sanitization. Stress — large repos, concurrent repos, reset-and-verify cycles. """ from __future__ import annotations import json import os import pathlib import subprocess import threading import time import pytest from tests.cli_test_helper import CliRunner, InvokeResult from muse.core.refs import ( get_head_commit_id, read_current_branch, ) from muse.core.object_store import object_path as snapshot_path from muse.core.types import split_id runner = CliRunner() # ────────────────────────────────────────────────────────────────────────────── # Helpers # ────────────────────────────────────────────────────────────────────────────── JSON_REQUIRED_KEYS = { "branch", "ref", "old_commit_id", "new_commit_id", "snapshot_id", "mode", "dry_run", } def _invoke(repo: pathlib.Path, args: list[str]) -> InvokeResult: saved = os.getcwd() try: os.chdir(repo) return runner.invoke(None, args) finally: os.chdir(saved) def _reset(repo: pathlib.Path, *extra: str) -> InvokeResult: return _invoke(repo, ["reset", *extra]) def _commit(repo: pathlib.Path, message: str) -> str: """Stage all changes and commit, returning the full commit ID.""" import re _invoke(repo, ["code", "add", "."]) result = _invoke(repo, ["commit", "-m", message]) m = re.search(r'\[(?:main|[^ ]+) (sha256:[0-9a-f]{64})', result.output) return m.group(1) if m else "" @pytest.fixture() def repo(tmp_path: pathlib.Path) -> pathlib.Path: """Initialised repo with two commits on ``main``.""" saved = os.getcwd() try: os.chdir(tmp_path) runner.invoke(None, ["init"]) finally: os.chdir(saved) (tmp_path / "a.py").write_text("x = 1\n") _commit(tmp_path, "initial") (tmp_path / "b.py").write_text("y = 2\n") _commit(tmp_path, "add b") return tmp_path @pytest.fixture() def c1_id(repo: pathlib.Path) -> str: """Full commit ID of the first commit (HEAD~1).""" from muse.core.commits import read_commit head_id = get_head_commit_id(repo, "main") or "" head = read_commit(repo, head_id) return (head.parent_commit_id or "") if head else "" # ────────────────────────────────────────────────────────────────────────────── # Unit — parser flags # ────────────────────────────────────────────────────────────────────────────── class TestRegisterFlags: def _parse(self, *args: str) -> "object": import argparse from muse.cli.commands.reset import register p = argparse.ArgumentParser() sub = p.add_subparsers() register(sub) return p.parse_args(["reset", *args]) def test_default_json_out_is_false(self) -> None: ns = self._parse("HEAD~1") assert ns.json_out is False def test_json_flag_sets_json_out(self) -> None: ns = self._parse("HEAD~1", "--json") assert ns.json_out is True def test_j_shorthand_sets_json_out(self) -> None: ns = self._parse("HEAD~1", "-j") assert ns.json_out is True def test_dry_run_default_false(self) -> None: import argparse from muse.cli.commands.reset import register p = argparse.ArgumentParser() sub = p.add_subparsers() register(sub) ns = p.parse_args(["reset", "HEAD~1"]) assert ns.dry_run is False def test_dry_run_flag(self) -> None: import argparse from muse.cli.commands.reset import register p = argparse.ArgumentParser() sub = p.add_subparsers() register(sub) ns = p.parse_args(["reset", "HEAD~1", "--dry-run"]) assert ns.dry_run is True def test_hard_default_false(self) -> None: import argparse from muse.cli.commands.reset import register p = argparse.ArgumentParser() sub = p.add_subparsers() register(sub) ns = p.parse_args(["reset", "HEAD~1"]) assert ns.hard is False def test_hard_flag(self) -> None: import argparse from muse.cli.commands.reset import register p = argparse.ArgumentParser() sub = p.add_subparsers() register(sub) ns = p.parse_args(["reset", "HEAD~1", "--hard"]) assert ns.hard is True def test_force_default_false(self) -> None: import argparse from muse.cli.commands.reset import register p = argparse.ArgumentParser() sub = p.add_subparsers() register(sub) ns = p.parse_args(["reset", "HEAD~1"]) assert ns.force is False # ────────────────────────────────────────────────────────────────────────────── # Unit — dead-code removal # ────────────────────────────────────────────────────────────────────────────── class TestDeadCodeRemoved: def test_read_branch_wrapper_removed(self) -> None: import muse.cli.commands.reset as m assert not hasattr(m, "_read_branch"), ( "_read_branch was a dead one-liner wrapper and must be deleted" ) # ────────────────────────────────────────────────────────────────────────────── # Integration — error routing to stderr # ────────────────────────────────────────────────────────────────────────────── class TestErrorRouting: def test_unknown_ref_error_to_stderr(self, repo: pathlib.Path) -> None: result = _reset(repo, "bogus-ref") assert result.exit_code == 1 assert "not found" in (result.stderr or "").lower() assert "not found" not in result.output.replace(result.stderr or "", "") def test_unknown_flag_exits_nonzero(self, repo: pathlib.Path) -> None: result = _reset(repo, "HEAD~1", "--format", "xml") assert result.exit_code != 0 def test_missing_snapshot_error_to_stderr(self, repo: pathlib.Path, c1_id: str) -> None: """When --hard reset target snapshot is missing, error goes to stderr.""" from muse.core.commits import read_commit commit = read_commit(repo, c1_id) if commit is None: pytest.skip("Could not read c1 commit") snap_id = commit.snapshot_id snap_path = snapshot_path(repo, snap_id) snap_path.unlink(missing_ok=True) result = _reset(repo, c1_id, "--hard") assert result.exit_code != 0 assert "not found" in (result.stderr or "").lower() or "snapshot" in (result.stderr or "").lower() def test_snapshot_pre_validated_before_branch_ref_written( self, repo: pathlib.Path, c1_id: str ) -> None: """Critical ordering fix: branch ref must NOT advance when snapshot is missing. Before the fix, write_branch_ref() was called BEFORE read_snapshot(), so a missing snapshot would leave the branch pointer at the new commit with an unrestored working tree — an inconsistent, unrecoverable state. """ from muse.core.commits import read_commit commit = read_commit(repo, c1_id) if commit is None: pytest.skip("Could not read c1 commit") snap_id = commit.snapshot_id snap_path = snapshot_path(repo, snap_id) snap_path.unlink(missing_ok=True) before_head = get_head_commit_id(repo, "main") _reset(repo, c1_id, "--hard") after_head = get_head_commit_id(repo, "main") # Branch ref must remain at the original commit — not advanced to c1. assert before_head == after_head, ( "Branch ref was advanced even though snapshot was missing — " "this is the pre-fix ordering bug" ) def test_snapshot_source_in_run_before_write_branch_ref(self) -> None: """Source inspection: read_snapshot must appear before write_branch_ref. We skip comment lines (those starting with #) to avoid false matches from documentation comments that reference function names. """ import inspect from muse.cli.commands.reset import run src = inspect.getsource(run) # Only consider non-comment executable lines. code_lines = [ (i, l) for i, l in enumerate(src.split("\n")) if not l.lstrip().startswith("#") ] snap_lineno = next((i for i, l in code_lines if "read_snapshot(" in l), -1) write_lineno = next((i for i, l in code_lines if "write_branch_ref(" in l), -1) assert snap_lineno != -1, "read_snapshot not found in run()" assert write_lineno != -1, "write_branch_ref not found in run()" assert snap_lineno < write_lineno, ( f"read_snapshot (line {snap_lineno}) must appear before " f"write_branch_ref (line {write_lineno}) in run() — " "this is the critical ordering fix that prevents orphaned branch refs" ) # ────────────────────────────────────────────────────────────────────────────── # Integration — JSON schema stability # ────────────────────────────────────────────────────────────────────────────── class TestJsonSchema: def test_soft_reset_has_all_keys(self, repo: pathlib.Path, c1_id: str) -> None: result = _reset(repo, c1_id, "--json") assert result.exit_code == 0 data = json.loads(result.output) missing = JSON_REQUIRED_KEYS - set(data) assert not missing, f"Missing keys in soft reset JSON: {missing}" def test_hard_reset_has_all_keys(self, repo: pathlib.Path, c1_id: str) -> None: result = _reset(repo, c1_id, "--hard", "--json") assert result.exit_code == 0 data = json.loads(result.output) missing = JSON_REQUIRED_KEYS - set(data) assert not missing, f"Missing keys in hard reset JSON: {missing}" def test_dry_run_has_all_keys(self, repo: pathlib.Path, c1_id: str) -> None: result = _reset(repo, c1_id, "--dry-run", "--json") assert result.exit_code == 0 data = json.loads(result.output) missing = JSON_REQUIRED_KEYS - set(data) assert not missing, f"Missing keys in dry-run JSON: {missing}" def test_soft_mode_is_soft(self, repo: pathlib.Path, c1_id: str) -> None: result = _reset(repo, c1_id, "--json") data = json.loads(result.output) assert data["mode"] == "soft" def test_hard_mode_is_hard(self, repo: pathlib.Path, c1_id: str) -> None: result = _reset(repo, c1_id, "--hard", "--json") data = json.loads(result.output) assert data["mode"] == "hard" def test_dry_run_flag_is_true(self, repo: pathlib.Path, c1_id: str) -> None: result = _reset(repo, c1_id, "--dry-run", "--json") data = json.loads(result.output) assert data["dry_run"] is True def test_live_reset_dry_run_flag_is_false(self, repo: pathlib.Path, c1_id: str) -> None: result = _reset(repo, c1_id, "--json") data = json.loads(result.output) assert data["dry_run"] is False def test_ref_field_matches_input(self, repo: pathlib.Path, c1_id: str) -> None: result = _reset(repo, c1_id, "--json") data = json.loads(result.output) assert data["ref"] == c1_id def test_snapshot_id_is_sha256(self, repo: pathlib.Path, c1_id: str) -> None: result = _reset(repo, c1_id, "--json") data = json.loads(result.output) sid = data["snapshot_id"] _, hex_part = split_id(sid) assert len(hex_part) == 64, f"Expected 64-char hex after prefix, got {len(hex_part)}: {sid!r}" assert all(c in "0123456789abcdef" for c in hex_part) def test_new_commit_id_is_sha256(self, repo: pathlib.Path, c1_id: str) -> None: result = _reset(repo, c1_id, "--json") data = json.loads(result.output) nid = data["new_commit_id"] _, hex_part = split_id(nid) assert len(hex_part) == 64, f"Expected 64-char hex after prefix, got {len(hex_part)}: {nid!r}" def test_old_commit_id_was_head(self, repo: pathlib.Path, c1_id: str) -> None: head_before = get_head_commit_id(repo, "main") result = _reset(repo, c1_id, "--json") data = json.loads(result.output) assert data["old_commit_id"] == head_before def test_branch_field_is_current_branch(self, repo: pathlib.Path, c1_id: str) -> None: result = _reset(repo, c1_id, "--json") data = json.loads(result.output) assert data["branch"] == "main" # ────────────────────────────────────────────────────────────────────────────── # Integration — --dry-run # ────────────────────────────────────────────────────────────────────────────── class TestDryRun: def test_dry_run_does_not_advance_branch(self, repo: pathlib.Path, c1_id: str) -> None: before = get_head_commit_id(repo, "main") _reset(repo, c1_id, "--dry-run") after = get_head_commit_id(repo, "main") assert before == after def test_dry_run_does_not_modify_workdir(self, repo: pathlib.Path, c1_id: str) -> None: b_content = (repo / "b.py").read_text() _reset(repo, c1_id, "--dry-run", "--hard") assert (repo / "b.py").read_text() == b_content def test_dry_run_exit_code_zero(self, repo: pathlib.Path, c1_id: str) -> None: result = _reset(repo, c1_id, "--dry-run") assert result.exit_code == 0 def test_dry_run_invalid_ref_exits_1(self, repo: pathlib.Path) -> None: result = _reset(repo, "nonexistent-ref", "--dry-run") assert result.exit_code == 1 def test_dry_run_json_shows_would_be_commit(self, repo: pathlib.Path, c1_id: str) -> None: result = _reset(repo, c1_id, "--dry-run", "--json") data = json.loads(result.output) assert data["new_commit_id"] == c1_id or data["new_commit_id"].startswith(c1_id) def test_dry_run_text_mentions_would(self, repo: pathlib.Path, c1_id: str) -> None: result = _reset(repo, c1_id, "--dry-run") assert "dry-run" in result.output.lower() or "would" in result.output.lower() def test_dry_run_does_not_write_reflog(self, repo: pathlib.Path, c1_id: str) -> None: from muse.core.reflog import read_reflog before_count = len(read_reflog(repo, "main")) _reset(repo, c1_id, "--dry-run") after_count = len(read_reflog(repo, "main")) assert before_count == after_count # ────────────────────────────────────────────────────────────────────────────── # Integration — soft reset # ────────────────────────────────────────────────────────────────────────────── class TestSoftReset: def test_soft_advances_branch_to_target(self, repo: pathlib.Path, c1_id: str) -> None: _reset(repo, c1_id) head = get_head_commit_id(repo, "main") assert head is not None and head.startswith(c1_id) def test_soft_preserves_working_tree(self, repo: pathlib.Path, c1_id: str) -> None: before = (repo / "b.py").read_text() _reset(repo, c1_id) assert (repo / "b.py").read_text() == before def test_soft_reflog_entry_written(self, repo: pathlib.Path, c1_id: str) -> None: from muse.core.reflog import read_reflog before_count = len(read_reflog(repo, "main")) _reset(repo, c1_id) after_count = len(read_reflog(repo, "main")) assert after_count > before_count def test_soft_text_output_has_commit_id(self, repo: pathlib.Path, c1_id: str) -> None: result = _reset(repo, c1_id) assert c1_id[:8] in result.output # ────────────────────────────────────────────────────────────────────────────── # Integration — hard reset # ────────────────────────────────────────────────────────────────────────────── class TestHardReset: def test_hard_advances_branch_to_target(self, repo: pathlib.Path, c1_id: str) -> None: result = _reset(repo, c1_id, "--hard") assert result.exit_code == 0 head = get_head_commit_id(repo, "main") assert head is not None and head.startswith(c1_id) def test_hard_restores_workdir(self, repo: pathlib.Path, c1_id: str) -> None: assert (repo / "b.py").exists() result = _reset(repo, c1_id, "--hard") assert result.exit_code == 0 # b.py was added in the second commit; resetting to c1 removes it assert not (repo / "b.py").exists() def test_hard_text_output_shows_head_is_now(self, repo: pathlib.Path, c1_id: str) -> None: result = _reset(repo, c1_id, "--hard") assert "HEAD is now at" in result.output or c1_id[:8] in result.output def test_hard_uses_message_first_line(self, repo: pathlib.Path, c1_id: str) -> None: """Text output shows only the first line of a multiline commit message.""" (repo / "x.py").write_text("x=1\n") multiline_id = _commit(repo, "first line\n\nmore detail here") # Go back to c1 so we can reset to the multiline commit _reset(repo, c1_id) result = _reset(repo, multiline_id, "--hard") assert "more detail here" not in result.output # ────────────────────────────────────────────────────────────────────────────── # Security — ANSI injection # ────────────────────────────────────────────────────────────────────────────── class TestSecurityAnsi: ESC = "\x1b[" def test_unknown_ref_sanitized_in_stderr(self, repo: pathlib.Path) -> None: malicious_ref = f"{self.ESC}31mmalicious{self.ESC}0m" result = _reset(repo, malicious_ref) assert self.ESC not in (result.stderr or "") def test_unknown_flag_with_ansi_exits_nonzero(self, repo: pathlib.Path) -> None: malicious_fmt = f"{self.ESC}31mxml{self.ESC}0m" result = _reset(repo, "HEAD~1", "--format", malicious_fmt) assert result.exit_code != 0 def test_no_ansi_in_stdout_on_error(self, repo: pathlib.Path) -> None: malicious_ref = f"{self.ESC}31mmalicious{self.ESC}0m" result = _reset(repo, malicious_ref) # stdout must be clean — errors go to stderr stdout_only = result.output.replace(result.stderr or "", "") assert self.ESC not in stdout_only def test_exc_sanitized_in_branch_validation(self, repo: pathlib.Path) -> None: """sanitize_display(str(exc)) must be used, not bare f'{exc}'.""" import inspect from muse.cli.commands.reset import run src = inspect.getsource(run) # Confirm the pattern sanitize_display(str(exc)) is used, not bare {exc} assert "sanitize_display(str(exc))" in src def test_ref_sanitized_in_not_found_message(self) -> None: """sanitize_display(ref) must be used in the not-found error message.""" import inspect from muse.cli.commands.reset import run src = inspect.getsource(run) assert "sanitize_display(ref)" in src def test_soft_text_output_sanitizes_branch(self, repo: pathlib.Path, c1_id: str) -> None: result = _reset(repo, c1_id) assert self.ESC not in result.output def test_hard_text_output_sanitizes_message(self, repo: pathlib.Path, c1_id: str) -> None: result = _reset(repo, c1_id, "--hard") assert self.ESC not in result.output # ────────────────────────────────────────────────────────────────────────────── # Integration — get_head_commit_id replaces ref_file.read_text() # ────────────────────────────────────────────────────────────────────────────── class TestRefAbstraction: def test_no_direct_ref_file_read(self) -> None: """run() must use get_head_commit_id(), not read ref_file directly.""" import inspect from muse.cli.commands.reset import run src = inspect.getsource(run) assert "ref_file.read_text" not in src, ( "Direct ref_file.read_text() bypasses the get_head_commit_id " "abstraction layer and is a TOCTOU vulnerability" ) assert "get_head_commit_id" in src def test_old_commit_id_correct_on_first_commit(self, repo: pathlib.Path) -> None: """old_commit_id in JSON must match HEAD before the reset.""" head = get_head_commit_id(repo, "main") result = _reset(repo, "HEAD~1", "--json") data = json.loads(result.output) assert data["old_commit_id"] == head # ────────────────────────────────────────────────────────────────────────────── # Stress # ────────────────────────────────────────────────────────────────────────────── @pytest.mark.slow class TestStress: def test_soft_reset_across_50_commits(self, repo: pathlib.Path) -> None: """Soft-reset across 50 commits must complete under 5s.""" # Add 48 more commits (we already have 2) for i in range(48): (repo / f"f{i:03d}.py").write_text(f"x={i}\n") _commit(repo, f"commit {i}") from muse.core.commits import read_commit # Walk to the 10th commit from the end current_id = get_head_commit_id(repo, "main") or "" target_id = current_id for _ in range(10): c = read_commit(repo, target_id) if c and c.parent_commit_id: target_id = c.parent_commit_id t0 = time.perf_counter() result = _reset(repo, target_id) elapsed = (time.perf_counter() - t0) * 1000 assert result.exit_code == 0 assert elapsed < 5000, f"50-commit soft reset took {elapsed:.0f}ms (limit 5s)" def test_hard_reset_with_100_files(self, repo: pathlib.Path, c1_id: str) -> None: """Hard-reset restoring 100 files must complete under 5s.""" for i in range(100): (repo / f"g{i:03d}.py").write_text(f"y={i}\n") _commit(repo, "add 100 files") head_id = get_head_commit_id(repo, "main") or "" # Reset back to c1 (removes 101 files) then back to head (restores them) t0 = time.perf_counter() r1 = _reset(repo, c1_id, "--hard") r2 = _reset(repo, head_id, "--hard") elapsed = (time.perf_counter() - t0) * 1000 assert r1.exit_code == 0 assert r2.exit_code == 0 assert elapsed < 5000, f"100-file hard reset cycle took {elapsed:.0f}ms" def test_dry_run_50_commits_fast(self, repo: pathlib.Path) -> None: """Dry-run across 50 commits must complete under 2s.""" for i in range(48): (repo / f"h{i:03d}.py").write_text(f"z={i}\n") _commit(repo, f"commit {i}") t0 = time.perf_counter() result = _reset(repo, "HEAD~1", "--dry-run") elapsed = (time.perf_counter() - t0) * 1000 assert result.exit_code == 0 assert elapsed < 2000, f"dry-run took {elapsed:.0f}ms (limit 2s)" def test_concurrent_resets_separate_repos(self, tmp_path: pathlib.Path) -> None: """Multiple repos resetting concurrently must not interfere.""" errors: list[str] = [] def do_reset(idx: int) -> None: repo_dir = tmp_path / f"repo_{idx}" repo_dir.mkdir() subprocess.run(["muse", "init"], cwd=str(repo_dir), capture_output=True) (repo_dir / "a.py").write_text(f"x={idx}\n") subprocess.run(["muse", "code", "add", "."], cwd=str(repo_dir), capture_output=True) subprocess.run( ["muse", "commit", "-m", f"c1_{idx}"], cwd=str(repo_dir), capture_output=True, ) (repo_dir / "b.py").write_text(f"y={idx}\n") subprocess.run(["muse", "code", "add", "."], cwd=str(repo_dir), capture_output=True) subprocess.run( ["muse", "commit", "-m", f"c2_{idx}"], cwd=str(repo_dir), capture_output=True, ) r = subprocess.run( ["muse", "reset", "HEAD~1", "--json"], cwd=str(repo_dir), capture_output=True, text=True, ) if r.returncode != 0: errors.append(f"repo_{idx}: exit={r.returncode}, err={r.stderr[:60]}") return try: data = json.loads(r.stdout) if data["mode"] != "soft": errors.append(f"repo_{idx}: unexpected mode {data['mode']}") if data["dry_run"] is not False: errors.append(f"repo_{idx}: dry_run not False") except Exception as e: errors.append(f"repo_{idx}: parse error {e}") threads = [threading.Thread(target=do_reset, args=(i,)) for i in range(6)] for t in threads: t.start() for t in threads: t.join() assert not errors, f"Concurrent reset errors:\n{'\n'.join(errors)}"