diff --git a/src/sm4_arm64.c b/src/sm4_arm64.c index 64ecee4e..d280d0d4 100644 --- a/src/sm4_arm64.c +++ b/src/sm4_arm64.c @@ -9,7 +9,7 @@ #include -#include +#include #include @@ -63,19 +63,6 @@ const uint8_t S[256] = { 0x79, 0xee, 0x5f, 0x3e, 0xd7, 0xcb, 0x39, 0x48, }; -#define GETU32(ptr) \ - ((uint32_t)(ptr)[0] << 24 | \ - (uint32_t)(ptr)[1] << 16 | \ - (uint32_t)(ptr)[2] << 8 | \ - (uint32_t)(ptr)[3]) - -#define PUTU32(ptr,X) \ - ((ptr)[0] = (uint8_t)((X) >> 24), \ - (ptr)[1] = (uint8_t)((X) >> 16), \ - (ptr)[2] = (uint8_t)((X) >> 8), \ - (ptr)[3] = (uint8_t)(X)) - -#define ROL32(X,n) (((X)<<(n)) | ((X)>>(32-(n)))) #define L32(X) \ ((X) ^ \ @@ -251,8 +238,8 @@ void sm4_ctr_encrypt_blocks(const SM4_KEY *key, uint8_t ctr[16], const uint8_t * } } -#define vrolq_n_u32(words, N) \ - vorrq_u32(vshlq_n_u32((words), (N)), vshrq_n_u32((words), 32 - (N))) +#define vrolq_n_u32(words, nbits) \ + vorrq_u32(vshlq_n_u32((words), (nbits)), vshrq_n_u32((words), 32 - (nbits))) void sm4_ctr32_encrypt_4blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t *in, size_t n4blks, uint8_t *out) { @@ -267,17 +254,12 @@ void sm4_ctr32_encrypt_4blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t uint32_t n; uint32x4_t ctr; uint32x4_t ctr0, ctr1, ctr2, ctr3; - uint32x4_t vi; - uint32x4_t fours; + uint32x4_t vi = vld1q_u32(incr); + uint32x4_t fours = vdupq_n_u32(4); uint32x4_t x0, x1, x2, x3, x4; uint32x4_t rk, xt; uint32x4x2_t x02, x13, x01, x23; int i; - error_print(); - - - vi = vld1q_u32(incr); - fours = vdupq_n_u32(4); // compute low ctr32 n = GETU32(iv + 12); @@ -286,13 +268,11 @@ void sm4_ctr32_encrypt_4blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t memcpy(buf, iv, 16); ctr = vld1q_u32(buf); ctr = vrev32q_u8(ctr); - error_print(); ctr0 = vdupq_n_u32(vgetq_lane_u32(ctr, 0)); ctr1 = vdupq_n_u32(vgetq_lane_u32(ctr, 1)); ctr2 = vdupq_n_u32(vgetq_lane_u32(ctr, 2)); ctr3 = vdupq_n_u32(vgetq_lane_u32(ctr, 3)); - error_print(); ctr3 = vaddq_u32(ctr3, vi); @@ -303,10 +283,8 @@ void sm4_ctr32_encrypt_4blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t x2 = ctr2; x3 = ctr3; - error_print(); for (i = 0; i < 32; i++) { - // X4 = X1 ^ X2 ^ X3 ^ RK[i] rk = vdupq_n_u32(key->rk[i]); x4 = veorq_u32(veorq_u32(x1, x2), veorq_u32(x3, rk)); @@ -339,19 +317,15 @@ void sm4_ctr32_encrypt_4blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t x0 = vrev32q_u8(x01.val[0]); vst1q_u32(buf, x0); - error_print(); x1 = vrev32q_u8(x01.val[1]); vst1q_u32(buf + 4, x1); - error_print(); x2 = vrev32q_u8(x23.val[0]); vst1q_u32(buf + 8, x2); - error_print(); x3 = vrev32q_u8(x23.val[1]); vst1q_u32(buf + 12, x3); - error_print(); // xor with plaintext for (i = 0; i < 16*4; i++) { @@ -384,16 +358,13 @@ void sm4_ctr32_encrypt_blocks(const SM4_KEY *key, uint8_t ctr[16], const uint8_t if (nblocks >= 4) { sm4_ctr32_encrypt_4blocks(key, ctr, in, nblocks/4, out); - in += 64 * (nblocks/4); out += 64 * (nblocks/4); nblocks %= 4; } while (nblocks--) { - sm4_encrypt(key, ctr, block); - ctr32_incr(ctr); for (i = 0; i < 16; i++) { out[i] = in[i] ^ block[i];