diff --git a/include/gmssl/xmss.h b/include/gmssl/xmss.h index 216f5596..5688f673 100644 --- a/include/gmssl/xmss.h +++ b/include/gmssl/xmss.h @@ -207,12 +207,18 @@ typedef struct { #define XMSS_PUBLIC_KEY_SIZE (4 + 32 + 32) // = 68 -typedef struct { +typedef struct XMSS_KEY_st XMSS_KEY; + +typedef int (*xmss_key_update_callback)(XMSS_KEY *key); + +typedef struct XMSS_KEY_st { XMSS_PUBLIC_KEY public_key; uint32_t index; xmss_hash256_t secret; xmss_hash256_t sk_prf; xmss_hash256_t *tree; // xmss_hash256_t[2^(h + 1) - 1] + xmss_key_update_callback update_callback; + void *update_param; } XMSS_KEY; // XMSS_SHA2_10_256: 65,640 @@ -224,13 +230,17 @@ int xmss_private_key_size(uint32_t xmss_type, size_t *keysize); int xmss_key_generate(XMSS_KEY *key, uint32_t xmss_type); int xmss_key_remaining_signs(const XMSS_KEY *key, size_t *count); +int xmss_key_set_update_callback(XMSS_KEY *key, xmss_key_update_callback update_cb, void *param); +int xmss_key_update(XMSS_KEY *key); +void xmss_key_cleanup(XMSS_KEY *key); + int xmss_public_key_to_bytes(const XMSS_KEY *key, uint8_t **out, size_t *outlen); int xmss_public_key_from_bytes(XMSS_KEY *key, const uint8_t **in, size_t *inlen); int xmss_public_key_print(FILE *fp, int fmt, int ind, const char *label, const XMSS_KEY *key); int xmss_private_key_to_bytes(const XMSS_KEY *key, uint8_t **out, size_t *outlen); int xmss_private_key_from_bytes(XMSS_KEY *key, const uint8_t **in, size_t *inlen); +int xmss_private_key_from_file(XMSS_KEY *key, FILE *fp); int xmss_private_key_print(FILE *fp, int fmt, int ind, const char *label, const XMSS_KEY *key); -void xmss_key_cleanup(XMSS_KEY *key); typedef struct { @@ -345,13 +355,19 @@ typedef struct { #define XMSSMT_PUBLIC_KEY_SIZE (4 + sizeof(xmss_hash256_t) + sizeof(xmss_hash256_t)) // = 68 bytes -typedef struct { +typedef struct XMSSMT_KEY_st XMSSMT_KEY; + +typedef int (*xmssmt_key_update_callback)(XMSSMT_KEY *key); + +typedef struct XMSSMT_KEY_st { XMSSMT_PUBLIC_KEY public_key; uint64_t index; // in [0, 2^60 - 1] xmss_hash256_t secret; xmss_hash256_t sk_prf; xmss_hash256_t *trees; xmss_wots_sig_t wots_sigs[XMSSMT_MAX_LAYERS - 1]; + xmssmt_key_update_callback update_callback; + void *update_param; } XMSSMT_KEY; /* @@ -368,12 +384,14 @@ int xmssmt_private_key_size(uint32_t xmssmt_type, size_t *len); int xmssmt_build_auth_path(const xmss_hash256_t *tree, size_t height, size_t layers, uint64_t index, xmss_hash256_t *auth_path); int xmssmt_key_generate(XMSSMT_KEY *key, uint32_t xmssmt_type); +int xmssmt_key_set_update_callback(XMSSMT_KEY *key, xmssmt_key_update_callback update_cb, void *param); int xmssmt_key_update(XMSSMT_KEY *key); int xmssmt_public_key_to_bytes(const XMSSMT_KEY *key, uint8_t **out, size_t *outlen); int xmssmt_public_key_from_bytes(XMSSMT_KEY *key, const uint8_t **in, size_t *inlen); int xmssmt_public_key_print(FILE *fp, int fmt, int ind, const char *label, const XMSSMT_KEY *key); int xmssmt_private_key_to_bytes(const XMSSMT_KEY *key, uint8_t **out, size_t *outlen); int xmssmt_private_key_from_bytes(XMSSMT_KEY *key, const uint8_t **in, size_t *inlen); +int xmssmt_private_key_from_file(XMSSMT_KEY *key, FILE *fp); int xmssmt_private_key_print(FILE *fp, int fmt, int ind, const char *label, const XMSSMT_KEY *key); void xmssmt_key_cleanup(XMSSMT_KEY *key); @@ -388,8 +406,7 @@ typedef struct { int xmssmt_index_to_bytes(uint64_t index, uint32_t xmssmt_type, uint8_t **out, size_t *outlen); int xmssmt_index_from_bytes(uint64_t *index, uint32_t xmssmt_type, const uint8_t **in, size_t *inlen); -#define XMSSMT_SIGNATURE_MAX_SIZE \ - (sizeof(uint64_t) + sizeof(xmss_hash256_t) + sizeof(xmss_wots_sig_t)*XMSSMT_MAX_LAYERS + sizeof(xmss_hash256_t)*XMSSMT_MAX_HEIGHT) // = 27688 bytes +#define XMSSMT_SIGNATURE_MAX_SIZE sizeof(XMSSMT_SIGNATURE) // >= 27688 bytes int xmssmt_key_get_signature_size(const XMSSMT_KEY *key, size_t *siglen); int xmssmt_signature_size(uint32_t xmssmt_type, size_t *siglen); diff --git a/src/xmss.c b/src/xmss.c index 9cd131b4..c04c7905 100644 --- a/src/xmss.c +++ b/src/xmss.c @@ -754,6 +754,17 @@ end: return ret; } +int xmss_key_set_update_callback(XMSS_KEY *key, xmss_key_update_callback update_cb, void *param) +{ + if (!key) { + error_print(); + return -1; + } + key->update_callback = update_cb; + key->update_param = param; + return 1; +} + int xmss_key_update(XMSS_KEY *key) { size_t height; @@ -774,6 +785,13 @@ int xmss_key_update(XMSS_KEY *key) return 0; } key->index++; + + if (key->update_callback) { + if (key->update_callback(key) != 1) { + error_print(); + return -1; + } + } return 1; } @@ -1186,12 +1204,15 @@ int xmss_sign_init(XMSS_SIGN_CTX *ctx, XMSS_KEY *key) xmss_adrs_set_ots_address(adrs, key->index); xmss_wots_derive_sk(key->secret, key->public_key.seed, adrs, ctx->xmss_sig.wots_sig); + // update key->index + if (xmss_key_update(key) != 1) { + error_print(); + return -1; + } + // xmss_sig.auth_path xmss_build_auth_path(key->tree, height, key->index, ctx->xmss_sig.auth_path); - // update key->index - key->index++; - // H_msg(M) := HASH256(toByte(2, 32) || r || XMSS_ROOT || toByte(idx_sig, 32) || M) xmss_hash256_init(&ctx->hash256_ctx); xmss_hash256_update(&ctx->hash256_ctx, xmss_hash256_two, sizeof(xmss_hash256_t)); @@ -1575,6 +1596,17 @@ int xmssmt_private_key_from_bytes(XMSSMT_KEY *key, const uint8_t **in, size_t *i return 1; } +int xmssmt_key_set_update_callback(XMSSMT_KEY *key, xmssmt_key_update_callback update_cb, void *param) +{ + if (!key) { + error_print(); + return -1; + } + key->update_callback = update_cb; + key->update_param = param; + return 1; +} + int xmssmt_key_update(XMSSMT_KEY *key) { size_t height; @@ -1628,6 +1660,12 @@ int xmssmt_key_update(XMSSMT_KEY *key) key->index++; + if (key->update_callback) { + if (key->update_callback(key) != 1) { + error_print(); + return -1; + } + } return 1; } @@ -2447,3 +2485,138 @@ int xmssmt_verify_finish(XMSSMT_SIGN_CTX *ctx) return 1; } + +int xmss_private_key_from_file(XMSS_KEY *key, FILE *fp) +{ + uint8_t pubkeybuf[XMSS_PUBLIC_KEY_SIZE]; + uint8_t *keybuf = NULL; + size_t keylen; + const uint8_t *cp; + size_t len; + + if (!key || !fp) { + error_print(); + return -1; + } + + // load xmss_public_key and get xmss_private_key_size + len = sizeof(pubkeybuf); + if (fread(pubkeybuf, 1, len, fp) != len) { + error_print(); + return -1; + } + cp = pubkeybuf; + if (xmss_public_key_from_bytes(key, &cp, &len) != 1) { + error_print(); + return -1; + } + if (len) { + error_print(); + return -1; + } + if (xmss_private_key_size(key->public_key.xmss_type, &keylen) != 1) { + error_print(); + return -1; + } + if (keylen <= sizeof(pubkeybuf)) { + error_print(); + return -1; + } + + // malloc and load full xmss_private_key + if (!(keybuf = malloc(keylen))) { + error_print(); + return -1; + } + memcpy(keybuf, pubkeybuf, sizeof(pubkeybuf)); + + len = keylen - sizeof(pubkeybuf); + if (fread(keybuf + sizeof(pubkeybuf), 1, len, fp) != len) { + free(keybuf); + error_print(); + return -1; + } + + cp = keybuf; + if (xmss_private_key_from_bytes(key, &cp, &keylen) != 1) { + free(keybuf); + error_print(); + return -1; + } + if (keylen) { + free(keybuf); + error_print(); + return -1; + } + + free(keybuf); + return 1; +} + +int xmssmt_private_key_from_file(XMSSMT_KEY *key, FILE *fp) +{ + uint8_t pubkeybuf[XMSSMT_PUBLIC_KEY_SIZE]; + uint8_t *keybuf = NULL; + size_t keylen; + const uint8_t *cp; + size_t len; + + if (!key || !fp) { + error_print(); + return -1; + } + + // load xmss_public_key and get xmss_private_key_size + len = sizeof(pubkeybuf); + if (fread(pubkeybuf, 1, len, fp) != len) { + error_print(); + return -1; + } + cp = pubkeybuf; + if (xmssmt_public_key_from_bytes(key, &cp, &len) != 1) { + error_print(); + return -1; + } + if (len) { + error_print(); + return -1; + } + if (xmssmt_private_key_size(key->public_key.xmssmt_type, &keylen) != 1) { + error_print(); + return -1; + } + if (keylen <= sizeof(pubkeybuf)) { + error_print(); + return -1; + } + + // malloc and load full xmss_private_key + if (!(keybuf = malloc(keylen))) { + error_print(); + return -1; + } + memcpy(keybuf, pubkeybuf, sizeof(pubkeybuf)); + + len = keylen - sizeof(pubkeybuf); + if (fread(keybuf + sizeof(pubkeybuf), 1, len, fp) != len) { + free(keybuf); + error_print(); + return -1; + } + + cp = keybuf; + if (xmssmt_private_key_from_bytes(key, &cp, &keylen) != 1) { + free(keybuf); + error_print(); + return -1; + } + if (keylen) { + free(keybuf); + error_print(); + return -1; + } + + free(keybuf); + return 1; +} + diff --git a/tools/xmssmtsign.c b/tools/xmssmtsign.c index 95ef9231..a711a1b4 100644 --- a/tools/xmssmtsign.c +++ b/tools/xmssmtsign.c @@ -26,6 +26,34 @@ static const char *options = " -verbose Print public key and signature\n" "\n"; +static int key_update_cb(XMSSMT_KEY *key) +{ + FILE *fp; + uint8_t index_buf[8]; + uint8_t *p = index_buf; + size_t len = 0; + + if (!key->update_param) { + error_print(); + return -1; + } + fp = (FILE *)key->update_param; + + // write index only + xmssmt_index_to_bytes(key->index, key->public_key.xmssmt_type, &p, &len); + + if (fseek(fp, XMSSMT_PUBLIC_KEY_SIZE, SEEK_SET) != 0) { + error_print(); + return -1; + } + if (fwrite(index_buf, 1, len, fp) != len + || fflush(fp) != 0) { + error_print(); + return -1; + } + return 1; +} + int xmssmtsign_main(int argc, char **argv) { int ret = 1; @@ -37,11 +65,6 @@ int xmssmtsign_main(int argc, char **argv) FILE *keyfp = NULL; FILE *infp = stdin; FILE *outfp = stdout; - uint8_t pubkey[XMSSMT_PUBLIC_KEY_SIZE]; - uint8_t *keybuf = NULL; - size_t keylen; - const uint8_t *cp; - uint8_t *p; XMSSMT_KEY key; XMSSMT_SIGN_CTX ctx; uint8_t sig[XMSSMT_SIGNATURE_MAX_SIZE]; @@ -103,77 +126,24 @@ bad: goto end; } - if (fread(pubkey, 1, sizeof(pubkey), keyfp) != sizeof(pubkey)) { - error_print(); - goto end; - } - cp = pubkey; - keylen = sizeof(pubkey); - if (xmssmt_public_key_from_bytes(&key, &cp, &keylen) != 1 ) { - error_print(); - goto end; - } - - - if (xmssmt_private_key_size(key.public_key.xmssmt_type, &keylen) != 1) { - error_print(); - goto end; - } - if (!(keybuf = malloc(keylen))) { - error_print(); - goto end; - } - memcpy(keybuf, pubkey, sizeof(pubkey)); - - - if (fread(keybuf + sizeof(pubkey), 1, keylen - sizeof(pubkey), keyfp) != keylen - sizeof(pubkey)) { + if (xmssmt_private_key_from_file(&key, keyfp) != 1) { fprintf(stderr, "%s: read private key failure\n", prog); goto end; } - cp = keybuf; - if (xmssmt_private_key_from_bytes(&key, &cp, &keylen) != 1) { - error_print(); - goto end; - } - if (keylen) { - error_print(); - return -1; - } if (verbose) { xmssmt_public_key_print(stderr, 0, 0, "lms_public_key", &key); } - if (xmssmt_sign_init(&ctx, &key) != 1) { + if (xmssmt_key_set_update_callback(&key, key_update_cb, keyfp) != 1) { error_print(); goto end; } -#if 0 - // write updated key back to file - // TODO: write back `q` only - p = keybuf; - keylen = 0; - if (xmssmt_private_key_to_bytes(&key, &p, &keylen) != 1) { - error_print(); - return -1; - } - rewind(keyfp); - if (fwrite(keybuf, 1, keylen, keyfp) != keylen) { - error_print(); - return -1; - } -#else - if (fseek(keyfp, XMSSMT_PUBLIC_KEY_SIZE, SEEK_SET) != 0) { + if (xmssmt_sign_init(&ctx, &key) != 1) { error_print(); goto end; } - uint8_t index_buf[8]; - uint8_t *pindex = index_buf; - size_t index_len = 0; - xmssmt_index_to_bytes(key.index, key.public_key.xmssmt_type, &pindex, &index_len); - fwrite(index_buf, 1, index_len, keyfp); -#endif while (1) { uint8_t buf[1024]; @@ -202,7 +172,6 @@ bad: end: xmssmt_key_cleanup(&key); - gmssl_secure_clear(keybuf, keylen); gmssl_secure_clear(&ctx, sizeof(ctx)); if (keyfp) fclose(keyfp); if (infp && infp != stdin) fclose(infp); diff --git a/tools/xmsssign.c b/tools/xmsssign.c index c8775864..b75c4248 100644 --- a/tools/xmsssign.c +++ b/tools/xmsssign.c @@ -27,6 +27,32 @@ static const char *options = " -verbose Print public key and signature\n" "\n"; +static int key_update_cb(XMSS_KEY *key) +{ + FILE *fp; + uint8_t index_buf[4]; + + if (!key->update_param) { + error_print(); + return -1; + } + fp = (FILE *)key->update_param; + + // write index only + PUTU32(index_buf, key->index); + + if (fseek(fp, XMSS_PUBLIC_KEY_SIZE, SEEK_SET) != 0) { + error_print(); + return -1; + } + if (fwrite(index_buf, 1, sizeof(index_buf), fp) != sizeof(index_buf) + || fflush(fp) != 0) { + error_print(); + return -1; + } + return 1; +} + int xmsssign_main(int argc, char **argv) { int ret = 1; @@ -38,12 +64,6 @@ int xmsssign_main(int argc, char **argv) FILE *keyfp = NULL; FILE *infp = stdin; FILE *outfp = stdout; - uint8_t pubkeybuf[XMSS_PUBLIC_KEY_SIZE]; - uint8_t *keybuf = NULL; - size_t keylen; - const uint8_t *cp; - uint8_t *p; - size_t len; XMSS_KEY key; XMSS_SIGN_CTX ctx; uint8_t sig[XMSS_SIGNATURE_MAX_SIZE]; @@ -105,81 +125,24 @@ bad: goto end; } - // load xmss_public_key - if (fread(pubkeybuf, 1, sizeof(pubkeybuf), keyfp) != sizeof(pubkeybuf)) { - error_print(); - goto end; - } - cp = pubkeybuf; - len = sizeof(pubkeybuf); - if (xmss_public_key_from_bytes(&key, &cp, &len) != 1) { - error_print(); - goto end; - } - if (len) { - error_print(); - goto end; - } - - // xmss_private_key_size - if (xmss_private_key_size(key.public_key.xmss_type, &keylen) != 1) { - error_print(); - goto end; - } - if (!(keybuf = malloc(keylen))) { - error_print(); - goto end; - } - memcpy(keybuf, pubkeybuf, sizeof(pubkeybuf)); - - len = keylen - sizeof(pubkeybuf); - if (fread(keybuf + sizeof(pubkeybuf), 1, len, keyfp) != len) { + if (xmss_private_key_from_file(&key, keyfp) != 1) { fprintf(stderr, "%s: read private key failure\n", prog); goto end; } - cp = keybuf; - if (xmss_private_key_from_bytes(&key, &cp, &keylen) != 1) { - error_print(); - goto end; - } - if (keylen) { - error_print(); - return -1; - } - if (verbose) { xmss_public_key_print(stderr, 0, 0, "lms_public_key", &key); } - if (xmss_sign_init(&ctx, &key) != 1) { + if (xmss_key_set_update_callback(&key, key_update_cb, keyfp) != 1) { error_print(); goto end; } -#if 0 - // write updated key back to file - // TODO: write back `q` only - p = keybuf; - if (xmss_private_key_to_bytes(&key, &p, &keylen) != 1) { - error_print(); - return -1; - } - rewind(keyfp); - if (fwrite(keybuf, 1, keylen, keyfp) != keylen) { - error_print(); - return -1; - } -#else - // write index only - if (fseek(keyfp, XMSS_PUBLIC_KEY_SIZE, SEEK_SET) != 0) { + if (xmss_sign_init(&ctx, &key) != 1) { error_print(); goto end; } - uint8_t index_buf[4]; - PUTU32(index_buf, key.index); - fwrite(index_buf, 1, 4, keyfp); -#endif while (1) { uint8_t buf[1024]; @@ -209,10 +172,6 @@ bad: end: xmss_key_cleanup(&key); gmssl_secure_clear(&ctx, sizeof(ctx)); - if (keybuf) { - gmssl_secure_clear(keybuf, keylen); - free(keybuf); - } if (keyfp) fclose(keyfp); if (infp && infp != stdin) fclose(infp); if (outfp && outfp != stdout) fclose(outfp);