def generate_switching_key(self, new_key): """Generates a switching key for CKKS scheme. Generates a switching key as described in KSGen in the CKKS paper. Args: new_key (Polynomial): New key to generate switching key. Returns: A switching key. """ mod = self.params.big_modulus mod_squared = mod**2 swk_coeff = Polynomial( self.params.poly_degree, sample_uniform(0, mod_squared, self.params.poly_degree)) swk_error = Polynomial(self.params.poly_degree, sample_triangle(self.params.poly_degree)) sw0 = swk_coeff.multiply(self.secret_key.s, mod_squared) sw0 = sw0.scalar_multiply(-1, mod_squared) sw0 = sw0.add(swk_error, mod_squared) temp = new_key.scalar_multiply(mod, mod_squared) sw0 = sw0.add(temp, mod_squared) sw1 = swk_coeff return PublicKey(sw0, sw1)
def generate_relin_key(self, params): """Generates a relinearization key for BFV scheme. Args: params (Parameters): Parameters including polynomial degree, plaintext, and ciphertext modulus. """ base = ceil(sqrt(params.ciph_modulus)) num_levels = floor(log(params.ciph_modulus, base)) + 1 keys = [0] * num_levels power = 1 sk_squared = self.secret_key.s.multiply(self.secret_key.s, params.ciph_modulus) for i in range(num_levels): k1 = Polynomial( params.poly_degree, sample_uniform(0, params.ciph_modulus, params.poly_degree)) error = Polynomial(params.poly_degree, sample_triangle(params.poly_degree)) k0 = self.secret_key.s.multiply(k1, params.ciph_modulus).add( error, params.ciph_modulus).scalar_multiply(-1).add( sk_squared.scalar_multiply(power), params.ciph_modulus).mod(params.ciph_modulus) keys[i] = (k0, k1) power *= base power %= params.ciph_modulus self.relin_key = BFVRelinKey(base, keys)
def test_multiply_01(self): poly1 = Polynomial(4, sample_uniform(0, 30, 4)) poly2 = Polynomial(4, sample_uniform(0, 30, 4)) poly_prod = poly1.multiply_fft(poly2) poly_prod2 = poly1.multiply_naive(poly2) self.assertEqual(poly_prod.coeffs, poly_prod2.coeffs)
def test_multiply_fft(self): poly1 = Polynomial(4, [0, 1, 4, 5]) poly2 = Polynomial(4, [1, 2, 4, 3]) poly_prod = poly1.multiply_fft(poly2) actual_coeffs = [-29, -31, -9, 17] self.assertEqual(poly_prod.coeffs, actual_coeffs)
def run_test_multiply(self, message1, message2): poly1 = Polynomial(self.degree, message1) poly2 = Polynomial(self.degree, message2) plain1 = Plaintext(poly1) plain2 = Plaintext(poly2) plain_prod = Plaintext(poly1.multiply(poly2, self.plain_modulus)) ciph1 = self.encryptor.encrypt(plain1) ciph2 = self.encryptor.encrypt(plain2) ciph_prod = self.evaluator.multiply(ciph1, ciph2, self.relin_key) decrypted_prod = self.decryptor.decrypt(ciph_prod) self.assertEqual(str(plain_prod), str(decrypted_prod))
def run_test_add(self, message1, message2): poly1 = Polynomial(self.degree, message1) poly2 = Polynomial(self.degree, message2) plain1 = Plaintext(poly1) plain2 = Plaintext(poly2) plain_sum = Plaintext(poly1.add(poly2, self.plain_modulus)) ciph1 = self.encryptor.encrypt(plain1) ciph2 = self.encryptor.encrypt(plain2) ciph_sum = self.evaluator.add(ciph1, ciph2) decrypted_sum = self.decryptor.decrypt(ciph_sum) self.assertEqual(str(plain_sum), str(decrypted_sum))
def run_test_subtract(self, message1, message2): poly1 = Polynomial(self.degree // 2, message1) poly2 = Polynomial(self.degree // 2, message2) plain1 = self.encoder.encode(message1, self.scaling_factor) plain2 = self.encoder.encode(message2, self.scaling_factor) plain_diff = poly1.subtract(poly2) ciph1 = self.encryptor.encrypt(plain1) ciph2 = self.encryptor.encrypt(plain2) ciph_diff = self.evaluator.subtract(ciph1, ciph2) decrypted_diff = self.decryptor.decrypt(ciph_diff) decoded_diff = self.encoder.decode(decrypted_diff) check_complex_vector_approx_eq(plain_diff.coeffs, decoded_diff, error=0.005)
def run_test_secret_key_add(self, message1, message2): poly1 = Polynomial(self.degree // 2, message1) poly2 = Polynomial(self.degree // 2, message2) plain1 = self.encoder.encode(message1, self.scaling_factor) plain2 = self.encoder.encode(message2, self.scaling_factor) plain_sum = poly1.add(poly2) ciph1 = self.encryptor.encrypt_with_secret_key(plain1) ciph2 = self.encryptor.encrypt_with_secret_key(plain2) ciph_sum = self.evaluator.add(ciph1, ciph2) decrypted_sum = self.decryptor.decrypt(ciph_sum) decoded_sum = self.encoder.decode(decrypted_sum) check_complex_vector_approx_eq(plain_sum.coeffs, decoded_sum, error=0.001)
def encode(self, values, scaling_factor): """Encodes complex numbers into a polynomial. Encodes an array of complex number into a polynomial. Args: values (list): List of complex numbers to encode. scaling_factor (float): Scaling factor to multiply by. Returns: A Plaintext object which represents the encoded value. """ num_values = len(values) plain_len = num_values << 1 # Canonical embedding inverse variant. to_scale = self.fft.embedding_inv(values) # Multiply by scaling factor, and split up real and imaginary parts. message = [0] * plain_len for i in range(num_values): message[i] = int(to_scale[i].real * scaling_factor + 0.5) message[i + num_values] = int(to_scale[i].imag * scaling_factor + 0.5) return Plaintext(Polynomial(plain_len, message), scaling_factor)
def encrypt(self, plain): """Encrypts a message. Encrypts the message and returns the corresponding ciphertext. Args: plain (Plaintext): Plaintext to be encrypted. Returns: A ciphertext consisting of a pair of polynomials in the ciphertext space. """ p0 = self.public_key.p0 p1 = self.public_key.p1 random_vec = Polynomial(self.poly_degree, sample_triangle(self.poly_degree)) error1 = Polynomial(self.poly_degree, sample_triangle(self.poly_degree)) error2 = Polynomial(self.poly_degree, sample_triangle(self.poly_degree)) c0 = p0.multiply(random_vec, self.coeff_modulus, crt=self.crt_context) c0 = error1.add(c0, self.coeff_modulus) c0 = c0.add(plain.poly, self.coeff_modulus) c0 = c0.mod_small(self.coeff_modulus) c1 = p1.multiply(random_vec, self.coeff_modulus, crt=self.crt_context) c1 = error2.add(c1, self.coeff_modulus) c1 = c1.mod_small(self.coeff_modulus) return Ciphertext(c0, c1, plain.scaling_factor, self.coeff_modulus)
def generate_public_key(self, params): """Generates a public key for CKKS scheme. Args: params (Parameters): Parameters including polynomial degree, plaintext, and ciphertext modulus. """ mod = self.params.big_modulus pk_coeff = Polynomial(params.poly_degree, sample_uniform(0, mod, params.poly_degree)) pk_error = Polynomial(params.poly_degree, sample_triangle(params.poly_degree)) p0 = pk_coeff.multiply(self.secret_key.s, mod) p0 = p0.scalar_multiply(-1, mod) p0 = p0.add(pk_error, mod) p1 = pk_coeff self.public_key = PublicKey(p0, p1)
def generate_secret_key(self, params): """Generates a secret key for CKKS scheme. Args: params (Parameters): Parameters including polynomial degree, plaintext, and ciphertext modulus. """ key = sample_hamming_weight_vector(params.poly_degree, params.hamming_weight) self.secret_key = SecretKey(Polynomial(params.poly_degree, key))
def generate_secret_key(self, params): """Generates a secret key for BFV scheme. Args: params (Parameters): Parameters including polynomial degree, plaintext, and ciphertext modulus. """ self.secret_key = SecretKey( Polynomial(params.poly_degree, sample_triangle(params.poly_degree)))
def test_multiply(self): poly1 = Polynomial(4, [0, 1, 4, 5]) poly2 = Polynomial(4, [1, 2, 4, 3]) poly_prod = poly1.multiply(poly2, 73) poly_prod2 = poly2.multiply(poly1, 73) self.assertEqual(poly_prod.coeffs, [44, 42, 64, 17]) self.assertEqual(poly_prod.coeffs, poly_prod2.coeffs)
def encrypt_with_secret_key(self, plain): """Encrypts a message with secret key encryption. Encrypts the message for secret key encryption and returns the corresponding ciphertext. Args: plain (Plaintext): Plaintext to be encrypted. Returns: A ciphertext consisting of a pair of polynomials in the ciphertext space. """ assert self.secret_key != None, 'Secret key does not exist' sk = self.secret_key.s random_vec = Polynomial(self.poly_degree, sample_triangle(self.poly_degree)) error = Polynomial(self.poly_degree, sample_triangle(self.poly_degree)) c0 = sk.multiply(random_vec, self.coeff_modulus, crt=self.crt_context) c0 = error.add(c0, self.coeff_modulus) c0 = c0.add(plain.poly, self.coeff_modulus) c0 = c0.mod_small(self.coeff_modulus) c1 = random_vec.scalar_multiply(-1, self.coeff_modulus) c1 = c1.mod_small(self.coeff_modulus) return Ciphertext(c0, c1, plain.scaling_factor, self.coeff_modulus)
def test_embedding(self): """Checks that canonical embedding is correct. Checks that the embedding matches the evaluations of the roots of unity at indices that are 1 (mod) 4. Raises: ValueError: An error if test fails. """ coeffs = [10, 34, 71, 31, 1, 2, 3, 4] poly = Polynomial(self.num_slots, coeffs) fft_length = self.num_slots * 4 embedding = self.fft.embedding(coeffs) evals = [] power = 1 for i in range(1, fft_length, 4): angle = 2 * pi * power / fft_length root_of_unity = complex(cos(angle), sin(angle)) evals.append(poly.evaluate(root_of_unity)) power = (power * 5) % fft_length check_complex_vector_approx_eq(embedding, evals, 0.00001)
def run_test_large_encrypt_decrypt(self, message): params = BFVParameters(poly_degree=self.large_degree, plain_modulus=self.large_plain_modulus, ciph_modulus=self.large_ciph_modulus) key_generator = BFVKeyGenerator(params) public_key = key_generator.public_key secret_key = key_generator.secret_key encryptor = BFVEncryptor(params, public_key) decryptor = BFVDecryptor(params, secret_key) message = Plaintext(Polynomial(self.large_degree, message)) ciphertext = encryptor.encrypt(message) decrypted_message = decryptor.decrypt(ciphertext) self.assertEqual(str(message), str(decrypted_message))
def run_test_conjugate(self, message): poly = Polynomial(self.degree // 2, message) plain = self.encoder.encode(message, self.scaling_factor) conj_message = [c.conjugate() for c in poly.coeffs] ciph = self.encryptor.encrypt(plain) conj_key = self.key_generator.generate_conj_key() ciph_conj = self.evaluator.conjugate(ciph, conj_key) decrypted_conj = self.decryptor.decrypt(ciph_conj) decoded_conj = self.encoder.decode(decrypted_conj) check_complex_vector_approx_eq(conj_message, decoded_conj, error=0.005)
def create_constant_plain(self, const): """Creates a plaintext containing a constant value. Takes a floating-point constant, and turns it into a plaintext. Args: const (float): Constant to encode. Returns: Plaintext with constant value. """ plain_vec = [0] * (self.degree) plain_vec[0] = int(const * self.scaling_factor) return Plaintext(Polynomial(self.degree, plain_vec), self.scaling_factor)
def run_test_rotate(self, message, r): poly = Polynomial(self.degree // 2, message) plain = self.encoder.encode(message, self.scaling_factor) rot_message = [0] * poly.ring_degree for i in range(poly.ring_degree): rot_message[i] = poly.coeffs[(i + r) % poly.ring_degree] ciph = self.encryptor.encrypt(plain) rot_key = self.key_generator.generate_rot_key(r) ciph_rot = self.evaluator.rotate(ciph, r, rot_key) decrypted_rot = self.decryptor.decrypt(ciph_rot) decoded_rot = self.encoder.decode(decrypted_rot) check_complex_vector_approx_eq(rot_message, decoded_rot, error=0.005)
def encode(self, values): """Encodes a list of integers into a polynomial. Encodes a N-length list of integers (where N is the polynomial degree) into a polynomial using CRT batching. Args: values (list): Integers to encode. Returns: A Plaintext object which represents the encoded value. """ assert len(values) == self.degree, 'Length of list does not equal \ polynomial degree.' coeffs = self.ntt.ftt_inv(values) return Plaintext(Polynomial(self.degree, coeffs))
def run_test_simple_rotate(self, message, rot): poly = Polynomial(self.degree // 2, message) plain = self.encoder.encode(message, self.scaling_factor) rot_message = [0] * poly.ring_degree for i in range(poly.ring_degree): rot_message[i] = poly.coeffs[(i + rot) % poly.ring_degree] ciph = self.encryptor.encrypt(plain) ciph_rot0 = ciph.c0.rotate(rot).mod_small(self.ciph_modulus) ciph_rot1 = ciph.c1.rotate(rot).mod_small(self.ciph_modulus) ciph_rot = Ciphertext(ciph_rot0, ciph_rot1, ciph.scaling_factor, self.ciph_modulus) decryptor = CKKSDecryptor(self.params, SecretKey(self.secret_key.s.rotate(rot))) decrypted_rot = decryptor.decrypt(ciph_rot) decoded_rot = self.encoder.decode(decrypted_rot) check_complex_vector_approx_eq(rot_message, decoded_rot, error=0.005)
def test_multiply_crt(self): log_modulus = 10 modulus = 1 << log_modulus prime_size = 59 log_poly_degree = 2 poly_degree = 1 << log_poly_degree num_primes = (2 + log_poly_degree + 4 * log_modulus + prime_size - 1) // prime_size crt = CRTContext(num_primes, prime_size, poly_degree) poly1 = Polynomial(poly_degree, [0, 1, 4, 5]) poly2 = Polynomial(poly_degree, [1, 2, 4, 3]) poly_prod = poly1.multiply_crt(poly2, crt) poly_prod = poly_prod.mod_small(modulus) poly_prod2 = poly2.multiply_crt(poly1, crt) poly_prod2 = poly_prod2.mod_small(modulus) actual = poly1.multiply_naive(poly2, modulus) actual = actual.mod_small(modulus) self.assertEqual(poly_prod.coeffs, actual.coeffs) self.assertEqual(poly_prod.coeffs, poly_prod2.coeffs)
def test_rotate(self): poly1 = Polynomial(4, [0, 1, 4, 59]) poly_rot = poly1.rotate(3) self.assertEqual(poly_rot.coeffs, [0, -1, 4, -59])
def test_evaluate(self): poly = Polynomial(self.degree, [0, 1, 2, 3, 4]) result = poly.evaluate(3) self.assertEqual(result, 426)
class TestPolynomial(unittest.TestCase): def setUp(self): self.degree = 5 self.coeff_modulus = 60 self.poly1 = Polynomial(self.degree, [0, 1, 4, 5, 59]) self.poly2 = Polynomial(self.degree, [1, 2, 4, 3, 2]) def test_add(self): poly_sum = self.poly1.add(self.poly2, self.coeff_modulus) poly_sum2 = self.poly2.add(self.poly1, self.coeff_modulus) self.assertEqual(poly_sum.coeffs, [1, 3, 8, 8, 1]) self.assertEqual(poly_sum.coeffs, poly_sum2.coeffs) def test_subtract(self): poly_diff = self.poly1.subtract(self.poly2, self.coeff_modulus) self.assertEqual(poly_diff.coeffs, [59, 59, 0, 2, 57]) def test_multiply(self): poly1 = Polynomial(4, [0, 1, 4, 5]) poly2 = Polynomial(4, [1, 2, 4, 3]) poly_prod = poly1.multiply(poly2, 73) poly_prod2 = poly2.multiply(poly1, 73) self.assertEqual(poly_prod.coeffs, [44, 42, 64, 17]) self.assertEqual(poly_prod.coeffs, poly_prod2.coeffs) def test_multiply_crt(self): log_modulus = 10 modulus = 1 << log_modulus prime_size = 59 log_poly_degree = 2 poly_degree = 1 << log_poly_degree num_primes = (2 + log_poly_degree + 4 * log_modulus + prime_size - 1) // prime_size crt = CRTContext(num_primes, prime_size, poly_degree) poly1 = Polynomial(poly_degree, [0, 1, 4, 5]) poly2 = Polynomial(poly_degree, [1, 2, 4, 3]) poly_prod = poly1.multiply_crt(poly2, crt) poly_prod = poly_prod.mod_small(modulus) poly_prod2 = poly2.multiply_crt(poly1, crt) poly_prod2 = poly_prod2.mod_small(modulus) actual = poly1.multiply_naive(poly2, modulus) actual = actual.mod_small(modulus) self.assertEqual(poly_prod.coeffs, actual.coeffs) self.assertEqual(poly_prod.coeffs, poly_prod2.coeffs) def test_multiply_fft(self): poly1 = Polynomial(4, [0, 1, 4, 5]) poly2 = Polynomial(4, [1, 2, 4, 3]) poly_prod = poly1.multiply_fft(poly2) actual_coeffs = [-29, -31, -9, 17] self.assertEqual(poly_prod.coeffs, actual_coeffs) def test_multiply_naive(self): poly_prod = self.poly1.multiply_naive(self.poly2, self.coeff_modulus) poly_prod2 = self.poly2.multiply_naive(self.poly1, self.coeff_modulus) self.assertEqual(poly_prod.coeffs, [28, 42, 59, 19, 28]) self.assertEqual(poly_prod.coeffs, poly_prod2.coeffs) def test_multiply_01(self): poly1 = Polynomial(4, sample_uniform(0, 30, 4)) poly2 = Polynomial(4, sample_uniform(0, 30, 4)) poly_prod = poly1.multiply_fft(poly2) poly_prod2 = poly1.multiply_naive(poly2) self.assertEqual(poly_prod.coeffs, poly_prod2.coeffs) def test_scalar_multiply(self): poly_prod = self.poly1.scalar_multiply(-1, self.coeff_modulus) self.assertEqual(poly_prod.coeffs, [0, 59, 56, 55, 1]) def test_rotate(self): poly1 = Polynomial(4, [0, 1, 4, 59]) poly_rot = poly1.rotate(3) self.assertEqual(poly_rot.coeffs, [0, -1, 4, -59]) def test_round(self): poly = Polynomial(self.degree, [0.51, -3.2, 54.666, 39.01, 0]) poly_rounded = poly.round() self.assertEqual(poly_rounded.coeffs, [1, -3, 55, 39, 0]) def test_mod(self): poly = Polynomial(self.degree, [57, -34, 100, 1000, -7999]) poly_rounded = poly.mod(self.coeff_modulus) self.assertEqual(poly_rounded.coeffs, [57, 26, 40, 40, 41]) def test_base_decompose(self): base = ceil(sqrt(self.coeff_modulus)) num_levels = floor(log(self.coeff_modulus, base)) + 1 poly_decomposed = self.poly1.base_decompose(base, num_levels) self.assertEqual(poly_decomposed[0].coeffs, [0, 1, 4, 5, 3]) self.assertEqual(poly_decomposed[1].coeffs, [0, 0, 0, 0, 7]) def test_evaluate(self): poly = Polynomial(self.degree, [0, 1, 2, 3, 4]) result = poly.evaluate(3) self.assertEqual(result, 426) def test_str(self): string1 = str(self.poly1) string2 = str(self.poly2) self.assertEqual(string1, '59x^4 + 5x^3 + 4x^2 + x') self.assertEqual(string2, '2x^4 + 3x^3 + 4x^2 + 2x + 1')
def setUp(self): self.degree = 5 self.coeff_modulus = 60 self.poly1 = Polynomial(self.degree, [0, 1, 4, 5, 59]) self.poly2 = Polynomial(self.degree, [1, 2, 4, 3, 2])
def test_round(self): poly = Polynomial(self.degree, [0.51, -3.2, 54.666, 39.01, 0]) poly_rounded = poly.round() self.assertEqual(poly_rounded.coeffs, [1, -3, 55, 39, 0])
def test_mod(self): poly = Polynomial(self.degree, [57, -34, 100, 1000, -7999]) poly_rounded = poly.mod(self.coeff_modulus) self.assertEqual(poly_rounded.coeffs, [57, 26, 40, 40, 41])
def encrypt(self, message): """Encrypts a message. Encrypts the message and returns the corresponding ciphertext. Args: message (Plaintext): Plaintext to be encrypted. Returns: A ciphertext consisting of a pair of polynomials in the ciphertext space. """ p0 = self.public_key.p0 p1 = self.public_key.p1 scaled_message = message.poly.scalar_multiply(self.scaling_factor, self.coeff_modulus) random_vec = Polynomial(self.poly_degree, sample_triangle(self.poly_degree)) error1 = Polynomial(self.poly_degree, sample_triangle(self.poly_degree)) error1 = Polynomial(self.poly_degree, [0] * self.poly_degree) error2 = Polynomial(self.poly_degree, sample_triangle(self.poly_degree)) error2 = Polynomial(self.poly_degree, [0] * self.poly_degree) c0 = error1.add(p0.multiply(random_vec, self.coeff_modulus), self.coeff_modulus).add(scaled_message, self.coeff_modulus) c1 = error2.add(p1.multiply(random_vec, self.coeff_modulus), self.coeff_modulus) return Ciphertext(c0, c1)