Merge TLS API

This commit is contained in:
Zhi Guan
2026-06-11 14:46:35 +08:00
parent b48f2c3772
commit e1c69d5633
3 changed files with 122 additions and 9 deletions

119
src/tls.c
View File

@@ -1853,13 +1853,32 @@ int tls_decrypt_recv(TLS_CONNECT *conn)
return 1; return 1;
} }
int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen) static int tls12_tlcp_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen)
{ {
tls_trace("send ApplicationData\n"); tls_trace("send ApplicationData\n");
return tls_encrypt_send(conn, TLS_record_application_data, in, inlen, sentlen); return tls_encrypt_send(conn, TLS_record_application_data, in, inlen, sentlen);
} }
int tls_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen) int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen)
{
if (!conn) {
error_print();
return -1;
}
switch (conn->protocol) {
case TLS_protocol_tlcp:
case TLS_protocol_tls12:
return tls12_tlcp_send(conn, in, inlen, sentlen);
case TLS_protocol_tls13:
return tls13_send(conn, in, inlen, sentlen);
default:
error_print();
return -1;
}
}
static int tls12_tlcp_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen)
{ {
if (!conn || !out || !outlen || !recvlen) { if (!conn || !out || !outlen || !recvlen) {
error_print(); error_print();
@@ -1906,7 +1925,26 @@ int tls_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen)
return 1; return 1;
} }
int tls_shutdown(TLS_CONNECT *conn) int tls_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen)
{
if (!conn) {
error_print();
return -1;
}
switch (conn->protocol) {
case TLS_protocol_tlcp:
case TLS_protocol_tls12:
return tls12_tlcp_recv(conn, out, outlen, recvlen);
case TLS_protocol_tls13:
return tls13_recv(conn, out, outlen, recvlen);
default:
error_print();
return -1;
}
}
static int tls12_tlcp_shutdown(TLS_CONNECT *conn)
{ {
int ret; int ret;
size_t recordlen; size_t recordlen;
@@ -1938,6 +1976,72 @@ int tls_shutdown(TLS_CONNECT *conn)
return 1; return 1;
} }
static int tls13_shutdown(TLS_CONNECT *conn)
{
int ret;
const BLOCK_CIPHER_KEY *key;
const uint8_t *iv;
uint8_t *seq_num;
size_t padding_len;
if (!conn) {
error_print();
return -1;
}
if (conn->is_client) {
key = &conn->client_write_key;
iv = conn->client_write_iv;
seq_num = conn->client_seq_num;
} else {
key = &conn->server_write_key;
iv = conn->server_write_iv;
seq_num = conn->server_seq_num;
}
tls_trace("send Alert.close_notify\n");
tls_record_set_alert(conn->plain_record, &conn->plain_recordlen,
TLS_alert_level_warning, TLS_alert_close_notify);
tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(key, iv, seq_num, conn->plain_record, conn->plain_recordlen,
padding_len, conn->record, &conn->recordlen) != 1) {
error_print();
return -1;
}
tls_seq_num_incr(seq_num);
conn->record_offset = 0;
ret = tls_send_record(conn);
if (ret != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
return 1;
}
int tls_shutdown(TLS_CONNECT *conn)
{
if (!conn) {
error_print();
return -1;
}
switch (conn->protocol) {
case TLS_protocol_tlcp:
case TLS_protocol_tls12:
return tls12_tlcp_shutdown(conn);
case TLS_protocol_tls13:
return tls13_shutdown(conn);
default:
error_print();
return -1;
}
}
int tls_authorities_from_certs(uint8_t *names, size_t *nameslen, size_t maxlen, const uint8_t *certs, size_t certslen) int tls_authorities_from_certs(uint8_t *names, size_t *nameslen, size_t maxlen, const uint8_t *certs, size_t certslen)
{ {
const uint8_t *cert; const uint8_t *cert;
@@ -2716,6 +2820,15 @@ int tls_ctx_set_key_update_seq_num_limit(TLS_CTX *ctx, size_t max_seq_num)
int tls_init(TLS_CONNECT *conn, TLS_CTX *ctx) int tls_init(TLS_CONNECT *conn, TLS_CTX *ctx)
{ {
if (!conn || !ctx) {
error_print();
return -1;
}
if (ctx->protocol == TLS_protocol_tls13) {
return tls13_init(conn, ctx);
}
memset(conn, 0, sizeof(*conn)); memset(conn, 0, sizeof(*conn));
conn->is_client = ctx->is_client; // TODO: remove conn->is_client conn->is_client = ctx->is_client; // TODO: remove conn->is_client

View File

@@ -443,7 +443,7 @@ bad:
// TLS_CONNECT // TLS_CONNECT
if (tls13_init(&conn, &ctx) != 1) { if (tls_init(&conn, &ctx) != 1) {
error_print(); error_print();
goto end; goto end;
} }
@@ -669,7 +669,7 @@ bad:
if (FD_ISSET(conn.sock, &fds_recv)) { if (FD_ISSET(conn.sock, &fds_recv)) {
memset(buf, 0, sizeof(buf)); memset(buf, 0, sizeof(buf));
if ((ret = tls13_recv(&conn, (uint8_t *)buf, sizeof(buf), &len)) != 1) { if ((ret = tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len)) != 1) {
if (ret == TLS_ERROR_SEND_AGAIN || ret == TLS_ERROR_RECV_AGAIN) { if (ret == TLS_ERROR_SEND_AGAIN || ret == TLS_ERROR_RECV_AGAIN) {
continue; continue;
} else { } else {
@@ -705,7 +705,7 @@ bad:
if (sent_len > 0 && FD_ISSET(conn.sock, &fds_send)) { if (sent_len > 0 && FD_ISSET(conn.sock, &fds_send)) {
// tls13_send 会返回一个 -1 , 但是没有打印错误信息!!!! // tls13_send 会返回一个 -1 , 但是没有打印错误信息!!!!
if ((ret = tls13_send(&conn, (uint8_t *)send_buf + sent_offset, sent_len, &sentlen)) != 1) { if ((ret = tls_send(&conn, (uint8_t *)send_buf + sent_offset, sent_len, &sentlen)) != 1) {
if (ret == TLS_ERROR_SEND_AGAIN || ret == TLS_ERROR_RECV_AGAIN) { if (ret == TLS_ERROR_SEND_AGAIN || ret == TLS_ERROR_RECV_AGAIN) {
continue; continue;
} else { } else {

View File

@@ -448,7 +448,7 @@ bad:
goto end; goto end;
} }
if (tls13_init(&conn, &ctx) != 1) { if (tls_init(&conn, &ctx) != 1) {
error_print(); error_print();
goto end; goto end;
} }
@@ -553,7 +553,7 @@ bad:
format_bytes(stderr, 0, 0, "tls13_send", (const uint8_t *)buf + send_offset, send_len); format_bytes(stderr, 0, 0, "tls13_send", (const uint8_t *)buf + send_offset, send_len);
if ((ret = tls13_send(&conn, (uint8_t *)buf + send_offset, send_len, &sentlen)) != 1) { if ((ret = tls_send(&conn, (uint8_t *)buf + send_offset, send_len, &sentlen)) != 1) {
if (ret == TLS_ERROR_SEND_AGAIN || ret == TLS_ERROR_RECV_AGAIN) { if (ret == TLS_ERROR_SEND_AGAIN || ret == TLS_ERROR_RECV_AGAIN) {
continue; continue;
} }
@@ -575,7 +575,7 @@ bad:
memset(buf, 0, sizeof(buf)); memset(buf, 0, sizeof(buf));
if ((ret = tls13_recv(&conn, (uint8_t *)buf, sizeof(buf), &len)) != 1) { if ((ret = tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len)) != 1) {
if (ret == TLS_ERROR_SEND_AGAIN || ret == TLS_ERROR_RECV_AGAIN) { if (ret == TLS_ERROR_SEND_AGAIN || ret == TLS_ERROR_RECV_AGAIN) {
continue; continue;
} }