diff --git a/src/tlcp.c b/src/tlcp.c index 363d3289..e961c444 100644 --- a/src/tlcp.c +++ b/src/tlcp.c @@ -983,33 +983,41 @@ int tlcp_send_client_key_exchange(TLS_CONNECT *conn) { uint8_t enced_pre_master_secret[SM2_MAX_CIPHERTEXT_SIZE]; size_t enced_pre_master_secret_len; + int ret; tls_trace("send ClientKeyExchange\n"); - if (tls_pre_master_secret_generate(conn->pre_master_secret, TLS_protocol_tlcp) != 1) { - error_print(); - return -1; + if (!conn->recordlen) { + if (tls_pre_master_secret_generate(conn->pre_master_secret, TLS_protocol_tlcp) != 1) { + error_print(); + return -1; + } + + if (sm2_encrypt(&conn->server_enc_key.u.sm2_key, conn->pre_master_secret, 48, + enced_pre_master_secret, &enced_pre_master_secret_len) != 1 + || tls_record_set_handshake_client_key_exchange_pke(conn->record, &conn->recordlen, + enced_pre_master_secret, enced_pre_master_secret_len) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_internal_error); + return -1; + } + tlcp_record_print(stderr, 0, 0, conn->record, conn->recordlen); + + if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { + error_print(); + return -1; + } + tls_handshake_digest_print(stderr, 0, 0, "ClientKeyExchange", &conn->dgst_ctx); } - if (sm2_encrypt(&conn->server_enc_key.u.sm2_key, conn->pre_master_secret, 48, - enced_pre_master_secret, &enced_pre_master_secret_len) != 1 - || tls_record_set_handshake_client_key_exchange_pke(conn->record, &conn->recordlen, - enced_pre_master_secret, enced_pre_master_secret_len) != 1) { - error_print(); - tls_send_alert(conn, TLS_alert_internal_error); - return -1; - } - tlcp_record_print(stderr, 0, 0, conn->record, conn->recordlen); - if (tls_record_send(conn->record, conn->recordlen, conn->sock) != 1) { - error_print(); - return -1; + if ((ret = tls_send_record(conn)) != 1) { + if (ret != TLS_ERROR_SEND_AGAIN) { + error_print(); + } + return ret; } - if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { - error_print(); - return -1; - } - tls_handshake_digest_print(stderr, 0, 0, "ClientKeyExchange", &conn->dgst_ctx); + tls_clean_record(conn); if (tlcp_generate_keys(conn) != 1) { error_print(); @@ -2059,11 +2067,18 @@ int tlcp_recv_client_key_exchange(TLS_CONNECT *conn) size_t enced_pms_len; size_t pre_master_secret_len; X509_KEY *enc_key; + int ret; tls_trace("recv ClientKeyExchange\n"); - if (tls_record_recv(conn->record, &conn->recordlen, conn->sock) != 1 - || tls_record_protocol(conn->record) != TLS_protocol_tlcp) { + if ((ret = tls_recv_record(conn)) != 1) { + if (ret != TLS_ERROR_RECV_AGAIN) { + error_print(); + } + return ret; + } + + if (tls_record_protocol(conn->record) != TLS_protocol_tlcp) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; @@ -2501,75 +2516,45 @@ int tlcp_server_handshake(TLS_CONNECT *conn) int tlcp_do_connect(TLS_CONNECT *conn) { int ret; - fd_set rfds; - fd_set wfds; // 应该把protocol_version的初始化放在这里 - conn->handshake_state = TLS_state_client_hello; - //sm3_init(&conn->sm3_ctx); - - while (1) { - - ret = tlcp_client_handshake(conn); - if (ret == 1) { - break; - - } else if (ret == TLS_ERROR_SEND_AGAIN) { - FD_ZERO(&rfds); - FD_ZERO(&wfds); - FD_SET(conn->sock, &rfds); - select(conn->sock + 1, &rfds, &wfds, NULL, NULL); - - } else if (ret == TLS_ERROR_RECV_AGAIN) { - FD_ZERO(&rfds); - FD_ZERO(&wfds); - FD_SET(conn->sock, &wfds); - select(conn->sock + 1, &rfds, &wfds, NULL, NULL); - - } else { - error_print(); - return -1; - } + if (conn->handshake_state == TLS_state_handshake_over) { + return 1; } - return 1; + if (conn->handshake_state == TLS_state_handshake_init) { + conn->handshake_state = TLS_state_client_hello; + } + + ret = tlcp_client_handshake(conn); + if (ret == 1 + || ret == TLS_ERROR_RECV_AGAIN + || ret == TLS_ERROR_SEND_AGAIN) { + return ret; + } + error_print(); + return -1; } int tlcp_do_accept(TLS_CONNECT *conn) { int ret; - fd_set rfds; - fd_set wfds; - conn->handshake_state = TLS_state_client_hello; - - //sm3_init(&conn->sm3_ctx); - - while (1) { - - ret = tlcp_server_handshake(conn); - - if (ret == 1) { - break; - - } else if (ret == TLS_ERROR_SEND_AGAIN) { - FD_ZERO(&rfds); - FD_ZERO(&wfds); - FD_SET(conn->sock, &rfds); - select(conn->sock + 1, &rfds, &wfds, NULL, NULL); - - } else if (ret == TLS_ERROR_RECV_AGAIN) { - FD_ZERO(&rfds); - FD_ZERO(&wfds); - FD_SET(conn->sock, &wfds); - select(conn->sock + 1, &rfds, &wfds, NULL, NULL); - - } else { - error_print(); - return -1; - } + if (conn->handshake_state == TLS_state_handshake_over) { + return 1; } - return 1; + if (conn->handshake_state == TLS_state_handshake_init) { + conn->handshake_state = TLS_state_client_hello; + } + + ret = tlcp_server_handshake(conn); + if (ret == 1 + || ret == TLS_ERROR_RECV_AGAIN + || ret == TLS_ERROR_SEND_AGAIN) { + return ret; + } + error_print(); + return -1; } diff --git a/src/tls12.c b/src/tls12.c index 5485544d..2efb145b 100644 --- a/src/tls12.c +++ b/src/tls12.c @@ -3783,77 +3783,45 @@ int tls12_server_handshake(TLS_CONNECT *conn) int tls12_do_connect(TLS_CONNECT *conn) { int ret; - fd_set rfds; - fd_set wfds; - conn->handshake_state = TLS_state_client_hello; - //sm3_init(&conn->sm3_ctx); - - - digest_init(&conn->dgst_ctx, DIGEST_sm3()); - - while (1) { - - ret = tls12_client_handshake(conn); - if (ret == 1) { - break; - - } else if (ret == TLS_ERROR_SEND_AGAIN) { - FD_ZERO(&rfds); - FD_ZERO(&wfds); - FD_SET(conn->sock, &wfds); - select(conn->sock + 1, &rfds, &wfds, NULL, NULL); - - } else if (ret == TLS_ERROR_RECV_AGAIN) { - FD_ZERO(&rfds); - FD_ZERO(&wfds); - FD_SET(conn->sock, &rfds); - select(conn->sock + 1, &rfds, &wfds, NULL, NULL); - - } else { - error_print(); - return -1; - } + if (conn->handshake_state == TLS_state_handshake_over) { + return 1; } - return 1; + if (conn->handshake_state == TLS_state_handshake_init) { + conn->handshake_state = TLS_state_client_hello; + digest_init(&conn->dgst_ctx, DIGEST_sm3()); + } + + ret = tls12_client_handshake(conn); + if (ret == 1 + || ret == TLS_ERROR_RECV_AGAIN + || ret == TLS_ERROR_SEND_AGAIN) { + return ret; + } + error_print(); + return -1; } int tls12_do_accept(TLS_CONNECT *conn) { int ret; - fd_set rfds; - fd_set wfds; - conn->handshake_state = TLS_state_client_hello; - - //sm3_init(&conn->sm3_ctx); - digest_init(&conn->dgst_ctx, DIGEST_sm3()); - - while (1) { - - ret = tls12_server_handshake(conn); - - if (ret == 1) { - break; - - } else if (ret == TLS_ERROR_SEND_AGAIN) { - FD_ZERO(&rfds); - FD_ZERO(&wfds); - FD_SET(conn->sock, &rfds); - select(conn->sock + 1, &rfds, &wfds, NULL, NULL); - - } else if (ret == TLS_ERROR_RECV_AGAIN) { - FD_ZERO(&rfds); - FD_ZERO(&wfds); - FD_SET(conn->sock, &wfds); - select(conn->sock + 1, &rfds, &wfds, NULL, NULL); - - } else { - error_print(); - return -1; - } + if (conn->handshake_state == TLS_state_handshake_over) { + return 1; } - return 1; + if (conn->handshake_state == TLS_state_handshake_init) { + conn->handshake_state = TLS_state_client_hello; + digest_init(&conn->dgst_ctx, DIGEST_sm3()); + } + + ret = tls12_server_handshake(conn); + if (ret == 1 + || ret == TLS_ERROR_RECV_AGAIN + || ret == TLS_ERROR_SEND_AGAIN) { + return ret; + } + error_print(); + return -1; } diff --git a/tools/tlcp_client.c b/tools/tlcp_client.c index c3071ac2..4c8079b9 100644 --- a/tools/tlcp_client.c +++ b/tools/tlcp_client.c @@ -54,6 +54,53 @@ static const char *help = "\n"; +static int set_socket_nonblocking(tls_socket_t sock) +{ +#ifdef WIN32 + u_long mode = 1; + if (ioctlsocket(sock, FIONBIO, &mode) != 0) { + error_print(); + return -1; + } +#else + int flags; + if ((flags = fcntl(sock, F_GETFL)) < 0 + || fcntl(sock, F_SETFL, flags | O_NONBLOCK) < 0) { + error_print(); + return -1; + } +#endif + return 1; +} + +static int do_handshake_select(TLS_CONNECT *conn) +{ + int ret; + fd_set rfds; + fd_set wfds; + + for (;;) { + ret = tls_do_handshake(conn); + if (ret == 1) { + return 1; + } + FD_ZERO(&rfds); + FD_ZERO(&wfds); + if (ret == TLS_ERROR_RECV_AGAIN) { + FD_SET(conn->sock, &rfds); + } else if (ret == TLS_ERROR_SEND_AGAIN) { + FD_SET(conn->sock, &wfds); + } else { + error_print(); + return -1; + } + if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) { + error_print(); + return -1; + } + } +} + int tlcp_client_main(int argc, char *argv[]) { int ret = -1; @@ -317,7 +364,12 @@ bad: goto end; } - if (tls_do_handshake(&conn) != 1) { + if (set_socket_nonblocking(sock) != 1) { + error_print(); + goto end; + } + + if (do_handshake_select(&conn) != 1) { fprintf(stderr, "%s: error\n", prog); goto end; } @@ -359,6 +411,7 @@ bad: for (;;) { int rv; + len = sizeof(buf); rv = tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len); if (rv == 1) { @@ -367,11 +420,26 @@ bad: } else if (rv == 0) { fprintf(stderr, "%s: TLCP connection is closed by remote host\n", prog); goto end; - } else if (rv == -EAGAIN) { - // when timeout, tls_recv return -EAGAIN (-11) - tls_shutdown(&conn); - ret = 0; - goto end; + } else if (rv == -EAGAIN + || rv == TLS_ERROR_RECV_AGAIN + || rv == TLS_ERROR_SEND_AGAIN) { + fd_set fds; + struct timeval timeout; + int sel; + + timeout.tv_sec = TIMEOUT_SECONDS; + timeout.tv_usec = 0; + FD_ZERO(&fds); + FD_SET(conn.sock, &fds); + sel = select(conn.sock + 1, &fds, NULL, NULL, &timeout); + if (sel < 0) { + fprintf(stderr, "%s: select error\n", prog); + goto end; + } else if (sel == 0) { + tls_shutdown(&conn); + ret = 0; + goto end; + } } else { fprintf(stderr, "%s: tls_recv error\n", prog); goto end; @@ -432,6 +500,7 @@ bad: if (FD_ISSET(conn.sock, &fds)) { int rv; + len = sizeof(buf); rv = tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len); if (rv == 1) { @@ -440,10 +509,10 @@ bad: } else if (rv == 0) { fprintf(stderr, "Connection closed by remote host\n"); goto end; - } else if (rv == -EAGAIN) { - // should not happen - error_print(); - goto end; + } else if (rv == -EAGAIN + || rv == TLS_ERROR_RECV_AGAIN + || rv == TLS_ERROR_SEND_AGAIN) { + continue; } else { error_print(); fprintf(stderr, "%s: tls_recv error\n", prog); diff --git a/tools/tlcp_server.c b/tools/tlcp_server.c index fda0c29a..c06a7ed1 100644 --- a/tools/tlcp_server.c +++ b/tools/tlcp_server.c @@ -34,6 +34,53 @@ static const char *help = #include "tlcp_help.h" "\n"; +static int set_socket_nonblocking(tls_socket_t sock) +{ +#ifdef WIN32 + u_long mode = 1; + if (ioctlsocket(sock, FIONBIO, &mode) != 0) { + error_print(); + return -1; + } +#else + int flags; + if ((flags = fcntl(sock, F_GETFL)) < 0 + || fcntl(sock, F_SETFL, flags | O_NONBLOCK) < 0) { + error_print(); + return -1; + } +#endif + return 1; +} + +static int do_handshake_select(TLS_CONNECT *conn) +{ + int ret; + fd_set rfds; + fd_set wfds; + + for (;;) { + ret = tls_do_handshake(conn); + if (ret == 1) { + return 1; + } + FD_ZERO(&rfds); + FD_ZERO(&wfds); + if (ret == TLS_ERROR_RECV_AGAIN) { + FD_SET(conn->sock, &rfds); + } else if (ret == TLS_ERROR_SEND_AGAIN) { + FD_SET(conn->sock, &wfds); + } else { + error_print(); + return -1; + } + if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) { + error_print(); + return -1; + } + } +} + int tlcp_server_main(int argc , char **argv) { int ret = 1; @@ -207,7 +254,12 @@ restart: return -1; } - if (tls_do_handshake(&conn) != 1) { + if (set_socket_nonblocking(conn_sock) != 1) { + error_print(); + return -1; + } + + if (do_handshake_select(&conn) != 1) { error_print(); return -1; } @@ -216,10 +268,24 @@ restart: int rv; size_t sentlen; + fd_set fds; do { + FD_ZERO(&fds); + FD_SET(conn.sock, &fds); + + if (select((int)(conn.sock + 1), &fds, NULL, NULL, NULL) < 0) { + fprintf(stderr, "%s: select failed\n", prog); + goto end; + } + len = sizeof(buf); if ((rv = tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len)) != 1) { + if (rv == -EAGAIN + || rv == TLS_ERROR_RECV_AGAIN + || rv == TLS_ERROR_SEND_AGAIN) { + continue; + } if (rv < 0) fprintf(stderr, "%s: recv failure\n", prog); else fprintf(stderr, "%s: Disconnected by remote\n", prog); diff --git a/tools/tls12_client.c b/tools/tls12_client.c index 45c73f9e..83238f1f 100644 --- a/tools/tls12_client.c +++ b/tools/tls12_client.c @@ -49,6 +49,53 @@ static const char *help = #include "tls12_help.h" "\n"; +static int set_socket_nonblocking(tls_socket_t sock) +{ +#ifdef WIN32 + u_long mode = 1; + if (ioctlsocket(sock, FIONBIO, &mode) != 0) { + error_print(); + return -1; + } +#else + int flags; + if ((flags = fcntl(sock, F_GETFL)) < 0 + || fcntl(sock, F_SETFL, flags | O_NONBLOCK) < 0) { + error_print(); + return -1; + } +#endif + return 1; +} + +static int do_handshake_select(TLS_CONNECT *conn) +{ + int ret; + fd_set rfds; + fd_set wfds; + + for (;;) { + ret = tls_do_handshake(conn); + if (ret == 1) { + return 1; + } + FD_ZERO(&rfds); + FD_ZERO(&wfds); + if (ret == TLS_ERROR_RECV_AGAIN) { + FD_SET(conn->sock, &rfds); + } else if (ret == TLS_ERROR_SEND_AGAIN) { + FD_SET(conn->sock, &wfds); + } else { + error_print(); + return -1; + } + if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) { + error_print(); + return -1; + } + } +} + int tls12_client_main(int argc, char *argv[]) { int ret = -1; @@ -298,7 +345,12 @@ bad: goto end; } - if (tls_do_handshake(&conn) != 1) { + if (set_socket_nonblocking(sock) != 1) { + error_print(); + goto end; + } + + if (do_handshake_select(&conn) != 1) { fprintf(stderr, "%s: error\n", prog); goto end; } @@ -334,8 +386,16 @@ bad: if (FD_ISSET(conn.sock, &fds)) { for (;;) { + int rv; + memset(buf, 0, sizeof(buf)); - if (tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len) != 1) { + len = sizeof(buf); + if ((rv = tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len)) != 1) { + if (rv == -EAGAIN + || rv == TLS_ERROR_RECV_AGAIN + || rv == TLS_ERROR_SEND_AGAIN) { + break; + } goto end; } fwrite(buf, 1, len, stdout); diff --git a/tools/tls12_server.c b/tools/tls12_server.c index 9f6e70be..822625aa 100644 --- a/tools/tls12_server.c +++ b/tools/tls12_server.c @@ -40,6 +40,53 @@ static const char *help = "\n"; +static int set_socket_nonblocking(tls_socket_t sock) +{ +#ifdef WIN32 + u_long mode = 1; + if (ioctlsocket(sock, FIONBIO, &mode) != 0) { + error_print(); + return -1; + } +#else + int flags; + if ((flags = fcntl(sock, F_GETFL)) < 0 + || fcntl(sock, F_SETFL, flags | O_NONBLOCK) < 0) { + error_print(); + return -1; + } +#endif + return 1; +} + +static int do_handshake_select(TLS_CONNECT *conn) +{ + int ret; + fd_set rfds; + fd_set wfds; + + for (;;) { + ret = tls_do_handshake(conn); + if (ret == 1) { + return 1; + } + FD_ZERO(&rfds); + FD_ZERO(&wfds); + if (ret == TLS_ERROR_RECV_AGAIN) { + FD_SET(conn->sock, &rfds); + } else if (ret == TLS_ERROR_SEND_AGAIN) { + FD_SET(conn->sock, &wfds); + } else { + error_print(); + return -1; + } + if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) { + error_print(); + return -1; + } + } +} + int tls12_server_main(int argc , char **argv) { int ret = 1; @@ -302,7 +349,12 @@ restart: goto end; } - if (tls_do_handshake(&conn) != 1) { + if (set_socket_nonblocking(conn_sock) != 1) { + error_print(); + goto end; + } + + if (do_handshake_select(&conn) != 1) { error_print(); goto end; } @@ -311,10 +363,24 @@ restart: int rv; size_t sentlen; + fd_set fds; do { + FD_ZERO(&fds); + FD_SET(conn.sock, &fds); + + if (select((int)(conn.sock + 1), &fds, NULL, NULL, NULL) < 0) { + fprintf(stderr, "%s: select failed\n", prog); + goto end; + } + len = sizeof(buf); if ((rv = tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len)) != 1) { + if (rv == -EAGAIN + || rv == TLS_ERROR_RECV_AGAIN + || rv == TLS_ERROR_SEND_AGAIN) { + continue; + } if (rv < 0) fprintf(stderr, "%s: recv failure\n", prog); else fprintf(stderr, "%s: Disconnected by remote\n", prog);