Add gf128_mul arm64 intrinsics

This commit is contained in:
Zhi Guan
2024-05-31 21:01:49 +08:00
parent e9e2d27021
commit 26750fbb1d
3 changed files with 111 additions and 1 deletions

View File

@@ -256,7 +256,8 @@ if (ENABLE_GMUL_ARM64)
message(STATUS "ENABLE_GMUL_ARM64 is ON") message(STATUS "ENABLE_GMUL_ARM64 is ON")
add_definitions(-DENABLE_GMUL_ARM64) add_definitions(-DENABLE_GMUL_ARM64)
enable_language(ASM) enable_language(ASM)
list(APPEND src src/gf128_arm64.S) #list(APPEND src src/gf128_arm64.S)
list(APPEND src src/gf128_arm64.c)
endif() endif()

74
src/gf128_arm64.c Normal file
View File

@@ -0,0 +1,74 @@
/*
* Copyright 2014-2024 The GmSSL Project. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the License); you may
* not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*/
#include <stdint.h>
#include <arm_neon.h>
#include <gmssl/gf128.h>
// this version is converted from the gf128_arm64.S by ChatGPT 4
// a little slower than the asm version
void gf128_mul(gf128_t r, const gf128_t a, const gf128_t b)
{
// Prepare zero
uint8x16_t v0;
uint8x16_t vzero = veorq_u8(v0, v0);
// Set f(x) = x^7 + x^2 + x + 1 (0x87) and prepare it by shifting right
uint8x16_t v7 = vdupq_n_u8(0x87);
uint64x2_t v7_shifted = vreinterpretq_u64_u8(v7);
v7_shifted = vshrq_n_u64(v7_shifted, 56);
// Load (a0, a1) and (b0, b1)
uint8x16_t va = vld1q_u8((const uint8_t*) a);
uint8x16_t vb = vld1q_u8((const uint8_t*) b);
// c = a0 * b0
poly64x2_t v3 = (poly64x2_t)vmull_p64(vget_low_p64(vreinterpretq_p64_u8(va)), vget_low_p64(vreinterpretq_p64_u8(vb)));
// a0 + a1 and b0 + b1
uint8x16_t va0a1 = vextq_u8(va, va, 8);
va0a1 = veorq_u8(va0a1, va);
uint8x16_t vb0b1 = vextq_u8(vb, vb, 8);
vb0b1 = veorq_u8(vb0b1, vb);
// d' = a1 * b1
poly64x2_t v4 = (poly64x2_t)vmull_high_p64(vreinterpretq_p64_u8(va), vreinterpretq_p64_u8(vb));
// e = (a0 + a1) * (b0 + b1) - a0 * b0 - a1 * b1
poly64x2_t v5 = (poly64x2_t)vmull_p64(vget_low_p64(vreinterpretq_p64_u8(va0a1)), vget_low_p64(vreinterpretq_p64_u8(vb0b1)));
v5 = veorq_u64(v5, v3);
v5 = veorq_u64(v5, v4);
// d = d' + e1
uint8x16_t ve1 = vextq_u8(vreinterpretq_u8_u64(v5), vzero, 8);
v4 = veorq_u64(v4, vreinterpretq_u64_u8(ve1));
// w = d1 * f0
poly64x2_t v6 = (poly64x2_t)vmull_high_p64(vreinterpretq_p64_u8(v4), v7_shifted);
// (e0 + w0) * x^64
v5 = veorq_u64(v5, v6);
uint8x16_t ve0w0 = vextq_u8(vzero, vreinterpretq_u8_u64(v5), 8);
// c = c + (e0 + w0) * x^64
v3 = veorq_u64(v3, vreinterpretq_u64_u8(ve0w0));
// (d0 + w1) * f0
uint8x16_t vw1 = vextq_u8(vreinterpretq_u8_u64(v6), vzero, 8);
v4 = veorq_u64(v4, vreinterpretq_u64_u8(vw1));
v4 = (poly64x2_t)vmull_p64(vget_low_p64(vreinterpretq_p64_u64(v4)), vget_low_p64(v7_shifted));
// c += (d0 + w1) * f0
v3 = veorq_u64(v3, v4);
// Output the result
vst1q_u8((uint8_t*) r, vreinterpretq_u8_u64(v3));
}

View File

@@ -11,6 +11,7 @@
#include <stdio.h> #include <stdio.h>
#include <string.h> #include <string.h>
#include <stdlib.h> #include <stdlib.h>
#include <time.h>
#include <gmssl/ghash.h> #include <gmssl/ghash.h>
#include <gmssl/hex.h> #include <gmssl/hex.h>
#include <gmssl/rand.h> #include <gmssl/rand.h>
@@ -180,10 +181,44 @@ int test_gcm(void)
} }
#endif #endif
static int speed_ghash(void)
{
GHASH_CTX ghash_ctx;
uint8_t h[16] = {0};
uint8_t aad[20] = {0};
uint8_t blocks[4096];
uint8_t ghash[16];
clock_t start, end;
double seconds;
int i;
ghash_init(&ghash_ctx, h, aad, sizeof(aad));
for (i = 0; i < 4096; i++) {
ghash_update(&ghash_ctx, blocks, sizeof(blocks));
}
start = clock();
ghash_init(&ghash_ctx, h, aad, sizeof(aad));
for (i = 0; i < 4096; i++) {
ghash_update(&ghash_ctx, blocks, sizeof(blocks));
}
ghash_finish(&ghash_ctx, ghash);
end = clock();
seconds = (double)(end - start)/CLOCKS_PER_SEC;
fprintf(stderr, "%s: %f MiB per second\n", __FUNCTION__, 16/seconds);
return 1;
}
int main(int argc, char **argv) int main(int argc, char **argv)
{ {
if (test_ghash() != 1) goto err; if (test_ghash() != 1) goto err;
#if ENABLE_TEST_SPEED
speed_ghash();
#endif
printf("%s all tests passed\n", __FILE__); printf("%s all tests passed\n", __FILE__);
return 0; return 0;
err: err: