diff --git a/include/gmssl/sm4.h b/include/gmssl/sm4.h index 18b17475..c21d4483 100644 --- a/include/gmssl/sm4.h +++ b/include/gmssl/sm4.h @@ -35,9 +35,9 @@ void sm4_set_decrypt_key(SM4_KEY *key, const uint8_t raw_key[SM4_KEY_SIZE]); void sm4_encrypt(const SM4_KEY *key, const uint8_t in[SM4_BLOCK_SIZE], uint8_t out[SM4_BLOCK_SIZE]); void sm4_encrypt_blocks(const SM4_KEY *key, const uint8_t *in, size_t nblocks, uint8_t *out); -void sm4_cbc_encrypt_blocks(const SM4_KEY *key, const uint8_t iv[SM4_BLOCK_SIZE], +void sm4_cbc_encrypt_blocks(const SM4_KEY *key, uint8_t iv[SM4_BLOCK_SIZE], const uint8_t *in, size_t nblocks, uint8_t *out); -void sm4_cbc_decrypt_blocks(const SM4_KEY *key, const uint8_t iv[SM4_BLOCK_SIZE], +void sm4_cbc_decrypt_blocks(const SM4_KEY *key, uint8_t iv[SM4_BLOCK_SIZE], const uint8_t *in, size_t nblocks, uint8_t *out); void sm4_ctr_encrypt_blocks(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t nblocks, uint8_t *out); void sm4_ctr32_encrypt_blocks(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t nblocks, uint8_t *out); diff --git a/src/sm4.c b/src/sm4.c index d26a6e73..472b7807 100644 --- a/src/sm4.c +++ b/src/sm4.c @@ -168,34 +168,42 @@ void sm4_encrypt_blocks(const SM4_KEY *key, const uint8_t *in, size_t nblocks, u } } -void sm4_cbc_encrypt_blocks(const SM4_KEY *key, const uint8_t iv[16], +void sm4_cbc_encrypt_blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out) { + uint8_t *piv = iv; + while (nblocks--) { size_t i; for (i = 0; i < 16; i++) { - out[i] = in[i] ^ iv[i]; + out[i] = in[i] ^ piv[i]; } sm4_encrypt(key, out, out); - iv = out; + piv = out; in += 16; out += 16; } + + memcpy(iv, piv, 16); } -void sm4_cbc_decrypt_blocks(const SM4_KEY *key, const uint8_t iv[16], +void sm4_cbc_decrypt_blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out) { + uint8_t *piv = iv; + while (nblocks--) { size_t i; sm4_encrypt(key, in, out); for (i = 0; i < 16; i++) { - out[i] ^= iv[i]; + out[i] ^= piv[i]; } - iv = in; + piv = in; in += 16; out += 16; } + + memcpy(iv, piv, 16); } static void ctr_incr(uint8_t a[16]) { @@ -630,7 +638,7 @@ void sm4_encrypt_blocks(const SM4_KEY *key, const uint8_t *in, size_t nblocks, u } } -void sm4_cbc_encrypt_blocks(const SM4_KEY *key, const uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out) +void sm4_cbc_encrypt_blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out) { const uint32_t *rk = key->rk; uint32_t X0, X1, X2, X3, X4; @@ -690,9 +698,14 @@ void sm4_cbc_encrypt_blocks(const SM4_KEY *key, const uint8_t iv[16], const uint in += 16; out += 16; } + + PUTU32(iv , X0); + PUTU32(iv + 4, X4); + PUTU32(iv + 8, X3); + PUTU32(iv + 12, X2); } -void sm4_cbc_decrypt_blocks(const SM4_KEY *key, const uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out) +void sm4_cbc_decrypt_blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out) { const uint32_t *rk = key->rk; uint32_t IV0, IV1, IV2, IV3; @@ -756,6 +769,11 @@ void sm4_cbc_decrypt_blocks(const SM4_KEY *key, const uint8_t iv[16], const uint in += 16; out += 16; } + + PUTU32(iv , IV0); + PUTU32(iv + 4, IV1); + PUTU32(iv + 8, IV2); + PUTU32(iv + 12, IV3); } void sm4_ctr_encrypt_blocks(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t nblocks, uint8_t *out) diff --git a/src/sm4_cbc.c b/src/sm4_cbc.c index 95e5e565..2349153c 100644 --- a/src/sm4_cbc.c +++ b/src/sm4_cbc.c @@ -13,36 +13,42 @@ #include -int sm4_cbc_padding_encrypt(const SM4_KEY *key, const uint8_t iv[16], +int sm4_cbc_padding_encrypt(const SM4_KEY *key, const uint8_t piv[16], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { + uint8_t iv[16]; uint8_t block[16]; size_t rem = inlen % 16; int padding = 16 - inlen % 16; + memcpy(iv, piv, 16); + if (in) { memcpy(block, in + inlen - rem, rem); } memset(block + rem, padding, padding); + if (inlen/16) { sm4_cbc_encrypt_blocks(key, iv, in, inlen/16, out); out += inlen - rem; - iv = out - 16; } sm4_cbc_encrypt_blocks(key, iv, block, 1, out); *outlen = inlen - rem + 16; return 1; } -int sm4_cbc_padding_decrypt(const SM4_KEY *key, const uint8_t iv[16], +int sm4_cbc_padding_decrypt(const SM4_KEY *key, const uint8_t piv[16], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { + uint8_t iv[16]; uint8_t block[16]; size_t len = sizeof(block); int padding; + memcpy(iv, piv, 16); + if (inlen == 0) { error_puts("warning: input lenght = 0"); return 0; @@ -53,8 +59,8 @@ int sm4_cbc_padding_decrypt(const SM4_KEY *key, const uint8_t iv[16], } if (inlen > 16) { sm4_cbc_decrypt_blocks(key, iv, in, inlen/16 - 1, out); - iv = in + inlen - 32; } + sm4_cbc_decrypt_blocks(key, iv, in + inlen - 16, 1, block); padding = block[15]; @@ -111,7 +117,7 @@ int sm4_cbc_encrypt_update(SM4_CBC_CTX *ctx, } memcpy(ctx->block + ctx->block_nbytes, in, left); sm4_cbc_encrypt_blocks(&ctx->sm4_key, ctx->iv, ctx->block, 1, out); - memcpy(ctx->iv, out, SM4_BLOCK_SIZE); + //memcpy(ctx->iv, out, SM4_BLOCK_SIZE); in += left; inlen -= left; out += SM4_BLOCK_SIZE; @@ -121,7 +127,7 @@ int sm4_cbc_encrypt_update(SM4_CBC_CTX *ctx, nblocks = inlen / SM4_BLOCK_SIZE; len = nblocks * SM4_BLOCK_SIZE; sm4_cbc_encrypt_blocks(&ctx->sm4_key, ctx->iv, in, nblocks, out); - memcpy(ctx->iv, out + len - SM4_BLOCK_SIZE, SM4_BLOCK_SIZE); + //memcpy(ctx->iv, out + len - SM4_BLOCK_SIZE, SM4_BLOCK_SIZE); in += len; inlen -= len; out += len; @@ -197,7 +203,6 @@ int sm4_cbc_decrypt_update(SM4_CBC_CTX *ctx, } memcpy(ctx->block + ctx->block_nbytes, in, left); sm4_cbc_decrypt_blocks(&ctx->sm4_key, ctx->iv, ctx->block, 1, out); - memcpy(ctx->iv, ctx->block, SM4_BLOCK_SIZE); in += left; inlen -= left; out += SM4_BLOCK_SIZE; @@ -207,7 +212,6 @@ int sm4_cbc_decrypt_update(SM4_CBC_CTX *ctx, nblocks = (inlen-1) / SM4_BLOCK_SIZE; len = nblocks * SM4_BLOCK_SIZE; sm4_cbc_decrypt_blocks(&ctx->sm4_key, ctx->iv, in, nblocks, out); - memcpy(ctx->iv, in + len - SM4_BLOCK_SIZE, SM4_BLOCK_SIZE); in += len; inlen -= len; out += len; diff --git a/src/tls.c b/src/tls.c index 296ef947..b2df6783 100644 --- a/src/tls.c +++ b/src/tls.c @@ -326,7 +326,7 @@ int tls_cbc_decrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *dec_key, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { SM3_HMAC_CTX hmac_ctx; - const uint8_t *iv; + uint8_t iv[16]; const uint8_t *padding; const uint8_t *mac; uint8_t header[5]; @@ -345,7 +345,7 @@ int tls_cbc_decrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *dec_key, return -1; } - iv = in; + memcpy(iv, in, 16); in += 16; inlen -= 16; diff --git a/tests/sm4_cbctest.c b/tests/sm4_cbctest.c index d657671c..0bd375d5 100644 --- a/tests/sm4_cbctest.c +++ b/tests/sm4_cbctest.c @@ -21,15 +21,19 @@ static int test_sm4_cbc(void) { SM4_KEY sm4_key; - uint8_t key[16] = {0}; - uint8_t iv[16] = {0}; + const uint8_t key[16] = {0}; + const uint8_t civ[16] = {0}; + uint8_t iv[16]; uint8_t buf1[32] = {0}; uint8_t buf2[32] = {0}; uint8_t buf3[32] = {0}; sm4_set_encrypt_key(&sm4_key, key); + memcpy(iv, civ, 16); sm4_cbc_encrypt_blocks(&sm4_key, iv, buf1, 2, buf2); + sm4_set_decrypt_key(&sm4_key, key); + memcpy(iv, civ, 16); sm4_cbc_decrypt_blocks(&sm4_key, iv, buf2, 2, buf3); if (memcmp(buf1, buf3, sizeof(buf3)) != 0) { @@ -107,6 +111,7 @@ static int test_sm4_cbc_test_vectors(void) } sm4_set_encrypt_key(&sm4_key, key); + hex_to_bytes(tests[i].iv, strlen(tests[i].iv), iv, &iv_len); sm4_cbc_encrypt_blocks(&sm4_key, iv, plaintext, plaintext_len/16, encrypted); if (memcmp(encrypted, ciphertext, ciphertext_len) != 0) {