diff --git a/include/gmssl/sm2.h b/include/gmssl/sm2.h index adf78b60..49d7ebb1 100644 --- a/include/gmssl/sm2.h +++ b/include/gmssl/sm2.h @@ -76,7 +76,7 @@ void sm2_bn_to_bits(const SM2_BN a, char bits[256]); void sm2_bn_set_word(SM2_BN r, uint32_t a); void sm2_bn_add(SM2_BN r, const SM2_BN a, const SM2_BN b); void sm2_bn_sub(SM2_BN ret, const SM2_BN a, const SM2_BN b); -void sm2_bn_rand_range(SM2_BN r, const SM2_BN range); // 这个函数需要修改一下,从外部引入随机数 +int sm2_bn_rand_range(SM2_BN r, const SM2_BN range); #define sm2_bn_init(r) memset((r),0,sizeof(SM2_BN)) #define sm2_bn_set_zero(r) memset((r),0,sizeof(SM2_BN)) @@ -98,7 +98,7 @@ void sm2_fp_div2(SM2_Fp r, const SM2_Fp a); void sm2_fp_neg(SM2_Fp r, const SM2_Fp a); void sm2_fp_sqr(SM2_Fp r, const SM2_Fp a); void sm2_fp_inv(SM2_Fp r, const SM2_Fp a); -void sm2_fp_rand(SM2_Fp r); // 外部提供随机性,如果满足条件就输出,如果不满足条件就哈希一下再输出 +int sm2_fp_rand(SM2_Fp r); int sm2_fp_sqrt(SM2_Fp r, const SM2_Fp a); @@ -114,11 +114,12 @@ typedef SM2_BN SM2_Fn; void sm2_fn_add(SM2_Fn r, const SM2_Fn a, const SM2_Fn b); void sm2_fn_sub(SM2_Fn r, const SM2_Fn a, const SM2_Fn b); void sm2_fn_mul(SM2_Fn r, const SM2_Fn a, const SM2_Fn b); +void sm2_fn_mul_word(SM2_Fn r, const SM2_Fn a, uint32_t b); void sm2_fn_exp(SM2_Fn r, const SM2_Fn a, const SM2_Fn e); void sm2_fn_neg(SM2_Fn r, const SM2_Fn a); void sm2_fn_sqr(SM2_Fn r, const SM2_Fn a); void sm2_fn_inv(SM2_Fn r, const SM2_Fn a); -void sm2_fn_rand(SM2_Fn r); +int sm2_fn_rand(SM2_Fn r); #define sm2_fn_init(r) sm2_bn_init(r) #define sm2_fn_set_zero(r) sm2_bn_set_zero(r) @@ -298,6 +299,19 @@ int sm2_private_key_info_decrypt_from_der(SM2_KEY *key, const uint8_t **attrs, s int sm2_private_key_info_encrypt_to_pem(const SM2_KEY *key, const char *pass, FILE *fp); int sm2_private_key_info_decrypt_from_pem(SM2_KEY *key, const char *pass, FILE *fp); +// SM2 Key Shamir Secret Sharing +typedef struct { + SM2_KEY key; + size_t index; + size_t total_cnt; +} SM2_KEY_SHARE; + +int sm2_key_split(const SM2_KEY *key, size_t recover_cnt, size_t total_cnt, SM2_KEY_SHARE *shares); +int sm2_key_recover(SM2_KEY *key, const SM2_KEY_SHARE *shares, size_t shares_cnt); +int sm2_key_share_encrypt_to_file(const SM2_KEY_SHARE *share, const char *pass, const char *path_prefix); +int sm2_key_share_decrypt_from_file(SM2_KEY_SHARE *share, const char *pass, const char *file); +int sm2_key_share_print(FILE *fp, int fmt, int ind, const char *label, const SM2_KEY_SHARE *share); + typedef struct { uint8_t r[32]; diff --git a/src/sm2_alg.c b/src/sm2_alg.c index 4aa7ef38..612744d4 100644 --- a/src/sm2_alg.c +++ b/src/sm2_alg.c @@ -317,14 +317,26 @@ void sm2_bn_sub(SM2_BN ret, const SM2_BN a, const SM2_BN b) sm2_bn_copy(ret, r); } -// FIXME: get random from outside -void sm2_bn_rand_range(SM2_BN r, const SM2_BN range) +int sm2_bn_rand_range(SM2_BN r, const SM2_BN range) { uint8_t buf[32]; do { - (void)rand_bytes(buf, sizeof(buf)); + if (rand_bytes(buf, sizeof(buf)) != 1) { + error_print(); + return -1; + } sm2_bn_from_bytes(r, buf); } while (sm2_bn_cmp(r, range) >= 0); + return 1; +} + +int sm2_fp_rand(SM2_Fp r) +{ + if (sm2_bn_rand_range(r, SM2_P) != 1) { + error_print(); + return -1; + } + return 1; } void sm2_fp_add(SM2_Fp r, const SM2_Fp a, const SM2_Fp b) @@ -618,21 +630,20 @@ static void sm2_bn288_sub(uint64_t ret[9], const uint64_t a[9], const uint64_t b } } -void sm2_fn_mul(SM2_BN r, const SM2_BN a, const SM2_BN b) +void sm2_fn_mul(SM2_BN ret, const SM2_BN a, const SM2_BN b) { - static const uint64_t mu[8] = { - 0xf15149a0, 0x12ac6361, 0xfa323c01, 0x8dfc2096, - 1, 1, 1, 0x100000001, + SM2_BN r; + static const uint64_t mu[9] = { + 0xf15149a0, 0x12ac6361, 0xfa323c01, 0x8dfc2096, 1, 1, 1, 1, 1, }; - uint64_t s[17]; + uint64_t s[18]; uint64_t zh[9]; uint64_t zl[9]; uint64_t q[9]; uint64_t w; int i, j; - /* z = a * b */ for (i = 0; i < 8; i++) { s[i] = 0; @@ -662,21 +673,20 @@ void sm2_fn_mul(SM2_BN r, const SM2_BN a, const SM2_BN b) } for (i = 0; i < 9; i++) { w = 0; - for (j = 0; j < 8; j++) { + for (j = 0; j < 9; j++) { w += s[i + j] + zh[i] * mu[j]; s[i + j] = w & 0xffffffff; w >>= 32; } - s[i + 8] = w; + s[i + 9] = w; } for (i = 0; i < 8; i++) { q[i] = s[9 + i]; } //printf("q = "); for (i = 7; i >= 0; i--) printf("%08x", (uint32_t)q[i]); printf("\n"); - /* q = q * n mod (2^32)^9 */ - for (i = 0; i < 8; i++) { + for (i = 0; i < 17; i++) { s[i] = 0; } for (i = 0; i < 8; i++) { @@ -691,7 +701,7 @@ void sm2_fn_mul(SM2_BN r, const SM2_BN a, const SM2_BN b) for (i = 0; i < 9; i++) { q[i] = s[i]; } - //printf("qn = "); for (i = 8; i >= 0; i--) printf("%08x", (uint32_t)q[i]); printf("\n"); + //printf("qn = "); for (i = 8; i >= 0; i--) printf("%08x ", (uint32_t)q[i]); printf("\n"); /* r = zl - q (mod (2^32)^9) */ @@ -702,9 +712,7 @@ void sm2_fn_mul(SM2_BN r, const SM2_BN a, const SM2_BN b) sm2_bn288_sub(q, c, q); sm2_bn288_add(zl, q, zl); } - - //printf("r = "); for (i = 8; i >= 0; i--) printf("%08x", (uint32_t)zl[i]); printf("\n"); - + //printf("zl = "); for (i = 8; i >= 0; i--) printf("%08x ", (uint32_t)zl[i]); printf("\n"); for (i = 0; i < 8; i++) { r[i] = zl[i]; } @@ -713,8 +721,16 @@ void sm2_fn_mul(SM2_BN r, const SM2_BN a, const SM2_BN b) /* while r >= p do: r = r - n */ while (sm2_bn_cmp(r, SM2_N) >= 0) { sm2_bn_sub(r, r, SM2_N); - //printf("r = r -n = "); for (i = 8; i >= 0; i--) printf("%08x", (uint32_t)zl[i]); printf("\n"); + //printf("r-n = "); for (i = 7; i >= 0; i--) printf("%16llx ", r[i]); printf("\n"); } + sm2_bn_copy(ret, r); +} + +void sm2_fn_mul_word(SM2_Fn r, const SM2_Fn a, uint32_t b) +{ + SM2_Fn t; + sm2_bn_set_word(t, b); + sm2_fn_mul(r, a, t); } void sm2_fn_sqr(SM2_BN r, const SM2_BN a) @@ -739,7 +755,6 @@ void sm2_fn_exp(SM2_BN r, const SM2_BN a, const SM2_BN e) w <<= 1; } } - sm2_bn_copy(r, t); } @@ -750,9 +765,13 @@ void sm2_fn_inv(SM2_BN r, const SM2_BN a) sm2_fn_exp(r, a, e); } -void sm2_fn_rand(SM2_BN r) +int sm2_fn_rand(SM2_BN r) { - sm2_bn_rand_range(r, SM2_N); + if (sm2_bn_rand_range(r, SM2_N) != 1) { + error_print(); + return -1; + } + return 1; } @@ -1088,7 +1107,7 @@ int sm2_point_from_x(SM2_POINT *P, const uint8_t x[32], int y) error_print(); return -1; } - + if ((y == 0x02 && sm2_bn_is_odd(_y)) || (y == 0x03) && !sm2_bn_is_odd(_y)) { sm2_fp_neg(_y, _y); } diff --git a/src/sm2_key_share.c b/src/sm2_key_share.c index d6f10144..ba55b507 100644 --- a/src/sm2_key_share.c +++ b/src/sm2_key_share.c @@ -12,144 +12,279 @@ #include #include #include +#include #include -typedef struct { - SM2_KEY key; - unsigned int index; - unsigned int total_cnt; -} SM2_KEY_SHARE; - - -static int sm2_fn_mul_word(SM2_Fn r, const SM2_Fn a, uint32_t b) +int sm2_key_share_print(FILE *fp, int fmt, int ind, const char *label, const SM2_KEY_SHARE *share) { - SM2_Fn t; - sm2_bn_set_word(t, b); - sm2_fn_mul(r, a, t); + format_print(fp, fmt, ind, "%s\n", label); + ind += 4; + format_print(fp, fmt, ind, "%zu/%zu\n", share->index, share->total_cnt); + format_print(fp, fmt, ind, "key", &share->key); return 1; } -static int eval_univariate_poly(const SM2_Fn *coeffs, size_t coeffs_cnt, unsigned int x, SM2_Fn out) -{ - sm2_bn_set_zero(out); +// y = f(x) +static void eval_univariate_poly(SM2_Fn y, const SM2_Fn *coeffs, size_t coeffs_cnt, uint32_t x) +{ + sm2_bn_set_zero(y); while (coeffs_cnt--) { - sm2_fn_mul_word(out, out, x); - sm2_fn_add(out, out, coeffs[coeffs_cnt]); + sm2_fn_mul_word(y, y, x); + sm2_fn_add(y, y, coeffs[coeffs_cnt]); } - return 1; } +#define SM2_KEY_MAX_SHARES 12 // 12! = 479001600 < 2^31 = 2147483648 + int sm2_key_split(const SM2_KEY *key, size_t recover_cnt, size_t total_cnt, SM2_KEY_SHARE *shares) { - SM2_Fn *coeffs = NULL; - size_t coeffs_cnt = recover_cnt; - size_t x; - size_t i; - + SM2_Fn coeffs[SM2_KEY_MAX_SHARES]; SM2_Fn y; uint8_t y_bytes[32]; + size_t i; - // f(x) = a_0 + a_1 * x + ... + a_(k-1) * x^(k-1) - // a_0 = private_key, a_i = rand(1, n-1) - if (!(coeffs = (SM2_Fn *)malloc(sizeof(SM2_Fn) * coeffs_cnt))) { + if (!key || !shares) { error_print(); return -1; } - sm2_bn_from_bytes(coeffs[0], key->private_key); + if (!total_cnt || total_cnt > SM2_KEY_MAX_SHARES) { + error_print(); + return -1; + } + if (!recover_cnt || recover_cnt > total_cnt) { + error_print(); + return -1; + } + // try to access mem + memset(shares, 0, sizeof(SM2_KEY_SHARE) * total_cnt); for (i = 1; i < recover_cnt; i++) { - sm2_fn_rand(coeffs[i]); // FIXME: check return value + if (sm2_fn_rand(coeffs[i]) != 1) { + error_print(); + return -1; + } } + sm2_bn_from_bytes(coeffs[0], key->private_key); - for (x = 1; x <= total_cnt; x++) { - SM2_KEY *key = &(shares[i].key); - // y = f(x) - eval_univariate_poly(coeffs, coeffs_cnt, x, y); - + for (i = 0; i < total_cnt; i++) { + uint32_t x = (uint32_t)(i + 1); + eval_univariate_poly(y, coeffs, recover_cnt, x); sm2_bn_to_bytes(y, y_bytes); - sm2_key_set_private_key(key, y_bytes); - - shares[i].index = x - 1; + sm2_key_set_private_key(&(shares[i].key), y_bytes); + shares[i].index = i; shares[i].total_cnt = total_cnt; } - memset(y, 0, sizeof(SM2_Fn)); - memset(y_bytes, 0, sizeof(y_bytes)); - memset(coeffs, 0, sizeof(SM2_Fn) * coeffs_cnt); - free(coeffs); + gmssl_secure_clear(coeffs, sizeof(coeffs)); + gmssl_secure_clear(y, sizeof(y)); + gmssl_secure_clear(y_bytes, sizeof(y_bytes)); return 1; } -// n is total_cnt, out is delta[] array -// for i=1..n, delta[i] = prod(-j/(i - j)) in GF(N), j = 1..n, j != i -int generate_delta_list(size_t total_cnt, SM2_Fn *out) +int sm2_key_recover(SM2_KEY *key, const SM2_KEY_SHARE *shares, size_t shares_cnt) { - SM2_Fn a; - size_t i, j; + SM2_Fn s; + uint8_t s_bytes[32]; + int x_i; + SM2_Fn y_i; + size_t i, j, k, n; - for (i = 0; i < total_cnt; i++) { - sm2_bn_set_one(out[i]); - - for (j = 0; j < total_cnt; j++) { - // Here i, j start from 0, so (i+1) and (j+1) is the needed value - // a = -(j + 1)/((i + 1) - (j + 1)) = -(j + 1)/(i - j), i != j - if (i < j) { - sm2_bn_set_word(a, j - i); - } else if (i > j) { - sm2_bn_set_word(a, i - j); - sm2_fn_neg(a, a); - } - sm2_fn_inv(a, a); - sm2_fn_mul_word(a, a, j + 1); - sm2_fn_mul(out[i], out[i], a); - } + if (!shares || !shares_cnt || !key) { + error_print(); + return -1; } - return 1; -} + k = shares_cnt; + n = shares[0].total_cnt; -int sm2_key_recover(const SM2_KEY_SHARE *shares, size_t shares_cnt, SM2_KEY *key) -{ - SM2_Fn a; - SM2_Fn s; - size_t i; - size_t total_cnt; - SM2_Fn *delta = NULL; - uint8_t a_bytes[32]; - - total_cnt = shares[0].total_cnt; - for (i = 1; i < shares_cnt; i++) { - if (shares[i].total_cnt != total_cnt - || shares[i].index > total_cnt) { + if (n > SM2_KEY_MAX_SHARES) { + error_print(); + return -1; + } + for (i = 0; i < k; i++) { + if (shares[i].total_cnt != n + || shares[i].index >= n) { error_print(); return -1; } } - if (!(delta = (SM2_Fn *)malloc(sizeof(SM2_Fn) * total_cnt))) { - error_print(); - return -1; - } - generate_delta_list(total_cnt, delta); + sm2_bn_set_zero(s); - sm2_bn_set_zero(a); + for (i = 0; i < k; i++) { + // delta_i + SM2_Fn d; + int num = 1; + int den = 1; + int sign = 1; - for (i = 0; i < shares_cnt; i++) { - const SM2_KEY *key = &shares[i].key; + x_i = (int)(shares[i].index + 1); - sm2_bn_from_bytes(s, key->private_key); - sm2_fn_mul(s, s, delta[shares[i].index]); - sm2_fn_add(a, a, s); + for (j = 0; j < k; j++) { + if (i != j) { + int x_j = (int)(shares[j].index + 1); + num *= -x_j; + den *= x_i - x_j; + } + } + if (num < 0) { + num = -num; + sign = -sign; + } + if (den < 0) { + den = -den; + sign = -sign; + } + + // delta_i = Fn( num / den ) + sm2_bn_set_word(d, den); + sm2_fn_inv(d, d); + sm2_fn_mul_word(d, d, num); + if (sign < 0) { + sm2_fn_neg(d, d); + } + + // s += delta_i * y_i + sm2_bn_from_bytes(y_i, shares[i].key.private_key); + if (sm2_bn_cmp(y_i, SM2_N) >= 0) { + gmssl_secure_clear(y_i, sizeof(y_i)); + gmssl_secure_clear(s, sizeof(s)); + error_print(); + return -1; + } + sm2_fn_mul(y_i, y_i, d); + sm2_fn_add(s, s, y_i); } - sm2_bn_to_bytes(a, a_bytes); - sm2_key_set_private_key(key, a_bytes); + sm2_bn_to_bytes(s, s_bytes); + sm2_key_set_private_key(key, s_bytes); - - memset(a, 0, sizeof(a)); - memset(a_bytes, 0, sizeof(a_bytes)); + gmssl_secure_clear(y_i, sizeof(y_i)); + gmssl_secure_clear(s, sizeof(s)); + gmssl_secure_clear(s_bytes, sizeof(s_bytes)); + return 1; +} + +int sm2_key_share_encrypt_to_file(const SM2_KEY_SHARE *share, const char *pass, const char *path_prefix) +{ + int ret; + char *path = NULL; + FILE *fp = NULL; + int len; + + if (!share || !pass || !path_prefix) { + error_print(); + return -1; + } + if (!share->total_cnt || share->total_cnt > 12 || share->index >= share->total_cnt) { + sm2_key_share_print(stderr, 0, 0, "share", share); + error_print(); + return -1; + } + if ((len = snprintf(NULL, 0, "%s-%zu-of-%zu.pem", path_prefix, share->index + 1, share->total_cnt)) <= 0) { + error_print(); + return -1; + } + if (!(path = malloc(len + 1))) { + error_print(); + return -1; + } + snprintf(path, len+1, "%s-%zu-of-%zu.pem", path_prefix, share->index + 1, share->total_cnt); + + + if (!(fp = fopen(path, "wb"))) { + error_print(); + goto end; + } + if (sm2_private_key_info_encrypt_to_pem(&share->key, pass, fp) != 1) { + error_print(); + goto end; + } + ret = 1; +end: + if (path) free(path); + if (fp) fclose(fp); + return ret; +} + +int sm2_key_share_decrypt_from_file(SM2_KEY_SHARE *share, const char *pass, const char *file) +{ + error_print(); + return -1; +} + +int test_sm2_key_share_args(size_t k, size_t n) +{ + SM2_KEY key; + SM2_KEY key_; + SM2_KEY_SHARE shares[SM2_KEY_MAX_SHARES]; + + if (sm2_key_generate(&key) != 1) { + error_print(); + return -1; + } + if (sm2_key_split(&key, k, n, shares) != 1) { + error_print(); + return -1; + } + + // recover from 0 .. k + if (sm2_key_recover(&key_, shares, k) != 1) { + error_print(); + return -1; + } + if (memcmp(&key_, &key, sizeof(SM2_KEY)) != 0) { + error_print(); + return -1; + } + + // recover from n-k .. n + memset(&key_, 0, sizeof(key_)); + if (sm2_key_recover(&key_, shares + n - k, k) != 1) { + error_print(); + return -1; + } + if (memcmp(&key_, &key, sizeof(SM2_KEY)) != 0) { + error_print(); + return -1; + } + return 1; +} + +int test_sm2_key_share(void) +{ + if (test_sm2_key_share_args(1, 1) != 1) { error_print(); return -1; } + if (test_sm2_key_share_args(1, 3) != 1) { error_print(); return -1; } + if (test_sm2_key_share_args(2, 3) != 1) { error_print(); return -1; } + if (test_sm2_key_share_args(3, 5) != 1) { error_print(); return -1; } + if (test_sm2_key_share_args(4, 5) != 1) { error_print(); return -1; } + if (test_sm2_key_share_args(5, 5) != 1) { error_print(); return -1; } + if (test_sm2_key_share_args(11, 12) != 1) { error_print(); return -1; } + if (test_sm2_key_share_args(12, 12) != 1) { error_print(); return -1; } + return 1; +} + +int test_sm2_key_share_file(void) +{ + SM2_KEY key; + SM2_KEY_SHARE shares[SM2_KEY_MAX_SHARES]; + + if (sm2_key_generate(&key) != 1) { + error_print(); + return -1; + } + if (sm2_key_split(&key, 2, 3, shares) != 1) { + error_print(); + return -1; + } + if (sm2_key_share_encrypt_to_file(&shares[0], "123456", "sm2key") != 1 + || sm2_key_share_encrypt_to_file(&shares[1], "123456", "sm2key") != 1 + || sm2_key_share_encrypt_to_file(&shares[2], "123456", "sm2key") != 1) { + error_print(); + return -1; + } return 1; } diff --git a/src/sm2_key_share.h b/src/sm2_key_share.h new file mode 100644 index 00000000..ef4b1582 --- /dev/null +++ b/src/sm2_key_share.h @@ -0,0 +1,29 @@ +/* + * Copyright 2014-2022 The GmSSL Project. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +#include +#include +#include +#include +#include +#include +#include + + +typedef struct { + SM2_KEY key; + size_t index; + size_t total_cnt; +} SM2_KEY_SHARE; + +int sm2_key_split(const SM2_KEY *key, size_t recover_cnt, size_t total_cnt, SM2_KEY_SHARE *shares) +int sm2_key_recover(SM2_KEY *key, const SM2_KEY_SHARE *shares, size_t shares_cnt) +int sm2_key_share_encrypt_to_file(const SM2_KEY_SHARE *share, const char *pass, const char *path_prefix) +int sm2_key_share_decrypt_from_file(SM2_KEY_SHARE *share, const char *pass, const char *file) +int sm2_key_share_print(FILE *fp, int fmt, int ind, const char *label, const SM2_KEY_SHARE *share)