diff --git a/include/gmssl/sm9.h b/include/gmssl/sm9.h index e687f661..6e8d59c3 100644 --- a/include/gmssl/sm9.h +++ b/include/gmssl/sm9.h @@ -303,7 +303,7 @@ int sm9_fn_equ(const sm9_fn_t a, const sm9_fn_t b); void sm9_fn_rand(sm9_fn_t r); void sm9_fp12_to_bytes(const sm9_fp12_t a, uint8_t buf[32 * 12]); -int sm9_fn_from_hash(sm9_fn_t h, const uint8_t Ha[40]); +void sm9_fn_from_hash(sm9_fn_t h, const uint8_t Ha[40]); int sm9_hash1(sm9_bn_t h1, const char *id, size_t idlen, uint8_t hid); diff --git a/src/sm9_alg.c b/src/sm9_alg.c index c5ba8b2f..2b0b7273 100644 --- a/src/sm9_alg.c +++ b/src/sm9_alg.c @@ -72,6 +72,7 @@ const sm9_bn_t SM9_N = {0xd69ecf25, 0xe56ee19c, 0x18ea8bee, 0x49f2934b, 0xf58ec7 const sm9_bn_t SM9_N_MINUS_ONE = {0xd69ecf24, 0xe56ee19c, 0x18ea8bee, 0x49f2934b, 0xf58ec744, 0xd603ab4f, 0x02a3a6f1, 0xb6400000}; const sm9_barrett_bn_t SM9_MU_P = {0xd5c22146, 0x71188f90, 0x1e36081c, 0xf2665f6d, 0xdcd1312a, 0x55f73aeb, 0xeb5759a6, 0x67980e0b, 0x00000001}; const sm9_barrett_bn_t SM9_MU_N = {0xdfc97c2f, 0x74df4fd4, 0xc9c073b0, 0x9c95d85e, 0xdcd1312c, 0x55f73aeb, 0xeb5759a6, 0x67980e0b, 0x00000001}; +const sm9_barrett_bn_t SM9_MU_N_MINUS_ONE = {0xdfc97c31, 0x74df4fd4, 0xc9c073b0, 0x9c95d85e, 0xdcd1312c, 0x55f73aeb, 0xeb5759a6, 0x67980e0b, 0x00000001}; // P1.X 0x93DE051D62BF718FF5ED0704487D01D6E1E4086909DC3280E8C4E4817C66DDDD @@ -1350,11 +1351,15 @@ void sm9_point_dbl(sm9_point_t *R, const sm9_point_t *P) void sm9_point_add(sm9_point_t *R, const sm9_point_t *P, const sm9_point_t *Q) { + sm9_fp_t x; + sm9_fp_t y; + sm9_point_get_xy(Q, x, y); + const uint64_t *X1 = P->X; const uint64_t *Y1 = P->Y; const uint64_t *Z1 = P->Z; - const uint64_t *x2 = Q->X; - const uint64_t *y2 = Q->Y; + const uint64_t *x2 = x; + const uint64_t *y2 = y; sm9_fp_t X3, Y3, Z3, T1, T2, T3, T4; if (sm9_point_is_at_infinity(Q)) { @@ -1613,7 +1618,7 @@ void sm9_twist_point_sub(sm9_twist_point_t *R, const sm9_twist_point_t *P, const { sm9_twist_point_t _T, *T = &_T; sm9_twist_point_neg(T, Q); - sm9_twist_point_add(R, P, T); + sm9_twist_point_add_full(R, P, T); } void sm9_twist_point_add_full(sm9_twist_point_t *R, const sm9_twist_point_t *P, const sm9_twist_point_t *Q) @@ -1684,7 +1689,7 @@ void sm9_twist_point_mul(sm9_twist_point_t *R, const sm9_bn_t k, const sm9_twist for (i = 0; i < 256; i++) { sm9_twist_point_dbl(Q, Q); if (kbits[i] == '1') { - sm9_twist_point_add(Q, Q, P); + sm9_twist_point_add_full(Q, Q, P); } } sm9_twist_point_copy(R, Q); @@ -1907,7 +1912,7 @@ void sm9_pairing(sm9_fp12_t r, const sm9_twist_point_t *Q, const sm9_point_t *P) sm9_eval_g_line(g_num, g_den, T, Q, P); sm9_fp12_mul(f_num, f_num, g_num); sm9_fp12_mul(f_den, f_den, g_den); - sm9_twist_point_add(T, T, Q); + sm9_twist_point_add_full(T, T, Q); } } @@ -2082,45 +2087,155 @@ int sm9_fn_equ(const sm9_fn_t a, const sm9_fn_t b) // for H1() and H2() // h = (Ha mod (n-1)) + 1; h in [1, n-1], n is the curve order, Ha is 40 bytes from hash -int sm9_fn_from_hash(sm9_fn_t h, const uint8_t Ha[40]) +void sm9_fn_from_hash(sm9_fn_t h, const uint8_t Ha[40]) { - return 1; + uint64_t s[18] = {0}; + sm9_barrett_bn_t zh, zl, q; + uint64_t w; + int i, j; + + /* s = Ha -> int */ + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 4; j++) { + s[i] <<= 8; + s[i] += Ha[4 * (9-i) + j]; + } + } + + /* zl = z mod (2^32)^9 = z[0..8] + * zh = z // (2^32)^7 = z[7..15] */ + for (i = 0; i < 9; i++) { + zl[i] = s[i]; + zh[i] = s[7 + i]; + } + + /* q = zh * mu // (2^32)^9 */ + for (i = 0; i < 18; i++) { + s[i] = 0; + } + for (i = 0; i < 9; i++) { + w = 0; + for (j = 0; j < 9; j++) { + w += s[i + j] + zh[i] * SM9_MU_N_MINUS_ONE[j]; // + s[i + j] = w & 0xffffffff; + w >>= 32; + } + s[i + 9] = w; + } + for (i = 0; i < 9; i++) { + q[i] = s[9 + i]; + } + + /* q = q * p mod (2^32)^9 */ + for (i = 0; i < 18; i++) { + s[i] = 0; + } + for (i = 0; i < 9; i++) { + w = 0; + for (j = 0; j < 8; j++) { + w += s[i + j] + q[i] * SM9_N_MINUS_ONE[j]; + s[i + j] = w & 0xffffffff; + w >>= 32; + } + s[i + 8] = w; + } + for (i = 0; i < 9; i++) { + q[i] = s[i]; + } + + /* h = zl - q (mod (2^32)^9) */ + + if (sm9_barrett_bn_cmp(zl, q)) { + sm9_barrett_bn_sub(zl, zl, q); + } else { + sm9_barrett_bn_t c = {0,0,0,0,0,0,0,0,0x100000000}; + sm9_barrett_bn_sub(q, c, q); + sm9_barrett_bn_add(zl, q, zl); + } + + for (i = 0; i < 8; i++) { + h[i] = zl[i]; + } + + h[7] += (zl[8] << 32); + + /* while h >= (n-1) do: h = h - (n-1) */ + while (sm9_bn_cmp(h, SM9_N_MINUS_ONE) >= 0) { + sm9_bn_sub(h, h, SM9_N_MINUS_ONE); + } + + sm9_fn_add(h, h, SM9_ONE); +} + +void sm9_fp2_from_bytes(sm9_fp2_t r, const uint8_t in[32 * 2]) +{ + sm9_bn_from_bytes(r[1], in); + sm9_bn_from_bytes(r[0], in + 32); +} + +void sm9_fp2_to_bytes(const sm9_fp2_t a, uint8_t buf[32 * 2]) +{ + sm9_bn_to_bytes(a[1], buf); + sm9_bn_to_bytes(a[0], buf + 32); +} + +void sm9_fp4_to_bytes(const sm9_fp4_t a, uint8_t buf[32 * 4]) +{ + sm9_fp2_to_bytes(a[1], buf); + sm9_fp2_to_bytes(a[0], buf + 32 * 2); } void sm9_fp12_to_bytes(const sm9_fp12_t a, uint8_t buf[32 * 12]) { - // FIXME: add impl + sm9_fp4_to_bytes(a[2], buf); + sm9_fp4_to_bytes(a[1], buf + 32 * 4); + sm9_fp4_to_bytes(a[0], buf + 32 * 8); } void sm9_fn_to_bytes(const sm9_fn_t a, uint8_t out[32]) { + sm9_bn_to_bytes(a, out); + return; } int sm9_fn_from_bytes(sm9_fn_t a, const uint8_t in[32]) { - // FIXME: impl - return -1; + sm9_bn_from_bytes(a, in); + return 1; } int sm9_point_to_uncompressed_octets(const sm9_point_t *P, uint8_t octets[65]) { - //FIXME: impl - return -1; + sm9_fp_t x; + sm9_fp_t y; + sm9_point_get_xy(P, x, y); + sm9_bn_to_bytes(x, octets); + sm9_bn_to_bytes(y, octets + 32); + return 1; } int sm9_point_from_uncompressed_octets(sm9_point_t *P, const uint8_t octets[65]) { - //FIXME: impl - return -1; + sm9_bn_from_bytes(P->X, octets); + sm9_bn_from_bytes(P->Y, octets + 32); + sm9_fp_set_one(P->Z); + return 1; } int sm9_twist_point_to_uncompressed_octets(const sm9_twist_point_t *P, uint8_t octets[129]) { - return -1; + sm9_fp2_t x; + sm9_fp2_t y; + sm9_twist_point_get_xy(P, x, y); + sm9_fp2_to_bytes(x, octets); + sm9_fp2_to_bytes(y, octets + 32 * 2); + return 1; } int sm9_twist_point_from_uncompressed_octets(sm9_twist_point_t *P, const uint8_t octets[129]) { - return -1; + sm9_fp2_from_bytes(P->X, octets); + sm9_fp2_from_bytes(P->Y, octets + 32 * 2); + sm9_fp2_set_one(P->Z); + return 1; } - diff --git a/src/sm9_lib.c b/src/sm9_lib.c index 1fa958c1..6eda85ad 100644 --- a/src/sm9_lib.c +++ b/src/sm9_lib.c @@ -139,6 +139,8 @@ int sm9_sign_finish(SM9_SIGN_CTX *ctx, const SM9_SIGN_KEY *key, uint8_t *sig, si return 1; } +#define hex_r "00033C8616B06704813203DFD00965022ED15975C662337AED648835DC4B1CBE" + int sm9_do_sign(const SM9_SIGN_KEY *key, const SM3_CTX *sm3_ctx, SM9_SIGNATURE *sig) { sm9_fn_t r; @@ -156,6 +158,7 @@ int sm9_do_sign(const SM9_SIGN_KEY *key, const SM3_CTX *sm3_ctx, SM9_SIGNATURE * do { // A2: rand r in [1, N-1] sm9_fn_rand(r); + //sm9_bn_from_hex(r, hex_r); // A3: w = g^r sm9_fp12_pow(g, g, r); diff --git a/tests/sm9test.c b/tests/sm9test.c index 28867d99..6ce5ac50 100644 --- a/tests/sm9test.c +++ b/tests/sm9test.c @@ -502,6 +502,46 @@ err: return -1; } +#define hex_ks "000130E78459D78545CB54C587E02CF480CE0B66340F319F348A1D5B1F2DC5F4" +#define hex_t2 "291FE3CAC8F58AD2DC462C8D4D578A94DAFD5624DDC28E328D2936688A86CF1A" +#define hex_Ppubs "9f64080b3084f733e48aff4b41b565011ce0711c5e392cfb0ab1b6791b94c408\ +-29dba116152d1f786ce843ed24a3b573414d2177386a92dd8f14d65696ea5e32\ +-69850938abea0112b57329f447e3a0cbad3e2fdb1a77f335e89e1408d0ef1c25\ +-41e00a53dda532da1a7ce027b7a46f741006e85f5cdff0730e75c05fb4e3216d" +#define hex_ds "A5702F05CF1315305E2D6EB64B0DEB923DB1A0BCF0CAFF90523AC8754AA69820-78559A844411F9825C109F5EE3F52D720DD01785392A727BB1556952B2B013D3" +#define hex_h_ans "823C4B21E4BD2DFE1ED92C606653E996668563152FC33F55D7BFBB9BD9705ADB" +#define hex_S_ans "73BF96923CE58B6AD0E13E9643A406D8EB98417C50EF1B29CEF9ADB48B6D598C-856712F1C2E0968AB7769F42A99586AED139D5B8B3E15891827CC2ACED9BAA05" + +int test_sm9_sign() { + SM9_SIGN_CTX ctx; + SM9_SIGN_KEY key; + SM9_SIGNATURE sig; + sm9_twist_point_t ts; + sm9_point_t s; + sm9_bn_t k; + int j = 1; + + uint8_t data[20] = {0x43, 0x68, 0x69, 0x6E, 0x65, 0x73, 0x65, 0x20, 0x49, 0x42, 0x53, 0x20, 0x73, 0x74, 0x61, 0x6E, 0x64, 0x61, 0x72, 0x64}; + + sm9_bn_from_hex(k, hex_ks); sm9_twist_point_mul_generator(&(key.Ppubs), k); + sm9_twist_point_from_hex(&ts, hex_Ppubs); if (!sm9_twist_point_equ(&ts, &(key.Ppubs))) goto err; ++j; + sm9_bn_from_hex(k, hex_t2); sm9_point_mul_generator(&(key.ds), k); + sm9_point_from_hex(&s, hex_ds); if (!sm9_point_equ(&(key.ds), &s)) goto err; ++j; + + sm9_sign_init(&ctx); + sm9_sign_update(&ctx, data, sizeof(data)); + sm9_do_sign(&key, &(ctx.sm3_ctx), &sig); + sm9_bn_from_hex(k, hex_h_ans); if (!sm9_fn_equ(sig.h, k)) goto err; ++j; + sm9_point_from_hex(&s, hex_S_ans); if (!sm9_point_equ(&(sig.S), &s)) goto err; ++j; + + printf("%s() ok\n", __FUNCTION__); + return 1; +err: + printf("%s test %d failed\n", __FUNCTION__, j); + error_print(); + return -1; +} + int main(void) { if (test_sm9_fp() != 1) goto err; if (test_sm9_fn() != 1) goto err; @@ -511,6 +551,7 @@ int main(void) { if (test_sm9_point() != 1) goto err; if (test_sm9_twist_point() != 1) goto err; if (test_sm9_pairing() != 1) goto err; + // if (test_sm9_sign() != 1) goto err; /* Must open "#define hex_r" in sm9_lib.c */ printf("%s all tests passed\n", __FILE__); return 0;