gabriel / muse public
_invariants.py python
576 lines 19.6 KB
Raw
sha256:2eaa5d95f9d9383498e76947410a26e5a3ba23d182f339910c424cf88fad412b fix: try fetch/presign before fetch/mpack to avoid Cloudfla… Sonnet 4.6 patch 7 days ago
1 """Musical invariants engine for the Muse MIDI plugin.
2
3 Invariants are semantic rules that a MIDI track must satisfy. They are
4 evaluated at commit time, merge time, or on-demand via ``muse music-check``.
5 Violations are reported with human-readable descriptions, severity levels,
6 and structured addresses for programmatic consumers.
7
8 Rule file format (TOML)
9 -----------------------
10 Rules are declared in ``.muse/midi_invariants.toml`` (default path).
11 Example::
12
13 [[rule]]
14 name = "max_polyphony"
15 severity = "error"
16 scope = "track"
17 rule_type = "max_polyphony"
18
19 [rule.params]
20 max_simultaneous = 6
21
22 [[rule]]
23 name = "keep_in_range"
24 severity = "warning"
25 scope = "track"
26 rule_type = "pitch_range"
27
28 [rule.params]
29 min_pitch = 24
30 max_pitch = 108
31
32 [[rule]]
33 name = "no_fifths"
34 severity = "warning"
35 scope = "voice_pair"
36 rule_type = "no_parallel_fifths"
37
38 [[rule]]
39 name = "consistent_key"
40 severity = "info"
41 scope = "track"
42 rule_type = "key_consistency"
43
44 [rule.params]
45 threshold = 0.15
46
47 Built-in rule types
48 -------------------
49
50 ``max_polyphony``
51 Detects bars where more than *max_simultaneous* notes overlap at any
52 tick position. Uses a sweep-line algorithm over start/end tick events.
53
54 ``pitch_range``
55 Detects any note with ``pitch < min_pitch`` or ``pitch > max_pitch``.
56
57 ``key_consistency``
58 Detects notes whose pitch class is highly inconsistent with the key
59 estimated by the Krumhansl-Schmuckler algorithm. Fires when the ratio
60 of "foreign" pitch classes exceeds *threshold*.
61
62 ``no_parallel_fifths``
63 Detects consecutive bars where the lowest voice and the second-lowest
64 voice both move by a perfect fifth in parallel (a classical counterpoint
65 violation). Best-effort heuristic — voice assignment is implicit.
66
67 Severity levels
68 ---------------
69 - ``"error"`` — must be resolved before committing (when ``--strict`` is set).
70 - ``"warning"`` — reported but does not block commits.
71 - ``"info"`` — informational; surfaced in ``muse music-check`` output only.
72
73 Public API
74 ----------
75 - :class:`InvariantRule` — rule declaration TypedDict.
76 - :class:`InvariantViolation` — single violation record TypedDict.
77 - :class:`InvariantReport` — full report for one commit / track.
78 - :func:`load_invariant_rules` — load from TOML file with defaults fallback.
79 - :func:`run_invariants` — evaluate all rules against a commit.
80 """
81
82 import logging
83 import pathlib
84 from typing import Literal, TypedDict
85
86 from muse.core.invariants import BaseReport, BaseViolation, make_report
87 from muse.core.object_store import read_object
88 from muse.core.snapshots import get_commit_snapshot_manifest
89 from muse.plugins.midi._query import NoteInfo, key_signature_guess, notes_by_bar
90 from muse.plugins.midi.midi_diff import extract_notes
91
92 type _ScopeMap = dict[str, "Literal['track', 'bar', 'voice_pair', 'global']"]
93 type _SeverityMap = dict[str, "Literal['info', 'warning', 'error']"]
94 type _ParamMap = dict[str, str | int | float]
95 logger = logging.getLogger(__name__)
96
97
98 # ---------------------------------------------------------------------------
99 # Types
100 # ---------------------------------------------------------------------------
101
102 class _InvariantRuleRequired(TypedDict):
103 name: str
104 severity: Literal["info", "warning", "error"]
105 scope: Literal["track", "bar", "voice_pair", "global"]
106 rule_type: str
107
108 class InvariantRule(_InvariantRuleRequired, total=False):
109 """Declaration of one MIDI invariant rule.
110
111 ``name`` Human-readable rule identifier (unique within a rule set).
112 ``severity`` Violation severity: ``"info"``, ``"warning"``, or ``"error"``.
113 ``scope`` Granularity: ``"track"``, ``"bar"``, ``"voice_pair"``, ``"global"``.
114 ``rule_type`` Built-in type string: ``"max_polyphony"``, ``"pitch_range"``,
115 ``"key_consistency"``, ``"no_parallel_fifths"``.
116 ``params`` Rule-specific parameter dict.
117 """
118
119 params: _ParamMap
120
121 class InvariantViolation(TypedDict):
122 """A single invariant violation record.
123
124 ``rule_name`` The name of the rule that fired.
125 ``severity`` Severity level from the rule declaration.
126 ``track`` Workspace-relative MIDI file path.
127 ``bar`` 1-indexed bar number (0 for track-level violations).
128 ``description`` Human-readable explanation of what was violated.
129 ``addresses`` Note addresses or other domain addresses involved.
130 """
131
132 rule_name: str
133 severity: Literal["info", "warning", "error"]
134 track: str
135 bar: int
136 description: str
137 addresses: list[str]
138
139 class InvariantReport(TypedDict):
140 """Full invariant check report for one commit.
141
142 ``commit_id`` The commit that was checked.
143 ``violations`` All violations found, sorted by track then bar.
144 ``rules_checked`` Number of rules evaluated.
145 ``has_errors`` True when any violation has severity ``"error"``.
146 ``has_warnings`` True when any violation has severity ``"warning"``.
147 """
148
149 commit_id: str
150 violations: list[InvariantViolation]
151 rules_checked: int
152 has_errors: bool
153 has_warnings: bool
154
155 # ---------------------------------------------------------------------------
156 # Built-in rule implementations
157 # ---------------------------------------------------------------------------
158
159 def check_max_polyphony(
160 notes: list[NoteInfo],
161 track: str,
162 rule_name: str,
163 severity: Literal["info", "warning", "error"],
164 *,
165 max_simultaneous: int = 6,
166 ) -> list[InvariantViolation]:
167 """Find bars where simultaneous note count exceeds *max_simultaneous*.
168
169 Uses a tick-based sweep-line over (start_tick, end_tick) intervals.
170 Reports one violation per offending bar.
171
172 Args:
173 notes: All notes in the track.
174 track: Track file path for violation records.
175 rule_name: Rule identifier string.
176 severity: Violation severity.
177 max_simultaneous: Maximum allowed simultaneous notes.
178
179 Returns:
180 List of :class:`InvariantViolation` records.
181 """
182 violations: list[InvariantViolation] = []
183 bars = notes_by_bar(notes)
184
185 for bar_num, bar_notes in sorted(bars.items()):
186 # Collect all tick events: +1 for note_on, -1 for note_off.
187 events: list[tuple[int, int]] = []
188 for n in bar_notes:
189 events.append((n.start_tick, 1))
190 events.append((n.start_tick + n.duration_ticks, -1))
191 events.sort(key=lambda e: (e[0], e[1])) # off before on at same tick
192
193 current = 0
194 peak = 0
195 peak_tick = 0
196 for tick, delta in events:
197 current += delta
198 if current > peak:
199 peak = current
200 peak_tick = tick
201
202 if peak > max_simultaneous:
203 violations.append(
204 InvariantViolation(
205 rule_name=rule_name,
206 severity=severity,
207 track=track,
208 bar=bar_num,
209 description=(
210 f"Polyphony reached {peak} simultaneous notes at tick {peak_tick} "
211 f"(max allowed: {max_simultaneous})"
212 ),
213 addresses=[f"bar:{bar_num}:tick:{peak_tick}"],
214 )
215 )
216
217 return violations
218
219 def check_pitch_range(
220 notes: list[NoteInfo],
221 track: str,
222 rule_name: str,
223 severity: Literal["info", "warning", "error"],
224 *,
225 min_pitch: int = 0,
226 max_pitch: int = 127,
227 ) -> list[InvariantViolation]:
228 """Find notes outside the allowed MIDI pitch range.
229
230 Args:
231 notes: All notes in the track.
232 track: Track file path.
233 rule_name: Rule identifier.
234 severity: Violation severity.
235 min_pitch: Lowest allowed MIDI pitch (inclusive).
236 max_pitch: Highest allowed MIDI pitch (inclusive).
237
238 Returns:
239 One :class:`InvariantViolation` per out-of-range note.
240 """
241 violations: list[InvariantViolation] = []
242 for note in notes:
243 if note.pitch < min_pitch or note.pitch > max_pitch:
244 violations.append(
245 InvariantViolation(
246 rule_name=rule_name,
247 severity=severity,
248 track=track,
249 bar=note.bar,
250 description=(
251 f"Note {note.pitch_name} (MIDI {note.pitch}) is outside "
252 f"allowed range [{min_pitch}, {max_pitch}]"
253 ),
254 addresses=[f"bar:{note.bar}:pitch:{note.pitch}"],
255 )
256 )
257 return violations
258
259 def check_key_consistency(
260 notes: list[NoteInfo],
261 track: str,
262 rule_name: str,
263 severity: Literal["info", "warning", "error"],
264 *,
265 threshold: float = 0.15,
266 ) -> list[InvariantViolation]:
267 """Detect notes whose pitch class is inconsistent with the guessed key.
268
269 Estimates the key using the Krumhansl-Schmuckler algorithm, then counts
270 the fraction of notes that use a pitch class not diatonic to that key.
271 Fires when the foreign-note ratio exceeds *threshold*.
272
273 Args:
274 notes: All notes in the track.
275 track: Track file path.
276 rule_name: Rule identifier.
277 severity: Violation severity.
278 threshold: Maximum allowed ratio of foreign pitch classes (0.0–1.0).
279
280 Returns:
281 Zero or one :class:`InvariantViolation` for the track.
282 """
283 if not notes:
284 return []
285
286 key_guess = key_signature_guess(notes)
287 # Parse key guess string e.g. "G major" or "D minor".
288 parts = key_guess.split()
289 if len(parts) < 2:
290 return []
291
292 root_name = parts[0]
293 mode = parts[1]
294
295 pitch_classes = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
296 root_idx = pitch_classes.index(root_name) if root_name in pitch_classes else -1
297 if root_idx < 0:
298 return []
299
300 # Diatonic pitch classes for major and natural minor scales.
301 major_steps = [0, 2, 4, 5, 7, 9, 11]
302 minor_steps = [0, 2, 3, 5, 7, 8, 10]
303 steps = major_steps if mode == "major" else minor_steps
304 diatonic_pcs = frozenset((root_idx + s) % 12 for s in steps)
305
306 foreign = sum(1 for n in notes if n.pitch_class not in diatonic_pcs)
307 ratio = foreign / len(notes)
308
309 if ratio > threshold:
310 return [
311 InvariantViolation(
312 rule_name=rule_name,
313 severity=severity,
314 track=track,
315 bar=0,
316 description=(
317 f"{foreign}/{len(notes)} notes ({ratio:.0%}) use pitch classes "
318 f"foreign to estimated key {key_guess} "
319 f"(threshold: {threshold:.0%})"
320 ),
321 addresses=[track],
322 )
323 ]
324 return []
325
326 def check_no_parallel_fifths(
327 notes: list[NoteInfo],
328 track: str,
329 rule_name: str,
330 severity: Literal["info", "warning", "error"],
331 ) -> list[InvariantViolation]:
332 """Detect consecutive bars with parallel perfect fifth motion.
333
334 Heuristic: for each pair of consecutive bars, find the two lowest-pitched
335 notes (approximating bass and tenor voices) and check whether both voices
336 move by a perfect fifth (7 semitones) in the same direction.
337
338 This is a best-effort approximation — accurate voice separation would
339 require dedicated voice-leading analysis beyond this scope.
340
341 Args:
342 notes: All notes in the track.
343 track: Track file path.
344 rule_name: Rule identifier.
345 severity: Violation severity.
346
347 Returns:
348 One :class:`InvariantViolation` per detected parallel-fifth bar pair.
349 """
350 violations: list[InvariantViolation] = []
351 bars = notes_by_bar(notes)
352 sorted_bars = sorted(bars.keys())
353
354 for i in range(len(sorted_bars) - 1):
355 bar_a = sorted_bars[i]
356 bar_b = sorted_bars[i + 1]
357 notes_a = sorted(bars[bar_a], key=lambda n: n.pitch)
358 notes_b = sorted(bars[bar_b], key=lambda n: n.pitch)
359
360 if len(notes_a) < 2 or len(notes_b) < 2:
361 continue
362
363 # Take two lowest pitches as approximated bass + tenor voices.
364 v1_a, v2_a = notes_a[0].pitch, notes_a[1].pitch
365 v1_b, v2_b = notes_b[0].pitch, notes_b[1].pitch
366
367 # Interval between voices in each bar.
368 interval_a = abs(v2_a - v1_a) % 12
369 interval_b = abs(v2_b - v1_b) % 12
370
371 # Both form a perfect fifth (7 semitones modulo octave)?
372 if interval_a == 7 and interval_b == 7:
373 # Both voices moved in the same direction?
374 motion_v1 = v1_b - v1_a
375 motion_v2 = v2_b - v2_a
376 if (motion_v1 > 0 and motion_v2 > 0) or (motion_v1 < 0 and motion_v2 < 0):
377 violations.append(
378 InvariantViolation(
379 rule_name=rule_name,
380 severity=severity,
381 track=track,
382 bar=bar_b,
383 description=(
384 f"Parallel fifths between bars {bar_a} and {bar_b}: "
385 f"lower voice {notes_a[0].pitch_name}→{notes_b[0].pitch_name}, "
386 f"upper voice {notes_a[1].pitch_name}→{notes_b[1].pitch_name}"
387 ),
388 addresses=[f"bar:{bar_a}", f"bar:{bar_b}"],
389 )
390 )
391
392 return violations
393
394 # ---------------------------------------------------------------------------
395 # Rule loading
396 # ---------------------------------------------------------------------------
397
398 _DEFAULT_RULE_SET: list[InvariantRule] = [
399 InvariantRule(
400 name="max_polyphony",
401 severity="warning",
402 scope="track",
403 rule_type="max_polyphony",
404 params={"max_simultaneous": 8},
405 ),
406 InvariantRule(
407 name="pitch_range",
408 severity="warning",
409 scope="track",
410 rule_type="pitch_range",
411 params={"min_pitch": 0, "max_pitch": 127},
412 ),
413 ]
414
415 def load_invariant_rules(rules_file: pathlib.Path | None = None) -> list[InvariantRule]:
416 """Load invariant rules from a TOML file, falling back to defaults.
417
418 Requires ``tomllib`` (Python 3.11+) for TOML parsing. If the file does
419 not exist or cannot be parsed, the default rule set is returned.
420
421 Args:
422 rules_file: Path to the TOML rule file. ``None`` means use defaults.
423
424 Returns:
425 List of :class:`InvariantRule` dicts.
426 """
427 if rules_file is None or not rules_file.exists():
428 return list(_DEFAULT_RULE_SET)
429
430 try:
431 import tomllib
432
433 with rules_file.open("rb") as fh:
434 data = tomllib.load(fh)
435
436 rules: list[InvariantRule] = []
437 for raw in data.get("rule", []):
438 _valid_severities: _SeverityMap = {
439 "info": "info", "warning": "warning", "error": "error",
440 }
441 _valid_scopes: _ScopeMap = {
442 "track": "track", "bar": "bar", "voice_pair": "voice_pair", "global": "global",
443 }
444 sev = _valid_severities.get(str(raw.get("severity", "")), "warning")
445 scope = _valid_scopes.get(str(raw.get("scope", "")), "track")
446 rule = InvariantRule(
447 name=str(raw.get("name", "unnamed")),
448 severity=sev,
449 scope=scope,
450 rule_type=str(raw.get("rule_type", "")),
451 )
452 if "params" in raw:
453 rule["params"] = raw["params"]
454 rules.append(rule)
455 return rules if rules else list(_DEFAULT_RULE_SET)
456
457 except Exception as exc:
458 logger.warning("⚠️ Could not load invariant rules from %s: %s", rules_file, exc)
459 return list(_DEFAULT_RULE_SET)
460
461 # ---------------------------------------------------------------------------
462 # Main runner
463 # ---------------------------------------------------------------------------
464
465 def run_invariants(
466 root: "pathlib.Path",
467 commit_id: str,
468 rules: list[InvariantRule],
469 *,
470 track_filter: str | None = None,
471 ) -> InvariantReport:
472 """Evaluate all *rules* against every MIDI track in *commit_id*.
473
474 Args:
475 root: Repository root.
476 commit_id: Commit to check.
477 rules: List of :class:`InvariantRule` declarations.
478 track_filter: Restrict check to a single MIDI file path.
479
480 Returns:
481 An :class:`InvariantReport` with all violations found.
482 """
483 import pathlib as _pathlib
484
485 all_violations: list[InvariantViolation] = []
486 manifest = get_commit_snapshot_manifest(root, commit_id) or {}
487
488 midi_paths = [
489 p for p in manifest
490 if p.lower().endswith(".mid")
491 and (track_filter is None or p == track_filter)
492 ]
493
494 for track_path in sorted(midi_paths):
495 obj_hash = manifest.get(track_path)
496 if obj_hash is None:
497 continue
498 raw = read_object(root, obj_hash)
499 if raw is None:
500 continue
501 try:
502 keys, tpb = extract_notes(raw)
503 except ValueError as exc:
504 logger.debug("Cannot parse MIDI %r: %s", track_path, exc)
505 continue
506
507 notes = [NoteInfo.from_note_key(k, tpb) for k in keys]
508
509 for rule in rules:
510 rt = rule["rule_type"]
511 sev = rule["severity"]
512 params = rule.get("params", {})
513 name = rule["name"]
514
515 if rt == "max_polyphony":
516 max_sim = int(params.get("max_simultaneous", 8))
517 all_violations.extend(
518 check_max_polyphony(notes, track_path, name, sev, max_simultaneous=max_sim)
519 )
520 elif rt == "pitch_range":
521 min_p = int(params.get("min_pitch", 0))
522 max_p = int(params.get("max_pitch", 127))
523 all_violations.extend(
524 check_pitch_range(notes, track_path, name, sev, min_pitch=min_p, max_pitch=max_p)
525 )
526 elif rt == "key_consistency":
527 thresh = float(params.get("threshold", 0.15))
528 all_violations.extend(
529 check_key_consistency(notes, track_path, name, sev, threshold=thresh)
530 )
531 elif rt == "no_parallel_fifths":
532 all_violations.extend(
533 check_no_parallel_fifths(notes, track_path, name, sev)
534 )
535 else:
536 logger.debug("Unknown rule_type %r in rule %r — skipped", rt, name)
537
538 all_violations.sort(key=lambda v: (v["track"], v["bar"]))
539 has_errors = any(v["severity"] == "error" for v in all_violations)
540 has_warnings = any(v["severity"] == "warning" for v in all_violations)
541
542 return InvariantReport(
543 commit_id=commit_id,
544 violations=all_violations,
545 rules_checked=len(rules) * len(midi_paths),
546 has_errors=has_errors,
547 has_warnings=has_warnings,
548 )
549
550 class MidiChecker:
551 """Satisfies :class:`~muse.core.invariants.InvariantChecker` for the MIDI domain.
552
553 Wraps :func:`run_invariants` so that the generic ``muse check`` command
554 can dispatch to the MIDI checker without knowing MIDI internals.
555 """
556
557 def check(
558 self,
559 repo_root: pathlib.Path,
560 commit_id: str,
561 *,
562 rules_file: pathlib.Path | None = None,
563 ) -> BaseReport:
564 """Run MIDI invariant checks against *commit_id* and return a :class:`~muse.core.invariants.BaseReport`."""
565 rules = load_invariant_rules(rules_file)
566 midi_report = run_invariants(repo_root, commit_id, rules)
567 base_violations: list[BaseViolation] = [
568 BaseViolation(
569 rule_name=v["rule_name"],
570 severity=v["severity"],
571 address=v["track"],
572 description=v["description"],
573 )
574 for v in midi_report["violations"]
575 ]
576 return make_report(commit_id, "midi", base_violations, midi_report["rules_checked"])
File History 1 commit
sha256:2eaa5d95f9d9383498e76947410a26e5a3ba23d182f339910c424cf88fad412b fix: try fetch/presign before fetch/mpack to avoid Cloudfla… Sonnet 4.6 patch 7 days ago