diff --git a/src/sm4_cbc.c b/src/sm4_cbc.c index f0858326..95e5e565 100644 --- a/src/sm4_cbc.c +++ b/src/sm4_cbc.c @@ -89,10 +89,14 @@ int sm4_cbc_encrypt_update(SM4_CBC_CTX *ctx, size_t nblocks; size_t len; - if (!ctx || !in || !out || !outlen) { + if (!ctx || !in || !outlen) { error_print(); return -1; } + if (!out) { + *outlen = 16 * ((inlen + 15)/16); + return 1; + } if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { error_print(); return -1; @@ -132,10 +136,14 @@ int sm4_cbc_encrypt_update(SM4_CBC_CTX *ctx, int sm4_cbc_encrypt_finish(SM4_CBC_CTX *ctx, uint8_t *out, size_t *outlen) { - if (!ctx || !out || !outlen) { + if (!ctx || !outlen) { error_print(); return -1; } + if (!out) { + *outlen = SM4_BLOCK_SIZE; + return 1; + } if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { error_print(); return -1; @@ -166,10 +174,14 @@ int sm4_cbc_decrypt_update(SM4_CBC_CTX *ctx, { size_t left, len, nblocks; - if (!ctx || !in || !out || !outlen) { + if (!ctx || !in || !outlen) { error_print(); return -1; } + if (!out) { + *outlen = 16 * ((inlen + 15)/16); + return 1; + } if (ctx->block_nbytes > SM4_BLOCK_SIZE) { error_print(); return -1; @@ -208,10 +220,14 @@ int sm4_cbc_decrypt_update(SM4_CBC_CTX *ctx, int sm4_cbc_decrypt_finish(SM4_CBC_CTX *ctx, uint8_t *out, size_t *outlen) { - if (!ctx || !out || !outlen) { + if (!ctx || !outlen) { error_print(); return -1; } + if (!out) { + *outlen = SM4_BLOCK_SIZE; + return 1; + } if (ctx->block_nbytes != SM4_BLOCK_SIZE) { error_print(); return -1; diff --git a/src/sm4_cfb.c b/src/sm4_cfb.c index f4bba174..e086b3b2 100644 --- a/src/sm4_cfb.c +++ b/src/sm4_cfb.c @@ -66,6 +66,10 @@ void sm4_cfb_decrypt(const SM4_KEY *key, size_t sbytes, uint8_t iv[16], int sm4_cfb_encrypt_init(SM4_CFB_CTX *ctx, size_t sbytes, const uint8_t key[SM4_BLOCK_SIZE], const uint8_t iv[SM4_BLOCK_SIZE]) { + if (!ctx || !key || !iv) { + error_print(); + return -1; + } if (sbytes < 1 || sbytes > 16) { error_print(); return -1; @@ -85,6 +89,14 @@ int sm4_cfb_encrypt_update(SM4_CFB_CTX *ctx, size_t nblocks; size_t len; + if (!ctx || !in || !outlen) { + error_print(); + return -1; + } + if (!out) { + *outlen = 16 * ((inlen + 15)/16); + return 1; + } if (ctx->block_nbytes >= ctx->sbytes) { error_print(); return -1; @@ -122,6 +134,14 @@ int sm4_cfb_encrypt_update(SM4_CFB_CTX *ctx, int sm4_cfb_encrypt_finish(SM4_CFB_CTX *ctx, uint8_t *out, size_t *outlen) { + if (!ctx || !outlen) { + error_print(); + return -1; + } + if (!out) { + *outlen = SM4_BLOCK_SIZE; + return 1; + } if (ctx->block_nbytes >= ctx->sbytes) { error_print(); return -1; @@ -134,6 +154,10 @@ int sm4_cfb_encrypt_finish(SM4_CFB_CTX *ctx, uint8_t *out, size_t *outlen) int sm4_cfb_decrypt_init(SM4_CFB_CTX *ctx, size_t sbytes, const uint8_t key[SM4_BLOCK_SIZE], const uint8_t iv[SM4_BLOCK_SIZE]) { + if (!ctx || !key || !iv) { + error_print(); + return -1; + } if (sbytes < 1 || sbytes > 16) { error_print(); return -1; @@ -153,6 +177,14 @@ int sm4_cfb_decrypt_update(SM4_CFB_CTX *ctx, size_t nblocks; size_t len; + if (!ctx || !in || !outlen) { + error_print(); + return -1; + } + if (!out) { + *outlen = 16 * ((inlen + 15)/16); + return 1; + } if (ctx->block_nbytes >= ctx->sbytes) { error_print(); return -1; @@ -190,6 +222,14 @@ int sm4_cfb_decrypt_update(SM4_CFB_CTX *ctx, int sm4_cfb_decrypt_finish(SM4_CFB_CTX *ctx, uint8_t *out, size_t *outlen) { + if (!ctx || !outlen) { + error_print(); + return -1; + } + if (!out) { + *outlen = SM4_BLOCK_SIZE; + return 1; + } if (ctx->block_nbytes >= ctx->sbytes) { error_print(); return -1; @@ -198,4 +238,3 @@ int sm4_cfb_decrypt_finish(SM4_CFB_CTX *ctx, uint8_t *out, size_t *outlen) *outlen = ctx->block_nbytes; return 1; } - diff --git a/src/sm4_ctr.c b/src/sm4_ctr.c index 0f03196e..d4a312a1 100644 --- a/src/sm4_ctr.c +++ b/src/sm4_ctr.c @@ -70,10 +70,14 @@ int sm4_ctr_encrypt_update(SM4_CTR_CTX *ctx, size_t nblocks; size_t len; - if (!ctx || !in || !out || !outlen) { + if (!ctx || !in || !outlen) { error_print(); return -1; } + if (!out) { + *outlen = 16 * ((inlen + 15)/16); + return 1; + } if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { error_print(); return -1; @@ -111,10 +115,14 @@ int sm4_ctr_encrypt_update(SM4_CTR_CTX *ctx, int sm4_ctr_encrypt_finish(SM4_CTR_CTX *ctx, uint8_t *out, size_t *outlen) { - if (!ctx || !out || !outlen) { + if (!ctx || !outlen) { error_print(); return -1; } + if (!out) { + *outlen = SM4_BLOCK_SIZE; + return 1; + } if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { error_print(); return -1; @@ -146,6 +154,14 @@ int sm4_ctr32_encrypt_update(SM4_CTR_CTX *ctx, size_t nblocks; size_t len; + if (!ctx || !in || !outlen) { + error_print(); + return -1; + } + if (!out) { + *outlen = 16 * ((inlen + 15)/16); + return 1; + } if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { error_print(); return -1; @@ -183,6 +199,14 @@ int sm4_ctr32_encrypt_update(SM4_CTR_CTX *ctx, int sm4_ctr32_encrypt_finish(SM4_CTR_CTX *ctx, uint8_t *out, size_t *outlen) { + if (!ctx || !outlen) { + error_print(); + return -1; + } + if (!out) { + *outlen = SM4_BLOCK_SIZE; + return 1; + } if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { error_print(); return -1; diff --git a/src/sm4_ecb.c b/src/sm4_ecb.c index 06f4b265..fb8592d9 100644 --- a/src/sm4_ecb.c +++ b/src/sm4_ecb.c @@ -32,10 +32,14 @@ int sm4_ecb_encrypt_update(SM4_ECB_CTX *ctx, size_t nblocks; size_t len; - if (!ctx || !in || !out || !outlen) { + if (!ctx || !in || !outlen) { error_print(); return -1; } + if (!out) { + *outlen = 16 * ((inlen + 15)/16); + return 1; + } if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { error_print(); return -1; @@ -77,6 +81,10 @@ int sm4_ecb_encrypt_finish(SM4_ECB_CTX *ctx, uint8_t *out, size_t *outlen) error_print(); return -1; } + if (!out) { + *outlen = SM4_BLOCK_SIZE; // anyway, caller should prepare a block buffer to support any length input + return 1; + } if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { error_print(); return -1; diff --git a/src/sm4_gcm.c b/src/sm4_gcm.c index 45cca1d8..d56b72e5 100644 --- a/src/sm4_gcm.c +++ b/src/sm4_gcm.c @@ -148,10 +148,14 @@ int sm4_gcm_encrypt_init(SM4_GCM_CTX *ctx, int sm4_gcm_encrypt_update(SM4_GCM_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) { - if (!ctx || !in || !out || !outlen) { + if (!ctx || !in || !outlen) { error_print(); return -1; } + if (!out) { + *outlen = 16 * ((inlen + 15)/16); + return 1; + } if (sm4_ctr32_encrypt_update(&ctx->enc_ctx, in, inlen, out, outlen) != 1) { error_print(); return -1; @@ -164,10 +168,14 @@ int sm4_gcm_encrypt_finish(SM4_GCM_CTX *ctx, uint8_t *out, size_t *outlen) { uint8_t mac[16]; - if (!ctx || !out || !outlen) { + if (!ctx || !outlen) { error_print(); return -1; } + if (!out) { + *outlen = SM4_BLOCK_SIZE * 2; // GCM output extra mac tag + return 1; + } if (sm4_ctr32_encrypt_finish(&ctx->enc_ctx, out, outlen) != 1) { error_print(); return -1; @@ -193,10 +201,14 @@ int sm4_gcm_decrypt_update(SM4_GCM_CTX *ctx, const uint8_t *in, size_t inlen, ui { size_t len; - if (!ctx || !in || !out || !outlen) { + if (!ctx || !in || !outlen) { error_print(); return -1; } + if (!out) { + *outlen = 16 * ((inlen + 15)/16); + return 1; + } if (ctx->maclen > ctx->taglen) { error_print(); return -1; @@ -251,10 +263,14 @@ int sm4_gcm_decrypt_finish(SM4_GCM_CTX *ctx, uint8_t *out, size_t *outlen) { uint8_t mac[GHASH_SIZE]; - if (!ctx || !out || !outlen) { + if (!ctx || !outlen) { error_print(); return -1; } + if (!out) { + *outlen = SM4_BLOCK_SIZE; + return 1; + } if (ctx->maclen != ctx->taglen) { error_print(); return -1; diff --git a/src/sm4_ofb.c b/src/sm4_ofb.c index 8e028769..cf82e236 100644 --- a/src/sm4_ofb.c +++ b/src/sm4_ofb.c @@ -30,6 +30,10 @@ void sm4_ofb_encrypt(const SM4_KEY *key, uint8_t iv[16], const uint8_t *in, size int sm4_ofb_encrypt_init(SM4_OFB_CTX *ctx, const uint8_t key[SM4_BLOCK_SIZE], const uint8_t iv[SM4_BLOCK_SIZE]) { + if (!ctx || !key || !iv) { + error_print(); + return -1; + } sm4_set_encrypt_key(&ctx->sm4_key, key); memcpy(ctx->iv, iv, SM4_BLOCK_SIZE); memset(ctx->block, 0, SM4_BLOCK_SIZE); @@ -44,6 +48,14 @@ int sm4_ofb_encrypt_update(SM4_OFB_CTX *ctx, size_t nblocks; size_t len; + if (!ctx || !in || !outlen) { + error_print(); + return -1; + } + if (!out) { + *outlen = 16 * ((inlen + 15)/16); + return 1; + } if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { error_print(); return -1; @@ -81,6 +93,14 @@ int sm4_ofb_encrypt_update(SM4_OFB_CTX *ctx, int sm4_ofb_encrypt_finish(SM4_OFB_CTX *ctx, uint8_t *out, size_t *outlen) { + if (!ctx || !outlen) { + error_print(); + return -1; + } + if (!out) { + *outlen = SM4_BLOCK_SIZE; + return 1; + } if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { error_print(); return -1;