"""Fetch path — wire_fetch_presign, wire_fetch_mpack, wire_fetch, process_mpack_gc_job.""" import asyncio import hashlib import logging import msgpack as _msgpack import time as _time_module from datetime import datetime, timezone from typing import TypedDict from sqlalchemy import func, select, text as _sa_text from sqlalchemy.dialects.postgresql import insert as _pg_insert from sqlalchemy.ext.asyncio import AsyncSession from musehub.db.musehub_repo_models import ( MusehubBranch, MusehubCommit, MusehubCommitGraph, MusehubFetchMPackCache, MusehubMPackIndex, MusehubObject, MusehubObjectRef, MusehubRepo, MusehubSnapshot, ) from musehub.models.wire import WireFetchRequest from muse.core.types import blob_id from musehub.storage import get_backend from musehub.services.musehub_wire_shared import ( FetchMPackResult, FetchNotIndexedError, FetchPresignResult, MPackValidationError, _reconstruct_manifest, _snap_row_to_wire_s3, _commit_to_wire_s3, _to_wire_commit, _utc_now, logger, ) type _CommitDeltaMap = dict[str, MusehubCommitGraph | MusehubCommit] async def _walk_commit_delta( session: AsyncSession, want: list[str] | set[str], have: list[str] | set[str], ) -> _CommitDeltaMap: _wcd_t0 = _time_module.perf_counter() _want_list = list(want) _have_list = list(have) logger.info("[_walk_commit_delta] START want=%d have=%d want_ids=%s", len(_want_list), len(_have_list), [cid[:16] for cid in _want_list[:5]]) have_set: frozenset[str] = frozenset(_have_list) starts = [cid for cid in _want_list if cid not in have_set] if not starts: logger.info("[_walk_commit_delta] SKIP — all want in have, 0ms") return {} if True: # fast path always active — commit graph is global (no repo_id) from sqlalchemy import func as _func want_gen_q = await session.execute( select(_func.max(MusehubCommitGraph.generation)) .where(MusehubCommitGraph.commit_id.in_(starts)) ) _want_gen_raw = want_gen_q.scalar() max_want_gen: int = _want_gen_raw or 0 _missing_from_graph: list[str] = [] _found_in_graph: list[tuple[str, int]] = [] for _scid in starts[:10]: _sg = await session.execute( select(MusehubCommitGraph.generation) .where(MusehubCommitGraph.commit_id == _scid) ) _sg_val = _sg.scalar_one_or_none() if _sg_val is None: _missing_from_graph.append(_scid[:16]) else: _found_in_graph.append((_scid[:16], _sg_val)) logger.info( "[_walk_commit_delta] want_gen_raw=%s max_want_gen=%d " "starts_in_graph=%s starts_missing_from_graph=%s", _want_gen_raw, max_want_gen, _found_in_graph, _missing_from_graph, ) min_have_gen: int = -1 if have_set: have_gen_q = await session.execute( select(_func.max(MusehubCommitGraph.generation)) .where(MusehubCommitGraph.commit_id.in_(list(have_set))) ) min_have_gen = have_gen_q.scalar() or -1 range_q = await session.execute( select( MusehubCommitGraph.commit_id, MusehubCommitGraph.parent_ids, MusehubCommitGraph.snapshot_id, ) .where(MusehubCommitGraph.generation > min_have_gen) .where(MusehubCommitGraph.generation <= max_want_gen) ) graph_map: dict[str, tuple[list[str], str | None]] = { cid: (pids or [], sid) for cid, pids, sid in range_q } logger.info( "[_walk_commit_delta] range_scan gen=(%d,%d] returned %d rows " "starts_in_map=%s", min_have_gen, max_want_gen, len(graph_map), [cid[:16] for cid in starts if cid in graph_map], ) visited_mem: set[str] = set(have_set) frontier_mem = [cid for cid in starts if cid not in visited_mem] reachable_cids: set[str] = set() while frontier_mem: next_mem: list[str] = [] for cid in frontier_mem: if cid in visited_mem: continue visited_mem.add(cid) reachable_cids.add(cid) pids_for_cid, _ = graph_map.get(cid, ([], None)) for p in pids_for_cid: if p not in visited_mem and p not in have_set: next_mem.append(p) frontier_mem = next_mem from types import SimpleNamespace as _SN needed_graph: dict[str, _SN] = {} for cid in reachable_cids: pids_ns, sid_ns = graph_map.get(cid, ([], None)) needed_graph[cid] = _SN(commit_id=cid, snapshot_id=sid_ns, parent_ids=pids_ns) _wcd_elapsed = (_time_module.perf_counter() - _wcd_t0) * 1000 logger.info( "[_walk_commit_delta] DONE (graph) commits=%d elapsed=%.1fms (%.3fms/commit) " "gen_range=(%d,%d] graph_rows=%d reachable=%d", len(needed_graph), _wcd_elapsed, _wcd_elapsed / max(len(needed_graph), 1), min_have_gen, max_want_gen, len(graph_map), len(reachable_cids), ) return needed_graph # type: ignore[return-value] # Legacy fallback from musehub.graph.walk import walk_dag_async _row_cache: dict[str, MusehubCommit] = {} _db_calls = 0 async def _adj(cid: str) -> list[str]: nonlocal _db_calls _db_calls += 1 row = await session.get(MusehubCommit, cid) if row is not None: _row_cache[cid] = row return row.parent_ids or [] if row else [] needed_legacy: dict[str, MusehubCommit] = {} async for cid in walk_dag_async(starts, _adj, exclude=have_set): if cid in _row_cache: needed_legacy[cid] = _row_cache[cid] _wcd_elapsed = (_time_module.perf_counter() - _wcd_t0) * 1000 logger.info( "[_walk_commit_delta] DONE (legacy) commits=%d db_calls=%d elapsed=%.1fms (%.2fms/commit)", len(needed_legacy), _db_calls, _wcd_elapsed, _wcd_elapsed / max(len(needed_legacy), 1), ) return needed_legacy async def wire_fetch_presign( session: AsyncSession, repo_id: str, req: WireFetchRequest, ttl_seconds: int = 3600, ) -> FetchPresignResult: import asyncio from datetime import timedelta _empty: FetchPresignResult = { "presign": False, "blob_urls": {}, "commits": [], "snapshots": [], "branch_heads": {}, "repo_id": repo_id, "domain": "", "default_branch": "main", "expires_at": None, "commit_count": 0, "blob_count": 0, } if not req.want: return _empty repo_row = await session.get(MusehubRepo, repo_id) if repo_row is None: return _empty _domain: str = repo_row.domain_id or "" _default_branch: str = repo_row.default_branch if repo_row.default_branch else "main" _empty["domain"] = _domain _empty["default_branch"] = _default_branch _empty["repo_id"] = repo_id have_set = set(req.have) needed_rows = await _walk_commit_delta(session, req.want, have_set) if not needed_rows: return {**_empty, "domain": _domain, "default_branch": _default_branch} _presign_commit_rows: dict[str, MusehubCommit] = {} _presign_any = next(iter(needed_rows.values())) if not isinstance(_presign_any, MusehubCommit): _PRESIGN_WIRE_BATCH = 2000 _presign_cids = list(needed_rows.keys()) for _pi in range(0, len(_presign_cids), _PRESIGN_WIRE_BATCH): _pq = await session.execute( select(MusehubCommit).where(MusehubCommit.commit_id.in_(_presign_cids[_pi : _pi + _PRESIGN_WIRE_BATCH])) ) for _pr in _pq.scalars(): _presign_commit_rows[_pr.commit_id] = _pr else: _presign_commit_rows = needed_rows # type: ignore[assignment] snap_ids = [r.snapshot_id for r in needed_rows.values() if r.snapshot_id] all_oids: set[str] = set() if snap_ids: snaps_q = await session.execute( select(MusehubSnapshot).where(MusehubSnapshot.snapshot_id.in_(snap_ids)) ) for snap in snaps_q.scalars().all(): manifest = ( _msgpack.unpackb(snap.manifest_blob, raw=False) if snap.manifest_blob else await _reconstruct_manifest(session, snap.snapshot_id) ) all_oids.update(v for v in manifest.values() if v) have_snap_ids: list[str] = [] if have_set: have_commits_q = await session.execute( select(MusehubCommit).where(MusehubCommit.commit_id.in_(have_set)) ) have_snap_ids = [r.snapshot_id for r in have_commits_q.scalars().all() if r.snapshot_id] have_oids: set[str] = set() if have_snap_ids: have_snaps_q = await session.execute( select(MusehubSnapshot).where(MusehubSnapshot.snapshot_id.in_(have_snap_ids)) ) for snap in have_snaps_q.scalars().all(): manifest = ( _msgpack.unpackb(snap.manifest_blob, raw=False) if snap.manifest_blob else await _reconstruct_manifest(session, snap.snapshot_id) ) have_oids.update(v for v in manifest.values() if v) new_oids = all_oids - have_oids n_objects = len(new_oids) n_commits = len(needed_rows) total_size = 0 if new_oids: size_q = await session.execute( select(func.coalesce(func.sum(MusehubObject.size_bytes), 0)).where( MusehubObject.object_id.in_(list(new_oids)) ) ) total_size = int(size_q.scalar() or 0) backend = get_backend() wire_commits = [ (await _commit_to_wire_s3(row, backend)).model_dump() for row in _presign_commit_rows.values() ] snap_rows_q = await session.execute( select(MusehubSnapshot).where(MusehubSnapshot.snapshot_id.in_(snap_ids)) ) wire_snaps = [ await _snap_row_to_wire_s3(snap, backend, session=session) for snap in snap_rows_q.scalars().all() ] branch_rows_q = await session.execute( select(MusehubBranch).where(MusehubBranch.repo_id == repo_id) ) branch_heads = { b.name: b.head_commit_id for b in branch_rows_q.scalars().all() if b.head_commit_id } sem = asyncio.Semaphore(50) logger.info( "fetch/presign: generating %d presigned GET URLs repo=%s/%s", len(new_oids), repo_row.owner, repo_row.slug, ) async def _presign_one(oid: str) -> tuple[str, str]: async with sem: url = await backend.presign_get(oid, ttl_seconds) logger.debug("fetch/presign: presigned oid=%s", oid) return oid, url pairs = await asyncio.gather(*(_presign_one(oid) for oid in new_oids)) blob_urls = {oid: url for oid, url in pairs} expires_at = (_utc_now() + timedelta(seconds=ttl_seconds)).isoformat() return { "presign": True, "blob_urls": blob_urls, "commits": wire_commits, "snapshots": wire_snaps, "branch_heads": branch_heads, "repo_id": repo_id, "domain": _domain, "default_branch": _default_branch, "expires_at": expires_at, "commit_count": n_commits, "blob_count": n_objects, } async def wire_fetch_mpack( session: AsyncSession, repo_id: str, want: list[str], have: list[str], ttl_seconds: int = 3600, ) -> FetchMPackResult: import msgpack as _msgpack_local _t0 = _time_module.perf_counter() def _ms() -> float: return (_time_module.perf_counter() - _t0) * 1000 logger.info("[wire_fetch_mpack] START repo_id=%s want=%d have=%d want_ids=%s", repo_id, len(want), len(have), [cid[:16] for cid in want[:5]]) _up_to_date: FetchMPackResult = { "mpack_url": None, "mpack_id": None, "commit_count": 0, "blob_count": 0, } if not want: logger.info("[wire_fetch_mpack] SKIP — want is empty") return _up_to_date backend = get_backend() have_set = set(have) logger.info("[wire_fetch_mpack] step=1 DAG walk starting t=%.1fms", _ms()) needed_rows = await _walk_commit_delta(session, want, have_set) logger.info("[wire_fetch_mpack] step=1 DAG walk done commits=%d t=%.1fms", len(needed_rows), _ms()) if not needed_rows: logger.info("[wire_fetch_mpack] SKIP — client already up-to-date (needed_rows empty)") return _up_to_date commit_rows: dict[str, MusehubCommit] = {} _any = next(iter(needed_rows.values())) _is_proxy = not isinstance(_any, MusehubCommit) logger.info("[wire_fetch_mpack] step=1b needed_rows=%d is_proxy=%s t=%.1fms", len(needed_rows), _is_proxy, _ms()) if _is_proxy: _cids = list(needed_rows.keys()) _q = await session.execute( select(MusehubCommit).where( _sa_text("commit_id = ANY(:ids)").bindparams(ids=_cids) ) ) for _row in _q.scalars(): commit_rows[_row.commit_id] = _row _missing_from_db = set(_cids) - set(commit_rows.keys()) logger.info( "[wire_fetch_mpack] step=1b bulk fetch done commits_in_db=%d missing_from_db=%d " "missing_ids=%s t=%.1fms", len(commit_rows), len(_missing_from_db), [cid[:16] for cid in list(_missing_from_db)[:5]], _ms(), ) else: commit_rows = needed_rows # type: ignore[assignment] logger.info("[wire_fetch_mpack] step=1b using MusehubCommit rows directly commits=%d t=%.1fms", len(commit_rows), _ms()) _proxy_snap_ids_raw = [r.snapshot_id for r in needed_rows.values()] _graph_snap_ids = [sid for sid in _proxy_snap_ids_raw if sid] _proxy_snap_none_count = sum(1 for s in _proxy_snap_ids_raw if not s) _commit_row_snap_ids = [r.snapshot_id for r in commit_rows.values() if r.snapshot_id] # Defensive fallback: CommitGraph rows from server-side merges may have snapshot_id=None # (pre-fix state). The MusehubCommit row always has the correct snapshot_id — merge both. snap_ids = list({*_graph_snap_ids, *_commit_row_snap_ids}) logger.info( "[wire_fetch_mpack] step=2 snap_ids_from_graph=%d snap_ids_none_in_graph=%d " "snap_ids_from_commit_rows=%d snap_ids_total=%d t=%.1fms", len(_graph_snap_ids), _proxy_snap_none_count, len(_commit_row_snap_ids), len(snap_ids), _ms(), ) snap_map: dict[str, dict] = {} if snap_ids: snaps_q = await session.execute( select(MusehubSnapshot).where( _sa_text("snapshot_id = ANY(:ids)").bindparams(ids=snap_ids) ) ) for snap in snaps_q.scalars().all(): snap_map[snap.snapshot_id] = await _snap_row_to_wire_s3(snap, backend, session=session) logger.info("[wire_fetch_mpack] step=2 snap_map loaded=%d t=%.1fms", len(snap_map), _ms()) all_oids: set[str] = set() _needed_cids = list(needed_rows.keys()) logger.warning("[GRAPH-DEBUG] wire_fetch_mpack: needed_rows=%d needed_cids_sample=%s", len(_needed_cids), [c[:16] for c in _needed_cids[:3]]) _debug_graph_q = await session.execute( select(MusehubCommitGraph.commit_id, MusehubCommitGraph.generation, MusehubCommitGraph.snapshot_id) .where(MusehubCommitGraph.commit_id.in_(_needed_cids)) .order_by(MusehubCommitGraph.generation.desc()) .limit(5) ) _debug_graph_rows = _debug_graph_q.all() logger.warning("[GRAPH-DEBUG] wire_fetch_mpack: CommitGraph has %d rows for needed_cids (top 5 by gen): %s", len(_debug_graph_rows), [(r[1], r[0][:16]) for r in _debug_graph_rows]) want_tip_snap_q = await session.execute( select(MusehubCommitGraph.snapshot_id) .where(MusehubCommitGraph.commit_id.in_(_needed_cids)) .order_by(MusehubCommitGraph.generation.desc()) .limit(1) ) want_tip_snap_id = want_tip_snap_q.scalar_one_or_none() logger.warning("[GRAPH-DEBUG] wire_fetch_mpack: want_tip_snap_id=%s", want_tip_snap_id[:20] if want_tip_snap_id else "NONE") logger.info("[wire_fetch_mpack] step=2 want_tip_snap_id=%s (from CommitGraph) t=%.1fms", want_tip_snap_id[:16] if want_tip_snap_id else None, _ms()) if want_tip_snap_id: wt_blob_q = await session.execute( select(MusehubSnapshot.manifest_blob) .where(MusehubSnapshot.snapshot_id == want_tip_snap_id) ) wt_blob = wt_blob_q.scalar_one_or_none() if wt_blob: all_oids.update(v for v in _msgpack_local.unpackb(wt_blob, raw=False).values() if v) logger.warning("[GRAPH-DEBUG] wire_fetch_mpack: want_tip manifest all_oids=%d wt_blob_present=%s", len(all_oids), wt_blob is not None) logger.info("[wire_fetch_mpack] step=2 want_tip manifest all_oids=%d wt_blob_present=%s t=%.1fms", len(all_oids), wt_blob is not None, _ms()) else: logger.warning( "[wire_fetch_mpack] step=2 WARN want_tip_snap_id=None — CommitGraph missing tip " "needed_cids=%s commit_rows_snap_ids=%s", [cid[:16] for cid in list(needed_rows.keys())[:5]], [sid[:16] for sid in _commit_row_snap_ids[:5]], ) have_oids: set[str] = set() if have_set: ht_snap_q = await session.execute( select(MusehubCommitGraph.snapshot_id) .where(MusehubCommitGraph.commit_id.in_(list(have_set))) .order_by(MusehubCommitGraph.generation.desc()) .limit(1) ) have_tip_snap_id = ht_snap_q.scalar_one_or_none() if have_tip_snap_id: ht_blob_q = await session.execute( select(MusehubSnapshot.manifest_blob) .where(MusehubSnapshot.snapshot_id == have_tip_snap_id) ) ht_blob = ht_blob_q.scalar_one_or_none() if ht_blob: have_oids.update(v for v in _msgpack_local.unpackb(ht_blob, raw=False).values() if v) new_oids = all_oids - have_oids logger.info( "[wire_fetch_mpack] step=2 done snap_map=%d all_oids=%d have_oids=%d new_oids=%d t=%.1fms", len(snap_map), len(all_oids), len(have_oids), len(new_oids), _ms(), ) if new_oids: indexed_q = await session.execute( select(MusehubMPackIndex.entity_id) .where(MusehubMPackIndex.entity_id.in_(list(new_oids))) .where(MusehubMPackIndex.entity_type == "object") ) indexed_oids = {row[0] for row in indexed_q} missing = new_oids - indexed_oids if missing: logger.warning( "[wire_fetch_mpack] step=3 NOT INDEXED %d/%d objects — raising FetchNotIndexedError t=%.1fms", len(missing), len(new_oids), _ms(), ) raise FetchNotIndexedError(len(missing)) logger.info("[wire_fetch_mpack] step=3 index coverage OK oids=%d t=%.1fms", len(new_oids), _ms()) cache_hits: dict[str, bytes] = {} if new_oids: _CACHE_CHUNK = 10000 _new_oid_list = list(new_oids) for _ci in range(0, len(_new_oid_list), _CACHE_CHUNK): _chunk = _new_oid_list[_ci : _ci + _CACHE_CHUNK] _cache_q = await session.execute( select(MusehubObject.object_id, MusehubObject.content_cache) .where(MusehubObject.object_id.in_(_chunk)) .where(MusehubObject.content_cache.isnot(None)) ) for _oid, _cached in _cache_q: if _cached: cache_hits[_oid] = bytes(_cached) cache_miss_oids = [oid for oid in new_oids if oid not in cache_hits] oid_to_mpack: dict[str, str] = {} if cache_miss_oids: _MIDX_CHUNK = 10000 for _ci in range(0, len(cache_miss_oids), _MIDX_CHUNK): _chunk = cache_miss_oids[_ci : _ci + _MIDX_CHUNK] _midx_q = await session.execute( select(MusehubMPackIndex.entity_id, MusehubMPackIndex.mpack_id) .where(MusehubMPackIndex.entity_id.in_(_chunk)) .where(MusehubMPackIndex.entity_type == "object") ) for _oid, _mid in _midx_q: oid_to_mpack[_oid] = _mid mpack_to_oids: dict[str, list[str]] = {} no_mpack_oids: list[str] = [] for oid in cache_miss_oids: mid = oid_to_mpack.get(oid) if mid: mpack_to_oids.setdefault(mid, []).append(oid) else: no_mpack_oids.append(oid) mpack_hits: dict[str, bytes] = {} mpack_miss_oids: list[str] = [] _sem_mpack = asyncio.Semaphore(8) async def _extract_from_mpack(mpack_id: str, oids: list[str]) -> None: async with _sem_mpack: raw = await backend.get_mpack(mpack_id) if raw is None: mpack_miss_oids.extend(oids) return import zstandard as _zstd_phase1 _dctx_phase1 = _zstd_phase1.ZstdDecompressor() try: if raw[:4] == b"MUSE": from muse.core.mpack import parse_wire_mpack as _parse_wire_fetch payload = _parse_wire_fetch(raw) else: payload = _msgpack_local.unpackb(raw, raw=False) except Exception as _parse_err: logger.warning( "[_extract_from_mpack] failed to parse mpack=%s: %s", mpack_id[:20], _parse_err, ) mpack_miss_oids.extend(oids) return obj_index: dict[str, bytes] = {} for o in payload.get("blobs", []): oid_entry = o.get("object_id", "") content = o.get("content") or b"" if not isinstance(content, bytes): content = bytes(content) _ZSTD_MAGIC = b"\x28\xb5\x2f\xfd" if (o.get("encoding") == "zstd" or content[:4] == _ZSTD_MAGIC) and content: try: content = _dctx_phase1.decompress(content) except Exception as _decomp_err: logger.warning( "[_extract_from_mpack] zstd decompress failed oid=%s: %s", oid_entry[:20], _decomp_err, ) continue obj_index[oid_entry] = content for oid in oids: content = obj_index.get(oid) if content is not None: mpack_hits[oid] = content else: mpack_miss_oids.append(oid) if mpack_to_oids: await asyncio.gather( *(_extract_from_mpack(mid, oids) for mid, oids in mpack_to_oids.items()) ) legacy_hits: dict[str, bytes] = {} _fallback_oids = no_mpack_oids + mpack_miss_oids if _fallback_oids: _sem_legacy = asyncio.Semaphore(50) async def _get_legacy(oid: str) -> None: async with _sem_legacy: data = await backend.get(oid) if data: legacy_hits[oid] = data await asyncio.gather(*(_get_legacy(oid) for oid in _fallback_oids)) _all_blob_bytes: dict[str, bytes] = {**legacy_hits, **mpack_hits, **cache_hits} blob_pairs = [(oid, _all_blob_bytes[oid]) for oid in new_oids if oid in _all_blob_bytes] logger.info( "[wire_fetch_mpack] step=4 fetched %d blobs (cache=%d mpack=%d legacy=%d) t=%.1fms", len(blob_pairs), len(cache_hits), len(mpack_hits), len(legacy_hits), _ms(), ) wire_commits = [ (await _commit_to_wire_s3(row, backend)).model_dump() for row in commit_rows.values() ] wire_snaps = [snap_map[sid] for sid in snap_ids if sid in snap_map] wire_blobs = [ {"object_id": oid, "content": data} for oid, data in blob_pairs if data ] logger.info( "[wire_fetch_mpack] step=5 assembly: wire_commits=%d wire_snaps=%d wire_blobs=%d " "snap_ids_total=%d snap_ids_in_map=%d commit_rows=%d t=%.1fms", len(wire_commits), len(wire_snaps), len(wire_blobs), len(snap_ids), sum(1 for sid in snap_ids if sid in snap_map), len(commit_rows), _ms(), ) from muse.core.mpack import build_wire_mpack as _build_wire_mpack _head_commit_id = want[0] if want else "" mpack_bytes = _build_wire_mpack( { "commits": wire_commits, "snapshots": wire_snaps, "blobs": wire_blobs, "tags": [], }, meta={"repo_id": repo_id, "head_commit_id": _head_commit_id}, ) mpack_id = blob_id(mpack_bytes) n_commits = len(wire_commits) n_blobs = len(wire_blobs) logger.info( "[wire_fetch_mpack] step=5 assembled commits=%d snapshots=%d blobs=%d bytes=%d t=%.1fms", n_commits, len(wire_snaps), n_blobs, len(mpack_bytes), _ms(), ) await backend.put_mpack(mpack_id, mpack_bytes) mpack_url = await backend.presign_mpack_get(mpack_id, ttl_seconds) logger.info( "[wire_fetch_mpack] step=6 mpack_id=%s mpack_url=%s t=%.1fms", mpack_id[:20], mpack_url[:80] if mpack_url else None, _ms(), ) logger.info("[wire_fetch_mpack] RETURN commits=%d blobs=%d TOTAL=%.1fms", n_commits, n_blobs, _ms()) async def _cleanup() -> None: await asyncio.sleep(ttl_seconds) try: await backend.delete(mpack_id) except Exception: pass asyncio.ensure_future(_cleanup()) return { "mpack_url": mpack_url, "mpack_id": mpack_id, "commit_count": n_commits, "blob_count": n_blobs, } async def _check_missing_objects( session: AsyncSession, needs_check: set[str], ) -> set[str]: if not needs_check: return set() from musehub.db.musehub_repo_models import MusehubObject registered: set[str] = set( (await session.execute( select(MusehubObject.object_id).where( MusehubObject.object_id.in_(list(needs_check)), MusehubObject.deleted_at.is_(None), ) )).scalars().all() ) return needs_check - registered class MPackGCResult(TypedDict): skipped: bool packs_before: int packs_after: int consolidated_key: str async def process_mpack_gc_job(session: AsyncSession, repo_id: str) -> MPackGCResult: import msgpack as _mp _skipped: MPackGCResult = { "skipped": True, "packs_before": 0, "packs_after": 0, "consolidated_key": "", } repo_oids_q = await session.execute( select(MusehubObjectRef.object_id) .where(MusehubObjectRef.repo_id == repo_id) ) repo_oid_set = {row[0] for row in repo_oids_q} mpack_q = await session.execute( select(MusehubMPackIndex.mpack_id) .where(MusehubMPackIndex.entity_id.in_(list(repo_oid_set))) .where(MusehubMPackIndex.entity_type == "object") .distinct() ) mpack_ids = [row[0] for row in mpack_q] packs_before = len(mpack_ids) if packs_before <= 1: _skipped["packs_before"] = packs_before if mpack_ids: _skipped["consolidated_key"] = mpack_ids[0] return _skipped import musehub.storage.backends as _backends_mod backend = _backends_mod.get_backend() merged_objects: dict[str, bytes] = {} async def _download(pid: str) -> None: raw = await backend.get_mpack(pid) if not raw: logger.warning("[mpack_gc] mpack not found in storage: %s", pid) return if raw[:4] == b"MUSE": from muse.core.mpack import parse_wire_mpack as _parse_gc _parsed = _parse_gc(raw) else: _parsed = _mp.unpackb(raw, raw=False) for obj in _parsed.get("blobs", []): oid = obj.get("object_id", "") content = obj.get("content", b"") if oid and oid not in merged_objects: merged_objects[oid] = content await asyncio.gather(*(_download(pid) for pid in mpack_ids)) from muse.core.mpack import build_wire_mpack as _build_gc_mpack consolidated_bytes = _build_gc_mpack({ "commits": [], "snapshots": [], "blobs": [ {"object_id": oid, "content": merged_objects[oid]} for oid in sorted(merged_objects) ], "tags": [], }) consolidated_key = "sha256:" + hashlib.sha256(consolidated_bytes).hexdigest() await backend.put_mpack(consolidated_key, consolidated_bytes) old_mpack_ids = [p for p in mpack_ids if p != consolidated_key] if old_mpack_ids: from sqlalchemy import delete as sa_delete await session.execute( sa_delete(MusehubMPackIndex) .where(MusehubMPackIndex.mpack_id.in_(old_mpack_ids)) .where(MusehubMPackIndex.entity_type == "object") ) _gc_now = datetime.now(timezone.utc) new_rows = [ { "entity_id": oid, "mpack_id": consolidated_key, "entity_type": "object", "created_at": _gc_now, } for oid in merged_objects ] if new_rows: _GC_MIDX_CHUNK = 5000 for _gmi in range(0, len(new_rows), _GC_MIDX_CHUNK): await session.execute( _pg_insert(MusehubMPackIndex) .values(new_rows[_gmi : _gmi + _GC_MIDX_CHUNK]) .on_conflict_do_nothing(index_elements=["entity_id", "mpack_id"]) ) logger.info( "[mpack_gc] repo=%s consolidated %d mpacks → 1 (objects=%d key=%s)", repo_id, packs_before, len(merged_objects), consolidated_key, ) return { "skipped": False, "packs_before": packs_before, "packs_after": 1, "consolidated_key": consolidated_key, } class FetchResult(TypedDict): mpack_id: str mpack_url: str | None commit_count: int blob_count: int class FetchMPackPrebuildResult(TypedDict): tips_requested: int tips_built: int tips_skipped: int elapsed_ms: float async def process_fetch_mpack_prebuild_job( session: AsyncSession, job_id: str, ) -> FetchMPackPrebuildResult: """Build and cache a fetch mpack for every branch tip in the job payload. Called by the background worker after every push. For each tip commit ID in ``payload["tip_commit_ids"]``, checks whether a fresh cache entry already exists in ``musehub_fetch_mpack_cache``; skips tips that are cached and builds the rest by calling ``wire_fetch_mpack``. The mpack_id returned by ``wire_fetch_mpack`` is written (or upserted) into ``musehub_fetch_mpack_cache`` so that subsequent fetch requests hit the cache and return a presigned URL in under a second. """ from datetime import timedelta from musehub.db.musehub_jobs_models import MusehubBackgroundJob from sqlalchemy.dialects.postgresql import insert as _upsert _t0 = _time_module.monotonic() def _ms() -> float: return (_time_module.monotonic() - _t0) * 1000 job_row = (await session.execute( select(MusehubBackgroundJob).where(MusehubBackgroundJob.job_id == job_id) )).scalar_one_or_none() if job_row is None: raise ValueError(f"fetch.mpack.prebuild job not found: {job_id}") repo_id: str = job_row.repo_id payload = job_row.payload or {} tip_commit_ids: list[str] = [str(t) for t in (payload.get("tip_commit_ids") or [])] if not tip_commit_ids: logger.warning("[fetch.mpack.prebuild] job=%s repo=%s no tip_commit_ids in payload", job_id[:16], repo_id) return {"tips_requested": 0, "tips_built": 0, "tips_skipped": 0, "elapsed_ms": 0.0} # Find which tips already have a fresh (non-expired) cache entry. now = _utc_now() cached_q = await session.execute( select(MusehubFetchMPackCache.tip_commit_id) .where(MusehubFetchMPackCache.repo_id == repo_id) .where(MusehubFetchMPackCache.tip_commit_id.in_(tip_commit_ids)) .where(MusehubFetchMPackCache.expires_at > now) ) already_cached: set[str] = {row[0] for row in cached_q} tips_to_build = [t for t in tip_commit_ids if t not in already_cached] tips_skipped = len(already_cached) logger.warning( "[fetch.mpack.prebuild] job=%s repo=%s tips=%d cached=%d to_build=%d t=%.1fms", job_id[:16], repo_id, len(tip_commit_ids), tips_skipped, len(tips_to_build), _ms(), ) tips_built = 0 for tip in tips_to_build: _tip_t0 = _time_module.monotonic() try: result = await wire_fetch_mpack(session, repo_id, want=[tip], have=[]) mpack_id = result.get("mpack_id") or "" if not mpack_id: logger.warning( "[fetch.mpack.prebuild] tip=%s produced no mpack_id — skipping cache write", tip[:20], ) continue cache_id = blob_id((repo_id + tip).encode()).replace("sha256:", "") expires_at = now + timedelta(days=7) await session.execute( _upsert(MusehubFetchMPackCache) .values( cache_id=cache_id, repo_id=repo_id, tip_commit_id=tip, mpack_id=mpack_id, created_at=now, expires_at=expires_at, ) .on_conflict_do_update( index_elements=["repo_id", "tip_commit_id"], set_={"mpack_id": mpack_id, "expires_at": expires_at}, ) ) tips_built += 1 _tip_ms = (_time_module.monotonic() - _tip_t0) * 1000 logger.warning( "[fetch.mpack.prebuild] built tip=%s mpack_id=%s t=%.1fms", tip[:20], mpack_id[:20], _tip_ms, ) except Exception as exc: logger.error( "[fetch.mpack.prebuild] tip=%s FAILED: %s", tip[:20], exc, exc_info=True, ) total_ms = _ms() logger.warning( "[fetch.mpack.prebuild] DONE job=%s repo=%s tips=%d built=%d skipped=%d TOTAL=%.1fms", job_id[:16], repo_id, len(tip_commit_ids), tips_built, tips_skipped, total_ms, ) return { "tips_requested": len(tip_commit_ids), "tips_built": tips_built, "tips_skipped": tips_skipped, "elapsed_ms": total_ms, } class FetchCommitNotFound(Exception): """A want commit_id does not exist in musehub_commits.""" class FetchNotReady(Exception): """Needed objects are absent from musehub_mpack_index — client must retry.""" async def wire_fetch( session: AsyncSession, repo_id: str, want: list[str], have: list[str], ttl_seconds: int = 3600, ) -> FetchResult: import msgpack as _mp _empty: FetchResult = {"mpack_id": "", "mpack_url": None, "commit_count": 0, "blob_count": 0} for entry in want: if not (isinstance(entry, str) and entry.startswith("sha256:")): raise MPackValidationError(f"want entry is not a sha256: id: {entry!r}") for entry in have: if not (isinstance(entry, str) and entry.startswith("sha256:")): raise MPackValidationError(f"have entry is not a sha256: id: {entry!r}") if want: existing_q = await session.execute( select(MusehubCommit.commit_id).where(MusehubCommit.commit_id.in_(want)) ) found = {row[0] for row in existing_q} missing_want = [cid for cid in want if cid not in found] if missing_want: raise FetchCommitNotFound(missing_want[0]) have_set = set(have) needed = await _walk_commit_delta(session, want, have_set) if not needed: return _empty cids = list(needed.keys()) commit_rows: dict[str, MusehubCommit] = {} for i in range(0, len(cids), 2000): q = await session.execute( select(MusehubCommit).where(MusehubCommit.commit_id.in_(cids[i:i + 2000])) ) for row in q.scalars(): commit_rows[row.commit_id] = row want_snap_ids = {r.snapshot_id for r in needed.values() if r.snapshot_id} have_snap_ids: set[str] = set() if have_set: have_commits_q = await session.execute( select(MusehubCommit.snapshot_id).where(MusehubCommit.commit_id.in_(list(have_set))) ) have_snap_ids = {row[0] for row in have_commits_q if row[0]} new_snap_ids = want_snap_ids - have_snap_ids snap_map: dict[str, dict] = {} new_oids: set[str] = set() backend = get_backend() if new_snap_ids: snaps_q = await session.execute( select(MusehubSnapshot).where(MusehubSnapshot.snapshot_id.in_(list(new_snap_ids))) ) for snap in snaps_q.scalars(): manifest = ( _mp.unpackb(snap.manifest_blob, raw=False) if snap.manifest_blob else await _reconstruct_manifest(session, snap.snapshot_id) ) new_oids.update(v for v in manifest.values() if v) snap_map[snap.snapshot_id] = await _snap_row_to_wire_s3(snap, backend, session=session) if have_snap_ids: have_snaps_q = await session.execute( select(MusehubSnapshot).where(MusehubSnapshot.snapshot_id.in_(list(have_snap_ids))) ) for snap in have_snaps_q.scalars(): m = ( _mp.unpackb(snap.manifest_blob, raw=False) if snap.manifest_blob else await _reconstruct_manifest(session, snap.snapshot_id) ) new_oids -= {v for v in m.values() if v} if new_oids: idx_q = await session.execute( select(MusehubMPackIndex.entity_id).where( MusehubMPackIndex.entity_id.in_(list(new_oids)), MusehubMPackIndex.entity_type == "object", ) ) indexed = {row[0] for row in idx_q} unindexed = new_oids - indexed if unindexed: raise FetchNotReady(f"{len(unindexed)} object(s) not yet in mpack_index") objects: list[dict] = [] if new_oids: obj_q = await session.execute( select(MusehubObject).where(MusehubObject.object_id.in_(list(new_oids))) ) for obj_row in obj_q.scalars(): if obj_row.content_cache is not None: content = obj_row.content_cache else: content = await backend.get(obj_row.object_id) or b"" objects.append({"object_id": obj_row.object_id, "content": content}) wire_commits = [_to_wire_commit(r).model_dump() for r in commit_rows.values()] from muse.core.mpack import build_wire_mpack as _build_fetch_mpack wire_bytes = _build_fetch_mpack({ "commits": wire_commits, "snapshots": list(snap_map.values()), "blobs": objects, "tags": [], }) mpack_id = blob_id(wire_bytes) await backend.put_mpack(mpack_id, wire_bytes) mpack_url = await backend.presign_mpack_get(mpack_id, ttl_seconds) return { "mpack_id": mpack_id, "mpack_url": mpack_url, "commit_count": len(commit_rows), "blob_count": len(objects), }