From 3f1fdc147a18132b9ec401a70126d40ec6d7e770 Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Thu, 25 Apr 2024 16:34:03 +0800 Subject: [PATCH] Add sm2_encrypt_pre_compute --- include/gmssl/sm2.h | 35 ++++-- src/sm2_enc.c | 277 ++++++++++++++++++++++++++++++-------------- tests/sm2_enctest.c | 38 ++++++ tools/sm2decrypt.c | 10 +- tools/sm2encrypt.c | 8 +- 5 files changed, 270 insertions(+), 98 deletions(-) diff --git a/include/gmssl/sm2.h b/include/gmssl/sm2.h index 72f13681..545722f9 100644 --- a/include/gmssl/sm2.h +++ b/include/gmssl/sm2.h @@ -251,17 +251,36 @@ _gmssl_export int sm2_ecdh(const SM2_KEY *key, const uint8_t *peer_public, size_ typedef struct { - SM2_KEY sm2_key; - uint8_t buf[SM2_MAX_CIPHERTEXT_SIZE]; + sm2_z256_t k; + SM2_POINT C1; +} SM2_ENC_PRE_COMP; + +#define SM2_ENC_PRE_COMP_NUM 8 +int sm2_encrypt_pre_compute(SM2_ENC_PRE_COMP pre_comp[SM2_ENC_PRE_COMP_NUM]); +int sm2_do_encrypt_ex(const SM2_KEY *key, const SM2_ENC_PRE_COMP *pre_comp, + const uint8_t *in, size_t inlen, SM2_CIPHERTEXT *out); + +typedef struct { + SM2_ENC_PRE_COMP pre_comp[SM2_ENC_PRE_COMP_NUM]; + size_t pre_comp_num; + uint8_t buf[SM2_MAX_PLAINTEXT_SIZE]; size_t buf_size; } SM2_ENC_CTX; -_gmssl_export int sm2_encrypt_init(SM2_ENC_CTX *ctx, const SM2_KEY *sm2_key); -_gmssl_export int sm2_encrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); -_gmssl_export int sm2_encrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); -_gmssl_export int sm2_decrypt_init(SM2_ENC_CTX *ctx, const SM2_KEY *sm2_key); -_gmssl_export int sm2_decrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); -_gmssl_export int sm2_decrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); +_gmssl_export int sm2_encrypt_init(SM2_ENC_CTX *ctx); +_gmssl_export int sm2_encrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen); +_gmssl_export int sm2_encrypt_finish(SM2_ENC_CTX *ctx, const SM2_KEY *public_key, uint8_t *out, size_t *outlen); +_gmssl_export int sm2_encrypt_reset(SM2_ENC_CTX *ctx); + +typedef struct { + uint8_t buf[SM2_MAX_CIPHERTEXT_SIZE]; + size_t buf_size; +} SM2_DEC_CTX; + +_gmssl_export int sm2_decrypt_init(SM2_DEC_CTX *ctx); +_gmssl_export int sm2_decrypt_update(SM2_DEC_CTX *ctx, const uint8_t *in, size_t inlen); +_gmssl_export int sm2_decrypt_finish(SM2_DEC_CTX *ctx, const SM2_KEY *key, uint8_t *out, size_t *outlen); +_gmssl_export int sm2_decrypt_reset(SM2_DEC_CTX *ctx); #ifdef __cplusplus diff --git a/src/sm2_enc.c b/src/sm2_enc.c index a03e0e7b..3c4c2ea1 100644 --- a/src/sm2_enc.c +++ b/src/sm2_enc.c @@ -58,22 +58,118 @@ int sm2_kdf(const uint8_t *in, size_t inlen, size_t outlen, uint8_t *out) return 1; } -int sm2_do_encrypt_pre_compute(sm2_z256_t k, uint8_t C1[64]) +// use Montgomery's Trick to inverse Z coordinates on multiple (x1, y1) = [k]G +int sm2_encrypt_pre_compute(SM2_ENC_PRE_COMP pre_comp[SM2_ENC_PRE_COMP_NUM]) { - SM2_Z256_POINT P; + SM2_Z256_POINT P[SM2_ENC_PRE_COMP_NUM]; + sm2_z256_t f[SM2_ENC_PRE_COMP_NUM]; + sm2_z256_t g[SM2_ENC_PRE_COMP_NUM]; + int i; - // rand k in [1, n - 1] - do { - if (sm2_z256_rand_range(k, sm2_z256_order()) != 1) { - error_print(); - return -1; - } - } while (sm2_z256_is_zero(k)); + for (i = 0; i < SM2_ENC_PRE_COMP_NUM; i++) { - // output C1 = k * G = (x1, y1) - sm2_z256_point_mul_generator(&P, k); - sm2_z256_point_to_bytes(&P, C1); + // rand k in [1, n - 1] + do { + if (sm2_z256_rand_range(pre_comp[i].k, sm2_z256_order()) != 1) { + error_print(); + return -1; + } + } while (sm2_z256_is_zero(pre_comp[i].k)); + // (x1, y1) = kG + sm2_z256_point_mul_generator(&P[i], pre_comp[i].k); + } + + // f[0] = Z[0] + // f[1] = Z[0] * Z[1] + // ... + // f[31] = Z[0] * Z[1] * ... * Z[31] + sm2_z256_copy(f[0], P[0].Z); + for (i = 1; i < SM2_ENC_PRE_COMP_NUM; i++) { + sm2_z256_modp_mont_mul(f[i], f[i - 1], P[i].Z); + } + + // f[31]^-1 = (Z[0] * ... * Z[31])^-1 + sm2_z256_modp_mont_inv(f[SM2_ENC_PRE_COMP_NUM - 1], f[SM2_ENC_PRE_COMP_NUM - 1]); + + // g[31] = Z[31] + // g[30] = Z[30] * Z[31] + // ... + // g[1] = Z[1] * Z[2] * ... * Z[31] + // + sm2_z256_copy(g[SM2_ENC_PRE_COMP_NUM - 1], P[SM2_ENC_PRE_COMP_NUM - 1].Z); + for (i = SM2_ENC_PRE_COMP_NUM - 2; i >= 1; i--) { + sm2_z256_modp_mont_mul(g[i], g[i + 1], P[i].Z); + } + + // Z[0]^-1 = g[1] * f[31]^-1 + // Z[1]^-1 = g[2] * f[0] * f[31]^-1 + // Z[2]^-1 = g[3] * f[1] * f[31]^-1 + // ... + // Z[30]^-1 = g[31] * f[29] * f[31]^-1 + // Z[31]^-1 = f[30] * f[31]^-1 + sm2_z256_modp_mont_mul(P[0].Z, g[1], f[SM2_ENC_PRE_COMP_NUM - 1]); + for (i = 1; i < SM2_ENC_PRE_COMP_NUM - 1; i++) { + sm2_z256_modp_mont_mul(P[i].Z, g[i + 1], f[i - 1]); + sm2_z256_modp_mont_mul(P[i].Z, P[i].Z, f[SM2_ENC_PRE_COMP_NUM - 1]); + } + sm2_z256_modp_mont_mul(P[SM2_ENC_PRE_COMP_NUM - 1].Z, + f[SM2_ENC_PRE_COMP_NUM - 2], f[SM2_ENC_PRE_COMP_NUM - 1]); + + // y[i] = Y[i] * Z[i]^-3 (mod n) + // x[i] = X[i] * Z[i]^-2 (mod n) + for (i = 0; i < SM2_ENC_PRE_COMP_NUM; i++) { + + sm2_z256_modp_mont_mul(P[i].Y, P[i].Y, P[i].Z); + sm2_z256_modp_mont_sqr(P[i].Z, P[i].Z); + sm2_z256_modp_mont_mul(P[i].Y, P[i].Y, P[i].Z); + sm2_z256_modp_mont_mul(P[i].X, P[i].X, P[i].Z); + + sm2_z256_modp_from_mont(P[i].X, P[i].X); + sm2_z256_modp_from_mont(P[i].Y, P[i].Y); + + sm2_z256_to_bytes(P[i].X, pre_comp[i].C1.x); + sm2_z256_to_bytes(P[i].Y, pre_comp[i].C1.y); + } + + return 1; +} + +int sm2_do_encrypt_ex(const SM2_KEY *key, const SM2_ENC_PRE_COMP *pre_comp, + const uint8_t *in, size_t inlen, SM2_CIPHERTEXT *out) +{ + SM2_Z256_POINT kP; + uint8_t x2y2[64]; + SM3_CTX sm3_ctx; + + // output C1 + out->point = pre_comp->C1; + + // k * P = (x2, y2) + sm2_z256_point_mul(&kP, pre_comp->k, &key->public_key); + sm2_z256_point_to_bytes(&kP, x2y2); + + // t = KDF(x2 || y2, inlen) + sm2_kdf(x2y2, 64, inlen, out->ciphertext); + + // if t is all zero, return 0, caller should change pre_comp and retry + if (all_zero(out->ciphertext, inlen)) { + return 0; + } + + // output C2 = M xor t + gmssl_memxor(out->ciphertext, out->ciphertext, in, inlen); + out->ciphertext_size = (uint32_t)inlen; + + // output C3 = Hash(x2 || m || y2) + sm3_init(&sm3_ctx); + sm3_update(&sm3_ctx, x2y2, 32); + sm3_update(&sm3_ctx, in, inlen); + sm3_update(&sm3_ctx, x2y2 + 32, 32); + sm3_finish(&sm3_ctx, out->hash); + + gmssl_secure_clear(&kP, sizeof(SM2_Z256_POINT)); + gmssl_secure_clear(x2y2, sizeof(x2y2)); return 1; } @@ -422,22 +518,30 @@ int sm2_decrypt(const SM2_KEY *key, const uint8_t *in, size_t inlen, uint8_t *ou } return 1; } -int sm2_encrypt_init(SM2_ENC_CTX *ctx, const SM2_KEY *sm2_key) +int sm2_encrypt_init(SM2_ENC_CTX *ctx) { - if (!ctx || !sm2_key) { + if (!ctx) { error_print(); return -1; } - memset(ctx, 0, sizeof(*ctx)); - ctx->sm2_key = *sm2_key; +#define ENABLE_SM2_ENC_PRE_COMPUTE 1 +#if ENABLE_SM2_ENC_PRE_COMPUTE + if (sm2_encrypt_pre_compute(ctx->pre_comp) != 1) { + error_print(); + return -1; + } + ctx->pre_comp_num = SM2_ENC_PRE_COMP_NUM; +#endif + + ctx->buf_size = 0; return 1; } -int sm2_encrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) +int sm2_encrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen) { - if (!ctx || !outlen) { + if (!ctx) { error_print(); return -1; } @@ -447,11 +551,6 @@ int sm2_encrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_ return -1; } - if (!out) { - *outlen = 0; - return 1; - } - if (in) { if (inlen > SM2_MAX_PLAINTEXT_SIZE - ctx->buf_size) { error_print(); @@ -462,13 +561,14 @@ int sm2_encrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_ ctx->buf_size += inlen; } - *outlen = 0; return 1; } -int sm2_encrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) +int sm2_encrypt_finish(SM2_ENC_CTX *ctx, const SM2_KEY *public_key, uint8_t *out, size_t *outlen) { - if (!ctx || !outlen) { + SM2_CIPHERTEXT ciphertext; + + if (!ctx || !public_key || !outlen) { error_print(); return -1; } @@ -477,55 +577,72 @@ int sm2_encrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_ error_print(); return -1; } + if (ctx->buf_size == 0) { + error_print(); + return -1; + } if (!out) { *outlen = SM2_MAX_CIPHERTEXT_SIZE; return 1; } - if (ctx->buf_size) { - if (in) { - if (inlen > SM2_MAX_PLAINTEXT_SIZE - ctx->buf_size) { - error_print(); - return -1; - } - memcpy(ctx->buf + ctx->buf_size, in, inlen); - ctx->buf_size += inlen; - } - if (sm2_encrypt(&ctx->sm2_key, ctx->buf, ctx->buf_size, out, outlen) != 1) { - error_print(); - return -1; - } - } else { - if (!in || !inlen || inlen > SM2_MAX_PLAINTEXT_SIZE) { - error_print(); - return -1; - } - if (sm2_encrypt(&ctx->sm2_key, in, inlen, out, outlen) != 1) { +#if ENABLE_SM2_ENC_PRE_COMPUTE + if (ctx->pre_comp_num == 0) { + if (sm2_encrypt_pre_compute(ctx->pre_comp) != 1) { error_print(); return -1; } + ctx->pre_comp_num = SM2_ENC_PRE_COMP_NUM; } - return 1; -} - -int sm2_decrypt_init(SM2_ENC_CTX *ctx, const SM2_KEY *sm2_key) -{ - if (!ctx || !sm2_key) { + ctx->pre_comp_num--; + if (sm2_do_encrypt_ex(public_key, &ctx->pre_comp[ctx->pre_comp_num], ctx->buf, ctx->buf_size, &ciphertext) != 1) { error_print(); return -1; } - memset(ctx, 0, sizeof(*ctx)); - ctx->sm2_key = *sm2_key; + *outlen = 0; + if (sm2_ciphertext_to_der(&ciphertext, &out, outlen) != 1) { + error_print(); + return -1; + } +#else + if (sm2_encrypt(public_key, ctx->buf, ctx->buf_size, out, outlen) != 1) { + error_print(); + return -1; + } +#endif return 1; } -int sm2_decrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) +int sm2_encrypt_reset(SM2_ENC_CTX *ctx) { - if (!ctx || !outlen) { + if (!ctx) { + error_print(); + return -1; + } + + ctx->buf_size = 0; + return 1; +} + +int sm2_decrypt_init(SM2_DEC_CTX *ctx) +{ + if (!ctx) { + error_print(); + return -1; + } + + ctx->buf_size = 0; + + return 1; +} + +int sm2_decrypt_update(SM2_DEC_CTX *ctx, const uint8_t *in, size_t inlen) +{ + if (!ctx) { error_print(); return -1; } @@ -535,11 +652,6 @@ int sm2_decrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_ return -1; } - if (!out) { - *outlen = 0; - return 1; - } - if (in) { if (inlen > SM2_MAX_CIPHERTEXT_SIZE - ctx->buf_size) { error_print(); @@ -550,13 +662,12 @@ int sm2_decrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_ ctx->buf_size += inlen; } - *outlen = 0; return 1; } -int sm2_decrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) +int sm2_decrypt_finish(SM2_DEC_CTX *ctx, const SM2_KEY *key, uint8_t *out, size_t *outlen) { - if (!ctx || !outlen) { + if (!ctx || !key || !outlen) { error_print(); return -1; } @@ -565,35 +676,31 @@ int sm2_decrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_ error_print(); return -1; } + if (ctx->buf_size < SM2_MIN_CIPHERTEXT_SIZE) { + error_print(); + return -1; + } if (!out) { *outlen = SM2_MAX_PLAINTEXT_SIZE; return 1; } - if (ctx->buf_size) { - if (in) { - if (inlen > SM2_MAX_CIPHERTEXT_SIZE - ctx->buf_size) { - error_print(); - return -1; - } - memcpy(ctx->buf + ctx->buf_size, in, inlen); - ctx->buf_size += inlen; - } - if (sm2_decrypt(&ctx->sm2_key, ctx->buf, ctx->buf_size, out, outlen) != 1) { - error_print(); - return -1; - } - } else { - if (!in || !inlen || inlen > SM2_MAX_CIPHERTEXT_SIZE) { - error_print(); - return -1; - } - if (sm2_decrypt(&ctx->sm2_key, in, inlen, out, outlen) != 1) { - error_print(); - return -1; - } + if (sm2_decrypt(key, ctx->buf, ctx->buf_size, out, outlen) != 1) { + error_print(); + return -1; } return 1; } + +int sm2_decrypt_reset(SM2_DEC_CTX *ctx) +{ + if (!ctx) { + error_print(); + return -1; + } + + ctx->buf_size = 0; + return 1; +} diff --git a/tests/sm2_enctest.c b/tests/sm2_enctest.c index 64039a25..d575b536 100644 --- a/tests/sm2_enctest.c +++ b/tests/sm2_enctest.c @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -280,6 +281,42 @@ static int test_sm2_encrypt(void) return 1; } +static int test_sm2_encrypt_ctx_speed(void) +{ + SM2_KEY sm2_key; + SM2_ENC_CTX enc_ctx; + uint8_t plaintext[32]; + uint8_t ciphertext[SM2_MAX_CIPHERTEXT_SIZE]; + size_t ciphertext_len; + clock_t begin, end; + double seconds; + int i; + + sm2_key_generate(&sm2_key); + + if (sm2_encrypt_init(&enc_ctx) != 1) { + error_print(); + return -1; + } + + begin = clock(); + for (i = 0; i < 4096; i++) { + if (sm2_encrypt_update(&enc_ctx, plaintext, sizeof(plaintext)) != 1) { + error_print(); + return -1; + } + if (sm2_encrypt_finish(&enc_ctx, &sm2_key, ciphertext, &ciphertext_len) != 1) { + error_print(); + return -1; + } + sm2_encrypt_reset(&enc_ctx); + } + end = clock(); + seconds = (double)(end - begin)/CLOCKS_PER_SEC; + + printf("%s: %f encryptions per second\n", __FUNCTION__, 4096/seconds); + return 1; +} int main(void) @@ -289,6 +326,7 @@ int main(void) if (test_sm2_do_encrypt_fixlen() != 1) goto err; if (test_sm2_encrypt() != 1) goto err; if (test_sm2_encrypt_fixlen() != 1) goto err; + if (test_sm2_encrypt_ctx_speed() != 1) goto err; printf("%s all tests passed\n", __FILE__); return 0; err: diff --git a/tools/sm2decrypt.c b/tools/sm2decrypt.c index b4aac58d..9471e281 100644 --- a/tools/sm2decrypt.c +++ b/tools/sm2decrypt.c @@ -30,7 +30,7 @@ int sm2decrypt_main(int argc, char **argv) FILE *infp = stdin; FILE *outfp = stdout; SM2_KEY key; - SM2_ENC_CTX ctx; + SM2_DEC_CTX ctx; uint8_t inbuf[SM2_MAX_CIPHERTEXT_SIZE]; uint8_t outbuf[SM2_MAX_CIPHERTEXT_SIZE]; size_t inlen, outlen; @@ -103,11 +103,15 @@ bad: goto end; } - if (sm2_decrypt_init(&ctx, &key) != 1) { + if (sm2_decrypt_init(&ctx) != 1) { fprintf(stderr, "%s: sm2_decrypt_init failed\n", prog); goto end; } - if (sm2_decrypt_finish(&ctx, inbuf, inlen, outbuf, &outlen) != 1) { + if (sm2_decrypt_update(&ctx, inbuf, inlen) != 1) { + fprintf(stderr, "%s: sm2_decyrpt_update failed\n", prog); + goto end; + } + if (sm2_decrypt_finish(&ctx, &key, outbuf, &outlen) != 1) { fprintf(stderr, "%s: decryption failure\n", prog); goto end; } diff --git a/tools/sm2encrypt.c b/tools/sm2encrypt.c index 2f00842d..e8eec0a7 100644 --- a/tools/sm2encrypt.c +++ b/tools/sm2encrypt.c @@ -125,11 +125,15 @@ bad: goto end; } - if (sm2_encrypt_init(&ctx, &key) != 1) { + if (sm2_encrypt_init(&ctx) != 1) { fprintf(stderr, "%s: sm2_encrypt_init failed\n", prog); goto end; } - if (sm2_encrypt_finish(&ctx, inbuf, inlen, outbuf, &outlen) != 1) { + if (sm2_encrypt_update(&ctx, inbuf, inlen) != 1) { + fprintf(stderr, "%s: sm2_encrypt_update failed\n", prog); + return -1; + } + if (sm2_encrypt_finish(&ctx, &key, outbuf, &outlen) != 1) { fprintf(stderr, "%s: sm2_encrypt_finish error\n", prog); goto end; }