Add CCM cipher suites

This commit is contained in:
Zhi Guan
2026-06-14 00:12:10 +08:00
parent 5d12858d41
commit 545e6a56f0
13 changed files with 869 additions and 36 deletions

View File

@@ -48,8 +48,14 @@ const size_t tls13_signature_algorithms_cnt =
const int tls13_cipher_suites[] = {
TLS_cipher_sm4_gcm_sm3,
#ifdef ENABLE_SM4_CCM
TLS_cipher_sm4_ccm_sm3,
#endif
#if defined(ENABLE_AES) && defined(ENABLE_SHA2)
TLS_cipher_aes_128_gcm_sha256,
#ifdef ENABLE_AES_CCM
TLS_cipher_aes_128_ccm_sha256,
#endif
#endif
};
const size_t tls13_cipher_suites_cnt =
@@ -72,11 +78,17 @@ int tls13_cipher_suite_get(int cipher_suite, const BLOCK_CIPHER **cipher, const
{
switch (cipher_suite) {
case TLS_cipher_sm4_gcm_sm3:
#ifdef ENABLE_SM4_CCM
case TLS_cipher_sm4_ccm_sm3:
#endif
*digest = DIGEST_sm3();
*cipher = BLOCK_CIPHER_sm4();
break;
#if defined(ENABLE_AES) && defined(ENABLE_SHA2)
case TLS_cipher_aes_128_gcm_sha256:
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
#endif
*digest = DIGEST_sha256();
*cipher = BLOCK_CIPHER_aes128();
break;
@@ -250,14 +262,163 @@ int tls13_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12],
return 1;
}
static int tls13_ccm_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12],
const uint8_t seq_num[8], int record_type,
const uint8_t *in, size_t inlen, size_t padding_len,
uint8_t *out, size_t *outlen)
{
uint8_t nonce[12];
uint8_t aad[5];
uint8_t *tag;
uint8_t *mbuf = NULL;
size_t mlen, clen;
int tls13_record_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12],
if (!(mbuf = malloc(inlen + 256))) {
error_print();
return -1;
}
nonce[0] = nonce[1] = nonce[2] = nonce[3] = 0;
memcpy(nonce + 4, seq_num, 8);
gmssl_memxor(nonce, nonce, iv, 12);
memcpy(mbuf, in, inlen);
mbuf[inlen] = record_type;
memset(mbuf + inlen + 1, 0, padding_len);
mlen = inlen + 1 + padding_len;
clen = mlen + GHASH_SIZE;
aad[0] = TLS_record_application_data;
aad[1] = 0x03;
aad[2] = 0x03;
aad[3] = (uint8_t)(clen >> 8);
aad[4] = (uint8_t)(clen);
tag = out + mlen;
switch (key->cipher->oid) {
#ifdef ENABLE_SM4_CCM
case OID_sm4:
if (sm4_ccm_encrypt(&(key->u.sm4_key), nonce, sizeof(nonce), aad, sizeof(aad),
mbuf, mlen, out, GHASH_SIZE, tag) != 1) {
error_print();
free(mbuf);
return -1;
}
break;
#endif
#ifdef ENABLE_AES_CCM
case OID_aes128:
if (aes_ccm_encrypt(&(key->u.aes_key), nonce, sizeof(nonce), aad, sizeof(aad),
mbuf, mlen, out, GHASH_SIZE, tag) != 1) {
error_print();
free(mbuf);
return -1;
}
break;
#endif
default:
error_print();
free(mbuf);
return -1;
}
*outlen = clen;
free(mbuf);
return 1;
}
static int tls13_ccm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12],
const uint8_t seq_num[8], const uint8_t *in, size_t inlen,
int *record_type, uint8_t *out, size_t *outlen)
{
uint8_t nonce[12];
uint8_t aad[5];
size_t mlen;
const uint8_t *tag;
nonce[0] = nonce[1] = nonce[2] = nonce[3] = 0;
memcpy(nonce + 4, seq_num, 8);
gmssl_memxor(nonce, nonce, iv, 12);
aad[0] = TLS_record_application_data;
aad[1] = 0x03;
aad[2] = 0x03;
aad[3] = (uint8_t)(inlen >> 8);
aad[4] = (uint8_t)(inlen);
if (inlen < GHASH_SIZE) {
error_print();
return -1;
}
mlen = inlen - GHASH_SIZE;
tag = in + mlen;
switch (key->cipher->oid) {
#ifdef ENABLE_SM4_CCM
case OID_sm4:
if (sm4_ccm_decrypt(&(key->u.sm4_key), nonce, sizeof(nonce), aad, sizeof(aad),
in, mlen, tag, GHASH_SIZE, out) != 1) {
error_print();
return -1;
}
break;
#endif
#ifdef ENABLE_AES_CCM
case OID_aes128:
if (aes_ccm_decrypt(&(key->u.aes_key), nonce, sizeof(nonce), aad, sizeof(aad),
in, mlen, tag, GHASH_SIZE, out) != 1) {
error_print();
return -1;
}
break;
#endif
default:
error_print();
return -1;
}
*record_type = 0;
while (mlen--) {
if (out[mlen] != 0) {
*record_type = out[mlen];
break;
}
}
*outlen = mlen;
if (!tls_record_type_name(*record_type)) {
error_print();
return -1;
}
return 1;
}
int tls13_record_encrypt(int cipher_suite, const BLOCK_CIPHER_KEY *key, const uint8_t iv[12],
const uint8_t seq_num[8], const uint8_t *record, size_t recordlen, size_t padding_len,
uint8_t *enced_record, size_t *enced_recordlen)
{
if (tls13_gcm_encrypt(key, iv,
seq_num, record[0], record + 5, recordlen - 5, padding_len,
enced_record + 5, enced_recordlen) != 1) {
switch (cipher_suite) {
case TLS_cipher_sm4_gcm_sm3:
case TLS_cipher_aes_128_gcm_sha256:
if (tls13_gcm_encrypt(key, iv,
seq_num, record[0], record + 5, recordlen - 5, padding_len,
enced_record + 5, enced_recordlen) != 1) {
error_print();
return -1;
}
break;
case TLS_cipher_sm4_ccm_sm3:
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
#endif
if (tls13_ccm_encrypt(key, iv,
seq_num, record[0], record + 5, recordlen - 5, padding_len,
enced_record + 5, enced_recordlen) != 1) {
error_print();
return -1;
}
break;
default:
error_print();
return -1;
}
@@ -273,15 +434,34 @@ int tls13_record_encrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12],
return 1;
}
int tls13_record_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12],
int tls13_record_decrypt(int cipher_suite, const BLOCK_CIPHER_KEY *key, const uint8_t iv[12],
const uint8_t seq_num[8], const uint8_t *enced_record, size_t enced_recordlen,
uint8_t *record, size_t *recordlen)
{
int record_type;
if (tls13_gcm_decrypt(key, iv,
seq_num, enced_record + 5, enced_recordlen - 5,
&record_type, record + 5, recordlen) != 1) {
switch (cipher_suite) {
case TLS_cipher_sm4_gcm_sm3:
case TLS_cipher_aes_128_gcm_sha256:
if (tls13_gcm_decrypt(key, iv,
seq_num, enced_record + 5, enced_recordlen - 5,
&record_type, record + 5, recordlen) != 1) {
error_print();
return -1;
}
break;
case TLS_cipher_sm4_ccm_sm3:
#ifdef ENABLE_AES_CCM
case TLS_cipher_aes_128_ccm_sha256:
#endif
if (tls13_ccm_decrypt(key, iv,
seq_num, enced_record + 5, enced_recordlen - 5,
&record_type, record + 5, recordlen) != 1) {
error_print();
return -1;
}
break;
default:
error_print();
return -1;
}
@@ -1251,7 +1431,7 @@ int tls13_do_recv(TLS_CONNECT *conn)
//format_print(stderr, 0, 0, "\n");
}
if (tls13_record_decrypt(key, iv, seq_num, conn->record, conn->recordlen,
if (tls13_record_decrypt(conn->cipher_suite, key, iv, seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) {
error_print();
return -1;
@@ -5261,7 +5441,7 @@ int tls13_recv_encrypted_extensions(TLS_CONNECT *conn)
}
tls13_record_print(stderr, 0, 0, conn->record, conn->recordlen);
if (tls13_record_decrypt(&conn->server_write_key, conn->server_write_iv,
if (tls13_record_decrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) {
error_print();
@@ -5744,7 +5924,7 @@ int tls13_recv_certificate_request(TLS_CONNECT *conn)
return ret;
}
if (tls13_record_decrypt(&conn->server_write_key, conn->server_write_iv,
if (tls13_record_decrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) {
error_print();
@@ -6029,7 +6209,7 @@ int tls13_recv_server_certificate(TLS_CONNECT *conn)
// decrypt unless previous handshake is CertificateRequest
if (!conn->plain_recordlen) {
if (tls13_record_decrypt(&conn->server_write_key, conn->server_write_iv,
if (tls13_record_decrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) {
error_print();
@@ -6163,7 +6343,7 @@ int tls13_recv_server_certificate_verify(TLS_CONNECT *conn)
return ret;
}
if (tls13_record_decrypt(&conn->server_write_key, conn->server_write_iv,
if (tls13_record_decrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) {
error_print();
@@ -6239,7 +6419,7 @@ int tls13_recv_client_certificate_verify(TLS_CONNECT *conn)
return ret;
}
if (tls13_record_decrypt(&conn->client_write_key, conn->client_write_iv,
if (tls13_record_decrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv,
conn->client_seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) {
error_print();
@@ -6325,7 +6505,7 @@ int tls13_recv_server_finished(TLS_CONNECT *conn)
return ret;
}
if (tls13_record_decrypt(&conn->server_write_key, conn->server_write_iv,
if (tls13_record_decrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) {
error_print();
@@ -6415,7 +6595,7 @@ int tls13_send_client_certificate(TLS_CONNECT *conn)
tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->client_write_key, conn->client_write_iv,
if (tls13_record_encrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv,
conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) {
error_print();
@@ -6468,7 +6648,7 @@ int tls13_send_client_certificate_verify(TLS_CONNECT *conn)
if(conn->verbose) tls_handshake_digest_print(stderr, 0, 0, "after client CertificateVerify", &conn->dgst_ctx);
tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->client_write_key, conn->client_write_iv,
if (tls13_record_encrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv,
conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) {
error_print();
@@ -6516,7 +6696,7 @@ int tls13_send_client_finished(TLS_CONNECT *conn)
//format_print(stderr, 0, 0, "client_seq_num: "PRIu64"\n", GETU64(conn->client_seq_num));
tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->client_write_key, conn->client_write_iv,
if (tls13_record_encrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv,
conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) {
error_print();
@@ -7927,7 +8107,7 @@ int tls13_send_alert(TLS_CONNECT *conn, int alert)
break;
default:
tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv,
if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) {
error_print();
@@ -8027,7 +8207,7 @@ int tls13_send_encrypted_extensions(TLS_CONNECT *conn)
//format_print(stderr, 0, 0, "server_seq_num: "PRIu64"\n", GETU64(conn->server_seq_num));
if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv,
if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) {
error_print();
@@ -8151,7 +8331,7 @@ int tls13_send_certificate_request(TLS_CONNECT *conn)
//format_print(stderr, 0, 0, "server_seq_num: "PRIu64"\n", GETU64(conn->server_seq_num));
tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv,
if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) {
error_print();
@@ -8218,7 +8398,7 @@ int tls13_send_server_certificate(TLS_CONNECT *conn)
if(conn->verbose) tls_handshake_digest_print(stderr, 0, 0, "ServerCertificate", &conn->dgst_ctx);
tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv,
if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) {
error_print();
@@ -8271,7 +8451,7 @@ int tls13_send_server_certificate_verify(TLS_CONNECT *conn)
}
tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv,
if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) {
error_print();
@@ -8320,7 +8500,7 @@ int tls13_send_server_finished(TLS_CONNECT *conn)
//format_print(stderr, 0, 0, "server_seq_num: "PRIu64"\n", GETU64(conn->server_seq_num));
tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv,
if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) {
error_print();
@@ -8384,7 +8564,7 @@ int tls13_recv_client_certificate(TLS_CONNECT *conn)
//format_print(stderr, 0, 0, "client_seq_num: "PRIu64"\n", GETU64(conn->client_seq_num));
if (tls13_record_decrypt(&conn->client_write_key, conn->client_write_iv,
if (tls13_record_decrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv,
conn->client_seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) {
error_print();
@@ -8521,7 +8701,7 @@ int tls13_recv_client_finished(TLS_CONNECT *conn)
//format_print(stderr, 0, 0, "client_seq_num: "PRIu64"\n", GETU64(conn->client_seq_num));
if (tls13_record_decrypt(&conn->client_write_key, conn->client_write_iv,
if (tls13_record_decrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv,
conn->client_seq_num, conn->record, conn->recordlen,
conn->plain_record, &conn->plain_recordlen) != 1) {
error_print();
@@ -8620,7 +8800,7 @@ int tls13_send_client_key_update(TLS_CONNECT *conn, int request_update)
tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->client_write_key, conn->client_write_iv,
if (tls13_record_encrypt(conn->cipher_suite, &conn->client_write_key, conn->client_write_iv,
conn->client_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) {
error_print();
@@ -8675,7 +8855,7 @@ int tls13_send_server_key_update(TLS_CONNECT *conn, int request_update)
tls13_record_print(stderr, 0, 0, conn->plain_record, conn->plain_recordlen);
tls13_padding_len_rand(&padding_len);
if (tls13_record_encrypt(&conn->server_write_key, conn->server_write_iv,
if (tls13_record_encrypt(conn->cipher_suite, &conn->server_write_key, conn->server_write_iv,
conn->server_seq_num, conn->plain_record, conn->plain_recordlen, padding_len,
conn->record, &conn->recordlen) != 1) {
error_print();