From 866e3aef675876900f9bc9d3c5bbb9ab72ec9cdb Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Sun, 21 Jun 2026 00:01:32 +0800 Subject: [PATCH] Fix #1900 --- CMakeLists.txt | 2 +- include/gmssl/version.h | 2 +- src/sm2_enc.c | 16 +++++++++++----- tests/sm2_enctest.c | 36 +++++++++++++++++++++++++++++++++++- 4 files changed, 48 insertions(+), 8 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 18b84348..a46bceba 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -874,7 +874,7 @@ endif() # set(CPACK_PACKAGE_NAME "GmSSL") set(CPACK_PACKAGE_VENDOR "GmSSL develop team") -set(CPACK_PACKAGE_VERSION "3.2.0-dev.1131") +set(CPACK_PACKAGE_VERSION "3.2.0-dev.1132") set(CPACK_PACKAGE_DESCRIPTION_FILE ${PROJECT_SOURCE_DIR}/README.md) set(CPACK_NSIS_MODIFY_PATH ON) include(CPack) diff --git a/include/gmssl/version.h b/include/gmssl/version.h index d20f4847..9baeec5e 100644 --- a/include/gmssl/version.h +++ b/include/gmssl/version.h @@ -18,7 +18,7 @@ extern "C" { #define GMSSL_VERSION_NUM 30200 -#define GMSSL_VERSION_STR "GmSSL 3.2.0-dev.1131" +#define GMSSL_VERSION_STR "GmSSL 3.2.0-dev.1132" int gmssl_version_num(void); const char *gmssl_version_str(void); diff --git a/src/sm2_enc.c b/src/sm2_enc.c index 97bf0d3d..e11a4438 100644 --- a/src/sm2_enc.c +++ b/src/sm2_enc.c @@ -322,6 +322,9 @@ int sm2_do_decrypt(const SM2_KEY *key, const SM2_CIPHERTEXT *in, uint8_t *out, s uint8_t x2y2[64]; SM3_CTX sm3_ctx; uint8_t hash[32]; + uint8_t plaintext[SM2_MAX_PLAINTEXT_SIZE]; + + *outlen = 0; // check C1 is on sm2 curve if (sm2_z256_point_from_bytes(&C1, (uint8_t *)&in->point) != 1) { @@ -334,20 +337,19 @@ int sm2_do_decrypt(const SM2_KEY *key, const SM2_CIPHERTEXT *in, uint8_t *out, s // t = KDF(x2 || y2, klen) and check t is not all zeros sm2_z256_point_to_bytes(&C1, x2y2); - sm2_kdf(x2y2, 64, in->ciphertext_size, out); - if (all_zero(out, in->ciphertext_size)) { + sm2_kdf(x2y2, 64, in->ciphertext_size, plaintext); + if (all_zero(plaintext, in->ciphertext_size)) { error_print(); goto end; } // M = C2 xor t - gmssl_memxor(out, out, in->ciphertext, in->ciphertext_size); - *outlen = in->ciphertext_size; + gmssl_memxor(plaintext, plaintext, in->ciphertext, in->ciphertext_size); // u = Hash(x2 || M || y2) sm3_init(&sm3_ctx); sm3_update(&sm3_ctx, x2y2, 32); - sm3_update(&sm3_ctx, out, in->ciphertext_size); + sm3_update(&sm3_ctx, plaintext, in->ciphertext_size); sm3_update(&sm3_ctx, x2y2 + 32, 32); sm3_finish(&sm3_ctx, hash); @@ -356,11 +358,15 @@ int sm2_do_decrypt(const SM2_KEY *key, const SM2_CIPHERTEXT *in, uint8_t *out, s error_print(); goto end; } + memcpy(out, plaintext, in->ciphertext_size); + *outlen = in->ciphertext_size; ret = 1; end: gmssl_secure_clear(&C1, sizeof(SM2_Z256_POINT)); gmssl_secure_clear(x2y2, sizeof(x2y2)); + gmssl_secure_clear(hash, sizeof(hash)); + gmssl_secure_clear(plaintext, sizeof(plaintext)); return ret; } diff --git a/tests/sm2_enctest.c b/tests/sm2_enctest.c index 7c6db322..2963ce88 100644 --- a/tests/sm2_enctest.c +++ b/tests/sm2_enctest.c @@ -113,6 +113,40 @@ static int test_sm2_do_encrypt(void) return 1; } +static int test_sm2_do_decrypt_bad_hash_does_not_output_plaintext(void) +{ + SM2_KEY sm2_key; + uint8_t plaintext[] = "Hello World!"; + SM2_CIPHERTEXT ciphertext; + uint8_t out[SM2_MAX_PLAINTEXT_SIZE]; + size_t outlen = sizeof(out); + size_t i; + + if (sm2_key_generate(&sm2_key) != 1 + || sm2_do_encrypt(&sm2_key, plaintext, sizeof(plaintext), &ciphertext) != 1) { + error_print(); + return -1; + } + + ciphertext.hash[0] ^= 0x01; + memset(out, 0xa5, sizeof(out)); + + if (sm2_do_decrypt(&sm2_key, &ciphertext, out, &outlen) != -1 + || outlen != 0) { + error_print(); + return -1; + } + for (i = 0; i < sizeof(out); i++) { + if (out[i] != 0xa5) { + error_print(); + return -1; + } + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + static int test_sm2_do_encrypt_fixlen(void) { struct { @@ -323,6 +357,7 @@ int main(void) { if (test_sm2_ciphertext() != 1) goto err; if (test_sm2_do_encrypt() != 1) goto err; + if (test_sm2_do_decrypt_bad_hash_does_not_output_plaintext() != 1) goto err; if (test_sm2_do_encrypt_fixlen() != 1) goto err; if (test_sm2_encrypt() != 1) goto err; if (test_sm2_encrypt_fixlen() != 1) goto err; @@ -335,4 +370,3 @@ err: error_print(); return -1; } -