From c93206922488c02865b332af0d2c06d5156c18e1 Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Fri, 13 Jan 2023 01:44:55 +0800 Subject: [PATCH] Update SM2 lib Add more checks Rename _ex functions to _fixlen Move extensions to sm2_recover.c --- include/gmssl/sm2.h | 21 +- include/gmssl/sm2_recover.h | 2 +- src/cms.c | 9 +- src/sm2_lib.c | 510 ++++++++++++++++++++++-------------- src/sm2_recover.c | 17 ++ 5 files changed, 350 insertions(+), 209 deletions(-) diff --git a/include/gmssl/sm2.h b/include/gmssl/sm2.h index a4f63abc..c9ad12cb 100644 --- a/include/gmssl/sm2.h +++ b/include/gmssl/sm2.h @@ -290,18 +290,21 @@ int sm2_do_sign(const SM2_KEY *key, const uint8_t dgst[32], SM2_SIGNATURE *sig); int sm2_do_sign_fast(const SM2_Fn d, const uint8_t dgst[32], SM2_SIGNATURE *sig); int sm2_do_verify(const SM2_KEY *key, const uint8_t dgst[32], const SM2_SIGNATURE *sig); + #define SM2_MIN_SIGNATURE_SIZE 8 #define SM2_MAX_SIGNATURE_SIZE 72 int sm2_signature_to_der(const SM2_SIGNATURE *sig, uint8_t **out, size_t *outlen); int sm2_signature_from_der(SM2_SIGNATURE *sig, const uint8_t **in, size_t *inlen); -int sm2_signature_to_public_key_points(const SM2_SIGNATURE *sig, const uint8_t dgst[32], - SM2_POINT points[4], size_t *points_cnt); - int sm2_signature_print(FILE *fp, int fmt, int ind, const char *label, const uint8_t *sig, size_t siglen); -int sm2_sign_ex(const SM2_KEY *key, int flags, const uint8_t dgst[32], uint8_t *sig, size_t *siglen); int sm2_sign(const SM2_KEY *key, const uint8_t dgst[32], uint8_t *sig, size_t *siglen); int sm2_verify(const SM2_KEY *key, const uint8_t dgst[32], const uint8_t *sig, size_t siglen); +enum { + SM2_signature_compact_size = 70, + SM2_signature_typical_size = 71, + SM2_signature_max_size = 72, +}; +int sm2_sign_fixlen(const SM2_KEY *key, const uint8_t dgst[32], size_t siglen, uint8_t *sig); #define SM2_DEFAULT_ID "1234567812345678" #define SM2_DEFAULT_ID_LENGTH (sizeof(SM2_DEFAULT_ID) - 1) // LENGTH for string and SIZE for bytes @@ -342,7 +345,6 @@ typedef struct { uint8_t ciphertext[SM2_MAX_PLAINTEXT_SIZE]; } SM2_CIPHERTEXT; -int sm2_do_encrypt_ex(const SM2_KEY *key, int flags, const uint8_t *in, size_t inlen, SM2_CIPHERTEXT *out); int sm2_do_encrypt(const SM2_KEY *key, const uint8_t *in, size_t inlen, SM2_CIPHERTEXT *out); int sm2_do_decrypt(const SM2_KEY *key, const SM2_CIPHERTEXT *in, uint8_t *out, size_t *outlen); @@ -351,10 +353,17 @@ int sm2_do_decrypt(const SM2_KEY *key, const SM2_CIPHERTEXT *in, uint8_t *out, s int sm2_ciphertext_to_der(const SM2_CIPHERTEXT *c, uint8_t **out, size_t *outlen); int sm2_ciphertext_from_der(SM2_CIPHERTEXT *c, const uint8_t **in, size_t *inlen); int sm2_ciphertext_print(FILE *fp, int fmt, int ind, const char *label, const uint8_t *a, size_t alen); -int sm2_encrypt_ex(const SM2_KEY *key, int flags, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); int sm2_encrypt(const SM2_KEY *key, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); int sm2_decrypt(const SM2_KEY *key, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); +enum { + SM2_ciphertext_compact_point_size = 68, + SM2_ciphertext_typical_point_size = 69, + SM2_ciphertext_max_point_size = 70, +}; +int sm2_do_encrypt_fixlen(const SM2_KEY *key, const uint8_t *in, size_t inlen, int point_size, SM2_CIPHERTEXT *out); +int sm2_encrypt_fixlen(const SM2_KEY *key, const uint8_t *in, size_t inlen, int point_size, uint8_t *out, size_t *outlen); + int sm2_ecdh(const SM2_KEY *key, const SM2_POINT *peer_public, SM2_POINT *out); diff --git a/include/gmssl/sm2_recover.h b/include/gmssl/sm2_recover.h index b89807cd..92d2c01d 100644 --- a/include/gmssl/sm2_recover.h +++ b/include/gmssl/sm2_recover.h @@ -23,7 +23,7 @@ extern "C" { int sm2_signature_to_public_key_points(const SM2_SIGNATURE *sig, const uint8_t dgst[32], SM2_POINT points[4], size_t *points_cnt); - +int sm2_signature_conjugate(const SM2_SIGNATURE *sig, SM2_SIGNATURE *new_sig); #ifdef __cplusplus } diff --git a/src/cms.c b/src/cms.c index 3383343e..fb183d03 100644 --- a/src/cms.c +++ b/src/cms.c @@ -746,14 +746,12 @@ int cms_signer_info_sign_to_der( int fixed_outlen = 1; uint8_t dgst[SM3_DIGEST_SIZE]; uint8_t sig[SM2_MAX_SIGNATURE_SIZE]; - size_t siglen; + size_t siglen = SM2_signature_typical_size; sm3_update(&ctx, authed_attrs, authed_attrs_len); sm3_finish(&ctx, dgst); - - - if (sm2_sign_ex(sign_key, fixed_outlen, dgst, sig, &siglen) != 1) { + if (sm2_sign_fixlen(sign_key, dgst, siglen, sig) != 1) { error_print(); return -1; } @@ -1311,7 +1309,8 @@ int cms_recipient_info_encrypt_to_der( return -1; } - if (sm2_encrypt_ex(public_key, fixed_outlen, in, inlen, enced_key, &enced_key_len) != 1) { + if (sm2_encrypt_fixlen(public_key, in, inlen, SM2_ciphertext_typical_point_size, + enced_key, &enced_key_len) != 1) { error_print(); return -1; } diff --git a/src/sm2_lib.c b/src/sm2_lib.c index 7d26e626..4767763a 100644 --- a/src/sm2_lib.c +++ b/src/sm2_lib.c @@ -23,82 +23,82 @@ extern const SM2_BN SM2_N; extern const SM2_BN SM2_ONE; -#define print_bn(str,a) sm2_bn_print(stderr,0,4,str,a) - -int sm2_do_sign_ex(const SM2_KEY *key, int fixed_outlen, const uint8_t dgst[32], SM2_SIGNATURE *sig) +int sm2_do_sign(const SM2_KEY *key, const uint8_t dgst[32], SM2_SIGNATURE *sig) { SM2_JACOBIAN_POINT _P, *P = &_P; SM2_BN d; + SM2_BN d_inv; SM2_BN e; SM2_BN k; SM2_BN x; + SM2_BN t; SM2_BN r; SM2_BN s; -retry: + //fprintf(stderr, "sm2_do_sign\n"); sm2_bn_from_bytes(d, key->private_key); - // e = H(M) - sm2_bn_from_bytes(e, dgst); //print_bn("e", e); - // e被重用了,注意retry的位置! + // compute (d + 1)^-1 (mod n) + sm2_fn_add(d_inv, d, SM2_ONE); //sm2_bn_print(stderr, 0, 4, "(1+d)", d_inv); + if (sm2_bn_is_zero(d_inv)) { + error_print(); + return -1; + } + sm2_fn_inv(d_inv, d_inv); //sm2_bn_print(stderr, 0, 4, "(1+d)^-1", d_inv); + // e = H(M) + sm2_bn_from_bytes(e, dgst); //sm2_bn_print(stderr, 0, 4, "e", e); + +retry: // rand k in [1, n - 1] do { - sm2_fn_rand(k); - } while (sm2_bn_is_zero(k)); - //print_bn("k", k); + if (sm2_fn_rand(k) != 1) { + error_print(); + return -1; + } + } while (sm2_bn_is_zero(k)); //sm2_bn_print(stderr, 0, 4, "k", k); // (x, y) = kG sm2_jacobian_point_mul_generator(P, k); sm2_jacobian_point_get_xy(P, x, NULL); - //print_bn("x", x); - + //sm2_bn_print(stderr, 0, 4, "x", x); // r = e + x (mod n) - sm2_fn_add(r, e, x); //print_bn("r = e + x (mod n)", r); - - /* if r == 0 or r + k == n re-generate k */ - if (sm2_bn_is_zero(r)) { - goto retry; + if (sm2_bn_cmp(e, SM2_N) >= 0) { + sm2_bn_sub(e, e, SM2_N); } - sm2_bn_add(x, r, k); - if (sm2_bn_cmp(x, SM2_N) == 0) { + if (sm2_bn_cmp(x, SM2_N) >= 0) { + sm2_bn_sub(x, x, SM2_N); + } + sm2_fn_add(r, e, x); //sm2_bn_print(stderr, 0, 4, "r = e + x (mod n)", r); + + // if r == 0 or r + k == n re-generate k + sm2_bn_add(t, r, k); + if (sm2_bn_is_zero(r) || sm2_bn_cmp(t, SM2_N) == 0) { + //sm2_bn_print(stderr, 0, 4, "r + k", t); goto retry; } - /* s = ((1 + d)^-1 * (k - r * d)) mod n */ + // s = ((1 + d)^-1 * (k - r * d)) mod n + sm2_fn_mul(t, r, d); //sm2_bn_print(stderr, 0, 4, "r*d", t); + sm2_fn_sub(k, k, t); //sm2_bn_print(stderr, 0, 4, "k-r*d", k); + sm2_fn_mul(s, d_inv, k); //sm2_bn_print(stderr, 0, 4, "s = ((1 + d)^-1 * (k - r * d)) mod n", s); - sm2_fn_mul(e, r, d); //print_bn("r*d", e); - sm2_fn_sub(k, k, e); //print_bn("k-r*d", k); - sm2_fn_add(e, SM2_ONE, d); //print_bn("1 +d", e); - sm2_fn_inv(e, e); //print_bn("(1+d)^-1", e); - sm2_fn_mul(s, e, k); //print_bn("s = ((1 + d)^-1 * (k - r * d)) mod n", s); - - sm2_bn_to_bytes(r, sig->r); //print_bn("r", r); - sm2_bn_to_bytes(s, sig->s); //print_bn("s", s); - - if (fixed_outlen) { - uint8_t buf[72]; - uint8_t *p = buf; - size_t len = 0; - sm2_signature_to_der(sig, &p, &len); - if (len != 71) { - goto retry; - } + // check s != 0 + if (sm2_bn_is_zero(s)) { + goto retry; } + sm2_bn_to_bytes(r, sig->r); //sm2_bn_print_bn(stderr, 0, 4, "r", r); + sm2_bn_to_bytes(s, sig->s); //sm2_bn_print_bn(stderr, 0, 4, "s", s); + gmssl_secure_clear(d, sizeof(d)); - gmssl_secure_clear(e, sizeof(e)); + gmssl_secure_clear(d_inv, sizeof(d_inv )); gmssl_secure_clear(k, sizeof(k)); - gmssl_secure_clear(x, sizeof(x)); + gmssl_secure_clear(t, sizeof(t)); return 1; } -int sm2_do_sign(const SM2_KEY *key, const uint8_t dgst[32], SM2_SIGNATURE *sig) -{ - return sm2_do_sign_ex(key, 0, dgst, sig); -} - int sm2_do_verify(const SM2_KEY *key, const uint8_t dgst[32], const SM2_SIGNATURE *sig) { SM2_JACOBIAN_POINT _P, *P = &_P; @@ -109,9 +109,15 @@ int sm2_do_verify(const SM2_KEY *key, const uint8_t dgst[32], const SM2_SIGNATUR SM2_BN x; SM2_BN t; + // parse public key + sm2_jacobian_point_from_bytes(P, (const uint8_t *)&key->public_key); + //sm2_jacobian_point_print(stderr, 0, 4, "P", P); + // parse signature values - sm2_bn_from_bytes(r, sig->r); //print_bn("r", r); - sm2_bn_from_bytes(s, sig->s); //print_bn("s", s); + sm2_bn_from_bytes(r, sig->r); //sm2_bn_print(stderr, 0, 4, "r", r); + sm2_bn_from_bytes(s, sig->s); //sm2_bn_print(stderr, 0, 4, "s", s); + + // check r, s in [1, n-1] if (sm2_bn_is_zero(r) == 1 || sm2_bn_cmp(r, SM2_N) >= 0 || sm2_bn_is_zero(s) == 1 @@ -120,13 +126,11 @@ int sm2_do_verify(const SM2_KEY *key, const uint8_t dgst[32], const SM2_SIGNATUR return -1; } - // parse public key - sm2_jacobian_point_from_bytes(P, (const uint8_t *)&key->public_key); - //print_point("P", P); + // e = H(M) + sm2_bn_from_bytes(e, dgst); //sm2_bn_print(stderr, 0, 4, "e = H(M)", e); - // t = r + s (mod n) - // check t != 0 - sm2_fn_add(t, r, s); //print_bn("t = r + s (mod n)", t); + // t = r + s (mod n), check t != 0 + sm2_fn_add(t, r, s); //sm2_bn_print(stderr, 0, 4, "t = r + s (mod n)", t); if (sm2_bn_is_zero(t)) { error_print(); return -1; @@ -135,13 +139,16 @@ int sm2_do_verify(const SM2_KEY *key, const uint8_t dgst[32], const SM2_SIGNATUR // Q = s * G + t * P sm2_jacobian_point_mul_sum(R, t, P, s); sm2_jacobian_point_get_xy(R, x, NULL); - //print_bn("x", x); + //sm2_bn_print(stderr, 0, 4, "x", x); - // e = H(M) // r' = e + x (mod n) - sm2_bn_from_bytes(e, dgst); //print_bn("e = H(M)", e); - sm2_fn_add(e, e, x); //print_bn("e + x (mod n)", e); - + if (sm2_bn_cmp(e, SM2_N) >= 0) { + sm2_bn_sub(e, e, SM2_N); + } + if (sm2_bn_cmp(x, SM2_N) >= 0) { + sm2_bn_sub(x, x, SM2_N); + } + sm2_fn_add(e, e, x); //sm2_bn_print(stderr, 0, 4, "e + x (mod n)", e); // check if r == r' if (sm2_bn_cmp(e, r) != 0) { @@ -150,23 +157,6 @@ int sm2_do_verify(const SM2_KEY *key, const uint8_t dgst[32], const SM2_SIGNATUR return 1; } -// verify the xR of R = s * G + (s + r) * P -// so (-r, -s) is also a valid SM2 signature -int sm2_signature_conjugate(const SM2_SIGNATURE *sig, SM2_SIGNATURE *new_sig) -{ - SM2_Fn r; - SM2_Fn s; - - sm2_bn_from_bytes(r, sig->r); - sm2_bn_from_bytes(s, sig->s); - sm2_fn_neg(r, r); - sm2_fn_neg(s, s); - sm2_bn_to_bytes(r, new_sig->r); - sm2_bn_to_bytes(s, new_sig->s); - - return 1; -} - int sm2_signature_to_der(const SM2_SIGNATURE *sig, uint8_t **out, size_t *outlen) { size_t len = 0; @@ -207,7 +197,7 @@ int sm2_signature_from_der(SM2_SIGNATURE *sig, const uint8_t **in, size_t *inlen return -1; } memset(sig, 0, sizeof(*sig)); - memcpy(sig->r + 32 - rlen, r, rlen); // 需要测试当r, s是比较小的整数时 + memcpy(sig->r + 32 - rlen, r, rlen); memcpy(sig->s + 32 - slen, s, slen); return 1; } @@ -227,65 +217,84 @@ int sm2_signature_print(FILE *fp, int fmt, int ind, const char *label, const uin return 1; } -#define SM2_SIGNATURE_MAX_DER_SIZE 77 - -int sm2_sign_ex(const SM2_KEY *key, int fixed_outlen, const uint8_t dgst[32], uint8_t *sig, size_t *siglen) +int sm2_sign(const SM2_KEY *key, const uint8_t dgst[32], uint8_t *sigbuf, size_t *siglen) { - SM2_SIGNATURE signature; + SM2_SIGNATURE sig; uint8_t *p; - if (!key - || !dgst - || !sig - || !siglen) { + if (!key || !dgst || !sigbuf || !siglen) { + error_print(); + return -1; + } + + if (sm2_do_sign(key, dgst, &sig) != 1) { error_print(); return -1; } - p = sig; *siglen = 0; - if (sm2_do_sign_ex(key, fixed_outlen, dgst, &signature) != 1 - || sm2_signature_to_der(&signature, &p, siglen) != 1) { + if (sm2_signature_to_der(&sig, &sigbuf, siglen) != 1) { error_print(); return -1; } return 1; } -int sm2_sign(const SM2_KEY *key, const uint8_t dgst[32], uint8_t *sig, size_t *siglen) +int sm2_sign_fixlen(const SM2_KEY *key, const uint8_t dgst[32], size_t siglen, uint8_t *sig) { - return sm2_sign_ex(key, 0, dgst, sig, siglen); -} + unsigned int trys = 200; // 200 trys is engouh + uint8_t buf[SM2_MAX_SIGNATURE_SIZE]; + size_t len; -int sm2_verify(const SM2_KEY *key, const uint8_t dgst[32], const uint8_t *sig, size_t siglen) -{ - int ret; - SM2_SIGNATURE signature; - const uint8_t *p; - - if (!key - || !dgst - || !sig - || !siglen) { + switch (siglen) { + case SM2_signature_compact_size: + case SM2_signature_typical_size: + case SM2_signature_max_size: + break; + default: error_print(); return -1; } - p = sig; - if (sm2_signature_from_der(&signature, &p, &siglen) != 1 + while (trys--) { + if (sm2_sign(key, dgst, buf, &len) != 1) { + error_print(); + return -1; + } + if (len == siglen) { + memcpy(sig, buf, len); + return 1; + } + } + + // might caused by bad randomness + error_print(); + return -1; +} + +int sm2_verify(const SM2_KEY *key, const uint8_t dgst[32], const uint8_t *sigbuf, size_t siglen) +{ + int ret; + SM2_SIGNATURE sig; + const uint8_t *p; + + if (!key || !dgst || !sigbuf || !siglen) { + error_print(); + return -1; + } + + if (sm2_signature_from_der(&sig, &sigbuf, &siglen) != 1 || asn1_length_is_zero(siglen) != 1) { error_print(); return -1; } - if ((ret = sm2_do_verify(key, dgst, &signature)) != 1) { + if ((ret = sm2_do_verify(key, dgst, &sig)) != 1) { if (ret < 0) error_print(); return ret; } return 1; } -extern void sm3_compress_blocks(uint32_t digest[8], const uint8_t *data, size_t blocks); - int sm2_compute_z(uint8_t z[32], const SM2_POINT *pub, const char *id, size_t idlen) { SM3_CTX ctx; @@ -453,120 +462,223 @@ int sm2_kdf(const uint8_t *in, size_t inlen, size_t outlen, uint8_t *out) return 1; } -int sm2_do_encrypt_ex(const SM2_KEY *key, int fixed_outlen, const uint8_t *in, size_t inlen, SM2_CIPHERTEXT *out) +static int all_zero(const uint8_t *buf, size_t len) { - SM2_BN k; - SM2_JACOBIAN_POINT _P, *P = &_P; - SM3_CTX sm3_ctx; - uint8_t buf[64]; - int i; - -retry: - // rand k in [1, n - 1] - sm2_bn_rand_range(k, SM2_N); - if (sm2_bn_is_zero(k)) goto retry; - - // C1 = k * G = (x1, y1) - sm2_jacobian_point_mul_generator(P, k); - sm2_jacobian_point_to_bytes(P, (uint8_t *)&out->point); - - if (fixed_outlen) { - size_t xlen = 0, ylen = 0; - asn1_integer_to_der(out->point.x, 32, NULL, &xlen); - if (xlen != 34) goto retry; - asn1_integer_to_der(out->point.y, 32, NULL, &ylen); - if (ylen != 34) goto retry; + size_t i; + for (i = 0; i < len; i++) { + if (buf[i]) { + return 0; + } } - - // Q = k * P = (x2, y2) - sm2_jacobian_point_from_bytes(P, (uint8_t *)&key->public_key); - - sm2_jacobian_point_mul(P, k, P); - - sm2_jacobian_point_to_bytes(P, buf); - - - // t = KDF(x2 || y2, klen) - sm2_kdf(buf, sizeof(buf), inlen, out->ciphertext); - - - // C2 = M xor t - for (i = 0; i < inlen; i++) { - out->ciphertext[i] ^= in[i]; - } - out->ciphertext_size = (uint32_t)inlen; - - // C3 = Hash(x2 || m || y2) - sm3_init(&sm3_ctx); - sm3_update(&sm3_ctx, buf, 32); - sm3_update(&sm3_ctx, in, inlen); - sm3_update(&sm3_ctx, buf + 32, 32); - sm3_finish(&sm3_ctx, out->hash); - return 1; } int sm2_do_encrypt(const SM2_KEY *key, const uint8_t *in, size_t inlen, SM2_CIPHERTEXT *out) { - return sm2_do_encrypt_ex(key, 0, in, inlen, out); + SM2_BN k; + SM2_JACOBIAN_POINT _P, *P = &_P; + SM2_JACOBIAN_POINT _C1, *C1 = &_C1; + SM2_JACOBIAN_POINT _kP, *kP = &_kP; + uint8_t x2y2[64]; + SM3_CTX sm3_ctx; + size_t i; + + if (!(SM2_MIN_PLAINTEXT_SIZE <= inlen && inlen <= SM2_MAX_PLAINTEXT_SIZE)) { + error_print(); + return -1; + } + + sm2_jacobian_point_from_bytes(P, (uint8_t *)&key->public_key); + + // S = h * P, check S != O + // for sm2 curve, h == 1 and S == P + // SM2_POINT can not present point at infinity, do do nothing here + +retry: + // rand k in [1, n - 1] + do { + if (sm2_fn_rand(k) != 1) { + error_print(); + return -1; + } + } while (sm2_bn_is_zero(k)); //sm2_bn_print(stderr, 0, 4, "k", k); + + // output C1 = k * G = (x1, y1) + sm2_jacobian_point_mul_generator(C1, k); + sm2_jacobian_point_to_bytes(C1, (uint8_t *)&out->point); + + // k * P = (x2, y2) + sm2_jacobian_point_mul(kP, k, P); + sm2_jacobian_point_to_bytes(kP, x2y2); + + // t = KDF(x2 || y2, inlen) + sm2_kdf(x2y2, 64, inlen, out->ciphertext); + + // if t is all zero, retry + if (all_zero(out->ciphertext, inlen)) { + goto retry; + } + + // output C2 = M xor t + gmssl_memxor(out->ciphertext, out->ciphertext, in, inlen); + out->ciphertext_size = (uint32_t)inlen; + + // output C3 = Hash(x2 || m || y2) + sm3_init(&sm3_ctx); + sm3_update(&sm3_ctx, x2y2, 32); + sm3_update(&sm3_ctx, in, inlen); + sm3_update(&sm3_ctx, x2y2 + 32, 32); + sm3_finish(&sm3_ctx, out->hash); + + gmssl_secure_clear(k, sizeof(k)); + gmssl_secure_clear(kP, sizeof(SM2_JACOBIAN_POINT)); + gmssl_secure_clear(x2y2, sizeof(x2y2)); + return 1; +} + +int sm2_do_encrypt_fixlen(const SM2_KEY *key, const uint8_t *in, size_t inlen, int point_size, SM2_CIPHERTEXT *out) +{ + unsigned int trys = 200; + SM2_BN k; + SM2_JACOBIAN_POINT _P, *P = &_P; + SM2_JACOBIAN_POINT _C1, *C1 = &_C1; + SM2_JACOBIAN_POINT _kP, *kP = &_kP; + uint8_t x2y2[64]; + SM3_CTX sm3_ctx; + size_t i; + + if (!(SM2_MIN_PLAINTEXT_SIZE <= inlen && inlen <= SM2_MAX_PLAINTEXT_SIZE)) { + error_print(); + return -1; + } + + switch (point_size) { + case SM2_ciphertext_compact_point_size: + case SM2_ciphertext_typical_point_size: + case SM2_ciphertext_max_point_size: + break; + default: + error_print(); + return -1; + } + + sm2_jacobian_point_from_bytes(P, (uint8_t *)&key->public_key); + + // S = h * P, check S != O + // for sm2 curve, h == 1 and S == P + // SM2_POINT can not present point at infinity, do do nothing here + +retry: + // rand k in [1, n - 1] + do { + if (sm2_fn_rand(k) != 1) { + error_print(); + return -1; + } + } while (sm2_bn_is_zero(k)); //sm2_bn_print(stderr, 0, 4, "k", k); + + // output C1 = k * G = (x1, y1) + sm2_jacobian_point_mul_generator(C1, k); + sm2_jacobian_point_to_bytes(C1, (uint8_t *)&out->point); + + // check fixlen + if (trys) { + size_t len = 0; + asn1_integer_to_der(out->point.x, 32, NULL, &len); + asn1_integer_to_der(out->point.y, 32, NULL, &len); + if (len != point_size) { + trys--; + goto retry; + } + } else { + gmssl_secure_clear(k, sizeof(k)); + error_print(); + return -1; + } + + // k * P = (x2, y2) + sm2_jacobian_point_mul(kP, k, P); + sm2_jacobian_point_to_bytes(kP, x2y2); + + // t = KDF(x2 || y2, inlen) + sm2_kdf(x2y2, 64, inlen, out->ciphertext); + + // if t is all zero, retry + if (all_zero(out->ciphertext, inlen)) { + goto retry; + } + + // output C2 = M xor t + gmssl_memxor(out->ciphertext, out->ciphertext, in, inlen); + out->ciphertext_size = (uint32_t)inlen; + + // output C3 = Hash(x2 || m || y2) + sm3_init(&sm3_ctx); + sm3_update(&sm3_ctx, x2y2, 32); + sm3_update(&sm3_ctx, in, inlen); + sm3_update(&sm3_ctx, x2y2 + 32, 32); + sm3_finish(&sm3_ctx, out->hash); + + gmssl_secure_clear(k, sizeof(k)); + gmssl_secure_clear(kP, sizeof(SM2_JACOBIAN_POINT)); + gmssl_secure_clear(x2y2, sizeof(x2y2)); + return 1; } int sm2_do_decrypt(const SM2_KEY *key, const SM2_CIPHERTEXT *in, uint8_t *out, size_t *outlen) { - uint32_t inlen, i; + int ret = -1; SM2_BN d; - SM2_JACOBIAN_POINT _P, *P = &_P; + SM2_JACOBIAN_POINT _C1, *C1 = &_C1; + uint8_t x2y2[64]; SM3_CTX sm3_ctx; - uint8_t buf[64]; uint8_t hash[32]; - // FIXME: check SM2_CIPHERTEXT format - - // check C1 - sm2_jacobian_point_from_bytes(P, (uint8_t *)&in->point); - //point_print(stdout, P, 0, 2); - - /* - if (!sm2_jacobian_point_is_on_curve(P)) { - fprintf(stderr, "%s %d: invalid ciphertext\n", __FILE__, __LINE__); + // check C1 is on sm2 curve + sm2_jacobian_point_from_bytes(C1, (uint8_t *)&in->point); + if (!sm2_jacobian_point_is_on_curve(C1)) { + error_print(); return -1; } - */ + + // check if S = h * C1 is point at infinity + // this will not happen, as SM2_POINT can not present point at infinity // d * C1 = (x2, y2) sm2_bn_from_bytes(d, key->private_key); - sm2_jacobian_point_mul(P, d, P); - sm2_bn_clean(d); - sm2_jacobian_point_to_bytes(P, buf); + sm2_jacobian_point_mul(C1, d, C1); - // t = KDF(x2 || y2, klen) - if ((inlen = in->ciphertext_size) <= 0) { - fprintf(stderr, "%s %d: invalid ciphertext\n", __FILE__, __LINE__); - return -1; + // t = KDF(x2 || y2, klen) and check t is not all zeros + sm2_jacobian_point_to_bytes(C1, x2y2); + sm2_kdf(x2y2, 64, in->ciphertext_size, out); + if (all_zero(out, in->ciphertext_size)) { + error_print(); + goto end; } - sm2_kdf(buf, sizeof(buf), inlen, out); - // M = C2 xor t - for (i = 0; i < inlen; i++) { - out[i] ^= in->ciphertext[i]; - } - *outlen = inlen; + gmssl_memxor(out, out, in->ciphertext, in->ciphertext_size); + *outlen = in->ciphertext_size; // u = Hash(x2 || M || y2) sm3_init(&sm3_ctx); - sm3_update(&sm3_ctx, buf, 32); - sm3_update(&sm3_ctx, out, inlen); - sm3_update(&sm3_ctx, buf + 32, 32); + sm3_update(&sm3_ctx, x2y2, 32); + sm3_update(&sm3_ctx, out, in->ciphertext_size); + sm3_update(&sm3_ctx, x2y2 + 32, 32); sm3_finish(&sm3_ctx, hash); // check if u == C3 if (memcmp(in->hash, hash, sizeof(hash)) != 0) { - fprintf(stderr, "%s %d: invalid ciphertext\n", __FILE__, __LINE__); - return -1; + error_print(); + goto end; } + ret = 1; - return 1; +end: + gmssl_secure_clear(d, sizeof(d)); + gmssl_secure_clear(C1, sizeof(SM2_JACOBIAN_POINT)); + gmssl_secure_clear(x2y2, sizeof(x2y2)); + return ret; } int sm2_ciphertext_to_der(const SM2_CIPHERTEXT *C, uint8_t **out, size_t *outlen) @@ -613,6 +725,7 @@ int sm2_ciphertext_from_der(SM2_CIPHERTEXT *C, const uint8_t **in, size_t *inlen || asn1_length_le(ylen, 32) != 1 || asn1_check(hashlen == 32) != 1 || asn1_length_le(clen, SM2_MAX_PLAINTEXT_SIZE) != 1 + || asn1_length_is_zero(clen) == 1 || asn1_length_is_zero(dlen) != 1) { error_print(); return -1; @@ -649,19 +762,11 @@ int sm2_ciphertext_print(FILE *fp, int fmt, int ind, const char *label, const ui return 1; } -int sm2_encrypt_ex(const SM2_KEY *key, int fixed_outlen, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) +int sm2_encrypt(const SM2_KEY *key, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { SM2_CIPHERTEXT C; - if (!key || !in || !out || !outlen) { - error_print(); - return -1; - } - if (inlen < SM2_MIN_PLAINTEXT_SIZE || inlen > SM2_MAX_PLAINTEXT_SIZE) { - error_print(); - return -1; - } - if (sm2_do_encrypt_ex(key, fixed_outlen, in, inlen, &C) != 1) { + if (sm2_do_encrypt(key, in, inlen, &C) != 1) { error_print(); return -1; } @@ -673,9 +778,20 @@ int sm2_encrypt_ex(const SM2_KEY *key, int fixed_outlen, const uint8_t *in, size return 1; } -int sm2_encrypt(const SM2_KEY *key, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) +int sm2_encrypt_fixlen(const SM2_KEY *key, const uint8_t *in, size_t inlen, int point_size, uint8_t *out, size_t *outlen) { - return sm2_encrypt_ex(key, 0, in, inlen, out, outlen); + SM2_CIPHERTEXT C; + + if (sm2_do_encrypt_fixlen(key, in, inlen, point_size, &C) != 1) { + error_print(); + return -1; + } + *outlen = 0; + if (sm2_ciphertext_to_der(&C, &out, outlen) != 1) { + error_print(); + return -1; + } + return 1; } int sm2_decrypt(const SM2_KEY *key, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) diff --git a/src/sm2_recover.c b/src/sm2_recover.c index 386dcfc3..e550ce47 100644 --- a/src/sm2_recover.c +++ b/src/sm2_recover.c @@ -112,3 +112,20 @@ int sm2_signature_to_public_key_points(const SM2_SIGNATURE *sig, const uint8_t d return 1; } + +// verify the xR of R = s * G + (s + r) * P +// so (-r, -s) is also a valid SM2 signature +int sm2_signature_conjugate(const SM2_SIGNATURE *sig, SM2_SIGNATURE *new_sig) +{ + SM2_Fn r; + SM2_Fn s; + + sm2_bn_from_bytes(r, sig->r); + sm2_bn_from_bytes(s, sig->s); + sm2_fn_neg(r, r); + sm2_fn_neg(s, s); + sm2_bn_to_bytes(r, new_sig->r); + sm2_bn_to_bytes(s, new_sig->s); + + return 1; +}