"""TDD — walk_dag generic DAG walker (Phase 1). Tests are written against the planned public API before implementation. All tests in this file are expected to FAIL until walk_dag is implemented in muse/core/graph.py. Coverage -------- walk_dag - empty starts → yields nothing - single node, no neighbours → yields just that node - linear chain BFS order (root last) - linear chain DFS order (root last — same for a chain) - diamond DAG: shared node visited exactly once (BFS) - diamond DAG: shared node visited exactly once (DFS) - BFS yields level-order (breadth before depth) - DFS yields depth-first (depth before breadth) - prune predicate stops subtree at matched node (node itself NOT yielded) - prune at root → yields nothing - exclude pre-seeds visited set (excluded nodes and subtrees skipped) - max_nodes cap stops iteration after N nodes - max_nodes=0 → yields nothing - generic type: works with ints, not just strings - multi-source starts: seeds from multiple start nodes - multi-source: shared ancestor visited once - adjacency returning empty → walk terminates at leaves - invalid order value → raises ValueError iter_ancestors rewired - iter_ancestors routes through walk_dag (structural check) """ from __future__ import annotations from collections.abc import Iterable, Iterator from typing import Callable import pytest from muse.core.graph import walk_dag # --------------------------------------------------------------------------- # Shared test graph topologies (pure in-memory dicts) # --------------------------------------------------------------------------- # Linear: A → B → C → D _LINEAR: dict[str, list[str]] = { "A": ["B"], "B": ["C"], "C": ["D"], "D": [], } # Diamond: A → B → D # A → C → D _DIAMOND: dict[str, list[str]] = { "A": ["B", "C"], "B": ["D"], "C": ["D"], "D": [], } # Wide tree: root → [L, R] # L → [LL, LR] # R → [RL, RR] _TREE: dict[str, list[str]] = { "root": ["L", "R"], "L": ["LL", "LR"], "R": ["RL", "RR"], "LL": [], "LR": [], "RL": [], "RR": [], } type _AdjGraph = dict[str, list[str]] def _adj(graph: _AdjGraph) -> Callable[[str], Iterable[str]]: """Return an adjacency function for a dict-based graph.""" return lambda node: graph.get(node, []) # --------------------------------------------------------------------------- # Basic traversal # --------------------------------------------------------------------------- class TestWalkDagBasic: def test_empty_starts_yields_nothing(self) -> None: result = list(walk_dag([], _adj(_LINEAR))) assert result == [] def test_single_node_no_neighbours(self) -> None: graph = {"X": []} result = list(walk_dag("X", _adj(graph))) assert result == ["X"] def test_single_node_iterable_start(self) -> None: graph = {"X": []} result = list(walk_dag(["X"], _adj(graph))) assert result == ["X"] def test_linear_chain_bfs_all_visited(self) -> None: result = list(walk_dag("A", _adj(_LINEAR), order="bfs")) assert set(result) == {"A", "B", "C", "D"} def test_linear_chain_dfs_all_visited(self) -> None: result = list(walk_dag("A", _adj(_LINEAR), order="dfs")) assert set(result) == {"A", "B", "C", "D"} def test_linear_chain_bfs_start_first(self) -> None: result = list(walk_dag("A", _adj(_LINEAR), order="bfs")) assert result[0] == "A" def test_linear_chain_dfs_start_first(self) -> None: result = list(walk_dag("A", _adj(_LINEAR), order="dfs")) assert result[0] == "A" def test_leaf_node_as_start(self) -> None: """Starting from a leaf yields only that leaf.""" result = list(walk_dag("D", _adj(_LINEAR))) assert result == ["D"] def test_invalid_order_raises(self) -> None: with pytest.raises(ValueError, match="order"): list(walk_dag("A", _adj(_LINEAR), order="zigzag")) # type: ignore[arg-type] # --------------------------------------------------------------------------- # BFS vs DFS order # --------------------------------------------------------------------------- class TestWalkDagOrder: def test_bfs_is_level_order(self) -> None: """BFS on wide tree: root before L/R, L/R before leaves.""" result = list(walk_dag("root", _adj(_TREE), order="bfs")) root_idx = result.index("root") l_idx = result.index("L") r_idx = result.index("R") ll_idx = result.index("LL") lr_idx = result.index("LR") # root before both children assert root_idx < l_idx assert root_idx < r_idx # both children before any grandchild assert l_idx < ll_idx assert l_idx < lr_idx assert r_idx < ll_idx def test_dfs_is_depth_first(self) -> None: """DFS on wide tree: root first, then one branch fully explored before other.""" result = list(walk_dag("root", _adj(_TREE), order="dfs")) root_idx = result.index("root") # root must be first assert root_idx == 0 # At least one subtree's leaf comes before the other subtree's root. # DFS: root → first-child branch fully → second-child branch. # _TREE adjacency: root → [L, R] # DFS stack pops R first, then L. So actual order depends on stack behaviour. # The invariant we care about: root is first, and no interleaving of L/R subtrees. l_idx = result.index("L") r_idx = result.index("R") ll_idx = result.index("LL") lr_idx = result.index("LR") # Whichever branch comes first, its leaves must be before the other branch's root. if l_idx < r_idx: # L branch explored first assert ll_idx < r_idx or lr_idx < r_idx else: # R branch explored first rl_idx = result.index("RL") rr_idx = result.index("RR") assert rl_idx < l_idx or rr_idx < l_idx def test_bfs_vs_dfs_different_on_tree(self) -> None: """BFS and DFS produce different orderings on a non-trivial tree.""" bfs = list(walk_dag("root", _adj(_TREE), order="bfs")) dfs = list(walk_dag("root", _adj(_TREE), order="dfs")) assert set(bfs) == set(dfs) # same nodes assert bfs != dfs # different order # --------------------------------------------------------------------------- # Diamond (deduplication) # --------------------------------------------------------------------------- class TestWalkDagDiamond: def test_diamond_bfs_visits_shared_once(self) -> None: result = list(walk_dag("A", _adj(_DIAMOND), order="bfs")) assert result.count("D") == 1 assert set(result) == {"A", "B", "C", "D"} def test_diamond_dfs_visits_shared_once(self) -> None: result = list(walk_dag("A", _adj(_DIAMOND), order="dfs")) assert result.count("D") == 1 assert set(result) == {"A", "B", "C", "D"} # --------------------------------------------------------------------------- # prune predicate # --------------------------------------------------------------------------- class TestWalkDagPrune: def test_prune_stops_at_matched_node(self) -> None: """prune("B") → B and its subtree [C, D] are never yielded.""" result = list(walk_dag("A", _adj(_LINEAR), prune=lambda n: n == "B")) assert "A" in result assert "B" not in result assert "C" not in result assert "D" not in result def test_prune_at_root_yields_nothing(self) -> None: """Pruning the start node itself → nothing yielded.""" result = list(walk_dag("A", _adj(_LINEAR), prune=lambda n: n == "A")) assert result == [] def test_prune_at_leaf_no_effect(self) -> None: """Pruning a leaf doesn't affect earlier nodes.""" result = set(walk_dag("A", _adj(_LINEAR), prune=lambda n: n == "D")) assert result == {"A", "B", "C"} def test_prune_one_branch_of_diamond(self) -> None: """Pruning B in diamond: C path still reaches D; B subtree skipped.""" # A → B (pruned), A → C → D result = list(walk_dag("A", _adj(_DIAMOND), prune=lambda n: n == "B")) assert "A" in result assert "B" not in result assert "C" in result assert "D" in result # D reachable via C def test_prune_both_branches_of_diamond(self) -> None: """Pruning both B and C → D never reached.""" result = list(walk_dag("A", _adj(_DIAMOND), prune=lambda n: n in {"B", "C"})) assert result == ["A"] # --------------------------------------------------------------------------- # exclude # --------------------------------------------------------------------------- class TestWalkDagExclude: def test_exclude_pre_seeds_visited(self) -> None: """Excluded nodes are treated as already-visited; not yielded.""" result = list(walk_dag("A", _adj(_LINEAR), exclude={"B"})) assert "A" in result assert "B" not in result assert "C" not in result # C only reachable via B assert "D" not in result def test_exclude_start_node_yields_nothing(self) -> None: """Excluding the start node → nothing yielded (start treated as visited).""" result = list(walk_dag("A", _adj(_LINEAR), exclude={"A"})) assert result == [] def test_exclude_leaf(self) -> None: """Excluding a leaf: other nodes still visited.""" result = set(walk_dag("A", _adj(_LINEAR), exclude={"D"})) assert result == {"A", "B", "C"} def test_exclude_does_not_mutate_input(self) -> None: """walk_dag must not mutate the caller's exclude set.""" excl: set[str] = set() list(walk_dag("A", _adj(_LINEAR), exclude=excl)) assert excl == set() # --------------------------------------------------------------------------- # max_nodes cap # --------------------------------------------------------------------------- class TestWalkDagMaxNodes: def test_max_nodes_zero_yields_nothing(self) -> None: result = list(walk_dag("A", _adj(_LINEAR), max_nodes=0)) assert result == [] def test_max_nodes_one_yields_start(self) -> None: result = list(walk_dag("A", _adj(_LINEAR), max_nodes=1)) assert result == ["A"] def test_max_nodes_caps_count(self) -> None: result = list(walk_dag("A", _adj(_LINEAR), max_nodes=2)) assert len(result) == 2 def test_max_nodes_larger_than_graph(self) -> None: """max_nodes larger than the graph → all nodes yielded normally.""" result = list(walk_dag("A", _adj(_LINEAR), max_nodes=999)) assert set(result) == {"A", "B", "C", "D"} # --------------------------------------------------------------------------- # Generic type (ints) # --------------------------------------------------------------------------- class TestWalkDagGenericType: def test_works_with_integers(self) -> None: """walk_dag is generic — must work with int nodes, not just strings.""" graph: dict[int, list[int]] = {1: [2, 3], 2: [4], 3: [4], 4: []} result = list(walk_dag(1, lambda n: graph.get(n, []))) assert set(result) == {1, 2, 3, 4} assert result.count(4) == 1 # shared node visited once def test_works_with_tuples(self) -> None: """Nodes can be any hashable — tuples work fine (wrap in list to start).""" graph: dict[tuple[int, int], list[tuple[int, int]]] = { (0, 0): [(1, 0), (0, 1)], (1, 0): [], (0, 1): [], } result = list(walk_dag([(0, 0)], lambda n: graph.get(n, []))) assert set(result) == {(0, 0), (1, 0), (0, 1)} # --------------------------------------------------------------------------- # Multi-source starts # --------------------------------------------------------------------------- class TestWalkDagMultiSource: def test_multi_source_visits_all_reachable(self) -> None: graph: dict[str, list[str]] = { "shared": [], "a": ["shared"], "b": ["shared"], } result = list(walk_dag(["a", "b"], _adj(graph))) assert set(result) == {"a", "b", "shared"} def test_multi_source_shared_ancestor_once(self) -> None: graph: dict[str, list[str]] = { "shared": [], "a": ["shared"], "b": ["shared"], } result = list(walk_dag(["a", "b"], _adj(graph))) assert result.count("shared") == 1 def test_multi_source_exclude_applies_to_all(self) -> None: graph: dict[str, list[str]] = { "shared": [], "a": ["shared"], "b": ["shared"], } result = list(walk_dag(["a", "b"], _adj(graph), exclude={"shared"})) assert "shared" not in result assert set(result) == {"a", "b"} # --------------------------------------------------------------------------- # Structural: iter_ancestors routes through walk_dag # --------------------------------------------------------------------------- class TestIterAncestorsRoutesThroughWalkDag: def test_iter_ancestors_calls_walk_dag(self) -> None: """iter_ancestors must delegate to walk_dag (implementation check).""" import inspect from muse.core import graph as graph_module source = inspect.getsource(graph_module.iter_ancestors) assert "walk_dag" in source, ( "iter_ancestors must call walk_dag after Phase 1 refactor" )