diff --git a/CMakeLists.txt b/CMakeLists.txt index c5a05d76..c07cc000 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -327,6 +327,15 @@ if (ENABLE_SM4_CTR_AESNI_AVX) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native") endif() +if (ENABLE_SM4_CL) + message(STATUS "ENABLE_SM4_CL is ON") + if (CMAKE_SYSTEM_NAME STREQUAL "Darwin") + add_definitions(-DMACOS) # to include + endif() + list(APPEND src src/sm4_cl.c) + list(APPEND tests sm4_cl) +endif() + if (ENABLE_SM4_ECB) message(STATUS "ENABLE_SM4_ECB is ON") add_definitions(-DENABLE_SM4_ECB) @@ -516,7 +525,10 @@ if (WIN32) elseif (APPLE) target_link_libraries(gmssl dl) target_link_libraries(gmssl "-framework Security") - #target_link_libraries(gmssl "-framework OpenCL") + if (ENABLE_SM4_CL) + # FIXME: different rules for cl and OpenCL framework + target_link_libraries(gmssl "-framework OpenCL") + endif() #target_link_libraries(gmssl "-framework CoreFoundation") # rand_apple.c CFRelease() elseif (MINGW) target_link_libraries(gmssl PRIVATE wsock32) diff --git a/include/gmssl/sm4_cl.h b/include/gmssl/sm4_cl.h index 5f1ca133..1394e2ef 100644 --- a/include/gmssl/sm4_cl.h +++ b/include/gmssl/sm4_cl.h @@ -10,34 +10,31 @@ #ifndef GMSSL_SM4_CL_H #define GMSSL_SM4_CL_H - -#ifdef __cplusplus -extern "C" { -#endif - #include #include #include #include #include - - -#ifdef APPLE +#ifdef MACOS #include #else #include #endif +#ifdef __cplusplus +extern "C" { +#endif + typedef struct { uint32_t rk[32]; + size_t workgroup_size; cl_context context; cl_command_queue queue; cl_program program; cl_kernel kernel; cl_mem mem_rk; cl_mem mem_io; - size_t workgroup_size; } SM4_CL_CTX; @@ -46,8 +43,6 @@ int sm4_cl_set_decrypt_key(SM4_CL_CTX *ctx, const uint8_t key[16]); int sm4_cl_encrypt(SM4_CL_CTX *ctx, const uint8_t *in, size_t nblocks, uint8_t *out); void sm4_cl_cleanup(SM4_CL_CTX *ctx); -int test_sm4_cl_encrypt(void); - #ifdef __cplusplus } diff --git a/src/sm4_cl.c b/src/sm4_cl.c index 5fbe0b3b..02ed565e 100644 --- a/src/sm4_cl.c +++ b/src/sm4_cl.c @@ -1,17 +1,20 @@ +/* + * Copyright 2014-2024 The GmSSL Project. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ + + #include #include #include #include -#include +#include #include -#define MACOS -#ifdef MACOS -#include -#else -#include -#endif - static char *clErrorString(cl_uint err) { @@ -68,16 +71,6 @@ static char *clErrorString(cl_uint err) static const char *sm4_cl_src; -typedef struct { - uint32_t rk[32]; - cl_context context; - cl_command_queue queue; - cl_program program; - cl_kernel kernel; - cl_mem mem_rk; - cl_mem mem_io; - size_t workgroup_size; -} SM4_CL_CTX; #define cl_error_print(e) \ do { fprintf(stderr, "%s: %d: %s()\n",__FILE__,__LINE__,clErrorString(e)); } while (0) @@ -91,6 +84,62 @@ void sm4_cl_cleanup(SM4_CL_CTX *ctx) clReleaseKernel(ctx->kernel); } +static void clPrintDeviceInfo(cl_device_id device) +{ + char deviceName[128]; + char vendorName[128]; + char deviceVersion[128]; + char driverVersion[128]; + char deviceProfile[128]; + cl_device_type deviceType; + cl_uint vendorID; + cl_ulong globalMemSize; + cl_ulong localMemSize; + size_t maxWorkGroupSize; + cl_uint maxWorkItemDimensions; + size_t maxWorkItemSizes[3]; + cl_uint maxComputeUnits; + char openclCVersion[128]; + char extensions[4096]; + + clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(deviceName), deviceName, NULL); + clGetDeviceInfo(device, CL_DEVICE_VENDOR, sizeof(vendorName), vendorName, NULL); + clGetDeviceInfo(device, CL_DEVICE_VERSION, sizeof(deviceVersion), deviceVersion, NULL); + clGetDeviceInfo(device, CL_DRIVER_VERSION, sizeof(driverVersion), driverVersion, NULL); + clGetDeviceInfo(device, CL_DEVICE_PROFILE, sizeof(deviceProfile), deviceProfile, NULL); + clGetDeviceInfo(device, CL_DEVICE_TYPE, sizeof(deviceType), &deviceType, NULL); + clGetDeviceInfo(device, CL_DEVICE_VENDOR_ID, sizeof(vendorID), &vendorID, NULL); + clGetDeviceInfo(device, CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(globalMemSize), &globalMemSize, NULL); + clGetDeviceInfo(device, CL_DEVICE_LOCAL_MEM_SIZE, sizeof(localMemSize), &localMemSize, NULL); + clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(maxWorkGroupSize), &maxWorkGroupSize, NULL); + clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_ITEM_DIMENSIONS, sizeof(maxWorkItemDimensions), &maxWorkItemDimensions, NULL); + clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(maxWorkItemSizes), maxWorkItemSizes, NULL); + clGetDeviceInfo(device, CL_DEVICE_MAX_COMPUTE_UNITS, sizeof(maxComputeUnits), &maxComputeUnits, NULL); + clGetDeviceInfo(device, CL_DEVICE_OPENCL_C_VERSION, sizeof(openclCVersion), openclCVersion, NULL); + clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, sizeof(extensions), extensions, NULL); + + printf("clGetDeviceInfo\n"); + printf(" Device Name: %s\n", deviceName); + printf(" Vendor: %s\n", vendorName); + printf(" Device Version: %s\n", deviceVersion); + printf(" Driver Version: %s\n", driverVersion); + printf(" Device Profile: %s\n", deviceProfile); + printf(" Device Type: %s\n", + (deviceType == CL_DEVICE_TYPE_CPU) ? "CPU" : + (deviceType == CL_DEVICE_TYPE_GPU) ? "GPU" : + (deviceType == CL_DEVICE_TYPE_ACCELERATOR) ? "Accelerator" : + (deviceType == CL_DEVICE_TYPE_DEFAULT) ? "Default" : "Unknown"); + printf(" Vendor ID: %u\n", vendorID); + printf(" Global Memory Size: %llu bytes\n", globalMemSize); + printf(" Local Memory Size: %llu bytes\n", localMemSize); + printf(" Max Work Group Size: %zu\n", maxWorkGroupSize); + printf(" Max Work Item Dimensions: %u\n", maxWorkItemDimensions); + printf(" Max Work Item Sizes: (%zu, %zu, %zu)\n", maxWorkItemSizes[0], maxWorkItemSizes[1], maxWorkItemSizes[2]); + printf(" Max Compute Units: %u\n", maxComputeUnits); + printf(" OpenCL C Version: %s\n", openclCVersion); + printf(" Extensions: %s\n", extensions); +} + static int sm4_cl_set_key(SM4_CL_CTX *ctx, const uint8_t key[16], int enc) { cl_platform_id platform; @@ -104,7 +153,6 @@ static int sm4_cl_set_key(SM4_CL_CTX *ctx, const uint8_t key[16], int enc) memset(ctx, 0, sizeof(*ctx)); - if ((err = clGetPlatformIDs(1, &platform, NULL)) != CL_SUCCESS) { cl_error_print(err); return -1; @@ -113,6 +161,8 @@ static int sm4_cl_set_key(SM4_CL_CTX *ctx, const uint8_t key[16], int enc) cl_error_print(err); return -1; } + //clPrintDeviceInfo(device); + if (!(ctx->context = clCreateContext(NULL, 1, &device, NULL, NULL, &err))) { cl_error_print(err); return -1; @@ -196,6 +246,8 @@ int sm4_cl_encrypt(SM4_CL_CTX *ctx, const uint8_t *in, size_t nblocks, uint8_t * cl_int err; size_t len = 16 * nblocks; cl_uint dim = 1; + size_t global_work_size = nblocks; + size_t local_work_size = 32; //ctx->workgroup_size; void *p; if (out != in) @@ -209,7 +261,10 @@ int sm4_cl_encrypt(SM4_CL_CTX *ctx, const uint8_t *in, size_t nblocks, uint8_t * cl_error_print(err); goto end; } - if ((err = clEnqueueNDRangeKernel(ctx->queue, ctx->kernel, dim, NULL, &nblocks, &ctx->workgroup_size, 0, NULL, NULL)) != CL_SUCCESS) { + // on Apple M2, CL_KERNEL_WORK_GROUP_SIZE = 256 + // but kernel will fail when local_work_size > 32. + if ((err = clEnqueueNDRangeKernel(ctx->queue, ctx->kernel, + dim, NULL, &global_work_size, &local_work_size, 0, NULL, NULL)) != CL_SUCCESS) { cl_error_print(err); goto end; } @@ -227,58 +282,6 @@ end: return ret; } -int test_sm4_cl_encrypt(void) -{ - const uint8_t key[16] = { - 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, - 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10, - }; - const uint8_t plaintext[16] = { - 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, - 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10, - }; - const uint8_t ciphertext[16] = { - 0x68, 0x1e, 0xdf, 0x34, 0xd2, 0x06, 0x96, 0x5e, - 0x86, 0xb3, 0xe9, 0x4f, 0x53, 0x6e, 0x42, 0x46, - }; - - int ret = -1; - SM4_CL_CTX ctx; - size_t nblocks = 1024; - uint8_t *buf = NULL; - size_t i; - - - if (!(buf = (uint8_t *)malloc(16 * nblocks))) { - error_print(); - return -1; - } - for (i = 0; i < nblocks; i++) { - memcpy(buf + 16 * i, plaintext, 16); - } - - if (sm4_cl_set_encrypt_key(&ctx, key) != 1) { - error_print(); - goto end; - } - if (sm4_cl_encrypt(&ctx, buf, nblocks, buf) != 1) { - error_print(); - goto end; - } - - for (i = 0; i < nblocks; i++) { - if (memcmp(buf + 16 * i, ciphertext, 16) != 0) { - error_print(); - goto end; - } - } - - ret = 1; -end: - if (buf) free(buf); - sm4_cl_cleanup(&ctx); - return ret; -} #define KERNEL(...) #__VA_ARGS__ @@ -303,7 +306,6 @@ __constant unsigned char SBOX[256] = { 0x18, 0xf0, 0x7d, 0xec, 0x3a, 0xdc, 0x4d, 0x20, 0x79, 0xee, 0x5f, 0x3e, 0xd7, 0xcb, 0x39, 0x48, }; - __kernel void sm4_encrypt(__global const unsigned int *rkey, __global unsigned char *data) { __local unsigned char S[256]; diff --git a/tests/sm4_cltest.c b/tests/sm4_cltest.c new file mode 100644 index 00000000..f35d9b27 --- /dev/null +++ b/tests/sm4_cltest.c @@ -0,0 +1,91 @@ +/* + * Copyright 2014-2024 The GmSSL Project. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ + + +#include +#include +#include +#include +#include +#include +#include + + + +int test_sm4_cl(void) +{ + const uint8_t key[16] = { + 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, + 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10, + }; + const uint8_t plaintext[16] = { + 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, + 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10, + }; + const uint8_t ciphertext[16] = { + 0x68, 0x1e, 0xdf, 0x34, 0xd2, 0x06, 0x96, 0x5e, + 0x86, 0xb3, 0xe9, 0x4f, 0x53, 0x6e, 0x42, 0x46, + }; + + int ret = -1; + SM4_CL_CTX ctx; + size_t nblocks = 1024; + uint8_t *buf = NULL; + size_t i; + + + if (!(buf = (uint8_t *)malloc(16 * nblocks))) { + error_print(); + return -1; + } + for (i = 0; i < nblocks; i++) { + memcpy(buf + 16 * i, plaintext, 16); + } + format_bytes(stderr, 0, 0, "in", buf, nblocks * 16); + + if (sm4_cl_set_encrypt_key(&ctx, key) != 1) { + error_print(); + goto end; + } + if (sm4_cl_encrypt(&ctx, buf, nblocks, buf) != 1) { + error_print(); + goto end; + } + + for (i = 0; i < nblocks; i++) { + //fprintf(stderr, "%zu ", i); + //format_bytes(stderr, 0, 0, "ciphertext", buf + 16*i, 16); + if (memcmp(buf + 16 * i, ciphertext, 16) != 0) { + error_print(); + goto end; + } + } + + ret = 1; +end: + if (buf) free(buf); + sm4_cl_cleanup(&ctx); + return ret; +} + +static int test_sm4_cl_ctr(void) +{ + return 1; +} + +int main(void) +{ + if (test_sm4_cl() != 1) goto err; + if (test_sm4_cl_ctr() != 1) goto err; + printf("%s all tests passed\n", __FILE__); + return 0; +err: + error_print(); + return 1; +}