gabriel / muse public
test_mpack_wire.py python
369 lines 12.4 KB
Raw
1 """TDD — MPack wire format encode/decode (issue #70 Phase 3).
2
3 Phase 3 replaces the msgpack bundle encoding with the MPack binary format
4 on both client and server.
5
6 Wire format:
7 [4B] b"MUSE"
8 [1B] version: 1
9 [1B] section_count
10 [N*17B] section table: each (1B type, 8B offset, 8B length)
11 [...] section data
12 [32B] SHA-256 footer
13
14 Sections:
15 OBJECTS (type=1): raw _build_pack() bytes — byte-identical to Phase 1 .mpack
16 COMMITS (type=2): [8B count] + N × [8B len + msgpack bytes]
17 SNAPSHOTS (type=3): same
18 TAGS (type=4): same
19 """
20 from __future__ import annotations
21
22 import datetime
23 import hashlib
24 import json
25 import pathlib
26 import struct
27
28 import pytest
29
30 from muse.core.mpack import MPack, apply_mpack, build_wire_mpack, parse_wire_mpack
31 from muse.core.object_store import read_object
32 from muse.core.paths import muse_dir
33 from muse.core.ids import hash_commit as compute_commit_id, hash_snapshot as compute_snapshot_id
34 from muse.core.commits import CommitRecord
35 from muse.core.types import blob_id
36
37 _DT = datetime.datetime(2026, 1, 1, tzinfo=datetime.timezone.utc)
38 _MAGIC = b"MUSE"
39
40
41 # ---------------------------------------------------------------------------
42 # Fixtures
43 # ---------------------------------------------------------------------------
44
45
46 def _init_repo(root: pathlib.Path) -> pathlib.Path:
47 dot = muse_dir(root)
48 dot.mkdir(parents=True)
49 (dot / "repo.json").write_text(json.dumps({"repo_id": "wire-test"}))
50 for d in ("commits", "snapshots", "objects", "refs/heads"):
51 (dot / d).mkdir(parents=True, exist_ok=True)
52 (dot / "HEAD").write_text("ref: refs/heads/main\n")
53 (dot / "config.toml").write_text("")
54 return root
55
56
57 def _make_mpack(n_objects: int = 3) -> tuple[MPack, list[tuple[str, bytes]]]:
58 objects: list[tuple[str, bytes]] = []
59 manifest: dict[str, str] = {}
60 for i in range(n_objects):
61 content = f"wire-content-{i}".encode() * 16
62 oid = blob_id(content)
63 objects.append((oid, content))
64 manifest[f"file_{i}.txt"] = oid
65
66 sid = compute_snapshot_id(manifest)
67 cid = compute_commit_id(
68 parent_ids=[], snapshot_id=sid, message="wire test",
69 committed_at_iso=_DT.isoformat(),
70 )
71 mpack: MPack = {
72 "blobs": [{"object_id": oid, "content": raw} for oid, raw in objects],
73 "snapshots": [{
74 "snapshot_id": sid,
75 "parent_snapshot_id": None,
76 "delta_upsert": manifest,
77 "delta_remove": [],
78 }],
79 "commits": [CommitRecord(
80 commit_id=cid, branch="main",
81 snapshot_id=sid, message="wire test", committed_at=_DT,
82 parent_commit_id=None, parent2_commit_id=None,
83 author="", metadata={}, structured_delta=None,
84 sem_ver_bump="none", breaking_changes=[],
85 agent_id="", model_id="", toolchain_id="",
86 prompt_hash="", signature="", signer_key_id="",
87 ).to_dict()],
88 "tags": [],
89 }
90 return mpack, objects
91
92
93 # ---------------------------------------------------------------------------
94 # Magic / format
95 # ---------------------------------------------------------------------------
96
97
98 def test_wire_mpack_starts_with_muse_magic() -> None:
99 mpack, _ = _make_mpack(2)
100 wire = build_wire_mpack(mpack)
101 assert wire[:4] == _MAGIC
102
103
104 def test_wire_mpack_version_byte_is_one() -> None:
105 mpack, _ = _make_mpack(1)
106 wire = build_wire_mpack(mpack)
107 assert wire[4] == 1
108
109
110 def test_wire_mpack_footer_sha256_is_valid() -> None:
111 mpack, _ = _make_mpack(2)
112 wire = build_wire_mpack(mpack)
113 body = wire[:-32]
114 stored = wire[-32:]
115 assert hashlib.sha256(body).digest() == stored
116
117
118 def test_parse_wire_mpack_raises_on_bad_magic() -> None:
119 bad = b"XXXX" + b"\x01\x04" + b"\x00" * 17 * 4 + b"\x00" * 100
120 with pytest.raises((ValueError, OSError)):
121 parse_wire_mpack(bad)
122
123
124 def test_parse_wire_mpack_raises_on_corrupted_footer() -> None:
125 mpack, _ = _make_mpack(1)
126 wire = bytearray(build_wire_mpack(mpack))
127 wire[-1] ^= 0xFF # flip footer byte
128 with pytest.raises(OSError):
129 parse_wire_mpack(bytes(wire))
130
131
132 # ---------------------------------------------------------------------------
133 # Round-trip — objects
134 # ---------------------------------------------------------------------------
135
136
137 def test_wire_mpack_round_trip_objects_byte_exact() -> None:
138 mpack, objects = _make_mpack(3)
139 wire = build_wire_mpack(mpack)
140 parsed = parse_wire_mpack(wire)
141 parsed_by_oid = {o["object_id"]: o["content"] for o in parsed["blobs"]}
142 for oid, content in objects:
143 assert parsed_by_oid[oid] == content
144
145
146 def test_wire_mpack_round_trip_objects_readable_after_apply(tmp_path: pathlib.Path) -> None:
147 repo = _init_repo(tmp_path)
148 mpack, objects = _make_mpack(4)
149 wire = build_wire_mpack(mpack)
150 parsed = parse_wire_mpack(wire)
151 apply_mpack(repo, parsed)
152 for oid, content in objects:
153 assert read_object(repo, oid) == content
154
155
156 def test_wire_mpack_round_trip_object_count() -> None:
157 mpack, objects = _make_mpack(7)
158 wire = build_wire_mpack(mpack)
159 parsed = parse_wire_mpack(wire)
160 assert len(parsed["blobs"]) == 7
161
162
163 # ---------------------------------------------------------------------------
164 # Round-trip — commits
165 # ---------------------------------------------------------------------------
166
167
168 def test_wire_mpack_round_trip_commits() -> None:
169 mpack, _ = _make_mpack(2)
170 wire = build_wire_mpack(mpack)
171 parsed = parse_wire_mpack(wire)
172 orig_cid = mpack["commits"][0]["commit_id"]
173 assert any(c["commit_id"] == orig_cid for c in parsed["commits"])
174
175
176 def test_wire_mpack_round_trip_commit_fields() -> None:
177 mpack, _ = _make_mpack(1)
178 wire = build_wire_mpack(mpack)
179 parsed = parse_wire_mpack(wire)
180 orig = mpack["commits"][0]
181 got = parsed["commits"][0]
182 assert got["commit_id"] == orig["commit_id"]
183 assert got["message"] == orig["message"]
184 assert got["snapshot_id"] == orig["snapshot_id"]
185 assert got["branch"] == orig["branch"]
186
187
188 # ---------------------------------------------------------------------------
189 # Round-trip — snapshots
190 # ---------------------------------------------------------------------------
191
192
193 def test_wire_mpack_round_trip_snapshots() -> None:
194 mpack, _ = _make_mpack(2)
195 wire = build_wire_mpack(mpack)
196 parsed = parse_wire_mpack(wire)
197 orig_sid = mpack["snapshots"][0]["snapshot_id"]
198 assert any(s["snapshot_id"] == orig_sid for s in parsed["snapshots"])
199
200
201 def test_wire_mpack_round_trip_snapshot_delta_upsert() -> None:
202 mpack, objects = _make_mpack(3)
203 wire = build_wire_mpack(mpack)
204 parsed = parse_wire_mpack(wire)
205 orig_delta = mpack["snapshots"][0]["delta_upsert"]
206 got_delta = parsed["snapshots"][0]["delta_upsert"]
207 assert got_delta == orig_delta
208
209
210 # ---------------------------------------------------------------------------
211 # Round-trip — tags
212 # ---------------------------------------------------------------------------
213
214
215 def test_wire_mpack_round_trip_tags() -> None:
216 mpack, _ = _make_mpack(1)
217 mpack["tags"] = [{
218 "tag_id": "sha256:" + "a" * 64,
219 "repo_id": "wire-test",
220 "commit_id": mpack["commits"][0]["commit_id"],
221 "tag": "v1.0.0",
222 "created_at": _DT.isoformat(),
223 }]
224 wire = build_wire_mpack(mpack)
225 parsed = parse_wire_mpack(wire)
226 assert len(parsed["tags"]) == 1
227 assert parsed["tags"][0]["tag"] == "v1.0.0"
228 assert parsed["tags"][0]["tag_id"] == "sha256:" + "a" * 64
229
230
231 # ---------------------------------------------------------------------------
232 # Objects section is byte-identical to Phase 1 local pack
233 # ---------------------------------------------------------------------------
234
235
236 def test_objects_section_identical_to_local_pack_bytes() -> None:
237 """The OBJECTS section in the wire bundle is byte-identical to _build_pack output."""
238 from muse.core.pack_store import _build_pack
239 mpack, objects = _make_mpack(4)
240 wire = build_wire_mpack(mpack)
241
242 # Parse section table to find OBJECTS section
243 section_count = wire[5]
244 cursor = 6
245 objects_offset = objects_length = None
246 for _ in range(section_count):
247 sec_type = wire[cursor]
248 sec_offset, sec_length = struct.unpack_from("<QQ", wire, cursor + 1)
249 cursor += 17
250 if sec_type == 1: # OBJECTS
251 objects_offset, objects_length = sec_offset, sec_length
252
253 assert objects_offset is not None, "No OBJECTS section found"
254 objects_section_bytes = wire[objects_offset:objects_offset + objects_length]
255
256 expected = _build_pack([(o["object_id"], o["content"]) for o in mpack["blobs"]])
257 assert objects_section_bytes == expected
258
259
260 # ---------------------------------------------------------------------------
261 # Zero loose objects after apply
262 # ---------------------------------------------------------------------------
263
264
265 def test_apply_wire_mpack_objects_accessible_after_apply(tmp_path: pathlib.Path) -> None:
266 repo = _init_repo(tmp_path)
267 mpack, object_pairs = _make_mpack(5)
268 wire = build_wire_mpack(mpack)
269 parsed = parse_wire_mpack(wire)
270 apply_mpack(repo, parsed)
271 for oid, content in object_pairs:
272 result = read_object(repo, oid)
273 assert result is not None, f"object {oid} not found after apply_mpack"
274 assert result == content
275
276
277 # ---------------------------------------------------------------------------
278 # Legacy msgpack detection
279 # ---------------------------------------------------------------------------
280
281
282 def test_legacy_msgpack_not_mistaken_for_wire_mpack() -> None:
283 """Bytes that don't start with b"MUSE" are not wire MPacks."""
284 import msgpack
285 legacy = msgpack.packb(
286 {"commits": [], "snapshots": [], "blobs": []},
287 use_bin_type=True,
288 )
289 assert legacy[:4] != _MAGIC
290
291
292 # ---------------------------------------------------------------------------
293 # Edge cases
294 # ---------------------------------------------------------------------------
295
296
297 def test_wire_mpack_empty_all_sections(tmp_path: pathlib.Path) -> None:
298 repo = _init_repo(tmp_path)
299 mpack: MPack = {"blobs": [], "snapshots": [], "commits": [], "tags": []}
300 wire = build_wire_mpack(mpack)
301 assert wire[:4] == _MAGIC
302 parsed = parse_wire_mpack(wire)
303 result = apply_mpack(repo, parsed)
304 assert result["blobs_written"] == 0
305 assert result["commits_written"] == 0
306
307
308 def test_wire_mpack_large_object_count(tmp_path: pathlib.Path) -> None:
309 repo = _init_repo(tmp_path)
310 mpack, objects = _make_mpack(200)
311 wire = build_wire_mpack(mpack)
312 parsed = parse_wire_mpack(wire)
313 apply_mpack(repo, parsed)
314 for oid, content in objects:
315 assert read_object(repo, oid) == content
316
317
318 # ---------------------------------------------------------------------------
319 # META section (type=5)
320 # ---------------------------------------------------------------------------
321
322
323 def test_wire_mpack_meta_round_trip() -> None:
324 mpack, _ = _make_mpack(1)
325 meta = {"repo_id": "wire-test", "head_commit_id": "sha256:" + "a" * 64}
326 wire = build_wire_mpack(mpack, meta=meta)
327 parsed = parse_wire_mpack(wire)
328 assert parsed.get("meta") == meta
329
330
331 def test_wire_mpack_no_meta_when_omitted() -> None:
332 mpack, _ = _make_mpack(1)
333 wire = build_wire_mpack(mpack)
334 parsed = parse_wire_mpack(wire)
335 assert "meta" not in parsed or parsed.get("meta") == {}
336
337
338 # ---------------------------------------------------------------------------
339 # Legacy msgpack bundle still applies
340 # ---------------------------------------------------------------------------
341
342
343 def test_legacy_msgpack_bundle_still_applies(tmp_path: pathlib.Path) -> None:
344 """A legacy msgpack-encoded bundle passes through apply_mpack unchanged."""
345 import msgpack
346
347 repo = _init_repo(tmp_path)
348 mpack, objects = _make_mpack(2)
349
350 # Build a legacy msgpack bundle (the pre-Phase 3 wire format)
351 legacy_dict = {
352 "blobs": [{"object_id": o["object_id"], "content": o["content"]}
353 for o in mpack["blobs"]],
354 "snapshots": mpack["snapshots"],
355 "commits": mpack["commits"],
356 "tags": [],
357 }
358 legacy_bytes = msgpack.packb(legacy_dict, use_bin_type=True)
359
360 # Confirm it is not mistaken for a wire MPack
361 assert legacy_bytes[:4] != _MAGIC
362
363 # Decode via the same detection path as transport.py
364 decoded = msgpack.unpackb(legacy_bytes, raw=False)
365
366 # apply_mpack accepts the plain dict — objects must land in the store
367 apply_mpack(repo, decoded)
368 for oid, content in objects:
369 assert read_object(repo, oid) == content
File History 1 commit