Update kyber.c

KEM passed
This commit is contained in:
Zhi Guan
2024-07-28 22:09:14 +08:00
parent ec23ce0677
commit 9f4dac228e

View File

@@ -54,6 +54,22 @@
#define KYBER_TEST #define KYBER_TEST
/*
CRYSTALS-Kyber Algorithm Specifications and Supporing Documentation (version 3.02)
FIPS-202 90s
XOF SHAKE-128 AES256-CTR MGF1-SM3
H SHA3-256 SHA256 SM3
G SHA3-512 SHA512 MGF1-SM3
PRF(s,b) SHAKE-256(s||b) AES256-CTR HKDF-SM3
KDF SHAKE-256 SHA256 HKDF-SM3
*/
typedef int16_t kyber_poly_t[256]; typedef int16_t kyber_poly_t[256];
typedef struct { typedef struct {
@@ -84,8 +100,6 @@ typedef KYBER_CPA_CIPHERTEXT KYBER_CIPHERTEXT;
void kyber_h_hash(const uint8_t *in, size_t inlen, uint8_t out[32]) void kyber_h_hash(const uint8_t *in, size_t inlen, uint8_t out[32])
{ {
SM3_CTX ctx; SM3_CTX ctx;
@@ -128,12 +142,18 @@ static int kyber_prf(const uint8_t seed[32], uint8_t N, size_t outlen, uint8_t *
return 1; return 1;
} }
static int kyber_kdf(const uint8_t in[64], uint8_t out[32]) static int kyber_kdf(const uint8_t in[64], uint8_t out[32])
{ {
return 0; uint8_t key[32];
sm3_hkdf_extract(NULL, 0, in, 64, key);
sm3_hkdf_expand(key, NULL, 0, 32, out);
gmssl_secure_clear(key, 32);
return 1;
} }
#define KYBER_FMT_POLY 1
#define KYBER_FMT_HEX 2
int kyber_poly_print(FILE *fp, int fmt, int ind, const char *label, const kyber_poly_t a) int kyber_poly_print(FILE *fp, int fmt, int ind, const char *label, const kyber_poly_t a)
{ {
int i; int i;
@@ -816,7 +836,7 @@ static int test_kyber_poly_ntt_mul(void)
return 1; return 1;
} }
static int test_kyber_poly_ops(void) static int test_kyber_poly_add(void)
{ {
kyber_poly_t a, b; kyber_poly_t a, b;
@@ -931,7 +951,7 @@ static int test_kyber_poly_compress(void)
//printf("compress(-, 1) bound = %d\n", bound); //printf("compress(-, 1) bound = %d\n", bound);
for (i = 0; i < 256; i++) { for (i = 0; i < 256; i++) {
if (b[i] < -bound || b[i] > bound) { if (b[i] < -bound || b[i] > bound) {
// 这块是有可能出现错误的 // FIXME: might failed
error_print(); error_print();
return -1; return -1;
} }
@@ -1058,7 +1078,6 @@ int kyber_cpa_keygen(KYBER_CPA_PUBLIC_KEY *pk, KYBER_CPA_PRIVATE_KEY *sk)
kyber_poly_t s[KYBER_K]; kyber_poly_t s[KYBER_K];
kyber_poly_t e[KYBER_K]; kyber_poly_t e[KYBER_K];
kyber_poly_t t[KYBER_K]; kyber_poly_t t[KYBER_K];
uint8_t d[64]; uint8_t d[64];
uint8_t *rho = d; uint8_t *rho = d;
uint8_t *sigma = d + 32; uint8_t *sigma = d + 32;
@@ -1072,43 +1091,32 @@ int kyber_cpa_keygen(KYBER_CPA_PUBLIC_KEY *pk, KYBER_CPA_PRIVATE_KEY *sk)
kyber_g_hash(d, 32, d); kyber_g_hash(d, 32, d);
format_bytes(stderr, 0, 0, "rho", rho, 32);
format_bytes(stderr, 0, 0, "sigma", sigma, 32);
// AHat[i][j] = Parse(XOR(rho, j, i)) // AHat[i][j] = Parse(XOR(rho, j, i))
for (i = 0; i < KYBER_K; i++) { for (i = 0; i < KYBER_K; i++) {
for (j = 0; j < KYBER_K; j++) { for (j = 0; j < KYBER_K; j++) {
kyber_poly_uniform_sample(A[i][j], rho, j, i); kyber_poly_uniform_sample(A[i][j], rho, j, i);
kyber_poly_print(stderr, 0, 0, "A[i][j]", A[i][j]);
} }
} }
// s[i] = CBD_eta1(PRF(sigma, N++)) // s[i] = CBD_eta1(PRF(sigma, N++))
for (i = 0; i < KYBER_K; i++) { for (i = 0; i < KYBER_K; i++) {
kyber_poly_cbd_sample(s[i], KYBER_ETA1, sigma, N); kyber_poly_cbd_sample(s[i], KYBER_ETA1, sigma, N);
//kyber_poly_set_all(s[i], 1);
kyber_poly_print(stderr, 0, 0, "s[i]", s[i]);
N++; N++;
} }
// e[i] = CBD_eta1(PRF(sigma, N++)) // e[i] = CBD_eta1(PRF(sigma, N++))
for (i = 0; i < KYBER_K; i++) { for (i = 0; i < KYBER_K; i++) {
kyber_poly_cbd_sample(e[i], KYBER_ETA1, sigma, N); kyber_poly_cbd_sample(e[i], KYBER_ETA1, sigma, N);
//kyber_poly_set_all(e[i], 0);
kyber_poly_print(stderr, 0, 0, "e[i]", e[i]);
N++; N++;
} }
// sHat = NTT(s) // sHat = NTT(s)
for (i = 0; i < KYBER_K; i++) { for (i = 0; i < KYBER_K; i++) {
kyber_poly_ntt(s[i]); kyber_poly_ntt(s[i]);
kyber_poly_print(stderr, 0, 0, "ntt(s[i])", s[i]);
} }
// eHat = NTT(e) // eHat = NTT(e)
for (i = 0; i < KYBER_K; i++) { for (i = 0; i < KYBER_K; i++) {
kyber_poly_ntt(e[i]); kyber_poly_ntt(e[i]);
kyber_poly_print(stderr, 0, 0, "ntt(e[i])", e[i]);
} }
for (i = 0; i < KYBER_K; i++) { for (i = 0; i < KYBER_K; i++) {
@@ -1123,14 +1131,8 @@ int kyber_cpa_keygen(KYBER_CPA_PUBLIC_KEY *pk, KYBER_CPA_PRIVATE_KEY *sk)
kyber_poly_add(t[i], t[i], tmp); kyber_poly_add(t[i], t[i], tmp);
} }
kyber_poly_add(t[i], t[i], e[i]); kyber_poly_add(t[i], t[i], e[i]);
kyber_poly_print(stderr, 0, 0, "ntt(t[i])", t[i]);
} }
// 这里实际上t没有压缩就是原来的值
// t - A^T * s 实际上就是很小的值
// output (pk, sk) // output (pk, sk)
for (i = 0; i < KYBER_K; i++) { for (i = 0; i < KYBER_K; i++) {
kyber_poly_encode12(t[i], pk->t[i]); kyber_poly_encode12(t[i], pk->t[i]);
@@ -1138,8 +1140,9 @@ int kyber_cpa_keygen(KYBER_CPA_PUBLIC_KEY *pk, KYBER_CPA_PRIVATE_KEY *sk)
} }
memcpy(pk->rho, rho, 32); memcpy(pk->rho, rho, 32);
gmssl_secure_clear(d, sizeof(d));
fprintf(stderr, "\n"); gmssl_secure_clear(s, sizeof(s));
gmssl_secure_clear(e, sizeof(e));
return 1; return 1;
} }
@@ -1165,15 +1168,58 @@ int kyber_cpa_keygen(KYBER_CPA_PUBLIC_KEY *pk, KYBER_CPA_PRIVATE_KEY *sk)
*/ */
int kyber_cpa_ciphertext_print(FILE *fp, int fmt, int ind, const char *label, const KYBER_CPA_CIPHERTEXT *c)
{
int i;
format_print(fp, fmt, ind, "%s\n", label);
ind += 4;
for (i = 0; i < KYBER_K; i++) {
format_print(fp, fmt, ind, "c1[%d] (Compress10(u[%d]))", i, i);
format_bytes(fp, fmt, 0, "", c->c1[i], KYBER_C1_SIZE);
}
format_bytes(fp, fmt, ind, "c2 (Compress4(v))", c->c2, KYBER_C2_SIZE);
return 1;
}
int kyber_ciphertext_print(FILE *fp, int fmt, int ind, const char *label, const KYBER_CPA_CIPHERTEXT *c)
{
return kyber_cpa_ciphertext_print(fp, fmt, ind, label, c);
}
int kyber_cpa_public_key_print(FILE *fp, int fmt, int ind, const char *label, const KYBER_CPA_PUBLIC_KEY *pk)
{
int i;
format_print(fp, fmt, ind, "%s\n", label);
ind += 4;
for (i = 0; i < KYBER_K; i++) {
format_print(fp, fmt, ind, "ntt(t[%d])", i);
format_bytes(fp, fmt, 0, "", pk->t[i], 384);
}
format_bytes(fp, fmt, ind, "rho", pk->rho, 32);
return 1;
}
int kyber_cpa_private_key_print(FILE *fp, int fmt, int ind, const char *label, const KYBER_CPA_PRIVATE_KEY *sk)
{
int i;
format_print(fp, fmt, ind, "%s\n", label);
ind += 4;
for (i = 0; i < KYBER_K; i++) {
format_print(fp, fmt, ind, "ntt(s[%d])", i);
format_bytes(fp, fmt, 0, "", sk->s[i], 384);
}
return 1;
}
int kyber_cpa_encrypt(const KYBER_CPA_PUBLIC_KEY *pk, const uint8_t in[32], int kyber_cpa_encrypt(const KYBER_CPA_PUBLIC_KEY *pk, const uint8_t in[32],
const uint8_t rand[32], KYBER_CPA_CIPHERTEXT *out) const uint8_t rand[32], KYBER_CPA_CIPHERTEXT *out)
{ {
int i, j;
int N = 0;
kyber_poly_t A[KYBER_K][KYBER_K]; kyber_poly_t A[KYBER_K][KYBER_K];
kyber_poly_t t[KYBER_K]; kyber_poly_t t[KYBER_K];
kyber_poly_t r[KYBER_K]; kyber_poly_t r[KYBER_K];
kyber_poly_t u[KYBER_K]; kyber_poly_t u[KYBER_K];
@@ -1181,52 +1227,41 @@ int kyber_cpa_encrypt(const KYBER_CPA_PUBLIC_KEY *pk, const uint8_t in[32],
kyber_poly_t e2; kyber_poly_t e2;
kyber_poly_t v; kyber_poly_t v;
kyber_poly_t m; kyber_poly_t m;
int i, j;
printf("%s() ok\n", __FUNCTION__); int N = 0;
// tHat = Decode12(pk) // tHat = Decode12(pk)
for (i = 0; i < KYBER_K; i++) { for (i = 0; i < KYBER_K; i++) {
kyber_poly_decode12(t[i], pk->t[i]); kyber_poly_decode12(t[i], pk->t[i]);
kyber_poly_print(stderr, 0, 0, "ntt(t[i])", t[i]);
} }
// AHat^T[i][j] = Parse(XOR(rho, i, j)) // AHat^T[i][j] = Parse(XOR(rho, i, j))
for (i = 0; i < KYBER_K; i++) { for (i = 0; i < KYBER_K; i++) {
for (j = 0; j < KYBER_K; j++) { for (j = 0; j < KYBER_K; j++) {
kyber_poly_uniform_sample(A[i][j], pk->rho, i, j); kyber_poly_uniform_sample(A[i][j], pk->rho, i, j);
kyber_poly_print(stderr, 0, 0, "A[i][j]", A[i][j]);
} }
} }
// r[i] = CBD_eta1(PRF(rand, N++)) // r[i] = CBD_eta1(PRF(rand, N++))
for (i = 0; i < KYBER_K; i++) { for (i = 0; i < KYBER_K; i++) {
kyber_poly_cbd_sample(r[i], KYBER_ETA1, rand, N); kyber_poly_cbd_sample(r[i], KYBER_ETA1, rand, N);
//kyber_poly_set_all(r[i], 2);
kyber_poly_print(stderr, 0, 0, "r[i]", r[i]);
N++; N++;
} }
// e1[i] = CBD_eta2(PRF(rand, N++)) // e1[i] = CBD_eta2(PRF(rand, N++))
for (i = 0; i < KYBER_K; i++) { for (i = 0; i < KYBER_K; i++) {
kyber_poly_cbd_sample(e1[i], KYBER_ETA2, rand, N); kyber_poly_cbd_sample(e1[i], KYBER_ETA2, rand, N);
//kyber_poly_set_all(e1[i], 0);
kyber_poly_print(stderr, 0, 0, "e1[i]", e1[i]);
N++; N++;
} }
// e2 = CBD_eta2(PRF(rand, N)) // e2 = CBD_eta2(PRF(rand, N))
kyber_poly_cbd_sample(e2, KYBER_ETA2, rand, N); kyber_poly_cbd_sample(e2, KYBER_ETA2, rand, N);
//kyber_poly_set_all(e2, 0);
kyber_poly_print(stderr, 0, 0, "e2", e2);
// rHat = NTT(r) // rHat = NTT(r)
for (i = 0; i < KYBER_K; i++) { for (i = 0; i < KYBER_K; i++) {
kyber_poly_ntt(r[i]); kyber_poly_ntt(r[i]);
kyber_poly_print(stderr, 0, 0, "ntt(r[i])", r[i]);
} }
// 实际上 u == A^T * r + e1
// u = NTT^-1(A^T * r) + e1 // u = NTT^-1(A^T * r) + e1
for (i = 0; i < KYBER_K; i++) { for (i = 0; i < KYBER_K; i++) {
kyber_poly_set_zero(u[i]); kyber_poly_set_zero(u[i]);
@@ -1240,11 +1275,9 @@ int kyber_cpa_encrypt(const KYBER_CPA_PUBLIC_KEY *pk, const uint8_t in[32],
kyber_poly_inv_ntt(u[i]); kyber_poly_inv_ntt(u[i]);
kyber_poly_add(u[i], u[i], e1[i]); kyber_poly_add(u[i], u[i], e1[i]);
kyber_poly_print(stderr, 0, 0, "u[i] = (A^T * r)[i]", u[i]);
} }
// v = NTT^-1( t^T * r ) + e2 + round(q/2)*m // v = NTT^-1( t^T * r ) + e2 + round(q/2)*m
kyber_poly_set_zero(v); kyber_poly_set_zero(v);
for (i = 0; i < KYBER_K; i++) { for (i = 0; i < KYBER_K; i++) {
kyber_poly_t tmp; kyber_poly_t tmp;
@@ -1253,15 +1286,6 @@ int kyber_cpa_encrypt(const KYBER_CPA_PUBLIC_KEY *pk, const uint8_t in[32],
} }
kyber_poly_inv_ntt(v); kyber_poly_inv_ntt(v);
kyber_poly_add(v, v, e2); kyber_poly_add(v, v, e2);
kyber_poly_print(stderr, 0, 0, "t^T * r + e2", v);
// check
// v = t^T * r + e2 == s^T * (A^T * r) == s^T * (u)
// 验证 v 和 s^T * u 大概是相等的
// 这里的主要问题是 s 的值是不知道的并且s 是ntt(s) 而不是原始s
if (0) { if (0) {
kyber_poly_t s[KYBER_K]; kyber_poly_t s[KYBER_K];
@@ -1279,26 +1303,15 @@ int kyber_cpa_encrypt(const KYBER_CPA_PUBLIC_KEY *pk, const uint8_t in[32],
kyber_poly_add(v_, v_, tmp); kyber_poly_add(v_, v_, tmp);
} }
kyber_poly_print(stderr, 0, 0, "test v", v);
kyber_poly_print(stderr, 0, 0, "test v", v_);
kyber_poly_sub(v_, v_, v); kyber_poly_sub(v_, v_, v);
kyber_poly_to_signed(v_, v_); kyber_poly_to_signed(v_, v_);
kyber_poly_print(stderr, 0, 0, "delta", v_); kyber_poly_print(stderr, 0, 0, "delta", v_);
} }
kyber_poly_decode1(m, in); kyber_poly_decode1(m, in);
kyber_poly_decompress(m, 1, m); kyber_poly_decompress(m, 1, m);
kyber_poly_add(v, v, m); kyber_poly_add(v, v, m);
// c1 = Encode10(Compress(u, 10)) // c1 = Encode10(Compress(u, 10))
for (i = 0; i < KYBER_K; i++) { for (i = 0; i < KYBER_K; i++) {
kyber_poly_compress(u[i], 10, u[i]); kyber_poly_compress(u[i], 10, u[i]);
@@ -1309,6 +1322,11 @@ int kyber_cpa_encrypt(const KYBER_CPA_PUBLIC_KEY *pk, const uint8_t in[32],
kyber_poly_compress(v, 4, v); kyber_poly_compress(v, 4, v);
kyber_poly_encode4(v, out->c2); kyber_poly_encode4(v, out->c2);
gmssl_secure_clear(m, sizeof(m));
gmssl_secure_clear(r, sizeof(r));
gmssl_secure_clear(e1, sizeof(e1));
gmssl_secure_clear(e2, sizeof(e2));
return 1; return 1;
} }
@@ -1354,45 +1372,75 @@ int kyber_cpa_decrypt(const KYBER_CPA_PRIVATE_KEY *sk, const KYBER_CPA_CIPHERTEX
kyber_poly_compress(m, 1, m); kyber_poly_compress(m, 1, m);
kyber_poly_encode1(m, out); kyber_poly_encode1(m, out);
gmssl_secure_clear(s, sizeof(s));
gmssl_secure_clear(m, sizeof(m));
return 1; return 1;
} }
int kyber_keygen(KYBER_PUBLIC_KEY *pk, KYBER_PRIVATE_KEY *sk) int kyber_keygen(KYBER_PUBLIC_KEY *pk, KYBER_PRIVATE_KEY *sk)
{ {
if (kyber_cpa_keygen(pk, &sk->sk) != 1) {
error_print();
return -1;
}
memcpy(&sk->pk, pk, sizeof(KYBER_PUBLIC_KEY));
kyber_h_hash((uint8_t *)pk, sizeof(KYBER_CPA_PUBLIC_KEY), sk->pk_hash);
if (rand_bytes(sk->z, 32) != 1) { if (rand_bytes(sk->z, 32) != 1) {
error_print(); error_print();
return -1; return -1;
} }
if (kyber_cpa_keygen(&sk->pk, &sk->sk) != 1) {
error_print();
return -1;
}
kyber_h_hash((uint8_t *)&sk->pk, sizeof(KYBER_CPA_PUBLIC_KEY), sk->pk_hash);
return 1; return 1;
} }
int kyber_private_key_print(FILE *fp, int fmt, int ind, const char *label, const KYBER_PRIVATE_KEY *sk)
{
format_print(fp, fmt, ind, "%s\n", label);
ind += 4;
kyber_cpa_private_key_print(fp, fmt, ind, "privateKey", &sk->sk);
kyber_cpa_public_key_print(fp, fmt, ind, "publicKey", &sk->pk);
format_bytes(fp, fmt, ind, "publicKeyHash", sk->pk_hash, 32);
format_bytes(fp, fmt, ind, "z", sk->z, 32);
return 1;
}
int kyber_public_key_print(FILE *fp, int fmt, int ind, const char *label, const KYBER_PUBLIC_KEY *pk)
{
return kyber_cpa_public_key_print(fp, fmt, ind, label, pk);
}
int kyber_encap(const KYBER_PUBLIC_KEY *pk, KYBER_CIPHERTEXT *c, uint8_t K[32]) int kyber_encap(const KYBER_PUBLIC_KEY *pk, KYBER_CIPHERTEXT *c, uint8_t K[32])
{ {
uint8_t m[64]; uint8_t m_h[64];
uint8_t K_r[64]; uint8_t K_r[64];
uint8_t *m = m_h;
uint8_t *h = m_h + 32;
uint8_t *K_ = K_r;
uint8_t *r = K_r + 32; uint8_t *r = K_r + 32;
// m = rand(32)
if (rand_bytes(m, 32) != 1) { if (rand_bytes(m, 32) != 1) {
error_print(); error_print();
return -1; return -1;
} }
// m = H(rand(32)) // m = H(m)
kyber_h_hash(m, 32, m); kyber_h_hash(m, 32, m);
// h = H(pk)
kyber_h_hash((const uint8_t *)pk, sizeof(KYBER_PUBLIC_KEY), h);
// (K_, r) = G(m || H(pk)) // (K_, r) = G(m || H(pk))
kyber_h_hash((const uint8_t *)pk, sizeof(KYBER_PUBLIC_KEY), m + 32); kyber_g_hash(m_h, 64, K_r);
kyber_g_hash(m, 64, K_r);
// c = Kyber.CPA.Enc(pk, m, r) // c = Kyber.CPA.Enc(pk, m, r)
if (kyber_cpa_encrypt(pk, m, r, c) != 1) { if (kyber_cpa_encrypt(pk, m, r, c) != 1) {
@@ -1400,10 +1448,14 @@ int kyber_encap(const KYBER_PUBLIC_KEY *pk, KYBER_CIPHERTEXT *c, uint8_t K[32])
return -1; return -1;
} }
// H(c)
kyber_h_hash((uint8_t *)c, sizeof(KYBER_CIPHERTEXT), r);
// K = KDF(K_ || H(c)) // K = KDF(K_ || H(c))
kyber_h_hash((uint8_t *)c, sizeof(KYBER_CIPHERTEXT), K_r + 32);
kyber_kdf(K_r, K); kyber_kdf(K_r, K);
gmssl_secure_clear(m_h, sizeof(m_h));
gmssl_secure_clear(K_r, sizeof(K_r));
return 1; return 1;
} }
@@ -1411,77 +1463,130 @@ int kyber_decap(const KYBER_PRIVATE_KEY *sk, const KYBER_CIPHERTEXT *c, uint8_t
{ {
uint8_t m_h[64]; uint8_t m_h[64];
uint8_t K_r[64]; uint8_t K_r[64];
uint8_t *m = m_h;
uint8_t *h = m_h + 32;
uint8_t *K_ = K_r;
uint8_t *r = K_r + 32; uint8_t *r = K_r + 32;
KYBER_CIPHERTEXT c_; KYBER_CIPHERTEXT c_;
uint8_t c_hash[32]; uint8_t c_hash[32];
if (kyber_cpa_decrypt(&sk->sk, c, m_h) != 1) {
// m' = Dec(sk, c)
if (kyber_cpa_decrypt(&sk->sk, c, m) != 1) {
error_print(); error_print();
return -1; return -1;
} }
// (K, r) = G(m || H(pk)) // h = H(pk)
memcpy(m_h + 32, sk->pk_hash, 32); memcpy(h, sk->pk_hash, 32);
// (K_, r) = G(m || h)
kyber_g_hash(m_h, 64, K_r); kyber_g_hash(m_h, 64, K_r);
if (kyber_cpa_encrypt(&sk->pk, m_h, r, &c_) != 1) { // c_ = CPA.Enc(pk, m, r)
if (kyber_cpa_encrypt(&sk->pk, m, r, &c_) != 1) {
gmssl_secure_clear(m_h, sizeof(m_h));
gmssl_secure_clear(K_r, sizeof(K_r));
error_print(); error_print();
return -1; return -1;
} }
// H(c)
kyber_h_hash((uint8_t *)c, sizeof(KYBER_CIPHERTEXT), r); kyber_h_hash((uint8_t *)c, sizeof(KYBER_CIPHERTEXT), r);
if (memcmp(c, &c_, sizeof(KYBER_CIPHERTEXT)) == 0) { if (memcmp(c, &c_, sizeof(KYBER_CIPHERTEXT)) == 0) {
// K = KDF(K_||H(c))
kyber_kdf(K_r, K); kyber_kdf(K_r, K);
} else { } else {
memcpy(K_r, sk->z, 32); error_print();
memcpy(K_r, sk->z, 32); // TODO: const time
kyber_kdf(K_r, K); kyber_kdf(K_r, K);
} }
gmssl_secure_clear(m_h, sizeof(m_h));
gmssl_secure_clear(K_r, sizeof(K_r));
return 1; return 1;
} }
static int test_kyber_cpa_keygen(void) static int test_kyber_cpa(void)
{ {
KYBER_CPA_PUBLIC_KEY pk; KYBER_CPA_PUBLIC_KEY pk;
KYBER_CPA_PRIVATE_KEY sk; KYBER_CPA_PRIVATE_KEY sk;
KYBER_CPA_CIPHERTEXT c; KYBER_CPA_CIPHERTEXT c;
uint8_t m[32];
uint8_t r[32];
uint8_t m_[32];
uint8_t r[32] = {0};
uint8_t m[32] = {1,0,1,0};
uint8_t K[32] = {0};
if (rand_bytes(r, 32) != 1) {
error_print();
return -1;
}
if (rand_bytes(m, 32) != 1) { if (rand_bytes(m, 32) != 1) {
error_print(); error_print();
return -1; return -1;
} }
if (rand_bytes(r, 32) != 1) {
kyber_cpa_keygen(&pk, &sk);
kyber_cpa_encrypt(&pk, m, r, &c);
kyber_cpa_decrypt(&sk, &c, K);
format_bytes(stderr, 0, 0, "m", m, 32);
format_bytes(stderr, 0, 0, "out", K, 32);
if (memcmp(K, m, 32) != 0) {
error_print(); error_print();
return -1; return -1;
} }
if (kyber_cpa_keygen(&pk, &sk) != 1) {
error_print();
return -1;
}
kyber_cpa_public_key_print(stderr, 0, 0, "publicKey", &pk);
kyber_cpa_private_key_print(stderr, 0, 0, "privateKey", &sk);
if (kyber_cpa_encrypt(&pk, m, r, &c) != 1) {
error_print();
return -1;
}
kyber_cpa_ciphertext_print(stderr, 0, 0, "ciphertext", &c);
if (kyber_cpa_decrypt(&sk, &c, m_) != 1) {
error_print();
return -1;
}
if (memcmp(m_, m, 32) != 0) {
error_print();
return -1;
}
printf("%s() ok\n", __FUNCTION__);
return 1;
}
static int test_kyber_kem(void)
{
KYBER_PRIVATE_KEY sk;
KYBER_PUBLIC_KEY pk;
KYBER_CIPHERTEXT c;
uint8_t K[32];
uint8_t K_[32];
if (kyber_keygen(&pk, &sk) != 1) {
error_print();
return -1;
}
kyber_public_key_print(stderr, 0, 0, "pk", &pk);
kyber_private_key_print(stderr, 0, 0, "sk", &sk);
if (kyber_encap(&pk, &c, K) != 1) {
error_print();
return -1;
}
kyber_ciphertext_print(stderr, 0, 0, "ciphertext", &c);
format_bytes(stderr, 0, 0, "KEM_K", K, 32);
if (kyber_decap(&sk, &c, K_) != 1) {
error_print();
return -1;
}
format_bytes(stderr, 0, 0, "DEC_K", K_, 32);
if (memcmp(K_, K, 32) != 0) {
error_print();
return -1;
}
printf("%s() ok\n", __FUNCTION__); printf("%s() ok\n", __FUNCTION__);
return 1; return 1;
@@ -1490,28 +1595,23 @@ static int test_kyber_cpa_keygen(void)
int main(void) int main(void)
{ {
init_zeta(); init_zeta();
if (test_kyber_cpa_keygen() != 1) goto err;
return 1;
if (test_kyber_poly_ops() != 1) goto err;
if (test_kyber_poly_ntt_mul() != 1) goto err;
if (test_kyber_poly_ntt() != 1) goto err;
if (test_kyber_poly_uniform_sample() != 1) goto err; if (test_kyber_poly_uniform_sample() != 1) goto err;
if (test_kyber_poly_cbd_sample() != 1) goto err; if (test_kyber_poly_cbd_sample() != 1) goto err;
if (test_kyber_poly_to_signed() != 1) goto err; if (test_kyber_poly_to_signed() != 1) goto err;
if (test_kyber_poly_compress() != 1) goto err; if (test_kyber_poly_compress() != 1) goto err;
if (test_kyber_poly_encode12() != 1) goto err; if (test_kyber_poly_encode12() != 1) goto err;
if (test_kyber_poly_encode10() != 1) goto err; if (test_kyber_poly_encode10() != 1) goto err;
if (test_kyber_poly_encode4() != 1) goto err; if (test_kyber_poly_encode4() != 1) goto err;
if (test_kyber_poly_encode1() != 1) goto err; if (test_kyber_poly_encode1() != 1) goto err;
if (test_kyber_poly_add() != 1) goto err;
if (test_kyber_poly_ntt() != 1) goto err;
if (test_kyber_poly_ntt_mul() != 1) goto err;
if (test_kyber_cpa() != 1) goto err;
if (test_kyber_kem() != 1) goto err;
printf("%s all tests passed\n", __FILE__); printf("%s all tests passed\n", __FILE__);
return 0; return 0;