diff --git a/CMakeLists.txt b/CMakeLists.txt index 1c14ca90..c98887fd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,8 +63,9 @@ add_library( src/sm9_lib.c src/tlcp.c src/tls.c -# src/tls12.c + src/tls12.c # src/tls13.c + src/tls_ext.c src/tls_trace.c src/version.c src/x509_alg.c @@ -135,8 +136,8 @@ add_executable( tools/reqsign.c tools/tlcp_client.c tools/tlcp_server.c -# tools/tls12_client.c -# tools/tls12_server.c + tools/tls12_client.c + tools/tls12_server.c # tools/tls13_client.c # tools/tls13_server.c tools/sdfutil.c diff --git a/include/gmssl/tls.h b/include/gmssl/tls.h index 71cedb6a..8596e2a9 100644 --- a/include/gmssl/tls.h +++ b/include/gmssl/tls.h @@ -63,6 +63,40 @@ extern "C" { #endif + +/* +TLS Public API + + TLS_PROTOCOL + TLS_protocol_tlcp + TLS_protocol_tls12 + TLS_protocol_tls13 + + TLS_CIPHER_SUITE + TLS_cipher_ecc_sm4_cbc_sm3 + TLS_cipher_ecc_sm4_gcm_sm3 + TLS_cipher_ecdhe_sm4_cbc_sm3 + TLS_cipher_ecdhe_sm4_gcm_sm3 + TLS_cipher_sm4_gcm_sm3 + + TLS_CTX + tls_ctx_init + tls_ctx_set_cipher_suites + tls_ctx_set_ca_certificates + tls_ctx_set_certificate_and_key + tls_ctx_set_tlcp_server_certificate_and_keys + tls_ctx_cleanup + + TLS_CONNECT + tls_init + tls_set_socket + tls_do_handshake + tls_send + tls_recv + tls_shutdown + tls_cleanup +*/ + typedef uint32_t uint24_t; #define tls_uint8_size() 1 @@ -85,21 +119,22 @@ int tls_array_from_bytes(const uint8_t **data, size_t datalen, const uint8_t **i int tls_uint8array_from_bytes(const uint8_t **data, size_t *datalen, const uint8_t **in, size_t *inlen); int tls_uint16array_from_bytes(const uint8_t **data, size_t *datalen, const uint8_t **in, size_t *inlen); int tls_uint24array_from_bytes(const uint8_t **data, size_t *datalen, const uint8_t **in, size_t *inlen); +int tls_length_is_zero(size_t len); typedef enum { - TLS_version_tlcp = 0x0101, - TLS_version_ssl2 = 0x0200, - TLS_version_ssl3 = 0x0300, - TLS_version_tls1 = 0x0301, - TLS_version_tls11 = 0x0302, - TLS_version_tls12 = 0x0303, - TLS_version_tls13 = 0x0304, - TLS_version_dtls1 = 0xfeff, // {254, 255} - TLS_version_dtls12 = 0xfefd, // {254, 253} -} TLS_VERSION; + TLS_protocol_tlcp = 0x0101, + TLS_protocol_ssl2 = 0x0200, + TLS_protocol_ssl3 = 0x0300, + TLS_protocol_tls1 = 0x0301, + TLS_protocol_tls11 = 0x0302, + TLS_protocol_tls12 = 0x0303, + TLS_protocol_tls13 = 0x0304, + TLS_protocol_dtls1 = 0xfeff, // {254, 255} + TLS_protocol_dtls12 = 0xfefd, // {254, 253} +} TLS_PROTOCOL; -const char *tls_version_text(int version); +const char *tls_protocol_name(int proto); typedef enum { @@ -110,24 +145,18 @@ typedef enum { TLS_cipher_sm4_ccm_sm3 = 0x00c7, // TLCP, GB/T 38636-2020, GM/T 0024-2012 - TLCP_cipher_ecdhe_sm4_cbc_sm3 = 0xe011, - TLCP_cipher_ecdhe_sm4_gcm_sm3 = 0xe051, - TLCP_cipher_ecc_sm4_cbc_sm3 = 0xe013, - TLCP_cipher_ecc_sm4_gcm_sm3 = 0xe053, - TLCP_cipher_ibsdh_sm4_cbc_sm3 = 0xe015, - TLCP_cipher_ibsdh_sm4_gcm_sm3 = 0xe055, - TLCP_cipher_ibc_sm4_cbc_sm3 = 0xe017, - TLCP_cipher_ibc_sm4_gcm_sm3 = 0xe057, - TLCP_cipher_rsa_sm4_cbc_sm3 = 0xe019, - TLCP_cipher_rsa_sm4_gcm_sm3 = 0xe059, - TLCP_cipher_rsa_sm4_cbc_sha256 = 0xe01c, - TLCP_cipher_rsa_sm4_gcm_sha256 = 0xe05a, - - // GmSSL v2.5 - GMSSL_cipher_ecdhe_sm2_with_sm4_sm3 = 0xe102, - GMSSL_cipher_ecdhe_sm2_with_sm4_gcm_sm3 = 0xe107, - GMSSL_cipher_ecdhe_sm2_with_sm4_ccm_sm3 = 0xe108, - GMSSL_cipher_ecdhe_sm2_with_zuc_sm3 = 0xe10d, + TLS_cipher_ecdhe_sm4_cbc_sm3 = 0xe011, // 可以让TLSv1.2使用这个 + TLS_cipher_ecdhe_sm4_gcm_sm3 = 0xe051, + TLS_cipher_ecc_sm4_cbc_sm3 = 0xe013, + TLS_cipher_ecc_sm4_gcm_sm3 = 0xe053, + TLS_cipher_ibsdh_sm4_cbc_sm3 = 0xe015, + TLS_cipher_ibsdh_sm4_gcm_sm3 = 0xe055, + TLS_cipher_ibc_sm4_cbc_sm3 = 0xe017, + TLS_cipher_ibc_sm4_gcm_sm3 = 0xe057, + TLS_cipher_rsa_sm4_cbc_sm3 = 0xe019, + TLS_cipher_rsa_sm4_gcm_sm3 = 0xe059, + TLS_cipher_rsa_sm4_cbc_sha256 = 0xe01c, + TLS_cipher_rsa_sm4_gcm_sha256 = 0xe05a, // TLS 1.3 RFC 8446 TLS_cipher_aes_128_gcm_sha256 = 0x1301, // Mandatory-to-implement @@ -256,7 +285,7 @@ typedef enum { TLS_extension_supported_ekt_ciphers = 39, TLS_extension_pre_shared_key = 41, TLS_extension_early_data = 42, - TLS_extension_supported_versions = 43, + TLS_extension_supported_protocols = 43, TLS_extension_cookie = 44, TLS_extension_psk_key_exchange_modes = 46, TLS_extension_certificate_authorities = 47, @@ -273,6 +302,8 @@ typedef enum { TLS_extension_renegotiation_info = 65281, } TLS_EXTENSION_TYPE; +const char *tls_extension_name(int ext); + typedef enum { TLS_point_uncompressed = 0, @@ -282,6 +313,7 @@ typedef enum { const char *tls_ec_point_format_name(int format); + typedef enum { TLS_curve_type_explicit_prime = 1, TLS_curve_type_explicit_char2 = 2, @@ -290,6 +322,9 @@ typedef enum { const char *tls_curve_type_name(int type); + +// 与其支持v2,还不如直接修改v2,让v2和v3兼容 + typedef enum { TLS_curve_secp256k1 = 22, TLS_curve_secp256r1 = 23, @@ -299,11 +334,11 @@ typedef enum { TLS_curve_brainpoolp384r1 = 27, TLS_curve_brainpoolp512r1 = 28, TLS_curve_x25519 = 29, - TLS_curve_x448 = 99, //30, 应该用一个宏来处理 + TLS_curve_x448 = 30, TLS_curve_brainpoolp256r1tls13 = 31, TLS_curve_brainpoolp384r1tls13 = 32, TLS_curve_brainpoolp512r1tls13 = 33, - TLS_curve_sm2p256v1 = 30,//41, // in gmssl v2, is 30 + TLS_curve_sm2p256v1 = 41, // GmSSLv2: 30 } TLS_NAMED_CURVE; const char *tls_named_curve_name(int curve); @@ -321,7 +356,7 @@ typedef enum { TLS_sig_rsa_pkcs1_sha512 = 0x0601, TLS_sig_ecdsa_secp521r1_sha512 = 0x0603, TLS_sig_rsa_pkcs1_sha512_legacy = 0x0620, - TLS_sig_sm2sig_sm3 = 0x0707,//0x0708, // is 0707 in gmsslv2 + TLS_sig_sm2sig_sm3 = 0x0708, // GmSSLv2: 0x0707 TLS_sig_rsa_pss_rsae_sha256 = 0x0804, TLS_sig_rsa_pss_rsae_sha384 = 0x0805, TLS_sig_rsa_pss_rsae_sha512 = 0x0806, @@ -414,7 +449,7 @@ int tls_record_decrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key, int tls_seq_num_incr(uint8_t seq_num[8]); int tls_random_generate(uint8_t random[32]); int tls_random_print(FILE *fp, const uint8_t random[32], int format, int indent); -int tls_pre_master_secret_generate(uint8_t pre_master_secret[48], int version); +int tls_pre_master_secret_generate(uint8_t pre_master_secret[48], int protocol); int tls_pre_master_secret_print(FILE *fp, const uint8_t pre_master_secret[48], int format, int indent); int tls_secrets_print(FILE *fp, @@ -427,7 +462,7 @@ int tls_secrets_print(FILE *fp, typedef struct { uint8_t type; - uint8_t version[2]; + uint8_t protocol[2]; uint8_t data_length[2]; } TLS_RECORD_HEADER; @@ -439,13 +474,13 @@ typedef struct { #define tls_record_type(record) ((record)[0]) #define tls_record_header(record) ((record)+0) -#define tls_record_version(record) (((uint16_t)((record)[1]) << 8) | (record)[2]) +#define tls_record_protocol(record) (((uint16_t)((record)[1]) << 8) | (record)[2]) #define tls_record_data(record) ((record)+TLS_RECORD_HEADER_SIZE) #define tls_record_data_length(record) (((uint16_t)((record)[3]) << 8) | (record)[4]) #define tls_record_length(record) (TLS_RECORD_HEADER_SIZE + tls_record_data_length(record)) int tls_record_set_type(uint8_t *record, int type); -int tls_record_set_version(uint8_t *record, int version); +int tls_record_set_protocol(uint8_t *record, int protocol); int tls_record_set_data_length(uint8_t *record, size_t length); int tls_record_set_data(uint8_t *record, const uint8_t *data, size_t datalen); @@ -488,23 +523,23 @@ int tls_hello_request_print(FILE *fp, const uint8_t *data, size_t datalen, int f #define TLS_MAX_SESSION_ID_SIZE 32 int tls_record_set_handshake_client_hello(uint8_t *record, size_t *recordlen, - int client_version, const uint8_t random[32], + int client_protocol, const uint8_t random[32], const uint8_t *session_id, size_t session_id_len, const int *cipher_suites, size_t cipher_suites_count, const uint8_t *exts, size_t exts_len); int tls_record_get_handshake_client_hello(const uint8_t *record, - int *client_version, const uint8_t **random, + int *client_protocol, const uint8_t **random, const uint8_t **session_id, size_t *session_id_len, const uint8_t **cipher_suites, size_t *cipher_suites_len, const uint8_t **exts, size_t *exts_len); int tls_client_hello_print(FILE *fp, const uint8_t *data, size_t datalen, int format, int indent); int tls_record_set_handshake_server_hello(uint8_t *record, size_t *recordlen, - int server_version, const uint8_t random[32], + int server_protocol, const uint8_t random[32], const uint8_t *session_id, size_t session_id_len, int cipher_suite, const uint8_t *exts, size_t exts_len); int tls_record_get_handshake_server_hello(const uint8_t *record, - int *version, const uint8_t **random, const uint8_t **session_id, size_t *session_id_len, + int *protocol, const uint8_t **random, const uint8_t **session_id, size_t *session_id_len, int *cipher_suite, const uint8_t **exts, size_t *exts_len); int tls_server_hello_print(FILE *fp, const uint8_t *server_hello, size_t len, int format, int indent); @@ -512,6 +547,18 @@ int tls_server_hello_print(FILE *fp, const uint8_t *server_hello, size_t len, in int tls_ext_signature_algors_to_bytes(const int *algors, size_t algors_count, uint8_t **out, size_t *outlen); +int tls_exts_add_ec_point_formats(uint8_t *exts, size_t *extslen, size_t maxlen, const int *formats, size_t formats_cnt); +int tls_exts_add_supported_groups(uint8_t *exts, size_t *extslen, size_t maxlen, const int *curves, size_t curves_cnt); +int tls_exts_add_signature_algors(uint8_t *exts, size_t *extslen, size_t maxlen, const int *algs, size_t algs_cnt); + +int tls_ext_from_bytes(int *type, const uint8_t **data, size_t *datalen, const uint8_t **in, size_t *inlen); +int tls_process_client_ec_point_formats(const uint8_t *data, size_t datalen, uint8_t *exts, size_t *extslen, size_t maxlen); +int tls_process_client_signature_algorithms(const uint8_t *data, size_t datalen, uint8_t *exts, size_t *extslen, size_t maxlen); +int tls_process_client_supported_groups(const uint8_t *data, size_t datalen, uint8_t *exts, size_t *extslen, size_t maxlen); +int tls_process_client_exts(const uint8_t *exts, size_t extslen, uint8_t *out, size_t *outlen, size_t maxlen); + +int tls_process_server_exts(const uint8_t *exts, size_t extslen, int *ec_point_format, int *supported_group, int *signature_algor); + // Certificate int tls_record_set_handshake_certificate(uint8_t *record, size_t *recordlen, const uint8_t *certs, size_t certslen); @@ -532,7 +579,7 @@ int tls_verify_server_ecdh_params(const SM2_KEY *server_sign_key, int tls_record_set_handshake_server_key_exchange_ecdhe(uint8_t *record, size_t *recordlen, int curve, const SM2_POINT *point, const uint8_t *sig, size_t siglen); int tls_record_get_handshake_server_key_exchange_ecdhe(const uint8_t *record, - int *curve, SM2_POINT *point, uint8_t *sig, size_t *siglen); + int *curve, SM2_POINT *point, const uint8_t **sig, size_t *siglen); int tls_server_key_exchange_ecdhe_print(FILE *fp, const uint8_t *data, size_t datalen, int format, int indent); @@ -575,8 +622,8 @@ int tls_client_key_exchange_pke_print(FILE *fp, const uint8_t *cke, size_t ckele int tls_client_key_exchange_print(FILE *fp, const uint8_t *cke, size_t ckelen, int format, int indent); int tls_record_set_handshake_client_key_exchange_ecdhe(uint8_t *record, size_t *recordlen, - const SM2_POINT *point); -int tls_record_get_handshake_client_key_exchange_ecdhe(const uint8_t *record, SM2_POINT *point); + const SM2_POINT *point); // 这里不应该支持SM2_POINT类型 +int tls_record_get_handshake_client_key_exchange_ecdhe(const uint8_t *record, SM2_POINT *point); int tls_client_key_exchange_ecdhe_print(FILE *fp, const uint8_t *data, size_t datalen, int format, int indent); @@ -610,12 +657,14 @@ int tls_client_verify_finish(TLS_CLIENT_VERIFY_CTX *ctx, const uint8_t *sig, siz void tls_client_verify_cleanup(TLS_CLIENT_VERIFY_CTX *ctx); // Finished +// FIXME: 支持TLS 1.3 提供MIN, MAX或TLS12, TLS13, TLCP... #define TLS_VERIFY_DATA_SIZE 12 // TLS 1.3或者其他版本支持更长的verify_data #define TLS_FINISHED_RECORD_SIZE (TLS_RECORD_HEADER_SIZE + TLS_HANDSHAKE_HEADER_SIZE + TLS_VERIFY_DATA_SIZE) // 21 #define TLS_MAX_PADDING_SIZE (1 + 255) #define TLS_MAC_SIZE SM3_HMAC_SIZE #define TLS_FINISHED_RECORD_BUF_SIZE (TLS_FINISHED_RECORD_SIZE + TLS_MAC_SIZE + TLS_MAX_PADDING_SIZE) // 309 + int tls_record_set_handshake_finished(uint8_t *record, size_t *recordlen, const uint8_t *verify_data, size_t verify_data_len); int tls_record_get_handshake_finished(const uint8_t *record, @@ -663,7 +712,7 @@ enum { #define TLS_MAX_CIPHER_SUITES_COUNT 64 typedef struct { - int protocol_version; + int protocol; int is_client; int cipher_suites[TLS_MAX_CIPHER_SUITES_COUNT]; size_t cipher_suites_cnt; @@ -676,7 +725,7 @@ typedef struct { int verify_depth; } TLS_CTX; -int tls_ctx_init(TLS_CTX *ctx, int version, int is_client); +int tls_ctx_init(TLS_CTX *ctx, int protocol, int is_client); int tls_ctx_set_cipher_suites(TLS_CTX *ctx, const int *cipher_suites, size_t cipher_suites_cnt); int tls_ctx_set_ca_certificates(TLS_CTX *ctx, const char *cacertsfile, int depth); int tls_ctx_set_certificate_and_key(TLS_CTX *ctx, const char *chainfile, @@ -694,7 +743,7 @@ void tls_ctx_cleanup(TLS_CTX *ctx); typedef struct { - int version; + int protocol; int is_client; int cipher_suites[TLS_MAX_CIPHER_SUITES_COUNT]; size_t cipher_suites_cnt; @@ -735,6 +784,9 @@ typedef struct { } TLS_CONNECT; +#define TLS_MAX_EXTENSIONS_SIZE 512 // 这个应该再考虑一下数值,是否可以用其他的缓冲区装载? + + int tls_init(TLS_CONNECT *conn, const TLS_CTX *ctx); int tls_set_socket(TLS_CONNECT *conn, int sock); int tls_do_handshake(TLS_CONNECT *conn); @@ -757,16 +809,22 @@ int tls13_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen, size_t pa int tls13_recv(TLS_CONNECT *conn, uint8_t *data, size_t *datalen); + + + + #define TLS_DEBUG #ifdef TLS_DEBUG # define tls_trace(s) fprintf(stderr,(s)) # define tls_record_trace(fp,rec,reclen,fmt,ind) tls_record_print(fp,rec,reclen,fmt,ind) # define tlcp_record_trace(fp,rec,reclen,fmt,ind) tlcp_record_print(fp,rec,reclen,fmt,ind) +# define tls12_record_trace(fp,rec,reclen,fmt,ind) tls12_record_print(fp,rec,reclen,fmt,ind) #else # define tls_trace(s) # define tls_record_trace(fp,rec,reclen,fmt,ind) # define tlcp_record_trace(fp,rec,reclen,fmt,ind) +# define tls12_record_trace(fp,rec,reclen,fmt,ind) #endif diff --git a/src/tlcp.c b/src/tlcp.c index 8a0b5038..743549a9 100644 --- a/src/tlcp.c +++ b/src/tlcp.c @@ -68,7 +68,7 @@ #include -static const int tlcp_ciphers[] = { TLCP_cipher_ecc_sm4_cbc_sm3 }; +static const int tlcp_ciphers[] = { TLS_cipher_ecc_sm4_cbc_sm3 }; static const size_t tlcp_ciphers_count = sizeof(tlcp_ciphers)/sizeof(tlcp_ciphers[0]); int tlcp_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int format, int indent) @@ -94,7 +94,7 @@ int tlcp_record_set_handshake_server_key_exchange_pke(uint8_t *record, size_t *r error_print(); return -1; } - if (tls_record_version(record) != TLS_version_tlcp) { + if (tls_record_protocol(record) != TLS_protocol_tlcp) { error_print(); return -1; } @@ -125,7 +125,7 @@ int tlcp_record_get_handshake_server_key_exchange_pke(const uint8_t *record, error_print(); return -1; } - if (tls_record_version(record) != TLS_version_tlcp) { + if (tls_record_protocol(record) != TLS_protocol_tlcp) { error_print(); return -1; } @@ -168,7 +168,7 @@ int tlcp_do_connect(TLS_CONNECT *conn) uint8_t client_random[32]; uint8_t server_random[32]; - int version; + int protocol; int cipher_suite; const uint8_t *random; const uint8_t *session_id; @@ -206,8 +206,8 @@ int tlcp_do_connect(TLS_CONNECT *conn) // 初始化记录缓冲 - tls_record_set_version(record, TLS_version_tlcp); - tls_record_set_version(finished_record, TLS_version_tlcp); + tls_record_set_protocol(record, TLS_protocol_tlcp); + tls_record_set_protocol(finished_record, TLS_protocol_tlcp); // 准备Finished Context(和ClientVerify) sm3_init(&sm3_ctx); @@ -218,7 +218,7 @@ int tlcp_do_connect(TLS_CONNECT *conn) // send ClientHello tls_random_generate(client_random); if (tls_record_set_handshake_client_hello(record, &recordlen, - TLS_version_tlcp, client_random, NULL, 0, + TLS_protocol_tlcp, client_random, NULL, 0, tlcp_ciphers, tlcp_ciphers_count, NULL, 0) != 1) { error_print(); goto end; @@ -241,19 +241,19 @@ int tlcp_do_connect(TLS_CONNECT *conn) goto end; } tlcp_record_trace(stderr, record, recordlen, 0, 0); - if (tls_record_version(record) != TLS_version_tlcp) { + if (tls_record_protocol(record) != TLS_protocol_tlcp) { error_print(); tls_send_alert(conn, TLS_alert_protocol_version); goto end; } if (tls_record_get_handshake_server_hello(record, - &version, &random, &session_id, &session_id_len, &cipher_suite, + &protocol, &random, &session_id, &session_id_len, &cipher_suite, &exts, &exts_len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); goto end; } - if (version != TLS_version_tlcp) { + if (protocol != TLS_protocol_tlcp) { tls_send_alert(conn, TLS_alert_protocol_version); error_print(); goto end; @@ -278,7 +278,7 @@ int tlcp_do_connect(TLS_CONNECT *conn) // recv ServerCertificate tls_trace("recv ServerCertificate\n"); if (tls_record_recv(record, &recordlen, conn->sock) != 1 - || tls_record_version(record) != TLS_version_tlcp) { + || tls_record_protocol(record) != TLS_protocol_tlcp) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); goto end; @@ -306,15 +306,15 @@ int tlcp_do_connect(TLS_CONNECT *conn) // recv ServerKeyExchange tls_trace("recv ServerKeyExchange\n"); if (tls_record_recv(record, &recordlen, conn->sock) != 1 - || tls_record_version(record) != TLS_version_tlcp) { + || tls_record_protocol(record) != TLS_protocol_tlcp) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); goto end; } tlcp_record_trace(stderr, record, recordlen, 0, 0); if (tlcp_record_get_handshake_server_key_exchange_pke(record, &sig, &siglen) != 1) { - tls_send_alert(conn, TLS_alert_unexpected_message); error_print(); + tls_send_alert(conn, TLS_alert_unexpected_message); goto end; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); @@ -349,7 +349,7 @@ int tlcp_do_connect(TLS_CONNECT *conn) // recv CertificateRequest or ServerHelloDone if (tls_record_recv(record, &recordlen, conn->sock) != 1 - || tls_record_version(record) != TLS_version_tlcp + || tls_record_protocol(record) != TLS_protocol_tlcp || tls_record_get_handshake(record, &handshake_type, &cp, &len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); @@ -386,7 +386,7 @@ int tlcp_do_connect(TLS_CONNECT *conn) // recv ServerHelloDone if (tls_record_recv(record, &recordlen, conn->sock) != 1 - || tls_record_version(record) != TLS_version_tlcp) { + || tls_record_protocol(record) != TLS_protocol_tlcp) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); goto end; @@ -427,7 +427,7 @@ int tlcp_do_connect(TLS_CONNECT *conn) // generate MASTER_SECRET tls_trace("generate secrets\n"); - if (tls_pre_master_secret_generate(pre_master_secret, TLS_version_tlcp) != 1 + if (tls_pre_master_secret_generate(pre_master_secret, TLS_protocol_tlcp) != 1 || tls_prf(pre_master_secret, 48, "master secret", client_random, 32, server_random, 32, 48, conn->master_secret) != 1 @@ -532,7 +532,7 @@ int tlcp_do_connect(TLS_CONNECT *conn) // [ChangeCipherSpec] tls_trace("recv [ChangeCipherSpec]\n"); if (tls_record_recv(record, &recordlen, conn->sock) != 1 - || tls_record_version(record) != TLS_version_tlcp) { + || tls_record_protocol(record) != TLS_protocol_tlcp) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); goto end; @@ -547,7 +547,7 @@ int tlcp_do_connect(TLS_CONNECT *conn) // Finished tls_trace("recv Finished\n"); if (tls_record_recv(record, &recordlen, conn->sock) != 1 - || tls_record_version(record) != TLS_version_tlcp) { + || tls_record_protocol(record) != TLS_protocol_tlcp) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); goto end; @@ -592,7 +592,7 @@ int tlcp_do_connect(TLS_CONNECT *conn) tls_trace("Connection established!\n"); - conn->version = TLS_version_tlcp; + conn->protocol = TLS_protocol_tlcp; conn->cipher_suite = cipher_suite; ret = 1; @@ -612,12 +612,12 @@ int tlcp_do_accept(TLS_CONNECT *conn) uint8_t *record = conn->record; uint8_t finished_record[TLS_FINISHED_RECORD_BUF_SIZE]; // 解密可能导致前面的record被覆盖 size_t recordlen, finished_record_len; - const int server_ciphers[] = { TLCP_cipher_ecc_sm4_cbc_sm3 }; // 未来应该支持GCM/CBC两个套件 + const int server_ciphers[] = { TLS_cipher_ecc_sm4_cbc_sm3 }; // 未来应该支持GCM/CBC两个套件 // ClientHello, ServerHello uint8_t client_random[32]; uint8_t server_random[32]; - int version; + int protocol; const uint8_t *random; const uint8_t *session_id; // TLCP服务器忽略客户端SessionID,也不主动设置SessionID size_t session_id_len; @@ -678,20 +678,20 @@ int tlcp_do_accept(TLS_CONNECT *conn) goto end; } tlcp_record_trace(stderr, record, recordlen, 0, 0); - if (tls_record_version(record) != TLS_version_tlcp) { + if (tls_record_protocol(record) != TLS_protocol_tlcp) { error_print(); tls_send_alert(conn, TLS_alert_protocol_version); goto end; } if (tls_record_get_handshake_client_hello(record, - &version, &random, &session_id, &session_id_len, + &protocol, &random, &session_id, &session_id_len, &client_ciphers, &client_ciphers_len, &exts, &exts_len) != 1) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); goto end; } - if (version != TLS_version_tlcp) { + if (protocol != TLS_protocol_tlcp) { error_print(); tls_send_alert(conn, TLS_alert_protocol_version); goto end; @@ -719,7 +719,7 @@ int tlcp_do_accept(TLS_CONNECT *conn) tls_trace("send ServerHello\n"); tls_random_generate(server_random); if (tls_record_set_handshake_server_hello(record, &recordlen, - TLS_version_tlcp, server_random, NULL, 0, + TLS_protocol_tlcp, server_random, NULL, 0, conn->cipher_suite, NULL, 0) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); @@ -827,7 +827,7 @@ int tlcp_do_accept(TLS_CONNECT *conn) if (conn->ca_certs_len) { tls_trace("recv ClientCertificate\n"); if (tls_record_recv(record, &recordlen, conn->sock) != 1 - || tls_record_version(record) != TLS_version_tlcp) { + || tls_record_protocol(record) != TLS_protocol_tlcp) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); goto end; @@ -851,7 +851,7 @@ int tlcp_do_accept(TLS_CONNECT *conn) // ClientKeyExchange tls_trace("recv ClientKeyExchange\n"); if (tls_record_recv(record, &recordlen, conn->sock) != 1 - || tls_record_version(record) != TLS_version_tlcp) { + || tls_record_protocol(record) != TLS_protocol_tlcp) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); goto end; @@ -881,7 +881,7 @@ int tlcp_do_accept(TLS_CONNECT *conn) if (client_verify) { tls_trace("recv CertificateVerify\n"); if (tls_record_recv(record, &recordlen, conn->sock) != 1 - || tls_record_version(record) != TLS_version_tlcp) { + || tls_record_protocol(record) != TLS_protocol_tlcp) { tls_send_alert(conn, TLS_alert_unexpected_message); error_print(); goto end; @@ -932,7 +932,7 @@ int tlcp_do_accept(TLS_CONNECT *conn) // recv [ChangeCipherSpec] tls_trace("recv [ChangeCipherSpec]\n"); if (tls_record_recv(record, &recordlen, conn->sock) != 1 - || tls_record_version(record) != TLS_version_tlcp) { + || tls_record_protocol(record) != TLS_protocol_tlcp) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); goto end; @@ -947,7 +947,7 @@ int tlcp_do_accept(TLS_CONNECT *conn) // recv ClientFinished tls_trace("recv Finished\n"); if (tls_record_recv(record, &recordlen, conn->sock) != 1 - || tls_record_version(record) != TLS_version_tlcp) { + || tls_record_protocol(record) != TLS_protocol_tlcp) { error_print(); tls_send_alert(conn, TLS_alert_unexpected_message); goto end; @@ -1035,7 +1035,7 @@ int tlcp_do_accept(TLS_CONNECT *conn) goto end; } - conn->version = TLS_version_tlcp; + conn->protocol = TLS_protocol_tlcp; tls_trace("Connection Established!\n\n"); ret = 1; diff --git a/src/tls.c b/src/tls.c index a9df2a22..be7ab744 100644 --- a/src/tls.c +++ b/src/tls.c @@ -268,14 +268,14 @@ int tls_record_set_type(uint8_t *record, int type) return 1; } -int tls_record_set_version(uint8_t *record, int version) +int tls_record_set_protocol(uint8_t *record, int protocol) { - if (!tls_version_text(version)) { + if (!tls_protocol_name(protocol)) { error_print(); return -1; } - record[1] = version >> 8; - record[2] = version; + record[1] = protocol >> 8; + record[2] = protocol; return 1; } @@ -528,14 +528,14 @@ int tls_prf(const uint8_t *secret, size_t secretlen, const char *label, return 1; } -int tls_pre_master_secret_generate(uint8_t pre_master_secret[48], int version) +int tls_pre_master_secret_generate(uint8_t pre_master_secret[48], int protocol) { - if (!tls_version_text(version)) { + if (!tls_protocol_name(protocol)) { error_print(); return -1; } - pre_master_secret[0] = version >> 8; - pre_master_secret[1] = version; + pre_master_secret[0] = protocol >> 8; + pre_master_secret[1] = protocol; if (rand_bytes(pre_master_secret + 2, 46) != 1) { error_print(); return -1; @@ -566,6 +566,7 @@ int tls_cert_type_from_oid(int oid) return 0; } +// 这两个函数没有对应的TLCP版本 int tls_sign_server_ecdh_params(const SM2_KEY *server_sign_key, const uint8_t client_random[32], const uint8_t server_random[32], int curve, const SM2_POINT *point, uint8_t *sig, size_t *siglen) @@ -637,7 +638,7 @@ int tls_record_set_handshake(uint8_t *record, size_t *recordlen, return -1; } - if (!tls_version_text(tls_record_version(record))) { + if (!tls_protocol_name(tls_record_protocol(record))) { error_print(); return -1; } @@ -671,7 +672,7 @@ int tls_record_get_handshake(const uint8_t *record, error_print(); return -1; } - if (!tls_version_text(tls_record_version(record))) { + if (!tls_protocol_name(tls_record_protocol(record))) { error_print(); return -1; } @@ -718,7 +719,7 @@ int tls_record_get_handshake(const uint8_t *record, } int tls_record_set_handshake_client_hello(uint8_t *record, size_t *recordlen, - int version, const uint8_t random[32], + int protocol, const uint8_t random[32], const uint8_t *session_id, size_t session_id_len, const int *cipher_suites, size_t cipher_suites_count, const uint8_t *exts, size_t exts_len) @@ -752,11 +753,11 @@ int tls_record_set_handshake_client_hello(uint8_t *record, size_t *recordlen, p = tls_handshake_data(tls_record_data(record)); len = 0; - if (!tls_version_text(version)) { + if (!tls_protocol_name(protocol)) { error_print(); return -1; } - tls_uint16_to_bytes((uint16_t)version, &p, &len); + tls_uint16_to_bytes((uint16_t)protocol, &p, &len); tls_array_to_bytes(random, 32, &p, &len); tls_uint8array_to_bytes(session_id, session_id_len, &p, &len); tls_uint16_to_bytes(cipher_suites_count * 2, &p, &len); @@ -772,7 +773,7 @@ int tls_record_set_handshake_client_hello(uint8_t *record, size_t *recordlen, tls_uint8_to_bytes((uint8_t)TLS_compression_null, &p, &len); if (exts) { size_t tmp_len = len; - if (version < TLS_version_tls12) { + if (protocol < TLS_protocol_tls12) { error_print(); return -1; } @@ -791,7 +792,7 @@ int tls_record_set_handshake_client_hello(uint8_t *record, size_t *recordlen, } int tls_record_get_handshake_client_hello(const uint8_t *record, - int *version, const uint8_t **random, + int *protocol, const uint8_t **random, const uint8_t **session_id, size_t *session_id_len, const uint8_t **cipher_suites, size_t *cipher_suites_len, const uint8_t **exts, size_t *exts_len) @@ -803,7 +804,7 @@ int tls_record_get_handshake_client_hello(const uint8_t *record, const uint8_t *comp_meths; size_t comp_meths_len; - if (!record || !version || !random + if (!record || !protocol || !random || !session_id || !session_id_len || !cipher_suites || !cipher_suites_len || !exts || !exts_len) { @@ -827,11 +828,11 @@ int tls_record_get_handshake_client_hello(const uint8_t *record, return -1; } - if (!tls_version_text(ver)) { + if (!tls_protocol_name(ver)) { error_print(); return -1; } - *version = ver; + *protocol = ver; if (*session_id) { if (*session_id_len == 0 @@ -872,7 +873,7 @@ int tls_record_get_handshake_client_hello(const uint8_t *record, } int tls_record_set_handshake_server_hello(uint8_t *record, size_t *recordlen, - int version, const uint8_t random[32], + int protocol, const uint8_t random[32], const uint8_t *session_id, size_t session_id_len, int cipher_suite, const uint8_t *exts, size_t exts_len) { @@ -892,7 +893,7 @@ int tls_record_set_handshake_server_hello(uint8_t *record, size_t *recordlen, return -1; } } - if (!tls_version_text(version)) { + if (!tls_protocol_name(protocol)) { error_print(); return -1; } @@ -904,13 +905,13 @@ int tls_record_set_handshake_server_hello(uint8_t *record, size_t *recordlen, p = tls_handshake_data(tls_record_data(record)); len = 0; - tls_uint16_to_bytes((uint16_t)version, &p, &len); + tls_uint16_to_bytes((uint16_t)protocol, &p, &len); tls_array_to_bytes(random, 32, &p, &len); tls_uint8array_to_bytes(session_id, session_id_len, &p, &len); tls_uint16_to_bytes((uint16_t)cipher_suite, &p, &len); tls_uint8_to_bytes((uint8_t)TLS_compression_null, &p, &len); if (exts) { - if (version < TLS_version_tls12) { + if (protocol < TLS_protocol_tls12) { error_print(); return -1; } @@ -924,7 +925,7 @@ int tls_record_set_handshake_server_hello(uint8_t *record, size_t *recordlen, } int tls_record_get_handshake_server_hello(const uint8_t *record, - int *version, const uint8_t **random, const uint8_t **session_id, size_t *session_id_len, + int *protocol, const uint8_t **random, const uint8_t **session_id, size_t *session_id_len, int *cipher_suite, const uint8_t **exts, size_t *exts_len) { int type; @@ -934,7 +935,7 @@ int tls_record_get_handshake_server_hello(const uint8_t *record, uint16_t cipher; uint8_t comp_meth; - if (!record || !version || !random || !session_id || !session_id_len + if (!record || !protocol || !random || !session_id || !session_id_len || !cipher_suite || !exts || !exts_len) { error_print(); return -1; @@ -956,15 +957,15 @@ int tls_record_get_handshake_server_hello(const uint8_t *record, return -1; } - if (!tls_version_text(ver)) { + if (!tls_protocol_name(ver)) { error_print(); return -1; } - if (ver < tls_record_version(record)) { + if (ver < tls_record_protocol(record)) { error_print(); return -1; } - *version = ver; + *protocol = ver; if (*session_id) { if (*session_id == 0 @@ -1522,7 +1523,7 @@ int tls_record_do_recv(uint8_t *record, size_t *recordlen, int sock) error_print(); return -1; } - if (!tls_version_text(tls_record_version(record))) { + if (!tls_protocol_name(tls_record_protocol(record))) { error_print(); return -1; } @@ -1571,7 +1572,7 @@ retry: uint8_t alert_record[TLS_ALERT_RECORD_SIZE]; size_t alert_record_len; tls_record_set_type(alert_record, TLS_record_alert); - tls_record_set_version(alert_record, tls_record_version(record)); + tls_record_set_protocol(alert_record, tls_record_protocol(record)); tls_record_set_alert(alert_record, &alert_record_len, TLS_alert_level_fatal, TLS_alert_close_notify); tls_trace("send Alert close_notifiy\n"); @@ -1619,7 +1620,7 @@ int tls_send_alert(TLS_CONNECT *conn, int alert) error_print(); return -1; } - tls_record_set_version(record, conn->version); + tls_record_set_protocol(record, conn->protocol); tls_record_set_alert(record, &recordlen, TLS_alert_level_fatal, alert); if (tls_record_send(record, sizeof(record), conn->sock) != 1) { @@ -1661,7 +1662,7 @@ int tls_send_warning(TLS_CONNECT *conn, int alert) error_print(); return -1; } - tls_record_set_version(record, conn->version); + tls_record_set_protocol(record, conn->protocol); tls_record_set_alert(record, &recordlen, TLS_alert_level_warning, alert); if (tls_record_send(record, sizeof(record), conn->sock) != 1) { @@ -1709,7 +1710,7 @@ int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen tls_trace("send ApplicationData\n"); if (tls_record_set_type(record, TLS_record_application_data) != 1 - || tls_record_set_version(record, conn->version) != 1 + || tls_record_set_protocol(record, conn->protocol) != 1 || tls_record_set_length(record, inlen) != 1) { error_print(); return -1; @@ -2021,7 +2022,7 @@ void tls_ctx_cleanup(TLS_CTX *ctx) } } -int tls_ctx_init(TLS_CTX *ctx, int protocol_version, int is_client) +int tls_ctx_init(TLS_CTX *ctx, int protocol, int is_client) { if (!ctx) { error_print(); @@ -2029,11 +2030,11 @@ int tls_ctx_init(TLS_CTX *ctx, int protocol_version, int is_client) } memset(ctx, 0, sizeof(*ctx)); - switch (protocol_version) { - case TLS_version_tlcp: - case TLS_version_tls12: - case TLS_version_tls13: - ctx->protocol_version = protocol_version; + switch (protocol) { + case TLS_protocol_tlcp: + case TLS_protocol_tls12: + case TLS_protocol_tls13: + ctx->protocol = protocol; break; default: error_print(); @@ -2078,7 +2079,7 @@ int tls_ctx_set_ca_certificates(TLS_CTX *ctx, const char *cacertsfile, int depth error_print(); return -1; } - if (!tls_version_text(ctx->protocol_version)) { + if (!tls_protocol_name(ctx->protocol)) { error_print(); return -1; } @@ -2115,7 +2116,7 @@ int tls_ctx_set_certificate_and_key(TLS_CTX *ctx, const char *chainfile, error_print(); return -1; } - if (!tls_version_text(ctx->protocol_version)) { + if (!tls_protocol_name(ctx->protocol)) { error_print(); return -1; } @@ -2178,7 +2179,7 @@ int tls_ctx_set_tlcp_server_certificate_and_keys(TLS_CTX *ctx, const char *chain error_print(); return -1; } - if (!tls_version_text(ctx->protocol_version)) { + if (!tls_protocol_name(ctx->protocol)) { error_print(); return -1; } @@ -2243,7 +2244,7 @@ int tls_init(TLS_CONNECT *conn, const TLS_CTX *ctx) size_t i; memset(conn, 0, sizeof(*conn)); - conn->version = ctx->protocol_version; + conn->protocol = ctx->protocol; conn->is_client = ctx->is_client; for (i = 0; i < ctx->cipher_suites_cnt; i++) { conn->cipher_suites[i] = ctx->cipher_suites[i]; @@ -2302,15 +2303,15 @@ int tls_set_socket(TLS_CONNECT *conn, int sock) int tls_do_handshake(TLS_CONNECT *conn) { - switch (conn->version) { - case TLS_version_tlcp: + switch (conn->protocol) { + case TLS_protocol_tlcp: if (conn->is_client) return tlcp_do_connect(conn); else return tlcp_do_accept(conn); - /* - case TLS_version_tls12: + case TLS_protocol_tls12: if (conn->is_client) return tls12_do_connect(conn); else return tls12_do_accept(conn); - case TLS_version_tls13: + /* + case TLS_protocol_tls13: if (conn->is_client) return tls13_do_connect(conn); else return tls13_do_accept(conn); */ diff --git a/src/tls12.c b/src/tls12.c index 04b27f97..a598155b 100644 --- a/src/tls12.c +++ b/src/tls12.c @@ -64,12 +64,13 @@ #include #include #include +#include #include static const int tls12_ciphers[] = { - GMSSL_cipher_ecdhe_sm2_with_sm4_sm3, + TLS_cipher_ecdhe_sm4_cbc_sm3, }; static const size_t tls12_ciphers_count = sizeof(tls12_ciphers)/sizeof(tls12_ciphers[0]); @@ -80,28 +81,15 @@ static const uint8_t tls12_exts[] = { /* signature_algors */ 0x00,0x0D, 0x00,0x04, 0x00,0x02, 0x07,0x07,//0x08, // sm2sig_sm3 }; -int tls_server_extensions_check(const uint8_t *exts, size_t extslen) + +int tls12_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int format, int indent) { - return 1; + // 目前只支持TLCP的ECC公钥加密套件,因此不论用哪个套件解析都是一样的 + // 如果未来支持ECDHE套件,可以将函数改为宏,直接传入 (conn->cipher_suite << 8) + format |= tls12_ciphers[0] << 8; + return tls_record_print(fp, record, recordlen, format, indent); } -/* -int tls_construct_server_extensions(const uint8_t *client_exts, size_t client_exts_len, - uint8_t *server_exts, size_t *server_exts_len) -{ - uint16_t type; - const uint8_t *p; - size_t len; - - while (client_exts_len) { - if (tls_uint16_from_bytes(&ext_type, &client_exts, &client_exts_len) != 1 - || tls_uint16array_from_bytes(&ext_data, &ext_datalen, &client_exts, &client_exts_len) != 1) { - error_print(); - return -1; - } - } -} -*/ int tls_record_set_handshake_server_key_exchange_ecdhe(uint8_t *record, size_t *recordlen, int curve, const SM2_POINT *point, const uint8_t *sig, size_t siglen) @@ -127,13 +115,15 @@ int tls_record_set_handshake_server_key_exchange_ecdhe(uint8_t *record, size_t * return 1; } +// 这里返回的应该是一个SM2_POINT吗? int tls_record_get_handshake_server_key_exchange_ecdhe(const uint8_t *record, - int *curve, SM2_POINT *point, uint8_t *sig, size_t *siglen) + int *curve, SM2_POINT *point, const uint8_t **sig, size_t *siglen) { int type; const uint8_t *p; size_t len; uint8_t curve_type; + uint16_t named_curve; const uint8_t *octets; size_t octetslen; uint16_t sig_alg; @@ -147,19 +137,25 @@ int tls_record_get_handshake_server_key_exchange_ecdhe(const uint8_t *record, error_print(); return -1; } - *curve = 0; if (tls_uint8_from_bytes(&curve_type, &p, &len) != 1 - || tls_uint16_from_bytes((uint16_t *)curve, &p, &len) != 1 + || tls_uint16_from_bytes(&named_curve, &p, &len) != 1 || tls_uint8array_from_bytes(&octets, &octetslen, &p, &len) != 1 || tls_uint16_from_bytes(&sig_alg, &p, &len) != 1 - || tls_uint16array_copy_from_bytes(sig, siglen, TLS_MAX_SIGNATURE_SIZE, &p, &len) != 1 - || len > 0) { + || tls_uint16array_from_bytes(sig, siglen, &p, &len) != 1 + || tls_length_is_zero(len) != 1) { error_print(); return -1; } - if (curve_type != TLS_curve_type_named_curve - || *curve != TLS_curve_sm2p256v1 - || octetslen != 65 + if (curve_type != TLS_curve_type_named_curve) { + error_print(); + return -1; + } + if (named_curve != TLS_curve_sm2p256v1) { + error_print(); + return -1; + } + *curve = named_curve; + if (octetslen != 65 || sm2_point_from_octets(point, octets, octetslen) != 1) { error_print(); return -1; @@ -207,662 +203,771 @@ int tls_record_get_handshake_client_key_exchange_ecdhe(const uint8_t *record, SM return 1; } -int tls12_record_recv(uint8_t *record, size_t *recordlen, int sock) -{ - if (tls_record_recv(record, recordlen, sock) != 1) { - error_print(); - return -1; - } - if (tls_record_version(record) != TLS_version_tls12) { - error_print(); - return -1; - } - return 1; -} - -int tls12_connect(TLS_CONNECT *conn, const char *hostname, int port, - FILE *ca_certs_fp, FILE *client_certs_fp, const SM2_KEY *client_sign_key) +int tls12_do_connect(TLS_CONNECT *conn) { + int ret = -1; uint8_t *record = conn->record; - size_t recordlen; - uint8_t finished[256]; - size_t finishedlen; + uint8_t finished_record[TLS_FINISHED_RECORD_BUF_SIZE]; + size_t recordlen, finished_record_len; - // handshake - int type; - const uint8_t *data; - size_t datalen; - - // client_hello, server_hello uint8_t client_random[32]; uint8_t server_random[32]; - uint8_t server_exts[TLS_MAX_EXTENSIONS_SIZE]; + int protocol; + int cipher_suite; + const uint8_t *random; + const uint8_t *session_id; + size_t session_id_len; + + uint8_t client_exts[TLS_MAX_EXTENSIONS_SIZE]; + size_t client_exts_len = 0; + const uint8_t *server_exts; size_t server_exts_len; - SM2_KEY server_pub_key; + // 扩展的协商结果,-1 表示服务器不支持该扩展(未给出响应) + int ec_point_format = -1; + int supported_group = -1; + int signature_algor = -1; + + + SM2_KEY server_sign_key; SM2_SIGN_CTX verify_ctx; SM2_SIGN_CTX sign_ctx; - uint8_t sig[TLS_MAX_SIGNATURE_SIZE]; - size_t siglen = sizeof(sig); - - // key_exchange - int curve; - SM2_KEY client_ecdh; - SM2_POINT server_ecdh_public; - uint8_t pre_master_secret[64]; - - // finished verify_data + const uint8_t *sig; + size_t siglen; + uint8_t pre_master_secret[48]; SM3_CTX sm3_ctx; SM3_CTX tmp_sm3_ctx; uint8_t sm3_hash[32]; - uint8_t verify_data[12]; - uint8_t remote_verify_data[12]; + const uint8_t *verify_data; + size_t verify_data_len; + uint8_t local_verify_data[12]; - if (conn->sock <= 0) { - int sock; - struct sockaddr_in server; + int handshake_type; + const uint8_t *server_enc_cert; // 这几个值也是不需要的 + size_t server_enc_cert_len; + uint8_t server_enc_cert_lenbuf[3]; + const uint8_t *cp; + uint8_t *p; + size_t len; - server.sin_addr.s_addr = inet_addr(hostname); - server.sin_family = AF_INET; - server.sin_port = htons(port); + int depth = 5; + int alert = 0; + int verify_result; - if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - error_print(); - return -1; - } - if (connect(sock, (struct sockaddr *)&server , sizeof(server)) < 0) { - error_print(); - return -1; - } - - conn->sock = sock; - conn->is_client = 1; - } + // 初始化记录缓冲 + tls_record_set_protocol(record, TLS_protocol_tls1); // ClientHello的记录层协议版本是TLSv1.0 + tls_record_set_protocol(finished_record, conn->protocol); + // 准备Finished Context(和ClientVerify) sm3_init(&sm3_ctx); - - if (client_certs_fp) { - if (!client_sign_key) { - error_print(); - return -1; - } - if (sm2_sign_init(&sign_ctx, client_sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1) { - error_print(); - return -1; - } - } - - tls_record_set_version(record, TLS_version_tls1); - tls_record_set_version(finished, TLS_version_tls12); + if (conn->client_certs_len) + sm2_sign_init(&sign_ctx, &conn->sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH); - tls_trace("send ClientHello\n"); + // send ClientHello tls_random_generate(client_random); + int ec_point_formats[] = { TLS_point_uncompressed }; + size_t ec_point_formats_cnt = 1; + int supported_groups[] = { TLS_curve_sm2p256v1 }; + size_t supported_groups_cnt = 1; + int signature_algors[] = { TLS_sig_sm2sig_sm3 }; + size_t signature_algors_cnt = 1; + + client_exts_len = 0; + tls_exts_add_ec_point_formats(client_exts, &client_exts_len, sizeof(client_exts), ec_point_formats, ec_point_formats_cnt); + tls_exts_add_supported_groups(client_exts, &client_exts_len, sizeof(client_exts), supported_groups, supported_groups_cnt); + tls_exts_add_signature_algors(client_exts, &client_exts_len, sizeof(client_exts), signature_algors, signature_algors_cnt); + if (tls_record_set_handshake_client_hello(record, &recordlen, - TLS_version_tls12, client_random, NULL, 0, - tls12_ciphers, tls12_ciphers_count, tls12_exts, sizeof(tls12_exts)) != 1) { + conn->protocol, client_random, NULL, 0, + tls12_ciphers, tls12_ciphers_count, + client_exts, client_exts_len) != 1) { error_print(); - return -1; + goto end; } - tls_record_print(stderr, record, recordlen, 0, 0); + tls_trace("send ClientHello\n"); + tls12_record_trace(stderr, record, recordlen, 0, 0); if (tls_record_send(record, recordlen, conn->sock) != 1) { error_print(); - return -1; + goto end; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - if (client_sign_key) + if (conn->client_certs_len) sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); + // recv ServerHello tls_trace("recv ServerHello\n"); - if (tls12_record_recv(record, &recordlen, conn->sock) != 1) { + if (tls_record_recv(record, &recordlen, conn->sock) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; + } + tls12_record_trace(stderr, record, recordlen, 0, 0); + if (tls_record_protocol(record) != conn->protocol) { + error_print(); + tls_send_alert(conn, TLS_alert_protocol_version); + goto end; } - tls_record_print(stderr, record, recordlen, 0, 0); if (tls_record_get_handshake_server_hello(record, - &conn->version, server_random, conn->session_id, &conn->session_id_len, - &conn->cipher_suite, server_exts, &server_exts_len) != 1) { + &protocol, &random, &session_id, &session_id_len, &cipher_suite, + &server_exts, &server_exts_len) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } - if (conn->version != TLS_version_tls12) { + if (protocol != conn->protocol) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_protocol_version); + goto end; } - if (tls_cipher_suite_in_list(conn->cipher_suite, tls12_ciphers, tls12_ciphers_count) != 1) { + // tls12_ciphers 应该改为conn的内部变量 + if (tls_cipher_suite_in_list(cipher_suite, tls12_ciphers, tls12_ciphers_count) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_handshake_failure); + goto end; } - if (tls_server_extensions_check(server_exts, server_exts_len) != 1) { + if (!server_exts) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } + if (tls_process_server_exts(server_exts, server_exts_len, &ec_point_format, &supported_group, &signature_algor) != 1 + || ec_point_format < 0 + || supported_group < 0 + || signature_algor < 0) { + error_print(); + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; + } + memcpy(server_random, random, 32); + memcpy(conn->session_id, session_id, session_id_len); + conn->cipher_suite = cipher_suite; sm3_update(&sm3_ctx, record + 5, recordlen - 5); - if (client_certs_fp) { + if (conn->client_certs_len) sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); - } - + // recv ServerCertificate tls_trace("recv ServerCertificate\n"); - if (tls12_record_recv(record, &recordlen, conn->sock) != 1) { + if (tls_record_recv(record, &recordlen, conn->sock) != 1 + || tls_record_protocol(record) != conn->protocol) { error_print(); - return -1; - } - tls_record_print(stderr, record, recordlen, 0, 0); - if (tls_record_get_handshake_certificate(record, conn->server_certs, &conn->server_certs_len) != 1) { - error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } + tls12_record_trace(stderr, record, recordlen, 0, 0); - /* - // FIXME: Segmentation fault! - if (tls_certificate_chain_verify(conn->server_certs, conn->server_certs_len, ca_certs_fp, 5) != 1) { + if (tls_record_get_handshake_certificate(record, + conn->server_certs, &conn->server_certs_len) != 1) { error_print(); - return -1; - } - */ - if (tls_certificate_get_public_keys(conn->server_certs, conn->server_certs_len, - &server_pub_key, NULL) != 1) { - error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - if (client_certs_fp) { + if (conn->client_certs_len) sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); + + // verify ServerCertificate + if (x509_certs_verify(conn->server_certs, conn->server_certs_len, + conn->ca_certs, conn->ca_certs_len, depth, &verify_result) != 1) { + error_print(); + tls_send_alert(conn, alert); + goto end; } + // recv ServerKeyExchange tls_trace("recv ServerKeyExchange\n"); - if (tls12_record_recv(record, &recordlen, conn->sock) != 1) { + if (tls_record_recv(record, &recordlen, conn->sock) != 1 + || tls_record_protocol(record) != conn->protocol) { error_print(); - return -1; - } - tls_record_print(stderr, record, recordlen, conn->cipher_suite << 8, 0); - sm3_update(&sm3_ctx, record + 5, recordlen - 5); - if (client_certs_fp) { - sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } + tls12_record_trace(stderr, record, recordlen, 0, 0); - if (tls_record_get_handshake_server_key_exchange_ecdhe(record, &curve, &server_ecdh_public, sig, &siglen) != 1) { + int curve; + SM2_POINT server_ecdhe_public; + if (tls_record_get_handshake_server_key_exchange_ecdhe(record, &curve, &server_ecdhe_public, &sig, &siglen) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } if (curve != TLS_curve_sm2p256v1) { error_print(); - return -1; - } - if (tls_verify_server_ecdh_params(&server_pub_key, - client_random, server_random, curve, &server_ecdh_public, sig, siglen) != 1) { - error_print(); - return -1; - } - - - - if (tls12_record_recv(record, &recordlen, conn->sock) != 1) { - error_print(); - return -1; - } - if (tls_record_get_handshake(record, &type, &data, &datalen) != 1) { - error_print(); - return -1; - } - if (type == TLS_handshake_certificate_request) { - int cert_types[TLS_MAX_CERTIFICATE_TYPES]; - size_t cert_types_count;; - uint8_t ca_names[TLS_MAX_CA_NAMES_SIZE]; - size_t ca_names_len; - - tls_trace("recv CertificateRequest\n"); - tls_record_print(stderr, record, recordlen, 0, 0); - if (!client_certs_fp) { - // 这里应该响应一个Alert吧? - error_print(); - return -1; - } - - if (tls_record_get_handshake_certificate_request(record, - cert_types, &cert_types_count, - ca_names, &ca_names_len) != 1) { - error_print(); - return -1; - } - sm3_update(&sm3_ctx, record + 5, recordlen - 5); - sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); - - if (tls_record_recv(record, &recordlen, conn->sock) != 1) { - error_print(); - return -1; - } - } else { - memset(&sign_ctx, 0, sizeof(SM2_SIGN_CTX)); - client_sign_key = NULL; - } - tls_trace("recv ServerHelloDone\n"); - tls_record_print(stderr, record, recordlen, 0, 0); - if (tls_record_get_handshake_server_hello_done(record) != 1) { - error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - if (client_certs_fp) { + if (conn->client_certs_len) sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); + + // verify ServerKeyExchange + if (x509_certs_get_cert_by_index(conn->server_certs, conn->server_certs_len, 0, &cp, &len) != 1 + || x509_cert_get_subject_public_key(cp, len, &server_sign_key) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_bad_certificate); + goto end; + } + if (tls_verify_server_ecdh_params(&server_sign_key, // 这应该是签名公钥 + client_random, server_random, curve, &server_ecdhe_public, sig, siglen) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_internal_error); + goto end; } + // recv CertificateRequest or ServerHelloDone + if (tls_record_recv(record, &recordlen, conn->sock) != 1 + || tls_record_protocol(record) != conn->protocol + || tls_record_get_handshake(record, &handshake_type, &cp, &len) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; + } + if (handshake_type == TLS_handshake_certificate_request) { + const uint8_t *cert_types; + size_t cert_types_len; + const uint8_t *ca_names; + size_t ca_names_len; - if (client_certs_fp) { - tls_trace("send ClientCertificate\n"); - if (tls_record_set_handshake_certificate_from_pem(record, &recordlen, client_certs_fp) != 1) { + // recv CertificateRequest + tls_trace("recv CertificateRequest\n"); + tls12_record_trace(stderr, record, recordlen, 0, 0); + if (tls_record_get_handshake_certificate_request(record, + &cert_types, &cert_types_len, &ca_names, &ca_names_len) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } - tls_record_print(stderr, record, recordlen, 0, 0); + if(!conn->client_certs_len) { + error_print(); + tls_send_alert(conn, TLS_alert_internal_error); + goto end; + } + if (tls_cert_types_accepted(cert_types, cert_types_len, conn->client_certs, conn->client_certs_len) != 1 + || tls_authorities_issued_certificate(ca_names, ca_names_len, conn->client_certs, conn->client_certs_len) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_unsupported_certificate); + goto end; + } + sm3_update(&sm3_ctx, record + 5, recordlen - 5); + sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); + + // recv ServerHelloDone + if (tls_record_recv(record, &recordlen, conn->sock) != 1 + || tls_record_protocol(record) != conn->protocol) { + error_print(); + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; + } + } else { + // 这个得处理一下 + conn->client_certs_len = 0; + gmssl_secure_clear(&conn->sign_key, sizeof(SM2_KEY)); + } + tls_trace("recv ServerHelloDone\n"); + tls12_record_trace(stderr, record, recordlen, 0, 0); + if (tls_record_get_handshake_server_hello_done(record) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; + } + sm3_update(&sm3_ctx, record + 5, recordlen - 5); + if (conn->client_certs_len) + sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); + + // send ClientCertificate + if (conn->client_certs_len) { + tls_trace("send ClientCertificate\n"); + if (tls_record_set_handshake_certificate(record, &recordlen, conn->client_certs, conn->client_certs_len) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_internal_error); + goto end; + } + tls12_record_trace(stderr, record, recordlen, 0, 0); if (tls_record_send(record, recordlen, conn->sock) != 1) { error_print(); - return -1; + goto end; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); } - + // generate MASTER_SECRET tls_trace("generate secrets\n"); + SM2_KEY client_ecdh; sm2_key_generate(&client_ecdh); - sm2_ecdh(&client_ecdh, &server_ecdh_public, &server_ecdh_public); - memcpy(pre_master_secret, &server_ecdh_public, 32); + sm2_ecdh(&client_ecdh, &server_ecdhe_public, &server_ecdhe_public); + memcpy(pre_master_secret, &server_ecdhe_public, 32); // 这个做法很不优雅 + // ECDHE和ECC的PMS结构是不一样的吗? - tls_prf(pre_master_secret, 32, "master secret", - client_random, 32, - server_random, 32, - 48, conn->master_secret); - tls_prf(conn->master_secret, 48, "key expansion", - server_random, 32, - client_random, 32, - 96, conn->key_block); + if (tls_prf(pre_master_secret, 32, "master secret", + client_random, 32, server_random, 32, + 48, conn->master_secret) != 1 + || tls_prf(conn->master_secret, 48, "key expansion", + server_random, 32, client_random, 32, + 96, conn->key_block) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_internal_error); + goto end; + } sm3_hmac_init(&conn->client_write_mac_ctx, conn->key_block, 32); sm3_hmac_init(&conn->server_write_mac_ctx, conn->key_block + 32, 32); sm4_set_encrypt_key(&conn->client_write_enc_key, conn->key_block + 64); sm4_set_decrypt_key(&conn->server_write_enc_key, conn->key_block + 80); - tls_secrets_print(stderr, pre_master_secret, 32, client_random, server_random, - conn->master_secret, conn->key_block, 96, 0, 4); - + tls_secrets_print(stderr, + pre_master_secret, 48, + client_random, server_random, + conn->master_secret, + conn->key_block, 96, + 0, 4); + // send ClientKeyExchange tls_trace("send ClientKeyExchange\n"); - if (tls_record_set_handshake_client_key_exchange_ecdhe(record, &recordlen, - &client_ecdh.public_key) != 1) { + if (tls_record_set_handshake_client_key_exchange_ecdhe(record, &recordlen, &client_ecdh.public_key) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_internal_error); + goto end; } - tls_record_print(stderr, record, recordlen, conn->cipher_suite << 8, 0); + tls12_record_trace(stderr, record, recordlen, 0, 0); if (tls_record_send(record, recordlen, conn->sock) != 1) { error_print(); - return -1; + goto end; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - - - if (client_certs_fp) { - tls_trace("send CertificateVerify\n"); + if (conn->client_certs_len) sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); - sm2_sign_finish(&sign_ctx, sig, &siglen); - memset(&sign_ctx, 0, sizeof(sign_ctx)); - if (tls_record_set_handshake_certificate_verify(record, &recordlen, sig, siglen) != 1) { + + // send CertificateVerify + if (conn->client_certs_len) { + tls_trace("send CertificateVerify\n"); + uint8_t sigbuf[SM2_MAX_SIGNATURE_SIZE]; + if (sm2_sign_finish(&sign_ctx, sigbuf, &siglen) != 1 + || tls_record_set_handshake_certificate_verify(record, &recordlen, sigbuf, siglen) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_internal_error); + goto end; } - tls_record_print(stderr, record, recordlen, 0, 0); + tls12_record_trace(stderr, record, recordlen, 0, 0); if (tls_record_send(record, recordlen, conn->sock) != 1) { error_print(); - return -1; + goto end; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); } - + // send [ChangeCipherSpec] tls_trace("send [ChangeCipherSpec]\n"); if (tls_record_set_change_cipher_spec(record, &recordlen) !=1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_internal_error); + goto end; } + tls12_record_trace(stderr, record, recordlen, 0, 0); if (tls_record_send(record, recordlen, conn->sock) != 1) { error_print(); - return -1; + goto end; } - tls_record_print(stderr, record, recordlen, 0, 0); - + // send Client Finished tls_trace("send Finished\n"); memcpy(&tmp_sm3_ctx, &sm3_ctx, sizeof(sm3_ctx)); sm3_finish(&tmp_sm3_ctx, sm3_hash); - - tls_prf(conn->master_secret, 48, "client finished", - sm3_hash, 32, NULL, 0, - sizeof(verify_data), verify_data); - if (tls_record_set_handshake_finished(finished, &finishedlen, verify_data) != 1) { + if (tls_prf(conn->master_secret, 48, "client finished", + sm3_hash, 32, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1 + || tls_record_set_handshake_finished(finished_record, &finished_record_len, + local_verify_data, sizeof(local_verify_data)) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_internal_error); + goto end; } - tls_record_print(stderr, finished, finishedlen, 0, 0); - sm3_update(&sm3_ctx, finished + 5, finishedlen - 5); + tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0); + sm3_update(&sm3_ctx, finished_record + 5, finished_record_len - 5); + // encrypt Client Finished + tls_trace("encrypt Finished\n"); if (tls_record_encrypt(&conn->client_write_mac_ctx, &conn->client_write_enc_key, - conn->client_seq_num, finished, finishedlen, record, &recordlen) != 1) { + conn->client_seq_num, finished_record, finished_record_len, record, &recordlen) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_internal_error); + goto end; } + tls12_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据 tls_seq_num_incr(conn->client_seq_num); if (tls_record_send(record, recordlen, conn->sock) != 1) { error_print(); - return -1; + goto end; } + // [ChangeCipherSpec] tls_trace("recv [ChangeCipherSpec]\n"); - if (tls12_record_recv(record, &recordlen, conn->sock) != 1) { + if (tls_record_recv(record, &recordlen, conn->sock) != 1 + || tls_record_protocol(record) != conn->protocol) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } - tls_record_print(stderr, record, recordlen, 0, 0); + tls12_record_trace(stderr, record, recordlen, 0, 0); if (tls_record_get_change_cipher_spec(record) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } + // Finished tls_trace("recv Finished\n"); - if (tls12_record_recv(record, &recordlen, conn->sock) != 1) { + if (tls_record_recv(record, &recordlen, conn->sock) != 1 + || tls_record_protocol(record) != conn->protocol) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } + if (recordlen > sizeof(finished_record)) { + error_print(); // 解密可能导致 finished_record 溢出 + tls_send_alert(conn, TLS_alert_bad_record_mac); + goto end; + } + tls12_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据 + tls_trace("decrypt Finished\n"); if (tls_record_decrypt(&conn->server_write_mac_ctx, &conn->server_write_enc_key, - conn->server_seq_num, record, recordlen, finished, &finishedlen) != 1) { + conn->server_seq_num, record, recordlen, finished_record, &finished_record_len) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_bad_record_mac); + goto end; } - tls_record_print(stderr, finished, finishedlen, 0, 0); + tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0); tls_seq_num_incr(conn->server_seq_num); - if (tls_record_get_handshake_finished(finished, remote_verify_data) != 1) { + if (tls_record_get_handshake_finished(finished_record, &verify_data, &verify_data_len) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; + } + if (verify_data_len != sizeof(local_verify_data)) { + error_print(); + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } sm3_finish(&sm3_ctx, sm3_hash); - tls_prf(conn->master_secret, 48, "server finished", - sm3_hash, 32, NULL, 0, - 12, verify_data); - if (memcmp(verify_data, remote_verify_data, 12) != 0) { - error_puts("server_finished.verify_data verification failure"); - return -1; + if (tls_prf(conn->master_secret, 48, "server finished", + sm3_hash, 32, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_internal_error); + goto end; } + if (memcmp(verify_data, local_verify_data, sizeof(local_verify_data)) != 0) { + error_print(); + tls_send_alert(conn, TLS_alert_decrypt_error); + goto end; + } + tls_trace("Connection established!\n"); - tls_trace("SSL Connection Established\n\n"); // 这里应该把协商的参数打印出来 + + conn->protocol = conn->protocol; + conn->cipher_suite = cipher_suite; + + ret = 1; + +end: + gmssl_secure_clear(&sign_ctx, sizeof(sign_ctx)); + gmssl_secure_clear(pre_master_secret, sizeof(pre_master_secret)); return 1; } - -int tls_set_fd(TLS_CONNECT *conn, int sock) +int tls12_do_accept(TLS_CONNECT *conn) { - int opts; + int ret = -1; - if ((opts = fcntl(sock, F_GETFL)) < 0) { - error_print(); - return -1; - } - opts &= ~O_NONBLOCK; - if (fcntl(sock, F_SETFL, opts) < 0) { - error_print(); - return -1; - } + int client_verify = 0; - conn->sock = sock; - return 1; -} - -int tls12_accept(TLS_CONNECT *conn, int port, - FILE *server_certs_fp, const SM2_KEY *server_sign_key, - FILE *client_cacerts_fp, - uint8_t *handshakes_buf, - size_t handshakes_buflen) -{ - uint8_t *handshakes = handshakes_buf; - size_t handshakeslen = 0; uint8_t *record = conn->record; - size_t recordlen; - uint8_t finished[256]; - size_t finishedlen = sizeof(finished); + uint8_t finished_record[TLS_FINISHED_RECORD_BUF_SIZE]; // 解密可能导致前面的record被覆盖 + size_t recordlen, finished_record_len; + // 这个ciphers不是应该在CTX中设置的吗 + const int server_ciphers[] = { TLS_cipher_ecdhe_sm4_cbc_sm3 }; // 未来应该支持GCM/CBC两个套件 + + // ClientHello, ServerHello uint8_t client_random[32]; uint8_t server_random[32]; - uint8_t session_id[32]; + int protocol; + const uint8_t *random; + const uint8_t *session_id; // TLCP服务器忽略客户端SessionID,也不主动设置SessionID size_t session_id_len; - int client_ciphers[12] = {0}; - size_t client_ciphers_count = sizeof(client_ciphers)/sizeof(client_ciphers[0]); - uint8_t exts[TLS_MAX_EXTENSIONS_SIZE]; - size_t exts_len; + const uint8_t *client_ciphers; + size_t client_ciphers_len; + const uint8_t *client_exts; + size_t client_exts_len; + uint8_t server_exts[TLS_MAX_EXTENSIONS_SIZE]; + size_t server_exts_len; + int curve = TLS_curve_sm2p256v1; // 这个是否应该在conn中设置? - SM2_KEY client_sign_key; - SM2_KEY server_ecdh; - SM2_POINT client_ecdh_public; + // ServerKeyExchange + SM2_KEY server_ecdhe_key; SM2_SIGN_CTX sign_ctx; - uint8_t sig[TLS_MAX_SIGNATURE_SIZE]; - size_t siglen = sizeof(sig); - uint8_t pre_master_secret[64]; + uint8_t sigbuf[SM2_MAX_SIGNATURE_SIZE]; + size_t siglen; + + // ClientCertificate, CertificateVerify + TLS_CLIENT_VERIFY_CTX client_verify_ctx; + SM2_KEY client_sign_key; + const uint8_t *sig; + const int verify_depth = 5; + int verify_result; + + // ClientKeyExchange + SM2_POINT client_ecdhe_point; + uint8_t pre_master_secret[SM2_MAX_PLAINTEXT_SIZE]; // sm2_decrypt 保证输出不会溢出 + size_t pre_master_secret_len; + + // Finished SM3_CTX sm3_ctx; SM3_CTX tmp_sm3_ctx; uint8_t sm3_hash[32]; - uint8_t verify_data[12]; uint8_t local_verify_data[12]; - size_t i; + const uint8_t *verify_data; + size_t verify_data_len; - if (conn->sock <= 0) { - int sock; - struct sockaddr_in server_addr; - struct sockaddr_in client_addr; - socklen_t client_addrlen; - - if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - error_print(); - return -1; - } - server_addr.sin_family = AF_INET; - server_addr.sin_addr.s_addr = INADDR_ANY; - server_addr.sin_port = htons(port); - - if (bind(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) { - error_print(); - return -1; - } - - error_puts("start listen ..."); - listen(sock, 5); - - client_addrlen = sizeof(client_addr); - if ((conn->sock = accept(sock, (struct sockaddr *)&client_addr, &client_addrlen)) < 0) { - error_print(); - return -1; - } - - error_puts("connected\n"); - conn->sock = sock; - } + uint8_t *p; + const uint8_t *cp; + size_t len; + // 服务器端如果设置了CA + if (conn->ca_certs_len) + client_verify = 1; + + // 初始化Finished和客户端验证环境 sm3_init(&sm3_ctx); + if (client_verify) + tls_client_verify_init(&client_verify_ctx); + + // recv ClientHello tls_trace("recv ClientHello\n"); if (tls_record_recv(record, &recordlen, conn->sock) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } - tls_record_print(stderr, record, recordlen, 0, 0); - if (tls_record_version(record) != TLS_version_tls1 - && tls_record_version(record) != TLS_version_tls12) { + tls12_record_trace(stderr, record, recordlen, 0, 0); + if (tls_record_protocol(record) != conn->protocol + && tls_record_protocol(record) != TLS_protocol_tls1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_protocol_version); + goto end; } if (tls_record_get_handshake_client_hello(record, - &conn->version, client_random, session_id, &session_id_len, - client_ciphers, &client_ciphers_count, exts, &exts_len) != 1) { + &protocol, &random, &session_id, &session_id_len, + &client_ciphers, &client_ciphers_len, + &client_exts, &client_exts_len) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } - if (conn->version != TLS_version_tls12) { + if (protocol != conn->protocol) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_protocol_version); + goto end; } - for (i = 0; i < tls12_ciphers_count; i++) { - if (tls_cipher_suite_in_list(tls12_ciphers[i], client_ciphers, client_ciphers_count) == 1) { - conn->cipher_suite = tls12_ciphers[i]; - break; - } + memcpy(client_random, random, 32); + if (tls_cipher_suites_select(client_ciphers, client_ciphers_len, + server_ciphers, sizeof(server_ciphers)/sizeof(server_ciphers[0]), + &conn->cipher_suite) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_insufficient_security); + goto end; } - if (conn->cipher_suite == 0) { - error_puts("no common cipher_suite"); - return -1; + if (client_exts) { + server_exts_len = 0; + curve = TLS_curve_sm2p256v1; + + tls_process_client_exts(client_exts, client_exts_len, server_exts, &server_exts_len, sizeof(server_exts)); + + + } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - tls_array_to_bytes(record + 5, recordlen - 5, &handshakes, &handshakeslen); + if (client_verify) + tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5); + // send ServerHello tls_trace("send ServerHello\n"); tls_random_generate(server_random); - tls_record_set_version(record, conn->version); + tls_record_set_protocol(record, conn->protocol); if (tls_record_set_handshake_server_hello(record, &recordlen, - conn->version, server_random, NULL, 0, - conn->cipher_suite, exts, exts_len) != 1) { + conn->protocol, server_random, NULL, 0, + conn->cipher_suite, server_exts, server_exts_len) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_internal_error); + goto end; } - tls_record_print(stderr, record, recordlen, 0, 0); + tls12_record_trace(stderr, record, recordlen, 0, 0); if (tls_record_send(record, recordlen, conn->sock) != 1) { error_print(); - return -1; + goto end; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - tls_array_to_bytes(record + 5, recordlen - 5, &handshakes, &handshakeslen); - + if (client_verify) + tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5); + // send ServerCertificate tls_trace("send ServerCertificate\n"); - if (tls_record_set_handshake_certificate_from_pem(record, &recordlen, server_certs_fp) != 1) { + if (tls_record_set_handshake_certificate(record, &recordlen, + conn->server_certs, conn->server_certs_len) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_internal_error); + goto end; } - tls_record_print(stderr, record, recordlen, 0, 0); + tls12_record_trace(stderr, record, recordlen, 0, 0); if (tls_record_send(record, recordlen, conn->sock) != 1) { error_print(); - return -1; - } - if (tls_record_get_handshake_certificate(record, conn->server_certs, &conn->server_certs_len) != 1) { - error_print(); - return -1; + goto end; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - tls_array_to_bytes(record + 5, recordlen - 5, &handshakes, &handshakeslen); - + if (client_verify) + tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5); + // send ServerKeyExchange tls_trace("send ServerKeyExchange\n"); - sm2_key_generate(&server_ecdh); - if (tls_sign_server_ecdh_params(server_sign_key, - client_random, server_random, - TLS_curve_sm2p256v1, &server_ecdh.public_key, sig, &siglen) != 1) { + sm2_key_generate(&server_ecdhe_key); + if (tls_sign_server_ecdh_params(&conn->sign_key, + client_random, server_random, TLS_curve_sm2p256v1, &server_ecdhe_key.public_key, + sigbuf, &siglen) != 1) { error_print(); + tls_send_alert(conn, TLS_alert_internal_error); return -1; } if (tls_record_set_handshake_server_key_exchange_ecdhe(record, &recordlen, - TLS_curve_sm2p256v1, &server_ecdh.public_key, sig, siglen) != 1) { + curve, &server_ecdhe_key.public_key, sigbuf, siglen) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_internal_error); + goto end; } - tls_record_print(stderr, record, recordlen, conn->cipher_suite << 8, 0); + tls12_record_trace(stderr, record, recordlen, 0, 0); if (tls_record_send(record, recordlen, conn->sock) != 1) { error_print(); - return -1; + goto end; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - tls_array_to_bytes(record + 5, recordlen - 5, &handshakes, &handshakeslen); + if (client_verify) + tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5); - - - if (client_cacerts_fp) { - tls_trace("send CertificateRequest\n"); - const int cert_types[] = { TLS_cert_type_ecdsa_sign, }; - size_t cert_types_count = sizeof(cert_types)/sizeof(cert_types[0]); - uint8_t ca_names[TLS_MAX_CA_NAMES_SIZE] = {0}; + // send CertificateRequest + if (client_verify) { + const uint8_t cert_types[] = { TLS_cert_type_ecdsa_sign }; + uint8_t ca_names[TLS_MAX_CA_NAMES_SIZE] = {0}; // TODO: 根据客户端验证CA证书列计算缓冲大小,或直接输出到record缓冲 size_t ca_names_len = 0; - // FIXME: 没有设置ca_names + tls_trace("send CertificateRequest\n"); + if (tls_authorities_from_certs(ca_names, &ca_names_len, sizeof(ca_names), + conn->ca_certs, conn->ca_certs_len) != 1) { + error_print(); + goto end; + } if (tls_record_set_handshake_certificate_request(record, &recordlen, - cert_types, cert_types_count, + cert_types, sizeof(cert_types), ca_names, ca_names_len) != 1) { error_print(); - return -1; + goto end; } - tls_record_print(stderr, record, recordlen, 0, 0); + tls12_record_trace(stderr, record, recordlen, 0, 0); if (tls_record_send(record, recordlen, conn->sock) != 1) { error_print(); - return -1; + goto end; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - tls_array_to_bytes(record + 5, recordlen - 5, &handshakes, &handshakeslen); + tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5); } + // send ServerHelloDone tls_trace("send ServerHelloDone\n"); - if (tls_record_set_handshake_server_hello_done(record, &recordlen) != 1) { - error_print(); - return -1; - } - tls_record_print(stderr, record, recordlen, 0, 0); + tls_record_set_handshake_server_hello_done(record, &recordlen); + tls12_record_trace(stderr, record, recordlen, 0, 0); if (tls_record_send(record, recordlen, conn->sock) != 1) { error_print(); - return -1; + goto end; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - tls_array_to_bytes(record + 5, recordlen - 5, &handshakes, &handshakeslen); + if (client_verify) + tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5); - - if (client_cacerts_fp) { + // recv ClientCertificate + if (conn->ca_certs_len) { tls_trace("recv ClientCertificate\n"); - if (tls12_record_recv(record, &recordlen, conn->sock) != 1) { + if (tls_record_recv(record, &recordlen, conn->sock) != 1 + || tls_record_protocol(record) != conn->protocol) { // protocol检查应该在trace之后 error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } - tls_record_print(stderr, record, recordlen, 0, 0); - if (tls_record_version(record) != TLS_version_tls12) { + tls12_record_trace(stderr, record, recordlen, 0, 0); + if (tls_record_get_handshake_certificate(record, conn->client_certs, &conn->client_certs_len) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } - if (tls_record_get_handshake_certificate(record, - conn->client_certs, &conn->client_certs_len) != 1) { + if (x509_certs_verify(conn->client_certs, conn->client_certs_len, + conn->ca_certs, conn->ca_certs_len, verify_depth, &verify_result) != 1) { error_print(); - return -1; - } - // FIXME: verify client's certificate with ca certs - // 拿到客户端公钥之后就可以开始准备sm2_verify_init 了 - if (tls_certificate_get_public_keys(conn->client_certs, conn->client_certs_len, - &client_sign_key, NULL) != 1) { - error_print(); - return -1; + tls_send_alert(conn, TLS_alert_bad_certificate); + goto end; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - tls_array_to_bytes(record + 5, recordlen - 5, &handshakes, &handshakeslen); + tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5); } - //sleep(1); - + // recv ClientKeyExchange tls_trace("recv ClientKeyExchange\n"); - if (tls12_record_recv(record, &recordlen, conn->sock) != 1) { + if (tls_record_recv(record, &recordlen, conn->sock) != 1 + || tls_record_protocol(record) != conn->protocol) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } - tls_record_print(stderr, record, recordlen, conn->cipher_suite << 8, 0); - if (tls_record_get_handshake_client_key_exchange_ecdhe(record, &client_ecdh_public) != 1) { + tls12_record_trace(stderr, record, recordlen, 0, 0); // 应该给tls12一个独立的trace + if (tls_record_get_handshake_client_key_exchange_ecdhe(record, &client_ecdhe_point) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } + sm3_update(&sm3_ctx, record + 5, recordlen - 5); - tls_array_to_bytes(record + 5, recordlen - 5, &handshakes, &handshakeslen); - + if (client_verify) + tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5); + // recv CertificateVerify + if (client_verify) { + tls_trace("recv CertificateVerify\n"); + if (tls_record_recv(record, &recordlen, conn->sock) != 1 + || tls_record_protocol(record) != conn->protocol) { + tls_send_alert(conn, TLS_alert_unexpected_message); + error_print(); + goto end; + } + tls12_record_trace(stderr, record, recordlen, 0, 0); + if (tls_record_get_handshake_certificate_verify(record, &sig, &siglen) != 1) { + tls_send_alert(conn, TLS_alert_unexpected_message); + error_print(); + goto end; + } + if (x509_certs_get_cert_by_index(conn->client_certs, conn->client_certs_len, 0, &cp, &len) != 1 + || x509_cert_get_subject_public_key(cp, len, &client_sign_key) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_bad_certificate); + goto end; + } + if (tls_client_verify_finish(&client_verify_ctx, sig, siglen, &client_sign_key) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_decrypt_error); + goto end; + } + sm3_update(&sm3_ctx, record + 5, recordlen - 5); + } + // generate secrets tls_trace("generate secrets\n"); - sm2_ecdh(&server_ecdh, &client_ecdh_public, (SM2_POINT *)pre_master_secret); + sm2_ecdh(&server_ecdhe_key, &client_ecdhe_point, &client_ecdhe_point); + memcpy(pre_master_secret, (uint8_t *)&client_ecdhe_point, 32); // 这里应该修改一下表示方式,比如get_xy() tls_prf(pre_master_secret, 32, "master secret", client_random, 32, server_random, 32, 48, conn->master_secret); @@ -876,99 +981,120 @@ int tls12_accept(TLS_CONNECT *conn, int port, tls_secrets_print(stderr, pre_master_secret, 32, client_random, server_random, conn->master_secret, conn->key_block, 96, 0, 4); - - if (client_cacerts_fp) { - tls_trace("recv CertificateVerify\n"); - if (tls_record_recv(record, &recordlen, conn->sock) != 1 - || tls_record_version(record) != TLS_version_tls12) { - error_print(); - return -1; - } - tls_record_print(stderr, record, recordlen, 0, 0); - if (tls_record_get_handshake_certificate_verify(record, sig, &siglen) != 1) { - error_print(); - return -1; - } - sm3_update(&sm3_ctx, record + 5, recordlen - 5); - sm2_verify_init(&sign_ctx, &client_sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH); - sm2_verify_update(&sign_ctx, handshakes_buf, handshakeslen); - if (sm2_verify_finish(&sign_ctx, sig, siglen) != 1) { - error_print(); - return -1; - } - } - + // recv [ChangeCipherSpec] tls_trace("recv [ChangeCipherSpec]\n"); - if (tls12_record_recv(record, &recordlen, conn->sock) != 1) { + if (tls_record_recv(record, &recordlen, conn->sock) != 1 + || tls_record_protocol(record) != conn->protocol) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } - tls_record_print(stderr, record, recordlen, 0, 0); + tls12_record_trace(stderr, record, recordlen, 0, 0); if (tls_record_get_change_cipher_spec(record) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } - tls_trace("recv ClientFinished\n"); - if (tls12_record_recv(record, &recordlen, conn->sock) != 1) { + // recv ClientFinished + tls_trace("recv Finished\n"); + if (tls_record_recv(record, &recordlen, conn->sock) != 1 + || tls_record_protocol(record) != conn->protocol) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; } + if (recordlen > sizeof(finished_record)) { + error_print(); + tls_send_alert(conn, TLS_alert_unexpected_message); + goto end; + } + tls12_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据 + + // decrypt ClientFinished + tls_trace("decrypt Finished\n"); if (tls_record_decrypt(&conn->client_write_mac_ctx, &conn->client_write_enc_key, - conn->client_seq_num, record, recordlen, finished, &finishedlen) != 1) { + conn->client_seq_num, record, recordlen, finished_record, &finished_record_len) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_bad_record_mac); + goto end; } + tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0); tls_seq_num_incr(conn->client_seq_num); - if (tls_record_get_handshake_finished(finished, verify_data) != 1) { + if (tls_record_get_handshake_finished(finished_record, &verify_data, &verify_data_len) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_bad_record_mac); + goto end; } - tls_record_print(stderr, finished, finishedlen, 0, 0); + if (verify_data_len != sizeof(local_verify_data)) { + error_print(); + tls_send_alert(conn, TLS_alert_bad_record_mac); + goto end; + } + + // verify ClientFinished memcpy(&tmp_sm3_ctx, &sm3_ctx, sizeof(SM3_CTX)); - sm3_update(&sm3_ctx, finished + 5, finishedlen - 5); - + sm3_update(&sm3_ctx, finished_record + 5, finished_record_len - 5); sm3_finish(&tmp_sm3_ctx, sm3_hash); - tls_prf(conn->master_secret, 48, "client finished", - sm3_hash, 32, NULL, 0, - 12, local_verify_data); - if (memcmp(local_verify_data, verify_data, 12) != 0) { + if (tls_prf(conn->master_secret, 48, "client finished", sm3_hash, 32, NULL, 0, + sizeof(local_verify_data), local_verify_data) != 1) { + error_print(); + tls_send_alert(conn, TLS_alert_internal_error); + goto end; + } + if (memcmp(verify_data, local_verify_data, sizeof(local_verify_data)) != 0) { error_puts("client_finished.verify_data verification failure"); - return -1; + tls_send_alert(conn, TLS_alert_decrypt_error); + goto end; } + // send [ChangeCipherSpec] tls_trace("send [ChangeCipherSpec]\n"); if (tls_record_set_change_cipher_spec(record, &recordlen) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_internal_error); + goto end; } - tls_record_print(stderr, record, recordlen, 0, 0); + tls12_record_trace(stderr, record, recordlen, 0, 0); if (tls_record_send(record, recordlen, conn->sock) != 1) { error_print(); - return -1; + goto end; } - tls_trace("send ServerFinished\n"); + // send ServerFinished + tls_trace("send Finished\n"); sm3_finish(&sm3_ctx, sm3_hash); - tls_prf(conn->master_secret, 48, "server finished", - sm3_hash, 32, NULL, 0, - 12, verify_data); - if (tls_record_set_handshake_finished(finished, &finishedlen, verify_data) != 1) { + if (tls_prf(conn->master_secret, 48, "server finished", sm3_hash, 32, NULL, 0, + sizeof(local_verify_data), local_verify_data) != 1 + || tls_record_set_handshake_finished(finished_record, &finished_record_len, + local_verify_data, sizeof(local_verify_data)) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_internal_error); + goto end; } - tls_record_print(stderr, finished, finishedlen, 0, 0); + tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0); if (tls_record_encrypt(&conn->server_write_mac_ctx, &conn->server_write_enc_key, - conn->server_seq_num, finished, finishedlen, record, &recordlen) != 1) { + conn->server_seq_num, finished_record, finished_record_len, record, &recordlen) != 1) { error_print(); - return -1; + tls_send_alert(conn, TLS_alert_internal_error); + goto end; } + tls_trace("encrypt Finished\n"); + tls12_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据 tls_seq_num_incr(conn->server_seq_num); if (tls_record_send(record, recordlen, conn->sock) != 1) { error_print(); - return -1; + goto end; } - tls_trace("SSL Connection Established\n\n"); - return 1; + conn->protocol = conn->protocol; + + tls_trace("Connection Established!\n\n"); + ret = 1; + +end: + gmssl_secure_clear(&sign_ctx, sizeof(sign_ctx)); + gmssl_secure_clear(pre_master_secret, sizeof(pre_master_secret)); + if (client_verify) tls_client_verify_cleanup(&client_verify_ctx); + return ret; } diff --git a/src/tls_ext.c b/src/tls_ext.c new file mode 100644 index 00000000..45a1fc49 --- /dev/null +++ b/src/tls_ext.c @@ -0,0 +1,470 @@ +/* + * Copyright (c) 2021 - 2021 The GmSSL Project. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * + * 3. All advertising materials mentioning features or use of this + * software must display the following acknowledgment: + * "This product includes software developed by the GmSSL Project. + * (http://gmssl.org/)" + * + * 4. The name "GmSSL Project" must not be used to endorse or promote + * products derived from this software without prior written + * permission. For written permission, please contact + * guanzhi1980@gmail.com. + * + * 5. Products derived from this software may not be called "GmSSL" + * nor may "GmSSL" appear in their names without prior written + * permission of the GmSSL Project. + * + * 6. Redistributions of any form whatsoever must retain the following + * acknowledgment: + * "This product includes software developed by the GmSSL Project + * (http://gmssl.org/)" + * + * THIS SOFTWARE IS PROVIDED BY THE GmSSL PROJECT ``AS IS'' AND ANY + * EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE GmSSL PROJECT OR + * ITS CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT + * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED + * OF THE POSSIBILITY OF SUCH DAMAGE. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#define TLS_EXTENSION_HEADER_SIZE 4 + +#if 0 + +int tls_exts_add(uint8_t *exts, size_t *extslen, size_t maxlen, + int type, const uint8_t *data, size_t datalen) +{ + if (!exts || !extslen) { + error_print(); + return -1; + } + if (datalen > TLS_MAX_PLAINTEXT_SIZE + || *extslen + TLS_EXTENSION_HEADER_SIZE + datalen > maxlen) { + error_print(); + return -1; + } + exts += *extslen; + tls_uint16_to_bytes(type, &exts, extslen); + tls_uint16array_to_bytes(data, datalen, &exts, extslen); + return 1; +} +#endif + +int tls_exts_add_ec_point_formats(uint8_t *exts, size_t *extslen, size_t maxlen, + const int *formats, size_t formats_cnt) +{ + int type = TLS_extension_ec_point_formats; + size_t datalen = tls_uint8_size() + tls_uint8_size() * formats_cnt; + size_t i; + + if (!exts || !extslen || !formats || !formats_cnt) { + error_print(); + return -1; + } + if (formats_cnt > 256) { + error_print(); + return -1; + } + if (*extslen + TLS_EXTENSION_HEADER_SIZE + datalen > maxlen) { + error_print(); + return -1; + } + exts += *extslen; + tls_uint16_to_bytes(type, &exts, extslen); + tls_uint16_to_bytes(datalen, &exts, extslen); + tls_uint8_to_bytes(tls_uint8_size() * formats_cnt, &exts, extslen); + for (i = 0; i < formats_cnt; i++) { + if (!tls_ec_point_format_name(formats[i])) { + error_print(); + return -1; + } + tls_uint8_to_bytes(formats[i], &exts, extslen); + } + return 1; +} + +#define TLS_MAX_SUPPORTED_GROUPS_COUNT 64 + +int tls_exts_add_supported_groups(uint8_t *exts, size_t *extslen, size_t maxlen, + const int *curves, size_t curves_cnt) +{ + int type = TLS_extension_supported_groups; + size_t datalen = tls_uint16_size() + tls_uint16_size() * curves_cnt; + size_t i; + + if (!exts || !extslen || !curves || !curves_cnt) { + error_print(); + return -1; + } + if (curves_cnt > TLS_MAX_SUPPORTED_GROUPS_COUNT) { + error_print(); + return -1; + } + if (*extslen + TLS_EXTENSION_HEADER_SIZE + datalen > maxlen) { + error_print(); + return -1; + } + exts += *extslen; + tls_uint16_to_bytes(type, &exts, extslen); + tls_uint16_to_bytes(datalen, &exts, extslen); + tls_uint16_to_bytes(tls_uint16_size() * curves_cnt, &exts, extslen); + for (i = 0; i < curves_cnt; i++) { + tls_uint16_to_bytes(curves[i], &exts, extslen); + } + return 1; +} + +#define TLS_MAX_SIGNATURE_ALGORS_COUNT 64 + +int tls_exts_add_signature_algors(uint8_t *exts, size_t *extslen, size_t maxlen, + const int *algs, size_t algs_cnt) +{ + int type = TLS_extension_signature_algorithms; + size_t datalen = tls_uint16_size() + tls_uint16_size() * algs_cnt; + size_t i; + + if (!exts || !extslen || !algs || !algs_cnt) { + error_print(); + return -1; + } + if (algs_cnt > TLS_MAX_SIGNATURE_ALGORS_COUNT) { + error_print(); + return -1; + } + if (*extslen + TLS_EXTENSION_HEADER_SIZE + datalen > maxlen) { + error_print(); + return -1; + } + exts += *extslen; + tls_uint16_to_bytes(type, &exts, extslen); + tls_uint16_to_bytes(datalen, &exts, extslen); + tls_uint16_to_bytes(tls_uint16_size() * algs_cnt, &exts, extslen); + for (i = 0; i < algs_cnt; i++) { + tls_uint16_to_bytes(algs[i], &exts, extslen); + } + return 1; +} + +int tls_process_client_ec_point_formats(const uint8_t *data, size_t datalen, + uint8_t *exts, size_t *extslen, size_t maxlen) +{ + int shared_formats[] = { TLS_point_uncompressed }; + size_t shared_formats_cnt = 0; + const uint8_t *p; + size_t len; + + if (!data || !datalen || !exts || !extslen) { + error_print(); + return -1; + } + if (tls_uint8array_from_bytes(&p, &len, &data, &datalen) != 1 + || tls_length_is_zero(datalen) != 1) { + error_print(); + return -1; + } + while (len) { + uint8_t format; + if (tls_uint8_from_bytes(&format, &p, &len) != 1) { + error_print(); + return -1; + } + if (!tls_ec_point_format_name(format)) { + error_print(); + return -1; + } + if (format == shared_formats[0]) { + shared_formats_cnt = 1; + } + } + if (tls_exts_add_ec_point_formats(exts, extslen, maxlen, shared_formats, shared_formats_cnt) != 1) { + error_print(); + return -1; + } + return 1; +} + +int tls_process_server_ec_point_formats(const uint8_t *data, size_t datalen) +{ + const uint8_t *p; + size_t len; + uint8_t format; + + if (tls_uint8array_from_bytes(&p, &len, &data, &datalen) != 1 + || tls_length_is_zero(datalen) != 1) { + error_print(); + return -1; + } + if (tls_uint8_from_bytes(&format, &p, &len) != 1 + || tls_length_is_zero(len) != 1) { + error_print(); + return -1; + } + if (format != TLS_point_uncompressed) { + error_print(); + return -1; + } + return 1; +} + +int tls_process_client_signature_algors(const uint8_t *data, size_t datalen, + uint8_t *exts, size_t *extslen, size_t maxlen) +{ + int shared_algs[1] = { TLS_sig_sm2sig_sm3 }; + size_t shared_algs_cnt = 0; + const uint8_t *p; + size_t len; + + if (!data || !datalen || !exts || !extslen) { + error_print(); + return -1; + } + if (tls_uint16array_from_bytes(&p, &len, &data, &datalen) != 1 + || tls_length_is_zero(datalen) != 1) { + error_print(); + return -1; + } + while (len) { + uint16_t alg; + if (tls_uint16_from_bytes(&alg, &p, &len) != 1) { + error_print(); + return -1; + } + if (!tls_signature_scheme_name(alg)) { + error_print(); + return -1; + } + if (alg == shared_algs[0]) { + shared_algs_cnt = 1; + } + } + if (tls_exts_add_signature_algors(exts, extslen, maxlen, shared_algs, shared_algs_cnt) != 1) { + error_print(); + return -1; + } + return 1; +} + +int tls_process_server_signature_algors(const uint8_t *data, size_t datalen) +{ + const uint8_t *p; + size_t len; + uint16_t alg; + + if (tls_uint16array_from_bytes(&p, &len, &data, &datalen) != 1 + || tls_length_is_zero(datalen) != 1) { + error_print(); + return -1; + } + if (tls_uint16_from_bytes(&alg, &p, &len) != 1 + || tls_length_is_zero(len) != 1) { + error_print(); + return -1; + } + if (alg != TLS_sig_sm2sig_sm3) { + error_print(); + return -1; + } + return 1; +} + +int tls_process_client_supported_groups(const uint8_t *data, size_t datalen, uint8_t *exts, size_t *extslen, size_t maxlen) +{ + int shared_curves[1] = { TLS_curve_sm2p256v1 }; + size_t shared_curves_cnt = 0; + const uint8_t *p; + size_t len; + + if (!data || !datalen || !exts || !extslen) { + error_print(); + return -1; + } + if (tls_uint16array_from_bytes(&p, &len, &data, &datalen) != 1 + || tls_length_is_zero(datalen) != 1) { + error_print(); + return -1; + } + while (len) { + uint16_t curve; + if (tls_uint16_from_bytes(&curve, &p, &len) != 1) { + error_print(); + return -1; + } + if (!tls_named_curve_name(curve)) { + error_print(); + return -1; + } + if (curve == shared_curves[0]) { + shared_curves_cnt = 1; + } + } + if (tls_exts_add_supported_groups(exts, extslen, maxlen, shared_curves, shared_curves_cnt) != 1) { + error_print(); + return -1; + } + return 1; +} + +int tls_process_server_supported_groups(const uint8_t *data, size_t datalen) +{ + const uint8_t *p; + size_t len; + uint16_t curve; + + if (tls_uint16array_from_bytes(&p, &len, &data, &datalen) != 1 + || tls_length_is_zero(datalen) != 1) { + error_print(); + return -1; + } + if (tls_uint16_from_bytes(&curve, &p, &len) != 1 + || tls_length_is_zero(len) != 1) { + error_print(); + return -1; + } + if (curve != TLS_curve_sm2p256v1) { + error_print(); + return -1; + } + return 1; +} + +int tls_ext_from_bytes(int *type, const uint8_t **data, size_t *datalen, const uint8_t **in, size_t *inlen) +{ + uint16_t ext_type; + if (tls_uint16_from_bytes(&ext_type, in, inlen) != 1 + || tls_uint16array_from_bytes(data, datalen, in, inlen) != 1) { + error_print(); + return -1; + } + *type = ext_type; + if (!tls_extension_name(ext_type)) { + error_print(); + return -1; + } + return 1; +} + +int tls_process_client_exts(const uint8_t *exts, size_t extslen, uint8_t *out, size_t *outlen, size_t maxlen) +{ + int type; + const uint8_t *data; + size_t datalen; + + while (extslen) { + if (tls_ext_from_bytes(&type, &data, &datalen, &exts, &extslen) != 1) { + error_print(); + return -1; + } + + switch (type) { + case TLS_extension_ec_point_formats: + if (tls_process_client_ec_point_formats(data, datalen, out, outlen, maxlen) != 1) { + error_print(); + return -1; + } + break; + case TLS_extension_signature_algorithms: + if (tls_process_client_signature_algors(data, datalen, out, outlen, maxlen) != 1) { + error_print(); + return -1; + } + break; + case TLS_extension_supported_groups: + if (tls_process_client_supported_groups(data, datalen, out, outlen, maxlen) != 1) { + error_print(); + return -1; + } + break; + default: + error_print(); + return -1; + } + } + return 1; +} + +int tls_process_server_exts(const uint8_t *exts, size_t extslen, + int *ec_point_format, int *supported_group, int *signature_algor) +{ + int type; + const uint8_t *data; + size_t datalen; + + *ec_point_format = -1; + *supported_group = -1; + *signature_algor = -1; + + while (extslen) { + if (tls_ext_from_bytes(&type, &data, &datalen, &exts, &extslen) != 1) { + error_print(); + return -1; + } + + switch (type) { + case TLS_extension_ec_point_formats: + if (tls_process_server_ec_point_formats(data, datalen) != 1) { + error_print(); + return -1; + } + *ec_point_format = TLS_point_uncompressed; + break; + case TLS_extension_signature_algorithms: + if (tls_process_server_signature_algors(data, datalen) != 1) { + error_print(); + return -1; + } + *supported_group = TLS_curve_sm2p256v1; + break; + case TLS_extension_supported_groups: + if (tls_process_server_supported_groups(data, datalen) != 1) { + error_print(); + return -1; + } + *signature_algor = TLS_sig_sm2sig_sm3; + break; + default: + error_print(); + return -1; + } + } + return 1; +} diff --git a/src/tls_trace.c b/src/tls_trace.c index d600292d..07495122 100644 --- a/src/tls_trace.c +++ b/src/tls_trace.c @@ -67,18 +67,18 @@ const char *tls_record_type_name(int type) return NULL; } -const char *tls_version_text(int version) +const char *tls_protocol_name(int protocol) { - switch(version) { - case TLS_version_tlcp: return "TLCP"; - case TLS_version_ssl2: return "SSL 2.0"; - case TLS_version_ssl3: return "SSL 3.0"; - case TLS_version_tls1: return "TLS 1.0"; - case TLS_version_tls11: return "TLS 1.1"; - case TLS_version_tls12: return "TLS 1.2"; - case TLS_version_tls13: return "TLS 1.3"; - case TLS_version_dtls1: return "DTLS 1.0"; - case TLS_version_dtls12: return "DTLS 1.2"; + switch(protocol) { + case TLS_protocol_tlcp: return "TLCP"; + case TLS_protocol_ssl2: return "SSL2.0"; + case TLS_protocol_ssl3: return "SSL3.0"; + case TLS_protocol_tls1: return "TLS1.0"; + case TLS_protocol_tls11: return "TLS1.1"; + case TLS_protocol_tls12: return "TLS1.2"; + case TLS_protocol_tls13: return "TLS1.3"; + case TLS_protocol_dtls1: return "DTLS1.0"; + case TLS_protocol_dtls12: return "DTLS1.2"; } return NULL; } @@ -86,23 +86,19 @@ const char *tls_version_text(int version) const char *tls_cipher_suite_name(int cipher) { switch (cipher) { - case TLCP_cipher_ecdhe_sm4_cbc_sm3: return "TLCP_ECDHE_SM4_CBC_SM3"; - case TLCP_cipher_ecdhe_sm4_gcm_sm3: return "TLCP_ECDHE_SM4_GCM_SM3"; - case TLCP_cipher_ecc_sm4_cbc_sm3: return "TLCP_ECC_SM4_CBC_SM3"; - case TLCP_cipher_ecc_sm4_gcm_sm3: return "TLCP_ECC_SM4_GCM_SM3"; - case TLCP_cipher_ibsdh_sm4_cbc_sm3: return "TLCP_IBSDH_SM4_CBC_SM3"; - case TLCP_cipher_ibsdh_sm4_gcm_sm3: return "TLCP_IBSDH_SM4_GCM_SM3"; - case TLCP_cipher_ibc_sm4_cbc_sm3: return "TLCP_IBC_SM4_CBC_SM3"; - case TLCP_cipher_ibc_sm4_gcm_sm3: return "TLCP_IBC_SM4_GCM_SM3"; - case TLCP_cipher_rsa_sm4_cbc_sm3: return "TLCP_RSA_SM4_CBC_SM3"; - case TLCP_cipher_rsa_sm4_gcm_sm3: return "TLCP_RSA_SM4_GCM_SM3"; - case TLCP_cipher_rsa_sm4_cbc_sha256: return "TLCP_RSA_SM4_CBC_SHA256"; - case TLCP_cipher_rsa_sm4_gcm_sha256: return "TLCP_RSA_SM4_GCM_SHA256"; - case GMSSL_cipher_ecdhe_sm2_with_sm4_sm3: return "GMSSL_ECDHE_SM2_WITH_SM4_SM3"; - case GMSSL_cipher_ecdhe_sm2_with_sm4_gcm_sm3: return "GMSSL_ECDHE_SM2_WITH_SM4_GCM_SM3"; - case GMSSL_cipher_ecdhe_sm2_with_sm4_ccm_sm3: return "GMSSL_ECDHE_SM2_WITH_SM4_CCM_SM3"; - case GMSSL_cipher_ecdhe_sm2_with_zuc_sm3: return "GMSSL_ECDHE_SM2_WITH_ZUC_SM3"; - case TLS_cipher_empty_renegotiation_info_scsv: return "TLS_EMPTY_RENEGOTIATION_INFO_SCSV"; + case TLS_cipher_ecdhe_sm4_cbc_sm3: return "ECDHE_SM4_CBC_SM3"; + case TLS_cipher_ecdhe_sm4_gcm_sm3: return "ECDHE_SM4_GCM_SM3"; + case TLS_cipher_ecc_sm4_cbc_sm3: return "ECC_SM4_CBC_SM3"; + case TLS_cipher_ecc_sm4_gcm_sm3: return "ECC_SM4_GCM_SM3"; + case TLS_cipher_ibsdh_sm4_cbc_sm3: return "IBSDH_SM4_CBC_SM3"; + case TLS_cipher_ibsdh_sm4_gcm_sm3: return "IBSDH_SM4_GCM_SM3"; + case TLS_cipher_ibc_sm4_cbc_sm3: return "IBC_SM4_CBC_SM3"; + case TLS_cipher_ibc_sm4_gcm_sm3: return "IBC_SM4_GCM_SM3"; + case TLS_cipher_rsa_sm4_cbc_sm3: return "RSA_SM4_CBC_SM3"; + case TLS_cipher_rsa_sm4_gcm_sm3: return "RSA_SM4_GCM_SM3"; + case TLS_cipher_rsa_sm4_cbc_sha256: return "RSA_SM4_CBC_SHA256"; + case TLS_cipher_rsa_sm4_gcm_sha256: return "RSA_SM4_GCM_SHA256"; + case TLS_cipher_empty_renegotiation_info_scsv: return "EMPTY_RENEGOTIATION_INFO_SCSV"; } return NULL; } @@ -160,7 +156,7 @@ const char *tls_extension_name(int ext) case TLS_extension_supported_ekt_ciphers: return "supported_ekt_ciphers"; case TLS_extension_pre_shared_key: return "pre_shared_key"; case TLS_extension_early_data: return "early_data"; - case TLS_extension_supported_versions: return "supported_versions"; + case TLS_extension_supported_protocols: return "supported_protocols"; case TLS_extension_cookie: return "cookie"; case TLS_extension_psk_key_exchange_modes: return "psk_key_exchange_modes"; case TLS_extension_certificate_authorities: return "certificate_authorities"; @@ -362,10 +358,10 @@ int tls_random_print(FILE *fp, const uint8_t random[32], int format, int indent) int tls_pre_master_secret_print(FILE *fp, const uint8_t pre_master_secret[48], int format, int indent) { - int version = ((int)pre_master_secret[0] << 8) | pre_master_secret[1]; + int protocol = ((int)pre_master_secret[0] << 8) | pre_master_secret[1]; format_print(fp, format, indent, "PreMasterSecret\n"); indent += 4; - format_print(fp, format, indent, "version : %s\n", tls_version_text(version)); + format_print(fp, format, indent, "protocol : %s\n", tls_protocol_name(protocol)); format_bytes(fp, format, indent, "pre_master_secret", pre_master_secret, 48); return 1; } @@ -389,7 +385,7 @@ int tls_extension_print(FILE *fp, int type, const uint8_t *data, size_t datalen, while (len) { uint16_t curve; tls_uint16_from_bytes(&curve, &p, &len); - format_print(fp, format, indent, "%s (0x%04x)\n", + format_print(fp, format, indent, "%s (%d)\n", tls_named_curve_name(curve), curve); } break; @@ -436,7 +432,7 @@ int tls_extension_print(FILE *fp, int type, const uint8_t *data, size_t datalen, error_print(); return -1; } - format_print(fp, format, indent, "group: %s\n", tls_named_curve_name(group)); + format_print(fp, format, indent, "group: %s (%d)\n", tls_named_curve_name(group), group); format_bytes(fp, format, indent, "key_exchange", key_exch, key_exch_len); } break; @@ -483,7 +479,7 @@ int tls_hello_request_print(FILE *fp, const uint8_t *data, size_t datalen, int f int tls_client_hello_print(FILE *fp, const uint8_t *data, size_t datalen, int format, int indent) { int ret = -1; - uint16_t version; + uint16_t protocol; const uint8_t *random; const uint8_t *session_id; const uint8_t *cipher_suites; @@ -493,9 +489,9 @@ int tls_client_hello_print(FILE *fp, const uint8_t *data, size_t datalen, int fo size_t i; format_print(fp, format, indent, "ClientHello\n"); indent += 4; - if (tls_uint16_from_bytes((uint16_t *)&version, &data, &datalen) != 1) goto end; + if (tls_uint16_from_bytes((uint16_t *)&protocol, &data, &datalen) != 1) goto end; format_print(fp, format, indent, "Version: %s (%d.%d)\n", - tls_version_text(version), version >> 8, version & 0xff); + tls_protocol_name(protocol), protocol >> 8, protocol & 0xff); if (tls_array_from_bytes(&random, 32, &data, &datalen) != 1) goto end; tls_random_print(fp, random, format, indent); if (tls_uint8array_from_bytes(&session_id, &session_id_len, &data, &datalen) != 1) goto end; @@ -534,7 +530,7 @@ end: int tls_server_hello_print(FILE *fp, const uint8_t *data, size_t datalen, int format, int indent) { int ret = -1; - uint16_t version; + uint16_t protocol; const uint8_t *random; const uint8_t *session_id; uint16_t cipher_suite; @@ -544,9 +540,9 @@ int tls_server_hello_print(FILE *fp, const uint8_t *data, size_t datalen, int fo size_t i; format_print(fp, format, indent, "ServerHello\n"); indent += 4; - if (tls_uint16_from_bytes(&version, &data, &datalen) != 1) goto bad; + if (tls_uint16_from_bytes(&protocol, &data, &datalen) != 1) goto bad; format_print(fp, format, indent, "Version: %s (%d.%d)\n", - tls_version_text(version), version >> 8, version & 0xff); + tls_protocol_name(protocol), protocol >> 8, protocol & 0xff); if (tls_array_from_bytes(&random, 32, &data, &datalen) != 1) goto bad; tls_random_print(fp, random, format, indent); if (tls_uint8array_from_bytes(&session_id, &session_id_len, &data, &datalen) != 1) goto bad; @@ -617,7 +613,7 @@ int tls_server_key_exchange_ecdhe_print(FILE *fp, const uint8_t *data, size_t da error_print(); return -1; } - format_print(fp, format, indent + 8, "named_curve: %s (04%04x)\n", + format_print(fp, format, indent + 8, "named_curve: %s (%d)\n", tls_named_curve_name(curve), curve); if (tls_uint8array_from_bytes(&octets, &octetslen, &data, &datalen) != 1) { error_print(); @@ -628,7 +624,7 @@ int tls_server_key_exchange_ecdhe_print(FILE *fp, const uint8_t *data, size_t da error_print(); return -1; } - format_print(fp, format, indent, "SignatureScheme: %s (04%04x)\n", + format_print(fp, format, indent, "SignatureScheme: %s (0x%04x)\n", tls_signature_scheme_name(sig_alg), sig_alg); if (tls_uint16array_from_bytes(&sig, &siglen, &data, &datalen) != 1) { error_print(); @@ -647,18 +643,15 @@ int tls_server_key_exchange_print(FILE *fp, const uint8_t *data, size_t datalen, int cipher_suite = (format >> 8) & 0xffff; switch (cipher_suite) { - case TLCP_cipher_ecc_sm4_cbc_sm3: - case TLCP_cipher_ecc_sm4_gcm_sm3: + case TLS_cipher_ecc_sm4_cbc_sm3: + case TLS_cipher_ecc_sm4_gcm_sm3: if (tlcp_server_key_exchange_pke_print(fp, data, datalen, format, indent) != 1) { error_print(); return -1; } break; - case TLCP_cipher_ecdhe_sm4_cbc_sm3: - case TLCP_cipher_ecdhe_sm4_gcm_sm3: - case GMSSL_cipher_ecdhe_sm2_with_sm4_sm3: - case GMSSL_cipher_ecdhe_sm2_with_sm4_gcm_sm3: - case GMSSL_cipher_ecdhe_sm2_with_sm4_ccm_sm3: + case TLS_cipher_ecdhe_sm4_cbc_sm3: + case TLS_cipher_ecdhe_sm4_gcm_sm3: if (tls_server_key_exchange_ecdhe_print(fp, data, datalen, format, indent) != 1) { error_print(); return -1; @@ -707,7 +700,8 @@ int tls_certificate_request_print(FILE *fp, const uint8_t *data, size_t datalen, if (tls_uint8array_from_bytes(&cert_types, &cert_types_len, &data, &datalen) != 1) goto bad; format_print(fp, format, indent, "cert_types\n"); while (cert_types_len--) { - format_print(fp, format, indent + 4, "%s\n", tls_cert_type_name(*cert_types++)); + int cert_type = *cert_types++; + format_print(fp, format, indent + 4, "%s (%d)\n", tls_cert_type_name(cert_type), cert_type); } if (tls_uint16array_from_bytes(&ca_names, &ca_names_len, &data, &datalen) != 1) goto bad; tls_certificate_subjects_print(fp, format, indent, "CAnames", ca_names, ca_names_len); @@ -764,18 +758,15 @@ int tls_client_key_exchange_print(FILE *fp, const uint8_t *data, size_t datalen, { int cipher_suite = (format >> 8) & 0xffff; switch (cipher_suite) { - case TLCP_cipher_ecc_sm4_cbc_sm3: - case TLCP_cipher_ecc_sm4_gcm_sm3: + case TLS_cipher_ecc_sm4_cbc_sm3: + case TLS_cipher_ecc_sm4_gcm_sm3: if (tls_client_key_exchange_pke_print(fp, data, datalen, format, indent) != 1) { error_print(); return -1; } break; - case TLCP_cipher_ecdhe_sm4_cbc_sm3: - case TLCP_cipher_ecdhe_sm4_gcm_sm3: - case GMSSL_cipher_ecdhe_sm2_with_sm4_sm3: - case GMSSL_cipher_ecdhe_sm2_with_sm4_gcm_sm3: - case GMSSL_cipher_ecdhe_sm2_with_sm4_ccm_sm3: + case TLS_cipher_ecdhe_sm4_cbc_sm3: + case TLS_cipher_ecdhe_sm4_gcm_sm3: if (tls_client_key_exchange_ecdhe_print(fp, data, datalen, format, indent) != 1) { error_print(); return -1; @@ -904,16 +895,16 @@ int tls_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int for { const uint8_t *data; size_t datalen; - int version; + int protocol; if (!fp || !record || recordlen < 5) { error_print(); return -1; } - version = tls_record_version(record); + protocol = tls_record_protocol(record); format_print(fp, format, indent, "Record\n"); indent += 4; format_print(fp, format, indent, "ContentType: %s (%d)\n", tls_record_type_name(record[0]), record[0]); - format_print(fp, format, indent, "Version: %s (%d.%d)\n", tls_version_text(version), version >> 8, version & 0xff); + format_print(fp, format, indent, "Version: %s (%d.%d)\n", tls_protocol_name(protocol), protocol >> 8, protocol & 0xff); format_print(fp, format, indent, "Length: %d\n", tls_record_data_length(record)); data = tls_record_data(record); diff --git a/tests/sm2test.c b/tests/sm2test.c index edf71c1c..47f24c45 100644 --- a/tests/sm2test.c +++ b/tests/sm2test.c @@ -102,84 +102,84 @@ int test_sm2_bn(void) sm2_bn_from_hex(r, hex_v); ok = (sm2_bn_cmp(r, v) == 0); printf("sm2 bn test %d %s\n", i++, ok ? "ok" : "failed"); - if (!ok) return 1; + if (!ok) return -1; // fp tests sm2_fp_add(r, x, y); ok = sm2_bn_equ_hex(r, hex_fp_add_x_y); printf("sm2 bn test %d %s\n", i++, ok ? "ok" : "failed"); - if (!ok) return 1; + if (!ok) return -1; sm2_fp_sub(r, x, y); ok = sm2_bn_equ_hex(r, hex_fp_sub_x_y); printf("sm2 bn test %d %s\n", i++, ok ? "ok" : "failed"); - if (!ok) return 1; + if (!ok) return -1; sm2_fp_mul(r, x, y); ok = sm2_bn_equ_hex(r, hex_fp_mul_x_y); printf("sm2 bn test %d %s\n", i++, ok ? "ok" : "failed"); - if (!ok) return 1; + if (!ok) return -1; sm2_fp_exp(r, x, y); ok = sm2_bn_equ_hex(r, hex_fp_exp_x_y); printf("sm2 bn test %d %s\n", i++, ok ? "ok" : "failed"); - if (!ok) return 1; + if (!ok) return -1; sm2_fp_inv(r, x); ok = sm2_bn_equ_hex(r, hex_fp_inv_x); printf("sm2 bn test %d %s\n", i++, ok ? "ok" : "failed"); - if (!ok) return 1; + if (!ok) return -1; sm2_fp_neg(r, x); ok = sm2_bn_equ_hex(r, hex_fp_neg_x); printf("sm2 bn test %d %s\n", i++, ok ? "ok" : "failed"); - if (!ok) return 1; + if (!ok) return -1; // fn tests sm2_fn_add(r, x, y); ok = sm2_bn_equ_hex(r, hex_fn_add_x_y); printf("sm2 bn test %d %s\n", i++, ok ? "ok" : "failed"); - if (!ok) return 1; + if (!ok) return -1; sm2_fn_sub(r, x, y); ok = sm2_bn_equ_hex(r, hex_fn_sub_x_y); printf("sm2 bn test %d %s\n", i++, ok ? "ok" : "failed"); - if (!ok) return 1; + if (!ok) return -1; sm2_fn_sub(r, y, x); ok = sm2_bn_equ_hex(r, hex_fn_sub_y_x); printf("sm2 bn test %d %s\n", i++, ok ? "ok" : "failed"); - if (!ok) return 1; + if (!ok) return -1; sm2_fn_neg(r, x); ok = sm2_bn_equ_hex(r, hex_fn_neg_x); printf("sm2 bn test %d %s\n", i++, ok ? "ok" : "failed"); - if (!ok) return 1; + if (!ok) return -1; sm2_fn_mul(r, x, y); ok = sm2_bn_equ_hex(r, hex_fn_mul_x_y); printf("sm2 bn test %d %s\n", i++, ok ? "ok" : "failed"); - if (!ok) return 1; + if (!ok) return -1; sm2_fn_mul(r, x, v); ok = sm2_bn_equ_hex(r, hex_fn_mul_x_v); printf("sm2 bn test %d %s\n", i++, ok ? "ok" : "failed"); - if (!ok) return 1; + if (!ok) return -1; sm2_fn_sqr(r, x); ok = sm2_bn_equ_hex(r, hex_fn_sqr_x); printf("sm2 bn test %d %s\n", i++, ok ? "ok" : "failed"); - if (!ok) return 1; + if (!ok) return -1; sm2_fn_exp(r, x, y); ok = sm2_bn_equ_hex(r, hex_fn_exp_x_y); printf("sm2 bn test %d %s\n", i++, ok ? "ok" : "failed"); - if (!ok) return 1; + if (!ok) return -1; sm2_fn_inv(r, x); ok = sm2_bn_equ_hex(r, hex_fn_inv_x); printf("sm2 bn test %d %s\n", i++, ok ? "ok" : "failed"); - if (!ok) return 1; + if (!ok) return -1; SM2_BN tv = { 0x2b94b325, 0x5da17313, 0x28d356b1, 0xa4f7fa5e, @@ -187,11 +187,11 @@ int test_sm2_bn(void) }; sm2_bn_from_hex(t, hex_t); ok = (sm2_bn_cmp(t, tv) == 0); - if (!ok) return 1; + if (!ok) return -1; sm2_bn_to_hex(t, hex); - return 0; + return 1; } @@ -231,7 +231,7 @@ int test_sm2_jacobian_point(void) SM2_JACOBIAN_POINT _P, *P = &_P; SM2_JACOBIAN_POINT _G, *G = &_G; SM2_BN k; - int err = 0, i = 1, ok; + int i = 1, ok; uint8_t buf[64]; @@ -239,40 +239,48 @@ int test_sm2_jacobian_point(void) sm2_jacobian_point_copy(G, SM2_G); ok = sm2_jacobian_point_equ_hex(G, hex_G); - printf("sm2 point test %d %s\n", i++, ok ? "ok" : "failed"); err += ok ^ 1; + printf("sm2 point test %d %s\n", i++, ok ? "ok" : "failed"); + if (!ok) return -1; ok = sm2_jacobian_point_is_on_curve(G); - printf("sm2 point test %d %s\n", i++, ok ? "ok" : "failed"); err += ok ^ 1; + printf("sm2 point test %d %s\n", i++, ok ? "ok" : "failed"); + if (!ok) return -1; sm2_jacobian_point_dbl(P, G); ok = sm2_jacobian_point_equ_hex(P, hex_2G); - printf("sm2 point test %d %s\n", i++, ok ? "ok" : "failed"); err += ok ^ 1; + printf("sm2 point test %d %s\n", i++, ok ? "ok" : "failed"); + if (!ok) return -1; sm2_jacobian_point_add(P, P, G); ok = sm2_jacobian_point_equ_hex(P, hex_3G); - printf("sm2 point test %d %s\n", i++, ok ? "ok" : "failed"); err += ok ^ 1; + printf("sm2 point test %d %s\n", i++, ok ? "ok" : "failed"); + if (!ok) return -1; sm2_jacobian_point_sub(P, P, G); ok = sm2_jacobian_point_equ_hex(P, hex_2G); - printf("sm2 point test %d %s\n", i++, ok ? "ok" : "failed"); err += ok ^ 1; + printf("sm2 point test %d %s\n", i++, ok ? "ok" : "failed"); + if (!ok) return -1; sm2_jacobian_point_neg(P, G); ok = sm2_jacobian_point_equ_hex(P, hex_negG); - printf("sm2 point test %d %s\n", i++, ok ? "ok" : "failed"); err += ok ^ 1; + printf("sm2 point test %d %s\n", i++, ok ? "ok" : "failed"); + if (!ok) return -1; sm2_bn_set_word(k, 10); sm2_jacobian_point_mul(P, k, G); ok = sm2_jacobian_point_equ_hex(P, hex_10G); - printf("sm2 point test %d %s\n", i++, ok ? "ok" : "failed"); err += ok ^ 1; + printf("sm2 point test %d %s\n", i++, ok ? "ok" : "failed"); + if (!ok) return -1; sm2_jacobian_point_mul_generator(P, SM2_B); ok = sm2_jacobian_point_equ_hex(P, hex_bG); - printf("sm2 point test %d %s\n", i++, ok ? "ok" : "failed"); err += ok ^ 1; + printf("sm2 point test %d %s\n", i++, ok ? "ok" : "failed"); + if (!ok) return -1; sm2_jacobian_point_to_bytes(P, buf); sm2_jacobian_point_from_hex(P, hex_P); - return err; + return 1; } #define hex_d "5aebdfd947543b713bc0df2c65baaecc5dadd2cab39c6971402daf92c263fad2" @@ -330,7 +338,7 @@ static int test_sm2_point(void) } printf("%s() ok\n", __FUNCTION__); - return 0; + return 1; } static int test_sm2_point_der(void) @@ -366,7 +374,7 @@ static int test_sm2_point_der(void) } printf("%s() ok\n", __FUNCTION__); - return 0; + return 1; } static int test_sm2_point_octets(void) @@ -404,7 +412,7 @@ static int test_sm2_point_octets(void) } printf("%s() ok\n", __FUNCTION__); - return 0; + return 1; } static int test_sm2_point_from_x(void) @@ -441,7 +449,7 @@ static int test_sm2_point_from_x(void) } printf("%s() ok\n", __FUNCTION__); - return 0; + return 1; } static int test_sm2_signature(void) @@ -495,7 +503,7 @@ static int test_sm2_signature(void) printf("%s() ok\n", __FUNCTION__); - return 0; + return 1; } static int test_sm2_sign(void) @@ -534,7 +542,7 @@ static int test_sm2_sign(void) // FIXME: 还应该增加验证不通过的测试 // 还应该增加底层的参数 printf("%s() ok\n", __FUNCTION__); - return 0; + return 1; } static int test_sm2_ciphertext(void) @@ -618,7 +626,7 @@ static int test_sm2_ciphertext(void) } printf("%s() ok\n", __FUNCTION__); - return 0; + return 1; } @@ -650,7 +658,7 @@ static int test_sm2_do_encrypt(void) } printf("%s() ok\n", __FUNCTION__); - return 0; + return 1; } @@ -728,7 +736,7 @@ test_sm2_do_encrypt() ok } printf("%s() ok\n", __FUNCTION__); - return 0; + return 1; } @@ -783,7 +791,7 @@ static int test_sm2_private_key(void) sm2_private_key_print(stderr, 0, 4, "ECPrivateKey", d, dlen); printf("%s() ok\n", __FUNCTION__); - return 0; + return 1; } static int test_sm2_private_key_info(void) @@ -832,7 +840,7 @@ static int test_sm2_private_key_info(void) } printf("%s() ok\n", __FUNCTION__); - return 0; + return 1; } static int test_sm2_enced_private_key_info(void) @@ -883,26 +891,29 @@ static int test_sm2_enced_private_key_info(void) } printf("%s() ok\n", __FUNCTION__); - return 0; + return 1; } + int main(void) { - int err = 0; - err += test_sm2_bn(); - err += test_sm2_jacobian_point(); - err += test_sm2_point(); - err += test_sm2_point_octets(); - err += test_sm2_point_from_x(); - err += test_sm2_point_der(); - err += test_sm2_private_key(); - err += test_sm2_private_key_info(); - err += test_sm2_enced_private_key_info(); - err += test_sm2_signature(); - err += test_sm2_sign(); -// err += test_sm2_ciphertext(); - err += test_sm2_do_encrypt(); - err += test_sm2_encrypt(); - if (!err) printf("%s all tests passed\n", __FILE__); - return err; + if (test_sm2_bn() != 1) goto err; + if (test_sm2_jacobian_point() != 1) goto err; + if (test_sm2_point() != 1) goto err; + if (test_sm2_point_octets() != 1) goto err; + if (test_sm2_point_from_x() != 1) goto err; + if (test_sm2_point_der() != 1) goto err; + if (test_sm2_private_key() != 1) goto err; + if (test_sm2_private_key_info() != 1) goto err; + if (test_sm2_enced_private_key_info() != 1) goto err; + if (test_sm2_signature() != 1) goto err; + if (test_sm2_sign() != 1) goto err; +// if (test_sm2_ciphertext() != 1) goto err; + if (test_sm2_do_encrypt() != 1) goto err; + if (test_sm2_encrypt() != 1) goto err; + printf("%s all tests passed\n", __FILE__); + return 0; +err: + error_print(); + return -1; } diff --git a/tools/gmssl.c b/tools/gmssl.c index 7df3afce..02de94cf 100644 --- a/tools/gmssl.c +++ b/tools/gmssl.c @@ -222,11 +222,11 @@ int main(int argc, char **argv) return tlcp_client_main(argc, argv); } else if (!strcmp(*argv, "tlcp_server")) { return tlcp_server_main(argc, argv); -/* } else if (!strcmp(*argv, "tls12_client")) { return tls12_client_main(argc, argv); } else if (!strcmp(*argv, "tls12_server")) { return tls12_server_main(argc, argv); +/* } else if (!strcmp(*argv, "tls13_client")) { return tls13_client_main(argc, argv); } else if (!strcmp(*argv, "tls13_server")) { diff --git a/tools/tlcp_client.c b/tools/tlcp_client.c index 33d32ffe..c715e4aa 100644 --- a/tools/tlcp_client.c +++ b/tools/tlcp_client.c @@ -60,7 +60,7 @@ #include -static int client_ciphers[] = { TLCP_cipher_ecc_sm4_cbc_sm3, }; +static int client_ciphers[] = { TLS_cipher_ecc_sm4_cbc_sm3, }; static const char *http_get = "GET / HTTP/1.1\r\n" @@ -149,7 +149,7 @@ bad: goto end; } - if (tls_ctx_init(&ctx, TLS_version_tlcp, TLS_client_mode) != 1 + if (tls_ctx_init(&ctx, TLS_protocol_tlcp, TLS_client_mode) != 1 || tls_ctx_set_cipher_suites(&ctx, client_ciphers, sizeof(client_ciphers)/sizeof(client_ciphers[0])) != 1 || tls_ctx_set_ca_certificates(&ctx, cacertfile, TLS_DEFAULT_VERIFY_DEPTH) != 1 || tls_ctx_set_certificate_and_key(&ctx, certfile, keyfile, pass) != 1) { diff --git a/tools/tlcp_server.c b/tools/tlcp_server.c index b97829d8..a0b2d99b 100644 --- a/tools/tlcp_server.c +++ b/tools/tlcp_server.c @@ -75,7 +75,7 @@ int tlcp_server_main(int argc , char **argv) char *encpass = NULL; char *cacertfile = NULL; - int server_ciphers[] = { TLCP_cipher_ecc_sm4_cbc_sm3, }; + int server_ciphers[] = { TLS_cipher_ecc_sm4_cbc_sm3, }; uint8_t verify_buf[4096]; TLS_CTX ctx; @@ -157,7 +157,7 @@ bad: memset(&ctx, 0, sizeof(ctx)); memset(&conn, 0, sizeof(conn)); - if (tls_ctx_init(&ctx, TLS_version_tlcp, TLS_server_mode) != 1 + if (tls_ctx_init(&ctx, TLS_protocol_tlcp, TLS_server_mode) != 1 || tls_ctx_set_cipher_suites(&ctx, server_ciphers, sizeof(server_ciphers)/sizeof(int)) != 1 || tls_ctx_set_tlcp_server_certificate_and_keys(&ctx, certfile, signkeyfile, signpass, enckeyfile, encpass) != 1) { error_print(); diff --git a/tools/tls12_client.c b/tools/tls12_client.c index 3bcb08a1..3069f057 100644 --- a/tools/tls12_client.c +++ b/tools/tls12_client.c @@ -47,23 +47,33 @@ */ #include +#include #include #include + #include +#include +#include +#include +#include #include #include -const char *http_get = +// TLSv1.2客户单和TLCP客户端可能没有什么区别 + +static int client_ciphers[] = { TLS_cipher_ecdhe_sm4_cbc_sm3 }; + +static const char *http_get = "GET / HTTP/1.1\r\n" "Hostname: aaa\r\n" "\r\n\r\n"; +static const char *options = "-host str [-port num] [-cacert file] [-cert file -key file -pass str]"; -static const char *options = "-host str [-port num] [-cacert file] [-cert file -key file [-pass str]]"; - -int tls12_client_main(int argc , char *argv[]) +int tls12_client_main(int argc, char *argv[]) { + int ret = -1; char *prog = argv[0]; char *host = NULL; int port = 443; @@ -71,24 +81,22 @@ int tls12_client_main(int argc , char *argv[]) char *certfile = NULL; char *keyfile = NULL; char *pass = NULL; - - FILE *cacertfp = NULL; - FILE *certfp = NULL; - FILE *keyfp = NULL; - SM2_KEY sm2_key; - + struct sockaddr_in server; + int sock; + TLS_CTX ctx; TLS_CONNECT conn; char buf[100] = {0}; size_t len = sizeof(buf); - - if (argc < 2) { - fprintf(stderr, "usage: %s %s\n", prog, options); - return 1; - } + char send_buf[1024] = {0}; + size_t send_len; argc--; argv++; - while (argc > 0) { + if (argc < 1) { + fprintf(stderr, "usage: %s %s\n", prog, options); + return 1; + } + while (argc >= 1) { if (!strcmp(*argv, "-help")) { printf("usage: %s %s\n", prog, options); return 0; @@ -122,62 +130,73 @@ bad: } if (!host) { - error_print(); - return 1; - } - - if (cacertfile) { - if (!(cacertfp = fopen(cacertfile, "r"))) { - error_print(); - return 1; - } - } - - if (certfile) { - if (!(certfp = fopen(certfile, "r"))) { - error_print(); - return 1; - } - if (!pass) { - pass = getpass("Password : "); - } - if (!keyfile) { - error_print(); - return 1; - } - if (!(keyfp = fopen(keyfile, "r"))) { - error_print(); - return -1; - } - if (sm2_private_key_info_decrypt_from_pem(&sm2_key, pass, keyfp) != 1) { - error_print(); - return -1; - } + fprintf(stderr, "%s: '-in' option required\n", prog); + return -1; } + memset(&ctx, 0, sizeof(ctx)); memset(&conn, 0, sizeof(conn)); - if (tls12_connect(&conn, host, port, cacertfp, certfp, &sm2_key) != 1) { - error_print(); - return -1; + server.sin_addr.s_addr = inet_addr(host); + server.sin_family = AF_INET; + server.sin_port = htons(port); + + + if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + fprintf(stderr, "%s: open socket error : %s\n", prog, strerror(errno)); + goto end; + } + if (connect(sock, (struct sockaddr *)&server , sizeof(server)) < 0) { + fprintf(stderr, "%s: connect error : %s\n", prog, strerror(errno)); + goto end; } - if (tls_send(&conn, (uint8_t *)"12345\n", 6) != 1) { - error_print(); - return -1; + if (tls_ctx_init(&ctx, TLS_protocol_tls12, TLS_client_mode) != 1 + || tls_ctx_set_cipher_suites(&ctx, client_ciphers, sizeof(client_ciphers)/sizeof(client_ciphers[0])) != 1 + || tls_ctx_set_ca_certificates(&ctx, cacertfile, TLS_DEFAULT_VERIFY_DEPTH) != 1 + || tls_ctx_set_certificate_and_key(&ctx, certfile, keyfile, pass) != 1) { + fprintf(stderr, "%s: context init error\n", prog); + goto end; + } + if (tls_init(&conn, &ctx) != 1 + || tls_set_socket(&conn, sock) != 1 + || tls_do_handshake(&conn) != 1) { + fprintf(stderr, "%s: error\n", prog); + goto end; } for (;;) { - memset(buf, 0, sizeof(buf)); - len = sizeof(buf); - if (tls_recv(&conn, (uint8_t *)buf, &len) != 1) { - error_print(); - return -1; + size_t sentlen; + + memset(send_buf, 0, sizeof(send_buf)); + if (!fgets(send_buf, sizeof(send_buf), stdin)) { + if (feof(stdin)) { + tls_shutdown(&conn); + goto end; + } else { + continue; + } } - if (len > 0) { + if (tls_send(&conn, (uint8_t *)send_buf, strlen(send_buf), &sentlen) != 1) { + fprintf(stderr, "%s: send error\n", prog); + goto end; + } + + { + memset(buf, 0, sizeof(buf)); + len = sizeof(buf); + if (tls_recv(&conn, (uint8_t *)buf, sizeof(len), &len) != 1) { + goto end; + } + buf[len] = 0; printf("%s\n", buf); - break; } } + + +end: + close(sock); + tls_ctx_cleanup(&ctx); + tls_cleanup(&conn); return 0; } diff --git a/tools/tls12_server.c b/tools/tls12_server.c index d5df670d..7331661c 100644 --- a/tools/tls12_server.c +++ b/tools/tls12_server.c @@ -1,4 +1,4 @@ -/* +/* * Copyright (c) 2021 - 2021 The GmSSL Project. All rights reserved. * * Redistribution and use in source and binary forms, with or without @@ -47,48 +47,56 @@ */ #include +#include #include #include #include +#include +#include +#include +#include +#include #include #include #include -// [-cacert file] 如果服务器需要客户端提供证书,那么自己必须准备可以验证客户端证书的CA证书 -// 因此如果提供了CA证书,那么等同于要求客户端验证 -static const char *options = " [-port num] -cert file -key file [-pass str] [-cacert file]"; -int tls12_server_main(int argc , char *argv[]) +static const char *options = "[-port num] -cert file -key file -pass str [-cacert file]"; + +int tls12_server_main(int argc , char **argv) { - int ret = -1; + int ret = 1; char *prog = argv[0]; - int port = 443; char *certfile = NULL; char *keyfile = NULL; char *pass = NULL; char *cacertfile = NULL; - FILE *certfp = NULL; - FILE *keyfp = NULL; - FILE *cacertfp = NULL; - SM2_KEY sm2_key; - + int server_ciphers[] = { TLS_cipher_ecdhe_sm4_cbc_sm3, }; uint8_t verify_buf[4096]; - + TLS_CTX ctx; TLS_CONNECT conn; char buf[1600] = {0}; size_t len = sizeof(buf); - if (argc < 2) { + int sock; + struct sockaddr_in server_addr; + struct sockaddr_in client_addr; + socklen_t client_addrlen; + int conn_sock; + + + argc--; + argv++; + + if (argc < 1) { fprintf(stderr, "usage: %s %s\n", prog, options); return 1; } - argc--; - argv++; - while (argc >= 1) { + while (argc > 0) { if (!strcmp(*argv, "-help")) { printf("usage: %s %s\n", prog, options); return 0; @@ -117,65 +125,98 @@ bad: argc--; argv++; } - - if (!certfile || !keyfile) { - error_print(); + if (!certfile) { + fprintf(stderr, "%s: '-cert' option required\n", prog); + return 1; + } + if (!keyfile) { + fprintf(stderr, "%s: '-key' option required\n", prog); + return 1; + } + if (!pass) { + fprintf(stderr, "%s: '-pass' option required\n", prog); return 1; } + memset(&ctx, 0, sizeof(ctx)); + memset(&conn, 0, sizeof(conn)); + + if (tls_ctx_init(&ctx, TLS_protocol_tls12, TLS_server_mode) != 1 + || tls_ctx_set_cipher_suites(&ctx, server_ciphers, sizeof(server_ciphers)/sizeof(int)) != 1 + || tls_ctx_set_certificate_and_key(&ctx, certfile, keyfile, pass) != 1) { + error_print(); + return -1; + } if (cacertfile) { - if (!(cacertfp = fopen(cacertfile, "r"))) { + if (tls_ctx_set_ca_certificates(&ctx, cacertfile, TLS_DEFAULT_VERIFY_DEPTH) != 1) { error_print(); return -1; } } - if (!(certfp = fopen(certfile, "r"))) { + // Socket + if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + error_print(); + return 1; + } + server_addr.sin_family = AF_INET; + server_addr.sin_addr.s_addr = INADDR_ANY; + server_addr.sin_port = htons(port); + if (bind(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) { + error_print(); + perror("tlcp_accept: bind: "); + goto end; + } + puts("start listen ...\n"); + listen(sock, 1); + + + +restart: + + client_addrlen = sizeof(client_addr); + if ((conn_sock = accept(sock, (struct sockaddr *)&client_addr, &client_addrlen)) < 0) { + error_print(); + goto end; + } + puts("socket connected\n"); + + if (tls_init(&conn, &ctx) != 1 + || tls_set_socket(&conn, conn_sock) != 1) { error_print(); return -1; } - if (!pass) { - pass = getpass("Password : "); - } - if (!(keyfp = fopen(keyfile, "r"))) { - error_print(); - return -1; - } - if (sm2_private_key_info_decrypt_from_pem(&sm2_key, pass, keyfp) != 1) { - error_print(); + if (tls_do_handshake(&conn) != 1) { + error_print(); // 为什么这个会触发呢? return -1; } - memset(&conn, 0, sizeof(conn)); - if (tls12_accept(&conn, port, certfp, &sm2_key, cacertfp, verify_buf, 4096) != 1) { - //if (tls12_accept(&conn, port, certfp, &sm2_key, NULL, NULL, 0) != 1) { - error_print(); - return -1; - } - - // 我要做一个反射的服务器,接收到用户的输入之后,再反射回去 for (;;) { - // 接收一个消息 - // 按道理说第二次执行的时候是不可能成功的了,因此客户端没有数据发过来 + int rv; + size_t sentlen; + do { len = sizeof(buf); - if (tls_recv(&conn, (uint8_t *)buf, &len) != 1) { - error_print(); - return -1; + if ((rv = tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len)) != 1) { + if (rv < 0) fprintf(stderr, "%s: recv failure\n", prog); + else fprintf(stderr, "%s: Disconnected by remote\n", prog); + + //close(conn.sock); + tls_cleanup(&conn); + goto restart; } } while (!len); - - // 把这个消息再发回去 - if (tls_send(&conn, (uint8_t *)buf, len) != 1) { - error_print(); - return -1; + if (tls_send(&conn, (uint8_t *)buf, len, &sentlen) != 1) { + fprintf(stderr, "%s: send failure, close connection\n", prog); + close(conn.sock); + goto end; } - - fprintf(stderr, "-----------------\n\n\n\n\n\n"); - } - return 0; + + +end: + return ret; }