From 545e6a56f05ff4d2a4a2204d47f1e9c200cfa9a1 Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Sun, 14 Jun 2026 00:12:10 +0800 Subject: [PATCH] Add CCM cipher suites --- CMakeLists.txt | 11 +- include/gmssl/aes.h | 15 +++ include/gmssl/tls.h | 10 +- include/gmssl/version.h | 2 +- src/aes_modes.c | 262 ++++++++++++++++++++++++++++++++++++++++ src/tls.c | 145 +++++++++++++++++++++- src/tls12.c | 56 +++++++++ src/tls13.c | 236 +++++++++++++++++++++++++++++++----- src/tls_psk.c | 6 +- src/tls_trace.c | 4 + tests/aestest.c | 54 +++++++++ tests/tls13test.c | 53 ++++++++ tests/tlstest.c | 51 ++++++++ 13 files changed, 869 insertions(+), 36 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 15af1875..9931e2be 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -67,6 +67,7 @@ option(ENABLE_KYBER "Enable Kyber" OFF) option(ENABLE_SHA1 "Enable SHA1" OFF) option(ENABLE_SHA2 "Enable SHA2" ON) option(ENABLE_AES "Enable AES" ON) +option(ENABLE_AES_CCM "Enable AES CCM mode" OFF) option(ENABLE_CHACHA20 "Enable Chacha20" OFF) option(ENABLE_ZUC "Enable ZUC" ON) option(ENABLE_GHASH "Enable standalone GHASH command and test" OFF) @@ -528,6 +529,14 @@ if (ENABLE_AES) list(APPEND tests aes) endif() +if (ENABLE_AES_CCM) + if (NOT ENABLE_AES) + message(FATAL_ERROR "ENABLE_AES_CCM requires ENABLE_AES") + endif() + message(STATUS "ENABLE_AES_CCM is ON") + add_definitions(-DENABLE_AES_CCM) +endif() + if (ENABLE_CHACHA20) message(STATUS "ENABLE_CHACHA20 is ON") @@ -768,7 +777,7 @@ endif() # set(CPACK_PACKAGE_NAME "GmSSL") set(CPACK_PACKAGE_VENDOR "GmSSL develop team") -set(CPACK_PACKAGE_VERSION "3.2.0-dev.1037") +set(CPACK_PACKAGE_VERSION "3.2.0-dev.1038") set(CPACK_PACKAGE_DESCRIPTION_FILE ${PROJECT_SOURCE_DIR}/README.md) set(CPACK_NSIS_MODIFY_PATH ON) include(CPack) diff --git a/include/gmssl/aes.h b/include/gmssl/aes.h index e43b27ac..ba06ada0 100644 --- a/include/gmssl/aes.h +++ b/include/gmssl/aes.h @@ -83,6 +83,21 @@ int aes_gcm_decrypt(const AES_KEY *key, const uint8_t *iv, size_t ivlen, const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen, const uint8_t *tag, size_t taglen, uint8_t *out); +#ifdef ENABLE_AES_CCM +#define AES_CCM_MIN_IV_SIZE 7 +#define AES_CCM_MAX_IV_SIZE 13 +#define AES_CCM_MIN_TAG_SIZE 4 +#define AES_CCM_MAX_TAG_SIZE 16 +#define AES_CCM_DEFAULT_TAG_SIZE 16 + +int aes_ccm_encrypt(const AES_KEY *key, const uint8_t *iv, size_t ivlen, + const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen, + uint8_t *out, size_t taglen, uint8_t *tag); +int aes_ccm_decrypt(const AES_KEY *key, const uint8_t *iv, size_t ivlen, + const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen, + const uint8_t *tag, size_t taglen, uint8_t *out); +#endif + #ifdef __cplusplus } diff --git a/include/gmssl/tls.h b/include/gmssl/tls.h index a2813eb0..01c6bb6d 100644 --- a/include/gmssl/tls.h +++ b/include/gmssl/tls.h @@ -419,6 +419,12 @@ int tls_gcm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], int tls_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], const uint8_t seq_num[8], const uint8_t header[5], const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); +int tls_ccm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], + const uint8_t seq_num[8], const uint8_t header[5], + const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); +int tls_ccm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], + const uint8_t seq_num[8], const uint8_t header[5], + const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); int tls12_record_decrypt(int cipher_suite, const HMAC_CTX *hmac_ctx, const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], const uint8_t seq_num[8], const uint8_t *in, size_t inlen, @@ -1725,10 +1731,10 @@ int tls13_gcm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], int tls13_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], const uint8_t seq_num[8], const uint8_t *in, size_t inlen, int *record_type, uint8_t *out, size_t *outlen); -int tls13_record_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], +int tls13_record_encrypt(int cipher_suite, const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], const uint8_t seq_num[8], const uint8_t *record, size_t recordlen, size_t padding_len, uint8_t *enced_record, size_t *enced_recordlen); -int tls13_record_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], +int tls13_record_decrypt(int cipher_suite, const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], const uint8_t seq_num[8], const uint8_t *enced_record, size_t enced_recordlen, uint8_t *record, size_t *recordlen); diff --git a/include/gmssl/version.h b/include/gmssl/version.h index 235067c0..7ed9e85a 100644 --- a/include/gmssl/version.h +++ b/include/gmssl/version.h @@ -19,7 +19,7 @@ extern "C" { // Also update CPACK_PACKAGE_VERSION in CMakeLists.txt #define GMSSL_VERSION_NUM 30200 -#define GMSSL_VERSION_STR "GmSSL 3.2.0-dev.1037" +#define GMSSL_VERSION_STR "GmSSL 3.2.0-dev.1038" int gmssl_version_num(void); const char *gmssl_version_str(void); diff --git a/src/aes_modes.c b/src/aes_modes.c index 08c1110a..be96cbf4 100644 --- a/src/aes_modes.c +++ b/src/aes_modes.c @@ -129,6 +129,268 @@ void aes_ctr_encrypt(const AES_KEY *key, uint8_t ctr[16], const uint8_t *in, siz } } +#ifdef ENABLE_AES_CCM +static void length_to_bytes(size_t len, size_t nbytes, uint8_t *out) +{ + uint8_t *p = out + nbytes - 1; + while (nbytes--) { + *p-- = len & 0xff; + len >>= 8; + } +} + +static void ctr_n_incr(uint8_t a[16], size_t n) +{ + size_t i; + for (i = 15; i >= 16 - n; i--) { + a[i]++; + if (a[i]) break; + } +} + +static void aes_ctr_n_encrypt(const AES_KEY *key, uint8_t ctr[16], size_t n, const uint8_t *in, size_t inlen, uint8_t *out) +{ + uint8_t block[16]; + size_t len; + + while (inlen) { + len = inlen < 16 ? inlen : 16; + aes_encrypt(key, ctr, block); + gmssl_memxor(out, in, block, len); + ctr_n_incr(ctr, n); + in += len; + out += len; + inlen -= len; + } +} + +typedef struct { + AES_KEY key; + uint8_t iv[16]; + size_t ivlen; +} AES_CBC_MAC_CTX; + +static int aes_cbc_mac_update(AES_CBC_MAC_CTX *ctx, const uint8_t *data, size_t datalen) +{ + if (!ctx || (!data && datalen)) { + error_print(); + return -1; + } + if (ctx->ivlen >= 16) { + error_print(); + return -1; + } + if (!data || !datalen) { + return 1; + } + while (datalen) { + size_t ivleft = 16 - ctx->ivlen; + size_t len = datalen < ivleft ? datalen : ivleft; + gmssl_memxor(ctx->iv + ctx->ivlen, ctx->iv + ctx->ivlen, data, len); + ctx->ivlen += len; + if (ctx->ivlen >= 16) { + aes_encrypt(&ctx->key, ctx->iv, ctx->iv); + ctx->ivlen = 0; + } + data += len; + datalen -= len; + } + return 1; +} + +static int aes_cbc_mac_finish(AES_CBC_MAC_CTX *ctx, uint8_t mac[16]) +{ + if (!ctx || !mac) { + error_print(); + return -1; + } + if (ctx->ivlen >= 16) { + error_print(); + return -1; + } + if (ctx->ivlen) { + aes_encrypt(&ctx->key, ctx->iv, ctx->iv); + ctx->ivlen = 0; + } + memcpy(mac, ctx->iv, 16); + return 1; +} + +int aes_ccm_encrypt(const AES_KEY *key, const uint8_t *iv, size_t ivlen, + const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen, + uint8_t *out, size_t taglen, uint8_t *tag) +{ + AES_CBC_MAC_CTX mac_ctx; + const uint8_t zeros[16] = {0}; + uint8_t block[16] = {0}; + uint8_t ctr[16] = {0}; + uint8_t mac[16]; + size_t inlen_size; + + if (!key || !iv || (!aad && aadlen) || (!in && inlen) || !out || !tag) { + error_print(); + return -1; + } + if (ivlen < 7 || ivlen > 13) { + error_print(); + return -1; + } + if (taglen < 4 || taglen > 16 || taglen & 1) { + error_print(); + return -1; + } + + inlen_size = 15 - ivlen; + if (inlen_size < 8 && inlen >= ((size_t)1 << (inlen_size * 8))) { + error_print(); + return -1; + } + + memset(&mac_ctx, 0, sizeof(mac_ctx)); + mac_ctx.key = *key; + + block[0] |= ((aadlen > 0) & 0x1) << 6; + block[0] |= (((taglen - 2)/2) & 0x7) << 3; + block[0] |= (inlen_size - 1) & 0x7; + memcpy(block + 1, iv, ivlen); + length_to_bytes(inlen, inlen_size, block + 1 + ivlen); + aes_cbc_mac_update(&mac_ctx, block, 16); + + if (aad && aadlen) { + size_t alen; + + if (aadlen < ((1<<16) - (1<<8))) { + length_to_bytes(aadlen, 2, block); + alen = 2; + } else if ((uint64_t)aadlen < ((uint64_t)1<<32)) { + block[0] = 0xff; + block[1] = 0xfe; + length_to_bytes(aadlen, 4, block + 2); + alen = 6; + } else { + block[0] = 0xff; + block[1] = 0xff; + length_to_bytes(aadlen, 8, block + 2); + alen = 10; + } + aes_cbc_mac_update(&mac_ctx, block, alen); + aes_cbc_mac_update(&mac_ctx, aad, aadlen); + if ((alen + aadlen) % 16) { + aes_cbc_mac_update(&mac_ctx, zeros, 16 - (alen + aadlen)%16); + } + } + + ctr[0] = 0; + ctr[0] |= (inlen_size - 1) & 0x7; + memcpy(ctr + 1, iv, ivlen); + memset(ctr + 1 + ivlen, 0, 15 - ivlen); + aes_encrypt(key, ctr, block); + + ctr[15] = 1; + aes_ctr_n_encrypt(key, ctr, 15 - ivlen, in, inlen, out); + + aes_cbc_mac_update(&mac_ctx, in, inlen); + if (inlen % 16) { + aes_cbc_mac_update(&mac_ctx, zeros, 16 - inlen % 16); + } + aes_cbc_mac_finish(&mac_ctx, mac); + gmssl_memxor(tag, mac, block, taglen); + + gmssl_secure_clear(&mac_ctx, sizeof(mac_ctx)); + return 1; +} + +int aes_ccm_decrypt(const AES_KEY *key, const uint8_t *iv, size_t ivlen, + const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen, + const uint8_t *tag, size_t taglen, uint8_t *out) +{ + AES_CBC_MAC_CTX mac_ctx; + const uint8_t zeros[16] = {0}; + uint8_t block[16] = {0}; + uint8_t ctr[16] = {0}; + uint8_t mac[16]; + size_t inlen_size; + + if (!key || !iv || (!aad && aadlen) || (!in && inlen) || !tag || !out) { + error_print(); + return -1; + } + if (ivlen < 7 || ivlen > 13) { + error_print(); + return -1; + } + if (taglen < 4 || taglen > 16 || taglen & 1) { + error_print(); + return -1; + } + + inlen_size = 15 - ivlen; + if (inlen_size < 8 && inlen >= ((size_t)1 << (inlen_size * 8))) { + error_print(); + return -1; + } + + memset(&mac_ctx, 0, sizeof(mac_ctx)); + mac_ctx.key = *key; + + block[0] |= ((aadlen > 0) & 0x1) << 6; + block[0] |= (((taglen - 2)/2) & 0x7) << 3; + block[0] |= (inlen_size - 1) & 0x7; + memcpy(block + 1, iv, ivlen); + length_to_bytes(inlen, inlen_size, block + 1 + ivlen); + aes_cbc_mac_update(&mac_ctx, block, 16); + + if (aad && aadlen) { + size_t alen; + + if (aadlen < ((1<<16) - (1<<8))) { + length_to_bytes(aadlen, 2, block); + alen = 2; + } else if ((uint64_t)aadlen < ((uint64_t)1<<32)) { + block[0] = 0xff; + block[1] = 0xfe; + length_to_bytes(aadlen, 4, block + 2); + alen = 6; + } else { + block[0] = 0xff; + block[1] = 0xff; + length_to_bytes(aadlen, 8, block + 2); + alen = 10; + } + aes_cbc_mac_update(&mac_ctx, block, alen); + aes_cbc_mac_update(&mac_ctx, aad, aadlen); + if ((alen + aadlen) % 16) { + aes_cbc_mac_update(&mac_ctx, zeros, 16 - (alen + aadlen)%16); + } + } + + ctr[0] = 0; + ctr[0] |= (inlen_size - 1) & 0x7; + memcpy(ctr + 1, iv, ivlen); + memset(ctr + 1 + ivlen, 0, 15 - ivlen); + aes_encrypt(key, ctr, block); + + ctr[15] = 1; + aes_ctr_n_encrypt(key, ctr, 15 - ivlen, in, inlen, out); + + aes_cbc_mac_update(&mac_ctx, out, inlen); + if (inlen % 16) { + aes_cbc_mac_update(&mac_ctx, zeros, 16 - inlen % 16); + } + aes_cbc_mac_finish(&mac_ctx, mac); + + gmssl_memxor(mac, mac, block, taglen); + if (gmssl_secure_memcmp(mac, tag, taglen) != 0) { + error_print(); + gmssl_secure_clear(&mac_ctx, sizeof(mac_ctx)); + return -1; + } + + gmssl_secure_clear(&mac_ctx, sizeof(mac_ctx)); + return 1; +} +#endif + static void ctr32_incr(uint8_t a[16]) { diff --git a/src/tls.c b/src/tls.c index e399eb07..1b494f9b 100644 --- a/src/tls.c +++ b/src/tls.c @@ -578,6 +578,130 @@ int tls_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], return 1; } +int tls_ccm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], + const uint8_t seq_num[8], const uint8_t header[5], + const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) +{ + uint8_t nonce[12]; + uint8_t aad[13]; + uint8_t *explicit_nonce; + uint8_t *tag; + + if (!key || !fixed_iv || !seq_num || !header || (!in && inlen) || !out || !outlen) { + error_print(); + return -1; + } + if (inlen > TLS_MAX_PLAINTEXT_SIZE) { + error_print(); + return -1; + } + if ((((size_t)header[3]) << 8) + header[4] != inlen) { + error_print(); + return -1; + } + + memcpy(nonce, fixed_iv, 4); + memcpy(nonce + 4, seq_num, 8); + + memcpy(aad, seq_num, 8); + memcpy(aad + 8, header, 5); + + explicit_nonce = out; + memcpy(explicit_nonce, seq_num, 8); + out += 8; + + tag = out + inlen; + + switch (key->cipher->oid) { +#ifdef ENABLE_SM4_CCM + case OID_sm4: + if (sm4_ccm_encrypt(&(key->u.sm4_key), nonce, sizeof(nonce), aad, sizeof(aad), + in, inlen, out, GHASH_SIZE, tag) != 1) { + error_print(); + return -1; + } + break; +#endif +#ifdef ENABLE_AES_CCM + case OID_aes128: + if (aes_ccm_encrypt(&(key->u.aes_key), nonce, sizeof(nonce), aad, sizeof(aad), + in, inlen, out, GHASH_SIZE, tag) != 1) { + error_print(); + return -1; + } + break; +#endif + default: + error_print(); + return -1; + } + + *outlen = 8 + inlen + GHASH_SIZE; + return 1; +} + +int tls_ccm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], + const uint8_t seq_num[8], const uint8_t header[5], + const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) +{ + uint8_t nonce[12]; + uint8_t aad[13]; + const uint8_t *explicit_nonce; + const uint8_t *tag; + size_t mlen; + + if (inlen < 8 + GHASH_SIZE) { + error_print(); + return -1; + } + + explicit_nonce = in; + in += 8; + inlen -= 8; + + if (inlen < GHASH_SIZE) { + error_print(); + return -1; + } + mlen = inlen - GHASH_SIZE; + tag = in + mlen; + + memcpy(nonce, fixed_iv, 4); + memcpy(nonce + 4, explicit_nonce, 8); + + memcpy(aad, seq_num, 8); + memcpy(aad + 8, header, 5); + aad[11] = (uint8_t)(mlen >> 8); + aad[12] = (uint8_t)mlen; + + switch (key->cipher->oid) { +#ifdef ENABLE_SM4_CCM + case OID_sm4: + if (sm4_ccm_decrypt(&(key->u.sm4_key), nonce, sizeof(nonce), aad, sizeof(aad), + in, mlen, tag, GHASH_SIZE, out) != 1) { + error_print(); + return -1; + } + break; +#endif +#ifdef ENABLE_AES_CCM + case OID_aes128: + if (aes_ccm_decrypt(&(key->u.aes_key), nonce, sizeof(nonce), aad, sizeof(aad), + in, mlen, tag, GHASH_SIZE, out) != 1) { + error_print(); + return -1; + } + break; +#endif + default: + error_print(); + return -1; + } + + *outlen = mlen; + return 1; +} + int tls_random_generate(uint8_t random[32]) { uint32_t gmt_unix_time = (uint32_t)time(NULL); @@ -1639,11 +1763,20 @@ static const int tls12_ciphers[] = { TLS_cipher_ecdhe_sm4_gcm_sm3, TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256, TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256, +#ifdef ENABLE_AES_CCM + TLS_cipher_aes_128_ccm_sha256, +#endif }; static const int tls13_ciphers[] = { TLS_cipher_sm4_gcm_sm3, +#ifdef ENABLE_SM4_CCM + TLS_cipher_sm4_ccm_sm3, +#endif TLS_cipher_aes_128_gcm_sha256, +#ifdef ENABLE_AES_CCM + TLS_cipher_aes_128_ccm_sha256, +#endif }; int tls_cipher_suite_match_protocol(int cipher, int protocol) @@ -1962,6 +2095,16 @@ static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *i return -1; } break; +#ifdef ENABLE_AES_CCM + case TLS_cipher_aes_128_ccm_sha256: + if (tls_ccm_encrypt(enc_key, fixed_iv, seq_num, conn->databuf, + conn->databuf + 5, tls_record_data_length(conn->databuf), + conn->record + 5, &recordlen) != 1) { + error_print(); + return -1; + } + break; +#endif case TLS_cipher_ecdhe_sm4_cbc_sm3: case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: if (tls_cbc_encrypt(hmac_ctx, enc_key, seq_num, conn->databuf, @@ -2233,7 +2376,7 @@ static int tls13_send_close_notify(TLS_CONNECT *conn) tls_record_set_alert(conn->plain_record, &conn->plain_recordlen, TLS_alert_level_warning, TLS_alert_close_notify); tls13_padding_len_rand(&padding_len); - if (tls13_record_encrypt(key, iv, seq_num, conn->plain_record, conn->plain_recordlen, + if (tls13_record_encrypt(conn->cipher_suite, key, iv, seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->record, &conn->recordlen) != 1) { error_print(); return -1; diff --git a/src/tls12.c b/src/tls12.c index b76af9dc..d0a2b6b7 100644 --- a/src/tls12.c +++ b/src/tls12.c @@ -49,6 +49,9 @@ const int tls12_cipher_suites[] = { #if defined(ENABLE_AES) && defined(ENABLE_SHA2) && defined(ENABLE_SECP256R1) TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256, TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256, +#ifdef ENABLE_AES_CCM + TLS_cipher_aes_128_ccm_sha256, +#endif #endif }; const size_t tls12_cipher_suites_cnt = @@ -78,6 +81,16 @@ static int tls12_record_encrypt(int cipher_suite, return -1; } break; +#ifdef ENABLE_AES_CCM + case TLS_cipher_aes_128_ccm_sha256: + if (tls_ccm_encrypt(key, fixed_iv, seq_num, in, + in + 5, inlen - 5, + out + 5, outlen) != 1) { + error_print(); + return -1; + } + break; +#endif case TLS_cipher_ecdhe_sm4_cbc_sm3: case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: if (tls_cbc_encrypt(hmac_ctx, key, seq_num, in, @@ -116,6 +129,16 @@ int tls12_record_decrypt(int cipher_suite, const HMAC_CTX *hmac_ctx, return -1; } break; +#ifdef ENABLE_AES_CCM + case TLS_cipher_aes_128_ccm_sha256: + if (tls_ccm_decrypt(key, fixed_iv, seq_num, in, + in + 5, inlen - 5, + out + 5, outlen) != 1) { + error_print(); + return -1; + } + break; +#endif case TLS_cipher_ecdhe_sm4_cbc_sm3: case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: if (tls_cbc_decrypt(hmac_ctx, key, seq_num, in, @@ -418,6 +441,9 @@ static int tls12_server_key_exchange_params_from_bytes(int cipher_suite, case TLS_cipher_ecdhe_sm4_gcm_sm3: case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: +#ifdef ENABLE_AES_CCM + case TLS_cipher_aes_128_ccm_sha256: +#endif { uint8_t curve_type; uint16_t named_curve; @@ -857,6 +883,9 @@ static int tls12_cipher_suite_get(int cipher_suite, const BLOCK_CIPHER **cipher, #if defined(ENABLE_AES) && defined(ENABLE_SHA2) && defined(ENABLE_SECP256R1) case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: +#ifdef ENABLE_AES_CCM + case TLS_cipher_aes_128_ccm_sha256: +#endif *cipher = BLOCK_CIPHER_aes128(); *digest = DIGEST_sha256(); break; @@ -877,6 +906,9 @@ static int tls12_cipher_suite_match_cert_group(int cipher_suite, int cert_group) #if defined(ENABLE_AES) && defined(ENABLE_SHA2) && defined(ENABLE_SECP256R1) case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: +#ifdef ENABLE_AES_CCM + case TLS_cipher_aes_128_ccm_sha256: +#endif return cert_group == TLS_curve_secp256r1; #endif default: @@ -904,6 +936,9 @@ static int tls12_signature_scheme_match_cipher_suite(int sig_alg, int cipher_sui switch (cipher_suite) { case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: +#ifdef ENABLE_AES_CCM + case TLS_cipher_aes_128_ccm_sha256: +#endif return 1; } #endif @@ -921,6 +956,9 @@ static int tls12_key_exchange_group_match_cipher_suite(int group, int cipher_sui #if defined(ENABLE_AES) && defined(ENABLE_SHA2) && defined(ENABLE_SECP256R1) case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: +#ifdef ENABLE_AES_CCM + case TLS_cipher_aes_128_ccm_sha256: +#endif return group == TLS_curve_secp256r1; #endif default: @@ -1948,6 +1986,9 @@ int tls_recv_server_certificate(TLS_CONNECT *conn) break; case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: +#ifdef ENABLE_AES_CCM + case TLS_cipher_aes_128_ccm_sha256: +#endif server_sig_alg = TLS_sig_ecdsa_secp256r1_sha256; break; default: @@ -2153,6 +2194,9 @@ int tls_curve_match_cipher_suite(int named_curve, int cipher_suite) switch (cipher_suite) { case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: +#ifdef ENABLE_AES_CCM + case TLS_cipher_aes_128_ccm_sha256: +#endif break; default: error_print(); @@ -2185,6 +2229,9 @@ int tls_signature_scheme_match_cipher_suite(int sig_alg, int cipher_suite) switch (cipher_suite) { case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: +#ifdef ENABLE_AES_CCM + case TLS_cipher_aes_128_ccm_sha256: +#endif break; default: error_print(); @@ -2261,6 +2308,9 @@ int tls_recv_server_key_exchange(TLS_CONNECT *conn) case TLS_cipher_ecdhe_sm4_gcm_sm3: case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: +#ifdef ENABLE_AES_CCM + case TLS_cipher_aes_128_ccm_sha256: +#endif { const uint8_t *p = server_ecdh_params; size_t len = server_ecdh_params_len; @@ -2716,6 +2766,9 @@ static int tls12_generate_key_block(TLS_CONNECT *conn) switch (conn->cipher_suite) { case TLS_cipher_ecdhe_sm4_gcm_sm3: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: +#ifdef ENABLE_AES_CCM + case TLS_cipher_aes_128_ccm_sha256: +#endif { size_t keylen = conn->cipher->key_size; size_t key_block_len = keylen * 2 + 8; @@ -2771,6 +2824,9 @@ static int tls12_generate_record_keys(TLS_CONNECT *conn) switch (conn->cipher_suite) { case TLS_cipher_ecdhe_sm4_gcm_sm3: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: +#ifdef ENABLE_AES_CCM + case TLS_cipher_aes_128_ccm_sha256: +#endif { size_t keylen = conn->cipher->key_size; diff --git a/src/tls13.c b/src/tls13.c index 1251b90a..56155e81 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -48,8 +48,14 @@ const size_t tls13_signature_algorithms_cnt = const int tls13_cipher_suites[] = { TLS_cipher_sm4_gcm_sm3, +#ifdef ENABLE_SM4_CCM + TLS_cipher_sm4_ccm_sm3, +#endif #if defined(ENABLE_AES) && defined(ENABLE_SHA2) TLS_cipher_aes_128_gcm_sha256, +#ifdef ENABLE_AES_CCM + TLS_cipher_aes_128_ccm_sha256, +#endif #endif }; const size_t tls13_cipher_suites_cnt = @@ -72,11 +78,17 @@ int tls13_cipher_suite_get(int cipher_suite, const BLOCK_CIPHER **cipher, const { switch (cipher_suite) { case TLS_cipher_sm4_gcm_sm3: +#ifdef ENABLE_SM4_CCM + case TLS_cipher_sm4_ccm_sm3: +#endif *digest = DIGEST_sm3(); *cipher = BLOCK_CIPHER_sm4(); break; #if defined(ENABLE_AES) && defined(ENABLE_SHA2) case TLS_cipher_aes_128_gcm_sha256: +#ifdef ENABLE_AES_CCM + case TLS_cipher_aes_128_ccm_sha256: +#endif *digest = DIGEST_sha256(); *cipher = BLOCK_CIPHER_aes128(); break; @@ -250,14 +262,163 @@ int tls13_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], return 1; } +static int tls13_ccm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], + const uint8_t seq_num[8], int record_type, + const uint8_t *in, size_t inlen, size_t padding_len, + uint8_t *out, size_t *outlen) +{ + uint8_t nonce[12]; + uint8_t aad[5]; + uint8_t *tag; + uint8_t *mbuf = NULL; + size_t mlen, clen; -int tls13_record_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], + if (!(mbuf = malloc(inlen + 256))) { + error_print(); + return -1; + } + + nonce[0] = nonce[1] = nonce[2] = nonce[3] = 0; + memcpy(nonce + 4, seq_num, 8); + gmssl_memxor(nonce, nonce, iv, 12); + + memcpy(mbuf, in, inlen); + mbuf[inlen] = record_type; + memset(mbuf + inlen + 1, 0, padding_len); + mlen = inlen + 1 + padding_len; + clen = mlen + GHASH_SIZE; + + aad[0] = TLS_record_application_data; + aad[1] = 0x03; + aad[2] = 0x03; + aad[3] = (uint8_t)(clen >> 8); + aad[4] = (uint8_t)(clen); + + tag = out + mlen; + + switch (key->cipher->oid) { +#ifdef ENABLE_SM4_CCM + case OID_sm4: + if (sm4_ccm_encrypt(&(key->u.sm4_key), nonce, sizeof(nonce), aad, sizeof(aad), + mbuf, mlen, out, GHASH_SIZE, tag) != 1) { + error_print(); + free(mbuf); + return -1; + } + break; +#endif +#ifdef ENABLE_AES_CCM + case OID_aes128: + if (aes_ccm_encrypt(&(key->u.aes_key), nonce, sizeof(nonce), aad, sizeof(aad), + mbuf, mlen, out, GHASH_SIZE, tag) != 1) { + error_print(); + free(mbuf); + return -1; + } + break; +#endif + default: + error_print(); + free(mbuf); + return -1; + } + *outlen = clen; + free(mbuf); + return 1; +} + +static int tls13_ccm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], + const uint8_t seq_num[8], const uint8_t *in, size_t inlen, + int *record_type, uint8_t *out, size_t *outlen) +{ + uint8_t nonce[12]; + uint8_t aad[5]; + size_t mlen; + const uint8_t *tag; + + nonce[0] = nonce[1] = nonce[2] = nonce[3] = 0; + memcpy(nonce + 4, seq_num, 8); + gmssl_memxor(nonce, nonce, iv, 12); + + aad[0] = TLS_record_application_data; + aad[1] = 0x03; + aad[2] = 0x03; + aad[3] = (uint8_t)(inlen >> 8); + aad[4] = (uint8_t)(inlen); + + if (inlen < GHASH_SIZE) { + error_print(); + return -1; + } + mlen = inlen - GHASH_SIZE; + tag = in + mlen; + + switch (key->cipher->oid) { +#ifdef ENABLE_SM4_CCM + case OID_sm4: + if (sm4_ccm_decrypt(&(key->u.sm4_key), nonce, sizeof(nonce), aad, sizeof(aad), + in, mlen, tag, GHASH_SIZE, out) != 1) { + error_print(); + return -1; + } + break; +#endif +#ifdef ENABLE_AES_CCM + case OID_aes128: + if (aes_ccm_decrypt(&(key->u.aes_key), nonce, sizeof(nonce), aad, sizeof(aad), + in, mlen, tag, GHASH_SIZE, out) != 1) { + error_print(); + return -1; + } + break; +#endif + default: + error_print(); + return -1; + } + + *record_type = 0; + while (mlen--) { + if (out[mlen] != 0) { + *record_type = out[mlen]; + break; + } + } + *outlen = mlen; + if (!tls_record_type_name(*record_type)) { + error_print(); + return -1; + } + return 1; +} + + +int tls13_record_encrypt(int cipher_suite, const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], const uint8_t seq_num[8], const uint8_t *record, size_t recordlen, size_t padding_len, uint8_t *enced_record, size_t *enced_recordlen) { - if (tls13_gcm_encrypt(key, iv, - seq_num, record[0], record + 5, recordlen - 5, padding_len, - enced_record + 5, enced_recordlen) != 1) { + switch (cipher_suite) { + case TLS_cipher_sm4_gcm_sm3: + case TLS_cipher_aes_128_gcm_sha256: + if (tls13_gcm_encrypt(key, iv, + seq_num, record[0], record + 5, recordlen - 5, padding_len, + enced_record + 5, enced_recordlen) != 1) { + error_print(); + return -1; + } + break; + case TLS_cipher_sm4_ccm_sm3: +#ifdef ENABLE_AES_CCM + case TLS_cipher_aes_128_ccm_sha256: +#endif + if (tls13_ccm_encrypt(key, iv, + seq_num, record[0], record + 5, recordlen - 5, padding_len, + enced_record + 5, enced_recordlen) != 1) { + error_print(); + return -1; + } + break; + default: error_print(); return -1; } @@ -273,15 +434,34 @@ int tls13_record_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], return 1; } -int tls13_record_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], +int tls13_record_decrypt(int cipher_suite, const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], const uint8_t seq_num[8], const uint8_t *enced_record, size_t enced_recordlen, uint8_t *record, size_t *recordlen) { int record_type; - if (tls13_gcm_decrypt(key, iv, - seq_num, enced_record + 5, enced_recordlen - 5, - &record_type, record + 5, recordlen) != 1) { + switch (cipher_suite) { + case TLS_cipher_sm4_gcm_sm3: + case TLS_cipher_aes_128_gcm_sha256: + if (tls13_gcm_decrypt(key, iv, + seq_num, enced_record + 5, enced_recordlen - 5, + &record_type, record + 5, recordlen) != 1) { + error_print(); + return -1; + } + break; + case TLS_cipher_sm4_ccm_sm3: +#ifdef ENABLE_AES_CCM + case TLS_cipher_aes_128_ccm_sha256: +#endif + if (tls13_ccm_decrypt(key, iv, + seq_num, enced_record + 5, enced_recordlen - 5, + &record_type, record + 5, recordlen) != 1) { + error_print(); + return -1; + } + break; + default: error_print(); return -1; } @@ -1251,7 +1431,7 @@ int tls13_do_recv(TLS_CONNECT *conn) //format_print(stderr, 0, 0, "\n"); } - if (tls13_record_decrypt(key, iv, seq_num, conn->record, conn->recordlen, + if (tls13_record_decrypt(conn->cipher_suite, key, iv, seq_num, conn->record, conn->recordlen, conn->plain_record, &conn->plain_recordlen) != 1) { error_print(); return -1; @@ -5261,7 +5441,7 @@ int tls13_recv_encrypted_extensions(TLS_CONNECT *conn) } tls13_record_print(stderr, 0, 0, conn->record, conn->recordlen); - if (tls13_record_decrypt(&conn->server_write_key, conn->server_write_iv, + if (tls13_record_decrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv, conn->server_seq_num, conn->record, conn->recordlen, conn->plain_record, &conn->plain_recordlen) != 1) { error_print(); @@ -5744,7 +5924,7 @@ int tls13_recv_certificate_request(TLS_CONNECT *conn) return ret; } - if (tls13_record_decrypt(&conn->server_write_key, conn->server_write_iv, + if (tls13_record_decrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv, conn->server_seq_num, conn->record, conn->recordlen, conn->plain_record, &conn->plain_recordlen) != 1) { error_print(); @@ -6029,7 +6209,7 @@ int tls13_recv_server_certificate(TLS_CONNECT *conn) // decrypt unless previous handshake is CertificateRequest if (!conn->plain_recordlen) { - if (tls13_record_decrypt(&conn->server_write_key, conn->server_write_iv, + if (tls13_record_decrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv, conn->server_seq_num, conn->record, conn->recordlen, conn->plain_record, &conn->plain_recordlen) != 1) { error_print(); @@ -6163,7 +6343,7 @@ int tls13_recv_server_certificate_verify(TLS_CONNECT *conn) return ret; } - if (tls13_record_decrypt(&conn->server_write_key, conn->server_write_iv, + if (tls13_record_decrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv, conn->server_seq_num, conn->record, conn->recordlen, conn->plain_record, &conn->plain_recordlen) != 1) { error_print(); @@ -6239,7 +6419,7 @@ int tls13_recv_client_certificate_verify(TLS_CONNECT *conn) return ret; } - if (tls13_record_decrypt(&conn->client_write_key, conn->client_write_iv, + if (tls13_record_decrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv, conn->client_seq_num, conn->record, conn->recordlen, conn->plain_record, &conn->plain_recordlen) != 1) { error_print(); @@ -6325,7 +6505,7 @@ int tls13_recv_server_finished(TLS_CONNECT *conn) return ret; } - if (tls13_record_decrypt(&conn->server_write_key, conn->server_write_iv, + if (tls13_record_decrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv, conn->server_seq_num, conn->record, conn->recordlen, conn->plain_record, &conn->plain_recordlen) != 1) { error_print(); @@ -6415,7 +6595,7 @@ int tls13_send_client_certificate(TLS_CONNECT *conn) tls13_padding_len_rand(&padding_len); - if (tls13_record_encrypt(&conn->client_write_key, conn->client_write_iv, + if (tls13_record_encrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv, conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->record, &conn->recordlen) != 1) { error_print(); @@ -6468,7 +6648,7 @@ int tls13_send_client_certificate_verify(TLS_CONNECT *conn) if(conn->verbose) tls_handshake_digest_print(stderr, 0, 0, "after client CertificateVerify", &conn->dgst_ctx); tls13_padding_len_rand(&padding_len); - if (tls13_record_encrypt(&conn->client_write_key, conn->client_write_iv, + if (tls13_record_encrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv, conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->record, &conn->recordlen) != 1) { error_print(); @@ -6516,7 +6696,7 @@ int tls13_send_client_finished(TLS_CONNECT *conn) //format_print(stderr, 0, 0, "client_seq_num: "PRIu64"\n", GETU64(conn->client_seq_num)); tls13_padding_len_rand(&padding_len); - if (tls13_record_encrypt(&conn->client_write_key, conn->client_write_iv, + if (tls13_record_encrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv, conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->record, &conn->recordlen) != 1) { error_print(); @@ -7927,7 +8107,7 @@ int tls13_send_alert(TLS_CONNECT *conn, int alert) break; default: tls13_padding_len_rand(&padding_len); - if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv, + if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv, conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->record, &conn->recordlen) != 1) { error_print(); @@ -8027,7 +8207,7 @@ int tls13_send_encrypted_extensions(TLS_CONNECT *conn) //format_print(stderr, 0, 0, "server_seq_num: "PRIu64"\n", GETU64(conn->server_seq_num)); - if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv, + if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv, conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->record, &conn->recordlen) != 1) { error_print(); @@ -8151,7 +8331,7 @@ int tls13_send_certificate_request(TLS_CONNECT *conn) //format_print(stderr, 0, 0, "server_seq_num: "PRIu64"\n", GETU64(conn->server_seq_num)); tls13_padding_len_rand(&padding_len); - if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv, + if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv, conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->record, &conn->recordlen) != 1) { error_print(); @@ -8218,7 +8398,7 @@ int tls13_send_server_certificate(TLS_CONNECT *conn) if(conn->verbose) tls_handshake_digest_print(stderr, 0, 0, "ServerCertificate", &conn->dgst_ctx); tls13_padding_len_rand(&padding_len); - if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv, + if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv, conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->record, &conn->recordlen) != 1) { error_print(); @@ -8271,7 +8451,7 @@ int tls13_send_server_certificate_verify(TLS_CONNECT *conn) } tls13_padding_len_rand(&padding_len); - if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv, + if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv, conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->record, &conn->recordlen) != 1) { error_print(); @@ -8320,7 +8500,7 @@ int tls13_send_server_finished(TLS_CONNECT *conn) //format_print(stderr, 0, 0, "server_seq_num: "PRIu64"\n", GETU64(conn->server_seq_num)); tls13_padding_len_rand(&padding_len); - if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv, + if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv, conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->record, &conn->recordlen) != 1) { error_print(); @@ -8384,7 +8564,7 @@ int tls13_recv_client_certificate(TLS_CONNECT *conn) //format_print(stderr, 0, 0, "client_seq_num: "PRIu64"\n", GETU64(conn->client_seq_num)); - if (tls13_record_decrypt(&conn->client_write_key, conn->client_write_iv, + if (tls13_record_decrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv, conn->client_seq_num, conn->record, conn->recordlen, conn->plain_record, &conn->plain_recordlen) != 1) { error_print(); @@ -8521,7 +8701,7 @@ int tls13_recv_client_finished(TLS_CONNECT *conn) //format_print(stderr, 0, 0, "client_seq_num: "PRIu64"\n", GETU64(conn->client_seq_num)); - if (tls13_record_decrypt(&conn->client_write_key, conn->client_write_iv, + if (tls13_record_decrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv, conn->client_seq_num, conn->record, conn->recordlen, conn->plain_record, &conn->plain_recordlen) != 1) { error_print(); @@ -8620,7 +8800,7 @@ int tls13_send_client_key_update(TLS_CONNECT *conn, int request_update) tls13_padding_len_rand(&padding_len); - if (tls13_record_encrypt(&conn->client_write_key, conn->client_write_iv, + if (tls13_record_encrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv, conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->record, &conn->recordlen) != 1) { error_print(); @@ -8675,7 +8855,7 @@ int tls13_send_server_key_update(TLS_CONNECT *conn, int request_update) tls13_record_print(stderr, 0, 0, conn->plain_record, conn->plain_recordlen); tls13_padding_len_rand(&padding_len); - if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv, + if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv, conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->record, &conn->recordlen) != 1) { error_print(); diff --git a/src/tls_psk.c b/src/tls_psk.c index 3b8fb990..27a7e905 100644 --- a/src/tls_psk.c +++ b/src/tls_psk.c @@ -683,7 +683,7 @@ int tls13_send_new_session_ticket(TLS_CONNECT *conn) format_print(stderr, 0, 0, "\n"); tls13_padding_len_rand(&padding_len); - if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv, + if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv, conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->record, &conn->recordlen) != 1) { error_print(); @@ -1766,7 +1766,7 @@ int tls13_send_end_of_early_data(TLS_CONNECT *conn) size_t padding_len; tls13_padding_len_rand(&padding_len); - if (tls13_record_encrypt(&conn->client_write_key, conn->client_write_iv, + if (tls13_record_encrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv, conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->record, &conn->recordlen) != 1) { error_print(); @@ -1817,7 +1817,7 @@ int tls13_recv_end_of_early_data(TLS_CONNECT *conn) format_bytes(stderr, 0, 4, "client_write_iv", conn->client_write_iv, 12); - if (tls13_record_decrypt(&conn->client_write_key, conn->client_write_iv, + if (tls13_record_decrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv, conn->client_seq_num, conn->record, conn->recordlen, conn->plain_record, &conn->plain_recordlen) != 1) { error_print(); diff --git a/src/tls_trace.c b/src/tls_trace.c index fea5a26c..00f856ca 100644 --- a/src/tls_trace.c +++ b/src/tls_trace.c @@ -91,8 +91,12 @@ int tls_cipher_suite_from_name(const char *name) { if (!strcmp(name, "TLS_SM4_GCM_SM3")) { return TLS_cipher_sm4_gcm_sm3; + } else if (!strcmp(name, "TLS_SM4_CCM_SM3")) { + return TLS_cipher_sm4_ccm_sm3; } else if (!strcmp(name, "TLS_AES_128_GCM_SHA256")) { return TLS_cipher_aes_128_gcm_sha256; + } else if (!strcmp(name, "TLS_AES_128_CCM_SHA256")) { + return TLS_cipher_aes_128_ccm_sha256; } else if (!strcmp(name, "TLS_ECDHE_SM4_CBC_SM3")) { return TLS_cipher_ecdhe_sm4_cbc_sm3; } else if (!strcmp(name, "TLS_ECDHE_SM4_GCM_SM3")) { diff --git a/tests/aestest.c b/tests/aestest.c index 6194410a..d9734099 100644 --- a/tests/aestest.c +++ b/tests/aestest.c @@ -387,6 +387,57 @@ int test_aes_gcm(void) return 1; } +#ifdef ENABLE_AES_CCM +int test_aes_ccm(void) +{ + AES_KEY aes_key; + uint8_t key[16] = { + 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, + 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f, + }; + uint8_t iv[12] = { + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, + 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + }; + uint8_t aad[20] = { + 0x00, 0x01, 0x02, 0x03, 0x04, + 0x05, 0x06, 0x07, 0x08, 0x09, + 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, + 0x0f, 0x10, 0x11, 0x12, 0x13, + }; + uint8_t in[23] = { + 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, + 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, + 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, + 0x32, 0x33, 0x34, 0x35, 0x36, + }; + uint8_t out[sizeof(in)]; + uint8_t buf[sizeof(in)]; + uint8_t tag[16]; + + aes_set_encrypt_key(&aes_key, key, sizeof(key)); + if (aes_ccm_encrypt(&aes_key, iv, sizeof(iv), aad, sizeof(aad), in, sizeof(in), + out, sizeof(tag), tag) != 1) { + error_print(); + return -1; + } + if (aes_ccm_decrypt(&aes_key, iv, sizeof(iv), aad, sizeof(aad), out, sizeof(out), + tag, sizeof(tag), buf) != 1 || memcmp(buf, in, sizeof(in)) != 0) { + error_print(); + return -1; + } + tag[0] ^= 0x01; + if (aes_ccm_decrypt(&aes_key, iv, sizeof(iv), aad, sizeof(aad), out, sizeof(out), + tag, sizeof(tag), buf) != -1) { + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} +#endif + int test_aes_cbc_pkcs5_wycheproof(void) { size_t i; @@ -465,6 +516,9 @@ int main(void) if (test_aes() != 1) goto err; if (test_aes_ctr() != 1) goto err; if (test_aes_gcm() != 1) goto err; +#ifdef ENABLE_AES_CCM + if (test_aes_ccm() != 1) goto err; +#endif if (test_aes_cbc_pkcs5_wycheproof() != 1) goto err; printf("%s all tests passed!\n", __FILE__); return 0; diff --git a/tests/tls13test.c b/tests/tls13test.c index 199f7673..e9c9aaf7 100644 --- a/tests/tls13test.c +++ b/tests/tls13test.c @@ -97,6 +97,56 @@ static int test_tls13_gcm(void) return 1; } +#ifdef ENABLE_AES_CCM +static int test_tls13_ccm(void) +{ + BLOCK_CIPHER_KEY block_key; + uint8_t key[16]; + uint8_t iv[12]; + uint8_t seq_num[8] = {0,0,0,0,0,0,0,1}; + uint8_t record[5 + 40]; + size_t recordlen; + size_t padding_len = 8; + uint8_t enced_record[256]; + size_t enced_recordlen; + uint8_t buf[256]; + size_t buflen; + + rand_bytes(key, sizeof(key)); + rand_bytes(iv, sizeof(iv)); + rand_bytes(record + 5, 40); + + record[0] = TLS_record_handshake; + record[1] = TLS_protocol_tls12 >> 8; + record[2] = TLS_protocol_tls12 & 0xff; + record[3] = 0; + record[4] = 40; + recordlen = 5 + 40; + + if (block_cipher_set_encrypt_key(&block_key, BLOCK_CIPHER_aes128(), key) != 1) { + error_print(); + return -1; + } + if (tls13_record_encrypt(TLS_cipher_aes_128_ccm_sha256, &block_key, iv, + seq_num, record, recordlen, padding_len, enced_record, &enced_recordlen) != 1) { + error_print(); + return -1; + } + if (tls13_record_decrypt(TLS_cipher_aes_128_ccm_sha256, &block_key, iv, + seq_num, enced_record, enced_recordlen, buf, &buflen) != 1) { + error_print(); + return -1; + } + if (buflen != recordlen || memcmp(buf, record, recordlen) != 0) { + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} +#endif + static int test_tls13_supported_versions_ext(void) { const int client_versions[] = { TLS_protocol_tls13, TLS_protocol_tls12, TLS_protocol_tlcp }; @@ -661,6 +711,9 @@ int main(void) { if (test_tls_ext() != 1) goto err; if (test_tls13_gcm() != 1) goto err; +#ifdef ENABLE_AES_CCM + if (test_tls13_ccm() != 1) goto err; +#endif if (test_tls13_supported_versions_ext() != 1) goto err; if (test_tls13_key_share_ext() != 1) goto err; if (test_tls_supported_groups_ext() != 1) goto err; diff --git a/tests/tlstest.c b/tests/tlstest.c index a371744b..2832f1c5 100644 --- a/tests/tlstest.c +++ b/tests/tlstest.c @@ -111,6 +111,54 @@ static int test_tls_cbc(void) return 1; } +#ifdef ENABLE_AES_CCM +static int test_tls_ccm(void) +{ + uint8_t key[16] = {0}; + BLOCK_CIPHER_KEY aes_key; + uint8_t fixed_iv[4] = {0x10, 0x11, 0x12, 0x13}; + uint8_t seq_num[8] = {0,0,0,0,0,0,0,1}; + uint8_t record[5 + 32]; + uint8_t enced_record[256]; + uint8_t buf[256]; + size_t recordlen; + size_t enced_recordlen; + size_t buflen; + + record[0] = TLS_record_handshake; + record[1] = TLS_protocol_tls12 >> 8; + record[2] = TLS_protocol_tls12 & 0xff; + record[3] = 0; + record[4] = 12; + memcpy(record + 5, "hello world", 12); + recordlen = 5 + 12; + + block_cipher_set_encrypt_key(&aes_key, BLOCK_CIPHER_aes128(), key); + if (tls_ccm_encrypt(&aes_key, fixed_iv, seq_num, record, + record + 5, 12, enced_record + 5, &enced_recordlen) != 1) { + error_print(); + return -1; + } + enced_record[0] = record[0]; + enced_record[1] = record[1]; + enced_record[2] = record[2]; + enced_record[3] = (uint8_t)(enced_recordlen >> 8); + enced_record[4] = (uint8_t)enced_recordlen; + enced_recordlen += 5; + + if (tls12_record_decrypt(TLS_cipher_aes_128_ccm_sha256, NULL, &aes_key, fixed_iv, seq_num, + enced_record, enced_recordlen, buf, &buflen) != 1 + || buflen != recordlen + || memcmp(buf, record, recordlen) != 0) { + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} +#endif + static int test_tls_random(void) { uint8_t random[32]; @@ -439,6 +487,9 @@ static int test_tls_trusted_ca_keys_ext(void) int main(void) { if (test_tls_null_to_bytes() != 1) goto err; +#ifdef ENABLE_AES_CCM + if (test_tls_ccm() != 1) goto err; +#endif /* if (test_tls_encode() != 1) goto err; if (test_tls_cbc() != 1) goto err;