diff --git a/CMakeLists.txt b/CMakeLists.txt index 8318180a..bbc2f0e9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,8 +18,11 @@ set(src src/sm3_hmac.c src/sm3_kdf.c src/sm3_digest.c - src/sm2_alg.c - src/sm2_key.c + #src/sm2_alg.c + src/sm2_point.c + src/sm2_z256.c + src/sm2_z256_table.c + src/sm2_z256_key.c src/sm2_z256_sign.c src/sm2_lib.c src/sm2_ctx.c @@ -122,7 +125,8 @@ set(tools set(tests sm4 sm3 - sm2 +# sm2 + sm2_z256 sm2_sign sm2_enc sm9 @@ -238,6 +242,15 @@ if (ENABLE_SM2_ALGOR_ID_ENCODE_NULL) endif() +option(ENABLE_SM2_Z256_ARMV8 "Enable SM2_Z256 ARMv8 assembly" OFF) +if (ENABLE_SM2_Z256_ARMV8) + message(STATUS "ENABLE_SM2_Z256_ARMV8 is ON") + add_definitions(-DENABLE_SM2_Z256_ARMV8) + enable_language(ASM) + list(APPEND src src/sm2_z256_armv8.S) +endif() + + option(ENABLE_SM2_PRIVATE_KEY_EXPORT "Enable export un-encrypted SM2 private key" OFF) if (ENABLE_SM2_PRIVATE_KEY_EXPORT) message(STATUS "ENABLE_SM2_PRIVATE_KEY_EXPORT is ON") @@ -316,15 +329,6 @@ if (ENABLE_SM4_XTS) endif() -option(ENABLE_SM2_Z256 "Enable SM2 z256 implementation" OFF) -if (ENABLE_SM2_Z256) - message(STATUS "ENABLE_SM2_Z256 is ON") - add_definitions(-DENABLE_SM2_Z256) - list(APPEND src src/sm2_z256.c src/sm2_z256_table.c) - list(APPEND tests sm2_z256) -endif() - - option(ENABLE_SM2_EXTS "Enable SM2 Extensions" OFF) if (ENABLE_SM2_EXTS) message(STATUS "ENABLE_SM4_AESNI_AVX") diff --git a/include/gmssl/sm2.h b/include/gmssl/sm2.h index 6edd229b..b5c28f8c 100644 --- a/include/gmssl/sm2.h +++ b/include/gmssl/sm2.h @@ -15,113 +15,15 @@ #include #include #include -#include #include +#include +#include #ifdef __cplusplus extern "C" { #endif -typedef uint64_t SM2_BN[8]; - -int sm2_bn_is_zero(const SM2_BN a); -int sm2_bn_is_one(const SM2_BN a); -int sm2_bn_is_odd(const SM2_BN a); -int sm2_bn_cmp(const SM2_BN a, const SM2_BN b); -int sm2_bn_from_hex(SM2_BN r, const char hex[64]); -int sm2_bn_from_asn1_integer(SM2_BN r, const uint8_t *d, size_t dlen); -int sm2_bn_equ_hex(const SM2_BN a, const char *hex); -int sm2_bn_print(FILE *fp, int fmt, int ind, const char *label, const SM2_BN a); -int sm2_bn_rshift(SM2_BN ret, const SM2_BN a, unsigned int nbits); - -void sm2_bn_to_bytes(const SM2_BN a, uint8_t out[32]); -void sm2_bn_from_bytes(SM2_BN r, const uint8_t in[32]); -void sm2_bn_to_hex(const SM2_BN a, char hex[64]); -void sm2_bn_to_bits(const SM2_BN a, char bits[256]); -void sm2_bn_set_word(SM2_BN r, uint32_t a); -void sm2_bn_add(SM2_BN r, const SM2_BN a, const SM2_BN b); -void sm2_bn_sub(SM2_BN ret, const SM2_BN a, const SM2_BN b); -int sm2_bn_rand_range(SM2_BN r, const SM2_BN range); - -#define sm2_bn_init(r) memset((r),0,sizeof(SM2_BN)) -#define sm2_bn_set_zero(r) memset((r),0,sizeof(SM2_BN)) -#define sm2_bn_set_one(r) sm2_bn_set_word((r),1) -#define sm2_bn_copy(r,a) memcpy((r),(a),sizeof(SM2_BN)) -#define sm2_bn_clean(r) memset((r),0,sizeof(SM2_BN)) - - -// GF(p) -typedef SM2_BN SM2_Fp; - -void sm2_fp_add(SM2_Fp r, const SM2_Fp a, const SM2_Fp b); -void sm2_fp_sub(SM2_Fp r, const SM2_Fp a, const SM2_Fp b); -void sm2_fp_mul(SM2_Fp r, const SM2_Fp a, const SM2_Fp b); -void sm2_fp_exp(SM2_Fp r, const SM2_Fp a, const SM2_Fp e); -void sm2_fp_dbl(SM2_Fp r, const SM2_Fp a); -void sm2_fp_tri(SM2_Fp r, const SM2_Fp a); -void sm2_fp_div2(SM2_Fp r, const SM2_Fp a); -void sm2_fp_neg(SM2_Fp r, const SM2_Fp a); -void sm2_fp_sqr(SM2_Fp r, const SM2_Fp a); -void sm2_fp_inv(SM2_Fp r, const SM2_Fp a); -int sm2_fp_rand(SM2_Fp r); - -int sm2_fp_sqrt(SM2_Fp r, const SM2_Fp a); - -#define sm2_fp_init(r) sm2_bn_init(r) -#define sm2_fp_set_zero(r) sm2_bn_set_zero(r) -#define sm2_fp_set_one(r) sm2_bn_set_one(r) -#define sm2_fp_copy(r,a) sm2_bn_copy(r,a) -#define sm2_fp_clean(r) sm2_bn_clean(r) - -// GF(n) -typedef SM2_BN SM2_Fn; - -void sm2_fn_add(SM2_Fn r, const SM2_Fn a, const SM2_Fn b); -void sm2_fn_sub(SM2_Fn r, const SM2_Fn a, const SM2_Fn b); -void sm2_fn_mul(SM2_Fn r, const SM2_Fn a, const SM2_Fn b); -void sm2_fn_mul_word(SM2_Fn r, const SM2_Fn a, uint32_t b); -void sm2_fn_exp(SM2_Fn r, const SM2_Fn a, const SM2_Fn e); -void sm2_fn_neg(SM2_Fn r, const SM2_Fn a); -void sm2_fn_sqr(SM2_Fn r, const SM2_Fn a); -void sm2_fn_inv(SM2_Fn r, const SM2_Fn a); -int sm2_fn_rand(SM2_Fn r); - -#define sm2_fn_init(r) sm2_bn_init(r) -#define sm2_fn_set_zero(r) sm2_bn_set_zero(r) -#define sm2_fn_set_one(r) sm2_bn_set_one(r) -#define sm2_fn_copy(r,a) sm2_bn_copy(r,a) -#define sm2_fn_clean(r) sm2_bn_clean(r) - - -typedef struct { - SM2_BN X; - SM2_BN Y; - SM2_BN Z; -} SM2_JACOBIAN_POINT; - -void sm2_jacobian_point_init(SM2_JACOBIAN_POINT *R); -void sm2_jacobian_point_set_xy(SM2_JACOBIAN_POINT *R, const SM2_BN x, const SM2_BN y); -void sm2_jacobian_point_get_xy(const SM2_JACOBIAN_POINT *P, SM2_BN x, SM2_BN y); -void sm2_jacobian_point_neg(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P); -void sm2_jacobian_point_dbl(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P); -void sm2_jacobian_point_add(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P, const SM2_JACOBIAN_POINT *Q); -void sm2_jacobian_point_sub(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P, const SM2_JACOBIAN_POINT *Q); -void sm2_jacobian_point_mul(SM2_JACOBIAN_POINT *R, const SM2_BN k, const SM2_JACOBIAN_POINT *P); -void sm2_jacobian_point_to_bytes(const SM2_JACOBIAN_POINT *P, uint8_t out[64]); -void sm2_jacobian_point_from_bytes(SM2_JACOBIAN_POINT *P, const uint8_t in[64]); -void sm2_jacobian_point_mul_generator(SM2_JACOBIAN_POINT *R, const SM2_BN k); -void sm2_jacobian_point_mul_sum(SM2_JACOBIAN_POINT *R, const SM2_BN t, const SM2_JACOBIAN_POINT *P, const SM2_BN s); -void sm2_jacobian_point_from_hex(SM2_JACOBIAN_POINT *P, const char hex[64 * 2]); // for testing only - -int sm2_jacobian_point_is_at_infinity(const SM2_JACOBIAN_POINT *P); -int sm2_jacobian_point_is_on_curve(const SM2_JACOBIAN_POINT *P); -int sm2_jacobian_point_equ_hex(const SM2_JACOBIAN_POINT *P, const char hex[128]); // for testing only -int sm2_jacobian_point_print(FILE *fp, int fmt, int ind, const char *label, const SM2_JACOBIAN_POINT *P); - -#define sm2_jacobian_point_set_infinity(R) sm2_jacobian_point_init(R) -#define sm2_jacobian_point_copy(R, P) memcpy((R), (P), sizeof(SM2_JACOBIAN_POINT)) - typedef uint8_t sm2_bn_t[32]; typedef struct { @@ -131,6 +33,8 @@ typedef struct { #define sm2_point_init(P) memset((P),0,sizeof(SM2_POINT)) #define sm2_point_set_infinity(P) sm2_point_init(P) + + int sm2_point_from_octets(SM2_POINT *P, const uint8_t *in, size_t inlen); void sm2_point_to_compressed_octets(const SM2_POINT *P, uint8_t out[33]); void sm2_point_to_uncompressed_octets(const SM2_POINT *P, uint8_t out[65]); @@ -147,6 +51,7 @@ int sm2_point_mul(SM2_POINT *R, const uint8_t k[32], const SM2_POINT *P); int sm2_point_mul_generator(SM2_POINT *R, const uint8_t k[32]); int sm2_point_mul_sum(SM2_POINT *R, const uint8_t k[32], const SM2_POINT *P, const uint8_t s[32]); // R = k * P + s * G + /* RFC 5480 Elliptic Curve Cryptography Subject Public Key Information ECPoint ::= OCTET STRING @@ -163,7 +68,6 @@ typedef struct { uint8_t private_key[32]; } SM2_KEY; - _gmssl_export int sm2_key_generate(SM2_KEY *key); int sm2_key_set_private_key(SM2_KEY *key, const uint8_t private_key[32]); // key->public_key will be replaced int sm2_key_set_public_key(SM2_KEY *key, const SM2_POINT *public_key); // key->private_key will be cleared // FIXME: support octets as input? @@ -174,6 +78,7 @@ int sm2_public_key_equ(const SM2_KEY *sm2_key, const SM2_KEY *pub_key); int sm2_public_key_digest(const SM2_KEY *key, uint8_t dgst[32]); int sm2_public_key_print(FILE *fp, int fmt, int ind, const char *label, const SM2_KEY *pub_key); + /* from RFC 5915 @@ -258,9 +163,14 @@ typedef struct { } SM2_SIGNATURE; int sm2_do_sign(const SM2_KEY *key, const uint8_t dgst[32], SM2_SIGNATURE *sig); -int sm2_do_sign_fast(const SM2_Fn d, const uint8_t dgst[32], SM2_SIGNATURE *sig); +int sm2_do_sign_fast(const uint64_t d[4], const uint8_t dgst[32], SM2_SIGNATURE *sig); int sm2_do_verify(const SM2_KEY *key, const uint8_t dgst[32], const SM2_SIGNATURE *sig); +int sm2_do_sign_pre_compute(uint64_t k[4], uint64_t x1[4]); + +int sm2_do_sign_fast_ex(const uint64_t d[4], const uint64_t k[4], const uint64_t x1[4], const uint8_t dgst[32], SM2_SIGNATURE *sig); +int sm2_do_verify_fast(const SM2_Z256_POINT *P, const uint8_t dgst[32], const SM2_SIGNATURE *sig); + #define SM2_MIN_SIGNATURE_SIZE 8 #define SM2_MAX_SIGNATURE_SIZE 72 @@ -277,6 +187,8 @@ enum { }; int sm2_sign_fixlen(const SM2_KEY *key, const uint8_t dgst[32], size_t siglen, uint8_t *sig); + + #define SM2_DEFAULT_ID "1234567812345678" #define SM2_DEFAULT_ID_LENGTH (sizeof(SM2_DEFAULT_ID) - 1) // LENGTH for string and SIZE for bytes #define SM2_DEFAULT_ID_BITS (SM2_DEFAULT_ID_LENGTH * 8) @@ -286,9 +198,23 @@ int sm2_sign_fixlen(const SM2_KEY *key, const uint8_t dgst[32], size_t siglen, u int sm2_compute_z(uint8_t z[32], const SM2_POINT *pub, const char *id, size_t idlen); +typedef struct { + uint64_t k[4]; + uint64_t x1[4]; +} SM2_SIGN_PRE_COMP; + typedef struct { SM3_CTX sm3_ctx; SM2_KEY key; + // FIXME: change `key` to SM2_Z256_POINT and uint64_t[4], inner type, faster sign/verify + + SM2_Z256_POINT public_key; // z256 only + uint64_t sign_key[8]; // u64[8] to support SM2_BN + SM3_CTX inited_sm3_ctx; + + SM2_SIGN_PRE_COMP pre_comp[32]; + unsigned int num_pre_comp; + } SM2_SIGN_CTX; _gmssl_export int sm2_sign_init(SM2_SIGN_CTX *ctx, const SM2_KEY *key, const char *id, size_t idlen); @@ -300,6 +226,8 @@ _gmssl_export int sm2_verify_init(SM2_SIGN_CTX *ctx, const SM2_KEY *key, const c _gmssl_export int sm2_verify_update(SM2_SIGN_CTX *ctx, const uint8_t *data, size_t datalen); _gmssl_export int sm2_verify_finish(SM2_SIGN_CTX *ctx, const uint8_t *sig, size_t siglen); +_gmssl_export int sm2_sign_ctx_reset(SM2_SIGN_CTX *ctx); + /* SM2Cipher ::= SEQUENCE { XCoordinate INTEGER, @@ -356,10 +284,6 @@ _gmssl_export int sm2_decrypt_init(SM2_ENC_CTX *ctx, const SM2_KEY *sm2_key); _gmssl_export int sm2_decrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); _gmssl_export int sm2_decrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); -const uint64_t *sm2_bn_prime(void); -const uint64_t *sm2_bn_order(void); -const uint64_t *sm2_bn_one(void); - #ifdef __cplusplus } #endif diff --git a/include/gmssl/sm2_p256.h b/include/gmssl/sm2_p256.h new file mode 100644 index 00000000..3508215e --- /dev/null +++ b/include/gmssl/sm2_p256.h @@ -0,0 +1,131 @@ +/* + * Copyright 2014-2024 The GmSSL Project. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ + + + +#ifndef GMSSL_SM2_P256_H +#define GMSSL_SM2_P256_H + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef uint64_t SM2_BN[8]; + +int sm2_bn_is_zero(const SM2_BN a); +int sm2_bn_is_one(const SM2_BN a); +int sm2_bn_is_odd(const SM2_BN a); +int sm2_bn_cmp(const SM2_BN a, const SM2_BN b); +int sm2_bn_from_hex(SM2_BN r, const char hex[64]); +int sm2_bn_from_asn1_integer(SM2_BN r, const uint8_t *d, size_t dlen); +int sm2_bn_equ_hex(const SM2_BN a, const char *hex); +int sm2_bn_print(FILE *fp, int fmt, int ind, const char *label, const SM2_BN a); +int sm2_bn_rshift(SM2_BN ret, const SM2_BN a, unsigned int nbits); + +void sm2_bn_to_bytes(const SM2_BN a, uint8_t out[32]); +void sm2_bn_from_bytes(SM2_BN r, const uint8_t in[32]); +void sm2_bn_to_hex(const SM2_BN a, char hex[64]); +void sm2_bn_to_bits(const SM2_BN a, char bits[256]); +void sm2_bn_set_word(SM2_BN r, uint32_t a); +void sm2_bn_add(SM2_BN r, const SM2_BN a, const SM2_BN b); +void sm2_bn_sub(SM2_BN ret, const SM2_BN a, const SM2_BN b); +int sm2_bn_rand_range(SM2_BN r, const SM2_BN range); + +#define sm2_bn_init(r) memset((r),0,sizeof(SM2_BN)) +#define sm2_bn_set_zero(r) memset((r),0,sizeof(SM2_BN)) +#define sm2_bn_set_one(r) sm2_bn_set_word((r),1) +#define sm2_bn_copy(r,a) memcpy((r),(a),sizeof(SM2_BN)) +#define sm2_bn_clean(r) memset((r),0,sizeof(SM2_BN)) + + +// GF(p) +typedef SM2_BN SM2_Fp; + +void sm2_fp_add(SM2_Fp r, const SM2_Fp a, const SM2_Fp b); +void sm2_fp_sub(SM2_Fp r, const SM2_Fp a, const SM2_Fp b); +void sm2_fp_mul(SM2_Fp r, const SM2_Fp a, const SM2_Fp b); +void sm2_fp_exp(SM2_Fp r, const SM2_Fp a, const SM2_Fp e); +void sm2_fp_dbl(SM2_Fp r, const SM2_Fp a); +void sm2_fp_tri(SM2_Fp r, const SM2_Fp a); +void sm2_fp_div2(SM2_Fp r, const SM2_Fp a); +void sm2_fp_neg(SM2_Fp r, const SM2_Fp a); +void sm2_fp_sqr(SM2_Fp r, const SM2_Fp a); +void sm2_fp_inv(SM2_Fp r, const SM2_Fp a); +int sm2_fp_rand(SM2_Fp r); + +int sm2_fp_sqrt(SM2_Fp r, const SM2_Fp a); + +#define sm2_fp_init(r) sm2_bn_init(r) +#define sm2_fp_set_zero(r) sm2_bn_set_zero(r) +#define sm2_fp_set_one(r) sm2_bn_set_one(r) +#define sm2_fp_copy(r,a) sm2_bn_copy(r,a) +#define sm2_fp_clean(r) sm2_bn_clean(r) + +// GF(n) +typedef SM2_BN SM2_Fn; + +void sm2_fn_add(SM2_Fn r, const SM2_Fn a, const SM2_Fn b); +void sm2_fn_sub(SM2_Fn r, const SM2_Fn a, const SM2_Fn b); +void sm2_fn_mul(SM2_Fn r, const SM2_Fn a, const SM2_Fn b); +void sm2_fn_mul_word(SM2_Fn r, const SM2_Fn a, uint32_t b); +void sm2_fn_exp(SM2_Fn r, const SM2_Fn a, const SM2_Fn e); +void sm2_fn_neg(SM2_Fn r, const SM2_Fn a); +void sm2_fn_sqr(SM2_Fn r, const SM2_Fn a); +void sm2_fn_inv(SM2_Fn r, const SM2_Fn a); +int sm2_fn_rand(SM2_Fn r); + +#define sm2_fn_init(r) sm2_bn_init(r) +#define sm2_fn_set_zero(r) sm2_bn_set_zero(r) +#define sm2_fn_set_one(r) sm2_bn_set_one(r) +#define sm2_fn_copy(r,a) sm2_bn_copy(r,a) +#define sm2_fn_clean(r) sm2_bn_clean(r) + + +typedef struct { + SM2_BN X; + SM2_BN Y; + SM2_BN Z; +} SM2_JACOBIAN_POINT; + +void sm2_jacobian_point_init(SM2_JACOBIAN_POINT *R); +void sm2_jacobian_point_set_xy(SM2_JACOBIAN_POINT *R, const SM2_BN x, const SM2_BN y); +void sm2_jacobian_point_get_xy(const SM2_JACOBIAN_POINT *P, SM2_BN x, SM2_BN y); +void sm2_jacobian_point_neg(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P); +void sm2_jacobian_point_dbl(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P); +void sm2_jacobian_point_add(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P, const SM2_JACOBIAN_POINT *Q); +void sm2_jacobian_point_sub(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P, const SM2_JACOBIAN_POINT *Q); +void sm2_jacobian_point_mul(SM2_JACOBIAN_POINT *R, const SM2_BN k, const SM2_JACOBIAN_POINT *P); +void sm2_jacobian_point_to_bytes(const SM2_JACOBIAN_POINT *P, uint8_t out[64]); +void sm2_jacobian_point_from_bytes(SM2_JACOBIAN_POINT *P, const uint8_t in[64]); +void sm2_jacobian_point_mul_generator(SM2_JACOBIAN_POINT *R, const SM2_BN k); +void sm2_jacobian_point_mul_sum(SM2_JACOBIAN_POINT *R, const SM2_BN t, const SM2_JACOBIAN_POINT *P, const SM2_BN s); +void sm2_jacobian_point_from_hex(SM2_JACOBIAN_POINT *P, const char hex[64 * 2]); // for testing only + +int sm2_jacobian_point_is_at_infinity(const SM2_JACOBIAN_POINT *P); +int sm2_jacobian_point_is_on_curve(const SM2_JACOBIAN_POINT *P); +int sm2_jacobian_point_equ_hex(const SM2_JACOBIAN_POINT *P, const char hex[128]); // for testing only +int sm2_jacobian_point_print(FILE *fp, int fmt, int ind, const char *label, const SM2_JACOBIAN_POINT *P); + +#define sm2_jacobian_point_set_infinity(R) sm2_jacobian_point_init(R) +#define sm2_jacobian_point_copy(R, P) memcpy((R), (P), sizeof(SM2_JACOBIAN_POINT)) + +const uint64_t *sm2_bn_prime(void); +const uint64_t *sm2_bn_order(void); +const uint64_t *sm2_bn_one(void); + + +#ifdef __cplusplus +} +#endif +#endif + diff --git a/include/gmssl/sm2_z256.h b/include/gmssl/sm2_z256.h index 5d6e25ed..b0191d4b 100644 --- a/include/gmssl/sm2_z256.h +++ b/include/gmssl/sm2_z256.h @@ -30,6 +30,7 @@ void sm2_z256_to_bytes(const uint64_t a[4], uint8_t out[32]); int sm2_z256_cmp(const uint64_t a[4], const uint64_t b[4]); uint64_t sm2_z256_is_zero(const uint64_t a[4]); uint64_t sm2_z256_equ(const uint64_t a[4], const uint64_t b[4]); +void sm2_z256_rshift(uint64_t r[4], const uint64_t a[4], unsigned int nbits); uint64_t sm2_z256_add(uint64_t r[4], const uint64_t a[4], const uint64_t b[4]); uint64_t sm2_z256_sub(uint64_t r[4], const uint64_t a[4], const uint64_t b[4]); void sm2_z256_mul(uint64_t r[8], const uint64_t a[4], const uint64_t b[4]); @@ -53,6 +54,7 @@ void sm2_z256_modp_mont_mul(uint64_t r[4], const uint64_t a[4], const uint64_t b void sm2_z256_modp_mont_sqr(uint64_t r[4], const uint64_t a[4]); void sm2_z256_modp_mont_exp(uint64_t r[4], const uint64_t a[4], const uint64_t e[4]); void sm2_z256_modp_mont_inv(uint64_t r[4], const uint64_t a[4]); +int sm2_z256_modp_mont_sqrt(uint64_t r[4], const uint64_t a[4]); int sm2_z256_modp_mont_print(FILE *fp, int ind, int fmt, const char *label, const uint64_t a[4]); int sm2_z256_modn_rand(uint64_t r[4]); @@ -79,11 +81,13 @@ typedef struct { uint64_t Z[4]; } SM2_Z256_POINT; +void sm2_z256_point_set_infinity(SM2_Z256_POINT *P); void sm2_z256_point_from_bytes(SM2_Z256_POINT *P, const uint8_t in[64]); void sm2_z256_point_to_bytes(const SM2_Z256_POINT *P, uint8_t out[64]); int sm2_z256_point_is_at_infinity(const SM2_Z256_POINT *P); int sm2_z256_point_is_on_curve(const SM2_Z256_POINT *P); +int sm2_z256_point_equ(const SM2_Z256_POINT *P, const SM2_Z256_POINT *Q); void sm2_z256_point_get_xy(const SM2_Z256_POINT *P, uint64_t x[4], uint64_t y[4]); void sm2_z256_point_dbl(SM2_Z256_POINT *R, const SM2_Z256_POINT *A); @@ -110,11 +114,25 @@ void sm2_z256_point_mul_sum(SM2_Z256_POINT *R, const uint64_t t[4], const SM2_Z2 const uint64_t *sm2_z256_prime(void); const uint64_t *sm2_z256_order(void); +const uint64_t *sm2_z256_order_minus_one(void); const uint64_t *sm2_z256_one(void); void sm2_z256_point_from_hex(SM2_Z256_POINT *P, const char *hex); int sm2_z256_point_equ_hex(const SM2_Z256_POINT *P, const char *hex); +enum { + SM2_point_at_infinity = 0x00, + SM2_point_compressed_y_even = 0x02, + SM2_point_compressed_y_odd = 0x03, + SM2_point_uncompressed = 0x04, + SM2_point_uncompressed_y_even = 0x06, + SM2_point_uncompressed_y_odd = 0x07, +}; + +int sm2_z256_point_from_x_bytes(SM2_Z256_POINT *P, const uint8_t x_bytes[32], int y_is_odd); +int sm2_z256_point_from_hash(SM2_Z256_POINT *R, const uint8_t *data, size_t datalen, int y_is_odd); + +int sm2_z256_point_from_octets(SM2_Z256_POINT *P, const uint8_t *in, size_t inlen); #ifdef __cplusplus } diff --git a/src/sm2_alg.c b/src/sm2_alg.c index 15681ead..59d4fd9a 100644 --- a/src/sm2_alg.c +++ b/src/sm2_alg.c @@ -539,10 +539,17 @@ int sm2_fp_sqrt(SM2_Fp r, const SM2_Fp a) SM2_BN u; SM2_BN y; // temp result, prevent call sm2_fp_sqrt(a, a) + printf("sm2_fp_sqrt\n"); + sm2_bn_print(stderr, 0, 4, "a", a); + // r = a^((p + 1)/4) when p = 3 (mod 4) sm2_bn_add(u, SM2_P, SM2_ONE); sm2_bn_rshift(u, u, 2); + + sm2_bn_print(stderr, 0, 4, "u", u); + sm2_fp_exp(y, a, u); + sm2_bn_print(stderr, 0, 4, "y", y); // check r^2 == a sm2_fp_sqr(u, y); @@ -1087,6 +1094,7 @@ int sm2_point_is_at_infinity(const SM2_POINT *P) return mem_is_zero((uint8_t *)P, sizeof(SM2_POINT)); } +// 这个函数和 sm2_z256_point_from_x_bytes 不一样 int sm2_point_from_x(SM2_POINT *P, const uint8_t x[32], int y) { SM2_BN _x, _y, _g, _z; diff --git a/src/sm2_ctx.c b/src/sm2_ctx.c index 0e4ced42..beecbaf4 100644 --- a/src/sm2_ctx.c +++ b/src/sm2_ctx.c @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -22,11 +23,19 @@ int sm2_sign_init(SM2_SIGN_CTX *ctx, const SM2_KEY *key, const char *id, size_t idlen) { + size_t i; + if (!ctx || !key) { error_print(); return -1; } ctx->key = *key; + + // d' = (d + 1)^-1 (mod n) + sm2_z256_from_bytes(ctx->sign_key, key->private_key); + sm2_z256_modn_add(ctx->sign_key, ctx->sign_key, sm2_z256_one()); + sm2_z256_modn_inv(ctx->sign_key, ctx->sign_key); + sm3_init(&ctx->sm3_ctx); if (id) { @@ -38,6 +47,24 @@ int sm2_sign_init(SM2_SIGN_CTX *ctx, const SM2_KEY *key, const char *id, size_t sm2_compute_z(z, &key->public_key, id, idlen); sm3_update(&ctx->sm3_ctx, z, sizeof(z)); } + + ctx->inited_sm3_ctx = ctx->sm3_ctx; + + // pre compute (k, x = [k]G.x) + for (i = 0; i < 32; i++) { + if (sm2_do_sign_pre_compute(ctx->pre_comp[i].k, ctx->pre_comp[i].x1) != 1) { + error_print(); + return -1; + } + } + ctx->num_pre_comp = 32; + + return 1; +} + +int sm2_sign_ctx_reset(SM2_SIGN_CTX *ctx) +{ + ctx->sm3_ctx = ctx->inited_sm3_ctx; return 1; } @@ -56,16 +83,39 @@ int sm2_sign_update(SM2_SIGN_CTX *ctx, const uint8_t *data, size_t datalen) int sm2_sign_finish(SM2_SIGN_CTX *ctx, uint8_t *sig, size_t *siglen) { uint8_t dgst[SM3_DIGEST_SIZE]; + SM2_SIGNATURE signature; if (!ctx || !sig || !siglen) { error_print(); return -1; } sm3_finish(&ctx->sm3_ctx, dgst); - if (sm2_sign(&ctx->key, dgst, sig, siglen) != 1) { + + if (ctx->num_pre_comp == 0) { + size_t i; + for (i = 0; i < 32; i++) { + if (sm2_do_sign_pre_compute(ctx->pre_comp[i].k, ctx->pre_comp[i].x1) != 1) { + error_print(); + return -1; + } + } + ctx->num_pre_comp = 32; + } + + ctx->num_pre_comp--; + if (sm2_do_sign_fast_ex(ctx->sign_key, + ctx->pre_comp[ctx->num_pre_comp].k, ctx->pre_comp[ctx->num_pre_comp].x1, + dgst, &signature) != 1) { error_print(); return -1; } + + *siglen = 0; + if (sm2_signature_to_der(&signature, &sig, siglen) != 1) { + error_print(); + return -1; + } + return 1; } @@ -93,6 +143,9 @@ int sm2_verify_init(SM2_SIGN_CTX *ctx, const SM2_KEY *key, const char *id, size_ } memset(ctx, 0, sizeof(*ctx)); ctx->key.public_key = key->public_key; + + sm2_z256_point_from_bytes(&ctx->public_key, (const uint8_t *)&key->public_key); + sm3_init(&ctx->sm3_ctx); if (id) { @@ -104,6 +157,9 @@ int sm2_verify_init(SM2_SIGN_CTX *ctx, const SM2_KEY *key, const char *id, size_ sm2_compute_z(z, &key->public_key, id, idlen); sm3_update(&ctx->sm3_ctx, z, sizeof(z)); } + + ctx->inited_sm3_ctx = ctx->sm3_ctx; + return 1; } @@ -135,9 +191,6 @@ int sm2_verify_finish(SM2_SIGN_CTX *ctx, const uint8_t *sig, size_t siglen) return 1; } - - - int sm2_encrypt_init(SM2_ENC_CTX *ctx, const SM2_KEY *sm2_key) { if (!ctx || !sm2_key) { diff --git a/src/sm2_point.c b/src/sm2_point.c new file mode 100644 index 00000000..0e517915 --- /dev/null +++ b/src/sm2_point.c @@ -0,0 +1,282 @@ +/* + * Copyright 2014-2024 The GmSSL Project. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ + + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +int sm2_point_is_on_curve(const SM2_POINT *P) +{ + SM2_Z256_POINT T; + sm2_z256_point_from_bytes(&T, (const uint8_t *)P); + + if (sm2_z256_point_is_on_curve(&T) == 1) { + return 1; + } else { + return 0; + } +} + +int sm2_point_is_at_infinity(const SM2_POINT *P) +{ + return mem_is_zero((uint8_t *)P, sizeof(SM2_POINT)); +} + +int sm2_point_from_x(SM2_POINT *P, const uint8_t x[32], int y_is_odd) +{ + + SM2_Z256_POINT T; + + if (sm2_z256_point_from_x_bytes(&T, x, y_is_odd) != 1) { + error_print(); + return -1; + } + + sm2_z256_point_to_bytes(&T, (uint8_t *)P); + return 1; +} + +int sm2_point_from_xy(SM2_POINT *P, const uint8_t x[32], const uint8_t y[32]) +{ + memcpy(P->x, x, 32); + memcpy(P->y, y, 32); + return sm2_point_is_on_curve(P); +} + +int sm2_point_add(SM2_POINT *R, const SM2_POINT *P, const SM2_POINT *Q) +{ + SM2_Z256_POINT P_; + SM2_Z256_POINT Q_; + + sm2_z256_point_from_bytes(&P_, (uint8_t *)P); + sm2_z256_point_from_bytes(&Q_, (uint8_t *)Q); + sm2_z256_point_add(&P_, &P_, &Q_); + sm2_z256_point_to_bytes(&P_, (uint8_t *)R); + + return 1; +} + +int sm2_point_sub(SM2_POINT *R, const SM2_POINT *P, const SM2_POINT *Q) +{ + SM2_Z256_POINT P_; + SM2_Z256_POINT Q_; + + sm2_z256_point_from_bytes(&P_, (uint8_t *)P); + sm2_z256_point_from_bytes(&Q_, (uint8_t *)Q); + sm2_z256_point_sub(&P_, &P_, &Q_); + sm2_z256_point_to_bytes(&P_, (uint8_t *)R); + + return 1; +} + +int sm2_point_neg(SM2_POINT *R, const SM2_POINT *P) +{ + SM2_Z256_POINT P_; + + sm2_z256_point_from_bytes(&P_, (uint8_t *)P); + sm2_z256_point_neg(&P_, &P_); + sm2_z256_point_to_bytes(&P_, (uint8_t *)R); + + return 1; +} + +int sm2_point_dbl(SM2_POINT *R, const SM2_POINT *P) +{ + SM2_Z256_POINT P_; + + sm2_z256_point_from_bytes(&P_, (uint8_t *)P); + sm2_z256_point_dbl(&P_, &P_); + sm2_z256_point_to_bytes(&P_, (uint8_t *)R); + + return 1; +} + +int sm2_point_mul(SM2_POINT *R, const uint8_t k[32], const SM2_POINT *P) +{ + uint64_t _k[4]; + SM2_Z256_POINT _P; + + sm2_z256_from_bytes(_k, k); + sm2_z256_point_from_bytes(&_P, (uint8_t *)P); + sm2_z256_point_mul(&_P, _k, &_P); + sm2_z256_point_to_bytes(&_P, (uint8_t *)R); + + memset(_k, 0, sizeof(_k)); + return 1; +} + +int sm2_point_mul_generator(SM2_POINT *R, const uint8_t k[32]) +{ + uint64_t _k[4]; + SM2_Z256_POINT _R; + + sm2_z256_from_bytes(_k, k); + sm2_z256_point_mul_generator(&_R, _k); + sm2_z256_point_to_bytes(&_R, (uint8_t *)R); + + memset(_k, 0, sizeof(_k)); + return 1; +} + +int sm2_point_mul_sum(SM2_POINT *R, const uint8_t k[32], const SM2_POINT *P, const uint8_t s[32]) +{ + uint64_t _k[4]; + SM2_Z256_POINT _P; + uint64_t _s[4]; + + sm2_z256_from_bytes(_k, k); + sm2_z256_point_from_bytes(&_P, (uint8_t *)P); + sm2_z256_from_bytes(_s, s); + sm2_z256_point_mul_sum(&_P, _k, &_P, _s); + sm2_z256_point_to_bytes(&_P, (uint8_t *)R); + + memset(_k, 0, sizeof(_k)); + memset(_s, 0, sizeof(_s)); + return 1; +} + +int sm2_point_print(FILE *fp, int fmt, int ind, const char *label, const SM2_POINT *P) +{ + format_print(fp, fmt, ind, "%s\n", label); + ind += 4; + format_bytes(fp, fmt, ind, "x", P->x, 32); + format_bytes(fp, fmt, ind, "y", P->y, 32); + return 1; +} + +void sm2_point_to_compressed_octets(const SM2_POINT *P, uint8_t out[33]) +{ + *out++ = (P->y[31] & 0x01) ? 0x03 : 0x02; + memcpy(out, P->x, 32); +} + +void sm2_point_to_uncompressed_octets(const SM2_POINT *P, uint8_t out[65]) +{ + *out++ = 0x04; + memcpy(out, P, 64); +} + +int sm2_z256_point_from_octets(SM2_Z256_POINT *P, const uint8_t *in, size_t inlen) +{ + switch (*in) { + case SM2_point_at_infinity: + if (inlen != 1) { + error_print(); + return -1; + } + sm2_z256_point_set_infinity(P); + break; + case SM2_point_compressed_y_even: + if (inlen != 33) { + error_print(); + return -1; + } + if (sm2_z256_point_from_x_bytes(P, in + 1, 0) != 1) { + error_print(); + return -1; + } + break; + case SM2_point_compressed_y_odd: + if (inlen != 33) { + error_print(); + return -1; + } + if (sm2_z256_point_from_x_bytes(P, in + 1, 1) != 1) { + error_print(); + return -1; + } + break; + case SM2_point_uncompressed: + if (inlen != 65) { + error_print(); + return -1; + } + sm2_z256_point_from_bytes(P, in + 1); + if (sm2_z256_point_is_on_curve(P) != 1) { + error_print(); + return -1; + } + break; + default: + error_print(); + return -1; + } + + return 1; +} + +int sm2_point_from_octets(SM2_POINT *P, const uint8_t *in, size_t inlen) +{ + if ((*in == 0x02 || *in == 0x03) && inlen == 33) { + if (sm2_point_from_x(P, in + 1, *in) != 1) { + error_print(); + return -1; + } + } else if (*in == 0x04 && inlen == 65) { + if (sm2_point_from_xy(P, in + 1, in + 33) != 1) { + error_print(); + return -1; + } + } else { + error_print(); + return -1; + } + return 1; +} + +int sm2_point_to_der(const SM2_POINT *P, uint8_t **out, size_t *outlen) +{ + uint8_t octets[65]; + if (!P) { + return 0; + } + sm2_point_to_uncompressed_octets(P, octets); + if (asn1_octet_string_to_der(octets, sizeof(octets), out, outlen) != 1) { + error_print(); + return -1; + } + return 1; +} + +int sm2_point_from_der(SM2_POINT *P, const uint8_t **in, size_t *inlen) +{ + int ret; + const uint8_t *d; + size_t dlen; + + if ((ret = asn1_octet_string_from_der(&d, &dlen, in, inlen)) != 1) { + if (ret < 0) error_print(); + return ret; + } + if (dlen != 65) { + error_print(); + return -1; + } + if (sm2_point_from_octets(P, d, dlen) != 1) { + error_print(); + return -1; + } + return 1; +} + +int sm2_point_from_hash(SM2_POINT *R, const uint8_t *data, size_t datalen) +{ + return 1; +} + diff --git a/src/sm2_z256.c b/src/sm2_z256.c index dd2ae29d..bf47a96f 100644 --- a/src/sm2_z256.c +++ b/src/sm2_z256.c @@ -52,6 +52,8 @@ #include #include #include +#include + /* SM2 parameters @@ -71,7 +73,10 @@ const uint64_t *sm2_z256_one(void) { return &SM2_Z256_ONE[0]; } - +void sm2_z256_set_zero(uint64_t a[4]) +{ + a[0] = a[1] = a[2] = a[3] = 0; +} int sm2_z256_rand_range(uint64_t r[4], const uint64_t range[4]) { @@ -161,6 +166,23 @@ uint64_t sm2_z256_is_zero(const uint64_t a[4]) is_zero(a[3]); } +void sm2_z256_rshift(uint64_t r[4], const uint64_t a[4], unsigned int nbits) +{ + nbits &= 0x3f; + + if (nbits) { + r[0] = a[0] >> nbits; + r[0] |= a[1] << (64 - nbits); + r[1] = a[1] >> nbits; + r[1] |= a[2] << (64 - nbits); + r[2] = a[2] >> nbits; + r[2] |= a[3] << (64 - nbits); + r[3] = a[3] >> nbits; + } else { + sm2_z256_copy(r, a); + } +} + uint64_t sm2_z256_add(uint64_t r[4], const uint64_t a[4], const uint64_t b[4]) { uint64_t t, c = 0; @@ -351,6 +373,9 @@ int sm2_z512_print(FILE *fp, int ind, int fmt, const char *label, const uint64_t const uint64_t SM2_Z256_P[4] = { 0xffffffffffffffff, 0xffffffff00000000, 0xffffffffffffffff, 0xfffffffeffffffff, }; +// 注意这里 SM2_Z256_P[0] 和 SM2_Z256_P[2] 是特殊值,在汇编中可以根据这个特殊值做特定的实现 + + const uint64_t *sm2_z256_prime(void) { return &SM2_Z256_P[0]; @@ -362,6 +387,7 @@ const uint64_t SM2_Z256_NEG_P[4] = { 1, ((uint64_t)1 << 32) - 1, 0, ((uint64_t)1 << 32), }; +#ifndef ENABLE_SM2_Z256_ARMV8 void sm2_z256_modp_add(uint64_t r[4], const uint64_t a[4], const uint64_t b[4]) { uint64_t c; @@ -404,6 +430,11 @@ void sm2_z256_modp_mul_by_3(uint64_t r[4], const uint64_t a[4]) sm2_z256_modp_add(r, t, a); } +void sm2_z256_modp_neg(uint64_t r[4], const uint64_t a[4]) +{ + (void)sm2_z256_sub(r, SM2_Z256_P, a); +} + void sm2_z256_modp_div_by_2(uint64_t r[4], const uint64_t a[4]) { uint64_t c = 0; @@ -422,11 +453,9 @@ void sm2_z256_modp_div_by_2(uint64_t r[4], const uint64_t a[4]) r[2] = (r[2] >> 1) | ((r[3] & 1) << 63); r[3] = (r[3] >> 1) | ((c & 1) << 63); } +#endif -void sm2_z256_modp_neg(uint64_t r[4], const uint64_t a[4]) -{ - (void)sm2_z256_sub(r, SM2_Z256_P, a); -} +// p' * p = -1 mod 2^256 // p' = -p^(-1) mod 2^256 // = fffffffc00000001fffffffe00000000ffffffff000000010000000000000001 @@ -435,10 +464,12 @@ const uint64_t SM2_Z256_P_PRIME[4] = { 0x0000000000000001, 0xffffffff00000001, 0xfffffffe00000000, 0xfffffffc00000001, }; + // mont(1) (mod p) = 2^256 mod p = 2^256 - p const uint64_t *SM2_Z256_MODP_MONT_ONE = SM2_Z256_NEG_P; -// z = xy +#ifndef ENABLE_SM2_Z256_ARMV8 +// z = a*b // c = (z + (z * p' mod 2^256) * p)/2^256 void sm2_z256_modp_mont_mul(uint64_t r[4], const uint64_t a[4], const uint64_t b[4]) { @@ -484,6 +515,24 @@ void sm2_z256_modp_mont_sqr(uint64_t r[4], const uint64_t a[4]) sm2_z256_modp_mont_mul(r, a, a); } +// mont(mont(a), 1) = aR * 1 * R^-1 (mod p) = a (mod p) +void sm2_z256_modp_from_mont(uint64_t r[4], const uint64_t a[4]) +{ + sm2_z256_modp_mont_mul(r, a, SM2_Z256_ONE); +} + +// 2^512 (mod p) +const uint64_t SM2_Z256_2e512modp[4] = { + 0x0000000200000003, 0x00000002ffffffff, 0x0000000100000001, 0x0000000400000002 +}; + +// mont(a) = a * 2^256 (mod p) = mont_mul(a, 2^512 mod p) +void sm2_z256_modp_to_mont(const uint64_t a[4], uint64_t r[4]) +{ + sm2_z256_modp_mont_mul(r, a, SM2_Z256_2e512modp); +} +#endif + void sm2_z256_modp_mont_exp(uint64_t r[4], const uint64_t a[4], const uint64_t e[4]) { uint64_t t[4]; @@ -589,21 +638,30 @@ void sm2_z256_modp_mont_inv(uint64_t r[4], const uint64_t a[4]) sm2_z256_modp_mont_mul(r, a4, a5); } -// mont(mont(a), 1) = aR * 1 * R^-1 (mod p) = a (mod p) -void sm2_z256_modp_from_mont(uint64_t r[4], const uint64_t a[4]) -{ - sm2_z256_modp_mont_mul(r, a, SM2_Z256_ONE); -} - -// 2^512 (mod p) -const uint64_t SM2_Z256_2e512modp[4] = { - 0x0000000200000003, 0x00000002ffffffff, 0x0000000100000001, 0x0000000400000002 +// (p+1)/4 = 3fffffffbfffffffffffffffffffffffffffffffc00000004000000000000000 +const uint64_t SM2_Z256_SQRT_EXP[4] = { + 0x4000000000000000, 0xffffffffc0000000, 0xffffffffffffffff, 0x3fffffffbfffffff, }; -// mont(a) = a * 2^256 (mod p) = mont_mul(a, 2^512 mod p) -void sm2_z256_modp_to_mont(const uint64_t a[4], uint64_t r[4]) +// -r (mod p), i.e. (p - r) is also a square root of a +int sm2_z256_modp_mont_sqrt(uint64_t r[4], const uint64_t a[4]) { - sm2_z256_modp_mont_mul(r, a, SM2_Z256_2e512modp); + uint64_t a_[4]; + uint64_t r_[4]; // temp result, prevent call sm2_fp_sqrt(a, a) + + // r = a^((p + 1)/4) when p = 3 (mod 4) + sm2_z256_modp_mont_exp(r_, a, SM2_Z256_SQRT_EXP); + + // check r^2 == a + sm2_z256_modp_mont_sqr(a_, r_); + if (sm2_z256_cmp(a_, a) != 0) { + // not every number has a square root, so it is not an error + // `sm2_z256_point_from_hash` need a non-negative return value + return 0; + } + + sm2_z256_copy(r, r_); + return 1; } int sm2_z256_modp_mont_print(FILE *fp, int ind, int fmt, const char *label, const uint64_t a[4]) @@ -621,6 +679,11 @@ const uint64_t SM2_Z256_N[4] = { 0x53bbf40939d54123, 0x7203df6b21c6052b, 0xffffffffffffffff, 0xfffffffeffffffff, }; +const uint64_t SM2_Z256_N_MINUS_ONE[4] = { + 0x53bbf40939d54122, 0x7203df6b21c6052b, 0xffffffffffffffff, 0xfffffffeffffffff, +}; + + // 2^256 - n = 0x10000000000000000000000008dfc2094de39fad4ac440bf6c62abedd const uint64_t SM2_Z256_NEG_N[4] = { 0xac440bf6c62abedd, 0x8dfc2094de39fad4, 0x0000000000000000, 0x0000000100000000, @@ -680,6 +743,10 @@ const uint64_t *sm2_z256_order(void) { return &SM2_Z256_N[0]; } +const uint64_t *sm2_z256_order_minus_one(void) { + return &SM2_Z256_N_MINUS_ONE[0]; +} + // mont(1) (mod n) = 2^256 - n const uint64_t *SM2_Z256_MODN_MONT_ONE = SM2_Z256_NEG_N; @@ -784,10 +851,45 @@ void sm2_z256_modn_exp(uint64_t r[4], const uint64_t a[4], const uint64_t e[4]) const uint64_t SM2_Z256_N_MINUS_TWO[4] = { 0x53bbf40939d54121, 0x7203df6b21c6052b, 0xffffffffffffffff, 0xfffffffeffffffff, }; +// exp都是从高位开始的,如果都是1的话,那么就是都要平方和乘 void sm2_z256_modn_mont_inv(uint64_t r[4], const uint64_t a[4]) { - sm2_z256_modn_mont_exp(r, a, SM2_Z256_N_MINUS_TWO); + // expand sm2_z256_modn_mont_exp(r, a, SM2_Z256_N_MINUS_TWO) + uint64_t t[4]; + uint64_t w; + int i; + int k = 0; + + sm2_z256_copy(t, a); + + for (i = 0; i < 30; i++) { + sm2_z256_modn_mont_sqr(t, t); + sm2_z256_modn_mont_mul(t, t, a); + } + sm2_z256_modn_mont_sqr(t, t); + for (i = 0; i < 96; i++) { + sm2_z256_modn_mont_sqr(t, t); + sm2_z256_modn_mont_mul(t, t, a); + } + w = SM2_Z256_N_MINUS_TWO[1]; + for (i = 0; i < 64; i++) { + sm2_z256_modn_mont_sqr(t, t); + if (w & 0x8000000000000000) { + sm2_z256_modn_mont_mul(t, t, a); + } + w <<= 1; + } + w = SM2_Z256_N_MINUS_TWO[0]; + for (i = 0; i < 64; i++) { + sm2_z256_modn_mont_sqr(t, t); + if (w & 0x8000000000000000) { + sm2_z256_modn_mont_mul(t, t, a); + } + w <<= 1; + } + + sm2_z256_copy(r, t); } void sm2_z256_modn_inv(uint64_t r[4], const uint64_t a[4]) @@ -805,7 +907,6 @@ void sm2_z256_modn_from_mont(uint64_t r[4], const uint64_t a[4]) sm2_z256_modn_mont_mul(r, a, SM2_Z256_ONE); } - // 2^512 (mod n) = 0x1eb5e412a22b3d3b620fc84c3affe0d43464504ade6fa2fa901192af7c114f20 const uint64_t SM2_Z256_2e512modn[4] = { 0x901192af7c114f20, 0x3464504ade6fa2fa, 0x620fc84c3affe0d4, 0x1eb5e412a22b3d3b, @@ -828,11 +929,29 @@ int sm2_z256_modn_mont_print(FILE *fp, int ind, int fmt, const char *label, cons // Jacobian Point with Montgomery coordinates +void sm2_z256_point_set_infinity(SM2_Z256_POINT *P) +{ + sm2_z256_copy(P->X, SM2_Z256_MODP_MONT_ONE); + sm2_z256_copy(P->Y, SM2_Z256_MODP_MONT_ONE); + sm2_z256_set_zero(P->Z); +} -// 这里还应该检查X == Y == mont(1) +// point at infinity should be like (k^2 : k^3 : 0), k in [0, p-1] int sm2_z256_point_is_at_infinity(const SM2_Z256_POINT *P) { if (sm2_z256_is_zero(P->Z)) { + uint64_t X_cub[4]; + uint64_t Y_sqr[4]; + + sm2_z256_modp_mont_sqr(X_cub, P->X); + sm2_z256_modp_mont_mul(X_cub, X_cub, P->X); + sm2_z256_modp_mont_sqr(Y_sqr, P->Y); + + if (sm2_z256_cmp(X_cub, Y_sqr) != 0) { + error_print(); + return 0; + } + return 1; } else { return 0; @@ -907,6 +1026,34 @@ void sm2_z256_point_get_xy(const SM2_Z256_POINT *P, uint64_t x[4], uint64_t y[4] } } +// impl with modified jacobian coordinates +void sm2_z256_point_dbl_x5(SM2_Z256_POINT *R, const SM2_Z256_POINT *A) + +{ + sm2_z256_point_dbl(R, A); + sm2_z256_point_dbl(R, R); + sm2_z256_point_dbl(R, R); + sm2_z256_point_dbl(R, R); + sm2_z256_point_dbl(R, R); +} + +void sm2_z256_point_multi_dbl(SM2_Z256_POINT *R, const SM2_Z256_POINT *P, unsigned int i) +{ + const uint64_t *X1 = P->X; + const uint64_t *Y1 = P->Y; + const uint64_t *Z1 = P->Z; + uint64_t *X3 = R->X; + uint64_t *Y3 = R->Y; + uint64_t *Z3 = R->Z; + uint64_t A[4]; + uint64_t B[4]; + uint64_t C[4]; + uint64_t D[4]; + uint64_t E[4]; + + // A = Z1^2 +} + void sm2_z256_point_dbl(SM2_Z256_POINT *R, const SM2_Z256_POINT *A) { const uint64_t *X1 = A->X; @@ -922,77 +1069,75 @@ void sm2_z256_point_dbl(SM2_Z256_POINT *R, const SM2_Z256_POINT *A) // S = 2*Y1 sm2_z256_modp_mul_by_2(S, Y1); - //sm2_z256_modp_mont_print(stderr, 0, 0, "1", S); + // Zsqr = Z1^2 sm2_z256_modp_mont_sqr(Zsqr, Z1); - //sm2_z256_modp_mont_print(stderr, 0, 0, "2", Zsqr); // S = S^2 = 4*Y1^2 sm2_z256_modp_mont_sqr(S, S); - //sm2_z256_modp_mont_print(stderr, 0, 0, "3", S); // Z3 = Z1 * Y1 sm2_z256_modp_mont_mul(Z3, Z1, Y1); - //sm2_z256_modp_mont_print(stderr, 0, 0, "4", Z3); // Z3 = 2 * Z3 = 2*Y1*Z1 sm2_z256_modp_mul_by_2(Z3, Z3); - //sm2_z256_modp_mont_print(stderr, 0, 0, "5", Z3); // M = X1 + Zsqr = X1 + Z1^2 sm2_z256_modp_add(M, X1, Zsqr); - //sm2_z256_modp_mont_print(stderr, 0, 0, "6", M); // Zsqr = X1 - Zsqr = X1 - Z1^2 sm2_z256_modp_sub(Zsqr, X1, Zsqr); - //sm2_z256_modp_mont_print(stderr, 0, 0, "7", Zsqr); // Y3 = S^2 = 16 * Y1^4 sm2_z256_modp_mont_sqr(Y3, S); - //sm2_z256_modp_mont_print(stderr, 0, 0, "8", Y3); // Y3 = Y3/2 = 8 * Y1^4 sm2_z256_modp_div_by_2(Y3, Y3); - //sm2_z256_modp_mont_print(stderr, 0, 0, "9", Y3); // M = M * Zsqr = (X1 + Z1^2)(X1 - Z1^2) sm2_z256_modp_mont_mul(M, M, Zsqr); - //sm2_z256_modp_mont_print(stderr, 0, 0, "10", M); // M = 3*M = 3(X1 + Z1^2)(X1 - Z1^2) sm2_z256_modp_mul_by_3(M, M); - //sm2_z256_modp_mont_print(stderr, 0, 0, "11", M); // S = S * X1 = 4 * X1 * Y1^2 sm2_z256_modp_mont_mul(S, S, X1); - //sm2_z256_modp_mont_print(stderr, 0, 0, "12", S); // tmp0 = 2 * S = 8 * X1 * Y1^2 sm2_z256_modp_mul_by_2(tmp0, S); - //sm2_z256_modp_mont_print(stderr, 0, 0, "13", tmp0); // X3 = M^2 = (3(X1 + Z1^2)(X1 - Z1^2))^2 sm2_z256_modp_mont_sqr(X3, M); - //sm2_z256_modp_mont_print(stderr, 0, 0, "14", X3); // X3 = X3 - tmp0 = (3(X1 + Z1^2)(X1 - Z1^2))^2 - 8 * X1 * Y1^2 sm2_z256_modp_sub(X3, X3, tmp0); - //sm2_z256_modp_mont_print(stderr, 0, 0, "15", X3); // S = S - X3 = 4 * X1 * Y1^2 - X3 sm2_z256_modp_sub(S, S, X3); - //sm2_z256_modp_mont_print(stderr, 0, 0, "16", S); // S = S * M = 3(X1 + Z1^2)(X1 - Z1^2)(4 * X1 * Y1^2 - X3) sm2_z256_modp_mont_mul(S, S, M); - //sm2_z256_modp_mont_print(stderr, 0, 0, "17", S); // Y3 = S - Y3 = 3(X1 + Z1^2)(X1 - Z1^2)(4 * X1 * Y1^2 - X3) - 8 * Y1^4 sm2_z256_modp_sub(Y3, S, Y3); - //sm2_z256_modp_mont_print(stderr, 0, 0, "18", Y3); } +/* + (X1:Y1:Z1) + (X2:Y2:Z2) => (X3:Y3:Z3) + + A = Y2 * Z1^3 - Y1 * Z2^3 + B = X2 * Z1^2 - X1 * Z2^2 + + X3 = A^2 - B^2 * (X2 * Z1^2 + X1 * Z2^2) + = A^2 - B^3 - 2 * B^2 * X1 * Z2^2 + Y3 = A * (X1 * B^2 * Z2^2 - X3) - Y1 * B^3 * Z2^3 + Z3 = B * Z1 * Z2 + + P + (-P) = (X:Y:Z) + (k^2*X : k^3*Y : k*Z) => (0:0:0) + +感觉点加也有很好的并行性 +*/ void sm2_z256_point_add(SM2_Z256_POINT *r, const SM2_Z256_POINT *a, const SM2_Z256_POINT *b) { uint64_t U2[4], S2[4]; @@ -1028,6 +1173,7 @@ void sm2_z256_point_add(SM2_Z256_POINT *r, const SM2_Z256_POINT *a, const SM2_Z2 in1infty = is_zero(in1infty); in2infty = is_zero(in2infty); + // 这里很明显有极好的并行性 sm2_z256_modp_mont_sqr(Z2sqr, in2_z); /* Z2^2 */ sm2_z256_modp_mont_sqr(Z1sqr, in1_z); /* Z1^2 */ @@ -1057,11 +1203,13 @@ void sm2_z256_point_add(SM2_Z256_POINT *r, const SM2_Z256_POINT *a, const SM2_Z2 sm2_z256_modp_mont_sqr(Rsqr, R); /* R^2 */ sm2_z256_modp_mont_mul(res_z, H, in1_z); /* Z3 = H*Z1*Z2 */ + sm2_z256_modp_mont_sqr(Hsqr, H); /* H^2 */ sm2_z256_modp_mont_mul(res_z, res_z, in2_z); /* Z3 = H*Z1*Z2 */ - sm2_z256_modp_mont_mul(Hcub, Hsqr, H); /* H^3 */ + sm2_z256_modp_mont_mul(Hcub, Hsqr, H); /* H^3 */ sm2_z256_modp_mont_mul(U2, U1, Hsqr); /* U1*H^2 */ + sm2_z256_modp_mul_by_2(Hsqr, U2); /* 2*U1*H^2 */ sm2_z256_modp_sub(res_x, Rsqr, Hsqr); @@ -1071,6 +1219,7 @@ void sm2_z256_point_add(SM2_Z256_POINT *r, const SM2_Z256_POINT *a, const SM2_Z2 sm2_z256_modp_mont_mul(S2, S1, Hcub); sm2_z256_modp_mont_mul(res_y, R, res_y); + sm2_z256_modp_sub(res_y, res_y, S2); sm2_z256_copy_conditional(res_x, in2_x, in1infty); @@ -1093,7 +1242,6 @@ void sm2_z256_point_neg(SM2_Z256_POINT *R, const SM2_Z256_POINT *P) sm2_z256_copy(R->Z, P->Z); } -// point_mul 中用到 void sm2_z256_point_sub(SM2_Z256_POINT *R, const SM2_Z256_POINT *A, const SM2_Z256_POINT *B) { SM2_Z256_POINT neg_B; @@ -1109,8 +1257,28 @@ void sm2_z256_point_mul(SM2_Z256_POINT *R, const uint64_t k[4], const SM2_Z256_P int n = (256 + window_size - 1)/window_size; int i; + // 这相当于做了一个预计算表 + /* + P 2P 4P 8P // 这实际上是一个连续的dbl + + 3P 6P, 12P + + 5P, 10P, + + 7P, 14P + + 15P + ... + + // 如果一次能并行计算4组点加法,那么这部分与计算表的计算量可以降低 + // 这个连续计算中,dbl的数量越多,计算量越低 + */ + // T[i] = (i + 1) * P memcpy(&T[0], P, sizeof(SM2_Z256_POINT)); + + // 这个计算大概是有并行能力的! + /* sm2_z256_point_dbl(&T[ 1], &T[ 0]); sm2_z256_point_add(&T[ 2], &T[ 1], P); sm2_z256_point_dbl(&T[ 3], &T[ 1]); @@ -1126,6 +1294,24 @@ void sm2_z256_point_mul(SM2_Z256_POINT *R, const uint64_t k[4], const SM2_Z256_P sm2_z256_point_dbl(&T[13], &T[ 6]); sm2_z256_point_add(&T[14], &T[13], P); sm2_z256_point_dbl(&T[15], &T[ 7]); + */ + + sm2_z256_point_dbl(&T[2-1], &T[1-1]); + sm2_z256_point_dbl(&T[4-1], &T[2-1]); + sm2_z256_point_dbl(&T[8-1], &T[4-1]); + sm2_z256_point_dbl(&T[16-1], &T[8-1]); + sm2_z256_point_add(&T[3-1], &T[2-1], P); + sm2_z256_point_dbl(&T[6-1], &T[3-1]); + sm2_z256_point_dbl(&T[12-1], &T[6-1]); + sm2_z256_point_add(&T[5-1], &T[3-1], &T[2-1]); + sm2_z256_point_dbl(&T[10-1], &T[5-1]); + sm2_z256_point_add(&T[7-1], &T[4-1], &T[3-1]); + sm2_z256_point_dbl(&T[14-1], &T[7-1]); + sm2_z256_point_add(&T[9-1], &T[4-1], &T[5-1]); + sm2_z256_point_add(&T[11-1], &T[6-1], &T[5-1]); + sm2_z256_point_add(&T[13-1], &T[7-1], &T[6-1]); + sm2_z256_point_add(&T[15-1], &T[8-1], &T[7-1]); + for (i = n - 1; i >= 0; i--) { int booth = sm2_z256_get_booth(k, window_size, i); @@ -1136,11 +1322,9 @@ void sm2_z256_point_mul(SM2_Z256_POINT *R, const uint64_t k[4], const SM2_Z256_P R_infinity = 0; } } else { - sm2_z256_point_dbl(R, R); - sm2_z256_point_dbl(R, R); - sm2_z256_point_dbl(R, R); - sm2_z256_point_dbl(R, R); - sm2_z256_point_dbl(R, R); + // 这个重复dbl的计算可以适当降低吗? + // 这说明对dbl的优化还是很有意义的,因为这里面dbl的数量最多 + sm2_z256_point_dbl_x5(R, R); if (booth > 0) { sm2_z256_point_add(R, R, &T[booth - 1]); @@ -1177,6 +1361,8 @@ void sm2_z256_point_copy_affine(SM2_Z256_POINT *R, const SM2_Z256_POINT_AFFINE * sm2_z256_copy(R->Z, SM2_Z256_MODP_MONT_ONE); } +// 这是一个比较容易并行的算法 +// r, a, b 都转换为实际输入的值 void sm2_z256_point_add_affine(SM2_Z256_POINT *r, const SM2_Z256_POINT *a, const SM2_Z256_POINT_AFFINE *b) { uint64_t U2[4], S2[4]; @@ -1287,44 +1473,60 @@ int sm2_z256_point_affine_print(FILE *fp, int fmt, int ind, const char *label, c extern const uint64_t sm2_z256_pre_comp[37][64 * 4 * 2]; static SM2_Z256_POINT_AFFINE (*g_pre_comp)[64] = (SM2_Z256_POINT_AFFINE (*)[64])sm2_z256_pre_comp; + +/* +这个函数的粗粒度并行算法 + + 输出的R应该有多个,输入的k也有多个 + + 轮数是一样的 + + 需要用一个数组表示这个值是否还是无穷远点 + +在签名、加密的时候,参与计算的k都是秘密值,因此需要考虑cache攻击的问题 + +但是在验签的时候,其中s*G计算,其中s是公开值,因此不需要考虑cache攻击 + +应该提供一个专用的常量时间的gather函数 + +*/ void sm2_z256_point_mul_generator(SM2_Z256_POINT *R, const uint64_t k[4]) { size_t window_size = 7; - int R_infinity = 1; + int R_infinity = 1; // 开始的时候点 int n = (256 + window_size - 1)/window_size; int i; for (i = n - 1; i >= 0; i--) { int booth = sm2_z256_get_booth(k, window_size, i); + // 下面的计算应该改为并行化 if (R_infinity) { if (booth != 0) { sm2_z256_point_copy_affine(R, &g_pre_comp[i][booth - 1]); R_infinity = 0; } } else { + + // 可以先把那个点从内存复制到当前空间中 + // 如果booth < 0,则把这个点改为 -P + // 然后再加上这个点,得到一个新的结果 if (booth > 0) { sm2_z256_point_add_affine(R, R, &g_pre_comp[i][booth - 1]); } else if (booth < 0) { sm2_z256_point_sub_affine(R, R, &g_pre_comp[i][-booth - 1]); } + + // booth == 0的时候意味应该加入的affine是一个无穷远点 + // 如果是无穷远点,读入的值,以及计算结果就没有用了。 } } if (R_infinity) { - memset(R, 0, sizeof(*R)); + sm2_z256_point_set_infinity(R); } } - - - - - - - - - // R = t*P + s*G void sm2_z256_point_mul_sum(SM2_Z256_POINT *R, const uint64_t t[4], const SM2_Z256_POINT *P, const uint64_t s[4]) { @@ -1334,8 +1536,6 @@ void sm2_z256_point_mul_sum(SM2_Z256_POINT *R, const uint64_t t[4], const SM2_Z2 sm2_z256_point_add(R, R, &Q); } - -// 这个是否要检查点是否在曲线上? void sm2_z256_point_from_bytes(SM2_Z256_POINT *P, const uint8_t in[64]) { sm2_z256_from_bytes(P->X, in); @@ -1364,6 +1564,35 @@ void sm2_z256_point_to_bytes(const SM2_Z256_POINT *P, uint8_t out[64]) sm2_z256_to_bytes(y, out + 32); } +int sm2_z256_point_equ(const SM2_Z256_POINT *P, const SM2_Z256_POINT *Q) +{ + uint64_t Z1[4] = {0}; + uint64_t Z2[4] = {0}; + uint64_t V1[4] = {0}; + uint64_t V2[4] = {0}; + + // X1 * Z2^2 == X2 * Z1^2 + sm2_z256_modp_mont_sqr(Z1, P->Z); + sm2_z256_modp_mont_sqr(Z2, Q->Z); + sm2_z256_modp_mont_mul(V1, P->X, Z2); + sm2_z256_modp_mont_mul(V2, Q->X, Z1); + if (sm2_z256_cmp(V1, V2) != 0) { + error_print(); + return 0; + } + + // Y1 * Z2^3 == Y2 * Z1^3 + sm2_z256_modp_mont_mul(Z1, Z1, P->Z); + sm2_z256_modp_mont_mul(Z2, Z2, Q->Z); + sm2_z256_modp_mont_mul(V1, P->Y, Z2); + sm2_z256_modp_mont_mul(V2, Q->Y, Z1); + if (sm2_z256_cmp(V1, V2) != 0) { + error_print(); + return 0; + } + + return 1; +} int sm2_z256_point_equ_hex(const SM2_Z256_POINT *P, const char *hex) { @@ -1379,8 +1608,96 @@ int sm2_z256_point_equ_hex(const SM2_Z256_POINT *P, const char *hex) return 0; } return 1; - } +int sm2_z256_is_odd(const uint64_t a[4]) +{ + return a[0] & 0x01; +} +int sm2_z256_point_from_x_bytes(SM2_Z256_POINT *P, const uint8_t x_bytes[32], int y_is_odd) +{ + uint64_t x[4]; + uint64_t y_sqr[4]; + uint64_t y[4]; + int ret; + + uint64_t SM2_Z256_MODP_MONT_THREE[4] = { 3,0,0,0 }; + + sm2_z256_modp_to_mont(SM2_Z256_MODP_MONT_THREE, SM2_Z256_MODP_MONT_THREE); + + sm2_z256_from_bytes(x, x_bytes); + if (sm2_z256_cmp(x, SM2_Z256_P) >= 0) { + error_print(); + return -1; + } + + sm2_z256_modp_to_mont(x, x); + + sm2_z256_copy(P->X, x); + + // y^2 = x^3 - 3x + b = (x^2 - 3)*x + b + sm2_z256_modp_mont_sqr(y_sqr, x); + sm2_z256_modp_sub(y_sqr, y_sqr, SM2_Z256_MODP_MONT_THREE); + sm2_z256_modp_mont_mul(y_sqr, y_sqr, x); + sm2_z256_modp_add(y_sqr, y_sqr, SM2_Z256_MODP_MONT_B); + + // y = sqrt(y^2) + if ((ret = sm2_z256_modp_mont_sqrt(y, y_sqr)) != 1) { + if (ret < 0) error_print(); + return ret; + } + + sm2_z256_copy(P->Y , y); // mont(y) + + sm2_z256_modp_from_mont(y, y); + if (y_is_odd) { + if (!sm2_z256_is_odd(y)) { + sm2_z256_modp_neg(P->Y, P->Y); + } + } else { + if (sm2_z256_is_odd(y)) { + sm2_z256_modp_neg(P->Y, P->Y); + } + } + + sm2_z256_copy(P->Z, SM2_Z256_MODP_MONT_ONE); + + return 1; +} + +int sm2_z256_point_from_hash(SM2_Z256_POINT *R, const uint8_t *data, size_t datalen, int y_is_odd) +{ + uint64_t x[4]; + uint8_t x_bytes[32]; + uint8_t dgst[32]; + int ret; + + do { + // x = sm3(data) mod p + sm3_digest(data, datalen, dgst); + + sm2_z256_from_bytes(x, dgst); + if (sm2_z256_cmp(x, SM2_Z256_P) >= 0) { + sm2_z256_sub(x, x, SM2_Z256_P); + } + sm2_z256_to_bytes(x, x_bytes); + + // compute y + if ((ret = sm2_z256_point_from_x_bytes(R, x_bytes, y_is_odd)) == 1) { + break; + } + if (ret < 0) { + error_print(); + return -1; + } + + // data = sm3(data), try again + data = dgst; + datalen = sizeof(dgst); + + } while (1); + + return 1; +} diff --git a/src/sm2_z256_key.c b/src/sm2_z256_key.c new file mode 100644 index 00000000..c2859ff7 --- /dev/null +++ b/src/sm2_z256_key.c @@ -0,0 +1,706 @@ +/* + * Copyright 2014-2024 The GmSSL Project. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +int sm2_key_generate(SM2_KEY *key) +{ + uint64_t d[4]; + SM2_Z256_POINT P; + + if (!key) { + error_print(); + return -1; + } + + do { + if (sm2_z256_rand_range(d, sm2_z256_order_minus_one()) != 1) { + error_print(); + return -1; + } + } while (sm2_z256_is_zero(d)); + + sm2_z256_point_mul_generator(&P, d); + + sm2_z256_to_bytes(d, key->private_key); + sm2_z256_point_to_bytes(&P, (uint8_t *)&key->public_key); + + gmssl_secure_clear(d, sizeof(d)); + return 1; +} + +int sm2_key_set_private_key(SM2_KEY *key, const uint8_t private_key[32]) +{ + uint64_t d[4]; + SM2_Z256_POINT P; + int ret = -1; + + if (!key || !private_key) { + error_print(); + return -1; + } + + sm2_z256_from_bytes(d, private_key); + + if (sm2_z256_is_zero(d)) { + error_print(); + goto end; + } + if (sm2_z256_cmp(d, sm2_z256_order_minus_one()) >= 0) { + error_print(); + goto end; + } + + sm2_z256_point_mul_generator(&P, d); + + sm2_z256_to_bytes(d, key->private_key); + sm2_z256_point_to_bytes(&P, (uint8_t *)&key->public_key); + + ret = 1; +end: + gmssl_secure_clear(d, sizeof(d)); + return ret; +} + +int sm2_key_set_public_key(SM2_KEY *key, const SM2_POINT *public_key) +{ + uint64_t d[4] = {0}; + SM2_Z256_POINT P; + + if (!key || !public_key) { + error_print(); + return -1; + } + + sm2_z256_point_from_bytes(&P, (uint8_t *)public_key); + if (sm2_z256_point_is_on_curve(&P) != 1) { + error_print(); + return -1; + } + + sm2_z256_to_bytes(d, key->private_key); + sm2_z256_point_to_bytes(&P, (uint8_t *)&key->public_key); + + return 1; +} + +int sm2_key_print(FILE *fp, int fmt, int ind, const char *label, const SM2_KEY *key) +{ + format_print(fp, fmt, ind, "%s\n", label); + ind += 4; + sm2_public_key_print(fp, fmt, ind, "publicKey", key); + format_bytes(fp, fmt, ind, "privateKey", key->private_key, 32); + return 1; +} + + +int sm2_public_key_to_der(const SM2_KEY *key, uint8_t **out, size_t *outlen) +{ + uint8_t buf[65]; + size_t len = 0; + + if (!key) { + return 0; + } + sm2_point_to_uncompressed_octets(&key->public_key, buf); + if (asn1_bit_octets_to_der(buf, sizeof(buf), out, outlen) != 1) { + error_print(); + return -1; + } + return 1; +} + +int sm2_public_key_from_der(SM2_KEY *key, const uint8_t **in, size_t *inlen) +{ + int ret; + const uint8_t *d; + size_t dlen; + SM2_POINT P; + + if ((ret = asn1_bit_octets_from_der(&d, &dlen, in, inlen)) != 1) { + if (ret < 0) error_print(); + return ret; + } + if (dlen != 65) { + error_print(); + return -1; + } + + // 这里不太对,SM2_POINT 被反复检查了 + if (sm2_point_from_octets(&P, d, dlen) != 1) { + error_print(); + return -1; + } + if (sm2_key_set_public_key(key, &P) != 1) { + error_print(); + return -1; + } + return 1; +} + +int sm2_public_key_print(FILE *fp, int fmt, int ind, const char *label, const SM2_KEY *pub_key) +{ + return sm2_point_print(fp, fmt, ind, label, &pub_key->public_key); +} + +int sm2_public_key_algor_to_der(uint8_t **out, size_t *outlen) +{ + if (x509_public_key_algor_to_der(OID_ec_public_key, OID_sm2, out, outlen) != 1) { + error_print(); + return -1; + } + return 1; +} + +int sm2_public_key_algor_from_der(const uint8_t **in, size_t *inlen) +{ + int ret; + int oid; + int curve; + + if ((ret = x509_public_key_algor_from_der(&oid, &curve, in, inlen)) != 1) { + if (ret < 0) error_print(); + return ret; + } + if (oid != OID_ec_public_key) { + error_print(); + return -1; + } + if (curve != OID_sm2) { + error_print(); + return -1; + } + return 1; +} + +#define SM2_PRIVATE_KEY_DER_SIZE 121 +int sm2_private_key_to_der(const SM2_KEY *key, uint8_t **out, size_t *outlen) +{ + size_t len = 0; + uint8_t params[64]; + uint8_t pubkey[128]; + uint8_t *params_ptr = params; + uint8_t *pubkey_ptr = pubkey; + size_t params_len = 0; + size_t pubkey_len = 0; + + if (!key) { + error_print(); + return -1; + } + if (ec_named_curve_to_der(OID_sm2, ¶ms_ptr, ¶ms_len) != 1 + || sm2_public_key_to_der(key, &pubkey_ptr, &pubkey_len) != 1) { + error_print(); + return -1; + } + if (asn1_int_to_der(EC_private_key_version, NULL, &len) != 1 + || asn1_octet_string_to_der(key->private_key, 32, NULL, &len) != 1 + || asn1_explicit_to_der(0, params, params_len, NULL, &len) != 1 + || asn1_explicit_to_der(1, pubkey, pubkey_len, NULL, &len) != 1 + || asn1_sequence_header_to_der(len, out, outlen) != 1 + || asn1_int_to_der(EC_private_key_version, out, outlen) != 1 + || asn1_octet_string_to_der(key->private_key, 32, out, outlen) != 1 + || asn1_explicit_to_der(0, params, params_len, out, outlen) != 1 + || asn1_explicit_to_der(1, pubkey, pubkey_len, out, outlen) != 1) { + error_print(); + return -1; + } + return 1; +} + +int sm2_private_key_from_der(SM2_KEY *key, const uint8_t **in, size_t *inlen) +{ + int ret; + const uint8_t *d; + size_t dlen; + int ver; + const uint8_t *prikey; + const uint8_t *params; + const uint8_t *pubkey; + size_t prikey_len, params_len, pubkey_len; + + if ((ret = asn1_sequence_from_der(&d, &dlen, in, inlen)) != 1) { + if (ret < 0) error_print(); + return ret; + } + if (asn1_int_from_der(&ver, &d, &dlen) != 1 + || asn1_octet_string_from_der(&prikey, &prikey_len, &d, &dlen) != 1 + || asn1_explicit_from_der(0, ¶ms, ¶ms_len, &d, &dlen) != 1 + || asn1_explicit_from_der(1, &pubkey, &pubkey_len, &d, &dlen) != 1 + || asn1_check(ver == EC_private_key_version) != 1 + || asn1_length_is_zero(dlen) != 1) { + error_print(); + return -1; + } + if (params) { + int curve; + if (ec_named_curve_from_der(&curve, ¶ms, ¶ms_len) != 1 + || asn1_check(curve == OID_sm2) != 1 + || asn1_length_is_zero(params_len) != 1) { + error_print(); + return -1; + } + } + if (asn1_check(prikey_len == 32) != 1 + || sm2_key_set_private_key(key, prikey) != 1) { + error_print(); + return -1; + } + + // check if the public key is correct + if (pubkey) { + SM2_KEY tmp_key; + if (sm2_public_key_from_der(&tmp_key, &pubkey, &pubkey_len) != 1 + || asn1_length_is_zero(pubkey_len) != 1) { + error_print(); + return -1; + } + if (sm2_public_key_equ(key, &tmp_key) != 1) { + error_print(); + return -1; + } + } + return 1; +} + +int sm2_private_key_print(FILE *fp, int fmt, int ind, const char *label, const uint8_t *d, size_t dlen) +{ + return ec_private_key_print(fp, fmt, ind, label, d, dlen); +} + +#define SM2_PRIVATE_KEY_INFO_DER_SIZE 150 + +int sm2_private_key_info_to_der(const SM2_KEY *sm2_key, uint8_t **out, size_t *outlen) +{ + size_t len = 0; + uint8_t prikey[SM2_PRIVATE_KEY_DER_SIZE]; + uint8_t *p = prikey; + size_t prikey_len = 0; + + if (sm2_private_key_to_der(sm2_key, &p, &prikey_len) != 1) { + error_print(); + return -1; + } + if (asn1_int_to_der(PKCS8_private_key_info_version, NULL, &len) != 1 + || sm2_public_key_algor_to_der(NULL, &len) != 1 + || asn1_octet_string_to_der(prikey, prikey_len, NULL, &len) != 1 + || asn1_sequence_header_to_der(len, out, outlen) != 1 + || asn1_int_to_der(PKCS8_private_key_info_version, out, outlen) != 1 + || sm2_public_key_algor_to_der(out, outlen) != 1 + || asn1_octet_string_to_der(prikey, prikey_len, out, outlen) != 1) { + memset(prikey, 0, sizeof(prikey)); + error_print(); + return -1; + } + memset(prikey, 0, sizeof(prikey)); + return 1; +} + +int sm2_private_key_info_from_der(SM2_KEY *sm2_key, const uint8_t **attrs, size_t *attrslen, + const uint8_t **in, size_t *inlen) +{ + int ret; + const uint8_t *d; + size_t dlen; + int version; + const uint8_t *prikey; + size_t prikey_len; + + if ((ret = asn1_sequence_from_der(&d, &dlen, in, inlen)) != 1) { + if (ret < 0) error_print(); + return ret; + } + if (asn1_int_from_der(&version, &d, &dlen) != 1 + || sm2_public_key_algor_from_der(&d, &dlen) != 1 + || asn1_octet_string_from_der(&prikey, &prikey_len, &d, &dlen) != 1 + || asn1_implicit_set_from_der(0, attrs, attrslen, &d, &dlen) < 0 + || asn1_length_is_zero(dlen) != 1) { + error_print(); + return -1; + } + if (asn1_check(version == PKCS8_private_key_info_version) != 1 + || sm2_private_key_from_der(sm2_key, &prikey, &prikey_len) != 1 + || asn1_length_is_zero(prikey_len) != 1) { + error_print(); + return -1; + } + return 1; +} + +int sm2_private_key_info_print(FILE *fp, int fmt, int ind, const char *label, const uint8_t *d, size_t dlen) +{ + int ret; + const uint8_t *p; + size_t len; + int val; + const uint8_t *prikey; + size_t prikey_len; + + format_print(fp, fmt, ind, "%s\n", label); + ind += 4; + + if (asn1_int_from_der(&val, &d, &dlen) != 1) goto err; + format_print(fp, fmt, ind, "version: %d\n", val); + if (asn1_sequence_from_der(&p, &len, &d, &dlen) != 1) goto err; + x509_public_key_algor_print(fp, fmt, ind, "privateKeyAlgorithm", p, len); + if (asn1_octet_string_from_der(&p, &len, &d, &dlen) != 1) goto err; + if (asn1_sequence_from_der(&prikey, &prikey_len, &p, &len) != 1) goto err; + ec_private_key_print(fp, fmt, ind + 4, "privateKey", prikey, prikey_len); + if (asn1_length_is_zero(len) != 1) goto err; + if ((ret = asn1_implicit_set_from_der(0, &p, &len, &d, &dlen)) < 0) goto err; + else if (ret) format_bytes(fp, fmt, ind, "attributes", p, len); + if (asn1_length_is_zero(dlen) != 1) goto err; + return 1; +err: + error_print(); + return -1; +} + +#ifdef ENABLE_SM2_PRIVATE_KEY_EXPORT +int sm2_private_key_info_to_pem(const SM2_KEY *key, FILE *fp) +{ + int ret = -1; + uint8_t buf[SM2_PRIVATE_KEY_INFO_DER_SIZE]; + uint8_t *p = buf; + size_t len = 0; + + if (!key || !fp) { + error_print(); + return -1; + } + if (sm2_private_key_info_to_der(key, &p, &len) != 1) { + error_print(); + goto end; + } + if (len != sizeof(buf)) { + error_print(); + goto end; + } + if (pem_write(fp, "PRIVATE KEY", buf, len) != 1) { + error_print(); + goto end; + } + ret = 1; +end: + gmssl_secure_clear(buf, sizeof(buf)); + return ret; +} + +int sm2_private_key_info_from_pem(SM2_KEY *sm2_key, FILE *fp) +{ + uint8_t buf[512]; + const uint8_t *cp = buf; + size_t len; + const uint8_t *attrs; + size_t attrs_len; + + if (pem_read(fp, "PRIVATE KEY", buf, &len, sizeof(buf)) != 1 + || sm2_private_key_info_from_der(sm2_key, &attrs, &attrs_len, &cp, &len) != 1 + || asn1_length_is_zero(len) != 1) { + error_print(); + return -1; + } + if (attrs_len) { + error_print(); + } + return 1; +} +#endif + +int sm2_public_key_info_to_der(const SM2_KEY *pub_key, uint8_t **out, size_t *outlen) +{ + size_t len = 0; + if (sm2_public_key_algor_to_der(NULL, &len) != 1 + || sm2_public_key_to_der(pub_key, NULL, &len) != 1 + || asn1_sequence_header_to_der(len, out, outlen) != 1 + || sm2_public_key_algor_to_der(out, outlen) != 1 + || sm2_public_key_to_der(pub_key, out, outlen) != 1) { + error_print(); + return -1; + } + return 1; +} + +int sm2_public_key_info_from_der(SM2_KEY *pub_key, const uint8_t **in, size_t *inlen) +{ + int ret; + const uint8_t *d; + size_t dlen; + + if ((ret = asn1_sequence_from_der(&d, &dlen, in, inlen)) != 1) { + if (ret < 0) error_print(); + return ret; + } + if (sm2_public_key_algor_from_der(&d, &dlen) != 1 + || sm2_public_key_from_der(pub_key, &d, &dlen) != 1 + || asn1_length_is_zero(dlen) != 1) { + error_print(); + return -1; + } + return 1; +} + +#ifdef ENABLE_SM2_PRIVATE_KEY_EXPORT + +// FIXME: side-channel of Base64 +int sm2_private_key_to_pem(const SM2_KEY *a, FILE *fp) +{ + uint8_t buf[512]; + uint8_t *p = buf; + size_t len = 0; + + if (sm2_private_key_to_der(a, &p, &len) != 1) { + error_print(); + return -1; + } + if (pem_write(fp, "EC PRIVATE KEY", buf, len) <= 0) { + error_print(); + return -1; + } + return 1; +} + +int sm2_private_key_from_pem(SM2_KEY *a, FILE *fp) +{ + uint8_t buf[512]; + const uint8_t *cp = buf; + size_t len; + + if (pem_read(fp, "EC PRIVATE KEY", buf, &len, sizeof(buf)) != 1) { + error_print(); + return -1; + } + if (sm2_private_key_from_der(a, &cp, &len) != 1 + || len > 0) { + error_print(); + return -1; + } + return 1; +} +#endif + +int sm2_public_key_info_to_pem(const SM2_KEY *a, FILE *fp) +{ + uint8_t buf[512]; + uint8_t *p = buf; + size_t len = 0; + + if (sm2_public_key_info_to_der(a, &p, &len) != 1) { + error_print(); + return -1; + } + if (pem_write(fp, "PUBLIC KEY", buf, len) <= 0) { + error_print(); + return -1; + } + return 1; +} + +int sm2_public_key_info_from_pem(SM2_KEY *a, FILE *fp) +{ + uint8_t buf[512]; + const uint8_t *cp = buf; + size_t len; + + if (pem_read(fp, "PUBLIC KEY", buf, &len, sizeof(buf)) != 1) { + error_print(); + return -1; + } + if (sm2_public_key_info_from_der(a, &cp, &len) != 1 + || asn1_length_is_zero(len) != 1) { + error_print(); + return -1; + } + return 1; +} + +int sm2_public_key_equ(const SM2_KEY *sm2_key, const SM2_KEY *pub_key) +{ + if (memcmp(sm2_key, pub_key, sizeof(SM2_POINT)) == 0) { + return 1; + } + return 0; +} + +int sm2_public_key_copy(SM2_KEY *sm2_key, const SM2_KEY *pub_key) +{ + return sm2_key_set_public_key(sm2_key, &pub_key->public_key); +} + +int sm2_public_key_digest(const SM2_KEY *sm2_key, uint8_t dgst[32]) +{ + uint8_t bits[65]; + sm2_point_to_uncompressed_octets(&sm2_key->public_key, bits); + sm3_digest(bits, sizeof(bits), dgst); + return 1; +} + +int sm2_private_key_info_encrypt_to_der(const SM2_KEY *sm2_key, const char *pass, + uint8_t **out, size_t *outlen) +{ + int ret = -1; + uint8_t pkey_info[SM2_PRIVATE_KEY_INFO_DER_SIZE]; + uint8_t *p = pkey_info; + size_t pkey_info_len = 0; + uint8_t salt[16]; + int iter = 65536; + uint8_t iv[16]; + uint8_t key[16]; + SM4_KEY sm4_key; + uint8_t enced_pkey_info[sizeof(pkey_info) + 32]; + size_t enced_pkey_info_len; + + if (!sm2_key || !pass || !outlen) { + error_print(); + return -1; + } + if (sm2_private_key_info_to_der(sm2_key, &p, &pkey_info_len) != 1 + || rand_bytes(salt, sizeof(salt)) != 1 + || rand_bytes(iv, sizeof(iv)) != 1 + || pbkdf2_genkey(DIGEST_sm3(), pass, strlen(pass), + salt, sizeof(salt), iter, sizeof(key), key) != 1) { + error_print(); + goto end; + } + /* + if (pkey_info_len != sizeof(pkey_info)) { + error_print(); + goto end; + } + */ + sm4_set_encrypt_key(&sm4_key, key); + if (sm4_cbc_padding_encrypt( + &sm4_key, iv, pkey_info, pkey_info_len, + enced_pkey_info, &enced_pkey_info_len) != 1 + || pkcs8_enced_private_key_info_to_der( + salt, sizeof(salt), iter, sizeof(key), OID_hmac_sm3, + OID_sm4_cbc, iv, sizeof(iv), + enced_pkey_info, enced_pkey_info_len, out, outlen) != 1) { + error_print(); + goto end; + } + + ret = 1; +end: + gmssl_secure_clear(pkey_info, sizeof(pkey_info)); + gmssl_secure_clear(key, sizeof(key)); + gmssl_secure_clear(&sm4_key, sizeof(sm4_key)); + return ret; +} + +int sm2_private_key_info_decrypt_from_der(SM2_KEY *sm2, + const uint8_t **attrs, size_t *attrs_len, + const char *pass, const uint8_t **in, size_t *inlen) +{ + int ret = -1; + const uint8_t *salt; + size_t saltlen; + int iter; + int keylen; + int prf; + int cipher; + const uint8_t *iv; + size_t ivlen; + uint8_t key[16]; + SM4_KEY sm4_key; + const uint8_t *enced_pkey_info; + size_t enced_pkey_info_len; + uint8_t pkey_info[256]; + const uint8_t *cp = pkey_info; + size_t pkey_info_len; + + if (!sm2 || !attrs || !attrs_len || !pass || !in || !(*in) || !inlen) { + error_print(); + return -1; + } + if (pkcs8_enced_private_key_info_from_der(&salt, &saltlen, &iter, &keylen, &prf, + &cipher, &iv, &ivlen, &enced_pkey_info, &enced_pkey_info_len, in, inlen) != 1 + || asn1_check(keylen == -1 || keylen == 16) != 1 + || asn1_check(prf == - 1 || prf == OID_hmac_sm3) != 1 + || asn1_check(cipher == OID_sm4_cbc) != 1 + || asn1_check(ivlen == 16) != 1 + || asn1_length_le(enced_pkey_info_len, sizeof(pkey_info)) != 1) { + error_print(); + return -1; + } + if (pbkdf2_genkey(DIGEST_sm3(), pass, strlen(pass), salt, saltlen, iter, sizeof(key), key) != 1) { + error_print(); + goto end; + } + sm4_set_decrypt_key(&sm4_key, key); + if (sm4_cbc_padding_decrypt(&sm4_key, iv, enced_pkey_info, enced_pkey_info_len, + pkey_info, &pkey_info_len) != 1 + || sm2_private_key_info_from_der(sm2, attrs, attrs_len, &cp, &pkey_info_len) != 1 + || asn1_length_is_zero(pkey_info_len) != 1) { + error_print(); + goto end; + } + ret = 1; +end: + gmssl_secure_clear(&sm4_key, sizeof(sm4_key)); + gmssl_secure_clear(key, sizeof(key)); + gmssl_secure_clear(pkey_info, sizeof(pkey_info)); + return ret; +} + +int sm2_private_key_info_encrypt_to_pem(const SM2_KEY *sm2_key, const char *pass, FILE *fp) +{ + uint8_t buf[1024]; + uint8_t *p = buf; + size_t len = 0; + + if (!fp) { + error_print(); + return -1; + } + if (sm2_private_key_info_encrypt_to_der(sm2_key, pass, &p, &len) != 1) { + error_print(); + return -1; + } + if (pem_write(fp, "ENCRYPTED PRIVATE KEY", buf, len) != 1) { + error_print(); + return -1; + } + return 1; +} + +int sm2_private_key_info_decrypt_from_pem(SM2_KEY *key, const char *pass, FILE *fp) +{ + uint8_t buf[512]; + const uint8_t *cp = buf; + size_t len; + const uint8_t *attrs; + size_t attrs_len; + + if (!key || !pass || !fp) { + error_print(); + return -1; + } + if (pem_read(fp, "ENCRYPTED PRIVATE KEY", buf, &len, sizeof(buf)) != 1 + || sm2_private_key_info_decrypt_from_der(key, &attrs, &attrs_len, pass, &cp, &len) != 1 + || asn1_length_is_zero(len) != 1) { + error_print(); + return -1; + } + return 1; +} diff --git a/src/sm2_z256_sign.c b/src/sm2_z256_sign.c index 49d1d488..d413c236 100644 --- a/src/sm2_z256_sign.c +++ b/src/sm2_z256_sign.c @@ -20,7 +20,6 @@ #include - typedef SM2_Z256 SM2_U256; #define sm2_u256_one() sm2_z256_one() @@ -52,7 +51,6 @@ typedef SM2_Z256_POINT SM2_U256_POINT; #define sm2_u256_point_get_xy(P,x,y) sm2_z256_point_get_xy((P),(x),(y)) - int sm2_do_sign(const SM2_KEY *key, const uint8_t dgst[32], SM2_SIGNATURE *sig) { SM2_U256_POINT _P, *P = &_P; @@ -82,6 +80,10 @@ int sm2_do_sign(const SM2_KEY *key, const uint8_t dgst[32], SM2_SIGNATURE *sig) sm2_u256_from_bytes(e, dgst); //sm2_bn_print(stderr, 0, 4, "e", e); retry: + + // >>>>>>>>>> BEGIN PRECOMP + + // rand k in [1, n - 1] do { if (sm2_u256_modn_rand(k) != 1) { @@ -96,6 +98,11 @@ retry: //sm2_bn_print(stderr, 0, 4, "x", x); + // 如果我们提前计算了 (k, x) 那么我们在真正做签名的时候就可以利用到这个与计算的表了,直接从表中读取 (k, x) + // 当然这些计算都可以放在sign_fast里面 + + // >>>>>>>>>>> END PRECOMP + // r = e + x (mod n) if (sm2_u256_cmp(e, order) >= 0) { sm2_u256_sub(e, e, order); @@ -132,13 +139,65 @@ retry: return 1; } +// k 和 x1 都是要参与计算的,因此我们返回的是内部格式 +int sm2_do_sign_pre_compute(uint64_t k[4], uint64_t x1[4]) +{ + SM2_Z256_POINT P; + + // rand k in [1, n - 1] + do { + if (sm2_z256_modn_rand(k) != 1) { + error_print(); + return -1; + } + } while (sm2_z256_is_zero(k)); + + // (x1, y1) = kG + sm2_u256_point_mul_generator(&P, k); // 这个函数要粗力度并行,这要怎么做? + sm2_u256_point_get_xy(&P, x1, NULL); + + return 1; +} + +// 实际上这里只有一次mod n的乘法,用barret就可以了 +int sm2_do_sign_fast_ex(const uint64_t d[4], const uint64_t k[4], const uint64_t x1[4], const uint8_t dgst[32], SM2_SIGNATURE *sig) +{ + SM2_Z256_POINT R; + uint64_t e[4]; + uint64_t r[4]; + uint64_t s[4]; + + const uint64_t *order = sm2_z256_order(); + + // e = H(M) + sm2_z256_from_bytes(e, dgst); + if (sm2_z256_cmp(e, order) >= 0) { + sm2_z256_sub(e, e, order); + } + + // r = e + x1 (mod n) + sm2_z256_modn_add(r, e, x1); + + // s = (k + r) * d' - r + sm2_z256_modn_add(s, k, r); + sm2_z256_modn_mul(s, s, d); + sm2_z256_modn_sub(s, s, r); + + sm2_u256_to_bytes(r, sig->r); + sm2_u256_to_bytes(s, sig->s); + + return 1; +} + + // (x1, y1) = k * G // r = e + x1 // s = (k - r * d)/(1 + d) = (k +r - r * d - r)/(1 + d) = (k + r - r(1 +d))/(1 + d) = (k + r)/(1 + d) - r // = -r + (k + r)*(1 + d)^-1 // = -r + (k + r) * d' -int sm2_do_sign_fast(const SM2_Fn d, const uint8_t dgst[32], SM2_SIGNATURE *sig) +// 这个函数是我们真正要调用的,甚至可以替代原来的函数 +int sm2_do_sign_fast(const uint64_t d[4], const uint8_t dgst[32], SM2_SIGNATURE *sig) { SM2_U256_POINT R; SM2_U256 e; @@ -155,6 +214,8 @@ int sm2_do_sign_fast(const SM2_Fn d, const uint8_t dgst[32], SM2_SIGNATURE *sig) sm2_u256_sub(e, e, order); } + /// <<<<<<<<<<< 这里的 (k, x1) 应该是从外部输入的!!,这样才是最快的。 + // rand k in [1, n - 1] do { if (sm2_u256_modn_rand(k) != 1) { @@ -164,16 +225,22 @@ int sm2_do_sign_fast(const SM2_Fn d, const uint8_t dgst[32], SM2_SIGNATURE *sig) } while (sm2_u256_is_zero(k)); // (x1, y1) = kG - sm2_u256_point_mul_generator(&R, k); + sm2_u256_point_mul_generator(&R, k); // 这个函数要粗力度并行,这要怎么做? sm2_u256_point_get_xy(&R, x1, NULL); + /// >>>>>>>>>>>>>>>>>> + // r = e + x1 (mod n) sm2_u256_modn_add(r, e, x1); // 对于快速实现来说,只需要一次乘法 + // 如果 (k, x) 是预计算的,这意味着我们可以并行这个操作 + // 也就是随机产生一些k,然后执行粗力度并行的点乘 + + // s = (k + r) * d' - r - sm2_u256_add(s, k, r); + sm2_u256_modn_add(s, k, r); sm2_u256_modn_mul(s, s, d); sm2_u256_modn_sub(s, s, r); @@ -182,6 +249,72 @@ int sm2_do_sign_fast(const SM2_Fn d, const uint8_t dgst[32], SM2_SIGNATURE *sig) return 1; } +// 这个其实并没有更快,无非就是降低了解析公钥椭圆曲线点的计算量,这个点要转换为内部的Mont格式 +// 这里根本没有modn的乘法 +int sm2_do_verify_fast(const SM2_Z256_POINT *P, const uint8_t dgst[32], const SM2_SIGNATURE *sig) +{ + SM2_U256_POINT R; + SM2_U256 r; + SM2_U256 s; + SM2_U256 e; + SM2_U256 x; + SM2_U256 t; + + const uint64_t *order = sm2_u256_order(); + + sm2_u256_from_bytes(r, sig->r); + // check r in [1, n-1] + if (sm2_u256_is_zero(r) == 1) { + error_print(); + return -1; + } + if (sm2_u256_cmp(r, order) >= 0) { + error_print(); + return -1; + } + + sm2_u256_from_bytes(s, sig->s); + // check s in [1, n-1] + if (sm2_u256_is_zero(s) == 1) { + error_print(); + return -1; + } + if (sm2_u256_cmp(s, order) >= 0) { + error_print(); + return -1; + } + + // e = H(M) + sm2_u256_from_bytes(e, dgst); + + // t = r + s (mod n), check t != 0 + sm2_u256_modn_add(t, r, s); + if (sm2_u256_is_zero(t)) { + error_print(); + return -1; + } + + // Q = s * G + t * P + sm2_u256_point_mul_sum(&R, t, P, s); + sm2_u256_point_get_xy(&R, x, NULL); + + // r' = e + x (mod n) + if (sm2_u256_cmp(e, order) >= 0) { + sm2_u256_sub(e, e, order); + } + if (sm2_u256_cmp(x, order) >= 0) { + sm2_u256_sub(x, x, order); + } + sm2_u256_modn_add(e, e, x); + + // check if r == r' + if (sm2_u256_cmp(e, r) != 0) { + error_print(); + return -1; + } + return 1; +} + int sm2_do_verify(const SM2_KEY *key, const uint8_t dgst[32], const SM2_SIGNATURE *sig) { SM2_U256_POINT _P, *P = &_P; @@ -277,6 +410,27 @@ static int all_zero(const uint8_t *buf, size_t len) return 1; } +int sm2_do_encrypt_pre_compute(uint64_t k[4], uint8_t C1[64]) +{ + SM2_Z256_POINT P; + + // rand k in [1, n - 1] + do { + if (sm2_z256_modn_rand(k) != 1) { + error_print(); + return -1; + } + } while (sm2_z256_is_zero(k)); + + // output C1 = k * G = (x1, y1) + sm2_z256_point_mul_generator(&P, k); + sm2_z256_point_to_bytes(&P, C1); + + return 1; +} + +// 和签名不一样,加密的时候要生成 (k, (x1, y1)) ,也就是y坐标也是需要的 +// 其中k是要参与计算的,但是 (x1, y1) 不参与计算,输出为 bytes 就可以了 int sm2_do_encrypt(const SM2_KEY *key, const uint8_t *in, size_t inlen, SM2_CIPHERTEXT *out) { SM2_U256 k; diff --git a/tests/sm2_enctest.c b/tests/sm2_enctest.c index c4c64851..088fdb5f 100644 --- a/tests/sm2_enctest.c +++ b/tests/sm2_enctest.c @@ -16,88 +16,52 @@ #include #include - - -// 由于当前Ciphertext中椭圆曲线点数据不正确,因此无法通过测试 +// 应该整理出不同编码长度的椭圆曲线点,可以由x求出y static int test_sm2_ciphertext(void) { + struct { + char *label; + size_t ciphertext_size; + } tests[] = { + { "null ciphertext", 0 }, + { "min ciphertext size", SM2_MIN_PLAINTEXT_SIZE }, + { "max ciphertext size", SM2_MAX_PLAINTEXT_SIZE }, + }; + SM2_CIPHERTEXT C; + SM2_KEY sm2_key; uint8_t buf[1024]; - uint8_t *p = buf; - const uint8_t *cp = buf; - size_t len = 0; + size_t i; - memset(&C, 0, sizeof(SM2_CIPHERTEXT)); + rand_bytes(C.hash, 32); + rand_bytes(C.ciphertext, SM2_MAX_PLAINTEXT_SIZE); - cp = p = buf; len = 0; - if (sm2_ciphertext_to_der(&C, &p, &len) != 1) { - error_print(); - return -1; - } - format_print(stderr, 0, 4, "SM2_NULL_CIPHERTEXT_SIZE: %zu\n", len); - format_bytes(stderr, 0, 4, "", buf, len); + for (i = 0; i < sizeof(tests)/sizeof(tests[0]); i++) { + uint8_t *p = buf; + const uint8_t *cp = buf; + size_t len = 0; - if (sm2_ciphertext_from_der(&C, &cp, &len) != 1 - || asn1_length_is_zero(len) != 1) { - error_print(); - return -1; - } + if (sm2_key_generate(&sm2_key) != 1) { + error_print(); + return -1; + } + C.point = sm2_key.public_key; + C.ciphertext_size = tests[i].ciphertext_size; + if (sm2_ciphertext_to_der(&C, &p, &len) != 1) { + error_print(); + return -1; + } - // {0, 0, Hash, MinLen} - C.ciphertext_size = SM2_MIN_PLAINTEXT_SIZE; - cp = p = buf; len = 0; - if (sm2_ciphertext_to_der(&C, &p, &len) != 1) { - error_print(); - return -1; - } - format_print(stderr, 0, 4, "SM2_MIN_PLAINTEXT_SIZE: %zu\n", SM2_MIN_PLAINTEXT_SIZE); - format_print(stderr, 0, 4, "SM2_MIN_CIPHERTEXT_SIZE: %zu\n", len); - format_bytes(stderr, 0, 4, "", buf, len); - if (len != SM2_MIN_CIPHERTEXT_SIZE) { - error_print(); - return -1; - } - if (sm2_ciphertext_from_der(&C, &cp, &len) != 1 - || asn1_length_is_zero(len) != 1) { - error_print(); - return -1; - } + printf("Plaintext size = %zu, SM2Ciphertext DER size %zu\n", tests[i].ciphertext_size, len); - // { 33, 33, Hash, NULL } - memset(&C, 0x80, sizeof(SM2_POINT)); - cp = p = buf; len = 0; - if (sm2_ciphertext_to_der(&C, &p, &len) != 1) { - error_print(); - return -1; - } - format_print(stderr, 0, 4, "ciphertext len: %zu\n", len); - format_bytes(stderr, 0, 4, "", buf, len); - if (sm2_ciphertext_from_der(&C, &cp, &len) != 1 - || asn1_length_is_zero(len) != 1) { - error_print(); - return -1; - } + if (sm2_ciphertext_from_der(&C, &cp, &len) != 1 + || asn1_length_is_zero(len) != 1) { + error_print(); + return -1; + } - // { 33, 33, Hash, MaxLen } - C.ciphertext_size = SM2_MAX_PLAINTEXT_SIZE;//SM2_MAX_PLAINTEXT_SIZE; - cp = p = buf; len = 0; - if (sm2_ciphertext_to_der(&C, &p, &len) != 1) { - error_print(); - return -1; - } - format_print(stderr, 0, 4, "SM2_MAX_PLAINTEXT_SIZE: %zu\n", SM2_MAX_PLAINTEXT_SIZE); - format_print(stderr, 0, 4, "SM2_MAX_CIPHERTEXT_SIZE: %zu\n", len); - format_bytes(stderr, 0, 4, "", buf, len); - if (len != SM2_MAX_CIPHERTEXT_SIZE) { - error_print(); - return -1; - } - if (sm2_ciphertext_from_der(&C, &cp, &len) != 1 - || asn1_length_is_zero(len) != 1) { - error_print(); - return -1; } printf("%s() ok\n", __FUNCTION__); @@ -265,14 +229,6 @@ static int test_sm2_encrypt_fixlen(void) } - -// 应该生成不同情况下的密文! - - - - - - static int test_sm2_encrypt(void) { SM2_KEY sm2_key; @@ -327,7 +283,7 @@ static int test_sm2_encrypt(void) int main(void) { - //if (test_sm2_ciphertext() != 1) goto err; // 需要正确的Ciphertext数据 + if (test_sm2_ciphertext() != 1) goto err; if (test_sm2_do_encrypt() != 1) goto err; if (test_sm2_do_encrypt_fixlen() != 1) goto err; if (test_sm2_encrypt() != 1) goto err; diff --git a/tests/sm2_signtest.c b/tests/sm2_signtest.c index 1a0d10b5..5a2005ea 100644 --- a/tests/sm2_signtest.c +++ b/tests/sm2_signtest.c @@ -109,6 +109,7 @@ static int test_sm2_do_sign(void) #define sm2_u256_modn_add sm2_z256_modn_add #define sm2_u256_modn_inv sm2_z256_modn_inv + static int test_sm2_do_sign_fast(void) { SM2_KEY sm2_key; @@ -141,6 +142,45 @@ static int test_sm2_do_sign_fast(void) return 1; } +static int test_sm2_do_sign_pre_compute(void) +{ + SM2_KEY sm2_key; + uint64_t d[4]; + + uint64_t k[4]; + uint64_t x1[4]; + uint8_t dgst[32]; + SM2_SIGNATURE sig; + + + sm2_key_generate(&sm2_key); + + const uint64_t *one = sm2_z256_one(); + sm2_z256_from_bytes(d, sm2_key.private_key); + sm2_z256_modn_add(d, d, one); + sm2_z256_modn_inv(d, d); + + if (sm2_do_sign_pre_compute(k, x1) != 1) { + error_print(); + return -1; + } + + rand_bytes(dgst, sizeof(dgst)); + + if (sm2_do_sign_fast_ex(d, k, x1, dgst, &sig) != 1) { + error_print(); + return -1; + } + + if (sm2_do_verify(&sm2_key, dgst, &sig) != 1) { + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + static int test_sm2_sign(void) { SM2_KEY sm2_key; @@ -209,12 +249,95 @@ static int test_sm2_sign_ctx(void) return 1; } +static int test_sm2_sign_ctx_reset(void) +{ + SM2_KEY sm2_key; + SM2_SIGN_CTX sign_ctx; + SM2_SIGN_CTX vrfy_ctx; + uint8_t msg[64]; + uint8_t sig[SM2_MAX_SIGNATURE_SIZE]; + size_t siglen; + + if (sm2_key_generate(&sm2_key) != 1) { + error_print(); + return -1; + } + + // init sign_ctx and sign a message + rand_bytes(msg, sizeof(msg)); + + if (sm2_sign_init(&sign_ctx, &sm2_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1) { + error_print(); + return -1; + } + if (sm2_sign_update(&sign_ctx, msg, sizeof(msg)) != 1) { + error_print(); + return -1; + } + if (sm2_sign_finish(&sign_ctx, sig, &siglen) != 1) { + error_print(); + return -1; + } + + // reset sign_ctx and sign another message + rand_bytes(msg, sizeof(msg)); + + if (sm2_sign_ctx_reset(&sign_ctx) != 1) { + error_print(); + return -1; + } + if (sm2_sign_update(&sign_ctx, msg, sizeof(msg)) != 1) { + error_print(); + return -1; + } + if (sm2_sign_finish(&sign_ctx, sig, &siglen) != 1) { + error_print(); + return -1; + } + + // verify, check whether reset works + if (sm2_verify_init(&vrfy_ctx, &sm2_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1) { + error_print(); + return -1; + } + if (sm2_verify_update(&vrfy_ctx, msg, sizeof(msg)) != 1) { + error_print(); + return -1; + } + if (sm2_verify_finish(&vrfy_ctx, sig, siglen) != 1) { + format_bytes(stderr, 0, 4, "signature", sig, siglen); + error_print(); + return -1; + } + + // reset ctx and verify again + if (sm2_sign_ctx_reset(&vrfy_ctx) != 1) { + error_print(); + return -1; + } + if (sm2_verify_update(&vrfy_ctx, msg, sizeof(msg)) != 1) { + error_print(); + return -1; + } + if (sm2_verify_finish(&vrfy_ctx, sig, siglen) != 1) { + format_bytes(stderr, 0, 4, "signature", sig, siglen); + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + int main(void) { + if (test_sm2_do_sign_fast() != 1) goto err; if (test_sm2_signature() != 1) goto err; if (test_sm2_do_sign() != 1) goto err; + if (test_sm2_do_sign_pre_compute() != 1) goto err; if (test_sm2_sign() != 1) goto err; if (test_sm2_sign_ctx() != 1) goto err; + if (test_sm2_sign_ctx_reset() != 1) goto err; printf("%s all tests passed\n", __FILE__); return 0; err: diff --git a/tests/sm2_z256test.c b/tests/sm2_z256test.c index c7016e38..e29cb68f 100644 --- a/tests/sm2_z256test.c +++ b/tests/sm2_z256test.c @@ -12,11 +12,20 @@ #include #include #include +#include #include +#include +#include #include #include #include + +/* +TODO: 验证点加、倍点等计算是否支持无穷远点、共轭点等特殊形势 + +*/ + enum { OP_ADD, OP_DBL, @@ -28,6 +37,67 @@ enum { OP_INV, }; +#define TEST_COUNT 10 + +static int test_sm2_z256_rshift(void) +{ + uint64_t r[4]; + uint64_t a[4]; + uint64_t b[4]; + unsigned int i; + + sm2_z256_modn_rand(a); + + sm2_z256_rshift(r, a, 0); + sm2_z256_copy(b, a); + if (sm2_z256_cmp(r, b) != 0) { + error_print(); + return -1; + } + + sm2_z256_rshift(r, a, 63); + for (i = 0; i < 63; i++) { + sm2_z256_rshift(a, a, 1); + } + if (sm2_z256_cmp(r, a) != 0) { + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + +static int test_sm2_z256_modp_mont_sqrt(void) +{ + uint64_t a[4]; + uint64_t neg_a[4]; + uint64_t mont_a[4]; + uint64_t mont_sqr_a[4]; + uint64_t mont_a_[4]; + uint64_t a_[4]; + int i; + + for (i = 0; i < 6; i++) { + sm2_z256_modn_rand(a); + sm2_z256_modp_neg(neg_a, a); + + sm2_z256_modp_to_mont(a, mont_a); + sm2_z256_modp_mont_sqr(mont_sqr_a, mont_a); + sm2_z256_modp_mont_sqrt(mont_a_, mont_sqr_a); + sm2_z256_modp_from_mont(a_, mont_a_); + + // a_ = sqrt(a^2), a_ should be a or -a + if (sm2_z256_cmp(a_, a) != 0 && sm2_z256_cmp(a_, neg_a) != 0) { + error_print(); + return -1; + } + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + static int test_sm2_z256_modp(void) { struct { @@ -389,6 +459,95 @@ static int test_sm2_z256_point_get_xy(void) return 1; } +static int test_sm2_z256_point_from_x_bytes(void) +{ + struct { + char *label; + char *xy; + int y_is_odd; + } tests[] = { + { + "G (y is even)", + "32c4ae2c1f1981195f9904466a39c9948fe30bbff2660be1715a4589334c74c7" + "bc3736a2f4f6779c59bdcee36b692153d0a9877cc62a474002df32e52139f0a0", + 0, + }, + { + "2G (y is odd)", + "56cefd60d7c87c000d58ef57fa73ba4d9c0dfa08c08a7331495c2e1da3f2bd52" + "31b7e7e6cc8189f668535ce0f8eaf1bd6de84c182f6c8e716f780d3a970a23c3", + 1, + }, + }; + + SM2_Z256_POINT P; + uint8_t x_bytes[32]; + size_t i, len; + + for (i = 0; i < sizeof(tests)/sizeof(tests[0]); i++) { + + hex_to_bytes(tests[i].xy, 64, x_bytes, &len); + + sm2_z256_point_from_x_bytes(&P, x_bytes, tests[i].y_is_odd); + + if (sm2_z256_point_equ_hex(&P, tests[i].xy) != 1) { + error_print(); + return -1; + } + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + +static int test_sm2_z256_point_add_conjugate(void) +{ + char *hex_G = + "32c4ae2c1f1981195f9904466a39c9948fe30bbff2660be1715a4589334c74c7" + "bc3736a2f4f6779c59bdcee36b692153d0a9877cc62a474002df32e52139f0a0"; + char *hex_negG = + "32c4ae2c1f1981195f9904466a39c9948fe30bbff2660be1715a4589334c74c7" + "43c8c95c0b098863a642311c9496deac2f56788239d5b8c0fd20cd1adec60f5f"; + + SM2_Z256_POINT R; + SM2_Z256_POINT P; + SM2_Z256_POINT Q; + + sm2_z256_point_from_hex(&P, hex_G); + sm2_z256_point_from_hex(&Q, hex_negG); + sm2_z256_point_add(&R, &P, &Q); + + // P + (-P) = (0:0:0) + if (!sm2_z256_is_zero(R.X) + || !sm2_z256_is_zero(R.Y) + || !sm2_z256_is_zero(R.Z)) { + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + +static int test_sm2_z256_point_dbl_infinity(void) +{ + SM2_Z256_POINT P_infinity; + SM2_Z256_POINT R; + + sm2_z256_point_set_infinity(&P_infinity); + sm2_z256_point_dbl(&R, &P_infinity); // 显然这个计算就会出错了! + + if (!sm2_z256_point_is_at_infinity(&R)) { + error_print(); // 这个会出错 + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; + +} + + static int test_sm2_z256_point_ops(void) { char *hex_G = @@ -472,6 +631,8 @@ static int test_sm2_z256_point_ops(void) return -1; } + if (sm2_z256_point_equ_hex(&P, tests[i].R) != 1) { + fprintf(stderr, "%s\n", tests[i].label); sm2_z256_point_print(stderr, 0, 4, "R", &P); fprintf(stderr, " R: %s\n", tests[i].R); @@ -479,8 +640,6 @@ static int test_sm2_z256_point_ops(void) fprintf(stderr, " A: %s\n", tests[i].A); fprintf(stderr, " B: %s\n", tests[i].B); - if (sm2_z256_point_equ_hex(&P, tests[i].R) != 1) { - error_print(); return -1; } @@ -614,14 +773,108 @@ static int test_sm2_z256_point_mul_generator(void) return 1; } +static int test_sm2_z256_point_equ(void) +{ + struct { + char *label; + char *mont_X1; + char *mont_Y1; + char *mont_Z1; + char *mont_X2; + char *mont_Y2; + char *mont_Z2; + } tests[] = { + { + "Point at Infinity (1:1:0)", + "0000000100000000000000000000000000000000ffffffff0000000000000001", // mont(1) + "0000000100000000000000000000000000000000ffffffff0000000000000001", // mont(1) + "0000000000000000000000000000000000000000000000000000000000000000", // 0 + "0000000100000000000000000000000000000000ffffffff0000000000000001", // mont(1) + "0000000100000000000000000000000000000000ffffffff0000000000000001", // mont(1) + "0000000000000000000000000000000000000000000000000000000000000000", // 0 + }, + { + "[2]2G == 2G + G + G", + "87b2ca9ded2487c6efdbc69303258763a0b5520fc63cf40154f6c059b945acf2", + "dc86353bc72db45ebb5b2d03cec4614b164688f19f12dd857fd007e181457b59", + "050653f8579d1d2d930d7346e31bad56b5a4654d6a9f2c5022434941744ced3a", + "e8457905838420a51366f7fe174ce34dc3579fefc188f0b5124e7537526ae99e", + "48c3374ab1d5fde0276bebb81b8ff0baa9805cc2d0f487e18d7b3a4352f4ae21", + "79f76fd57f22f1e282d64ff809a53f1f729f6b89c6f626b96725a9d05704e681", + } + }; + + SM2_Z256_POINT P; + SM2_Z256_POINT Q; + size_t i; + + for (i = 0; i < sizeof(tests)/sizeof(tests[0]); i++) { + + sm2_z256_from_hex(P.X, tests[i].mont_X1); + sm2_z256_from_hex(P.Y, tests[i].mont_Y1); + sm2_z256_from_hex(P.Z, tests[i].mont_Z1); + + sm2_z256_from_hex(Q.X, tests[i].mont_X2); + sm2_z256_from_hex(Q.Y, tests[i].mont_Y2); + sm2_z256_from_hex(Q.Z, tests[i].mont_Z2); + + if (sm2_z256_point_equ(&P, &Q) != 1) { + error_print(); + return -1; + } + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + +static int test_sm2_z256_point_from_hash(void) +{ + SM2_Z256_POINT P; + uint8_t data[64]; + size_t datalen = sizeof(data); + int y_is_odd = 1; + int y_is_even = 0; + size_t i; + + for (i = 0; i < 5; i++) { + + rand_bytes(data, datalen); + + if (sm2_z256_point_from_hash(&P, data, datalen, y_is_odd) != 1) { + error_print(); + return -1; + } + if (sm2_z256_point_from_hash(&P, data, datalen, y_is_even) != 1) { + error_print(); + return -1; + } + } + + printf("%s() ok\n", __FUNCTION__); + return 1; + + +} + + int main(void) { + + if (test_sm2_z256_rshift() != 1) goto err; if (test_sm2_z256_modp() != 1) goto err; if (test_sm2_z256_modn() != 1) goto err; if (test_sm2_z256_point_is_on_curve() != 1) goto err; + if (test_sm2_z256_point_equ() != 1) goto err; if (test_sm2_z256_point_get_xy() != 1) goto err; + if (test_sm2_z256_point_add_conjugate() != 1) goto err; + if (test_sm2_z256_point_dbl_infinity() != 1) goto err; if (test_sm2_z256_point_ops() != 1) goto err; if (test_sm2_z256_point_mul_generator() != 1) goto err; + if (test_sm2_z256_point_from_hash() != 1) goto err; + if (test_sm2_z256_point_from_x_bytes() != 1) goto err; + if (test_sm2_z256_modp_mont_sqrt() != 1) goto err; + printf("%s all tests passed\n", __FILE__); return 0; err: diff --git a/tools/sm2speed.c b/tools/sm2speed.c new file mode 100644 index 00000000..d8942a28 --- /dev/null +++ b/tools/sm2speed.c @@ -0,0 +1,59 @@ +/* + * Copyright 2014-2024 The GmSSL Project. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ + + +#include +#include +#include +#include +#include + +int sm2sign_speed(void) +{ + SM2_KEY sm2_key; + SM2_SIGN_CTX sign_ctx; + uint8_t msg[32]; + uint8_t sig[SM2_MAX_SIGNATURE_SIZE]; + size_t siglen; + size_t i; + + if (sm2_key_generate(&sm2_key) != 1) { + error_print(); + return -1; + } + + if (sm2_sign_init(&sign_ctx, &sm2_key, SM2_DEFAULT_ID, strlen(SM2_DEFAULT_ID)) != 1) { + error_print(); + return -1; + } + + for (i = 0; i < 10000; i++) { + /* + if (sm2_sign_init(&sign_ctx, &sm2_key, SM2_DEFAULT_ID, strlen(SM2_DEFAULT_ID)) != 1) { + error_print(); + return -1; + } + */ + if (sm2_sign_ctx_reset(&sign_ctx) != 1 + || sm2_sign_update(&sign_ctx, msg, sizeof(msg)) != 1 + || sm2_sign_finish(&sign_ctx, sig, &siglen) != 1) { + error_print(); + return -1; + } + } + + return 0; +} + +int main(void) +{ + sm2sign_speed(); + return 0; +} + diff --git a/tools/sm3xmss_sign.c b/tools/sm3xmss_sign.c new file mode 100644 index 00000000..b62d02c0 --- /dev/null +++ b/tools/sm3xmss_sign.c @@ -0,0 +1,154 @@ +/* + * Copyright 2014-2023 The GmSSL Project. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ + + +#include +#include +#include +#include +#include +#include + +static const char *usage = "-key file [-in file] [-out file]\n"; + +static const char *help = +"Options\n" +" -key file Input private key file\n" +" -in file Input data file (if not using stdin)\n" +" -out file Output signature file\n" +"\n"; + +int sm3xmss_sign_main(int argc, char **argv) +{ + int ret = 1; + char *prog = argv[0]; + char *keyfile = NULL; + char *infile = NULL; + char *outfile = NULL; + FILE *keyfp = NULL; + FILE *infp = stdin; + FILE *outfp = stdout; + SM3_XMSS_KEY key; + SM3_XMSS_SIGN_CTX sign_ctx; + uint8_t *sigbuf = NULL; + size_t siglen; + + argc--; + argv++; + + if (argc < 1) { + fprintf(stderr, "usage: %s %s\n", prog, usage); + return 1; + } + + while (argc > 0) { + if (!strcmp(*argv, "-help")) { + printf("usage: %s %s\n", prog, usage); + printf("%s\n", help); + ret = 0; + goto end; + } else if (!strcmp(*argv, "-key")) { + if (--argc < 1) goto bad; + keyfile = *(++argv); + if (!(keyfp = fopen(keyfile, "rb"))) { + fprintf(stderr, "%s: open '%s' failure: %s\n", prog, keyfile, strerror(errno)); + goto end; + } + } else if (!strcmp(*argv, "-in")) { + if (--argc < 1) goto bad; + infile = *(++argv); + if (!(infp = fopen(infile, "rb"))) { + fprintf(stderr, "%s: open '%s' failure: %s\n", prog, infile, strerror(errno)); + goto end; + } + } else if (!strcmp(*argv, "-out")) { + if (--argc < 1) goto bad; + outfile = *(++argv); + if (!(outfp = fopen(outfile, "wb"))) { + fprintf(stderr, "%s: open '%s' failure: %s\n", prog, outfile, strerror(errno)); + goto end; + } + } else { + fprintf(stderr, "%s: illegal option '%s'\n", prog, *argv); + goto end; +bad: + fprintf(stderr, "%s: `%s` option value missing\n", prog, *argv); + goto end; + } + + argc--; + argv++; + } + + if (!keyfile) { + fprintf(stderr, "%s: `-key` option required\n", prog); + goto end; + } + + if (sm3_xmss_key_from_bytes(&key, NULL, 0) != 1) { + error_print(); + goto end; + } + + if (fread(&key, 1, sizeof(key), keyfp) != sizeof(key)) { + fprintf(stderr, "%s: read private key failure\n", prog); + goto end; + } + + if (sm3_xmss_sign_init(&sign_ctx, &key) != 1) { + error_print(); + goto end; + } + + while (1) { + uint8_t buf[1024]; + size_t len = fread(buf, 1, sizeof(buf), infp); + if (len == 0) { + break; + } + if (sm3_xmss_sign_update(&sign_ctx, buf, len) != 1) { + error_print(); + goto end; + } + } + + if (sm3_xmss_sign_finish(&sign_ctx, &key, NULL, &siglen) != 1) { + error_print(); + goto end; + } + + if (!(sigbuf = malloc(siglen))) { + fprintf(stderr, "%s: malloc failure\n", prog); + goto end; + } + + if (sm3_xmss_sign_finish(&sign_ctx, &key, sigbuf, &siglen) != 1) { + error_print(); + goto end; + } + + if (fwrite(sigbuf, 1, siglen, outfp) != siglen) { + error_print(); + goto end; + } + + ret = 0; + +end: + gmssl_secure_clear(&key, sizeof(key)); + gmssl_secure_clear(&sign_ctx, sizeof(sign_ctx)); + if (sigbuf) { + gmssl_secure_clear(sigbuf, siglen); + free(sigbuf); + } + if (keyfp) fclose(keyfp); + if (infp && infp != stdin) fclose(infp); + if (outfp && outfp != stdout) fclose(outfp); + return ret; +}