gabriel / musehub public
test_coordination.py python
1,158 lines 46.7 KB
Raw
sha256:3ff9c9863a9891bdcde71b4a43228f66d0493e38b7cc1d09fe9eb7de774046b2 feat: add repair-commit wire endpoint (API parity with repa… Opus 4.8 minor ⚠ breaking 1 day ago
1 """Section 7 — Coordination (muse coord): 7-layer test suite.
2
3 Covers:
4 - musehub/api/routes/coord.py (push_coord, pull_coord, watch_coord HTTP handlers,
5 _resolve_repo, _assert_readable, _assert_writable)
6 - musehub/services/musehub_coord.py (coord_push, coord_pull, coord_watch_stream,
7 _row_to_out, write-once semantics, heartbeat upsert)
8 - musehub/services/musehub_coord_server.py (materialize_coord_record, list_reservations,
9 conflict_check, extend_reservation,
10 list_tasks, claim_task, complete_task, fail_task)
11 - musehub/models/coord.py (CoordRecordIn validators, CoordPushRequest,
12 CoordPollRequest)
13 - musehub/db/coord_models.py (MusehubCoordRecord, MusehubCoordReservation,
14 MusehubCoordTask)
15
16 Layers:
17 1. Unit — model validators, pure helpers, no DB
18 2. Integration — real DB (PostgreSQL), service-layer calls, no HTTP
19 3. End-to-End — full HTTP via AsyncClient, real DB
20 4. Stress — 500-record push, 100 tasks, cursor pagination
21 5. Data Integrity — write-once, heartbeat upsert, constraint enforcement,
22 task lifecycle state machine
23 6. Security — auth guards, ownership enforcement, private-repo 404,
24 invalid kind rejection
25 7. Performance — latency budgets for push/pull/materialize
26 """
27 import asyncio
28 import secrets
29 import time
30 from datetime import datetime, timedelta, timezone
31
32 import pytest
33 import pytest_asyncio
34 from httpx import AsyncClient
35 from sqlalchemy.ext.asyncio import AsyncSession
36
37 from musehub.core.genesis import compute_identity_id
38 from musehub.types.json_types import JSONObject, StrDict
39 from musehub.models.coord import (
40 CoordPollRequest,
41 CoordPushRequest,
42 CoordRecordIn,
43 _VALID_KINDS,
44 )
45 from muse.core.types import fake_id
46 from tests.factories import create_repo
47
48 # ---------------------------------------------------------------------------
49 # Local helpers
50 # ---------------------------------------------------------------------------
51
52 def _now() -> datetime:
53 return datetime.now(tz=timezone.utc)
54
55
56 def _new_id() -> str:
57 return secrets.token_hex(16)
58
59
60 def _record(
61 kind: str = "intent",
62 record_id: str | None = None,
63 run_id: str = "agent-1",
64 payload: JSONObject | None = None,
65 expires_at: datetime | None = None,
66 ) -> CoordRecordIn:
67 return CoordRecordIn(
68 kind=kind,
69 record_id=record_id or _new_id(),
70 run_id=run_id,
71 payload=payload or {"action": kind, "data": "test"},
72 expires_at=expires_at,
73 )
74
75
76 def _push_body(*records: CoordRecordIn) -> JSONObject:
77 return {"records": [r.model_dump(mode="json") for r in records]}
78
79
80 def _pull_body(since_id: int = 0, kinds: list[str] | None = None, limit: int = 500) -> JSONObject:
81 return {"since_id": since_id, "kinds": kinds or [], "limit": limit}
82
83
84 # ===========================================================================
85 # Layer 1 — Unit tests (model validators, pure helpers)
86 # ===========================================================================
87
88 class TestCoordRecordInValidators:
89 def test_valid_kinds_accepted(self) -> None:
90 for kind in _VALID_KINDS:
91 r = _record(kind=kind)
92 assert r.kind == kind
93
94 def test_invalid_kind_raises(self) -> None:
95 import pytest
96 with pytest.raises(Exception):
97 _record(kind="garbage")
98
99 def test_invalid_record_id_rejected(self) -> None:
100 with pytest.raises(Exception):
101 _record(record_id="has spaces and slashes/bad")
102
103 def test_uppercase_record_id_accepted(self) -> None:
104 uid = secrets.token_hex(16).upper()
105 r = _record(record_id=uid)
106 assert r.record_id == uid
107
108 def test_run_id_empty_string_allowed(self) -> None:
109 r = _record(run_id="")
110 assert r.run_id == ""
111
112 def test_expires_at_optional(self) -> None:
113 r = _record()
114 assert r.expires_at is None
115
116 def test_expires_at_accepted(self) -> None:
117 exp = _now() + timedelta(seconds=300)
118 r = _record(expires_at=exp)
119 assert r.expires_at is not None
120
121
122 class TestCoordPushRequestValidators:
123 def test_empty_records_rejected(self) -> None:
124 with pytest.raises(Exception):
125 CoordPushRequest(records=[])
126
127 def test_max_500_records_accepted(self) -> None:
128 records = [_record() for _ in range(500)]
129 req = CoordPushRequest(records=records)
130 assert len(req.records) == 500
131
132 def test_501_records_rejected(self) -> None:
133 with pytest.raises(Exception):
134 CoordPushRequest(records=[_record() for _ in range(501)])
135
136 def test_single_record_accepted(self) -> None:
137 req = CoordPushRequest(records=[_record()])
138 assert len(req.records) == 1
139
140
141 class TestCoordPollRequestValidators:
142 def test_default_values(self) -> None:
143 req = CoordPollRequest()
144 assert req.since_id == 0
145 assert req.kinds == []
146 assert req.limit == 500
147
148 def test_since_id_must_be_non_negative(self) -> None:
149 with pytest.raises(Exception):
150 CoordPollRequest(since_id=-1)
151
152 def test_invalid_kind_in_pull_rejected(self) -> None:
153 with pytest.raises(Exception):
154 CoordPollRequest(kinds=["garbage"])
155
156 def test_valid_kinds_filter_accepted(self) -> None:
157 req = CoordPollRequest(kinds=["reservation", "task"])
158 assert set(req.kinds) == {"reservation", "task"}
159
160 def test_limit_range(self) -> None:
161 assert CoordPollRequest(limit=1).limit == 1
162 assert CoordPollRequest(limit=1000).limit == 1000
163 with pytest.raises(Exception):
164 CoordPollRequest(limit=0)
165 with pytest.raises(Exception):
166 CoordPollRequest(limit=1001)
167
168
169 class TestValidKinds:
170 def test_all_expected_kinds_present(self) -> None:
171 expected = {"reservation", "intent", "release", "heartbeat",
172 "dependency", "task", "claim"}
173 assert expected == _VALID_KINDS
174
175
176 # ===========================================================================
177 # Layer 2 — Integration tests (real DB, service layer, no HTTP)
178 # ===========================================================================
179
180 class TestCoordPushIntegration:
181 @pytest.mark.asyncio
182 async def test_push_inserts_records(self, db_session: AsyncSession) -> None:
183 from musehub.services.musehub_coord import coord_push
184
185 repo = await create_repo(db_session, slug="push-insert")
186 req = CoordPushRequest(records=[_record("intent"), _record("dependency")])
187 resp = await coord_push(db_session, repo.repo_id, req)
188 assert resp.inserted == 2
189 assert resp.skipped == 0
190
191 @pytest.mark.asyncio
192 async def test_push_write_once_skips_duplicate(self, db_session: AsyncSession) -> None:
193 from musehub.services.musehub_coord import coord_push
194
195 repo = await create_repo(db_session, slug="push-writeonce")
196 rec = _record("intent")
197 req = CoordPushRequest(records=[rec])
198
199 resp1 = await coord_push(db_session, repo.repo_id, req)
200 assert resp1.inserted == 1
201
202 resp2 = await coord_push(db_session, repo.repo_id, req)
203 assert resp2.skipped == 1
204 assert resp2.inserted == 0
205
206 @pytest.mark.asyncio
207 async def test_push_heartbeat_upserts_payload(self, db_session: AsyncSession) -> None:
208 from musehub.services.musehub_coord import coord_push, coord_pull
209
210 repo = await create_repo(db_session, slug="push-hb-upsert")
211 uid = _new_id()
212 req1 = CoordPushRequest(records=[_record("heartbeat", record_id=uid,
213 payload={"tick": 1})])
214 req2 = CoordPushRequest(records=[_record("heartbeat", record_id=uid,
215 payload={"tick": 2})])
216
217 r1 = await coord_push(db_session, repo.repo_id, req1)
218 assert r1.inserted == 1
219
220 r2 = await coord_push(db_session, repo.repo_id, req2)
221 # Re-push of same heartbeat → upsert, counted as skipped (no new row)
222 assert r2.skipped == 1
223
224 # Payload should be updated
225 pull_resp = await coord_pull(db_session, repo.repo_id,
226 CoordPollRequest(kinds=["heartbeat"]))
227 assert len(pull_resp.records) == 1
228 assert pull_resp.records[0].payload["tick"] == 2
229
230 @pytest.mark.asyncio
231 async def test_push_all_valid_kinds(self, db_session: AsyncSession) -> None:
232 from musehub.services.musehub_coord import coord_push
233
234 repo = await create_repo(db_session, slug="push-all-kinds")
235 records = [_record(k) for k in _VALID_KINDS]
236 req = CoordPushRequest(records=records)
237 resp = await coord_push(db_session, repo.repo_id, req)
238 assert resp.inserted == len(_VALID_KINDS)
239
240 @pytest.mark.asyncio
241 async def test_push_materializes_reservation(self, db_session: AsyncSession) -> None:
242 from musehub.services.musehub_coord import coord_push
243 from musehub.services.musehub_coord_server import list_reservations
244
245 repo = await create_repo(db_session, slug="push-materialize-res")
246 rec_id = _new_id()
247 res_id = fake_id(rec_id)
248 payload = {
249 "reservation_id": res_id,
250 "run_id": "agent-42",
251 "addresses": ["src/main.py::process"],
252 "ttl_s": 300,
253 }
254 req = CoordPushRequest(records=[_record("reservation", record_id=rec_id, payload=payload)])
255 await coord_push(db_session, repo.repo_id, req)
256
257 reservations = await list_reservations(db_session, repo.repo_id)
258 assert len(reservations) == 1
259 assert reservations[0].symbol_address == "src/main.py::process"
260 assert reservations[0].agent_id == "agent-42"
261
262 @pytest.mark.asyncio
263 async def test_push_materializes_task(self, db_session: AsyncSession) -> None:
264 from musehub.services.musehub_coord import coord_push
265 from musehub.services.musehub_coord_server import list_tasks
266
267 repo = await create_repo(db_session, slug="push-materialize-task")
268 rec_id = _new_id()
269 task_id = fake_id(rec_id)
270 payload = {
271 "task_id": task_id,
272 "queue": "ci",
273 "priority": 10,
274 "created_by": "dispatcher",
275 }
276 req = CoordPushRequest(records=[_record("task", record_id=rec_id, payload=payload)])
277 await coord_push(db_session, repo.repo_id, req)
278
279 tasks = await list_tasks(db_session, repo.repo_id)
280 assert len(tasks) == 1
281 assert tasks[0].task_id == task_id
282 assert tasks[0].queue == "ci"
283 assert tasks[0].priority == 10
284 assert tasks[0].status == "pending"
285
286
287 class TestCoordPullIntegration:
288 @pytest.mark.asyncio
289 async def test_pull_empty_returns_cursor_zero(self, db_session: AsyncSession) -> None:
290 from musehub.services.musehub_coord import coord_pull
291
292 repo = await create_repo(db_session, slug="pull-empty")
293 resp = await coord_pull(db_session, repo.repo_id, CoordPollRequest())
294 assert resp.records == []
295 assert resp.cursor == 0
296
297 @pytest.mark.asyncio
298 async def test_pull_returns_all_pushed_records(self, db_session: AsyncSession) -> None:
299 from musehub.services.musehub_coord import coord_push, coord_pull
300
301 repo = await create_repo(db_session, slug="pull-all")
302 req = CoordPushRequest(records=[_record("intent"), _record("dependency"),
303 _record("heartbeat")])
304 await coord_push(db_session, repo.repo_id, req)
305
306 resp = await coord_pull(db_session, repo.repo_id, CoordPollRequest())
307 assert len(resp.records) == 3
308 assert resp.cursor == resp.records[-1].id
309
310 @pytest.mark.asyncio
311 async def test_pull_since_id_cursor_pagination(self, db_session: AsyncSession) -> None:
312 from musehub.services.musehub_coord import coord_push, coord_pull
313
314 repo = await create_repo(db_session, slug="pull-cursor")
315 for _ in range(5):
316 await coord_push(db_session, repo.repo_id,
317 CoordPushRequest(records=[_record("intent")]))
318
319 # Fetch first 3
320 resp1 = await coord_pull(db_session, repo.repo_id,
321 CoordPollRequest(limit=3))
322 assert len(resp1.records) == 3
323 cursor = resp1.cursor
324
325 # Fetch next 2 using cursor
326 resp2 = await coord_pull(db_session, repo.repo_id,
327 CoordPollRequest(since_id=cursor))
328 assert len(resp2.records) == 2
329 # IDs must be strictly greater than cursor
330 assert all(r.id > cursor for r in resp2.records)
331
332 @pytest.mark.asyncio
333 async def test_pull_kinds_filter(self, db_session: AsyncSession) -> None:
334 from musehub.services.musehub_coord import coord_push, coord_pull
335
336 repo = await create_repo(db_session, slug="pull-kinds-filter")
337 await coord_push(db_session, repo.repo_id,
338 CoordPushRequest(records=[_record("intent"), _record("heartbeat"),
339 _record("dependency")]))
340
341 resp = await coord_pull(db_session, repo.repo_id,
342 CoordPollRequest(kinds=["intent"]))
343 assert len(resp.records) == 1
344 assert resp.records[0].kind == "intent"
345
346 @pytest.mark.asyncio
347 async def test_pull_ordered_oldest_first(self, db_session: AsyncSession) -> None:
348 from musehub.services.musehub_coord import coord_push, coord_pull
349
350 repo = await create_repo(db_session, slug="pull-ordered")
351 for _ in range(3):
352 await coord_push(db_session, repo.repo_id,
353 CoordPushRequest(records=[_record("intent")]))
354
355 resp = await coord_pull(db_session, repo.repo_id, CoordPollRequest())
356 ids = [r.id for r in resp.records]
357 assert ids == sorted(ids)
358
359 @pytest.mark.asyncio
360 async def test_pull_limit_respected(self, db_session: AsyncSession) -> None:
361 from musehub.services.musehub_coord import coord_push, coord_pull
362
363 repo = await create_repo(db_session, slug="pull-limit")
364 for _ in range(10):
365 await coord_push(db_session, repo.repo_id,
366 CoordPushRequest(records=[_record("intent")]))
367
368 resp = await coord_pull(db_session, repo.repo_id,
369 CoordPollRequest(limit=4))
370 assert len(resp.records) == 4
371
372
373 class TestCoordServerIntegration:
374 @pytest.mark.asyncio
375 async def test_conflict_check_no_reservations(self, db_session: AsyncSession) -> None:
376 from musehub.services.musehub_coord_server import conflict_check
377
378 repo = await create_repo(db_session, slug="conflict-empty")
379 result = await conflict_check(db_session, repo.repo_id, ["a.py::Fn"])
380 assert result == []
381
382 @pytest.mark.asyncio
383 async def test_conflict_check_finds_active_reservation(
384 self, db_session: AsyncSession
385 ) -> None:
386 from musehub.services.musehub_coord import coord_push
387 from musehub.services.musehub_coord_server import conflict_check
388
389 repo = await create_repo(db_session, slug="conflict-found")
390 rec_id = _new_id()
391 res_id = fake_id(rec_id)
392 exp = _now() + timedelta(seconds=300)
393 payload = {
394 "reservation_id": res_id,
395 "run_id": "worker-1",
396 "addresses": ["a.py::MyFn"],
397 "ttl_s": 300,
398 "expires_at": exp.isoformat(),
399 }
400 await coord_push(db_session, repo.repo_id,
401 CoordPushRequest(records=[_record("reservation", record_id=rec_id,
402 payload=payload,
403 expires_at=exp)]))
404
405 conflicts = await conflict_check(db_session, repo.repo_id, ["a.py::MyFn"])
406 assert len(conflicts) == 1
407 assert conflicts[0]["symbol_address"] == "a.py::MyFn"
408
409 @pytest.mark.asyncio
410 async def test_conflict_check_ignores_expired_reservation(
411 self, db_session: AsyncSession
412 ) -> None:
413 from musehub.services.musehub_coord import coord_push
414 from musehub.services.musehub_coord_server import conflict_check
415 from musehub.db import coord_models as _cm
416
417 repo = await create_repo(db_session, slug="conflict-expired")
418 # Insert reservation directly with past expires_at
419 past = _now() - timedelta(seconds=10)
420 row = _cm.MusehubCoordReservation(
421 reservation_id=fake_id("expired-reservation"),
422 repo_id=repo.repo_id,
423 symbol_address="a.py::OldFn",
424 agent_id="old-agent",
425 ttl_s=10,
426 created_at=_now() - timedelta(seconds=20),
427 expires_at=past,
428 )
429 db_session.add(row)
430 await db_session.commit()
431
432 conflicts = await conflict_check(db_session, repo.repo_id, ["a.py::OldFn"])
433 assert conflicts == []
434
435 @pytest.mark.asyncio
436 async def test_extend_reservation(self, db_session: AsyncSession) -> None:
437 from musehub.services.musehub_coord import coord_push
438 from musehub.services.musehub_coord_server import extend_reservation, list_reservations
439
440 repo = await create_repo(db_session, slug="extend-reservation")
441 rec_id = _new_id()
442 res_id = fake_id(rec_id)
443 exp = _now() + timedelta(seconds=60)
444 payload = {
445 "reservation_id": res_id,
446 "run_id": "agent-ext",
447 "addresses": ["b.py::Fn"],
448 "ttl_s": 60,
449 "expires_at": exp.isoformat(),
450 }
451 await coord_push(db_session, repo.repo_id,
452 CoordPushRequest(records=[_record("reservation", record_id=rec_id,
453 payload=payload, expires_at=exp)]))
454
455 res_before = await list_reservations(db_session, repo.repo_id)
456 old_exp = res_before[0].expires_at
457
458 updated = await extend_reservation(db_session, repo.repo_id, res_id, extend_by_s=600)
459 assert updated is not None
460 # New expiry must be later than original
461 new_exp = updated.expires_at
462 if old_exp.tzinfo is None:
463 old_exp = old_exp.replace(tzinfo=timezone.utc)
464 if new_exp.tzinfo is None:
465 new_exp = new_exp.replace(tzinfo=timezone.utc)
466 assert new_exp > old_exp
467
468 @pytest.mark.asyncio
469 async def test_task_lifecycle_claim_complete(self, db_session: AsyncSession) -> None:
470 from musehub.services.musehub_coord import coord_push
471 from musehub.services.musehub_coord_server import claim_task, complete_task
472
473 repo = await create_repo(db_session, slug="task-lifecycle")
474 rec_id = _new_id()
475 task_id = fake_id(rec_id)
476 payload = {"task_id": task_id, "queue": "default", "priority": 50,
477 "created_by": "dispatcher"}
478 await coord_push(db_session, repo.repo_id,
479 CoordPushRequest(records=[_record("task", record_id=rec_id,
480 payload=payload)]))
481
482 claimed = await claim_task(db_session, repo.repo_id, task_id, "worker-1")
483 assert claimed is not None
484 assert claimed.status == "claimed"
485 assert claimed.claimed_by == "worker-1"
486
487 completed = await complete_task(db_session, repo.repo_id, task_id, "worker-1",
488 result={"output": "done"})
489 assert completed is not None
490 assert completed.status == "completed"
491 assert completed.payload.get("result") == {"output": "done"}
492
493 @pytest.mark.asyncio
494 async def test_task_lifecycle_claim_fail(self, db_session: AsyncSession) -> None:
495 from musehub.services.musehub_coord import coord_push
496 from musehub.services.musehub_coord_server import claim_task, fail_task
497
498 repo = await create_repo(db_session, slug="task-fail")
499 rec_id = _new_id()
500 task_id = fake_id(rec_id)
501 payload = {"task_id": task_id, "queue": "default", "priority": 50}
502 await coord_push(db_session, repo.repo_id,
503 CoordPushRequest(records=[_record("task", record_id=rec_id,
504 payload=payload)]))
505
506 await claim_task(db_session, repo.repo_id, task_id, "worker-2")
507 failed = await fail_task(db_session, repo.repo_id, task_id, "worker-2",
508 reason="OOM")
509 assert failed is not None
510 assert failed.status == "failed"
511 assert failed.payload.get("failure_reason") == "OOM"
512
513 @pytest.mark.asyncio
514 async def test_claim_already_claimed_task_returns_none(
515 self, db_session: AsyncSession
516 ) -> None:
517 from musehub.services.musehub_coord import coord_push
518 from musehub.services.musehub_coord_server import claim_task
519
520 repo = await create_repo(db_session, slug="double-claim")
521 rec_id = _new_id()
522 task_id = fake_id(rec_id)
523 payload = {"task_id": task_id, "queue": "default"}
524 await coord_push(db_session, repo.repo_id,
525 CoordPushRequest(records=[_record("task", record_id=rec_id,
526 payload=payload)]))
527
528 r1 = await claim_task(db_session, repo.repo_id, task_id, "worker-A")
529 assert r1 is not None
530 r2 = await claim_task(db_session, repo.repo_id, task_id, "worker-B")
531 assert r2 is None # Already claimed by worker-A
532
533 @pytest.mark.asyncio
534 async def test_list_tasks_filter_by_status(self, db_session: AsyncSession) -> None:
535 from musehub.services.musehub_coord import coord_push
536 from musehub.services.musehub_coord_server import claim_task, list_tasks
537
538 repo = await create_repo(db_session, slug="list-tasks-status")
539 tids = []
540 for i in range(3):
541 rec_id = _new_id()
542 tid = fake_id(rec_id)
543 tids.append(tid)
544 await coord_push(db_session, repo.repo_id,
545 CoordPushRequest(records=[_record("task", record_id=rec_id,
546 payload={"task_id": tid})]))
547 if i == 0:
548 await claim_task(db_session, repo.repo_id, tid, "worker-X")
549
550 pending = await list_tasks(db_session, repo.repo_id, status="pending")
551 claimed = await list_tasks(db_session, repo.repo_id, status="claimed")
552 assert len(pending) == 2
553 assert len(claimed) == 1
554
555
556 # ===========================================================================
557 # Layer 3 — End-to-End tests (full HTTP via AsyncClient, real DB)
558 # ===========================================================================
559
560 class TestCoordEndToEnd:
561 @pytest.mark.asyncio
562 async def test_push_404_unknown_repo(
563 self, client: AsyncClient, db_session: AsyncSession, auth_headers: StrDict
564 ) -> None:
565 resp = await client.post(
566 "/ghost-owner/ghost-repo/coord/push",
567 json=_push_body(_record()),
568 headers=auth_headers,
569 )
570 assert resp.status_code == 404
571
572 @pytest.mark.asyncio
573 async def test_push_requires_auth(
574 self, client: AsyncClient, db_session: AsyncSession
575 ) -> None:
576 repo = await create_repo(db_session, slug="push-noauth")
577 await db_session.commit()
578 resp = await client.post(
579 f"/{repo.owner}/{repo.slug}/coord/push",
580 json=_push_body(_record()),
581 )
582 assert resp.status_code == 401
583
584 @pytest.mark.asyncio
585 async def test_push_403_non_owner(
586 self, client: AsyncClient, db_session: AsyncSession, auth_headers: StrDict
587 ) -> None:
588 # auth_headers gives identity_id = _TEST_IDENTITY_ID; create repo with different owner_user_id
589 repo = await create_repo(db_session, slug="push-nonowner", owner_user_id=compute_identity_id(b"other-user"))
590 await db_session.commit()
591 resp = await client.post(
592 f"/{repo.owner}/{repo.slug}/coord/push",
593 json=_push_body(_record()),
594 headers=auth_headers,
595 )
596 assert resp.status_code == 403
597
598 @pytest.mark.asyncio
599 async def test_push_success_returns_counts(
600 self, client: AsyncClient, db_session: AsyncSession, auth_headers: StrDict
601 ) -> None:
602 from tests.conftest import _TEST_IDENTITY_ID
603 repo = await create_repo(db_session, slug="push-e2e-ok",
604 owner_user_id=_TEST_IDENTITY_ID)
605 await db_session.commit()
606 resp = await client.post(
607 f"/{repo.owner}/{repo.slug}/coord/push",
608 json=_push_body(_record("intent"), _record("dependency")),
609 headers=auth_headers,
610 )
611 assert resp.status_code == 200
612 data = resp.json()
613 assert data["inserted"] == 2
614 assert data["skipped"] == 0
615
616 @pytest.mark.asyncio
617 async def test_pull_public_repo_no_auth(
618 self, client: AsyncClient, db_session: AsyncSession
619 ) -> None:
620 repo = await create_repo(db_session, slug="pull-e2e-pub", visibility="public")
621 await db_session.commit()
622 resp = await client.post(
623 f"/{repo.owner}/{repo.slug}/coord/pull",
624 json=_pull_body(),
625 )
626 assert resp.status_code == 200
627 data = resp.json()
628 assert "records" in data
629 assert "cursor" in data
630
631 @pytest.mark.asyncio
632 async def test_pull_private_repo_404_no_auth(
633 self, client: AsyncClient, db_session: AsyncSession
634 ) -> None:
635 repo = await create_repo(db_session, slug="pull-e2e-priv", visibility="private")
636 await db_session.commit()
637 resp = await client.post(
638 f"/{repo.owner}/{repo.slug}/coord/pull",
639 json=_pull_body(),
640 )
641 assert resp.status_code == 404
642
643 @pytest.mark.asyncio
644 async def test_pull_returns_pushed_records(
645 self, client: AsyncClient, db_session: AsyncSession, auth_headers: StrDict
646 ) -> None:
647 from tests.conftest import _TEST_IDENTITY_ID
648 repo = await create_repo(db_session, slug="pull-e2e-round",
649 owner_user_id=_TEST_IDENTITY_ID, visibility="public")
650 await db_session.commit()
651
652 push_resp = await client.post(
653 f"/{repo.owner}/{repo.slug}/coord/push",
654 json=_push_body(_record("intent"), _record("heartbeat")),
655 headers=auth_headers,
656 )
657 assert push_resp.status_code == 200
658
659 pull_resp = await client.post(
660 f"/{repo.owner}/{repo.slug}/coord/pull",
661 json=_pull_body(),
662 )
663 assert pull_resp.status_code == 200
664 data = pull_resp.json()
665 assert len(data["records"]) == 2
666
667 @pytest.mark.asyncio
668 async def test_pull_kinds_filter_via_http(
669 self, client: AsyncClient, db_session: AsyncSession, auth_headers: StrDict
670 ) -> None:
671 from tests.conftest import _TEST_IDENTITY_ID
672 repo = await create_repo(db_session, slug="pull-e2e-filter",
673 owner_user_id=_TEST_IDENTITY_ID, visibility="public")
674 await db_session.commit()
675
676 await client.post(
677 f"/{repo.owner}/{repo.slug}/coord/push",
678 json=_push_body(_record("intent"), _record("dependency"), _record("heartbeat")),
679 headers=auth_headers,
680 )
681
682 pull_resp = await client.post(
683 f"/{repo.owner}/{repo.slug}/coord/pull",
684 json=_pull_body(kinds=["heartbeat"]),
685 )
686 assert pull_resp.status_code == 200
687 records = pull_resp.json()["records"]
688 assert len(records) == 1
689 assert records[0]["kind"] == "heartbeat"
690
691 @pytest.mark.asyncio
692 async def test_watch_invalid_kind_400(
693 self, client: AsyncClient, db_session: AsyncSession
694 ) -> None:
695 repo = await create_repo(db_session, slug="watch-invalid-kind", visibility="public")
696 await db_session.commit()
697 resp = await client.get(
698 f"/{repo.owner}/{repo.slug}/coord/watch",
699 params={"kinds": "garbage"},
700 )
701 assert resp.status_code == 400
702
703 @pytest.mark.asyncio
704 async def test_watch_404_unknown_repo(
705 self, client: AsyncClient, db_session: AsyncSession
706 ) -> None:
707 resp = await client.get("/ghost/norepo/coord/watch")
708 assert resp.status_code == 404
709
710
711 # ===========================================================================
712 # Layer 4 — Stress tests
713 # ===========================================================================
714
715 class TestStress:
716 @pytest.mark.asyncio
717 async def test_push_500_records_single_call(self, db_session: AsyncSession) -> None:
718 from musehub.services.musehub_coord import coord_push, coord_pull
719
720 repo = await create_repo(db_session, slug="stress-push-500")
721 records = [_record("intent") for _ in range(500)]
722 req = CoordPushRequest(records=records)
723 resp = await coord_push(db_session, repo.repo_id, req)
724 assert resp.inserted == 500
725 assert resp.skipped == 0
726
727 # All 500 must be pullable
728 pull = await coord_pull(db_session, repo.repo_id,
729 CoordPollRequest(limit=1000))
730 assert len(pull.records) == 500
731
732 @pytest.mark.asyncio
733 async def test_cursor_pagination_through_500_records(
734 self, db_session: AsyncSession
735 ) -> None:
736 from musehub.services.musehub_coord import coord_push, coord_pull
737
738 repo = await create_repo(db_session, slug="stress-cursor-500")
739 records = [_record("dependency") for _ in range(500)]
740 await coord_push(db_session, repo.repo_id, CoordPushRequest(records=records))
741
742 cursor = 0
743 fetched = 0
744 pages = 0
745 while True:
746 page = await coord_pull(db_session, repo.repo_id,
747 CoordPollRequest(since_id=cursor, limit=100))
748 if not page.records:
749 break
750 fetched += len(page.records)
751 cursor = page.cursor
752 pages += 1
753 assert fetched == 500
754 assert pages == 5
755
756 @pytest.mark.asyncio
757 async def test_task_queue_100_tasks(self, db_session: AsyncSession) -> None:
758 from musehub.services.musehub_coord import coord_push
759 from musehub.services.musehub_coord_server import list_tasks, claim_task
760
761 repo = await create_repo(db_session, slug="stress-100-tasks")
762 for _ in range(100):
763 rec_id = _new_id()
764 tid = fake_id(rec_id)
765 payload = {"task_id": tid, "queue": "batch", "priority": 50}
766 await coord_push(db_session, repo.repo_id,
767 CoordPushRequest(records=[_record("task", record_id=rec_id,
768 payload=payload)]))
769
770 tasks = await list_tasks(db_session, repo.repo_id, queue="batch", limit=100)
771 assert len(tasks) == 100
772
773 # Claim first 10
774 claimed_count = 0
775 for task in tasks[:10]:
776 result = await claim_task(db_session, repo.repo_id, task.task_id, "batch-worker")
777 if result is not None:
778 claimed_count += 1
779 assert claimed_count == 10
780
781 @pytest.mark.asyncio
782 async def test_conflict_check_100_reserved_symbols(
783 self, db_session: AsyncSession
784 ) -> None:
785 from musehub.services.musehub_coord import coord_push
786 from musehub.services.musehub_coord_server import conflict_check
787 from musehub.db import coord_models as _cm
788
789 repo = await create_repo(db_session, slug="stress-conflict-100")
790 # Insert 100 reservations directly
791 exp = _now() + timedelta(seconds=300)
792 for i in range(100):
793 row = _cm.MusehubCoordReservation(
794 reservation_id=fake_id(f"stress-res-{i}"),
795 repo_id=repo.repo_id,
796 symbol_address=f"module/file_{i}.py::Fn{i}",
797 agent_id=f"agent-{i}",
798 ttl_s=300,
799 created_at=_now(),
800 expires_at=exp,
801 )
802 db_session.add(row)
803 await db_session.commit()
804
805 # Check the last 50 — all should conflict
806 addresses = [f"module/file_{i}.py::Fn{i}" for i in range(50, 100)]
807 conflicts = await conflict_check(db_session, repo.repo_id, addresses)
808 assert len(conflicts) == 50
809
810
811 # ===========================================================================
812 # Layer 5 — Data Integrity tests
813 # ===========================================================================
814
815 class TestDataIntegrity:
816 @pytest.mark.asyncio
817 async def test_write_once_constraint_enforced(self, db_session: AsyncSession) -> None:
818 """The UniqueConstraint on (repo_id, kind, record_id) must hold."""
819 from musehub.services.musehub_coord import coord_push
820
821 repo = await create_repo(db_session, slug="di-unique-constraint")
822 uid = _new_id()
823 rec = _record("intent", record_id=uid)
824
825 r1 = await coord_push(db_session, repo.repo_id, CoordPushRequest(records=[rec]))
826 r2 = await coord_push(db_session, repo.repo_id, CoordPushRequest(records=[rec]))
827 # First → inserted, second → skipped (not error)
828 assert r1.inserted == 1
829 assert r2.skipped == 1
830
831 @pytest.mark.asyncio
832 async def test_heartbeat_upsert_does_not_create_new_row(
833 self, db_session: AsyncSession
834 ) -> None:
835 from musehub.services.musehub_coord import coord_push, coord_pull
836
837 repo = await create_repo(db_session, slug="di-hb-no-dup")
838 uid = _new_id()
839 for i in range(5):
840 rec = _record("heartbeat", record_id=uid, payload={"tick": i})
841 await coord_push(db_session, repo.repo_id, CoordPushRequest(records=[rec]))
842
843 resp = await coord_pull(db_session, repo.repo_id,
844 CoordPollRequest(kinds=["heartbeat"]))
845 # Only 1 row despite 5 pushes
846 assert len(resp.records) == 1
847 assert resp.records[0].payload["tick"] == 4
848
849 @pytest.mark.asyncio
850 async def test_coord_record_fields_complete(self, db_session: AsyncSession) -> None:
851 from musehub.services.musehub_coord import coord_push, coord_pull
852
853 repo = await create_repo(db_session, slug="di-record-fields")
854 uid = _new_id()
855 exp = _now() + timedelta(seconds=120)
856 rec = _record("dependency", record_id=uid, run_id="run-99",
857 payload={"dep": "x"}, expires_at=exp)
858 await coord_push(db_session, repo.repo_id, CoordPushRequest(records=[rec]))
859
860 resp = await coord_pull(db_session, repo.repo_id,
861 CoordPollRequest(kinds=["dependency"]))
862 r = resp.records[0]
863 assert r.kind == "dependency"
864 assert r.record_id == uid
865 assert r.run_id == "run-99"
866 assert r.payload == {"dep": "x"}
867 assert r.repo_id == repo.repo_id
868 assert r.created_at is not None
869
870 @pytest.mark.asyncio
871 async def test_task_depends_on_preserved(self, db_session: AsyncSession) -> None:
872 from musehub.services.musehub_coord import coord_push
873 from musehub.services.musehub_coord_server import list_tasks
874
875 repo = await create_repo(db_session, slug="di-depends-on")
876 dep_a = fake_id("dep-a")
877 dep_b = fake_id("dep-b")
878 rec_id = _new_id()
879 task_id = fake_id(rec_id)
880 payload = {"task_id": task_id, "queue": "default", "depends_on": [dep_a, dep_b]}
881 await coord_push(db_session, repo.repo_id,
882 CoordPushRequest(records=[_record("task", record_id=rec_id,
883 payload=payload)]))
884 tasks = await list_tasks(db_session, repo.repo_id)
885 assert tasks[0].depends_on == [dep_a, dep_b]
886
887 @pytest.mark.asyncio
888 async def test_release_marks_reservation_released(self, db_session: AsyncSession) -> None:
889 from musehub.services.musehub_coord import coord_push
890 from musehub.services.musehub_coord_server import list_reservations
891 from musehub.db import coord_models as _cm
892
893 repo = await create_repo(db_session, slug="di-release")
894 rec_id = _new_id()
895 res_id = fake_id(rec_id)
896 exp = _now() + timedelta(seconds=300)
897 res_payload = {
898 "reservation_id": res_id,
899 "run_id": "agent-r",
900 "addresses": ["c.py::Fn"],
901 "ttl_s": 300,
902 "expires_at": exp.isoformat(),
903 }
904 await coord_push(db_session, repo.repo_id,
905 CoordPushRequest(records=[_record("reservation", record_id=rec_id,
906 payload=res_payload,
907 expires_at=exp)]))
908
909 # Confirm reservation exists
910 active = await list_reservations(db_session, repo.repo_id)
911 assert len(active) == 1
912
913 # Push a release record
914 rel_id = _new_id()
915 rel_payload = {"reservation_id": res_id}
916 await coord_push(db_session, repo.repo_id,
917 CoordPushRequest(records=[_record("release", record_id=rel_id,
918 payload=rel_payload)]))
919
920 # Reservation should now be gone from active list
921 active_after = await list_reservations(db_session, repo.repo_id)
922 assert len(active_after) == 0
923
924
925 @pytest.mark.asyncio
926 async def test_reservation_multi_address_stores_all_rows(self, db_session: AsyncSession) -> None:
927 """A reservation covering N addresses must create N rows — one per address.
928
929 This is a regression test for a bug where session.get(PK) on the second
930 loop iteration found the first row and skipped the remaining addresses,
931 leaving only the first address in the DB.
932 """
933 from musehub.services.musehub_coord import coord_push
934 from musehub.services.musehub_coord_server import list_reservations, conflict_check
935 from sqlalchemy import select
936 from musehub.db import coord_models as _cm
937
938 repo = await create_repo(db_session, slug="multi-addr-reservation")
939 rec_id = _new_id()
940 res_id = fake_id(rec_id)
941 addresses = ["src/engine.py::AudioEngine", "src/mixer.py::Mixer", "src/output.py::Output"]
942 exp = _now() + timedelta(seconds=300)
943 res_payload = {
944 "reservation_id": res_id,
945 "run_id": "agent-multi",
946 "addresses": addresses,
947 "ttl_s": 300,
948 "expires_at": exp.isoformat(),
949 }
950 await coord_push(
951 db_session, repo.repo_id,
952 CoordPushRequest(records=[_record("reservation", record_id=rec_id,
953 payload=res_payload, expires_at=exp)])
954 )
955
956 # Fetch all rows for this reservation_id directly
957 result = await db_session.execute(
958 select(_cm.MusehubCoordReservation).where(
959 _cm.MusehubCoordReservation.reservation_id == res_id
960 )
961 )
962 rows = result.scalars().all()
963 stored_addresses = {r.symbol_address for r in rows}
964
965 assert stored_addresses == set(addresses), (
966 f"Expected all 3 addresses stored, got: {stored_addresses}"
967 )
968
969 # All three addresses should show in conflict_check
970 conflicts = await conflict_check(db_session, repo.repo_id, addresses)
971 assert len(conflicts) == len(addresses), (
972 f"Expected {len(addresses)} conflicts, got {len(conflicts)}"
973 )
974
975
976 # ===========================================================================
977 # Layer 6 — Security tests
978 # ===========================================================================
979
980 class TestSecurity:
981 @pytest.mark.asyncio
982 async def test_push_requires_authentication(
983 self, client: AsyncClient, db_session: AsyncSession
984 ) -> None:
985 repo = await create_repo(db_session, slug="sec-push-noauth", visibility="public")
986 await db_session.commit()
987 resp = await client.post(
988 f"/{repo.owner}/{repo.slug}/coord/push",
989 json=_push_body(_record()),
990 )
991 assert resp.status_code == 401
992
993 @pytest.mark.asyncio
994 async def test_push_403_for_non_owner_authenticated(
995 self, client: AsyncClient, db_session: AsyncSession, auth_headers: StrDict
996 ) -> None:
997 repo = await create_repo(db_session, slug="sec-push-nonowner",
998 owner_user_id=compute_identity_id(b"other-owner"))
999 await db_session.commit()
1000 resp = await client.post(
1001 f"/{repo.owner}/{repo.slug}/coord/push",
1002 json=_push_body(_record()),
1003 headers=auth_headers,
1004 )
1005 assert resp.status_code == 403
1006
1007 @pytest.mark.asyncio
1008 async def test_private_repo_pull_returns_404_unauthenticated(
1009 self, client: AsyncClient, db_session: AsyncSession
1010 ) -> None:
1011 repo = await create_repo(db_session, slug="sec-priv-pull", visibility="private")
1012 await db_session.commit()
1013 resp = await client.post(
1014 f"/{repo.owner}/{repo.slug}/coord/pull",
1015 json=_pull_body(),
1016 )
1017 assert resp.status_code == 404
1018
1019 @pytest.mark.asyncio
1020 async def test_push_invalid_kind_rejected(
1021 self, client: AsyncClient, db_session: AsyncSession, auth_headers: StrDict
1022 ) -> None:
1023 from tests.conftest import _TEST_IDENTITY_ID
1024 repo = await create_repo(db_session, slug="sec-invalid-kind",
1025 owner_user_id=_TEST_IDENTITY_ID)
1026 await db_session.commit()
1027 bad_payload = {
1028 "records": [{
1029 "kind": "INJECT_SQL",
1030 "record_id": secrets.token_hex(16),
1031 "run_id": "",
1032 "payload": {},
1033 }]
1034 }
1035 resp = await client.post(
1036 f"/{repo.owner}/{repo.slug}/coord/push",
1037 json=bad_payload,
1038 headers=auth_headers,
1039 )
1040 assert resp.status_code == 422
1041
1042 @pytest.mark.asyncio
1043 async def test_watch_invalid_kind_query_param_400(
1044 self, client: AsyncClient, db_session: AsyncSession
1045 ) -> None:
1046 repo = await create_repo(db_session, slug="sec-watch-kind", visibility="public")
1047 await db_session.commit()
1048 resp = await client.get(
1049 f"/{repo.owner}/{repo.slug}/coord/watch",
1050 params={"kinds": "evil_kind"},
1051 )
1052 assert resp.status_code == 400
1053
1054 @pytest.mark.asyncio
1055 async def test_complete_task_wrong_agent_rejected(
1056 self, db_session: AsyncSession
1057 ) -> None:
1058 from musehub.services.musehub_coord import coord_push
1059 from musehub.services.musehub_coord_server import claim_task, complete_task
1060
1061 repo = await create_repo(db_session, slug="sec-complete-wrong-agent")
1062 rec_id = _new_id()
1063 task_id = fake_id(rec_id)
1064 payload = {"task_id": task_id, "queue": "default"}
1065 await coord_push(db_session, repo.repo_id,
1066 CoordPushRequest(records=[_record("task", record_id=rec_id,
1067 payload=payload)]))
1068
1069 await claim_task(db_session, repo.repo_id, task_id, "worker-A")
1070 result = await complete_task(db_session, repo.repo_id, task_id, "worker-B")
1071 # worker-B did not claim it — must return None
1072 assert result is None
1073
1074
1075 # ===========================================================================
1076 # Layer 7 — Performance tests
1077 # ===========================================================================
1078
1079 class TestPerformance:
1080 @pytest.mark.asyncio
1081 async def test_push_100_records_under_500ms(self, db_session: AsyncSession) -> None:
1082 from musehub.services.musehub_coord import coord_push
1083
1084 repo = await create_repo(db_session, slug="perf-push-100")
1085 records = [_record("intent") for _ in range(100)]
1086 req = CoordPushRequest(records=records)
1087
1088 t0 = time.perf_counter()
1089 resp = await coord_push(db_session, repo.repo_id, req)
1090 elapsed_ms = (time.perf_counter() - t0) * 1000
1091
1092 assert resp.inserted == 100
1093 assert elapsed_ms < 500, f"push 100 records took {elapsed_ms:.1f}ms"
1094
1095 @pytest.mark.asyncio
1096 async def test_pull_500_records_under_200ms(self, db_session: AsyncSession) -> None:
1097 from musehub.services.musehub_coord import coord_push, coord_pull
1098
1099 repo = await create_repo(db_session, slug="perf-pull-500")
1100 records = [_record("dependency") for _ in range(500)]
1101 await coord_push(db_session, repo.repo_id, CoordPushRequest(records=records))
1102
1103 t0 = time.perf_counter()
1104 resp = await coord_pull(db_session, repo.repo_id, CoordPollRequest(limit=1000))
1105 elapsed_ms = (time.perf_counter() - t0) * 1000
1106
1107 assert len(resp.records) == 500
1108 assert elapsed_ms < 200, f"pull 500 records took {elapsed_ms:.1f}ms"
1109
1110 @pytest.mark.asyncio
1111 async def test_conflict_check_50_addresses_under_100ms(
1112 self, db_session: AsyncSession
1113 ) -> None:
1114 from musehub.services.musehub_coord_server import conflict_check
1115 from musehub.db import coord_models as _cm
1116
1117 repo = await create_repo(db_session, slug="perf-conflict")
1118 exp = _now() + timedelta(seconds=300)
1119 for i in range(50):
1120 db_session.add(_cm.MusehubCoordReservation(
1121 reservation_id=fake_id(f"perf-res-{i}"),
1122 repo_id=repo.repo_id,
1123 symbol_address=f"pkg/file_{i}.py::Fn{i}",
1124 agent_id="agent",
1125 ttl_s=300,
1126 created_at=_now(),
1127 expires_at=exp,
1128 ))
1129 await db_session.commit()
1130
1131 addresses = [f"pkg/file_{i}.py::Fn{i}" for i in range(50)]
1132 t0 = time.perf_counter()
1133 conflicts = await conflict_check(db_session, repo.repo_id, addresses)
1134 elapsed_ms = (time.perf_counter() - t0) * 1000
1135
1136 assert len(conflicts) == 50
1137 assert elapsed_ms < 100, f"conflict_check 50 addresses took {elapsed_ms:.1f}ms"
1138
1139 @pytest.mark.asyncio
1140 async def test_task_queue_list_100_under_100ms(self, db_session: AsyncSession) -> None:
1141 from musehub.services.musehub_coord import coord_push
1142 from musehub.services.musehub_coord_server import list_tasks
1143
1144 repo = await create_repo(db_session, slug="perf-tasklist-100")
1145 for _ in range(100):
1146 rec_id = _new_id()
1147 tid = fake_id(rec_id)
1148 await coord_push(db_session, repo.repo_id,
1149 CoordPushRequest(records=[_record("task", record_id=rec_id,
1150 payload={"task_id": tid,
1151 "queue": "perf"})]))
1152
1153 t0 = time.perf_counter()
1154 tasks = await list_tasks(db_session, repo.repo_id, limit=100)
1155 elapsed_ms = (time.perf_counter() - t0) * 1000
1156
1157 assert len(tasks) == 100
1158 assert elapsed_ms < 100, f"list_tasks 100 took {elapsed_ms:.1f}ms"
File History 1 commit
sha256:3ff9c9863a9891bdcde71b4a43228f66d0493e38b7cc1d09fe9eb7de774046b2 feat: add repair-commit wire endpoint (API parity with repa… Opus 4.8 minor 1 day ago