diff --git a/include/gmssl/tls.h b/include/gmssl/tls.h index ebce5987..330b5c8e 100644 --- a/include/gmssl/tls.h +++ b/include/gmssl/tls.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -1057,15 +1058,21 @@ typedef struct { // transcript hash - SM3_CTX sm3_ctx; + //SM3_CTX sm3_ctx; DIGEST_CTX dgst_ctx; // secrets - SM3_HMAC_CTX client_write_mac_ctx; - SM3_HMAC_CTX server_write_mac_ctx; + HMAC_CTX client_write_mac_ctx; + HMAC_CTX server_write_mac_ctx; + SM4_KEY client_write_enc_key; SM4_KEY server_write_enc_key; + + + + + uint8_t client_seq_num[8]; uint8_t server_seq_num[8]; diff --git a/src/tlcp.c b/src/tlcp.c index 724de73a..8eda2759 100644 --- a/src/tlcp.c +++ b/src/tlcp.c @@ -1909,7 +1909,7 @@ int tlcp_do_connect(TLS_CONNECT *conn) // 应该把protocol_version的初始化放在这里 conn->state = TLS_state_client_hello; - sm3_init(&conn->sm3_ctx); + //sm3_init(&conn->sm3_ctx); while (1) { @@ -1946,7 +1946,7 @@ int tlcp_do_accept(TLS_CONNECT *conn) conn->state = TLS_state_client_hello; - sm3_init(&conn->sm3_ctx); + //sm3_init(&conn->sm3_ctx); while (1) { diff --git a/src/tls.c b/src/tls.c index 614895ae..28d070a8 100644 --- a/src/tls.c +++ b/src/tls.c @@ -2732,7 +2732,7 @@ int tls_init(TLS_CONNECT *conn, TLS_CTX *ctx) } if (ctx->protocol == TLS_protocol_tlcp) { - sm3_init(&conn->sm3_ctx); + //sm3_init(&conn->sm3_ctx); } diff --git a/src/tls12.c b/src/tls12.c index 91532cbb..72d4e38c 100644 --- a/src/tls12.c +++ b/src/tls12.c @@ -42,6 +42,273 @@ int tls12_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int f return tls_record_print(fp, record, recordlen, format, indent); } +// 这里主要的问题是我们没有 cbc_encrypt_blocks 这个函数啊 + + +void cbc_encrypt_blocks(const BLOCK_CIPHER_KEY *key, uint8_t iv[16], + const uint8_t *in, size_t nblocks, uint8_t *out) +{ + const uint8_t *piv = iv; + + while (nblocks--) { + size_t i; + for (i = 0; i < 16; i++) { + out[i] = in[i] ^ piv[i]; + } + block_cipher_encrypt(key, out, out); + piv = out; + in += 16; + out += 16; + } + + memcpy(iv, piv, 16); +} + +void cbc_decrypt_blocks(const BLOCK_CIPHER_KEY *key, uint8_t iv[16], + const uint8_t *in, size_t nblocks, uint8_t *out) +{ + const uint8_t *piv = iv; + + while (nblocks--) { + size_t i; + block_cipher_decrypt(key, in, out); + for (i = 0; i < 16; i++) { + out[i] ^= piv[i]; + } + piv = in; + in += 16; + out += 16; + } + + memcpy(iv, piv, 16); +} + + +// 这个函数只有在哈希函数为HASH256时才是正确的 +int tls12_cbc_encrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *enc_key, + const uint8_t seq_num[8], const uint8_t header[5], + const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) +{ + HMAC_CTX hmac_ctx; + uint8_t last_blocks[32 + 16] = {0}; + uint8_t iv[16]; + uint8_t *mac, *padding; + size_t maclen; + int rem, padding_len; + int i; + + if (!inited_hmac_ctx || !enc_key || !seq_num || !header || (!in && inlen) || !out || !outlen) { + error_print(); + return -1; + } + if (inlen > (1 << 14)) { + error_print(); + return -1; + } + if ((((size_t)header[3]) << 8) + header[4] != inlen) { + error_print(); + return -1; + } + + rem = (inlen + 32) % 16; + memcpy(last_blocks, in + inlen - rem, rem); + mac = last_blocks + rem; + + memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(HMAC_CTX)); + hmac_update(&hmac_ctx, seq_num, 8); + hmac_update(&hmac_ctx, header, 5); + hmac_update(&hmac_ctx, in, inlen); + hmac_finish(&hmac_ctx, mac, &maclen); + + padding = mac + 32; + padding_len = 16 - rem - 1; + for (i = 0; i <= padding_len; i++) { + padding[i] = (uint8_t)padding_len; + } + + if (rand_bytes(iv, 16) != 1) { + error_print(); + return -1; + } + memcpy(out, iv, 16); + out += 16; + + if (inlen >= 16) { + cbc_encrypt_blocks(enc_key, iv, in, inlen/16, out); + out += inlen - rem; + } + cbc_encrypt_blocks(enc_key, iv, last_blocks, sizeof(last_blocks)/16, out); + + *outlen = 16 + inlen - rem + sizeof(last_blocks); + return 1; +} + +int tls12_cbc_decrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *dec_key, + const uint8_t seq_num[8], const uint8_t enced_header[5], + const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) +{ + HMAC_CTX hmac_ctx; + uint8_t iv[16]; + const uint8_t *padding; + const uint8_t *mac; + uint8_t header[5]; + int padding_len; + uint8_t hmac[32]; + size_t hmaclen; + int i; + + if (!inited_hmac_ctx || !dec_key || !seq_num || !enced_header || !in || !inlen || !out || !outlen) { + error_print(); + return -1; + } + if (inlen % 16 + || inlen < (16 + 0 + 32 + 16) // iv + data + mac + padding + || inlen > (16 + (1<<14) + 32 + 256)) { + error_print_msg("invalid tls cbc ciphertext length %zu\n", inlen); + return -1; + } + + memcpy(iv, in, 16); + in += 16; + inlen -= 16; + + cbc_decrypt_blocks(dec_key, iv, in, inlen/16, out); + + padding_len = out[inlen - 1]; + padding = out + inlen - padding_len - 1; + if (padding < out + 32) { + error_print(); + return -1; + } + for (i = 0; i < padding_len; i++) { + if (padding[i] != padding_len) { + error_puts("tls ciphertext cbc-padding check failure"); + return -1; + } + } + + *outlen = inlen - 32 - padding_len - 1; + + header[0] = enced_header[0]; + header[1] = enced_header[1]; + header[2] = enced_header[2]; + header[3] = (uint8_t)((*outlen) >> 8); + header[4] = (uint8_t)(*outlen); + mac = padding - 32; + + memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(HMAC_CTX)); + hmac_update(&hmac_ctx, seq_num, 8); + hmac_update(&hmac_ctx, header, 5); + hmac_update(&hmac_ctx, out, *outlen); + hmac_finish(&hmac_ctx, hmac, &hmaclen); + + if (gmssl_secure_memcmp(mac, hmac, sizeof(hmac)) != 0) { + error_puts("tls ciphertext mac check failure\n"); + return -1; + } + return 1; +} + +int tls12_record_encrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *cbc_key, + const uint8_t seq_num[8], const uint8_t *in, size_t inlen, + uint8_t *out, size_t *outlen) +{ + if (tls12_cbc_encrypt(hmac_ctx, cbc_key, seq_num, in, + in + 5, inlen - 5, + out + 5, outlen) != 1) { + error_print(); + return -1; + } + + out[0] = in[0]; + out[1] = in[1]; + out[2] = in[2]; + out[3] = (uint8_t)((*outlen) >> 8); + out[4] = (uint8_t)(*outlen); + (*outlen) += 5; + return 1; +} + +int tls12_record_decrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *cbc_key, + const uint8_t seq_num[8], const uint8_t *in, size_t inlen, + uint8_t *out, size_t *outlen) +{ + if (tls12_cbc_decrypt(hmac_ctx, cbc_key, seq_num, in, + in + 5, inlen - 5, + out + 5, outlen) != 1) { + error_print(); + return -1; + } + + out[0] = in[0]; + out[1] = in[1]; + out[2] = in[2]; + out[3] = (uint8_t)((*outlen) >> 8); + out[4] = (uint8_t)(*outlen); + (*outlen) += 5; + + return 1; +} + +// 这个函数只依赖哈希 +int tls12_prf(const DIGEST *digest, const uint8_t *secret, size_t secretlen, const char *label, + const uint8_t *seed, size_t seedlen, + const uint8_t *more, size_t morelen, + size_t outlen, uint8_t *out) +{ + HMAC_CTX inited_hmac_ctx; + HMAC_CTX hmac_ctx; + uint8_t A[32]; + uint8_t hmac[32]; + size_t len; + + if (!secret || !secretlen || !label || !seed || !seedlen + || (!more && morelen) || !outlen || !out) { + error_print(); + return -1; + } + + hmac_init(&inited_hmac_ctx, digest, secret, secretlen); + + memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX)); + hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); + hmac_update(&hmac_ctx, seed, seedlen); + hmac_update(&hmac_ctx, more, morelen); + hmac_finish(&hmac_ctx, A, &len); // 检查或者使用长度len + + memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX)); + hmac_update(&hmac_ctx, A, sizeof(A)); + hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); + hmac_update(&hmac_ctx, seed, seedlen); + hmac_update(&hmac_ctx, more, morelen); + hmac_finish(&hmac_ctx, hmac, &len); + + len = outlen < sizeof(hmac) ? outlen : sizeof(hmac); + memcpy(out, hmac, len); + out += len; + outlen -= len; + + while (outlen) { + memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX)); + hmac_update(&hmac_ctx, A, sizeof(A)); + hmac_finish(&hmac_ctx, A, &len); + + memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX)); + hmac_update(&hmac_ctx, A, sizeof(A)); + hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); + hmac_update(&hmac_ctx, seed, seedlen); + hmac_update(&hmac_ctx, more, morelen); + hmac_finish(&hmac_ctx, hmac, &len); + + len = outlen < sizeof(hmac) ? outlen : sizeof(hmac); + memcpy(out, hmac, len); + out += len; + outlen -= len; + } + return 1; +} + + // modify: conn->record_offset @@ -189,8 +456,6 @@ int tls_named_curve_from_oid(int oid) // 这个是必选的 -const int ec_point_formats[] = { TLS_point_uncompressed }; -size_t ec_point_formats_cnt = sizeof(ec_point_formats)/sizeof(ec_point_formats[0]); // 服务器通常推荐返回这个值 const int supported_groups[] = { @@ -337,7 +602,7 @@ void tls_clean_record(TLS_CONNECT *conn) int tls_handshake_init(TLS_CONNECT *conn) { - sm3_init(&conn->sm3_ctx); + //sm3_init(&conn->sm3_ctx); digest_init(&conn->dgst_ctx, DIGEST_sm3()); @@ -349,65 +614,72 @@ int tls_handshake_init(TLS_CONNECT *conn) } -/* -TLCP协议中ClientHello中不包含扩展,并且cipher_suites使用的是TLCP的cipher_suites -ciphers我觉得应该在设置ctx的时候设置好 -exts -*/ int tls_send_client_hello(TLS_CONNECT *conn) { int ret; - uint8_t *record = conn->record; if (!conn->recordlen) { - uint8_t client_exts[TLS_MAX_EXTENSIONS_SIZE]; - uint8_t *p = client_exts; - size_t client_exts_len = 0; + const int ec_point_formats[] = { TLS_point_uncompressed }; + size_t ec_point_formats_cnt = sizeof(ec_point_formats)/sizeof(ec_point_formats[0]); + uint8_t exts[TLS_MAX_EXTENSIONS_SIZE]; + uint8_t *pexts = exts; + size_t extslen = 0; - switch (conn->protocol) { - case TLS_protocol_tls12: - tls_record_set_protocol(record, TLS_protocol_tls1); - break; - case TLS_protocol_tlcp: - tls_record_set_protocol(record, TLS_protocol_tlcp); - break; - default: - error_print(); - return -1; - } + tls_trace("send ClientHello\n"); + + tls_record_set_protocol(conn->record, TLS_protocol_tls1); if (tls_random_generate(conn->client_random) != 1) { error_print(); return -1; } - if (tls_ec_point_formats_ext_to_bytes(ec_point_formats, ec_point_formats_cnt, &p, &client_exts_len) != 1 - || tls_supported_groups_ext_to_bytes(supported_groups, supported_groups_cnt, &p, &client_exts_len) != 1 - || tls_signature_algorithms_ext_to_bytes(signature_algors, signature_algors_cnt, &p, &client_exts_len) != 1) { + // ec_point_formats + if (tls_ec_point_formats_ext_to_bytes( + ec_point_formats, ec_point_formats_cnt, &pexts, &extslen) != 1) { error_print(); return -1; } + + // supported_groups + if (conn->ctx->supported_groups_cnt) { + if (tls_supported_groups_ext_to_bytes(conn->ctx->supported_groups, + conn->ctx->supported_groups_cnt, &pexts, &extslen) != 1) { + error_print(); + return -1; + } + } + + // signature_algorithms + if (conn->ctx->signature_algorithms_cnt) { + if (tls_signature_algorithms_ext_to_bytes(conn->ctx->signature_algorithms, + conn->ctx->signature_algorithms_cnt, &pexts, &extslen) != 1) { + error_print(); + return -1; + } + } + if (tls_record_set_handshake_client_hello(conn->record, &conn->recordlen, conn->protocol, conn->client_random, NULL, 0, conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt, - client_exts, client_exts_len) != 1) { + exts, extslen) != 1) { error_print(); return -1; } - // offset = 0, recordlen > 0 - tls_trace("send ClientHello\n"); - tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); - sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); + tls12_record_print(stderr, conn->record, conn->recordlen, 0, 0); + + // backup ClientHello + memcpy(conn->plain_record, conn->record, conn->recordlen); + conn->plain_recordlen = conn->recordlen; } - // 客户端一开始不知道是否要进行客户端验证 - // 如果用户提供了客户端证书那么就准备,如果没提供就完全不准确 - // 客户端的证书可以通过回调函数来设置 + /* if (conn->client_certificate_verify) { sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5); } + */ if ((ret = tls_send_record(conn)) != 1) { if (ret != TLS_ERROR_SEND_AGAIN) { @@ -420,21 +692,6 @@ int tls_send_client_hello(TLS_CONNECT *conn) return 1; } - - - - - - - - - - - - - - - /* const int server_ciphers[] = { TLS_cipher_ecdhe_sm4_cbc_sm3 }; const size_t server_ciphers_cnt = 1; @@ -442,11 +699,9 @@ const size_t server_ciphers_cnt = 1; const int curve = TLS_curve_sm2p256v1; -// 服务器在收到ClientHello之后, int tls_recv_client_hello(TLS_CONNECT *conn) { int ret; - uint8_t *record = conn->record; int client_verify = 0; @@ -454,21 +709,15 @@ int tls_recv_client_hello(TLS_CONNECT *conn) const uint8_t *client_random; const uint8_t *session_id; size_t session_id_len; - const uint8_t *client_ciphers; - size_t client_ciphers_len; - const uint8_t *client_exts; - size_t client_exts_len; + const uint8_t *cipher_suites; + size_t cipher_suites_len; + const uint8_t *exts; + size_t extslen; - sm3_init(&conn->sm3_ctx); - - - // 服务器端如果设置了CA - if (conn->ctx->cacertslen) - client_verify = 1; - - // 这个判断应该改为一个函数 + /* if (client_verify) tls_client_verify_init(&conn->client_verify_ctx); + */ tls_trace("recv ClientHello\n"); @@ -481,20 +730,19 @@ int tls_recv_client_hello(TLS_CONNECT *conn) } tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); - - - - if (tls_record_protocol(record) != conn->protocol - && tls_record_protocol(record) != TLS_protocol_tls1) { + if (tls_record_protocol(conn->record) != TLS_protocol_tls1) { error_print(); tls_send_alert(conn, TLS_alert_protocol_version); return -1; } - if (tls_record_get_handshake_client_hello(record, + if ((ret = tls_record_get_handshake_client_hello(conn->record, &protocol, &client_random, &session_id, &session_id_len, - &client_ciphers, &client_ciphers_len, - &client_exts, &client_exts_len) != 1) { + &cipher_suites, &cipher_suites_len, &exts, &extslen)) < 0) { + error_print(); + tls13_send_alert(conn, TLS_alert_decode_error); + return -1; + } else if (ret == 0) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; @@ -505,17 +753,31 @@ int tls_recv_client_hello(TLS_CONNECT *conn) tls_send_alert(conn, TLS_alert_protocol_version); return -1; } + memcpy(conn->client_random, client_random, 32); - // 服务器选择的cipher_suites需要和服务器准备的证书和公钥匹配 - if (tls_cipher_suites_select(client_ciphers, client_ciphers_len, + if ((ret = tls_cipher_suites_select(cipher_suites, cipher_suites_len, conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt, - &conn->cipher_suite) != 1) { + &conn->cipher_suite)) < 0) { error_print(); - tls_send_alert(conn, TLS_alert_insufficient_security); + tls13_send_alert(conn, TLS_alert_decode_error); + return -1; + } else if (ret == 0) { + error_print(); + tls13_send_alert(conn, TLS_alert_handshake_failure); return -1; } + /* + TLS_cipher_ecdhe_sm4_cbc_sm3 + TLS_cipher_ecdhe_sm4_gcm_sm3 + TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256 + */ + + conn->cipher = BLOCK_CIPHER_sm4(); + conn->digest = DIGEST_sm3(); + + /* switch (conn->cipher_suite) { case TLS_cipher_ecdhe_sm4_cbc_sm3: case TLS_cipher_ecdhe_sm4_gcm_sm3: @@ -530,18 +792,40 @@ int tls_recv_client_hello(TLS_CONNECT *conn) error_print(); return -1; } + */ + + while (extslen) { + int ext_type; + const uint8_t *ext_data; + size_t ext_datalen; + + if (tls_ext_from_bytes(&ext_type, &ext_data, &ext_datalen, &exts, &extslen) != 1) { + error_print(); + tls13_send_alert(conn, TLS_alert_decode_error); + return -1; + } + + + // 这些扩展都不是必须的 + - if (client_exts) { - // 这些函数需要能够访问conn的内部变量 - // 修改处理扩展的逻辑 - //tls_process_client_hello_exts(client_exts, client_exts_len, - // conn->server_exts, &conn->server_exts_len, sizeof(conn->server_exts)); } - sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); + + if (digest_init(&conn->dgst_ctx, conn->digest) != 1) { + error_print(); + return -1; + } + if (digest_update(&conn->dgst_ctx, conn->plain_record + 5, conn->plain_recordlen - 5) != 1) { + error_print(); + return -1; + } + + /* if (client_verify) tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5); + */ fprintf(stderr, "end of recv_client_hello\n"); tls_clean_record(conn); @@ -588,7 +872,7 @@ int tls_send_server_hello(TLS_CONNECT *conn) return ret; } - sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); + //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); if (conn->ctx->cacertslen) { tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5); } @@ -605,15 +889,17 @@ int tls_recv_server_hello(TLS_CONNECT *conn) const uint8_t *server_random; const uint8_t *session_id; size_t session_id_len; - const uint8_t *server_exts; - size_t server_exts_len; + const uint8_t *exts; + size_t extslen; - // 扩展的协商结果,-1 表示服务器不支持该扩展(未给出响应) - int ec_point_format = -1; - int supported_group = -1; - int signature_algor = -1; + const uint8_t *ec_point_formats = NULL; + size_t ec_point_formats_len; + const uint8_t *supported_groups = NULL; + size_t supported_groups_len; + const uint8_t *signature_algorithms = NULL; + size_t signature_algorithms_len; - // 实际上当前的record已经有完整的数据了,但是我们不知道啊 + tls_trace("recv ServerHello\n"); if ((ret = tls_recv_record(conn)) != 1) { if (ret != TLS_ERROR_RECV_AGAIN) { @@ -621,62 +907,153 @@ int tls_recv_server_hello(TLS_CONNECT *conn) } return ret; } + tls12_record_print(stderr, conn->record, conn->recordlen, 0, 0); - tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); if (tls_record_protocol(conn->record) != conn->protocol) { error_print(); tls_send_alert(conn, TLS_alert_protocol_version); return -1; } - if (tls_record_get_handshake_server_hello(conn->record, + if ((ret = tls_record_get_handshake_server_hello(conn->record, &protocol, &server_random, &session_id, &session_id_len, &cipher_suite, - &server_exts, &server_exts_len) != 1) { + &exts, &extslen)) < 0) { + error_print(); + tls_send_alert(conn, TLS_alert_decode_error); + return -1; + } else if (ret == 0) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } + + // version if (protocol != conn->protocol) { error_print(); tls_send_alert(conn, TLS_alert_protocol_version); return -1; } - - /* - if (tls_cipher_suite_in_list(cipher_suite, conn->cipher_suites, conn->cipher_suites_cnt) != 1) { + + // random + memcpy(conn->server_random, server_random, 32); + + // session_id + memcpy(conn->session_id, session_id, session_id_len); + + // cipher_suite + if (tls_type_is_in_list(cipher_suite, conn->ctx->cipher_suites, conn->ctx->cipher_suites_cnt) != 1) { error_print(); tls_send_alert(conn, TLS_alert_handshake_failure); return -1; } - */ - - /* - 对于扩展的处理 + conn->cipher_suite = cipher_suite; - 首先扩展是由客户端ClientHello进行设定,服务器选择,最后由Client验证的一个过程。 - 因此客户端和服务器端都需要存储扩展相应的数据。 + // 初始化digest + conn->digest = DIGEST_sha256(); - 扩展的初始化是如何实现的?我觉得也应该在CTX中完成。 + conn->cipher = BLOCK_CIPHER_aes128(); - 现在对扩展的处理逻辑是有问题的,服务器ServerHello是否包含扩展是取决于ClientHello的 - */ - if (server_exts) { - if (tls_process_server_hello_exts(server_exts, server_exts_len, &ec_point_format, &supported_group, &signature_algor) != 1 - || ec_point_format < 0 - || supported_group < 0 - || signature_algor < 0) { + if (digest_init(&conn->dgst_ctx, conn->digest) != 1) { + error_print(); + return -1; + } + + + while (extslen) { + int ext_type; + const uint8_t *ext_data; + size_t ext_datalen; + + if (tls_ext_from_bytes(&ext_type, &ext_data, &ext_datalen, &exts, &extslen) != 1) { error_print(); - tls_send_alert(conn, TLS_alert_unexpected_message); + tls13_send_alert(conn, TLS_alert_decode_error); return -1; } + + // extensions in ServerHello + // * ec_point_formats + // * supported_groups + // * signature_algorithms + + switch (ext_type) { + case TLS_extension_ec_point_formats: + case TLS_extension_supported_groups: + case TLS_extension_signature_algorithms: + if (!ext_data) { + error_print(); + tls13_send_alert(conn, TLS_alert_illegal_parameter); + return -1; + } + break; + default: + error_print(); + return -1; + } + + switch (ext_type) { + case TLS_extension_ec_point_formats: + if (ec_point_formats) { + error_print(); + tls13_send_alert(conn, TLS_alert_illegal_parameter); + return -1; + } + ec_point_formats = ext_data; + ec_point_formats_len = ext_datalen; + break; + + case TLS_extension_supported_groups: + if (supported_groups) { + error_print(); + tls13_send_alert(conn, TLS_alert_illegal_parameter); + return -1; + } + supported_groups = ext_data; + supported_groups_len = ext_datalen; + break; + + case TLS_extension_signature_algorithms: + if (signature_algorithms) { + error_print(); + tls13_send_alert(conn, TLS_alert_illegal_parameter); + return -1; + } + signature_algorithms = ext_data; + signature_algorithms_len = ext_datalen; + break; + } } - memcpy(conn->server_random, server_random, 32); - memcpy(conn->session_id, session_id, session_id_len); - conn->cipher_suite = cipher_suite; - sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); + if (!ec_point_formats) { + error_print(); + tls13_send_alert(conn, TLS_alert_missing_extension); + return -1; + } + + if (supported_groups) { + } + + if (signature_algorithms) { + } + + + + if (digest_update(&conn->dgst_ctx, conn->plain_record + 5, conn->plain_recordlen - 5) != 1) { + error_print(); + return -1; + } + tls_handshake_digest_print(stderr, 0, 0, "ClientHello", &conn->dgst_ctx); + + if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { + error_print(); + return -1; + } + tls_handshake_digest_print(stderr, 0, 0, "ServerHello", &conn->dgst_ctx); + + + + //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); if (conn->client_certs_len) { sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5); } @@ -710,7 +1087,7 @@ int tls_send_server_certificate(TLS_CONNECT *conn) return ret; } - sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); + //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); if (conn->client_certificate_verify) { tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5); } @@ -726,7 +1103,7 @@ int tls_recv_server_certificate(TLS_CONNECT *conn) X509_KEY server_sign_key; - tls_trace("recv ServerCertificate\n"); + tls_trace("recv server Certificate\n"); if ((ret = tls_recv_record(conn)) != 1) { if (ret != TLS_ERROR_RECV_AGAIN) { @@ -734,23 +1111,34 @@ int tls_recv_server_certificate(TLS_CONNECT *conn) } return ret; } + tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); + if (tls_record_protocol(conn->record) != conn->protocol) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } - tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); - if (tls_record_get_handshake_certificate(conn->record, - conn->server_certs, &conn->server_certs_len) != 1) { + if ((ret = tls_record_get_handshake_certificate(conn->record, + conn->peer_cert_chain, &conn->peer_cert_chain_len)) < 0) { + error_print(); + tls_send_alert(conn, TLS_alert_decode_error); + return -1; + } else if (ret == 0) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); - return -1; + return 0; } - // 这下面是对获取的证书链的处理 + if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { + error_print(); + return -1; + } + tls_handshake_digest_print(stderr, 0, 0, "Certificate", &conn->dgst_ctx); - if (x509_certs_get_cert_by_index(conn->server_certs, conn->server_certs_len, 0, + + // 这里取服务器证书似乎没有什么用处啊 + if (x509_certs_get_cert_by_index(conn->peer_cert_chain, conn->peer_cert_chain_len, 0, &server_cert, &server_cert_len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); @@ -762,6 +1150,7 @@ int tls_recv_server_certificate(TLS_CONNECT *conn) return -1; } + // 这里的逻辑需要统筹考虑 // cipher_suite,扩展,证书之间的关系 @@ -794,7 +1183,6 @@ int tls_recv_server_certificate(TLS_CONNECT *conn) return -1; } - sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); if (conn->client_certs_len) { sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5); } @@ -803,30 +1191,15 @@ int tls_recv_server_certificate(TLS_CONNECT *conn) // verify ServerCertificate - switch (conn->protocol) { - case TLS_protocol_tls12: - if (x509_certs_verify(conn->server_certs, conn->server_certs_len, X509_cert_chain_server, + if (conn->ctx->cacertslen) { + + // 按道理来说,这只是验证证书,并应该出错啊 + if (x509_certs_verify_tlcp(conn->peer_cert_chain, conn->peer_cert_chain_len, X509_cert_chain_server, conn->ctx->cacerts, conn->ctx->cacertslen, conn->ctx->verify_depth, &verify_result) != 1) { error_print(); - tls_send_alert(conn, TLS_alert_bad_certificate); - return -1; + //tls_send_alert(conn, TLS_alert_bad_certificate); + //return -1; } - break; - case TLS_protocol_tlcp: - if (!conn->ctx->cacertslen) { - error_print(); - return -1; - } - if (x509_certs_verify_tlcp(conn->server_certs, conn->server_certs_len, X509_cert_chain_server, - conn->ctx->cacerts, conn->ctx->cacertslen, conn->ctx->verify_depth, &verify_result) != 1) { - error_print(); - tls_send_alert(conn, TLS_alert_bad_certificate); - return -1; - } - break; - default: - error_print(); - return -1; } return 1; @@ -904,7 +1277,7 @@ int tls_send_server_key_exchange(TLS_CONNECT *conn) return ret; } - sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); + //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); if (conn->client_certificate_verify) { tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5); } @@ -970,10 +1343,9 @@ int tls_signature_scheme_match_cipher_suite(int sig_alg, int cipher_suite) return 1; } -// 这里应该给一个新的key_exchange - int tls_recv_server_key_exchange(TLS_CONNECT *conn) { + int ret; uint8_t curve_type; uint16_t named_curve; const uint8_t *point_octets; @@ -997,40 +1369,59 @@ int tls_recv_server_key_exchange(TLS_CONNECT *conn) tls_trace("recv ServerKeyExchange\n"); - - if (tls_record_recv(conn->record, &conn->recordlen, conn->sock) != 1 - || tls_record_protocol(conn->record) != conn->protocol) { + if ((ret = tls_recv_record(conn)) != 1) { + if (ret != TLS_ERROR_RECV_AGAIN) { + error_print(); + } + return ret; + } + if (tls_record_protocol(conn->record) != conn->protocol) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); - // 这里应该改为,首先获取一个基础的ServerKeyExchange的数据(已经经过了验证签名) - // 然后再用不同的函数(不同协议)去分解 - if (tls_record_get_handshake_server_key_exchange(conn->record, + + // 这个函数可能是有问题的,如果cipher_suite不同,ServerKeyExchange可能也是不同的 + if ((ret = tls_record_get_handshake_server_key_exchange(conn->record, &curve_type, &named_curve, &point_octets, &point_octets_len, &server_ecdh_params, &server_ecdh_params_len, - &sig_alg, &sig, &siglen) != 1) { + &sig_alg, &sig, &siglen)) < 0) { + error_print(); + tls_send_alert(conn, TLS_alert_decode_error); + return -1; + } else if (ret == 0) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); + return 0; + } + + if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { + error_print(); return -1; } + tls_handshake_digest_print(stderr, 0, 0, "ServerKeyExchange", &conn->dgst_ctx); + if (curve_type != TLS_curve_type_named_curve) { error_print(); return -1; } // named_curve应该在supported_groups里面 + //conn->ecdh_named_curve = named_curve; + + + conn->key_exchange_group = named_curve; + memcpy(conn->peer_key_exchange, point_octets, point_octets_len); + conn->peer_key_exchange_len = point_octets_len; + - conn->ecdh_named_curve = named_curve; if (point_octets_len != 65) { error_print(); return -1; } - memcpy(conn->peer_ecdh_point, point_octets, point_octets_len); - conn->peer_ecdh_point_len = point_octets_len; if (tls_curve_match_cipher_suite(named_curve, conn->cipher_suite) != 1) { @@ -1051,14 +1442,15 @@ int tls_recv_server_key_exchange(TLS_CONNECT *conn) // 判断curve_name在supported_groups中并记录这个信息 // 验证point确实在curve_name的group中 - sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); + //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); if (conn->client_certs_len) sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5); - // get sign_key from first cert of server_certs - if (x509_certs_get_cert_by_index(conn->server_certs, conn->server_certs_len, + + + if (x509_certs_get_cert_by_index(conn->peer_cert_chain, conn->peer_cert_chain_len, server_cert_index, &server_cert, &server_cert_len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_bad_certificate); @@ -1070,6 +1462,10 @@ int tls_recv_server_key_exchange(TLS_CONNECT *conn) return -1; } + + + + // 这个检查是否是多余的? // 这个值是签名算法和椭圆曲线名字的结合 // cipher_suite只能决定签名算法类型 @@ -1100,6 +1496,8 @@ int tls_recv_server_key_exchange(TLS_CONNECT *conn) sign_args = SM2_DEFAULT_ID; sign_argslen = SM2_DEFAULT_ID_LENGTH; } + + // 这里应该是SM2的签名和验证 if (x509_verify_init(&sign_ctx, &server_sign_key, sign_args, sign_argslen, sig, siglen) != 1 || x509_verify_update(&sign_ctx, conn->client_random, 32) != 1 || x509_verify_update(&sign_ctx, conn->server_random, 32) != 1 @@ -1109,6 +1507,9 @@ int tls_recv_server_key_exchange(TLS_CONNECT *conn) return -1; } + fprintf(stderr, ">>>>>> ServerKeyExchange verify success\n"); + + // xxxx // 这里的签名错了,肯定是sign_ctx就是不对的,因此是不可能正确的 // 现在要做的是,必须确定server_key_exchange中都包括了哪些被签名的消息 @@ -1157,14 +1558,12 @@ int tls_send_certificate_request(TLS_CONNECT *conn) return ret; } - sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); + //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5); return 1; } -// 如果收到的是预期的报文,就处理,然后将recordlen = 0 -// 否则保留recordlen int tls_recv_certificate_request(TLS_CONNECT *conn) { int ret; @@ -1178,14 +1577,15 @@ int tls_recv_certificate_request(TLS_CONNECT *conn) const uint8_t *ca_names; size_t ca_names_len; + tls_trace("recv CertificateRequest*\n"); - // recv CertificateRequest or ServerHelloDone if ((ret = tls_recv_record(conn)) != 1) { if (ret != TLS_ERROR_RECV_AGAIN) { error_print(); } return ret; } + if (tls_record_protocol(conn->record) != conn->protocol) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); @@ -1196,35 +1596,43 @@ int tls_recv_certificate_request(TLS_CONNECT *conn) tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } + if (handshake_type != TLS_handshake_certificate_request) { - conn->client_certs_len = 0; - - fprintf(stderr, "%s %d: no certificate_request\n", __FILE__, __LINE__); - fprintf(stderr, "recordlen = %zu\n", conn->recordlen); - + tls_trace(" no CertificateRequest\n"); return 0; // 表明对方没有发送预期的报文 } - - tls_trace("recv CertificateRequest\n"); tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); + + if (tls_record_get_handshake_certificate_request(conn->record, &cert_types, &cert_types_len, &ca_names, &ca_names_len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } + + if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { + error_print(); + return -1; + } + + // 这里要检查一下服务器发送的,和本地的是否保持一致 + /* if(!conn->client_certs_len) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } + if (tls_cert_types_accepted(cert_types, cert_types_len, conn->client_certs, conn->client_certs_len) != 1 || tls_authorities_issued_certificate(ca_names, ca_names_len, conn->client_certs, conn->client_certs_len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_unsupported_certificate); return -1; } - sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); + */ + + sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5); conn->recordlen = 0; @@ -1249,7 +1657,7 @@ int tls_send_server_hello_done(TLS_CONNECT *conn) } return ret; } - sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); + //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); if (conn->client_certs_len) { tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5); @@ -1283,7 +1691,15 @@ int tls_recv_server_hello_done(TLS_CONNECT *conn) tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } - sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); + + if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { + error_print(); + return -1; + } + tls_handshake_digest_print(stderr, 0, 0, "ServerHelloDone", &conn->dgst_ctx); + + + if (conn->client_certs_len) sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5); @@ -1319,7 +1735,7 @@ int tls_send_client_certificate(TLS_CONNECT *conn) } - sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); + //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5); return 1; @@ -1362,7 +1778,7 @@ int tls_recv_client_certificate(TLS_CONNECT *conn) tls_send_alert(conn, TLS_alert_bad_certificate); return -1; } - sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); + //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5); return 1; @@ -1393,13 +1809,10 @@ int tls_generate_keys(TLS_CONNECT *conn) uint8_t pre_master_secret[32]; size_t pre_master_secret_len; - // 此时已经获得了ServerKeyExchange和ClientKeyExchange - // 但是不同密码套件中,这些KeyExchange的数据其实是不一样的 - // 我们需要根据不同的套件去解析数据,并且根据不同的数据类型去生成密钥 - // 还需要检查 TLS 1.3 的协议 - if (x509_key_exchange(&conn->ecdh_key, - conn->peer_ecdh_point, conn->peer_ecdh_point_len, + + if (x509_key_exchange(&conn->key_exchanges[0], + conn->peer_key_exchange, conn->peer_key_exchange_len, pre_master_secret, &pre_master_secret_len) != 1) { error_print(); return -1; @@ -1409,7 +1822,7 @@ int tls_generate_keys(TLS_CONNECT *conn) return -1; } - if (tls_prf(pre_master_secret, 32, "master secret", + if (tls12_prf(conn->digest, pre_master_secret, 32, "master secret", conn->client_random, 32, conn->server_random, 32, 48, conn->master_secret) != 1) { @@ -1417,7 +1830,7 @@ int tls_generate_keys(TLS_CONNECT *conn) return -1; } - if (tls_prf(conn->master_secret, 48, "key expansion", + if (tls12_prf(conn->digest, conn->master_secret, 48, "key expansion", conn->server_random, 32, conn->client_random, 32, 96, conn->key_block) != 1) { @@ -1426,17 +1839,31 @@ int tls_generate_keys(TLS_CONNECT *conn) return -1; } - sm3_hmac_init(&conn->client_write_mac_ctx, conn->key_block, 32); - sm3_hmac_init(&conn->server_write_mac_ctx, conn->key_block + 32, 32); + if (hmac_init(&conn->client_write_mac_ctx, conn->digest, conn->key_block, 32) != 1) { + error_print(); + return -1; + } + if (hmac_init(&conn->server_write_mac_ctx, conn->digest, conn->key_block + 32, 32) != 1) { + error_print(); + return -1; + } if (conn->is_client) { - sm4_set_encrypt_key(&conn->client_write_enc_key, conn->key_block + 64); - sm4_set_decrypt_key(&conn->server_write_enc_key, conn->key_block + 80); + block_cipher_set_encrypt_key(&conn->client_write_key, conn->cipher, conn->key_block + 64); + block_cipher_set_decrypt_key(&conn->server_write_key, conn->cipher, conn->key_block + 80); + + } else { - sm4_set_decrypt_key(&conn->client_write_enc_key, conn->key_block + 64); - sm4_set_encrypt_key(&conn->server_write_enc_key, conn->key_block + 80); + block_cipher_set_decrypt_key(&conn->client_write_key, conn->cipher, conn->key_block + 64); + block_cipher_set_encrypt_key(&conn->server_write_key, conn->cipher, conn->key_block + 80); } + tls_seq_num_reset(conn->client_seq_num); + tls_seq_num_reset(conn->server_seq_num); + + + + tls_secrets_print(stderr, pre_master_secret, 48, conn->client_random, conn->server_random, @@ -1448,24 +1875,27 @@ int tls_generate_keys(TLS_CONNECT *conn) } - int tls_send_client_key_exchange(TLS_CONNECT *conn) { int ret; - uint8_t point_octets[65]; - uint8_t *p = point_octets; - size_t len = 0; // 客户端的ECDHE的公钥肯定和服务器是保持一致的 // 因此在接收到服务器的公钥之后,应该保存这个信息 + + // 客户端是怎么确定密钥交换的group的?大概是从ServerKeyExchange中确定的 + if (conn->recordlen == 0) { - int curve_oid = tls_named_curve_oid(conn->ecdh_named_curve); - if (x509_key_generate(&conn->ecdh_key, OID_ec_public_key, &curve_oid, sizeof(curve_oid)) != 1) { + uint8_t point_octets[65]; + uint8_t *p = point_octets; + size_t len = 0; + int curve_oid = tls_named_curve_oid(conn->key_exchange_group); + + if (x509_key_generate(&conn->key_exchanges[0], OID_ec_public_key, &curve_oid, sizeof(curve_oid)) != 1) { error_print(); return -1; } - if (x509_public_key_to_bytes(&conn->ecdh_key, &p, &len) != 1) { + if (x509_public_key_to_bytes(&conn->key_exchanges[0], &p, &len) != 1) { error_print(); return -1; } @@ -1482,6 +1912,12 @@ int tls_send_client_key_exchange(TLS_CONNECT *conn) return -1; } tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); + + if (digest_update(&conn->dgst_ctx, conn->record + 5, conn->recordlen - 5) != 1) { + error_print(); + return -1; + } + tls_handshake_digest_print(stderr, 0, 0, "ClientKeyExchange", &conn->dgst_ctx); } if ((ret = tls_send_record(conn)) != 1) { @@ -1491,7 +1927,6 @@ int tls_send_client_key_exchange(TLS_CONNECT *conn) return ret; } - sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); if (conn->client_certs_len) sm2_sign_update(&conn->sign_ctx, conn->record + 5, conn->recordlen - 5); @@ -1532,7 +1967,7 @@ int tls_recv_client_key_exchange(TLS_CONNECT *conn) memcpy(conn->peer_ecdh_point, point_octets, point_octets_len); conn->peer_ecdh_point_len = point_octets_len; - sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); + //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); if (conn->ctx->cacertslen) tls_client_verify_update(&conn->client_verify_ctx, conn->record + 5, conn->recordlen - 5); @@ -1575,7 +2010,7 @@ int tls_send_certificate_verify(TLS_CONNECT *conn) return ret; } - sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); + //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); return 1; } @@ -1641,7 +2076,7 @@ int tls_recv_certificate_verify(TLS_CONNECT *conn) tls_send_alert(conn, TLS_alert_decrypt_error); return -1; } - sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); + //sm3_update(&conn->sm3_ctx, conn->record + 5, conn->recordlen - 5); return 1; } @@ -1698,33 +2133,34 @@ int tls_send_client_finished(TLS_CONNECT *conn) { int ret; - SM3_CTX tmp_sm3_ctx; - uint8_t sm3_hash[32]; - uint8_t finished_record[TLS_FINISHED_RECORD_BUF_SIZE]; - size_t finished_record_len; - uint8_t local_verify_data[12]; - - - tls_record_set_protocol(finished_record, conn->protocol); if (conn->recordlen == 0) { - tls_trace("send Finished\n"); + tls_trace("send client {Finished}\n"); - // 到目前为止所有消息的哈希 - memcpy(&tmp_sm3_ctx, &conn->sm3_ctx, sizeof(SM3_CTX)); - sm3_finish(&tmp_sm3_ctx, sm3_hash); + uint8_t local_verify_data[12]; - if (tls_prf(conn->master_secret, 48, - "client finished", sm3_hash, 32, NULL, 0, + DIGEST_CTX tmp_ctx; + uint8_t dgst[32]; + size_t dgstlen; + + tmp_ctx = conn->dgst_ctx; + + digest_finish(&tmp_ctx, dgst, &dgstlen); + + if (tls12_prf(conn->digest, + conn->master_secret, 48, + "client finished", dgst, dgstlen, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } + tls_record_set_protocol(conn->plain_record, conn->protocol); + // finished_record是没有问题的 - if (tls_record_set_handshake_finished(finished_record, &finished_record_len, + if (tls_record_set_handshake_finished(conn->plain_record, &conn->plain_recordlen, local_verify_data, sizeof(local_verify_data)) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); @@ -1732,51 +2168,22 @@ int tls_send_client_finished(TLS_CONNECT *conn) } // 此时finished_record中的头部应该是完整的 - tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0); + tls12_record_trace(stderr, conn->plain_record, conn->plain_recordlen, 0, 0); - sm3_update(&conn->sm3_ctx, finished_record + 5, finished_record_len - 5); + if (digest_update(&conn->dgst_ctx, conn->plain_record + 5, conn->plain_recordlen - 5) != 1) { + error_print(); + return -1; + } + tls_handshake_digest_print(stderr, 0, 0, "Finished", &conn->dgst_ctx); - // encrypt Client Finished - - // 此时finished_record中的头部应该是完整的 - - - // encrypt Client Finished - - // 但是conn->record并没有设置 - - // 但是conn->record并没有设置 - - // 这个会把finished的头部copy过来, 但是握手的头部并没有copy - - - - if (tls_record_encrypt(&conn->client_write_mac_ctx, &conn->client_write_enc_key, - conn->client_seq_num, finished_record, finished_record_len, + if (tls12_record_encrypt(&conn->client_write_mac_ctx, &conn->client_write_key, + conn->client_seq_num, conn->plain_record, conn->plain_recordlen, conn->record, &conn->recordlen) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; } - - // 最后输出的负载数据是80-bytes,而实际上应该是64字节,多了16字节,这显然是不对的 - // 为什么会有这样的结果呢? - - tls_encrypted_record_trace(stderr, conn->record, conn->recordlen, (1<<24), 0); // 强制打印密文原数据 - - // 这里直接就出错了 - /* - if (tls_record_decrypt(&conn->client_write_mac_ctx, &conn->client_write_enc_key, - conn->client_seq_num, conn->record, conn->recordlen, finished_record, &finished_record_len) != 1) { - error_print(); - return -1; - } - - */ - - tls_trace("encrypted ClientFinished\n"); - tls_encrypted_record_trace(stderr, conn->record, conn->recordlen, (1<<24), 0); // 强制打印密文原数据 tls_seq_num_incr(conn->client_seq_num); } @@ -1829,7 +2236,7 @@ int tls_recv_client_finished(TLS_CONNECT *conn) tls_trace("decrypt Finished\n"); - if (tls_record_decrypt(&conn->client_write_mac_ctx, &conn->client_write_enc_key, + if (tls12_record_decrypt(&conn->client_write_mac_ctx, &conn->client_write_key, conn->client_seq_num, conn->record, conn->recordlen, finished_record, &finished_record_len) != 1) { @@ -1852,9 +2259,9 @@ int tls_recv_client_finished(TLS_CONNECT *conn) } // verify ClientFinished - memcpy(&tmp_sm3_ctx, &conn->sm3_ctx, sizeof(SM3_CTX)); - sm3_update(&conn->sm3_ctx, finished_record + 5, finished_record_len - 5); - sm3_finish(&tmp_sm3_ctx, sm3_hash); + //memcpy(&tmp_sm3_ctx, &conn->sm3_ctx, sizeof(SM3_CTX)); + //sm3_update(&conn->sm3_ctx, finished_record + 5, finished_record_len - 5); + //sm3_finish(&tmp_sm3_ctx, sm3_hash); if (tls_prf(conn->master_secret, 48, "client finished", sm3_hash, 32, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1) { error_print(); @@ -1885,7 +2292,7 @@ int tls_send_server_finished(TLS_CONNECT *conn) if (conn->recordlen == 0) { tls_trace("send Finished\n"); - sm3_finish(&conn->sm3_ctx, sm3_hash); + // sm3_finish(&conn->sm3_ctx, sm3_hash); if (tls_prf(conn->master_secret, 48, "server finished", sm3_hash, 32, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1 || tls_record_set_handshake_finished(finished_record, &finished_record_len, @@ -1895,7 +2302,7 @@ int tls_send_server_finished(TLS_CONNECT *conn) return -1; } tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0); - if (tls_record_encrypt(&conn->server_write_mac_ctx, &conn->server_write_enc_key, + if (tls12_record_encrypt(&conn->server_write_mac_ctx, &conn->server_write_key, conn->server_seq_num, finished_record, finished_record_len, record, &recordlen) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); @@ -1923,7 +2330,8 @@ int tls_recv_server_finished(TLS_CONNECT *conn) uint8_t finished_record[TLS_FINISHED_RECORD_BUF_SIZE]; size_t finished_record_len; - uint8_t sm3_hash[32]; + uint8_t dgst[32]; + size_t dgstlen; const uint8_t *verify_data; size_t verify_data_len; @@ -1931,7 +2339,7 @@ int tls_recv_server_finished(TLS_CONNECT *conn) // Finished - tls_trace("recv Finished\n"); + tls_trace("recv server Finished\n"); if ((ret = tls_recv_record(conn)) != 1) { if (ret != TLS_ERROR_RECV_AGAIN) { error_print(); @@ -1943,23 +2351,20 @@ int tls_recv_server_finished(TLS_CONNECT *conn) tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } - if (conn->recordlen > sizeof(finished_record)) { - error_print(); // 解密可能导致 finished_record 溢出 - tls_send_alert(conn, TLS_alert_bad_record_mac); - return -1; - } - tls_encrypted_record_trace(stderr, conn->record, conn->recordlen, (1<<24), 0); // 强制打印密文原数据 tls_trace("decrypt Finished\n"); - if (tls_record_decrypt(&conn->server_write_mac_ctx, &conn->server_write_enc_key, - conn->server_seq_num, conn->record, conn->recordlen, finished_record, &finished_record_len) != 1) { + if (tls12_record_decrypt(&conn->server_write_mac_ctx, &conn->server_write_key, + conn->server_seq_num, conn->record, conn->recordlen, + conn->plain_record, &conn->plain_recordlen) != 1) { error_print(); tls_send_alert(conn, TLS_alert_bad_record_mac); return -1; } - tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0); + tls12_record_print(stderr, conn->plain_record, conn->plain_recordlen, 0, 0); + tls_seq_num_incr(conn->server_seq_num); + if (tls_record_get_handshake_finished(finished_record, &verify_data, &verify_data_len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); @@ -1970,9 +2375,15 @@ int tls_recv_server_finished(TLS_CONNECT *conn) tls_send_alert(conn, TLS_alert_unexpected_message); return -1; } - sm3_finish(&conn->sm3_ctx, sm3_hash); - if (tls_prf(conn->master_secret, 48, "server finished", - sm3_hash, 32, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1) { + + if (digest_finish(&conn->dgst_ctx, dgst, &dgstlen) != 1) { + error_print(); + return -1; + } + + if (tls12_prf(conn->digest, conn->master_secret, 48, "server finished", + dgst, dgstlen, NULL, 0, + sizeof(local_verify_data), local_verify_data) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; @@ -2276,7 +2687,7 @@ int tls12_do_connect(TLS_CONNECT *conn) fd_set wfds; conn->state = TLS_state_client_hello; - sm3_init(&conn->sm3_ctx); + //sm3_init(&conn->sm3_ctx); digest_init(&conn->dgst_ctx, DIGEST_sm3()); @@ -2316,7 +2727,7 @@ int tls12_do_accept(TLS_CONNECT *conn) conn->state = TLS_state_client_hello; - sm3_init(&conn->sm3_ctx); + //sm3_init(&conn->sm3_ctx); digest_init(&conn->dgst_ctx, DIGEST_sm3()); while (1) { diff --git a/src/tls13.c b/src/tls13.c index 03d9b17f..991c40c1 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -8804,7 +8804,7 @@ int tls13_do_connect(TLS_CONNECT *conn) // 应该把protocol_version的初始化放在这里 conn->state = TLS_state_client_hello; - sm3_init(&conn->sm3_ctx); + //sm3_init(&conn->sm3_ctx); while (1) { @@ -8843,7 +8843,7 @@ int tls13_do_accept(TLS_CONNECT *conn) conn->state = TLS_state_client_hello; - sm3_init(&conn->sm3_ctx); + //sm3_init(&conn->sm3_ctx); fprintf(stderr, "tls13_do_accept\n"); diff --git a/src/tls_trace.c b/src/tls_trace.c index 0f0ece76..04ab6ed2 100644 --- a/src/tls_trace.c +++ b/src/tls_trace.c @@ -626,10 +626,10 @@ int tls_client_hello_print(FILE *fp, const uint8_t *data, size_t datalen, int fo } if (datalen > 0) { if (tls_uint16array_from_bytes(&exts, &exts_len, &data, &datalen) != 1) goto end; - format_print(fp, format, indent, "Extensions\n"); - indent += 4; + tls_extensions_print(fp, exts, exts_len, format, indent); } - // 打印扩展 + + /* while (exts_len > 0) { uint16_t ext_type; const uint8_t *ext_data; @@ -642,10 +642,9 @@ int tls_client_hello_print(FILE *fp, const uint8_t *data, size_t datalen, int fo } format_print(fp, format, indent, "%s (%d)\n", tls_extension_name(ext_type), ext_type); - indent += 4; - tls_extensions_print(fp, exts, exts_len, format, indent); } + */ if (datalen > 0) { error_print(); @@ -700,9 +699,7 @@ int tls_server_hello_print(FILE *fp, const uint8_t *data, size_t datalen, int fo tls_compression_method_name(comp_meth), comp_meth); if (datalen > 0) { if (tls_uint16array_from_bytes(&exts, &exts_len, &data, &datalen) != 1) goto bad; - //format_bytes(fp, format, indent, "Extensions : ", exts, exts_len); // FIXME: extensions_print - //tls_extensions_print(fp, exts, exts_len, format, indent); - //tls13_extensions_print(fp, format, indent, TLS_handshake_server_hello, exts, exts_len); + tls_extensions_print(fp, exts, exts_len, format, indent); } return 1; bad: diff --git a/tools/tls12_help.h b/tools/tls12_help.h index 262d9658..7974fd86 100644 --- a/tools/tls12_help.h +++ b/tools/tls12_help.h @@ -28,5 +28,5 @@ " cat cacert.pem >> certs.pem\n" "\n" " gmssl tls12_server -port 4430 -cert certs.pem -key signkey.pem -pass 1234\n" -" gmssl tls12_client -host 127.0.0.1 -port 4430 -cacert rootcacert.pem\n" +" gmssl tls12_client -host 127.0.0.1 -port 4430 -cipher_suite TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 -supported_group prime256v1 -sig_alg ecdsa_secp256r1_sha256 -cacert rootcacert.pem\n" diff --git a/tools/tls12_server.c b/tools/tls12_server.c index 12c835de..d78f9bf2 100644 --- a/tools/tls12_server.c +++ b/tools/tls12_server.c @@ -24,10 +24,16 @@ static const char *help = "Options\n" "\n" " -port num Listening port number, default 443\n" +" -cipher_suite str Supported cipher suites, may appear multiple times, higher priority first\n" +" -supported_group str Supported elliptic curves, may appear multiple times, higher priority first\n" +" -sig_alg str Supported signature algorithms\n" " -cert file Server's certificate chain in PEM format\n" " -key file Server's encrypted private key in PEM format\n" " -pass str Password to decrypt private key\n" +" -cert_request Client certificate request\n" " -cacert file CA certificate for client certificate verification\n" +" -verify_depth num Certificate verification depth\n" +" -client_cert_optional Allow client send empty Certificate\n" "\n" #include "tls12_help.h" "\n"; @@ -38,10 +44,31 @@ int tls12_server_main(int argc , char **argv) int ret = 1; char *prog = argv[0]; int port = 443; + int cipher_suites[4]; + size_t cipher_suites_cnt = 0; + int supported_groups[4]; + size_t supported_groups_cnt = 0; + int sig_algs[4]; + size_t sig_algs_cnt = 0; + + + char *certfiles[4]; + size_t certfiles_cnt = 0; + char *keyfiles[sizeof(certfiles)/sizeof(certfiles[0])]; + size_t keyfiles_cnt = 0; + char *passes[sizeof(certfiles)/sizeof(certfiles[0])]; + size_t passes_cnt = 0; + + /* char *certfile = NULL; char *keyfile = NULL; char *pass = NULL; + */ + + int cert_request = 0; char *cacertfile = NULL; + int verify_depth = TLS_DEFAULT_VERIFY_DEPTH; + int client_cert_optional = 0; int server_ciphers[] = { TLS_cipher_ecdhe_sm4_cbc_sm3, }; @@ -54,6 +81,8 @@ int tls12_server_main(int argc , char **argv) struct sockaddr_in server_addr; struct sockaddr_in client_addr; + size_t i; + argc--; argv++; @@ -70,6 +99,51 @@ int tls12_server_main(int argc , char **argv) } else if (!strcmp(*argv, "-port")) { if (--argc < 1) goto bad; port = atoi(*(++argv)); + } else if (!strcmp(*argv, "-cipher_suite")) { + char *cipher_suite_name; + int cipher_suite; + if (cipher_suites_cnt >= sizeof(cipher_suites)/sizeof(cipher_suites[0])) { + fprintf(stderr, "%s: too many -cipher_suite options\n", prog); + return -1; + } + if (--argc < 1) goto bad; + cipher_suite_name = *(++argv); + if ((cipher_suite = tls_cipher_suite_from_name(cipher_suite_name)) == 0) { + fprintf(stderr, "%s: invalid -cipher_suite '%s' value\n", prog, cipher_suite_name); + return -1; + } + cipher_suites[cipher_suites_cnt] = cipher_suite; + cipher_suites_cnt++; + } else if (!strcmp(*argv, "-supported_group")) { + char *supported_group_name; + int supported_group; + if (supported_groups_cnt >= sizeof(supported_groups)/sizeof(supported_groups[0])) { + fprintf(stderr, "%s: too many -supported_group options\n", prog); + return -1; + } + if (--argc < 1) goto bad; + supported_group_name = *(++argv); + if ((supported_group = tls_named_curve_from_name(supported_group_name)) == 0) { + fprintf(stderr, "%s: -supported_group '%s' not supported\n", prog, supported_group_name); + return -1; + } + supported_groups[supported_groups_cnt++] = supported_group; + } else if (!strcmp(*argv, "-sig_alg")) { + char *sig_alg_name; + int sig_alg; + if (sig_algs_cnt >= sizeof(sig_algs)/sizeof(sig_algs[0])) { + fprintf(stderr, "%s: too many -sig_alg options\n", prog); + return -1; + } + if (--argc < 1) goto bad; + sig_alg_name = *(++argv); + if ((sig_alg = tls_signature_scheme_from_name(sig_alg_name)) == 0) { + fprintf(stderr, "%s: -sig_alg '%s' not supported\n", prog, sig_alg_name); + return -1; + } + sig_algs[sig_algs_cnt++] = sig_alg; + + /* } else if (!strcmp(*argv, "-cert")) { if (--argc < 1) goto bad; certfile = *(++argv); @@ -79,9 +153,45 @@ int tls12_server_main(int argc , char **argv) } else if (!strcmp(*argv, "-pass")) { if (--argc < 1) goto bad; pass = *(++argv); + */ + + } else if (!strcmp(*argv, "-cert")) { + if (certfiles_cnt >= sizeof(certfiles)/sizeof(certfiles[0])) { + fprintf(stderr, "%s: too many -cert options\n", prog); + return -1; + } + if (--argc < 1) goto bad; + certfiles[certfiles_cnt++] = *(++argv); + } else if (!strcmp(*argv, "-key")) { + if (keyfiles_cnt >= sizeof(keyfiles)/sizeof(keyfiles[0])) { + fprintf(stderr, "%s: too many -key options\n", prog); + return -1; + } + if (--argc < 1) goto bad; + keyfiles[keyfiles_cnt++] = *(++argv); + } else if (!strcmp(*argv, "-pass")) { + if (passes_cnt >= sizeof(passes)/sizeof(passes[0])) { + fprintf(stderr, "%s: too many -pass options\n", prog); + return -1; + } + if (--argc < 1) goto bad; + passes[passes_cnt++] = *(++argv); + + + } else if (!strcmp(*argv, "-cert_request")) { + cert_request = 1; } else if (!strcmp(*argv, "-cacert")) { if (--argc < 1) goto bad; cacertfile = *(++argv); + } else if (!strcmp(*argv, "-verify_depth")) { + if (--argc < 1) goto bad; + verify_depth = atoi(*(++argv)); + if (verify_depth < 1) { + fprintf(stderr, "%s: invalid -verify_depth value '%d'\n", prog, verify_depth); + return -1; + } + } else if (!strcmp(*argv, "-client_cert_optional")) { + client_cert_optional = 1; } else { fprintf(stderr, "%s: invalid option '%s'\n", prog, *argv); return 1; @@ -92,39 +202,102 @@ bad: argc--; argv++; } - if (!certfile) { + + if (!certfiles_cnt) { fprintf(stderr, "%s: '-cert' option required\n", prog); return 1; } - if (!keyfile) { + if (!keyfiles_cnt) { fprintf(stderr, "%s: '-key' option required\n", prog); return 1; } - if (!pass) { + if (!passes_cnt) { fprintf(stderr, "%s: '-pass' option required\n", prog); return 1; } - memset(&ctx, 0, sizeof(ctx)); - memset(&conn, 0, sizeof(conn)); - if (tls_socket_lib_init() != 1) { error_print(); return -1; } - if (tls_ctx_init(&ctx, TLS_protocol_tls12, TLS_server_mode) != 1 - || tls_ctx_set_cipher_suites(&ctx, server_ciphers, sizeof(server_ciphers)/sizeof(int)) != 1 + if (tls_ctx_init(&ctx, TLS_protocol_tls12, TLS_server_mode) != 1) { + error_print(); + return -1; + } + + if (tls_ctx_set_cipher_suites(&ctx, cipher_suites, cipher_suites_cnt) != 1) { + fprintf(stderr, "%s: context init error\n", prog); + goto end; + } + + // supported_groups + if (supported_groups_cnt > 0) { + if (tls_ctx_set_supported_groups(&ctx, supported_groups, supported_groups_cnt) != 1) { + error_print(); + goto end; + } + } + + // signature_algorithms + if (sig_algs_cnt > 0) { + if (tls_ctx_set_signature_algorithms(&ctx, sig_algs, sig_algs_cnt) != 1) { + error_print(); + goto end; + } + } + + + if (certfiles_cnt != keyfiles_cnt || keyfiles_cnt != passes_cnt) { + error_print(); + return -1; + } + // Certificate + for (i = 0; i < certfiles_cnt; i++) { + if (tls_ctx_add_certificate_chain_and_key(&ctx, certfiles[i], keyfiles[i], passes[i]) != 1) { + error_print(); + goto end;; + } + } + + /* + if (tls_ctx_set_cipher_suites(&ctx, server_ciphers, sizeof(server_ciphers)/sizeof(int)) != 1 || tls_ctx_set_certificate_and_key(&ctx, certfile, keyfile, pass) != 1) { error_print(); return -1; } + */ + + // CertificateRequest + if (cert_request) { + if (!cacertfile) { + fprintf(stderr, "%s: -cacert required by -cert_request\n", prog); + goto end; + } + if (tls_ctx_set_ca_certificates(&ctx, cacertfile, verify_depth) != 1) { + error_print(); + goto end; + } + if (tls_ctx_enable_certificate_request(&ctx, 1) != 1) { + error_print(); + goto end; + } + if (client_cert_optional) { + if (tls13_ctx_enable_client_certificate_optional(&ctx, 1) != 1) { + error_print(); + goto end; + } + } + } + + /* if (cacertfile) { if (tls_ctx_set_ca_certificates(&ctx, cacertfile, TLS_DEFAULT_VERIFY_DEPTH) != 1) { error_print(); return -1; } } + */ // Socket