Clean TLS_CTX/CONNECT

This commit is contained in:
Zhi Guan
2026-06-12 12:01:30 +08:00
parent 30bc6a2a4c
commit fb93fba5ff
5 changed files with 73 additions and 116 deletions

View File

@@ -786,10 +786,6 @@ typedef struct {
uint8_t cert_chains[8192]; uint8_t cert_chains[8192];
size_t cert_chains_len; size_t cert_chains_len;
size_t cert_chains_cnt; // 这是一个多余的值,不应该存储多余的值
// size_t cert_chain_idx; // == 1 mean the first certificate
uint8_t *certs; // 这里应该改为cert_chain我们将certs表示为互相独立的证书
size_t certslen;
// 每个证书链都应该有附带的status_request和sct信息 // 每个证书链都应该有附带的status_request和sct信息
@@ -804,8 +800,6 @@ typedef struct {
X509_KEY enc_keys[4]; X509_KEY enc_keys[4];
size_t x509_keys_cnt; size_t x509_keys_cnt;
X509_KEY signkey;
X509_KEY kenckey;
// 对于客户端来说需要提供所有的CA证书注意这里不是证书链而是一个个独立的证书 // 对于客户端来说需要提供所有的CA证书注意这里不是证书链而是一个个独立的证书
// 对于服务器来说在certificate_request中需要从这些证书中提取dn_names并发送给客户端然后再验证客户端证书 // 对于服务器来说在certificate_request中需要从这些证书中提取dn_names并发送给客户端然后再验证客户端证书
@@ -1057,8 +1051,6 @@ typedef struct {
// 一般来说我们只要保存对方发过来的证书因为己方的证书都在CTX中对吗 // 一般来说我们只要保存对方发过来的证书因为己方的证书都在CTX中对吗
uint8_t server_certs[TLS_MAX_CERTIFICATES_SIZE]; // TODO: use ptr and malloc
size_t server_certs_len;
uint8_t client_certs[TLS_MAX_CERTIFICATES_SIZE]; uint8_t client_certs[TLS_MAX_CERTIFICATES_SIZE];
size_t client_certs_len; size_t client_certs_len;
@@ -1086,8 +1078,6 @@ typedef struct {
size_t peer_cert_chain_len; size_t peer_cert_chain_len;
X509_KEY sign_key;
X509_KEY kenc_key; // 应该作为服务器的SM2加密
X509_KEY server_enc_key; X509_KEY server_enc_key;
int verify_result; int verify_result;
@@ -1102,9 +1092,6 @@ typedef struct {
HMAC_CTX client_write_mac_ctx; HMAC_CTX client_write_mac_ctx;
HMAC_CTX server_write_mac_ctx; HMAC_CTX server_write_mac_ctx;
SM4_KEY client_write_enc_key;
SM4_KEY server_write_enc_key;
@@ -1138,15 +1125,6 @@ typedef struct {
SM2_SIGN_CTX sign_ctx; SM2_SIGN_CTX sign_ctx;
TLS_CLIENT_VERIFY_CTX client_verify_ctx; TLS_CLIENT_VERIFY_CTX client_verify_ctx;
// 所有这些命名为ecdh的都需要替换掉
uint16_t ecdh_named_curve;
X509_KEY ecdh_keys[2];
size_t ecdh_keys_cnt;
X509_KEY ecdh_key;
uint8_t peer_ecdh_point[65];
size_t peer_ecdh_point_len;
// HelloRetryRequest // HelloRetryRequest
int hello_retry_request; int hello_retry_request;

View File

@@ -666,13 +666,6 @@ int tlcp_recv_server_certificate(TLS_CONNECT *conn)
tls_send_alert(conn, TLS_alert_unexpected_message); tls_send_alert(conn, TLS_alert_unexpected_message);
return -1; return -1;
} }
if (conn->peer_cert_chain_len > sizeof(conn->server_certs)) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
memcpy(conn->server_certs, conn->peer_cert_chain, conn->peer_cert_chain_len);
conn->server_certs_len = conn->peer_cert_chain_len;
if (x509_certs_get_cert_by_index(conn->peer_cert_chain, conn->peer_cert_chain_len, if (x509_certs_get_cert_by_index(conn->peer_cert_chain, conn->peer_cert_chain_len,
0, &server_cert, &server_cert_len) != 1) { 0, &server_cert, &server_cert_len) != 1) {
@@ -1279,18 +1272,9 @@ static int tlcp_cert_chains_select(TLS_CONNECT *conn,
} }
} }
if (cert_chain_len > sizeof(conn->server_certs)) {
error_print();
return -1;
}
conn->cert_chain = cert_chain; conn->cert_chain = cert_chain;
conn->cert_chain_len = cert_chain_len; conn->cert_chain_len = cert_chain_len;
conn->cert_chain_idx = cert_chain_idx; conn->cert_chain_idx = cert_chain_idx;
conn->sign_key = conn->ctx->x509_keys[cert_chain_idx - 1];
conn->kenc_key = conn->ctx->enc_keys[cert_chain_idx - 1];
memcpy(conn->server_certs, cert_chain, cert_chain_len);
conn->server_certs_len = cert_chain_len;
conn->signature_algorithms[0] = TLS_sig_sm2sig_sm3; conn->signature_algorithms[0] = TLS_sig_sm2sig_sm3;
conn->signature_algorithms_cnt = 1; conn->signature_algorithms_cnt = 1;
return 1; return 1;
@@ -1392,7 +1376,6 @@ int tlcp_recv_client_hello(TLS_CONNECT *conn)
case TLS_cipher_ecc_sm4_cbc_sm3: case TLS_cipher_ecc_sm4_cbc_sm3:
case TLS_cipher_ecc_sm4_gcm_sm3: case TLS_cipher_ecc_sm4_gcm_sm3:
conn->signature_algorithms[0] = TLS_sig_sm2sig_sm3; conn->signature_algorithms[0] = TLS_sig_sm2sig_sm3;
conn->ecdh_named_curve = 0;
break; break;
case TLS_cipher_ecdhe_sm4_cbc_sm3: case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_sm4_gcm_sm3: case TLS_cipher_ecdhe_sm4_gcm_sm3:
@@ -1700,7 +1683,21 @@ int tlcp_send_server_hello(TLS_CONNECT *conn)
} }
// 因为这个就是发送证书而已和tls12是没有什么区别的 /*
-- IBC_SM4_CBC_SM3 和 IBC_SM4_GCM_SM3 套件的服务器Certificate消息格式
opaque ASN.1IBCParam<1..2^24-1>
struct {
opaque ibc_id<1..2^16-1>;
ASN.1IBCParam ibc_parameter;
} Certificate;
其中ibc_id是服务器的SM9的ID这个ID暂时是一个没有内部结构的字节串后续有可能是一个DER结构的字节串
ibc_parameter 是SM9 sm9_enc_master_key_to_der 输出的DER数据
*/
int tlcp_send_server_certificate(TLS_CONNECT *conn) int tlcp_send_server_certificate(TLS_CONNECT *conn)
{ {
int ret; int ret;
@@ -1759,6 +1756,8 @@ int tlcp_send_server_key_exchange(TLS_CONNECT *conn)
tls_trace("send ServerKeyExchange\n"); tls_trace("send ServerKeyExchange\n");
if (conn->recordlen == 0) { if (conn->recordlen == 0) {
X509_KEY *sign_key;
if (!conn->cert_chain || !conn->cert_chain_len || !conn->cert_chain_idx) { if (!conn->cert_chain || !conn->cert_chain_len || !conn->cert_chain_idx) {
error_print(); error_print();
tls_send_alert(conn, TLS_alert_internal_error); tls_send_alert(conn, TLS_alert_internal_error);
@@ -1780,7 +1779,13 @@ int tlcp_send_server_key_exchange(TLS_CONNECT *conn)
return -1; return -1;
} }
if (sm2_sign_init(&sign_ctx, &conn->sign_key.u.sm2_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1 sign_key = &conn->ctx->x509_keys[conn->cert_chain_idx - 1];
if (sign_key->algor != OID_ec_public_key || sign_key->algor_param != OID_sm2) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
if (sm2_sign_init(&sign_ctx, &sign_key->u.sm2_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1
|| sm2_sign_update(&sign_ctx, conn->client_random, 32) != 1 || sm2_sign_update(&sign_ctx, conn->client_random, 32) != 1
|| sm2_sign_update(&sign_ctx, conn->server_random, 32) != 1 || sm2_sign_update(&sign_ctx, conn->server_random, 32) != 1
|| sm2_sign_update(&sign_ctx, server_ecc_params, server_ecc_params_len) != 1 || sm2_sign_update(&sign_ctx, server_ecc_params, server_ecc_params_len) != 1
@@ -2053,6 +2058,7 @@ int tlcp_recv_client_key_exchange(TLS_CONNECT *conn)
const uint8_t *enced_pms; const uint8_t *enced_pms;
size_t enced_pms_len; size_t enced_pms_len;
size_t pre_master_secret_len; size_t pre_master_secret_len;
X509_KEY *enc_key;
tls_trace("recv ClientKeyExchange\n"); tls_trace("recv ClientKeyExchange\n");
@@ -2069,8 +2075,13 @@ int tlcp_recv_client_key_exchange(TLS_CONNECT *conn)
tls_send_alert(conn, TLS_alert_unexpected_message); tls_send_alert(conn, TLS_alert_unexpected_message);
return -1; return -1;
} }
if (!conn->cert_chain_idx || conn->kenc_key.algor != OID_ec_public_key if (!conn->cert_chain_idx) {
|| conn->kenc_key.algor_param != OID_sm2) { error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
enc_key = &conn->ctx->enc_keys[conn->cert_chain_idx - 1];
if (enc_key->algor != OID_ec_public_key || enc_key->algor_param != OID_sm2) {
error_print(); error_print();
tls_send_alert(conn, TLS_alert_internal_error); tls_send_alert(conn, TLS_alert_internal_error);
return -1; return -1;
@@ -2079,7 +2090,7 @@ int tlcp_recv_client_key_exchange(TLS_CONNECT *conn)
// FIXME: // FIXME:
// 这里需要检查一下密钥的长度,因为输入的长度是确定的,因此输出的密文长度应该也是确定的 // 这里需要检查一下密钥的长度,因为输入的长度是确定的,因此输出的密文长度应该也是确定的
if (sm2_decrypt(&conn->kenc_key.u.sm2_key, enced_pms, enced_pms_len, if (sm2_decrypt(&enc_key->u.sm2_key, enced_pms, enced_pms_len,
conn->pre_master_secret, &pre_master_secret_len) != 1) { conn->pre_master_secret, &pre_master_secret_len) != 1) {
error_print(); error_print();
tls_send_alert(conn, TLS_alert_decrypt_error); tls_send_alert(conn, TLS_alert_decrypt_error);

103
src/tls.c
View File

@@ -2410,9 +2410,6 @@ void tls_ctx_cleanup(TLS_CTX *ctx)
x509_key_cleanup(&ctx->x509_keys[i]); x509_key_cleanup(&ctx->x509_keys[i]);
x509_key_cleanup(&ctx->enc_keys[i]); x509_key_cleanup(&ctx->enc_keys[i]);
} }
x509_key_cleanup(&ctx->signkey);
x509_key_cleanup(&ctx->kenckey);
if (ctx->certs) free(ctx->certs);
if (ctx->cacerts) free(ctx->cacerts); if (ctx->cacerts) free(ctx->cacerts);
memset(ctx, 0, sizeof(TLS_CTX)); memset(ctx, 0, sizeof(TLS_CTX));
} }
@@ -2678,18 +2675,9 @@ int tls_ctx_add_certificate_chain_and_key(TLS_CTX *ctx, const char *chainfile,
} }
// 保留这个函数,相当于是对证书链的初始化
int tls_ctx_set_certificate_and_key(TLS_CTX *ctx, const char *chainfile, int tls_ctx_set_certificate_and_key(TLS_CTX *ctx, const char *chainfile,
const char *keyfile, const char *keypass) const char *keyfile, const char *keypass)
{ {
int ret = -1;
uint8_t *certs = NULL;
size_t certslen;
FILE *keyfp = NULL;
const uint8_t *cert;
size_t certlen;
X509_KEY public_key;
if (!ctx || !chainfile || !keyfile || !keypass) { if (!ctx || !chainfile || !keyfile || !keypass) {
error_print(); error_print();
return -1; return -1;
@@ -2698,47 +2686,15 @@ int tls_ctx_set_certificate_and_key(TLS_CTX *ctx, const char *chainfile,
error_print(); error_print();
return -1; return -1;
} }
if (ctx->certs) { if (ctx->cert_chains_len || ctx->x509_keys_cnt) {
error_print(); error_print();
return -1; return -1;
} }
if (tls_ctx_add_certificate_chain_and_key(ctx, chainfile, keyfile, keypass) != 1) {
if (x509_certs_new_from_file(&certs, &certslen, chainfile) != 1) {
error_print();
goto end;
}
if (x509_certs_get_cert_by_index(certs, certslen, 0, &cert, &certlen) != 1
|| x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) {
error_print(); error_print();
return -1; return -1;
} }
return 1;
if (public_key.algor == OID_ec_public_key) {
if (!(keyfp = fopen(keyfile, "r"))) {
error_print();
return -1;
}
} else {
if (!(keyfp = fopen(keyfile, "rb+"))) {
error_print();
return -1;
}
}
if (x509_private_key_from_file(&ctx->signkey, public_key.algor, keypass, keyfp) != 1) {
error_print();
return -1;
}
ctx->certs = certs;
ctx->certslen = certslen;
certs = NULL;
ret = 1;
end:
if (certs) free(certs);
if (keyfp) fclose(keyfp);
return ret;
} }
int tlcp_ctx_add_server_certificate_and_keys(TLS_CTX *ctx, const char *chainfile, int tlcp_ctx_add_server_certificate_and_keys(TLS_CTX *ctx, const char *chainfile,
@@ -2832,12 +2788,7 @@ int tlcp_ctx_add_server_certificate_and_keys(TLS_CTX *ctx, const char *chainfile
} }
ctx->cert_chains_len += cert_chains_len; ctx->cert_chains_len += cert_chains_len;
ctx->cert_chains_cnt++;
ctx->x509_keys_cnt++; ctx->x509_keys_cnt++;
if (key_idx == 0) {
ctx->signkey = ctx->x509_keys[0];
ctx->kenckey = ctx->enc_keys[0];
}
ret = 1; ret = 1;
end: end:
@@ -2969,6 +2920,29 @@ int tls_ctx_set_key_update_seq_num_limit(TLS_CTX *ctx, size_t max_seq_num)
} }
static int tls_ctx_get_certificate_chain(const TLS_CTX *ctx, size_t idx,
const uint8_t **cert_chain, size_t *cert_chain_len)
{
const uint8_t *p;
size_t len;
size_t i;
if (!ctx || !cert_chain || !cert_chain_len || !idx) {
error_print();
return -1;
}
p = ctx->cert_chains;
len = ctx->cert_chains_len;
for (i = 1; i <= idx; i++) {
if (tls_uint24array_from_bytes(cert_chain, cert_chain_len, &p, &len) != 1) {
error_print();
return -1;
}
}
return 1;
}
int tls_init(TLS_CONNECT *conn, TLS_CTX *ctx) int tls_init(TLS_CONNECT *conn, TLS_CTX *ctx)
{ {
if (!conn || !ctx) { if (!conn || !ctx) {
@@ -2986,21 +2960,22 @@ int tls_init(TLS_CONNECT *conn, TLS_CTX *ctx)
conn->protocol = ctx->protocol; conn->protocol = ctx->protocol;
if (ctx->certslen > TLS_MAX_CERTIFICATES_SIZE) { if (conn->is_client && ctx->cert_chains_len) {
error_print(); if (tls_ctx_get_certificate_chain(ctx, 1,
return -1; &conn->cert_chain, &conn->cert_chain_len) != 1) {
} error_print();
if (conn->is_client) { return -1;
memcpy(conn->client_certs, ctx->certs, ctx->certslen); }
conn->client_certs_len = ctx->certslen; if (conn->cert_chain_len > sizeof(conn->client_certs)) {
} else { error_print();
memcpy(conn->server_certs, ctx->certs, ctx->certslen); return -1;
conn->server_certs_len = ctx->certslen; }
memcpy(conn->client_certs, conn->cert_chain, conn->cert_chain_len);
conn->client_certs_len = conn->cert_chain_len;
conn->cert_chain_idx = 1;
} }
conn->sign_key = ctx->signkey;
conn->kenc_key = ctx->kenckey;
conn->ctx = ctx; conn->ctx = ctx;
conn->key_exchanges_cnt = ctx->key_exchanges_cnt; conn->key_exchanges_cnt = ctx->key_exchanges_cnt;

View File

@@ -757,10 +757,6 @@ int tls_handshake_init(TLS_CONNECT *conn)
digest_init(&conn->dgst_ctx, DIGEST_sm3()); digest_init(&conn->dgst_ctx, DIGEST_sm3());
if (conn->client_certs_len) {
//sm2_sign_init(&conn->sign_ctx, &conn->sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH);
}
return 1; return 1;
} }
@@ -2378,9 +2374,6 @@ int tls_recv_server_key_exchange(TLS_CONNECT *conn)
} }
// named_curve应该在supported_groups里面 // named_curve应该在supported_groups里面
//conn->ecdh_named_curve = named_curve;
conn->key_exchange_group = named_curve; conn->key_exchange_group = named_curve;
memcpy(conn->peer_key_exchange, point_octets, point_octets_len); memcpy(conn->peer_key_exchange, point_octets, point_octets_len);
conn->peer_key_exchange_len = point_octets_len; conn->peer_key_exchange_len = point_octets_len;

View File

@@ -329,7 +329,7 @@ bad:
perror("fopen"); perror("fopen");
goto end; goto end;
} }
if (x509_certs_to_pem(conn.server_certs, conn.server_certs_len, certoutfp) != 1) { if (x509_certs_to_pem(conn.peer_cert_chain, conn.peer_cert_chain_len, certoutfp) != 1) {
fprintf(stderr, "%s: x509_certs_to_pem error\n", prog); fprintf(stderr, "%s: x509_certs_to_pem error\n", prog);
fclose(certoutfp); fclose(certoutfp);
goto end; goto end;