Add GMUL ASM

This commit is contained in:
Zhi Guan
2024-03-31 16:54:24 +08:00
parent 6de0e0229b
commit 6b36c51cdf
5 changed files with 167 additions and 147 deletions

View File

@@ -245,6 +245,14 @@ if (ENABLE_SM2_ALGOR_ID_ENCODE_NULL)
endif() endif()
option(ENABLE_GMUL_ARMV8 "Enable GF(2^128) Multiplication ARMv8 assembly" OFF)
if (ENABLE_GMUL_ARMV8)
message(STATUS "ENABLE_GMUL_ARMV8 is ON")
add_definitions(-DENABLE_GMUL_ARMV8)
enable_language(ASM)
list(APPEND src src/gf128_armv8.S)
endif()
option(ENABLE_SM2_Z256_ARMV8 "Enable SM2_Z256 ARMv8 assembly" OFF) option(ENABLE_SM2_Z256_ARMV8 "Enable SM2_Z256 ARMv8 assembly" OFF)
if (ENABLE_SM2_Z256_ARMV8) if (ENABLE_SM2_Z256_ARMV8)
message(STATUS "ENABLE_SM2_Z256_ARMV8 is ON") message(STATUS "ENABLE_SM2_Z256_ARMV8 is ON")
@@ -489,6 +497,7 @@ if (WIN32)
elseif (APPLE) elseif (APPLE)
target_link_libraries(gmssl dl) target_link_libraries(gmssl dl)
target_link_libraries(gmssl "-framework Security") target_link_libraries(gmssl "-framework Security")
#target_link_libraries(gmssl "-framework OpenCL")
#target_link_libraries(gmssl "-framework CoreFoundation") # rand_apple.c CFRelease() #target_link_libraries(gmssl "-framework CoreFoundation") # rand_apple.c CFRelease()
elseif (MINGW) elseif (MINGW)
target_link_libraries(gmssl PRIVATE wsock32) target_link_libraries(gmssl PRIVATE wsock32)

View File

@@ -29,9 +29,11 @@ extern "C" {
//typedef unsigned __int128 gf128_t; //typedef unsigned __int128 gf128_t;
// the least significant bit of lo is a_0, the most significant bit of hi is a_127
// so x^7 + x^2 + x + 1 is 0x87
typedef struct { typedef struct {
uint64_t hi;
uint64_t lo; uint64_t lo;
uint64_t hi;
} gf128_t; } gf128_t;
@@ -45,7 +47,7 @@ gf128_t gf128_mul2(gf128_t a);
gf128_t gf128_from_bytes(const uint8_t p[16]); gf128_t gf128_from_bytes(const uint8_t p[16]);
void gf128_to_bytes(gf128_t a, uint8_t p[16]); void gf128_to_bytes(gf128_t a, uint8_t p[16]);
int gf128_print(FILE *fp, int fmt ,int ind, const char *label, gf128_t a); int gf128_print(FILE *fp, int fmt ,int ind, const char *label, gf128_t a);
void gf128_print_bits(gf128_t a);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@@ -1,131 +0,0 @@
/*
* Copyright 2014-2024 The GmSSL Project. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the License); you may
* not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*/
#ifndef GMSSL_SM2_P256_H
#define GMSSL_SM2_P256_H
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#ifdef __cplusplus
extern "C" {
#endif
typedef uint64_t SM2_BN[8];
int sm2_bn_is_zero(const SM2_BN a);
int sm2_bn_is_one(const SM2_BN a);
int sm2_bn_is_odd(const SM2_BN a);
int sm2_bn_cmp(const SM2_BN a, const SM2_BN b);
int sm2_bn_from_hex(SM2_BN r, const char hex[64]);
int sm2_bn_from_asn1_integer(SM2_BN r, const uint8_t *d, size_t dlen);
int sm2_bn_equ_hex(const SM2_BN a, const char *hex);
int sm2_bn_print(FILE *fp, int fmt, int ind, const char *label, const SM2_BN a);
int sm2_bn_rshift(SM2_BN ret, const SM2_BN a, unsigned int nbits);
void sm2_bn_to_bytes(const SM2_BN a, uint8_t out[32]);
void sm2_bn_from_bytes(SM2_BN r, const uint8_t in[32]);
void sm2_bn_to_hex(const SM2_BN a, char hex[64]);
void sm2_bn_to_bits(const SM2_BN a, char bits[256]);
void sm2_bn_set_word(SM2_BN r, uint32_t a);
void sm2_bn_add(SM2_BN r, const SM2_BN a, const SM2_BN b);
void sm2_bn_sub(SM2_BN ret, const SM2_BN a, const SM2_BN b);
int sm2_bn_rand_range(SM2_BN r, const SM2_BN range);
#define sm2_bn_init(r) memset((r),0,sizeof(SM2_BN))
#define sm2_bn_set_zero(r) memset((r),0,sizeof(SM2_BN))
#define sm2_bn_set_one(r) sm2_bn_set_word((r),1)
#define sm2_bn_copy(r,a) memcpy((r),(a),sizeof(SM2_BN))
#define sm2_bn_clean(r) memset((r),0,sizeof(SM2_BN))
// GF(p)
typedef SM2_BN SM2_Fp;
void sm2_fp_add(SM2_Fp r, const SM2_Fp a, const SM2_Fp b);
void sm2_fp_sub(SM2_Fp r, const SM2_Fp a, const SM2_Fp b);
void sm2_fp_mul(SM2_Fp r, const SM2_Fp a, const SM2_Fp b);
void sm2_fp_exp(SM2_Fp r, const SM2_Fp a, const SM2_Fp e);
void sm2_fp_dbl(SM2_Fp r, const SM2_Fp a);
void sm2_fp_tri(SM2_Fp r, const SM2_Fp a);
void sm2_fp_div2(SM2_Fp r, const SM2_Fp a);
void sm2_fp_neg(SM2_Fp r, const SM2_Fp a);
void sm2_fp_sqr(SM2_Fp r, const SM2_Fp a);
void sm2_fp_inv(SM2_Fp r, const SM2_Fp a);
int sm2_fp_rand(SM2_Fp r);
int sm2_fp_sqrt(SM2_Fp r, const SM2_Fp a);
#define sm2_fp_init(r) sm2_bn_init(r)
#define sm2_fp_set_zero(r) sm2_bn_set_zero(r)
#define sm2_fp_set_one(r) sm2_bn_set_one(r)
#define sm2_fp_copy(r,a) sm2_bn_copy(r,a)
#define sm2_fp_clean(r) sm2_bn_clean(r)
// GF(n)
typedef SM2_BN SM2_Fn;
void sm2_fn_add(SM2_Fn r, const SM2_Fn a, const SM2_Fn b);
void sm2_fn_sub(SM2_Fn r, const SM2_Fn a, const SM2_Fn b);
void sm2_fn_mul(SM2_Fn r, const SM2_Fn a, const SM2_Fn b);
void sm2_fn_mul_word(SM2_Fn r, const SM2_Fn a, uint32_t b);
void sm2_fn_exp(SM2_Fn r, const SM2_Fn a, const SM2_Fn e);
void sm2_fn_neg(SM2_Fn r, const SM2_Fn a);
void sm2_fn_sqr(SM2_Fn r, const SM2_Fn a);
void sm2_fn_inv(SM2_Fn r, const SM2_Fn a);
int sm2_fn_rand(SM2_Fn r);
#define sm2_fn_init(r) sm2_bn_init(r)
#define sm2_fn_set_zero(r) sm2_bn_set_zero(r)
#define sm2_fn_set_one(r) sm2_bn_set_one(r)
#define sm2_fn_copy(r,a) sm2_bn_copy(r,a)
#define sm2_fn_clean(r) sm2_bn_clean(r)
typedef struct {
SM2_BN X;
SM2_BN Y;
SM2_BN Z;
} SM2_JACOBIAN_POINT;
void sm2_jacobian_point_init(SM2_JACOBIAN_POINT *R);
void sm2_jacobian_point_set_xy(SM2_JACOBIAN_POINT *R, const SM2_BN x, const SM2_BN y);
void sm2_jacobian_point_get_xy(const SM2_JACOBIAN_POINT *P, SM2_BN x, SM2_BN y);
void sm2_jacobian_point_neg(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P);
void sm2_jacobian_point_dbl(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P);
void sm2_jacobian_point_add(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P, const SM2_JACOBIAN_POINT *Q);
void sm2_jacobian_point_sub(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P, const SM2_JACOBIAN_POINT *Q);
void sm2_jacobian_point_mul(SM2_JACOBIAN_POINT *R, const SM2_BN k, const SM2_JACOBIAN_POINT *P);
void sm2_jacobian_point_to_bytes(const SM2_JACOBIAN_POINT *P, uint8_t out[64]);
void sm2_jacobian_point_from_bytes(SM2_JACOBIAN_POINT *P, const uint8_t in[64]);
void sm2_jacobian_point_mul_generator(SM2_JACOBIAN_POINT *R, const SM2_BN k);
void sm2_jacobian_point_mul_sum(SM2_JACOBIAN_POINT *R, const SM2_BN t, const SM2_JACOBIAN_POINT *P, const SM2_BN s);
void sm2_jacobian_point_from_hex(SM2_JACOBIAN_POINT *P, const char hex[64 * 2]); // for testing only
int sm2_jacobian_point_is_at_infinity(const SM2_JACOBIAN_POINT *P);
int sm2_jacobian_point_is_on_curve(const SM2_JACOBIAN_POINT *P);
int sm2_jacobian_point_equ_hex(const SM2_JACOBIAN_POINT *P, const char hex[128]); // for testing only
int sm2_jacobian_point_print(FILE *fp, int fmt, int ind, const char *label, const SM2_JACOBIAN_POINT *P);
#define sm2_jacobian_point_set_infinity(R) sm2_jacobian_point_init(R)
#define sm2_jacobian_point_copy(R, P) memcpy((R), (P), sizeof(SM2_JACOBIAN_POINT))
const uint64_t *sm2_bn_prime(void);
const uint64_t *sm2_bn_order(void);
const uint64_t *sm2_bn_one(void);
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -22,6 +22,19 @@
#include <gmssl/endian.h> #include <gmssl/endian.h>
#include <gmssl/error.h> #include <gmssl/error.h>
static uint64_t reverse_bits(uint64_t a)
{
uint64_t r = 0;
int i;
for (i = 0; i < 63; i++) {
r |= a & 1;
r <<= 1;
a >>= 1;
}
r |= a & 1;
return r;
}
gf128_t gf128_zero(void) gf128_t gf128_zero(void)
{ {
@@ -50,6 +63,10 @@ int gf128_equ_hex(gf128_t a, const char *s)
void gf128_print_bits(gf128_t a) void gf128_print_bits(gf128_t a)
{ {
int i; int i;
a.hi = reverse_bits(a.hi);
a.lo = reverse_bits(a.lo);
for (i = 0; i < 64; i++) { for (i = 0; i < 64; i++) {
printf("%d", (int)(a.hi % 2)); printf("%d", (int)(a.hi % 2));
a.hi >>= 1; a.hi >>= 1;
@@ -75,20 +92,6 @@ int gf128_print(FILE *fp, int fmt, int ind, const char *label, gf128_t a)
return 1; return 1;
} }
static uint64_t reverse_bits(uint64_t a)
{
uint64_t r = 0;
int i;
for (i = 0; i < 63; i++) {
r |= a & 1;
r <<= 1;
a >>= 1;
}
r |= a & 1;
return r;
}
gf128_t gf128_from_bytes(const uint8_t p[16]) gf128_t gf128_from_bytes(const uint8_t p[16])
{ {
gf128_t r; gf128_t r;
@@ -117,6 +120,7 @@ gf128_t gf128_add(gf128_t a, gf128_t b)
return r; return r;
} }
#ifndef ENABLE_GMUL_ARMV8
gf128_t gf128_mul(gf128_t a, gf128_t b) gf128_t gf128_mul(gf128_t a, gf128_t b)
{ {
gf128_t r = {0, 0}; gf128_t r = {0, 0};
@@ -159,6 +163,18 @@ gf128_t gf128_mul(gf128_t a, gf128_t b)
return r; return r;
} }
#else
extern void gmul(uint64_t r[2], const uint64_t a[2], const uint64_t b[2]);
gf128_t gf128_mul(gf128_t a, gf128_t b)
{
gf128_t r;
gmul(&r, &a, &b);
return r;
}
#endif
gf128_t gf128_mul2(gf128_t a) gf128_t gf128_mul2(gf128_t a)
{ {

View File

@@ -17,6 +17,125 @@
#include <gmssl/error.h> #include <gmssl/error.h>
int test_gf128_mul_more(void)
{
struct {
char *label;
char *r;
char *a;
char *b;
} tests[] = {
{
"1 * 0",
"0000000000000000" "0000000000000000",
"8000000000000000" "0000000000000000",
"0000000000000000" "0000000000000000",
},
// 这个现在显然是不对的
{
"1 * 1",
"8000000000000000" "0000000000000000",
"8000000000000000" "0000000000000000",
"8000000000000000" "0000000000000000",
},
{
" * 2",
"e1000000000000000000000000000000",
"00000000000000000000000000000001",
"40000000000000000000000000000000",
},
{
"a * 2",
"8e1807c980d24cd4b2fc5fb3bf4cf406",
"de300f9301a499a965f8bf677e99e80d",
"40000000000000000000000000000000",
},
{
"a * b",
"7d87dda57a20b0c51d9743071ab14010",
"de300f9301a499a965f8bf677e99e80d",
"14b267838ec9ef1bb7b5ce8c19e34bc6",
},
};
gf128_t r;
gf128_t a;
gf128_t b;
size_t i;
gf128_t one = { 1, 0 };
uint8_t buf[16];
gf128_to_bytes(one, buf);
format_bytes(stderr, 0, 0, "one", buf, 16);
printf("\n");
for (i = 0; i < sizeof(tests)/sizeof(tests[0]); i++) {
printf("test %zu\n", i);
a = gf128_from_hex(tests[i].a);
b = gf128_from_hex(tests[i].b);
printf("a0 = %llx, a1 = %llx\n", a.lo, a.hi);
printf("b0 = %llx, b1 = %llx\n", b.lo, b.hi);
r = gf128_mul(a, b);
printf("r0 = %llx, r1 = %llx\n", r.lo, r.hi);
gf128_print_bits(a);
gf128_print_bits(b);
gf128_print_bits(r);
gf128_to_bytes(r, buf);
format_bytes(stderr, 0, 0, "r" ,buf ,16);
if (gf128_equ_hex(r, tests[i].r) != 1) {
error_print();
//return -1;
}
}
printf("%s() ok\n", __FUNCTION__);
return 1;
}
int test_gf128_armv8(void)
{
gf128_t a = { 1, 0 };
gf128_t b = { 1, 0 };
a = gf128_from_hex("de300f9301a499a965f8bf677e99e80d");
b = gf128_from_hex("14b267838ec9ef1bb7b5ce8c19e34bc6");
// pmull 是对低位做了乘法
//gf128_print_bits(b);
gf128_t c = gf128_mul(a, b);
gf128_print_bits(c);
return 1;
}
int test_gf128_print(void)
{
gf128_t a = { 0, 0x8000000000000000 }; // a = 1 + 0*x + ... + 0*x^127
gf128_print(stderr, 0, 0, "1 + 0*x + ... + 0*x^127", a);
// 这个函数打印的不对,因为真正的值是需要 reverse_bits 的,但是这里我们没有反转
gf128_print_bits(a);
// 看来这个比较奇怪了
return 1;
}
int test_gf128_from_hex(void) int test_gf128_from_hex(void)
{ {
char *tests[] = { char *tests[] = {
@@ -67,6 +186,8 @@ int test_gf128_mul2(void)
return 1; return 1;
} }
int test_gf128_mul(void) int test_gf128_mul(void)
{ {
char *hex_a = "de300f9301a499a965f8bf677e99e80d"; char *hex_a = "de300f9301a499a965f8bf677e99e80d";
@@ -96,6 +217,9 @@ int test_gf128_mul(void)
int main(void) int main(void)
{ {
if (test_gf128_armv8() != 1) goto err;
if (test_gf128_mul_more() != 1) goto err;
if (test_gf128_print() != 1) goto err;
if (test_gf128_from_hex() != 1) goto err; if (test_gf128_from_hex() != 1) goto err;
if (test_gf128_mul2() != 1) goto err; if (test_gf128_mul2() != 1) goto err;
if (test_gf128_mul() != 1) goto err; if (test_gf128_mul() != 1) goto err;