mirror of
https://github.com/guanzhi/GmSSL.git
synced 2026-06-19 19:33:38 +08:00
3418 lines
85 KiB
C
3418 lines
85 KiB
C
/*
|
||
* Copyright 2014-2026 The GmSSL Project. All Rights Reserved.
|
||
*
|
||
* Licensed under the Apache License, Version 2.0 (the License); you may
|
||
* not use this file except in compliance with the License.
|
||
*
|
||
* http://www.apache.org/licenses/LICENSE-2.0
|
||
*/
|
||
|
||
|
||
#include <time.h>
|
||
#include <errno.h>
|
||
#include <stdio.h>
|
||
#include <stdlib.h>
|
||
#include <string.h>
|
||
#include <assert.h>
|
||
#include <gmssl/rand.h>
|
||
#include <gmssl/x509.h>
|
||
#include <gmssl/error.h>
|
||
#include <gmssl/sm2.h>
|
||
#include <gmssl/sm3.h>
|
||
#include <gmssl/sm4.h>
|
||
#include <gmssl/pem.h>
|
||
#include <gmssl/mem.h>
|
||
#include <gmssl/tls.h>
|
||
|
||
|
||
|
||
|
||
static const int tls12_ciphers[] = {
|
||
TLS_cipher_ecdhe_sm4_cbc_sm3,
|
||
TLS_cipher_ecdhe_sm4_gcm_sm3,
|
||
TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256,
|
||
};
|
||
|
||
|
||
int tls12_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int format, int indent)
|
||
{
|
||
// 目前只支持TLCP的ECC公钥加密套件,因此不论用哪个套件解析都是一样的
|
||
// 如果未来支持ECDHE套件,可以将函数改为宏,直接传入 (conn->cipher_suite << 8)
|
||
format |= tls12_ciphers[0] << 8; // 应该是KeyExchange需要这个参数
|
||
return tls_record_print(fp, record, recordlen, format, indent);
|
||
}
|
||
|
||
// 这里主要的问题是我们没有 cbc_encrypt_blocks 这个函数啊
|
||
|
||
|
||
void cbc_encrypt_blocks(const BLOCK_CIPHER_KEY *key, uint8_t iv[16],
|
||
const uint8_t *in, size_t nblocks, uint8_t *out)
|
||
{
|
||
const uint8_t *piv = iv;
|
||
|
||
while (nblocks--) {
|
||
size_t i;
|
||
for (i = 0; i < 16; i++) {
|
||
out[i] = in[i] ^ piv[i];
|
||
}
|
||
block_cipher_encrypt(key, out, out);
|
||
piv = out;
|
||
in += 16;
|
||
out += 16;
|
||
}
|
||
|
||
memcpy(iv, piv, 16);
|
||
}
|
||
|
||
void cbc_decrypt_blocks(const BLOCK_CIPHER_KEY *key, uint8_t iv[16],
|
||
const uint8_t *in, size_t nblocks, uint8_t *out)
|
||
{
|
||
const uint8_t *piv = iv;
|
||
|
||
while (nblocks--) {
|
||
size_t i;
|
||
block_cipher_decrypt(key, in, out);
|
||
for (i = 0; i < 16; i++) {
|
||
out[i] ^= piv[i];
|
||
}
|
||
piv = in;
|
||
in += 16;
|
||
out += 16;
|
||
}
|
||
|
||
memcpy(iv, piv, 16);
|
||
}
|
||
|
||
|
||
// 这个函数只有在哈希函数为HASH256时才是正确的
|
||
int tls12_cbc_encrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *enc_key,
|
||
const uint8_t seq_num[8], const uint8_t header[5],
|
||
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
|
||
{
|
||
HMAC_CTX hmac_ctx;
|
||
uint8_t last_blocks[32 + 16] = {0};
|
||
uint8_t iv[16];
|
||
uint8_t *mac, *padding;
|
||
size_t maclen;
|
||
int rem, padding_len;
|
||
int i;
|
||
|
||
if (!inited_hmac_ctx || !enc_key || !seq_num || !header || (!in && inlen) || !out || !outlen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (inlen > (1 << 14)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if ((((size_t)header[3]) << 8) + header[4] != inlen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
rem = (inlen + 32) % 16;
|
||
memcpy(last_blocks, in + inlen - rem, rem);
|
||
mac = last_blocks + rem;
|
||
|
||
memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(HMAC_CTX));
|
||
hmac_update(&hmac_ctx, seq_num, 8);
|
||
hmac_update(&hmac_ctx, header, 5);
|
||
hmac_update(&hmac_ctx, in, inlen);
|
||
hmac_finish(&hmac_ctx, mac, &maclen);
|
||
|
||
padding = mac + 32;
|
||
padding_len = 16 - rem - 1;
|
||
for (i = 0; i <= padding_len; i++) {
|
||
padding[i] = (uint8_t)padding_len;
|
||
}
|
||
|
||
if (rand_bytes(iv, 16) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
memcpy(out, iv, 16);
|
||
out += 16;
|
||
|
||
if (inlen >= 16) {
|
||
cbc_encrypt_blocks(enc_key, iv, in, inlen/16, out);
|
||
out += inlen - rem;
|
||
}
|
||
cbc_encrypt_blocks(enc_key, iv, last_blocks, sizeof(last_blocks)/16, out);
|
||
|
||
*outlen = 16 + inlen - rem + sizeof(last_blocks);
|
||
return 1;
|
||
}
|
||
|
||
int tls12_cbc_decrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *dec_key,
|
||
const uint8_t seq_num[8], const uint8_t enced_header[5],
|
||
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
|
||
{
|
||
HMAC_CTX hmac_ctx;
|
||
uint8_t iv[16];
|
||
const uint8_t *padding;
|
||
const uint8_t *mac;
|
||
uint8_t header[5];
|
||
int padding_len;
|
||
uint8_t hmac[32];
|
||
size_t hmaclen;
|
||
int i;
|
||
|
||
if (!inited_hmac_ctx || !dec_key || !seq_num || !enced_header || !in || !inlen || !out || !outlen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (inlen % 16
|
||
|| inlen < (16 + 0 + 32 + 16) // iv + data + mac + padding
|
||
|| inlen > (16 + (1<<14) + 32 + 256)) {
|
||
error_print_msg("invalid tls cbc ciphertext length %zu\n", inlen);
|
||
return -1;
|
||
}
|
||
|
||
memcpy(iv, in, 16);
|
||
|
||
format_bytes(stderr, 0, 0, "itls12_cbc_decrypt: iv", iv, 16);
|
||
|
||
|
||
in += 16;
|
||
inlen -= 16;
|
||
|
||
cbc_decrypt_blocks(dec_key, iv, in, inlen/16, out);
|
||
|
||
format_bytes(stderr, 0, 0, "cbc_decrypt out", out, inlen);
|
||
|
||
|
||
padding_len = out[inlen - 1];
|
||
padding = out + inlen - padding_len - 1;
|
||
if (padding < out + 32) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
for (i = 0; i < padding_len; i++) {
|
||
if (padding[i] != padding_len) {
|
||
error_puts("tls ciphertext cbc-padding check failure");
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
*outlen = inlen - 32 - padding_len - 1;
|
||
|
||
header[0] = enced_header[0];
|
||
header[1] = enced_header[1];
|
||
header[2] = enced_header[2];
|
||
header[3] = (uint8_t)((*outlen) >> 8);
|
||
header[4] = (uint8_t)(*outlen);
|
||
mac = padding - 32;
|
||
|
||
memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(HMAC_CTX));
|
||
hmac_update(&hmac_ctx, seq_num, 8);
|
||
hmac_update(&hmac_ctx, header, 5);
|
||
hmac_update(&hmac_ctx, out, *outlen);
|
||
hmac_finish(&hmac_ctx, hmac, &hmaclen);
|
||
|
||
if (gmssl_secure_memcmp(mac, hmac, sizeof(hmac)) != 0) {
|
||
error_puts("tls ciphertext mac check failure\n");
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls12_record_encrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *cbc_key,
|
||
const uint8_t seq_num[8], const uint8_t *in, size_t inlen,
|
||
uint8_t *out, size_t *outlen)
|
||
{
|
||
if (tls12_cbc_encrypt(hmac_ctx, cbc_key, seq_num, in,
|
||
in + 5, inlen - 5,
|
||
out + 5, outlen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
out[0] = in[0];
|
||
out[1] = in[1];
|
||
out[2] = in[2];
|
||
out[3] = (uint8_t)((*outlen) >> 8);
|
||
out[4] = (uint8_t)(*outlen);
|
||
(*outlen) += 5;
|
||
return 1;
|
||
}
|
||
|
||
int tls12_record_decrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *cbc_key,
|
||
const uint8_t seq_num[8], const uint8_t *in, size_t inlen,
|
||
uint8_t *out, size_t *outlen)
|
||
{
|
||
if (tls12_cbc_decrypt(hmac_ctx, cbc_key, seq_num, in,
|
||
in + 5, inlen - 5,
|
||
out + 5, outlen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
out[0] = in[0];
|
||
out[1] = in[1];
|
||
out[2] = in[2];
|
||
out[3] = (uint8_t)((*outlen) >> 8);
|
||
out[4] = (uint8_t)(*outlen);
|
||
(*outlen) += 5;
|
||
|
||
return 1;
|
||
}
|
||
|
||
int tls12_prf(const DIGEST *digest, const uint8_t *secret, size_t secretlen, const char *label,
|
||
const uint8_t *seed, size_t seedlen,
|
||
const uint8_t *more, size_t morelen,
|
||
size_t outlen, uint8_t *out)
|
||
{
|
||
HMAC_CTX inited_hmac_ctx;
|
||
HMAC_CTX hmac_ctx;
|
||
uint8_t A[32];
|
||
uint8_t hmac[32];
|
||
size_t len;
|
||
|
||
|
||
if (!secret || !secretlen || !label || !seed || !seedlen
|
||
|| (!more && morelen) || !outlen || !out) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
hmac_init(&inited_hmac_ctx, digest, secret, secretlen);
|
||
|
||
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX));
|
||
hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
|
||
hmac_update(&hmac_ctx, seed, seedlen);
|
||
hmac_update(&hmac_ctx, more, morelen);
|
||
hmac_finish(&hmac_ctx, A, &len); // 检查或者使用长度len
|
||
|
||
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX));
|
||
hmac_update(&hmac_ctx, A, sizeof(A));
|
||
hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
|
||
hmac_update(&hmac_ctx, seed, seedlen);
|
||
hmac_update(&hmac_ctx, more, morelen);
|
||
hmac_finish(&hmac_ctx, hmac, &len);
|
||
|
||
len = outlen < sizeof(hmac) ? outlen : sizeof(hmac);
|
||
memcpy(out, hmac, len);
|
||
out += len;
|
||
outlen -= len;
|
||
|
||
while (outlen) {
|
||
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX));
|
||
hmac_update(&hmac_ctx, A, sizeof(A));
|
||
hmac_finish(&hmac_ctx, A, &len);
|
||
|
||
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX));
|
||
hmac_update(&hmac_ctx, A, sizeof(A));
|
||
hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
|
||
hmac_update(&hmac_ctx, seed, seedlen);
|
||
hmac_update(&hmac_ctx, more, morelen);
|
||
hmac_finish(&hmac_ctx, hmac, &len);
|
||
|
||
len = outlen < sizeof(hmac) ? outlen : sizeof(hmac);
|
||
memcpy(out, hmac, len);
|
||
out += len;
|
||
outlen -= len;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
|
||
|
||
|
||
// modify: conn->record_offset
|
||
int tls_send_record(TLS_CONNECT *conn)
|
||
{
|
||
size_t left;
|
||
tls_ret_t n;
|
||
|
||
left = tls_record_length(conn->record) - conn->record_offset;
|
||
while (left) {
|
||
n = tls_socket_send(conn->sock, conn->record + conn->record_offset, left, 0);
|
||
if (n < 0) {
|
||
if (errno == EAGAIN || errno == EWOULDBLOCK) {
|
||
return TLS_ERROR_SEND_AGAIN;
|
||
} else if (errno == EINTR) {
|
||
continue;
|
||
} else {
|
||
fprintf(stderr, "%s %d: send() error: %s\n", __FILE__, __LINE__, strerror(errno));
|
||
error_print();
|
||
return -1;
|
||
}
|
||
}
|
||
conn->record_offset += n;
|
||
left -= n;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_recv_record(TLS_CONNECT *conn)
|
||
{
|
||
size_t left;
|
||
tls_ret_t n;
|
||
|
||
if (conn->recordlen) {
|
||
return 1;
|
||
}
|
||
|
||
if (conn->record_offset < 5) {
|
||
left = 5 - conn->record_offset;
|
||
while (left) {
|
||
n = tls_socket_recv(conn->sock, conn->record + conn->record_offset, left, 0);
|
||
if (n < 0) {
|
||
if (errno == EAGAIN || errno == EWOULDBLOCK) {
|
||
return TLS_ERROR_RECV_AGAIN;
|
||
} else if (errno == EINTR) {
|
||
continue;
|
||
} else {
|
||
error_print();
|
||
// TODO: check the usage of OpenSSL SSL_ERR_SYSCALL
|
||
// if applications such as Nginx, HTTPD do not use this error, we just return -1
|
||
return TLS_ERROR_SYSCALL;
|
||
}
|
||
} else if (n == 0) {
|
||
error_print();
|
||
return TLS_ERROR_TCP_CLOSED;
|
||
}
|
||
conn->record_offset += n;
|
||
left -= n;
|
||
}
|
||
}
|
||
|
||
if (conn->record_offset == 5) {
|
||
if (!tls_record_type_name(tls_record_type(conn->record))) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!tls_protocol_name(tls_record_protocol(conn->record))) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_record_length(conn->record) > TLS_MAX_RECORD_SIZE) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
if (conn->record_offset >= tls_record_length(conn->record)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
left = tls_record_length(conn->record) - conn->record_offset;
|
||
while (left) {
|
||
n = tls_socket_recv(conn->sock, conn->record + conn->record_offset, left, 0);
|
||
if (n < 0) {
|
||
if (errno == EAGAIN || errno == EWOULDBLOCK) {
|
||
return TLS_ERROR_RECV_AGAIN;
|
||
} else if (errno == EINTR) {
|
||
continue;
|
||
} else {
|
||
error_print();
|
||
return TLS_ERROR_SYSCALL;
|
||
}
|
||
} else if (n == 0) {
|
||
error_print();
|
||
return TLS_ERROR_TCP_CLOSED;
|
||
}
|
||
conn->record_offset += n;
|
||
left -= n;
|
||
|
||
}
|
||
|
||
conn->recordlen = conn->record_offset;
|
||
|
||
|
||
// 应该判断是否为Alert这种异常状况
|
||
|
||
return 1;
|
||
}
|
||
|
||
int tls_named_curve_oid(int named_curve)
|
||
{
|
||
switch (named_curve) {
|
||
case TLS_curve_secp256r1: return OID_secp256r1;
|
||
case TLS_curve_sm2p256v1: return OID_sm2;
|
||
}
|
||
return OID_undef;
|
||
}
|
||
|
||
int tls_named_curve_from_oid(int oid)
|
||
{
|
||
switch (oid) {
|
||
case OID_secp256r1: return TLS_curve_secp256r1;
|
||
case OID_sm2: return TLS_curve_sm2p256v1;
|
||
}
|
||
return 0;
|
||
}
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
// 这个是必选的
|
||
|
||
// 服务器通常推荐返回这个值
|
||
const int supported_groups[] = {
|
||
TLS_curve_sm2p256v1,
|
||
TLS_curve_secp256r1,
|
||
};
|
||
size_t supported_groups_cnt = sizeof(supported_groups)/sizeof(supported_groups[0]);
|
||
|
||
// 仍旧是不可设置的
|
||
const int signature_algors[] = {
|
||
TLS_sig_sm2sig_sm3,
|
||
TLS_sig_ecdsa_secp256r1_sha256,
|
||
};
|
||
size_t signature_algors_cnt = sizeof(signature_algors)/sizeof(signature_algors[0]);
|
||
|
||
|
||
|
||
int tls_record_set_handshake_server_key_exchange(uint8_t *record, size_t *recordlen,
|
||
const uint8_t *server_ecdh_params, size_t server_ecdh_params_len,
|
||
uint16_t sig_alg, const uint8_t *sig, size_t siglen)
|
||
{
|
||
const int type = TLS_handshake_server_key_exchange;
|
||
uint8_t *p = tls_handshake_data(tls_record_data(record));
|
||
size_t len = 0;
|
||
|
||
if (server_ecdh_params_len != 69) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (siglen > TLS_MAX_SIGNATURE_SIZE) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_array_to_bytes(server_ecdh_params, server_ecdh_params_len, &p, &len);
|
||
tls_uint16_to_bytes(sig_alg, &p, &len);
|
||
tls_uint16array_to_bytes(sig, siglen, &p, &len);
|
||
tls_record_set_handshake(record, recordlen, type, NULL, len);
|
||
return 1;
|
||
}
|
||
|
||
// 这个函数是有问题的,因为tlcp的格式和TLS不一样
|
||
int tls_record_get_handshake_server_key_exchange(const uint8_t *record,
|
||
uint8_t *curve_type, uint16_t *named_curve,
|
||
const uint8_t **point_octets, size_t *point_octets_len,
|
||
const uint8_t **server_ecdh_params, size_t *server_ecdh_params_len,
|
||
uint16_t *sig_alg, const uint8_t **sig, size_t *siglen)
|
||
{
|
||
int type;
|
||
const uint8_t *p;
|
||
size_t len;
|
||
|
||
if (tls_record_get_handshake(record, &type, &p, &len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (type != TLS_handshake_server_key_exchange) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
*server_ecdh_params = p;
|
||
if (tls_uint8_from_bytes(curve_type, &p, &len) != 1
|
||
|| tls_uint16_from_bytes(named_curve, &p, &len) != 1
|
||
|| tls_uint8array_from_bytes(point_octets, point_octets_len, &p, &len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
*server_ecdh_params_len = p - *server_ecdh_params;
|
||
if (*server_ecdh_params_len != 69) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_uint16_from_bytes(sig_alg, &p, &len) != 1
|
||
|| tls_uint16array_from_bytes(sig, siglen, &p, &len) != 1
|
||
|| tls_length_is_zero(len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (*curve_type != TLS_curve_type_named_curve) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!tls_named_curve_name(*named_curve)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!tls_signature_scheme_name(*sig_alg)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_set_handshake_client_key_exchange(uint8_t *record, size_t *recordlen,
|
||
const uint8_t *point_octets, size_t point_octets_len)
|
||
{
|
||
int type = TLS_handshake_client_key_exchange;
|
||
uint8_t *p = tls_handshake_data(tls_record_data(record));
|
||
size_t len = 0;
|
||
|
||
if (point_octets_len != 65) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_uint8array_to_bytes(point_octets, (uint8_t)point_octets_len, &p, &len);
|
||
tls_record_set_handshake(record, recordlen, type, NULL, len);
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_get_handshake_client_key_exchange(const uint8_t *record,
|
||
const uint8_t **point_octets, size_t *point_octets_len)
|
||
{
|
||
int type;
|
||
const uint8_t *p;
|
||
size_t len;
|
||
|
||
if (tls_record_get_handshake(record, &type, &p, &len) != 1
|
||
|| type != TLS_handshake_client_key_exchange) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_uint8array_from_bytes(point_octets, point_octets_len, &p, &len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (*point_octets_len != 65) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (len) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
|
||
|
||
|
||
int tls12_cert_chains_select(const uint8_t *cert_chains, size_t cert_chains_len,
|
||
const int *supported_groups, size_t supported_groups_cnt, // optional
|
||
const int *signature_algorithms, size_t signature_algorithms_cnt, // optional
|
||
const uint8_t *ca_names, size_t ca_names_len, // certificate_authorities optional
|
||
const uint8_t *host_name, size_t host_name_len, // optional, only in ClientHello
|
||
const uint8_t **certs, size_t *certs_len, size_t *certs_idx, int *prefered_sig_alg) // optional
|
||
{
|
||
size_t i;
|
||
|
||
if (!cert_chains || !cert_chains_len) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
for (i = 1; cert_chains_len; i++) {
|
||
const uint8_t *cert_chain;
|
||
size_t cert_chain_len;
|
||
int sig_alg;
|
||
int ret;
|
||
|
||
if (tls_uint24array_from_bytes(&cert_chain, &cert_chain_len,
|
||
&cert_chains, &cert_chains_len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if (certs) *certs = cert_chain;
|
||
if (certs_len) *certs_len = cert_chain_len;
|
||
if (certs_idx) *certs_idx = i;
|
||
if (prefered_sig_alg) *prefered_sig_alg = sig_alg;
|
||
return 1;
|
||
}
|
||
|
||
return 0;
|
||
}
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
void tls_clean_record(TLS_CONNECT *conn)
|
||
{
|
||
conn->record_offset = 0;
|
||
conn->recordlen = 0;
|
||
}
|
||
|
||
|
||
int tls_handshake_init(TLS_CONNECT *conn)
|
||
{
|
||
|
||
//sm3_init(&conn->sm3_ctx);
|
||
digest_init(&conn->dgst_ctx, DIGEST_sm3());
|
||
|
||
|
||
if (conn->client_certs_len) {
|
||
//sm2_sign_init(&conn->sign_ctx, &conn->sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH);
|
||
}
|
||
|
||
return 1;
|
||
}
|
||
|
||
const int ec_point_formats[] = { TLS_point_uncompressed };
|
||
size_t ec_point_formats_cnt = sizeof(ec_point_formats)/sizeof(ec_point_formats[0]);
|
||
|
||
|
||
|
||
// 有可能需要支持SNI
|
||
|
||
int tls_send_client_hello(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
|
||
if (!conn->recordlen) {
|
||
uint8_t exts[TLS_MAX_EXTENSIONS_SIZE];
|
||
uint8_t *pexts = exts;
|
||
size_t extslen = 0;
|
||
|
||
tls_trace("send ClientHello\n");
|
||
|
||
tls_record_set_protocol(conn->record, TLS_protocol_tls1);
|
||
|
||
if (tls_random_generate(conn->client_random) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
// ec_point_formats
|
||
if (tls_ec_point_formats_ext_to_bytes(
|
||
ec_point_formats, ec_point_formats_cnt, &pexts, &extslen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
// supported_groups
|
||
if (conn->ctx->supported_groups_cnt) {
|
||
if (tls_supported_groups_ext_to_bytes(conn->ctx->supported_groups,
|
||
conn->ctx->supported_groups_cnt, &pexts, &extslen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
// signature_algorithms
|
||
if (conn->ctx->signature_algorithms_cnt) {
|
||
if (tls_signature_algorithms_ext_to_bytes(conn->ctx->signature_algorithms,
|
||
conn->ctx->signature_algorithms_cnt, &pexts, &extslen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
if (tls_record_set_handshake_client_hello(conn->record, &conn->recordlen,
|
||
conn->protocol, conn->client_random, NULL, 0,
|
||
conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt,
|
||
exts, extslen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
tls12_record_print(stderr, conn->record, conn->recordlen, 0, 0);
|
||
|
||
// backup ClientHello
|
||
memcpy(conn->plain_record, conn->record, conn->recordlen);
|
||
conn->plain_recordlen = conn->recordlen;
|
||
}
|
||
|
||
/*
|
||
if (conn->client_certificate_verify) {
|
||
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
|
||
}
|
||
*/
|
||
|
||
if ((ret = tls_send_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_SEND_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
|
||
tls_clean_record(conn);
|
||
return 1;
|
||
}
|
||
|
||
/*
|
||
const int server_ciphers[] = { TLS_cipher_ecdhe_sm4_cbc_sm3 };
|
||
const size_t server_ciphers_cnt = 1;
|
||
*/
|
||
const int curve = TLS_curve_sm2p256v1;
|
||
|
||
static int tls12_cipher_suite_get(int cipher_suite, const BLOCK_CIPHER **cipher, const DIGEST **digest)
|
||
{
|
||
switch (cipher_suite) {
|
||
case TLS_cipher_ecdhe_sm4_cbc_sm3:
|
||
case TLS_cipher_ecdhe_sm4_gcm_sm3:
|
||
*cipher = BLOCK_CIPHER_sm4();
|
||
*digest = DIGEST_sm3();
|
||
break;
|
||
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
|
||
*cipher = BLOCK_CIPHER_aes128();
|
||
*digest = DIGEST_sha256();
|
||
break;
|
||
default:
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
static int tls12_cipher_suite_match_cert_group(int cipher_suite, int cert_group)
|
||
{
|
||
switch (cipher_suite) {
|
||
case TLS_cipher_ecdhe_sm4_cbc_sm3:
|
||
case TLS_cipher_ecdhe_sm4_gcm_sm3:
|
||
return cert_group == TLS_curve_sm2p256v1;
|
||
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
|
||
return cert_group == TLS_curve_secp256r1;
|
||
default:
|
||
return 0;
|
||
}
|
||
}
|
||
|
||
static int tls12_signature_scheme_match_cert_group(int sig_alg, int cert_group)
|
||
{
|
||
return tls_signature_scheme_group_oid(sig_alg) == tls_named_curve_oid(cert_group);
|
||
}
|
||
|
||
static int tls12_signature_scheme_match_cipher_suite(int sig_alg, int cipher_suite)
|
||
{
|
||
switch (sig_alg) {
|
||
case TLS_sig_sm2sig_sm3:
|
||
switch (cipher_suite) {
|
||
case TLS_cipher_ecdhe_sm4_cbc_sm3:
|
||
case TLS_cipher_ecdhe_sm4_gcm_sm3:
|
||
return 1;
|
||
}
|
||
break;
|
||
case TLS_sig_ecdsa_secp256r1_sha256:
|
||
if (cipher_suite == TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256) {
|
||
return 1;
|
||
}
|
||
break;
|
||
}
|
||
return 0;
|
||
}
|
||
|
||
static int tls12_key_exchange_group_match_cipher_suite(int group, int cipher_suite)
|
||
{
|
||
switch (cipher_suite) {
|
||
case TLS_cipher_ecdhe_sm4_cbc_sm3:
|
||
case TLS_cipher_ecdhe_sm4_gcm_sm3:
|
||
return group == TLS_curve_sm2p256v1;
|
||
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
|
||
return group == TLS_curve_secp256r1;
|
||
default:
|
||
return 0;
|
||
}
|
||
}
|
||
|
||
static int tls12_select_common_cipher_suites(const uint8_t *client_ciphers, size_t client_ciphers_len,
|
||
const int *server_ciphers, size_t server_ciphers_cnt,
|
||
int *common_ciphers, size_t *common_ciphers_cnt, size_t max_cnt)
|
||
{
|
||
size_t i;
|
||
|
||
if (!client_ciphers || !client_ciphers_len
|
||
|| !server_ciphers || !server_ciphers_cnt
|
||
|| !common_ciphers || !common_ciphers_cnt || !max_cnt) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
*common_ciphers_cnt = 0;
|
||
for (i = 0; i < server_ciphers_cnt && *common_ciphers_cnt < max_cnt; i++) {
|
||
const uint8_t *p = client_ciphers;
|
||
size_t len = client_ciphers_len;
|
||
while (len) {
|
||
uint16_t cipher;
|
||
if (tls_uint16_from_bytes(&cipher, &p, &len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (cipher == server_ciphers[i]) {
|
||
common_ciphers[(*common_ciphers_cnt)++] = server_ciphers[i];
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
|
||
return *common_ciphers_cnt ? 1 : 0;
|
||
}
|
||
|
||
// support_uncompressed
|
||
static int tls_ec_point_formats_support_uncompressed(const uint8_t *ext_data, size_t ext_datalen)
|
||
{
|
||
const uint8_t *formats;
|
||
size_t formats_len;
|
||
int uncompressed = 0;
|
||
|
||
if (tls_uint8array_from_bytes(&formats, &formats_len, &ext_data, &ext_datalen) != 1
|
||
|| tls_length_is_zero(ext_datalen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!formats_len) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
while (formats_len) {
|
||
uint8_t format;
|
||
if (tls_uint8_from_bytes(&format, &formats, &formats_len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!tls_ec_point_format_name(format)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (format == TLS_point_uncompressed) {
|
||
uncompressed = 1;
|
||
}
|
||
}
|
||
|
||
if (!uncompressed) {
|
||
error_print();
|
||
return 0;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
static int tls12_cert_chain_get_end_entity_group(const uint8_t *cert_chain, size_t cert_chain_len, int *group)
|
||
{
|
||
const uint8_t *cert;
|
||
size_t certlen;
|
||
X509_KEY public_key;
|
||
|
||
if (!cert_chain || !cert_chain_len || !group) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (x509_certs_get_cert_by_index(cert_chain, cert_chain_len, 0, &cert, &certlen) != 1
|
||
|| x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (public_key.algor != OID_ec_public_key) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if ((*group = tls_named_curve_from_oid(public_key.algor_param)) == 0) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
static int tls12_public_key_get_group(const X509_KEY *public_key, int *group)
|
||
{
|
||
if (!public_key || !group) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (public_key->algor != OID_ec_public_key) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if ((*group = tls_named_curve_from_oid(public_key->algor_param)) == 0) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
static int tls12_select_key_exchange_group(const int *groups, size_t groups_cnt,
|
||
int cipher_suite, int *selected_group)
|
||
{
|
||
size_t i;
|
||
|
||
if (!groups || !groups_cnt || !selected_group) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
for (i = 0; i < groups_cnt; i++) {
|
||
if (tls12_key_exchange_group_match_cipher_suite(groups[i], cipher_suite)) {
|
||
*selected_group = groups[i];
|
||
return 1;
|
||
}
|
||
}
|
||
return 0;
|
||
}
|
||
|
||
// 这个函数的名字最好换一下
|
||
static int tls12_select_parameters(TLS_CONNECT *conn,
|
||
const int *common_cipher_suites, size_t common_cipher_suites_cnt,
|
||
const int *common_supported_groups, size_t common_supported_groups_cnt,
|
||
const int *common_signature_algorithms, size_t common_signature_algorithms_cnt,
|
||
const int *signature_algorithms_cert, size_t signature_algorithms_cert_cnt,
|
||
const uint8_t *host_name, size_t host_name_len)
|
||
{
|
||
const uint8_t *cert_chains = conn->ctx->cert_chains;
|
||
size_t cert_chains_len = conn->ctx->cert_chains_len;
|
||
size_t cert_chain_idx;
|
||
|
||
if (!conn || !common_cipher_suites || !common_cipher_suites_cnt
|
||
|| !common_supported_groups || !common_supported_groups_cnt
|
||
|| !common_signature_algorithms || !common_signature_algorithms_cnt) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!cert_chains || !cert_chains_len) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
for (cert_chain_idx = 1; cert_chains_len; cert_chain_idx++) {
|
||
const uint8_t *cert_chain;
|
||
size_t cert_chain_len;
|
||
const uint8_t *cert;
|
||
size_t certlen;
|
||
int cert_group;
|
||
size_t i;
|
||
int ret;
|
||
|
||
if (tls_uint24array_from_bytes(&cert_chain, &cert_chain_len,
|
||
&cert_chains, &cert_chains_len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls12_cert_chain_get_end_entity_group(cert_chain, cert_chain_len, &cert_group) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!tls_type_is_in_list(cert_group, common_supported_groups, common_supported_groups_cnt)) {
|
||
continue;
|
||
}
|
||
if (x509_certs_get_cert_by_index(cert_chain, cert_chain_len, 0, &cert, &certlen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (host_name && host_name_len) {
|
||
if ((ret = tls_cert_match_server_name(cert, certlen, host_name, host_name_len)) < 0) {
|
||
error_print();
|
||
return -1;
|
||
} else if (ret == 0) {
|
||
continue;
|
||
}
|
||
}
|
||
if (signature_algorithms_cert && signature_algorithms_cert_cnt) {
|
||
if ((ret = tls_cert_chain_match_signature_algorithms_cert(cert_chain, cert_chain_len,
|
||
signature_algorithms_cert, signature_algorithms_cert_cnt)) < 0) {
|
||
error_print();
|
||
return -1;
|
||
} else if (ret == 0) {
|
||
continue;
|
||
}
|
||
}
|
||
|
||
for (i = 0; i < common_cipher_suites_cnt; i++) {
|
||
size_t j;
|
||
int cipher_suite = common_cipher_suites[i];
|
||
int key_exchange_group;
|
||
|
||
if (!tls12_cipher_suite_match_cert_group(cipher_suite, cert_group)) {
|
||
continue;
|
||
}
|
||
if ((ret = tls12_select_key_exchange_group(common_supported_groups,
|
||
common_supported_groups_cnt, cipher_suite, &key_exchange_group)) < 0) {
|
||
error_print();
|
||
return -1;
|
||
} else if (ret == 0) {
|
||
continue;
|
||
}
|
||
|
||
for (j = 0; j < common_signature_algorithms_cnt; j++) {
|
||
int sig_alg = common_signature_algorithms[j];
|
||
|
||
if (!tls12_signature_scheme_match_cert_group(sig_alg, cert_group)) {
|
||
continue;
|
||
}
|
||
if (!tls12_signature_scheme_match_cipher_suite(sig_alg, cipher_suite)) {
|
||
continue;
|
||
}
|
||
|
||
conn->cipher_suite = cipher_suite;
|
||
conn->cert_chain = cert_chain;
|
||
conn->cert_chain_len = cert_chain_len;
|
||
conn->cert_chain_idx = cert_chain_idx;
|
||
conn->sig_alg = sig_alg;
|
||
conn->key_exchange_group = key_exchange_group;
|
||
return 1;
|
||
}
|
||
}
|
||
}
|
||
|
||
warning_print();
|
||
return 0;
|
||
}
|
||
|
||
|
||
int tls_recv_client_hello(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
|
||
int client_verify = 0;
|
||
|
||
int protocol;
|
||
const uint8_t *client_random;
|
||
const uint8_t *session_id;
|
||
size_t session_id_len;
|
||
const uint8_t *cipher_suites;
|
||
size_t cipher_suites_len;
|
||
const uint8_t *exts;
|
||
size_t extslen;
|
||
const uint8_t *ec_point_formats = NULL;
|
||
size_t ec_point_formats_len = 0;
|
||
const uint8_t *supported_groups = NULL;
|
||
size_t supported_groups_len = 0;
|
||
const uint8_t *signature_algorithms = NULL;
|
||
size_t signature_algorithms_len = 0;
|
||
const uint8_t *signature_algorithms_cert = NULL;
|
||
size_t signature_algorithms_cert_len = 0;
|
||
const uint8_t *server_name = NULL;
|
||
size_t server_name_len = 0;
|
||
int common_cipher_suites[TLS_MAX_CIPHER_SUITES_COUNT];
|
||
size_t common_cipher_suites_cnt = 0;
|
||
int common_supported_groups[32];
|
||
size_t common_supported_groups_cnt = 0;
|
||
int common_signature_algorithms[32];
|
||
size_t common_signature_algorithms_cnt = 0;
|
||
int common_signature_algorithms_cert[32];
|
||
size_t common_signature_algorithms_cert_cnt = 0;
|
||
const int *cert_signature_algorithms = NULL;
|
||
size_t cert_signature_algorithms_cnt = 0;
|
||
const uint8_t *host_name = NULL;
|
||
size_t host_name_len = 0;
|
||
|
||
/*
|
||
if (client_verify)
|
||
tls_client_verify_init(&conn->client_verify_ctx);
|
||
*/
|
||
|
||
|
||
tls_trace("recv ClientHello\n");
|
||
|
||
if ((ret = tls_recv_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_RECV_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
|
||
|
||
if (tls_record_protocol(conn->record) != TLS_protocol_tls1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_protocol_version);
|
||
return -1;
|
||
}
|
||
|
||
if ((ret = tls_record_get_handshake_client_hello(conn->record,
|
||
&protocol, &client_random, &session_id, &session_id_len,
|
||
&cipher_suites, &cipher_suites_len, &exts, &extslen)) < 0) {
|
||
error_print();
|
||
tls13_send_alert(conn, TLS_alert_decode_error);
|
||
return -1;
|
||
} else if (ret == 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return -1;
|
||
}
|
||
|
||
if (protocol != conn->protocol) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_protocol_version);
|
||
return -1;
|
||
}
|
||
|
||
memcpy(conn->client_random, client_random, 32);
|
||
|
||
|
||
while (extslen) {
|
||
int ext_type;
|
||
const uint8_t *ext_data;
|
||
size_t ext_datalen;
|
||
|
||
if (tls_ext_from_bytes(&ext_type, &ext_data, &ext_datalen, &exts, &extslen) != 1) {
|
||
error_print();
|
||
tls13_send_alert(conn, TLS_alert_decode_error);
|
||
return -1;
|
||
}
|
||
|
||
switch (ext_type) {
|
||
case TLS_extension_ec_point_formats:
|
||
case TLS_extension_supported_groups:
|
||
case TLS_extension_signature_algorithms:
|
||
case TLS_extension_signature_algorithms_cert:
|
||
case TLS_extension_server_name:
|
||
if (!ext_data) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_decode_error);
|
||
return -1;
|
||
}
|
||
break;
|
||
}
|
||
|
||
switch (ext_type) {
|
||
case TLS_extension_ec_point_formats:
|
||
if (ec_point_formats) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_illegal_parameter);
|
||
return -1;
|
||
}
|
||
ec_point_formats = ext_data;
|
||
ec_point_formats_len = ext_datalen;
|
||
break;
|
||
case TLS_extension_supported_groups:
|
||
if (supported_groups) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_illegal_parameter);
|
||
return -1;
|
||
}
|
||
supported_groups = ext_data;
|
||
supported_groups_len = ext_datalen;
|
||
break;
|
||
case TLS_extension_signature_algorithms:
|
||
if (signature_algorithms) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_illegal_parameter);
|
||
return -1;
|
||
}
|
||
signature_algorithms = ext_data;
|
||
signature_algorithms_len = ext_datalen;
|
||
break;
|
||
case TLS_extension_signature_algorithms_cert:
|
||
if (signature_algorithms_cert) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_illegal_parameter);
|
||
return -1;
|
||
}
|
||
signature_algorithms_cert = ext_data;
|
||
signature_algorithms_cert_len = ext_datalen;
|
||
break;
|
||
case TLS_extension_server_name:
|
||
if (server_name) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_illegal_parameter);
|
||
return -1;
|
||
}
|
||
server_name = ext_data;
|
||
server_name_len = ext_datalen;
|
||
break;
|
||
default:
|
||
warning_print();
|
||
}
|
||
}
|
||
|
||
if (ec_point_formats) {
|
||
if ((ret = tls_ec_point_formats_support_uncompressed(ec_point_formats, ec_point_formats_len)) < 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_decode_error);
|
||
return -1;
|
||
} else if (ret == 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_illegal_parameter);
|
||
return -1;
|
||
}
|
||
conn->ec_point_formats = 1;
|
||
}
|
||
|
||
if ((ret = tls12_select_common_cipher_suites(cipher_suites, cipher_suites_len,
|
||
conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt,
|
||
common_cipher_suites, &common_cipher_suites_cnt,
|
||
sizeof(common_cipher_suites)/sizeof(common_cipher_suites[0]))) < 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_decode_error);
|
||
return -1;
|
||
} else if (ret == 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_handshake_failure);
|
||
return -1;
|
||
}
|
||
|
||
if (supported_groups) {
|
||
if ((ret = tls_process_supported_groups(supported_groups, supported_groups_len,
|
||
conn->ctx->supported_groups, conn->ctx->supported_groups_cnt,
|
||
common_supported_groups, &common_supported_groups_cnt,
|
||
sizeof(common_supported_groups)/sizeof(common_supported_groups[0]))) < 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_decode_error);
|
||
return -1;
|
||
} else if (ret == 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_handshake_failure);
|
||
return -1;
|
||
}
|
||
} else {
|
||
if (!conn->ctx->supported_groups_cnt) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_handshake_failure);
|
||
return -1;
|
||
}
|
||
memcpy(common_supported_groups, conn->ctx->supported_groups,
|
||
conn->ctx->supported_groups_cnt * sizeof(conn->ctx->supported_groups[0]));
|
||
common_supported_groups_cnt = conn->ctx->supported_groups_cnt;
|
||
}
|
||
|
||
if (signature_algorithms) {
|
||
if ((ret = tls_process_signature_algorithms(signature_algorithms, signature_algorithms_len,
|
||
conn->ctx->signature_algorithms, conn->ctx->signature_algorithms_cnt,
|
||
common_signature_algorithms, &common_signature_algorithms_cnt,
|
||
sizeof(common_signature_algorithms)/sizeof(common_signature_algorithms[0]))) < 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_decode_error);
|
||
return -1;
|
||
} else if (ret == 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_handshake_failure);
|
||
return -1;
|
||
}
|
||
} else {
|
||
if (!conn->ctx->signature_algorithms_cnt) {
|
||
error_print();
|
||
tls13_send_alert(conn, TLS_alert_handshake_failure);
|
||
return -1;
|
||
}
|
||
memcpy(common_signature_algorithms, conn->ctx->signature_algorithms,
|
||
conn->ctx->signature_algorithms_cnt * sizeof(conn->ctx->signature_algorithms[0]));
|
||
common_signature_algorithms_cnt = conn->ctx->signature_algorithms_cnt;
|
||
}
|
||
|
||
if (signature_algorithms_cert) {
|
||
if ((ret = tls_process_signature_algorithms(signature_algorithms_cert, signature_algorithms_cert_len,
|
||
conn->ctx->signature_algorithms, conn->ctx->signature_algorithms_cnt,
|
||
common_signature_algorithms_cert, &common_signature_algorithms_cert_cnt,
|
||
sizeof(common_signature_algorithms_cert)/sizeof(common_signature_algorithms_cert[0]))) < 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_decode_error);
|
||
return -1;
|
||
} else if (ret == 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_handshake_failure);
|
||
return -1;
|
||
}
|
||
cert_signature_algorithms = common_signature_algorithms_cert;
|
||
cert_signature_algorithms_cnt = common_signature_algorithms_cert_cnt;
|
||
} else if (signature_algorithms) {
|
||
cert_signature_algorithms = common_signature_algorithms;
|
||
cert_signature_algorithms_cnt = common_signature_algorithms_cnt;
|
||
}
|
||
|
||
if (server_name) {
|
||
if (tls_server_name_from_bytes(&host_name, &host_name_len, server_name, server_name_len) != 1) {
|
||
error_print();
|
||
tls13_send_alert(conn, TLS_alert_decode_error);
|
||
return -1;
|
||
}
|
||
conn->server_name = 1;
|
||
}
|
||
|
||
if ((ret = tls12_select_parameters(conn,
|
||
common_cipher_suites, common_cipher_suites_cnt,
|
||
common_supported_groups, common_supported_groups_cnt,
|
||
common_signature_algorithms, common_signature_algorithms_cnt,
|
||
cert_signature_algorithms, cert_signature_algorithms_cnt,
|
||
host_name, host_name_len)) < 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_decode_error);
|
||
return -1;
|
||
} else if (ret == 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_handshake_failure);
|
||
return -1;
|
||
}
|
||
|
||
if (tls12_cipher_suite_get(conn->cipher_suite, &conn->cipher, &conn->digest) != 1) {
|
||
error_print();
|
||
tls13_send_alert(conn, TLS_alert_internal_error);
|
||
return -1;
|
||
}
|
||
|
||
if (digest_init(&conn->dgst_ctx, conn->digest) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_handshake_digest_print(stderr, 0, 0, "ClientHello", &conn->dgst_ctx);
|
||
|
||
|
||
|
||
|
||
/*
|
||
if (client_verify)
|
||
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
|
||
|
||
*/
|
||
|
||
fprintf(stderr, "end of recv_client_hello\n");
|
||
tls_clean_record(conn);
|
||
return 1;
|
||
}
|
||
|
||
int tls_send_server_hello(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
|
||
tls_trace("send ServerHello\n");
|
||
|
||
if (conn->recordlen == 0) {
|
||
|
||
uint8_t exts[512];
|
||
uint8_t *pexts = exts;
|
||
size_t extslen = 0;
|
||
|
||
tls_record_set_protocol(conn->record, conn->protocol);
|
||
|
||
if (tls_random_generate(conn->server_random) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
// extensions in ServerHello
|
||
// ec_point_formats
|
||
if (conn->ec_point_formats) {
|
||
if (tls_ec_point_formats_ext_to_bytes(ec_point_formats, ec_point_formats_cnt, &pexts, &extslen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
if (tls_record_set_handshake_server_hello(conn->record, &conn->recordlen,
|
||
conn->protocol, conn->server_random, NULL, 0,
|
||
conn->cipher_suite,
|
||
exts, extslen) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_internal_error);
|
||
return -1;
|
||
}
|
||
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
|
||
|
||
|
||
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_handshake_digest_print(stderr, 0, 0, "ServerHello", &conn->dgst_ctx);
|
||
|
||
}
|
||
|
||
if ((ret = tls_send_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_SEND_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
|
||
if (conn->ctx->cacertslen) {
|
||
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
|
||
}
|
||
|
||
tls_clean_record(conn);
|
||
return 1;
|
||
}
|
||
|
||
int tls_recv_server_hello(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
int protocol;
|
||
int cipher_suite;
|
||
const uint8_t *server_random;
|
||
const uint8_t *session_id;
|
||
size_t session_id_len;
|
||
const uint8_t *exts;
|
||
size_t extslen;
|
||
|
||
const uint8_t *ec_point_formats = NULL;
|
||
size_t ec_point_formats_len = 0;
|
||
|
||
tls_trace("recv ServerHello\n");
|
||
|
||
if ((ret = tls_recv_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_RECV_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
tls12_record_print(stderr, conn->record, conn->recordlen, 0, 0);
|
||
|
||
if (tls_record_protocol(conn->record) != conn->protocol) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_protocol_version);
|
||
return -1;
|
||
}
|
||
|
||
if ((ret = tls_record_get_handshake_server_hello(conn->record,
|
||
&protocol, &server_random, &session_id, &session_id_len, &cipher_suite,
|
||
&exts, &extslen)) < 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_decode_error);
|
||
return -1;
|
||
} else if (ret == 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return -1;
|
||
}
|
||
|
||
// version
|
||
if (protocol != conn->protocol) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_protocol_version);
|
||
return -1;
|
||
}
|
||
|
||
// random
|
||
memcpy(conn->server_random, server_random, 32);
|
||
|
||
// session_id
|
||
memcpy(conn->session_id, session_id, session_id_len);
|
||
|
||
// cipher_suite
|
||
if (tls_type_is_in_list(cipher_suite, conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_handshake_failure);
|
||
return -1;
|
||
}
|
||
conn->cipher_suite = cipher_suite;
|
||
|
||
if (tls12_cipher_suite_get(conn->cipher_suite, &conn->cipher, &conn->digest) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_internal_error);
|
||
return -1;
|
||
}
|
||
|
||
if (digest_init(&conn->dgst_ctx, conn->digest) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
|
||
while (extslen) {
|
||
int ext_type;
|
||
const uint8_t *ext_data;
|
||
size_t ext_datalen;
|
||
|
||
if (tls_ext_from_bytes(&ext_type, &ext_data, &ext_datalen, &exts, &extslen) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_decode_error);
|
||
return -1;
|
||
}
|
||
|
||
switch (ext_type) {
|
||
case TLS_extension_ec_point_formats:
|
||
if (ec_point_formats) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_illegal_parameter);
|
||
return -1;
|
||
}
|
||
ec_point_formats = ext_data;
|
||
ec_point_formats_len = ext_datalen;
|
||
break;
|
||
default:
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_illegal_parameter);
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
if (ec_point_formats) {
|
||
if ((ret = tls_ec_point_formats_support_uncompressed(ec_point_formats, ec_point_formats_len)) < 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_decode_error);
|
||
return -1;
|
||
} else if (ret == 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_illegal_parameter);
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
if (digest_update(&conn->dgst_ctx, conn->plain_record + 5, conn->plain_recordlen - 5) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_handshake_digest_print(stderr, 0, 0, "ClientHello", &conn->dgst_ctx);
|
||
|
||
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_handshake_digest_print(stderr, 0, 0, "ServerHello", &conn->dgst_ctx);
|
||
|
||
|
||
|
||
//sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
|
||
if (conn->client_certs_len) {
|
||
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
|
||
}
|
||
|
||
return 1;
|
||
}
|
||
|
||
// TLS12 发送的是常规的证书链
|
||
// TLCP SM2 发送的是SM2的双证书链,但是在数据格式上没有区别
|
||
// TLCP SM9 发送的是服务器的ID和SM9公开参数(这个格式是不同的),但是存储上可能也是一样的
|
||
// 我不确定SM2和SM9的格式是否是相容的
|
||
int tls_send_server_certificate(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
tls_trace("send ServerCertificate\n");
|
||
|
||
if (conn->recordlen == 0) {
|
||
if (tls_record_set_handshake_certificate(conn->record, &conn->recordlen,
|
||
conn->cert_chain, conn->cert_chain_len) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_internal_error);
|
||
return -1;
|
||
}
|
||
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
|
||
|
||
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_handshake_digest_print(stderr, 0, 0, "Certificate", &conn->dgst_ctx);
|
||
|
||
}
|
||
|
||
if ((ret = tls_send_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_SEND_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
|
||
if (conn->client_certificate_verify) {
|
||
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_recv_server_certificate(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
int verify_result = 0;
|
||
const uint8_t *server_cert;
|
||
size_t server_cert_len;
|
||
X509_KEY server_sign_key;
|
||
int server_sig_alg = 0;
|
||
int server_group;
|
||
int cert_sig_alg = 0;
|
||
const int *signature_algorithms_cert = NULL;
|
||
size_t signature_algorithms_cert_cnt = 0;
|
||
|
||
|
||
tls_trace("recv server Certificate\n");
|
||
|
||
if ((ret = tls_recv_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_RECV_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
|
||
|
||
if (tls_record_protocol(conn->record) != conn->protocol) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return -1;
|
||
}
|
||
|
||
if ((ret = tls_record_get_handshake_certificate(conn->record,
|
||
conn->peer_cert_chain, &conn->peer_cert_chain_len)) < 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_decode_error);
|
||
return -1;
|
||
} else if (ret == 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return 0;
|
||
}
|
||
if (!conn->peer_cert_chain_len) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
}
|
||
|
||
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_handshake_digest_print(stderr, 0, 0, "Certificate", &conn->dgst_ctx);
|
||
|
||
|
||
// server_sign_key
|
||
if (x509_certs_get_cert_by_index(conn->peer_cert_chain, conn->peer_cert_chain_len, 0,
|
||
&server_cert, &server_cert_len) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
}
|
||
if (x509_cert_get_subject_public_key(server_cert, server_cert_len, &server_sign_key) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
}
|
||
|
||
if (tls12_public_key_get_group(&server_sign_key, &server_group) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
}
|
||
|
||
// check server certificate matches negotiated cipher_suite
|
||
if (!tls12_cipher_suite_match_cert_group(conn->cipher_suite, server_group)) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
}
|
||
switch (conn->cipher_suite) {
|
||
case TLS_cipher_ecdhe_sm4_cbc_sm3:
|
||
case TLS_cipher_ecdhe_sm4_gcm_sm3:
|
||
case TLS_cipher_ecc_sm4_cbc_sm3:
|
||
case TLS_cipher_ecc_sm4_gcm_sm3:
|
||
server_sig_alg = TLS_sig_sm2sig_sm3;
|
||
break;
|
||
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
|
||
server_sig_alg = TLS_sig_ecdsa_secp256r1_sha256;
|
||
break;
|
||
default:
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
}
|
||
|
||
// check server certificate matches ClientHello.supported_groups
|
||
if (conn->ctx->supported_groups_cnt) {
|
||
if (!tls_type_is_in_list(server_group, conn->ctx->supported_groups,
|
||
conn->ctx->supported_groups_cnt)) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
// check server certificate matches ClientHello.signature_algorithms
|
||
if (conn->ctx->signature_algorithms_cnt) {
|
||
if ((ret = tls_cert_match_signature_algorithms(server_cert, server_cert_len,
|
||
conn->ctx->signature_algorithms,
|
||
conn->ctx->signature_algorithms_cnt,
|
||
&cert_sig_alg)) < 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
} else if (ret == 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
}
|
||
if (!tls12_signature_scheme_match_cert_group(cert_sig_alg, server_group)
|
||
|| !tls12_signature_scheme_match_cipher_suite(cert_sig_alg, conn->cipher_suite)) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
}
|
||
server_sig_alg = cert_sig_alg;
|
||
}
|
||
|
||
// check certificate-chain signatures match ClientHello.signature_algorithms_cert
|
||
if (conn->signature_algorithms_cert) {
|
||
signature_algorithms_cert = conn->ctx->signature_algorithms;
|
||
signature_algorithms_cert_cnt = conn->ctx->signature_algorithms_cnt;
|
||
} else if (conn->ctx->signature_algorithms_cnt) {
|
||
signature_algorithms_cert = conn->ctx->signature_algorithms;
|
||
signature_algorithms_cert_cnt = conn->ctx->signature_algorithms_cnt;
|
||
}
|
||
if (signature_algorithms_cert && signature_algorithms_cert_cnt) {
|
||
if ((ret = tls_cert_chain_match_signature_algorithms_cert(
|
||
conn->peer_cert_chain, conn->peer_cert_chain_len,
|
||
signature_algorithms_cert, signature_algorithms_cert_cnt)) < 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
} else if (ret == 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
// check server certificate matches ClientHello.server_name
|
||
if (conn->server_name) {
|
||
if ((ret = tls_cert_match_server_name(server_cert, server_cert_len,
|
||
conn->host_name, conn->host_name_len)) < 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
} else if (ret == 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
conn->signature_algorithms[0] = server_sig_alg;
|
||
conn->signature_algorithms_cnt = 1;
|
||
|
||
if (conn->client_certs_len) {
|
||
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
|
||
}
|
||
|
||
assert(conn->ctx->verify_depth > 0 && conn->ctx->verify_depth < 10);
|
||
|
||
// verify server Certificate
|
||
if (conn->ctx->cacertslen) {
|
||
if (x509_certs_verify(conn->peer_cert_chain, conn->peer_cert_chain_len, X509_cert_chain_server,
|
||
conn->ctx->cacerts, conn->ctx->cacertslen, conn->ctx->verify_depth, &verify_result) != 1) {
|
||
error_print();
|
||
conn->verify_result = verify_result;
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
}
|
||
}
|
||
conn->verify_result = verify_result;
|
||
|
||
return 1;
|
||
}
|
||
|
||
|
||
|
||
int tls_send_server_key_exchange(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
uint8_t server_ecdh_params[69];
|
||
uint8_t *p = server_ecdh_params + 4;
|
||
size_t len = 0;
|
||
X509_SIGN_CTX sign_ctx;
|
||
const void *sign_args = NULL;
|
||
size_t sign_argslen = 0;
|
||
uint8_t sig[X509_SIGNATURE_MAX_SIZE];
|
||
size_t siglen;
|
||
|
||
tls_trace("send ServerKeyExchange\n");
|
||
|
||
if (conn->recordlen == 0) {
|
||
int curve_oid = tls_named_curve_oid(conn->key_exchange_group);
|
||
|
||
|
||
// generate server ecdh_key
|
||
if (x509_key_generate(&conn->key_exchanges[0], OID_ec_public_key, &curve_oid, sizeof(curve_oid)) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
// build server_ecdh_params
|
||
server_ecdh_params[0] = TLS_curve_type_named_curve;
|
||
server_ecdh_params[1] = conn->key_exchange_group >> 8;
|
||
server_ecdh_params[2] = (uint8_t)conn->key_exchange_group;
|
||
server_ecdh_params[3] = 65;
|
||
if (x509_public_key_to_bytes(&conn->key_exchanges[0], &p, &len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (len != 65) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
X509_KEY *sign_key = &conn->ctx->x509_keys[conn->cert_chain_idx - 1];
|
||
|
||
// sign server_ecdh_params
|
||
if (sign_key->algor == OID_ec_public_key && sign_key->algor_param == OID_sm2) {
|
||
sign_args = SM2_DEFAULT_ID;
|
||
sign_argslen = SM2_DEFAULT_ID_LENGTH;
|
||
}
|
||
if (x509_sign_init(&sign_ctx, sign_key, sign_args, sign_argslen) != 1
|
||
|| x509_sign_update(&sign_ctx, conn->client_random, 32) != 1
|
||
|| x509_sign_update(&sign_ctx, conn->server_random, 32) != 1
|
||
|| x509_sign_update(&sign_ctx, server_ecdh_params, 69) != 1
|
||
|| x509_sign_finish(&sign_ctx, sig, &siglen) != 1) {
|
||
x509_sign_ctx_cleanup(&sign_ctx);
|
||
error_print();
|
||
return -1;
|
||
}
|
||
x509_sign_ctx_cleanup(&sign_ctx);
|
||
|
||
|
||
if (tls_record_set_handshake_server_key_exchange(conn->record, &conn->recordlen,
|
||
server_ecdh_params, sizeof(server_ecdh_params),
|
||
conn->sig_alg, sig, siglen) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_internal_error);
|
||
return -1;
|
||
}
|
||
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
|
||
|
||
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_handshake_digest_print(stderr, 0, 0, "ServerKeyExchange", &conn->dgst_ctx);
|
||
|
||
|
||
}
|
||
|
||
|
||
if ((ret = tls_send_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_SEND_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
|
||
//sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
|
||
if (conn->client_certificate_verify) {
|
||
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
|
||
}
|
||
|
||
return 1;
|
||
}
|
||
|
||
// match the ecdhe of cipher_suite
|
||
int tls_curve_match_cipher_suite(int named_curve, int cipher_suite)
|
||
{
|
||
switch (named_curve) {
|
||
case TLS_curve_sm2p256v1:
|
||
switch (cipher_suite) {
|
||
case TLS_cipher_ecdhe_sm4_cbc_sm3:
|
||
case TLS_cipher_ecdhe_sm4_gcm_sm3:
|
||
break;
|
||
default:
|
||
error_print();
|
||
return -1;
|
||
}
|
||
break;
|
||
case TLS_curve_secp256r1:
|
||
if (cipher_suite != TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
break;
|
||
default:
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_signature_scheme_match_cipher_suite(int sig_alg, int cipher_suite)
|
||
{
|
||
switch (sig_alg) {
|
||
case TLS_sig_sm2sig_sm3:
|
||
switch (cipher_suite) {
|
||
case TLS_cipher_ecdhe_sm4_cbc_sm3:
|
||
case TLS_cipher_ecdhe_sm4_gcm_sm3:
|
||
case TLS_cipher_ecc_sm4_cbc_sm3:
|
||
case TLS_cipher_ecc_sm4_gcm_sm3:
|
||
break;
|
||
default:
|
||
error_print();
|
||
return -1;
|
||
}
|
||
break;
|
||
case TLS_sig_ecdsa_secp256r1_sha256:
|
||
switch (cipher_suite) {
|
||
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
|
||
break;
|
||
default:
|
||
error_print();
|
||
return -1;
|
||
}
|
||
break;
|
||
default:
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_recv_server_key_exchange(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
uint8_t curve_type;
|
||
uint16_t named_curve;
|
||
const uint8_t *point_octets;
|
||
size_t point_octets_len;
|
||
const uint8_t *server_ecdh_params;
|
||
size_t server_ecdh_params_len;
|
||
uint16_t sig_alg;
|
||
const uint8_t *sig;
|
||
size_t siglen;
|
||
|
||
// verify ServerKeyExchange
|
||
X509_KEY server_sign_key;
|
||
|
||
int server_cert_index = 0;
|
||
const uint8_t *server_cert;
|
||
size_t server_cert_len;
|
||
|
||
X509_SIGN_CTX sign_ctx;
|
||
const void *sign_args = NULL;
|
||
size_t sign_argslen = 0;
|
||
|
||
tls_trace("recv ServerKeyExchange\n");
|
||
|
||
if ((ret = tls_recv_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_RECV_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
if (tls_record_protocol(conn->record) != conn->protocol) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return -1;
|
||
}
|
||
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
|
||
|
||
|
||
// 这个函数可能是有问题的,如果cipher_suite不同,ServerKeyExchange可能也是不同的
|
||
if ((ret = tls_record_get_handshake_server_key_exchange(conn->record,
|
||
&curve_type, &named_curve, &point_octets, &point_octets_len,
|
||
&server_ecdh_params, &server_ecdh_params_len,
|
||
&sig_alg, &sig, &siglen)) < 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_decode_error);
|
||
return -1;
|
||
} else if (ret == 0) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return 0;
|
||
}
|
||
|
||
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_handshake_digest_print(stderr, 0, 0, "ServerKeyExchange", &conn->dgst_ctx);
|
||
|
||
if (curve_type != TLS_curve_type_named_curve) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
// named_curve应该在supported_groups里面
|
||
|
||
//conn->ecdh_named_curve = named_curve;
|
||
|
||
|
||
conn->key_exchange_group = named_curve;
|
||
memcpy(conn->peer_key_exchange, point_octets, point_octets_len);
|
||
conn->peer_key_exchange_len = point_octets_len;
|
||
|
||
|
||
|
||
if (point_octets_len != 65) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
|
||
if (tls_curve_match_cipher_suite(named_curve, conn->cipher_suite) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (point_octets_len != 65) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_signature_scheme_match_cipher_suite(sig_alg, conn->cipher_suite) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
// 解析server_key_exchange, curve_type, curve_name, point 这三个信息
|
||
// 判断curve_type == named_curve
|
||
// 判断curve_name在supported_groups中并记录这个信息
|
||
// 验证point确实在curve_name的group中
|
||
|
||
//sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
|
||
if (conn->client_certs_len)
|
||
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
|
||
|
||
|
||
|
||
|
||
|
||
if (x509_certs_get_cert_by_index(conn->peer_cert_chain, conn->peer_cert_chain_len,
|
||
server_cert_index, &server_cert, &server_cert_len) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
}
|
||
if (x509_cert_get_subject_public_key(server_cert, server_cert_len, &server_sign_key) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
}
|
||
|
||
|
||
|
||
|
||
|
||
// 这个检查是否是多余的?
|
||
// 这个值是签名算法和椭圆曲线名字的结合
|
||
// cipher_suite只能决定签名算法类型
|
||
// 公钥证书里面的公钥实际上只包含曲线的类型(而不决定签名算法,因为一个椭圆曲线本质上支持多种不同的签名算法)
|
||
switch (sig_alg) {
|
||
case TLS_sig_sm2sig_sm3:
|
||
if (server_sign_key.algor != OID_ec_public_key
|
||
|| server_sign_key.algor_param != OID_sm2) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
}
|
||
break;
|
||
case TLS_sig_ecdsa_secp256r1_sha256:
|
||
if (server_sign_key.algor != OID_ec_public_key
|
||
|| server_sign_key.algor_param != OID_secp256r1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
}
|
||
break;
|
||
default:
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if (server_sign_key.algor == OID_ec_public_key && server_sign_key.algor_param == OID_sm2) {
|
||
sign_args = SM2_DEFAULT_ID;
|
||
sign_argslen = SM2_DEFAULT_ID_LENGTH;
|
||
}
|
||
|
||
// 这里应该是SM2的签名和验证
|
||
if (x509_verify_init(&sign_ctx, &server_sign_key, sign_args, sign_argslen, sig, siglen) != 1
|
||
|| x509_verify_update(&sign_ctx, conn->client_random, 32) != 1
|
||
|| x509_verify_update(&sign_ctx, conn->server_random, 32) != 1
|
||
|| x509_verify_update(&sign_ctx, server_ecdh_params, 69) != 1
|
||
|| x509_verify_finish(&sign_ctx) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
fprintf(stderr, ">>>>>> ServerKeyExchange verify success\n");
|
||
|
||
|
||
// xxxx
|
||
// 这里的签名错了,肯定是sign_ctx就是不对的,因此是不可能正确的
|
||
// 现在要做的是,必须确定server_key_exchange中都包括了哪些被签名的消息
|
||
|
||
return 1;
|
||
}
|
||
|
||
|
||
int tls_send_certificate_request(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
|
||
// 如果要进行客户端证书验证,服务器要提供验证的证书,但是所有证书的
|
||
const uint8_t cert_types[] = { TLS_cert_type_ecdsa_sign };
|
||
uint8_t ca_names[TLS_MAX_CA_NAMES_SIZE] = {0}; // TODO: 根据客户端验证CA证书列计算缓冲大小,或直接输出到record缓冲
|
||
size_t ca_names_len = 0;
|
||
|
||
|
||
if (!conn->client_certificate_verify) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if (conn->recordlen == 0) {
|
||
tls_trace("send CertificateRequest\n");
|
||
if (tls_authorities_from_certs(ca_names, &ca_names_len, sizeof(ca_names),
|
||
conn->ctx->cacerts, conn->ctx->cacertslen) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_internal_error);
|
||
return -1;
|
||
}
|
||
if (tls_record_set_handshake_certificate_request(conn->record, &conn->recordlen,
|
||
cert_types, sizeof(cert_types),
|
||
ca_names, ca_names_len) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_internal_error);
|
||
return -1;
|
||
}
|
||
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
|
||
}
|
||
|
||
if ((ret = tls_send_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_SEND_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
|
||
//sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
|
||
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
|
||
|
||
return 1;
|
||
}
|
||
|
||
int tls_recv_certificate_request(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
uint8_t *record = conn->record;
|
||
const uint8_t *cp;
|
||
size_t len;
|
||
int handshake_type;
|
||
|
||
const uint8_t *cert_types;
|
||
size_t cert_types_len;
|
||
const uint8_t *ca_names;
|
||
size_t ca_names_len;
|
||
|
||
tls_trace("recv CertificateRequest*\n");
|
||
|
||
if ((ret = tls_recv_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_RECV_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
|
||
if (tls_record_protocol(conn->record) != conn->protocol) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return -1;
|
||
}
|
||
if (tls_record_get_handshake(record, &handshake_type, &cp, &len) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return -1;
|
||
}
|
||
|
||
if (handshake_type != TLS_handshake_certificate_request) {
|
||
tls_trace(" no CertificateRequest\n");
|
||
return 0; // 表明对方没有发送预期的报文
|
||
}
|
||
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
|
||
|
||
|
||
if (tls_record_get_handshake_certificate_request(conn->record,
|
||
&cert_types, &cert_types_len, &ca_names, &ca_names_len) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return -1;
|
||
}
|
||
|
||
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
// 这里要检查一下服务器发送的,和本地的是否保持一致
|
||
/*
|
||
if(!conn->client_certs_len) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_internal_error);
|
||
return -1;
|
||
}
|
||
|
||
if (tls_cert_types_accepted(cert_types, cert_types_len, conn->client_certs, conn->client_certs_len) != 1
|
||
|| tls_authorities_issued_certificate(ca_names, ca_names_len, conn->client_certs, conn->client_certs_len) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unsupported_certificate);
|
||
return -1;
|
||
}
|
||
*/
|
||
|
||
|
||
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
|
||
|
||
conn->recordlen = 0;
|
||
return 1;
|
||
}
|
||
|
||
int tls_send_server_hello_done(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
tls_trace("send ServerHelloDone\n");
|
||
|
||
|
||
if (conn->recordlen == 0) {
|
||
tls_record_set_handshake_server_hello_done(conn->record, &conn->recordlen);
|
||
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
|
||
|
||
|
||
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_handshake_digest_print(stderr, 0, 0, "ServerHelloDone", &conn->dgst_ctx);
|
||
}
|
||
|
||
|
||
if ((ret = tls_send_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_SEND_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
|
||
if (conn->client_certs_len) {
|
||
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
|
||
// 这是一个非常特殊的状态,其他的所有recv状态都是要读取的
|
||
// 但是这个状态在大多数情况下,之前已经读取完了,但是我们无法判断这个信息
|
||
int tls_recv_server_hello_done(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
tls_trace("recv ServerHelloDone\n");
|
||
|
||
if ((ret = tls_recv_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_RECV_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
if (tls_record_protocol(conn->record) != conn->protocol) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return -1;
|
||
}
|
||
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
|
||
|
||
if (tls_record_get_handshake_server_hello_done(conn->record) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return -1;
|
||
}
|
||
|
||
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_handshake_digest_print(stderr, 0, 0, "ServerHelloDone", &conn->dgst_ctx);
|
||
|
||
|
||
|
||
if (conn->client_certs_len)
|
||
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
|
||
|
||
|
||
return 1;
|
||
}
|
||
|
||
int tls_send_client_certificate(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
tls_trace("send ClientCertificate\n");
|
||
|
||
if (conn->client_certs_len == 0) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if (conn->recordlen == 0) {
|
||
if (tls_record_set_handshake_certificate(conn->record, &conn->recordlen,
|
||
conn->client_certs, conn->client_certs_len) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_internal_error);
|
||
return -1;
|
||
}
|
||
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
|
||
}
|
||
|
||
if ((ret = tls_send_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_SEND_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
|
||
|
||
//sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
|
||
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
|
||
|
||
return 1;
|
||
}
|
||
|
||
// 只有在需要验证客户端证书的时候这个函数才执行,是否内部要判断一下
|
||
int tls_recv_client_certificate(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
const int verify_depth = 5;
|
||
int verify_result;
|
||
|
||
tls_trace("recv ClientCertificate\n");
|
||
|
||
if (conn->ctx->cacertslen == 0) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if ((ret = tls_recv_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_RECV_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
if (tls_record_protocol(conn->record) != conn->protocol) { // protocol检查应该在trace之后
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return -1;
|
||
}
|
||
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
|
||
if (tls_record_get_handshake_certificate(conn->record, conn->client_certs, &conn->client_certs_len) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return -1;
|
||
}
|
||
if (x509_certs_verify(conn->client_certs, conn->client_certs_len, X509_cert_chain_client,
|
||
conn->ctx->cacerts, conn->ctx->cacertslen, verify_depth, &verify_result) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
}
|
||
//sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
|
||
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
|
||
|
||
return 1;
|
||
}
|
||
|
||
int tls_generate_keys(TLS_CONNECT *conn)
|
||
{
|
||
uint8_t pre_master_secret[32];
|
||
size_t pre_master_secret_len;
|
||
|
||
// 这里密钥是完全用ECDHE生成的
|
||
|
||
|
||
if (x509_key_exchange(&conn->key_exchanges[0],
|
||
conn->peer_key_exchange, conn->peer_key_exchange_len,
|
||
pre_master_secret, &pre_master_secret_len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (pre_master_secret_len != sizeof(pre_master_secret)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
format_bytes(stderr, 0, 0, "pre_master_secret", pre_master_secret, pre_master_secret_len);
|
||
|
||
// master_secret和transcript_hash没有任何关系
|
||
if (tls12_prf(conn->digest,
|
||
pre_master_secret, 32,
|
||
"master secret",
|
||
conn->client_random, 32,
|
||
conn->server_random, 32,
|
||
48, conn->master_secret) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
format_bytes(stderr, 0, 0, "master_secret", conn->master_secret, 48);
|
||
|
||
|
||
// OpenSSL tls1_prf 中,这里生成的是128字节,也就是把IV也生成了
|
||
// 为什么生成IV呢?
|
||
|
||
if (tls12_prf(conn->digest, conn->master_secret, 48, "key expansion",
|
||
conn->server_random, 32,
|
||
conn->client_random, 32,
|
||
96, conn->key_block) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_internal_error);
|
||
return -1;
|
||
}
|
||
|
||
/*
|
||
如果这里导出了IV,并且用这个IV去加密数据
|
||
被加密的数据中包含了一个随机的IV,那么这个随机的IV是干什么用的呢?
|
||
|
||
|
||
*/
|
||
|
||
format_bytes(stderr, 0, 0, "key_blocks", conn->key_block, 96);
|
||
|
||
if (hmac_init(&conn->client_write_mac_ctx, conn->digest, conn->key_block, 32) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (hmac_init(&conn->server_write_mac_ctx, conn->digest, conn->key_block + 32, 32) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
format_bytes(stderr, 0, 0, "client_write_mac_key", conn->key_block, 32);
|
||
format_bytes(stderr, 0, 0, "server_write_mac_key", conn->key_block + 32, 32);
|
||
format_bytes(stderr, 0, 0, "client_write_key", conn->key_block + 64, 16);
|
||
format_bytes(stderr, 0, 0, "server_write_key", conn->key_block + 80, 16);
|
||
|
||
|
||
if (conn->is_client) {
|
||
block_cipher_set_encrypt_key(&conn->client_write_key, conn->cipher, conn->key_block + 64);
|
||
block_cipher_set_decrypt_key(&conn->server_write_key, conn->cipher, conn->key_block + 80);
|
||
|
||
} else {
|
||
|
||
block_cipher_set_decrypt_key(&conn->client_write_key, conn->cipher, conn->key_block + 64);
|
||
block_cipher_set_encrypt_key(&conn->server_write_key, conn->cipher, conn->key_block + 80);
|
||
}
|
||
|
||
tls_seq_num_reset(conn->client_seq_num);
|
||
tls_seq_num_reset(conn->server_seq_num);
|
||
|
||
|
||
|
||
/*
|
||
tls_secrets_print(stderr,
|
||
pre_master_secret, 32,
|
||
conn->client_random, conn->server_random,
|
||
conn->master_secret,
|
||
conn->key_block, 96,
|
||
0, 4);
|
||
*/
|
||
|
||
return 1;
|
||
}
|
||
|
||
|
||
int tls_send_client_key_exchange(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
|
||
// 客户端的ECDHE的公钥肯定和服务器是保持一致的
|
||
// 因此在接收到服务器的公钥之后,应该保存这个信息
|
||
|
||
|
||
// 客户端是怎么确定密钥交换的group的?大概是从ServerKeyExchange中确定的
|
||
|
||
if (conn->recordlen == 0) {
|
||
uint8_t point_octets[65];
|
||
uint8_t *p = point_octets;
|
||
size_t len = 0;
|
||
int curve_oid = tls_named_curve_oid(conn->key_exchange_group);
|
||
|
||
if (x509_key_generate(&conn->key_exchanges[0], OID_ec_public_key, &curve_oid, sizeof(curve_oid)) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (x509_public_key_to_bytes(&conn->key_exchanges[0], &p, &len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (len != sizeof(point_octets)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
tls_trace("send ClientKeyExchange\n");
|
||
if (tls_record_set_handshake_client_key_exchange(conn->record, &conn->recordlen,
|
||
point_octets, len) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_internal_error);
|
||
return -1;
|
||
}
|
||
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
|
||
|
||
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_handshake_digest_print(stderr, 0, 0, "ClientKeyExchange", &conn->dgst_ctx);
|
||
}
|
||
|
||
if ((ret = tls_send_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_SEND_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
|
||
if (conn->client_certs_len)
|
||
sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5);
|
||
|
||
return 1;
|
||
}
|
||
|
||
int tls_recv_client_key_exchange(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
const uint8_t *point_octets;
|
||
size_t point_octets_len;
|
||
|
||
tls_trace("recv ClientKeyExchange\n");
|
||
if ((ret = tls_recv_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_RECV_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
if (tls_record_protocol(conn->record) != conn->protocol) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return -1;
|
||
}
|
||
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
|
||
|
||
if (tls_record_get_handshake_client_key_exchange(conn->record,
|
||
&point_octets, &point_octets_len) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return -1;
|
||
}
|
||
if (point_octets_len != 65) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
memcpy(conn->peer_key_exchange, point_octets, point_octets_len);
|
||
conn->peer_key_exchange_len = point_octets_len;
|
||
|
||
|
||
if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_handshake_digest_print(stderr, 0, 0, "ClientKeyExchange", &conn->dgst_ctx);
|
||
|
||
|
||
|
||
if (conn->ctx->cacertslen)
|
||
tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5);
|
||
|
||
return 1;
|
||
}
|
||
|
||
|
||
|
||
|
||
int tls_send_certificate_verify(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
uint8_t sig[SM2_MAX_SIGNATURE_SIZE];
|
||
size_t siglen;
|
||
|
||
tls_trace("send CertificateVerify\n");
|
||
|
||
if (!conn->client_certificate_verify) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if (conn->recordlen == 0) {
|
||
if (sm2_sign_finish(&conn->sign_ctx, sig, &siglen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_record_set_handshake_certificate_verify(conn->record, &conn->recordlen, sig, siglen) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_internal_error);
|
||
return -1;
|
||
}
|
||
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
|
||
}
|
||
|
||
if ((ret = tls_send_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_SEND_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
|
||
//sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
|
||
return 1;
|
||
}
|
||
|
||
int tls_recv_certificate_verify(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
X509_KEY client_sign_key;
|
||
|
||
const uint8_t *sig;
|
||
size_t siglen;
|
||
|
||
const uint8_t *client_cert;
|
||
size_t client_cert_len;
|
||
|
||
if (!conn->client_certificate_verify) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
tls_trace("recv CertificateVerify\n");
|
||
if ((ret = tls_recv_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_RECV_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
if (tls_record_protocol(conn->record) != conn->protocol) {
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
|
||
|
||
// get signature from certificate_verify
|
||
if (tls_record_get_handshake_certificate_verify(conn->record, &sig, &siglen) != 1) {
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
// get sign_key from client certificate
|
||
if (x509_certs_get_cert_by_index(conn->client_certs, conn->client_certs_len, 0,
|
||
&client_cert, &client_cert_len) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
}
|
||
if (x509_cert_get_subject_public_key(client_cert, client_cert_len, &client_sign_key) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
}
|
||
// 这里是否要验证证书的类型呢?我们现在还不支持其他签名算法
|
||
if (client_sign_key.algor != OID_ec_public_key
|
||
|| client_sign_key.algor_param != OID_sm2) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_certificate);
|
||
return -1;
|
||
}
|
||
|
||
if (tls_client_verify_finish(&conn->client_verify_ctx, sig, siglen, &client_sign_key.u.sm2_key) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_decrypt_error);
|
||
return -1;
|
||
}
|
||
//sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5);
|
||
|
||
return 1;
|
||
}
|
||
|
||
int tls_send_change_cipher_spec(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
if (conn->recordlen == 0) {
|
||
tls_trace("send [ChangeCipherSpec]\n");
|
||
if (tls_record_set_change_cipher_spec(conn->record, &conn->recordlen) !=1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_internal_error);
|
||
return -1;
|
||
}
|
||
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
|
||
}
|
||
if ((ret = tls_send_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_SEND_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_recv_change_cipher_spec(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
|
||
tls_trace("recv [ChangeCipherSpec]\n");
|
||
if ((ret = tls_recv_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_RECV_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
|
||
if (tls_record_protocol(conn->record) != conn->protocol) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return -1;
|
||
}
|
||
|
||
tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0);
|
||
if (tls_record_get_change_cipher_spec(conn->record) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_send_client_finished(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
|
||
|
||
if (conn->recordlen == 0) {
|
||
tls_trace("send client {Finished}\n");
|
||
|
||
uint8_t local_verify_data[12];
|
||
|
||
|
||
DIGEST_CTX tmp_ctx;
|
||
uint8_t dgst[32];
|
||
size_t dgstlen;
|
||
|
||
tmp_ctx = conn->dgst_ctx;
|
||
|
||
digest_finish(&tmp_ctx, dgst, &dgstlen);
|
||
|
||
if (tls12_prf(conn->digest,
|
||
conn->master_secret, 48,
|
||
"client finished", dgst, dgstlen, NULL, 0,
|
||
sizeof(local_verify_data), local_verify_data) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_internal_error);
|
||
return -1;
|
||
}
|
||
|
||
tls_record_set_protocol(conn->plain_record, conn->protocol);
|
||
|
||
if (tls_record_set_handshake_finished(conn->plain_record, &conn->plain_recordlen,
|
||
local_verify_data, sizeof(local_verify_data)) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_internal_error);
|
||
return -1;
|
||
}
|
||
|
||
tls12_record_trace(stderr, conn->plain_record, conn->plain_recordlen, 0, 0);
|
||
|
||
if (digest_update(&conn->dgst_ctx, conn->plain_record + 5, conn->plain_recordlen - 5) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_handshake_digest_print(stderr, 0, 0, "Finished", &conn->dgst_ctx);
|
||
|
||
if (tls12_record_encrypt(&conn->client_write_mac_ctx, &conn->client_write_key,
|
||
conn->client_seq_num, conn->plain_record, conn->plain_recordlen,
|
||
conn->record, &conn->recordlen) != 1) {
|
||
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_internal_error);
|
||
return -1;
|
||
}
|
||
tls_seq_num_incr(conn->client_seq_num);
|
||
|
||
format_bytes(stderr, 0, 0, "encrypted finsished ..... ", conn->record, conn->recordlen);
|
||
}
|
||
|
||
if ((ret = tls_send_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_SEND_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
|
||
return 1;
|
||
}
|
||
|
||
int tls_recv_client_finished(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
|
||
const uint8_t *verify_data;
|
||
size_t verify_data_len;
|
||
uint8_t local_verify_data[12];
|
||
|
||
DIGEST_CTX tmp_ctx;
|
||
uint8_t dgst[32];
|
||
size_t dgstlen;
|
||
|
||
|
||
tmp_ctx = conn->dgst_ctx;
|
||
|
||
if (digest_finish(&tmp_ctx, dgst, &dgstlen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls12_prf(conn->digest, conn->master_secret, 48, "client finished", dgst, dgstlen, NULL, 0,
|
||
sizeof(local_verify_data), local_verify_data) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_internal_error);
|
||
return -1;
|
||
}
|
||
format_bytes(stderr, 0, 0, "verify_data", local_verify_data, 12);
|
||
|
||
|
||
// recv ClientFinished
|
||
tls_trace("recv client {Finished}\n");
|
||
if ((ret = tls_recv_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_RECV_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
//tls12_record_print(stderr, conn->record, conn->recordlen, 0, 0);
|
||
|
||
format_bytes(stderr, 0, 0, "Finished", conn->record, conn->recordlen);
|
||
|
||
|
||
if (tls_record_protocol(conn->record) != conn->protocol) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return -1;
|
||
}
|
||
|
||
// decrypt ClientFinished
|
||
tls_trace(">>>>>>>decrypt Finished\n");
|
||
|
||
|
||
format_bytes(stderr, 0, 0, "client_seq_num", conn->client_seq_num, 8);
|
||
|
||
if (tls12_record_decrypt(&conn->client_write_mac_ctx, &conn->client_write_key,
|
||
conn->client_seq_num, conn->record, conn->recordlen,
|
||
conn->plain_record, &conn->plain_recordlen) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_record_mac);
|
||
return -1;
|
||
}
|
||
tls_seq_num_incr(conn->client_seq_num);
|
||
|
||
|
||
|
||
tls12_record_trace(stderr, conn->plain_record, conn->plain_recordlen, 0, 0);
|
||
|
||
if (tls_record_get_handshake_finished(conn->plain_record, &verify_data, &verify_data_len) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_record_mac);
|
||
return -1;
|
||
}
|
||
if (verify_data_len != sizeof(local_verify_data)) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_record_mac);
|
||
return -1;
|
||
}
|
||
|
||
|
||
if (digest_update(&conn->dgst_ctx, conn->plain_record + 5, conn->plain_recordlen - 5) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_handshake_digest_print(stderr, 0, 0, "client Finished", &conn->dgst_ctx);
|
||
|
||
// verify ClientFinished
|
||
|
||
|
||
if (memcmp(verify_data, local_verify_data, sizeof(local_verify_data)) != 0) {
|
||
error_puts("client_finished.verify_data verification failure");
|
||
tls_send_alert(conn, TLS_alert_decrypt_error);
|
||
return -1;
|
||
}
|
||
|
||
return 1;
|
||
}
|
||
|
||
int tls_send_server_finished(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
uint8_t *record = conn->record;
|
||
size_t recordlen;
|
||
uint8_t local_verify_data[12];
|
||
|
||
|
||
tls_record_set_protocol(conn->plain_record, conn->protocol);
|
||
|
||
if (conn->recordlen == 0) {
|
||
tls_trace("send server Finished\n");
|
||
|
||
uint8_t dgst[32];
|
||
size_t dgstlen;
|
||
|
||
digest_finish(&conn->dgst_ctx, dgst, &dgstlen);
|
||
|
||
if (tls12_prf(conn->digest, conn->master_secret, 48, "server finished", dgst, dgstlen, NULL, 0,
|
||
sizeof(local_verify_data), local_verify_data) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
format_bytes(stderr, 0, 0, "server verify_data", local_verify_data, 12);
|
||
|
||
if (tls_record_set_handshake_finished(conn->plain_record, &conn->plain_recordlen,
|
||
local_verify_data, sizeof(local_verify_data)) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_internal_error);
|
||
return -1;
|
||
}
|
||
tls12_record_trace(stderr, conn->plain_record, conn->plain_recordlen, 0, 0);
|
||
|
||
if (tls12_record_encrypt(&conn->server_write_mac_ctx, &conn->server_write_key,
|
||
conn->server_seq_num, conn->plain_record, conn->plain_recordlen,
|
||
conn->record, &conn->recordlen) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_internal_error);
|
||
return -1;
|
||
}
|
||
tls_seq_num_incr(conn->server_seq_num);
|
||
|
||
}
|
||
|
||
if ((ret = tls_send_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_SEND_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
|
||
return 1;
|
||
}
|
||
|
||
int tls_recv_server_finished(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
uint8_t finished_record[TLS_FINISHED_RECORD_BUF_SIZE];
|
||
size_t finished_record_len;
|
||
|
||
uint8_t dgst[32];
|
||
size_t dgstlen;
|
||
|
||
const uint8_t *verify_data;
|
||
size_t verify_data_len;
|
||
uint8_t local_verify_data[12];
|
||
|
||
|
||
|
||
if (digest_finish(&conn->dgst_ctx, dgst, &dgstlen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls12_prf(conn->digest, conn->master_secret, 48, "server finished",
|
||
dgst, dgstlen, NULL, 0,
|
||
sizeof(local_verify_data), local_verify_data) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_internal_error);
|
||
return -1;
|
||
}
|
||
format_bytes(stderr, 0, 0, ">>> verify_data", local_verify_data, 12);
|
||
|
||
|
||
// Finished
|
||
tls_trace("recv server Finished\n");
|
||
if ((ret = tls_recv_record(conn)) != 1) {
|
||
if (ret != TLS_ERROR_RECV_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
if (tls_record_protocol(conn->record) != conn->protocol) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return -1;
|
||
}
|
||
|
||
|
||
tls_trace("decrypt Finished\n");
|
||
|
||
|
||
format_bytes(stderr, 0, 0, "server_seq_num", conn->server_seq_num, 8);
|
||
|
||
if (tls12_record_decrypt(&conn->server_write_mac_ctx, &conn->server_write_key,
|
||
conn->server_seq_num, conn->record, conn->recordlen,
|
||
conn->plain_record, &conn->plain_recordlen) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_bad_record_mac);
|
||
return -1;
|
||
}
|
||
tls12_record_print(stderr, conn->plain_record, conn->plain_recordlen, 0, 0);
|
||
|
||
tls_seq_num_incr(conn->server_seq_num);
|
||
|
||
if (tls_record_get_handshake_finished(conn->plain_record, &verify_data, &verify_data_len) != 1) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return -1;
|
||
}
|
||
if (verify_data_len != sizeof(local_verify_data)) {
|
||
error_print();
|
||
tls_send_alert(conn, TLS_alert_unexpected_message);
|
||
return -1;
|
||
}
|
||
|
||
|
||
if (memcmp(verify_data, local_verify_data, sizeof(local_verify_data)) != 0) {
|
||
error_puts("server_finished.verify_data verification failure");
|
||
tls_send_alert(conn, TLS_alert_decrypt_error);
|
||
return -1;
|
||
}
|
||
|
||
if (!conn->ctx->quiet)
|
||
fprintf(stderr, "Connection established!\n");
|
||
|
||
return 1;
|
||
}
|
||
|
||
|
||
|
||
/*
|
||
Client Server
|
||
|
||
ClientHello -------->
|
||
ServerHello
|
||
Certificate
|
||
ServerKeyExchange
|
||
CertificateRequest*
|
||
<-------- ServerHelloDone
|
||
Certificate*
|
||
ClientKeyExchange
|
||
CertificateVerify*
|
||
[ChangeCipherSpec]
|
||
Finished -------->
|
||
[ChangeCipherSpec]
|
||
<-------- Finished
|
||
Application Data <-------> Application Data
|
||
|
||
*/
|
||
|
||
int tls12_do_client_handshake(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
int next_state;
|
||
|
||
switch (conn->state) {
|
||
case TLS_state_client_hello:
|
||
ret = tls_send_client_hello(conn);
|
||
next_state = TLS_state_server_hello;
|
||
break;
|
||
|
||
case TLS_state_server_hello:
|
||
ret = tls_recv_server_hello(conn);
|
||
next_state = TLS_state_server_certificate;
|
||
break;
|
||
|
||
case TLS_state_server_certificate:
|
||
ret = tls_recv_server_certificate(conn);
|
||
next_state = TLS_state_server_key_exchange;
|
||
break;
|
||
|
||
case TLS_state_server_key_exchange:
|
||
ret = tls_recv_server_key_exchange(conn);
|
||
next_state = TLS_state_certificate_request;
|
||
break;
|
||
|
||
// the only optional state
|
||
case TLS_state_certificate_request:
|
||
fprintf(stderr, "TLS_state_certificate_request\n");
|
||
ret = tls_recv_certificate_request(conn);
|
||
fprintf(stderr, " ret = %d\n", ret);
|
||
|
||
if (ret == 1) conn->client_certificate_verify = 1;
|
||
next_state = TLS_state_server_hello_done;
|
||
break;
|
||
|
||
case TLS_state_server_hello_done:
|
||
fprintf(stderr, "TLS_state_server_hello_done\n");
|
||
ret = tls_recv_server_hello_done(conn);
|
||
if (conn->client_certificate_verify)
|
||
next_state = TLS_state_client_certificate;
|
||
else next_state = TLS_state_client_key_exchange;
|
||
break;
|
||
|
||
case TLS_state_client_certificate:
|
||
ret = tls_send_client_certificate(conn);
|
||
next_state = TLS_state_client_key_exchange;
|
||
break;
|
||
|
||
case TLS_state_client_key_exchange:
|
||
ret = tls_send_client_key_exchange(conn);
|
||
next_state = TLS_state_generate_keys;
|
||
break;
|
||
|
||
case TLS_state_generate_keys:
|
||
ret = tls_generate_keys(conn);
|
||
if (conn->client_certificate_verify)
|
||
next_state = TLS_state_certificate_verify;
|
||
else next_state = TLS_state_client_change_cipher_spec;
|
||
break;
|
||
|
||
case TLS_state_certificate_verify:
|
||
ret = tls_send_certificate_verify(conn);
|
||
next_state = TLS_state_client_change_cipher_spec;
|
||
|
||
case TLS_state_client_change_cipher_spec:
|
||
ret = tls_send_change_cipher_spec(conn);
|
||
next_state = TLS_state_client_finished;
|
||
break;
|
||
|
||
case TLS_state_client_finished:
|
||
ret = tls_send_client_finished(conn);
|
||
next_state = TLS_state_server_change_cipher_spec;
|
||
break;
|
||
|
||
case TLS_state_server_change_cipher_spec:
|
||
ret = tls_recv_change_cipher_spec(conn);
|
||
next_state = TLS_state_server_finished;
|
||
break;
|
||
|
||
case TLS_state_server_finished:
|
||
ret = tls_recv_server_finished(conn);
|
||
next_state = TLS_state_handshake_over;
|
||
break;
|
||
|
||
default:
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if (ret < 0) {
|
||
if (ret == TLS_ERROR_RECV_AGAIN || ret == TLS_ERROR_SEND_AGAIN) {
|
||
return ret;
|
||
} else {
|
||
error_print();
|
||
return ret;
|
||
}
|
||
}
|
||
|
||
conn->state = next_state;
|
||
|
||
// ret == 0 means this step is bypassed
|
||
if (ret == 1) {
|
||
tls_clean_record(conn);
|
||
}
|
||
|
||
return 1;
|
||
}
|
||
|
||
int tls12_do_server_handshake(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
int next_state;
|
||
|
||
switch (conn->state) {
|
||
case TLS_state_client_hello:
|
||
ret = tls_recv_client_hello(conn);
|
||
next_state = TLS_state_server_hello;
|
||
break;
|
||
|
||
case TLS_state_server_hello:
|
||
ret = tls_send_server_hello(conn);
|
||
next_state = TLS_state_server_certificate;
|
||
break;
|
||
|
||
case TLS_state_server_certificate:
|
||
ret = tls_send_server_certificate(conn);
|
||
next_state = TLS_state_server_key_exchange;
|
||
break;
|
||
|
||
case TLS_state_server_key_exchange:
|
||
ret = tls_send_server_key_exchange(conn);
|
||
if (conn->client_certificate_verify)
|
||
next_state = TLS_state_certificate_request;
|
||
else next_state = TLS_state_server_hello_done;
|
||
break;
|
||
|
||
case TLS_state_certificate_request:
|
||
ret = tls_send_certificate_request(conn);
|
||
next_state = TLS_state_server_hello_done;
|
||
break;
|
||
|
||
case TLS_state_server_hello_done:
|
||
ret = tls_send_server_hello_done(conn);
|
||
if (conn->client_certificate_verify)
|
||
next_state = TLS_state_client_certificate;
|
||
else next_state = TLS_state_client_key_exchange;
|
||
break;
|
||
|
||
case TLS_state_client_certificate:
|
||
ret = tls_recv_client_certificate(conn);
|
||
next_state = TLS_state_client_key_exchange;
|
||
break;
|
||
|
||
case TLS_state_client_key_exchange:
|
||
ret = tls_recv_client_key_exchange(conn);
|
||
if (conn->client_certificate_verify)
|
||
next_state = TLS_state_certificate_verify;
|
||
else next_state = TLS_state_generate_keys;
|
||
break;
|
||
|
||
case TLS_state_certificate_verify:
|
||
ret = tls_recv_certificate_verify(conn);
|
||
next_state = TLS_state_generate_keys;
|
||
break;
|
||
|
||
case TLS_state_generate_keys:
|
||
ret = tls_generate_keys(conn);
|
||
next_state = TLS_state_client_change_cipher_spec;
|
||
break;
|
||
|
||
case TLS_state_client_change_cipher_spec:
|
||
ret = tls_recv_change_cipher_spec(conn);
|
||
next_state = TLS_state_client_finished;
|
||
break;
|
||
|
||
case TLS_state_client_finished:
|
||
ret = tls_recv_client_finished(conn);
|
||
next_state = TLS_state_server_change_cipher_spec;
|
||
break;
|
||
|
||
case TLS_state_server_change_cipher_spec:
|
||
ret = tls_send_change_cipher_spec(conn);
|
||
next_state = TLS_state_server_finished;
|
||
break;
|
||
|
||
case TLS_state_server_finished:
|
||
ret = tls_send_server_finished(conn);
|
||
next_state = TLS_state_handshake_over;
|
||
break;
|
||
|
||
default:
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if (ret != 1) {
|
||
if (ret == TLS_ERROR_RECV_AGAIN || ret == TLS_ERROR_SEND_AGAIN) {
|
||
return ret;
|
||
} else {
|
||
error_print();
|
||
return ret;
|
||
}
|
||
|
||
}
|
||
|
||
conn->state = next_state;
|
||
|
||
tls_clean_record(conn);
|
||
|
||
return 1;
|
||
}
|
||
|
||
|
||
// 这个函数显然是不对的,因为这个函数就是一个重入的函数,重入函数不应该自己设置状态啊
|
||
int tls12_client_handshake(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
|
||
while (conn->state != TLS_state_handshake_over) {
|
||
|
||
ret = tls12_do_client_handshake(conn);
|
||
|
||
if (ret != 1) {
|
||
if (ret != TLS_ERROR_RECV_AGAIN && ret != TLS_ERROR_SEND_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
}
|
||
|
||
// TODO: cleanup conn?
|
||
|
||
return 1;
|
||
}
|
||
|
||
int tls12_server_handshake(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
|
||
|
||
while (conn->state != TLS_state_handshake_over) {
|
||
|
||
ret = tls12_do_server_handshake(conn);
|
||
|
||
if (ret != 1) {
|
||
if (ret != TLS_ERROR_RECV_AGAIN && ret != TLS_ERROR_SEND_AGAIN) {
|
||
error_print();
|
||
}
|
||
return ret;
|
||
}
|
||
}
|
||
|
||
// TODO: cleanup conn?
|
||
|
||
return 1;
|
||
}
|
||
|
||
int tls12_do_connect(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
fd_set rfds;
|
||
fd_set wfds;
|
||
|
||
conn->state = TLS_state_client_hello;
|
||
//sm3_init(&conn->sm3_ctx);
|
||
|
||
|
||
digest_init(&conn->dgst_ctx, DIGEST_sm3());
|
||
|
||
while (1) {
|
||
|
||
ret = tls12_client_handshake(conn);
|
||
if (ret == 1) {
|
||
break;
|
||
|
||
} else if (ret == TLS_ERROR_SEND_AGAIN) {
|
||
FD_ZERO(&rfds);
|
||
FD_ZERO(&wfds);
|
||
FD_SET(conn->sock, &wfds);
|
||
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
|
||
|
||
} else if (ret == TLS_ERROR_RECV_AGAIN) {
|
||
FD_ZERO(&rfds);
|
||
FD_ZERO(&wfds);
|
||
FD_SET(conn->sock, &rfds);
|
||
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
|
||
|
||
} else {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
return 1;
|
||
}
|
||
|
||
int tls12_do_accept(TLS_CONNECT *conn)
|
||
{
|
||
int ret;
|
||
fd_set rfds;
|
||
fd_set wfds;
|
||
|
||
conn->state = TLS_state_client_hello;
|
||
|
||
//sm3_init(&conn->sm3_ctx);
|
||
digest_init(&conn->dgst_ctx, DIGEST_sm3());
|
||
|
||
while (1) {
|
||
|
||
ret = tls12_server_handshake(conn);
|
||
|
||
if (ret == 1) {
|
||
break;
|
||
|
||
} else if (ret == TLS_ERROR_SEND_AGAIN) {
|
||
FD_ZERO(&rfds);
|
||
FD_ZERO(&wfds);
|
||
FD_SET(conn->sock, &rfds);
|
||
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
|
||
|
||
} else if (ret == TLS_ERROR_RECV_AGAIN) {
|
||
FD_ZERO(&rfds);
|
||
FD_ZERO(&wfds);
|
||
FD_SET(conn->sock, &wfds);
|
||
select(conn->sock + 1, &rfds, &wfds, NULL, NULL);
|
||
|
||
} else {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
return 1;
|
||
}
|