From e919690d6a549e30cab67c9c555b9c020a066862 Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Mon, 5 Jan 2026 12:02:24 +0800 Subject: [PATCH] Update XMSS --- include/gmssl/xmss.h | 11 +- src/xmss.c | 618 ++++++++++++++++++++++--------------------- 2 files changed, 322 insertions(+), 307 deletions(-) diff --git a/include/gmssl/xmss.h b/include/gmssl/xmss.h index b00322f7..c41a2caf 100644 --- a/include/gmssl/xmss.h +++ b/include/gmssl/xmss.h @@ -237,8 +237,8 @@ int xmss_private_key_print(FILE *fp, int fmt, int ind, const char *label, const typedef struct { uint32_t index; // < 2^(XMSS_MAX_HEIGHT) = 2^20, always encode to 4 bytes - uint8_t random[32]; - hash256_t wots_sig[67]; + hash256_t random; + wots_sig_t wots_sig; hash256_t auth_path[XMSS_MAX_HEIGHT]; } XMSS_SIGNATURE; @@ -248,23 +248,24 @@ typedef struct { #define XMSS_SIGNATURE_MIN_SIZE (4 + 32 + 32*67 + 32 * XMSS_MIN_HEIGHT) // = 2500 bytes #define XMSS_SIGNATURE_MAX_SIZE (4 + 32 + 32*67 + 32 * XMSS_MAX_HEIGHT) // = 2820 bytes int xmss_signature_size(uint32_t xmss_type, size_t *siglen); +int xmss_signature_to_bytes(const XMSS_SIGNATURE *sig, uint32_t xmss_type, uint8_t **out, size_t *outlen); +int xmss_signature_from_bytes(XMSS_SIGNATURE *sig, uint32_t xmss_type, const uint8_t **in, size_t *inlen); int xmss_signature_print(FILE *fp, int fmt, int ind, const char *label, const uint8_t *in, size_t inlen); int xmss_signature_print_ex(FILE *fp, int fmt, int ind, const char *label, const XMSS_SIGNATURE *sig); - typedef struct { XMSS_PUBLIC_KEY xmss_public_key; XMSS_SIGNATURE xmss_sig; HASH256_CTX hash256_ctx; } XMSS_SIGN_CTX; - int xmss_sign_init(XMSS_SIGN_CTX *ctx, XMSS_KEY *key); int xmss_sign_update(XMSS_SIGN_CTX *ctx, const uint8_t *data, size_t datalen); int xmss_sign_finish(XMSS_SIGN_CTX *ctx, uint8_t *sigbuf, size_t *siglen); int xmss_verify_init(XMSS_SIGN_CTX *ctx, const XMSS_KEY *key, const uint8_t *sigbuf, size_t siglen); int xmss_verify_update(XMSS_SIGN_CTX *ctx, const uint8_t *data, size_t datalen); int xmss_verify_finish(XMSS_SIGN_CTX *ctx); +void xmss_sign_ctx_cleanup(XMSS_SIGN_CTX *ctx); enum { @@ -412,7 +413,7 @@ int xmssmt_verify_init_ex(XMSSMT_SIGN_CTX *ctx, const XMSSMT_KEY *key, const XMS int xmssmt_verify_init(XMSSMT_SIGN_CTX *ctx, const XMSSMT_KEY *key, const uint8_t *sig, size_t siglen); int xmssmt_verify_update(XMSSMT_SIGN_CTX *ctx, const uint8_t *data, size_t datalen); int xmssmt_verify_finish(XMSSMT_SIGN_CTX *ctx); - +void xmssmt_sign_ctx_cleanup(XMSSMT_SIGN_CTX *ctx); #ifdef __cplusplus } diff --git a/src/xmss.c b/src/xmss.c index 5f955c97..66107d6f 100644 --- a/src/xmss.c +++ b/src/xmss.c @@ -175,11 +175,11 @@ int xmss_adrs_print(FILE *fp, int fmt, int ind, const char *label, const hash256 tree_address = GETU64(adrs); adrs += 8; - format_print(fp, fmt, ind, "tree_address: %"PRIu64"\n", tree_address); + format_print(fp, fmt, ind, "tree_address : %"PRIu64"\n", tree_address); type = GETU32(adrs); adrs += 4; - format_print(fp, fmt, ind, "type: %"PRIu32"\n", type); + format_print(fp, fmt, ind, "type : %"PRIu32"\n", type); if (type == XMSS_ADRS_TYPE_OTS) { uint32_t ots_address; @@ -188,13 +188,13 @@ int xmss_adrs_print(FILE *fp, int fmt, int ind, const char *label, const hash256 ots_address = GETU32(adrs); adrs += 4; - format_print(fp, fmt, ind, "ots_address: %"PRIu32"\n", ots_address); + format_print(fp, fmt, ind, "ots_address : %"PRIu32"\n", ots_address); chain_address = GETU32(adrs); adrs += 4; format_print(fp, fmt, ind, "chain_address: %"PRIu32"\n", chain_address); hash_address = GETU32(adrs); adrs += 4; - format_print(fp, fmt, ind, "hash_address: %"PRIu32"\n", hash_address); + format_print(fp, fmt, ind, "hash_address : %"PRIu32"\n", hash_address); } else if (type == XMSS_ADRS_TYPE_LTREE) { uint32_t ltree_address; uint32_t tree_height; @@ -205,10 +205,10 @@ int xmss_adrs_print(FILE *fp, int fmt, int ind, const char *label, const hash256 format_print(fp, fmt, ind, "ltree_address: %"PRIu32"\n", ltree_address); tree_height = GETU32(adrs); adrs += 4; - format_print(fp, fmt, ind, "tree_height: %"PRIu32"\n", tree_height); + format_print(fp, fmt, ind, "tree_height : %"PRIu32"\n", tree_height); tree_index = GETU32(adrs); adrs += 4; - format_print(fp, fmt, ind, "tree_index: %"PRIu32"\n", tree_index); + format_print(fp, fmt, ind, "tree_index : %"PRIu32"\n", tree_index); } else if (type == XMSS_ADRS_TYPE_HASHTREE) { uint32_t padding; uint32_t tree_height; @@ -216,20 +216,20 @@ int xmss_adrs_print(FILE *fp, int fmt, int ind, const char *label, const hash256 padding = GETU32(adrs); adrs += 4; - format_print(fp, fmt, ind, "padding: %"PRIu32"\n", padding); + format_print(fp, fmt, ind, "padding : %"PRIu32"\n", padding); tree_height = GETU32(adrs); adrs += 4; - format_print(fp, fmt, ind, "tree_height: %"PRIu32"\n", tree_height); + format_print(fp, fmt, ind, "tree_height : %"PRIu32"\n", tree_height); tree_index = GETU32(adrs); adrs += 4; - format_print(fp, fmt, ind, "tree_index: %"PRIu32"\n", tree_index); + format_print(fp, fmt, ind, "tree_index : %"PRIu32"\n", tree_index); } else { error_print(); } key_and_mask = GETU32(adrs); adrs += 4; - format_print(fp, fmt, ind, "key_and_mask: %"PRIu32"\n", key_and_mask); + format_print(fp, fmt, ind, "key_and_mask : %"PRIu32"\n", key_and_mask); return 1; } @@ -246,78 +246,73 @@ void wots_derive_sk(const hash256_t secret, }; HASH256_CTX ctx; xmss_adrs_t adrs; - int i; + int chain; adrs_copy_layer_address(adrs, ots_adrs); adrs_copy_tree_address(adrs, ots_adrs); adrs_copy_type(adrs, ots_adrs); adrs_copy_ots_address(adrs, ots_adrs); - for (i = 0; i < 67; i++) { - adrs_set_chain_address(adrs, i); + for (chain = 0; chain < WOTS_NUM_CHAINS; chain++) { + adrs_set_chain_address(adrs, chain); adrs_set_hash_address(adrs, 0); - adrs_set_key_and_mask(adrs, 0); + adrs_set_key_and_mask(adrs, XMSS_ADRS_GENERATE_KEY); hash256_init(&ctx); hash256_update(&ctx, hash256_four, sizeof(hash256_t)); hash256_update(&ctx, secret, sizeof(hash256_t)); hash256_update(&ctx, seed, sizeof(hash256_t)); hash256_update(&ctx, adrs, sizeof(xmss_adrs_t)); - hash256_finish(&ctx, sk[i]); + hash256_finish(&ctx, sk[chain]); } } void wots_chain(const hash256_t x, const hash256_t seed, const xmss_adrs_t ots_adrs, - int start, int steps, hash256_t pk) + int start, int steps, hash256_t y) { const hash256_t hash256_zero = {0}; HASH256_CTX ctx; - uint8_t adrs[32]; + xmss_adrs_t adrs; + hash256_t key; + hash256_t bitmask; int i; - //assert(start >= 0 && start <= 15); - //assert(steps >= 0 && steps <= 15); - //assert(start + steps <= 15); + // tmp = x + memcpy(y, x, sizeof(hash256_t)); - memcpy(pk, x, 32); - - // 4 * 6 = 24, copy 24 bytes adrs_copy_layer_address(adrs, ots_adrs); adrs_copy_tree_address(adrs, ots_adrs); adrs_copy_type(adrs, ots_adrs); adrs_copy_ots_address(adrs, ots_adrs); adrs_copy_chain_address(adrs, ots_adrs); - for (i = start; i < start + steps; i++) { - uint8_t key[32]; - uint8_t bitmask[32]; - - adrs_set_hash_address(adrs, i); + for (i = 0; i < steps; i++) { + adrs_set_hash_address(adrs, start + i); // key = prf(seed, adrs) adrs_set_key_and_mask(adrs, XMSS_ADRS_GENERATE_KEY); hash256_init(&ctx); - hash256_update(&ctx, hash256_three, 32); - hash256_update(&ctx, seed, 32); - hash256_update(&ctx, adrs, 32); + hash256_update(&ctx, hash256_three, sizeof(hash256_t)); + hash256_update(&ctx, seed, sizeof(hash256_t)); + hash256_update(&ctx, adrs, sizeof(xmss_adrs_t)); hash256_finish(&ctx, key); // bitmask = prf(seed, adrs) adrs_set_key_and_mask(adrs, XMSS_ADRS_GENERATE_BITMASK); hash256_init(&ctx); - hash256_update(&ctx, hash256_three, 32); - hash256_update(&ctx, seed, 32); - hash256_update(&ctx, adrs, 32); + hash256_update(&ctx, hash256_three, sizeof(hash256_t)); + hash256_update(&ctx, seed, sizeof(hash256_t)); + hash256_update(&ctx, adrs, sizeof(xmss_adrs_t)); hash256_finish(&ctx, bitmask); // tmp = f(key, tmp xor bitmask) - gmssl_memxor(pk, pk, bitmask, 32); + gmssl_memxor(y, y, bitmask, sizeof(hash256_t)); hash256_init(&ctx); - hash256_update(&ctx, hash256_zero, 32); - hash256_update(&ctx, key, 32); - hash256_update(&ctx, pk, 32); - hash256_finish(&ctx, pk); + hash256_update(&ctx, hash256_zero, sizeof(hash256_t)); + hash256_update(&ctx, key, sizeof(hash256_t)); + hash256_update(&ctx, y, sizeof(hash256_t)); + hash256_finish(&ctx, y); } } @@ -326,23 +321,26 @@ void wots_sk_to_pk(const wots_key_t sk, const hash256_t seed, const xmss_adrs_t ots_adrs, wots_key_t pk) { + const int start = 0; + const int steps = WOTS_WINTERNITZ_W - 1; xmss_adrs_t adrs; - int i; + int chain; adrs_copy_layer_address(adrs, ots_adrs); adrs_copy_tree_address(adrs, ots_adrs); adrs_copy_type(adrs, ots_adrs); adrs_copy_ots_address(adrs, ots_adrs); - for (i = 0; i < 67; i++) { - adrs_set_chain_address(adrs, i); - wots_chain(sk[i], seed, adrs, 0, 15, pk[i]); + for (chain = 0; chain < WOTS_NUM_CHAINS; chain++) { + adrs_set_chain_address(adrs, chain); + wots_chain(sk[chain], seed, adrs, start, steps, pk[chain]); } } // seperate 256 bit digest into 256/4 = 64 step values, generate 3 checksum step values // output steps[i] in [0, w-1] = [0, 16-1] -static void base_w_and_checksum(const hash256_t dgst, uint8_t steps[67]) +// this implementation is for hash256 and w=16 only! +static void base_w_and_checksum(const hash256_t dgst, int steps[67]) { int csum = 0; int sbits; @@ -370,33 +368,13 @@ static void base_w_and_checksum(const hash256_t dgst, uint8_t steps[67]) } void wots_sign(const wots_key_t sk, - const hash256_t seed, const xmss_adrs_t wots_adrs, + const hash256_t seed, const xmss_adrs_t ots_adrs, const hash256_t dgst, wots_key_t sig) { hash256_t adrs; - uint8_t steps[WOTS_NUM_CHAINS]; - int i; - - adrs_copy_layer_address(adrs, wots_adrs); - adrs_copy_tree_address(adrs, wots_adrs); - adrs_copy_type(adrs, wots_adrs); - adrs_copy_ots_address(adrs, wots_adrs); - - base_w_and_checksum(dgst, steps); - - for (i = 0; i < WOTS_NUM_CHAINS; i++) { - adrs_set_chain_address(adrs, i); - wots_chain(sk[i], seed, adrs, 0, steps[i], sig[i]); - } -} - -void wots_sig_to_pk(const wots_sig_t sig, - const hash256_t seed, const xmss_adrs_t ots_adrs, - const hash256_t dgst, wots_key_t pk) -{ - hash256_t adrs; - uint8_t steps[67]; - int i; + const int start = 0; + int steps[WOTS_NUM_CHAINS]; + int chain; adrs_copy_layer_address(adrs, ots_adrs); adrs_copy_tree_address(adrs, ots_adrs); @@ -405,9 +383,30 @@ void wots_sig_to_pk(const wots_sig_t sig, base_w_and_checksum(dgst, steps); - for (i = 0; i < 67; i++) { - adrs_set_chain_address(adrs, i); - wots_chain(sig[i], seed, adrs, steps[i], 15 - steps[i], pk[i]); + for (chain = 0; chain < WOTS_NUM_CHAINS; chain++) { + adrs_set_chain_address(adrs, chain); + wots_chain(sk[chain], seed, adrs, start, steps[chain], sig[chain]); + } +} + +void wots_sig_to_pk(const wots_sig_t sig, + const hash256_t seed, const xmss_adrs_t ots_adrs, + const hash256_t dgst, wots_key_t pk) +{ + hash256_t adrs; + int steps[67]; + int chain; + + adrs_copy_layer_address(adrs, ots_adrs); + adrs_copy_tree_address(adrs, ots_adrs); + adrs_copy_type(adrs, ots_adrs); + adrs_copy_ots_address(adrs, ots_adrs); + + base_w_and_checksum(dgst, steps); + + for (chain = 0; chain < WOTS_NUM_CHAINS; chain++) { + adrs_set_chain_address(adrs, chain); + wots_chain(sig[chain], seed, adrs, steps[chain], 15 - steps[chain], pk[chain]); } } @@ -425,8 +424,8 @@ static void randomized_tree_hash(const hash256_t left_child, const hash256_t rig HASH256_CTX ctx; xmss_adrs_t adrs; hash256_t key; - hash256_t bm_0; - hash256_t bm_1; + hash256_t bm0; + hash256_t bm1; // copy adrs (and set the last key_and_mask) adrs_copy_layer_address(adrs, tree_adrs); @@ -436,38 +435,38 @@ static void randomized_tree_hash(const hash256_t left_child, const hash256_t rig adrs_copy_tree_height(adrs, tree_adrs); adrs_copy_tree_index(adrs, tree_adrs); - adrs_set_key_and_mask(adrs, 0); // key = prf(seed, adrs) + adrs_set_key_and_mask(adrs, 0); hash256_init(&ctx); hash256_update(&ctx, hash256_three, sizeof(hash256_t)); hash256_update(&ctx, seed, sizeof(hash256_t)); hash256_update(&ctx, adrs, sizeof(xmss_adrs_t)); hash256_finish(&ctx, key); - adrs_set_key_and_mask(adrs, 1); // bm_0 = prf(seed, adrs) + adrs_set_key_and_mask(adrs, 1); hash256_init(&ctx); hash256_update(&ctx, hash256_three, sizeof(hash256_t)); hash256_update(&ctx, seed, sizeof(hash256_t)); hash256_update(&ctx, adrs, sizeof(xmss_adrs_t)); - hash256_finish(&ctx, bm_0); + hash256_finish(&ctx, bm0); - adrs_set_key_and_mask(adrs, 2); // bm_1 = prf(seed, adrs) + adrs_set_key_and_mask(adrs, 2); hash256_init(&ctx); hash256_update(&ctx, hash256_three, sizeof(hash256_t)); hash256_update(&ctx, seed, sizeof(hash256_t)); hash256_update(&ctx, adrs, sizeof(xmss_adrs_t)); - hash256_finish(&ctx, bm_1); + hash256_finish(&ctx, bm1); // parent = Hash( tobyte(1, 32) || key || (left xor bm_0) || (right xor bm_1) ) - gmssl_memxor(bm_0, bm_0, left_child, sizeof(hash256_t)); - gmssl_memxor(bm_1, bm_1, right_child, sizeof(hash256_t)); + gmssl_memxor(bm0, bm0, left_child, sizeof(hash256_t)); + gmssl_memxor(bm1, bm1, right_child, sizeof(hash256_t)); hash256_init(&ctx); hash256_update(&ctx, hash256_one, sizeof(hash256_t)); hash256_update(&ctx, key, sizeof(hash256_t)); - hash256_update(&ctx, bm_0, sizeof(hash256_t)); - hash256_update(&ctx, bm_1, sizeof(hash256_t)); + hash256_update(&ctx, bm0, sizeof(hash256_t)); + hash256_update(&ctx, bm1, sizeof(hash256_t)); hash256_finish(&ctx, parent); } @@ -480,19 +479,20 @@ void wots_pk_to_root(const wots_key_t in_pk, xmss_adrs_t adrs; uint32_t tree_height = 0; int len = WOTS_NUM_CHAINS; + uint32_t i; memcpy(pk, in_pk, sizeof(wots_key_t)); adrs_copy_layer_address(adrs, in_adrs); adrs_copy_tree_address(adrs, in_adrs); - adrs_copy_type(adrs, in_adrs); // type must be LTREE + adrs_copy_type(adrs, in_adrs); adrs_copy_ltree_address(adrs, in_adrs); adrs_set_tree_height(adrs, tree_height++); while (len > 1) { - for (i = 0; i < (uint32_t)len/2; i++) { + for (i = 0; i < len/2; i++) { adrs_set_tree_index(adrs, i); randomized_tree_hash(pk[2 * i], pk[2 * i + 1], seed, adrs, pk[i]); } @@ -508,92 +508,83 @@ void wots_pk_to_root(const wots_key_t in_pk, } int wots_verify(const hash256_t wots_root, - const hash256_t seed, const xmss_adrs_t in_adrs, + const hash256_t seed, const xmss_adrs_t ots_adrs, const hash256_t dgst, const wots_sig_t sig) { xmss_adrs_t adrs; wots_key_t pk; hash256_t root; - adrs_copy_layer_address(adrs, in_adrs); - adrs_copy_tree_address(adrs, in_adrs); - - adrs_set_type(adrs, XMSS_ADRS_TYPE_OTS); - adrs_copy_ots_address(adrs, in_adrs); + adrs_copy_layer_address(adrs, ots_adrs); + adrs_copy_tree_address(adrs, ots_adrs); + adrs_copy_type(adrs, ots_adrs); + adrs_copy_ots_address(adrs, ots_adrs); wots_sig_to_pk(sig, seed, adrs, dgst, pk); adrs_set_type(adrs, XMSS_ADRS_TYPE_LTREE); - adrs_copy_ltree_address(adrs, in_adrs); + adrs_copy_ltree_address(adrs, ots_adrs); // ltree_address offset is same as ots_address wots_pk_to_root(pk, seed, adrs, root); - if (memcmp(root, wots_root, sizeof(hash256_t)) == 0) { - return 1; - } else { + if (memcmp(root, wots_root, sizeof(hash256_t)) != 0) { + //error_print(); return 0; } + return 1; } -// adrs: layer_address, tree_address, ots_address or ltree_address should be set void wots_derive_root(const hash256_t secret, - const hash256_t seed, const xmss_adrs_t adrs, + const hash256_t seed, const xmss_adrs_t ots_adrs, hash256_t wots_root) { + xmss_adrs_t adrs; wots_key_t wots_key; - xmss_adrs_t wots_adrs; - xmss_adrs_t ltree_adrs; - adrs_copy_layer_address(wots_adrs, adrs); - adrs_copy_tree_address(wots_adrs, adrs); - adrs_set_type(wots_adrs, XMSS_ADRS_TYPE_OTS); - adrs_copy_ots_address(wots_adrs, adrs); + adrs_copy_layer_address(adrs, ots_adrs); + adrs_copy_tree_address(adrs, ots_adrs); + adrs_copy_type(adrs, ots_adrs); + adrs_copy_ots_address(adrs, ots_adrs); + wots_derive_sk(secret, seed, adrs, wots_key); + wots_sk_to_pk(wots_key, seed, adrs, wots_key); - wots_derive_sk(secret, seed, wots_adrs, wots_key); - wots_sk_to_pk(wots_key, seed, wots_adrs, wots_key); - - adrs_copy_layer_address(ltree_adrs, adrs); - adrs_copy_tree_address(ltree_adrs, adrs); - adrs_set_type(ltree_adrs, XMSS_ADRS_TYPE_LTREE); - adrs_copy_ltree_address(ltree_adrs, adrs); // ltree_address == ots_address - - wots_pk_to_root(wots_key, seed, ltree_adrs, wots_root); + adrs_set_type(adrs, XMSS_ADRS_TYPE_LTREE); + adrs_copy_ltree_address(adrs, ots_adrs); // ltree_address offset is same as ots_address + wots_pk_to_root(wots_key, seed, adrs, wots_root); } -static size_t tree_root_offset(size_t height) { + + +static size_t xmss_tree_root_offset(size_t height) { return (1 << (height + 1)) - 2; } -// 2^(height + 1) - 1 +size_t xmss_num_tree_nodes(size_t height) { + return (1 << (height + 1)) - 1; +} + void xmss_build_tree(const hash256_t secret, - const hash256_t seed, const xmss_adrs_t tree_adrs, + const hash256_t seed, const xmss_adrs_t xmss_adrs, size_t height, hash256_t *tree) { xmss_adrs_t adrs; hash256_t *children; hash256_t *parents; size_t n = 1 << height; - size_t h; - size_t i; + uint32_t h; // as tree_height + uint32_t i; // as tree_index - adrs_copy_layer_address(adrs, tree_adrs); - adrs_copy_tree_address(adrs, tree_adrs); + adrs_copy_layer_address(adrs, xmss_adrs); + adrs_copy_tree_address(adrs, xmss_adrs); // derive 2^h wots+ roots as leaves of xmss tree adrs_set_type(adrs, XMSS_ADRS_TYPE_OTS); - //fprintf(stderr, "xmss_build_tree() progress\n"); for (i = 0; i < n; i++) { adrs_set_ots_address(adrs, i); wots_derive_root(secret, seed, adrs, tree[i]); - /* - if (i % (n/100) == 0 && i/(n/100) <= 100) { - fprintf(stderr, " %zu%%\n", i/(n/100) ); - } - */ } // build xmss tree adrs_set_type(adrs, XMSS_ADRS_TYPE_HASHTREE); adrs_set_padding(adrs, 0); - adrs_set_key_and_mask(adrs, 0); children = tree; parents = tree + n; @@ -609,21 +600,6 @@ void xmss_build_tree(const hash256_t secret, } } -void xmss_do_sign(const hash256_t secret, uint32_t index, - const hash256_t seed, const xmss_adrs_t in_adrs, - const hash256_t dgst, wots_sig_t wots_sig) -{ - xmss_adrs_t adrs; - - adrs_copy_layer_address(adrs, in_adrs); - adrs_copy_tree_address(adrs, in_adrs); - adrs_set_type(adrs, XMSS_ADRS_TYPE_OTS); - adrs_set_ots_address(adrs, index); - - wots_derive_sk(secret, seed, adrs, wots_sig); - wots_sign(wots_sig, seed, adrs, dgst, wots_sig); -} - void xmss_build_auth_path(const hash256_t *tree, size_t height, uint32_t tree_index, hash256_t *auth_path) { size_t h; @@ -634,30 +610,18 @@ void xmss_build_auth_path(const hash256_t *tree, size_t height, uint32_t tree_in } } -static uint64_t xmssmt_tree_address(uint64_t index, size_t height, size_t layers, size_t layer) { - return (index >> (height/layers) * (layer + 1)); -} - -static uint64_t xmssmt_tree_index(uint64_t index, size_t height, size_t layers, size_t layer) { - return (index >> (height/layers) * layer) % (1 << (height/layers)); -} - - - void xmss_build_root(const hash256_t wots_root, uint32_t tree_index, - const hash256_t seed, const xmss_adrs_t in_adrs, + const hash256_t seed, const xmss_adrs_t xmss_adrs, const hash256_t *auth_path, size_t height, hash256_t root) { xmss_adrs_t adrs; - size_t h; - - adrs_copy_layer_address(adrs, in_adrs); - adrs_copy_tree_address(adrs, in_adrs); + uint32_t h; + adrs_copy_layer_address(adrs, xmss_adrs); + adrs_copy_tree_address(adrs, xmss_adrs); adrs_set_type(adrs, XMSS_ADRS_TYPE_HASHTREE); adrs_set_padding(adrs, 0); - adrs_set_key_and_mask(adrs, 0); memcpy(root, wots_root, sizeof(hash256_t)); @@ -666,55 +630,13 @@ void xmss_build_root(const hash256_t wots_root, uint32_t tree_index, tree_index >>= 1; adrs_set_tree_height(adrs, h + 1); adrs_set_tree_index(adrs, tree_index); + if (right_child) randomized_tree_hash(auth_path[h], root, seed, adrs, root); else randomized_tree_hash(root, auth_path[h], seed, adrs, root); } } -// remove this function -void xmss_sig_to_root(const hash256_t wots_sig[67], - const uint8_t seed[32], const uint8_t in_adrs[32], - const uint8_t dgst[32], - const hash256_t *auth_path, int height, - uint8_t xmss_root[32]) -{ - xmss_adrs_t adrs; - wots_key_t wots_pk; - uint8_t *node = xmss_root; - int h; - uint32_t index; - - // wots_sig to wots_pk - adrs_copy_layer_address(adrs, in_adrs); - adrs_copy_tree_address(adrs, in_adrs); - adrs_set_type(adrs, XMSS_ADRS_TYPE_OTS); - adrs_copy_ots_address(adrs, in_adrs); - wots_sig_to_pk(wots_sig, seed, adrs, dgst, wots_pk); - - // wots_pk to wots_root - adrs_set_type(adrs, XMSS_ADRS_TYPE_LTREE); - adrs_copy_ltree_address(adrs, in_adrs); - wots_pk_to_root(wots_pk, seed, adrs, xmss_root); - - index = GETU32(in_adrs + 16); - - // wots_root, auth_path => xmss_root - adrs_set_type(adrs, XMSS_ADRS_TYPE_HASHTREE); - adrs_set_padding(adrs, 0); - adrs_set_key_and_mask(adrs, 0); - - for (h = 0; h < height; h++) { - int right = index & 1; - index >>= 1; - adrs_set_tree_height(adrs, h); - adrs_set_tree_index(adrs, index); - if (right) - randomized_tree_hash(auth_path[h], node, seed, adrs, node); - else randomized_tree_hash(node, auth_path[h], seed, adrs, node); - } -} - int xmss_type_to_height(uint32_t xmss_type, size_t *height) { switch (xmss_type) { @@ -750,13 +672,13 @@ uint32_t xmss_type_from_name(const char *name) return 0; } -int xmss_key_generate(XMSS_KEY *key, uint32_t xmss_type) +int xmss_key_generate_ex(XMSS_KEY *key, uint32_t xmss_type, + const hash256_t seed, const hash256_t secret, const hash256_t sk_prf) { size_t height; - size_t tree_nodes; // = 2^(h + 1) - 1 xmss_adrs_t adrs; - if (!key) { + if (!key || !seed || !secret || !sk_prf) { error_print(); return -1; } @@ -765,29 +687,76 @@ int xmss_key_generate(XMSS_KEY *key, uint32_t xmss_type) return -1; } memset(key, 0, sizeof(*key)); + if (!(key->tree = malloc(sizeof(hash256_t) * xmss_num_tree_nodes(height)))) { + error_print(); + return -1; + } key->public_key.xmss_type = xmss_type; - - if (rand_bytes(key->public_key.seed, 32) != 1 - || rand_bytes(key->secret, 32) != 1 - || rand_bytes(key->sk_prf, 32) != 1) { - error_print(); - return -1; - } - tree_nodes = (1 << height) * 2 - 1; - if (!(key->tree = malloc(sizeof(hash256_t) * tree_nodes))) { - error_print(); - return -1; - } + memcpy(key->public_key.seed, seed, sizeof(hash256_t)); + memcpy(key->secret, secret, sizeof(hash256_t)); + memcpy(key->sk_prf, sk_prf, sizeof(hash256_t)); adrs_set_layer_address(adrs, 0); adrs_set_tree_address(adrs, 0); xmss_build_tree(key->secret, key->public_key.seed, adrs, height, key->tree); - memcpy(key->public_key.root, key->tree[tree_root_offset(height)], sizeof(hash256_t)); + memcpy(key->public_key.root, key->tree[xmss_tree_root_offset(height)], sizeof(hash256_t)); key->index = 0; return 1; } +int xmss_key_generate(XMSS_KEY *key, uint32_t xmss_type) +{ + int ret = -1; + hash256_t seed; + hash256_t secret; + hash256_t sk_prf; + + if (!key) { + error_print(); + return -1; + } + if (rand_bytes(seed, sizeof(hash256_t)) != 1 + || rand_bytes(secret, sizeof(hash256_t)) != 1 + || rand_bytes(sk_prf, sizeof(hash256_t)) != 1) { + error_print(); + goto end; + } + if (xmss_key_generate_ex(key, xmss_type, seed, secret, sk_prf) != 1) { + error_print(); + goto end; + } + ret = 1; +end: + gmssl_secure_clear(seed, sizeof(seed)); // clear all RNG outputs + gmssl_secure_clear(secret, sizeof(secret)); + gmssl_secure_clear(sk_prf, sizeof(sk_prf)); + return ret; +} + +int xmss_key_update(XMSS_KEY *key) +{ + size_t height; + + if (!key) { + error_print(); + return -1; + } + if (xmss_type_to_height(key->public_key.xmss_type, &height) != 1) { + error_print(); + return -1; + } + if (key->index > (1 << height)) { + error_print(); + return -1; + } + if (key->index == (1 << height)) { + return 0; + } + key->index++; + return 1; +} + int xmss_key_remaining_signs(const XMSS_KEY *key, size_t *count) { size_t height; @@ -813,12 +782,13 @@ int xmss_key_remaining_signs(const XMSS_KEY *key, size_t *count) void xmss_key_cleanup(XMSS_KEY *key) { if (key) { + gmssl_secure_clear(key->public_key.seed, sizeof(hash256_t)); // clear all RNG outputs gmssl_secure_clear(key->secret, sizeof(hash256_t)); gmssl_secure_clear(key->sk_prf, sizeof(hash256_t)); if (key->tree) { free(key->tree); + key->tree = NULL; } - memset(key, 0, sizeof(*key)); } } @@ -861,21 +831,21 @@ int xmss_public_key_print(FILE *fp, int fmt, int ind, const char *label, const X format_print(fp, fmt, ind, "%s\n", label); ind += 4; format_print(fp, fmt, ind, "type: %s\n", xmss_type_name(key->public_key.xmss_type)); - format_bytes(fp, fmt, ind, "seed", key->public_key.seed, 32); - format_bytes(fp, fmt, ind, "root", key->public_key.root, 32); + format_bytes(fp, fmt, ind, "seed", key->public_key.seed, sizeof(hash256_t)); + format_bytes(fp, fmt, ind, "root", key->public_key.root, sizeof(hash256_t)); return 1; } - int xmss_private_key_to_bytes(const XMSS_KEY *key, uint8_t **out, size_t *outlen) { if (!key || !outlen) { error_print(); return -1; } - uint32_to_bytes(key->public_key.xmss_type, out, outlen); - hash256_to_bytes(key->public_key.root, out, outlen); - hash256_to_bytes(key->public_key.seed, out, outlen); + if (xmss_public_key_to_bytes(key, out, outlen) != 1) { + error_print(); + return -1; + } uint32_to_bytes(key->index, out, outlen); hash256_to_bytes(key->secret, out, outlen); hash256_to_bytes(key->sk_prf, out, outlen); @@ -896,18 +866,15 @@ int xmss_private_key_from_bytes(XMSS_KEY *key, const uint8_t **in, size_t *inlen error_print(); return -1; } - memset(key, 0, sizeof(*key)); - // xmss_type - uint32_from_bytes(&key->public_key.xmss_type, in, inlen); + if (xmss_public_key_from_bytes(key, in, inlen) != 1) { + error_print(); + return -1; + } if (xmss_type_to_height(key->public_key.xmss_type, &height) != 1) { error_print(); return -1; } - // root - hash256_from_bytes(key->public_key.root, in, inlen); - // seed - hash256_from_bytes(key->public_key.seed, in, inlen); // index, allow index == 2^h, which means out-of-keys uint32_from_bytes(&key->index, in, inlen); @@ -932,7 +899,7 @@ int xmss_private_key_from_bytes(XMSS_KEY *key, const uint8_t **in, size_t *inlen adrs_set_tree_address(adrs, 0); xmss_build_tree(key->secret, key->public_key.seed, adrs, height, key->tree); // check - if (memcmp(key->tree[tree_root_offset(height)], + if (memcmp(key->tree[xmss_tree_root_offset(height)], key->public_key.root, sizeof(hash256_t)) != 0) { xmss_key_cleanup(key); error_print(); @@ -945,12 +912,10 @@ int xmss_private_key_print(FILE *fp, int fmt, int ind, const char *label, const { format_print(fp, fmt, ind, "%s\n", label); ind += 4; - format_print(fp, fmt, ind, "type: %s\n", xmss_type_name(key->public_key.xmss_type)); - format_bytes(fp, fmt, ind, "seed", key->public_key.seed, 32); - format_bytes(fp, fmt, ind, "root", key->public_key.root, 32); - format_bytes(fp, fmt, ind, "secret", key->secret, 32); - format_bytes(fp, fmt, ind, "sk_prf", key->sk_prf, 32); - format_print(fp, fmt, ind, "index: %u\n", key->index); + xmss_public_key_print(fp, fmt, ind, "public_key", key); + format_bytes(fp, fmt, ind, "secret", key->secret, sizeof(hash256_t)); + format_bytes(fp, fmt, ind, "sk_prf", key->sk_prf, sizeof(hash256_t)); + format_print(fp, fmt, ind, "index: %"PRIu32"\n", key->index); return 1; } @@ -974,7 +939,7 @@ int xmss_signature_size(uint32_t xmss_type, size_t *siglen) return 1; } -int xmss_signature_from_bytes(uint32_t xmss_type, XMSS_SIGNATURE *sig, const uint8_t **in, size_t *inlen) +int xmss_signature_from_bytes(XMSS_SIGNATURE *sig, uint32_t xmss_type, const uint8_t **in, size_t *inlen) { size_t height; size_t siglen; @@ -995,15 +960,16 @@ int xmss_signature_from_bytes(uint32_t xmss_type, XMSS_SIGNATURE *sig, const uin uint32_from_bytes(&sig->index, in, inlen); hash256_from_bytes(sig->random, in, inlen); - for (i = 0; i < 67; i++) + for (i = 0; i < WOTS_NUM_CHAINS; i++) { hash256_from_bytes(sig->wots_sig[i], in, inlen); - for (i = 0; i < height; i++) + } + for (i = 0; i < height; i++) { hash256_from_bytes(sig->auth_path[i], in, inlen); - + } return 1; } -int xmss_signature_to_bytes(uint32_t xmss_type, const XMSS_SIGNATURE *sig, uint8_t **out, size_t *outlen) +int xmss_signature_to_bytes(const XMSS_SIGNATURE *sig, uint32_t xmss_type, uint8_t **out, size_t *outlen) { size_t height; size_t i; @@ -1018,7 +984,7 @@ int xmss_signature_to_bytes(uint32_t xmss_type, const XMSS_SIGNATURE *sig, uint8 } uint32_to_bytes(sig->index, out, outlen); hash256_to_bytes(sig->random, out, outlen); - for (i = 0; i < 67; i++) { + for (i = 0; i < WOTS_NUM_CHAINS; i++) { hash256_to_bytes(sig->wots_sig[i], out, outlen); } for (i = 0; i < height; i++) { @@ -1045,12 +1011,12 @@ int xmss_signature_print_ex(FILE *fp, int fmt, int ind, const char *label, const format_bytes(fp, fmt, ind, "random", sig->random, 32); format_print(fp, fmt, ind, "wots_sig\n"); for (i = 0; i < 67; i++) { - format_print(fp, fmt, ind+4, "%d ", i); + format_print(fp, fmt, ind+4, "%d", i); format_bytes(fp, fmt, 0, "", sig->wots_sig[i], 32); } format_print(fp, fmt, ind, "auth_path\n"); for (i = 0; i < height; i++) { - format_print(fp, fmt, ind+4, "%d ", i); + format_print(fp, fmt, ind+4, "%d", i); format_bytes(fp, fmt, 0, "", sig->auth_path[i], 32); } return 1; @@ -1088,7 +1054,7 @@ int xmss_signature_print(FILE *fp, int fmt, int ind, const char *label, const ui error_print(); return -1; } - format_print(fp, fmt, ind+4, "%d ", i); + format_print(fp, fmt, ind+4, "%d", i); format_bytes(fp, fmt, 0, "", sig, 32); sig += 32; siglen -= 32; @@ -1096,21 +1062,29 @@ int xmss_signature_print(FILE *fp, int fmt, int ind, const char *label, const ui format_print(fp, fmt, ind, "auth_path\n"); for (i = 0; i < XMSS_MAX_HEIGHT && siglen >= 32; i++) { - format_print(fp, fmt, ind+4, "%d ", i); + format_print(fp, fmt, ind+4, "%d", i); format_bytes(fp, fmt, 0, "", sig, 32); sig += 32; siglen -= 32; } - format_print(fp, fmt, ind, "[left %zu bytes]\n", siglen); + format_print(fp, fmt, ind, "[%zu bytes left]\n", siglen); return 1; } +void xmss_sign_ctx_cleanup(XMSS_SIGN_CTX *ctx) +{ + if (ctx) { + gmssl_secure_clear(ctx->xmss_sig.random, sizeof(hash256_t)); + gmssl_secure_clear(ctx->xmss_sig.wots_sig, sizeof(wots_sig_t)); // might cache wots_sk + } +} + int xmss_sign_init(XMSS_SIGN_CTX *ctx, XMSS_KEY *key) { - uint8_t index_buf[32] = {0}; - uint8_t adrs[32]; + hash256_t hash256_index = {0}; + xmss_adrs_t adrs; size_t height; if (!ctx || !key) { @@ -1121,6 +1095,11 @@ int xmss_sign_init(XMSS_SIGN_CTX *ctx, XMSS_KEY *key) error_print(); return -1; } + // check if out of keys + if (key->index >= (1 << height)) { + error_print(); + return -1; + } memset(ctx, 0, sizeof(*ctx)); // cache public key @@ -1130,12 +1109,12 @@ int xmss_sign_init(XMSS_SIGN_CTX *ctx, XMSS_KEY *key) ctx->xmss_sig.index = key->index; // derive ctx->xmss_sig.random - PUTU32(index_buf + 28, key->index); + PUTU32(hash256_index + 28, key->index); // r = PRF(SK_PRF, toByte(idx_sig, 32)); hash256_init(&ctx->hash256_ctx); - hash256_update(&ctx->hash256_ctx, hash256_three, 32); - hash256_update(&ctx->hash256_ctx, key->sk_prf, 32); - hash256_update(&ctx->hash256_ctx, index_buf, 32); + hash256_update(&ctx->hash256_ctx, hash256_three, sizeof(hash256_t)); + hash256_update(&ctx->hash256_ctx, key->sk_prf, sizeof(hash256_t)); + hash256_update(&ctx->hash256_ctx, hash256_index, sizeof(hash256_t)); hash256_finish(&ctx->hash256_ctx, ctx->xmss_sig.random); // wots_sk => ctx->xmss_sig.wots_sig @@ -1153,26 +1132,36 @@ int xmss_sign_init(XMSS_SIGN_CTX *ctx, XMSS_KEY *key) // H_msg(M) := HASH256(toByte(2, 32) || r || XMSS_ROOT || toByte(idx_sig, 32) || M) hash256_init(&ctx->hash256_ctx); - hash256_update(&ctx->hash256_ctx, hash256_two, 32); - hash256_update(&ctx->hash256_ctx, ctx->xmss_sig.random, 32); - hash256_update(&ctx->hash256_ctx, key->public_key.root, 32); - hash256_update(&ctx->hash256_ctx, index_buf, 32); + hash256_update(&ctx->hash256_ctx, hash256_two, sizeof(hash256_t)); + hash256_update(&ctx->hash256_ctx, ctx->xmss_sig.random, sizeof(hash256_t)); + hash256_update(&ctx->hash256_ctx, key->public_key.root, sizeof(hash256_t)); + hash256_update(&ctx->hash256_ctx, hash256_index, sizeof(hash256_t)); return 1; } int xmss_sign_update(XMSS_SIGN_CTX *ctx, const uint8_t *data, size_t datalen) { + if (!ctx) { + error_print(); + return -1; + } if (data && datalen) { hash256_update(&ctx->hash256_ctx, data, datalen); } return 1; } -int xmss_sign_finish(XMSS_SIGN_CTX *ctx, uint8_t *sigbuf, size_t *siglen) +// TODO: support output *siglen only +int xmss_sign_finish(XMSS_SIGN_CTX *ctx, uint8_t *sig, size_t *siglen) { xmss_adrs_t adrs; - uint8_t dgst[32]; + hash256_t dgst; + + if (!ctx || !sig || !siglen) { + error_print(); + return -1; + } hash256_finish(&ctx->hash256_ctx, dgst); @@ -1185,47 +1174,70 @@ int xmss_sign_finish(XMSS_SIGN_CTX *ctx, uint8_t *sigbuf, size_t *siglen) ctx->xmss_sig.wots_sig); *siglen = 0; - if (xmss_signature_to_bytes(ctx->xmss_public_key.xmss_type, &ctx->xmss_sig, &sigbuf, siglen) != 1) { + if (xmss_signature_to_bytes(&ctx->xmss_sig, ctx->xmss_public_key.xmss_type, &sig, siglen) != 1) { error_print(); return -1; } + return 1; +} +int xmss_verify_init_ex(XMSS_SIGN_CTX *ctx, const XMSS_KEY *key, const XMSS_SIGNATURE *sig) +{ + hash256_t hash256_index = {0}; + if (!ctx || !key || !sig) { + error_print(); + return -1; + } + // cache xmss_public_key + ctx->xmss_public_key = key->public_key; + + // cache xmss_sig + ctx->xmss_sig = *sig; + + // hash256_init + PUTU32(hash256_index + 28, ctx->xmss_sig.index); + hash256_init(&ctx->hash256_ctx); + hash256_update(&ctx->hash256_ctx, hash256_two, sizeof(hash256_t)); + hash256_update(&ctx->hash256_ctx, ctx->xmss_sig.random, sizeof(hash256_t)); + hash256_update(&ctx->hash256_ctx, key->public_key.root, sizeof(hash256_t)); + hash256_update(&ctx->hash256_ctx, hash256_index, sizeof(hash256_t)); return 1; } int xmss_verify_init(XMSS_SIGN_CTX *ctx, const XMSS_KEY *key, const uint8_t *sig, size_t siglen) { - uint8_t sig_index[32]; + hash256_t hash256_index = {0}; if (!ctx || !key || !sig || !siglen) { error_print(); return -1; } - // cache xmss_public_key ctx->xmss_public_key = key->public_key; // parse signature - if (xmss_signature_from_bytes(key->public_key.xmss_type, &ctx->xmss_sig, &sig, &siglen) != 1) { + if (xmss_signature_from_bytes(&ctx->xmss_sig, key->public_key.xmss_type, &sig, &siglen) != 1) { error_print(); return -1; } - memset(sig_index, 0, 28); - PUTU32(sig_index + 28, ctx->xmss_sig.index); - + // hash256_init + PUTU32(hash256_index + 28, ctx->xmss_sig.index); hash256_init(&ctx->hash256_ctx); - hash256_update(&ctx->hash256_ctx, hash256_two, 32); - hash256_update(&ctx->hash256_ctx, ctx->xmss_sig.random, 32); - hash256_update(&ctx->hash256_ctx, key->public_key.root, 32); - hash256_update(&ctx->hash256_ctx, sig_index, 32); - + hash256_update(&ctx->hash256_ctx, hash256_two, sizeof(hash256_t)); + hash256_update(&ctx->hash256_ctx, ctx->xmss_sig.random, sizeof(hash256_t)); + hash256_update(&ctx->hash256_ctx, key->public_key.root, sizeof(hash256_t)); + hash256_update(&ctx->hash256_ctx, hash256_index, sizeof(hash256_t)); return 1; } int xmss_verify_update(XMSS_SIGN_CTX *ctx, const uint8_t *data, size_t datalen) { + if (!ctx) { + error_print(); + return -1; + } if (data && datalen) { hash256_update(&ctx->hash256_ctx, data, datalen); } @@ -1239,7 +1251,6 @@ int xmss_verify_finish(XMSS_SIGN_CTX *ctx) hash256_t dgst; xmss_adrs_t adrs; hash256_t root; - int right; if (!ctx) { error_print(); @@ -1273,13 +1284,12 @@ int xmss_verify_finish(XMSS_SIGN_CTX *ctx) // wots_root (index), auth_path => xmss_root adrs_set_type(adrs, XMSS_ADRS_TYPE_HASHTREE); adrs_set_padding(adrs, 0); - adrs_set_key_and_mask(adrs, 0); for (h = 0; h < height; h++) { - right = index & 1; + int right_child = index & 1; index >>= 1; adrs_set_tree_height(adrs, h + 1); adrs_set_tree_index(adrs, index); - if (right) + if (right_child) randomized_tree_hash(ctx->xmss_sig.auth_path[h], root, ctx->xmss_public_key.seed, adrs, root); else randomized_tree_hash(root, ctx->xmss_sig.auth_path[h], ctx->xmss_public_key.seed, adrs, root); } @@ -1350,9 +1360,12 @@ int xmssmt_type_to_height_and_layers(uint32_t xmssmt_type, size_t *height, size_ return 1; } -size_t xmss_num_tree_nodes(size_t height) -{ - return (1 << (height + 1)) - 1; +static uint64_t xmssmt_tree_address(uint64_t index, size_t height, size_t layers, size_t layer) { + return (index >> (height/layers) * (layer + 1)); +} + +static uint64_t xmssmt_tree_index(uint64_t index, size_t height, size_t layers, size_t layer) { + return (index >> (height/layers) * layer) % (1 << (height/layers)); } size_t xmssmt_num_trees_nodes(size_t height, size_t layers) @@ -1404,12 +1417,11 @@ int xmssmt_private_key_size(uint32_t xmssmt_type, size_t *len) error_print(); return -1; } - *len = XMSSMT_PUBLIC_KEY_SIZE; *len += sizeof(hash256_t); *len += sizeof(hash256_t); xmssmt_index_to_bytes(index, xmssmt_type, NULL, len); - *len += sizeof(hash256_t) * ((1 << (height/layers + 1)) - 1) * layers; + *len += sizeof(hash256_t) * xmssmt_num_trees_nodes(height, layers); *len += sizeof(wots_sig_t) * (layers - 1); return 1; } @@ -1501,36 +1513,33 @@ int xmssmt_private_key_from_bytes(XMSSMT_KEY *key, const uint8_t **in, size_t *i return 1; } - - int xmssmt_key_update(XMSSMT_KEY *key) { size_t height; size_t layers; size_t layer; hash256_t *tree; - uint64_t next_index; - xmss_adrs_t adrs; + uint8_t *xmss_root; // FIXME: use hash256_t* - uint8_t *xmss_root; - + if (!key) { + error_print(); + return -1; + } if (xmssmt_type_to_height_and_layers(key->public_key.xmssmt_type, &height, &layers) != 1) { error_print(); return -1; } - - if (key->index > (1 << height)) { + if (key->index >= (1 << height)) { + if (key->index == (1 << height)) { + return 0; + } error_print(); return -1; } - if (key->index == (1 << height)) { - return 0; - } next_index = key->index + 1; - tree = key->trees; for (layer = 0; layer < layers - 1; layer++) { @@ -1550,7 +1559,7 @@ int xmssmt_key_update(XMSSMT_KEY *key) adrs_set_type(adrs, XMSS_ADRS_TYPE_OTS); adrs_set_ots_address(adrs, xmssmt_tree_index(next_index, height, layers, layer + 1)); wots_derive_sk(key->secret, key->public_key.seed, adrs, key->wots_sigs[layer]); - xmss_root = tree[tree_root_offset(height/layers)]; + xmss_root = tree[xmss_tree_root_offset(height/layers)]; wots_sign(key->wots_sigs[layer], key->public_key.seed, adrs, xmss_root, key->wots_sigs[layer]); tree += xmss_num_tree_nodes(height/layers); } @@ -1563,6 +1572,7 @@ int xmssmt_key_update(XMSSMT_KEY *key) void xmssmt_key_cleanup(XMSSMT_KEY *key) { if (key) { + gmssl_secure_clear(key->public_key.seed, sizeof(hash256_t)); // clear all RNG outputs gmssl_secure_clear(key->secret, sizeof(hash256_t)); gmssl_secure_clear(key->sk_prf, sizeof(hash256_t)); if (key->trees) { @@ -1629,7 +1639,7 @@ int xmssmt_key_generate_ex(XMSSMT_KEY *key, uint32_t xmssmt_type, xmss_build_tree(key->secret, key->public_key.seed, adrs, height/layers, tree); - xmss_root = tree[tree_root_offset(height/layers)]; + xmss_root = tree[xmss_tree_root_offset(height/layers)]; tree += xmss_num_tree_nodes(height/layers); // sign xmss_root with higher layer @@ -1643,7 +1653,7 @@ int xmssmt_key_generate_ex(XMSSMT_KEY *key, uint32_t xmssmt_type, /* hash256_t *tree2 = key->trees + xmss_num_tree_nodes(height/layers) * layer; - hash256_t xmss_root2 = tree2[tree_root_offset(height/layers)]; + hash256_t xmss_root2 = tree2[xmss_tree_root_offset(height/layers)]; fprintf(stderr, "%p %p\n", tree, tree2); @@ -1668,7 +1678,7 @@ int xmssmt_key_generate_ex(XMSSMT_KEY *key, uint32_t xmssmt_type, // extra check for (layer = 0; layer < layers - 1; layer++) { - uint8_t *dgst = tree[tree_root_offset(height/layers)]; + uint8_t *dgst = tree[xmss_tree_root_offset(height/layers)]; tree_address = xmssmt_tree_address(index, height, layers, layer + 1); tree_index = xmssmt_tree_index(index, height, layers, layer + 1); @@ -1789,15 +1799,13 @@ int xmssmt_private_key_print(FILE *fp, int fmt, int ind, const char *label, cons for (i = 0; i < layers; i++) { char label[64]; snprintf(label, sizeof(label), "xmss_root[%zu]", i); - format_bytes(fp, fmt, ind, label, tree[tree_root_offset(height/layers)], 32); + format_bytes(fp, fmt, ind, label, tree[xmss_tree_root_offset(height/layers)], 32); tree += xmss_num_tree_nodes(height/layers); } return 1; } - - int xmssmt_index_to_bytes(uint64_t index, uint32_t xmssmt_type, uint8_t **out, size_t *outlen) { size_t height; @@ -1989,8 +1997,7 @@ int xmssmt_signature_print_ex(FILE *fp, int fmt, int ind, const char *label, con format_print(fp, fmt, ind, "%s\n", label); ind += 4; - //format_print(fp, fmt, ind, "index: %"PRIu64"----\n", sig->index); - format_print(fp, fmt, ind, "index: %llu\n", (unsigned long long)sig->index); + format_print(fp, fmt, ind, "index: %"PRIu64"\n", sig->index); format_bytes(fp, fmt, ind, "random", sig->random, 32); for (layer = 0; layer < layers; layer++) { @@ -2073,6 +2080,14 @@ int xmssmt_signature_print(FILE *fp, int fmt, int ind, const char *label, const return 1; } +void xmssmt_sign_ctx_cleanup(XMSSMT_SIGN_CTX *ctx) +{ + if (ctx) { + gmssl_secure_clear(ctx->xmssmt_sig.random, sizeof(hash256_t)); + gmssl_secure_clear(ctx->xmssmt_sig.wots_sigs[0], sizeof(wots_sig_t)); + } +} + int xmssmt_sign_init(XMSSMT_SIGN_CTX *ctx, XMSSMT_KEY *key) { size_t height; @@ -2256,7 +2271,6 @@ int xmssmt_verify_init_ex(XMSSMT_SIGN_CTX *ctx, const XMSSMT_KEY *key, const XMS return 1; } - // check compatible publickey and sig int xmssmt_verify_init(XMSSMT_SIGN_CTX *ctx, const XMSSMT_KEY *key, const uint8_t *sig, size_t siglen) {