Clean TLS code

This commit is contained in:
Zhi Guan
2026-06-13 23:52:29 +08:00
parent a73c303339
commit 5d12858d41
8 changed files with 231 additions and 308 deletions

169
src/tls.c
View File

@@ -458,44 +458,123 @@ int tls_cbc_decrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *dec
return 1;
}
int tls_record_cbc_encrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *cbc_key,
const uint8_t seq_num[8], const uint8_t *in, size_t inlen,
uint8_t *out, size_t *outlen)
int tls_gcm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4],
const uint8_t seq_num[8], const uint8_t header[5],
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
{
if (tls_cbc_encrypt(hmac_ctx, cbc_key, seq_num, in,
in + 5, inlen - 5,
out + 5, outlen) != 1) {
uint8_t nonce[12];
uint8_t aad[13];
uint8_t *explicit_nonce;
uint8_t *gmac;
if (!key || !fixed_iv || !seq_num || !header || (!in && inlen) || !out || !outlen) {
error_print();
return -1;
}
if (inlen > TLS_MAX_PLAINTEXT_SIZE) {
error_print();
return -1;
}
if ((((size_t)header[3]) << 8) + header[4] != inlen) {
error_print();
return -1;
}
out[0] = in[0];
out[1] = in[1];
out[2] = in[2];
out[3] = (uint8_t)((*outlen) >> 8);
out[4] = (uint8_t)(*outlen);
(*outlen) += 5;
memcpy(nonce, fixed_iv, 4);
memcpy(nonce + 4, seq_num, 8);
memcpy(aad, seq_num, 8);
memcpy(aad + 8, header, 5);
explicit_nonce = out;
memcpy(explicit_nonce, seq_num, 8);
out += 8;
gmac = out + inlen;
switch (key->cipher->oid) {
case OID_sm4:
if (sm4_gcm_encrypt(&(key->u.sm4_key), nonce, sizeof(nonce), aad, sizeof(aad),
in, inlen, out, GHASH_SIZE, gmac) != 1) {
error_print();
return -1;
}
break;
#ifdef ENABLE_AES
case OID_aes128:
if (aes_gcm_encrypt(&(key->u.aes_key), nonce, sizeof(nonce), aad, sizeof(aad),
in, inlen, out, GHASH_SIZE, gmac) != 1) {
error_print();
return -1;
}
break;
#endif
default:
error_print();
return -1;
}
*outlen = 8 + inlen + GHASH_SIZE;
return 1;
}
int tls_record_cbc_decrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *cbc_key,
const uint8_t seq_num[8], const uint8_t *in, size_t inlen,
uint8_t *out, size_t *outlen)
int tls_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4],
const uint8_t seq_num[8], const uint8_t header[5],
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
{
if (tls_cbc_decrypt(hmac_ctx, cbc_key, seq_num, in,
in + 5, inlen - 5,
out + 5, outlen) != 1) {
uint8_t nonce[12];
uint8_t aad[13];
const uint8_t *explicit_nonce;
const uint8_t *gmac;
size_t mlen;
if (inlen < 8 + GHASH_SIZE) {
error_print();
return -1;
}
out[0] = in[0];
out[1] = in[1];
out[2] = in[2];
out[3] = (uint8_t)((*outlen) >> 8);
out[4] = (uint8_t)(*outlen);
(*outlen) += 5;
explicit_nonce = in;
in += 8;
inlen -= 8;
if (inlen < GHASH_SIZE) {
error_print();
return -1;
}
mlen = inlen - GHASH_SIZE;
gmac = in + mlen;
memcpy(nonce, fixed_iv, 4);
memcpy(nonce + 4, explicit_nonce, 8);
memcpy(aad, seq_num, 8);
memcpy(aad + 8, header, 5);
aad[11] = (uint8_t)(mlen >> 8);
aad[12] = (uint8_t)mlen;
switch (key->cipher->oid) {
case OID_sm4:
if (sm4_gcm_decrypt(&(key->u.sm4_key), nonce, sizeof(nonce), aad, sizeof(aad),
in, mlen, gmac, GHASH_SIZE, out) != 1) {
error_print();
return -1;
}
break;
#ifdef ENABLE_AES
case OID_aes128:
if (aes_gcm_decrypt(&(key->u.aes_key), nonce, sizeof(nonce), aad, sizeof(aad),
in, mlen, gmac, GHASH_SIZE, out) != 1) {
error_print();
return -1;
}
break;
#endif
default:
error_print();
return -1;
}
*outlen = mlen;
return 1;
}
@@ -1876,18 +1955,18 @@ static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *i
switch (conn->cipher_suite) {
case TLS_cipher_ecdhe_sm4_gcm_sm3:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256:
if (tls12_record_gcm_encrypt(enc_key, fixed_iv, seq_num,
conn->databuf, tls_record_length(conn->databuf),
conn->record, &recordlen) != 1) {
if (tls_gcm_encrypt(enc_key, fixed_iv, seq_num, conn->databuf,
conn->databuf + 5, tls_record_data_length(conn->databuf),
conn->record + 5, &recordlen) != 1) {
error_print();
return -1;
}
break;
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
if (tls_record_cbc_encrypt(hmac_ctx, enc_key, seq_num,
conn->databuf, tls_record_length(conn->databuf),
conn->record, &recordlen) != 1) {
if (tls_cbc_encrypt(hmac_ctx, enc_key, seq_num, conn->databuf,
conn->databuf + 5, tls_record_data_length(conn->databuf),
conn->record + 5, &recordlen) != 1) {
error_print();
return -1;
}
@@ -1896,6 +1975,12 @@ static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *i
error_print();
return -1;
}
conn->record[0] = conn->databuf[0];
conn->record[1] = conn->databuf[1];
conn->record[2] = conn->databuf[2];
conn->record[3] = (uint8_t)(recordlen >> 8);
conn->record[4] = (uint8_t)(recordlen);
recordlen += 5;
} else if (conn->protocol == TLS_protocol_tlcp) {
if (tlcp_record_encrypt(conn->cipher_suite, hmac_ctx, enc_key, fixed_iv, seq_num,
conn->databuf, tls_record_length(conn->databuf),
@@ -1904,12 +1989,18 @@ static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *i
return -1;
}
} else {
if (tls_record_cbc_encrypt(hmac_ctx, enc_key, seq_num,
conn->databuf, tls_record_length(conn->databuf),
conn->record, &recordlen) != 1) {
if (tls_cbc_encrypt(hmac_ctx, enc_key, seq_num, conn->databuf,
conn->databuf + 5, tls_record_data_length(conn->databuf),
conn->record + 5, &recordlen) != 1) {
error_print();
return -1;
}
conn->record[0] = conn->databuf[0];
conn->record[1] = conn->databuf[1];
conn->record[2] = conn->databuf[2];
conn->record[3] = (uint8_t)(recordlen >> 8);
conn->record[4] = (uint8_t)(recordlen);
recordlen += 5;
}
tls_seq_num_incr(seq_num);
@@ -1987,12 +2078,18 @@ int tls_decrypt_recv(TLS_CONNECT *conn)
return -1;
}
} else {
if (tls_record_cbc_decrypt(hmac_ctx, dec_key, seq_num,
record, recordlen,
conn->databuf, &conn->datalen) != 1) {
if (tls_cbc_decrypt(hmac_ctx, dec_key, seq_num, record,
record + 5, recordlen - 5,
conn->databuf + 5, &conn->datalen) != 1) {
error_print();
return -1;
}
conn->databuf[0] = record[0];
conn->databuf[1] = record[1];
conn->databuf[2] = record[2];
conn->databuf[3] = (uint8_t)(conn->datalen >> 8);
conn->databuf[4] = (uint8_t)(conn->datalen);
conn->datalen += 5;
}
tls_seq_num_incr(seq_num);