Update SM2

This commit is contained in:
Zhi Guan
2024-03-10 22:34:43 +08:00
parent cfdcd0c0e3
commit 33baa3df92
15 changed files with 2410 additions and 268 deletions

View File

@@ -18,8 +18,11 @@ set(src
src/sm3_hmac.c src/sm3_hmac.c
src/sm3_kdf.c src/sm3_kdf.c
src/sm3_digest.c src/sm3_digest.c
src/sm2_alg.c #src/sm2_alg.c
src/sm2_key.c src/sm2_point.c
src/sm2_z256.c
src/sm2_z256_table.c
src/sm2_z256_key.c
src/sm2_z256_sign.c src/sm2_z256_sign.c
src/sm2_lib.c src/sm2_lib.c
src/sm2_ctx.c src/sm2_ctx.c
@@ -122,7 +125,8 @@ set(tools
set(tests set(tests
sm4 sm4
sm3 sm3
sm2 # sm2
sm2_z256
sm2_sign sm2_sign
sm2_enc sm2_enc
sm9 sm9
@@ -238,6 +242,15 @@ if (ENABLE_SM2_ALGOR_ID_ENCODE_NULL)
endif() endif()
option(ENABLE_SM2_Z256_ARMV8 "Enable SM2_Z256 ARMv8 assembly" OFF)
if (ENABLE_SM2_Z256_ARMV8)
message(STATUS "ENABLE_SM2_Z256_ARMV8 is ON")
add_definitions(-DENABLE_SM2_Z256_ARMV8)
enable_language(ASM)
list(APPEND src src/sm2_z256_armv8.S)
endif()
option(ENABLE_SM2_PRIVATE_KEY_EXPORT "Enable export un-encrypted SM2 private key" OFF) option(ENABLE_SM2_PRIVATE_KEY_EXPORT "Enable export un-encrypted SM2 private key" OFF)
if (ENABLE_SM2_PRIVATE_KEY_EXPORT) if (ENABLE_SM2_PRIVATE_KEY_EXPORT)
message(STATUS "ENABLE_SM2_PRIVATE_KEY_EXPORT is ON") message(STATUS "ENABLE_SM2_PRIVATE_KEY_EXPORT is ON")
@@ -316,15 +329,6 @@ if (ENABLE_SM4_XTS)
endif() endif()
option(ENABLE_SM2_Z256 "Enable SM2 z256 implementation" OFF)
if (ENABLE_SM2_Z256)
message(STATUS "ENABLE_SM2_Z256 is ON")
add_definitions(-DENABLE_SM2_Z256)
list(APPEND src src/sm2_z256.c src/sm2_z256_table.c)
list(APPEND tests sm2_z256)
endif()
option(ENABLE_SM2_EXTS "Enable SM2 Extensions" OFF) option(ENABLE_SM2_EXTS "Enable SM2 Extensions" OFF)
if (ENABLE_SM2_EXTS) if (ENABLE_SM2_EXTS)
message(STATUS "ENABLE_SM4_AESNI_AVX") message(STATUS "ENABLE_SM4_AESNI_AVX")

View File

@@ -15,113 +15,15 @@
#include <stdio.h> #include <stdio.h>
#include <stdint.h> #include <stdint.h>
#include <stdlib.h> #include <stdlib.h>
#include <gmssl/sm3.h>
#include <gmssl/api.h> #include <gmssl/api.h>
#include <gmssl/sm3.h>
#include <gmssl/sm2_z256.h>
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
typedef uint64_t SM2_BN[8];
int sm2_bn_is_zero(const SM2_BN a);
int sm2_bn_is_one(const SM2_BN a);
int sm2_bn_is_odd(const SM2_BN a);
int sm2_bn_cmp(const SM2_BN a, const SM2_BN b);
int sm2_bn_from_hex(SM2_BN r, const char hex[64]);
int sm2_bn_from_asn1_integer(SM2_BN r, const uint8_t *d, size_t dlen);
int sm2_bn_equ_hex(const SM2_BN a, const char *hex);
int sm2_bn_print(FILE *fp, int fmt, int ind, const char *label, const SM2_BN a);
int sm2_bn_rshift(SM2_BN ret, const SM2_BN a, unsigned int nbits);
void sm2_bn_to_bytes(const SM2_BN a, uint8_t out[32]);
void sm2_bn_from_bytes(SM2_BN r, const uint8_t in[32]);
void sm2_bn_to_hex(const SM2_BN a, char hex[64]);
void sm2_bn_to_bits(const SM2_BN a, char bits[256]);
void sm2_bn_set_word(SM2_BN r, uint32_t a);
void sm2_bn_add(SM2_BN r, const SM2_BN a, const SM2_BN b);
void sm2_bn_sub(SM2_BN ret, const SM2_BN a, const SM2_BN b);
int sm2_bn_rand_range(SM2_BN r, const SM2_BN range);
#define sm2_bn_init(r) memset((r),0,sizeof(SM2_BN))
#define sm2_bn_set_zero(r) memset((r),0,sizeof(SM2_BN))
#define sm2_bn_set_one(r) sm2_bn_set_word((r),1)
#define sm2_bn_copy(r,a) memcpy((r),(a),sizeof(SM2_BN))
#define sm2_bn_clean(r) memset((r),0,sizeof(SM2_BN))
// GF(p)
typedef SM2_BN SM2_Fp;
void sm2_fp_add(SM2_Fp r, const SM2_Fp a, const SM2_Fp b);
void sm2_fp_sub(SM2_Fp r, const SM2_Fp a, const SM2_Fp b);
void sm2_fp_mul(SM2_Fp r, const SM2_Fp a, const SM2_Fp b);
void sm2_fp_exp(SM2_Fp r, const SM2_Fp a, const SM2_Fp e);
void sm2_fp_dbl(SM2_Fp r, const SM2_Fp a);
void sm2_fp_tri(SM2_Fp r, const SM2_Fp a);
void sm2_fp_div2(SM2_Fp r, const SM2_Fp a);
void sm2_fp_neg(SM2_Fp r, const SM2_Fp a);
void sm2_fp_sqr(SM2_Fp r, const SM2_Fp a);
void sm2_fp_inv(SM2_Fp r, const SM2_Fp a);
int sm2_fp_rand(SM2_Fp r);
int sm2_fp_sqrt(SM2_Fp r, const SM2_Fp a);
#define sm2_fp_init(r) sm2_bn_init(r)
#define sm2_fp_set_zero(r) sm2_bn_set_zero(r)
#define sm2_fp_set_one(r) sm2_bn_set_one(r)
#define sm2_fp_copy(r,a) sm2_bn_copy(r,a)
#define sm2_fp_clean(r) sm2_bn_clean(r)
// GF(n)
typedef SM2_BN SM2_Fn;
void sm2_fn_add(SM2_Fn r, const SM2_Fn a, const SM2_Fn b);
void sm2_fn_sub(SM2_Fn r, const SM2_Fn a, const SM2_Fn b);
void sm2_fn_mul(SM2_Fn r, const SM2_Fn a, const SM2_Fn b);
void sm2_fn_mul_word(SM2_Fn r, const SM2_Fn a, uint32_t b);
void sm2_fn_exp(SM2_Fn r, const SM2_Fn a, const SM2_Fn e);
void sm2_fn_neg(SM2_Fn r, const SM2_Fn a);
void sm2_fn_sqr(SM2_Fn r, const SM2_Fn a);
void sm2_fn_inv(SM2_Fn r, const SM2_Fn a);
int sm2_fn_rand(SM2_Fn r);
#define sm2_fn_init(r) sm2_bn_init(r)
#define sm2_fn_set_zero(r) sm2_bn_set_zero(r)
#define sm2_fn_set_one(r) sm2_bn_set_one(r)
#define sm2_fn_copy(r,a) sm2_bn_copy(r,a)
#define sm2_fn_clean(r) sm2_bn_clean(r)
typedef struct {
SM2_BN X;
SM2_BN Y;
SM2_BN Z;
} SM2_JACOBIAN_POINT;
void sm2_jacobian_point_init(SM2_JACOBIAN_POINT *R);
void sm2_jacobian_point_set_xy(SM2_JACOBIAN_POINT *R, const SM2_BN x, const SM2_BN y);
void sm2_jacobian_point_get_xy(const SM2_JACOBIAN_POINT *P, SM2_BN x, SM2_BN y);
void sm2_jacobian_point_neg(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P);
void sm2_jacobian_point_dbl(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P);
void sm2_jacobian_point_add(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P, const SM2_JACOBIAN_POINT *Q);
void sm2_jacobian_point_sub(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P, const SM2_JACOBIAN_POINT *Q);
void sm2_jacobian_point_mul(SM2_JACOBIAN_POINT *R, const SM2_BN k, const SM2_JACOBIAN_POINT *P);
void sm2_jacobian_point_to_bytes(const SM2_JACOBIAN_POINT *P, uint8_t out[64]);
void sm2_jacobian_point_from_bytes(SM2_JACOBIAN_POINT *P, const uint8_t in[64]);
void sm2_jacobian_point_mul_generator(SM2_JACOBIAN_POINT *R, const SM2_BN k);
void sm2_jacobian_point_mul_sum(SM2_JACOBIAN_POINT *R, const SM2_BN t, const SM2_JACOBIAN_POINT *P, const SM2_BN s);
void sm2_jacobian_point_from_hex(SM2_JACOBIAN_POINT *P, const char hex[64 * 2]); // for testing only
int sm2_jacobian_point_is_at_infinity(const SM2_JACOBIAN_POINT *P);
int sm2_jacobian_point_is_on_curve(const SM2_JACOBIAN_POINT *P);
int sm2_jacobian_point_equ_hex(const SM2_JACOBIAN_POINT *P, const char hex[128]); // for testing only
int sm2_jacobian_point_print(FILE *fp, int fmt, int ind, const char *label, const SM2_JACOBIAN_POINT *P);
#define sm2_jacobian_point_set_infinity(R) sm2_jacobian_point_init(R)
#define sm2_jacobian_point_copy(R, P) memcpy((R), (P), sizeof(SM2_JACOBIAN_POINT))
typedef uint8_t sm2_bn_t[32]; typedef uint8_t sm2_bn_t[32];
typedef struct { typedef struct {
@@ -131,6 +33,8 @@ typedef struct {
#define sm2_point_init(P) memset((P),0,sizeof(SM2_POINT)) #define sm2_point_init(P) memset((P),0,sizeof(SM2_POINT))
#define sm2_point_set_infinity(P) sm2_point_init(P) #define sm2_point_set_infinity(P) sm2_point_init(P)
int sm2_point_from_octets(SM2_POINT *P, const uint8_t *in, size_t inlen); int sm2_point_from_octets(SM2_POINT *P, const uint8_t *in, size_t inlen);
void sm2_point_to_compressed_octets(const SM2_POINT *P, uint8_t out[33]); void sm2_point_to_compressed_octets(const SM2_POINT *P, uint8_t out[33]);
void sm2_point_to_uncompressed_octets(const SM2_POINT *P, uint8_t out[65]); void sm2_point_to_uncompressed_octets(const SM2_POINT *P, uint8_t out[65]);
@@ -147,6 +51,7 @@ int sm2_point_mul(SM2_POINT *R, const uint8_t k[32], const SM2_POINT *P);
int sm2_point_mul_generator(SM2_POINT *R, const uint8_t k[32]); int sm2_point_mul_generator(SM2_POINT *R, const uint8_t k[32]);
int sm2_point_mul_sum(SM2_POINT *R, const uint8_t k[32], const SM2_POINT *P, const uint8_t s[32]); // R = k * P + s * G int sm2_point_mul_sum(SM2_POINT *R, const uint8_t k[32], const SM2_POINT *P, const uint8_t s[32]); // R = k * P + s * G
/* /*
RFC 5480 Elliptic Curve Cryptography Subject Public Key Information RFC 5480 Elliptic Curve Cryptography Subject Public Key Information
ECPoint ::= OCTET STRING ECPoint ::= OCTET STRING
@@ -163,7 +68,6 @@ typedef struct {
uint8_t private_key[32]; uint8_t private_key[32];
} SM2_KEY; } SM2_KEY;
_gmssl_export int sm2_key_generate(SM2_KEY *key); _gmssl_export int sm2_key_generate(SM2_KEY *key);
int sm2_key_set_private_key(SM2_KEY *key, const uint8_t private_key[32]); // key->public_key will be replaced int sm2_key_set_private_key(SM2_KEY *key, const uint8_t private_key[32]); // key->public_key will be replaced
int sm2_key_set_public_key(SM2_KEY *key, const SM2_POINT *public_key); // key->private_key will be cleared // FIXME: support octets as input? int sm2_key_set_public_key(SM2_KEY *key, const SM2_POINT *public_key); // key->private_key will be cleared // FIXME: support octets as input?
@@ -174,6 +78,7 @@ int sm2_public_key_equ(const SM2_KEY *sm2_key, const SM2_KEY *pub_key);
int sm2_public_key_digest(const SM2_KEY *key, uint8_t dgst[32]); int sm2_public_key_digest(const SM2_KEY *key, uint8_t dgst[32]);
int sm2_public_key_print(FILE *fp, int fmt, int ind, const char *label, const SM2_KEY *pub_key); int sm2_public_key_print(FILE *fp, int fmt, int ind, const char *label, const SM2_KEY *pub_key);
/* /*
from RFC 5915 from RFC 5915
@@ -258,9 +163,14 @@ typedef struct {
} SM2_SIGNATURE; } SM2_SIGNATURE;
int sm2_do_sign(const SM2_KEY *key, const uint8_t dgst[32], SM2_SIGNATURE *sig); int sm2_do_sign(const SM2_KEY *key, const uint8_t dgst[32], SM2_SIGNATURE *sig);
int sm2_do_sign_fast(const SM2_Fn d, const uint8_t dgst[32], SM2_SIGNATURE *sig); int sm2_do_sign_fast(const uint64_t d[4], const uint8_t dgst[32], SM2_SIGNATURE *sig);
int sm2_do_verify(const SM2_KEY *key, const uint8_t dgst[32], const SM2_SIGNATURE *sig); int sm2_do_verify(const SM2_KEY *key, const uint8_t dgst[32], const SM2_SIGNATURE *sig);
int sm2_do_sign_pre_compute(uint64_t k[4], uint64_t x1[4]);
int sm2_do_sign_fast_ex(const uint64_t d[4], const uint64_t k[4], const uint64_t x1[4], const uint8_t dgst[32], SM2_SIGNATURE *sig);
int sm2_do_verify_fast(const SM2_Z256_POINT *P, const uint8_t dgst[32], const SM2_SIGNATURE *sig);
#define SM2_MIN_SIGNATURE_SIZE 8 #define SM2_MIN_SIGNATURE_SIZE 8
#define SM2_MAX_SIGNATURE_SIZE 72 #define SM2_MAX_SIGNATURE_SIZE 72
@@ -277,6 +187,8 @@ enum {
}; };
int sm2_sign_fixlen(const SM2_KEY *key, const uint8_t dgst[32], size_t siglen, uint8_t *sig); int sm2_sign_fixlen(const SM2_KEY *key, const uint8_t dgst[32], size_t siglen, uint8_t *sig);
#define SM2_DEFAULT_ID "1234567812345678" #define SM2_DEFAULT_ID "1234567812345678"
#define SM2_DEFAULT_ID_LENGTH (sizeof(SM2_DEFAULT_ID) - 1) // LENGTH for string and SIZE for bytes #define SM2_DEFAULT_ID_LENGTH (sizeof(SM2_DEFAULT_ID) - 1) // LENGTH for string and SIZE for bytes
#define SM2_DEFAULT_ID_BITS (SM2_DEFAULT_ID_LENGTH * 8) #define SM2_DEFAULT_ID_BITS (SM2_DEFAULT_ID_LENGTH * 8)
@@ -286,9 +198,23 @@ int sm2_sign_fixlen(const SM2_KEY *key, const uint8_t dgst[32], size_t siglen, u
int sm2_compute_z(uint8_t z[32], const SM2_POINT *pub, const char *id, size_t idlen); int sm2_compute_z(uint8_t z[32], const SM2_POINT *pub, const char *id, size_t idlen);
typedef struct {
uint64_t k[4];
uint64_t x1[4];
} SM2_SIGN_PRE_COMP;
typedef struct { typedef struct {
SM3_CTX sm3_ctx; SM3_CTX sm3_ctx;
SM2_KEY key; SM2_KEY key;
// FIXME: change `key` to SM2_Z256_POINT and uint64_t[4], inner type, faster sign/verify
SM2_Z256_POINT public_key; // z256 only
uint64_t sign_key[8]; // u64[8] to support SM2_BN
SM3_CTX inited_sm3_ctx;
SM2_SIGN_PRE_COMP pre_comp[32];
unsigned int num_pre_comp;
} SM2_SIGN_CTX; } SM2_SIGN_CTX;
_gmssl_export int sm2_sign_init(SM2_SIGN_CTX *ctx, const SM2_KEY *key, const char *id, size_t idlen); _gmssl_export int sm2_sign_init(SM2_SIGN_CTX *ctx, const SM2_KEY *key, const char *id, size_t idlen);
@@ -300,6 +226,8 @@ _gmssl_export int sm2_verify_init(SM2_SIGN_CTX *ctx, const SM2_KEY *key, const c
_gmssl_export int sm2_verify_update(SM2_SIGN_CTX *ctx, const uint8_t *data, size_t datalen); _gmssl_export int sm2_verify_update(SM2_SIGN_CTX *ctx, const uint8_t *data, size_t datalen);
_gmssl_export int sm2_verify_finish(SM2_SIGN_CTX *ctx, const uint8_t *sig, size_t siglen); _gmssl_export int sm2_verify_finish(SM2_SIGN_CTX *ctx, const uint8_t *sig, size_t siglen);
_gmssl_export int sm2_sign_ctx_reset(SM2_SIGN_CTX *ctx);
/* /*
SM2Cipher ::= SEQUENCE { SM2Cipher ::= SEQUENCE {
XCoordinate INTEGER, XCoordinate INTEGER,
@@ -356,10 +284,6 @@ _gmssl_export int sm2_decrypt_init(SM2_ENC_CTX *ctx, const SM2_KEY *sm2_key);
_gmssl_export int sm2_decrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); _gmssl_export int sm2_decrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
_gmssl_export int sm2_decrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); _gmssl_export int sm2_decrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
const uint64_t *sm2_bn_prime(void);
const uint64_t *sm2_bn_order(void);
const uint64_t *sm2_bn_one(void);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

131
include/gmssl/sm2_p256.h Normal file
View File

@@ -0,0 +1,131 @@
/*
* Copyright 2014-2024 The GmSSL Project. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the License); you may
* not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*/
#ifndef GMSSL_SM2_P256_H
#define GMSSL_SM2_P256_H
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#ifdef __cplusplus
extern "C" {
#endif
typedef uint64_t SM2_BN[8];
int sm2_bn_is_zero(const SM2_BN a);
int sm2_bn_is_one(const SM2_BN a);
int sm2_bn_is_odd(const SM2_BN a);
int sm2_bn_cmp(const SM2_BN a, const SM2_BN b);
int sm2_bn_from_hex(SM2_BN r, const char hex[64]);
int sm2_bn_from_asn1_integer(SM2_BN r, const uint8_t *d, size_t dlen);
int sm2_bn_equ_hex(const SM2_BN a, const char *hex);
int sm2_bn_print(FILE *fp, int fmt, int ind, const char *label, const SM2_BN a);
int sm2_bn_rshift(SM2_BN ret, const SM2_BN a, unsigned int nbits);
void sm2_bn_to_bytes(const SM2_BN a, uint8_t out[32]);
void sm2_bn_from_bytes(SM2_BN r, const uint8_t in[32]);
void sm2_bn_to_hex(const SM2_BN a, char hex[64]);
void sm2_bn_to_bits(const SM2_BN a, char bits[256]);
void sm2_bn_set_word(SM2_BN r, uint32_t a);
void sm2_bn_add(SM2_BN r, const SM2_BN a, const SM2_BN b);
void sm2_bn_sub(SM2_BN ret, const SM2_BN a, const SM2_BN b);
int sm2_bn_rand_range(SM2_BN r, const SM2_BN range);
#define sm2_bn_init(r) memset((r),0,sizeof(SM2_BN))
#define sm2_bn_set_zero(r) memset((r),0,sizeof(SM2_BN))
#define sm2_bn_set_one(r) sm2_bn_set_word((r),1)
#define sm2_bn_copy(r,a) memcpy((r),(a),sizeof(SM2_BN))
#define sm2_bn_clean(r) memset((r),0,sizeof(SM2_BN))
// GF(p)
typedef SM2_BN SM2_Fp;
void sm2_fp_add(SM2_Fp r, const SM2_Fp a, const SM2_Fp b);
void sm2_fp_sub(SM2_Fp r, const SM2_Fp a, const SM2_Fp b);
void sm2_fp_mul(SM2_Fp r, const SM2_Fp a, const SM2_Fp b);
void sm2_fp_exp(SM2_Fp r, const SM2_Fp a, const SM2_Fp e);
void sm2_fp_dbl(SM2_Fp r, const SM2_Fp a);
void sm2_fp_tri(SM2_Fp r, const SM2_Fp a);
void sm2_fp_div2(SM2_Fp r, const SM2_Fp a);
void sm2_fp_neg(SM2_Fp r, const SM2_Fp a);
void sm2_fp_sqr(SM2_Fp r, const SM2_Fp a);
void sm2_fp_inv(SM2_Fp r, const SM2_Fp a);
int sm2_fp_rand(SM2_Fp r);
int sm2_fp_sqrt(SM2_Fp r, const SM2_Fp a);
#define sm2_fp_init(r) sm2_bn_init(r)
#define sm2_fp_set_zero(r) sm2_bn_set_zero(r)
#define sm2_fp_set_one(r) sm2_bn_set_one(r)
#define sm2_fp_copy(r,a) sm2_bn_copy(r,a)
#define sm2_fp_clean(r) sm2_bn_clean(r)
// GF(n)
typedef SM2_BN SM2_Fn;
void sm2_fn_add(SM2_Fn r, const SM2_Fn a, const SM2_Fn b);
void sm2_fn_sub(SM2_Fn r, const SM2_Fn a, const SM2_Fn b);
void sm2_fn_mul(SM2_Fn r, const SM2_Fn a, const SM2_Fn b);
void sm2_fn_mul_word(SM2_Fn r, const SM2_Fn a, uint32_t b);
void sm2_fn_exp(SM2_Fn r, const SM2_Fn a, const SM2_Fn e);
void sm2_fn_neg(SM2_Fn r, const SM2_Fn a);
void sm2_fn_sqr(SM2_Fn r, const SM2_Fn a);
void sm2_fn_inv(SM2_Fn r, const SM2_Fn a);
int sm2_fn_rand(SM2_Fn r);
#define sm2_fn_init(r) sm2_bn_init(r)
#define sm2_fn_set_zero(r) sm2_bn_set_zero(r)
#define sm2_fn_set_one(r) sm2_bn_set_one(r)
#define sm2_fn_copy(r,a) sm2_bn_copy(r,a)
#define sm2_fn_clean(r) sm2_bn_clean(r)
typedef struct {
SM2_BN X;
SM2_BN Y;
SM2_BN Z;
} SM2_JACOBIAN_POINT;
void sm2_jacobian_point_init(SM2_JACOBIAN_POINT *R);
void sm2_jacobian_point_set_xy(SM2_JACOBIAN_POINT *R, const SM2_BN x, const SM2_BN y);
void sm2_jacobian_point_get_xy(const SM2_JACOBIAN_POINT *P, SM2_BN x, SM2_BN y);
void sm2_jacobian_point_neg(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P);
void sm2_jacobian_point_dbl(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P);
void sm2_jacobian_point_add(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P, const SM2_JACOBIAN_POINT *Q);
void sm2_jacobian_point_sub(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P, const SM2_JACOBIAN_POINT *Q);
void sm2_jacobian_point_mul(SM2_JACOBIAN_POINT *R, const SM2_BN k, const SM2_JACOBIAN_POINT *P);
void sm2_jacobian_point_to_bytes(const SM2_JACOBIAN_POINT *P, uint8_t out[64]);
void sm2_jacobian_point_from_bytes(SM2_JACOBIAN_POINT *P, const uint8_t in[64]);
void sm2_jacobian_point_mul_generator(SM2_JACOBIAN_POINT *R, const SM2_BN k);
void sm2_jacobian_point_mul_sum(SM2_JACOBIAN_POINT *R, const SM2_BN t, const SM2_JACOBIAN_POINT *P, const SM2_BN s);
void sm2_jacobian_point_from_hex(SM2_JACOBIAN_POINT *P, const char hex[64 * 2]); // for testing only
int sm2_jacobian_point_is_at_infinity(const SM2_JACOBIAN_POINT *P);
int sm2_jacobian_point_is_on_curve(const SM2_JACOBIAN_POINT *P);
int sm2_jacobian_point_equ_hex(const SM2_JACOBIAN_POINT *P, const char hex[128]); // for testing only
int sm2_jacobian_point_print(FILE *fp, int fmt, int ind, const char *label, const SM2_JACOBIAN_POINT *P);
#define sm2_jacobian_point_set_infinity(R) sm2_jacobian_point_init(R)
#define sm2_jacobian_point_copy(R, P) memcpy((R), (P), sizeof(SM2_JACOBIAN_POINT))
const uint64_t *sm2_bn_prime(void);
const uint64_t *sm2_bn_order(void);
const uint64_t *sm2_bn_one(void);
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -30,6 +30,7 @@ void sm2_z256_to_bytes(const uint64_t a[4], uint8_t out[32]);
int sm2_z256_cmp(const uint64_t a[4], const uint64_t b[4]); int sm2_z256_cmp(const uint64_t a[4], const uint64_t b[4]);
uint64_t sm2_z256_is_zero(const uint64_t a[4]); uint64_t sm2_z256_is_zero(const uint64_t a[4]);
uint64_t sm2_z256_equ(const uint64_t a[4], const uint64_t b[4]); uint64_t sm2_z256_equ(const uint64_t a[4], const uint64_t b[4]);
void sm2_z256_rshift(uint64_t r[4], const uint64_t a[4], unsigned int nbits);
uint64_t sm2_z256_add(uint64_t r[4], const uint64_t a[4], const uint64_t b[4]); uint64_t sm2_z256_add(uint64_t r[4], const uint64_t a[4], const uint64_t b[4]);
uint64_t sm2_z256_sub(uint64_t r[4], const uint64_t a[4], const uint64_t b[4]); uint64_t sm2_z256_sub(uint64_t r[4], const uint64_t a[4], const uint64_t b[4]);
void sm2_z256_mul(uint64_t r[8], const uint64_t a[4], const uint64_t b[4]); void sm2_z256_mul(uint64_t r[8], const uint64_t a[4], const uint64_t b[4]);
@@ -53,6 +54,7 @@ void sm2_z256_modp_mont_mul(uint64_t r[4], const uint64_t a[4], const uint64_t b
void sm2_z256_modp_mont_sqr(uint64_t r[4], const uint64_t a[4]); void sm2_z256_modp_mont_sqr(uint64_t r[4], const uint64_t a[4]);
void sm2_z256_modp_mont_exp(uint64_t r[4], const uint64_t a[4], const uint64_t e[4]); void sm2_z256_modp_mont_exp(uint64_t r[4], const uint64_t a[4], const uint64_t e[4]);
void sm2_z256_modp_mont_inv(uint64_t r[4], const uint64_t a[4]); void sm2_z256_modp_mont_inv(uint64_t r[4], const uint64_t a[4]);
int sm2_z256_modp_mont_sqrt(uint64_t r[4], const uint64_t a[4]);
int sm2_z256_modp_mont_print(FILE *fp, int ind, int fmt, const char *label, const uint64_t a[4]); int sm2_z256_modp_mont_print(FILE *fp, int ind, int fmt, const char *label, const uint64_t a[4]);
int sm2_z256_modn_rand(uint64_t r[4]); int sm2_z256_modn_rand(uint64_t r[4]);
@@ -79,11 +81,13 @@ typedef struct {
uint64_t Z[4]; uint64_t Z[4];
} SM2_Z256_POINT; } SM2_Z256_POINT;
void sm2_z256_point_set_infinity(SM2_Z256_POINT *P);
void sm2_z256_point_from_bytes(SM2_Z256_POINT *P, const uint8_t in[64]); void sm2_z256_point_from_bytes(SM2_Z256_POINT *P, const uint8_t in[64]);
void sm2_z256_point_to_bytes(const SM2_Z256_POINT *P, uint8_t out[64]); void sm2_z256_point_to_bytes(const SM2_Z256_POINT *P, uint8_t out[64]);
int sm2_z256_point_is_at_infinity(const SM2_Z256_POINT *P); int sm2_z256_point_is_at_infinity(const SM2_Z256_POINT *P);
int sm2_z256_point_is_on_curve(const SM2_Z256_POINT *P); int sm2_z256_point_is_on_curve(const SM2_Z256_POINT *P);
int sm2_z256_point_equ(const SM2_Z256_POINT *P, const SM2_Z256_POINT *Q);
void sm2_z256_point_get_xy(const SM2_Z256_POINT *P, uint64_t x[4], uint64_t y[4]); void sm2_z256_point_get_xy(const SM2_Z256_POINT *P, uint64_t x[4], uint64_t y[4]);
void sm2_z256_point_dbl(SM2_Z256_POINT *R, const SM2_Z256_POINT *A); void sm2_z256_point_dbl(SM2_Z256_POINT *R, const SM2_Z256_POINT *A);
@@ -110,11 +114,25 @@ void sm2_z256_point_mul_sum(SM2_Z256_POINT *R, const uint64_t t[4], const SM2_Z2
const uint64_t *sm2_z256_prime(void); const uint64_t *sm2_z256_prime(void);
const uint64_t *sm2_z256_order(void); const uint64_t *sm2_z256_order(void);
const uint64_t *sm2_z256_order_minus_one(void);
const uint64_t *sm2_z256_one(void); const uint64_t *sm2_z256_one(void);
void sm2_z256_point_from_hex(SM2_Z256_POINT *P, const char *hex); void sm2_z256_point_from_hex(SM2_Z256_POINT *P, const char *hex);
int sm2_z256_point_equ_hex(const SM2_Z256_POINT *P, const char *hex); int sm2_z256_point_equ_hex(const SM2_Z256_POINT *P, const char *hex);
enum {
SM2_point_at_infinity = 0x00,
SM2_point_compressed_y_even = 0x02,
SM2_point_compressed_y_odd = 0x03,
SM2_point_uncompressed = 0x04,
SM2_point_uncompressed_y_even = 0x06,
SM2_point_uncompressed_y_odd = 0x07,
};
int sm2_z256_point_from_x_bytes(SM2_Z256_POINT *P, const uint8_t x_bytes[32], int y_is_odd);
int sm2_z256_point_from_hash(SM2_Z256_POINT *R, const uint8_t *data, size_t datalen, int y_is_odd);
int sm2_z256_point_from_octets(SM2_Z256_POINT *P, const uint8_t *in, size_t inlen);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@@ -539,10 +539,17 @@ int sm2_fp_sqrt(SM2_Fp r, const SM2_Fp a)
SM2_BN u; SM2_BN u;
SM2_BN y; // temp result, prevent call sm2_fp_sqrt(a, a) SM2_BN y; // temp result, prevent call sm2_fp_sqrt(a, a)
printf("sm2_fp_sqrt\n");
sm2_bn_print(stderr, 0, 4, "a", a);
// r = a^((p + 1)/4) when p = 3 (mod 4) // r = a^((p + 1)/4) when p = 3 (mod 4)
sm2_bn_add(u, SM2_P, SM2_ONE); sm2_bn_add(u, SM2_P, SM2_ONE);
sm2_bn_rshift(u, u, 2); sm2_bn_rshift(u, u, 2);
sm2_bn_print(stderr, 0, 4, "u", u);
sm2_fp_exp(y, a, u); sm2_fp_exp(y, a, u);
sm2_bn_print(stderr, 0, 4, "y", y);
// check r^2 == a // check r^2 == a
sm2_fp_sqr(u, y); sm2_fp_sqr(u, y);
@@ -1087,6 +1094,7 @@ int sm2_point_is_at_infinity(const SM2_POINT *P)
return mem_is_zero((uint8_t *)P, sizeof(SM2_POINT)); return mem_is_zero((uint8_t *)P, sizeof(SM2_POINT));
} }
// 这个函数和 sm2_z256_point_from_x_bytes 不一样
int sm2_point_from_x(SM2_POINT *P, const uint8_t x[32], int y) int sm2_point_from_x(SM2_POINT *P, const uint8_t x[32], int y)
{ {
SM2_BN _x, _y, _g, _z; SM2_BN _x, _y, _g, _z;

View File

@@ -14,6 +14,7 @@
#include <stdlib.h> #include <stdlib.h>
#include <gmssl/mem.h> #include <gmssl/mem.h>
#include <gmssl/sm2.h> #include <gmssl/sm2.h>
#include <gmssl/sm2_z256.h>
#include <gmssl/sm3.h> #include <gmssl/sm3.h>
#include <gmssl/asn1.h> #include <gmssl/asn1.h>
#include <gmssl/error.h> #include <gmssl/error.h>
@@ -22,11 +23,19 @@
int sm2_sign_init(SM2_SIGN_CTX *ctx, const SM2_KEY *key, const char *id, size_t idlen) int sm2_sign_init(SM2_SIGN_CTX *ctx, const SM2_KEY *key, const char *id, size_t idlen)
{ {
size_t i;
if (!ctx || !key) { if (!ctx || !key) {
error_print(); error_print();
return -1; return -1;
} }
ctx->key = *key; ctx->key = *key;
// d' = (d + 1)^-1 (mod n)
sm2_z256_from_bytes(ctx->sign_key, key->private_key);
sm2_z256_modn_add(ctx->sign_key, ctx->sign_key, sm2_z256_one());
sm2_z256_modn_inv(ctx->sign_key, ctx->sign_key);
sm3_init(&ctx->sm3_ctx); sm3_init(&ctx->sm3_ctx);
if (id) { if (id) {
@@ -38,6 +47,24 @@ int sm2_sign_init(SM2_SIGN_CTX *ctx, const SM2_KEY *key, const char *id, size_t
sm2_compute_z(z, &key->public_key, id, idlen); sm2_compute_z(z, &key->public_key, id, idlen);
sm3_update(&ctx->sm3_ctx, z, sizeof(z)); sm3_update(&ctx->sm3_ctx, z, sizeof(z));
} }
ctx->inited_sm3_ctx = ctx->sm3_ctx;
// pre compute (k, x = [k]G.x)
for (i = 0; i < 32; i++) {
if (sm2_do_sign_pre_compute(ctx->pre_comp[i].k, ctx->pre_comp[i].x1) != 1) {
error_print();
return -1;
}
}
ctx->num_pre_comp = 32;
return 1;
}
int sm2_sign_ctx_reset(SM2_SIGN_CTX *ctx)
{
ctx->sm3_ctx = ctx->inited_sm3_ctx;
return 1; return 1;
} }
@@ -56,16 +83,39 @@ int sm2_sign_update(SM2_SIGN_CTX *ctx, const uint8_t *data, size_t datalen)
int sm2_sign_finish(SM2_SIGN_CTX *ctx, uint8_t *sig, size_t *siglen) int sm2_sign_finish(SM2_SIGN_CTX *ctx, uint8_t *sig, size_t *siglen)
{ {
uint8_t dgst[SM3_DIGEST_SIZE]; uint8_t dgst[SM3_DIGEST_SIZE];
SM2_SIGNATURE signature;
if (!ctx || !sig || !siglen) { if (!ctx || !sig || !siglen) {
error_print(); error_print();
return -1; return -1;
} }
sm3_finish(&ctx->sm3_ctx, dgst); sm3_finish(&ctx->sm3_ctx, dgst);
if (sm2_sign(&ctx->key, dgst, sig, siglen) != 1) {
if (ctx->num_pre_comp == 0) {
size_t i;
for (i = 0; i < 32; i++) {
if (sm2_do_sign_pre_compute(ctx->pre_comp[i].k, ctx->pre_comp[i].x1) != 1) {
error_print(); error_print();
return -1; return -1;
} }
}
ctx->num_pre_comp = 32;
}
ctx->num_pre_comp--;
if (sm2_do_sign_fast_ex(ctx->sign_key,
ctx->pre_comp[ctx->num_pre_comp].k, ctx->pre_comp[ctx->num_pre_comp].x1,
dgst, &signature) != 1) {
error_print();
return -1;
}
*siglen = 0;
if (sm2_signature_to_der(&signature, &sig, siglen) != 1) {
error_print();
return -1;
}
return 1; return 1;
} }
@@ -93,6 +143,9 @@ int sm2_verify_init(SM2_SIGN_CTX *ctx, const SM2_KEY *key, const char *id, size_
} }
memset(ctx, 0, sizeof(*ctx)); memset(ctx, 0, sizeof(*ctx));
ctx->key.public_key = key->public_key; ctx->key.public_key = key->public_key;
sm2_z256_point_from_bytes(&ctx->public_key, (const uint8_t *)&key->public_key);
sm3_init(&ctx->sm3_ctx); sm3_init(&ctx->sm3_ctx);
if (id) { if (id) {
@@ -104,6 +157,9 @@ int sm2_verify_init(SM2_SIGN_CTX *ctx, const SM2_KEY *key, const char *id, size_
sm2_compute_z(z, &key->public_key, id, idlen); sm2_compute_z(z, &key->public_key, id, idlen);
sm3_update(&ctx->sm3_ctx, z, sizeof(z)); sm3_update(&ctx->sm3_ctx, z, sizeof(z));
} }
ctx->inited_sm3_ctx = ctx->sm3_ctx;
return 1; return 1;
} }
@@ -135,9 +191,6 @@ int sm2_verify_finish(SM2_SIGN_CTX *ctx, const uint8_t *sig, size_t siglen)
return 1; return 1;
} }
int sm2_encrypt_init(SM2_ENC_CTX *ctx, const SM2_KEY *sm2_key) int sm2_encrypt_init(SM2_ENC_CTX *ctx, const SM2_KEY *sm2_key)
{ {
if (!ctx || !sm2_key) { if (!ctx || !sm2_key) {

282
src/sm2_point.c Normal file
View File

@@ -0,0 +1,282 @@
/*
* Copyright 2014-2024 The GmSSL Project. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the License); you may
* not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*/
#include <stdio.h>
#include <string.h>
#include <assert.h>
#include <gmssl/sm2.h>
#include <gmssl/sm2_z256.h>
#include <gmssl/mem.h>
#include <gmssl/asn1.h>
#include <gmssl/rand.h>
#include <gmssl/error.h>
#include <gmssl/endian.h>
int sm2_point_is_on_curve(const SM2_POINT *P)
{
SM2_Z256_POINT T;
sm2_z256_point_from_bytes(&T, (const uint8_t *)P);
if (sm2_z256_point_is_on_curve(&T) == 1) {
return 1;
} else {
return 0;
}
}
int sm2_point_is_at_infinity(const SM2_POINT *P)
{
return mem_is_zero((uint8_t *)P, sizeof(SM2_POINT));
}
int sm2_point_from_x(SM2_POINT *P, const uint8_t x[32], int y_is_odd)
{
SM2_Z256_POINT T;
if (sm2_z256_point_from_x_bytes(&T, x, y_is_odd) != 1) {
error_print();
return -1;
}
sm2_z256_point_to_bytes(&T, (uint8_t *)P);
return 1;
}
int sm2_point_from_xy(SM2_POINT *P, const uint8_t x[32], const uint8_t y[32])
{
memcpy(P->x, x, 32);
memcpy(P->y, y, 32);
return sm2_point_is_on_curve(P);
}
int sm2_point_add(SM2_POINT *R, const SM2_POINT *P, const SM2_POINT *Q)
{
SM2_Z256_POINT P_;
SM2_Z256_POINT Q_;
sm2_z256_point_from_bytes(&P_, (uint8_t *)P);
sm2_z256_point_from_bytes(&Q_, (uint8_t *)Q);
sm2_z256_point_add(&P_, &P_, &Q_);
sm2_z256_point_to_bytes(&P_, (uint8_t *)R);
return 1;
}
int sm2_point_sub(SM2_POINT *R, const SM2_POINT *P, const SM2_POINT *Q)
{
SM2_Z256_POINT P_;
SM2_Z256_POINT Q_;
sm2_z256_point_from_bytes(&P_, (uint8_t *)P);
sm2_z256_point_from_bytes(&Q_, (uint8_t *)Q);
sm2_z256_point_sub(&P_, &P_, &Q_);
sm2_z256_point_to_bytes(&P_, (uint8_t *)R);
return 1;
}
int sm2_point_neg(SM2_POINT *R, const SM2_POINT *P)
{
SM2_Z256_POINT P_;
sm2_z256_point_from_bytes(&P_, (uint8_t *)P);
sm2_z256_point_neg(&P_, &P_);
sm2_z256_point_to_bytes(&P_, (uint8_t *)R);
return 1;
}
int sm2_point_dbl(SM2_POINT *R, const SM2_POINT *P)
{
SM2_Z256_POINT P_;
sm2_z256_point_from_bytes(&P_, (uint8_t *)P);
sm2_z256_point_dbl(&P_, &P_);
sm2_z256_point_to_bytes(&P_, (uint8_t *)R);
return 1;
}
int sm2_point_mul(SM2_POINT *R, const uint8_t k[32], const SM2_POINT *P)
{
uint64_t _k[4];
SM2_Z256_POINT _P;
sm2_z256_from_bytes(_k, k);
sm2_z256_point_from_bytes(&_P, (uint8_t *)P);
sm2_z256_point_mul(&_P, _k, &_P);
sm2_z256_point_to_bytes(&_P, (uint8_t *)R);
memset(_k, 0, sizeof(_k));
return 1;
}
int sm2_point_mul_generator(SM2_POINT *R, const uint8_t k[32])
{
uint64_t _k[4];
SM2_Z256_POINT _R;
sm2_z256_from_bytes(_k, k);
sm2_z256_point_mul_generator(&_R, _k);
sm2_z256_point_to_bytes(&_R, (uint8_t *)R);
memset(_k, 0, sizeof(_k));
return 1;
}
int sm2_point_mul_sum(SM2_POINT *R, const uint8_t k[32], const SM2_POINT *P, const uint8_t s[32])
{
uint64_t _k[4];
SM2_Z256_POINT _P;
uint64_t _s[4];
sm2_z256_from_bytes(_k, k);
sm2_z256_point_from_bytes(&_P, (uint8_t *)P);
sm2_z256_from_bytes(_s, s);
sm2_z256_point_mul_sum(&_P, _k, &_P, _s);
sm2_z256_point_to_bytes(&_P, (uint8_t *)R);
memset(_k, 0, sizeof(_k));
memset(_s, 0, sizeof(_s));
return 1;
}
int sm2_point_print(FILE *fp, int fmt, int ind, const char *label, const SM2_POINT *P)
{
format_print(fp, fmt, ind, "%s\n", label);
ind += 4;
format_bytes(fp, fmt, ind, "x", P->x, 32);
format_bytes(fp, fmt, ind, "y", P->y, 32);
return 1;
}
void sm2_point_to_compressed_octets(const SM2_POINT *P, uint8_t out[33])
{
*out++ = (P->y[31] & 0x01) ? 0x03 : 0x02;
memcpy(out, P->x, 32);
}
void sm2_point_to_uncompressed_octets(const SM2_POINT *P, uint8_t out[65])
{
*out++ = 0x04;
memcpy(out, P, 64);
}
int sm2_z256_point_from_octets(SM2_Z256_POINT *P, const uint8_t *in, size_t inlen)
{
switch (*in) {
case SM2_point_at_infinity:
if (inlen != 1) {
error_print();
return -1;
}
sm2_z256_point_set_infinity(P);
break;
case SM2_point_compressed_y_even:
if (inlen != 33) {
error_print();
return -1;
}
if (sm2_z256_point_from_x_bytes(P, in + 1, 0) != 1) {
error_print();
return -1;
}
break;
case SM2_point_compressed_y_odd:
if (inlen != 33) {
error_print();
return -1;
}
if (sm2_z256_point_from_x_bytes(P, in + 1, 1) != 1) {
error_print();
return -1;
}
break;
case SM2_point_uncompressed:
if (inlen != 65) {
error_print();
return -1;
}
sm2_z256_point_from_bytes(P, in + 1);
if (sm2_z256_point_is_on_curve(P) != 1) {
error_print();
return -1;
}
break;
default:
error_print();
return -1;
}
return 1;
}
int sm2_point_from_octets(SM2_POINT *P, const uint8_t *in, size_t inlen)
{
if ((*in == 0x02 || *in == 0x03) && inlen == 33) {
if (sm2_point_from_x(P, in + 1, *in) != 1) {
error_print();
return -1;
}
} else if (*in == 0x04 && inlen == 65) {
if (sm2_point_from_xy(P, in + 1, in + 33) != 1) {
error_print();
return -1;
}
} else {
error_print();
return -1;
}
return 1;
}
int sm2_point_to_der(const SM2_POINT *P, uint8_t **out, size_t *outlen)
{
uint8_t octets[65];
if (!P) {
return 0;
}
sm2_point_to_uncompressed_octets(P, octets);
if (asn1_octet_string_to_der(octets, sizeof(octets), out, outlen) != 1) {
error_print();
return -1;
}
return 1;
}
int sm2_point_from_der(SM2_POINT *P, const uint8_t **in, size_t *inlen)
{
int ret;
const uint8_t *d;
size_t dlen;
if ((ret = asn1_octet_string_from_der(&d, &dlen, in, inlen)) != 1) {
if (ret < 0) error_print();
return ret;
}
if (dlen != 65) {
error_print();
return -1;
}
if (sm2_point_from_octets(P, d, dlen) != 1) {
error_print();
return -1;
}
return 1;
}
int sm2_point_from_hash(SM2_POINT *R, const uint8_t *data, size_t datalen)
{
return 1;
}

View File

@@ -52,6 +52,8 @@
#include <gmssl/rand.h> #include <gmssl/rand.h>
#include <gmssl/endian.h> #include <gmssl/endian.h>
#include <gmssl/sm2_z256.h> #include <gmssl/sm2_z256.h>
#include <gmssl/sm3.h>
/* /*
SM2 parameters SM2 parameters
@@ -71,7 +73,10 @@ const uint64_t *sm2_z256_one(void) {
return &SM2_Z256_ONE[0]; return &SM2_Z256_ONE[0];
} }
void sm2_z256_set_zero(uint64_t a[4])
{
a[0] = a[1] = a[2] = a[3] = 0;
}
int sm2_z256_rand_range(uint64_t r[4], const uint64_t range[4]) int sm2_z256_rand_range(uint64_t r[4], const uint64_t range[4])
{ {
@@ -161,6 +166,23 @@ uint64_t sm2_z256_is_zero(const uint64_t a[4])
is_zero(a[3]); is_zero(a[3]);
} }
void sm2_z256_rshift(uint64_t r[4], const uint64_t a[4], unsigned int nbits)
{
nbits &= 0x3f;
if (nbits) {
r[0] = a[0] >> nbits;
r[0] |= a[1] << (64 - nbits);
r[1] = a[1] >> nbits;
r[1] |= a[2] << (64 - nbits);
r[2] = a[2] >> nbits;
r[2] |= a[3] << (64 - nbits);
r[3] = a[3] >> nbits;
} else {
sm2_z256_copy(r, a);
}
}
uint64_t sm2_z256_add(uint64_t r[4], const uint64_t a[4], const uint64_t b[4]) uint64_t sm2_z256_add(uint64_t r[4], const uint64_t a[4], const uint64_t b[4])
{ {
uint64_t t, c = 0; uint64_t t, c = 0;
@@ -351,6 +373,9 @@ int sm2_z512_print(FILE *fp, int ind, int fmt, const char *label, const uint64_t
const uint64_t SM2_Z256_P[4] = { const uint64_t SM2_Z256_P[4] = {
0xffffffffffffffff, 0xffffffff00000000, 0xffffffffffffffff, 0xfffffffeffffffff, 0xffffffffffffffff, 0xffffffff00000000, 0xffffffffffffffff, 0xfffffffeffffffff,
}; };
// 注意这里 SM2_Z256_P[0] 和 SM2_Z256_P[2] 是特殊值,在汇编中可以根据这个特殊值做特定的实现
const uint64_t *sm2_z256_prime(void) { const uint64_t *sm2_z256_prime(void) {
return &SM2_Z256_P[0]; return &SM2_Z256_P[0];
@@ -362,6 +387,7 @@ const uint64_t SM2_Z256_NEG_P[4] = {
1, ((uint64_t)1 << 32) - 1, 0, ((uint64_t)1 << 32), 1, ((uint64_t)1 << 32) - 1, 0, ((uint64_t)1 << 32),
}; };
#ifndef ENABLE_SM2_Z256_ARMV8
void sm2_z256_modp_add(uint64_t r[4], const uint64_t a[4], const uint64_t b[4]) void sm2_z256_modp_add(uint64_t r[4], const uint64_t a[4], const uint64_t b[4])
{ {
uint64_t c; uint64_t c;
@@ -404,6 +430,11 @@ void sm2_z256_modp_mul_by_3(uint64_t r[4], const uint64_t a[4])
sm2_z256_modp_add(r, t, a); sm2_z256_modp_add(r, t, a);
} }
void sm2_z256_modp_neg(uint64_t r[4], const uint64_t a[4])
{
(void)sm2_z256_sub(r, SM2_Z256_P, a);
}
void sm2_z256_modp_div_by_2(uint64_t r[4], const uint64_t a[4]) void sm2_z256_modp_div_by_2(uint64_t r[4], const uint64_t a[4])
{ {
uint64_t c = 0; uint64_t c = 0;
@@ -422,11 +453,9 @@ void sm2_z256_modp_div_by_2(uint64_t r[4], const uint64_t a[4])
r[2] = (r[2] >> 1) | ((r[3] & 1) << 63); r[2] = (r[2] >> 1) | ((r[3] & 1) << 63);
r[3] = (r[3] >> 1) | ((c & 1) << 63); r[3] = (r[3] >> 1) | ((c & 1) << 63);
} }
#endif
void sm2_z256_modp_neg(uint64_t r[4], const uint64_t a[4]) // p' * p = -1 mod 2^256
{
(void)sm2_z256_sub(r, SM2_Z256_P, a);
}
// p' = -p^(-1) mod 2^256 // p' = -p^(-1) mod 2^256
// = fffffffc00000001fffffffe00000000ffffffff000000010000000000000001 // = fffffffc00000001fffffffe00000000ffffffff000000010000000000000001
@@ -435,10 +464,12 @@ const uint64_t SM2_Z256_P_PRIME[4] = {
0x0000000000000001, 0xffffffff00000001, 0xfffffffe00000000, 0xfffffffc00000001, 0x0000000000000001, 0xffffffff00000001, 0xfffffffe00000000, 0xfffffffc00000001,
}; };
// mont(1) (mod p) = 2^256 mod p = 2^256 - p // mont(1) (mod p) = 2^256 mod p = 2^256 - p
const uint64_t *SM2_Z256_MODP_MONT_ONE = SM2_Z256_NEG_P; const uint64_t *SM2_Z256_MODP_MONT_ONE = SM2_Z256_NEG_P;
// z = xy #ifndef ENABLE_SM2_Z256_ARMV8
// z = a*b
// c = (z + (z * p' mod 2^256) * p)/2^256 // c = (z + (z * p' mod 2^256) * p)/2^256
void sm2_z256_modp_mont_mul(uint64_t r[4], const uint64_t a[4], const uint64_t b[4]) void sm2_z256_modp_mont_mul(uint64_t r[4], const uint64_t a[4], const uint64_t b[4])
{ {
@@ -484,6 +515,24 @@ void sm2_z256_modp_mont_sqr(uint64_t r[4], const uint64_t a[4])
sm2_z256_modp_mont_mul(r, a, a); sm2_z256_modp_mont_mul(r, a, a);
} }
// mont(mont(a), 1) = aR * 1 * R^-1 (mod p) = a (mod p)
void sm2_z256_modp_from_mont(uint64_t r[4], const uint64_t a[4])
{
sm2_z256_modp_mont_mul(r, a, SM2_Z256_ONE);
}
// 2^512 (mod p)
const uint64_t SM2_Z256_2e512modp[4] = {
0x0000000200000003, 0x00000002ffffffff, 0x0000000100000001, 0x0000000400000002
};
// mont(a) = a * 2^256 (mod p) = mont_mul(a, 2^512 mod p)
void sm2_z256_modp_to_mont(const uint64_t a[4], uint64_t r[4])
{
sm2_z256_modp_mont_mul(r, a, SM2_Z256_2e512modp);
}
#endif
void sm2_z256_modp_mont_exp(uint64_t r[4], const uint64_t a[4], const uint64_t e[4]) void sm2_z256_modp_mont_exp(uint64_t r[4], const uint64_t a[4], const uint64_t e[4])
{ {
uint64_t t[4]; uint64_t t[4];
@@ -589,21 +638,30 @@ void sm2_z256_modp_mont_inv(uint64_t r[4], const uint64_t a[4])
sm2_z256_modp_mont_mul(r, a4, a5); sm2_z256_modp_mont_mul(r, a4, a5);
} }
// mont(mont(a), 1) = aR * 1 * R^-1 (mod p) = a (mod p) // (p+1)/4 = 3fffffffbfffffffffffffffffffffffffffffffc00000004000000000000000
void sm2_z256_modp_from_mont(uint64_t r[4], const uint64_t a[4]) const uint64_t SM2_Z256_SQRT_EXP[4] = {
{ 0x4000000000000000, 0xffffffffc0000000, 0xffffffffffffffff, 0x3fffffffbfffffff,
sm2_z256_modp_mont_mul(r, a, SM2_Z256_ONE);
}
// 2^512 (mod p)
const uint64_t SM2_Z256_2e512modp[4] = {
0x0000000200000003, 0x00000002ffffffff, 0x0000000100000001, 0x0000000400000002
}; };
// mont(a) = a * 2^256 (mod p) = mont_mul(a, 2^512 mod p) // -r (mod p), i.e. (p - r) is also a square root of a
void sm2_z256_modp_to_mont(const uint64_t a[4], uint64_t r[4]) int sm2_z256_modp_mont_sqrt(uint64_t r[4], const uint64_t a[4])
{ {
sm2_z256_modp_mont_mul(r, a, SM2_Z256_2e512modp); uint64_t a_[4];
uint64_t r_[4]; // temp result, prevent call sm2_fp_sqrt(a, a)
// r = a^((p + 1)/4) when p = 3 (mod 4)
sm2_z256_modp_mont_exp(r_, a, SM2_Z256_SQRT_EXP);
// check r^2 == a
sm2_z256_modp_mont_sqr(a_, r_);
if (sm2_z256_cmp(a_, a) != 0) {
// not every number has a square root, so it is not an error
// `sm2_z256_point_from_hash` need a non-negative return value
return 0;
}
sm2_z256_copy(r, r_);
return 1;
} }
int sm2_z256_modp_mont_print(FILE *fp, int ind, int fmt, const char *label, const uint64_t a[4]) int sm2_z256_modp_mont_print(FILE *fp, int ind, int fmt, const char *label, const uint64_t a[4])
@@ -621,6 +679,11 @@ const uint64_t SM2_Z256_N[4] = {
0x53bbf40939d54123, 0x7203df6b21c6052b, 0xffffffffffffffff, 0xfffffffeffffffff, 0x53bbf40939d54123, 0x7203df6b21c6052b, 0xffffffffffffffff, 0xfffffffeffffffff,
}; };
const uint64_t SM2_Z256_N_MINUS_ONE[4] = {
0x53bbf40939d54122, 0x7203df6b21c6052b, 0xffffffffffffffff, 0xfffffffeffffffff,
};
// 2^256 - n = 0x10000000000000000000000008dfc2094de39fad4ac440bf6c62abedd // 2^256 - n = 0x10000000000000000000000008dfc2094de39fad4ac440bf6c62abedd
const uint64_t SM2_Z256_NEG_N[4] = { const uint64_t SM2_Z256_NEG_N[4] = {
0xac440bf6c62abedd, 0x8dfc2094de39fad4, 0x0000000000000000, 0x0000000100000000, 0xac440bf6c62abedd, 0x8dfc2094de39fad4, 0x0000000000000000, 0x0000000100000000,
@@ -680,6 +743,10 @@ const uint64_t *sm2_z256_order(void) {
return &SM2_Z256_N[0]; return &SM2_Z256_N[0];
} }
const uint64_t *sm2_z256_order_minus_one(void) {
return &SM2_Z256_N_MINUS_ONE[0];
}
// mont(1) (mod n) = 2^256 - n // mont(1) (mod n) = 2^256 - n
const uint64_t *SM2_Z256_MODN_MONT_ONE = SM2_Z256_NEG_N; const uint64_t *SM2_Z256_MODN_MONT_ONE = SM2_Z256_NEG_N;
@@ -784,10 +851,45 @@ void sm2_z256_modn_exp(uint64_t r[4], const uint64_t a[4], const uint64_t e[4])
const uint64_t SM2_Z256_N_MINUS_TWO[4] = { const uint64_t SM2_Z256_N_MINUS_TWO[4] = {
0x53bbf40939d54121, 0x7203df6b21c6052b, 0xffffffffffffffff, 0xfffffffeffffffff, 0x53bbf40939d54121, 0x7203df6b21c6052b, 0xffffffffffffffff, 0xfffffffeffffffff,
}; };
// exp都是从高位开始的如果都是1的话那么就是都要平方和乘
void sm2_z256_modn_mont_inv(uint64_t r[4], const uint64_t a[4]) void sm2_z256_modn_mont_inv(uint64_t r[4], const uint64_t a[4])
{ {
sm2_z256_modn_mont_exp(r, a, SM2_Z256_N_MINUS_TWO); // expand sm2_z256_modn_mont_exp(r, a, SM2_Z256_N_MINUS_TWO)
uint64_t t[4];
uint64_t w;
int i;
int k = 0;
sm2_z256_copy(t, a);
for (i = 0; i < 30; i++) {
sm2_z256_modn_mont_sqr(t, t);
sm2_z256_modn_mont_mul(t, t, a);
}
sm2_z256_modn_mont_sqr(t, t);
for (i = 0; i < 96; i++) {
sm2_z256_modn_mont_sqr(t, t);
sm2_z256_modn_mont_mul(t, t, a);
}
w = SM2_Z256_N_MINUS_TWO[1];
for (i = 0; i < 64; i++) {
sm2_z256_modn_mont_sqr(t, t);
if (w & 0x8000000000000000) {
sm2_z256_modn_mont_mul(t, t, a);
}
w <<= 1;
}
w = SM2_Z256_N_MINUS_TWO[0];
for (i = 0; i < 64; i++) {
sm2_z256_modn_mont_sqr(t, t);
if (w & 0x8000000000000000) {
sm2_z256_modn_mont_mul(t, t, a);
}
w <<= 1;
}
sm2_z256_copy(r, t);
} }
void sm2_z256_modn_inv(uint64_t r[4], const uint64_t a[4]) void sm2_z256_modn_inv(uint64_t r[4], const uint64_t a[4])
@@ -805,7 +907,6 @@ void sm2_z256_modn_from_mont(uint64_t r[4], const uint64_t a[4])
sm2_z256_modn_mont_mul(r, a, SM2_Z256_ONE); sm2_z256_modn_mont_mul(r, a, SM2_Z256_ONE);
} }
// 2^512 (mod n) = 0x1eb5e412a22b3d3b620fc84c3affe0d43464504ade6fa2fa901192af7c114f20 // 2^512 (mod n) = 0x1eb5e412a22b3d3b620fc84c3affe0d43464504ade6fa2fa901192af7c114f20
const uint64_t SM2_Z256_2e512modn[4] = { const uint64_t SM2_Z256_2e512modn[4] = {
0x901192af7c114f20, 0x3464504ade6fa2fa, 0x620fc84c3affe0d4, 0x1eb5e412a22b3d3b, 0x901192af7c114f20, 0x3464504ade6fa2fa, 0x620fc84c3affe0d4, 0x1eb5e412a22b3d3b,
@@ -828,11 +929,29 @@ int sm2_z256_modn_mont_print(FILE *fp, int ind, int fmt, const char *label, cons
// Jacobian Point with Montgomery coordinates // Jacobian Point with Montgomery coordinates
void sm2_z256_point_set_infinity(SM2_Z256_POINT *P)
{
sm2_z256_copy(P->X, SM2_Z256_MODP_MONT_ONE);
sm2_z256_copy(P->Y, SM2_Z256_MODP_MONT_ONE);
sm2_z256_set_zero(P->Z);
}
// 这里还应该检查X == Y == mont(1) // point at infinity should be like (k^2 : k^3 : 0), k in [0, p-1]
int sm2_z256_point_is_at_infinity(const SM2_Z256_POINT *P) int sm2_z256_point_is_at_infinity(const SM2_Z256_POINT *P)
{ {
if (sm2_z256_is_zero(P->Z)) { if (sm2_z256_is_zero(P->Z)) {
uint64_t X_cub[4];
uint64_t Y_sqr[4];
sm2_z256_modp_mont_sqr(X_cub, P->X);
sm2_z256_modp_mont_mul(X_cub, X_cub, P->X);
sm2_z256_modp_mont_sqr(Y_sqr, P->Y);
if (sm2_z256_cmp(X_cub, Y_sqr) != 0) {
error_print();
return 0;
}
return 1; return 1;
} else { } else {
return 0; return 0;
@@ -907,6 +1026,34 @@ void sm2_z256_point_get_xy(const SM2_Z256_POINT *P, uint64_t x[4], uint64_t y[4]
} }
} }
// impl with modified jacobian coordinates
void sm2_z256_point_dbl_x5(SM2_Z256_POINT *R, const SM2_Z256_POINT *A)
{
sm2_z256_point_dbl(R, A);
sm2_z256_point_dbl(R, R);
sm2_z256_point_dbl(R, R);
sm2_z256_point_dbl(R, R);
sm2_z256_point_dbl(R, R);
}
void sm2_z256_point_multi_dbl(SM2_Z256_POINT *R, const SM2_Z256_POINT *P, unsigned int i)
{
const uint64_t *X1 = P->X;
const uint64_t *Y1 = P->Y;
const uint64_t *Z1 = P->Z;
uint64_t *X3 = R->X;
uint64_t *Y3 = R->Y;
uint64_t *Z3 = R->Z;
uint64_t A[4];
uint64_t B[4];
uint64_t C[4];
uint64_t D[4];
uint64_t E[4];
// A = Z1^2
}
void sm2_z256_point_dbl(SM2_Z256_POINT *R, const SM2_Z256_POINT *A) void sm2_z256_point_dbl(SM2_Z256_POINT *R, const SM2_Z256_POINT *A)
{ {
const uint64_t *X1 = A->X; const uint64_t *X1 = A->X;
@@ -922,77 +1069,75 @@ void sm2_z256_point_dbl(SM2_Z256_POINT *R, const SM2_Z256_POINT *A)
// S = 2*Y1 // S = 2*Y1
sm2_z256_modp_mul_by_2(S, Y1); sm2_z256_modp_mul_by_2(S, Y1);
//sm2_z256_modp_mont_print(stderr, 0, 0, "1", S);
// Zsqr = Z1^2 // Zsqr = Z1^2
sm2_z256_modp_mont_sqr(Zsqr, Z1); sm2_z256_modp_mont_sqr(Zsqr, Z1);
//sm2_z256_modp_mont_print(stderr, 0, 0, "2", Zsqr);
// S = S^2 = 4*Y1^2 // S = S^2 = 4*Y1^2
sm2_z256_modp_mont_sqr(S, S); sm2_z256_modp_mont_sqr(S, S);
//sm2_z256_modp_mont_print(stderr, 0, 0, "3", S);
// Z3 = Z1 * Y1 // Z3 = Z1 * Y1
sm2_z256_modp_mont_mul(Z3, Z1, Y1); sm2_z256_modp_mont_mul(Z3, Z1, Y1);
//sm2_z256_modp_mont_print(stderr, 0, 0, "4", Z3);
// Z3 = 2 * Z3 = 2*Y1*Z1 // Z3 = 2 * Z3 = 2*Y1*Z1
sm2_z256_modp_mul_by_2(Z3, Z3); sm2_z256_modp_mul_by_2(Z3, Z3);
//sm2_z256_modp_mont_print(stderr, 0, 0, "5", Z3);
// M = X1 + Zsqr = X1 + Z1^2 // M = X1 + Zsqr = X1 + Z1^2
sm2_z256_modp_add(M, X1, Zsqr); sm2_z256_modp_add(M, X1, Zsqr);
//sm2_z256_modp_mont_print(stderr, 0, 0, "6", M);
// Zsqr = X1 - Zsqr = X1 - Z1^2 // Zsqr = X1 - Zsqr = X1 - Z1^2
sm2_z256_modp_sub(Zsqr, X1, Zsqr); sm2_z256_modp_sub(Zsqr, X1, Zsqr);
//sm2_z256_modp_mont_print(stderr, 0, 0, "7", Zsqr);
// Y3 = S^2 = 16 * Y1^4 // Y3 = S^2 = 16 * Y1^4
sm2_z256_modp_mont_sqr(Y3, S); sm2_z256_modp_mont_sqr(Y3, S);
//sm2_z256_modp_mont_print(stderr, 0, 0, "8", Y3);
// Y3 = Y3/2 = 8 * Y1^4 // Y3 = Y3/2 = 8 * Y1^4
sm2_z256_modp_div_by_2(Y3, Y3); sm2_z256_modp_div_by_2(Y3, Y3);
//sm2_z256_modp_mont_print(stderr, 0, 0, "9", Y3);
// M = M * Zsqr = (X1 + Z1^2)(X1 - Z1^2) // M = M * Zsqr = (X1 + Z1^2)(X1 - Z1^2)
sm2_z256_modp_mont_mul(M, M, Zsqr); sm2_z256_modp_mont_mul(M, M, Zsqr);
//sm2_z256_modp_mont_print(stderr, 0, 0, "10", M);
// M = 3*M = 3(X1 + Z1^2)(X1 - Z1^2) // M = 3*M = 3(X1 + Z1^2)(X1 - Z1^2)
sm2_z256_modp_mul_by_3(M, M); sm2_z256_modp_mul_by_3(M, M);
//sm2_z256_modp_mont_print(stderr, 0, 0, "11", M);
// S = S * X1 = 4 * X1 * Y1^2 // S = S * X1 = 4 * X1 * Y1^2
sm2_z256_modp_mont_mul(S, S, X1); sm2_z256_modp_mont_mul(S, S, X1);
//sm2_z256_modp_mont_print(stderr, 0, 0, "12", S);
// tmp0 = 2 * S = 8 * X1 * Y1^2 // tmp0 = 2 * S = 8 * X1 * Y1^2
sm2_z256_modp_mul_by_2(tmp0, S); sm2_z256_modp_mul_by_2(tmp0, S);
//sm2_z256_modp_mont_print(stderr, 0, 0, "13", tmp0);
// X3 = M^2 = (3(X1 + Z1^2)(X1 - Z1^2))^2 // X3 = M^2 = (3(X1 + Z1^2)(X1 - Z1^2))^2
sm2_z256_modp_mont_sqr(X3, M); sm2_z256_modp_mont_sqr(X3, M);
//sm2_z256_modp_mont_print(stderr, 0, 0, "14", X3);
// X3 = X3 - tmp0 = (3(X1 + Z1^2)(X1 - Z1^2))^2 - 8 * X1 * Y1^2 // X3 = X3 - tmp0 = (3(X1 + Z1^2)(X1 - Z1^2))^2 - 8 * X1 * Y1^2
sm2_z256_modp_sub(X3, X3, tmp0); sm2_z256_modp_sub(X3, X3, tmp0);
//sm2_z256_modp_mont_print(stderr, 0, 0, "15", X3);
// S = S - X3 = 4 * X1 * Y1^2 - X3 // S = S - X3 = 4 * X1 * Y1^2 - X3
sm2_z256_modp_sub(S, S, X3); sm2_z256_modp_sub(S, S, X3);
//sm2_z256_modp_mont_print(stderr, 0, 0, "16", S);
// S = S * M = 3(X1 + Z1^2)(X1 - Z1^2)(4 * X1 * Y1^2 - X3) // S = S * M = 3(X1 + Z1^2)(X1 - Z1^2)(4 * X1 * Y1^2 - X3)
sm2_z256_modp_mont_mul(S, S, M); sm2_z256_modp_mont_mul(S, S, M);
//sm2_z256_modp_mont_print(stderr, 0, 0, "17", S);
// Y3 = S - Y3 = 3(X1 + Z1^2)(X1 - Z1^2)(4 * X1 * Y1^2 - X3) - 8 * Y1^4 // Y3 = S - Y3 = 3(X1 + Z1^2)(X1 - Z1^2)(4 * X1 * Y1^2 - X3) - 8 * Y1^4
sm2_z256_modp_sub(Y3, S, Y3); sm2_z256_modp_sub(Y3, S, Y3);
//sm2_z256_modp_mont_print(stderr, 0, 0, "18", Y3);
} }
/*
(X1:Y1:Z1) + (X2:Y2:Z2) => (X3:Y3:Z3)
A = Y2 * Z1^3 - Y1 * Z2^3
B = X2 * Z1^2 - X1 * Z2^2
X3 = A^2 - B^2 * (X2 * Z1^2 + X1 * Z2^2)
= A^2 - B^3 - 2 * B^2 * X1 * Z2^2
Y3 = A * (X1 * B^2 * Z2^2 - X3) - Y1 * B^3 * Z2^3
Z3 = B * Z1 * Z2
P + (-P) = (X:Y:Z) + (k^2*X : k^3*Y : k*Z) => (0:0:0)
感觉点加也有很好的并行性
*/
void sm2_z256_point_add(SM2_Z256_POINT *r, const SM2_Z256_POINT *a, const SM2_Z256_POINT *b) void sm2_z256_point_add(SM2_Z256_POINT *r, const SM2_Z256_POINT *a, const SM2_Z256_POINT *b)
{ {
uint64_t U2[4], S2[4]; uint64_t U2[4], S2[4];
@@ -1028,6 +1173,7 @@ void sm2_z256_point_add(SM2_Z256_POINT *r, const SM2_Z256_POINT *a, const SM2_Z2
in1infty = is_zero(in1infty); in1infty = is_zero(in1infty);
in2infty = is_zero(in2infty); in2infty = is_zero(in2infty);
// 这里很明显有极好的并行性
sm2_z256_modp_mont_sqr(Z2sqr, in2_z); /* Z2^2 */ sm2_z256_modp_mont_sqr(Z2sqr, in2_z); /* Z2^2 */
sm2_z256_modp_mont_sqr(Z1sqr, in1_z); /* Z1^2 */ sm2_z256_modp_mont_sqr(Z1sqr, in1_z); /* Z1^2 */
@@ -1057,11 +1203,13 @@ void sm2_z256_point_add(SM2_Z256_POINT *r, const SM2_Z256_POINT *a, const SM2_Z2
sm2_z256_modp_mont_sqr(Rsqr, R); /* R^2 */ sm2_z256_modp_mont_sqr(Rsqr, R); /* R^2 */
sm2_z256_modp_mont_mul(res_z, H, in1_z); /* Z3 = H*Z1*Z2 */ sm2_z256_modp_mont_mul(res_z, H, in1_z); /* Z3 = H*Z1*Z2 */
sm2_z256_modp_mont_sqr(Hsqr, H); /* H^2 */ sm2_z256_modp_mont_sqr(Hsqr, H); /* H^2 */
sm2_z256_modp_mont_mul(res_z, res_z, in2_z); /* Z3 = H*Z1*Z2 */ sm2_z256_modp_mont_mul(res_z, res_z, in2_z); /* Z3 = H*Z1*Z2 */
sm2_z256_modp_mont_mul(Hcub, Hsqr, H); /* H^3 */
sm2_z256_modp_mont_mul(Hcub, Hsqr, H); /* H^3 */
sm2_z256_modp_mont_mul(U2, U1, Hsqr); /* U1*H^2 */ sm2_z256_modp_mont_mul(U2, U1, Hsqr); /* U1*H^2 */
sm2_z256_modp_mul_by_2(Hsqr, U2); /* 2*U1*H^2 */ sm2_z256_modp_mul_by_2(Hsqr, U2); /* 2*U1*H^2 */
sm2_z256_modp_sub(res_x, Rsqr, Hsqr); sm2_z256_modp_sub(res_x, Rsqr, Hsqr);
@@ -1071,6 +1219,7 @@ void sm2_z256_point_add(SM2_Z256_POINT *r, const SM2_Z256_POINT *a, const SM2_Z2
sm2_z256_modp_mont_mul(S2, S1, Hcub); sm2_z256_modp_mont_mul(S2, S1, Hcub);
sm2_z256_modp_mont_mul(res_y, R, res_y); sm2_z256_modp_mont_mul(res_y, R, res_y);
sm2_z256_modp_sub(res_y, res_y, S2); sm2_z256_modp_sub(res_y, res_y, S2);
sm2_z256_copy_conditional(res_x, in2_x, in1infty); sm2_z256_copy_conditional(res_x, in2_x, in1infty);
@@ -1093,7 +1242,6 @@ void sm2_z256_point_neg(SM2_Z256_POINT *R, const SM2_Z256_POINT *P)
sm2_z256_copy(R->Z, P->Z); sm2_z256_copy(R->Z, P->Z);
} }
// point_mul 中用到
void sm2_z256_point_sub(SM2_Z256_POINT *R, const SM2_Z256_POINT *A, const SM2_Z256_POINT *B) void sm2_z256_point_sub(SM2_Z256_POINT *R, const SM2_Z256_POINT *A, const SM2_Z256_POINT *B)
{ {
SM2_Z256_POINT neg_B; SM2_Z256_POINT neg_B;
@@ -1109,8 +1257,28 @@ void sm2_z256_point_mul(SM2_Z256_POINT *R, const uint64_t k[4], const SM2_Z256_P
int n = (256 + window_size - 1)/window_size; int n = (256 + window_size - 1)/window_size;
int i; int i;
// 这相当于做了一个预计算表
/*
P 2P 4P 8P // 这实际上是一个连续的dbl
3P 6P, 12P
5P, 10P,
7P, 14P
15P
...
// 如果一次能并行计算4组点加法那么这部分与计算表的计算量可以降低
// 这个连续计算中dbl的数量越多计算量越低
*/
// T[i] = (i + 1) * P // T[i] = (i + 1) * P
memcpy(&T[0], P, sizeof(SM2_Z256_POINT)); memcpy(&T[0], P, sizeof(SM2_Z256_POINT));
// 这个计算大概是有并行能力的!
/*
sm2_z256_point_dbl(&T[ 1], &T[ 0]); sm2_z256_point_dbl(&T[ 1], &T[ 0]);
sm2_z256_point_add(&T[ 2], &T[ 1], P); sm2_z256_point_add(&T[ 2], &T[ 1], P);
sm2_z256_point_dbl(&T[ 3], &T[ 1]); sm2_z256_point_dbl(&T[ 3], &T[ 1]);
@@ -1126,6 +1294,24 @@ void sm2_z256_point_mul(SM2_Z256_POINT *R, const uint64_t k[4], const SM2_Z256_P
sm2_z256_point_dbl(&T[13], &T[ 6]); sm2_z256_point_dbl(&T[13], &T[ 6]);
sm2_z256_point_add(&T[14], &T[13], P); sm2_z256_point_add(&T[14], &T[13], P);
sm2_z256_point_dbl(&T[15], &T[ 7]); sm2_z256_point_dbl(&T[15], &T[ 7]);
*/
sm2_z256_point_dbl(&T[2-1], &T[1-1]);
sm2_z256_point_dbl(&T[4-1], &T[2-1]);
sm2_z256_point_dbl(&T[8-1], &T[4-1]);
sm2_z256_point_dbl(&T[16-1], &T[8-1]);
sm2_z256_point_add(&T[3-1], &T[2-1], P);
sm2_z256_point_dbl(&T[6-1], &T[3-1]);
sm2_z256_point_dbl(&T[12-1], &T[6-1]);
sm2_z256_point_add(&T[5-1], &T[3-1], &T[2-1]);
sm2_z256_point_dbl(&T[10-1], &T[5-1]);
sm2_z256_point_add(&T[7-1], &T[4-1], &T[3-1]);
sm2_z256_point_dbl(&T[14-1], &T[7-1]);
sm2_z256_point_add(&T[9-1], &T[4-1], &T[5-1]);
sm2_z256_point_add(&T[11-1], &T[6-1], &T[5-1]);
sm2_z256_point_add(&T[13-1], &T[7-1], &T[6-1]);
sm2_z256_point_add(&T[15-1], &T[8-1], &T[7-1]);
for (i = n - 1; i >= 0; i--) { for (i = n - 1; i >= 0; i--) {
int booth = sm2_z256_get_booth(k, window_size, i); int booth = sm2_z256_get_booth(k, window_size, i);
@@ -1136,11 +1322,9 @@ void sm2_z256_point_mul(SM2_Z256_POINT *R, const uint64_t k[4], const SM2_Z256_P
R_infinity = 0; R_infinity = 0;
} }
} else { } else {
sm2_z256_point_dbl(R, R); // 这个重复dbl的计算可以适当降低吗
sm2_z256_point_dbl(R, R); // 这说明对dbl的优化还是很有意义的因为这里面dbl的数量最多
sm2_z256_point_dbl(R, R); sm2_z256_point_dbl_x5(R, R);
sm2_z256_point_dbl(R, R);
sm2_z256_point_dbl(R, R);
if (booth > 0) { if (booth > 0) {
sm2_z256_point_add(R, R, &T[booth - 1]); sm2_z256_point_add(R, R, &T[booth - 1]);
@@ -1177,6 +1361,8 @@ void sm2_z256_point_copy_affine(SM2_Z256_POINT *R, const SM2_Z256_POINT_AFFINE *
sm2_z256_copy(R->Z, SM2_Z256_MODP_MONT_ONE); sm2_z256_copy(R->Z, SM2_Z256_MODP_MONT_ONE);
} }
// 这是一个比较容易并行的算法
// r, a, b 都转换为实际输入的值
void sm2_z256_point_add_affine(SM2_Z256_POINT *r, const SM2_Z256_POINT *a, const SM2_Z256_POINT_AFFINE *b) void sm2_z256_point_add_affine(SM2_Z256_POINT *r, const SM2_Z256_POINT *a, const SM2_Z256_POINT_AFFINE *b)
{ {
uint64_t U2[4], S2[4]; uint64_t U2[4], S2[4];
@@ -1287,44 +1473,60 @@ int sm2_z256_point_affine_print(FILE *fp, int fmt, int ind, const char *label, c
extern const uint64_t sm2_z256_pre_comp[37][64 * 4 * 2]; extern const uint64_t sm2_z256_pre_comp[37][64 * 4 * 2];
static SM2_Z256_POINT_AFFINE (*g_pre_comp)[64] = (SM2_Z256_POINT_AFFINE (*)[64])sm2_z256_pre_comp; static SM2_Z256_POINT_AFFINE (*g_pre_comp)[64] = (SM2_Z256_POINT_AFFINE (*)[64])sm2_z256_pre_comp;
/*
这个函数的粗粒度并行算法
输出的R应该有多个输入的k也有多个
轮数是一样的
需要用一个数组表示这个值是否还是无穷远点
在签名、加密的时候参与计算的k都是秘密值因此需要考虑cache攻击的问题
但是在验签的时候其中s*G计算其中s是公开值因此不需要考虑cache攻击
应该提供一个专用的常量时间的gather函数
*/
void sm2_z256_point_mul_generator(SM2_Z256_POINT *R, const uint64_t k[4]) void sm2_z256_point_mul_generator(SM2_Z256_POINT *R, const uint64_t k[4])
{ {
size_t window_size = 7; size_t window_size = 7;
int R_infinity = 1; int R_infinity = 1; // 开始的时候点
int n = (256 + window_size - 1)/window_size; int n = (256 + window_size - 1)/window_size;
int i; int i;
for (i = n - 1; i >= 0; i--) { for (i = n - 1; i >= 0; i--) {
int booth = sm2_z256_get_booth(k, window_size, i); int booth = sm2_z256_get_booth(k, window_size, i);
// 下面的计算应该改为并行化
if (R_infinity) { if (R_infinity) {
if (booth != 0) { if (booth != 0) {
sm2_z256_point_copy_affine(R, &g_pre_comp[i][booth - 1]); sm2_z256_point_copy_affine(R, &g_pre_comp[i][booth - 1]);
R_infinity = 0; R_infinity = 0;
} }
} else { } else {
// 可以先把那个点从内存复制到当前空间中
// 如果booth < 0则把这个点改为 -P
// 然后再加上这个点,得到一个新的结果
if (booth > 0) { if (booth > 0) {
sm2_z256_point_add_affine(R, R, &g_pre_comp[i][booth - 1]); sm2_z256_point_add_affine(R, R, &g_pre_comp[i][booth - 1]);
} else if (booth < 0) { } else if (booth < 0) {
sm2_z256_point_sub_affine(R, R, &g_pre_comp[i][-booth - 1]); sm2_z256_point_sub_affine(R, R, &g_pre_comp[i][-booth - 1]);
} }
// booth == 0的时候意味应该加入的affine是一个无穷远点
// 如果是无穷远点,读入的值,以及计算结果就没有用了。
} }
} }
if (R_infinity) { if (R_infinity) {
memset(R, 0, sizeof(*R)); sm2_z256_point_set_infinity(R);
} }
} }
// R = t*P + s*G // R = t*P + s*G
void sm2_z256_point_mul_sum(SM2_Z256_POINT *R, const uint64_t t[4], const SM2_Z256_POINT *P, const uint64_t s[4]) void sm2_z256_point_mul_sum(SM2_Z256_POINT *R, const uint64_t t[4], const SM2_Z256_POINT *P, const uint64_t s[4])
{ {
@@ -1334,8 +1536,6 @@ void sm2_z256_point_mul_sum(SM2_Z256_POINT *R, const uint64_t t[4], const SM2_Z2
sm2_z256_point_add(R, R, &Q); sm2_z256_point_add(R, R, &Q);
} }
// 这个是否要检查点是否在曲线上?
void sm2_z256_point_from_bytes(SM2_Z256_POINT *P, const uint8_t in[64]) void sm2_z256_point_from_bytes(SM2_Z256_POINT *P, const uint8_t in[64])
{ {
sm2_z256_from_bytes(P->X, in); sm2_z256_from_bytes(P->X, in);
@@ -1364,6 +1564,35 @@ void sm2_z256_point_to_bytes(const SM2_Z256_POINT *P, uint8_t out[64])
sm2_z256_to_bytes(y, out + 32); sm2_z256_to_bytes(y, out + 32);
} }
int sm2_z256_point_equ(const SM2_Z256_POINT *P, const SM2_Z256_POINT *Q)
{
uint64_t Z1[4] = {0};
uint64_t Z2[4] = {0};
uint64_t V1[4] = {0};
uint64_t V2[4] = {0};
// X1 * Z2^2 == X2 * Z1^2
sm2_z256_modp_mont_sqr(Z1, P->Z);
sm2_z256_modp_mont_sqr(Z2, Q->Z);
sm2_z256_modp_mont_mul(V1, P->X, Z2);
sm2_z256_modp_mont_mul(V2, Q->X, Z1);
if (sm2_z256_cmp(V1, V2) != 0) {
error_print();
return 0;
}
// Y1 * Z2^3 == Y2 * Z1^3
sm2_z256_modp_mont_mul(Z1, Z1, P->Z);
sm2_z256_modp_mont_mul(Z2, Z2, Q->Z);
sm2_z256_modp_mont_mul(V1, P->Y, Z2);
sm2_z256_modp_mont_mul(V2, Q->Y, Z1);
if (sm2_z256_cmp(V1, V2) != 0) {
error_print();
return 0;
}
return 1;
}
int sm2_z256_point_equ_hex(const SM2_Z256_POINT *P, const char *hex) int sm2_z256_point_equ_hex(const SM2_Z256_POINT *P, const char *hex)
{ {
@@ -1379,8 +1608,96 @@ int sm2_z256_point_equ_hex(const SM2_Z256_POINT *P, const char *hex)
return 0; return 0;
} }
return 1; return 1;
} }
int sm2_z256_is_odd(const uint64_t a[4])
{
return a[0] & 0x01;
}
int sm2_z256_point_from_x_bytes(SM2_Z256_POINT *P, const uint8_t x_bytes[32], int y_is_odd)
{
uint64_t x[4];
uint64_t y_sqr[4];
uint64_t y[4];
int ret;
uint64_t SM2_Z256_MODP_MONT_THREE[4] = { 3,0,0,0 };
sm2_z256_modp_to_mont(SM2_Z256_MODP_MONT_THREE, SM2_Z256_MODP_MONT_THREE);
sm2_z256_from_bytes(x, x_bytes);
if (sm2_z256_cmp(x, SM2_Z256_P) >= 0) {
error_print();
return -1;
}
sm2_z256_modp_to_mont(x, x);
sm2_z256_copy(P->X, x);
// y^2 = x^3 - 3x + b = (x^2 - 3)*x + b
sm2_z256_modp_mont_sqr(y_sqr, x);
sm2_z256_modp_sub(y_sqr, y_sqr, SM2_Z256_MODP_MONT_THREE);
sm2_z256_modp_mont_mul(y_sqr, y_sqr, x);
sm2_z256_modp_add(y_sqr, y_sqr, SM2_Z256_MODP_MONT_B);
// y = sqrt(y^2)
if ((ret = sm2_z256_modp_mont_sqrt(y, y_sqr)) != 1) {
if (ret < 0) error_print();
return ret;
}
sm2_z256_copy(P->Y , y); // mont(y)
sm2_z256_modp_from_mont(y, y);
if (y_is_odd) {
if (!sm2_z256_is_odd(y)) {
sm2_z256_modp_neg(P->Y, P->Y);
}
} else {
if (sm2_z256_is_odd(y)) {
sm2_z256_modp_neg(P->Y, P->Y);
}
}
sm2_z256_copy(P->Z, SM2_Z256_MODP_MONT_ONE);
return 1;
}
int sm2_z256_point_from_hash(SM2_Z256_POINT *R, const uint8_t *data, size_t datalen, int y_is_odd)
{
uint64_t x[4];
uint8_t x_bytes[32];
uint8_t dgst[32];
int ret;
do {
// x = sm3(data) mod p
sm3_digest(data, datalen, dgst);
sm2_z256_from_bytes(x, dgst);
if (sm2_z256_cmp(x, SM2_Z256_P) >= 0) {
sm2_z256_sub(x, x, SM2_Z256_P);
}
sm2_z256_to_bytes(x, x_bytes);
// compute y
if ((ret = sm2_z256_point_from_x_bytes(R, x_bytes, y_is_odd)) == 1) {
break;
}
if (ret < 0) {
error_print();
return -1;
}
// data = sm3(data), try again
data = dgst;
datalen = sizeof(dgst);
} while (1);
return 1;
}

706
src/sm2_z256_key.c Normal file
View File

@@ -0,0 +1,706 @@
/*
* Copyright 2014-2024 The GmSSL Project. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the License); you may
* not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*/
#include <string.h>
#include <gmssl/sm2_z256.h>
#include <gmssl/oid.h>
#include <gmssl/asn1.h>
#include <gmssl/pem.h>
#include <gmssl/sm4.h>
#include <gmssl/rand.h>
#include <gmssl/pbkdf2.h>
#include <gmssl/pkcs8.h>
#include <gmssl/error.h>
#include <gmssl/ec.h>
#include <gmssl/mem.h>
#include <gmssl/x509_alg.h>
int sm2_key_generate(SM2_KEY *key)
{
uint64_t d[4];
SM2_Z256_POINT P;
if (!key) {
error_print();
return -1;
}
do {
if (sm2_z256_rand_range(d, sm2_z256_order_minus_one()) != 1) {
error_print();
return -1;
}
} while (sm2_z256_is_zero(d));
sm2_z256_point_mul_generator(&P, d);
sm2_z256_to_bytes(d, key->private_key);
sm2_z256_point_to_bytes(&P, (uint8_t *)&key->public_key);
gmssl_secure_clear(d, sizeof(d));
return 1;
}
int sm2_key_set_private_key(SM2_KEY *key, const uint8_t private_key[32])
{
uint64_t d[4];
SM2_Z256_POINT P;
int ret = -1;
if (!key || !private_key) {
error_print();
return -1;
}
sm2_z256_from_bytes(d, private_key);
if (sm2_z256_is_zero(d)) {
error_print();
goto end;
}
if (sm2_z256_cmp(d, sm2_z256_order_minus_one()) >= 0) {
error_print();
goto end;
}
sm2_z256_point_mul_generator(&P, d);
sm2_z256_to_bytes(d, key->private_key);
sm2_z256_point_to_bytes(&P, (uint8_t *)&key->public_key);
ret = 1;
end:
gmssl_secure_clear(d, sizeof(d));
return ret;
}
int sm2_key_set_public_key(SM2_KEY *key, const SM2_POINT *public_key)
{
uint64_t d[4] = {0};
SM2_Z256_POINT P;
if (!key || !public_key) {
error_print();
return -1;
}
sm2_z256_point_from_bytes(&P, (uint8_t *)public_key);
if (sm2_z256_point_is_on_curve(&P) != 1) {
error_print();
return -1;
}
sm2_z256_to_bytes(d, key->private_key);
sm2_z256_point_to_bytes(&P, (uint8_t *)&key->public_key);
return 1;
}
int sm2_key_print(FILE *fp, int fmt, int ind, const char *label, const SM2_KEY *key)
{
format_print(fp, fmt, ind, "%s\n", label);
ind += 4;
sm2_public_key_print(fp, fmt, ind, "publicKey", key);
format_bytes(fp, fmt, ind, "privateKey", key->private_key, 32);
return 1;
}
int sm2_public_key_to_der(const SM2_KEY *key, uint8_t **out, size_t *outlen)
{
uint8_t buf[65];
size_t len = 0;
if (!key) {
return 0;
}
sm2_point_to_uncompressed_octets(&key->public_key, buf);
if (asn1_bit_octets_to_der(buf, sizeof(buf), out, outlen) != 1) {
error_print();
return -1;
}
return 1;
}
int sm2_public_key_from_der(SM2_KEY *key, const uint8_t **in, size_t *inlen)
{
int ret;
const uint8_t *d;
size_t dlen;
SM2_POINT P;
if ((ret = asn1_bit_octets_from_der(&d, &dlen, in, inlen)) != 1) {
if (ret < 0) error_print();
return ret;
}
if (dlen != 65) {
error_print();
return -1;
}
// 这里不太对SM2_POINT 被反复检查了
if (sm2_point_from_octets(&P, d, dlen) != 1) {
error_print();
return -1;
}
if (sm2_key_set_public_key(key, &P) != 1) {
error_print();
return -1;
}
return 1;
}
int sm2_public_key_print(FILE *fp, int fmt, int ind, const char *label, const SM2_KEY *pub_key)
{
return sm2_point_print(fp, fmt, ind, label, &pub_key->public_key);
}
int sm2_public_key_algor_to_der(uint8_t **out, size_t *outlen)
{
if (x509_public_key_algor_to_der(OID_ec_public_key, OID_sm2, out, outlen) != 1) {
error_print();
return -1;
}
return 1;
}
int sm2_public_key_algor_from_der(const uint8_t **in, size_t *inlen)
{
int ret;
int oid;
int curve;
if ((ret = x509_public_key_algor_from_der(&oid, &curve, in, inlen)) != 1) {
if (ret < 0) error_print();
return ret;
}
if (oid != OID_ec_public_key) {
error_print();
return -1;
}
if (curve != OID_sm2) {
error_print();
return -1;
}
return 1;
}
#define SM2_PRIVATE_KEY_DER_SIZE 121
int sm2_private_key_to_der(const SM2_KEY *key, uint8_t **out, size_t *outlen)
{
size_t len = 0;
uint8_t params[64];
uint8_t pubkey[128];
uint8_t *params_ptr = params;
uint8_t *pubkey_ptr = pubkey;
size_t params_len = 0;
size_t pubkey_len = 0;
if (!key) {
error_print();
return -1;
}
if (ec_named_curve_to_der(OID_sm2, &params_ptr, &params_len) != 1
|| sm2_public_key_to_der(key, &pubkey_ptr, &pubkey_len) != 1) {
error_print();
return -1;
}
if (asn1_int_to_der(EC_private_key_version, NULL, &len) != 1
|| asn1_octet_string_to_der(key->private_key, 32, NULL, &len) != 1
|| asn1_explicit_to_der(0, params, params_len, NULL, &len) != 1
|| asn1_explicit_to_der(1, pubkey, pubkey_len, NULL, &len) != 1
|| asn1_sequence_header_to_der(len, out, outlen) != 1
|| asn1_int_to_der(EC_private_key_version, out, outlen) != 1
|| asn1_octet_string_to_der(key->private_key, 32, out, outlen) != 1
|| asn1_explicit_to_der(0, params, params_len, out, outlen) != 1
|| asn1_explicit_to_der(1, pubkey, pubkey_len, out, outlen) != 1) {
error_print();
return -1;
}
return 1;
}
int sm2_private_key_from_der(SM2_KEY *key, const uint8_t **in, size_t *inlen)
{
int ret;
const uint8_t *d;
size_t dlen;
int ver;
const uint8_t *prikey;
const uint8_t *params;
const uint8_t *pubkey;
size_t prikey_len, params_len, pubkey_len;
if ((ret = asn1_sequence_from_der(&d, &dlen, in, inlen)) != 1) {
if (ret < 0) error_print();
return ret;
}
if (asn1_int_from_der(&ver, &d, &dlen) != 1
|| asn1_octet_string_from_der(&prikey, &prikey_len, &d, &dlen) != 1
|| asn1_explicit_from_der(0, &params, &params_len, &d, &dlen) != 1
|| asn1_explicit_from_der(1, &pubkey, &pubkey_len, &d, &dlen) != 1
|| asn1_check(ver == EC_private_key_version) != 1
|| asn1_length_is_zero(dlen) != 1) {
error_print();
return -1;
}
if (params) {
int curve;
if (ec_named_curve_from_der(&curve, &params, &params_len) != 1
|| asn1_check(curve == OID_sm2) != 1
|| asn1_length_is_zero(params_len) != 1) {
error_print();
return -1;
}
}
if (asn1_check(prikey_len == 32) != 1
|| sm2_key_set_private_key(key, prikey) != 1) {
error_print();
return -1;
}
// check if the public key is correct
if (pubkey) {
SM2_KEY tmp_key;
if (sm2_public_key_from_der(&tmp_key, &pubkey, &pubkey_len) != 1
|| asn1_length_is_zero(pubkey_len) != 1) {
error_print();
return -1;
}
if (sm2_public_key_equ(key, &tmp_key) != 1) {
error_print();
return -1;
}
}
return 1;
}
int sm2_private_key_print(FILE *fp, int fmt, int ind, const char *label, const uint8_t *d, size_t dlen)
{
return ec_private_key_print(fp, fmt, ind, label, d, dlen);
}
#define SM2_PRIVATE_KEY_INFO_DER_SIZE 150
int sm2_private_key_info_to_der(const SM2_KEY *sm2_key, uint8_t **out, size_t *outlen)
{
size_t len = 0;
uint8_t prikey[SM2_PRIVATE_KEY_DER_SIZE];
uint8_t *p = prikey;
size_t prikey_len = 0;
if (sm2_private_key_to_der(sm2_key, &p, &prikey_len) != 1) {
error_print();
return -1;
}
if (asn1_int_to_der(PKCS8_private_key_info_version, NULL, &len) != 1
|| sm2_public_key_algor_to_der(NULL, &len) != 1
|| asn1_octet_string_to_der(prikey, prikey_len, NULL, &len) != 1
|| asn1_sequence_header_to_der(len, out, outlen) != 1
|| asn1_int_to_der(PKCS8_private_key_info_version, out, outlen) != 1
|| sm2_public_key_algor_to_der(out, outlen) != 1
|| asn1_octet_string_to_der(prikey, prikey_len, out, outlen) != 1) {
memset(prikey, 0, sizeof(prikey));
error_print();
return -1;
}
memset(prikey, 0, sizeof(prikey));
return 1;
}
int sm2_private_key_info_from_der(SM2_KEY *sm2_key, const uint8_t **attrs, size_t *attrslen,
const uint8_t **in, size_t *inlen)
{
int ret;
const uint8_t *d;
size_t dlen;
int version;
const uint8_t *prikey;
size_t prikey_len;
if ((ret = asn1_sequence_from_der(&d, &dlen, in, inlen)) != 1) {
if (ret < 0) error_print();
return ret;
}
if (asn1_int_from_der(&version, &d, &dlen) != 1
|| sm2_public_key_algor_from_der(&d, &dlen) != 1
|| asn1_octet_string_from_der(&prikey, &prikey_len, &d, &dlen) != 1
|| asn1_implicit_set_from_der(0, attrs, attrslen, &d, &dlen) < 0
|| asn1_length_is_zero(dlen) != 1) {
error_print();
return -1;
}
if (asn1_check(version == PKCS8_private_key_info_version) != 1
|| sm2_private_key_from_der(sm2_key, &prikey, &prikey_len) != 1
|| asn1_length_is_zero(prikey_len) != 1) {
error_print();
return -1;
}
return 1;
}
int sm2_private_key_info_print(FILE *fp, int fmt, int ind, const char *label, const uint8_t *d, size_t dlen)
{
int ret;
const uint8_t *p;
size_t len;
int val;
const uint8_t *prikey;
size_t prikey_len;
format_print(fp, fmt, ind, "%s\n", label);
ind += 4;
if (asn1_int_from_der(&val, &d, &dlen) != 1) goto err;
format_print(fp, fmt, ind, "version: %d\n", val);
if (asn1_sequence_from_der(&p, &len, &d, &dlen) != 1) goto err;
x509_public_key_algor_print(fp, fmt, ind, "privateKeyAlgorithm", p, len);
if (asn1_octet_string_from_der(&p, &len, &d, &dlen) != 1) goto err;
if (asn1_sequence_from_der(&prikey, &prikey_len, &p, &len) != 1) goto err;
ec_private_key_print(fp, fmt, ind + 4, "privateKey", prikey, prikey_len);
if (asn1_length_is_zero(len) != 1) goto err;
if ((ret = asn1_implicit_set_from_der(0, &p, &len, &d, &dlen)) < 0) goto err;
else if (ret) format_bytes(fp, fmt, ind, "attributes", p, len);
if (asn1_length_is_zero(dlen) != 1) goto err;
return 1;
err:
error_print();
return -1;
}
#ifdef ENABLE_SM2_PRIVATE_KEY_EXPORT
int sm2_private_key_info_to_pem(const SM2_KEY *key, FILE *fp)
{
int ret = -1;
uint8_t buf[SM2_PRIVATE_KEY_INFO_DER_SIZE];
uint8_t *p = buf;
size_t len = 0;
if (!key || !fp) {
error_print();
return -1;
}
if (sm2_private_key_info_to_der(key, &p, &len) != 1) {
error_print();
goto end;
}
if (len != sizeof(buf)) {
error_print();
goto end;
}
if (pem_write(fp, "PRIVATE KEY", buf, len) != 1) {
error_print();
goto end;
}
ret = 1;
end:
gmssl_secure_clear(buf, sizeof(buf));
return ret;
}
int sm2_private_key_info_from_pem(SM2_KEY *sm2_key, FILE *fp)
{
uint8_t buf[512];
const uint8_t *cp = buf;
size_t len;
const uint8_t *attrs;
size_t attrs_len;
if (pem_read(fp, "PRIVATE KEY", buf, &len, sizeof(buf)) != 1
|| sm2_private_key_info_from_der(sm2_key, &attrs, &attrs_len, &cp, &len) != 1
|| asn1_length_is_zero(len) != 1) {
error_print();
return -1;
}
if (attrs_len) {
error_print();
}
return 1;
}
#endif
int sm2_public_key_info_to_der(const SM2_KEY *pub_key, uint8_t **out, size_t *outlen)
{
size_t len = 0;
if (sm2_public_key_algor_to_der(NULL, &len) != 1
|| sm2_public_key_to_der(pub_key, NULL, &len) != 1
|| asn1_sequence_header_to_der(len, out, outlen) != 1
|| sm2_public_key_algor_to_der(out, outlen) != 1
|| sm2_public_key_to_der(pub_key, out, outlen) != 1) {
error_print();
return -1;
}
return 1;
}
int sm2_public_key_info_from_der(SM2_KEY *pub_key, const uint8_t **in, size_t *inlen)
{
int ret;
const uint8_t *d;
size_t dlen;
if ((ret = asn1_sequence_from_der(&d, &dlen, in, inlen)) != 1) {
if (ret < 0) error_print();
return ret;
}
if (sm2_public_key_algor_from_der(&d, &dlen) != 1
|| sm2_public_key_from_der(pub_key, &d, &dlen) != 1
|| asn1_length_is_zero(dlen) != 1) {
error_print();
return -1;
}
return 1;
}
#ifdef ENABLE_SM2_PRIVATE_KEY_EXPORT
// FIXME: side-channel of Base64
int sm2_private_key_to_pem(const SM2_KEY *a, FILE *fp)
{
uint8_t buf[512];
uint8_t *p = buf;
size_t len = 0;
if (sm2_private_key_to_der(a, &p, &len) != 1) {
error_print();
return -1;
}
if (pem_write(fp, "EC PRIVATE KEY", buf, len) <= 0) {
error_print();
return -1;
}
return 1;
}
int sm2_private_key_from_pem(SM2_KEY *a, FILE *fp)
{
uint8_t buf[512];
const uint8_t *cp = buf;
size_t len;
if (pem_read(fp, "EC PRIVATE KEY", buf, &len, sizeof(buf)) != 1) {
error_print();
return -1;
}
if (sm2_private_key_from_der(a, &cp, &len) != 1
|| len > 0) {
error_print();
return -1;
}
return 1;
}
#endif
int sm2_public_key_info_to_pem(const SM2_KEY *a, FILE *fp)
{
uint8_t buf[512];
uint8_t *p = buf;
size_t len = 0;
if (sm2_public_key_info_to_der(a, &p, &len) != 1) {
error_print();
return -1;
}
if (pem_write(fp, "PUBLIC KEY", buf, len) <= 0) {
error_print();
return -1;
}
return 1;
}
int sm2_public_key_info_from_pem(SM2_KEY *a, FILE *fp)
{
uint8_t buf[512];
const uint8_t *cp = buf;
size_t len;
if (pem_read(fp, "PUBLIC KEY", buf, &len, sizeof(buf)) != 1) {
error_print();
return -1;
}
if (sm2_public_key_info_from_der(a, &cp, &len) != 1
|| asn1_length_is_zero(len) != 1) {
error_print();
return -1;
}
return 1;
}
int sm2_public_key_equ(const SM2_KEY *sm2_key, const SM2_KEY *pub_key)
{
if (memcmp(sm2_key, pub_key, sizeof(SM2_POINT)) == 0) {
return 1;
}
return 0;
}
int sm2_public_key_copy(SM2_KEY *sm2_key, const SM2_KEY *pub_key)
{
return sm2_key_set_public_key(sm2_key, &pub_key->public_key);
}
int sm2_public_key_digest(const SM2_KEY *sm2_key, uint8_t dgst[32])
{
uint8_t bits[65];
sm2_point_to_uncompressed_octets(&sm2_key->public_key, bits);
sm3_digest(bits, sizeof(bits), dgst);
return 1;
}
int sm2_private_key_info_encrypt_to_der(const SM2_KEY *sm2_key, const char *pass,
uint8_t **out, size_t *outlen)
{
int ret = -1;
uint8_t pkey_info[SM2_PRIVATE_KEY_INFO_DER_SIZE];
uint8_t *p = pkey_info;
size_t pkey_info_len = 0;
uint8_t salt[16];
int iter = 65536;
uint8_t iv[16];
uint8_t key[16];
SM4_KEY sm4_key;
uint8_t enced_pkey_info[sizeof(pkey_info) + 32];
size_t enced_pkey_info_len;
if (!sm2_key || !pass || !outlen) {
error_print();
return -1;
}
if (sm2_private_key_info_to_der(sm2_key, &p, &pkey_info_len) != 1
|| rand_bytes(salt, sizeof(salt)) != 1
|| rand_bytes(iv, sizeof(iv)) != 1
|| pbkdf2_genkey(DIGEST_sm3(), pass, strlen(pass),
salt, sizeof(salt), iter, sizeof(key), key) != 1) {
error_print();
goto end;
}
/*
if (pkey_info_len != sizeof(pkey_info)) {
error_print();
goto end;
}
*/
sm4_set_encrypt_key(&sm4_key, key);
if (sm4_cbc_padding_encrypt(
&sm4_key, iv, pkey_info, pkey_info_len,
enced_pkey_info, &enced_pkey_info_len) != 1
|| pkcs8_enced_private_key_info_to_der(
salt, sizeof(salt), iter, sizeof(key), OID_hmac_sm3,
OID_sm4_cbc, iv, sizeof(iv),
enced_pkey_info, enced_pkey_info_len, out, outlen) != 1) {
error_print();
goto end;
}
ret = 1;
end:
gmssl_secure_clear(pkey_info, sizeof(pkey_info));
gmssl_secure_clear(key, sizeof(key));
gmssl_secure_clear(&sm4_key, sizeof(sm4_key));
return ret;
}
int sm2_private_key_info_decrypt_from_der(SM2_KEY *sm2,
const uint8_t **attrs, size_t *attrs_len,
const char *pass, const uint8_t **in, size_t *inlen)
{
int ret = -1;
const uint8_t *salt;
size_t saltlen;
int iter;
int keylen;
int prf;
int cipher;
const uint8_t *iv;
size_t ivlen;
uint8_t key[16];
SM4_KEY sm4_key;
const uint8_t *enced_pkey_info;
size_t enced_pkey_info_len;
uint8_t pkey_info[256];
const uint8_t *cp = pkey_info;
size_t pkey_info_len;
if (!sm2 || !attrs || !attrs_len || !pass || !in || !(*in) || !inlen) {
error_print();
return -1;
}
if (pkcs8_enced_private_key_info_from_der(&salt, &saltlen, &iter, &keylen, &prf,
&cipher, &iv, &ivlen, &enced_pkey_info, &enced_pkey_info_len, in, inlen) != 1
|| asn1_check(keylen == -1 || keylen == 16) != 1
|| asn1_check(prf == - 1 || prf == OID_hmac_sm3) != 1
|| asn1_check(cipher == OID_sm4_cbc) != 1
|| asn1_check(ivlen == 16) != 1
|| asn1_length_le(enced_pkey_info_len, sizeof(pkey_info)) != 1) {
error_print();
return -1;
}
if (pbkdf2_genkey(DIGEST_sm3(), pass, strlen(pass), salt, saltlen, iter, sizeof(key), key) != 1) {
error_print();
goto end;
}
sm4_set_decrypt_key(&sm4_key, key);
if (sm4_cbc_padding_decrypt(&sm4_key, iv, enced_pkey_info, enced_pkey_info_len,
pkey_info, &pkey_info_len) != 1
|| sm2_private_key_info_from_der(sm2, attrs, attrs_len, &cp, &pkey_info_len) != 1
|| asn1_length_is_zero(pkey_info_len) != 1) {
error_print();
goto end;
}
ret = 1;
end:
gmssl_secure_clear(&sm4_key, sizeof(sm4_key));
gmssl_secure_clear(key, sizeof(key));
gmssl_secure_clear(pkey_info, sizeof(pkey_info));
return ret;
}
int sm2_private_key_info_encrypt_to_pem(const SM2_KEY *sm2_key, const char *pass, FILE *fp)
{
uint8_t buf[1024];
uint8_t *p = buf;
size_t len = 0;
if (!fp) {
error_print();
return -1;
}
if (sm2_private_key_info_encrypt_to_der(sm2_key, pass, &p, &len) != 1) {
error_print();
return -1;
}
if (pem_write(fp, "ENCRYPTED PRIVATE KEY", buf, len) != 1) {
error_print();
return -1;
}
return 1;
}
int sm2_private_key_info_decrypt_from_pem(SM2_KEY *key, const char *pass, FILE *fp)
{
uint8_t buf[512];
const uint8_t *cp = buf;
size_t len;
const uint8_t *attrs;
size_t attrs_len;
if (!key || !pass || !fp) {
error_print();
return -1;
}
if (pem_read(fp, "ENCRYPTED PRIVATE KEY", buf, &len, sizeof(buf)) != 1
|| sm2_private_key_info_decrypt_from_der(key, &attrs, &attrs_len, pass, &cp, &len) != 1
|| asn1_length_is_zero(len) != 1) {
error_print();
return -1;
}
return 1;
}

View File

@@ -20,7 +20,6 @@
#include <gmssl/endian.h> #include <gmssl/endian.h>
typedef SM2_Z256 SM2_U256; typedef SM2_Z256 SM2_U256;
#define sm2_u256_one() sm2_z256_one() #define sm2_u256_one() sm2_z256_one()
@@ -52,7 +51,6 @@ typedef SM2_Z256_POINT SM2_U256_POINT;
#define sm2_u256_point_get_xy(P,x,y) sm2_z256_point_get_xy((P),(x),(y)) #define sm2_u256_point_get_xy(P,x,y) sm2_z256_point_get_xy((P),(x),(y))
int sm2_do_sign(const SM2_KEY *key, const uint8_t dgst[32], SM2_SIGNATURE *sig) int sm2_do_sign(const SM2_KEY *key, const uint8_t dgst[32], SM2_SIGNATURE *sig)
{ {
SM2_U256_POINT _P, *P = &_P; SM2_U256_POINT _P, *P = &_P;
@@ -82,6 +80,10 @@ int sm2_do_sign(const SM2_KEY *key, const uint8_t dgst[32], SM2_SIGNATURE *sig)
sm2_u256_from_bytes(e, dgst); //sm2_bn_print(stderr, 0, 4, "e", e); sm2_u256_from_bytes(e, dgst); //sm2_bn_print(stderr, 0, 4, "e", e);
retry: retry:
// >>>>>>>>>> BEGIN PRECOMP
// rand k in [1, n - 1] // rand k in [1, n - 1]
do { do {
if (sm2_u256_modn_rand(k) != 1) { if (sm2_u256_modn_rand(k) != 1) {
@@ -96,6 +98,11 @@ retry:
//sm2_bn_print(stderr, 0, 4, "x", x); //sm2_bn_print(stderr, 0, 4, "x", x);
// 如果我们提前计算了 (k, x) 那么我们在真正做签名的时候就可以利用到这个与计算的表了,直接从表中读取 (k, x)
// 当然这些计算都可以放在sign_fast里面
// >>>>>>>>>>> END PRECOMP
// r = e + x (mod n) // r = e + x (mod n)
if (sm2_u256_cmp(e, order) >= 0) { if (sm2_u256_cmp(e, order) >= 0) {
sm2_u256_sub(e, e, order); sm2_u256_sub(e, e, order);
@@ -132,13 +139,65 @@ retry:
return 1; return 1;
} }
// k 和 x1 都是要参与计算的,因此我们返回的是内部格式
int sm2_do_sign_pre_compute(uint64_t k[4], uint64_t x1[4])
{
SM2_Z256_POINT P;
// rand k in [1, n - 1]
do {
if (sm2_z256_modn_rand(k) != 1) {
error_print();
return -1;
}
} while (sm2_z256_is_zero(k));
// (x1, y1) = kG
sm2_u256_point_mul_generator(&P, k); // 这个函数要粗力度并行,这要怎么做?
sm2_u256_point_get_xy(&P, x1, NULL);
return 1;
}
// 实际上这里只有一次mod n的乘法用barret就可以了
int sm2_do_sign_fast_ex(const uint64_t d[4], const uint64_t k[4], const uint64_t x1[4], const uint8_t dgst[32], SM2_SIGNATURE *sig)
{
SM2_Z256_POINT R;
uint64_t e[4];
uint64_t r[4];
uint64_t s[4];
const uint64_t *order = sm2_z256_order();
// e = H(M)
sm2_z256_from_bytes(e, dgst);
if (sm2_z256_cmp(e, order) >= 0) {
sm2_z256_sub(e, e, order);
}
// r = e + x1 (mod n)
sm2_z256_modn_add(r, e, x1);
// s = (k + r) * d' - r
sm2_z256_modn_add(s, k, r);
sm2_z256_modn_mul(s, s, d);
sm2_z256_modn_sub(s, s, r);
sm2_u256_to_bytes(r, sig->r);
sm2_u256_to_bytes(s, sig->s);
return 1;
}
// (x1, y1) = k * G // (x1, y1) = k * G
// r = e + x1 // r = e + x1
// s = (k - r * d)/(1 + d) = (k +r - r * d - r)/(1 + d) = (k + r - r(1 +d))/(1 + d) = (k + r)/(1 + d) - r // s = (k - r * d)/(1 + d) = (k +r - r * d - r)/(1 + d) = (k + r - r(1 +d))/(1 + d) = (k + r)/(1 + d) - r
// = -r + (k + r)*(1 + d)^-1 // = -r + (k + r)*(1 + d)^-1
// = -r + (k + r) * d' // = -r + (k + r) * d'
int sm2_do_sign_fast(const SM2_Fn d, const uint8_t dgst[32], SM2_SIGNATURE *sig) // 这个函数是我们真正要调用的,甚至可以替代原来的函数
int sm2_do_sign_fast(const uint64_t d[4], const uint8_t dgst[32], SM2_SIGNATURE *sig)
{ {
SM2_U256_POINT R; SM2_U256_POINT R;
SM2_U256 e; SM2_U256 e;
@@ -155,6 +214,8 @@ int sm2_do_sign_fast(const SM2_Fn d, const uint8_t dgst[32], SM2_SIGNATURE *sig)
sm2_u256_sub(e, e, order); sm2_u256_sub(e, e, order);
} }
/// <<<<<<<<<<< 这里的 (k, x1) 应该是从外部输入的!!,这样才是最快的。
// rand k in [1, n - 1] // rand k in [1, n - 1]
do { do {
if (sm2_u256_modn_rand(k) != 1) { if (sm2_u256_modn_rand(k) != 1) {
@@ -164,16 +225,22 @@ int sm2_do_sign_fast(const SM2_Fn d, const uint8_t dgst[32], SM2_SIGNATURE *sig)
} while (sm2_u256_is_zero(k)); } while (sm2_u256_is_zero(k));
// (x1, y1) = kG // (x1, y1) = kG
sm2_u256_point_mul_generator(&R, k); sm2_u256_point_mul_generator(&R, k); // 这个函数要粗力度并行,这要怎么做?
sm2_u256_point_get_xy(&R, x1, NULL); sm2_u256_point_get_xy(&R, x1, NULL);
/// >>>>>>>>>>>>>>>>>>
// r = e + x1 (mod n) // r = e + x1 (mod n)
sm2_u256_modn_add(r, e, x1); sm2_u256_modn_add(r, e, x1);
// 对于快速实现来说,只需要一次乘法 // 对于快速实现来说,只需要一次乘法
// 如果 (k, x) 是预计算的,这意味着我们可以并行这个操作
// 也就是随机产生一些k然后执行粗力度并行的点乘
// s = (k + r) * d' - r // s = (k + r) * d' - r
sm2_u256_add(s, k, r); sm2_u256_modn_add(s, k, r);
sm2_u256_modn_mul(s, s, d); sm2_u256_modn_mul(s, s, d);
sm2_u256_modn_sub(s, s, r); sm2_u256_modn_sub(s, s, r);
@@ -182,6 +249,72 @@ int sm2_do_sign_fast(const SM2_Fn d, const uint8_t dgst[32], SM2_SIGNATURE *sig)
return 1; return 1;
} }
// 这个其实并没有更快无非就是降低了解析公钥椭圆曲线点的计算量这个点要转换为内部的Mont格式
// 这里根本没有modn的乘法
int sm2_do_verify_fast(const SM2_Z256_POINT *P, const uint8_t dgst[32], const SM2_SIGNATURE *sig)
{
SM2_U256_POINT R;
SM2_U256 r;
SM2_U256 s;
SM2_U256 e;
SM2_U256 x;
SM2_U256 t;
const uint64_t *order = sm2_u256_order();
sm2_u256_from_bytes(r, sig->r);
// check r in [1, n-1]
if (sm2_u256_is_zero(r) == 1) {
error_print();
return -1;
}
if (sm2_u256_cmp(r, order) >= 0) {
error_print();
return -1;
}
sm2_u256_from_bytes(s, sig->s);
// check s in [1, n-1]
if (sm2_u256_is_zero(s) == 1) {
error_print();
return -1;
}
if (sm2_u256_cmp(s, order) >= 0) {
error_print();
return -1;
}
// e = H(M)
sm2_u256_from_bytes(e, dgst);
// t = r + s (mod n), check t != 0
sm2_u256_modn_add(t, r, s);
if (sm2_u256_is_zero(t)) {
error_print();
return -1;
}
// Q = s * G + t * P
sm2_u256_point_mul_sum(&R, t, P, s);
sm2_u256_point_get_xy(&R, x, NULL);
// r' = e + x (mod n)
if (sm2_u256_cmp(e, order) >= 0) {
sm2_u256_sub(e, e, order);
}
if (sm2_u256_cmp(x, order) >= 0) {
sm2_u256_sub(x, x, order);
}
sm2_u256_modn_add(e, e, x);
// check if r == r'
if (sm2_u256_cmp(e, r) != 0) {
error_print();
return -1;
}
return 1;
}
int sm2_do_verify(const SM2_KEY *key, const uint8_t dgst[32], const SM2_SIGNATURE *sig) int sm2_do_verify(const SM2_KEY *key, const uint8_t dgst[32], const SM2_SIGNATURE *sig)
{ {
SM2_U256_POINT _P, *P = &_P; SM2_U256_POINT _P, *P = &_P;
@@ -277,6 +410,27 @@ static int all_zero(const uint8_t *buf, size_t len)
return 1; return 1;
} }
int sm2_do_encrypt_pre_compute(uint64_t k[4], uint8_t C1[64])
{
SM2_Z256_POINT P;
// rand k in [1, n - 1]
do {
if (sm2_z256_modn_rand(k) != 1) {
error_print();
return -1;
}
} while (sm2_z256_is_zero(k));
// output C1 = k * G = (x1, y1)
sm2_z256_point_mul_generator(&P, k);
sm2_z256_point_to_bytes(&P, C1);
return 1;
}
// 和签名不一样,加密的时候要生成 (k, (x1, y1)) 也就是y坐标也是需要的
// 其中k是要参与计算的但是 (x1, y1) 不参与计算,输出为 bytes 就可以了
int sm2_do_encrypt(const SM2_KEY *key, const uint8_t *in, size_t inlen, SM2_CIPHERTEXT *out) int sm2_do_encrypt(const SM2_KEY *key, const uint8_t *in, size_t inlen, SM2_CIPHERTEXT *out)
{ {
SM2_U256 k; SM2_U256 k;

View File

@@ -16,27 +16,45 @@
#include <gmssl/sm2.h> #include <gmssl/sm2.h>
#include <gmssl/pkcs8.h> #include <gmssl/pkcs8.h>
// 应该整理出不同编码长度的椭圆曲线点可以由x求出y
// 由于当前Ciphertext中椭圆曲线点数据不正确因此无法通过测试
static int test_sm2_ciphertext(void) static int test_sm2_ciphertext(void)
{ {
struct {
char *label;
size_t ciphertext_size;
} tests[] = {
{ "null ciphertext", 0 },
{ "min ciphertext size", SM2_MIN_PLAINTEXT_SIZE },
{ "max ciphertext size", SM2_MAX_PLAINTEXT_SIZE },
};
SM2_CIPHERTEXT C; SM2_CIPHERTEXT C;
SM2_KEY sm2_key;
uint8_t buf[1024]; uint8_t buf[1024];
size_t i;
rand_bytes(C.hash, 32);
rand_bytes(C.ciphertext, SM2_MAX_PLAINTEXT_SIZE);
for (i = 0; i < sizeof(tests)/sizeof(tests[0]); i++) {
uint8_t *p = buf; uint8_t *p = buf;
const uint8_t *cp = buf; const uint8_t *cp = buf;
size_t len = 0; size_t len = 0;
memset(&C, 0, sizeof(SM2_CIPHERTEXT)); if (sm2_key_generate(&sm2_key) != 1) {
error_print();
return -1;
}
C.point = sm2_key.public_key;
C.ciphertext_size = tests[i].ciphertext_size;
cp = p = buf; len = 0;
if (sm2_ciphertext_to_der(&C, &p, &len) != 1) { if (sm2_ciphertext_to_der(&C, &p, &len) != 1) {
error_print(); error_print();
return -1; return -1;
} }
format_print(stderr, 0, 4, "SM2_NULL_CIPHERTEXT_SIZE: %zu\n", len);
format_bytes(stderr, 0, 4, "", buf, len);
printf("Plaintext size = %zu, SM2Ciphertext DER size %zu\n", tests[i].ciphertext_size, len);
if (sm2_ciphertext_from_der(&C, &cp, &len) != 1 if (sm2_ciphertext_from_der(&C, &cp, &len) != 1
|| asn1_length_is_zero(len) != 1) { || asn1_length_is_zero(len) != 1) {
@@ -44,60 +62,6 @@ static int test_sm2_ciphertext(void)
return -1; return -1;
} }
// {0, 0, Hash, MinLen}
C.ciphertext_size = SM2_MIN_PLAINTEXT_SIZE;
cp = p = buf; len = 0;
if (sm2_ciphertext_to_der(&C, &p, &len) != 1) {
error_print();
return -1;
}
format_print(stderr, 0, 4, "SM2_MIN_PLAINTEXT_SIZE: %zu\n", SM2_MIN_PLAINTEXT_SIZE);
format_print(stderr, 0, 4, "SM2_MIN_CIPHERTEXT_SIZE: %zu\n", len);
format_bytes(stderr, 0, 4, "", buf, len);
if (len != SM2_MIN_CIPHERTEXT_SIZE) {
error_print();
return -1;
}
if (sm2_ciphertext_from_der(&C, &cp, &len) != 1
|| asn1_length_is_zero(len) != 1) {
error_print();
return -1;
}
// { 33, 33, Hash, NULL }
memset(&C, 0x80, sizeof(SM2_POINT));
cp = p = buf; len = 0;
if (sm2_ciphertext_to_der(&C, &p, &len) != 1) {
error_print();
return -1;
}
format_print(stderr, 0, 4, "ciphertext len: %zu\n", len);
format_bytes(stderr, 0, 4, "", buf, len);
if (sm2_ciphertext_from_der(&C, &cp, &len) != 1
|| asn1_length_is_zero(len) != 1) {
error_print();
return -1;
}
// { 33, 33, Hash, MaxLen }
C.ciphertext_size = SM2_MAX_PLAINTEXT_SIZE;//SM2_MAX_PLAINTEXT_SIZE;
cp = p = buf; len = 0;
if (sm2_ciphertext_to_der(&C, &p, &len) != 1) {
error_print();
return -1;
}
format_print(stderr, 0, 4, "SM2_MAX_PLAINTEXT_SIZE: %zu\n", SM2_MAX_PLAINTEXT_SIZE);
format_print(stderr, 0, 4, "SM2_MAX_CIPHERTEXT_SIZE: %zu\n", len);
format_bytes(stderr, 0, 4, "", buf, len);
if (len != SM2_MAX_CIPHERTEXT_SIZE) {
error_print();
return -1;
}
if (sm2_ciphertext_from_der(&C, &cp, &len) != 1
|| asn1_length_is_zero(len) != 1) {
error_print();
return -1;
} }
printf("%s() ok\n", __FUNCTION__); printf("%s() ok\n", __FUNCTION__);
@@ -265,14 +229,6 @@ static int test_sm2_encrypt_fixlen(void)
} }
// 应该生成不同情况下的密文!
static int test_sm2_encrypt(void) static int test_sm2_encrypt(void)
{ {
SM2_KEY sm2_key; SM2_KEY sm2_key;
@@ -327,7 +283,7 @@ static int test_sm2_encrypt(void)
int main(void) int main(void)
{ {
//if (test_sm2_ciphertext() != 1) goto err; // 需要正确的Ciphertext数据 if (test_sm2_ciphertext() != 1) goto err;
if (test_sm2_do_encrypt() != 1) goto err; if (test_sm2_do_encrypt() != 1) goto err;
if (test_sm2_do_encrypt_fixlen() != 1) goto err; if (test_sm2_do_encrypt_fixlen() != 1) goto err;
if (test_sm2_encrypt() != 1) goto err; if (test_sm2_encrypt() != 1) goto err;

View File

@@ -109,6 +109,7 @@ static int test_sm2_do_sign(void)
#define sm2_u256_modn_add sm2_z256_modn_add #define sm2_u256_modn_add sm2_z256_modn_add
#define sm2_u256_modn_inv sm2_z256_modn_inv #define sm2_u256_modn_inv sm2_z256_modn_inv
static int test_sm2_do_sign_fast(void) static int test_sm2_do_sign_fast(void)
{ {
SM2_KEY sm2_key; SM2_KEY sm2_key;
@@ -141,6 +142,45 @@ static int test_sm2_do_sign_fast(void)
return 1; return 1;
} }
static int test_sm2_do_sign_pre_compute(void)
{
SM2_KEY sm2_key;
uint64_t d[4];
uint64_t k[4];
uint64_t x1[4];
uint8_t dgst[32];
SM2_SIGNATURE sig;
sm2_key_generate(&sm2_key);
const uint64_t *one = sm2_z256_one();
sm2_z256_from_bytes(d, sm2_key.private_key);
sm2_z256_modn_add(d, d, one);
sm2_z256_modn_inv(d, d);
if (sm2_do_sign_pre_compute(k, x1) != 1) {
error_print();
return -1;
}
rand_bytes(dgst, sizeof(dgst));
if (sm2_do_sign_fast_ex(d, k, x1, dgst, &sig) != 1) {
error_print();
return -1;
}
if (sm2_do_verify(&sm2_key, dgst, &sig) != 1) {
error_print();
return -1;
}
printf("%s() ok\n", __FUNCTION__);
return 1;
}
static int test_sm2_sign(void) static int test_sm2_sign(void)
{ {
SM2_KEY sm2_key; SM2_KEY sm2_key;
@@ -209,12 +249,95 @@ static int test_sm2_sign_ctx(void)
return 1; return 1;
} }
static int test_sm2_sign_ctx_reset(void)
{
SM2_KEY sm2_key;
SM2_SIGN_CTX sign_ctx;
SM2_SIGN_CTX vrfy_ctx;
uint8_t msg[64];
uint8_t sig[SM2_MAX_SIGNATURE_SIZE];
size_t siglen;
if (sm2_key_generate(&sm2_key) != 1) {
error_print();
return -1;
}
// init sign_ctx and sign a message
rand_bytes(msg, sizeof(msg));
if (sm2_sign_init(&sign_ctx, &sm2_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1) {
error_print();
return -1;
}
if (sm2_sign_update(&sign_ctx, msg, sizeof(msg)) != 1) {
error_print();
return -1;
}
if (sm2_sign_finish(&sign_ctx, sig, &siglen) != 1) {
error_print();
return -1;
}
// reset sign_ctx and sign another message
rand_bytes(msg, sizeof(msg));
if (sm2_sign_ctx_reset(&sign_ctx) != 1) {
error_print();
return -1;
}
if (sm2_sign_update(&sign_ctx, msg, sizeof(msg)) != 1) {
error_print();
return -1;
}
if (sm2_sign_finish(&sign_ctx, sig, &siglen) != 1) {
error_print();
return -1;
}
// verify, check whether reset works
if (sm2_verify_init(&vrfy_ctx, &sm2_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1) {
error_print();
return -1;
}
if (sm2_verify_update(&vrfy_ctx, msg, sizeof(msg)) != 1) {
error_print();
return -1;
}
if (sm2_verify_finish(&vrfy_ctx, sig, siglen) != 1) {
format_bytes(stderr, 0, 4, "signature", sig, siglen);
error_print();
return -1;
}
// reset ctx and verify again
if (sm2_sign_ctx_reset(&vrfy_ctx) != 1) {
error_print();
return -1;
}
if (sm2_verify_update(&vrfy_ctx, msg, sizeof(msg)) != 1) {
error_print();
return -1;
}
if (sm2_verify_finish(&vrfy_ctx, sig, siglen) != 1) {
format_bytes(stderr, 0, 4, "signature", sig, siglen);
error_print();
return -1;
}
printf("%s() ok\n", __FUNCTION__);
return 1;
}
int main(void) int main(void)
{ {
if (test_sm2_do_sign_fast() != 1) goto err;
if (test_sm2_signature() != 1) goto err; if (test_sm2_signature() != 1) goto err;
if (test_sm2_do_sign() != 1) goto err; if (test_sm2_do_sign() != 1) goto err;
if (test_sm2_do_sign_pre_compute() != 1) goto err;
if (test_sm2_sign() != 1) goto err; if (test_sm2_sign() != 1) goto err;
if (test_sm2_sign_ctx() != 1) goto err; if (test_sm2_sign_ctx() != 1) goto err;
if (test_sm2_sign_ctx_reset() != 1) goto err;
printf("%s all tests passed\n", __FILE__); printf("%s all tests passed\n", __FILE__);
return 0; return 0;
err: err:

View File

@@ -12,11 +12,20 @@
#include <string.h> #include <string.h>
#include <stdlib.h> #include <stdlib.h>
#include <stdint.h> #include <stdint.h>
#include <gmssl/sm2.h>
#include <gmssl/sm2_z256.h> #include <gmssl/sm2_z256.h>
#include <gmssl/sm3.h>
#include <gmssl/sm3_digest.h>
#include <gmssl/hex.h> #include <gmssl/hex.h>
#include <gmssl/rand.h> #include <gmssl/rand.h>
#include <gmssl/error.h> #include <gmssl/error.h>
/*
TODO: 验证点加、倍点等计算是否支持无穷远点、共轭点等特殊形势
*/
enum { enum {
OP_ADD, OP_ADD,
OP_DBL, OP_DBL,
@@ -28,6 +37,67 @@ enum {
OP_INV, OP_INV,
}; };
#define TEST_COUNT 10
static int test_sm2_z256_rshift(void)
{
uint64_t r[4];
uint64_t a[4];
uint64_t b[4];
unsigned int i;
sm2_z256_modn_rand(a);
sm2_z256_rshift(r, a, 0);
sm2_z256_copy(b, a);
if (sm2_z256_cmp(r, b) != 0) {
error_print();
return -1;
}
sm2_z256_rshift(r, a, 63);
for (i = 0; i < 63; i++) {
sm2_z256_rshift(a, a, 1);
}
if (sm2_z256_cmp(r, a) != 0) {
error_print();
return -1;
}
printf("%s() ok\n", __FUNCTION__);
return 1;
}
static int test_sm2_z256_modp_mont_sqrt(void)
{
uint64_t a[4];
uint64_t neg_a[4];
uint64_t mont_a[4];
uint64_t mont_sqr_a[4];
uint64_t mont_a_[4];
uint64_t a_[4];
int i;
for (i = 0; i < 6; i++) {
sm2_z256_modn_rand(a);
sm2_z256_modp_neg(neg_a, a);
sm2_z256_modp_to_mont(a, mont_a);
sm2_z256_modp_mont_sqr(mont_sqr_a, mont_a);
sm2_z256_modp_mont_sqrt(mont_a_, mont_sqr_a);
sm2_z256_modp_from_mont(a_, mont_a_);
// a_ = sqrt(a^2), a_ should be a or -a
if (sm2_z256_cmp(a_, a) != 0 && sm2_z256_cmp(a_, neg_a) != 0) {
error_print();
return -1;
}
}
printf("%s() ok\n", __FUNCTION__);
return 1;
}
static int test_sm2_z256_modp(void) static int test_sm2_z256_modp(void)
{ {
struct { struct {
@@ -389,6 +459,95 @@ static int test_sm2_z256_point_get_xy(void)
return 1; return 1;
} }
static int test_sm2_z256_point_from_x_bytes(void)
{
struct {
char *label;
char *xy;
int y_is_odd;
} tests[] = {
{
"G (y is even)",
"32c4ae2c1f1981195f9904466a39c9948fe30bbff2660be1715a4589334c74c7"
"bc3736a2f4f6779c59bdcee36b692153d0a9877cc62a474002df32e52139f0a0",
0,
},
{
"2G (y is odd)",
"56cefd60d7c87c000d58ef57fa73ba4d9c0dfa08c08a7331495c2e1da3f2bd52"
"31b7e7e6cc8189f668535ce0f8eaf1bd6de84c182f6c8e716f780d3a970a23c3",
1,
},
};
SM2_Z256_POINT P;
uint8_t x_bytes[32];
size_t i, len;
for (i = 0; i < sizeof(tests)/sizeof(tests[0]); i++) {
hex_to_bytes(tests[i].xy, 64, x_bytes, &len);
sm2_z256_point_from_x_bytes(&P, x_bytes, tests[i].y_is_odd);
if (sm2_z256_point_equ_hex(&P, tests[i].xy) != 1) {
error_print();
return -1;
}
}
printf("%s() ok\n", __FUNCTION__);
return 1;
}
static int test_sm2_z256_point_add_conjugate(void)
{
char *hex_G =
"32c4ae2c1f1981195f9904466a39c9948fe30bbff2660be1715a4589334c74c7"
"bc3736a2f4f6779c59bdcee36b692153d0a9877cc62a474002df32e52139f0a0";
char *hex_negG =
"32c4ae2c1f1981195f9904466a39c9948fe30bbff2660be1715a4589334c74c7"
"43c8c95c0b098863a642311c9496deac2f56788239d5b8c0fd20cd1adec60f5f";
SM2_Z256_POINT R;
SM2_Z256_POINT P;
SM2_Z256_POINT Q;
sm2_z256_point_from_hex(&P, hex_G);
sm2_z256_point_from_hex(&Q, hex_negG);
sm2_z256_point_add(&R, &P, &Q);
// P + (-P) = (0:0:0)
if (!sm2_z256_is_zero(R.X)
|| !sm2_z256_is_zero(R.Y)
|| !sm2_z256_is_zero(R.Z)) {
error_print();
return -1;
}
printf("%s() ok\n", __FUNCTION__);
return 1;
}
static int test_sm2_z256_point_dbl_infinity(void)
{
SM2_Z256_POINT P_infinity;
SM2_Z256_POINT R;
sm2_z256_point_set_infinity(&P_infinity);
sm2_z256_point_dbl(&R, &P_infinity); // 显然这个计算就会出错了!
if (!sm2_z256_point_is_at_infinity(&R)) {
error_print(); // 这个会出错
return -1;
}
printf("%s() ok\n", __FUNCTION__);
return 1;
}
static int test_sm2_z256_point_ops(void) static int test_sm2_z256_point_ops(void)
{ {
char *hex_G = char *hex_G =
@@ -472,6 +631,8 @@ static int test_sm2_z256_point_ops(void)
return -1; return -1;
} }
if (sm2_z256_point_equ_hex(&P, tests[i].R) != 1) {
fprintf(stderr, "%s\n", tests[i].label); fprintf(stderr, "%s\n", tests[i].label);
sm2_z256_point_print(stderr, 0, 4, "R", &P); sm2_z256_point_print(stderr, 0, 4, "R", &P);
fprintf(stderr, " R: %s\n", tests[i].R); fprintf(stderr, " R: %s\n", tests[i].R);
@@ -479,8 +640,6 @@ static int test_sm2_z256_point_ops(void)
fprintf(stderr, " A: %s\n", tests[i].A); fprintf(stderr, " A: %s\n", tests[i].A);
fprintf(stderr, " B: %s\n", tests[i].B); fprintf(stderr, " B: %s\n", tests[i].B);
if (sm2_z256_point_equ_hex(&P, tests[i].R) != 1) {
error_print(); error_print();
return -1; return -1;
} }
@@ -614,14 +773,108 @@ static int test_sm2_z256_point_mul_generator(void)
return 1; return 1;
} }
static int test_sm2_z256_point_equ(void)
{
struct {
char *label;
char *mont_X1;
char *mont_Y1;
char *mont_Z1;
char *mont_X2;
char *mont_Y2;
char *mont_Z2;
} tests[] = {
{
"Point at Infinity (1:1:0)",
"0000000100000000000000000000000000000000ffffffff0000000000000001", // mont(1)
"0000000100000000000000000000000000000000ffffffff0000000000000001", // mont(1)
"0000000000000000000000000000000000000000000000000000000000000000", // 0
"0000000100000000000000000000000000000000ffffffff0000000000000001", // mont(1)
"0000000100000000000000000000000000000000ffffffff0000000000000001", // mont(1)
"0000000000000000000000000000000000000000000000000000000000000000", // 0
},
{
"[2]2G == 2G + G + G",
"87b2ca9ded2487c6efdbc69303258763a0b5520fc63cf40154f6c059b945acf2",
"dc86353bc72db45ebb5b2d03cec4614b164688f19f12dd857fd007e181457b59",
"050653f8579d1d2d930d7346e31bad56b5a4654d6a9f2c5022434941744ced3a",
"e8457905838420a51366f7fe174ce34dc3579fefc188f0b5124e7537526ae99e",
"48c3374ab1d5fde0276bebb81b8ff0baa9805cc2d0f487e18d7b3a4352f4ae21",
"79f76fd57f22f1e282d64ff809a53f1f729f6b89c6f626b96725a9d05704e681",
}
};
SM2_Z256_POINT P;
SM2_Z256_POINT Q;
size_t i;
for (i = 0; i < sizeof(tests)/sizeof(tests[0]); i++) {
sm2_z256_from_hex(P.X, tests[i].mont_X1);
sm2_z256_from_hex(P.Y, tests[i].mont_Y1);
sm2_z256_from_hex(P.Z, tests[i].mont_Z1);
sm2_z256_from_hex(Q.X, tests[i].mont_X2);
sm2_z256_from_hex(Q.Y, tests[i].mont_Y2);
sm2_z256_from_hex(Q.Z, tests[i].mont_Z2);
if (sm2_z256_point_equ(&P, &Q) != 1) {
error_print();
return -1;
}
}
printf("%s() ok\n", __FUNCTION__);
return 1;
}
static int test_sm2_z256_point_from_hash(void)
{
SM2_Z256_POINT P;
uint8_t data[64];
size_t datalen = sizeof(data);
int y_is_odd = 1;
int y_is_even = 0;
size_t i;
for (i = 0; i < 5; i++) {
rand_bytes(data, datalen);
if (sm2_z256_point_from_hash(&P, data, datalen, y_is_odd) != 1) {
error_print();
return -1;
}
if (sm2_z256_point_from_hash(&P, data, datalen, y_is_even) != 1) {
error_print();
return -1;
}
}
printf("%s() ok\n", __FUNCTION__);
return 1;
}
int main(void) int main(void)
{ {
if (test_sm2_z256_rshift() != 1) goto err;
if (test_sm2_z256_modp() != 1) goto err; if (test_sm2_z256_modp() != 1) goto err;
if (test_sm2_z256_modn() != 1) goto err; if (test_sm2_z256_modn() != 1) goto err;
if (test_sm2_z256_point_is_on_curve() != 1) goto err; if (test_sm2_z256_point_is_on_curve() != 1) goto err;
if (test_sm2_z256_point_equ() != 1) goto err;
if (test_sm2_z256_point_get_xy() != 1) goto err; if (test_sm2_z256_point_get_xy() != 1) goto err;
if (test_sm2_z256_point_add_conjugate() != 1) goto err;
if (test_sm2_z256_point_dbl_infinity() != 1) goto err;
if (test_sm2_z256_point_ops() != 1) goto err; if (test_sm2_z256_point_ops() != 1) goto err;
if (test_sm2_z256_point_mul_generator() != 1) goto err; if (test_sm2_z256_point_mul_generator() != 1) goto err;
if (test_sm2_z256_point_from_hash() != 1) goto err;
if (test_sm2_z256_point_from_x_bytes() != 1) goto err;
if (test_sm2_z256_modp_mont_sqrt() != 1) goto err;
printf("%s all tests passed\n", __FILE__); printf("%s all tests passed\n", __FILE__);
return 0; return 0;
err: err:

59
tools/sm2speed.c Normal file
View File

@@ -0,0 +1,59 @@
/*
* Copyright 2014-2024 The GmSSL Project. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the License); you may
* not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*/
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <gmssl/sm2.h>
#include <gmssl/error.h>
int sm2sign_speed(void)
{
SM2_KEY sm2_key;
SM2_SIGN_CTX sign_ctx;
uint8_t msg[32];
uint8_t sig[SM2_MAX_SIGNATURE_SIZE];
size_t siglen;
size_t i;
if (sm2_key_generate(&sm2_key) != 1) {
error_print();
return -1;
}
if (sm2_sign_init(&sign_ctx, &sm2_key, SM2_DEFAULT_ID, strlen(SM2_DEFAULT_ID)) != 1) {
error_print();
return -1;
}
for (i = 0; i < 10000; i++) {
/*
if (sm2_sign_init(&sign_ctx, &sm2_key, SM2_DEFAULT_ID, strlen(SM2_DEFAULT_ID)) != 1) {
error_print();
return -1;
}
*/
if (sm2_sign_ctx_reset(&sign_ctx) != 1
|| sm2_sign_update(&sign_ctx, msg, sizeof(msg)) != 1
|| sm2_sign_finish(&sign_ctx, sig, &siglen) != 1) {
error_print();
return -1;
}
}
return 0;
}
int main(void)
{
sm2sign_speed();
return 0;
}

154
tools/sm3xmss_sign.c Normal file
View File

@@ -0,0 +1,154 @@
/*
* Copyright 2014-2023 The GmSSL Project. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the License); you may
* not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*/
#include <stdio.h>
#include <errno.h>
#include <string.h>
#include <stdlib.h>
#include <gmssl/error.h>
#include <gmssl/sm3_xmss.h>
static const char *usage = "-key file [-in file] [-out file]\n";
static const char *help =
"Options\n"
" -key file Input private key file\n"
" -in file Input data file (if not using stdin)\n"
" -out file Output signature file\n"
"\n";
int sm3xmss_sign_main(int argc, char **argv)
{
int ret = 1;
char *prog = argv[0];
char *keyfile = NULL;
char *infile = NULL;
char *outfile = NULL;
FILE *keyfp = NULL;
FILE *infp = stdin;
FILE *outfp = stdout;
SM3_XMSS_KEY key;
SM3_XMSS_SIGN_CTX sign_ctx;
uint8_t *sigbuf = NULL;
size_t siglen;
argc--;
argv++;
if (argc < 1) {
fprintf(stderr, "usage: %s %s\n", prog, usage);
return 1;
}
while (argc > 0) {
if (!strcmp(*argv, "-help")) {
printf("usage: %s %s\n", prog, usage);
printf("%s\n", help);
ret = 0;
goto end;
} else if (!strcmp(*argv, "-key")) {
if (--argc < 1) goto bad;
keyfile = *(++argv);
if (!(keyfp = fopen(keyfile, "rb"))) {
fprintf(stderr, "%s: open '%s' failure: %s\n", prog, keyfile, strerror(errno));
goto end;
}
} else if (!strcmp(*argv, "-in")) {
if (--argc < 1) goto bad;
infile = *(++argv);
if (!(infp = fopen(infile, "rb"))) {
fprintf(stderr, "%s: open '%s' failure: %s\n", prog, infile, strerror(errno));
goto end;
}
} else if (!strcmp(*argv, "-out")) {
if (--argc < 1) goto bad;
outfile = *(++argv);
if (!(outfp = fopen(outfile, "wb"))) {
fprintf(stderr, "%s: open '%s' failure: %s\n", prog, outfile, strerror(errno));
goto end;
}
} else {
fprintf(stderr, "%s: illegal option '%s'\n", prog, *argv);
goto end;
bad:
fprintf(stderr, "%s: `%s` option value missing\n", prog, *argv);
goto end;
}
argc--;
argv++;
}
if (!keyfile) {
fprintf(stderr, "%s: `-key` option required\n", prog);
goto end;
}
if (sm3_xmss_key_from_bytes(&key, NULL, 0) != 1) {
error_print();
goto end;
}
if (fread(&key, 1, sizeof(key), keyfp) != sizeof(key)) {
fprintf(stderr, "%s: read private key failure\n", prog);
goto end;
}
if (sm3_xmss_sign_init(&sign_ctx, &key) != 1) {
error_print();
goto end;
}
while (1) {
uint8_t buf[1024];
size_t len = fread(buf, 1, sizeof(buf), infp);
if (len == 0) {
break;
}
if (sm3_xmss_sign_update(&sign_ctx, buf, len) != 1) {
error_print();
goto end;
}
}
if (sm3_xmss_sign_finish(&sign_ctx, &key, NULL, &siglen) != 1) {
error_print();
goto end;
}
if (!(sigbuf = malloc(siglen))) {
fprintf(stderr, "%s: malloc failure\n", prog);
goto end;
}
if (sm3_xmss_sign_finish(&sign_ctx, &key, sigbuf, &siglen) != 1) {
error_print();
goto end;
}
if (fwrite(sigbuf, 1, siglen, outfp) != siglen) {
error_print();
goto end;
}
ret = 0;
end:
gmssl_secure_clear(&key, sizeof(key));
gmssl_secure_clear(&sign_ctx, sizeof(sign_ctx));
if (sigbuf) {
gmssl_secure_clear(sigbuf, siglen);
free(sigbuf);
}
if (keyfp) fclose(keyfp);
if (infp && infp != stdin) fclose(infp);
if (outfp && outfp != stdout) fclose(outfp);
return ret;
}