/* * Copyright 2014-2022 The GmSSL Project. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the License); you may * not use this file except in compliance with the License. * * http://www.apache.org/licenses/LICENSE-2.0 */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include void tls_uint8_to_bytes(uint8_t a, uint8_t **out, size_t *outlen) { if (out && *out) { *(*out)++ = a; } (*outlen)++; } void tls_uint16_to_bytes(uint16_t a, uint8_t **out, size_t *outlen) { if (out && *out) { *(*out)++ = (uint8_t)(a >> 8); *(*out)++ = (uint8_t)a; } *outlen += 2; } void tls_uint24_to_bytes(uint24_t a, uint8_t **out, size_t *outlen) { if (out && *out) { *(*out)++ = (uint8_t)(a >> 16); *(*out)++ = (uint8_t)(a >> 8); *(*out)++ = (uint8_t)(a); } (*outlen) += 3; } void tls_uint32_to_bytes(uint32_t a, uint8_t **out, size_t *outlen) { if (out && *out) { *(*out)++ = (uint8_t)(a >> 24); *(*out)++ = (uint8_t)(a >> 16); *(*out)++ = (uint8_t)(a >> 8); *(*out)++ = (uint8_t)(a); } (*outlen) += 4; } void tls_array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen) { if (out && *out) { if (data) { memcpy(*out, data, datalen); } *out += datalen; } *outlen += datalen; } /* 这几个函数要区分data = NULL, datalen = 0 和 data = NULL, datalen != 0的情况 前者意味着数据为空,因此输出的就是一个长度 后者意味着数据不为空,只是我们不想输出数据,只输出头部的长度,并且更新整个的输出长度。 这种情况应该避免! */ void tls_uint8array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen) { tls_uint8_to_bytes((uint8_t)datalen, out, outlen); tls_array_to_bytes(data, datalen, out, outlen); } void tls_uint16array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen) { tls_uint16_to_bytes((uint16_t)datalen, out, outlen); tls_array_to_bytes(data, datalen, out, outlen); } void tls_uint24array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen) { tls_uint24_to_bytes((uint24_t)datalen, out, outlen); tls_array_to_bytes(data, datalen, out, outlen); } int tls_uint8_from_bytes(uint8_t *a, const uint8_t **in, size_t *inlen) { if (*inlen < 1) { error_print(); return -1; } *a = *(*in)++; (*inlen)--; return 1; } int tls_uint16_from_bytes(uint16_t *a, const uint8_t **in, size_t *inlen) { if (*inlen < 2) { error_print(); return -1; } *a = *(*in)++; *a <<= 8; *a |= *(*in)++; *inlen -= 2; return 1; } int tls_uint24_from_bytes(uint24_t *a, const uint8_t **in, size_t *inlen) { if (*inlen < 3) { error_print(); return -1; } *a = *(*in)++; *a <<= 8; *a |= *(*in)++; *a <<= 8; *a |= *(*in)++; *inlen -= 3; return 1; } int tls_uint32_from_bytes(uint32_t *a, const uint8_t **in, size_t *inlen) { if (*inlen < 4) { error_print(); return -1; } *a = *(*in)++; *a <<= 8; *a |= *(*in)++; *a <<= 8; *a |= *(*in)++; *a <<= 8; *a |= *(*in)++; *inlen -= 4; return 1; } int tls_array_from_bytes(const uint8_t **data, size_t datalen, const uint8_t **in, size_t *inlen) { if (*inlen < datalen) { error_print(); return -1; } *data = *in; *in += datalen; *inlen -= datalen; return 1; } int tls_uint8array_from_bytes(const uint8_t **data, size_t *datalen, const uint8_t **in, size_t *inlen) { uint8_t len; if (tls_uint8_from_bytes(&len, in, inlen) != 1 || tls_array_from_bytes(data, len, in, inlen) != 1) { error_print(); return -1; } if (!len) { *data = NULL; } *datalen = len; return 1; } int tls_uint16array_from_bytes(const uint8_t **data, size_t *datalen, const uint8_t **in, size_t *inlen) { uint16_t len; if (tls_uint16_from_bytes(&len, in, inlen) != 1 || tls_array_from_bytes(data, len, in, inlen) != 1) { error_print(); return -1; } if (!len) { *data = NULL; } *datalen = len; return 1; } int tls_uint24array_from_bytes(const uint8_t **data, size_t *datalen, const uint8_t **in, size_t *inlen) { uint24_t len; if (tls_uint24_from_bytes(&len, in, inlen) != 1 || tls_array_from_bytes(data, len, in, inlen) != 1) { error_print(); return -1; } if (!len) { *data = NULL; } *datalen = len; return 1; } int tls_length_is_zero(size_t len) { if (len) { error_print(); return -1; } return 1; } int tls_record_set_type(uint8_t *record, int type) { if (!tls_record_type_name(type)) { error_print(); return -1; } record[0] = (uint8_t)type; return 1; } int tls_record_set_protocol(uint8_t *record, int protocol) { if (!tls_protocol_name(protocol)) { error_print(); return -1; } record[1] = (uint8_t)(protocol >> 8); record[2] = (uint8_t)(protocol); return 1; } int tls_record_set_length(uint8_t *record, size_t length) { uint8_t *p = record + 3; size_t len; if (length > TLS_MAX_CIPHERTEXT_SIZE) { error_print(); return -1; } tls_uint16_to_bytes((uint16_t)length, &p, &len); return 1; } int tls_record_set_data(uint8_t *record, const uint8_t *data, size_t datalen) { if (tls_record_set_length(record, datalen) != 1) { error_print(); return -1; } memcpy(tls_record_data(record), data, datalen); return 1; } int tls_cbc_encrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *enc_key, const uint8_t seq_num[8], const uint8_t header[5], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { SM3_HMAC_CTX hmac_ctx; uint8_t last_blocks[32 + 16] = {0}; uint8_t *mac, *padding, *iv; int rem, padding_len; int i; if (!inited_hmac_ctx || !enc_key || !seq_num || !header || (!in && inlen) || !out || !outlen) { error_print(); return -1; } if (inlen > (1 << 14)) { error_print_msg("invalid tls record data length %zu\n", inlen); return -1; } if ((((size_t)header[3]) << 8) + header[4] != inlen) { error_print(); return -1; } rem = (inlen + 32) % 16; memcpy(last_blocks, in + inlen - rem, rem); mac = last_blocks + rem; memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(SM3_HMAC_CTX)); sm3_hmac_update(&hmac_ctx, seq_num, 8); sm3_hmac_update(&hmac_ctx, header, 5); sm3_hmac_update(&hmac_ctx, in, inlen); sm3_hmac_finish(&hmac_ctx, mac); padding = mac + 32; padding_len = 16 - rem - 1; for (i = 0; i <= padding_len; i++) { padding[i] = (uint8_t)padding_len; } iv = out; if (rand_bytes(iv, 16) != 1) { error_print(); return -1; } out += 16; if (inlen >= 16) { sm4_cbc_encrypt(enc_key, iv, in, inlen / 16, out); out += inlen - rem; iv = out - 16; } sm4_cbc_encrypt(enc_key, iv, last_blocks, sizeof(last_blocks) / 16, out); *outlen = 16 + inlen - rem + sizeof(last_blocks); return 1; } int tls_cbc_decrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *dec_key, const uint8_t seq_num[8], const uint8_t enced_header[5], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { SM3_HMAC_CTX hmac_ctx; const uint8_t *iv; const uint8_t *padding; const uint8_t *mac; uint8_t header[5]; int padding_len; uint8_t hmac[32]; int i; if (!inited_hmac_ctx || !dec_key || !seq_num || !enced_header || !in || !inlen || !out || !outlen) { error_print(); return -1; } if (inlen % 16 || inlen < (16 + 0 + 32 + 16) // iv + data + mac + padding || inlen > (16 + (1 << 14) + 32 + 256)) { error_print_msg("invalid tls cbc ciphertext length %zu\n", inlen); return -1; } iv = in; in += 16; inlen -= 16; sm4_cbc_decrypt(dec_key, iv, in, inlen / 16, out); padding_len = out[inlen - 1]; padding = out + inlen - padding_len - 1; if (padding < out + 32) { error_print(); return -1; } for (i = 0; i < padding_len; i++) { if (padding[i] != padding_len) { error_puts("tls ciphertext cbc-padding check failure"); return -1; } } *outlen = inlen - 32 - padding_len - 1; header[0] = enced_header[0]; header[1] = enced_header[1]; header[2] = enced_header[2]; header[3] = (uint8_t)((*outlen) >> 8); header[4] = (uint8_t)(*outlen); mac = padding - 32; memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(SM3_HMAC_CTX)); sm3_hmac_update(&hmac_ctx, seq_num, 8); sm3_hmac_update(&hmac_ctx, header, 5); sm3_hmac_update(&hmac_ctx, out, *outlen); sm3_hmac_finish(&hmac_ctx, hmac); if (gmssl_secure_memcmp(mac, hmac, sizeof(hmac)) != 0) { error_puts("tls ciphertext mac check failure\n"); return -1; } return 1; } int tls_record_encrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key, const uint8_t seq_num[8], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { if (tls_cbc_encrypt(hmac_ctx, cbc_key, seq_num, in, in + 5, inlen - 5, out + 5, outlen) != 1) { error_print(); return -1; } out[0] = in[0]; out[1] = in[1]; out[2] = in[2]; out[3] = (uint8_t)((*outlen) >> 8); out[4] = (uint8_t)(*outlen); (*outlen) += 5; return 1; } int tls_record_decrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key, const uint8_t seq_num[8], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { if (tls_cbc_decrypt(hmac_ctx, cbc_key, seq_num, in, in + 5, inlen - 5, out + 5, outlen) != 1) { error_print(); return -1; } out[0] = in[0]; out[1] = in[1]; out[2] = in[2]; out[3] = (uint8_t)((*outlen) >> 8); out[4] = (uint8_t)(*outlen); (*outlen) += 5; return 1; } int tls_random_generate(uint8_t random[32]) { uint32_t gmt_unix_time = (uint32_t)time(NULL); uint8_t *p = random; size_t len = 0; tls_uint32_to_bytes(gmt_unix_time, &p, &len); if (rand_bytes(random + 4, 28) != 1) { error_print(); return -1; } return 1; } int tls_prf(const uint8_t *secret, size_t secretlen, const char *label, const uint8_t *seed, size_t seedlen, const uint8_t *more, size_t morelen, size_t outlen, uint8_t *out) { SM3_HMAC_CTX inited_hmac_ctx; SM3_HMAC_CTX hmac_ctx; uint8_t A[32]; uint8_t hmac[32]; size_t len; if (!secret || !secretlen || !label || !seed || !seedlen || (!more && morelen) || !outlen || !out) { error_print(); return -1; } sm3_hmac_init(&inited_hmac_ctx, secret, secretlen); memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX)); sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); sm3_hmac_update(&hmac_ctx, seed, seedlen); sm3_hmac_update(&hmac_ctx, more, morelen); sm3_hmac_finish(&hmac_ctx, A); memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX)); sm3_hmac_update(&hmac_ctx, A, sizeof(A)); sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); sm3_hmac_update(&hmac_ctx, seed, seedlen); sm3_hmac_update(&hmac_ctx, more, morelen); sm3_hmac_finish(&hmac_ctx, hmac); len = outlen < sizeof(hmac) ? outlen : sizeof(hmac); memcpy(out, hmac, len); out += len; outlen -= len; while (outlen) { memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX)); sm3_hmac_update(&hmac_ctx, A, sizeof(A)); sm3_hmac_finish(&hmac_ctx, A); memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX)); sm3_hmac_update(&hmac_ctx, A, sizeof(A)); sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); sm3_hmac_update(&hmac_ctx, seed, seedlen); sm3_hmac_update(&hmac_ctx, more, morelen); sm3_hmac_finish(&hmac_ctx, hmac); len = outlen < sizeof(hmac) ? outlen : sizeof(hmac); memcpy(out, hmac, len); out += len; outlen -= len; } return 1; } int tls_pre_master_secret_generate(uint8_t pre_master_secret[48], int protocol) { if (!tls_protocol_name(protocol)) { error_print(); return -1; } pre_master_secret[0] = (uint8_t)(protocol >> 8); pre_master_secret[1] = (uint8_t)(protocol); if (rand_bytes(pre_master_secret + 2, 46) != 1) { error_print(); return -1; } return 1; } // 用于设置CertificateRequest int tls_cert_type_from_oid(int oid) { switch (oid) { case OID_sm2sign_with_sm3: case OID_ecdsa_with_sha1: case OID_ecdsa_with_sha224: case OID_ecdsa_with_sha256: case OID_ecdsa_with_sha512: return TLS_cert_type_ecdsa_sign; case OID_rsasign_with_sm3: case OID_rsasign_with_md5: case OID_rsasign_with_sha1: case OID_rsasign_with_sha224: case OID_rsasign_with_sha256: case OID_rsasign_with_sha384: case OID_rsasign_with_sha512: return TLS_cert_type_rsa_sign; } // TLS_cert_type_xxx 中没有为0的值 return 0; } // 这两个函数没有对应的TLCP版本 int tls_sign_server_ecdh_params(const SM2_KEY *server_sign_key, const uint8_t client_random[32], const uint8_t server_random[32], int curve, const SM2_POINT *point, uint8_t *sig, size_t *siglen) { uint8_t server_ecdh_params[69]; SM2_SIGN_CTX sign_ctx; if (!server_sign_key || !client_random || !server_random || curve != TLS_curve_sm2p256v1 || !point || !sig || !siglen) { error_print(); return -1; } server_ecdh_params[0] = TLS_curve_type_named_curve; server_ecdh_params[1] = (uint8_t)(curve >> 8); server_ecdh_params[2] = (uint8_t)curve; server_ecdh_params[3] = 65; sm2_point_to_uncompressed_octets(point, server_ecdh_params + 4); sm2_sign_init(&sign_ctx, server_sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH); sm2_sign_update(&sign_ctx, client_random, 32); sm2_sign_update(&sign_ctx, server_random, 32); sm2_sign_update(&sign_ctx, server_ecdh_params, 69); sm2_sign_finish(&sign_ctx, sig, siglen); return 1; } int tls_verify_server_ecdh_params(const SM2_KEY *server_sign_key, const uint8_t client_random[32], const uint8_t server_random[32], int curve, const SM2_POINT *point, const uint8_t *sig, size_t siglen) { int ret; uint8_t server_ecdh_params[69]; SM2_SIGN_CTX verify_ctx; if (!server_sign_key || !client_random || !server_random || curve != TLS_curve_sm2p256v1 || !point || !sig || !siglen || siglen > SM2_MAX_SIGNATURE_SIZE) { error_print(); return -1; } server_ecdh_params[0] = TLS_curve_type_named_curve; server_ecdh_params[1] = (uint8_t)(curve >> 8); server_ecdh_params[2] = (uint8_t)(curve); server_ecdh_params[3] = 65; sm2_point_to_uncompressed_octets(point, server_ecdh_params + 4); sm2_verify_init(&verify_ctx, server_sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH); sm2_verify_update(&verify_ctx, client_random, 32); sm2_verify_update(&verify_ctx, server_random, 32); sm2_verify_update(&verify_ctx, server_ecdh_params, 69); ret = sm2_verify_finish(&verify_ctx, sig, siglen); if (ret != 1) error_print(); return ret; } int tls_record_set_handshake(uint8_t *record, size_t *recordlen, int type, const uint8_t *data, size_t datalen) { size_t handshakelen; if (!record || !recordlen) { error_print(); return -1; } // 由于ServerHelloDone没有负载数据,因此允许 data,datalen = NULL,0 if (datalen > TLS_MAX_PLAINTEXT_SIZE - TLS_HANDSHAKE_HEADER_SIZE) { error_print(); return -1; } if (!tls_protocol_name(tls_record_protocol(record))) { error_print(); return -1; } if (!tls_handshake_type_name(type)) { error_print(); return -1; } handshakelen = TLS_HANDSHAKE_HEADER_SIZE + datalen; record[0] = TLS_record_handshake; record[3] = (uint8_t)(handshakelen >> 8); record[4] = (uint8_t)(handshakelen); record[5] = (uint8_t)(type); record[6] = (uint8_t)(datalen >> 16); record[7] = (uint8_t)(datalen >> 8); record[8] = (uint8_t)(datalen); if (data) { memcpy(tls_handshake_data(tls_record_data(record)), data, datalen); } *recordlen = TLS_RECORD_HEADER_SIZE + handshakelen; return 1; } int tls_record_get_handshake(const uint8_t *record, int *type, const uint8_t **data, size_t *datalen) { const uint8_t *handshake; size_t handshake_len; uint24_t handshake_datalen; if (!record || !type || !data || !datalen) { error_print(); return -1; } if (!tls_protocol_name(tls_record_protocol(record))) { error_print(); return -1; } if (tls_record_type(record) != TLS_record_handshake) { error_print(); return -1; } handshake = tls_record_data(record); handshake_len = tls_record_data_length(record); if (handshake_len < TLS_HANDSHAKE_HEADER_SIZE) { error_print(); return -1; } if (handshake_len > TLS_MAX_PLAINTEXT_SIZE) { // 不支持证书长度超过记录长度的特殊情况 error_print(); return -1; } if (!tls_handshake_type_name(handshake[0])) { error_print(); return -1; } *type = handshake[0]; handshake++; handshake_len--; if (tls_uint24_from_bytes(&handshake_datalen, &handshake, &handshake_len) != 1) { error_print(); return -1; } if (handshake_len != handshake_datalen) { error_print(); return -1; } *data = handshake; *datalen = handshake_datalen; if (*datalen == 0) { *data = NULL; } return 1; } int tls_record_set_handshake_client_hello( uint8_t *record, size_t *recordlen, int protocol, const uint8_t random[32], const uint8_t *session_id, size_t session_id_len, const int *cipher_suites, size_t cipher_suites_count, const uint8_t *exts, size_t exts_len) { uint8_t type = TLS_handshake_client_hello; uint8_t *p; size_t len; if (!record || !recordlen || !random || !cipher_suites || !cipher_suites_count) { error_print(); return -1; } if (session_id) { if (!session_id_len || session_id_len < TLS_MAX_SESSION_ID_SIZE || session_id_len > TLS_MAX_SESSION_ID_SIZE) { error_print(); return -1; } } if (cipher_suites_count > TLS_MAX_CIPHER_SUITES_COUNT) { error_print(); return -1; } if (exts && !exts_len) { error_print(); return -1; } p = tls_handshake_data(tls_record_data(record)); len = 0; if (!tls_protocol_name(protocol)) { error_print(); return -1; } tls_uint16_to_bytes((uint16_t)protocol, &p, &len); tls_array_to_bytes(random, 32, &p, &len); tls_uint8array_to_bytes(session_id, session_id_len, &p, &len); tls_uint16_to_bytes((uint16_t)(cipher_suites_count * 2), &p, &len); while (cipher_suites_count--) { if (!tls_cipher_suite_name(*cipher_suites)) { error_print(); return -1; } tls_uint16_to_bytes((uint16_t)*cipher_suites, &p, &len); cipher_suites++; } tls_uint8_to_bytes(1, &p, &len); tls_uint8_to_bytes((uint8_t)TLS_compression_null, &p, &len); if (exts) { size_t tmp_len = len; if (protocol < TLS_protocol_tls12) { error_print(); return -1; } tls_uint16array_to_bytes(exts, exts_len, NULL, &tmp_len); if (tmp_len > TLS_MAX_HANDSHAKE_DATA_SIZE) { error_print(); return -1; } tls_uint16array_to_bytes(exts, exts_len, &p, &len); } if (tls_record_set_handshake(record, recordlen, type, NULL, len) != 1) { error_print(); return -1; } return 1; } int tls_record_get_handshake_client_hello( const uint8_t *record, int *protocol, const uint8_t **random, const uint8_t **session_id, size_t *session_id_len, const uint8_t **cipher_suites, size_t *cipher_suites_len, const uint8_t **exts, size_t *exts_len) { int type; const uint8_t *p; size_t len; uint16_t ver; const uint8_t *comp_meths; size_t comp_meths_len; if (!record || !protocol || !random || !session_id || !session_id_len || !cipher_suites || !cipher_suites_len || !exts || !exts_len) { error_print(); return -1; } if (tls_record_get_handshake(record, &type, &p, &len) != 1) { error_print(); return -1; } if (type != TLS_handshake_client_hello) { error_print(); return -1; } if (tls_uint16_from_bytes(&ver, &p, &len) != 1 || tls_array_from_bytes(random, 32, &p, &len) != 1 || tls_uint8array_from_bytes(session_id, session_id_len, &p, &len) != 1 || tls_uint16array_from_bytes(cipher_suites, cipher_suites_len, &p, &len) != 1 || tls_uint8array_from_bytes(&comp_meths, &comp_meths_len, &p, &len) != 1) { error_print(); return -1; } if (!tls_protocol_name(ver)) { error_print(); return -1; } *protocol = ver; if (*session_id) { if (*session_id_len == 0 || *session_id_len < TLS_MIN_SESSION_ID_SIZE || *session_id_len > TLS_MAX_SESSION_ID_SIZE) { error_print(); return -1; } } if (!cipher_suites) { error_print(); return -1; } if (*cipher_suites_len % 2) { error_print(); return -1; } if (len) { if (tls_uint16array_from_bytes(exts, exts_len, &p, &len) != 1) { error_print(); return -1; } if (*exts == NULL) { error_print(); return -1; } } else { *exts = NULL; *exts_len = 0; } if (len) { error_print(); return -1; } return 1; } int tls_record_set_handshake_server_hello( uint8_t *record, size_t *recordlen, int protocol, const uint8_t random[32], const uint8_t *session_id, size_t session_id_len, int cipher_suite, const uint8_t *exts, size_t exts_len) { uint8_t type = TLS_handshake_server_hello; uint8_t *p; size_t len; if (!record || !recordlen || !random) { error_print(); return -1; } if (session_id) { if (session_id_len == 0 || session_id_len < TLS_MIN_SESSION_ID_SIZE || session_id_len > TLS_MAX_SESSION_ID_SIZE) { error_print(); return -1; } } if (!tls_protocol_name(protocol)) { error_print(); return -1; } if (!tls_cipher_suite_name(cipher_suite)) { error_print(); return -1; } p = tls_handshake_data(tls_record_data(record)); len = 0; tls_uint16_to_bytes((uint16_t)protocol, &p, &len); tls_array_to_bytes(random, 32, &p, &len); tls_uint8array_to_bytes(session_id, session_id_len, &p, &len); tls_uint16_to_bytes((uint16_t)cipher_suite, &p, &len); tls_uint8_to_bytes((uint8_t)TLS_compression_null, &p, &len); if (exts) { if (protocol < TLS_protocol_tls12) { error_print(); return -1; } tls_uint16array_to_bytes(exts, exts_len, &p, &len); } if (tls_record_set_handshake(record, recordlen, type, NULL, len) != 1) { error_print(); return -1; } return 1; } int tls_record_get_handshake_server_hello( const uint8_t *record, int *protocol, const uint8_t **random, const uint8_t **session_id, size_t *session_id_len, int *cipher_suite, const uint8_t **exts, size_t *exts_len) { int type; const uint8_t *p; size_t len; uint16_t ver; uint16_t cipher; uint8_t comp_meth; if (!record || !protocol || !random || !session_id || !session_id_len || !cipher_suite || !exts || !exts_len) { error_print(); return -1; } if (tls_record_get_handshake(record, &type, &p, &len) != 1) { error_print(); return -1; } if (type != TLS_handshake_server_hello) { error_print(); return -1; } if (tls_uint16_from_bytes(&ver, &p, &len) != 1 || tls_array_from_bytes(random, 32, &p, &len) != 1 || tls_uint8array_from_bytes(session_id, session_id_len, &p, &len) != 1 || tls_uint16_from_bytes(&cipher, &p, &len) != 1 || tls_uint8_from_bytes(&comp_meth, &p, &len) != 1) { error_print(); return -1; } if (!tls_protocol_name(ver)) { error_print(); return -1; } if (ver < tls_record_protocol(record)) { error_print(); return -1; } *protocol = ver; if (*session_id) { if (*session_id == 0 || *session_id_len < TLS_MIN_SESSION_ID_SIZE || *session_id_len > TLS_MAX_SESSION_ID_SIZE) { error_print(); return -1; } } if (!tls_cipher_suite_name(cipher)) { error_print(); return -1; } *cipher_suite = cipher; if (comp_meth != TLS_compression_null) { error_print(); return -1; } if (len) { if (tls_uint16array_from_bytes(exts, exts_len, &p, &len) != 1) { error_print(); return -1; } if (*exts == NULL) { error_print(); return -1; } } else { *exts = NULL; *exts_len = 0; } if (len) { error_print(); return -1; } return 1; } int tls_record_set_handshake_certificate(uint8_t *record, size_t *recordlen, const uint8_t *certs, size_t certslen) { int type = TLS_handshake_certificate; uint8_t *data; size_t datalen; uint8_t *p; size_t len; if (!record || !recordlen || !certs || !certslen) { error_print(); return -1; } data = tls_handshake_data(tls_record_data(record)); p = data + tls_uint24_size(); datalen = tls_uint24_size(); len = 0; while (certslen) { const uint8_t *cert; size_t certlen; if (x509_cert_from_der(&cert, &certlen, &certs, &certslen) != 1) { error_print(); return -1; } tls_uint24array_to_bytes(cert, certlen, NULL, &datalen); if (datalen > TLS_MAX_HANDSHAKE_DATA_SIZE) { error_print(); return -1; } tls_uint24array_to_bytes(cert, certlen, &p, &len); } tls_uint24_to_bytes((uint24_t)len, &data, &len); tls_record_set_handshake(record, recordlen, type, NULL, datalen); return 1; } int tls_record_get_handshake_certificate(const uint8_t *record, uint8_t *certs, size_t *certslen) { int type; const uint8_t *data; size_t datalen; const uint8_t *cp; size_t len; if (tls_record_get_handshake(record, &type, &data, &datalen) != 1) { error_print(); return -1; } if (type != TLS_handshake_certificate) { error_print(); return -1; } if (tls_uint24array_from_bytes(&cp, &len, &data, &datalen) != 1) { error_print(); return -1; } *certslen = 0; while (len) { const uint8_t *a; size_t alen; const uint8_t *cert; size_t certlen; if (tls_uint24array_from_bytes(&a, &alen, &cp, &len) != 1) { error_print(); return -1; } if (x509_cert_from_der(&cert, &certlen, &a, &alen) != 1 || asn1_length_is_zero(alen) != 1 || x509_cert_to_der(cert, certlen, &certs, certslen) != 1) { error_print(); return -1; } } return 1; } int tls_record_set_handshake_certificate_request( uint8_t *record, size_t *recordlen, const uint8_t *cert_types, size_t cert_types_len, const uint8_t *ca_names, size_t ca_names_len) { int type = TLS_handshake_certificate_request; uint8_t *p; size_t len = 0; size_t datalen = 0; if (!record || !recordlen) { error_print(); return -1; } if (cert_types) { if (cert_types_len == 0 || cert_types_len > TLS_MAX_CERTIFICATE_TYPES) { error_print(); return -1; } } if (ca_names) { if (ca_names_len == 0 || ca_names_len > TLS_MAX_CA_NAMES_SIZE) { error_print(); return -1; } } tls_uint8array_to_bytes(cert_types, cert_types_len, NULL, &datalen); tls_uint16array_to_bytes(ca_names, ca_names_len, NULL, &datalen); if (datalen > TLS_MAX_HANDSHAKE_DATA_SIZE) { error_print(); return -1; } p = tls_handshake_data(tls_record_data(record)); tls_uint8array_to_bytes(cert_types, cert_types_len, &p, &len); tls_uint16array_to_bytes(ca_names, ca_names_len, &p, &len); tls_record_set_handshake(record, recordlen, type, NULL, datalen); return 1; } int tls_record_get_handshake_certificate_request(const uint8_t *record, const uint8_t **cert_types, size_t *cert_types_len, const uint8_t **ca_names, size_t *ca_names_len) { int type; const uint8_t *cp; size_t len; size_t i; if (!record || !cert_types || !cert_types_len || !ca_names || !ca_names_len) { error_print(); return -1; } if (tls_record_get_handshake(record, &type, &cp, &len) != 1) { error_print(); return -1; } if (type != TLS_handshake_certificate_request) { error_print(); return -1; } if (tls_uint8array_from_bytes(cert_types, cert_types_len, &cp, &len) != 1 || tls_uint16array_from_bytes(ca_names, ca_names_len, &cp, &len) != 1 || tls_length_is_zero(len) != 1) { error_print(); return -1; } if (*cert_types == NULL) { error_print(); return -1; } for (i = 0; i < *cert_types_len; i++) { if (!tls_cert_type_name((*cert_types)[i])) { error_print(); return -1; } } if (*ca_names) { const uint8_t *names = *ca_names; size_t nameslen = *ca_names_len; while (nameslen) { if (tls_uint16array_from_bytes(&cp, &len, &names, &nameslen) != 1) { error_print(); return -1; } } } return 1; } int tls_record_set_handshake_server_hello_done(uint8_t *record, size_t *recordlen) { int type = TLS_handshake_server_hello_done; if (!record || !recordlen) { error_print(); return -1; } tls_record_set_handshake(record, recordlen, type, NULL, 0); return 1; } int tls_record_get_handshake_server_hello_done(const uint8_t *record) { int type; const uint8_t *p; size_t len; if (!record) { error_print(); return -1; } if (tls_record_get_handshake(record, &type, &p, &len) != 1 || type != TLS_handshake_server_hello_done) { error_print(); return -1; } if (p != NULL || len != 0) { error_print(); return -1; } return 1; } int tls_record_set_handshake_client_key_exchange_pke(uint8_t *record, size_t *recordlen, const uint8_t *enced_pms, size_t enced_pms_len) { int type = TLS_handshake_client_key_exchange; uint8_t *p; size_t len = 0; if (!record || !recordlen || !enced_pms || !enced_pms_len) { error_print(); return -1; } if (enced_pms_len > TLS_MAX_HANDSHAKE_DATA_SIZE - tls_uint16_size()) { error_print(); return -1; } p = tls_handshake_data(tls_record_data(record)); tls_uint16array_to_bytes(enced_pms, enced_pms_len, &p, &len); tls_record_set_handshake(record, recordlen, type, NULL, len); return 1; } int tls_record_get_handshake_client_key_exchange_pke(const uint8_t *record, const uint8_t **enced_pms, size_t *enced_pms_len) { int type; const uint8_t *cp; size_t len; if (!record || !enced_pms || !enced_pms_len) { error_print(); return -1; } if (tls_record_get_handshake(record, &type, &cp, &len) != 1) { error_print(); return -1; } if (type != TLS_handshake_client_key_exchange) { error_print(); return -1; } if (tls_uint16array_from_bytes(enced_pms, enced_pms_len, &cp, &len) != 1 || tls_length_is_zero(len) != 1) { error_print(); return -1; } return 1; } int tls_record_set_handshake_certificate_verify(uint8_t *record, size_t *recordlen, const uint8_t *sig, size_t siglen) { int type = TLS_handshake_certificate_verify; if (!record || !recordlen || !sig || !siglen) { error_print(); return -1; } if (siglen > TLS_MAX_SIGNATURE_SIZE) { error_print(); return -1; } tls_record_set_handshake(record, recordlen, type, sig, siglen); return 1; } int tls_record_get_handshake_certificate_verify(const uint8_t *record, const uint8_t **sig, size_t *siglen) { int type; if (!record || !sig || !siglen) { error_print(); return -1; } if (tls_record_get_handshake(record, &type, sig, siglen) != 1) { error_print(); return -1; } if (type != TLS_handshake_certificate_verify) { error_print(); return -1; } if (*sig == NULL || *siglen == 0) { error_print(); return -1; } if (*siglen > TLS_MAX_SIGNATURE_SIZE) { error_print(); return -1; } return 1; } int tls_record_set_handshake_finished(uint8_t *record, size_t *recordlen, const uint8_t *verify_data, size_t verify_data_len) { int type = TLS_handshake_finished; if (!record || !recordlen || !verify_data || !verify_data_len) { error_print(); return -1; } if (verify_data_len != 12 && verify_data_len != 32) { error_print(); return -1; } tls_record_set_handshake(record, recordlen, type, verify_data, verify_data_len); return 1; } int tls_record_get_handshake_finished(const uint8_t *record, const uint8_t **verify_data, size_t *verify_data_len) { int type; if (!record || !verify_data || !verify_data_len) { error_print(); return -1; } if (tls_record_get_handshake(record, &type, verify_data, verify_data_len) != 1) { error_print(); return -1; } if (type != TLS_handshake_finished) { error_print(); return -1; } if (*verify_data == NULL || *verify_data_len == 0) { error_print(); return -1; } if (*verify_data_len != 12 && *verify_data_len != 32) { error_print(); return -1; } return 1; } int tls_record_set_alert(uint8_t *record, size_t *recordlen, int alert_level, int alert_description) { if (!record || !recordlen) { error_print(); return -1; } if (!tls_alert_level_name(alert_level)) { error_print(); return -1; } if (!tls_alert_description_text(alert_description)) { error_print(); return -1; } record[0] = TLS_record_alert; record[3] = 0; // length record[4] = 2; // length record[5] = (uint8_t)alert_level; record[6] = (uint8_t)alert_description; *recordlen = TLS_RECORD_HEADER_SIZE + 2; return 1; } int tls_record_get_alert(const uint8_t *record, int *alert_level, int *alert_description) { if (!record || !alert_level || !alert_description) { error_print(); return -1; } if (tls_record_type(record) != TLS_record_alert) { error_print(); return -1; } if (record[3] != 0 || record[4] != 2) { error_print(); return -1; } *alert_level = record[5]; *alert_description = record[6]; if (!tls_alert_level_name(*alert_level)) { error_print(); return -1; } if (!tls_alert_description_text(*alert_description)) { error_puts("warning"); return -1; } return 1; } int tls_record_set_change_cipher_spec(uint8_t *record, size_t *recordlen) { if (!record || !recordlen) { error_print(); return -1; } record[0] = TLS_record_change_cipher_spec; record[3] = 0; record[4] = 1; record[5] = TLS_change_cipher_spec; *recordlen = TLS_RECORD_HEADER_SIZE + 1; return 1; } int tls_record_get_change_cipher_spec(const uint8_t *record) { if (!record) { error_print(); return -1; } if (tls_record_type(record) != TLS_record_change_cipher_spec) { error_print(); return -1; } if (record[3] != 0 || record[4] != 1) { error_print(); return -1; } if (record[5] != TLS_change_cipher_spec) { error_print(); return -1; } return 1; } int tls_record_set_application_data(uint8_t *record, size_t *recordlen, const uint8_t *data, size_t datalen) { if (!record || !recordlen || !data || !datalen) { error_print(); return -1; } record[0] = TLS_record_application_data; record[3] = (datalen >> 8) & 0xff; record[4] = datalen & 0xff; memcpy(tls_record_data(record), data, datalen); *recordlen = TLS_RECORD_HEADER_SIZE + datalen; return 1; } int tls_record_get_application_data(uint8_t *record, const uint8_t **data, size_t *datalen) { if (!record || !data || !datalen) { error_print(); return -1; } if (tls_record_type(record) != TLS_record_application_data) { error_print(); return -1; } *datalen = ((size_t)record[3] << 8) | record[4]; *data = *datalen ? record + TLS_RECORD_HEADER_SIZE : 0; return 1; } int tls_cipher_suite_in_list(int cipher, const int *list, size_t list_count) { size_t i; if (!list || !list_count) { error_print(); return -1; } for (i = 0; i < list_count; i++) { if (cipher == list[i]) { return 1; } } return 0; } int tls_record_send(const uint8_t *record, size_t recordlen, tls_socket_t sock) { tls_ret_t r; if (!record) { error_print(); return -1; } if (recordlen < TLS_RECORD_HEADER_SIZE) { error_print(); return -1; } if (tls_record_length(record) != recordlen) { error_print(); return -1; } if ((r = tls_socket_send(sock, record, recordlen, 0)) < 0) { perror("tls_record_send"); error_print(); return -1; } else if (r != recordlen) { error_print(); return -1; } return 1; } int tls_record_do_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock) { tls_ret_t r; size_t len; len = 5; while (len) { while ((r = tls_socket_recv(sock, record + 5 - len, len, 0)) < 0) { if (errno == EAGAIN) { continue; } else { error_print(); return -1; } } error_print(); if (r == 0) { perror("tls_record_do_recv"); error_print(); return 0; } len -= r; } if (!tls_record_type_name(tls_record_type(record))) { error_print(); return -1; } if (!tls_protocol_name(tls_record_protocol(record))) { error_print(); return -1; } len = (size_t)record[3] << 8 | record[4]; *recordlen = 5 + len; if (*recordlen > TLS_MAX_RECORD_SIZE) { // 这里只检查是否超过最大长度,握手协议的长度检查由上层协议完成 error_print(); return -1; } while (len) { if ((r = recv(sock, record + *recordlen - len, len, 0)) < 0) { perror("tls_record_do_recv"); error_print(); return -1; } len -= r; } return 1; } int tls_record_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock) { retry: if (tls_record_do_recv(record, recordlen, sock) != 1) { error_print(); return -1; } if (tls_record_type(record) == TLS_record_alert) { int level; int alert; if (tls_record_get_alert(record, &level, &alert) != 1) { error_print(); return -1; } tls_record_trace(stderr, record, *recordlen, 0, 0); if (level == TLS_alert_level_warning) { // 忽略Warning,读取下一个记录 error_puts("Warning record received!\n"); goto retry; } if (alert == TLS_alert_close_notify) { // close_notify是唯一需要提供反馈的Fatal Alert,其他直接中止连接 uint8_t alert_record[TLS_ALERT_RECORD_SIZE]; size_t alert_record_len; tls_record_set_type(alert_record, TLS_record_alert); tls_record_set_protocol(alert_record, tls_record_protocol(record)); tls_record_set_alert(alert_record, &alert_record_len, TLS_alert_level_fatal, TLS_alert_close_notify); tls_trace("send Alert close_notifiy\n"); tls_record_trace(stderr, alert_record, alert_record_len, 0, 0); tls_record_send(alert_record, alert_record_len, sock); } // 返回错误0通知调用方不再做任何处理(无需再发送Alert) return 0; } return 1; } int tls_seq_num_incr(uint8_t seq_num[8]) { int i; for (i = 7; i > 0; i--) { seq_num[i]++; if (seq_num[i]) break; } // FIXME: 检查溢出 return 1; } int tls_compression_methods_has_null_compression(const uint8_t *meths, size_t methslen) { if (!meths || !methslen) { error_print(); return -1; } while (methslen--) { if (*meths++ == TLS_compression_null) { return 1; } } error_print(); return -1; } int tls_send_alert(TLS_CONNECT *conn, int alert) { uint8_t record[5 + 2]; size_t recordlen; if (!conn) { error_print(); return -1; } tls_record_set_protocol(record, conn->protocol == TLS_protocol_tls13 ? TLS_protocol_tls12 : conn->protocol); tls_record_set_alert(record, &recordlen, TLS_alert_level_fatal, alert); if (tls_record_send(record, sizeof(record), conn->sock) != 1) { error_print(); return -1; } tls_record_trace(stderr, record, sizeof(record), 0, 0); return 1; } int tls_alert_level(int alert) { switch (alert) { case TLS_alert_bad_certificate: case TLS_alert_unsupported_certificate: case TLS_alert_certificate_revoked: case TLS_alert_certificate_expired: case TLS_alert_certificate_unknown: return 0; case TLS_alert_user_canceled: case TLS_alert_no_renegotiation: return TLS_alert_level_warning; } return TLS_alert_level_fatal; } int tls_send_warning(TLS_CONNECT *conn, int alert) { uint8_t record[5 + 2]; size_t recordlen; if (!conn) { error_print(); return -1; } if (tls_alert_level(alert) == TLS_alert_level_fatal) { error_print(); return -1; } tls_record_set_protocol(record, conn->protocol == TLS_protocol_tls13 ? TLS_protocol_tls12 : conn->protocol); tls_record_set_alert(record, &recordlen, TLS_alert_level_warning, alert); if (tls_record_send(record, sizeof(record), conn->sock) != 1) { error_print(); return -1; } tls_record_trace(stderr, record, sizeof(record), 0, 0); return 1; } int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen) { const SM3_HMAC_CTX *hmac_ctx; const SM4_KEY *enc_key; uint8_t *seq_num; uint8_t *record; size_t datalen; if (!conn) { error_print(); return -1; } if (!in || !inlen || !sentlen) { error_print(); return -1; } if (inlen > TLS_MAX_PLAINTEXT_SIZE) { inlen = TLS_MAX_PLAINTEXT_SIZE; } if (conn->is_client) { hmac_ctx = &conn->client_write_mac_ctx; enc_key = &conn->client_write_enc_key; seq_num = conn->client_seq_num; } else { hmac_ctx = &conn->server_write_mac_ctx; enc_key = &conn->server_write_enc_key; seq_num = conn->server_seq_num; } record = conn->record; tls_trace("send ApplicationData\n"); if (tls_record_set_type(record, TLS_record_application_data) != 1 || tls_record_set_protocol(record, conn->protocol) != 1 || tls_record_set_length(record, inlen) != 1) { error_print(); return -1; } if (tls_cbc_encrypt(hmac_ctx, enc_key, seq_num, tls_record_header(record), in, inlen, tls_record_data(record), &datalen) != 1) { error_print(); return -1; } if (tls_record_set_length(record, datalen) != 1) { error_print(); return -1; } tls_seq_num_incr(seq_num); if (tls_record_send(record, tls_record_length(record), conn->sock) != 1) { error_print(); return -1; } *sentlen = inlen; tls_record_trace(stderr, record, tls_record_length(record), 0, 0); return 1; } int tls_do_recv(TLS_CONNECT *conn) { int ret; const SM3_HMAC_CTX *hmac_ctx; const SM4_KEY *dec_key; uint8_t *seq_num; uint8_t *record = conn->record; size_t recordlen; if (conn->is_client) { hmac_ctx = &conn->server_write_mac_ctx; dec_key = &conn->server_write_enc_key; seq_num = conn->server_seq_num; } else { hmac_ctx = &conn->client_write_mac_ctx; dec_key = &conn->client_write_enc_key; seq_num = conn->client_seq_num; } tls_trace("recv ApplicationData\n"); if ((ret = tls_record_recv(record, &recordlen, conn->sock)) != 1) { if (ret < 0) error_print(); return ret; } tls_record_trace(stderr, record, recordlen, 0, 0); if (tls_cbc_decrypt(hmac_ctx, dec_key, seq_num, record, tls_record_data(record), tls_record_data_length(record), conn->databuf, &conn->datalen) != 1) { error_print(); return -1; } conn->data = conn->databuf; tls_seq_num_incr(seq_num); tls_record_set_data(record, conn->data, conn->datalen); tls_trace("decrypt ApplicationData\n"); tls_record_trace(stderr, record, tls_record_length(record), 0, 0); return 1; } int tls_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen) { if (!conn || !out || !outlen || !recvlen) { error_print(); return -1; } if (conn->datalen == 0) { int ret; if ((ret = tls_do_recv(conn)) != 1) { if (ret) error_print(); return ret; } } *recvlen = outlen <= conn->datalen ? outlen : conn->datalen; memcpy(out, conn->data, *recvlen); conn->data += *recvlen; conn->datalen -= *recvlen; return 1; } int tls_shutdown(TLS_CONNECT *conn) { size_t recordlen; if (!conn) { error_print(); return -1; } tls_trace("send Alert close_notify\n"); if (tls_send_alert(conn, TLS_alert_close_notify) != 1) { error_print(); return -1; } tls_trace("recv Alert close_notify\n"); if (tls_record_do_recv(conn->record, &recordlen, conn->sock) != 1) { error_print(); return -1; } tls_record_trace(stderr, conn->record, recordlen, 0, 0); return 1; } int tls_authorities_from_certs(uint8_t *names, size_t *nameslen, size_t maxlen, const uint8_t *certs, size_t certslen) { const uint8_t *cert; size_t certlen; const uint8_t *name; size_t namelen; *nameslen = 0; while (certslen) { size_t alen = 0; if (x509_cert_from_der(&cert, &certlen, &certs, &certslen) != 1 || x509_cert_get_subject(cert, certlen, &name, &namelen) != 1 || asn1_sequence_to_der(name, namelen, NULL, &alen) != 1) { error_print(); return -1; } if (tls_uint16_size() + alen > maxlen) { error_print(); return -1; } if (alen > UINT16_MAX) { error_print(); return -1; } tls_uint16_to_bytes((uint16_t)alen, &names, nameslen); if (asn1_sequence_to_der(name, namelen, &names, nameslen) != 1) { error_print(); return -1; } maxlen -= alen; } return 1; } int tls_authorities_issued_certificate(const uint8_t *ca_names, size_t ca_names_len, const uint8_t *certs, size_t certslen) { const uint8_t *cert; size_t certlen; const uint8_t *issuer; size_t issuer_len; if (x509_certs_get_last(certs, certslen, &cert, &certlen) != 1 || x509_cert_get_issuer(cert, certlen, &issuer, &issuer_len) != 1) { error_print(); return -1; } while (ca_names_len) { const uint8_t *p; size_t len; const uint8_t *name; size_t namelen; if (tls_uint16array_from_bytes(&p, &len, &ca_names, &ca_names_len) != 1) { error_print(); return -1; } if (asn1_sequence_from_der(&name, &namelen, &p, &len) != 1 || asn1_length_is_zero(len) != 1) { error_print(); return -1; } if (x509_name_equ(name, namelen, issuer, issuer_len) == 1) { return 1; } } error_print(); return 0; } int tls_cert_types_accepted(const uint8_t *types, size_t types_len, const uint8_t *client_certs, size_t client_certs_len) { const uint8_t *cert; size_t certlen; int sig_alg; size_t i; if (x509_certs_get_cert_by_index(client_certs, client_certs_len, 0, &cert, &certlen) != 1) { error_print(); return -1; } if ((sig_alg = tls_cert_type_from_oid(OID_sm2sign_with_sm3)) < 0) { error_print(); return -1; } for (i = 0; i < types_len; i++) { if (sig_alg == types[i]) { return 1; } } return 0; } int tls_client_verify_init(TLS_CLIENT_VERIFY_CTX *ctx) { if (!ctx) { error_print(); return -1; } memset(ctx, 0, sizeof(TLS_CLIENT_VERIFY_CTX)); return 1; } int tls_client_verify_update(TLS_CLIENT_VERIFY_CTX *ctx, const uint8_t *handshake, size_t handshake_len) { uint8_t *buf; if (!ctx || !handshake || !handshake_len) { error_print(); return -1; } if (ctx->index < 0 || ctx->index > 7) { error_print(); return -1; } if (!(buf = malloc(handshake_len))) { error_print(); return -1; } memcpy(buf, handshake, handshake_len); ctx->handshake[ctx->index] = buf; ctx->handshake_len[ctx->index] = handshake_len; ctx->index++; return 1; } int tls_client_verify_finish(TLS_CLIENT_VERIFY_CTX *ctx, const uint8_t *sig, size_t siglen, const SM2_KEY *public_key) { int ret; SM2_SIGN_CTX sm2_ctx; int i; if (!ctx || !sig || !siglen || !public_key) { error_print(); return -1; } if (ctx->index != 8) { error_print(); return -1; } if (sm2_verify_init(&sm2_ctx, public_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1) { error_print(); return -1; } for (i = 0; i < 8; i++) { if (sm2_verify_update(&sm2_ctx, ctx->handshake[i], ctx->handshake_len[i]) != 1) { error_print(); return -1; } } if ((ret = sm2_verify_finish(&sm2_ctx, sig, siglen)) < 0) { error_print(); return -1; } return ret; } void tls_client_verify_cleanup(TLS_CLIENT_VERIFY_CTX *ctx) { if (ctx) { int i; for (i = 0; i < ctx->index; i++) { if (ctx->handshake[i]) { free(ctx->handshake[i]); ctx->handshake[i] = NULL; ctx->handshake_len[i] = 0; } } } } int tls_cipher_suites_select(const uint8_t *client_ciphers, size_t client_ciphers_len, const int *server_ciphers, size_t server_ciphers_cnt, int *selected_cipher) { if (!client_ciphers || !client_ciphers_len || !server_ciphers || !server_ciphers_cnt || !selected_cipher) { error_print(); return -1; } while (server_ciphers_cnt--) { const uint8_t *p = client_ciphers; size_t len = client_ciphers_len; while (len) { uint16_t cipher; if (tls_uint16_from_bytes(&cipher, &p, &len) != 1) { error_print(); return -1; } if (cipher == *server_ciphers) { *selected_cipher = *server_ciphers; return 1; } } server_ciphers++; } return 0; } void tls_ctx_cleanup(TLS_CTX *ctx) { if (ctx) { gmssl_secure_clear(&ctx->signkey, sizeof(SM2_KEY)); gmssl_secure_clear(&ctx->kenckey, sizeof(SM2_KEY)); if (ctx->certs) free(ctx->certs); if (ctx->cacerts) free(ctx->cacerts); memset(ctx, 0, sizeof(TLS_CTX)); } } int tls_ctx_init(TLS_CTX *ctx, int protocol, int is_client) { if (!ctx) { error_print(); return -1; } memset(ctx, 0, sizeof(*ctx)); switch (protocol) { case TLS_protocol_tlcp: case TLS_protocol_tls12: case TLS_protocol_tls13: ctx->protocol = protocol; break; default: error_print(); return -1; } ctx->is_client = is_client ? 1 : 0; return 1; } int tls_ctx_set_cipher_suites(TLS_CTX *ctx, const int *cipher_suites, size_t cipher_suites_cnt) { size_t i; if (!ctx || !cipher_suites || !cipher_suites_cnt) { error_print(); return -1; } if (cipher_suites_cnt < 1 || cipher_suites_cnt > TLS_MAX_CIPHER_SUITES_COUNT) { error_print(); return -1; } for (i = 0; i < cipher_suites_cnt; i++) { if (!tls_cipher_suite_name(cipher_suites[i])) { error_print(); return -1; } } for (i = 0; i < cipher_suites_cnt; i++) { ctx->cipher_suites[i] = cipher_suites[i]; } ctx->cipher_suites_cnt = cipher_suites_cnt; return 1; } int tls_ctx_set_ca_certificates(TLS_CTX *ctx, const char *cacertsfile, int depth) { if (!ctx || !cacertsfile) { error_print(); return -1; } if (depth < 0 || depth > TLS_MAX_VERIFY_DEPTH) { error_print(); return -1; } if (!tls_protocol_name(ctx->protocol)) { error_print(); return -1; } if (ctx->cacerts) { error_print(); return -1; } if (x509_certs_new_from_file(&ctx->cacerts, &ctx->cacertslen, cacertsfile) != 1) { error_print(); return -1; } if (ctx->cacertslen == 0) { error_print(); return -1; } ctx->verify_depth = depth; return 1; } int tls_ctx_set_certificate_and_key(TLS_CTX *ctx, const char *chainfile, const char *keyfile, const char *keypass) { int ret = -1; uint8_t *certs = NULL; size_t certslen; FILE *keyfp = NULL; SM2_KEY key; const uint8_t *cert; size_t certlen; SM2_KEY public_key; if (!ctx || !chainfile || !keyfile || !keypass) { error_print(); return -1; } if (!tls_protocol_name(ctx->protocol)) { error_print(); return -1; } if (ctx->certs) { error_print(); return -1; } if (x509_certs_new_from_file(&certs, &certslen, chainfile) != 1) { error_print(); goto end; } if (!(keyfp = fopen(keyfile, "r"))) { error_print(); goto end; } if (sm2_private_key_info_decrypt_from_pem(&key, keypass, keyfp) != 1) { error_print(); goto end; } if (x509_certs_get_cert_by_index(certs, certslen, 0, &cert, &certlen) != 1 || x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) { error_print(); return -1; } if (sm2_public_key_equ(&key, &public_key) != 1) { error_print(); return -1; } ctx->certs = certs; ctx->certslen = certslen; ctx->signkey = key; certs = NULL; ret = 1; end: gmssl_secure_clear(&key, sizeof(key)); if (certs) free(certs); if (keyfp) fclose(keyfp); return ret; } int tls_ctx_set_tlcp_server_certificate_and_keys( TLS_CTX *ctx, const char *chainfile, const char *signkeyfile, const char *signkeypass, const char *kenckeyfile, const char *kenckeypass) { int ret = -1; uint8_t *certs = NULL; size_t certslen; FILE *signkeyfp = NULL; FILE *kenckeyfp = NULL; SM2_KEY signkey; SM2_KEY kenckey; const uint8_t *cert; size_t certlen; SM2_KEY public_key; if (!ctx || !chainfile || !signkeyfile || !signkeypass || !kenckeyfile || !kenckeypass) { error_print(); return -1; } if (!tls_protocol_name(ctx->protocol)) { error_print(); return -1; } if (ctx->certs) { error_print(); return -1; } if (x509_certs_new_from_file(&certs, &certslen, chainfile) != 1) { error_print(); return -1; } if (!(signkeyfp = fopen(signkeyfile, "r"))) { error_print(); goto end; } if (sm2_private_key_info_decrypt_from_pem(&signkey, signkeypass, signkeyfp) != 1) { error_print(); goto end; } if (x509_certs_get_cert_by_index(certs, certslen, 0, &cert, &certlen) != 1 || x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1 || sm2_public_key_equ(&signkey, &public_key) != 1) { error_print(); goto end; } if (!(kenckeyfp = fopen(kenckeyfile, "r"))) { error_print(); goto end; } if (sm2_private_key_info_decrypt_from_pem(&kenckey, kenckeypass, kenckeyfp) != 1) { error_print(); goto end; } if (x509_certs_get_cert_by_index(certs, certslen, 1, &cert, &certlen) != 1 || x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1 || sm2_public_key_equ(&kenckey, &public_key) != 1) { error_print(); goto end; } ctx->certs = certs; ctx->certslen = certslen; ctx->signkey = signkey; ctx->kenckey = kenckey; certs = NULL; ret = 1; end: gmssl_secure_clear(&signkey, sizeof(signkey)); gmssl_secure_clear(&kenckey, sizeof(kenckey)); if (certs) free(certs); if (signkeyfp) fclose(signkeyfp); if (kenckeyfp) fclose(kenckeyfp); return ret; } int tls_init(TLS_CONNECT *conn, const TLS_CTX *ctx) { size_t i; memset(conn, 0, sizeof(*conn)); conn->protocol = ctx->protocol; conn->is_client = ctx->is_client; for (i = 0; i < ctx->cipher_suites_cnt; i++) { conn->cipher_suites[i] = ctx->cipher_suites[i]; } conn->cipher_suites_cnt = ctx->cipher_suites_cnt; if (ctx->certslen > TLS_MAX_CERTIFICATES_SIZE) { error_print(); return -1; } if (conn->is_client) { memcpy(conn->client_certs, ctx->certs, ctx->certslen); conn->client_certs_len = ctx->certslen; } else { memcpy(conn->server_certs, ctx->certs, ctx->certslen); conn->server_certs_len = ctx->certslen; } if (ctx->cacertslen > TLS_MAX_CERTIFICATES_SIZE) { error_print(); return -1; } memcpy(conn->ca_certs, ctx->cacerts, ctx->cacertslen); conn->ca_certs_len = ctx->cacertslen; conn->sign_key = ctx->signkey; conn->kenc_key = ctx->kenckey; return 1; } void tls_cleanup(TLS_CONNECT *conn) { gmssl_secure_clear(conn, sizeof(TLS_CONNECT)); } int tls_set_socket(TLS_CONNECT *conn, tls_socket_t sock) { #if 0 int opts; // FIXME: do we still need this? when using select? if ((opts = fcntl(sock, F_GETFL)) < 0) { error_print(); perror("tls_set_socket"); return -1; } opts &= ~O_NONBLOCK; if (fcntl(sock, F_SETFL, opts) < 0) { error_print(); return -1; } #endif conn->sock = sock; return 1; } int tls_do_handshake(TLS_CONNECT *conn) { switch (conn->protocol) { case TLS_protocol_tlcp: if (conn->is_client) return tlcp_do_connect(conn); else return tlcp_do_accept(conn); case TLS_protocol_tls12: if (conn->is_client) return tls12_do_connect(conn); else return tls12_do_accept(conn); case TLS_protocol_tls13: if (conn->is_client) return tls13_do_connect(conn); else return tls13_do_accept(conn); } error_print(); return -1; } int tls_get_verify_result(TLS_CONNECT *conn, int *result) { *result = conn->verify_result; return 1; }