"""Tests for DerivedKey memory zeroing after use. DerivedKey.private_bytes and DerivedKey.chain_code used to be immutable ``bytes`` — they could not be zeroed, so raw key material lingered in the Python heap indefinitely after derivation. Fix: - Fields changed to ``bytearray`` so contents can be overwritten. - ``DerivedKey.zero()`` sets both fields to all-zero bytes. - ``derive_path`` zeroes each intermediate DerivedKey after deriving the next child. - ``derive_hd_public_info`` zeroes the final DerivedKey after the Ed25519 PrivateKey object has been created. Coverage -------- I DerivedKey fields are bytearray I1 private_bytes is bytearray, not bytes I2 chain_code is bytearray, not bytes II DerivedKey.zero() wipes both fields II1 after zero(), private_bytes is all-zero II2 after zero(), chain_code is all-zero II3 zero() does not affect the length (still 32 bytes) III derive_hd_public_info zeroes the final DerivedKey III1 private_bytes is all-zero in the DerivedKey after derive_hd_public_info returns III2 chain_code is all-zero in the DerivedKey after derive_hd_public_info returns IV Derivation still correct after zeroing changes IV1 same seed → same fingerprint (deterministic derivation unchanged) """ from __future__ import annotations from unittest.mock import patch import pytest from muse.core import hdkeys as _hdkeys from muse.core.slip010 import master_key, DerivedKey from muse.core.bip39 import mnemonic_to_seed from muse.core.keypair import derive_hd_public_info _MNEMONIC = ( "abandon abandon abandon abandon abandon abandon abandon abandon " "abandon abandon abandon about" ) _SEED = mnemonic_to_seed(_MNEMONIC) # --------------------------------------------------------------------------- # I DerivedKey fields are bytearray # --------------------------------------------------------------------------- class TestDerivedKeyFieldTypes: def test_I1_private_bytes_is_bytearray(self) -> None: """I1: DerivedKey.private_bytes must be bytearray, not bytes.""" dk = master_key(_SEED) assert isinstance(dk.private_bytes, bytearray), ( f"private_bytes must be bytearray, got {type(dk.private_bytes).__name__}" ) def test_I2_chain_code_is_bytearray(self) -> None: """I2: DerivedKey.chain_code must be bytearray, not bytes.""" dk = master_key(_SEED) assert isinstance(dk.chain_code, bytearray), ( f"chain_code must be bytearray, got {type(dk.chain_code).__name__}" ) # --------------------------------------------------------------------------- # II DerivedKey.zero() wipes both fields # --------------------------------------------------------------------------- class TestDerivedKeyZero: def test_II1_zero_wipes_private_bytes(self) -> None: """II1: after zero(), private_bytes contains only null bytes.""" dk = master_key(_SEED) assert any(b != 0 for b in dk.private_bytes), "pre-condition: key must not already be zero" dk.zero() assert dk.private_bytes == bytearray(32), "private_bytes must be all-zero after zero()" def test_II2_zero_wipes_chain_code(self) -> None: """II2: after zero(), chain_code contains only null bytes.""" dk = master_key(_SEED) assert any(b != 0 for b in dk.chain_code), "pre-condition: chain_code must not already be zero" dk.zero() assert dk.chain_code == bytearray(32), "chain_code must be all-zero after zero()" def test_II3_zero_preserves_length(self) -> None: """II3: zero() does not change the field lengths.""" dk = master_key(_SEED) dk.zero() assert len(dk.private_bytes) == 32 assert len(dk.chain_code) == 32 # --------------------------------------------------------------------------- # III derive_hd_public_info zeroes the final DerivedKey # --------------------------------------------------------------------------- class TestDeriveHdPublicInfoZeroing: def test_III1_private_bytes_zeroed_after_derive(self) -> None: """III1: the DerivedKey's private_bytes are all-zero after derive_hd_public_info.""" captured: list[DerivedKey] = [] original_derive = _hdkeys.derive_identity_key def capturing_derive(*args: int | bytes, **kwargs: int) -> DerivedKey: dk = original_derive(*args, **kwargs) captured.append(dk) return dk with patch.object(_hdkeys, "derive_identity_key", side_effect=capturing_derive): derive_hd_public_info(_SEED) assert captured, "derive_identity_key was not called" dk = captured[0] assert dk.private_bytes == bytearray(32), ( "private_bytes must be zeroed after derive_hd_public_info" ) def test_III2_chain_code_zeroed_after_derive(self) -> None: """III2: the DerivedKey's chain_code is all-zero after derive_hd_public_info.""" captured: list[DerivedKey] = [] original_derive = _hdkeys.derive_identity_key def capturing_derive(*args: int | bytes, **kwargs: int) -> DerivedKey: dk = original_derive(*args, **kwargs) captured.append(dk) return dk with patch.object(_hdkeys, "derive_identity_key", side_effect=capturing_derive): derive_hd_public_info(_SEED) dk = captured[0] assert dk.chain_code == bytearray(32), ( "chain_code must be zeroed after derive_hd_public_info" ) # --------------------------------------------------------------------------- # IV Derivation still correct # --------------------------------------------------------------------------- class TestDerivedKeyZeroingCorrectness: def test_IV1_same_seed_same_fingerprint(self) -> None: """IV1: zeroing does not affect determinism — same seed → same fingerprint.""" _, fp1 = derive_hd_public_info(_SEED) _, fp2 = derive_hd_public_info(_SEED) assert fp1 == fp2, "Zeroing must not break deterministic derivation"