gabriel / musehub public

test_mcp_protocol.py file-level

at sha256:5 · View file ↗ · Intel ↗

History
1 files
1 commits
0 hotspots
0 🧊 dead
0 πŸ’₯ blast risk
sha256:0 fix: fall back to any indexed mpack in read_object_bytes when push mpac… · gabriel · Jun 17, 2026
1 """Section 13 β€” MCP Protocol Layer: 7-layer test suite.
2
3 Covers gaps not addressed by the existing 102 tests in:
4 - test_mcp_dispatcher.py (protocol correctness, tools, resources, prompts)
5 - test_mcp_streamable_http.py (transport, session CRUD, origin, lifecycle)
6 - test_mcp_elicitation.py (elicitation flows, progress, interactive tools)
7
8 New coverage in this file:
9
10 Layer 1 Unit:
11 - _validate_origin: URL with path (path stripped), URL with port, malformed URL
12 - MCPSession.is_expired(): elapsed > TTL β†’ True; elapsed < TTL β†’ False
13 - MCPSession.touch(): resets last_active, deferring expiry
14 - MCPSession.supports_elicitation_form(): empty-dict variant (backwards compat)
15 - push_to_session ring buffer: capped at 50 β€” oldest dropped
16 - create_session stores anonymous user_id correctly
17
18 Layer 2 Integration:
19 - create_session returns unique session IDs for each call
20 - delete_session puts None sentinel to all registered SSE queues
21 - delete_session cancels pending asyncio Futures
22 - get_session evicts expired session, returns None
23 - push_to_session broadcasts to multiple queues simultaneously
24
25 Layer 3 E2E (HTTP):
26 - Full lifecycle: initialize β†’ ping (with session) β†’ DELETE
27 - 127.0.0.1 origin always allowed
28 - Origin containing allow-listed domain as a path component is rejected
29 - Batch with notification mixed: response list excludes notification
30 - Empty batch array returns 200 with empty list
31
32 Layer 4 Stress:
33 - 50-item ping batch β†’ 50 responses
34 - push_to_session to 10 queues simultaneously β€” all receive event
35
36 Layer 5 Data Integrity:
37 - 100 sessions have 100 distinct IDs (no collisions)
38 - Ring buffer stays at ≀50 entries after 60 pushes
39 - Session user_id and client_capabilities preserved
40
41 Layer 6 Security:
42 - http://127.0.0.1 origin always allowed (_ALWAYS_ALLOW_ORIGINS)
43 - http://127.0.0.1:8080 (non-standard port) is accepted (part of always-allow netloc)
44 - Origin with path suffix does not expand allow list
45 - Non-initialize POST without Mcp-Session-Id routes to dispatcher (no crash)
46
47 Layer 7 Performance:
48 - 100Γ— handle_request("ping") under 100 ms
49 - 1000Γ— MCPSession.is_expired() under 10 ms
50 - 100Γ— push_to_session under 50 ms
51 """
52 from __future__ import annotations
53
54 import asyncio
55 import time
56 from unittest.mock import patch
57
58 import pytest
59 import pytest_asyncio
60 from httpx import AsyncClient, ASGITransport
61 from sqlalchemy.ext.asyncio import AsyncSession
62
63 from musehub.main import app
64 from musehub.mcp.dispatcher import handle_request
65 from musehub.types.json_types import JSONObject, StrDict
66 from musehub.mcp.session import (
67 MCPSession,
68 SessionCapacityError,
69 create_session,
70 delete_session,
71 get_session,
72 push_to_session,
73 create_pending_elicitation,
74 )
75 from musehub.api.routes.mcp import _validate_origin
76
77
78 # ── Fixtures ──────────────────────────────────────────────────────────────────
79
80
81 @pytest.fixture
82 def anyio_backend() -> str:
83 return "asyncio"
84
85
86 @pytest_asyncio.fixture
87 async def http_client(db_session: AsyncSession) -> AsyncClient:
88 async with AsyncClient(
89 transport=ASGITransport(app=app),
90 base_url="http://localhost",
91 ) as c:
92 yield c
93
94
95 def _init_body() -> JSONObject:
96 return {
97 "jsonrpc": "2.0",
98 "id": 1,
99 "method": "initialize",
100 "params": {
101 "protocolVersion": "2025-11-25",
102 "clientInfo": {"name": "test-client", "version": "1.0"},
103 "capabilities": {"elicitation": {"form": {}}},
104 },
105 }
106
107
108 def _req(method: str, params: JSONObject | None = None, req_id: int = 1) -> JSONObject:
109 msg = {"jsonrpc": "2.0", "id": req_id, "method": method}
110 if params is not None:
111 msg["params"] = params
112 return msg
113
114
115 class _FakeRequest:
116 """Minimal stub implementing the subset of Request used by _validate_origin."""
117
118 def __init__(self, origin: str | None) -> None:
119 self.headers: StrDict = (
120 {"Origin": origin} if origin is not None else {}
121 )
122 self.headers = _CaseInsensitiveDict(self.headers)
123
124
125 class _CaseInsensitiveDict(dict[str, str]):
126 def get(self, key: str, default: str | None = None) -> str | None:
127 return super().get(key.lower(), default) or super().get(key, default)
128
129
130 # ── Layer 1 β€” Unit ────────────────────────────────────────────────────────────
131
132
133 class TestUnitValidateOrigin:
134 """_validate_origin edge cases not tested elsewhere."""
135
136 def test_no_origin_returns_true(self) -> None:
137 assert _validate_origin(_FakeRequest(None)) is True
138
139 def test_localhost_no_port_returns_true(self) -> None:
140 assert _validate_origin(_FakeRequest("http://localhost")) is True
141
142 def test_127_0_0_1_returns_true(self) -> None:
143 assert _validate_origin(_FakeRequest("http://127.0.0.1")) is True
144
145 def test_localhost_with_path_still_resolves(self) -> None:
146 """Path component must be stripped; localhost is still allow-listed."""
147 assert _validate_origin(_FakeRequest("http://localhost/some/path")) is True
148
149 def test_evil_origin_returns_false(self) -> None:
150 assert _validate_origin(_FakeRequest("https://evil.example.com")) is False
151
152 def test_evil_with_localhost_path_returns_false(self) -> None:
153 """Attacker embedding 'localhost' in the path must not bypass the check."""
154 assert (
155 _validate_origin(
156 _FakeRequest("https://evil.example.com/localhost")
157 )
158 is False
159 )
160
161 def test_malformed_url_treated_as_allowed_or_false(self) -> None:
162 """A URL that causes urlparse to produce empty netloc is rejected."""
163 # urlparse("not a url") β†’ scheme='', netloc='' β†’ normalised='://'
164 # '://' is not in either allow set.
165 result = _validate_origin(_FakeRequest("not a url at all"))
166 assert result is False
167
168
169 class TestUnitMCPSessionExpiry:
170 """MCPSession.is_expired() and touch() temporal semantics."""
171
172 def test_fresh_session_is_not_expired(self) -> None:
173 s = create_session(None, {})
174 try:
175 assert s.is_expired() is False
176 finally:
177 delete_session(s.session_id)
178
179 def test_last_active_in_past_is_expired(self) -> None:
180 s = create_session(None, {})
181 try:
182 # Wind last_active back by more than the TTL.
183 s.last_active = time.monotonic() - 3700
184 assert s.is_expired() is True
185 finally:
186 delete_session(s.session_id)
187
188 def test_touch_defers_expiry(self) -> None:
189 """touch() resets last_active so the session is no longer expired."""
190 s = create_session(None, {})
191 try:
192 s.last_active = time.monotonic() - 3700 # force expired
193 assert s.is_expired() is True
194 s.touch()
195 assert s.is_expired() is False
196 finally:
197 delete_session(s.session_id)
198
199
200 class TestUnitElicitationCapabilities:
201 """supports_elicitation_form/url edge cases."""
202
203 def test_empty_elicitation_dict_counts_as_form(self) -> None:
204 """Empty elicitation dict ≑ form-only per spec backward compat."""
205 s = create_session(None, {"elicitation": {}})
206 try:
207 assert s.supports_elicitation_form() is True
208 finally:
209 delete_session(s.session_id)
210
211 def test_elicitation_not_a_dict_returns_false_for_form(self) -> None:
212 s = create_session(None, {"elicitation": True})
213 try:
214 assert s.supports_elicitation_form() is False
215 finally:
216 delete_session(s.session_id)
217
218 def test_elicitation_not_a_dict_returns_false_for_url(self) -> None:
219 s = create_session(None, {"elicitation": "url"})
220 try:
221 assert s.supports_elicitation_url() is False
222 finally:
223 delete_session(s.session_id)
224
225 def test_no_elicitation_key_returns_false(self) -> None:
226 s = create_session(None, {})
227 try:
228 assert s.supports_elicitation_form() is False
229 assert s.supports_elicitation_url() is False
230 finally:
231 delete_session(s.session_id)
232
233
234 class TestUnitRingBuffer:
235 """push_to_session ring buffer capping behaviour."""
236
237 def test_ring_buffer_capped_at_50(self) -> None:
238 s = create_session(None, {})
239 try:
240 for i in range(60):
241 push_to_session(s, f"data: event-{i}\n\n")
242 assert len(s.event_buffer) == 50
243 finally:
244 delete_session(s.session_id)
245
246 def test_ring_buffer_drops_oldest(self) -> None:
247 s = create_session(None, {})
248 try:
249 for i in range(55):
250 push_to_session(s, f"data: event-{i}\n\n")
251 # Oldest events (0–4) must be gone; event-5 must now be first.
252 first_text = s.event_buffer[0][1]
253 assert "event-5" in first_text
254 finally:
255 delete_session(s.session_id)
256
257
258 class TestUnitCreateSessionAnonymous:
259 def test_anonymous_session_user_id_is_none(self) -> None:
260 s = create_session(None, {})
261 try:
262 assert s.user_id is None
263 finally:
264 delete_session(s.session_id)
265
266 def test_authenticated_session_stores_user_id(self) -> None:
267 s = create_session("user-xyz", {"elicitation": {"form": {}}})
268 try:
269 assert s.user_id == "user-xyz"
270 assert s.client_capabilities == {"elicitation": {"form": {}}}
271 finally:
272 delete_session(s.session_id)
273
274
275 # ── Layer 2 β€” Integration ─────────────────────────────────────────────────────
276
277
278 class TestIntegrationSessionUniqueness:
279 def test_create_session_returns_unique_ids(self) -> None:
280 sessions = [create_session(None, {}) for _ in range(10)]
281 ids = [s.session_id for s in sessions]
282 try:
283 assert len(set(ids)) == 10
284 finally:
285 for s in sessions:
286 delete_session(s.session_id)
287
288
289 class TestIntegrationDeleteSessionSSE:
290 async def test_delete_sends_none_sentinel_to_queues(self) -> None:
291 """delete_session must put the None sentinel to all registered SSE queues."""
292 s = create_session(None, {})
293 q1: asyncio.Queue[str | None] = asyncio.Queue()
294 q2: asyncio.Queue[str | None] = asyncio.Queue()
295 s.sse_queues.extend([q1, q2])
296
297 delete_session(s.session_id)
298
299 assert q1.get_nowait() is None
300 assert q2.get_nowait() is None
301
302 async def test_delete_cancels_pending_futures(self) -> None:
303 """delete_session must cancel all unresolved elicitation Futures."""
304 s = create_session(None, {})
305 fut = create_pending_elicitation(s, "elicit-99")
306 assert not fut.done()
307
308 delete_session(s.session_id)
309
310 assert fut.cancelled()
311
312
313 class TestIntegrationExpiredSessionEviction:
314 def test_expired_session_evicted_by_get_session(self) -> None:
315 s = create_session(None, {})
316 sid = s.session_id
317 s.last_active = time.monotonic() - 3700 # force expired
318
319 result = get_session(sid)
320
321 assert result is None
322 # Confirm fully evicted (second call should also return None cleanly).
323 assert get_session(sid) is None
324
325
326 class TestIntegrationMultiQueueBroadcast:
327 async def test_push_broadcasts_to_all_queues(self) -> None:
328 s = create_session(None, {})
329 try:
330 queues: list[asyncio.Queue[str | None]] = [
331 asyncio.Queue() for _ in range(5)
332 ]
333 s.sse_queues.extend(queues)
334
335 push_to_session(s, "data: hello\n\n")
336
337 for q in queues:
338 item = q.get_nowait()
339 assert item == "data: hello\n\n"
340 finally:
341 delete_session(s.session_id)
342
343
344 # ── Layer 3 β€” End-to-End ──────────────────────────────────────────────────────
345
346
347 class TestE2EFullLifecycle:
348 async def test_initialize_ping_delete(self, http_client: AsyncClient) -> None:
349 """Full session lifecycle: initialize β†’ ping β†’ DELETE."""
350 # Initialize.
351 init_resp = await http_client.post(
352 "/mcp",
353 json=_init_body(),
354 headers={"Content-Type": "application/json"},
355 )
356 assert init_resp.status_code == 200
357 session_id = init_resp.headers["mcp-session-id"]
358
359 # Ping using the session.
360 ping_resp = await http_client.post(
361 "/mcp",
362 json=_req("ping"),
363 headers={
364 "Content-Type": "application/json",
365 "Mcp-Session-Id": session_id,
366 },
367 )
368 assert ping_resp.status_code == 200
369 assert ping_resp.json()["result"] == {}
370
371 # Delete the session.
372 del_resp = await http_client.delete(
373 "/mcp",
374 headers={"Mcp-Session-Id": session_id},
375 )
376 assert del_resp.status_code == 200
377 assert get_session(session_id) is None
378
379
380 class TestE2EOriginEdgeCases:
381 async def test_127_0_0_1_origin_allowed(self, http_client: AsyncClient) -> None:
382 """http://127.0.0.1 must always be allowed."""
383 resp = await http_client.post(
384 "/mcp",
385 json=_init_body(),
386 headers={
387 "Content-Type": "application/json",
388 "Origin": "http://127.0.0.1",
389 },
390 )
391 assert resp.status_code == 200
392 # Clean up.
393 if "mcp-session-id" in resp.headers:
394 delete_session(resp.headers["mcp-session-id"])
395
396 async def test_origin_with_localhost_in_path_rejected(
397 self, http_client: AsyncClient
398 ) -> None:
399 """Attacker embedding localhost in path must not bypass origin check."""
400 resp = await http_client.post(
401 "/mcp",
402 json=_init_body(),
403 headers={
404 "Content-Type": "application/json",
405 "Origin": "https://evil.example.com/localhost",
406 },
407 )
408 assert resp.status_code == 403
409
410
411 class TestE2EBatchWithNotification:
412 async def test_batch_excludes_notification(self, http_client: AsyncClient) -> None:
413 """Batch with a notification mixed in β€” response list must omit notification."""
414 batch = [
415 _req("ping", req_id=1),
416 # Notification (no id).
417 {"jsonrpc": "2.0", "method": "notifications/initialized"},
418 _req("ping", req_id=3),
419 ]
420 resp = await http_client.post(
421 "/mcp",
422 json=batch,
423 headers={"Content-Type": "application/json"},
424 )
425 assert resp.status_code == 200
426 data = resp.json()
427 assert isinstance(data, list)
428 # Only the two ping requests produce responses; notification is excluded.
429 assert len(data) == 2
430 ids = {item["id"] for item in data}
431 assert ids == {1, 3}
432
433 async def test_empty_batch_returns_202(
434 self, http_client: AsyncClient
435 ) -> None:
436 """An empty batch array produces no responses β€” treated as notification-only."""
437 resp = await http_client.post(
438 "/mcp",
439 json=[],
440 headers={"Content-Type": "application/json"},
441 )
442 assert resp.status_code == 202
443
444
445 # ── Layer 4 β€” Stress ──────────────────────────────────────────────────────────
446
447
448 class TestStressBatch:
449 async def test_50_item_ping_batch(self, http_client: AsyncClient) -> None:
450 """50-item ping batch must return exactly 50 responses."""
451 batch = [_req("ping", req_id=i) for i in range(50)]
452 resp = await http_client.post(
453 "/mcp",
454 json=batch,
455 headers={"Content-Type": "application/json"},
456 )
457 assert resp.status_code == 200
458 data = resp.json()
459 assert isinstance(data, list)
460 assert len(data) == 50
461 for item in data:
462 assert item["result"] == {}
463
464
465 class TestStressPushToMultipleQueues:
466 async def test_push_to_10_queues(self) -> None:
467 """push_to_session to 10 registered queues β€” all receive the event."""
468 s = create_session(None, {})
469 try:
470 queues: list[asyncio.Queue[str | None]] = [
471 asyncio.Queue() for _ in range(10)
472 ]
473 s.sse_queues.extend(queues)
474
475 push_to_session(s, "data: stress\n\n")
476
477 for q in queues:
478 assert q.get_nowait() == "data: stress\n\n"
479 finally:
480 delete_session(s.session_id)
481
482
483 # ── Layer 5 β€” Data Integrity ──────────────────────────────────────────────────
484
485
486 class TestDataIntegritySessionIDs:
487 def test_100_sessions_have_unique_ids(self) -> None:
488 sessions = [create_session(None, {}) for _ in range(100)]
489 ids = [s.session_id for s in sessions]
490 try:
491 assert len(set(ids)) == 100
492 finally:
493 for s in sessions:
494 delete_session(s.session_id)
495
496
497 class TestDataIntegrityRingBuffer:
498 def test_ring_buffer_never_exceeds_50(self) -> None:
499 s = create_session(None, {})
500 try:
501 for i in range(60):
502 push_to_session(s, f"data: {i}\n\n")
503 assert len(s.event_buffer) <= 50
504 finally:
505 delete_session(s.session_id)
506
507 def test_ring_buffer_content_after_exact_50_pushes(self) -> None:
508 s = create_session(None, {})
509 try:
510 for i in range(50):
511 push_to_session(s, f"data: {i}\n\n")
512 assert len(s.event_buffer) == 50
513 assert "data: 0" in s.event_buffer[0][1]
514 assert "data: 49" in s.event_buffer[-1][1]
515 finally:
516 delete_session(s.session_id)
517
518
519 class TestDataIntegritySessionAttributes:
520 def test_session_user_id_preserved(self) -> None:
521 s = create_session("preserved-user", {"cap": "x"})
522 try:
523 fetched = get_session(s.session_id)
524 assert fetched is not None
525 assert fetched.user_id == "preserved-user"
526 assert fetched.client_capabilities == {"cap": "x"}
527 finally:
528 delete_session(s.session_id)
529
530 def test_delete_removes_session_from_store(self) -> None:
531 s = create_session(None, {})
532 sid = s.session_id
533 delete_session(sid)
534 assert get_session(sid) is None
535
536
537 # ── Layer 6 β€” Security ────────────────────────────────────────────────────────
538
539
540 class TestSecurityOrigin:
541 async def test_127_0_0_1_always_allowed(self, http_client: AsyncClient) -> None:
542 resp = await http_client.post(
543 "/mcp",
544 json=_init_body(),
545 headers={
546 "Content-Type": "application/json",
547 "Origin": "http://127.0.0.1",
548 },
549 )
550 assert resp.status_code == 200
551 if "mcp-session-id" in resp.headers:
552 delete_session(resp.headers["mcp-session-id"])
553
554 async def test_non_allowlisted_origin_rejected(
555 self, http_client: AsyncClient
556 ) -> None:
557 resp = await http_client.post(
558 "/mcp",
559 json=_init_body(),
560 headers={
561 "Content-Type": "application/json",
562 "Origin": "https://attacker.com",
563 },
564 )
565 assert resp.status_code == 403
566
567 async def test_origin_with_subdomain_rejected(
568 self, http_client: AsyncClient
569 ) -> None:
570 """localhost.attacker.com must not be confused with localhost."""
571 resp = await http_client.post(
572 "/mcp",
573 json=_init_body(),
574 headers={
575 "Content-Type": "application/json",
576 "Origin": "http://localhost.attacker.com",
577 },
578 )
579 assert resp.status_code == 403
580
581 async def test_no_origin_non_browser_allowed(
582 self, http_client: AsyncClient
583 ) -> None:
584 """curl / stdio bridges don't send Origin β€” must be permitted."""
585 resp = await http_client.post(
586 "/mcp",
587 json=_init_body(),
588 headers={"Content-Type": "application/json"},
589 )
590 assert resp.status_code == 200
591 if "mcp-session-id" in resp.headers:
592 delete_session(resp.headers["mcp-session-id"])
593
594
595 class TestSecuritySessionCapacity:
596 def test_session_capacity_error_on_overflow(self) -> None:
597 """create_session must raise SessionCapacityError when the store is full."""
598 from musehub.mcp import session as _session_mod
599
600 original_max = _session_mod._MAX_SESSIONS
601 _session_mod._MAX_SESSIONS = 0
602 try:
603 with pytest.raises(SessionCapacityError):
604 create_session(None, {})
605 finally:
606 _session_mod._MAX_SESSIONS = original_max
607
608
609 # ── Layer 7 β€” Performance ─────────────────────────────────────────────────────
610
611
612 class TestPerformanceDispatcher:
613 async def test_100_ping_requests_under_100ms(self) -> None:
614 """100Γ— handle_request('ping') must complete in under 100 ms."""
615 session = create_session(None, {})
616 req = {"jsonrpc": "2.0", "id": 1, "method": "ping"}
617 start = time.perf_counter()
618 for _ in range(100):
619 await handle_request(req, session=session)
620 elapsed_ms = (time.perf_counter() - start) * 1000
621 delete_session(session.session_id)
622 assert elapsed_ms < 100, f"100 pings took {elapsed_ms:.1f} ms"
623
624
625 class TestPerformanceSessionOps:
626 def test_1000_is_expired_calls_under_10ms(self) -> None:
627 """1000Γ— MCPSession.is_expired() must complete in under 10 ms."""
628 s = create_session(None, {})
629 try:
630 start = time.perf_counter()
631 for _ in range(1000):
632 s.is_expired()
633 elapsed_ms = (time.perf_counter() - start) * 1000
634 assert elapsed_ms < 10, f"1000Γ— is_expired took {elapsed_ms:.1f} ms"
635 finally:
636 delete_session(s.session_id)
637
638 def test_100_push_to_session_under_50ms(self) -> None:
639 """100Γ— push_to_session (no queues) must complete in under 50 ms."""
640 s = create_session(None, {})
641 try:
642 start = time.perf_counter()
643 for i in range(100):
644 push_to_session(s, f"data: {i}\n\n")
645 elapsed_ms = (time.perf_counter() - start) * 1000
646 assert elapsed_ms < 50, f"100Γ— push_to_session took {elapsed_ms:.1f} ms"
647 finally:
648 delete_session(s.session_id)