Add CCM cipher suites

This commit is contained in:
Zhi Guan
2026-06-14 00:12:10 +08:00
parent 5d12858d41
commit 545e6a56f0
13 changed files with 869 additions and 36 deletions

View File

@@ -67,6 +67,7 @@ option(ENABLE_KYBER "Enable Kyber" OFF)
option(ENABLE_SHA1 "Enable SHA1" OFF) option(ENABLE_SHA1 "Enable SHA1" OFF)
option(ENABLE_SHA2 "Enable SHA2" ON) option(ENABLE_SHA2 "Enable SHA2" ON)
option(ENABLE_AES "Enable AES" ON) option(ENABLE_AES "Enable AES" ON)
option(ENABLE_AES_CCM "Enable AES CCM mode" OFF)
option(ENABLE_CHACHA20 "Enable Chacha20" OFF) option(ENABLE_CHACHA20 "Enable Chacha20" OFF)
option(ENABLE_ZUC "Enable ZUC" ON) option(ENABLE_ZUC "Enable ZUC" ON)
option(ENABLE_GHASH "Enable standalone GHASH command and test" OFF) option(ENABLE_GHASH "Enable standalone GHASH command and test" OFF)
@@ -528,6 +529,14 @@ if (ENABLE_AES)
list(APPEND tests aes) list(APPEND tests aes)
endif() endif()
if (ENABLE_AES_CCM)
if (NOT ENABLE_AES)
message(FATAL_ERROR "ENABLE_AES_CCM requires ENABLE_AES")
endif()
message(STATUS "ENABLE_AES_CCM is ON")
add_definitions(-DENABLE_AES_CCM)
endif()
if (ENABLE_CHACHA20) if (ENABLE_CHACHA20)
message(STATUS "ENABLE_CHACHA20 is ON") message(STATUS "ENABLE_CHACHA20 is ON")
@@ -768,7 +777,7 @@ endif()
# #
set(CPACK_PACKAGE_NAME "GmSSL") set(CPACK_PACKAGE_NAME "GmSSL")
set(CPACK_PACKAGE_VENDOR "GmSSL develop team") set(CPACK_PACKAGE_VENDOR "GmSSL develop team")
set(CPACK_PACKAGE_VERSION "3.2.0-dev.1037") set(CPACK_PACKAGE_VERSION "3.2.0-dev.1038")
set(CPACK_PACKAGE_DESCRIPTION_FILE ${PROJECT_SOURCE_DIR}/README.md) set(CPACK_PACKAGE_DESCRIPTION_FILE ${PROJECT_SOURCE_DIR}/README.md)
set(CPACK_NSIS_MODIFY_PATH ON) set(CPACK_NSIS_MODIFY_PATH ON)
include(CPack) include(CPack)

View File

@@ -83,6 +83,21 @@ int aes_gcm_decrypt(const AES_KEY *key, const uint8_t *iv, size_t ivlen,
const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen, const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen,
const uint8_t *tag, size_t taglen, uint8_t *out); const uint8_t *tag, size_t taglen, uint8_t *out);
#ifdef ENABLE_AES_CCM
#define AES_CCM_MIN_IV_SIZE 7
#define AES_CCM_MAX_IV_SIZE 13
#define AES_CCM_MIN_TAG_SIZE 4
#define AES_CCM_MAX_TAG_SIZE 16
#define AES_CCM_DEFAULT_TAG_SIZE 16
int aes_ccm_encrypt(const AES_KEY *key, const uint8_t *iv, size_t ivlen,
const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen,
uint8_t *out, size_t taglen, uint8_t *tag);
int aes_ccm_decrypt(const AES_KEY *key, const uint8_t *iv, size_t ivlen,
const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen,
const uint8_t *tag, size_t taglen, uint8_t *out);
#endif
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@@ -419,6 +419,12 @@ int tls_gcm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4],
int tls_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], int tls_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4],
const uint8_t seq_num[8], const uint8_t header[5], const uint8_t seq_num[8], const uint8_t header[5],
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
int tls_ccm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4],
const uint8_t seq_num[8], const uint8_t header[5],
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
int tls_ccm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4],
const uint8_t seq_num[8], const uint8_t header[5],
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
int tls12_record_decrypt(int cipher_suite, const HMAC_CTX *hmac_ctx, int tls12_record_decrypt(int cipher_suite, const HMAC_CTX *hmac_ctx,
const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4], const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4],
const uint8_t seq_num[8], const uint8_t *in, size_t inlen, const uint8_t seq_num[8], const uint8_t *in, size_t inlen,
@@ -1725,10 +1731,10 @@ int tls13_gcm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12],
int tls13_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], int tls13_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12],
const uint8_t seq_num[8], const uint8_t *in, size_t inlen, const uint8_t seq_num[8], const uint8_t *in, size_t inlen,
int *record_type, uint8_t *out, size_t *outlen); int *record_type, uint8_t *out, size_t *outlen);
int tls13_record_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], int tls13_record_encrypt(int cipher_suite, const BLOCK_CIPHER_KEY *key, const uint8_t iv[12],
const uint8_t seq_num[8], const uint8_t *record, size_t recordlen, size_t padding_len, const uint8_t seq_num[8], const uint8_t *record, size_t recordlen, size_t padding_len,
uint8_t *enced_record, size_t *enced_recordlen); uint8_t *enced_record, size_t *enced_recordlen);
int tls13_record_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], int tls13_record_decrypt(int cipher_suite, const BLOCK_CIPHER_KEY *key, const uint8_t iv[12],
const uint8_t seq_num[8], const uint8_t *enced_record, size_t enced_recordlen, const uint8_t seq_num[8], const uint8_t *enced_record, size_t enced_recordlen,
uint8_t *record, size_t *recordlen); uint8_t *record, size_t *recordlen);

View File

@@ -19,7 +19,7 @@ extern "C" {
// Also update CPACK_PACKAGE_VERSION in CMakeLists.txt // Also update CPACK_PACKAGE_VERSION in CMakeLists.txt
#define GMSSL_VERSION_NUM 30200 #define GMSSL_VERSION_NUM 30200
#define GMSSL_VERSION_STR "GmSSL 3.2.0-dev.1037" #define GMSSL_VERSION_STR "GmSSL 3.2.0-dev.1038"
int gmssl_version_num(void); int gmssl_version_num(void);
const char *gmssl_version_str(void); const char *gmssl_version_str(void);

View File

@@ -129,6 +129,268 @@ void aes_ctr_encrypt(const AES_KEY *key, uint8_t ctr[16], const uint8_t *in, siz
} }
} }
#ifdef ENABLE_AES_CCM
static void length_to_bytes(size_t len, size_t nbytes, uint8_t *out)
{
uint8_t *p = out + nbytes - 1;
while (nbytes--) {
*p-- = len & 0xff;
len >>= 8;
}
}
static void ctr_n_incr(uint8_t a[16], size_t n)
{
size_t i;
for (i = 15; i >= 16 - n; i--) {
a[i]++;
if (a[i]) break;
}
}
static void aes_ctr_n_encrypt(const AES_KEY *key, uint8_t ctr[16], size_t n, const uint8_t *in, size_t inlen, uint8_t *out)
{
uint8_t block[16];
size_t len;
while (inlen) {
len = inlen < 16 ? inlen : 16;
aes_encrypt(key, ctr, block);
gmssl_memxor(out, in, block, len);
ctr_n_incr(ctr, n);
in += len;
out += len;
inlen -= len;
}
}
typedef struct {
AES_KEY key;
uint8_t iv[16];
size_t ivlen;
} AES_CBC_MAC_CTX;
static int aes_cbc_mac_update(AES_CBC_MAC_CTX *ctx, const uint8_t *data, size_t datalen)
{
if (!ctx || (!data && datalen)) {
error_print();
return -1;
}
if (ctx->ivlen >= 16) {
error_print();
return -1;
}
if (!data || !datalen) {
return 1;
}
while (datalen) {
size_t ivleft = 16 - ctx->ivlen;
size_t len = datalen < ivleft ? datalen : ivleft;
gmssl_memxor(ctx->iv + ctx->ivlen, ctx->iv + ctx->ivlen, data, len);
ctx->ivlen += len;
if (ctx->ivlen >= 16) {
aes_encrypt(&ctx->key, ctx->iv, ctx->iv);
ctx->ivlen = 0;
}
data += len;
datalen -= len;
}
return 1;
}
static int aes_cbc_mac_finish(AES_CBC_MAC_CTX *ctx, uint8_t mac[16])
{
if (!ctx || !mac) {
error_print();
return -1;
}
if (ctx->ivlen >= 16) {
error_print();
return -1;
}
if (ctx->ivlen) {
aes_encrypt(&ctx->key, ctx->iv, ctx->iv);
ctx->ivlen = 0;
}
memcpy(mac, ctx->iv, 16);
return 1;
}
int aes_ccm_encrypt(const AES_KEY *key, const uint8_t *iv, size_t ivlen,
const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen,
uint8_t *out, size_t taglen, uint8_t *tag)
{
AES_CBC_MAC_CTX mac_ctx;
const uint8_t zeros[16] = {0};
uint8_t block[16] = {0};
uint8_t ctr[16] = {0};
uint8_t mac[16];
size_t inlen_size;
if (!key || !iv || (!aad && aadlen) || (!in && inlen) || !out || !tag) {
error_print();
return -1;
}
if (ivlen < 7 || ivlen > 13) {
error_print();
return -1;
}
if (taglen < 4 || taglen > 16 || taglen & 1) {
error_print();
return -1;
}
inlen_size = 15 - ivlen;
if (inlen_size < 8 && inlen >= ((size_t)1 << (inlen_size * 8))) {
error_print();
return -1;
}
memset(&mac_ctx, 0, sizeof(mac_ctx));
mac_ctx.key = *key;
block[0] |= ((aadlen > 0) & 0x1) << 6;
block[0] |= (((taglen - 2)/2) & 0x7) << 3;
block[0] |= (inlen_size - 1) & 0x7;
memcpy(block + 1, iv, ivlen);
length_to_bytes(inlen, inlen_size, block + 1 + ivlen);
aes_cbc_mac_update(&mac_ctx, block, 16);
if (aad && aadlen) {
size_t alen;
if (aadlen < ((1<<16) - (1<<8))) {
length_to_bytes(aadlen, 2, block);
alen = 2;
} else if ((uint64_t)aadlen < ((uint64_t)1<<32)) {
block[0] = 0xff;
block[1] = 0xfe;
length_to_bytes(aadlen, 4, block + 2);
alen = 6;
} else {
block[0] = 0xff;
block[1] = 0xff;
length_to_bytes(aadlen, 8, block + 2);
alen = 10;
}
aes_cbc_mac_update(&mac_ctx, block, alen);
aes_cbc_mac_update(&mac_ctx, aad, aadlen);
if ((alen + aadlen) % 16) {
aes_cbc_mac_update(&mac_ctx, zeros, 16 - (alen + aadlen)%16);
}
}
ctr[0] = 0;
ctr[0] |= (inlen_size - 1) & 0x7;
memcpy(ctr + 1, iv, ivlen);
memset(ctr + 1 + ivlen, 0, 15 - ivlen);
aes_encrypt(key, ctr, block);
ctr[15] = 1;
aes_ctr_n_encrypt(key, ctr, 15 - ivlen, in, inlen, out);
aes_cbc_mac_update(&mac_ctx, in, inlen);
if (inlen % 16) {
aes_cbc_mac_update(&mac_ctx, zeros, 16 - inlen % 16);
}
aes_cbc_mac_finish(&mac_ctx, mac);
gmssl_memxor(tag, mac, block, taglen);
gmssl_secure_clear(&mac_ctx, sizeof(mac_ctx));
return 1;
}
int aes_ccm_decrypt(const AES_KEY *key, const uint8_t *iv, size_t ivlen,
const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen,
const uint8_t *tag, size_t taglen, uint8_t *out)
{
AES_CBC_MAC_CTX mac_ctx;
const uint8_t zeros[16] = {0};
uint8_t block[16] = {0};
uint8_t ctr[16] = {0};
uint8_t mac[16];
size_t inlen_size;
if (!key || !iv || (!aad && aadlen) || (!in && inlen) || !tag || !out) {
error_print();
return -1;
}
if (ivlen < 7 || ivlen > 13) {
error_print();
return -1;
}
if (taglen < 4 || taglen > 16 || taglen & 1) {
error_print();
return -1;
}
inlen_size = 15 - ivlen;
if (inlen_size < 8 && inlen >= ((size_t)1 << (inlen_size * 8))) {
error_print();
return -1;
}
memset(&mac_ctx, 0, sizeof(mac_ctx));
mac_ctx.key = *key;
block[0] |= ((aadlen > 0) & 0x1) << 6;
block[0] |= (((taglen - 2)/2) & 0x7) << 3;
block[0] |= (inlen_size - 1) & 0x7;
memcpy(block + 1, iv, ivlen);
length_to_bytes(inlen, inlen_size, block + 1 + ivlen);
aes_cbc_mac_update(&mac_ctx, block, 16);
if (aad && aadlen) {
size_t alen;
if (aadlen < ((1<<16) - (1<<8))) {
length_to_bytes(aadlen, 2, block);
alen = 2;
} else if ((uint64_t)aadlen < ((uint64_t)1<<32)) {
block[0] = 0xff;
block[1] = 0xfe;
length_to_bytes(aadlen, 4, block + 2);
alen = 6;
} else {
block[0] = 0xff;
block[1] = 0xff;
length_to_bytes(aadlen, 8, block + 2);
alen = 10;
}
aes_cbc_mac_update(&mac_ctx, block, alen);
aes_cbc_mac_update(&mac_ctx, aad, aadlen);
if ((alen + aadlen) % 16) {
aes_cbc_mac_update(&mac_ctx, zeros, 16 - (alen + aadlen)%16);
}
}
ctr[0] = 0;
ctr[0] |= (inlen_size - 1) & 0x7;
memcpy(ctr + 1, iv, ivlen);
memset(ctr + 1 + ivlen, 0, 15 - ivlen);
aes_encrypt(key, ctr, block);
ctr[15] = 1;
aes_ctr_n_encrypt(key, ctr, 15 - ivlen, in, inlen, out);
aes_cbc_mac_update(&mac_ctx, out, inlen);
if (inlen % 16) {
aes_cbc_mac_update(&mac_ctx, zeros, 16 - inlen % 16);
}
aes_cbc_mac_finish(&mac_ctx, mac);
gmssl_memxor(mac, mac, block, taglen);
if (gmssl_secure_memcmp(mac, tag, taglen) != 0) {
error_print();
gmssl_secure_clear(&mac_ctx, sizeof(mac_ctx));
return -1;
}
gmssl_secure_clear(&mac_ctx, sizeof(mac_ctx));
return 1;
}
#endif
static void ctr32_incr(uint8_t a[16]) static void ctr32_incr(uint8_t a[16])
{ {

145
src/tls.c
View File

@@ -578,6 +578,130 @@ int tls_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4],
return 1; return 1;
} }
int tls_ccm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4],
const uint8_t seq_num[8], const uint8_t header[5],
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
{
uint8_t nonce[12];
uint8_t aad[13];
uint8_t *explicit_nonce;
uint8_t *tag;
if (!key || !fixed_iv || !seq_num || !header || (!in && inlen) || !out || !outlen) {
error_print();
return -1;
}
if (inlen > TLS_MAX_PLAINTEXT_SIZE) {
error_print();
return -1;
}
if ((((size_t)header[3]) << 8) + header[4] != inlen) {
error_print();
return -1;
}
memcpy(nonce, fixed_iv, 4);
memcpy(nonce + 4, seq_num, 8);
memcpy(aad, seq_num, 8);
memcpy(aad + 8, header, 5);
explicit_nonce = out;
memcpy(explicit_nonce, seq_num, 8);
out += 8;
tag = out + inlen;
switch (key->cipher->oid) {
#ifdef ENABLE_SM4_CCM
case OID_sm4:
if (sm4_ccm_encrypt(&(key->u.sm4_key), nonce, sizeof(nonce), aad, sizeof(aad),
in, inlen, out, GHASH_SIZE, tag) != 1) {
error_print();
return -1;
}
break;
#endif
#ifdef ENABLE_AES_CCM
case OID_aes128:
if (aes_ccm_encrypt(&(key->u.aes_key), nonce, sizeof(nonce), aad, sizeof(aad),
in, inlen, out, GHASH_SIZE, tag) != 1) {
error_print();
return -1;
}
break;
#endif
default:
error_print();
return -1;
}
*outlen = 8 + inlen + GHASH_SIZE;
return 1;
}
int tls_ccm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t fixed_iv[4],
const uint8_t seq_num[8], const uint8_t header[5],
const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
{
uint8_t nonce[12];
uint8_t aad[13];
const uint8_t *explicit_nonce;
const uint8_t *tag;
size_t mlen;
if (inlen < 8 + GHASH_SIZE) {
error_print();
return -1;
}
explicit_nonce = in;
in += 8;
inlen -= 8;
if (inlen < GHASH_SIZE) {
error_print();
return -1;
}
mlen = inlen - GHASH_SIZE;
tag = in + mlen;
memcpy(nonce, fixed_iv, 4);
memcpy(nonce + 4, explicit_nonce, 8);
memcpy(aad, seq_num, 8);
memcpy(aad + 8, header, 5);
aad[11] = (uint8_t)(mlen >> 8);
aad[12] = (uint8_t)mlen;
switch (key->cipher->oid) {
#ifdef ENABLE_SM4_CCM
case OID_sm4:
if (sm4_ccm_decrypt(&(key->u.sm4_key), nonce, sizeof(nonce), aad, sizeof(aad),
in, mlen, tag, GHASH_SIZE, out) != 1) {
error_print();
return -1;
}
break;
#endif
#ifdef ENABLE_AES_CCM
case OID_aes128:
if (aes_ccm_decrypt(&(key->u.aes_key), nonce, sizeof(nonce), aad, sizeof(aad),
in, mlen, tag, GHASH_SIZE, out) != 1) {
error_print();
return -1;
}
break;
#endif
default:
error_print();
return -1;
}
*outlen = mlen;
return 1;
}
int tls_random_generate(uint8_t random[32]) int tls_random_generate(uint8_t random[32])
{ {
uint32_t gmt_unix_time = (uint32_t)time(NULL); uint32_t gmt_unix_time = (uint32_t)time(NULL);
@@ -1639,11 +1763,20 @@ static const int tls12_ciphers[] = {
TLS_cipher_ecdhe_sm4_gcm_sm3, TLS_cipher_ecdhe_sm4_gcm_sm3,
TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256, TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256,
TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256, TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256,
#ifdef ENABLE_AES_CCM
TLS_cipher_aes_128_ccm_sha256,
#endif
}; };
static const int tls13_ciphers[] = { static const int tls13_ciphers[] = {
TLS_cipher_sm4_gcm_sm3, TLS_cipher_sm4_gcm_sm3,
#ifdef ENABLE_SM4_CCM
TLS_cipher_sm4_ccm_sm3,
#endif
TLS_cipher_aes_128_gcm_sha256, TLS_cipher_aes_128_gcm_sha256,
#ifdef ENABLE_AES_CCM
TLS_cipher_aes_128_ccm_sha256,
#endif
}; };
int tls_cipher_suite_match_protocol(int cipher, int protocol) int tls_cipher_suite_match_protocol(int cipher, int protocol)
@@ -1962,6 +2095,16 @@ static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *i
return -1; return -1;
} }
break; break;
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
if (tls_ccm_encrypt(enc_key, fixed_iv, seq_num, conn->databuf,
conn->databuf + 5, tls_record_data_length(conn->databuf),
conn->record + 5, &recordlen) != 1) {
error_print();
return -1;
}
break;
#endif
case TLS_cipher_ecdhe_sm4_cbc_sm3: case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
if (tls_cbc_encrypt(hmac_ctx, enc_key, seq_num, conn->databuf, if (tls_cbc_encrypt(hmac_ctx, enc_key, seq_num, conn->databuf,
@@ -2233,7 +2376,7 @@ static int tls13_send_close_notify(TLS_CONNECT *conn)
tls_record_set_alert(conn->plain_record, &conn->plain_recordlen, tls_record_set_alert(conn->plain_record, &conn->plain_recordlen,
TLS_alert_level_warning, TLS_alert_close_notify); TLS_alert_level_warning, TLS_alert_close_notify);
tls13_padding_len_rand(&padding_len); tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(key, iv, seq_num, conn->plain_record, conn->plain_recordlen, if (tls13_record_encrypt(conn->cipher_suite, key, iv, seq_num, conn->plain_record, conn->plain_recordlen,
padding_len, conn->record, &conn->recordlen) != 1) { padding_len, conn->record, &conn->recordlen) != 1) {
error_print(); error_print();
return -1; return -1;

View File

@@ -49,6 +49,9 @@ const int tls12_cipher_suites[] = {
#if defined(ENABLE_AES) && defined(ENABLE_SHA2) && defined(ENABLE_SECP256R1) #if defined(ENABLE_AES) && defined(ENABLE_SHA2) && defined(ENABLE_SECP256R1)
TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256, TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256,
TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256, TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256,
#ifdef ENABLE_AES_CCM
TLS_cipher_aes_128_ccm_sha256,
#endif
#endif #endif
}; };
const size_t tls12_cipher_suites_cnt = const size_t tls12_cipher_suites_cnt =
@@ -78,6 +81,16 @@ static int tls12_record_encrypt(int cipher_suite,
return -1; return -1;
} }
break; break;
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
if (tls_ccm_encrypt(key, fixed_iv, seq_num, in,
in + 5, inlen - 5,
out + 5, outlen) != 1) {
error_print();
return -1;
}
break;
#endif
case TLS_cipher_ecdhe_sm4_cbc_sm3: case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
if (tls_cbc_encrypt(hmac_ctx, key, seq_num, in, if (tls_cbc_encrypt(hmac_ctx, key, seq_num, in,
@@ -116,6 +129,16 @@ int tls12_record_decrypt(int cipher_suite, const HMAC_CTX *hmac_ctx,
return -1; return -1;
} }
break; break;
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
if (tls_ccm_decrypt(key, fixed_iv, seq_num, in,
in + 5, inlen - 5,
out + 5, outlen) != 1) {
error_print();
return -1;
}
break;
#endif
case TLS_cipher_ecdhe_sm4_cbc_sm3: case TLS_cipher_ecdhe_sm4_cbc_sm3:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
if (tls_cbc_decrypt(hmac_ctx, key, seq_num, in, if (tls_cbc_decrypt(hmac_ctx, key, seq_num, in,
@@ -418,6 +441,9 @@ static int tls12_server_key_exchange_params_from_bytes(int cipher_suite,
case TLS_cipher_ecdhe_sm4_gcm_sm3: case TLS_cipher_ecdhe_sm4_gcm_sm3:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256:
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
#endif
{ {
uint8_t curve_type; uint8_t curve_type;
uint16_t named_curve; uint16_t named_curve;
@@ -857,6 +883,9 @@ static int tls12_cipher_suite_get(int cipher_suite, const BLOCK_CIPHER **cipher,
#if defined(ENABLE_AES) && defined(ENABLE_SHA2) && defined(ENABLE_SECP256R1) #if defined(ENABLE_AES) && defined(ENABLE_SHA2) && defined(ENABLE_SECP256R1)
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256:
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
#endif
*cipher = BLOCK_CIPHER_aes128(); *cipher = BLOCK_CIPHER_aes128();
*digest = DIGEST_sha256(); *digest = DIGEST_sha256();
break; break;
@@ -877,6 +906,9 @@ static int tls12_cipher_suite_match_cert_group(int cipher_suite, int cert_group)
#if defined(ENABLE_AES) && defined(ENABLE_SHA2) && defined(ENABLE_SECP256R1) #if defined(ENABLE_AES) && defined(ENABLE_SHA2) && defined(ENABLE_SECP256R1)
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256:
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
#endif
return cert_group == TLS_curve_secp256r1; return cert_group == TLS_curve_secp256r1;
#endif #endif
default: default:
@@ -904,6 +936,9 @@ static int tls12_signature_scheme_match_cipher_suite(int sig_alg, int cipher_sui
switch (cipher_suite) { switch (cipher_suite) {
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256:
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
#endif
return 1; return 1;
} }
#endif #endif
@@ -921,6 +956,9 @@ static int tls12_key_exchange_group_match_cipher_suite(int group, int cipher_sui
#if defined(ENABLE_AES) && defined(ENABLE_SHA2) && defined(ENABLE_SECP256R1) #if defined(ENABLE_AES) && defined(ENABLE_SHA2) && defined(ENABLE_SECP256R1)
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256:
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
#endif
return group == TLS_curve_secp256r1; return group == TLS_curve_secp256r1;
#endif #endif
default: default:
@@ -1948,6 +1986,9 @@ int tls_recv_server_certificate(TLS_CONNECT *conn)
break; break;
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256:
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
#endif
server_sig_alg = TLS_sig_ecdsa_secp256r1_sha256; server_sig_alg = TLS_sig_ecdsa_secp256r1_sha256;
break; break;
default: default:
@@ -2153,6 +2194,9 @@ int tls_curve_match_cipher_suite(int named_curve, int cipher_suite)
switch (cipher_suite) { switch (cipher_suite) {
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256:
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
#endif
break; break;
default: default:
error_print(); error_print();
@@ -2185,6 +2229,9 @@ int tls_signature_scheme_match_cipher_suite(int sig_alg, int cipher_suite)
switch (cipher_suite) { switch (cipher_suite) {
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256:
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
#endif
break; break;
default: default:
error_print(); error_print();
@@ -2261,6 +2308,9 @@ int tls_recv_server_key_exchange(TLS_CONNECT *conn)
case TLS_cipher_ecdhe_sm4_gcm_sm3: case TLS_cipher_ecdhe_sm4_gcm_sm3:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_cbc_sha256:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256:
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
#endif
{ {
const uint8_t *p = server_ecdh_params; const uint8_t *p = server_ecdh_params;
size_t len = server_ecdh_params_len; size_t len = server_ecdh_params_len;
@@ -2716,6 +2766,9 @@ static int tls12_generate_key_block(TLS_CONNECT *conn)
switch (conn->cipher_suite) { switch (conn->cipher_suite) {
case TLS_cipher_ecdhe_sm4_gcm_sm3: case TLS_cipher_ecdhe_sm4_gcm_sm3:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256:
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
#endif
{ {
size_t keylen = conn->cipher->key_size; size_t keylen = conn->cipher->key_size;
size_t key_block_len = keylen * 2 + 8; size_t key_block_len = keylen * 2 + 8;
@@ -2771,6 +2824,9 @@ static int tls12_generate_record_keys(TLS_CONNECT *conn)
switch (conn->cipher_suite) { switch (conn->cipher_suite) {
case TLS_cipher_ecdhe_sm4_gcm_sm3: case TLS_cipher_ecdhe_sm4_gcm_sm3:
case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256: case TLS_cipher_ecdhe_ecdsa_with_aes_128_gcm_sha256:
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
#endif
{ {
size_t keylen = conn->cipher->key_size; size_t keylen = conn->cipher->key_size;

View File

@@ -48,8 +48,14 @@ const size_t tls13_signature_algorithms_cnt =
const int tls13_cipher_suites[] = { const int tls13_cipher_suites[] = {
TLS_cipher_sm4_gcm_sm3, TLS_cipher_sm4_gcm_sm3,
#ifdef ENABLE_SM4_CCM
TLS_cipher_sm4_ccm_sm3,
#endif
#if defined(ENABLE_AES) && defined(ENABLE_SHA2) #if defined(ENABLE_AES) && defined(ENABLE_SHA2)
TLS_cipher_aes_128_gcm_sha256, TLS_cipher_aes_128_gcm_sha256,
#ifdef ENABLE_AES_CCM
TLS_cipher_aes_128_ccm_sha256,
#endif
#endif #endif
}; };
const size_t tls13_cipher_suites_cnt = const size_t tls13_cipher_suites_cnt =
@@ -72,11 +78,17 @@ int tls13_cipher_suite_get(int cipher_suite, const BLOCK_CIPHER **cipher, const
{ {
switch (cipher_suite) { switch (cipher_suite) {
case TLS_cipher_sm4_gcm_sm3: case TLS_cipher_sm4_gcm_sm3:
#ifdef ENABLE_SM4_CCM
case TLS_cipher_sm4_ccm_sm3:
#endif
*digest = DIGEST_sm3(); *digest = DIGEST_sm3();
*cipher = BLOCK_CIPHER_sm4(); *cipher = BLOCK_CIPHER_sm4();
break; break;
#if defined(ENABLE_AES) && defined(ENABLE_SHA2) #if defined(ENABLE_AES) && defined(ENABLE_SHA2)
case TLS_cipher_aes_128_gcm_sha256: case TLS_cipher_aes_128_gcm_sha256:
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
#endif
*digest = DIGEST_sha256(); *digest = DIGEST_sha256();
*cipher = BLOCK_CIPHER_aes128(); *cipher = BLOCK_CIPHER_aes128();
break; break;
@@ -250,17 +262,166 @@ int tls13_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12],
return 1; return 1;
} }
static int tls13_ccm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12],
const uint8_t seq_num[8], int record_type,
const uint8_t *in, size_t inlen, size_t padding_len,
uint8_t *out, size_t *outlen)
{
uint8_t nonce[12];
uint8_t aad[5];
uint8_t *tag;
uint8_t *mbuf = NULL;
size_t mlen, clen;
int tls13_record_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], if (!(mbuf = malloc(inlen + 256))) {
error_print();
return -1;
}
nonce[0] = nonce[1] = nonce[2] = nonce[3] = 0;
memcpy(nonce + 4, seq_num, 8);
gmssl_memxor(nonce, nonce, iv, 12);
memcpy(mbuf, in, inlen);
mbuf[inlen] = record_type;
memset(mbuf + inlen + 1, 0, padding_len);
mlen = inlen + 1 + padding_len;
clen = mlen + GHASH_SIZE;
aad[0] = TLS_record_application_data;
aad[1] = 0x03;
aad[2] = 0x03;
aad[3] = (uint8_t)(clen >> 8);
aad[4] = (uint8_t)(clen);
tag = out + mlen;
switch (key->cipher->oid) {
#ifdef ENABLE_SM4_CCM
case OID_sm4:
if (sm4_ccm_encrypt(&(key->u.sm4_key), nonce, sizeof(nonce), aad, sizeof(aad),
mbuf, mlen, out, GHASH_SIZE, tag) != 1) {
error_print();
free(mbuf);
return -1;
}
break;
#endif
#ifdef ENABLE_AES_CCM
case OID_aes128:
if (aes_ccm_encrypt(&(key->u.aes_key), nonce, sizeof(nonce), aad, sizeof(aad),
mbuf, mlen, out, GHASH_SIZE, tag) != 1) {
error_print();
free(mbuf);
return -1;
}
break;
#endif
default:
error_print();
free(mbuf);
return -1;
}
*outlen = clen;
free(mbuf);
return 1;
}
static int tls13_ccm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12],
const uint8_t seq_num[8], const uint8_t *in, size_t inlen,
int *record_type, uint8_t *out, size_t *outlen)
{
uint8_t nonce[12];
uint8_t aad[5];
size_t mlen;
const uint8_t *tag;
nonce[0] = nonce[1] = nonce[2] = nonce[3] = 0;
memcpy(nonce + 4, seq_num, 8);
gmssl_memxor(nonce, nonce, iv, 12);
aad[0] = TLS_record_application_data;
aad[1] = 0x03;
aad[2] = 0x03;
aad[3] = (uint8_t)(inlen >> 8);
aad[4] = (uint8_t)(inlen);
if (inlen < GHASH_SIZE) {
error_print();
return -1;
}
mlen = inlen - GHASH_SIZE;
tag = in + mlen;
switch (key->cipher->oid) {
#ifdef ENABLE_SM4_CCM
case OID_sm4:
if (sm4_ccm_decrypt(&(key->u.sm4_key), nonce, sizeof(nonce), aad, sizeof(aad),
in, mlen, tag, GHASH_SIZE, out) != 1) {
error_print();
return -1;
}
break;
#endif
#ifdef ENABLE_AES_CCM
case OID_aes128:
if (aes_ccm_decrypt(&(key->u.aes_key), nonce, sizeof(nonce), aad, sizeof(aad),
in, mlen, tag, GHASH_SIZE, out) != 1) {
error_print();
return -1;
}
break;
#endif
default:
error_print();
return -1;
}
*record_type = 0;
while (mlen--) {
if (out[mlen] != 0) {
*record_type = out[mlen];
break;
}
}
*outlen = mlen;
if (!tls_record_type_name(*record_type)) {
error_print();
return -1;
}
return 1;
}
int tls13_record_encrypt(int cipher_suite, const BLOCK_CIPHER_KEY *key, const uint8_t iv[12],
const uint8_t seq_num[8], const uint8_t *record, size_t recordlen, size_t padding_len, const uint8_t seq_num[8], const uint8_t *record, size_t recordlen, size_t padding_len,
uint8_t *enced_record, size_t *enced_recordlen) uint8_t *enced_record, size_t *enced_recordlen)
{ {
switch (cipher_suite) {
case TLS_cipher_sm4_gcm_sm3:
case TLS_cipher_aes_128_gcm_sha256:
if (tls13_gcm_encrypt(key, iv, if (tls13_gcm_encrypt(key, iv,
seq_num, record[0], record + 5, recordlen - 5, padding_len, seq_num, record[0], record + 5, recordlen - 5, padding_len,
enced_record + 5, enced_recordlen) != 1) { enced_record + 5, enced_recordlen) != 1) {
error_print(); error_print();
return -1; return -1;
} }
break;
case TLS_cipher_sm4_ccm_sm3:
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
#endif
if (tls13_ccm_encrypt(key, iv,
seq_num, record[0], record + 5, recordlen - 5, padding_len,
enced_record + 5, enced_recordlen) != 1) {
error_print();
return -1;
}
break;
default:
error_print();
return -1;
}
// in tls1.3, type of encrypted records must be application_data // in tls1.3, type of encrypted records must be application_data
enced_record[0] = TLS_record_application_data; enced_record[0] = TLS_record_application_data;
@@ -273,18 +434,37 @@ int tls13_record_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12],
return 1; return 1;
} }
int tls13_record_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], int tls13_record_decrypt(int cipher_suite, const BLOCK_CIPHER_KEY *key, const uint8_t iv[12],
const uint8_t seq_num[8], const uint8_t *enced_record, size_t enced_recordlen, const uint8_t seq_num[8], const uint8_t *enced_record, size_t enced_recordlen,
uint8_t *record, size_t *recordlen) uint8_t *record, size_t *recordlen)
{ {
int record_type; int record_type;
switch (cipher_suite) {
case TLS_cipher_sm4_gcm_sm3:
case TLS_cipher_aes_128_gcm_sha256:
if (tls13_gcm_decrypt(key, iv, if (tls13_gcm_decrypt(key, iv,
seq_num, enced_record + 5, enced_recordlen - 5, seq_num, enced_record + 5, enced_recordlen - 5,
&record_type, record + 5, recordlen) != 1) { &record_type, record + 5, recordlen) != 1) {
error_print(); error_print();
return -1; return -1;
} }
break;
case TLS_cipher_sm4_ccm_sm3:
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
#endif
if (tls13_ccm_decrypt(key, iv,
seq_num, enced_record + 5, enced_recordlen - 5,
&record_type, record + 5, recordlen) != 1) {
error_print();
return -1;
}
break;
default:
error_print();
return -1;
}
record[0] = record_type; record[0] = record_type;
record[1] = 0x03; //TLS_protocol_tls12_major; record[1] = 0x03; //TLS_protocol_tls12_major;
record[2] = 0x03; //TLS_protocol_tls12_minor; record[2] = 0x03; //TLS_protocol_tls12_minor;
@@ -1251,7 +1431,7 @@ int tls13_do_recv(TLS_CONNECT *conn)
//format_print(stderr, 0, 0, "\n"); //format_print(stderr, 0, 0, "\n");
} }
if (tls13_record_decrypt(key, iv, seq_num, conn->record, conn->recordlen, if (tls13_record_decrypt(conn->cipher_suite, key, iv, seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) { conn->plain_record, &conn->plain_recordlen) != 1) {
error_print(); error_print();
return -1; return -1;
@@ -5261,7 +5441,7 @@ int tls13_recv_encrypted_extensions(TLS_CONNECT *conn)
} }
tls13_record_print(stderr, 0, 0, conn->record, conn->recordlen); tls13_record_print(stderr, 0, 0, conn->record, conn->recordlen);
if (tls13_record_decrypt(&conn->server_write_key, conn->server_write_iv, if (tls13_record_decrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->record, conn->recordlen, conn->server_seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) { conn->plain_record, &conn->plain_recordlen) != 1) {
error_print(); error_print();
@@ -5744,7 +5924,7 @@ int tls13_recv_certificate_request(TLS_CONNECT *conn)
return ret; return ret;
} }
if (tls13_record_decrypt(&conn->server_write_key, conn->server_write_iv, if (tls13_record_decrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->record, conn->recordlen, conn->server_seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) { conn->plain_record, &conn->plain_recordlen) != 1) {
error_print(); error_print();
@@ -6029,7 +6209,7 @@ int tls13_recv_server_certificate(TLS_CONNECT *conn)
// decrypt unless previous handshake is CertificateRequest // decrypt unless previous handshake is CertificateRequest
if (!conn->plain_recordlen) { if (!conn->plain_recordlen) {
if (tls13_record_decrypt(&conn->server_write_key, conn->server_write_iv, if (tls13_record_decrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->record, conn->recordlen, conn->server_seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) { conn->plain_record, &conn->plain_recordlen) != 1) {
error_print(); error_print();
@@ -6163,7 +6343,7 @@ int tls13_recv_server_certificate_verify(TLS_CONNECT *conn)
return ret; return ret;
} }
if (tls13_record_decrypt(&conn->server_write_key, conn->server_write_iv, if (tls13_record_decrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->record, conn->recordlen, conn->server_seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) { conn->plain_record, &conn->plain_recordlen) != 1) {
error_print(); error_print();
@@ -6239,7 +6419,7 @@ int tls13_recv_client_certificate_verify(TLS_CONNECT *conn)
return ret; return ret;
} }
if (tls13_record_decrypt(&conn->client_write_key, conn->client_write_iv, if (tls13_record_decrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv,
conn->client_seq_num, conn->record, conn->recordlen, conn->client_seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) { conn->plain_record, &conn->plain_recordlen) != 1) {
error_print(); error_print();
@@ -6325,7 +6505,7 @@ int tls13_recv_server_finished(TLS_CONNECT *conn)
return ret; return ret;
} }
if (tls13_record_decrypt(&conn->server_write_key, conn->server_write_iv, if (tls13_record_decrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->record, conn->recordlen, conn->server_seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) { conn->plain_record, &conn->plain_recordlen) != 1) {
error_print(); error_print();
@@ -6415,7 +6595,7 @@ int tls13_send_client_certificate(TLS_CONNECT *conn)
tls13_padding_len_rand(&padding_len); tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->client_write_key, conn->client_write_iv, if (tls13_record_encrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv,
conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) { conn->record, &conn->recordlen) != 1) {
error_print(); error_print();
@@ -6468,7 +6648,7 @@ int tls13_send_client_certificate_verify(TLS_CONNECT *conn)
if(conn->verbose) tls_handshake_digest_print(stderr, 0, 0, "after client CertificateVerify", &conn->dgst_ctx); if(conn->verbose) tls_handshake_digest_print(stderr, 0, 0, "after client CertificateVerify", &conn->dgst_ctx);
tls13_padding_len_rand(&padding_len); tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->client_write_key, conn->client_write_iv, if (tls13_record_encrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv,
conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) { conn->record, &conn->recordlen) != 1) {
error_print(); error_print();
@@ -6516,7 +6696,7 @@ int tls13_send_client_finished(TLS_CONNECT *conn)
//format_print(stderr, 0, 0, "client_seq_num: "PRIu64"\n", GETU64(conn->client_seq_num)); //format_print(stderr, 0, 0, "client_seq_num: "PRIu64"\n", GETU64(conn->client_seq_num));
tls13_padding_len_rand(&padding_len); tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->client_write_key, conn->client_write_iv, if (tls13_record_encrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv,
conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) { conn->record, &conn->recordlen) != 1) {
error_print(); error_print();
@@ -7927,7 +8107,7 @@ int tls13_send_alert(TLS_CONNECT *conn, int alert)
break; break;
default: default:
tls13_padding_len_rand(&padding_len); tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv, if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) { conn->record, &conn->recordlen) != 1) {
error_print(); error_print();
@@ -8027,7 +8207,7 @@ int tls13_send_encrypted_extensions(TLS_CONNECT *conn)
//format_print(stderr, 0, 0, "server_seq_num: "PRIu64"\n", GETU64(conn->server_seq_num)); //format_print(stderr, 0, 0, "server_seq_num: "PRIu64"\n", GETU64(conn->server_seq_num));
if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv, if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) { conn->record, &conn->recordlen) != 1) {
error_print(); error_print();
@@ -8151,7 +8331,7 @@ int tls13_send_certificate_request(TLS_CONNECT *conn)
//format_print(stderr, 0, 0, "server_seq_num: "PRIu64"\n", GETU64(conn->server_seq_num)); //format_print(stderr, 0, 0, "server_seq_num: "PRIu64"\n", GETU64(conn->server_seq_num));
tls13_padding_len_rand(&padding_len); tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv, if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) { conn->record, &conn->recordlen) != 1) {
error_print(); error_print();
@@ -8218,7 +8398,7 @@ int tls13_send_server_certificate(TLS_CONNECT *conn)
if(conn->verbose) tls_handshake_digest_print(stderr, 0, 0, "ServerCertificate", &conn->dgst_ctx); if(conn->verbose) tls_handshake_digest_print(stderr, 0, 0, "ServerCertificate", &conn->dgst_ctx);
tls13_padding_len_rand(&padding_len); tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv, if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) { conn->record, &conn->recordlen) != 1) {
error_print(); error_print();
@@ -8271,7 +8451,7 @@ int tls13_send_server_certificate_verify(TLS_CONNECT *conn)
} }
tls13_padding_len_rand(&padding_len); tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv, if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) { conn->record, &conn->recordlen) != 1) {
error_print(); error_print();
@@ -8320,7 +8500,7 @@ int tls13_send_server_finished(TLS_CONNECT *conn)
//format_print(stderr, 0, 0, "server_seq_num: "PRIu64"\n", GETU64(conn->server_seq_num)); //format_print(stderr, 0, 0, "server_seq_num: "PRIu64"\n", GETU64(conn->server_seq_num));
tls13_padding_len_rand(&padding_len); tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv, if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) { conn->record, &conn->recordlen) != 1) {
error_print(); error_print();
@@ -8384,7 +8564,7 @@ int tls13_recv_client_certificate(TLS_CONNECT *conn)
//format_print(stderr, 0, 0, "client_seq_num: "PRIu64"\n", GETU64(conn->client_seq_num)); //format_print(stderr, 0, 0, "client_seq_num: "PRIu64"\n", GETU64(conn->client_seq_num));
if (tls13_record_decrypt(&conn->client_write_key, conn->client_write_iv, if (tls13_record_decrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv,
conn->client_seq_num, conn->record, conn->recordlen, conn->client_seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) { conn->plain_record, &conn->plain_recordlen) != 1) {
error_print(); error_print();
@@ -8521,7 +8701,7 @@ int tls13_recv_client_finished(TLS_CONNECT *conn)
//format_print(stderr, 0, 0, "client_seq_num: "PRIu64"\n", GETU64(conn->client_seq_num)); //format_print(stderr, 0, 0, "client_seq_num: "PRIu64"\n", GETU64(conn->client_seq_num));
if (tls13_record_decrypt(&conn->client_write_key, conn->client_write_iv, if (tls13_record_decrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv,
conn->client_seq_num, conn->record, conn->recordlen, conn->client_seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) { conn->plain_record, &conn->plain_recordlen) != 1) {
error_print(); error_print();
@@ -8620,7 +8800,7 @@ int tls13_send_client_key_update(TLS_CONNECT *conn, int request_update)
tls13_padding_len_rand(&padding_len); tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->client_write_key, conn->client_write_iv, if (tls13_record_encrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv,
conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) { conn->record, &conn->recordlen) != 1) {
error_print(); error_print();
@@ -8675,7 +8855,7 @@ int tls13_send_server_key_update(TLS_CONNECT *conn, int request_update)
tls13_record_print(stderr, 0, 0, conn->plain_record, conn->plain_recordlen); tls13_record_print(stderr, 0, 0, conn->plain_record, conn->plain_recordlen);
tls13_padding_len_rand(&padding_len); tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv, if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) { conn->record, &conn->recordlen) != 1) {
error_print(); error_print();

View File

@@ -683,7 +683,7 @@ int tls13_send_new_session_ticket(TLS_CONNECT *conn)
format_print(stderr, 0, 0, "\n"); format_print(stderr, 0, 0, "\n");
tls13_padding_len_rand(&padding_len); tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv, if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) { conn->record, &conn->recordlen) != 1) {
error_print(); error_print();
@@ -1766,7 +1766,7 @@ int tls13_send_end_of_early_data(TLS_CONNECT *conn)
size_t padding_len; size_t padding_len;
tls13_padding_len_rand(&padding_len); tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->client_write_key, conn->client_write_iv, if (tls13_record_encrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv,
conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len, conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) { conn->record, &conn->recordlen) != 1) {
error_print(); error_print();
@@ -1817,7 +1817,7 @@ int tls13_recv_end_of_early_data(TLS_CONNECT *conn)
format_bytes(stderr, 0, 4, "client_write_iv", conn->client_write_iv, 12); format_bytes(stderr, 0, 4, "client_write_iv", conn->client_write_iv, 12);
if (tls13_record_decrypt(&conn->client_write_key, conn->client_write_iv, if (tls13_record_decrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv,
conn->client_seq_num, conn->record, conn->recordlen, conn->client_seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) { conn->plain_record, &conn->plain_recordlen) != 1) {
error_print(); error_print();

View File

@@ -91,8 +91,12 @@ int tls_cipher_suite_from_name(const char *name)
{ {
if (!strcmp(name, "TLS_SM4_GCM_SM3")) { if (!strcmp(name, "TLS_SM4_GCM_SM3")) {
return TLS_cipher_sm4_gcm_sm3; return TLS_cipher_sm4_gcm_sm3;
} else if (!strcmp(name, "TLS_SM4_CCM_SM3")) {
return TLS_cipher_sm4_ccm_sm3;
} else if (!strcmp(name, "TLS_AES_128_GCM_SHA256")) { } else if (!strcmp(name, "TLS_AES_128_GCM_SHA256")) {
return TLS_cipher_aes_128_gcm_sha256; return TLS_cipher_aes_128_gcm_sha256;
} else if (!strcmp(name, "TLS_AES_128_CCM_SHA256")) {
return TLS_cipher_aes_128_ccm_sha256;
} else if (!strcmp(name, "TLS_ECDHE_SM4_CBC_SM3")) { } else if (!strcmp(name, "TLS_ECDHE_SM4_CBC_SM3")) {
return TLS_cipher_ecdhe_sm4_cbc_sm3; return TLS_cipher_ecdhe_sm4_cbc_sm3;
} else if (!strcmp(name, "TLS_ECDHE_SM4_GCM_SM3")) { } else if (!strcmp(name, "TLS_ECDHE_SM4_GCM_SM3")) {

View File

@@ -387,6 +387,57 @@ int test_aes_gcm(void)
return 1; return 1;
} }
#ifdef ENABLE_AES_CCM
int test_aes_ccm(void)
{
AES_KEY aes_key;
uint8_t key[16] = {
0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47,
0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f,
};
uint8_t iv[12] = {
0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b,
};
uint8_t aad[20] = {
0x00, 0x01, 0x02, 0x03, 0x04,
0x05, 0x06, 0x07, 0x08, 0x09,
0x0a, 0x0b, 0x0c, 0x0d, 0x0e,
0x0f, 0x10, 0x11, 0x12, 0x13,
};
uint8_t in[23] = {
0x20, 0x21, 0x22, 0x23, 0x24, 0x25,
0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b,
0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31,
0x32, 0x33, 0x34, 0x35, 0x36,
};
uint8_t out[sizeof(in)];
uint8_t buf[sizeof(in)];
uint8_t tag[16];
aes_set_encrypt_key(&aes_key, key, sizeof(key));
if (aes_ccm_encrypt(&aes_key, iv, sizeof(iv), aad, sizeof(aad), in, sizeof(in),
out, sizeof(tag), tag) != 1) {
error_print();
return -1;
}
if (aes_ccm_decrypt(&aes_key, iv, sizeof(iv), aad, sizeof(aad), out, sizeof(out),
tag, sizeof(tag), buf) != 1 || memcmp(buf, in, sizeof(in)) != 0) {
error_print();
return -1;
}
tag[0] ^= 0x01;
if (aes_ccm_decrypt(&aes_key, iv, sizeof(iv), aad, sizeof(aad), out, sizeof(out),
tag, sizeof(tag), buf) != -1) {
error_print();
return -1;
}
printf("%s() ok\n", __FUNCTION__);
return 1;
}
#endif
int test_aes_cbc_pkcs5_wycheproof(void) int test_aes_cbc_pkcs5_wycheproof(void)
{ {
size_t i; size_t i;
@@ -465,6 +516,9 @@ int main(void)
if (test_aes() != 1) goto err; if (test_aes() != 1) goto err;
if (test_aes_ctr() != 1) goto err; if (test_aes_ctr() != 1) goto err;
if (test_aes_gcm() != 1) goto err; if (test_aes_gcm() != 1) goto err;
#ifdef ENABLE_AES_CCM
if (test_aes_ccm() != 1) goto err;
#endif
if (test_aes_cbc_pkcs5_wycheproof() != 1) goto err; if (test_aes_cbc_pkcs5_wycheproof() != 1) goto err;
printf("%s all tests passed!\n", __FILE__); printf("%s all tests passed!\n", __FILE__);
return 0; return 0;

View File

@@ -97,6 +97,56 @@ static int test_tls13_gcm(void)
return 1; return 1;
} }
#ifdef ENABLE_AES_CCM
static int test_tls13_ccm(void)
{
BLOCK_CIPHER_KEY block_key;
uint8_t key[16];
uint8_t iv[12];
uint8_t seq_num[8] = {0,0,0,0,0,0,0,1};
uint8_t record[5 + 40];
size_t recordlen;
size_t padding_len = 8;
uint8_t enced_record[256];
size_t enced_recordlen;
uint8_t buf[256];
size_t buflen;
rand_bytes(key, sizeof(key));
rand_bytes(iv, sizeof(iv));
rand_bytes(record + 5, 40);
record[0] = TLS_record_handshake;
record[1] = TLS_protocol_tls12 >> 8;
record[2] = TLS_protocol_tls12 & 0xff;
record[3] = 0;
record[4] = 40;
recordlen = 5 + 40;
if (block_cipher_set_encrypt_key(&block_key, BLOCK_CIPHER_aes128(), key) != 1) {
error_print();
return -1;
}
if (tls13_record_encrypt(TLS_cipher_aes_128_ccm_sha256, &block_key, iv,
seq_num, record, recordlen, padding_len, enced_record, &enced_recordlen) != 1) {
error_print();
return -1;
}
if (tls13_record_decrypt(TLS_cipher_aes_128_ccm_sha256, &block_key, iv,
seq_num, enced_record, enced_recordlen, buf, &buflen) != 1) {
error_print();
return -1;
}
if (buflen != recordlen || memcmp(buf, record, recordlen) != 0) {
error_print();
return -1;
}
printf("%s() ok\n", __FUNCTION__);
return 1;
}
#endif
static int test_tls13_supported_versions_ext(void) static int test_tls13_supported_versions_ext(void)
{ {
const int client_versions[] = { TLS_protocol_tls13, TLS_protocol_tls12, TLS_protocol_tlcp }; const int client_versions[] = { TLS_protocol_tls13, TLS_protocol_tls12, TLS_protocol_tlcp };
@@ -661,6 +711,9 @@ int main(void)
{ {
if (test_tls_ext() != 1) goto err; if (test_tls_ext() != 1) goto err;
if (test_tls13_gcm() != 1) goto err; if (test_tls13_gcm() != 1) goto err;
#ifdef ENABLE_AES_CCM
if (test_tls13_ccm() != 1) goto err;
#endif
if (test_tls13_supported_versions_ext() != 1) goto err; if (test_tls13_supported_versions_ext() != 1) goto err;
if (test_tls13_key_share_ext() != 1) goto err; if (test_tls13_key_share_ext() != 1) goto err;
if (test_tls_supported_groups_ext() != 1) goto err; if (test_tls_supported_groups_ext() != 1) goto err;

View File

@@ -111,6 +111,54 @@ static int test_tls_cbc(void)
return 1; return 1;
} }
#ifdef ENABLE_AES_CCM
static int test_tls_ccm(void)
{
uint8_t key[16] = {0};
BLOCK_CIPHER_KEY aes_key;
uint8_t fixed_iv[4] = {0x10, 0x11, 0x12, 0x13};
uint8_t seq_num[8] = {0,0,0,0,0,0,0,1};
uint8_t record[5 + 32];
uint8_t enced_record[256];
uint8_t buf[256];
size_t recordlen;
size_t enced_recordlen;
size_t buflen;
record[0] = TLS_record_handshake;
record[1] = TLS_protocol_tls12 >> 8;
record[2] = TLS_protocol_tls12 & 0xff;
record[3] = 0;
record[4] = 12;
memcpy(record + 5, "hello world", 12);
recordlen = 5 + 12;
block_cipher_set_encrypt_key(&aes_key, BLOCK_CIPHER_aes128(), key);
if (tls_ccm_encrypt(&aes_key, fixed_iv, seq_num, record,
record + 5, 12, enced_record + 5, &enced_recordlen) != 1) {
error_print();
return -1;
}
enced_record[0] = record[0];
enced_record[1] = record[1];
enced_record[2] = record[2];
enced_record[3] = (uint8_t)(enced_recordlen >> 8);
enced_record[4] = (uint8_t)enced_recordlen;
enced_recordlen += 5;
if (tls12_record_decrypt(TLS_cipher_aes_128_ccm_sha256, NULL, &aes_key, fixed_iv, seq_num,
enced_record, enced_recordlen, buf, &buflen) != 1
|| buflen != recordlen
|| memcmp(buf, record, recordlen) != 0) {
error_print();
return -1;
}
printf("%s() ok\n", __FUNCTION__);
return 1;
}
#endif
static int test_tls_random(void) static int test_tls_random(void)
{ {
uint8_t random[32]; uint8_t random[32];
@@ -439,6 +487,9 @@ static int test_tls_trusted_ca_keys_ext(void)
int main(void) int main(void)
{ {
if (test_tls_null_to_bytes() != 1) goto err; if (test_tls_null_to_bytes() != 1) goto err;
#ifdef ENABLE_AES_CCM
if (test_tls_ccm() != 1) goto err;
#endif
/* /*
if (test_tls_encode() != 1) goto err; if (test_tls_encode() != 1) goto err;
if (test_tls_cbc() != 1) goto err; if (test_tls_cbc() != 1) goto err;