From 5d12858d411931ad00df108109d34d4623c51d6b Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Sat, 13 Jun 2026 23:52:29 +0800 Subject: [PATCH] Clean TLS code --- CMakeLists.txt | 2 +- include/gmssl/tls.h | 23 +----- include/gmssl/version.h | 2 +- src/tlcp.c | 30 ++++++- src/tls.c | 169 +++++++++++++++++++++++++++++++--------- src/tls12.c | 157 +++++++------------------------------ src/tls13.c | 90 ++++++++++----------- tests/ghashtest.c | 66 ---------------- 8 files changed, 231 insertions(+), 308 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 30b33c6f..15af1875 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -768,7 +768,7 @@ endif() # set(CPACK_PACKAGE_NAME "GmSSL") set(CPACK_PACKAGE_VENDOR "GmSSL develop team") -set(CPACK_PACKAGE_VERSION "3.2.0-dev.1036") +set(CPACK_PACKAGE_VERSION "3.2.0-dev.1037") set(CPACK_PACKAGE_DESCRIPTION_FILE ${PROJECT_SOURCE_DIR}/README.md) set(CPACK_NSIS_MODIFY_PATH ON) include(CPack) diff --git a/include/gmssl/tls.h b/include/gmssl/tls.h index c483f1da..a2813eb0 100644 --- a/include/gmssl/tls.h +++ b/include/gmssl/tls.h @@ -413,21 +413,12 @@ int tls_cbc_encrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *enc_key, int tls_cbc_decrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *dec_key, const uint8_t seq_num[8], const uint8_t header[5], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); -int tls_record_cbc_encrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *cbc_key, - const uint8_t seq_num[8], const uint8_t *in, size_t inlen, - uint8_t *out, size_t *outlen); -int tls_record_cbc_decrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *cbc_key, - const uint8_t seq_num[8], const uint8_t *in, size_t inlen, - uint8_t *out, size_t *outlen); -int tls12_gcm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], +int tls_gcm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], + const uint8_t seq_num[8], const uint8_t header[5], + const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); +int tls_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], const uint8_t seq_num[8], const uint8_t header[5], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); -int tls12_record_gcm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], - const uint8_t seq_num[8], const uint8_t *in, size_t inlen, - uint8_t *out, size_t *outlen); -int tls12_record_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], - const uint8_t seq_num[8], const uint8_t *in, size_t inlen, - uint8_t *out, size_t *outlen); int tls12_record_decrypt(int cipher_suite, const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], const uint8_t seq_num[8], const uint8_t *in, size_t inlen, @@ -1727,12 +1718,6 @@ int tls13_random_generate(uint8_t random[32]); int tls13_cipher_suite_get(int cipher_suite, const BLOCK_CIPHER **cipher, const DIGEST **digest); int tls13_padding_len_rand(size_t *padding_len); -int gcm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t *iv, size_t ivlen, - const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen, - uint8_t *out, size_t taglen, uint8_t *tag); -int gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t *iv, size_t ivlen, - const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen, - const uint8_t *tag, size_t taglen, uint8_t *out); int tls13_gcm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], const uint8_t seq_num[8], int record_type, const uint8_t *in, size_t inlen, size_t padding_len, // TLSInnerPlaintext.content diff --git a/include/gmssl/version.h b/include/gmssl/version.h index 170c87f6..235067c0 100644 --- a/include/gmssl/version.h +++ b/include/gmssl/version.h @@ -19,7 +19,7 @@ extern "C" { // Also update CPACK_PACKAGE_VERSION in CMakeLists.txt #define GMSSL_VERSION_NUM 30200 -#define GMSSL_VERSION_STR "GmSSL 3.2.0-dev.1036" +#define GMSSL_VERSION_STR "GmSSL 3.2.0-dev.1037" int gmssl_version_num(void); const char *gmssl_version_str(void); diff --git a/src/tlcp.c b/src/tlcp.c index d9e58185..5abcbbc9 100644 --- a/src/tlcp.c +++ b/src/tlcp.c @@ -74,13 +74,17 @@ int tlcp_record_encrypt(int cipher_suite, { switch (cipher_suite) { case TLS_cipher_ecc_sm4_cbc_sm3: - if (tls_record_cbc_encrypt(hmac_ctx, key, seq_num, in, inlen, out, outlen) != 1) { + if (tls_cbc_encrypt(hmac_ctx, key, seq_num, in, + in + 5, inlen - 5, + out + 5, outlen) != 1) { error_print(); return -1; } break; case TLS_cipher_ecc_sm4_gcm_sm3: - if (tls12_record_gcm_encrypt(key, fixed_iv, seq_num, in, inlen, out, outlen) != 1) { + if (tls_gcm_encrypt(key, fixed_iv, seq_num, in, + in + 5, inlen - 5, + out + 5, outlen) != 1) { error_print(); return -1; } @@ -89,6 +93,13 @@ int tlcp_record_encrypt(int cipher_suite, error_print(); return -1; } + + out[0] = in[0]; + out[1] = in[1]; + out[2] = in[2]; + out[3] = (uint8_t)((*outlen) >> 8); + out[4] = (uint8_t)(*outlen); + (*outlen) += 5; return 1; } @@ -99,13 +110,17 @@ int tlcp_record_decrypt(int cipher_suite, { switch (cipher_suite) { case TLS_cipher_ecc_sm4_cbc_sm3: - if (tls_record_cbc_decrypt(hmac_ctx, key, seq_num, in, inlen, out, outlen) != 1) { + if (tls_cbc_decrypt(hmac_ctx, key, seq_num, in, + in + 5, inlen - 5, + out + 5, outlen) != 1) { error_print(); return -1; } break; case TLS_cipher_ecc_sm4_gcm_sm3: - if (tls12_record_gcm_decrypt(key, fixed_iv, seq_num, in, inlen, out, outlen) != 1) { + if (tls_gcm_decrypt(key, fixed_iv, seq_num, in, + in + 5, inlen - 5, + out + 5, outlen) != 1) { error_print(); return -1; } @@ -114,6 +129,13 @@ int tlcp_record_decrypt(int cipher_suite, error_print(); return -1; } + + out[0] = in[0]; + out[1] = in[1]; + out[2] = in[2]; + out[3] = (uint8_t)((*outlen) >> 8); + out[4] = (uint8_t)(*outlen); + (*outlen) += 5; return 1; } diff --git a/src/tls.c b/src/tls.c index bad7b01b..e399eb07 100644 --- a/src/tls.c +++ b/src/tls.c @@ -458,44 +458,123 @@ int tls_cbc_decrypt(const HMAC_CTX *inited_hmac_ctx, const BLOCK_CIPHER_KEY *dec return 1; } -int tls_record_cbc_encrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *cbc_key, - const uint8_t seq_num[8], const uint8_t *in, size_t inlen, - uint8_t *out, size_t *outlen) +int tls_gcm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], + const uint8_t seq_num[8], const uint8_t header[5], + const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { - if (tls_cbc_encrypt(hmac_ctx, cbc_key, seq_num, in, - in + 5, inlen - 5, - out + 5, outlen) != 1) { + uint8_t nonce[12]; + uint8_t aad[13]; + uint8_t *explicit_nonce; + uint8_t *gmac; + + if (!key || !fixed_iv || !seq_num || !header || (!in && inlen) || !out || !outlen) { + error_print(); + return -1; + } + if (inlen > TLS_MAX_PLAINTEXT_SIZE) { + error_print(); + return -1; + } + if ((((size_t)header[3]) << 8) + header[4] != inlen) { error_print(); return -1; } - out[0] = in[0]; - out[1] = in[1]; - out[2] = in[2]; - out[3] = (uint8_t)((*outlen) >> 8); - out[4] = (uint8_t)(*outlen); - (*outlen) += 5; + memcpy(nonce, fixed_iv, 4); + memcpy(nonce + 4, seq_num, 8); + + memcpy(aad, seq_num, 8); + memcpy(aad + 8, header, 5); + + explicit_nonce = out; + memcpy(explicit_nonce, seq_num, 8); + out += 8; + + gmac = out + inlen; + + switch (key->cipher->oid) { + case OID_sm4: + if (sm4_gcm_encrypt(&(key->u.sm4_key), nonce, sizeof(nonce), aad, sizeof(aad), + in, inlen, out, GHASH_SIZE, gmac) != 1) { + error_print(); + return -1; + } + break; +#ifdef ENABLE_AES + case OID_aes128: + if (aes_gcm_encrypt(&(key->u.aes_key), nonce, sizeof(nonce), aad, sizeof(aad), + in, inlen, out, GHASH_SIZE, gmac) != 1) { + error_print(); + return -1; + } + break; +#endif + default: + error_print(); + return -1; + } + + *outlen = 8 + inlen + GHASH_SIZE; return 1; } -int tls_record_cbc_decrypt(const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *cbc_key, - const uint8_t seq_num[8], const uint8_t *in, size_t inlen, - uint8_t *out, size_t *outlen) +int tls_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], + const uint8_t seq_num[8], const uint8_t header[5], + const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { - if (tls_cbc_decrypt(hmac_ctx, cbc_key, seq_num, in, - in + 5, inlen - 5, - out + 5, outlen) != 1) { + uint8_t nonce[12]; + uint8_t aad[13]; + const uint8_t *explicit_nonce; + const uint8_t *gmac; + size_t mlen; + + if (inlen < 8 + GHASH_SIZE) { error_print(); return -1; } - out[0] = in[0]; - out[1] = in[1]; - out[2] = in[2]; - out[3] = (uint8_t)((*outlen) >> 8); - out[4] = (uint8_t)(*outlen); - (*outlen) += 5; + explicit_nonce = in; + in += 8; + inlen -= 8; + if (inlen < GHASH_SIZE) { + error_print(); + return -1; + } + mlen = inlen - GHASH_SIZE; + gmac = in + mlen; + + memcpy(nonce, fixed_iv, 4); + memcpy(nonce + 4, explicit_nonce, 8); + + memcpy(aad, seq_num, 8); + memcpy(aad + 8, header, 5); + aad[11] = (uint8_t)(mlen >> 8); + aad[12] = (uint8_t)mlen; + + switch (key->cipher->oid) { + case OID_sm4: + if (sm4_gcm_decrypt(&(key->u.sm4_key), nonce, sizeof(nonce), aad, sizeof(aad), + in, mlen, gmac, GHASH_SIZE, out) != 1) { + error_print(); + return -1; + } + break; +#ifdef ENABLE_AES + case OID_aes128: + if (aes_gcm_decrypt(&(key->u.aes_key), nonce, sizeof(nonce), aad, sizeof(aad), + in, mlen, gmac, GHASH_SIZE, out) != 1) { + error_print(); + return -1; + } + break; +#endif + default: + error_print(); + return -1; + } + + *outlen = mlen; return 1; } @@ -1876,18 +1955,18 @@ static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *i switch (conn->cipher_suite) { case TLS_cipher_ecdhe_sm4_gcm_sm3: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: - if (tls12_record_gcm_encrypt(enc_key, fixed_iv, seq_num, - conn->databuf, tls_record_length(conn->databuf), - conn->record, &recordlen) != 1) { + if (tls_gcm_encrypt(enc_key, fixed_iv, seq_num, conn->databuf, + conn->databuf + 5, tls_record_data_length(conn->databuf), + conn->record + 5, &recordlen) != 1) { error_print(); return -1; } break; case TLS_cipher_ecdhe_sm4_cbc_sm3: case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: - if (tls_record_cbc_encrypt(hmac_ctx, enc_key, seq_num, - conn->databuf, tls_record_length(conn->databuf), - conn->record, &recordlen) != 1) { + if (tls_cbc_encrypt(hmac_ctx, enc_key, seq_num, conn->databuf, + conn->databuf + 5, tls_record_data_length(conn->databuf), + conn->record + 5, &recordlen) != 1) { error_print(); return -1; } @@ -1896,6 +1975,12 @@ static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *i error_print(); return -1; } + conn->record[0] = conn->databuf[0]; + conn->record[1] = conn->databuf[1]; + conn->record[2] = conn->databuf[2]; + conn->record[3] = (uint8_t)(recordlen >> 8); + conn->record[4] = (uint8_t)(recordlen); + recordlen += 5; } else if (conn->protocol == TLS_protocol_tlcp) { if (tlcp_record_encrypt(conn->cipher_suite, hmac_ctx, enc_key, fixed_iv, seq_num, conn->databuf, tls_record_length(conn->databuf), @@ -1904,12 +1989,18 @@ static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *i return -1; } } else { - if (tls_record_cbc_encrypt(hmac_ctx, enc_key, seq_num, - conn->databuf, tls_record_length(conn->databuf), - conn->record, &recordlen) != 1) { + if (tls_cbc_encrypt(hmac_ctx, enc_key, seq_num, conn->databuf, + conn->databuf + 5, tls_record_data_length(conn->databuf), + conn->record + 5, &recordlen) != 1) { error_print(); return -1; } + conn->record[0] = conn->databuf[0]; + conn->record[1] = conn->databuf[1]; + conn->record[2] = conn->databuf[2]; + conn->record[3] = (uint8_t)(recordlen >> 8); + conn->record[4] = (uint8_t)(recordlen); + recordlen += 5; } tls_seq_num_incr(seq_num); @@ -1987,12 +2078,18 @@ int tls_decrypt_recv(TLS_CONNECT *conn) return -1; } } else { - if (tls_record_cbc_decrypt(hmac_ctx, dec_key, seq_num, - record, recordlen, - conn->databuf, &conn->datalen) != 1) { + if (tls_cbc_decrypt(hmac_ctx, dec_key, seq_num, record, + record + 5, recordlen - 5, + conn->databuf + 5, &conn->datalen) != 1) { error_print(); return -1; } + conn->databuf[0] = record[0]; + conn->databuf[1] = record[1]; + conn->databuf[2] = record[2]; + conn->databuf[3] = (uint8_t)(conn->datalen >> 8); + conn->databuf[4] = (uint8_t)(conn->datalen); + conn->datalen += 5; } tls_seq_num_incr(seq_num); diff --git a/src/tls12.c b/src/tls12.c index e5342ff1..b76af9dc 100644 --- a/src/tls12.c +++ b/src/tls12.c @@ -63,133 +63,6 @@ int tls12_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int f return tls_record_print(fp, record, recordlen, format, indent); } -int tls12_gcm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], - const uint8_t seq_num[8], const uint8_t header[5], - const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) -{ - uint8_t nonce[12]; - uint8_t aad[13]; - uint8_t *explicit_nonce; - uint8_t *gmac; - - if (!key || !fixed_iv || !seq_num || !header || (!in && inlen) || !out || !outlen) { - error_print(); - return -1; - } - if (inlen > TLS_MAX_PLAINTEXT_SIZE) { - error_print(); - return -1; - } - if ((((size_t)header[3]) << 8) + header[4] != inlen) { - error_print(); - return -1; - } - - memcpy(nonce, fixed_iv, 4); - memcpy(nonce + 4, seq_num, 8); - - memcpy(aad, seq_num, 8); - memcpy(aad + 8, header, 5); - - explicit_nonce = out; - memcpy(explicit_nonce, seq_num, 8); - out += 8; - - gmac = out + inlen; - if (gcm_encrypt(key, nonce, sizeof(nonce), aad, sizeof(aad), in, inlen, out, GHASH_SIZE, gmac) != 1) { - error_print(); - return -1; - } - - *outlen = 8 + inlen + GHASH_SIZE; - return 1; -} - -int tls12_record_gcm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], - const uint8_t seq_num[8], const uint8_t *in, size_t inlen, - uint8_t *out, size_t *outlen) -{ - if (tls12_gcm_encrypt(key, fixed_iv, seq_num, in, - in + 5, inlen - 5, - out + 5, outlen) != 1) { - error_print(); - return -1; - } - - out[0] = in[0]; - out[1] = in[1]; - out[2] = in[2]; - out[3] = (uint8_t)((*outlen) >> 8); - out[4] = (uint8_t)(*outlen); - (*outlen) += 5; - return 1; -} - -static int tls12_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], - const uint8_t seq_num[8], const uint8_t header[5], - const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) -{ - uint8_t nonce[12]; - uint8_t aad[13]; - const uint8_t *explicit_nonce; - const uint8_t *gmac; - size_t mlen; - - if (inlen < 8 + GHASH_SIZE) { - error_print(); - return -1; - } - - explicit_nonce = in; - in += 8; - inlen -= 8; - - if (inlen < GHASH_SIZE) { - error_print(); - return -1; - } - mlen = inlen - GHASH_SIZE; - gmac = in + mlen; - - memcpy(nonce, fixed_iv, 4); - memcpy(nonce + 4, explicit_nonce, 8); - - memcpy(aad, seq_num, 8); - memcpy(aad + 8, header, 5); - aad[11] = (uint8_t)(mlen >> 8); - aad[12] = (uint8_t)mlen; - - if (gcm_decrypt(key, nonce, sizeof(nonce), aad, sizeof(aad), - in, mlen, gmac, GHASH_SIZE, out) != 1) { - error_print(); - return -1; - } - - *outlen = mlen; - return 1; -} - -int tls12_record_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], - const uint8_t seq_num[8], const uint8_t *in, size_t inlen, - uint8_t *out, size_t *outlen) -{ - if (tls12_gcm_decrypt(key, fixed_iv, seq_num, in, - in + 5, inlen - 5, - out + 5, outlen) != 1) { - error_print(); - return -1; - } - - out[0] = in[0]; - out[1] = in[1]; - out[2] = in[2]; - out[3] = (uint8_t)((*outlen) >> 8); - out[4] = (uint8_t)(*outlen); - (*outlen) += 5; - - return 1; -} - static int tls12_record_encrypt(int cipher_suite, const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], const uint8_t seq_num[8], const uint8_t *in, size_t inlen, @@ -198,14 +71,18 @@ static int tls12_record_encrypt(int cipher_suite, switch (cipher_suite) { case TLS_cipher_ecdhe_sm4_gcm_sm3: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: - if (tls12_record_gcm_encrypt(key, fixed_iv, seq_num, in, inlen, out, outlen) != 1) { + if (tls_gcm_encrypt(key, fixed_iv, seq_num, in, + in + 5, inlen - 5, + out + 5, outlen) != 1) { error_print(); return -1; } break; case TLS_cipher_ecdhe_sm4_cbc_sm3: case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: - if (tls_record_cbc_encrypt(hmac_ctx, key, seq_num, in, inlen, out, outlen) != 1) { + if (tls_cbc_encrypt(hmac_ctx, key, seq_num, in, + in + 5, inlen - 5, + out + 5, outlen) != 1) { error_print(); return -1; } @@ -214,6 +91,13 @@ static int tls12_record_encrypt(int cipher_suite, error_print(); return -1; } + + out[0] = in[0]; + out[1] = in[1]; + out[2] = in[2]; + out[3] = (uint8_t)((*outlen) >> 8); + out[4] = (uint8_t)(*outlen); + (*outlen) += 5; return 1; } @@ -225,14 +109,18 @@ int tls12_record_decrypt(int cipher_suite, const HMAC_CTX *hmac_ctx, switch (cipher_suite) { case TLS_cipher_ecdhe_sm4_gcm_sm3: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: - if (tls12_record_gcm_decrypt(key, fixed_iv, seq_num, in, inlen, out, outlen) != 1) { + if (tls_gcm_decrypt(key, fixed_iv, seq_num, in, + in + 5, inlen - 5, + out + 5, outlen) != 1) { error_print(); return -1; } break; case TLS_cipher_ecdhe_sm4_cbc_sm3: case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: - if (tls_record_cbc_decrypt(hmac_ctx, key, seq_num, in, inlen, out, outlen) != 1) { + if (tls_cbc_decrypt(hmac_ctx, key, seq_num, in, + in + 5, inlen - 5, + out + 5, outlen) != 1) { error_print(); return -1; } @@ -241,6 +129,13 @@ int tls12_record_decrypt(int cipher_suite, const HMAC_CTX *hmac_ctx, error_print(); return -1; } + + out[0] = in[0]; + out[1] = in[1]; + out[2] = in[2]; + out[3] = (uint8_t)((*outlen) >> 8); + out[4] = (uint8_t)(*outlen); + (*outlen) += 5; return 1; } diff --git a/src/tls13.c b/src/tls13.c index 92d415ea..1251b90a 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -104,53 +104,6 @@ int tls13_padding_len_rand(size_t *padding_len) -int gcm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t *iv, size_t ivlen, - const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen, - uint8_t *out, size_t taglen, uint8_t *tag) -{ - if (key->cipher == BLOCK_CIPHER_sm4()) { - if (sm4_gcm_encrypt(&(key->u.sm4_key), iv, ivlen, aad, aadlen, in, inlen, out, taglen, tag) != 1) { - error_print(); - return -1; - } -// 避免在tls13.c中引入宏 -#ifdef ENABLE_AES - } else if (key->cipher == BLOCK_CIPHER_aes128()) { - if (aes_gcm_encrypt(&(key->u.aes_key), iv, ivlen, aad, aadlen, in, inlen, out, taglen, tag) != 1) { - error_print(); - return -1; - } -#endif - } else { - error_print(); - return -1; - } - return 1; -} - -int gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t *iv, size_t ivlen, - const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen, - const uint8_t *tag, size_t taglen, uint8_t *out) -{ - if (key->cipher == BLOCK_CIPHER_sm4()) { - if (sm4_gcm_decrypt(&(key->u.sm4_key), iv, ivlen, aad, aadlen, in, inlen, tag, taglen, out) != 1) { - error_print(); - return -1; - } -#ifdef ENABLE_AES - } else if (key->cipher == BLOCK_CIPHER_aes128()) { - if (aes_gcm_decrypt(&(key->u.aes_key), iv, ivlen, aad, aadlen, in, inlen, tag, taglen, out) != 1) { - error_print(); - return -1; - } -#endif - } else { - error_print(); - return -1; - } - return 1; -} - /* struct { opaque content[TLSPlaintext.length]; @@ -173,7 +126,7 @@ int tls13_gcm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], uint8_t nonce[12]; uint8_t aad[5]; uint8_t *gmac; - uint8_t *mbuf = NULL; // FIXME: update gcm_encrypt API + uint8_t *mbuf = NULL; size_t mlen, clen; if (!(mbuf = malloc(inlen + 256))) { @@ -201,7 +154,27 @@ int tls13_gcm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], aad[4] = (uint8_t)(clen); gmac = out + mlen; - if (gcm_encrypt(key, nonce, sizeof(nonce), aad, sizeof(aad), mbuf, mlen, out, 16, gmac) != 1) { + + switch (key->cipher->oid) { + case OID_sm4: + if (sm4_gcm_encrypt(&(key->u.sm4_key), nonce, sizeof(nonce), aad, sizeof(aad), + mbuf, mlen, out, 16, gmac) != 1) { + error_print(); + free(mbuf); + return -1; + } + break; +#ifdef ENABLE_AES + case OID_aes128: + if (aes_gcm_encrypt(&(key->u.aes_key), nonce, sizeof(nonce), aad, sizeof(aad), + mbuf, mlen, out, 16, gmac) != 1) { + error_print(); + free(mbuf); + return -1; + } + break; +#endif + default: error_print(); free(mbuf); return -1; @@ -240,7 +213,24 @@ int tls13_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], mlen = inlen - GHASH_SIZE; gmac = in + mlen; - if (gcm_decrypt(key, nonce, 12, aad, 5, in, mlen, gmac, GHASH_SIZE, out) != 1) { + switch (key->cipher->oid) { + case OID_sm4: + if (sm4_gcm_decrypt(&(key->u.sm4_key), nonce, sizeof(nonce), aad, sizeof(aad), + in, mlen, gmac, GHASH_SIZE, out) != 1) { + error_print(); + return -1; + } + break; + case OID_aes128: +#ifdef ENABLE_AES + if (aes_gcm_decrypt(&(key->u.aes_key), nonce, sizeof(nonce), aad, sizeof(aad), + in, mlen, gmac, GHASH_SIZE, out) != 1) { + error_print(); + return -1; + } + break; +#endif + default: error_print(); return -1; } diff --git a/tests/ghashtest.c b/tests/ghashtest.c index e3a2cab9..6c74e92f 100644 --- a/tests/ghashtest.c +++ b/tests/ghashtest.c @@ -115,72 +115,6 @@ int test_ghash(void) return 1; } -#if 0 -int test_gcm(void) -{ - BLOCK_CIPHER_KEY block_key; - uint8_t key[16]; - uint8_t iv[12]; - uint8_t aad[64]; - uint8_t in[100]; - uint8_t out[sizeof(in)]; - uint8_t buf[sizeof(in)]; - uint8_t tag[16]; - - rand_bytes(key, sizeof(key)); - rand_bytes(iv, sizeof(iv)); - rand_bytes(aad, sizeof(aad)); - rand_bytes(in, sizeof(in)); - -#ifdef ENABLE_AES - memset(out, 0, sizeof(out)); - memset(buf, 0, sizeof(buf)); - memset(tag, 0, sizeof(tag)); - - if (block_cipher_set_encrypt_key(&block_key, BLOCK_CIPHER_aes128(), key) != 1) { - error_print(); - return -1; - } - if (gcm_encrypt(&block_key, iv, sizeof(iv), aad, sizeof(aad), in, sizeof(in), out, sizeof(tag), tag) != 1) { - error_print(); - return -1; - } - if (gcm_decrypt(&block_key, iv, sizeof(iv), aad, sizeof(aad), out, sizeof(out), tag, sizeof(tag), buf) != 1) { - error_print(); - return -1; - } - if (memcmp(buf, in, sizeof(in)) != 0) { - error_print(); - return -1; - } -#endif // ENABLE_AES - - memset(out, 0, sizeof(out)); - memset(buf, 0, sizeof(buf)); - memset(tag, 0, sizeof(tag)); - - if (block_cipher_set_encrypt_key(&block_key, BLOCK_CIPHER_sm4(), key) != 1) { - error_print(); - return -1; - } - if (gcm_encrypt(&block_key, iv, sizeof(iv), aad, sizeof(aad), in, sizeof(in), out, sizeof(tag), tag) != 1) { - error_print(); - return -1; - } - if (gcm_decrypt(&block_key, iv, sizeof(iv), aad, sizeof(aad), out, sizeof(out), tag, sizeof(tag), buf) != 1) { - error_print(); - return -1; - } - if (memcmp(buf, in, sizeof(in)) != 0) { - error_print(); - return -1; - } - - printf("%s() ok\n", __FUNCTION__); - return 1; -} -#endif - static int speed_ghash(void) { GHASH_CTX ghash_ctx;