Let sm4_cbc_encrypt_blocks update iv

If in == out, then after encryptions the input (i.e. iv) is changed
This commit is contained in:
Zhi Guan
2024-05-13 21:44:06 +08:00
parent 3b6c2a3e9b
commit 7f3072e917
5 changed files with 49 additions and 22 deletions

View File

@@ -35,9 +35,9 @@ void sm4_set_decrypt_key(SM4_KEY *key, const uint8_t raw_key[SM4_KEY_SIZE]);
void sm4_encrypt(const SM4_KEY *key, const uint8_t in[SM4_BLOCK_SIZE], uint8_t out[SM4_BLOCK_SIZE]);
void sm4_encrypt_blocks(const SM4_KEY *key, const uint8_t *in, size_t nblocks, uint8_t *out);
void sm4_cbc_encrypt_blocks(const SM4_KEY *key, const uint8_t iv[SM4_BLOCK_SIZE],
void sm4_cbc_encrypt_blocks(const SM4_KEY *key, uint8_t iv[SM4_BLOCK_SIZE],
const uint8_t *in, size_t nblocks, uint8_t *out);
void sm4_cbc_decrypt_blocks(const SM4_KEY *key, const uint8_t iv[SM4_BLOCK_SIZE],
void sm4_cbc_decrypt_blocks(const SM4_KEY *key, uint8_t iv[SM4_BLOCK_SIZE],
const uint8_t *in, size_t nblocks, uint8_t *out);
void sm4_ctr_encrypt_blocks(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t nblocks, uint8_t *out);
void sm4_ctr32_encrypt_blocks(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t nblocks, uint8_t *out);

View File

@@ -168,34 +168,42 @@ void sm4_encrypt_blocks(const SM4_KEY *key, const uint8_t *in, size_t nblocks, u
}
}
void sm4_cbc_encrypt_blocks(const SM4_KEY *key, const uint8_t iv[16],
void sm4_cbc_encrypt_blocks(const SM4_KEY *key, uint8_t iv[16],
const uint8_t *in, size_t nblocks, uint8_t *out)
{
uint8_t *piv = iv;
while (nblocks--) {
size_t i;
for (i = 0; i < 16; i++) {
out[i] = in[i] ^ iv[i];
out[i] = in[i] ^ piv[i];
}
sm4_encrypt(key, out, out);
iv = out;
piv = out;
in += 16;
out += 16;
}
memcpy(iv, piv, 16);
}
void sm4_cbc_decrypt_blocks(const SM4_KEY *key, const uint8_t iv[16],
void sm4_cbc_decrypt_blocks(const SM4_KEY *key, uint8_t iv[16],
const uint8_t *in, size_t nblocks, uint8_t *out)
{
uint8_t *piv = iv;
while (nblocks--) {
size_t i;
sm4_encrypt(key, in, out);
for (i = 0; i < 16; i++) {
out[i] ^= iv[i];
out[i] ^= piv[i];
}
iv = in;
piv = in;
in += 16;
out += 16;
}
memcpy(iv, piv, 16);
}
static void ctr_incr(uint8_t a[16]) {
@@ -630,7 +638,7 @@ void sm4_encrypt_blocks(const SM4_KEY *key, const uint8_t *in, size_t nblocks, u
}
}
void sm4_cbc_encrypt_blocks(const SM4_KEY *key, const uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out)
void sm4_cbc_encrypt_blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out)
{
const uint32_t *rk = key->rk;
uint32_t X0, X1, X2, X3, X4;
@@ -690,9 +698,14 @@ void sm4_cbc_encrypt_blocks(const SM4_KEY *key, const uint8_t iv[16], const uint
in += 16;
out += 16;
}
PUTU32(iv , X0);
PUTU32(iv + 4, X4);
PUTU32(iv + 8, X3);
PUTU32(iv + 12, X2);
}
void sm4_cbc_decrypt_blocks(const SM4_KEY *key, const uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out)
void sm4_cbc_decrypt_blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out)
{
const uint32_t *rk = key->rk;
uint32_t IV0, IV1, IV2, IV3;
@@ -756,6 +769,11 @@ void sm4_cbc_decrypt_blocks(const SM4_KEY *key, const uint8_t iv[16], const uint
in += 16;
out += 16;
}
PUTU32(iv , IV0);
PUTU32(iv + 4, IV1);
PUTU32(iv + 8, IV2);
PUTU32(iv + 12, IV3);
}
void sm4_ctr_encrypt_blocks(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t nblocks, uint8_t *out)

View File

@@ -13,36 +13,42 @@
#include <gmssl/error.h>
int sm4_cbc_padding_encrypt(const SM4_KEY *key, const uint8_t iv[16],
int sm4_cbc_padding_encrypt(const SM4_KEY *key, const uint8_t piv[16],
const uint8_t *in, size_t inlen,
uint8_t *out, size_t *outlen)
{
uint8_t iv[16];
uint8_t block[16];
size_t rem = inlen % 16;
int padding = 16 - inlen % 16;
memcpy(iv, piv, 16);
if (in) {
memcpy(block, in + inlen - rem, rem);
}
memset(block + rem, padding, padding);
if (inlen/16) {
sm4_cbc_encrypt_blocks(key, iv, in, inlen/16, out);
out += inlen - rem;
iv = out - 16;
}
sm4_cbc_encrypt_blocks(key, iv, block, 1, out);
*outlen = inlen - rem + 16;
return 1;
}
int sm4_cbc_padding_decrypt(const SM4_KEY *key, const uint8_t iv[16],
int sm4_cbc_padding_decrypt(const SM4_KEY *key, const uint8_t piv[16],
const uint8_t *in, size_t inlen,
uint8_t *out, size_t *outlen)
{
uint8_t iv[16];
uint8_t block[16];
size_t len = sizeof(block);
int padding;
memcpy(iv, piv, 16);
if (inlen == 0) {
error_puts("warning: input lenght = 0");
return 0;
@@ -53,8 +59,8 @@ int sm4_cbc_padding_decrypt(const SM4_KEY *key, const uint8_t iv[16],
}
if (inlen > 16) {
sm4_cbc_decrypt_blocks(key, iv, in, inlen/16 - 1, out);
iv = in + inlen - 32;
}
sm4_cbc_decrypt_blocks(key, iv, in + inlen - 16, 1, block);
padding = block[15];
@@ -111,7 +117,7 @@ int sm4_cbc_encrypt_update(SM4_CBC_CTX *ctx,
}
memcpy(ctx->block + ctx->block_nbytes, in, left);
sm4_cbc_encrypt_blocks(&ctx->sm4_key, ctx->iv, ctx->block, 1, out);
memcpy(ctx->iv, out, SM4_BLOCK_SIZE);
//memcpy(ctx->iv, out, SM4_BLOCK_SIZE);
in += left;
inlen -= left;
out += SM4_BLOCK_SIZE;
@@ -121,7 +127,7 @@ int sm4_cbc_encrypt_update(SM4_CBC_CTX *ctx,
nblocks = inlen / SM4_BLOCK_SIZE;
len = nblocks * SM4_BLOCK_SIZE;
sm4_cbc_encrypt_blocks(&ctx->sm4_key, ctx->iv, in, nblocks, out);
memcpy(ctx->iv, out + len - SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
//memcpy(ctx->iv, out + len - SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
in += len;
inlen -= len;
out += len;
@@ -197,7 +203,6 @@ int sm4_cbc_decrypt_update(SM4_CBC_CTX *ctx,
}
memcpy(ctx->block + ctx->block_nbytes, in, left);
sm4_cbc_decrypt_blocks(&ctx->sm4_key, ctx->iv, ctx->block, 1, out);
memcpy(ctx->iv, ctx->block, SM4_BLOCK_SIZE);
in += left;
inlen -= left;
out += SM4_BLOCK_SIZE;
@@ -207,7 +212,6 @@ int sm4_cbc_decrypt_update(SM4_CBC_CTX *ctx,
nblocks = (inlen-1) / SM4_BLOCK_SIZE;
len = nblocks * SM4_BLOCK_SIZE;
sm4_cbc_decrypt_blocks(&ctx->sm4_key, ctx->iv, in, nblocks, out);
memcpy(ctx->iv, in + len - SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
in += len;
inlen -= len;
out += len;

View File

@@ -326,7 +326,7 @@ int tls_cbc_decrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *dec_key,
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
{
SM3_HMAC_CTX hmac_ctx;
const uint8_t *iv;
uint8_t iv[16];
const uint8_t *padding;
const uint8_t *mac;
uint8_t header[5];
@@ -345,7 +345,7 @@ int tls_cbc_decrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *dec_key,
return -1;
}
iv = in;
memcpy(iv, in, 16);
in += 16;
inlen -= 16;

View File

@@ -21,15 +21,19 @@
static int test_sm4_cbc(void)
{
SM4_KEY sm4_key;
uint8_t key[16] = {0};
uint8_t iv[16] = {0};
const uint8_t key[16] = {0};
const uint8_t civ[16] = {0};
uint8_t iv[16];
uint8_t buf1[32] = {0};
uint8_t buf2[32] = {0};
uint8_t buf3[32] = {0};
sm4_set_encrypt_key(&sm4_key, key);
memcpy(iv, civ, 16);
sm4_cbc_encrypt_blocks(&sm4_key, iv, buf1, 2, buf2);
sm4_set_decrypt_key(&sm4_key, key);
memcpy(iv, civ, 16);
sm4_cbc_decrypt_blocks(&sm4_key, iv, buf2, 2, buf3);
if (memcmp(buf1, buf3, sizeof(buf3)) != 0) {
@@ -107,6 +111,7 @@ static int test_sm4_cbc_test_vectors(void)
}
sm4_set_encrypt_key(&sm4_key, key);
hex_to_bytes(tests[i].iv, strlen(tests[i].iv), iv, &iv_len);
sm4_cbc_encrypt_blocks(&sm4_key, iv, plaintext, plaintext_len/16, encrypted);
if (memcmp(encrypted, ciphertext, ciphertext_len) != 0) {