diff --git a/CMakeLists.txt b/CMakeLists.txt index 7c6332bb..385cfe9c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -65,6 +65,7 @@ set(src src/skf/skf_ext.c src/skf/skf_prn.c src/skf/skf_wisec.c + src/socket.c src/tls.c src/tls_ext.c src/tls_trace.c diff --git a/demos/scripts/certdemo.sh b/demos/scripts/certdemo.sh index 08b178fb..0bf7ee38 100755 --- a/demos/scripts/certdemo.sh +++ b/demos/scripts/certdemo.sh @@ -1,5 +1,5 @@ #!/bin/bash - +set -e gmssl sm2keygen -pass 1234 -out rootcakey.pem gmssl certgen -C CN -ST Beijing -L Haidian -O PKU -OU CS -CN ROOTCA -days 3650 \ diff --git a/demos/scripts/tlcpdemo.sh b/demos/scripts/tlcpdemo.sh index d3642b40..36bdf8e5 100755 --- a/demos/scripts/tlcpdemo.sh +++ b/demos/scripts/tlcpdemo.sh @@ -1,5 +1,6 @@ #!/bin/bash -x +set -e gmssl sm2keygen -pass 1234 -out rootcakey.pem gmssl certgen -C CN -ST Beijing -L Haidian -O PKU -OU CS -CN ROOTCA -days 3650 -key rootcakey.pem -pass 1234 -out rootcacert.pem -key_usage keyCertSign -key_usage cRLSign -ca @@ -26,7 +27,11 @@ cat cacert.pem >> double_certs.pem # If port is already in use, `gmssl` will fail, use `ps aux | grep gmssl` and `sudo kill -9` to kill existing proc # TODO: check if `gmssl` is failed -sudo gmssl tlcp_server -port 443 -cert double_certs.pem -key signkey.pem -pass 1234 -ex_key enckey.pem -ex_pass 1234 -cacert cacert.pem & # 1>/dev/null 2>/dev/null & +which sudo +if [ $? -eq 0 ]; then + SUDO=sudo +fi +$SUDO gmssl tlcp_server -port 443 -cert double_certs.pem -key signkey.pem -pass 1234 -ex_key enckey.pem -ex_pass 1234 -cacert cacert.pem & 1>/dev/null 2>/dev/null & sleep 3 gmssl sm2keygen -pass 1234 -out clientkey.pem @@ -34,5 +39,5 @@ gmssl reqgen -C CN -ST Beijing -L Haidian -O PKU -OU CS -CN Client -key clientke gmssl reqsign -in clientreq.pem -days 365 -key_usage digitalSignature -cacert cacert.pem -key cakey.pem -pass 1234 -out clientcert.pem gmssl certparse -in clientcert.pem -#gmssl tlcp_client -host 127.0.0.1 -cacert rootcacert.pem -cert clientcert.pem -key clientkey.pem -pass 1234 +gmssl tlcp_client -host 127.0.0.1 -cacert rootcacert.pem -cert clientcert.pem -key clientkey.pem -pass 1234 diff --git a/demos/scripts/tls12demo.sh b/demos/scripts/tls12demo.sh index b9028456..21418b45 100755 --- a/demos/scripts/tls12demo.sh +++ b/demos/scripts/tls12demo.sh @@ -20,7 +20,11 @@ cat cacert.pem >> certs.pem # If port is already in use, `gmssl` will fail, use `ps aux | grep gmssl` and `sudo kill -9` to kill existing proc # TODO: check if `gmssl` is failed -sudo gmssl tls12_server -port 443 -cert certs.pem -key signkey.pem -pass 1234 -cacert cacert.pem & #1>/dev/null 2>/dev/null & +which sudo +if [ $? -eq 0 ]; then + SUDO=sudo +fi +$SUDO gmssl tls12_server -port 4430 -cert certs.pem -key signkey.pem -pass 1234 -cacert cacert.pem & #1>/dev/null 2>/dev/null & sleep 3 gmssl sm2keygen -pass 1234 -out clientkey.pem @@ -28,5 +32,5 @@ gmssl reqgen -C CN -ST Beijing -L Haidian -O PKU -OU CS -CN Client -key clientke gmssl reqsign -in clientreq.pem -days 365 -key_usage digitalSignature -cacert cacert.pem -key cakey.pem -pass 1234 -out clientcert.pem gmssl certparse -in clientcert.pem -gmssl tls12_client -host 127.0.0.1 -cacert rootcacert.pem -cert clientcert.pem -key clientkey.pem -pass 1234 +gmssl tls12_client -host 127.0.0.1 -port 4430 -cacert rootcacert.pem -cert clientcert.pem -key clientkey.pem -pass 1234 diff --git a/demos/scripts/tls13demo.sh b/demos/scripts/tls13demo.sh index 5107c61f..0c555270 100755 --- a/demos/scripts/tls13demo.sh +++ b/demos/scripts/tls13demo.sh @@ -20,7 +20,11 @@ cat cacert.pem >> certs.pem # If port is already in use, `gmssl` will fail, use `ps aux | grep gmssl` and `sudo kill -9` to kill existing proc # TODO: check if `gmssl` is failed -sudo gmssl tls13_server -port 443 -cert certs.pem -key signkey.pem -pass 1234 -cacert cacert.pem & # 1>/dev/null 2>/dev/null & +which sudo +if [ $? -eq 0 ]; then + SUDO=sudo +fi +$SUDO gmssl tls13_server -port 4433 -cert certs.pem -key signkey.pem -pass 1234 -cacert cacert.pem & # 1>/dev/null 2>/dev/null & sleep 3 gmssl sm2keygen -pass 1234 -out clientkey.pem @@ -28,5 +32,5 @@ gmssl reqgen -C CN -ST Beijing -L Haidian -O PKU -OU CS -CN Client -key clientke gmssl reqsign -in clientreq.pem -days 365 -key_usage digitalSignature -cacert cacert.pem -key cakey.pem -pass 1234 -out clientcert.pem gmssl certparse -in clientcert.pem -gmssl tls13_client -host 127.0.0.1 -cacert rootcacert.pem -cert clientcert.pem -key clientkey.pem -pass 1234 +gmssl tls13_client -host 127.0.0.1 -port 4433 -cacert rootcacert.pem -cert clientcert.pem -key clientkey.pem -pass 1234 diff --git a/include/gmssl/socket.h b/include/gmssl/socket.h index b0a34569..e5c5d926 100644 --- a/include/gmssl/socket.h +++ b/include/gmssl/socket.h @@ -19,6 +19,7 @@ extern "C" { #endif + #ifdef WIN32 #pragma comment (lib, "Ws2_32.lib") #pragma comment (lib, "Mswsock.lib") @@ -55,13 +56,15 @@ typedef socklen_t tls_socklen_t; #define tls_socket_recv(sock,buf,len,flags) recv(sock,buf,len,flags) #define tls_socket_close(sock) close(sock) - - #endif - - - +int tls_socket_lib_init(void); +int tls_socket_lib_cleanup(void); +int tls_socket_create(tls_socket_t *sock, int af, int type, int protocl); +int tls_socket_connect(tls_socket_t sock, const struct sockaddr_in *addr); +int tls_socket_bind(tls_socket_t sock, const struct sockaddr_in *addr); +int tls_socket_listen(tls_socket_t sock, int backlog); +int tls_socket_accept(tls_socket_t sock, struct sockaddr_in *addr, tls_socket_t *conn_sock); #ifdef __cplusplus diff --git a/src/socket.c b/src/socket.c new file mode 100644 index 00000000..c4f910be --- /dev/null +++ b/src/socket.c @@ -0,0 +1,169 @@ +/* + * Copyright 2014-2023 The GmSSL Project. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ + + +#include +#include +#include +#include +#include +#include + + +#ifdef WIN32 +int tls_socket_lib_init(void) +{ + WORD wVersion = MAKEWORD(2, 2); + WSADATA wsaData; + int err; + + if ((err = WSAStartup(wVersion, &wsaData)) != 0) { + fprintf(stderr, "WSAStartup() return error %d\n", err); + error_print(); + return -1; + } + return 1; +} + +int tls_socket_lib_cleanup(void) +{ + if (WSACleanup() != 0) { + fprintf(stderr, "WSACleanup() return error %d\n", WSAGetLastError()); + error_print(); + return -1; + } + return 1; +} + +int tls_socket_create(tls_socket_t *sock, int af, int type, int protocl) +{ + if (!sock) { + error_print(); + return -1; + } + if ((*sock = socket(af, type, protocol)) == INVALID_SOCKET) { + fprintf(stderr, "%s %d: socket error: %d\n", __FILE__, __LINE__, WSAGetLastError()); + error_print(); + return -1; + } + return 1; +} + +int tls_socket_connect(tls_socket_t sock, const struct sockaddr_in *addr) +{ + int addr_len = (int)sizeof(struct sockaddr_in); + if (connect(sock, (const struct sockaddr *)addr, addr_len) == SOCKET_ERROR) { + fprintf(stderr, "%s %d: socket error: %d\n", __FILE__, __LINE__, WSAGetLastError()); + error_print(); + return -1; + } + return 1; +} + +int tls_socket_bind(tls_socket_t sock, const struct sockaddr_in *addr) +{ + int addr_len = (int)sizeof(struct sockaddr_in); + if (bind(sock, (const struct sockaddr *)addr, addr_len) == SOCKET_ERROR) { + fprintf(stderr, "%s %d: socket bind error: %u\n", __FILE__, __LINE__, WSAGetLastError()); + error_print(); + return -1; + } + return 1; +} + +int tls_socket_listen(tls_socket_t sock, int backlog) +{ + if (listen(sock, backlog) == SOCKET_ERROR) { + fprintf(stderr, "%s %d: socket listen error: %u\n", __FILE__, __LINE__, WSAGetLastError()); + error_print(); + return -1; + } + return 1; +} + +int tls_socket_accept(tls_socket_t sock, struct sockaddr_in *addr, tls_socket_t *conn_sock) +{ + int addr_len = (int)sizeof(struct sockaddr_in_); + if ((*conn_sock = accept(sock, (struct sockaddr *)addr, &addr_len)) == INVALID_SOCKET) { + fprintf(stderr, "%s %d: accept error: %u\n", __FILE__, __LINE__, WSAGetLastError()); + error_print(); + return -1; + } + return 1; +} + +#else + +int tls_socket_lib_init(void) +{ + return 1; +} + +int tls_socket_lib_cleanup(void) +{ + return 1; +} + +int tls_socket_create(tls_socket_t *sock, int af, int type, int protocol) +{ + if (!sock) { + error_print(); + return -1; + } + if ((*sock = socket(af, type, protocol)) == -1) { + fprintf(stderr, "%s %d: socket error: %s\n", __FILE__, __LINE__, strerror(errno)); + error_print(); + return -1; + } + return 1; +} + +int tls_socket_connect(tls_socket_t sock, const struct sockaddr_in *addr) +{ + socklen_t addr_len = sizeof(struct sockaddr_in); + if (connect(sock, (const struct sockaddr *)addr, addr_len) == -1) { + fprintf(stderr, "%s %d: socket error: %s\n", __FILE__, __LINE__, strerror(errno)); + error_print(); + return -1; + } + return 1; +} + +int tls_socket_bind(tls_socket_t sock, const struct sockaddr_in *addr) +{ + socklen_t addr_len = (socklen_t)sizeof(struct sockaddr_in); + if (bind(sock, (const struct sockaddr *)addr, addr_len) == -1) { + fprintf(stderr, "%s %d: socket bind error: %s\n", __FILE__, __LINE__, strerror(errno)); + error_print(); + return -1; + } + return 1; +} + +int tls_socket_listen(tls_socket_t sock, int backlog) +{ + if (listen(sock, backlog) == -1) { + fprintf(stderr, "%s %d: socket listen error: %s\n", __FILE__, __LINE__, strerror(errno)); + error_print(); + return -1; + } + return 1; +} + +int tls_socket_accept(tls_socket_t sock, struct sockaddr_in *addr, tls_socket_t *conn_sock) +{ + socklen_t addr_len = (socklen_t)sizeof(struct sockaddr_in); + if ((*conn_sock = accept(sock, (struct sockaddr *)addr, &addr_len)) == -1) { + fprintf(stderr, "%s %d: accept: %s\n", __FILE__, __LINE__, strerror(errno)); + error_print(); + return -1; + } + return 1; +} +#endif diff --git a/tools/tlcp_client.c b/tools/tlcp_client.c index c48504d6..a4aecc23 100644 --- a/tools/tlcp_client.c +++ b/tools/tlcp_client.c @@ -89,16 +89,10 @@ bad: return -1; } -#ifdef WIN32 - WORD wVersion; - WSADATA wsaData; - wVersion = MAKEWORD(2, 2); - int err; - if ((err = WSAStartup(wVersion, &wsaData)) != 0) { - fprintf(stderr, "WSAStartup error %d\n", err); + if (tls_socket_lib_init() != 1) { + error_print(); return -1; } -#endif if (!(hp = gethostbyname(host))) { //herror("tlcp_client: '-host' invalid"); @@ -112,27 +106,17 @@ bad: server.sin_family = AF_INET; server.sin_port = htons(port); -#ifdef WIN32 - if ((sock = socket(AF_INET, SOCK_STREAM, 0)) == INVALID_SOCKET) { - fprintf(stderr, "%s: open socket error : %u\n", prog, WSAGetLastError()); + + if (tls_socket_create(&sock, AF_INET, SOCK_STREAM, 0) != 1) { + fprintf(stderr, "%s: open socket error\n", prog); goto end; } sock_inited = 1; - if (connect(sock, (struct sockaddr *)&server , sizeof(server)) == SOCKET_ERROR) { - fprintf(stderr, "%s: connect error : %u\n", prog, WSAGetLastError()); + + if (tls_socket_connect(sock, &server) != 1) { + fprintf(stderr, "%s: socket connect error\n", prog); goto end; } -#else - if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - fprintf(stderr, "%s: open socket error : %s\n", prog, strerror(errno)); - goto end; - } - sock_inited = 1; - if (connect(sock, (struct sockaddr *)&server , sizeof(server)) < 0) { - fprintf(stderr, "%s: connect error : %s\n", prog, strerror(errno)); - goto end; - } -#endif if (tls_ctx_init(&ctx, TLS_protocol_tlcp, TLS_client_mode) != 1 || tls_ctx_set_cipher_suites(&ctx, client_ciphers, sizeof(client_ciphers)/sizeof(client_ciphers[0])) != 1) { diff --git a/tools/tlcp_server.c b/tools/tlcp_server.c index ab3a1acf..e0abfff7 100644 --- a/tools/tlcp_server.c +++ b/tools/tlcp_server.c @@ -124,49 +124,35 @@ bad: } } -#ifdef WIN32 - WORD wVersion; - WSADATA wsaData; - wVersion = MAKEWORD(2, 2); - int err; - if ((err = WSAStartup(wVersion, &wsaData)) != 0) { - fprintf(stderr, "WSAStartup error %d\n", err); + + if (tls_socket_lib_init() != 1) { + error_print(); return -1; } -#endif - - - // Socket - if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + if (tls_socket_create(&sock, AF_INET, SOCK_STREAM, 0) != 1) { error_print(); return 1; } server_addr.sin_family = AF_INET; server_addr.sin_addr.s_addr = INADDR_ANY; server_addr.sin_port = htons(port); -#ifdef WIN32 - if (bind(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) == SOCKET_ERROR) { - fprintf(stderr, "bind error %u\n", WSAGetLastError()); - goto end; - } -#else - if (bind(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) { - error_print(); - perror("tlcp_accept: bind: "); - goto end; - } -#endif - puts("start listen ...\n"); - listen(sock, 1); + if (tls_socket_bind(sock, &server_addr) != 1) { + fprintf(stderr, "%s: socket bind error\n", prog); + goto end; + } + + puts("start listen ...\n"); + tls_socket_listen(sock, 1); restart: client_addrlen = sizeof(client_addr); - if ((conn_sock = accept(sock, (struct sockaddr *)&client_addr, &client_addrlen)) < 0) { - error_print(); + + if (tls_socket_accept(sock, &client_addr, &conn_sock) != 1) { + fprintf(stderr, "%s: socket accept error\n", prog); goto end; } puts("socket connected\n"); diff --git a/tools/tls12_client.c b/tools/tls12_client.c index ee0a1a78..01fb10d8 100644 --- a/tools/tls12_client.c +++ b/tools/tls12_client.c @@ -89,6 +89,12 @@ bad: fprintf(stderr, "%s: '-in' option required\n", prog); return -1; } + + if (tls_socket_lib_init() != 1) { + error_print(); + return -1; + } + if (!(hp = gethostbyname(host))) { //herror("tls12_client: '-host' invalid"); // herror() not in winsock2, use WSAGetLastError() instead goto end; @@ -102,12 +108,12 @@ bad: server.sin_port = htons(port); - if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - //fprintf(stderr, "%s: open socket error : %s\n", prog, strerror(errno)); //FIXME: WIN32 use WSAGetLastError() + if (tls_socket_create(&sock, AF_INET, SOCK_STREAM, 0) != 1) { + fprintf(stderr, "%s: create socket error\n", prog); goto end; } - if (connect(sock, (struct sockaddr *)&server , sizeof(server)) < 0) { - //fprintf(stderr, "%s: connect error : %s\n", prog, strerror(errno)); // + if (tls_socket_connect(sock, &server) != 1) { + fprintf(stderr, "%s: socket connect error\n", prog); goto end; } diff --git a/tools/tls12_server.c b/tools/tls12_server.c index ccd60e40..55aa45a0 100644 --- a/tools/tls12_server.c +++ b/tools/tls12_server.c @@ -95,6 +95,11 @@ bad: memset(&ctx, 0, sizeof(ctx)); memset(&conn, 0, sizeof(conn)); + if (tls_socket_lib_init() != 1) { + error_print(); + return -1; + } + if (tls_ctx_init(&ctx, TLS_protocol_tls12, TLS_server_mode) != 1 || tls_ctx_set_cipher_suites(&ctx, server_ciphers, sizeof(server_ciphers)/sizeof(int)) != 1 || tls_ctx_set_certificate_and_key(&ctx, certfile, keyfile, pass) != 1) { @@ -109,28 +114,30 @@ bad: } // Socket - if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - error_print(); - return 1; + + if (tls_socket_create(&sock, AF_INET, SOCK_STREAM, 0) != 1) { + fprintf(stderr, "%s: create socket error\n", prog); + goto end; } + server_addr.sin_family = AF_INET; server_addr.sin_addr.s_addr = INADDR_ANY; server_addr.sin_port = htons(port); - if (bind(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) { - error_print(); - perror("tlcp_accept: bind: "); + + if (tls_socket_bind(sock, &server_addr) != 1) { + fprintf(stderr, "%s: socket bind error\n", prog); goto end; } + puts("start listen ...\n"); - listen(sock, 1); - - + tls_socket_listen(sock, 1); restart: - client_addrlen = sizeof(client_addr); - if ((conn_sock = accept(sock, (struct sockaddr *)&client_addr, &client_addrlen)) < 0) { - error_print(); + //client_addrlen = sizeof(client_addr); + + if (tls_socket_accept(sock, &client_addr, &conn_sock) != 1) { + fprintf(stderr, "%s: socket accept error\n", prog); goto end; } puts("socket connected\n"); diff --git a/tools/tls13_client.c b/tools/tls13_client.c index 50c17735..ea268651 100644 --- a/tools/tls13_client.c +++ b/tools/tls13_client.c @@ -89,6 +89,11 @@ bad: fprintf(stderr, "%s: '-in' option required\n", prog); return -1; } + + if (tls_socket_lib_init() != 1) { + error_print(); + return -1; + } if (!(hp = gethostbyname(host))) { //herror("tls13_client: '-host' invalid"); goto end; @@ -101,13 +106,12 @@ bad: server.sin_family = AF_INET; server.sin_port = htons(port); - - if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - fprintf(stderr, "%s: open socket error : %s\n", prog, strerror(errno)); + if (tls_socket_create(&sock, AF_INET, SOCK_STREAM, 0) != 1) { + fprintf(stderr, "%s: socket create error\n", prog); goto end; } - if (connect(sock, (struct sockaddr *)&server , sizeof(server)) < 0) { - fprintf(stderr, "%s: connect error : %s\n", prog, strerror(errno)); + if (tls_socket_connect(sock, &server) != 1) { + fprintf(stderr, "%s: socket connect error\n", prog); goto end; } diff --git a/tools/tls13_server.c b/tools/tls13_server.c index 09308ac2..a86ac5cd 100644 --- a/tools/tls13_server.c +++ b/tools/tls13_server.c @@ -89,6 +89,10 @@ bad: fprintf(stderr, "%s: '-pass' option required\n", prog); return 1; } + if (tls_socket_lib_init() != 1) { + error_print(); + return -1; + } memset(&ctx, 0, sizeof(ctx)); memset(&conn, 0, sizeof(conn)); @@ -106,29 +110,28 @@ bad: } } - // Socket - if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - error_print(); - return 1; + + if (tls_socket_create(&sock, AF_INET, SOCK_STREAM, 0) != 1) { + fprintf(stderr, "%s: socket create error\n", prog); + goto end; } server_addr.sin_family = AF_INET; server_addr.sin_addr.s_addr = INADDR_ANY; server_addr.sin_port = htons(port); - if (bind(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) { - error_print(); - perror("tlcp_accept: bind: "); + if (tls_socket_bind(sock, &server_addr) != 1) { + fprintf(stderr, "%s: socket bind error\n", prog); goto end; } puts("start listen ...\n"); - listen(sock, 1); + tls_socket_listen(sock, 1); restart: - client_addrlen = sizeof(client_addr); - if ((conn_sock = accept(sock, (struct sockaddr *)&client_addr, &client_addrlen)) < 0) { - error_print(); + //client_addrlen = sizeof(client_addr); + if (tls_socket_accept(sock, &client_addr, &conn_sock) != 1) { + fprintf(stderr, "%s: socket accept error\n", prog); goto end; } puts("socket connected\n");