From 94881281543bd44734ae437fed4807fd2b7e3cbe Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Sun, 18 Jan 2026 12:12:45 +0800 Subject: [PATCH] Add LMS key_update callback --- include/gmssl/lms.h | 266 +++++++++++++++++------------------ src/lms.c | 334 ++++++++++++++++++-------------------------- src/x509_key.c | 2 +- tests/lmstest.c | 19 +-- tools/gmssl.c | 6 +- tools/hsssign.c | 47 +++++-- tools/lmskeygen.c | 2 +- tools/lmssign.c | 51 +++++-- tools/lmsverify.c | 2 +- 9 files changed, 355 insertions(+), 374 deletions(-) diff --git a/include/gmssl/lms.h b/include/gmssl/lms.h index f64ca9b7..f0b04974 100644 --- a/include/gmssl/lms.h +++ b/include/gmssl/lms.h @@ -1,5 +1,5 @@ /* - * Copyright 2014-2025 The GmSSL Project. All Rights Reserved. + * Copyright 2014-2026 The GmSSL Project. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the License); you may * not use this file except in compliance with the License. @@ -31,79 +31,48 @@ extern "C" { typedef uint8_t lms_hash256_t[32]; - // Crosscheck with data from LMS-reference (SHA-256), except the LMS signature. -#if defined(ENABLE_LMS_CROSSCHECK) && defined(ENABLE_SHA2) && !defined(LMS_HASH256_CTX) -# define LMS_HASH256_CTX SHA256_CTX -# define lms_hash256_init sha256_init -# define lms_hash256_update sha256_update -# define lms_hash256_finish sha256_finish +#if defined(ENABLE_LMS_CROSSCHECK) && defined(ENABLE_SHA2) +#define LMS_HASH256_CTX SHA256_CTX +#define lms_hash256_init sha256_init +#define lms_hash256_update sha256_update +#define lms_hash256_finish sha256_finish #else -# define LMS_HASH256_CTX SM3_CTX -# define lms_hash256_init sm3_init -# define lms_hash256_update sm3_update -# define lms_hash256_finish sm3_finish +#define LMS_HASH256_CTX SM3_CTX +#define lms_hash256_init sm3_init +#define lms_hash256_update sm3_update +#define lms_hash256_finish sm3_finish #endif #if defined(ENABLE_LMS_CROSSCHECK) && defined(ENABLE_SHA2) -# define LMOTS_HASH256_N32_W8 LMOTS_SHA256_N32_W8 -# define LMOTS_HASH256_N32_W8_NAME "LMOTS_SHA256_N32_W8" -# define LMS_HASH256_M32_H5 LMS_SHA256_M32_H5 -# define LMS_HASH256_M32_H10 LMS_SHA256_M32_H10 -# define LMS_HASH256_M32_H15 LMS_SHA256_M32_H15 -# define LMS_HASH256_M32_H20 LMS_SHA256_M32_H20 -# define LMS_HASH256_M32_H25 LMS_SHA256_M32_H25 -# define LMS_HASH256_M32_H5_NAME "LMS_SHA256_M32_H5" -# define LMS_HASH256_M32_H10_NAME "LMS_SHA256_M32_H10" -# define LMS_HASH256_M32_H15_NAME "LMS_SHA256_M32_H15" -# define LMS_HASH256_M32_H20_NAME "LMS_SHA256_M32_H20" -# define LMS_HASH256_M32_H25_NAME "LMS_SHA256_M32_H25" -#else -# define LMOTS_HASH256_N32_W8 LMOTS_SM3_N32_W8 -# define LMOTS_HASH256_N32_W8_NAME "LMOTS_SM3_N32_W8" -# define LMS_HASH256_M32_H5 LMS_SM3_M32_H5 -# define LMS_HASH256_M32_H10 LMS_SM3_M32_H10 -# define LMS_HASH256_M32_H15 LMS_SM3_M32_H15 -# define LMS_HASH256_M32_H20 LMS_SM3_M32_H20 -# define LMS_HASH256_M32_H25 LMS_SM3_M32_H25 -# define LMS_HASH256_M32_H5_NAME "LMS_SM3_M32_H5" -# define LMS_HASH256_M32_H10_NAME "LMS_SM3_M32_H10" -# define LMS_HASH256_M32_H15_NAME "LMS_SM3_M32_H15" -# define LMS_HASH256_M32_H20_NAME "LMS_SM3_M32_H20" -# define LMS_HASH256_M32_H25_NAME "LMS_SM3_M32_H25" -#endif - - enum { - LMOTS_RESERVED = 0, - LMOTS_SHA256_N32_W1 = 1, - LMOTS_SHA256_N32_W2 = 2, - LMOTS_SHA256_N32_W4 = 3, + //LMOTS_SHA256_N32_W1 = 1, + //LMOTS_SHA256_N32_W2 = 2, + //LMOTS_SHA256_N32_W4 = 3, LMOTS_SHA256_N32_W8 = 4, - LMOTS_SM3_N32_W1 = 11, - LMOTS_SM3_N32_W2 = 12, - LMOTS_SM3_N32_W4 = 13, +}; +#define LMOTS_HASH256_N32_W8 LMOTS_SHA256_N32_W8 +#define LMOTS_HASH256_N32_W8_NAME "LMOTS_SHA256_N32_W8" +#else +enum { + //LMOTS_SM3_N32_W1 = 11, + //LMOTS_SM3_N32_W2 = 12, + //LMOTS_SM3_N32_W4 = 13, LMOTS_SM3_N32_W8 = 14, }; - -enum { -#if defined(ENABLE_LMS_CROSSCHECK) && defined(ENABLE_SHA2) - LMS_SHA256_M32_H5 = 5, - LMS_SHA256_M32_H10 = 6, - LMS_SHA256_M32_H15 = 7, - LMS_SHA256_M32_H20 = 8, - LMS_SHA256_M32_H25 = 9, -#else - // TODO: submit to IETF - LMS_SM3_M32_H5 = 5, - LMS_SM3_M32_H10 = 6, - LMS_SM3_M32_H15 = 7, - LMS_SM3_M32_H20 = 8, - LMS_SM3_M32_H25 = 9, +#define LMOTS_HASH256_N32_W8 LMOTS_SM3_N32_W8 +#define LMOTS_HASH256_N32_W8_NAME "LMOTS_SM3_N32_W8" #endif -}; +// in LMS, we use Winternitz w = 2^8 = 256 +// represent 256-bit hash as 256/8 = 32 base_w numbers +// max checksum is 255 * 32 = 8160 < 2^13 = 8192, so checksum need two 8-bit base_w number +// so total hash chains is 32 + 2 = 34 +#define LMOTS_NUM_CHAINS 34 + +typedef lms_hash256_t lmots_key_t[34]; +typedef lms_hash256_t lmots_sig_t[34]; char *lmots_type_name(int lmots_type); void lmots_derive_secrets(const lms_hash256_t seed, const uint8_t I[16], int q, lms_hash256_t x[34]); @@ -112,12 +81,54 @@ void lmots_compute_signature(const uint8_t I[16], int q, const lms_hash256_t dgs void lmots_signature_to_public_hash(const uint8_t I[16], int q, const lms_hash256_t y[34], const lms_hash256_t dgst, lms_hash256_t pub); -char *lms_type_name(int lms_type); -int lms_type_from_name(const char *name); -int lms_type_to_height(int type, size_t *height); -void lms_derive_merkle_tree(const lms_hash256_t seed, const uint8_t I[16], int height, lms_hash256_t *tree); -void lms_derive_merkle_root(const lms_hash256_t seed, const uint8_t I[16], int height, lms_hash256_t root); +#if defined(ENABLE_LMS_CROSSCHECK) && defined(ENABLE_SHA2) +enum { + LMS_SHA256_M32_H5 = 5, + LMS_SHA256_M32_H10 = 6, + LMS_SHA256_M32_H15 = 7, + LMS_SHA256_M32_H20 = 8, + LMS_SHA256_M32_H25 = 9, +}; +#else +// TODO: submit to IETF +enum { + LMS_SM3_M32_H5 = 5, + LMS_SM3_M32_H10 = 6, + LMS_SM3_M32_H15 = 7, + LMS_SM3_M32_H20 = 8, + LMS_SM3_M32_H25 = 9, +}; +#endif +#if defined(ENABLE_LMS_CROSSCHECK) && defined(ENABLE_SHA2) +# define LMS_HASH256_M32_H5 LMS_SHA256_M32_H5 +# define LMS_HASH256_M32_H5_NAME "LMS_SHA256_M32_H5" +# define LMS_HASH256_M32_H10 LMS_SHA256_M32_H10 +# define LMS_HASH256_M32_H10_NAME "LMS_SHA256_M32_H10" +# define LMS_HASH256_M32_H15 LMS_SHA256_M32_H15 +# define LMS_HASH256_M32_H15_NAME "LMS_SHA256_M32_H15" +# define LMS_HASH256_M32_H20 LMS_SHA256_M32_H20 +# define LMS_HASH256_M32_H20_NAME "LMS_SHA256_M32_H20" +# define LMS_HASH256_M32_H25 LMS_SHA256_M32_H25 +# define LMS_HASH256_M32_H25_NAME "LMS_SHA256_M32_H25" +#else +# define LMS_HASH256_M32_H5 LMS_SM3_M32_H5 +# define LMS_HASH256_M32_H5_NAME "LMS_SM3_M32_H5" +# define LMS_HASH256_M32_H10 LMS_SM3_M32_H10 +# define LMS_HASH256_M32_H10_NAME "LMS_SM3_M32_H10" +# define LMS_HASH256_M32_H15 LMS_SM3_M32_H15 +# define LMS_HASH256_M32_H15_NAME "LMS_SM3_M32_H15" +# define LMS_HASH256_M32_H20 LMS_SM3_M32_H20 +# define LMS_HASH256_M32_H20_NAME "LMS_SM3_M32_H20" +# define LMS_HASH256_M32_H25 LMS_SM3_M32_H25 +# define LMS_HASH256_M32_H25_NAME "LMS_SM3_M32_H25" +#endif + +char *lms_type_name(int lms_type); +int lms_type_from_name(const char *name); +int lms_type_to_height(int type, size_t *height); +void lms_derive_merkle_tree(const lms_hash256_t seed, const uint8_t I[16], int height, lms_hash256_t *tree); +void lms_derive_merkle_root(const lms_hash256_t seed, const uint8_t I[16], int height, lms_hash256_t root); typedef struct { int lms_type; @@ -128,75 +139,73 @@ typedef struct { #define LMS_PUBLIC_KEY_SIZE (4 + 4 + 16 + 32) // = 56 bytes -typedef struct { +typedef struct LMS_KEY_st LMS_KEY; + +typedef int (*lms_key_update_callback)(LMS_KEY *key); + +typedef struct LMS_KEY_st { LMS_PUBLIC_KEY public_key; - lms_hash256_t *tree; lms_hash256_t seed; - uint32_t q; // in [0, 2^h - 1], q++ after every sign // 应该改为index + uint32_t q; // key index + lms_hash256_t *tree; + lms_key_update_callback update_callback; + void *update_param; } LMS_KEY; #define LMS_PRIVATE_KEY_SIZE (LMS_PUBLIC_KEY_SIZE + 32 + 4) // = 92 bytes -// FIXME: do we need a function to update lms_key->q ? - int lms_key_generate_ex(LMS_KEY *key, int lms_type, const lms_hash256_t seed, const uint8_t I[16], int cache_tree); int lms_key_generate(LMS_KEY *key, int lms_type); -int lms_key_check(const LMS_KEY *key, const LMS_PUBLIC_KEY *pub); +int lms_key_set_update_callback(LMS_KEY *key, lms_key_update_callback update_cb, void *param); +int lms_key_update(LMS_KEY *key); int lms_key_remaining_signs(const LMS_KEY *key, size_t *count); - -int lms_public_key_to_bytes(const LMS_KEY *key, uint8_t **out, size_t *outlen); -int lms_public_key_from_bytes_ex(const LMS_PUBLIC_KEY **key, const uint8_t **in, size_t *inlen); // 这个函数需要修改 -int lms_public_key_from_bytes(LMS_KEY *key, const uint8_t **in, size_t *inlen); -int lms_private_key_to_bytes(const LMS_KEY *key, uint8_t **out, size_t *outlen); -int lms_private_key_from_bytes(LMS_KEY *key, const uint8_t **in, size_t *inlen); -int lms_public_key_print(FILE *fp, int fmt, int ind, const char *label, const LMS_PUBLIC_KEY *pub); -int lms_key_print(FILE *fp, int fmt, int ind, const char *label, const LMS_KEY *key); // +int lms_key_get_signature_size(const LMS_KEY *key, size_t *siglen); +int lms_key_check(const LMS_KEY *key, const LMS_PUBLIC_KEY *pub); void lms_key_cleanup(LMS_KEY *key); +int lms_public_key_to_bytes_ex(const LMS_PUBLIC_KEY *public_key, uint8_t **out, size_t *outlen); +int lms_public_key_from_bytes_ex(LMS_PUBLIC_KEY *public_key, const uint8_t **in, size_t *inlen); +int lms_public_key_to_bytes(const LMS_KEY *key, uint8_t **out, size_t *outlen); +int lms_public_key_from_bytes(LMS_KEY *key, const uint8_t **in, size_t *inlen); +int lms_public_key_print(FILE *fp, int fmt, int ind, const char *label, const LMS_KEY *pub); +int lms_private_key_to_bytes(const LMS_KEY *key, uint8_t **out, size_t *outlen); +int lms_private_key_from_bytes(LMS_KEY *key, const uint8_t **in, size_t *inlen); +int lms_private_key_print(FILE *fp, int fmt, int ind, const char *label, const LMS_KEY *key); typedef struct { - int q; // index of LMS tree leaf, in [0, 2^h - 1] + int q; // key index struct { - int lmots_type; // LMOTS_SM3_N32_W8 or LMOTS_SHA256_N32_W8 in compile time - lms_hash256_t C; // randomness of every LMOTS signature - lms_hash256_t y[34]; // for w = 8 and hash256, 34 winternitz chains + int lmots_type; + lms_hash256_t C; // signature random + lms_hash256_t y[34]; } lmots_sig; int lms_type; - lms_hash256_t path[25]; // max tree height = 25 when LMS_SM3_M32_H25 + lms_hash256_t path[LMS_MAX_HEIGHT]; } LMS_SIGNATURE; -// encoded size, SHOULD be changed when supporting text/der encoding -#define LMS_SIGNATURE_MIN_SIZE (4 + 4 + 32 + 32*34 + 4 + 32*5) // = 1292 bytes -#define LMS_SIGNATURE_MAX_SIZE (4 + 4 + 32 + 32*34 + 4 + 32*25) // = 1932 bytes - - int lms_signature_to_merkle_root(const uint8_t I[16], size_t h, int q, const lms_hash256_t y[34], const lms_hash256_t *path, const lms_hash256_t dgst, lms_hash256_t root); +#define LMS_HASH256_M32_H5_SIGNATURE_SIZE 1292 +#define LMS_HASH256_M32_H10_SIGNATURE_SIZE 1452 +#define LMS_HASH256_M32_H15_SIGNATURE_SIZE 1612 +#define LMS_HASH256_M32_H20_SIGNATURE_SIZE 1772 +#define LMS_HASH256_M32_H25_SIGNATURE_SIZE 1932 +#define LMS_SIGNATURE_MIN_SIZE LMS_HASH256_M32_H5_SIGNATURE_SIZE // = 4 + 4 + 32 + 32*34 + 4 + 32*5 = 1292 bytes +#define LMS_SIGNATURE_MAX_SIZE LMS_HASH256_M32_H25_SIGNATURE_SIZE // = 4 + 4 + 32 + 32*34 + 4 + 32*25 = 1932 bytes -/* - * LMS_HASH256_M32_H5 1292 - * LMS_HASH256_M32_H10 1452 - * LMS_HASH256_M32_H15 1612 - * LMS_HASH256_M32_H20 1772 - * LMS_HASH256_M32_H25 1932 - */ int lms_signature_size(int lms_type, size_t *siglen); -int lms_key_get_signature_size(const LMS_KEY *key, size_t *siglen); - int lms_signature_to_bytes(const LMS_SIGNATURE *sig, uint8_t **out, size_t *outlen); -int lms_signature_from_bytes_ex(const LMS_SIGNATURE **sig, size_t *siglen, const uint8_t **in, size_t *inlen);// 这个接口有点奇怪,siglen? int lms_signature_from_bytes(LMS_SIGNATURE *sig, const uint8_t **in, size_t *inlen); int lms_signature_print_ex(FILE *fp, int fmt, int ind, const char *label, const LMS_SIGNATURE *sig); int lms_signature_print(FILE *fp, int fmt, int ind, const char *label, const uint8_t *sig, size_t siglen); - typedef struct { LMS_HASH256_CTX lms_hash256_ctx; - LMS_PUBLIC_KEY lms_public_key; // FIXME: or use LMS_PUBLIC_KEY to re-use tree? - LMS_SIGNATURE lms_sig; + LMS_PUBLIC_KEY lms_public_key; + LMS_SIGNATURE lms_sig; // cache lmots x[34] } LMS_SIGN_CTX; int lms_sign_init(LMS_SIGN_CTX *ctx, LMS_KEY *key); @@ -204,49 +213,48 @@ int lms_sign_update(LMS_SIGN_CTX *ctx, const uint8_t *data, size_t datalen); int lms_sign_finish_ex(LMS_SIGN_CTX *ctx, LMS_SIGNATURE *sig); int lms_sign_finish(LMS_SIGN_CTX *ctx, uint8_t *sig, size_t *siglen); int lms_verify_init_ex(LMS_SIGN_CTX *ctx, const LMS_KEY *key, const LMS_SIGNATURE *sig); -int lms_verify_init(LMS_SIGN_CTX *ctx, const LMS_KEY *key, const uint8_t *sigbuf, size_t siglen); +int lms_verify_init(LMS_SIGN_CTX *ctx, const LMS_KEY *key, const uint8_t *sig, size_t siglen); int lms_verify_update(LMS_SIGN_CTX *ctx, const uint8_t *data, size_t datalen); int lms_verify_finish(LMS_SIGN_CTX *ctx); - -// `lms_sign_init` copy lmots private to ctx->lms_sig.y -// call `lms_sign_ctx_cleanup` incase `lms_sign_finish` not called nor finished void lms_sign_ctx_cleanup(LMS_SIGN_CTX *ctx); + // just for reference, HSS_PUBLIC_KEY memory layout might not compatible with HSS_KEY typedef struct { uint32_t levels; LMS_PUBLIC_KEY lms_public_key; } HSS_PUBLIC_KEY; -// HSS_PUBLIC_KEY: { level, lms_key[0].public_key } #define HSS_PUBLIC_KEY_SIZE (4 + LMS_PUBLIC_KEY_SIZE) +typedef struct HSS_KEY_st HSS_KEY; -// TODO: LMS_KEY should be a tree other than a vector -// when updated, low level lms keys will lost, maybe a good feature -typedef struct { - uint32_t levels; // should be checked to prevent memory error +typedef int (*hss_key_update_callback)(HSS_KEY *key); + +typedef struct HSS_KEY_st { + uint32_t levels; LMS_KEY lms_key[5]; LMS_SIGNATURE lms_sig[4]; + hss_key_update_callback update_callback; + void *update_param; } HSS_KEY; - #define HSS_PRIVATE_KEY_MAX_SIZE sizeof(HSS_KEY) int hss_private_key_size(const int *lms_types, size_t levels, size_t *len); int hss_key_generate(HSS_KEY *key, const int *lms_types, size_t levels); +int hss_key_set_update_callback(HSS_KEY *key, hss_key_update_callback update_cb, void *param); int hss_key_update(HSS_KEY *key); +int hss_key_get_signature_size(const HSS_KEY *key, size_t *siglen); +void hss_key_cleanup(HSS_KEY *key); -int hss_public_key_digest(const HSS_KEY *key, uint8_t dgst[32]); int hss_public_key_to_bytes(const HSS_KEY *key, uint8_t **out, size_t *outlen); int hss_private_key_to_bytes(const HSS_KEY *key, uint8_t **out, size_t *outlen); int hss_public_key_from_bytes(HSS_KEY *key, const uint8_t **in, size_t *inlen); int hss_private_key_from_bytes(HSS_KEY *key, const uint8_t **in, size_t *inlen); int hss_public_key_print(FILE *fp, int fmt, int ind, const char *label, const HSS_KEY *key); -int hss_key_print(FILE *fp, int fmt, int ind, const char *label, const HSS_KEY *key); -void hss_key_cleanup(HSS_KEY *key); - +int hss_private_key_print(FILE *fp, int fmt, int ind, const char *label, const HSS_KEY *key); typedef struct { uint32_t num_signed_public_keys; // = hss_key->levels - 1 @@ -257,45 +265,29 @@ typedef struct { LMS_SIGNATURE msg_lms_sig; // = sign(hss->lms_key[levels-1], msg) } HSS_SIGNATURE; - #define HSS_SIGNATURE_MAX_SIZE sizeof(HSS_SIGNATURE) int hss_signature_size(const int *lms_types, size_t levels, size_t *len); -int hss_key_get_signature_size(const HSS_KEY *key, size_t *siglen); - int hss_signature_to_bytes(const HSS_SIGNATURE *sig, uint8_t **out, size_t *outlen); int hss_signature_from_bytes(HSS_SIGNATURE *sig, const uint8_t **in, size_t *inlen); int hss_signature_print_ex(FILE *fp, int fmt, int ind, const char *label, const HSS_SIGNATURE *sig); int hss_signature_print(FILE *fp, int fmt, int ind, const char *label, const uint8_t *sig, size_t siglen); - typedef struct { - LMS_SIGN_CTX lms_ctx; + LMS_SIGN_CTX lms_sign_ctx; uint32_t levels; LMS_SIGNATURE lms_sigs[HSS_MAX_LEVELS - 1]; LMS_PUBLIC_KEY lms_public_keys[HSS_MAX_LEVELS - 1]; } HSS_SIGN_CTX; - int hss_sign_init(HSS_SIGN_CTX *ctx, HSS_KEY *key); int hss_sign_update(HSS_SIGN_CTX *ctx, const uint8_t *data, size_t datalen); -int hss_sign_finish(HSS_SIGN_CTX *ctx, uint8_t *sig, size_t *siglen); int hss_sign_finish_ex(HSS_SIGN_CTX *ctx, HSS_SIGNATURE *sig); +int hss_sign_finish(HSS_SIGN_CTX *ctx, uint8_t *sig, size_t *siglen); int hss_verify_init_ex(HSS_SIGN_CTX *ctx, const HSS_KEY *key, const HSS_SIGNATURE *sig); int hss_verify_init(HSS_SIGN_CTX *ctx, const HSS_KEY *key, const uint8_t *sigbuf, size_t siglen); int hss_verify_update(HSS_SIGN_CTX *ctx, const uint8_t *data, size_t datalen); int hss_verify_finish(HSS_SIGN_CTX *ctx); - - -// X.509 related -#define HSS_PUBLIC_KEY_DER_SIZE 63 -#define HSS_PUBLIC_KEY_INFO_SIZE 82 - -int hss_public_key_to_der(const HSS_KEY *key, uint8_t **out, size_t *outlen); -int hss_public_key_from_der(HSS_KEY *key, const uint8_t **in, size_t *inlen); -int hss_public_key_algor_to_der(uint8_t **out, size_t *outlen); -int hss_public_key_algor_from_der(const uint8_t **in, size_t *inlen); -int hss_public_key_info_to_der(const HSS_KEY *key, uint8_t **out, size_t *outlen); -int hss_public_key_info_from_der(HSS_KEY *key, const uint8_t **in, size_t *inlen); +void hss_sign_ctx_cleanup(HSS_SIGN_CTX *ctx); #ifdef __cplusplus diff --git a/src/lms.c b/src/lms.c index fa760a7c..cbf4657c 100644 --- a/src/lms.c +++ b/src/lms.c @@ -14,8 +14,6 @@ #include #include - - /* * TODO: * 1. add key_update callback @@ -343,18 +341,31 @@ void lms_derive_merkle_root(const lms_hash256_t seed, const uint8_t I[16], int h int lms_public_key_to_bytes(const LMS_KEY *key, uint8_t **out, size_t *outlen) { - if (!key || !outlen) { + if (!key) { + error_print(); + return -1; + } + if (lms_public_key_to_bytes_ex(&key->public_key, out, outlen) != 1) { + error_print(); + return -1; + } + return 1; +} + +int lms_public_key_to_bytes_ex(const LMS_PUBLIC_KEY *public_key, uint8_t **out, size_t *outlen) +{ + if (!public_key || !outlen) { error_print(); return -1; } if (out && *out) { - PUTU32(*out, key->public_key.lms_type); + PUTU32(*out, public_key->lms_type); *out += 4; - PUTU32(*out, key->public_key.lmots_type); + PUTU32(*out, public_key->lmots_type); *out += 4; - memcpy(*out, key->public_key.I, 16); + memcpy(*out, public_key->I, 16); *out += 16; - memcpy(*out, key->public_key.root, 32); + memcpy(*out, public_key->root, 32); *out += 32; } *outlen += LMS_PUBLIC_KEY_SIZE; @@ -377,9 +388,9 @@ int lms_private_key_to_bytes(const LMS_KEY *key, uint8_t **out, size_t *outlen) return 1; } -int lms_public_key_from_bytes(LMS_KEY *key, const uint8_t **in, size_t *inlen) +int lms_public_key_from_bytes_ex(LMS_PUBLIC_KEY *public_key, const uint8_t **in, size_t *inlen) { - if (!key || !in || !(*in) || !inlen) { + if (!public_key || !in || !(*in) || !inlen) { error_print(); return -1; } @@ -388,35 +399,49 @@ int lms_public_key_from_bytes(LMS_KEY *key, const uint8_t **in, size_t *inlen) return -1; } - memset(key, 0, sizeof(*key)); + memset(public_key, 0, sizeof(LMS_PUBLIC_KEY)); - key->public_key.lms_type = GETU32(*in); - if (!lms_type_name(key->public_key.lms_type)) { + public_key->lms_type = GETU32(*in); + if (!lms_type_name(public_key->lms_type)) { error_print(); return -1; } *in += 4; *inlen -= 4; - key->public_key.lmots_type = GETU32(*in); - if (!lmots_type_name(key->public_key.lmots_type)) { + public_key->lmots_type = GETU32(*in); + if (!lmots_type_name(public_key->lmots_type)) { error_print(); return -1; } *in += 4; *inlen -= 4; - memcpy(key->public_key.I, *in, 16); + memcpy(public_key->I, *in, 16); *in += 16; *inlen -= 16; - memcpy(key->public_key.root, *in, 32); + memcpy(public_key->root, *in, 32); *in += 32; *inlen -= 32; return 1; } +int lms_public_key_from_bytes(LMS_KEY *key, const uint8_t **in, size_t *inlen) +{ + if (!key) { + error_print(); + return -1; + } + memset(key, 0, sizeof(LMS_KEY)); + if (lms_public_key_from_bytes_ex(&key->public_key, in, inlen) != 1) { + error_print(); + return -1; + } + return 1; +} + int lms_key_check(const LMS_KEY *key, const LMS_PUBLIC_KEY *pub) { // FIXME: implement this @@ -494,22 +519,22 @@ int lms_private_key_from_bytes(LMS_KEY *key, const uint8_t **in, size_t *inlen) return 1; } -int lms_public_key_print(FILE *fp, int fmt, int ind, const char *label, const LMS_PUBLIC_KEY *pub) +int lms_public_key_print(FILE *fp, int fmt, int ind, const char *label, const LMS_KEY *pub) { format_print(fp, fmt, ind, "%s\n", label); ind += 4; - format_print(fp, fmt, ind, "lms_type: %s\n", lms_type_name(pub->lms_type)); - format_print(fp, fmt, ind, "lmots_type: %s\n", lmots_type_name(pub->lmots_type)); - format_bytes(fp, fmt, ind, "I", pub->I, 16); - format_bytes(fp, fmt, ind, "root", pub->root, 32); + format_print(fp, fmt, ind, "lms_type: %s\n", lms_type_name(pub->public_key.lms_type)); + format_print(fp, fmt, ind, "lmots_type: %s\n", lmots_type_name(pub->public_key.lmots_type)); + format_bytes(fp, fmt, ind, "I", pub->public_key.I, 16); + format_bytes(fp, fmt, ind, "root", pub->public_key.root, 32); return 1; } -int lms_key_print(FILE *fp, int fmt, int ind, const char *label, const LMS_KEY *key) +int lms_private_key_print(FILE *fp, int fmt, int ind, const char *label, const LMS_KEY *key) { format_print(fp, fmt, ind, "%s\n", label); ind += 4; - lms_public_key_print(fp, fmt, ind, "lms_public_key", &key->public_key); + lms_public_key_print(fp, fmt, ind, "lms_public_key", key); format_bytes(fp, fmt, ind, "seed", key->seed, 32); format_print(fp, fmt, ind, "q = %d\n", key->q); if (key->tree && fmt) { @@ -550,6 +575,8 @@ int lms_key_generate_ex(LMS_KEY *key, int lms_type, const lms_hash256_t seed, co } n = 1 << h; + memset(key, 0, sizeof(LMS_KEY)); + key->public_key.lms_type = lms_type; key->public_key.lmots_type = LMOTS_HASH256_N32_W8; @@ -593,6 +620,47 @@ int lms_key_generate(LMS_KEY *key, int lms_type) return 1; } +int lms_key_set_update_callback(LMS_KEY *key, lms_key_update_callback update_cb, void *param) +{ + if (!key) { + error_print(); + return -1; + } + key->update_callback = update_cb; + key->update_param = param; + return 1; +} + +int lms_key_update(LMS_KEY *key) +{ + size_t height; + + if (!key) { + error_print(); + return -1; + } + if (lms_type_to_height(key->public_key.lms_type, &height) != 1) { + error_print(); + return -1; + } + if (key->q < 0 || key->q > (1 << height)) { + error_print(); + return -1; + } + if (key->q == (1 << height)) { + return 0; + } + key->q++; + + if (key->update_callback) { + if (key->update_callback(key) != 1) { + error_print(); + return -1; + } + } + return 1; +} + int lms_signature_size(int lms_type, size_t *len) { size_t height; @@ -915,7 +983,10 @@ int lms_sign_init(LMS_SIGN_CTX *ctx, LMS_KEY *key) lmots_derive_secrets(key->seed, key->public_key.I, key->q, lms_sig->lmots_sig.y); // update key state, SHOULD not use the updated key->q - (key->q)++; + if (lms_key_update(key) != 1) { + error_print(); + return -1; + } lms_sig->lms_type = key->public_key.lms_type; @@ -1121,29 +1192,12 @@ int lms_verify_finish(LMS_SIGN_CTX *ctx) } } -int hss_public_key_digest(const HSS_KEY *key, uint8_t dgst[32]) -{ - SM3_CTX ctx; - uint8_t bytes[HSS_PUBLIC_KEY_SIZE]; - uint8_t *p = bytes; - size_t len; - - if (hss_public_key_to_bytes(key, &p, &len) != 1) { - error_print(); - return -1; - } - sm3_init(&ctx); - sm3_update(&ctx, bytes, sizeof(bytes)); - sm3_finish(&ctx, dgst); - return 1; -} - int hss_public_key_print(FILE *fp, int fmt, int ind, const char *label, const HSS_KEY *key) { format_print(fp, fmt, ind, "%s\n", label); ind += 4; format_print(fp, fmt, ind, "levels: %d\n", key->levels); - lms_public_key_print(fp, fmt, ind, "lms_public_key", &key->lms_key[0].public_key); + lms_public_key_print(fp, fmt, ind, "lms_public_key", &key->lms_key[0]); return 1; } @@ -1297,7 +1351,7 @@ int hss_private_key_from_bytes(HSS_KEY *key, const uint8_t **in, size_t *inlen) return 1; } -int hss_key_print(FILE *fp, int fmt, int ind, const char *label, const HSS_KEY *key) +int hss_private_key_print(FILE *fp, int fmt, int ind, const char *label, const HSS_KEY *key) { int i; @@ -1305,14 +1359,14 @@ int hss_key_print(FILE *fp, int fmt, int ind, const char *label, const HSS_KEY * ind += 4; format_print(fp, fmt, ind, "levels: %d\n", key->levels); - lms_key_print(fp, fmt, ind, "lms_key[0]", &key->lms_key[0]); + lms_private_key_print(fp, fmt, ind, "lms_key[0]", &key->lms_key[0]); for (i = 1; i < key->levels; i++) { char title[64]; snprintf(title, sizeof(title), "lms_signature[%d]", i - 1); lms_signature_print_ex(fp, fmt, ind, title, &key->lms_sig[i - 1]); snprintf(title, sizeof(title), "lms_key[%d]", i); - lms_key_print(fp, fmt, ind, title, &key->lms_key[i]); + lms_private_key_print(fp, fmt, ind, title, &key->lms_key[i]); } return 1; @@ -1485,13 +1539,13 @@ int hss_signature_from_bytes(HSS_SIGNATURE *sig, const uint8_t **in, size_t *inl for (i = 0; i < sig->num_signed_public_keys; i++) { LMS_SIGNATURE *lms_sig = &sig->signed_public_keys[i].lms_sig; - LMS_KEY *lms_key = (LMS_KEY *)&sig->signed_public_keys[i].lms_public_key; + LMS_PUBLIC_KEY *lms_key = &sig->signed_public_keys[i].lms_public_key; if (lms_signature_from_bytes(lms_sig, in, inlen) != 1) { error_print(); return -1; } - if (lms_public_key_from_bytes(lms_key, in, inlen) != 1) { + if (lms_public_key_from_bytes_ex(lms_key, in, inlen) != 1) { error_print(); return -1; } @@ -1528,7 +1582,7 @@ int hss_signature_to_bytes(const HSS_SIGNATURE *sig, uint8_t **out, size_t *outl error_print(); return -1; } - if (lms_public_key_to_bytes((LMS_KEY *)&sig->signed_public_keys[i].lms_public_key, out, &len) != 1) { + if (lms_public_key_to_bytes_ex(&sig->signed_public_keys[i].lms_public_key, out, &len) != 1) { error_print(); return -1; } @@ -1544,6 +1598,16 @@ int hss_signature_to_bytes(const HSS_SIGNATURE *sig, uint8_t **out, size_t *outl return 1; } +int hss_key_set_update_callback(HSS_KEY *key, hss_key_update_callback update_cb, void *param) +{ + if (!key) { + error_print(); + return -1; + } + key->update_callback = update_cb; + key->update_param = param; + return 1; +} int hss_key_update(HSS_KEY *key) { @@ -1564,11 +1628,6 @@ int hss_key_update(HSS_KEY *key) } // the lowest level is not out of keys if (level >= key->levels) { - - fprintf(stderr, "key->levels = %d\n", key->levels); - fprintf(stderr, "level out of key = %d\n", level); - - error_print(); return -1; } @@ -1609,6 +1668,12 @@ int hss_key_update(HSS_KEY *key) } } + if (key->update_callback) { + if (key->update_callback(key) != 1) { + error_print(); + return -1; + } + } return 1; } @@ -1628,7 +1693,7 @@ int hss_sign_init(HSS_SIGN_CTX *ctx, HSS_KEY *key) memset(ctx, 0, sizeof(*ctx)); - if (lms_sign_init(&ctx->lms_ctx, &key->lms_key[key->levels - 1]) != 1) { + if (lms_sign_init(&ctx->lms_sign_ctx, &key->lms_key[key->levels - 1]) != 1) { error_print(); return -1; } @@ -1661,7 +1726,7 @@ int hss_sign_update(HSS_SIGN_CTX *ctx, const uint8_t *data, size_t datalen) return -1; } if (data && datalen) { - if (lms_sign_update(&ctx->lms_ctx, data, datalen) != 1) { + if (lms_sign_update(&ctx->lms_sign_ctx, data, datalen) != 1) { error_print(); return -1; } @@ -1707,7 +1772,7 @@ int hss_sign_finish_ex(HSS_SIGN_CTX *ctx, HSS_SIGNATURE *sig) sig->signed_public_keys[i].lms_public_key = ctx->lms_public_keys[i]; } - if (lms_sign_finish_ex(&ctx->lms_ctx, &sig->msg_lms_sig) != 1) { + if (lms_sign_finish_ex(&ctx->lms_sign_ctx, &sig->msg_lms_sig) != 1) { error_print(); return -1; } @@ -1739,7 +1804,7 @@ int hss_verify_init_ex(HSS_SIGN_CTX *ctx, const HSS_KEY *key, const HSS_SIGNATUR } if (sig->num_signed_public_keys == 0) { - if (lms_verify_init_ex(&ctx->lms_ctx, &key->lms_key[0], + if (lms_verify_init_ex(&ctx->lms_sign_ctx, &key->lms_key[0], &sig->msg_lms_sig) != 1) { error_print(); return -1; @@ -1753,17 +1818,17 @@ int hss_verify_init_ex(HSS_SIGN_CTX *ctx, const HSS_KEY *key, const HSS_SIGNATUR return -1; } - if (lms_verify_init_ex(&ctx->lms_ctx, &key->lms_key[0], + if (lms_verify_init_ex(&ctx->lms_sign_ctx, &key->lms_key[0], &sig->signed_public_keys[0].lms_sig) != 1) { error_print(); return -1; } - if (lms_verify_update(&ctx->lms_ctx, buf, len) != 1) { + if (lms_verify_update(&ctx->lms_sign_ctx, buf, len) != 1) { error_print(); return -1; } - if (lms_verify_finish(&ctx->lms_ctx) != 1) { + if (lms_verify_finish(&ctx->lms_sign_ctx) != 1) { error_print(); return -1; } @@ -1777,25 +1842,25 @@ int hss_verify_init_ex(HSS_SIGN_CTX *ctx, const HSS_KEY *key, const HSS_SIGNATUR return -1; } - if (lms_verify_init_ex(&ctx->lms_ctx, + if (lms_verify_init_ex(&ctx->lms_sign_ctx, (LMS_KEY *)&sig->signed_public_keys[i - 1].lms_public_key, &sig->signed_public_keys[i].lms_sig) != 1) { error_print(); return -1; } - if (lms_verify_update(&ctx->lms_ctx, buf, len) != 1) { + if (lms_verify_update(&ctx->lms_sign_ctx, buf, len) != 1) { error_print(); return -1; } - if (lms_verify_finish(&ctx->lms_ctx) != 1) { + if (lms_verify_finish(&ctx->lms_sign_ctx) != 1) { error_print(); return -1; } } // verify(pk[last], msg, msg_sig) - if (lms_verify_init_ex(&ctx->lms_ctx, + if (lms_verify_init_ex(&ctx->lms_sign_ctx, (LMS_KEY *)&sig->signed_public_keys[sig->num_signed_public_keys - 1].lms_public_key, &sig->msg_lms_sig) != 1) { error_print(); @@ -1837,7 +1902,7 @@ int hss_verify_update(HSS_SIGN_CTX *ctx, const uint8_t *data, size_t datalen) return -1; } if (data && datalen) { - if (lms_verify_update(&ctx->lms_ctx, data, datalen) != 1) { + if (lms_verify_update(&ctx->lms_sign_ctx, data, datalen) != 1) { error_print(); return -1; } @@ -1852,7 +1917,7 @@ int hss_verify_finish(HSS_SIGN_CTX *ctx) error_print(); return -1; } - if ((ret = lms_verify_finish(&ctx->lms_ctx)) != 1) { + if ((ret = lms_verify_finish(&ctx->lms_sign_ctx)) != 1) { error_print(); return ret; } @@ -1873,7 +1938,7 @@ int hss_signature_print_ex(FILE *fp, int fmt, int ind, const char *label, const snprintf(title, sizeof(title), "lms_signature[%zu]", i); lms_signature_print_ex(fp, fmt, ind, title, &sig->signed_public_keys[0].lms_sig); snprintf(title, sizeof(title), "lms_public_key[%zu]", i + 1); - lms_public_key_print(fp, fmt, ind, title, &sig->signed_public_keys[0].lms_public_key); + lms_public_key_print(fp, fmt, ind, title, (LMS_KEY *)&sig->signed_public_keys[0]); } lms_signature_print_ex(fp, fmt, ind, "message_signature", &sig->msg_lms_sig); @@ -1884,7 +1949,7 @@ int hss_signature_print(FILE *fp, int fmt, int ind, const char *label, const uin { LMS_SIGNATURE lms_sig; size_t lms_siglen; - LMS_PUBLIC_KEY lms_pub; + LMS_KEY lms_key; int num; int i; @@ -1913,12 +1978,12 @@ int hss_signature_print(FILE *fp, int fmt, int ind, const char *label, const uin snprintf(title, sizeof(title), "lms_signature[%d]", i); lms_signature_print_ex(fp, fmt, ind, title, &lms_sig); - if (lms_public_key_from_bytes((LMS_KEY *)&lms_pub, &sig, &siglen) != 1) { + if (lms_public_key_from_bytes(&lms_key, &sig, &siglen) != 1) { error_print(); return -1; } snprintf(title, sizeof(title), "lms_public_key[%d]", i + 1); - lms_public_key_print(fp, fmt, ind, title, &lms_pub); + lms_public_key_print(fp, fmt, ind, title, &lms_key); } if (lms_signature_from_bytes(&lms_sig, &sig, &siglen) != 1) { @@ -1956,132 +2021,9 @@ int hss_private_key_size(const int *lms_types, size_t levels, size_t *len) return 1; } -// X.509 related - -/* - SubjectPublicKeyInfo ::= SEQUENCE { - algorithm AlgorithmIdentifier, - subjectPublicKey BIT STRING - } - - in RFC 8708 (HSS/LMS in CMS) - only HSS has OID id-alg-hss-lms-hashsig (1.2.840.113549.1.9.16.3.17) - so only HSS public key is supported in x509_ functions - - hss_public_key_to_bytes encode HSS_PUBLIC_KEY to SubjectPublicKeyInfo.subjectPublicKey - the public_key raw data (not OCTET STRING TLV) is encoded into a BIT STRING TLV - same as EC_POINT -*/ -int hss_public_key_to_der(const HSS_KEY *key, uint8_t **out, size_t *outlen) +void hss_sign_ctx_cleanup(HSS_SIGN_CTX *ctx) { - uint8_t octets[HSS_PUBLIC_KEY_SIZE]; - uint8_t *p = octets; - size_t len = 0; - - if (!key) { - return 0; + if (ctx) { + lms_sign_ctx_cleanup(&ctx->lms_sign_ctx); } - - if (hss_public_key_to_bytes(key, &p, &len) != 1) { - error_print(); - return -1; - } - if (len != sizeof(octets)) { - error_print(); - return -1; - } - if (asn1_bit_octets_to_der(octets, sizeof(octets), out, outlen) != 1) { - error_print(); - return -1; - } - return 1; -} - -int hss_public_key_from_der(HSS_KEY *key, const uint8_t **in, size_t *inlen) -{ - int ret; - const uint8_t *d; - size_t dlen; - - if ((ret = asn1_bit_octets_from_der(&d, &dlen, in, inlen)) != 1) { - if (ret < 0) error_print(); - return ret; - } - if (dlen != HSS_PUBLIC_KEY_SIZE) { - error_print(); - return -1; - } - - if (hss_public_key_from_bytes(key, &d, &dlen) != 1) { - error_print(); - return -1; - } - if (dlen) { - error_print(); - return -1; - } - - return 1; -} - -int hss_public_key_algor_to_der(uint8_t **out, size_t *outlen) -{ - if (x509_public_key_algor_to_der(OID_hss_lms_hashsig, OID_undef, out, outlen) != 1) { - error_print(); - return -1; - } - return 1; -} - -int hss_public_key_algor_from_der(const uint8_t **in, size_t *inlen) -{ - int ret; - int oid; - int param = OID_undef; - - if ((ret = x509_public_key_algor_from_der(&oid, ¶m, in, inlen)) != 1) { - if (ret < 0) error_print(); - return ret; - } - if (oid != OID_hss_lms_hashsig) { - error_print(); - return -1; - } - // param == 0: parameter is empty - // param == 1: parameter is null object - // param == other values, x509_public_key_algor_from_der fail - return 1; -} - -int hss_public_key_info_to_der(const HSS_KEY *key, uint8_t **out, size_t *outlen) -{ - size_t len = 0; - if (hss_public_key_algor_to_der(NULL, &len) != 1 - || hss_public_key_to_der(key, NULL, &len) != 1 - || asn1_sequence_header_to_der(len, out, outlen) != 1 - || hss_public_key_algor_to_der(out, outlen) != 1 - || hss_public_key_to_der(key, out, outlen) != 1) { - error_print(); - return -1; - } - return 1; -} - -int hss_public_key_info_from_der(HSS_KEY *key, const uint8_t **in, size_t *inlen) -{ - int ret; - const uint8_t *d; - size_t dlen; - - if ((ret = asn1_sequence_from_der(&d, &dlen, in, inlen)) != 1) { - if (ret < 0) error_print(); - return ret; - } - if (hss_public_key_algor_from_der(&d, &dlen) != 1 - || hss_public_key_from_der(key, &d, &dlen) != 1 - || asn1_length_is_zero(dlen) != 1) { - error_print(); - return -1; - } - return 1; } diff --git a/src/x509_key.c b/src/x509_key.c index 06e2c057..645cf970 100644 --- a/src/x509_key.c +++ b/src/x509_key.c @@ -162,7 +162,7 @@ int x509_public_key_print(FILE *fp, int fmt, int ind, const char *label, const X } break; case OID_lms_hashsig: - if (lms_public_key_print(fp, fmt, ind, label, &key->u.lms_key.public_key) != 1) { + if (lms_public_key_print(fp, fmt, ind, label, &key->u.lms_key) != 1) { error_print(); return -1; } diff --git a/tests/lmstest.c b/tests/lmstest.c index 7004bdef..d6fe59d5 100644 --- a/tests/lmstest.c +++ b/tests/lmstest.c @@ -294,7 +294,7 @@ static int test_lms_key_generate(void) error_print(); return -1; } - //lms_key_print(stdout, 0, 0, "lms_key", &lms_key); + lms_private_key_print(stdout, 0, 0, "lms_private_key", &lms_key); printf("%s() ok\n", __FUNCTION__); return 1; @@ -341,13 +341,13 @@ static int test_lms_key_to_bytes(void) error_print(); return -1; } - lms_key_print(stdout, 0, 4, "lms_public_key", &key); + lms_public_key_print(stdout, 0, 4, "lms_public_key", &key); if (lms_private_key_from_bytes(&key, &cp, &len) != 1) { error_print(); return -1; } - lms_key_print(stdout, 0, 4, "lms_private_key", &key); + lms_private_key_print(stdout, 0, 4, "lms_private_key", &key); if (len != 0) { error_print(); return -1; @@ -539,7 +539,7 @@ static int test_hss_key_generate(void) } hss_public_key_print(stdout, 0, 4, "hss_public_key", &key); - hss_key_print(stdout, 0, 4, "hss_key", &key); + hss_private_key_print(stdout, 0, 4, "hss_key", &key); printf("%s() ok\n", __FUNCTION__); return 1; @@ -799,7 +799,7 @@ static int test_hss_key_to_bytes(void) error_print(); return -1; } - hss_key_print(stdout, 0, 4, "lms_private_key", &key); + hss_private_key_print(stdout, 0, 4, "lms_private_key", &key); if (len != 0) { error_print(); return -1; @@ -868,7 +868,7 @@ static int test_hss_sign_level2(void) error_print(); return -1; } - hss_key_print(stderr, 0, 4, "hss_key", &key); + hss_private_key_print(stderr, 0, 4, "hss_key", &key); if (hss_sign_init(&ctx, &key) != 1) { @@ -916,7 +916,7 @@ static int test_hss_sign(void) error_print(); return -1; } - hss_key_print(stderr, 0, 4, "hss_key", &key); + hss_private_key_print(stderr, 0, 4, "hss_key", &key); if (hss_sign_init(&ctx, &key) != 1) { @@ -951,6 +951,7 @@ static int test_hss_sign(void) return 1; } +/* static int test_hss_public_key_algor(void) { int lms_types[] = { @@ -1025,7 +1026,7 @@ static int test_hss_public_key_algor(void) return 1; } - +*/ int main(void) { @@ -1048,7 +1049,7 @@ int main(void) if (test_hss_sign_level1() != 1) goto err; if (test_hss_sign_level2() != 1) goto err; if (test_hss_sign() != 1) goto err; - if (test_hss_public_key_algor() != 1) goto err; +// if (test_hss_public_key_algor() != 1) goto err; printf("%s all tests passed\n", __FILE__); return 0; diff --git a/tools/gmssl.c b/tools/gmssl.c index 460a0fee..d2fe4576 100644 --- a/tools/gmssl.c +++ b/tools/gmssl.c @@ -64,7 +64,7 @@ extern int tls12_client_main(int argc, char **argv); extern int tls12_server_main(int argc, char **argv); extern int tls13_client_main(int argc, char **argv); extern int tls13_server_main(int argc, char **argv); -#ifdef ENABLE_LMS_HSS +#ifdef ENABLE_LMS extern int lmskeygen_main(int argc, char **argv); extern int lmssign_main(int argc, char **argv); extern int lmsverify_main(int argc, char **argv); @@ -154,7 +154,7 @@ static const char *options = " cmsdecrypt Decrypt CMS EnvelopedData\n" " cmssign Generate CMS SignedData\n" " cmsverify Verify CMS SignedData\n" -#ifdef ENABLE_LMS_HSS +#ifdef ENABLE_LMS " lmskeygen Generate LMS-SM3 (Leighton-Micali Signature) keypair\n" " lmssign Generate LMS-SM3 signature\n" " lmsverify Verify LMS-SM3 signature\n" @@ -334,7 +334,7 @@ int main(int argc, char **argv) return tls13_client_main(argc, argv); } else if (!strcmp(*argv, "tls13_server")) { return tls13_server_main(argc, argv); -#ifdef ENABLE_LMS_HSS +#ifdef ENABLE_LMS } else if (!strcmp(*argv, "lmskeygen")) { return lmskeygen_main(argc, argv); } else if (!strcmp(*argv, "lmssign")) { diff --git a/tools/hsssign.c b/tools/hsssign.c index f73f09ee..0c8b0c51 100644 --- a/tools/hsssign.c +++ b/tools/hsssign.c @@ -1,5 +1,5 @@ /* - * Copyright 2014-2025 The GmSSL Project. All Rights Reserved. + * Copyright 2014-2026 The GmSSL Project. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the License); you may * not use this file except in compliance with the License. @@ -26,6 +26,36 @@ static const char *options = " -verbose Print public key and signature\n" "\n"; +static int key_update_cb(HSS_KEY *key) +{ + FILE *fp; + uint8_t buf[HSS_PRIVATE_KEY_MAX_SIZE]; + uint8_t *p = buf; + size_t len = 0; + + if (!key->update_param) { + error_print(); + return -1; + } + fp = (FILE *)key->update_param; + + if (hss_private_key_to_bytes(key, &p, &len) != 1) { + error_print(); + return -1; + } + rewind(fp); + if (fwrite(buf, 1, len, fp) != len + || fflush(fp) != 0) { + gmssl_secure_clear(buf, sizeof(buf)); + error_print(); + return -1; + } + // TODO: need fsync to make sure data is written to disk + // but fsync need , not std C + gmssl_secure_clear(buf, sizeof(buf)); + return 1; +} + int hsssign_main(int argc, char **argv) { int ret = 1; @@ -112,28 +142,21 @@ bad: } if (keylen) { error_print(); - return -1; + goto end; } if (verbose) { hss_public_key_print(stderr, 0, 0, "hss_public_key", &key); } - if (hss_sign_init(&ctx, &key) != 1) { + if (hss_key_set_update_callback(&key, key_update_cb, keyfp) != 1) { error_print(); goto end; } - // write updated key back to file - // TODO: write back `q` only - if (hss_private_key_to_bytes(&key, &p, &keylen) != 1) { + if (hss_sign_init(&ctx, &key) != 1) { error_print(); - return -1; - } - rewind(keyfp); - if (fwrite(keybuf, 1, keylen, keyfp) != keylen) { - error_print(); - return -1; + goto end; } while (1) { diff --git a/tools/lmskeygen.c b/tools/lmskeygen.c index 2b949c72..21a4ed41 100644 --- a/tools/lmskeygen.c +++ b/tools/lmskeygen.c @@ -115,7 +115,7 @@ bad: return -1; } if (verbose) { - lms_public_key_print(stderr, 0, 0, "lms_public_key", &key.public_key); + lms_public_key_print(stderr, 0, 0, "lms_public_key", &key); } if (lms_private_key_to_bytes(&key, &pout, &outlen) != 1) { diff --git a/tools/lmssign.c b/tools/lmssign.c index 741d3db5..d8eb93ea 100644 --- a/tools/lmssign.c +++ b/tools/lmssign.c @@ -1,5 +1,5 @@ /* - * Copyright 2014-2025 The GmSSL Project. All Rights Reserved. + * Copyright 2014-2026 The GmSSL Project. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the License); you may * not use this file except in compliance with the License. @@ -26,6 +26,36 @@ static const char *options = " -verbose Print public key and signature\n" "\n"; +static int key_update_cb(LMS_KEY *key) +{ + FILE *fp; + uint8_t buf[LMS_PRIVATE_KEY_SIZE]; + uint8_t *p = buf; + size_t len = 0; + + if (!key->update_param) { + error_print(); + return -1; + } + fp = (FILE *)key->update_param; + + if (lms_private_key_to_bytes(key, &p, &len) != 1) { + error_print(); + return -1; + } + rewind(fp); + if (fwrite(buf, 1, len, fp) != len + || fflush(fp) != 0) { + gmssl_secure_clear(buf, sizeof(buf)); + error_print(); + return -1; + } + // TODO: need fsync to make sure data is written to disk + // but fsync need , not std C + gmssl_secure_clear(buf, sizeof(buf)); + return 1; +} + int lmssign_main(int argc, char **argv) { int ret = 1; @@ -116,7 +146,12 @@ bad: } if (verbose) { - lms_public_key_print(stderr, 0, 0, "lms_public_key", &key.public_key); + lms_public_key_print(stderr, 0, 0, "lms_public_key", &key); + } + + if (lms_key_set_update_callback(&key, key_update_cb, keyfp) != 1) { + error_print(); + goto end; } if (lms_sign_init(&ctx, &key) != 1) { @@ -124,18 +159,6 @@ bad: goto end; } - // write updated key back to file - // TODO: write back `q` only - if (lms_private_key_to_bytes(&key, &p, &keylen) != 1) { - error_print(); - return -1; - } - rewind(keyfp); - if (fwrite(keybuf, 1, keylen, keyfp) != keylen) { - error_print(); - return -1; - } - while (1) { uint8_t buf[1024]; size_t len = fread(buf, 1, sizeof(buf), infp); diff --git a/tools/lmsverify.c b/tools/lmsverify.c index e68d0638..d0433c97 100644 --- a/tools/lmsverify.c +++ b/tools/lmsverify.c @@ -113,7 +113,7 @@ bad: goto end; } if (verbose) { - lms_public_key_print(stderr, 0, 0, "lms_public_key", &key.public_key); + lms_public_key_print(stderr, 0, 0, "lms_public_key", &key); } // read signature even if signature not compatible with the public key