gabriel / muse public
test_runner.py python
626 lines 20.7 KB
Raw
sha256:81ae324db5ad375fbfe4834c6fcb378312cafad3cc92dec5d3e5c427306621a2 fix: remove commit_exists filter from have anchors — server… Sonnet 4.6 patch 21 days ago
1 """Subprocess pytest adapter for ``muse code test``.
2
3 Executes a targeted set of pytest node IDs in an isolated subprocess
4 environment, captures structured results via pytest-json-report, and
5 returns a typed :class:`RunResult`.
6
7 Design principles
8 -----------------
9 **Isolation** — Each ``run_tests`` call spawns a fresh subprocess. Only
10 the variables named in *env_allowlist* are passed through from the parent
11 environment. This prevents secrets (tokens, credentials, API keys) from
12 leaking into the test environment and ensures test results are reproducible.
13
14 **Parallelism** — When ``workers > 1``, the target list is split into
15 ``workers`` partitions and executed as independent subprocesses. Results
16 are gathered in completion order and merged into a single :class:`RunResult`.
17 This is a coarse partition strategy; for fine-grained per-test parallelism
18 use ``pytest-xdist`` by adding ``-x``/``--dist`` to *extra_args*.
19
20 **Budget** — A wall-clock timeout (``timeout_s``) is enforced via
21 ``subprocess.run(timeout=...)``. When exceeded the subprocess is killed and
22 a ``timed_out=True`` flag is set on the :class:`RunResult`. Individual test
23 budgets can be enforced with ``pytest-timeout`` in *extra_args*
24 (``--timeout=N``).
25
26 **Fallback** — When pytest-json-report is not installed, the adapter falls
27 back to parsing pytest's ``--tb=short`` text output. The fallback extracts
28 pass/fail counts but cannot produce per-test :class:`CaseResult` records;
29 it returns a summary-only result with an empty ``results`` list and sets
30 ``json_report_available=False``.
31
32 Security
33 --------
34 * ``subprocess.run`` is called with ``shell=False`` — the command is a
35 ``list[str]`` never passed to a shell interpreter.
36 * The subprocess receives no environment variables except those explicitly
37 listed in *env_allowlist* plus the minimal variables required for Python to
38 start (``PATH``, ``HOME``, ``PYTHONPATH``, ``VIRTUAL_ENV``).
39 * pytest is invoked as ``sys.executable -m pytest`` so the test runner uses
40 the same interpreter as the calling process — it is not resolved from
41 ``PATH`` and cannot be hijacked by a malicious ``pytest`` script on the
42 user's ``PATH``.
43 * The JSON report is written to a ``tempfile.NamedTemporaryFile`` in the
44 system temp directory and deleted after reading. It is not world-writable.
45 """
46
47 import json
48 import logging
49 import os
50 import pathlib
51 import subprocess
52 import sys
53 import tempfile
54 import time
55 from collections.abc import Callable, Sequence
56 from typing import Literal, NotRequired, TypedDict
57 from muse.core.types import Manifest, content_hash, now_utc_iso
58 from muse.core.timing import start_timer
59
60 logger = logging.getLogger(__name__)
61
62 # Allowlist of flag prefixes that agents/users may pass via extra_args.
63 # Any flag not matching a prefix here is silently dropped to prevent
64 # injection attacks (e.g. --rootdir overrides, --plugin, --import-mode).
65 _SAFE_PYTEST_FLAG_PREFIXES: frozenset[str] = frozenset({
66 "-v",
67 "-vv",
68 "-s",
69 "-x",
70 "--tb=",
71 "-k",
72 "-m",
73 "--collect-only", # explicit — NOT "--co" (would match --confcutdir)
74 "--durations=",
75 "--durations",
76 "--no-header",
77 "--maxfail=",
78 "--log-level=",
79 "--log-cli-level=",
80 "--timeout=",
81 })
82
83 def _filter_extra_args(extra_args: list[str]) -> list[str]:
84 """Return only the safe subset of *extra_args*.
85
86 Strips any flag not in :data:`_SAFE_PYTEST_FLAG_PREFIXES` to prevent
87 injection via ``--rootdir``, ``--import-mode``, ``-p``, ``--override-ini``,
88 ``--confcutdir``, or other subprocess-escaping vectors.
89
90 Positional values that immediately follow a dropped flag are also dropped,
91 because they are the value of the banned flag (e.g. ``--rootdir /etc``).
92 """
93 safe: list[str] = []
94 # None = neutral, True = prev flag was kept and expects a value, False = prev was dropped
95 prev_kept_expects_value: bool | None = None
96
97 for arg in extra_args:
98 if not arg.startswith("-"):
99 # Positional — keep only if it belongs to a kept flag, not a dropped one.
100 if prev_kept_expects_value is True:
101 safe.append(arg)
102 elif prev_kept_expects_value is False:
103 logger.debug("test_runner: dropping positional after unsafe flag %r", arg)
104 prev_kept_expects_value = None
105 continue
106
107 allowed = any(arg == p or arg.startswith(p) for p in _SAFE_PYTEST_FLAG_PREFIXES)
108 if allowed:
109 safe.append(arg)
110 # Flags that take a separate positional value.
111 prev_kept_expects_value = arg in {"-k", "-m", "--durations", "--maxfail", "--timeout"}
112 else:
113 logger.debug("test_runner: dropping unsafe extra_arg %r", arg)
114 # Mark as dropped so the next positional (if any) is also dropped.
115 prev_kept_expects_value = False
116
117 return safe
118
119 # ---------------------------------------------------------------------------
120 # Public type definitions
121 # ---------------------------------------------------------------------------
122
123 Outcome = Literal["passed", "failed", "error", "skipped"]
124
125 # Environment variables always forwarded to the subprocess regardless of the
126 # allowlist, because Python cannot start without them.
127 _MANDATORY_ENV_VARS: frozenset[str] = frozenset(
128 {
129 "PATH",
130 "HOME",
131 "PYTHONPATH",
132 "VIRTUAL_ENV",
133 "PYTHONHOME",
134 "PYTHONDONTWRITEBYTECODE",
135 "TMPDIR",
136 "TMP",
137 "TEMP",
138 "LANG",
139 "LC_ALL",
140 "LC_CTYPE",
141 }
142 )
143
144 # Default list of non-sensitive env vars allowed through.
145 DEFAULT_ENV_ALLOWLIST: list[str] = [
146 "MUSE_REPO_ROOT",
147 "MUSE_TEST_ENV",
148 "CI",
149 "TERM",
150 "COLORTERM",
151 "COLUMNS",
152 "ROWS",
153 "NO_COLOR",
154 ]
155
156 class RunConfig(TypedDict):
157 """Configuration for a single ``run_tests`` invocation."""
158
159 targets: list[str]
160 """Pytest node IDs or file paths to execute.
161
162 Each entry may be:
163 * A file path: ``"tests/test_foo.py"``
164 * A node ID: ``"tests/test_foo.py::TestBar::test_baz"``
165 * Empty list → pytest discovers all tests under ``testpaths`` from
166 ``pytest.ini``/``pyproject.toml``.
167 """
168
169 workers: int
170 """Number of parallel subprocess partitions. ``1`` = single process."""
171
172 timeout_s: float
173 """Wall-clock budget per *partition* in seconds. ``0`` = unlimited."""
174
175 extra_args: list[str]
176 """Additional arguments forwarded verbatim to pytest after node IDs."""
177
178 env_allowlist: list[str]
179 """Additional env var names to forward (beyond mandatory vars)."""
180
181 cwd: pathlib.Path | None
182 """Working directory for the subprocess. ``None`` = inherit CWD."""
183
184 stream_output: bool
185 """When ``True``, pytest's stdout/stderr are inherited by the parent
186 process so output streams live to the terminal. ``False`` (the default)
187 captures both streams for machine-readable processing (``--json`` mode)."""
188
189 class CaseResult(TypedDict):
190 """Outcome for a single pytest test function."""
191
192 node_id: str
193 """Pytest node ID."""
194
195 outcome: Outcome
196 """Test outcome."""
197
198 duration_ms: float
199 """Wall-clock execution time in milliseconds."""
200
201 stdout: str
202 """Captured stdout (empty when pytest was not run with ``-s``)."""
203
204 stderr: str
205 """Captured stderr."""
206
207 longrepr: NotRequired[str]
208 """Short failure representation (omitted when test passes)."""
209
210 class RunResult(TypedDict):
211 """Structured result of a ``run_tests`` call."""
212
213 run_id: str
214 """content-addressed sha256: ID identifying the run (matches the :class:`RunRecord` in history)."""
215
216 targets: list[str]
217 """Targets that were passed to pytest."""
218
219 exit_code: int
220 """Exit code from pytest (0 = all passed, 1 = failures, 2 = interrupted, …)."""
221
222 duration_ms: float
223 """Total wall-clock time for the run in milliseconds."""
224
225 results: list[CaseResult]
226 """Per-test results (empty when JSON report is unavailable)."""
227
228 total: int
229 """Total number of collected test cases."""
230
231 passed: int
232 """Number of passing tests."""
233
234 failed: int
235 """Number of failing tests."""
236
237 errored: int
238 """Number of tests that raised an unexpected exception."""
239
240 skipped: int
241 """Number of skipped tests."""
242
243 timed_out: bool
244 """True if the subprocess was killed due to *timeout_s*."""
245
246 json_report_available: bool
247 """True if pytest-json-report was available and produced a report."""
248
249 stdout: str
250 """Combined stdout from all subprocess partitions (for fallback mode)."""
251
252 stderr: str
253 """Combined stderr from all subprocess partitions."""
254
255 # ---------------------------------------------------------------------------
256 # Internal helpers
257 # ---------------------------------------------------------------------------
258
259 def _build_env(allowlist: Sequence[str]) -> Manifest:
260 """Build a sanitised environment dict for the pytest subprocess.
261
262 Only variables in *_MANDATORY_ENV_VARS* and *allowlist* are forwarded
263 from the current process environment. All other variables are stripped
264 to prevent credential leakage.
265 """
266 allowed: frozenset[str] = _MANDATORY_ENV_VARS | frozenset(allowlist)
267 return {
268 k: v
269 for k, v in os.environ.items()
270 if k in allowed
271 }
272
273 def _check_json_report() -> bool:
274 """Return True if pytest-json-report is importable in the current env."""
275 try:
276 import importlib.util
277 return importlib.util.find_spec("pytest_jsonreport") is not None
278 except Exception:
279 return False
280
281 class _ParsedReport(TypedDict):
282 """Parsed content from a pytest JSON report file."""
283
284 results: list[CaseResult]
285 total: int
286 passed: int
287 failed: int
288 errored: int
289 skipped: int
290
291 def _parse_json_report(path: str) -> _ParsedReport:
292 """Parse a pytest-json-report output file into a :class:`_ParsedReport`."""
293 try:
294 with open(path, encoding="utf-8") as fh:
295 doc = json.load(fh)
296 except Exception as exc:
297 logger.warning("⚠️ test_runner: failed to read JSON report %s: %s", path, exc)
298 return _ParsedReport(
299 results=[], total=0, passed=0, failed=0, errored=0, skipped=0
300 )
301
302 if not isinstance(doc, dict):
303 return _ParsedReport(
304 results=[], total=0, passed=0, failed=0, errored=0, skipped=0
305 )
306
307 summary = doc.get("summary", {})
308 if not isinstance(summary, dict):
309 summary = {}
310
311 raw_tests = doc.get("tests", [])
312 if not isinstance(raw_tests, list):
313 raw_tests = []
314
315 results: list[CaseResult] = []
316 for t in raw_tests:
317 if not isinstance(t, dict):
318 continue
319 node_id = str(t.get("nodeid", ""))
320 raw_outcome = str(t.get("outcome", "error"))
321 if raw_outcome == "passed":
322 outcome: Outcome = "passed"
323 elif raw_outcome == "failed":
324 outcome = "failed"
325 elif raw_outcome == "skipped":
326 outcome = "skipped"
327 else:
328 outcome = "error"
329 call_info = t.get("call", {})
330 if not isinstance(call_info, dict):
331 call_info = {}
332 duration_ms = float(call_info.get("duration", 0.0)) * 1000.0
333 stdout = str(call_info.get("stdout", "") or "")
334 stderr = str(call_info.get("stderr", "") or "")
335 longrepr_raw = call_info.get("longrepr", "")
336 longrepr = str(longrepr_raw) if longrepr_raw else ""
337
338 rec = CaseResult(
339 node_id=node_id,
340 outcome=outcome,
341 duration_ms=duration_ms,
342 stdout=stdout,
343 stderr=stderr,
344 )
345 if longrepr:
346 rec["longrepr"] = longrepr
347 results.append(rec)
348
349 passed = int(summary.get("passed", 0))
350 failed = int(summary.get("failed", 0))
351 errored = int(summary.get("error", 0))
352 skipped = int(summary.get("skipped", 0))
353 total = int(summary.get("total", len(results)))
354
355 return _ParsedReport(
356 results=results,
357 total=total,
358 passed=passed,
359 failed=failed,
360 errored=errored,
361 skipped=skipped,
362 )
363
364 class _FallbackCounts(TypedDict):
365 """Counts parsed from pytest text output when JSON report is unavailable."""
366
367 total: int
368 passed: int
369 failed: int
370 errored: int
371 skipped: int
372
373 def _parse_text_output(stdout: str) -> _FallbackCounts:
374 """Extract pass/fail counts from pytest's ``--tb=short`` text output."""
375 passed = failed = errored = skipped = total = 0
376 for line in stdout.splitlines():
377 if " passed" in line or " failed" in line or " error" in line:
378 parts = line.split()
379 for i, part in enumerate(parts):
380 # Strip trailing commas: pytest emits "2 passed, 1 failed …"
381 bare = part.rstrip(",")
382 if bare == "passed" and i > 0:
383 try:
384 passed = int(parts[i - 1].rstrip(","))
385 except ValueError:
386 pass
387 elif bare == "failed" and i > 0:
388 try:
389 failed = int(parts[i - 1].rstrip(","))
390 except ValueError:
391 pass
392 elif bare in {"error", "errors"} and i > 0:
393 try:
394 errored = int(parts[i - 1].rstrip(","))
395 except ValueError:
396 pass
397 elif bare == "skipped" and i > 0:
398 try:
399 skipped = int(parts[i - 1].rstrip(","))
400 except ValueError:
401 pass
402 total = passed + failed + errored + skipped
403 return _FallbackCounts(
404 total=total,
405 passed=passed,
406 failed=failed,
407 errored=errored,
408 skipped=skipped,
409 )
410
411 class _PartitionResult(TypedDict):
412 """Result from executing one subprocess partition."""
413
414 exit_code: int
415 duration_ms: float
416 timed_out: bool
417 report: _ParsedReport | None
418 stdout: str
419 stderr: str
420 fallback_counts: _FallbackCounts | None
421
422 def _run_partition(
423 targets: list[str],
424 config: RunConfig,
425 json_report: bool,
426 report_path: str,
427 ) -> _PartitionResult:
428 """Execute one pytest subprocess for *targets* and return its result."""
429 stream = config.get("stream_output", False)
430
431 cmd: list[str] = [sys.executable, "-m", "pytest"]
432
433 if json_report:
434 cmd += ["--json-report", f"--json-report-file={report_path}"]
435
436 if stream:
437 # Verbose mode: let pytest own the terminal directly.
438 cmd += ["--tb=short", "-v"]
439 else:
440 cmd += ["--tb=short", "-q"]
441
442 cmd += targets
443 cmd += _filter_extra_args(config["extra_args"])
444
445 env = _build_env(config["env_allowlist"])
446 cwd_str: str | None = str(config["cwd"]) if config["cwd"] else None
447 timeout: float | None = config["timeout_s"] if config["timeout_s"] > 0 else None
448
449 elapsed = start_timer()
450 timed_out = False
451 exit_code = 0
452 stdout = ""
453 stderr = ""
454
455 try:
456 if stream:
457 # Inherit the parent's stdout/stderr so pytest output flows live.
458 stream_proc: subprocess.CompletedProcess[bytes] = subprocess.run(
459 cmd,
460 timeout=timeout,
461 env=env,
462 cwd=cwd_str,
463 )
464 exit_code = stream_proc.returncode
465 else:
466 cap_proc: subprocess.CompletedProcess[str] = subprocess.run(
467 cmd,
468 capture_output=True,
469 text=True,
470 timeout=timeout,
471 env=env,
472 cwd=cwd_str,
473 )
474 exit_code = cap_proc.returncode
475 stdout = cap_proc.stdout or ""
476 stderr = cap_proc.stderr or ""
477 except subprocess.TimeoutExpired as exc:
478 timed_out = True
479 exit_code = 124 # same as GNU timeout convention
480 if not stream:
481 raw_out = exc.stdout or b""
482 raw_err = exc.stderr or b""
483 stdout = raw_out.decode(errors="replace") if isinstance(raw_out, bytes) else raw_out
484 stderr = raw_err.decode(errors="replace") if isinstance(raw_err, bytes) else raw_err
485 logger.warning("⚠️ test_runner: partition timed out after %.1f s", config["timeout_s"])
486 except OSError as exc:
487 exit_code = 127
488 stderr = f"failed to launch pytest: {exc}"
489 logger.error("❌ test_runner: subprocess error: %s", exc)
490
491 parsed_report: _ParsedReport | None = None
492 fallback: _FallbackCounts | None = None
493
494 if json_report and not timed_out:
495 try:
496 parsed_report = _parse_json_report(report_path)
497 except Exception as exc:
498 logger.debug("test_runner: json report parse failed: %s", exc)
499 # Only parse text output when we actually captured it. Stream
500 # mode inherits the parent's stdout/stderr, so `stdout` is always
501 # the empty string — passing it to _parse_text_output would return
502 # zeros and produce a misleading "0 passed 0 failed" summary.
503 fallback = _parse_text_output(stdout) if stdout else None
504 else:
505 fallback = _parse_text_output(stdout) if stdout else None
506
507 return _PartitionResult(
508 exit_code=exit_code,
509 duration_ms=elapsed(),
510 timed_out=timed_out,
511 report=parsed_report,
512 stdout=stdout,
513 stderr=stderr,
514 fallback_counts=fallback,
515 )
516
517 def _partition(items: list[str], n: int) -> list[list[str]]:
518 """Split *items* into *n* roughly equal partitions."""
519 if n <= 1 or not items:
520 return [items]
521 size = max(1, len(items) // n)
522 parts: list[list[str]] = []
523 for i in range(0, len(items), size):
524 parts.append(items[i : i + size])
525 return parts
526
527 # ---------------------------------------------------------------------------
528 # Public API
529 # ---------------------------------------------------------------------------
530
531 def run_tests(
532 config: RunConfig,
533 *,
534 progress_cb: Callable[[CaseResult], None] | None = None,
535 ) -> RunResult:
536 """Execute the tests described by *config* and return structured results.
537
538 Args:
539 config: See :class:`RunConfig` for full documentation.
540 progress_cb: Optional callback invoked once per :class:`CaseResult`
541 as results arrive (useful for streaming progress to a
542 terminal). Called synchronously in the gathering loop.
543
544 Returns:
545 A :class:`RunResult` with per-test outcomes and aggregate counts.
546 """
547 targets = config["targets"]
548 workers = max(1, config["workers"])
549 _started_at = now_utc_iso()
550 run_id = content_hash({"started_at": _started_at, "targets": sorted(targets), "workers": workers})
551 json_report = _check_json_report()
552
553 if not json_report:
554 logger.info(
555 "ℹ️ pytest-jsonreport not installed — running in fallback mode "
556 "(aggregate counts only, no per-test results)"
557 )
558
559 elapsed = start_timer()
560
561 partitions = _partition(targets, workers)
562
563 # --- Execute partitions ------------------------------------------------
564 all_results: list[CaseResult] = []
565 all_stdout: list[str] = []
566 all_stderr: list[str] = []
567 combined_exit_code = 0
568 any_timed_out = False
569 total = passed = failed = errored = skipped = 0
570
571 for part_targets in partitions:
572 with tempfile.NamedTemporaryFile(
573 suffix=".json", delete=False, mode="w"
574 ) as tmp:
575 report_path = tmp.name
576
577 try:
578 part = _run_partition(part_targets, config, json_report, report_path)
579 finally:
580 try:
581 os.unlink(report_path)
582 except OSError:
583 pass
584
585 if part["exit_code"] != 0:
586 combined_exit_code = max(combined_exit_code, part["exit_code"])
587 if part["timed_out"]:
588 any_timed_out = True
589 all_stdout.append(part["stdout"])
590 all_stderr.append(part["stderr"])
591
592 if part["report"] is not None:
593 rep = part["report"]
594 total += rep["total"]
595 passed += rep["passed"]
596 failed += rep["failed"]
597 errored += rep["errored"]
598 skipped += rep["skipped"]
599 for res in rep["results"]:
600 all_results.append(res)
601 if progress_cb is not None:
602 progress_cb(res)
603 elif part["fallback_counts"] is not None:
604 fb = part["fallback_counts"]
605 total += fb["total"]
606 passed += fb["passed"]
607 failed += fb["failed"]
608 errored += fb["errored"]
609 skipped += fb["skipped"]
610
611 return RunResult(
612 run_id=run_id,
613 targets=targets,
614 exit_code=combined_exit_code,
615 duration_ms=elapsed(),
616 results=all_results,
617 total=total,
618 passed=passed,
619 failed=failed,
620 errored=errored,
621 skipped=skipped,
622 timed_out=any_timed_out,
623 json_report_available=json_report,
624 stdout="\n".join(all_stdout),
625 stderr="\n".join(all_stderr),
626 )
File History 4 commits
sha256:81ae324db5ad375fbfe4834c6fcb378312cafad3cc92dec5d3e5c427306621a2 fix: remove commit_exists filter from have anchors — server… Sonnet 4.6 patch 21 days ago
sha256:36c3cb3e76619d4c30a6d9bf81b5ec4ff148e30dcfed913e3114ca7b43b81c7e fix: rename objects→blobs in push client and all stale test… Sonnet 4.6 patch 23 days ago
sha256:c06a9b9b9fee26c68ea725b44d54b2c0a171301ce9de746d5b656617b4463a9a fix: repair four test failures from post-migration audit Sonnet 4.6 patch 29 days ago
sha256:1900655993c83c4107067375548a7be823e471d2515830842f1a12cba4bd3cdf fix: unified object store migration — idempotent writes, JS… Sonnet 4.6 minor 29 days ago