#!/usr/bin/env python3 """One-time backfill: decompress zlib-stored objects in R2 and replace with plain bytes. Objects pushed via the old wire path were stored zlib-compressed in R2 under the SHA-256 of their *plain* content. This violates content-addressing: the declared identity (SHA-256 of plain bytes) does not match the stored bytes (compressed). This script corrects all such objects: 1. Pages through musehub_objects rows where storage_uri starts with "s3://". 2. Fetches each object from R2. 3. Skips objects that are already plain bytes. 4. For zlib-compressed objects: a. Decompresses. b. Verifies SHA-256(decompressed) == object_id. Skips on mismatch. c. Re-uploads plain bytes to R2 (same key — idempotent). d. Updates size_bytes in DB (content_cache stays NULL). 5. Reports totals and any errors. After a successful run, decompress_if_needed() is no longer needed on the read path — all objects in R2 are guaranteed to be plain bytes. Run inside Docker on the target instance: docker exec musehub-blue python3 /app/deploy/decompress_objects.py [--dry-run] [--batch 200] [--concurrency 16] Options: --dry-run Print what would be changed without touching R2 or the DB. --batch N DB page size (default 200). --concurrency N Parallel R2 fetches per batch (default 16). --repo-id UUID Limit to a single repo (targeted fix). """ from __future__ import annotations import argparse import asyncio import sys import time import zlib import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker from muse.core.types import blob_id, split_id from musehub.config import settings from musehub.db.musehub_repo_models import MusehubObject, MusehubObjectRef from musehub.storage import get_backend _ZLIB_MAGIC = (b"\x78\x01", b"\x78\x9c", b"\x78\xda", b"\x78\x5e") def _is_zlib(data: bytes) -> bool: return len(data) >= 2 and data[:2] in _ZLIB_MAGIC def _decompress(data: bytes) -> bytes | None: try: return zlib.decompress(data) except zlib.error: return None def _fmt_eta(seconds: float) -> str: if seconds < 60: return f"{seconds:.0f}s" if seconds < 3600: return f"{seconds / 60:.1f}m" return f"{seconds / 3600:.1f}h" async def _get_header(backend: object, object_id: str) -> bytes | None: """Return the first 2 bytes of an object using a Range GET.""" client = backend._get_client() # type: ignore[attr-defined] key = backend._key(object_id) # type: ignore[attr-defined] def _range_get() -> bytes | None: try: resp = client.get_object( Bucket=backend._bucket, Key=key, Range="bytes=0-1" # type: ignore[attr-defined] ) return resp["Body"].read(2) except Exception: return None return await asyncio.to_thread(_range_get) async def backfill(dry_run: bool, quiet: bool, batch_size: int, concurrency: int, repo_id: str | None) -> int: engine = create_async_engine(settings.database_url, echo=False) async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) backend = get_backend() # ── count total objects up front so we know the denominator ────────────── async with async_session() as session: count_stmt = ( sa.select(sa.func.count()).select_from(MusehubObject) .where( MusehubObject.storage_uri.like("s3://%"), MusehubObject.deleted_at.is_(None), ) ) if repo_id: count_stmt = ( sa.select(sa.func.count()).select_from(MusehubObject) .join(MusehubObjectRef, MusehubObject.object_id == MusehubObjectRef.object_id) .where( MusehubObjectRef.repo_id == repo_id, MusehubObject.storage_uri.like("s3://%"), MusehubObject.deleted_at.is_(None), ) ) total_objects: int = (await session.execute(count_stmt)).scalar_one() scope = f"repo_id={repo_id}" if repo_id else "all repos" print(f"Backfill scope: {scope}") print(f"Total objects to scan: {total_objects:,}") if total_objects == 0: print("Nothing to do.") return 0 print() # ── shared progress state (updated inside asyncio tasks) ───────────────── done_count = 0 plain_count = 0 decompressed_count = 0 hash_mismatch_count = 0 error_count = 0 start_time = time.monotonic() progress_lock = asyncio.Lock() def _progress_line(extra: str = "") -> None: if quiet: return elapsed = time.monotonic() - start_time rate = done_count / elapsed if elapsed > 0 else 0 remaining = total_objects - done_count eta_str = _fmt_eta(remaining / rate) if rate > 0 else "?" pct = 100 * done_count / total_objects if total_objects else 100 print( f"\r [{done_count:>{len(str(total_objects))}}/{total_objects}]" f" {pct:5.1f}%" f" {remaining:,} remaining" f" {rate:.1f} obj/s" f" ETA {eta_str}" + (f" {extra}" if extra else ""), end="", flush=True, ) total_checked = 0 total_errors = 0 async with async_session() as session: offset = 0 while True: obj_stmt = ( sa.select(MusehubObject.object_id) .where( MusehubObject.storage_uri.like("s3://%"), MusehubObject.deleted_at.is_(None), ) ) if repo_id: obj_stmt = ( sa.select(MusehubObject.object_id) .join(MusehubObjectRef, MusehubObject.object_id == MusehubObjectRef.object_id) .where( MusehubObjectRef.repo_id == repo_id, MusehubObject.storage_uri.like("s3://%"), MusehubObject.deleted_at.is_(None), ) ) rows = (await session.execute( obj_stmt .order_by(MusehubObject.object_id) .offset(offset) .limit(batch_size) )).scalars().all() if not rows: break offset += len(rows) total_checked += len(rows) sem = asyncio.Semaphore(concurrency) async def _process(oid: str) -> tuple[str, str, int]: nonlocal done_count, plain_count, decompressed_count nonlocal hash_mismatch_count, error_count async with sem: status = "plain" new_size = 0 detail = "" try: header = await _get_header(backend, oid) except Exception as exc: print(f"\n ERROR fetching header {oid}: {exc}", file=sys.stderr) async with progress_lock: done_count += 1 error_count += 1 _progress_line() return oid, "error", 0 if header is None or not _is_zlib(header): async with progress_lock: done_count += 1 plain_count += 1 _progress_line() return oid, "plain", 0 # Has zlib header — fetch full object. try: data = await backend.get(oid) except Exception as exc: print(f"\n ERROR fetching {oid}: {exc}", file=sys.stderr) async with progress_lock: done_count += 1 error_count += 1 _progress_line() return oid, "error", 0 if data is None: async with progress_lock: done_count += 1 plain_count += 1 _progress_line() return oid, "plain", 0 decompressed = _decompress(data) if decompressed is None: detail = f"zlib header but decompress failed — skipping" async with progress_lock: done_count += 1 error_count += 1 _progress_line(f"WARN {oid} {detail}") return oid, "error", 0 _, bare_hex = split_id(oid) if blob_id(decompressed) != oid: _, actual = split_id(blob_id(decompressed)) detail = f"hash mismatch (declared={bare_hex[:12]}… actual={actual[:12]}…)" async with progress_lock: done_count += 1 hash_mismatch_count += 1 _progress_line(f"WARN {oid} {detail}") return oid, "hash_mismatch", 0 new_size = len(decompressed) verb = "[dry] decompress" if dry_run else "decompress" async with progress_lock: done_count += 1 decompressed_count += 1 _progress_line(f"{verb} {oid} ({len(data)} → {new_size} bytes)") if dry_run: return oid, "decompressed", new_size try: await backend.put(oid, decompressed) except Exception as exc: print(f"\n ERROR re-uploading {oid}: {exc}", file=sys.stderr) async with progress_lock: error_count += 1 return oid, "error", 0 return oid, "decompressed", new_size r2_results = await asyncio.gather(*(_process(oid) for oid in rows)) # DB updates for successfully decompressed objects. if not dry_run: for oid, status, new_size in r2_results: if status != "decompressed": continue try: await session.execute( sa.update(MusehubObject) .where(MusehubObject.object_id == oid) .values(size_bytes=new_size) ) await session.commit() except Exception as exc: print(f"\n ERROR updating DB for {oid}: {exc}", file=sys.stderr) await session.rollback() async with progress_lock: error_count += 1 # Final newline after the inline progress line. if not quiet: print() elapsed = time.monotonic() - start_time prefix = "[dry-run] " if dry_run else "" print( f"\n{prefix}Backfill complete ({elapsed:.1f}s):\n" f" {total_checked:6,} objects checked\n" f" {plain_count:6,} already plain (skipped)\n" f" {decompressed_count:6,} decompressed and re-uploaded\n" f" {hash_mismatch_count:6,} skipped (hash mismatch after decompress)\n" f" {error_count:6,} errors" ) return error_count def main() -> None: parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument("--dry-run", action="store_true", help="Print what would change without touching R2 or the DB") parser.add_argument("--quiet", action="store_true", help="Suppress per-object progress; only print final summary") parser.add_argument("--batch", type=int, default=200, metavar="N", help="DB page size (default 200)") parser.add_argument("--concurrency", type=int, default=16, metavar="N", help="Parallel R2 fetches per batch (default 16)") parser.add_argument("--repo-id", default=None, metavar="UUID", help="Limit to a single repo_id (for targeted testing)") args = parser.parse_args() errors = asyncio.run(backfill( dry_run=args.dry_run, quiet=args.quiet, batch_size=args.batch, concurrency=args.concurrency, repo_id=args.repo_id, )) sys.exit(1 if errors else 0) if __name__ == "__main__": main()