mirror of
https://github.com/guanzhi/GmSSL.git
synced 2026-06-19 19:33:38 +08:00
Update TLS state machine
This commit is contained in:
@@ -1004,6 +1004,7 @@ enum {
|
||||
TLS_state_handshake_over,
|
||||
|
||||
|
||||
TLS_state_send_record,
|
||||
TLS_state_recv_record_header,
|
||||
TLS_state_recv_record_data,
|
||||
};
|
||||
@@ -1017,8 +1018,10 @@ typedef struct {
|
||||
|
||||
TLS_CTX *ctx;
|
||||
|
||||
// handshake state for state machine
|
||||
int state;
|
||||
// states for state machines
|
||||
int handshake_state;
|
||||
int send_state;
|
||||
int recv_state;
|
||||
|
||||
|
||||
uint8_t record[TLS_MAX_RECORD_SIZE];
|
||||
|
||||
16
src/tlcp.c
16
src/tlcp.c
@@ -2261,7 +2261,7 @@ int tlcp_do_client_handshake(TLS_CONNECT *conn)
|
||||
int ret;
|
||||
int next_state;
|
||||
|
||||
switch (conn->state) {
|
||||
switch (conn->handshake_state) {
|
||||
case TLS_state_client_hello:
|
||||
ret = tlcp_send_client_hello(conn);
|
||||
next_state = TLS_state_server_hello;
|
||||
@@ -2346,7 +2346,7 @@ int tlcp_do_client_handshake(TLS_CONNECT *conn)
|
||||
}
|
||||
}
|
||||
|
||||
conn->state = next_state;
|
||||
conn->handshake_state = next_state;
|
||||
|
||||
// ret == 0 means this step is bypassed
|
||||
if (ret == 1) {
|
||||
@@ -2361,7 +2361,7 @@ int tlcp_do_server_handshake(TLS_CONNECT *conn)
|
||||
int ret;
|
||||
int next_state;
|
||||
|
||||
switch (conn->state) {
|
||||
switch (conn->handshake_state) {
|
||||
case TLS_state_client_hello:
|
||||
ret = tlcp_recv_client_hello(conn);
|
||||
next_state = TLS_state_server_hello;
|
||||
@@ -2448,7 +2448,7 @@ int tlcp_do_server_handshake(TLS_CONNECT *conn)
|
||||
|
||||
}
|
||||
|
||||
conn->state = next_state;
|
||||
conn->handshake_state = next_state;
|
||||
|
||||
tls_clean_record(conn);
|
||||
|
||||
@@ -2459,7 +2459,7 @@ int tlcp_client_handshake(TLS_CONNECT *conn)
|
||||
{
|
||||
int ret;
|
||||
|
||||
while (conn->state != TLS_state_handshake_over) {
|
||||
while (conn->handshake_state != TLS_state_handshake_over) {
|
||||
|
||||
ret = tlcp_do_client_handshake(conn);
|
||||
|
||||
@@ -2481,7 +2481,7 @@ int tlcp_server_handshake(TLS_CONNECT *conn)
|
||||
int ret;
|
||||
|
||||
|
||||
while (conn->state != TLS_state_handshake_over) {
|
||||
while (conn->handshake_state != TLS_state_handshake_over) {
|
||||
|
||||
ret = tlcp_do_server_handshake(conn);
|
||||
|
||||
@@ -2506,7 +2506,7 @@ int tlcp_do_connect(TLS_CONNECT *conn)
|
||||
|
||||
// 应该把protocol_version的初始化放在这里
|
||||
|
||||
conn->state = TLS_state_client_hello;
|
||||
conn->handshake_state = TLS_state_client_hello;
|
||||
//sm3_init(&conn->sm3_ctx);
|
||||
|
||||
while (1) {
|
||||
@@ -2542,7 +2542,7 @@ int tlcp_do_accept(TLS_CONNECT *conn)
|
||||
fd_set rfds;
|
||||
fd_set wfds;
|
||||
|
||||
conn->state = TLS_state_client_hello;
|
||||
conn->handshake_state = TLS_state_client_hello;
|
||||
|
||||
//sm3_init(&conn->sm3_ctx);
|
||||
|
||||
|
||||
16
src/tls12.c
16
src/tls12.c
@@ -3536,7 +3536,7 @@ int tls12_do_client_handshake(TLS_CONNECT *conn)
|
||||
int ret;
|
||||
int next_state;
|
||||
|
||||
switch (conn->state) {
|
||||
switch (conn->handshake_state) {
|
||||
case TLS_state_client_hello:
|
||||
ret = tls_send_client_hello(conn);
|
||||
next_state = TLS_state_server_hello;
|
||||
@@ -3626,7 +3626,7 @@ int tls12_do_client_handshake(TLS_CONNECT *conn)
|
||||
}
|
||||
}
|
||||
|
||||
conn->state = next_state;
|
||||
conn->handshake_state = next_state;
|
||||
|
||||
// ret == 0 means this step is bypassed
|
||||
if (ret == 1) {
|
||||
@@ -3641,7 +3641,7 @@ int tls12_do_server_handshake(TLS_CONNECT *conn)
|
||||
int ret;
|
||||
int next_state;
|
||||
|
||||
switch (conn->state) {
|
||||
switch (conn->handshake_state) {
|
||||
case TLS_state_client_hello:
|
||||
ret = tls_recv_client_hello(conn);
|
||||
next_state = TLS_state_server_hello;
|
||||
@@ -3728,7 +3728,7 @@ int tls12_do_server_handshake(TLS_CONNECT *conn)
|
||||
|
||||
}
|
||||
|
||||
conn->state = next_state;
|
||||
conn->handshake_state = next_state;
|
||||
|
||||
tls_clean_record(conn);
|
||||
|
||||
@@ -3741,7 +3741,7 @@ int tls12_client_handshake(TLS_CONNECT *conn)
|
||||
{
|
||||
int ret;
|
||||
|
||||
while (conn->state != TLS_state_handshake_over) {
|
||||
while (conn->handshake_state != TLS_state_handshake_over) {
|
||||
|
||||
ret = tls12_do_client_handshake(conn);
|
||||
|
||||
@@ -3763,7 +3763,7 @@ int tls12_server_handshake(TLS_CONNECT *conn)
|
||||
int ret;
|
||||
|
||||
|
||||
while (conn->state != TLS_state_handshake_over) {
|
||||
while (conn->handshake_state != TLS_state_handshake_over) {
|
||||
|
||||
ret = tls12_do_server_handshake(conn);
|
||||
|
||||
@@ -3786,7 +3786,7 @@ int tls12_do_connect(TLS_CONNECT *conn)
|
||||
fd_set rfds;
|
||||
fd_set wfds;
|
||||
|
||||
conn->state = TLS_state_client_hello;
|
||||
conn->handshake_state = TLS_state_client_hello;
|
||||
//sm3_init(&conn->sm3_ctx);
|
||||
|
||||
|
||||
@@ -3825,7 +3825,7 @@ int tls12_do_accept(TLS_CONNECT *conn)
|
||||
fd_set rfds;
|
||||
fd_set wfds;
|
||||
|
||||
conn->state = TLS_state_client_hello;
|
||||
conn->handshake_state = TLS_state_client_hello;
|
||||
|
||||
//sm3_init(&conn->sm3_ctx);
|
||||
digest_init(&conn->dgst_ctx, DIGEST_sm3());
|
||||
|
||||
159
src/tls13.c
159
src/tls13.c
@@ -1091,6 +1091,7 @@ int tls13_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen, size_t *s
|
||||
|
||||
conn->recordlen = 5 + record_datalen;
|
||||
conn->record_offset = 0;
|
||||
conn->send_state = TLS_state_send_record;
|
||||
|
||||
// 需要记录密文对应的明文是什么,当完整的报文发送之后,这些信息要返回给调用方
|
||||
//conn->plain_recordlen = datalen + 5;
|
||||
@@ -1123,6 +1124,8 @@ int tls13_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen, size_t *s
|
||||
|
||||
//*sentlen = conn->plain_recordlen - 5;
|
||||
*sentlen = conn->sentlen;
|
||||
conn->record_offset = 0;
|
||||
conn->send_state = 0;
|
||||
|
||||
return ret;
|
||||
}
|
||||
@@ -1135,14 +1138,11 @@ int tls13_do_recv(TLS_CONNECT *conn)
|
||||
uint8_t *seq_num;
|
||||
|
||||
|
||||
// 在接收EarlyData的时候,当前的状态有问题啊
|
||||
|
||||
switch (conn->state) {
|
||||
switch (conn->recv_state) {
|
||||
case 0:
|
||||
case TLS_state_early_data:
|
||||
conn->record_offset = 0;
|
||||
conn->recordlen = TLS_RECORD_HEADER_SIZE;
|
||||
conn->state = TLS_state_recv_record_header;
|
||||
conn->recv_state = TLS_state_recv_record_header;
|
||||
|
||||
case TLS_state_recv_record_header:
|
||||
while (conn->recordlen) {
|
||||
@@ -1167,7 +1167,7 @@ int tls13_do_recv(TLS_CONNECT *conn)
|
||||
return -1;
|
||||
}
|
||||
conn->recordlen = tls_record_data_length(conn->record);
|
||||
conn->state = TLS_state_recv_record_data;
|
||||
conn->recv_state = TLS_state_recv_record_data;
|
||||
|
||||
case TLS_state_recv_record_data:
|
||||
while (conn->recordlen) {
|
||||
@@ -1182,11 +1182,11 @@ int tls13_do_recv(TLS_CONNECT *conn)
|
||||
conn->recordlen -= n;
|
||||
conn->record_offset += n;
|
||||
}
|
||||
conn->state = 0;
|
||||
conn->recv_state = 0;
|
||||
break;
|
||||
|
||||
default:
|
||||
fprintf(stderr, "conn->state = %d\n", conn->state);
|
||||
fprintf(stderr, "conn->recv_state = %d\n", conn->recv_state);
|
||||
error_print();
|
||||
return -1;
|
||||
}
|
||||
@@ -1354,11 +1354,15 @@ int tls13_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen)
|
||||
// 这里需要考虑max_early_data_size的问题
|
||||
int tls13_recv_early_data(TLS_CONNECT *conn)
|
||||
{
|
||||
int ret;
|
||||
|
||||
tls_trace("recv EarlyData\n");
|
||||
|
||||
if (tls13_do_recv(conn) != 1) {
|
||||
error_print();
|
||||
return -1;
|
||||
if ((ret = tls13_do_recv(conn)) != 1) {
|
||||
if (ret != TLS_ERROR_RECV_AGAIN && ret != TLS_ERROR_SEND_AGAIN) {
|
||||
error_print();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
memcpy(conn->early_data_buf, conn->data, conn->datalen);
|
||||
conn->early_data_len = conn->datalen;
|
||||
@@ -7792,10 +7796,10 @@ int tls13_send_alert(TLS_CONNECT *conn, int alert)
|
||||
tls13_record_print(stderr, 0, 0, conn->plain_record, conn->plain_recordlen);
|
||||
|
||||
|
||||
switch (conn->state) {
|
||||
case TLS_handshake_client_hello:
|
||||
case TLS_handshake_server_hello:
|
||||
case TLS_handshake_hello_retry_request:
|
||||
switch (conn->handshake_state) {
|
||||
case TLS_state_client_hello:
|
||||
case TLS_state_server_hello:
|
||||
case TLS_state_hello_retry_request:
|
||||
tls_socket_send(conn->sock, conn->plain_record, conn->plain_recordlen, 0);
|
||||
break;
|
||||
default:
|
||||
@@ -8589,7 +8593,7 @@ int tls13_do_client_handshake(TLS_CONNECT *conn)
|
||||
int ret;
|
||||
int next_state;
|
||||
|
||||
switch (conn->state) {
|
||||
switch (conn->handshake_state) {
|
||||
case TLS_state_hello_retry_request:
|
||||
case TLS_state_client_hello_again:
|
||||
case TLS_state_server_hello:
|
||||
@@ -8604,7 +8608,7 @@ int tls13_do_client_handshake(TLS_CONNECT *conn)
|
||||
break;
|
||||
}
|
||||
|
||||
switch (conn->state) {
|
||||
switch (conn->handshake_state) {
|
||||
case TLS_state_client_hello:
|
||||
ret = tls13_send_client_hello(conn);
|
||||
next_state = TLS_state_hello_retry_request;
|
||||
@@ -8711,7 +8715,7 @@ int tls13_do_client_handshake(TLS_CONNECT *conn)
|
||||
}
|
||||
}
|
||||
|
||||
conn->state = next_state;
|
||||
conn->handshake_state = next_state;
|
||||
|
||||
// ret == 0 means this step is bypassed
|
||||
if (ret == 1) {
|
||||
@@ -8727,7 +8731,7 @@ int tls13_do_server_handshake(TLS_CONNECT *conn)
|
||||
int ret;
|
||||
int next_state;
|
||||
|
||||
switch (conn->state) {
|
||||
switch (conn->handshake_state) {
|
||||
case TLS_state_client_hello:
|
||||
ret = tls13_recv_client_hello(conn);
|
||||
if (conn->early_data)
|
||||
@@ -8854,7 +8858,7 @@ int tls13_do_server_handshake(TLS_CONNECT *conn)
|
||||
|
||||
}
|
||||
|
||||
conn->state = next_state;
|
||||
conn->handshake_state = next_state;
|
||||
|
||||
tls_clean_record(conn);
|
||||
|
||||
@@ -8865,7 +8869,7 @@ int tls13_client_handshake(TLS_CONNECT *conn)
|
||||
{
|
||||
int ret;
|
||||
|
||||
while (conn->state != TLS_state_handshake_over) {
|
||||
while (conn->handshake_state != TLS_state_handshake_over) {
|
||||
|
||||
ret = tls13_do_client_handshake(conn);
|
||||
|
||||
@@ -8879,8 +8883,6 @@ int tls13_client_handshake(TLS_CONNECT *conn)
|
||||
|
||||
}
|
||||
|
||||
conn->state = 0;
|
||||
|
||||
// TODO: cleanup conn?
|
||||
|
||||
return 1;
|
||||
@@ -8890,7 +8892,7 @@ int tls13_server_handshake(TLS_CONNECT *conn)
|
||||
{
|
||||
int ret;
|
||||
|
||||
while (conn->state != TLS_state_handshake_over) {
|
||||
while (conn->handshake_state != TLS_state_handshake_over) {
|
||||
|
||||
ret = tls13_do_server_handshake(conn);
|
||||
|
||||
@@ -8902,8 +8904,6 @@ int tls13_server_handshake(TLS_CONNECT *conn)
|
||||
}
|
||||
}
|
||||
|
||||
conn->state = 0;
|
||||
|
||||
// TODO: cleanup conn?
|
||||
|
||||
return 1;
|
||||
@@ -8912,82 +8912,59 @@ int tls13_server_handshake(TLS_CONNECT *conn)
|
||||
int tls13_do_connect(TLS_CONNECT *conn)
|
||||
{
|
||||
int ret;
|
||||
fd_set rfds;
|
||||
fd_set wfds;
|
||||
|
||||
// 应该把protocol_version的初始化放在这里
|
||||
|
||||
conn->state = TLS_state_client_hello;
|
||||
//sm3_init(&conn->sm3_ctx);
|
||||
|
||||
while (1) {
|
||||
|
||||
ret = tls13_client_handshake(conn);
|
||||
if (ret == 1) {
|
||||
break;
|
||||
|
||||
} else if (ret == TLS_ERROR_SEND_AGAIN) {
|
||||
FD_ZERO(&rfds);
|
||||
FD_ZERO(&wfds);
|
||||
FD_SET(conn->sock, &rfds);
|
||||
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
|
||||
|
||||
} else if (ret == TLS_ERROR_RECV_AGAIN) {
|
||||
FD_ZERO(&rfds);
|
||||
FD_ZERO(&wfds);
|
||||
FD_SET(conn->sock, &wfds);
|
||||
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
|
||||
|
||||
} else {
|
||||
error_print();
|
||||
return -1;
|
||||
}
|
||||
if (!conn || !conn->is_client) {
|
||||
error_print();
|
||||
return -1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "tls13_do_connect: connected\n");
|
||||
if (conn->handshake_state == TLS_state_handshake_over) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 1;
|
||||
if (conn->handshake_state == TLS_state_handshake_init) {
|
||||
conn->handshake_state = TLS_state_client_hello;
|
||||
}
|
||||
|
||||
ret = tls13_client_handshake(conn);
|
||||
if (ret == 1) {
|
||||
conn->handshake_state = TLS_state_handshake_over;
|
||||
return 1;
|
||||
}
|
||||
if (ret == TLS_ERROR_RECV_AGAIN || ret == TLS_ERROR_SEND_AGAIN) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
error_print();
|
||||
return -1;
|
||||
}
|
||||
|
||||
int tls13_do_accept(TLS_CONNECT *conn)
|
||||
{
|
||||
int ret;
|
||||
fd_set rfds;
|
||||
fd_set wfds;
|
||||
|
||||
conn->state = TLS_state_client_hello;
|
||||
|
||||
//sm3_init(&conn->sm3_ctx);
|
||||
|
||||
fprintf(stderr, "tls13_do_accept\n");
|
||||
|
||||
|
||||
while (1) {
|
||||
|
||||
ret = tls13_server_handshake(conn);
|
||||
|
||||
if (ret == 1) {
|
||||
break;
|
||||
|
||||
} else if (ret == TLS_ERROR_SEND_AGAIN) {
|
||||
FD_ZERO(&rfds);
|
||||
FD_ZERO(&wfds);
|
||||
FD_SET(conn->sock, &rfds);
|
||||
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
|
||||
|
||||
} else if (ret == TLS_ERROR_RECV_AGAIN) {
|
||||
FD_ZERO(&rfds);
|
||||
FD_ZERO(&wfds);
|
||||
FD_SET(conn->sock, &wfds);
|
||||
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
|
||||
|
||||
} else {
|
||||
error_print();
|
||||
return -1;
|
||||
}
|
||||
if (!conn || conn->is_client) {
|
||||
error_print();
|
||||
return -1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "tls13_do_accept: connected\n");
|
||||
if (conn->handshake_state == TLS_state_handshake_over) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 1;
|
||||
if (conn->handshake_state == TLS_state_handshake_init) {
|
||||
conn->handshake_state = TLS_state_client_hello;
|
||||
}
|
||||
|
||||
ret = tls13_server_handshake(conn);
|
||||
if (ret == 1) {
|
||||
conn->handshake_state = TLS_state_handshake_over;
|
||||
return 1;
|
||||
}
|
||||
if (ret == TLS_ERROR_RECV_AGAIN || ret == TLS_ERROR_SEND_AGAIN) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
error_print();
|
||||
return -1;
|
||||
}
|
||||
|
||||
@@ -23,6 +23,53 @@
|
||||
#endif
|
||||
|
||||
|
||||
static int set_socket_nonblocking(tls_socket_t sock)
|
||||
{
|
||||
#ifdef WIN32
|
||||
u_long mode = 1;
|
||||
if (ioctlsocket(sock, FIONBIO, &mode) != 0) {
|
||||
error_print();
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
int flags;
|
||||
if ((flags = fcntl(sock, F_GETFL)) < 0
|
||||
|| fcntl(sock, F_SETFL, flags | O_NONBLOCK) < 0) {
|
||||
error_print();
|
||||
return -1;
|
||||
}
|
||||
#endif
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int do_handshake_select(TLS_CONNECT *conn)
|
||||
{
|
||||
int ret;
|
||||
fd_set rfds;
|
||||
fd_set wfds;
|
||||
|
||||
for (;;) {
|
||||
ret = tls_do_handshake(conn);
|
||||
if (ret == 1) {
|
||||
return 1;
|
||||
}
|
||||
FD_ZERO(&rfds);
|
||||
FD_ZERO(&wfds);
|
||||
if (ret == TLS_ERROR_RECV_AGAIN) {
|
||||
FD_SET(conn->sock, &rfds);
|
||||
} else if (ret == TLS_ERROR_SEND_AGAIN) {
|
||||
FD_SET(conn->sock, &wfds);
|
||||
} else {
|
||||
error_print();
|
||||
return -1;
|
||||
}
|
||||
if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) {
|
||||
error_print();
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static const char *http_get =
|
||||
"GET / HTTP/1.1\r\n"
|
||||
"Hostname: aaa\r\n"
|
||||
@@ -629,7 +676,12 @@ bad:
|
||||
goto end;
|
||||
}
|
||||
|
||||
if (tls_do_handshake(&conn) != 1) {
|
||||
if (set_socket_nonblocking(sock) != 1) {
|
||||
error_print();
|
||||
goto end;
|
||||
}
|
||||
|
||||
if (do_handshake_select(&conn) != 1) {
|
||||
fprintf(stderr, "%s: error\n", prog);
|
||||
goto end;
|
||||
}
|
||||
|
||||
@@ -52,6 +52,53 @@ static const char *help =
|
||||
"\n";
|
||||
|
||||
|
||||
static int set_socket_nonblocking(tls_socket_t sock)
|
||||
{
|
||||
#ifdef WIN32
|
||||
u_long mode = 1;
|
||||
if (ioctlsocket(sock, FIONBIO, &mode) != 0) {
|
||||
error_print();
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
int flags;
|
||||
if ((flags = fcntl(sock, F_GETFL)) < 0
|
||||
|| fcntl(sock, F_SETFL, flags | O_NONBLOCK) < 0) {
|
||||
error_print();
|
||||
return -1;
|
||||
}
|
||||
#endif
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int do_handshake_select(TLS_CONNECT *conn)
|
||||
{
|
||||
int ret;
|
||||
fd_set rfds;
|
||||
fd_set wfds;
|
||||
|
||||
for (;;) {
|
||||
ret = tls_do_handshake(conn);
|
||||
if (ret == 1) {
|
||||
return 1;
|
||||
}
|
||||
FD_ZERO(&rfds);
|
||||
FD_ZERO(&wfds);
|
||||
if (ret == TLS_ERROR_RECV_AGAIN) {
|
||||
FD_SET(conn->sock, &rfds);
|
||||
} else if (ret == TLS_ERROR_SEND_AGAIN) {
|
||||
FD_SET(conn->sock, &wfds);
|
||||
} else {
|
||||
error_print();
|
||||
return -1;
|
||||
}
|
||||
if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) {
|
||||
error_print();
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
int tls13_server_main(int argc , char **argv)
|
||||
{
|
||||
@@ -468,6 +515,11 @@ bad:
|
||||
goto end;
|
||||
}
|
||||
|
||||
if (set_socket_nonblocking(conn_sock) != 1) {
|
||||
error_print();
|
||||
goto end;
|
||||
}
|
||||
|
||||
|
||||
// pre_shared_key from external
|
||||
if (psk_keys_cnt) {
|
||||
@@ -511,7 +563,7 @@ bad:
|
||||
tls13_enable_pre_shared_key(&conn, 1);
|
||||
}
|
||||
|
||||
if (tls_do_handshake(&conn) != 1) {
|
||||
if (do_handshake_select(&conn) != 1) {
|
||||
error_print();
|
||||
goto end;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user