mirror of
https://github.com/guanzhi/GmSSL.git
synced 2026-05-06 16:36:16 +08:00
2285 lines
63 KiB
C
2285 lines
63 KiB
C
/*
|
||
* Copyright 2014-2022 The GmSSL Project. All Rights Reserved.
|
||
*
|
||
* Licensed under the Apache License, Version 2.0 (the License); you may
|
||
* not use this file except in compliance with the License.
|
||
*
|
||
* http://www.apache.org/licenses/LICENSE-2.0
|
||
*/
|
||
|
||
#include <assert.h>
|
||
#include <errno.h>
|
||
#include <gmssl/error.h>
|
||
#include <gmssl/mem.h>
|
||
#include <gmssl/pem.h>
|
||
#include <gmssl/rand.h>
|
||
#include <gmssl/sm2.h>
|
||
#include <gmssl/sm3.h>
|
||
#include <gmssl/sm4.h>
|
||
#include <gmssl/tls.h>
|
||
#include <gmssl/x509.h>
|
||
#include <stdio.h>
|
||
#include <stdlib.h>
|
||
#include <string.h>
|
||
#include <time.h>
|
||
|
||
void tls_uint8_to_bytes(uint8_t a, uint8_t **out, size_t *outlen) {
|
||
if (out && *out) {
|
||
*(*out)++ = a;
|
||
}
|
||
(*outlen)++;
|
||
}
|
||
|
||
void tls_uint16_to_bytes(uint16_t a, uint8_t **out, size_t *outlen) {
|
||
if (out && *out) {
|
||
*(*out)++ = (uint8_t)(a >> 8);
|
||
*(*out)++ = (uint8_t)a;
|
||
}
|
||
*outlen += 2;
|
||
}
|
||
|
||
void tls_uint24_to_bytes(uint24_t a, uint8_t **out, size_t *outlen) {
|
||
if (out && *out) {
|
||
*(*out)++ = (uint8_t)(a >> 16);
|
||
*(*out)++ = (uint8_t)(a >> 8);
|
||
*(*out)++ = (uint8_t)(a);
|
||
}
|
||
(*outlen) += 3;
|
||
}
|
||
|
||
void tls_uint32_to_bytes(uint32_t a, uint8_t **out, size_t *outlen) {
|
||
if (out && *out) {
|
||
*(*out)++ = (uint8_t)(a >> 24);
|
||
*(*out)++ = (uint8_t)(a >> 16);
|
||
*(*out)++ = (uint8_t)(a >> 8);
|
||
*(*out)++ = (uint8_t)(a);
|
||
}
|
||
(*outlen) += 4;
|
||
}
|
||
|
||
void tls_array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out,
|
||
size_t *outlen) {
|
||
if (out && *out) {
|
||
if (data) {
|
||
memcpy(*out, data, datalen);
|
||
}
|
||
*out += datalen;
|
||
}
|
||
*outlen += datalen;
|
||
}
|
||
|
||
/*
|
||
这几个函数要区分data = NULL, datalen = 0 和 data = NULL, datalen != 0的情况
|
||
前者意味着数据为空,因此输出的就是一个长度
|
||
后者意味着数据不为空,只是我们不想输出数据,只输出头部的长度,并且更新整个的输出长度。
|
||
这种情况应该避免!
|
||
|
||
*/
|
||
|
||
void tls_uint8array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out,
|
||
size_t *outlen) {
|
||
tls_uint8_to_bytes((uint8_t)datalen, out, outlen);
|
||
tls_array_to_bytes(data, datalen, out, outlen);
|
||
}
|
||
|
||
void tls_uint16array_to_bytes(const uint8_t *data, size_t datalen,
|
||
uint8_t **out, size_t *outlen) {
|
||
tls_uint16_to_bytes((uint16_t)datalen, out, outlen);
|
||
tls_array_to_bytes(data, datalen, out, outlen);
|
||
}
|
||
|
||
void tls_uint24array_to_bytes(const uint8_t *data, size_t datalen,
|
||
uint8_t **out, size_t *outlen) {
|
||
tls_uint24_to_bytes((uint24_t)datalen, out, outlen);
|
||
tls_array_to_bytes(data, datalen, out, outlen);
|
||
}
|
||
|
||
int tls_uint8_from_bytes(uint8_t *a, const uint8_t **in, size_t *inlen) {
|
||
if (*inlen < 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
*a = *(*in)++;
|
||
(*inlen)--;
|
||
return 1;
|
||
}
|
||
|
||
int tls_uint16_from_bytes(uint16_t *a, const uint8_t **in, size_t *inlen) {
|
||
if (*inlen < 2) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
*a = *(*in)++;
|
||
*a <<= 8;
|
||
*a |= *(*in)++;
|
||
*inlen -= 2;
|
||
return 1;
|
||
}
|
||
|
||
int tls_uint24_from_bytes(uint24_t *a, const uint8_t **in, size_t *inlen) {
|
||
if (*inlen < 3) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
*a = *(*in)++;
|
||
*a <<= 8;
|
||
*a |= *(*in)++;
|
||
*a <<= 8;
|
||
*a |= *(*in)++;
|
||
*inlen -= 3;
|
||
return 1;
|
||
}
|
||
|
||
int tls_uint32_from_bytes(uint32_t *a, const uint8_t **in, size_t *inlen) {
|
||
if (*inlen < 4) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
*a = *(*in)++;
|
||
*a <<= 8;
|
||
*a |= *(*in)++;
|
||
*a <<= 8;
|
||
*a |= *(*in)++;
|
||
*a <<= 8;
|
||
*a |= *(*in)++;
|
||
*inlen -= 4;
|
||
return 1;
|
||
}
|
||
|
||
int tls_array_from_bytes(const uint8_t **data, size_t datalen,
|
||
const uint8_t **in, size_t *inlen) {
|
||
if (*inlen < datalen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
*data = *in;
|
||
*in += datalen;
|
||
*inlen -= datalen;
|
||
return 1;
|
||
}
|
||
|
||
int tls_uint8array_from_bytes(const uint8_t **data, size_t *datalen,
|
||
const uint8_t **in, size_t *inlen) {
|
||
uint8_t len;
|
||
if (tls_uint8_from_bytes(&len, in, inlen) != 1 ||
|
||
tls_array_from_bytes(data, len, in, inlen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!len) {
|
||
*data = NULL;
|
||
}
|
||
*datalen = len;
|
||
return 1;
|
||
}
|
||
|
||
int tls_uint16array_from_bytes(const uint8_t **data, size_t *datalen,
|
||
const uint8_t **in, size_t *inlen) {
|
||
uint16_t len;
|
||
if (tls_uint16_from_bytes(&len, in, inlen) != 1 ||
|
||
tls_array_from_bytes(data, len, in, inlen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!len) {
|
||
*data = NULL;
|
||
}
|
||
*datalen = len;
|
||
return 1;
|
||
}
|
||
|
||
int tls_uint24array_from_bytes(const uint8_t **data, size_t *datalen,
|
||
const uint8_t **in, size_t *inlen) {
|
||
uint24_t len;
|
||
if (tls_uint24_from_bytes(&len, in, inlen) != 1 ||
|
||
tls_array_from_bytes(data, len, in, inlen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!len) {
|
||
*data = NULL;
|
||
}
|
||
*datalen = len;
|
||
return 1;
|
||
}
|
||
|
||
int tls_length_is_zero(size_t len) {
|
||
if (len) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_set_type(uint8_t *record, int type) {
|
||
if (!tls_record_type_name(type)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
record[0] = (uint8_t)type;
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_set_protocol(uint8_t *record, int protocol) {
|
||
if (!tls_protocol_name(protocol)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
record[1] = (uint8_t)(protocol >> 8);
|
||
record[2] = (uint8_t)(protocol);
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_set_length(uint8_t *record, size_t length) {
|
||
uint8_t *p = record + 3;
|
||
size_t len;
|
||
if (length > TLS_MAX_CIPHERTEXT_SIZE) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_uint16_to_bytes((uint16_t)length, &p, &len);
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_set_data(uint8_t *record, const uint8_t *data, size_t datalen) {
|
||
if (tls_record_set_length(record, datalen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
memcpy(tls_record_data(record), data, datalen);
|
||
return 1;
|
||
}
|
||
|
||
int tls_cbc_encrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *enc_key,
|
||
const uint8_t seq_num[8], const uint8_t header[5],
|
||
const uint8_t *in, size_t inlen, uint8_t *out,
|
||
size_t *outlen) {
|
||
SM3_HMAC_CTX hmac_ctx;
|
||
uint8_t last_blocks[32 + 16] = {0};
|
||
uint8_t *mac, *padding, *iv;
|
||
int rem, padding_len;
|
||
int i;
|
||
|
||
if (!inited_hmac_ctx || !enc_key || !seq_num || !header || (!in && inlen) ||
|
||
!out || !outlen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (inlen > (1 << 14)) {
|
||
error_print_msg("invalid tls record data length %zu\n", inlen);
|
||
return -1;
|
||
}
|
||
if ((((size_t)header[3]) << 8) + header[4] != inlen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
rem = (inlen + 32) % 16;
|
||
memcpy(last_blocks, in + inlen - rem, rem);
|
||
mac = last_blocks + rem;
|
||
|
||
memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
|
||
sm3_hmac_update(&hmac_ctx, seq_num, 8);
|
||
sm3_hmac_update(&hmac_ctx, header, 5);
|
||
sm3_hmac_update(&hmac_ctx, in, inlen);
|
||
sm3_hmac_finish(&hmac_ctx, mac);
|
||
|
||
padding = mac + 32;
|
||
padding_len = 16 - rem - 1;
|
||
for (i = 0; i <= padding_len; i++) {
|
||
padding[i] = (uint8_t)padding_len;
|
||
}
|
||
|
||
iv = out;
|
||
if (rand_bytes(iv, 16) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
out += 16;
|
||
|
||
if (inlen >= 16) {
|
||
sm4_cbc_encrypt(enc_key, iv, in, inlen / 16, out);
|
||
out += inlen - rem;
|
||
iv = out - 16;
|
||
}
|
||
sm4_cbc_encrypt(enc_key, iv, last_blocks, sizeof(last_blocks) / 16, out);
|
||
*outlen = 16 + inlen - rem + sizeof(last_blocks);
|
||
return 1;
|
||
}
|
||
|
||
int tls_cbc_decrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *dec_key,
|
||
const uint8_t seq_num[8], const uint8_t enced_header[5],
|
||
const uint8_t *in, size_t inlen, uint8_t *out,
|
||
size_t *outlen) {
|
||
SM3_HMAC_CTX hmac_ctx;
|
||
const uint8_t *iv;
|
||
const uint8_t *padding;
|
||
const uint8_t *mac;
|
||
uint8_t header[5];
|
||
int padding_len;
|
||
uint8_t hmac[32];
|
||
int i;
|
||
|
||
if (!inited_hmac_ctx || !dec_key || !seq_num || !enced_header || !in ||
|
||
!inlen || !out || !outlen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (inlen % 16 || inlen < (16 + 0 + 32 + 16) // iv + data + mac + padding
|
||
|| inlen > (16 + (1 << 14) + 32 + 256)) {
|
||
error_print_msg("invalid tls cbc ciphertext length %zu\n", inlen);
|
||
return -1;
|
||
}
|
||
|
||
iv = in;
|
||
in += 16;
|
||
inlen -= 16;
|
||
|
||
sm4_cbc_decrypt(dec_key, iv, in, inlen / 16, out);
|
||
|
||
padding_len = out[inlen - 1];
|
||
padding = out + inlen - padding_len - 1;
|
||
if (padding < out + 32) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
for (i = 0; i < padding_len; i++) {
|
||
if (padding[i] != padding_len) {
|
||
error_puts("tls ciphertext cbc-padding check failure");
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
*outlen = inlen - 32 - padding_len - 1;
|
||
|
||
header[0] = enced_header[0];
|
||
header[1] = enced_header[1];
|
||
header[2] = enced_header[2];
|
||
header[3] = (uint8_t)((*outlen) >> 8);
|
||
header[4] = (uint8_t)(*outlen);
|
||
mac = padding - 32;
|
||
|
||
memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
|
||
sm3_hmac_update(&hmac_ctx, seq_num, 8);
|
||
sm3_hmac_update(&hmac_ctx, header, 5);
|
||
sm3_hmac_update(&hmac_ctx, out, *outlen);
|
||
sm3_hmac_finish(&hmac_ctx, hmac);
|
||
if (gmssl_secure_memcmp(mac, hmac, sizeof(hmac)) != 0) {
|
||
error_puts("tls ciphertext mac check failure\n");
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_encrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key,
|
||
const uint8_t seq_num[8], const uint8_t *in,
|
||
size_t inlen, uint8_t *out, size_t *outlen) {
|
||
if (tls_cbc_encrypt(hmac_ctx, cbc_key, seq_num, in, in + 5, inlen - 5,
|
||
out + 5, outlen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
out[0] = in[0];
|
||
out[1] = in[1];
|
||
out[2] = in[2];
|
||
out[3] = (uint8_t)((*outlen) >> 8);
|
||
out[4] = (uint8_t)(*outlen);
|
||
(*outlen) += 5;
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_decrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key,
|
||
const uint8_t seq_num[8], const uint8_t *in,
|
||
size_t inlen, uint8_t *out, size_t *outlen) {
|
||
if (tls_cbc_decrypt(hmac_ctx, cbc_key, seq_num, in, in + 5, inlen - 5,
|
||
out + 5, outlen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
out[0] = in[0];
|
||
out[1] = in[1];
|
||
out[2] = in[2];
|
||
out[3] = (uint8_t)((*outlen) >> 8);
|
||
out[4] = (uint8_t)(*outlen);
|
||
(*outlen) += 5;
|
||
|
||
return 1;
|
||
}
|
||
|
||
int tls_random_generate(uint8_t random[32]) {
|
||
uint32_t gmt_unix_time = (uint32_t)time(NULL);
|
||
uint8_t *p = random;
|
||
size_t len = 0;
|
||
tls_uint32_to_bytes(gmt_unix_time, &p, &len);
|
||
if (rand_bytes(random + 4, 28) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_prf(const uint8_t *secret, size_t secretlen, const char *label,
|
||
const uint8_t *seed, size_t seedlen, const uint8_t *more,
|
||
size_t morelen, size_t outlen, uint8_t *out) {
|
||
SM3_HMAC_CTX inited_hmac_ctx;
|
||
SM3_HMAC_CTX hmac_ctx;
|
||
uint8_t A[32];
|
||
uint8_t hmac[32];
|
||
size_t len;
|
||
|
||
if (!secret || !secretlen || !label || !seed || !seedlen ||
|
||
(!more && morelen) || !outlen || !out) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
sm3_hmac_init(&inited_hmac_ctx, secret, secretlen);
|
||
|
||
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
|
||
sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
|
||
sm3_hmac_update(&hmac_ctx, seed, seedlen);
|
||
sm3_hmac_update(&hmac_ctx, more, morelen);
|
||
sm3_hmac_finish(&hmac_ctx, A);
|
||
|
||
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
|
||
sm3_hmac_update(&hmac_ctx, A, sizeof(A));
|
||
sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
|
||
sm3_hmac_update(&hmac_ctx, seed, seedlen);
|
||
sm3_hmac_update(&hmac_ctx, more, morelen);
|
||
sm3_hmac_finish(&hmac_ctx, hmac);
|
||
|
||
len = outlen < sizeof(hmac) ? outlen : sizeof(hmac);
|
||
memcpy(out, hmac, len);
|
||
out += len;
|
||
outlen -= len;
|
||
|
||
while (outlen) {
|
||
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
|
||
sm3_hmac_update(&hmac_ctx, A, sizeof(A));
|
||
sm3_hmac_finish(&hmac_ctx, A);
|
||
|
||
memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
|
||
sm3_hmac_update(&hmac_ctx, A, sizeof(A));
|
||
sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
|
||
sm3_hmac_update(&hmac_ctx, seed, seedlen);
|
||
sm3_hmac_update(&hmac_ctx, more, morelen);
|
||
sm3_hmac_finish(&hmac_ctx, hmac);
|
||
|
||
len = outlen < sizeof(hmac) ? outlen : sizeof(hmac);
|
||
memcpy(out, hmac, len);
|
||
out += len;
|
||
outlen -= len;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_pre_master_secret_generate(uint8_t pre_master_secret[48],
|
||
int protocol) {
|
||
if (!tls_protocol_name(protocol)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
pre_master_secret[0] = (uint8_t)(protocol >> 8);
|
||
pre_master_secret[1] = (uint8_t)(protocol);
|
||
if (rand_bytes(pre_master_secret + 2, 46) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
// 用于设置CertificateRequest
|
||
int tls_cert_type_from_oid(int oid) {
|
||
switch (oid) {
|
||
case OID_sm2sign_with_sm3:
|
||
case OID_ecdsa_with_sha1:
|
||
case OID_ecdsa_with_sha224:
|
||
case OID_ecdsa_with_sha256:
|
||
case OID_ecdsa_with_sha512:
|
||
return TLS_cert_type_ecdsa_sign;
|
||
case OID_rsasign_with_sm3:
|
||
case OID_rsasign_with_md5:
|
||
case OID_rsasign_with_sha1:
|
||
case OID_rsasign_with_sha224:
|
||
case OID_rsasign_with_sha256:
|
||
case OID_rsasign_with_sha384:
|
||
case OID_rsasign_with_sha512:
|
||
return TLS_cert_type_rsa_sign;
|
||
}
|
||
// TLS_cert_type_xxx 中没有为0的值
|
||
return 0;
|
||
}
|
||
|
||
// 这两个函数没有对应的TLCP版本
|
||
int tls_sign_server_ecdh_params(const SM2_KEY *server_sign_key,
|
||
const uint8_t client_random[32],
|
||
const uint8_t server_random[32], int curve,
|
||
const SM2_POINT *point, uint8_t *sig,
|
||
size_t *siglen) {
|
||
uint8_t server_ecdh_params[69];
|
||
SM2_SIGN_CTX sign_ctx;
|
||
|
||
if (!server_sign_key || !client_random || !server_random ||
|
||
curve != TLS_curve_sm2p256v1 || !point || !sig || !siglen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
server_ecdh_params[0] = TLS_curve_type_named_curve;
|
||
server_ecdh_params[1] = (uint8_t)(curve >> 8);
|
||
server_ecdh_params[2] = (uint8_t)curve;
|
||
server_ecdh_params[3] = 65;
|
||
sm2_point_to_uncompressed_octets(point, server_ecdh_params + 4);
|
||
|
||
sm2_sign_init(&sign_ctx, server_sign_key, SM2_DEFAULT_ID,
|
||
SM2_DEFAULT_ID_LENGTH);
|
||
sm2_sign_update(&sign_ctx, client_random, 32);
|
||
sm2_sign_update(&sign_ctx, server_random, 32);
|
||
sm2_sign_update(&sign_ctx, server_ecdh_params, 69);
|
||
sm2_sign_finish(&sign_ctx, sig, siglen);
|
||
|
||
return 1;
|
||
}
|
||
|
||
int tls_verify_server_ecdh_params(const SM2_KEY *server_sign_key,
|
||
const uint8_t client_random[32],
|
||
const uint8_t server_random[32], int curve,
|
||
const SM2_POINT *point, const uint8_t *sig,
|
||
size_t siglen) {
|
||
int ret;
|
||
uint8_t server_ecdh_params[69];
|
||
SM2_SIGN_CTX verify_ctx;
|
||
|
||
if (!server_sign_key || !client_random || !server_random ||
|
||
curve != TLS_curve_sm2p256v1 || !point || !sig || !siglen ||
|
||
siglen > SM2_MAX_SIGNATURE_SIZE) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
server_ecdh_params[0] = TLS_curve_type_named_curve;
|
||
server_ecdh_params[1] = (uint8_t)(curve >> 8);
|
||
server_ecdh_params[2] = (uint8_t)(curve);
|
||
server_ecdh_params[3] = 65;
|
||
sm2_point_to_uncompressed_octets(point, server_ecdh_params + 4);
|
||
|
||
sm2_verify_init(&verify_ctx, server_sign_key, SM2_DEFAULT_ID,
|
||
SM2_DEFAULT_ID_LENGTH);
|
||
sm2_verify_update(&verify_ctx, client_random, 32);
|
||
sm2_verify_update(&verify_ctx, server_random, 32);
|
||
sm2_verify_update(&verify_ctx, server_ecdh_params, 69);
|
||
ret = sm2_verify_finish(&verify_ctx, sig, siglen);
|
||
if (ret != 1) error_print();
|
||
return ret;
|
||
}
|
||
|
||
int tls_record_set_handshake(uint8_t *record, size_t *recordlen, int type,
|
||
const uint8_t *data, size_t datalen) {
|
||
size_t handshakelen;
|
||
|
||
if (!record || !recordlen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
// 由于ServerHelloDone没有负载数据,因此允许 data,datalen = NULL,0
|
||
if (datalen > TLS_MAX_PLAINTEXT_SIZE - TLS_HANDSHAKE_HEADER_SIZE) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!tls_protocol_name(tls_record_protocol(record))) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!tls_handshake_type_name(type)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
handshakelen = TLS_HANDSHAKE_HEADER_SIZE + datalen;
|
||
record[0] = TLS_record_handshake;
|
||
record[3] = (uint8_t)(handshakelen >> 8);
|
||
record[4] = (uint8_t)(handshakelen);
|
||
record[5] = (uint8_t)(type);
|
||
record[6] = (uint8_t)(datalen >> 16);
|
||
record[7] = (uint8_t)(datalen >> 8);
|
||
record[8] = (uint8_t)(datalen);
|
||
if (data) {
|
||
memcpy(tls_handshake_data(tls_record_data(record)), data, datalen);
|
||
}
|
||
*recordlen = TLS_RECORD_HEADER_SIZE + handshakelen;
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_get_handshake(const uint8_t *record, int *type,
|
||
const uint8_t **data, size_t *datalen) {
|
||
const uint8_t *handshake;
|
||
size_t handshake_len;
|
||
uint24_t handshake_datalen;
|
||
|
||
if (!record || !type || !data || !datalen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!tls_protocol_name(tls_record_protocol(record))) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_record_type(record) != TLS_record_handshake) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
handshake = tls_record_data(record);
|
||
handshake_len = tls_record_data_length(record);
|
||
|
||
if (handshake_len < TLS_HANDSHAKE_HEADER_SIZE) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (handshake_len > TLS_MAX_PLAINTEXT_SIZE) {
|
||
// 不支持证书长度超过记录长度的特殊情况
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if (!tls_handshake_type_name(handshake[0])) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
*type = handshake[0];
|
||
|
||
handshake++;
|
||
handshake_len--;
|
||
if (tls_uint24_from_bytes(&handshake_datalen, &handshake, &handshake_len) !=
|
||
1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (handshake_len != handshake_datalen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
*data = handshake;
|
||
*datalen = handshake_datalen;
|
||
|
||
if (*datalen == 0) {
|
||
*data = NULL;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_set_handshake_client_hello(
|
||
uint8_t *record, size_t *recordlen, int protocol, const uint8_t random[32],
|
||
const uint8_t *session_id, size_t session_id_len, const int *cipher_suites,
|
||
size_t cipher_suites_count, const uint8_t *exts, size_t exts_len) {
|
||
uint8_t type = TLS_handshake_client_hello;
|
||
uint8_t *p;
|
||
size_t len;
|
||
|
||
if (!record || !recordlen || !random || !cipher_suites ||
|
||
!cipher_suites_count) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (session_id) {
|
||
if (!session_id_len || session_id_len < TLS_MAX_SESSION_ID_SIZE ||
|
||
session_id_len > TLS_MAX_SESSION_ID_SIZE) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
}
|
||
if (cipher_suites_count > TLS_MAX_CIPHER_SUITES_COUNT) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (exts && !exts_len) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
p = tls_handshake_data(tls_record_data(record));
|
||
len = 0;
|
||
|
||
if (!tls_protocol_name(protocol)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_uint16_to_bytes((uint16_t)protocol, &p, &len);
|
||
tls_array_to_bytes(random, 32, &p, &len);
|
||
tls_uint8array_to_bytes(session_id, session_id_len, &p, &len);
|
||
tls_uint16_to_bytes((uint16_t)(cipher_suites_count * 2), &p, &len);
|
||
while (cipher_suites_count--) {
|
||
if (!tls_cipher_suite_name(*cipher_suites)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_uint16_to_bytes((uint16_t)*cipher_suites, &p, &len);
|
||
cipher_suites++;
|
||
}
|
||
tls_uint8_to_bytes(1, &p, &len);
|
||
tls_uint8_to_bytes((uint8_t)TLS_compression_null, &p, &len);
|
||
if (exts) {
|
||
size_t tmp_len = len;
|
||
if (protocol < TLS_protocol_tls12) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_uint16array_to_bytes(exts, exts_len, NULL, &tmp_len);
|
||
if (tmp_len > TLS_MAX_HANDSHAKE_DATA_SIZE) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_uint16array_to_bytes(exts, exts_len, &p, &len);
|
||
}
|
||
if (tls_record_set_handshake(record, recordlen, type, NULL, len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_get_handshake_client_hello(
|
||
const uint8_t *record, int *protocol, const uint8_t **random,
|
||
const uint8_t **session_id, size_t *session_id_len,
|
||
const uint8_t **cipher_suites, size_t *cipher_suites_len,
|
||
const uint8_t **exts, size_t *exts_len) {
|
||
int type;
|
||
const uint8_t *p;
|
||
size_t len;
|
||
uint16_t ver;
|
||
const uint8_t *comp_meths;
|
||
size_t comp_meths_len;
|
||
|
||
if (!record || !protocol || !random || !session_id || !session_id_len ||
|
||
!cipher_suites || !cipher_suites_len || !exts || !exts_len) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_record_get_handshake(record, &type, &p, &len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (type != TLS_handshake_client_hello) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_uint16_from_bytes(&ver, &p, &len) != 1 ||
|
||
tls_array_from_bytes(random, 32, &p, &len) != 1 ||
|
||
tls_uint8array_from_bytes(session_id, session_id_len, &p, &len) != 1 ||
|
||
tls_uint16array_from_bytes(cipher_suites, cipher_suites_len, &p,
|
||
&len) != 1 ||
|
||
tls_uint8array_from_bytes(&comp_meths, &comp_meths_len, &p, &len) !=
|
||
1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if (!tls_protocol_name(ver)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
*protocol = ver;
|
||
|
||
if (*session_id) {
|
||
if (*session_id_len == 0 || *session_id_len < TLS_MIN_SESSION_ID_SIZE ||
|
||
*session_id_len > TLS_MAX_SESSION_ID_SIZE) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
if (!cipher_suites) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (*cipher_suites_len % 2) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if (len) {
|
||
if (tls_uint16array_from_bytes(exts, exts_len, &p, &len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (*exts == NULL) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
} else {
|
||
*exts = NULL;
|
||
*exts_len = 0;
|
||
}
|
||
if (len) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_set_handshake_server_hello(
|
||
uint8_t *record, size_t *recordlen, int protocol, const uint8_t random[32],
|
||
const uint8_t *session_id, size_t session_id_len, int cipher_suite,
|
||
const uint8_t *exts, size_t exts_len) {
|
||
uint8_t type = TLS_handshake_server_hello;
|
||
uint8_t *p;
|
||
size_t len;
|
||
|
||
if (!record || !recordlen || !random) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (session_id) {
|
||
if (session_id_len == 0 || session_id_len < TLS_MIN_SESSION_ID_SIZE ||
|
||
session_id_len > TLS_MAX_SESSION_ID_SIZE) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
}
|
||
if (!tls_protocol_name(protocol)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!tls_cipher_suite_name(cipher_suite)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
p = tls_handshake_data(tls_record_data(record));
|
||
len = 0;
|
||
|
||
tls_uint16_to_bytes((uint16_t)protocol, &p, &len);
|
||
tls_array_to_bytes(random, 32, &p, &len);
|
||
tls_uint8array_to_bytes(session_id, session_id_len, &p, &len);
|
||
tls_uint16_to_bytes((uint16_t)cipher_suite, &p, &len);
|
||
tls_uint8_to_bytes((uint8_t)TLS_compression_null, &p, &len);
|
||
if (exts) {
|
||
if (protocol < TLS_protocol_tls12) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_uint16array_to_bytes(exts, exts_len, &p, &len);
|
||
}
|
||
if (tls_record_set_handshake(record, recordlen, type, NULL, len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_get_handshake_server_hello(
|
||
const uint8_t *record, int *protocol, const uint8_t **random,
|
||
const uint8_t **session_id, size_t *session_id_len, int *cipher_suite,
|
||
const uint8_t **exts, size_t *exts_len) {
|
||
int type;
|
||
const uint8_t *p;
|
||
size_t len;
|
||
uint16_t ver;
|
||
uint16_t cipher;
|
||
uint8_t comp_meth;
|
||
|
||
if (!record || !protocol || !random || !session_id || !session_id_len ||
|
||
!cipher_suite || !exts || !exts_len) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_record_get_handshake(record, &type, &p, &len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (type != TLS_handshake_server_hello) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_uint16_from_bytes(&ver, &p, &len) != 1 ||
|
||
tls_array_from_bytes(random, 32, &p, &len) != 1 ||
|
||
tls_uint8array_from_bytes(session_id, session_id_len, &p, &len) != 1 ||
|
||
tls_uint16_from_bytes(&cipher, &p, &len) != 1 ||
|
||
tls_uint8_from_bytes(&comp_meth, &p, &len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if (!tls_protocol_name(ver)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (ver < tls_record_protocol(record)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
*protocol = ver;
|
||
|
||
if (*session_id) {
|
||
if (*session_id == 0 || *session_id_len < TLS_MIN_SESSION_ID_SIZE ||
|
||
*session_id_len > TLS_MAX_SESSION_ID_SIZE) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
if (!tls_cipher_suite_name(cipher)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
*cipher_suite = cipher;
|
||
|
||
if (comp_meth != TLS_compression_null) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if (len) {
|
||
if (tls_uint16array_from_bytes(exts, exts_len, &p, &len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (*exts == NULL) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
} else {
|
||
*exts = NULL;
|
||
*exts_len = 0;
|
||
}
|
||
if (len) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_set_handshake_certificate(uint8_t *record, size_t *recordlen,
|
||
const uint8_t *certs,
|
||
size_t certslen) {
|
||
int type = TLS_handshake_certificate;
|
||
uint8_t *data;
|
||
size_t datalen;
|
||
uint8_t *p;
|
||
size_t len;
|
||
|
||
if (!record || !recordlen || !certs || !certslen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
data = tls_handshake_data(tls_record_data(record));
|
||
p = data + tls_uint24_size();
|
||
datalen = tls_uint24_size();
|
||
len = 0;
|
||
|
||
while (certslen) {
|
||
const uint8_t *cert;
|
||
size_t certlen;
|
||
|
||
if (x509_cert_from_der(&cert, &certlen, &certs, &certslen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_uint24array_to_bytes(cert, certlen, NULL, &datalen);
|
||
if (datalen > TLS_MAX_HANDSHAKE_DATA_SIZE) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_uint24array_to_bytes(cert, certlen, &p, &len);
|
||
}
|
||
tls_uint24_to_bytes((uint24_t)len, &data, &len);
|
||
tls_record_set_handshake(record, recordlen, type, NULL, datalen);
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_get_handshake_certificate(const uint8_t *record, uint8_t *certs,
|
||
size_t *certslen) {
|
||
int type;
|
||
const uint8_t *data;
|
||
size_t datalen;
|
||
const uint8_t *cp;
|
||
size_t len;
|
||
|
||
if (tls_record_get_handshake(record, &type, &data, &datalen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (type != TLS_handshake_certificate) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_uint24array_from_bytes(&cp, &len, &data, &datalen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
*certslen = 0;
|
||
while (len) {
|
||
const uint8_t *a;
|
||
size_t alen;
|
||
const uint8_t *cert;
|
||
size_t certlen;
|
||
|
||
if (tls_uint24array_from_bytes(&a, &alen, &cp, &len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (x509_cert_from_der(&cert, &certlen, &a, &alen) != 1 ||
|
||
asn1_length_is_zero(alen) != 1 ||
|
||
x509_cert_to_der(cert, certlen, &certs, certslen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_set_handshake_certificate_request(
|
||
uint8_t *record, size_t *recordlen, const uint8_t *cert_types,
|
||
size_t cert_types_len, const uint8_t *ca_names, size_t ca_names_len) {
|
||
int type = TLS_handshake_certificate_request;
|
||
uint8_t *p;
|
||
size_t len = 0;
|
||
size_t datalen = 0;
|
||
|
||
if (!record || !recordlen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (cert_types) {
|
||
if (cert_types_len == 0 || cert_types_len > TLS_MAX_CERTIFICATE_TYPES) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
}
|
||
if (ca_names) {
|
||
if (ca_names_len == 0 || ca_names_len > TLS_MAX_CA_NAMES_SIZE) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
}
|
||
tls_uint8array_to_bytes(cert_types, cert_types_len, NULL, &datalen);
|
||
tls_uint16array_to_bytes(ca_names, ca_names_len, NULL, &datalen);
|
||
if (datalen > TLS_MAX_HANDSHAKE_DATA_SIZE) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
p = tls_handshake_data(tls_record_data(record));
|
||
tls_uint8array_to_bytes(cert_types, cert_types_len, &p, &len);
|
||
tls_uint16array_to_bytes(ca_names, ca_names_len, &p, &len);
|
||
tls_record_set_handshake(record, recordlen, type, NULL, datalen);
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_get_handshake_certificate_request(const uint8_t *record,
|
||
const uint8_t **cert_types,
|
||
size_t *cert_types_len,
|
||
const uint8_t **ca_names,
|
||
size_t *ca_names_len) {
|
||
int type;
|
||
const uint8_t *cp;
|
||
size_t len;
|
||
size_t i;
|
||
|
||
if (!record || !cert_types || !cert_types_len || !ca_names ||
|
||
!ca_names_len) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_record_get_handshake(record, &type, &cp, &len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (type != TLS_handshake_certificate_request) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_uint8array_from_bytes(cert_types, cert_types_len, &cp, &len) != 1 ||
|
||
tls_uint16array_from_bytes(ca_names, ca_names_len, &cp, &len) != 1 ||
|
||
tls_length_is_zero(len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if (*cert_types == NULL) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
for (i = 0; i < *cert_types_len; i++) {
|
||
if (!tls_cert_type_name((*cert_types)[i])) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
}
|
||
if (*ca_names) {
|
||
const uint8_t *names = *ca_names;
|
||
size_t nameslen = *ca_names_len;
|
||
while (nameslen) {
|
||
if (tls_uint16array_from_bytes(&cp, &len, &names, &nameslen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
}
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_set_handshake_server_hello_done(uint8_t *record,
|
||
size_t *recordlen) {
|
||
int type = TLS_handshake_server_hello_done;
|
||
if (!record || !recordlen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_record_set_handshake(record, recordlen, type, NULL, 0);
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_get_handshake_server_hello_done(const uint8_t *record) {
|
||
int type;
|
||
const uint8_t *p;
|
||
size_t len;
|
||
|
||
if (!record) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_record_get_handshake(record, &type, &p, &len) != 1 ||
|
||
type != TLS_handshake_server_hello_done) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (p != NULL || len != 0) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_set_handshake_client_key_exchange_pke(uint8_t *record,
|
||
size_t *recordlen,
|
||
const uint8_t *enced_pms,
|
||
size_t enced_pms_len) {
|
||
int type = TLS_handshake_client_key_exchange;
|
||
uint8_t *p;
|
||
size_t len = 0;
|
||
|
||
if (!record || !recordlen || !enced_pms || !enced_pms_len) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (enced_pms_len > TLS_MAX_HANDSHAKE_DATA_SIZE - tls_uint16_size()) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
p = tls_handshake_data(tls_record_data(record));
|
||
tls_uint16array_to_bytes(enced_pms, enced_pms_len, &p, &len);
|
||
tls_record_set_handshake(record, recordlen, type, NULL, len);
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_get_handshake_client_key_exchange_pke(const uint8_t *record,
|
||
const uint8_t **enced_pms,
|
||
size_t *enced_pms_len) {
|
||
int type;
|
||
const uint8_t *cp;
|
||
size_t len;
|
||
|
||
if (!record || !enced_pms || !enced_pms_len) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_record_get_handshake(record, &type, &cp, &len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (type != TLS_handshake_client_key_exchange) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_uint16array_from_bytes(enced_pms, enced_pms_len, &cp, &len) != 1 ||
|
||
tls_length_is_zero(len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_set_handshake_certificate_verify(uint8_t *record,
|
||
size_t *recordlen,
|
||
const uint8_t *sig,
|
||
size_t siglen) {
|
||
int type = TLS_handshake_certificate_verify;
|
||
|
||
if (!record || !recordlen || !sig || !siglen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (siglen > TLS_MAX_SIGNATURE_SIZE) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_record_set_handshake(record, recordlen, type, sig, siglen);
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_get_handshake_certificate_verify(const uint8_t *record,
|
||
const uint8_t **sig,
|
||
size_t *siglen) {
|
||
int type;
|
||
|
||
if (!record || !sig || !siglen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_record_get_handshake(record, &type, sig, siglen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (type != TLS_handshake_certificate_verify) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (*sig == NULL || *siglen == 0) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (*siglen > TLS_MAX_SIGNATURE_SIZE) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_set_handshake_finished(uint8_t *record, size_t *recordlen,
|
||
const uint8_t *verify_data,
|
||
size_t verify_data_len) {
|
||
int type = TLS_handshake_finished;
|
||
|
||
if (!record || !recordlen || !verify_data || !verify_data_len) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (verify_data_len != 12 && verify_data_len != 32) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_record_set_handshake(record, recordlen, type, verify_data,
|
||
verify_data_len);
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_get_handshake_finished(const uint8_t *record,
|
||
const uint8_t **verify_data,
|
||
size_t *verify_data_len) {
|
||
int type;
|
||
|
||
if (!record || !verify_data || !verify_data_len) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_record_get_handshake(record, &type, verify_data, verify_data_len) !=
|
||
1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (type != TLS_handshake_finished) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (*verify_data == NULL || *verify_data_len == 0) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (*verify_data_len != 12 && *verify_data_len != 32) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_set_alert(uint8_t *record, size_t *recordlen, int alert_level,
|
||
int alert_description) {
|
||
if (!record || !recordlen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!tls_alert_level_name(alert_level)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!tls_alert_description_text(alert_description)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
record[0] = TLS_record_alert;
|
||
record[3] = 0; // length
|
||
record[4] = 2; // length
|
||
record[5] = (uint8_t)alert_level;
|
||
record[6] = (uint8_t)alert_description;
|
||
*recordlen = TLS_RECORD_HEADER_SIZE + 2;
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_get_alert(const uint8_t *record, int *alert_level,
|
||
int *alert_description) {
|
||
if (!record || !alert_level || !alert_description) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_record_type(record) != TLS_record_alert) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (record[3] != 0 || record[4] != 2) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
*alert_level = record[5];
|
||
*alert_description = record[6];
|
||
if (!tls_alert_level_name(*alert_level)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!tls_alert_description_text(*alert_description)) {
|
||
error_puts("warning");
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_set_change_cipher_spec(uint8_t *record, size_t *recordlen) {
|
||
if (!record || !recordlen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
record[0] = TLS_record_change_cipher_spec;
|
||
record[3] = 0;
|
||
record[4] = 1;
|
||
record[5] = TLS_change_cipher_spec;
|
||
*recordlen = TLS_RECORD_HEADER_SIZE + 1;
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_get_change_cipher_spec(const uint8_t *record) {
|
||
if (!record) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_record_type(record) != TLS_record_change_cipher_spec) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (record[3] != 0 || record[4] != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (record[5] != TLS_change_cipher_spec) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_set_application_data(uint8_t *record, size_t *recordlen,
|
||
const uint8_t *data, size_t datalen) {
|
||
if (!record || !recordlen || !data || !datalen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
record[0] = TLS_record_application_data;
|
||
record[3] = (datalen >> 8) & 0xff;
|
||
record[4] = datalen & 0xff;
|
||
memcpy(tls_record_data(record), data, datalen);
|
||
*recordlen = TLS_RECORD_HEADER_SIZE + datalen;
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_get_application_data(uint8_t *record, const uint8_t **data,
|
||
size_t *datalen) {
|
||
if (!record || !data || !datalen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_record_type(record) != TLS_record_application_data) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
*datalen = ((size_t)record[3] << 8) | record[4];
|
||
*data = *datalen ? record + TLS_RECORD_HEADER_SIZE : 0;
|
||
return 1;
|
||
}
|
||
|
||
int tls_cipher_suite_in_list(int cipher, const int *list, size_t list_count) {
|
||
size_t i;
|
||
if (!list || !list_count) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
for (i = 0; i < list_count; i++) {
|
||
if (cipher == list[i]) {
|
||
return 1;
|
||
}
|
||
}
|
||
return 0;
|
||
}
|
||
|
||
int tls_record_send(const uint8_t *record, size_t recordlen,
|
||
tls_socket_t sock) {
|
||
tls_ret_t r;
|
||
|
||
if (!record) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (recordlen < TLS_RECORD_HEADER_SIZE) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_record_length(record) != recordlen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if ((r = tls_socket_send(sock, record, recordlen, 0)) < 0) {
|
||
perror("tls_record_send");
|
||
error_print();
|
||
return -1;
|
||
} else if (r != recordlen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_do_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock) {
|
||
tls_ret_t r;
|
||
size_t len;
|
||
|
||
len = 5;
|
||
while (len) {
|
||
while ((r = tls_socket_recv(sock, record + 5 - len, len, 0)) < 0) {
|
||
if (errno == EAGAIN) {
|
||
continue;
|
||
} else {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
error_print();
|
||
if (r == 0) {
|
||
perror("tls_record_do_recv");
|
||
error_print();
|
||
return 0;
|
||
}
|
||
|
||
len -= r;
|
||
}
|
||
if (!tls_record_type_name(tls_record_type(record))) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!tls_protocol_name(tls_record_protocol(record))) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
len = (size_t)record[3] << 8 | record[4];
|
||
*recordlen = 5 + len;
|
||
if (*recordlen > TLS_MAX_RECORD_SIZE) {
|
||
// 这里只检查是否超过最大长度,握手协议的长度检查由上层协议完成
|
||
error_print();
|
||
return -1;
|
||
}
|
||
while (len) {
|
||
if ((r = recv(sock, record + *recordlen - len, len, 0)) < 0) {
|
||
perror("tls_record_do_recv");
|
||
error_print();
|
||
return -1;
|
||
}
|
||
len -= r;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_record_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock) {
|
||
retry:
|
||
if (tls_record_do_recv(record, recordlen, sock) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if (tls_record_type(record) == TLS_record_alert) {
|
||
int level;
|
||
int alert;
|
||
if (tls_record_get_alert(record, &level, &alert) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_record_trace(stderr, record, *recordlen, 0, 0);
|
||
if (level == TLS_alert_level_warning) {
|
||
// 忽略Warning,读取下一个记录
|
||
error_puts("Warning record received!\n");
|
||
goto retry;
|
||
}
|
||
if (alert == TLS_alert_close_notify) {
|
||
// close_notify是唯一需要提供反馈的Fatal Alert,其他直接中止连接
|
||
uint8_t alert_record[TLS_ALERT_RECORD_SIZE];
|
||
size_t alert_record_len;
|
||
tls_record_set_type(alert_record, TLS_record_alert);
|
||
tls_record_set_protocol(alert_record, tls_record_protocol(record));
|
||
tls_record_set_alert(alert_record, &alert_record_len,
|
||
TLS_alert_level_fatal, TLS_alert_close_notify);
|
||
|
||
tls_trace("send Alert close_notifiy\n");
|
||
tls_record_trace(stderr, alert_record, alert_record_len, 0, 0);
|
||
tls_record_send(alert_record, alert_record_len, sock);
|
||
}
|
||
// 返回错误0通知调用方不再做任何处理(无需再发送Alert)
|
||
return 0;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_seq_num_incr(uint8_t seq_num[8]) {
|
||
int i;
|
||
for (i = 7; i > 0; i--) {
|
||
seq_num[i]++;
|
||
if (seq_num[i]) break;
|
||
}
|
||
// FIXME: 检查溢出
|
||
return 1;
|
||
}
|
||
|
||
int tls_compression_methods_has_null_compression(const uint8_t *meths,
|
||
size_t methslen) {
|
||
if (!meths || !methslen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
while (methslen--) {
|
||
if (*meths++ == TLS_compression_null) {
|
||
return 1;
|
||
}
|
||
}
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
int tls_send_alert(TLS_CONNECT *conn, int alert) {
|
||
uint8_t record[5 + 2];
|
||
size_t recordlen;
|
||
|
||
if (!conn) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_record_set_protocol(record, conn->protocol == TLS_protocol_tls13
|
||
? TLS_protocol_tls12
|
||
: conn->protocol);
|
||
tls_record_set_alert(record, &recordlen, TLS_alert_level_fatal, alert);
|
||
|
||
if (tls_record_send(record, sizeof(record), conn->sock) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_record_trace(stderr, record, sizeof(record), 0, 0);
|
||
return 1;
|
||
}
|
||
|
||
int tls_alert_level(int alert) {
|
||
switch (alert) {
|
||
case TLS_alert_bad_certificate:
|
||
case TLS_alert_unsupported_certificate:
|
||
case TLS_alert_certificate_revoked:
|
||
case TLS_alert_certificate_expired:
|
||
case TLS_alert_certificate_unknown:
|
||
return 0;
|
||
case TLS_alert_user_canceled:
|
||
case TLS_alert_no_renegotiation:
|
||
return TLS_alert_level_warning;
|
||
}
|
||
return TLS_alert_level_fatal;
|
||
}
|
||
|
||
int tls_send_warning(TLS_CONNECT *conn, int alert) {
|
||
uint8_t record[5 + 2];
|
||
size_t recordlen;
|
||
|
||
if (!conn) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_alert_level(alert) == TLS_alert_level_fatal) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_record_set_protocol(record, conn->protocol == TLS_protocol_tls13
|
||
? TLS_protocol_tls12
|
||
: conn->protocol);
|
||
tls_record_set_alert(record, &recordlen, TLS_alert_level_warning, alert);
|
||
|
||
if (tls_record_send(record, sizeof(record), conn->sock) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_record_trace(stderr, record, sizeof(record), 0, 0);
|
||
return 1;
|
||
}
|
||
|
||
int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen,
|
||
size_t *sentlen) {
|
||
const SM3_HMAC_CTX *hmac_ctx;
|
||
const SM4_KEY *enc_key;
|
||
uint8_t *seq_num;
|
||
uint8_t *record;
|
||
size_t datalen;
|
||
|
||
if (!conn) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!in || !inlen || !sentlen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if (inlen > TLS_MAX_PLAINTEXT_SIZE) {
|
||
inlen = TLS_MAX_PLAINTEXT_SIZE;
|
||
}
|
||
|
||
if (conn->is_client) {
|
||
hmac_ctx = &conn->client_write_mac_ctx;
|
||
enc_key = &conn->client_write_enc_key;
|
||
seq_num = conn->client_seq_num;
|
||
} else {
|
||
hmac_ctx = &conn->server_write_mac_ctx;
|
||
enc_key = &conn->server_write_enc_key;
|
||
seq_num = conn->server_seq_num;
|
||
}
|
||
record = conn->record;
|
||
|
||
tls_trace("send ApplicationData\n");
|
||
|
||
if (tls_record_set_type(record, TLS_record_application_data) != 1 ||
|
||
tls_record_set_protocol(record, conn->protocol) != 1 ||
|
||
tls_record_set_length(record, inlen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if (tls_cbc_encrypt(hmac_ctx, enc_key, seq_num, tls_record_header(record),
|
||
in, inlen, tls_record_data(record), &datalen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_record_set_length(record, datalen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_seq_num_incr(seq_num);
|
||
if (tls_record_send(record, tls_record_length(record), conn->sock) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
*sentlen = inlen;
|
||
tls_record_trace(stderr, record, tls_record_length(record), 0, 0);
|
||
return 1;
|
||
}
|
||
|
||
int tls_do_recv(TLS_CONNECT *conn) {
|
||
int ret;
|
||
const SM3_HMAC_CTX *hmac_ctx;
|
||
const SM4_KEY *dec_key;
|
||
uint8_t *seq_num;
|
||
|
||
uint8_t *record = conn->record;
|
||
size_t recordlen;
|
||
|
||
if (conn->is_client) {
|
||
hmac_ctx = &conn->server_write_mac_ctx;
|
||
dec_key = &conn->server_write_enc_key;
|
||
seq_num = conn->server_seq_num;
|
||
} else {
|
||
hmac_ctx = &conn->client_write_mac_ctx;
|
||
dec_key = &conn->client_write_enc_key;
|
||
seq_num = conn->client_seq_num;
|
||
}
|
||
|
||
tls_trace("recv ApplicationData\n");
|
||
if ((ret = tls_record_recv(record, &recordlen, conn->sock)) != 1) {
|
||
if (ret < 0) error_print();
|
||
return ret;
|
||
}
|
||
|
||
tls_record_trace(stderr, record, recordlen, 0, 0);
|
||
if (tls_cbc_decrypt(hmac_ctx, dec_key, seq_num, record,
|
||
tls_record_data(record), tls_record_data_length(record),
|
||
conn->databuf, &conn->datalen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
conn->data = conn->databuf;
|
||
tls_seq_num_incr(seq_num);
|
||
|
||
tls_record_set_data(record, conn->data, conn->datalen);
|
||
tls_trace("decrypt ApplicationData\n");
|
||
tls_record_trace(stderr, record, tls_record_length(record), 0, 0);
|
||
return 1;
|
||
}
|
||
|
||
int tls_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen) {
|
||
if (!conn || !out || !outlen || !recvlen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (conn->datalen == 0) {
|
||
int ret;
|
||
if ((ret = tls_do_recv(conn)) != 1) {
|
||
if (ret) error_print();
|
||
return ret;
|
||
}
|
||
}
|
||
*recvlen = outlen <= conn->datalen ? outlen : conn->datalen;
|
||
memcpy(out, conn->data, *recvlen);
|
||
conn->data += *recvlen;
|
||
conn->datalen -= *recvlen;
|
||
return 1;
|
||
}
|
||
|
||
int tls_shutdown(TLS_CONNECT *conn) {
|
||
size_t recordlen;
|
||
if (!conn) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_trace("send Alert close_notify\n");
|
||
if (tls_send_alert(conn, TLS_alert_close_notify) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_trace("recv Alert close_notify\n");
|
||
|
||
if (tls_record_do_recv(conn->record, &recordlen, conn->sock) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_record_trace(stderr, conn->record, recordlen, 0, 0);
|
||
|
||
return 1;
|
||
}
|
||
|
||
int tls_authorities_from_certs(uint8_t *names, size_t *nameslen, size_t maxlen,
|
||
const uint8_t *certs, size_t certslen) {
|
||
const uint8_t *cert;
|
||
size_t certlen;
|
||
const uint8_t *name;
|
||
size_t namelen;
|
||
|
||
*nameslen = 0;
|
||
while (certslen) {
|
||
size_t alen = 0;
|
||
if (x509_cert_from_der(&cert, &certlen, &certs, &certslen) != 1 ||
|
||
x509_cert_get_subject(cert, certlen, &name, &namelen) != 1 ||
|
||
asn1_sequence_to_der(name, namelen, NULL, &alen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (tls_uint16_size() + alen > maxlen) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (alen > UINT16_MAX) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
tls_uint16_to_bytes((uint16_t)alen, &names, nameslen);
|
||
if (asn1_sequence_to_der(name, namelen, &names, nameslen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
maxlen -= alen;
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
int tls_authorities_issued_certificate(const uint8_t *ca_names,
|
||
size_t ca_names_len,
|
||
const uint8_t *certs, size_t certslen) {
|
||
const uint8_t *cert;
|
||
size_t certlen;
|
||
const uint8_t *issuer;
|
||
size_t issuer_len;
|
||
|
||
if (x509_certs_get_last(certs, certslen, &cert, &certlen) != 1 ||
|
||
x509_cert_get_issuer(cert, certlen, &issuer, &issuer_len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
while (ca_names_len) {
|
||
const uint8_t *p;
|
||
size_t len;
|
||
const uint8_t *name;
|
||
size_t namelen;
|
||
|
||
if (tls_uint16array_from_bytes(&p, &len, &ca_names, &ca_names_len) !=
|
||
1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (asn1_sequence_from_der(&name, &namelen, &p, &len) != 1 ||
|
||
asn1_length_is_zero(len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (x509_name_equ(name, namelen, issuer, issuer_len) == 1) {
|
||
return 1;
|
||
}
|
||
}
|
||
error_print();
|
||
return 0;
|
||
}
|
||
|
||
int tls_cert_types_accepted(const uint8_t *types, size_t types_len,
|
||
const uint8_t *client_certs,
|
||
size_t client_certs_len) {
|
||
const uint8_t *cert;
|
||
size_t certlen;
|
||
int sig_alg;
|
||
size_t i;
|
||
|
||
if (x509_certs_get_cert_by_index(client_certs, client_certs_len, 0, &cert,
|
||
&certlen) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if ((sig_alg = tls_cert_type_from_oid(OID_sm2sign_with_sm3)) < 0) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
for (i = 0; i < types_len; i++) {
|
||
if (sig_alg == types[i]) {
|
||
return 1;
|
||
}
|
||
}
|
||
return 0;
|
||
}
|
||
|
||
int tls_client_verify_init(TLS_CLIENT_VERIFY_CTX *ctx) {
|
||
if (!ctx) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
memset(ctx, 0, sizeof(TLS_CLIENT_VERIFY_CTX));
|
||
return 1;
|
||
}
|
||
|
||
int tls_client_verify_update(TLS_CLIENT_VERIFY_CTX *ctx,
|
||
const uint8_t *handshake, size_t handshake_len) {
|
||
uint8_t *buf;
|
||
if (!ctx || !handshake || !handshake_len) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (ctx->index < 0 || ctx->index > 7) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!(buf = malloc(handshake_len))) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
memcpy(buf, handshake, handshake_len);
|
||
ctx->handshake[ctx->index] = buf;
|
||
ctx->handshake_len[ctx->index] = handshake_len;
|
||
ctx->index++;
|
||
return 1;
|
||
}
|
||
|
||
int tls_client_verify_finish(TLS_CLIENT_VERIFY_CTX *ctx, const uint8_t *sig,
|
||
size_t siglen, const SM2_KEY *public_key) {
|
||
int ret;
|
||
SM2_SIGN_CTX sm2_ctx;
|
||
int i;
|
||
|
||
if (!ctx || !sig || !siglen || !public_key) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if (ctx->index != 8) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (sm2_verify_init(&sm2_ctx, public_key, SM2_DEFAULT_ID,
|
||
SM2_DEFAULT_ID_LENGTH) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
for (i = 0; i < 8; i++) {
|
||
if (sm2_verify_update(&sm2_ctx, ctx->handshake[i],
|
||
ctx->handshake_len[i]) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
}
|
||
if ((ret = sm2_verify_finish(&sm2_ctx, sig, siglen)) < 0) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
return ret;
|
||
}
|
||
|
||
void tls_client_verify_cleanup(TLS_CLIENT_VERIFY_CTX *ctx) {
|
||
if (ctx) {
|
||
int i;
|
||
for (i = 0; i < ctx->index; i++) {
|
||
if (ctx->handshake[i]) {
|
||
free(ctx->handshake[i]);
|
||
ctx->handshake[i] = NULL;
|
||
ctx->handshake_len[i] = 0;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
int tls_cipher_suites_select(const uint8_t *client_ciphers,
|
||
size_t client_ciphers_len,
|
||
const int *server_ciphers,
|
||
size_t server_ciphers_cnt, int *selected_cipher) {
|
||
if (!client_ciphers || !client_ciphers_len || !server_ciphers ||
|
||
!server_ciphers_cnt || !selected_cipher) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
while (server_ciphers_cnt--) {
|
||
const uint8_t *p = client_ciphers;
|
||
size_t len = client_ciphers_len;
|
||
while (len) {
|
||
uint16_t cipher;
|
||
if (tls_uint16_from_bytes(&cipher, &p, &len) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (cipher == *server_ciphers) {
|
||
*selected_cipher = *server_ciphers;
|
||
return 1;
|
||
}
|
||
}
|
||
server_ciphers++;
|
||
}
|
||
return 0;
|
||
}
|
||
|
||
void tls_ctx_cleanup(TLS_CTX *ctx) {
|
||
if (ctx) {
|
||
gmssl_secure_clear(&ctx->signkey, sizeof(SM2_KEY));
|
||
gmssl_secure_clear(&ctx->kenckey, sizeof(SM2_KEY));
|
||
if (ctx->certs) free(ctx->certs);
|
||
if (ctx->cacerts) free(ctx->cacerts);
|
||
memset(ctx, 0, sizeof(TLS_CTX));
|
||
}
|
||
}
|
||
|
||
int tls_ctx_init(TLS_CTX *ctx, int protocol, int is_client) {
|
||
if (!ctx) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
memset(ctx, 0, sizeof(*ctx));
|
||
|
||
switch (protocol) {
|
||
case TLS_protocol_tlcp:
|
||
case TLS_protocol_tls12:
|
||
case TLS_protocol_tls13:
|
||
ctx->protocol = protocol;
|
||
break;
|
||
default:
|
||
error_print();
|
||
return -1;
|
||
}
|
||
ctx->is_client = is_client ? 1 : 0;
|
||
return 1;
|
||
}
|
||
|
||
int tls_ctx_set_cipher_suites(TLS_CTX *ctx, const int *cipher_suites,
|
||
size_t cipher_suites_cnt) {
|
||
size_t i;
|
||
|
||
if (!ctx || !cipher_suites || !cipher_suites_cnt) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (cipher_suites_cnt < 1 ||
|
||
cipher_suites_cnt > TLS_MAX_CIPHER_SUITES_COUNT) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
for (i = 0; i < cipher_suites_cnt; i++) {
|
||
if (!tls_cipher_suite_name(cipher_suites[i])) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
}
|
||
for (i = 0; i < cipher_suites_cnt; i++) {
|
||
ctx->cipher_suites[i] = cipher_suites[i];
|
||
}
|
||
ctx->cipher_suites_cnt = cipher_suites_cnt;
|
||
return 1;
|
||
}
|
||
|
||
int tls_ctx_set_ca_certificates(TLS_CTX *ctx, const char *cacertsfile,
|
||
int depth) {
|
||
if (!ctx || !cacertsfile) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (depth < 0 || depth > TLS_MAX_VERIFY_DEPTH) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!tls_protocol_name(ctx->protocol)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (ctx->cacerts) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (x509_certs_new_from_file(&ctx->cacerts, &ctx->cacertslen,
|
||
cacertsfile) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (ctx->cacertslen == 0) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
ctx->verify_depth = depth;
|
||
return 1;
|
||
}
|
||
|
||
int tls_ctx_set_certificate_and_key(TLS_CTX *ctx, const char *chainfile,
|
||
const char *keyfile, const char *keypass) {
|
||
int ret = -1;
|
||
uint8_t *certs = NULL;
|
||
size_t certslen;
|
||
FILE *keyfp = NULL;
|
||
SM2_KEY key;
|
||
const uint8_t *cert;
|
||
size_t certlen;
|
||
SM2_KEY public_key;
|
||
|
||
if (!ctx || !chainfile || !keyfile || !keypass) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!tls_protocol_name(ctx->protocol)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (ctx->certs) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if (x509_certs_new_from_file(&certs, &certslen, chainfile) != 1) {
|
||
error_print();
|
||
goto end;
|
||
}
|
||
if (!(keyfp = fopen(keyfile, "r"))) {
|
||
error_print();
|
||
goto end;
|
||
}
|
||
if (sm2_private_key_info_decrypt_from_pem(&key, keypass, keyfp) != 1) {
|
||
error_print();
|
||
goto end;
|
||
}
|
||
if (x509_certs_get_cert_by_index(certs, certslen, 0, &cert, &certlen) !=
|
||
1 ||
|
||
x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (sm2_public_key_equ(&key, &public_key) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
ctx->certs = certs;
|
||
ctx->certslen = certslen;
|
||
ctx->signkey = key;
|
||
certs = NULL;
|
||
ret = 1;
|
||
|
||
end:
|
||
gmssl_secure_clear(&key, sizeof(key));
|
||
if (certs) free(certs);
|
||
if (keyfp) fclose(keyfp);
|
||
return ret;
|
||
}
|
||
|
||
int tls_ctx_set_tlcp_server_certificate_and_keys(
|
||
TLS_CTX *ctx, const char *chainfile, const char *signkeyfile,
|
||
const char *signkeypass, const char *kenckeyfile, const char *kenckeypass) {
|
||
int ret = -1;
|
||
uint8_t *certs = NULL;
|
||
size_t certslen;
|
||
FILE *signkeyfp = NULL;
|
||
FILE *kenckeyfp = NULL;
|
||
SM2_KEY signkey;
|
||
SM2_KEY kenckey;
|
||
|
||
const uint8_t *cert;
|
||
size_t certlen;
|
||
SM2_KEY public_key;
|
||
|
||
if (!ctx || !chainfile || !signkeyfile || !signkeypass || !kenckeyfile ||
|
||
!kenckeypass) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (!tls_protocol_name(ctx->protocol)) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (ctx->certs) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if (x509_certs_new_from_file(&certs, &certslen, chainfile) != 1) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
if (!(signkeyfp = fopen(signkeyfile, "r"))) {
|
||
error_print();
|
||
goto end;
|
||
}
|
||
if (sm2_private_key_info_decrypt_from_pem(&signkey, signkeypass,
|
||
signkeyfp) != 1) {
|
||
error_print();
|
||
goto end;
|
||
}
|
||
if (x509_certs_get_cert_by_index(certs, certslen, 0, &cert, &certlen) !=
|
||
1 ||
|
||
x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1 ||
|
||
sm2_public_key_equ(&signkey, &public_key) != 1) {
|
||
error_print();
|
||
goto end;
|
||
}
|
||
|
||
if (!(kenckeyfp = fopen(kenckeyfile, "r"))) {
|
||
error_print();
|
||
goto end;
|
||
}
|
||
if (sm2_private_key_info_decrypt_from_pem(&kenckey, kenckeypass,
|
||
kenckeyfp) != 1) {
|
||
error_print();
|
||
goto end;
|
||
}
|
||
if (x509_certs_get_cert_by_index(certs, certslen, 1, &cert, &certlen) !=
|
||
1 ||
|
||
x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1 ||
|
||
sm2_public_key_equ(&kenckey, &public_key) != 1) {
|
||
error_print();
|
||
goto end;
|
||
}
|
||
|
||
ctx->certs = certs;
|
||
ctx->certslen = certslen;
|
||
ctx->signkey = signkey;
|
||
ctx->kenckey = kenckey;
|
||
certs = NULL;
|
||
ret = 1;
|
||
|
||
end:
|
||
gmssl_secure_clear(&signkey, sizeof(signkey));
|
||
gmssl_secure_clear(&kenckey, sizeof(kenckey));
|
||
if (certs) free(certs);
|
||
if (signkeyfp) fclose(signkeyfp);
|
||
if (kenckeyfp) fclose(kenckeyfp);
|
||
return ret;
|
||
}
|
||
|
||
int tls_init(TLS_CONNECT *conn, const TLS_CTX *ctx) {
|
||
size_t i;
|
||
memset(conn, 0, sizeof(*conn));
|
||
|
||
conn->protocol = ctx->protocol;
|
||
conn->is_client = ctx->is_client;
|
||
for (i = 0; i < ctx->cipher_suites_cnt; i++) {
|
||
conn->cipher_suites[i] = ctx->cipher_suites[i];
|
||
}
|
||
conn->cipher_suites_cnt = ctx->cipher_suites_cnt;
|
||
|
||
if (ctx->certslen > TLS_MAX_CERTIFICATES_SIZE) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
if (conn->is_client) {
|
||
memcpy(conn->client_certs, ctx->certs, ctx->certslen);
|
||
conn->client_certs_len = ctx->certslen;
|
||
} else {
|
||
memcpy(conn->server_certs, ctx->certs, ctx->certslen);
|
||
conn->server_certs_len = ctx->certslen;
|
||
}
|
||
|
||
if (ctx->cacertslen > TLS_MAX_CERTIFICATES_SIZE) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
memcpy(conn->ca_certs, ctx->cacerts, ctx->cacertslen);
|
||
conn->ca_certs_len = ctx->cacertslen;
|
||
|
||
conn->sign_key = ctx->signkey;
|
||
conn->kenc_key = ctx->kenckey;
|
||
|
||
return 1;
|
||
}
|
||
|
||
void tls_cleanup(TLS_CONNECT *conn) {
|
||
gmssl_secure_clear(conn, sizeof(TLS_CONNECT));
|
||
}
|
||
|
||
int tls_set_socket(TLS_CONNECT *conn, tls_socket_t sock) {
|
||
#if 0
|
||
int opts;
|
||
|
||
// FIXME: do we still need this? when using select?
|
||
if ((opts = fcntl(sock, F_GETFL)) < 0) {
|
||
error_print();
|
||
perror("tls_set_socket");
|
||
return -1;
|
||
}
|
||
opts &= ~O_NONBLOCK;
|
||
if (fcntl(sock, F_SETFL, opts) < 0) {
|
||
error_print();
|
||
return -1;
|
||
}
|
||
#endif
|
||
conn->sock = sock;
|
||
return 1;
|
||
}
|
||
|
||
int tls_do_handshake(TLS_CONNECT *conn) {
|
||
switch (conn->protocol) {
|
||
case TLS_protocol_tlcp:
|
||
if (conn->is_client)
|
||
return tlcp_do_connect(conn);
|
||
else
|
||
return tlcp_do_accept(conn);
|
||
case TLS_protocol_tls12:
|
||
if (conn->is_client)
|
||
return tls12_do_connect(conn);
|
||
else
|
||
return tls12_do_accept(conn);
|
||
case TLS_protocol_tls13:
|
||
if (conn->is_client)
|
||
return tls13_do_connect(conn);
|
||
else
|
||
return tls13_do_accept(conn);
|
||
}
|
||
|
||
error_print();
|
||
return -1;
|
||
}
|
||
|
||
int tls_get_verify_result(TLS_CONNECT *conn, int *result) {
|
||
*result = conn->verify_result;
|
||
return 1;
|
||
}
|