From 8b586d4299abdc350fea69a35201cf0512698f5e Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Fri, 12 Jun 2026 14:09:42 +0800 Subject: [PATCH] Update TLS/TLCP shutdown --- include/gmssl/tls.h | 5 + src/tls.c | 314 ++++++++++++++++++++++++++----------------- src/tls12.c | 3 + src/tls13.c | 24 +++- tools/tlcp_client.c | 81 ++++++++++- tools/tlcp_server.c | 74 +++++++++- tools/tls12_client.c | 74 +++++++++- tools/tls12_server.c | 74 +++++++++- tools/tls13_client.c | 36 ++++- tools/tls13_server.c | 31 +++++ 10 files changed, 570 insertions(+), 146 deletions(-) diff --git a/include/gmssl/tls.h b/include/gmssl/tls.h index 22b889d7..106cf070 100644 --- a/include/gmssl/tls.h +++ b/include/gmssl/tls.h @@ -1007,6 +1007,9 @@ enum { TLS_state_send_record, TLS_state_recv_record_header, TLS_state_recv_record_data, + TLS_state_shutdown_send_close_notify, + TLS_state_shutdown_recv_close_notify, + TLS_state_shutdown_over, }; @@ -1022,6 +1025,8 @@ typedef struct { int handshake_state; int send_state; int recv_state; + int shutdown_state; + int close_notify_received; uint8_t record[TLS_MAX_RECORD_SIZE]; diff --git a/src/tls.c b/src/tls.c index 549312b6..dfde5f4f 100644 --- a/src/tls.c +++ b/src/tls.c @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -1797,6 +1798,7 @@ static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *i const uint8_t *fixed_iv; uint8_t *seq_num; size_t recordlen; + int ret; if (!conn) { error_print(); @@ -1806,84 +1808,107 @@ static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *i error_print(); return -1; } - - if (inlen > TLS_MAX_PLAINTEXT_SIZE) { - inlen = TLS_MAX_PLAINTEXT_SIZE; + if (conn->recv_state) { + *sentlen = 0; + return TLS_ERROR_RECV_AGAIN; } - - if (conn->datalen) { - error_puts("recv all buffered data before send"); - return -1; - } - - if (conn->is_client) { - hmac_ctx = &conn->client_write_mac_ctx; - enc_key = &conn->client_write_key; - fixed_iv = conn->client_write_iv; - seq_num = conn->client_seq_num; - } else { - hmac_ctx = &conn->server_write_mac_ctx; - enc_key = &conn->server_write_key; - fixed_iv = conn->server_write_iv; - seq_num = conn->server_seq_num; - } - - if (tls_record_set_type(conn->databuf, record_type) != 1 - || tls_record_set_protocol(conn->databuf, conn->protocol) != 1 - || tls_record_set_data(conn->databuf, in, inlen) != 1) { + if (conn->send_state && conn->send_state != TLS_state_send_record) { error_print(); return -1; } - tls_record_trace(stderr, conn->databuf, tls_record_length(conn->databuf), 0, 0); - if (conn->protocol == TLS_protocol_tls12) { - switch (conn->cipher_suite) { - case TLS_cipher_ecdhe_sm4_gcm_sm3: - case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: - if (tls12_record_gcm_encrypt(enc_key, fixed_iv, seq_num, + *sentlen = 0; + + if (!conn->recordlen) { + + if (inlen > TLS_MAX_PLAINTEXT_SIZE) { + inlen = TLS_MAX_PLAINTEXT_SIZE; + } + + if (conn->datalen) { + error_puts("recv all buffered data before send"); + return -1; + } + + if (conn->is_client) { + hmac_ctx = &conn->client_write_mac_ctx; + enc_key = &conn->client_write_key; + fixed_iv = conn->client_write_iv; + seq_num = conn->client_seq_num; + } else { + hmac_ctx = &conn->server_write_mac_ctx; + enc_key = &conn->server_write_key; + fixed_iv = conn->server_write_iv; + seq_num = conn->server_seq_num; + } + + if (tls_record_set_type(conn->databuf, record_type) != 1 + || tls_record_set_protocol(conn->databuf, conn->protocol) != 1 + || tls_record_set_data(conn->databuf, in, inlen) != 1) { + error_print(); + return -1; + } + tls_record_trace(stderr, conn->databuf, tls_record_length(conn->databuf), 0, 0); + + if (conn->protocol == TLS_protocol_tls12) { + switch (conn->cipher_suite) { + case TLS_cipher_ecdhe_sm4_gcm_sm3: + case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: + if (tls12_record_gcm_encrypt(enc_key, fixed_iv, seq_num, + conn->databuf, tls_record_length(conn->databuf), + conn->record, &recordlen) != 1) { + error_print(); + return -1; + } + break; + case TLS_cipher_ecdhe_sm4_cbc_sm3: + case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: + if (tls_record_cbc_encrypt(hmac_ctx, enc_key, seq_num, + conn->databuf, tls_record_length(conn->databuf), + conn->record, &recordlen) != 1) { + error_print(); + return -1; + } + break; + default: + error_print(); + return -1; + } + } else if (conn->protocol == TLS_protocol_tlcp) { + if (tlcp_record_encrypt(conn->cipher_suite, hmac_ctx, enc_key, fixed_iv, seq_num, conn->databuf, tls_record_length(conn->databuf), conn->record, &recordlen) != 1) { error_print(); return -1; } - break; - case TLS_cipher_ecdhe_sm4_cbc_sm3: - case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: + } else { if (tls_record_cbc_encrypt(hmac_ctx, enc_key, seq_num, conn->databuf, tls_record_length(conn->databuf), conn->record, &recordlen) != 1) { error_print(); return -1; } - break; - default: - error_print(); - return -1; } - } else if (conn->protocol == TLS_protocol_tlcp) { - if (tlcp_record_encrypt(conn->cipher_suite, hmac_ctx, enc_key, fixed_iv, seq_num, - conn->databuf, tls_record_length(conn->databuf), - conn->record, &recordlen) != 1) { - error_print(); - return -1; - } - } else { - if (tls_record_cbc_encrypt(hmac_ctx, enc_key, seq_num, - conn->databuf, tls_record_length(conn->databuf), - conn->record, &recordlen) != 1) { - error_print(); - return -1; - } - } - tls_seq_num_incr(seq_num); + tls_seq_num_incr(seq_num); - if (tls_record_send(conn->record, recordlen, conn->sock) != 1) { - error_print(); - return -1; + conn->recordlen = recordlen; + conn->record_offset = 0; + conn->sentlen = inlen; + conn->send_state = TLS_state_send_record; + tls_encrypted_record_trace(stderr, conn->record, recordlen, 0, 0); } - tls_encrypted_record_trace(stderr, conn->record, recordlen, 0, 0); - *sentlen = inlen; + ret = tls_send_record(conn); + if (ret != 1) { + if (ret != TLS_ERROR_SEND_AGAIN) { + error_print(); + } + return ret; + } + + *sentlen = conn->sentlen; + conn->send_state = 0; + tls_clean_record(conn); return 1; } @@ -1911,10 +1936,20 @@ int tls_decrypt_recv(TLS_CONNECT *conn) } tls_trace("recv Encrypted Record\n"); - if ((ret = tls_record_recv(record, &recordlen, conn->sock)) != 1) { - if (ret < 0 && ret != -EAGAIN) error_print(); + if (conn->send_state) { + return TLS_ERROR_SEND_AGAIN; + } + conn->recv_state = TLS_state_recv_record_header; + if ((ret = tls_recv_record(conn)) != 1) { + if (ret != TLS_ERROR_RECV_AGAIN) { + conn->recv_state = 0; + tls_clean_record(conn); + error_print(); + } return ret; } + conn->recv_state = 0; + recordlen = conn->recordlen; tls_encrypted_record_trace(stderr, record, recordlen, 0, 0); if (conn->protocol == TLS_protocol_tls12) { @@ -1982,14 +2017,18 @@ static int tls12_tlcp_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_ if (conn->datalen == 0) { int ret; if ((ret = tls_decrypt_recv(conn)) != 1) { - if (ret < 0 && ret != -EAGAIN) error_print(); + if (ret != TLS_ERROR_RECV_AGAIN && ret != TLS_ERROR_SEND_AGAIN) { + error_print(); + } return ret; } switch (tls_record_type(conn->record)) { case TLS_record_application_data: + tls_clean_record(conn); break; case TLS_record_change_cipher_spec: + tls_clean_record(conn); error_print(); return -1; case TLS_record_alert: @@ -2000,12 +2039,20 @@ static int tls12_tlcp_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_ tls_record_get_alert(conn->databuf, &level, &alert); if (alert == TLS_alert_close_notify) { tls_trace("recv Alert.close_notify\n"); + conn->close_notify_received = 1; + conn->data = NULL; + conn->datalen = 0; + tls_clean_record(conn); return 0; } tls_trace("alert received\n"); + conn->data = NULL; + conn->datalen = 0; + tls_clean_record(conn); return -1; } default: + tls_clean_record(conn); error_print(); return -1; } @@ -2038,39 +2085,7 @@ int tls_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen) } } -static int tls12_tlcp_shutdown(TLS_CONNECT *conn) -{ - int ret; - size_t recordlen; - uint8_t alert[2]; - alert[0] = TLS_alert_level_fatal; - alert[1] = TLS_alert_close_notify; - - if (!conn) { - error_print(); - return -1; - } - - tls_trace("send Alert.close_notify\n"); - - if (tls_encrypt_send(conn, TLS_record_alert, alert, sizeof(alert), &recordlen) != 1) { - error_print(); - return -1; - } - - tls_trace("recv Alert.close_notify\n"); - - if ((ret = tls_decrypt_recv(conn)) != 1) { - if (ret == 0) tls_trace("Connection closed by remote without close_notify\n"); - else if (ret == -EAGAIN) tls_trace("-EAGAIN\n"); - else error_print(); - return -1; - } - - return 1; -} - -static int tls13_shutdown(TLS_CONNECT *conn) +static int tls13_send_close_notify(TLS_CONNECT *conn) { int ret; const BLOCK_CIPHER_KEY *key; @@ -2083,28 +2098,31 @@ static int tls13_shutdown(TLS_CONNECT *conn) return -1; } - if (conn->is_client) { - key = &conn->client_write_key; - iv = conn->client_write_iv; - seq_num = conn->client_seq_num; - } else { - key = &conn->server_write_key; - iv = conn->server_write_iv; - seq_num = conn->server_seq_num; - } + if (!conn->recordlen) { + if (conn->is_client) { + key = &conn->client_write_key; + iv = conn->client_write_iv; + seq_num = conn->client_seq_num; + } else { + key = &conn->server_write_key; + iv = conn->server_write_iv; + seq_num = conn->server_seq_num; + } - tls_trace("send Alert.close_notify\n"); + tls_trace("send Alert.close_notify\n"); - tls_record_set_alert(conn->plain_record, &conn->plain_recordlen, - TLS_alert_level_warning, TLS_alert_close_notify); - tls13_padding_len_rand(&padding_len); - if (tls13_record_encrypt(key, iv, seq_num, conn->plain_record, conn->plain_recordlen, - padding_len, conn->record, &conn->recordlen) != 1) { - error_print(); - return -1; + tls_record_set_alert(conn->plain_record, &conn->plain_recordlen, + TLS_alert_level_warning, TLS_alert_close_notify); + tls13_padding_len_rand(&padding_len); + if (tls13_record_encrypt(key, iv, seq_num, conn->plain_record, conn->plain_recordlen, + padding_len, conn->record, &conn->recordlen) != 1) { + error_print(); + return -1; + } + tls_seq_num_incr(seq_num); + conn->record_offset = 0; + conn->send_state = TLS_state_send_record; } - tls_seq_num_incr(seq_num); - conn->record_offset = 0; ret = tls_send_record(conn); if (ret != 1) { @@ -2114,26 +2132,80 @@ static int tls13_shutdown(TLS_CONNECT *conn) return ret; } + conn->send_state = 0; + tls_clean_record(conn); return 1; } -int tls_shutdown(TLS_CONNECT *conn) +static int tls_send_close_notify(TLS_CONNECT *conn) { + size_t sentlen; + uint8_t alert[2]; + if (!conn) { error_print(); return -1; } - switch (conn->protocol) { - case TLS_protocol_tlcp: - case TLS_protocol_tls12: - return tls12_tlcp_shutdown(conn); - case TLS_protocol_tls13: - return tls13_shutdown(conn); - default: + if (conn->protocol == TLS_protocol_tls13) { + return tls13_send_close_notify(conn); + } + + alert[0] = TLS_alert_level_warning; + alert[1] = TLS_alert_close_notify; + tls_trace("send Alert.close_notify\n"); + return tls_encrypt_send(conn, TLS_record_alert, alert, sizeof(alert), &sentlen); +} + +int tls_shutdown(TLS_CONNECT *conn) +{ + int ret; + uint8_t buf[1]; + size_t len; + + if (!conn) { error_print(); return -1; } + + if (conn->shutdown_state == TLS_state_shutdown_over) { + return 1; + } + if (!conn->shutdown_state) { + conn->shutdown_state = TLS_state_shutdown_send_close_notify; + } + + if (conn->shutdown_state == TLS_state_shutdown_send_close_notify) { + if ((ret = tls_send_close_notify(conn)) != 1) { + return ret; + } + if (conn->close_notify_received) { + conn->shutdown_state = TLS_state_shutdown_over; + return 1; + } + conn->shutdown_state = TLS_state_shutdown_recv_close_notify; + } + + if (conn->shutdown_state == TLS_state_shutdown_recv_close_notify) { + tls_trace("recv Alert.close_notify\n"); + ret = tls_recv(conn, buf, sizeof(buf), &len); + if (ret == 0 && conn->close_notify_received) { + conn->shutdown_state = TLS_state_shutdown_over; + return 1; + } + if (ret == TLS_ERROR_RECV_AGAIN || ret == TLS_ERROR_SEND_AGAIN) { + return ret; + } + if (ret == TLS_ERROR_TCP_CLOSED) { + tls_trace("Connection closed by remote without close_notify\n"); + return ret; + } + error_print(); + return -1; + } + + error_print(); + return -1; } int tls_authorities_from_certs(uint8_t *names, size_t *nameslen, size_t maxlen, const uint8_t *certs, size_t certslen) diff --git a/src/tls12.c b/src/tls12.c index 2efb145b..6743aaf8 100644 --- a/src/tls12.c +++ b/src/tls12.c @@ -303,6 +303,9 @@ int tls_send_record(TLS_CONNECT *conn) error_print(); return -1; } + } else if (n == 0) { + error_print(); + return TLS_ERROR_TCP_CLOSED; } conn->record_offset += n; left -= n; diff --git a/src/tls13.c b/src/tls13.c index da620a08..d013be70 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -1147,13 +1147,17 @@ int tls13_do_recv(TLS_CONNECT *conn) case TLS_state_recv_record_header: while (conn->recordlen) { - if ((n = tls_socket_recv(conn->sock, conn->record + conn->record_offset, conn->recordlen, 0)) <= 0) { + n = tls_socket_recv(conn->sock, conn->record + conn->record_offset, conn->recordlen, 0); + if (n < 0) { if (errno == EAGAIN || errno == EWOULDBLOCK) { return TLS_ERROR_RECV_AGAIN; } else { error_print(); return -1; } + } else if (n == 0) { + error_print(); + return TLS_ERROR_TCP_CLOSED; } conn->recordlen -= n; conn->record_offset += n; @@ -1171,13 +1175,17 @@ int tls13_do_recv(TLS_CONNECT *conn) case TLS_state_recv_record_data: while (conn->recordlen) { - if ((n = tls_socket_recv(conn->sock, conn->record + conn->record_offset, conn->recordlen, 0)) <= 0) { + n = tls_socket_recv(conn->sock, conn->record + conn->record_offset, conn->recordlen, 0); + if (n < 0) { if (errno == EAGAIN || errno == EWOULDBLOCK) { return TLS_ERROR_RECV_AGAIN; } else { error_print(); return -1; } + } else if (n == 0) { + error_print(); + return TLS_ERROR_TCP_CLOSED; } conn->recordlen -= n; conn->record_offset += n; @@ -1310,6 +1318,16 @@ int tls13_do_recv(TLS_CONNECT *conn) error_print(); return -1; } + if (alert_description == TLS_alert_close_notify) { + tls_trace("recv Alert.close_notify\n"); + conn->close_notify_received = 1; + conn->data = NULL; + conn->datalen = 0; + conn->record_offset = 0; + conn->recordlen = 0; + conn->plain_recordlen = 0; + return 0; + } } return -1; @@ -1331,7 +1349,7 @@ int tls13_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen) if (conn->datalen == 0) { int ret; if ((ret = tls13_do_recv(conn)) != 1) { - if (ret != TLS_ERROR_RECV_AGAIN && ret != TLS_ERROR_SEND_AGAIN) { + if (ret != 0 && ret != TLS_ERROR_RECV_AGAIN && ret != TLS_ERROR_SEND_AGAIN) { error_print(); } return ret; diff --git a/tools/tlcp_client.c b/tools/tlcp_client.c index 4c8079b9..84ac6ade 100644 --- a/tools/tlcp_client.c +++ b/tools/tlcp_client.c @@ -101,6 +101,67 @@ static int do_handshake_select(TLS_CONNECT *conn) } } +static int do_shutdown_select(TLS_CONNECT *conn) +{ + int ret; + fd_set rfds; + fd_set wfds; + + for (;;) { + ret = tls_shutdown(conn); + if (ret == 1) { + return 1; + } + FD_ZERO(&rfds); + FD_ZERO(&wfds); + if (ret == TLS_ERROR_RECV_AGAIN) { + FD_SET(conn->sock, &rfds); + } else if (ret == TLS_ERROR_SEND_AGAIN) { + FD_SET(conn->sock, &wfds); + } else { + error_print(); + return -1; + } + if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) { + error_print(); + return -1; + } + } +} + +static int do_send_select(TLS_CONNECT *conn, const uint8_t *buf, size_t len) +{ + int ret; + size_t offset = 0; + fd_set rfds; + fd_set wfds; + + while (offset < len) { + size_t sentlen = 0; + + ret = tls_send(conn, buf + offset, len - offset, &sentlen); + if (ret == 1) { + offset += sentlen; + continue; + } + FD_ZERO(&rfds); + FD_ZERO(&wfds); + if (ret == TLS_ERROR_RECV_AGAIN) { + FD_SET(conn->sock, &rfds); + } else if (ret == TLS_ERROR_SEND_AGAIN) { + FD_SET(conn->sock, &wfds); + } else { + error_print(); + return -1; + } + if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) { + error_print(); + return -1; + } + } + return 1; +} + int tlcp_client_main(int argc, char *argv[]) { int ret = -1; @@ -396,7 +457,7 @@ bad: snprintf(buf, sizeof(buf), "GET %s HTTP/1.1\r\nHost: %s\r\n\r\n", get, host); - if (tls_send(&conn, (uint8_t *)buf, strlen(buf), &len) != 1) { + if (do_send_select(&conn, (uint8_t *)buf, strlen(buf)) != 1) { fprintf(stderr, "%s: send error\n", prog); goto end; } @@ -419,6 +480,8 @@ bad: fflush(stdout); } else if (rv == 0) { fprintf(stderr, "%s: TLCP connection is closed by remote host\n", prog); + do_shutdown_select(&conn); + ret = 0; goto end; } else if (rv == -EAGAIN || rv == TLS_ERROR_RECV_AGAIN @@ -436,7 +499,7 @@ bad: fprintf(stderr, "%s: select error\n", prog); goto end; } else if (sel == 0) { - tls_shutdown(&conn); + do_shutdown_select(&conn); ret = 0; goto end; } @@ -458,7 +521,7 @@ bad: if (read_stdin) { #ifdef WIN32 if (fgets(buf, sizeof(buf), stdin)) { - if (tls_send(&conn, (uint8_t *)buf, strlen(buf), &len) != 1) { + if (do_send_select(&conn, (uint8_t *)buf, strlen(buf)) != 1) { fprintf(stderr, "%s: send error\n", prog); goto end; } @@ -467,7 +530,9 @@ bad: fprintf(stderr, "%s: length of input line exceeds buffer size\n", prog); goto end; } - read_stdin = 0; + do_shutdown_select(&conn); + ret = 0; + goto end; } #else FD_SET(STDIN_FILENO, &fds); // in POSIX, first arg type is int @@ -483,7 +548,7 @@ bad: if (read_stdin && FD_ISSET(STDIN_FILENO, &fds)) { if (fgets(buf, sizeof(buf), stdin)) { - if (tls_send(&conn, (uint8_t *)buf, strlen(buf), &len) != 1) { + if (do_send_select(&conn, (uint8_t *)buf, strlen(buf)) != 1) { fprintf(stderr, "%s: send error\n", prog); goto end; } @@ -492,7 +557,9 @@ bad: fprintf(stderr, "%s: length of input line exceeds buffer size\n", prog); goto end; } - read_stdin = 0; + do_shutdown_select(&conn); + ret = 0; + goto end; } } #endif @@ -508,6 +575,8 @@ bad: fflush(stdout); } else if (rv == 0) { fprintf(stderr, "Connection closed by remote host\n"); + do_shutdown_select(&conn); + ret = 0; goto end; } else if (rv == -EAGAIN || rv == TLS_ERROR_RECV_AGAIN diff --git a/tools/tlcp_server.c b/tools/tlcp_server.c index c06a7ed1..50630365 100644 --- a/tools/tlcp_server.c +++ b/tools/tlcp_server.c @@ -81,6 +81,67 @@ static int do_handshake_select(TLS_CONNECT *conn) } } +static int do_shutdown_select(TLS_CONNECT *conn) +{ + int ret; + fd_set rfds; + fd_set wfds; + + for (;;) { + ret = tls_shutdown(conn); + if (ret == 1) { + return 1; + } + FD_ZERO(&rfds); + FD_ZERO(&wfds); + if (ret == TLS_ERROR_RECV_AGAIN) { + FD_SET(conn->sock, &rfds); + } else if (ret == TLS_ERROR_SEND_AGAIN) { + FD_SET(conn->sock, &wfds); + } else { + error_print(); + return -1; + } + if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) { + error_print(); + return -1; + } + } +} + +static int do_send_select(TLS_CONNECT *conn, const uint8_t *buf, size_t len) +{ + int ret; + size_t offset = 0; + fd_set rfds; + fd_set wfds; + + while (offset < len) { + size_t sentlen = 0; + + ret = tls_send(conn, buf + offset, len - offset, &sentlen); + if (ret == 1) { + offset += sentlen; + continue; + } + FD_ZERO(&rfds); + FD_ZERO(&wfds); + if (ret == TLS_ERROR_RECV_AGAIN) { + FD_SET(conn->sock, &rfds); + } else if (ret == TLS_ERROR_SEND_AGAIN) { + FD_SET(conn->sock, &wfds); + } else { + error_print(); + return -1; + } + if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) { + error_print(); + return -1; + } + } + return 1; +} + int tlcp_server_main(int argc , char **argv) { int ret = 1; @@ -267,7 +328,6 @@ restart: for (;;) { int rv; - size_t sentlen; fd_set fds; do { @@ -286,8 +346,14 @@ restart: || rv == TLS_ERROR_SEND_AGAIN) { continue; } - if (rv < 0) fprintf(stderr, "%s: recv failure\n", prog); - else fprintf(stderr, "%s: Disconnected by remote\n", prog); + if (rv < 0) { + fprintf(stderr, "%s: recv failure\n", prog); + } else { + if (do_shutdown_select(&conn) != 1) { + fprintf(stderr, "%s: shutdown failure\n", prog); + } + fprintf(stderr, "%s: Disconnected by remote\n", prog); + } //tls_socket_close(conn.sock); // FIXME: tls_cleanup(&conn); @@ -295,7 +361,7 @@ restart: } } while (!len); - if (tls_send(&conn, (uint8_t *)buf, len, &sentlen) != 1) { + if (do_send_select(&conn, (uint8_t *)buf, len) != 1) { fprintf(stderr, "%s: send failure, close connection\n", prog); tls_socket_close(conn.sock); goto end; diff --git a/tools/tls12_client.c b/tools/tls12_client.c index 83238f1f..1714b9f8 100644 --- a/tools/tls12_client.c +++ b/tools/tls12_client.c @@ -96,6 +96,67 @@ static int do_handshake_select(TLS_CONNECT *conn) } } +static int do_shutdown_select(TLS_CONNECT *conn) +{ + int ret; + fd_set rfds; + fd_set wfds; + + for (;;) { + ret = tls_shutdown(conn); + if (ret == 1) { + return 1; + } + FD_ZERO(&rfds); + FD_ZERO(&wfds); + if (ret == TLS_ERROR_RECV_AGAIN) { + FD_SET(conn->sock, &rfds); + } else if (ret == TLS_ERROR_SEND_AGAIN) { + FD_SET(conn->sock, &wfds); + } else { + error_print(); + return -1; + } + if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) { + error_print(); + return -1; + } + } +} + +static int do_send_select(TLS_CONNECT *conn, const uint8_t *buf, size_t len) +{ + int ret; + size_t offset = 0; + fd_set rfds; + fd_set wfds; + + while (offset < len) { + size_t sentlen = 0; + + ret = tls_send(conn, buf + offset, len - offset, &sentlen); + if (ret == 1) { + offset += sentlen; + continue; + } + FD_ZERO(&rfds); + FD_ZERO(&wfds); + if (ret == TLS_ERROR_RECV_AGAIN) { + FD_SET(conn->sock, &rfds); + } else if (ret == TLS_ERROR_SEND_AGAIN) { + FD_SET(conn->sock, &wfds); + } else { + error_print(); + return -1; + } + if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) { + error_print(); + return -1; + } + } + return 1; +} + int tls12_client_main(int argc, char *argv[]) { int ret = -1; @@ -357,17 +418,15 @@ bad: for (;;) { fd_set fds; - size_t sentlen; - if (!fgets(send_buf, sizeof(send_buf), stdin)) { if (feof(stdin)) { - tls_shutdown(&conn); + do_shutdown_select(&conn); goto end; } else { continue; } } - if (tls_send(&conn, (uint8_t *)send_buf, strlen(send_buf), &sentlen) != 1) { + if (do_send_select(&conn, (uint8_t *)send_buf, strlen(send_buf)) != 1) { fprintf(stderr, "%s: send error\n", prog); goto end; } @@ -396,6 +455,9 @@ bad: || rv == TLS_ERROR_SEND_AGAIN) { break; } + if (rv == 0) { + do_shutdown_select(&conn); + } goto end; } fwrite(buf, 1, len, stdout); @@ -415,13 +477,13 @@ bad: if (!fgets(send_buf, sizeof(send_buf), stdin)) { if (feof(stdin)) { - tls_shutdown(&conn); + do_shutdown_select(&conn); goto end; } else { continue; } } - if (tls_send(&conn, (uint8_t *)send_buf, strlen(send_buf), &sentlen) != 1) { + if (do_send_select(&conn, (uint8_t *)send_buf, strlen(send_buf)) != 1) { fprintf(stderr, "%s: send error\n", prog); goto end; } diff --git a/tools/tls12_server.c b/tools/tls12_server.c index 822625aa..94de3b3d 100644 --- a/tools/tls12_server.c +++ b/tools/tls12_server.c @@ -87,6 +87,67 @@ static int do_handshake_select(TLS_CONNECT *conn) } } +static int do_shutdown_select(TLS_CONNECT *conn) +{ + int ret; + fd_set rfds; + fd_set wfds; + + for (;;) { + ret = tls_shutdown(conn); + if (ret == 1) { + return 1; + } + FD_ZERO(&rfds); + FD_ZERO(&wfds); + if (ret == TLS_ERROR_RECV_AGAIN) { + FD_SET(conn->sock, &rfds); + } else if (ret == TLS_ERROR_SEND_AGAIN) { + FD_SET(conn->sock, &wfds); + } else { + error_print(); + return -1; + } + if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) { + error_print(); + return -1; + } + } +} + +static int do_send_select(TLS_CONNECT *conn, const uint8_t *buf, size_t len) +{ + int ret; + size_t offset = 0; + fd_set rfds; + fd_set wfds; + + while (offset < len) { + size_t sentlen = 0; + + ret = tls_send(conn, buf + offset, len - offset, &sentlen); + if (ret == 1) { + offset += sentlen; + continue; + } + FD_ZERO(&rfds); + FD_ZERO(&wfds); + if (ret == TLS_ERROR_RECV_AGAIN) { + FD_SET(conn->sock, &rfds); + } else if (ret == TLS_ERROR_SEND_AGAIN) { + FD_SET(conn->sock, &wfds); + } else { + error_print(); + return -1; + } + if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) { + error_print(); + return -1; + } + } + return 1; +} + int tls12_server_main(int argc , char **argv) { int ret = 1; @@ -362,7 +423,6 @@ restart: for (;;) { int rv; - size_t sentlen; fd_set fds; do { @@ -381,8 +441,14 @@ restart: || rv == TLS_ERROR_SEND_AGAIN) { continue; } - if (rv < 0) fprintf(stderr, "%s: recv failure\n", prog); - else fprintf(stderr, "%s: Disconnected by remote\n", prog); + if (rv < 0) { + fprintf(stderr, "%s: recv failure\n", prog); + } else { + if (do_shutdown_select(&conn) != 1) { + fprintf(stderr, "%s: shutdown failure\n", prog); + } + fprintf(stderr, "%s: Disconnected by remote\n", prog); + } //tls_socket_close(conn.sock); // FIXME: tls_cleanup(&conn); @@ -390,7 +456,7 @@ restart: } } while (!len); - if (tls_send(&conn, (uint8_t *)buf, len, &sentlen) != 1) { + if (do_send_select(&conn, (uint8_t *)buf, len) != 1) { fprintf(stderr, "%s: send failure, close connection\n", prog); tls_socket_close(conn.sock); goto end; diff --git a/tools/tls13_client.c b/tools/tls13_client.c index 2c882f57..7a63bf69 100644 --- a/tools/tls13_client.c +++ b/tools/tls13_client.c @@ -70,6 +70,34 @@ static int do_handshake_select(TLS_CONNECT *conn) } } +static int do_shutdown_select(TLS_CONNECT *conn) +{ + int ret; + fd_set rfds; + fd_set wfds; + + for (;;) { + ret = tls_shutdown(conn); + if (ret == 1) { + return 1; + } + FD_ZERO(&rfds); + FD_ZERO(&wfds); + if (ret == TLS_ERROR_RECV_AGAIN) { + FD_SET(conn->sock, &rfds); + } else if (ret == TLS_ERROR_SEND_AGAIN) { + FD_SET(conn->sock, &wfds); + } else { + error_print(); + return -1; + } + if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) { + error_print(); + return -1; + } + } +} + static const char *http_get = "GET / HTTP/1.1\r\n" "Hostname: aaa\r\n" @@ -725,6 +753,10 @@ bad: if ((ret = tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len)) != 1) { if (ret == TLS_ERROR_SEND_AGAIN || ret == TLS_ERROR_RECV_AGAIN) { continue; + } else if (ret == 0) { + do_shutdown_select(&conn); + ret = 0; + goto end; } else { error_print(); goto end; @@ -742,9 +774,9 @@ bad: if (!fgets(send_buf, sizeof(send_buf), stdin)) { if (feof(stdin)) { - error_print(); fprintf(stderr, "client shutdown\n"); - tls_shutdown(&conn); + do_shutdown_select(&conn); + ret = 0; goto end; } else { continue; diff --git a/tools/tls13_server.c b/tools/tls13_server.c index 68225780..75de46fb 100644 --- a/tools/tls13_server.c +++ b/tools/tls13_server.c @@ -99,6 +99,34 @@ static int do_handshake_select(TLS_CONNECT *conn) } } +static int do_shutdown_select(TLS_CONNECT *conn) +{ + int ret; + fd_set rfds; + fd_set wfds; + + for (;;) { + ret = tls_shutdown(conn); + if (ret == 1) { + return 1; + } + FD_ZERO(&rfds); + FD_ZERO(&wfds); + if (ret == TLS_ERROR_RECV_AGAIN) { + FD_SET(conn->sock, &rfds); + } else if (ret == TLS_ERROR_SEND_AGAIN) { + FD_SET(conn->sock, &wfds); + } else { + error_print(); + return -1; + } + if (select((int)(conn->sock + 1), &rfds, &wfds, NULL, NULL) < 0) { + error_print(); + return -1; + } + } +} + int tls13_server_main(int argc , char **argv) { @@ -630,6 +658,9 @@ bad: if ((ret = tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len)) != 1) { if (ret == TLS_ERROR_SEND_AGAIN || ret == TLS_ERROR_RECV_AGAIN) { continue; + } else if (ret == 0) { + do_shutdown_select(&conn); + goto end; } error_print(); goto end;