"""Tests for the code-domain invariants engine.""" import pathlib import tempfile import pytest from muse.core.invariants import InvariantChecker from muse.plugins.code._invariants import ( CodeChecker, CodeInvariantRule, check_max_complexity, check_no_circular_imports, check_no_dead_exports, check_test_coverage_floor, load_invariant_rules, run_invariants, ) from muse.core.object_store import object_path from muse.core.paths import code_invariants_path, muse_dir # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_repo(tmp_path: pathlib.Path) -> pathlib.Path: """Set up a minimal .muse/ structure.""" dot_muse = muse_dir(tmp_path) dot_muse.mkdir() (dot_muse / "repo.json").write_text('{"repo_id":"test"}') (dot_muse / "HEAD").write_text("ref: refs/heads/main") (dot_muse / "commits").mkdir() (dot_muse / "snapshots").mkdir() (dot_muse / "refs" / "heads").mkdir(parents=True) (dot_muse / "objects").mkdir() return tmp_path def _write_object(root: pathlib.Path, content: bytes) -> str: from muse.core.types import blob_id from muse.core.object_store import write_object oid = blob_id(content) write_object(root, oid, content) return oid # --------------------------------------------------------------------------- # _estimate_complexity (via check_max_complexity) # --------------------------------------------------------------------------- class TestMaxComplexity: def test_simple_function_no_violation(self) -> None: with tempfile.TemporaryDirectory() as tmp: root = _make_repo(pathlib.Path(tmp)) src = b"def simple():\n return 1\n" h = _write_object(root, src) manifest = {"mod.py": h} violations = check_max_complexity(manifest, root, "test", "error", threshold=10) assert violations == [] def test_complex_function_triggers_violation(self) -> None: # 15+ branches = definitely over threshold 5. src = b""" def complex(): if True: pass if True: pass if True: pass if True: pass if True: pass if True: pass if True: pass return 1 """ with tempfile.TemporaryDirectory() as tmp: root = _make_repo(pathlib.Path(tmp)) h = _write_object(root, src) manifest = {"mod.py": h} violations = check_max_complexity(manifest, root, "gate", "error", threshold=5) assert len(violations) >= 1 assert violations[0]["rule_name"] == "gate" assert "complexity" in violations[0]["description"].lower() def test_non_python_file_skipped(self) -> None: with tempfile.TemporaryDirectory() as tmp: root = _make_repo(pathlib.Path(tmp)) src = b"def hello() { return 1; }" h = _write_object(root, src) manifest = {"mod.js": h} violations = check_max_complexity(manifest, root, "c", "error", threshold=1) assert violations == [] # --------------------------------------------------------------------------- # check_no_circular_imports # --------------------------------------------------------------------------- class TestNoCircularImports: def test_no_cycle_returns_empty(self) -> None: with tempfile.TemporaryDirectory() as tmp: root = _make_repo(pathlib.Path(tmp)) a = b"import b\n" b_src = b"x = 1\n" ha = _write_object(root, a) hb = _write_object(root, b_src) manifest = {"a.py": ha, "b.py": hb} violations = check_no_circular_imports(manifest, root, "no_cycles", "error") assert violations == [] def test_cycle_detected(self) -> None: with tempfile.TemporaryDirectory() as tmp: root = _make_repo(pathlib.Path(tmp)) # a imports b, b imports a → cycle a = b"import b\n" b_src = b"import a\n" ha = _write_object(root, a) hb = _write_object(root, b_src) manifest = {"a.py": ha, "b.py": hb} violations = check_no_circular_imports(manifest, root, "no_cycles", "error") assert len(violations) >= 1 assert "cycle" in violations[0]["description"].lower() def test_three_file_cycle_detected(self) -> None: with tempfile.TemporaryDirectory() as tmp: root = _make_repo(pathlib.Path(tmp)) a = b"import b\n" b_src = b"import c\n" c_src = b"import a\n" ha = _write_object(root, a) hb = _write_object(root, b_src) hc = _write_object(root, c_src) manifest = {"a.py": ha, "b.py": hb, "c.py": hc} violations = check_no_circular_imports(manifest, root, "cycles", "error") assert len(violations) >= 1 # --------------------------------------------------------------------------- # check_no_dead_exports # --------------------------------------------------------------------------- class TestNoDeadExports: def test_used_function_not_reported(self) -> None: with tempfile.TemporaryDirectory() as tmp: root = _make_repo(pathlib.Path(tmp)) lib = b"def my_func():\n return 1\n" main = b"from lib import my_func\n" hl = _write_object(root, lib) hm = _write_object(root, main) manifest = {"lib.py": hl, "main.py": hm} violations = check_no_dead_exports(manifest, root, "dead", "warning") # lib.my_func is imported by main.py → should not be reported. addresses = [v["address"] for v in violations] assert "lib.py::my_func" not in addresses def test_unused_function_reported(self) -> None: with tempfile.TemporaryDirectory() as tmp: root = _make_repo(pathlib.Path(tmp)) lib = b"def orphan_fn():\n return 1\n" other = b"x = 1\n" hl = _write_object(root, lib) ho = _write_object(root, other) manifest = {"lib.py": hl, "other.py": ho} violations = check_no_dead_exports(manifest, root, "dead", "warning") addresses = [v["address"] for v in violations] assert "lib.py::orphan_fn" in addresses def test_private_function_exempt(self) -> None: with tempfile.TemporaryDirectory() as tmp: root = _make_repo(pathlib.Path(tmp)) lib = b"def _private():\n return 1\n" h = _write_object(root, lib) manifest = {"lib.py": h} violations = check_no_dead_exports(manifest, root, "dead", "warning") # Private functions are exempt. assert all("_private" not in v["address"] for v in violations) def test_test_file_exempt(self) -> None: with tempfile.TemporaryDirectory() as tmp: root = _make_repo(pathlib.Path(tmp)) lib = b"def test_something():\n assert True\n" h = _write_object(root, lib) manifest = {"test_stuff.py": h} violations = check_no_dead_exports(manifest, root, "dead", "warning") assert violations == [] # --------------------------------------------------------------------------- # check_test_coverage_floor # --------------------------------------------------------------------------- class TestTestCoverageFloor: def test_well_covered_code_no_violation(self) -> None: with tempfile.TemporaryDirectory() as tmp: root = _make_repo(pathlib.Path(tmp)) src = b"def foo():\n return 1\n" test_src = b"def test_foo():\n assert True\n" hs = _write_object(root, src) ht = _write_object(root, test_src) manifest = {"src.py": hs, "test_src.py": ht} violations = check_test_coverage_floor(manifest, root, "coverage", "warning", min_ratio=0.5) assert violations == [] def test_uncovered_code_violates(self) -> None: with tempfile.TemporaryDirectory() as tmp: root = _make_repo(pathlib.Path(tmp)) src = b"def foo():\n pass\ndef bar():\n pass\ndef baz():\n pass\n" h = _write_object(root, src) manifest = {"src.py": h} violations = check_test_coverage_floor(manifest, root, "coverage", "warning", min_ratio=0.5) assert len(violations) == 1 assert "coverage floor" in violations[0]["description"].lower() def test_no_functions_no_violation(self) -> None: with tempfile.TemporaryDirectory() as tmp: root = _make_repo(pathlib.Path(tmp)) src = b"X = 1\n" h = _write_object(root, src) manifest = {"config.py": h} violations = check_test_coverage_floor(manifest, root, "coverage", "warning", min_ratio=0.5) assert violations == [] # --------------------------------------------------------------------------- # load_invariant_rules # --------------------------------------------------------------------------- class TestLoadInvariantRules: def test_no_file_returns_defaults(self) -> None: """load_invariant_rules(repo_root, None) with no file on disk returns built-in defaults.""" with tempfile.TemporaryDirectory() as tmp: root = _make_repo(pathlib.Path(tmp)) rules = load_invariant_rules(root, None) assert len(rules) >= 1 rule_types = {r["rule_type"] for r in rules} assert "max_complexity" in rule_types def test_default_path_from_repo_root(self) -> None: """When rules_file is None, the default is .muse/code_invariants.toml inside repo_root.""" toml = "[[rule]]\nname='repo_rule'\nseverity='error'\nscope='function'\nrule_type='max_complexity'\n" with tempfile.TemporaryDirectory() as tmp: root = _make_repo(pathlib.Path(tmp)) code_invariants_path(root).write_text(toml) rules = load_invariant_rules(root, None) assert any(r["name"] == "repo_rule" for r in rules) def test_explicit_missing_path_returns_empty(self) -> None: """An explicit path that does not exist yields no rules (caller opts out of defaults).""" with tempfile.TemporaryDirectory() as tmp: root = _make_repo(pathlib.Path(tmp)) rules = load_invariant_rules(root, pathlib.Path("/no/such/file.toml")) assert rules == [] def test_toml_file_loaded(self) -> None: toml = "[[rule]]\nname='r1'\nseverity='error'\nscope='function'\nrule_type='max_complexity'\n" with tempfile.TemporaryDirectory() as tmp: root = _make_repo(pathlib.Path(tmp)) with tempfile.NamedTemporaryFile(suffix=".toml", mode="w", delete=False) as f: f.write(toml) path = pathlib.Path(f.name) try: rules = load_invariant_rules(root, path) assert any(r["rule_type"] == "max_complexity" for r in rules) finally: path.unlink(missing_ok=True) # --------------------------------------------------------------------------- # CodeChecker (protocol) # --------------------------------------------------------------------------- class TestCodeChecker: def test_satisfies_invariant_checker_protocol(self) -> None: checker = CodeChecker() assert isinstance(checker, InvariantChecker) def test_check_returns_base_report(self) -> None: with tempfile.TemporaryDirectory() as tmp: root = _make_repo(pathlib.Path(tmp)) # No commits — check should return a report with 0 violations. from muse.core.commits import ( CommitRecord, write_commit, ) from muse.core.snapshots import ( SnapshotRecord, write_snapshot, ) from muse.core.ids import hash_commit as compute_commit_id, hash_snapshot as compute_snapshot_id import datetime snap_id = compute_snapshot_id({}) snap = SnapshotRecord(snapshot_id=snap_id, manifest={}) write_snapshot(root, snap) ts = datetime.datetime(2026, 1, 1, tzinfo=datetime.timezone.utc) cid = compute_commit_id( parent_ids=[], snapshot_id=snap_id, message="init", committed_at_iso=ts.isoformat(), ) commit = CommitRecord( commit_id=cid, branch="main", snapshot_id=snap_id, message="init", committed_at=ts, ) write_commit(root, commit) report = CodeChecker().check(root, cid) assert report["commit_id"] == cid assert report["domain"] == "code" assert isinstance(report["violations"], list)