Use z256 and jacobian coordinates as inner presentation of SM2 point

This commit is contained in:
Zhi Guan
2024-04-18 21:27:00 +08:00
parent 549c68d2df
commit e9bbcf5490
18 changed files with 312 additions and 646 deletions

View File

@@ -23,7 +23,6 @@
int sm2_do_sign(const SM2_KEY *key, const uint8_t dgst[32], SM2_SIGNATURE *sig)
{
SM2_Z256_POINT P;
sm2_z256_t d;
sm2_z256_t d_inv;
sm2_z256_t e;
sm2_z256_t k;
@@ -32,10 +31,8 @@ int sm2_do_sign(const SM2_KEY *key, const uint8_t dgst[32], SM2_SIGNATURE *sig)
sm2_z256_t r;
sm2_z256_t s;
sm2_z256_from_bytes(d, key->private_key);
// compute (d + 1)^-1 (mod n)
sm2_z256_modn_add(d_inv, d, sm2_z256_one());
sm2_z256_modn_add(d_inv, key->private_key, sm2_z256_one());
if (sm2_z256_is_zero(d_inv)) {
error_print();
return -1;
@@ -75,7 +72,7 @@ retry:
}
// s = ((1 + d)^-1 * (k - r * d)) mod n
sm2_z256_modn_mul(t, r, d);
sm2_z256_modn_mul(t, r, key->private_key);
sm2_z256_modn_sub(k, k, t);
sm2_z256_modn_mul(s, d_inv, k);
@@ -87,14 +84,22 @@ retry:
sm2_z256_to_bytes(r, sig->r);
sm2_z256_to_bytes(s, sig->s);
gmssl_secure_clear(d, sizeof(d));
gmssl_secure_clear(d_inv, sizeof(d_inv));
gmssl_secure_clear(k, sizeof(k));
gmssl_secure_clear(t, sizeof(t));
return 1;
}
int sm2_do_sign_pre_compute(uint64_t k[4], uint64_t x1[4])
// d' = (d + 1)^-1 (mod n)
int sm2_fast_sign_compute_key(const SM2_KEY *key, sm2_z256_t fast_private)
{
sm2_z256_modn_add(fast_private, key->private_key, sm2_z256_one());
sm2_z256_modn_inv(fast_private, fast_private);
return 1;
}
// (x1, y1) = [k]G
int sm2_fast_sign_pre_compute(sm2_z256_t k, sm2_z256_t x1_modn)
{
SM2_Z256_POINT P;
@@ -108,12 +113,21 @@ int sm2_do_sign_pre_compute(uint64_t k[4], uint64_t x1[4])
// (x1, y1) = kG
sm2_z256_point_mul_generator(&P, k);
sm2_z256_point_get_xy(&P, x1, NULL);
sm2_z256_point_get_xy(&P, x1_modn, NULL);
// x1 mod n
if (sm2_z256_cmp(x1_modn, sm2_z256_order()) >= 0) {
sm2_z256_sub(x1_modn, x1_modn, sm2_z256_order());
}
return 1;
}
int sm2_do_sign_fast_ex(const uint64_t d[4], const uint64_t k[4], const uint64_t x1[4], const uint8_t dgst[32], SM2_SIGNATURE *sig)
// s = (k - r * d)/(1 + d)
// = -r + (k + r)*(1 + d)^-1
// = -r + (k + r) * d'
int sm2_fast_sign(const sm2_z256_t fast_private,
const sm2_z256_t k, const sm2_z256_t x1,
const uint8_t dgst[32], SM2_SIGNATURE *sig)
{
SM2_Z256_POINT R;
sm2_z256_t e;
@@ -131,7 +145,7 @@ int sm2_do_sign_fast_ex(const uint64_t d[4], const uint64_t k[4], const uint64_t
// s = (k + r) * d' - r
sm2_z256_modn_add(s, k, r);
sm2_z256_modn_mul(s, s, d);
sm2_z256_modn_mul(s, s, fast_private);
sm2_z256_modn_sub(s, s, r);
sm2_z256_to_bytes(r, sig->r);
@@ -140,67 +154,7 @@ int sm2_do_sign_fast_ex(const uint64_t d[4], const uint64_t k[4], const uint64_t
return 1;
}
// (x1, y1) = k * G
// r = e + x1
// s = (k - r * d)/(1 + d) = (k +r - r * d - r)/(1 + d) = (k + r - r(1 +d))/(1 + d) = (k + r)/(1 + d) - r
// = -r + (k + r)*(1 + d)^-1
// = -r + (k + r) * d'
int sm2_do_sign_fast(const uint64_t d[4], const uint8_t dgst[32], SM2_SIGNATURE *sig)
{
SM2_Z256_POINT R;
sm2_z256_t e;
sm2_z256_t k;
sm2_z256_t x1;
sm2_z256_t r;
sm2_z256_t s;
const uint64_t *order = sm2_z256_order();
// e = H(M)
sm2_z256_from_bytes(e, dgst);
if (sm2_z256_cmp(e, order) >= 0) {
sm2_z256_sub(e, e, order);
}
/// <<<<<<<<<<< 这里的 (k, x1) 应该是从外部输入的!!,这样才是最快的。
// rand k in [1, n - 1]
do {
if (sm2_z256_rand_range(k, sm2_z256_order()) != 1) {
error_print();
return -1;
}
} while (sm2_z256_is_zero(k));
// (x1, y1) = kG
sm2_z256_point_mul_generator(&R, k); // 这个函数要粗力度并行,这要怎么做?
sm2_z256_point_get_xy(&R, x1, NULL);
/// >>>>>>>>>>>>>>>>>>
// r = e + x1 (mod n)
sm2_z256_modn_add(r, e, x1);
// 对于快速实现来说,只需要一次乘法
// 如果 (k, x) 是预计算的,这意味着我们可以并行这个操作
// 也就是随机产生一些k然后执行粗力度并行的点乘
// s = (k + r) * d' - r
sm2_z256_modn_add(s, k, r);
sm2_z256_modn_mul(s, s, d);
sm2_z256_modn_sub(s, s, r);
sm2_z256_to_bytes(r, sig->r);
sm2_z256_to_bytes(s, sig->s);
return 1;
}
// 这个其实并没有更快无非就是降低了解析公钥椭圆曲线点的计算量这个点要转换为内部的Mont格式
// 这里根本没有modn的乘法
int sm2_do_verify_fast(const SM2_Z256_POINT *P, const uint8_t dgst[32], const SM2_SIGNATURE *sig)
int sm2_do_verify(const SM2_KEY *key, const uint8_t dgst[32], const SM2_SIGNATURE *sig)
{
SM2_Z256_POINT R;
sm2_z256_t r;
@@ -209,33 +163,26 @@ int sm2_do_verify_fast(const SM2_Z256_POINT *P, const uint8_t dgst[32], const SM
sm2_z256_t x;
sm2_z256_t t;
const uint64_t *order = sm2_z256_order();
// check r, s in [1, n-1]
sm2_z256_from_bytes(r, sig->r);
// check r in [1, n-1]
if (sm2_z256_is_zero(r) == 1) {
error_print();
return -1;
}
if (sm2_z256_cmp(r, order) >= 0) {
if (sm2_z256_cmp(r, sm2_z256_order()) >= 0) {
error_print();
return -1;
}
sm2_z256_from_bytes(s, sig->s);
// check s in [1, n-1]
if (sm2_z256_is_zero(s) == 1) {
error_print();
return -1;
}
if (sm2_z256_cmp(s, order) >= 0) {
if (sm2_z256_cmp(s, sm2_z256_order()) >= 0) {
error_print();
return -1;
}
// e = H(M)
sm2_z256_from_bytes(e, dgst);
// t = r + s (mod n), check t != 0
sm2_z256_modn_add(t, r, s);
if (sm2_z256_is_zero(t)) {
@@ -243,16 +190,19 @@ int sm2_do_verify_fast(const SM2_Z256_POINT *P, const uint8_t dgst[32], const SM
return -1;
}
// Q = s * G + t * P
sm2_z256_point_mul_sum(&R, t, P, s);
// Q(x,y) = s * G + t * P
sm2_z256_point_mul_sum(&R, t, &key->public_key, s);
sm2_z256_point_get_xy(&R, x, NULL);
// r' = e + x (mod n)
if (sm2_z256_cmp(e, order) >= 0) {
sm2_z256_sub(e, e, order);
// e = H(M)
sm2_z256_from_bytes(e, dgst);
if (sm2_z256_cmp(e, sm2_z256_order()) >= 0) {
sm2_z256_sub(e, e, sm2_z256_order());
}
if (sm2_z256_cmp(x, order) >= 0) {
sm2_z256_sub(x, x, order);
// r' = e + x (mod n)
if (sm2_z256_cmp(x, sm2_z256_order()) >= 0) {
sm2_z256_sub(x, x, sm2_z256_order());
}
sm2_z256_modn_add(e, e, x);
@@ -264,90 +214,6 @@ int sm2_do_verify_fast(const SM2_Z256_POINT *P, const uint8_t dgst[32], const SM
return 1;
}
int sm2_do_verify(const SM2_KEY *key, const uint8_t dgst[32], const SM2_SIGNATURE *sig)
{
SM2_Z256_POINT _P, *P = &_P;
SM2_Z256_POINT _R, *R = &_R;
sm2_z256_t r;
sm2_z256_t s;
sm2_z256_t e;
sm2_z256_t x;
sm2_z256_t t;
const uint64_t *order = sm2_z256_order();
sm2_z256_print(stderr, 0, 4, "n", order);
// parse public key
sm2_z256_point_from_bytes(P, (const uint8_t *)&key->public_key);
//sm2_z256_point_from_bytes(P, (const uint8_t *)&key->public_key);
//sm2_jacobian_point_print(stderr, 0, 4, "P", P);
// parse signature values
sm2_z256_from_bytes(r, sig->r); sm2_z256_print(stderr, 0, 4, "r", r);
sm2_z256_from_bytes(s, sig->s); sm2_z256_print(stderr, 0, 4, "s", s);
// check r, s in [1, n-1]
if (sm2_z256_is_zero(r) == 1) {
error_print();
return -1;
}
if (sm2_z256_cmp(r, order) >= 0) {
sm2_z256_print(stderr, 0, 4, "err: r", r);
sm2_z256_print(stderr, 0, 4, "err: order", order);
error_print();
return -1;
}
if (sm2_z256_is_zero(s) == 1) {
error_print();
return -1;
}
if (sm2_z256_cmp(s, order) >= 0) {
sm2_z256_print(stderr, 0, 4, "err: s", s);
sm2_z256_print(stderr, 0, 4, "err: order", order);
printf(">>>>>\n");
int r = sm2_z256_cmp(s, order);
fprintf(stderr, "cmp ret = %d\n", r);
printf(">>>>>\n");
error_print();
return -1;
}
// e = H(M)
sm2_z256_from_bytes(e, dgst); //sm2_bn_print(stderr, 0, 4, "e = H(M)", e);
// t = r + s (mod n), check t != 0
sm2_z256_modn_add(t, r, s); //sm2_bn_print(stderr, 0, 4, "t = r + s (mod n)", t);
if (sm2_z256_is_zero(t)) {
error_print();
return -1;
}
// Q = s * G + t * P
sm2_z256_point_mul_sum(R, t, P, s);
sm2_z256_point_get_xy(R, x, NULL);
//sm2_bn_print(stderr, 0, 4, "x", x);
// r' = e + x (mod n)
if (sm2_z256_cmp(e, order) >= 0) {
sm2_z256_sub(e, e, order);
}
if (sm2_z256_cmp(x, order) >= 0) {
sm2_z256_sub(x, x, order);
}
sm2_z256_modn_add(e, e, x); //sm2_bn_print(stderr, 0, 4, "e + x (mod n)", e);
// check if r == r'
if (sm2_z256_cmp(e, r) != 0) {
error_print();
return -1;
}
return 1;
}
int sm2_signature_to_der(const SM2_SIGNATURE *sig, uint8_t **out, size_t *outlen)
{
size_t len = 0;
@@ -483,7 +349,7 @@ int sm2_verify(const SM2_KEY *key, const uint8_t dgst[32], const uint8_t *sigbuf
return 1;
}
int sm2_compute_z(uint8_t z[32], const SM2_POINT *pub, const char *id, size_t idlen)
int sm2_compute_z(uint8_t z[32], const SM2_Z256_POINT *pub, const char *id, size_t idlen)
{
SM3_CTX ctx;
uint8_t zin[18 + 32 * 6] = {
@@ -504,8 +370,7 @@ int sm2_compute_z(uint8_t z[32], const SM2_POINT *pub, const char *id, size_t id
return -1;
}
memcpy(&zin[18 + 32 * 4], pub->x, 32);
memcpy(&zin[18 + 32 * 5], pub->y, 32);
sm2_z256_point_to_bytes(pub, &zin[18 + 32 * 4]);
sm3_init(&ctx);
if (strcmp(id, SM2_DEFAULT_ID) == 0) {
@@ -550,6 +415,7 @@ int sm2_kdf(const uint8_t *in, size_t inlen, size_t outlen, uint8_t *out)
return 1;
}
int sm2_sign_init(SM2_SIGN_CTX *ctx, const SM2_KEY *key, const char *id, size_t idlen)
{
size_t i;
@@ -558,17 +424,11 @@ int sm2_sign_init(SM2_SIGN_CTX *ctx, const SM2_KEY *key, const char *id, size_t
error_print();
return -1;
}
ctx->key = *key;
// d' = (d + 1)^-1 (mod n)
sm2_z256_from_bytes(ctx->sign_key, key->private_key);
sm2_z256_modn_add(ctx->sign_key, ctx->sign_key, sm2_z256_one());
sm2_z256_modn_inv(ctx->sign_key, ctx->sign_key);
sm3_init(&ctx->sm3_ctx);
if (id) {
uint8_t z[SM3_DIGEST_SIZE];
if (idlen <= 0 || idlen > SM2_MAX_ID_LENGTH) {
error_print();
return -1;
@@ -576,24 +436,26 @@ int sm2_sign_init(SM2_SIGN_CTX *ctx, const SM2_KEY *key, const char *id, size_t
sm2_compute_z(z, &key->public_key, id, idlen);
sm3_update(&ctx->sm3_ctx, z, sizeof(z));
}
ctx->saved_sm3_ctx = ctx->sm3_ctx;
ctx->inited_sm3_ctx = ctx->sm3_ctx;
// pre compute (k, x = [k]G.x)
for (i = 0; i < 32; i++) {
if (sm2_do_sign_pre_compute(ctx->pre_comp[i].k, ctx->pre_comp[i].x1) != 1) {
for (i = 0; i < SM2_SIGN_PRE_COMP_COUNT; i++) {
if (sm2_fast_sign_pre_compute(ctx->pre_comp[i].k, ctx->pre_comp[i].x1) != 1) {
error_print();
return -1;
}
}
ctx->num_pre_comp = 32;
ctx->num_pre_comp = SM2_SIGN_PRE_COMP_COUNT;
// copy private key at last
ctx->key = *key;
sm2_fast_sign_compute_key(key, ctx->fast_sign_private);
return 1;
}
int sm2_sign_ctx_reset(SM2_SIGN_CTX *ctx)
{
ctx->sm3_ctx = ctx->inited_sm3_ctx;
ctx->sm3_ctx = ctx->saved_sm3_ctx;
return 1;
}
@@ -618,21 +480,22 @@ int sm2_sign_finish(SM2_SIGN_CTX *ctx, uint8_t *sig, size_t *siglen)
error_print();
return -1;
}
sm3_finish(&ctx->sm3_ctx, dgst);
if (ctx->num_pre_comp == 0) {
size_t i;
for (i = 0; i < 32; i++) {
if (sm2_do_sign_pre_compute(ctx->pre_comp[i].k, ctx->pre_comp[i].x1) != 1) {
for (i = 0; i < SM2_SIGN_PRE_COMP_COUNT; i++) {
if (sm2_fast_sign_pre_compute(ctx->pre_comp[i].k, ctx->pre_comp[i].x1) != 1) {
error_print();
return -1;
}
}
ctx->num_pre_comp = 32;
ctx->num_pre_comp = SM2_SIGN_PRE_COMP_COUNT;
}
ctx->num_pre_comp--;
if (sm2_do_sign_fast_ex(ctx->sign_key,
if (sm2_fast_sign(ctx->fast_sign_private,
ctx->pre_comp[ctx->num_pre_comp].k, ctx->pre_comp[ctx->num_pre_comp].x1,
dgst, &signature) != 1) {
error_print();
@@ -670,15 +533,11 @@ int sm2_verify_init(SM2_SIGN_CTX *ctx, const SM2_KEY *key, const char *id, size_
error_print();
return -1;
}
memset(ctx, 0, sizeof(*ctx));
ctx->key.public_key = key->public_key;
sm2_z256_point_from_bytes((SM2_Z256_POINT *)&ctx->public_key, (const uint8_t *)&key->public_key);
sm3_init(&ctx->sm3_ctx);
if (id) {
uint8_t z[SM3_DIGEST_SIZE];
if (idlen <= 0 || idlen > SM2_MAX_ID_LENGTH) {
error_print();
return -1;
@@ -686,8 +545,15 @@ int sm2_verify_init(SM2_SIGN_CTX *ctx, const SM2_KEY *key, const char *id, size_
sm2_compute_z(z, &key->public_key, id, idlen);
sm3_update(&ctx->sm3_ctx, z, sizeof(z));
}
ctx->saved_sm3_ctx = ctx->sm3_ctx;
ctx->inited_sm3_ctx = ctx->sm3_ctx;
if (sm2_key_set_public_key(&ctx->key, &key->public_key) != 1) {
error_print();
return -1;
}
sm2_z256_set_zero(ctx->fast_sign_private);
memset(ctx->pre_comp, 0, sizeof(SM2_SIGN_PRE_COMP) * SM2_SIGN_PRE_COMP_COUNT);
return 1;
}