Update XMSS

Add key_update callback and private_key_from_file
This commit is contained in:
Zhi Guan
2026-01-18 17:09:27 +08:00
parent 2e8d3abbc9
commit 9db11c6d06
4 changed files with 258 additions and 140 deletions

View File

@@ -207,12 +207,18 @@ typedef struct {
#define XMSS_PUBLIC_KEY_SIZE (4 + 32 + 32) // = 68 #define XMSS_PUBLIC_KEY_SIZE (4 + 32 + 32) // = 68
typedef struct { typedef struct XMSS_KEY_st XMSS_KEY;
typedef int (*xmss_key_update_callback)(XMSS_KEY *key);
typedef struct XMSS_KEY_st {
XMSS_PUBLIC_KEY public_key; XMSS_PUBLIC_KEY public_key;
uint32_t index; uint32_t index;
xmss_hash256_t secret; xmss_hash256_t secret;
xmss_hash256_t sk_prf; xmss_hash256_t sk_prf;
xmss_hash256_t *tree; // xmss_hash256_t[2^(h + 1) - 1] xmss_hash256_t *tree; // xmss_hash256_t[2^(h + 1) - 1]
xmss_key_update_callback update_callback;
void *update_param;
} XMSS_KEY; } XMSS_KEY;
// XMSS_SHA2_10_256: 65,640 // XMSS_SHA2_10_256: 65,640
@@ -224,13 +230,17 @@ int xmss_private_key_size(uint32_t xmss_type, size_t *keysize);
int xmss_key_generate(XMSS_KEY *key, uint32_t xmss_type); int xmss_key_generate(XMSS_KEY *key, uint32_t xmss_type);
int xmss_key_remaining_signs(const XMSS_KEY *key, size_t *count); int xmss_key_remaining_signs(const XMSS_KEY *key, size_t *count);
int xmss_key_set_update_callback(XMSS_KEY *key, xmss_key_update_callback update_cb, void *param);
int xmss_key_update(XMSS_KEY *key);
void xmss_key_cleanup(XMSS_KEY *key);
int xmss_public_key_to_bytes(const XMSS_KEY *key, uint8_t **out, size_t *outlen); int xmss_public_key_to_bytes(const XMSS_KEY *key, uint8_t **out, size_t *outlen);
int xmss_public_key_from_bytes(XMSS_KEY *key, const uint8_t **in, size_t *inlen); int xmss_public_key_from_bytes(XMSS_KEY *key, const uint8_t **in, size_t *inlen);
int xmss_public_key_print(FILE *fp, int fmt, int ind, const char *label, const XMSS_KEY *key); int xmss_public_key_print(FILE *fp, int fmt, int ind, const char *label, const XMSS_KEY *key);
int xmss_private_key_to_bytes(const XMSS_KEY *key, uint8_t **out, size_t *outlen); int xmss_private_key_to_bytes(const XMSS_KEY *key, uint8_t **out, size_t *outlen);
int xmss_private_key_from_bytes(XMSS_KEY *key, const uint8_t **in, size_t *inlen); int xmss_private_key_from_bytes(XMSS_KEY *key, const uint8_t **in, size_t *inlen);
int xmss_private_key_from_file(XMSS_KEY *key, FILE *fp);
int xmss_private_key_print(FILE *fp, int fmt, int ind, const char *label, const XMSS_KEY *key); int xmss_private_key_print(FILE *fp, int fmt, int ind, const char *label, const XMSS_KEY *key);
void xmss_key_cleanup(XMSS_KEY *key);
typedef struct { typedef struct {
@@ -345,13 +355,19 @@ typedef struct {
#define XMSSMT_PUBLIC_KEY_SIZE (4 + sizeof(xmss_hash256_t) + sizeof(xmss_hash256_t)) // = 68 bytes #define XMSSMT_PUBLIC_KEY_SIZE (4 + sizeof(xmss_hash256_t) + sizeof(xmss_hash256_t)) // = 68 bytes
typedef struct { typedef struct XMSSMT_KEY_st XMSSMT_KEY;
typedef int (*xmssmt_key_update_callback)(XMSSMT_KEY *key);
typedef struct XMSSMT_KEY_st {
XMSSMT_PUBLIC_KEY public_key; XMSSMT_PUBLIC_KEY public_key;
uint64_t index; // in [0, 2^60 - 1] uint64_t index; // in [0, 2^60 - 1]
xmss_hash256_t secret; xmss_hash256_t secret;
xmss_hash256_t sk_prf; xmss_hash256_t sk_prf;
xmss_hash256_t *trees; xmss_hash256_t *trees;
xmss_wots_sig_t wots_sigs[XMSSMT_MAX_LAYERS - 1]; xmss_wots_sig_t wots_sigs[XMSSMT_MAX_LAYERS - 1];
xmssmt_key_update_callback update_callback;
void *update_param;
} XMSSMT_KEY; } XMSSMT_KEY;
/* /*
@@ -368,12 +384,14 @@ int xmssmt_private_key_size(uint32_t xmssmt_type, size_t *len);
int xmssmt_build_auth_path(const xmss_hash256_t *tree, size_t height, size_t layers, uint64_t index, xmss_hash256_t *auth_path); int xmssmt_build_auth_path(const xmss_hash256_t *tree, size_t height, size_t layers, uint64_t index, xmss_hash256_t *auth_path);
int xmssmt_key_generate(XMSSMT_KEY *key, uint32_t xmssmt_type); int xmssmt_key_generate(XMSSMT_KEY *key, uint32_t xmssmt_type);
int xmssmt_key_set_update_callback(XMSSMT_KEY *key, xmssmt_key_update_callback update_cb, void *param);
int xmssmt_key_update(XMSSMT_KEY *key); int xmssmt_key_update(XMSSMT_KEY *key);
int xmssmt_public_key_to_bytes(const XMSSMT_KEY *key, uint8_t **out, size_t *outlen); int xmssmt_public_key_to_bytes(const XMSSMT_KEY *key, uint8_t **out, size_t *outlen);
int xmssmt_public_key_from_bytes(XMSSMT_KEY *key, const uint8_t **in, size_t *inlen); int xmssmt_public_key_from_bytes(XMSSMT_KEY *key, const uint8_t **in, size_t *inlen);
int xmssmt_public_key_print(FILE *fp, int fmt, int ind, const char *label, const XMSSMT_KEY *key); int xmssmt_public_key_print(FILE *fp, int fmt, int ind, const char *label, const XMSSMT_KEY *key);
int xmssmt_private_key_to_bytes(const XMSSMT_KEY *key, uint8_t **out, size_t *outlen); int xmssmt_private_key_to_bytes(const XMSSMT_KEY *key, uint8_t **out, size_t *outlen);
int xmssmt_private_key_from_bytes(XMSSMT_KEY *key, const uint8_t **in, size_t *inlen); int xmssmt_private_key_from_bytes(XMSSMT_KEY *key, const uint8_t **in, size_t *inlen);
int xmssmt_private_key_from_file(XMSSMT_KEY *key, FILE *fp);
int xmssmt_private_key_print(FILE *fp, int fmt, int ind, const char *label, const XMSSMT_KEY *key); int xmssmt_private_key_print(FILE *fp, int fmt, int ind, const char *label, const XMSSMT_KEY *key);
void xmssmt_key_cleanup(XMSSMT_KEY *key); void xmssmt_key_cleanup(XMSSMT_KEY *key);
@@ -388,8 +406,7 @@ typedef struct {
int xmssmt_index_to_bytes(uint64_t index, uint32_t xmssmt_type, uint8_t **out, size_t *outlen); int xmssmt_index_to_bytes(uint64_t index, uint32_t xmssmt_type, uint8_t **out, size_t *outlen);
int xmssmt_index_from_bytes(uint64_t *index, uint32_t xmssmt_type, const uint8_t **in, size_t *inlen); int xmssmt_index_from_bytes(uint64_t *index, uint32_t xmssmt_type, const uint8_t **in, size_t *inlen);
#define XMSSMT_SIGNATURE_MAX_SIZE \ #define XMSSMT_SIGNATURE_MAX_SIZE sizeof(XMSSMT_SIGNATURE) // >= 27688 bytes
(sizeof(uint64_t) + sizeof(xmss_hash256_t) + sizeof(xmss_wots_sig_t)*XMSSMT_MAX_LAYERS + sizeof(xmss_hash256_t)*XMSSMT_MAX_HEIGHT) // = 27688 bytes
int xmssmt_key_get_signature_size(const XMSSMT_KEY *key, size_t *siglen); int xmssmt_key_get_signature_size(const XMSSMT_KEY *key, size_t *siglen);
int xmssmt_signature_size(uint32_t xmssmt_type, size_t *siglen); int xmssmt_signature_size(uint32_t xmssmt_type, size_t *siglen);

View File

@@ -754,6 +754,17 @@ end:
return ret; return ret;
} }
int xmss_key_set_update_callback(XMSS_KEY *key, xmss_key_update_callback update_cb, void *param)
{
if (!key) {
error_print();
return -1;
}
key->update_callback = update_cb;
key->update_param = param;
return 1;
}
int xmss_key_update(XMSS_KEY *key) int xmss_key_update(XMSS_KEY *key)
{ {
size_t height; size_t height;
@@ -774,6 +785,13 @@ int xmss_key_update(XMSS_KEY *key)
return 0; return 0;
} }
key->index++; key->index++;
if (key->update_callback) {
if (key->update_callback(key) != 1) {
error_print();
return -1;
}
}
return 1; return 1;
} }
@@ -1186,12 +1204,15 @@ int xmss_sign_init(XMSS_SIGN_CTX *ctx, XMSS_KEY *key)
xmss_adrs_set_ots_address(adrs, key->index); xmss_adrs_set_ots_address(adrs, key->index);
xmss_wots_derive_sk(key->secret, key->public_key.seed, adrs, ctx->xmss_sig.wots_sig); xmss_wots_derive_sk(key->secret, key->public_key.seed, adrs, ctx->xmss_sig.wots_sig);
// update key->index
if (xmss_key_update(key) != 1) {
error_print();
return -1;
}
// xmss_sig.auth_path // xmss_sig.auth_path
xmss_build_auth_path(key->tree, height, key->index, ctx->xmss_sig.auth_path); xmss_build_auth_path(key->tree, height, key->index, ctx->xmss_sig.auth_path);
// update key->index
key->index++;
// H_msg(M) := HASH256(toByte(2, 32) || r || XMSS_ROOT || toByte(idx_sig, 32) || M) // H_msg(M) := HASH256(toByte(2, 32) || r || XMSS_ROOT || toByte(idx_sig, 32) || M)
xmss_hash256_init(&ctx->hash256_ctx); xmss_hash256_init(&ctx->hash256_ctx);
xmss_hash256_update(&ctx->hash256_ctx, xmss_hash256_two, sizeof(xmss_hash256_t)); xmss_hash256_update(&ctx->hash256_ctx, xmss_hash256_two, sizeof(xmss_hash256_t));
@@ -1575,6 +1596,17 @@ int xmssmt_private_key_from_bytes(XMSSMT_KEY *key, const uint8_t **in, size_t *i
return 1; return 1;
} }
int xmssmt_key_set_update_callback(XMSSMT_KEY *key, xmssmt_key_update_callback update_cb, void *param)
{
if (!key) {
error_print();
return -1;
}
key->update_callback = update_cb;
key->update_param = param;
return 1;
}
int xmssmt_key_update(XMSSMT_KEY *key) int xmssmt_key_update(XMSSMT_KEY *key)
{ {
size_t height; size_t height;
@@ -1628,6 +1660,12 @@ int xmssmt_key_update(XMSSMT_KEY *key)
key->index++; key->index++;
if (key->update_callback) {
if (key->update_callback(key) != 1) {
error_print();
return -1;
}
}
return 1; return 1;
} }
@@ -2447,3 +2485,138 @@ int xmssmt_verify_finish(XMSSMT_SIGN_CTX *ctx)
return 1; return 1;
} }
int xmss_private_key_from_file(XMSS_KEY *key, FILE *fp)
{
uint8_t pubkeybuf[XMSS_PUBLIC_KEY_SIZE];
uint8_t *keybuf = NULL;
size_t keylen;
const uint8_t *cp;
size_t len;
if (!key || !fp) {
error_print();
return -1;
}
// load xmss_public_key and get xmss_private_key_size
len = sizeof(pubkeybuf);
if (fread(pubkeybuf, 1, len, fp) != len) {
error_print();
return -1;
}
cp = pubkeybuf;
if (xmss_public_key_from_bytes(key, &cp, &len) != 1) {
error_print();
return -1;
}
if (len) {
error_print();
return -1;
}
if (xmss_private_key_size(key->public_key.xmss_type, &keylen) != 1) {
error_print();
return -1;
}
if (keylen <= sizeof(pubkeybuf)) {
error_print();
return -1;
}
// malloc and load full xmss_private_key
if (!(keybuf = malloc(keylen))) {
error_print();
return -1;
}
memcpy(keybuf, pubkeybuf, sizeof(pubkeybuf));
len = keylen - sizeof(pubkeybuf);
if (fread(keybuf + sizeof(pubkeybuf), 1, len, fp) != len) {
free(keybuf);
error_print();
return -1;
}
cp = keybuf;
if (xmss_private_key_from_bytes(key, &cp, &keylen) != 1) {
free(keybuf);
error_print();
return -1;
}
if (keylen) {
free(keybuf);
error_print();
return -1;
}
free(keybuf);
return 1;
}
int xmssmt_private_key_from_file(XMSSMT_KEY *key, FILE *fp)
{
uint8_t pubkeybuf[XMSSMT_PUBLIC_KEY_SIZE];
uint8_t *keybuf = NULL;
size_t keylen;
const uint8_t *cp;
size_t len;
if (!key || !fp) {
error_print();
return -1;
}
// load xmss_public_key and get xmss_private_key_size
len = sizeof(pubkeybuf);
if (fread(pubkeybuf, 1, len, fp) != len) {
error_print();
return -1;
}
cp = pubkeybuf;
if (xmssmt_public_key_from_bytes(key, &cp, &len) != 1) {
error_print();
return -1;
}
if (len) {
error_print();
return -1;
}
if (xmssmt_private_key_size(key->public_key.xmssmt_type, &keylen) != 1) {
error_print();
return -1;
}
if (keylen <= sizeof(pubkeybuf)) {
error_print();
return -1;
}
// malloc and load full xmss_private_key
if (!(keybuf = malloc(keylen))) {
error_print();
return -1;
}
memcpy(keybuf, pubkeybuf, sizeof(pubkeybuf));
len = keylen - sizeof(pubkeybuf);
if (fread(keybuf + sizeof(pubkeybuf), 1, len, fp) != len) {
free(keybuf);
error_print();
return -1;
}
cp = keybuf;
if (xmssmt_private_key_from_bytes(key, &cp, &keylen) != 1) {
free(keybuf);
error_print();
return -1;
}
if (keylen) {
free(keybuf);
error_print();
return -1;
}
free(keybuf);
return 1;
}

View File

@@ -26,6 +26,34 @@ static const char *options =
" -verbose Print public key and signature\n" " -verbose Print public key and signature\n"
"\n"; "\n";
static int key_update_cb(XMSSMT_KEY *key)
{
FILE *fp;
uint8_t index_buf[8];
uint8_t *p = index_buf;
size_t len = 0;
if (!key->update_param) {
error_print();
return -1;
}
fp = (FILE *)key->update_param;
// write index only
xmssmt_index_to_bytes(key->index, key->public_key.xmssmt_type, &p, &len);
if (fseek(fp, XMSSMT_PUBLIC_KEY_SIZE, SEEK_SET) != 0) {
error_print();
return -1;
}
if (fwrite(index_buf, 1, len, fp) != len
|| fflush(fp) != 0) {
error_print();
return -1;
}
return 1;
}
int xmssmtsign_main(int argc, char **argv) int xmssmtsign_main(int argc, char **argv)
{ {
int ret = 1; int ret = 1;
@@ -37,11 +65,6 @@ int xmssmtsign_main(int argc, char **argv)
FILE *keyfp = NULL; FILE *keyfp = NULL;
FILE *infp = stdin; FILE *infp = stdin;
FILE *outfp = stdout; FILE *outfp = stdout;
uint8_t pubkey[XMSSMT_PUBLIC_KEY_SIZE];
uint8_t *keybuf = NULL;
size_t keylen;
const uint8_t *cp;
uint8_t *p;
XMSSMT_KEY key; XMSSMT_KEY key;
XMSSMT_SIGN_CTX ctx; XMSSMT_SIGN_CTX ctx;
uint8_t sig[XMSSMT_SIGNATURE_MAX_SIZE]; uint8_t sig[XMSSMT_SIGNATURE_MAX_SIZE];
@@ -103,77 +126,24 @@ bad:
goto end; goto end;
} }
if (fread(pubkey, 1, sizeof(pubkey), keyfp) != sizeof(pubkey)) { if (xmssmt_private_key_from_file(&key, keyfp) != 1) {
error_print();
goto end;
}
cp = pubkey;
keylen = sizeof(pubkey);
if (xmssmt_public_key_from_bytes(&key, &cp, &keylen) != 1 ) {
error_print();
goto end;
}
if (xmssmt_private_key_size(key.public_key.xmssmt_type, &keylen) != 1) {
error_print();
goto end;
}
if (!(keybuf = malloc(keylen))) {
error_print();
goto end;
}
memcpy(keybuf, pubkey, sizeof(pubkey));
if (fread(keybuf + sizeof(pubkey), 1, keylen - sizeof(pubkey), keyfp) != keylen - sizeof(pubkey)) {
fprintf(stderr, "%s: read private key failure\n", prog); fprintf(stderr, "%s: read private key failure\n", prog);
goto end; goto end;
} }
cp = keybuf;
if (xmssmt_private_key_from_bytes(&key, &cp, &keylen) != 1) {
error_print();
goto end;
}
if (keylen) {
error_print();
return -1;
}
if (verbose) { if (verbose) {
xmssmt_public_key_print(stderr, 0, 0, "lms_public_key", &key); xmssmt_public_key_print(stderr, 0, 0, "lms_public_key", &key);
} }
if (xmssmt_sign_init(&ctx, &key) != 1) { if (xmssmt_key_set_update_callback(&key, key_update_cb, keyfp) != 1) {
error_print(); error_print();
goto end; goto end;
} }
#if 0 if (xmssmt_sign_init(&ctx, &key) != 1) {
// write updated key back to file
// TODO: write back `q` only
p = keybuf;
keylen = 0;
if (xmssmt_private_key_to_bytes(&key, &p, &keylen) != 1) {
error_print();
return -1;
}
rewind(keyfp);
if (fwrite(keybuf, 1, keylen, keyfp) != keylen) {
error_print();
return -1;
}
#else
if (fseek(keyfp, XMSSMT_PUBLIC_KEY_SIZE, SEEK_SET) != 0) {
error_print(); error_print();
goto end; goto end;
} }
uint8_t index_buf[8];
uint8_t *pindex = index_buf;
size_t index_len = 0;
xmssmt_index_to_bytes(key.index, key.public_key.xmssmt_type, &pindex, &index_len);
fwrite(index_buf, 1, index_len, keyfp);
#endif
while (1) { while (1) {
uint8_t buf[1024]; uint8_t buf[1024];
@@ -202,7 +172,6 @@ bad:
end: end:
xmssmt_key_cleanup(&key); xmssmt_key_cleanup(&key);
gmssl_secure_clear(keybuf, keylen);
gmssl_secure_clear(&ctx, sizeof(ctx)); gmssl_secure_clear(&ctx, sizeof(ctx));
if (keyfp) fclose(keyfp); if (keyfp) fclose(keyfp);
if (infp && infp != stdin) fclose(infp); if (infp && infp != stdin) fclose(infp);

View File

@@ -27,6 +27,32 @@ static const char *options =
" -verbose Print public key and signature\n" " -verbose Print public key and signature\n"
"\n"; "\n";
static int key_update_cb(XMSS_KEY *key)
{
FILE *fp;
uint8_t index_buf[4];
if (!key->update_param) {
error_print();
return -1;
}
fp = (FILE *)key->update_param;
// write index only
PUTU32(index_buf, key->index);
if (fseek(fp, XMSS_PUBLIC_KEY_SIZE, SEEK_SET) != 0) {
error_print();
return -1;
}
if (fwrite(index_buf, 1, sizeof(index_buf), fp) != sizeof(index_buf)
|| fflush(fp) != 0) {
error_print();
return -1;
}
return 1;
}
int xmsssign_main(int argc, char **argv) int xmsssign_main(int argc, char **argv)
{ {
int ret = 1; int ret = 1;
@@ -38,12 +64,6 @@ int xmsssign_main(int argc, char **argv)
FILE *keyfp = NULL; FILE *keyfp = NULL;
FILE *infp = stdin; FILE *infp = stdin;
FILE *outfp = stdout; FILE *outfp = stdout;
uint8_t pubkeybuf[XMSS_PUBLIC_KEY_SIZE];
uint8_t *keybuf = NULL;
size_t keylen;
const uint8_t *cp;
uint8_t *p;
size_t len;
XMSS_KEY key; XMSS_KEY key;
XMSS_SIGN_CTX ctx; XMSS_SIGN_CTX ctx;
uint8_t sig[XMSS_SIGNATURE_MAX_SIZE]; uint8_t sig[XMSS_SIGNATURE_MAX_SIZE];
@@ -105,81 +125,24 @@ bad:
goto end; goto end;
} }
// load xmss_public_key if (xmss_private_key_from_file(&key, keyfp) != 1) {
if (fread(pubkeybuf, 1, sizeof(pubkeybuf), keyfp) != sizeof(pubkeybuf)) {
error_print();
goto end;
}
cp = pubkeybuf;
len = sizeof(pubkeybuf);
if (xmss_public_key_from_bytes(&key, &cp, &len) != 1) {
error_print();
goto end;
}
if (len) {
error_print();
goto end;
}
// xmss_private_key_size
if (xmss_private_key_size(key.public_key.xmss_type, &keylen) != 1) {
error_print();
goto end;
}
if (!(keybuf = malloc(keylen))) {
error_print();
goto end;
}
memcpy(keybuf, pubkeybuf, sizeof(pubkeybuf));
len = keylen - sizeof(pubkeybuf);
if (fread(keybuf + sizeof(pubkeybuf), 1, len, keyfp) != len) {
fprintf(stderr, "%s: read private key failure\n", prog); fprintf(stderr, "%s: read private key failure\n", prog);
goto end; goto end;
} }
cp = keybuf;
if (xmss_private_key_from_bytes(&key, &cp, &keylen) != 1) {
error_print();
goto end;
}
if (keylen) {
error_print();
return -1;
}
if (verbose) { if (verbose) {
xmss_public_key_print(stderr, 0, 0, "lms_public_key", &key); xmss_public_key_print(stderr, 0, 0, "lms_public_key", &key);
} }
if (xmss_sign_init(&ctx, &key) != 1) { if (xmss_key_set_update_callback(&key, key_update_cb, keyfp) != 1) {
error_print(); error_print();
goto end; goto end;
} }
#if 0 if (xmss_sign_init(&ctx, &key) != 1) {
// write updated key back to file
// TODO: write back `q` only
p = keybuf;
if (xmss_private_key_to_bytes(&key, &p, &keylen) != 1) {
error_print();
return -1;
}
rewind(keyfp);
if (fwrite(keybuf, 1, keylen, keyfp) != keylen) {
error_print();
return -1;
}
#else
// write index only
if (fseek(keyfp, XMSS_PUBLIC_KEY_SIZE, SEEK_SET) != 0) {
error_print(); error_print();
goto end; goto end;
} }
uint8_t index_buf[4];
PUTU32(index_buf, key.index);
fwrite(index_buf, 1, 4, keyfp);
#endif
while (1) { while (1) {
uint8_t buf[1024]; uint8_t buf[1024];
@@ -209,10 +172,6 @@ bad:
end: end:
xmss_key_cleanup(&key); xmss_key_cleanup(&key);
gmssl_secure_clear(&ctx, sizeof(ctx)); gmssl_secure_clear(&ctx, sizeof(ctx));
if (keybuf) {
gmssl_secure_clear(keybuf, keylen);
free(keybuf);
}
if (keyfp) fclose(keyfp); if (keyfp) fclose(keyfp);
if (infp && infp != stdin) fclose(infp); if (infp && infp != stdin) fclose(infp);
if (outfp && outfp != stdout) fclose(outfp); if (outfp && outfp != stdout) fclose(outfp);