Fix tls.c

This commit is contained in:
Zhi Guan
2026-06-11 21:03:14 +08:00
parent 94279854f8
commit 40e00284a2
4 changed files with 108 additions and 68 deletions

140
src/tls.c
View File

@@ -279,14 +279,53 @@ int tls_record_set_data(uint8_t *record, const uint8_t *data, size_t datalen)
return 1;
}
int tls_cbc_encrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *enc_key,
static void tls_cbc_encrypt_blocks(const BLOCK_CIPHER_KEY *key, uint8_t iv[16],
const uint8_t *in, size_t nblocks, uint8_t *out)
{
const uint8_t *piv = iv;
while (nblocks--) {
size_t i;
for (i = 0; i < 16; i++) {
out[i] = in[i] ^ piv[i];
}
block_cipher_encrypt(key, out, out);
piv = out;
in += 16;
out += 16;
}
memcpy(iv, piv, 16);
}
static void tls_cbc_decrypt_blocks(const BLOCK_CIPHER_KEY *key, uint8_t iv[16],
const uint8_t *in, size_t nblocks, uint8_t *out)
{
const uint8_t *piv = iv;
while (nblocks--) {
size_t i;
block_cipher_decrypt(key, in, out);
for (i = 0; i < 16; i++) {
out[i] ^= piv[i];
}
piv = in;
in += 16;
out += 16;
}
memcpy(iv, piv, 16);
}
int tls_cbc_encrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *enc_key,
const uint8_t seq_num[8], const uint8_t header[5],
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
{
SM3_HMAC_CTX hmac_ctx;
HMAC_CTX hmac_ctx;
uint8_t last_blocks[32 + 16] = {0};
uint8_t iv[16];
uint8_t *mac, *padding;
size_t maclen;
int rem, padding_len;
int i;
@@ -307,11 +346,11 @@ int tls_cbc_encrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *enc_key,
memcpy(last_blocks, in + inlen - rem, rem);
mac = last_blocks + rem;
memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
sm3_hmac_update(&hmac_ctx, seq_num, 8);
sm3_hmac_update(&hmac_ctx, header, 5);
sm3_hmac_update(&hmac_ctx, in, inlen);
sm3_hmac_finish(&hmac_ctx, mac);
memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(HMAC_CTX));
hmac_update(&hmac_ctx, seq_num, 8);
hmac_update(&hmac_ctx, header, 5);
hmac_update(&hmac_ctx, in, inlen);
hmac_finish(&hmac_ctx, mac, &maclen);
padding = mac + 32;
padding_len = 16 - rem - 1;
@@ -327,25 +366,26 @@ int tls_cbc_encrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *enc_key,
out += 16;
if (inlen >= 16) {
sm4_cbc_encrypt_blocks(enc_key, iv, in, inlen/16, out);
tls_cbc_encrypt_blocks(enc_key, iv, in, inlen/16, out);
out += inlen - rem;
}
sm4_cbc_encrypt_blocks(enc_key, iv, last_blocks, sizeof(last_blocks)/16, out);
tls_cbc_encrypt_blocks(enc_key, iv, last_blocks, sizeof(last_blocks)/16, out);
*outlen = 16 + inlen - rem + sizeof(last_blocks);
return 1;
}
int tls_cbc_decrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *dec_key,
int tls_cbc_decrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *dec_key,
const uint8_t seq_num[8], const uint8_t enced_header[5],
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
{
SM3_HMAC_CTX hmac_ctx;
HMAC_CTX hmac_ctx;
uint8_t iv[16];
const uint8_t *padding;
const uint8_t *mac;
uint8_t header[5];
int padding_len;
uint8_t hmac[32];
size_t hmaclen;
int i;
if (!inited_hmac_ctx || !dec_key || !seq_num || !enced_header || !in || !inlen || !out || !outlen) {
@@ -363,7 +403,7 @@ int tls_cbc_decrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *dec_key,
in += 16;
inlen -= 16;
sm4_cbc_decrypt_blocks(dec_key, iv, in, inlen/16, out);
tls_cbc_decrypt_blocks(dec_key, iv, in, inlen/16, out);
padding_len = out[inlen - 1];
padding = out + inlen - padding_len - 1;
@@ -387,11 +427,11 @@ int tls_cbc_decrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *dec_key,
header[4] = (uint8_t)(*outlen);
mac = padding - 32;
memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
sm3_hmac_update(&hmac_ctx, seq_num, 8);
sm3_hmac_update(&hmac_ctx, header, 5);
sm3_hmac_update(&hmac_ctx, out, *outlen);
sm3_hmac_finish(&hmac_ctx, hmac);
memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(HMAC_CTX));
hmac_update(&hmac_ctx, seq_num, 8);
hmac_update(&hmac_ctx, header, 5);
hmac_update(&hmac_ctx, out, *outlen);
hmac_finish(&hmac_ctx, hmac, &hmaclen);
if (gmssl_secure_memcmp(mac, hmac, sizeof(hmac)) != 0) {
error_puts("tls ciphertext mac check failure\n");
return -1;
@@ -399,7 +439,7 @@ int tls_cbc_decrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *dec_key,
return 1;
}
int tls_record_encrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key,
int tls_record_encrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *cbc_key,
const uint8_t seq_num[8], const uint8_t *in, size_t inlen,
uint8_t *out, size_t *outlen)
{
@@ -419,7 +459,7 @@ int tls_record_encrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key,
return 1;
}
int tls_record_decrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key,
int tls_record_decrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *cbc_key,
const uint8_t seq_num[8], const uint8_t *in, size_t inlen,
uint8_t *out, size_t *outlen)
{
@@ -458,8 +498,8 @@ int tls_prf(const uint8_t *secret, size_t secretlen, const char *label,
const uint8_t *more, size_t morelen,
size_t outlen, uint8_t *out)
{
SM3_HMAC_CTX inited_hmac_ctx;
SM3_HMAC_CTX hmac_ctx;
HMAC_CTX inited_hmac_ctx;
HMAC_CTX hmac_ctx;
uint8_t A[32];
uint8_t hmac[32];
size_t len;
@@ -470,20 +510,20 @@ int tls_prf(const uint8_t *secret, size_t secretlen, const char *label,
return -1;
}
sm3_hmac_init(&inited_hmac_ctx, secret, secretlen);
hmac_init(&inited_hmac_ctx, DIGEST_sm3(), secret, secretlen);
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
sm3_hmac_update(&hmac_ctx, seed, seedlen);
sm3_hmac_update(&hmac_ctx, more, morelen);
sm3_hmac_finish(&hmac_ctx, A);
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX));
hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
hmac_update(&hmac_ctx, seed, seedlen);
hmac_update(&hmac_ctx, more, morelen);
hmac_finish(&hmac_ctx, A, &len);
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
sm3_hmac_update(&hmac_ctx, A, sizeof(A));
sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
sm3_hmac_update(&hmac_ctx, seed, seedlen);
sm3_hmac_update(&hmac_ctx, more, morelen);
sm3_hmac_finish(&hmac_ctx, hmac);
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX));
hmac_update(&hmac_ctx, A, sizeof(A));
hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
hmac_update(&hmac_ctx, seed, seedlen);
hmac_update(&hmac_ctx, more, morelen);
hmac_finish(&hmac_ctx, hmac, &len);
len = outlen < sizeof(hmac) ? outlen : sizeof(hmac);
memcpy(out, hmac, len);
@@ -491,16 +531,16 @@ int tls_prf(const uint8_t *secret, size_t secretlen, const char *label,
outlen -= len;
while (outlen) {
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
sm3_hmac_update(&hmac_ctx, A, sizeof(A));
sm3_hmac_finish(&hmac_ctx, A);
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX));
hmac_update(&hmac_ctx, A, sizeof(A));
hmac_finish(&hmac_ctx, A, &len);
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
sm3_hmac_update(&hmac_ctx, A, sizeof(A));
sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
sm3_hmac_update(&hmac_ctx, seed, seedlen);
sm3_hmac_update(&hmac_ctx, more, morelen);
sm3_hmac_finish(&hmac_ctx, hmac);
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX));
hmac_update(&hmac_ctx, A, sizeof(A));
hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
hmac_update(&hmac_ctx, seed, seedlen);
hmac_update(&hmac_ctx, more, morelen);
hmac_finish(&hmac_ctx, hmac, &len);
len = outlen < sizeof(hmac) ? outlen : sizeof(hmac);
memcpy(out, hmac, len);
@@ -1751,8 +1791,8 @@ int tls_send_warning(TLS_CONNECT *conn, int alert)
static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *in, size_t inlen, size_t *sentlen)
{
const SM3_HMAC_CTX *hmac_ctx;
const SM4_KEY *enc_key;
const HMAC_CTX *hmac_ctx;
const BLOCK_CIPHER_KEY *enc_key;
uint8_t *seq_num;
size_t recordlen;
@@ -1776,11 +1816,11 @@ static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *i
if (conn->is_client) {
hmac_ctx = &conn->client_write_mac_ctx;
enc_key = &conn->client_write_enc_key;
enc_key = &conn->client_write_key;
seq_num = conn->client_seq_num;
} else {
hmac_ctx = &conn->server_write_mac_ctx;
enc_key = &conn->server_write_enc_key;
enc_key = &conn->server_write_key;
seq_num = conn->server_seq_num;
}
@@ -1813,8 +1853,8 @@ static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *i
int tls_decrypt_recv(TLS_CONNECT *conn)
{
int ret;
const SM3_HMAC_CTX *hmac_ctx;
const SM4_KEY *dec_key;
const HMAC_CTX *hmac_ctx;
const BLOCK_CIPHER_KEY *dec_key;
uint8_t *seq_num;
uint8_t *record = conn->record;
@@ -1822,11 +1862,11 @@ int tls_decrypt_recv(TLS_CONNECT *conn)
if (conn->is_client) {
hmac_ctx = &conn->server_write_mac_ctx;
dec_key = &conn->server_write_enc_key;
dec_key = &conn->server_write_key;
seq_num = conn->server_seq_num;
} else {
hmac_ctx = &conn->client_write_mac_ctx;
dec_key = &conn->client_write_enc_key;
dec_key = &conn->client_write_key;
seq_num = conn->client_seq_num;
}