From 5eaab7033dca3f12a55fd0b86697b02b06830708 Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Sat, 11 Jun 2022 23:50:54 +0800 Subject: [PATCH] Update TLCP --- include/gmssl/tls.h | 327 ++++++------ include/gmssl/x509.h | 5 + src/asn1.c | 73 +-- src/tlcp.c | 163 ++---- src/tls.c | 1124 ++++++++++++++++++++++++++++++------------ src/tls_trace.c | 19 +- src/x509_cer.c | 65 +++ tools/tlcp_client.c | 123 +++-- tools/tlcp_server.c | 123 +++-- 9 files changed, 1263 insertions(+), 759 deletions(-) diff --git a/include/gmssl/tls.h b/include/gmssl/tls.h index a7f73394..71cedb6a 100644 --- a/include/gmssl/tls.h +++ b/include/gmssl/tls.h @@ -1,4 +1,4 @@ -/* +/* * Copyright (c) 2020 - 2021 The GmSSL Project. All rights reserved. * * Redistribution and use in source and binary forms, with or without @@ -88,28 +88,29 @@ int tls_uint24array_from_bytes(const uint8_t **data, size_t *datalen, const uint typedef enum { - TLS_version_tls12_major = 3, - TLS_version_tls12_minor = 3, - TLS_version_tlcp = 0x0101, - TLS_version_ssl2 = 0x0200, - TLS_version_ssl3 = 0x0300, - TLS_version_tls1 = 0x0301, - TLS_version_tls11 = 0x0302, - TLS_version_tls12 = 0x0303, - TLS_version_tls13 = 0x0304, - TLS_version_dtls1 = 0xfeff, // {254, 255} - TLS_version_dtls12 = 0xfefd, // {254, 253} + TLS_version_tlcp = 0x0101, + TLS_version_ssl2 = 0x0200, + TLS_version_ssl3 = 0x0300, + TLS_version_tls1 = 0x0301, + TLS_version_tls11 = 0x0302, + TLS_version_tls12 = 0x0303, + TLS_version_tls13 = 0x0304, + TLS_version_dtls1 = 0xfeff, // {254, 255} + TLS_version_dtls12 = 0xfefd, // {254, 253} } TLS_VERSION; const char *tls_version_text(int version); -// 兼容GmSSL 2.5.4 typedef enum { TLS_cipher_null_with_null_null = 0x0000, + + // TLS 1.3, RFC 8998 TLS_cipher_sm4_gcm_sm3 = 0x00c6, TLS_cipher_sm4_ccm_sm3 = 0x00c7, - TLCP_cipher_ecdhe_sm4_cbc_sm3 = 0xe011, // TLCP, TLS 1.2 + + // TLCP, GB/T 38636-2020, GM/T 0024-2012 + TLCP_cipher_ecdhe_sm4_cbc_sm3 = 0xe011, TLCP_cipher_ecdhe_sm4_gcm_sm3 = 0xe051, TLCP_cipher_ecc_sm4_cbc_sm3 = 0xe013, TLCP_cipher_ecc_sm4_gcm_sm3 = 0xe053, @@ -121,20 +122,21 @@ typedef enum { TLCP_cipher_rsa_sm4_gcm_sm3 = 0xe059, TLCP_cipher_rsa_sm4_cbc_sha256 = 0xe01c, TLCP_cipher_rsa_sm4_gcm_sha256 = 0xe05a, + + // GmSSL v2.5 GMSSL_cipher_ecdhe_sm2_with_sm4_sm3 = 0xe102, GMSSL_cipher_ecdhe_sm2_with_sm4_gcm_sm3 = 0xe107, GMSSL_cipher_ecdhe_sm2_with_sm4_ccm_sm3 = 0xe108, GMSSL_cipher_ecdhe_sm2_with_zuc_sm3 = 0xe10d, - TLS_cipher_empty_renegotiation_info_scsv = 0x00ff, - - // TLS 1.3 ciphers (rfc 8446 p.133) - TLS_cipher_aes_128_gcm_sha256 = 0x1301, // mandatory-to-implement + // TLS 1.3 RFC 8446 + TLS_cipher_aes_128_gcm_sha256 = 0x1301, // Mandatory-to-implement TLS_cipher_aes_256_gcm_sha384 = 0x1302, // SHOULD implement TLS_cipher_chacha20_poly1305_sha256 = 0x1303, // SHOULD implement TLS_cipher_aes_128_ccm_sha256 = 0x1304, TLS_cipher_aes_128_ccm_8_sha256 = 0x1305, + TLS_cipher_empty_renegotiation_info_scsv = 0x00ff, } TLS_CIPHER_SUITE; const char *tls_cipher_suite_name(int cipher); @@ -152,7 +154,7 @@ const char *tls_compression_method_name(int meth); typedef enum { - TLS_record_invalid = 0, // TLS 1.3 + TLS_record_invalid = 0, // TLS 1.3 TLS_record_change_cipher_spec = 20, // 0x14 TLS_record_alert = 21, // 0x15 TLS_record_handshake = 22, // 0x16 @@ -163,6 +165,7 @@ typedef enum { const char *tls_record_type_name(int type); + typedef enum { TLS_handshake_hello_request = 0, TLS_handshake_client_hello = 1, @@ -191,8 +194,6 @@ typedef enum { const char *tls_handshake_type_name(int type); - -// 这里面应该有一个和OID类型的转换 typedef enum { TLS_cert_type_rsa_sign = 1, TLS_cert_type_dss_sign = 2, @@ -213,7 +214,7 @@ const char *tls_cert_type_name(int type); int tls_cert_type_from_oid(int oid); typedef enum { - TLS_extension_server_name = 0, // tls 1.3 mandatory-to-implement + TLS_extension_server_name = 0, TLS_extension_max_fragment_length = 1, TLS_extension_client_certificate_url = 2, TLS_extension_trusted_ca_keys = 3, @@ -222,11 +223,11 @@ typedef enum { TLS_extension_user_mapping = 6, TLS_extension_client_authz = 7, TLS_extension_server_authz = 8, - TLS_extension_cert_type = 9, // 这个是支持服务器证书的类型吗?仅仅用CIPHER_SUITE不够吗? - TLS_extension_supported_groups = 10, // 必须支持 - TLS_extension_ec_point_formats = 11, // 必须支持 + TLS_extension_cert_type = 9, + TLS_extension_supported_groups = 10, + TLS_extension_ec_point_formats = 11, TLS_extension_srp = 12, - TLS_extension_signature_algorithms = 13, // // tls 1.3 mandatory-to-implement + TLS_extension_signature_algorithms = 13, TLS_extension_use_srtp = 14, TLS_extension_heartbeat = 15, TLS_extension_application_layer_protocol_negotiation= 16, @@ -235,8 +236,8 @@ typedef enum { TLS_extension_client_certificate_type = 19, TLS_extension_server_certificate_type = 20, TLS_extension_padding = 21, - TLS_extension_encrypt_then_mac = 22, // 应该支持 - TLS_extension_extended_master_secret = 23, // 这个是什么意思? + TLS_extension_encrypt_then_mac = 22, + TLS_extension_extended_master_secret = 23, TLS_extension_token_binding = 24, TLS_extension_cached_info = 25, TLS_extension_tls_lts = 26, @@ -248,20 +249,20 @@ typedef enum { TLS_extension_ticket_pinning = 32, TLS_extension_tls_cert_with_extern_psk = 33, TLS_extension_delegated_credentials = 34, - TLS_extension_session_ticket = 35, // 应该支持 + TLS_extension_session_ticket = 35, TLS_extension_TLMSP = 36, TLS_extension_TLMSP_proxying = 37, TLS_extension_TLMSP_delegate = 38, TLS_extension_supported_ekt_ciphers = 39, TLS_extension_pre_shared_key = 41, TLS_extension_early_data = 42, - TLS_extension_supported_versions = 43, // tls 1.3 mandatory-to-implement - TLS_extension_cookie = 44, // tls 1.3 mandatory-to-implement + TLS_extension_supported_versions = 43, + TLS_extension_cookie = 44, TLS_extension_psk_key_exchange_modes = 46, TLS_extension_certificate_authorities = 47, TLS_extension_oid_filters = 48, TLS_extension_post_handshake_auth = 49, - TLS_extension_signature_algorithms_cert = 50, // tls 1.3 mandatory-to-implement + TLS_extension_signature_algorithms_cert = 50, TLS_extension_key_share = 51, TLS_extension_transparency_info = 52, TLS_extension_connection_id = 53, @@ -298,7 +299,7 @@ typedef enum { TLS_curve_brainpoolp384r1 = 27, TLS_curve_brainpoolp512r1 = 28, TLS_curve_x25519 = 29, - TLS_curve_x448 = 99, //30, + TLS_curve_x448 = 99, //30, 应该用一个宏来处理 TLS_curve_brainpoolp256r1tls13 = 31, TLS_curve_brainpoolp384r1tls13 = 32, TLS_curve_brainpoolp512r1tls13 = 33, @@ -307,6 +308,7 @@ typedef enum { const char *tls_named_curve_name(int curve); + typedef enum { TLS_sig_rsa_pkcs1_sha1 = 0x0201, TLS_sig_ecdsa_sha1 = 0x0203, @@ -335,41 +337,46 @@ typedef enum { const char *tls_signature_scheme_name(int scheme); + typedef enum { TLS_change_cipher_spec = 1, } TLS_CHANGE_CIPHER_SPEC_TYPE; + typedef enum { TLS_alert_level_warning = 1, TLS_alert_level_fatal = 2, } TLS_ALERT_LEVEL; +const char *tls_alert_level_name(int level); + + typedef enum { - TLS_alert_close_notify = 0, // Fatal - TLS_alert_unexpected_message = 10, // Fatal 和正确实现的对方交互时不应出现此错误 - TLS_alert_bad_record_mac = 20, // Fatal 密文Mac验证错误、CBC密文长度错误、CBC密文填充错误 - TLS_alert_decryption_failed = 21, // 作废 - TLS_alert_record_overflow = 22, // Fatal TLSCiphertext.length > 2^14 + 2048, TLSCompressed.length > 2^14 + 1024 - TLS_alert_decompression_failure = 30, // 本实现不支持压缩 - TLS_alert_handshake_failure = 40, // Fatal, 安全参数协商失败,TLCP服务器没有找到合适的套件 - TLS_alert_no_certificate = 41, // 作废 - TLS_alert_bad_certificate = 42, // Any 我们使用Fatal - TLS_alert_unsupported_certificate = 43, // Any 我们使用Fatal - TLS_alert_certificate_revoked = 44, // Any 我们使用Fatal - TLS_alert_certificate_expired = 45, // Any 我们使用Fatal - TLS_alert_certificate_unknown = 46, // Any 我们使用Fatal, 大概没有CA证书 - TLS_alert_illegal_parameter = 47, // Fatal, 似乎TLCP不会遇到此情况 - TLS_alert_unknown_ca = 48, // Fatal - TLS_alert_access_denied = 49, // Fatal,??? - TLS_alert_decode_error = 50, // Fatal, 正确实现不会出现此问题,可能由网络故障导致 - TLS_alert_decrypt_error = 51, // Fatal, 验签失败、Finished验证失败 - TLS_alert_export_restriction = 60, // 作废 - TLS_alert_protocol_version = 70, // Fatal - TLS_alert_insufficient_security = 71, // Fatal,如果客户端Ciphers均强度不足,则服务器返回此错误 - TLS_alert_internal_error = 80, // Fatal - TLS_alert_user_canceled = 90, // Warning, 一般后面要跟一个close_notify,似乎没有必要 - TLS_alert_no_renegotiation = 100, // Warning,客户端收到HelloRequest, 服务器在握手后收到ClientHello - TLS_alert_unsupported_extension = 110, // Fatal, 服务器ServerHello返回不在ClientHello范围内的扩展 + TLS_alert_close_notify = 0, + TLS_alert_unexpected_message = 10, + TLS_alert_bad_record_mac = 20, + TLS_alert_decryption_failed = 21, + TLS_alert_record_overflow = 22, + TLS_alert_decompression_failure = 30, + TLS_alert_handshake_failure = 40, + TLS_alert_no_certificate = 41, + TLS_alert_bad_certificate = 42, + TLS_alert_unsupported_certificate = 43, + TLS_alert_certificate_revoked = 44, + TLS_alert_certificate_expired = 45, + TLS_alert_certificate_unknown = 46, + TLS_alert_illegal_parameter = 47, + TLS_alert_unknown_ca = 48, + TLS_alert_access_denied = 49, + TLS_alert_decode_error = 50, + TLS_alert_decrypt_error = 51, + TLS_alert_export_restriction = 60, + TLS_alert_protocol_version = 70, + TLS_alert_insufficient_security = 71, + TLS_alert_internal_error = 80, + TLS_alert_user_canceled = 90, + TLS_alert_no_renegotiation = 100, + TLS_alert_unsupported_extension = 110, TLS_alert_unsupported_site2site = 200, TLS_alert_no_area = 201, TLS_alert_unsupported_areatype = 202, @@ -378,42 +385,25 @@ typedef enum { TLS_alert_identity_need = 205, } TLS_ALERT_DESCRIPTION; -/* -TLCP ServerCertificate - 如果不是SM2的证书,那么返回 unsupported_certificate - 如果证书链有问题,比如不是双证书(比如少一个加密证书),bad_certificate - 如果证书本身验证错误 bad_certificate - 如果证书链没有对应的ROOTCA证书,那么返回certificate_unknown - 如果证书过期 - 如果证书作废:必须要结合CRL等 - 如果证书扩展没有通过验证,返回bad_certificate,Warning/Fatal -*/ +const char *tls_alert_description_text(int description); -// Ciphers - -int tls_seq_num_incr(uint8_t seq_num[8]); int tls_prf(const uint8_t *secret, size_t secretlen, const char *label, const uint8_t *seed, size_t seedlen, const uint8_t *more, size_t morelen, size_t outlen, uint8_t *out); - int tls13_hkdf_extract(const DIGEST *digest, const uint8_t salt[32], const uint8_t in[32], uint8_t out[32]); int tls13_hkdf_expand_label(const DIGEST *digest, const uint8_t secret[32], const char *label, const uint8_t *context, size_t context_len, size_t outlen, uint8_t *out); int tls13_derive_secret(const uint8_t secret[32], const char *label, const DIGEST_CTX *dgst_ctx, uint8_t out[32]); -#define TLS_MAX_PADDING_SIZE (1 + 255) -#define TLS_MAC_SIZE 32 // 目前只支持SM3、SHA256 - int tls_cbc_encrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *enc_key, const uint8_t seq_num[8], const uint8_t header[5], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); int tls_cbc_decrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *dec_key, const uint8_t seq_num[8], const uint8_t header[5], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); - int tls_record_encrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key, const uint8_t seq_num[8], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); @@ -421,9 +411,9 @@ int tls_record_decrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key, const uint8_t seq_num[8], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); +int tls_seq_num_incr(uint8_t seq_num[8]); int tls_random_generate(uint8_t random[32]); int tls_random_print(FILE *fp, const uint8_t random[32], int format, int indent); - int tls_pre_master_secret_generate(uint8_t pre_master_secret[48], int version); int tls_pre_master_secret_print(FILE *fp, const uint8_t pre_master_secret[48], int format, int indent); @@ -434,35 +424,37 @@ int tls_secrets_print(FILE *fp, const uint8_t *key_block, size_t key_block_len, int format, int indent); -// Record typedef struct { uint8_t type; uint8_t version[2]; - uint8_t length[2]; + uint8_t data_length[2]; } TLS_RECORD_HEADER; -#define TLS_RECORD_HEADER_SIZE 5 -#define TLS_MAX_PLAINTEXT_SIZE (1 << 14) // 2^14 = 16384 -#define TLS_MAX_COMPRESSED_SIZE ((1 << 14) + 1024) // 17408 -#define TLS_MAX_CIPHERTEXT_SIZE ((1 << 14) + 2048) // 18432 -#define TLS_MAX_RECORD_SIZE (TLS_RECORD_HEADER_SIZE + TLS_MAX_CIPHERTEXT_SIZE) // 18437 +#define TLS_RECORD_HEADER_SIZE (1 + tls_uint16_size() + tls_uint16_size()) // 5 +#define TLS_MAX_PLAINTEXT_SIZE (1 << 14) // 16384 +#define TLS_MAX_COMPRESSED_SIZE ((1 << 14) + 1024) // 17408 +#define TLS_MAX_CIPHERTEXT_SIZE ((1 << 14) + 2048) // 18432 +#define TLS_MAX_RECORD_SIZE (TLS_RECORD_HEADER_SIZE + TLS_MAX_CIPHERTEXT_SIZE) // 18437 -int tls_record_type(const uint8_t *record); -int tls_record_version(const uint8_t *record); -int tls_record_length(const uint8_t *record); -#define tls_record_data(record) ((record)+4) +#define tls_record_type(record) ((record)[0]) +#define tls_record_header(record) ((record)+0) +#define tls_record_version(record) (((uint16_t)((record)[1]) << 8) | (record)[2]) +#define tls_record_data(record) ((record)+TLS_RECORD_HEADER_SIZE) +#define tls_record_data_length(record) (((uint16_t)((record)[3]) << 8) | (record)[4]) +#define tls_record_length(record) (TLS_RECORD_HEADER_SIZE + tls_record_data_length(record)) int tls_record_set_type(uint8_t *record, int type); int tls_record_set_version(uint8_t *record, int version); +int tls_record_set_data_length(uint8_t *record, size_t length); +int tls_record_set_data(uint8_t *record, const uint8_t *data, size_t datalen); -// format -// 0-7 比特表示通用输出格式 -// 8-16 比特表示密码套件 format |= (cipher_suite << 8) -// 因为握手消息ServerKeyExchange, ClientKeyExchange的解析依赖当前密码套件 +// 握手消息ServerKeyExchange, ClientKeyExchange的解析依赖当前密码套件 +#define tls_format_set_cipher_suite(fmt,cipher) do {(fmt)|=((cipher)<<8);} while (0) int tls_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int format, int indent); int tlcp_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int format, int indent); + int tls_record_send(const uint8_t *record, size_t recordlen, int sock); int tls_record_recv(uint8_t *record, size_t *recordlen, int sock); int tls12_record_recv(uint8_t *record, size_t *recordlen, int sock); @@ -478,20 +470,17 @@ typedef struct { #define TLS_HANDSHAKE_HEADER_SIZE 4 #define TLS_MAX_HANDSHAKE_DATA_SIZE (TLS_MAX_PLAINTEXT_SIZE - TLS_HANDSHAKE_HEADER_SIZE) -#define tls_handshake_data(p) ((p)+5) +#define tls_handshake_data(p) ((p) + TLS_HANDSHAKE_HEADER_SIZE) +//#define tls_handshake_data_length(p) + + int tls_record_set_handshake(uint8_t *record, size_t *recordlen, int type, const uint8_t *data, size_t datalen); int tls_record_get_handshake(const uint8_t *record, int *type, const uint8_t **data, size_t *datalen); int tls_handshake_print(FILE *fp, const uint8_t *handshake, size_t handshakelen, int format, int indent); -/* -HelloRequest -TLCP 服务器均不支持 HelloRequest 握手消息 -TLS 服务器 - 在握手阶段收到 HelloRequest 消息则中止握手 - 在建立连接之后收到,忽略或者同时返回一个no_renegotiation warning alert -*/ +// HelloRequest int tls_hello_request_print(FILE *fp, const uint8_t *data, size_t datalen, int format, int indent); // ClientHello, ServerHello @@ -523,7 +512,6 @@ int tls_server_hello_print(FILE *fp, const uint8_t *server_hello, size_t len, in int tls_ext_signature_algors_to_bytes(const int *algors, size_t algors_count, uint8_t **out, size_t *outlen); - // Certificate int tls_record_set_handshake_certificate(uint8_t *record, size_t *recordlen, const uint8_t *certs, size_t certslen); @@ -534,6 +522,7 @@ int tls_record_get_handshake_certificate(const uint8_t *record, uint8_t *certs, // ServerKeyExchange int tls_server_key_exchange_print(FILE *fp, const uint8_t *ske, size_t skelen, int format, int indent); +#define TLS_MAX_SIGNATURE_SIZE SM2_MAX_SIGNATURE_SIZE int tls_sign_server_ecdh_params(const SM2_KEY *server_sign_key, const uint8_t client_random[32], const uint8_t server_random[32], int curve, const SM2_POINT *point, uint8_t *sig, size_t *siglen); @@ -556,8 +545,8 @@ int tlcp_server_key_exchange_pke_print(FILE *fp, const uint8_t *sig, size_t sigl // CertificateRequest -#define TLS_MAX_CERTIFICATE_TYPES 16 -#define TLS_MAX_CA_NAMES_SIZE 512 +#define TLS_MAX_CERTIFICATE_TYPES 256 +#define TLS_MAX_CA_NAMES_SIZE (TLS_MAX_HANDSHAKE_DATA_SIZE - tls_uint8_size() - tls_uint16_size()) int tls_authorities_from_certs(uint8_t *ca_names, size_t *ca_names_len, size_t maxlen, const uint8_t *certs, size_t certslen); int tls_authorities_issued_certificate(const uint8_t *ca_names, size_t ca_namelen, const uint8_t *certs, size_t certslen); @@ -623,6 +612,8 @@ void tls_client_verify_cleanup(TLS_CLIENT_VERIFY_CTX *ctx); // Finished #define TLS_VERIFY_DATA_SIZE 12 // TLS 1.3或者其他版本支持更长的verify_data #define TLS_FINISHED_RECORD_SIZE (TLS_RECORD_HEADER_SIZE + TLS_HANDSHAKE_HEADER_SIZE + TLS_VERIFY_DATA_SIZE) // 21 +#define TLS_MAX_PADDING_SIZE (1 + 255) +#define TLS_MAC_SIZE SM3_HMAC_SIZE #define TLS_FINISHED_RECORD_BUF_SIZE (TLS_FINISHED_RECORD_SIZE + TLS_MAC_SIZE + TLS_MAX_PADDING_SIZE) // 309 int tls_record_set_handshake_finished(uint8_t *record, size_t *recordlen, @@ -640,12 +631,9 @@ typedef struct { #define TLS_ALERT_RECORD_SIZE (TLS_RECORD_HEADER_SIZE + 2) -const char *tls_alert_level_name(int level); -const char *tls_alert_description_text(int description); -int tls_alert_print(FILE *fp, const uint8_t *data, size_t datalen, int format, int indent); - int tls_record_set_alert(uint8_t *record, size_t *recordlen, int alert_level, int alert_description); int tls_record_get_alert(const uint8_t *record, int *alert_level, int *alert_description); +int tls_alert_print(FILE *fp, const uint8_t *data, size_t datalen, int format, int indent); // ChangeCipherSpec @@ -667,37 +655,56 @@ int tls_application_data_print(FILE *fp, const uint8_t *data, size_t datalen, in +enum { + TLS_server_mode = 0, + TLS_client_mode = 1, +}; + +#define TLS_MAX_CIPHER_SUITES_COUNT 64 + typedef struct { - int protocol_versions[4]; - size_t protocol_versions_cnt; - int cipher_suites[8]; - size_t cipher_suits_cnt; - uint8_t certs[4096]; - size_t certslen; - SM2_KEY key; - SM2_KEY ex_key; - uint8_t cacerts[2048]; + int protocol_version; + int is_client; + int cipher_suites[TLS_MAX_CIPHER_SUITES_COUNT]; + size_t cipher_suites_cnt; + uint8_t *cacerts; size_t cacertslen; - int shutdown_mode; + uint8_t *certs; + size_t certslen; + SM2_KEY signkey; + SM2_KEY kenckey; + int verify_depth; } TLS_CTX; -int tls_ctx_set_protocol_versions(TLS_CTX *ctx, const int *versions, size_t versions_cnt); -int tls_ctx_set_cipher_suites(TLS_CTX *ctx, const char *ciphers); -int tls_ctx_set_certificats_and_keys(TLS_CTX *ctx, FILE *certs_fp, FILE *key_fp, const char *pass, FILE *ex_key_fp, const char *ex_pass); -int tls_ctx_set_ca_certificates(TLS_CTX *ctx, FILE *fp, int depth); -int tls_ctx_set_crl(TLS_CTX *ctx, FILE *fp); -int tls_ctx_set_client_verify_ca_certificates(TLS_CTX *ctx, FILE *fp, int depth); -int tls_ctx_set_shutdown_mode(TLS_CTX *ctx, int mode); +int tls_ctx_init(TLS_CTX *ctx, int version, int is_client); +int tls_ctx_set_cipher_suites(TLS_CTX *ctx, const int *cipher_suites, size_t cipher_suites_cnt); +int tls_ctx_set_ca_certificates(TLS_CTX *ctx, const char *cacertsfile, int depth); +int tls_ctx_set_certificate_and_key(TLS_CTX *ctx, const char *chainfile, + const char *keyfile, const char *keypass); +int tls_ctx_set_tlcp_server_certificate_and_keys(TLS_CTX *ctx, const char *chainfile, + const char *signkeyfile, const char *signkeypass, + const char *kenckeyfile, const char *kenckeypass); +void tls_ctx_cleanup(TLS_CTX *ctx); #define TLS_MAX_CERTIFICATES_SIZE 2048 +#define TLS_DEFAULT_VERIFY_DEPTH 4 +#define TLS_MAX_VERIFY_DEPTH 5 + typedef struct { - int sock; - int is_client; - int version; + int is_client; + int cipher_suites[TLS_MAX_CIPHER_SUITES_COUNT]; + size_t cipher_suites_cnt; + + int sock; + + uint8_t record[TLS_MAX_RECORD_SIZE]; + uint8_t data[TLS_MAX_PLAINTEXT_SIZE]; + size_t datalen; + int cipher_suite; uint8_t session_id[32]; size_t session_id_len; @@ -708,7 +715,10 @@ typedef struct { uint8_t ca_certs[2048]; size_t ca_certs_len; - int client_cert_verify_result; + SM2_KEY sign_key; + SM2_KEY kenc_key; + + int verify_result; uint8_t master_secret[48]; uint8_t key_block[96]; @@ -720,60 +730,43 @@ typedef struct { uint8_t client_seq_num[8]; uint8_t server_seq_num[8]; - // 这个有点问题,我们暂时还是不支持 - BLOCK_CIPHER_KEY client_write_key; // used in tls13.c - BLOCK_CIPHER_KEY server_write_key; // used in tls13.c uint8_t client_write_iv[12]; // tls13 uint8_t server_write_iv[12]; // tls13 - } TLS_CONNECT; int tls_init(TLS_CONNECT *conn, const TLS_CTX *ctx); -int tls_set_fd(TLS_CONNECT *conn, int sock); -int tls_get_verify_result(TLS_CONNECT *conn, int *result); +int tls_set_socket(TLS_CONNECT *conn, int sock); +int tls_do_handshake(TLS_CONNECT *conn); +int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen); +int tls_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen); +int tls_shutdown(TLS_CONNECT *conn); +void tls_cleanup(TLS_CONNECT *conn); -int tls_send_alert (TLS_CONNECT *conn, int alert); +int tlcp_do_connect(TLS_CONNECT *conn); +int tlcp_do_accept(TLS_CONNECT *conn); +int tls12_do_connect(TLS_CONNECT *conn); +int tls12_do_accept(TLS_CONNECT *conn); +int tls13_do_connect(TLS_CONNECT *conn); +int tls13_do_accept(TLS_CONNECT *conn); + +int tls_send_alert(TLS_CONNECT *conn, int alert); int tls_send_warning(TLS_CONNECT *conn, int alert); - -int tlcp_connect(TLS_CONNECT *conn, const char *hostname, int port, - FILE *ca_certs_fp, FILE *client_certs_fp, const SM2_KEY *client_sign_key); - -int tlcp_accept(TLS_CONNECT *conn, int port, - FILE *server_certs_fp, const SM2_KEY *server_sign_key, const SM2_KEY *server_enc_key, - FILE *client_cacerts_fp, uint8_t *client_cert_verify_buf, size_t client_cert_verify_buflen); - -int tls12_connect(TLS_CONNECT *conn, const char *hostname, int port, - FILE *ca_certs_fp, FILE *client_certs_fp, const SM2_KEY *client_sign_key); -int tls12_accept(TLS_CONNECT *conn, int port, - FILE *certs_fp, const SM2_KEY *server_sign_key, - FILE *client_cacerts_fp, uint8_t *handshakes_buf, size_t handshakes_buflen); - -int tls13_connect(TLS_CONNECT *conn, const char *hostname, int port, - FILE *ca_certs_fp, FILE *client_certs_fp, const SM2_KEY *client_sign_key); -int tls13_accept(TLS_CONNECT *conn, int port, - FILE *certs_fp, const SM2_KEY *server_sign_key, - FILE *client_cacerts_fp); - -int tls_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen); -int tls_recv(TLS_CONNECT *conn, uint8_t *data, size_t *datalen); - int tls13_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen, size_t padding_len); int tls13_recv(TLS_CONNECT *conn, uint8_t *data, size_t *datalen); -int tls_shutdown(TLS_CONNECT *conn); +#define TLS_DEBUG - -#if 1 -#define tls_trace(s) fprintf(stderr,(s)) -#define tls_record_trace(fp,rec,reclen,fmt,ind) tls_record_print(fp,rec,reclen,fmt,ind) -#define tlcp_record_trace(fp,rec,reclen,fmt,ind) tlcp_record_print(fp,rec,reclen,fmt,ind) +#ifdef TLS_DEBUG +# define tls_trace(s) fprintf(stderr,(s)) +# define tls_record_trace(fp,rec,reclen,fmt,ind) tls_record_print(fp,rec,reclen,fmt,ind) +# define tlcp_record_trace(fp,rec,reclen,fmt,ind) tlcp_record_print(fp,rec,reclen,fmt,ind) #else -#define tls_trace(s) -#define tls_record_trace(fp,rec,reclen,fmt,ind) -#define tlcp_record_trace(fp,rec,reclen,fmt,ind) +# define tls_trace(s) +# define tls_record_trace(fp,rec,reclen,fmt,ind) +# define tlcp_record_trace(fp,rec,reclen,fmt,ind) #endif diff --git a/include/gmssl/x509.h b/include/gmssl/x509.h index 1b95bc32..43e41d69 100644 --- a/include/gmssl/x509.h +++ b/include/gmssl/x509.h @@ -379,6 +379,11 @@ int x509_certs_get_subjects(const uint8_t *certs, size_t certslen, uint8_t *name int x509_certs_print(FILE *fp, int fmt, int ind, const char *label, const uint8_t *d, size_t dlen); +int x509_cert_new_from_file(uint8_t **out, size_t *outlen, const char *file); +int x509_certs_new_from_file(uint8_t **out, size_t *outlen, const char *file); + + + #ifdef __cplusplus } diff --git a/src/asn1.c b/src/asn1.c index 674f933f..7c0d0a91 100644 --- a/src/asn1.c +++ b/src/asn1.c @@ -173,7 +173,7 @@ int asn1_ia5_string_check(const char *a, size_t alen) int asn1_tag_to_der(int tag, uint8_t **out, size_t *outlen) { - if (out) { + if (out && *out) { *(*out)++ = (uint8_t)tag; } (*outlen)++; @@ -183,7 +183,7 @@ int asn1_tag_to_der(int tag, uint8_t **out, size_t *outlen) int asn1_length_to_der(size_t len, uint8_t **out, size_t *outlen) { if (len < 128) { - if (out) { + if (out && *out) { *(*out)++ = (uint8_t)len; } (*outlen)++; @@ -198,7 +198,7 @@ int asn1_length_to_der(size_t len, uint8_t **out, size_t *outlen) else if (len < (1 << 24)) i = 3; else i = 4; - if (out) { + if (out && *out) { *(*out)++ = 0x80 + i; memcpy(*out, buf + 4 - i, i); (*out) += i; @@ -211,7 +211,7 @@ int asn1_length_to_der(size_t len, uint8_t **out, size_t *outlen) // 提供返回值是为了和其他to_der函数一致 int asn1_data_to_der(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen) { - if (out) { + if (out && *out) { memcpy(*out, data, datalen); *out += datalen; } @@ -301,7 +301,7 @@ int asn1_data_from_der(const uint8_t **data, size_t datalen, const uint8_t **in, int asn1_header_to_der(int tag, size_t len, uint8_t **out, size_t *outlen) { - if ((out && !(*out)) || !outlen) { + if (!outlen) { error_print(); return -1; } @@ -429,7 +429,8 @@ int asn1_boolean_from_name(int *val, const char *name) int asn1_boolean_to_der_ex(int tag, int val, uint8_t **out, size_t *outlen) { - if ((out && !(*out)) || !outlen) { + if (!outlen) { + error_print(); return -1; } @@ -437,7 +438,7 @@ int asn1_boolean_to_der_ex(int tag, int val, uint8_t **out, size_t *outlen) return 0; } - if (out) { + if (out && *out) { *(*out)++ = tag; *(*out)++ = 0x01; *(*out)++ = val ? 0xff : 0x00; @@ -448,22 +449,20 @@ int asn1_boolean_to_der_ex(int tag, int val, uint8_t **out, size_t *outlen) int asn1_integer_to_der_ex(int tag, const uint8_t *a, size_t alen, uint8_t **out, size_t *outlen) { - if (!a) { - return 0; - } - - - - if (alen <= 0 || alen > INT_MAX || (out && !(*out)) || !outlen) { + if (!outlen) { error_print(); return -1; } + if (alen <= 0 || alen > INT_MAX) { + error_print(); + return -1; + } + if (!a) { + return 0; + } - - - - if (out) + if (out && *out) *(*out)++ = tag; (*outlen)++; @@ -474,7 +473,7 @@ int asn1_integer_to_der_ex(int tag, const uint8_t *a, size_t alen, uint8_t **out if (a[0] & 0x80) { asn1_length_to_der(alen + 1, out, outlen); - if (out) { + if (out && *out) { *(*out)++ = 0x00; memcpy(*out, a, alen); (*out) += alen; @@ -482,7 +481,7 @@ int asn1_integer_to_der_ex(int tag, const uint8_t *a, size_t alen, uint8_t **out (*outlen) += 1 + alen; } else { asn1_length_to_der(alen, out ,outlen); - if (out) { + if (out && *out) { memcpy(*out, a, alen); (*out) += alen; } @@ -571,11 +570,11 @@ const char *asn1_null_name(void) int asn1_null_to_der(uint8_t **out, size_t *outlen) { - if ((out && !(*out)) || !outlen) { + if (!outlen) { + error_print(); return -1; } - - if (out) { + if (out && *out) { *(*out)++ = ASN1_TAG_NULL; *(*out)++ = 0x00; } @@ -597,7 +596,7 @@ static void asn1_oid_node_to_base128(uint32_t a, uint8_t **out, size_t *outlen) } while (n--) { - if (out) + if (out && *out) *(*out)++ = buf[n]; (*outlen)++; } @@ -639,10 +638,14 @@ static int asn1_oid_node_from_base128(uint32_t *a, const uint8_t **in, size_t *i int asn1_object_identifier_to_octets(const uint32_t *nodes, size_t nodes_cnt, uint8_t *out, size_t *outlen) { + if (!outlen) { + error_print(); + return -1; + } if (nodes_cnt < 2 || nodes_cnt > 32) { return -1; } - if (out) + if (out && *out) *out++ = (uint8_t)(nodes[0] * 40 + nodes[1]); (*outlen) = 1; nodes += 2; @@ -705,11 +708,12 @@ int asn1_object_identifier_to_der_ex(int tag, const uint32_t *nodes, size_t node uint8_t octets[32]; size_t octetslen = 0; - if ((out && !(*out)) || !outlen) { + if (!outlen) { + error_print(); return -1; } - if (out) + if (out && *out) *(*out)++ = tag; (*outlen)++; @@ -717,7 +721,7 @@ int asn1_object_identifier_to_der_ex(int tag, const uint32_t *nodes, size_t node asn1_length_to_der(octetslen, out, outlen); - if (out) { + if (out && *out) { // 注意:If out == NULL, *out ==> Segment Fault memcpy(*out, octets, octetslen); *out += octetslen; @@ -824,18 +828,19 @@ int asn1_utc_time_to_der_ex(int tag, time_t a, uint8_t **out, size_t *outlen) struct tm tm_val; char buf[ASN1_UTC_TIME_LEN + 1]; - if ((out && !(*out)) || !outlen) { + if (!outlen) { + error_print(); return -1; } gmtime_r(&a, &tm_val); strftime(buf, sizeof(buf), "%y%m%d%H%M%SZ", &tm_val); - if (out) + if (out && *out) *(*out)++ = tag; (*outlen)++; asn1_length_to_der(sizeof(buf)-1, out, outlen); - if (out) { + if (out && *out) { memcpy(*out, buf, sizeof(buf)-1); (*out) += sizeof(buf)-1; } @@ -850,7 +855,7 @@ int asn1_generalized_time_to_der_ex(int tag, time_t a, uint8_t **out, size_t *ou struct tm tm_val; char buf[ASN1_GENERALIZED_TIME_LEN + 1]; - if ((out && !(*out)) || !outlen) { + if (!outlen) { error_print(); return -1; } @@ -859,11 +864,11 @@ int asn1_generalized_time_to_der_ex(int tag, time_t a, uint8_t **out, size_t *ou strftime(buf, sizeof(buf), "%Y%m%d%H%M%SZ", &tm_val); //printf("%s %d: generalized time : %s\n", __FILE__, __LINE__, buf); - if (out) + if (out && *out) *(*out)++ = tag; (*outlen)++; asn1_length_to_der(ASN1_GENERALIZED_TIME_LEN, out, outlen); - if (out) { + if (out && *out) { memcpy(*out, buf, ASN1_GENERALIZED_TIME_LEN); (*out) += ASN1_GENERALIZED_TIME_LEN; } diff --git a/src/tlcp.c b/src/tlcp.c index 239f2273..8a0b5038 100644 --- a/src/tlcp.c +++ b/src/tlcp.c @@ -159,11 +159,10 @@ int tlcp_server_key_exchange_pke_print(FILE *fp, const uint8_t *data, size_t dat return 1; } -int tlcp_connect(TLS_CONNECT *conn, const char *hostname, int port, - FILE *ca_certs_fp, FILE *client_certs_fp, const SM2_KEY *client_sign_key) +int tlcp_do_connect(TLS_CONNECT *conn) { int ret = -1; - uint8_t record[TLS_MAX_RECORD_SIZE]; + uint8_t *record = conn->record; uint8_t finished_record[TLS_FINISHED_RECORD_BUF_SIZE]; size_t recordlen, finished_record_len; @@ -205,29 +204,6 @@ int tlcp_connect(TLS_CONNECT *conn, const char *hostname, int port, int alert = 0; int verify_result; - struct sockaddr_in server; - - - memset(conn, 0, sizeof(*conn)); - conn->is_client = 1; - - // 设置CA证书(和客户端证书) - if (ca_certs_fp) { - if (x509_certs_from_pem(conn->ca_certs, &conn->ca_certs_len, 2048, ca_certs_fp) != 1) { - error_print(); - goto end; - } - } - if (client_sign_key) { - if (!client_certs_fp) { - error_print(); - goto end; - } - if (x509_certs_from_pem(conn->client_certs, &conn->client_certs_len, 2048, client_certs_fp) != 1) { - error_print(); - goto end; - } - } // 初始化记录缓冲 tls_record_set_version(record, TLS_version_tlcp); @@ -235,21 +211,8 @@ int tlcp_connect(TLS_CONNECT *conn, const char *hostname, int port, // 准备Finished Context(和ClientVerify) sm3_init(&sm3_ctx); - if (client_sign_key) - sm2_sign_init(&sign_ctx, client_sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH); - - // 设置Socket - server.sin_addr.s_addr = inet_addr(hostname); - server.sin_family = AF_INET; - server.sin_port = htons(port); - if ((conn->sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - error_print(); - goto end; - } - if (connect(conn->sock, (struct sockaddr *)&server , sizeof(server)) < 0) { - error_print(); - goto end; - } + if (conn->client_certs_len) + sm2_sign_init(&sign_ctx, &conn->sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH); // send ClientHello @@ -267,7 +230,7 @@ int tlcp_connect(TLS_CONNECT *conn, const char *hostname, int port, goto end; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - if (client_sign_key) + if (conn->client_certs_len) sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); // recv ServerHello @@ -309,7 +272,7 @@ int tlcp_connect(TLS_CONNECT *conn, const char *hostname, int port, memcpy(conn->session_id, session_id, session_id_len); conn->cipher_suite = cipher_suite; sm3_update(&sm3_ctx, record + 5, recordlen - 5); - if (client_sign_key) + if (conn->client_certs_len) sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); // recv ServerCertificate @@ -321,6 +284,7 @@ int tlcp_connect(TLS_CONNECT *conn, const char *hostname, int port, goto end; } tlcp_record_trace(stderr, record, recordlen, 0, 0); + if (tls_record_get_handshake_certificate(record, conn->server_certs, &conn->server_certs_len) != 1) { error_print(); @@ -328,7 +292,7 @@ int tlcp_connect(TLS_CONNECT *conn, const char *hostname, int port, goto end; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - if (client_sign_key) + if (conn->client_certs_len) sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); // verify ServerCertificate @@ -354,7 +318,7 @@ int tlcp_connect(TLS_CONNECT *conn, const char *hostname, int port, goto end; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - if (client_sign_key) + if (conn->client_certs_len) sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); // verify ServerKeyExchange @@ -406,7 +370,7 @@ int tlcp_connect(TLS_CONNECT *conn, const char *hostname, int port, tls_send_alert(conn, TLS_alert_unexpected_message); goto end; } - if(!client_sign_key) { + if(!conn->client_certs_len) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); goto end; @@ -428,7 +392,10 @@ int tlcp_connect(TLS_CONNECT *conn, const char *hostname, int port, goto end; } } else { - client_sign_key = NULL; + // 这个得处理一下 + conn->client_certs_len = 0; + gmssl_secure_clear(&conn->sign_key, sizeof(SM2_KEY)); + //client_sign_key = NULL; } tls_trace("recv ServerHelloDone\n"); tlcp_record_trace(stderr, record, recordlen, 0, 0); @@ -438,11 +405,11 @@ int tlcp_connect(TLS_CONNECT *conn, const char *hostname, int port, goto end; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - if (client_sign_key) + if (conn->client_certs_len) sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); // send ClientCertificate - if (client_sign_key) { + if (conn->client_certs_len) { tls_trace("send ClientCertificate\n"); if (tls_record_set_handshake_certificate(record, &recordlen, conn->client_certs, conn->client_certs_len) != 1) { error_print(); @@ -475,15 +442,12 @@ int tlcp_connect(TLS_CONNECT *conn, const char *hostname, int port, sm3_hmac_init(&conn->server_write_mac_ctx, conn->key_block + 32, 32); sm4_set_encrypt_key(&conn->client_write_enc_key, conn->key_block + 64); sm4_set_decrypt_key(&conn->server_write_enc_key, conn->key_block + 80); - /* - format_bytes(stderr, 0, 4, "PRE_MASTER_SECRET", pre_master_secret, 48); - format_bytes(stderr, 0, 4, "MASTER_SECRET", conn->master_secret, 48); - format_bytes(stderr, 0, 4, "CLIENT_WRITE_MAC_KEY", conn->key_block, 32); - format_bytes(stderr, 0, 4, "SERVER_WRITE_MAC_KEY", conn->key_block + 32, 32); - format_bytes(stderr, 0, 4, "CLIENT_WRITE_ENC_KEY", conn->key_block + 64, 16); - format_bytes(stderr, 0, 4, "SERVER_WRITE_ENC_KEY", conn->key_block + 80, 16); - format_print(stderr, 0, 0, "\n"); - */ + tls_secrets_print(stderr, + pre_master_secret, 48, + client_random, server_random, + conn->master_secret, + conn->key_block, 96, + 0, 4); // send ClientKeyExchange tls_trace("send ClientKeyExchange\n"); @@ -501,11 +465,11 @@ int tlcp_connect(TLS_CONNECT *conn, const char *hostname, int port, goto end; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - if (client_sign_key) + if (conn->client_certs_len) sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); // send CertificateVerify - if (client_sign_key) { + if (conn->client_certs_len) { tls_trace("send CertificateVerify\n"); uint8_t sigbuf[SM2_MAX_SIGNATURE_SIZE]; if (sm2_sign_finish(&sign_ctx, sigbuf, &siglen) != 1 @@ -639,21 +603,13 @@ end: return 1; } - -int tlcp_accept(TLS_CONNECT *conn, int port, - FILE *certs_fp, const SM2_KEY *server_sign_key, const SM2_KEY *server_enc_key, - FILE *client_cacerts_fp, uint8_t *handshakes_buf, size_t handshakes_buflen) +int tlcp_do_accept(TLS_CONNECT *conn) { int ret = -1; - int sock; - struct sockaddr_in server_addr; - struct sockaddr_in client_addr; - socklen_t client_addrlen; - int client_verify = 0; - uint8_t record[TLS_MAX_RECORD_SIZE]; + uint8_t *record = conn->record; uint8_t finished_record[TLS_FINISHED_RECORD_BUF_SIZE]; // 解密可能导致前面的record被覆盖 size_t recordlen, finished_record_len; const int server_ciphers[] = { TLCP_cipher_ecc_sm4_cbc_sm3 }; // 未来应该支持GCM/CBC两个套件 @@ -704,51 +660,9 @@ int tlcp_accept(TLS_CONNECT *conn, int port, size_t len; - - memset(conn, 0, sizeof(*conn)); - - - - // 设置服务器证书(客户端验证CA证书) - if (!certs_fp || !server_sign_key || !server_enc_key) { - error_print(); - goto end; - } - if (x509_certs_from_pem(conn->server_certs, &conn->server_certs_len, 2048, certs_fp) != 1) { - error_print(); - goto end; - } - if (client_cacerts_fp) { - if (x509_certs_from_pem(conn->ca_certs, &conn->ca_certs_len, 2048, client_cacerts_fp) != 1) { - error_print(); - goto end; - } + // 服务器端如果设置了CA + if (conn->ca_certs_len) client_verify = 1; - } - - // Socket - if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - error_print(); - goto end; - } - server_addr.sin_family = AF_INET; - server_addr.sin_addr.s_addr = INADDR_ANY; - server_addr.sin_port = htons(port); - if (bind(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) { - error_print(); - perror("tlcp_accept: bind: "); - goto end; - } - puts("start listen ...\n"); - listen(sock, 5); - - client_addrlen = sizeof(client_addr); - if ((conn->sock = accept(sock, (struct sockaddr *)&client_addr, &client_addrlen)) < 0) { - error_print(); - goto end; - } - puts("socket connected\n"); - // 初始化Finished和客户端验证环境 sm3_init(&sm3_ctx); @@ -846,7 +760,7 @@ int tlcp_accept(TLS_CONNECT *conn, int port, } p = server_enc_cert_lenbuf; len = 0; tls_uint24_to_bytes(server_enc_cert_len, &p, &len); - if (sm2_sign_init(&sign_ctx, server_sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1 + if (sm2_sign_init(&sign_ctx, &conn->sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1 || sm2_sign_update(&sign_ctx, client_random, 32) != 1 || sm2_sign_update(&sign_ctx, server_random, 32) != 1 || sm2_sign_update(&sign_ctx, server_enc_cert_lenbuf, 3) != 1 @@ -910,7 +824,7 @@ int tlcp_accept(TLS_CONNECT *conn, int port, tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5); // recv ClientCertificate - if (client_cacerts_fp) { + if (conn->ca_certs_len) { tls_trace("recv ClientCertificate\n"); if (tls_record_recv(record, &recordlen, conn->sock) != 1 || tls_record_version(record) != TLS_version_tlcp) { @@ -948,7 +862,7 @@ int tlcp_accept(TLS_CONNECT *conn, int port, tls_send_alert(conn, TLS_alert_unexpected_message); goto end; } - if (sm2_decrypt(server_enc_key, enced_pms, enced_pms_len, + if (sm2_decrypt(&conn->kenc_key, enced_pms, enced_pms_len, pre_master_secret, &pre_master_secret_len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_decrypt_error); @@ -1008,15 +922,12 @@ int tlcp_accept(TLS_CONNECT *conn, int port, sm3_hmac_init(&conn->server_write_mac_ctx, conn->key_block + 32, 32); sm4_set_decrypt_key(&conn->client_write_enc_key, conn->key_block + 64); sm4_set_encrypt_key(&conn->server_write_enc_key, conn->key_block + 80); - /* - format_bytes(stderr, 0, 4, "PRE_MASTER_SECRET", pre_master_secret, 48); - format_bytes(stderr, 0, 4, "MASTER_SECRET", conn->master_secret, 48); - format_bytes(stderr, 0, 4, "CLIENT_WRITE_MAC_KEY", conn->key_block, 32); - format_bytes(stderr, 0, 4, "SERVER_WRITE_MAC_KEY", conn->key_block + 32, 32); - format_bytes(stderr, 0, 4, "CLIENT_WRITE_ENC_KEY", conn->key_block + 64, 16); - format_bytes(stderr, 0, 4, "SERVER_WRITE_ENC_KEY", conn->key_block + 80, 16); - format_print(stderr, 0, 0, "\n"); - */ + tls_secrets_print(stderr, + pre_master_secret, 48, + client_random, server_random, + conn->master_secret, + conn->key_block, 96, + 0, 4); // recv [ChangeCipherSpec] tls_trace("recv [ChangeCipherSpec]\n"); diff --git a/src/tls.c b/src/tls.c index 647a2ee9..a9df2a22 100644 --- a/src/tls.c +++ b/src/tls.c @@ -49,9 +49,11 @@ #include #include +#include #include #include #include +#include #include #include #include @@ -69,7 +71,7 @@ void tls_uint8_to_bytes(uint8_t a, uint8_t **out, size_t *outlen) { - if (out) { + if (out && *out) { *(*out)++ = a; } (*outlen)++; @@ -77,7 +79,7 @@ void tls_uint8_to_bytes(uint8_t a, uint8_t **out, size_t *outlen) void tls_uint16_to_bytes(uint16_t a, uint8_t **out, size_t *outlen) { - if (out) { + if (out && *out) { *(*out)++ = (uint8_t)(a >> 8); *(*out)++ = (uint8_t)a; } @@ -86,7 +88,7 @@ void tls_uint16_to_bytes(uint16_t a, uint8_t **out, size_t *outlen) void tls_uint24_to_bytes(uint24_t a, uint8_t **out, size_t *outlen) { - if (out) { + if (out && *out) { *(*out)++ = (uint8_t)(a >> 16); *(*out)++ = (uint8_t)(a >> 8); *(*out)++ = (uint8_t)(a); @@ -96,7 +98,7 @@ void tls_uint24_to_bytes(uint24_t a, uint8_t **out, size_t *outlen) void tls_uint32_to_bytes(uint32_t a, uint8_t **out, size_t *outlen) { - if (out) { + if (out && *out) { *(*out)++ = (uint8_t)(a >> 24); *(*out)++ = (uint8_t)(a >> 16); *(*out)++ = (uint8_t)(a >> 8); @@ -107,7 +109,7 @@ void tls_uint32_to_bytes(uint32_t a, uint8_t **out, size_t *outlen) void tls_array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen) { - if (*out) { + if (out && *out) { if (data) { memcpy(*out, data, datalen); } @@ -247,23 +249,13 @@ int tls_uint24array_from_bytes(const uint8_t **data, size_t *datalen, const uint return 1; } -// 获取记录基本信息,不做正确性检查,考虑实现为宏 -int tls_record_type(const uint8_t *record) +int tls_length_is_zero(size_t len) { - return record[0]; -} - -int tls_record_length(const uint8_t *record) -{ - int ret; - ret = ((uint16_t)record[3] << 8) | record[4]; - return ret; -} - -int tls_record_version(const uint8_t *record) -{ - int version = ((int)record[1] << 8) | record[2]; - return version; + if (len) { + error_print(); + return -1; + } + return 1; } int tls_record_set_type(uint8_t *record, int type) @@ -287,6 +279,29 @@ int tls_record_set_version(uint8_t *record, int version) return 1; } +int tls_record_set_length(uint8_t *record, size_t length) +{ + uint8_t *p = record + 3; + size_t len; + if (length > TLS_MAX_CIPHERTEXT_SIZE) { + error_print(); + return -1; + } + tls_uint16_to_bytes(length, &p, &len); + return 1; +} + +int tls_record_set_data(uint8_t *record, const uint8_t *data, size_t datalen) +{ + if (tls_record_set_length(record, datalen) != 1) { + error_print(); + return -1; + } + memcpy(tls_record_data(record), data, datalen); + return 1; +} + + int tls_cbc_encrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *enc_key, const uint8_t seq_num[8], const uint8_t header[5], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) @@ -616,12 +631,21 @@ int tls_record_set_handshake(uint8_t *record, size_t *recordlen, error_print(); return -1; } - // 这个长度限制应该修改为宏 - if (datalen > (1 << 14) - 4) { - error_puts("gmssl does not support handshake longer than record"); + // 由于ServerHelloDone没有负载数据,因此允许 data,datalen = NULL,0 + if (datalen > TLS_MAX_PLAINTEXT_SIZE - TLS_HANDSHAKE_HEADER_SIZE) { + error_print(); return -1; } - handshakelen = 4 + datalen; + + if (!tls_version_text(tls_record_version(record))) { + error_print(); + return -1; + } + if (!tls_handshake_type_name(type)) { + error_print(); + return -1; + } + handshakelen = TLS_HANDSHAKE_HEADER_SIZE + datalen; record[0] = TLS_record_handshake; record[3] = handshakelen >> 8; record[4] = handshakelen; @@ -630,42 +654,62 @@ int tls_record_set_handshake(uint8_t *record, size_t *recordlen, record[7] = datalen >> 8; record[8] = datalen; if (data) { - memcpy(record + 5 + 4, data, datalen); + memcpy(tls_handshake_data(tls_record_data(record)), data, datalen); } - *recordlen = 5 + handshakelen; + *recordlen = TLS_RECORD_HEADER_SIZE + handshakelen; return 1; } -// 这个函数应该再仔细检查一下 int tls_record_get_handshake(const uint8_t *record, int *type, const uint8_t **data, size_t *datalen) { - size_t record_datalen; - + const uint8_t *handshake; + size_t handshake_len; + uint24_t handshake_datalen; if (!record || !type || !data || !datalen) { error_print(); return -1; } - if (record[0] != TLS_record_handshake) { + if (!tls_version_text(tls_record_version(record))) { error_print(); return -1; } - // 我们应该假定这个record是正确的,不再检查长度之类 - record_datalen = (size_t)record[3] << 8 | record[4]; - if (record_datalen > TLS_MAX_PLAINTEXT_SIZE - || record_datalen < 4) { + if (tls_record_type(record) != TLS_record_handshake) { error_print(); return -1; } - if (!tls_handshake_type_name(record[5])) { + handshake = tls_record_data(record); + handshake_len = tls_record_data_length(record); + + if (handshake_len < TLS_HANDSHAKE_HEADER_SIZE) { + error_print(); + return -1; + } + if (handshake_len > TLS_MAX_PLAINTEXT_SIZE) { + // 不支持证书长度超过记录长度的特殊情况 error_print(); return -1; } - *type = record[5]; - *datalen = ((size_t)record[6] << 16) | ((size_t)record[7] << 8) | record[8]; // FIXME:检查长度 - *data = record + 5 + 4; + if (!tls_handshake_type_name(handshake[0])) { + error_print(); + return -1; + } + *type = handshake[0]; + + handshake++; + handshake_len--; + if (tls_uint24_from_bytes(&handshake_datalen, &handshake, &handshake_len) != 1) { + error_print(); + return -1; + } + if (handshake_len != handshake_datalen) { + error_print(); + return -1; + } + *data = handshake; + *datalen = handshake_datalen; if (*datalen == 0) { *data = NULL; @@ -673,8 +717,6 @@ int tls_record_get_handshake(const uint8_t *record, return 1; } -// handshake messages - int tls_record_set_handshake_client_hello(uint8_t *record, size_t *recordlen, int version, const uint8_t random[32], const uint8_t *session_id, size_t session_id_len, @@ -682,37 +724,63 @@ int tls_record_set_handshake_client_hello(uint8_t *record, size_t *recordlen, const uint8_t *exts, size_t exts_len) { uint8_t type = TLS_handshake_client_hello; - uint8_t *p = record + 5 + 4; - size_t len = 0; + uint8_t *p; + size_t len; - if (!record || !recordlen || !random - || (!session_id && session_id_len) || session_id_len > 32 - || !cipher_suites || !cipher_suites_count || cipher_suites_count > 64 - || (!exts && exts_len) || exts_len > 512) { + if (!record || !recordlen || !random || !cipher_suites || !cipher_suites_count) { error_print(); return -1; } + if (session_id) { + if (!session_id_len + || session_id_len < TLS_MAX_SESSION_ID_SIZE + || session_id_len > TLS_MAX_SESSION_ID_SIZE) { + error_print(); + return -1; + } + } + if (cipher_suites_count > TLS_MAX_CIPHER_SUITES_COUNT) { + error_print(); + return -1; + } + if (exts && !exts_len) { + error_print(); + return -1; + } + + + p = tls_handshake_data(tls_record_data(record)); + len = 0; + if (!tls_version_text(version)) { error_print(); return -1; } - - tls_uint16_to_bytes((uint16_t)version, &p, &len); tls_array_to_bytes(random, 32, &p, &len); tls_uint8array_to_bytes(session_id, session_id_len, &p, &len); tls_uint16_to_bytes(cipher_suites_count * 2, &p, &len); while (cipher_suites_count--) { + if (!tls_cipher_suite_name(*cipher_suites)) { + error_print(); + return -1; + } tls_uint16_to_bytes((uint16_t)*cipher_suites, &p, &len); cipher_suites++; } tls_uint8_to_bytes(1, &p, &len); tls_uint8_to_bytes((uint8_t)TLS_compression_null, &p, &len); if (exts) { + size_t tmp_len = len; if (version < TLS_version_tls12) { error_print(); return -1; } + tls_uint16array_to_bytes(exts, exts_len, NULL, &tmp_len); + if (tmp_len > TLS_MAX_HANDSHAKE_DATA_SIZE) { + error_print(); + return -1; + } tls_uint16array_to_bytes(exts, exts_len, &p, &len); } if (tls_record_set_handshake(record, recordlen, type, NULL, len) != 1) { @@ -722,7 +790,6 @@ int tls_record_set_handshake_client_hello(uint8_t *record, size_t *recordlen, return 1; } -// 这样的函数是否会出现内部错误或者消息解析错误呢? int tls_record_get_handshake_client_hello(const uint8_t *record, int *version, const uint8_t **random, const uint8_t **session_id, size_t *session_id_len, @@ -736,18 +803,18 @@ int tls_record_get_handshake_client_hello(const uint8_t *record, const uint8_t *comp_meths; size_t comp_meths_len; - if (!record || !random || !session_id || !session_id_len + if (!record || !version || !random + || !session_id || !session_id_len || !cipher_suites || !cipher_suites_len - || record[0] != TLS_record_handshake) { // record_type应该有一个独立的错误 + || !exts || !exts_len) { error_print(); return -1; } - if (tls_record_type(record) != TLS_record_handshake) { + if (tls_record_get_handshake(record, &type, &p, &len) != 1) { error_print(); return -1; } - if (tls_record_get_handshake(record, &type, &p, &len) != 1 - || type != TLS_handshake_client_hello) { + if (type != TLS_handshake_client_hello) { error_print(); return -1; } @@ -759,16 +826,22 @@ int tls_record_get_handshake_client_hello(const uint8_t *record, error_print(); return -1; } + if (!tls_version_text(ver)) { error_print(); return -1; } *version = ver; - if (*session_id_len > TLS_MAX_SESSION_ID_SIZE) { - error_print(); - return -1; + + if (*session_id) { + if (*session_id_len == 0 + || *session_id_len < TLS_MIN_SESSION_ID_SIZE + || *session_id_len > TLS_MAX_SESSION_ID_SIZE) { + error_print(); + return -1; + } } - // 是否允许未定义密码套件,留给调用方解析判断 + if (!cipher_suites) { error_print(); return -1; @@ -777,11 +850,16 @@ int tls_record_get_handshake_client_hello(const uint8_t *record, error_print(); return -1; } + if (len) { if (tls_uint16array_from_bytes(exts, exts_len, &p, &len) != 1) { error_print(); return -1; } + if (*exts == NULL) { + error_print(); + return -1; + } } else { *exts = NULL; *exts_len = 0; @@ -793,39 +871,27 @@ int tls_record_get_handshake_client_hello(const uint8_t *record, return 1; } -// 如果有错误,都是内部错误 int tls_record_set_handshake_server_hello(uint8_t *record, size_t *recordlen, int version, const uint8_t random[32], const uint8_t *session_id, size_t session_id_len, int cipher_suite, const uint8_t *exts, size_t exts_len) { uint8_t type = TLS_handshake_server_hello; - uint8_t *p = record + 5 + 4; - size_t len = 0; + uint8_t *p; + size_t len; - if (!record || !recordlen) { + if (!record || !recordlen || !random) { error_print(); return -1; } - if (tls_version_text(version) == NULL || random == NULL) { - error_print(); - return -1; - } - if (session_id != NULL) { - if (session_id_len <= 0 || session_id_len > 32) { + if (session_id) { + if (session_id_len == 0 + || session_id_len < TLS_MIN_SESSION_ID_SIZE + || session_id_len > TLS_MAX_SESSION_ID_SIZE) { error_print(); return -1; } } - if (exts && exts_len > 512) { - error_print(); - return -1; - } - - if (record[0] != TLS_record_handshake) { - error_print(); - return -1; - } if (!tls_version_text(version)) { error_print(); return -1; @@ -834,12 +900,10 @@ int tls_record_set_handshake_server_hello(uint8_t *record, size_t *recordlen, error_print(); return -1; } - /* - if (version < tls_record_version(record)) { - error_print(); - return -1; - } - */ + + p = tls_handshake_data(tls_record_data(record)); + len = 0; + tls_uint16_to_bytes((uint16_t)version, &p, &len); tls_array_to_bytes(random, 32, &p, &len); tls_uint8array_to_bytes(session_id, session_id_len, &p, &len); @@ -866,16 +930,12 @@ int tls_record_get_handshake_server_hello(const uint8_t *record, int type; const uint8_t *p; size_t len; - uint16_t ver; // 如果直接读取uint16到*version中,则*version的高16位没有初始化 - uint16_t cipher; // 同上 + uint16_t ver; + uint16_t cipher; uint8_t comp_meth; if (!record || !version || !random || !session_id || !session_id_len - || !cipher_suite) { - error_print(); - return -1; - } - if (record[0] != TLS_record_handshake) { + || !cipher_suite || !exts || !exts_len) { error_print(); return -1; } @@ -895,6 +955,7 @@ int tls_record_get_handshake_server_hello(const uint8_t *record, error_print(); return -1; } + if (!tls_version_text(ver)) { error_print(); return -1; @@ -904,24 +965,36 @@ int tls_record_get_handshake_server_hello(const uint8_t *record, return -1; } *version = ver; - if (*session_id_len > TLS_MAX_SESSION_ID_SIZE) { + + if (*session_id) { + if (*session_id == 0 + || *session_id_len < TLS_MIN_SESSION_ID_SIZE + || *session_id_len > TLS_MAX_SESSION_ID_SIZE) { + error_print(); + return -1; + } + } + + if (!tls_cipher_suite_name(cipher)) { error_print(); return -1; } - if (!tls_cipher_suite_name(cipher)) { - error_print_msg("unknown server cipher_suite 0x%04x", *cipher_suite); - return -1; - } *cipher_suite = cipher; + if (comp_meth != TLS_compression_null) { error_print(); return -1; } + if (len) { if (tls_uint16array_from_bytes(exts, exts_len, &p, &len) != 1) { error_print(); return -1; } + if (*exts == NULL) { + error_print(); + return -1; + } } else { *exts = NULL; *exts_len = 0; @@ -933,15 +1006,14 @@ int tls_record_get_handshake_server_hello(const uint8_t *record, return 1; } - int tls_record_set_handshake_certificate(uint8_t *record, size_t *recordlen, const uint8_t *certs, size_t certslen) { int type = TLS_handshake_certificate; - const size_t maxlen = TLS_MAX_HANDSHAKE_DATA_SIZE - tls_uint24_size(); - uint8_t *data, *p; - size_t datalen = 0; - size_t len = 0; + uint8_t *data; + size_t datalen; + uint8_t *p; + size_t len; if (!record || !recordlen || !certs || !certslen) { error_print(); @@ -949,8 +1021,9 @@ int tls_record_set_handshake_certificate(uint8_t *record, size_t *recordlen, } data = tls_handshake_data(tls_record_data(record)); p = data + tls_uint24_size(); + datalen = tls_uint24_size(); + len = 0; - // set (uint24 certlen, cert)* while (certslen) { const uint8_t *cert; size_t certlen; @@ -959,55 +1032,24 @@ int tls_record_set_handshake_certificate(uint8_t *record, size_t *recordlen, error_print(); return -1; } - // 如何防止溢出 - if (3 + certlen > maxlen) { + tls_uint24array_to_bytes(cert, certlen, NULL, &datalen); + if (datalen > TLS_MAX_HANDSHAKE_DATA_SIZE) { error_print(); return -1; } - tls_uint24array_to_bytes(cert, certlen, &p, &datalen); + tls_uint24array_to_bytes(cert, certlen, &p, &len); } - tls_uint24_to_bytes(datalen, &data, &datalen); + tls_uint24_to_bytes(len, &data, &len); tls_record_set_handshake(record, recordlen, type, NULL, datalen); return 1; } -/* -int tls_record_set_handshake_certificate_from_pem(uint8_t *record, size_t *recordlen, FILE *fp) -{ - int type = TLS_handshake_certificate; - uint8_t *data = record + 5 + 4; - uint8_t *certs = data + 3; - size_t datalen, certslen = 0; - - for (;;) { - int ret; - uint8_t cert[1024]; - size_t certlen; - - if ((ret = x509_cert_from_pem(cert, &certlen, sizeof(cert), fp)) < 0) { - error_print(); - return -1; - } else if (!ret) { - break; - } - tls_uint24array_to_bytes(cert, certlen, &certs, &certslen); - } - datalen = certslen; - tls_uint24_to_bytes((uint24_t)certslen, &data, &datalen); - tls_record_set_handshake(record, recordlen, type, NULL, datalen); - return 1; -} -*/ - -// 如果certs长度超过限制怎么办? -// 在调用这个函数之前,应该保证准备的缓冲区为 int tls_record_get_handshake_certificate(const uint8_t *record, uint8_t *certs, size_t *certslen) { int type; const uint8_t *data; size_t datalen; - uint8_t *out = certs; - const uint8_t *p; + const uint8_t *cp; size_t len; if (tls_record_get_handshake(record, &type, &data, &datalen) != 1) { @@ -1018,25 +1060,25 @@ int tls_record_get_handshake_certificate(const uint8_t *record, uint8_t *certs, error_print(); return -1; } - if (tls_uint24array_from_bytes(&p, &len, &data, &datalen) != 1) { + if (tls_uint24array_from_bytes(&cp, &len, &data, &datalen) != 1) { error_print(); return -1; } *certslen = 0; while (len) { - const uint8_t *d; - size_t dlen; + const uint8_t *a; + size_t alen; const uint8_t *cert; size_t certlen; - if (tls_uint24array_from_bytes(&d, &dlen, &p, &len) != 1) { + if (tls_uint24array_from_bytes(&a, &alen, &cp, &len) != 1) { error_print(); return -1; } - if (x509_cert_from_der(&cert, &certlen, &d, &dlen) != 1 - || asn1_length_is_zero(dlen) != 1 - || x509_cert_to_der(cert, certlen, &out, certslen) != 1) { + if (x509_cert_from_der(&cert, &certlen, &a, &alen) != 1 + || asn1_length_is_zero(alen) != 1 + || x509_cert_to_der(cert, certlen, &certs, certslen) != 1) { error_print(); return -1; } @@ -1049,19 +1091,36 @@ int tls_record_set_handshake_certificate_request(uint8_t *record, size_t *record const uint8_t *ca_names, size_t ca_names_len) { int type = TLS_handshake_certificate_request; - uint8_t *p = record + 5 + 4; - size_t len = 0; + uint8_t *p; + size_t len =0; + size_t datalen = 0; - if (!record || !recordlen - || !cert_types || !cert_types_len || cert_types_len > TLS_MAX_CERTIFICATE_TYPES - || (!ca_names && ca_names_len) || ca_names_len > TLS_MAX_CA_NAMES_SIZE) { + if (!record || !recordlen) { error_print(); return -1; } - // 对cert_types_len和ca_names_len的长度检查保证输出不会超过记录长度 + if (cert_types) { + if (cert_types_len == 0 || cert_types_len > TLS_MAX_CERTIFICATE_TYPES) { + error_print(); + return -1; + } + } + if (ca_names) { + if (ca_names_len == 0 || ca_names_len > TLS_MAX_CA_NAMES_SIZE) { + error_print(); + return -1; + } + } + tls_uint8array_to_bytes(cert_types, cert_types_len, NULL, &datalen); + tls_uint16array_to_bytes(ca_names, ca_names_len, NULL, &datalen); + if (datalen > TLS_MAX_HANDSHAKE_DATA_SIZE) { + error_print(); + return -1; + } + p = tls_handshake_data(tls_record_data(record)); tls_uint8array_to_bytes(cert_types, cert_types_len, &p, &len); tls_uint16array_to_bytes(ca_names, ca_names_len, &p, &len); - tls_record_set_handshake(record, recordlen, type, NULL, len); + tls_record_set_handshake(record, recordlen, type, NULL, datalen); return 1; } @@ -1072,22 +1131,47 @@ int tls_record_get_handshake_certificate_request(const uint8_t *record, int type; const uint8_t *cp; size_t len; - const uint8_t *types; - size_t count; + size_t i; - if (!record - || !cert_types || !cert_types_len || !ca_names || !ca_names_len - || record[0] != TLS_record_handshake) { + if (!record || !cert_types || !cert_types_len || !ca_names || !ca_names_len) { error_print(); return -1; } - if (tls_record_get_handshake(record, &type, &cp, &len) != 1 - || tls_uint8array_from_bytes(cert_types, cert_types_len, &cp, &len) != 1 - || tls_uint16array_from_bytes(ca_names, ca_names_len, &cp, &len) != 1 - || len > 0) { + if (tls_record_get_handshake(record, &type, &cp, &len) != 1) { error_print(); return -1; } + if (type != TLS_handshake_certificate_request) { + error_print(); + return -1; + } + if (tls_uint8array_from_bytes(cert_types, cert_types_len, &cp, &len) != 1 + || tls_uint16array_from_bytes(ca_names, ca_names_len, &cp, &len) != 1 + || tls_length_is_zero(len) != 1) { + error_print(); + return -1; + } + + if (*cert_types == NULL) { + error_print(); + return -1; + } + for (i = 0; i < *cert_types_len; i++) { + if (!tls_cert_type_name((*cert_types)[i])) { + error_print(); + return -1; + } + } + if (*ca_names) { + const uint8_t *names = *ca_names; + size_t nameslen = *ca_names_len; + while (nameslen) { + if (tls_uint16array_from_bytes(&cp, &len, &names, &nameslen) != 1) { + error_print(); + return -1; + } + } + } return 1; } @@ -1113,8 +1197,11 @@ int tls_record_get_handshake_server_hello_done(const uint8_t *record) return -1; } if (tls_record_get_handshake(record, &type, &p, &len) != 1 - || type != TLS_handshake_server_hello_done - || len != 0) { + || type != TLS_handshake_server_hello_done) { + error_print(); + return -1; + } + if (p != NULL || len != 0) { error_print(); return -1; } @@ -1125,13 +1212,18 @@ int tls_record_set_handshake_client_key_exchange_pke(uint8_t *record, size_t *re const uint8_t *enced_pms, size_t enced_pms_len) { int type = TLS_handshake_client_key_exchange; - uint8_t *p = record + 5 + 4; + uint8_t *p; size_t len = 0; - if (!record || !recordlen - || !enced_pms || !enced_pms_len || enced_pms_len > 65535) { + + if (!record || !recordlen || !enced_pms || !enced_pms_len) { error_print(); return -1; } + if (enced_pms_len > TLS_MAX_HANDSHAKE_DATA_SIZE - tls_uint16_size()) { + error_print(); + return -1; + } + p = tls_handshake_data(tls_record_data(record)); tls_uint16array_to_bytes(enced_pms, enced_pms_len, &p, &len); tls_record_set_handshake(record, recordlen, type, NULL, len); return 1; @@ -1141,21 +1233,23 @@ int tls_record_get_handshake_client_key_exchange_pke(const uint8_t *record, const uint8_t **enced_pms, size_t *enced_pms_len) { int type; - const uint8_t *p; + const uint8_t *cp; size_t len; - if (!record || !enced_pms || !enced_pms_len - || record[0] != TLS_record_handshake) { + if (!record || !enced_pms || !enced_pms_len) { error_print(); return -1; } - if (tls_record_get_handshake(record, &type, &p, &len) != 1 - || type != TLS_handshake_client_key_exchange) { + if (tls_record_get_handshake(record, &type, &cp, &len) != 1) { error_print(); return -1; } - if (tls_uint16array_from_bytes(enced_pms, enced_pms_len, &p, &len) != 1 - || len > 0) { + if (type != TLS_handshake_client_key_exchange) { + error_print(); + return -1; + } + if (tls_uint16array_from_bytes(enced_pms, enced_pms_len, &cp, &len) != 1 + || tls_length_is_zero(len) != 1) { error_print(); return -1; } @@ -1166,6 +1260,15 @@ int tls_record_set_handshake_certificate_verify(uint8_t *record, size_t *recordl const uint8_t *sig, size_t siglen) { int type = TLS_handshake_certificate_verify; + + if (!record || !recordlen || !sig || !siglen) { + error_print(); + return -1; + } + if (siglen > TLS_MAX_SIGNATURE_SIZE) { + error_print(); + return -1; + } tls_record_set_handshake(record, recordlen, type, sig, siglen); return 1; } @@ -1175,25 +1278,39 @@ int tls_record_get_handshake_certificate_verify(const uint8_t *record, { int type; - if (tls_record_get_handshake(record, &type, sig, siglen) != 1 - || type != TLS_handshake_certificate_verify) { + if (!record || !sig || !siglen) { error_print(); return -1; } - if (*sig == NULL) { + if (tls_record_get_handshake(record, &type, sig, siglen) != 1) { + error_print(); + return -1; + } + if (type != TLS_handshake_certificate_verify) { + error_print(); + return -1; + } + if (*sig == NULL || *siglen == 0) { + error_print(); + return -1; + } + if (*siglen > TLS_MAX_SIGNATURE_SIZE) { error_print(); return -1; } return 1; } -//FIXME: TLS 1.3 中的verify_data长度和hashLen一样,并且长度是不单独编码的, -// 因此这个函数应该改一下了 int tls_record_set_handshake_finished(uint8_t *record, size_t *recordlen, const uint8_t *verify_data, size_t verify_data_len) { int type = TLS_handshake_finished; - if (!record || !recordlen || !verify_data) { + + if (!record || !recordlen || !verify_data || !verify_data_len) { + error_print(); + return -1; + } + if (verify_data_len != 12 && verify_data_len != 32) { error_print(); return -1; } @@ -1205,6 +1322,10 @@ int tls_record_get_handshake_finished(const uint8_t *record, const uint8_t **ver { int type; + if (!record || !verify_data || !verify_data_len) { + error_print(); + return -1; + } if (tls_record_get_handshake(record, &type, verify_data, verify_data_len) != 1) { error_print(); return -1; @@ -1213,24 +1334,30 @@ int tls_record_get_handshake_finished(const uint8_t *record, const uint8_t **ver error_print(); return -1; } - if (*verify_data == NULL) { + if (*verify_data == NULL || *verify_data_len == 0) { + error_print(); + return -1; + } + if (*verify_data_len != 12 && *verify_data_len != 32) { error_print(); return -1; } return 1; } -// alert protocol - - -// 这个函数没有必要设置长度,因此Alert长度是固定的! int tls_record_set_alert(uint8_t *record, size_t *recordlen, int alert_level, int alert_description) { - if (!record || !recordlen - || !tls_alert_level_name(alert_level) - || !tls_alert_description_text(alert_description)) { + if (!record || !recordlen) { + error_print(); + return -1; + } + if (!tls_alert_level_name(alert_level)) { + error_print(); + return -1; + } + if (!tls_alert_description_text(alert_description)) { error_print(); return -1; } @@ -1239,7 +1366,7 @@ int tls_record_set_alert(uint8_t *record, size_t *recordlen, record[4] = 2; // length record[5] = (uint8_t)alert_level; record[6] = (uint8_t)alert_description; - *recordlen = 7; + *recordlen = TLS_RECORD_HEADER_SIZE + 2; return 1; } @@ -1251,7 +1378,7 @@ int tls_record_get_alert(const uint8_t *record, error_print(); return -1; } - if (record[0] != TLS_record_alert) { + if (tls_record_type(record) != TLS_record_alert) { error_print(); return -1; } @@ -1272,10 +1399,6 @@ int tls_record_get_alert(const uint8_t *record, return 1; } - -// change_cipher_spec protocol - - int tls_record_set_change_cipher_spec(uint8_t *record, size_t *recordlen) { if (!record || !recordlen) { @@ -1286,7 +1409,7 @@ int tls_record_set_change_cipher_spec(uint8_t *record, size_t *recordlen) record[3] = 0; record[4] = 1; record[5] = TLS_change_cipher_spec; - *recordlen = 6; + *recordlen = TLS_RECORD_HEADER_SIZE + 1; return 1; } @@ -1296,7 +1419,7 @@ int tls_record_get_change_cipher_spec(const uint8_t *record) error_print(); return -1; } - if (record[0] != TLS_record_change_cipher_spec) { + if (tls_record_type(record) != TLS_record_change_cipher_spec) { error_print(); return -1; } @@ -1305,7 +1428,7 @@ int tls_record_get_change_cipher_spec(const uint8_t *record) return -1; } if (record[5] != TLS_change_cipher_spec) { - error_print_msg("unknown ChangeCipherSpec value %d", record[5]); + error_print(); return -1; } return 1; @@ -1314,30 +1437,41 @@ int tls_record_get_change_cipher_spec(const uint8_t *record) int tls_record_set_application_data(uint8_t *record, size_t *recordlen, const uint8_t *data, size_t datalen) { + if (!record || !recordlen || !data || !datalen) { + error_print(); + return -1; + } record[0] = TLS_record_application_data; record[3] = (datalen >> 8) & 0xff; record[4] = datalen & 0xff; - memcpy(record + 5, data, datalen); - *recordlen = 5 + datalen; + memcpy(tls_record_data(record), data, datalen); + *recordlen = TLS_RECORD_HEADER_SIZE + datalen; return 1; } int tls_record_get_application_data(uint8_t *record, const uint8_t **data, size_t *datalen) { - if (record[0] != TLS_record_application_data) { + if (!record || !data || !datalen) { + error_print(); + return -1; + } + if (tls_record_type(record) != TLS_record_application_data) { error_print(); return -1; } *datalen = ((size_t)record[3] << 8) | record[4]; - *data = record + 5; + *data = *datalen ? record + TLS_RECORD_HEADER_SIZE : 0; return 1; } - int tls_cipher_suite_in_list(int cipher, const int *list, size_t list_count) { size_t i; + if (!list || !list_count) { + error_print(); + return -1; + } for (i = 0; i < list_count; i++) { if (cipher == list[i]) { return 1; @@ -1346,18 +1480,23 @@ int tls_cipher_suite_in_list(int cipher, const int *list, size_t list_count) return 0; } -// 两类错误,一种是输入的记录格式有问题,一种是网络问题 -// 显然输入格式是编译期的错误,不应该发生 -// 网络错误如果发生,那么也没有必要再发错误消息了 int tls_record_send(const uint8_t *record, size_t recordlen, int sock) { ssize_t r; - if (recordlen < 5 - || recordlen - 5 != (((size_t)record[3] << 8) | record[4])) { + if (!record) { + error_print(); + return -1; + } + if (recordlen < TLS_RECORD_HEADER_SIZE) { + error_print(); + return -1; + } + if (tls_record_length(record) != recordlen) { error_print(); return -1; } if ((r = send(sock, record, recordlen, 0)) < 0) { + perror(""); error_print(); return -1; } else if (r != recordlen) { @@ -1367,24 +1506,24 @@ int tls_record_send(const uint8_t *record, size_t recordlen, int sock) return 1; } -int tls_record_recv(uint8_t *record, size_t *recordlen, int sock) +int tls_record_do_recv(uint8_t *record, size_t *recordlen, int sock) { ssize_t r; int type; size_t len; -retry: // TODO:支持非租塞socket或针对可能的网络延迟重新recv if ((r = recv(sock, record, 5, 0)) < 0) { + perror(""); error_print(); return -1; } if (!tls_record_type_name(tls_record_type(record))) { - error_print_msg("Invalid record type: %d\n", record[0]); + error_print(); return -1; } if (!tls_version_text(tls_record_version(record))) { - error_print_msg("Invalid record version: %d.%d\n", record[1], record[2]); + error_print(); return -1; } len = (size_t)record[3] << 8 | record[4]; @@ -1403,6 +1542,17 @@ retry: return -1; } } + return 1; +} + +int tls_record_recv(uint8_t *record, size_t *recordlen, int sock) +{ +retry: + if (tls_record_do_recv(record, recordlen, sock) != 1) { + error_print(); + return -1; + } + if (tls_record_type(record) == TLS_record_alert) { int level; int alert; @@ -1423,11 +1573,12 @@ retry: tls_record_set_type(alert_record, TLS_record_alert); tls_record_set_version(alert_record, tls_record_version(record)); tls_record_set_alert(alert_record, &alert_record_len, TLS_alert_level_fatal, TLS_alert_close_notify); - tls_record_print(stderr, alert_record, alert_record_len, 0, 0); + + tls_trace("send Alert close_notifiy\n"); + tls_record_trace(stderr, alert_record, alert_record_len, 0, 0); tls_record_send(alert_record, alert_record_len, sock); } // 返回错误0通知调用方不再做任何处理(无需再发送Alert) - error_puts("Alert record received!\n"); return 0; } return 1; @@ -1455,6 +1606,7 @@ int tls_compression_methods_has_null_compression(const uint8_t *meths, size_t me return 1; } } + error_print(); return -1; } @@ -1463,9 +1615,17 @@ int tls_send_alert(TLS_CONNECT *conn, int alert) uint8_t record[5 + 2]; size_t recordlen; + if (!conn) { + error_print(); + return -1; + } tls_record_set_version(record, conn->version); tls_record_set_alert(record, &recordlen, TLS_alert_level_fatal, alert); - tls_record_send(record, sizeof(record), conn->sock); + + if (tls_record_send(record, sizeof(record), conn->sock) != 1) { + error_print(); + return -1; + } tls_record_trace(stderr, record, sizeof(record), 0, 0); return 1; } @@ -1488,39 +1648,52 @@ int tls_alert_level(int alert) return -1; } - int tls_send_warning(TLS_CONNECT *conn, int alert) { uint8_t record[5 + 2]; size_t recordlen; + if (!conn) { + error_print(); + return -1; + } if (tls_alert_level(alert) == TLS_alert_level_fatal) { error_print(); return -1; } tls_record_set_version(record, conn->version); tls_record_set_alert(record, &recordlen, TLS_alert_level_warning, alert); - tls_record_send(record, sizeof(record), conn->sock); + + if (tls_record_send(record, sizeof(record), conn->sock) != 1) { + error_print(); + return -1; + } tls_record_trace(stderr, record, sizeof(record), 0, 0); return 1; } - - - -// FIXME: 设定支持的最大输入长度 -// FIXME: 没回返回实际的发送长度 -int tls_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen) +int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen) { const SM3_HMAC_CTX *hmac_ctx; const SM4_KEY *enc_key; uint8_t *seq_num; - uint8_t mrec[1600]; - uint8_t crec[1600]; - size_t mlen = sizeof(mrec); - size_t clen = sizeof(crec); + uint8_t *record; + size_t recordlen; + uint8_t *data; + size_t datalen; - // FIXME: 检查datalen的长度 + if (!conn) { + error_print(); + return -1; + } + if (!in || !inlen || !sentlen) { + error_print(); + return -1; + } + + if (inlen > TLS_MAX_PLAINTEXT_SIZE) { + inlen = TLS_MAX_PLAINTEXT_SIZE; + } if (conn->is_client) { hmac_ctx = &conn->client_write_mac_ctx; @@ -1531,29 +1704,45 @@ int tls_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen) enc_key = &conn->server_write_enc_key; seq_num = conn->server_seq_num; } + record = conn->record; tls_trace("send ApplicationData\n"); - if (tls_record_set_version(mrec, conn->version) != 1 - || tls_record_set_application_data(mrec, &mlen, data, datalen) != 1 - || tls_record_encrypt(hmac_ctx, enc_key, seq_num, mrec, mlen, crec, &clen) != 1 - || tls_seq_num_incr(seq_num) != 1 - || tls_record_send(crec, clen, conn->sock) != 1) { + + if (tls_record_set_type(record, TLS_record_application_data) != 1 + || tls_record_set_version(record, conn->version) != 1 + || tls_record_set_length(record, inlen) != 1) { error_print(); return -1; } - tls_record_trace(stderr, crec, clen, 0, 0); + + if (tls_cbc_encrypt(hmac_ctx, enc_key, seq_num, tls_record_header(record), + in, inlen, tls_record_data(record), &datalen) != 1) { + error_print(); + return -1; + } + if (tls_record_set_length(record, datalen) != 1) { + error_print(); + return -1; + } + tls_seq_num_incr(seq_num); + if (tls_record_send(record, tls_record_length(record), conn->sock) != 1) { + error_print(); + return -1; + } + *sentlen = inlen; + tls_record_trace(stderr, record, tls_record_length(record), 0, 0); return 1; } -int tls_recv(TLS_CONNECT *conn, uint8_t *data, size_t *datalen) +int tls_do_recv(TLS_CONNECT *conn) { + int ret; const SM3_HMAC_CTX *hmac_ctx; const SM4_KEY *dec_key; uint8_t *seq_num; - uint8_t mrec[1600]; - uint8_t crec[1600]; - size_t mlen = sizeof(mrec); - size_t clen = sizeof(crec); + + uint8_t *record = conn->record; + size_t recordlen; if (conn->is_client) { hmac_ctx = &conn->server_write_mac_ctx; @@ -1566,91 +1755,69 @@ int tls_recv(TLS_CONNECT *conn, uint8_t *data, size_t *datalen) } tls_trace("recv ApplicationData\n"); - if (tls_record_recv(crec, &clen, conn->sock) != 1) { + if ((ret = tls_record_recv(record, &recordlen, conn->sock)) != 1) { + if (ret < 0) error_print(); + return ret; + } + tls_record_trace(stderr, record, recordlen, 0, 0); + if (tls_cbc_decrypt(hmac_ctx, dec_key, seq_num, record, + tls_record_data(record), tls_record_data_length(record), + conn->data, &conn->datalen) != 1) { error_print(); return -1; } + tls_seq_num_incr(seq_num); - if (crec[0] == TLS_record_alert) { - int level; - int alert; - - if (tls_record_get_alert(crec, &level, &alert) != 1) { - error_print(); - return -1; - } - if (alert == TLS_alert_close_notify) { - if (tls_record_send(crec, clen, conn->sock) != 1) { - error_print(); - return -1; - } - - } else { - error_print(); - return -1; - } - - if (level == TLS_alert_level_fatal) { - tls_trace("close Connection\n"); - return 0; - } - } - - - // FIXME: 检查版本号 - if (tls_record_decrypt(hmac_ctx, dec_key, seq_num, crec, clen, mrec, &mlen) != 1 - || tls_seq_num_incr(seq_num) != 1) { - error_print(); - return -1; - } - tls_record_trace(stderr, mrec, mlen, 0, 0); - memcpy(data, mrec + 5, mlen - 5); - *datalen = mlen - 5; + tls_record_set_data(record, conn->data, conn->datalen); + tls_trace("decrypt ApplicationData\n"); + tls_record_trace(stderr, record, tls_record_length(record), 0, 0); + return 1; +} + +int tls_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen) +{ + if (!conn || !out || !outlen || !recvlen) { + error_print(); + return -1; + } + if (conn->datalen == 0) { + int ret; + if ((ret = tls_do_recv(conn)) != 1) { + if (ret) error_print(); + return ret; + } + } + *recvlen = outlen <= conn->datalen ? outlen : conn->datalen; + memcpy(out, conn->data, *recvlen); + conn->datalen -= *recvlen; return 1; } -//FIXME: any difference in TLS 1.2 and TLS 1.3? int tls_shutdown(TLS_CONNECT *conn) { - uint8_t alert[128]; - size_t len; - - tls_record_set_version(alert, conn->version); - tls_record_set_alert(alert, &len, TLS_alert_level_fatal, TLS_alert_close_notify); - - if (tls_record_send(alert, len, conn->sock) != 1) { + size_t recordlen; + if (!conn) { error_print(); return -1; } - - tls_trace("send Alert.close_notify\n"); - tls_record_trace(stderr, alert, len, 0, 0); - - - memset(alert, 0, sizeof(alert)); - // 这里接收实际上只是检查一下对方是否合规,不管怎么说我们都要结束了 - if (tls_record_recv(alert, &len, conn->sock) != 1) { + tls_trace("send Alert close_notify\n"); + if (tls_send_alert(conn, TLS_alert_close_notify) != 1) { error_print(); return -1; } - tls_trace("recv Alert.close_notify\n"); - tls_record_trace(stderr, alert, len, 0, 0); + tls_trace("recv Alert close_notify\n"); + + if (tls_record_do_recv(conn->record, &recordlen, conn->sock) != 1) { + error_print(); + return -1; + } + tls_record_trace(stderr, conn->record, recordlen, 0, 0); return 1; } -// 参考 man verify 的错误返回值 -int tls_get_verify_result(TLS_CONNECT *conn, int *result) -{ - *result = 0; - return 1; -} - -// 这里的输出是record,因此是有一个长度限制的 - int tls_authorities_from_certs(uint8_t *names, size_t *nameslen, size_t maxlen, const uint8_t *certs, size_t certslen) { - uint8_t *out = names; const uint8_t *cert; size_t certlen; const uint8_t *name; @@ -1666,16 +1833,11 @@ int tls_authorities_from_certs(uint8_t *names, size_t *nameslen, size_t maxlen, return -1; } if (tls_uint16_size() + alen > maxlen) { - - fprintf(stderr, "alen = %zu\n", alen); - fprintf(stderr, "maxlen = %zu\n", maxlen); - error_print(); return -1; } - // 这里要兼容names == NULL的情况 - tls_uint16_to_bytes(alen, &out, nameslen); - if (asn1_sequence_to_der(name, namelen, &out, nameslen) != 1) { + tls_uint16_to_bytes(alen, &names, nameslen); + if (asn1_sequence_to_der(name, namelen, &names, nameslen) != 1) { error_print(); return -1; } @@ -1742,9 +1904,12 @@ int tls_cert_types_accepted(const uint8_t *types, size_t types_len, const uint8_ return 0; } - int tls_client_verify_init(TLS_CLIENT_VERIFY_CTX *ctx) { + if (!ctx) { + error_print(); + return -1; + } memset(ctx, 0, sizeof(TLS_CLIENT_VERIFY_CTX)); return 1; } @@ -1752,6 +1917,10 @@ int tls_client_verify_init(TLS_CLIENT_VERIFY_CTX *ctx) int tls_client_verify_update(TLS_CLIENT_VERIFY_CTX *ctx, const uint8_t *handshake, size_t handshake_len) { uint8_t *buf; + if (!ctx || !handshake || !handshake_len) { + error_print(); + return -1; + } if (ctx->index < 0 || ctx->index > 7) { error_print(); return -1; @@ -1773,6 +1942,11 @@ int tls_client_verify_finish(TLS_CLIENT_VERIFY_CTX *ctx, const uint8_t *sig, siz SM2_SIGN_CTX sm2_ctx; int i; + if (!ctx || !sig || !siglen || !public_key) { + error_print(); + return -1; + } + if (ctx->index != 8) { error_print(); return -1; @@ -1796,21 +1970,27 @@ int tls_client_verify_finish(TLS_CLIENT_VERIFY_CTX *ctx, const uint8_t *sig, siz void tls_client_verify_cleanup(TLS_CLIENT_VERIFY_CTX *ctx) { - int i; - for (i = 0; i< ctx->index; i++) { - if (ctx->handshake[i]) { - free(ctx->handshake[i]); - ctx->handshake[i] = NULL; - ctx->handshake_len[i] = 0; + if (ctx) { + int i; + for (i = 0; i< ctx->index; i++) { + if (ctx->handshake[i]) { + free(ctx->handshake[i]); + ctx->handshake[i] = NULL; + ctx->handshake_len[i] = 0; + } } } } - int tls_cipher_suites_select(const uint8_t *client_ciphers, size_t client_ciphers_len, const int *server_ciphers, size_t server_ciphers_cnt, int *selected_cipher) { + if (!client_ciphers || !client_ciphers_len + || !server_ciphers || !server_ciphers_cnt || !selected_cipher) { + error_print(); + return -1; + } while (server_ciphers_cnt--) { const uint8_t *p = client_ciphers; size_t len = client_ciphers_len; @@ -1830,5 +2010,317 @@ int tls_cipher_suites_select(const uint8_t *client_ciphers, size_t client_cipher return 0; } +void tls_ctx_cleanup(TLS_CTX *ctx) +{ + if (ctx) { + gmssl_secure_clear(&ctx->signkey, sizeof(SM2_KEY)); + gmssl_secure_clear(&ctx->kenckey, sizeof(SM2_KEY)); + if (ctx->certs) free(ctx->certs); + if (ctx->cacerts) free(ctx->cacerts); + memset(ctx, 0, sizeof(TLS_CTX)); + } +} + +int tls_ctx_init(TLS_CTX *ctx, int protocol_version, int is_client) +{ + if (!ctx) { + error_print(); + return -1; + } + memset(ctx, 0, sizeof(*ctx)); + + switch (protocol_version) { + case TLS_version_tlcp: + case TLS_version_tls12: + case TLS_version_tls13: + ctx->protocol_version = protocol_version; + break; + default: + error_print(); + return -1; + } + ctx->is_client = is_client ? 1 : 0; + return 1; +} + +int tls_ctx_set_cipher_suites(TLS_CTX *ctx, const int *cipher_suites, size_t cipher_suites_cnt) +{ + size_t i; + + if (!ctx || !cipher_suites || !cipher_suites_cnt) { + error_print(); + return -1; + } + if (cipher_suites_cnt < 1 || cipher_suites_cnt > TLS_MAX_CIPHER_SUITES_COUNT) { + error_print(); + return -1; + } + for (i = 0; i < cipher_suites_cnt; i++) { + if (!tls_cipher_suite_name(cipher_suites[i])) { + error_print(); + return -1; + } + } + for (i = 0; i < cipher_suites_cnt; i++) { + ctx->cipher_suites[i] = cipher_suites[i]; + } + ctx->cipher_suites_cnt = cipher_suites_cnt; + return 1; +} + +int tls_ctx_set_ca_certificates(TLS_CTX *ctx, const char *cacertsfile, int depth) +{ + if (!ctx || !cacertsfile) { + error_print(); + return -1; + } + if (depth < 0 || depth > TLS_MAX_VERIFY_DEPTH) { + error_print(); + return -1; + } + if (!tls_version_text(ctx->protocol_version)) { + error_print(); + return -1; + } + if (ctx->cacerts) { + error_print(); + return -1; + } + if (x509_certs_new_from_file(&ctx->cacerts, &ctx->cacertslen, cacertsfile) != 1) { + error_print(); + return -1; + } + if (ctx->cacertslen == 0) { + error_print(); + return -1; + } + + ctx->verify_depth = depth; + return 1; +} + +int tls_ctx_set_certificate_and_key(TLS_CTX *ctx, const char *chainfile, + const char *keyfile, const char *keypass) +{ + int ret = -1; + uint8_t *certs = NULL; + size_t certslen; + FILE *keyfp = NULL; + SM2_KEY key; + const uint8_t *cert; + size_t certlen; + SM2_KEY public_key; + + if (!ctx || !chainfile || !keyfile || !keypass) { + error_print(); + return -1; + } + if (!tls_version_text(ctx->protocol_version)) { + error_print(); + return -1; + } + if (ctx->certs) { + error_print(); + return -1; + } + + if (x509_certs_new_from_file(&certs, &certslen, chainfile) != 1) { + error_print(); + goto end; + } + if (!(keyfp = fopen(keyfile, "r"))) { + error_print(); + goto end; + } + if (sm2_private_key_info_decrypt_from_pem(&key, keypass, keyfp) != 1) { + error_print(); + goto end; + } + if (x509_certs_get_cert_by_index(certs, certslen, 0, &cert, &certlen) != 1 + || x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) { + error_print(); + return -1; + } + if (sm2_public_key_equ(&key, &public_key) != 1) { + error_print(); + return -1; + } + ctx->certs = certs; + ctx->certslen = certslen; + ctx->signkey = key; + certs = NULL; + ret = 1; + +end: + gmssl_secure_clear(&key, sizeof(key)); + if (certs) free(certs); + if (keyfp) fclose(keyfp); + return ret; +} + +int tls_ctx_set_tlcp_server_certificate_and_keys(TLS_CTX *ctx, const char *chainfile, + const char *signkeyfile, const char *signkeypass, + const char *kenckeyfile, const char *kenckeypass) +{ + int ret = -1; + uint8_t *certs = NULL; + size_t certslen; + FILE *signkeyfp = NULL; + FILE *kenckeyfp = NULL; + SM2_KEY signkey; + SM2_KEY kenckey; + + const uint8_t *cert; + size_t certlen; + SM2_KEY public_key; + + if (!ctx || !chainfile || !signkeyfile || !signkeypass || !kenckeyfile || !kenckeypass) { + error_print(); + return -1; + } + if (!tls_version_text(ctx->protocol_version)) { + error_print(); + return -1; + } + if (ctx->certs) { + error_print(); + return -1; + } + + if (x509_certs_new_from_file(&certs, &certslen, chainfile) != 1) { + error_print(); + return -1; + } + + if (!(signkeyfp = fopen(signkeyfile, "r"))) { + error_print(); + goto end; + } + if (sm2_private_key_info_decrypt_from_pem(&signkey, signkeypass, signkeyfp) != 1) { + error_print(); + goto end; + } + if (x509_certs_get_cert_by_index(certs, certslen, 0, &cert, &certlen) != 1 + || x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1 + || sm2_public_key_equ(&signkey, &public_key) != 1) { + error_print(); + goto end; + } + + if (!(kenckeyfp = fopen(kenckeyfile, "r"))) { + error_print(); + goto end; + } + if (sm2_private_key_info_decrypt_from_pem(&kenckey, kenckeypass, kenckeyfp) != 1) { + error_print(); + goto end; + } + if (x509_certs_get_cert_by_index(certs, certslen, 1, &cert, &certlen) != 1 + || x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1 + || sm2_public_key_equ(&kenckey, &public_key) != 1) { + error_print(); + goto end; + } + + ctx->certs = certs; + ctx->certslen = certslen; + ctx->signkey = signkey; + ctx->kenckey = kenckey; + certs = NULL; + ret = 1; + +end: + gmssl_secure_clear(&signkey, sizeof(signkey)); + gmssl_secure_clear(&kenckey, sizeof(kenckey)); + if (certs) free(certs); + if (signkeyfp) fclose(signkeyfp); + if (kenckeyfp) fclose(kenckeyfp); + return ret; +} + +int tls_init(TLS_CONNECT *conn, const TLS_CTX *ctx) +{ + size_t i; + memset(conn, 0, sizeof(*conn)); + + conn->version = ctx->protocol_version; + conn->is_client = ctx->is_client; + for (i = 0; i < ctx->cipher_suites_cnt; i++) { + conn->cipher_suites[i] = ctx->cipher_suites[i]; + } + conn->cipher_suites_cnt = ctx->cipher_suites_cnt; + if (ctx->certslen > TLS_MAX_CERTIFICATES_SIZE) { + error_print(); + return -1; + } + if (conn->is_client) { + memcpy(conn->client_certs, ctx->certs, ctx->certslen); + conn->client_certs_len = ctx->certslen; + } else { + memcpy(conn->server_certs, ctx->certs, ctx->certslen); + conn->server_certs_len = ctx->certslen; + } + + if (ctx->cacertslen > TLS_MAX_CERTIFICATES_SIZE) { + error_print(); + return -1; + } + memcpy(conn->ca_certs, ctx->cacerts, ctx->cacertslen); + conn->ca_certs_len = ctx->cacertslen; + + conn->sign_key = ctx->signkey; + conn->kenc_key = ctx->kenckey; + + return 1; +} + +void tls_cleanup(TLS_CONNECT *conn) +{ + gmssl_secure_clear(conn, sizeof(TLS_CONNECT)); +} + + +int tls_set_socket(TLS_CONNECT *conn, int sock) +{ + int opts; + + if ((opts = fcntl(sock, F_GETFL)) < 0) { + error_print(); + perror(""); + return -1; + } + opts &= ~O_NONBLOCK; + if (fcntl(sock, F_SETFL, opts) < 0) { + error_print(); + return -1; + } + conn->sock = sock; + return 1; +} + +int tls_do_handshake(TLS_CONNECT *conn) +{ + switch (conn->version) { + case TLS_version_tlcp: + if (conn->is_client) return tlcp_do_connect(conn); + else return tlcp_do_accept(conn); + /* + case TLS_version_tls12: + if (conn->is_client) return tls12_do_connect(conn); + else return tls12_do_accept(conn); + case TLS_version_tls13: + if (conn->is_client) return tls13_do_connect(conn); + else return tls13_do_accept(conn); + */ + } + error_print(); + return -1; +} + +int tls_get_verify_result(TLS_CONNECT *conn, int *result) +{ + *result = conn->verify_result; + return 1; +} diff --git a/src/tls_trace.c b/src/tls_trace.c index f69188fb..d600292d 100644 --- a/src/tls_trace.c +++ b/src/tls_trace.c @@ -899,6 +899,7 @@ int tls_application_data_print(FILE *fp, const uint8_t *data, size_t datalen, in // 当消息为ClientKeyExchange,ServerKeyExchange,需要密码套件中的密钥交换算法信息 // 当消息为加密的Finished,记录类型为Handshake,但是记录负载数据中没有Handshake头 +// 注意:这里的recordlen 是冗余的,要容忍recordlen的错误 int tls_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int format, int indent) { const uint8_t *data; @@ -913,10 +914,15 @@ int tls_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int for format_print(fp, format, indent, "Record\n"); indent += 4; format_print(fp, format, indent, "ContentType: %s (%d)\n", tls_record_type_name(record[0]), record[0]); format_print(fp, format, indent, "Version: %s (%d.%d)\n", tls_version_text(version), version >> 8, version & 0xff); - format_print(fp, format, indent, "Length: %d\n", tls_record_length(record)); + format_print(fp, format, indent, "Length: %d\n", tls_record_data_length(record)); - data = record + 5; - datalen = recordlen - 5; + data = tls_record_data(record); + datalen = tls_record_data_length(record); + + if (recordlen < tls_record_length(record)) { + error_print(); + return -1; + } // 最高字节设置后强制打印记录原始数据 if (format >> 24) { @@ -954,6 +960,12 @@ int tls_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int for error_print(); return -1; } + + recordlen -= tls_record_length(record); + if (recordlen) { + format_print(fp, 0, 0, "DataLeftInRecord: %zu\n", recordlen); + } + fprintf(fp, "\n"); return 1; } @@ -965,6 +977,7 @@ int tls_secrets_print(FILE *fp, const uint8_t *key_block, size_t key_block_len, int format, int indent) { + // 应该检查一下key_block_len的值,判断是否支持,或者算法选择, 或者要求输入一个cipher_suite参数 format_bytes(stderr, format, indent, "pre_master_secret", pre_master_secret, pre_master_secret_len); format_bytes(stderr, format, indent, "client_random", client_random, 32); format_bytes(stderr, format, indent, "server_random", server_random, 32); diff --git a/src/x509_cer.c b/src/x509_cer.c index e27f0d0c..8afd09b5 100644 --- a/src/x509_cer.c +++ b/src/x509_cer.c @@ -1016,6 +1016,9 @@ int x509_certificate_print(FILE *fp, int fmt, int ind, const char *label, const size_t len; int val; + format_print(fp, fmt, ind, "%s\n", label); + ind += 4; + if (asn1_sequence_from_der(&p, &len, &d, &dlen) != 1) goto err; x509_tbs_cert_print(fp, fmt, ind, "tbsCertificate", p, len); if (x509_signature_algor_from_der(&val, &d, &dlen) != 1) goto err; @@ -1640,3 +1643,65 @@ int x509_certs_print(FILE *fp, int fmt, int ind, const char *label, const uint8_ } return 1; } + +#include +#include + +int x509_cert_new_from_file(uint8_t **out, size_t *outlen, const char *file) +{ + int ret = -1; + FILE *fp = NULL; + struct stat st; + uint8_t *buf = NULL; + size_t buflen; + + if (!(fp = fopen(file, "r")) + || fstat(fileno(fp), &st) < 0 + || (buflen = (st.st_size * 3)/4 + 1) < 0 + || (buf = malloc((st.st_size * 3)/4 + 1)) == NULL) { + error_print(); + goto end; + } + if (x509_cert_from_pem(buf, outlen, buflen, fp) != 1) { + error_print(); + goto end; + } + *out = buf; + buf = NULL; + ret = 1; +end: + if (fp) fclose(fp); + if (buf) free(buf); + return ret; +} + +int x509_certs_new_from_file(uint8_t **out, size_t *outlen, const char *file) +{ + int ret = -1; + FILE *fp = NULL; + struct stat st; + uint8_t *buf = NULL; + size_t buflen; + + if (!(fp = fopen(file, "r")) + || fstat(fileno(fp), &st) < 0 + || (buflen = (st.st_size * 3)/4 + 1) < 0 + || (buf = malloc((st.st_size * 3)/4 + 1)) == NULL) { + error_print(); + goto end; + } + if (x509_certs_from_pem(buf, outlen, buflen, fp) != 1) { + error_print(); + goto end; + } + *out = buf; + buf = NULL; + ret = 1; +end: + if (fp) fclose(fp); + if (buf) free(buf); + return ret; +} + + + diff --git a/tools/tlcp_client.c b/tools/tlcp_client.c index b89edd9c..33d32ffe 100644 --- a/tools/tlcp_client.c +++ b/tools/tlcp_client.c @@ -47,21 +47,27 @@ */ #include +#include #include #include + #include +#include +#include +#include +#include #include #include +static int client_ciphers[] = { TLCP_cipher_ecc_sm4_cbc_sm3, }; + static const char *http_get = "GET / HTTP/1.1\r\n" "Hostname: aaa\r\n" "\r\n\r\n"; - -// 虽然服务器可以用双证书,但是客户端只能使用一个证书,也就是签名证书 -static const char *options = "-host str [-port num] [-cacert file] [-cert file -key file [-pass str]]"; +static const char *options = "-host str [-port num] [-cacert file] [-cert file -key file -pass str]"; int tlcp_client_main(int argc, char *argv[]) { @@ -69,28 +75,25 @@ int tlcp_client_main(int argc, char *argv[]) char *prog = argv[0]; char *host = NULL; int port = 443; + char *cacertfile = NULL; + char *certfile = NULL; + char *keyfile = NULL; char *pass = NULL; + struct sockaddr_in server; + int sock; + TLS_CTX ctx; TLS_CONNECT conn; char buf[100] = {0}; size_t len = sizeof(buf); - char send_buf[1024] = {0}; size_t send_len; - char *file; - - FILE *cacertfp = NULL; - FILE *certfp = NULL; - FILE *keyfp = NULL; - SM2_KEY sign_key; - - if (argc < 2) { + argc--; + argv++; + if (argc < 1) { fprintf(stderr, "usage: %s %s\n", prog, options); return 1; } - - argc--; - argv++; while (argc >= 1) { if (!strcmp(*argv, "-help")) { printf("usage: %s %s\n", prog, options); @@ -103,25 +106,13 @@ int tlcp_client_main(int argc, char *argv[]) port = atoi(*(++argv)); } else if (!strcmp(*argv, "-cacert")) { if (--argc < 1) goto bad; - file = *(++argv); - if (!(cacertfp = fopen(file, "r"))) { - error_print(); - return -1; - } + cacertfile = *(++argv); } else if (!strcmp(*argv, "-cert")) { if (--argc < 1) goto bad; - file = *(++argv); - if (!(certfp = fopen(file, "r"))) { - error_print(); - return -1; - } + certfile = *(++argv); } else if (!strcmp(*argv, "-key")) { if (--argc < 1) goto bad; - file = *(++argv); - if (!(keyfp = fopen(file, "r"))) { - error_print(); - return -1; - } + keyfile = *(++argv); } else if (!strcmp(*argv, "-pass")) { if (--argc < 1) goto bad; pass = *(++argv); @@ -137,35 +128,45 @@ bad: } if (!host) { - error_print(); + fprintf(stderr, "%s: '-in' option required\n", prog); return -1; } - if (certfp) { - if (!keyfp) { - error_print(); - return -1; - } - if (!pass) { - pass = getpass("Password : "); - } - if (sm2_private_key_info_decrypt_from_pem(&sign_key, pass, keyfp) != 1) { - error_print(); - return -1; - } - } - + memset(&ctx, 0, sizeof(ctx)); memset(&conn, 0, sizeof(conn)); - if (tlcp_connect(&conn, host, port, cacertfp, certfp, &sign_key) != 1) { - error_print(); - return -1; + server.sin_addr.s_addr = inet_addr(host); + server.sin_family = AF_INET; + server.sin_port = htons(port); + + + if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + fprintf(stderr, "%s: open socket error : %s\n", prog, strerror(errno)); + goto end; + } + if (connect(sock, (struct sockaddr *)&server , sizeof(server)) < 0) { + fprintf(stderr, "%s: connect error : %s\n", prog, strerror(errno)); + goto end; + } + + if (tls_ctx_init(&ctx, TLS_version_tlcp, TLS_client_mode) != 1 + || tls_ctx_set_cipher_suites(&ctx, client_ciphers, sizeof(client_ciphers)/sizeof(client_ciphers[0])) != 1 + || tls_ctx_set_ca_certificates(&ctx, cacertfile, TLS_DEFAULT_VERIFY_DEPTH) != 1 + || tls_ctx_set_certificate_and_key(&ctx, certfile, keyfile, pass) != 1) { + fprintf(stderr, "%s: context init error\n", prog); + goto end; + } + if (tls_init(&conn, &ctx) != 1 + || tls_set_socket(&conn, sock) != 1 + || tls_do_handshake(&conn) != 1) { + fprintf(stderr, "%s: error\n", prog); + goto end; } for (;;) { + size_t sentlen; memset(send_buf, 0, sizeof(send_buf)); - if (!fgets(send_buf, sizeof(send_buf), stdin)) { if (feof(stdin)) { tls_shutdown(&conn); @@ -174,28 +175,26 @@ bad: continue; } } - - if (tls_send(&conn, (uint8_t *)send_buf, strlen(send_buf)) != 1) { - error_print(); - return -1; + if (tls_send(&conn, (uint8_t *)send_buf, strlen(send_buf), &sentlen) != 1) { + fprintf(stderr, "%s: send error\n", prog); + goto end; } - for (;;) { + { memset(buf, 0, sizeof(buf)); len = sizeof(buf); - if (tls_recv(&conn, (uint8_t *)buf, &len) != 1) { - error_print(); - return -1; - } - if (len > 0) { - printf("%s\n", buf); - break; + if (tls_recv(&conn, (uint8_t *)buf, sizeof(len), &len) != 1) { + goto end; } + buf[len] = 0; + printf("%s\n", buf); } - } end: + close(sock); + tls_ctx_cleanup(&ctx); + tls_cleanup(&conn); return 0; } diff --git a/tools/tlcp_server.c b/tools/tlcp_server.c index e6c352e7..b97829d8 100644 --- a/tools/tlcp_server.c +++ b/tools/tlcp_server.c @@ -51,9 +51,14 @@ #include #include #include +#include +#include +#include +#include #include #include #include +#include static const char *options = "[-port num] -cert file -key file [-pass str] -ex_key file [-ex_pass str] [-cacert file]"; @@ -70,19 +75,20 @@ int tlcp_server_main(int argc , char **argv) char *encpass = NULL; char *cacertfile = NULL; - FILE *certfp = NULL; - FILE *signkeyfp = NULL; - FILE *enckeyfp = NULL; - FILE *cacertfp = NULL; - SM2_KEY signkey; - SM2_KEY enckey; - + int server_ciphers[] = { TLCP_cipher_ecc_sm4_cbc_sm3, }; uint8_t verify_buf[4096]; + TLS_CTX ctx; TLS_CONNECT conn; char buf[1600] = {0}; size_t len = sizeof(buf); + int sock; + struct sockaddr_in server_addr; + struct sockaddr_in client_addr; + socklen_t client_addrlen; + int conn_sock; + argc--; argv++; @@ -102,37 +108,21 @@ int tlcp_server_main(int argc , char **argv) } else if (!strcmp(*argv, "-cert")) { if (--argc < 1) goto bad; certfile = *(++argv); - if (!(certfp = fopen(certfile, "r"))) { - fprintf(stderr, "%s: open '%s' failure : %s\n", prog, certfile, strerror(errno)); - goto end; - } } else if (!strcmp(*argv, "-key")) { if (--argc < 1) goto bad; signkeyfile = *(++argv); - if (!(signkeyfp = fopen(signkeyfile, "r"))) { - fprintf(stderr, "%s: open '%s' failure : %s\n", prog, signkeyfile, strerror(errno)); - goto end; - } } else if (!strcmp(*argv, "-pass")) { if (--argc < 1) goto bad; signpass = *(++argv); } else if (!strcmp(*argv, "-ex_key")) { if (--argc < 1) goto bad; enckeyfile = *(++argv); - if (!(enckeyfp = fopen(enckeyfile, "r"))) { - fprintf(stderr, "%s: open '%s' failure : %s\n", prog, enckeyfile, strerror(errno)); - goto end; - } } else if (!strcmp(*argv, "-ex_pass")) { if (--argc < 1) goto bad; encpass = *(++argv); } else if (!strcmp(*argv, "-cacert")) { if (--argc < 1) goto bad; cacertfile = *(++argv); - if (!(cacertfp = fopen(cacertfile, "r"))) { - fprintf(stderr, "%s: open '%s' failure : %s\n", prog, cacertfile, strerror(errno)); - goto end; - } } else { fprintf(stderr, "%s: invalid option '%s'\n", prog, *argv); return 1; @@ -145,60 +135,97 @@ bad: } if (!certfile) { fprintf(stderr, "%s: '-cert' option required\n", prog); - goto end; + return 1; } if (!signkeyfile) { fprintf(stderr, "%s: '-key' option required\n", prog); - goto end; + return 1; } if (!signpass) { fprintf(stderr, "%s: '-pass' option required\n", prog); - goto end; + return 1; } if (!enckeyfile) { fprintf(stderr, "%s: '-ex_key' option required\n", prog); - goto end; + return 1; } if (!encpass) { fprintf(stderr, "%s: '-ex_pass' option required\n", prog); - goto end; + return 1; } - if (sm2_private_key_info_decrypt_from_pem(&signkey, signpass, signkeyfp) != 1) { - fprintf(stderr, "%s: load private key failure\n", prog); - goto end; - } - if (sm2_private_key_info_decrypt_from_pem(&enckey, encpass, enckeyfp) != 1) { - fprintf(stderr, "%s: load private key failure\n", prog); - goto end; - } - - printf("start ...........\n"); - -restart: + memset(&ctx, 0, sizeof(ctx)); memset(&conn, 0, sizeof(conn)); - if (tlcp_accept(&conn, port, certfp, &signkey, &enckey, cacertfp, verify_buf, 4096) != 1) { - fprintf(stderr, "%s: tlcp accept failure\n", prog); + if (tls_ctx_init(&ctx, TLS_version_tlcp, TLS_server_mode) != 1 + || tls_ctx_set_cipher_suites(&ctx, server_ciphers, sizeof(server_ciphers)/sizeof(int)) != 1 + || tls_ctx_set_tlcp_server_certificate_and_keys(&ctx, certfile, signkeyfile, signpass, enckeyfile, encpass) != 1) { + error_print(); + return -1; + } + if (cacertfile) { + if (tls_ctx_set_ca_certificates(&ctx, cacertfile, TLS_DEFAULT_VERIFY_DEPTH) != 1) { + error_print(); + return -1; + } + } + + // Socket + if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + error_print(); + return 1; + } + server_addr.sin_family = AF_INET; + server_addr.sin_addr.s_addr = INADDR_ANY; + server_addr.sin_port = htons(port); + if (bind(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) { + error_print(); + perror("tlcp_accept: bind: "); goto end; } + puts("start listen ...\n"); + listen(sock, 1); + + + +restart: + + client_addrlen = sizeof(client_addr); + if ((conn_sock = accept(sock, (struct sockaddr *)&client_addr, &client_addrlen)) < 0) { + error_print(); + goto end; + } + puts("socket connected\n"); + + if (tls_init(&conn, &ctx) != 1 + || tls_set_socket(&conn, conn_sock) != 1) { + error_print(); + return -1; + } + + if (tls_do_handshake(&conn) != 1) { + error_print(); // 为什么这个会触发呢? + return -1; + } for (;;) { int rv; + size_t sentlen; do { len = sizeof(buf); - if ((rv = tls_recv(&conn, (uint8_t *)buf, &len)) != 1) { + if ((rv = tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len)) != 1) { if (rv < 0) fprintf(stderr, "%s: recv failure\n", prog); else fprintf(stderr, "%s: Disconnected by remote\n", prog); - close(conn.sock); + //close(conn.sock); + tls_cleanup(&conn); goto restart; } } while (!len); - if (tls_send(&conn, (uint8_t *)buf, len) != 1) { + if (tls_send(&conn, (uint8_t *)buf, len, &sentlen) != 1) { fprintf(stderr, "%s: send failure, close connection\n", prog); close(conn.sock); goto end; @@ -207,11 +234,5 @@ restart: end: - gmssl_secure_clear(&signkey, sizeof(signkey)); - gmssl_secure_clear(&enckey, sizeof(enckey)); - if (certfp) fclose(certfp); - if (signkeyfp) fclose(signkeyfp); - if (enckeyfp) fclose(enckeyfp); - if (cacertfp) fclose(cacertfp); return ret; }