diff --git a/CMakeLists.txt b/CMakeLists.txt index 9931e2be..23c122ce 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -777,7 +777,7 @@ endif() # set(CPACK_PACKAGE_NAME "GmSSL") set(CPACK_PACKAGE_VENDOR "GmSSL develop team") -set(CPACK_PACKAGE_VERSION "3.2.0-dev.1038") +set(CPACK_PACKAGE_VERSION "3.2.0-dev.1039") set(CPACK_PACKAGE_DESCRIPTION_FILE ${PROJECT_SOURCE_DIR}/README.md) set(CPACK_NSIS_MODIFY_PATH ON) include(CPack) diff --git a/include/gmssl/aes.h b/include/gmssl/aes.h index ba06ada0..b61ebb36 100644 --- a/include/gmssl/aes.h +++ b/include/gmssl/aes.h @@ -47,9 +47,9 @@ void aes_encrypt(const AES_KEY *key, const uint8_t in[AES_BLOCK_SIZE], uint8_t o void aes_decrypt(const AES_KEY *key, const uint8_t in[AES_BLOCK_SIZE], uint8_t out[AES_BLOCK_SIZE]); -void aes_cbc_encrypt(const AES_KEY *key, const uint8_t iv[AES_BLOCK_SIZE], +void aes_cbc_encrypt_blocks(const AES_KEY *key, const uint8_t iv[AES_BLOCK_SIZE], const uint8_t *in, size_t nblocks, uint8_t *out); -void aes_cbc_decrypt(const AES_KEY *key, const uint8_t iv[AES_BLOCK_SIZE], +void aes_cbc_decrypt_blocks(const AES_KEY *key, const uint8_t iv[AES_BLOCK_SIZE], const uint8_t *in, size_t nblocks, uint8_t *out); int aes_cbc_padding_encrypt(const AES_KEY *key, const uint8_t iv[AES_BLOCK_SIZE], const uint8_t *in, size_t inlen, diff --git a/include/gmssl/tls.h b/include/gmssl/tls.h index 01c6bb6d..6b676d3d 100644 --- a/include/gmssl/tls.h +++ b/include/gmssl/tls.h @@ -402,7 +402,7 @@ const char *tls_alert_description_text(int description); // Key and Crypto -int tls_prf(const uint8_t *secret, size_t secretlen, const char *label, +int tls_prf(const DIGEST *digest, const uint8_t *secret, size_t secretlen, const char *label, const uint8_t *seed, size_t seedlen, const uint8_t *more, size_t morelen, size_t outlen, uint8_t *out); @@ -1420,7 +1420,7 @@ int tls13_ctx_enable_change_cipher_spec(TLS_CTX *ctx, int enable); int tls_generate_keys(TLS_CONNECT *conn); -int tls_compute_verify_data(const uint8_t master_secret[48], +int tls_compute_verify_data(const DIGEST *digest, const uint8_t master_secret[48], const char *label, const DIGEST_CTX *dgst_ctx, uint8_t verify_data[12]); diff --git a/include/gmssl/version.h b/include/gmssl/version.h index 7ed9e85a..af0a7fe2 100644 --- a/include/gmssl/version.h +++ b/include/gmssl/version.h @@ -19,7 +19,7 @@ extern "C" { // Also update CPACK_PACKAGE_VERSION in CMakeLists.txt #define GMSSL_VERSION_NUM 30200 -#define GMSSL_VERSION_STR "GmSSL 3.2.0-dev.1038" +#define GMSSL_VERSION_STR "GmSSL 3.2.0-dev.1039" int gmssl_version_num(void); const char *gmssl_version_str(void); diff --git a/src/aes_modes.c b/src/aes_modes.c index be96cbf4..d4c265b6 100644 --- a/src/aes_modes.c +++ b/src/aes_modes.c @@ -18,7 +18,7 @@ #include -void aes_cbc_encrypt(const AES_KEY *key, const uint8_t iv[16], +void aes_cbc_encrypt_blocks(const AES_KEY *key, const uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out) { while (nblocks--) { @@ -30,7 +30,7 @@ void aes_cbc_encrypt(const AES_KEY *key, const uint8_t iv[16], } } -void aes_cbc_decrypt(const AES_KEY *key, const uint8_t iv[16], +void aes_cbc_decrypt_blocks(const AES_KEY *key, const uint8_t iv[16], const uint8_t *in, size_t nblocks, uint8_t *out) { while (nblocks--) { @@ -55,11 +55,11 @@ int aes_cbc_padding_encrypt(const AES_KEY *key, const uint8_t iv[16], } memset(block + rem, padding, padding); if (inlen/16) { - aes_cbc_encrypt(key, iv, in, inlen/16, out); + aes_cbc_encrypt_blocks(key, iv, in, inlen/16, out); out += inlen - rem; iv = out - 16; } - aes_cbc_encrypt(key, iv, block, 1, out); + aes_cbc_encrypt_blocks(key, iv, block, 1, out); *outlen = inlen - rem + 16; return 1; } @@ -82,10 +82,10 @@ int aes_cbc_padding_decrypt(const AES_KEY *key, const uint8_t iv[16], return -1; } if (inlen > 16) { - aes_cbc_decrypt(key, iv, in, inlen/16 - 1, out); + aes_cbc_decrypt_blocks(key, iv, in, inlen/16 - 1, out); iv = in + inlen - 32; } - aes_cbc_decrypt(key, iv, in + inlen - 16, 1, block); + aes_cbc_decrypt_blocks(key, iv, in + inlen - 16, 1, block); padding = block[15]; if (padding < 1 || padding > 16) { error_print(); diff --git a/src/tlcp.c b/src/tlcp.c index 5abcbbc9..47c8028c 100644 --- a/src/tlcp.c +++ b/src/tlcp.c @@ -1122,7 +1122,7 @@ int tlcp_send_client_finished(TLS_CONNECT *conn) if(conn->verbose) tls_trace("send client {Finished}\n"); - if (tls_compute_verify_data(conn->master_secret, "client finished", &conn->dgst_ctx, verify_data) != 1) { + if (tls_compute_verify_data(conn->digest, conn->master_secret, "client finished", &conn->dgst_ctx, verify_data) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; @@ -1223,7 +1223,7 @@ int tlcp_recv_server_finished(TLS_CONNECT *conn) return -1; } - if (tls_compute_verify_data(conn->master_secret, "server finished", &conn->dgst_ctx, local_verify_data) != 1) { + if (tls_compute_verify_data(conn->digest, conn->master_secret, "server finished", &conn->dgst_ctx, local_verify_data) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); return -1; @@ -1885,7 +1885,7 @@ static int tlcp_generate_master_secret(TLS_CONNECT *conn) error_print(); return -1; } - if (tls_prf(conn->pre_master_secret, 48, "master secret", + if (tls_prf(conn->digest, conn->pre_master_secret, 48, "master secret", conn->client_random, 32, conn->server_random, 32, 48, conn->master_secret) != 1) { @@ -1919,7 +1919,7 @@ static int tlcp_generate_key_block(TLS_CONNECT *conn) error_print(); return -1; } - if (tls_prf(conn->master_secret, 48, "key expansion", + if (tls_prf(conn->digest, conn->master_secret, 48, "key expansion", conn->server_random, 32, conn->client_random, 32, key_block_len, conn->key_block) != 1) { @@ -2184,7 +2184,7 @@ int tlcp_recv_client_finished(TLS_CONNECT *conn) size_t verify_data_len; uint8_t local_verify_data[12]; - if (tls_compute_verify_data(conn->master_secret, "client finished", + if (tls_compute_verify_data(conn->digest, conn->master_secret, "client finished", &conn->dgst_ctx, local_verify_data) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); @@ -2245,7 +2245,7 @@ int tlcp_send_server_finished(TLS_CONNECT *conn) if (conn->recordlen == 0) { if(conn->verbose) tls_trace("send server {Finished}\n"); - if (tls_compute_verify_data(conn->master_secret, "server finished", + if (tls_compute_verify_data(conn->digest, conn->master_secret, "server finished", &conn->dgst_ctx, verify_data) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); diff --git a/src/tls.c b/src/tls.c index 1b494f9b..4234d833 100644 --- a/src/tls.c +++ b/src/tls.c @@ -298,44 +298,6 @@ int tls_record_set_data(uint8_t *record, const uint8_t *data, size_t datalen) return 1; } -static void tls_cbc_encrypt_blocks(const BLOCK_CIPHER_KEY *key, uint8_t iv[16], - const uint8_t *in, size_t nblocks, uint8_t *out) -{ - const uint8_t *piv = iv; - - while (nblocks--) { - size_t i; - for (i = 0; i < 16; i++) { - out[i] = in[i] ^ piv[i]; - } - block_cipher_encrypt(key, out, out); - piv = out; - in += 16; - out += 16; - } - - memcpy(iv, piv, 16); -} - -static void tls_cbc_decrypt_blocks(const BLOCK_CIPHER_KEY *key, uint8_t iv[16], - const uint8_t *in, size_t nblocks, uint8_t *out) -{ - const uint8_t *piv = iv; - - while (nblocks--) { - size_t i; - block_cipher_decrypt(key, in, out); - for (i = 0; i < 16; i++) { - out[i] ^= piv[i]; - } - piv = in; - in += 16; - out += 16; - } - - memcpy(iv, piv, 16); -} - int tls_cbc_encrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *enc_key, const uint8_t seq_num[8], const uint8_t header[5], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) @@ -348,7 +310,8 @@ int tls_cbc_encrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *enc int rem, padding_len; int i; - if (!inited_hmac_ctx || !enc_key || !seq_num || !header || (!in && inlen) || !out || !outlen) { + if (!inited_hmac_ctx || !enc_key || !enc_key->cipher + || !seq_num || !header || (!in && inlen) || !out || !outlen) { error_print(); return -1; } @@ -385,10 +348,40 @@ int tls_cbc_encrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *enc out += 16; if (inlen >= 16) { - tls_cbc_encrypt_blocks(enc_key, iv, in, inlen/16, out); + size_t nblocks = inlen/16; + + switch (enc_key->cipher->oid) { + case OID_sm4: + sm4_cbc_encrypt_blocks(&enc_key->u.sm4_key, iv, in, nblocks, out); + break; +#ifdef ENABLE_AES + case OID_aes128: + case OID_aes256: + aes_cbc_encrypt_blocks(&enc_key->u.aes_key, iv, in, nblocks, out); + break; +#endif + default: + error_print(); + return -1; + } out += inlen - rem; + memcpy(iv, out - 16, 16); + } + switch (enc_key->cipher->oid) { + case OID_sm4: + sm4_cbc_encrypt_blocks(&enc_key->u.sm4_key, iv, last_blocks, sizeof(last_blocks)/16, out); + break; +#ifdef ENABLE_AES + case OID_aes128: + case OID_aes192: + case OID_aes256: + aes_cbc_encrypt_blocks(&enc_key->u.aes_key, iv, last_blocks, sizeof(last_blocks)/16, out); + break; +#endif + default: + error_print(); + return -1; } - tls_cbc_encrypt_blocks(enc_key, iv, last_blocks, sizeof(last_blocks)/16, out); *outlen = 16 + inlen - rem + sizeof(last_blocks); return 1; } @@ -407,14 +400,15 @@ int tls_cbc_decrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *dec size_t hmaclen; int i; - if (!inited_hmac_ctx || !dec_key || !seq_num || !enced_header || !in || !inlen || !out || !outlen) { + if (!inited_hmac_ctx || !dec_key || !dec_key->cipher + || !seq_num || !enced_header || !in || !inlen || !out || !outlen) { error_print(); return -1; } if (inlen % 16 || inlen < (16 + 0 + 32 + 16) // iv + data + mac + padding || inlen > (16 + (1<<14) + 32 + 256)) { - error_print_msg("invalid tls cbc ciphertext length %zu\n", inlen); + error_print(); return -1; } @@ -422,7 +416,21 @@ int tls_cbc_decrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *dec in += 16; inlen -= 16; - tls_cbc_decrypt_blocks(dec_key, iv, in, inlen/16, out); + switch (dec_key->cipher->oid) { + case OID_sm4: + sm4_cbc_decrypt_blocks(&dec_key->u.sm4_key, iv, in, inlen/16, out); + break; +#ifdef ENABLE_AES + case OID_aes128: + case OID_aes192: + case OID_aes256: + aes_cbc_decrypt_blocks(&dec_key->u.aes_key, iv, in, inlen/16, out); + break; +#endif + default: + error_print(); + return -1; + } padding_len = out[inlen - 1]; padding = out + inlen - padding_len - 1; @@ -432,7 +440,7 @@ int tls_cbc_decrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *dec } for (i = 0; i < padding_len; i++) { if (padding[i] != padding_len) { - error_puts("tls ciphertext cbc-padding check failure"); + error_print(); return -1; } } @@ -452,7 +460,7 @@ int tls_cbc_decrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *dec hmac_update(&hmac_ctx, out, *outlen); hmac_finish(&hmac_ctx, hmac, &hmaclen); if (gmssl_secure_memcmp(mac, hmac, sizeof(hmac)) != 0) { - error_puts("tls ciphertext mac check failure\n"); + error_print(); return -1; } return 1; @@ -502,6 +510,8 @@ int tls_gcm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], break; #ifdef ENABLE_AES case OID_aes128: + case OID_aes192: + case OID_aes256: if (aes_gcm_encrypt(&(key->u.aes_key), nonce, sizeof(nonce), aad, sizeof(aad), in, inlen, out, GHASH_SIZE, gmac) != 1) { error_print(); @@ -562,6 +572,8 @@ int tls_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], break; #ifdef ENABLE_AES case OID_aes128: + case OID_aes192: + case OID_aes256: if (aes_gcm_decrypt(&(key->u.aes_key), nonce, sizeof(nonce), aad, sizeof(aad), in, mlen, gmac, GHASH_SIZE, out) != 1) { error_print(); @@ -624,6 +636,8 @@ int tls_ccm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], #endif #ifdef ENABLE_AES_CCM case OID_aes128: + case OID_aes192: + case OID_aes256: if (aes_ccm_encrypt(&(key->u.aes_key), nonce, sizeof(nonce), aad, sizeof(aad), in, inlen, out, GHASH_SIZE, tag) != 1) { error_print(); @@ -686,6 +700,8 @@ int tls_ccm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], #endif #ifdef ENABLE_AES_CCM case OID_aes128: + case OID_aes192: + case OID_aes256: if (aes_ccm_decrypt(&(key->u.aes_key), nonce, sizeof(nonce), aad, sizeof(aad), in, mlen, tag, GHASH_SIZE, out) != 1) { error_print(); @@ -715,24 +731,30 @@ int tls_random_generate(uint8_t random[32]) return 1; } -int tls_prf(const uint8_t *secret, size_t secretlen, const char *label, +int tls_prf(const DIGEST *digest, const uint8_t *secret, size_t secretlen, const char *label, const uint8_t *seed, size_t seedlen, const uint8_t *more, size_t morelen, size_t outlen, uint8_t *out) { HMAC_CTX inited_hmac_ctx; HMAC_CTX hmac_ctx; - uint8_t A[32]; - uint8_t hmac[32]; + uint8_t A[DIGEST_MAX_SIZE]; + uint8_t hmac[DIGEST_MAX_SIZE]; size_t len; + size_t hmaclen; - if (!secret || !secretlen || !label || !seed || !seedlen + if (!digest || !secret || !secretlen || !label || !seed || !seedlen || (!more && morelen) || !outlen || !out) { error_print(); return -1; } + if (digest->digest_size > sizeof(hmac) || !digest->digest_size) { + error_print(); + return -1; + } + hmaclen = digest->digest_size; - hmac_init(&inited_hmac_ctx, DIGEST_sm3(), secret, secretlen); + hmac_init(&inited_hmac_ctx, digest, secret, secretlen); memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX)); hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); @@ -741,30 +763,30 @@ int tls_prf(const uint8_t *secret, size_t secretlen, const char *label, hmac_finish(&hmac_ctx, A, &len); memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX)); - hmac_update(&hmac_ctx, A, sizeof(A)); + hmac_update(&hmac_ctx, A, hmaclen); hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); hmac_update(&hmac_ctx, seed, seedlen); hmac_update(&hmac_ctx, more, morelen); hmac_finish(&hmac_ctx, hmac, &len); - len = outlen < sizeof(hmac) ? outlen : sizeof(hmac); + len = outlen < hmaclen ? outlen : hmaclen; memcpy(out, hmac, len); out += len; outlen -= len; while (outlen) { memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX)); - hmac_update(&hmac_ctx, A, sizeof(A)); + hmac_update(&hmac_ctx, A, hmaclen); hmac_finish(&hmac_ctx, A, &len); memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX)); - hmac_update(&hmac_ctx, A, sizeof(A)); + hmac_update(&hmac_ctx, A, hmaclen); hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); hmac_update(&hmac_ctx, seed, seedlen); hmac_update(&hmac_ctx, more, morelen); hmac_finish(&hmac_ctx, hmac, &len); - len = outlen < sizeof(hmac) ? outlen : sizeof(hmac); + len = outlen < hmaclen ? outlen : hmaclen; memcpy(out, hmac, len); out += len; outlen -= len; @@ -3928,7 +3950,7 @@ int tls_handshake_digest_print(FILE *fp, int fmt, int ind, const char *label, co return 1; } -int tls_compute_verify_data(const uint8_t master_secret[48], +int tls_compute_verify_data(const DIGEST *digest, const uint8_t master_secret[48], const char *label, const DIGEST_CTX *dgst_ctx, uint8_t verify_data[12]) { const size_t master_secret_len = 48; @@ -3937,7 +3959,7 @@ int tls_compute_verify_data(const uint8_t master_secret[48], uint8_t dgst[64]; size_t dgstlen; - if (!master_secret || !dgst_ctx || !verify_data) { + if (!digest || !master_secret || !dgst_ctx || !verify_data) { error_print(); return -1; } @@ -3947,7 +3969,7 @@ int tls_compute_verify_data(const uint8_t master_secret[48], error_print(); return -1; } - if (tls_prf(master_secret, master_secret_len, + if (tls_prf(digest, master_secret, master_secret_len, label, // "client finished" or "server finished", dgst, dgstlen, NULL, 0, verify_data_len, verify_data) != 1) { diff --git a/src/tls12.c b/src/tls12.c index d0a2b6b7..3d59a6de 100644 --- a/src/tls12.c +++ b/src/tls12.c @@ -162,67 +162,6 @@ int tls12_record_decrypt(int cipher_suite, const HMAC_CTX *hmac_ctx, return 1; } -int tls12_prf(const DIGEST *digest, const uint8_t *secret, size_t secretlen, const char *label, - const uint8_t *seed, size_t seedlen, - const uint8_t *more, size_t morelen, - size_t outlen, uint8_t *out) -{ - HMAC_CTX inited_hmac_ctx; - HMAC_CTX hmac_ctx; - uint8_t A[32]; - uint8_t hmac[32]; - size_t len; - - - if (!secret || !secretlen || !label || !seed || !seedlen - || (!more && morelen) || !outlen || !out) { - error_print(); - return -1; - } - - hmac_init(&inited_hmac_ctx, digest, secret, secretlen); - - memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX)); - hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); - hmac_update(&hmac_ctx, seed, seedlen); - hmac_update(&hmac_ctx, more, morelen); - hmac_finish(&hmac_ctx, A, &len); // 检查或者使用长度len - - memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX)); - hmac_update(&hmac_ctx, A, sizeof(A)); - hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); - hmac_update(&hmac_ctx, seed, seedlen); - hmac_update(&hmac_ctx, more, morelen); - hmac_finish(&hmac_ctx, hmac, &len); - - len = outlen < sizeof(hmac) ? outlen : sizeof(hmac); - memcpy(out, hmac, len); - out += len; - outlen -= len; - - while (outlen) { - memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX)); - hmac_update(&hmac_ctx, A, sizeof(A)); - hmac_finish(&hmac_ctx, A, &len); - - memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(HMAC_CTX)); - hmac_update(&hmac_ctx, A, sizeof(A)); - hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label)); - hmac_update(&hmac_ctx, seed, seedlen); - hmac_update(&hmac_ctx, more, morelen); - hmac_finish(&hmac_ctx, hmac, &len); - - len = outlen < sizeof(hmac) ? outlen : sizeof(hmac); - memcpy(out, hmac, len); - out += len; - outlen -= len; - } - return 1; -} - - - - // modify: conn->record_offset int tls_send_record(TLS_CONNECT *conn) { @@ -2746,7 +2685,7 @@ static int tls12_generate_master_secret(TLS_CONNECT *conn, error_print(); return -1; } - if (tls12_prf(conn->digest, pre_master_secret, pre_master_secret_len, + if (tls_prf(conn->digest, pre_master_secret, pre_master_secret_len, "master secret", conn->client_random, 32, conn->server_random, 32, @@ -2773,7 +2712,7 @@ static int tls12_generate_key_block(TLS_CONNECT *conn) size_t keylen = conn->cipher->key_size; size_t key_block_len = keylen * 2 + 8; - if (tls12_prf(conn->digest, conn->master_secret, 48, "key expansion", + if (tls_prf(conn->digest, conn->master_secret, 48, "key expansion", conn->server_random, 32, conn->client_random, 32, key_block_len, conn->key_block) != 1) { @@ -2792,7 +2731,7 @@ static int tls12_generate_key_block(TLS_CONNECT *conn) // OpenSSL tls1_prf 中,这里生成的是128字节,也就是把IV也生成了 // 为什么生成IV呢? - if (tls12_prf(conn->digest, conn->master_secret, 48, "key expansion", + if (tls_prf(conn->digest, conn->master_secret, 48, "key expansion", conn->server_random, 32, conn->client_random, 32, 96, conn->key_block) != 1) { @@ -3206,7 +3145,7 @@ int tls_send_client_finished(TLS_CONNECT *conn) digest_finish(&tmp_ctx, dgst, &dgstlen); - if (tls12_prf(conn->digest, + if (tls_prf(conn->digest, conn->master_secret, 48, "client finished", dgst, dgstlen, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1) { @@ -3277,7 +3216,7 @@ int tls_recv_client_finished(TLS_CONNECT *conn) error_print(); return -1; } - if (tls12_prf(conn->digest, conn->master_secret, 48, "client finished", dgst, dgstlen, NULL, 0, + if (tls_prf(conn->digest, conn->master_secret, 48, "client finished", dgst, dgstlen, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); @@ -3378,7 +3317,7 @@ int tls_send_server_finished(TLS_CONNECT *conn) digest_finish(&conn->dgst_ctx, dgst, &dgstlen); - if (tls12_prf(conn->digest, conn->master_secret, 48, "server finished", dgst, dgstlen, NULL, 0, + if (tls_prf(conn->digest, conn->master_secret, 48, "server finished", dgst, dgstlen, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1) { error_print(); return -1; @@ -3437,7 +3376,7 @@ int tls_recv_server_finished(TLS_CONNECT *conn) error_print(); return -1; } - if (tls12_prf(conn->digest, conn->master_secret, 48, "server finished", + if (tls_prf(conn->digest, conn->master_secret, 48, "server finished", dgst, dgstlen, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1) { error_print(); diff --git a/tests/tlstest.c b/tests/tlstest.c index 2832f1c5..2cdbd661 100644 --- a/tests/tlstest.c +++ b/tests/tlstest.c @@ -84,6 +84,13 @@ static int test_tls_cbc(void) uint8_t key[32] = {0}; HMAC_CTX hmac_ctx; BLOCK_CIPHER_KEY sm4_key; +#if defined(ENABLE_AES) && defined(ENABLE_SHA2) + BLOCK_CIPHER_KEY aes_key; + uint8_t aes_out[256]; + uint8_t aes_buf[256] = {0}; + size_t aes_len; + size_t aes_buflen; +#endif uint8_t seq_num[8] = { 0,0,0,0,0,0,0,1 }; uint8_t header[5]; uint8_t in[] = "hello world"; @@ -106,6 +113,30 @@ static int test_tls_cbc(void) block_cipher_set_decrypt_key(&sm4_key, BLOCK_CIPHER_sm4(), key); tls_cbc_decrypt(&hmac_ctx, &sm4_key, seq_num, header, out, len, buf, &buflen); + if (buflen != sizeof(in) || memcmp(buf, in, sizeof(in)) != 0) { + error_print(); + return -1; + } + +#if defined(ENABLE_AES) && defined(ENABLE_SHA2) + hmac_init(&hmac_ctx, DIGEST_sha256(), key, sizeof(key)); + block_cipher_set_encrypt_key(&aes_key, BLOCK_CIPHER_aes128(), key); + if (tls_cbc_encrypt(&hmac_ctx, &aes_key, seq_num, header, in, sizeof(in), aes_out, &aes_len) != 1) { + error_print(); + return -1; + } + + hmac_init(&hmac_ctx, DIGEST_sha256(), key, sizeof(key)); + block_cipher_set_decrypt_key(&aes_key, BLOCK_CIPHER_aes128(), key); + if (tls_cbc_decrypt(&hmac_ctx, &aes_key, seq_num, header, aes_out, aes_len, aes_buf, &aes_buflen) != 1) { + error_print(); + return -1; + } + if (aes_buflen != sizeof(in) || memcmp(aes_buf, in, sizeof(in)) != 0) { + error_print(); + return -1; + } +#endif printf("%s() ok\n", __FUNCTION__); return 1;