diff --git a/CMakeLists.txt b/CMakeLists.txt index f81dbd99..54b13c84 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,6 +45,7 @@ add_library( src/sm2_lib.c src/sm3.c src/sm3_hmac.c + src/sm3_kdf.c src/sm4_common.c src/sm4_enc.c src/sm4_modes.c diff --git a/include/gmssl/mem.h b/include/gmssl/mem.h index 2cb2cf98..3f912cce 100644 --- a/include/gmssl/mem.h +++ b/include/gmssl/mem.h @@ -58,6 +58,7 @@ void gmssl_memxor(void *r, const void *a, const void *b, size_t len); int gmssl_secure_memcmp(const volatile void * volatile in_a, const volatile void * volatile in_b, size_t len); void gmssl_secure_clear(void *ptr, size_t len); +int mem_is_zero(const uint8_t *buf, size_t len); #endif diff --git a/include/gmssl/sm3.h b/include/gmssl/sm3.h index aa2b3dca..cacc75ae 100644 --- a/include/gmssl/sm3.h +++ b/include/gmssl/sm3.h @@ -91,6 +91,16 @@ void sm3_hmac(const uint8_t *key, size_t keylen, uint8_t mac[SM3_HMAC_SIZE]); +typedef struct { + SM3_CTX sm3_ctx; + size_t outlen; +} SM3_KDF_CTX; + +void sm3_kdf_init(SM3_KDF_CTX *ctx, size_t outlen); +void sm3_kdf_update(SM3_KDF_CTX *ctx, const uint8_t *data, size_t datalen); +void sm3_kdf_finish(SM3_KDF_CTX *ctx, uint8_t *out); + + #ifdef __cplusplus } #endif diff --git a/include/gmssl/sm9.h b/include/gmssl/sm9.h index 8a6ff468..0ba06d31 100644 --- a/include/gmssl/sm9.h +++ b/include/gmssl/sm9.h @@ -268,7 +268,7 @@ void sm9_twist_point_add(sm9_twist_point_t *R, const sm9_twist_point_t *P, const void sm9_twist_point_sub(sm9_twist_point_t *R, const sm9_twist_point_t *P, const sm9_twist_point_t *Q); void sm9_twist_point_add_full(sm9_twist_point_t *R, const sm9_twist_point_t *P, const sm9_twist_point_t *Q); void sm9_twist_point_mul(sm9_twist_point_t *R, const sm9_bn_t k, const sm9_twist_point_t *P); -void sm9_twist_point_mul_G(sm9_twist_point_t *R, const sm9_bn_t k); +void sm9_twist_point_mul_generator(sm9_twist_point_t *R, const sm9_bn_t k); void sm9_eval_g_tangent(sm9_fp12_t num, sm9_fp12_t den, const sm9_twist_point_t *P, const sm9_point_t *Q); void sm9_eval_g_line(sm9_fp12_t num, sm9_fp12_t den, const sm9_twist_point_t *T, const sm9_twist_point_t *P, const sm9_point_t *Q); @@ -302,9 +302,20 @@ int sm9_fn_equ(const sm9_fn_t a, const sm9_fn_t b); void sm9_fn_rand(sm9_fn_t r); void sm9_fp12_to_bytes(const sm9_fp12_t a, uint8_t buf[32 * 12]); +int sm9_fn_from_hash(sm9_fn_t h, const uint8_t Ha[40]); int sm9_hash1(sm9_bn_t h1, const char *id, size_t idlen, uint8_t hid); +int sm9_point_to_bytes(const sm9_point_t *P, uint8_t out[32 * 2]); +int sm9_point_from_bytes(sm9_point_t *P, const uint8_t in[32 * 2]); +int sm9_twist_point_to_bytes(const sm9_twist_point_t *P, uint8_t out[32 * 2]); +int sm9_twist_point_from_bytes(sm9_twist_point_t *P, const uint8_t in[32 * 2]); + + + + + + // set the same value as sm2 #define SM9_MAX_ID_BITS 65535 #define SM9_MAX_ID_SIZE (SM9_MAX_ID_BITS/8) @@ -332,24 +343,29 @@ typedef struct { int sm9_sign_master_key_generate(SM9_SIGN_MASTER_KEY *master); int sm9_sign_master_key_extract_key(SM9_SIGN_MASTER_KEY *master, const char *id, size_t idlen, SM9_SIGN_KEY *key); -typedef struct { - SM3_CTX sm3_ctx; - SM9_SIGN_KEY key; -} SM9_SIGN_CTX; typedef struct { sm9_fn_t h; sm9_point_t S; } SM9_SIGNATURE; +int sm9_do_sign(const SM9_SIGN_KEY *key, const SM3_CTX *sm3_ctx, SM9_SIGNATURE *sig); +int sm9_do_verify(const SM9_SIGN_MASTER_KEY *mpk, const char *id, size_t idlen, + const SM3_CTX *sm3_ctx, const SM9_SIGNATURE *sig); + + +typedef struct { + SM3_CTX sm3_ctx; +} SM9_SIGN_CTX; + int sm9_sign_init(SM9_SIGN_CTX *ctx); int sm9_sign_update(SM9_SIGN_CTX *ctx, const uint8_t *data, size_t datalen); -int sm9_sign_finish(SM9_SIGN_CTX *ctx, SM9_SIGN_KEY *key, SM9_SIGNATURE *sig); +int sm9_sign_finish(SM9_SIGN_CTX *ctx, const SM9_SIGN_KEY *key, uint8_t *sig, size_t *siglen); int sm9_verify_init(SM9_SIGN_CTX *ctx); int sm9_verify_update(SM9_SIGN_CTX *ctx, const uint8_t *data, size_t datalen); -int sm9_verify_finish(SM9_SIGN_CTX *ctx, const SM9_SIGNATURE *sig, - const SM9_SIGN_MASTER_KEY *master_public, const char *id, size_t idlen); +int sm9_verify_finish(SM9_SIGN_CTX *ctx, const uint8_t *sig, size_t siglen, + const SM9_SIGN_MASTER_KEY *mpk, const char *id, size_t idlen); typedef struct { @@ -365,7 +381,12 @@ typedef struct { int sm9_enc_master_key_generate(SM9_ENC_MASTER_KEY *master); int sm9_enc_master_key_extract_key(SM9_ENC_MASTER_KEY *master, const char *id, size_t idlen, SM9_ENC_KEY *key); - +int sm9_kem_encrypt(const SM9_ENC_MASTER_KEY *mpk, const char *id, size_t idlen, size_t klen, uint8_t *kbuf, uint8_t cbuf[64]); +int sm9_kem_decrypt(const SM9_ENC_KEY *key, const char *id, size_t idlen, const uint8_t cbuf[64], size_t klen, uint8_t *kbuf); +int sm9_do_encrypt(const SM9_ENC_MASTER_KEY *mpk, const char *id, size_t idlen, + const uint8_t *in, size_t inlen, uint8_t C1[64], uint8_t *C2, uint8_t C3[32]); +int sm9_do_decrypt(const SM9_ENC_KEY *key, const char *id, size_t idlen, + const uint8_t C1[64], const uint8_t *C2, size_t C2len, const uint8_t C3[32], uint8_t *out); # ifdef __cplusplus } diff --git a/src/hex.c b/src/hex.c index 563bed4c..f3285ade 100644 --- a/src/hex.c +++ b/src/hex.c @@ -240,3 +240,18 @@ void gmssl_secure_clear(void *ptr, size_t len) { memset_func(ptr, 0, len); } + +int mem_is_zero(const uint8_t *buf, size_t len) +{ + int ret = 1; + size_t i; + for (i = 0; i < len; i++) { + if (buf[i]) ret = 0; + } + return ret; +} + + + + + diff --git a/src/sm3_kdf.c b/src/sm3_kdf.c new file mode 100644 index 00000000..89ef929c --- /dev/null +++ b/src/sm3_kdf.c @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2014 - 2021 The GmSSL Project. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * + * 3. All advertising materials mentioning features or use of this + * software must display the following acknowledgment: + * "This product includes software developed by the GmSSL Project. + * (http://gmssl.org/)" + * + * 4. The name "GmSSL Project" must not be used to endorse or promote + * products derived from this software without prior written + * permission. For written permission, please contact + * guanzhi1980@gmail.com. + * + * 5. Products derived from this software may not be called "GmSSL" + * nor may "GmSSL" appear in their names without prior written + * permission of the GmSSL Project. + * + * 6. Redistributions of any form whatsoever must retain the following + * acknowledgment: + * "This product includes software developed by the GmSSL Project + * (http://gmssl.org/)" + * + * THIS SOFTWARE IS PROVIDED BY THE GmSSL PROJECT ``AS IS'' AND ANY + * EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE GmSSL PROJECT OR + * ITS CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT + * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED + * OF THE POSSIBILITY OF SUCH DAMAGE. + */ + + +#include +#include +#include +#include + + +void sm3_kdf_init(SM3_KDF_CTX *ctx, size_t outlen) +{ + sm3_init(&ctx->sm3_ctx); + ctx->outlen = outlen; +} + +void sm3_kdf_update(SM3_KDF_CTX *ctx, const uint8_t *data, size_t datalen) +{ + sm3_update(&ctx->sm3_ctx, data, datalen); +} + +void sm3_kdf_finish(SM3_KDF_CTX *ctx, uint8_t *out) +{ + SM3_CTX sm3_ctx; + size_t outlen = ctx->outlen; + uint8_t counter_be[4]; + uint8_t dgst[SM3_DIGEST_SIZE]; + uint32_t counter = 1; + size_t len; + + while (outlen) { + PUTU32(counter_be, counter); + counter++; + + sm3_ctx = ctx->sm3_ctx; + sm3_update(&sm3_ctx, counter_be, sizeof(counter_be)); + sm3_finish(&sm3_ctx, dgst); + + len = outlen < SM3_DIGEST_SIZE ? outlen : SM3_DIGEST_SIZE; + memcpy(out, dgst, len); + out += len; + outlen -= len; + } + + memset(&sm3_ctx, 0, sizeof(SM3_CTX)); + memset(dgst, 0, sizeof(dgst)); +} diff --git a/src/sm9_alg.c b/src/sm9_alg.c index 9dd9626e..979b156c 100644 --- a/src/sm9_alg.c +++ b/src/sm9_alg.c @@ -1687,7 +1687,7 @@ void sm9_twist_point_mul(sm9_twist_point_t *R, const sm9_bn_t k, const sm9_twist sm9_twist_point_copy(R, Q); } -void sm9_twist_point_mul_G(sm9_twist_point_t *R, const sm9_bn_t k) +void sm9_twist_point_mul_generator(sm9_twist_point_t *R, const sm9_bn_t k) { sm9_twist_point_mul(R, k, SM9_P2); } @@ -1927,3 +1927,78 @@ void sm9_pairing(sm9_fp12_t r, const sm9_twist_point_t *Q, const sm9_point_t *P) sm9_final_exponent(r, r); } +void sm9_fn_add(sm9_fn_t r, const sm9_fn_t a, const sm9_fn_t b) +{ +} + +void sm9_fn_sub(sm9_fn_t r, const sm9_fn_t a, const sm9_fn_t b) +{ +} + +void sm9_fn_mul(sm9_fn_t r, const sm9_fn_t a, const sm9_fn_t b) +{ +} + +void sm9_fn_inv(sm9_fn_t r, const sm9_fn_t a) +{ +} + +int sm9_fn_is_zero(const sm9_fn_t a) +{ + return 0; +} +void sm9_fn_rand(sm9_fn_t r) +{ + // FIXME: add impl +} + +int sm9_fn_equ(const sm9_fn_t a, const sm9_fn_t b) +{ + // FIXME: add impl + return 1; +} + +// for H1() and H2() +// h = (Ha mod (n-1)) + 1; h in [1, n-1], n is the curve order, Ha is 40 bytes from hash +int sm9_fn_from_hash(sm9_fn_t h, const uint8_t Ha[40]) +{ + return 1; +} + +void sm9_fp12_to_bytes(const sm9_fp12_t a, uint8_t buf[32 * 12]) +{ + // FIXME: add impl +} + +int sm9_point_to_bytes(const sm9_point_t *P, uint8_t out[32 * 2]) +{ + // FIXME + return 1; +} + +int sm9_point_from_bytes(sm9_point_t *P, const uint8_t in[32 * 2]) +{ + // FIXME + return 1; +} + +int sm9_twist_point_to_bytes(const sm9_twist_point_t *P, uint8_t out[32 * 2]) +{ + // FIXME + return 1; +} + +int sm9_twist_point_from_bytes(sm9_twist_point_t *P, const uint8_t in[32 * 2]) +{ + // FIXME + return 1; +} + + + + + + + + + diff --git a/src/sm9_key.c b/src/sm9_key.c index 696bb662..01891781 100644 --- a/src/sm9_key.c +++ b/src/sm9_key.c @@ -54,116 +54,99 @@ #include - - - // generate h1 in [1, n-1] int sm9_hash1(sm9_bn_t h1, const char *id, size_t idlen, uint8_t hid) { - sm9_fn_t h; - SM3_CTX ctx1; - SM3_CTX ctx2; - + SM3_CTX ctx; uint8_t prefix[1] = {0x01}; uint8_t ct1[4] = {0x00, 0x00, 0x00, 0x01}; uint8_t ct2[4] = {0x00, 0x00, 0x00, 0x02}; - uint8_t buf[64]; + uint8_t Ha[64]; - sm3_init(&ctx1); - sm3_update(&ctx1, prefix, sizeof(prefix)); - sm3_update(&ctx1, (uint8_t *)id, idlen); - sm3_update(&ctx1, &hid, 1); - ctx2 = ctx1; - sm3_update(&ctx1, ct1, sizeof(ct1)); - sm3_update(&ctx2, ct2, sizeof(ct2)); - sm3_finish(&ctx1, buf); - sm3_finish(&ctx2, buf + 32); + sm3_init(&ctx); + sm3_update(&ctx, prefix, sizeof(prefix)); + sm3_update(&ctx, (uint8_t *)id, idlen); + sm3_update(&ctx, &hid, 1); + sm3_update(&ctx, ct1, sizeof(ct1)); + sm3_finish(&ctx, Ha); - // 这个buflen == 64,我们要将长为40的部分取出来,模 N-1 再加1 - return -1; + sm3_init(&ctx); + sm3_update(&ctx, prefix, sizeof(prefix)); + sm3_update(&ctx, (uint8_t *)id, idlen); + sm3_update(&ctx, &hid, 1); + sm3_update(&ctx, ct2, sizeof(ct2)); + sm3_finish(&ctx, Ha + 32); + + sm9_fn_from_hash(h1, Ha); + return 1; } -void sm9_fn_add(sm9_fn_t r, const sm9_fn_t a, const sm9_fn_t b) -{ -} - -void sm9_fn_sub(sm9_fn_t r, const sm9_fn_t a, const sm9_fn_t b) -{ -} - -void sm9_fn_mul(sm9_fn_t r, const sm9_fn_t a, const sm9_fn_t b) -{ -} - -void sm9_fn_inv(sm9_fn_t r, const sm9_fn_t a) -{ -} - -int sm9_fn_is_zero(const sm9_fn_t a) -{ - return 0; -} - -int sm9_sign_master_key_generate(SM9_SIGN_MASTER_KEY *master) +int sm9_sign_master_key_generate(SM9_SIGN_MASTER_KEY *msk) { // k = rand(1, n-1) - //sm9_bn_rand_range(master->ks, SM9_N); + sm9_fn_rand(msk->ks); // Ppubs = k * P2 in E'(F_p^2) - sm9_twist_point_mul_G(&master->Ppubs, master->ks); + sm9_twist_point_mul_generator(&msk->Ppubs, msk->ks); return 1; } -int sm9_enc_master_key_generate(SM9_ENC_MASTER_KEY *master) +int sm9_enc_master_key_generate(SM9_ENC_MASTER_KEY *msk) { // k = rand(1, n-1) - //sm9_bn_rand_range(master->ke, SM9_N); + sm9_fn_rand(msk->ke); // Ppube = ke * P1 in E(F_p) - sm9_point_mul_generator(&master->Ppube, master->ke); + sm9_point_mul_generator(&msk->Ppube, msk->ke); return 1; } -int sm9_sign_master_key_extract_key(SM9_SIGN_MASTER_KEY *master, const char *id, size_t idlen, SM9_SIGN_KEY *key) +int sm9_sign_master_key_extract_key(SM9_SIGN_MASTER_KEY *msk, const char *id, size_t idlen, SM9_SIGN_KEY *key) { sm9_fn_t t; + // t1 = H1(ID || hid, N) + ks sm9_hash1(t, id, idlen, SM9_HID_SIGN); - sm9_fn_add(t, t, master->ks); + sm9_fn_add(t, t, msk->ks); if (sm9_fn_is_zero(t)) { + // 这是一个严重问题,意味着整个msk都需要作废了 error_print(); return -1; } + + // t2 = ks * t1^-1 sm9_fn_inv(t, t); - sm9_fn_mul(t, t, master->ks); + sm9_fn_mul(t, t, msk->ks); + + // ds = t2 * P1 sm9_point_mul_generator(&key->ds, t); + key->Ppubs = msk->Ppubs; + return 1; } -int sm9_enc_master_key_extract_key(SM9_ENC_MASTER_KEY *master, const char *id, size_t idlen, +int sm9_enc_master_key_extract_key(SM9_ENC_MASTER_KEY *msk, const char *id, size_t idlen, SM9_ENC_KEY *key) { sm9_fn_t t; + // t1 = H1(ID || hid, N) + ke sm9_hash1(t, id, idlen, SM9_HID_ENC); - sm9_fn_add(t, t, master->ke); + sm9_fn_add(t, t, msk->ke); if (sm9_fn_is_zero(t)) { error_print(); return -1; } + + // t2 = ke * t1^-1 sm9_fn_inv(t, t); - sm9_fn_mul(t, t, master->ke); - sm9_twist_point_mul_G(&key->de, t); + sm9_fn_mul(t, t, msk->ke); + + // de = t2 * P2 + sm9_twist_point_mul_generator(&key->de, t); + key->Ppube = msk->Ppube; + return 1; } - - - - - - - - - diff --git a/src/sm9_lib.c b/src/sm9_lib.c index 3e69b30f..086ea94c 100644 --- a/src/sm9_lib.c +++ b/src/sm9_lib.c @@ -1,4 +1,4 @@ -/* +/* * Copyright (c) 2014 - 2020 The GmSSL Project. All rights reserved. * * Redistribution and use in source and binary forms, with or without @@ -50,27 +50,12 @@ #include #include #include +#include #include #include #include -void sm9_fn_rand(sm9_fn_t r) -{ - // FIXME: add impl -} - -int sm9_fn_equ(const sm9_fn_t a, const sm9_fn_t b) -{ - // FIXME: add impl - return 1; -} - -void sm9_fp12_to_bytes(const sm9_fp12_t a, uint8_t buf[32 * 12]) -{ - // FIXME: add impl -} - int sm9_sign_init(SM9_SIGN_CTX *ctx) { const uint8_t prefix[1] = {0x02}; @@ -85,26 +70,56 @@ int sm9_sign_update(SM9_SIGN_CTX *ctx, const uint8_t *data, size_t datalen) return 1; } -int sm9_sign_finish(SM9_SIGN_CTX *ctx, SM9_SIGN_KEY *key, SM9_SIGNATURE *sig) +int sm9_sign_finish(SM9_SIGN_CTX *ctx, const SM9_SIGN_KEY *key, uint8_t *sig, size_t *siglen) +{ + return -1; +} + +int sm9_do_sign(const SM9_SIGN_KEY *key, const SM3_CTX *sm3_ctx, SM9_SIGNATURE *sig) { sm9_fn_t r; - sm9_fn_t h; sm9_fp12_t g; - sm9_fp12_t w; uint8_t wbuf[32 * 12]; - uint8_t dgst[32]; + SM3_CTX ctx = *sm3_ctx; + SM3_CTX tmp_ctx; + uint8_t ct1[4] = {0,0,0,1}; + uint8_t ct2[4] = {0,0,0,2}; + uint8_t Ha[64]; + // A1: g = e(P1, Ppubs) sm9_pairing(g, &key->Ppubs, SM9_P1); + do { + // A2: rand r in [1, N-1] sm9_fn_rand(r); - sm9_fp12_pow(w, g, r); - sm9_fp12_to_bytes(w, wbuf); - sm3_update(&ctx->sm3_ctx, wbuf, sizeof(wbuf)); - sm3_finish(&ctx->sm3_ctx, dgst); - // do H2() staff, generate output sig->h - sm9_fn_sub(r, r, h); + + // A3: w = g^r + sm9_fp12_pow(g, g, r); + sm9_fp12_to_bytes(g, wbuf); + + // A4: h = H2(M || w, N) + sm3_update(&ctx, wbuf, sizeof(wbuf)); + tmp_ctx = ctx; + sm3_update(&ctx, ct1, sizeof(ct1)); + sm3_finish(&ctx, Ha); + sm3_update(&tmp_ctx, ct2, sizeof(ct2)); + sm3_finish(&tmp_ctx, Ha + 32); + sm9_fn_from_hash(sig->h, Ha); + + // A5: l = (r - h) mod N, if l = 0, goto A2 + sm9_fn_sub(r, r, sig->h); + } while (sm9_fn_is_zero(r)); + + // A6: S = l * dsA sm9_point_mul(&sig->S, r, &key->ds); + + gmssl_secure_clear(&r, sizeof(r)); + gmssl_secure_clear(&g, sizeof(g)); + gmssl_secure_clear(wbuf, sizeof(wbuf)); + gmssl_secure_clear(&tmp_ctx, sizeof(tmp_ctx)); + gmssl_secure_clear(Ha, sizeof(Ha)); + return 1; } @@ -122,9 +137,14 @@ int sm9_verify_update(SM9_SIGN_CTX *ctx, const uint8_t *data, size_t datalen) return 1; } -// 签名的时候 -int sm9_verify_finish(SM9_SIGN_CTX *ctx, const SM9_SIGNATURE *sig, - const SM9_SIGN_MASTER_KEY *master_public, const char *id, size_t idlen) +int sm9_verify_finish(SM9_SIGN_CTX *ctx, const uint8_t *sig, size_t siglen, + const SM9_SIGN_MASTER_KEY *mpk, const char *id, size_t idlen) +{ + return -1; +} + +int sm9_do_verify(const SM9_SIGN_MASTER_KEY *mpk, const char *id, size_t idlen, + const SM3_CTX *sm3_ctx, const SM9_SIGNATURE *sig) { sm9_fn_t h1; sm9_fn_t h2; @@ -134,23 +154,163 @@ int sm9_verify_finish(SM9_SIGN_CTX *ctx, const SM9_SIGNATURE *sig, sm9_fp12_t w; sm9_twist_point_t P; uint8_t wbuf[32 * 12]; + SM3_CTX ctx = *sm3_ctx; + SM3_CTX tmp_ctx; + uint8_t ct1[4] = {0,0,0,1}; + uint8_t ct2[4] = {0,0,0,2}; + uint8_t Ha[64]; - sm9_pairing(g, &master_public->Ppubs, SM9_P1); + // B1: check h in [1, N-1] + + // B2: check S in G1 + + // B3: g = e(P1, Ppubs) + sm9_pairing(g, &mpk->Ppubs, SM9_P1); + + // B4: t = g^h sm9_fp12_pow(t, g, sig->h); + + // B5: h1 = H1(ID || hid, N) sm9_hash1(h1, id, idlen, SM9_HID_SIGN); - sm9_twist_point_mul_G(&P, h1); - sm9_twist_point_add(&P, &P, &master_public->Ppubs); + + // B6: P = h1 * P2 + Ppubs + sm9_twist_point_mul_generator(&P, h1); + sm9_twist_point_add(&P, &P, &mpk->Ppubs); + + // B7: u = e(S, P) sm9_pairing(u, &P, &sig->S); + + // B8: w = u * t sm9_fp12_mul(w, u, t); sm9_fp12_to_bytes(w, wbuf); - sm3_update(&ctx->sm3_ctx, wbuf, sizeof(wbuf)); - // convert h2 - + // B9: h2 = H2(M || w, N), check h2 == h + sm3_update(&ctx, wbuf, sizeof(wbuf)); + tmp_ctx = ctx; + sm3_update(&ctx, ct1, sizeof(ct1)); + sm3_finish(&ctx, Ha); + sm3_update(&tmp_ctx, ct2, sizeof(ct2)); + sm3_finish(&tmp_ctx, Ha + 32); + sm9_fn_from_hash(h2, Ha); if (sm9_fn_equ(h2, sig->h) != 1) { return 0; } + return 1; } +int sm9_kem_encrypt(const SM9_ENC_MASTER_KEY *mpk, const char *id, size_t idlen, + size_t klen, uint8_t *kbuf, uint8_t cbuf[64]) +{ + sm9_fn_t r; + sm9_fp12_t w; + sm9_point_t C; + uint8_t wbuf[32 * 12]; + SM3_KDF_CTX kdf_ctx; + // A1: Q = H1(ID||hid,N) * P1 + Ppube + sm9_hash1(r, id, idlen, SM9_HID_EXCH); + sm9_point_mul(&C, r, SM9_P1); + sm9_point_add(&C, &C, &mpk->Ppube); + + do { + // A2: rand r in [1, N-1] + sm9_fn_rand(r); + + // A3: C1 = r * Q + sm9_point_mul(&C, r, &C); + sm9_point_to_bytes(&C, cbuf); + + // A4: g = e(Ppube, P2) + sm9_pairing(w, SM9_P2, &mpk->Ppube); + + // A5: w = g^r + sm9_fp12_pow(w, w, r); + sm9_fp12_to_bytes(w, wbuf); + + // A6: K = KDF(C || w || ID_B, klen), if K == 0, goto A2 + sm3_kdf_init(&kdf_ctx, klen); + sm3_kdf_update(&kdf_ctx, cbuf, 64); + sm3_kdf_update(&kdf_ctx, wbuf, sizeof(wbuf)); + sm3_kdf_update(&kdf_ctx, (uint8_t *)id, idlen); + sm3_kdf_finish(&kdf_ctx, kbuf); + + } while (mem_is_zero(kbuf, klen) == 1); + + gmssl_secure_clear(&r, sizeof(r)); + gmssl_secure_clear(&w, sizeof(w)); + gmssl_secure_clear(&C, sizeof(C)); + gmssl_secure_clear(wbuf, sizeof(wbuf)); + gmssl_secure_clear(&kdf_ctx, sizeof(kdf_ctx)); + + // A7: output (K, C) + return 1; +} + +int sm9_kem_decrypt(const SM9_ENC_KEY *key, const char *id, size_t idlen, const uint8_t cbuf[64], + size_t klen, uint8_t *kbuf) +{ + sm9_fp12_t w; + sm9_point_t C; + uint8_t wbuf[32 * 12]; + SM3_KDF_CTX kdf_ctx; + + // B1: check C in G1 + if (sm9_point_from_bytes(&C, cbuf) != 1) { + error_print(); + return -1; + } + + // B2: w = e(C, de); + sm9_pairing(w, &key->de, &C); + + // B3: K = KDF(C || w || ID, klen) + sm3_kdf_init(&kdf_ctx, klen); + sm3_kdf_update(&kdf_ctx, cbuf, 64); + sm3_kdf_update(&kdf_ctx, wbuf, sizeof(wbuf)); + sm3_kdf_update(&kdf_ctx, (uint8_t *)id, idlen); + sm3_kdf_finish(&kdf_ctx, kbuf); + + if (mem_is_zero(kbuf, klen) != 1) { + error_print(); + return -1; + } + + gmssl_secure_clear(&w, sizeof(w)); + gmssl_secure_clear(&C, sizeof(C)); + gmssl_secure_clear(wbuf, sizeof(wbuf)); + gmssl_secure_clear(&kdf_ctx, sizeof(kdf_ctx)); + + // B4: output K + return 1; +} + +int sm9_do_encrypt(const SM9_ENC_MASTER_KEY *mpk, const char *id, size_t idlen, + const uint8_t *in, size_t inlen, + uint8_t C1[64], uint8_t *C2, uint8_t C3[32]) +{ + uint8_t K[inlen + 32]; + + sm9_kem_encrypt(mpk, id, idlen, sizeof(K), K, C1); + gmssl_memxor(C2, K, in, inlen); + sm3_hmac(K + inlen, 32, C2, inlen, C3); + + return 1; +} + +int sm9_do_decrypt(const SM9_ENC_KEY *key, const char *id, size_t idlen, + const uint8_t C1[64], const uint8_t *C2, size_t C2len, const uint8_t C3[32], + uint8_t *out) +{ + uint8_t K[C2len + 32]; + uint8_t mac[32]; + + sm9_kem_decrypt(key, id, idlen, C1, sizeof(K), K); + sm3_hmac(K + C2len, 32, C2, C2len, mac); + if (gmssl_secure_memcmp(C3, mac, sizeof(mac)) != 0) { + error_print(); + return -1; + } + gmssl_memxor(out, K, C2, C2len); + return 1; +} diff --git a/tests/hash_drbgtest.c b/tests/hash_drbgtest.c index 23adde47..d1fcef0f 100644 --- a/tests/hash_drbgtest.c +++ b/tests/hash_drbgtest.c @@ -99,7 +99,6 @@ int main(void) hex_to_bytes(PR1, strlen(PR1), pr1, &pr1_len); hex_to_bytes(PR2, strlen(PR2), pr2, &pr2_len); - hash_drbg_init(&drbg, DIGEST_sha1(), entropy, entropy_len, nonce, nonce_len, @@ -111,6 +110,7 @@ int main(void) || memcmp(drbg.C, c, clen) != 0 || drbg.reseed_counter != 1) { printf("failed\n"); + return 1; } else { printf("ok\n"); } @@ -118,7 +118,6 @@ int main(void) hash_drbg_reseed(&drbg, pr1, pr1_len, NULL, 0); hash_drbg_generate(&drbg, NULL, 0, 640/8, out); - hash_drbg_reseed(&drbg, pr2, pr2_len, NULL, 0); hash_drbg_generate(&drbg, NULL, 0, 640/8, out); diff --git a/tests/sm9test.c b/tests/sm9test.c index 25dd0d8e..84743301 100644 --- a/tests/sm9test.c +++ b/tests/sm9test.c @@ -388,7 +388,7 @@ int test_sm9_twist_point() { sm9_twist_point_add_full(&r, &p, &q); if (!sm9_twist_point_equ(&r, &s)) goto err; ++j; sm9_twist_point_sub(&r, &p, &q); sm9_twist_point_from_hex(&s, hex_tpoint_sub); if (!sm9_twist_point_equ(&r, &s)) goto err; ++j; sm9_twist_point_mul(&r, k, &p); sm9_twist_point_from_hex(&s, hex_tpoint_mul); if (!sm9_twist_point_equ(&r, &s)) goto err; ++j; - sm9_twist_point_mul_G(&r, k); sm9_twist_point_from_hex(&s, hex_tpoint_mulg); if (!sm9_twist_point_equ(&r, &s)) goto err; ++j; + sm9_twist_point_mul_generator(&r, k); sm9_twist_point_from_hex(&s, hex_tpoint_mulg); if (!sm9_twist_point_equ(&r, &s)) goto err; ++j; printf("%s() ok\n", __FUNCTION__); return 1;