diff --git a/include/gmssl/sm4_cl.h b/include/gmssl/sm4_cl.h index c3a5fe64..ac73f971 100644 --- a/include/gmssl/sm4_cl.h +++ b/include/gmssl/sm4_cl.h @@ -1,5 +1,5 @@ /* - * Copyright 2014-2022 The GmSSL Project. All Rights Reserved. + * Copyright 2014-2024 The GmSSL Project. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the License); you may * not use this file except in compliance with the License. @@ -40,7 +40,7 @@ typedef struct { int sm4_cl_set_encrypt_key(SM4_CL_CTX *ctx, const uint8_t key[16]); int sm4_cl_set_decrypt_key(SM4_CL_CTX *ctx, const uint8_t key[16]); -int sm4_cl_ctr32_encrypt(SM4_CL_CTX *ctx, uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out); +int sm4_cl_ctr32_encrypt_blocks(SM4_CL_CTX *ctx, uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out); void sm4_cl_cleanup(SM4_CL_CTX *ctx); diff --git a/src/sm4_cl.c b/src/sm4_cl.c index 52ad3b52..05ade0fb 100644 --- a/src/sm4_cl.c +++ b/src/sm4_cl.c @@ -198,7 +198,7 @@ static int sm4_cl_set_key(SM4_CL_CTX *ctx, const uint8_t key[16], int enc) free(log); goto end; } - if (!(ctx->kernel = clCreateKernel(ctx->program, "sm4_ctr32_encrypt", &err))) { + if (!(ctx->kernel = clCreateKernel(ctx->program, "sm4_ctr32_encrypt_blocks", &err))) { cl_error_print(err); goto end; } @@ -243,7 +243,7 @@ int sm4_cl_set_decrypt_key(SM4_CL_CTX *ctx, const uint8_t key[16]) return sm4_cl_set_key(ctx, key, 0); } -int sm4_cl_ctr32_encrypt(SM4_CL_CTX *ctx, uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out) +int sm4_cl_ctr32_encrypt_blocks(SM4_CL_CTX *ctx, uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out) { int ret = -1; cl_int err; @@ -330,7 +330,7 @@ __constant unsigned char SBOX[256] = { 0x18, 0xf0, 0x7d, 0xec, 0x3a, 0xdc, 0x4d, 0x20, 0x79, 0xee, 0x5f, 0x3e, 0xd7, 0xcb, 0x39, 0x48, }; -__kernel void sm4_ctr32_encrypt(__global const unsigned int *rkey, __global const unsigned int *ctr, __global unsigned char *data) +__kernel void sm4_ctr32_encrypt_blocks(__global const unsigned int *rkey, __global const unsigned int *ctr, __global unsigned char *data) { unsigned int x0, x1, x2, x3, x4, i, t; uint global_id = get_global_id(0); diff --git a/tests/sm4_cltest.c b/tests/sm4_cltest.c index af861325..2999c342 100644 --- a/tests/sm4_cltest.c +++ b/tests/sm4_cltest.c @@ -18,7 +18,7 @@ #include -static int test_sm4_cl_ctr32_encrypt(void) +static int test_sm4_cl_ctr32_encrypt_blocks(void) { const char *key_hex = "0123456789abcdeffedcba9876543210"; const char *iv_hex = "0123456789abcdeffedcba9876543210"; @@ -56,7 +56,7 @@ static int test_sm4_cl_ctr32_encrypt(void) } memcpy(ctr, iv, sizeof(iv)); - if (sm4_cl_ctr32_encrypt(&ctx, ctr, buf, nblocks, buf) != 1) { + if (sm4_cl_ctr32_encrypt_blocks(&ctx, ctr, buf, nblocks, buf) != 1) { error_print(); goto end; } @@ -75,7 +75,7 @@ end: return ret; } -static int test_sm4_cl_ctr32_encrypt_speed(void) +static int speed_sm4_cl_ctr32_encrypt_blocks(void) { const uint8_t key[16] = { 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, @@ -96,7 +96,7 @@ static int test_sm4_cl_ctr32_encrypt_speed(void) uint8_t ctr[16]; size_t nblocks = 1024*1024; uint8_t *buf = NULL; - clock_t start, end; + clock_t begin, end; double seconds; size_t i; @@ -113,15 +113,15 @@ static int test_sm4_cl_ctr32_encrypt_speed(void) goto end; } - start = clock(); - if (sm4_cl_ctr32_encrypt(&ctx, ctr, buf, nblocks, buf) != 1) { + begin = clock(); + if (sm4_cl_ctr32_encrypt_blocks(&ctx, ctr, buf, nblocks, buf) != 1) { error_print(); goto end; } end = clock(); - seconds = (double)(end - start)/CLOCKS_PER_SEC; - fprintf(stderr, "sm4_cl_encrypt: %f-MiB per seconds\n", 16/seconds); + seconds = (double)(end - begin)/CLOCKS_PER_SEC; + fprintf(stderr, "%s: %f-MiB per seconds\n", __FUNCTION__, 16/seconds); ret = 1; end: @@ -132,8 +132,10 @@ end: int main(void) { - if (test_sm4_cl_ctr32_encrypt() != 1) goto err; - if (test_sm4_cl_ctr32_encrypt_speed() != 1) goto err; + if (test_sm4_cl_ctr32_encrypt_blocks() != 1) goto err; +#if ENABLE_TEST_SPEED + if (speed_sm4_cl_ctr32_encrypt_blocks() != 1) goto err; +#endif printf("%s all tests passed\n", __FILE__); return 0; err: