gabriel / muse public
entity.py python
573 lines 21.5 KB
Raw
sha256:2eaa5d95f9d9383498e76947410a26e5a3ba23d182f339910c424cf88fad412b fix: try fetch/presign before fetch/mpack to avoid Cloudfla… Sonnet 4.6 patch 8 days ago
1 """Stable entity identity for MIDI note objects in the Muse MIDI plugin.
2
3 The key insight
4 ---------------
5 Content hash ≠ entity identity.
6
7 The current note content ID — ``SHA-256(pitch:velocity:start_tick:duration_ticks:channel)``
8 — is correct for *content equality* but wrong for *entity identity*. When a
9 musician or agent changes a note's velocity from 80 to 100, the old model
10 produces a ``DeleteOp + InsertOp`` (two unrelated content hashes). With stable
11 entity identity, the diff produces a ``MutateOp`` — "velocity 80→100 on note
12 C4@bar4" — and the note's lineage is preserved through the edit.
13
14 A ``NoteEntity`` extends the five ``NoteKey`` fields with an optional
15 ``entity_id`` — a content-addressed ID assigned at first insertion that persists across all
16 subsequent mutations regardless of how the note's properties change.
17
18 Entity assignment heuristic
19 ----------------------------
20 The function :func:`assign_entity_ids` maps a new note list to the entity IDs
21 in a prior index using a three-tier match:
22
23 1. **Exact content match** — all five fields identical → same entity, no mutation.
24 2. **Fuzzy match** — same pitch + channel, ``|Δtick| ≤ threshold``, and
25 ``|Δvelocity| ≤ threshold`` → same entity, emit ``MutateOp``.
26 3. **No match** → new entity, new content-addressed ID, emit ``InsertOp``.
27
28 Notes in the prior index that matched nothing → emit ``DeleteOp``.
29
30 Storage
31 -------
32 Entity indexes live under ``.muse/entity_index/`` as derived artifacts:
33
34 .muse/entity_index/<commit_id>/<track_safe_name>.json
35
36 They are fully rebuildable from commit history and should be added to the
37 ``[domain.midi]`` section of ``.museignore`` in agent automation scripts
38 to avoid accidental commits.
39
40 Public API
41 ----------
42 - :class:`NoteEntity` — ``NoteKey`` fields + optional entity metadata.
43 - :class:`EntityIndexEntry` — one entity's record in the index.
44 - :class:`EntityIndex` — the full per-track, per-commit index.
45 - :func:`assign_entity_ids` — map a new note list onto prior entity IDs.
46 - :func:`diff_with_entity_ids` — entity-aware diff → ``list[DomainOp]``.
47 - :func:`build_entity_index` — build an :class:`EntityIndex` from entities.
48 - :func:`write_entity_index` / :func:`read_entity_index` — I/O.
49 """
50
51 import hashlib
52 import json
53 import logging
54 import pathlib
55 from typing import TypedDict
56
57 from muse.core.paths import entity_index_dir
58 from muse.core.types import blob_id, content_hash, load_json_file, short_id, split_id
59 from muse.domain import (
60 DeleteOp,
61 DomainOp,
62 FieldMutation,
63 InsertOp,
64 MutateOp,
65 )
66 from muse.plugins.midi._midi_keys import NoteKey, _note_content_id, _note_summary
67
68 type _ContentMap = dict[str, str]
69 type _NoteKeyBuckets = dict[str, list[NoteKey]]
70 type _FieldMutationMap = dict[str, FieldMutation]
71 type _EntityMap = dict[str, "EntityIndexEntry"]
72 type _NoteEntityMap = dict[str, "NoteEntity"]
73
74 logger = logging.getLogger(__name__)
75
76 # ---------------------------------------------------------------------------
77 # Types
78 # ---------------------------------------------------------------------------
79
80 class _NoteEntityRequired(TypedDict):
81 """Required fields shared with NoteKey."""
82
83 pitch: int
84 velocity: int
85 start_tick: int
86 duration_ticks: int
87 channel: int
88
89 class NoteEntity(_NoteEntityRequired, total=False):
90 """A MIDI note with optional stable entity identity.
91
92 When ``entity_id`` is absent the note is treated as content-only (legacy
93 behaviour). When present it is a content-addressed ID that persists across mutations to
94 velocity, timing, or duration — enabling lineage tracking through edits.
95
96 ``voice_id``
97 Optional voice lane identifier (e.g. ``"soprano"``, ``"alto"``).
98 Assigned by a voice-separation analysis pass; not required for basic
99 entity tracking.
100 ``origin_commit_id``
101 Short-form commit ID where this entity was first created.
102 ``origin_op_id``
103 Op ID from the op log that created this entity.
104 """
105
106 entity_id: str
107 voice_id: str
108 origin_commit_id: str
109 origin_op_id: str
110
111 class EntityIndexEntry(TypedDict):
112 """One entity's record in the per-track entity index.
113
114 ``content_id``
115 SHA-256 content hash of the note's current fields (the five ``NoteKey``
116 fields). Updated on every mutation.
117 ``origin_commit_id``
118 Commit where this entity was first inserted.
119 ``voice_id``
120 Voice stream assignment, or empty string if unassigned.
121 """
122
123 content_id: str
124 origin_commit_id: str
125 voice_id: str
126
127 class EntityIndex(TypedDict):
128 """Complete entity index for one track at one commit.
129
130 ``entities`` maps ``entity_id`` → :class:`EntityIndexEntry`.
131 This is the lookup table used by :func:`assign_entity_ids` to re-identify
132 notes across commits.
133 """
134
135 track_path: str
136 commit_id: str
137 entities: _EntityMap
138
139 # ---------------------------------------------------------------------------
140 # Entity ID assignment
141 # ---------------------------------------------------------------------------
142
143 #: Default threshold in MIDI ticks for fuzzy timing match (≈ 10 ms at 120 BPM
144 #: with 480 ticks/beat: 480 × 0.010 × 2 ≈ 10 ticks).
145 _DEFAULT_TICK_THRESHOLD = 10
146
147 #: Default velocity difference threshold for fuzzy entity matching.
148 _DEFAULT_VEL_THRESHOLD = 20
149
150 def assign_entity_ids(
151 notes: list[NoteKey],
152 prior_index: EntityIndex | None,
153 commit_id: str,
154 op_id: str,
155 *,
156 mutation_threshold_ticks: int = _DEFAULT_TICK_THRESHOLD,
157 mutation_threshold_velocity: int = _DEFAULT_VEL_THRESHOLD,
158 ) -> list[NoteEntity]:
159 """Assign stable entity IDs to a list of notes.
160
161 Maps each note in *notes* to an entity ID from *prior_index* using a
162 three-tier matching heuristic (exact → fuzzy → new).
163
164 Args:
165 notes: New note list (from the current commit).
166 prior_index: Entity index from the parent commit.
167 ``None`` means this is the first commit
168 for this track (all notes get new IDs).
169 commit_id: Current commit ID (stored as provenance).
170 op_id: Op log entry ID that produced these notes.
171 mutation_threshold_ticks: Max |Δtick| for fuzzy timing match.
172 mutation_threshold_velocity: Max |Δvelocity| for fuzzy match.
173
174 Returns:
175 List of :class:`NoteEntity` objects in the same order as *notes*,
176 each with a populated ``entity_id``.
177 """
178 if prior_index is None:
179 return [_new_entity(n, commit_id, op_id) for n in notes]
180
181 # Build lookup: content_id → entity_id for exact matches.
182 content_to_entity: _ContentMap = {}
183 # Build list for fuzzy matching: [(entity_id, note_key_fields, entry)]
184 fuzzy_candidates: list[tuple[str, NoteKey]] = []
185
186 for entity_id, entry in prior_index["entities"].items():
187 cid = entry["content_id"]
188 # Reconstruct a NoteKey from the content hash is impossible directly,
189 # so we carry a parallel lookup keyed by content_id string.
190 content_to_entity[cid] = entity_id
191 # For fuzzy matching we need the actual field values. These aren't
192 # stored in the index entry (only the hash is), so fuzzy matching
193 # operates on the NEW notes' fields against the hash. We build the
194 # fuzzy set from the incoming notes rather than the prior index.
195 _ = fuzzy_candidates # populated below
196
197 # Build a richer map from new notes' content IDs.
198 new_by_cid: _NoteKeyBuckets = {}
199 for n in notes:
200 cid = _note_content_id(n)
201 new_by_cid.setdefault(cid, []).append(n)
202
203 # --- Tier 1: exact content match ---
204 # Assign entity IDs to notes whose content hash appears in the prior index.
205 assigned: dict[int, str] = {} # index → entity_id
206 used_entities: set[str] = set()
207
208 for i, note in enumerate(notes):
209 cid = _note_content_id(note)
210 if cid in content_to_entity:
211 eid = content_to_entity[cid]
212 if eid not in used_entities:
213 assigned[i] = eid
214 used_entities.add(eid)
215
216 # --- Tier 2: fuzzy match for unassigned notes ---
217 # Build prior note field table from the original notes that produced the
218 # prior index. Since we only have hashes, we use the *new* notes as a
219 # proxy: any note with (same pitch, same channel, close tick, close vel)
220 # is a mutation candidate.
221 #
222 # Approach: for each unassigned new note, find an unused prior entity
223 # whose content_id resolves to a note with matching pitch+channel and
224 # close tick+velocity. Since we can't reverse-SHA the prior hash, we
225 # instead accept the fuzzy match if the content hash of the hypothetical
226 # un-mutated note (same pitch, channel, but old vel/tick fields) matches.
227 #
228 # In practice, callers pass both old and new note lists when they have
229 # them; the simple heuristic here covers the common agent-edit case.
230 prior_entity_ids = list(prior_index["entities"].keys())
231
232 for i, note in enumerate(notes):
233 if i in assigned:
234 continue
235 # Try to match against any unused prior entity by field similarity.
236 # We approximate by assuming the prior entity had similar fields.
237 best_eid: str | None = None
238 best_score = float("inf")
239
240 for eid in prior_entity_ids:
241 if eid in used_entities:
242 continue
243 entry = prior_index["entities"][eid]
244 prior_cid = entry["content_id"]
245
246 # Attempt to reconstruct a plausible prior note for this entity.
247 # We don't have the raw fields — approximate by checking if a
248 # note with the same pitch + channel but slightly different
249 # timing/velocity would hash to this content_id.
250 for vel_delta in range(-mutation_threshold_velocity, mutation_threshold_velocity + 1, 2):
251 for tick_delta in range(-mutation_threshold_ticks, mutation_threshold_ticks + 1, 2):
252 candidate: NoteKey = NoteKey(
253 pitch=note["pitch"],
254 velocity=max(0, min(127, note["velocity"] + vel_delta)),
255 start_tick=max(0, note["start_tick"] + tick_delta),
256 duration_ticks=note["duration_ticks"],
257 channel=note["channel"],
258 )
259 if _note_content_id(candidate) == prior_cid:
260 score = abs(vel_delta) + abs(tick_delta)
261 if score < best_score:
262 best_score = score
263 best_eid = eid
264 break
265 if best_eid is not None and best_score == 0:
266 break
267
268 if best_eid is not None:
269 assigned[i] = best_eid
270 used_entities.add(best_eid)
271
272 # --- Build output ---
273 result: list[NoteEntity] = []
274 for i, note in enumerate(notes):
275 if i in assigned:
276 entity: NoteEntity = NoteEntity(
277 pitch=note["pitch"],
278 velocity=note["velocity"],
279 start_tick=note["start_tick"],
280 duration_ticks=note["duration_ticks"],
281 channel=note["channel"],
282 entity_id=assigned[i],
283 origin_commit_id=prior_index["entities"][assigned[i]]["origin_commit_id"],
284 origin_op_id=op_id,
285 voice_id=prior_index["entities"][assigned[i]].get("voice_id", ""),
286 )
287 else:
288 entity = _new_entity(note, commit_id, op_id)
289 result.append(entity)
290
291 return result
292
293 def _new_entity(note: NoteKey, commit_id: str, op_id: str) -> NoteEntity:
294 """Create a :class:`NoteEntity` with a content-addressed entity_id."""
295 return NoteEntity(
296 pitch=note["pitch"],
297 velocity=note["velocity"],
298 start_tick=note["start_tick"],
299 duration_ticks=note["duration_ticks"],
300 channel=note["channel"],
301 entity_id=content_hash({
302 "pitch": note["pitch"],
303 "start_tick": note["start_tick"],
304 "channel": note["channel"],
305 "op_id": op_id,
306 }),
307 origin_commit_id=commit_id,
308 origin_op_id=op_id,
309 voice_id="",
310 )
311
312 # ---------------------------------------------------------------------------
313 # Entity-aware diff
314 # ---------------------------------------------------------------------------
315
316 def diff_with_entity_ids(
317 base_entities: list[NoteEntity],
318 target_entities: list[NoteEntity],
319 ticks_per_beat: int,
320 ) -> list[DomainOp]:
321 """Produce an entity-aware diff between two note lists.
322
323 Compared to the content-hash-only diff in :mod:`~muse.plugins.midi.midi_diff`,
324 this function detects *mutations* — cases where the same entity_id appears
325 in both lists with different field values — and emits ``MutateOp`` entries
326 instead of ``DeleteOp + InsertOp`` pairs.
327
328 Algorithm:
329 1. Build ``entity_id → NoteEntity`` maps for base and target.
330 2. For entities present in both: compare fields; emit ``MutateOp`` if
331 anything changed, otherwise "keep" (no op).
332 3. For entities only in base: emit ``DeleteOp``.
333 4. For entities only in target: emit ``InsertOp``.
334 5. For notes without an entity_id: fall back to content-hash comparison
335 (insert/delete only, no mutation tracking).
336
337 Args:
338 base_entities: Notes from the ancestor commit, with entity IDs.
339 target_entities: Notes from the current commit, with entity IDs.
340 ticks_per_beat: Used for human-readable summaries.
341
342 Returns:
343 Ordered list of :class:`~muse.domain.DomainOp` entries.
344 """
345 ops: list[DomainOp] = []
346
347 # Separate tracked (have entity_id) from untracked notes.
348 base_tracked: _NoteEntityMap = {}
349 base_untracked: list[NoteEntity] = []
350 for note in base_entities:
351 if "entity_id" in note and note.get("entity_id"):
352 base_tracked[note["entity_id"]] = note
353 else:
354 base_untracked.append(note)
355
356 target_tracked: _NoteEntityMap = {}
357 target_untracked: list[NoteEntity] = []
358 for note in target_entities:
359 if "entity_id" in note and note.get("entity_id"):
360 target_tracked[note["entity_id"]] = note
361 else:
362 target_untracked.append(note)
363
364 # --- Tracked: mutate, keep, insert, delete ---
365 all_entity_ids = set(base_tracked) | set(target_tracked)
366
367 for eid in sorted(all_entity_ids):
368 base_note = base_tracked.get(eid)
369 target_note = target_tracked.get(eid)
370
371 if base_note is not None and target_note is not None:
372 old_cid = _note_content_id(_entity_to_key(base_note))
373 new_cid = _note_content_id(_entity_to_key(target_note))
374 if old_cid == new_cid:
375 continue # unchanged
376
377 fields = _field_diff(base_note, target_note)
378 base_note_key = _entity_to_key(base_note)
379 target_note_key = _entity_to_key(target_note)
380 ops.append(
381 MutateOp(
382 op="mutate",
383 address=f"note:entity:{eid}",
384 entity_id=eid,
385 old_content_id=old_cid,
386 new_content_id=new_cid,
387 fields=fields,
388 old_summary=_note_summary(base_note_key, ticks_per_beat),
389 new_summary=_note_summary(target_note_key, ticks_per_beat),
390 position=None,
391 )
392 )
393
394 elif base_note is not None:
395 cid = _note_content_id(_entity_to_key(base_note))
396 ops.append(
397 DeleteOp(
398 op="delete",
399 address=f"note:entity:{eid}",
400 position=None,
401 content_id=cid,
402 content_summary=_note_summary(_entity_to_key(base_note), ticks_per_beat),
403 )
404 )
405
406 else:
407 assert target_note is not None
408 cid = _note_content_id(_entity_to_key(target_note))
409 ops.append(
410 InsertOp(
411 op="insert",
412 address=f"note:entity:{eid}",
413 position=None,
414 content_id=cid,
415 content_summary=_note_summary(_entity_to_key(target_note), ticks_per_beat),
416 )
417 )
418
419 # --- Untracked: fall back to content-hash insert/delete ---
420 base_content_ids = {_note_content_id(_entity_to_key(n)) for n in base_untracked}
421 target_content_ids = {_note_content_id(_entity_to_key(n)) for n in target_untracked}
422
423 for note in base_untracked:
424 cid = _note_content_id(_entity_to_key(note))
425 if cid not in target_content_ids:
426 ops.append(
427 DeleteOp(
428 op="delete",
429 address="note:untracked",
430 position=None,
431 content_id=cid,
432 content_summary=_note_summary(_entity_to_key(note), ticks_per_beat),
433 )
434 )
435
436 for note in target_untracked:
437 cid = _note_content_id(_entity_to_key(note))
438 if cid not in base_content_ids:
439 ops.append(
440 InsertOp(
441 op="insert",
442 address="note:untracked",
443 position=None,
444 content_id=cid,
445 content_summary=_note_summary(_entity_to_key(note), ticks_per_beat),
446 )
447 )
448
449 return ops
450
451 def _entity_to_key(entity: NoteEntity) -> NoteKey:
452 """Extract the five NoteKey fields from a NoteEntity."""
453 return NoteKey(
454 pitch=entity["pitch"],
455 velocity=entity["velocity"],
456 start_tick=entity["start_tick"],
457 duration_ticks=entity["duration_ticks"],
458 channel=entity["channel"],
459 )
460
461 def _field_diff(base: NoteEntity, target: NoteEntity) -> _FieldMutationMap:
462 """Return a FieldMutation map for all fields that changed."""
463 mutations: _FieldMutationMap = {}
464 # Unpack into flat tuples to avoid variable-key TypedDict access.
465 base_vals: tuple[int, int, int, int, int] = (
466 base["pitch"], base["velocity"], base["start_tick"], base["duration_ticks"], base["channel"]
467 )
468 target_vals: tuple[int, int, int, int, int] = (
469 target["pitch"], target["velocity"], target["start_tick"], target["duration_ticks"], target["channel"]
470 )
471 names = ("pitch", "velocity", "start_tick", "duration_ticks", "channel")
472 for name, bv, tv in zip(names, base_vals, target_vals):
473 if bv != tv:
474 mutations[name] = FieldMutation(old=str(bv), new=str(tv))
475 return mutations
476
477 # ---------------------------------------------------------------------------
478 # Entity index I/O
479 # ---------------------------------------------------------------------------
480
481 def build_entity_index(
482 entities: list[NoteEntity],
483 track_path: str,
484 commit_id: str,
485 ) -> EntityIndex:
486 """Build an :class:`EntityIndex` from a list of :class:`NoteEntity` objects.
487
488 Notes without an ``entity_id`` are skipped (untracked notes do not appear
489 in the index).
490
491 Args:
492 entities: Note entities from the current commit.
493 track_path: Workspace-relative MIDI file path.
494 commit_id: Current commit ID.
495
496 Returns:
497 A populated :class:`EntityIndex`.
498 """
499 entries: _EntityMap = {}
500 for note in entities:
501 eid = note.get("entity_id", "")
502 if not eid:
503 continue
504 entries[eid] = EntityIndexEntry(
505 content_id=_note_content_id(_entity_to_key(note)),
506 origin_commit_id=note.get("origin_commit_id", commit_id),
507 voice_id=note.get("voice_id", ""),
508 )
509 return EntityIndex(
510 track_path=track_path,
511 commit_id=commit_id,
512 entities=entries,
513 )
514
515 def _index_path(repo_root: pathlib.Path, commit_id: str, track_path: str) -> pathlib.Path:
516 safe_track = track_path.replace("/", "_").replace(".", "_")
517 sha = split_id(blob_id(track_path.encode()))[1][:8]
518 return (
519 entity_index_dir(repo_root)
520 / commit_id[:16]
521 / f"{safe_track}_{sha}.json"
522 )
523
524 def write_entity_index(
525 repo_root: pathlib.Path,
526 commit_id: str,
527 track_path: str,
528 index: EntityIndex,
529 ) -> None:
530 """Persist *index* to ``.muse/entity_index/<commit_id>/<track>.json``.
531
532 Creates parent directories as needed. Safe to call multiple times —
533 an existing file is overwritten.
534
535 Args:
536 repo_root: Repository root.
537 commit_id: Commit ID for the snapshot this index belongs to.
538 track_path: Workspace-relative MIDI file path.
539 index: The entity index to persist.
540 """
541 path = _index_path(repo_root, commit_id, track_path)
542 path.parent.mkdir(parents=True, exist_ok=True)
543 path.write_text(f"{json.dumps(index, indent=2)}\n")
544 logger.debug(
545 "✅ Entity index written: %d entities for %r @ %s",
546 len(index["entities"]),
547 track_path,
548 short_id(commit_id, strip=True),
549 )
550
551 def read_entity_index(
552 repo_root: pathlib.Path,
553 commit_id: str,
554 track_path: str,
555 ) -> EntityIndex | None:
556 """Load the entity index for *track_path* at *commit_id*.
557
558 Args:
559 repo_root: Repository root.
560 commit_id: Commit ID.
561 track_path: Workspace-relative MIDI file path.
562
563 Returns:
564 The :class:`EntityIndex`, or ``None`` when no index file exists.
565 """
566 path = _index_path(repo_root, commit_id, track_path)
567 if not path.exists():
568 return None
569 raw: EntityIndex | None = load_json_file(path)
570 if raw is None:
571 logger.warning("⚠️ Corrupt entity index %s: unreadable or invalid JSON", path)
572 return None
573 return raw
File History 1 commit
sha256:2eaa5d95f9d9383498e76947410a26e5a3ba23d182f339910c424cf88fad412b fix: try fetch/presign before fetch/mpack to avoid Cloudfla… Sonnet 4.6 patch 8 days ago