"""TDD — MPack wire format encode/decode (issue #70 Phase 3). Phase 3 replaces the msgpack bundle encoding with the MPack binary format on both client and server. Wire format: [4B] b"MUSE" [1B] version: 1 [1B] section_count [N*17B] section table: each (1B type, 8B offset, 8B length) [...] section data [32B] SHA-256 footer Sections: OBJECTS (type=1): raw _build_pack() bytes — byte-identical to Phase 1 .mpack COMMITS (type=2): [8B count] + N × [8B len + msgpack bytes] SNAPSHOTS (type=3): same TAGS (type=4): same """ from __future__ import annotations import datetime import hashlib import json import pathlib import struct import pytest from muse.core.mpack import MPack, apply_mpack, build_wire_mpack, parse_wire_mpack from muse.core.object_store import read_object from muse.core.paths import muse_dir from muse.core.ids import hash_commit as compute_commit_id, hash_snapshot as compute_snapshot_id from muse.core.store import CommitRecord from muse.core.types import blob_id _DT = datetime.datetime(2026, 1, 1, tzinfo=datetime.timezone.utc) _MAGIC = b"MUSE" # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- def _init_repo(root: pathlib.Path) -> pathlib.Path: dot = muse_dir(root) dot.mkdir(parents=True) (dot / "repo.json").write_text(json.dumps({"repo_id": "wire-test"})) for d in ("commits", "snapshots", "objects", "refs/heads"): (dot / d).mkdir(parents=True, exist_ok=True) (dot / "HEAD").write_text("ref: refs/heads/main\n") (dot / "config.toml").write_text("") return root def _make_mpack(n_objects: int = 3) -> tuple[MPack, list[tuple[str, bytes]]]: objects: list[tuple[str, bytes]] = [] manifest: dict[str, str] = {} for i in range(n_objects): content = f"wire-content-{i}".encode() * 16 oid = blob_id(content) objects.append((oid, content)) manifest[f"file_{i}.txt"] = oid sid = compute_snapshot_id(manifest) cid = compute_commit_id( parent_ids=[], snapshot_id=sid, message="wire test", committed_at_iso=_DT.isoformat(), ) mpack: MPack = { "objects": [{"object_id": oid, "content": raw} for oid, raw in objects], "snapshots": [{ "snapshot_id": sid, "parent_snapshot_id": None, "delta_add": manifest, "delta_remove": [], }], "commits": [CommitRecord( commit_id=cid, branch="main", snapshot_id=sid, message="wire test", committed_at=_DT, parent_commit_id=None, parent2_commit_id=None, author="", metadata={}, structured_delta=None, sem_ver_bump="none", breaking_changes=[], agent_id="", model_id="", toolchain_id="", prompt_hash="", signature="", signer_key_id="", ).to_dict()], "tags": [], } return mpack, objects # --------------------------------------------------------------------------- # Magic / format # --------------------------------------------------------------------------- def test_wire_mpack_starts_with_muse_magic() -> None: mpack, _ = _make_mpack(2) wire = build_wire_mpack(mpack) assert wire[:4] == _MAGIC def test_wire_mpack_version_byte_is_one() -> None: mpack, _ = _make_mpack(1) wire = build_wire_mpack(mpack) assert wire[4] == 1 def test_wire_mpack_footer_sha256_is_valid() -> None: mpack, _ = _make_mpack(2) wire = build_wire_mpack(mpack) body = wire[:-32] stored = wire[-32:] assert hashlib.sha256(body).digest() == stored def test_parse_wire_mpack_raises_on_bad_magic() -> None: bad = b"XXXX" + b"\x01\x04" + b"\x00" * 17 * 4 + b"\x00" * 100 with pytest.raises((ValueError, OSError)): parse_wire_mpack(bad) def test_parse_wire_mpack_raises_on_corrupted_footer() -> None: mpack, _ = _make_mpack(1) wire = bytearray(build_wire_mpack(mpack)) wire[-1] ^= 0xFF # flip footer byte with pytest.raises(OSError): parse_wire_mpack(bytes(wire)) # --------------------------------------------------------------------------- # Round-trip — objects # --------------------------------------------------------------------------- def test_wire_mpack_round_trip_objects_byte_exact() -> None: mpack, objects = _make_mpack(3) wire = build_wire_mpack(mpack) parsed = parse_wire_mpack(wire) parsed_by_oid = {o["object_id"]: o["content"] for o in parsed["objects"]} for oid, content in objects: assert parsed_by_oid[oid] == content def test_wire_mpack_round_trip_objects_readable_after_apply(tmp_path: pathlib.Path) -> None: repo = _init_repo(tmp_path) mpack, objects = _make_mpack(4) wire = build_wire_mpack(mpack) parsed = parse_wire_mpack(wire) apply_mpack(repo, parsed) for oid, content in objects: assert read_object(repo, oid) == content def test_wire_mpack_round_trip_object_count() -> None: mpack, objects = _make_mpack(7) wire = build_wire_mpack(mpack) parsed = parse_wire_mpack(wire) assert len(parsed["objects"]) == 7 # --------------------------------------------------------------------------- # Round-trip — commits # --------------------------------------------------------------------------- def test_wire_mpack_round_trip_commits() -> None: mpack, _ = _make_mpack(2) wire = build_wire_mpack(mpack) parsed = parse_wire_mpack(wire) orig_cid = mpack["commits"][0]["commit_id"] assert any(c["commit_id"] == orig_cid for c in parsed["commits"]) def test_wire_mpack_round_trip_commit_fields() -> None: mpack, _ = _make_mpack(1) wire = build_wire_mpack(mpack) parsed = parse_wire_mpack(wire) orig = mpack["commits"][0] got = parsed["commits"][0] assert got["commit_id"] == orig["commit_id"] assert got["message"] == orig["message"] assert got["snapshot_id"] == orig["snapshot_id"] assert got["branch"] == orig["branch"] # --------------------------------------------------------------------------- # Round-trip — snapshots # --------------------------------------------------------------------------- def test_wire_mpack_round_trip_snapshots() -> None: mpack, _ = _make_mpack(2) wire = build_wire_mpack(mpack) parsed = parse_wire_mpack(wire) orig_sid = mpack["snapshots"][0]["snapshot_id"] assert any(s["snapshot_id"] == orig_sid for s in parsed["snapshots"]) def test_wire_mpack_round_trip_snapshot_delta_add() -> None: mpack, objects = _make_mpack(3) wire = build_wire_mpack(mpack) parsed = parse_wire_mpack(wire) orig_delta = mpack["snapshots"][0]["delta_add"] got_delta = parsed["snapshots"][0]["delta_add"] assert got_delta == orig_delta # --------------------------------------------------------------------------- # Round-trip — tags # --------------------------------------------------------------------------- def test_wire_mpack_round_trip_tags() -> None: mpack, _ = _make_mpack(1) mpack["tags"] = [{ "tag_id": "sha256:" + "a" * 64, "repo_id": "wire-test", "commit_id": mpack["commits"][0]["commit_id"], "tag": "v1.0.0", "created_at": _DT.isoformat(), }] wire = build_wire_mpack(mpack) parsed = parse_wire_mpack(wire) assert len(parsed["tags"]) == 1 assert parsed["tags"][0]["tag"] == "v1.0.0" assert parsed["tags"][0]["tag_id"] == "sha256:" + "a" * 64 # --------------------------------------------------------------------------- # Objects section is byte-identical to Phase 1 local pack # --------------------------------------------------------------------------- def test_objects_section_identical_to_local_pack_bytes() -> None: """The OBJECTS section in the wire bundle is byte-identical to _build_pack output.""" from muse.core.pack_store import _build_pack mpack, objects = _make_mpack(4) wire = build_wire_mpack(mpack) # Parse section table to find OBJECTS section section_count = wire[5] cursor = 6 objects_offset = objects_length = None for _ in range(section_count): sec_type = wire[cursor] sec_offset, sec_length = struct.unpack_from(" None: repo = _init_repo(tmp_path) mpack, object_pairs = _make_mpack(5) wire = build_wire_mpack(mpack) parsed = parse_wire_mpack(wire) apply_mpack(repo, parsed) for oid, content in object_pairs: result = read_object(repo, oid) assert result is not None, f"object {oid} not found after apply_mpack" assert result == content # --------------------------------------------------------------------------- # Legacy msgpack detection # --------------------------------------------------------------------------- def test_legacy_msgpack_not_mistaken_for_wire_mpack() -> None: """Bytes that don't start with b"MUSE" are not wire MPacks.""" import msgpack legacy = msgpack.packb( {"commits": [], "snapshots": [], "objects": []}, use_bin_type=True, ) assert legacy[:4] != _MAGIC # --------------------------------------------------------------------------- # Edge cases # --------------------------------------------------------------------------- def test_wire_mpack_empty_all_sections(tmp_path: pathlib.Path) -> None: repo = _init_repo(tmp_path) mpack: MPack = {"objects": [], "snapshots": [], "commits": [], "tags": []} wire = build_wire_mpack(mpack) assert wire[:4] == _MAGIC parsed = parse_wire_mpack(wire) result = apply_mpack(repo, parsed) assert result["objects_written"] == 0 assert result["commits_written"] == 0 def test_wire_mpack_large_object_count(tmp_path: pathlib.Path) -> None: repo = _init_repo(tmp_path) mpack, objects = _make_mpack(200) wire = build_wire_mpack(mpack) parsed = parse_wire_mpack(wire) apply_mpack(repo, parsed) for oid, content in objects: assert read_object(repo, oid) == content # --------------------------------------------------------------------------- # META section (type=5) # --------------------------------------------------------------------------- def test_wire_mpack_meta_round_trip() -> None: mpack, _ = _make_mpack(1) meta = {"repo_id": "wire-test", "head_commit_id": "sha256:" + "a" * 64} wire = build_wire_mpack(mpack, meta=meta) parsed = parse_wire_mpack(wire) assert parsed.get("meta") == meta def test_wire_mpack_no_meta_when_omitted() -> None: mpack, _ = _make_mpack(1) wire = build_wire_mpack(mpack) parsed = parse_wire_mpack(wire) assert "meta" not in parsed or parsed.get("meta") == {} # --------------------------------------------------------------------------- # Legacy msgpack bundle still applies # --------------------------------------------------------------------------- def test_legacy_msgpack_bundle_still_applies(tmp_path: pathlib.Path) -> None: """A legacy msgpack-encoded bundle passes through apply_mpack unchanged.""" import msgpack repo = _init_repo(tmp_path) mpack, objects = _make_mpack(2) # Build a legacy msgpack bundle (the pre-Phase 3 wire format) legacy_dict = { "objects": [{"object_id": o["object_id"], "content": o["content"]} for o in mpack["objects"]], "snapshots": mpack["snapshots"], "commits": mpack["commits"], "tags": [], } legacy_bytes = msgpack.packb(legacy_dict, use_bin_type=True) # Confirm it is not mistaken for a wire MPack assert legacy_bytes[:4] != _MAGIC # Decode via the same detection path as transport.py decoded = msgpack.unpackb(legacy_bytes, raw=False) # apply_mpack accepts the plain dict — objects must land in the store apply_mpack(repo, decoded) for oid, content in objects: assert read_object(repo, oid) == content