From 40e00284a2d4a774ca825418ee3b9f5abed168aa Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Thu, 11 Jun 2026 21:03:14 +0800 Subject: [PATCH] Fix tls.c --- include/gmssl/tls.h | 8 +-- src/tlcp.c | 16 ++--- src/tls.c | 140 ++++++++++++++++++++++++++++---------------- tests/tlstest.c | 12 ++-- 4 files changed, 108 insertions(+), 68 deletions(-) diff --git a/include/gmssl/tls.h b/include/gmssl/tls.h index 4dc1b72f..4bf515e7 100644 --- a/include/gmssl/tls.h +++ b/include/gmssl/tls.h @@ -404,16 +404,16 @@ int tls_prf(const uint8_t *secret, size_t secretlen, const char *label, const uint8_t *more, size_t morelen, size_t outlen, uint8_t *out); -int tls_cbc_encrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *enc_key, +int tls_cbc_encrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *enc_key, const uint8_t seq_num[8], const uint8_t header[5], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); -int tls_cbc_decrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *dec_key, +int tls_cbc_decrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *dec_key, const uint8_t seq_num[8], const uint8_t header[5], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); -int tls_record_encrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key, +int tls_record_encrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *cbc_key, const uint8_t seq_num[8], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); -int tls_record_decrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key, +int tls_record_decrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *cbc_key, const uint8_t seq_num[8], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); diff --git a/src/tlcp.c b/src/tlcp.c index 8eda2759..0274482a 100644 --- a/src/tlcp.c +++ b/src/tlcp.c @@ -925,7 +925,7 @@ int tlcp_send_client_finished(TLS_CONNECT *conn) //tlcp_handshake_digest_print(stderr, 0, 0, "client Finished", &conn->sm3_ctx); - if (tls_record_encrypt(&conn->client_write_mac_ctx, &conn->client_write_enc_key, + if (tls_record_encrypt(&conn->client_write_mac_ctx, &conn->client_write_key, conn->client_seq_num, conn->plain_record, conn->plain_recordlen, conn->record, &conn->recordlen) != 1) { @@ -976,7 +976,7 @@ int tlcp_recv_server_finished(TLS_CONNECT *conn) return -1; } - if (tls_record_decrypt(&conn->server_write_mac_ctx, &conn->server_write_enc_key, + if (tls_record_decrypt(&conn->server_write_mac_ctx, &conn->server_write_key, conn->server_seq_num, conn->record, conn->recordlen, conn->plain_record, &conn->plain_recordlen) != 1) { error_print(); @@ -1467,15 +1467,15 @@ int tlcp_generate_keys(TLS_CONNECT *conn) } // 主力这里是不对的,需要为client, server设定不同的加密密钥 - sm3_hmac_init(&conn->client_write_mac_ctx, conn->key_block, 32); - sm3_hmac_init(&conn->server_write_mac_ctx, conn->key_block + 32, 32); + hmac_init(&conn->client_write_mac_ctx, DIGEST_sm3(), conn->key_block, 32); + hmac_init(&conn->server_write_mac_ctx, DIGEST_sm3(), conn->key_block + 32, 32); if (conn->is_client) { - sm4_set_encrypt_key(&conn->client_write_enc_key, conn->key_block + 64); - sm4_set_decrypt_key(&conn->server_write_enc_key, conn->key_block + 80); + block_cipher_set_encrypt_key(&conn->client_write_key, BLOCK_CIPHER_sm4(), conn->key_block + 64); + block_cipher_set_decrypt_key(&conn->server_write_key, BLOCK_CIPHER_sm4(), conn->key_block + 80); } else { - sm4_set_decrypt_key(&conn->client_write_enc_key, conn->key_block + 64); - sm4_set_encrypt_key(&conn->server_write_enc_key, conn->key_block + 80); + block_cipher_set_decrypt_key(&conn->client_write_key, BLOCK_CIPHER_sm4(), conn->key_block + 64); + block_cipher_set_encrypt_key(&conn->server_write_key, BLOCK_CIPHER_sm4(), conn->key_block + 80); } tls_secrets_print(stderr, diff --git a/src/tls.c b/src/tls.c index 6185fb96..41f2d1e9 100644 --- a/src/tls.c +++ b/src/tls.c @@ -279,14 +279,53 @@ int tls_record_set_data(uint8_t *record, const uint8_t *data, size_t datalen) return 1; } -int tls_cbc_encrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *enc_key, +static void tls_cbc_encrypt_blocks(const BLOCK_CIPHER_KEY *key, uint8_t iv[16], + const uint8_t *in, size_t nblocks, uint8_t *out) +{ + const uint8_t *piv = iv; + + while (nblocks--) { + size_t i; + for (i = 0; i < 16; i++) { + out[i] = in[i] ^ piv[i]; + } + block_cipher_encrypt(key, out, out); + piv = out; + in += 16; + out += 16; + } + + memcpy(iv, piv, 16); +} + +static void tls_cbc_decrypt_blocks(const BLOCK_CIPHER_KEY *key, uint8_t iv[16], + const uint8_t *in, size_t nblocks, uint8_t *out) +{ + const uint8_t *piv = iv; + + while (nblocks--) { + size_t i; + block_cipher_decrypt(key, in, out); + for (i = 0; i < 16; i++) { + out[i] ^= piv[i]; + } + piv = in; + in += 16; + out += 16; + } + + memcpy(iv, piv, 16); +} + +int tls_cbc_encrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *enc_key, const uint8_t seq_num[8], const uint8_t header[5], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { - SM3_HMAC_CTX hmac_ctx; + HMAC_CTX hmac_ctx; uint8_t last_blocks[32 + 16] = {0}; uint8_t iv[16]; uint8_t *mac, *padding; + size_t maclen; int rem, padding_len; int i; @@ -307,11 +346,11 @@ int tls_cbc_encrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *enc_key, memcpy(last_blocks, in + inlen - rem, rem); mac = last_blocks + rem; - memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(SM3_HMAC_CTX)); - sm3_hmac_update(&hmac_ctx, seq_num, 8); - sm3_hmac_update(&hmac_ctx, header, 5); - sm3_hmac_update(&hmac_ctx, in, inlen); - sm3_hmac_finish(&hmac_ctx, mac); + memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(HMAC_CTX)); + hmac_update(&hmac_ctx, seq_num, 8); + hmac_update(&hmac_ctx, header, 5); + hmac_update(&hmac_ctx, in, inlen); + hmac_finish(&hmac_ctx, mac, &maclen); padding = mac + 32; padding_len = 16 - rem - 1; @@ -327,25 +366,26 @@ int tls_cbc_encrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *enc_key, out += 16; if (inlen >= 16) { - sm4_cbc_encrypt_blocks(enc_key, iv, in, inlen/16, out); + tls_cbc_encrypt_blocks(enc_key, iv, in, inlen/16, out); out += inlen - rem; } - sm4_cbc_encrypt_blocks(enc_key, iv, last_blocks, sizeof(last_blocks)/16, out); + tls_cbc_encrypt_blocks(enc_key, iv, last_blocks, sizeof(last_blocks)/16, out); *outlen = 16 + inlen - rem + sizeof(last_blocks); return 1; } -int tls_cbc_decrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *dec_key, +int tls_cbc_decrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *dec_key, const uint8_t seq_num[8], const uint8_t enced_header[5], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { - SM3_HMAC_CTX hmac_ctx; + HMAC_CTX hmac_ctx; uint8_t iv[16]; const uint8_t *padding; const uint8_t *mac; uint8_t header[5]; int padding_len; uint8_t hmac[32]; + size_t hmaclen; int i; if (!inited_hmac_ctx || !dec_key || !seq_num || !enced_header || !in || !inlen || !out || !outlen) { @@ -363,7 +403,7 @@ int tls_cbc_decrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *dec_key, in += 16; inlen -= 16; - sm4_cbc_decrypt_blocks(dec_key, iv, in, inlen/16, out); + tls_cbc_decrypt_blocks(dec_key, iv, in, inlen/16, out); padding_len = out[inlen - 1]; padding = out + inlen - padding_len - 1; @@ -387,11 +427,11 @@ int tls_cbc_decrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *dec_key, header[4] = (uint8_t)(*outlen); mac = padding - 32; - memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(SM3_HMAC_CTX)); - sm3_hmac_update(&hmac_ctx, seq_num, 8); - sm3_hmac_update(&hmac_ctx, header, 5); - sm3_hmac_update(&hmac_ctx, out, *outlen); - sm3_hmac_finish(&hmac_ctx, hmac); + memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(HMAC_CTX)); + hmac_update(&hmac_ctx, seq_num, 8); + hmac_update(&hmac_ctx, header, 5); + hmac_update(&hmac_ctx, out, *outlen); + hmac_finish(&hmac_ctx, hmac, &hmaclen); if (gmssl_secure_memcmp(mac, hmac, sizeof(hmac)) != 0) { error_puts("tls ciphertext mac check failure\n"); return -1; @@ -399,7 +439,7 @@ int tls_cbc_decrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *dec_key, return 1; } -int tls_record_encrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key, +int tls_record_encrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *cbc_key, const uint8_t seq_num[8], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { @@ -419,7 +459,7 @@ int tls_record_encrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key, return 1; } -int tls_record_decrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key, +int tls_record_decrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *cbc_key, const uint8_t seq_num[8], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { @@ -458,8 +498,8 @@ int tls_prf(const uint8_t *secret, size_t secretlen, const char *label, const uint8_t *more, size_t morelen, size_t outlen, uint8_t *out) { - SM3_HMAC_CTX inited_hmac_ctx; - SM3_HMAC_CTX hmac_ctx; + HMAC_CTX inited_hmac_ctx; + HMAC_CTX hmac_ctx; uint8_t A[32]; uint8_t hmac[32]; size_t len; @@ -470,20 +510,20 @@ int tls_prf(const uint8_t *secret, size_t secretlen, const char *label, return -1; } - sm3_hmac_init(&inited_hmac_ctx, secret, secretlen); + hmac_init(&inited_hmac_ctx, DIGEST_sm3(), secret, secretlen); - memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX)); - sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); - sm3_hmac_update(&hmac_ctx, seed, seedlen); - sm3_hmac_update(&hmac_ctx, more, morelen); - sm3_hmac_finish(&hmac_ctx, A); + memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX)); + hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); + hmac_update(&hmac_ctx, seed, seedlen); + hmac_update(&hmac_ctx, more, morelen); + hmac_finish(&hmac_ctx, A, &len); - memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX)); - sm3_hmac_update(&hmac_ctx, A, sizeof(A)); - sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); - sm3_hmac_update(&hmac_ctx, seed, seedlen); - sm3_hmac_update(&hmac_ctx, more, morelen); - sm3_hmac_finish(&hmac_ctx, hmac); + memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX)); + hmac_update(&hmac_ctx, A, sizeof(A)); + hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); + hmac_update(&hmac_ctx, seed, seedlen); + hmac_update(&hmac_ctx, more, morelen); + hmac_finish(&hmac_ctx, hmac, &len); len = outlen < sizeof(hmac) ? outlen : sizeof(hmac); memcpy(out, hmac, len); @@ -491,16 +531,16 @@ int tls_prf(const uint8_t *secret, size_t secretlen, const char *label, outlen -= len; while (outlen) { - memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX)); - sm3_hmac_update(&hmac_ctx, A, sizeof(A)); - sm3_hmac_finish(&hmac_ctx, A); + memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX)); + hmac_update(&hmac_ctx, A, sizeof(A)); + hmac_finish(&hmac_ctx, A, &len); - memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX)); - sm3_hmac_update(&hmac_ctx, A, sizeof(A)); - sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); - sm3_hmac_update(&hmac_ctx, seed, seedlen); - sm3_hmac_update(&hmac_ctx, more, morelen); - sm3_hmac_finish(&hmac_ctx, hmac); + memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX)); + hmac_update(&hmac_ctx, A, sizeof(A)); + hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); + hmac_update(&hmac_ctx, seed, seedlen); + hmac_update(&hmac_ctx, more, morelen); + hmac_finish(&hmac_ctx, hmac, &len); len = outlen < sizeof(hmac) ? outlen : sizeof(hmac); memcpy(out, hmac, len); @@ -1751,8 +1791,8 @@ int tls_send_warning(TLS_CONNECT *conn, int alert) static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *in, size_t inlen, size_t *sentlen) { - const SM3_HMAC_CTX *hmac_ctx; - const SM4_KEY *enc_key; + const HMAC_CTX *hmac_ctx; + const BLOCK_CIPHER_KEY *enc_key; uint8_t *seq_num; size_t recordlen; @@ -1776,11 +1816,11 @@ static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *i if (conn->is_client) { hmac_ctx = &conn->client_write_mac_ctx; - enc_key = &conn->client_write_enc_key; + enc_key = &conn->client_write_key; seq_num = conn->client_seq_num; } else { hmac_ctx = &conn->server_write_mac_ctx; - enc_key = &conn->server_write_enc_key; + enc_key = &conn->server_write_key; seq_num = conn->server_seq_num; } @@ -1813,8 +1853,8 @@ static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *i int tls_decrypt_recv(TLS_CONNECT *conn) { int ret; - const SM3_HMAC_CTX *hmac_ctx; - const SM4_KEY *dec_key; + const HMAC_CTX *hmac_ctx; + const BLOCK_CIPHER_KEY *dec_key; uint8_t *seq_num; uint8_t *record = conn->record; @@ -1822,11 +1862,11 @@ int tls_decrypt_recv(TLS_CONNECT *conn) if (conn->is_client) { hmac_ctx = &conn->server_write_mac_ctx; - dec_key = &conn->server_write_enc_key; + dec_key = &conn->server_write_key; seq_num = conn->server_seq_num; } else { hmac_ctx = &conn->client_write_mac_ctx; - dec_key = &conn->client_write_enc_key; + dec_key = &conn->client_write_key; seq_num = conn->client_seq_num; } diff --git a/tests/tlstest.c b/tests/tlstest.c index 85e1e032..0622beba 100644 --- a/tests/tlstest.c +++ b/tests/tlstest.c @@ -82,8 +82,8 @@ static int test_tls_null_to_bytes(void) static int test_tls_cbc(void) { uint8_t key[32] = {0}; - SM3_HMAC_CTX hmac_ctx; - SM4_KEY sm4_key; + HMAC_CTX hmac_ctx; + BLOCK_CIPHER_KEY sm4_key; uint8_t seq_num[8] = { 0,0,0,0,0,0,0,1 }; uint8_t header[5]; uint8_t in[] = "hello world"; @@ -98,12 +98,12 @@ static int test_tls_cbc(void) header[3] = sizeof(in) >> 8; header[4] = sizeof(in) & 0xff; - sm3_hmac_init(&hmac_ctx, key, 32); - sm4_set_encrypt_key(&sm4_key, key); + hmac_init(&hmac_ctx, DIGEST_sm3(), key, 32); + block_cipher_set_encrypt_key(&sm4_key, BLOCK_CIPHER_sm4(), key); tls_cbc_encrypt(&hmac_ctx, &sm4_key, seq_num, header, in, sizeof(in), out, &len); - sm3_hmac_init(&hmac_ctx, key, 32); - sm4_set_decrypt_key(&sm4_key, key); + hmac_init(&hmac_ctx, DIGEST_sm3(), key, 32); + block_cipher_set_decrypt_key(&sm4_key, BLOCK_CIPHER_sm4(), key); tls_cbc_decrypt(&hmac_ctx, &sm4_key, seq_num, header, out, len, buf, &buflen);