diff --git a/src/sm4.c b/src/sm4.c index 472b7807..32c64b21 100644 --- a/src/sm4.c +++ b/src/sm4.c @@ -171,7 +171,7 @@ void sm4_encrypt_blocks(const SM4_KEY *key, const uint8_t *in, size_t nblocks, u void sm4_cbc_encrypt_blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out) { - uint8_t *piv = iv; + const uint8_t *piv = iv; while (nblocks--) { size_t i; @@ -190,7 +190,7 @@ void sm4_cbc_encrypt_blocks(const SM4_KEY *key, uint8_t iv[16], void sm4_cbc_decrypt_blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out) { - uint8_t *piv = iv; + const uint8_t *piv = iv; while (nblocks--) { size_t i; diff --git a/src/sm4_arm64.c b/src/sm4_arm64.c index d280d0d4..8e45b7b5 100644 --- a/src/sm4_arm64.c +++ b/src/sm4_arm64.c @@ -184,34 +184,42 @@ void sm4_encrypt_blocks(const SM4_KEY *key, const uint8_t *in, size_t nblocks, u } } -void sm4_cbc_encrypt_blocks(const SM4_KEY *key, const uint8_t iv[16], +void sm4_cbc_encrypt_blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out) { + const uint8_t *piv = iv; + while (nblocks--) { size_t i; for (i = 0; i < 16; i++) { - out[i] = in[i] ^ iv[i]; + out[i] = in[i] ^ piv[i]; } sm4_encrypt(key, out, out); - iv = out; + piv = out; in += 16; out += 16; } + + memcpy(iv, piv, 16); } -void sm4_cbc_decrypt_blocks(const SM4_KEY *key, const uint8_t iv[16], +void sm4_cbc_decrypt_blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out) { + const uint8_t *piv = iv; + while (nblocks--) { size_t i; sm4_encrypt(key, in, out); for (i = 0; i < 16; i++) { - out[i] ^= iv[i]; + out[i] ^= piv[i]; } - iv = in; + piv = in; in += 16; out += 16; } + + memcpy(iv, piv, 16); } static void ctr_incr(uint8_t a[16]) { diff --git a/tests/sm4_cbc_mactest.c b/tests/sm4_cbc_mactest.c index ccba24e4..e9e258cd 100644 --- a/tests/sm4_cbc_mactest.c +++ b/tests/sm4_cbc_mactest.c @@ -21,7 +21,8 @@ static int test_sm4_cbc_mac(void) SM4_KEY sm4_key; SM4_CBC_MAC_CTX ctx; uint8_t key[16]; - uint8_t iv[16] = {0}; + const uint8_t civ[16] = {0}; + uint8_t iv[16]; uint8_t m[128]; uint8_t c[128]; uint8_t mac1[16]; @@ -34,6 +35,7 @@ static int test_sm4_cbc_mac(void) sm4_set_encrypt_key(&sm4_key, key); // test 1 + memcpy(iv, civ, 16); sm4_cbc_encrypt_blocks(&sm4_key, iv, m, sizeof(m)/16, c); memcpy(mac1, c + sizeof(m) - 16, 16); @@ -56,6 +58,7 @@ static int test_sm4_cbc_mac(void) // test 2 m[sizeof(m) - 1] = 0; + memcpy(iv, civ, 16); sm4_cbc_encrypt_blocks(&sm4_key, iv, m, sizeof(m)/16, c); memcpy(mac1, c + sizeof(m) - 16, 16);