Update XMSS-SM3

XMSS is in developing, not fully tested
This commit is contained in:
Zhi Guan
2025-12-08 10:24:00 +08:00
parent d3dd07e885
commit bae8f54667
11 changed files with 573 additions and 276 deletions

View File

@@ -18,7 +18,7 @@
#include <gmssl/hex.h>
#include <gmssl/rand.h>
#include <gmssl/error.h>
#include <gmssl/sm3_xmss.h>
#include <gmssl/xmss.h>
#define uint32_from_bytes(ptr) \
@@ -134,6 +134,8 @@ static void hash256_prf_keygen_init(HASH256_CTX *hash256_ctx, const uint8_t key[
hash256_update(hash256_ctx, key, 32);
}
// compute wots+ chain, start from the secret x or a signature, output the final value
// 最终的这个结果应该不是PK啊
static void wots_chain(const uint8_t x[32], int start, int steps,
const HASH256_CTX *prf_seed_ctx, const uint8_t in_adrs[32], uint8_t pk[32])
{
@@ -173,7 +175,7 @@ static void wots_chain(const uint8_t x[32], int start, int steps,
memcpy(pk, state, 32);
}
void sm3_wots_derive_sk(const uint8_t secret[32], const uint8_t seed[32], const uint8_t in_adrs[32], hash256_bytes_t sk[67])
void sm3_wots_derive_sk(const uint8_t secret[32], const uint8_t seed[32], const uint8_t in_adrs[32], hash256_t sk[67])
{
HASH256_CTX prf_keygen_ctx;
HASH256_CTX prf_ctx;
@@ -196,9 +198,9 @@ void sm3_wots_derive_sk(const uint8_t secret[32], const uint8_t seed[32], const
}
}
void sm3_wots_derive_pk(const hash256_bytes_t sk[67],
void sm3_wots_derive_pk(const hash256_t sk[67],
const HASH256_CTX *prf_seed_ctx, const uint8_t in_adrs[32],
hash256_bytes_t pk[67])
hash256_t pk[67])
{
uint8_t adrs[32];
int i;
@@ -238,9 +240,9 @@ static void base_w_and_checksum(const uint8_t dgst[32], uint8_t msg[67])
msg[66] = csum_bytes[1] >> 4;
}
void sm3_wots_do_sign(const hash256_bytes_t sk[67],
void sm3_wots_do_sign(const hash256_t sk[67],
const HASH256_CTX *prf_seed_ctx, const uint8_t in_adrs[32],
const uint8_t dgst[32], hash256_bytes_t sig[67])
const uint8_t dgst[32], hash256_t sig[67])
{
uint8_t adrs[32];
uint8_t msg[67];
@@ -256,9 +258,9 @@ void sm3_wots_do_sign(const hash256_bytes_t sk[67],
}
}
void sm3_wots_sig_to_pk(const hash256_bytes_t sig[67], const uint8_t dgst[32],
void sm3_wots_sig_to_pk(const hash256_t sig[67], const uint8_t dgst[32],
const HASH256_CTX *prf_seed_ctx, const uint8_t in_adrs[32],
hash256_bytes_t pk[67])
hash256_t pk[67])
{
uint8_t adrs[32];
uint8_t msg[67];
@@ -311,11 +313,11 @@ static void randomized_hash(const uint8_t left[32], const uint8_t right[32],
hash256_finish(&hash256_ctx, out);
}
static void build_ltree(const hash256_bytes_t in_pk[67],
static void build_ltree(const hash256_t in_pk[67],
const HASH256_CTX *prf_seed_ctx, const uint8_t in_adrs[32],
uint8_t wots_root[32])
{
hash256_bytes_t pk[67];
hash256_t pk[67];
uint8_t adrs[32];
uint32_t tree_height = 0;
int len = 67;
@@ -344,9 +346,9 @@ static void build_ltree(const hash256_bytes_t in_pk[67],
// len(tree) = 2^h - 1
// root = tree[len(tree) - 1] = tree[2^h - 2]
static void build_hash_tree(const hash256_bytes_t *leaves, int height,
static void build_hash_tree(const hash256_t *leaves, int height,
const HASH256_CTX *prf_seed_ctx, const uint8_t in_adrs[32],
hash256_bytes_t *tree)
hash256_t *tree)
{
uint8_t adrs[32];
int n = 1 << height;
@@ -368,9 +370,9 @@ static void build_hash_tree(const hash256_bytes_t *leaves, int height,
}
}
void sm3_xmss_derive_root(const uint8_t xmss_secret[32], int height,
void xmss_derive_root(const uint8_t xmss_secret[32], int height,
const uint8_t seed[32],
hash256_bytes_t *tree, uint8_t xmss_root[32])
hash256_t *tree, uint8_t xmss_root[32])
{
HASH256_CTX prf_keygen_ctx;
HASH256_CTX prf_seed_ctx;
@@ -383,8 +385,8 @@ void sm3_xmss_derive_root(const uint8_t xmss_secret[32], int height,
// generate all the wots pk[]
for (i = 0; i < (uint32_t)(1<<height); i++) {
//HASH256_CTX prf_ctx = prf_keygen_ctx;
hash256_bytes_t wots_sk[67];
hash256_bytes_t wots_pk[67];
hash256_t wots_sk[67];
hash256_t wots_pk[67];
// xmss_secret => wots_sk[0..67] => wots_pk[0..67]
// follow github.com/XMSS/xmss-reference
@@ -405,7 +407,7 @@ void sm3_xmss_derive_root(const uint8_t xmss_secret[32], int height,
memcpy(xmss_root, tree + (1 << (height + 1)) - 2, 32);
}
static void build_auth_path(const hash256_bytes_t *tree, int height, int index, hash256_bytes_t *path)
static void build_auth_path(const hash256_t *tree, int height, int index, hash256_t *path)
{
int h;
for (h = 0; h < height; h++) {
@@ -415,16 +417,16 @@ static void build_auth_path(const hash256_bytes_t *tree, int height, int index,
}
}
void sm3_xmss_do_sign(const uint8_t xmss_secret[32], int index,
void xmss_do_sign(const uint8_t xmss_secret[32], int index,
const uint8_t seed[32], const uint8_t in_adrs[32], int height,
const hash256_bytes_t *tree,
const hash256_t *tree,
const uint8_t dgst[32],
hash256_bytes_t wots_sig[67],
hash256_bytes_t *auth_path)
hash256_t wots_sig[67],
hash256_t *auth_path)
{
HASH256_CTX prf_seed_ctx;
uint8_t adrs[32];
hash256_bytes_t wots_sk[67];
hash256_t wots_sk[67];
hash256_prf_init(&prf_seed_ctx, seed);
memcpy(adrs, in_adrs, 32);
@@ -439,14 +441,14 @@ void sm3_xmss_do_sign(const uint8_t xmss_secret[32], int index,
build_auth_path(tree, height, index, auth_path);
}
void sm3_xmss_sig_to_root(const hash256_bytes_t wots_sig[67], int index, const hash256_bytes_t *auth_path,
void xmss_sig_to_root(const hash256_t wots_sig[67], int index, const hash256_t *auth_path,
const uint8_t seed[32], const uint8_t in_adrs[32], int height,
const uint8_t dgst[32],
uint8_t xmss_root[32])
{
HASH256_CTX prf_seed_ctx;
uint8_t adrs[32];
hash256_bytes_t wots_pk[67];
hash256_t wots_pk[67];
uint8_t *node = xmss_root;
int h;
@@ -475,7 +477,7 @@ void sm3_xmss_sig_to_root(const hash256_bytes_t wots_sig[67], int index, const h
}
}
int sm3_xmss_height_from_oid(uint32_t *height, uint32_t id)
int xmss_height_from_oid(uint32_t *height, uint32_t id)
{
switch (id) {
case XMSS_SM3_10: *height = 10; break;
@@ -491,11 +493,27 @@ int sm3_xmss_height_from_oid(uint32_t *height, uint32_t id)
return 1;
}
int sm3_xmss_key_generate(SM3_XMSS_KEY *key, uint32_t oid)
int xmss_oid_to_height(uint32_t oid, size_t *height)
{
switch (oid) {
case XMSS_SM3_10: *height = 10; break;
case XMSS_SM3_16: *height = 16; break;
case XMSS_SM3_20: *height = 20; break;
case XMSS_SHA256_10: *height = 10; break;
case XMSS_SHA256_16: *height = 16; break;
case XMSS_SHA256_20: *height = 20; break;
default:
error_print();
return -1;
}
return 1;
}
int xmss_key_generate(XMSS_KEY *key, uint32_t oid)
{
uint32_t height;
if (sm3_xmss_height_from_oid(&height, oid) != 1) {
if (xmss_height_from_oid(&height, oid) != 1) {
error_print();
return -1;
}
@@ -509,12 +527,12 @@ int sm3_xmss_key_generate(SM3_XMSS_KEY *key, uint32_t oid)
error_print();
return -1;
}
sm3_xmss_derive_root(key->secret, height, key->seed, key->tree, key->root);
xmss_derive_root(key->secret, height, key->seed, key->tree, key->root);
return 1;
}
void sm3_xmss_key_cleanup(SM3_XMSS_KEY *key)
void xmss_key_cleanup(XMSS_KEY *key)
{
if (key->tree) {
free(key->tree);
@@ -522,7 +540,7 @@ void sm3_xmss_key_cleanup(SM3_XMSS_KEY *key)
gmssl_secure_clear(key, sizeof(*key));
}
int sm3_xmss_key_print(FILE *fp, int fmt, int ind, const char *label, const SM3_XMSS_KEY *key)
int xmss_key_print(FILE *fp, int fmt, int ind, const char *label, const XMSS_KEY *key)
{
format_print(fp, fmt, ind, "%s\n", label);
ind += 4;
@@ -535,16 +553,18 @@ int sm3_xmss_key_print(FILE *fp, int fmt, int ind, const char *label, const SM3_
return 1;
}
int sm3_xmss_key_get_height(const SM3_XMSS_KEY *key, uint32_t *height)
int xmss_key_get_height(const XMSS_KEY *key, uint32_t *height)
{
if (sm3_xmss_height_from_oid(height, key->oid) != 1) {
if (xmss_height_from_oid(height, key->oid) != 1) {
error_print();
return -1;
}
return 1;
}
int sm3_xmss_key_to_bytes(const SM3_XMSS_KEY *key, uint8_t *out, size_t *outlen)
// save the full tree, should use a flag to choose cache tree or not
int xmss_key_to_bytes(const XMSS_KEY *key, uint8_t *out, size_t *outlen)
{
uint32_t height;
size_t tree_size;
@@ -555,7 +575,7 @@ int sm3_xmss_key_to_bytes(const SM3_XMSS_KEY *key, uint8_t *out, size_t *outlen)
return -1;
}
if (sm3_xmss_height_from_oid(&height, key->oid) != 1) {
if (xmss_height_from_oid(&height, key->oid) != 1) {
error_print();
return -1;
}
@@ -583,7 +603,7 @@ int sm3_xmss_key_to_bytes(const SM3_XMSS_KEY *key, uint8_t *out, size_t *outlen)
return 1;
}
int sm3_xmss_key_from_bytes(SM3_XMSS_KEY *key, const uint8_t *in, size_t inlen)
int xmss_key_from_bytes(XMSS_KEY *key, const uint8_t *in, size_t inlen)
{
uint32_t height;
size_t tree_size;
@@ -596,7 +616,7 @@ int sm3_xmss_key_from_bytes(SM3_XMSS_KEY *key, const uint8_t *in, size_t inlen)
p = in;
key->oid = uint32_from_bytes(p); p += 4;
if (sm3_xmss_height_from_oid(&height, key->oid) != 1) {
if (xmss_height_from_oid(&height, key->oid) != 1) {
error_print();
return -1;
}
@@ -624,7 +644,7 @@ int sm3_xmss_key_from_bytes(SM3_XMSS_KEY *key, const uint8_t *in, size_t inlen)
return 1;
}
int sm3_xmss_public_key_to_bytes(const SM3_XMSS_KEY *key, uint8_t *out, size_t *outlen)
int xmss_public_key_to_bytes(const XMSS_KEY *key, uint8_t *out, size_t *outlen)
{
uint32_t height;
uint8_t *p;
@@ -634,7 +654,7 @@ int sm3_xmss_public_key_to_bytes(const SM3_XMSS_KEY *key, uint8_t *out, size_t *
return -1;
}
if (sm3_xmss_height_from_oid(&height, key->oid) != 1) {
if (xmss_height_from_oid(&height, key->oid) != 1) {
error_print();
return -1;
}
@@ -654,7 +674,7 @@ int sm3_xmss_public_key_to_bytes(const SM3_XMSS_KEY *key, uint8_t *out, size_t *
}
// FIXME: check input length
int sm3_xmss_public_key_from_bytes(SM3_XMSS_KEY *key, const uint8_t *in, size_t inlen)
int xmss_public_key_from_bytes(XMSS_KEY *key, const uint8_t *in, size_t inlen)
{
uint32_t height;
const uint8_t *p;
@@ -665,7 +685,7 @@ int sm3_xmss_public_key_from_bytes(SM3_XMSS_KEY *key, const uint8_t *in, size_t
}
p = in;
key->oid = uint32_from_bytes(p); p += 4;
if (sm3_xmss_height_from_oid(&height, key->oid) != 1) {
if (xmss_height_from_oid(&height, key->oid) != 1) {
error_print();
return -1;
}
@@ -674,10 +694,10 @@ int sm3_xmss_public_key_from_bytes(SM3_XMSS_KEY *key, const uint8_t *in, size_t
return 1;
}
int sm3_xmss_signature_print(FILE *fp, int fmt, int ind, const char *label, const uint8_t *in, size_t inlen)
int xmss_signature_print(FILE *fp, int fmt, int ind, const char *label, const uint8_t *in, size_t inlen)
{
uint32_t index;
SM3_XMSS_SIGNATURE *sig = (SM3_XMSS_SIGNATURE *)in;
XMSS_SIGNATURE *sig = (XMSS_SIGNATURE *)in;
int i;
format_print(fp, fmt, ind, "%s\n", label);
@@ -693,7 +713,7 @@ int sm3_xmss_signature_print(FILE *fp, int fmt, int ind, const char *label, cons
}
format_print(fp, fmt, ind, "auth_path\n");
assert(sizeof(SM3_XMSS_SIGNATURE) == 4 + 32 * (68 + 20));
assert(sizeof(XMSS_SIGNATURE) == 4 + 32 * (68 + 20));
inlen -= 4 + 32 * 68;
for (i = 0; i < 20 && inlen >= 32; i++) {
format_print(fp, fmt, ind+4, "%d ", i);
@@ -704,7 +724,31 @@ int sm3_xmss_signature_print(FILE *fp, int fmt, int ind, const char *label, cons
return 1;
}
int sm3_xmss_sign_init(SM3_XMSS_SIGN_CTX *ctx, const SM3_XMSS_KEY *key)
// (4 + n + (len + h) * n)
int xmss_signature_size(uint32_t oid, size_t *siglen)
{
size_t height;
if (!siglen) {
error_print();
return -1;
}
if (xmss_oid_to_height(oid, &height) != 1) {
error_print();
return -1;
}
*siglen = 4 // OID
+ 32 // random
+ 32 * 67 // WOTS signature size
+ 32 * height // path
;
return 1;
}
int xmss_sign_init(XMSS_SIGN_CTX *ctx, const XMSS_KEY *key)
{
HASH256_CTX prf_ctx;
uint8_t hash_id[32] = {0};
@@ -727,7 +771,7 @@ int sm3_xmss_sign_init(SM3_XMSS_SIGN_CTX *ctx, const SM3_XMSS_KEY *key)
return 1;
}
int sm3_xmss_sign_update(SM3_XMSS_SIGN_CTX *ctx, const uint8_t *data, size_t datalen)
int xmss_sign_update(XMSS_SIGN_CTX *ctx, const uint8_t *data, size_t datalen)
{
if (data && datalen) {
hash256_update(&ctx->hash256_ctx, data, datalen);
@@ -735,17 +779,17 @@ int sm3_xmss_sign_update(SM3_XMSS_SIGN_CTX *ctx, const uint8_t *data, size_t dat
return 1;
}
int sm3_xmss_sign_finish(SM3_XMSS_SIGN_CTX *ctx, const SM3_XMSS_KEY *key, uint8_t *sigbuf, size_t *siglen)
int xmss_sign_finish(XMSS_SIGN_CTX *ctx, const XMSS_KEY *key, uint8_t *sigbuf, size_t *siglen)
{
SM3_XMSS_SIGNATURE *sig = (SM3_XMSS_SIGNATURE *)sigbuf;
XMSS_SIGNATURE *sig = (XMSS_SIGNATURE *)sigbuf;
uint8_t adrs[32] = {0};
uint8_t dgst[32];
uint32_t height;
hash256_finish(&ctx->hash256_ctx, dgst);
sm3_xmss_key_get_height(key, &height);
sm3_xmss_do_sign(key->secret, key->index, key->seed, adrs, height, key->tree, dgst,
xmss_key_get_height(key, &height);
xmss_do_sign(key->secret, key->index, key->seed, adrs, height, key->tree, dgst,
sig->wots_sig, sig->auth_path);
uint32_to_bytes(key->index, sig->index);
@@ -755,9 +799,9 @@ int sm3_xmss_sign_finish(SM3_XMSS_SIGN_CTX *ctx, const SM3_XMSS_KEY *key, uint8_
return 1;
}
int sm3_xmss_verify_init(SM3_XMSS_SIGN_CTX *ctx, const SM3_XMSS_KEY *key, const uint8_t *sigbuf, size_t siglen)
int xmss_verify_init(XMSS_SIGN_CTX *ctx, const XMSS_KEY *key, const uint8_t *sigbuf, size_t siglen)
{
SM3_XMSS_SIGNATURE *sig = (SM3_XMSS_SIGNATURE *)sigbuf;
XMSS_SIGNATURE *sig = (XMSS_SIGNATURE *)sigbuf;
uint8_t hash_id[32] = {0};
uint8_t index_buf[32] = {0};
@@ -773,7 +817,7 @@ int sm3_xmss_verify_init(SM3_XMSS_SIGN_CTX *ctx, const SM3_XMSS_KEY *key, const
return 1;
}
int sm3_xmss_verify_update(SM3_XMSS_SIGN_CTX *ctx, const uint8_t *data, size_t datalen)
int xmss_verify_update(XMSS_SIGN_CTX *ctx, const uint8_t *data, size_t datalen)
{
if (data && datalen) {
hash256_update(&ctx->hash256_ctx, data, datalen);
@@ -781,10 +825,10 @@ int sm3_xmss_verify_update(SM3_XMSS_SIGN_CTX *ctx, const uint8_t *data, size_t d
return 1;
}
int sm3_xmss_verify_finish(SM3_XMSS_SIGN_CTX *ctx, const SM3_XMSS_KEY *key, const uint8_t *sigbuf, size_t siglen)
int xmss_verify_finish(XMSS_SIGN_CTX *ctx, const XMSS_KEY *key, const uint8_t *sigbuf, size_t siglen)
{
const SM3_XMSS_SIGNATURE *sig = (const SM3_XMSS_SIGNATURE *)sigbuf;
const XMSS_SIGNATURE *sig = (const XMSS_SIGNATURE *)sigbuf;
uint8_t adrs[32] = {0};
uint8_t dgst[32];
uint32_t index, height;
@@ -792,10 +836,10 @@ int sm3_xmss_verify_finish(SM3_XMSS_SIGN_CTX *ctx, const SM3_XMSS_KEY *key, cons
hash256_finish(&ctx->hash256_ctx, dgst);
sm3_xmss_key_get_height(key, &height);
xmss_key_get_height(key, &height);
index = uint32_from_bytes(sig->index);
sm3_xmss_sig_to_root(sig->wots_sig, index, sig->auth_path, key->seed, adrs, height, dgst, xmss_root);
xmss_sig_to_root(sig->wots_sig, index, sig->auth_path, key->seed, adrs, height, dgst, xmss_root);
if (memcmp(xmss_root, key->root, 32) != 0) {
error_print();