"""Phase 3 mpack unpack invariants — idempotency and rollback safety. wire_push_unpack_mpack is synchronous and inline (no background job). These tests verify the invariants that matter for the current design: 1. Idempotency — calling unpack-mpack twice with the same mpack produces no duplicate rows and leaves the DB in the correct final state. 2. Correctness — a single unpack-mpack call writes all commits, snapshots, and objects with the right storage_uri and counts. 3. Rollback + retry — rolling back after a failed unpack and retrying produces the same correct final state as a clean run. """ from __future__ import annotations import datetime import hashlib import pathlib import pytest import pytest_asyncio pytestmark = pytest.mark.skip(reason="muse wire protocol in flux") from httpx import AsyncClient, ASGITransport from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from musehub.auth.request_signing import MSignContext, require_signed_request, optional_signed_request from musehub.db.musehub_repo_models import MusehubObject, MusehubCommit, MusehubSnapshot from musehub.db.database import get_db from musehub.main import app from muse.core.object_store import write_object from muse.core.mpack import build_mpack, build_wire_mpack from muse.core.paths import muse_dir from muse.core.snapshot import compute_commit_id, compute_snapshot_id from muse.core.commits import CommitRecord, write_commit from muse.core.refs import write_branch_ref from muse.core.snapshots import SnapshotRecord, write_snapshot from muse.core.types import blob_id from musehub.types.json_types import JSONObject _AUTH_CTX = MSignContext( handle="gabriel", identity_id="sha256:" + "0" * 64, is_agent=False, is_admin=True, ) _N_FILES = 8 _N_COMMITS = 4 _FILES_CHANGED = 2 _BLOB_SIZE = 128 # ── fixtures ──────────────────────────────────────────────────────────────── @pytest_asyncio.fixture() async def client(db_session: AsyncSession) -> None: async def _override_get_db() -> None: yield db_session app.dependency_overrides[get_db] = _override_get_db app.dependency_overrides[require_signed_request] = lambda: _AUTH_CTX app.dependency_overrides[optional_signed_request] = lambda: _AUTH_CTX async with AsyncClient( transport=ASGITransport(app=app), base_url="https://localhost:1337", ) as c: yield c app.dependency_overrides.clear() @pytest_asyncio.fixture() async def repo(client: AsyncClient) -> None: resp = await client.post( "/api/repos", json={"owner": "gabriel", "name": "phase3-retry-test", "visibility": "public", "initialize": False}, ) assert resp.status_code in (200, 201), resp.text data = resp.json() yield data["slug"] await client.delete(f"/api/repos/{data['repoId']}") def _make_repo(tmp: pathlib.Path) -> tuple[pathlib.Path, str, bytes, dict]: """Build a local repo and return (path, head_commit_id, wire_bytes, raw_mpack).""" tmp.mkdir(parents=True, exist_ok=True) dot = muse_dir(tmp) dot.mkdir() (dot / "repo.json").write_text('{"repo_id":"phase3-test","owner":"gabriel"}') for d in ("commits", "snapshots", "objects"): (dot / d).mkdir() (dot / "refs" / "heads").mkdir(parents=True) (dot / "HEAD").write_text("ref: refs/heads/main\n") (dot / "config.toml").write_text("") blob_ids: list[str] = [] for i in range(_N_FILES): data = f"base-{i:04d}".encode() + b"x" * _BLOB_SIZE oid = blob_id(data) write_object(tmp, oid, data) blob_ids.append(oid) base_manifest = {f"src/file_{i:04d}.py": blob_ids[i] for i in range(_N_FILES)} parent = None tip = "" ts = datetime.datetime(2026, 1, 1, tzinfo=datetime.timezone.utc) for i in range(_N_COMMITS): manifest = dict(base_manifest) for j in range(_FILES_CHANGED): idx = (i * _FILES_CHANGED + j) % _N_FILES raw = f"c{i:04d}-f{j}".encode() + b"y" * _BLOB_SIZE oid = blob_id(raw) write_object(tmp, oid, raw) manifest[f"src/file_{idx:04d}.py"] = oid sid = compute_snapshot_id(manifest) write_snapshot(tmp, SnapshotRecord(snapshot_id=sid, manifest=manifest)) msg = f"commit-{i:05d}" cid = compute_commit_id( parent_ids=[parent] if parent else [], snapshot_id=sid, message=msg, committed_at_iso=ts.isoformat(), author="gabriel", ) write_commit(tmp, CommitRecord( commit_id=cid, branch="main", snapshot_id=sid, message=msg, committed_at=ts, parent_commit_id=parent, parent2_commit_id=None, author="gabriel", metadata={}, structured_delta=None, sem_ver_bump="none", breaking_changes=[], agent_id="", model_id="", toolchain_id="", prompt_hash="", signature="", signer_key_id="", )) parent = cid tip = cid ts += datetime.timedelta(seconds=60) write_branch_ref(tmp, "main", tip) raw_mpack = build_mpack(tmp, [tip], have=[]) wire_bytes = build_wire_mpack(raw_mpack) return tmp, tip, wire_bytes, raw_mpack async def _push_and_unpack(client: AsyncClient, repo_slug: str, wire_bytes: bytes, head: str) -> JSONObject: """Presign, PUT, and unpack a mpack. Returns the unpack-mpack response dict.""" import httpx as _httpx mpack_key = "sha256:" + hashlib.sha256(wire_bytes).hexdigest() pr = await client.post( f"/gabriel/{repo_slug}/push/mpack-presign", content=__import__("msgpack").packb( {"mpack_key": mpack_key, "size_bytes": len(wire_bytes)}, use_bin_type=True, ), headers={"Content-Type": "application/x-msgpack"}, ) assert pr.status_code == 200, pr.text upload_url = pr.json().get("upload_url") or pr.json().get("uploadUrl") assert upload_url async with _httpx.AsyncClient() as raw: put = await raw.put(upload_url, content=wire_bytes) assert put.status_code in (200, 204) ur = await client.post( f"/gabriel/{repo_slug}/push/unpack-mpack", content=__import__("msgpack").packb( {"mpack_key": mpack_key, "branch": "main", "head": head}, use_bin_type=True, ), headers={"Content-Type": "application/x-msgpack"}, ) assert ur.status_code == 200, ur.text return ur.json() # ── tests ──────────────────────────────────────────────────────────────────── @pytest.mark.asyncio async def test_unpack_mpack_idempotent( client: AsyncClient, repo: str, tmp_path: pathlib.Path, db_session: AsyncSession, ) -> None: """Calling unpack-mpack twice with the same mpack produces no duplicate rows. wire_push_unpack_mpack uses ON CONFLICT DO NOTHING / ON CONFLICT DO UPDATE throughout, so a second identical push must leave the DB in the same state as a single push — no extra rows, no constraint violations. """ _, head, wire_bytes, raw_mpack = _make_repo(tmp_path / "repo") result1 = await _push_and_unpack(client, repo, wire_bytes, head) assert result1.get("commits_written", 0) == _N_COMMITS, result1 # Second push — must not raise and must not duplicate rows. result2 = await _push_and_unpack(client, repo, wire_bytes, head) assert result2.get("commits_written", 0) == 0, ( "second push of same mpack must write 0 new commits (all already exist)" ) all_oids = [obj["object_id"] for obj in (raw_mpack.get("objects") or [])] rows = (await db_session.execute( select(MusehubObject).where(MusehubObject.object_id.in_(all_oids)) )).scalars().all() assert len(rows) == len(all_oids), ( f"expected exactly {len(all_oids)} object rows after double push, got {len(rows)}" ) @pytest.mark.asyncio async def test_unpack_mpack_writes_all_entities( client: AsyncClient, repo: str, tmp_path: pathlib.Path, db_session: AsyncSession, ) -> None: """A single unpack-mpack call correctly writes all commits, snapshots, and objects. Verifies counts in the response and presence in the DB with correct storage_uri. """ _, head, wire_bytes, raw_mpack = _make_repo(tmp_path / "repo") result = await _push_and_unpack(client, repo, wire_bytes, head) assert result.get("commits_written") == _N_COMMITS, ( f"expected {_N_COMMITS} commits_written, got {result}" ) assert result.get("snapshots_written") == _N_COMMITS, ( f"expected {_N_COMMITS} snapshots_written (one per commit), got {result}" ) n_objects = len(raw_mpack.get("objects") or []) assert result.get("objects_written") == n_objects, ( f"expected {n_objects} objects_written, got {result}" ) all_oids = [obj["object_id"] for obj in (raw_mpack.get("objects") or [])] rows = (await db_session.execute( select(MusehubObject).where(MusehubObject.object_id.in_(all_oids)) )).scalars().all() assert len(rows) == n_objects, ( f"expected {n_objects} musehub_objects rows, got {len(rows)}" ) # All objects must have a real storage_uri (mpack:// or mem://) — nothing pending. bad = [r for r in rows if r.storage_uri == "pending"] assert not bad, f"{len(bad)} objects still have storage_uri='pending'" @pytest.mark.asyncio async def test_unpack_mpack_retry_after_rollback( client: AsyncClient, repo: str, tmp_path: pathlib.Path, db_session: AsyncSession, ) -> None: """Rolling back after a failed unpack and retrying produces correct final state. Simulates a mid-flight DB failure: the session is rolled back after the first presign+PUT (mpack is already in MinIO), then unpack-mpack is called again. The retry must complete all writes correctly. """ _, head, wire_bytes, raw_mpack = _make_repo(tmp_path / "repo") mpack_key = "sha256:" + hashlib.sha256(wire_bytes).hexdigest() import httpx as _httpx, msgpack as _msgpack # Presign and PUT — mpack lands in MinIO. pr = await client.post( f"/gabriel/{repo}/push/mpack-presign", content=_msgpack.packb({"mpack_key": mpack_key, "size_bytes": len(wire_bytes)}, use_bin_type=True), headers={"Content-Type": "application/x-msgpack"}, ) assert pr.status_code == 200, pr.text upload_url = pr.json().get("upload_url") or pr.json().get("uploadUrl") async with _httpx.AsyncClient() as raw: put = await raw.put(upload_url, content=wire_bytes) assert put.status_code in (200, 204) # Simulate a crash: roll back whatever the presign step may have written. await db_session.rollback() # Retry unpack — must succeed and leave DB in correct state. ur = await client.post( f"/gabriel/{repo}/push/unpack-mpack", content=_msgpack.packb({"mpack_key": mpack_key, "branch": "main", "head": head}, use_bin_type=True), headers={"Content-Type": "application/x-msgpack"}, ) assert ur.status_code == 200, ur.text result = ur.json() assert result.get("commits_written") == _N_COMMITS, ( f"expected {_N_COMMITS} commits after rollback+retry, got {result}" ) all_oids = [obj["object_id"] for obj in (raw_mpack.get("objects") or [])] rows = (await db_session.execute( select(MusehubObject).where(MusehubObject.object_id.in_(all_oids)) )).scalars().all() assert len(rows) == len(all_oids), ( f"expected {len(all_oids)} objects after rollback+retry, got {len(rows)}" )