Update TLS state machine

This commit is contained in:
Zhi Guan
2026-06-12 13:28:10 +08:00
parent fb93fba5ff
commit 6f42fdf31f
6 changed files with 195 additions and 111 deletions

View File

@@ -1004,6 +1004,7 @@ enum {
TLS_state_handshake_over,
TLS_state_send_record,
TLS_state_recv_record_header,
TLS_state_recv_record_data,
};
@@ -1017,8 +1018,10 @@ typedef struct {
TLS_CTX *ctx;
// handshake state for state machine
int state;
// states for state machines
int handshake_state;
int send_state;
int recv_state;
uint8_t record[TLS_MAX_RECORD_SIZE];

View File

@@ -2261,7 +2261,7 @@ int tlcp_do_client_handshake(TLS_CONNECT *conn)
int ret;
int next_state;
switch (conn->state) {
switch (conn->handshake_state) {
case TLS_state_client_hello:
ret = tlcp_send_client_hello(conn);
next_state = TLS_state_server_hello;
@@ -2346,7 +2346,7 @@ int tlcp_do_client_handshake(TLS_CONNECT *conn)
}
}
conn->state = next_state;
conn->handshake_state = next_state;
// ret == 0 means this step is bypassed
if (ret == 1) {
@@ -2361,7 +2361,7 @@ int tlcp_do_server_handshake(TLS_CONNECT *conn)
int ret;
int next_state;
switch (conn->state) {
switch (conn->handshake_state) {
case TLS_state_client_hello:
ret = tlcp_recv_client_hello(conn);
next_state = TLS_state_server_hello;
@@ -2448,7 +2448,7 @@ int tlcp_do_server_handshake(TLS_CONNECT *conn)
}
conn->state = next_state;
conn->handshake_state = next_state;
tls_clean_record(conn);
@@ -2459,7 +2459,7 @@ int tlcp_client_handshake(TLS_CONNECT *conn)
{
int ret;
while (conn->state != TLS_state_handshake_over) {
while (conn->handshake_state != TLS_state_handshake_over) {
ret = tlcp_do_client_handshake(conn);
@@ -2481,7 +2481,7 @@ int tlcp_server_handshake(TLS_CONNECT *conn)
int ret;
while (conn->state != TLS_state_handshake_over) {
while (conn->handshake_state != TLS_state_handshake_over) {
ret = tlcp_do_server_handshake(conn);
@@ -2506,7 +2506,7 @@ int tlcp_do_connect(TLS_CONNECT *conn)
// 应该把protocol_version的初始化放在这里
conn->state = TLS_state_client_hello;
conn->handshake_state = TLS_state_client_hello;
//sm3_init(&conn->sm3_ctx);
while (1) {
@@ -2542,7 +2542,7 @@ int tlcp_do_accept(TLS_CONNECT *conn)
fd_set rfds;
fd_set wfds;
conn->state = TLS_state_client_hello;
conn->handshake_state = TLS_state_client_hello;
//sm3_init(&conn->sm3_ctx);

View File

@@ -3536,7 +3536,7 @@ int tls12_do_client_handshake(TLS_CONNECT *conn)
int ret;
int next_state;
switch (conn->state) {
switch (conn->handshake_state) {
case TLS_state_client_hello:
ret = tls_send_client_hello(conn);
next_state = TLS_state_server_hello;
@@ -3626,7 +3626,7 @@ int tls12_do_client_handshake(TLS_CONNECT *conn)
}
}
conn->state = next_state;
conn->handshake_state = next_state;
// ret == 0 means this step is bypassed
if (ret == 1) {
@@ -3641,7 +3641,7 @@ int tls12_do_server_handshake(TLS_CONNECT *conn)
int ret;
int next_state;
switch (conn->state) {
switch (conn->handshake_state) {
case TLS_state_client_hello:
ret = tls_recv_client_hello(conn);
next_state = TLS_state_server_hello;
@@ -3728,7 +3728,7 @@ int tls12_do_server_handshake(TLS_CONNECT *conn)
}
conn->state = next_state;
conn->handshake_state = next_state;
tls_clean_record(conn);
@@ -3741,7 +3741,7 @@ int tls12_client_handshake(TLS_CONNECT *conn)
{
int ret;
while (conn->state != TLS_state_handshake_over) {
while (conn->handshake_state != TLS_state_handshake_over) {
ret = tls12_do_client_handshake(conn);
@@ -3763,7 +3763,7 @@ int tls12_server_handshake(TLS_CONNECT *conn)
int ret;
while (conn->state != TLS_state_handshake_over) {
while (conn->handshake_state != TLS_state_handshake_over) {
ret = tls12_do_server_handshake(conn);
@@ -3786,7 +3786,7 @@ int tls12_do_connect(TLS_CONNECT *conn)
fd_set rfds;
fd_set wfds;
conn->state = TLS_state_client_hello;
conn->handshake_state = TLS_state_client_hello;
//sm3_init(&conn->sm3_ctx);
@@ -3825,7 +3825,7 @@ int tls12_do_accept(TLS_CONNECT *conn)
fd_set rfds;
fd_set wfds;
conn->state = TLS_state_client_hello;
conn->handshake_state = TLS_state_client_hello;
//sm3_init(&conn->sm3_ctx);
digest_init(&conn->dgst_ctx, DIGEST_sm3());

View File

@@ -1091,6 +1091,7 @@ int tls13_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen, size_t *s
conn->recordlen = 5 + record_datalen;
conn->record_offset = 0;
conn->send_state = TLS_state_send_record;
// 需要记录密文对应的明文是什么,当完整的报文发送之后,这些信息要返回给调用方
//conn->plain_recordlen = datalen + 5;
@@ -1123,6 +1124,8 @@ int tls13_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen, size_t *s
//*sentlen = conn->plain_recordlen - 5;
*sentlen = conn->sentlen;
conn->record_offset = 0;
conn->send_state = 0;
return ret;
}
@@ -1135,14 +1138,11 @@ int tls13_do_recv(TLS_CONNECT *conn)
uint8_t *seq_num;
// 在接收EarlyData的时候当前的状态有问题啊
switch (conn->state) {
switch (conn->recv_state) {
case 0:
case TLS_state_early_data:
conn->record_offset = 0;
conn->recordlen = TLS_RECORD_HEADER_SIZE;
conn->state = TLS_state_recv_record_header;
conn->recv_state = TLS_state_recv_record_header;
case TLS_state_recv_record_header:
while (conn->recordlen) {
@@ -1167,7 +1167,7 @@ int tls13_do_recv(TLS_CONNECT *conn)
return -1;
}
conn->recordlen = tls_record_data_length(conn->record);
conn->state = TLS_state_recv_record_data;
conn->recv_state = TLS_state_recv_record_data;
case TLS_state_recv_record_data:
while (conn->recordlen) {
@@ -1182,11 +1182,11 @@ int tls13_do_recv(TLS_CONNECT *conn)
conn->recordlen -= n;
conn->record_offset += n;
}
conn->state = 0;
conn->recv_state = 0;
break;
default:
fprintf(stderr, "conn->state = %d\n", conn->state);
fprintf(stderr, "conn->recv_state = %d\n", conn->recv_state);
error_print();
return -1;
}
@@ -1354,11 +1354,15 @@ int tls13_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen)
// 这里需要考虑max_early_data_size的问题
int tls13_recv_early_data(TLS_CONNECT *conn)
{
int ret;
tls_trace("recv EarlyData\n");
if (tls13_do_recv(conn) != 1) {
error_print();
return -1;
if ((ret = tls13_do_recv(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN && ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
memcpy(conn->early_data_buf, conn->data, conn->datalen);
conn->early_data_len = conn->datalen;
@@ -7792,10 +7796,10 @@ int tls13_send_alert(TLS_CONNECT *conn, int alert)
tls13_record_print(stderr, 0, 0, conn->plain_record, conn->plain_recordlen);
switch (conn->state) {
case TLS_handshake_client_hello:
case TLS_handshake_server_hello:
case TLS_handshake_hello_retry_request:
switch (conn->handshake_state) {
case TLS_state_client_hello:
case TLS_state_server_hello:
case TLS_state_hello_retry_request:
tls_socket_send(conn->sock, conn->plain_record, conn->plain_recordlen, 0);
break;
default:
@@ -8589,7 +8593,7 @@ int tls13_do_client_handshake(TLS_CONNECT *conn)
int ret;
int next_state;
switch (conn->state) {
switch (conn->handshake_state) {
case TLS_state_hello_retry_request:
case TLS_state_client_hello_again:
case TLS_state_server_hello:
@@ -8604,7 +8608,7 @@ int tls13_do_client_handshake(TLS_CONNECT *conn)
break;
}
switch (conn->state) {
switch (conn->handshake_state) {
case TLS_state_client_hello:
ret = tls13_send_client_hello(conn);
next_state = TLS_state_hello_retry_request;
@@ -8711,7 +8715,7 @@ int tls13_do_client_handshake(TLS_CONNECT *conn)
}
}
conn->state = next_state;
conn->handshake_state = next_state;
// ret == 0 means this step is bypassed
if (ret == 1) {
@@ -8727,7 +8731,7 @@ int tls13_do_server_handshake(TLS_CONNECT *conn)
int ret;
int next_state;
switch (conn->state) {
switch (conn->handshake_state) {
case TLS_state_client_hello:
ret = tls13_recv_client_hello(conn);
if (conn->early_data)
@@ -8854,7 +8858,7 @@ int tls13_do_server_handshake(TLS_CONNECT *conn)
}
conn->state = next_state;
conn->handshake_state = next_state;
tls_clean_record(conn);
@@ -8865,7 +8869,7 @@ int tls13_client_handshake(TLS_CONNECT *conn)
{
int ret;
while (conn->state != TLS_state_handshake_over) {
while (conn->handshake_state != TLS_state_handshake_over) {
ret = tls13_do_client_handshake(conn);
@@ -8879,8 +8883,6 @@ int tls13_client_handshake(TLS_CONNECT *conn)
}
conn->state = 0;
// TODO: cleanup conn?
return 1;
@@ -8890,7 +8892,7 @@ int tls13_server_handshake(TLS_CONNECT *conn)
{
int ret;
while (conn->state != TLS_state_handshake_over) {
while (conn->handshake_state != TLS_state_handshake_over) {
ret = tls13_do_server_handshake(conn);
@@ -8902,8 +8904,6 @@ int tls13_server_handshake(TLS_CONNECT *conn)
}
}
conn->state = 0;
// TODO: cleanup conn?
return 1;
@@ -8912,82 +8912,59 @@ int tls13_server_handshake(TLS_CONNECT *conn)
int tls13_do_connect(TLS_CONNECT *conn)
{
int ret;
fd_set rfds;
fd_set wfds;
// 应该把protocol_version的初始化放在这里
conn->state = TLS_state_client_hello;
//sm3_init(&conn->sm3_ctx);
while (1) {
ret = tls13_client_handshake(conn);
if (ret == 1) {
break;
} else if (ret == TLS_ERROR_SEND_AGAIN) {
FD_ZERO(&rfds);
FD_ZERO(&wfds);
FD_SET(conn->sock, &rfds);
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
} else if (ret == TLS_ERROR_RECV_AGAIN) {
FD_ZERO(&rfds);
FD_ZERO(&wfds);
FD_SET(conn->sock, &wfds);
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
} else {
error_print();
return -1;
}
if (!conn || !conn->is_client) {
error_print();
return -1;
}
fprintf(stderr, "tls13_do_connect: connected\n");
if (conn->handshake_state == TLS_state_handshake_over) {
return 1;
}
return 1;
if (conn->handshake_state == TLS_state_handshake_init) {
conn->handshake_state = TLS_state_client_hello;
}
ret = tls13_client_handshake(conn);
if (ret == 1) {
conn->handshake_state = TLS_state_handshake_over;
return 1;
}
if (ret == TLS_ERROR_RECV_AGAIN || ret == TLS_ERROR_SEND_AGAIN) {
return ret;
}
error_print();
return -1;
}
int tls13_do_accept(TLS_CONNECT *conn)
{
int ret;
fd_set rfds;
fd_set wfds;
conn->state = TLS_state_client_hello;
//sm3_init(&conn->sm3_ctx);
fprintf(stderr, "tls13_do_accept\n");
while (1) {
ret = tls13_server_handshake(conn);
if (ret == 1) {
break;
} else if (ret == TLS_ERROR_SEND_AGAIN) {
FD_ZERO(&rfds);
FD_ZERO(&wfds);
FD_SET(conn->sock, &rfds);
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
} else if (ret == TLS_ERROR_RECV_AGAIN) {
FD_ZERO(&rfds);
FD_ZERO(&wfds);
FD_SET(conn->sock, &wfds);
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
} else {
error_print();
return -1;
}
if (!conn || conn->is_client) {
error_print();
return -1;
}
fprintf(stderr, "tls13_do_accept: connected\n");
if (conn->handshake_state == TLS_state_handshake_over) {
return 1;
}
return 1;
if (conn->handshake_state == TLS_state_handshake_init) {
conn->handshake_state = TLS_state_client_hello;
}
ret = tls13_server_handshake(conn);
if (ret == 1) {
conn->handshake_state = TLS_state_handshake_over;
return 1;
}
if (ret == TLS_ERROR_RECV_AGAIN || ret == TLS_ERROR_SEND_AGAIN) {
return ret;
}
error_print();
return -1;
}

View File

@@ -23,6 +23,53 @@
#endif
static int set_socket_nonblocking(tls_socket_t sock)
{
#ifdef WIN32
u_long mode = 1;
if (ioctlsocket(sock, FIONBIO, &mode) != 0) {
error_print();
return -1;
}
#else
int flags;
if ((flags = fcntl(sock, F_GETFL)) < 0
|| fcntl(sock, F_SETFL, flags | O_NONBLOCK) < 0) {
error_print();
return -1;
}
#endif
return 1;
}
static int do_handshake_select(TLS_CONNECT *conn)
{
int ret;
fd_set rfds;
fd_set wfds;
for (;;) {
ret = tls_do_handshake(conn);
if (ret == 1) {
return 1;
}
FD_ZERO(&rfds);
FD_ZERO(&wfds);
if (ret == TLS_ERROR_RECV_AGAIN) {
FD_SET(conn->sock, &rfds);
} else if (ret == TLS_ERROR_SEND_AGAIN) {
FD_SET(conn->sock, &wfds);
} else {
error_print();
return -1;
}
if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) {
error_print();
return -1;
}
}
}
static const char *http_get =
"GET / HTTP/1.1\r\n"
"Hostname: aaa\r\n"
@@ -629,7 +676,12 @@ bad:
goto end;
}
if (tls_do_handshake(&conn) != 1) {
if (set_socket_nonblocking(sock) != 1) {
error_print();
goto end;
}
if (do_handshake_select(&conn) != 1) {
fprintf(stderr, "%s: error\n", prog);
goto end;
}

View File

@@ -52,6 +52,53 @@ static const char *help =
"\n";
static int set_socket_nonblocking(tls_socket_t sock)
{
#ifdef WIN32
u_long mode = 1;
if (ioctlsocket(sock, FIONBIO, &mode) != 0) {
error_print();
return -1;
}
#else
int flags;
if ((flags = fcntl(sock, F_GETFL)) < 0
|| fcntl(sock, F_SETFL, flags | O_NONBLOCK) < 0) {
error_print();
return -1;
}
#endif
return 1;
}
static int do_handshake_select(TLS_CONNECT *conn)
{
int ret;
fd_set rfds;
fd_set wfds;
for (;;) {
ret = tls_do_handshake(conn);
if (ret == 1) {
return 1;
}
FD_ZERO(&rfds);
FD_ZERO(&wfds);
if (ret == TLS_ERROR_RECV_AGAIN) {
FD_SET(conn->sock, &rfds);
} else if (ret == TLS_ERROR_SEND_AGAIN) {
FD_SET(conn->sock, &wfds);
} else {
error_print();
return -1;
}
if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) {
error_print();
return -1;
}
}
}
int tls13_server_main(int argc , char **argv)
{
@@ -468,6 +515,11 @@ bad:
goto end;
}
if (set_socket_nonblocking(conn_sock) != 1) {
error_print();
goto end;
}
// pre_shared_key from external
if (psk_keys_cnt) {
@@ -511,7 +563,7 @@ bad:
tls13_enable_pre_shared_key(&conn, 1);
}
if (tls_do_handshake(&conn) != 1) {
if (do_handshake_select(&conn) != 1) {
error_print();
goto end;
}