Fix TLS Alert

Alert, ChangeCipherSpec record should be encrypted after handshake
This commit is contained in:
Zhi Guan
2024-02-06 20:57:27 +08:00
parent 24783e56ed
commit 69ffa88037
5 changed files with 196 additions and 180 deletions

View File

@@ -718,7 +718,6 @@ typedef struct {
uint8_t record[TLS_MAX_RECORD_SIZE];
// 其实这个就不太对了,还是应该有一个完整的密文记录
uint8_t databuf[TLS_MAX_PLAINTEXT_SIZE];
uint8_t *data;
size_t datalen;
@@ -828,17 +827,20 @@ int tls13_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12],
#ifdef ENABLE_TLS_DEBUG
# define tls_trace(s) fprintf(stderr,(s))
# define tls_record_trace(fp,rec,reclen,fmt,ind) tls_record_print(fp,rec,reclen,fmt,ind)
# define tls_encrypted_record_trace(fp,rec,reclen,fmt,ind) tls_encrypted_record_print(fp,rec,reclen,fmt,ind)
# define tlcp_record_trace(fp,rec,reclen,fmt,ind) tlcp_record_print(fp,rec,reclen,fmt,ind)
# define tls12_record_trace(fp,rec,reclen,fmt,ind) tls12_record_print(fp,rec,reclen,fmt,ind)
# define tls13_record_trace(fp,rec,reclen,fmt,ind) tls13_record_print(fp,fmt,ind,rec,reclen)
#else
# define tls_trace(s)
# define tls_record_trace(fp,rec,reclen,fmt,ind)
# define tls_encrypted_record_trace(fp,rec,reclen,fmt,ind)
# define tlcp_record_trace(fp,rec,reclen,fmt,ind)
# define tls12_record_trace(fp,rec,reclen,fmt,ind)
# define tls13_record_trace(fp,rec,reclen,fmt,ind)
#endif
int tls_encrypted_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int format, int indent);
#ifdef __cplusplus
}

View File

@@ -484,14 +484,13 @@ int tlcp_do_connect(TLS_CONNECT *conn)
sm3_update(&sm3_ctx, finished_record + 5, finished_record_len - 5);
// encrypt Client Finished
tls_trace("encrypt Finished\n");
if (tls_record_encrypt(&conn->client_write_mac_ctx, &conn->client_write_enc_key,
conn->client_seq_num, finished_record, finished_record_len, record, &recordlen) != 1) {
error_print();
tls_send_alert(conn, TLS_alert_internal_error);
goto end;
}
tlcp_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据
tls_encrypted_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据
tls_seq_num_incr(conn->client_seq_num);
if (tls_record_send(record, recordlen, conn->sock) != 1) {
error_print();
@@ -526,8 +525,7 @@ int tlcp_do_connect(TLS_CONNECT *conn)
tls_send_alert(conn, TLS_alert_bad_record_mac);
goto end;
}
tlcp_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据
tls_trace("decrypt Finished\n");
tls_encrypted_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据
if (tls_record_decrypt(&conn->server_write_mac_ctx, &conn->server_write_enc_key,
conn->server_seq_num, record, recordlen, finished_record, &finished_record_len) != 1) {
error_print();
@@ -920,10 +918,10 @@ int tlcp_do_accept(TLS_CONNECT *conn)
tls_send_alert(conn, TLS_alert_unexpected_message);
goto end;
}
tlcp_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据
tls_encrypted_record_trace(stderr, record, recordlen, 0, 0);
// decrypt ClientFinished
tls_trace("decrypt Finished\n");
//tls_trace("decrypt Finished\n");
if (tls_record_decrypt(&conn->client_write_mac_ctx, &conn->client_write_enc_key,
conn->client_seq_num, record, recordlen, finished_record, &finished_record_len) != 1) {
error_print();
@@ -990,8 +988,7 @@ int tlcp_do_accept(TLS_CONNECT *conn)
tls_send_alert(conn, TLS_alert_internal_error);
goto end;
}
tls_trace("encrypt Finished\n");
tlcp_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据
tls_encrypted_record_trace(stderr, record, recordlen, 0, 0);
tls_seq_num_incr(conn->server_seq_num);
if (tls_record_send(record, recordlen, conn->sock) != 1) {
error_print();

164
src/tls.c
View File

@@ -1332,6 +1332,8 @@ int tls_record_set_alert(uint8_t *record, size_t *recordlen,
return -1;
}
record[0] = TLS_record_alert;
//record[1] = protocol.major should be set by others
//record[2] = protocol.minor should be set by others
record[3] = 0; // length
record[4] = 2; // length
record[5] = (uint8_t)alert_level;
@@ -1491,7 +1493,7 @@ int tls_record_send(const uint8_t *record, size_t recordlen, tls_socket_t sock)
return 1;
}
int tls_record_do_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock)
int tls_record_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock)
{
uint8_t *p = record;
size_t len;
@@ -1503,7 +1505,7 @@ int tls_record_do_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock)
p += n;
len -= n;
} else if (n == 0) {
error_puts("TCP connection closed");
tls_trace("TCP connection closed");
*recordlen = 0;
return 0;
} else {
@@ -1541,7 +1543,7 @@ int tls_record_do_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock)
p += n;
len -= n;
} else if (n == 0) {
error_puts("connection closed");
tls_trace("connection closed");
*recordlen = 0;
return 0;
} else {
@@ -1558,45 +1560,6 @@ int tls_record_do_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock)
return 1;
}
int tls_record_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock)
{
int ret;
if ((ret = tls_record_do_recv(record, recordlen, sock)) != 1) {
if (ret && ret != -EAGAIN) error_print();
return ret;
}
if (tls_record_type(record) == TLS_record_alert) {
int level;
int alert;
if (tls_record_get_alert(record, &level, &alert) != 1) {
error_print();
return -1;
}
tls_record_trace(stderr, record, *recordlen, 0, 0);
if (level == TLS_alert_level_fatal && alert == TLS_alert_close_notify) {
#if ENABLE_TLS_RESPOND_CLOSE_NOTIFY
tls_trace("send Alert close_notifiy\n");
tls_record_trace(stderr, record, *recordlen, 0, 0);
if (tls_record_send(record, *recordlen, sock) != 1) {
error_print();
return -1;
}
#endif
return 0;
} else {
error_print();
return -1;
}
}
return 1;
}
int tls_seq_num_incr(uint8_t seq_num[8])
{
int i;
@@ -1604,7 +1567,7 @@ int tls_seq_num_incr(uint8_t seq_num[8])
seq_num[i]++;
if (seq_num[i]) break;
}
// FIXME: 检查溢出
// FIXME: check overflow
return 1;
}
@@ -1632,6 +1595,7 @@ int tls_send_alert(TLS_CONNECT *conn, int alert)
error_print();
return -1;
}
tls_record_set_protocol(record, conn->protocol == TLS_protocol_tls13 ? TLS_protocol_tls12 : conn->protocol);
tls_record_set_alert(record, &recordlen, TLS_alert_level_fatal, alert);
@@ -1692,13 +1656,12 @@ int tls_send_warning(TLS_CONNECT *conn, int alert)
return 1;
}
int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen)
static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *in, size_t inlen, size_t *sentlen)
{
const SM3_HMAC_CTX *hmac_ctx;
const SM4_KEY *enc_key;
uint8_t *seq_num;
uint8_t *record;
size_t datalen;
size_t recordlen;
if (!conn) {
error_print();
@@ -1713,6 +1676,11 @@ int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen
inlen = TLS_MAX_PLAINTEXT_SIZE;
}
if (conn->datalen) {
error_puts("recv all buffered data before send");
return -1;
}
if (conn->is_client) {
hmac_ctx = &conn->client_write_mac_ctx;
enc_key = &conn->client_write_enc_key;
@@ -1722,37 +1690,34 @@ int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen
enc_key = &conn->server_write_enc_key;
seq_num = conn->server_seq_num;
}
record = conn->record;
tls_trace("send ApplicationData\n");
if (tls_record_set_type(record, TLS_record_application_data) != 1
|| tls_record_set_protocol(record, conn->protocol) != 1
|| tls_record_set_length(record, inlen) != 1) {
if (tls_record_set_type(conn->databuf, record_type) != 1
|| tls_record_set_protocol(conn->databuf, conn->protocol) != 1
|| tls_record_set_data(conn->databuf, in, inlen) != 1) {
error_print();
return -1;
}
tls_record_trace(stderr, conn->databuf, tls_record_length(conn->databuf), 0, 0);
if (tls_cbc_encrypt(hmac_ctx, enc_key, seq_num, tls_record_header(record),
in, inlen, tls_record_data(record), &datalen) != 1) {
error_print();
return -1;
}
if (tls_record_set_length(record, datalen) != 1) {
if (tls_record_encrypt(hmac_ctx, enc_key, seq_num,
conn->databuf, tls_record_length(conn->databuf),
conn->record, &recordlen) != 1) {
error_print();
return -1;
}
tls_seq_num_incr(seq_num);
if (tls_record_send(record, tls_record_length(record), conn->sock) != 1) {
if (tls_record_send(conn->record, recordlen, conn->sock) != 1) {
error_print();
return -1;
}
tls_encrypted_record_trace(stderr, conn->record, recordlen, 0, 0);
*sentlen = inlen;
tls_record_trace(stderr, record, tls_record_length(record), 0, 0);
return 1;
}
int tls_do_recv(TLS_CONNECT *conn)
int tls_decrypt_recv(TLS_CONNECT *conn)
{
int ret;
const SM3_HMAC_CTX *hmac_ctx;
@@ -1772,68 +1737,111 @@ int tls_do_recv(TLS_CONNECT *conn)
seq_num = conn->client_seq_num;
}
tls_trace("recv ApplicationData\n");
tls_trace("recv Encrypted Record\n");
if ((ret = tls_record_recv(record, &recordlen, conn->sock)) != 1) {
if (ret < 0 && ret != -EAGAIN) error_print();
return ret;
}
tls_encrypted_record_trace(stderr, record, recordlen, 0, 0);
tls_record_trace(stderr, record, recordlen, 0, 0);
if (tls_cbc_decrypt(hmac_ctx, dec_key, seq_num, record,
tls_record_data(record), tls_record_data_length(record),
if (tls_record_decrypt(hmac_ctx, dec_key, seq_num,
record, recordlen,
conn->databuf, &conn->datalen) != 1) {
error_print();
return -1;
}
conn->data = conn->databuf;
tls_seq_num_incr(seq_num);
tls_record_set_data(record, conn->data, conn->datalen);
tls_trace("decrypt ApplicationData\n");
tls_record_trace(stderr, record, tls_record_length(record), 0, 0);
conn->data = tls_record_data(conn->databuf);
conn->datalen = tls_record_data_length(conn->databuf);
tls_record_trace(stderr, conn->databuf, tls_record_length(conn->databuf), 0, 0);
return 1;
}
int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen)
{
tls_trace("send ApplicationData\n");
return tls_encrypt_send(conn, TLS_record_application_data, in, inlen, sentlen);
}
int tls_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen)
{
if (!conn || !out || !outlen || !recvlen) {
error_print();
return -1;
}
if (conn->datalen == 0) {
int ret;
if ((ret = tls_do_recv(conn)) != 1) {
if ((ret = tls_decrypt_recv(conn)) != 1) {
if (ret < 0 && ret != -EAGAIN) error_print();
return ret;
}
switch (tls_record_type(conn->record)) {
case TLS_record_application_data:
break;
case TLS_record_change_cipher_spec:
error_print();
return -1;
case TLS_record_alert:
{
// should call tls_process_alert()
int level;
int alert;
tls_record_get_alert(conn->databuf, &level, &alert);
if (alert == TLS_alert_close_notify) {
tls_trace("recv Alert.close_notify\n");
return 0;
}
tls_trace("alert received\n");
return -1;
}
default:
error_print();
return -1;
}
}
*recvlen = outlen <= conn->datalen ? outlen : conn->datalen;
memcpy(out, conn->data, *recvlen);
conn->data += *recvlen;
conn->datalen -= *recvlen;
return 1;
}
int tls_shutdown(TLS_CONNECT *conn)
{
int ret;
size_t recordlen;
uint8_t alert[2];
alert[0] = TLS_alert_level_fatal;
alert[1] = TLS_alert_close_notify;
if (!conn) {
error_print();
return -1;
}
tls_trace("send Alert close_notify\n");
if (tls_send_alert(conn, TLS_alert_close_notify) != 1) {
tls_trace("send Alert.close_notify\n");
if (tls_encrypt_send(conn, TLS_record_alert, alert, sizeof(alert), &recordlen) != 1) {
error_print();
return -1;
}
#ifdef ENABLE_TLS_RESPOND_CLOSE_NOTIFY
tls_trace("recv Alert close_notify\n");
if (tls_record_do_recv(conn->record, &recordlen, conn->sock) != 1) {
error_print();
tls_trace("recv Alert.close_notify\n");
if ((ret = tls_decrypt_recv(conn)) != 1) {
if (ret == 0) tls_trace("Connection closed by remote without close_notify\n");
else if (ret == -EAGAIN) tls_trace("-EAGAIN\n");
else error_print();
return -1;
}
tls_record_trace(stderr, conn->record, recordlen, 0, 0);
#endif
return 1;
}

View File

@@ -1070,7 +1070,7 @@ int tls13_record_print(FILE *fp, int format, int indent, const uint8_t *record,
}
// FIXME: 需要根据RFC来考虑这个函数的参数,从底向上逐步修改每个函数的接口参数
// FIXME: 根据RFC来考虑这个函数的参数,从底向上逐步修改每个函数的接口参数
// 仅从record数据是不能判断这个record是TLS 1.2还是TLS 1.3
// 不同协议上,同名的握手消息,其格式也是不一样的。这真是太恶心了!!!!
@@ -1105,13 +1105,6 @@ int tls_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int for
return -1;
}
// 最高字节设置后强制打印记录原始数据
if (format >> 24) {
format_bytes(fp, format, indent, "Data", data, datalen);
fprintf(fp, "\n");
return 1;
}
switch (record[0]) {
case TLS_record_handshake:
if (tls_handshake_print(fp, data, datalen, format, indent) != 1) {
@@ -1173,3 +1166,24 @@ int tls_secrets_print(FILE *fp,
format_print(stderr, format, indent, "\n");
return 1;
}
int tls_encrypted_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int format, int indent)
{
int protocol;
if (!fp || !record || recordlen < 5) {
error_print();
return -1;
}
protocol = tls_record_protocol(record);
format_print(fp, format, indent, "EncryptedRecord\n"); indent += 4;
format_print(fp, format, indent, "ContentType: %s (%d)\n", tls_record_type_name(record[0]), record[0]);
format_print(fp, format, indent, "Version: %s (%d.%d)\n", tls_protocol_name(protocol), protocol >> 8, protocol & 0xff);
format_print(fp, format, indent, "Length: %d\n", tls_record_data_length(record));
format_bytes(fp, format, indent, "EncryptedData", tls_record_data(record), tls_record_data_length(record));
fprintf(fp, "\n");
return 1;
}

View File

@@ -70,6 +70,7 @@ int tlcp_client_main(int argc, char *argv[])
size_t len = sizeof(buf);
char send_buf[1024] = {0};
size_t sentlen;
int read_stdin = 1;
argc--;
argv++;
@@ -130,24 +131,19 @@ bad:
return -1;
}
if (!(hp = gethostbyname(host))) {
fprintf(stderr, "%s: invalid hostname '%s'\n", prog, host);
goto end;
}
memset(&ctx, 0, sizeof(ctx));
memset(&conn, 0, sizeof(conn));
server.sin_addr = *((struct in_addr *)hp->h_addr_list[0]);
server.sin_family = AF_INET;
server.sin_port = htons(port);
if (tls_socket_create(&sock, AF_INET, SOCK_STREAM, 0) != 1) {
fprintf(stderr, "%s: open socket error\n", prog);
goto end;
}
if (!(hp = gethostbyname(host))) {
fprintf(stderr, "%s: invalid hostname '%s'\n", prog, host);
goto end;
}
server.sin_addr = *((struct in_addr *)hp->h_addr_list[0]);
server.sin_family = AF_INET;
server.sin_port = htons(port);
if (tls_socket_connect(sock, &server) != 1) {
fprintf(stderr, "%s: socket connect error\n", prog);
goto end;
@@ -158,19 +154,30 @@ bad:
fprintf(stderr, "%s: context init error\n", prog);
goto end;
}
if (cacertfile) {
if (tls_ctx_set_ca_certificates(&ctx, cacertfile, TLS_DEFAULT_VERIFY_DEPTH) != 1) {
fprintf(stderr, "%s: context init error\n", prog);
goto end;
}
}
if (certfile) {
if (!keyfile) {
fprintf(stderr, "%s: option '-key' should be assigned with '-cert'\n", prog);
goto end;
}
if (!pass) {
fprintf(stderr, "%s: option '-pass' should be assigned with '-pass'\n", prog);
goto end;
}
if (tls_ctx_set_certificate_and_key(&ctx, certfile, keyfile, pass) != 1) {
fprintf(stderr, "%s: context init error\n", prog);
goto end;
}
}
if (quiet || get) {
if (quiet) {
ctx.quiet = 1;
}
@@ -196,6 +203,9 @@ bad:
fclose(outcertsfp);
}
// tls_shutdown(&conn);
// return 0;
if (get) {
struct timeval timeout;
timeout.tv_sec = TIMEOUT_SECONDS;
@@ -208,6 +218,7 @@ bad:
goto end;
}
// use timeout to close the HTTP connection
if (setsockopt(conn.sock, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof(timeout)) != 0) {
perror("setsockopt");
fprintf(stderr, "%s: set socket timeout error\n", prog);
@@ -215,100 +226,84 @@ bad:
}
for (;;) {
int ret;
if ((ret = tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len)) != 1) {
if (ret == 0) {
fprintf(stderr, "%s: TLCP connection is closed by remote host\n", prog);
} else if (ret != -EAGAIN) {
fprintf(stderr, "%s: recv error\n", prog);
}
break;
int rv;
rv = tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len);
if (rv == 1) {
fwrite(buf, 1, len, stdout);
fflush(stdout);
} else if (rv == 0) {
fprintf(stderr, "%s: TLCP connection is closed by remote host\n", prog);
goto end;
} else if (rv == -EAGAIN) {
// when timeout, tls_recv return -EAGAIN (-11)
tls_shutdown(&conn);
ret = 0;
goto end;
} else {
fprintf(stderr, "%s: tls_recv error\n", prog);
goto end;
}
fwrite(buf, 1, len, stdout);
fflush(stdout);
}
tls_shutdown(&conn);
goto end;
}
for (;;) {
fd_set fds;
if (!fgets(send_buf, sizeof(send_buf), stdin)) {
if (feof(stdin)) {
tls_shutdown(&conn);
goto end;
} else {
continue;
}
}
if (tls_send(&conn, (uint8_t *)send_buf, strlen(send_buf), &sentlen) != 1) {
fprintf(stderr, "%s: send error\n", prog);
goto end;
}
FD_ZERO(&fds);
FD_SET(conn.sock, &fds);
#ifdef WIN32
#else
FD_SET(fileno(stdin), &fds); //FD_SET(STDIN_FILENO, &fds); // NOT allowed in winsock2 !!!
#endif
if (read_stdin)
FD_SET(STDIN_FILENO, &fds);
if (select((int)(conn.sock + 1), // WinSock2 select() ignore this arg
&fds, NULL, NULL, NULL) < 0) {
fprintf(stderr, "%s: select failed\n", prog);
#ifdef WIN32
fprintf(stderr, "WSAGetLastError = %u\n", WSAGetLastError());
#endif
if (select(conn.sock + 1, &fds, NULL, NULL, NULL) < 0) {
fprintf(stderr, "%s: select error\n", prog);
goto end;
}
if (read_stdin && FD_ISSET(STDIN_FILENO, &fds)) {
if (fgets(buf, sizeof(buf), stdin)) {
if (tls_send(&conn, (uint8_t *)buf, strlen(buf), &len) != 1) {
fprintf(stderr, "%s: send error\n", prog);
goto end;
}
} else {
if (!feof(stdin)) {
fprintf(stderr, "%s: length of input line exceeds buffer size\n", prog);
goto end;
}
read_stdin = 0;
}
}
if (FD_ISSET(conn.sock, &fds)) {
for (;;) {
memset(buf, 0, sizeof(buf));
if (tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len) != 1) {
goto end;
}
int rv;
rv = tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len);
if (rv == 1) {
fwrite(buf, 1, len, stdout);
fflush(stdout);
// 应该调整tls_recv 逻辑、API或者其他方式
if (conn.datalen == 0) {
break;
}
}
}
#ifdef WIN32
#else
if (FD_ISSET(fileno(stdin), &fds)) {
fprintf(stderr, "recv from stdin\n");
memset(send_buf, 0, sizeof(send_buf));
if (!fgets(send_buf, sizeof(send_buf), stdin)) {
if (feof(stdin)) {
tls_shutdown(&conn);
goto end;
} else {
continue;
}
}
if (tls_send(&conn, (uint8_t *)send_buf, strlen(send_buf), &sentlen) != 1) {
fprintf(stderr, "%s: send error\n", prog);
} else if (rv == 0) {
fprintf(stderr, "Connection closed by remote host\n");
goto end;
} else if (rv == -EAGAIN) {
// should not happen
error_print();
goto end;
} else {
error_print();
fprintf(stderr, "%s: tls_recv error\n", prog);
goto end;
}
}
#endif
fprintf(stderr, "end of this round\n");
}
end:
// FIXME: clean ctx and connection ASAP, as Ctrl-C is not handled
if (sock != -1) tls_socket_close(sock);
tls_ctx_cleanup(&ctx);
tls_cleanup(&conn);
return 0;
return ret;
}