def run_test_multiply_matrix(self, message, mat): matrix_prod_message = matrix_vector_multiply(mat, message) plain = self.encoder.encode(message, self.scaling_factor) ciph = self.encryptor.encrypt(plain) rot_keys = {} matrix_len = len(mat) matrix_len_factor1 = int(sqrt(matrix_len)) if matrix_len != matrix_len_factor1 * matrix_len_factor1: matrix_len_factor1 = int(sqrt(2 * matrix_len)) matrix_len_factor2 = matrix_len // matrix_len_factor1 for i in range(1, matrix_len_factor1): rot_keys[i] = self.key_generator.generate_rot_key(i) for j in range(matrix_len_factor2): rot_keys[matrix_len_factor1 * j] = self.key_generator.generate_rot_key( matrix_len_factor1 * j) for i in range(matrix_len): rot_keys[i] = self.key_generator.generate_rot_key(i) ciph_prod = self.evaluator.multiply_matrix(ciph, mat, rot_keys, self.encoder) decrypted_prod = self.decryptor.decrypt(ciph_prod) decoded_prod = self.encoder.decode(decrypted_prod) check_complex_vector_approx_eq(matrix_prod_message, decoded_prod, error=0.01)
def run_test_multiply(self, vec1, vec2): """Checks that encode satisfies homomorphic multiplication. Encodes two input vectors, and check that their product matches before and after encoding. Before encoding, the product is component-wise, and after encoding, the product is polynomial, since the encoding includes an inverse FFT operation. Args: vec1 (list (complex)): First vector. vec2 (list (complex)): Second vector. Raises: ValueError: An error if test fails. """ orig_prod = [0] * (self.degree // 2) for i in range(self.degree // 2): orig_prod[i] = vec1[i] * vec2[i] plain1 = self.encoder.encode(vec1, self.scaling_factor) plain2 = self.encoder.encode(vec2, self.scaling_factor) plain_prod = Plaintext(plain1.poly.multiply_naive(plain2.poly), scaling_factor=self.scaling_factor**2) expected = self.encoder.decode(plain_prod) check_complex_vector_approx_eq(expected, orig_prod, error=0.1)
def run_test_secret_key_encrypt_decrypt(self, message): plain = self.encoder.encode(message, self.scaling_factor) ciphertext = self.encryptor.encrypt_with_secret_key(plain) decrypted = self.decryptor.decrypt(ciphertext) decoded = self.encoder.decode(decrypted) check_complex_vector_approx_eq(message, decoded, 0.001)
def run_test_exp(self, message): plain = self.encoder.encode(message, self.scaling_factor) const = 2 * math.pi plain_exp = taylor_exp(message, const, num_iterations=5) ciph = self.encryptor.encrypt(plain) ciph_exp = self.evaluator.exp(ciph, const, self.relin_key, self.encoder) decrypted_exp = self.decryptor.decrypt(ciph_exp) decrypted_exp = self.encoder.decode(decrypted_exp) check_complex_vector_approx_eq(plain_exp, decrypted_exp, error=0.1)
def run_test_secret_key_add(self, message1, message2): poly1 = Polynomial(self.degree // 2, message1) poly2 = Polynomial(self.degree // 2, message2) plain1 = self.encoder.encode(message1, self.scaling_factor) plain2 = self.encoder.encode(message2, self.scaling_factor) plain_sum = poly1.add(poly2) ciph1 = self.encryptor.encrypt_with_secret_key(plain1) ciph2 = self.encryptor.encrypt_with_secret_key(plain2) ciph_sum = self.evaluator.add(ciph1, ciph2) decrypted_sum = self.decryptor.decrypt(ciph_sum) decoded_sum = self.encoder.decode(decrypted_sum) check_complex_vector_approx_eq(plain_sum.coeffs, decoded_sum, error=0.001)
def run_test_subtract(self, message1, message2): poly1 = Polynomial(self.degree // 2, message1) poly2 = Polynomial(self.degree // 2, message2) plain1 = self.encoder.encode(message1, self.scaling_factor) plain2 = self.encoder.encode(message2, self.scaling_factor) plain_diff = poly1.subtract(poly2) ciph1 = self.encryptor.encrypt(plain1) ciph2 = self.encryptor.encrypt(plain2) ciph_diff = self.evaluator.subtract(ciph1, ciph2) decrypted_diff = self.decryptor.decrypt(ciph_diff) decoded_diff = self.encoder.decode(decrypted_diff) check_complex_vector_approx_eq(plain_diff.coeffs, decoded_diff, error=0.005)
def run_test_multiply(self, message1, message2): num_slots = len(message1) plain1 = self.encoder.encode(message1, self.scaling_factor) plain2 = self.encoder.encode(message2, self.scaling_factor) plain_prod = [0] * num_slots for i in range(num_slots): plain_prod[i] = message1[i] * message2[i] ciph1 = self.encryptor.encrypt(plain1) ciph2 = self.encryptor.encrypt(plain2) ciph_prod = self.evaluator.multiply(ciph1, ciph2, self.relin_key) decrypted_prod = self.decryptor.decrypt(ciph_prod) decoded_prod = self.encoder.decode(decrypted_prod) check_complex_vector_approx_eq(plain_prod, decoded_prod, error=0.01)
def run_test_add_plain(self, message1, message2): plain1 = self.encoder.encode(message1, self.scaling_factor) plain2 = self.encoder.encode(message2, self.scaling_factor) plain_sum = [0] * (self.degree // 2) for i in range(self.degree // 2): plain_sum[i] = message1[i] + message2[i] ciph1 = self.encryptor.encrypt(plain1) ciph_sum = self.evaluator.add_plain(ciph1, plain2) decrypted_sum = self.decryptor.decrypt(ciph_sum) decoded_sum = self.encoder.decode(decrypted_sum) check_complex_vector_approx_eq(plain_sum, decoded_sum, error=0.001)
def run_test_multiply_plain(self, message1, message2): plain1 = self.encoder.encode(message1, self.scaling_factor) plain2 = self.encoder.encode(message2, self.scaling_factor) plain_prod = [0] * (self.degree // 2) for i in range(self.degree // 2): plain_prod[i] = message1[i] * message2[i] ciph1 = self.encryptor.encrypt(plain1) ciph_prod = self.evaluator.multiply_plain(ciph1, plain2) decrypted_prod = self.decryptor.decrypt(ciph_prod) decoded_prod = self.encoder.decode(decrypted_prod) check_complex_vector_approx_eq(plain_prod, decoded_prod, error=0.001)
def run_test_conjugate(self, message): poly = Polynomial(self.degree // 2, message) plain = self.encoder.encode(message, self.scaling_factor) conj_message = [c.conjugate() for c in poly.coeffs] ciph = self.encryptor.encrypt(plain) conj_key = self.key_generator.generate_conj_key() ciph_conj = self.evaluator.conjugate(ciph, conj_key) decrypted_conj = self.decryptor.decrypt(ciph_conj) decoded_conj = self.encoder.decode(decrypted_conj) check_complex_vector_approx_eq(conj_message, decoded_conj, error=0.005)
def run_test_encode_decode(self, vec): """Checks that encode and decode are inverses. Encodes the input vector, decodes the result, and checks that they match. Args: vec (list (complex)): Vector of complex numbers to encode. Raises: ValueError: An error if test fails. """ plain = self.encoder.encode(vec, self.scaling_factor) value = self.encoder.decode(plain) check_complex_vector_approx_eq(vec, value, error=0.1)
def test_fft_inverses(self): """Checks that fft_fwd and fft_inv are inverses. Performs the FFT on the input vector, performs the inverse FFT on the result, and checks that they match. Raises: ValueError: An error if test fails. """ vec = sample_uniform(0, 7, self.num_slots) fft_vec = self.fft.fft_fwd(vec) to_check = self.fft.fft_inv(fft_vec) check_complex_vector_approx_eq(vec, to_check, 0.000001, "fft_inv is not the inverse of fft_fwd")
def run_test_bootstrap(self, message): num_slots = len(message) plain = self.encoder.encode(message, self.scaling_factor) ciph = self.encryptor.encrypt(plain) rot_keys = {} for i in range(num_slots): rot_keys[i] = self.key_generator.generate_rot_key(i) conj_key = self.key_generator.generate_conj_key() ciph, new_ciph = self.evaluator.bootstrap(ciph, rot_keys, conj_key, self.relin_key, self.encoder) new_plain = self.decryptor.decrypt(new_ciph) new_plain = self.encoder.decode(new_plain) check_complex_vector_approx_eq(message, new_plain, error=0.05)
def run_test_rotate(self, message, r): poly = Polynomial(self.degree // 2, message) plain = self.encoder.encode(message, self.scaling_factor) rot_message = [0] * poly.ring_degree for i in range(poly.ring_degree): rot_message[i] = poly.coeffs[(i + r) % poly.ring_degree] ciph = self.encryptor.encrypt(plain) rot_key = self.key_generator.generate_rot_key(r) ciph_rot = self.evaluator.rotate(ciph, r, rot_key) decrypted_rot = self.decryptor.decrypt(ciph_rot) decoded_rot = self.encoder.decode(decrypted_rot) check_complex_vector_approx_eq(rot_message, decoded_rot, error=0.005)
def test_embedding_inverses(self): """Checks that embedding and embedding_inv are inverses. Computes the canonical embedding on the input vector, performs the inverse embedding on the result, and checks that they match. Raises: ValueError: An error if test fails. """ n = 1 << 5 context = FFTContext(fft_length=4 * n) vec = sample_uniform(0, 7, n) fft_vec = context.embedding(vec) to_check = context.embedding_inv(fft_vec) check_complex_vector_approx_eq(vec, to_check, 0.000001, "embedding_inv is not the inverse of embedding")
def run_test_simple_rotate(self, message, rot): poly = Polynomial(self.degree // 2, message) plain = self.encoder.encode(message, self.scaling_factor) rot_message = [0] * poly.ring_degree for i in range(poly.ring_degree): rot_message[i] = poly.coeffs[(i + rot) % poly.ring_degree] ciph = self.encryptor.encrypt(plain) ciph_rot0 = ciph.c0.rotate(rot).mod_small(self.ciph_modulus) ciph_rot1 = ciph.c1.rotate(rot).mod_small(self.ciph_modulus) ciph_rot = Ciphertext(ciph_rot0, ciph_rot1, ciph.scaling_factor, self.ciph_modulus) decryptor = CKKSDecryptor(self.params, SecretKey(self.secret_key.s.rotate(rot))) decrypted_rot = decryptor.decrypt(ciph_rot) decoded_rot = self.encoder.decode(decrypted_rot) check_complex_vector_approx_eq(rot_message, decoded_rot, error=0.005)
def test_embedding(self): """Checks that canonical embedding is correct. Checks that the embedding matches the evaluations of the roots of unity at indices that are 1 (mod) 4. Raises: ValueError: An error if test fails. """ coeffs = [10, 34, 71, 31, 1, 2, 3, 4] poly = Polynomial(self.num_slots, coeffs) fft_length = self.num_slots * 4 embedding = self.fft.embedding(coeffs) evals = [] power = 1 for i in range(1, fft_length, 4): angle = 2 * pi * power / fft_length root_of_unity = complex(cos(angle), sin(angle)) evals.append(poly.evaluate(root_of_unity)) power = (power * 5) % fft_length check_complex_vector_approx_eq(embedding, evals, 0.00001)
def run_test_coeff_to_slot(self, message): num_slots = len(message) plain = self.encoder.encode(message, self.scaling_factor) plain_ans1 = mat.scalar_multiply(plain.poly.coeffs[:num_slots], 1 / self.scaling_factor) plain_ans2 = mat.scalar_multiply(plain.poly.coeffs[num_slots:], 1 / self.scaling_factor) ciph = self.encryptor.encrypt(plain) rot_keys = {} for i in range(num_slots): rot_keys[i] = self.key_generator.generate_rot_key(i) conj_key = self.key_generator.generate_conj_key() ciph1, ciph2 = self.evaluator.coeff_to_slot(ciph, rot_keys, conj_key, self.encoder) decrypted_1 = self.decryptor.decrypt(ciph1) decrypted_1 = self.encoder.decode(decrypted_1) decrypted_2 = self.decryptor.decrypt(ciph2) decrypted_2 = self.encoder.decode(decrypted_2) check_complex_vector_approx_eq(plain_ans1, decrypted_1, error=0.01) check_complex_vector_approx_eq(plain_ans2, decrypted_2, error=0.01)
def run_test_slot_to_coeff(self, message): num_slots = len(message) plain = self.encoder.encode(message, self.scaling_factor) ciph = self.encryptor.encrypt(plain) rot_keys = {} for i in range(num_slots): rot_keys[i] = self.key_generator.generate_rot_key(i) conj_key = self.key_generator.generate_conj_key() ciph1, ciph2 = self.evaluator.coeff_to_slot(ciph, rot_keys, conj_key, self.encoder) decrypted_1 = self.decryptor.decrypt(ciph1) decrypted_1 = self.encoder.decode(decrypted_1) decrypted_2 = self.decryptor.decrypt(ciph2) decrypted_2 = self.encoder.decode(decrypted_2) ciph_ans = self.evaluator.slot_to_coeff(ciph1, ciph2, rot_keys, self.encoder) decrypted = self.decryptor.decrypt(ciph_ans) plain_ans1 = mat.scalar_multiply(decrypted.poly.coeffs[:num_slots], 1 / decrypted.scaling_factor) plain_ans2 = mat.scalar_multiply(decrypted.poly.coeffs[num_slots:], 1 / decrypted.scaling_factor) prim_root = math.e**(math.pi * 1j / 2 / num_slots) primitive_roots = [prim_root] * (num_slots) for i in range(1, num_slots): primitive_roots[i] = primitive_roots[i - 1]**5 mat_0 = [[1] * (num_slots) for _ in range(num_slots)] mat_1 = [[1] * (num_slots) for _ in range(num_slots)] for i in range(num_slots): for k in range(1, num_slots): mat_0[i][k] = mat_0[i][k - 1] * primitive_roots[i] for i in range(num_slots): mat_1[i][0] = mat_0[i][-1] * primitive_roots[i] for i in range(num_slots): for k in range(1, num_slots): mat_1[i][k] = mat_1[i][k - 1] * primitive_roots[i] plain_1 = mat.matrix_vector_multiply(mat_0, decrypted_1) plain_2 = mat.matrix_vector_multiply(mat_1, decrypted_2) new_plain = [plain_1[i] + plain_2[i] for i in range(num_slots)] encoded = self.encoder.encode(new_plain, self.scaling_factor) plain_check1 = mat.scalar_multiply(encoded.poly.coeffs[:num_slots], 1 / decrypted.scaling_factor) plain_check2 = mat.scalar_multiply(encoded.poly.coeffs[num_slots:], 1 / decrypted.scaling_factor) decrypted = self.encoder.decode(decrypted) check_complex_vector_approx_eq(decrypted, new_plain, error=0.001) check_complex_vector_approx_eq(plain_check1, plain_ans1, error=0.001) check_complex_vector_approx_eq(plain_check2, plain_ans2, error=0.001) check_complex_vector_approx_eq(decrypted_1, plain_ans1, error=0.001) check_complex_vector_approx_eq(decrypted_2, plain_ans2, error=0.001) check_complex_vector_approx_eq(message, new_plain, error=0.001)
def test_fft(self): fft_vec = self.fft.fft_fwd(coeffs=[0, 1, 4, 5]) check_complex_vector_approx_eq(fft_vec, [10, -4-4j, -2, -4+4j])
def run_test_bootstrap_steps(self, message): # ------------------- SETUP -------------------- # num_slots = self.degree // 2 plain = self.encoder.encode(message, self.scaling_factor) ciph = self.encryptor.encrypt(plain) rot_keys = {} for i in range(num_slots): rot_keys[i] = self.key_generator.generate_rot_key(i) conj_key = self.key_generator.generate_conj_key() # Raise modulus. old_modulus = ciph.modulus old_scaling_factor = self.scaling_factor self.evaluator.raise_modulus(ciph) print(message) print("-----------------------") print(plain.poly.coeffs) plain = self.decryptor.decrypt(ciph) test_plain = Plaintext(plain.poly.mod_small(self.ciph_modulus), self.scaling_factor) print("-------- TEST --------") print(test_plain.poly.coeffs) print(self.encoder.decode(test_plain)) print("---------- BIT SIZE ------------") print(math.log(self.scaling_factor, 2)) print(math.log(self.ciph_modulus, 2)) print(math.log(abs(plain.poly.coeffs[0]), 2)) print("---------- MOD ------------") print(plain.poly.coeffs[0]) print(plain.poly.coeffs[0] > self.ciph_modulus / 2) print(math.sin(2 * math.pi * plain.poly.coeffs[0] / self.ciph_modulus)) print(2 * math.pi * (plain.p.coeffs[0] % self.ciph_modulus) / self.ciph_modulus) print( math.sin(2 * math.pi * plain.poly.coeffs[0] / self.ciph_modulus) * self.ciph_modulus / 2 / math.pi) print(plain.poly.coeffs[0] % self.ciph_modulus) # Coeff to slot. ciph0, ciph1 = self.evaluator.coeff_to_slot(ciph, rot_keys, conj_key, self.encoder) plain_slots0 = [ plain.poly.coeffs[i] / self.evaluator.scaling_factor for i in range(num_slots) ] plain_slots1 = [ plain.poly.coeffs[i] / self.evaluator.scaling_factor for i in range(num_slots, 2 * num_slots) ] print("----- COEFF TO SLOT -------") print(plain_slots0) print(plain_slots1) decrypted0 = self.decryptor.decrypt(ciph0) decoded0 = self.encoder.decode(decrypted0) decrypted1 = self.decryptor.decrypt(ciph1) decoded1 = self.encoder.decode(decrypted1) check_complex_vector_approx_eq(decoded0, plain_slots0, error_message="COEFF TO SLOT FAILED") check_complex_vector_approx_eq(decoded1, plain_slots1, error_message="COEFF TO SLOT FAILED") # Exponentiate. const = self.evaluator.scaling_factor / old_modulus * 2 * math.pi * 1j ciph_exp0 = self.evaluator.exp(ciph0, const, self.relin_key, self.encoder) ciph_neg_exp0 = self.evaluator.conjugate(ciph_exp0, conj_key) ciph_exp1 = self.evaluator.exp(ciph1, const, self.relin_key, self.encoder) ciph_neg_exp1 = self.evaluator.conjugate(ciph_exp1, conj_key) pre_exp0 = [plain_slots0[i] * const for i in range(num_slots)] pre_exp1 = [plain_slots1[i] * const for i in range(num_slots)] exp0 = [cmath.exp(pre_exp0[i]) for i in range(num_slots)] exp1 = [cmath.exp(pre_exp1[i]) for i in range(num_slots)] taylor_exp0 = taylor_exp(plain_slots0, const, num_iterations=self.num_taylor_exp_iterations) taylor_exp1 = taylor_exp(plain_slots1, const, num_iterations=self.num_taylor_exp_iterations) neg_exp0 = [cmath.exp(-pre_exp0[i]) for i in range(num_slots)] neg_exp1 = [cmath.exp(-pre_exp1[i]) for i in range(num_slots)] print("----- EXP -------") print("----- argument -----") print(pre_exp0) print(pre_exp1) print("---- actual exp ------") print(exp0) print(exp1) print("---- taylor series exp ----") print(taylor_exp0) print(taylor_exp1) print("---- actual negative exp -----") print(neg_exp0) print(neg_exp1) decrypted_exp0 = self.decryptor.decrypt(ciph_exp0) decoded_exp0 = self.encoder.decode(decrypted_exp0) decrypted_neg_exp0 = self.decryptor.decrypt(ciph_neg_exp0) decoded_neg_exp0 = self.encoder.decode(decrypted_neg_exp0) decrypted_exp1 = self.decryptor.decrypt(ciph_exp1) decoded_exp1 = self.encoder.decode(decrypted_exp1) decrypted_neg_exp1 = self.decryptor.decrypt(ciph_neg_exp1) decoded_neg_exp1 = self.encoder.decode(decrypted_neg_exp1) check_complex_vector_approx_eq(decoded_exp0, exp0, error=0.001, error_message="EXP FAILED") check_complex_vector_approx_eq(decoded_exp1, exp1, error=0.001, error_message="EXP FAILED") check_complex_vector_approx_eq(decoded_neg_exp0, neg_exp0, error=0.001, error_message="EXP FAILED") check_complex_vector_approx_eq(decoded_neg_exp1, neg_exp1, error=0.001, error_message="EXP FAILED") # Compute sine. ciph_sin0 = self.evaluator.subtract(ciph_exp0, ciph_neg_exp0) ciph_sin1 = self.evaluator.subtract(ciph_exp1, ciph_neg_exp1) sin0 = [(exp0[i] - neg_exp0[i]) / 2 / 1j for i in range(num_slots)] sin1 = [(exp1[i] - neg_exp1[i]) / 2 / 1j for i in range(num_slots)] # Scale sine. const = self.evaluator.create_complex_constant_plain( old_modulus / self.evaluator.scaling_factor * 0.25 / math.pi / 1j, self.encoder) ciph0 = self.evaluator.multiply_plain(ciph_sin0, const) ciph1 = self.evaluator.multiply_plain(ciph_sin1, const) ciph0 = self.evaluator.rescale(ciph0, self.evaluator.scaling_factor) ciph1 = self.evaluator.rescale(ciph1, self.evaluator.scaling_factor) print("----- SIN -------") print(sin0) print(sin1) sin_check0 = [cmath.sin(pre_exp0[i]) for i in range(num_slots)] sin_check1 = [cmath.sin(pre_exp1[i]) for i in range(num_slots)] print(sin_check0) print(sin_check1) scaled_sin0 = [ sin0[i] * self.ciph_modulus / self.evaluator.scaling_factor / 2 / math.pi for i in range(num_slots) ] scaled_sin1 = [ sin1[i] * self.ciph_modulus / self.evaluator.scaling_factor / 2 / math.pi for i in range(num_slots) ] print("----- SCALED SIN -------") print(scaled_sin0) print(scaled_sin1) expected_slots0 = [(plain.poly.coeffs[i] % self.ciph_modulus) / self.evaluator.scaling_factor for i in range(num_slots)] expected_slots1 = [(plain.poly.coeffs[i] % self.ciph_modulus) / self.evaluator.scaling_factor for i in range(num_slots, 2 * num_slots)] print(expected_slots0) print(expected_slots1) decrypted0 = self.decryptor.decrypt(ciph0) decoded0 = self.encoder.decode(decrypted0) decrypted1 = self.decryptor.decrypt(ciph1) decoded1 = self.encoder.decode(decrypted1) check_complex_vector_approx_eq(decoded0, scaled_sin0, error=0.1, error_message="SIN FAILED") check_complex_vector_approx_eq(decoded1, scaled_sin1, error=0.1, error_message="SIN FAILED") # Slot to coeff. ciph = self.evaluator.slot_to_coeff(ciph0, ciph1, rot_keys, self.encoder) # Reset scaling factor. self.scaling_factor = old_scaling_factor ciph.scaling_factor = self.scaling_factor new_plain = self.decryptor.decrypt(ciph) new_plain = self.encoder.decode(new_plain)\ print("-------- ANSWER -------") print(new_plain) check_complex_vector_approx_eq(message, new_plain, error=0.05, error_message="FINAL CHECK FAILED") print("------------ BOOTSTRAPPING MODULUS CHANGES -------------") print("Old modulus q: %d bits" % (int(math.log(old_modulus, 2)))) print("Raised modulus Q_0: %d bits" % (int(math.log(self.big_modulus, 2)))) print("Final modulus Q_1: %d bits" % (int(math.log(ciph.modulus, 2))))