diff --git a/include/gmssl/sm2.h b/include/gmssl/sm2.h index 5c4964cf..f14b583e 100644 --- a/include/gmssl/sm2.h +++ b/include/gmssl/sm2.h @@ -128,8 +128,16 @@ int sm2_do_sign(const SM2_KEY *key, const uint8_t dgst[32], SM2_SIGNATURE *sig); int sm2_do_verify(const SM2_KEY *key, const uint8_t dgst[32], const SM2_SIGNATURE *sig); int sm2_fast_sign_compute_key(const SM2_KEY *key, sm2_z256_t fast_private); -int sm2_fast_sign_pre_compute(sm2_z256_t k, sm2_z256_t x1_modn); -int sm2_fast_sign(const sm2_z256_t fast_private, const sm2_z256_t k, const sm2_z256_t x1_modn, + +typedef struct { + sm2_z256_t k; + sm2_z256_t x1_modn; +} SM2_SIGN_PRE_COMP; + +#define SM2_SIGN_PRE_COMP_COUNT 32 + +int sm2_fast_sign_pre_compute(SM2_SIGN_PRE_COMP pre_comp[32]); +int sm2_fast_sign(const sm2_z256_t fast_private, SM2_SIGN_PRE_COMP *pre_comp, const uint8_t dgst[32], SM2_SIGNATURE *sig); @@ -159,12 +167,6 @@ int sm2_sign_fixlen(const SM2_KEY *key, const uint8_t dgst[32], size_t siglen, u int sm2_compute_z(uint8_t z[32], const SM2_Z256_POINT *pub, const char *id, size_t idlen); -typedef struct { - sm2_z256_t k; - sm2_z256_t x1; // x1 (mod n) -} SM2_SIGN_PRE_COMP; - -#define SM2_SIGN_PRE_COMP_COUNT 32 typedef struct { SM3_CTX sm3_ctx; diff --git a/src/sm2_sign.c b/src/sm2_sign.c index 15e90061..fbcbf721 100644 --- a/src/sm2_sign.c +++ b/src/sm2_sign.c @@ -101,35 +101,81 @@ int sm2_fast_sign_compute_key(const SM2_KEY *key, sm2_z256_t fast_private) return 1; } -// (x1, y1) = [k]G -int sm2_fast_sign_pre_compute(sm2_z256_t k, sm2_z256_t x1_modn) +// use Montgomery's Trick to inverse Z coordinates on multiple (x1, y1) = [k]G +int sm2_fast_sign_pre_compute(SM2_SIGN_PRE_COMP pre_comp[32]) { - SM2_Z256_POINT P; + SM2_Z256_POINT P[32]; + sm2_z256_t f[32]; + sm2_z256_t g[32]; + int i; - // rand k in [1, n - 1] - do { - if (sm2_z256_rand_range(k, sm2_z256_order()) != 1) { - error_print(); - return -1; - } - } while (sm2_z256_is_zero(k)); + for (i = 0; i < 32; i++) { - // (x1, y1) = kG - sm2_z256_point_mul_generator(&P, k); - sm2_z256_point_get_xy(&P, x1_modn, NULL); + // rand k in [1, n - 1] + do { + if (sm2_z256_rand_range(pre_comp[i].k, sm2_z256_order()) != 1) { + error_print(); + return -1; + } + } while (sm2_z256_is_zero(pre_comp[i].k)); - // x1 mod n - if (sm2_z256_cmp(x1_modn, sm2_z256_order()) >= 0) { - sm2_z256_sub(x1_modn, x1_modn, sm2_z256_order()); + // (x1, y1) = kG + sm2_z256_point_mul_generator(&P[i], pre_comp[i].k); } + + // f[0] = Z[0] + // f[1] = Z[0] * Z[1] + // ... + // f[31] = Z[0] * Z[1] * ... * Z[31] + sm2_z256_copy(f[0], P[0].Z); + for (i = 1; i < 32; i++) { + sm2_z256_modp_mont_mul(f[i], f[i - 1], P[i].Z); + } + + // f[31]^-1 = (Z[0] * ... * Z[31])^-1 + sm2_z256_modp_mont_inv(f[31], f[31]); + + // g[31] = Z[31] + // g[30] = Z[30] * Z[31] + // ... + // g[1] = Z[1] * Z[2] * ... * Z[31] + // + sm2_z256_copy(g[31], P[31].Z); + for (i = 30; i >= 1; i--) { + sm2_z256_modp_mont_mul(g[i], g[i + 1], P[i].Z); + } + + // Z[0]^-1 = g[1] * f[31]^-1 + // Z[1]^-1 = g[2] * f[0] * f[31]^-1 + // Z[2]^-1 = g[3] * f[1] * f[31]^-1 + // ... + // Z[30]^-1 = g[31] * f[29] * f[31]^-1 + // Z[31]^-1 = f[30] * f[31]^-1 + sm2_z256_modp_mont_mul(P[0].Z, g[1], f[31]); + for (i = 1; i <= 30; i++) { + sm2_z256_modp_mont_mul(P[i].Z, g[i + 1], f[i - 1]); + sm2_z256_modp_mont_mul(P[i].Z, P[i].Z, f[31]); + } + sm2_z256_modp_mont_mul(P[31].Z, f[30], f[31]); + + // x[i] = X[i] * Z[i]^-2 (mod n) + for (i = 0; i < 32; i++) { + sm2_z256_modp_mont_sqr(P[i].Z, P[i].Z); + sm2_z256_modp_mont_mul(pre_comp[i].x1_modn, P[i].X, P[i].Z); + sm2_z256_modp_from_mont(pre_comp[i].x1_modn, pre_comp[i].x1_modn); + if (sm2_z256_cmp(pre_comp[i].x1_modn, sm2_z256_order()) >= 0) { + sm2_z256_sub(pre_comp[i].x1_modn, pre_comp[i].x1_modn, sm2_z256_order()); + } + } + return 1; } + // s = (k - r * d)/(1 + d) // = -r + (k + r)*(1 + d)^-1 // = -r + (k + r) * d' -int sm2_fast_sign(const sm2_z256_t fast_private, - const sm2_z256_t k, const sm2_z256_t x1_modn, +int sm2_fast_sign(const sm2_z256_t fast_private, SM2_SIGN_PRE_COMP *pre_comp, const uint8_t dgst[32], SM2_SIGNATURE *sig) { SM2_Z256_POINT R; @@ -144,10 +190,10 @@ int sm2_fast_sign(const sm2_z256_t fast_private, } // r = e + x1 (mod n) - sm2_z256_modn_add(r, e, x1_modn); + sm2_z256_modn_add(r, e, pre_comp->x1_modn); // s = (k + r) * d' - r - sm2_z256_modn_add(s, k, r); + sm2_z256_modn_add(s, pre_comp->k, r); sm2_z256_modn_mul(s, s, fast_private); sm2_z256_modn_sub(s, s, r); @@ -412,11 +458,9 @@ int sm2_sign_init(SM2_SIGN_CTX *ctx, const SM2_KEY *key, const char *id, size_t } ctx->saved_sm3_ctx = ctx->sm3_ctx; - for (i = 0; i < SM2_SIGN_PRE_COMP_COUNT; i++) { - if (sm2_fast_sign_pre_compute(ctx->pre_comp[i].k, ctx->pre_comp[i].x1) != 1) { - error_print(); - return -1; - } + if (sm2_fast_sign_pre_compute(ctx->pre_comp) != 1) { + error_print(); + return -1; } ctx->num_pre_comp = SM2_SIGN_PRE_COMP_COUNT; @@ -458,19 +502,15 @@ int sm2_sign_finish(SM2_SIGN_CTX *ctx, uint8_t *sig, size_t *siglen) sm3_finish(&ctx->sm3_ctx, dgst); if (ctx->num_pre_comp == 0) { - size_t i; - for (i = 0; i < SM2_SIGN_PRE_COMP_COUNT; i++) { - if (sm2_fast_sign_pre_compute(ctx->pre_comp[i].k, ctx->pre_comp[i].x1) != 1) { - error_print(); - return -1; - } + if (sm2_fast_sign_pre_compute(ctx->pre_comp) != 1) { + error_print(); + return -1; } ctx->num_pre_comp = SM2_SIGN_PRE_COMP_COUNT; } ctx->num_pre_comp--; - if (sm2_fast_sign(ctx->fast_sign_private, - ctx->pre_comp[ctx->num_pre_comp].k, ctx->pre_comp[ctx->num_pre_comp].x1, + if (sm2_fast_sign(ctx->fast_sign_private, &ctx->pre_comp[ctx->num_pre_comp], dgst, &signature) != 1) { error_print(); return -1; diff --git a/tests/sm2_signtest.c b/tests/sm2_signtest.c index f14523d4..3a1d163d 100644 --- a/tests/sm2_signtest.c +++ b/tests/sm2_signtest.c @@ -106,6 +106,7 @@ static int test_sm2_fast_sign(void) { SM2_KEY sm2_key; sm2_z256_t fast_private; + SM2_SIGN_PRE_COMP pre_comp[32]; uint8_t dgst[32]; SM2_SIGNATURE sig; size_t i; @@ -118,17 +119,15 @@ static int test_sm2_fast_sign(void) error_print(); return -1; } + if (sm2_fast_sign_pre_compute(pre_comp) != 1) { + error_print(); + return -1; + } rand_bytes(dgst, sizeof(dgst)); - for (i = 0; i < TEST_COUNT; i++) { - sm2_z256_t k; - sm2_z256_t x1_modn; + for (i = 0; i < TEST_COUNT && i < sizeof(pre_comp)/sizeof(pre_comp[0]); i++) { - if (sm2_fast_sign_pre_compute(k, x1_modn) != 1) { - error_print(); - return -1; - } - if (sm2_fast_sign(fast_private, k, x1_modn, dgst, &sig) != 1) { + if (sm2_fast_sign(fast_private, &pre_comp[i], dgst, &sig) != 1) { error_print(); return -1; } @@ -143,45 +142,6 @@ static int test_sm2_fast_sign(void) return 1; } -static int test_sm2_do_sign_pre_compute(void) -{ - SM2_KEY sm2_key; - uint64_t d[4]; - - uint64_t k[4]; - uint64_t x1[4]; - uint8_t dgst[32]; - SM2_SIGNATURE sig; - - - sm2_key_generate(&sm2_key); - - const uint64_t *one = sm2_z256_one(); - sm2_z256_copy(d, sm2_key.private_key); - sm2_z256_modn_add(d, d, one); - sm2_z256_modn_inv(d, d); - - if (sm2_fast_sign_pre_compute(k, x1) != 1) { - error_print(); - return -1; - } - - rand_bytes(dgst, sizeof(dgst)); - - if (sm2_fast_sign(d, k, x1, dgst, &sig) != 1) { - error_print(); - return -1; - } - - if (sm2_do_verify(&sm2_key, dgst, &sig) != 1) { - error_print(); - return -1; - } - - printf("%s() ok\n", __FUNCTION__); - return 1; -} - static int test_sm2_sign(void) { SM2_KEY sm2_key; @@ -334,7 +294,6 @@ int main(void) { if (test_sm2_signature() != 1) goto err; if (test_sm2_do_sign() != 1) goto err; - if (test_sm2_do_sign_pre_compute() != 1) goto err; if (test_sm2_fast_sign() != 1) goto err; if (test_sm2_sign() != 1) goto err; if (test_sm2_sign_ctx() != 1) goto err;