Fix TLS Alert

Alert, ChangeCipherSpec record should be encrypted after handshake
This commit is contained in:
Zhi Guan
2024-02-06 20:57:27 +08:00
parent 24783e56ed
commit 69ffa88037
5 changed files with 196 additions and 180 deletions

164
src/tls.c
View File

@@ -1332,6 +1332,8 @@ int tls_record_set_alert(uint8_t *record, size_t *recordlen,
return -1;
}
record[0] = TLS_record_alert;
//record[1] = protocol.major should be set by others
//record[2] = protocol.minor should be set by others
record[3] = 0; // length
record[4] = 2; // length
record[5] = (uint8_t)alert_level;
@@ -1491,7 +1493,7 @@ int tls_record_send(const uint8_t *record, size_t recordlen, tls_socket_t sock)
return 1;
}
int tls_record_do_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock)
int tls_record_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock)
{
uint8_t *p = record;
size_t len;
@@ -1503,7 +1505,7 @@ int tls_record_do_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock)
p += n;
len -= n;
} else if (n == 0) {
error_puts("TCP connection closed");
tls_trace("TCP connection closed");
*recordlen = 0;
return 0;
} else {
@@ -1541,7 +1543,7 @@ int tls_record_do_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock)
p += n;
len -= n;
} else if (n == 0) {
error_puts("connection closed");
tls_trace("connection closed");
*recordlen = 0;
return 0;
} else {
@@ -1558,45 +1560,6 @@ int tls_record_do_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock)
return 1;
}
int tls_record_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock)
{
int ret;
if ((ret = tls_record_do_recv(record, recordlen, sock)) != 1) {
if (ret && ret != -EAGAIN) error_print();
return ret;
}
if (tls_record_type(record) == TLS_record_alert) {
int level;
int alert;
if (tls_record_get_alert(record, &level, &alert) != 1) {
error_print();
return -1;
}
tls_record_trace(stderr, record, *recordlen, 0, 0);
if (level == TLS_alert_level_fatal && alert == TLS_alert_close_notify) {
#if ENABLE_TLS_RESPOND_CLOSE_NOTIFY
tls_trace("send Alert close_notifiy\n");
tls_record_trace(stderr, record, *recordlen, 0, 0);
if (tls_record_send(record, *recordlen, sock) != 1) {
error_print();
return -1;
}
#endif
return 0;
} else {
error_print();
return -1;
}
}
return 1;
}
int tls_seq_num_incr(uint8_t seq_num[8])
{
int i;
@@ -1604,7 +1567,7 @@ int tls_seq_num_incr(uint8_t seq_num[8])
seq_num[i]++;
if (seq_num[i]) break;
}
// FIXME: 检查溢出
// FIXME: check overflow
return 1;
}
@@ -1632,6 +1595,7 @@ int tls_send_alert(TLS_CONNECT *conn, int alert)
error_print();
return -1;
}
tls_record_set_protocol(record, conn->protocol == TLS_protocol_tls13 ? TLS_protocol_tls12 : conn->protocol);
tls_record_set_alert(record, &recordlen, TLS_alert_level_fatal, alert);
@@ -1692,13 +1656,12 @@ int tls_send_warning(TLS_CONNECT *conn, int alert)
return 1;
}
int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen)
static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *in, size_t inlen, size_t *sentlen)
{
const SM3_HMAC_CTX *hmac_ctx;
const SM4_KEY *enc_key;
uint8_t *seq_num;
uint8_t *record;
size_t datalen;
size_t recordlen;
if (!conn) {
error_print();
@@ -1713,6 +1676,11 @@ int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen
inlen = TLS_MAX_PLAINTEXT_SIZE;
}
if (conn->datalen) {
error_puts("recv all buffered data before send");
return -1;
}
if (conn->is_client) {
hmac_ctx = &conn->client_write_mac_ctx;
enc_key = &conn->client_write_enc_key;
@@ -1722,37 +1690,34 @@ int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen
enc_key = &conn->server_write_enc_key;
seq_num = conn->server_seq_num;
}
record = conn->record;
tls_trace("send ApplicationData\n");
if (tls_record_set_type(record, TLS_record_application_data) != 1
|| tls_record_set_protocol(record, conn->protocol) != 1
|| tls_record_set_length(record, inlen) != 1) {
if (tls_record_set_type(conn->databuf, record_type) != 1
|| tls_record_set_protocol(conn->databuf, conn->protocol) != 1
|| tls_record_set_data(conn->databuf, in, inlen) != 1) {
error_print();
return -1;
}
tls_record_trace(stderr, conn->databuf, tls_record_length(conn->databuf), 0, 0);
if (tls_cbc_encrypt(hmac_ctx, enc_key, seq_num, tls_record_header(record),
in, inlen, tls_record_data(record), &datalen) != 1) {
error_print();
return -1;
}
if (tls_record_set_length(record, datalen) != 1) {
if (tls_record_encrypt(hmac_ctx, enc_key, seq_num,
conn->databuf, tls_record_length(conn->databuf),
conn->record, &recordlen) != 1) {
error_print();
return -1;
}
tls_seq_num_incr(seq_num);
if (tls_record_send(record, tls_record_length(record), conn->sock) != 1) {
if (tls_record_send(conn->record, recordlen, conn->sock) != 1) {
error_print();
return -1;
}
tls_encrypted_record_trace(stderr, conn->record, recordlen, 0, 0);
*sentlen = inlen;
tls_record_trace(stderr, record, tls_record_length(record), 0, 0);
return 1;
}
int tls_do_recv(TLS_CONNECT *conn)
int tls_decrypt_recv(TLS_CONNECT *conn)
{
int ret;
const SM3_HMAC_CTX *hmac_ctx;
@@ -1772,68 +1737,111 @@ int tls_do_recv(TLS_CONNECT *conn)
seq_num = conn->client_seq_num;
}
tls_trace("recv ApplicationData\n");
tls_trace("recv Encrypted Record\n");
if ((ret = tls_record_recv(record, &recordlen, conn->sock)) != 1) {
if (ret < 0 && ret != -EAGAIN) error_print();
return ret;
}
tls_encrypted_record_trace(stderr, record, recordlen, 0, 0);
tls_record_trace(stderr, record, recordlen, 0, 0);
if (tls_cbc_decrypt(hmac_ctx, dec_key, seq_num, record,
tls_record_data(record), tls_record_data_length(record),
if (tls_record_decrypt(hmac_ctx, dec_key, seq_num,
record, recordlen,
conn->databuf, &conn->datalen) != 1) {
error_print();
return -1;
}
conn->data = conn->databuf;
tls_seq_num_incr(seq_num);
tls_record_set_data(record, conn->data, conn->datalen);
tls_trace("decrypt ApplicationData\n");
tls_record_trace(stderr, record, tls_record_length(record), 0, 0);
conn->data = tls_record_data(conn->databuf);
conn->datalen = tls_record_data_length(conn->databuf);
tls_record_trace(stderr, conn->databuf, tls_record_length(conn->databuf), 0, 0);
return 1;
}
int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen)
{
tls_trace("send ApplicationData\n");
return tls_encrypt_send(conn, TLS_record_application_data, in, inlen, sentlen);
}
int tls_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen)
{
if (!conn || !out || !outlen || !recvlen) {
error_print();
return -1;
}
if (conn->datalen == 0) {
int ret;
if ((ret = tls_do_recv(conn)) != 1) {
if ((ret = tls_decrypt_recv(conn)) != 1) {
if (ret < 0 && ret != -EAGAIN) error_print();
return ret;
}
switch (tls_record_type(conn->record)) {
case TLS_record_application_data:
break;
case TLS_record_change_cipher_spec:
error_print();
return -1;
case TLS_record_alert:
{
// should call tls_process_alert()
int level;
int alert;
tls_record_get_alert(conn->databuf, &level, &alert);
if (alert == TLS_alert_close_notify) {
tls_trace("recv Alert.close_notify\n");
return 0;
}
tls_trace("alert received\n");
return -1;
}
default:
error_print();
return -1;
}
}
*recvlen = outlen <= conn->datalen ? outlen : conn->datalen;
memcpy(out, conn->data, *recvlen);
conn->data += *recvlen;
conn->datalen -= *recvlen;
return 1;
}
int tls_shutdown(TLS_CONNECT *conn)
{
int ret;
size_t recordlen;
uint8_t alert[2];
alert[0] = TLS_alert_level_fatal;
alert[1] = TLS_alert_close_notify;
if (!conn) {
error_print();
return -1;
}
tls_trace("send Alert close_notify\n");
if (tls_send_alert(conn, TLS_alert_close_notify) != 1) {
tls_trace("send Alert.close_notify\n");
if (tls_encrypt_send(conn, TLS_record_alert, alert, sizeof(alert), &recordlen) != 1) {
error_print();
return -1;
}
#ifdef ENABLE_TLS_RESPOND_CLOSE_NOTIFY
tls_trace("recv Alert close_notify\n");
if (tls_record_do_recv(conn->record, &recordlen, conn->sock) != 1) {
error_print();
tls_trace("recv Alert.close_notify\n");
if ((ret = tls_decrypt_recv(conn)) != 1) {
if (ret == 0) tls_trace("Connection closed by remote without close_notify\n");
else if (ret == -EAGAIN) tls_trace("-EAGAIN\n");
else error_print();
return -1;
}
tls_record_trace(stderr, conn->record, recordlen, 0, 0);
#endif
return 1;
}