diff --git a/src/sm4_ccm.c b/src/sm4_ccm.c index 6270809c..4dc69185 100644 --- a/src/sm4_ccm.c +++ b/src/sm4_ccm.c @@ -14,8 +14,6 @@ #include - - static void length_to_bytes(size_t len, size_t nbytes, uint8_t *out) { uint8_t *p = out + nbytes - 1; @@ -25,6 +23,32 @@ static void length_to_bytes(size_t len, size_t nbytes, uint8_t *out) } } +static void ctr_n_incr(uint8_t a[16], size_t n) +{ + size_t i; + for (i = 15; i >= 16 - n; i--) { + a[i]++; + if (a[i]) break; + } +} + +// TODO: add test vectors for counter overflow +static void sm4_ctr_n_encrypt(const SM4_KEY *key, uint8_t ctr[16], size_t n, const uint8_t *in, size_t inlen, uint8_t *out) +{ + uint8_t block[16]; + size_t len; + + while (inlen) { + len = inlen < 16 ? inlen : 16; + sm4_encrypt(key, ctr, block); + gmssl_memxor(out, in, block, len); + ctr_n_incr(ctr, n); + in += len; + out += len; + inlen -= len; + } +} + int sm4_ccm_encrypt(const SM4_KEY *sm4_key, const uint8_t *iv, size_t ivlen, const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen, uint8_t *out, size_t taglen, uint8_t *tag) @@ -97,7 +121,7 @@ int sm4_ccm_encrypt(const SM4_KEY *sm4_key, const uint8_t *iv, size_t ivlen, sm4_encrypt(sm4_key, ctr, block); ctr[15] = 1; - sm4_ctr_encrypt(sm4_key, ctr, in, inlen, out); + sm4_ctr_n_encrypt(sm4_key, ctr, 15 - ivlen, in, inlen, out); sm4_cbc_mac_update(&mac_ctx, in, inlen); if (inlen % 16) { @@ -182,7 +206,7 @@ int sm4_ccm_decrypt(const SM4_KEY *sm4_key, const uint8_t *iv, size_t ivlen, sm4_encrypt(sm4_key, ctr, block); ctr[15] = 1; - sm4_ctr_encrypt(sm4_key, ctr, in, inlen, out); + sm4_ctr_n_encrypt(sm4_key, ctr, 15 - ivlen, in, inlen, out); sm4_cbc_mac_update(&mac_ctx, out, inlen); // diff from encrypt if (inlen % 16) {