From e996e72537b873feb9cca572a1cc7d724798cb1e Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Mon, 23 Mar 2026 20:50:55 +0800 Subject: [PATCH] Update TLS 1.3 0-RTT --- include/gmssl/tls.h | 4 + src/tls13.c | 279 +++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 268 insertions(+), 15 deletions(-) diff --git a/include/gmssl/tls.h b/include/gmssl/tls.h index 8266fd70..f29a6ddb 100644 --- a/include/gmssl/tls.h +++ b/include/gmssl/tls.h @@ -841,6 +841,8 @@ void tls_ctx_cleanup(TLS_CTX *ctx); enum { TLS_state_handshake_init = 0, TLS_state_client_hello, + TLS_state_early_data, + TLS_state_end_of_early_data, TLS_state_hello_retry_request, TLS_state_client_hello_again, TLS_state_server_hello, @@ -1033,6 +1035,8 @@ typedef struct { FILE *in_session; + uint8_t early_data_buf[8192]; + size_t early_data_len; } TLS_CONNECT; diff --git a/src/tls13.c b/src/tls13.c index cf10a14e..d7302ab1 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -315,6 +315,20 @@ int tls13_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen, size_t *s return 1; } +int tls13_send_early_data(TLS_CONNECT *conn) +{ + size_t sentlen; + + tls_trace("send EarlyData\n"); + + conn->early_data_len = 66; //xxx 先用来测试吧 + if (tls13_send(conn, conn->early_data_buf, conn->early_data_len, &sentlen) != 1) { + error_print(); + return -1; + } + return 1; +} + int tls13_do_recv(TLS_CONNECT *conn) { int ret; @@ -387,6 +401,21 @@ int tls13_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen) return 1; } +int tls13_recv_early_data(TLS_CONNECT *conn) +{ + tls_trace("recv EarlyData\n"); + + if (tls13_do_recv(conn) != 1) { + error_print(); + return -1; + } + memcpy(conn->early_data_buf, conn->data, conn->datalen); + conn->early_data_len = conn->datalen; + + format_bytes(stderr, 0, 4, "EarlyData", conn->early_data_buf, conn->early_data_len); + + return 1; +} /* @@ -1551,7 +1580,6 @@ int tls13_process_client_pre_shared_key(TLS_CONNECT *conn, const uint8_t *ext_da const uint8_t *truncated_binders = conn->plain_record + (binders - conn->record); size_t truncated_binderslen = binderslen; - while (truncated_binderslen) { const uint8_t *truncated_binder; size_t truncated_binderlen; @@ -1562,8 +1590,6 @@ int tls13_process_client_pre_shared_key(TLS_CONNECT *conn, const uint8_t *ext_da memset((uint8_t *)truncated_binder, 0, truncated_binderlen); } - - *selected_psk_identity = -1; for (i = 0; identitieslen; i++) { @@ -1664,6 +1690,9 @@ int tls13_process_client_pre_shared_key(TLS_CONNECT *conn, const uint8_t *ext_da error_print(); return 0; } + + + return 1; } @@ -3030,7 +3059,7 @@ int tls13_record_set_handshake_end_of_early_data(uint8_t *record, size_t *record return 1; } -int tls13_record_get_handshake_end_of_early_data(uint8_t *record, size_t *recordlen) +int tls13_record_get_handshake_end_of_early_data(uint8_t *record) { int type; const uint8_t *cp; @@ -3690,6 +3719,8 @@ int tls13_send_client_hello(TLS_CONNECT *conn) tls13_record_print(stderr, 0, 0, conn->record, conn->recordlen); conn->digest = DIGEST_sm3(); + conn->cipher = BLOCK_CIPHER_sm4(); + DIGEST_CTX tmp_dgst_ctx; @@ -3728,6 +3759,9 @@ int tls13_send_client_hello(TLS_CONNECT *conn) } conn->recordlen = 0; + + + conn->early_data = 1; } if (tls_record_set_handshake_client_hello(conn->record, &conn->recordlen, @@ -3738,9 +3772,39 @@ int tls13_send_client_hello(TLS_CONNECT *conn) error_print(); return -1; } - - tls13_record_print(stderr, 0, 0, conn->record, conn->recordlen); + + + + if (conn->early_data) { + + DIGEST_CTX early_dgst_ctx; + uint8_t client_early_traffic_secret[32]; + + digest_init(&early_dgst_ctx, conn->digest); + digest_update(&early_dgst_ctx, conn->record + 5, conn->recordlen - 5); + + uint8_t zeros[32] = {0}; + uint8_t early_secret[32]; + uint8_t client_write_key[16]; + + // [1] + tls13_hkdf_extract(conn->digest, zeros, conn->psk, early_secret); + // [2] + tls13_derive_secret(early_secret, "c e traffic", &early_dgst_ctx, client_early_traffic_secret); + + tls13_hkdf_expand_label(conn->digest, client_early_traffic_secret, "key", NULL, 0, 16, client_write_key); + block_cipher_set_encrypt_key(&conn->client_write_key, conn->cipher, client_write_key); + tls13_hkdf_expand_label(conn->digest, client_early_traffic_secret, "iv", NULL, 0, 12, conn->client_write_iv); + tls_seq_num_reset(conn->client_seq_num); + + // client_early_traffic_secret 用来加密early_data, end_of_early_data + format_print(stderr, 0, 0, "client_write_key/iv <= client_early_traffic_secret\n"); + format_bytes(stderr, 0, 4, "client_early_traffic_secret", client_early_traffic_secret, 32); + format_bytes(stderr, 0, 4, "client_write_key", client_write_key, 16); + format_bytes(stderr, 0, 4, "client_write_iv", conn->client_write_iv, 12); + + } } if (conn->client_certificate_verify) { @@ -4262,6 +4326,107 @@ int tls13_recv_server_hello(TLS_CONNECT *conn) return 1; } +int tls13_send_end_of_early_data(TLS_CONNECT *conn) +{ + int ret; + + if (!conn->recordlen) { + if (tls13_record_set_handshake_end_of_early_data(conn->plain_record, &conn->plain_recordlen) != 1) { + error_print(); + return -1; + } + tls_trace("send EndOfEarlyData\n"); + + format_bytes(stderr, 0, 4, "client_write_iv", conn->client_write_iv, 12); + + size_t padding_len; + tls13_padding_len_rand(&padding_len); + if (tls13_record_encrypt(&conn->client_write_key, conn->client_write_iv, + conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, + conn->record, &conn->recordlen) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_internal_error); + return -1; + } + + tls13_record_print(stderr, 0, 0, conn->record, conn->recordlen); + } + + if ((ret = tls_send_record(conn)) != 1) { + if (ret != TLS_ERROR_SEND_AGAIN) { + error_print(); + } + return ret; + } + + + uint8_t client_write_key[16]; + tls13_hkdf_expand_label(conn->digest, conn->client_handshake_traffic_secret, "key", NULL, 0, 16, client_write_key); + block_cipher_set_encrypt_key(&conn->client_write_key, conn->cipher, client_write_key); + tls13_hkdf_expand_label(conn->digest, conn->client_handshake_traffic_secret, "iv", NULL, 0, 12, conn->client_write_iv); + tls_seq_num_reset(conn->client_seq_num); + + // client_early_traffic_secret 用来加密early_data, end_of_early_data + format_print(stderr, 0, 0, "client_write_key/iv <= client_handshake_traffic_secret\n"); + format_bytes(stderr, 0, 4, "client_early_traffic_secret", conn->client_handshake_traffic_secret, 32); + format_bytes(stderr, 0, 4, "client_write_key", client_write_key, 16); + format_bytes(stderr, 0, 4, "client_write_iv", conn->client_write_iv, 12); + + + tls_clean_record(conn); + return 1; + +} + +int tls13_recv_end_of_early_data(TLS_CONNECT *conn) +{ + int ret; + tls_trace("recv {EndOfEarlyData}\n"); + + if ((ret = tls_recv_record(conn)) != 1) { + if (ret != TLS_ERROR_RECV_AGAIN) { + error_print(); + } + return ret; + } + + + format_bytes(stderr, 0, 4, "client_write_iv", conn->client_write_iv, 12); + + if (tls13_record_decrypt(&conn->client_write_key, conn->client_write_iv, + conn->client_seq_num, conn->record, conn->recordlen, + conn->plain_record, &conn->plain_recordlen) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_bad_record_mac); + return -1; + } + tls13_record_print(stderr, 0, 0, conn->plain_record, conn->plain_recordlen); + + if ((ret = tls13_record_get_handshake_end_of_early_data(conn->plain_record)) < 0) { + error_print(); + tls_send_alert(conn, TLS_alert_decode_error); + return -1; + } else if (ret == 0) { + error_print(); + tls_send_alert(conn, TLS_alert_unexpected_message); + return -1; + } + + uint8_t client_write_key[16]; + tls13_hkdf_expand_label(conn->digest, conn->client_handshake_traffic_secret, "key", NULL, 0, 16, client_write_key); + block_cipher_set_encrypt_key(&conn->client_write_key, conn->cipher, client_write_key); + tls13_hkdf_expand_label(conn->digest, conn->client_handshake_traffic_secret, "iv", NULL, 0, 12, conn->client_write_iv); + tls_seq_num_reset(conn->client_seq_num); + + format_print(stderr, 0, 0, "client_write_key/iv <= client_handshake_traffic_secret\n"); + format_bytes(stderr, 0, 4, "client_handshake_traffic_secret", conn->client_handshake_traffic_secret, 32); + format_bytes(stderr, 0, 4, "client_write_key", client_write_key, 16); + format_bytes(stderr, 0, 4, "client_write_iv", conn->client_write_iv, 12); + + return 1; +} + + /* 0 @@ -4379,11 +4544,12 @@ int tls13_generate_keys(TLS_CONNECT *conn) tls13_hkdf_expand_label(conn->digest, conn->server_handshake_traffic_secret, "iv", NULL, 0, 12, conn->server_write_iv); tls_seq_num_reset(conn->server_seq_num); - - tls13_hkdf_expand_label(conn->digest, conn->client_handshake_traffic_secret, "key", NULL, 0, 16, client_write_key); - block_cipher_set_encrypt_key(&conn->client_write_key, conn->cipher, client_write_key); - tls13_hkdf_expand_label(conn->digest, conn->client_handshake_traffic_secret, "iv", NULL, 0, 12, conn->client_write_iv); - tls_seq_num_reset(conn->client_seq_num); + if (!conn->early_data) { + tls13_hkdf_expand_label(conn->digest, conn->client_handshake_traffic_secret, "key", NULL, 0, 16, client_write_key); + block_cipher_set_encrypt_key(&conn->client_write_key, conn->cipher, client_write_key); + tls13_hkdf_expand_label(conn->digest, conn->client_handshake_traffic_secret, "iv", NULL, 0, 12, conn->client_write_iv); + tls_seq_num_reset(conn->client_seq_num); + } format_bytes(stderr, 0, 4, "client_write_key", client_write_key, 16); format_bytes(stderr, 0, 4, "server_write_key", server_write_key, 16); @@ -5383,6 +5549,39 @@ int tls13_recv_client_hello(TLS_CONNECT *conn) // ServerHello依赖这个标志 } + + conn->early_data = 1; + + + + // 这部分应该单独放在一个括号里面 + // 这里是设置early_data的加密密钥,这个加密密钥几乎完全是用PSK生成的 + + DIGEST_CTX early_dgst_ctx; + uint8_t client_early_traffic_secret[32]; + + digest_init(&early_dgst_ctx, conn->digest); + digest_update(&early_dgst_ctx, conn->record + 5, conn->recordlen - 5); + + uint8_t zeros[32] = {0}; + uint8_t early_secret[32]; + uint8_t client_write_key[16]; + + // [1] + tls13_hkdf_extract(conn->digest, zeros, conn->psk, early_secret); + // [2] + tls13_derive_secret(early_secret, "c e traffic", &early_dgst_ctx, client_early_traffic_secret); + + tls13_hkdf_expand_label(conn->digest, client_early_traffic_secret, "key", NULL, 0, 16, client_write_key); + block_cipher_set_encrypt_key(&conn->client_write_key, conn->cipher, client_write_key); + tls13_hkdf_expand_label(conn->digest, client_early_traffic_secret, "iv", NULL, 0, 12, conn->client_write_iv); + tls_seq_num_reset(conn->client_seq_num); + + format_print(stderr, 0, 0, "client_write_key/iv <= client_early_traffic_secret\n"); + format_bytes(stderr, 0, 4, "client_early_traffic_secret", client_early_traffic_secret, 32); + format_bytes(stderr, 0, 4, "client_write_key", client_write_key, 16); + format_bytes(stderr, 0, 4, "client_write_iv", conn->client_write_iv, 12); + break; default: @@ -6725,6 +6924,29 @@ int tls13_recv_new_session_ticket(TLS_CONNECT *conn) send_client_certificate send_certificate_verify send_client_finished + + + +PSK模式 + + send_client_hello + recv_server_hello + recv_server_finished + send_client_finished + +0-RTT + + send_client_hello + send_application_data* + recv_server_hello + recv_server_finished + send_client_finished + + 客户端在发送完ClientHello之后 + 可以同时发送0-RTT数据和接收ServerHello + 也就是说这两个状态是可以并发的 + + */ @@ -6736,9 +6958,15 @@ int tls13_do_client_handshake(TLS_CONNECT *conn) switch (conn->state) { case TLS_state_client_hello: ret = tls13_send_client_hello(conn); - next_state = TLS_state_hello_retry_request; + if (conn->early_data) + next_state = TLS_state_early_data; + else next_state = TLS_state_hello_retry_request; break; + case TLS_state_early_data: + ret = tls13_send_early_data(conn); + next_state = TLS_state_hello_retry_request; + case TLS_state_hello_retry_request: // optional ret = tls13_recv_hello_retry_request(conn); if (conn->hello_retry_request) @@ -6785,11 +7013,18 @@ int tls13_do_client_handshake(TLS_CONNECT *conn) case TLS_state_server_finished: ret = tls13_recv_server_finished(conn); - if (conn->client_certificate_verify) + if (conn->early_data) + next_state = TLS_state_end_of_early_data; + else if (conn->client_certificate_verify) next_state = TLS_state_client_certificate; else next_state = TLS_state_client_finished; break; + case TLS_state_end_of_early_data: + ret = tls13_send_end_of_early_data(conn); + next_state = TLS_state_client_finished; + break; + case TLS_state_client_certificate: ret = tls13_send_client_certificate(conn); next_state = TLS_state_client_certificate_verify; @@ -6843,11 +7078,18 @@ int tls13_do_server_handshake(TLS_CONNECT *conn) switch (conn->state) { case TLS_state_client_hello: ret = tls13_recv_client_hello(conn); - if (conn->hello_retry_request) + if (conn->early_data) + next_state = TLS_state_early_data; + else if (conn->hello_retry_request) next_state = TLS_state_hello_retry_request; else next_state = TLS_state_server_hello; break; + case TLS_state_early_data: + ret = tls13_recv_early_data(conn); + next_state = TLS_state_server_hello; + break; + case TLS_state_hello_retry_request: ret = tls13_send_hello_retry_request(conn); next_state = TLS_state_client_hello_again; @@ -6893,11 +7135,18 @@ int tls13_do_server_handshake(TLS_CONNECT *conn) case TLS_state_server_finished: ret = tls13_send_server_finished(conn); - if (conn->certificate_request) + if (conn->early_data) + next_state = TLS_state_end_of_early_data; + else if (conn->certificate_request) next_state = TLS_state_client_certificate; else next_state = TLS_state_client_finished; break; + case TLS_state_end_of_early_data: + ret = tls13_recv_end_of_early_data(conn); + next_state = TLS_state_client_finished; + break; + case TLS_state_client_certificate: ret = tls13_recv_client_certificate(conn); next_state = TLS_state_client_certificate_verify;