diff --git a/src/sm9_alg.c b/src/sm9_alg.c index 2b0b7273..b7ab3ab3 100644 --- a/src/sm9_alg.c +++ b/src/sm9_alg.c @@ -1702,11 +1702,15 @@ void sm9_twist_point_mul_generator(sm9_twist_point_t *R, const sm9_bn_t k) void sm9_eval_g_tangent(sm9_fp12_t num, sm9_fp12_t den, const sm9_twist_point_t *P, const sm9_point_t *Q) { + sm9_fp_t x; + sm9_fp_t y; + sm9_point_get_xy(Q, x, y); + const sm9_fp_t *XP = P->X; const sm9_fp_t *YP = P->Y; const sm9_fp_t *ZP = P->Z; - const uint64_t *xQ = Q->X; - const uint64_t *yQ = Q->Y; + const uint64_t *xQ = x; + const uint64_t *yQ = y; sm9_fp_t *a0 = num[0][0]; sm9_fp_t *a1 = num[0][1]; @@ -1743,14 +1747,18 @@ void sm9_eval_g_tangent(sm9_fp12_t num, sm9_fp12_t den, const sm9_twist_point_t void sm9_eval_g_line(sm9_fp12_t num, sm9_fp12_t den, const sm9_twist_point_t *T, const sm9_twist_point_t *P, const sm9_point_t *Q) { + sm9_fp_t x; + sm9_fp_t y; + sm9_point_get_xy(Q, x, y); + const sm9_fp_t *XT = T->X; const sm9_fp_t *YT = T->Y; const sm9_fp_t *ZT = T->Z; const sm9_fp_t *XP = P->X; const sm9_fp_t *YP = P->Y; const sm9_fp_t *ZP = P->Z; - const uint64_t *xQ = Q->X; - const uint64_t *yQ = Q->Y; + const uint64_t *xQ = x; + const uint64_t *yQ = y; sm9_fp_t *a0 = num[0][0]; sm9_fp_t *a1 = num[0][1]; @@ -2209,33 +2217,39 @@ int sm9_point_to_uncompressed_octets(const sm9_point_t *P, uint8_t octets[65]) sm9_fp_t x; sm9_fp_t y; sm9_point_get_xy(P, x, y); - sm9_bn_to_bytes(x, octets); - sm9_bn_to_bytes(y, octets + 32); + octets[0] = 0x04; + sm9_bn_to_bytes(x, octets + 1); + sm9_bn_to_bytes(y, octets + 32 + 1); return 1; } int sm9_point_from_uncompressed_octets(sm9_point_t *P, const uint8_t octets[65]) { - sm9_bn_from_bytes(P->X, octets); - sm9_bn_from_bytes(P->Y, octets + 32); + assert(octets[0] == 0x04); + sm9_bn_from_bytes(P->X, octets + 1); + sm9_bn_from_bytes(P->Y, octets + 32 + 1); sm9_fp_set_one(P->Z); + if (!sm9_point_is_on_curve(P)) return -1; return 1; } int sm9_twist_point_to_uncompressed_octets(const sm9_twist_point_t *P, uint8_t octets[129]) { + octets[0] = 0x04; sm9_fp2_t x; sm9_fp2_t y; sm9_twist_point_get_xy(P, x, y); - sm9_fp2_to_bytes(x, octets); - sm9_fp2_to_bytes(y, octets + 32 * 2); + sm9_fp2_to_bytes(x, octets + 1); + sm9_fp2_to_bytes(y, octets + 32 * 2 + 1); return 1; } int sm9_twist_point_from_uncompressed_octets(sm9_twist_point_t *P, const uint8_t octets[129]) { - sm9_fp2_from_bytes(P->X, octets); - sm9_fp2_from_bytes(P->Y, octets + 32 * 2); + assert(octets[0] == 0x04); + sm9_fp2_from_bytes(P->X, octets + 1); + sm9_fp2_from_bytes(P->Y, octets + 32 * 2 + 1); sm9_fp2_set_one(P->Z); + if (!sm9_twist_point_is_on_curve(P)) return -1; return 1; } diff --git a/src/sm9_lib.c b/src/sm9_lib.c index 6eda85ad..5da96d38 100644 --- a/src/sm9_lib.c +++ b/src/sm9_lib.c @@ -139,8 +139,6 @@ int sm9_sign_finish(SM9_SIGN_CTX *ctx, const SM9_SIGN_KEY *key, uint8_t *sig, si return 1; } -#define hex_r "00033C8616B06704813203DFD00965022ED15975C662337AED648835DC4B1CBE" - int sm9_do_sign(const SM9_SIGN_KEY *key, const SM3_CTX *sm3_ctx, SM9_SIGNATURE *sig) { sm9_fn_t r; @@ -158,7 +156,6 @@ int sm9_do_sign(const SM9_SIGN_KEY *key, const SM3_CTX *sm3_ctx, SM9_SIGNATURE * do { // A2: rand r in [1, N-1] sm9_fn_rand(r); - //sm9_bn_from_hex(r, hex_r); // A3: w = g^r sm9_fp12_pow(g, g, r); @@ -289,7 +286,7 @@ int sm9_kem_encrypt(const SM9_ENC_MASTER_KEY *mpk, const char *id, size_t idlen, SM3_KDF_CTX kdf_ctx; // A1: Q = H1(ID||hid,N) * P1 + Ppube - sm9_hash1(r, id, idlen, SM9_HID_EXCH); + sm9_hash1(r, id, idlen, SM9_HID_ENC); sm9_point_mul(C, r, SM9_P1); sm9_point_add(C, C, &mpk->Ppube); @@ -335,10 +332,11 @@ int sm9_kem_decrypt(const SM9_ENC_KEY *key, const char *id, size_t idlen, const SM3_KDF_CTX kdf_ctx; // B1: check C in G1 - sm9_point_to_uncompressed_octets(C, cbuf + 1); + sm9_point_to_uncompressed_octets(C, cbuf); // B2: w = e(C, de); sm9_pairing(w, &key->de, C); + sm9_fp12_to_bytes(w, wbuf); // B3: K = KDF(C || w || ID, klen) sm3_kdf_init(&kdf_ctx, klen);