diff --git a/include/gmssl/sm9.h b/include/gmssl/sm9.h index 6e8d59c3..af74e2a4 100644 --- a/include/gmssl/sm9.h +++ b/include/gmssl/sm9.h @@ -448,6 +448,10 @@ int sm9_do_encrypt(const SM9_ENC_MASTER_KEY *mpk, const char *id, size_t idlen, const uint8_t *in, size_t inlen, sm9_point_t *C1, uint8_t *c2, uint8_t c3[SM3_HMAC_SIZE]); int sm9_do_decrypt(const SM9_ENC_KEY *key, const char *id, size_t idlen, const sm9_point_t *C1, const uint8_t *c2, size_t c2len, const uint8_t c3[SM3_HMAC_SIZE], uint8_t *out); +int sm9_encrypt(const SM9_ENC_MASTER_KEY *mpk, const char *id, size_t idlen, + const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); +int sm9_decrypt(const SM9_ENC_KEY *key, const char *id, size_t idlen, + const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); # ifdef __cplusplus } diff --git a/src/sm9_alg.c b/src/sm9_alg.c index 2b0b7273..b7ab3ab3 100644 --- a/src/sm9_alg.c +++ b/src/sm9_alg.c @@ -1702,11 +1702,15 @@ void sm9_twist_point_mul_generator(sm9_twist_point_t *R, const sm9_bn_t k) void sm9_eval_g_tangent(sm9_fp12_t num, sm9_fp12_t den, const sm9_twist_point_t *P, const sm9_point_t *Q) { + sm9_fp_t x; + sm9_fp_t y; + sm9_point_get_xy(Q, x, y); + const sm9_fp_t *XP = P->X; const sm9_fp_t *YP = P->Y; const sm9_fp_t *ZP = P->Z; - const uint64_t *xQ = Q->X; - const uint64_t *yQ = Q->Y; + const uint64_t *xQ = x; + const uint64_t *yQ = y; sm9_fp_t *a0 = num[0][0]; sm9_fp_t *a1 = num[0][1]; @@ -1743,14 +1747,18 @@ void sm9_eval_g_tangent(sm9_fp12_t num, sm9_fp12_t den, const sm9_twist_point_t void sm9_eval_g_line(sm9_fp12_t num, sm9_fp12_t den, const sm9_twist_point_t *T, const sm9_twist_point_t *P, const sm9_point_t *Q) { + sm9_fp_t x; + sm9_fp_t y; + sm9_point_get_xy(Q, x, y); + const sm9_fp_t *XT = T->X; const sm9_fp_t *YT = T->Y; const sm9_fp_t *ZT = T->Z; const sm9_fp_t *XP = P->X; const sm9_fp_t *YP = P->Y; const sm9_fp_t *ZP = P->Z; - const uint64_t *xQ = Q->X; - const uint64_t *yQ = Q->Y; + const uint64_t *xQ = x; + const uint64_t *yQ = y; sm9_fp_t *a0 = num[0][0]; sm9_fp_t *a1 = num[0][1]; @@ -2209,33 +2217,39 @@ int sm9_point_to_uncompressed_octets(const sm9_point_t *P, uint8_t octets[65]) sm9_fp_t x; sm9_fp_t y; sm9_point_get_xy(P, x, y); - sm9_bn_to_bytes(x, octets); - sm9_bn_to_bytes(y, octets + 32); + octets[0] = 0x04; + sm9_bn_to_bytes(x, octets + 1); + sm9_bn_to_bytes(y, octets + 32 + 1); return 1; } int sm9_point_from_uncompressed_octets(sm9_point_t *P, const uint8_t octets[65]) { - sm9_bn_from_bytes(P->X, octets); - sm9_bn_from_bytes(P->Y, octets + 32); + assert(octets[0] == 0x04); + sm9_bn_from_bytes(P->X, octets + 1); + sm9_bn_from_bytes(P->Y, octets + 32 + 1); sm9_fp_set_one(P->Z); + if (!sm9_point_is_on_curve(P)) return -1; return 1; } int sm9_twist_point_to_uncompressed_octets(const sm9_twist_point_t *P, uint8_t octets[129]) { + octets[0] = 0x04; sm9_fp2_t x; sm9_fp2_t y; sm9_twist_point_get_xy(P, x, y); - sm9_fp2_to_bytes(x, octets); - sm9_fp2_to_bytes(y, octets + 32 * 2); + sm9_fp2_to_bytes(x, octets + 1); + sm9_fp2_to_bytes(y, octets + 32 * 2 + 1); return 1; } int sm9_twist_point_from_uncompressed_octets(sm9_twist_point_t *P, const uint8_t octets[129]) { - sm9_fp2_from_bytes(P->X, octets); - sm9_fp2_from_bytes(P->Y, octets + 32 * 2); + assert(octets[0] == 0x04); + sm9_fp2_from_bytes(P->X, octets + 1); + sm9_fp2_from_bytes(P->Y, octets + 32 * 2 + 1); sm9_fp2_set_one(P->Z); + if (!sm9_twist_point_is_on_curve(P)) return -1; return 1; } diff --git a/src/sm9_lib.c b/src/sm9_lib.c index 6eda85ad..5da96d38 100644 --- a/src/sm9_lib.c +++ b/src/sm9_lib.c @@ -139,8 +139,6 @@ int sm9_sign_finish(SM9_SIGN_CTX *ctx, const SM9_SIGN_KEY *key, uint8_t *sig, si return 1; } -#define hex_r "00033C8616B06704813203DFD00965022ED15975C662337AED648835DC4B1CBE" - int sm9_do_sign(const SM9_SIGN_KEY *key, const SM3_CTX *sm3_ctx, SM9_SIGNATURE *sig) { sm9_fn_t r; @@ -158,7 +156,6 @@ int sm9_do_sign(const SM9_SIGN_KEY *key, const SM3_CTX *sm3_ctx, SM9_SIGNATURE * do { // A2: rand r in [1, N-1] sm9_fn_rand(r); - //sm9_bn_from_hex(r, hex_r); // A3: w = g^r sm9_fp12_pow(g, g, r); @@ -289,7 +286,7 @@ int sm9_kem_encrypt(const SM9_ENC_MASTER_KEY *mpk, const char *id, size_t idlen, SM3_KDF_CTX kdf_ctx; // A1: Q = H1(ID||hid,N) * P1 + Ppube - sm9_hash1(r, id, idlen, SM9_HID_EXCH); + sm9_hash1(r, id, idlen, SM9_HID_ENC); sm9_point_mul(C, r, SM9_P1); sm9_point_add(C, C, &mpk->Ppube); @@ -335,10 +332,11 @@ int sm9_kem_decrypt(const SM9_ENC_KEY *key, const char *id, size_t idlen, const SM3_KDF_CTX kdf_ctx; // B1: check C in G1 - sm9_point_to_uncompressed_octets(C, cbuf + 1); + sm9_point_to_uncompressed_octets(C, cbuf); // B2: w = e(C, de); sm9_pairing(w, &key->de, C); + sm9_fp12_to_bytes(w, wbuf); // B3: K = KDF(C || w || ID, klen) sm3_kdf_init(&kdf_ctx, klen); diff --git a/tests/sm9test.c b/tests/sm9test.c index 6ce5ac50..1dbc664f 100644 --- a/tests/sm9test.c +++ b/tests/sm9test.c @@ -503,37 +503,67 @@ err: } #define hex_ks "000130E78459D78545CB54C587E02CF480CE0B66340F319F348A1D5B1F2DC5F4" -#define hex_t2 "291FE3CAC8F58AD2DC462C8D4D578A94DAFD5624DDC28E328D2936688A86CF1A" -#define hex_Ppubs "9f64080b3084f733e48aff4b41b565011ce0711c5e392cfb0ab1b6791b94c408\ --29dba116152d1f786ce843ed24a3b573414d2177386a92dd8f14d65696ea5e32\ --69850938abea0112b57329f447e3a0cbad3e2fdb1a77f335e89e1408d0ef1c25\ --41e00a53dda532da1a7ce027b7a46f741006e85f5cdff0730e75c05fb4e3216d" #define hex_ds "A5702F05CF1315305E2D6EB64B0DEB923DB1A0BCF0CAFF90523AC8754AA69820-78559A844411F9825C109F5EE3F52D720DD01785392A727BB1556952B2B013D3" -#define hex_h_ans "823C4B21E4BD2DFE1ED92C606653E996668563152FC33F55D7BFBB9BD9705ADB" -#define hex_S_ans "73BF96923CE58B6AD0E13E9643A406D8EB98417C50EF1B29CEF9ADB48B6D598C-856712F1C2E0968AB7769F42A99586AED139D5B8B3E15891827CC2ACED9BAA05" int test_sm9_sign() { SM9_SIGN_CTX ctx; SM9_SIGN_KEY key; - SM9_SIGNATURE sig; - sm9_twist_point_t ts; - sm9_point_t s; - sm9_bn_t k; + SM9_SIGN_MASTER_KEY mpk; + sm9_point_t ds; + uint8_t sig[1000] = {0}; + size_t siglen = 0; int j = 1; uint8_t data[20] = {0x43, 0x68, 0x69, 0x6E, 0x65, 0x73, 0x65, 0x20, 0x49, 0x42, 0x53, 0x20, 0x73, 0x74, 0x61, 0x6E, 0x64, 0x61, 0x72, 0x64}; + uint8_t IDA[5] = {0x41, 0x6C, 0x69, 0x63, 0x65}; - sm9_bn_from_hex(k, hex_ks); sm9_twist_point_mul_generator(&(key.Ppubs), k); - sm9_twist_point_from_hex(&ts, hex_Ppubs); if (!sm9_twist_point_equ(&ts, &(key.Ppubs))) goto err; ++j; - sm9_bn_from_hex(k, hex_t2); sm9_point_mul_generator(&(key.ds), k); - sm9_point_from_hex(&s, hex_ds); if (!sm9_point_equ(&(key.ds), &s)) goto err; ++j; + sm9_bn_from_hex(mpk.ks, hex_ks); sm9_twist_point_mul_generator(&(mpk.Ppubs), mpk.ks); + if (sm9_sign_master_key_extract_key(&mpk, IDA, sizeof(IDA), &key) < 0) goto err; ++j; + sm9_point_from_hex(&ds, hex_ds); if (!sm9_point_equ(&(key.ds), &ds)) goto err; ++j; sm9_sign_init(&ctx); sm9_sign_update(&ctx, data, sizeof(data)); - sm9_do_sign(&key, &(ctx.sm3_ctx), &sig); - sm9_bn_from_hex(k, hex_h_ans); if (!sm9_fn_equ(sig.h, k)) goto err; ++j; - sm9_point_from_hex(&s, hex_S_ans); if (!sm9_point_equ(&(sig.S), &s)) goto err; ++j; + if (sm9_sign_finish(&ctx, &key, sig, &siglen) < 0) goto err; ++j; + + sm9_verify_init(&ctx); + sm9_verify_update(&ctx, data, sizeof(data)); + if (sm9_verify_finish(&ctx, sig, siglen, &mpk, IDA, sizeof(IDA)) < 0) goto err; ++j; + + printf("%s() ok\n", __FUNCTION__); + return 1; +err: + printf("%s test %d failed\n", __FUNCTION__, j); + error_print(); + return -1; +} + +#define hex_ke "0001EDEE3778F441F8DEA3D9FA0ACC4E07EE36C93F9A08618AF4AD85CEDE1C22" +#define hex_de "94736ACD2C8C8796CC4785E938301A139A059D3537B6414140B2D31EECF41683\ +-115BAE85F5D8BC6C3DBD9E5342979ACCCF3C2F4F28420B1CB4F8C0B59A19B158\ +-7AA5E47570DA7600CD760A0CF7BEAF71C447F3844753FE74FA7BA92CA7D3B55F\ +-27538A62E7F7BFB51DCE08704796D94C9D56734F119EA44732B50E31CDEB75C1" + +int test_sm9_encrypt() { + SM9_ENC_MASTER_KEY msk; + SM9_ENC_KEY key; + sm9_twist_point_t de; + uint8_t out[1000] = {0}; + size_t outlen = 0; + int j = 1; + uint8_t data[20] = {0x43, 0x68, 0x69, 0x6E, 0x65, 0x73, 0x65, 0x20, 0x49, 0x42, 0x53, 0x20, 0x73, 0x74, 0x61, 0x6E, 0x64, 0x61, 0x72, 0x64}; + uint8_t dec[20] = {0}; + size_t declen = 20; + uint8_t IDB[3] = {0x42, 0x6F, 0x62}; + + sm9_bn_from_hex(msk.ke, hex_ke); sm9_point_mul_generator(&(msk.Ppube), msk.ke); + if (sm9_enc_master_key_extract_key(&msk, IDB, sizeof(IDB), &key) < 0) goto err; ++j; + sm9_twist_point_from_hex(&de, hex_de); if (!sm9_twist_point_equ(&(key.de), &de)) goto err; ++j; + + if (sm9_encrypt(&msk, IDB, sizeof(IDB), data, sizeof(data), out, &outlen) < 0) goto err; ++j; + if (sm9_decrypt(&key, IDB, sizeof(IDB), out, outlen, dec, &declen) < 0) goto err; ++j; + if (memcmp(data, dec, sizeof(data)) != 0) goto err; ++j; + printf("%s() ok\n", __FUNCTION__); return 1; err: @@ -551,7 +581,8 @@ int main(void) { if (test_sm9_point() != 1) goto err; if (test_sm9_twist_point() != 1) goto err; if (test_sm9_pairing() != 1) goto err; - // if (test_sm9_sign() != 1) goto err; /* Must open "#define hex_r" in sm9_lib.c */ + if (test_sm9_sign() != 1) goto err; + if (test_sm9_encrypt() != 1) goto err; printf("%s all tests passed\n", __FILE__); return 0;