From b2e334bfb9f4a670935ee1a807feaaec26e33546 Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Sun, 21 Apr 2024 23:58:01 +0800 Subject: [PATCH] FIX GCM counter bug, change inc128 to inc32 --- src/aes_modes.c | 34 +++++++++++++-- src/sm4_gcm.c | 110 +++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 125 insertions(+), 19 deletions(-) diff --git a/src/aes_modes.c b/src/aes_modes.c index 7a521837..ce694e44 100644 --- a/src/aes_modes.c +++ b/src/aes_modes.c @@ -121,6 +121,32 @@ void aes_ctr_encrypt(const AES_KEY *key, uint8_t ctr[16], const uint8_t *in, siz } } + +static void ctr32_incr(uint8_t a[16]) +{ + int i; + for (i = 15; i >= 12; i--) { + a[i]++; + if (a[i]) break; + } +} + +static void aes_ctr32_encrypt(const AES_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t inlen, uint8_t *out) +{ + uint8_t block[16]; + size_t len; + + while (inlen) { + len = inlen < 16 ? inlen : 16; + aes_encrypt(key, ctr, block); + gmssl_memxor(out, in, block, len); + ctr32_incr(ctr); + in += len; + out += len; + inlen -= len; + } +} + int aes_gcm_encrypt(const AES_KEY *key, const uint8_t *iv, size_t ivlen, const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen, uint8_t *out, size_t taglen, uint8_t *tag) @@ -149,8 +175,8 @@ int aes_gcm_encrypt(const AES_KEY *key, const uint8_t *iv, size_t ivlen, aes_encrypt(key, Y, T); - ctr_incr(Y); - aes_ctr_encrypt(key, Y, in, inlen, out); + ctr32_incr(Y); + aes_ctr32_encrypt(key, Y, in, inlen, out); ghash(H, aad, aadlen, out, inlen, H); gmssl_memxor(tag, T, H, taglen); @@ -186,8 +212,8 @@ int aes_gcm_decrypt(const AES_KEY *key, const uint8_t *iv, size_t ivlen, return -1; } - ctr_incr(Y); - aes_ctr_encrypt(key, Y, in, inlen, out); + ctr32_incr(Y); + aes_ctr32_encrypt(key, Y, in, inlen, out); return 1; } diff --git a/src/sm4_gcm.c b/src/sm4_gcm.c index 6cdd40c1..92c22ee3 100644 --- a/src/sm4_gcm.c +++ b/src/sm4_gcm.c @@ -12,17 +12,35 @@ #include #include #include +#include -static void ctr_incr(uint8_t a[16]) +// inc32() in nist-sp800-38d +static void ctr32_incr(uint8_t a[16]) { int i; - for (i = 15; i >= 0; i--) { + for (i = 15; i >= 12; i--) { a[i]++; if (a[i]) break; } } +static void sm4_ctr32_encrypt(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t inlen, uint8_t *out) +{ + uint8_t block[16]; + size_t len; + + while (inlen) { + len = inlen < 16 ? inlen : 16; + sm4_encrypt(key, ctr, block); + gmssl_memxor(out, in, block, len); + ctr32_incr(ctr); + in += len; + out += len; + inlen -= len; + } +} + int sm4_gcm_encrypt(const SM4_KEY *key, const uint8_t *iv, size_t ivlen, const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen, uint8_t *out, size_t taglen, uint8_t *tag) @@ -48,8 +66,8 @@ int sm4_gcm_encrypt(const SM4_KEY *key, const uint8_t *iv, size_t ivlen, sm4_encrypt(key, Y, T); - ctr_incr(Y); - sm4_ctr_encrypt(key, Y, in, inlen, out); + ctr32_incr(Y); + sm4_ctr32_encrypt(key, Y, in, inlen, out); ghash(H, aad, aadlen, out, inlen, H); gmssl_memxor(tag, T, H, taglen); @@ -84,12 +102,75 @@ int sm4_gcm_decrypt(const SM4_KEY *key, const uint8_t *iv, size_t ivlen, return -1; } - ctr_incr(Y); - sm4_ctr_encrypt(key, Y, in, inlen, out); + ctr32_incr(Y); + sm4_ctr32_encrypt(key, Y, in, inlen, out); return 1; } +static int sm4_ctr32_encrypt_init(SM4_CTR_CTX *ctx, + const uint8_t key[SM4_BLOCK_SIZE], const uint8_t ctr[SM4_BLOCK_SIZE]) +{ + sm4_set_encrypt_key(&ctx->sm4_key, key); + memcpy(ctx->ctr, ctr, SM4_BLOCK_SIZE); + memset(ctx->block, 0, SM4_BLOCK_SIZE); + ctx->block_nbytes = 0; + return 1; +} + +static int sm4_ctr32_encrypt_update(SM4_CTR_CTX *ctx, + const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) +{ + size_t left; + size_t nblocks; + size_t len; + + if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { + error_print(); + return -1; + } + *outlen = 0; + if (ctx->block_nbytes) { + left = SM4_BLOCK_SIZE - ctx->block_nbytes; + if (inlen < left) { + memcpy(ctx->block + ctx->block_nbytes, in, inlen); + ctx->block_nbytes += inlen; + return 1; + } + memcpy(ctx->block + ctx->block_nbytes, in, left); + sm4_ctr32_encrypt(&ctx->sm4_key, ctx->ctr, ctx->block, SM4_BLOCK_SIZE, out); + in += left; + inlen -= left; + out += SM4_BLOCK_SIZE; + *outlen += SM4_BLOCK_SIZE; + } + if (inlen >= SM4_BLOCK_SIZE) { + nblocks = inlen / SM4_BLOCK_SIZE; + len = nblocks * SM4_BLOCK_SIZE; + sm4_ctr32_encrypt(&ctx->sm4_key, ctx->ctr, in, len, out); + in += len; + inlen -= len; + out += len; + *outlen += len; + } + if (inlen) { + memcpy(ctx->block, in, inlen); + } + ctx->block_nbytes = inlen; + return 1; +} + +static int sm4_ctr32_encrypt_finish(SM4_CTR_CTX *ctx, uint8_t *out, size_t *outlen) +{ + if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { + error_print(); + return -1; + } + sm4_ctr32_encrypt(&ctx->sm4_key, ctx->ctr, ctx->block, ctx->block_nbytes, out); + *outlen = ctx->block_nbytes; + return 1; +} + int sm4_gcm_encrypt_init(SM4_GCM_CTX *ctx, const uint8_t *key, size_t keylen, const uint8_t *iv, size_t ivlen, const uint8_t *aad, size_t aadlen, size_t taglen) @@ -117,7 +198,7 @@ int sm4_gcm_encrypt_init(SM4_GCM_CTX *ctx, memset(ctx, 0, sizeof(*ctx)); ctx->taglen = taglen; - if (sm4_ctr_encrypt_init(&ctx->enc_ctx, key, H) != 1) { + if (sm4_ctr32_encrypt_init(&ctx->enc_ctx, key, H) != 1) { error_print(); return -1; } @@ -136,7 +217,7 @@ int sm4_gcm_encrypt_init(SM4_GCM_CTX *ctx, sm4_encrypt(&ctx->enc_ctx.sm4_key, Y, ctx->Y); - ctr_incr(Y); + ctr32_incr(Y); memcpy(ctx->enc_ctx.ctr, Y, 16); gmssl_secure_clear(H, sizeof(H)); @@ -150,7 +231,7 @@ int sm4_gcm_encrypt_update(SM4_GCM_CTX *ctx, const uint8_t *in, size_t inlen, ui error_print(); return -1; } - if (sm4_ctr_encrypt_update(&ctx->enc_ctx, in, inlen, out, outlen) != 1) { + if (sm4_ctr32_encrypt_update(&ctx->enc_ctx, in, inlen, out, outlen) != 1) { error_print(); return -1; } @@ -166,7 +247,7 @@ int sm4_gcm_encrypt_finish(SM4_GCM_CTX *ctx, uint8_t *out, size_t *outlen) error_print(); return -1; } - if (sm4_ctr_encrypt_finish(&ctx->enc_ctx, out, outlen) != 1) { + if (sm4_ctr32_encrypt_finish(&ctx->enc_ctx, out, outlen) != 1) { error_print(); return -1; } @@ -217,7 +298,7 @@ int sm4_gcm_decrypt_update(SM4_GCM_CTX *ctx, const uint8_t *in, size_t inlen, ui if (inlen <= ctx->taglen) { uint8_t tmp[GHASH_SIZE]; ghash_update(&ctx->mac_ctx, ctx->mac, inlen); - if (sm4_ctr_encrypt_update(&ctx->enc_ctx, ctx->mac, inlen, out, outlen) != 1) { + if (sm4_ctr32_encrypt_update(&ctx->enc_ctx, ctx->mac, inlen, out, outlen) != 1) { error_print(); return -1; } @@ -227,7 +308,7 @@ int sm4_gcm_decrypt_update(SM4_GCM_CTX *ctx, const uint8_t *in, size_t inlen, ui memcpy(ctx->mac, tmp, GHASH_SIZE); } else { ghash_update(&ctx->mac_ctx, ctx->mac, ctx->taglen); - if (sm4_ctr_encrypt_update(&ctx->enc_ctx, ctx->mac, ctx->taglen, out, outlen) != 1) { + if (sm4_ctr32_encrypt_update(&ctx->enc_ctx, ctx->mac, ctx->taglen, out, outlen) != 1) { error_print(); return -1; } @@ -235,7 +316,7 @@ int sm4_gcm_decrypt_update(SM4_GCM_CTX *ctx, const uint8_t *in, size_t inlen, ui inlen -= ctx->taglen; ghash_update(&ctx->mac_ctx, in, inlen); - if (sm4_ctr_encrypt_update(&ctx->enc_ctx, in, inlen, out, &len) != 1) { + if (sm4_ctr32_encrypt_update(&ctx->enc_ctx, in, inlen, out, &len) != 1) { error_print(); return -1; } @@ -258,7 +339,7 @@ int sm4_gcm_decrypt_finish(SM4_GCM_CTX *ctx, uint8_t *out, size_t *outlen) return -1; } ghash_finish(&ctx->mac_ctx, mac); - if (sm4_ctr_encrypt_finish(&ctx->enc_ctx, out, outlen) != 1) { + if (sm4_ctr32_encrypt_finish(&ctx->enc_ctx, out, outlen) != 1) { error_print(); return -1; } @@ -272,4 +353,3 @@ int sm4_gcm_decrypt_finish(SM4_GCM_CTX *ctx, uint8_t *out, size_t *outlen) ctx->maclen = 0; return 1; } -