Add encrypt/decrypt to sm2_ctx.c

This commit is contained in:
Zhi Guan
2024-01-07 17:26:29 +08:00
parent 2dab02f76a
commit 31efcb5d87
4 changed files with 212 additions and 4 deletions

View File

@@ -341,6 +341,20 @@ int sm2_do_ecdh(const SM2_KEY *key, const SM2_POINT *peer_public, SM2_POINT *out
_gmssl_export int sm2_ecdh(const SM2_KEY *key, const uint8_t *peer_public, size_t peer_public_len, SM2_POINT *out);
typedef struct {
SM2_KEY sm2_key;
uint8_t buf[SM2_MAX_CIPHERTEXT_SIZE];
size_t buf_size;
} SM2_ENC_CTX;
_gmssl_export int sm2_encrypt_init(SM2_ENC_CTX *ctx, const SM2_KEY *sm2_key);
_gmssl_export int sm2_encrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
_gmssl_export int sm2_encrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
_gmssl_export int sm2_decrypt_init(SM2_ENC_CTX *ctx, const SM2_KEY *sm2_key);
_gmssl_export int sm2_decrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
_gmssl_export int sm2_decrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen);
#ifdef __cplusplus
}
#endif

View File

@@ -135,3 +135,182 @@ int sm2_verify_finish(SM2_SIGN_CTX *ctx, const uint8_t *sig, size_t siglen)
return 1;
}
int sm2_encrypt_init(SM2_ENC_CTX *ctx, const SM2_KEY *sm2_key)
{
if (!ctx || !sm2_key) {
error_print();
return -1;
}
memset(ctx, 0, sizeof(*ctx));
ctx->sm2_key = *sm2_key;
return 1;
}
int sm2_encrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
{
if (!ctx || !outlen) {
error_print();
return -1;
}
if (ctx->buf_size > SM2_MAX_PLAINTEXT_SIZE) {
error_print();
return -1;
}
if (!out) {
*outlen = 0;
return 1;
}
if (in) {
if (inlen > SM2_MAX_PLAINTEXT_SIZE - ctx->buf_size) {
error_print();
return -1;
}
memcpy(ctx->buf + ctx->buf_size, in, inlen);
ctx->buf_size += inlen;
}
*outlen = 0;
return 1;
}
int sm2_encrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
{
if (!ctx || !outlen) {
error_print();
return -1;
}
if (ctx->buf_size > SM2_MAX_PLAINTEXT_SIZE) {
error_print();
return -1;
}
if (!out) {
*outlen = SM2_MAX_CIPHERTEXT_SIZE;
return 1;
}
if (ctx->buf_size) {
if (in) {
if (inlen > SM2_MAX_PLAINTEXT_SIZE - ctx->buf_size) {
error_print();
return -1;
}
memcpy(ctx->buf + ctx->buf_size, in, inlen);
ctx->buf_size += inlen;
}
if (sm2_encrypt(&ctx->sm2_key, ctx->buf, ctx->buf_size, out, outlen) != 1) {
error_print();
return -1;
}
} else {
if (!in || !inlen || inlen > SM2_MAX_PLAINTEXT_SIZE) {
error_print();
return -1;
}
if (sm2_encrypt(&ctx->sm2_key, in, inlen, out, outlen) != 1) {
error_print();
return -1;
}
}
return 1;
}
int sm2_decrypt_init(SM2_ENC_CTX *ctx, const SM2_KEY *sm2_key)
{
if (!ctx || !sm2_key) {
error_print();
return -1;
}
memset(ctx, 0, sizeof(*ctx));
ctx->sm2_key = *sm2_key;
return 1;
}
int sm2_decrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
{
if (!ctx || !outlen) {
error_print();
return -1;
}
if (ctx->buf_size > SM2_MAX_CIPHERTEXT_SIZE) {
error_print();
return -1;
}
if (!out) {
*outlen = 0;
return 1;
}
if (in) {
if (inlen > SM2_MAX_CIPHERTEXT_SIZE - ctx->buf_size) {
error_print();
return -1;
}
memcpy(ctx->buf + ctx->buf_size, in, inlen);
ctx->buf_size += inlen;
}
*outlen = 0;
return 1;
}
int sm2_decrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
{
if (!ctx || !outlen) {
error_print();
return -1;
}
if (ctx->buf_size > SM2_MAX_CIPHERTEXT_SIZE) {
error_print();
return -1;
}
if (!out) {
*outlen = SM2_MAX_PLAINTEXT_SIZE;
return 1;
}
if (ctx->buf_size) {
if (in) {
if (inlen > SM2_MAX_CIPHERTEXT_SIZE - ctx->buf_size) {
error_print();
return -1;
}
memcpy(ctx->buf + ctx->buf_size, in, inlen);
ctx->buf_size += inlen;
}
if (sm2_decrypt(&ctx->sm2_key, ctx->buf, ctx->buf_size, out, outlen) != 1) {
error_print();
return -1;
}
} else {
if (!in || !inlen || inlen > SM2_MAX_CIPHERTEXT_SIZE) {
error_print();
return -1;
}
if (sm2_decrypt(&ctx->sm2_key, in, inlen, out, outlen) != 1) {
error_print();
return -1;
}
}
return 1;
}

View File

@@ -30,6 +30,7 @@ int sm2decrypt_main(int argc, char **argv)
FILE *infp = stdin;
FILE *outfp = stdout;
SM2_KEY key;
SM2_ENC_CTX ctx;
uint8_t inbuf[SM2_MAX_CIPHERTEXT_SIZE];
uint8_t outbuf[SM2_MAX_CIPHERTEXT_SIZE];
size_t inlen, outlen;
@@ -101,7 +102,12 @@ bad:
fprintf(stderr, "%s: read input failed : %s\n", prog, strerror(errno));
goto end;
}
if (sm2_decrypt(&key, inbuf, inlen, outbuf, &outlen) != 1) {
if (sm2_decrypt_init(&ctx, &key) != 1) {
fprintf(stderr, "%s: sm2_decrypt_init failed\n", prog);
goto end;
}
if (sm2_decrypt_finish(&ctx, inbuf, inlen, outbuf, &outlen) != 1) {
fprintf(stderr, "%s: decryption failure\n", prog);
goto end;
}
@@ -112,6 +118,8 @@ bad:
ret = 0;
end:
gmssl_secure_clear(&key, sizeof(key));
gmssl_secure_clear(&ctx, sizeof(ctx));
gmssl_secure_clear(outbuf, sizeof(outbuf));
if (keyfp) fclose(keyfp);
if (infile && infp) fclose(infp);
if (outfile && outfp) fclose(outfp);

View File

@@ -12,10 +12,10 @@
#include <errno.h>
#include <string.h>
#include <stdlib.h>
#include <gmssl/mem.h>
#include <gmssl/sm2.h>
#include <gmssl/x509.h>
static const char *options = "(-pubkey pem | -cert pem) [-in file] [-out file]";
int sm2encrypt_main(int argc, char **argv)
@@ -33,6 +33,7 @@ int sm2encrypt_main(int argc, char **argv)
uint8_t cert[1024];
size_t certlen;
SM2_KEY key;
SM2_ENC_CTX ctx;
uint8_t inbuf[SM2_MAX_PLAINTEXT_SIZE + 1];
uint8_t outbuf[SM2_MAX_CIPHERTEXT_SIZE];
size_t inlen, outlen = sizeof(outbuf);
@@ -124,8 +125,12 @@ bad:
goto end;
}
if (sm2_encrypt(&key, inbuf, inlen, outbuf, &outlen) != 1) {
fprintf(stderr, "%s: inner error\n", prog);
if (sm2_encrypt_init(&ctx, &key) != 1) {
fprintf(stderr, "%s: sm2_encrypt_init failed\n", prog);
goto end;
}
if (sm2_encrypt_finish(&ctx, inbuf, inlen, outbuf, &outlen) != 1) {
fprintf(stderr, "%s: sm2_encrypt_finish error\n", prog);
goto end;
}
@@ -136,6 +141,8 @@ bad:
ret = 0;
end:
gmssl_secure_clear(&ctx, sizeof(ctx));
gmssl_secure_clear(inbuf, sizeof(inbuf));
if (infile && infp) fclose(infp);
if (outfile && outfp) fclose(outfp);
if (pubkeyfp) fclose(pubkeyfp);