"""Unit tests for the musehub.crypto.keys abstraction layer. Covers every public function, every error path, every algorithm boundary, and every security-critical property documented in keys.py. Red-team coverage: - Bit-flip attacks on signature bytes - Bit-flip attacks on public key bytes - Zero-length and over-length inputs - Cross-algorithm key/signature confusion - Constant-time fingerprint comparison side-channel - b64url padding stripping (both directions) """ from __future__ import annotations import os import time from muse.core.types import public_key_fingerprint import pytest from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey from musehub.crypto.keys import ( AlgorithmNotImplementedError, DEFAULT_ALGORITHM, IMPLEMENTED_ALGORITHMS, SIGNATURE_SIZES, PUBLIC_KEY_SIZES, InvalidKeyError, KeyAlgorithm, SignatureError, b64url_decode, b64url_encode, fingerprints_equal, key_fingerprint, verify_signature, ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _ed25519_keypair() -> tuple[Ed25519PrivateKey, bytes]: priv = Ed25519PrivateKey.generate() pub = priv.public_key().public_bytes_raw() return priv, pub def _sign_ed25519(priv: Ed25519PrivateKey, msg: bytes) -> bytes: return priv.sign(msg) # --------------------------------------------------------------------------- # KeyAlgorithm enum # --------------------------------------------------------------------------- class TestKeyAlgorithmEnum: def test_ed25519_value(self) -> None: assert KeyAlgorithm.ED25519.value == "ed25519" def test_ml_dsa_65_value(self) -> None: assert KeyAlgorithm.ML_DSA_65.value == "ml-dsa-65" def test_round_trip_from_string(self) -> None: assert KeyAlgorithm("ed25519") is KeyAlgorithm.ED25519 def test_unknown_string_raises(self) -> None: with pytest.raises(ValueError): KeyAlgorithm("rsa-2048") def test_default_algorithm_is_ed25519(self) -> None: assert DEFAULT_ALGORITHM is KeyAlgorithm.ED25519 def test_ed25519_is_implemented(self) -> None: assert KeyAlgorithm.ED25519 in IMPLEMENTED_ALGORITHMS def test_ml_dsa_65_is_not_yet_implemented(self) -> None: # When this test fails, it means ML-DSA-65 was added — good! # Update IMPLEMENTED_ALGORITHMS and remove this assert. assert KeyAlgorithm.ML_DSA_65 not in IMPLEMENTED_ALGORITHMS # --------------------------------------------------------------------------- # Key size registry # --------------------------------------------------------------------------- class TestKeySizes: def test_ed25519_public_key_is_32_bytes(self) -> None: assert PUBLIC_KEY_SIZES[KeyAlgorithm.ED25519] == 32 def test_ml_dsa_65_public_key_is_1952_bytes(self) -> None: assert PUBLIC_KEY_SIZES[KeyAlgorithm.ML_DSA_65] == 1952 def test_ed25519_signature_is_64_bytes(self) -> None: assert SIGNATURE_SIZES[KeyAlgorithm.ED25519] == 64 def test_ml_dsa_65_signature_is_3309_bytes(self) -> None: assert SIGNATURE_SIZES[KeyAlgorithm.ML_DSA_65] == 3309 def test_all_algorithms_have_key_and_sig_size(self) -> None: for algo in KeyAlgorithm: assert algo in PUBLIC_KEY_SIZES, f"Missing PUBLIC_KEY_SIZES entry for {algo}" assert algo in SIGNATURE_SIZES, f"Missing SIGNATURE_SIZES entry for {algo}" # --------------------------------------------------------------------------- # key_fingerprint # --------------------------------------------------------------------------- class TestKeyFingerprint: def test_is_sha256_prefixed_hex(self) -> None: raw = os.urandom(32) expected = public_key_fingerprint(raw) assert key_fingerprint(raw) == expected def test_output_is_71_chars(self) -> None: assert len(key_fingerprint(os.urandom(32))) == 71 def test_starts_with_sha256_prefix(self) -> None: fp = key_fingerprint(os.urandom(32)) assert fp.startswith("sha256:") def test_hex_part_is_lowercase(self) -> None: fp = key_fingerprint(os.urandom(32)) hex_part = fp[len("sha256:"):] assert hex_part == hex_part.lower() def test_different_keys_have_different_fingerprints(self) -> None: a = os.urandom(32) b = os.urandom(32) assert key_fingerprint(a) != key_fingerprint(b) def test_same_key_always_same_fingerprint(self) -> None: raw = os.urandom(32) assert key_fingerprint(raw) == key_fingerprint(raw) def test_empty_bytes_does_not_crash(self) -> None: fp = key_fingerprint(b"") assert len(fp) == 71 assert fp.startswith("sha256:") def test_large_key_bytes_work(self) -> None: # ML-DSA-65 key: 1952 bytes fp = key_fingerprint(os.urandom(1952)) assert len(fp) == 71 assert fp.startswith("sha256:") # --------------------------------------------------------------------------- # fingerprints_equal — constant-time comparison # --------------------------------------------------------------------------- class TestFingerprintsEqual: def test_equal_fingerprints(self) -> None: raw = os.urandom(32) fp = key_fingerprint(raw) assert fingerprints_equal(fp, fp) is True def test_different_fingerprints(self) -> None: fp_a = key_fingerprint(os.urandom(32)) fp_b = key_fingerprint(os.urandom(32)) assert fingerprints_equal(fp_a, fp_b) is False def test_case_insensitive(self) -> None: fp = key_fingerprint(os.urandom(32)) assert fingerprints_equal(fp.upper(), fp.lower()) is True def test_timing_is_not_short_circuit(self) -> None: """ Both equal and unequal comparisons must take approximately the same time — hmac.compare_digest processes all bytes regardless of mismatch. This test is probabilistic; flakiness indicates a timing leak. """ raw = os.urandom(32) fp = key_fingerprint(raw) fp_wrong = key_fingerprint(os.urandom(32)) samples = 1000 times_equal = [] times_unequal = [] for _ in range(samples): t0 = time.perf_counter_ns() fingerprints_equal(fp, fp) times_equal.append(time.perf_counter_ns() - t0) t0 = time.perf_counter_ns() fingerprints_equal(fp, fp_wrong) times_unequal.append(time.perf_counter_ns() - t0) # Median times should be within 10× of each other (very lenient — # the real guarantee comes from hmac.compare_digest itself). median_eq = sorted(times_equal)[samples // 2] median_ne = sorted(times_unequal)[samples // 2] ratio = max(median_eq, median_ne) / max(min(median_eq, median_ne), 1) assert ratio < 10, ( f"Suspicious timing gap: equal={median_eq}ns unequal={median_ne}ns ratio={ratio:.1f}x" ) # --------------------------------------------------------------------------- # b64url_encode / b64url_decode # --------------------------------------------------------------------------- class TestB64url: def test_round_trip(self) -> None: for _ in range(50): raw = os.urandom(64) assert b64url_decode(b64url_encode(raw)) == raw def test_no_padding_in_encoded(self) -> None: for length in range(1, 40): assert "=" not in b64url_encode(os.urandom(length)) def test_url_safe_chars_only(self) -> None: import string allowed = set(string.ascii_letters + string.digits + "-_") for _ in range(50): encoded = b64url_encode(os.urandom(64)) assert set(encoded) <= allowed, f"Non-url-safe chars in: {encoded}" def test_decode_with_padding(self) -> None: raw = os.urandom(10) encoded_with_padding = f"{b64url_encode(raw)}==" assert b64url_decode(encoded_with_padding) == raw def test_decode_without_padding(self) -> None: raw = os.urandom(10) encoded = b64url_encode(raw) assert b64url_decode(encoded) == raw def test_empty_bytes(self) -> None: assert b64url_encode(b"") == "" assert b64url_decode("") == b"" def test_known_vector(self) -> None: # RFC 4648 §10: bytes [0xFB, 0xFF, 0xFE] → "+//+" in standard base64 # → "-__-" in base64url raw = bytes([0xFB, 0xFF, 0xFE]) assert b64url_encode(raw) == "-__-" assert b64url_decode("-__-") == raw # --------------------------------------------------------------------------- # verify_signature — Ed25519 # --------------------------------------------------------------------------- class TestVerifySignatureEd25519: def test_valid_signature(self) -> None: priv, pub = _ed25519_keypair() msg = os.urandom(32) sig = _sign_ed25519(priv, msg) verify_signature( algorithm=KeyAlgorithm.ED25519, public_key_bytes=pub, message=msg, signature_bytes=sig, ) # must not raise def test_wrong_message_rejected(self) -> None: priv, pub = _ed25519_keypair() msg = os.urandom(32) sig = _sign_ed25519(priv, msg) with pytest.raises(SignatureError): verify_signature( algorithm=KeyAlgorithm.ED25519, public_key_bytes=pub, message=msg + b"\x00", # one extra byte signature_bytes=sig, ) def test_wrong_key_rejected(self) -> None: priv_a, pub_a = _ed25519_keypair() priv_b, pub_b = _ed25519_keypair() msg = os.urandom(32) sig = _sign_ed25519(priv_a, msg) with pytest.raises(SignatureError): verify_signature( algorithm=KeyAlgorithm.ED25519, public_key_bytes=pub_b, # wrong key message=msg, signature_bytes=sig, ) def test_bit_flip_in_signature_rejected(self) -> None: priv, pub = _ed25519_keypair() msg = os.urandom(32) sig = bytearray(_sign_ed25519(priv, msg)) sig[0] ^= 0xFF # flip first byte with pytest.raises(SignatureError): verify_signature( algorithm=KeyAlgorithm.ED25519, public_key_bytes=pub, message=msg, signature_bytes=bytes(sig), ) def test_bit_flip_last_byte_rejected(self) -> None: priv, pub = _ed25519_keypair() msg = os.urandom(32) sig = bytearray(_sign_ed25519(priv, msg)) sig[-1] ^= 0x01 # flip single bit at end with pytest.raises(SignatureError): verify_signature( algorithm=KeyAlgorithm.ED25519, public_key_bytes=pub, message=msg, signature_bytes=bytes(sig), ) def test_bit_flip_in_public_key_rejected(self) -> None: priv, pub = _ed25519_keypair() msg = os.urandom(32) sig = _sign_ed25519(priv, msg) bad_pub = bytearray(pub) bad_pub[0] ^= 0x01 with pytest.raises((SignatureError, InvalidKeyError)): verify_signature( algorithm=KeyAlgorithm.ED25519, public_key_bytes=bytes(bad_pub), message=msg, signature_bytes=sig, ) def test_zeroed_signature_rejected(self) -> None: priv, pub = _ed25519_keypair() msg = os.urandom(32) with pytest.raises(SignatureError): verify_signature( algorithm=KeyAlgorithm.ED25519, public_key_bytes=pub, message=msg, signature_bytes=bytes(64), ) def test_zeroed_public_key_rejected(self) -> None: priv, pub = _ed25519_keypair() msg = os.urandom(32) sig = _sign_ed25519(priv, msg) with pytest.raises((SignatureError, InvalidKeyError)): verify_signature( algorithm=KeyAlgorithm.ED25519, public_key_bytes=bytes(32), message=msg, signature_bytes=sig, ) def test_short_public_key_rejected(self) -> None: priv, pub = _ed25519_keypair() msg = os.urandom(32) sig = _sign_ed25519(priv, msg) with pytest.raises(InvalidKeyError): verify_signature( algorithm=KeyAlgorithm.ED25519, public_key_bytes=pub[:31], # one byte short message=msg, signature_bytes=sig, ) def test_long_public_key_rejected(self) -> None: priv, pub = _ed25519_keypair() msg = os.urandom(32) sig = _sign_ed25519(priv, msg) with pytest.raises(InvalidKeyError): verify_signature( algorithm=KeyAlgorithm.ED25519, public_key_bytes=pub + b"\x00", # one byte extra message=msg, signature_bytes=sig, ) def test_short_signature_rejected(self) -> None: priv, pub = _ed25519_keypair() msg = os.urandom(32) sig = _sign_ed25519(priv, msg) with pytest.raises(SignatureError): verify_signature( algorithm=KeyAlgorithm.ED25519, public_key_bytes=pub, message=msg, signature_bytes=sig[:63], ) def test_long_signature_rejected(self) -> None: priv, pub = _ed25519_keypair() msg = os.urandom(32) sig = _sign_ed25519(priv, msg) with pytest.raises(SignatureError): verify_signature( algorithm=KeyAlgorithm.ED25519, public_key_bytes=pub, message=msg, signature_bytes=sig + b"\x00", ) def test_empty_message_is_allowed(self) -> None: """Ed25519 is defined for all-length messages including empty.""" priv, pub = _ed25519_keypair() sig = _sign_ed25519(priv, b"") verify_signature( algorithm=KeyAlgorithm.ED25519, public_key_bytes=pub, message=b"", signature_bytes=sig, ) def test_large_message(self) -> None: priv, pub = _ed25519_keypair() msg = os.urandom(1024 * 1024) # 1 MB sig = _sign_ed25519(priv, msg) verify_signature( algorithm=KeyAlgorithm.ED25519, public_key_bytes=pub, message=msg, signature_bytes=sig, ) # --------------------------------------------------------------------------- # verify_signature — ML-DSA-65 (not yet implemented) # --------------------------------------------------------------------------- class TestVerifySignatureMlDsa65: def test_raises_not_implemented(self) -> None: with pytest.raises(AlgorithmNotImplementedError) as exc_info: verify_signature( algorithm=KeyAlgorithm.ML_DSA_65, public_key_bytes=os.urandom(1952), message=b"hello", signature_bytes=os.urandom(3309), ) assert "ml-dsa-65" in str(exc_info.value).lower() def test_error_message_mentions_upgrade_path(self) -> None: with pytest.raises(AlgorithmNotImplementedError) as exc_info: verify_signature( algorithm=KeyAlgorithm.ML_DSA_65, public_key_bytes=os.urandom(1952), message=b"hello", signature_bytes=os.urandom(3309), ) msg = str(exc_info.value) assert "keys.py" in msg or "defined" in msg # --------------------------------------------------------------------------- # b64url_decode — canonical algo-prefixed values # --------------------------------------------------------------------------- class TestBase64UrlCodecContracts: """Strict contracts for base64url codec functions. ``b64url_decode`` — bare-only, for the MSign header ``sig=`` field. ``decode_pubkey`` / ``decode_sig`` — canonical prefixed, for all stored values. There is no backward compatibility: every cryptographic value is either explicitly bare (MSign sig= by protocol design) or canonically prefixed. Mixing these up is a programming error, not a supported usage. """ def test_b64url_decode_bare_value(self) -> None: """b64url_decode correctly decodes bare base64url (its only valid input).""" raw = os.urandom(32) encoded = b64url_encode(raw) assert b64url_decode(encoded) == raw def test_b64url_decode_bare_64_byte_signature(self) -> None: """b64url_decode decodes a bare 64-byte signature (MSign sig= use case).""" raw = os.urandom(64) encoded = b64url_encode(raw) assert b64url_decode(encoded) == raw def test_decode_pubkey_extracts_raw_bytes(self) -> None: """decode_pubkey correctly decodes a canonical ``ed25519:`` public key.""" from muse.core.types import decode_pubkey, encode_pubkey raw = os.urandom(32) prefixed = encode_pubkey("ed25519", raw) algo, decoded = decode_pubkey(prefixed) assert algo == "ed25519" assert decoded == raw def test_decode_sig_extracts_raw_bytes(self) -> None: """decode_sig correctly decodes a canonical ``ed25519:`` signature.""" from muse.core.types import decode_sig, encode_sig raw = os.urandom(64) prefixed = encode_sig("ed25519", raw) algo, decoded = decode_sig(prefixed) assert algo == "ed25519" assert decoded == raw def test_decode_pubkey_round_trip_with_real_key(self) -> None: """encode_pubkey → decode_pubkey round-trips a real Ed25519 public key.""" from muse.core.types import encode_pubkey, decode_pubkey from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat priv = Ed25519PrivateKey.generate() raw = priv.public_key().public_bytes(Encoding.Raw, PublicFormat.Raw) prefixed = encode_pubkey("ed25519", raw) algo, decoded = decode_pubkey(prefixed) assert algo == "ed25519" assert decoded == raw