Update tls12.c

This commit is contained in:
Zhi Guan
2026-06-11 20:28:49 +08:00
parent e1c69d5633
commit 1dd920c198

View File

@@ -757,6 +757,259 @@ const size_t server_ciphers_cnt = 1;
*/ */
const int curve = TLS_curve_sm2p256v1; const int curve = TLS_curve_sm2p256v1;
static int tls12_cipher_suite_get(int cipher_suite, const BLOCK_CIPHER **cipher, const DIGEST **digest)
{
switch (cipher_suite) {
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_sm4_gcm_sm3:
*cipher = BLOCK_CIPHER_sm4();
*digest = DIGEST_sm3();
break;
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
*cipher = BLOCK_CIPHER_aes128();
*digest = DIGEST_sha256();
break;
default:
error_print();
return -1;
}
return 1;
}
static int tls12_cipher_suite_match_cert_group(int cipher_suite, int cert_group)
{
switch (cipher_suite) {
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_sm4_gcm_sm3:
return cert_group == TLS_curve_sm2p256v1;
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
return cert_group == TLS_curve_secp256r1;
default:
return 0;
}
}
static int tls12_signature_scheme_match_cert_group(int sig_alg, int cert_group)
{
return tls_signature_scheme_group_oid(sig_alg) == tls_named_curve_oid(cert_group);
}
static int tls12_signature_scheme_match_cipher_suite(int sig_alg, int cipher_suite)
{
switch (sig_alg) {
case TLS_sig_sm2sig_sm3:
switch (cipher_suite) {
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_sm4_gcm_sm3:
return 1;
}
break;
case TLS_sig_ecdsa_secp256r1_sha256:
if (cipher_suite == TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256) {
return 1;
}
break;
}
return 0;
}
static int tls12_key_exchange_group_match_cipher_suite(int group, int cipher_suite)
{
switch (cipher_suite) {
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_sm4_gcm_sm3:
return group == TLS_curve_sm2p256v1;
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
return group == TLS_curve_secp256r1;
default:
return 0;
}
}
static int tls12_select_common_cipher_suites(const uint8_t *client_ciphers, size_t client_ciphers_len,
const int *server_ciphers, size_t server_ciphers_cnt,
int *common_ciphers, size_t *common_ciphers_cnt, size_t max_cnt)
{
size_t i;
if (!client_ciphers || !client_ciphers_len
|| !server_ciphers || !server_ciphers_cnt
|| !common_ciphers || !common_ciphers_cnt || !max_cnt) {
error_print();
return -1;
}
*common_ciphers_cnt = 0;
for (i = 0; i < server_ciphers_cnt && *common_ciphers_cnt < max_cnt; i++) {
const uint8_t *p = client_ciphers;
size_t len = client_ciphers_len;
while (len) {
uint16_t cipher;
if (tls_uint16_from_bytes(&cipher, &p, &len) != 1) {
error_print();
return -1;
}
if (cipher == server_ciphers[i]) {
common_ciphers[(*common_ciphers_cnt)++] = server_ciphers[i];
break;
}
}
}
return *common_ciphers_cnt ? 1 : 0;
}
static int tls12_cert_chain_get_end_entity_group(const uint8_t *cert_chain, size_t cert_chain_len, int *group)
{
const uint8_t *cert;
size_t certlen;
X509_KEY public_key;
if (!cert_chain || !cert_chain_len || !group) {
error_print();
return -1;
}
if (x509_certs_get_cert_by_index(cert_chain, cert_chain_len, 0, &cert, &certlen) != 1
|| x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) {
error_print();
return -1;
}
if (public_key.algor != OID_ec_public_key) {
error_print();
return -1;
}
if ((*group = tls_named_curve_from_oid(public_key.algor_param)) == 0) {
error_print();
return -1;
}
return 1;
}
static int tls12_select_key_exchange_group(const int *groups, size_t groups_cnt,
int cipher_suite, int *selected_group)
{
size_t i;
if (!groups || !groups_cnt || !selected_group) {
error_print();
return -1;
}
for (i = 0; i < groups_cnt; i++) {
if (tls12_key_exchange_group_match_cipher_suite(groups[i], cipher_suite)) {
*selected_group = groups[i];
return 1;
}
}
return 0;
}
// 这个函数的名字最好换一下
static int tls12_select_parameters(TLS_CONNECT *conn,
const int *common_cipher_suites, size_t common_cipher_suites_cnt,
const int *common_supported_groups, size_t common_supported_groups_cnt,
const int *common_signature_algorithms, size_t common_signature_algorithms_cnt,
const int *signature_algorithms_cert, size_t signature_algorithms_cert_cnt,
const uint8_t *host_name, size_t host_name_len)
{
const uint8_t *cert_chains = conn->ctx->cert_chains;
size_t cert_chains_len = conn->ctx->cert_chains_len;
size_t cert_chain_idx;
if (!conn || !common_cipher_suites || !common_cipher_suites_cnt
|| !common_supported_groups || !common_supported_groups_cnt
|| !common_signature_algorithms || !common_signature_algorithms_cnt) {
error_print();
return -1;
}
if (!cert_chains || !cert_chains_len) {
error_print();
return -1;
}
for (cert_chain_idx = 1; cert_chains_len; cert_chain_idx++) {
const uint8_t *cert_chain;
size_t cert_chain_len;
const uint8_t *cert;
size_t certlen;
int cert_group;
size_t i;
int ret;
if (tls_uint24array_from_bytes(&cert_chain, &cert_chain_len,
&cert_chains, &cert_chains_len) != 1) {
error_print();
return -1;
}
if (tls12_cert_chain_get_end_entity_group(cert_chain, cert_chain_len, &cert_group) != 1) {
error_print();
return -1;
}
if (!tls_type_is_in_list(cert_group, common_supported_groups, common_supported_groups_cnt)) {
continue;
}
if (x509_certs_get_cert_by_index(cert_chain, cert_chain_len, 0, &cert, &certlen) != 1) {
error_print();
return -1;
}
if (host_name && host_name_len) {
if ((ret = tls_cert_match_server_name(cert, certlen, host_name, host_name_len)) < 0) {
error_print();
return -1;
} else if (ret == 0) {
continue;
}
}
if (signature_algorithms_cert && signature_algorithms_cert_cnt) {
if ((ret = tls_cert_chain_match_signature_algorithms_cert(cert_chain, cert_chain_len,
signature_algorithms_cert, signature_algorithms_cert_cnt)) < 0) {
error_print();
return -1;
} else if (ret == 0) {
continue;
}
}
for (i = 0; i < common_cipher_suites_cnt; i++) {
size_t j;
int cipher_suite = common_cipher_suites[i];
int key_exchange_group;
if (!tls12_cipher_suite_match_cert_group(cipher_suite, cert_group)) {
continue;
}
if ((ret = tls12_select_key_exchange_group(common_supported_groups,
common_supported_groups_cnt, cipher_suite, &key_exchange_group)) < 0) {
error_print();
return -1;
} else if (ret == 0) {
continue;
}
for (j = 0; j < common_signature_algorithms_cnt; j++) {
int sig_alg = common_signature_algorithms[j];
if (!tls12_signature_scheme_match_cert_group(sig_alg, cert_group)) {
continue;
}
if (!tls12_signature_scheme_match_cipher_suite(sig_alg, cipher_suite)) {
continue;
}
conn->cipher_suite = cipher_suite;
conn->cert_chain = cert_chain;
conn->cert_chain_len = cert_chain_len;
conn->cert_chain_idx = cert_chain_idx;
conn->sig_alg = sig_alg;
conn->key_exchange_group = key_exchange_group;
return 1;
}
}
}
warning_print();
return 0;
}
int tls_recv_client_hello(TLS_CONNECT *conn) int tls_recv_client_hello(TLS_CONNECT *conn)
{ {
@@ -772,6 +1025,26 @@ int tls_recv_client_hello(TLS_CONNECT *conn)
size_t cipher_suites_len; size_t cipher_suites_len;
const uint8_t *exts; const uint8_t *exts;
size_t extslen; size_t extslen;
const uint8_t *supported_groups = NULL;
size_t supported_groups_len = 0;
const uint8_t *signature_algorithms = NULL;
size_t signature_algorithms_len = 0;
const uint8_t *signature_algorithms_cert = NULL;
size_t signature_algorithms_cert_len = 0;
const uint8_t *server_name = NULL;
size_t server_name_len = 0;
int common_cipher_suites[TLS_MAX_CIPHER_SUITES_COUNT];
size_t common_cipher_suites_cnt = 0;
int common_supported_groups[32];
size_t common_supported_groups_cnt = 0;
int common_signature_algorithms[32];
size_t common_signature_algorithms_cnt = 0;
int common_signature_algorithms_cert[32];
size_t common_signature_algorithms_cert_cnt = 0;
const int *cert_signature_algorithms = NULL;
size_t cert_signature_algorithms_cnt = 0;
const uint8_t *host_name = NULL;
size_t host_name_len = 0;
/* /*
if (client_verify) if (client_verify)
@@ -816,25 +1089,6 @@ int tls_recv_client_hello(TLS_CONNECT *conn)
memcpy(conn->client_random, client_random, 32); memcpy(conn->client_random, client_random, 32);
if ((ret = tls_cipher_suites_select(cipher_suites, cipher_suites_len,
conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt,
&conn->cipher_suite)) < 0) {
error_print();
tls13_send_alert(conn, TLS_alert_decode_error);
return -1;
} else if (ret == 0) {
error_print();
tls13_send_alert(conn, TLS_alert_handshake_failure);
return -1;
}
conn->cipher = BLOCK_CIPHER_aes128();
conn->digest = DIGEST_sha256();
while (extslen) { while (extslen) {
int ext_type; int ext_type;
const uint8_t *ext_data; const uint8_t *ext_data;
@@ -846,30 +1100,171 @@ int tls_recv_client_hello(TLS_CONNECT *conn)
return -1; return -1;
} }
switch (ext_type) {
case TLS_extension_supported_groups:
case TLS_extension_signature_algorithms:
case TLS_extension_signature_algorithms_cert:
case TLS_extension_server_name:
if (!ext_data) {
error_print();
tls_send_alert(conn, TLS_alert_decode_error);
return -1;
}
break;
}
// 这些扩展都不是必须的 switch (ext_type) {
case TLS_extension_supported_groups:
if (supported_groups) {
error_print();
tls_send_alert(conn, TLS_alert_illegal_parameter);
return -1;
}
supported_groups = ext_data;
supported_groups_len = ext_datalen;
break;
case TLS_extension_signature_algorithms:
if (signature_algorithms) {
error_print();
tls_send_alert(conn, TLS_alert_illegal_parameter);
return -1;
}
signature_algorithms = ext_data;
signature_algorithms_len = ext_datalen;
break;
case TLS_extension_signature_algorithms_cert:
if (signature_algorithms_cert) {
error_print();
tls_send_alert(conn, TLS_alert_illegal_parameter);
return -1;
}
signature_algorithms_cert = ext_data;
signature_algorithms_cert_len = ext_datalen;
break;
case TLS_extension_server_name:
if (server_name) {
error_print();
tls_send_alert(conn, TLS_alert_illegal_parameter);
return -1;
}
server_name = ext_data;
server_name_len = ext_datalen;
break;
default:
warning_print();
}
} }
if ((ret = tls12_select_common_cipher_suites(cipher_suites, cipher_suites_len,
// select server certificate conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt,
// 这里的逻辑还没有实现sig_alg也要处理 common_cipher_suites, &common_cipher_suites_cnt,
sizeof(common_cipher_suites)/sizeof(common_cipher_suites[0]))) < 0) {
if (tls12_cert_chains_select(conn->ctx->cert_chains, conn->ctx->cert_chains_len,
NULL, 0,
NULL, 0,
NULL, 0,
NULL, 0,
&conn->cert_chain, &conn->cert_chain_len, &conn->cert_chain_idx, &conn->sig_alg) != 1) {
error_print(); error_print();
tls_send_alert(conn, TLS_alert_decode_error);
return -1;
} else if (ret == 0) {
error_print();
tls_send_alert(conn, TLS_alert_handshake_failure);
return -1; return -1;
} }
// 上面这个逻辑里面根本找不到合适的sig_alg xxxxx
conn->sig_alg = TLS_sig_ecdsa_secp256r1_sha256;
if (supported_groups) {
if ((ret = tls_process_supported_groups(supported_groups, supported_groups_len,
conn->ctx->supported_groups, conn->ctx->supported_groups_cnt,
common_supported_groups, &common_supported_groups_cnt,
sizeof(common_supported_groups)/sizeof(common_supported_groups[0]))) < 0) {
error_print();
tls_send_alert(conn, TLS_alert_decode_error);
return -1;
} else if (ret == 0) {
error_print();
tls_send_alert(conn, TLS_alert_handshake_failure);
return -1;
}
} else {
if (!conn->ctx->supported_groups_cnt) {
error_print();
tls_send_alert(conn, TLS_alert_handshake_failure);
return -1;
}
memcpy(common_supported_groups, conn->ctx->supported_groups,
conn->ctx->supported_groups_cnt * sizeof(conn->ctx->supported_groups[0]));
common_supported_groups_cnt = conn->ctx->supported_groups_cnt;
}
// 还要设置密钥交换的算法 if (signature_algorithms) {
conn->key_exchange_group = TLS_curve_secp256r1; if ((ret = tls_process_signature_algorithms(signature_algorithms, signature_algorithms_len,
conn->ctx->signature_algorithms, conn->ctx->signature_algorithms_cnt,
common_signature_algorithms, &common_signature_algorithms_cnt,
sizeof(common_signature_algorithms)/sizeof(common_signature_algorithms[0]))) < 0) {
error_print();
tls_send_alert(conn, TLS_alert_decode_error);
return -1;
} else if (ret == 0) {
error_print();
tls_send_alert(conn, TLS_alert_handshake_failure);
return -1;
}
} else {
if (!conn->ctx->signature_algorithms_cnt) {
error_print();
tls13_send_alert(conn, TLS_alert_handshake_failure);
return -1;
}
memcpy(common_signature_algorithms, conn->ctx->signature_algorithms,
conn->ctx->signature_algorithms_cnt * sizeof(conn->ctx->signature_algorithms[0]));
common_signature_algorithms_cnt = conn->ctx->signature_algorithms_cnt;
}
if (signature_algorithms_cert) {
if ((ret = tls_process_signature_algorithms(signature_algorithms_cert, signature_algorithms_cert_len,
conn->ctx->signature_algorithms, conn->ctx->signature_algorithms_cnt,
common_signature_algorithms_cert, &common_signature_algorithms_cert_cnt,
sizeof(common_signature_algorithms_cert)/sizeof(common_signature_algorithms_cert[0]))) < 0) {
error_print();
tls_send_alert(conn, TLS_alert_decode_error);
return -1;
} else if (ret == 0) {
error_print();
tls_send_alert(conn, TLS_alert_handshake_failure);
return -1;
}
cert_signature_algorithms = common_signature_algorithms_cert;
cert_signature_algorithms_cnt = common_signature_algorithms_cert_cnt;
} else if (signature_algorithms) {
cert_signature_algorithms = common_signature_algorithms;
cert_signature_algorithms_cnt = common_signature_algorithms_cnt;
}
if (server_name) {
if (tls_server_name_from_bytes(&host_name, &host_name_len, server_name, server_name_len) != 1) {
error_print();
tls13_send_alert(conn, TLS_alert_decode_error);
return -1;
}
conn->server_name = 1;
}
if ((ret = tls12_select_parameters(conn,
common_cipher_suites, common_cipher_suites_cnt,
common_supported_groups, common_supported_groups_cnt,
common_signature_algorithms, common_signature_algorithms_cnt,
cert_signature_algorithms, cert_signature_algorithms_cnt,
host_name, host_name_len)) < 0) {
error_print();
tls_send_alert(conn, TLS_alert_decode_error);
return -1;
} else if (ret == 0) {
error_print();
tls_send_alert(conn, TLS_alert_handshake_failure);
return -1;
}
if (tls12_cipher_suite_get(conn->cipher_suite, &conn->cipher, &conn->digest) != 1) {
error_print();
tls13_send_alert(conn, TLS_alert_internal_error);
return -1;
}
if (digest_init(&conn->dgst_ctx, conn->digest) != 1) { if (digest_init(&conn->dgst_ctx, conn->digest) != 1) {
error_print(); error_print();