Update SM2 key sharing

This commit is contained in:
Zhi Guan
2022-11-07 13:42:18 +08:00
parent 57a3c9682d
commit 13e6f5138f
4 changed files with 313 additions and 116 deletions

View File

@@ -76,7 +76,7 @@ void sm2_bn_to_bits(const SM2_BN a, char bits[256]);
void sm2_bn_set_word(SM2_BN r, uint32_t a);
void sm2_bn_add(SM2_BN r, const SM2_BN a, const SM2_BN b);
void sm2_bn_sub(SM2_BN ret, const SM2_BN a, const SM2_BN b);
void sm2_bn_rand_range(SM2_BN r, const SM2_BN range); // 这个函数需要修改一下,从外部引入随机数
int sm2_bn_rand_range(SM2_BN r, const SM2_BN range);
#define sm2_bn_init(r) memset((r),0,sizeof(SM2_BN))
#define sm2_bn_set_zero(r) memset((r),0,sizeof(SM2_BN))
@@ -98,7 +98,7 @@ void sm2_fp_div2(SM2_Fp r, const SM2_Fp a);
void sm2_fp_neg(SM2_Fp r, const SM2_Fp a);
void sm2_fp_sqr(SM2_Fp r, const SM2_Fp a);
void sm2_fp_inv(SM2_Fp r, const SM2_Fp a);
void sm2_fp_rand(SM2_Fp r); // 外部提供随机性,如果满足条件就输出,如果不满足条件就哈希一下再输出
int sm2_fp_rand(SM2_Fp r);
int sm2_fp_sqrt(SM2_Fp r, const SM2_Fp a);
@@ -114,11 +114,12 @@ typedef SM2_BN SM2_Fn;
void sm2_fn_add(SM2_Fn r, const SM2_Fn a, const SM2_Fn b);
void sm2_fn_sub(SM2_Fn r, const SM2_Fn a, const SM2_Fn b);
void sm2_fn_mul(SM2_Fn r, const SM2_Fn a, const SM2_Fn b);
void sm2_fn_mul_word(SM2_Fn r, const SM2_Fn a, uint32_t b);
void sm2_fn_exp(SM2_Fn r, const SM2_Fn a, const SM2_Fn e);
void sm2_fn_neg(SM2_Fn r, const SM2_Fn a);
void sm2_fn_sqr(SM2_Fn r, const SM2_Fn a);
void sm2_fn_inv(SM2_Fn r, const SM2_Fn a);
void sm2_fn_rand(SM2_Fn r);
int sm2_fn_rand(SM2_Fn r);
#define sm2_fn_init(r) sm2_bn_init(r)
#define sm2_fn_set_zero(r) sm2_bn_set_zero(r)
@@ -298,6 +299,19 @@ int sm2_private_key_info_decrypt_from_der(SM2_KEY *key, const uint8_t **attrs, s
int sm2_private_key_info_encrypt_to_pem(const SM2_KEY *key, const char *pass, FILE *fp);
int sm2_private_key_info_decrypt_from_pem(SM2_KEY *key, const char *pass, FILE *fp);
// SM2 Key Shamir Secret Sharing
typedef struct {
SM2_KEY key;
size_t index;
size_t total_cnt;
} SM2_KEY_SHARE;
int sm2_key_split(const SM2_KEY *key, size_t recover_cnt, size_t total_cnt, SM2_KEY_SHARE *shares);
int sm2_key_recover(SM2_KEY *key, const SM2_KEY_SHARE *shares, size_t shares_cnt);
int sm2_key_share_encrypt_to_file(const SM2_KEY_SHARE *share, const char *pass, const char *path_prefix);
int sm2_key_share_decrypt_from_file(SM2_KEY_SHARE *share, const char *pass, const char *file);
int sm2_key_share_print(FILE *fp, int fmt, int ind, const char *label, const SM2_KEY_SHARE *share);
typedef struct {
uint8_t r[32];

View File

@@ -317,14 +317,26 @@ void sm2_bn_sub(SM2_BN ret, const SM2_BN a, const SM2_BN b)
sm2_bn_copy(ret, r);
}
// FIXME: get random from outside
void sm2_bn_rand_range(SM2_BN r, const SM2_BN range)
int sm2_bn_rand_range(SM2_BN r, const SM2_BN range)
{
uint8_t buf[32];
do {
(void)rand_bytes(buf, sizeof(buf));
if (rand_bytes(buf, sizeof(buf)) != 1) {
error_print();
return -1;
}
sm2_bn_from_bytes(r, buf);
} while (sm2_bn_cmp(r, range) >= 0);
return 1;
}
int sm2_fp_rand(SM2_Fp r)
{
if (sm2_bn_rand_range(r, SM2_P) != 1) {
error_print();
return -1;
}
return 1;
}
void sm2_fp_add(SM2_Fp r, const SM2_Fp a, const SM2_Fp b)
@@ -618,21 +630,20 @@ static void sm2_bn288_sub(uint64_t ret[9], const uint64_t a[9], const uint64_t b
}
}
void sm2_fn_mul(SM2_BN r, const SM2_BN a, const SM2_BN b)
void sm2_fn_mul(SM2_BN ret, const SM2_BN a, const SM2_BN b)
{
static const uint64_t mu[8] = {
0xf15149a0, 0x12ac6361, 0xfa323c01, 0x8dfc2096,
1, 1, 1, 0x100000001,
SM2_BN r;
static const uint64_t mu[9] = {
0xf15149a0, 0x12ac6361, 0xfa323c01, 0x8dfc2096, 1, 1, 1, 1, 1,
};
uint64_t s[17];
uint64_t s[18];
uint64_t zh[9];
uint64_t zl[9];
uint64_t q[9];
uint64_t w;
int i, j;
/* z = a * b */
for (i = 0; i < 8; i++) {
s[i] = 0;
@@ -662,21 +673,20 @@ void sm2_fn_mul(SM2_BN r, const SM2_BN a, const SM2_BN b)
}
for (i = 0; i < 9; i++) {
w = 0;
for (j = 0; j < 8; j++) {
for (j = 0; j < 9; j++) {
w += s[i + j] + zh[i] * mu[j];
s[i + j] = w & 0xffffffff;
w >>= 32;
}
s[i + 8] = w;
s[i + 9] = w;
}
for (i = 0; i < 8; i++) {
q[i] = s[9 + i];
}
//printf("q = "); for (i = 7; i >= 0; i--) printf("%08x", (uint32_t)q[i]); printf("\n");
/* q = q * n mod (2^32)^9 */
for (i = 0; i < 8; i++) {
for (i = 0; i < 17; i++) {
s[i] = 0;
}
for (i = 0; i < 8; i++) {
@@ -691,7 +701,7 @@ void sm2_fn_mul(SM2_BN r, const SM2_BN a, const SM2_BN b)
for (i = 0; i < 9; i++) {
q[i] = s[i];
}
//printf("qn = "); for (i = 8; i >= 0; i--) printf("%08x", (uint32_t)q[i]); printf("\n");
//printf("qn = "); for (i = 8; i >= 0; i--) printf("%08x ", (uint32_t)q[i]); printf("\n");
/* r = zl - q (mod (2^32)^9) */
@@ -702,9 +712,7 @@ void sm2_fn_mul(SM2_BN r, const SM2_BN a, const SM2_BN b)
sm2_bn288_sub(q, c, q);
sm2_bn288_add(zl, q, zl);
}
//printf("r = "); for (i = 8; i >= 0; i--) printf("%08x", (uint32_t)zl[i]); printf("\n");
//printf("zl = "); for (i = 8; i >= 0; i--) printf("%08x ", (uint32_t)zl[i]); printf("\n");
for (i = 0; i < 8; i++) {
r[i] = zl[i];
}
@@ -713,8 +721,16 @@ void sm2_fn_mul(SM2_BN r, const SM2_BN a, const SM2_BN b)
/* while r >= p do: r = r - n */
while (sm2_bn_cmp(r, SM2_N) >= 0) {
sm2_bn_sub(r, r, SM2_N);
//printf("r = r -n = "); for (i = 8; i >= 0; i--) printf("%08x", (uint32_t)zl[i]); printf("\n");
//printf("r-n = "); for (i = 7; i >= 0; i--) printf("%16llx ", r[i]); printf("\n");
}
sm2_bn_copy(ret, r);
}
void sm2_fn_mul_word(SM2_Fn r, const SM2_Fn a, uint32_t b)
{
SM2_Fn t;
sm2_bn_set_word(t, b);
sm2_fn_mul(r, a, t);
}
void sm2_fn_sqr(SM2_BN r, const SM2_BN a)
@@ -739,7 +755,6 @@ void sm2_fn_exp(SM2_BN r, const SM2_BN a, const SM2_BN e)
w <<= 1;
}
}
sm2_bn_copy(r, t);
}
@@ -750,9 +765,13 @@ void sm2_fn_inv(SM2_BN r, const SM2_BN a)
sm2_fn_exp(r, a, e);
}
void sm2_fn_rand(SM2_BN r)
int sm2_fn_rand(SM2_BN r)
{
sm2_bn_rand_range(r, SM2_N);
if (sm2_bn_rand_range(r, SM2_N) != 1) {
error_print();
return -1;
}
return 1;
}

View File

@@ -12,144 +12,279 @@
#include <stdlib.h>
#include <stdint.h>
#include <gmssl/sm2.h>
#include <gmssl/mem.h>
#include <gmssl/error.h>
typedef struct {
SM2_KEY key;
unsigned int index;
unsigned int total_cnt;
} SM2_KEY_SHARE;
static int sm2_fn_mul_word(SM2_Fn r, const SM2_Fn a, uint32_t b)
int sm2_key_share_print(FILE *fp, int fmt, int ind, const char *label, const SM2_KEY_SHARE *share)
{
SM2_Fn t;
sm2_bn_set_word(t, b);
sm2_fn_mul(r, a, t);
format_print(fp, fmt, ind, "%s\n", label);
ind += 4;
format_print(fp, fmt, ind, "%zu/%zu\n", share->index, share->total_cnt);
format_print(fp, fmt, ind, "key", &share->key);
return 1;
}
static int eval_univariate_poly(const SM2_Fn *coeffs, size_t coeffs_cnt, unsigned int x, SM2_Fn out)
{
sm2_bn_set_zero(out);
// y = f(x)
static void eval_univariate_poly(SM2_Fn y, const SM2_Fn *coeffs, size_t coeffs_cnt, uint32_t x)
{
sm2_bn_set_zero(y);
while (coeffs_cnt--) {
sm2_fn_mul_word(out, out, x);
sm2_fn_add(out, out, coeffs[coeffs_cnt]);
sm2_fn_mul_word(y, y, x);
sm2_fn_add(y, y, coeffs[coeffs_cnt]);
}
return 1;
}
#define SM2_KEY_MAX_SHARES 12 // 12! = 479001600 < 2^31 = 2147483648
int sm2_key_split(const SM2_KEY *key, size_t recover_cnt, size_t total_cnt, SM2_KEY_SHARE *shares)
{
SM2_Fn *coeffs = NULL;
size_t coeffs_cnt = recover_cnt;
size_t x;
size_t i;
SM2_Fn coeffs[SM2_KEY_MAX_SHARES];
SM2_Fn y;
uint8_t y_bytes[32];
size_t i;
// f(x) = a_0 + a_1 * x + ... + a_(k-1) * x^(k-1)
// a_0 = private_key, a_i = rand(1, n-1)
if (!(coeffs = (SM2_Fn *)malloc(sizeof(SM2_Fn) * coeffs_cnt))) {
if (!key || !shares) {
error_print();
return -1;
}
sm2_bn_from_bytes(coeffs[0], key->private_key);
if (!total_cnt || total_cnt > SM2_KEY_MAX_SHARES) {
error_print();
return -1;
}
if (!recover_cnt || recover_cnt > total_cnt) {
error_print();
return -1;
}
// try to access mem
memset(shares, 0, sizeof(SM2_KEY_SHARE) * total_cnt);
for (i = 1; i < recover_cnt; i++) {
sm2_fn_rand(coeffs[i]); // FIXME: check return value
if (sm2_fn_rand(coeffs[i]) != 1) {
error_print();
return -1;
}
}
sm2_bn_from_bytes(coeffs[0], key->private_key);
for (x = 1; x <= total_cnt; x++) {
SM2_KEY *key = &(shares[i].key);
// y = f(x)
eval_univariate_poly(coeffs, coeffs_cnt, x, y);
for (i = 0; i < total_cnt; i++) {
uint32_t x = (uint32_t)(i + 1);
eval_univariate_poly(y, coeffs, recover_cnt, x);
sm2_bn_to_bytes(y, y_bytes);
sm2_key_set_private_key(key, y_bytes);
shares[i].index = x - 1;
sm2_key_set_private_key(&(shares[i].key), y_bytes);
shares[i].index = i;
shares[i].total_cnt = total_cnt;
}
memset(y, 0, sizeof(SM2_Fn));
memset(y_bytes, 0, sizeof(y_bytes));
memset(coeffs, 0, sizeof(SM2_Fn) * coeffs_cnt);
free(coeffs);
gmssl_secure_clear(coeffs, sizeof(coeffs));
gmssl_secure_clear(y, sizeof(y));
gmssl_secure_clear(y_bytes, sizeof(y_bytes));
return 1;
}
// n is total_cnt, out is delta[] array
// for i=1..n, delta[i] = prod(-j/(i - j)) in GF(N), j = 1..n, j != i
int generate_delta_list(size_t total_cnt, SM2_Fn *out)
int sm2_key_recover(SM2_KEY *key, const SM2_KEY_SHARE *shares, size_t shares_cnt)
{
SM2_Fn a;
size_t i, j;
SM2_Fn s;
uint8_t s_bytes[32];
int x_i;
SM2_Fn y_i;
size_t i, j, k, n;
for (i = 0; i < total_cnt; i++) {
sm2_bn_set_one(out[i]);
for (j = 0; j < total_cnt; j++) {
// Here i, j start from 0, so (i+1) and (j+1) is the needed value
// a = -(j + 1)/((i + 1) - (j + 1)) = -(j + 1)/(i - j), i != j
if (i < j) {
sm2_bn_set_word(a, j - i);
} else if (i > j) {
sm2_bn_set_word(a, i - j);
sm2_fn_neg(a, a);
}
sm2_fn_inv(a, a);
sm2_fn_mul_word(a, a, j + 1);
sm2_fn_mul(out[i], out[i], a);
}
if (!shares || !shares_cnt || !key) {
error_print();
return -1;
}
return 1;
}
k = shares_cnt;
n = shares[0].total_cnt;
int sm2_key_recover(const SM2_KEY_SHARE *shares, size_t shares_cnt, SM2_KEY *key)
{
SM2_Fn a;
SM2_Fn s;
size_t i;
size_t total_cnt;
SM2_Fn *delta = NULL;
uint8_t a_bytes[32];
total_cnt = shares[0].total_cnt;
for (i = 1; i < shares_cnt; i++) {
if (shares[i].total_cnt != total_cnt
|| shares[i].index > total_cnt) {
if (n > SM2_KEY_MAX_SHARES) {
error_print();
return -1;
}
for (i = 0; i < k; i++) {
if (shares[i].total_cnt != n
|| shares[i].index >= n) {
error_print();
return -1;
}
}
if (!(delta = (SM2_Fn *)malloc(sizeof(SM2_Fn) * total_cnt))) {
error_print();
return -1;
}
generate_delta_list(total_cnt, delta);
sm2_bn_set_zero(s);
sm2_bn_set_zero(a);
for (i = 0; i < k; i++) {
// delta_i
SM2_Fn d;
int num = 1;
int den = 1;
int sign = 1;
for (i = 0; i < shares_cnt; i++) {
const SM2_KEY *key = &shares[i].key;
x_i = (int)(shares[i].index + 1);
sm2_bn_from_bytes(s, key->private_key);
sm2_fn_mul(s, s, delta[shares[i].index]);
sm2_fn_add(a, a, s);
for (j = 0; j < k; j++) {
if (i != j) {
int x_j = (int)(shares[j].index + 1);
num *= -x_j;
den *= x_i - x_j;
}
}
if (num < 0) {
num = -num;
sign = -sign;
}
if (den < 0) {
den = -den;
sign = -sign;
}
// delta_i = Fn( num / den )
sm2_bn_set_word(d, den);
sm2_fn_inv(d, d);
sm2_fn_mul_word(d, d, num);
if (sign < 0) {
sm2_fn_neg(d, d);
}
// s += delta_i * y_i
sm2_bn_from_bytes(y_i, shares[i].key.private_key);
if (sm2_bn_cmp(y_i, SM2_N) >= 0) {
gmssl_secure_clear(y_i, sizeof(y_i));
gmssl_secure_clear(s, sizeof(s));
error_print();
return -1;
}
sm2_fn_mul(y_i, y_i, d);
sm2_fn_add(s, s, y_i);
}
sm2_bn_to_bytes(a, a_bytes);
sm2_key_set_private_key(key, a_bytes);
sm2_bn_to_bytes(s, s_bytes);
sm2_key_set_private_key(key, s_bytes);
memset(a, 0, sizeof(a));
memset(a_bytes, 0, sizeof(a_bytes));
gmssl_secure_clear(y_i, sizeof(y_i));
gmssl_secure_clear(s, sizeof(s));
gmssl_secure_clear(s_bytes, sizeof(s_bytes));
return 1;
}
int sm2_key_share_encrypt_to_file(const SM2_KEY_SHARE *share, const char *pass, const char *path_prefix)
{
int ret;
char *path = NULL;
FILE *fp = NULL;
int len;
if (!share || !pass || !path_prefix) {
error_print();
return -1;
}
if (!share->total_cnt || share->total_cnt > 12 || share->index >= share->total_cnt) {
sm2_key_share_print(stderr, 0, 0, "share", share);
error_print();
return -1;
}
if ((len = snprintf(NULL, 0, "%s-%zu-of-%zu.pem", path_prefix, share->index + 1, share->total_cnt)) <= 0) {
error_print();
return -1;
}
if (!(path = malloc(len + 1))) {
error_print();
return -1;
}
snprintf(path, len+1, "%s-%zu-of-%zu.pem", path_prefix, share->index + 1, share->total_cnt);
if (!(fp = fopen(path, "wb"))) {
error_print();
goto end;
}
if (sm2_private_key_info_encrypt_to_pem(&share->key, pass, fp) != 1) {
error_print();
goto end;
}
ret = 1;
end:
if (path) free(path);
if (fp) fclose(fp);
return ret;
}
int sm2_key_share_decrypt_from_file(SM2_KEY_SHARE *share, const char *pass, const char *file)
{
error_print();
return -1;
}
int test_sm2_key_share_args(size_t k, size_t n)
{
SM2_KEY key;
SM2_KEY key_;
SM2_KEY_SHARE shares[SM2_KEY_MAX_SHARES];
if (sm2_key_generate(&key) != 1) {
error_print();
return -1;
}
if (sm2_key_split(&key, k, n, shares) != 1) {
error_print();
return -1;
}
// recover from 0 .. k
if (sm2_key_recover(&key_, shares, k) != 1) {
error_print();
return -1;
}
if (memcmp(&key_, &key, sizeof(SM2_KEY)) != 0) {
error_print();
return -1;
}
// recover from n-k .. n
memset(&key_, 0, sizeof(key_));
if (sm2_key_recover(&key_, shares + n - k, k) != 1) {
error_print();
return -1;
}
if (memcmp(&key_, &key, sizeof(SM2_KEY)) != 0) {
error_print();
return -1;
}
return 1;
}
int test_sm2_key_share(void)
{
if (test_sm2_key_share_args(1, 1) != 1) { error_print(); return -1; }
if (test_sm2_key_share_args(1, 3) != 1) { error_print(); return -1; }
if (test_sm2_key_share_args(2, 3) != 1) { error_print(); return -1; }
if (test_sm2_key_share_args(3, 5) != 1) { error_print(); return -1; }
if (test_sm2_key_share_args(4, 5) != 1) { error_print(); return -1; }
if (test_sm2_key_share_args(5, 5) != 1) { error_print(); return -1; }
if (test_sm2_key_share_args(11, 12) != 1) { error_print(); return -1; }
if (test_sm2_key_share_args(12, 12) != 1) { error_print(); return -1; }
return 1;
}
int test_sm2_key_share_file(void)
{
SM2_KEY key;
SM2_KEY_SHARE shares[SM2_KEY_MAX_SHARES];
if (sm2_key_generate(&key) != 1) {
error_print();
return -1;
}
if (sm2_key_split(&key, 2, 3, shares) != 1) {
error_print();
return -1;
}
if (sm2_key_share_encrypt_to_file(&shares[0], "123456", "sm2key") != 1
|| sm2_key_share_encrypt_to_file(&shares[1], "123456", "sm2key") != 1
|| sm2_key_share_encrypt_to_file(&shares[2], "123456", "sm2key") != 1) {
error_print();
return -1;
}
return 1;
}

29
src/sm2_key_share.h Normal file
View File

@@ -0,0 +1,29 @@
/*
* Copyright 2014-2022 The GmSSL Project. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the License); you may
* not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*/
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <stdint.h>
#include <gmssl/sm2.h>
#include <gmssl/mem.h>
#include <gmssl/error.h>
typedef struct {
SM2_KEY key;
size_t index;
size_t total_cnt;
} SM2_KEY_SHARE;
int sm2_key_split(const SM2_KEY *key, size_t recover_cnt, size_t total_cnt, SM2_KEY_SHARE *shares)
int sm2_key_recover(SM2_KEY *key, const SM2_KEY_SHARE *shares, size_t shares_cnt)
int sm2_key_share_encrypt_to_file(const SM2_KEY_SHARE *share, const char *pass, const char *path_prefix)
int sm2_key_share_decrypt_from_file(SM2_KEY_SHARE *share, const char *pass, const char *file)
int sm2_key_share_print(FILE *fp, int fmt, int ind, const char *label, const SM2_KEY_SHARE *share)