Add sm2_encrypt_pre_compute

This commit is contained in:
Zhi Guan
2024-04-25 16:34:03 +08:00
parent f0859a1f04
commit 3f1fdc147a
5 changed files with 270 additions and 98 deletions

View File

@@ -58,22 +58,118 @@ int sm2_kdf(const uint8_t *in, size_t inlen, size_t outlen, uint8_t *out)
return 1;
}
int sm2_do_encrypt_pre_compute(sm2_z256_t k, uint8_t C1[64])
// use Montgomery's Trick to inverse Z coordinates on multiple (x1, y1) = [k]G
int sm2_encrypt_pre_compute(SM2_ENC_PRE_COMP pre_comp[SM2_ENC_PRE_COMP_NUM])
{
SM2_Z256_POINT P;
SM2_Z256_POINT P[SM2_ENC_PRE_COMP_NUM];
sm2_z256_t f[SM2_ENC_PRE_COMP_NUM];
sm2_z256_t g[SM2_ENC_PRE_COMP_NUM];
int i;
// rand k in [1, n - 1]
do {
if (sm2_z256_rand_range(k, sm2_z256_order()) != 1) {
error_print();
return -1;
}
} while (sm2_z256_is_zero(k));
for (i = 0; i < SM2_ENC_PRE_COMP_NUM; i++) {
// output C1 = k * G = (x1, y1)
sm2_z256_point_mul_generator(&P, k);
sm2_z256_point_to_bytes(&P, C1);
// rand k in [1, n - 1]
do {
if (sm2_z256_rand_range(pre_comp[i].k, sm2_z256_order()) != 1) {
error_print();
return -1;
}
} while (sm2_z256_is_zero(pre_comp[i].k));
// (x1, y1) = kG
sm2_z256_point_mul_generator(&P[i], pre_comp[i].k);
}
// f[0] = Z[0]
// f[1] = Z[0] * Z[1]
// ...
// f[31] = Z[0] * Z[1] * ... * Z[31]
sm2_z256_copy(f[0], P[0].Z);
for (i = 1; i < SM2_ENC_PRE_COMP_NUM; i++) {
sm2_z256_modp_mont_mul(f[i], f[i - 1], P[i].Z);
}
// f[31]^-1 = (Z[0] * ... * Z[31])^-1
sm2_z256_modp_mont_inv(f[SM2_ENC_PRE_COMP_NUM - 1], f[SM2_ENC_PRE_COMP_NUM - 1]);
// g[31] = Z[31]
// g[30] = Z[30] * Z[31]
// ...
// g[1] = Z[1] * Z[2] * ... * Z[31]
//
sm2_z256_copy(g[SM2_ENC_PRE_COMP_NUM - 1], P[SM2_ENC_PRE_COMP_NUM - 1].Z);
for (i = SM2_ENC_PRE_COMP_NUM - 2; i >= 1; i--) {
sm2_z256_modp_mont_mul(g[i], g[i + 1], P[i].Z);
}
// Z[0]^-1 = g[1] * f[31]^-1
// Z[1]^-1 = g[2] * f[0] * f[31]^-1
// Z[2]^-1 = g[3] * f[1] * f[31]^-1
// ...
// Z[30]^-1 = g[31] * f[29] * f[31]^-1
// Z[31]^-1 = f[30] * f[31]^-1
sm2_z256_modp_mont_mul(P[0].Z, g[1], f[SM2_ENC_PRE_COMP_NUM - 1]);
for (i = 1; i < SM2_ENC_PRE_COMP_NUM - 1; i++) {
sm2_z256_modp_mont_mul(P[i].Z, g[i + 1], f[i - 1]);
sm2_z256_modp_mont_mul(P[i].Z, P[i].Z, f[SM2_ENC_PRE_COMP_NUM - 1]);
}
sm2_z256_modp_mont_mul(P[SM2_ENC_PRE_COMP_NUM - 1].Z,
f[SM2_ENC_PRE_COMP_NUM - 2], f[SM2_ENC_PRE_COMP_NUM - 1]);
// y[i] = Y[i] * Z[i]^-3 (mod n)
// x[i] = X[i] * Z[i]^-2 (mod n)
for (i = 0; i < SM2_ENC_PRE_COMP_NUM; i++) {
sm2_z256_modp_mont_mul(P[i].Y, P[i].Y, P[i].Z);
sm2_z256_modp_mont_sqr(P[i].Z, P[i].Z);
sm2_z256_modp_mont_mul(P[i].Y, P[i].Y, P[i].Z);
sm2_z256_modp_mont_mul(P[i].X, P[i].X, P[i].Z);
sm2_z256_modp_from_mont(P[i].X, P[i].X);
sm2_z256_modp_from_mont(P[i].Y, P[i].Y);
sm2_z256_to_bytes(P[i].X, pre_comp[i].C1.x);
sm2_z256_to_bytes(P[i].Y, pre_comp[i].C1.y);
}
return 1;
}
int sm2_do_encrypt_ex(const SM2_KEY *key, const SM2_ENC_PRE_COMP *pre_comp,
const uint8_t *in, size_t inlen, SM2_CIPHERTEXT *out)
{
SM2_Z256_POINT kP;
uint8_t x2y2[64];
SM3_CTX sm3_ctx;
// output C1
out->point = pre_comp->C1;
// k * P = (x2, y2)
sm2_z256_point_mul(&kP, pre_comp->k, &key->public_key);
sm2_z256_point_to_bytes(&kP, x2y2);
// t = KDF(x2 || y2, inlen)
sm2_kdf(x2y2, 64, inlen, out->ciphertext);
// if t is all zero, return 0, caller should change pre_comp and retry
if (all_zero(out->ciphertext, inlen)) {
return 0;
}
// output C2 = M xor t
gmssl_memxor(out->ciphertext, out->ciphertext, in, inlen);
out->ciphertext_size = (uint32_t)inlen;
// output C3 = Hash(x2 || m || y2)
sm3_init(&sm3_ctx);
sm3_update(&sm3_ctx, x2y2, 32);
sm3_update(&sm3_ctx, in, inlen);
sm3_update(&sm3_ctx, x2y2 + 32, 32);
sm3_finish(&sm3_ctx, out->hash);
gmssl_secure_clear(&kP, sizeof(SM2_Z256_POINT));
gmssl_secure_clear(x2y2, sizeof(x2y2));
return 1;
}
@@ -422,22 +518,30 @@ int sm2_decrypt(const SM2_KEY *key, const uint8_t *in, size_t inlen, uint8_t *ou
}
return 1;
}
int sm2_encrypt_init(SM2_ENC_CTX *ctx, const SM2_KEY *sm2_key)
int sm2_encrypt_init(SM2_ENC_CTX *ctx)
{
if (!ctx || !sm2_key) {
if (!ctx) {
error_print();
return -1;
}
memset(ctx, 0, sizeof(*ctx));
ctx->sm2_key = *sm2_key;
#define ENABLE_SM2_ENC_PRE_COMPUTE 1
#if ENABLE_SM2_ENC_PRE_COMPUTE
if (sm2_encrypt_pre_compute(ctx->pre_comp) != 1) {
error_print();
return -1;
}
ctx->pre_comp_num = SM2_ENC_PRE_COMP_NUM;
#endif
ctx->buf_size = 0;
return 1;
}
int sm2_encrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
int sm2_encrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen)
{
if (!ctx || !outlen) {
if (!ctx) {
error_print();
return -1;
}
@@ -447,11 +551,6 @@ int sm2_encrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_
return -1;
}
if (!out) {
*outlen = 0;
return 1;
}
if (in) {
if (inlen > SM2_MAX_PLAINTEXT_SIZE - ctx->buf_size) {
error_print();
@@ -462,13 +561,14 @@ int sm2_encrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_
ctx->buf_size += inlen;
}
*outlen = 0;
return 1;
}
int sm2_encrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
int sm2_encrypt_finish(SM2_ENC_CTX *ctx, const SM2_KEY *public_key, uint8_t *out, size_t *outlen)
{
if (!ctx || !outlen) {
SM2_CIPHERTEXT ciphertext;
if (!ctx || !public_key || !outlen) {
error_print();
return -1;
}
@@ -477,55 +577,72 @@ int sm2_encrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_
error_print();
return -1;
}
if (ctx->buf_size == 0) {
error_print();
return -1;
}
if (!out) {
*outlen = SM2_MAX_CIPHERTEXT_SIZE;
return 1;
}
if (ctx->buf_size) {
if (in) {
if (inlen > SM2_MAX_PLAINTEXT_SIZE - ctx->buf_size) {
error_print();
return -1;
}
memcpy(ctx->buf + ctx->buf_size, in, inlen);
ctx->buf_size += inlen;
}
if (sm2_encrypt(&ctx->sm2_key, ctx->buf, ctx->buf_size, out, outlen) != 1) {
error_print();
return -1;
}
} else {
if (!in || !inlen || inlen > SM2_MAX_PLAINTEXT_SIZE) {
error_print();
return -1;
}
if (sm2_encrypt(&ctx->sm2_key, in, inlen, out, outlen) != 1) {
#if ENABLE_SM2_ENC_PRE_COMPUTE
if (ctx->pre_comp_num == 0) {
if (sm2_encrypt_pre_compute(ctx->pre_comp) != 1) {
error_print();
return -1;
}
ctx->pre_comp_num = SM2_ENC_PRE_COMP_NUM;
}
return 1;
}
int sm2_decrypt_init(SM2_ENC_CTX *ctx, const SM2_KEY *sm2_key)
{
if (!ctx || !sm2_key) {
ctx->pre_comp_num--;
if (sm2_do_encrypt_ex(public_key, &ctx->pre_comp[ctx->pre_comp_num], ctx->buf, ctx->buf_size, &ciphertext) != 1) {
error_print();
return -1;
}
memset(ctx, 0, sizeof(*ctx));
ctx->sm2_key = *sm2_key;
*outlen = 0;
if (sm2_ciphertext_to_der(&ciphertext, &out, outlen) != 1) {
error_print();
return -1;
}
#else
if (sm2_encrypt(public_key, ctx->buf, ctx->buf_size, out, outlen) != 1) {
error_print();
return -1;
}
#endif
return 1;
}
int sm2_decrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
int sm2_encrypt_reset(SM2_ENC_CTX *ctx)
{
if (!ctx || !outlen) {
if (!ctx) {
error_print();
return -1;
}
ctx->buf_size = 0;
return 1;
}
int sm2_decrypt_init(SM2_DEC_CTX *ctx)
{
if (!ctx) {
error_print();
return -1;
}
ctx->buf_size = 0;
return 1;
}
int sm2_decrypt_update(SM2_DEC_CTX *ctx, const uint8_t *in, size_t inlen)
{
if (!ctx) {
error_print();
return -1;
}
@@ -535,11 +652,6 @@ int sm2_decrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_
return -1;
}
if (!out) {
*outlen = 0;
return 1;
}
if (in) {
if (inlen > SM2_MAX_CIPHERTEXT_SIZE - ctx->buf_size) {
error_print();
@@ -550,13 +662,12 @@ int sm2_decrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_
ctx->buf_size += inlen;
}
*outlen = 0;
return 1;
}
int sm2_decrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
int sm2_decrypt_finish(SM2_DEC_CTX *ctx, const SM2_KEY *key, uint8_t *out, size_t *outlen)
{
if (!ctx || !outlen) {
if (!ctx || !key || !outlen) {
error_print();
return -1;
}
@@ -565,35 +676,31 @@ int sm2_decrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_
error_print();
return -1;
}
if (ctx->buf_size < SM2_MIN_CIPHERTEXT_SIZE) {
error_print();
return -1;
}
if (!out) {
*outlen = SM2_MAX_PLAINTEXT_SIZE;
return 1;
}
if (ctx->buf_size) {
if (in) {
if (inlen > SM2_MAX_CIPHERTEXT_SIZE - ctx->buf_size) {
error_print();
return -1;
}
memcpy(ctx->buf + ctx->buf_size, in, inlen);
ctx->buf_size += inlen;
}
if (sm2_decrypt(&ctx->sm2_key, ctx->buf, ctx->buf_size, out, outlen) != 1) {
error_print();
return -1;
}
} else {
if (!in || !inlen || inlen > SM2_MAX_CIPHERTEXT_SIZE) {
error_print();
return -1;
}
if (sm2_decrypt(&ctx->sm2_key, in, inlen, out, outlen) != 1) {
error_print();
return -1;
}
if (sm2_decrypt(key, ctx->buf, ctx->buf_size, out, outlen) != 1) {
error_print();
return -1;
}
return 1;
}
int sm2_decrypt_reset(SM2_DEC_CTX *ctx)
{
if (!ctx) {
error_print();
return -1;
}
ctx->buf_size = 0;
return 1;
}