diff --git a/CMakeLists.txt b/CMakeLists.txt index f495087b..a68dea58 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,6 +61,7 @@ option(ENABLE_SHA1 "Enable SHA1" ON) option(ENABLE_SHA2 "Enable SHA2" ON) option(ENABLE_AES "Enable AES" ON) option(ENABLE_CHACHA20 "Enable Chacha20" ON) +option(ENABLE_KYBER "Enable Kyber" ON) option(ENABLE_SKF "Enable SKF module" OFF) option(ENABLE_SDF "Enable SDF module" ON) @@ -477,6 +478,13 @@ if (ENABLE_CHACHA20) endif() +if (ENABLE_KYBER) + message(STATUS "ENABLE_KYBER is ON") + list(APPEND src src/kyber.c) + list(APPEND tests kyber) +endif() + + if (ENABLE_INTEL_RDRAND) include(CheckSourceCompiles) set(CMAKE_REQUIRED_FLAGS "-rdrand") diff --git a/include/gmssl/kyber.h b/include/gmssl/kyber.h new file mode 100644 index 00000000..ced52a5c --- /dev/null +++ b/include/gmssl/kyber.h @@ -0,0 +1,184 @@ +/* + * Copyright 2014-2025 The GmSSL Project. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +#ifndef GMSSL_KYBER_H +#define GMSSL_KYBER_H + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#ifdef __cplusplus +extern "C" { +#endif + + + +#define KYBER_Q 3329 +#define KYBER_ZETA 17 +#define KYBER_N 256 +#define KYBER_ETA2 2 +#define KYBER_POLY_NBYTES (256 * 12 / 8) + +#define KYBER512_K 2 +#define KYBER768_K 3 +#define KYBER1024_K 4 + +#define KYBER512_ETA1 3 +#define KYBER768_ETA1 2 +#define KYBER1024_ETA1 2 + +#define KYBER512_DU 10 +#define KYBER768_DU 10 +#define KYBER1024_DU 11 + +#define KYBER512_DV 4 +#define KYBER769_DV 4 +#define KYBER1024_DV 5 + +#define KYBER_K KYBER512_K +#define KYBER_ETA1 KYBER512_ETA1 +#define KYBER_DU KYBER512_DU +#define KYBER_DV KYBER512_DV + + +#define KYBER_C1_SIZE ((256 * KYBER_DU)/8) +#define KYBER_C2_SIZE ((256 * KYBER_DV)/8) + + +#define KYBER_TEST + + +/* +CRYSTALS-Kyber Algorithm Specifications and Supporing Documentation (version 3.02) + + + FIPS-202 90s + + XOF SHAKE-128 AES256-CTR MGF1-SM3 + H SHA3-256 SHA256 SM3 + G SHA3-512 SHA512 MGF1-SM3 + PRF(s,b) SHAKE-256(s||b) AES256-CTR HKDF-SM3 + KDF SHAKE-256 SHA256 HKDF-SM3 + +*/ + + + +typedef int16_t kyber_poly_t[256]; + +typedef struct { + uint8_t t[KYBER_K][384]; + uint8_t rho[32]; +} KYBER_CPA_PUBLIC_KEY; + +typedef struct { + uint8_t s[KYBER_K][384]; +} KYBER_CPA_PRIVATE_KEY; + +typedef struct { + uint8_t c1[KYBER_K][KYBER_C1_SIZE]; + uint8_t c2[KYBER_C2_SIZE]; +} KYBER_CPA_CIPHERTEXT; + + +typedef KYBER_CPA_PUBLIC_KEY KYBER_PUBLIC_KEY; + +typedef struct { + KYBER_CPA_PRIVATE_KEY sk; + KYBER_CPA_PUBLIC_KEY pk; + uint8_t pk_hash[32]; + uint8_t z[32]; +} KYBER_PRIVATE_KEY; + +typedef KYBER_CPA_CIPHERTEXT KYBER_CIPHERTEXT; + + + +void kyber_h_hash(const uint8_t *in, size_t inlen, uint8_t out[32]); +void kyber_g_hash(const uint8_t *in, size_t inlen, uint8_t out[64]); + + +#define KYBER_FMT_POLY 1 +#define KYBER_FMT_HEX 2 + +int kyber_poly_print(FILE *fp, int fmt, int ind, const char *label, const kyber_poly_t a); +void kyber_poly_set_zero(kyber_poly_t r); + + +int kyber_poly_rand(kyber_poly_t r); +int kyber_poly_uniform_sample(kyber_poly_t r, const uint8_t rho[32], uint8_t j, uint8_t i); +int kyber_poly_cbd_sample(kyber_poly_t r, int eta, const uint8_t secret[32], uint8_t n); +int kyber_poly_equ(const kyber_poly_t a, const kyber_poly_t b); +void kyber_poly_add(kyber_poly_t r, const kyber_poly_t a, const kyber_poly_t b); +void kyber_poly_sub(kyber_poly_t r, const kyber_poly_t a, const kyber_poly_t b); + + +int16_t zeta[256]; + +void init_zeta(void); + +int kyber_poly_ntt(int16_t a[256]); +int kyber_poly_inv_ntt(int16_t a[256]); + +int kyber_poly_ntt_mul(kyber_poly_t r, const kyber_poly_t a, const kyber_poly_t b); +void kyber_poly_copy(kyber_poly_t r, const kyber_poly_t a); + +int kyber_poly_mul(kyber_poly_t r, const kyber_poly_t a, const kyber_poly_t b); + +void kyber_poly_ntt_mul_scalar(kyber_poly_t r, int scalar, const kyber_poly_t a); + +int kyber_poly_to_signed(const kyber_poly_t a, kyber_poly_t r); + +int kyber_poly_from_signed(kyber_poly_t r, const kyber_poly_t a); +int kyber_poly_compress(const kyber_poly_t a, int dbits, kyber_poly_t z); +int kyber_poly_decompress(kyber_poly_t r, int dbits, const kyber_poly_t z); +int kyber_poly_encode12(const kyber_poly_t a, uint8_t out[384]); +int kyber_poly_decode12(kyber_poly_t r, const uint8_t in[384]); + +int kyber_poly_encode10(const kyber_poly_t a, uint8_t out[320]); + +int kyber_poly_decode10(kyber_poly_t r, const uint8_t in[320]); + +int kyber_poly_encode4(const kyber_poly_t a, uint8_t out[128]); + +void kyber_poly_decode4(kyber_poly_t r, const uint8_t in[128]); +void kyber_poly_decode1(kyber_poly_t r, const uint8_t in[32]); +int kyber_poly_encode1(const kyber_poly_t a, uint8_t out[32]); + + + +int kyber_cpa_keygen(KYBER_CPA_PUBLIC_KEY *pk, KYBER_CPA_PRIVATE_KEY *sk); +int kyber_cpa_ciphertext_print(FILE *fp, int fmt, int ind, const char *label, const KYBER_CPA_CIPHERTEXT *c); +int kyber_ciphertext_print(FILE *fp, int fmt, int ind, const char *label, const KYBER_CPA_CIPHERTEXT *c); +int kyber_cpa_public_key_print(FILE *fp, int fmt, int ind, const char *label, const KYBER_CPA_PUBLIC_KEY *pk); +int kyber_cpa_private_key_print(FILE *fp, int fmt, int ind, const char *label, const KYBER_CPA_PRIVATE_KEY *sk); +int kyber_cpa_encrypt(const KYBER_CPA_PUBLIC_KEY *pk, const uint8_t in[32], + const uint8_t rand[32], KYBER_CPA_CIPHERTEXT *out); +int kyber_cpa_decrypt(const KYBER_CPA_PRIVATE_KEY *sk, const KYBER_CPA_CIPHERTEXT *in, uint8_t out[32]); +int kyber_keygen(KYBER_PUBLIC_KEY *pk, KYBER_PRIVATE_KEY *sk); +int kyber_private_key_print(FILE *fp, int fmt, int ind, const char *label, const KYBER_PRIVATE_KEY *sk); +int kyber_public_key_print(FILE *fp, int fmt, int ind, const char *label, const KYBER_PUBLIC_KEY *pk); +int kyber_encap(const KYBER_PUBLIC_KEY *pk, KYBER_CIPHERTEXT *c, uint8_t K[32]); +int kyber_decap(const KYBER_PRIVATE_KEY *sk, const KYBER_CIPHERTEXT *c, uint8_t K[32]); + + +#ifdef __cplusplus +} +#endif +#endif diff --git a/src/kyber.c b/src/kyber.c index ecbe1209..540a78f9 100644 --- a/src/kyber.c +++ b/src/kyber.c @@ -666,411 +666,7 @@ int kyber_poly_encode1(const kyber_poly_t a, uint8_t out[32]) -static int test_kyber_poly_uniform_sample(void) -{ - kyber_poly_t a; - uint8_t rho[32]; - rand_bytes(rho, sizeof(rho)); - - - kyber_poly_uniform_sample(a, rho, 0, 0); - kyber_poly_to_signed(a, a); - - //kyber_poly_print(stderr, 0, 0, "a from uniform sampling", a); - - return 1; -} - -static int test_kyber_poly_cbd_sample(void) -{ - kyber_poly_t a; - uint8_t seed[32]; - - - rand_bytes(seed, sizeof(seed)); - kyber_poly_cbd_sample(a, 2, seed, 0); - kyber_poly_to_signed(a, a); - //kyber_poly_print(stderr, 0, 0, "cbd(eta=2)", a); - - kyber_poly_cbd_sample(a, 3, seed, 0); - kyber_poly_to_signed(a, a); - //kyber_poly_print(stderr, 0, 0, "cbd(eta=3)", a); - - return 1; -} - -static int test_kyber_poly_to_signed(void) -{ - kyber_poly_t a, b; - int i; - - - if (kyber_poly_rand(a) != 1) { - error_print(); - return -1; - } - if (kyber_poly_to_signed(a, b) != 1) { - error_print(); - return -1; - } - for (i = 0; i < 256; i++) { - if (b[i] < -(KYBER_Q - 1)/2 || b[i] > (KYBER_Q - 1)/2) { - error_print(); - return -1; - } - } - - if (kyber_poly_from_signed(b, b) != 1) { - error_print(); - return -1; - } - if (kyber_poly_equ(a, b) != 1) { - error_print(); - return -1; - } - - printf("%s() ok\n", __FUNCTION__); - return 1; -} - -static int test_kyber_poly_ntt(void) -{ - kyber_poly_t a, b; - int i; - - - if (kyber_poly_rand(a) != 1) { - error_print(); - return -1; - } - - memcpy(b, a, sizeof(kyber_poly_t)); - if (kyber_poly_ntt(b) != 1) { - error_print(); - return -1; - } - if (kyber_poly_inv_ntt(b) != 1) { - error_print(); - return -1; - } - - if (kyber_poly_equ(a, b) != 1) { - error_print(); - return -1; - } - - printf("%s() ok\n", __FUNCTION__); - return 1; -} - -/* -#!/bin/sage - - q = 3329 - n = 256 - - R. = PolynomialRing(Integers(q)) - Rq = R.quotient(x^n + 1, 'x') - - # a = 1 + 2*x + 3*x^2 + ... + 256*x^255 - coefficients = list(range(1, n+1)) - a = sum(coeff * x^i for i, coeff in enumerate(coefficients)) - a = Rq(a) - - # b = 256 + 255*x + ... + 1*x^255 - coefficients = list(range(n, 0, -1)) - b = sum(coeff * x^i for i, coeff in enumerate(coefficients)) - b = Rq(b) - - r = a * b - - r = r.lift() # Quotient ring element back to a polynomial - r = r.coefficients(sparse=False) - for i in range(0, n, 16): - print(r[i:i+16]) - -*/ -static int test_kyber_poly_ntt_mul(void) -{ - const kyber_poly_t r = { - 656, 772, 1140, 1758, 2624, 407, 1763, 32, 1870, 617, 2929, 2146, 1595, 1274, 1181, 1314, - 1671, 2250, 3049, 737, 1970, 88, 1747, 287, 2364, 1318, 476, 3165, 2725, 2483, 2437, 2585, - 2925, 126, 844, 1748, 2836, 777, 2227, 526, 2330, 979, 3129, 2120, 1279, 604, 93, 3073, - 2884, 2853, 2978, 3257, 359, 940, 1669, 2544, 234, 1395, 2696, 806, 2381, 761, 2602, 1244, - 14, 2239, 1259, 401, 2992, 2372, 1868, 1478, 1200, 1032, 972, 1018, 1168, 1420, 1772, 2222, - 2768, 79, 811, 1633, 2543, 210, 1290, 2452, 365, 1685, 3081, 1222, 2764, 1047, 2727, 1144, - 2954, 1497, 100, 2090, 807, 2907, 1730, 603, 2853, 1820, 831, 3213, 2306, 1437, 604, 3134, - 2367, 1630, 921, 238, 2908, 2271, 1654, 1055, 472, 3232, 2675, 2128, 1589, 1056, 527, 0, - 2802, 2273, 1740, 1201, 654, 97, 2857, 2274, 1675, 1058, 421, 3091, 2408, 1699, 962, 195, - 2725, 1892, 1023, 116, 2498, 1509, 476, 2726, 1599, 422, 2522, 1239, 3229, 1832, 375, 2185, - 602, 2282, 565, 2107, 248, 1644, 2964, 877, 2039, 3119, 786, 1696, 2518, 3250, 561, 1107, - 1557, 1909, 2161, 2311, 2357, 2297, 2129, 1851, 1461, 957, 337, 2928, 2070, 1090, 3315, 2085, - 727, 2568, 948, 2523, 633, 1934, 3095, 785, 1660, 2389, 2970, 72, 351, 476, 445, 256, - 3236, 2725, 2050, 1209, 200, 2350, 999, 2803, 1102, 2552, 493, 1581, 2485, 3203, 404, 744, - 892, 846, 604, 164, 2853, 2011, 965, 3042, 1582, 3241, 1359, 2592, 280, 1079, 1658, 2015, - 2148, 2055, 1734, 1183, 400, 2712, 1459, 3297, 1566, 2922, 705, 1571, 2189, 2557, 2673, 2535, - }; - kyber_poly_t a; // [1, 2, 3, ..., 256] - kyber_poly_t b; // [256, 255, ..., 1] - kyber_poly_t r_; - int i; - - for (i = 0; i < 256; i++) { - a[i] = i + 1; - b[i] = 256 - i; - } - - kyber_poly_ntt(a); - kyber_poly_ntt(b); - - kyber_poly_ntt_mul(r_, a, b); - kyber_poly_inv_ntt(r_); - - if (kyber_poly_equ(r_, r) != 1) { - error_print(); - return -1; - } - - printf("%s() ok\n", __FUNCTION__); - return 1; -} - -static int test_kyber_poly_add(void) -{ - kyber_poly_t a, b; - - if (kyber_poly_rand(a) != 1) { - error_print(); - return -1; - } - - // (a + a) - a =?= a - kyber_poly_add(b, a, a); - kyber_poly_sub(b, b, a); - if (kyber_poly_equ(a, b) != 1) { - error_print(); - return -1; - } - - // (a + a) + (a + a) =?= 4*a - kyber_poly_add(b, a, a); - kyber_poly_add(b, b, b); - kyber_poly_ntt_mul_scalar(a, 4, a); - - if (kyber_poly_equ(a, b) != 1) { - error_print(); - return -1; - } - - printf("%s() ok\n", __FUNCTION__); - return 1; -} - - - - -static int round_div(int a, int b) -{ - return (a + (b + 1)/2)/b; -} - -// a' = Decompress(Compress(a, d), d), check |a - a' mod+- q| <= round(q/2^(d + 1)) -static int test_kyber_poly_compress(void) -{ - kyber_poly_t a, b; - int16_t bound; - int i; - - if (kyber_poly_rand(a) != 1) { - error_print(); - return -1; - } - - // compress(a, 10) - if (kyber_poly_compress(a, 10, b) != 1) { - error_print(); - return -1; - } - if (kyber_poly_decompress(b, 10, b) != 1) { - error_print(); - return -1; - } - kyber_poly_sub(b, a, b); - if (kyber_poly_to_signed(b, b) != 1) { - error_print(); - return -1; - } - bound = round_div(KYBER_Q, 1 << (10 + 1)); - //printf("compress(-, 10) bound = %d\n", bound); - for (i = 0; i < 256; i++) { - if (b[i] < -bound || b[i] > bound) { - error_print(); - return -1; - } - } - - // compress(a, 4) - if (kyber_poly_compress(a, 4, b) != 1) { - error_print(); - return -1; - } - if (kyber_poly_decompress(b, 4, b) != 1) { - error_print(); - return -1; - } - kyber_poly_sub(b, a, b); - if (kyber_poly_to_signed(b, b) != 1) { - error_print(); - return -1; - } - bound = round_div(KYBER_Q, 1 << (4 + 1)); - //printf("compress(-, 4) bound = %d\n", bound); - for (i = 0; i < 256; i++) { - if (b[i] < -bound || b[i] > bound) { - error_print(); - return -1; - } - } - - // compress(a, 1) - if (kyber_poly_compress(a, 1, b) != 1) { - error_print(); - return -1; - } - if (kyber_poly_decompress(b, 1, b) != 1) { - error_print(); - return -1; - } - kyber_poly_sub(b, a, b); - if (kyber_poly_to_signed(b, b) != 1) { - error_print(); - return -1; - } - bound = round_div(KYBER_Q, 1 << (1 + 1)); - //printf("compress(-, 1) bound = %d\n", bound); - for (i = 0; i < 256; i++) { - if (b[i] < -bound || b[i] > bound) { - // FIXME: might failed - error_print(); - return -1; - } - } - - printf("%s() ok\n", __FUNCTION__); - return 1; -} - -static int test_kyber_poly_encode12(void) -{ - kyber_poly_t a; - kyber_poly_t b; - uint8_t bytes[384]; - - if (kyber_poly_rand(a) != 1) { - error_print(); - return -1; - } - if (kyber_poly_encode12(a, bytes) != 1) { - error_print(); - return -1; - } - if (kyber_poly_decode12(b, bytes) != 1) { - error_print(); - return -1; - } - if (kyber_poly_equ(a, b) != 1) { - error_print(); - return -1; - } - - printf("%s() ok\n", __FUNCTION__); - return 1; -} - -static int test_kyber_poly_encode10(void) -{ - kyber_poly_t a; - kyber_poly_t b; - uint8_t bytes[320]; - - if (kyber_poly_rand(a) != 1) { - error_print(); - return -1; - } - if (kyber_poly_compress(a, 10, a) != 1) { - error_print(); - return -1; - } - if (kyber_poly_encode10(a, bytes) != 1) { - error_print(); - return -1; - } - kyber_poly_decode10(b, bytes); - if (kyber_poly_equ(a, b) != 1) { - error_print(); - return -1; - } - - printf("%s() ok\n", __FUNCTION__); - return 1; -} - -static int test_kyber_poly_encode4(void) -{ - kyber_poly_t a; - kyber_poly_t b; - uint8_t bytes[128]; - - if (kyber_poly_rand(a) != 1) { - error_print(); - return -1; - } - if (kyber_poly_compress(a, 4, a) != 1) { - error_print(); - return -1; - } - if (kyber_poly_encode4(a, bytes) != 1) { - error_print(); - return -1; - } - kyber_poly_decode4(b, bytes); - if (kyber_poly_equ(a, b) != 1) { - error_print(); - return -1; - } - - printf("%s() ok\n", __FUNCTION__); - return 1; -} - -static int test_kyber_poly_encode1(void) -{ - kyber_poly_t a; - kyber_poly_t b; - uint8_t bytes[32]; - - if (kyber_poly_rand(a) != 1) { - error_print(); - return -1; - } - if (kyber_poly_compress(a, 1, a) != 1) { - error_print(); - return -1; - } - if (kyber_poly_encode1(a, bytes) != 1) { - error_print(); - return -1; - } - kyber_poly_decode1(b, bytes); - if (kyber_poly_equ(a, b) != 1) { - error_print(); - return -1; - } - - printf("%s() ok\n", __FUNCTION__); - return 1; -} int kyber_cpa_keygen(KYBER_CPA_PUBLIC_KEY *pk, KYBER_CPA_PRIVATE_KEY *sk) { @@ -1507,115 +1103,3 @@ int kyber_decap(const KYBER_PRIVATE_KEY *sk, const KYBER_CIPHERTEXT *c, uint8_t gmssl_secure_clear(K_r, sizeof(K_r)); return 1; } - -static int test_kyber_cpa(void) -{ - KYBER_CPA_PUBLIC_KEY pk; - KYBER_CPA_PRIVATE_KEY sk; - KYBER_CPA_CIPHERTEXT c; - uint8_t m[32]; - uint8_t r[32]; - uint8_t m_[32]; - - if (rand_bytes(m, 32) != 1) { - error_print(); - return -1; - } - if (rand_bytes(r, 32) != 1) { - error_print(); - return -1; - } - - if (kyber_cpa_keygen(&pk, &sk) != 1) { - error_print(); - return -1; - } - kyber_cpa_public_key_print(stderr, 0, 0, "publicKey", &pk); - kyber_cpa_private_key_print(stderr, 0, 0, "privateKey", &sk); - - if (kyber_cpa_encrypt(&pk, m, r, &c) != 1) { - error_print(); - return -1; - } - kyber_cpa_ciphertext_print(stderr, 0, 0, "ciphertext", &c); - - if (kyber_cpa_decrypt(&sk, &c, m_) != 1) { - error_print(); - return -1; - } - if (memcmp(m_, m, 32) != 0) { - error_print(); - return -1; - } - - printf("%s() ok\n", __FUNCTION__); - return 1; -} - -static int test_kyber_kem(void) -{ - KYBER_PRIVATE_KEY sk; - KYBER_PUBLIC_KEY pk; - KYBER_CIPHERTEXT c; - uint8_t K[32]; - uint8_t K_[32]; - - if (kyber_keygen(&pk, &sk) != 1) { - error_print(); - return -1; - } - - kyber_public_key_print(stderr, 0, 0, "pk", &pk); - kyber_private_key_print(stderr, 0, 0, "sk", &sk); - - - if (kyber_encap(&pk, &c, K) != 1) { - error_print(); - return -1; - } - kyber_ciphertext_print(stderr, 0, 0, "ciphertext", &c); - format_bytes(stderr, 0, 0, "KEM_K", K, 32); - - if (kyber_decap(&sk, &c, K_) != 1) { - error_print(); - return -1; - } - format_bytes(stderr, 0, 0, "DEC_K", K_, 32); - - - if (memcmp(K_, K, 32) != 0) { - error_print(); - return -1; - } - - printf("%s() ok\n", __FUNCTION__); - return 1; -} - - - - -int main(void) -{ - init_zeta(); - - if (test_kyber_poly_uniform_sample() != 1) goto err; - if (test_kyber_poly_cbd_sample() != 1) goto err; - if (test_kyber_poly_to_signed() != 1) goto err; - if (test_kyber_poly_compress() != 1) goto err; - if (test_kyber_poly_encode12() != 1) goto err; - if (test_kyber_poly_encode10() != 1) goto err; - if (test_kyber_poly_encode4() != 1) goto err; - if (test_kyber_poly_encode1() != 1) goto err; - if (test_kyber_poly_add() != 1) goto err; - if (test_kyber_poly_ntt() != 1) goto err; - if (test_kyber_poly_ntt_mul() != 1) goto err; - if (test_kyber_cpa() != 1) goto err; - if (test_kyber_kem() != 1) goto err; - - printf("%s all tests passed\n", __FILE__); - return 0; -err: - error_print(); - return 1; -} diff --git a/tests/kybertest.c b/tests/kybertest.c new file mode 100644 index 00000000..06bcf7f4 --- /dev/null +++ b/tests/kybertest.c @@ -0,0 +1,537 @@ +/* + * Copyright 2014-2025 The GmSSL Project. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +#include +#include +#include +#include +#include +#include +#include + + + + + +static int test_kyber_poly_uniform_sample(void) +{ + kyber_poly_t a; + uint8_t rho[32]; + + rand_bytes(rho, sizeof(rho)); + + + kyber_poly_uniform_sample(a, rho, 0, 0); + kyber_poly_to_signed(a, a); + + //kyber_poly_print(stderr, 0, 0, "a from uniform sampling", a); + + return 1; +} + +static int test_kyber_poly_cbd_sample(void) +{ + kyber_poly_t a; + uint8_t seed[32]; + + + rand_bytes(seed, sizeof(seed)); + kyber_poly_cbd_sample(a, 2, seed, 0); + kyber_poly_to_signed(a, a); + //kyber_poly_print(stderr, 0, 0, "cbd(eta=2)", a); + + kyber_poly_cbd_sample(a, 3, seed, 0); + kyber_poly_to_signed(a, a); + //kyber_poly_print(stderr, 0, 0, "cbd(eta=3)", a); + + return 1; +} + +static int test_kyber_poly_to_signed(void) +{ + kyber_poly_t a, b; + int i; + + + if (kyber_poly_rand(a) != 1) { + error_print(); + return -1; + } + if (kyber_poly_to_signed(a, b) != 1) { + error_print(); + return -1; + } + for (i = 0; i < 256; i++) { + if (b[i] < -(KYBER_Q - 1)/2 || b[i] > (KYBER_Q - 1)/2) { + error_print(); + return -1; + } + } + + if (kyber_poly_from_signed(b, b) != 1) { + error_print(); + return -1; + } + if (kyber_poly_equ(a, b) != 1) { + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + +static int test_kyber_poly_ntt(void) +{ + kyber_poly_t a, b; + int i; + + + if (kyber_poly_rand(a) != 1) { + error_print(); + return -1; + } + + memcpy(b, a, sizeof(kyber_poly_t)); + if (kyber_poly_ntt(b) != 1) { + error_print(); + return -1; + } + if (kyber_poly_inv_ntt(b) != 1) { + error_print(); + return -1; + } + + if (kyber_poly_equ(a, b) != 1) { + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + +/* +#!/bin/sage + + q = 3329 + n = 256 + + R. = PolynomialRing(Integers(q)) + Rq = R.quotient(x^n + 1, 'x') + + # a = 1 + 2*x + 3*x^2 + ... + 256*x^255 + coefficients = list(range(1, n+1)) + a = sum(coeff * x^i for i, coeff in enumerate(coefficients)) + a = Rq(a) + + # b = 256 + 255*x + ... + 1*x^255 + coefficients = list(range(n, 0, -1)) + b = sum(coeff * x^i for i, coeff in enumerate(coefficients)) + b = Rq(b) + + r = a * b + + r = r.lift() # Quotient ring element back to a polynomial + r = r.coefficients(sparse=False) + for i in range(0, n, 16): + print(r[i:i+16]) + +*/ +static int test_kyber_poly_ntt_mul(void) +{ + const kyber_poly_t r = { + 656, 772, 1140, 1758, 2624, 407, 1763, 32, 1870, 617, 2929, 2146, 1595, 1274, 1181, 1314, + 1671, 2250, 3049, 737, 1970, 88, 1747, 287, 2364, 1318, 476, 3165, 2725, 2483, 2437, 2585, + 2925, 126, 844, 1748, 2836, 777, 2227, 526, 2330, 979, 3129, 2120, 1279, 604, 93, 3073, + 2884, 2853, 2978, 3257, 359, 940, 1669, 2544, 234, 1395, 2696, 806, 2381, 761, 2602, 1244, + 14, 2239, 1259, 401, 2992, 2372, 1868, 1478, 1200, 1032, 972, 1018, 1168, 1420, 1772, 2222, + 2768, 79, 811, 1633, 2543, 210, 1290, 2452, 365, 1685, 3081, 1222, 2764, 1047, 2727, 1144, + 2954, 1497, 100, 2090, 807, 2907, 1730, 603, 2853, 1820, 831, 3213, 2306, 1437, 604, 3134, + 2367, 1630, 921, 238, 2908, 2271, 1654, 1055, 472, 3232, 2675, 2128, 1589, 1056, 527, 0, + 2802, 2273, 1740, 1201, 654, 97, 2857, 2274, 1675, 1058, 421, 3091, 2408, 1699, 962, 195, + 2725, 1892, 1023, 116, 2498, 1509, 476, 2726, 1599, 422, 2522, 1239, 3229, 1832, 375, 2185, + 602, 2282, 565, 2107, 248, 1644, 2964, 877, 2039, 3119, 786, 1696, 2518, 3250, 561, 1107, + 1557, 1909, 2161, 2311, 2357, 2297, 2129, 1851, 1461, 957, 337, 2928, 2070, 1090, 3315, 2085, + 727, 2568, 948, 2523, 633, 1934, 3095, 785, 1660, 2389, 2970, 72, 351, 476, 445, 256, + 3236, 2725, 2050, 1209, 200, 2350, 999, 2803, 1102, 2552, 493, 1581, 2485, 3203, 404, 744, + 892, 846, 604, 164, 2853, 2011, 965, 3042, 1582, 3241, 1359, 2592, 280, 1079, 1658, 2015, + 2148, 2055, 1734, 1183, 400, 2712, 1459, 3297, 1566, 2922, 705, 1571, 2189, 2557, 2673, 2535, + }; + kyber_poly_t a; // [1, 2, 3, ..., 256] + kyber_poly_t b; // [256, 255, ..., 1] + kyber_poly_t r_; + int i; + + for (i = 0; i < 256; i++) { + a[i] = i + 1; + b[i] = 256 - i; + } + + kyber_poly_ntt(a); + kyber_poly_ntt(b); + + kyber_poly_ntt_mul(r_, a, b); + kyber_poly_inv_ntt(r_); + + if (kyber_poly_equ(r_, r) != 1) { + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + +static int test_kyber_poly_add(void) +{ + kyber_poly_t a, b; + + if (kyber_poly_rand(a) != 1) { + error_print(); + return -1; + } + + // (a + a) - a =?= a + kyber_poly_add(b, a, a); + kyber_poly_sub(b, b, a); + if (kyber_poly_equ(a, b) != 1) { + error_print(); + return -1; + } + + // (a + a) + (a + a) =?= 4*a + kyber_poly_add(b, a, a); + kyber_poly_add(b, b, b); + kyber_poly_ntt_mul_scalar(a, 4, a); + + if (kyber_poly_equ(a, b) != 1) { + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + + + + +static int round_div(int a, int b) +{ + return (a + (b + 1)/2)/b; +} + +// a' = Decompress(Compress(a, d), d), check |a - a' mod+- q| <= round(q/2^(d + 1)) +static int test_kyber_poly_compress(void) +{ + kyber_poly_t a, b; + int16_t bound; + int i; + + if (kyber_poly_rand(a) != 1) { + error_print(); + return -1; + } + + // compress(a, 10) + if (kyber_poly_compress(a, 10, b) != 1) { + error_print(); + return -1; + } + if (kyber_poly_decompress(b, 10, b) != 1) { + error_print(); + return -1; + } + kyber_poly_sub(b, a, b); + if (kyber_poly_to_signed(b, b) != 1) { + error_print(); + return -1; + } + bound = round_div(KYBER_Q, 1 << (10 + 1)); + //printf("compress(-, 10) bound = %d\n", bound); + for (i = 0; i < 256; i++) { + if (b[i] < -bound || b[i] > bound) { + error_print(); + return -1; + } + } + + // compress(a, 4) + if (kyber_poly_compress(a, 4, b) != 1) { + error_print(); + return -1; + } + if (kyber_poly_decompress(b, 4, b) != 1) { + error_print(); + return -1; + } + kyber_poly_sub(b, a, b); + if (kyber_poly_to_signed(b, b) != 1) { + error_print(); + return -1; + } + bound = round_div(KYBER_Q, 1 << (4 + 1)); + //printf("compress(-, 4) bound = %d\n", bound); + for (i = 0; i < 256; i++) { + if (b[i] < -bound || b[i] > bound) { + error_print(); + return -1; + } + } + + // compress(a, 1) + if (kyber_poly_compress(a, 1, b) != 1) { + error_print(); + return -1; + } + if (kyber_poly_decompress(b, 1, b) != 1) { + error_print(); + return -1; + } + kyber_poly_sub(b, a, b); + if (kyber_poly_to_signed(b, b) != 1) { + error_print(); + return -1; + } + bound = round_div(KYBER_Q, 1 << (1 + 1)); + //printf("compress(-, 1) bound = %d\n", bound); + for (i = 0; i < 256; i++) { + if (b[i] < -bound || b[i] > bound) { + // FIXME: might failed + error_print(); + return -1; + } + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + +static int test_kyber_poly_encode12(void) +{ + kyber_poly_t a; + kyber_poly_t b; + uint8_t bytes[384]; + + if (kyber_poly_rand(a) != 1) { + error_print(); + return -1; + } + if (kyber_poly_encode12(a, bytes) != 1) { + error_print(); + return -1; + } + if (kyber_poly_decode12(b, bytes) != 1) { + error_print(); + return -1; + } + if (kyber_poly_equ(a, b) != 1) { + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + +static int test_kyber_poly_encode10(void) +{ + kyber_poly_t a; + kyber_poly_t b; + uint8_t bytes[320]; + + if (kyber_poly_rand(a) != 1) { + error_print(); + return -1; + } + if (kyber_poly_compress(a, 10, a) != 1) { + error_print(); + return -1; + } + if (kyber_poly_encode10(a, bytes) != 1) { + error_print(); + return -1; + } + kyber_poly_decode10(b, bytes); + if (kyber_poly_equ(a, b) != 1) { + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + +static int test_kyber_poly_encode4(void) +{ + kyber_poly_t a; + kyber_poly_t b; + uint8_t bytes[128]; + + if (kyber_poly_rand(a) != 1) { + error_print(); + return -1; + } + if (kyber_poly_compress(a, 4, a) != 1) { + error_print(); + return -1; + } + if (kyber_poly_encode4(a, bytes) != 1) { + error_print(); + return -1; + } + kyber_poly_decode4(b, bytes); + if (kyber_poly_equ(a, b) != 1) { + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + +static int test_kyber_poly_encode1(void) +{ + kyber_poly_t a; + kyber_poly_t b; + uint8_t bytes[32]; + + if (kyber_poly_rand(a) != 1) { + error_print(); + return -1; + } + if (kyber_poly_compress(a, 1, a) != 1) { + error_print(); + return -1; + } + if (kyber_poly_encode1(a, bytes) != 1) { + error_print(); + return -1; + } + kyber_poly_decode1(b, bytes); + if (kyber_poly_equ(a, b) != 1) { + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + + +static int test_kyber_cpa(void) +{ + KYBER_CPA_PUBLIC_KEY pk; + KYBER_CPA_PRIVATE_KEY sk; + KYBER_CPA_CIPHERTEXT c; + uint8_t m[32]; + uint8_t r[32]; + uint8_t m_[32]; + + if (rand_bytes(m, 32) != 1) { + error_print(); + return -1; + } + if (rand_bytes(r, 32) != 1) { + error_print(); + return -1; + } + + if (kyber_cpa_keygen(&pk, &sk) != 1) { + error_print(); + return -1; + } + kyber_cpa_public_key_print(stderr, 0, 0, "publicKey", &pk); + kyber_cpa_private_key_print(stderr, 0, 0, "privateKey", &sk); + + if (kyber_cpa_encrypt(&pk, m, r, &c) != 1) { + error_print(); + return -1; + } + kyber_cpa_ciphertext_print(stderr, 0, 0, "ciphertext", &c); + + if (kyber_cpa_decrypt(&sk, &c, m_) != 1) { + error_print(); + return -1; + } + if (memcmp(m_, m, 32) != 0) { + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + +static int test_kyber_kem(void) +{ + KYBER_PRIVATE_KEY sk; + KYBER_PUBLIC_KEY pk; + KYBER_CIPHERTEXT c; + uint8_t K[32]; + uint8_t K_[32]; + + if (kyber_keygen(&pk, &sk) != 1) { + error_print(); + return -1; + } + + kyber_public_key_print(stderr, 0, 0, "pk", &pk); + kyber_private_key_print(stderr, 0, 0, "sk", &sk); + + + if (kyber_encap(&pk, &c, K) != 1) { + error_print(); + return -1; + } + kyber_ciphertext_print(stderr, 0, 0, "ciphertext", &c); + format_bytes(stderr, 0, 0, "KEM_K", K, 32); + + if (kyber_decap(&sk, &c, K_) != 1) { + error_print(); + return -1; + } + format_bytes(stderr, 0, 0, "DEC_K", K_, 32); + + + if (memcmp(K_, K, 32) != 0) { + error_print(); + return -1; + } + + printf("%s() ok\n", __FUNCTION__); + return 1; +} + +int main(void) +{ + init_zeta(); + + if (test_kyber_poly_uniform_sample() != 1) goto err; + if (test_kyber_poly_cbd_sample() != 1) goto err; + if (test_kyber_poly_to_signed() != 1) goto err; + if (test_kyber_poly_compress() != 1) goto err; + if (test_kyber_poly_encode12() != 1) goto err; + if (test_kyber_poly_encode10() != 1) goto err; + if (test_kyber_poly_encode4() != 1) goto err; + if (test_kyber_poly_encode1() != 1) goto err; + if (test_kyber_poly_add() != 1) goto err; + if (test_kyber_poly_ntt() != 1) goto err; + if (test_kyber_poly_ntt_mul() != 1) goto err; + if (test_kyber_cpa() != 1) goto err; + if (test_kyber_kem() != 1) goto err; + + printf("%s all tests passed\n", __FILE__); + return 0; +err: + error_print(); + return 1; +} +