"""Tests for derive_agent_sub_seed memory zeroing. derive_agent_sub_seed manually chains master_key → child_key × 4 without zeroing any intermediate DerivedKey, and returns the final key material as immutable bytes. This means: 1. Four intermediate DerivedKey objects (each with 32-byte private_bytes and chain_code) linger in the Python heap after reassignment. 2. The returned 64-byte bytearray lives in the caller's frame until manually zeroed — callers must now do this. Fix: - Zero each intermediate DerivedKey before reassigning dk. - Zero the final DerivedKey after building the return value. - Return bytearray so callers can zero it after use. Coverage -------- I Return type is bytearray I1 derive_agent_sub_seed returns bytearray, not bytes II Callers can zero the returned value II1 returned bytearray can be overwritten in-place III Intermediate keys are zeroed (via monkey-patch inspection) III1 all DerivedKey objects created during derivation are zeroed by the time the function returns IV Correctness unaffected IV1 same inputs → same 64-byte output (determinism preserved) IV2 output length is always 64 bytes """ from __future__ import annotations from unittest.mock import patch, call import pytest from muse.core.bip39 import mnemonic_to_seed from muse.core.hdkeys import derive_agent_sub_seed, DOMAIN_IDENTITY from muse.core.slip010 import DerivedKey _MNEMONIC = ( "abandon abandon abandon abandon abandon abandon abandon abandon " "abandon abandon abandon about" ) _SEED = mnemonic_to_seed(_MNEMONIC) # --------------------------------------------------------------------------- # I Return type is bytearray # --------------------------------------------------------------------------- class TestReturnType: def test_I1_returns_bytearray(self) -> None: """I1: derive_agent_sub_seed must return bytearray so callers can zero it.""" result = derive_agent_sub_seed(_SEED, domain=DOMAIN_IDENTITY, agent_id=0) assert isinstance(result, bytearray), ( f"Expected bytearray, got {type(result).__name__}" ) # --------------------------------------------------------------------------- # II Callers can zero the returned value # --------------------------------------------------------------------------- class TestCallerCanZero: def test_II1_returned_bytearray_is_mutable(self) -> None: """II1: the returned bytearray can be overwritten in place.""" result = derive_agent_sub_seed(_SEED, domain=DOMAIN_IDENTITY, agent_id=0) assert any(b != 0 for b in result), "pre-condition: result must not already be zero" result[:] = b"\x00" * len(result) assert result == bytearray(64), "caller must be able to zero the returned bytearray" # --------------------------------------------------------------------------- # III Intermediate keys are zeroed # --------------------------------------------------------------------------- class TestIntermediatesZeroed: def test_III1_all_derived_keys_zeroed_on_return(self) -> None: """III1: every DerivedKey created during derivation is zeroed by the time derive_agent_sub_seed returns.""" import muse.core.hdkeys as hdkeys_mod from muse.core.slip010 import master_key as real_master, child_key as real_child created: list[DerivedKey] = [] def tracking_master(seed: bytes) -> DerivedKey: dk = real_master(seed) created.append(dk) return dk def tracking_child(parent: DerivedKey, index: int) -> DerivedKey: dk = real_child(parent, index) created.append(dk) return dk # Patch in the hdkeys namespace where the names are bound with patch.object(hdkeys_mod, "master_key", side_effect=tracking_master), \ patch.object(hdkeys_mod, "child_key", side_effect=tracking_child): derive_agent_sub_seed(_SEED, domain=DOMAIN_IDENTITY, agent_id=0) assert created, "No DerivedKey objects were tracked — something is wrong" for i, dk in enumerate(created): assert dk.private_bytes == bytearray(32), ( f"DerivedKey #{i} private_bytes not zeroed after derive_agent_sub_seed" ) assert dk.chain_code == bytearray(32), ( f"DerivedKey #{i} chain_code not zeroed after derive_agent_sub_seed" ) # --------------------------------------------------------------------------- # IV Correctness unaffected # --------------------------------------------------------------------------- class TestCorrectness: def test_IV1_deterministic(self) -> None: """IV1: same inputs always produce the same 64 bytes.""" r1 = derive_agent_sub_seed(_SEED, domain=DOMAIN_IDENTITY, agent_id=0) r2 = derive_agent_sub_seed(_SEED, domain=DOMAIN_IDENTITY, agent_id=0) # Compare before zeroing assert bytes(r1) == bytes(r2), "derive_agent_sub_seed must be deterministic" def test_IV2_length_64(self) -> None: """IV2: output is always exactly 64 bytes.""" result = derive_agent_sub_seed(_SEED, domain=DOMAIN_IDENTITY, agent_id=0) assert len(result) == 64