From 38451da6a842dbc9aec566a6fd9753b9a7c8e1f5 Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Mon, 5 Jan 2026 21:19:23 +0800 Subject: [PATCH] Update XMSS --- include/gmssl/xmss.h | 23 +++++------- src/xmss.c | 87 ++++++++++++++++++++++++++++++++++---------- tests/xmsstest.c | 63 ++++++++++++++++++-------------- tools/xmsskeygen.c | 33 ++++++++++++++--- tools/xmssmtsign.c | 14 ++++++- tools/xmsssign.c | 60 +++++++++++++++++++++++++++--- 6 files changed, 206 insertions(+), 74 deletions(-) diff --git a/include/gmssl/xmss.h b/include/gmssl/xmss.h index c41a2caf..13c80d70 100644 --- a/include/gmssl/xmss.h +++ b/include/gmssl/xmss.h @@ -149,19 +149,9 @@ int wots_verify(const hash256_t wots_root, // from RFC 8391 table 7 enum { - XMSS_RESERVED = 0x00000000, XMSS_SHA2_10_256 = 0x00000001, XMSS_SHA2_16_256 = 0x00000002, XMSS_SHA2_20_256 = 0x00000003, - XMSS_SHA2_10_512 = 0x00000004, - XMSS_SHA2_16_512 = 0x00000005, - XMSS_SHA2_20_512 = 0x00000006, - XMSS_SHAKE_10_256 = 0x00000007, - XMSS_SHAKE_16_256 = 0x00000008, - XMSS_SHAKE_20_256 = 0x00000009, - XMSS_SHAKE_10_512 = 0x0000000A, - XMSS_SHAKE_16_512 = 0x0000000B, - XMSS_SHAKE_20_512 = 0x0000000C, }; enum { @@ -216,23 +206,28 @@ typedef struct { typedef struct { XMSS_PUBLIC_KEY public_key; + uint32_t index; hash256_t secret; hash256_t sk_prf; - uint32_t index; hash256_t *tree; // hash256_t[2^(h + 1) - 1] } XMSS_KEY; -#define XMSS_PRIVATE_KEY_SIZE (XMSS_PUBLIC_KEY_SIZE + 32 + 32 + 4) // = 136 +// XMSS_SHA2_10_256: 65,640 +// XMSS_SHA2_16_256: 4,194,408 +// XMSS_SHA2_20_256: 67,108,968 +int xmss_private_key_size(uint32_t xmss_type, size_t *keysize); + +//#define XMSS_PRIVATE_KEY_SIZE (XMSS_PUBLIC_KEY_SIZE + 32 + 32 + 4) // = 136 int xmss_key_generate(XMSS_KEY *key, uint32_t xmss_type); int xmss_key_remaining_signs(const XMSS_KEY *key, size_t *count); -void xmss_key_cleanup(XMSS_KEY *key); int xmss_public_key_to_bytes(const XMSS_KEY *key, uint8_t **out, size_t *outlen); int xmss_public_key_from_bytes(XMSS_KEY *key, const uint8_t **in, size_t *inlen); int xmss_public_key_print(FILE *fp, int fmt, int ind, const char *label, const XMSS_KEY *key); int xmss_private_key_to_bytes(const XMSS_KEY *key, uint8_t **out, size_t *outlen); int xmss_private_key_from_bytes(XMSS_KEY *key, const uint8_t **in, size_t *inlen); int xmss_private_key_print(FILE *fp, int fmt, int ind, const char *label, const XMSS_KEY *key); +void xmss_key_cleanup(XMSS_KEY *key); typedef struct { @@ -348,9 +343,9 @@ typedef struct { typedef struct { XMSSMT_PUBLIC_KEY public_key; + uint64_t index; // in [0, 2^60 - 1] hash256_t secret; hash256_t sk_prf; - uint64_t index; // in [0, 2^60 - 1] hash256_t *trees; wots_sig_t wots_sigs[XMSSMT_MAX_LAYERS - 1]; } XMSSMT_KEY; diff --git a/src/xmss.c b/src/xmss.c index 66107d6f..420dbdf9 100644 --- a/src/xmss.c +++ b/src/xmss.c @@ -672,6 +672,26 @@ uint32_t xmss_type_from_name(const char *name) return 0; } +int xmss_private_key_size(uint32_t xmss_type, size_t *keysize) +{ + size_t height; + + if (!keysize) { + error_print(); + return -1; + } + if (xmss_type_to_height(xmss_type, &height) != 1) { + error_print(); + return -1; + } + *keysize = XMSS_PUBLIC_KEY_SIZE + + sizeof(hash256_t) + + sizeof(hash256_t) + + sizeof(uint32_t) + + sizeof(hash256_t) * xmss_num_tree_nodes(height); + return 1; +} + int xmss_key_generate_ex(XMSS_KEY *key, uint32_t xmss_type, const hash256_t seed, const hash256_t secret, const hash256_t sk_prf) { @@ -838,6 +858,9 @@ int xmss_public_key_print(FILE *fp, int fmt, int ind, const char *label, const X int xmss_private_key_to_bytes(const XMSS_KEY *key, uint8_t **out, size_t *outlen) { + size_t height; + size_t tree_size; + if (!key || !outlen) { error_print(); return -1; @@ -849,32 +872,55 @@ int xmss_private_key_to_bytes(const XMSS_KEY *key, uint8_t **out, size_t *outlen uint32_to_bytes(key->index, out, outlen); hash256_to_bytes(key->secret, out, outlen); hash256_to_bytes(key->sk_prf, out, outlen); + + if (key->tree == NULL) { + error_print(); + return -1; + } + if (xmss_type_to_height(key->public_key.xmss_type, &height) != 1) { + error_print(); + return -1; + } + tree_size = sizeof(hash256_t) * xmss_num_tree_nodes(height); + if (out && *out) { + memcpy(*out, key->tree, tree_size); + *out += tree_size; + } + *outlen += tree_size; return 1; } int xmss_private_key_from_bytes(XMSS_KEY *key, const uint8_t **in, size_t *inlen) { size_t height; - size_t tree_nodes; + size_t tree_size; xmss_adrs_t adrs; if (!key || !in || !(*in) || !inlen) { error_print(); return -1; } - if (*inlen < XMSS_PRIVATE_KEY_SIZE) { - error_print(); - return -1; - } - if (xmss_public_key_from_bytes(key, in, inlen) != 1) { error_print(); return -1; } + // check inlen without tree + if (*inlen < sizeof(uint32_t) + sizeof(hash256_t)*2) { + error_print(); + return -1; + } + if (xmss_type_to_height(key->public_key.xmss_type, &height) != 1) { error_print(); return -1; } + tree_size = sizeof(hash256_t) * xmss_num_tree_nodes(height); + + // prepare buffer (might failure ops) before load secrets + if (!(key->tree = malloc(tree_size))) { + error_print(); + return -1; + } // index, allow index == 2^h, which means out-of-keys uint32_from_bytes(&key->index, in, inlen); @@ -882,22 +928,25 @@ int xmss_private_key_from_bytes(XMSS_KEY *key, const uint8_t **in, size_t *inlen error_print(); return -1; } - // prepare buffer (might failure ops) before load secrets - tree_nodes = (1 << (height + 1)) - 1; - if (!(key->tree = malloc(sizeof(hash256_t) * tree_nodes))) { - error_print(); - return -1; - } - - // secret hash256_from_bytes(key->secret, in, inlen); - // sk_prf hash256_from_bytes(key->sk_prf, in, inlen); - // build_tree - adrs_set_layer_address(adrs, 0); - adrs_set_tree_address(adrs, 0); - xmss_build_tree(key->secret, key->public_key.seed, adrs, height, key->tree); + if (*inlen) { + // load tree + if (*inlen < tree_size) { + error_print(); + return -1; + } + memcpy(key->tree, *in, tree_size); + *in += tree_size; + *inlen -= tree_size; + } else { + // build_tree + adrs_set_layer_address(adrs, 0); + adrs_set_tree_address(adrs, 0); + xmss_build_tree(key->secret, key->public_key.seed, adrs, height, key->tree); + } + // check if (memcmp(key->tree[xmss_tree_root_offset(height)], key->public_key.root, sizeof(hash256_t)) != 0) { diff --git a/tests/xmsstest.c b/tests/xmsstest.c index 2d2a49be..38c15d8f 100644 --- a/tests/xmsstest.c +++ b/tests/xmsstest.c @@ -306,6 +306,36 @@ static int test_xmss_build_root(void) return 1; } +static int test_xmss_private_key_size(void) +{ + struct { + uint32_t xmss_type; + size_t keylen; + } tests[] = { + { XMSS_HASH256_10_256, 65640 }, + { XMSS_HASH256_16_256, 4194408 }, + { XMSS_HASH256_20_256, 67108968 }, + }; + size_t keylen; + size_t i; + + format_print(stderr, 0, 4, "xmss_private_key_size\n"); + for (i = 0; i < sizeof(tests)/sizeof(tests[0]); i++) { + if (xmss_private_key_size(tests[i].xmss_type, &keylen) != 1) { + error_print(); + return -1; + } + if (keylen != tests[i].keylen) { + error_print(); + return -1; + } + format_print(stderr, 0, 8, "%s: %zu\n", xmss_type_name(tests[i].xmss_type), keylen); + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + static int test_xmss_key_generate(void) { uint32_t xmss_type = XMSS_HASH256_10_256; @@ -342,12 +372,12 @@ static int test_xmss_key_generate(void) return 1; } -static int test_xmss_key_to_bytes(void) +static int test_xmss_public_key_to_bytes(void) { uint32_t xmss_type = XMSS_HASH256_10_256; XMSS_KEY key; XMSS_KEY pub; - uint8_t buf[XMSS_PUBLIC_KEY_SIZE + XMSS_PRIVATE_KEY_SIZE]; + uint8_t buf[XMSS_PUBLIC_KEY_SIZE]; uint8_t *p = buf; const uint8_t *cp = buf; size_t len = 0; @@ -365,16 +395,6 @@ static int test_xmss_key_to_bytes(void) error_print(); return -1; } - - if (xmss_private_key_to_bytes(&key, &p, &len) != 1) { - error_print(); - return -1; - } - if (len != XMSS_PUBLIC_KEY_SIZE + XMSS_PRIVATE_KEY_SIZE) { - error_print(); - return -1; - } - if (xmss_public_key_from_bytes(&pub, &cp, &len) != 1) { error_print(); return -1; @@ -383,18 +403,6 @@ static int test_xmss_key_to_bytes(void) error_print(); return -1; } - if (xmss_private_key_from_bytes(&pub, &cp, &len) != 1) { - error_print(); - return -1; - } - - // FIXME: compare trees - /* - if (memcmp(&key, &pub, sizeof(XMSS_KEY)) != 0) { - error_print(); - return -1; - } - */ if (len) { error_print(); return -1; @@ -1100,13 +1108,13 @@ int main(void) if (test_xmss_adrs() != 1) goto err; if (test_xmss_build_tree() != 1) goto err; if (test_xmss_build_root() != 1) goto err; - if (test_xmss_key_generate() != 1) goto err; - if (test_xmss_key_to_bytes() != 1) goto err; + if (test_xmss_public_key_to_bytes() != 1) goto err; + if (test_xmss_private_key_size() != 1) goto err; + //if (test_xmss_private_key_to_bytes() != 1) goto err; if (test_xmss_signature_size() != 1) goto err; if (test_xmss_sign() != 1) goto err; if (test_xmss_sign_init() != 1) goto err; - if (test_xmssmt_key_generate() != 1) goto err; if (test_xmssmt_index_to_bytes() != 1) goto err; if (test_xmssmt_signature_to_bytes() != 1) goto err; @@ -1117,7 +1125,6 @@ int main(void) if (test_xmssmt_private_key_to_bytes() != 1) goto err; if (test_xmssmt_sign() != 1) goto err; if (test_xmssmt_sign_update() != 1) goto err; - printf("%s all tests passed\n", __FILE__); return 0; err: diff --git a/tools/xmsskeygen.c b/tools/xmsskeygen.c index 034ba1b2..5dff4eac 100644 --- a/tools/xmsskeygen.c +++ b/tools/xmsskeygen.c @@ -1,5 +1,5 @@ /* - * Copyright 2014-2025 The GmSSL Project. All Rights Reserved. + * Copyright 2014-2026 The GmSSL Project. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the License); you may * not use this file except in compliance with the License. @@ -42,11 +42,12 @@ int xmsskeygen_main(int argc, char **argv) FILE *outfp = NULL; FILE *puboutfp = stdout; XMSS_KEY key; - uint8_t out[XMSS_PRIVATE_KEY_SIZE]; + uint8_t *out = NULL; + uint8_t *pout; + size_t outlen; uint8_t pubout[XMSS_PUBLIC_KEY_SIZE]; - uint8_t *pout = out; - uint8_t *ppubout = pubout; - size_t outlen = 0, puboutlen = 0; + uint8_t *ppubout; + size_t puboutlen ; memset(&key, 0, sizeof(key)); @@ -116,6 +117,17 @@ bad: xmss_public_key_print(stderr, 0, 0, "xmss_public_key", &key); } + outlen = 0; + if (xmss_private_key_to_bytes(&key, NULL, &outlen) != 1) { + error_print(); + goto end; + } + if (!(out = malloc(outlen))) { + error_print(); + goto end; + } + pout = out; + outlen = 0; if (xmss_private_key_to_bytes(&key, &pout, &outlen) != 1) { error_print(); goto end; @@ -125,10 +137,16 @@ bad: goto end; } + ppubout = pubout; + puboutlen = 0; if (xmss_public_key_to_bytes(&key, &ppubout, &puboutlen) != 1) { error_print(); goto end; } + if (puboutlen != sizeof(pubout)) { + error_print(); + goto end; + } if (fwrite(pubout, 1, puboutlen, puboutfp) != puboutlen) { error_print(); goto end; @@ -137,7 +155,10 @@ bad: ret = 0; end: xmss_key_cleanup(&key); - gmssl_secure_clear(out, outlen); + if (out) { + gmssl_secure_clear(out, outlen); + free(out); + } if (outfile && outfp) fclose(outfp); if (puboutfile && puboutfp) fclose(puboutfp); return ret; diff --git a/tools/xmssmtsign.c b/tools/xmssmtsign.c index 17a0e605..95ef9231 100644 --- a/tools/xmssmtsign.c +++ b/tools/xmssmtsign.c @@ -149,6 +149,7 @@ bad: goto end; } +#if 0 // write updated key back to file // TODO: write back `q` only p = keybuf; @@ -162,6 +163,17 @@ bad: error_print(); return -1; } +#else + if (fseek(keyfp, XMSSMT_PUBLIC_KEY_SIZE, SEEK_SET) != 0) { + error_print(); + goto end; + } + uint8_t index_buf[8]; + uint8_t *pindex = index_buf; + size_t index_len = 0; + xmssmt_index_to_bytes(key.index, key.public_key.xmssmt_type, &pindex, &index_len); + fwrite(index_buf, 1, index_len, keyfp); +#endif while (1) { uint8_t buf[1024]; @@ -189,7 +201,7 @@ bad: ret = 0; end: - //xmss_key_cleanup(&key); + xmssmt_key_cleanup(&key); gmssl_secure_clear(keybuf, keylen); gmssl_secure_clear(&ctx, sizeof(ctx)); if (keyfp) fclose(keyfp); diff --git a/tools/xmsssign.c b/tools/xmsssign.c index 08366cad..c8775864 100644 --- a/tools/xmsssign.c +++ b/tools/xmsssign.c @@ -14,6 +14,7 @@ #include #include #include +#include #include static const char *usage = "-key file [-in file] [-out file] [-verbose]\n"; @@ -37,10 +38,12 @@ int xmsssign_main(int argc, char **argv) FILE *keyfp = NULL; FILE *infp = stdin; FILE *outfp = stdout; - uint8_t keybuf[XMSS_PRIVATE_KEY_SIZE]; - size_t keylen = XMSS_PRIVATE_KEY_SIZE; - const uint8_t *cp = keybuf; - uint8_t *p = keybuf; + uint8_t pubkeybuf[XMSS_PUBLIC_KEY_SIZE]; + uint8_t *keybuf = NULL; + size_t keylen; + const uint8_t *cp; + uint8_t *p; + size_t len; XMSS_KEY key; XMSS_SIGN_CTX ctx; uint8_t sig[XMSS_SIGNATURE_MAX_SIZE]; @@ -102,10 +105,40 @@ bad: goto end; } - if (fread(keybuf, 1, keylen, keyfp) != keylen) { + // load xmss_public_key + if (fread(pubkeybuf, 1, sizeof(pubkeybuf), keyfp) != sizeof(pubkeybuf)) { + error_print(); + goto end; + } + cp = pubkeybuf; + len = sizeof(pubkeybuf); + if (xmss_public_key_from_bytes(&key, &cp, &len) != 1) { + error_print(); + goto end; + } + if (len) { + error_print(); + goto end; + } + + // xmss_private_key_size + if (xmss_private_key_size(key.public_key.xmss_type, &keylen) != 1) { + error_print(); + goto end; + } + if (!(keybuf = malloc(keylen))) { + error_print(); + goto end; + } + memcpy(keybuf, pubkeybuf, sizeof(pubkeybuf)); + + len = keylen - sizeof(pubkeybuf); + if (fread(keybuf + sizeof(pubkeybuf), 1, len, keyfp) != len) { fprintf(stderr, "%s: read private key failure\n", prog); goto end; } + + cp = keybuf; if (xmss_private_key_from_bytes(&key, &cp, &keylen) != 1) { error_print(); goto end; @@ -124,8 +157,10 @@ bad: goto end; } +#if 0 // write updated key back to file // TODO: write back `q` only + p = keybuf; if (xmss_private_key_to_bytes(&key, &p, &keylen) != 1) { error_print(); return -1; @@ -135,6 +170,16 @@ bad: error_print(); return -1; } +#else + // write index only + if (fseek(keyfp, XMSS_PUBLIC_KEY_SIZE, SEEK_SET) != 0) { + error_print(); + goto end; + } + uint8_t index_buf[4]; + PUTU32(index_buf, key.index); + fwrite(index_buf, 1, 4, keyfp); +#endif while (1) { uint8_t buf[1024]; @@ -163,8 +208,11 @@ bad: end: xmss_key_cleanup(&key); - gmssl_secure_clear(keybuf, sizeof(keybuf)); gmssl_secure_clear(&ctx, sizeof(ctx)); + if (keybuf) { + gmssl_secure_clear(keybuf, keylen); + free(keybuf); + } if (keyfp) fclose(keyfp); if (infp && infp != stdin) fclose(infp); if (outfp && outfp != stdout) fclose(outfp);