Update TLS state machine

This commit is contained in:
Zhi Guan
2026-06-12 13:28:10 +08:00
parent fb93fba5ff
commit 6f42fdf31f
6 changed files with 195 additions and 111 deletions

View File

@@ -1004,6 +1004,7 @@ enum {
TLS_state_handshake_over, TLS_state_handshake_over,
TLS_state_send_record,
TLS_state_recv_record_header, TLS_state_recv_record_header,
TLS_state_recv_record_data, TLS_state_recv_record_data,
}; };
@@ -1017,8 +1018,10 @@ typedef struct {
TLS_CTX *ctx; TLS_CTX *ctx;
// handshake state for state machine // states for state machines
int state; int handshake_state;
int send_state;
int recv_state;
uint8_t record[TLS_MAX_RECORD_SIZE]; uint8_t record[TLS_MAX_RECORD_SIZE];

View File

@@ -2261,7 +2261,7 @@ int tlcp_do_client_handshake(TLS_CONNECT *conn)
int ret; int ret;
int next_state; int next_state;
switch (conn->state) { switch (conn->handshake_state) {
case TLS_state_client_hello: case TLS_state_client_hello:
ret = tlcp_send_client_hello(conn); ret = tlcp_send_client_hello(conn);
next_state = TLS_state_server_hello; next_state = TLS_state_server_hello;
@@ -2346,7 +2346,7 @@ int tlcp_do_client_handshake(TLS_CONNECT *conn)
} }
} }
conn->state = next_state; conn->handshake_state = next_state;
// ret == 0 means this step is bypassed // ret == 0 means this step is bypassed
if (ret == 1) { if (ret == 1) {
@@ -2361,7 +2361,7 @@ int tlcp_do_server_handshake(TLS_CONNECT *conn)
int ret; int ret;
int next_state; int next_state;
switch (conn->state) { switch (conn->handshake_state) {
case TLS_state_client_hello: case TLS_state_client_hello:
ret = tlcp_recv_client_hello(conn); ret = tlcp_recv_client_hello(conn);
next_state = TLS_state_server_hello; next_state = TLS_state_server_hello;
@@ -2448,7 +2448,7 @@ int tlcp_do_server_handshake(TLS_CONNECT *conn)
} }
conn->state = next_state; conn->handshake_state = next_state;
tls_clean_record(conn); tls_clean_record(conn);
@@ -2459,7 +2459,7 @@ int tlcp_client_handshake(TLS_CONNECT *conn)
{ {
int ret; int ret;
while (conn->state != TLS_state_handshake_over) { while (conn->handshake_state != TLS_state_handshake_over) {
ret = tlcp_do_client_handshake(conn); ret = tlcp_do_client_handshake(conn);
@@ -2481,7 +2481,7 @@ int tlcp_server_handshake(TLS_CONNECT *conn)
int ret; int ret;
while (conn->state != TLS_state_handshake_over) { while (conn->handshake_state != TLS_state_handshake_over) {
ret = tlcp_do_server_handshake(conn); ret = tlcp_do_server_handshake(conn);
@@ -2506,7 +2506,7 @@ int tlcp_do_connect(TLS_CONNECT *conn)
// 应该把protocol_version的初始化放在这里 // 应该把protocol_version的初始化放在这里
conn->state = TLS_state_client_hello; conn->handshake_state = TLS_state_client_hello;
//sm3_init(&conn->sm3_ctx); //sm3_init(&conn->sm3_ctx);
while (1) { while (1) {
@@ -2542,7 +2542,7 @@ int tlcp_do_accept(TLS_CONNECT *conn)
fd_set rfds; fd_set rfds;
fd_set wfds; fd_set wfds;
conn->state = TLS_state_client_hello; conn->handshake_state = TLS_state_client_hello;
//sm3_init(&conn->sm3_ctx); //sm3_init(&conn->sm3_ctx);

View File

@@ -3536,7 +3536,7 @@ int tls12_do_client_handshake(TLS_CONNECT *conn)
int ret; int ret;
int next_state; int next_state;
switch (conn->state) { switch (conn->handshake_state) {
case TLS_state_client_hello: case TLS_state_client_hello:
ret = tls_send_client_hello(conn); ret = tls_send_client_hello(conn);
next_state = TLS_state_server_hello; next_state = TLS_state_server_hello;
@@ -3626,7 +3626,7 @@ int tls12_do_client_handshake(TLS_CONNECT *conn)
} }
} }
conn->state = next_state; conn->handshake_state = next_state;
// ret == 0 means this step is bypassed // ret == 0 means this step is bypassed
if (ret == 1) { if (ret == 1) {
@@ -3641,7 +3641,7 @@ int tls12_do_server_handshake(TLS_CONNECT *conn)
int ret; int ret;
int next_state; int next_state;
switch (conn->state) { switch (conn->handshake_state) {
case TLS_state_client_hello: case TLS_state_client_hello:
ret = tls_recv_client_hello(conn); ret = tls_recv_client_hello(conn);
next_state = TLS_state_server_hello; next_state = TLS_state_server_hello;
@@ -3728,7 +3728,7 @@ int tls12_do_server_handshake(TLS_CONNECT *conn)
} }
conn->state = next_state; conn->handshake_state = next_state;
tls_clean_record(conn); tls_clean_record(conn);
@@ -3741,7 +3741,7 @@ int tls12_client_handshake(TLS_CONNECT *conn)
{ {
int ret; int ret;
while (conn->state != TLS_state_handshake_over) { while (conn->handshake_state != TLS_state_handshake_over) {
ret = tls12_do_client_handshake(conn); ret = tls12_do_client_handshake(conn);
@@ -3763,7 +3763,7 @@ int tls12_server_handshake(TLS_CONNECT *conn)
int ret; int ret;
while (conn->state != TLS_state_handshake_over) { while (conn->handshake_state != TLS_state_handshake_over) {
ret = tls12_do_server_handshake(conn); ret = tls12_do_server_handshake(conn);
@@ -3786,7 +3786,7 @@ int tls12_do_connect(TLS_CONNECT *conn)
fd_set rfds; fd_set rfds;
fd_set wfds; fd_set wfds;
conn->state = TLS_state_client_hello; conn->handshake_state = TLS_state_client_hello;
//sm3_init(&conn->sm3_ctx); //sm3_init(&conn->sm3_ctx);
@@ -3825,7 +3825,7 @@ int tls12_do_accept(TLS_CONNECT *conn)
fd_set rfds; fd_set rfds;
fd_set wfds; fd_set wfds;
conn->state = TLS_state_client_hello; conn->handshake_state = TLS_state_client_hello;
//sm3_init(&conn->sm3_ctx); //sm3_init(&conn->sm3_ctx);
digest_init(&conn->dgst_ctx, DIGEST_sm3()); digest_init(&conn->dgst_ctx, DIGEST_sm3());

View File

@@ -1091,6 +1091,7 @@ int tls13_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen, size_t *s
conn->recordlen = 5 + record_datalen; conn->recordlen = 5 + record_datalen;
conn->record_offset = 0; conn->record_offset = 0;
conn->send_state = TLS_state_send_record;
// 需要记录密文对应的明文是什么,当完整的报文发送之后,这些信息要返回给调用方 // 需要记录密文对应的明文是什么,当完整的报文发送之后,这些信息要返回给调用方
//conn->plain_recordlen = datalen + 5; //conn->plain_recordlen = datalen + 5;
@@ -1123,6 +1124,8 @@ int tls13_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen, size_t *s
//*sentlen = conn->plain_recordlen - 5; //*sentlen = conn->plain_recordlen - 5;
*sentlen = conn->sentlen; *sentlen = conn->sentlen;
conn->record_offset = 0;
conn->send_state = 0;
return ret; return ret;
} }
@@ -1135,14 +1138,11 @@ int tls13_do_recv(TLS_CONNECT *conn)
uint8_t *seq_num; uint8_t *seq_num;
// 在接收EarlyData的时候当前的状态有问题啊 switch (conn->recv_state) {
switch (conn->state) {
case 0: case 0:
case TLS_state_early_data:
conn->record_offset = 0; conn->record_offset = 0;
conn->recordlen = TLS_RECORD_HEADER_SIZE; conn->recordlen = TLS_RECORD_HEADER_SIZE;
conn->state = TLS_state_recv_record_header; conn->recv_state = TLS_state_recv_record_header;
case TLS_state_recv_record_header: case TLS_state_recv_record_header:
while (conn->recordlen) { while (conn->recordlen) {
@@ -1167,7 +1167,7 @@ int tls13_do_recv(TLS_CONNECT *conn)
return -1; return -1;
} }
conn->recordlen = tls_record_data_length(conn->record); conn->recordlen = tls_record_data_length(conn->record);
conn->state = TLS_state_recv_record_data; conn->recv_state = TLS_state_recv_record_data;
case TLS_state_recv_record_data: case TLS_state_recv_record_data:
while (conn->recordlen) { while (conn->recordlen) {
@@ -1182,11 +1182,11 @@ int tls13_do_recv(TLS_CONNECT *conn)
conn->recordlen -= n; conn->recordlen -= n;
conn->record_offset += n; conn->record_offset += n;
} }
conn->state = 0; conn->recv_state = 0;
break; break;
default: default:
fprintf(stderr, "conn->state = %d\n", conn->state); fprintf(stderr, "conn->recv_state = %d\n", conn->recv_state);
error_print(); error_print();
return -1; return -1;
} }
@@ -1354,11 +1354,15 @@ int tls13_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen)
// 这里需要考虑max_early_data_size的问题 // 这里需要考虑max_early_data_size的问题
int tls13_recv_early_data(TLS_CONNECT *conn) int tls13_recv_early_data(TLS_CONNECT *conn)
{ {
int ret;
tls_trace("recv EarlyData\n"); tls_trace("recv EarlyData\n");
if (tls13_do_recv(conn) != 1) { if ((ret = tls13_do_recv(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN && ret != TLS_ERROR_SEND_AGAIN) {
error_print(); error_print();
return -1; }
return ret;
} }
memcpy(conn->early_data_buf, conn->data, conn->datalen); memcpy(conn->early_data_buf, conn->data, conn->datalen);
conn->early_data_len = conn->datalen; conn->early_data_len = conn->datalen;
@@ -7792,10 +7796,10 @@ int tls13_send_alert(TLS_CONNECT *conn, int alert)
tls13_record_print(stderr, 0, 0, conn->plain_record, conn->plain_recordlen); tls13_record_print(stderr, 0, 0, conn->plain_record, conn->plain_recordlen);
switch (conn->state) { switch (conn->handshake_state) {
case TLS_handshake_client_hello: case TLS_state_client_hello:
case TLS_handshake_server_hello: case TLS_state_server_hello:
case TLS_handshake_hello_retry_request: case TLS_state_hello_retry_request:
tls_socket_send(conn->sock, conn->plain_record, conn->plain_recordlen, 0); tls_socket_send(conn->sock, conn->plain_record, conn->plain_recordlen, 0);
break; break;
default: default:
@@ -8589,7 +8593,7 @@ int tls13_do_client_handshake(TLS_CONNECT *conn)
int ret; int ret;
int next_state; int next_state;
switch (conn->state) { switch (conn->handshake_state) {
case TLS_state_hello_retry_request: case TLS_state_hello_retry_request:
case TLS_state_client_hello_again: case TLS_state_client_hello_again:
case TLS_state_server_hello: case TLS_state_server_hello:
@@ -8604,7 +8608,7 @@ int tls13_do_client_handshake(TLS_CONNECT *conn)
break; break;
} }
switch (conn->state) { switch (conn->handshake_state) {
case TLS_state_client_hello: case TLS_state_client_hello:
ret = tls13_send_client_hello(conn); ret = tls13_send_client_hello(conn);
next_state = TLS_state_hello_retry_request; next_state = TLS_state_hello_retry_request;
@@ -8711,7 +8715,7 @@ int tls13_do_client_handshake(TLS_CONNECT *conn)
} }
} }
conn->state = next_state; conn->handshake_state = next_state;
// ret == 0 means this step is bypassed // ret == 0 means this step is bypassed
if (ret == 1) { if (ret == 1) {
@@ -8727,7 +8731,7 @@ int tls13_do_server_handshake(TLS_CONNECT *conn)
int ret; int ret;
int next_state; int next_state;
switch (conn->state) { switch (conn->handshake_state) {
case TLS_state_client_hello: case TLS_state_client_hello:
ret = tls13_recv_client_hello(conn); ret = tls13_recv_client_hello(conn);
if (conn->early_data) if (conn->early_data)
@@ -8854,7 +8858,7 @@ int tls13_do_server_handshake(TLS_CONNECT *conn)
} }
conn->state = next_state; conn->handshake_state = next_state;
tls_clean_record(conn); tls_clean_record(conn);
@@ -8865,7 +8869,7 @@ int tls13_client_handshake(TLS_CONNECT *conn)
{ {
int ret; int ret;
while (conn->state != TLS_state_handshake_over) { while (conn->handshake_state != TLS_state_handshake_over) {
ret = tls13_do_client_handshake(conn); ret = tls13_do_client_handshake(conn);
@@ -8879,8 +8883,6 @@ int tls13_client_handshake(TLS_CONNECT *conn)
} }
conn->state = 0;
// TODO: cleanup conn? // TODO: cleanup conn?
return 1; return 1;
@@ -8890,7 +8892,7 @@ int tls13_server_handshake(TLS_CONNECT *conn)
{ {
int ret; int ret;
while (conn->state != TLS_state_handshake_over) { while (conn->handshake_state != TLS_state_handshake_over) {
ret = tls13_do_server_handshake(conn); ret = tls13_do_server_handshake(conn);
@@ -8902,8 +8904,6 @@ int tls13_server_handshake(TLS_CONNECT *conn)
} }
} }
conn->state = 0;
// TODO: cleanup conn? // TODO: cleanup conn?
return 1; return 1;
@@ -8912,82 +8912,59 @@ int tls13_server_handshake(TLS_CONNECT *conn)
int tls13_do_connect(TLS_CONNECT *conn) int tls13_do_connect(TLS_CONNECT *conn)
{ {
int ret; int ret;
fd_set rfds;
fd_set wfds;
// 应该把protocol_version的初始化放在这里 if (!conn || !conn->is_client) {
conn->state = TLS_state_client_hello;
//sm3_init(&conn->sm3_ctx);
while (1) {
ret = tls13_client_handshake(conn);
if (ret == 1) {
break;
} else if (ret == TLS_ERROR_SEND_AGAIN) {
FD_ZERO(&rfds);
FD_ZERO(&wfds);
FD_SET(conn->sock, &rfds);
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
} else if (ret == TLS_ERROR_RECV_AGAIN) {
FD_ZERO(&rfds);
FD_ZERO(&wfds);
FD_SET(conn->sock, &wfds);
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
} else {
error_print(); error_print();
return -1; return -1;
} }
if (conn->handshake_state == TLS_state_handshake_over) {
return 1;
} }
fprintf(stderr, "tls13_do_connect: connected\n"); if (conn->handshake_state == TLS_state_handshake_init) {
conn->handshake_state = TLS_state_client_hello;
}
ret = tls13_client_handshake(conn);
if (ret == 1) {
conn->handshake_state = TLS_state_handshake_over;
return 1; return 1;
}
if (ret == TLS_ERROR_RECV_AGAIN || ret == TLS_ERROR_SEND_AGAIN) {
return ret;
}
error_print();
return -1;
} }
int tls13_do_accept(TLS_CONNECT *conn) int tls13_do_accept(TLS_CONNECT *conn)
{ {
int ret; int ret;
fd_set rfds;
fd_set wfds;
conn->state = TLS_state_client_hello; if (!conn || conn->is_client) {
//sm3_init(&conn->sm3_ctx);
fprintf(stderr, "tls13_do_accept\n");
while (1) {
ret = tls13_server_handshake(conn);
if (ret == 1) {
break;
} else if (ret == TLS_ERROR_SEND_AGAIN) {
FD_ZERO(&rfds);
FD_ZERO(&wfds);
FD_SET(conn->sock, &rfds);
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
} else if (ret == TLS_ERROR_RECV_AGAIN) {
FD_ZERO(&rfds);
FD_ZERO(&wfds);
FD_SET(conn->sock, &wfds);
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
} else {
error_print(); error_print();
return -1; return -1;
} }
if (conn->handshake_state == TLS_state_handshake_over) {
return 1;
} }
fprintf(stderr, "tls13_do_accept: connected\n"); if (conn->handshake_state == TLS_state_handshake_init) {
conn->handshake_state = TLS_state_client_hello;
}
ret = tls13_server_handshake(conn);
if (ret == 1) {
conn->handshake_state = TLS_state_handshake_over;
return 1; return 1;
}
if (ret == TLS_ERROR_RECV_AGAIN || ret == TLS_ERROR_SEND_AGAIN) {
return ret;
}
error_print();
return -1;
} }

View File

@@ -23,6 +23,53 @@
#endif #endif
static int set_socket_nonblocking(tls_socket_t sock)
{
#ifdef WIN32
u_long mode = 1;
if (ioctlsocket(sock, FIONBIO, &mode) != 0) {
error_print();
return -1;
}
#else
int flags;
if ((flags = fcntl(sock, F_GETFL)) < 0
|| fcntl(sock, F_SETFL, flags | O_NONBLOCK) < 0) {
error_print();
return -1;
}
#endif
return 1;
}
static int do_handshake_select(TLS_CONNECT *conn)
{
int ret;
fd_set rfds;
fd_set wfds;
for (;;) {
ret = tls_do_handshake(conn);
if (ret == 1) {
return 1;
}
FD_ZERO(&rfds);
FD_ZERO(&wfds);
if (ret == TLS_ERROR_RECV_AGAIN) {
FD_SET(conn->sock, &rfds);
} else if (ret == TLS_ERROR_SEND_AGAIN) {
FD_SET(conn->sock, &wfds);
} else {
error_print();
return -1;
}
if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) {
error_print();
return -1;
}
}
}
static const char *http_get = static const char *http_get =
"GET / HTTP/1.1\r\n" "GET / HTTP/1.1\r\n"
"Hostname: aaa\r\n" "Hostname: aaa\r\n"
@@ -629,7 +676,12 @@ bad:
goto end; goto end;
} }
if (tls_do_handshake(&conn) != 1) { if (set_socket_nonblocking(sock) != 1) {
error_print();
goto end;
}
if (do_handshake_select(&conn) != 1) {
fprintf(stderr, "%s: error\n", prog); fprintf(stderr, "%s: error\n", prog);
goto end; goto end;
} }

View File

@@ -52,6 +52,53 @@ static const char *help =
"\n"; "\n";
static int set_socket_nonblocking(tls_socket_t sock)
{
#ifdef WIN32
u_long mode = 1;
if (ioctlsocket(sock, FIONBIO, &mode) != 0) {
error_print();
return -1;
}
#else
int flags;
if ((flags = fcntl(sock, F_GETFL)) < 0
|| fcntl(sock, F_SETFL, flags | O_NONBLOCK) < 0) {
error_print();
return -1;
}
#endif
return 1;
}
static int do_handshake_select(TLS_CONNECT *conn)
{
int ret;
fd_set rfds;
fd_set wfds;
for (;;) {
ret = tls_do_handshake(conn);
if (ret == 1) {
return 1;
}
FD_ZERO(&rfds);
FD_ZERO(&wfds);
if (ret == TLS_ERROR_RECV_AGAIN) {
FD_SET(conn->sock, &rfds);
} else if (ret == TLS_ERROR_SEND_AGAIN) {
FD_SET(conn->sock, &wfds);
} else {
error_print();
return -1;
}
if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) {
error_print();
return -1;
}
}
}
int tls13_server_main(int argc , char **argv) int tls13_server_main(int argc , char **argv)
{ {
@@ -468,6 +515,11 @@ bad:
goto end; goto end;
} }
if (set_socket_nonblocking(conn_sock) != 1) {
error_print();
goto end;
}
// pre_shared_key from external // pre_shared_key from external
if (psk_keys_cnt) { if (psk_keys_cnt) {
@@ -511,7 +563,7 @@ bad:
tls13_enable_pre_shared_key(&conn, 1); tls13_enable_pre_shared_key(&conn, 1);
} }
if (tls_do_handshake(&conn) != 1) { if (do_handshake_select(&conn) != 1) {
error_print(); error_print();
goto end; goto end;
} }