Update TLS 1.3 0-RTT

This commit is contained in:
Zhi Guan
2026-03-23 20:50:55 +08:00
parent 5efe2005d4
commit e996e72537
2 changed files with 268 additions and 15 deletions

View File

@@ -315,6 +315,20 @@ int tls13_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen, size_t *s
return 1;
}
int tls13_send_early_data(TLS_CONNECT *conn)
{
size_t sentlen;
tls_trace("send EarlyData\n");
conn->early_data_len = 66; //xxx 先用来测试吧
if (tls13_send(conn, conn->early_data_buf, conn->early_data_len, &sentlen) != 1) {
error_print();
return -1;
}
return 1;
}
int tls13_do_recv(TLS_CONNECT *conn)
{
int ret;
@@ -387,6 +401,21 @@ int tls13_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen)
return 1;
}
int tls13_recv_early_data(TLS_CONNECT *conn)
{
tls_trace("recv EarlyData\n");
if (tls13_do_recv(conn) != 1) {
error_print();
return -1;
}
memcpy(conn->early_data_buf, conn->data, conn->datalen);
conn->early_data_len = conn->datalen;
format_bytes(stderr, 0, 4, "EarlyData", conn->early_data_buf, conn->early_data_len);
return 1;
}
/*
@@ -1551,7 +1580,6 @@ int tls13_process_client_pre_shared_key(TLS_CONNECT *conn, const uint8_t *ext_da
const uint8_t *truncated_binders = conn->plain_record + (binders - conn->record);
size_t truncated_binderslen = binderslen;
while (truncated_binderslen) {
const uint8_t *truncated_binder;
size_t truncated_binderlen;
@@ -1562,8 +1590,6 @@ int tls13_process_client_pre_shared_key(TLS_CONNECT *conn, const uint8_t *ext_da
memset((uint8_t *)truncated_binder, 0, truncated_binderlen);
}
*selected_psk_identity = -1;
for (i = 0; identitieslen; i++) {
@@ -1664,6 +1690,9 @@ int tls13_process_client_pre_shared_key(TLS_CONNECT *conn, const uint8_t *ext_da
error_print();
return 0;
}
return 1;
}
@@ -3030,7 +3059,7 @@ int tls13_record_set_handshake_end_of_early_data(uint8_t *record, size_t *record
return 1;
}
int tls13_record_get_handshake_end_of_early_data(uint8_t *record, size_t *recordlen)
int tls13_record_get_handshake_end_of_early_data(uint8_t *record)
{
int type;
const uint8_t *cp;
@@ -3690,6 +3719,8 @@ int tls13_send_client_hello(TLS_CONNECT *conn)
tls13_record_print(stderr, 0, 0, conn->record, conn->recordlen);
conn->digest = DIGEST_sm3();
conn->cipher = BLOCK_CIPHER_sm4();
DIGEST_CTX tmp_dgst_ctx;
@@ -3728,6 +3759,9 @@ int tls13_send_client_hello(TLS_CONNECT *conn)
}
conn->recordlen = 0;
conn->early_data = 1;
}
if (tls_record_set_handshake_client_hello(conn->record, &conn->recordlen,
@@ -3738,9 +3772,39 @@ int tls13_send_client_hello(TLS_CONNECT *conn)
error_print();
return -1;
}
tls13_record_print(stderr, 0, 0, conn->record, conn->recordlen);
if (conn->early_data) {
DIGEST_CTX early_dgst_ctx;
uint8_t client_early_traffic_secret[32];
digest_init(&early_dgst_ctx, conn->digest);
digest_update(&early_dgst_ctx, conn->record + 5, conn->recordlen - 5);
uint8_t zeros[32] = {0};
uint8_t early_secret[32];
uint8_t client_write_key[16];
// [1]
tls13_hkdf_extract(conn->digest, zeros, conn->psk, early_secret);
// [2]
tls13_derive_secret(early_secret, "c e traffic", &early_dgst_ctx, client_early_traffic_secret);
tls13_hkdf_expand_label(conn->digest, client_early_traffic_secret, "key", NULL, 0, 16, client_write_key);
block_cipher_set_encrypt_key(&conn->client_write_key, conn->cipher, client_write_key);
tls13_hkdf_expand_label(conn->digest, client_early_traffic_secret, "iv", NULL, 0, 12, conn->client_write_iv);
tls_seq_num_reset(conn->client_seq_num);
// client_early_traffic_secret 用来加密early_data, end_of_early_data
format_print(stderr, 0, 0, "client_write_key/iv <= client_early_traffic_secret\n");
format_bytes(stderr, 0, 4, "client_early_traffic_secret", client_early_traffic_secret, 32);
format_bytes(stderr, 0, 4, "client_write_key", client_write_key, 16);
format_bytes(stderr, 0, 4, "client_write_iv", conn->client_write_iv, 12);
}
}
if (conn->client_certificate_verify) {
@@ -4262,6 +4326,107 @@ int tls13_recv_server_hello(TLS_CONNECT *conn)
return 1;
}
int tls13_send_end_of_early_data(TLS_CONNECT *conn)
{
int ret;
if (!conn->recordlen) {
if (tls13_record_set_handshake_end_of_early_data(conn->plain_record, &conn->plain_recordlen) != 1) {
error_print();
return -1;
}
tls_trace("send EndOfEarlyData\n");
format_bytes(stderr, 0, 4, "client_write_iv", conn->client_write_iv, 12);
size_t padding_len;
tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->client_write_key, conn->client_write_iv,
conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls13_record_print(stderr, 0, 0, conn->record, conn->recordlen);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
uint8_t client_write_key[16];
tls13_hkdf_expand_label(conn->digest, conn->client_handshake_traffic_secret, "key", NULL, 0, 16, client_write_key);
block_cipher_set_encrypt_key(&conn->client_write_key, conn->cipher, client_write_key);
tls13_hkdf_expand_label(conn->digest, conn->client_handshake_traffic_secret, "iv", NULL, 0, 12, conn->client_write_iv);
tls_seq_num_reset(conn->client_seq_num);
// client_early_traffic_secret 用来加密early_data, end_of_early_data
format_print(stderr, 0, 0, "client_write_key/iv <= client_handshake_traffic_secret\n");
format_bytes(stderr, 0, 4, "client_early_traffic_secret", conn->client_handshake_traffic_secret, 32);
format_bytes(stderr, 0, 4, "client_write_key", client_write_key, 16);
format_bytes(stderr, 0, 4, "client_write_iv", conn->client_write_iv, 12);
tls_clean_record(conn);
return 1;
}
int tls13_recv_end_of_early_data(TLS_CONNECT *conn)
{
int ret;
tls_trace("recv {EndOfEarlyData}\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
format_bytes(stderr, 0, 4, "client_write_iv", conn->client_write_iv, 12);
if (tls13_record_decrypt(&conn->client_write_key, conn->client_write_iv,
conn->client_seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_record_mac);
return -1;
}
tls13_record_print(stderr, 0, 0, conn->plain_record, conn->plain_recordlen);
if ((ret = tls13_record_get_handshake_end_of_early_data(conn->plain_record)) < 0) {
error_print();
tls_send_alert(conn, TLS_alert_decode_error);
return -1;
} else if (ret == 0) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
uint8_t client_write_key[16];
tls13_hkdf_expand_label(conn->digest, conn->client_handshake_traffic_secret, "key", NULL, 0, 16, client_write_key);
block_cipher_set_encrypt_key(&conn->client_write_key, conn->cipher, client_write_key);
tls13_hkdf_expand_label(conn->digest, conn->client_handshake_traffic_secret, "iv", NULL, 0, 12, conn->client_write_iv);
tls_seq_num_reset(conn->client_seq_num);
format_print(stderr, 0, 0, "client_write_key/iv <= client_handshake_traffic_secret\n");
format_bytes(stderr, 0, 4, "client_handshake_traffic_secret", conn->client_handshake_traffic_secret, 32);
format_bytes(stderr, 0, 4, "client_write_key", client_write_key, 16);
format_bytes(stderr, 0, 4, "client_write_iv", conn->client_write_iv, 12);
return 1;
}
/*
0
@@ -4379,11 +4544,12 @@ int tls13_generate_keys(TLS_CONNECT *conn)
tls13_hkdf_expand_label(conn->digest, conn->server_handshake_traffic_secret, "iv", NULL, 0, 12, conn->server_write_iv);
tls_seq_num_reset(conn->server_seq_num);
tls13_hkdf_expand_label(conn->digest, conn->client_handshake_traffic_secret, "key", NULL, 0, 16, client_write_key);
block_cipher_set_encrypt_key(&conn->client_write_key, conn->cipher, client_write_key);
tls13_hkdf_expand_label(conn->digest, conn->client_handshake_traffic_secret, "iv", NULL, 0, 12, conn->client_write_iv);
tls_seq_num_reset(conn->client_seq_num);
if (!conn->early_data) {
tls13_hkdf_expand_label(conn->digest, conn->client_handshake_traffic_secret, "key", NULL, 0, 16, client_write_key);
block_cipher_set_encrypt_key(&conn->client_write_key, conn->cipher, client_write_key);
tls13_hkdf_expand_label(conn->digest, conn->client_handshake_traffic_secret, "iv", NULL, 0, 12, conn->client_write_iv);
tls_seq_num_reset(conn->client_seq_num);
}
format_bytes(stderr, 0, 4, "client_write_key", client_write_key, 16);
format_bytes(stderr, 0, 4, "server_write_key", server_write_key, 16);
@@ -5383,6 +5549,39 @@ int tls13_recv_client_hello(TLS_CONNECT *conn)
// ServerHello依赖这个标志
}
conn->early_data = 1;
// 这部分应该单独放在一个括号里面
// 这里是设置early_data的加密密钥这个加密密钥几乎完全是用PSK生成的
DIGEST_CTX early_dgst_ctx;
uint8_t client_early_traffic_secret[32];
digest_init(&early_dgst_ctx, conn->digest);
digest_update(&early_dgst_ctx, conn->record + 5, conn->recordlen - 5);
uint8_t zeros[32] = {0};
uint8_t early_secret[32];
uint8_t client_write_key[16];
// [1]
tls13_hkdf_extract(conn->digest, zeros, conn->psk, early_secret);
// [2]
tls13_derive_secret(early_secret, "c e traffic", &early_dgst_ctx, client_early_traffic_secret);
tls13_hkdf_expand_label(conn->digest, client_early_traffic_secret, "key", NULL, 0, 16, client_write_key);
block_cipher_set_encrypt_key(&conn->client_write_key, conn->cipher, client_write_key);
tls13_hkdf_expand_label(conn->digest, client_early_traffic_secret, "iv", NULL, 0, 12, conn->client_write_iv);
tls_seq_num_reset(conn->client_seq_num);
format_print(stderr, 0, 0, "client_write_key/iv <= client_early_traffic_secret\n");
format_bytes(stderr, 0, 4, "client_early_traffic_secret", client_early_traffic_secret, 32);
format_bytes(stderr, 0, 4, "client_write_key", client_write_key, 16);
format_bytes(stderr, 0, 4, "client_write_iv", conn->client_write_iv, 12);
break;
default:
@@ -6725,6 +6924,29 @@ int tls13_recv_new_session_ticket(TLS_CONNECT *conn)
send_client_certificate
send_certificate_verify
send_client_finished
PSK模式
send_client_hello
recv_server_hello
recv_server_finished
send_client_finished
0-RTT
send_client_hello
send_application_data*
recv_server_hello
recv_server_finished
send_client_finished
客户端在发送完ClientHello之后
可以同时发送0-RTT数据和接收ServerHello
也就是说这两个状态是可以并发的
*/
@@ -6736,9 +6958,15 @@ int tls13_do_client_handshake(TLS_CONNECT *conn)
switch (conn->state) {
case TLS_state_client_hello:
ret = tls13_send_client_hello(conn);
next_state = TLS_state_hello_retry_request;
if (conn->early_data)
next_state = TLS_state_early_data;
else next_state = TLS_state_hello_retry_request;
break;
case TLS_state_early_data:
ret = tls13_send_early_data(conn);
next_state = TLS_state_hello_retry_request;
case TLS_state_hello_retry_request: // optional
ret = tls13_recv_hello_retry_request(conn);
if (conn->hello_retry_request)
@@ -6785,11 +7013,18 @@ int tls13_do_client_handshake(TLS_CONNECT *conn)
case TLS_state_server_finished:
ret = tls13_recv_server_finished(conn);
if (conn->client_certificate_verify)
if (conn->early_data)
next_state = TLS_state_end_of_early_data;
else if (conn->client_certificate_verify)
next_state = TLS_state_client_certificate;
else next_state = TLS_state_client_finished;
break;
case TLS_state_end_of_early_data:
ret = tls13_send_end_of_early_data(conn);
next_state = TLS_state_client_finished;
break;
case TLS_state_client_certificate:
ret = tls13_send_client_certificate(conn);
next_state = TLS_state_client_certificate_verify;
@@ -6843,11 +7078,18 @@ int tls13_do_server_handshake(TLS_CONNECT *conn)
switch (conn->state) {
case TLS_state_client_hello:
ret = tls13_recv_client_hello(conn);
if (conn->hello_retry_request)
if (conn->early_data)
next_state = TLS_state_early_data;
else if (conn->hello_retry_request)
next_state = TLS_state_hello_retry_request;
else next_state = TLS_state_server_hello;
break;
case TLS_state_early_data:
ret = tls13_recv_early_data(conn);
next_state = TLS_state_server_hello;
break;
case TLS_state_hello_retry_request:
ret = tls13_send_hello_retry_request(conn);
next_state = TLS_state_client_hello_again;
@@ -6893,11 +7135,18 @@ int tls13_do_server_handshake(TLS_CONNECT *conn)
case TLS_state_server_finished:
ret = tls13_send_server_finished(conn);
if (conn->certificate_request)
if (conn->early_data)
next_state = TLS_state_end_of_early_data;
else if (conn->certificate_request)
next_state = TLS_state_client_certificate;
else next_state = TLS_state_client_finished;
break;
case TLS_state_end_of_early_data:
ret = tls13_recv_end_of_early_data(conn);
next_state = TLS_state_client_finished;
break;
case TLS_state_client_certificate:
ret = tls13_recv_client_certificate(conn);
next_state = TLS_state_client_certificate_verify;