From 6b36c51cdf5a80971bbc171b90fc675bb33a51fe Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Sun, 31 Mar 2024 16:54:24 +0800 Subject: [PATCH] Add GMUL ASM --- CMakeLists.txt | 9 +++ include/gmssl/gf128.h | 6 +- include/gmssl/sm2_p256.h | 131 --------------------------------------- src/gf128.c | 44 ++++++++----- tests/gf128test.c | 124 ++++++++++++++++++++++++++++++++++++ 5 files changed, 167 insertions(+), 147 deletions(-) delete mode 100644 include/gmssl/sm2_p256.h diff --git a/CMakeLists.txt b/CMakeLists.txt index e668360d..64072714 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -245,6 +245,14 @@ if (ENABLE_SM2_ALGOR_ID_ENCODE_NULL) endif() +option(ENABLE_GMUL_ARMV8 "Enable GF(2^128) Multiplication ARMv8 assembly" OFF) +if (ENABLE_GMUL_ARMV8) + message(STATUS "ENABLE_GMUL_ARMV8 is ON") + add_definitions(-DENABLE_GMUL_ARMV8) + enable_language(ASM) + list(APPEND src src/gf128_armv8.S) +endif() + option(ENABLE_SM2_Z256_ARMV8 "Enable SM2_Z256 ARMv8 assembly" OFF) if (ENABLE_SM2_Z256_ARMV8) message(STATUS "ENABLE_SM2_Z256_ARMV8 is ON") @@ -489,6 +497,7 @@ if (WIN32) elseif (APPLE) target_link_libraries(gmssl dl) target_link_libraries(gmssl "-framework Security") + #target_link_libraries(gmssl "-framework OpenCL") #target_link_libraries(gmssl "-framework CoreFoundation") # rand_apple.c CFRelease() elseif (MINGW) target_link_libraries(gmssl PRIVATE wsock32) diff --git a/include/gmssl/gf128.h b/include/gmssl/gf128.h index 8263e4ab..80e535f8 100644 --- a/include/gmssl/gf128.h +++ b/include/gmssl/gf128.h @@ -29,9 +29,11 @@ extern "C" { //typedef unsigned __int128 gf128_t; +// the least significant bit of lo is a_0, the most significant bit of hi is a_127 +// so x^7 + x^2 + x + 1 is 0x87 typedef struct { - uint64_t hi; uint64_t lo; + uint64_t hi; } gf128_t; @@ -45,7 +47,7 @@ gf128_t gf128_mul2(gf128_t a); gf128_t gf128_from_bytes(const uint8_t p[16]); void gf128_to_bytes(gf128_t a, uint8_t p[16]); int gf128_print(FILE *fp, int fmt ,int ind, const char *label, gf128_t a); - +void gf128_print_bits(gf128_t a); #ifdef __cplusplus } diff --git a/include/gmssl/sm2_p256.h b/include/gmssl/sm2_p256.h deleted file mode 100644 index 3508215e..00000000 --- a/include/gmssl/sm2_p256.h +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Copyright 2014-2024 The GmSSL Project. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the License); you may - * not use this file except in compliance with the License. - * - * http://www.apache.org/licenses/LICENSE-2.0 - */ - - - -#ifndef GMSSL_SM2_P256_H -#define GMSSL_SM2_P256_H - -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -typedef uint64_t SM2_BN[8]; - -int sm2_bn_is_zero(const SM2_BN a); -int sm2_bn_is_one(const SM2_BN a); -int sm2_bn_is_odd(const SM2_BN a); -int sm2_bn_cmp(const SM2_BN a, const SM2_BN b); -int sm2_bn_from_hex(SM2_BN r, const char hex[64]); -int sm2_bn_from_asn1_integer(SM2_BN r, const uint8_t *d, size_t dlen); -int sm2_bn_equ_hex(const SM2_BN a, const char *hex); -int sm2_bn_print(FILE *fp, int fmt, int ind, const char *label, const SM2_BN a); -int sm2_bn_rshift(SM2_BN ret, const SM2_BN a, unsigned int nbits); - -void sm2_bn_to_bytes(const SM2_BN a, uint8_t out[32]); -void sm2_bn_from_bytes(SM2_BN r, const uint8_t in[32]); -void sm2_bn_to_hex(const SM2_BN a, char hex[64]); -void sm2_bn_to_bits(const SM2_BN a, char bits[256]); -void sm2_bn_set_word(SM2_BN r, uint32_t a); -void sm2_bn_add(SM2_BN r, const SM2_BN a, const SM2_BN b); -void sm2_bn_sub(SM2_BN ret, const SM2_BN a, const SM2_BN b); -int sm2_bn_rand_range(SM2_BN r, const SM2_BN range); - -#define sm2_bn_init(r) memset((r),0,sizeof(SM2_BN)) -#define sm2_bn_set_zero(r) memset((r),0,sizeof(SM2_BN)) -#define sm2_bn_set_one(r) sm2_bn_set_word((r),1) -#define sm2_bn_copy(r,a) memcpy((r),(a),sizeof(SM2_BN)) -#define sm2_bn_clean(r) memset((r),0,sizeof(SM2_BN)) - - -// GF(p) -typedef SM2_BN SM2_Fp; - -void sm2_fp_add(SM2_Fp r, const SM2_Fp a, const SM2_Fp b); -void sm2_fp_sub(SM2_Fp r, const SM2_Fp a, const SM2_Fp b); -void sm2_fp_mul(SM2_Fp r, const SM2_Fp a, const SM2_Fp b); -void sm2_fp_exp(SM2_Fp r, const SM2_Fp a, const SM2_Fp e); -void sm2_fp_dbl(SM2_Fp r, const SM2_Fp a); -void sm2_fp_tri(SM2_Fp r, const SM2_Fp a); -void sm2_fp_div2(SM2_Fp r, const SM2_Fp a); -void sm2_fp_neg(SM2_Fp r, const SM2_Fp a); -void sm2_fp_sqr(SM2_Fp r, const SM2_Fp a); -void sm2_fp_inv(SM2_Fp r, const SM2_Fp a); -int sm2_fp_rand(SM2_Fp r); - -int sm2_fp_sqrt(SM2_Fp r, const SM2_Fp a); - -#define sm2_fp_init(r) sm2_bn_init(r) -#define sm2_fp_set_zero(r) sm2_bn_set_zero(r) -#define sm2_fp_set_one(r) sm2_bn_set_one(r) -#define sm2_fp_copy(r,a) sm2_bn_copy(r,a) -#define sm2_fp_clean(r) sm2_bn_clean(r) - -// GF(n) -typedef SM2_BN SM2_Fn; - -void sm2_fn_add(SM2_Fn r, const SM2_Fn a, const SM2_Fn b); -void sm2_fn_sub(SM2_Fn r, const SM2_Fn a, const SM2_Fn b); -void sm2_fn_mul(SM2_Fn r, const SM2_Fn a, const SM2_Fn b); -void sm2_fn_mul_word(SM2_Fn r, const SM2_Fn a, uint32_t b); -void sm2_fn_exp(SM2_Fn r, const SM2_Fn a, const SM2_Fn e); -void sm2_fn_neg(SM2_Fn r, const SM2_Fn a); -void sm2_fn_sqr(SM2_Fn r, const SM2_Fn a); -void sm2_fn_inv(SM2_Fn r, const SM2_Fn a); -int sm2_fn_rand(SM2_Fn r); - -#define sm2_fn_init(r) sm2_bn_init(r) -#define sm2_fn_set_zero(r) sm2_bn_set_zero(r) -#define sm2_fn_set_one(r) sm2_bn_set_one(r) -#define sm2_fn_copy(r,a) sm2_bn_copy(r,a) -#define sm2_fn_clean(r) sm2_bn_clean(r) - - -typedef struct { - SM2_BN X; - SM2_BN Y; - SM2_BN Z; -} SM2_JACOBIAN_POINT; - -void sm2_jacobian_point_init(SM2_JACOBIAN_POINT *R); -void sm2_jacobian_point_set_xy(SM2_JACOBIAN_POINT *R, const SM2_BN x, const SM2_BN y); -void sm2_jacobian_point_get_xy(const SM2_JACOBIAN_POINT *P, SM2_BN x, SM2_BN y); -void sm2_jacobian_point_neg(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P); -void sm2_jacobian_point_dbl(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P); -void sm2_jacobian_point_add(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P, const SM2_JACOBIAN_POINT *Q); -void sm2_jacobian_point_sub(SM2_JACOBIAN_POINT *R, const SM2_JACOBIAN_POINT *P, const SM2_JACOBIAN_POINT *Q); -void sm2_jacobian_point_mul(SM2_JACOBIAN_POINT *R, const SM2_BN k, const SM2_JACOBIAN_POINT *P); -void sm2_jacobian_point_to_bytes(const SM2_JACOBIAN_POINT *P, uint8_t out[64]); -void sm2_jacobian_point_from_bytes(SM2_JACOBIAN_POINT *P, const uint8_t in[64]); -void sm2_jacobian_point_mul_generator(SM2_JACOBIAN_POINT *R, const SM2_BN k); -void sm2_jacobian_point_mul_sum(SM2_JACOBIAN_POINT *R, const SM2_BN t, const SM2_JACOBIAN_POINT *P, const SM2_BN s); -void sm2_jacobian_point_from_hex(SM2_JACOBIAN_POINT *P, const char hex[64 * 2]); // for testing only - -int sm2_jacobian_point_is_at_infinity(const SM2_JACOBIAN_POINT *P); -int sm2_jacobian_point_is_on_curve(const SM2_JACOBIAN_POINT *P); -int sm2_jacobian_point_equ_hex(const SM2_JACOBIAN_POINT *P, const char hex[128]); // for testing only -int sm2_jacobian_point_print(FILE *fp, int fmt, int ind, const char *label, const SM2_JACOBIAN_POINT *P); - -#define sm2_jacobian_point_set_infinity(R) sm2_jacobian_point_init(R) -#define sm2_jacobian_point_copy(R, P) memcpy((R), (P), sizeof(SM2_JACOBIAN_POINT)) - -const uint64_t *sm2_bn_prime(void); -const uint64_t *sm2_bn_order(void); -const uint64_t *sm2_bn_one(void); - - -#ifdef __cplusplus -} -#endif -#endif - diff --git a/src/gf128.c b/src/gf128.c index 8f39969a..04537133 100644 --- a/src/gf128.c +++ b/src/gf128.c @@ -22,6 +22,19 @@ #include #include +static uint64_t reverse_bits(uint64_t a) +{ + uint64_t r = 0; + int i; + + for (i = 0; i < 63; i++) { + r |= a & 1; + r <<= 1; + a >>= 1; + } + r |= a & 1; + return r; +} gf128_t gf128_zero(void) { @@ -50,6 +63,10 @@ int gf128_equ_hex(gf128_t a, const char *s) void gf128_print_bits(gf128_t a) { int i; + + a.hi = reverse_bits(a.hi); + a.lo = reverse_bits(a.lo); + for (i = 0; i < 64; i++) { printf("%d", (int)(a.hi % 2)); a.hi >>= 1; @@ -75,20 +92,6 @@ int gf128_print(FILE *fp, int fmt, int ind, const char *label, gf128_t a) return 1; } -static uint64_t reverse_bits(uint64_t a) -{ - uint64_t r = 0; - int i; - - for (i = 0; i < 63; i++) { - r |= a & 1; - r <<= 1; - a >>= 1; - } - r |= a & 1; - return r; -} - gf128_t gf128_from_bytes(const uint8_t p[16]) { gf128_t r; @@ -117,6 +120,7 @@ gf128_t gf128_add(gf128_t a, gf128_t b) return r; } +#ifndef ENABLE_GMUL_ARMV8 gf128_t gf128_mul(gf128_t a, gf128_t b) { gf128_t r = {0, 0}; @@ -159,6 +163,18 @@ gf128_t gf128_mul(gf128_t a, gf128_t b) return r; } +#else + +extern void gmul(uint64_t r[2], const uint64_t a[2], const uint64_t b[2]); + +gf128_t gf128_mul(gf128_t a, gf128_t b) +{ + gf128_t r; + gmul(&r, &a, &b); + return r; +} + +#endif gf128_t gf128_mul2(gf128_t a) { diff --git a/tests/gf128test.c b/tests/gf128test.c index ad64956a..47a99ead 100644 --- a/tests/gf128test.c +++ b/tests/gf128test.c @@ -17,6 +17,125 @@ #include +int test_gf128_mul_more(void) +{ + struct { + char *label; + char *r; + char *a; + char *b; + } tests[] = { + { + "1 * 0", + "0000000000000000" "0000000000000000", + "8000000000000000" "0000000000000000", + "0000000000000000" "0000000000000000", + }, + // 这个现在显然是不对的 + { + "1 * 1", + "8000000000000000" "0000000000000000", + "8000000000000000" "0000000000000000", + "8000000000000000" "0000000000000000", + }, + { + " * 2", + "e1000000000000000000000000000000", + "00000000000000000000000000000001", + "40000000000000000000000000000000", + }, + { + "a * 2", + "8e1807c980d24cd4b2fc5fb3bf4cf406", + "de300f9301a499a965f8bf677e99e80d", + "40000000000000000000000000000000", + }, + { + "a * b", + "7d87dda57a20b0c51d9743071ab14010", + "de300f9301a499a965f8bf677e99e80d", + "14b267838ec9ef1bb7b5ce8c19e34bc6", + }, + }; + + gf128_t r; + gf128_t a; + gf128_t b; + size_t i; + + + gf128_t one = { 1, 0 }; + uint8_t buf[16]; + + gf128_to_bytes(one, buf); + + format_bytes(stderr, 0, 0, "one", buf, 16); + + printf("\n"); + + for (i = 0; i < sizeof(tests)/sizeof(tests[0]); i++) { + + printf("test %zu\n", i); + + a = gf128_from_hex(tests[i].a); + b = gf128_from_hex(tests[i].b); + + printf("a0 = %llx, a1 = %llx\n", a.lo, a.hi); + printf("b0 = %llx, b1 = %llx\n", b.lo, b.hi); + + r = gf128_mul(a, b); + + printf("r0 = %llx, r1 = %llx\n", r.lo, r.hi); + gf128_print_bits(a); + gf128_print_bits(b); + gf128_print_bits(r); + + gf128_to_bytes(r, buf); + format_bytes(stderr, 0, 0, "r" ,buf ,16); + + if (gf128_equ_hex(r, tests[i].r) != 1) { + error_print(); + //return -1; + } + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + +int test_gf128_armv8(void) +{ + gf128_t a = { 1, 0 }; + gf128_t b = { 1, 0 }; + + a = gf128_from_hex("de300f9301a499a965f8bf677e99e80d"); + b = gf128_from_hex("14b267838ec9ef1bb7b5ce8c19e34bc6"); + + // pmull 是对低位做了乘法 + //gf128_print_bits(b); + + gf128_t c = gf128_mul(a, b); + + gf128_print_bits(c); + return 1; +} + + +int test_gf128_print(void) +{ + gf128_t a = { 0, 0x8000000000000000 }; // a = 1 + 0*x + ... + 0*x^127 + gf128_print(stderr, 0, 0, "1 + 0*x + ... + 0*x^127", a); + + // 这个函数打印的不对,因为真正的值是需要 reverse_bits 的,但是这里我们没有反转 + gf128_print_bits(a); + + // 看来这个比较奇怪了 + + + return 1; + +} + int test_gf128_from_hex(void) { char *tests[] = { @@ -67,6 +186,8 @@ int test_gf128_mul2(void) return 1; } + + int test_gf128_mul(void) { char *hex_a = "de300f9301a499a965f8bf677e99e80d"; @@ -96,6 +217,9 @@ int test_gf128_mul(void) int main(void) { + if (test_gf128_armv8() != 1) goto err; + if (test_gf128_mul_more() != 1) goto err; + if (test_gf128_print() != 1) goto err; if (test_gf128_from_hex() != 1) goto err; if (test_gf128_mul2() != 1) goto err; if (test_gf128_mul() != 1) goto err;