gabriel / muse public
test_midi_diff.py python
345 lines 13.5 KB
Raw
sha256:1c4b3e3a9a1f300774c3ee662b572a698d5fd405bf765a71e6011a2e9c3eaaaa feat: Muse — version control for the agent era Human 73 days ago
1 """Tests for muse.plugins.midi.midi_diff — Myers LCS on MIDI note sequences.
2
3 Covers:
4 - NoteKey extraction from MIDI bytes.
5 - LCS edit script correctness (keep/insert/delete).
6 - LCS minimality (length of keep steps == LCS length).
7 - diff_midi_notes() produces correct StructuredDelta.
8 - Content IDs are deterministic and unique per note.
9 - Human-readable summaries and content_summary strings.
10 """
11
12 import io
13 import struct
14
15 import mido
16 import pytest
17
18 pytestmark = pytest.mark.midi
19
20 from muse.plugins.midi.midi_diff import (
21 NoteKey,
22 diff_midi_notes,
23 extract_notes,
24 lcs_edit_script,
25 )
26
27
28 # ---------------------------------------------------------------------------
29 # MIDI builder helpers
30 # ---------------------------------------------------------------------------
31
32 def _build_midi(notes: list[tuple[int, int, int, int]]) -> bytes:
33 """Build a minimal type-0 MIDI file from (pitch, velocity, start, duration) tuples.
34
35 All values use ticks_per_beat=480. Produces valid mido-parseable MIDI bytes.
36 """
37 mid = mido.MidiFile(type=0, ticks_per_beat=480)
38 track = mido.MidiTrack()
39 mid.tracks.append(track)
40
41 # Collect all events sorted by tick.
42 events: list[tuple[int, str, int, int]] = [] # (tick, type, note, velocity)
43 for pitch, velocity, start, duration in notes:
44 events.append((start, "note_on", pitch, velocity))
45 events.append((start + duration, "note_off", pitch, 0))
46
47 events.sort(key=lambda e: e[0])
48
49 prev_tick = 0
50 for tick, msg_type, note, vel in events:
51 delta = tick - prev_tick
52 track.append(mido.Message(msg_type, note=note, velocity=vel, time=delta))
53 prev_tick = tick
54
55 track.append(mido.MetaMessage("end_of_track", time=0))
56
57 buf = io.BytesIO()
58 mid.save(file=buf)
59 return buf.getvalue()
60
61
62 def _note(pitch: int, velocity: int = 80, start: int = 0, duration: int = 480) -> NoteKey:
63 return NoteKey(
64 pitch=pitch, velocity=velocity, start_tick=start,
65 duration_ticks=duration, channel=0,
66 )
67
68
69 # ---------------------------------------------------------------------------
70 # extract_notes
71 # ---------------------------------------------------------------------------
72
73 class TestExtractNotes:
74 def test_empty_midi_returns_no_notes(self) -> None:
75 midi_bytes = _build_midi([])
76 notes, tpb = extract_notes(midi_bytes)
77 assert notes == []
78 assert tpb == 480
79
80 def test_single_note_extracted(self) -> None:
81 midi_bytes = _build_midi([(60, 80, 0, 480)]) # C4
82 notes, tpb = extract_notes(midi_bytes)
83 assert len(notes) == 1
84 assert notes[0]["pitch"] == 60
85 assert notes[0]["velocity"] == 80
86 assert notes[0]["start_tick"] == 0
87 assert notes[0]["duration_ticks"] == 480
88
89 def test_multiple_notes_extracted(self) -> None:
90 midi_bytes = _build_midi([
91 (60, 80, 0, 480),
92 (64, 90, 480, 480),
93 (67, 70, 960, 480),
94 ])
95 notes, _ = extract_notes(midi_bytes)
96 assert len(notes) == 3
97
98 def test_notes_sorted_by_start_tick(self) -> None:
99 midi_bytes = _build_midi([
100 (67, 70, 960, 240),
101 (60, 80, 0, 480),
102 (64, 90, 480, 480),
103 ])
104 notes, _ = extract_notes(midi_bytes)
105 ticks = [n["start_tick"] for n in notes]
106 assert ticks == sorted(ticks)
107
108 def test_invalid_bytes_raises_value_error(self) -> None:
109 with pytest.raises(ValueError):
110 extract_notes(b"not a midi file")
111
112 def test_ticks_per_beat_is_returned(self) -> None:
113 midi_bytes = _build_midi([(60, 80, 0, 480)])
114 _, tpb = extract_notes(midi_bytes)
115 assert tpb == 480
116
117
118 # ---------------------------------------------------------------------------
119 # lcs_edit_script
120 # ---------------------------------------------------------------------------
121
122 class TestLCSEditScript:
123 """LCS tests use start_tick=pitch so same-pitch notes always compare equal.
124
125 NoteKey equality requires ALL five fields to match. Using start_tick=pitch
126 ensures that notes with the same pitch in base and target are considered
127 identical by LCS, giving intuitive edit scripts.
128 """
129
130 def _nk(self, pitch: int) -> NoteKey:
131 """Make a NoteKey where start_tick equals pitch for stable matching."""
132 return NoteKey(
133 pitch=pitch, velocity=80,
134 start_tick=pitch, # deterministic: same pitch → same tick → same key
135 duration_ticks=480, channel=0,
136 )
137
138 def _seq(self, pitches: list[int]) -> list[NoteKey]:
139 return [self._nk(p) for p in pitches]
140
141 def test_identical_sequences_keeps_all(self) -> None:
142 notes = self._seq([60, 62, 64])
143 steps = lcs_edit_script(notes, notes)
144 kinds = [s.kind for s in steps]
145 assert kinds == ["keep", "keep", "keep"]
146
147 def test_empty_to_sequence_all_inserts(self) -> None:
148 target = self._seq([60, 62])
149 steps = lcs_edit_script([], target)
150 assert all(s.kind == "insert" for s in steps)
151 assert len(steps) == 2
152
153 def test_sequence_to_empty_all_deletes(self) -> None:
154 base = self._seq([60, 62])
155 steps = lcs_edit_script(base, [])
156 assert all(s.kind == "delete" for s in steps)
157 assert len(steps) == 2
158
159 def test_single_insert_at_end(self) -> None:
160 # base=[60,62], target=[60,62,64] → keep 60, keep 62, insert 64
161 base = self._seq([60, 62])
162 target = self._seq([60, 62, 64])
163 steps = lcs_edit_script(base, target)
164 keeps = [s for s in steps if s.kind == "keep"]
165 inserts = [s for s in steps if s.kind == "insert"]
166 assert len(keeps) == 2
167 assert len(inserts) == 1
168 assert inserts[0].note["pitch"] == 64
169
170 def test_single_delete_from_middle(self) -> None:
171 # base=[60,62,64], target=[60,64] → keep 60, delete 62, keep 64
172 # NoteKeys with start_tick=pitch ensure 64@64 matches 64@64.
173 base = self._seq([60, 62, 64])
174 target = self._seq([60, 64])
175 steps = lcs_edit_script(base, target)
176 deletes = [s for s in steps if s.kind == "delete"]
177 assert len(deletes) == 1
178 assert deletes[0].note["pitch"] == 62
179
180 def test_pitch_change_is_delete_plus_insert(self) -> None:
181 # A note with a different pitch → one delete + one insert.
182 base = [_note(60)]
183 target = [_note(62)]
184 steps = lcs_edit_script(base, target)
185 kinds = {s.kind for s in steps}
186 assert "delete" in kinds
187 assert "insert" in kinds
188 assert "keep" not in kinds
189
190 def test_lcs_is_minimal_keeps_equal_lcs_length(self) -> None:
191 # LCS of [60,62,64,65] and [60,64,65,67] is [60,64,65] (length 3)
192 # because 60@60, 64@64, 65@65 all have matching counterparts in target.
193 base = self._seq([60, 62, 64, 65])
194 target = self._seq([60, 64, 65, 67])
195 steps = lcs_edit_script(base, target)
196 keeps = [s for s in steps if s.kind == "keep"]
197 assert len(keeps) == 3
198
199 def test_empty_both_returns_empty(self) -> None:
200 steps = lcs_edit_script([], [])
201 assert steps == []
202
203 def test_step_indices_are_consistent(self) -> None:
204 base = self._seq([60, 62, 64])
205 target = self._seq([60, 64])
206 steps = lcs_edit_script(base, target)
207 base_indices = [s.base_index for s in steps if s.kind != "insert"]
208 target_indices = [s.target_index for s in steps if s.kind != "delete"]
209 assert base_indices == sorted(base_indices)
210 assert target_indices == sorted(target_indices)
211
212 def test_reorder_detected_as_delete_insert(self) -> None:
213 # Swapping pitches at the same positions → notes differ → no keeps.
214 # Using start_tick=0 for all to guarantee tick collision is NOT the issue;
215 # the pitch mismatch is what creates the delete+insert.
216 base = [NoteKey(pitch=60, velocity=80, start_tick=0, duration_ticks=480, channel=0),
217 NoteKey(pitch=62, velocity=80, start_tick=480, duration_ticks=480, channel=0)]
218 target = [NoteKey(pitch=62, velocity=80, start_tick=0, duration_ticks=480, channel=0),
219 NoteKey(pitch=60, velocity=80, start_tick=480, duration_ticks=480, channel=0)]
220 steps = lcs_edit_script(base, target)
221 keeps = [s for s in steps if s.kind == "keep"]
222 # No notes match exactly (same pitch at same tick is not present in both).
223 assert len(keeps) == 0
224
225
226 # ---------------------------------------------------------------------------
227 # diff_midi_notes
228 # ---------------------------------------------------------------------------
229
230 class TestDiffMidiNotes:
231 def test_no_change_returns_empty_ops(self) -> None:
232 midi_bytes = _build_midi([(60, 80, 0, 480)])
233 delta = diff_midi_notes(midi_bytes, midi_bytes)
234 assert delta["ops"] == []
235
236 def test_no_change_summary(self) -> None:
237 midi_bytes = _build_midi([(60, 80, 0, 480)])
238 delta = diff_midi_notes(midi_bytes, midi_bytes)
239 assert "no note changes" in delta["summary"]
240
241 def test_add_note_returns_insert_op(self) -> None:
242 base_bytes = _build_midi([(60, 80, 0, 480)])
243 target_bytes = _build_midi([(60, 80, 0, 480), (64, 80, 480, 480)])
244 delta = diff_midi_notes(base_bytes, target_bytes)
245 inserts = [op for op in delta["ops"] if op["op"] == "insert"]
246 assert len(inserts) == 1
247
248 def test_remove_note_returns_delete_op(self) -> None:
249 base_bytes = _build_midi([(60, 80, 0, 480), (64, 80, 480, 480)])
250 target_bytes = _build_midi([(60, 80, 0, 480)])
251 delta = diff_midi_notes(base_bytes, target_bytes)
252 deletes = [op for op in delta["ops"] if op["op"] == "delete"]
253 assert len(deletes) == 1
254
255 def test_change_pitch_produces_delete_and_insert(self) -> None:
256 base_bytes = _build_midi([(60, 80, 0, 480)])
257 target_bytes = _build_midi([(62, 80, 0, 480)])
258 delta = diff_midi_notes(base_bytes, target_bytes)
259 kinds = {op["op"] for op in delta["ops"]}
260 assert "delete" in kinds
261 assert "insert" in kinds
262
263 def test_summary_mentions_added_notes(self) -> None:
264 base_bytes = _build_midi([(60, 80, 0, 480)])
265 target_bytes = _build_midi([(60, 80, 0, 480), (64, 80, 480, 480)])
266 delta = diff_midi_notes(base_bytes, target_bytes)
267 assert "added" in delta["summary"]
268
269 def test_summary_mentions_removed_notes(self) -> None:
270 base_bytes = _build_midi([(60, 80, 0, 480), (64, 80, 480, 480)])
271 target_bytes = _build_midi([(60, 80, 0, 480)])
272 delta = diff_midi_notes(base_bytes, target_bytes)
273 assert "removed" in delta["summary"]
274
275 def test_summary_singular_for_one_note(self) -> None:
276 base_bytes = _build_midi([])
277 target_bytes = _build_midi([(60, 80, 0, 480)])
278 delta = diff_midi_notes(base_bytes, target_bytes)
279 assert "1 note added" in delta["summary"]
280
281 def test_summary_plural_for_multiple_notes(self) -> None:
282 base_bytes = _build_midi([])
283 target_bytes = _build_midi([(60, 80, 0, 480), (64, 80, 480, 480)])
284 delta = diff_midi_notes(base_bytes, target_bytes)
285 assert "2 notes added" in delta["summary"]
286
287 def test_content_id_is_deterministic(self) -> None:
288 midi_bytes = _build_midi([(60, 80, 0, 480)])
289 empty_bytes = _build_midi([])
290 delta1 = diff_midi_notes(empty_bytes, midi_bytes)
291 delta2 = diff_midi_notes(empty_bytes, midi_bytes)
292 ids1 = [op["content_id"] for op in delta1["ops"]]
293 ids2 = [op["content_id"] for op in delta2["ops"]]
294 assert ids1 == ids2
295
296 def test_content_ids_differ_for_different_notes(self) -> None:
297 empty_bytes = _build_midi([])
298 midi_c4 = _build_midi([(60, 80, 0, 480)])
299 midi_d4 = _build_midi([(62, 80, 0, 480)])
300 delta_c4 = diff_midi_notes(empty_bytes, midi_c4)
301 delta_d4 = diff_midi_notes(empty_bytes, midi_d4)
302 id_c4 = delta_c4["ops"][0]["content_id"]
303 id_d4 = delta_d4["ops"][0]["content_id"]
304 assert id_c4 != id_d4
305
306 def test_content_summary_is_human_readable(self) -> None:
307 empty_bytes = _build_midi([])
308 target_bytes = _build_midi([(60, 80, 0, 480)]) # C4
309 delta = diff_midi_notes(empty_bytes, target_bytes)
310 summary = delta["ops"][0]["content_summary"]
311 assert "C4" in summary
312 assert "vel=80" in summary
313
314 def test_domain_is_midi_notes(self) -> None:
315 midi_bytes = _build_midi([(60, 80, 0, 480)])
316 empty_bytes = _build_midi([])
317 delta = diff_midi_notes(empty_bytes, midi_bytes)
318 assert delta["domain"] == "midi_notes"
319
320 def test_invalid_base_raises_value_error(self) -> None:
321 valid = _build_midi([(60, 80, 0, 480)])
322 with pytest.raises(ValueError):
323 diff_midi_notes(b"garbage", valid)
324
325 def test_invalid_target_raises_value_error(self) -> None:
326 valid = _build_midi([(60, 80, 0, 480)])
327 with pytest.raises(ValueError):
328 diff_midi_notes(valid, b"garbage")
329
330 def test_file_path_appears_in_content_summary_context(self) -> None:
331 # file_path is used only for logging; no crash expected.
332 base_bytes = _build_midi([])
333 target_bytes = _build_midi([(60, 80, 0, 480)])
334 delta = diff_midi_notes(
335 base_bytes, target_bytes, file_path="tracks/piano.mid"
336 )
337 assert len(delta["ops"]) == 1
338
339 def test_position_reflects_sequence_index(self) -> None:
340 empty = _build_midi([])
341 two_notes = _build_midi([(60, 80, 0, 480), (64, 80, 480, 480)])
342 delta = diff_midi_notes(empty, two_notes)
343 positions = [op["position"] for op in delta["ops"]]
344 assert 0 in positions
345 assert 1 in positions
File History 1 commit
sha256:1c4b3e3a9a1f300774c3ee662b572a698d5fd405bf765a71e6011a2e9c3eaaaa feat: Muse — version control for the agent era Human 73 days ago