From 8a90eb9c91a8220bd12ac1469253910c61bc93f0 Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Fri, 12 Jun 2026 09:08:49 +0800 Subject: [PATCH] Update TLCP to support SNI --- include/gmssl/tls.h | 3 + src/tlcp.c | 315 +++++++++++++++++++++++++++++++++++++++++--- src/tls.c | 94 +++++++++---- tools/tlcp_client.c | 12 +- tools/tlcp_server.c | 82 +++++++++--- 5 files changed, 441 insertions(+), 65 deletions(-) diff --git a/include/gmssl/tls.h b/include/gmssl/tls.h index fbd2ee56..b6ff82d4 100644 --- a/include/gmssl/tls.h +++ b/include/gmssl/tls.h @@ -920,6 +920,9 @@ int tls_ctx_set_signature_algorithms(TLS_CTX *ctx, const int *sig_algs, size_t s int tls_ctx_set_ca_certificates(TLS_CTX *ctx, const char *cacertsfile, int depth); int tls_ctx_set_certificate_and_key(TLS_CTX *ctx, const char *chainfile, const char *keyfile, const char *keypass); +int tlcp_ctx_add_server_certificate_and_keys(TLS_CTX *ctx, const char *chainfile, + const char *signkeyfile, const char *signkeypass, + const char *kenckeyfile, const char *kenckeypass); int tls_ctx_set_tlcp_server_certificate_and_keys(TLS_CTX *ctx, const char *chainfile, const char *signkeyfile, const char *signkeypass, const char *kenckeyfile, const char *kenckeypass); diff --git a/src/tlcp.c b/src/tlcp.c index ce01e406..7b1495bf 100644 --- a/src/tlcp.c +++ b/src/tlcp.c @@ -498,6 +498,8 @@ int tlcp_recv_server_certificate(TLS_CONNECT *conn) { int ret; int verify_result; + const uint8_t *server_cert; + size_t server_cert_len; tls_trace("recv server Certificate\n"); @@ -538,9 +540,35 @@ int tlcp_recv_server_certificate(TLS_CONNECT *conn) tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } + if (conn->peer_cert_chain_len > sizeof(conn->server_certs)) { + error_print(); + tls_send_alert(conn, TLS_alert_bad_certificate); + return -1; + } + memcpy(conn->server_certs, conn->peer_cert_chain, conn->peer_cert_chain_len); + conn->server_certs_len = conn->peer_cert_chain_len; + + if (x509_certs_get_cert_by_index(conn->peer_cert_chain, conn->peer_cert_chain_len, + 0, &server_cert, &server_cert_len) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_bad_certificate); + return -1; + } + if (conn->server_name) { + if ((ret = tls_cert_match_server_name(server_cert, server_cert_len, + conn->host_name, conn->host_name_len)) < 0) { + error_print(); + tls_send_alert(conn, TLS_alert_bad_certificate); + return -1; + } else if (ret == 0) { + error_print(); + tls_send_alert(conn, TLS_alert_bad_certificate); + return -1; + } + } if (conn->ctx->cacertslen) { - if (x509_certs_verify_tlcp(conn->server_certs, conn->server_certs_len, X509_cert_chain_server, + if (x509_certs_verify_tlcp(conn->peer_cert_chain, conn->peer_cert_chain_len, X509_cert_chain_server, conn->ctx->cacerts, conn->ctx->cacertslen, conn->ctx->verify_depth, &verify_result) != 1) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); @@ -602,7 +630,7 @@ int tlcp_recv_server_key_exchange(TLS_CONNECT *conn) // verify ServerKeyExchange if (x509_certs_get_cert_by_index(conn->peer_cert_chain, conn->peer_cert_chain_len, 0, &cp, &len) != 1 || x509_cert_get_subject_public_key(cp, len, &server_sign_key) != 1 - || x509_certs_get_cert_by_index(conn->server_certs, conn->server_certs_len, 1, &server_enc_cert, &server_enc_cert_len) != 1 + || x509_certs_get_cert_by_index(conn->peer_cert_chain, conn->peer_cert_chain_len, 1, &server_enc_cert, &server_enc_cert_len) != 1 || x509_cert_get_subject_public_key(server_enc_cert, server_enc_cert_len, &conn->server_enc_key) != 1) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); @@ -917,6 +945,8 @@ int tlcp_send_client_finished(TLS_CONNECT *conn) return -1; } + tls_record_set_protocol(conn->plain_record, conn->protocol); + if (tls_record_set_handshake_finished(conn->plain_record, &conn->plain_recordlen, verify_data, sizeof(verify_data)) != 1) { error_print(); @@ -926,8 +956,11 @@ int tlcp_send_client_finished(TLS_CONNECT *conn) tlcp_record_print(stderr, 0, 0, conn->plain_record, conn->plain_recordlen); - //sm3_update(&conn->sm3_ctx, conn->plain_record + 5, conn->plain_recordlen - 5); - //tlcp_handshake_digest_print(stderr, 0, 0, "client Finished", &conn->sm3_ctx); + if (digest_update(&conn->dgst_ctx, conn->plain_record + 5, conn->plain_recordlen - 5) != 1) { + error_print(); + return -1; + } + tls_handshake_digest_print(stderr, 0, 0, "client Finished", &conn->dgst_ctx); if (tls_record_encrypt(&conn->client_write_mac_ctx, &conn->client_write_key, @@ -1054,6 +1087,74 @@ int tlcp_recv_server_finished(TLS_CONNECT *conn) // Server +static int tlcp_cert_chains_select(TLS_CONNECT *conn, + const uint8_t *host_name, size_t host_name_len) +{ + const uint8_t *cert_chains; + size_t cert_chains_len; + size_t cert_chain_idx; + + if (!conn || !conn->ctx || !conn->ctx->cert_chains_len) { + error_print(); + return -1; + } + + cert_chains = conn->ctx->cert_chains; + cert_chains_len = conn->ctx->cert_chains_len; + + for (cert_chain_idx = 1; cert_chains_len; cert_chain_idx++) { + const uint8_t *cert_chain; + size_t cert_chain_len; + const uint8_t *sign_cert; + size_t sign_cert_len; + const uint8_t *enc_cert; + size_t enc_cert_len; + int ret; + + if (tls_uint24array_from_bytes(&cert_chain, &cert_chain_len, + &cert_chains, &cert_chains_len) != 1) { + error_print(); + return -1; + } + if (x509_certs_get_cert_by_index(cert_chain, cert_chain_len, 0, &sign_cert, &sign_cert_len) != 1 + || x509_certs_get_cert_by_index(cert_chain, cert_chain_len, 1, &enc_cert, &enc_cert_len) != 1) { + error_print(); + return -1; + } + + if (host_name && host_name_len) { + if ((ret = tls_cert_match_server_name(sign_cert, sign_cert_len, host_name, host_name_len)) < 0) { + error_print(); + return -1; + } else if (ret == 0) { + continue; + } + } + + if (cert_chain_len > sizeof(conn->server_certs)) { + error_print(); + return -1; + } + + conn->cert_chain = cert_chain; + conn->cert_chain_len = cert_chain_len; + conn->cert_chain_idx = cert_chain_idx; + conn->sign_key = conn->ctx->x509_keys[cert_chain_idx - 1]; + conn->kenc_key = conn->ctx->enc_keys[cert_chain_idx - 1]; + memcpy(conn->server_certs, cert_chain, cert_chain_len); + conn->server_certs_len = cert_chain_len; + conn->signature_algorithms[0] = TLS_sig_sm2sig_sm3; + conn->signature_algorithms_cnt = 1; + return 1; + } + + conn->cert_chain = NULL; + conn->cert_chain_len = 0; + conn->cert_chain_idx = 0; + warning_print(); + return 0; +} + int tlcp_recv_client_hello(TLS_CONNECT *conn) { int ret; @@ -1068,19 +1169,21 @@ int tlcp_recv_client_hello(TLS_CONNECT *conn) // extensions const uint8_t *server_name = NULL; - size_t server_name_len; + size_t server_name_len = 0; const uint8_t *trusted_ca_keys = NULL; - size_t trusted_ca_keys_len; + size_t trusted_ca_keys_len = 0; const uint8_t *status_request = NULL; - size_t status_request_len; + size_t status_request_len = 0; const uint8_t *supported_groups = NULL; - size_t supported_groups_len; + size_t supported_groups_len = 0; const uint8_t *signature_algorithms = NULL; - size_t signature_algorithms_len; + size_t signature_algorithms_len = 0; const uint8_t *application_layer_protocol_negotiation = NULL; - size_t application_layer_protocol_negotiation_len; + size_t application_layer_protocol_negotiation_len = 0; const uint8_t *client_id = NULL; - size_t client_id_len; + size_t client_id_len = 0; + const uint8_t *host_name = NULL; + size_t host_name_len = 0; @@ -1245,6 +1348,39 @@ int tlcp_recv_client_hello(TLS_CONNECT *conn) } } + if (server_name) { + if (tls_server_name_from_bytes(&host_name, &host_name_len, server_name, server_name_len) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_decode_error); + return -1; + } + if (host_name_len > sizeof(conn->host_name)) { + error_print(); + tls_send_alert(conn, TLS_alert_illegal_parameter); + return -1; + } + memcpy(conn->host_name, host_name, host_name_len); + conn->host_name_len = host_name_len; + conn->server_name = 1; + } + + if ((ret = tlcp_cert_chains_select(conn, host_name, host_name_len)) < 0) { + error_print(); + tls_send_alert(conn, TLS_alert_decode_error); + return -1; + } else if (ret == 0) { + error_print(); + tls_send_alert(conn, TLS_alert_handshake_failure); + return -1; + } + + if (digest_init(&conn->dgst_ctx, DIGEST_sm3()) != 1 + || digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { + error_print(); + return -1; + } + tls_handshake_digest_print(stderr, 0, 0, "ClientHello", &conn->dgst_ctx); + //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); //tlcp_handshake_digest_print(stderr, 0, 0, "ClientHello", &conn->sm3_ctx); @@ -1336,6 +1472,12 @@ int tlcp_send_server_hello(TLS_CONNECT *conn) return -1; } tlcp_record_trace(stderr, conn->record, conn->recordlen, 0, 0); + + if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { + error_print(); + return -1; + } + tls_handshake_digest_print(stderr, 0, 0, "ServerHello", &conn->dgst_ctx); } if ((ret = tls_send_record(conn)) != 1) { @@ -1364,7 +1506,39 @@ int tlcp_send_server_certificate(TLS_CONNECT *conn) { int ret; - // 根据套件不同选择SM2证书还是SM9的 + tls_trace("send ServerCertificate\n"); + + if (conn->recordlen == 0) { + if (!conn->cert_chain || !conn->cert_chain_len) { + error_print(); + tls_send_alert(conn, TLS_alert_internal_error); + return -1; + } + if (tls_record_set_handshake_certificate(conn->record, &conn->recordlen, + conn->cert_chain, conn->cert_chain_len) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_internal_error); + return -1; + } + tlcp_record_print(stderr, 0, 0, conn->record, conn->recordlen); + + if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { + error_print(); + return -1; + } + tls_handshake_digest_print(stderr, 0, 0, "Certificate", &conn->dgst_ctx); + } + + if ((ret = tls_send_record(conn)) != 1) { + if (ret != TLS_ERROR_SEND_AGAIN) { + error_print(); + } + return ret; + } + + if (conn->client_certificate_verify) { + tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5); + } return 1; } @@ -1386,8 +1560,13 @@ int tlcp_send_server_key_exchange(TLS_CONNECT *conn) tls_trace("send ServerKeyExchange\n"); if (conn->recordlen == 0) { + if (!conn->cert_chain || !conn->cert_chain_len || !conn->cert_chain_idx) { + error_print(); + tls_send_alert(conn, TLS_alert_internal_error); + return -1; + } - if (x509_certs_get_cert_by_index(conn->server_certs, conn->server_certs_len, 1, + if (x509_certs_get_cert_by_index(conn->cert_chain, conn->cert_chain_len, 1, &server_enc_cert, &server_enc_cert_len) != 1) { error_print(); return -1; @@ -1414,6 +1593,12 @@ int tlcp_send_server_key_exchange(TLS_CONNECT *conn) return -1; } tlcp_record_print(stderr, 0, 0, conn->record, conn->recordlen); + + if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { + error_print(); + return -1; + } + tls_handshake_digest_print(stderr, 0, 0, "ServerKeyExchange", &conn->dgst_ctx); } if ((ret = tls_send_record(conn)) != 1) { @@ -1423,9 +1608,6 @@ int tlcp_send_server_key_exchange(TLS_CONNECT *conn) return ret; } - ///sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); - //tlcp_handshake_digest_print(stderr, 0, 0, "ServerKeyExchange", &conn->sm3_ctx); - if (conn->client_certificate_verify) { tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5); } @@ -1619,6 +1801,12 @@ int tlcp_recv_client_key_exchange(TLS_CONNECT *conn) tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } + if (!conn->cert_chain_idx || conn->kenc_key.algor != OID_ec_public_key + || conn->kenc_key.algor_param != OID_sm2) { + error_print(); + tls_send_alert(conn, TLS_alert_internal_error); + return -1; + } // FIXME: // 这里需要检查一下密钥的长度,因为输入的长度是确定的,因此输出的密文长度应该也是确定的 @@ -1634,8 +1822,11 @@ int tlcp_recv_client_key_exchange(TLS_CONNECT *conn) tls_send_alert(conn, TLS_alert_decrypt_error); return -1; } - //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); - //tlcp_handshake_digest_print(stderr, 0, 0, "ClientKeyExchange", &conn->sm3_ctx);i + if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { + error_print(); + return -1; + } + tls_handshake_digest_print(stderr, 0, 0, "ClientKeyExchange", &conn->dgst_ctx); if (tlcp_generate_keys(conn) != 1) { error_print(); @@ -1659,22 +1850,104 @@ int tlcp_recv_certificate_verify(TLS_CONNECT *conn) int tlcp_recv_client_finished(TLS_CONNECT *conn) { int ret; + const uint8_t *verify_data; + size_t verify_data_len; + uint8_t local_verify_data[12]; - if ((ret = tls_recv_client_finished(conn)) != 1) { + if (tls_compute_verify_data(conn->master_secret, "client finished", + &conn->dgst_ctx, local_verify_data) != 1) { error_print(); + tls_send_alert(conn, TLS_alert_internal_error); + return -1; + } + + tls_trace("recv client {Finished}\n"); + + if ((ret = tls_recv_record(conn)) != 1) { + if (ret != TLS_ERROR_RECV_AGAIN) { + error_print(); + } return ret; } + if (tls_record_protocol(conn->record) != conn->protocol) { + error_print(); + tls_send_alert(conn, TLS_alert_unexpected_message); + return -1; + } + if (tls_record_decrypt(&conn->client_write_mac_ctx, &conn->client_write_key, + conn->client_seq_num, conn->record, conn->recordlen, + conn->plain_record, &conn->plain_recordlen) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_bad_record_mac); + return -1; + } + tls_seq_num_incr(conn->client_seq_num); + + tlcp_record_print(stderr, 0, 0, conn->plain_record, conn->plain_recordlen); + + if (tls_record_get_handshake_finished(conn->plain_record, &verify_data, &verify_data_len) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_unexpected_message); + return -1; + } + if (verify_data_len != sizeof(local_verify_data) + || memcmp(verify_data, local_verify_data, sizeof(local_verify_data)) != 0) { + error_print(); + tls_send_alert(conn, TLS_alert_decrypt_error); + return -1; + } + + if (digest_update(&conn->dgst_ctx, conn->plain_record + 5, conn->plain_recordlen - 5) != 1) { + error_print(); + return -1; + } + tls_handshake_digest_print(stderr, 0, 0, "client Finished", &conn->dgst_ctx); + return 1; } int tlcp_send_server_finished(TLS_CONNECT *conn) { int ret; + uint8_t verify_data[12]; - if ((ret = tls_send_server_finished(conn)) != 1) { - error_print(); + if (conn->recordlen == 0) { + tls_trace("send server {Finished}\n"); + + if (tls_compute_verify_data(conn->master_secret, "server finished", + &conn->dgst_ctx, verify_data) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_internal_error); + return -1; + } + + tls_record_set_protocol(conn->plain_record, conn->protocol); + + if (tls_record_set_handshake_finished(conn->plain_record, &conn->plain_recordlen, + verify_data, sizeof(verify_data)) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_internal_error); + return -1; + } + tlcp_record_print(stderr, 0, 0, conn->plain_record, conn->plain_recordlen); + + if (tls_record_encrypt(&conn->server_write_mac_ctx, &conn->server_write_key, + conn->server_seq_num, conn->plain_record, conn->plain_recordlen, + conn->record, &conn->recordlen) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_internal_error); + return -1; + } + tls_seq_num_incr(conn->server_seq_num); + } + + if ((ret = tls_send_record(conn)) != 1) { + if (ret != TLS_ERROR_SEND_AGAIN) { + error_print(); + } return ret; } + return 1; } diff --git a/src/tls.c b/src/tls.c index c12bcd44..b249f8eb 100644 --- a/src/tls.c +++ b/src/tls.c @@ -804,7 +804,7 @@ int tls_record_set_handshake_client_hello(uint8_t *record, size_t *recordlen, tls_uint8_to_bytes((uint8_t)TLS_compression_null, &p, &len); if (exts) { size_t tmp_len = len; - if (protocol < TLS_protocol_tls12) { + if (protocol != TLS_protocol_tlcp && protocol < TLS_protocol_tls12) { error_print(); return -1; } @@ -940,7 +940,7 @@ int tls_record_set_handshake_server_hello(uint8_t *record, size_t *recordlen, tls_uint16_to_bytes((uint16_t)cipher_suite, &p, &len); tls_uint8_to_bytes((uint8_t)TLS_compression_null, &p, &len); if (exts) { - if (protocol < TLS_protocol_tls12) { + if (protocol != TLS_protocol_tlcp && protocol < TLS_protocol_tls12) { error_print(); return -1; } @@ -2391,8 +2391,14 @@ int tls_ctx_init(TLS_CTX *ctx, int protocol, int is_client) void tls_ctx_cleanup(TLS_CTX *ctx) { if (ctx) { - gmssl_secure_clear(&ctx->signkey, sizeof(SM2_KEY)); - gmssl_secure_clear(&ctx->kenckey, sizeof(SM2_KEY)); + size_t i; + + for (i = 0; i < ctx->x509_keys_cnt; i++) { + x509_key_cleanup(&ctx->x509_keys[i]); + x509_key_cleanup(&ctx->enc_keys[i]); + } + x509_key_cleanup(&ctx->signkey); + x509_key_cleanup(&ctx->kenckey); if (ctx->certs) free(ctx->certs); if (ctx->cacerts) free(ctx->cacerts); memset(ctx, 0, sizeof(TLS_CTX)); @@ -2707,15 +2713,18 @@ end: return ret; } -int tls_ctx_set_tlcp_server_certificate_and_keys(TLS_CTX *ctx, const char *chainfile, +int tlcp_ctx_add_server_certificate_and_keys(TLS_CTX *ctx, const char *chainfile, const char *signkeyfile, const char *signkeypass, const char *kenckeyfile, const char *kenckeypass) { int ret = -1; const int algor = OID_ec_public_key; const int algor_param = OID_sm2; - uint8_t *certs = NULL; - size_t certslen; + uint8_t *cert_chain; + size_t cert_chain_len; + size_t cert_chains_len; + size_t key_idx; + FILE *certfp = NULL; FILE *signkeyfp = NULL; FILE *kenckeyfp = NULL; @@ -2732,32 +2741,50 @@ int tls_ctx_set_tlcp_server_certificate_and_keys(TLS_CTX *ctx, const char *chain error_print(); return -1; } - if (ctx->certs) { + if (ctx->protocol != TLS_protocol_tlcp || ctx->is_client) { error_print(); return -1; } - - if (x509_certs_new_from_file(&certs, &certslen, chainfile) != 1) { + if (ctx->x509_keys_cnt >= sizeof(ctx->x509_keys)/sizeof(ctx->x509_keys[0])) { error_print(); return -1; } + key_idx = ctx->x509_keys_cnt; + if (sizeof(ctx->cert_chains) <= ctx->cert_chains_len + tls_uint24_size()) { + error_print(); + return -1; + } + if (!(certfp = fopen(chainfile, "r"))) { + error_print(); + goto end; + } + cert_chain = ctx->cert_chains + ctx->cert_chains_len; + if (x509_certs_from_pem(cert_chain + tls_uint24_size(), &cert_chain_len, + sizeof(ctx->cert_chains) - ctx->cert_chains_len - tls_uint24_size(), certfp) != 1) { + error_print(); + goto end; + } + cert_chains_len = 0; + tls_uint24_to_bytes((uint24_t)cert_chain_len, &cert_chain, &cert_chains_len); + cert_chains_len += cert_chain_len; // load sign key if (!(signkeyfp = fopen(signkeyfile, "r"))) { error_print(); goto end; } - if (x509_private_key_from_file(&ctx->signkey, algor, signkeypass, signkeyfp) != 1) { + if (x509_private_key_from_file(&ctx->x509_keys[key_idx], algor, signkeypass, signkeyfp) != 1) { error_print(); goto end; } - if (x509_certs_get_cert_by_index(certs, certslen, 0, &cert, &certlen) != 1 + if (x509_certs_get_cert_by_index(cert_chain, cert_chain_len, 0, &cert, &certlen) != 1 || x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) { error_print(); goto end; } - if (x509_public_key_equ(&ctx->signkey, &public_key) != 1) { + if (public_key.algor != algor || public_key.algor_param != algor_param + || x509_public_key_equ(&ctx->x509_keys[key_idx], &public_key) != 1) { error_print(); goto end; } @@ -2767,34 +2794,57 @@ int tls_ctx_set_tlcp_server_certificate_and_keys(TLS_CTX *ctx, const char *chain error_print(); goto end; } - if (x509_private_key_from_file(&ctx->kenckey, algor, kenckeypass, kenckeyfp) != 1) { + if (x509_private_key_from_file(&ctx->enc_keys[key_idx], algor, kenckeypass, kenckeyfp) != 1) { error_print(); goto end; } - if (x509_certs_get_cert_by_index(certs, certslen, 1, &cert, &certlen) != 1 + if (x509_certs_get_cert_by_index(cert_chain, cert_chain_len, 1, &cert, &certlen) != 1 || x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) { error_print(); goto end; } - if (x509_public_key_equ(&ctx->kenckey, &public_key) != 1) { + if (public_key.algor != algor || public_key.algor_param != algor_param + || x509_public_key_equ(&ctx->enc_keys[key_idx], &public_key) != 1) { error_print(); goto end; } - ctx->certs = certs; - ctx->certslen = certslen; - certs = NULL; + ctx->cert_chains_len += cert_chains_len; + ctx->cert_chains_cnt++; + ctx->x509_keys_cnt++; + if (key_idx == 0) { + ctx->signkey = ctx->x509_keys[0]; + ctx->kenckey = ctx->enc_keys[0]; + } ret = 1; end: - if (ret != 1) x509_key_cleanup(&ctx->signkey); - if (ret != 1) x509_key_cleanup(&ctx->kenckey); - if (certs) free(certs); + if (ret != 1) { + x509_key_cleanup(&ctx->x509_keys[key_idx]); + x509_key_cleanup(&ctx->enc_keys[key_idx]); + } + if (certfp) fclose(certfp); if (signkeyfp) fclose(signkeyfp); if (kenckeyfp) fclose(kenckeyfp); return ret; } +int tls_ctx_set_tlcp_server_certificate_and_keys(TLS_CTX *ctx, const char *chainfile, + const char *signkeyfile, const char *signkeypass, + const char *kenckeyfile, const char *kenckeypass) +{ + if (!ctx || ctx->cert_chains_len || ctx->x509_keys_cnt) { + error_print(); + return -1; + } + if (tlcp_ctx_add_server_certificate_and_keys(ctx, chainfile, + signkeyfile, signkeypass, kenckeyfile, kenckeypass) != 1) { + error_print(); + return -1; + } + return 1; +} + int tls_ctx_set_supported_groups(TLS_CTX *ctx, const int *groups, size_t groups_cnt) { size_t i; diff --git a/tools/tlcp_client.c b/tools/tlcp_client.c index 19665024..070e6878 100644 --- a/tools/tlcp_client.c +++ b/tools/tlcp_client.c @@ -42,7 +42,7 @@ static const char *help = " -client_cert_optional Allow client send empty Certificate\n" " -get path Send a GET request with given path of URI\n" " -outcerts file Save server certificates to a PEM file\n" -" -server_name Send server_name (SNI) request\n" +" -server_name str Send server_name (SNI) request\n" " -status_request Send status_request (OCSP Stapling) request\n" " -quiet Without printing any status message\n" "\n" @@ -68,6 +68,7 @@ int tlcp_client_main(int argc, char *argv[]) char *certfile = NULL; char *keyfile = NULL; char *pass = NULL; + char *server_name = NULL; int client_cert_optional = 0; char *get = NULL; char *outcertsfile = NULL; @@ -162,6 +163,9 @@ int tlcp_client_main(int argc, char *argv[]) } else if (!strcmp(*argv, "-pass")) { if (--argc < 1) goto bad; pass = *(++argv); + } else if (!strcmp(*argv, "-server_name")) { + if (--argc < 1) goto bad; + server_name = *(++argv); } else if (!strcmp(*argv, "-client_cert_optional")) { client_cert_optional = 1; } else if (!strcmp(*argv, "-get")) { @@ -252,6 +256,12 @@ bad: error_print(); goto end; } + if (server_name) { + if (tls_set_server_name(&conn, (uint8_t *)server_name, strlen(server_name)) != 1) { + error_print(); + goto end; + } + } if (tls_socket_create(&sock, AF_INET, SOCK_STREAM, 0) != 1) { fprintf(stderr, "%s: faild to open socket\n", prog); diff --git a/tools/tlcp_server.c b/tools/tlcp_server.c index c2ca0fc5..ed650e5a 100644 --- a/tools/tlcp_server.c +++ b/tools/tlcp_server.c @@ -18,16 +18,18 @@ #include -static const char *options = "[-port num] -cert file -key file [-pass str] -ex_key file [-ex_pass str] [-cacert file]"; +static const char *options = "[-port num] -cert file -key file -pass str -ex_key file -ex_pass str [-cacert file]"; static const char *help = "Options\n" "\n" " -port num Listening port number, default 443\n" -" -cert file Server's certificate chain in PEM format\n" -" -key file Server's encrypted private key in PEM format\n" -" -pass str Password to decrypt private key\n" +" -cert file Server's certificate chain in PEM format, may appear multiple times\n" +" -key file Server's signing private key in PEM format, may appear multiple times\n" +" -pass str Password to decrypt signing private key, may appear multiple times\n" +" -ex_key file Server's encryption private key in PEM format, may appear multiple times\n" +" -ex_pass str Password to decrypt encryption private key, may appear multiple times\n" " -cacert file CA certificate for client certificate verification\n" "\n" #include "tlcp_help.h" @@ -38,11 +40,16 @@ int tlcp_server_main(int argc , char **argv) int ret = 1; char *prog = argv[0]; int port = 443; - char *certfile = NULL; - char *signkeyfile = NULL; - char *signpass = NULL; - char *enckeyfile = NULL; - char *encpass = NULL; + char *certfiles[4]; + size_t certfiles_cnt = 0; + char *signkeyfiles[sizeof(certfiles)/sizeof(certfiles[0])]; + size_t signkeyfiles_cnt = 0; + char *signpasses[sizeof(certfiles)/sizeof(certfiles[0])]; + size_t signpasses_cnt = 0; + char *enckeyfiles[sizeof(certfiles)/sizeof(certfiles[0])]; + size_t enckeyfiles_cnt = 0; + char *encpasses[sizeof(certfiles)/sizeof(certfiles[0])]; + size_t encpasses_cnt = 0; char *cacertfile = NULL; int server_ciphers[] = { TLS_cipher_ecc_sm4_cbc_sm3, }; @@ -56,6 +63,7 @@ int tlcp_server_main(int argc , char **argv) struct sockaddr_in server_addr; struct sockaddr_in client_addr; tls_socklen_t client_addrlen; + size_t i; argc--; argv++; @@ -74,20 +82,40 @@ int tlcp_server_main(int argc , char **argv) if (--argc < 1) goto bad; port = atoi(*(++argv)); } else if (!strcmp(*argv, "-cert")) { + if (certfiles_cnt >= sizeof(certfiles)/sizeof(certfiles[0])) { + fprintf(stderr, "%s: too many -cert options\n", prog); + return -1; + } if (--argc < 1) goto bad; - certfile = *(++argv); + certfiles[certfiles_cnt++] = *(++argv); } else if (!strcmp(*argv, "-key")) { + if (signkeyfiles_cnt >= sizeof(signkeyfiles)/sizeof(signkeyfiles[0])) { + fprintf(stderr, "%s: too many -key options\n", prog); + return -1; + } if (--argc < 1) goto bad; - signkeyfile = *(++argv); + signkeyfiles[signkeyfiles_cnt++] = *(++argv); } else if (!strcmp(*argv, "-pass")) { + if (signpasses_cnt >= sizeof(signpasses)/sizeof(signpasses[0])) { + fprintf(stderr, "%s: too many -pass options\n", prog); + return -1; + } if (--argc < 1) goto bad; - signpass = *(++argv); + signpasses[signpasses_cnt++] = *(++argv); } else if (!strcmp(*argv, "-ex_key")) { + if (enckeyfiles_cnt >= sizeof(enckeyfiles)/sizeof(enckeyfiles[0])) { + fprintf(stderr, "%s: too many -ex_key options\n", prog); + return -1; + } if (--argc < 1) goto bad; - enckeyfile = *(++argv); + enckeyfiles[enckeyfiles_cnt++] = *(++argv); } else if (!strcmp(*argv, "-ex_pass")) { + if (encpasses_cnt >= sizeof(encpasses)/sizeof(encpasses[0])) { + fprintf(stderr, "%s: too many -ex_pass options\n", prog); + return -1; + } if (--argc < 1) goto bad; - encpass = *(++argv); + encpasses[encpasses_cnt++] = *(++argv); } else if (!strcmp(*argv, "-cacert")) { if (--argc < 1) goto bad; cacertfile = *(++argv); @@ -101,36 +129,48 @@ bad: argc--; argv++; } - if (!certfile) { + if (!certfiles_cnt) { fprintf(stderr, "%s: '-cert' option required\n", prog); return 1; } - if (!signkeyfile) { + if (!signkeyfiles_cnt) { fprintf(stderr, "%s: '-key' option required\n", prog); return 1; } - if (!signpass) { + if (!signpasses_cnt) { fprintf(stderr, "%s: '-pass' option required\n", prog); return 1; } - if (!enckeyfile) { + if (!enckeyfiles_cnt) { fprintf(stderr, "%s: '-ex_key' option required\n", prog); return 1; } - if (!encpass) { + if (!encpasses_cnt) { fprintf(stderr, "%s: '-ex_pass' option required\n", prog); return 1; } + if (certfiles_cnt != signkeyfiles_cnt || signkeyfiles_cnt != signpasses_cnt + || signpasses_cnt != enckeyfiles_cnt || enckeyfiles_cnt != encpasses_cnt) { + fprintf(stderr, "%s: -cert/-key/-pass/-ex_key/-ex_pass counts mismatch\n", prog); + return 1; + } memset(&ctx, 0, sizeof(ctx)); memset(&conn, 0, sizeof(conn)); if (tls_ctx_init(&ctx, TLS_protocol_tlcp, TLS_server_mode) != 1 - || tls_ctx_set_cipher_suites(&ctx, server_ciphers, sizeof(server_ciphers)/sizeof(int)) != 1 - || tls_ctx_set_tlcp_server_certificate_and_keys(&ctx, certfile, signkeyfile, signpass, enckeyfile, encpass) != 1) { + || tls_ctx_set_cipher_suites(&ctx, server_ciphers, sizeof(server_ciphers)/sizeof(int)) != 1) { error_print(); return -1; } + for (i = 0; i < certfiles_cnt; i++) { + if (tlcp_ctx_add_server_certificate_and_keys(&ctx, + certfiles[i], signkeyfiles[i], signpasses[i], + enckeyfiles[i], encpasses[i]) != 1) { + error_print(); + return -1; + } + } if (cacertfile) { if (tls_ctx_set_ca_certificates(&ctx, cacertfile, TLS_DEFAULT_VERIFY_DEPTH) != 1) { error_print();