gabriel / muse public

tree_edit.py file-level

at sha256:1 · View file ↗ · Intel ↗

History
1 files
1 commits
0 hotspots
0 🧊 dead
0 πŸ’₯ blast risk
sha256:0 chore: trigger prebuild on 068c4d6f deployment · gabriel · Jun 21, 2026
1 """LCS-based tree edit algorithm for labeled ordered trees.
2
3 Implements a correct tree diff that produces ``InsertOp``, ``DeleteOp``,
4 ``ReplaceOp``, and ``MoveOp`` entries for labeled ordered trees.
5
6 Algorithm
7 ---------
8 The diff proceeds top-down recursively:
9
10 1. Compare root nodes by ``content_id``. Different content_id β†’ ``ReplaceOp``
11 on the root node.
12 2. Diff the children sequences using the same LCS algorithm as
13 :mod:`~muse.core.diff_algorithms.lcs`:
14
15 - Matched child pairs (same ``content_id``) β†’ recurse into subtree.
16 - Unmatched inserts β†’ ``InsertOp`` (entire subtree added).
17 - Unmatched deletes β†’ ``DeleteOp`` (entire subtree removed).
18 - Paired insert+delete of same ``content_id`` at different positions β†’
19 ``MoveOp``.
20
21 This approach is O(nm) per tree level where n, m are child counts. It does
22 not find the globally minimal edit script (Zhang-Shasha is optimal), but it
23 is correct: every change is accounted for, and applying the script to the base
24 tree produces the target tree. For the bounded tree sizes typical of domain
25 objects (scenes, tracks, ASTs ≲ 10k nodes), this is more than adequate for
26 Zhang-Shasha optimisation is a drop-in replacement once needed.
27
28 ``TreeNode`` is defined here and re-exported by the package ``__init__``.
29
30 Public API
31 ----------
32 - :class:`TreeNode` β€” labeled ordered tree node (frozen dataclass).
33 - :func:`diff` β€” ``TreeNode`` Γ— ``TreeNode`` β†’ ``StructuredDelta``.
34 """
35
36 import logging
37 from dataclasses import dataclass
38 from typing import Literal
39
40 from muse.core.schema import TreeSchema
41 from muse.domain import (
42
43 DeleteOp,
44 DomainOp,
45 InsertOp,
46 MoveOp,
47 ReplaceOp,
48 StructuredDelta,
49 )
50
51 type _DeleteMap = dict[str, tuple[int, "TreeNode"]]
52
53 logger = logging.getLogger(__name__)
54
55 # ---------------------------------------------------------------------------
56 # TreeNode β€” the unit of tree-edit comparison
57 # ---------------------------------------------------------------------------
58
59 @dataclass(frozen=True)
60 class TreeNode:
61 """A node in a labeled ordered tree for domain tree-edit algorithms.
62
63 ``id`` is a stable unique identifier for the node (e.g. content-addressed ID or path).
64 ``label`` is the human-readable name (e.g. element tag, node type).
65 ``content_id`` is the SHA-256 of this node's own value β€” excluding its
66 children. Two nodes are considered the same iff their ``content_id``\\s
67 match; a different ``content_id`` triggers a ``ReplaceOp``.
68 ``children`` is an ordered tuple of child nodes.
69 """
70
71 id: str
72 label: str
73 content_id: str
74 children: tuple[TreeNode, ...]
75
76 # ---------------------------------------------------------------------------
77 # Internal helpers
78 # ---------------------------------------------------------------------------
79
80 def _subtree_nodes(node: TreeNode) -> list[TreeNode]:
81 """Return all nodes in *node*'s subtree (postorder)."""
82 result: list[TreeNode] = []
83
84 def _visit(n: TreeNode) -> None:
85 for child in n.children:
86 _visit(child)
87 result.append(n)
88
89 _visit(node)
90 return result
91
92 def _lcs_children(
93 base_children: tuple[TreeNode, ...],
94 target_children: tuple[TreeNode, ...],
95 ) -> list[tuple[Literal["keep", "insert", "delete"], int, int]]:
96 """LCS shortest-edit script on two sequences of child nodes.
97
98 Comparison is by ``id`` β€” children with the same id are matched (a "keep"),
99 even if their ``content_id`` differs. A kept pair that has a different
100 ``content_id`` will produce a ``ReplaceOp`` when recursed into by
101 :func:`_diff_nodes`.
102
103 Unmatched children produce insert / delete ops.
104
105 Returns a list of ``(kind, base_idx, target_idx)`` triples.
106 """
107 n, m = len(base_children), len(target_children)
108 base_ids = [c.id for c in base_children]
109 target_ids = [c.id for c in target_children]
110
111 dp: list[list[int]] = [[0] * (m + 1) for _ in range(n + 1)]
112 for i in range(n - 1, -1, -1):
113 for j in range(m - 1, -1, -1):
114 if base_ids[i] == target_ids[j]:
115 dp[i][j] = dp[i + 1][j + 1] + 1
116 else:
117 dp[i][j] = max(dp[i + 1][j], dp[i][j + 1])
118
119 result: list[tuple[Literal["keep", "insert", "delete"], int, int]] = []
120 i, j = 0, 0
121 while i < n or j < m:
122 if i < n and j < m and base_ids[i] == target_ids[j]:
123 result.append(("keep", i, j))
124 i += 1
125 j += 1
126 elif j < m and (i >= n or dp[i][j + 1] >= dp[i + 1][j]):
127 result.append(("insert", i, j))
128 j += 1
129 else:
130 result.append(("delete", i, j))
131 i += 1
132
133 return result
134
135 def _diff_nodes(
136 base: TreeNode,
137 target: TreeNode,
138 *,
139 domain: str,
140 address: str,
141 ) -> list[DomainOp]:
142 """Recursively diff two tree nodes, returning a flat op list."""
143 ops: list[DomainOp] = []
144 node_addr = f"{address}/{base.id}" if address else base.id
145
146 # Root node comparison
147 if base.content_id != target.content_id:
148 ops.append(
149 ReplaceOp(
150 op="replace",
151 address=node_addr,
152 position=None,
153 old_content_id=base.content_id,
154 new_content_id=target.content_id,
155 old_summary=f"{base.label} (prev)",
156 new_summary=f"{target.label} (new)",
157 )
158 )
159
160 if not base.children and not target.children:
161 return ops
162
163 # Diff children via LCS
164 script = _lcs_children(base.children, target.children)
165
166 raw_inserts: list[tuple[int, TreeNode]] = [] # (target_idx, node)
167 raw_deletes: list[tuple[int, TreeNode]] = [] # (base_idx, node)
168
169 for kind, bi, ti in script:
170 if kind == "keep":
171 # Recurse into the matched child pair
172 ops.extend(
173 _diff_nodes(
174 base.children[bi],
175 target.children[ti],
176 domain=domain,
177 address=node_addr,
178 )
179 )
180 elif kind == "insert":
181 raw_inserts.append((ti, target.children[ti]))
182 else:
183 raw_deletes.append((bi, base.children[bi]))
184
185 # Move detection: paired insert+delete of the same node id at different positions.
186 # Node identity is tracked by id, not content_id, so a repositioned node
187 # is detected as a move even if its content also changed.
188 delete_by_id: _DeleteMap = {}
189 for bi, node in raw_deletes:
190 if node.id not in delete_by_id:
191 delete_by_id[node.id] = (bi, node)
192
193 consumed_ids: set[str] = set()
194 for ti, node in raw_inserts:
195 nid = node.id
196 if nid in delete_by_id and nid not in consumed_ids:
197 from_idx, _ = delete_by_id[nid]
198 if from_idx != ti:
199 ops.append(
200 MoveOp(
201 op="move",
202 address=node_addr,
203 from_position=from_idx,
204 to_position=ti,
205 content_id=node.content_id,
206 )
207 )
208 consumed_ids.add(nid)
209 continue
210 # True insert β€” recursively add the entire subtree's nodes
211 for sub_node in _subtree_nodes(node):
212 ops.append(
213 InsertOp(
214 op="insert",
215 address=node_addr,
216 position=ti,
217 content_id=sub_node.content_id,
218 content_summary=f"{sub_node.label} added",
219 )
220 )
221
222 for bi, node in raw_deletes:
223 if node.id in consumed_ids:
224 continue
225 # True delete β€” recursively remove the entire subtree's nodes
226 for sub_node in _subtree_nodes(node):
227 ops.append(
228 DeleteOp(
229 op="delete",
230 address=node_addr,
231 position=bi,
232 content_id=sub_node.content_id,
233 content_summary=f"{sub_node.label} removed",
234 )
235 )
236
237 return ops
238
239 # ---------------------------------------------------------------------------
240 # Top-level diff entry point
241 # ---------------------------------------------------------------------------
242
243 def diff(
244 schema: TreeSchema,
245 base: TreeNode,
246 target: TreeNode,
247 *,
248 domain: str,
249 address: str = "",
250 ) -> StructuredDelta:
251 """Diff two labeled ordered trees, returning a ``StructuredDelta``.
252
253 Produces ``ReplaceOp`` for node relabels, ``InsertOp`` / ``DeleteOp``
254 for subtree insertions and deletions, and ``MoveOp`` for repositioned
255 subtrees (detected as paired delete+insert of the same content).
256
257 Args:
258 schema: The ``TreeSchema`` declaring node type and diff algorithm.
259 base: Root of the base (ancestor) tree.
260 target: Root of the target (newer) tree.
261 domain: Domain tag for the returned ``StructuredDelta``.
262 address: Address prefix for generated op entries.
263
264 Returns:
265 A ``StructuredDelta`` with typed ops and a human-readable summary.
266 """
267 # Fast path: identical trees
268 if base.content_id == target.content_id and base.children == target.children:
269 return StructuredDelta(
270 domain=domain,
271 ops=[],
272 summary=f"no {schema['node_type']} changes",
273 )
274
275 ops = _diff_nodes(base, target, domain=domain, address=address)
276
277 n_replace = sum(1 for op in ops if op["op"] == "replace")
278 n_insert = sum(1 for op in ops if op["op"] == "insert")
279 n_delete = sum(1 for op in ops if op["op"] == "delete")
280 n_move = sum(1 for op in ops if op["op"] == "move")
281
282 parts: list[str] = []
283 if n_replace:
284 parts.append(f"{n_replace} relabelled")
285 if n_insert:
286 parts.append(f"{n_insert} added")
287 if n_delete:
288 parts.append(f"{n_delete} removed")
289 if n_move:
290 parts.append(f"{n_move} moved")
291 summary = ", ".join(parts) if parts else f"no {schema['node_type']} changes"
292
293 logger.debug(
294 "tree_edit.diff: +%d -%d ~%d r%d ops on %r",
295 n_insert,
296 n_delete,
297 n_move,
298 n_replace,
299 address,
300 )
301
302 return StructuredDelta(domain=domain, ops=ops, summary=summary)