Clean TLS code

This commit is contained in:
Zhi Guan
2026-06-15 11:15:33 +08:00
parent 808d22e2a5
commit 823fe11897
6 changed files with 533 additions and 402 deletions

412
src/tls.c
View File

@@ -698,6 +698,21 @@ int tls_ccm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4],
return 1;
}
int tls_seq_num_incr(uint8_t seq_num[8])
{
int i;
for (i = 7; i > 0; i--) {
seq_num[i]++;
if (seq_num[i]) break;
}
return 1;
}
void tls_seq_num_reset(uint8_t seq_num[8])
{
memset(seq_num, 0, 8);
}
int tls_random_generate(uint8_t random[32])
{
uint32_t gmt_unix_time = (uint32_t)time(NULL);
@@ -1094,14 +1109,123 @@ int tls_cipher_suites_select(const uint8_t *client_ciphers, size_t client_cipher
return 0;
}
int tls_compression_methods_has_null_compression(const uint8_t *meths, size_t methslen)
{
if (!meths || !methslen) {
error_print();
return -1;
}
while (methslen--) {
if (*meths++ == TLS_compression_null) {
return 1;
}
}
error_print();
return -1;
}
int tls_record_encrypt(int cipher_suite,
const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4],
const uint8_t seq_num[8], const uint8_t *in, size_t inlen,
uint8_t *out, size_t *outlen)
{
switch (cipher_suite) {
case TLS_cipher_ecc_sm4_cbc_sm3:
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
if (tls_cbc_encrypt(hmac_ctx, key, seq_num, in,
in + 5, inlen - 5,
out + 5, outlen) != 1) {
error_print();
return -1;
}
break;
case TLS_cipher_ecc_sm4_gcm_sm3:
case TLS_cipher_ecdhe_sm4_gcm_sm3:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256:
if (tls_gcm_encrypt(key, fixed_iv, seq_num, in,
in + 5, inlen - 5,
out + 5, outlen) != 1) {
error_print();
return -1;
}
break;
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
if (tls_ccm_encrypt(key, fixed_iv, seq_num, in,
in + 5, inlen - 5,
out + 5, outlen) != 1) {
error_print();
return -1;
}
break;
#endif
default:
error_print();
return -1;
}
out[0] = in[0];
out[1] = in[1];
out[2] = in[2];
out[3] = (uint8_t)((*outlen) >> 8);
out[4] = (uint8_t)(*outlen);
(*outlen) += 5;
return 1;
}
int tls_record_decrypt(int cipher_suite, const HMAC_CTX *hmac_ctx,
const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4],
const uint8_t seq_num[8], const uint8_t *in, size_t inlen,
uint8_t *out, size_t *outlen)
{
switch (cipher_suite) {
case TLS_cipher_ecc_sm4_cbc_sm3:
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
if (tls_cbc_decrypt(hmac_ctx, key, seq_num, in,
in + 5, inlen - 5,
out + 5, outlen) != 1) {
error_print();
return -1;
}
break;
case TLS_cipher_ecc_sm4_gcm_sm3:
case TLS_cipher_ecdhe_sm4_gcm_sm3:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256:
if (tls_gcm_decrypt(key, fixed_iv, seq_num, in,
in + 5, inlen - 5,
out + 5, outlen) != 1) {
error_print();
return -1;
}
break;
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
if (tls_ccm_decrypt(key, fixed_iv, seq_num, in,
in + 5, inlen - 5,
out + 5, outlen) != 1) {
error_print();
return -1;
}
break;
#endif
default:
error_print();
return -1;
}
out[0] = in[0];
out[1] = in[1];
out[2] = in[2];
out[3] = (uint8_t)((*outlen) >> 8);
out[4] = (uint8_t)(*outlen);
(*outlen) += 5;
return 1;
}
@@ -1115,10 +1239,6 @@ int tls_record_set_handshake(uint8_t *record, size_t *recordlen,
error_print();
return -1;
}
if (!data && datalen) {
error_print();
return -1;
}
if (datalen > TLS_MAX_PLAINTEXT_SIZE - TLS_HANDSHAKE_HEADER_SIZE) {
error_print();
return -1;
@@ -1584,7 +1704,7 @@ int tls_server_ecdh_params_to_bytes(const X509_KEY *public_key, uint8_t **out, s
{
int named_curve;
uint8_t point[65];
uint8_t *point_ptr;
uint8_t *point_ptr = point;
size_t point_len = 0;
if (!public_key || !outlen) {
@@ -2178,35 +2298,6 @@ int tls_recv_record(TLS_CONNECT *conn)
}
int tls_seq_num_incr(uint8_t seq_num[8])
{
int i;
for (i = 7; i > 0; i--) {
seq_num[i]++;
if (seq_num[i]) break;
}
return 1;
}
void tls_seq_num_reset(uint8_t seq_num[8])
{
memset(seq_num, 0, 8);
}
int tls_compression_methods_has_null_compression(const uint8_t *meths, size_t methslen)
{
if (!meths || !methslen) {
error_print();
return -1;
}
while (methslen--) {
if (*meths++ == TLS_compression_null) {
return 1;
}
}
error_print();
return -1;
}
int tls_send_alert(TLS_CONNECT *conn, int alert)
{
@@ -2225,7 +2316,13 @@ int tls_send_alert(TLS_CONNECT *conn, int alert)
error_print();
return -1;
}
if(conn->verbose) tls_record_trace(stderr, record, sizeof(record), 0, 0);
if (conn->verbose) {
if (conn->protocol == TLS_protocol_tls12) {
tls12_record_print(stderr, record, sizeof(record), 0, 0);
} else {
tls_record_trace(stderr, record, sizeof(record), 0, 0);
}
}
return 1;
}
@@ -2274,153 +2371,18 @@ int tls_send_warning(TLS_CONNECT *conn, int alert)
error_print();
return -1;
}
if(conn->verbose) tls_record_trace(stderr, record, sizeof(record), 0, 0);
return 1;
}
static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *in, size_t inlen, size_t *sentlen)
{
const HMAC_CTX *hmac_ctx;
const BLOCK_CIPHER_KEY *enc_key;
const uint8_t *fixed_iv;
uint8_t *seq_num;
size_t recordlen;
int ret;
if (!conn) {
error_print();
return -1;
}
if (!in || !inlen || !sentlen) {
error_print();
return -1;
}
if (conn->recv_state) {
*sentlen = 0;
return TLS_ERROR_RECV_AGAIN;
}
if (conn->send_state && conn->send_state != TLS_state_send_record) {
error_print();
return -1;
}
*sentlen = 0;
if (!conn->recordlen) {
if (inlen > TLS_MAX_PLAINTEXT_SIZE) {
inlen = TLS_MAX_PLAINTEXT_SIZE;
}
if (conn->datalen) {
error_puts("recv all buffered data before send");
return -1;
}
if (conn->is_client) {
hmac_ctx = &conn->client_write_mac_ctx;
enc_key = &conn->client_write_key;
fixed_iv = conn->client_write_iv;
seq_num = conn->client_seq_num;
} else {
hmac_ctx = &conn->server_write_mac_ctx;
enc_key = &conn->server_write_key;
fixed_iv = conn->server_write_iv;
seq_num = conn->server_seq_num;
}
if (tls_record_set_type(conn->databuf, record_type) != 1
|| tls_record_set_protocol(conn->databuf, conn->protocol) != 1
|| tls_record_set_data(conn->databuf, in, inlen) != 1) {
error_print();
return -1;
}
if(conn->verbose) tls_record_trace(stderr, conn->databuf, tls_record_length(conn->databuf), 0, 0);
if (conn->verbose) {
if (conn->protocol == TLS_protocol_tls12) {
switch (conn->cipher_suite) {
case TLS_cipher_ecdhe_sm4_gcm_sm3:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256:
if (tls_gcm_encrypt(enc_key, fixed_iv, seq_num, conn->databuf,
conn->databuf + 5, tls_record_data_length(conn->databuf),
conn->record + 5, &recordlen) != 1) {
error_print();
return -1;
}
break;
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
if (tls_ccm_encrypt(enc_key, fixed_iv, seq_num, conn->databuf,
conn->databuf + 5, tls_record_data_length(conn->databuf),
conn->record + 5, &recordlen) != 1) {
error_print();
return -1;
}
break;
#endif
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
if (tls_cbc_encrypt(hmac_ctx, enc_key, seq_num, conn->databuf,
conn->databuf + 5, tls_record_data_length(conn->databuf),
conn->record + 5, &recordlen) != 1) {
error_print();
return -1;
}
break;
default:
error_print();
return -1;
}
conn->record[0] = conn->databuf[0];
conn->record[1] = conn->databuf[1];
conn->record[2] = conn->databuf[2];
conn->record[3] = (uint8_t)(recordlen >> 8);
conn->record[4] = (uint8_t)(recordlen);
recordlen += 5;
} else if (conn->protocol == TLS_protocol_tlcp) {
if (tlcp_record_encrypt(conn->cipher_suite, hmac_ctx, enc_key, fixed_iv, seq_num,
conn->databuf, tls_record_length(conn->databuf),
conn->record, &recordlen) != 1) {
error_print();
return -1;
}
tls12_record_print(stderr, record, sizeof(record), 0, 0);
} else {
if (tls_cbc_encrypt(hmac_ctx, enc_key, seq_num, conn->databuf,
conn->databuf + 5, tls_record_data_length(conn->databuf),
conn->record + 5, &recordlen) != 1) {
error_print();
return -1;
}
conn->record[0] = conn->databuf[0];
conn->record[1] = conn->databuf[1];
conn->record[2] = conn->databuf[2];
conn->record[3] = (uint8_t)(recordlen >> 8);
conn->record[4] = (uint8_t)(recordlen);
recordlen += 5;
tls_record_trace(stderr, record, sizeof(record), 0, 0);
}
tls_seq_num_incr(seq_num);
conn->recordlen = recordlen;
conn->record_offset = 0;
conn->sentlen = inlen;
conn->send_state = TLS_state_send_record;
if(conn->verbose) tls_encrypted_record_trace(stderr, conn->record, recordlen, 0, 0);
}
ret = tls_send_record(conn);
if (ret != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
*sentlen = conn->sentlen;
conn->send_state = 0;
tls_clean_record(conn);
return 1;
}
int tls_decrypt_recv(TLS_CONNECT *conn)
{
int ret;
@@ -2459,16 +2421,22 @@ int tls_decrypt_recv(TLS_CONNECT *conn)
}
conn->recv_state = 0;
recordlen = conn->recordlen;
if(conn->verbose) tls_encrypted_record_trace(stderr, record, recordlen, 0, 0);
if (conn->verbose) {
if (conn->protocol == TLS_protocol_tls12) {
tls_encrypted_record_print(stderr, record, recordlen, 0, 0);
} else {
tls_encrypted_record_trace(stderr, record, recordlen, 0, 0);
}
}
if (conn->protocol == TLS_protocol_tls12) {
if (tls12_record_decrypt(conn->cipher_suite, hmac_ctx, dec_key, fixed_iv, seq_num,
if (tls_record_decrypt(conn->cipher_suite, hmac_ctx, dec_key, fixed_iv, seq_num,
record, recordlen, conn->databuf, &conn->datalen) != 1) {
error_print();
return -1;
}
} else if (conn->protocol == TLS_protocol_tlcp) {
if (tlcp_record_decrypt(conn->cipher_suite, hmac_ctx, dec_key, fixed_iv, seq_num,
if (tls_record_decrypt(conn->cipher_suite, hmac_ctx, dec_key, fixed_iv, seq_num,
record, recordlen, conn->databuf, &conn->datalen) != 1) {
error_print();
return -1;
@@ -2492,17 +2460,17 @@ int tls_decrypt_recv(TLS_CONNECT *conn)
conn->data = tls_record_data(conn->databuf);
conn->datalen = tls_record_data_length(conn->databuf);
if(conn->verbose) tls_record_trace(stderr, conn->databuf, tls_record_length(conn->databuf), 0, 0);
if (conn->verbose) {
if (conn->protocol == TLS_protocol_tls12) {
tls12_record_print(stderr, conn->databuf, tls_record_length(conn->databuf), 0, 0);
} else {
tls_record_trace(stderr, conn->databuf, tls_record_length(conn->databuf), 0, 0);
}
}
return 1;
}
static int tls12_tlcp_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen)
{
if(conn->verbose) tls_trace("send ApplicationData\n");
return tls_encrypt_send(conn, TLS_record_application_data, in, inlen, sentlen);
}
int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen)
{
if (!conn) {
@@ -2512,8 +2480,9 @@ int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen
switch (conn->protocol) {
case TLS_protocol_tlcp:
return tlcp_send(conn, in, inlen, sentlen);
case TLS_protocol_tls12:
return tls12_tlcp_send(conn, in, inlen, sentlen);
return tls12_send(conn, in, inlen, sentlen);
case TLS_protocol_tls13:
return tls13_send(conn, in, inlen, sentlen);
default:
@@ -2600,6 +2569,61 @@ int tls_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen)
}
}
static int tls12_send_close_notify(TLS_CONNECT *conn)
{
int ret;
const HMAC_CTX *hmac;
const BLOCK_CIPHER_KEY *key;
const uint8_t *iv;
uint8_t *seq_num;
if (!conn) {
error_print();
return -1;
}
if (!conn->recordlen) {
if (conn->is_client) {
hmac = &conn->client_write_mac_ctx;
key = &conn->client_write_key;
iv = conn->client_write_iv;
seq_num = conn->client_seq_num;
} else {
hmac = &conn->server_write_mac_ctx;
key = &conn->server_write_key;
iv = conn->server_write_iv;
seq_num = conn->server_seq_num;
}
if(conn->verbose) tls_trace("send Alert.close_notify\n");
tls_record_set_alert(conn->plain_record, &conn->plain_recordlen,
TLS_alert_level_warning, TLS_alert_close_notify);
if (tls_record_encrypt(conn->cipher_suite, hmac, key, iv, seq_num,
conn->plain_record, conn->plain_recordlen,
conn->record, &conn->recordlen) != 1) {
error_print();
return -1;
}
tls_seq_num_incr(seq_num);
conn->record_offset = 0;
conn->send_state = TLS_state_send_record;
}
ret = tls_send_record(conn);
if (ret != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
conn->send_state = 0;
tls_clean_record(conn);
return 1;
}
static int tls13_send_close_notify(TLS_CONNECT *conn)
{
int ret;
@@ -2662,14 +2686,15 @@ static int tls_send_close_notify(TLS_CONNECT *conn)
return -1;
}
if (conn->protocol == TLS_protocol_tls13) {
switch (conn->protocol) {
case TLS_protocol_tlcp:
case TLS_protocol_tls12:
return tls12_send_close_notify(conn);
case TLS_protocol_tls13:
return tls13_send_close_notify(conn);
}
alert[0] = TLS_alert_level_warning;
alert[1] = TLS_alert_close_notify;
if(conn->verbose) tls_trace("send Alert.close_notify\n");
return tls_encrypt_send(conn, TLS_record_alert, alert, sizeof(alert), &sentlen);
error_print();
return -1;
}
int tls_shutdown(TLS_CONNECT *conn)
@@ -3444,4 +3469,3 @@ int tls_get_verify_result(TLS_CONNECT *conn, int *result)
*result = conn->verify_result;
return 1;
}