From 1dd920c1986e1121695b6197554b5874fe05711e Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Thu, 11 Jun 2026 20:28:49 +0800 Subject: [PATCH] Update tls12.c --- src/tls12.c | 463 ++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 429 insertions(+), 34 deletions(-) diff --git a/src/tls12.c b/src/tls12.c index c9429528..87a7ac38 100644 --- a/src/tls12.c +++ b/src/tls12.c @@ -757,6 +757,259 @@ const size_t server_ciphers_cnt = 1; */ const int curve = TLS_curve_sm2p256v1; +static int tls12_cipher_suite_get(int cipher_suite, const BLOCK_CIPHER **cipher, const DIGEST **digest) +{ + switch (cipher_suite) { + case TLS_cipher_ecdhe_sm4_cbc_sm3: + case TLS_cipher_ecdhe_sm4_gcm_sm3: + *cipher = BLOCK_CIPHER_sm4(); + *digest = DIGEST_sm3(); + break; + case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: + *cipher = BLOCK_CIPHER_aes128(); + *digest = DIGEST_sha256(); + break; + default: + error_print(); + return -1; + } + return 1; +} + +static int tls12_cipher_suite_match_cert_group(int cipher_suite, int cert_group) +{ + switch (cipher_suite) { + case TLS_cipher_ecdhe_sm4_cbc_sm3: + case TLS_cipher_ecdhe_sm4_gcm_sm3: + return cert_group == TLS_curve_sm2p256v1; + case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: + return cert_group == TLS_curve_secp256r1; + default: + return 0; + } +} + +static int tls12_signature_scheme_match_cert_group(int sig_alg, int cert_group) +{ + return tls_signature_scheme_group_oid(sig_alg) == tls_named_curve_oid(cert_group); +} + +static int tls12_signature_scheme_match_cipher_suite(int sig_alg, int cipher_suite) +{ + switch (sig_alg) { + case TLS_sig_sm2sig_sm3: + switch (cipher_suite) { + case TLS_cipher_ecdhe_sm4_cbc_sm3: + case TLS_cipher_ecdhe_sm4_gcm_sm3: + return 1; + } + break; + case TLS_sig_ecdsa_secp256r1_sha256: + if (cipher_suite == TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256) { + return 1; + } + break; + } + return 0; +} + +static int tls12_key_exchange_group_match_cipher_suite(int group, int cipher_suite) +{ + switch (cipher_suite) { + case TLS_cipher_ecdhe_sm4_cbc_sm3: + case TLS_cipher_ecdhe_sm4_gcm_sm3: + return group == TLS_curve_sm2p256v1; + case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: + return group == TLS_curve_secp256r1; + default: + return 0; + } +} + +static int tls12_select_common_cipher_suites(const uint8_t *client_ciphers, size_t client_ciphers_len, + const int *server_ciphers, size_t server_ciphers_cnt, + int *common_ciphers, size_t *common_ciphers_cnt, size_t max_cnt) +{ + size_t i; + + if (!client_ciphers || !client_ciphers_len + || !server_ciphers || !server_ciphers_cnt + || !common_ciphers || !common_ciphers_cnt || !max_cnt) { + error_print(); + return -1; + } + + *common_ciphers_cnt = 0; + for (i = 0; i < server_ciphers_cnt && *common_ciphers_cnt < max_cnt; i++) { + const uint8_t *p = client_ciphers; + size_t len = client_ciphers_len; + while (len) { + uint16_t cipher; + if (tls_uint16_from_bytes(&cipher, &p, &len) != 1) { + error_print(); + return -1; + } + if (cipher == server_ciphers[i]) { + common_ciphers[(*common_ciphers_cnt)++] = server_ciphers[i]; + break; + } + } + } + + return *common_ciphers_cnt ? 1 : 0; +} + +static int tls12_cert_chain_get_end_entity_group(const uint8_t *cert_chain, size_t cert_chain_len, int *group) +{ + const uint8_t *cert; + size_t certlen; + X509_KEY public_key; + + if (!cert_chain || !cert_chain_len || !group) { + error_print(); + return -1; + } + if (x509_certs_get_cert_by_index(cert_chain, cert_chain_len, 0, &cert, &certlen) != 1 + || x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) { + error_print(); + return -1; + } + if (public_key.algor != OID_ec_public_key) { + error_print(); + return -1; + } + if ((*group = tls_named_curve_from_oid(public_key.algor_param)) == 0) { + error_print(); + return -1; + } + return 1; +} + +static int tls12_select_key_exchange_group(const int *groups, size_t groups_cnt, + int cipher_suite, int *selected_group) +{ + size_t i; + + if (!groups || !groups_cnt || !selected_group) { + error_print(); + return -1; + } + for (i = 0; i < groups_cnt; i++) { + if (tls12_key_exchange_group_match_cipher_suite(groups[i], cipher_suite)) { + *selected_group = groups[i]; + return 1; + } + } + return 0; +} + +// 这个函数的名字最好换一下 +static int tls12_select_parameters(TLS_CONNECT *conn, + const int *common_cipher_suites, size_t common_cipher_suites_cnt, + const int *common_supported_groups, size_t common_supported_groups_cnt, + const int *common_signature_algorithms, size_t common_signature_algorithms_cnt, + const int *signature_algorithms_cert, size_t signature_algorithms_cert_cnt, + const uint8_t *host_name, size_t host_name_len) +{ + const uint8_t *cert_chains = conn->ctx->cert_chains; + size_t cert_chains_len = conn->ctx->cert_chains_len; + size_t cert_chain_idx; + + if (!conn || !common_cipher_suites || !common_cipher_suites_cnt + || !common_supported_groups || !common_supported_groups_cnt + || !common_signature_algorithms || !common_signature_algorithms_cnt) { + error_print(); + return -1; + } + if (!cert_chains || !cert_chains_len) { + error_print(); + return -1; + } + + for (cert_chain_idx = 1; cert_chains_len; cert_chain_idx++) { + const uint8_t *cert_chain; + size_t cert_chain_len; + const uint8_t *cert; + size_t certlen; + int cert_group; + size_t i; + int ret; + + if (tls_uint24array_from_bytes(&cert_chain, &cert_chain_len, + &cert_chains, &cert_chains_len) != 1) { + error_print(); + return -1; + } + if (tls12_cert_chain_get_end_entity_group(cert_chain, cert_chain_len, &cert_group) != 1) { + error_print(); + return -1; + } + if (!tls_type_is_in_list(cert_group, common_supported_groups, common_supported_groups_cnt)) { + continue; + } + if (x509_certs_get_cert_by_index(cert_chain, cert_chain_len, 0, &cert, &certlen) != 1) { + error_print(); + return -1; + } + if (host_name && host_name_len) { + if ((ret = tls_cert_match_server_name(cert, certlen, host_name, host_name_len)) < 0) { + error_print(); + return -1; + } else if (ret == 0) { + continue; + } + } + if (signature_algorithms_cert && signature_algorithms_cert_cnt) { + if ((ret = tls_cert_chain_match_signature_algorithms_cert(cert_chain, cert_chain_len, + signature_algorithms_cert, signature_algorithms_cert_cnt)) < 0) { + error_print(); + return -1; + } else if (ret == 0) { + continue; + } + } + + for (i = 0; i < common_cipher_suites_cnt; i++) { + size_t j; + int cipher_suite = common_cipher_suites[i]; + int key_exchange_group; + + if (!tls12_cipher_suite_match_cert_group(cipher_suite, cert_group)) { + continue; + } + if ((ret = tls12_select_key_exchange_group(common_supported_groups, + common_supported_groups_cnt, cipher_suite, &key_exchange_group)) < 0) { + error_print(); + return -1; + } else if (ret == 0) { + continue; + } + + for (j = 0; j < common_signature_algorithms_cnt; j++) { + int sig_alg = common_signature_algorithms[j]; + + if (!tls12_signature_scheme_match_cert_group(sig_alg, cert_group)) { + continue; + } + if (!tls12_signature_scheme_match_cipher_suite(sig_alg, cipher_suite)) { + continue; + } + + conn->cipher_suite = cipher_suite; + conn->cert_chain = cert_chain; + conn->cert_chain_len = cert_chain_len; + conn->cert_chain_idx = cert_chain_idx; + conn->sig_alg = sig_alg; + conn->key_exchange_group = key_exchange_group; + return 1; + } + } + } + + warning_print(); + return 0; +} + int tls_recv_client_hello(TLS_CONNECT *conn) { @@ -772,6 +1025,26 @@ int tls_recv_client_hello(TLS_CONNECT *conn) size_t cipher_suites_len; const uint8_t *exts; size_t extslen; + const uint8_t *supported_groups = NULL; + size_t supported_groups_len = 0; + const uint8_t *signature_algorithms = NULL; + size_t signature_algorithms_len = 0; + const uint8_t *signature_algorithms_cert = NULL; + size_t signature_algorithms_cert_len = 0; + const uint8_t *server_name = NULL; + size_t server_name_len = 0; + int common_cipher_suites[TLS_MAX_CIPHER_SUITES_COUNT]; + size_t common_cipher_suites_cnt = 0; + int common_supported_groups[32]; + size_t common_supported_groups_cnt = 0; + int common_signature_algorithms[32]; + size_t common_signature_algorithms_cnt = 0; + int common_signature_algorithms_cert[32]; + size_t common_signature_algorithms_cert_cnt = 0; + const int *cert_signature_algorithms = NULL; + size_t cert_signature_algorithms_cnt = 0; + const uint8_t *host_name = NULL; + size_t host_name_len = 0; /* if (client_verify) @@ -816,25 +1089,6 @@ int tls_recv_client_hello(TLS_CONNECT *conn) memcpy(conn->client_random, client_random, 32); - - if ((ret = tls_cipher_suites_select(cipher_suites, cipher_suites_len, - conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt, - &conn->cipher_suite)) < 0) { - error_print(); - tls13_send_alert(conn, TLS_alert_decode_error); - return -1; - } else if (ret == 0) { - error_print(); - tls13_send_alert(conn, TLS_alert_handshake_failure); - return -1; - } - - - conn->cipher = BLOCK_CIPHER_aes128(); - - conn->digest = DIGEST_sha256(); - - while (extslen) { int ext_type; const uint8_t *ext_data; @@ -846,30 +1100,171 @@ int tls_recv_client_hello(TLS_CONNECT *conn) return -1; } + switch (ext_type) { + case TLS_extension_supported_groups: + case TLS_extension_signature_algorithms: + case TLS_extension_signature_algorithms_cert: + case TLS_extension_server_name: + if (!ext_data) { + error_print(); + tls_send_alert(conn, TLS_alert_decode_error); + return -1; + } + break; + } - // 这些扩展都不是必须的 + switch (ext_type) { + case TLS_extension_supported_groups: + if (supported_groups) { + error_print(); + tls_send_alert(conn, TLS_alert_illegal_parameter); + return -1; + } + supported_groups = ext_data; + supported_groups_len = ext_datalen; + break; + case TLS_extension_signature_algorithms: + if (signature_algorithms) { + error_print(); + tls_send_alert(conn, TLS_alert_illegal_parameter); + return -1; + } + signature_algorithms = ext_data; + signature_algorithms_len = ext_datalen; + break; + case TLS_extension_signature_algorithms_cert: + if (signature_algorithms_cert) { + error_print(); + tls_send_alert(conn, TLS_alert_illegal_parameter); + return -1; + } + signature_algorithms_cert = ext_data; + signature_algorithms_cert_len = ext_datalen; + break; + case TLS_extension_server_name: + if (server_name) { + error_print(); + tls_send_alert(conn, TLS_alert_illegal_parameter); + return -1; + } + server_name = ext_data; + server_name_len = ext_datalen; + break; + default: + warning_print(); + } } - - // select server certificate - // 这里的逻辑还没有实现,sig_alg也要处理 - - if (tls12_cert_chains_select(conn->ctx->cert_chains, conn->ctx->cert_chains_len, - NULL, 0, - NULL, 0, - NULL, 0, - NULL, 0, - &conn->cert_chain, &conn->cert_chain_len, &conn->cert_chain_idx, &conn->sig_alg) != 1) { + if ((ret = tls12_select_common_cipher_suites(cipher_suites, cipher_suites_len, + conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt, + common_cipher_suites, &common_cipher_suites_cnt, + sizeof(common_cipher_suites)/sizeof(common_cipher_suites[0]))) < 0) { error_print(); + tls_send_alert(conn, TLS_alert_decode_error); + return -1; + } else if (ret == 0) { + error_print(); + tls_send_alert(conn, TLS_alert_handshake_failure); return -1; } - // 上面这个逻辑里面根本找不到合适的sig_alg xxxxx - conn->sig_alg = TLS_sig_ecdsa_secp256r1_sha256; + if (supported_groups) { + if ((ret = tls_process_supported_groups(supported_groups, supported_groups_len, + conn->ctx->supported_groups, conn->ctx->supported_groups_cnt, + common_supported_groups, &common_supported_groups_cnt, + sizeof(common_supported_groups)/sizeof(common_supported_groups[0]))) < 0) { + error_print(); + tls_send_alert(conn, TLS_alert_decode_error); + return -1; + } else if (ret == 0) { + error_print(); + tls_send_alert(conn, TLS_alert_handshake_failure); + return -1; + } + } else { + if (!conn->ctx->supported_groups_cnt) { + error_print(); + tls_send_alert(conn, TLS_alert_handshake_failure); + return -1; + } + memcpy(common_supported_groups, conn->ctx->supported_groups, + conn->ctx->supported_groups_cnt * sizeof(conn->ctx->supported_groups[0])); + common_supported_groups_cnt = conn->ctx->supported_groups_cnt; + } - // 还要设置密钥交换的算法 - conn->key_exchange_group = TLS_curve_secp256r1; + if (signature_algorithms) { + if ((ret = tls_process_signature_algorithms(signature_algorithms, signature_algorithms_len, + conn->ctx->signature_algorithms, conn->ctx->signature_algorithms_cnt, + common_signature_algorithms, &common_signature_algorithms_cnt, + sizeof(common_signature_algorithms)/sizeof(common_signature_algorithms[0]))) < 0) { + error_print(); + tls_send_alert(conn, TLS_alert_decode_error); + return -1; + } else if (ret == 0) { + error_print(); + tls_send_alert(conn, TLS_alert_handshake_failure); + return -1; + } + } else { + if (!conn->ctx->signature_algorithms_cnt) { + error_print(); + tls13_send_alert(conn, TLS_alert_handshake_failure); + return -1; + } + memcpy(common_signature_algorithms, conn->ctx->signature_algorithms, + conn->ctx->signature_algorithms_cnt * sizeof(conn->ctx->signature_algorithms[0])); + common_signature_algorithms_cnt = conn->ctx->signature_algorithms_cnt; + } + if (signature_algorithms_cert) { + if ((ret = tls_process_signature_algorithms(signature_algorithms_cert, signature_algorithms_cert_len, + conn->ctx->signature_algorithms, conn->ctx->signature_algorithms_cnt, + common_signature_algorithms_cert, &common_signature_algorithms_cert_cnt, + sizeof(common_signature_algorithms_cert)/sizeof(common_signature_algorithms_cert[0]))) < 0) { + error_print(); + tls_send_alert(conn, TLS_alert_decode_error); + return -1; + } else if (ret == 0) { + error_print(); + tls_send_alert(conn, TLS_alert_handshake_failure); + return -1; + } + cert_signature_algorithms = common_signature_algorithms_cert; + cert_signature_algorithms_cnt = common_signature_algorithms_cert_cnt; + } else if (signature_algorithms) { + cert_signature_algorithms = common_signature_algorithms; + cert_signature_algorithms_cnt = common_signature_algorithms_cnt; + } + + if (server_name) { + if (tls_server_name_from_bytes(&host_name, &host_name_len, server_name, server_name_len) != 1) { + error_print(); + tls13_send_alert(conn, TLS_alert_decode_error); + return -1; + } + conn->server_name = 1; + } + + if ((ret = tls12_select_parameters(conn, + common_cipher_suites, common_cipher_suites_cnt, + common_supported_groups, common_supported_groups_cnt, + common_signature_algorithms, common_signature_algorithms_cnt, + cert_signature_algorithms, cert_signature_algorithms_cnt, + host_name, host_name_len)) < 0) { + error_print(); + tls_send_alert(conn, TLS_alert_decode_error); + return -1; + } else if (ret == 0) { + error_print(); + tls_send_alert(conn, TLS_alert_handshake_failure); + return -1; + } + + if (tls12_cipher_suite_get(conn->cipher_suite, &conn->cipher, &conn->digest) != 1) { + error_print(); + tls13_send_alert(conn, TLS_alert_internal_error); + return -1; + } if (digest_init(&conn->dgst_ctx, conn->digest) != 1) { error_print();