/* * Copyright 2014-2026 The GmSSL Project. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the License); you may * not use this file except in compliance with the License. * * http://www.apache.org/licenses/LICENSE-2.0 */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include void tls_uint8_to_bytes(uint8_t a, uint8_t **out, size_t *outlen) { if (out && *out) { *(*out)++ = a; } (*outlen)++; } void tls_uint16_to_bytes(uint16_t a, uint8_t **out, size_t *outlen) { if (out && *out) { *(*out)++ = (uint8_t)(a >> 8); *(*out)++ = (uint8_t)a; } *outlen += 2; } void tls_uint24_to_bytes(uint24_t a, uint8_t **out, size_t *outlen) { if (out && *out) { *(*out)++ = (uint8_t)(a >> 16); *(*out)++ = (uint8_t)(a >> 8); *(*out)++ = (uint8_t)(a); } (*outlen) += 3; } void tls_uint32_to_bytes(uint32_t a, uint8_t **out, size_t *outlen) { if (out && *out) { *(*out)++ = (uint8_t)(a >> 24); *(*out)++ = (uint8_t)(a >> 16); *(*out)++ = (uint8_t)(a >> 8); *(*out)++ = (uint8_t)(a ); } (*outlen) += 4; } void tls_uint64_to_bytes(uint64_t a, uint8_t **out, size_t *outlen) { if (out && *out) { PUTU64(*out, a); } (*outlen) += 8; } void tls_array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen) { if (out && *out) { if (data) { memcpy(*out, data, datalen); } *out += datalen; } *outlen += datalen; } /* 这几个函数要区分data = NULL, datalen = 0 和 data = NULL, datalen != 0的情况 前者意味着数据为空,因此输出的就是一个长度 后者意味着数据不为空,只是我们不想输出数据,只输出头部的长度,并且更新整个的输出长度。 这种情况应该避免! */ void tls_uint8array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen) { tls_uint8_to_bytes((uint8_t)datalen, out, outlen); tls_array_to_bytes(data, datalen, out, outlen); } void tls_uint16array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen) { tls_uint16_to_bytes((uint16_t)datalen, out, outlen); tls_array_to_bytes(data, datalen, out, outlen); } void tls_uint24array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen) { tls_uint24_to_bytes((uint24_t)datalen, out, outlen); tls_array_to_bytes(data, datalen, out, outlen); } int tls_uint8_from_bytes(uint8_t *a, const uint8_t **in, size_t *inlen) { if (*inlen < 1) { error_print(); return -1; } *a = *(*in)++; (*inlen)--; return 1; } int tls_uint16_from_bytes(uint16_t *a, const uint8_t **in, size_t *inlen) { if (*inlen < 2) { error_print(); return -1; } *a = *(*in)++; *a <<= 8; *a |= *(*in)++; *inlen -= 2; return 1; } int tls_uint24_from_bytes(uint24_t *a, const uint8_t **in, size_t *inlen) { if (*inlen < 3) { error_print(); return -1; } *a = *(*in)++; *a <<= 8; *a |= *(*in)++; *a <<= 8; *a |= *(*in)++; *inlen -= 3; return 1; } int tls_uint32_from_bytes(uint32_t *a, const uint8_t **in, size_t *inlen) { if (*inlen < 4) { error_print(); return -1; } *a = *(*in)++; *a <<= 8; *a |= *(*in)++; *a <<= 8; *a |= *(*in)++; *a <<= 8; *a |= *(*in)++; *inlen -= 4; return 1; } int tls_uint64_from_bytes(uint64_t *a, const uint8_t **in, size_t *inlen) { if (*inlen < 8) { error_print(); return -1; } *a = GETU64(*in); *in += 8; *inlen -= 8; return 1; } int tls_array_from_bytes(const uint8_t **data, size_t datalen, const uint8_t **in, size_t *inlen) { if (*inlen < datalen) { error_print(); return -1; } *data = *in; *in += datalen; *inlen -= datalen; return 1; } int tls_uint8array_from_bytes(const uint8_t **data, size_t *datalen, const uint8_t **in, size_t *inlen) { uint8_t len; if (tls_uint8_from_bytes(&len, in, inlen) != 1 || tls_array_from_bytes(data, len, in, inlen) != 1) { error_print(); return -1; } if (!len) { *data = NULL; } *datalen = len; return 1; } int tls_uint16array_from_bytes(const uint8_t **data, size_t *datalen, const uint8_t **in, size_t *inlen) { uint16_t len; if (tls_uint16_from_bytes(&len, in, inlen) != 1 || tls_array_from_bytes(data, len, in, inlen) != 1) { error_print(); return -1; } if (!len) { *data = NULL; } *datalen = len; return 1; } int tls_uint24array_from_bytes(const uint8_t **data, size_t *datalen, const uint8_t **in, size_t *inlen) { uint24_t len; if (tls_uint24_from_bytes(&len, in, inlen) != 1 || tls_array_from_bytes(data, len, in, inlen) != 1) { error_print(); return -1; } if (!len) { *data = NULL; } *datalen = len; return 1; } int tls_length_is_zero(size_t len) { if (len) { error_print(); return -1; } return 1; } int tls_record_set_type(uint8_t *record, int type) { if (!tls_record_type_name(type)) { error_print(); return -1; } record[0] = (uint8_t)type; return 1; } int tls_record_set_protocol(uint8_t *record, int protocol) { if (!tls_protocol_name(protocol)) { error_print(); return -1; } record[1] = (uint8_t)(protocol >> 8); record[2] = (uint8_t)(protocol); return 1; } int tls_record_set_length(uint8_t *record, size_t length) { uint8_t *p = record + 3; size_t len; if (length > TLS_MAX_CIPHERTEXT_SIZE) { error_print(); return -1; } tls_uint16_to_bytes((uint16_t)length, &p, &len); return 1; } int tls_record_set_data(uint8_t *record, const uint8_t *data, size_t datalen) { if (tls_record_set_length(record, datalen) != 1) { error_print(); return -1; } memcpy(tls_record_data(record), data, datalen); return 1; } int tls_cbc_encrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *enc_key, const uint8_t seq_num[8], const uint8_t header[5], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { SM3_HMAC_CTX hmac_ctx; uint8_t last_blocks[32 + 16] = {0}; uint8_t iv[16]; uint8_t *mac, *padding; int rem, padding_len; int i; if (!inited_hmac_ctx || !enc_key || !seq_num || !header || (!in && inlen) || !out || !outlen) { error_print(); return -1; } if (inlen > (1 << 14)) { error_print_msg("invalid tls record data length %zu\n", inlen); return -1; } if ((((size_t)header[3]) << 8) + header[4] != inlen) { error_print(); return -1; } rem = (inlen + 32) % 16; memcpy(last_blocks, in + inlen - rem, rem); mac = last_blocks + rem; memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(SM3_HMAC_CTX)); sm3_hmac_update(&hmac_ctx, seq_num, 8); sm3_hmac_update(&hmac_ctx, header, 5); sm3_hmac_update(&hmac_ctx, in, inlen); sm3_hmac_finish(&hmac_ctx, mac); padding = mac + 32; padding_len = 16 - rem - 1; for (i = 0; i <= padding_len; i++) { padding[i] = (uint8_t)padding_len; } if (rand_bytes(iv, 16) != 1) { error_print(); return -1; } memcpy(out, iv, 16); out += 16; if (inlen >= 16) { sm4_cbc_encrypt_blocks(enc_key, iv, in, inlen/16, out); out += inlen - rem; } sm4_cbc_encrypt_blocks(enc_key, iv, last_blocks, sizeof(last_blocks)/16, out); *outlen = 16 + inlen - rem + sizeof(last_blocks); return 1; } // 这个函数应该把所有的输入的dgst都打印出来!这样就可以容易判断出到底是哪个输入错了 int tls_cbc_decrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *dec_key, const uint8_t seq_num[8], const uint8_t enced_header[5], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { SM3_HMAC_CTX hmac_ctx; uint8_t iv[16]; const uint8_t *padding; const uint8_t *mac; uint8_t header[5]; int padding_len; uint8_t hmac[32]; int i; if (!inited_hmac_ctx || !dec_key || !seq_num || !enced_header || !in || !inlen || !out || !outlen) { error_print(); return -1; } if (inlen % 16 || inlen < (16 + 0 + 32 + 16) // iv + data + mac + padding || inlen > (16 + (1<<14) + 32 + 256)) { error_print_msg("invalid tls cbc ciphertext length %zu\n", inlen); return -1; } memcpy(iv, in, 16); in += 16; inlen -= 16; sm4_cbc_decrypt_blocks(dec_key, iv, in, inlen/16, out); padding_len = out[inlen - 1]; padding = out + inlen - padding_len - 1; if (padding < out + 32) { error_print(); return -1; } for (i = 0; i < padding_len; i++) { if (padding[i] != padding_len) { error_puts("tls ciphertext cbc-padding check failure"); return -1; } } *outlen = inlen - 32 - padding_len - 1; header[0] = enced_header[0]; header[1] = enced_header[1]; header[2] = enced_header[2]; header[3] = (uint8_t)((*outlen) >> 8); header[4] = (uint8_t)(*outlen); mac = padding - 32; memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(SM3_HMAC_CTX)); sm3_hmac_update(&hmac_ctx, seq_num, 8); sm3_hmac_update(&hmac_ctx, header, 5); sm3_hmac_update(&hmac_ctx, out, *outlen); sm3_hmac_finish(&hmac_ctx, hmac); if (gmssl_secure_memcmp(mac, hmac, sizeof(hmac)) != 0) { error_puts("tls ciphertext mac check failure\n"); return -1; } return 1; } int tls_record_encrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key, const uint8_t seq_num[8], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { if (tls_cbc_encrypt(hmac_ctx, cbc_key, seq_num, in, in + 5, inlen - 5, out + 5, outlen) != 1) { error_print(); return -1; } out[0] = in[0]; out[1] = in[1]; out[2] = in[2]; out[3] = (uint8_t)((*outlen) >> 8); out[4] = (uint8_t)(*outlen); (*outlen) += 5; return 1; } int tls_record_decrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key, const uint8_t seq_num[8], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { if (tls_cbc_decrypt(hmac_ctx, cbc_key, seq_num, in, in + 5, inlen - 5, out + 5, outlen) != 1) { error_print(); return -1; } out[0] = in[0]; out[1] = in[1]; out[2] = in[2]; out[3] = (uint8_t)((*outlen) >> 8); out[4] = (uint8_t)(*outlen); (*outlen) += 5; return 1; } int tls_random_generate(uint8_t random[32]) { uint32_t gmt_unix_time = (uint32_t)time(NULL); uint8_t *p = random; size_t len = 0; tls_uint32_to_bytes(gmt_unix_time, &p, &len); if (rand_bytes(random + 4, 28) != 1) { error_print(); return -1; } return 1; } int tls_prf(const uint8_t *secret, size_t secretlen, const char *label, const uint8_t *seed, size_t seedlen, const uint8_t *more, size_t morelen, size_t outlen, uint8_t *out) { SM3_HMAC_CTX inited_hmac_ctx; SM3_HMAC_CTX hmac_ctx; uint8_t A[32]; uint8_t hmac[32]; size_t len; if (!secret || !secretlen || !label || !seed || !seedlen || (!more && morelen) || !outlen || !out) { error_print(); return -1; } sm3_hmac_init(&inited_hmac_ctx, secret, secretlen); memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX)); sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); sm3_hmac_update(&hmac_ctx, seed, seedlen); sm3_hmac_update(&hmac_ctx, more, morelen); sm3_hmac_finish(&hmac_ctx, A); memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX)); sm3_hmac_update(&hmac_ctx, A, sizeof(A)); sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); sm3_hmac_update(&hmac_ctx, seed, seedlen); sm3_hmac_update(&hmac_ctx, more, morelen); sm3_hmac_finish(&hmac_ctx, hmac); len = outlen < sizeof(hmac) ? outlen : sizeof(hmac); memcpy(out, hmac, len); out += len; outlen -= len; while (outlen) { memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX)); sm3_hmac_update(&hmac_ctx, A, sizeof(A)); sm3_hmac_finish(&hmac_ctx, A); memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX)); sm3_hmac_update(&hmac_ctx, A, sizeof(A)); sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); sm3_hmac_update(&hmac_ctx, seed, seedlen); sm3_hmac_update(&hmac_ctx, more, morelen); sm3_hmac_finish(&hmac_ctx, hmac); len = outlen < sizeof(hmac) ? outlen : sizeof(hmac); memcpy(out, hmac, len); out += len; outlen -= len; } return 1; } int tls_pre_master_secret_generate(uint8_t pre_master_secret[48], int protocol) { if (!tls_protocol_name(protocol)) { error_print(); return -1; } pre_master_secret[0] = (uint8_t)(protocol >> 8); pre_master_secret[1] = (uint8_t)(protocol); if (rand_bytes(pre_master_secret + 2, 46) != 1) { error_print(); return -1; } return 1; } // 用于设置CertificateRequest int tls_cert_type_from_oid(int oid) { switch (oid) { case OID_sm2sign_with_sm3: case OID_ecdsa_with_sha1: case OID_ecdsa_with_sha224: case OID_ecdsa_with_sha256: case OID_ecdsa_with_sha512: return TLS_cert_type_ecdsa_sign; case OID_rsasign_with_sm3: case OID_rsasign_with_md5: case OID_rsasign_with_sha1: case OID_rsasign_with_sha224: case OID_rsasign_with_sha256: case OID_rsasign_with_sha384: case OID_rsasign_with_sha512: return TLS_cert_type_rsa_sign; } // TLS_cert_type_xxx 中没有为0的值 return 0; } // 这两个函数没有对应的TLCP版本, 这个现在已经有了ex版本了 int tls_sign_server_ecdh_params(const SM2_KEY *server_sign_key, const uint8_t client_random[32], const uint8_t server_random[32], int curve, const SM2_Z256_POINT *point, uint8_t *sig, size_t *siglen) { uint8_t server_ecdh_params[69]; SM2_SIGN_CTX sign_ctx; if (!server_sign_key || !client_random || !server_random || curve != TLS_curve_sm2p256v1 || !point || !sig || !siglen) { error_print(); return -1; } server_ecdh_params[0] = TLS_curve_type_named_curve; server_ecdh_params[1] = (uint8_t)(curve >> 8); server_ecdh_params[2] = (uint8_t)curve; server_ecdh_params[3] = 65; sm2_z256_point_to_uncompressed_octets(point, server_ecdh_params + 4); sm2_sign_init(&sign_ctx, server_sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH); sm2_sign_update(&sign_ctx, client_random, 32); sm2_sign_update(&sign_ctx, server_random, 32); sm2_sign_update(&sign_ctx, server_ecdh_params, 69); sm2_sign_finish(&sign_ctx, sig, siglen); return 1; } int tls_verify_server_ecdh_params(const SM2_KEY *server_sign_key, const uint8_t client_random[32], const uint8_t server_random[32], int curve, const SM2_Z256_POINT *point, const uint8_t *sig, size_t siglen) { int ret; uint8_t server_ecdh_params[69]; SM2_VERIFY_CTX verify_ctx; if (!server_sign_key || !client_random || !server_random || curve != TLS_curve_sm2p256v1 || !point || !sig || !siglen || siglen > SM2_MAX_SIGNATURE_SIZE) { error_print(); return -1; } server_ecdh_params[0] = TLS_curve_type_named_curve; server_ecdh_params[1] = (uint8_t)(curve >> 8); server_ecdh_params[2] = (uint8_t)(curve); server_ecdh_params[3] = 65; sm2_z256_point_to_uncompressed_octets(point, server_ecdh_params + 4); sm2_verify_init(&verify_ctx, server_sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH); sm2_verify_update(&verify_ctx, client_random, 32); sm2_verify_update(&verify_ctx, server_random, 32); sm2_verify_update(&verify_ctx, server_ecdh_params, 69); ret = sm2_verify_finish(&verify_ctx, sig, siglen); if (ret != 1) error_print(); return ret; } int tls_record_set_handshake(uint8_t *record, size_t *recordlen, int type, const uint8_t *data, size_t datalen) { size_t handshakelen; if (!record || !recordlen) { error_print(); return -1; } // 由于ServerHelloDone没有负载数据,因此允许 data,datalen = NULL,0 if (datalen > TLS_MAX_PLAINTEXT_SIZE - TLS_HANDSHAKE_HEADER_SIZE) { error_print(); return -1; } if (!tls_protocol_name(tls_record_protocol(record))) { error_print(); return -1; } if (!tls_handshake_type_name(type)) { error_print(); return -1; } handshakelen = TLS_HANDSHAKE_HEADER_SIZE + datalen; record[0] = TLS_record_handshake; record[3] = (uint8_t)(handshakelen >> 8); record[4] = (uint8_t)(handshakelen); record[5] = (uint8_t)(type); record[6] = (uint8_t)(datalen >> 16); record[7] = (uint8_t)(datalen >> 8); record[8] = (uint8_t)(datalen); if (data) { memcpy(tls_handshake_data(tls_record_data(record)), data, datalen); } *recordlen = TLS_RECORD_HEADER_SIZE + handshakelen; return 1; } int tls_record_set_handshake_header(uint8_t *record, size_t *recordlen, int type, int length) { if (tls_record_set_handshake(record, recordlen, type, NULL, length) != 1) { error_print(); return -1; } return 1; } int tls_record_get_handshake(const uint8_t *record, int *type, const uint8_t **data, size_t *datalen) { const uint8_t *handshake; size_t handshake_len; uint24_t handshake_datalen; if (!record || !type || !data || !datalen) { error_print(); return -1; } if (!tls_protocol_name(tls_record_protocol(record))) { error_print(); return -1; } if (tls_record_type(record) != TLS_record_handshake) { error_print(); return -1; } handshake = tls_record_data(record); handshake_len = tls_record_data_length(record); if (handshake_len < TLS_HANDSHAKE_HEADER_SIZE) { error_print(); return -1; } if (handshake_len > TLS_MAX_PLAINTEXT_SIZE) { // 不支持证书长度超过记录长度的特殊情况 error_print(); return -1; } if (!tls_handshake_type_name(handshake[0])) { error_print(); return -1; } *type = handshake[0]; handshake++; handshake_len--; if (tls_uint24_from_bytes(&handshake_datalen, &handshake, &handshake_len) != 1) { error_print(); return -1; } if (handshake_len != handshake_datalen) { error_print(); return -1; } *data = handshake; *datalen = handshake_datalen; if (*datalen == 0) { *data = NULL; } return 1; } int tls_record_set_handshake_client_hello(uint8_t *record, size_t *recordlen, int protocol, const uint8_t random[32], const uint8_t *session_id, size_t session_id_len, const int *cipher_suites, size_t cipher_suites_count, const uint8_t *exts, size_t exts_len) { uint8_t type = TLS_handshake_client_hello; uint8_t *p; size_t len; if (!record || !recordlen || !random || !cipher_suites || !cipher_suites_count) { error_print(); return -1; } if (session_id) { if (!session_id_len || session_id_len < TLS_MAX_SESSION_ID_SIZE || session_id_len > TLS_MAX_SESSION_ID_SIZE) { error_print(); return -1; } } if (cipher_suites_count > TLS_MAX_CIPHER_SUITES_COUNT) { error_print(); return -1; } if (exts && !exts_len) { error_print(); return -1; } p = tls_handshake_data(tls_record_data(record)); len = 0; if (!tls_protocol_name(protocol)) { error_print(); return -1; } tls_uint16_to_bytes((uint16_t)protocol, &p, &len); tls_array_to_bytes(random, 32, &p, &len); tls_uint8array_to_bytes(session_id, session_id_len, &p, &len); tls_uint16_to_bytes((uint16_t)(cipher_suites_count * 2), &p, &len); while (cipher_suites_count--) { if (!tls_cipher_suite_name(*cipher_suites)) { error_print(); return -1; } tls_uint16_to_bytes((uint16_t)*cipher_suites, &p, &len); cipher_suites++; } tls_uint8_to_bytes(1, &p, &len); tls_uint8_to_bytes((uint8_t)TLS_compression_null, &p, &len); if (exts) { size_t tmp_len = len; if (protocol < TLS_protocol_tls12) { error_print(); return -1; } tls_uint16array_to_bytes(exts, exts_len, NULL, &tmp_len); if (tmp_len > TLS_MAX_HANDSHAKE_DATA_SIZE) { error_print(); return -1; } tls_uint16array_to_bytes(exts, exts_len, &p, &len); } if (tls_record_set_handshake(record, recordlen, type, NULL, len) != 1) { error_print(); return -1; } return 1; } int tls_record_get_handshake_client_hello(const uint8_t *record, int *protocol, const uint8_t **random, const uint8_t **session_id, size_t *session_id_len, const uint8_t **cipher_suites, size_t *cipher_suites_len, const uint8_t **exts, size_t *exts_len) { int type; const uint8_t *p; size_t len; uint16_t ver; const uint8_t *comp_meths; size_t comp_meths_len; if (!record || !protocol || !random || !session_id || !session_id_len || !cipher_suites || !cipher_suites_len || !exts || !exts_len) { error_print(); return -1; } if (tls_record_get_handshake(record, &type, &p, &len) != 1) { error_print(); return -1; } if (type != TLS_handshake_client_hello) { error_print(); return -1; } if (tls_uint16_from_bytes(&ver, &p, &len) != 1 || tls_array_from_bytes(random, 32, &p, &len) != 1 || tls_uint8array_from_bytes(session_id, session_id_len, &p, &len) != 1 || tls_uint16array_from_bytes(cipher_suites, cipher_suites_len, &p, &len) != 1 || tls_uint8array_from_bytes(&comp_meths, &comp_meths_len, &p, &len) != 1) { error_print(); return -1; } if (!tls_protocol_name(ver)) { error_print(); return -1; } *protocol = ver; if (*session_id) { if (*session_id_len == 0 || *session_id_len < TLS_MIN_SESSION_ID_SIZE || *session_id_len > TLS_MAX_SESSION_ID_SIZE) { error_print(); return -1; } } if (!cipher_suites) { error_print(); return -1; } if (*cipher_suites_len % 2) { error_print(); return -1; } if (len) { if (tls_uint16array_from_bytes(exts, exts_len, &p, &len) != 1) { error_print(); return -1; } if (*exts == NULL) { error_print(); return -1; } } else { *exts = NULL; *exts_len = 0; } if (len) { error_print(); return -1; } return 1; } int tls_record_set_handshake_server_hello(uint8_t *record, size_t *recordlen, int protocol, const uint8_t random[32], const uint8_t *session_id, size_t session_id_len, int cipher_suite, const uint8_t *exts, size_t exts_len) { uint8_t type = TLS_handshake_server_hello; uint8_t *p; size_t len; if (!record || !recordlen || !random) { error_print(); return -1; } if (session_id) { if (session_id_len == 0 || session_id_len < TLS_MIN_SESSION_ID_SIZE || session_id_len > TLS_MAX_SESSION_ID_SIZE) { error_print(); return -1; } } if (!tls_protocol_name(protocol)) { error_print(); return -1; } if (!tls_cipher_suite_name(cipher_suite)) { error_print(); return -1; } p = tls_handshake_data(tls_record_data(record)); len = 0; tls_uint16_to_bytes((uint16_t)protocol, &p, &len); tls_array_to_bytes(random, 32, &p, &len); tls_uint8array_to_bytes(session_id, session_id_len, &p, &len); tls_uint16_to_bytes((uint16_t)cipher_suite, &p, &len); tls_uint8_to_bytes((uint8_t)TLS_compression_null, &p, &len); if (exts) { if (protocol < TLS_protocol_tls12) { error_print(); return -1; } tls_uint16array_to_bytes(exts, exts_len, &p, &len); } if (tls_record_set_handshake(record, recordlen, type, NULL, len) != 1) { error_print(); return -1; } return 1; } /* 如果报文的结构正确,但是数据不合法的时候,应该返回TLS_alert_illegal_parameter 例如服务器的选择不在ClientHello提供的列表中 因此涉及到语义错误的,应该返回这个错误。 如果语义我们不能理解,但是格式正确,那么应该忽略 */ int tls_record_get_handshake_server_hello(const uint8_t *record, int *protocol, const uint8_t **random, const uint8_t **session_id, size_t *session_id_len, int *cipher_suite, const uint8_t **exts, size_t *exts_len) { int type; const uint8_t *p; size_t len; uint16_t ver; uint16_t cipher; uint8_t comp_meth; if (!record || !protocol || !random || !session_id || !session_id_len || !cipher_suite || !exts || !exts_len) { error_print(); return -1; } if (tls_record_get_handshake(record, &type, &p, &len) != 1) { error_print(); return -1; } if (type != TLS_handshake_server_hello) { error_print(); return 0; } if (tls_uint16_from_bytes(&ver, &p, &len) != 1 || tls_array_from_bytes(random, 32, &p, &len) != 1 || tls_uint8array_from_bytes(session_id, session_id_len, &p, &len) != 1 || tls_uint16_from_bytes(&cipher, &p, &len) != 1 || tls_uint8_from_bytes(&comp_meth, &p, &len) != 1) { error_print(); return -1; } if (!tls_protocol_name(ver)) { error_print(); return -1; } if (ver < tls_record_protocol(record)) { error_print(); return -1; } *protocol = ver; if (*session_id) { if (*session_id == 0 || *session_id_len < TLS_MIN_SESSION_ID_SIZE || *session_id_len > TLS_MAX_SESSION_ID_SIZE) { error_print(); return -1; } } if (!tls_cipher_suite_name(cipher)) { error_print(); return -1; } *cipher_suite = cipher; // 这个值应该返回给上一层,这样上面才能够返回正确的错误值 if (comp_meth != TLS_compression_null) { error_print(); return -1; } if (len) { if (tls_uint16array_from_bytes(exts, exts_len, &p, &len) != 1) { error_print(); return -1; } if (*exts == NULL) { error_print(); return -1; } } else { *exts = NULL; *exts_len = 0; } if (len) { error_print(); return -1; } return 1; } int tls_record_set_handshake_certificate(uint8_t *record, size_t *recordlen, const uint8_t *certs, size_t certslen) { int type = TLS_handshake_certificate; uint8_t *data; size_t datalen; uint8_t *p; size_t len; if (!record || !recordlen || !certs || !certslen) { error_print(); return -1; } data = tls_handshake_data(tls_record_data(record)); p = data + tls_uint24_size(); datalen = tls_uint24_size(); len = 0; while (certslen) { const uint8_t *cert; size_t certlen; if (x509_cert_from_der(&cert, &certlen, &certs, &certslen) != 1) { error_print(); return -1; } tls_uint24array_to_bytes(cert, certlen, NULL, &datalen); if (datalen > TLS_MAX_HANDSHAKE_DATA_SIZE) { error_print(); return -1; } tls_uint24array_to_bytes(cert, certlen, &p, &len); } tls_uint24_to_bytes((uint24_t)len, &data, &len); tls_record_set_handshake(record, recordlen, type, NULL, datalen); return 1; } int tls_record_get_handshake_certificate(const uint8_t *record, uint8_t *certs, size_t *certslen) { int type; const uint8_t *data; size_t datalen; const uint8_t *cp; size_t len; if (tls_record_get_handshake(record, &type, &data, &datalen) != 1) { error_print(); return -1; } if (type != TLS_handshake_certificate) { error_print(); return -1; } if (tls_uint24array_from_bytes(&cp, &len, &data, &datalen) != 1) { error_print(); return -1; } *certslen = 0; while (len) { const uint8_t *a; size_t alen; const uint8_t *cert; size_t certlen; if (tls_uint24array_from_bytes(&a, &alen, &cp, &len) != 1) { error_print(); return -1; } if (x509_cert_from_der(&cert, &certlen, &a, &alen) != 1 || asn1_length_is_zero(alen) != 1 || x509_cert_to_der(cert, certlen, &certs, certslen) != 1) { error_print(); return -1; } } return 1; } /* struct { ClientCertificateType certificate_types<1..2^8-1>; SignatureAndHashAlgorithm supported_signature_algorithms<2^16-1>; // 可能缺少这个参数 DistinguishedName certificate_authorities<0..2^16-1>; } CertificateRequest; */ int tls_record_set_handshake_certificate_request(uint8_t *record, size_t *recordlen, const uint8_t *cert_types, size_t cert_types_len, const uint8_t *ca_names, size_t ca_names_len) { int type = TLS_handshake_certificate_request; uint8_t *p; size_t len =0; size_t datalen = 0; if (!record || !recordlen) { error_print(); return -1; } if (cert_types) { if (cert_types_len == 0 || cert_types_len > TLS_MAX_CERTIFICATE_TYPES) { error_print(); return -1; } } if (ca_names) { if (ca_names_len == 0 || ca_names_len > TLS_MAX_CA_NAMES_SIZE) { error_print(); return -1; } } tls_uint8array_to_bytes(cert_types, cert_types_len, NULL, &datalen); tls_uint16array_to_bytes(ca_names, ca_names_len, NULL, &datalen); if (datalen > TLS_MAX_HANDSHAKE_DATA_SIZE) { error_print(); return -1; } p = tls_handshake_data(tls_record_data(record)); tls_uint8array_to_bytes(cert_types, cert_types_len, &p, &len); tls_uint16array_to_bytes(ca_names, ca_names_len, &p, &len); tls_record_set_handshake(record, recordlen, type, NULL, datalen); return 1; } int tls_record_get_handshake_certificate_request(const uint8_t *record, const uint8_t **cert_types, size_t *cert_types_len, const uint8_t **ca_names, size_t *ca_names_len) { int type; const uint8_t *cp; size_t len; size_t i; if (!record || !cert_types || !cert_types_len || !ca_names || !ca_names_len) { error_print(); return -1; } if (tls_record_get_handshake(record, &type, &cp, &len) != 1) { error_print(); return -1; } if (type != TLS_handshake_certificate_request) { error_print(); return -1; } if (tls_uint8array_from_bytes(cert_types, cert_types_len, &cp, &len) != 1 || tls_uint16array_from_bytes(ca_names, ca_names_len, &cp, &len) != 1 || tls_length_is_zero(len) != 1) { error_print(); return -1; } if (*cert_types == NULL) { error_print(); return -1; } for (i = 0; i < *cert_types_len; i++) { if (!tls_cert_type_name((*cert_types)[i])) { error_print(); return -1; } } if (*ca_names) { const uint8_t *names = *ca_names; size_t nameslen = *ca_names_len; while (nameslen) { if (tls_uint16array_from_bytes(&cp, &len, &names, &nameslen) != 1) { error_print(); return -1; } } } return 1; } int tls_record_set_handshake_server_hello_done(uint8_t *record, size_t *recordlen) { int type = TLS_handshake_server_hello_done; if (!record || !recordlen) { error_print(); return -1; } tls_record_set_handshake(record, recordlen, type, NULL, 0); return 1; } int tls_record_get_handshake_server_hello_done(const uint8_t *record) { int type; const uint8_t *p; size_t len; if (!record) { error_print(); return -1; } if (tls_record_get_handshake(record, &type, &p, &len) != 1 || type != TLS_handshake_server_hello_done) { error_print(); return -1; } if (p != NULL || len != 0) { error_print(); return -1; } return 1; } int tls_record_set_handshake_client_key_exchange_pke(uint8_t *record, size_t *recordlen, const uint8_t *enced_pms, size_t enced_pms_len) { int type = TLS_handshake_client_key_exchange; uint8_t *p; size_t len = 0; if (!record || !recordlen || !enced_pms || !enced_pms_len) { error_print(); return -1; } if (enced_pms_len > TLS_MAX_HANDSHAKE_DATA_SIZE - tls_uint16_size()) { error_print(); return -1; } p = tls_handshake_data(tls_record_data(record)); tls_uint16array_to_bytes(enced_pms, enced_pms_len, &p, &len); tls_record_set_handshake(record, recordlen, type, NULL, len); return 1; } int tls_record_get_handshake_client_key_exchange_pke(const uint8_t *record, const uint8_t **enced_pms, size_t *enced_pms_len) { int type; const uint8_t *cp; size_t len; if (!record || !enced_pms || !enced_pms_len) { error_print(); return -1; } if (tls_record_get_handshake(record, &type, &cp, &len) != 1) { error_print(); return -1; } if (type != TLS_handshake_client_key_exchange) { error_print(); return -1; } if (tls_uint16array_from_bytes(enced_pms, enced_pms_len, &cp, &len) != 1 || tls_length_is_zero(len) != 1) { error_print(); return -1; } return 1; } int tls_record_set_handshake_certificate_verify(uint8_t *record, size_t *recordlen, const uint8_t *sig, size_t siglen) { int type = TLS_handshake_certificate_verify; uint8_t *p; size_t len = 0; if (!record || !recordlen || !sig || !siglen) { error_print(); return -1; } if (siglen > TLS_MAX_SIGNATURE_SIZE) { error_print(); return -1; } p = tls_handshake_data(tls_record_data(record)); tls_uint16array_to_bytes(sig, siglen, &p, &len); tls_record_set_handshake(record, recordlen, type, NULL, len); return 1; } int tls_record_get_handshake_certificate_verify(const uint8_t *record, const uint8_t **sig, size_t *siglen) { int type; const uint8_t *cp; size_t len; if (!record || !sig || !siglen) { error_print(); return -1; } if (tls_record_get_handshake(record, &type, &cp, &len) != 1) { error_print(); return -1; } if (type != TLS_handshake_certificate_verify) { error_print(); return -1; } if (tls_uint16array_from_bytes(sig, siglen, &cp, &len) != 1 || tls_length_is_zero(len) != 1) { error_print(); return -1; } return 1; } int tls_record_set_handshake_finished(uint8_t *record, size_t *recordlen, const uint8_t *verify_data, size_t verify_data_len) { int type = TLS_handshake_finished; if (!record || !recordlen || !verify_data || !verify_data_len) { error_print(); return -1; } if (verify_data_len != 12 && verify_data_len != 32) { error_print(); return -1; } tls_record_set_handshake(record, recordlen, type, verify_data, verify_data_len); return 1; } // 这个应该改为只支持TLS 1.2的12字节长度判断 int tls_record_get_handshake_finished(const uint8_t *record, const uint8_t **verify_data, size_t *verify_data_len) { int type; if (!record || !verify_data || !verify_data_len) { error_print(); return -1; } if (tls_record_get_handshake(record, &type, verify_data, verify_data_len) != 1) { error_print(); return -1; } if (type != TLS_handshake_finished) { error_print(); return -1; } if (*verify_data == NULL || *verify_data_len == 0) { error_print(); return -1; } if (*verify_data_len != 12 && *verify_data_len != 32) { error_print(); return -1; } return 1; } int tls_record_set_alert(uint8_t *record, size_t *recordlen, int alert_level, int alert_description) { if (!record || !recordlen) { error_print(); return -1; } if (!tls_alert_level_name(alert_level)) { error_print(); return -1; } if (!tls_alert_description_text(alert_description)) { error_print(); return -1; } record[0] = TLS_record_alert; //record[1] = protocol.major should be set by others //record[2] = protocol.minor should be set by others record[3] = 0; // length record[4] = 2; // length record[5] = (uint8_t)alert_level; record[6] = (uint8_t)alert_description; *recordlen = TLS_RECORD_HEADER_SIZE + 2; return 1; } int tls_record_get_alert(const uint8_t *record, int *alert_level, int *alert_description) { if (!record || !alert_level || !alert_description) { error_print(); return -1; } if (tls_record_type(record) != TLS_record_alert) { error_print(); return -1; } if (record[3] != 0 || record[4] != 2) { error_print(); return -1; } *alert_level = record[5]; *alert_description = record[6]; if (!tls_alert_level_name(*alert_level)) { error_print(); return -1; } if (!tls_alert_description_text(*alert_description)) { error_puts("warning"); return -1; } return 1; } int tls_record_set_change_cipher_spec(uint8_t *record, size_t *recordlen) { if (!record || !recordlen) { error_print(); return -1; } record[0] = TLS_record_change_cipher_spec; record[3] = 0; record[4] = 1; record[5] = TLS_change_cipher_spec; *recordlen = TLS_RECORD_HEADER_SIZE + 1; return 1; } int tls_record_get_change_cipher_spec(const uint8_t *record) { if (!record) { error_print(); return -1; } if (tls_record_type(record) != TLS_record_change_cipher_spec) { error_print(); return -1; } if (record[3] != 0 || record[4] != 1) { error_print(); return -1; } if (record[5] != TLS_change_cipher_spec) { error_print(); return -1; } return 1; } int tls_record_set_application_data(uint8_t *record, size_t *recordlen, const uint8_t *data, size_t datalen) { if (!record || !recordlen || !data || !datalen) { error_print(); return -1; } record[0] = TLS_record_application_data; record[3] = (datalen >> 8) & 0xff; record[4] = datalen & 0xff; memcpy(tls_record_data(record), data, datalen); *recordlen = TLS_RECORD_HEADER_SIZE + datalen; return 1; } int tls_record_get_application_data(uint8_t *record, const uint8_t **data, size_t *datalen) { if (!record || !data || !datalen) { error_print(); return -1; } if (tls_record_type(record) != TLS_record_application_data) { error_print(); return -1; } *datalen = ((size_t)record[3] << 8) | record[4]; *data = *datalen ? record + TLS_RECORD_HEADER_SIZE : 0; return 1; } int tls_type_is_in_list(int type, const int *list, size_t list_count) { size_t i; for (i = 0; i < list_count; i++) { if (type == list[i]) { return 1; } } return 0; } static const int tlcp_ciphers[] = { TLS_cipher_ecc_sm4_cbc_sm3, TLS_cipher_ecc_sm4_gcm_sm3, TLS_cipher_ibc_sm4_cbc_sm3, TLS_cipher_ibc_sm4_gcm_sm3, }; static const int tls12_ciphers[] = { TLS_cipher_ecdhe_sm4_cbc_sm3, TLS_cipher_ecdhe_sm4_gcm_sm3, TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256, }; static const int tls13_ciphers[] = { TLS_cipher_sm4_gcm_sm3, }; int tls_cipher_suite_support_protocol(int cipher, int protocol) { const int *ciphers; size_t ciphers_cnt; switch (protocol) { case TLS_protocol_tlcp: ciphers = tlcp_ciphers; ciphers_cnt = sizeof(tlcp_ciphers)/sizeof(tlcp_ciphers[0]); break; case TLS_protocol_tls12: ciphers = tls12_ciphers; ciphers_cnt = sizeof(tls12_ciphers)/sizeof(tls12_ciphers[0]); break; case TLS_protocol_tls13: ciphers = tls13_ciphers; ciphers_cnt = sizeof(tls13_ciphers)/sizeof(tls13_ciphers[0]); break; default: error_print(); return -1; } /* if (!tls_cipher_suite_in_list(cipher, ciphers, ciphers_cnt)) { error_print(); return 0; } */ return 1; } /* 尽可能的发送数据,直到发送完整的报文,或者send 返回错误 如果send 返回EAGAIN,那么向上层返回WANT_WRITE 正常情况下,一方总是可以发送任意数量的数据,当发送方缓冲区已经满了的时候 send会返回EAGIN,那么如果底层没处理完,那就没有任何办法 如果这个函数在获得EAGAIN之后就返回给上层了,那么还需要标明到底发送出去了多少数据 */ int tls_record_send(const uint8_t *record, size_t recordlen, tls_socket_t sock) { tls_ret_t n; if (!record) { error_print(); return -1; } if (recordlen < TLS_RECORD_HEADER_SIZE) { error_print(); return -1; } if (tls_record_length(record) != recordlen) { error_print(); return -1; } while (recordlen) { if ((n = tls_socket_send(sock, record, recordlen, 0)) > 0) { record += n; recordlen -= n; } else if (n == 0) { if (errno == EAGAIN || errno == EWOULDBLOCK) { tls_socket_wait(); } else { error_puts("TCP connection closed"); return 0; } } else { if (errno == EAGAIN || errno == EWOULDBLOCK) { tls_socket_wait(); } else { error_print(); return -1; } } } return 1; } int tls_record_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock) { uint8_t *p = record; size_t len; tls_ret_t n; len = 5; while (len) { if ((n = tls_socket_recv(sock, p, len, 0)) > 0) { p += n; len -= n; } else if (n == 0) { tls_trace("TCP connection closed"); *recordlen = 0; return 0; } else { if (errno == EAGAIN || errno == EWOULDBLOCK) { if (len == 5) { return -EAGAIN; } tls_socket_wait(); } else { perror("recv"); error_print(); return -1; } } } if (!tls_record_type_name(tls_record_type(record))) { error_print(); return -1; } if (!tls_protocol_name(tls_record_protocol(record))) { error_print(); return -1; } len = (size_t)record[3] << 8 | record[4]; *recordlen = 5 + len; if (*recordlen > TLS_MAX_RECORD_SIZE) { error_print(); return -1; } while (len) { if ((n = tls_socket_recv(sock, p, len, 0)) > 0) { p += n; len -= n; } else if (n == 0) { tls_trace("connection closed"); *recordlen = 0; return 0; } else { if (errno == EAGAIN || errno == EWOULDBLOCK) { tls_socket_wait(); } else { perror("recv"); error_print(); return -1; } } } return 1; } int tls_seq_num_incr(uint8_t seq_num[8]) { int i; for (i = 7; i > 0; i--) { seq_num[i]++; if (seq_num[i]) break; } // FIXME: check overflow return 1; } void tls_seq_num_reset(uint8_t seq_num[8]) { memset(seq_num, 0, 8); } int tls_compression_methods_has_null_compression(const uint8_t *meths, size_t methslen) { if (!meths || !methslen) { error_print(); return -1; } while (methslen--) { if (*meths++ == TLS_compression_null) { return 1; } } error_print(); return -1; } int tls_send_alert(TLS_CONNECT *conn, int alert) { uint8_t record[5 + 2]; size_t recordlen; if (!conn) { error_print(); return -1; } tls_record_set_protocol(record, conn->protocol == TLS_protocol_tls13 ? TLS_protocol_tls12 : conn->protocol); tls_record_set_alert(record, &recordlen, TLS_alert_level_fatal, alert); if (tls_record_send(record, sizeof(record), conn->sock) != 1) { error_print(); return -1; } tls_record_trace(stderr, record, sizeof(record), 0, 0); return 1; } int tls_alert_level(int alert) { switch (alert) { case TLS_alert_unexpected_message: case TLS_alert_bad_record_mac: case TLS_alert_record_overflow: case TLS_alert_decompression_failure: case TLS_alert_handshake_failure: case TLS_alert_illegal_parameter: case TLS_alert_unknown_ca: case TLS_alert_access_denied: case TLS_alert_decode_error: case TLS_alert_decrypt_error: case TLS_alert_protocol_version: case TLS_alert_insufficient_security: case TLS_alert_internal_error: case TLS_alert_unsupported_extension: return TLS_alert_level_fatal; case TLS_alert_user_canceled: case TLS_alert_no_renegotiation: return TLS_alert_level_warning; } return TLS_alert_level_undefined; } int tls_send_warning(TLS_CONNECT *conn, int alert) { uint8_t record[5 + 2]; size_t recordlen; if (!conn) { error_print(); return -1; } if (tls_alert_level(alert) == TLS_alert_level_fatal) { error_print(); return -1; } tls_record_set_protocol(record, conn->protocol == TLS_protocol_tls13 ? TLS_protocol_tls12 : conn->protocol); tls_record_set_alert(record, &recordlen, TLS_alert_level_warning, alert); if (tls_record_send(record, sizeof(record), conn->sock) != 1) { error_print(); return -1; } tls_record_trace(stderr, record, sizeof(record), 0, 0); return 1; } static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *in, size_t inlen, size_t *sentlen) { const SM3_HMAC_CTX *hmac_ctx; const SM4_KEY *enc_key; uint8_t *seq_num; size_t recordlen; if (!conn) { error_print(); return -1; } if (!in || !inlen || !sentlen) { error_print(); return -1; } if (inlen > TLS_MAX_PLAINTEXT_SIZE) { inlen = TLS_MAX_PLAINTEXT_SIZE; } if (conn->datalen) { error_puts("recv all buffered data before send"); return -1; } if (conn->is_client) { hmac_ctx = &conn->client_write_mac_ctx; enc_key = &conn->client_write_enc_key; seq_num = conn->client_seq_num; } else { hmac_ctx = &conn->server_write_mac_ctx; enc_key = &conn->server_write_enc_key; seq_num = conn->server_seq_num; } if (tls_record_set_type(conn->databuf, record_type) != 1 || tls_record_set_protocol(conn->databuf, conn->protocol) != 1 || tls_record_set_data(conn->databuf, in, inlen) != 1) { error_print(); return -1; } tls_record_trace(stderr, conn->databuf, tls_record_length(conn->databuf), 0, 0); if (tls_record_encrypt(hmac_ctx, enc_key, seq_num, conn->databuf, tls_record_length(conn->databuf), conn->record, &recordlen) != 1) { error_print(); return -1; } tls_seq_num_incr(seq_num); if (tls_record_send(conn->record, recordlen, conn->sock) != 1) { error_print(); return -1; } tls_encrypted_record_trace(stderr, conn->record, recordlen, 0, 0); *sentlen = inlen; return 1; } int tls_decrypt_recv(TLS_CONNECT *conn) { int ret; const SM3_HMAC_CTX *hmac_ctx; const SM4_KEY *dec_key; uint8_t *seq_num; uint8_t *record = conn->record; size_t recordlen; if (conn->is_client) { hmac_ctx = &conn->server_write_mac_ctx; dec_key = &conn->server_write_enc_key; seq_num = conn->server_seq_num; } else { hmac_ctx = &conn->client_write_mac_ctx; dec_key = &conn->client_write_enc_key; seq_num = conn->client_seq_num; } tls_trace("recv Encrypted Record\n"); if ((ret = tls_record_recv(record, &recordlen, conn->sock)) != 1) { if (ret < 0 && ret != -EAGAIN) error_print(); return ret; } tls_encrypted_record_trace(stderr, record, recordlen, 0, 0); if (tls_record_decrypt(hmac_ctx, dec_key, seq_num, record, recordlen, conn->databuf, &conn->datalen) != 1) { error_print(); return -1; } tls_seq_num_incr(seq_num); conn->data = tls_record_data(conn->databuf); conn->datalen = tls_record_data_length(conn->databuf); tls_record_trace(stderr, conn->databuf, tls_record_length(conn->databuf), 0, 0); return 1; } int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen) { tls_trace("send ApplicationData\n"); return tls_encrypt_send(conn, TLS_record_application_data, in, inlen, sentlen); } int tls_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen) { if (!conn || !out || !outlen || !recvlen) { error_print(); return -1; } if (conn->datalen == 0) { int ret; if ((ret = tls_decrypt_recv(conn)) != 1) { if (ret < 0 && ret != -EAGAIN) error_print(); return ret; } switch (tls_record_type(conn->record)) { case TLS_record_application_data: break; case TLS_record_change_cipher_spec: error_print(); return -1; case TLS_record_alert: { // should call tls_process_alert() int level; int alert; tls_record_get_alert(conn->databuf, &level, &alert); if (alert == TLS_alert_close_notify) { tls_trace("recv Alert.close_notify\n"); return 0; } tls_trace("alert received\n"); return -1; } default: error_print(); return -1; } } *recvlen = outlen <= conn->datalen ? outlen : conn->datalen; memcpy(out, conn->data, *recvlen); conn->data += *recvlen; conn->datalen -= *recvlen; return 1; } int tls_shutdown(TLS_CONNECT *conn) { int ret; size_t recordlen; uint8_t alert[2]; alert[0] = TLS_alert_level_fatal; alert[1] = TLS_alert_close_notify; if (!conn) { error_print(); return -1; } tls_trace("send Alert.close_notify\n"); if (tls_encrypt_send(conn, TLS_record_alert, alert, sizeof(alert), &recordlen) != 1) { error_print(); return -1; } tls_trace("recv Alert.close_notify\n"); if ((ret = tls_decrypt_recv(conn)) != 1) { if (ret == 0) tls_trace("Connection closed by remote without close_notify\n"); else if (ret == -EAGAIN) tls_trace("-EAGAIN\n"); else error_print(); return -1; } return 1; } int tls_authorities_from_certs(uint8_t *names, size_t *nameslen, size_t maxlen, const uint8_t *certs, size_t certslen) { const uint8_t *cert; size_t certlen; const uint8_t *name; size_t namelen; *nameslen = 0; while (certslen) { size_t alen = 0; if (x509_cert_from_der(&cert, &certlen, &certs, &certslen) != 1 || x509_cert_get_subject(cert, certlen, &name, &namelen) != 1 || asn1_sequence_to_der(name, namelen, NULL, &alen) != 1) { error_print(); return -1; } if (tls_uint16_size() + alen > maxlen) { error_print(); return -1; } if (alen > UINT16_MAX) { error_print(); return -1; } tls_uint16_to_bytes((uint16_t)alen, &names, nameslen); if (asn1_sequence_to_der(name, namelen, &names, nameslen) != 1) { error_print(); return -1; } maxlen -= alen; } return 1; } int tls_authorities_issued_certificate(const uint8_t *ca_names, size_t ca_names_len, const uint8_t *certs, size_t certslen) { const uint8_t *cert; size_t certlen; const uint8_t *issuer; size_t issuer_len; if (x509_certs_get_last(certs, certslen, &cert, &certlen) != 1 || x509_cert_get_issuer(cert, certlen, &issuer, &issuer_len) != 1) { error_print(); return -1; } while (ca_names_len) { const uint8_t *p; size_t len; const uint8_t *name; size_t namelen; if (tls_uint16array_from_bytes(&p, &len, &ca_names, &ca_names_len) != 1) { error_print(); return -1; } if (asn1_sequence_from_der(&name, &namelen, &p, &len) != 1 || asn1_length_is_zero(len) != 1) { error_print(); return -1; } if (x509_name_equ(name, namelen, issuer, issuer_len) == 1) { return 1; } } error_print(); return 0; } int tls_cert_types_accepted(const uint8_t *types, size_t types_len, const uint8_t *client_certs, size_t client_certs_len) { const uint8_t *cert; size_t certlen; int sig_alg; size_t i; if (x509_certs_get_cert_by_index(client_certs, client_certs_len, 0, &cert, &certlen) != 1) { error_print(); return -1; } if ((sig_alg = tls_cert_type_from_oid(OID_sm2sign_with_sm3)) < 0) { error_print(); return -1; } for (i = 0; i < types_len; i++) { if (sig_alg == types[i]) { return 1; } } return 0; } int tls_client_verify_init(TLS_CLIENT_VERIFY_CTX *ctx) { if (!ctx) { error_print(); return -1; } memset(ctx, 0, sizeof(TLS_CLIENT_VERIFY_CTX)); return 1; } // FIXME: remove malloc! int tls_client_verify_update(TLS_CLIENT_VERIFY_CTX *ctx, const uint8_t *handshake, size_t handshake_len) { uint8_t *buf; if (!ctx || !handshake || !handshake_len) { error_print(); return -1; } if (ctx->index < 0 || ctx->index > 7) { error_print(); return -1; } if (!(buf = malloc(handshake_len))) { error_print(); return -1; } memcpy(buf, handshake, handshake_len); ctx->handshake[ctx->index] = buf; ctx->handshake_len[ctx->index] = handshake_len; ctx->index++; return 1; } int tls_client_verify_finish(TLS_CLIENT_VERIFY_CTX *ctx, const uint8_t *sig, size_t siglen, const SM2_KEY *public_key) { int ret; SM2_VERIFY_CTX verify_ctx; int i; if (!ctx || !sig || !siglen || !public_key) { error_print(); return -1; } if (ctx->index != 8) { error_print(); return -1; } // 这里的主要困难是,SM2的签名验证需要以Z作为输入,但是在没有拿到客户端的公钥之前,无法启动验证 if (sm2_verify_init(&verify_ctx, public_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1) { error_print(); return -1; } for (i = 0; i < 8; i++) { if (sm2_verify_update(&verify_ctx, ctx->handshake[i], ctx->handshake_len[i]) != 1) { error_print(); return -1; } } if ((ret = sm2_verify_finish(&verify_ctx, sig, siglen)) < 0) { error_print(); return -1; } return ret; } void tls_client_verify_cleanup(TLS_CLIENT_VERIFY_CTX *ctx) { if (ctx) { int i; for (i = 0; i< ctx->index; i++) { if (ctx->handshake[i]) { free(ctx->handshake[i]); ctx->handshake[i] = NULL; ctx->handshake_len[i] = 0; } } } } // 这个函数不对啊,应该以服务器优先来选择参数 int tls_cipher_suites_select(const uint8_t *client_ciphers, size_t client_ciphers_len, const int *server_ciphers, size_t server_ciphers_cnt, int *selected_cipher) { if (!client_ciphers || !client_ciphers_len || !server_ciphers || !server_ciphers_cnt || !selected_cipher) { error_print(); return -1; } while (server_ciphers_cnt--) { const uint8_t *p = client_ciphers; size_t len = client_ciphers_len; while (len) { uint16_t cipher; if (tls_uint16_from_bytes(&cipher, &p, &len) != 1) { error_print(); return -1; } if (cipher == *server_ciphers) { *selected_cipher = *server_ciphers; return 1; } } server_ciphers++; } return 0; } void tls_ctx_cleanup(TLS_CTX *ctx) { if (ctx) { gmssl_secure_clear(&ctx->signkey, sizeof(SM2_KEY)); gmssl_secure_clear(&ctx->kenckey, sizeof(SM2_KEY)); if (ctx->certs) free(ctx->certs); if (ctx->cacerts) free(ctx->cacerts); memset(ctx, 0, sizeof(TLS_CTX)); } } int tls_ctx_print(FILE *fp, int fmt, int ind, const char *label, const TLS_CTX *ctx) { return 0; } int tls_ctx_init(TLS_CTX *ctx, int protocol, int is_client) { const int supported_versions[] = { TLS_protocol_tls13, TLS_protocol_tls12, TLS_protocol_tlcp, }; if (!ctx) { error_print(); return -1; } memset(ctx, 0, sizeof(*ctx)); // protocol switch (protocol) { case TLS_protocol_tlcp: case TLS_protocol_tls12: case TLS_protocol_tls13: ctx->protocol = protocol; break; default: error_print(); return -1; } ctx->is_client = is_client ? 1 : 0; // supported_versions memcpy(ctx->supported_versions, supported_versions, sizeof(supported_versions)); ctx->supported_versions_cnt = sizeof(supported_versions)/sizeof(supported_versions[0]); ctx->verify_depth = 5; // key_share ctx->key_exchanges_cnt = 2; return 1; } int tls_ctx_set_supported_versions(TLS_CTX *ctx, const int *versions, size_t versions_cnt) { size_t i; if (!ctx || !versions || !versions_cnt) { error_print(); return -1; } if (versions_cnt > sizeof(ctx->supported_versions)/sizeof(ctx->supported_versions[0])) { error_print(); return -1; } for (i = 0; i < versions_cnt; i++) { switch (versions[i]) { case TLS_protocol_tls13: case TLS_protocol_tls12: case TLS_protocol_tlcp: break; default: error_print(); return -1; } ctx->supported_versions[i] = versions[i]; } ctx->supported_versions_cnt = versions_cnt; return 1; } int tls_ctx_set_cipher_suites(TLS_CTX *ctx, const int *cipher_suites, size_t cipher_suites_cnt) { size_t i; if (!ctx || !cipher_suites || !cipher_suites_cnt) { error_print(); return -1; } if (cipher_suites_cnt > sizeof(ctx->cipher_suites)/sizeof(ctx->cipher_suites[0])) { error_print(); return -1; } for (i = 0; i < cipher_suites_cnt; i++) { if (!tls_cipher_suite_name(cipher_suites[i])) { error_print(); return -1; } ctx->cipher_suites[i] = cipher_suites[i]; } ctx->cipher_suites_cnt = cipher_suites_cnt; return 1; } int tls_ctx_set_key_exchange_modes(TLS_CTX *ctx, int modes) { if (!ctx) { error_print(); return -1; } if (modes & ~(TLS_KE_CERT_DHE|TLS_KE_PSK_DHE|TLS_KE_PSK)) { error_print(); return -1; } ctx->key_exchange_modes = modes; return 1; } // 这个函数不是很好,直接提供的是一个文件名 int tls_ctx_set_ca_certificates(TLS_CTX *ctx, const char *cacertsfile, int depth) { if (!ctx || !cacertsfile) { error_print(); return -1; } if (depth < 0 || depth > TLS_MAX_VERIFY_DEPTH) { error_print(); return -1; } if (!tls_protocol_name(ctx->protocol)) { error_print(); return -1; } if (ctx->cacerts) { error_print(); return -1; } if (x509_certs_new_from_file(&ctx->cacerts, &ctx->cacertslen, cacertsfile) != 1) { error_print(); return -1; } if (ctx->cacertslen == 0) { error_print(); return -1; } ctx->verify_depth = depth; return 1; } // 这个函数需要设置一个默认的证书链 // 这个函数实际上是有问题的,没有给这个证书链提供status_request和sct_list // cert_chain的格式到底是什么呢? // 是单独的证书链,还是也包含扩展呢? int tls_ctx_add_certificate_chain_and_key(TLS_CTX *ctx, const char *chainfile, const char *keyfile, const char *keypass) { uint8_t *cert_chain; size_t cert_chain_len; FILE *certfp = NULL; const uint8_t *cert; size_t certlen; X509_KEY public_key; FILE *keyfp = NULL; if (!ctx || !chainfile || !keyfile || !keypass) { error_print(); return -1; } // no space in ctx->cert_chains[] if (sizeof(ctx->cert_chains) <= ctx->cert_chains_len + tls_uint24_size()) { error_print(); return -1; } // no space in ctx->x509_keys[] if (sizeof(ctx->x509_keys)/sizeof(ctx->x509_keys[0]) <= ctx->x509_keys_cnt) { error_print(); return -1; } if (!(certfp = fopen(chainfile, "r"))) { error_print(); return -1; } // read and save cert_chain as uint24array cert_chain = ctx->cert_chains + ctx->cert_chains_len; if (x509_certs_from_pem(cert_chain + tls_uint24_size(), &cert_chain_len, sizeof(ctx->cert_chains) - ctx->cert_chains_len - tls_uint24_size(), certfp) != 1) { error_print(); return -1; } tls_uint24_to_bytes(cert_chain_len, &cert_chain, &cert_chain_len); ctx->cert_chains_len += cert_chain_len; cert_chain_len -= tls_uint24_size(); if (x509_certs_get_cert_by_index(cert_chain, cert_chain_len, 0, &cert, &certlen) != 1 || x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) { fclose(certfp); error_print(); return -1; } if (public_key.algor == OID_ec_public_key) { if (!(keyfp = fopen(keyfile, "r"))) { fclose(certfp); error_print(); return -1; } } else { if (!(keyfp = fopen(keyfile, "rb+"))) { fclose(certfp); error_print(); return -1; } } // read and save at most two keys as uint16array if (x509_private_key_from_file(&ctx->x509_keys[ctx->x509_keys_cnt], public_key.algor, keypass, keyfp) != 1) { fclose(certfp); fclose(keyfp); error_print(); return -1; } ctx->x509_keys_cnt++; fclose(certfp); fclose(keyfp); return 1; } // 保留这个函数,相当于是对证书链的初始化 int tls_ctx_set_certificate_and_key(TLS_CTX *ctx, const char *chainfile, const char *keyfile, const char *keypass) { int ret = -1; uint8_t *certs = NULL; size_t certslen; FILE *keyfp = NULL; const uint8_t *cert; size_t certlen; X509_KEY public_key; if (!ctx || !chainfile || !keyfile || !keypass) { error_print(); return -1; } if (!tls_protocol_name(ctx->protocol)) { error_print(); return -1; } if (ctx->certs) { error_print(); return -1; } if (x509_certs_new_from_file(&certs, &certslen, chainfile) != 1) { error_print(); goto end; } if (x509_certs_get_cert_by_index(certs, certslen, 0, &cert, &certlen) != 1 || x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) { error_print(); return -1; } if (public_key.algor == OID_ec_public_key) { if (!(keyfp = fopen(keyfile, "r"))) { error_print(); return -1; } } else { if (!(keyfp = fopen(keyfile, "rb+"))) { error_print(); return -1; } } if (x509_private_key_from_file(&ctx->signkey, public_key.algor, keypass, keyfp) != 1) { error_print(); return -1; } ctx->certs = certs; ctx->certslen = certslen; certs = NULL; ret = 1; end: if (certs) free(certs); if (keyfp) fclose(keyfp); return ret; } int tls_ctx_set_tlcp_server_certificate_and_keys(TLS_CTX *ctx, const char *chainfile, const char *signkeyfile, const char *signkeypass, const char *kenckeyfile, const char *kenckeypass) { int ret = -1; const int algor = OID_ec_public_key; const int algor_param = OID_sm2; uint8_t *certs = NULL; size_t certslen; FILE *signkeyfp = NULL; FILE *kenckeyfp = NULL; const uint8_t *cert; size_t certlen; X509_KEY public_key; if (!ctx || !chainfile || !signkeyfile || !signkeypass || !kenckeyfile || !kenckeypass) { error_print(); return -1; } if (!tls_protocol_name(ctx->protocol)) { error_print(); return -1; } if (ctx->certs) { error_print(); return -1; } if (x509_certs_new_from_file(&certs, &certslen, chainfile) != 1) { error_print(); return -1; } // load sign key if (!(signkeyfp = fopen(signkeyfile, "r"))) { error_print(); goto end; } if (x509_private_key_from_file(&ctx->signkey, algor, signkeypass, signkeyfp) != 1) { error_print(); goto end; } if (x509_certs_get_cert_by_index(certs, certslen, 0, &cert, &certlen) != 1 || x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) { error_print(); goto end; } if (x509_public_key_equ(&ctx->signkey, &public_key) != 1) { error_print(); goto end; } // load enc key if (!(kenckeyfp = fopen(kenckeyfile, "r"))) { error_print(); goto end; } if (x509_private_key_from_file(&ctx->kenckey, algor, kenckeypass, kenckeyfp) != 1) { error_print(); goto end; } if (x509_certs_get_cert_by_index(certs, certslen, 1, &cert, &certlen) != 1 || x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) { error_print(); goto end; } if (x509_public_key_equ(&ctx->kenckey, &public_key) != 1) { error_print(); goto end; } ctx->certs = certs; ctx->certslen = certslen; certs = NULL; ret = 1; end: if (ret != 1) x509_key_cleanup(&ctx->signkey); if (ret != 1) x509_key_cleanup(&ctx->kenckey); if (certs) free(certs); if (signkeyfp) fclose(signkeyfp); if (kenckeyfp) fclose(kenckeyfp); return ret; } int tls_ctx_set_supported_groups(TLS_CTX *ctx, const int *groups, size_t groups_cnt) { size_t i; if (!ctx || !groups || !groups_cnt) { error_print(); return -1; } if (groups_cnt > sizeof(ctx->supported_groups)/sizeof(ctx->supported_groups[0])) { error_print(); return -1; } for (i = 0; i < groups_cnt; i++) { switch (groups[i]) { case TLS_curve_sm2p256v1: case TLS_curve_secp256r1: break; default: error_print(); return -1; } ctx->supported_groups[i] = groups[i]; } ctx->supported_groups_cnt = groups_cnt; return 1; } int tls_ctx_set_signature_algorithms(TLS_CTX *ctx, const int *sig_algs, size_t sig_algs_cnt) { size_t i; if (!ctx || !sig_algs || !sig_algs_cnt) { error_print(); return -1; } if (sig_algs_cnt > sizeof(ctx->signature_algorithms)/sizeof(ctx->signature_algorithms[0])) { error_print(); return -1; } for (i = 0; i < sig_algs_cnt; i++) { switch (sig_algs[i]) { case TLS_sig_sm2sig_sm3: case TLS_sig_ecdsa_secp256r1_sha256: break; default: error_print(); return -1; } ctx->signature_algorithms[i] = sig_algs[i]; } ctx->signature_algorithms_cnt = sig_algs_cnt; return 1; } int tls13_ctx_set_client_hello_key_exchanges_cnt(TLS_CTX *ctx, size_t cnt) { if (!ctx) { error_print(); return -1; } if (cnt > sizeof(((TLS_CONNECT *)NULL)->key_exchanges)/sizeof(((TLS_CONNECT *)NULL)->key_exchanges[0])) { error_print(); return -1; } ctx->key_exchanges_cnt = cnt; return 1; } int tls_init(TLS_CONNECT *conn, TLS_CTX *ctx) { size_t i; memset(conn, 0, sizeof(*conn)); conn->is_client = ctx->is_client; conn->protocol = ctx->protocol; /* for (i = 0; i < ctx->cipher_suites_cnt; i++) { conn->cipher_suites[i] = ctx->cipher_suites[i]; } conn->cipher_suites_cnt = ctx->cipher_suites_cnt; */ if (ctx->certslen > TLS_MAX_CERTIFICATES_SIZE) { error_print(); return -1; } if (conn->is_client) { memcpy(conn->client_certs, ctx->certs, ctx->certslen); conn->client_certs_len = ctx->certslen; } else { memcpy(conn->server_certs, ctx->certs, ctx->certslen); conn->server_certs_len = ctx->certslen; } /* if (ctx->cacertslen > TLS_MAX_CERTIFICATES_SIZE) { error_print(); return -1; } memcpy(conn->ca_certs, ctx->cacerts, ctx->cacertslen); conn->ca_certs_len = ctx->cacertslen; */ conn->sign_key = ctx->signkey; conn->kenc_key = ctx->kenckey; conn->quiet = ctx->quiet; conn->ctx = ctx; conn->key_exchanges_cnt = ctx->key_exchanges_cnt; conn->new_session_ticket = ctx->new_session_ticket; conn->key_exchange_modes = ctx->key_exchange_modes; // early_data conn->early_data = ctx->early_data; conn->max_early_data_size = ctx->max_early_data_size; return 1; } void tls_cleanup(TLS_CONNECT *conn) { gmssl_secure_clear(conn, sizeof(TLS_CONNECT)); } /* int tls_set_hostname(TLS_CONNECT *conn, const char *hostname) { if (strlen(hostname) > 255) { error_print(); return -1; } conn->hostname = hostname; return 1; } */ int tls_set_socket(TLS_CONNECT *conn, tls_socket_t sock) { #ifdef WIN32 u_long flags = 0; // TODO: 0 == blocking, 1 == non-blocking if(ioctlsocket(sock, FIONBIO, &flags) != 0) { error_puts("socket in non-blocking mode"); //nginx will pass a socket in non-blocking mode //return -1; // FIXME } #else int flags = 0; if ((flags = fcntl(sock, F_GETFL)) == -1) { error_print(); perror("fcntl error"); return -1; } if (flags & O_NONBLOCK) { error_puts("socket in non-blocking mode"); //nginx will pass a socket in non-blocking mode //return -1; // FIXME } #endif conn->sock = sock; return 1; } int tls_do_handshake(TLS_CONNECT *conn) { switch (conn->protocol) { case TLS_protocol_tlcp: if (conn->is_client) return tlcp_do_connect(conn); else return tlcp_do_accept(conn); case TLS_protocol_tls12: if (conn->is_client) return tls12_do_connect(conn); else return tls12_do_accept(conn); case TLS_protocol_tls13: if (conn->is_client) return tls13_do_connect(conn); else return tls13_do_accept(conn); } error_print(); return -1; } int tls_get_verify_result(TLS_CONNECT *conn, int *result) { *result = conn->verify_result; return 1; } int tls_uint16array_from_file(uint8_t *arr, size_t *arrlen, size_t maxlen, FILE *fp) { uint16_t datalen; const uint8_t *cp; size_t len = 2; if (!arr || !arrlen || !fp) { error_print(); return -1; } if (maxlen < 2) { error_print(); return -1; } if (fread(arr, 1, 2, fp) != 2) { error_print(); return -1; } cp = arr; len = 2; if (tls_uint16_from_bytes(&datalen, &cp, &len) != 1 || tls_length_is_zero(len) != 1) { error_print(); return -1; } *arrlen = 2 + datalen; if (2 + datalen > maxlen) { error_print(); return 0; } if (fread(arr + 2, 1, datalen, fp) != datalen) { error_print(); return -1; } return 1; } int tls_set_server_name(TLS_CONNECT *conn, const uint8_t *host_name, size_t host_name_len) { if (!conn || !host_name || !host_name_len) { error_print(); return -1; } if (!conn->is_client) { error_print(); return -1; } if (host_name_len >= sizeof(conn->server_name)) { error_print(); return -1; } memcpy(conn->server_name, host_name, host_name_len); conn->server_name[host_name_len] = 0; conn->server_name_len = host_name_len; return 1; }