Update SM4 OpenCL implementation

This commit is contained in:
Zhi Guan
2024-04-12 16:07:30 +08:00
parent 9fd4464980
commit 8e2c4ebd2f
4 changed files with 185 additions and 85 deletions

View File

@@ -327,6 +327,15 @@ if (ENABLE_SM4_CTR_AESNI_AVX)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native")
endif()
if (ENABLE_SM4_CL)
message(STATUS "ENABLE_SM4_CL is ON")
if (CMAKE_SYSTEM_NAME STREQUAL "Darwin")
add_definitions(-DMACOS) # to include <OpenCL/OpenCL.h>
endif()
list(APPEND src src/sm4_cl.c)
list(APPEND tests sm4_cl)
endif()
if (ENABLE_SM4_ECB)
message(STATUS "ENABLE_SM4_ECB is ON")
add_definitions(-DENABLE_SM4_ECB)
@@ -516,7 +525,10 @@ if (WIN32)
elseif (APPLE)
target_link_libraries(gmssl dl)
target_link_libraries(gmssl "-framework Security")
#target_link_libraries(gmssl "-framework OpenCL")
if (ENABLE_SM4_CL)
# FIXME: different rules for cl and OpenCL framework
target_link_libraries(gmssl "-framework OpenCL")
endif()
#target_link_libraries(gmssl "-framework CoreFoundation") # rand_apple.c CFRelease()
elseif (MINGW)
target_link_libraries(gmssl PRIVATE wsock32)

View File

@@ -10,34 +10,31 @@
#ifndef GMSSL_SM4_CL_H
#define GMSSL_SM4_CL_H
#ifdef __cplusplus
extern "C" {
#endif
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <stdint.h>
#include <gmssl/sm4.h>
#ifdef APPLE
#ifdef MACOS
#include <OpenCL/OpenCL.h>
#else
#include <CL/cl.h>
#endif
#ifdef __cplusplus
extern "C" {
#endif
typedef struct {
uint32_t rk[32];
size_t workgroup_size;
cl_context context;
cl_command_queue queue;
cl_program program;
cl_kernel kernel;
cl_mem mem_rk;
cl_mem mem_io;
size_t workgroup_size;
} SM4_CL_CTX;
@@ -46,8 +43,6 @@ int sm4_cl_set_decrypt_key(SM4_CL_CTX *ctx, const uint8_t key[16]);
int sm4_cl_encrypt(SM4_CL_CTX *ctx, const uint8_t *in, size_t nblocks, uint8_t *out);
void sm4_cl_cleanup(SM4_CL_CTX *ctx);
int test_sm4_cl_encrypt(void);
#ifdef __cplusplus
}

View File

@@ -1,17 +1,20 @@
/*
* Copyright 2014-2024 The GmSSL Project. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the License); you may
* not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*/
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <stdint.h>
#include <gmssl/sm4.h>
#include <gmssl/sm4_cl.h>
#include <gmssl/error.h>
#define MACOS
#ifdef MACOS
#include <OpenCL/OpenCL.h>
#else
#include <CL/cl.h>
#endif
static char *clErrorString(cl_uint err)
{
@@ -68,16 +71,6 @@ static char *clErrorString(cl_uint err)
static const char *sm4_cl_src;
typedef struct {
uint32_t rk[32];
cl_context context;
cl_command_queue queue;
cl_program program;
cl_kernel kernel;
cl_mem mem_rk;
cl_mem mem_io;
size_t workgroup_size;
} SM4_CL_CTX;
#define cl_error_print(e) \
do { fprintf(stderr, "%s: %d: %s()\n",__FILE__,__LINE__,clErrorString(e)); } while (0)
@@ -91,6 +84,62 @@ void sm4_cl_cleanup(SM4_CL_CTX *ctx)
clReleaseKernel(ctx->kernel);
}
static void clPrintDeviceInfo(cl_device_id device)
{
char deviceName[128];
char vendorName[128];
char deviceVersion[128];
char driverVersion[128];
char deviceProfile[128];
cl_device_type deviceType;
cl_uint vendorID;
cl_ulong globalMemSize;
cl_ulong localMemSize;
size_t maxWorkGroupSize;
cl_uint maxWorkItemDimensions;
size_t maxWorkItemSizes[3];
cl_uint maxComputeUnits;
char openclCVersion[128];
char extensions[4096];
clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(deviceName), deviceName, NULL);
clGetDeviceInfo(device, CL_DEVICE_VENDOR, sizeof(vendorName), vendorName, NULL);
clGetDeviceInfo(device, CL_DEVICE_VERSION, sizeof(deviceVersion), deviceVersion, NULL);
clGetDeviceInfo(device, CL_DRIVER_VERSION, sizeof(driverVersion), driverVersion, NULL);
clGetDeviceInfo(device, CL_DEVICE_PROFILE, sizeof(deviceProfile), deviceProfile, NULL);
clGetDeviceInfo(device, CL_DEVICE_TYPE, sizeof(deviceType), &deviceType, NULL);
clGetDeviceInfo(device, CL_DEVICE_VENDOR_ID, sizeof(vendorID), &vendorID, NULL);
clGetDeviceInfo(device, CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(globalMemSize), &globalMemSize, NULL);
clGetDeviceInfo(device, CL_DEVICE_LOCAL_MEM_SIZE, sizeof(localMemSize), &localMemSize, NULL);
clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(maxWorkGroupSize), &maxWorkGroupSize, NULL);
clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_ITEM_DIMENSIONS, sizeof(maxWorkItemDimensions), &maxWorkItemDimensions, NULL);
clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(maxWorkItemSizes), maxWorkItemSizes, NULL);
clGetDeviceInfo(device, CL_DEVICE_MAX_COMPUTE_UNITS, sizeof(maxComputeUnits), &maxComputeUnits, NULL);
clGetDeviceInfo(device, CL_DEVICE_OPENCL_C_VERSION, sizeof(openclCVersion), openclCVersion, NULL);
clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, sizeof(extensions), extensions, NULL);
printf("clGetDeviceInfo\n");
printf(" Device Name: %s\n", deviceName);
printf(" Vendor: %s\n", vendorName);
printf(" Device Version: %s\n", deviceVersion);
printf(" Driver Version: %s\n", driverVersion);
printf(" Device Profile: %s\n", deviceProfile);
printf(" Device Type: %s\n",
(deviceType == CL_DEVICE_TYPE_CPU) ? "CPU" :
(deviceType == CL_DEVICE_TYPE_GPU) ? "GPU" :
(deviceType == CL_DEVICE_TYPE_ACCELERATOR) ? "Accelerator" :
(deviceType == CL_DEVICE_TYPE_DEFAULT) ? "Default" : "Unknown");
printf(" Vendor ID: %u\n", vendorID);
printf(" Global Memory Size: %llu bytes\n", globalMemSize);
printf(" Local Memory Size: %llu bytes\n", localMemSize);
printf(" Max Work Group Size: %zu\n", maxWorkGroupSize);
printf(" Max Work Item Dimensions: %u\n", maxWorkItemDimensions);
printf(" Max Work Item Sizes: (%zu, %zu, %zu)\n", maxWorkItemSizes[0], maxWorkItemSizes[1], maxWorkItemSizes[2]);
printf(" Max Compute Units: %u\n", maxComputeUnits);
printf(" OpenCL C Version: %s\n", openclCVersion);
printf(" Extensions: %s\n", extensions);
}
static int sm4_cl_set_key(SM4_CL_CTX *ctx, const uint8_t key[16], int enc)
{
cl_platform_id platform;
@@ -104,7 +153,6 @@ static int sm4_cl_set_key(SM4_CL_CTX *ctx, const uint8_t key[16], int enc)
memset(ctx, 0, sizeof(*ctx));
if ((err = clGetPlatformIDs(1, &platform, NULL)) != CL_SUCCESS) {
cl_error_print(err);
return -1;
@@ -113,6 +161,8 @@ static int sm4_cl_set_key(SM4_CL_CTX *ctx, const uint8_t key[16], int enc)
cl_error_print(err);
return -1;
}
//clPrintDeviceInfo(device);
if (!(ctx->context = clCreateContext(NULL, 1, &device, NULL, NULL, &err))) {
cl_error_print(err);
return -1;
@@ -196,6 +246,8 @@ int sm4_cl_encrypt(SM4_CL_CTX *ctx, const uint8_t *in, size_t nblocks, uint8_t *
cl_int err;
size_t len = 16 * nblocks;
cl_uint dim = 1;
size_t global_work_size = nblocks;
size_t local_work_size = 32; //ctx->workgroup_size;
void *p;
if (out != in)
@@ -209,7 +261,10 @@ int sm4_cl_encrypt(SM4_CL_CTX *ctx, const uint8_t *in, size_t nblocks, uint8_t *
cl_error_print(err);
goto end;
}
if ((err = clEnqueueNDRangeKernel(ctx->queue, ctx->kernel, dim, NULL, &nblocks, &ctx->workgroup_size, 0, NULL, NULL)) != CL_SUCCESS) {
// on Apple M2, CL_KERNEL_WORK_GROUP_SIZE = 256
// but kernel will fail when local_work_size > 32.
if ((err = clEnqueueNDRangeKernel(ctx->queue, ctx->kernel,
dim, NULL, &global_work_size, &local_work_size, 0, NULL, NULL)) != CL_SUCCESS) {
cl_error_print(err);
goto end;
}
@@ -227,58 +282,6 @@ end:
return ret;
}
int test_sm4_cl_encrypt(void)
{
const uint8_t key[16] = {
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef,
0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10,
};
const uint8_t plaintext[16] = {
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef,
0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10,
};
const uint8_t ciphertext[16] = {
0x68, 0x1e, 0xdf, 0x34, 0xd2, 0x06, 0x96, 0x5e,
0x86, 0xb3, 0xe9, 0x4f, 0x53, 0x6e, 0x42, 0x46,
};
int ret = -1;
SM4_CL_CTX ctx;
size_t nblocks = 1024;
uint8_t *buf = NULL;
size_t i;
if (!(buf = (uint8_t *)malloc(16 * nblocks))) {
error_print();
return -1;
}
for (i = 0; i < nblocks; i++) {
memcpy(buf + 16 * i, plaintext, 16);
}
if (sm4_cl_set_encrypt_key(&ctx, key) != 1) {
error_print();
goto end;
}
if (sm4_cl_encrypt(&ctx, buf, nblocks, buf) != 1) {
error_print();
goto end;
}
for (i = 0; i < nblocks; i++) {
if (memcmp(buf + 16 * i, ciphertext, 16) != 0) {
error_print();
goto end;
}
}
ret = 1;
end:
if (buf) free(buf);
sm4_cl_cleanup(&ctx);
return ret;
}
#define KERNEL(...) #__VA_ARGS__
@@ -303,7 +306,6 @@ __constant unsigned char SBOX[256] = {
0x18, 0xf0, 0x7d, 0xec, 0x3a, 0xdc, 0x4d, 0x20, 0x79, 0xee, 0x5f, 0x3e, 0xd7, 0xcb, 0x39, 0x48,
};
__kernel void sm4_encrypt(__global const unsigned int *rkey, __global unsigned char *data)
{
__local unsigned char S[256];

91
tests/sm4_cltest.c Normal file
View File

@@ -0,0 +1,91 @@
/*
* Copyright 2014-2024 The GmSSL Project. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the License); you may
* not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*/
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <gmssl/sm4_cl.h>
#include <gmssl/hex.h>
#include <gmssl/rand.h>
#include <gmssl/error.h>
int test_sm4_cl(void)
{
const uint8_t key[16] = {
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef,
0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10,
};
const uint8_t plaintext[16] = {
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef,
0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10,
};
const uint8_t ciphertext[16] = {
0x68, 0x1e, 0xdf, 0x34, 0xd2, 0x06, 0x96, 0x5e,
0x86, 0xb3, 0xe9, 0x4f, 0x53, 0x6e, 0x42, 0x46,
};
int ret = -1;
SM4_CL_CTX ctx;
size_t nblocks = 1024;
uint8_t *buf = NULL;
size_t i;
if (!(buf = (uint8_t *)malloc(16 * nblocks))) {
error_print();
return -1;
}
for (i = 0; i < nblocks; i++) {
memcpy(buf + 16 * i, plaintext, 16);
}
format_bytes(stderr, 0, 0, "in", buf, nblocks * 16);
if (sm4_cl_set_encrypt_key(&ctx, key) != 1) {
error_print();
goto end;
}
if (sm4_cl_encrypt(&ctx, buf, nblocks, buf) != 1) {
error_print();
goto end;
}
for (i = 0; i < nblocks; i++) {
//fprintf(stderr, "%zu ", i);
//format_bytes(stderr, 0, 0, "ciphertext", buf + 16*i, 16);
if (memcmp(buf + 16 * i, ciphertext, 16) != 0) {
error_print();
goto end;
}
}
ret = 1;
end:
if (buf) free(buf);
sm4_cl_cleanup(&ctx);
return ret;
}
static int test_sm4_cl_ctr(void)
{
return 1;
}
int main(void)
{
if (test_sm4_cl() != 1) goto err;
if (test_sm4_cl_ctr() != 1) goto err;
printf("%s all tests passed\n", __FILE__);
return 0;
err:
error_print();
return 1;
}