diff --git a/include/gmssl/sm3.h b/include/gmssl/sm3.h index 9b9be2b0..37d18cc4 100644 --- a/include/gmssl/sm3.h +++ b/include/gmssl/sm3.h @@ -58,6 +58,7 @@ void sm3_update(SM3_CTX *ctx, const uint8_t *data, size_t datalen); void sm3_finish(SM3_CTX *ctx, uint8_t dgst[SM3_DIGEST_SIZE]); void sm3_digest(const uint8_t *data, size_t datalen, uint8_t dgst[SM3_DIGEST_SIZE]); +void sm3_compress_blocks(uint32_t digest[8], const uint8_t *data, size_t blocks); typedef struct { SM3_CTX sm3_ctx; diff --git a/src/sm3.c b/src/sm3.c index 1853b62c..eb2d2e64 100644 --- a/src/sm3.c +++ b/src/sm3.c @@ -294,7 +294,6 @@ void sm3_compress_blocks(uint32_t digest[8], const uint8_t *data, size_t blocks) } } - void sm3_init(SM3_CTX *ctx) { memset(ctx, 0, sizeof(*ctx)); @@ -321,7 +320,6 @@ void sm3_update(SM3_CTX *ctx, const uint8_t *data, size_t data_len) ctx->num += data_len; return; } else { - memcpy(ctx->block + ctx->num, data, left); sm3_compress_blocks(ctx->digest, ctx->block, 1); ctx->nblocks++; data += left; @@ -330,10 +328,12 @@ void sm3_update(SM3_CTX *ctx, const uint8_t *data, size_t data_len) } blocks = data_len / SM3_BLOCK_SIZE; - sm3_compress_blocks(ctx->digest, data, blocks); - ctx->nblocks += blocks; - data += SM3_BLOCK_SIZE * blocks; - data_len -= SM3_BLOCK_SIZE * blocks; + if (blocks) { + sm3_compress_blocks(ctx->digest, data, blocks); + ctx->nblocks += blocks; + data += SM3_BLOCK_SIZE * blocks; + data_len -= SM3_BLOCK_SIZE * blocks; + } ctx->num = data_len; if (data_len) {