gabriel / muse public
midi_shard.py python
169 lines 6.5 KB
Raw
1 """muse shard — partition a MIDI composition into bar-range shards for parallel agents.
2
3 Splits a track into N non-overlapping bar-range segments and writes each
4 shard as a separate MIDI file. An agent swarm can then work on shards in
5 parallel with zero risk of note-level conflicts, merging the shards back
6 together with ``muse mix``.
7
8 Usage::
9
10 muse shard tracks/full.mid --shards 4
11 muse shard tracks/full.mid --shards 8 --output-dir shards/
12 muse shard tracks/full.mid --bars-per-shard 16
13 muse shard tracks/full.mid --shards 4 --dry-run
14
15 Output::
16
17 Shard plan: tracks/full.mid → 4 shards
18 Total bars: 32 · ~8 bars per shard
19
20 Shard 0 bars 1– 8 → shards/full_shard_0.mid (28 notes)
21 Shard 1 bars 9–16 → shards/full_shard_1.mid (31 notes)
22 Shard 2 bars 17–24 → shards/full_shard_2.mid (24 notes)
23 Shard 3 bars 25–32 → shards/full_shard_3.mid (19 notes)
24
25 ✅ 4 shards written to shards/
26 """
27
28 import argparse
29 import logging
30 import pathlib
31 import sys
32
33 from muse.core.errors import ExitCode
34 from muse.core.repo import require_repo
35 from muse.core.validation import contain_path
36 from muse.plugins.midi._query import NoteInfo, load_track_from_workdir, notes_by_bar, notes_to_midi_bytes
37
38 logger = logging.getLogger(__name__)
39
40 def _shard_notes(
41 notes: list[NoteInfo],
42 bar_ranges: list[tuple[int, int]],
43 ) -> list[list[NoteInfo]]:
44 """Partition notes into groups by bar range, rebasing start ticks to 0."""
45 bars = notes_by_bar(notes)
46 if not notes:
47 return [[] for _ in bar_ranges]
48
49 tpb = notes[0].ticks_per_beat
50
51 shards: list[list[NoteInfo]] = []
52 for lo_bar, hi_bar in bar_ranges:
53 shard_notes: list[NoteInfo] = []
54 bar_offset = (lo_bar - 1) * 4 * tpb
55 for bar_num in range(lo_bar, hi_bar + 1):
56 for note in bars.get(bar_num, []):
57 rebased_tick = max(0, note.start_tick - bar_offset)
58 shard_notes.append(NoteInfo(
59 pitch=note.pitch,
60 velocity=note.velocity,
61 start_tick=rebased_tick,
62 duration_ticks=note.duration_ticks,
63 channel=note.channel,
64 ticks_per_beat=note.ticks_per_beat,
65 ))
66 shards.append(shard_notes)
67 return shards
68
69 def register(subparsers: "argparse._SubParsersAction[argparse.ArgumentParser]") -> None:
70 """Register the shard subcommand."""
71 parser = subparsers.add_parser("shard", help="Split a MIDI track into N bar-range shards for parallel agent work.", description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
72 parser.add_argument("track", metavar="TRACK", help="Workspace-relative path to a .mid file.")
73 parser.add_argument("--shards", "-n", metavar="N", type=int, default=None, dest="num_shards", help="Number of shards to split into (mutually exclusive with --bars-per-shard).")
74 parser.add_argument("--bars-per-shard", "-b", metavar="N", type=int, default=None, help="Bars per shard (mutually exclusive with --shards).")
75 parser.add_argument("--output-dir", "-o", metavar="DIR", default="shards", help="Directory to write shard files (default: shards/).")
76 parser.add_argument("--dry-run", action="store_true", help="Preview shard plan without writing.")
77 parser.set_defaults(func=run)
78
79 def run(args: argparse.Namespace) -> None:
80 """Split a MIDI track into N bar-range shards for parallel agent work.
81
82 ``muse shard`` is the musical equivalent of partitioning a codebase for
83 a parallelised agent swarm. Each shard is a valid MIDI file covering a
84 non-overlapping bar range. Agents work on shards independently, then
85 the shards are recombined with ``muse mix``.
86
87 Specify either ``--shards N`` (divide evenly) or ``--bars-per-shard N``
88 (fixed shard size with a remainder shard at the end).
89 """
90 track: str = args.track
91 num_shards: int | None = args.num_shards
92 bars_per_shard: int | None = args.bars_per_shard
93 output_dir: str = args.output_dir
94 dry_run: bool = args.dry_run
95
96 if num_shards is not None and bars_per_shard is not None:
97 print("❌ --shards and --bars-per-shard are mutually exclusive.", file=sys.stderr)
98 raise SystemExit(ExitCode.USER_ERROR)
99 if num_shards is None and bars_per_shard is None:
100 num_shards = 4
101
102 root = require_repo()
103 result = load_track_from_workdir(root, track)
104 if result is None:
105 print(f"❌ Track '{track}' not found or not a valid MIDI file.", file=sys.stderr)
106 raise SystemExit(ExitCode.USER_ERROR)
107
108 notes, tpb = result
109 if not notes:
110 print(f" (track '{track}' contains no notes — nothing to shard)")
111 return
112
113 bars = notes_by_bar(notes)
114 all_bars = sorted(bars.keys())
115 total_bars = len(all_bars)
116 if total_bars == 0:
117 print(" (no bars detected)")
118 return
119
120 first_bar = all_bars[0]
121 last_bar = all_bars[-1]
122 bar_span = last_bar - first_bar + 1
123
124 # Determine bar-range splits
125 if num_shards is not None:
126 n = max(1, num_shards)
127 bps = max(1, bar_span // n)
128 else:
129 bps = max(1, bars_per_shard or 1)
130 n = (bar_span + bps - 1) // bps
131
132 bar_ranges: list[tuple[int, int]] = []
133 cur = first_bar
134 for i in range(n):
135 lo = cur
136 hi = lo + bps - 1 if i < n - 1 else last_bar
137 bar_ranges.append((lo, hi))
138 cur = hi + 1
139 if cur > last_bar:
140 break
141
142 track_stem = pathlib.Path(track).stem
143
144 print(f"\nShard plan: {track} → {len(bar_ranges)} shards")
145 print(f"Total bars: {total_bars} · ~{bps} bars per shard\n")
146
147 shard_notes_list = _shard_notes(notes, bar_ranges)
148
149 try:
150 out_dir = contain_path(root, output_dir)
151 except ValueError as exc:
152 print(f"❌ Invalid --output-dir: {exc}")
153 raise SystemExit(ExitCode.USER_ERROR)
154 for idx, ((lo, hi), shard_notes) in enumerate(zip(bar_ranges, shard_notes_list)):
155 out_name = f"{track_stem}_shard_{idx}.mid"
156 out_path = out_dir / out_name
157 print(
158 f" Shard {idx} bars {lo:>3}–{hi:>3} → {output_dir}/{out_name}"
159 f" ({len(shard_notes)} notes)"
160 )
161 if not dry_run:
162 out_dir.mkdir(parents=True, exist_ok=True)
163 midi_bytes = notes_to_midi_bytes(shard_notes, tpb) if shard_notes else notes_to_midi_bytes([], tpb)
164 out_path.write_bytes(midi_bytes)
165
166 if dry_run:
167 print("\n No files written (--dry-run).")
168 else:
169 print(f"\n✅ {len(bar_ranges)} shards written to {output_dir}/")
File History 1 commit