Add LMS key_update callback

This commit is contained in:
Zhi Guan
2026-01-18 12:12:45 +08:00
parent 47639a9e23
commit 9488128154
9 changed files with 355 additions and 374 deletions

334
src/lms.c
View File

@@ -14,8 +14,6 @@
#include <gmssl/lms.h>
#include <gmssl/x509_alg.h>
/*
* TODO:
* 1. add key_update callback
@@ -343,18 +341,31 @@ void lms_derive_merkle_root(const lms_hash256_t seed, const uint8_t I[16], int h
int lms_public_key_to_bytes(const LMS_KEY *key, uint8_t **out, size_t *outlen)
{
if (!key || !outlen) {
if (!key) {
error_print();
return -1;
}
if (lms_public_key_to_bytes_ex(&key->public_key, out, outlen) != 1) {
error_print();
return -1;
}
return 1;
}
int lms_public_key_to_bytes_ex(const LMS_PUBLIC_KEY *public_key, uint8_t **out, size_t *outlen)
{
if (!public_key || !outlen) {
error_print();
return -1;
}
if (out && *out) {
PUTU32(*out, key->public_key.lms_type);
PUTU32(*out, public_key->lms_type);
*out += 4;
PUTU32(*out, key->public_key.lmots_type);
PUTU32(*out, public_key->lmots_type);
*out += 4;
memcpy(*out, key->public_key.I, 16);
memcpy(*out, public_key->I, 16);
*out += 16;
memcpy(*out, key->public_key.root, 32);
memcpy(*out, public_key->root, 32);
*out += 32;
}
*outlen += LMS_PUBLIC_KEY_SIZE;
@@ -377,9 +388,9 @@ int lms_private_key_to_bytes(const LMS_KEY *key, uint8_t **out, size_t *outlen)
return 1;
}
int lms_public_key_from_bytes(LMS_KEY *key, const uint8_t **in, size_t *inlen)
int lms_public_key_from_bytes_ex(LMS_PUBLIC_KEY *public_key, const uint8_t **in, size_t *inlen)
{
if (!key || !in || !(*in) || !inlen) {
if (!public_key || !in || !(*in) || !inlen) {
error_print();
return -1;
}
@@ -388,35 +399,49 @@ int lms_public_key_from_bytes(LMS_KEY *key, const uint8_t **in, size_t *inlen)
return -1;
}
memset(key, 0, sizeof(*key));
memset(public_key, 0, sizeof(LMS_PUBLIC_KEY));
key->public_key.lms_type = GETU32(*in);
if (!lms_type_name(key->public_key.lms_type)) {
public_key->lms_type = GETU32(*in);
if (!lms_type_name(public_key->lms_type)) {
error_print();
return -1;
}
*in += 4;
*inlen -= 4;
key->public_key.lmots_type = GETU32(*in);
if (!lmots_type_name(key->public_key.lmots_type)) {
public_key->lmots_type = GETU32(*in);
if (!lmots_type_name(public_key->lmots_type)) {
error_print();
return -1;
}
*in += 4;
*inlen -= 4;
memcpy(key->public_key.I, *in, 16);
memcpy(public_key->I, *in, 16);
*in += 16;
*inlen -= 16;
memcpy(key->public_key.root, *in, 32);
memcpy(public_key->root, *in, 32);
*in += 32;
*inlen -= 32;
return 1;
}
int lms_public_key_from_bytes(LMS_KEY *key, const uint8_t **in, size_t *inlen)
{
if (!key) {
error_print();
return -1;
}
memset(key, 0, sizeof(LMS_KEY));
if (lms_public_key_from_bytes_ex(&key->public_key, in, inlen) != 1) {
error_print();
return -1;
}
return 1;
}
int lms_key_check(const LMS_KEY *key, const LMS_PUBLIC_KEY *pub)
{
// FIXME: implement this
@@ -494,22 +519,22 @@ int lms_private_key_from_bytes(LMS_KEY *key, const uint8_t **in, size_t *inlen)
return 1;
}
int lms_public_key_print(FILE *fp, int fmt, int ind, const char *label, const LMS_PUBLIC_KEY *pub)
int lms_public_key_print(FILE *fp, int fmt, int ind, const char *label, const LMS_KEY *pub)
{
format_print(fp, fmt, ind, "%s\n", label);
ind += 4;
format_print(fp, fmt, ind, "lms_type: %s\n", lms_type_name(pub->lms_type));
format_print(fp, fmt, ind, "lmots_type: %s\n", lmots_type_name(pub->lmots_type));
format_bytes(fp, fmt, ind, "I", pub->I, 16);
format_bytes(fp, fmt, ind, "root", pub->root, 32);
format_print(fp, fmt, ind, "lms_type: %s\n", lms_type_name(pub->public_key.lms_type));
format_print(fp, fmt, ind, "lmots_type: %s\n", lmots_type_name(pub->public_key.lmots_type));
format_bytes(fp, fmt, ind, "I", pub->public_key.I, 16);
format_bytes(fp, fmt, ind, "root", pub->public_key.root, 32);
return 1;
}
int lms_key_print(FILE *fp, int fmt, int ind, const char *label, const LMS_KEY *key)
int lms_private_key_print(FILE *fp, int fmt, int ind, const char *label, const LMS_KEY *key)
{
format_print(fp, fmt, ind, "%s\n", label);
ind += 4;
lms_public_key_print(fp, fmt, ind, "lms_public_key", &key->public_key);
lms_public_key_print(fp, fmt, ind, "lms_public_key", key);
format_bytes(fp, fmt, ind, "seed", key->seed, 32);
format_print(fp, fmt, ind, "q = %d\n", key->q);
if (key->tree && fmt) {
@@ -550,6 +575,8 @@ int lms_key_generate_ex(LMS_KEY *key, int lms_type, const lms_hash256_t seed, co
}
n = 1 << h;
memset(key, 0, sizeof(LMS_KEY));
key->public_key.lms_type = lms_type;
key->public_key.lmots_type = LMOTS_HASH256_N32_W8;
@@ -593,6 +620,47 @@ int lms_key_generate(LMS_KEY *key, int lms_type)
return 1;
}
int lms_key_set_update_callback(LMS_KEY *key, lms_key_update_callback update_cb, void *param)
{
if (!key) {
error_print();
return -1;
}
key->update_callback = update_cb;
key->update_param = param;
return 1;
}
int lms_key_update(LMS_KEY *key)
{
size_t height;
if (!key) {
error_print();
return -1;
}
if (lms_type_to_height(key->public_key.lms_type, &height) != 1) {
error_print();
return -1;
}
if (key->q < 0 || key->q > (1 << height)) {
error_print();
return -1;
}
if (key->q == (1 << height)) {
return 0;
}
key->q++;
if (key->update_callback) {
if (key->update_callback(key) != 1) {
error_print();
return -1;
}
}
return 1;
}
int lms_signature_size(int lms_type, size_t *len)
{
size_t height;
@@ -915,7 +983,10 @@ int lms_sign_init(LMS_SIGN_CTX *ctx, LMS_KEY *key)
lmots_derive_secrets(key->seed, key->public_key.I, key->q, lms_sig->lmots_sig.y);
// update key state, SHOULD not use the updated key->q
(key->q)++;
if (lms_key_update(key) != 1) {
error_print();
return -1;
}
lms_sig->lms_type = key->public_key.lms_type;
@@ -1121,29 +1192,12 @@ int lms_verify_finish(LMS_SIGN_CTX *ctx)
}
}
int hss_public_key_digest(const HSS_KEY *key, uint8_t dgst[32])
{
SM3_CTX ctx;
uint8_t bytes[HSS_PUBLIC_KEY_SIZE];
uint8_t *p = bytes;
size_t len;
if (hss_public_key_to_bytes(key, &p, &len) != 1) {
error_print();
return -1;
}
sm3_init(&ctx);
sm3_update(&ctx, bytes, sizeof(bytes));
sm3_finish(&ctx, dgst);
return 1;
}
int hss_public_key_print(FILE *fp, int fmt, int ind, const char *label, const HSS_KEY *key)
{
format_print(fp, fmt, ind, "%s\n", label);
ind += 4;
format_print(fp, fmt, ind, "levels: %d\n", key->levels);
lms_public_key_print(fp, fmt, ind, "lms_public_key", &key->lms_key[0].public_key);
lms_public_key_print(fp, fmt, ind, "lms_public_key", &key->lms_key[0]);
return 1;
}
@@ -1297,7 +1351,7 @@ int hss_private_key_from_bytes(HSS_KEY *key, const uint8_t **in, size_t *inlen)
return 1;
}
int hss_key_print(FILE *fp, int fmt, int ind, const char *label, const HSS_KEY *key)
int hss_private_key_print(FILE *fp, int fmt, int ind, const char *label, const HSS_KEY *key)
{
int i;
@@ -1305,14 +1359,14 @@ int hss_key_print(FILE *fp, int fmt, int ind, const char *label, const HSS_KEY *
ind += 4;
format_print(fp, fmt, ind, "levels: %d\n", key->levels);
lms_key_print(fp, fmt, ind, "lms_key[0]", &key->lms_key[0]);
lms_private_key_print(fp, fmt, ind, "lms_key[0]", &key->lms_key[0]);
for (i = 1; i < key->levels; i++) {
char title[64];
snprintf(title, sizeof(title), "lms_signature[%d]", i - 1);
lms_signature_print_ex(fp, fmt, ind, title, &key->lms_sig[i - 1]);
snprintf(title, sizeof(title), "lms_key[%d]", i);
lms_key_print(fp, fmt, ind, title, &key->lms_key[i]);
lms_private_key_print(fp, fmt, ind, title, &key->lms_key[i]);
}
return 1;
@@ -1485,13 +1539,13 @@ int hss_signature_from_bytes(HSS_SIGNATURE *sig, const uint8_t **in, size_t *inl
for (i = 0; i < sig->num_signed_public_keys; i++) {
LMS_SIGNATURE *lms_sig = &sig->signed_public_keys[i].lms_sig;
LMS_KEY *lms_key = (LMS_KEY *)&sig->signed_public_keys[i].lms_public_key;
LMS_PUBLIC_KEY *lms_key = &sig->signed_public_keys[i].lms_public_key;
if (lms_signature_from_bytes(lms_sig, in, inlen) != 1) {
error_print();
return -1;
}
if (lms_public_key_from_bytes(lms_key, in, inlen) != 1) {
if (lms_public_key_from_bytes_ex(lms_key, in, inlen) != 1) {
error_print();
return -1;
}
@@ -1528,7 +1582,7 @@ int hss_signature_to_bytes(const HSS_SIGNATURE *sig, uint8_t **out, size_t *outl
error_print();
return -1;
}
if (lms_public_key_to_bytes((LMS_KEY *)&sig->signed_public_keys[i].lms_public_key, out, &len) != 1) {
if (lms_public_key_to_bytes_ex(&sig->signed_public_keys[i].lms_public_key, out, &len) != 1) {
error_print();
return -1;
}
@@ -1544,6 +1598,16 @@ int hss_signature_to_bytes(const HSS_SIGNATURE *sig, uint8_t **out, size_t *outl
return 1;
}
int hss_key_set_update_callback(HSS_KEY *key, hss_key_update_callback update_cb, void *param)
{
if (!key) {
error_print();
return -1;
}
key->update_callback = update_cb;
key->update_param = param;
return 1;
}
int hss_key_update(HSS_KEY *key)
{
@@ -1564,11 +1628,6 @@ int hss_key_update(HSS_KEY *key)
}
// the lowest level is not out of keys
if (level >= key->levels) {
fprintf(stderr, "key->levels = %d\n", key->levels);
fprintf(stderr, "level out of key = %d\n", level);
error_print();
return -1;
}
@@ -1609,6 +1668,12 @@ int hss_key_update(HSS_KEY *key)
}
}
if (key->update_callback) {
if (key->update_callback(key) != 1) {
error_print();
return -1;
}
}
return 1;
}
@@ -1628,7 +1693,7 @@ int hss_sign_init(HSS_SIGN_CTX *ctx, HSS_KEY *key)
memset(ctx, 0, sizeof(*ctx));
if (lms_sign_init(&ctx->lms_ctx, &key->lms_key[key->levels - 1]) != 1) {
if (lms_sign_init(&ctx->lms_sign_ctx, &key->lms_key[key->levels - 1]) != 1) {
error_print();
return -1;
}
@@ -1661,7 +1726,7 @@ int hss_sign_update(HSS_SIGN_CTX *ctx, const uint8_t *data, size_t datalen)
return -1;
}
if (data && datalen) {
if (lms_sign_update(&ctx->lms_ctx, data, datalen) != 1) {
if (lms_sign_update(&ctx->lms_sign_ctx, data, datalen) != 1) {
error_print();
return -1;
}
@@ -1707,7 +1772,7 @@ int hss_sign_finish_ex(HSS_SIGN_CTX *ctx, HSS_SIGNATURE *sig)
sig->signed_public_keys[i].lms_public_key = ctx->lms_public_keys[i];
}
if (lms_sign_finish_ex(&ctx->lms_ctx, &sig->msg_lms_sig) != 1) {
if (lms_sign_finish_ex(&ctx->lms_sign_ctx, &sig->msg_lms_sig) != 1) {
error_print();
return -1;
}
@@ -1739,7 +1804,7 @@ int hss_verify_init_ex(HSS_SIGN_CTX *ctx, const HSS_KEY *key, const HSS_SIGNATUR
}
if (sig->num_signed_public_keys == 0) {
if (lms_verify_init_ex(&ctx->lms_ctx, &key->lms_key[0],
if (lms_verify_init_ex(&ctx->lms_sign_ctx, &key->lms_key[0],
&sig->msg_lms_sig) != 1) {
error_print();
return -1;
@@ -1753,17 +1818,17 @@ int hss_verify_init_ex(HSS_SIGN_CTX *ctx, const HSS_KEY *key, const HSS_SIGNATUR
return -1;
}
if (lms_verify_init_ex(&ctx->lms_ctx, &key->lms_key[0],
if (lms_verify_init_ex(&ctx->lms_sign_ctx, &key->lms_key[0],
&sig->signed_public_keys[0].lms_sig) != 1) {
error_print();
return -1;
}
if (lms_verify_update(&ctx->lms_ctx, buf, len) != 1) {
if (lms_verify_update(&ctx->lms_sign_ctx, buf, len) != 1) {
error_print();
return -1;
}
if (lms_verify_finish(&ctx->lms_ctx) != 1) {
if (lms_verify_finish(&ctx->lms_sign_ctx) != 1) {
error_print();
return -1;
}
@@ -1777,25 +1842,25 @@ int hss_verify_init_ex(HSS_SIGN_CTX *ctx, const HSS_KEY *key, const HSS_SIGNATUR
return -1;
}
if (lms_verify_init_ex(&ctx->lms_ctx,
if (lms_verify_init_ex(&ctx->lms_sign_ctx,
(LMS_KEY *)&sig->signed_public_keys[i - 1].lms_public_key,
&sig->signed_public_keys[i].lms_sig) != 1) {
error_print();
return -1;
}
if (lms_verify_update(&ctx->lms_ctx, buf, len) != 1) {
if (lms_verify_update(&ctx->lms_sign_ctx, buf, len) != 1) {
error_print();
return -1;
}
if (lms_verify_finish(&ctx->lms_ctx) != 1) {
if (lms_verify_finish(&ctx->lms_sign_ctx) != 1) {
error_print();
return -1;
}
}
// verify(pk[last], msg, msg_sig)
if (lms_verify_init_ex(&ctx->lms_ctx,
if (lms_verify_init_ex(&ctx->lms_sign_ctx,
(LMS_KEY *)&sig->signed_public_keys[sig->num_signed_public_keys - 1].lms_public_key,
&sig->msg_lms_sig) != 1) {
error_print();
@@ -1837,7 +1902,7 @@ int hss_verify_update(HSS_SIGN_CTX *ctx, const uint8_t *data, size_t datalen)
return -1;
}
if (data && datalen) {
if (lms_verify_update(&ctx->lms_ctx, data, datalen) != 1) {
if (lms_verify_update(&ctx->lms_sign_ctx, data, datalen) != 1) {
error_print();
return -1;
}
@@ -1852,7 +1917,7 @@ int hss_verify_finish(HSS_SIGN_CTX *ctx)
error_print();
return -1;
}
if ((ret = lms_verify_finish(&ctx->lms_ctx)) != 1) {
if ((ret = lms_verify_finish(&ctx->lms_sign_ctx)) != 1) {
error_print();
return ret;
}
@@ -1873,7 +1938,7 @@ int hss_signature_print_ex(FILE *fp, int fmt, int ind, const char *label, const
snprintf(title, sizeof(title), "lms_signature[%zu]", i);
lms_signature_print_ex(fp, fmt, ind, title, &sig->signed_public_keys[0].lms_sig);
snprintf(title, sizeof(title), "lms_public_key[%zu]", i + 1);
lms_public_key_print(fp, fmt, ind, title, &sig->signed_public_keys[0].lms_public_key);
lms_public_key_print(fp, fmt, ind, title, (LMS_KEY *)&sig->signed_public_keys[0]);
}
lms_signature_print_ex(fp, fmt, ind, "message_signature", &sig->msg_lms_sig);
@@ -1884,7 +1949,7 @@ int hss_signature_print(FILE *fp, int fmt, int ind, const char *label, const uin
{
LMS_SIGNATURE lms_sig;
size_t lms_siglen;
LMS_PUBLIC_KEY lms_pub;
LMS_KEY lms_key;
int num;
int i;
@@ -1913,12 +1978,12 @@ int hss_signature_print(FILE *fp, int fmt, int ind, const char *label, const uin
snprintf(title, sizeof(title), "lms_signature[%d]", i);
lms_signature_print_ex(fp, fmt, ind, title, &lms_sig);
if (lms_public_key_from_bytes((LMS_KEY *)&lms_pub, &sig, &siglen) != 1) {
if (lms_public_key_from_bytes(&lms_key, &sig, &siglen) != 1) {
error_print();
return -1;
}
snprintf(title, sizeof(title), "lms_public_key[%d]", i + 1);
lms_public_key_print(fp, fmt, ind, title, &lms_pub);
lms_public_key_print(fp, fmt, ind, title, &lms_key);
}
if (lms_signature_from_bytes(&lms_sig, &sig, &siglen) != 1) {
@@ -1956,132 +2021,9 @@ int hss_private_key_size(const int *lms_types, size_t levels, size_t *len)
return 1;
}
// X.509 related
/*
SubjectPublicKeyInfo ::= SEQUENCE {
algorithm AlgorithmIdentifier,
subjectPublicKey BIT STRING
}
in RFC 8708 (HSS/LMS in CMS)
only HSS has OID id-alg-hss-lms-hashsig (1.2.840.113549.1.9.16.3.17)
so only HSS public key is supported in x509_ functions
hss_public_key_to_bytes encode HSS_PUBLIC_KEY to SubjectPublicKeyInfo.subjectPublicKey
the public_key raw data (not OCTET STRING TLV) is encoded into a BIT STRING TLV
same as EC_POINT
*/
int hss_public_key_to_der(const HSS_KEY *key, uint8_t **out, size_t *outlen)
void hss_sign_ctx_cleanup(HSS_SIGN_CTX *ctx)
{
uint8_t octets[HSS_PUBLIC_KEY_SIZE];
uint8_t *p = octets;
size_t len = 0;
if (!key) {
return 0;
if (ctx) {
lms_sign_ctx_cleanup(&ctx->lms_sign_ctx);
}
if (hss_public_key_to_bytes(key, &p, &len) != 1) {
error_print();
return -1;
}
if (len != sizeof(octets)) {
error_print();
return -1;
}
if (asn1_bit_octets_to_der(octets, sizeof(octets), out, outlen) != 1) {
error_print();
return -1;
}
return 1;
}
int hss_public_key_from_der(HSS_KEY *key, const uint8_t **in, size_t *inlen)
{
int ret;
const uint8_t *d;
size_t dlen;
if ((ret = asn1_bit_octets_from_der(&d, &dlen, in, inlen)) != 1) {
if (ret < 0) error_print();
return ret;
}
if (dlen != HSS_PUBLIC_KEY_SIZE) {
error_print();
return -1;
}
if (hss_public_key_from_bytes(key, &d, &dlen) != 1) {
error_print();
return -1;
}
if (dlen) {
error_print();
return -1;
}
return 1;
}
int hss_public_key_algor_to_der(uint8_t **out, size_t *outlen)
{
if (x509_public_key_algor_to_der(OID_hss_lms_hashsig, OID_undef, out, outlen) != 1) {
error_print();
return -1;
}
return 1;
}
int hss_public_key_algor_from_der(const uint8_t **in, size_t *inlen)
{
int ret;
int oid;
int param = OID_undef;
if ((ret = x509_public_key_algor_from_der(&oid, &param, in, inlen)) != 1) {
if (ret < 0) error_print();
return ret;
}
if (oid != OID_hss_lms_hashsig) {
error_print();
return -1;
}
// param == 0: parameter is empty
// param == 1: parameter is null object
// param == other values, x509_public_key_algor_from_der fail
return 1;
}
int hss_public_key_info_to_der(const HSS_KEY *key, uint8_t **out, size_t *outlen)
{
size_t len = 0;
if (hss_public_key_algor_to_der(NULL, &len) != 1
|| hss_public_key_to_der(key, NULL, &len) != 1
|| asn1_sequence_header_to_der(len, out, outlen) != 1
|| hss_public_key_algor_to_der(out, outlen) != 1
|| hss_public_key_to_der(key, out, outlen) != 1) {
error_print();
return -1;
}
return 1;
}
int hss_public_key_info_from_der(HSS_KEY *key, const uint8_t **in, size_t *inlen)
{
int ret;
const uint8_t *d;
size_t dlen;
if ((ret = asn1_sequence_from_der(&d, &dlen, in, inlen)) != 1) {
if (ret < 0) error_print();
return ret;
}
if (hss_public_key_algor_from_der(&d, &dlen) != 1
|| hss_public_key_from_der(key, &d, &dlen) != 1
|| asn1_length_is_zero(dlen) != 1) {
error_print();
return -1;
}
return 1;
}

View File

@@ -162,7 +162,7 @@ int x509_public_key_print(FILE *fp, int fmt, int ind, const char *label, const X
}
break;
case OID_lms_hashsig:
if (lms_public_key_print(fp, fmt, ind, label, &key->u.lms_key.public_key) != 1) {
if (lms_public_key_print(fp, fmt, ind, label, &key->u.lms_key) != 1) {
error_print();
return -1;
}