diff --git a/src/tls12.c b/src/tls12.c index da43d99c..2a1756b0 100644 --- a/src/tls12.c +++ b/src/tls12.c @@ -503,6 +503,111 @@ int tls_record_set_handshake_server_key_exchange(uint8_t *record, size_t *record return 1; } +static int tls12_server_ecdh_params_from_bytes(uint8_t *curve_type, uint16_t *named_curve, + const uint8_t **point_octets, size_t *point_octets_len, + const uint8_t **in, size_t *inlen) +{ + if (!curve_type || !named_curve || !point_octets || !point_octets_len + || !in || !(*in) || !inlen) { + error_print(); + return -1; + } + if (tls_uint8_from_bytes(curve_type, in, inlen) != 1 + || tls_uint16_from_bytes(named_curve, in, inlen) != 1 + || tls_uint8array_from_bytes(point_octets, point_octets_len, in, inlen) != 1) { + error_print(); + return -1; + } + if (*curve_type != TLS_curve_type_named_curve) { + error_print(); + return -1; + } + if (!tls_named_curve_name(*named_curve)) { + error_print(); + return -1; + } + if (!*point_octets || !*point_octets_len) { + error_print(); + return -1; + } + return 1; +} + +static int tls12_server_key_exchange_params_from_bytes(int cipher_suite, + const uint8_t **params, size_t *params_len, const uint8_t **in, size_t *inlen) +{ + const uint8_t *start; + + if (!params || !params_len || !in || !(*in) || !inlen) { + error_print(); + return -1; + } + + start = *in; + switch (cipher_suite) { + case TLS_cipher_ecdhe_sm4_cbc_sm3: + case TLS_cipher_ecdhe_sm4_gcm_sm3: + case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: + { + uint8_t curve_type; + uint16_t named_curve; + const uint8_t *point_octets; + size_t point_octets_len; + + if (tls12_server_ecdh_params_from_bytes(&curve_type, &named_curve, + &point_octets, &point_octets_len, in, inlen) != 1) { + error_print(); + return -1; + } + } + break; + default: + error_print(); + return -1; + } + + *params = start; + *params_len = *in - start; + return 1; +} + +static int tls12_record_get_handshake_server_key_exchange(const uint8_t *record, int cipher_suite, + const uint8_t **params, size_t *params_len, + uint16_t *sig_alg, const uint8_t **sig, size_t *siglen) +{ + int type; + const uint8_t *p; + size_t len; + + if (!record || !params || !params_len || !sig_alg || !sig || !siglen) { + error_print(); + return -1; + } + if (tls_record_get_handshake(record, &type, &p, &len) != 1) { + error_print(); + return -1; + } + if (type != TLS_handshake_server_key_exchange) { + error_print(); + return 0; + } + + if (tls12_server_key_exchange_params_from_bytes(cipher_suite, + params, params_len, &p, &len) != 1 + || tls_uint16_from_bytes(sig_alg, &p, &len) != 1 + || tls_uint16array_from_bytes(sig, siglen, &p, &len) != 1 + || tls_length_is_zero(len) != 1) { + error_print(); + return -1; + } + if (!tls_signature_scheme_name(*sig_alg)) { + error_print(); + return -1; + } + + return 1; +} + // 这个函数是有问题的,因为tlcp的格式和TLS不一样 int tls_record_get_handshake_server_key_exchange(const uint8_t *record, uint8_t *curve_type, uint16_t *named_curve, @@ -524,9 +629,8 @@ int tls_record_get_handshake_server_key_exchange(const uint8_t *record, } *server_ecdh_params = p; - if (tls_uint8_from_bytes(curve_type, &p, &len) != 1 - || tls_uint16_from_bytes(named_curve, &p, &len) != 1 - || tls_uint8array_from_bytes(point_octets, point_octets_len, &p, &len) != 1) { + if (tls12_server_ecdh_params_from_bytes(curve_type, named_curve, + point_octets, point_octets_len, &p, &len) != 1) { error_print(); return -1; } @@ -541,14 +645,6 @@ int tls_record_get_handshake_server_key_exchange(const uint8_t *record, error_print(); return -1; } - if (*curve_type != TLS_curve_type_named_curve) { - error_print(); - return -1; - } - if (!tls_named_curve_name(*named_curve)) { - error_print(); - return -1; - } if (!tls_signature_scheme_name(*sig_alg)) { error_print(); return -1; @@ -1975,12 +2071,12 @@ int tls_signature_scheme_match_cipher_suite(int sig_alg, int cipher_suite) int tls_recv_server_key_exchange(TLS_CONNECT *conn) { int ret; - uint8_t curve_type; - uint16_t named_curve; - const uint8_t *point_octets; - size_t point_octets_len; - const uint8_t *server_ecdh_params; - size_t server_ecdh_params_len; + uint8_t curve_type = 0; + uint16_t named_curve = 0; + const uint8_t *point_octets = NULL; + size_t point_octets_len = 0; + const uint8_t *server_ecdh_params = NULL; + size_t server_ecdh_params_len = 0; uint16_t sig_alg; const uint8_t *sig; size_t siglen; @@ -2012,10 +2108,8 @@ int tls_recv_server_key_exchange(TLS_CONNECT *conn) tls12_record_trace(stderr, conn->record, conn->recordlen, 0, 0); - // 这个函数可能是有问题的,如果cipher_suite不同,ServerKeyExchange可能也是不同的 - if ((ret = tls_record_get_handshake_server_key_exchange(conn->record, - &curve_type, &named_curve, &point_octets, &point_octets_len, - &server_ecdh_params, &server_ecdh_params_len, + if ((ret = tls12_record_get_handshake_server_key_exchange(conn->record, + conn->cipher_suite, &server_ecdh_params, &server_ecdh_params_len, &sig_alg, &sig, &siglen)) < 0) { error_print(); tls_send_alert(conn, TLS_alert_decode_error); @@ -2032,8 +2126,32 @@ int tls_recv_server_key_exchange(TLS_CONNECT *conn) } tls_handshake_digest_print(stderr, 0, 0, "ServerKeyExchange", &conn->dgst_ctx); + switch (conn->cipher_suite) { + case TLS_cipher_ecdhe_sm4_cbc_sm3: + case TLS_cipher_ecdhe_sm4_gcm_sm3: + case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: + { + const uint8_t *p = server_ecdh_params; + size_t len = server_ecdh_params_len; + + if (tls12_server_ecdh_params_from_bytes(&curve_type, &named_curve, + &point_octets, &point_octets_len, &p, &len) != 1 + || tls_length_is_zero(len) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_decode_error); + return -1; + } + } + break; + default: + error_print(); + tls_send_alert(conn, TLS_alert_unexpected_message); + return -1; + } + if (curve_type != TLS_curve_type_named_curve) { error_print(); + tls_send_alert(conn, TLS_alert_illegal_parameter); return -1; } // named_curve应该在supported_groups里面 @@ -2049,20 +2167,19 @@ int tls_recv_server_key_exchange(TLS_CONNECT *conn) if (point_octets_len != 65) { error_print(); + tls_send_alert(conn, TLS_alert_illegal_parameter); return -1; } if (tls_curve_match_cipher_suite(named_curve, conn->cipher_suite) != 1) { error_print(); - return -1; - } - if (point_octets_len != 65) { - error_print(); + tls_send_alert(conn, TLS_alert_illegal_parameter); return -1; } if (tls_signature_scheme_match_cipher_suite(sig_alg, conn->cipher_suite) != 1) { error_print(); + tls_send_alert(conn, TLS_alert_illegal_parameter); return -1; } @@ -2130,9 +2247,10 @@ int tls_recv_server_key_exchange(TLS_CONNECT *conn) if (x509_verify_init(&sign_ctx, &server_sign_key, sign_args, sign_argslen, sig, siglen) != 1 || x509_verify_update(&sign_ctx, conn->client_random, 32) != 1 || x509_verify_update(&sign_ctx, conn->server_random, 32) != 1 - || x509_verify_update(&sign_ctx, server_ecdh_params, 69) != 1 + || x509_verify_update(&sign_ctx, server_ecdh_params, server_ecdh_params_len) != 1 || x509_verify_finish(&sign_ctx) != 1) { error_print(); + tls_send_alert(conn, TLS_alert_decrypt_error); return -1; }