Clean code

This commit is contained in:
Zhi Guan
2026-05-28 10:00:53 +08:00
parent d90e7638fb
commit 164561ee94
3 changed files with 325 additions and 1268 deletions

1143
src/tlcp.c

File diff suppressed because it is too large Load Diff

View File

@@ -1503,39 +1503,31 @@ static const int tls12_ciphers[] = {
static const int tls13_ciphers[] = {
TLS_cipher_sm4_gcm_sm3,
TLS_cipher_aes_128_gcm_sha256,
};
int tls_cipher_suite_support_protocol(int cipher, int protocol)
int tls_cipher_suite_match_protocol(int cipher, int protocol)
{
const int *ciphers;
size_t ciphers_cnt;
switch (protocol) {
case TLS_protocol_tlcp:
ciphers = tlcp_ciphers;
ciphers_cnt = sizeof(tlcp_ciphers)/sizeof(tlcp_ciphers[0]);
if (!tls_type_is_in_list(cipher, tlcp_ciphers, sizeof(tlcp_ciphers)/sizeof(tlcp_ciphers[0]))) {
return 0;
}
break;
case TLS_protocol_tls12:
ciphers = tls12_ciphers;
ciphers_cnt = sizeof(tls12_ciphers)/sizeof(tls12_ciphers[0]);
if (!tls_type_is_in_list(cipher, tls12_ciphers, sizeof(tls12_ciphers)/sizeof(tls12_ciphers[0]))) {
return 0;
}
break;
case TLS_protocol_tls13:
ciphers = tls13_ciphers;
ciphers_cnt = sizeof(tls13_ciphers)/sizeof(tls13_ciphers[0]);
if (!tls_type_is_in_list(cipher, tls13_ciphers, sizeof(tls13_ciphers)/sizeof(tls13_ciphers[0]))) {
return 0;
}
break;
default:
error_print();
return -1;
}
/*
if (!tls_cipher_suite_in_list(cipher, ciphers, ciphers_cnt)) {
error_print();
return 0;
}
*/
return 1;
}

View File

@@ -9,6 +9,7 @@
#include <time.h>
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
@@ -23,35 +24,8 @@
#include <gmssl/mem.h>
#include <gmssl/tls.h>
#include <errno.h>
// 现在client_certificate_verify做的是不好的
/*
是否要求客户端提供证书是服务器决定的服务器方需要提供相应的CA证书
对于服务器来说CONN中保存的CA证书是用于验证客户端的
对于客户端来说,这些证书是用于验证服务器的
服务器知道是否验证客户端证书是通过是否有CA证书判断的
客户端是通过什么判断的?
实际上客户端可以提供备选的证书,因此应该有一个标识符来标记
现在通盘考虑一下ECDHE过程中双方需要准备什么
Client: X509_KEY私钥自己生成
从服务器那边拿到的对方的X509_KEY公钥
服务器:自己的私钥,以及对方的公钥
每次如果发现当前的缓冲区是有数据的record_left > 0说明已经组装完了还没发送完
然后我们就要调用一个函数来发送数据
*/
// 实际上这个功能本质上是把缓冲区的数据发出去
static const int tls12_ciphers[] = {
TLS_cipher_ecdhe_sm4_cbc_sm3,
@@ -360,9 +334,6 @@ void tls_clean_record(TLS_CONNECT *conn)
}
int tls_handshake_init(TLS_CONNECT *conn)
{
@@ -449,49 +420,6 @@ int tls_send_client_hello(TLS_CONNECT *conn)
return 1;
}
//static const int tlcp_ciphers[] = { TLS_cipher_ecc_sm4_cbc_sm3 };
int tlcp_send_client_hello(TLS_CONNECT *conn)
{
int ret;
uint8_t *record = conn->record;
if (!conn->recordlen) {
tls_record_set_protocol(record, TLS_protocol_tlcp);
if (tls_random_generate(conn->client_random) != 1) {
error_print();
return -1;
}
if (tls_record_set_handshake_client_hello(conn->record, &conn->recordlen,
conn->protocol, conn->client_random, NULL, 0,
conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt,
NULL, 0) != 1) {
error_print();
return -1;
}
// offset = 0, recordlen > 0
tls_trace("send ClientHello\n");
tlcp_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
}
if (conn->client_certificate_verify) {
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
tls_clean_record(conn);
return 1;
}
@@ -503,107 +431,6 @@ int tlcp_send_client_hello(TLS_CONNECT *conn)
int tlcp_recv_client_hello(TLS_CONNECT *conn)
{
int ret;
uint8_t *record = conn->record;
int client_verify = 0;
int protocol;
const uint8_t *client_random;
const uint8_t *session_id;
size_t session_id_len;
const uint8_t *client_ciphers;
size_t client_ciphers_len;
const uint8_t *client_exts;
size_t client_exts_len;
//sm3_init(&conn->sm3_ctx);
// 服务器端如果设置了CA
if (conn->ctx->cacertslen)
client_verify = 1;
// 这个判断应该改为一个函数
if (client_verify)
tls_client_verify_init(&conn->client_verify_ctx);
tls_trace("recv ClientHello\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
tlcp_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
// 这里TLCP和TLS12是不一样的
if (tls_record_protocol(record) != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_protocol_version);
return -1;
}
if (tls_record_get_handshake_client_hello(record,
&protocol, &client_random, &session_id, &session_id_len,
&client_ciphers, &client_ciphers_len,
&client_exts, &client_exts_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (protocol != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_protocol_version);
return -1;
}
memcpy(conn->client_random, client_random, 32);
if (tls_cipher_suites_select(client_ciphers, client_ciphers_len,
conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt,
&conn->cipher_suite) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_insufficient_security);
return -1;
}
switch (conn->cipher_suite) {
case TLS_cipher_ecc_sm4_cbc_sm3:
conn->signature_algorithms[0] = TLS_sig_sm2sig_sm3;
conn->ecdh_named_curve = 0;
break;
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_sm4_gcm_sm3:
default:
error_print();
return -1;
}
if (client_exts) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
if (client_verify)
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
fprintf(stderr, "end of recv_client_hello\n");
tls_clean_record(conn);
return 1;
}
@@ -1005,65 +832,7 @@ int tls_recv_server_certificate(TLS_CONNECT *conn)
return 1;
}
int tlcp_send_server_key_exchange(TLS_CONNECT *conn)
{
SM2_SIGN_CTX sign_ctx;
uint8_t sigbuf[SM2_MAX_SIGNATURE_SIZE];
size_t siglen;
const uint8_t *server_enc_cert;
size_t server_enc_cert_len;
uint8_t server_enc_cert_lenbuf[3];
uint8_t *p;
size_t len;
int ret;
tls_trace("send ServerKeyExchange\n");
if (conn->recordlen == 0) {
if (x509_certs_get_cert_by_index(conn->server_certs, conn->server_certs_len, 1,
&server_enc_cert, &server_enc_cert_len) != 1) {
error_print();
return -1;
}
p = server_enc_cert_lenbuf;
len = 0;
tls_uint24_to_bytes((uint24_t)server_enc_cert_len, &p, &len);
if (sm2_sign_init(&sign_ctx, &conn->sign_key.u.sm2_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1
|| sm2_sign_update(&sign_ctx, conn->client_random, 32) != 1
|| sm2_sign_update(&sign_ctx, conn->server_random, 32) != 1
|| sm2_sign_update(&sign_ctx, server_enc_cert_lenbuf, 3) != 1
|| sm2_sign_update(&sign_ctx, server_enc_cert, server_enc_cert_len) != 1
|| sm2_sign_finish(&sign_ctx, sigbuf, &siglen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
if (tlcp_record_set_handshake_server_key_exchange_pke(conn->record, &conn->recordlen, sigbuf, siglen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tlcp_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
if (conn->client_certificate_verify) {
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
}
return 1;
}
int tls_send_server_key_exchange(TLS_CONNECT *conn)
{
@@ -1347,78 +1116,6 @@ int tls_recv_server_key_exchange(TLS_CONNECT *conn)
return 1;
}
int tlcp_recv_server_key_exchange(TLS_CONNECT *conn)
{
const uint8_t *sig;
size_t siglen;
const uint8_t *cp;
size_t len;
X509_KEY server_sign_key;
const uint8_t *server_enc_cert;
size_t server_enc_cert_len;
uint8_t server_enc_cert_lenbuf[3];
uint8_t *p;
SM2_VERIFY_CTX verify_ctx;
tls_trace("recv ServerKeyExchange\n");
if (tls_record_recv(conn->record, &conn->recordlen, conn->sock) != 1
|| tls_record_protocol(conn->record) != TLS_protocol_tlcp) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
tlcp_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tlcp_record_get_handshake_server_key_exchange_pke(conn->record, &sig, &siglen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
// verify ServerKeyExchange
if (x509_certs_get_cert_by_index(conn->server_certs, conn->server_certs_len, 0, &cp, &len) != 1
|| x509_cert_get_subject_public_key(cp, len, &server_sign_key) != 1
|| x509_certs_get_cert_by_index(conn->server_certs, conn->server_certs_len, 1, &server_enc_cert, &server_enc_cert_len) != 1
|| x509_cert_get_subject_public_key(server_enc_cert, server_enc_cert_len, &conn->server_enc_key) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
if (server_sign_key.algor != OID_ec_public_key
|| server_sign_key.algor_param != OID_sm2
|| conn->server_enc_key.algor != OID_ec_public_key
|| conn->server_enc_key.algor_param != OID_sm2) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
p = server_enc_cert_lenbuf;
len = 0;
tls_uint24_to_bytes((uint24_t)server_enc_cert_len, &p, &len);
if (sm2_verify_init(&verify_ctx, &server_sign_key.u.sm2_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1
|| sm2_verify_update(&verify_ctx, conn->client_random, 32) != 1
|| sm2_verify_update(&verify_ctx, conn->server_random, 32) != 1
|| sm2_verify_update(&verify_ctx, server_enc_cert_lenbuf, 3) != 1
|| sm2_verify_update(&verify_ctx, server_enc_cert, server_enc_cert_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
if (sm2_verify_finish(&verify_ctx, sig, siglen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_decrypt_error);
return -1;
}
return 1;
}
int tls_send_certificate_request(TLS_CONNECT *conn)
{
@@ -1750,48 +1447,8 @@ int tls_generate_keys(TLS_CONNECT *conn)
return 1;
}
// 对于客户端是先发送client_key_exchange在generate_keys
int tlcp_generate_keys(TLS_CONNECT *conn)
{
tls_trace("generate secrets\n");
if (tls_prf(conn->pre_master_secret, 48, "master secret",
conn->client_random, 32,
conn->server_random, 32,
48, conn->master_secret) != 1
|| tls_prf(conn->master_secret, 48, "key expansion",
conn->server_random, 32, // 这里顺序为什么是反的
conn->client_random, 32,
96, conn->key_block) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
// 主力这里是不对的需要为client, server设定不同的加密密钥
sm3_hmac_init(&conn->client_write_mac_ctx, conn->key_block, 32);
sm3_hmac_init(&conn->server_write_mac_ctx, conn->key_block + 32, 32);
if (conn->is_client) {
sm4_set_encrypt_key(&conn->client_write_enc_key, conn->key_block + 64);
sm4_set_decrypt_key(&conn->server_write_enc_key, conn->key_block + 80);
} else {
sm4_set_decrypt_key(&conn->client_write_enc_key, conn->key_block + 64);
sm4_set_encrypt_key(&conn->server_write_enc_key, conn->key_block + 80);
}
tls_secrets_print(stderr,
conn->pre_master_secret, 48,
conn->client_random, conn->server_random,
conn->master_secret,
conn->key_block, 96,
0, 4);
return 1;
}
int tls_send_client_key_exchange(TLS_CONNECT *conn)
{
int ret;
@@ -1882,80 +1539,9 @@ int tls_recv_client_key_exchange(TLS_CONNECT *conn)
return 1;
}
// 对于TLCP是否应该先执行send_client_key_exchange再生成密钥呢
int tlcp_send_client_key_exchange(TLS_CONNECT *conn)
{
uint8_t enced_pre_master_secret[SM2_MAX_CIPHERTEXT_SIZE];
size_t enced_pre_master_secret_len;
tls_trace("send ClientKeyExchange\n");
if (tls_pre_master_secret_generate(conn->pre_master_secret, TLS_protocol_tlcp) != 1) {
error_print();
return -1;
}
if (sm2_encrypt(&conn->server_enc_key.u.sm2_key, conn->pre_master_secret, 48,
enced_pre_master_secret, &enced_pre_master_secret_len) != 1
|| tls_record_set_handshake_client_key_exchange_pke(conn->record, &conn->recordlen,
enced_pre_master_secret, enced_pre_master_secret_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tlcp_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tls_record_send(conn->record, conn->recordlen, conn->sock) != 1) {
error_print();
return -1;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
return 1;
}
int tlcp_recv_client_key_exchange(TLS_CONNECT *conn)
{
const uint8_t *enced_pms;
size_t enced_pms_len;
size_t pre_master_secret_len;
tls_trace("recv ClientKeyExchange\n");
if (tls_record_recv(conn->record, &conn->recordlen, conn->sock) != 1
|| tls_record_protocol(conn->record) != TLS_protocol_tlcp) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
tlcp_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tls_record_get_handshake_client_key_exchange_pke(conn->record, &enced_pms, &enced_pms_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
// FIXME:
// 这里需要检查一下密钥的长度,因为输入的长度是确定的,因此输出的密文长度应该也是确定的
if (sm2_decrypt(&conn->kenc_key.u.sm2_key, enced_pms, enced_pms_len,
conn->pre_master_secret, &pre_master_secret_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_decrypt_error);
return -1;
}
if (pre_master_secret_len != 48) {
error_print();
tls_send_alert(conn, TLS_alert_decrypt_error);
return -1;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
return 1;
}
int tls_send_certificate_verify(TLS_CONNECT *conn)
{
int ret;
@@ -2423,10 +2009,6 @@ int tls_recv_server_finished(TLS_CONNECT *conn)
<-------- Finished
Application Data <-------> Application Data
*/
int tls12_do_client_handshake(TLS_CONNECT *conn)