From b0c5208a687daaac25f59c6aeee40945a5f67504 Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Thu, 4 Jan 2024 09:37:12 +0800 Subject: [PATCH] Support SM3 third-party implementation --- CMakeLists.txt | 1 + include/gmssl/sm3_digest.h | 40 +++++++++++++++++++ src/sm3_digest.c | 80 ++++++++++++++++++++++++++++++++++++++ tools/sm3.c | 30 ++++++++++---- tools/sm3hmac.c | 25 +++++++++--- 5 files changed, 163 insertions(+), 13 deletions(-) create mode 100644 include/gmssl/sm3_digest.h create mode 100644 src/sm3_digest.c diff --git a/CMakeLists.txt b/CMakeLists.txt index 931c3f65..41c02181 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,6 +17,7 @@ set(src src/sm3.c src/sm3_hmac.c src/sm3_kdf.c + src/sm3_digest.c src/sm2_alg.c src/sm2_key.c src/sm2_lib.c diff --git a/include/gmssl/sm3_digest.h b/include/gmssl/sm3_digest.h new file mode 100644 index 00000000..e8d0b564 --- /dev/null +++ b/include/gmssl/sm3_digest.h @@ -0,0 +1,40 @@ +/* + * Copyright 2014-2024 The GmSSL Project. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ + + +#ifndef GMSSL_SM3_DIGEST_H +#define GMSSL_SM3_DIGEST_H + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + + +typedef struct { + union { + SM3_CTX sm3_ctx; + SM3_HMAC_CTX hmac_ctx; + void *handle; + }; + int state; +} SM3_DIGEST_CTX; + +int sm3_digest_init(SM3_DIGEST_CTX *ctx, const uint8_t *key, size_t keylen); +int sm3_digest_update(SM3_DIGEST_CTX *ctx, const uint8_t *data, size_t datalen); +int sm3_digest_finish(SM3_DIGEST_CTX *ctx, uint8_t dgst[SM3_DIGEST_SIZE]); + + +#ifdef __cplusplus +} +#endif +#endif diff --git a/src/sm3_digest.c b/src/sm3_digest.c new file mode 100644 index 00000000..e1faf06a --- /dev/null +++ b/src/sm3_digest.c @@ -0,0 +1,80 @@ +/* + * Copyright 2014-2024 The GmSSL Project. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ + + +#include +#include +#include + + +int sm3_digest_init(SM3_DIGEST_CTX *ctx, const uint8_t *key, size_t keylen) +{ + if (!ctx) { + error_print(); + return -1; + } + + memset(ctx, 0, sizeof(*ctx)); + + if (!key) { + sm3_init(&ctx->sm3_ctx); + ctx->state = 1; + } else { + if (keylen < 12 || keylen > 64) { + error_print(); + return -1; + } + sm3_hmac_init(&ctx->hmac_ctx, key, keylen); + ctx->state = 2; + } + + return 1; +} + +int sm3_digest_update(SM3_DIGEST_CTX *ctx, const uint8_t *data, size_t datalen) +{ + if (!ctx) { + error_print(); + return -1; + } + if (!data || !datalen) { + error_print(); + return -1; + } + + if (ctx->state == 1) { + sm3_update(&ctx->sm3_ctx, data, datalen); + } else if (ctx->state == 2) { + sm3_hmac_update(&ctx->hmac_ctx, data, datalen); + } else { + error_print(); + return -1; + } + return 1; +} + +int sm3_digest_finish(SM3_DIGEST_CTX *ctx, uint8_t dgst[SM3_DIGEST_SIZE]) +{ + if (!ctx || !dgst) { + error_print(); + return -1; + } + + if (ctx->state == 1) { + sm3_finish(&ctx->sm3_ctx, dgst); + } else if (ctx->state == 2) { + sm3_hmac_finish(&ctx->hmac_ctx, dgst); + } else { + error_print(); + return -1; + } + + memset(ctx, 0, sizeof(*ctx)); + return 1; +} diff --git a/tools/sm3.c b/tools/sm3.c index 380c01f3..d125b739 100644 --- a/tools/sm3.c +++ b/tools/sm3.c @@ -13,7 +13,7 @@ #include #include #include -#include +#include #include #include @@ -75,7 +75,7 @@ int sm3_main(int argc, char **argv) FILE *outfp = stdout; uint8_t id_bin[64]; size_t id_bin_len; - SM3_CTX sm3_ctx; + SM3_DIGEST_CTX sm3_ctx; uint8_t dgst[32]; int i; @@ -171,7 +171,10 @@ bad: goto end; } - sm3_init(&sm3_ctx); + if (sm3_digest_init(&sm3_ctx, NULL, 0) != 1) { + fprintf(stderr, "%s: inner error\n", prog); + goto end; + } if (pubkeyfile) { SM2_KEY sm2_key; @@ -191,20 +194,33 @@ bad: sm2_compute_z(z, (SM2_POINT *)&sm2_key, id, strlen(id)); } - sm3_update(&sm3_ctx, z, sizeof(z)); + if (sm3_digest_update(&sm3_ctx, z, sizeof(z)) != 1) { + fprintf(stderr, "%s: inner error\n", prog); + goto end; + } } if (in_str) { - sm3_update(&sm3_ctx, (uint8_t *)in_str, strlen(in_str)); + if (sm3_digest_update(&sm3_ctx, (uint8_t *)in_str, strlen(in_str)) != 1) { + fprintf(stderr, "%s: inner error\n", prog); + goto end; + } + } else { uint8_t buf[4096]; size_t len; while ((len = fread(buf, 1, sizeof(buf), infp)) > 0) { - sm3_update(&sm3_ctx, buf, len); + if (sm3_digest_update(&sm3_ctx, buf, len) != 1) { + fprintf(stderr, "%s: inner error\n", prog); + goto end; + } } memset(buf, 0, sizeof(buf)); } - sm3_finish(&sm3_ctx, dgst); + if (sm3_digest_finish(&sm3_ctx, dgst) != 1) { + fprintf(stderr, "%s: inner error\n", prog); + goto end; + } memset(&sm3_ctx, 0, sizeof(sm3_ctx)); if (outformat > 1) { diff --git a/tools/sm3hmac.c b/tools/sm3hmac.c index 28f1cb11..90d57e23 100644 --- a/tools/sm3hmac.c +++ b/tools/sm3hmac.c @@ -14,7 +14,7 @@ #include #include #include -#include +#include static const char *usage = "-key hex [-in file | -in_str str] [-bin|-hex] [-out file]"; @@ -63,7 +63,7 @@ int sm3hmac_main(int argc, char **argv) size_t keylen; FILE *infp = stdin; FILE *outfp = stdout; - SM3_HMAC_CTX ctx; + SM3_DIGEST_CTX ctx; uint8_t mac[SM3_HMAC_SIZE]; size_t i; @@ -142,18 +142,31 @@ bad: goto end; } - sm3_hmac_init(&ctx, key, keylen); + if (sm3_digest_init(&ctx, key, keylen) != 1) { + fprintf(stderr, "%s: inner error\n", prog); + goto end; + } + if (in_str) { - sm3_hmac_update(&ctx, (uint8_t *)in_str, strlen(in_str)); + if (sm3_digest_update(&ctx, (uint8_t *)in_str, strlen(in_str)) != 1) { + fprintf(stderr, "%s: inner error\n", prog); + goto end; + } } else { uint8_t buf[4096]; size_t len; while ((len = fread(buf, 1, sizeof(buf), infp)) > 0) { - sm3_hmac_update(&ctx, buf, len); + if (sm3_digest_update(&ctx, buf, len) != 1) { + fprintf(stderr, "%s: inner error\n", prog); + goto end; + } } memset(buf, 0, sizeof(buf)); } - sm3_hmac_finish(&ctx, mac); + if (sm3_digest_finish(&ctx, mac) != 1) { + fprintf(stderr, "%s: inner error\n", prog); + goto end; + } if (outformat > 1) { if (fwrite(mac, 1, sizeof(mac), outfp) != sizeof(mac)) {