Files
GmSSL/src/tls12.c
2026-05-27 18:17:43 +08:00

2782 lines
71 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/*
* Copyright 2014-2026 The GmSSL Project. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the License); you may
* not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*/
#include <time.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include <gmssl/rand.h>
#include <gmssl/x509.h>
#include <gmssl/error.h>
#include <gmssl/sm2.h>
#include <gmssl/sm3.h>
#include <gmssl/sm4.h>
#include <gmssl/pem.h>
#include <gmssl/mem.h>
#include <gmssl/tls.h>
#include <fcntl.h>
#include <errno.h>
#include <sys/select.h>
// 现在client_certificate_verify做的是不好的
/*
是否要求客户端提供证书是服务器决定的服务器方需要提供相应的CA证书
对于服务器来说CONN中保存的CA证书是用于验证客户端的
对于客户端来说,这些证书是用于验证服务器的
服务器知道是否验证客户端证书是通过是否有CA证书判断的
客户端是通过什么判断的?
实际上客户端可以提供备选的证书,因此应该有一个标识符来标记
现在通盘考虑一下ECDHE过程中双方需要准备什么
Client: X509_KEY私钥自己生成
从服务器那边拿到的对方的X509_KEY公钥
服务器:自己的私钥,以及对方的公钥
每次如果发现当前的缓冲区是有数据的record_left > 0说明已经组装完了还没发送完
然后我们就要调用一个函数来发送数据
*/
// 实际上这个功能本质上是把缓冲区的数据发出去
static const int tls12_ciphers[] = {
TLS_cipher_ecdhe_sm4_cbc_sm3,
TLS_cipher_ecdhe_sm4_gcm_sm3,
TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256,
};
int tls12_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int format, int indent)
{
// 目前只支持TLCP的ECC公钥加密套件因此不论用哪个套件解析都是一样的
// 如果未来支持ECDHE套件可以将函数改为宏直接传入 (conn->cipher_suite << 8)
format |= tls12_ciphers[0] << 8; // 应该是KeyExchange需要这个参数
return tls_record_print(fp, record, recordlen, format, indent);
}
// modify: conn->record_offset
int tls_send_record(TLS_CONNECT *conn)
{
size_t left;
tls_ret_t n;
left = tls_record_length(conn->record) - conn->record_offset;
while (left) {
n = tls_socket_send(conn->sock, conn->record + conn->record_offset, left, 0);
if (n < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
return TLS_ERROR_SEND_AGAIN;
} else if (errno == EINTR) {
continue;
} else {
fprintf(stderr, "%s %d: send() error: %s\n", __FILE__, __LINE__, strerror(errno));
error_print();
return -1;
}
}
conn->record_offset += n;
left -= n;
}
return 1;
}
int tls_recv_record(TLS_CONNECT *conn)
{
size_t left;
tls_ret_t n;
if (conn->recordlen) {
return 1;
}
if (conn->record_offset < 5) {
left = 5 - conn->record_offset;
while (left) {
n = tls_socket_recv(conn->sock, conn->record + conn->record_offset, left, 0);
if (n < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
return TLS_ERROR_RECV_AGAIN;
} else if (errno == EINTR) {
continue;
} else {
error_print();
// TODO: check the usage of OpenSSL SSL_ERR_SYSCALL
// if applications such as Nginx, HTTPD do not use this error, we just return -1
return TLS_ERROR_SYSCALL;
}
} else if (n == 0) {
error_print();
return TLS_ERROR_TCP_CLOSED;
}
conn->record_offset += n;
left -= n;
}
}
if (conn->record_offset == 5) {
if (!tls_record_type_name(tls_record_type(conn->record))) {
error_print();
return -1;
}
if (!tls_protocol_name(tls_record_protocol(conn->record))) {
error_print();
return -1;
}
if (tls_record_length(conn->record) > TLS_MAX_RECORD_SIZE) {
error_print();
return -1;
}
}
if (conn->record_offset >= tls_record_length(conn->record)) {
error_print();
return -1;
}
left = tls_record_length(conn->record) - conn->record_offset;
while (left) {
n = tls_socket_recv(conn->sock, conn->record + conn->record_offset, left, 0);
if (n < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
return TLS_ERROR_RECV_AGAIN;
} else if (errno == EINTR) {
continue;
} else {
error_print();
return TLS_ERROR_SYSCALL;
}
} else if (n == 0) {
error_print();
return TLS_ERROR_TCP_CLOSED;
}
conn->record_offset += n;
left -= n;
}
conn->recordlen = conn->record_offset;
// 应该判断是否为Alert这种异常状况
return 1;
}
int tls_named_curve_oid(int named_curve)
{
switch (named_curve) {
case TLS_curve_secp256r1: return OID_secp256r1;
case TLS_curve_sm2p256v1: return OID_sm2;
}
return OID_undef;
}
int tls_named_curve_from_oid(int oid)
{
switch (oid) {
case OID_secp256r1: return TLS_curve_secp256r1;
case OID_sm2: return TLS_curve_sm2p256v1;
}
return 0;
}
// 这个是必选的
const int ec_point_formats[] = { TLS_point_uncompressed };
size_t ec_point_formats_cnt = sizeof(ec_point_formats)/sizeof(ec_point_formats[0]);
// 服务器通常推荐返回这个值
const int supported_groups[] = {
TLS_curve_sm2p256v1,
TLS_curve_secp256r1,
};
size_t supported_groups_cnt = sizeof(supported_groups)/sizeof(supported_groups[0]);
// 仍旧是不可设置的
const int signature_algors[] = {
TLS_sig_sm2sig_sm3,
TLS_sig_ecdsa_secp256r1_sha256,
};
size_t signature_algors_cnt = sizeof(signature_algors)/sizeof(signature_algors[0]);
int tls_record_set_handshake_server_key_exchange(uint8_t *record, size_t *recordlen,
const uint8_t *server_ecdh_params, size_t server_ecdh_params_len,
uint16_t sig_alg, const uint8_t *sig, size_t siglen)
{
const int type = TLS_handshake_server_key_exchange;
uint8_t *p = tls_handshake_data(tls_record_data(record));
size_t len = 0;
if (server_ecdh_params_len != 69) {
error_print();
return -1;
}
if (siglen > TLS_MAX_SIGNATURE_SIZE) {
error_print();
return -1;
}
tls_array_to_bytes(server_ecdh_params, server_ecdh_params_len, &p, &len);
tls_uint16_to_bytes(sig_alg, &p, &len);
tls_uint16array_to_bytes(sig, siglen, &p, &len);
tls_record_set_handshake(record, recordlen, type, NULL, len);
return 1;
}
// 这个函数是有问题的因为tlcp的格式和TLS不一样
int tls_record_get_handshake_server_key_exchange(const uint8_t *record,
uint8_t *curve_type, uint16_t *named_curve,
const uint8_t **point_octets, size_t *point_octets_len,
const uint8_t **server_ecdh_params, size_t *server_ecdh_params_len,
uint16_t *sig_alg, const uint8_t **sig, size_t *siglen)
{
int type;
const uint8_t *p;
size_t len;
if (tls_record_get_handshake(record, &type, &p, &len) != 1) {
error_print();
return -1;
}
if (type != TLS_handshake_server_key_exchange) {
error_print();
return -1;
}
*server_ecdh_params = p;
if (tls_uint8_from_bytes(curve_type, &p, &len) != 1
|| tls_uint16_from_bytes(named_curve, &p, &len) != 1
|| tls_uint8array_from_bytes(point_octets, point_octets_len, &p, &len) != 1) {
error_print();
return -1;
}
*server_ecdh_params_len = p - *server_ecdh_params;
if (*server_ecdh_params_len != 69) {
error_print();
return -1;
}
if (tls_uint16_from_bytes(sig_alg, &p, &len) != 1
|| tls_uint16array_from_bytes(sig, siglen, &p, &len) != 1
|| tls_length_is_zero(len) != 1) {
error_print();
return -1;
}
if (*curve_type != TLS_curve_type_named_curve) {
error_print();
return -1;
}
if (!tls_named_curve_name(*named_curve)) {
error_print();
return -1;
}
if (!tls_signature_scheme_name(*sig_alg)) {
error_print();
return -1;
}
return 1;
}
int tls_record_set_handshake_client_key_exchange(uint8_t *record, size_t *recordlen,
const uint8_t *point_octets, size_t point_octets_len)
{
int type = TLS_handshake_client_key_exchange;
uint8_t *p = tls_handshake_data(tls_record_data(record));
size_t len = 0;
if (point_octets_len != 65) {
error_print();
return -1;
}
tls_uint8array_to_bytes(point_octets, (uint8_t)point_octets_len, &p, &len);
tls_record_set_handshake(record, recordlen, type, NULL, len);
return 1;
}
int tls_record_get_handshake_client_key_exchange(const uint8_t *record,
const uint8_t **point_octets, size_t *point_octets_len)
{
int type;
const uint8_t *p;
size_t len;
if (tls_record_get_handshake(record, &type, &p, &len) != 1
|| type != TLS_handshake_client_key_exchange) {
error_print();
return -1;
}
if (tls_uint8array_from_bytes(point_octets, point_octets_len, &p, &len) != 1) {
error_print();
return -1;
}
if (*point_octets_len != 65) {
error_print();
return -1;
}
if (len) {
error_print();
return -1;
}
return 1;
}
void tls_clean_record(TLS_CONNECT *conn)
{
conn->record_offset = 0;
conn->recordlen = 0;
}
int tls_handshake_init(TLS_CONNECT *conn)
{
sm3_init(&conn->sm3_ctx);
digest_init(&conn->dgst_ctx, DIGEST_sm3());
if (conn->client_certs_len) {
//sm2_sign_init(&conn->sign_ctx, &conn->sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH);
}
return 1;
}
/*
TLCP协议中ClientHello中不包含扩展并且cipher_suites使用的是TLCP的cipher_suites
ciphers我觉得应该在设置ctx的时候设置好
exts
*/
int tls_send_client_hello(TLS_CONNECT *conn)
{
int ret;
uint8_t *record = conn->record;
if (!conn->recordlen) {
uint8_t client_exts[TLS_MAX_EXTENSIONS_SIZE];
uint8_t *p = client_exts;
size_t client_exts_len = 0;
switch (conn->protocol) {
case TLS_protocol_tls12:
tls_record_set_protocol(record, TLS_protocol_tls1);
break;
case TLS_protocol_tlcp:
tls_record_set_protocol(record, TLS_protocol_tlcp);
break;
default:
error_print();
return -1;
}
if (tls_random_generate(conn->client_random) != 1) {
error_print();
return -1;
}
if (tls_ec_point_formats_ext_to_bytes(ec_point_formats, ec_point_formats_cnt, &p, &client_exts_len) != 1
|| tls_supported_groups_ext_to_bytes(supported_groups, supported_groups_cnt, &p, &client_exts_len) != 1
|| tls_signature_algorithms_ext_to_bytes(signature_algors, signature_algors_cnt, &p, &client_exts_len) != 1) {
error_print();
return -1;
}
if (tls_record_set_handshake_client_hello(conn->record, &conn->recordlen,
conn->protocol, conn->client_random, NULL, 0,
conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt,
client_exts, client_exts_len) != 1) {
error_print();
return -1;
}
// offset = 0, recordlen > 0
tls_trace("send ClientHello\n");
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
}
// 客户端一开始不知道是否要进行客户端验证
// 如果用户提供了客户端证书那么就准备,如果没提供就完全不准确
// 客户端的证书可以通过回调函数来设置
if (conn->client_certificate_verify) {
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
tls_clean_record(conn);
return 1;
}
//static const int tlcp_ciphers[] = { TLS_cipher_ecc_sm4_cbc_sm3 };
int tlcp_send_client_hello(TLS_CONNECT *conn)
{
int ret;
uint8_t *record = conn->record;
if (!conn->recordlen) {
tls_record_set_protocol(record, TLS_protocol_tlcp);
if (tls_random_generate(conn->client_random) != 1) {
error_print();
return -1;
}
if (tls_record_set_handshake_client_hello(conn->record, &conn->recordlen,
conn->protocol, conn->client_random, NULL, 0,
conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt,
NULL, 0) != 1) {
error_print();
return -1;
}
// offset = 0, recordlen > 0
tls_trace("send ClientHello\n");
tlcp_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
}
if (conn->client_certificate_verify) {
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
tls_clean_record(conn);
return 1;
}
int tlcp_recv_client_hello(TLS_CONNECT *conn)
{
int ret;
uint8_t *record = conn->record;
size_t recordlen;
int client_verify = 0;
int protocol;
const uint8_t *client_random;
const uint8_t *session_id;
size_t session_id_len;
const uint8_t *client_ciphers;
size_t client_ciphers_len;
const uint8_t *client_exts;
size_t client_exts_len;
//sm3_init(&conn->sm3_ctx);
// 服务器端如果设置了CA
if (conn->ctx->cacertslen)
client_verify = 1;
// 这个判断应该改为一个函数
if (client_verify)
tls_client_verify_init(&conn->client_verify_ctx);
tls_trace("recv ClientHello\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
tlcp_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
// 这里TLCP和TLS12是不一样的
if (tls_record_protocol(record) != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_protocol_version);
return -1;
}
if (tls_record_get_handshake_client_hello(record,
&protocol, &client_random, &session_id, &session_id_len,
&client_ciphers, &client_ciphers_len,
&client_exts, &client_exts_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (protocol != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_protocol_version);
return -1;
}
memcpy(conn->client_random, client_random, 32);
if (tls_cipher_suites_select(client_ciphers, client_ciphers_len,
conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt,
&conn->cipher_suite) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_insufficient_security);
return -1;
}
switch (conn->cipher_suite) {
case TLS_cipher_ecc_sm4_cbc_sm3:
conn->signature_algorithms[0] = TLS_sig_sm2sig_sm3;
conn->ecdh_named_curve = 0;
break;
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_sm4_gcm_sm3:
default:
error_print();
return -1;
}
if (client_exts) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
if (client_verify)
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
fprintf(stderr, "end of recv_client_hello\n");
tls_clean_record(conn);
return 1;
}
/*
const int server_ciphers[] = { TLS_cipher_ecdhe_sm4_cbc_sm3 };
const size_t server_ciphers_cnt = 1;
*/
const int curve = TLS_curve_sm2p256v1;
// 服务器在收到ClientHello之后
int tls_recv_client_hello(TLS_CONNECT *conn)
{
int ret;
uint8_t *record = conn->record;
size_t recordlen;
int client_verify = 0;
int protocol;
const uint8_t *client_random;
const uint8_t *session_id;
size_t session_id_len;
const uint8_t *client_ciphers;
size_t client_ciphers_len;
const uint8_t *client_exts;
size_t client_exts_len;
sm3_init(&conn->sm3_ctx);
// 服务器端如果设置了CA
if (conn->ctx->cacertslen)
client_verify = 1;
// 这个判断应该改为一个函数
if (client_verify)
tls_client_verify_init(&conn->client_verify_ctx);
tls_trace("recv ClientHello\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tls_record_protocol(record) != conn->protocol
&& tls_record_protocol(record) != TLS_protocol_tls1) {
error_print();
tls_send_alert(conn, TLS_alert_protocol_version);
return -1;
}
if (tls_record_get_handshake_client_hello(record,
&protocol, &client_random, &session_id, &session_id_len,
&client_ciphers, &client_ciphers_len,
&client_exts, &client_exts_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (protocol != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_protocol_version);
return -1;
}
memcpy(conn->client_random, client_random, 32);
// 服务器选择的cipher_suites需要和服务器准备的证书和公钥匹配
if (tls_cipher_suites_select(client_ciphers, client_ciphers_len,
conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt,
&conn->cipher_suite) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_insufficient_security);
return -1;
}
switch (conn->cipher_suite) {
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_sm4_gcm_sm3:
conn->signature_algorithms[0] = TLS_sig_sm2sig_sm3;
conn->ecdh_named_curve = TLS_curve_sm2p256v1;
break;
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
conn->signature_algorithms[0] = TLS_sig_ecdsa_secp256r1_sha256;
conn->ecdh_named_curve = TLS_curve_secp256r1;
break;
default:
error_print();
return -1;
}
if (client_exts) {
// 这些函数需要能够访问conn的内部变量
// 修改处理扩展的逻辑
//tls_process_client_hello_exts(client_exts, client_exts_len,
// conn->server_exts, &conn->server_exts_len, sizeof(conn->server_exts));
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
if (client_verify)
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
fprintf(stderr, "end of recv_client_hello\n");
tls_clean_record(conn);
return 1;
}
int tls_send_server_hello(TLS_CONNECT *conn)
{
int ret;
tls_trace("send ServerHello\n");
if (conn->recordlen == 0) {
const uint8_t *server_exts = NULL;
size_t server_exts_len = 0;
tls_record_set_protocol(conn->record, conn->protocol);
if (tls_random_generate(conn->server_random) != 1) {
error_print();
return -1;
}
// 修改处理扩展的逻辑把ClientHello的每个扩展处理结果分别放在conn的各个变量中
/*
if (conn->server_exts_len) {
server_exts = conn->server_exts;
server_exts_len = conn->server_exts_len;
}
*/
if (tls_record_set_handshake_server_hello(conn->record, &conn->recordlen,
conn->protocol, conn->server_random, NULL, 0,
conn->cipher_suite,
server_exts, server_exts_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
if (conn->ctx->cacertslen) {
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
}
tls_clean_record(conn);
return 1;
}
int tls_recv_server_hello(TLS_CONNECT *conn)
{
int ret;
int protocol;
int cipher_suite;
const uint8_t *server_random;
const uint8_t *session_id;
size_t session_id_len;
const uint8_t *server_exts;
size_t server_exts_len;
// 扩展的协商结果,-1 表示服务器不支持该扩展(未给出响应)
int ec_point_format = -1;
int supported_group = -1;
int signature_algor = -1;
// 实际上当前的record已经有完整的数据了但是我们不知道啊
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tls_record_protocol(conn->record) != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_protocol_version);
return -1;
}
if (tls_record_get_handshake_server_hello(conn->record,
&protocol, &server_random, &session_id, &session_id_len, &cipher_suite,
&server_exts, &server_exts_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (protocol != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_protocol_version);
return -1;
}
/*
if (tls_cipher_suite_in_list(cipher_suite, conn->cipher_suites, conn->cipher_suites_cnt) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_handshake_failure);
return -1;
}
*/
/*
对于扩展的处理
首先扩展是由客户端ClientHello进行设定服务器选择最后由Client验证的一个过程。
因此客户端和服务器端都需要存储扩展相应的数据。
扩展的初始化是如何实现的我觉得也应该在CTX中完成。
现在对扩展的处理逻辑是有问题的服务器ServerHello是否包含扩展是取决于ClientHello的
*/
if (server_exts) {
if (tls_process_server_hello_exts(server_exts, server_exts_len, &ec_point_format, &supported_group, &signature_algor) != 1
|| ec_point_format < 0
|| supported_group < 0
|| signature_algor < 0) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
}
memcpy(conn->server_random, server_random, 32);
memcpy(conn->session_id, session_id, session_id_len);
conn->cipher_suite = cipher_suite;
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
if (conn->client_certs_len) {
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
}
return 1;
}
// TLS12 发送的是常规的证书链
// TLCP SM2 发送的是SM2的双证书链但是在数据格式上没有区别
// TLCP SM9 发送的是服务器的ID和SM9公开参数这个格式是不同的但是存储上可能也是一样的
// 我不确定SM2和SM9的格式是否是相容的
int tls_send_server_certificate(TLS_CONNECT *conn)
{
int ret;
tls_trace("send ServerCertificate\n");
if (conn->recordlen == 0) {
if (tls_record_set_handshake_certificate(conn->record, &conn->recordlen,
conn->server_certs, conn->server_certs_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
if (conn->client_certificate_verify) {
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
}
return 1;
}
int tls_recv_server_certificate(TLS_CONNECT *conn)
{
int ret;
int verify_result;
const uint8_t *server_cert;
size_t server_cert_len;
X509_KEY server_sign_key;
tls_trace("recv ServerCertificate\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
if (tls_record_protocol(conn->record) != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tls_record_get_handshake_certificate(conn->record,
conn->server_certs, &conn->server_certs_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
// 这下面是对获取的证书链的处理
if (x509_certs_get_cert_by_index(conn->server_certs, conn->server_certs_len, 0,
&server_cert, &server_cert_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
if (x509_cert_get_subject_public_key(server_cert, server_cert_len, &server_sign_key) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
// 这里的逻辑需要统筹考虑
// cipher_suite扩展证书之间的关系
// set conn->server_sig_alg (decided by cipher_suite and server_cert.sign_key.algor, algor_param)
if (server_sign_key.algor != OID_ec_public_key) {
error_print();
return -1;
}
switch (conn->cipher_suite) {
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_sm4_gcm_sm3:
case TLS_cipher_ecc_sm4_cbc_sm3:
case TLS_cipher_ecc_sm4_gcm_sm3:
if (server_sign_key.algor_param != OID_sm2) {
error_print();
return -1;
}
conn->signature_algorithms[0] = TLS_sig_sm2sig_sm3;
break;
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
if (server_sign_key.algor_param != OID_secp256r1) {
error_print();
return -1;
}
conn->signature_algorithms[0] = TLS_sig_ecdsa_secp256r1_sha256;
break;
default:
error_print();
return -1;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
if (conn->client_certs_len) {
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
}
assert(conn->ctx->verify_depth > 0 && conn->ctx->verify_depth < 10);
// verify ServerCertificate
switch (conn->protocol) {
case TLS_protocol_tls12:
if (x509_certs_verify(conn->server_certs, conn->server_certs_len, X509_cert_chain_server,
conn->ctx->cacerts, conn->ctx->cacertslen, conn->ctx->verify_depth, &verify_result) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
break;
case TLS_protocol_tlcp:
if (!conn->ctx->cacertslen) {
error_print();
return -1;
}
if (x509_certs_verify_tlcp(conn->server_certs, conn->server_certs_len, X509_cert_chain_server,
conn->ctx->cacerts, conn->ctx->cacertslen, conn->ctx->verify_depth, &verify_result) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
break;
default:
error_print();
return -1;
}
return 1;
}
int tlcp_send_server_key_exchange(TLS_CONNECT *conn)
{
SM2_SIGN_CTX sign_ctx;
uint8_t sigbuf[SM2_MAX_SIGNATURE_SIZE];
size_t siglen;
const uint8_t *server_enc_cert;
size_t server_enc_cert_len;
uint8_t server_enc_cert_lenbuf[3];
uint8_t *p;
size_t len;
int ret;
tls_trace("send ServerKeyExchange\n");
if (conn->recordlen == 0) {
if (x509_certs_get_cert_by_index(conn->server_certs, conn->server_certs_len, 1,
&server_enc_cert, &server_enc_cert_len) != 1) {
error_print();
return -1;
}
p = server_enc_cert_lenbuf;
len = 0;
tls_uint24_to_bytes((uint24_t)server_enc_cert_len, &p, &len);
if (sm2_sign_init(&sign_ctx, &conn->sign_key.u.sm2_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1
|| sm2_sign_update(&sign_ctx, conn->client_random, 32) != 1
|| sm2_sign_update(&sign_ctx, conn->server_random, 32) != 1
|| sm2_sign_update(&sign_ctx, server_enc_cert_lenbuf, 3) != 1
|| sm2_sign_update(&sign_ctx, server_enc_cert, server_enc_cert_len) != 1
|| sm2_sign_finish(&sign_ctx, sigbuf, &siglen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
if (tlcp_record_set_handshake_server_key_exchange_pke(conn->record, &conn->recordlen, sigbuf, siglen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tlcp_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
if (conn->client_certificate_verify) {
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
}
return 1;
}
int tls_send_server_key_exchange(TLS_CONNECT *conn)
{
int ret;
uint8_t server_ecdh_params[69];
uint8_t *p = server_ecdh_params + 4;
size_t len = 0;
X509_SIGN_CTX sign_ctx;
const void *sign_args = NULL;
size_t sign_argslen = 0;
uint8_t sig[X509_SIGNATURE_MAX_SIZE];
size_t siglen;
tls_trace("send ServerKeyExchange\n");
if (conn->recordlen == 0) {
int curve_oid = tls_named_curve_oid(conn->ecdh_named_curve);
// generate server ecdh_key
if (x509_key_generate(&conn->ecdh_key, OID_ec_public_key, &curve_oid, sizeof(curve_oid)) != 1) {
error_print();
return -1;
}
// build server_ecdh_params
server_ecdh_params[0] = TLS_curve_type_named_curve;
server_ecdh_params[1] = conn->ecdh_named_curve >> 8;
server_ecdh_params[2] = conn->ecdh_named_curve;
server_ecdh_params[3] = 65;
if (x509_public_key_to_bytes(&conn->ecdh_key, &p, &len) != 1) {
error_print();
return -1;
}
if (len != 65) {
error_print();
return -1;
}
// sign server_ecdh_params
if (conn->sign_key.algor == OID_ec_public_key && conn->sign_key.algor_param == OID_sm2) {
sign_args = SM2_DEFAULT_ID;
sign_argslen = SM2_DEFAULT_ID_LENGTH;
}
if (x509_sign_init(&sign_ctx, &conn->sign_key, sign_args, sign_argslen) != 1
|| x509_sign_update(&sign_ctx, conn->client_random, 32) != 1
|| x509_sign_update(&sign_ctx, conn->server_random, 32) != 1
|| x509_sign_update(&sign_ctx, server_ecdh_params, 69) != 1
|| x509_sign_finish(&sign_ctx, sig, &siglen) != 1) {
x509_sign_ctx_cleanup(&sign_ctx);
error_print();
return -1;
}
x509_sign_ctx_cleanup(&sign_ctx);
if (tls_record_set_handshake_server_key_exchange(conn->record, &conn->recordlen,
server_ecdh_params, sizeof(server_ecdh_params),
conn->signature_algorithms[0], sig, siglen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
if (conn->client_certificate_verify) {
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
}
return 1;
}
// match the ecdhe of cipher_suite
int tls_curve_match_cipher_suite(int named_curve, int cipher_suite)
{
switch (named_curve) {
case TLS_curve_sm2p256v1:
switch (cipher_suite) {
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_sm4_gcm_sm3:
break;
default:
error_print();
return -1;
}
break;
case TLS_curve_secp256r1:
if (cipher_suite != TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256) {
error_print();
return -1;
}
break;
default:
error_print();
return -1;
}
return 1;
}
int tls_signature_scheme_match_cipher_suite(int sig_alg, int cipher_suite)
{
switch (sig_alg) {
case TLS_sig_sm2sig_sm3:
switch (cipher_suite) {
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_sm4_gcm_sm3:
case TLS_cipher_ecc_sm4_cbc_sm3:
case TLS_cipher_ecc_sm4_gcm_sm3:
break;
default:
error_print();
return -1;
}
break;
case TLS_sig_ecdsa_secp256r1_sha256:
switch (cipher_suite) {
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
break;
default:
error_print();
return -1;
}
break;
default:
error_print();
return -1;
}
return 1;
}
// 这里应该给一个新的key_exchange
int tls_recv_server_key_exchange(TLS_CONNECT *conn)
{
int ret;
uint8_t curve_type;
uint16_t named_curve;
const uint8_t *point_octets;
size_t point_octets_len;
const uint8_t *server_ecdh_params;
size_t server_ecdh_params_len;
uint16_t sig_alg;
const uint8_t *sig;
size_t siglen;
// verify ServerKeyExchange
X509_KEY server_sign_key;
int server_cert_index = 0;
const uint8_t *server_cert;
size_t server_cert_len;
uint16_t tls_sig_alg; // 这个值没有初始化
// 这属于握手过程中决定的具体算法,因此握手完成之后就应该确定下来了
// 这应该是由cipher_suite和服务器证书中公钥named_curve)共同决定的
X509_SIGN_CTX sign_ctx;
const void *sign_args = NULL;
size_t sign_argslen = 0;
tls_trace("recv ServerKeyExchange\n");
if (tls_record_recv(conn->record, &conn->recordlen, conn->sock) != 1
|| tls_record_protocol(conn->record) != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
// 这里应该改为首先获取一个基础的ServerKeyExchange的数据已经经过了验证签名
// 然后再用不同的函数(不同协议)去分解
if (tls_record_get_handshake_server_key_exchange(conn->record,
&curve_type, &named_curve, &point_octets, &point_octets_len,
&server_ecdh_params, &server_ecdh_params_len,
&sig_alg, &sig, &siglen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (curve_type != TLS_curve_type_named_curve) {
error_print();
return -1;
}
// named_curve应该在supported_groups里面
conn->ecdh_named_curve = named_curve;
if (point_octets_len != 65) {
error_print();
return -1;
}
memcpy(conn->peer_ecdh_point, point_octets, point_octets_len);
conn->peer_ecdh_point_len = point_octets_len;
if (tls_curve_match_cipher_suite(named_curve, conn->cipher_suite) != 1) {
error_print();
return -1;
}
if (point_octets_len != 65) {
error_print();
return -1;
}
if (tls_signature_scheme_match_cipher_suite(sig_alg, conn->cipher_suite) != 1) {
error_print();
return -1;
}
// 解析server_key_exchange, curve_type, curve_name, point 这三个信息
// 判断curve_type == named_curve
// 判断curve_name在supported_groups中并记录这个信息
// 验证point确实在curve_name的group中
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
if (conn->client_certs_len)
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
// get sign_key from first cert of server_certs
if (x509_certs_get_cert_by_index(conn->server_certs, conn->server_certs_len,
server_cert_index, &server_cert, &server_cert_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
if (x509_cert_get_subject_public_key(server_cert, server_cert_len, &server_sign_key) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
// 这个检查是否是多余的?
// 这个值是签名算法和椭圆曲线名字的结合
// cipher_suite只能决定签名算法类型
// 公钥证书里面的公钥实际上只包含曲线的类型(而不决定签名算法,因为一个椭圆曲线本质上支持多种不同的签名算法)
switch (sig_alg) {
case TLS_sig_sm2sig_sm3:
if (server_sign_key.algor != OID_ec_public_key
|| server_sign_key.algor_param != OID_sm2) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
break;
case TLS_sig_ecdsa_secp256r1_sha256:
if (server_sign_key.algor != OID_ec_public_key
|| server_sign_key.algor_param != OID_secp256r1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
break;
default:
error_print();
return -1;
}
if (server_sign_key.algor == OID_ec_public_key && server_sign_key.algor_param == OID_sm2) {
sign_args = SM2_DEFAULT_ID;
sign_argslen = SM2_DEFAULT_ID_LENGTH;
}
if (x509_verify_init(&sign_ctx, &server_sign_key, sign_args, sign_argslen, sig, siglen) != 1
|| x509_verify_update(&sign_ctx, conn->client_random, 32) != 1
|| x509_verify_update(&sign_ctx, conn->server_random, 32) != 1
|| x509_verify_update(&sign_ctx, server_ecdh_params, 69) != 1
|| x509_verify_finish(&sign_ctx) != 1) {
error_print();
return -1;
}
// xxxx
// 这里的签名错了肯定是sign_ctx就是不对的因此是不可能正确的
// 现在要做的是必须确定server_key_exchange中都包括了哪些被签名的消息
return 1;
}
int tlcp_recv_server_key_exchange(TLS_CONNECT *conn)
{
const uint8_t *sig;
size_t siglen;
const uint8_t *cp;
size_t len;
X509_KEY server_sign_key;
const uint8_t *server_enc_cert;
size_t server_enc_cert_len;
uint8_t server_enc_cert_lenbuf[3];
uint8_t *p;
SM2_VERIFY_CTX verify_ctx;
tls_trace("recv ServerKeyExchange\n");
if (tls_record_recv(conn->record, &conn->recordlen, conn->sock) != 1
|| tls_record_protocol(conn->record) != TLS_protocol_tlcp) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
tlcp_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tlcp_record_get_handshake_server_key_exchange_pke(conn->record, &sig, &siglen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
// verify ServerKeyExchange
if (x509_certs_get_cert_by_index(conn->server_certs, conn->server_certs_len, 0, &cp, &len) != 1
|| x509_cert_get_subject_public_key(cp, len, &server_sign_key) != 1
|| x509_certs_get_cert_by_index(conn->server_certs, conn->server_certs_len, 1, &server_enc_cert, &server_enc_cert_len) != 1
|| x509_cert_get_subject_public_key(server_enc_cert, server_enc_cert_len, &conn->server_enc_key) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
if (server_sign_key.algor != OID_ec_public_key
|| server_sign_key.algor_param != OID_sm2
|| conn->server_enc_key.algor != OID_ec_public_key
|| conn->server_enc_key.algor_param != OID_sm2) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
p = server_enc_cert_lenbuf;
len = 0;
tls_uint24_to_bytes((uint24_t)server_enc_cert_len, &p, &len);
if (sm2_verify_init(&verify_ctx, &server_sign_key.u.sm2_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1
|| sm2_verify_update(&verify_ctx, conn->client_random, 32) != 1
|| sm2_verify_update(&verify_ctx, conn->server_random, 32) != 1
|| sm2_verify_update(&verify_ctx, server_enc_cert_lenbuf, 3) != 1
|| sm2_verify_update(&verify_ctx, server_enc_cert, server_enc_cert_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
if (sm2_verify_finish(&verify_ctx, sig, siglen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_decrypt_error);
return -1;
}
return 1;
}
int tls_send_certificate_request(TLS_CONNECT *conn)
{
int ret;
// 如果要进行客户端证书验证,服务器要提供验证的证书,但是所有证书的
const uint8_t cert_types[] = { TLS_cert_type_ecdsa_sign };
uint8_t ca_names[TLS_MAX_CA_NAMES_SIZE] = {0}; // TODO: 根据客户端验证CA证书列计算缓冲大小或直接输出到record缓冲
size_t ca_names_len = 0;
if (!conn->client_certificate_verify) {
error_print();
return -1;
}
if (conn->recordlen == 0) {
tls_trace("send CertificateRequest\n");
if (tls_authorities_from_certs(ca_names, &ca_names_len, sizeof(ca_names),
conn->ctx->cacerts, conn->ctx->cacertslen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
if (tls_record_set_handshake_certificate_request(conn->record, &conn->recordlen,
cert_types, sizeof(cert_types),
ca_names, ca_names_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
return 1;
}
// 如果收到的是预期的报文就处理然后将recordlen = 0
// 否则保留recordlen
int tls_recv_certificate_request(TLS_CONNECT *conn)
{
int ret;
uint8_t *record = conn->record;
size_t recordlen;
const uint8_t *cp;
size_t len;
int handshake_type;
const uint8_t *cert_types;
size_t cert_types_len;
const uint8_t *ca_names;
size_t ca_names_len;
// recv CertificateRequest or ServerHelloDone
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
if (tls_record_protocol(conn->record) != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (tls_record_get_handshake(record, &handshake_type, &cp, &len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (handshake_type != TLS_handshake_certificate_request) {
conn->client_certs_len = 0;
fprintf(stderr, "%s %d: no certificate_request\n", __FILE__, __LINE__);
fprintf(stderr, "recordlen = %zu\n", conn->recordlen);
return 0; // 表明对方没有发送预期的报文
}
tls_trace("recv CertificateRequest\n");
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tls_record_get_handshake_certificate_request(conn->record,
&cert_types, &cert_types_len, &ca_names, &ca_names_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if(!conn->client_certs_len) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
if (tls_cert_types_accepted(cert_types, cert_types_len, conn->client_certs, conn->client_certs_len) != 1
|| tls_authorities_issued_certificate(ca_names, ca_names_len, conn->client_certs, conn->client_certs_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unsupported_certificate);
return -1;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
conn->recordlen = 0;
return 1;
}
int tls_send_server_hello_done(TLS_CONNECT *conn)
{
int ret;
tls_trace("send ServerHelloDone\n");
if (conn->recordlen == 0) {
tls_record_set_handshake_server_hello_done(conn->record, &conn->recordlen);
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
if (conn->client_certs_len) {
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
}
return 1;
}
// 这是一个非常特殊的状态其他的所有recv状态都是要读取的
// 但是这个状态在大多数情况下,之前已经读取完了,但是我们无法判断这个信息
int tls_recv_server_hello_done(TLS_CONNECT *conn)
{
int ret;
tls_trace("recv ServerHelloDone\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
if (tls_record_protocol(conn->record) != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tls_record_get_handshake_server_hello_done(conn->record) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
if (conn->client_certs_len)
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
return 1;
}
int tls_send_client_certificate(TLS_CONNECT *conn)
{
int ret;
tls_trace("send ClientCertificate\n");
if (conn->client_certs_len == 0) {
error_print();
return -1;
}
if (conn->recordlen == 0) {
if (tls_record_set_handshake_certificate(conn->record, &conn->recordlen,
conn->client_certs, conn->client_certs_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
return 1;
}
// 只有在需要验证客户端证书的时候这个函数才执行,是否内部要判断一下
int tls_recv_client_certificate(TLS_CONNECT *conn)
{
int ret;
const int verify_depth = 5;
int verify_result;
tls_trace("recv ClientCertificate\n");
if (conn->ctx->cacertslen == 0) {
error_print();
return -1;
}
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
if (tls_record_protocol(conn->record) != conn->protocol) { // protocol检查应该在trace之后
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tls_record_get_handshake_certificate(conn->record, conn->client_certs, &conn->client_certs_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (x509_certs_verify(conn->client_certs, conn->client_certs_len, X509_cert_chain_client,
conn->ctx->cacerts, conn->ctx->cacertslen, verify_depth, &verify_result) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
return 1;
}
/*
不同密码套件中使用的密钥生成方法不一样,需要的输入也不一样。
应该考虑在CONN中维护union版本的ServerKeyExchange和ClientKeyExchange
ServerKeyExchange ClientKeyExchange
ECDHE Server X509_KEY X509_KEY public
ECDHE Client X509_KEY X509_KEY public
ECC Server N/A SM2Cipher
ECC Client N/A N/A 客户端可能需要服务器公钥的类型
我们现在还不支持SM2的ECDH呢
SM9相关的密码套件呢
*/
int tls_generate_keys(TLS_CONNECT *conn)
{
uint8_t pre_master_secret[32];
size_t pre_master_secret_len;
uint8_t key_block[96];
// 此时已经获得了ServerKeyExchange和ClientKeyExchange
// 但是不同密码套件中这些KeyExchange的数据其实是不一样的
// 我们需要根据不同的套件去解析数据,并且根据不同的数据类型去生成密钥
// 还需要检查 TLS 1.3 的协议
if (x509_key_exchange(&conn->ecdh_key,
conn->peer_ecdh_point, conn->peer_ecdh_point_len,
pre_master_secret, &pre_master_secret_len) != 1) {
error_print();
return -1;
}
if (pre_master_secret_len != sizeof(pre_master_secret)) {
error_print();
return -1;
}
if (tls_prf(pre_master_secret, 32, "master secret",
conn->client_random, 32,
conn->server_random, 32,
48, conn->master_secret) != 1) {
error_print();
return -1;
}
if (tls_prf(conn->master_secret, 48, "key expansion",
conn->server_random, 32,
conn->client_random, 32,
96, conn->key_block) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
sm3_hmac_init(&conn->client_write_mac_ctx, conn->key_block, 32);
sm3_hmac_init(&conn->server_write_mac_ctx, conn->key_block + 32, 32);
if (conn->is_client) {
sm4_set_encrypt_key(&conn->client_write_enc_key, conn->key_block + 64);
sm4_set_decrypt_key(&conn->server_write_enc_key, conn->key_block + 80);
} else {
sm4_set_decrypt_key(&conn->client_write_enc_key, conn->key_block + 64);
sm4_set_encrypt_key(&conn->server_write_enc_key, conn->key_block + 80);
}
tls_secrets_print(stderr,
pre_master_secret, 48,
conn->client_random, conn->server_random,
conn->master_secret,
conn->key_block, 96,
0, 4);
return 1;
}
// 对于客户端是先发送client_key_exchange在generate_keys
int tlcp_generate_keys(TLS_CONNECT *conn)
{
uint8_t enced_pre_master_secret[SM2_MAX_CIPHERTEXT_SIZE];
size_t enced_pre_master_secret_len;
tls_trace("generate secrets\n");
if (tls_prf(conn->pre_master_secret, 48, "master secret",
conn->client_random, 32,
conn->server_random, 32,
48, conn->master_secret) != 1
|| tls_prf(conn->master_secret, 48, "key expansion",
conn->server_random, 32, // 这里顺序为什么是反的
conn->client_random, 32,
96, conn->key_block) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
// 主力这里是不对的需要为client, server设定不同的加密密钥
sm3_hmac_init(&conn->client_write_mac_ctx, conn->key_block, 32);
sm3_hmac_init(&conn->server_write_mac_ctx, conn->key_block + 32, 32);
if (conn->is_client) {
sm4_set_encrypt_key(&conn->client_write_enc_key, conn->key_block + 64);
sm4_set_decrypt_key(&conn->server_write_enc_key, conn->key_block + 80);
} else {
sm4_set_decrypt_key(&conn->client_write_enc_key, conn->key_block + 64);
sm4_set_encrypt_key(&conn->server_write_enc_key, conn->key_block + 80);
}
tls_secrets_print(stderr,
conn->pre_master_secret, 48,
conn->client_random, conn->server_random,
conn->master_secret,
conn->key_block, 96,
0, 4);
return 1;
}
int tls_send_client_key_exchange(TLS_CONNECT *conn)
{
int ret;
uint8_t point_octets[65];
uint8_t *p = point_octets;
size_t len = 0;
// 客户端的ECDHE的公钥肯定和服务器是保持一致的
// 因此在接收到服务器的公钥之后,应该保存这个信息
if (conn->recordlen == 0) {
int curve_oid = tls_named_curve_oid(conn->ecdh_named_curve);
if (x509_key_generate(&conn->ecdh_key, OID_ec_public_key, &curve_oid, sizeof(curve_oid)) != 1) {
error_print();
return -1;
}
if (x509_public_key_to_bytes(&conn->ecdh_key, &p, &len) != 1) {
error_print();
return -1;
}
if (len != sizeof(point_octets)) {
error_print();
return -1;
}
tls_trace("send ClientKeyExchange\n");
if (tls_record_set_handshake_client_key_exchange(conn->record, &conn->recordlen,
point_octets, len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
if (conn->client_certs_len)
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
return 1;
}
int tls_recv_client_key_exchange(TLS_CONNECT *conn)
{
int ret;
const uint8_t *point_octets;
size_t point_octets_len;
tls_trace("recv ClientKeyExchange\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
if (tls_record_protocol(conn->record) != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tls_record_get_handshake_client_key_exchange(conn->record,
&point_octets, &point_octets_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (point_octets_len != 65) {
error_print();
return -1;
}
memcpy(conn->peer_ecdh_point, point_octets, point_octets_len);
conn->peer_ecdh_point_len = point_octets_len;
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
if (conn->ctx->cacertslen)
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
return 1;
}
// 对于TLCP是否应该先执行send_client_key_exchange再生成密钥呢
int tlcp_send_client_key_exchange(TLS_CONNECT *conn)
{
uint8_t enced_pre_master_secret[SM2_MAX_CIPHERTEXT_SIZE];
size_t enced_pre_master_secret_len;
tls_trace("send ClientKeyExchange\n");
if (tls_pre_master_secret_generate(conn->pre_master_secret, TLS_protocol_tlcp) != 1) {
error_print();
return -1;
}
if (sm2_encrypt(&conn->server_enc_key.u.sm2_key, conn->pre_master_secret, 48,
enced_pre_master_secret, &enced_pre_master_secret_len) != 1
|| tls_record_set_handshake_client_key_exchange_pke(conn->record, &conn->recordlen,
enced_pre_master_secret, enced_pre_master_secret_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tlcp_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tls_record_send(conn->record, conn->recordlen, conn->sock) != 1) {
error_print();
return -1;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
return 1;
}
int tlcp_recv_client_key_exchange(TLS_CONNECT *conn)
{
const uint8_t *enced_pms;
size_t enced_pms_len;
size_t pre_master_secret_len;
tls_trace("recv ClientKeyExchange\n");
if (tls_record_recv(conn->record, &conn->recordlen, conn->sock) != 1
|| tls_record_protocol(conn->record) != TLS_protocol_tlcp) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
tlcp_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tls_record_get_handshake_client_key_exchange_pke(conn->record, &enced_pms, &enced_pms_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
// FIXME:
// 这里需要检查一下密钥的长度,因为输入的长度是确定的,因此输出的密文长度应该也是确定的
if (sm2_decrypt(&conn->kenc_key.u.sm2_key, enced_pms, enced_pms_len,
conn->pre_master_secret, &pre_master_secret_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_decrypt_error);
return -1;
}
if (pre_master_secret_len != 48) {
error_print();
tls_send_alert(conn, TLS_alert_decrypt_error);
return -1;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
return 1;
}
int tls_send_certificate_verify(TLS_CONNECT *conn)
{
int ret;
uint8_t sig[SM2_MAX_SIGNATURE_SIZE];
size_t siglen;
tls_trace("send CertificateVerify\n");
if (!conn->client_certificate_verify) {
error_print();
return -1;
}
if (conn->recordlen == 0) {
if (sm2_sign_finish(&conn->sign_ctx, sig, &siglen) != 1) {
error_print();
return -1;
}
if (tls_record_set_handshake_certificate_verify(conn->record, &conn->recordlen, sig, siglen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
return 1;
}
int tls_recv_certificate_verify(TLS_CONNECT *conn)
{
int ret;
X509_KEY client_sign_key;
const uint8_t *sig;
size_t siglen;
const uint8_t *client_cert;
size_t client_cert_len;
if (!conn->client_certificate_verify) {
error_print();
return -1;
}
tls_trace("recv CertificateVerify\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
if (tls_record_protocol(conn->record) != conn->protocol) {
tls_send_alert(conn, TLS_alert_unexpected_message);
error_print();
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
// get signature from certificate_verify
if (tls_record_get_handshake_certificate_verify(conn->record, &sig, &siglen) != 1) {
tls_send_alert(conn, TLS_alert_unexpected_message);
error_print();
return -1;
}
// get sign_key from client certificate
if (x509_certs_get_cert_by_index(conn->client_certs, conn->client_certs_len, 0,
&client_cert, &client_cert_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
if (x509_cert_get_subject_public_key(client_cert, client_cert_len, &client_sign_key) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
// 这里是否要验证证书的类型呢?我们现在还不支持其他签名算法
if (client_sign_key.algor != OID_ec_public_key
|| client_sign_key.algor_param != OID_sm2) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
if (tls_client_verify_finish(&conn->client_verify_ctx, sig, siglen, &client_sign_key.u.sm2_key) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_decrypt_error);
return -1;
}
sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
return 1;
}
int tls_send_change_cipher_spec(TLS_CONNECT *conn)
{
int ret;
if (conn->recordlen == 0) {
tls_trace("send [ChangeCipherSpec]\n");
if (tls_record_set_change_cipher_spec(conn->record, &conn->recordlen) !=1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
return 1;
}
int tls_recv_change_cipher_spec(TLS_CONNECT *conn)
{
int ret;
tls_trace("recv [ChangeCipherSpec]\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
if (tls_record_protocol(conn->record) != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tls_record_get_change_cipher_spec(conn->record) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
return 1;
}
int tls_send_client_finished(TLS_CONNECT *conn)
{
int ret;
SM3_CTX tmp_sm3_ctx;
uint8_t sm3_hash[32];
uint8_t finished_record[TLS_FINISHED_RECORD_BUF_SIZE];
size_t finished_record_len;
uint8_t local_verify_data[12];
tls_record_set_protocol(finished_record, conn->protocol);
if (conn->recordlen == 0) {
tls_trace("send Finished\n");
// 到目前为止所有消息的哈希
memcpy(&tmp_sm3_ctx, &conn->sm3_ctx, sizeof(SM3_CTX));
sm3_finish(&tmp_sm3_ctx, sm3_hash);
if (tls_prf(conn->master_secret, 48,
"client finished", sm3_hash, 32, NULL, 0,
sizeof(local_verify_data), local_verify_data) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
// finished_record是没有问题的
if (tls_record_set_handshake_finished(finished_record, &finished_record_len,
local_verify_data, sizeof(local_verify_data)) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
// 此时finished_record中的头部应该是完整的
tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0);
sm3_update(&conn->sm3_ctx, finished_record + 5, finished_record_len - 5);
// encrypt Client Finished
// 此时finished_record中的头部应该是完整的
// encrypt Client Finished
// 但是conn->record并没有设置
// 但是conn->record并没有设置
// 这个会把finished的头部copy过来, 但是握手的头部并没有copy
if (tls_record_encrypt(&conn->client_write_mac_ctx, &conn->client_write_enc_key,
conn->client_seq_num, finished_record, finished_record_len,
conn->record, &conn->recordlen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
// 最后输出的负载数据是80-bytes而实际上应该是64字节多了16字节这显然是不对的
// 为什么会有这样的结果呢?
tls_encrypted_record_trace(stderr, conn->record, conn->recordlen, (1<<24), 0); // 强制打印密文原数据
// 这里直接就出错了
/*
if (tls_record_decrypt(&conn->client_write_mac_ctx, &conn->client_write_enc_key,
conn->client_seq_num, conn->record, conn->recordlen, finished_record, &finished_record_len) != 1) {
error_print();
return -1;
}
*/
tls_trace("encrypted ClientFinished\n");
tls_encrypted_record_trace(stderr, conn->record, conn->recordlen, (1<<24), 0); // 强制打印密文原数据
tls_seq_num_incr(conn->client_seq_num);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
return 1;
}
int tls_recv_client_finished(TLS_CONNECT *conn)
{
int ret;
uint8_t finished_record[TLS_FINISHED_RECORD_BUF_SIZE];
size_t finished_record_len;
const uint8_t *verify_data;
size_t verify_data_len;
uint8_t local_verify_data[12];
SM3_CTX tmp_sm3_ctx;
uint8_t sm3_hash[32];
// recv ClientFinished
tls_trace("recv Finished\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
if (tls_record_protocol(conn->record) != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (conn->recordlen > sizeof(finished_record)) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
tls_encrypted_record_trace(stderr, conn->record, conn->recordlen, (1<<24), 0); // 强制打印密文原数据
// decrypt ClientFinished
tls_trace("decrypt Finished\n");
if (tls_record_decrypt(&conn->client_write_mac_ctx, &conn->client_write_enc_key,
conn->client_seq_num, conn->record, conn->recordlen,
finished_record, &finished_record_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_record_mac);
return -1;
}
tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0);
tls_seq_num_incr(conn->client_seq_num);
if (tls_record_get_handshake_finished(finished_record, &verify_data, &verify_data_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_record_mac);
return -1;
}
if (verify_data_len != sizeof(local_verify_data)) {
error_print();
tls_send_alert(conn, TLS_alert_bad_record_mac);
return -1;
}
// verify ClientFinished
memcpy(&tmp_sm3_ctx, &conn->sm3_ctx, sizeof(SM3_CTX));
sm3_update(&conn->sm3_ctx, finished_record + 5, finished_record_len - 5);
sm3_finish(&tmp_sm3_ctx, sm3_hash);
if (tls_prf(conn->master_secret, 48, "client finished", sm3_hash, 32, NULL, 0,
sizeof(local_verify_data), local_verify_data) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
if (memcmp(verify_data, local_verify_data, sizeof(local_verify_data)) != 0) {
error_puts("client_finished.verify_data verification failure");
tls_send_alert(conn, TLS_alert_decrypt_error);
return -1;
}
return 1;
}
int tls_send_server_finished(TLS_CONNECT *conn)
{
int ret;
uint8_t *record = conn->record;
size_t recordlen;
uint8_t sm3_hash[32];
uint8_t local_verify_data[12];
uint8_t finished_record[TLS_FINISHED_RECORD_BUF_SIZE];
size_t finished_record_len;
tls_record_set_protocol(finished_record, conn->protocol);
if (conn->recordlen == 0) {
tls_trace("send Finished\n");
sm3_finish(&conn->sm3_ctx, sm3_hash);
if (tls_prf(conn->master_secret, 48, "server finished", sm3_hash, 32, NULL, 0,
sizeof(local_verify_data), local_verify_data) != 1
|| tls_record_set_handshake_finished(finished_record, &finished_record_len,
local_verify_data, sizeof(local_verify_data)) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0);
if (tls_record_encrypt(&conn->server_write_mac_ctx, &conn->server_write_enc_key,
conn->server_seq_num, finished_record, finished_record_len, record, &recordlen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls_trace("encrypt Finished\n");
tls_encrypted_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据
tls_seq_num_incr(conn->server_seq_num);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
return 1;
}
int tls_recv_server_finished(TLS_CONNECT *conn)
{
int ret;
uint8_t finished_record[TLS_FINISHED_RECORD_BUF_SIZE];
size_t finished_record_len;
uint8_t sm3_hash[32];
const uint8_t *verify_data;
size_t verify_data_len;
uint8_t local_verify_data[12];
// Finished
tls_trace("recv Finished\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
if (tls_record_protocol(conn->record) != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (conn->recordlen > sizeof(finished_record)) {
error_print(); // 解密可能导致 finished_record 溢出
tls_send_alert(conn, TLS_alert_bad_record_mac);
return -1;
}
tls_encrypted_record_trace(stderr, conn->record, conn->recordlen, (1<<24), 0); // 强制打印密文原数据
tls_trace("decrypt Finished\n");
if (tls_record_decrypt(&conn->server_write_mac_ctx, &conn->server_write_enc_key,
conn->server_seq_num, conn->record, conn->recordlen, finished_record, &finished_record_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_record_mac);
return -1;
}
tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0);
tls_seq_num_incr(conn->server_seq_num);
if (tls_record_get_handshake_finished(finished_record, &verify_data, &verify_data_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (verify_data_len != sizeof(local_verify_data)) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
sm3_finish(&conn->sm3_ctx, sm3_hash);
if (tls_prf(conn->master_secret, 48, "server finished",
sm3_hash, 32, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
if (memcmp(verify_data, local_verify_data, sizeof(local_verify_data)) != 0) {
error_puts("server_finished.verify_data verification failure");
tls_send_alert(conn, TLS_alert_decrypt_error);
return -1;
}
if (!conn->ctx->quiet)
fprintf(stderr, "Connection established!\n");
return 1;
}
/*
Client Server
ClientHello -------->
ServerHello
Certificate
ServerKeyExchange
CertificateRequest*
<-------- ServerHelloDone
Certificate*
ClientKeyExchange
CertificateVerify*
[ChangeCipherSpec]
Finished -------->
[ChangeCipherSpec]
<-------- Finished
Application Data <-------> Application Data
*/
int tls12_do_client_handshake(TLS_CONNECT *conn)
{
int ret;
int next_state;
switch (conn->state) {
case TLS_state_client_hello:
ret = tls_send_client_hello(conn);
next_state = TLS_state_server_hello;
break;
case TLS_state_server_hello:
ret = tls_recv_server_hello(conn);
next_state = TLS_state_server_certificate;
break;
case TLS_state_server_certificate:
ret = tls_recv_server_certificate(conn);
next_state = TLS_state_server_key_exchange;
break;
case TLS_state_server_key_exchange:
ret = tls_recv_server_key_exchange(conn);
next_state = TLS_state_certificate_request;
break;
// the only optional state
case TLS_state_certificate_request:
fprintf(stderr, "TLS_state_certificate_request\n");
ret = tls_recv_certificate_request(conn);
fprintf(stderr, " ret = %d\n", ret);
if (ret == 1) conn->client_certificate_verify = 1;
next_state = TLS_state_server_hello_done;
break;
case TLS_state_server_hello_done:
fprintf(stderr, "TLS_state_server_hello_done\n");
ret = tls_recv_server_hello_done(conn);
if (conn->client_certificate_verify)
next_state = TLS_state_client_certificate;
else next_state = TLS_state_client_key_exchange;
break;
case TLS_state_client_certificate:
ret = tls_send_client_certificate(conn);
next_state = TLS_state_client_key_exchange;
break;
case TLS_state_client_key_exchange:
ret = tls_send_client_key_exchange(conn);
next_state = TLS_state_generate_keys;
break;
case TLS_state_generate_keys:
ret = tls_generate_keys(conn);
if (conn->client_certificate_verify)
next_state = TLS_state_certificate_verify;
else next_state = TLS_state_client_change_cipher_spec;
break;
case TLS_state_certificate_verify:
ret = tls_send_certificate_verify(conn);
next_state = TLS_state_client_change_cipher_spec;
case TLS_state_client_change_cipher_spec:
ret = tls_send_change_cipher_spec(conn);
next_state = TLS_state_client_finished;
break;
case TLS_state_client_finished:
ret = tls_send_client_finished(conn);
next_state = TLS_state_server_change_cipher_spec;
break;
case TLS_state_server_change_cipher_spec:
ret = tls_recv_change_cipher_spec(conn);
next_state = TLS_state_server_finished;
break;
case TLS_state_server_finished:
ret = tls_recv_server_finished(conn);
next_state = TLS_state_handshake_over;
break;
default:
error_print();
return -1;
}
if (ret < 0) {
if (ret == TLS_ERROR_RECV_AGAIN || ret == TLS_ERROR_SEND_AGAIN) {
return ret;
} else {
error_print();
return ret;
}
}
conn->state = next_state;
// ret == 0 means this step is bypassed
if (ret == 1) {
tls_clean_record(conn);
}
return 1;
}
int tls12_do_server_handshake(TLS_CONNECT *conn)
{
int ret;
int next_state;
switch (conn->state) {
case TLS_state_client_hello:
ret = tls_recv_client_hello(conn);
next_state = TLS_state_server_hello;
break;
case TLS_state_server_hello:
ret = tls_send_server_hello(conn);
next_state = TLS_state_server_certificate;
break;
case TLS_state_server_certificate:
ret = tls_send_server_certificate(conn);
next_state = TLS_state_server_key_exchange;
break;
case TLS_state_server_key_exchange:
ret = tls_send_server_key_exchange(conn);
if (conn->client_certificate_verify)
next_state = TLS_state_certificate_request;
else next_state = TLS_state_server_hello_done;
break;
case TLS_state_certificate_request:
ret = tls_send_certificate_request(conn);
next_state = TLS_state_server_hello_done;
break;
case TLS_state_server_hello_done:
ret = tls_send_server_hello_done(conn);
if (conn->client_certificate_verify)
next_state = TLS_state_client_certificate;
else next_state = TLS_state_client_key_exchange;
break;
case TLS_state_client_certificate:
ret = tls_recv_client_certificate(conn);
next_state = TLS_state_client_key_exchange;
break;
case TLS_state_client_key_exchange:
ret = tls_recv_client_key_exchange(conn);
if (conn->client_certificate_verify)
next_state = TLS_state_certificate_verify;
else next_state = TLS_state_generate_keys;
break;
case TLS_state_certificate_verify:
ret = tls_recv_certificate_verify(conn);
next_state = TLS_state_generate_keys;
break;
case TLS_state_generate_keys:
ret = tls_generate_keys(conn);
next_state = TLS_state_client_change_cipher_spec;
break;
case TLS_state_client_change_cipher_spec:
ret = tls_recv_change_cipher_spec(conn);
next_state = TLS_state_client_finished;
break;
case TLS_state_client_finished:
ret = tls_recv_client_finished(conn);
next_state = TLS_state_server_change_cipher_spec;
break;
case TLS_state_server_change_cipher_spec:
ret = tls_send_change_cipher_spec(conn);
next_state = TLS_state_server_finished;
break;
case TLS_state_server_finished:
ret = tls_send_server_finished(conn);
next_state = TLS_state_handshake_over;
break;
default:
error_print();
return -1;
}
if (ret != 1) {
if (ret == TLS_ERROR_RECV_AGAIN || ret == TLS_ERROR_SEND_AGAIN) {
return ret;
} else {
error_print();
return ret;
}
}
conn->state = next_state;
tls_clean_record(conn);
return 1;
}
// 这个函数显然是不对的,因为这个函数就是一个重入的函数,重入函数不应该自己设置状态啊
int tls12_client_handshake(TLS_CONNECT *conn)
{
int ret;
while (conn->state != TLS_state_handshake_over) {
ret = tls12_do_client_handshake(conn);
if (ret != 1) {
if (ret != TLS_ERROR_RECV_AGAIN && ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
}
// TODO: cleanup conn?
return 1;
}
int tls12_server_handshake(TLS_CONNECT *conn)
{
int ret;
while (conn->state != TLS_state_handshake_over) {
ret = tls12_do_server_handshake(conn);
if (ret != 1) {
if (ret != TLS_ERROR_RECV_AGAIN && ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
}
// TODO: cleanup conn?
return 1;
}
int tls12_do_connect(TLS_CONNECT *conn)
{
int ret;
fd_set rfds;
fd_set wfds;
conn->state = TLS_state_client_hello;
sm3_init(&conn->sm3_ctx);
digest_init(&conn->dgst_ctx, DIGEST_sm3());
while (1) {
ret = tls12_client_handshake(conn);
if (ret == 1) {
break;
} else if (ret == TLS_ERROR_SEND_AGAIN) {
FD_ZERO(&rfds);
FD_ZERO(&wfds);
FD_SET(conn->sock, &wfds);
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
} else if (ret == TLS_ERROR_RECV_AGAIN) {
FD_ZERO(&rfds);
FD_ZERO(&wfds);
FD_SET(conn->sock, &rfds);
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
} else {
error_print();
return -1;
}
}
return 1;
}
int tls12_do_accept(TLS_CONNECT *conn)
{
int ret;
fd_set rfds;
fd_set wfds;
conn->state = TLS_state_client_hello;
sm3_init(&conn->sm3_ctx);
digest_init(&conn->dgst_ctx, DIGEST_sm3());
while (1) {
ret = tls12_server_handshake(conn);
if (ret == 1) {
break;
} else if (ret == TLS_ERROR_SEND_AGAIN) {
FD_ZERO(&rfds);
FD_ZERO(&wfds);
FD_SET(conn->sock, &rfds);
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
} else if (ret == TLS_ERROR_RECV_AGAIN) {
FD_ZERO(&rfds);
FD_ZERO(&wfds);
FD_SET(conn->sock, &wfds);
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
} else {
error_print();
return -1;
}
}
return 1;
}