From 801d896d5ad819b10421f2a3959f57474ae41e28 Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Fri, 29 Jul 2022 17:13:10 +0800 Subject: [PATCH] Fix tls_do_recv bug and update SSL clients --- include/gmssl/tls.h | 7 +-- src/tls.c | 4 +- src/tls12.c | 4 +- src/tls13.c | 106 ++++++++++++++++++++++++++++++++++++------- tools/tlcp_client.c | 25 +++++----- tools/tls12_client.c | 67 ++++++++++++++++++--------- tools/tls13_client.c | 67 ++++++++++++++++++--------- tools/tls13_server.c | 2 +- 8 files changed, 202 insertions(+), 80 deletions(-) diff --git a/include/gmssl/tls.h b/include/gmssl/tls.h index aec93a05..f36bc396 100644 --- a/include/gmssl/tls.h +++ b/include/gmssl/tls.h @@ -789,7 +789,8 @@ typedef struct { uint8_t record[TLS_MAX_RECORD_SIZE]; // 其实这个就不太对了,还是应该有一个完整的密文记录 - uint8_t data[TLS_MAX_PLAINTEXT_SIZE]; + uint8_t databuf[TLS_MAX_PLAINTEXT_SIZE]; + uint8_t *data; size_t datalen; int cipher_suite; @@ -846,8 +847,8 @@ int tls13_do_accept(TLS_CONNECT *conn); int tls_send_alert(TLS_CONNECT *conn, int alert); int tls_send_warning(TLS_CONNECT *conn, int alert); -int tls13_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen, size_t padding_len); -int tls13_recv(TLS_CONNECT *conn, uint8_t *data, size_t *datalen); +int tls13_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen, size_t *sentlen); +int tls13_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen); int tls13_connect(TLS_CONNECT *conn, const char *hostname, int port, FILE *server_cacerts_fp, diff --git a/src/tls.c b/src/tls.c index 4616f8ae..02e47559 100644 --- a/src/tls.c +++ b/src/tls.c @@ -1775,10 +1775,11 @@ int tls_do_recv(TLS_CONNECT *conn) tls_record_trace(stderr, record, recordlen, 0, 0); if (tls_cbc_decrypt(hmac_ctx, dec_key, seq_num, record, tls_record_data(record), tls_record_data_length(record), - conn->data, &conn->datalen) != 1) { + conn->databuf, &conn->datalen) != 1) { error_print(); return -1; } + conn->data = conn->databuf; tls_seq_num_incr(seq_num); tls_record_set_data(record, conn->data, conn->datalen); @@ -1802,6 +1803,7 @@ int tls_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen) } *recvlen = outlen <= conn->datalen ? outlen : conn->datalen; memcpy(out, conn->data, *recvlen); + conn->data += *recvlen; conn->datalen -= *recvlen; return 1; } diff --git a/src/tls12.c b/src/tls12.c index 35cf90c6..34672785 100644 --- a/src/tls12.c +++ b/src/tls12.c @@ -684,7 +684,7 @@ int tls12_do_connect(TLS_CONNECT *conn) tls_send_alert(conn, TLS_alert_decrypt_error); goto end; } - tls_trace("Connection established!\n"); + printf("Connection established!\n"); conn->protocol = conn->protocol; @@ -1117,7 +1117,7 @@ int tls12_do_accept(TLS_CONNECT *conn) conn->protocol = conn->protocol; - tls_trace("Connection Established!\n\n"); + printf("Connection Established!\n\n"); ret = 1; end: diff --git a/src/tls13.c b/src/tls13.c index ed4eeb65..d29fae2e 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -245,13 +245,14 @@ int tls13_record_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], return 1; } -int tls13_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen, size_t padding_len) +int tls13_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen, size_t *sentlen) { const BLOCK_CIPHER_KEY *key; const uint8_t *iv; uint8_t *seq_num; uint8_t *record = conn->record; size_t recordlen; + size_t padding_len = 0; //FIXME: 在conn中设置是否加随机填充,及设置该值 tls_trace("send {ApplicationData}\n"); @@ -284,9 +285,12 @@ int tls13_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen, size_t pa tls_seq_num_incr(seq_num); + *sentlen = datalen; + return 1; } +/* int tls13_recv(TLS_CONNECT *conn, uint8_t *data, size_t *datalen) { int record_type; @@ -303,8 +307,6 @@ int tls13_recv(TLS_CONNECT *conn, uint8_t *data, size_t *datalen) key = &conn->server_write_key; iv = conn->server_write_iv; seq_num = conn->server_seq_num; - - } else { key = &conn->client_write_key; iv = conn->client_write_iv; @@ -336,6 +338,77 @@ int tls13_recv(TLS_CONNECT *conn, uint8_t *data, size_t *datalen) } return 1; } +*/ + +int tls13_do_recv(TLS_CONNECT *conn) +{ + int ret; + const BLOCK_CIPHER_KEY *key; + const uint8_t *iv; + uint8_t *seq_num; + uint8_t *record = conn->record; + size_t recordlen; + int record_type; + + if (conn->is_client) { + key = &conn->server_write_key; + iv = conn->server_write_iv; + seq_num = conn->server_seq_num; + } else { + key = &conn->client_write_key; + iv = conn->client_write_iv; + seq_num = conn->client_seq_num; + } + + tls_trace("recv ApplicationData\n"); + if ((ret = tls_record_recv(record, &recordlen, conn->sock)) != 1) { + if (ret < 0) error_print(); + return ret; + } + tls_record_trace(stderr, record, recordlen, 0, 0); + // TODO: 是否需要检查record_type? record[0] != TLS_record_application_data + + if (tls13_gcm_decrypt(key, iv, + seq_num, record + 5, recordlen - 5, + &record_type, conn->databuf, &conn->datalen) != 1) { + error_print(); + return -1; + } + conn->data = conn->databuf; + tls_seq_num_incr(seq_num); + + tls_record_set_data(record, conn->data, conn->datalen); + tls_trace("decrypt ApplicationData\n"); + tls_record_trace(stderr, record, tls_record_length(record), 0, 0); + + + if (record_type != TLS_record_application_data) { + error_print(); + return -1; + } + return 1; +} + +int tls13_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen) +{ + if (!conn || !out || !outlen || !recvlen) { + error_print(); + return -1; + } + if (conn->datalen == 0) { + int ret; + if ((ret = tls13_do_recv(conn)) != 1) { + if (ret) error_print(); + return ret; + } + } + *recvlen = outlen <= conn->datalen ? outlen : conn->datalen; + memcpy(out, conn->data, *recvlen); + conn->data += *recvlen; + conn->datalen -= *recvlen; + return 1; +} + /* @@ -1551,13 +1624,13 @@ int tls13_do_connect(TLS_CONNECT *conn) tls13_hkdf_expand_label(digest, server_handshake_traffic_secret, "iv", NULL, 0, 12, conn->server_write_iv); block_cipher_set_encrypt_key(&conn->client_write_key, cipher, client_write_key); block_cipher_set_encrypt_key(&conn->server_write_key, cipher, server_write_key); - + /* format_bytes(stderr, 0, 4, "client_write_key", client_write_key, 16); format_bytes(stderr, 0, 4, "server_write_key", server_write_key, 16); format_bytes(stderr, 0, 4, "client_write_iv", conn->client_write_iv, 12); format_bytes(stderr, 0, 4, "server_write_iv", conn->server_write_iv, 12); format_print(stderr, 0, 0, "\n"); - + */ // recv {EncryptedExtensions} printf("recv {EncryptedExtensions}\n"); @@ -1736,11 +1809,12 @@ int tls13_do_connect(TLS_CONNECT *conn) tls13_hkdf_expand_label(digest, server_application_traffic_secret, "key", NULL, 0, 16, server_write_key); block_cipher_set_encrypt_key(&conn->server_write_key, cipher, server_write_key); tls13_hkdf_expand_label(digest, server_application_traffic_secret, "iv", NULL, 0, 12, conn->server_write_iv); - + /* format_print(stderr, 0, 0, "update server secrets\n"); format_bytes(stderr, 0, 4, "server_write_key", server_write_key, 16); format_bytes(stderr, 0, 4, "server_write_iv", conn->server_write_iv, 12); format_print(stderr, 0, 0, "\n"); + */ if (conn->client_certs_len) { int client_sign_algor; @@ -1831,13 +1905,13 @@ int tls13_do_connect(TLS_CONNECT *conn) tls13_hkdf_expand_label(digest, client_application_traffic_secret, "key", NULL, 0, 16, client_write_key); tls13_hkdf_expand_label(digest, client_application_traffic_secret, "iv", NULL, 0, 12, conn->client_write_iv); block_cipher_set_encrypt_key(&conn->client_write_key, cipher, client_write_key); - + /* format_print(stderr, 0, 0, "update client secrets\n"); format_bytes(stderr, 0, 4, "client_write_key", client_write_key, 16); format_bytes(stderr, 0, 4, "client_write_iv", conn->client_write_iv, 12); format_print(stderr, 0, 0, "\n"); - - tls_trace("++++ Connection established\n"); + */ + printf("++++ Connection established\n"); end: return 1; @@ -2003,14 +2077,14 @@ int tls13_do_accept(TLS_CONNECT *conn) tls13_hkdf_expand_label(digest, server_handshake_traffic_secret, "key", NULL, 0, 16, server_write_key); block_cipher_set_encrypt_key(&conn->server_write_key, cipher, server_write_key); tls13_hkdf_expand_label(digest, server_handshake_traffic_secret, "iv", NULL, 0, 12, conn->server_write_iv); - + /* format_print(stderr, 0, 0, "generate handshake secrets\n"); format_bytes(stderr, 0, 4, "client_write_key", client_write_key, 16); format_bytes(stderr, 0, 4, "server_write_key", server_write_key, 16); format_bytes(stderr, 0, 4, "client_write_iv", conn->client_write_iv, 12); format_bytes(stderr, 0, 4, "server_write_iv", conn->server_write_iv, 12); format_print(stderr, 0, 0, "\n"); - + */ // 3. Send {EncryptedExtensions} tls_trace("send {EncryptedExtensions}\n"); @@ -2146,12 +2220,12 @@ int tls13_do_accept(TLS_CONNECT *conn) tls13_hkdf_expand_label(digest, server_application_traffic_secret, "key", NULL, 0, 16, server_write_key); block_cipher_set_encrypt_key(&conn->server_write_key, cipher, server_write_key); tls13_hkdf_expand_label(digest, server_application_traffic_secret, "iv", NULL, 0, 12, conn->server_write_iv); - + /* format_print(stderr, 0, 0, "update server secrets\n"); format_bytes(stderr, 0, 4, "server_write_key", server_write_key, 16); format_bytes(stderr, 0, 4, "server_write_iv", conn->server_write_iv, 12); format_print(stderr, 0, 0, "\n"); - + */ // Recv Client {Certificate*} if (client_verify) { @@ -2272,13 +2346,13 @@ int tls13_do_accept(TLS_CONNECT *conn) tls13_hkdf_expand_label(digest, client_application_traffic_secret, "key", NULL, 0, 16, client_write_key); tls13_hkdf_expand_label(digest, client_application_traffic_secret, "iv", NULL, 0, 12, conn->client_write_iv); block_cipher_set_encrypt_key(&conn->client_write_key, cipher, client_write_key); - + /* format_print(stderr, 0, 0, "update client secrets\n"); format_bytes(stderr, 0, 4, "client_write_key", client_write_key, 16); format_bytes(stderr, 0, 4, "client_write_iv", conn->client_write_iv, 12); format_print(stderr, 0, 0, "\n"); - - tls_trace("Connection Established!\n\n"); + */ + printf("Connection Established!\n\n"); end: return 1; diff --git a/tools/tlcp_client.c b/tools/tlcp_client.c index d04e7477..5a6c52fd 100644 --- a/tools/tlcp_client.c +++ b/tools/tlcp_client.c @@ -50,7 +50,6 @@ #include #include #include - #include #include #include @@ -85,7 +84,7 @@ int tlcp_client_main(int argc, char *argv[]) int sock; TLS_CTX ctx; TLS_CONNECT conn; - char buf[1000] = {0}; + char buf[10] = {0}; size_t len = sizeof(buf); char send_buf[1024] = {0}; size_t send_len; @@ -180,16 +179,15 @@ bad: goto end; } - fd_set fds; for (;;) { + fd_set fds; size_t sentlen; - FD_ZERO(&fds); FD_SET(conn.sock, &fds); - FD_SET(0, &fds); + FD_SET(STDIN_FILENO, &fds); if (select(conn.sock + 1, &fds, NULL, NULL, NULL) < 0) { fprintf(stderr, "%s: select failed\n", prog); @@ -197,18 +195,17 @@ bad: } if (FD_ISSET(conn.sock, &fds)) { - memset(buf, 0, sizeof(buf)); - len = sizeof(buf); - - if (tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len) != 1) { - goto end; + for (;;) { + memset(buf, 0, sizeof(buf)); + if (tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len) != 1) { + goto end; + } + fwrite(buf, 1, len, stdout); + fflush(stdout); } - buf[len] = 0; - printf("%s\n", buf); - } - if (FD_ISSET(0, &fds)) { + if (FD_ISSET(STDIN_FILENO, &fds)) { memset(send_buf, 0, sizeof(send_buf)); if (!fgets(send_buf, sizeof(send_buf), stdin)) { diff --git a/tools/tls12_client.c b/tools/tls12_client.c index 8e031185..12cec233 100644 --- a/tools/tls12_client.c +++ b/tools/tls12_client.c @@ -87,7 +87,7 @@ int tls12_client_main(int argc, char *argv[]) int sock; TLS_CTX ctx; TLS_CONNECT conn; - char buf[100] = {0}; + char buf[1024] = {0}; size_t len = sizeof(buf); char send_buf[1024] = {0}; size_t send_len; @@ -158,12 +158,23 @@ bad: } if (tls_ctx_init(&ctx, TLS_protocol_tls12, TLS_client_mode) != 1 - || tls_ctx_set_cipher_suites(&ctx, client_ciphers, sizeof(client_ciphers)/sizeof(client_ciphers[0])) != 1 - || tls_ctx_set_ca_certificates(&ctx, cacertfile, TLS_DEFAULT_VERIFY_DEPTH) != 1 - || tls_ctx_set_certificate_and_key(&ctx, certfile, keyfile, pass) != 1) { + || tls_ctx_set_cipher_suites(&ctx, client_ciphers, sizeof(client_ciphers)/sizeof(client_ciphers[0])) != 1) { fprintf(stderr, "%s: context init error\n", prog); goto end; } + if (cacertfile) { + if (tls_ctx_set_ca_certificates(&ctx, cacertfile, TLS_DEFAULT_VERIFY_DEPTH) != 1) { + fprintf(stderr, "%s: context init error\n", prog); + goto end; + } + } + if (certfile) { + if (tls_ctx_set_certificate_and_key(&ctx, certfile, keyfile, pass) != 1) { + fprintf(stderr, "%s: context init error\n", prog); + goto end; + } + } + if (tls_init(&conn, &ctx) != 1 || tls_set_socket(&conn, sock) != 1 || tls_do_handshake(&conn) != 1) { @@ -172,30 +183,44 @@ bad: } for (;;) { + fd_set fds; size_t sentlen; - memset(send_buf, 0, sizeof(send_buf)); - if (!fgets(send_buf, sizeof(send_buf), stdin)) { - if (feof(stdin)) { - tls_shutdown(&conn); - goto end; - } else { - continue; - } - } - if (tls_send(&conn, (uint8_t *)send_buf, strlen(send_buf), &sentlen) != 1) { - fprintf(stderr, "%s: send error\n", prog); + FD_ZERO(&fds); + FD_SET(conn.sock, &fds); + FD_SET(STDIN_FILENO, &fds); + + if (select(conn.sock + 1, &fds, NULL, NULL, NULL) < 0) { + fprintf(stderr, "%s: select failed\n", prog); goto end; } - { - memset(buf, 0, sizeof(buf)); - len = sizeof(buf); - if (tls_recv(&conn, (uint8_t *)buf, sizeof(len), &len) != 1) { + if (FD_ISSET(conn.sock, &fds)) { + for (;;) { + memset(buf, 0, sizeof(buf)); + if (tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len) != 1) { + goto end; + } + fwrite(buf, 1, len, stdout); + fflush(stdout); + } + + } + if (FD_ISSET(STDIN_FILENO, &fds)) { + memset(send_buf, 0, sizeof(send_buf)); + + if (!fgets(send_buf, sizeof(send_buf), stdin)) { + if (feof(stdin)) { + tls_shutdown(&conn); + goto end; + } else { + continue; + } + } + if (tls_send(&conn, (uint8_t *)send_buf, strlen(send_buf), &sentlen) != 1) { + fprintf(stderr, "%s: send error\n", prog); goto end; } - buf[len] = 0; - printf("%s\n", buf); } } diff --git a/tools/tls13_client.c b/tools/tls13_client.c index 76f3634e..7add727b 100644 --- a/tools/tls13_client.c +++ b/tools/tls13_client.c @@ -87,7 +87,7 @@ int tls13_client_main(int argc, char *argv[]) int sock; TLS_CTX ctx; TLS_CONNECT conn; - char buf[100] = {0}; + char buf[1024] = {0}; size_t len = sizeof(buf); char send_buf[1024] = {0}; size_t send_len; @@ -158,12 +158,22 @@ bad: } if (tls_ctx_init(&ctx, TLS_protocol_tls13, TLS_client_mode) != 1 - || tls_ctx_set_cipher_suites(&ctx, client_ciphers, sizeof(client_ciphers)/sizeof(client_ciphers[0])) != 1 - || tls_ctx_set_ca_certificates(&ctx, cacertfile, TLS_DEFAULT_VERIFY_DEPTH) != 1 - || tls_ctx_set_certificate_and_key(&ctx, certfile, keyfile, pass) != 1) { + || tls_ctx_set_cipher_suites(&ctx, client_ciphers, sizeof(client_ciphers)/sizeof(client_ciphers[0])) != 1) { fprintf(stderr, "%s: context init error\n", prog); goto end; } + if (cacertfile) { + if (tls_ctx_set_ca_certificates(&ctx, cacertfile, TLS_DEFAULT_VERIFY_DEPTH) != 1) { + fprintf(stderr, "%s: context init error\n", prog); + goto end; + } + } + if (certfile) { + if (tls_ctx_set_certificate_and_key(&ctx, certfile, keyfile, pass) != 1) { + fprintf(stderr, "%s: context init error\n", prog); + goto end; + } + } if (tls_init(&conn, &ctx) != 1 || tls_set_socket(&conn, sock) != 1 || tls_do_handshake(&conn) != 1) { @@ -172,34 +182,47 @@ bad: } for (;;) { + fd_set fds; size_t sentlen; - memset(send_buf, 0, sizeof(send_buf)); - if (!fgets(send_buf, sizeof(send_buf), stdin)) { - if (feof(stdin)) { - tls_shutdown(&conn); - goto end; - } else { - continue; - } - } - if (tls13_send(&conn, (uint8_t *)send_buf, strlen(send_buf), 0 /*&sentlen*/) != 1) { - fprintf(stderr, "%s: send error\n", prog); + FD_ZERO(&fds); + FD_SET(conn.sock, &fds); + FD_SET(STDIN_FILENO, &fds); + + if (select(conn.sock + 1, &fds, NULL, NULL, NULL) < 0) { + fprintf(stderr, "%s: select failed\n", prog); goto end; } - { - memset(buf, 0, sizeof(buf)); - len = sizeof(buf); - if (tls13_recv(&conn, (uint8_t *)buf, /*sizeof(len),*/ &len) != 1) { + if (FD_ISSET(conn.sock, &fds)) { + for (;;) { + memset(buf, 0, sizeof(buf)); + if (tls13_recv(&conn, (uint8_t *)buf, sizeof(buf), &len) != 1) { + goto end; + } + fwrite(buf, 1, len, stdout); + fflush(stdout); + } + + } + if (FD_ISSET(STDIN_FILENO, &fds)) { + memset(send_buf, 0, sizeof(send_buf)); + + if (!fgets(send_buf, sizeof(send_buf), stdin)) { + if (feof(stdin)) { + tls_shutdown(&conn); + goto end; + } else { + continue; + } + } + if (tls13_send(&conn, (uint8_t *)send_buf, strlen(send_buf), &sentlen) != 1) { + fprintf(stderr, "%s: send error\n", prog); goto end; } - buf[len] = 0; - printf("%s\n", buf); } } - end: close(sock); tls_ctx_cleanup(&ctx); diff --git a/tools/tls13_server.c b/tools/tls13_server.c index 7619a43a..f2ad8637 100644 --- a/tools/tls13_server.c +++ b/tools/tls13_server.c @@ -199,7 +199,7 @@ restart: do { len = sizeof(buf); - if ((rv = tls13_recv(&conn, (uint8_t *)buf, /*sizeof(buf),*/ &len)) != 1) { + if ((rv = tls13_recv(&conn, (uint8_t *)buf, sizeof(buf), &len)) != 1) { if (rv < 0) fprintf(stderr, "%s: recv failure\n", prog); else fprintf(stderr, "%s: Disconnected by remote\n", prog);