diff --git a/include/gmssl/base64.h b/include/gmssl/base64.h index d11d3f5d..acf3aea4 100644 --- a/include/gmssl/base64.h +++ b/include/gmssl/base64.h @@ -47,6 +47,8 @@ void base64_encode_finish(BASE64_CTX *ctx, uint8_t *out, int *outlen); void base64_decode_init(BASE64_CTX *ctx); int base64_decode_update(BASE64_CTX *ctx, const uint8_t *in, int inlen, uint8_t *out, int *outlen); int base64_decode_finish(BASE64_CTX *ctx, uint8_t *out, int *outlen); +int base64_decode_update_ex(BASE64_CTX *ctx, const uint8_t *in, int inlen, uint8_t *out, int *outlen, size_t maxout); +int base64_decode_finish_ex(BASE64_CTX *ctx, uint8_t *out, int *outlen, size_t maxout); int base64_encode_block(unsigned char *t, const unsigned char *f, int dlen); diff --git a/src/base64.c b/src/base64.c index 4957747d..66afd56c 100644 --- a/src/base64.c +++ b/src/base64.c @@ -211,10 +211,12 @@ void base64_decode_init(BASE64_CTX *ctx) * - There is extra trailing padding, or data after padding. * - B64_EOF is detected after an incomplete base64 block. */ -int base64_decode_update(BASE64_CTX *ctx, const uint8_t *in, int inl, uint8_t *out, int *outl) +int base64_decode_update_ex(BASE64_CTX *ctx, const uint8_t *in, int inl, uint8_t *out, int *outl, size_t maxout) { - int seof = 0, eof = 0, rv = -1, ret = 0, i, v, tmp, n, decoded_len; + int seof = 0, eof = 0, rv = -1, ret = 0, i, v, tmp, n, decoded_len, block_len, output_len; unsigned char *d; + unsigned char block[BIN_PER_LINE]; + unsigned char *p; n = ctx->num; d = ctx->enc_data; @@ -277,14 +279,27 @@ int base64_decode_update(BASE64_CTX *ctx, const uint8_t *in, int inl, uint8_t *o } if (n == 64) { - decoded_len = base64_decode_block(out, d, n); + block_len = n / 4 * 3; + output_len = block_len - eof; + if (eof > block_len + || (size_t)ret > maxout + || (size_t)output_len > maxout - (size_t)ret) { + error_print(); + rv = -1; + goto end; + } + p = eof ? block : out; + decoded_len = base64_decode_block(p, d, n); n = 0; if (decoded_len < 0 || eof > decoded_len) { rv = -1; goto end; } - ret += decoded_len - eof; - out += decoded_len - eof; + if (eof && output_len > 0) { + memcpy(out, block, output_len); + } + ret += output_len; + out += output_len; } } @@ -296,14 +311,28 @@ int base64_decode_update(BASE64_CTX *ctx, const uint8_t *in, int inl, uint8_t *o tail: if (n > 0) { if ((n & 3) == 0) { - decoded_len = base64_decode_block(out, d, n); + block_len = n / 4 * 3; + output_len = block_len - eof; + if (eof > block_len + || (size_t)ret > maxout + || (size_t)output_len > maxout - (size_t)ret) { + error_print(); + rv = -1; + goto end; + } + p = eof ? block : out; + decoded_len = base64_decode_block(p, d, n); n = 0; if (decoded_len < 0 || eof > decoded_len) { error_print(); rv = -1; goto end; } - ret += (decoded_len - eof); + if (eof && output_len > 0) { + memcpy(out, block, output_len); + } + ret += output_len; + out += output_len; } else if (seof) { /* EOF in the middle of a base64 block. */ error_print(); @@ -320,6 +349,11 @@ end: return (rv); } +int base64_decode_update(BASE64_CTX *ctx, const uint8_t *in, int inl, uint8_t *out, int *outl) +{ + return base64_decode_update_ex(ctx, in, inl, out, outl, (size_t)-1); +} + int base64_decode_block(unsigned char *t, const unsigned char *f, int n) { int i, ret = 0, a, b, c, d; @@ -359,12 +393,17 @@ int base64_decode_block(unsigned char *t, const unsigned char *f, int n) return (ret); } -int base64_decode_finish(BASE64_CTX *ctx, uint8_t *out, int *outl) +int base64_decode_finish_ex(BASE64_CTX *ctx, uint8_t *out, int *outl, size_t maxout) { int i; *outl = 0; if (ctx->num != 0) { + if ((ctx->num & 3) != 0 + || (size_t)(ctx->num / 4 * 3) > maxout) { + error_print(); + return (-1); + } i = base64_decode_block(out, ctx->enc_data, ctx->num); if (i < 0) { error_print(); @@ -376,3 +415,8 @@ int base64_decode_finish(BASE64_CTX *ctx, uint8_t *out, int *outl) } else return (1); } + +int base64_decode_finish(BASE64_CTX *ctx, uint8_t *out, int *outl) +{ + return base64_decode_finish_ex(ctx, out, outl, (size_t)-1); +} diff --git a/src/pem.c b/src/pem.c index a53fba40..eb522d7e 100644 --- a/src/pem.c +++ b/src/pem.c @@ -72,8 +72,10 @@ int pem_read(FILE *fp, const char *name, uint8_t *data, size_t *datalen, size_t char line[80]; char begin_line[80]; char end_line[80]; - int len; BASE64_CTX ctx; + size_t linelen; + int len; + int ret; snprintf(begin_line, sizeof(begin_line), "-----BEGIN %s-----", name); snprintf(end_line, sizeof(end_line), "-----END %s-----", name); @@ -116,12 +118,21 @@ int pem_read(FILE *fp, const char *name, uint8_t *data, size_t *datalen, size_t break; } - base64_decode_update(&ctx, (uint8_t *)line, (int)strlen(line), data, &len); - data += len; + linelen = strlen(line); + ret = base64_decode_update_ex(&ctx, (uint8_t *)line, (int)linelen, + data + *datalen, &len, maxlen - *datalen); + if (ret < 0 || len < 0 || (size_t)len > maxlen - *datalen) { + error_print(); + return -1; + } *datalen += len; } - base64_decode_finish(&ctx, data, &len); + if (base64_decode_finish_ex(&ctx, data + *datalen, &len, maxlen - *datalen) != 1 + || len < 0 || (size_t)len > maxlen - *datalen) { + error_print(); + return -1; + } *datalen += len; return 1; } diff --git a/tests/base64test.c b/tests/base64test.c index ef2ae234..14b744b8 100644 --- a/tests/base64test.c +++ b/tests/base64test.c @@ -61,9 +61,78 @@ static int test_base64(void) return 1; } +struct base64_one_byte_buffer { + uint8_t data[1]; + uint8_t canary[8]; +}; + +static int test_base64_decode_update_ex_padding(void) +{ + const uint8_t in[] = "AA=="; + struct base64_one_byte_buffer buf; + uint8_t original_canary[sizeof(buf.canary)]; + BASE64_CTX ctx; + int len = 0; + int ret; + + memset(&buf, 0xff, sizeof(buf)); + memset(buf.canary, 0xcc, sizeof(buf.canary)); + memcpy(original_canary, buf.canary, sizeof(original_canary)); + + base64_decode_init(&ctx); + ret = base64_decode_update_ex(&ctx, in, sizeof(in) - 1, + buf.data, &len, sizeof(buf.data)); + if (ret != 0 || len != 1 || buf.data[0] != 0) { + error_print(); + return -1; + } + if (memcmp(buf.canary, original_canary, sizeof(original_canary)) != 0) { + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + +struct base64_two_byte_buffer { + uint8_t data[2]; + uint8_t canary[8]; +}; + +static int test_base64_decode_update_ex_maxout(void) +{ + const uint8_t in[] = "AAAA"; + struct base64_two_byte_buffer buf; + uint8_t original_canary[sizeof(buf.canary)]; + BASE64_CTX ctx; + int len = 0; + int ret; + + memset(&buf, 0xcc, sizeof(buf)); + memcpy(original_canary, buf.canary, sizeof(original_canary)); + + base64_decode_init(&ctx); + ret = base64_decode_update_ex(&ctx, in, sizeof(in) - 1, + buf.data, &len, sizeof(buf.data)); + if (ret != -1 || len != 0) { + error_print(); + return -1; + } + if (memcmp(buf.canary, original_canary, sizeof(original_canary)) != 0) { + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + int main(void) { if (test_base64() != 1) goto err; + if (test_base64_decode_update_ex_padding() != 1) goto err; + if (test_base64_decode_update_ex_maxout() != 1) goto err; printf("%s all tests passed\n", __FILE__); return 0; err: diff --git a/tests/pemtest.c b/tests/pemtest.c index 3c00b4dd..e249a03d 100644 --- a/tests/pemtest.c +++ b/tests/pemtest.c @@ -235,11 +235,114 @@ static int test_pem_windows_style_without_last_newline(void) return 1; } +struct pem_bounded_buffer { + uint8_t data[16]; + uint8_t canary[32]; +}; + +static int test_pem_read_maxlen(void) +{ + FILE *fp; + const char *file = "test_pem_read_maxlen.pem"; + struct pem_bounded_buffer buf; + uint8_t original_canary[sizeof(buf.canary)]; + size_t len = 0; + int ret; + int i; + + if (!(fp = fopen(file, "wb"))) { + error_print(); + return -1; + } + fputs("-----BEGIN TEST-----\n", fp); + for (i = 0; i < 64; i++) { + fputc('A', fp); + } + fputs("\n-----END TEST-----\n", fp); + fclose(fp); + + memset(&buf, 0, sizeof(buf)); + memset(buf.canary, 0xcc, sizeof(buf.canary)); + memcpy(original_canary, buf.canary, sizeof(original_canary)); + + if (!(fp = fopen(file, "rb"))) { + error_print(); + return -1; + } + ret = pem_read(fp, "TEST", buf.data, &len, sizeof(buf.data)); + fclose(fp); + + if (ret == 1) { + error_print(); + return -1; + } + if (len > sizeof(buf.data)) { + error_print(); + return -1; + } + if (memcmp(buf.canary, original_canary, sizeof(original_canary)) != 0) { + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + +struct pem_one_byte_buffer { + uint8_t data[1]; + uint8_t canary[8]; +}; + +static int test_pem_read_padding_maxlen(void) +{ + FILE *fp; + const char *file = "test_pem_read_padding_maxlen.pem"; + struct pem_one_byte_buffer buf; + uint8_t original_canary[sizeof(buf.canary)]; + size_t len = 0; + int ret; + + if (!(fp = fopen(file, "wb"))) { + error_print(); + return -1; + } + fputs("-----BEGIN TEST-----\n", fp); + fputs("AA==\n", fp); + fputs("-----END TEST-----\n", fp); + fclose(fp); + + memset(&buf, 0xff, sizeof(buf)); + memset(buf.canary, 0xcc, sizeof(buf.canary)); + memcpy(original_canary, buf.canary, sizeof(original_canary)); + + if (!(fp = fopen(file, "rb"))) { + error_print(); + return -1; + } + ret = pem_read(fp, "TEST", buf.data, &len, sizeof(buf.data)); + fclose(fp); + + if (ret != 1 || len != sizeof(buf.data) || buf.data[0] != 0) { + error_print(); + return -1; + } + if (memcmp(buf.canary, original_canary, sizeof(original_canary)) != 0) { + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + int main(void) { if (test_pem_unix_style() != 1) { error_print(); return 1; } if (test_pem_unix_style_without_last_newline() != 1) { error_print(); return 1; } if (test_pem_windows_style() != 1) { error_print(); return 1; } if (test_pem_windows_style_without_last_newline() != 1) { error_print(); return 1; } + if (test_pem_read_maxlen() != 1) { error_print(); return 1; } + if (test_pem_read_padding_maxlen() != 1) { error_print(); return 1; } return 0; }