Files
GmSSL/src/tls.c
2026-04-13 11:34:16 +08:00

2864 lines
63 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/*
* Copyright 2014-2026 The GmSSL Project. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the License); you may
* not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*/
#include <time.h>
#include <stdio.h>
#include <fcntl.h>
#include <assert.h>
#include <stdlib.h>
#include <string.h>
#include <gmssl/rand.h>
#include <gmssl/x509.h>
#include <gmssl/error.h>
#include <gmssl/endian.h>
#include <gmssl/mem.h>
#include <gmssl/sm2.h>
#include <gmssl/sm3.h>
#include <gmssl/sm4.h>
#include <gmssl/pem.h>
#include <gmssl/tls.h>
void tls_uint8_to_bytes(uint8_t a, uint8_t **out, size_t *outlen)
{
if (out && *out) {
*(*out)++ = a;
}
(*outlen)++;
}
void tls_uint16_to_bytes(uint16_t a, uint8_t **out, size_t *outlen)
{
if (out && *out) {
*(*out)++ = (uint8_t)(a >> 8);
*(*out)++ = (uint8_t)a;
}
*outlen += 2;
}
void tls_uint24_to_bytes(uint24_t a, uint8_t **out, size_t *outlen)
{
if (out && *out) {
*(*out)++ = (uint8_t)(a >> 16);
*(*out)++ = (uint8_t)(a >> 8);
*(*out)++ = (uint8_t)(a);
}
(*outlen) += 3;
}
void tls_uint32_to_bytes(uint32_t a, uint8_t **out, size_t *outlen)
{
if (out && *out) {
*(*out)++ = (uint8_t)(a >> 24);
*(*out)++ = (uint8_t)(a >> 16);
*(*out)++ = (uint8_t)(a >> 8);
*(*out)++ = (uint8_t)(a );
}
(*outlen) += 4;
}
void tls_uint64_to_bytes(uint64_t a, uint8_t **out, size_t *outlen)
{
if (out && *out) {
PUTU64(*out, a);
}
(*outlen) += 8;
}
void tls_array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen)
{
if (out && *out) {
if (data) {
memcpy(*out, data, datalen);
}
*out += datalen;
}
*outlen += datalen;
}
/*
这几个函数要区分data = NULL, datalen = 0 和 data = NULL, datalen != 0的情况
前者意味着数据为空,因此输出的就是一个长度
后者意味着数据不为空,只是我们不想输出数据,只输出头部的长度,并且更新整个的输出长度。 这种情况应该避免!
*/
void tls_uint8array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen)
{
tls_uint8_to_bytes((uint8_t)datalen, out, outlen);
tls_array_to_bytes(data, datalen, out, outlen);
}
void tls_uint16array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen)
{
tls_uint16_to_bytes((uint16_t)datalen, out, outlen);
tls_array_to_bytes(data, datalen, out, outlen);
}
void tls_uint24array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen)
{
tls_uint24_to_bytes((uint24_t)datalen, out, outlen);
tls_array_to_bytes(data, datalen, out, outlen);
}
int tls_uint8_from_bytes(uint8_t *a, const uint8_t **in, size_t *inlen)
{
if (*inlen < 1) {
error_print();
return -1;
}
*a = *(*in)++;
(*inlen)--;
return 1;
}
int tls_uint16_from_bytes(uint16_t *a, const uint8_t **in, size_t *inlen)
{
if (*inlen < 2) {
error_print();
return -1;
}
*a = *(*in)++;
*a <<= 8;
*a |= *(*in)++;
*inlen -= 2;
return 1;
}
int tls_uint24_from_bytes(uint24_t *a, const uint8_t **in, size_t *inlen)
{
if (*inlen < 3) {
error_print();
return -1;
}
*a = *(*in)++;
*a <<= 8;
*a |= *(*in)++;
*a <<= 8;
*a |= *(*in)++;
*inlen -= 3;
return 1;
}
int tls_uint32_from_bytes(uint32_t *a, const uint8_t **in, size_t *inlen)
{
if (*inlen < 4) {
error_print();
return -1;
}
*a = *(*in)++;
*a <<= 8;
*a |= *(*in)++;
*a <<= 8;
*a |= *(*in)++;
*a <<= 8;
*a |= *(*in)++;
*inlen -= 4;
return 1;
}
int tls_uint64_from_bytes(uint64_t *a, const uint8_t **in, size_t *inlen)
{
if (*inlen < 8) {
error_print();
return -1;
}
*a = GETU64(*in);
*in += 8;
*inlen -= 8;
return 1;
}
int tls_array_from_bytes(const uint8_t **data, size_t datalen, const uint8_t **in, size_t *inlen)
{
if (*inlen < datalen) {
error_print();
return -1;
}
*data = *in;
*in += datalen;
*inlen -= datalen;
return 1;
}
int tls_uint8array_from_bytes(const uint8_t **data, size_t *datalen, const uint8_t **in, size_t *inlen)
{
uint8_t len;
if (tls_uint8_from_bytes(&len, in, inlen) != 1
|| tls_array_from_bytes(data, len, in, inlen) != 1) {
error_print();
return -1;
}
if (!len) {
*data = NULL;
}
*datalen = len;
return 1;
}
int tls_uint16array_from_bytes(const uint8_t **data, size_t *datalen, const uint8_t **in, size_t *inlen)
{
uint16_t len;
if (tls_uint16_from_bytes(&len, in, inlen) != 1
|| tls_array_from_bytes(data, len, in, inlen) != 1) {
error_print();
return -1;
}
if (!len) {
*data = NULL;
}
*datalen = len;
return 1;
}
int tls_uint24array_from_bytes(const uint8_t **data, size_t *datalen, const uint8_t **in, size_t *inlen)
{
uint24_t len;
if (tls_uint24_from_bytes(&len, in, inlen) != 1
|| tls_array_from_bytes(data, len, in, inlen) != 1) {
error_print();
return -1;
}
if (!len) {
*data = NULL;
}
*datalen = len;
return 1;
}
int tls_length_is_zero(size_t len)
{
if (len) {
error_print();
return -1;
}
return 1;
}
int tls_record_set_type(uint8_t *record, int type)
{
if (!tls_record_type_name(type)) {
error_print();
return -1;
}
record[0] = (uint8_t)type;
return 1;
}
int tls_record_set_protocol(uint8_t *record, int protocol)
{
if (!tls_protocol_name(protocol)) {
error_print();
return -1;
}
record[1] = (uint8_t)(protocol >> 8);
record[2] = (uint8_t)(protocol);
return 1;
}
int tls_record_set_length(uint8_t *record, size_t length)
{
uint8_t *p = record + 3;
size_t len;
if (length > TLS_MAX_CIPHERTEXT_SIZE) {
error_print();
return -1;
}
tls_uint16_to_bytes((uint16_t)length, &p, &len);
return 1;
}
int tls_record_set_data(uint8_t *record, const uint8_t *data, size_t datalen)
{
if (tls_record_set_length(record, datalen) != 1) {
error_print();
return -1;
}
memcpy(tls_record_data(record), data, datalen);
return 1;
}
int tls_cbc_encrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *enc_key,
const uint8_t seq_num[8], const uint8_t header[5],
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
{
SM3_HMAC_CTX hmac_ctx;
uint8_t last_blocks[32 + 16] = {0};
uint8_t iv[16];
uint8_t *mac, *padding;
int rem, padding_len;
int i;
if (!inited_hmac_ctx || !enc_key || !seq_num || !header || (!in && inlen) || !out || !outlen) {
error_print();
return -1;
}
if (inlen > (1 << 14)) {
error_print_msg("invalid tls record data length %zu\n", inlen);
return -1;
}
if ((((size_t)header[3]) << 8) + header[4] != inlen) {
error_print();
return -1;
}
rem = (inlen + 32) % 16;
memcpy(last_blocks, in + inlen - rem, rem);
mac = last_blocks + rem;
memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
sm3_hmac_update(&hmac_ctx, seq_num, 8);
sm3_hmac_update(&hmac_ctx, header, 5);
sm3_hmac_update(&hmac_ctx, in, inlen);
sm3_hmac_finish(&hmac_ctx, mac);
padding = mac + 32;
padding_len = 16 - rem - 1;
for (i = 0; i <= padding_len; i++) {
padding[i] = (uint8_t)padding_len;
}
if (rand_bytes(iv, 16) != 1) {
error_print();
return -1;
}
memcpy(out, iv, 16);
out += 16;
if (inlen >= 16) {
sm4_cbc_encrypt_blocks(enc_key, iv, in, inlen/16, out);
out += inlen - rem;
}
sm4_cbc_encrypt_blocks(enc_key, iv, last_blocks, sizeof(last_blocks)/16, out);
*outlen = 16 + inlen - rem + sizeof(last_blocks);
return 1;
}
// 这个函数应该把所有的输入的dgst都打印出来这样就可以容易判断出到底是哪个输入错了
int tls_cbc_decrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *dec_key,
const uint8_t seq_num[8], const uint8_t enced_header[5],
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
{
SM3_HMAC_CTX hmac_ctx;
uint8_t iv[16];
const uint8_t *padding;
const uint8_t *mac;
uint8_t header[5];
int padding_len;
uint8_t hmac[32];
int i;
if (!inited_hmac_ctx || !dec_key || !seq_num || !enced_header || !in || !inlen || !out || !outlen) {
error_print();
return -1;
}
if (inlen % 16
|| inlen < (16 + 0 + 32 + 16) // iv + data + mac + padding
|| inlen > (16 + (1<<14) + 32 + 256)) {
error_print_msg("invalid tls cbc ciphertext length %zu\n", inlen);
return -1;
}
memcpy(iv, in, 16);
in += 16;
inlen -= 16;
sm4_cbc_decrypt_blocks(dec_key, iv, in, inlen/16, out);
padding_len = out[inlen - 1];
padding = out + inlen - padding_len - 1;
if (padding < out + 32) {
error_print();
return -1;
}
for (i = 0; i < padding_len; i++) {
if (padding[i] != padding_len) {
error_puts("tls ciphertext cbc-padding check failure");
return -1;
}
}
*outlen = inlen - 32 - padding_len - 1;
header[0] = enced_header[0];
header[1] = enced_header[1];
header[2] = enced_header[2];
header[3] = (uint8_t)((*outlen) >> 8);
header[4] = (uint8_t)(*outlen);
mac = padding - 32;
memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
sm3_hmac_update(&hmac_ctx, seq_num, 8);
sm3_hmac_update(&hmac_ctx, header, 5);
sm3_hmac_update(&hmac_ctx, out, *outlen);
sm3_hmac_finish(&hmac_ctx, hmac);
if (gmssl_secure_memcmp(mac, hmac, sizeof(hmac)) != 0) {
error_puts("tls ciphertext mac check failure\n");
return -1;
}
return 1;
}
int tls_record_encrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key,
const uint8_t seq_num[8], const uint8_t *in, size_t inlen,
uint8_t *out, size_t *outlen)
{
if (tls_cbc_encrypt(hmac_ctx, cbc_key, seq_num, in,
in + 5, inlen - 5,
out + 5, outlen) != 1) {
error_print();
return -1;
}
out[0] = in[0];
out[1] = in[1];
out[2] = in[2];
out[3] = (uint8_t)((*outlen) >> 8);
out[4] = (uint8_t)(*outlen);
(*outlen) += 5;
return 1;
}
int tls_record_decrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key,
const uint8_t seq_num[8], const uint8_t *in, size_t inlen,
uint8_t *out, size_t *outlen)
{
if (tls_cbc_decrypt(hmac_ctx, cbc_key, seq_num, in,
in + 5, inlen - 5,
out + 5, outlen) != 1) {
error_print();
return -1;
}
out[0] = in[0];
out[1] = in[1];
out[2] = in[2];
out[3] = (uint8_t)((*outlen) >> 8);
out[4] = (uint8_t)(*outlen);
(*outlen) += 5;
return 1;
}
int tls_random_generate(uint8_t random[32])
{
uint32_t gmt_unix_time = (uint32_t)time(NULL);
uint8_t *p = random;
size_t len = 0;
tls_uint32_to_bytes(gmt_unix_time, &p, &len);
if (rand_bytes(random + 4, 28) != 1) {
error_print();
return -1;
}
return 1;
}
int tls_prf(const uint8_t *secret, size_t secretlen, const char *label,
const uint8_t *seed, size_t seedlen,
const uint8_t *more, size_t morelen,
size_t outlen, uint8_t *out)
{
SM3_HMAC_CTX inited_hmac_ctx;
SM3_HMAC_CTX hmac_ctx;
uint8_t A[32];
uint8_t hmac[32];
size_t len;
if (!secret || !secretlen || !label || !seed || !seedlen
|| (!more && morelen) || !outlen || !out) {
error_print();
return -1;
}
sm3_hmac_init(&inited_hmac_ctx, secret, secretlen);
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
sm3_hmac_update(&hmac_ctx, seed, seedlen);
sm3_hmac_update(&hmac_ctx, more, morelen);
sm3_hmac_finish(&hmac_ctx, A);
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
sm3_hmac_update(&hmac_ctx, A, sizeof(A));
sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
sm3_hmac_update(&hmac_ctx, seed, seedlen);
sm3_hmac_update(&hmac_ctx, more, morelen);
sm3_hmac_finish(&hmac_ctx, hmac);
len = outlen < sizeof(hmac) ? outlen : sizeof(hmac);
memcpy(out, hmac, len);
out += len;
outlen -= len;
while (outlen) {
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
sm3_hmac_update(&hmac_ctx, A, sizeof(A));
sm3_hmac_finish(&hmac_ctx, A);
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
sm3_hmac_update(&hmac_ctx, A, sizeof(A));
sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
sm3_hmac_update(&hmac_ctx, seed, seedlen);
sm3_hmac_update(&hmac_ctx, more, morelen);
sm3_hmac_finish(&hmac_ctx, hmac);
len = outlen < sizeof(hmac) ? outlen : sizeof(hmac);
memcpy(out, hmac, len);
out += len;
outlen -= len;
}
return 1;
}
int tls_pre_master_secret_generate(uint8_t pre_master_secret[48], int protocol)
{
if (!tls_protocol_name(protocol)) {
error_print();
return -1;
}
pre_master_secret[0] = (uint8_t)(protocol >> 8);
pre_master_secret[1] = (uint8_t)(protocol);
if (rand_bytes(pre_master_secret + 2, 46) != 1) {
error_print();
return -1;
}
return 1;
}
// 用于设置CertificateRequest
int tls_cert_type_from_oid(int oid)
{
switch (oid) {
case OID_sm2sign_with_sm3:
case OID_ecdsa_with_sha1:
case OID_ecdsa_with_sha224:
case OID_ecdsa_with_sha256:
case OID_ecdsa_with_sha512:
return TLS_cert_type_ecdsa_sign;
case OID_rsasign_with_sm3:
case OID_rsasign_with_md5:
case OID_rsasign_with_sha1:
case OID_rsasign_with_sha224:
case OID_rsasign_with_sha256:
case OID_rsasign_with_sha384:
case OID_rsasign_with_sha512:
return TLS_cert_type_rsa_sign;
}
// TLS_cert_type_xxx 中没有为0的值
return 0;
}
// 这两个函数没有对应的TLCP版本 这个现在已经有了ex版本了
int tls_sign_server_ecdh_params(const SM2_KEY *server_sign_key,
const uint8_t client_random[32], const uint8_t server_random[32],
int curve, const SM2_Z256_POINT *point, uint8_t *sig, size_t *siglen)
{
uint8_t server_ecdh_params[69];
SM2_SIGN_CTX sign_ctx;
if (!server_sign_key || !client_random || !server_random
|| curve != TLS_curve_sm2p256v1 || !point || !sig || !siglen) {
error_print();
return -1;
}
server_ecdh_params[0] = TLS_curve_type_named_curve;
server_ecdh_params[1] = (uint8_t)(curve >> 8);
server_ecdh_params[2] = (uint8_t)curve;
server_ecdh_params[3] = 65;
sm2_z256_point_to_uncompressed_octets(point, server_ecdh_params + 4);
sm2_sign_init(&sign_ctx, server_sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH);
sm2_sign_update(&sign_ctx, client_random, 32);
sm2_sign_update(&sign_ctx, server_random, 32);
sm2_sign_update(&sign_ctx, server_ecdh_params, 69);
sm2_sign_finish(&sign_ctx, sig, siglen);
return 1;
}
int tls_verify_server_ecdh_params(const SM2_KEY *server_sign_key,
const uint8_t client_random[32], const uint8_t server_random[32],
int curve, const SM2_Z256_POINT *point, const uint8_t *sig, size_t siglen)
{
int ret;
uint8_t server_ecdh_params[69];
SM2_VERIFY_CTX verify_ctx;
if (!server_sign_key || !client_random || !server_random
|| curve != TLS_curve_sm2p256v1 || !point || !sig || !siglen
|| siglen > SM2_MAX_SIGNATURE_SIZE) {
error_print();
return -1;
}
server_ecdh_params[0] = TLS_curve_type_named_curve;
server_ecdh_params[1] = (uint8_t)(curve >> 8);
server_ecdh_params[2] = (uint8_t)(curve);
server_ecdh_params[3] = 65;
sm2_z256_point_to_uncompressed_octets(point, server_ecdh_params + 4);
sm2_verify_init(&verify_ctx, server_sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH);
sm2_verify_update(&verify_ctx, client_random, 32);
sm2_verify_update(&verify_ctx, server_random, 32);
sm2_verify_update(&verify_ctx, server_ecdh_params, 69);
ret = sm2_verify_finish(&verify_ctx, sig, siglen);
if (ret != 1) error_print();
return ret;
}
int tls_record_set_handshake(uint8_t *record, size_t *recordlen,
int type, const uint8_t *data, size_t datalen)
{
size_t handshakelen;
if (!record || !recordlen) {
error_print();
return -1;
}
// 由于ServerHelloDone没有负载数据因此允许 data,datalen = NULL,0
if (datalen > TLS_MAX_PLAINTEXT_SIZE - TLS_HANDSHAKE_HEADER_SIZE) {
error_print();
return -1;
}
if (!tls_protocol_name(tls_record_protocol(record))) {
error_print();
return -1;
}
if (!tls_handshake_type_name(type)) {
error_print();
return -1;
}
handshakelen = TLS_HANDSHAKE_HEADER_SIZE + datalen;
record[0] = TLS_record_handshake;
record[3] = (uint8_t)(handshakelen >> 8);
record[4] = (uint8_t)(handshakelen);
record[5] = (uint8_t)(type);
record[6] = (uint8_t)(datalen >> 16);
record[7] = (uint8_t)(datalen >> 8);
record[8] = (uint8_t)(datalen);
if (data) {
memcpy(tls_handshake_data(tls_record_data(record)), data, datalen);
}
*recordlen = TLS_RECORD_HEADER_SIZE + handshakelen;
return 1;
}
int tls_record_set_handshake_header(uint8_t *record, size_t *recordlen,
int type, int length)
{
if (tls_record_set_handshake(record, recordlen, type, NULL, length) != 1) {
error_print();
return -1;
}
return 1;
}
int tls_record_get_handshake(const uint8_t *record,
int *type, const uint8_t **data, size_t *datalen)
{
const uint8_t *handshake;
size_t handshake_len;
uint24_t handshake_datalen;
if (!record || !type || !data || !datalen) {
error_print();
return -1;
}
if (!tls_protocol_name(tls_record_protocol(record))) {
error_print();
return -1;
}
if (tls_record_type(record) != TLS_record_handshake) {
error_print();
return -1;
}
handshake = tls_record_data(record);
handshake_len = tls_record_data_length(record);
if (handshake_len < TLS_HANDSHAKE_HEADER_SIZE) {
error_print();
return -1;
}
if (handshake_len > TLS_MAX_PLAINTEXT_SIZE) {
// 不支持证书长度超过记录长度的特殊情况
error_print();
return -1;
}
if (!tls_handshake_type_name(handshake[0])) {
error_print();
return -1;
}
*type = handshake[0];
handshake++;
handshake_len--;
if (tls_uint24_from_bytes(&handshake_datalen, &handshake, &handshake_len) != 1) {
error_print();
return -1;
}
if (handshake_len != handshake_datalen) {
error_print();
return -1;
}
*data = handshake;
*datalen = handshake_datalen;
if (*datalen == 0) {
*data = NULL;
}
return 1;
}
int tls_record_set_handshake_client_hello(uint8_t *record, size_t *recordlen,
int protocol, const uint8_t random[32],
const uint8_t *session_id, size_t session_id_len,
const int *cipher_suites, size_t cipher_suites_count,
const uint8_t *exts, size_t exts_len)
{
uint8_t type = TLS_handshake_client_hello;
uint8_t *p;
size_t len;
if (!record || !recordlen || !random || !cipher_suites || !cipher_suites_count) {
error_print();
return -1;
}
if (session_id) {
if (!session_id_len
|| session_id_len < TLS_MAX_SESSION_ID_SIZE
|| session_id_len > TLS_MAX_SESSION_ID_SIZE) {
error_print();
return -1;
}
}
if (cipher_suites_count > TLS_MAX_CIPHER_SUITES_COUNT) {
error_print();
return -1;
}
if (exts && !exts_len) {
error_print();
return -1;
}
p = tls_handshake_data(tls_record_data(record));
len = 0;
if (!tls_protocol_name(protocol)) {
error_print();
return -1;
}
tls_uint16_to_bytes((uint16_t)protocol, &p, &len);
tls_array_to_bytes(random, 32, &p, &len);
tls_uint8array_to_bytes(session_id, session_id_len, &p, &len);
tls_uint16_to_bytes((uint16_t)(cipher_suites_count * 2), &p, &len);
while (cipher_suites_count--) {
if (!tls_cipher_suite_name(*cipher_suites)) {
error_print();
return -1;
}
tls_uint16_to_bytes((uint16_t)*cipher_suites, &p, &len);
cipher_suites++;
}
tls_uint8_to_bytes(1, &p, &len);
tls_uint8_to_bytes((uint8_t)TLS_compression_null, &p, &len);
if (exts) {
size_t tmp_len = len;
if (protocol < TLS_protocol_tls12) {
error_print();
return -1;
}
tls_uint16array_to_bytes(exts, exts_len, NULL, &tmp_len);
if (tmp_len > TLS_MAX_HANDSHAKE_DATA_SIZE) {
error_print();
return -1;
}
tls_uint16array_to_bytes(exts, exts_len, &p, &len);
}
if (tls_record_set_handshake(record, recordlen, type, NULL, len) != 1) {
error_print();
return -1;
}
return 1;
}
int tls_record_get_handshake_client_hello(const uint8_t *record,
int *protocol, const uint8_t **random,
const uint8_t **session_id, size_t *session_id_len,
const uint8_t **cipher_suites, size_t *cipher_suites_len,
const uint8_t **exts, size_t *exts_len)
{
int type;
const uint8_t *p;
size_t len;
uint16_t ver;
const uint8_t *comp_meths;
size_t comp_meths_len;
if (!record || !protocol || !random
|| !session_id || !session_id_len
|| !cipher_suites || !cipher_suites_len
|| !exts || !exts_len) {
error_print();
return -1;
}
if (tls_record_get_handshake(record, &type, &p, &len) != 1) {
error_print();
return -1;
}
if (type != TLS_handshake_client_hello) {
error_print();
return -1;
}
if (tls_uint16_from_bytes(&ver, &p, &len) != 1
|| tls_array_from_bytes(random, 32, &p, &len) != 1
|| tls_uint8array_from_bytes(session_id, session_id_len, &p, &len) != 1
|| tls_uint16array_from_bytes(cipher_suites, cipher_suites_len, &p, &len) != 1
|| tls_uint8array_from_bytes(&comp_meths, &comp_meths_len, &p, &len) != 1) {
error_print();
return -1;
}
if (!tls_protocol_name(ver)) {
error_print();
return -1;
}
*protocol = ver;
if (*session_id) {
if (*session_id_len == 0
|| *session_id_len < TLS_MIN_SESSION_ID_SIZE
|| *session_id_len > TLS_MAX_SESSION_ID_SIZE) {
error_print();
return -1;
}
}
if (!cipher_suites) {
error_print();
return -1;
}
if (*cipher_suites_len % 2) {
error_print();
return -1;
}
if (len) {
if (tls_uint16array_from_bytes(exts, exts_len, &p, &len) != 1) {
error_print();
return -1;
}
if (*exts == NULL) {
error_print();
return -1;
}
} else {
*exts = NULL;
*exts_len = 0;
}
if (len) {
error_print();
return -1;
}
return 1;
}
int tls_record_set_handshake_server_hello(uint8_t *record, size_t *recordlen,
int protocol, const uint8_t random[32],
const uint8_t *session_id, size_t session_id_len, int cipher_suite,
const uint8_t *exts, size_t exts_len)
{
uint8_t type = TLS_handshake_server_hello;
uint8_t *p;
size_t len;
if (!record || !recordlen || !random) {
error_print();
return -1;
}
if (session_id) {
if (session_id_len == 0
|| session_id_len < TLS_MIN_SESSION_ID_SIZE
|| session_id_len > TLS_MAX_SESSION_ID_SIZE) {
error_print();
return -1;
}
}
if (!tls_protocol_name(protocol)) {
error_print();
return -1;
}
if (!tls_cipher_suite_name(cipher_suite)) {
error_print();
return -1;
}
p = tls_handshake_data(tls_record_data(record));
len = 0;
tls_uint16_to_bytes((uint16_t)protocol, &p, &len);
tls_array_to_bytes(random, 32, &p, &len);
tls_uint8array_to_bytes(session_id, session_id_len, &p, &len);
tls_uint16_to_bytes((uint16_t)cipher_suite, &p, &len);
tls_uint8_to_bytes((uint8_t)TLS_compression_null, &p, &len);
if (exts) {
if (protocol < TLS_protocol_tls12) {
error_print();
return -1;
}
tls_uint16array_to_bytes(exts, exts_len, &p, &len);
}
if (tls_record_set_handshake(record, recordlen, type, NULL, len) != 1) {
error_print();
return -1;
}
return 1;
}
/*
如果报文的结构正确但是数据不合法的时候应该返回TLS_alert_illegal_parameter
例如服务器的选择不在ClientHello提供的列表中
因此涉及到语义错误的,应该返回这个错误。
如果语义我们不能理解,但是格式正确,那么应该忽略
*/
int tls_record_get_handshake_server_hello(const uint8_t *record,
int *protocol, const uint8_t **random, const uint8_t **session_id, size_t *session_id_len,
int *cipher_suite, const uint8_t **exts, size_t *exts_len)
{
int type;
const uint8_t *p;
size_t len;
uint16_t ver;
uint16_t cipher;
uint8_t comp_meth;
if (!record || !protocol || !random || !session_id || !session_id_len
|| !cipher_suite || !exts || !exts_len) {
error_print();
return -1;
}
if (tls_record_get_handshake(record, &type, &p, &len) != 1) {
error_print();
return -1;
}
if (type != TLS_handshake_server_hello) {
error_print();
return 0;
}
if (tls_uint16_from_bytes(&ver, &p, &len) != 1
|| tls_array_from_bytes(random, 32, &p, &len) != 1
|| tls_uint8array_from_bytes(session_id, session_id_len, &p, &len) != 1
|| tls_uint16_from_bytes(&cipher, &p, &len) != 1
|| tls_uint8_from_bytes(&comp_meth, &p, &len) != 1) {
error_print();
return -1;
}
if (!tls_protocol_name(ver)) {
error_print();
return -1;
}
if (ver < tls_record_protocol(record)) {
error_print();
return -1;
}
*protocol = ver;
if (*session_id) {
if (*session_id == 0
|| *session_id_len < TLS_MIN_SESSION_ID_SIZE
|| *session_id_len > TLS_MAX_SESSION_ID_SIZE) {
error_print();
return -1;
}
}
if (!tls_cipher_suite_name(cipher)) {
error_print();
return -1;
}
*cipher_suite = cipher;
// 这个值应该返回给上一层,这样上面才能够返回正确的错误值
if (comp_meth != TLS_compression_null) {
error_print();
return -1;
}
if (len) {
if (tls_uint16array_from_bytes(exts, exts_len, &p, &len) != 1) {
error_print();
return -1;
}
if (*exts == NULL) {
error_print();
return -1;
}
} else {
*exts = NULL;
*exts_len = 0;
}
if (len) {
error_print();
return -1;
}
return 1;
}
int tls_record_set_handshake_certificate(uint8_t *record, size_t *recordlen,
const uint8_t *certs, size_t certslen)
{
int type = TLS_handshake_certificate;
uint8_t *data;
size_t datalen;
uint8_t *p;
size_t len;
if (!record || !recordlen || !certs || !certslen) {
error_print();
return -1;
}
data = tls_handshake_data(tls_record_data(record));
p = data + tls_uint24_size();
datalen = tls_uint24_size();
len = 0;
while (certslen) {
const uint8_t *cert;
size_t certlen;
if (x509_cert_from_der(&cert, &certlen, &certs, &certslen) != 1) {
error_print();
return -1;
}
tls_uint24array_to_bytes(cert, certlen, NULL, &datalen);
if (datalen > TLS_MAX_HANDSHAKE_DATA_SIZE) {
error_print();
return -1;
}
tls_uint24array_to_bytes(cert, certlen, &p, &len);
}
tls_uint24_to_bytes((uint24_t)len, &data, &len);
tls_record_set_handshake(record, recordlen, type, NULL, datalen);
return 1;
}
int tls_record_get_handshake_certificate(const uint8_t *record, uint8_t *certs, size_t *certslen)
{
int type;
const uint8_t *data;
size_t datalen;
const uint8_t *cp;
size_t len;
if (tls_record_get_handshake(record, &type, &data, &datalen) != 1) {
error_print();
return -1;
}
if (type != TLS_handshake_certificate) {
error_print();
return -1;
}
if (tls_uint24array_from_bytes(&cp, &len, &data, &datalen) != 1) {
error_print();
return -1;
}
*certslen = 0;
while (len) {
const uint8_t *a;
size_t alen;
const uint8_t *cert;
size_t certlen;
if (tls_uint24array_from_bytes(&a, &alen, &cp, &len) != 1) {
error_print();
return -1;
}
if (x509_cert_from_der(&cert, &certlen, &a, &alen) != 1
|| asn1_length_is_zero(alen) != 1
|| x509_cert_to_der(cert, certlen, &certs, certslen) != 1) {
error_print();
return -1;
}
}
return 1;
}
/*
struct {
ClientCertificateType certificate_types<1..2^8-1>;
SignatureAndHashAlgorithm supported_signature_algorithms<2^16-1>; // 可能缺少这个参数
DistinguishedName certificate_authorities<0..2^16-1>;
} CertificateRequest;
*/
int tls_record_set_handshake_certificate_request(uint8_t *record, size_t *recordlen,
const uint8_t *cert_types, size_t cert_types_len,
const uint8_t *ca_names, size_t ca_names_len)
{
int type = TLS_handshake_certificate_request;
uint8_t *p;
size_t len =0;
size_t datalen = 0;
if (!record || !recordlen) {
error_print();
return -1;
}
if (cert_types) {
if (cert_types_len == 0 || cert_types_len > TLS_MAX_CERTIFICATE_TYPES) {
error_print();
return -1;
}
}
if (ca_names) {
if (ca_names_len == 0 || ca_names_len > TLS_MAX_CA_NAMES_SIZE) {
error_print();
return -1;
}
}
tls_uint8array_to_bytes(cert_types, cert_types_len, NULL, &datalen);
tls_uint16array_to_bytes(ca_names, ca_names_len, NULL, &datalen);
if (datalen > TLS_MAX_HANDSHAKE_DATA_SIZE) {
error_print();
return -1;
}
p = tls_handshake_data(tls_record_data(record));
tls_uint8array_to_bytes(cert_types, cert_types_len, &p, &len);
tls_uint16array_to_bytes(ca_names, ca_names_len, &p, &len);
tls_record_set_handshake(record, recordlen, type, NULL, datalen);
return 1;
}
int tls_record_get_handshake_certificate_request(const uint8_t *record,
const uint8_t **cert_types, size_t *cert_types_len,
const uint8_t **ca_names, size_t *ca_names_len)
{
int type;
const uint8_t *cp;
size_t len;
size_t i;
if (!record || !cert_types || !cert_types_len || !ca_names || !ca_names_len) {
error_print();
return -1;
}
if (tls_record_get_handshake(record, &type, &cp, &len) != 1) {
error_print();
return -1;
}
if (type != TLS_handshake_certificate_request) {
error_print();
return -1;
}
if (tls_uint8array_from_bytes(cert_types, cert_types_len, &cp, &len) != 1
|| tls_uint16array_from_bytes(ca_names, ca_names_len, &cp, &len) != 1
|| tls_length_is_zero(len) != 1) {
error_print();
return -1;
}
if (*cert_types == NULL) {
error_print();
return -1;
}
for (i = 0; i < *cert_types_len; i++) {
if (!tls_cert_type_name((*cert_types)[i])) {
error_print();
return -1;
}
}
if (*ca_names) {
const uint8_t *names = *ca_names;
size_t nameslen = *ca_names_len;
while (nameslen) {
if (tls_uint16array_from_bytes(&cp, &len, &names, &nameslen) != 1) {
error_print();
return -1;
}
}
}
return 1;
}
int tls_record_set_handshake_server_hello_done(uint8_t *record, size_t *recordlen)
{
int type = TLS_handshake_server_hello_done;
if (!record || !recordlen) {
error_print();
return -1;
}
tls_record_set_handshake(record, recordlen, type, NULL, 0);
return 1;
}
int tls_record_get_handshake_server_hello_done(const uint8_t *record)
{
int type;
const uint8_t *p;
size_t len;
if (!record) {
error_print();
return -1;
}
if (tls_record_get_handshake(record, &type, &p, &len) != 1
|| type != TLS_handshake_server_hello_done) {
error_print();
return -1;
}
if (p != NULL || len != 0) {
error_print();
return -1;
}
return 1;
}
int tls_record_set_handshake_client_key_exchange_pke(uint8_t *record, size_t *recordlen,
const uint8_t *enced_pms, size_t enced_pms_len)
{
int type = TLS_handshake_client_key_exchange;
uint8_t *p;
size_t len = 0;
if (!record || !recordlen || !enced_pms || !enced_pms_len) {
error_print();
return -1;
}
if (enced_pms_len > TLS_MAX_HANDSHAKE_DATA_SIZE - tls_uint16_size()) {
error_print();
return -1;
}
p = tls_handshake_data(tls_record_data(record));
tls_uint16array_to_bytes(enced_pms, enced_pms_len, &p, &len);
tls_record_set_handshake(record, recordlen, type, NULL, len);
return 1;
}
int tls_record_get_handshake_client_key_exchange_pke(const uint8_t *record,
const uint8_t **enced_pms, size_t *enced_pms_len)
{
int type;
const uint8_t *cp;
size_t len;
if (!record || !enced_pms || !enced_pms_len) {
error_print();
return -1;
}
if (tls_record_get_handshake(record, &type, &cp, &len) != 1) {
error_print();
return -1;
}
if (type != TLS_handshake_client_key_exchange) {
error_print();
return -1;
}
if (tls_uint16array_from_bytes(enced_pms, enced_pms_len, &cp, &len) != 1
|| tls_length_is_zero(len) != 1) {
error_print();
return -1;
}
return 1;
}
int tls_record_set_handshake_certificate_verify(uint8_t *record, size_t *recordlen,
const uint8_t *sig, size_t siglen)
{
int type = TLS_handshake_certificate_verify;
uint8_t *p;
size_t len = 0;
if (!record || !recordlen || !sig || !siglen) {
error_print();
return -1;
}
if (siglen > TLS_MAX_SIGNATURE_SIZE) {
error_print();
return -1;
}
p = tls_handshake_data(tls_record_data(record));
tls_uint16array_to_bytes(sig, siglen, &p, &len);
tls_record_set_handshake(record, recordlen, type, NULL, len);
return 1;
}
int tls_record_get_handshake_certificate_verify(const uint8_t *record,
const uint8_t **sig, size_t *siglen)
{
int type;
const uint8_t *cp;
size_t len;
if (!record || !sig || !siglen) {
error_print();
return -1;
}
if (tls_record_get_handshake(record, &type, &cp, &len) != 1) {
error_print();
return -1;
}
if (type != TLS_handshake_certificate_verify) {
error_print();
return -1;
}
if (tls_uint16array_from_bytes(sig, siglen, &cp, &len) != 1
|| tls_length_is_zero(len) != 1) {
error_print();
return -1;
}
return 1;
}
int tls_record_set_handshake_finished(uint8_t *record, size_t *recordlen,
const uint8_t *verify_data, size_t verify_data_len)
{
int type = TLS_handshake_finished;
if (!record || !recordlen || !verify_data || !verify_data_len) {
error_print();
return -1;
}
if (verify_data_len != 12 && verify_data_len != 32) {
error_print();
return -1;
}
tls_record_set_handshake(record, recordlen, type, verify_data, verify_data_len);
return 1;
}
// 这个应该改为只支持TLS 1.2的12字节长度判断
int tls_record_get_handshake_finished(const uint8_t *record, const uint8_t **verify_data, size_t *verify_data_len)
{
int type;
if (!record || !verify_data || !verify_data_len) {
error_print();
return -1;
}
if (tls_record_get_handshake(record, &type, verify_data, verify_data_len) != 1) {
error_print();
return -1;
}
if (type != TLS_handshake_finished) {
error_print();
return -1;
}
if (*verify_data == NULL || *verify_data_len == 0) {
error_print();
return -1;
}
if (*verify_data_len != 12 && *verify_data_len != 32) {
error_print();
return -1;
}
return 1;
}
int tls_record_set_alert(uint8_t *record, size_t *recordlen,
int alert_level,
int alert_description)
{
if (!record || !recordlen) {
error_print();
return -1;
}
if (!tls_alert_level_name(alert_level)) {
error_print();
return -1;
}
if (!tls_alert_description_text(alert_description)) {
error_print();
return -1;
}
record[0] = TLS_record_alert;
//record[1] = protocol.major should be set by others
//record[2] = protocol.minor should be set by others
record[3] = 0; // length
record[4] = 2; // length
record[5] = (uint8_t)alert_level;
record[6] = (uint8_t)alert_description;
*recordlen = TLS_RECORD_HEADER_SIZE + 2;
return 1;
}
int tls_record_get_alert(const uint8_t *record,
int *alert_level,
int *alert_description)
{
if (!record || !alert_level || !alert_description) {
error_print();
return -1;
}
if (tls_record_type(record) != TLS_record_alert) {
error_print();
return -1;
}
if (record[3] != 0 || record[4] != 2) {
error_print();
return -1;
}
*alert_level = record[5];
*alert_description = record[6];
if (!tls_alert_level_name(*alert_level)) {
error_print();
return -1;
}
if (!tls_alert_description_text(*alert_description)) {
error_puts("warning");
return -1;
}
return 1;
}
int tls_record_set_change_cipher_spec(uint8_t *record, size_t *recordlen)
{
if (!record || !recordlen) {
error_print();
return -1;
}
record[0] = TLS_record_change_cipher_spec;
record[3] = 0;
record[4] = 1;
record[5] = TLS_change_cipher_spec;
*recordlen = TLS_RECORD_HEADER_SIZE + 1;
return 1;
}
int tls_record_get_change_cipher_spec(const uint8_t *record)
{
if (!record) {
error_print();
return -1;
}
if (tls_record_type(record) != TLS_record_change_cipher_spec) {
error_print();
return -1;
}
if (record[3] != 0 || record[4] != 1) {
error_print();
return -1;
}
if (record[5] != TLS_change_cipher_spec) {
error_print();
return -1;
}
return 1;
}
int tls_record_set_application_data(uint8_t *record, size_t *recordlen,
const uint8_t *data, size_t datalen)
{
if (!record || !recordlen || !data || !datalen) {
error_print();
return -1;
}
record[0] = TLS_record_application_data;
record[3] = (datalen >> 8) & 0xff;
record[4] = datalen & 0xff;
memcpy(tls_record_data(record), data, datalen);
*recordlen = TLS_RECORD_HEADER_SIZE + datalen;
return 1;
}
int tls_record_get_application_data(uint8_t *record,
const uint8_t **data, size_t *datalen)
{
if (!record || !data || !datalen) {
error_print();
return -1;
}
if (tls_record_type(record) != TLS_record_application_data) {
error_print();
return -1;
}
*datalen = ((size_t)record[3] << 8) | record[4];
*data = *datalen ? record + TLS_RECORD_HEADER_SIZE : 0;
return 1;
}
int tls_type_is_in_list(int type, const int *list, size_t list_count)
{
size_t i;
for (i = 0; i < list_count; i++) {
if (type == list[i]) {
return 1;
}
}
return 0;
}
static const int tlcp_ciphers[] = {
TLS_cipher_ecc_sm4_cbc_sm3,
TLS_cipher_ecc_sm4_gcm_sm3,
TLS_cipher_ibc_sm4_cbc_sm3,
TLS_cipher_ibc_sm4_gcm_sm3,
};
static const int tls12_ciphers[] = {
TLS_cipher_ecdhe_sm4_cbc_sm3,
TLS_cipher_ecdhe_sm4_gcm_sm3,
TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256,
};
static const int tls13_ciphers[] = {
TLS_cipher_sm4_gcm_sm3,
};
int tls_cipher_suite_support_protocol(int cipher, int protocol)
{
const int *ciphers;
size_t ciphers_cnt;
switch (protocol) {
case TLS_protocol_tlcp:
ciphers = tlcp_ciphers;
ciphers_cnt = sizeof(tlcp_ciphers)/sizeof(tlcp_ciphers[0]);
break;
case TLS_protocol_tls12:
ciphers = tls12_ciphers;
ciphers_cnt = sizeof(tls12_ciphers)/sizeof(tls12_ciphers[0]);
break;
case TLS_protocol_tls13:
ciphers = tls13_ciphers;
ciphers_cnt = sizeof(tls13_ciphers)/sizeof(tls13_ciphers[0]);
break;
default:
error_print();
return -1;
}
/*
if (!tls_cipher_suite_in_list(cipher, ciphers, ciphers_cnt)) {
error_print();
return 0;
}
*/
return 1;
}
/*
尽可能的发送数据直到发送完整的报文或者send 返回错误
如果send 返回EAGAIN那么向上层返回WANT_WRITE
正常情况下,一方总是可以发送任意数量的数据,当发送方缓冲区已经满了的时候
send会返回EAGIN那么如果底层没处理完那就没有任何办法
如果这个函数在获得EAGAIN之后就返回给上层了那么还需要标明到底发送出去了多少数据
*/
int tls_record_send(const uint8_t *record, size_t recordlen, tls_socket_t sock)
{
tls_ret_t n;
if (!record) {
error_print();
return -1;
}
if (recordlen < TLS_RECORD_HEADER_SIZE) {
error_print();
return -1;
}
if (tls_record_length(record) != recordlen) {
error_print();
return -1;
}
while (recordlen) {
if ((n = tls_socket_send(sock, record, recordlen, 0)) > 0) {
record += n;
recordlen -= n;
} else if (n == 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
tls_socket_wait();
} else {
error_puts("TCP connection closed");
return 0;
}
} else {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
tls_socket_wait();
} else {
error_print();
return -1;
}
}
}
return 1;
}
int tls_record_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock)
{
uint8_t *p = record;
size_t len;
tls_ret_t n;
len = 5;
while (len) {
if ((n = tls_socket_recv(sock, p, len, 0)) > 0) {
p += n;
len -= n;
} else if (n == 0) {
tls_trace("TCP connection closed");
*recordlen = 0;
return 0;
} else {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
if (len == 5) {
return -EAGAIN;
}
tls_socket_wait();
} else {
perror("recv");
error_print();
return -1;
}
}
}
if (!tls_record_type_name(tls_record_type(record))) {
error_print();
return -1;
}
if (!tls_protocol_name(tls_record_protocol(record))) {
error_print();
return -1;
}
len = (size_t)record[3] << 8 | record[4];
*recordlen = 5 + len;
if (*recordlen > TLS_MAX_RECORD_SIZE) {
error_print();
return -1;
}
while (len) {
if ((n = tls_socket_recv(sock, p, len, 0)) > 0) {
p += n;
len -= n;
} else if (n == 0) {
tls_trace("connection closed");
*recordlen = 0;
return 0;
} else {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
tls_socket_wait();
} else {
perror("recv");
error_print();
return -1;
}
}
}
return 1;
}
int tls_seq_num_incr(uint8_t seq_num[8])
{
int i;
for (i = 7; i > 0; i--) {
seq_num[i]++;
if (seq_num[i]) break;
}
// FIXME: check overflow
return 1;
}
void tls_seq_num_reset(uint8_t seq_num[8])
{
memset(seq_num, 0, 8);
}
int tls_compression_methods_has_null_compression(const uint8_t *meths, size_t methslen)
{
if (!meths || !methslen) {
error_print();
return -1;
}
while (methslen--) {
if (*meths++ == TLS_compression_null) {
return 1;
}
}
error_print();
return -1;
}
int tls_send_alert(TLS_CONNECT *conn, int alert)
{
uint8_t record[5 + 2];
size_t recordlen;
if (!conn) {
error_print();
return -1;
}
tls_record_set_protocol(record, conn->protocol == TLS_protocol_tls13 ? TLS_protocol_tls12 : conn->protocol);
tls_record_set_alert(record, &recordlen, TLS_alert_level_fatal, alert);
if (tls_record_send(record, sizeof(record), conn->sock) != 1) {
error_print();
return -1;
}
tls_record_trace(stderr, record, sizeof(record), 0, 0);
return 1;
}
int tls_alert_level(int alert)
{
switch (alert) {
case TLS_alert_unexpected_message:
case TLS_alert_bad_record_mac:
case TLS_alert_record_overflow:
case TLS_alert_decompression_failure:
case TLS_alert_handshake_failure:
case TLS_alert_illegal_parameter:
case TLS_alert_unknown_ca:
case TLS_alert_access_denied:
case TLS_alert_decode_error:
case TLS_alert_decrypt_error:
case TLS_alert_protocol_version:
case TLS_alert_insufficient_security:
case TLS_alert_internal_error:
case TLS_alert_unsupported_extension:
return TLS_alert_level_fatal;
case TLS_alert_user_canceled:
case TLS_alert_no_renegotiation:
return TLS_alert_level_warning;
}
return TLS_alert_level_undefined;
}
int tls_send_warning(TLS_CONNECT *conn, int alert)
{
uint8_t record[5 + 2];
size_t recordlen;
if (!conn) {
error_print();
return -1;
}
if (tls_alert_level(alert) == TLS_alert_level_fatal) {
error_print();
return -1;
}
tls_record_set_protocol(record, conn->protocol == TLS_protocol_tls13 ? TLS_protocol_tls12 : conn->protocol);
tls_record_set_alert(record, &recordlen, TLS_alert_level_warning, alert);
if (tls_record_send(record, sizeof(record), conn->sock) != 1) {
error_print();
return -1;
}
tls_record_trace(stderr, record, sizeof(record), 0, 0);
return 1;
}
static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *in, size_t inlen, size_t *sentlen)
{
const SM3_HMAC_CTX *hmac_ctx;
const SM4_KEY *enc_key;
uint8_t *seq_num;
size_t recordlen;
if (!conn) {
error_print();
return -1;
}
if (!in || !inlen || !sentlen) {
error_print();
return -1;
}
if (inlen > TLS_MAX_PLAINTEXT_SIZE) {
inlen = TLS_MAX_PLAINTEXT_SIZE;
}
if (conn->datalen) {
error_puts("recv all buffered data before send");
return -1;
}
if (conn->is_client) {
hmac_ctx = &conn->client_write_mac_ctx;
enc_key = &conn->client_write_enc_key;
seq_num = conn->client_seq_num;
} else {
hmac_ctx = &conn->server_write_mac_ctx;
enc_key = &conn->server_write_enc_key;
seq_num = conn->server_seq_num;
}
if (tls_record_set_type(conn->databuf, record_type) != 1
|| tls_record_set_protocol(conn->databuf, conn->protocol) != 1
|| tls_record_set_data(conn->databuf, in, inlen) != 1) {
error_print();
return -1;
}
tls_record_trace(stderr, conn->databuf, tls_record_length(conn->databuf), 0, 0);
if (tls_record_encrypt(hmac_ctx, enc_key, seq_num,
conn->databuf, tls_record_length(conn->databuf),
conn->record, &recordlen) != 1) {
error_print();
return -1;
}
tls_seq_num_incr(seq_num);
if (tls_record_send(conn->record, recordlen, conn->sock) != 1) {
error_print();
return -1;
}
tls_encrypted_record_trace(stderr, conn->record, recordlen, 0, 0);
*sentlen = inlen;
return 1;
}
int tls_decrypt_recv(TLS_CONNECT *conn)
{
int ret;
const SM3_HMAC_CTX *hmac_ctx;
const SM4_KEY *dec_key;
uint8_t *seq_num;
uint8_t *record = conn->record;
size_t recordlen;
if (conn->is_client) {
hmac_ctx = &conn->server_write_mac_ctx;
dec_key = &conn->server_write_enc_key;
seq_num = conn->server_seq_num;
} else {
hmac_ctx = &conn->client_write_mac_ctx;
dec_key = &conn->client_write_enc_key;
seq_num = conn->client_seq_num;
}
tls_trace("recv Encrypted Record\n");
if ((ret = tls_record_recv(record, &recordlen, conn->sock)) != 1) {
if (ret < 0 && ret != -EAGAIN) error_print();
return ret;
}
tls_encrypted_record_trace(stderr, record, recordlen, 0, 0);
if (tls_record_decrypt(hmac_ctx, dec_key, seq_num,
record, recordlen,
conn->databuf, &conn->datalen) != 1) {
error_print();
return -1;
}
tls_seq_num_incr(seq_num);
conn->data = tls_record_data(conn->databuf);
conn->datalen = tls_record_data_length(conn->databuf);
tls_record_trace(stderr, conn->databuf, tls_record_length(conn->databuf), 0, 0);
return 1;
}
int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen)
{
tls_trace("send ApplicationData\n");
return tls_encrypt_send(conn, TLS_record_application_data, in, inlen, sentlen);
}
int tls_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen)
{
if (!conn || !out || !outlen || !recvlen) {
error_print();
return -1;
}
if (conn->datalen == 0) {
int ret;
if ((ret = tls_decrypt_recv(conn)) != 1) {
if (ret < 0 && ret != -EAGAIN) error_print();
return ret;
}
switch (tls_record_type(conn->record)) {
case TLS_record_application_data:
break;
case TLS_record_change_cipher_spec:
error_print();
return -1;
case TLS_record_alert:
{
// should call tls_process_alert()
int level;
int alert;
tls_record_get_alert(conn->databuf, &level, &alert);
if (alert == TLS_alert_close_notify) {
tls_trace("recv Alert.close_notify\n");
return 0;
}
tls_trace("alert received\n");
return -1;
}
default:
error_print();
return -1;
}
}
*recvlen = outlen <= conn->datalen ? outlen : conn->datalen;
memcpy(out, conn->data, *recvlen);
conn->data += *recvlen;
conn->datalen -= *recvlen;
return 1;
}
int tls_shutdown(TLS_CONNECT *conn)
{
int ret;
size_t recordlen;
uint8_t alert[2];
alert[0] = TLS_alert_level_fatal;
alert[1] = TLS_alert_close_notify;
if (!conn) {
error_print();
return -1;
}
tls_trace("send Alert.close_notify\n");
if (tls_encrypt_send(conn, TLS_record_alert, alert, sizeof(alert), &recordlen) != 1) {
error_print();
return -1;
}
tls_trace("recv Alert.close_notify\n");
if ((ret = tls_decrypt_recv(conn)) != 1) {
if (ret == 0) tls_trace("Connection closed by remote without close_notify\n");
else if (ret == -EAGAIN) tls_trace("-EAGAIN\n");
else error_print();
return -1;
}
return 1;
}
int tls_authorities_from_certs(uint8_t *names, size_t *nameslen, size_t maxlen, const uint8_t *certs, size_t certslen)
{
const uint8_t *cert;
size_t certlen;
const uint8_t *name;
size_t namelen;
*nameslen = 0;
while (certslen) {
size_t alen = 0;
if (x509_cert_from_der(&cert, &certlen, &certs, &certslen) != 1
|| x509_cert_get_subject(cert, certlen, &name, &namelen) != 1
|| asn1_sequence_to_der(name, namelen, NULL, &alen) != 1) {
error_print();
return -1;
}
if (tls_uint16_size() + alen > maxlen) {
error_print();
return -1;
}
if (alen > UINT16_MAX) {
error_print();
return -1;
}
tls_uint16_to_bytes((uint16_t)alen, &names, nameslen);
if (asn1_sequence_to_der(name, namelen, &names, nameslen) != 1) {
error_print();
return -1;
}
maxlen -= alen;
}
return 1;
}
int tls_authorities_issued_certificate(const uint8_t *ca_names, size_t ca_names_len, const uint8_t *certs, size_t certslen)
{
const uint8_t *cert;
size_t certlen;
const uint8_t *issuer;
size_t issuer_len;
if (x509_certs_get_last(certs, certslen, &cert, &certlen) != 1
|| x509_cert_get_issuer(cert, certlen, &issuer, &issuer_len) != 1) {
error_print();
return -1;
}
while (ca_names_len) {
const uint8_t *p;
size_t len;
const uint8_t *name;
size_t namelen;
if (tls_uint16array_from_bytes(&p, &len, &ca_names, &ca_names_len) != 1) {
error_print();
return -1;
}
if (asn1_sequence_from_der(&name, &namelen, &p, &len) != 1
|| asn1_length_is_zero(len) != 1) {
error_print();
return -1;
}
if (x509_name_equ(name, namelen, issuer, issuer_len) == 1) {
return 1;
}
}
error_print();
return 0;
}
int tls_cert_types_accepted(const uint8_t *types, size_t types_len, const uint8_t *client_certs, size_t client_certs_len)
{
const uint8_t *cert;
size_t certlen;
int sig_alg;
size_t i;
if (x509_certs_get_cert_by_index(client_certs, client_certs_len, 0, &cert, &certlen) != 1) {
error_print();
return -1;
}
if ((sig_alg = tls_cert_type_from_oid(OID_sm2sign_with_sm3)) < 0) {
error_print();
return -1;
}
for (i = 0; i < types_len; i++) {
if (sig_alg == types[i]) {
return 1;
}
}
return 0;
}
int tls_client_verify_init(TLS_CLIENT_VERIFY_CTX *ctx)
{
if (!ctx) {
error_print();
return -1;
}
memset(ctx, 0, sizeof(TLS_CLIENT_VERIFY_CTX));
return 1;
}
// FIXME: remove malloc!
int tls_client_verify_update(TLS_CLIENT_VERIFY_CTX *ctx, const uint8_t *handshake, size_t handshake_len)
{
uint8_t *buf;
if (!ctx || !handshake || !handshake_len) {
error_print();
return -1;
}
if (ctx->index < 0 || ctx->index > 7) {
error_print();
return -1;
}
if (!(buf = malloc(handshake_len))) {
error_print();
return -1;
}
memcpy(buf, handshake, handshake_len);
ctx->handshake[ctx->index] = buf;
ctx->handshake_len[ctx->index] = handshake_len;
ctx->index++;
return 1;
}
int tls_client_verify_finish(TLS_CLIENT_VERIFY_CTX *ctx, const uint8_t *sig, size_t siglen, const SM2_KEY *public_key)
{
int ret;
SM2_VERIFY_CTX verify_ctx;
int i;
if (!ctx || !sig || !siglen || !public_key) {
error_print();
return -1;
}
if (ctx->index != 8) {
error_print();
return -1;
}
// 这里的主要困难是SM2的签名验证需要以Z作为输入但是在没有拿到客户端的公钥之前无法启动验证
if (sm2_verify_init(&verify_ctx, public_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1) {
error_print();
return -1;
}
for (i = 0; i < 8; i++) {
if (sm2_verify_update(&verify_ctx, ctx->handshake[i], ctx->handshake_len[i]) != 1) {
error_print();
return -1;
}
}
if ((ret = sm2_verify_finish(&verify_ctx, sig, siglen)) < 0) {
error_print();
return -1;
}
return ret;
}
void tls_client_verify_cleanup(TLS_CLIENT_VERIFY_CTX *ctx)
{
if (ctx) {
int i;
for (i = 0; i< ctx->index; i++) {
if (ctx->handshake[i]) {
free(ctx->handshake[i]);
ctx->handshake[i] = NULL;
ctx->handshake_len[i] = 0;
}
}
}
}
// 这个函数不对啊,应该以服务器优先来选择参数
int tls_cipher_suites_select(const uint8_t *client_ciphers, size_t client_ciphers_len,
const int *server_ciphers, size_t server_ciphers_cnt,
int *selected_cipher)
{
if (!client_ciphers || !client_ciphers_len
|| !server_ciphers || !server_ciphers_cnt || !selected_cipher) {
error_print();
return -1;
}
while (server_ciphers_cnt--) {
const uint8_t *p = client_ciphers;
size_t len = client_ciphers_len;
while (len) {
uint16_t cipher;
if (tls_uint16_from_bytes(&cipher, &p, &len) != 1) {
error_print();
return -1;
}
if (cipher == *server_ciphers) {
*selected_cipher = *server_ciphers;
return 1;
}
}
server_ciphers++;
}
return 0;
}
void tls_ctx_cleanup(TLS_CTX *ctx)
{
if (ctx) {
gmssl_secure_clear(&ctx->signkey, sizeof(SM2_KEY));
gmssl_secure_clear(&ctx->kenckey, sizeof(SM2_KEY));
if (ctx->certs) free(ctx->certs);
if (ctx->cacerts) free(ctx->cacerts);
memset(ctx, 0, sizeof(TLS_CTX));
}
}
int tls_ctx_print(FILE *fp, int fmt, int ind, const char *label, const TLS_CTX *ctx)
{
return 0;
}
int tls_ctx_init(TLS_CTX *ctx, int protocol, int is_client)
{
const int supported_versions[] = {
TLS_protocol_tls13,
TLS_protocol_tls12,
TLS_protocol_tlcp,
};
if (!ctx) {
error_print();
return -1;
}
memset(ctx, 0, sizeof(*ctx));
// protocol
switch (protocol) {
case TLS_protocol_tlcp:
case TLS_protocol_tls12:
case TLS_protocol_tls13:
ctx->protocol = protocol;
break;
default:
error_print();
return -1;
}
ctx->is_client = is_client ? 1 : 0;
// supported_versions
memcpy(ctx->supported_versions, supported_versions, sizeof(supported_versions));
ctx->supported_versions_cnt = sizeof(supported_versions)/sizeof(supported_versions[0]);
ctx->verify_depth = 5;
// key_share
ctx->key_exchanges_cnt = 2;
return 1;
}
int tls_ctx_set_supported_versions(TLS_CTX *ctx, const int *versions, size_t versions_cnt)
{
size_t i;
if (!ctx || !versions || !versions_cnt) {
error_print();
return -1;
}
if (versions_cnt > sizeof(ctx->supported_versions)/sizeof(ctx->supported_versions[0])) {
error_print();
return -1;
}
for (i = 0; i < versions_cnt; i++) {
switch (versions[i]) {
case TLS_protocol_tls13:
case TLS_protocol_tls12:
case TLS_protocol_tlcp:
break;
default:
error_print();
return -1;
}
ctx->supported_versions[i] = versions[i];
}
ctx->supported_versions_cnt = versions_cnt;
return 1;
}
int tls_ctx_set_cipher_suites(TLS_CTX *ctx, const int *cipher_suites, size_t cipher_suites_cnt)
{
size_t i;
if (!ctx || !cipher_suites || !cipher_suites_cnt) {
error_print();
return -1;
}
if (cipher_suites_cnt > sizeof(ctx->cipher_suites)/sizeof(ctx->cipher_suites[0])) {
error_print();
return -1;
}
for (i = 0; i < cipher_suites_cnt; i++) {
if (!tls_cipher_suite_name(cipher_suites[i])) {
error_print();
return -1;
}
ctx->cipher_suites[i] = cipher_suites[i];
}
ctx->cipher_suites_cnt = cipher_suites_cnt;
return 1;
}
int tls_ctx_set_key_exchange_modes(TLS_CTX *ctx, int modes)
{
if (!ctx) {
error_print();
return -1;
}
if (modes & ~(TLS_KE_CERT_DHE|TLS_KE_PSK_DHE|TLS_KE_PSK)) {
error_print();
return -1;
}
ctx->key_exchange_modes = modes;
return 1;
}
// 这个函数不是很好,直接提供的是一个文件名
int tls_ctx_set_ca_certificates(TLS_CTX *ctx, const char *cacertsfile, int depth)
{
if (!ctx || !cacertsfile) {
error_print();
return -1;
}
if (depth < 0 || depth > TLS_MAX_VERIFY_DEPTH) {
error_print();
return -1;
}
if (!tls_protocol_name(ctx->protocol)) {
error_print();
return -1;
}
if (ctx->cacerts) {
error_print();
return -1;
}
if (x509_certs_new_from_file(&ctx->cacerts, &ctx->cacertslen, cacertsfile) != 1) {
error_print();
return -1;
}
if (ctx->cacertslen == 0) {
error_print();
return -1;
}
ctx->verify_depth = depth;
return 1;
}
// 这个函数需要设置一个默认的证书链
// 这个函数实际上是有问题的没有给这个证书链提供status_request和sct_list
// cert_chain的格式到底是什么呢
// 是单独的证书链,还是也包含扩展呢?
int tls_ctx_add_certificate_chain_and_key(TLS_CTX *ctx, const char *chainfile,
const char *keyfile, const char *keypass)
{
uint8_t *cert_chain;
size_t cert_chain_len;
FILE *certfp = NULL;
const uint8_t *cert;
size_t certlen;
X509_KEY public_key;
FILE *keyfp = NULL;
if (!ctx || !chainfile || !keyfile || !keypass) {
error_print();
return -1;
}
// no space in ctx->cert_chains[]
if (sizeof(ctx->cert_chains) <= ctx->cert_chains_len + tls_uint24_size()) {
error_print();
return -1;
}
// no space in ctx->x509_keys[]
if (sizeof(ctx->x509_keys)/sizeof(ctx->x509_keys[0]) <= ctx->x509_keys_cnt) {
error_print();
return -1;
}
if (!(certfp = fopen(chainfile, "r"))) {
error_print();
return -1;
}
// read and save cert_chain as uint24array
cert_chain = ctx->cert_chains + ctx->cert_chains_len;
if (x509_certs_from_pem(cert_chain + tls_uint24_size(), &cert_chain_len,
sizeof(ctx->cert_chains) - ctx->cert_chains_len - tls_uint24_size(),
certfp) != 1) {
error_print();
return -1;
}
tls_uint24_to_bytes(cert_chain_len, &cert_chain, &cert_chain_len);
ctx->cert_chains_len += cert_chain_len;
cert_chain_len -= tls_uint24_size();
if (x509_certs_get_cert_by_index(cert_chain, cert_chain_len, 0, &cert, &certlen) != 1
|| x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) {
fclose(certfp);
error_print();
return -1;
}
if (public_key.algor == OID_ec_public_key) {
if (!(keyfp = fopen(keyfile, "r"))) {
fclose(certfp);
error_print();
return -1;
}
} else {
if (!(keyfp = fopen(keyfile, "rb+"))) {
fclose(certfp);
error_print();
return -1;
}
}
// read and save at most two keys as uint16array
if (x509_private_key_from_file(&ctx->x509_keys[ctx->x509_keys_cnt], public_key.algor, keypass, keyfp) != 1) {
fclose(certfp);
fclose(keyfp);
error_print();
return -1;
}
ctx->x509_keys_cnt++;
fclose(certfp);
fclose(keyfp);
return 1;
}
// 保留这个函数,相当于是对证书链的初始化
int tls_ctx_set_certificate_and_key(TLS_CTX *ctx, const char *chainfile,
const char *keyfile, const char *keypass)
{
int ret = -1;
uint8_t *certs = NULL;
size_t certslen;
FILE *keyfp = NULL;
const uint8_t *cert;
size_t certlen;
X509_KEY public_key;
if (!ctx || !chainfile || !keyfile || !keypass) {
error_print();
return -1;
}
if (!tls_protocol_name(ctx->protocol)) {
error_print();
return -1;
}
if (ctx->certs) {
error_print();
return -1;
}
if (x509_certs_new_from_file(&certs, &certslen, chainfile) != 1) {
error_print();
goto end;
}
if (x509_certs_get_cert_by_index(certs, certslen, 0, &cert, &certlen) != 1
|| x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) {
error_print();
return -1;
}
if (public_key.algor == OID_ec_public_key) {
if (!(keyfp = fopen(keyfile, "r"))) {
error_print();
return -1;
}
} else {
if (!(keyfp = fopen(keyfile, "rb+"))) {
error_print();
return -1;
}
}
if (x509_private_key_from_file(&ctx->signkey, public_key.algor, keypass, keyfp) != 1) {
error_print();
return -1;
}
ctx->certs = certs;
ctx->certslen = certslen;
certs = NULL;
ret = 1;
end:
if (certs) free(certs);
if (keyfp) fclose(keyfp);
return ret;
}
int tls_ctx_set_tlcp_server_certificate_and_keys(TLS_CTX *ctx, const char *chainfile,
const char *signkeyfile, const char *signkeypass,
const char *kenckeyfile, const char *kenckeypass)
{
int ret = -1;
const int algor = OID_ec_public_key;
const int algor_param = OID_sm2;
uint8_t *certs = NULL;
size_t certslen;
FILE *signkeyfp = NULL;
FILE *kenckeyfp = NULL;
const uint8_t *cert;
size_t certlen;
X509_KEY public_key;
if (!ctx || !chainfile || !signkeyfile || !signkeypass || !kenckeyfile || !kenckeypass) {
error_print();
return -1;
}
if (!tls_protocol_name(ctx->protocol)) {
error_print();
return -1;
}
if (ctx->certs) {
error_print();
return -1;
}
if (x509_certs_new_from_file(&certs, &certslen, chainfile) != 1) {
error_print();
return -1;
}
// load sign key
if (!(signkeyfp = fopen(signkeyfile, "r"))) {
error_print();
goto end;
}
if (x509_private_key_from_file(&ctx->signkey, algor, signkeypass, signkeyfp) != 1) {
error_print();
goto end;
}
if (x509_certs_get_cert_by_index(certs, certslen, 0, &cert, &certlen) != 1
|| x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) {
error_print();
goto end;
}
if (x509_public_key_equ(&ctx->signkey, &public_key) != 1) {
error_print();
goto end;
}
// load enc key
if (!(kenckeyfp = fopen(kenckeyfile, "r"))) {
error_print();
goto end;
}
if (x509_private_key_from_file(&ctx->kenckey, algor, kenckeypass, kenckeyfp) != 1) {
error_print();
goto end;
}
if (x509_certs_get_cert_by_index(certs, certslen, 1, &cert, &certlen) != 1
|| x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) {
error_print();
goto end;
}
if (x509_public_key_equ(&ctx->kenckey, &public_key) != 1) {
error_print();
goto end;
}
ctx->certs = certs;
ctx->certslen = certslen;
certs = NULL;
ret = 1;
end:
if (ret != 1) x509_key_cleanup(&ctx->signkey);
if (ret != 1) x509_key_cleanup(&ctx->kenckey);
if (certs) free(certs);
if (signkeyfp) fclose(signkeyfp);
if (kenckeyfp) fclose(kenckeyfp);
return ret;
}
int tls_ctx_set_supported_groups(TLS_CTX *ctx, const int *groups, size_t groups_cnt)
{
size_t i;
if (!ctx || !groups || !groups_cnt) {
error_print();
return -1;
}
if (groups_cnt > sizeof(ctx->supported_groups)/sizeof(ctx->supported_groups[0])) {
error_print();
return -1;
}
for (i = 0; i < groups_cnt; i++) {
switch (groups[i]) {
case TLS_curve_sm2p256v1:
case TLS_curve_secp256r1:
break;
default:
error_print();
return -1;
}
ctx->supported_groups[i] = groups[i];
}
ctx->supported_groups_cnt = groups_cnt;
return 1;
}
int tls_ctx_set_signature_algorithms(TLS_CTX *ctx, const int *sig_algs, size_t sig_algs_cnt)
{
size_t i;
if (!ctx || !sig_algs || !sig_algs_cnt) {
error_print();
return -1;
}
if (sig_algs_cnt > sizeof(ctx->signature_algorithms)/sizeof(ctx->signature_algorithms[0])) {
error_print();
return -1;
}
for (i = 0; i < sig_algs_cnt; i++) {
switch (sig_algs[i]) {
case TLS_sig_sm2sig_sm3:
case TLS_sig_ecdsa_secp256r1_sha256:
break;
default:
error_print();
return -1;
}
ctx->signature_algorithms[i] = sig_algs[i];
}
ctx->signature_algorithms_cnt = sig_algs_cnt;
return 1;
}
int tls13_ctx_set_client_hello_key_exchanges_cnt(TLS_CTX *ctx, size_t cnt)
{
if (!ctx) {
error_print();
return -1;
}
if (cnt > sizeof(((TLS_CONNECT *)NULL)->key_exchanges)/sizeof(((TLS_CONNECT *)NULL)->key_exchanges[0])) {
error_print();
return -1;
}
ctx->key_exchanges_cnt = cnt;
return 1;
}
int tls_init(TLS_CONNECT *conn, TLS_CTX *ctx)
{
size_t i;
memset(conn, 0, sizeof(*conn));
conn->is_client = ctx->is_client;
conn->protocol = ctx->protocol;
/*
for (i = 0; i < ctx->cipher_suites_cnt; i++) {
conn->cipher_suites[i] = ctx->cipher_suites[i];
}
conn->cipher_suites_cnt = ctx->cipher_suites_cnt;
*/
if (ctx->certslen > TLS_MAX_CERTIFICATES_SIZE) {
error_print();
return -1;
}
if (conn->is_client) {
memcpy(conn->client_certs, ctx->certs, ctx->certslen);
conn->client_certs_len = ctx->certslen;
} else {
memcpy(conn->server_certs, ctx->certs, ctx->certslen);
conn->server_certs_len = ctx->certslen;
}
/*
if (ctx->cacertslen > TLS_MAX_CERTIFICATES_SIZE) {
error_print();
return -1;
}
memcpy(conn->ca_certs, ctx->cacerts, ctx->cacertslen);
conn->ca_certs_len = ctx->cacertslen;
*/
conn->sign_key = ctx->signkey;
conn->kenc_key = ctx->kenckey;
conn->quiet = ctx->quiet;
conn->ctx = ctx;
conn->key_exchanges_cnt = ctx->key_exchanges_cnt;
conn->new_session_ticket = ctx->new_session_ticket;
conn->key_exchange_modes = ctx->key_exchange_modes;
// early_data
conn->early_data = ctx->early_data;
conn->max_early_data_size = ctx->max_early_data_size;
return 1;
}
void tls_cleanup(TLS_CONNECT *conn)
{
gmssl_secure_clear(conn, sizeof(TLS_CONNECT));
}
/*
int tls_set_hostname(TLS_CONNECT *conn, const char *hostname)
{
if (strlen(hostname) > 255) {
error_print();
return -1;
}
conn->hostname = hostname;
return 1;
}
*/
int tls_set_socket(TLS_CONNECT *conn, tls_socket_t sock)
{
#ifdef WIN32
u_long flags = 0; // TODO: 0 == blocking, 1 == non-blocking
if(ioctlsocket(sock, FIONBIO, &flags) != 0) {
error_puts("socket in non-blocking mode");
//nginx will pass a socket in non-blocking mode
//return -1; // FIXME
}
#else
int flags = 0;
if ((flags = fcntl(sock, F_GETFL)) == -1) {
error_print();
perror("fcntl error");
return -1;
}
if (flags & O_NONBLOCK) {
error_puts("socket in non-blocking mode");
//nginx will pass a socket in non-blocking mode
//return -1; // FIXME
}
#endif
conn->sock = sock;
return 1;
}
int tls_do_handshake(TLS_CONNECT *conn)
{
switch (conn->protocol) {
case TLS_protocol_tlcp:
if (conn->is_client) return tlcp_do_connect(conn);
else return tlcp_do_accept(conn);
case TLS_protocol_tls12:
if (conn->is_client) return tls12_do_connect(conn);
else return tls12_do_accept(conn);
case TLS_protocol_tls13:
if (conn->is_client) return tls13_do_connect(conn);
else return tls13_do_accept(conn);
}
error_print();
return -1;
}
int tls_get_verify_result(TLS_CONNECT *conn, int *result)
{
*result = conn->verify_result;
return 1;
}
int tls_uint16array_from_file(uint8_t *arr, size_t *arrlen, size_t maxlen, FILE *fp)
{
uint16_t datalen;
const uint8_t *cp;
size_t len = 2;
if (!arr || !arrlen || !fp) {
error_print();
return -1;
}
if (maxlen < 2) {
error_print();
return -1;
}
if (fread(arr, 1, 2, fp) != 2) {
error_print();
return -1;
}
cp = arr;
len = 2;
if (tls_uint16_from_bytes(&datalen, &cp, &len) != 1
|| tls_length_is_zero(len) != 1) {
error_print();
return -1;
}
*arrlen = 2 + datalen;
if (2 + datalen > maxlen) {
error_print();
return 0;
}
if (fread(arr + 2, 1, datalen, fp) != datalen) {
error_print();
return -1;
}
return 1;
}
int tls_set_server_name(TLS_CONNECT *conn, const uint8_t *host_name, size_t host_name_len)
{
if (!conn || !host_name || !host_name_len) {
error_print();
return -1;
}
if (!conn->is_client) {
error_print();
return -1;
}
if (host_name_len >= sizeof(conn->server_name)) {
error_print();
return -1;
}
memcpy(conn->server_name, host_name, host_name_len);
conn->server_name[host_name_len] = 0;
conn->server_name_len = host_name_len;
return 1;
}