From cffee1dd9f9f77bf5f88129b1e28de81fdf016b5 Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Sun, 2 Jun 2024 10:19:24 +0800 Subject: [PATCH] Fix tls12, tls13 bug MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 目前TLS 1.2, 1.3的握手过程中使用了SM2_Z256_POINT,应该改为使用SM2_POINT,可以兼容其他曲线类型,只在做ECDH的时候才判断点的正确性。 --- include/gmssl/sm2_z256.h | 1 - src/tls12.c | 23 ++++++++++++++++------- src/tls13.c | 13 +++++++++---- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/include/gmssl/sm2_z256.h b/include/gmssl/sm2_z256.h index c038c26e..2fcb9936 100644 --- a/include/gmssl/sm2_z256.h +++ b/include/gmssl/sm2_z256.h @@ -138,7 +138,6 @@ int sm2_z256_point_from_octets(SM2_Z256_POINT *P, const uint8_t *in, size_t inle int sm2_z256_point_to_uncompressed_octets(const SM2_Z256_POINT *P, uint8_t out[65]); int sm2_z256_point_to_compressed_octets(const SM2_Z256_POINT *P, uint8_t out[33]); -int sm2_z256_point_from_octets(SM2_Z256_POINT *P, const uint8_t *in, size_t inlen); /* RFC 5480 Elliptic Curve Cryptography Subject Public Key Information diff --git a/src/tls12.c b/src/tls12.c index f48fa7aa..713b0ff3 100644 --- a/src/tls12.c +++ b/src/tls12.c @@ -66,6 +66,8 @@ int tls_record_set_handshake_server_key_exchange_ecdhe(uint8_t *record, size_t * } // 这里返回的应该是一个SM2_Z256_POINT吗? +// 首先得问题就是这里面我们是否有必要去检查这个值是否是椭圆曲线上的一个点,还是把这个值传递给后面 +// 如果在这里直接检查,那么意味着这个函数不兼容其他的椭圆曲线点 int tls_record_get_handshake_server_key_exchange_ecdhe(const uint8_t *record, int *curve, SM2_Z256_POINT *point, const uint8_t **sig, size_t *siglen) { @@ -467,7 +469,12 @@ int tls12_do_connect(TLS_CONNECT *conn) SM2_KEY client_ecdh; sm2_key_generate(&client_ecdh); sm2_do_ecdh(&client_ecdh, &server_ecdhe_public, &server_ecdhe_public); - memcpy(pre_master_secret, &server_ecdhe_public, 32); // 这个做法很不优雅 + + // 需要重新考虑在TLS中是用sm2_do_ecdh还是sm2_ecdh,sm2_ecdh对nistp256的兼容性更好 + uint8_t point_bytes[64]; + sm2_z256_point_to_bytes(&server_ecdhe_public, point_bytes); + + memcpy(pre_master_secret, point_bytes, 32); // 这个做法很不优雅 // ECDHE和ECC的PMS结构是不一样的吗? if (tls_prf(pre_master_secret, 32, "master secret", @@ -484,14 +491,13 @@ int tls12_do_connect(TLS_CONNECT *conn) sm3_hmac_init(&conn->server_write_mac_ctx, conn->key_block + 32, 32); sm4_set_encrypt_key(&conn->client_write_enc_key, conn->key_block + 64); sm4_set_decrypt_key(&conn->server_write_enc_key, conn->key_block + 80); - /* + tls_secrets_print(stderr, pre_master_secret, 48, client_random, server_random, conn->master_secret, conn->key_block, 96, 0, 4); - */ // send ClientKeyExchange tls_trace("send ClientKeyExchange\n"); @@ -938,7 +944,11 @@ int tls12_do_accept(TLS_CONNECT *conn) // generate secrets tls_trace("generate secrets\n"); sm2_do_ecdh(&server_ecdhe_key, &client_ecdhe_point, &client_ecdhe_point); - memcpy(pre_master_secret, (uint8_t *)&client_ecdhe_point, 32); // 这里应该修改一下表示方式,比如get_xy() + + uint8_t point_bytes[64]; + sm2_z256_point_to_bytes(&client_ecdhe_point, point_bytes); + memcpy(pre_master_secret, point_bytes, 32); // 这里应该修改一下表示方式,比如get_xy() + tls_prf(pre_master_secret, 32, "master secret", client_random, 32, server_random, 32, 48, conn->master_secret); @@ -949,10 +959,9 @@ int tls12_do_accept(TLS_CONNECT *conn) sm3_hmac_init(&conn->server_write_mac_ctx, conn->key_block + 32, 32); sm4_set_decrypt_key(&conn->client_write_enc_key, conn->key_block + 64); sm4_set_encrypt_key(&conn->server_write_enc_key, conn->key_block + 80); - /* - tls_secrets_print(stderr, pre_master_secret, 32, client_random, server_random, + + tls_secrets_print(stderr, pre_master_secret, 48, client_random, server_random, conn->master_secret, conn->key_block, 96, 0, 4); - */ // recv [ChangeCipherSpec] tls_trace("recv [ChangeCipherSpec]\n"); diff --git a/src/tls13.c b/src/tls13.c index 8ef99b3d..5af1dbe6 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -1600,10 +1600,12 @@ int tls13_do_connect(TLS_CONNECT *conn) uint8_t client_write_iv[12] uint8_t server_write_iv[12] */ - sm2_do_ecdh(&client_ecdhe, &server_ecdhe_public, &server_ecdhe_public); + sm2_do_ecdh(&client_ecdhe, &server_ecdhe_public, &server_ecdhe_public); + uint8_t share_point[64]; + sm2_z256_point_to_bytes(&server_ecdhe_public, share_point); /* [1] */ tls13_hkdf_extract(digest, zeros, psk, early_secret); /* [5] */ tls13_derive_secret(early_secret, "derived", &null_dgst_ctx, handshake_secret); - /* [6] */ tls13_hkdf_extract(digest, handshake_secret, (uint8_t *)&server_ecdhe_public, handshake_secret); + /* [6] */ tls13_hkdf_extract(digest, handshake_secret, share_point, handshake_secret); /* [7] */ tls13_derive_secret(handshake_secret, "c hs traffic", &dgst_ctx, client_handshake_traffic_secret); /* [8] */ tls13_derive_secret(handshake_secret, "s hs traffic", &dgst_ctx, server_handshake_traffic_secret); /* [9] */ tls13_derive_secret(handshake_secret, "derived", &null_dgst_ctx, master_secret); @@ -2081,10 +2083,13 @@ int tls13_do_accept(TLS_CONNECT *conn) digest_update(&dgst_ctx, record + 5, recordlen - 5); - sm2_do_ecdh(&server_ecdhe, &client_ecdhe_public, &client_ecdhe_public); + sm2_do_ecdh(&server_ecdhe, &client_ecdhe_public, &client_ecdhe_public); + uint8_t share_point[64];//FIXME: 应该重新考虑TLS中如何使用sm2_do_ecdh还是sm2_ecdh + sm2_z256_point_to_bytes(&client_ecdhe_public, share_point); + /* 1 */ tls13_hkdf_extract(digest, zeros, psk, early_secret); /* 5 */ tls13_derive_secret(early_secret, "derived", &null_dgst_ctx, handshake_secret); - /* 6 */ tls13_hkdf_extract(digest, handshake_secret, (uint8_t *)&client_ecdhe_public, handshake_secret); + /* 6 */ tls13_hkdf_extract(digest, handshake_secret, share_point, handshake_secret); /* 7 */ tls13_derive_secret(handshake_secret, "c hs traffic", &dgst_ctx, client_handshake_traffic_secret); /* 8 */ tls13_derive_secret(handshake_secret, "s hs traffic", &dgst_ctx, server_handshake_traffic_secret); /* 9 */ tls13_derive_secret(handshake_secret, "derived", &null_dgst_ctx, master_secret);