Files
GmSSL/src/tls12.c
2026-06-11 20:28:49 +08:00

3321 lines
82 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/*
* Copyright 2014-2026 The GmSSL Project. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the License); you may
* not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*/
#include <time.h>
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include <gmssl/rand.h>
#include <gmssl/x509.h>
#include <gmssl/error.h>
#include <gmssl/sm2.h>
#include <gmssl/sm3.h>
#include <gmssl/sm4.h>
#include <gmssl/pem.h>
#include <gmssl/mem.h>
#include <gmssl/tls.h>
static const int tls12_ciphers[] = {
TLS_cipher_ecdhe_sm4_cbc_sm3,
TLS_cipher_ecdhe_sm4_gcm_sm3,
TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256,
};
int tls12_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int format, int indent)
{
// 目前只支持TLCP的ECC公钥加密套件因此不论用哪个套件解析都是一样的
// 如果未来支持ECDHE套件可以将函数改为宏直接传入 (conn->cipher_suite << 8)
format |= tls12_ciphers[0] << 8; // 应该是KeyExchange需要这个参数
return tls_record_print(fp, record, recordlen, format, indent);
}
// 这里主要的问题是我们没有 cbc_encrypt_blocks 这个函数啊
void cbc_encrypt_blocks(const BLOCK_CIPHER_KEY *key, uint8_t iv[16],
const uint8_t *in, size_t nblocks, uint8_t *out)
{
const uint8_t *piv = iv;
while (nblocks--) {
size_t i;
for (i = 0; i < 16; i++) {
out[i] = in[i] ^ piv[i];
}
block_cipher_encrypt(key, out, out);
piv = out;
in += 16;
out += 16;
}
memcpy(iv, piv, 16);
}
void cbc_decrypt_blocks(const BLOCK_CIPHER_KEY *key, uint8_t iv[16],
const uint8_t *in, size_t nblocks, uint8_t *out)
{
const uint8_t *piv = iv;
while (nblocks--) {
size_t i;
block_cipher_decrypt(key, in, out);
for (i = 0; i < 16; i++) {
out[i] ^= piv[i];
}
piv = in;
in += 16;
out += 16;
}
memcpy(iv, piv, 16);
}
// 这个函数只有在哈希函数为HASH256时才是正确的
int tls12_cbc_encrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *enc_key,
const uint8_t seq_num[8], const uint8_t header[5],
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
{
HMAC_CTX hmac_ctx;
uint8_t last_blocks[32 + 16] = {0};
uint8_t iv[16];
uint8_t *mac, *padding;
size_t maclen;
int rem, padding_len;
int i;
if (!inited_hmac_ctx || !enc_key || !seq_num || !header || (!in && inlen) || !out || !outlen) {
error_print();
return -1;
}
if (inlen > (1 << 14)) {
error_print();
return -1;
}
if ((((size_t)header[3]) << 8) + header[4] != inlen) {
error_print();
return -1;
}
rem = (inlen + 32) % 16;
memcpy(last_blocks, in + inlen - rem, rem);
mac = last_blocks + rem;
memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(HMAC_CTX));
hmac_update(&hmac_ctx, seq_num, 8);
hmac_update(&hmac_ctx, header, 5);
hmac_update(&hmac_ctx, in, inlen);
hmac_finish(&hmac_ctx, mac, &maclen);
padding = mac + 32;
padding_len = 16 - rem - 1;
for (i = 0; i <= padding_len; i++) {
padding[i] = (uint8_t)padding_len;
}
if (rand_bytes(iv, 16) != 1) {
error_print();
return -1;
}
memcpy(out, iv, 16);
out += 16;
if (inlen >= 16) {
cbc_encrypt_blocks(enc_key, iv, in, inlen/16, out);
out += inlen - rem;
}
cbc_encrypt_blocks(enc_key, iv, last_blocks, sizeof(last_blocks)/16, out);
*outlen = 16 + inlen - rem + sizeof(last_blocks);
return 1;
}
int tls12_cbc_decrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *dec_key,
const uint8_t seq_num[8], const uint8_t enced_header[5],
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
{
HMAC_CTX hmac_ctx;
uint8_t iv[16];
const uint8_t *padding;
const uint8_t *mac;
uint8_t header[5];
int padding_len;
uint8_t hmac[32];
size_t hmaclen;
int i;
if (!inited_hmac_ctx || !dec_key || !seq_num || !enced_header || !in || !inlen || !out || !outlen) {
error_print();
return -1;
}
if (inlen % 16
|| inlen < (16 + 0 + 32 + 16) // iv + data + mac + padding
|| inlen > (16 + (1<<14) + 32 + 256)) {
error_print_msg("invalid tls cbc ciphertext length %zu\n", inlen);
return -1;
}
memcpy(iv, in, 16);
format_bytes(stderr, 0, 0, "itls12_cbc_decrypt: iv", iv, 16);
in += 16;
inlen -= 16;
cbc_decrypt_blocks(dec_key, iv, in, inlen/16, out);
format_bytes(stderr, 0, 0, "cbc_decrypt out", out, inlen);
padding_len = out[inlen - 1];
padding = out + inlen - padding_len - 1;
if (padding < out + 32) {
error_print();
return -1;
}
for (i = 0; i < padding_len; i++) {
if (padding[i] != padding_len) {
error_puts("tls ciphertext cbc-padding check failure");
return -1;
}
}
*outlen = inlen - 32 - padding_len - 1;
header[0] = enced_header[0];
header[1] = enced_header[1];
header[2] = enced_header[2];
header[3] = (uint8_t)((*outlen) >> 8);
header[4] = (uint8_t)(*outlen);
mac = padding - 32;
memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(HMAC_CTX));
hmac_update(&hmac_ctx, seq_num, 8);
hmac_update(&hmac_ctx, header, 5);
hmac_update(&hmac_ctx, out, *outlen);
hmac_finish(&hmac_ctx, hmac, &hmaclen);
if (gmssl_secure_memcmp(mac, hmac, sizeof(hmac)) != 0) {
error_puts("tls ciphertext mac check failure\n");
return -1;
}
return 1;
}
int tls12_record_encrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *cbc_key,
const uint8_t seq_num[8], const uint8_t *in, size_t inlen,
uint8_t *out, size_t *outlen)
{
if (tls12_cbc_encrypt(hmac_ctx, cbc_key, seq_num, in,
in + 5, inlen - 5,
out + 5, outlen) != 1) {
error_print();
return -1;
}
out[0] = in[0];
out[1] = in[1];
out[2] = in[2];
out[3] = (uint8_t)((*outlen) >> 8);
out[4] = (uint8_t)(*outlen);
(*outlen) += 5;
return 1;
}
int tls12_record_decrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *cbc_key,
const uint8_t seq_num[8], const uint8_t *in, size_t inlen,
uint8_t *out, size_t *outlen)
{
if (tls12_cbc_decrypt(hmac_ctx, cbc_key, seq_num, in,
in + 5, inlen - 5,
out + 5, outlen) != 1) {
error_print();
return -1;
}
out[0] = in[0];
out[1] = in[1];
out[2] = in[2];
out[3] = (uint8_t)((*outlen) >> 8);
out[4] = (uint8_t)(*outlen);
(*outlen) += 5;
return 1;
}
int tls12_prf(const DIGEST *digest, const uint8_t *secret, size_t secretlen, const char *label,
const uint8_t *seed, size_t seedlen,
const uint8_t *more, size_t morelen,
size_t outlen, uint8_t *out)
{
HMAC_CTX inited_hmac_ctx;
HMAC_CTX hmac_ctx;
uint8_t A[32];
uint8_t hmac[32];
size_t len;
if (!secret || !secretlen || !label || !seed || !seedlen
|| (!more && morelen) || !outlen || !out) {
error_print();
return -1;
}
hmac_init(&inited_hmac_ctx, digest, secret, secretlen);
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX));
hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
hmac_update(&hmac_ctx, seed, seedlen);
hmac_update(&hmac_ctx, more, morelen);
hmac_finish(&hmac_ctx, A, &len); // 检查或者使用长度len
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX));
hmac_update(&hmac_ctx, A, sizeof(A));
hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
hmac_update(&hmac_ctx, seed, seedlen);
hmac_update(&hmac_ctx, more, morelen);
hmac_finish(&hmac_ctx, hmac, &len);
len = outlen < sizeof(hmac) ? outlen : sizeof(hmac);
memcpy(out, hmac, len);
out += len;
outlen -= len;
while (outlen) {
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX));
hmac_update(&hmac_ctx, A, sizeof(A));
hmac_finish(&hmac_ctx, A, &len);
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX));
hmac_update(&hmac_ctx, A, sizeof(A));
hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
hmac_update(&hmac_ctx, seed, seedlen);
hmac_update(&hmac_ctx, more, morelen);
hmac_finish(&hmac_ctx, hmac, &len);
len = outlen < sizeof(hmac) ? outlen : sizeof(hmac);
memcpy(out, hmac, len);
out += len;
outlen -= len;
}
return 1;
}
// modify: conn->record_offset
int tls_send_record(TLS_CONNECT *conn)
{
size_t left;
tls_ret_t n;
left = tls_record_length(conn->record) - conn->record_offset;
while (left) {
n = tls_socket_send(conn->sock, conn->record + conn->record_offset, left, 0);
if (n < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
return TLS_ERROR_SEND_AGAIN;
} else if (errno == EINTR) {
continue;
} else {
fprintf(stderr, "%s %d: send() error: %s\n", __FILE__, __LINE__, strerror(errno));
error_print();
return -1;
}
}
conn->record_offset += n;
left -= n;
}
return 1;
}
int tls_recv_record(TLS_CONNECT *conn)
{
size_t left;
tls_ret_t n;
if (conn->recordlen) {
return 1;
}
if (conn->record_offset < 5) {
left = 5 - conn->record_offset;
while (left) {
n = tls_socket_recv(conn->sock, conn->record + conn->record_offset, left, 0);
if (n < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
return TLS_ERROR_RECV_AGAIN;
} else if (errno == EINTR) {
continue;
} else {
error_print();
// TODO: check the usage of OpenSSL SSL_ERR_SYSCALL
// if applications such as Nginx, HTTPD do not use this error, we just return -1
return TLS_ERROR_SYSCALL;
}
} else if (n == 0) {
error_print();
return TLS_ERROR_TCP_CLOSED;
}
conn->record_offset += n;
left -= n;
}
}
if (conn->record_offset == 5) {
if (!tls_record_type_name(tls_record_type(conn->record))) {
error_print();
return -1;
}
if (!tls_protocol_name(tls_record_protocol(conn->record))) {
error_print();
return -1;
}
if (tls_record_length(conn->record) > TLS_MAX_RECORD_SIZE) {
error_print();
return -1;
}
}
if (conn->record_offset >= tls_record_length(conn->record)) {
error_print();
return -1;
}
left = tls_record_length(conn->record) - conn->record_offset;
while (left) {
n = tls_socket_recv(conn->sock, conn->record + conn->record_offset, left, 0);
if (n < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
return TLS_ERROR_RECV_AGAIN;
} else if (errno == EINTR) {
continue;
} else {
error_print();
return TLS_ERROR_SYSCALL;
}
} else if (n == 0) {
error_print();
return TLS_ERROR_TCP_CLOSED;
}
conn->record_offset += n;
left -= n;
}
conn->recordlen = conn->record_offset;
// 应该判断是否为Alert这种异常状况
return 1;
}
int tls_named_curve_oid(int named_curve)
{
switch (named_curve) {
case TLS_curve_secp256r1: return OID_secp256r1;
case TLS_curve_sm2p256v1: return OID_sm2;
}
return OID_undef;
}
int tls_named_curve_from_oid(int oid)
{
switch (oid) {
case OID_secp256r1: return TLS_curve_secp256r1;
case OID_sm2: return TLS_curve_sm2p256v1;
}
return 0;
}
// 这个是必选的
// 服务器通常推荐返回这个值
const int supported_groups[] = {
TLS_curve_sm2p256v1,
TLS_curve_secp256r1,
};
size_t supported_groups_cnt = sizeof(supported_groups)/sizeof(supported_groups[0]);
// 仍旧是不可设置的
const int signature_algors[] = {
TLS_sig_sm2sig_sm3,
TLS_sig_ecdsa_secp256r1_sha256,
};
size_t signature_algors_cnt = sizeof(signature_algors)/sizeof(signature_algors[0]);
int tls_record_set_handshake_server_key_exchange(uint8_t *record, size_t *recordlen,
const uint8_t *server_ecdh_params, size_t server_ecdh_params_len,
uint16_t sig_alg, const uint8_t *sig, size_t siglen)
{
const int type = TLS_handshake_server_key_exchange;
uint8_t *p = tls_handshake_data(tls_record_data(record));
size_t len = 0;
if (server_ecdh_params_len != 69) {
error_print();
return -1;
}
if (siglen > TLS_MAX_SIGNATURE_SIZE) {
error_print();
return -1;
}
tls_array_to_bytes(server_ecdh_params, server_ecdh_params_len, &p, &len);
tls_uint16_to_bytes(sig_alg, &p, &len);
tls_uint16array_to_bytes(sig, siglen, &p, &len);
tls_record_set_handshake(record, recordlen, type, NULL, len);
return 1;
}
// 这个函数是有问题的因为tlcp的格式和TLS不一样
int tls_record_get_handshake_server_key_exchange(const uint8_t *record,
uint8_t *curve_type, uint16_t *named_curve,
const uint8_t **point_octets, size_t *point_octets_len,
const uint8_t **server_ecdh_params, size_t *server_ecdh_params_len,
uint16_t *sig_alg, const uint8_t **sig, size_t *siglen)
{
int type;
const uint8_t *p;
size_t len;
if (tls_record_get_handshake(record, &type, &p, &len) != 1) {
error_print();
return -1;
}
if (type != TLS_handshake_server_key_exchange) {
error_print();
return -1;
}
*server_ecdh_params = p;
if (tls_uint8_from_bytes(curve_type, &p, &len) != 1
|| tls_uint16_from_bytes(named_curve, &p, &len) != 1
|| tls_uint8array_from_bytes(point_octets, point_octets_len, &p, &len) != 1) {
error_print();
return -1;
}
*server_ecdh_params_len = p - *server_ecdh_params;
if (*server_ecdh_params_len != 69) {
error_print();
return -1;
}
if (tls_uint16_from_bytes(sig_alg, &p, &len) != 1
|| tls_uint16array_from_bytes(sig, siglen, &p, &len) != 1
|| tls_length_is_zero(len) != 1) {
error_print();
return -1;
}
if (*curve_type != TLS_curve_type_named_curve) {
error_print();
return -1;
}
if (!tls_named_curve_name(*named_curve)) {
error_print();
return -1;
}
if (!tls_signature_scheme_name(*sig_alg)) {
error_print();
return -1;
}
return 1;
}
int tls_record_set_handshake_client_key_exchange(uint8_t *record, size_t *recordlen,
const uint8_t *point_octets, size_t point_octets_len)
{
int type = TLS_handshake_client_key_exchange;
uint8_t *p = tls_handshake_data(tls_record_data(record));
size_t len = 0;
if (point_octets_len != 65) {
error_print();
return -1;
}
tls_uint8array_to_bytes(point_octets, (uint8_t)point_octets_len, &p, &len);
tls_record_set_handshake(record, recordlen, type, NULL, len);
return 1;
}
int tls_record_get_handshake_client_key_exchange(const uint8_t *record,
const uint8_t **point_octets, size_t *point_octets_len)
{
int type;
const uint8_t *p;
size_t len;
if (tls_record_get_handshake(record, &type, &p, &len) != 1
|| type != TLS_handshake_client_key_exchange) {
error_print();
return -1;
}
if (tls_uint8array_from_bytes(point_octets, point_octets_len, &p, &len) != 1) {
error_print();
return -1;
}
if (*point_octets_len != 65) {
error_print();
return -1;
}
if (len) {
error_print();
return -1;
}
return 1;
}
int tls12_cert_chains_select(const uint8_t *cert_chains, size_t cert_chains_len,
const int *supported_groups, size_t supported_groups_cnt, // optional
const int *signature_algorithms, size_t signature_algorithms_cnt, // optional
const uint8_t *ca_names, size_t ca_names_len, // certificate_authorities optional
const uint8_t *host_name, size_t host_name_len, // optional, only in ClientHello
const uint8_t **certs, size_t *certs_len, size_t *certs_idx, int *prefered_sig_alg) // optional
{
size_t i;
if (!cert_chains || !cert_chains_len) {
error_print();
return -1;
}
for (i = 1; cert_chains_len; i++) {
const uint8_t *cert_chain;
size_t cert_chain_len;
int sig_alg;
int ret;
if (tls_uint24array_from_bytes(&cert_chain, &cert_chain_len,
&cert_chains, &cert_chains_len) != 1) {
error_print();
return -1;
}
if (certs) *certs = cert_chain;
if (certs_len) *certs_len = cert_chain_len;
if (certs_idx) *certs_idx = i;
if (prefered_sig_alg) *prefered_sig_alg = sig_alg;
return 1;
}
return 0;
}
void tls_clean_record(TLS_CONNECT *conn)
{
conn->record_offset = 0;
conn->recordlen = 0;
}
int tls_handshake_init(TLS_CONNECT *conn)
{
//sm3_init(&conn->sm3_ctx);
digest_init(&conn->dgst_ctx, DIGEST_sm3());
if (conn->client_certs_len) {
//sm2_sign_init(&conn->sign_ctx, &conn->sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH);
}
return 1;
}
const int ec_point_formats[] = { TLS_point_uncompressed };
size_t ec_point_formats_cnt = sizeof(ec_point_formats)/sizeof(ec_point_formats[0]);
// 有可能需要支持SNI
int tls_send_client_hello(TLS_CONNECT *conn)
{
int ret;
if (!conn->recordlen) {
uint8_t exts[TLS_MAX_EXTENSIONS_SIZE];
uint8_t *pexts = exts;
size_t extslen = 0;
tls_trace("send ClientHello\n");
tls_record_set_protocol(conn->record, TLS_protocol_tls1);
if (tls_random_generate(conn->client_random) != 1) {
error_print();
return -1;
}
// ec_point_formats
if (tls_ec_point_formats_ext_to_bytes(
ec_point_formats, ec_point_formats_cnt, &pexts, &extslen) != 1) {
error_print();
return -1;
}
// supported_groups
if (conn->ctx->supported_groups_cnt) {
if (tls_supported_groups_ext_to_bytes(conn->ctx->supported_groups,
conn->ctx->supported_groups_cnt, &pexts, &extslen) != 1) {
error_print();
return -1;
}
}
// signature_algorithms
if (conn->ctx->signature_algorithms_cnt) {
if (tls_signature_algorithms_ext_to_bytes(conn->ctx->signature_algorithms,
conn->ctx->signature_algorithms_cnt, &pexts, &extslen) != 1) {
error_print();
return -1;
}
}
if (tls_record_set_handshake_client_hello(conn->record, &conn->recordlen,
conn->protocol, conn->client_random, NULL, 0,
conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt,
exts, extslen) != 1) {
error_print();
return -1;
}
tls12_record_print(stderr, conn->record, conn->recordlen, 0, 0);
// backup ClientHello
memcpy(conn->plain_record, conn->record, conn->recordlen);
conn->plain_recordlen = conn->recordlen;
}
/*
if (conn->client_certificate_verify) {
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
}
*/
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
tls_clean_record(conn);
return 1;
}
/*
const int server_ciphers[] = { TLS_cipher_ecdhe_sm4_cbc_sm3 };
const size_t server_ciphers_cnt = 1;
*/
const int curve = TLS_curve_sm2p256v1;
static int tls12_cipher_suite_get(int cipher_suite, const BLOCK_CIPHER **cipher, const DIGEST **digest)
{
switch (cipher_suite) {
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_sm4_gcm_sm3:
*cipher = BLOCK_CIPHER_sm4();
*digest = DIGEST_sm3();
break;
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
*cipher = BLOCK_CIPHER_aes128();
*digest = DIGEST_sha256();
break;
default:
error_print();
return -1;
}
return 1;
}
static int tls12_cipher_suite_match_cert_group(int cipher_suite, int cert_group)
{
switch (cipher_suite) {
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_sm4_gcm_sm3:
return cert_group == TLS_curve_sm2p256v1;
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
return cert_group == TLS_curve_secp256r1;
default:
return 0;
}
}
static int tls12_signature_scheme_match_cert_group(int sig_alg, int cert_group)
{
return tls_signature_scheme_group_oid(sig_alg) == tls_named_curve_oid(cert_group);
}
static int tls12_signature_scheme_match_cipher_suite(int sig_alg, int cipher_suite)
{
switch (sig_alg) {
case TLS_sig_sm2sig_sm3:
switch (cipher_suite) {
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_sm4_gcm_sm3:
return 1;
}
break;
case TLS_sig_ecdsa_secp256r1_sha256:
if (cipher_suite == TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256) {
return 1;
}
break;
}
return 0;
}
static int tls12_key_exchange_group_match_cipher_suite(int group, int cipher_suite)
{
switch (cipher_suite) {
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_sm4_gcm_sm3:
return group == TLS_curve_sm2p256v1;
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
return group == TLS_curve_secp256r1;
default:
return 0;
}
}
static int tls12_select_common_cipher_suites(const uint8_t *client_ciphers, size_t client_ciphers_len,
const int *server_ciphers, size_t server_ciphers_cnt,
int *common_ciphers, size_t *common_ciphers_cnt, size_t max_cnt)
{
size_t i;
if (!client_ciphers || !client_ciphers_len
|| !server_ciphers || !server_ciphers_cnt
|| !common_ciphers || !common_ciphers_cnt || !max_cnt) {
error_print();
return -1;
}
*common_ciphers_cnt = 0;
for (i = 0; i < server_ciphers_cnt && *common_ciphers_cnt < max_cnt; i++) {
const uint8_t *p = client_ciphers;
size_t len = client_ciphers_len;
while (len) {
uint16_t cipher;
if (tls_uint16_from_bytes(&cipher, &p, &len) != 1) {
error_print();
return -1;
}
if (cipher == server_ciphers[i]) {
common_ciphers[(*common_ciphers_cnt)++] = server_ciphers[i];
break;
}
}
}
return *common_ciphers_cnt ? 1 : 0;
}
static int tls12_cert_chain_get_end_entity_group(const uint8_t *cert_chain, size_t cert_chain_len, int *group)
{
const uint8_t *cert;
size_t certlen;
X509_KEY public_key;
if (!cert_chain || !cert_chain_len || !group) {
error_print();
return -1;
}
if (x509_certs_get_cert_by_index(cert_chain, cert_chain_len, 0, &cert, &certlen) != 1
|| x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) {
error_print();
return -1;
}
if (public_key.algor != OID_ec_public_key) {
error_print();
return -1;
}
if ((*group = tls_named_curve_from_oid(public_key.algor_param)) == 0) {
error_print();
return -1;
}
return 1;
}
static int tls12_select_key_exchange_group(const int *groups, size_t groups_cnt,
int cipher_suite, int *selected_group)
{
size_t i;
if (!groups || !groups_cnt || !selected_group) {
error_print();
return -1;
}
for (i = 0; i < groups_cnt; i++) {
if (tls12_key_exchange_group_match_cipher_suite(groups[i], cipher_suite)) {
*selected_group = groups[i];
return 1;
}
}
return 0;
}
// 这个函数的名字最好换一下
static int tls12_select_parameters(TLS_CONNECT *conn,
const int *common_cipher_suites, size_t common_cipher_suites_cnt,
const int *common_supported_groups, size_t common_supported_groups_cnt,
const int *common_signature_algorithms, size_t common_signature_algorithms_cnt,
const int *signature_algorithms_cert, size_t signature_algorithms_cert_cnt,
const uint8_t *host_name, size_t host_name_len)
{
const uint8_t *cert_chains = conn->ctx->cert_chains;
size_t cert_chains_len = conn->ctx->cert_chains_len;
size_t cert_chain_idx;
if (!conn || !common_cipher_suites || !common_cipher_suites_cnt
|| !common_supported_groups || !common_supported_groups_cnt
|| !common_signature_algorithms || !common_signature_algorithms_cnt) {
error_print();
return -1;
}
if (!cert_chains || !cert_chains_len) {
error_print();
return -1;
}
for (cert_chain_idx = 1; cert_chains_len; cert_chain_idx++) {
const uint8_t *cert_chain;
size_t cert_chain_len;
const uint8_t *cert;
size_t certlen;
int cert_group;
size_t i;
int ret;
if (tls_uint24array_from_bytes(&cert_chain, &cert_chain_len,
&cert_chains, &cert_chains_len) != 1) {
error_print();
return -1;
}
if (tls12_cert_chain_get_end_entity_group(cert_chain, cert_chain_len, &cert_group) != 1) {
error_print();
return -1;
}
if (!tls_type_is_in_list(cert_group, common_supported_groups, common_supported_groups_cnt)) {
continue;
}
if (x509_certs_get_cert_by_index(cert_chain, cert_chain_len, 0, &cert, &certlen) != 1) {
error_print();
return -1;
}
if (host_name && host_name_len) {
if ((ret = tls_cert_match_server_name(cert, certlen, host_name, host_name_len)) < 0) {
error_print();
return -1;
} else if (ret == 0) {
continue;
}
}
if (signature_algorithms_cert && signature_algorithms_cert_cnt) {
if ((ret = tls_cert_chain_match_signature_algorithms_cert(cert_chain, cert_chain_len,
signature_algorithms_cert, signature_algorithms_cert_cnt)) < 0) {
error_print();
return -1;
} else if (ret == 0) {
continue;
}
}
for (i = 0; i < common_cipher_suites_cnt; i++) {
size_t j;
int cipher_suite = common_cipher_suites[i];
int key_exchange_group;
if (!tls12_cipher_suite_match_cert_group(cipher_suite, cert_group)) {
continue;
}
if ((ret = tls12_select_key_exchange_group(common_supported_groups,
common_supported_groups_cnt, cipher_suite, &key_exchange_group)) < 0) {
error_print();
return -1;
} else if (ret == 0) {
continue;
}
for (j = 0; j < common_signature_algorithms_cnt; j++) {
int sig_alg = common_signature_algorithms[j];
if (!tls12_signature_scheme_match_cert_group(sig_alg, cert_group)) {
continue;
}
if (!tls12_signature_scheme_match_cipher_suite(sig_alg, cipher_suite)) {
continue;
}
conn->cipher_suite = cipher_suite;
conn->cert_chain = cert_chain;
conn->cert_chain_len = cert_chain_len;
conn->cert_chain_idx = cert_chain_idx;
conn->sig_alg = sig_alg;
conn->key_exchange_group = key_exchange_group;
return 1;
}
}
}
warning_print();
return 0;
}
int tls_recv_client_hello(TLS_CONNECT *conn)
{
int ret;
int client_verify = 0;
int protocol;
const uint8_t *client_random;
const uint8_t *session_id;
size_t session_id_len;
const uint8_t *cipher_suites;
size_t cipher_suites_len;
const uint8_t *exts;
size_t extslen;
const uint8_t *supported_groups = NULL;
size_t supported_groups_len = 0;
const uint8_t *signature_algorithms = NULL;
size_t signature_algorithms_len = 0;
const uint8_t *signature_algorithms_cert = NULL;
size_t signature_algorithms_cert_len = 0;
const uint8_t *server_name = NULL;
size_t server_name_len = 0;
int common_cipher_suites[TLS_MAX_CIPHER_SUITES_COUNT];
size_t common_cipher_suites_cnt = 0;
int common_supported_groups[32];
size_t common_supported_groups_cnt = 0;
int common_signature_algorithms[32];
size_t common_signature_algorithms_cnt = 0;
int common_signature_algorithms_cert[32];
size_t common_signature_algorithms_cert_cnt = 0;
const int *cert_signature_algorithms = NULL;
size_t cert_signature_algorithms_cnt = 0;
const uint8_t *host_name = NULL;
size_t host_name_len = 0;
/*
if (client_verify)
tls_client_verify_init(&conn->client_verify_ctx);
*/
tls_trace("recv ClientHello\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tls_record_protocol(conn->record) != TLS_protocol_tls1) {
error_print();
tls_send_alert(conn, TLS_alert_protocol_version);
return -1;
}
if ((ret = tls_record_get_handshake_client_hello(conn->record,
&protocol, &client_random, &session_id, &session_id_len,
&cipher_suites, &cipher_suites_len, &exts, &extslen)) < 0) {
error_print();
tls13_send_alert(conn, TLS_alert_decode_error);
return -1;
} else if (ret == 0) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (protocol != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_protocol_version);
return -1;
}
memcpy(conn->client_random, client_random, 32);
while (extslen) {
int ext_type;
const uint8_t *ext_data;
size_t ext_datalen;
if (tls_ext_from_bytes(&ext_type, &ext_data, &ext_datalen, &exts, &extslen) != 1) {
error_print();
tls13_send_alert(conn, TLS_alert_decode_error);
return -1;
}
switch (ext_type) {
case TLS_extension_supported_groups:
case TLS_extension_signature_algorithms:
case TLS_extension_signature_algorithms_cert:
case TLS_extension_server_name:
if (!ext_data) {
error_print();
tls_send_alert(conn, TLS_alert_decode_error);
return -1;
}
break;
}
switch (ext_type) {
case TLS_extension_supported_groups:
if (supported_groups) {
error_print();
tls_send_alert(conn, TLS_alert_illegal_parameter);
return -1;
}
supported_groups = ext_data;
supported_groups_len = ext_datalen;
break;
case TLS_extension_signature_algorithms:
if (signature_algorithms) {
error_print();
tls_send_alert(conn, TLS_alert_illegal_parameter);
return -1;
}
signature_algorithms = ext_data;
signature_algorithms_len = ext_datalen;
break;
case TLS_extension_signature_algorithms_cert:
if (signature_algorithms_cert) {
error_print();
tls_send_alert(conn, TLS_alert_illegal_parameter);
return -1;
}
signature_algorithms_cert = ext_data;
signature_algorithms_cert_len = ext_datalen;
break;
case TLS_extension_server_name:
if (server_name) {
error_print();
tls_send_alert(conn, TLS_alert_illegal_parameter);
return -1;
}
server_name = ext_data;
server_name_len = ext_datalen;
break;
default:
warning_print();
}
}
if ((ret = tls12_select_common_cipher_suites(cipher_suites, cipher_suites_len,
conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt,
common_cipher_suites, &common_cipher_suites_cnt,
sizeof(common_cipher_suites)/sizeof(common_cipher_suites[0]))) < 0) {
error_print();
tls_send_alert(conn, TLS_alert_decode_error);
return -1;
} else if (ret == 0) {
error_print();
tls_send_alert(conn, TLS_alert_handshake_failure);
return -1;
}
if (supported_groups) {
if ((ret = tls_process_supported_groups(supported_groups, supported_groups_len,
conn->ctx->supported_groups, conn->ctx->supported_groups_cnt,
common_supported_groups, &common_supported_groups_cnt,
sizeof(common_supported_groups)/sizeof(common_supported_groups[0]))) < 0) {
error_print();
tls_send_alert(conn, TLS_alert_decode_error);
return -1;
} else if (ret == 0) {
error_print();
tls_send_alert(conn, TLS_alert_handshake_failure);
return -1;
}
} else {
if (!conn->ctx->supported_groups_cnt) {
error_print();
tls_send_alert(conn, TLS_alert_handshake_failure);
return -1;
}
memcpy(common_supported_groups, conn->ctx->supported_groups,
conn->ctx->supported_groups_cnt * sizeof(conn->ctx->supported_groups[0]));
common_supported_groups_cnt = conn->ctx->supported_groups_cnt;
}
if (signature_algorithms) {
if ((ret = tls_process_signature_algorithms(signature_algorithms, signature_algorithms_len,
conn->ctx->signature_algorithms, conn->ctx->signature_algorithms_cnt,
common_signature_algorithms, &common_signature_algorithms_cnt,
sizeof(common_signature_algorithms)/sizeof(common_signature_algorithms[0]))) < 0) {
error_print();
tls_send_alert(conn, TLS_alert_decode_error);
return -1;
} else if (ret == 0) {
error_print();
tls_send_alert(conn, TLS_alert_handshake_failure);
return -1;
}
} else {
if (!conn->ctx->signature_algorithms_cnt) {
error_print();
tls13_send_alert(conn, TLS_alert_handshake_failure);
return -1;
}
memcpy(common_signature_algorithms, conn->ctx->signature_algorithms,
conn->ctx->signature_algorithms_cnt * sizeof(conn->ctx->signature_algorithms[0]));
common_signature_algorithms_cnt = conn->ctx->signature_algorithms_cnt;
}
if (signature_algorithms_cert) {
if ((ret = tls_process_signature_algorithms(signature_algorithms_cert, signature_algorithms_cert_len,
conn->ctx->signature_algorithms, conn->ctx->signature_algorithms_cnt,
common_signature_algorithms_cert, &common_signature_algorithms_cert_cnt,
sizeof(common_signature_algorithms_cert)/sizeof(common_signature_algorithms_cert[0]))) < 0) {
error_print();
tls_send_alert(conn, TLS_alert_decode_error);
return -1;
} else if (ret == 0) {
error_print();
tls_send_alert(conn, TLS_alert_handshake_failure);
return -1;
}
cert_signature_algorithms = common_signature_algorithms_cert;
cert_signature_algorithms_cnt = common_signature_algorithms_cert_cnt;
} else if (signature_algorithms) {
cert_signature_algorithms = common_signature_algorithms;
cert_signature_algorithms_cnt = common_signature_algorithms_cnt;
}
if (server_name) {
if (tls_server_name_from_bytes(&host_name, &host_name_len, server_name, server_name_len) != 1) {
error_print();
tls13_send_alert(conn, TLS_alert_decode_error);
return -1;
}
conn->server_name = 1;
}
if ((ret = tls12_select_parameters(conn,
common_cipher_suites, common_cipher_suites_cnt,
common_supported_groups, common_supported_groups_cnt,
common_signature_algorithms, common_signature_algorithms_cnt,
cert_signature_algorithms, cert_signature_algorithms_cnt,
host_name, host_name_len)) < 0) {
error_print();
tls_send_alert(conn, TLS_alert_decode_error);
return -1;
} else if (ret == 0) {
error_print();
tls_send_alert(conn, TLS_alert_handshake_failure);
return -1;
}
if (tls12_cipher_suite_get(conn->cipher_suite, &conn->cipher, &conn->digest) != 1) {
error_print();
tls13_send_alert(conn, TLS_alert_internal_error);
return -1;
}
if (digest_init(&conn->dgst_ctx, conn->digest) != 1) {
error_print();
return -1;
}
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
error_print();
return -1;
}
tls_handshake_digest_print(stderr, 0, 0, "ClientHello", &conn->dgst_ctx);
/*
if (client_verify)
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
*/
fprintf(stderr, "end of recv_client_hello\n");
tls_clean_record(conn);
return 1;
}
int tls_send_server_hello(TLS_CONNECT *conn)
{
int ret;
tls_trace("send ServerHello\n");
if (conn->recordlen == 0) {
uint8_t exts[512];
uint8_t *pexts = exts;
size_t extslen = 0;
tls_record_set_protocol(conn->record, conn->protocol);
if (tls_random_generate(conn->server_random) != 1) {
error_print();
return -1;
}
// extensions in ServerHello
// ec_point_formats
// supported_groups
// signature_algorithms
if (tls_ec_point_formats_ext_to_bytes(ec_point_formats, ec_point_formats_cnt, &pexts, &extslen) != 1) {
error_print();
return -1;
}
/*
if (tls_supported_groups_ext_to_bytes(conn->ctx->supported_groups, conn->ctx->supported_groups_cnt,
&pexts, &extslen) != 1) {
error_print();
return -1;
}
if (tls_signature_algorithms_ext_to_bytes(conn->ctx->signature_algorithms, conn->ctx->signature_algorithms_cnt,
&pexts, &extslen) != 1) {
error_print();
return -1;
}
*/
if (tls_record_set_handshake_server_hello(conn->record, &conn->recordlen,
conn->protocol, conn->server_random, NULL, 0,
conn->cipher_suite,
exts, extslen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
error_print();
return -1;
}
tls_handshake_digest_print(stderr, 0, 0, "ServerHello", &conn->dgst_ctx);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
//sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
if (conn->ctx->cacertslen) {
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
}
tls_clean_record(conn);
return 1;
}
int tls_recv_server_hello(TLS_CONNECT *conn)
{
int ret;
int protocol;
int cipher_suite;
const uint8_t *server_random;
const uint8_t *session_id;
size_t session_id_len;
const uint8_t *exts;
size_t extslen;
const uint8_t *ec_point_formats = NULL;
size_t ec_point_formats_len;
const uint8_t *supported_groups = NULL;
size_t supported_groups_len;
const uint8_t *signature_algorithms = NULL;
size_t signature_algorithms_len;
tls_trace("recv ServerHello\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
tls12_record_print(stderr, conn->record, conn->recordlen, 0, 0);
if (tls_record_protocol(conn->record) != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_protocol_version);
return -1;
}
if ((ret = tls_record_get_handshake_server_hello(conn->record,
&protocol, &server_random, &session_id, &session_id_len, &cipher_suite,
&exts, &extslen)) < 0) {
error_print();
tls_send_alert(conn, TLS_alert_decode_error);
return -1;
} else if (ret == 0) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
// version
if (protocol != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_protocol_version);
return -1;
}
// random
memcpy(conn->server_random, server_random, 32);
// session_id
memcpy(conn->session_id, session_id, session_id_len);
// cipher_suite
if (tls_type_is_in_list(cipher_suite, conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_handshake_failure);
return -1;
}
conn->cipher_suite = cipher_suite;
// 初始化digest
conn->digest = DIGEST_sha256();
conn->cipher = BLOCK_CIPHER_aes128();
if (digest_init(&conn->dgst_ctx, conn->digest) != 1) {
error_print();
return -1;
}
while (extslen) {
int ext_type;
const uint8_t *ext_data;
size_t ext_datalen;
if (tls_ext_from_bytes(&ext_type, &ext_data, &ext_datalen, &exts, &extslen) != 1) {
error_print();
tls13_send_alert(conn, TLS_alert_decode_error);
return -1;
}
// extensions in ServerHello
// * ec_point_formats
// * supported_groups
// * signature_algorithms
switch (ext_type) {
case TLS_extension_ec_point_formats:
case TLS_extension_supported_groups:
case TLS_extension_signature_algorithms:
if (!ext_data) {
error_print();
tls13_send_alert(conn, TLS_alert_illegal_parameter);
return -1;
}
break;
default:
error_print();
return -1;
}
switch (ext_type) {
case TLS_extension_ec_point_formats:
if (ec_point_formats) {
error_print();
tls13_send_alert(conn, TLS_alert_illegal_parameter);
return -1;
}
ec_point_formats = ext_data;
ec_point_formats_len = ext_datalen;
break;
case TLS_extension_supported_groups:
if (supported_groups) {
error_print();
tls13_send_alert(conn, TLS_alert_illegal_parameter);
return -1;
}
supported_groups = ext_data;
supported_groups_len = ext_datalen;
break;
case TLS_extension_signature_algorithms:
if (signature_algorithms) {
error_print();
tls13_send_alert(conn, TLS_alert_illegal_parameter);
return -1;
}
signature_algorithms = ext_data;
signature_algorithms_len = ext_datalen;
break;
}
}
if (!ec_point_formats) {
error_print();
tls13_send_alert(conn, TLS_alert_missing_extension);
return -1;
}
if (supported_groups) {
}
if (signature_algorithms) {
}
if (digest_update(&conn->dgst_ctx, conn->plain_record + 5, conn->plain_recordlen - 5) != 1) {
error_print();
return -1;
}
tls_handshake_digest_print(stderr, 0, 0, "ClientHello", &conn->dgst_ctx);
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
error_print();
return -1;
}
tls_handshake_digest_print(stderr, 0, 0, "ServerHello", &conn->dgst_ctx);
//sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
if (conn->client_certs_len) {
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
}
return 1;
}
// TLS12 发送的是常规的证书链
// TLCP SM2 发送的是SM2的双证书链但是在数据格式上没有区别
// TLCP SM9 发送的是服务器的ID和SM9公开参数这个格式是不同的但是存储上可能也是一样的
// 我不确定SM2和SM9的格式是否是相容的
int tls_send_server_certificate(TLS_CONNECT *conn)
{
int ret;
tls_trace("send ServerCertificate\n");
if (conn->recordlen == 0) {
if (tls_record_set_handshake_certificate(conn->record, &conn->recordlen,
conn->cert_chain, conn->cert_chain_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
error_print();
return -1;
}
tls_handshake_digest_print(stderr, 0, 0, "Certificate", &conn->dgst_ctx);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
if (conn->client_certificate_verify) {
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
}
return 1;
}
int tls_recv_server_certificate(TLS_CONNECT *conn)
{
int ret;
int verify_result;
const uint8_t *server_cert;
size_t server_cert_len;
X509_KEY server_sign_key;
tls_trace("recv server Certificate\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tls_record_protocol(conn->record) != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if ((ret = tls_record_get_handshake_certificate(conn->record,
conn->peer_cert_chain, &conn->peer_cert_chain_len)) < 0) {
error_print();
tls_send_alert(conn, TLS_alert_decode_error);
return -1;
} else if (ret == 0) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return 0;
}
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
error_print();
return -1;
}
tls_handshake_digest_print(stderr, 0, 0, "Certificate", &conn->dgst_ctx);
// 这里取服务器证书似乎没有什么用处啊
if (x509_certs_get_cert_by_index(conn->peer_cert_chain, conn->peer_cert_chain_len, 0,
&server_cert, &server_cert_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
if (x509_cert_get_subject_public_key(server_cert, server_cert_len, &server_sign_key) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
// 这里的逻辑需要统筹考虑
// cipher_suite扩展证书之间的关系
// set conn->server_sig_alg (decided by cipher_suite and server_cert.sign_key.algor, algor_param)
if (server_sign_key.algor != OID_ec_public_key) {
error_print();
return -1;
}
switch (conn->cipher_suite) {
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_sm4_gcm_sm3:
case TLS_cipher_ecc_sm4_cbc_sm3:
case TLS_cipher_ecc_sm4_gcm_sm3:
if (server_sign_key.algor_param != OID_sm2) {
error_print();
return -1;
}
conn->signature_algorithms[0] = TLS_sig_sm2sig_sm3;
break;
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
if (server_sign_key.algor_param != OID_secp256r1) {
error_print();
return -1;
}
conn->signature_algorithms[0] = TLS_sig_ecdsa_secp256r1_sha256;
break;
default:
error_print();
return -1;
}
if (conn->client_certs_len) {
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
}
assert(conn->ctx->verify_depth > 0 && conn->ctx->verify_depth < 10);
// verify ServerCertificate
if (conn->ctx->cacertslen) {
// 按道理来说,这只是验证证书,并应该出错啊
if (x509_certs_verify_tlcp(conn->peer_cert_chain, conn->peer_cert_chain_len, X509_cert_chain_server,
conn->ctx->cacerts, conn->ctx->cacertslen, conn->ctx->verify_depth, &verify_result) != 1) {
error_print();
//tls_send_alert(conn, TLS_alert_bad_certificate);
//return -1;
}
}
return 1;
}
int tls_send_server_key_exchange(TLS_CONNECT *conn)
{
int ret;
uint8_t server_ecdh_params[69];
uint8_t *p = server_ecdh_params + 4;
size_t len = 0;
X509_SIGN_CTX sign_ctx;
const void *sign_args = NULL;
size_t sign_argslen = 0;
uint8_t sig[X509_SIGNATURE_MAX_SIZE];
size_t siglen;
tls_trace("send ServerKeyExchange\n");
if (conn->recordlen == 0) {
int curve_oid = tls_named_curve_oid(conn->key_exchange_group);
// generate server ecdh_key
if (x509_key_generate(&conn->key_exchanges[0], OID_ec_public_key, &curve_oid, sizeof(curve_oid)) != 1) {
error_print();
return -1;
}
// build server_ecdh_params
server_ecdh_params[0] = TLS_curve_type_named_curve;
server_ecdh_params[1] = conn->key_exchange_group >> 8;
server_ecdh_params[2] = (uint8_t)conn->key_exchange_group;
server_ecdh_params[3] = 65;
if (x509_public_key_to_bytes(&conn->key_exchanges[0], &p, &len) != 1) {
error_print();
return -1;
}
if (len != 65) {
error_print();
return -1;
}
X509_KEY *sign_key = &conn->ctx->x509_keys[conn->cert_chain_idx - 1];
// sign server_ecdh_params
if (sign_key->algor == OID_ec_public_key && sign_key->algor_param == OID_sm2) {
sign_args = SM2_DEFAULT_ID;
sign_argslen = SM2_DEFAULT_ID_LENGTH;
}
if (x509_sign_init(&sign_ctx, sign_key, sign_args, sign_argslen) != 1
|| x509_sign_update(&sign_ctx, conn->client_random, 32) != 1
|| x509_sign_update(&sign_ctx, conn->server_random, 32) != 1
|| x509_sign_update(&sign_ctx, server_ecdh_params, 69) != 1
|| x509_sign_finish(&sign_ctx, sig, &siglen) != 1) {
x509_sign_ctx_cleanup(&sign_ctx);
error_print();
return -1;
}
x509_sign_ctx_cleanup(&sign_ctx);
if (tls_record_set_handshake_server_key_exchange(conn->record, &conn->recordlen,
server_ecdh_params, sizeof(server_ecdh_params),
conn->sig_alg, sig, siglen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
error_print();
return -1;
}
tls_handshake_digest_print(stderr, 0, 0, "ServerKeyExchange", &conn->dgst_ctx);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
//sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
if (conn->client_certificate_verify) {
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
}
return 1;
}
// match the ecdhe of cipher_suite
int tls_curve_match_cipher_suite(int named_curve, int cipher_suite)
{
switch (named_curve) {
case TLS_curve_sm2p256v1:
switch (cipher_suite) {
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_sm4_gcm_sm3:
break;
default:
error_print();
return -1;
}
break;
case TLS_curve_secp256r1:
if (cipher_suite != TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256) {
error_print();
return -1;
}
break;
default:
error_print();
return -1;
}
return 1;
}
int tls_signature_scheme_match_cipher_suite(int sig_alg, int cipher_suite)
{
switch (sig_alg) {
case TLS_sig_sm2sig_sm3:
switch (cipher_suite) {
case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_sm4_gcm_sm3:
case TLS_cipher_ecc_sm4_cbc_sm3:
case TLS_cipher_ecc_sm4_gcm_sm3:
break;
default:
error_print();
return -1;
}
break;
case TLS_sig_ecdsa_secp256r1_sha256:
switch (cipher_suite) {
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
break;
default:
error_print();
return -1;
}
break;
default:
error_print();
return -1;
}
return 1;
}
int tls_recv_server_key_exchange(TLS_CONNECT *conn)
{
int ret;
uint8_t curve_type;
uint16_t named_curve;
const uint8_t *point_octets;
size_t point_octets_len;
const uint8_t *server_ecdh_params;
size_t server_ecdh_params_len;
uint16_t sig_alg;
const uint8_t *sig;
size_t siglen;
// verify ServerKeyExchange
X509_KEY server_sign_key;
int server_cert_index = 0;
const uint8_t *server_cert;
size_t server_cert_len;
X509_SIGN_CTX sign_ctx;
const void *sign_args = NULL;
size_t sign_argslen = 0;
tls_trace("recv ServerKeyExchange\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
if (tls_record_protocol(conn->record) != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
// 这个函数可能是有问题的如果cipher_suite不同ServerKeyExchange可能也是不同的
if ((ret = tls_record_get_handshake_server_key_exchange(conn->record,
&curve_type, &named_curve, &point_octets, &point_octets_len,
&server_ecdh_params, &server_ecdh_params_len,
&sig_alg, &sig, &siglen)) < 0) {
error_print();
tls_send_alert(conn, TLS_alert_decode_error);
return -1;
} else if (ret == 0) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return 0;
}
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
error_print();
return -1;
}
tls_handshake_digest_print(stderr, 0, 0, "ServerKeyExchange", &conn->dgst_ctx);
if (curve_type != TLS_curve_type_named_curve) {
error_print();
return -1;
}
// named_curve应该在supported_groups里面
//conn->ecdh_named_curve = named_curve;
conn->key_exchange_group = named_curve;
memcpy(conn->peer_key_exchange, point_octets, point_octets_len);
conn->peer_key_exchange_len = point_octets_len;
if (point_octets_len != 65) {
error_print();
return -1;
}
if (tls_curve_match_cipher_suite(named_curve, conn->cipher_suite) != 1) {
error_print();
return -1;
}
if (point_octets_len != 65) {
error_print();
return -1;
}
if (tls_signature_scheme_match_cipher_suite(sig_alg, conn->cipher_suite) != 1) {
error_print();
return -1;
}
// 解析server_key_exchange, curve_type, curve_name, point 这三个信息
// 判断curve_type == named_curve
// 判断curve_name在supported_groups中并记录这个信息
// 验证point确实在curve_name的group中
//sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
if (conn->client_certs_len)
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
if (x509_certs_get_cert_by_index(conn->peer_cert_chain, conn->peer_cert_chain_len,
server_cert_index, &server_cert, &server_cert_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
if (x509_cert_get_subject_public_key(server_cert, server_cert_len, &server_sign_key) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
// 这个检查是否是多余的?
// 这个值是签名算法和椭圆曲线名字的结合
// cipher_suite只能决定签名算法类型
// 公钥证书里面的公钥实际上只包含曲线的类型(而不决定签名算法,因为一个椭圆曲线本质上支持多种不同的签名算法)
switch (sig_alg) {
case TLS_sig_sm2sig_sm3:
if (server_sign_key.algor != OID_ec_public_key
|| server_sign_key.algor_param != OID_sm2) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
break;
case TLS_sig_ecdsa_secp256r1_sha256:
if (server_sign_key.algor != OID_ec_public_key
|| server_sign_key.algor_param != OID_secp256r1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
break;
default:
error_print();
return -1;
}
if (server_sign_key.algor == OID_ec_public_key && server_sign_key.algor_param == OID_sm2) {
sign_args = SM2_DEFAULT_ID;
sign_argslen = SM2_DEFAULT_ID_LENGTH;
}
// 这里应该是SM2的签名和验证
if (x509_verify_init(&sign_ctx, &server_sign_key, sign_args, sign_argslen, sig, siglen) != 1
|| x509_verify_update(&sign_ctx, conn->client_random, 32) != 1
|| x509_verify_update(&sign_ctx, conn->server_random, 32) != 1
|| x509_verify_update(&sign_ctx, server_ecdh_params, 69) != 1
|| x509_verify_finish(&sign_ctx) != 1) {
error_print();
return -1;
}
fprintf(stderr, ">>>>>> ServerKeyExchange verify success\n");
// xxxx
// 这里的签名错了肯定是sign_ctx就是不对的因此是不可能正确的
// 现在要做的是必须确定server_key_exchange中都包括了哪些被签名的消息
return 1;
}
int tls_send_certificate_request(TLS_CONNECT *conn)
{
int ret;
// 如果要进行客户端证书验证,服务器要提供验证的证书,但是所有证书的
const uint8_t cert_types[] = { TLS_cert_type_ecdsa_sign };
uint8_t ca_names[TLS_MAX_CA_NAMES_SIZE] = {0}; // TODO: 根据客户端验证CA证书列计算缓冲大小或直接输出到record缓冲
size_t ca_names_len = 0;
if (!conn->client_certificate_verify) {
error_print();
return -1;
}
if (conn->recordlen == 0) {
tls_trace("send CertificateRequest\n");
if (tls_authorities_from_certs(ca_names, &ca_names_len, sizeof(ca_names),
conn->ctx->cacerts, conn->ctx->cacertslen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
if (tls_record_set_handshake_certificate_request(conn->record, &conn->recordlen,
cert_types, sizeof(cert_types),
ca_names, ca_names_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
//sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
return 1;
}
int tls_recv_certificate_request(TLS_CONNECT *conn)
{
int ret;
uint8_t *record = conn->record;
const uint8_t *cp;
size_t len;
int handshake_type;
const uint8_t *cert_types;
size_t cert_types_len;
const uint8_t *ca_names;
size_t ca_names_len;
tls_trace("recv CertificateRequest*\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
if (tls_record_protocol(conn->record) != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (tls_record_get_handshake(record, &handshake_type, &cp, &len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (handshake_type != TLS_handshake_certificate_request) {
tls_trace(" no CertificateRequest\n");
return 0; // 表明对方没有发送预期的报文
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tls_record_get_handshake_certificate_request(conn->record,
&cert_types, &cert_types_len, &ca_names, &ca_names_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
error_print();
return -1;
}
// 这里要检查一下服务器发送的,和本地的是否保持一致
/*
if(!conn->client_certs_len) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
if (tls_cert_types_accepted(cert_types, cert_types_len, conn->client_certs, conn->client_certs_len) != 1
|| tls_authorities_issued_certificate(ca_names, ca_names_len, conn->client_certs, conn->client_certs_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unsupported_certificate);
return -1;
}
*/
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
conn->recordlen = 0;
return 1;
}
int tls_send_server_hello_done(TLS_CONNECT *conn)
{
int ret;
tls_trace("send ServerHelloDone\n");
if (conn->recordlen == 0) {
tls_record_set_handshake_server_hello_done(conn->record, &conn->recordlen);
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
error_print();
return -1;
}
tls_handshake_digest_print(stderr, 0, 0, "ServerHelloDone", &conn->dgst_ctx);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
if (conn->client_certs_len) {
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
}
return 1;
}
// 这是一个非常特殊的状态其他的所有recv状态都是要读取的
// 但是这个状态在大多数情况下,之前已经读取完了,但是我们无法判断这个信息
int tls_recv_server_hello_done(TLS_CONNECT *conn)
{
int ret;
tls_trace("recv ServerHelloDone\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
if (tls_record_protocol(conn->record) != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tls_record_get_handshake_server_hello_done(conn->record) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
error_print();
return -1;
}
tls_handshake_digest_print(stderr, 0, 0, "ServerHelloDone", &conn->dgst_ctx);
if (conn->client_certs_len)
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
return 1;
}
int tls_send_client_certificate(TLS_CONNECT *conn)
{
int ret;
tls_trace("send ClientCertificate\n");
if (conn->client_certs_len == 0) {
error_print();
return -1;
}
if (conn->recordlen == 0) {
if (tls_record_set_handshake_certificate(conn->record, &conn->recordlen,
conn->client_certs, conn->client_certs_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
//sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
return 1;
}
// 只有在需要验证客户端证书的时候这个函数才执行,是否内部要判断一下
int tls_recv_client_certificate(TLS_CONNECT *conn)
{
int ret;
const int verify_depth = 5;
int verify_result;
tls_trace("recv ClientCertificate\n");
if (conn->ctx->cacertslen == 0) {
error_print();
return -1;
}
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
if (tls_record_protocol(conn->record) != conn->protocol) { // protocol检查应该在trace之后
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tls_record_get_handshake_certificate(conn->record, conn->client_certs, &conn->client_certs_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (x509_certs_verify(conn->client_certs, conn->client_certs_len, X509_cert_chain_client,
conn->ctx->cacerts, conn->ctx->cacertslen, verify_depth, &verify_result) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
//sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
return 1;
}
int tls_generate_keys(TLS_CONNECT *conn)
{
uint8_t pre_master_secret[32];
size_t pre_master_secret_len;
// 这里密钥是完全用ECDHE生成的
if (x509_key_exchange(&conn->key_exchanges[0],
conn->peer_key_exchange, conn->peer_key_exchange_len,
pre_master_secret, &pre_master_secret_len) != 1) {
error_print();
return -1;
}
if (pre_master_secret_len != sizeof(pre_master_secret)) {
error_print();
return -1;
}
format_bytes(stderr, 0, 0, "pre_master_secret", pre_master_secret, pre_master_secret_len);
// master_secret和transcript_hash没有任何关系
if (tls12_prf(conn->digest,
pre_master_secret, 32,
"master secret",
conn->client_random, 32,
conn->server_random, 32,
48, conn->master_secret) != 1) {
error_print();
return -1;
}
format_bytes(stderr, 0, 0, "master_secret", conn->master_secret, 48);
// OpenSSL tls1_prf 中这里生成的是128字节也就是把IV也生成了
// 为什么生成IV呢
if (tls12_prf(conn->digest, conn->master_secret, 48, "key expansion",
conn->server_random, 32,
conn->client_random, 32,
96, conn->key_block) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
/*
如果这里导出了IV并且用这个IV去加密数据
被加密的数据中包含了一个随机的IV那么这个随机的IV是干什么用的呢
*/
format_bytes(stderr, 0, 0, "key_blocks", conn->key_block, 96);
if (hmac_init(&conn->client_write_mac_ctx, conn->digest, conn->key_block, 32) != 1) {
error_print();
return -1;
}
if (hmac_init(&conn->server_write_mac_ctx, conn->digest, conn->key_block + 32, 32) != 1) {
error_print();
return -1;
}
format_bytes(stderr, 0, 0, "client_write_mac_key", conn->key_block, 32);
format_bytes(stderr, 0, 0, "server_write_mac_key", conn->key_block + 32, 32);
format_bytes(stderr, 0, 0, "client_write_key", conn->key_block + 64, 16);
format_bytes(stderr, 0, 0, "server_write_key", conn->key_block + 80, 16);
if (conn->is_client) {
block_cipher_set_encrypt_key(&conn->client_write_key, conn->cipher, conn->key_block + 64);
block_cipher_set_decrypt_key(&conn->server_write_key, conn->cipher, conn->key_block + 80);
} else {
block_cipher_set_decrypt_key(&conn->client_write_key, conn->cipher, conn->key_block + 64);
block_cipher_set_encrypt_key(&conn->server_write_key, conn->cipher, conn->key_block + 80);
}
tls_seq_num_reset(conn->client_seq_num);
tls_seq_num_reset(conn->server_seq_num);
/*
tls_secrets_print(stderr,
pre_master_secret, 32,
conn->client_random, conn->server_random,
conn->master_secret,
conn->key_block, 96,
0, 4);
*/
return 1;
}
int tls_send_client_key_exchange(TLS_CONNECT *conn)
{
int ret;
// 客户端的ECDHE的公钥肯定和服务器是保持一致的
// 因此在接收到服务器的公钥之后,应该保存这个信息
// 客户端是怎么确定密钥交换的group的大概是从ServerKeyExchange中确定的
if (conn->recordlen == 0) {
uint8_t point_octets[65];
uint8_t *p = point_octets;
size_t len = 0;
int curve_oid = tls_named_curve_oid(conn->key_exchange_group);
if (x509_key_generate(&conn->key_exchanges[0], OID_ec_public_key, &curve_oid, sizeof(curve_oid)) != 1) {
error_print();
return -1;
}
if (x509_public_key_to_bytes(&conn->key_exchanges[0], &p, &len) != 1) {
error_print();
return -1;
}
if (len != sizeof(point_octets)) {
error_print();
return -1;
}
tls_trace("send ClientKeyExchange\n");
if (tls_record_set_handshake_client_key_exchange(conn->record, &conn->recordlen,
point_octets, len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
error_print();
return -1;
}
tls_handshake_digest_print(stderr, 0, 0, "ClientKeyExchange", &conn->dgst_ctx);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
if (conn->client_certs_len)
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
return 1;
}
int tls_recv_client_key_exchange(TLS_CONNECT *conn)
{
int ret;
const uint8_t *point_octets;
size_t point_octets_len;
tls_trace("recv ClientKeyExchange\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
if (tls_record_protocol(conn->record) != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tls_record_get_handshake_client_key_exchange(conn->record,
&point_octets, &point_octets_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (point_octets_len != 65) {
error_print();
return -1;
}
memcpy(conn->peer_key_exchange, point_octets, point_octets_len);
conn->peer_key_exchange_len = point_octets_len;
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
error_print();
return -1;
}
tls_handshake_digest_print(stderr, 0, 0, "ClientKeyExchange", &conn->dgst_ctx);
if (conn->ctx->cacertslen)
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
return 1;
}
int tls_send_certificate_verify(TLS_CONNECT *conn)
{
int ret;
uint8_t sig[SM2_MAX_SIGNATURE_SIZE];
size_t siglen;
tls_trace("send CertificateVerify\n");
if (!conn->client_certificate_verify) {
error_print();
return -1;
}
if (conn->recordlen == 0) {
if (sm2_sign_finish(&conn->sign_ctx, sig, &siglen) != 1) {
error_print();
return -1;
}
if (tls_record_set_handshake_certificate_verify(conn->record, &conn->recordlen, sig, siglen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
//sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
return 1;
}
int tls_recv_certificate_verify(TLS_CONNECT *conn)
{
int ret;
X509_KEY client_sign_key;
const uint8_t *sig;
size_t siglen;
const uint8_t *client_cert;
size_t client_cert_len;
if (!conn->client_certificate_verify) {
error_print();
return -1;
}
tls_trace("recv CertificateVerify\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
if (tls_record_protocol(conn->record) != conn->protocol) {
tls_send_alert(conn, TLS_alert_unexpected_message);
error_print();
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
// get signature from certificate_verify
if (tls_record_get_handshake_certificate_verify(conn->record, &sig, &siglen) != 1) {
tls_send_alert(conn, TLS_alert_unexpected_message);
error_print();
return -1;
}
// get sign_key from client certificate
if (x509_certs_get_cert_by_index(conn->client_certs, conn->client_certs_len, 0,
&client_cert, &client_cert_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
if (x509_cert_get_subject_public_key(client_cert, client_cert_len, &client_sign_key) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
// 这里是否要验证证书的类型呢?我们现在还不支持其他签名算法
if (client_sign_key.algor != OID_ec_public_key
|| client_sign_key.algor_param != OID_sm2) {
error_print();
tls_send_alert(conn, TLS_alert_bad_certificate);
return -1;
}
if (tls_client_verify_finish(&conn->client_verify_ctx, sig, siglen, &client_sign_key.u.sm2_key) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_decrypt_error);
return -1;
}
//sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
return 1;
}
int tls_send_change_cipher_spec(TLS_CONNECT *conn)
{
int ret;
if (conn->recordlen == 0) {
tls_trace("send [ChangeCipherSpec]\n");
if (tls_record_set_change_cipher_spec(conn->record, &conn->recordlen) !=1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
return 1;
}
int tls_recv_change_cipher_spec(TLS_CONNECT *conn)
{
int ret;
tls_trace("recv [ChangeCipherSpec]\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
if (tls_record_protocol(conn->record) != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
if (tls_record_get_change_cipher_spec(conn->record) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
return 1;
}
int tls_send_client_finished(TLS_CONNECT *conn)
{
int ret;
if (conn->recordlen == 0) {
tls_trace("send client {Finished}\n");
uint8_t local_verify_data[12];
DIGEST_CTX tmp_ctx;
uint8_t dgst[32];
size_t dgstlen;
tmp_ctx = conn->dgst_ctx;
digest_finish(&tmp_ctx, dgst, &dgstlen);
if (tls12_prf(conn->digest,
conn->master_secret, 48,
"client finished", dgst, dgstlen, NULL, 0,
sizeof(local_verify_data), local_verify_data) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls_record_set_protocol(conn->plain_record, conn->protocol);
if (tls_record_set_handshake_finished(conn->plain_record, &conn->plain_recordlen,
local_verify_data, sizeof(local_verify_data)) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls12_record_trace(stderr, conn->plain_record, conn->plain_recordlen, 0, 0);
if (digest_update(&conn->dgst_ctx, conn->plain_record + 5, conn->plain_recordlen - 5) != 1) {
error_print();
return -1;
}
tls_handshake_digest_print(stderr, 0, 0, "Finished", &conn->dgst_ctx);
if (tls12_record_encrypt(&conn->client_write_mac_ctx, &conn->client_write_key,
conn->client_seq_num, conn->plain_record, conn->plain_recordlen,
conn->record, &conn->recordlen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls_seq_num_incr(conn->client_seq_num);
format_bytes(stderr, 0, 0, "encrypted finsished ..... ", conn->record, conn->recordlen);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
return 1;
}
int tls_recv_client_finished(TLS_CONNECT *conn)
{
int ret;
const uint8_t *verify_data;
size_t verify_data_len;
uint8_t local_verify_data[12];
DIGEST_CTX tmp_ctx;
uint8_t dgst[32];
size_t dgstlen;
tmp_ctx = conn->dgst_ctx;
if (digest_finish(&tmp_ctx, dgst, &dgstlen) != 1) {
error_print();
return -1;
}
if (tls12_prf(conn->digest, conn->master_secret, 48, "client finished", dgst, dgstlen, NULL, 0,
sizeof(local_verify_data), local_verify_data) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
format_bytes(stderr, 0, 0, "verify_data", local_verify_data, 12);
// recv ClientFinished
tls_trace("recv client {Finished}\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
//tls12_record_print(stderr, conn->record, conn->recordlen, 0, 0);
format_bytes(stderr, 0, 0, "Finished", conn->record, conn->recordlen);
if (tls_record_protocol(conn->record) != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
// decrypt ClientFinished
tls_trace(">>>>>>>decrypt Finished\n");
format_bytes(stderr, 0, 0, "client_seq_num", conn->client_seq_num, 8);
if (tls12_record_decrypt(&conn->client_write_mac_ctx, &conn->client_write_key,
conn->client_seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_record_mac);
return -1;
}
tls_seq_num_incr(conn->client_seq_num);
tls12_record_trace(stderr, conn->plain_record, conn->plain_recordlen, 0, 0);
if (tls_record_get_handshake_finished(conn->plain_record, &verify_data, &verify_data_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_record_mac);
return -1;
}
if (verify_data_len != sizeof(local_verify_data)) {
error_print();
tls_send_alert(conn, TLS_alert_bad_record_mac);
return -1;
}
if (digest_update(&conn->dgst_ctx, conn->plain_record + 5, conn->plain_recordlen - 5) != 1) {
error_print();
return -1;
}
tls_handshake_digest_print(stderr, 0, 0, "client Finished", &conn->dgst_ctx);
// verify ClientFinished
if (memcmp(verify_data, local_verify_data, sizeof(local_verify_data)) != 0) {
error_puts("client_finished.verify_data verification failure");
tls_send_alert(conn, TLS_alert_decrypt_error);
return -1;
}
return 1;
}
int tls_send_server_finished(TLS_CONNECT *conn)
{
int ret;
uint8_t *record = conn->record;
size_t recordlen;
uint8_t local_verify_data[12];
tls_record_set_protocol(conn->plain_record, conn->protocol);
if (conn->recordlen == 0) {
tls_trace("send server Finished\n");
uint8_t dgst[32];
size_t dgstlen;
digest_finish(&conn->dgst_ctx, dgst, &dgstlen);
if (tls12_prf(conn->digest, conn->master_secret, 48, "server finished", dgst, dgstlen, NULL, 0,
sizeof(local_verify_data), local_verify_data) != 1) {
error_print();
return -1;
}
format_bytes(stderr, 0, 0, "server verify_data", local_verify_data, 12);
if (tls_record_set_handshake_finished(conn->plain_record, &conn->plain_recordlen,
local_verify_data, sizeof(local_verify_data)) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls12_record_trace(stderr, conn->plain_record, conn->plain_recordlen, 0, 0);
if (tls12_record_encrypt(&conn->server_write_mac_ctx, &conn->server_write_key,
conn->server_seq_num, conn->plain_record, conn->plain_recordlen,
conn->record, &conn->recordlen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
tls_seq_num_incr(conn->server_seq_num);
}
if ((ret = tls_send_record(conn)) != 1) {
if (ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
return 1;
}
int tls_recv_server_finished(TLS_CONNECT *conn)
{
int ret;
uint8_t finished_record[TLS_FINISHED_RECORD_BUF_SIZE];
size_t finished_record_len;
uint8_t dgst[32];
size_t dgstlen;
const uint8_t *verify_data;
size_t verify_data_len;
uint8_t local_verify_data[12];
if (digest_finish(&conn->dgst_ctx, dgst, &dgstlen) != 1) {
error_print();
return -1;
}
if (tls12_prf(conn->digest, conn->master_secret, 48, "server finished",
dgst, dgstlen, NULL, 0,
sizeof(local_verify_data), local_verify_data) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
return -1;
}
format_bytes(stderr, 0, 0, ">>> verify_data", local_verify_data, 12);
// Finished
tls_trace("recv server Finished\n");
if ((ret = tls_recv_record(conn)) != 1) {
if (ret != TLS_ERROR_RECV_AGAIN) {
error_print();
}
return ret;
}
if (tls_record_protocol(conn->record) != conn->protocol) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
tls_trace("decrypt Finished\n");
format_bytes(stderr, 0, 0, "server_seq_num", conn->server_seq_num, 8);
if (tls12_record_decrypt(&conn->server_write_mac_ctx, &conn->server_write_key,
conn->server_seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_bad_record_mac);
return -1;
}
tls12_record_print(stderr, conn->plain_record, conn->plain_recordlen, 0, 0);
tls_seq_num_incr(conn->server_seq_num);
if (tls_record_get_handshake_finished(conn->plain_record, &verify_data, &verify_data_len) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (verify_data_len != sizeof(local_verify_data)) {
error_print();
tls_send_alert(conn, TLS_alert_unexpected_message);
return -1;
}
if (memcmp(verify_data, local_verify_data, sizeof(local_verify_data)) != 0) {
error_puts("server_finished.verify_data verification failure");
tls_send_alert(conn, TLS_alert_decrypt_error);
return -1;
}
if (!conn->ctx->quiet)
fprintf(stderr, "Connection established!\n");
return 1;
}
/*
Client Server
ClientHello -------->
ServerHello
Certificate
ServerKeyExchange
CertificateRequest*
<-------- ServerHelloDone
Certificate*
ClientKeyExchange
CertificateVerify*
[ChangeCipherSpec]
Finished -------->
[ChangeCipherSpec]
<-------- Finished
Application Data <-------> Application Data
*/
int tls12_do_client_handshake(TLS_CONNECT *conn)
{
int ret;
int next_state;
switch (conn->state) {
case TLS_state_client_hello:
ret = tls_send_client_hello(conn);
next_state = TLS_state_server_hello;
break;
case TLS_state_server_hello:
ret = tls_recv_server_hello(conn);
next_state = TLS_state_server_certificate;
break;
case TLS_state_server_certificate:
ret = tls_recv_server_certificate(conn);
next_state = TLS_state_server_key_exchange;
break;
case TLS_state_server_key_exchange:
ret = tls_recv_server_key_exchange(conn);
next_state = TLS_state_certificate_request;
break;
// the only optional state
case TLS_state_certificate_request:
fprintf(stderr, "TLS_state_certificate_request\n");
ret = tls_recv_certificate_request(conn);
fprintf(stderr, " ret = %d\n", ret);
if (ret == 1) conn->client_certificate_verify = 1;
next_state = TLS_state_server_hello_done;
break;
case TLS_state_server_hello_done:
fprintf(stderr, "TLS_state_server_hello_done\n");
ret = tls_recv_server_hello_done(conn);
if (conn->client_certificate_verify)
next_state = TLS_state_client_certificate;
else next_state = TLS_state_client_key_exchange;
break;
case TLS_state_client_certificate:
ret = tls_send_client_certificate(conn);
next_state = TLS_state_client_key_exchange;
break;
case TLS_state_client_key_exchange:
ret = tls_send_client_key_exchange(conn);
next_state = TLS_state_generate_keys;
break;
case TLS_state_generate_keys:
ret = tls_generate_keys(conn);
if (conn->client_certificate_verify)
next_state = TLS_state_certificate_verify;
else next_state = TLS_state_client_change_cipher_spec;
break;
case TLS_state_certificate_verify:
ret = tls_send_certificate_verify(conn);
next_state = TLS_state_client_change_cipher_spec;
case TLS_state_client_change_cipher_spec:
ret = tls_send_change_cipher_spec(conn);
next_state = TLS_state_client_finished;
break;
case TLS_state_client_finished:
ret = tls_send_client_finished(conn);
next_state = TLS_state_server_change_cipher_spec;
break;
case TLS_state_server_change_cipher_spec:
ret = tls_recv_change_cipher_spec(conn);
next_state = TLS_state_server_finished;
break;
case TLS_state_server_finished:
ret = tls_recv_server_finished(conn);
next_state = TLS_state_handshake_over;
break;
default:
error_print();
return -1;
}
if (ret < 0) {
if (ret == TLS_ERROR_RECV_AGAIN || ret == TLS_ERROR_SEND_AGAIN) {
return ret;
} else {
error_print();
return ret;
}
}
conn->state = next_state;
// ret == 0 means this step is bypassed
if (ret == 1) {
tls_clean_record(conn);
}
return 1;
}
int tls12_do_server_handshake(TLS_CONNECT *conn)
{
int ret;
int next_state;
switch (conn->state) {
case TLS_state_client_hello:
ret = tls_recv_client_hello(conn);
next_state = TLS_state_server_hello;
break;
case TLS_state_server_hello:
ret = tls_send_server_hello(conn);
next_state = TLS_state_server_certificate;
break;
case TLS_state_server_certificate:
ret = tls_send_server_certificate(conn);
next_state = TLS_state_server_key_exchange;
break;
case TLS_state_server_key_exchange:
ret = tls_send_server_key_exchange(conn);
if (conn->client_certificate_verify)
next_state = TLS_state_certificate_request;
else next_state = TLS_state_server_hello_done;
break;
case TLS_state_certificate_request:
ret = tls_send_certificate_request(conn);
next_state = TLS_state_server_hello_done;
break;
case TLS_state_server_hello_done:
ret = tls_send_server_hello_done(conn);
if (conn->client_certificate_verify)
next_state = TLS_state_client_certificate;
else next_state = TLS_state_client_key_exchange;
break;
case TLS_state_client_certificate:
ret = tls_recv_client_certificate(conn);
next_state = TLS_state_client_key_exchange;
break;
case TLS_state_client_key_exchange:
ret = tls_recv_client_key_exchange(conn);
if (conn->client_certificate_verify)
next_state = TLS_state_certificate_verify;
else next_state = TLS_state_generate_keys;
break;
case TLS_state_certificate_verify:
ret = tls_recv_certificate_verify(conn);
next_state = TLS_state_generate_keys;
break;
case TLS_state_generate_keys:
ret = tls_generate_keys(conn);
next_state = TLS_state_client_change_cipher_spec;
break;
case TLS_state_client_change_cipher_spec:
ret = tls_recv_change_cipher_spec(conn);
next_state = TLS_state_client_finished;
break;
case TLS_state_client_finished:
ret = tls_recv_client_finished(conn);
next_state = TLS_state_server_change_cipher_spec;
break;
case TLS_state_server_change_cipher_spec:
ret = tls_send_change_cipher_spec(conn);
next_state = TLS_state_server_finished;
break;
case TLS_state_server_finished:
ret = tls_send_server_finished(conn);
next_state = TLS_state_handshake_over;
break;
default:
error_print();
return -1;
}
if (ret != 1) {
if (ret == TLS_ERROR_RECV_AGAIN || ret == TLS_ERROR_SEND_AGAIN) {
return ret;
} else {
error_print();
return ret;
}
}
conn->state = next_state;
tls_clean_record(conn);
return 1;
}
// 这个函数显然是不对的,因为这个函数就是一个重入的函数,重入函数不应该自己设置状态啊
int tls12_client_handshake(TLS_CONNECT *conn)
{
int ret;
while (conn->state != TLS_state_handshake_over) {
ret = tls12_do_client_handshake(conn);
if (ret != 1) {
if (ret != TLS_ERROR_RECV_AGAIN && ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
}
// TODO: cleanup conn?
return 1;
}
int tls12_server_handshake(TLS_CONNECT *conn)
{
int ret;
while (conn->state != TLS_state_handshake_over) {
ret = tls12_do_server_handshake(conn);
if (ret != 1) {
if (ret != TLS_ERROR_RECV_AGAIN && ret != TLS_ERROR_SEND_AGAIN) {
error_print();
}
return ret;
}
}
// TODO: cleanup conn?
return 1;
}
int tls12_do_connect(TLS_CONNECT *conn)
{
int ret;
fd_set rfds;
fd_set wfds;
conn->state = TLS_state_client_hello;
//sm3_init(&conn->sm3_ctx);
digest_init(&conn->dgst_ctx, DIGEST_sm3());
while (1) {
ret = tls12_client_handshake(conn);
if (ret == 1) {
break;
} else if (ret == TLS_ERROR_SEND_AGAIN) {
FD_ZERO(&rfds);
FD_ZERO(&wfds);
FD_SET(conn->sock, &wfds);
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
} else if (ret == TLS_ERROR_RECV_AGAIN) {
FD_ZERO(&rfds);
FD_ZERO(&wfds);
FD_SET(conn->sock, &rfds);
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
} else {
error_print();
return -1;
}
}
return 1;
}
int tls12_do_accept(TLS_CONNECT *conn)
{
int ret;
fd_set rfds;
fd_set wfds;
conn->state = TLS_state_client_hello;
//sm3_init(&conn->sm3_ctx);
digest_init(&conn->dgst_ctx, DIGEST_sm3());
while (1) {
ret = tls12_server_handshake(conn);
if (ret == 1) {
break;
} else if (ret == TLS_ERROR_SEND_AGAIN) {
FD_ZERO(&rfds);
FD_ZERO(&wfds);
FD_SET(conn->sock, &rfds);
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
} else if (ret == TLS_ERROR_RECV_AGAIN) {
FD_ZERO(&rfds);
FD_ZERO(&wfds);
FD_SET(conn->sock, &wfds);
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
} else {
error_print();
return -1;
}
}
return 1;
}