Update sm4_arm64.c

This commit is contained in:
Zhi Guan
2024-05-12 23:25:39 +08:00
parent 7a94496355
commit f5fb0a5ae9

View File

@@ -9,7 +9,7 @@
#include <gmssl/sm4.h>
#include <gmssl/error.h>
#include <gmssl/endian.h>
#include <arm_neon.h>
@@ -63,19 +63,6 @@ const uint8_t S[256] = {
0x79, 0xee, 0x5f, 0x3e, 0xd7, 0xcb, 0x39, 0x48,
};
#define GETU32(ptr) \
((uint32_t)(ptr)[0] << 24 | \
(uint32_t)(ptr)[1] << 16 | \
(uint32_t)(ptr)[2] << 8 | \
(uint32_t)(ptr)[3])
#define PUTU32(ptr,X) \
((ptr)[0] = (uint8_t)((X) >> 24), \
(ptr)[1] = (uint8_t)((X) >> 16), \
(ptr)[2] = (uint8_t)((X) >> 8), \
(ptr)[3] = (uint8_t)(X))
#define ROL32(X,n) (((X)<<(n)) | ((X)>>(32-(n))))
#define L32(X) \
((X) ^ \
@@ -251,8 +238,8 @@ void sm4_ctr_encrypt_blocks(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *
}
}
#define vrolq_n_u32(words, N) \
vorrq_u32(vshlq_n_u32((words), (N)), vshrq_n_u32((words), 32 - (N)))
#define vrolq_n_u32(words, nbits) \
vorrq_u32(vshlq_n_u32((words), (nbits)), vshrq_n_u32((words), 32 - (nbits)))
void sm4_ctr32_encrypt_4blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t *in, size_t n4blks, uint8_t *out)
{
@@ -267,17 +254,12 @@ void sm4_ctr32_encrypt_4blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t
uint32_t n;
uint32x4_t ctr;
uint32x4_t ctr0, ctr1, ctr2, ctr3;
uint32x4_t vi;
uint32x4_t fours;
uint32x4_t vi = vld1q_u32(incr);
uint32x4_t fours = vdupq_n_u32(4);
uint32x4_t x0, x1, x2, x3, x4;
uint32x4_t rk, xt;
uint32x4x2_t x02, x13, x01, x23;
int i;
error_print();
vi = vld1q_u32(incr);
fours = vdupq_n_u32(4);
// compute low ctr32
n = GETU32(iv + 12);
@@ -286,13 +268,11 @@ void sm4_ctr32_encrypt_4blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t
memcpy(buf, iv, 16);
ctr = vld1q_u32(buf);
ctr = vrev32q_u8(ctr);
error_print();
ctr0 = vdupq_n_u32(vgetq_lane_u32(ctr, 0));
ctr1 = vdupq_n_u32(vgetq_lane_u32(ctr, 1));
ctr2 = vdupq_n_u32(vgetq_lane_u32(ctr, 2));
ctr3 = vdupq_n_u32(vgetq_lane_u32(ctr, 3));
error_print();
ctr3 = vaddq_u32(ctr3, vi);
@@ -303,10 +283,8 @@ void sm4_ctr32_encrypt_4blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t
x2 = ctr2;
x3 = ctr3;
error_print();
for (i = 0; i < 32; i++) {
// X4 = X1 ^ X2 ^ X3 ^ RK[i]
rk = vdupq_n_u32(key->rk[i]);
x4 = veorq_u32(veorq_u32(x1, x2), veorq_u32(x3, rk));
@@ -339,19 +317,15 @@ void sm4_ctr32_encrypt_4blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t
x0 = vrev32q_u8(x01.val[0]);
vst1q_u32(buf, x0);
error_print();
x1 = vrev32q_u8(x01.val[1]);
vst1q_u32(buf + 4, x1);
error_print();
x2 = vrev32q_u8(x23.val[0]);
vst1q_u32(buf + 8, x2);
error_print();
x3 = vrev32q_u8(x23.val[1]);
vst1q_u32(buf + 12, x3);
error_print();
// xor with plaintext
for (i = 0; i < 16*4; i++) {
@@ -384,16 +358,13 @@ void sm4_ctr32_encrypt_blocks(const SM4_KEY *key, uint8_t ctr[16], const uint8_t
if (nblocks >= 4) {
sm4_ctr32_encrypt_4blocks(key, ctr, in, nblocks/4, out);
in += 64 * (nblocks/4);
out += 64 * (nblocks/4);
nblocks %= 4;
}
while (nblocks--) {
sm4_encrypt(key, ctr, block);
ctr32_incr(ctr);
for (i = 0; i < 16; i++) {
out[i] = in[i] ^ block[i];