diff --git a/include/gmssl/sm2.h b/include/gmssl/sm2.h index c9ad12cb..5fef50a5 100644 --- a/include/gmssl/sm2.h +++ b/include/gmssl/sm2.h @@ -365,7 +365,8 @@ int sm2_do_encrypt_fixlen(const SM2_KEY *key, const uint8_t *in, size_t inlen, i int sm2_encrypt_fixlen(const SM2_KEY *key, const uint8_t *in, size_t inlen, int point_size, uint8_t *out, size_t *outlen); -int sm2_ecdh(const SM2_KEY *key, const SM2_POINT *peer_public, SM2_POINT *out); +int sm2_do_ecdh(const SM2_KEY *key, const SM2_POINT *peer_public, SM2_POINT *out); +int sm2_ecdh(const SM2_KEY *key, const uint8_t *peer_public, size_t peer_public_len, SM2_POINT *out); typedef uint8_t sm2_bn_t[32]; diff --git a/src/sm2_lib.c b/src/sm2_lib.c index 4767763a..d98260b5 100644 --- a/src/sm2_lib.c +++ b/src/sm2_lib.c @@ -814,7 +814,7 @@ int sm2_decrypt(const SM2_KEY *key, const uint8_t *in, size_t inlen, uint8_t *ou return 1; } -int sm2_ecdh(const SM2_KEY *key, const SM2_POINT *peer_public, SM2_POINT *out) +int sm2_do_ecdh(const SM2_KEY *key, const SM2_POINT *peer_public, SM2_POINT *out) { if (!key || !peer_public || !out) { error_print(); @@ -827,6 +827,22 @@ int sm2_ecdh(const SM2_KEY *key, const SM2_POINT *peer_public, SM2_POINT *out) return 1; } +int sm2_ecdh(const SM2_KEY *key, const uint8_t *peer_public, size_t peer_public_len, SM2_POINT *out) +{ + SM2_POINT point; + + if (sm2_point_from_octets(&point, peer_public, peer_public_len) != 1) { + error_print(); + return -1; + } + if (sm2_do_ecdh(key, &point, out) != 1) { + error_print(); + return -1; + } + return 1; +} + + // (x1, y1) = k * G // r = e + x1 // s = (k - r * d)/(1 + d) = (k +r - r * d - r)/(1 + d) = (k + r - r(1 +d))/(1 + d) = (k + r)/(1 + d) - r diff --git a/src/tls12.c b/src/tls12.c index af9f148a..f9bf6f98 100644 --- a/src/tls12.c +++ b/src/tls12.c @@ -472,7 +472,7 @@ int tls12_do_connect(TLS_CONNECT *conn) tls_trace("generate secrets\n"); SM2_KEY client_ecdh; sm2_key_generate(&client_ecdh); - sm2_ecdh(&client_ecdh, &server_ecdhe_public, &server_ecdhe_public); + sm2_do_ecdh(&client_ecdh, &server_ecdhe_public, &server_ecdhe_public); memcpy(pre_master_secret, &server_ecdhe_public, 32); // 这个做法很不优雅 // ECDHE和ECC的PMS结构是不一样的吗? @@ -942,7 +942,7 @@ int tls12_do_accept(TLS_CONNECT *conn) // generate secrets tls_trace("generate secrets\n"); - sm2_ecdh(&server_ecdhe_key, &client_ecdhe_point, &client_ecdhe_point); + sm2_do_ecdh(&server_ecdhe_key, &client_ecdhe_point, &client_ecdhe_point); memcpy(pre_master_secret, (uint8_t *)&client_ecdhe_point, 32); // 这里应该修改一下表示方式,比如get_xy() tls_prf(pre_master_secret, 32, "master secret", client_random, 32, server_random, 32, diff --git a/src/tls13.c b/src/tls13.c index a347a6f7..4f8d3860 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -1564,7 +1564,7 @@ int tls13_do_connect(TLS_CONNECT *conn) uint8_t client_write_iv[12] uint8_t server_write_iv[12] */ - sm2_ecdh(&client_ecdhe, &server_ecdhe_public, &server_ecdhe_public); + sm2_do_ecdh(&client_ecdhe, &server_ecdhe_public, &server_ecdhe_public); /* [1] */ tls13_hkdf_extract(digest, zeros, psk, early_secret); /* [5] */ tls13_derive_secret(early_secret, "derived", &null_dgst_ctx, handshake_secret); /* [6] */ tls13_hkdf_extract(digest, handshake_secret, (uint8_t *)&server_ecdhe_public, handshake_secret); @@ -2042,7 +2042,7 @@ int tls13_do_accept(TLS_CONNECT *conn) digest_update(&dgst_ctx, record + 5, recordlen - 5); - sm2_ecdh(&server_ecdhe, &client_ecdhe_public, &client_ecdhe_public); + sm2_do_ecdh(&server_ecdhe, &client_ecdhe_public, &client_ecdhe_public); /* 1 */ tls13_hkdf_extract(digest, zeros, psk, early_secret); /* 5 */ tls13_derive_secret(early_secret, "derived", &null_dgst_ctx, handshake_secret); /* 6 */ tls13_hkdf_extract(digest, handshake_secret, (uint8_t *)&client_ecdhe_public, handshake_secret);