diff --git a/include/gmssl/hkdf.h b/include/gmssl/hkdf.h index 45ce05c7..1287448e 100644 --- a/include/gmssl/hkdf.h +++ b/include/gmssl/hkdf.h @@ -65,7 +65,7 @@ int hkdf_extract(const DIGEST *digest, const uint8_t *salt, size_t saltlen, uint8_t *prk, size_t *prklen); int hkdf_expand(const DIGEST *digest, const uint8_t *prk, size_t prklen, - const uint8_t *info, size_t infolen, + const uint8_t *opt_info, size_t opt_infolen, size_t L, uint8_t *okm); diff --git a/src/asn1.c b/src/asn1.c index ea91aa6f..5619685f 100644 --- a/src/asn1.c +++ b/src/asn1.c @@ -457,6 +457,11 @@ int asn1_integer_to_der_ex(int tag, const uint8_t *a, size_t alen, uint8_t **out *(*out)++ = tag; (*outlen)++; + while (*a == 0 && alen > 1) { + a++; + alen--; + } + if (a[0] & 0x80) { asn1_length_to_der(alen + 1, out, outlen); if (out) { @@ -466,11 +471,7 @@ int asn1_integer_to_der_ex(int tag, const uint8_t *a, size_t alen, uint8_t **out } (*outlen) += 1 + alen; } else { - while (*a == 0 && alen > 1) { - a++; - alen--; - } - asn1_length_to_der(alen, out, outlen); + asn1_length_to_der(alen, out ,outlen); if (out) { memcpy(*out, a, alen); (*out) += alen; diff --git a/src/digest.c b/src/digest.c index e1f1036c..8eac5582 100644 --- a/src/digest.c +++ b/src/digest.c @@ -51,6 +51,7 @@ #include #include #include +#include typedef struct { @@ -83,6 +84,10 @@ const char *digest_name(const DIGEST *digest) int digest_init(DIGEST_CTX *ctx, const DIGEST *algor) { memset(ctx, 0, sizeof(DIGEST_CTX)); + if (algor == NULL) { + error_print(); + return -1; + } ctx->digest = algor; ctx->digest->init(ctx); return 1; @@ -90,12 +95,19 @@ int digest_init(DIGEST_CTX *ctx, const DIGEST *algor) int digest_update(DIGEST_CTX *ctx, const uint8_t *data, size_t datalen) { + if (data == NULL || datalen == 0) { + return 0; + } ctx->digest->update(ctx, data, datalen); return 1; } int digest_finish(DIGEST_CTX *ctx, uint8_t *dgst, size_t *dgstlen) { + if (dgst == NULL || dgstlen == NULL) { + error_print(); + return -1; + } ctx->digest->finish(ctx, dgst); *dgstlen = ctx->digest->digest_size; return 1; @@ -105,10 +117,11 @@ int digest(const DIGEST *digest, const uint8_t *data, size_t datalen, uint8_t *dgst, size_t *dgstlen) { DIGEST_CTX ctx; - if (!digest_init(&ctx, digest) - || !digest_update(&ctx, data, datalen) - || !digest_finish(&ctx, dgst, dgstlen)) { - return 0; + if (digest_init(&ctx, digest) != 1 + || digest_update(&ctx, data, datalen) < 0 + || digest_finish(&ctx, dgst, dgstlen) != 1) { + error_print(); + return -1; } memset(&ctx, 0, sizeof(DIGEST_CTX)); return 1; @@ -141,7 +154,8 @@ const DIGEST *digest_from_name(const char *name) static int sm3_digest_init(DIGEST_CTX *ctx) { if (!ctx) { - return 0; + error_print(); + return -1; } sm3_init(&ctx->u.sm3_ctx); return 1; @@ -150,7 +164,8 @@ static int sm3_digest_init(DIGEST_CTX *ctx) static int sm3_digest_update(DIGEST_CTX *ctx, const uint8_t *in, size_t inlen) { if (!ctx || (!in && inlen != 0)) { - return 0; + error_print(); + return -1; } sm3_update(&ctx->u.sm3_ctx, in, inlen); return 1; @@ -159,7 +174,8 @@ static int sm3_digest_update(DIGEST_CTX *ctx, const uint8_t *in, size_t inlen) static int sm3_digest_finish(DIGEST_CTX *ctx, uint8_t *dgst) { if (!ctx || !dgst) { - return 0; + error_print(); + return -1; } sm3_finish(&ctx->u.sm3_ctx, dgst); return 1; @@ -186,7 +202,8 @@ const DIGEST *DIGEST_sm3(void) static int md5_digest_init(DIGEST_CTX *ctx) { if (!ctx) { - return 0; + error_print(); + return -1; } md5_init(&ctx->u.md5_ctx); return 1; @@ -195,7 +212,8 @@ static int md5_digest_init(DIGEST_CTX *ctx) static int md5_digest_update(DIGEST_CTX *ctx, const uint8_t *in, size_t inlen) { if (!ctx || (!in && inlen != 0)) { - return 0; + error_print(); + return -1; } md5_update(&ctx->u.md5_ctx, in, inlen); return 1; @@ -204,7 +222,8 @@ static int md5_digest_update(DIGEST_CTX *ctx, const uint8_t *in, size_t inlen) static int md5_digest_finish(DIGEST_CTX *ctx, uint8_t *dgst) { if (!ctx || !dgst) { - return 0; + error_print(); + return -1; } md5_finish(&ctx->u.md5_ctx, dgst); return 1; @@ -231,7 +250,8 @@ const DIGEST *DIGEST_md5(void) static int sha1_digest_init(DIGEST_CTX *ctx) { if (!ctx) { - return 0; + error_print(); + return -1; } sha1_init(&ctx->u.sha1_ctx); return 1; @@ -240,7 +260,8 @@ static int sha1_digest_init(DIGEST_CTX *ctx) static int sha1_digest_update(DIGEST_CTX *ctx, const uint8_t *in, size_t inlen) { if (!ctx || (!in && inlen != 0)) { - return 0; + error_print(); + return -1; } sha1_update(&ctx->u.sha1_ctx, in, inlen); return 1; @@ -249,7 +270,8 @@ static int sha1_digest_update(DIGEST_CTX *ctx, const uint8_t *in, size_t inlen) static int sha1_digest_finish(DIGEST_CTX *ctx, uint8_t *dgst) { if (!ctx || !dgst) { - return 0; + error_print(); + return -1; } sha1_finish(&ctx->u.sha1_ctx, dgst); return 1; @@ -276,7 +298,8 @@ const DIGEST *DIGEST_sha1(void) static int sha224_digest_init(DIGEST_CTX *ctx) { if (!ctx) { - return 0; + error_print(); + return -1; } sha224_init(&ctx->u.sha224_ctx); return 1; @@ -285,7 +308,8 @@ static int sha224_digest_init(DIGEST_CTX *ctx) static int sha224_digest_update(DIGEST_CTX *ctx, const uint8_t *in, size_t inlen) { if (!ctx || (!in && inlen != 0)) { - return 0; + error_print(); + return -1; } sha224_update(&ctx->u.sha224_ctx, in, inlen); return 1; @@ -294,7 +318,8 @@ static int sha224_digest_update(DIGEST_CTX *ctx, const uint8_t *in, size_t inlen static int sha224_digest_finish(DIGEST_CTX *ctx, uint8_t *dgst) { if (!ctx || !dgst) { - return 0; + error_print(); + return -1; } sha224_finish(&ctx->u.sha224_ctx, dgst); return 1; @@ -318,7 +343,8 @@ const DIGEST *DIGEST_sha224(void) static int sha256_digest_init(DIGEST_CTX *ctx) { if (!ctx) { - return 0; + error_print(); + return -1; } sha256_init(&ctx->u.sha256_ctx); return 1; @@ -327,7 +353,8 @@ static int sha256_digest_init(DIGEST_CTX *ctx) static int sha256_digest_update(DIGEST_CTX *ctx, const uint8_t *in, size_t inlen) { if (!ctx || (!in && inlen != 0)) { - return 0; + error_print(); + return -1; } sha256_update(&ctx->u.sha256_ctx, in, inlen); return 1; @@ -336,7 +363,8 @@ static int sha256_digest_update(DIGEST_CTX *ctx, const uint8_t *in, size_t inlen static int sha256_digest_finish(DIGEST_CTX *ctx, uint8_t *dgst) { if (!ctx || !dgst) { - return 0; + error_print(); + return -1; } sha256_finish(&ctx->u.sha256_ctx, dgst); return 1; @@ -361,7 +389,8 @@ const DIGEST *DIGEST_sha256(void) static int sha384_digest_init(DIGEST_CTX *ctx) { if (!ctx) { - return 0; + error_print(); + return -1; } sha384_init(&ctx->u.sha384_ctx); return 1; @@ -370,7 +399,8 @@ static int sha384_digest_init(DIGEST_CTX *ctx) static int sha384_digest_update(DIGEST_CTX *ctx, const uint8_t *in, size_t inlen) { if (!ctx || (!in && inlen != 0)) { - return 0; + error_print(); + return -1; } sha384_update(&ctx->u.sha384_ctx, in, inlen); return 1; @@ -379,7 +409,8 @@ static int sha384_digest_update(DIGEST_CTX *ctx, const uint8_t *in, size_t inlen static int sha384_digest_finish(DIGEST_CTX *ctx, uint8_t *dgst) { if (!ctx || !dgst) { - return 0; + error_print(); + return -1; } sha384_finish(&ctx->u.sha384_ctx, dgst); return 1; @@ -404,7 +435,8 @@ const DIGEST *DIGEST_sha384(void) static int sha512_digest_init(DIGEST_CTX *ctx) { if (!ctx) { - return 0; + error_print(); + return -1; } sha512_init(&ctx->u.sha512_ctx); return 1; @@ -413,7 +445,8 @@ static int sha512_digest_init(DIGEST_CTX *ctx) static int sha512_digest_update(DIGEST_CTX *ctx, const uint8_t *in, size_t inlen) { if (!ctx || (!in && inlen != 0)) { - return 0; + error_print(); + return -1; } sha512_update(&ctx->u.sha512_ctx, in, inlen); return 1; @@ -422,7 +455,8 @@ static int sha512_digest_update(DIGEST_CTX *ctx, const uint8_t *in, size_t inlen static int sha512_digest_finish(DIGEST_CTX *ctx, uint8_t *dgst) { if (!ctx || !dgst) { - return 0; + error_print(); + return -1; } sha512_finish(&ctx->u.sha512_ctx, dgst); return 1; @@ -448,7 +482,8 @@ static int sha512_224_digest_finish(DIGEST_CTX *ctx, uint8_t *dgst) { uint8_t buf[SHA512_DIGEST_SIZE]; if (!ctx || !dgst) { - return 0; + error_print(); + return -1; } sha512_finish(&ctx->u.sha512_ctx, buf); memcpy(dgst, buf, SHA224_DIGEST_SIZE); @@ -476,7 +511,8 @@ static int sha512_256_digest_finish(DIGEST_CTX *ctx, uint8_t *dgst) { uint8_t buf[SHA512_DIGEST_SIZE]; if (!ctx || !dgst) { - return 0; + error_print(); + return -1; } sha512_finish(&ctx->u.sha512_ctx, buf); memcpy(dgst, buf, SHA256_DIGEST_SIZE); diff --git a/src/hash_drbg.c b/src/hash_drbg.c index 20bcb214..3595254a 100644 --- a/src/hash_drbg.c +++ b/src/hash_drbg.c @@ -51,6 +51,7 @@ #include #include #include +#include #include "endian.h" static int hash_df(const DIGEST *digest, const uint8_t *in, size_t inlen, @@ -67,11 +68,11 @@ static int hash_df(const DIGEST *digest, const uint8_t *in, size_t inlen, PUTU32(outbits, (uint32_t)outlen << 3); while (outlen > 0) { - if (!digest_init(&ctx, digest) - || !digest_update(&ctx, &counter, sizeof(counter)) - || !digest_update(&ctx, outbits, sizeof(outbits)) - || !digest_update(&ctx, in, inlen) - || !digest_finish(&ctx, dgst, &len)) { + if (digest_init(&ctx, digest) != 1 + || digest_update(&ctx, &counter, sizeof(counter)) != 1 + || digest_update(&ctx, outbits, sizeof(outbits)) != 1 + || digest_update(&ctx, in, inlen) != 1 + || digest_finish(&ctx, dgst, &len) != 1) { goto end; } @@ -211,7 +212,7 @@ end: static void drbg_add(uint8_t *R, const uint8_t *A, size_t seedlen) { int temp = 0; - size_t i; + int i; for (i = seedlen - 1; i >= 0; i--) { temp += R[i] + A[i]; R[i] = temp & 0xff; @@ -222,7 +223,7 @@ static void drbg_add(uint8_t *R, const uint8_t *A, size_t seedlen) static void drbg_add1(uint8_t *R, size_t seedlen) { int temp = 1; - size_t i; + int i; for (i = seedlen - 1; i >= 0; i--) { temp += R[i]; R[i] = temp & 0xff; @@ -244,9 +245,9 @@ static int drbg_hashgen(HASH_DRBG *drbg, size_t outlen, uint8_t *out) while (outlen > 0) { /* output Hash(data) */ - if (!digest_init(&ctx, drbg->digest) - || !digest_update(&ctx, data, drbg->seedlen) - || !digest_finish(&ctx, dgst, &len)) { + if (digest_init(&ctx, drbg->digest) != 1 + || digest_update(&ctx, data, drbg->seedlen) != 1 + || digest_finish(&ctx, dgst, &len) != 1) { goto end; } @@ -288,11 +289,11 @@ int hash_drbg_generate(HASH_DRBG *drbg, if (additional) { /* w = Hash (0x02 || V || additional_input) */ prefix = 0x02; - if (!digest_init(&ctx, drbg->digest) - || !digest_update(&ctx, &prefix, 1) - || !digest_update(&ctx, drbg->V, drbg->seedlen) - || !digest_update(&ctx, additional, additional_len) - || !digest_finish(&ctx, dgst, &dgstlen)) { + if (digest_init(&ctx, drbg->digest) != 1 + || digest_update(&ctx, &prefix, 1) != 1 + || digest_update(&ctx, drbg->V, drbg->seedlen) != 1 + || digest_update(&ctx, additional, additional_len) != 1 + || digest_finish(&ctx, dgst, &dgstlen) != 1) { goto end; } @@ -307,10 +308,10 @@ int hash_drbg_generate(HASH_DRBG *drbg, /* H = Hash (0x03 || V). */ prefix = 0x03; - if (!digest_init(&ctx, drbg->digest) - || !digest_update(&ctx, &prefix, 1) - || !digest_update(&ctx, drbg->V, drbg->seedlen) - || !digest_finish(&ctx, dgst, &dgstlen)) { + if (digest_init(&ctx, drbg->digest) != 1 + || digest_update(&ctx, &prefix, 1) != 1 + || digest_update(&ctx, drbg->V, drbg->seedlen) != 1 + || digest_finish(&ctx, dgst, &dgstlen) != 1) { goto end; } diff --git a/src/hkdf.c b/src/hkdf.c index b44027ba..b7e7c87c 100644 --- a/src/hkdf.c +++ b/src/hkdf.c @@ -113,7 +113,7 @@ int hkdf_extract(const DIGEST *digest, const uint8_t *salt, size_t saltlen, } int hkdf_expand(const DIGEST *digest, const uint8_t *prk, size_t prklen, - const uint8_t *info, size_t infolen, + const uint8_t *opt_info, size_t opt_infolen, size_t L, uint8_t *okm) { HMAC_CTX hmac_ctx; @@ -123,7 +123,7 @@ int hkdf_expand(const DIGEST *digest, const uint8_t *prk, size_t prklen, if (L > 0) { if (hmac_init(&hmac_ctx, digest, prk, prklen) != 1 - || hmac_update(&hmac_ctx, info, infolen) != 1 + || hmac_update(&hmac_ctx, opt_info, opt_infolen) < 0 || hmac_update(&hmac_ctx, &counter, 1) != 1 || hmac_finish(&hmac_ctx, T, &len) != 1) { error_print(); @@ -144,7 +144,7 @@ int hkdf_expand(const DIGEST *digest, const uint8_t *prk, size_t prklen, } if (hmac_init(&hmac_ctx, digest, prk, prklen) != 1 || hmac_update(&hmac_ctx, T, len) != 1 - || hmac_update(&hmac_ctx, info, infolen) != 1 + || hmac_update(&hmac_ctx, opt_info, opt_infolen) < 0 || hmac_update(&hmac_ctx, &counter, 1) != 1 || hmac_finish(&hmac_ctx, T, &len) != 1) { error_print(); diff --git a/src/hmac.c b/src/hmac.c index d8b3f9cb..479fe2b6 100644 --- a/src/hmac.c +++ b/src/hmac.c @@ -97,10 +97,13 @@ int hmac_init(HMAC_CTX *ctx, const DIGEST *digest, const uint8_t *key, size_t ke int hmac_update(HMAC_CTX *ctx, const uint8_t *data, size_t datalen) { - if (!ctx || (!data && datalen != 0)) { + if (ctx == NULL) { error_print(); return -1; } + if (data == NULL || datalen == 0) { + return 0; + } if (digest_update(&ctx->digest_ctx, data, datalen) != 1) { error_print(); return -1; @@ -110,6 +113,10 @@ int hmac_update(HMAC_CTX *ctx, const uint8_t *data, size_t datalen) int hmac_finish(HMAC_CTX *ctx, uint8_t *mac, size_t *maclen) { + if (ctx == NULL || maclen == NULL) { + error_print(); + return -1; + } if (digest_finish(&ctx->digest_ctx, mac, maclen) != 1) { error_print(); return -1; diff --git a/src/tlcp.c b/src/tlcp.c index c7414e17..4929a9d8 100644 --- a/src/tlcp.c +++ b/src/tlcp.c @@ -764,7 +764,7 @@ int tlcp_accept(TLS_CONNECT *conn, int port, handshakeslen += recordlen - 5; } - if (handshakes) { + if (client_cacerts_fp) { tls_trace("<<<< ClientCertificate\n"); if (tls_record_recv(record, &recordlen, conn->sock) != 1 || tls_record_version(record) != TLS_version_tlcp) { @@ -812,7 +812,7 @@ int tlcp_accept(TLS_CONNECT *conn, int port, return -1; } - if (handshakes) { + if (client_cacerts_fp) { tls_trace("<<<< CertificateVerify\n"); if (tls_record_recv(record, &recordlen, conn->sock) != 1 || tls_record_version(record) != TLS_version_tlcp) { diff --git a/src/tls.c b/src/tls.c index 23328b26..729d911c 100644 --- a/src/tls.c +++ b/src/tls.c @@ -106,7 +106,7 @@ void tls_uint32_to_bytes(uint32_t a, uint8_t **out, size_t *outlen) void tls_array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen) { - if (out) { + if (*out) { memcpy(*out, data, datalen); *out += datalen; } @@ -810,16 +810,44 @@ int tls_record_set_handshake_server_hello(uint8_t *record, size_t *recordlen, uint8_t *p = record + 5 + 4; size_t len = 0; - if (!record || !recordlen || !tls_version_text(version) || !random - || (!session_id && session_id_len) || session_id_len > 32 - || (!exts && exts_len) || exts_len > 512) { + if (!record || !recordlen) { error_print(); return -1; } - if (record[0] != TLS_record_handshake - || !tls_version_text(version) - || !tls_cipher_suite_name(cipher_suite) - || version < tls_record_version(record)) { + if (tls_version_text(version) == NULL || random == NULL) { + error_print(); + return -1; + } + if (session_id != NULL) { + if (session_id_len <= 0 || session_id_len > 32) { + error_print(); + return -1; + } + } + if (exts && exts_len > 512) { + error_print(); + return -1; + } + + + if (record[0] != TLS_record_handshake) { + error_print(); + return -1; + } + if (!tls_version_text(version)) { + error_print(); + return -1; + } + if (!tls_cipher_suite_name(cipher_suite)) { + error_print(); + return -1; + } + if (version < tls_record_version(record)) { + + + printf("version = %d\n", version); + printf("version = %d\n", tls_record_version(record)); + error_print(); return -1; } @@ -1087,11 +1115,10 @@ int tls_certificate_chain_verify(const uint8_t *certs, size_t certslen, FILE *ca size_t certlen; const uint8_t *cacert; size_t cacertlen; - const uint8_t *subject; - size_t subject_len; - uint8_t rootcacert[1024]; - size_t rootcacertlen; - const char *signer_id = SM2_DEFAULT_ID; + const uint8_t *issuer; + size_t issuer_len; + uint8_t rootcert[4096]; + size_t rootcertlen; if (tls_uint24array_from_bytes(&cert, &certlen, &certs, &certslen) != 1) { error_print(); @@ -1102,19 +1129,19 @@ int tls_certificate_chain_verify(const uint8_t *certs, size_t certslen, FILE *ca error_print(); return -1; } - if (x509_cert_verify_by_ca_cert(cert, certlen, cacert, cacertlen, signer_id, strlen(signer_id)) != 1) { + if (x509_cert_verify_by_ca_cert(cert, certlen, cacert, cacertlen, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1) { error_print(); return -1; } cert = cacert; certlen = cacertlen; } - if (x509_cert_get_subject(cacert, cacertlen, &subject, &subject_len) != 1) { + if (x509_cert_get_issuer(cert, certlen, &issuer, &issuer_len) != 1) { error_print(); return -1; } - if (x509_cert_from_pem_by_subject(rootcacert, &rootcacertlen, sizeof(rootcacert), subject, subject_len, ca_certs_fp) != 1 - || x509_cert_verify_by_ca_cert(cert, certlen, cacert, cacertlen, signer_id, strlen(signer_id)) != 1) { + if (x509_cert_from_pem_by_subject(rootcert, &rootcertlen, sizeof(rootcert), issuer, issuer_len, ca_certs_fp) != 1 + || x509_cert_verify_by_ca_cert(cert, certlen, rootcert, rootcertlen, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1) { error_print(); return -1; } @@ -1539,7 +1566,7 @@ int tls_send(TLS_CONNECT *conn, const uint8_t *data, size_t datalen) seq_num = conn->server_seq_num; } - tls_trace(">>>> ApplicationData\n"); + tls_trace("send ApplicationData\n"); if (tls_record_set_version(mrec, conn->version) != 1 || tls_record_set_application_data(mrec, &mlen, data, datalen) != 1 || tls_record_encrypt(hmac_ctx, enc_key, seq_num, mrec, mlen, crec, &clen) != 1 @@ -1572,7 +1599,7 @@ int tls_recv(TLS_CONNECT *conn, uint8_t *data, size_t *datalen) seq_num = conn->client_seq_num; } - tls_trace("<<<< ApplicationData\n"); + tls_trace("recv ApplicationData\n"); if (tls_record_recv(crec, &clen, conn->sock) != 1 // FIXME: 检查版本号 || tls_record_decrypt(hmac_ctx, dec_key, seq_num, crec, clen, mrec, &mlen) != 1 diff --git a/src/tls12.c b/src/tls12.c index 4fbc7439..c0474a8f 100644 --- a/src/tls12.c +++ b/src/tls12.c @@ -80,12 +80,12 @@ static const uint8_t tls12_exts[] = { /* signature_algors */ 0x00,0x0D, 0x00,0x04, 0x00,0x02, 0x07,0x07,//0x08, // sm2sig_sm3 }; -/* int tls_server_extensions_check(const uint8_t *exts, size_t extslen) { - return -1; + return 1; } +/* int tls_construct_server_extensions(const uint8_t *client_exts, size_t client_exts_len, uint8_t *server_exts, size_t *server_exts_len) { @@ -228,63 +228,75 @@ int tls12_connect(TLS_CONNECT *conn, const char *hostname, int port, uint8_t finished[256]; size_t finishedlen; + // handshake int type; const uint8_t *data; size_t datalen; + // client_hello, server_hello uint8_t client_random[32]; uint8_t server_random[32]; - uint8_t exts[TLS_MAX_EXTENSIONS_SIZE]; - size_t exts_len; + uint8_t server_exts[TLS_MAX_EXTENSIONS_SIZE]; + size_t server_exts_len; - SM2_KEY server_sign_key; + SM2_KEY server_pub_key; SM2_SIGN_CTX verify_ctx; - SM2_SIGN_CTX sign_ctx; // for certificate_verify signature generation + SM2_SIGN_CTX sign_ctx; uint8_t sig[TLS_MAX_SIGNATURE_SIZE]; size_t siglen = sizeof(sig); - + // key_exchange int curve; SM2_KEY client_ecdh; SM2_POINT server_ecdh_public; - uint8_t pre_master_secret[64]; + // finished verify_data SM3_CTX sm3_ctx; SM3_CTX tmp_sm3_ctx; uint8_t sm3_hash[32]; uint8_t verify_data[12]; - uint8_t local_verify_data[12]; + uint8_t remote_verify_data[12]; struct sockaddr_in server; - server.sin_addr.s_addr = inet_addr(hostname); - server.sin_family = AF_INET; - server.sin_port = htons(port); sm3_init(&sm3_ctx); - if (client_sign_key) - sm2_sign_init(&sign_ctx, client_sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH); + + + + server.sin_addr.s_addr = inet_addr(hostname); + server.sin_family = AF_INET; + server.sin_port = htons(port); + if ((conn->sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + error_print(); + return -1; + } + if (connect(conn->sock, (struct sockaddr *)&server , sizeof(server)) < 0) { + error_print(); + return -1; + } + conn->is_client = 1; + + + if (client_certs_fp) { + if (!client_sign_key) { + error_print(); + return -1; + } + if (sm2_sign_init(&sign_ctx, client_sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1) { + error_print(); + return -1; + } + } + tls_record_set_version(record, TLS_version_tls1); tls_record_set_version(finished, TLS_version_tls12); - if ((conn->sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - error_print(); - return -1; - } - - if (connect(conn->sock, (struct sockaddr *)&server , sizeof(server)) < 0) { - error_print(); - return -1; - } - - conn->is_client = 1; - - - tls_trace(">>>> ClientHello\n"); + tls_trace("send ClientHello\n"); tls_random_generate(client_random); if (tls_record_set_handshake_client_hello(record, &recordlen, TLS_version_tls12, client_random, NULL, 0, @@ -301,7 +313,7 @@ int tls12_connect(TLS_CONNECT *conn, const char *hostname, int port, if (client_sign_key) sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); - tls_trace("<<<< ServerHello\n"); + tls_trace("recv ServerHello\n"); if (tls12_record_recv(record, &recordlen, conn->sock) != 1) { error_print(); return -1; @@ -309,7 +321,7 @@ int tls12_connect(TLS_CONNECT *conn, const char *hostname, int port, tls_record_print(stderr, record, recordlen, 0, 0); if (tls_record_get_handshake_server_hello(record, &conn->version, server_random, conn->session_id, &conn->session_id_len, - &conn->cipher_suite, exts, &exts_len) != 1) { + &conn->cipher_suite, server_exts, &server_exts_len) != 1) { error_print(); return -1; } @@ -321,12 +333,17 @@ int tls12_connect(TLS_CONNECT *conn, const char *hostname, int port, error_print(); return -1; } - // FIXME: check extensions + if (tls_server_extensions_check(server_exts, server_exts_len) != 1) { + error_print(); + return -1; + } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - if (client_sign_key) + if (client_certs_fp) { sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); + } - tls_trace("<<<< ServerCertificate\n"); + + tls_trace("recv ServerCertificate\n"); if (tls12_record_recv(record, &recordlen, conn->sock) != 1) { error_print(); return -1; @@ -338,30 +355,32 @@ int tls12_connect(TLS_CONNECT *conn, const char *hostname, int port, } /* - // FIXME: review cert chain verification + // FIXME: Segmentation fault! if (tls_certificate_chain_verify(conn->server_certs, conn->server_certs_len, ca_certs_fp, 5) != 1) { error_print(); return -1; } */ if (tls_certificate_get_public_keys(conn->server_certs, conn->server_certs_len, - &server_sign_key, NULL) != 1) { + &server_pub_key, NULL) != 1) { error_print(); return -1; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - if (client_sign_key) + if (client_certs_fp) { sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); + } - tls_trace("<<<< ServerKeyExchange\n"); + tls_trace("recv ServerKeyExchange\n"); if (tls12_record_recv(record, &recordlen, conn->sock) != 1) { error_print(); return -1; } tls_record_print(stderr, record, recordlen, conn->cipher_suite << 8, 0); sm3_update(&sm3_ctx, record + 5, recordlen - 5); - if (client_sign_key) + if (client_certs_fp) { sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); + } if (tls_record_get_handshake_server_key_exchange_ecdhe(record, &curve, &server_ecdh_public, sig, &siglen) != 1) { error_print(); @@ -371,49 +390,36 @@ int tls12_connect(TLS_CONNECT *conn, const char *hostname, int port, error_print(); return -1; } - if (tls_verify_server_ecdh_params(&server_sign_key, + if (tls_verify_server_ecdh_params(&server_pub_key, client_random, server_random, curve, &server_ecdh_public, sig, siglen) != 1) { error_print(); return -1; } - tls_trace("++++ generate secrets\n"); - sm2_key_generate(&client_ecdh); - sm2_ecdh(&client_ecdh, &server_ecdh_public, &server_ecdh_public); - memcpy(pre_master_secret, &server_ecdh_public, 32); - - - tls_prf(pre_master_secret, 32, "master secret", - client_random, 32, - server_random, 32, - 48, conn->master_secret); - tls_prf(conn->master_secret, 48, "key expansion", - server_random, 32, - client_random, 32, - 96, conn->key_block); - sm3_hmac_init(&conn->client_write_mac_ctx, conn->key_block, 32); - sm3_hmac_init(&conn->server_write_mac_ctx, conn->key_block + 32, 32); - sm4_set_encrypt_key(&conn->client_write_enc_key, conn->key_block + 64); - sm4_set_decrypt_key(&conn->server_write_enc_key, conn->key_block + 80); - tls_secrets_print(stderr, pre_master_secret, 32, client_random, server_random, - conn->master_secret, conn->key_block, 96, 0, 0); if (tls12_record_recv(record, &recordlen, conn->sock) != 1) { error_print(); return -1; } - tls_record_print(stderr, record, recordlen, 0, 0); if (tls_record_get_handshake(record, &type, &data, &datalen) != 1) { error_print(); return -1; } if (type == TLS_handshake_certificate_request) { - tls_trace("<<<< CertificateRequest\n"); int cert_types[TLS_MAX_CERTIFICATE_TYPES]; size_t cert_types_count;; uint8_t ca_names[TLS_MAX_CA_NAMES_SIZE]; size_t ca_names_len; + + tls_trace("recv CertificateRequest\n"); + tls_record_print(stderr, record, recordlen, 0, 0); + if (!client_certs_fp) { + // 这里应该响应一个Alert吧? + error_print(); + return -1; + } + if (tls_record_get_handshake_certificate_request(record, cert_types, &cert_types_count, ca_names, &ca_names_len) != 1) { @@ -421,8 +427,7 @@ int tls12_connect(TLS_CONNECT *conn, const char *hostname, int port, return -1; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - if (client_sign_key) - sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); + sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); if (tls_record_recv(record, &recordlen, conn->sock) != 1) { error_print(); @@ -432,19 +437,20 @@ int tls12_connect(TLS_CONNECT *conn, const char *hostname, int port, memset(&sign_ctx, 0, sizeof(SM2_SIGN_CTX)); client_sign_key = NULL; } - tls_trace("<<<< ServerHelloDone\n"); + tls_trace("recv ServerHelloDone\n"); tls_record_print(stderr, record, recordlen, 0, 0); if (tls_record_get_handshake_server_hello_done(record) != 1) { error_print(); return -1; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - - - if (client_sign_key) { + if (client_certs_fp) { sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); + } - tls_trace(">>>> ClientCertificate\n"); + + if (client_certs_fp) { + tls_trace("send ClientCertificate\n"); if (tls_record_set_handshake_certificate_from_pem(record, &recordlen, client_certs_fp) != 1) { error_print(); return -1; @@ -459,10 +465,28 @@ int tls12_connect(TLS_CONNECT *conn, const char *hostname, int port, } - tls_trace(">>>> ClientKeyExchange\n"); + tls_trace("generate secrets\n"); + sm2_key_generate(&client_ecdh); + sm2_ecdh(&client_ecdh, &server_ecdh_public, &server_ecdh_public); + memcpy(pre_master_secret, &server_ecdh_public, 32); + + tls_prf(pre_master_secret, 32, "master secret", + client_random, 32, + server_random, 32, + 48, conn->master_secret); + tls_prf(conn->master_secret, 48, "key expansion", + server_random, 32, + client_random, 32, + 96, conn->key_block); + sm3_hmac_init(&conn->client_write_mac_ctx, conn->key_block, 32); + sm3_hmac_init(&conn->server_write_mac_ctx, conn->key_block + 32, 32); + sm4_set_encrypt_key(&conn->client_write_enc_key, conn->key_block + 64); + sm4_set_decrypt_key(&conn->server_write_enc_key, conn->key_block + 80); + tls_secrets_print(stderr, pre_master_secret, 32, client_random, server_random, + conn->master_secret, conn->key_block, 96, 0, 4); - // 客户端的临时公钥 + tls_trace("send ClientKeyExchange\n"); if (tls_record_set_handshake_client_key_exchange_ecdhe(record, &recordlen, &client_ecdh.public_key) != 1) { error_print(); @@ -474,12 +498,13 @@ int tls12_connect(TLS_CONNECT *conn, const char *hostname, int port, return -1; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - if (client_sign_key) - sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); - if (client_sign_key) { - tls_trace(">>>> CertificateVerify\n"); + + if (client_certs_fp) { + tls_trace("send CertificateVerify\n"); + sm2_sign_update(&sign_ctx, record + 5, recordlen - 5); sm2_sign_finish(&sign_ctx, sig, &siglen); + memset(&sign_ctx, 0, sizeof(sign_ctx)); if (tls_record_set_handshake_certificate_verify(record, &recordlen, sig, siglen) != 1) { error_print(); return -1; @@ -492,7 +517,8 @@ int tls12_connect(TLS_CONNECT *conn, const char *hostname, int port, sm3_update(&sm3_ctx, record + 5, recordlen - 5); } - tls_trace(">>>> [ChangeCipherSpec]\n"); + + tls_trace("send [ChangeCipherSpec]\n"); if (tls_record_set_change_cipher_spec(record, &recordlen) !=1) { error_print(); return -1; @@ -503,7 +529,8 @@ int tls12_connect(TLS_CONNECT *conn, const char *hostname, int port, } tls_record_print(stderr, record, recordlen, 0, 0); - tls_trace(">>>> Finished\n"); + + tls_trace("send Finished\n"); memcpy(&tmp_sm3_ctx, &sm3_ctx, sizeof(sm3_ctx)); sm3_finish(&tmp_sm3_ctx, sm3_hash); @@ -528,7 +555,7 @@ int tls12_connect(TLS_CONNECT *conn, const char *hostname, int port, return -1; } - tls_trace("<<<< [ChangeCipherSpec]\n"); + tls_trace("recv [ChangeCipherSpec]\n"); if (tls12_record_recv(record, &recordlen, conn->sock) != 1) { error_print(); return -1; @@ -539,7 +566,7 @@ int tls12_connect(TLS_CONNECT *conn, const char *hostname, int port, return -1; } - tls_trace("<<<< Finished\n"); + tls_trace("recv Finished\n"); if (tls12_record_recv(record, &recordlen, conn->sock) != 1) { error_print(); return -1; @@ -551,31 +578,38 @@ int tls12_connect(TLS_CONNECT *conn, const char *hostname, int port, } tls_record_print(stderr, finished, finishedlen, 0, 0); tls_seq_num_incr(conn->server_seq_num); - if (tls_record_get_handshake_finished(finished, verify_data) != 1) { + if (tls_record_get_handshake_finished(finished, remote_verify_data) != 1) { error_print(); return -1; } sm3_finish(&sm3_ctx, sm3_hash); tls_prf(conn->master_secret, 48, "server finished", sm3_hash, 32, NULL, 0, - 12, local_verify_data); - if (memcmp(local_verify_data, verify_data, 12) != 0) { + 12, verify_data); + if (memcmp(verify_data, remote_verify_data, 12) != 0) { error_puts("server_finished.verify_data verification failure"); return -1; } - tls_trace("++++ Connection established\n"); + tls_trace("SSL Connection Established\n\n"); // 这里应该把协商的参数打印出来 return 1; } // 实际上我们需要好几个比较大的buffer // 一个是记录的buffer -// 还有就是server端需要一个握手buffer +// 还有就是server端需要一个握手buffer,这是啥? + +/* +常规情况下服务器和客户端对所有的握手消息计算哈希值,最后用于Finished消息 +但是如果服务器要求客户端提供客户端证书,那么就必须要验证客户端证书 +*/ int tls12_accept(TLS_CONNECT *conn, int port, FILE *server_certs_fp, const SM2_KEY *server_sign_key, - FILE *client_cacerts_fp, uint8_t *handshakes_buf, size_t handshakes_buflen) + FILE *client_cacerts_fp, + uint8_t *handshakes_buf, + size_t handshakes_buflen) { uint8_t *handshakes = handshakes_buf; size_t handshakeslen = 0; @@ -645,7 +679,7 @@ int tls12_accept(TLS_CONNECT *conn, int port, sm3_init(&sm3_ctx); - tls_trace("<<<< ClientHello\n"); + tls_trace("recv ClientHello\n"); if (tls_record_recv(record, &recordlen, conn->sock) != 1) { error_print(); return -1; @@ -678,15 +712,9 @@ int tls12_accept(TLS_CONNECT *conn, int port, } sm3_update(&sm3_ctx, record + 5, recordlen - 5); tls_array_to_bytes(record + 5, recordlen - 5, &handshakes, &handshakeslen); - if (handshakes) { - /* - memcpy(handshakes, record + 5, recordlen - 5); - handshakes += recordlen - 5; - handshakeslen += recordlen - 5; - */ - } - tls_trace(">>>> ServerHello\n"); + + tls_trace("send ServerHello\n"); tls_random_generate(server_random); tls_record_set_version(record, conn->version); if (tls_record_set_handshake_server_hello(record, &recordlen, @@ -701,13 +729,10 @@ int tls12_accept(TLS_CONNECT *conn, int port, return -1; } sm3_update(&sm3_ctx, record + 5, recordlen - 5); - if (handshakes) { - memcpy(handshakes, record + 5, recordlen - 5); - handshakes += recordlen - 5; - handshakeslen += recordlen - 5; - } + tls_array_to_bytes(record + 5, recordlen - 5, &handshakes, &handshakeslen); - tls_trace(">>>> ServerCertificate\n"); + + tls_trace("send ServerCertificate\n"); if (tls_record_set_handshake_certificate_from_pem(record, &recordlen, server_certs_fp) != 1) { error_print(); return -1; @@ -723,15 +748,9 @@ int tls12_accept(TLS_CONNECT *conn, int port, } sm3_update(&sm3_ctx, record + 5, recordlen - 5); tls_array_to_bytes(record + 5, recordlen - 5, &handshakes, &handshakeslen); - if (handshakes) { - /* - memcpy(handshakes, record + 5, recordlen - 5); - handshakes += recordlen - 5; - handshakeslen += recordlen - 5; - */ - } - tls_trace(">>>> ServerKeyExchange\n"); + + tls_trace("send ServerKeyExchange\n"); sm2_key_generate(&server_ecdh); if (tls_sign_server_ecdh_params(server_sign_key, client_random, server_random, @@ -751,16 +770,11 @@ int tls12_accept(TLS_CONNECT *conn, int port, } sm3_update(&sm3_ctx, record + 5, recordlen - 5); tls_array_to_bytes(record + 5, recordlen - 5, &handshakes, &handshakeslen); - if (handshakes) { - /* - memcpy(handshakes, record + 5, recordlen - 5); - handshakes += recordlen - 5; - handshakeslen += recordlen - 5; - */ - } + + if (client_cacerts_fp) { - tls_trace(">>>> CertificateRequest\n"); + tls_trace("send CertificateRequest\n"); const int cert_types[] = { TLS_cert_type_ecdsa_sign, }; uint8_t ca_names[TLS_MAX_CA_NAMES_SIZE] = {0}; size_t cert_types_count = sizeof(cert_types)/sizeof(cert_types[0]); @@ -778,16 +792,9 @@ int tls12_accept(TLS_CONNECT *conn, int port, } sm3_update(&sm3_ctx, record + 5, recordlen - 5); tls_array_to_bytes(record + 5, recordlen - 5, &handshakes, &handshakeslen); - if (handshakes) { - /* - memcpy(handshakes, record + 5, recordlen - 5); - handshakes += recordlen - 5; - handshakeslen += recordlen - 5; - */ - } } - tls_trace(">>>> ServerHelloDone\n"); + tls_trace("send ServerHelloDone\n"); if (tls_record_set_handshake_server_hello_done(record, &recordlen) != 1) { error_print(); return -1; @@ -799,16 +806,10 @@ int tls12_accept(TLS_CONNECT *conn, int port, } sm3_update(&sm3_ctx, record + 5, recordlen - 5); tls_array_to_bytes(record + 5, recordlen - 5, &handshakes, &handshakeslen); - if (handshakes) { - /* - memcpy(handshakes, record + 5, recordlen - 5); - handshakes += recordlen - 5; - handshakeslen += recordlen - 5; - */ - } - if (handshakes) { - tls_trace("<<<< ClientCertificate\n"); + + if (client_cacerts_fp) { + tls_trace("recv ClientCertificate\n"); if (tls12_record_recv(record, &recordlen, conn->sock) != 1) { error_print(); return -1; @@ -824,6 +825,7 @@ int tls12_accept(TLS_CONNECT *conn, int port, return -1; } // FIXME: verify client's certificate with ca certs + // 拿到客户端公钥之后就可以开始准备sm2_verify_init 了 if (tls_certificate_get_public_keys(conn->client_certs, conn->client_certs_len, &client_sign_key, NULL) != 1) { error_print(); @@ -831,14 +833,9 @@ int tls12_accept(TLS_CONNECT *conn, int port, } sm3_update(&sm3_ctx, record + 5, recordlen - 5); tls_array_to_bytes(record + 5, recordlen - 5, &handshakes, &handshakeslen); - /* - memcpy(handshakes, record + 5, recordlen - 5); - handshakes += recordlen - 5; - handshakeslen += recordlen - 5; - */ } - tls_trace("<<<< ClientKeyExchange\n"); + tls_trace("recv ClientKeyExchange\n"); if (tls12_record_recv(record, &recordlen, conn->sock) != 1) { error_print(); return -1; @@ -850,15 +847,10 @@ int tls12_accept(TLS_CONNECT *conn, int port, } sm3_update(&sm3_ctx, record + 5, recordlen - 5); tls_array_to_bytes(record + 5, recordlen - 5, &handshakes, &handshakeslen); - if (handshakes) { - /* - memcpy(handshakes, record + 5, recordlen - 5); - handshakes += recordlen - 5; - handshakeslen += recordlen - 5; - */ - } - tls_trace("++++ generate secrets\n"); + + + tls_trace("generate secrets\n"); sm2_ecdh(&server_ecdh, &client_ecdh_public, (SM2_POINT *)pre_master_secret); tls_prf(pre_master_secret, 32, "master secret", client_random, 32, server_random, 32, @@ -871,11 +863,11 @@ int tls12_accept(TLS_CONNECT *conn, int port, sm4_set_decrypt_key(&conn->client_write_enc_key, conn->key_block + 64); sm4_set_encrypt_key(&conn->server_write_enc_key, conn->key_block + 80); tls_secrets_print(stderr, pre_master_secret, 32, client_random, server_random, - conn->master_secret, conn->key_block, 96, 0, 0); + conn->master_secret, conn->key_block, 96, 0, 4); - if (handshakes) { - tls_trace("<<<< CertificateVerify\n"); + if (client_cacerts_fp) { + tls_trace("recv CertificateVerify\n"); if (tls_record_recv(record, &recordlen, conn->sock) != 1 || tls_record_version(record) != TLS_version_tls12) { error_print(); @@ -895,7 +887,7 @@ int tls12_accept(TLS_CONNECT *conn, int port, } } - tls_trace("<<<< [ChangeCipherSpec]\n"); + tls_trace("recv [ChangeCipherSpec]\n"); if (tls12_record_recv(record, &recordlen, conn->sock) != 1) { error_print(); return -1; @@ -906,7 +898,7 @@ int tls12_accept(TLS_CONNECT *conn, int port, return -1; } - tls_trace("<<<< ClientFinished\n"); + tls_trace("recv ClientFinished\n"); if (tls12_record_recv(record, &recordlen, conn->sock) != 1) { error_print(); return -1; @@ -934,7 +926,7 @@ int tls12_accept(TLS_CONNECT *conn, int port, return -1; } - tls_trace(">>>> [ChangeCipherSpec]\n"); + tls_trace("send [ChangeCipherSpec]\n"); if (tls_record_set_change_cipher_spec(record, &recordlen) != 1) { error_print(); return -1; @@ -945,7 +937,7 @@ int tls12_accept(TLS_CONNECT *conn, int port, return -1; } - tls_trace(">>>> ServerFinished\n"); + tls_trace("send ServerFinished\n"); sm3_finish(&sm3_ctx, sm3_hash); tls_prf(conn->master_secret, 48, "server finished", sm3_hash, 32, NULL, 0, @@ -966,6 +958,6 @@ int tls12_accept(TLS_CONNECT *conn, int port, return -1; } - tls_trace("Connection Established!\n\n"); + tls_trace("SSL Connection Established\n\n"); return 1; } diff --git a/src/tls_trace.c b/src/tls_trace.c index 7aeb0361..699bfc56 100644 --- a/src/tls_trace.c +++ b/src/tls_trace.c @@ -355,7 +355,7 @@ int tls_random_print(FILE *fp, const uint8_t random[32], int format, int indent) format_print(fp, format, indent, "Random\n"); indent += 4; format_print(fp, format, indent, "gmt_unix_time : %s", ctime(&gmt_unix_time)); - format_bytes(fp, format, indent, "random : ", random + 4, 28); + format_bytes(fp, format, indent, "random", random + 4, 28); return 1; } @@ -365,7 +365,7 @@ int tls_pre_master_secret_print(FILE *fp, const uint8_t pre_master_secret[48], i format_print(fp, format, indent, "PreMasterSecret\n"); indent += 4; format_print(fp, format, indent, "version : %s\n", tls_version_text(version)); - format_bytes(fp, format, indent, "pre_master_secret : ", pre_master_secret, 48); + format_bytes(fp, format, indent, "pre_master_secret", pre_master_secret, 48); return 1; } @@ -435,13 +435,13 @@ int tls_extension_print(FILE *fp, int type, const uint8_t *data, size_t datalen, error_print(); return -1; } - format_print(fp, format, indent, "group : %s\n", tls_named_curve_name(group)); - format_bytes(fp, format, indent, "key_exchange : ", key_exch, key_exch_len); + format_print(fp, format, indent, "group: %s\n", tls_named_curve_name(group)); + format_bytes(fp, format, indent, "key_exchange", key_exch, key_exch_len); } break; default: - format_bytes(fp, format, indent, "raw_data : ", data, datalen); + format_bytes(fp, format, indent, "raw_data", data, datalen); } return 1; } @@ -482,12 +482,12 @@ int tls_client_hello_print(FILE *fp, const uint8_t *data, size_t datalen, int fo format_print(fp, format, indent, "ClientHello\n"); indent += 4; if (tls_uint16_from_bytes((uint16_t *)&version, &data, &datalen) != 1) goto end; - format_print(fp, format, indent, "Version : %s (%d.%d)\n", + format_print(fp, format, indent, "Version: %s (%d.%d)\n", tls_version_text(version), version >> 8, version & 0xff); if (tls_array_from_bytes(&random, 32, &data, &datalen) != 1) goto end; tls_random_print(fp, random, format, indent); if (tls_uint8array_from_bytes(&session_id, &session_id_len, &data, &datalen) != 1) goto end; - format_bytes(fp, format, indent, "SessionID : ", session_id, session_id_len); + format_bytes(fp, format, indent, "SessionID", session_id, session_id_len); if (tls_uint16array_from_bytes(&cipher_suites, &cipher_suites_len, &data, &datalen) != 1) goto end; format_print(fp, format, indent, "CipherSuites\n"); while (cipher_suites_len >= 2) { @@ -533,17 +533,17 @@ int tls_server_hello_print(FILE *fp, const uint8_t *data, size_t datalen, int fo format_print(fp, format, indent, "ServerHello\n"); indent += 4; if (tls_uint16_from_bytes(&version, &data, &datalen) != 1) goto bad; - format_print(fp, format, indent, "Version : %s (%d.%d)\n", + format_print(fp, format, indent, "Version: %s (%d.%d)\n", tls_version_text(version), version >> 8, version & 0xff); if (tls_array_from_bytes(&random, 32, &data, &datalen) != 1) goto bad; tls_random_print(fp, random, format, indent); if (tls_uint8array_from_bytes(&session_id, &session_id_len, &data, &datalen) != 1) goto bad; - format_bytes(fp, format, indent, "SessionID : ", session_id, session_id_len); + format_bytes(fp, format, indent, "SessionID", session_id, session_id_len); if (tls_uint16_from_bytes(&cipher_suite, &data, &datalen) != 1) goto bad; - format_print(fp, format, indent, "CipherSuite : %s (0x%04x)\n", + format_print(fp, format, indent, "CipherSuite: %s (0x%04x)\n", tls_cipher_suite_name(cipher_suite), cipher_suite); if (tls_uint8_from_bytes(&comp_meth, &data, &datalen) != 1) goto bad; - format_print(fp, format, indent, "CompressionMethod : %s (%d)\n", + format_print(fp, format, indent, "CompressionMethod: %s (%d)\n", tls_compression_method_name(comp_meth), comp_meth); if (datalen > 0) { if (tls_uint16array_from_bytes(&exts, &exts_len, &data, &datalen) != 1) goto bad; @@ -590,9 +590,9 @@ int tlcp_server_key_exchange_pke_print(FILE *fp, const uint8_t *data, size_t dat error_print(); } */ - format_print(fp, format, indent, "ServerKeyExchange:\n"); + format_print(fp, format, indent, "ServerKeyExchange\n"); indent += 4; - format_bytes(fp, format, indent, "signature : ", sig, siglen); + format_bytes(fp, format, indent, "signature", sig, siglen); return 1; } @@ -615,30 +615,30 @@ int tls_server_key_exchange_ecdhe_print(FILE *fp, const uint8_t *data, size_t da error_print(); return -1; } - format_print(fp, format, indent + 8, "curve_type : %s (%d)\n", + format_print(fp, format, indent + 8, "curve_type: %s (%d)\n", tls_curve_type_name(curve_type), curve_type); if (tls_uint16_from_bytes(&curve, &data, &datalen) != 1) { error_print(); return -1; } - format_print(fp, format, indent + 8, "named_curve : %s (04%04x)\n", + format_print(fp, format, indent + 8, "named_curve: %s (04%04x)\n", tls_named_curve_name(curve), curve); if (tls_uint8array_from_bytes(&octets, &octetslen, &data, &datalen) != 1) { error_print(); return -1; } - format_bytes(fp, format, indent + 4, "point : ", octets, octetslen); + format_bytes(fp, format, indent + 4, "point", octets, octetslen); if (tls_uint16_from_bytes(&sig_alg, &data, &datalen) != 1) { error_print(); return -1; } - format_print(fp, format, indent, "SignatureScheme : %s (04%04x)\n", + format_print(fp, format, indent, "SignatureScheme: %s (04%04x)\n", tls_signature_scheme_name(sig_alg), sig_alg); if (tls_uint16array_from_bytes(&sig, &siglen, &data, &datalen) != 1) { error_print(); return -1; } - format_bytes(fp, format, indent, "Siganture : ", sig, siglen); + format_bytes(fp, format, indent, "Siganture", sig, siglen); if (datalen > 0) { error_print(); return -1; @@ -649,6 +649,7 @@ int tls_server_key_exchange_ecdhe_print(FILE *fp, const uint8_t *data, size_t da int tls_server_key_exchange_print(FILE *fp, const uint8_t *data, size_t datalen, int format, int indent) { int cipher_suite = (format >> 8) & 0xffff; + switch (cipher_suite) { case TLCP_cipher_ecc_sm4_cbc_sm3: case TLCP_cipher_ecc_sm4_gcm_sm3: @@ -687,7 +688,7 @@ int tls_certificate_request_print(FILE *fp, const uint8_t *data, size_t datalen, format_print(fp, format, indent + 4, "%s\n", tls_cert_type_name(*cert_types++)); } if (tls_uint16array_from_bytes(&ca_names, &ca_names_len, &data, &datalen) != 1) goto bad; - format_bytes(fp, format, indent, "CAnames : ", ca_names, ca_names_len); + format_bytes(fp, format, indent, "CAnames", ca_names, ca_names_len); return 1; bad: error_print(); @@ -712,7 +713,7 @@ int tls_client_key_exchange_pke_print(FILE *fp, const uint8_t *data, size_t data error_print(); return -1; } - format_bytes(fp, format, indent, "EncryptedPreMasterSecret : ", enced_pms, enced_pms_len); + format_bytes(fp, format, indent, "EncryptedPreMasterSecret", enced_pms, enced_pms_len); return 1; } @@ -728,7 +729,7 @@ int tls_client_key_exchange_ecdhe_print(FILE *fp, const uint8_t *data, size_t da error_print(); return -1; } - format_bytes(fp, format, indent, "ecdh_Yc : ", octets, octetslen); + format_bytes(fp, format, indent, "ecdh_Yc", octets, octetslen); if (datalen > 0) { error_print(); return -1; @@ -766,8 +767,8 @@ int tls_client_key_exchange_print(FILE *fp, const uint8_t *data, size_t datalen, int tls_certificate_verify_print(FILE *fp, const uint8_t *data, size_t datalen, int format, int indent) { - format_print(fp, format, indent, "CertificateVerify :\n"); - format_bytes(fp, format, indent + 4, "Signature : ", data, datalen); + format_print(fp, format, indent, "CertificateVerify\n"); + format_bytes(fp, format, indent + 4, "Signature", data, datalen); return 1; } @@ -775,7 +776,7 @@ int tls_finished_print(FILE *fp, const uint8_t *data, size_t datalen, int format { format_print(fp, format, indent, "Finished\n"); indent += 4; - format_bytes(fp, format, indent, "verify_data : ", data, datalen); + format_bytes(fp, format, indent, "verify_data", data, datalen); return 1; } @@ -793,12 +794,12 @@ int tls_handshake_print(FILE *fp, const uint8_t *handshake, size_t handshakelen, error_print(); return -1; } - format_print(fp, format, indent, "Type : %s (%d)\n", tls_handshake_type_name(type), type); + format_print(fp, format, indent, "Type: %s (%d)\n", tls_handshake_type_name(type), type); if (tls_uint24_from_bytes((uint24_t *)&datalen, &cp, &handshakelen) != 1) { error_print(); return -1; } - format_print(fp, format, indent, "Length : %zu\n", datalen); + format_print(fp, format, indent, "Length: %zu\n", datalen); if (tls_array_from_bytes(&data, datalen, &cp, &handshakelen) != 1) { error_print(); @@ -845,10 +846,10 @@ int tls_alert_print(FILE *fp, const uint8_t *data, size_t datalen, int format, i error_print(); return -1; } - format_print(fp, format, indent, "Alert :\n"); + format_print(fp, format, indent, "Alert:\n"); indent += 4; - format_print(fp, format, indent, "Level : %s (%d)\n", tls_alert_level_name(data[0]), data[0]); - format_print(fp, format, indent, "Reason : %s (%d)\n", tls_alert_description_text(data[1]), data[1]); + format_print(fp, format, indent, "Level: %s (%d)\n", tls_alert_level_name(data[0]), data[0]); + format_print(fp, format, indent, "Reason: %s (%d)\n", tls_alert_description_text(data[1]), data[1]); return 1; } @@ -866,7 +867,7 @@ int tls_change_cipher_spec_print(FILE *fp, const uint8_t *data, size_t datalen, int tls_application_data_print(FILE *fp, const uint8_t *data, size_t datalen, int format, int indent) { - format_bytes(fp, format, indent, "ApplicationData : ", data, datalen); + format_bytes(fp, format, indent, "ApplicationData", data, datalen); return 1; } @@ -882,9 +883,9 @@ int tls_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int for } version = tls_record_version(record); format_print(fp, format, indent, "Record\n"); indent += 4; - format_print(fp, format, indent, "ContentType : %s (%d)\n", tls_record_type_name(record[0]), record[0]); - format_print(fp, format, indent, "Version : %s (%d.%d)\n", tls_version_text(version), version >> 8, version & 0xff); - format_print(fp, format, indent, "Length : %d\n", tls_record_length(record)); + format_print(fp, format, indent, "ContentType: %s (%d)\n", tls_record_type_name(record[0]), record[0]); + format_print(fp, format, indent, "Version: %s (%d.%d)\n", tls_version_text(version), version >> 8, version & 0xff); + format_print(fp, format, indent, "Length: %d\n", tls_record_length(record)); data = record + 5; datalen = recordlen - 5; @@ -929,14 +930,14 @@ int tls_secrets_print(FILE *fp, const uint8_t *key_block, size_t key_block_len, int format, int indent) { - format_bytes(stderr, format, indent, "pre_master_secret : ", pre_master_secret, pre_master_secret_len); - format_bytes(stderr, format, indent, "client_random : ", client_random, 32); - format_bytes(stderr, format, indent, "server_random : ", server_random, 32); - format_bytes(stderr, format, indent, "master_secret : ", master_secret, 48); - format_bytes(stderr, format, indent, "client_write_mac_key : ", key_block, 32); - format_bytes(stderr, format, indent, "server_write_mac_key : ", key_block + 32, 32); - format_bytes(stderr, format, indent, "client_write_enc_key : ", key_block + 64, 16); - format_bytes(stderr, format, indent, "server_write_enc_key : ", key_block + 80, 16); + format_bytes(stderr, format, indent, "pre_master_secret", pre_master_secret, pre_master_secret_len); + format_bytes(stderr, format, indent, "client_random", client_random, 32); + format_bytes(stderr, format, indent, "server_random", server_random, 32); + format_bytes(stderr, format, indent, "master_secret", master_secret, 48); + format_bytes(stderr, format, indent, "client_write_mac_key", key_block, 32); + format_bytes(stderr, format, indent, "server_write_mac_key", key_block + 32, 32); + format_bytes(stderr, format, indent, "client_write_enc_key", key_block + 64, 16); + format_bytes(stderr, format, indent, "server_write_enc_key", key_block + 80, 16); format_print(stderr, format, indent, "\n"); return 1; } diff --git a/src/x509_ext.c b/src/x509_ext.c index 3029c9c4..96684833 100644 --- a/src/x509_ext.c +++ b/src/x509_ext.c @@ -717,7 +717,7 @@ int x509_key_usage_from_name(int *flag, const char *name) int i; for (i = 0; i < x509_key_usages_count; i++) { if (strcmp(name, x509_key_usages[i]) == 0) { - *flag = i; + *flag = 1 << i; return 1; } } diff --git a/tests/hash_drbgtest.c b/tests/hash_drbgtest.c index 3a303bd2..23adde47 100644 --- a/tests/hash_drbgtest.c +++ b/tests/hash_drbgtest.c @@ -53,6 +53,7 @@ #include #include #include +#include #define EntropyInput "212956390783381dbfc6362dd0da9a09" @@ -114,10 +115,10 @@ int main(void) printf("ok\n"); } - hash_drbg_reseed(&drbg, pr1, pr1_len, NULL, 0); hash_drbg_generate(&drbg, NULL, 0, 640/8, out); + hash_drbg_reseed(&drbg, pr2, pr2_len, NULL, 0); hash_drbg_generate(&drbg, NULL, 0, 640/8, out); diff --git a/tests/tlstest.c b/tests/tlstest.c index 67ccbf60..a6d09dc8 100644 --- a/tests/tlstest.c +++ b/tests/tlstest.c @@ -97,6 +97,7 @@ static int test_tls_encode(void) return 1; } + printf("%s() ok\n", __FUNCTION__); return 0; } @@ -128,9 +129,7 @@ static int test_tls_cbc(void) tls_cbc_decrypt(&hmac_ctx, &sm4_key, seq_num, header, out, len, buf, &buflen); printf("%s\n", buf); - - - return 1; + return 0; } static int test_tls_random(void) @@ -138,6 +137,8 @@ static int test_tls_random(void) uint8_t random[32]; tls_random_generate(random); tls_random_print(stdout, random, 0, 0); + + printf("%s() ok\n", __FUNCTION__); return 0; } @@ -148,7 +149,7 @@ static int test_tls_client_hello(void) int version = TLS_version_tlcp; uint8_t random[32]; - uint16_t cipher_suites[] = { + int cipher_suites[] = { TLCP_cipher_ecc_sm4_cbc_sm3, TLCP_cipher_ecc_sm4_gcm_sm3, TLCP_cipher_ecdhe_sm4_cbc_sm3, @@ -162,16 +163,20 @@ static int test_tls_client_hello(void) TLCP_cipher_rsa_sm4_cbc_sha256, TLCP_cipher_rsa_sm4_gcm_sha256, }; - uint8_t comp_meths[] = {0}; + int comp_meths[] = {0}; - tls_record_set_handshake_client_hello(record, &recordlen, + if (tls_record_set_handshake_client_hello(record, &recordlen, version, random, NULL, 0, - cipher_suites, sizeof(cipher_suites)/2, - NULL, 0); - + cipher_suites, sizeof(cipher_suites)/sizeof(cipher_suites[0]), + NULL, 0) != 1) { + error_print(); + return -1; + } tls_client_hello_print(stdout, record + 5 + 4, recordlen - 5 -4, 0, 4); + + printf("%s() ok\n", __FUNCTION__); return 0; } @@ -180,20 +185,23 @@ static int test_tls_server_hello(void) uint8_t record[512]; size_t recordlen = 0; - - uint8_t version[2] = {1,1}; uint8_t random[32]; uint16_t cipher_suite = TLCP_cipher_ecdhe_sm4_cbc_sm3; - tls_record_set_handshake_server_hello(record, &recordlen, - version, + + tls_record_set_version(record, TLS_version_tlcp); + if (tls_record_set_handshake_server_hello(record, &recordlen, + TLS_version_tlcp, random, NULL, 0, cipher_suite, - NULL, 0); - + NULL, 0) != 1) { + error_print(); + return -1; + } tls_server_hello_print(stdout, record + 5 + 4, recordlen - 5 -4, 0, 0); + printf("%s() ok\n", __FUNCTION__); return 0; } @@ -203,7 +211,10 @@ static int test_tls_certificate(void) size_t recordlen = 0; FILE *fp = NULL; - if (!(fp = fopen("cacerts.pem", "r"))) { + // 测试函数不要有外部的依赖 + + /* + if (!(fp = fopen("cacert.pem", "r"))) { error_print(); return -1; } @@ -212,6 +223,9 @@ static int test_tls_certificate(void) return -1; } tls_certificate_print(stdout, record + 9, recordlen - 9, 0, 0); + */ + + printf("%s() ok\n", __FUNCTION__); return 0; } @@ -219,11 +233,10 @@ static int test_tls_server_key_exchange(void) { uint8_t record[1024]; size_t recordlen = 0; - const uint8_t version[] = {1,1}; - uint8_t sig[77]; + uint8_t sig[77] = {0xAA, 0xBB}; size_t siglen; - tls_record_set_version(record, version); + tls_record_set_version(record, TLS_version_tlcp); if (tlcp_record_set_handshake_server_key_exchange_pke(record, &recordlen, sig, sizeof(sig)) != 1) { error_print(); return -1; @@ -232,19 +245,20 @@ static int test_tls_server_key_exchange(void) error_print(); return -1; } - tls_server_key_exchange_print(stdout, sig, siglen, 0, 0); - return 1; + tls_server_key_exchange_print(stdout, sig, siglen, TLCP_cipher_ecc_sm4_gcm_sm3 << 8, 0); + + printf("%s() ok\n", __FUNCTION__); + return 0; } static int test_tls_certificate_verify(void) { uint8_t record[1024]; size_t recordlen = 0; - const uint8_t version[] = {1,1}; uint8_t sig[77]; size_t siglen; - tls_record_set_version(record, version); + tls_record_set_version(record, TLS_version_tls12); if (tls_record_set_handshake_certificate_verify(record, &recordlen, sig, sizeof(sig)) != 1) { error_print(); return -1; @@ -254,7 +268,9 @@ static int test_tls_certificate_verify(void) return -1; } tls_certificate_verify_print(stdout, sig, siglen, 0, 0); - return 1; + + printf("%s() ok\n", __FUNCTION__); + return 0; } static int test_tls_finished(void) @@ -272,7 +288,9 @@ static int test_tls_finished(void) return -1; } tls_finished_print(stdout, verify_data, 12, 0, 0); - return 1; + + printf("%s() ok\n", __FUNCTION__); + return 0; } static int test_tls_alert(void) @@ -291,7 +309,9 @@ static int test_tls_alert(void) return -1; } tls_alert_print(stdout, record + 5, recordlen - 5, 0, 0); - return 1; + + printf("%s() ok\n", __FUNCTION__); + return 0; } static int test_tls_change_cipher_spec(void) @@ -308,7 +328,9 @@ static int test_tls_change_cipher_spec(void) return -1; } tls_change_cipher_spec_print(stdout, record + 5, recordlen - 5, 0, 0); - return 1; + + printf("%s() ok\n", __FUNCTION__); + return 0; } static int test_tls_application_data(void) @@ -328,7 +350,9 @@ static int test_tls_application_data(void) return -1; } tls_application_data_print(stdout, p, len, 0, 0); - return 1; + + printf("%s() ok\n", __FUNCTION__); + return 0; } int main(void) @@ -346,5 +370,6 @@ int main(void) err += test_tls_alert(); err += test_tls_change_cipher_spec(); err += test_tls_application_data(); + if (err == 0) printf("%s all tests passed\n", __FILE__); return err; } diff --git a/tests/toolstest.sh b/tests/toolstest.sh deleted file mode 100755 index 7c4f3786..00000000 --- a/tests/toolstest.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash -x - -rm -fr *.pem -rm -fr *.der - -# generate sm2 keypair and encrypt with password -sm2keygen -pass 123456 -out cakey.pem -pubout capubkey.pem - -# generate a self-signed certificate -certgen -C CN -ST Beijing -L Haidian -O PKU -OU CS -CN CA -days 365 -key cakey.pem -pass 123456 -out cacert.pem -certparse -in cacert.pem - -# generate a req and sign by ca certificate -sm2keygen -pass 123456 -out key.pem -pubout pubkey.pem -reqgen -C CN -ST Beijing -L Haidian -O PKU -OU CS -CN Alice -days 365 -key key.pem -pass 123456 -out req.pem -reqparse -in req.pem -reqsign -in req.pem -days 365 -cacert cacert.pem -key cakey.pem -pass 123456 -out cert.pem -certparse -in cert.pem - -# hash and hmac -echo -n "abc" | sm3 -echo -n "abc" | sm3hmac -keyhex 1122334455667788 - -# encrypt with public key -echo hello | sm2encrypt -pubkey pubkey.pem -out ciphertext.der -sm2decrypt -in ciphertext.der -key key.pem -pass 123456 - -# encrypt with certificate -echo hello | sm2encrypt -cert cert.pem -out ciphertext.der -sm2decrypt -in ciphertext.der -key key.pem -pass 123456 - -# sign and verify with public key and certificate -echo hello | sm2sign -key key.pem -pass 123456 -out signature.der -echo hello | sm2verify -pubkey pubkey.pem -sig signature.der -echo hello | sm2verify -cert cert.pem -sig signature.der - diff --git a/tools/certgen.c b/tools/certgen.c index a933c5ab..00a1ceb8 100644 --- a/tools/certgen.c +++ b/tools/certgen.c @@ -1,4 +1,4 @@ -/* +/* * Copyright (c) 2021 - 2021 The GmSSL Project. All rights reserved. * * Redistribution and use in source and binary forms, with or without @@ -51,17 +51,32 @@ #include #include #include +#include #include #include #include +static int ext_key_usage_set(int *usages, const char *usage_name) +{ + int flag; + if (x509_key_usage_from_name(&flag, usage_name) != 1) { + error_print(); + return -1; + } + *usages |= flag; + return 1; +} + #ifndef WIN32 #include #include #endif -static const char *options = "[-C str] [-ST str] [-L str] [-O str] [-OU str] -CN str -days num -key file [-pass pass]"; +static const char *options = + "[-C str] [-ST str] [-L str] [-O str] [-OU str] -CN str -days num " + "-key file [-pass pass] " + "[-key_usage str]*"; int main(int argc, char **argv) @@ -75,11 +90,11 @@ int main(int argc, char **argv) char *org_unit = NULL; char *common_name = NULL; int days = 0; - char *keyfile = NULL; + int key_usage = 0; + char *file = NULL; + FILE *outfp = stdout; FILE *keyfp = NULL; char *pass = NULL; - char *outfile = NULL; - FILE *outfp = stdout; SM2_KEY sm2_key; uint8_t serial[12]; @@ -88,6 +103,8 @@ int main(int argc, char **argv) time_t not_before; time_t not_after; uint8_t uniq_id[32]; + uint8_t exts[512]; + size_t extslen = 0; uint8_t cert[1024]; size_t certlen; @@ -117,18 +134,32 @@ help: } else if (!strcmp(*argv, "-L")) { if (--argc < 1) goto bad; locality = *(++argv); - } else if (!strcmp(*argv, "-key")) { - if (--argc < 1) goto bad; - keyfile = *(++argv); } else if (!strcmp(*argv, "-days")) { if (--argc < 1) goto bad; days = atoi(*(++argv)); + } else if (!strcmp(*argv, "-key_usage")) { + if (--argc < 1) goto bad; + if (ext_key_usage_set(&key_usage, *(++argv)) != 1) { + error_print(); + return -1; + } + } else if (!strcmp(*argv, "-key")) { + if (--argc < 1) goto bad; + file = *(++argv); + if (!(keyfp = fopen(file, "r"))) { + error_print(); + return -1; + } } else if (!strcmp(*argv, "-pass")) { if (--argc < 1) goto bad; pass = *(++argv); } else if (!strcmp(*argv, "-out")) { if (--argc < 1) goto bad; - outfile = *(++argv); + file = *(++argv); + if (!(outfp = fopen(file, "w"))) { + error_print(); + return -1; + } } else { bad: fprintf(stderr, "%s: illegal option '%s'\n", prog, *argv); @@ -140,7 +171,7 @@ bad: argv++; } - if (!common_name || days <= 0 || !keyfile) { + if (!common_name || days <= 0) { fprintf(stderr, "%s: missing options\n", prog); fprintf(stderr, "usage: %s %s\n", prog, options); return 1; @@ -154,24 +185,22 @@ bad: error_print(); return -1; } - if (!(keyfp = fopen(keyfile, "r")) - || sm2_private_key_info_decrypt_from_pem(&sm2_key, pass, keyfp) != 1) { + + if (keyfp == NULL) { + error_print(); + return -1; + } + if (sm2_private_key_info_decrypt_from_pem(&sm2_key, pass, keyfp) != 1) { error_print(); goto end; } - if (outfile) { - if (!(outfp = fopen(outfile, "wb"))) { - error_print(); - return -1; - } - } - time(¬_before); if (rand_bytes(serial, sizeof(serial)) != 1 || x509_name_set(name, &namelen, sizeof(name), country, state, locality, org, org_unit, common_name) != 1 || x509_validity_add_days(¬_after, not_before, days) != 1 + || x509_exts_add_key_usage(exts, &extslen, sizeof(exts), 1, key_usage) != 1 || x509_cert_sign( cert, &certlen, sizeof(cert), X509_version_v3, @@ -183,7 +212,7 @@ bad: &sm2_key, NULL, 0, NULL, 0, - NULL, 0, + exts, extslen, &sm2_key, SM2_DEFAULT_ID, strlen(SM2_DEFAULT_ID)) != 1 || x509_cert_to_pem(cert, certlen, outfp) != 1) { error_print(); diff --git a/tools/reqgen.c b/tools/reqgen.c index f83d8877..c0d7cdc1 100644 --- a/tools/reqgen.c +++ b/tools/reqgen.c @@ -76,9 +76,9 @@ int main(int argc, char **argv) char *org = NULL; char *org_unit = NULL; char *common_name = NULL; - char *keyfile = NULL; + + char *file = NULL; char *pass = NULL; - char *outfile = NULL; int days = 0; FILE *keyfp = NULL; @@ -124,7 +124,11 @@ int main(int argc, char **argv) common_name = *(++argv); } else if (!strcmp(*argv, "-key")) { if (--argc < 1) goto bad; - keyfile = *(++argv); + file = *(++argv); + if (!(keyfp = fopen(file, "r"))) { + error_print(); + return -1; + } } else if (!strcmp(*argv, "-pass")) { if (--argc < 1) goto bad; pass = *(++argv); @@ -133,7 +137,11 @@ int main(int argc, char **argv) days = atoi(*(++argv)); } else if (!strcmp(*argv, "-out")) { if (--argc < 1) goto bad; - outfile = *(++argv); + file = *(++argv); + if (!(outfp = fopen(file, "w"))) { + error_print(); + return -1; + } } else { bad: fprintf(stderr, "usage: %s %s\n", prog, options); @@ -143,19 +151,11 @@ bad: argv++; } - if (!common_name || days <= 0 || !keyfile) { + if (!common_name || days <= 0 || !keyfp) { fprintf(stderr, "%s: missing options\n", prog); fprintf(stderr, "usage: %s %s\n", prog, options); return 1; } - - if (outfile) { - if (!(outfp = fopen(outfile, "wb"))) { - error_print(); - return -1; - } - } - if (!pass) { pass = getpass("Encryption Password : "); } @@ -164,8 +164,7 @@ bad: error_print(); return -1; } - if (!(keyfp = fopen(keyfile, "r")) - || sm2_private_key_info_decrypt_from_pem(&sm2_key, pass, keyfp) != 1) { + if (sm2_private_key_info_decrypt_from_pem(&sm2_key, pass, keyfp) != 1) { error_print(); return -1; } diff --git a/tools/reqsign.c b/tools/reqsign.c index c394b1ac..0f859aaa 100644 --- a/tools/reqsign.c +++ b/tools/reqsign.c @@ -52,24 +52,39 @@ #include #include #include +#include #include #include #include #include +static int ext_key_usage_set(int *usages, const char *usage_name) +{ + int flag = 0; + if (x509_key_usage_from_name(&flag, usage_name) != 1) { + error_print(); + return -1; + } + *usages |= flag; + + printf("flag = %08x", flag); + printf("usage = %08x", *usages); + return 1; +} + + static const char *usage = "usage: %s [-in file] -days num -cacert file -key file [-pass str] [-out file]\n"; int main(int argc, char **argv) { char *prog = argv[0]; - char *infile = NULL; - char *outfile = NULL; - char *cacertfile = NULL; - char *keyfile = NULL; + char *file; char *pass = NULL; int days = 0; FILE *infp = stdin; + + uint8_t req[512]; size_t reqlen; const uint8_t *subject; @@ -92,6 +107,9 @@ int main(int argc, char **argv) size_t certlen; uint8_t serial[12]; time_t not_before, not_after; + uint8_t exts[512]; + size_t extslen = 0; + int key_usage = 0; if (argc < 2) { @@ -108,22 +126,45 @@ help: return 0; } else if (!strcmp(*argv, "-in")) { if (--argc < 1) goto bad; - infile = *(++argv); - } else if (!strcmp(*argv, "-cacert")) { - if (--argc < 1) goto bad; - cacertfile = *(++argv); - } else if (!strcmp(*argv, "-key")) { - if (--argc < 1) goto bad; - keyfile = *(++argv); - } else if (!strcmp(*argv, "-pass")) { - if (--argc < 1) goto bad; - pass = *(++argv); + file = *(++argv); + if (!(infp = fopen(file, "r"))) { + error_print(); + return -1; + } } else if (!strcmp(*argv, "-days")) { if (--argc < 1) goto bad; days = atoi(*(++argv)); + } else if (!strcmp(*argv, "-key_usage")) { + if (--argc < 1) goto bad; + if (ext_key_usage_set(&key_usage, *(++argv)) != 1) { + error_print(); + return -1; + } + } else if (!strcmp(*argv, "-cacert")) { + if (--argc < 1) goto bad; + file = *(++argv); + if (!(cacertfp = fopen(file, "r"))) { + error_print(); + return -1; + } + } else if (!strcmp(*argv, "-key")) { + if (--argc < 1) goto bad; + file = *(++argv); + if (!(keyfp = fopen(file, "r"))) { + error_print(); + return -1; + } + } else if (!strcmp(*argv, "-pass")) { + if (--argc < 1) goto bad; + pass = *(++argv); + } else if (!strcmp(*argv, "-out")) { if (--argc < 1) goto bad; - outfile = *(++argv); + file = *(++argv); + if (!(outfp = fopen(file, "w"))) { + error_print(); + return -1; + } } else { bad: error_print(); @@ -134,19 +175,13 @@ bad: argv++; } if (days <= 0 - || !cacertfile - || !keyfile) { + || !infp + || !cacertfp + || !keyfp) { error_print(); return -1; } - - if (infile) { - if (!(infp = fopen(infile, "r"))) { - error_print(); - return -1; - } - } if (x509_req_from_pem(req, &reqlen, sizeof(req), infp) != 1 || x509_req_get_details(req, reqlen, NULL, &subject, &subject_len, &subject_public_key, @@ -155,21 +190,13 @@ bad: return -1; } - if (!(cacertfp = fopen(cacertfile, "r")) - || x509_cert_from_pem(cacert, &cacertlen, sizeof(cacert), cacertfp) != 1 + if (x509_cert_from_pem(cacert, &cacertlen, sizeof(cacert), cacertfp) != 1 || x509_cert_get_subject(cacert, cacertlen, &issuer, &issuer_len) != 1 || x509_cert_get_subject_public_key(cacert, cacertlen, &issuer_public_key) != 1) { error_print(); return -1; } - if (outfile) { - if (!(outfp = fopen(outfile, "w"))) { - error_print(); - return -1; - } - } - if (!pass) { pass = getpass("Password : "); } @@ -177,8 +204,7 @@ bad: error_print(); return -1; } - if (!(keyfp = fopen(keyfile, "r")) - || sm2_private_key_info_decrypt_from_pem(&sm2_key, pass, keyfp) != 1 + if (sm2_private_key_info_decrypt_from_pem(&sm2_key, pass, keyfp) != 1 || sm2_public_key_equ(&sm2_key, &issuer_public_key) != 1) { error_print(); memset(&sm2_key, 0, sizeof(SM2_KEY)); @@ -189,6 +215,7 @@ bad: time(¬_before); if (x509_validity_add_days(¬_after, not_before, days) != 1 + || x509_exts_add_key_usage(exts, &extslen, sizeof(exts), 1, key_usage) != 1 || x509_cert_sign( cert, &certlen, sizeof(cert), X509_version_v3, @@ -200,7 +227,7 @@ bad: &subject_public_key, NULL, 0, NULL, 0, - NULL, 0, + exts, extslen, &sm2_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1 || x509_cert_to_pem(cert, certlen, outfp) != 1) { memset(&sm2_key, 0, sizeof(SM2_KEY)); @@ -208,6 +235,7 @@ bad: return -1; } + // FIXME: fclose() .... memset(&sm2_key, 0, sizeof(SM2_KEY)); return 0; } diff --git a/tools/tlcp_client.c b/tools/tlcp_client.c index fb18660d..af8def40 100644 --- a/tools/tlcp_client.c +++ b/tools/tlcp_client.c @@ -49,6 +49,7 @@ #include #include #include +#include #include #include @@ -58,100 +59,94 @@ const char *http_get = "Hostname: aaa\r\n" "\r\n\r\n"; -void print_usage(const char *prog) -{ - printf("Usage: %s [options]\n", prog); - printf(" -host \n"); - printf(" -port \n"); - printf(" -cacerts \n"); - printf(" -cert \n"); - printf(" -key \n"); -} -int main(int argc , char *argv[]) +// 虽然服务器可以用双证书,但是客户端只能使用一个证书,也就是签名证书 +static const char *options = "-host str [-port num] [-cacert file] [-cert file -key file [-pass str]]"; + +int main(int argc, char *argv[]) { int ret = -1; char *prog = argv[0]; char *host = NULL; int port = 443; + char *pass = NULL; TLS_CONNECT conn; char buf[100] = {0}; size_t len = sizeof(buf); + char *file; - char *cacertsfile = NULL; - char *certfile = NULL; - char *keyfile = NULL; - - FILE *cacertsfp = NULL; + FILE *cacertfp = NULL; FILE *certfp = NULL; FILE *keyfp = NULL; SM2_KEY sign_key; - if (argc < 2) { - print_usage(prog); - return 0; + fprintf(stderr, "usage: %s %s\n", prog, options); + return 1; } argc--; argv++; while (argc >= 1) { if (!strcmp(*argv, "-help")) { - print_usage(prog); + printf("usage: %s %s\n", prog, options); return 0; - } else if (!strcmp(*argv, "-host")) { if (--argc < 1) goto bad; host = *(++argv); - } else if (!strcmp(*argv, "-port")) { if (--argc < 1) goto bad; port = atoi(*(++argv)); - - } else if (!strcmp(*argv, "-cacerts")) { + } else if (!strcmp(*argv, "-cacert")) { if (--argc < 1) goto bad; - cacertsfile = *(++argv); - + file = *(++argv); + if (!(cacertfp = fopen(file, "r"))) { + error_print(); + return -1; + } } else if (!strcmp(*argv, "-cert")) { if (--argc < 1) goto bad; - certfile = *(++argv); - + file = *(++argv); + if (!(certfp = fopen(file, "r"))) { + error_print(); + return -1; + } } else if (!strcmp(*argv, "-key")) { if (--argc < 1) goto bad; - keyfile = *(++argv); - + file = *(++argv); + if (!(keyfp = fopen(file, "r"))) { + error_print(); + return -1; + } + } else if (!strcmp(*argv, "-pass")) { + if (--argc < 1) goto bad; + pass = *(++argv); } else { - print_usage(prog); + fprintf(stderr, "%s: invalid option '%s'\n", prog, *argv); + return 1; +bad: + fprintf(stderr, "%s: option '%s' argument required\n", prog, *argv); return 0; } argc--; argv++; } - if (!host || !certfile || !keyfile) { - print_usage(prog); + if (!host) { + error_print(); return -1; } - if (cacertsfile) { - if (!(cacertsfp = fopen(cacertsfile, "r"))) { + if (certfp) { + if (!keyfp) { error_print(); return -1; } - } - if (certfile) { - if (!(certfp = fopen(certfile, "r"))) { - error_print(); - return -1; + if (!pass) { + pass = getpass("Password : "); } - } - if (keyfile) { - if (!(keyfp = fopen(keyfile, "r"))) { - error_print(); - return -1; - } - if (sm2_private_key_from_pem(&sign_key, keyfp) != 1) { + if (sm2_private_key_info_decrypt_from_pem(&sign_key, pass, keyfp) != 1) { error_print(); return -1; } @@ -159,7 +154,7 @@ int main(int argc , char *argv[]) memset(&conn, 0, sizeof(conn)); - if (tlcp_connect(&conn, host, port, cacertsfp, certfp, &sign_key) != 1) { + if (tlcp_connect(&conn, host, port, cacertfp, certfp, &sign_key) != 1) { error_print(); return -1; } @@ -184,11 +179,5 @@ int main(int argc , char *argv[]) } } - return 1; -bad: - fprintf(stderr, "%s: command error\n", prog); - return 0; } - - diff --git a/tools/tlcp_server.c b/tools/tlcp_server.c index 94ba1355..dcd6df00 100644 --- a/tools/tlcp_server.c +++ b/tools/tlcp_server.c @@ -55,31 +55,25 @@ #include -void print_usage(const char *prog) -{ - printf("Usage: %s [options]\n", prog); - printf(" -port \n"); - printf(" -cert \n"); - printf(" -signkey \n"); - printf(" -enckey \n"); -} +static const char *options = "[-port num] -cert file -key file [-pass str] -ex_key file [-ex_pass str] [-cacert file]"; - -int main(int argc , char *argv[]) +int main(int argc , char **argv) { int ret = -1; char *prog = argv[0]; int port = 443; - char *certfile = NULL; - char *signkeyfile = NULL; - char *enckeyfile = NULL; + char *file = NULL; + FILE *certfp = NULL; FILE *signkeyfp = NULL; FILE *enckeyfp = NULL; SM2_KEY signkey; SM2_KEY enckey; + char *pass = NULL; + char *ex_pass = NULL; + uint8_t verify_buf[4096]; @@ -88,73 +82,89 @@ int main(int argc , char *argv[]) size_t len = sizeof(buf); if (argc < 2) { - print_usage(prog); - return 0; + fprintf(stderr, "usage: %s %s\n", prog, options); + return 1; } argc--; argv++; while (argc >= 1) { if (!strcmp(*argv, "-help")) { - print_usage(prog); + printf("usage: %s %s\n", prog, options); return 0; - } else if (!strcmp(*argv, "-port")) { if (--argc < 1) goto bad; port = atoi(*(++argv)); - } else if (!strcmp(*argv, "-cert")) { if (--argc < 1) goto bad; - certfile = *(++argv); - - } else if (!strcmp(*argv, "-signkey")) { + file = *(++argv); + if (!(certfp = fopen(file, "r"))) { + error_print(); + return -1; + } + } else if (!strcmp(*argv, "-key")) { if (--argc < 1) goto bad; - signkeyfile = *(++argv); - - } else if (!strcmp(*argv, "-enckey")) { + file = *(++argv); + if (!(signkeyfp = fopen(file, "r"))) { + error_print(); + return -1; + } + } else if (!strcmp(*argv, "-pass")) { if (--argc < 1) goto bad; - enckeyfile = *(++argv); - + pass = *(++argv); + } else if (!strcmp(*argv, "-ex_key")) { + if (--argc < 1) goto bad; + file = *(++argv); + if (!(enckeyfp = fopen(file, "r"))) { + error_print(); + return -1; + } + } else if (!strcmp(*argv, "-ex_pass")) { + if (--argc < 1) goto bad; + ex_pass = *(++argv); } else { - print_usage(prog); - return 0; + fprintf(stderr, "%s: invalid option '%s'\n", prog, *argv); + return 1; +bad: + fprintf(stderr, "%s: option '%s' argument required\n", prog, *argv); + return 1; } argc--; argv++; } - if (!certfile || !signkeyfile || !enckeyfile) { - print_usage(prog); + if (!certfp) { + error_print(); return -1; } - - if (!(certfp = fopen(certfile, "r"))) { + if (!signkeyfp) { + error_print(); + return -1; + } + if (!enckeyfp) { error_print(); return -1; } - - if (!(signkeyfp = fopen(signkeyfile, "r"))) { - error_print(); - return -1; + if (!pass) { + pass = getpass("Sign Key Password : "); } - if (sm2_private_key_from_pem(&signkey, signkeyfp) != 1) { + if (sm2_private_key_info_decrypt_from_pem(&signkey, pass, signkeyfp) != 1) { error_print(); return -1; } - if (!(enckeyfp = fopen(enckeyfile, "r"))) { - error_print(); - return -1; + if (!ex_pass) { + ex_pass = getpass("Encryption Key Password : "); } - if (sm2_private_key_from_pem(&enckey, enckeyfp) != 1) { + if (sm2_private_key_info_decrypt_from_pem(&enckey, ex_pass, enckeyfp) != 1) { error_print(); return -1; } memset(&conn, 0, sizeof(conn)); if (tlcp_accept(&conn, port, certfp, &signkey, &enckey, - certfp, verify_buf, 4096) != 1) { + NULL, verify_buf, 4096) != 1) { error_print(); return -1; } @@ -184,10 +194,5 @@ int main(int argc , char *argv[]) } - - return 1; -bad: - fprintf(stderr, "%s: command error\n", prog); - return 0; } diff --git a/tools/tls12_client.c b/tools/tls12_client.c index 9855deb5..cef77afb 100644 --- a/tools/tls12_client.c +++ b/tools/tls12_client.c @@ -59,70 +59,62 @@ const char *http_get = "Hostname: aaa\r\n" "\r\n\r\n"; -void print_usage(const char *prog) -{ - printf("Usage: %s [options]\n", prog); - printf(" -host \n"); - printf(" -port \n"); - printf(" -cacerts \n"); - printf(" -cert \n"); - printf(" -key \n"); -} + +static const char *options = "-host str [-port num] [-cacert file] [-cert file -key file [-pass str]]"; int main(int argc , char *argv[]) { - int ret = -1; char *prog = argv[0]; char *host = NULL; int port = 443; + char *cacertfile = NULL; + char *certfile = NULL; + char *keyfile = NULL; + char *pass = NULL; + + FILE *cacertfp = NULL; + FILE *certfp = NULL; + FILE *keyfp = NULL; + SM2_KEY sm2_key; + TLS_CONNECT conn; char buf[100] = {0}; size_t len = sizeof(buf); - char *cacertsfile = NULL; - char *certfile = NULL; - char *keyfile = NULL; - - FILE *cacertsfp = NULL; - FILE *certfp = NULL; - FILE *keyfp = NULL; - SM2_KEY sign_key; - - if (argc < 2) { - print_usage(prog); - return 0; + fprintf(stderr, "usage: %s %s\n", prog, options); + return 1; } argc--; argv++; - while (argc >= 1) { + while (argc > 0) { if (!strcmp(*argv, "-help")) { - print_usage(prog); + printf("usage: %s %s\n", prog, options); return 0; - } else if (!strcmp(*argv, "-host")) { if (--argc < 1) goto bad; host = *(++argv); - } else if (!strcmp(*argv, "-port")) { if (--argc < 1) goto bad; port = atoi(*(++argv)); - - } else if (!strcmp(*argv, "-cacerts")) { + } else if (!strcmp(*argv, "-cacert")) { if (--argc < 1) goto bad; - cacertsfile = *(++argv); - + cacertfile = *(++argv); } else if (!strcmp(*argv, "-cert")) { if (--argc < 1) goto bad; certfile = *(++argv); - } else if (!strcmp(*argv, "-key")) { if (--argc < 1) goto bad; keyfile = *(++argv); - + } else if (!strcmp(*argv, "-pass")) { + if (--argc < 1) goto bad; + pass = *(++argv); } else { - print_usage(prog); + fprintf(stderr, "%s: invalid option '%s'\n", prog, *argv); + return 1; +bad: + fprintf(stderr, "%s: option '%s' argument required\n", prog, *argv); return 0; } argc--; @@ -130,28 +122,34 @@ int main(int argc , char *argv[]) } if (!host) { - print_usage(prog); - return -1; + error_print(); + return 1; } - if (cacertsfile) { - if (!(cacertsfp = fopen(cacertsfile, "r"))) { + if (cacertfile) { + if (!(cacertfp = fopen(cacertfile, "r"))) { error_print(); - return -1; + return 1; } } + if (certfile) { if (!(certfp = fopen(certfile, "r"))) { error_print(); - return -1; + return 1; + } + if (!pass) { + pass = getpass("Password : "); + } + if (!keyfile) { + error_print(); + return 1; } - } - if (keyfile) { if (!(keyfp = fopen(keyfile, "r"))) { error_print(); return -1; } - if (sm2_private_key_from_pem(&sign_key, keyfp) != 1) { + if (sm2_private_key_info_decrypt_from_pem(&sm2_key, pass, keyfp) != 1) { error_print(); return -1; } @@ -159,12 +157,11 @@ int main(int argc , char *argv[]) memset(&conn, 0, sizeof(conn)); - if (tls12_connect(&conn, host, port, cacertsfp, certfp, &sign_key) != 1) { + if (tls12_connect(&conn, host, port, cacertfp, certfp, &sm2_key) != 1) { error_print(); return -1; } - // 这个client 发收了一个消息就结束了 if (tls_send(&conn, (uint8_t *)"12345\n", 6) != 1) { error_print(); return -1; @@ -182,12 +179,5 @@ int main(int argc , char *argv[]) break; } } - - return 1; -bad: - fprintf(stderr, "%s: command error\n", prog); - return 0; } - - diff --git a/tools/tls12_server.c b/tools/tls12_server.c index ce79ec0f..f1bd7795 100644 --- a/tools/tls12_server.c +++ b/tools/tls12_server.c @@ -1,4 +1,4 @@ -/* +/* * Copyright (c) 2021 - 2021 The GmSSL Project. All rights reserved. * * Redistribution and use in source and binary forms, with or without @@ -54,25 +54,25 @@ #include #include - -void print_usage(const char *prog) -{ - printf("Usage: %s [options]\n", prog); - printf(" -port \n"); - printf(" -cert \n"); - printf(" -signkey \n"); -} +// [-cacert file] 如果服务器需要客户端提供证书,那么自己必须准备可以验证客户端证书的CA证书 +// 因此如果提供了CA证书,那么等同于要求客户端验证 +static const char *options = " [-port num] -cert file -key file [-pass str] [-cacert file]"; int main(int argc , char *argv[]) { int ret = -1; char *prog = argv[0]; + int port = 443; char *certfile = NULL; - char *signkeyfile = NULL; + char *keyfile = NULL; + char *pass = NULL; + char *cacertfile = NULL; + FILE *certfp = NULL; - FILE *signkeyfp = NULL; - SM2_KEY signkey; + FILE *keyfp = NULL; + FILE *cacertfp = NULL; + SM2_KEY sm2_key; uint8_t verify_buf[4096]; @@ -82,40 +82,52 @@ int main(int argc , char *argv[]) size_t len = sizeof(buf); if (argc < 2) { - print_usage(prog); - return 0; + fprintf(stderr, "usage: %s %s\n", prog, options); + return 1; } argc--; argv++; while (argc >= 1) { if (!strcmp(*argv, "-help")) { - print_usage(prog); + printf("usage: %s %s\n", prog, options); return 0; - } else if (!strcmp(*argv, "-port")) { if (--argc < 1) goto bad; port = atoi(*(++argv)); - } else if (!strcmp(*argv, "-cert")) { if (--argc < 1) goto bad; certfile = *(++argv); - - } else if (!strcmp(*argv, "-signkey")) { + } else if (!strcmp(*argv, "-key")) { if (--argc < 1) goto bad; - signkeyfile = *(++argv); - + keyfile = *(++argv); + } else if (!strcmp(*argv, "-pass")) { + if (--argc < 1) goto bad; + pass = *(++argv); + } else if (!strcmp(*argv, "-cacert")) { + if (--argc < 1) goto bad; + cacertfile = *(++argv); } else { - print_usage(prog); - return 0; + fprintf(stderr, "%s: invalid option '%s'\n", prog, *argv); + return 1; +bad: + fprintf(stderr, "%s: option '%s' argument required\n", prog, *argv); + return 1; } argc--; argv++; } - if (!certfile || !signkeyfile) { - print_usage(prog); - return -1; + if (!certfile || !keyfile) { + error_print(); + return 1; + } + + if (cacertfile) { + if (!(cacertfp = fopen(cacertfile, "r"))) { + error_print(); + return -1; + } } if (!(certfp = fopen(certfile, "r"))) { @@ -123,19 +135,21 @@ int main(int argc , char *argv[]) return -1; } - - if (!(signkeyfp = fopen(signkeyfile, "r"))) { + if (!pass) { + pass = getpass("Password : "); + } + if (!(keyfp = fopen(keyfile, "r"))) { error_print(); return -1; } - if (sm2_private_key_from_pem(&signkey, signkeyfp) != 1) { + if (sm2_private_key_info_decrypt_from_pem(&sm2_key, pass, keyfp) != 1) { error_print(); return -1; } memset(&conn, 0, sizeof(conn)); - if (tls12_accept(&conn, port, certfp, &signkey, - NULL /* certfp */, verify_buf, 4096) != 1) { + if (tls12_accept(&conn, port, certfp, &sm2_key, cacertfp, verify_buf, 4096) != 1) { + //if (tls12_accept(&conn, port, certfp, &sm2_key, NULL, NULL, 0) != 1) { error_print(); return -1; } @@ -163,12 +177,5 @@ int main(int argc , char *argv[]) fprintf(stderr, "-----------------\n\n\n\n\n\n"); } - - - - return 1; -bad: - fprintf(stderr, "%s: command error\n", prog); - return 0; }