/* * Copyright 2014-2026 The GmSSL Project. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the License); you may * not use this file except in compliance with the License. * * http://www.apache.org/licenses/LICENSE-2.0 */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include static const int tls12_ciphers[] = { TLS_cipher_ecdhe_sm4_cbc_sm3, TLS_cipher_ecdhe_sm4_gcm_sm3, TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256, }; int tls12_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int format, int indent) { // 目前只支持TLCP的ECC公钥加密套件,因此不论用哪个套件解析都是一样的 // 如果未来支持ECDHE套件,可以将函数改为宏,直接传入 (conn->cipher_suite << 8) format |= tls12_ciphers[0] << 8; // 应该是KeyExchange需要这个参数 return tls_record_print(fp, record, recordlen, format, indent); } // 这里主要的问题是我们没有 cbc_encrypt_blocks 这个函数啊 void cbc_encrypt_blocks(const BLOCK_CIPHER_KEY *key, uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out) { const uint8_t *piv = iv; while (nblocks--) { size_t i; for (i = 0; i < 16; i++) { out[i] = in[i] ^ piv[i]; } block_cipher_encrypt(key, out, out); piv = out; in += 16; out += 16; } memcpy(iv, piv, 16); } void cbc_decrypt_blocks(const BLOCK_CIPHER_KEY *key, uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out) { const uint8_t *piv = iv; while (nblocks--) { size_t i; block_cipher_decrypt(key, in, out); for (i = 0; i < 16; i++) { out[i] ^= piv[i]; } piv = in; in += 16; out += 16; } memcpy(iv, piv, 16); } // 这个函数只有在哈希函数为HASH256时才是正确的 int tls12_cbc_encrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *enc_key, const uint8_t seq_num[8], const uint8_t header[5], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { HMAC_CTX hmac_ctx; uint8_t last_blocks[32 + 16] = {0}; uint8_t iv[16]; uint8_t *mac, *padding; size_t maclen; int rem, padding_len; int i; if (!inited_hmac_ctx || !enc_key || !seq_num || !header || (!in && inlen) || !out || !outlen) { error_print(); return -1; } if (inlen > (1 << 14)) { error_print(); return -1; } if ((((size_t)header[3]) << 8) + header[4] != inlen) { error_print(); return -1; } rem = (inlen + 32) % 16; memcpy(last_blocks, in + inlen - rem, rem); mac = last_blocks + rem; memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(HMAC_CTX)); hmac_update(&hmac_ctx, seq_num, 8); hmac_update(&hmac_ctx, header, 5); hmac_update(&hmac_ctx, in, inlen); hmac_finish(&hmac_ctx, mac, &maclen); padding = mac + 32; padding_len = 16 - rem - 1; for (i = 0; i <= padding_len; i++) { padding[i] = (uint8_t)padding_len; } if (rand_bytes(iv, 16) != 1) { error_print(); return -1; } memcpy(out, iv, 16); out += 16; if (inlen >= 16) { cbc_encrypt_blocks(enc_key, iv, in, inlen/16, out); out += inlen - rem; } cbc_encrypt_blocks(enc_key, iv, last_blocks, sizeof(last_blocks)/16, out); *outlen = 16 + inlen - rem + sizeof(last_blocks); return 1; } int tls12_cbc_decrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *dec_key, const uint8_t seq_num[8], const uint8_t enced_header[5], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { HMAC_CTX hmac_ctx; uint8_t iv[16]; const uint8_t *padding; const uint8_t *mac; uint8_t header[5]; int padding_len; uint8_t hmac[32]; size_t hmaclen; int i; if (!inited_hmac_ctx || !dec_key || !seq_num || !enced_header || !in || !inlen || !out || !outlen) { error_print(); return -1; } if (inlen % 16 || inlen < (16 + 0 + 32 + 16) // iv + data + mac + padding || inlen > (16 + (1<<14) + 32 + 256)) { error_print_msg("invalid tls cbc ciphertext length %zu\n", inlen); return -1; } memcpy(iv, in, 16); format_bytes(stderr, 0, 0, "itls12_cbc_decrypt: iv", iv, 16); in += 16; inlen -= 16; cbc_decrypt_blocks(dec_key, iv, in, inlen/16, out); format_bytes(stderr, 0, 0, "cbc_decrypt out", out, inlen); padding_len = out[inlen - 1]; padding = out + inlen - padding_len - 1; if (padding < out + 32) { error_print(); return -1; } for (i = 0; i < padding_len; i++) { if (padding[i] != padding_len) { error_puts("tls ciphertext cbc-padding check failure"); return -1; } } *outlen = inlen - 32 - padding_len - 1; header[0] = enced_header[0]; header[1] = enced_header[1]; header[2] = enced_header[2]; header[3] = (uint8_t)((*outlen) >> 8); header[4] = (uint8_t)(*outlen); mac = padding - 32; memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(HMAC_CTX)); hmac_update(&hmac_ctx, seq_num, 8); hmac_update(&hmac_ctx, header, 5); hmac_update(&hmac_ctx, out, *outlen); hmac_finish(&hmac_ctx, hmac, &hmaclen); if (gmssl_secure_memcmp(mac, hmac, sizeof(hmac)) != 0) { error_puts("tls ciphertext mac check failure\n"); return -1; } return 1; } int tls12_record_encrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *cbc_key, const uint8_t seq_num[8], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { if (tls12_cbc_encrypt(hmac_ctx, cbc_key, seq_num, in, in + 5, inlen - 5, out + 5, outlen) != 1) { error_print(); return -1; } out[0] = in[0]; out[1] = in[1]; out[2] = in[2]; out[3] = (uint8_t)((*outlen) >> 8); out[4] = (uint8_t)(*outlen); (*outlen) += 5; return 1; } int tls12_record_decrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *cbc_key, const uint8_t seq_num[8], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { if (tls12_cbc_decrypt(hmac_ctx, cbc_key, seq_num, in, in + 5, inlen - 5, out + 5, outlen) != 1) { error_print(); return -1; } out[0] = in[0]; out[1] = in[1]; out[2] = in[2]; out[3] = (uint8_t)((*outlen) >> 8); out[4] = (uint8_t)(*outlen); (*outlen) += 5; return 1; } int tls12_prf(const DIGEST *digest, const uint8_t *secret, size_t secretlen, const char *label, const uint8_t *seed, size_t seedlen, const uint8_t *more, size_t morelen, size_t outlen, uint8_t *out) { HMAC_CTX inited_hmac_ctx; HMAC_CTX hmac_ctx; uint8_t A[32]; uint8_t hmac[32]; size_t len; if (!secret || !secretlen || !label || !seed || !seedlen || (!more && morelen) || !outlen || !out) { error_print(); return -1; } hmac_init(&inited_hmac_ctx, digest, secret, secretlen); memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX)); hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); hmac_update(&hmac_ctx, seed, seedlen); hmac_update(&hmac_ctx, more, morelen); hmac_finish(&hmac_ctx, A, &len); // 检查或者使用长度len memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX)); hmac_update(&hmac_ctx, A, sizeof(A)); hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); hmac_update(&hmac_ctx, seed, seedlen); hmac_update(&hmac_ctx, more, morelen); hmac_finish(&hmac_ctx, hmac, &len); len = outlen < sizeof(hmac) ? outlen : sizeof(hmac); memcpy(out, hmac, len); out += len; outlen -= len; while (outlen) { memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX)); hmac_update(&hmac_ctx, A, sizeof(A)); hmac_finish(&hmac_ctx, A, &len); memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX)); hmac_update(&hmac_ctx, A, sizeof(A)); hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); hmac_update(&hmac_ctx, seed, seedlen); hmac_update(&hmac_ctx, more, morelen); hmac_finish(&hmac_ctx, hmac, &len); len = outlen < sizeof(hmac) ? outlen : sizeof(hmac); memcpy(out, hmac, len); out += len; outlen -= len; } return 1; } // modify: conn->record_offset int tls_send_record(TLS_CONNECT *conn) { size_t left; tls_ret_t n; left = tls_record_length(conn->record) - conn->record_offset; while (left) { n = tls_socket_send(conn->sock, conn->record + conn->record_offset, left, 0); if (n < 0) { if (errno == EAGAIN || errno == EWOULDBLOCK) { return TLS_ERROR_SEND_AGAIN; } else if (errno == EINTR) { continue; } else { fprintf(stderr, "%s %d: send() error: %s\n", __FILE__, __LINE__, strerror(errno)); error_print(); return -1; } } conn->record_offset += n; left -= n; } return 1; } int tls_recv_record(TLS_CONNECT *conn) { size_t left; tls_ret_t n; if (conn->recordlen) { return 1; } if (conn->record_offset < 5) { left = 5 - conn->record_offset; while (left) { n = tls_socket_recv(conn->sock, conn->record + conn->record_offset, left, 0); if (n < 0) { if (errno == EAGAIN || errno == EWOULDBLOCK) { return TLS_ERROR_RECV_AGAIN; } else if (errno == EINTR) { continue; } else { error_print(); // TODO: check the usage of OpenSSL SSL_ERR_SYSCALL // if applications such as Nginx, HTTPD do not use this error, we just return -1 return TLS_ERROR_SYSCALL; } } else if (n == 0) { error_print(); return TLS_ERROR_TCP_CLOSED; } conn->record_offset += n; left -= n; } } if (conn->record_offset == 5) { if (!tls_record_type_name(tls_record_type(conn->record))) { error_print(); return -1; } if (!tls_protocol_name(tls_record_protocol(conn->record))) { error_print(); return -1; } if (tls_record_length(conn->record) > TLS_MAX_RECORD_SIZE) { error_print(); return -1; } } if (conn->record_offset >= tls_record_length(conn->record)) { error_print(); return -1; } left = tls_record_length(conn->record) - conn->record_offset; while (left) { n = tls_socket_recv(conn->sock, conn->record + conn->record_offset, left, 0); if (n < 0) { if (errno == EAGAIN || errno == EWOULDBLOCK) { return TLS_ERROR_RECV_AGAIN; } else if (errno == EINTR) { continue; } else { error_print(); return TLS_ERROR_SYSCALL; } } else if (n == 0) { error_print(); return TLS_ERROR_TCP_CLOSED; } conn->record_offset += n; left -= n; } conn->recordlen = conn->record_offset; // 应该判断是否为Alert这种异常状况 return 1; } int tls_named_curve_oid(int named_curve) { switch (named_curve) { case TLS_curve_secp256r1: return OID_secp256r1; case TLS_curve_sm2p256v1: return OID_sm2; } return OID_undef; } int tls_named_curve_from_oid(int oid) { switch (oid) { case OID_secp256r1: return TLS_curve_secp256r1; case OID_sm2: return TLS_curve_sm2p256v1; } return 0; } // 这个是必选的 // 服务器通常推荐返回这个值 const int supported_groups[] = { TLS_curve_sm2p256v1, TLS_curve_secp256r1, }; size_t supported_groups_cnt = sizeof(supported_groups)/sizeof(supported_groups[0]); // 仍旧是不可设置的 const int signature_algors[] = { TLS_sig_sm2sig_sm3, TLS_sig_ecdsa_secp256r1_sha256, }; size_t signature_algors_cnt = sizeof(signature_algors)/sizeof(signature_algors[0]); int tls_record_set_handshake_server_key_exchange(uint8_t *record, size_t *recordlen, const uint8_t *server_ecdh_params, size_t server_ecdh_params_len, uint16_t sig_alg, const uint8_t *sig, size_t siglen) { const int type = TLS_handshake_server_key_exchange; uint8_t *p = tls_handshake_data(tls_record_data(record)); size_t len = 0; if (server_ecdh_params_len != 69) { error_print(); return -1; } if (siglen > TLS_MAX_SIGNATURE_SIZE) { error_print(); return -1; } tls_array_to_bytes(server_ecdh_params, server_ecdh_params_len, &p, &len); tls_uint16_to_bytes(sig_alg, &p, &len); tls_uint16array_to_bytes(sig, siglen, &p, &len); tls_record_set_handshake(record, recordlen, type, NULL, len); return 1; } // 这个函数是有问题的,因为tlcp的格式和TLS不一样 int tls_record_get_handshake_server_key_exchange(const uint8_t *record, uint8_t *curve_type, uint16_t *named_curve, const uint8_t **point_octets, size_t *point_octets_len, const uint8_t **server_ecdh_params, size_t *server_ecdh_params_len, uint16_t *sig_alg, const uint8_t **sig, size_t *siglen) { int type; const uint8_t *p; size_t len; if (tls_record_get_handshake(record, &type, &p, &len) != 1) { error_print(); return -1; } if (type != TLS_handshake_server_key_exchange) { error_print(); return -1; } *server_ecdh_params = p; if (tls_uint8_from_bytes(curve_type, &p, &len) != 1 || tls_uint16_from_bytes(named_curve, &p, &len) != 1 || tls_uint8array_from_bytes(point_octets, point_octets_len, &p, &len) != 1) { error_print(); return -1; } *server_ecdh_params_len = p - *server_ecdh_params; if (*server_ecdh_params_len != 69) { error_print(); return -1; } if (tls_uint16_from_bytes(sig_alg, &p, &len) != 1 || tls_uint16array_from_bytes(sig, siglen, &p, &len) != 1 || tls_length_is_zero(len) != 1) { error_print(); return -1; } if (*curve_type != TLS_curve_type_named_curve) { error_print(); return -1; } if (!tls_named_curve_name(*named_curve)) { error_print(); return -1; } if (!tls_signature_scheme_name(*sig_alg)) { error_print(); return -1; } return 1; } int tls_record_set_handshake_client_key_exchange(uint8_t *record, size_t *recordlen, const uint8_t *point_octets, size_t point_octets_len) { int type = TLS_handshake_client_key_exchange; uint8_t *p = tls_handshake_data(tls_record_data(record)); size_t len = 0; if (point_octets_len != 65) { error_print(); return -1; } tls_uint8array_to_bytes(point_octets, (uint8_t)point_octets_len, &p, &len); tls_record_set_handshake(record, recordlen, type, NULL, len); return 1; } int tls_record_get_handshake_client_key_exchange(const uint8_t *record, const uint8_t **point_octets, size_t *point_octets_len) { int type; const uint8_t *p; size_t len; if (tls_record_get_handshake(record, &type, &p, &len) != 1 || type != TLS_handshake_client_key_exchange) { error_print(); return -1; } if (tls_uint8array_from_bytes(point_octets, point_octets_len, &p, &len) != 1) { error_print(); return -1; } if (*point_octets_len != 65) { error_print(); return -1; } if (len) { error_print(); return -1; } return 1; } int tls12_cert_chains_select(const uint8_t *cert_chains, size_t cert_chains_len, const int *supported_groups, size_t supported_groups_cnt, // optional const int *signature_algorithms, size_t signature_algorithms_cnt, // optional const uint8_t *ca_names, size_t ca_names_len, // certificate_authorities optional const uint8_t *host_name, size_t host_name_len, // optional, only in ClientHello const uint8_t **certs, size_t *certs_len, size_t *certs_idx, int *prefered_sig_alg) // optional { size_t i; if (!cert_chains || !cert_chains_len) { error_print(); return -1; } for (i = 1; cert_chains_len; i++) { const uint8_t *cert_chain; size_t cert_chain_len; int sig_alg; int ret; if (tls_uint24array_from_bytes(&cert_chain, &cert_chain_len, &cert_chains, &cert_chains_len) != 1) { error_print(); return -1; } if (certs) *certs = cert_chain; if (certs_len) *certs_len = cert_chain_len; if (certs_idx) *certs_idx = i; if (prefered_sig_alg) *prefered_sig_alg = sig_alg; return 1; } return 0; } void tls_clean_record(TLS_CONNECT *conn) { conn->record_offset = 0; conn->recordlen = 0; } int tls_handshake_init(TLS_CONNECT *conn) { //sm3_init(&conn->sm3_ctx); digest_init(&conn->dgst_ctx, DIGEST_sm3()); if (conn->client_certs_len) { //sm2_sign_init(&conn->sign_ctx, &conn->sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH); } return 1; } const int ec_point_formats[] = { TLS_point_uncompressed }; size_t ec_point_formats_cnt = sizeof(ec_point_formats)/sizeof(ec_point_formats[0]); // 有可能需要支持SNI int tls_send_client_hello(TLS_CONNECT *conn) { int ret; if (!conn->recordlen) { uint8_t exts[TLS_MAX_EXTENSIONS_SIZE]; uint8_t *pexts = exts; size_t extslen = 0; tls_trace("send ClientHello\n"); tls_record_set_protocol(conn->record, TLS_protocol_tls1); if (tls_random_generate(conn->client_random) != 1) { error_print(); return -1; } // ec_point_formats if (tls_ec_point_formats_ext_to_bytes( ec_point_formats, ec_point_formats_cnt, &pexts, &extslen) != 1) { error_print(); return -1; } // supported_groups if (conn->ctx->supported_groups_cnt) { if (tls_supported_groups_ext_to_bytes(conn->ctx->supported_groups, conn->ctx->supported_groups_cnt, &pexts, &extslen) != 1) { error_print(); return -1; } } // signature_algorithms if (conn->ctx->signature_algorithms_cnt) { if (tls_signature_algorithms_ext_to_bytes(conn->ctx->signature_algorithms, conn->ctx->signature_algorithms_cnt, &pexts, &extslen) != 1) { error_print(); return -1; } } if (tls_record_set_handshake_client_hello(conn->record, &conn->recordlen, conn->protocol, conn->client_random, NULL, 0, conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt, exts, extslen) != 1) { error_print(); return -1; } tls12_record_print(stderr, conn->record, conn->recordlen, 0, 0); // backup ClientHello memcpy(conn->plain_record, conn->record, conn->recordlen); conn->plain_recordlen = conn->recordlen; } /* if (conn->client_certificate_verify) { sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5); } */ if ((ret = tls_send_record(conn)) != 1) { if (ret != TLS_ERROR_SEND_AGAIN) { error_print(); } return ret; } tls_clean_record(conn); return 1; } /* const int server_ciphers[] = { TLS_cipher_ecdhe_sm4_cbc_sm3 }; const size_t server_ciphers_cnt = 1; */ const int curve = TLS_curve_sm2p256v1; static int tls12_cipher_suite_get(int cipher_suite, const BLOCK_CIPHER **cipher, const DIGEST **digest) { switch (cipher_suite) { case TLS_cipher_ecdhe_sm4_cbc_sm3: case TLS_cipher_ecdhe_sm4_gcm_sm3: *cipher = BLOCK_CIPHER_sm4(); *digest = DIGEST_sm3(); break; case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: *cipher = BLOCK_CIPHER_aes128(); *digest = DIGEST_sha256(); break; default: error_print(); return -1; } return 1; } static int tls12_cipher_suite_match_cert_group(int cipher_suite, int cert_group) { switch (cipher_suite) { case TLS_cipher_ecdhe_sm4_cbc_sm3: case TLS_cipher_ecdhe_sm4_gcm_sm3: return cert_group == TLS_curve_sm2p256v1; case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: return cert_group == TLS_curve_secp256r1; default: return 0; } } static int tls12_signature_scheme_match_cert_group(int sig_alg, int cert_group) { return tls_signature_scheme_group_oid(sig_alg) == tls_named_curve_oid(cert_group); } static int tls12_signature_scheme_match_cipher_suite(int sig_alg, int cipher_suite) { switch (sig_alg) { case TLS_sig_sm2sig_sm3: switch (cipher_suite) { case TLS_cipher_ecdhe_sm4_cbc_sm3: case TLS_cipher_ecdhe_sm4_gcm_sm3: return 1; } break; case TLS_sig_ecdsa_secp256r1_sha256: if (cipher_suite == TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256) { return 1; } break; } return 0; } static int tls12_key_exchange_group_match_cipher_suite(int group, int cipher_suite) { switch (cipher_suite) { case TLS_cipher_ecdhe_sm4_cbc_sm3: case TLS_cipher_ecdhe_sm4_gcm_sm3: return group == TLS_curve_sm2p256v1; case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: return group == TLS_curve_secp256r1; default: return 0; } } static int tls12_select_common_cipher_suites(const uint8_t *client_ciphers, size_t client_ciphers_len, const int *server_ciphers, size_t server_ciphers_cnt, int *common_ciphers, size_t *common_ciphers_cnt, size_t max_cnt) { size_t i; if (!client_ciphers || !client_ciphers_len || !server_ciphers || !server_ciphers_cnt || !common_ciphers || !common_ciphers_cnt || !max_cnt) { error_print(); return -1; } *common_ciphers_cnt = 0; for (i = 0; i < server_ciphers_cnt && *common_ciphers_cnt < max_cnt; i++) { const uint8_t *p = client_ciphers; size_t len = client_ciphers_len; while (len) { uint16_t cipher; if (tls_uint16_from_bytes(&cipher, &p, &len) != 1) { error_print(); return -1; } if (cipher == server_ciphers[i]) { common_ciphers[(*common_ciphers_cnt)++] = server_ciphers[i]; break; } } } return *common_ciphers_cnt ? 1 : 0; } // support_uncompressed static int tls_ec_point_formats_support_uncompressed(const uint8_t *ext_data, size_t ext_datalen) { const uint8_t *formats; size_t formats_len; int uncompressed = 0; if (tls_uint8array_from_bytes(&formats, &formats_len, &ext_data, &ext_datalen) != 1 || tls_length_is_zero(ext_datalen) != 1) { error_print(); return -1; } if (!formats_len) { error_print(); return -1; } while (formats_len) { uint8_t format; if (tls_uint8_from_bytes(&format, &formats, &formats_len) != 1) { error_print(); return -1; } if (!tls_ec_point_format_name(format)) { error_print(); return -1; } if (format == TLS_point_uncompressed) { uncompressed = 1; } } if (!uncompressed) { error_print(); return 0; } return 1; } static int tls12_cert_chain_get_end_entity_group(const uint8_t *cert_chain, size_t cert_chain_len, int *group) { const uint8_t *cert; size_t certlen; X509_KEY public_key; if (!cert_chain || !cert_chain_len || !group) { error_print(); return -1; } if (x509_certs_get_cert_by_index(cert_chain, cert_chain_len, 0, &cert, &certlen) != 1 || x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) { error_print(); return -1; } if (public_key.algor != OID_ec_public_key) { error_print(); return -1; } if ((*group = tls_named_curve_from_oid(public_key.algor_param)) == 0) { error_print(); return -1; } return 1; } static int tls12_public_key_get_group(const X509_KEY *public_key, int *group) { if (!public_key || !group) { error_print(); return -1; } if (public_key->algor != OID_ec_public_key) { error_print(); return -1; } if ((*group = tls_named_curve_from_oid(public_key->algor_param)) == 0) { error_print(); return -1; } return 1; } static int tls12_select_key_exchange_group(const int *groups, size_t groups_cnt, int cipher_suite, int *selected_group) { size_t i; if (!groups || !groups_cnt || !selected_group) { error_print(); return -1; } for (i = 0; i < groups_cnt; i++) { if (tls12_key_exchange_group_match_cipher_suite(groups[i], cipher_suite)) { *selected_group = groups[i]; return 1; } } return 0; } // 这个函数的名字最好换一下 static int tls12_select_parameters(TLS_CONNECT *conn, const int *common_cipher_suites, size_t common_cipher_suites_cnt, const int *common_supported_groups, size_t common_supported_groups_cnt, const int *common_signature_algorithms, size_t common_signature_algorithms_cnt, const int *signature_algorithms_cert, size_t signature_algorithms_cert_cnt, const uint8_t *host_name, size_t host_name_len) { const uint8_t *cert_chains = conn->ctx->cert_chains; size_t cert_chains_len = conn->ctx->cert_chains_len; size_t cert_chain_idx; if (!conn || !common_cipher_suites || !common_cipher_suites_cnt || !common_supported_groups || !common_supported_groups_cnt || !common_signature_algorithms || !common_signature_algorithms_cnt) { error_print(); return -1; } if (!cert_chains || !cert_chains_len) { error_print(); return -1; } for (cert_chain_idx = 1; cert_chains_len; cert_chain_idx++) { const uint8_t *cert_chain; size_t cert_chain_len; const uint8_t *cert; size_t certlen; int cert_group; size_t i; int ret; if (tls_uint24array_from_bytes(&cert_chain, &cert_chain_len, &cert_chains, &cert_chains_len) != 1) { error_print(); return -1; } if (tls12_cert_chain_get_end_entity_group(cert_chain, cert_chain_len, &cert_group) != 1) { error_print(); return -1; } if (!tls_type_is_in_list(cert_group, common_supported_groups, common_supported_groups_cnt)) { continue; } if (x509_certs_get_cert_by_index(cert_chain, cert_chain_len, 0, &cert, &certlen) != 1) { error_print(); return -1; } if (host_name && host_name_len) { if ((ret = tls_cert_match_server_name(cert, certlen, host_name, host_name_len)) < 0) { error_print(); return -1; } else if (ret == 0) { continue; } } if (signature_algorithms_cert && signature_algorithms_cert_cnt) { if ((ret = tls_cert_chain_match_signature_algorithms_cert(cert_chain, cert_chain_len, signature_algorithms_cert, signature_algorithms_cert_cnt)) < 0) { error_print(); return -1; } else if (ret == 0) { continue; } } for (i = 0; i < common_cipher_suites_cnt; i++) { size_t j; int cipher_suite = common_cipher_suites[i]; int key_exchange_group; if (!tls12_cipher_suite_match_cert_group(cipher_suite, cert_group)) { continue; } if ((ret = tls12_select_key_exchange_group(common_supported_groups, common_supported_groups_cnt, cipher_suite, &key_exchange_group)) < 0) { error_print(); return -1; } else if (ret == 0) { continue; } for (j = 0; j < common_signature_algorithms_cnt; j++) { int sig_alg = common_signature_algorithms[j]; if (!tls12_signature_scheme_match_cert_group(sig_alg, cert_group)) { continue; } if (!tls12_signature_scheme_match_cipher_suite(sig_alg, cipher_suite)) { continue; } conn->cipher_suite = cipher_suite; conn->cert_chain = cert_chain; conn->cert_chain_len = cert_chain_len; conn->cert_chain_idx = cert_chain_idx; conn->sig_alg = sig_alg; conn->key_exchange_group = key_exchange_group; return 1; } } } warning_print(); return 0; } int tls_recv_client_hello(TLS_CONNECT *conn) { int ret; int client_verify = 0; int protocol; const uint8_t *client_random; const uint8_t *session_id; size_t session_id_len; const uint8_t *cipher_suites; size_t cipher_suites_len; const uint8_t *exts; size_t extslen; const uint8_t *ec_point_formats = NULL; size_t ec_point_formats_len = 0; const uint8_t *supported_groups = NULL; size_t supported_groups_len = 0; const uint8_t *signature_algorithms = NULL; size_t signature_algorithms_len = 0; const uint8_t *signature_algorithms_cert = NULL; size_t signature_algorithms_cert_len = 0; const uint8_t *server_name = NULL; size_t server_name_len = 0; int common_cipher_suites[TLS_MAX_CIPHER_SUITES_COUNT]; size_t common_cipher_suites_cnt = 0; int common_supported_groups[32]; size_t common_supported_groups_cnt = 0; int common_signature_algorithms[32]; size_t common_signature_algorithms_cnt = 0; int common_signature_algorithms_cert[32]; size_t common_signature_algorithms_cert_cnt = 0; const int *cert_signature_algorithms = NULL; size_t cert_signature_algorithms_cnt = 0; const uint8_t *host_name = NULL; size_t host_name_len = 0; /* if (client_verify) tls_client_verify_init(&conn->client_verify_ctx); */ tls_trace("recv ClientHello\n"); if ((ret = tls_recv_record(conn)) != 1) { if (ret != TLS_ERROR_RECV_AGAIN) { error_print(); } return ret; } tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); if (tls_record_protocol(conn->record) != TLS_protocol_tls1) { error_print(); tls_send_alert(conn, TLS_alert_protocol_version); return -1; } if ((ret = tls_record_get_handshake_client_hello(conn->record, &protocol, &client_random, &session_id, &session_id_len, &cipher_suites, &cipher_suites_len, &exts, &extslen)) < 0) { error_print(); tls13_send_alert(conn, TLS_alert_decode_error); return -1; } else if (ret == 0) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } if (protocol != conn->protocol) { error_print(); tls_send_alert(conn, TLS_alert_protocol_version); return -1; } memcpy(conn->client_random, client_random, 32); while (extslen) { int ext_type; const uint8_t *ext_data; size_t ext_datalen; if (tls_ext_from_bytes(&ext_type, &ext_data, &ext_datalen, &exts, &extslen) != 1) { error_print(); tls13_send_alert(conn, TLS_alert_decode_error); return -1; } switch (ext_type) { case TLS_extension_ec_point_formats: case TLS_extension_supported_groups: case TLS_extension_signature_algorithms: case TLS_extension_signature_algorithms_cert: case TLS_extension_server_name: if (!ext_data) { error_print(); tls_send_alert(conn, TLS_alert_decode_error); return -1; } break; } switch (ext_type) { case TLS_extension_ec_point_formats: if (ec_point_formats) { error_print(); tls_send_alert(conn, TLS_alert_illegal_parameter); return -1; } ec_point_formats = ext_data; ec_point_formats_len = ext_datalen; break; case TLS_extension_supported_groups: if (supported_groups) { error_print(); tls_send_alert(conn, TLS_alert_illegal_parameter); return -1; } supported_groups = ext_data; supported_groups_len = ext_datalen; break; case TLS_extension_signature_algorithms: if (signature_algorithms) { error_print(); tls_send_alert(conn, TLS_alert_illegal_parameter); return -1; } signature_algorithms = ext_data; signature_algorithms_len = ext_datalen; break; case TLS_extension_signature_algorithms_cert: if (signature_algorithms_cert) { error_print(); tls_send_alert(conn, TLS_alert_illegal_parameter); return -1; } signature_algorithms_cert = ext_data; signature_algorithms_cert_len = ext_datalen; break; case TLS_extension_server_name: if (server_name) { error_print(); tls_send_alert(conn, TLS_alert_illegal_parameter); return -1; } server_name = ext_data; server_name_len = ext_datalen; break; default: warning_print(); } } if (ec_point_formats) { if ((ret = tls_ec_point_formats_support_uncompressed(ec_point_formats, ec_point_formats_len)) < 0) { error_print(); tls_send_alert(conn, TLS_alert_decode_error); return -1; } else if (ret == 0) { error_print(); tls_send_alert(conn, TLS_alert_illegal_parameter); return -1; } conn->ec_point_formats = 1; } if ((ret = tls12_select_common_cipher_suites(cipher_suites, cipher_suites_len, conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt, common_cipher_suites, &common_cipher_suites_cnt, sizeof(common_cipher_suites)/sizeof(common_cipher_suites[0]))) < 0) { error_print(); tls_send_alert(conn, TLS_alert_decode_error); return -1; } else if (ret == 0) { error_print(); tls_send_alert(conn, TLS_alert_handshake_failure); return -1; } if (supported_groups) { if ((ret = tls_process_supported_groups(supported_groups, supported_groups_len, conn->ctx->supported_groups, conn->ctx->supported_groups_cnt, common_supported_groups, &common_supported_groups_cnt, sizeof(common_supported_groups)/sizeof(common_supported_groups[0]))) < 0) { error_print(); tls_send_alert(conn, TLS_alert_decode_error); return -1; } else if (ret == 0) { error_print(); tls_send_alert(conn, TLS_alert_handshake_failure); return -1; } } else { if (!conn->ctx->supported_groups_cnt) { error_print(); tls_send_alert(conn, TLS_alert_handshake_failure); return -1; } memcpy(common_supported_groups, conn->ctx->supported_groups, conn->ctx->supported_groups_cnt * sizeof(conn->ctx->supported_groups[0])); common_supported_groups_cnt = conn->ctx->supported_groups_cnt; } if (signature_algorithms) { if ((ret = tls_process_signature_algorithms(signature_algorithms, signature_algorithms_len, conn->ctx->signature_algorithms, conn->ctx->signature_algorithms_cnt, common_signature_algorithms, &common_signature_algorithms_cnt, sizeof(common_signature_algorithms)/sizeof(common_signature_algorithms[0]))) < 0) { error_print(); tls_send_alert(conn, TLS_alert_decode_error); return -1; } else if (ret == 0) { error_print(); tls_send_alert(conn, TLS_alert_handshake_failure); return -1; } } else { if (!conn->ctx->signature_algorithms_cnt) { error_print(); tls13_send_alert(conn, TLS_alert_handshake_failure); return -1; } memcpy(common_signature_algorithms, conn->ctx->signature_algorithms, conn->ctx->signature_algorithms_cnt * sizeof(conn->ctx->signature_algorithms[0])); common_signature_algorithms_cnt = conn->ctx->signature_algorithms_cnt; } if (signature_algorithms_cert) { if ((ret = tls_process_signature_algorithms(signature_algorithms_cert, signature_algorithms_cert_len, conn->ctx->signature_algorithms, conn->ctx->signature_algorithms_cnt, common_signature_algorithms_cert, &common_signature_algorithms_cert_cnt, sizeof(common_signature_algorithms_cert)/sizeof(common_signature_algorithms_cert[0]))) < 0) { error_print(); tls_send_alert(conn, TLS_alert_decode_error); return -1; } else if (ret == 0) { error_print(); tls_send_alert(conn, TLS_alert_handshake_failure); return -1; } cert_signature_algorithms = common_signature_algorithms_cert; cert_signature_algorithms_cnt = common_signature_algorithms_cert_cnt; } else if (signature_algorithms) { cert_signature_algorithms = common_signature_algorithms; cert_signature_algorithms_cnt = common_signature_algorithms_cnt; } if (server_name) { if (tls_server_name_from_bytes(&host_name, &host_name_len, server_name, server_name_len) != 1) { error_print(); tls13_send_alert(conn, TLS_alert_decode_error); return -1; } conn->server_name = 1; } if ((ret = tls12_select_parameters(conn, common_cipher_suites, common_cipher_suites_cnt, common_supported_groups, common_supported_groups_cnt, common_signature_algorithms, common_signature_algorithms_cnt, cert_signature_algorithms, cert_signature_algorithms_cnt, host_name, host_name_len)) < 0) { error_print(); tls_send_alert(conn, TLS_alert_decode_error); return -1; } else if (ret == 0) { error_print(); tls_send_alert(conn, TLS_alert_handshake_failure); return -1; } if (tls12_cipher_suite_get(conn->cipher_suite, &conn->cipher, &conn->digest) != 1) { error_print(); tls13_send_alert(conn, TLS_alert_internal_error); return -1; } if (digest_init(&conn->dgst_ctx, conn->digest) != 1) { error_print(); return -1; } if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { error_print(); return -1; } tls_handshake_digest_print(stderr, 0, 0, "ClientHello", &conn->dgst_ctx); /* if (client_verify) tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5); */ fprintf(stderr, "end of recv_client_hello\n"); tls_clean_record(conn); return 1; } int tls_send_server_hello(TLS_CONNECT *conn) { int ret; tls_trace("send ServerHello\n"); if (conn->recordlen == 0) { uint8_t exts[512]; uint8_t *pexts = exts; size_t extslen = 0; tls_record_set_protocol(conn->record, conn->protocol); if (tls_random_generate(conn->server_random) != 1) { error_print(); return -1; } // extensions in ServerHello // ec_point_formats if (conn->ec_point_formats) { if (tls_ec_point_formats_ext_to_bytes(ec_point_formats, ec_point_formats_cnt, &pexts, &extslen) != 1) { error_print(); return -1; } } if (tls_record_set_handshake_server_hello(conn->record, &conn->recordlen, conn->protocol, conn->server_random, NULL, 0, conn->cipher_suite, exts, extslen) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { error_print(); return -1; } tls_handshake_digest_print(stderr, 0, 0, "ServerHello", &conn->dgst_ctx); } if ((ret = tls_send_record(conn)) != 1) { if (ret != TLS_ERROR_SEND_AGAIN) { error_print(); } return ret; } if (conn->ctx->cacertslen) { tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5); } tls_clean_record(conn); return 1; } int tls_recv_server_hello(TLS_CONNECT *conn) { int ret; int protocol; int cipher_suite; const uint8_t *server_random; const uint8_t *session_id; size_t session_id_len; const uint8_t *exts; size_t extslen; const uint8_t *ec_point_formats = NULL; size_t ec_point_formats_len = 0; tls_trace("recv ServerHello\n"); if ((ret = tls_recv_record(conn)) != 1) { if (ret != TLS_ERROR_RECV_AGAIN) { error_print(); } return ret; } tls12_record_print(stderr, conn->record, conn->recordlen, 0, 0); if (tls_record_protocol(conn->record) != conn->protocol) { error_print(); tls_send_alert(conn, TLS_alert_protocol_version); return -1; } if ((ret = tls_record_get_handshake_server_hello(conn->record, &protocol, &server_random, &session_id, &session_id_len, &cipher_suite, &exts, &extslen)) < 0) { error_print(); tls_send_alert(conn, TLS_alert_decode_error); return -1; } else if (ret == 0) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } // version if (protocol != conn->protocol) { error_print(); tls_send_alert(conn, TLS_alert_protocol_version); return -1; } // random memcpy(conn->server_random, server_random, 32); // session_id memcpy(conn->session_id, session_id, session_id_len); // cipher_suite if (tls_type_is_in_list(cipher_suite, conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt) != 1) { error_print(); tls_send_alert(conn, TLS_alert_handshake_failure); return -1; } conn->cipher_suite = cipher_suite; if (tls12_cipher_suite_get(conn->cipher_suite, &conn->cipher, &conn->digest) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } if (digest_init(&conn->dgst_ctx, conn->digest) != 1) { error_print(); return -1; } while (extslen) { int ext_type; const uint8_t *ext_data; size_t ext_datalen; if (tls_ext_from_bytes(&ext_type, &ext_data, &ext_datalen, &exts, &extslen) != 1) { error_print(); tls_send_alert(conn, TLS_alert_decode_error); return -1; } switch (ext_type) { case TLS_extension_ec_point_formats: if (ec_point_formats) { error_print(); tls_send_alert(conn, TLS_alert_illegal_parameter); return -1; } ec_point_formats = ext_data; ec_point_formats_len = ext_datalen; break; default: error_print(); tls_send_alert(conn, TLS_alert_illegal_parameter); return -1; } } if (ec_point_formats) { if ((ret = tls_ec_point_formats_support_uncompressed(ec_point_formats, ec_point_formats_len)) < 0) { error_print(); tls_send_alert(conn, TLS_alert_decode_error); return -1; } else if (ret == 0) { error_print(); tls_send_alert(conn, TLS_alert_illegal_parameter); return -1; } } if (digest_update(&conn->dgst_ctx, conn->plain_record + 5, conn->plain_recordlen - 5) != 1) { error_print(); return -1; } tls_handshake_digest_print(stderr, 0, 0, "ClientHello", &conn->dgst_ctx); if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { error_print(); return -1; } tls_handshake_digest_print(stderr, 0, 0, "ServerHello", &conn->dgst_ctx); //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); if (conn->client_certs_len) { sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5); } return 1; } // TLS12 发送的是常规的证书链 // TLCP SM2 发送的是SM2的双证书链,但是在数据格式上没有区别 // TLCP SM9 发送的是服务器的ID和SM9公开参数(这个格式是不同的),但是存储上可能也是一样的 // 我不确定SM2和SM9的格式是否是相容的 int tls_send_server_certificate(TLS_CONNECT *conn) { int ret; tls_trace("send ServerCertificate\n"); if (conn->recordlen == 0) { if (tls_record_set_handshake_certificate(conn->record, &conn->recordlen, conn->cert_chain, conn->cert_chain_len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { error_print(); return -1; } tls_handshake_digest_print(stderr, 0, 0, "Certificate", &conn->dgst_ctx); } if ((ret = tls_send_record(conn)) != 1) { if (ret != TLS_ERROR_SEND_AGAIN) { error_print(); } return ret; } if (conn->client_certificate_verify) { tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5); } return 1; } int tls_recv_server_certificate(TLS_CONNECT *conn) { int ret; int verify_result = 0; const uint8_t *server_cert; size_t server_cert_len; X509_KEY server_sign_key; int server_sig_alg = 0; int server_group; int cert_sig_alg = 0; const int *signature_algorithms_cert = NULL; size_t signature_algorithms_cert_cnt = 0; tls_trace("recv server Certificate\n"); if ((ret = tls_recv_record(conn)) != 1) { if (ret != TLS_ERROR_RECV_AGAIN) { error_print(); } return ret; } tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); if (tls_record_protocol(conn->record) != conn->protocol) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } if ((ret = tls_record_get_handshake_certificate(conn->record, conn->peer_cert_chain, &conn->peer_cert_chain_len)) < 0) { error_print(); tls_send_alert(conn, TLS_alert_decode_error); return -1; } else if (ret == 0) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return 0; } if (!conn->peer_cert_chain_len) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { error_print(); return -1; } tls_handshake_digest_print(stderr, 0, 0, "Certificate", &conn->dgst_ctx); // server_sign_key if (x509_certs_get_cert_by_index(conn->peer_cert_chain, conn->peer_cert_chain_len, 0, &server_cert, &server_cert_len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } if (x509_cert_get_subject_public_key(server_cert, server_cert_len, &server_sign_key) != 1) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } if (tls12_public_key_get_group(&server_sign_key, &server_group) != 1) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } // check server certificate matches negotiated cipher_suite if (!tls12_cipher_suite_match_cert_group(conn->cipher_suite, server_group)) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } switch (conn->cipher_suite) { case TLS_cipher_ecdhe_sm4_cbc_sm3: case TLS_cipher_ecdhe_sm4_gcm_sm3: case TLS_cipher_ecc_sm4_cbc_sm3: case TLS_cipher_ecc_sm4_gcm_sm3: server_sig_alg = TLS_sig_sm2sig_sm3; break; case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: server_sig_alg = TLS_sig_ecdsa_secp256r1_sha256; break; default: error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } // check server certificate matches ClientHello.supported_groups if (conn->ctx->supported_groups_cnt) { if (!tls_type_is_in_list(server_group, conn->ctx->supported_groups, conn->ctx->supported_groups_cnt)) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } } // check server certificate matches ClientHello.signature_algorithms if (conn->ctx->signature_algorithms_cnt) { if ((ret = tls_cert_match_signature_algorithms(server_cert, server_cert_len, conn->ctx->signature_algorithms, conn->ctx->signature_algorithms_cnt, &cert_sig_alg)) < 0) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } else if (ret == 0) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } if (!tls12_signature_scheme_match_cert_group(cert_sig_alg, server_group) || !tls12_signature_scheme_match_cipher_suite(cert_sig_alg, conn->cipher_suite)) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } server_sig_alg = cert_sig_alg; } // check certificate-chain signatures match ClientHello.signature_algorithms_cert if (conn->signature_algorithms_cert) { signature_algorithms_cert = conn->ctx->signature_algorithms; signature_algorithms_cert_cnt = conn->ctx->signature_algorithms_cnt; } else if (conn->ctx->signature_algorithms_cnt) { signature_algorithms_cert = conn->ctx->signature_algorithms; signature_algorithms_cert_cnt = conn->ctx->signature_algorithms_cnt; } if (signature_algorithms_cert && signature_algorithms_cert_cnt) { if ((ret = tls_cert_chain_match_signature_algorithms_cert( conn->peer_cert_chain, conn->peer_cert_chain_len, signature_algorithms_cert, signature_algorithms_cert_cnt)) < 0) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } else if (ret == 0) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } } // check server certificate matches ClientHello.server_name if (conn->server_name) { if ((ret = tls_cert_match_server_name(server_cert, server_cert_len, conn->host_name, conn->host_name_len)) < 0) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } else if (ret == 0) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } } conn->signature_algorithms[0] = server_sig_alg; conn->signature_algorithms_cnt = 1; if (conn->client_certs_len) { sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5); } assert(conn->ctx->verify_depth > 0 && conn->ctx->verify_depth < 10); // verify server Certificate if (conn->ctx->cacertslen) { if (x509_certs_verify(conn->peer_cert_chain, conn->peer_cert_chain_len, X509_cert_chain_server, conn->ctx->cacerts, conn->ctx->cacertslen, conn->ctx->verify_depth, &verify_result) != 1) { error_print(); conn->verify_result = verify_result; tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } } conn->verify_result = verify_result; return 1; } int tls_send_server_key_exchange(TLS_CONNECT *conn) { int ret; uint8_t server_ecdh_params[69]; uint8_t *p = server_ecdh_params + 4; size_t len = 0; X509_SIGN_CTX sign_ctx; const void *sign_args = NULL; size_t sign_argslen = 0; uint8_t sig[X509_SIGNATURE_MAX_SIZE]; size_t siglen; tls_trace("send ServerKeyExchange\n"); if (conn->recordlen == 0) { int curve_oid = tls_named_curve_oid(conn->key_exchange_group); // generate server ecdh_key if (x509_key_generate(&conn->key_exchanges[0], OID_ec_public_key, &curve_oid, sizeof(curve_oid)) != 1) { error_print(); return -1; } // build server_ecdh_params server_ecdh_params[0] = TLS_curve_type_named_curve; server_ecdh_params[1] = conn->key_exchange_group >> 8; server_ecdh_params[2] = (uint8_t)conn->key_exchange_group; server_ecdh_params[3] = 65; if (x509_public_key_to_bytes(&conn->key_exchanges[0], &p, &len) != 1) { error_print(); return -1; } if (len != 65) { error_print(); return -1; } X509_KEY *sign_key = &conn->ctx->x509_keys[conn->cert_chain_idx - 1]; // sign server_ecdh_params if (sign_key->algor == OID_ec_public_key && sign_key->algor_param == OID_sm2) { sign_args = SM2_DEFAULT_ID; sign_argslen = SM2_DEFAULT_ID_LENGTH; } if (x509_sign_init(&sign_ctx, sign_key, sign_args, sign_argslen) != 1 || x509_sign_update(&sign_ctx, conn->client_random, 32) != 1 || x509_sign_update(&sign_ctx, conn->server_random, 32) != 1 || x509_sign_update(&sign_ctx, server_ecdh_params, 69) != 1 || x509_sign_finish(&sign_ctx, sig, &siglen) != 1) { x509_sign_ctx_cleanup(&sign_ctx); error_print(); return -1; } x509_sign_ctx_cleanup(&sign_ctx); if (tls_record_set_handshake_server_key_exchange(conn->record, &conn->recordlen, server_ecdh_params, sizeof(server_ecdh_params), conn->sig_alg, sig, siglen) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { error_print(); return -1; } tls_handshake_digest_print(stderr, 0, 0, "ServerKeyExchange", &conn->dgst_ctx); } if ((ret = tls_send_record(conn)) != 1) { if (ret != TLS_ERROR_SEND_AGAIN) { error_print(); } return ret; } //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); if (conn->client_certificate_verify) { tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5); } return 1; } // match the ecdhe of cipher_suite int tls_curve_match_cipher_suite(int named_curve, int cipher_suite) { switch (named_curve) { case TLS_curve_sm2p256v1: switch (cipher_suite) { case TLS_cipher_ecdhe_sm4_cbc_sm3: case TLS_cipher_ecdhe_sm4_gcm_sm3: break; default: error_print(); return -1; } break; case TLS_curve_secp256r1: if (cipher_suite != TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256) { error_print(); return -1; } break; default: error_print(); return -1; } return 1; } int tls_signature_scheme_match_cipher_suite(int sig_alg, int cipher_suite) { switch (sig_alg) { case TLS_sig_sm2sig_sm3: switch (cipher_suite) { case TLS_cipher_ecdhe_sm4_cbc_sm3: case TLS_cipher_ecdhe_sm4_gcm_sm3: case TLS_cipher_ecc_sm4_cbc_sm3: case TLS_cipher_ecc_sm4_gcm_sm3: break; default: error_print(); return -1; } break; case TLS_sig_ecdsa_secp256r1_sha256: switch (cipher_suite) { case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: break; default: error_print(); return -1; } break; default: error_print(); return -1; } return 1; } int tls_recv_server_key_exchange(TLS_CONNECT *conn) { int ret; uint8_t curve_type; uint16_t named_curve; const uint8_t *point_octets; size_t point_octets_len; const uint8_t *server_ecdh_params; size_t server_ecdh_params_len; uint16_t sig_alg; const uint8_t *sig; size_t siglen; // verify ServerKeyExchange X509_KEY server_sign_key; int server_cert_index = 0; const uint8_t *server_cert; size_t server_cert_len; X509_SIGN_CTX sign_ctx; const void *sign_args = NULL; size_t sign_argslen = 0; tls_trace("recv ServerKeyExchange\n"); if ((ret = tls_recv_record(conn)) != 1) { if (ret != TLS_ERROR_RECV_AGAIN) { error_print(); } return ret; } if (tls_record_protocol(conn->record) != conn->protocol) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); // 这个函数可能是有问题的,如果cipher_suite不同,ServerKeyExchange可能也是不同的 if ((ret = tls_record_get_handshake_server_key_exchange(conn->record, &curve_type, &named_curve, &point_octets, &point_octets_len, &server_ecdh_params, &server_ecdh_params_len, &sig_alg, &sig, &siglen)) < 0) { error_print(); tls_send_alert(conn, TLS_alert_decode_error); return -1; } else if (ret == 0) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return 0; } if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { error_print(); return -1; } tls_handshake_digest_print(stderr, 0, 0, "ServerKeyExchange", &conn->dgst_ctx); if (curve_type != TLS_curve_type_named_curve) { error_print(); return -1; } // named_curve应该在supported_groups里面 //conn->ecdh_named_curve = named_curve; conn->key_exchange_group = named_curve; memcpy(conn->peer_key_exchange, point_octets, point_octets_len); conn->peer_key_exchange_len = point_octets_len; if (point_octets_len != 65) { error_print(); return -1; } if (tls_curve_match_cipher_suite(named_curve, conn->cipher_suite) != 1) { error_print(); return -1; } if (point_octets_len != 65) { error_print(); return -1; } if (tls_signature_scheme_match_cipher_suite(sig_alg, conn->cipher_suite) != 1) { error_print(); return -1; } // 解析server_key_exchange, curve_type, curve_name, point 这三个信息 // 判断curve_type == named_curve // 判断curve_name在supported_groups中并记录这个信息 // 验证point确实在curve_name的group中 //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); if (conn->client_certs_len) sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5); if (x509_certs_get_cert_by_index(conn->peer_cert_chain, conn->peer_cert_chain_len, server_cert_index, &server_cert, &server_cert_len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } if (x509_cert_get_subject_public_key(server_cert, server_cert_len, &server_sign_key) != 1) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } // 这个检查是否是多余的? // 这个值是签名算法和椭圆曲线名字的结合 // cipher_suite只能决定签名算法类型 // 公钥证书里面的公钥实际上只包含曲线的类型(而不决定签名算法,因为一个椭圆曲线本质上支持多种不同的签名算法) switch (sig_alg) { case TLS_sig_sm2sig_sm3: if (server_sign_key.algor != OID_ec_public_key || server_sign_key.algor_param != OID_sm2) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } break; case TLS_sig_ecdsa_secp256r1_sha256: if (server_sign_key.algor != OID_ec_public_key || server_sign_key.algor_param != OID_secp256r1) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } break; default: error_print(); return -1; } if (server_sign_key.algor == OID_ec_public_key && server_sign_key.algor_param == OID_sm2) { sign_args = SM2_DEFAULT_ID; sign_argslen = SM2_DEFAULT_ID_LENGTH; } // 这里应该是SM2的签名和验证 if (x509_verify_init(&sign_ctx, &server_sign_key, sign_args, sign_argslen, sig, siglen) != 1 || x509_verify_update(&sign_ctx, conn->client_random, 32) != 1 || x509_verify_update(&sign_ctx, conn->server_random, 32) != 1 || x509_verify_update(&sign_ctx, server_ecdh_params, 69) != 1 || x509_verify_finish(&sign_ctx) != 1) { error_print(); return -1; } fprintf(stderr, ">>>>>> ServerKeyExchange verify success\n"); // xxxx // 这里的签名错了,肯定是sign_ctx就是不对的,因此是不可能正确的 // 现在要做的是,必须确定server_key_exchange中都包括了哪些被签名的消息 return 1; } int tls_send_certificate_request(TLS_CONNECT *conn) { int ret; // 如果要进行客户端证书验证,服务器要提供验证的证书,但是所有证书的 const uint8_t cert_types[] = { TLS_cert_type_ecdsa_sign }; uint8_t ca_names[TLS_MAX_CA_NAMES_SIZE] = {0}; // TODO: 根据客户端验证CA证书列计算缓冲大小,或直接输出到record缓冲 size_t ca_names_len = 0; if (!conn->client_certificate_verify) { error_print(); return -1; } if (conn->recordlen == 0) { tls_trace("send CertificateRequest\n"); if (tls_authorities_from_certs(ca_names, &ca_names_len, sizeof(ca_names), conn->ctx->cacerts, conn->ctx->cacertslen) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } if (tls_record_set_handshake_certificate_request(conn->record, &conn->recordlen, cert_types, sizeof(cert_types), ca_names, ca_names_len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); } if ((ret = tls_send_record(conn)) != 1) { if (ret != TLS_ERROR_SEND_AGAIN) { error_print(); } return ret; } //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5); return 1; } int tls_recv_certificate_request(TLS_CONNECT *conn) { int ret; uint8_t *record = conn->record; const uint8_t *cp; size_t len; int handshake_type; const uint8_t *cert_types; size_t cert_types_len; const uint8_t *ca_names; size_t ca_names_len; tls_trace("recv CertificateRequest*\n"); if ((ret = tls_recv_record(conn)) != 1) { if (ret != TLS_ERROR_RECV_AGAIN) { error_print(); } return ret; } if (tls_record_protocol(conn->record) != conn->protocol) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } if (tls_record_get_handshake(record, &handshake_type, &cp, &len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } if (handshake_type != TLS_handshake_certificate_request) { tls_trace(" no CertificateRequest\n"); return 0; // 表明对方没有发送预期的报文 } tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); if (tls_record_get_handshake_certificate_request(conn->record, &cert_types, &cert_types_len, &ca_names, &ca_names_len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { error_print(); return -1; } // 这里要检查一下服务器发送的,和本地的是否保持一致 /* if(!conn->client_certs_len) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } if (tls_cert_types_accepted(cert_types, cert_types_len, conn->client_certs, conn->client_certs_len) != 1 || tls_authorities_issued_certificate(ca_names, ca_names_len, conn->client_certs, conn->client_certs_len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_unsupported_certificate); return -1; } */ sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5); conn->recordlen = 0; return 1; } int tls_send_server_hello_done(TLS_CONNECT *conn) { int ret; tls_trace("send ServerHelloDone\n"); if (conn->recordlen == 0) { tls_record_set_handshake_server_hello_done(conn->record, &conn->recordlen); tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { error_print(); return -1; } tls_handshake_digest_print(stderr, 0, 0, "ServerHelloDone", &conn->dgst_ctx); } if ((ret = tls_send_record(conn)) != 1) { if (ret != TLS_ERROR_SEND_AGAIN) { error_print(); } return ret; } if (conn->client_certs_len) { tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5); } return 1; } // 这是一个非常特殊的状态,其他的所有recv状态都是要读取的 // 但是这个状态在大多数情况下,之前已经读取完了,但是我们无法判断这个信息 int tls_recv_server_hello_done(TLS_CONNECT *conn) { int ret; tls_trace("recv ServerHelloDone\n"); if ((ret = tls_recv_record(conn)) != 1) { if (ret != TLS_ERROR_RECV_AGAIN) { error_print(); } return ret; } if (tls_record_protocol(conn->record) != conn->protocol) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); if (tls_record_get_handshake_server_hello_done(conn->record) != 1) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { error_print(); return -1; } tls_handshake_digest_print(stderr, 0, 0, "ServerHelloDone", &conn->dgst_ctx); if (conn->client_certs_len) sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5); return 1; } int tls_send_client_certificate(TLS_CONNECT *conn) { int ret; tls_trace("send ClientCertificate\n"); if (conn->client_certs_len == 0) { error_print(); return -1; } if (conn->recordlen == 0) { if (tls_record_set_handshake_certificate(conn->record, &conn->recordlen, conn->client_certs, conn->client_certs_len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); } if ((ret = tls_send_record(conn)) != 1) { if (ret != TLS_ERROR_SEND_AGAIN) { error_print(); } return ret; } //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5); return 1; } // 只有在需要验证客户端证书的时候这个函数才执行,是否内部要判断一下 int tls_recv_client_certificate(TLS_CONNECT *conn) { int ret; const int verify_depth = 5; int verify_result; tls_trace("recv ClientCertificate\n"); if (conn->ctx->cacertslen == 0) { error_print(); return -1; } if ((ret = tls_recv_record(conn)) != 1) { if (ret != TLS_ERROR_RECV_AGAIN) { error_print(); } return ret; } if (tls_record_protocol(conn->record) != conn->protocol) { // protocol检查应该在trace之后 error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); if (tls_record_get_handshake_certificate(conn->record, conn->client_certs, &conn->client_certs_len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } if (x509_certs_verify(conn->client_certs, conn->client_certs_len, X509_cert_chain_client, conn->ctx->cacerts, conn->ctx->cacertslen, verify_depth, &verify_result) != 1) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5); return 1; } int tls_generate_keys(TLS_CONNECT *conn) { uint8_t pre_master_secret[32]; size_t pre_master_secret_len; // 这里密钥是完全用ECDHE生成的 if (x509_key_exchange(&conn->key_exchanges[0], conn->peer_key_exchange, conn->peer_key_exchange_len, pre_master_secret, &pre_master_secret_len) != 1) { error_print(); return -1; } if (pre_master_secret_len != sizeof(pre_master_secret)) { error_print(); return -1; } format_bytes(stderr, 0, 0, "pre_master_secret", pre_master_secret, pre_master_secret_len); // master_secret和transcript_hash没有任何关系 if (tls12_prf(conn->digest, pre_master_secret, 32, "master secret", conn->client_random, 32, conn->server_random, 32, 48, conn->master_secret) != 1) { error_print(); return -1; } format_bytes(stderr, 0, 0, "master_secret", conn->master_secret, 48); // OpenSSL tls1_prf 中,这里生成的是128字节,也就是把IV也生成了 // 为什么生成IV呢? if (tls12_prf(conn->digest, conn->master_secret, 48, "key expansion", conn->server_random, 32, conn->client_random, 32, 96, conn->key_block) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } /* 如果这里导出了IV,并且用这个IV去加密数据 被加密的数据中包含了一个随机的IV,那么这个随机的IV是干什么用的呢? */ format_bytes(stderr, 0, 0, "key_blocks", conn->key_block, 96); if (hmac_init(&conn->client_write_mac_ctx, conn->digest, conn->key_block, 32) != 1) { error_print(); return -1; } if (hmac_init(&conn->server_write_mac_ctx, conn->digest, conn->key_block + 32, 32) != 1) { error_print(); return -1; } format_bytes(stderr, 0, 0, "client_write_mac_key", conn->key_block, 32); format_bytes(stderr, 0, 0, "server_write_mac_key", conn->key_block + 32, 32); format_bytes(stderr, 0, 0, "client_write_key", conn->key_block + 64, 16); format_bytes(stderr, 0, 0, "server_write_key", conn->key_block + 80, 16); if (conn->is_client) { block_cipher_set_encrypt_key(&conn->client_write_key, conn->cipher, conn->key_block + 64); block_cipher_set_decrypt_key(&conn->server_write_key, conn->cipher, conn->key_block + 80); } else { block_cipher_set_decrypt_key(&conn->client_write_key, conn->cipher, conn->key_block + 64); block_cipher_set_encrypt_key(&conn->server_write_key, conn->cipher, conn->key_block + 80); } tls_seq_num_reset(conn->client_seq_num); tls_seq_num_reset(conn->server_seq_num); /* tls_secrets_print(stderr, pre_master_secret, 32, conn->client_random, conn->server_random, conn->master_secret, conn->key_block, 96, 0, 4); */ return 1; } int tls_send_client_key_exchange(TLS_CONNECT *conn) { int ret; // 客户端的ECDHE的公钥肯定和服务器是保持一致的 // 因此在接收到服务器的公钥之后,应该保存这个信息 // 客户端是怎么确定密钥交换的group的?大概是从ServerKeyExchange中确定的 if (conn->recordlen == 0) { uint8_t point_octets[65]; uint8_t *p = point_octets; size_t len = 0; int curve_oid = tls_named_curve_oid(conn->key_exchange_group); if (x509_key_generate(&conn->key_exchanges[0], OID_ec_public_key, &curve_oid, sizeof(curve_oid)) != 1) { error_print(); return -1; } if (x509_public_key_to_bytes(&conn->key_exchanges[0], &p, &len) != 1) { error_print(); return -1; } if (len != sizeof(point_octets)) { error_print(); return -1; } tls_trace("send ClientKeyExchange\n"); if (tls_record_set_handshake_client_key_exchange(conn->record, &conn->recordlen, point_octets, len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { error_print(); return -1; } tls_handshake_digest_print(stderr, 0, 0, "ClientKeyExchange", &conn->dgst_ctx); } if ((ret = tls_send_record(conn)) != 1) { if (ret != TLS_ERROR_SEND_AGAIN) { error_print(); } return ret; } if (conn->client_certs_len) sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5); return 1; } int tls_recv_client_key_exchange(TLS_CONNECT *conn) { int ret; const uint8_t *point_octets; size_t point_octets_len; tls_trace("recv ClientKeyExchange\n"); if ((ret = tls_recv_record(conn)) != 1) { if (ret != TLS_ERROR_RECV_AGAIN) { error_print(); } return ret; } if (tls_record_protocol(conn->record) != conn->protocol) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); if (tls_record_get_handshake_client_key_exchange(conn->record, &point_octets, &point_octets_len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } if (point_octets_len != 65) { error_print(); return -1; } memcpy(conn->peer_key_exchange, point_octets, point_octets_len); conn->peer_key_exchange_len = point_octets_len; if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { error_print(); return -1; } tls_handshake_digest_print(stderr, 0, 0, "ClientKeyExchange", &conn->dgst_ctx); if (conn->ctx->cacertslen) tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5); return 1; } int tls_send_certificate_verify(TLS_CONNECT *conn) { int ret; uint8_t sig[SM2_MAX_SIGNATURE_SIZE]; size_t siglen; tls_trace("send CertificateVerify\n"); if (!conn->client_certificate_verify) { error_print(); return -1; } if (conn->recordlen == 0) { if (sm2_sign_finish(&conn->sign_ctx, sig, &siglen) != 1) { error_print(); return -1; } if (tls_record_set_handshake_certificate_verify(conn->record, &conn->recordlen, sig, siglen) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); } if ((ret = tls_send_record(conn)) != 1) { if (ret != TLS_ERROR_SEND_AGAIN) { error_print(); } return ret; } //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); return 1; } int tls_recv_certificate_verify(TLS_CONNECT *conn) { int ret; X509_KEY client_sign_key; const uint8_t *sig; size_t siglen; const uint8_t *client_cert; size_t client_cert_len; if (!conn->client_certificate_verify) { error_print(); return -1; } tls_trace("recv CertificateVerify\n"); if ((ret = tls_recv_record(conn)) != 1) { if (ret != TLS_ERROR_RECV_AGAIN) { error_print(); } return ret; } if (tls_record_protocol(conn->record) != conn->protocol) { tls_send_alert(conn, TLS_alert_unexpected_message); error_print(); return -1; } tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); // get signature from certificate_verify if (tls_record_get_handshake_certificate_verify(conn->record, &sig, &siglen) != 1) { tls_send_alert(conn, TLS_alert_unexpected_message); error_print(); return -1; } // get sign_key from client certificate if (x509_certs_get_cert_by_index(conn->client_certs, conn->client_certs_len, 0, &client_cert, &client_cert_len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } if (x509_cert_get_subject_public_key(client_cert, client_cert_len, &client_sign_key) != 1) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } // 这里是否要验证证书的类型呢?我们现在还不支持其他签名算法 if (client_sign_key.algor != OID_ec_public_key || client_sign_key.algor_param != OID_sm2) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } if (tls_client_verify_finish(&conn->client_verify_ctx, sig, siglen, &client_sign_key.u.sm2_key) != 1) { error_print(); tls_send_alert(conn, TLS_alert_decrypt_error); return -1; } //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); return 1; } int tls_send_change_cipher_spec(TLS_CONNECT *conn) { int ret; if (conn->recordlen == 0) { tls_trace("send [ChangeCipherSpec]\n"); if (tls_record_set_change_cipher_spec(conn->record, &conn->recordlen) !=1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); } if ((ret = tls_send_record(conn)) != 1) { if (ret != TLS_ERROR_SEND_AGAIN) { error_print(); } return ret; } return 1; } int tls_recv_change_cipher_spec(TLS_CONNECT *conn) { int ret; tls_trace("recv [ChangeCipherSpec]\n"); if ((ret = tls_recv_record(conn)) != 1) { if (ret != TLS_ERROR_RECV_AGAIN) { error_print(); } return ret; } if (tls_record_protocol(conn->record) != conn->protocol) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); if (tls_record_get_change_cipher_spec(conn->record) != 1) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } return 1; } int tls_send_client_finished(TLS_CONNECT *conn) { int ret; if (conn->recordlen == 0) { tls_trace("send client {Finished}\n"); uint8_t local_verify_data[12]; DIGEST_CTX tmp_ctx; uint8_t dgst[32]; size_t dgstlen; tmp_ctx = conn->dgst_ctx; digest_finish(&tmp_ctx, dgst, &dgstlen); if (tls12_prf(conn->digest, conn->master_secret, 48, "client finished", dgst, dgstlen, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } tls_record_set_protocol(conn->plain_record, conn->protocol); if (tls_record_set_handshake_finished(conn->plain_record, &conn->plain_recordlen, local_verify_data, sizeof(local_verify_data)) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } tls12_record_trace(stderr, conn->plain_record, conn->plain_recordlen, 0, 0); if (digest_update(&conn->dgst_ctx, conn->plain_record + 5, conn->plain_recordlen - 5) != 1) { error_print(); return -1; } tls_handshake_digest_print(stderr, 0, 0, "Finished", &conn->dgst_ctx); if (tls12_record_encrypt(&conn->client_write_mac_ctx, &conn->client_write_key, conn->client_seq_num, conn->plain_record, conn->plain_recordlen, conn->record, &conn->recordlen) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } tls_seq_num_incr(conn->client_seq_num); format_bytes(stderr, 0, 0, "encrypted finsished ..... ", conn->record, conn->recordlen); } if ((ret = tls_send_record(conn)) != 1) { if (ret != TLS_ERROR_SEND_AGAIN) { error_print(); } return ret; } return 1; } int tls_recv_client_finished(TLS_CONNECT *conn) { int ret; const uint8_t *verify_data; size_t verify_data_len; uint8_t local_verify_data[12]; DIGEST_CTX tmp_ctx; uint8_t dgst[32]; size_t dgstlen; tmp_ctx = conn->dgst_ctx; if (digest_finish(&tmp_ctx, dgst, &dgstlen) != 1) { error_print(); return -1; } if (tls12_prf(conn->digest, conn->master_secret, 48, "client finished", dgst, dgstlen, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } format_bytes(stderr, 0, 0, "verify_data", local_verify_data, 12); // recv ClientFinished tls_trace("recv client {Finished}\n"); if ((ret = tls_recv_record(conn)) != 1) { if (ret != TLS_ERROR_RECV_AGAIN) { error_print(); } return ret; } //tls12_record_print(stderr, conn->record, conn->recordlen, 0, 0); format_bytes(stderr, 0, 0, "Finished", conn->record, conn->recordlen); if (tls_record_protocol(conn->record) != conn->protocol) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } // decrypt ClientFinished tls_trace(">>>>>>>decrypt Finished\n"); format_bytes(stderr, 0, 0, "client_seq_num", conn->client_seq_num, 8); if (tls12_record_decrypt(&conn->client_write_mac_ctx, &conn->client_write_key, conn->client_seq_num, conn->record, conn->recordlen, conn->plain_record, &conn->plain_recordlen) != 1) { error_print(); tls_send_alert(conn, TLS_alert_bad_record_mac); return -1; } tls_seq_num_incr(conn->client_seq_num); tls12_record_trace(stderr, conn->plain_record, conn->plain_recordlen, 0, 0); if (tls_record_get_handshake_finished(conn->plain_record, &verify_data, &verify_data_len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_bad_record_mac); return -1; } if (verify_data_len != sizeof(local_verify_data)) { error_print(); tls_send_alert(conn, TLS_alert_bad_record_mac); return -1; } if (digest_update(&conn->dgst_ctx, conn->plain_record + 5, conn->plain_recordlen - 5) != 1) { error_print(); return -1; } tls_handshake_digest_print(stderr, 0, 0, "client Finished", &conn->dgst_ctx); // verify ClientFinished if (memcmp(verify_data, local_verify_data, sizeof(local_verify_data)) != 0) { error_puts("client_finished.verify_data verification failure"); tls_send_alert(conn, TLS_alert_decrypt_error); return -1; } return 1; } int tls_send_server_finished(TLS_CONNECT *conn) { int ret; uint8_t *record = conn->record; size_t recordlen; uint8_t local_verify_data[12]; tls_record_set_protocol(conn->plain_record, conn->protocol); if (conn->recordlen == 0) { tls_trace("send server Finished\n"); uint8_t dgst[32]; size_t dgstlen; digest_finish(&conn->dgst_ctx, dgst, &dgstlen); if (tls12_prf(conn->digest, conn->master_secret, 48, "server finished", dgst, dgstlen, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1) { error_print(); return -1; } format_bytes(stderr, 0, 0, "server verify_data", local_verify_data, 12); if (tls_record_set_handshake_finished(conn->plain_record, &conn->plain_recordlen, local_verify_data, sizeof(local_verify_data)) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } tls12_record_trace(stderr, conn->plain_record, conn->plain_recordlen, 0, 0); if (tls12_record_encrypt(&conn->server_write_mac_ctx, &conn->server_write_key, conn->server_seq_num, conn->plain_record, conn->plain_recordlen, conn->record, &conn->recordlen) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } tls_seq_num_incr(conn->server_seq_num); } if ((ret = tls_send_record(conn)) != 1) { if (ret != TLS_ERROR_SEND_AGAIN) { error_print(); } return ret; } return 1; } int tls_recv_server_finished(TLS_CONNECT *conn) { int ret; uint8_t finished_record[TLS_FINISHED_RECORD_BUF_SIZE]; size_t finished_record_len; uint8_t dgst[32]; size_t dgstlen; const uint8_t *verify_data; size_t verify_data_len; uint8_t local_verify_data[12]; if (digest_finish(&conn->dgst_ctx, dgst, &dgstlen) != 1) { error_print(); return -1; } if (tls12_prf(conn->digest, conn->master_secret, 48, "server finished", dgst, dgstlen, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } format_bytes(stderr, 0, 0, ">>> verify_data", local_verify_data, 12); // Finished tls_trace("recv server Finished\n"); if ((ret = tls_recv_record(conn)) != 1) { if (ret != TLS_ERROR_RECV_AGAIN) { error_print(); } return ret; } if (tls_record_protocol(conn->record) != conn->protocol) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } tls_trace("decrypt Finished\n"); format_bytes(stderr, 0, 0, "server_seq_num", conn->server_seq_num, 8); if (tls12_record_decrypt(&conn->server_write_mac_ctx, &conn->server_write_key, conn->server_seq_num, conn->record, conn->recordlen, conn->plain_record, &conn->plain_recordlen) != 1) { error_print(); tls_send_alert(conn, TLS_alert_bad_record_mac); return -1; } tls12_record_print(stderr, conn->plain_record, conn->plain_recordlen, 0, 0); tls_seq_num_incr(conn->server_seq_num); if (tls_record_get_handshake_finished(conn->plain_record, &verify_data, &verify_data_len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } if (verify_data_len != sizeof(local_verify_data)) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } if (memcmp(verify_data, local_verify_data, sizeof(local_verify_data)) != 0) { error_puts("server_finished.verify_data verification failure"); tls_send_alert(conn, TLS_alert_decrypt_error); return -1; } if (!conn->ctx->quiet) fprintf(stderr, "Connection established!\n"); return 1; } /* Client Server ClientHello --------> ServerHello Certificate ServerKeyExchange CertificateRequest* <-------- ServerHelloDone Certificate* ClientKeyExchange CertificateVerify* [ChangeCipherSpec] Finished --------> [ChangeCipherSpec] <-------- Finished Application Data <-------> Application Data */ int tls12_do_client_handshake(TLS_CONNECT *conn) { int ret; int next_state; switch (conn->state) { case TLS_state_client_hello: ret = tls_send_client_hello(conn); next_state = TLS_state_server_hello; break; case TLS_state_server_hello: ret = tls_recv_server_hello(conn); next_state = TLS_state_server_certificate; break; case TLS_state_server_certificate: ret = tls_recv_server_certificate(conn); next_state = TLS_state_server_key_exchange; break; case TLS_state_server_key_exchange: ret = tls_recv_server_key_exchange(conn); next_state = TLS_state_certificate_request; break; // the only optional state case TLS_state_certificate_request: fprintf(stderr, "TLS_state_certificate_request\n"); ret = tls_recv_certificate_request(conn); fprintf(stderr, " ret = %d\n", ret); if (ret == 1) conn->client_certificate_verify = 1; next_state = TLS_state_server_hello_done; break; case TLS_state_server_hello_done: fprintf(stderr, "TLS_state_server_hello_done\n"); ret = tls_recv_server_hello_done(conn); if (conn->client_certificate_verify) next_state = TLS_state_client_certificate; else next_state = TLS_state_client_key_exchange; break; case TLS_state_client_certificate: ret = tls_send_client_certificate(conn); next_state = TLS_state_client_key_exchange; break; case TLS_state_client_key_exchange: ret = tls_send_client_key_exchange(conn); next_state = TLS_state_generate_keys; break; case TLS_state_generate_keys: ret = tls_generate_keys(conn); if (conn->client_certificate_verify) next_state = TLS_state_certificate_verify; else next_state = TLS_state_client_change_cipher_spec; break; case TLS_state_certificate_verify: ret = tls_send_certificate_verify(conn); next_state = TLS_state_client_change_cipher_spec; case TLS_state_client_change_cipher_spec: ret = tls_send_change_cipher_spec(conn); next_state = TLS_state_client_finished; break; case TLS_state_client_finished: ret = tls_send_client_finished(conn); next_state = TLS_state_server_change_cipher_spec; break; case TLS_state_server_change_cipher_spec: ret = tls_recv_change_cipher_spec(conn); next_state = TLS_state_server_finished; break; case TLS_state_server_finished: ret = tls_recv_server_finished(conn); next_state = TLS_state_handshake_over; break; default: error_print(); return -1; } if (ret < 0) { if (ret == TLS_ERROR_RECV_AGAIN || ret == TLS_ERROR_SEND_AGAIN) { return ret; } else { error_print(); return ret; } } conn->state = next_state; // ret == 0 means this step is bypassed if (ret == 1) { tls_clean_record(conn); } return 1; } int tls12_do_server_handshake(TLS_CONNECT *conn) { int ret; int next_state; switch (conn->state) { case TLS_state_client_hello: ret = tls_recv_client_hello(conn); next_state = TLS_state_server_hello; break; case TLS_state_server_hello: ret = tls_send_server_hello(conn); next_state = TLS_state_server_certificate; break; case TLS_state_server_certificate: ret = tls_send_server_certificate(conn); next_state = TLS_state_server_key_exchange; break; case TLS_state_server_key_exchange: ret = tls_send_server_key_exchange(conn); if (conn->client_certificate_verify) next_state = TLS_state_certificate_request; else next_state = TLS_state_server_hello_done; break; case TLS_state_certificate_request: ret = tls_send_certificate_request(conn); next_state = TLS_state_server_hello_done; break; case TLS_state_server_hello_done: ret = tls_send_server_hello_done(conn); if (conn->client_certificate_verify) next_state = TLS_state_client_certificate; else next_state = TLS_state_client_key_exchange; break; case TLS_state_client_certificate: ret = tls_recv_client_certificate(conn); next_state = TLS_state_client_key_exchange; break; case TLS_state_client_key_exchange: ret = tls_recv_client_key_exchange(conn); if (conn->client_certificate_verify) next_state = TLS_state_certificate_verify; else next_state = TLS_state_generate_keys; break; case TLS_state_certificate_verify: ret = tls_recv_certificate_verify(conn); next_state = TLS_state_generate_keys; break; case TLS_state_generate_keys: ret = tls_generate_keys(conn); next_state = TLS_state_client_change_cipher_spec; break; case TLS_state_client_change_cipher_spec: ret = tls_recv_change_cipher_spec(conn); next_state = TLS_state_client_finished; break; case TLS_state_client_finished: ret = tls_recv_client_finished(conn); next_state = TLS_state_server_change_cipher_spec; break; case TLS_state_server_change_cipher_spec: ret = tls_send_change_cipher_spec(conn); next_state = TLS_state_server_finished; break; case TLS_state_server_finished: ret = tls_send_server_finished(conn); next_state = TLS_state_handshake_over; break; default: error_print(); return -1; } if (ret != 1) { if (ret == TLS_ERROR_RECV_AGAIN || ret == TLS_ERROR_SEND_AGAIN) { return ret; } else { error_print(); return ret; } } conn->state = next_state; tls_clean_record(conn); return 1; } // 这个函数显然是不对的,因为这个函数就是一个重入的函数,重入函数不应该自己设置状态啊 int tls12_client_handshake(TLS_CONNECT *conn) { int ret; while (conn->state != TLS_state_handshake_over) { ret = tls12_do_client_handshake(conn); if (ret != 1) { if (ret != TLS_ERROR_RECV_AGAIN && ret != TLS_ERROR_SEND_AGAIN) { error_print(); } return ret; } } // TODO: cleanup conn? return 1; } int tls12_server_handshake(TLS_CONNECT *conn) { int ret; while (conn->state != TLS_state_handshake_over) { ret = tls12_do_server_handshake(conn); if (ret != 1) { if (ret != TLS_ERROR_RECV_AGAIN && ret != TLS_ERROR_SEND_AGAIN) { error_print(); } return ret; } } // TODO: cleanup conn? return 1; } int tls12_do_connect(TLS_CONNECT *conn) { int ret; fd_set rfds; fd_set wfds; conn->state = TLS_state_client_hello; //sm3_init(&conn->sm3_ctx); digest_init(&conn->dgst_ctx, DIGEST_sm3()); while (1) { ret = tls12_client_handshake(conn); if (ret == 1) { break; } else if (ret == TLS_ERROR_SEND_AGAIN) { FD_ZERO(&rfds); FD_ZERO(&wfds); FD_SET(conn->sock, &wfds); select(conn->sock + 1, &rfds, &wfds, NULL, NULL); } else if (ret == TLS_ERROR_RECV_AGAIN) { FD_ZERO(&rfds); FD_ZERO(&wfds); FD_SET(conn->sock, &rfds); select(conn->sock + 1, &rfds, &wfds, NULL, NULL); } else { error_print(); return -1; } } return 1; } int tls12_do_accept(TLS_CONNECT *conn) { int ret; fd_set rfds; fd_set wfds; conn->state = TLS_state_client_hello; //sm3_init(&conn->sm3_ctx); digest_init(&conn->dgst_ctx, DIGEST_sm3()); while (1) { ret = tls12_server_handshake(conn); if (ret == 1) { break; } else if (ret == TLS_ERROR_SEND_AGAIN) { FD_ZERO(&rfds); FD_ZERO(&wfds); FD_SET(conn->sock, &rfds); select(conn->sock + 1, &rfds, &wfds, NULL, NULL); } else if (ret == TLS_ERROR_RECV_AGAIN) { FD_ZERO(&rfds); FD_ZERO(&wfds); FD_SET(conn->sock, &wfds); select(conn->sock + 1, &rfds, &wfds, NULL, NULL); } else { error_print(); return -1; } } return 1; }