From 6f42fdf31f1ff4e4a26af51120c8a8d6c5f579cf Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Fri, 12 Jun 2026 13:28:10 +0800 Subject: [PATCH] Update TLS state machine --- include/gmssl/tls.h | 7 +- src/tlcp.c | 16 ++--- src/tls12.c | 16 ++--- src/tls13.c | 159 ++++++++++++++++++------------------------- tools/tls13_client.c | 54 ++++++++++++++- tools/tls13_server.c | 54 ++++++++++++++- 6 files changed, 195 insertions(+), 111 deletions(-) diff --git a/include/gmssl/tls.h b/include/gmssl/tls.h index e15ec268..22b889d7 100644 --- a/include/gmssl/tls.h +++ b/include/gmssl/tls.h @@ -1004,6 +1004,7 @@ enum { TLS_state_handshake_over, + TLS_state_send_record, TLS_state_recv_record_header, TLS_state_recv_record_data, }; @@ -1017,8 +1018,10 @@ typedef struct { TLS_CTX *ctx; - // handshake state for state machine - int state; + // states for state machines + int handshake_state; + int send_state; + int recv_state; uint8_t record[TLS_MAX_RECORD_SIZE]; diff --git a/src/tlcp.c b/src/tlcp.c index 03cffcac..363d3289 100644 --- a/src/tlcp.c +++ b/src/tlcp.c @@ -2261,7 +2261,7 @@ int tlcp_do_client_handshake(TLS_CONNECT *conn) int ret; int next_state; - switch (conn->state) { + switch (conn->handshake_state) { case TLS_state_client_hello: ret = tlcp_send_client_hello(conn); next_state = TLS_state_server_hello; @@ -2346,7 +2346,7 @@ int tlcp_do_client_handshake(TLS_CONNECT *conn) } } - conn->state = next_state; + conn->handshake_state = next_state; // ret == 0 means this step is bypassed if (ret == 1) { @@ -2361,7 +2361,7 @@ int tlcp_do_server_handshake(TLS_CONNECT *conn) int ret; int next_state; - switch (conn->state) { + switch (conn->handshake_state) { case TLS_state_client_hello: ret = tlcp_recv_client_hello(conn); next_state = TLS_state_server_hello; @@ -2448,7 +2448,7 @@ int tlcp_do_server_handshake(TLS_CONNECT *conn) } - conn->state = next_state; + conn->handshake_state = next_state; tls_clean_record(conn); @@ -2459,7 +2459,7 @@ int tlcp_client_handshake(TLS_CONNECT *conn) { int ret; - while (conn->state != TLS_state_handshake_over) { + while (conn->handshake_state != TLS_state_handshake_over) { ret = tlcp_do_client_handshake(conn); @@ -2481,7 +2481,7 @@ int tlcp_server_handshake(TLS_CONNECT *conn) int ret; - while (conn->state != TLS_state_handshake_over) { + while (conn->handshake_state != TLS_state_handshake_over) { ret = tlcp_do_server_handshake(conn); @@ -2506,7 +2506,7 @@ int tlcp_do_connect(TLS_CONNECT *conn) // 应该把protocol_version的初始化放在这里 - conn->state = TLS_state_client_hello; + conn->handshake_state = TLS_state_client_hello; //sm3_init(&conn->sm3_ctx); while (1) { @@ -2542,7 +2542,7 @@ int tlcp_do_accept(TLS_CONNECT *conn) fd_set rfds; fd_set wfds; - conn->state = TLS_state_client_hello; + conn->handshake_state = TLS_state_client_hello; //sm3_init(&conn->sm3_ctx); diff --git a/src/tls12.c b/src/tls12.c index d9285080..5485544d 100644 --- a/src/tls12.c +++ b/src/tls12.c @@ -3536,7 +3536,7 @@ int tls12_do_client_handshake(TLS_CONNECT *conn) int ret; int next_state; - switch (conn->state) { + switch (conn->handshake_state) { case TLS_state_client_hello: ret = tls_send_client_hello(conn); next_state = TLS_state_server_hello; @@ -3626,7 +3626,7 @@ int tls12_do_client_handshake(TLS_CONNECT *conn) } } - conn->state = next_state; + conn->handshake_state = next_state; // ret == 0 means this step is bypassed if (ret == 1) { @@ -3641,7 +3641,7 @@ int tls12_do_server_handshake(TLS_CONNECT *conn) int ret; int next_state; - switch (conn->state) { + switch (conn->handshake_state) { case TLS_state_client_hello: ret = tls_recv_client_hello(conn); next_state = TLS_state_server_hello; @@ -3728,7 +3728,7 @@ int tls12_do_server_handshake(TLS_CONNECT *conn) } - conn->state = next_state; + conn->handshake_state = next_state; tls_clean_record(conn); @@ -3741,7 +3741,7 @@ int tls12_client_handshake(TLS_CONNECT *conn) { int ret; - while (conn->state != TLS_state_handshake_over) { + while (conn->handshake_state != TLS_state_handshake_over) { ret = tls12_do_client_handshake(conn); @@ -3763,7 +3763,7 @@ int tls12_server_handshake(TLS_CONNECT *conn) int ret; - while (conn->state != TLS_state_handshake_over) { + while (conn->handshake_state != TLS_state_handshake_over) { ret = tls12_do_server_handshake(conn); @@ -3786,7 +3786,7 @@ int tls12_do_connect(TLS_CONNECT *conn) fd_set rfds; fd_set wfds; - conn->state = TLS_state_client_hello; + conn->handshake_state = TLS_state_client_hello; //sm3_init(&conn->sm3_ctx); @@ -3825,7 +3825,7 @@ int tls12_do_accept(TLS_CONNECT *conn) fd_set rfds; fd_set wfds; - conn->state = TLS_state_client_hello; + conn->handshake_state = TLS_state_client_hello; //sm3_init(&conn->sm3_ctx); digest_init(&conn->dgst_ctx, DIGEST_sm3()); diff --git a/src/tls13.c b/src/tls13.c index cf1f8a98..043ce6ab 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -1091,6 +1091,7 @@ int tls13_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen, size_t *s conn->recordlen = 5 + record_datalen; conn->record_offset = 0; + conn->send_state = TLS_state_send_record; // 需要记录密文对应的明文是什么,当完整的报文发送之后,这些信息要返回给调用方 //conn->plain_recordlen = datalen + 5; @@ -1123,6 +1124,8 @@ int tls13_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen, size_t *s //*sentlen = conn->plain_recordlen - 5; *sentlen = conn->sentlen; + conn->record_offset = 0; + conn->send_state = 0; return ret; } @@ -1135,14 +1138,11 @@ int tls13_do_recv(TLS_CONNECT *conn) uint8_t *seq_num; - // 在接收EarlyData的时候,当前的状态有问题啊 - - switch (conn->state) { + switch (conn->recv_state) { case 0: - case TLS_state_early_data: conn->record_offset = 0; conn->recordlen = TLS_RECORD_HEADER_SIZE; - conn->state = TLS_state_recv_record_header; + conn->recv_state = TLS_state_recv_record_header; case TLS_state_recv_record_header: while (conn->recordlen) { @@ -1167,7 +1167,7 @@ int tls13_do_recv(TLS_CONNECT *conn) return -1; } conn->recordlen = tls_record_data_length(conn->record); - conn->state = TLS_state_recv_record_data; + conn->recv_state = TLS_state_recv_record_data; case TLS_state_recv_record_data: while (conn->recordlen) { @@ -1182,11 +1182,11 @@ int tls13_do_recv(TLS_CONNECT *conn) conn->recordlen -= n; conn->record_offset += n; } - conn->state = 0; + conn->recv_state = 0; break; default: - fprintf(stderr, "conn->state = %d\n", conn->state); + fprintf(stderr, "conn->recv_state = %d\n", conn->recv_state); error_print(); return -1; } @@ -1354,11 +1354,15 @@ int tls13_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen) // 这里需要考虑max_early_data_size的问题 int tls13_recv_early_data(TLS_CONNECT *conn) { + int ret; + tls_trace("recv EarlyData\n"); - if (tls13_do_recv(conn) != 1) { - error_print(); - return -1; + if ((ret = tls13_do_recv(conn)) != 1) { + if (ret != TLS_ERROR_RECV_AGAIN && ret != TLS_ERROR_SEND_AGAIN) { + error_print(); + } + return ret; } memcpy(conn->early_data_buf, conn->data, conn->datalen); conn->early_data_len = conn->datalen; @@ -7792,10 +7796,10 @@ int tls13_send_alert(TLS_CONNECT *conn, int alert) tls13_record_print(stderr, 0, 0, conn->plain_record, conn->plain_recordlen); - switch (conn->state) { - case TLS_handshake_client_hello: - case TLS_handshake_server_hello: - case TLS_handshake_hello_retry_request: + switch (conn->handshake_state) { + case TLS_state_client_hello: + case TLS_state_server_hello: + case TLS_state_hello_retry_request: tls_socket_send(conn->sock, conn->plain_record, conn->plain_recordlen, 0); break; default: @@ -8589,7 +8593,7 @@ int tls13_do_client_handshake(TLS_CONNECT *conn) int ret; int next_state; - switch (conn->state) { + switch (conn->handshake_state) { case TLS_state_hello_retry_request: case TLS_state_client_hello_again: case TLS_state_server_hello: @@ -8604,7 +8608,7 @@ int tls13_do_client_handshake(TLS_CONNECT *conn) break; } - switch (conn->state) { + switch (conn->handshake_state) { case TLS_state_client_hello: ret = tls13_send_client_hello(conn); next_state = TLS_state_hello_retry_request; @@ -8711,7 +8715,7 @@ int tls13_do_client_handshake(TLS_CONNECT *conn) } } - conn->state = next_state; + conn->handshake_state = next_state; // ret == 0 means this step is bypassed if (ret == 1) { @@ -8727,7 +8731,7 @@ int tls13_do_server_handshake(TLS_CONNECT *conn) int ret; int next_state; - switch (conn->state) { + switch (conn->handshake_state) { case TLS_state_client_hello: ret = tls13_recv_client_hello(conn); if (conn->early_data) @@ -8854,7 +8858,7 @@ int tls13_do_server_handshake(TLS_CONNECT *conn) } - conn->state = next_state; + conn->handshake_state = next_state; tls_clean_record(conn); @@ -8865,7 +8869,7 @@ int tls13_client_handshake(TLS_CONNECT *conn) { int ret; - while (conn->state != TLS_state_handshake_over) { + while (conn->handshake_state != TLS_state_handshake_over) { ret = tls13_do_client_handshake(conn); @@ -8879,8 +8883,6 @@ int tls13_client_handshake(TLS_CONNECT *conn) } - conn->state = 0; - // TODO: cleanup conn? return 1; @@ -8890,7 +8892,7 @@ int tls13_server_handshake(TLS_CONNECT *conn) { int ret; - while (conn->state != TLS_state_handshake_over) { + while (conn->handshake_state != TLS_state_handshake_over) { ret = tls13_do_server_handshake(conn); @@ -8902,8 +8904,6 @@ int tls13_server_handshake(TLS_CONNECT *conn) } } - conn->state = 0; - // TODO: cleanup conn? return 1; @@ -8912,82 +8912,59 @@ int tls13_server_handshake(TLS_CONNECT *conn) int tls13_do_connect(TLS_CONNECT *conn) { int ret; - fd_set rfds; - fd_set wfds; - // 应该把protocol_version的初始化放在这里 - - conn->state = TLS_state_client_hello; - //sm3_init(&conn->sm3_ctx); - - while (1) { - - ret = tls13_client_handshake(conn); - if (ret == 1) { - break; - - } else if (ret == TLS_ERROR_SEND_AGAIN) { - FD_ZERO(&rfds); - FD_ZERO(&wfds); - FD_SET(conn->sock, &rfds); - select(conn->sock + 1, &rfds, &wfds, NULL, NULL); - - } else if (ret == TLS_ERROR_RECV_AGAIN) { - FD_ZERO(&rfds); - FD_ZERO(&wfds); - FD_SET(conn->sock, &wfds); - select(conn->sock + 1, &rfds, &wfds, NULL, NULL); - - } else { - error_print(); - return -1; - } + if (!conn || !conn->is_client) { + error_print(); + return -1; } - fprintf(stderr, "tls13_do_connect: connected\n"); + if (conn->handshake_state == TLS_state_handshake_over) { + return 1; + } - return 1; + if (conn->handshake_state == TLS_state_handshake_init) { + conn->handshake_state = TLS_state_client_hello; + } + + ret = tls13_client_handshake(conn); + if (ret == 1) { + conn->handshake_state = TLS_state_handshake_over; + return 1; + } + if (ret == TLS_ERROR_RECV_AGAIN || ret == TLS_ERROR_SEND_AGAIN) { + return ret; + } + + error_print(); + return -1; } int tls13_do_accept(TLS_CONNECT *conn) { int ret; - fd_set rfds; - fd_set wfds; - conn->state = TLS_state_client_hello; - - //sm3_init(&conn->sm3_ctx); - - fprintf(stderr, "tls13_do_accept\n"); - - - while (1) { - - ret = tls13_server_handshake(conn); - - if (ret == 1) { - break; - - } else if (ret == TLS_ERROR_SEND_AGAIN) { - FD_ZERO(&rfds); - FD_ZERO(&wfds); - FD_SET(conn->sock, &rfds); - select(conn->sock + 1, &rfds, &wfds, NULL, NULL); - - } else if (ret == TLS_ERROR_RECV_AGAIN) { - FD_ZERO(&rfds); - FD_ZERO(&wfds); - FD_SET(conn->sock, &wfds); - select(conn->sock + 1, &rfds, &wfds, NULL, NULL); - - } else { - error_print(); - return -1; - } + if (!conn || conn->is_client) { + error_print(); + return -1; } - fprintf(stderr, "tls13_do_accept: connected\n"); + if (conn->handshake_state == TLS_state_handshake_over) { + return 1; + } - return 1; + if (conn->handshake_state == TLS_state_handshake_init) { + conn->handshake_state = TLS_state_client_hello; + } + + ret = tls13_server_handshake(conn); + if (ret == 1) { + conn->handshake_state = TLS_state_handshake_over; + return 1; + } + if (ret == TLS_ERROR_RECV_AGAIN || ret == TLS_ERROR_SEND_AGAIN) { + return ret; + } + + error_print(); + return -1; } diff --git a/tools/tls13_client.c b/tools/tls13_client.c index b74c20ed..2c882f57 100644 --- a/tools/tls13_client.c +++ b/tools/tls13_client.c @@ -23,6 +23,53 @@ #endif +static int set_socket_nonblocking(tls_socket_t sock) +{ +#ifdef WIN32 + u_long mode = 1; + if (ioctlsocket(sock, FIONBIO, &mode) != 0) { + error_print(); + return -1; + } +#else + int flags; + if ((flags = fcntl(sock, F_GETFL)) < 0 + || fcntl(sock, F_SETFL, flags | O_NONBLOCK) < 0) { + error_print(); + return -1; + } +#endif + return 1; +} + +static int do_handshake_select(TLS_CONNECT *conn) +{ + int ret; + fd_set rfds; + fd_set wfds; + + for (;;) { + ret = tls_do_handshake(conn); + if (ret == 1) { + return 1; + } + FD_ZERO(&rfds); + FD_ZERO(&wfds); + if (ret == TLS_ERROR_RECV_AGAIN) { + FD_SET(conn->sock, &rfds); + } else if (ret == TLS_ERROR_SEND_AGAIN) { + FD_SET(conn->sock, &wfds); + } else { + error_print(); + return -1; + } + if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) { + error_print(); + return -1; + } + } +} + static const char *http_get = "GET / HTTP/1.1\r\n" "Hostname: aaa\r\n" @@ -629,7 +676,12 @@ bad: goto end; } - if (tls_do_handshake(&conn) != 1) { + if (set_socket_nonblocking(sock) != 1) { + error_print(); + goto end; + } + + if (do_handshake_select(&conn) != 1) { fprintf(stderr, "%s: error\n", prog); goto end; } diff --git a/tools/tls13_server.c b/tools/tls13_server.c index 599efee6..68225780 100644 --- a/tools/tls13_server.c +++ b/tools/tls13_server.c @@ -52,6 +52,53 @@ static const char *help = "\n"; +static int set_socket_nonblocking(tls_socket_t sock) +{ +#ifdef WIN32 + u_long mode = 1; + if (ioctlsocket(sock, FIONBIO, &mode) != 0) { + error_print(); + return -1; + } +#else + int flags; + if ((flags = fcntl(sock, F_GETFL)) < 0 + || fcntl(sock, F_SETFL, flags | O_NONBLOCK) < 0) { + error_print(); + return -1; + } +#endif + return 1; +} + +static int do_handshake_select(TLS_CONNECT *conn) +{ + int ret; + fd_set rfds; + fd_set wfds; + + for (;;) { + ret = tls_do_handshake(conn); + if (ret == 1) { + return 1; + } + FD_ZERO(&rfds); + FD_ZERO(&wfds); + if (ret == TLS_ERROR_RECV_AGAIN) { + FD_SET(conn->sock, &rfds); + } else if (ret == TLS_ERROR_SEND_AGAIN) { + FD_SET(conn->sock, &wfds); + } else { + error_print(); + return -1; + } + if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) { + error_print(); + return -1; + } + } +} + int tls13_server_main(int argc , char **argv) { @@ -468,6 +515,11 @@ bad: goto end; } + if (set_socket_nonblocking(conn_sock) != 1) { + error_print(); + goto end; + } + // pre_shared_key from external if (psk_keys_cnt) { @@ -511,7 +563,7 @@ bad: tls13_enable_pre_shared_key(&conn, 1); } - if (tls_do_handshake(&conn) != 1) { + if (do_handshake_select(&conn) != 1) { error_print(); goto end; }