コード例 #1
0
 def run_test_add(self, message1, message2):
     poly1 = Polynomial(self.degree, message1)
     poly2 = Polynomial(self.degree, message2)
     plain1 = Plaintext(poly1)
     plain2 = Plaintext(poly2)
     plain_sum = Plaintext(poly1.add(poly2, self.plain_modulus))
     ciph1 = self.encryptor.encrypt(plain1)
     ciph2 = self.encryptor.encrypt(plain2)
     ciph_sum = self.evaluator.add(ciph1, ciph2)
     decrypted_sum = self.decryptor.decrypt(ciph_sum)
     self.assertEqual(str(plain_sum), str(decrypted_sum))
コード例 #2
0
 def run_test_multiply(self, message1, message2):
     poly1 = Polynomial(self.degree, message1)
     poly2 = Polynomial(self.degree, message2)
     plain1 = Plaintext(poly1)
     plain2 = Plaintext(poly2)
     plain_prod = Plaintext(poly1.multiply(poly2, self.plain_modulus))
     ciph1 = self.encryptor.encrypt(plain1)
     ciph2 = self.encryptor.encrypt(plain2)
     ciph_prod = self.evaluator.multiply(ciph1, ciph2, self.relin_key)
     decrypted_prod = self.decryptor.decrypt(ciph_prod)
     self.assertEqual(str(plain_prod), str(decrypted_prod))
コード例 #3
0
    def decrypt(self, ciphertext, c2=None):
        """Decrypts a ciphertext.

        Decrypts the ciphertext and returns the corresponding plaintext.

        Args:
            ciphertext (Ciphertext): Ciphertext to be decrypted.

        Returns:
            The plaintext corresponding to the decrypted ciphertext.
        """
        (c0, c1) = (ciphertext.c0, ciphertext.c1)
        intermed_message = c0.add(
            c1.multiply(self.secret_key.s, self.ciph_modulus),
            self.ciph_modulus)
        if c2:
            secret_key_squared = self.secret_key.s.multiply(
                self.secret_key.s, self.ciph_modulus)
            intermed_message = intermed_message.add(
                c2.multiply(secret_key_squared, self.ciph_modulus),
                self.ciph_modulus)

        intermed_message = intermed_message.scalar_multiply(
            1 / self.scaling_factor)
        intermed_message = intermed_message.round()
        intermed_message = intermed_message.mod(self.plain_modulus)
        return Plaintext(intermed_message)
コード例 #4
0
    def run_test_multiply(self, vec1, vec2):
        """Checks that encode satisfies homomorphic multiplication.

        Encodes two input vectors, and check that their product matches
        before and after encoding. Before encoding, the product is
        component-wise, and after encoding, the product is polynomial, since
        the encoding includes an inverse FFT operation.

        Args:
            vec1 (list (complex)): First vector.
            vec2 (list (complex)): Second vector.

        Raises:
            ValueError: An error if test fails.
        """
        orig_prod = [0] * (self.degree // 2)
        for i in range(self.degree // 2):
            orig_prod[i] = vec1[i] * vec2[i]

        plain1 = self.encoder.encode(vec1, self.scaling_factor)
        plain2 = self.encoder.encode(vec2, self.scaling_factor)
        plain_prod = Plaintext(plain1.poly.multiply_naive(plain2.poly),
                               scaling_factor=self.scaling_factor**2)
        expected = self.encoder.decode(plain_prod)

        check_complex_vector_approx_eq(expected, orig_prod, error=0.1)
コード例 #5
0
    def decrypt(self, ciphertext, c2=None):
        """Decrypts a ciphertext.

        Decrypts the ciphertext and returns the corresponding plaintext.

        Args:
            ciphertext (Ciphertext): Ciphertext to be decrypted.
            c2 (Polynomial): Optional additional parameter for a ciphertext that
                has not been relinearized.

        Returns:
            The plaintext corresponding to the decrypted ciphertext.
        """
        (c0, c1) = (ciphertext.c0, ciphertext.c1)

        message = c1.multiply(self.secret_key.s,
                              ciphertext.modulus,
                              crt=self.crt_context)
        message = c0.add(message, ciphertext.modulus)
        if c2:
            secret_key_squared = self.secret_key.s.multiply(
                self.secret_key.s, ciphertext.modulus)
            c2_message = c2.multiply(secret_key_squared,
                                     ciphertext.modulus,
                                     crt=self.crt_context)
            message = message.add(c2_message, ciphertext.modulus)

        message = message.mod_small(ciphertext.modulus)
        return Plaintext(message, ciphertext.scaling_factor)
コード例 #6
0
ファイル: ckks_encoder.py プロジェクト: sarojaerabelli/py-fhe
    def encode(self, values, scaling_factor):
        """Encodes complex numbers into a polynomial.

        Encodes an array of complex number into a polynomial.

        Args:
            values (list): List of complex numbers to encode.
            scaling_factor (float): Scaling factor to multiply by.

        Returns:
            A Plaintext object which represents the encoded value.
        """
        num_values = len(values)
        plain_len = num_values << 1

        # Canonical embedding inverse variant.
        to_scale = self.fft.embedding_inv(values)

        # Multiply by scaling factor, and split up real and imaginary parts.
        message = [0] * plain_len
        for i in range(num_values):
            message[i] = int(to_scale[i].real * scaling_factor + 0.5)
            message[i + num_values] = int(to_scale[i].imag * scaling_factor +
                                          0.5)

        return Plaintext(Polynomial(plain_len, message), scaling_factor)
コード例 #7
0
 def run_test_encode_multiply(self, inp1, inp2):
     prod = [(inp1[i] * inp2[i]) % self.plain_modulus
             for i in range(len(inp1))]
     plain1 = self.encoder.encode(inp1)
     plain2 = self.encoder.encode(inp2)
     plain_prod = Plaintext(
         plain1.poly.multiply(plain2.poly, self.plain_modulus))
     decoded_prod = self.encoder.decode(plain_prod)
     self.assertEqual(prod, decoded_prod)
コード例 #8
0
ファイル: decryptor.py プロジェクト: Kate-Lin/flask_MPC
 def decrypt(self,ciphertext,c2=None):
     (c0,c1) = ciphertext.c0,ciphertext.c1
     message = c1.multiply(self.secret_key,ciphertext.modulus,crt=self.crt_context)
     message = c0.add(message,ciphertext.modulus)
     if c2:
         secret_key_squared = self.secret_key.multiply(self.secret_key,ciphertext.modulus)
         c2_message = c2.multiply(secret_key_squared, ciphertext.modulus, crt=self.crt_context)
         message = message.add(c2_message,ciphertext.modulus)
     message = message.mod_small(ciphertext.modulus)
     return Plaintext(message,ciphertext.scaling_factor)
コード例 #9
0
 def run_test_large_encrypt_decrypt(self, message):
     params = BFVParameters(poly_degree=self.large_degree,
                            plain_modulus=self.large_plain_modulus,
                            ciph_modulus=self.large_ciph_modulus)
     key_generator = BFVKeyGenerator(params)
     public_key = key_generator.public_key
     secret_key = key_generator.secret_key
     encryptor = BFVEncryptor(params, public_key)
     decryptor = BFVDecryptor(params, secret_key)
     message = Plaintext(Polynomial(self.large_degree, message))
     ciphertext = encryptor.encrypt(message)
     decrypted_message = decryptor.decrypt(ciphertext)
     self.assertEqual(str(message), str(decrypted_message))
コード例 #10
0
    def encode(self, z, scaling_factor: float) -> Plaintext:
        """Encodes a vector by expanding it first to H,
        scale it, project it on the lattice of sigma(R), and performs
        sigma inverse.
        """
        pi_z = self.pi_inverse(z)
        scaled_pi_z = scaling_factor * pi_z
        rounded_scale_pi_zi = self.sigma_R_discretization(scaled_pi_z)
        p = self.sigma_inverse(rounded_scale_pi_zi)

        # We round it afterwards due to numerical imprecision
        coef = np.round(np.real(p.coef)).astype(int).tolist()
        p = Poly(self.degree, coef)

        return Plaintext(p, scaling_factor=scaling_factor)
コード例 #11
0
    def create_constant_plain(self, const):
        """Creates a plaintext containing a constant value.

        Takes a floating-point constant, and turns it into a plaintext.

        Args:
            const (float): Constant to encode.

        Returns:
            Plaintext with constant value.
        """
        plain_vec = [0] * (self.degree)
        plain_vec[0] = int(const * self.scaling_factor)
        return Plaintext(Polynomial(self.degree, plain_vec),
                         self.scaling_factor)
コード例 #12
0
    def encode(self, values):
        """Encodes a list of integers into a polynomial.

        Encodes a N-length list of integers (where N is the polynomial degree)
        into a polynomial using CRT batching.

        Args: 
            values (list): Integers to encode.

        Returns:
            A Plaintext object which represents the encoded value.
        """
        assert len(values) == self.degree, 'Length of list does not equal \
            polynomial degree.'

        coeffs = self.ntt.ftt_inv(values)
        return Plaintext(Polynomial(self.degree, coeffs))
コード例 #13
0
 def test_decode(self):
     plain = Plaintext(Polynomial(self.degree, [1, 0, 1, 0, 1] + [0] * (self.degree - 5)))
     value = self.encoder.decode(plain)
     self.assertEqual(value, 21)
コード例 #14
0
    def run_test_bootstrap_steps(self, message):

        # ------------------- SETUP -------------------- #
        num_slots = self.degree // 2
        plain = self.encoder.encode(message, self.scaling_factor)
        ciph = self.encryptor.encrypt(plain)

        rot_keys = {}
        for i in range(num_slots):
            rot_keys[i] = self.key_generator.generate_rot_key(i)

        conj_key = self.key_generator.generate_conj_key()

        # Raise modulus.
        old_modulus = ciph.modulus
        old_scaling_factor = self.scaling_factor
        self.evaluator.raise_modulus(ciph)

        print(message)
        print("-----------------------")
        print(plain.poly.coeffs)
        plain = self.decryptor.decrypt(ciph)
        test_plain = Plaintext(plain.poly.mod_small(self.ciph_modulus),
                               self.scaling_factor)
        print("-------- TEST --------")
        print(test_plain.poly.coeffs)
        print(self.encoder.decode(test_plain))

        print("---------- BIT SIZE ------------")
        print(math.log(self.scaling_factor, 2))
        print(math.log(self.ciph_modulus, 2))
        print(math.log(abs(plain.poly.coeffs[0]), 2))
        print("---------- MOD ------------")
        print(plain.poly.coeffs[0])
        print(plain.poly.coeffs[0] > self.ciph_modulus / 2)
        print(math.sin(2 * math.pi * plain.poly.coeffs[0] / self.ciph_modulus))
        print(2 * math.pi * (plain.p.coeffs[0] % self.ciph_modulus) /
              self.ciph_modulus)
        print(
            math.sin(2 * math.pi * plain.poly.coeffs[0] / self.ciph_modulus) *
            self.ciph_modulus / 2 / math.pi)
        print(plain.poly.coeffs[0] % self.ciph_modulus)

        # Coeff to slot.
        ciph0, ciph1 = self.evaluator.coeff_to_slot(ciph, rot_keys, conj_key,
                                                    self.encoder)
        plain_slots0 = [
            plain.poly.coeffs[i] / self.evaluator.scaling_factor
            for i in range(num_slots)
        ]
        plain_slots1 = [
            plain.poly.coeffs[i] / self.evaluator.scaling_factor
            for i in range(num_slots, 2 * num_slots)
        ]
        print("----- COEFF TO SLOT -------")
        print(plain_slots0)
        print(plain_slots1)
        decrypted0 = self.decryptor.decrypt(ciph0)
        decoded0 = self.encoder.decode(decrypted0)
        decrypted1 = self.decryptor.decrypt(ciph1)
        decoded1 = self.encoder.decode(decrypted1)
        check_complex_vector_approx_eq(decoded0,
                                       plain_slots0,
                                       error_message="COEFF TO SLOT FAILED")
        check_complex_vector_approx_eq(decoded1,
                                       plain_slots1,
                                       error_message="COEFF TO SLOT FAILED")

        # Exponentiate.
        const = self.evaluator.scaling_factor / old_modulus * 2 * math.pi * 1j
        ciph_exp0 = self.evaluator.exp(ciph0, const, self.relin_key,
                                       self.encoder)
        ciph_neg_exp0 = self.evaluator.conjugate(ciph_exp0, conj_key)
        ciph_exp1 = self.evaluator.exp(ciph1, const, self.relin_key,
                                       self.encoder)
        ciph_neg_exp1 = self.evaluator.conjugate(ciph_exp1, conj_key)

        pre_exp0 = [plain_slots0[i] * const for i in range(num_slots)]
        pre_exp1 = [plain_slots1[i] * const for i in range(num_slots)]
        exp0 = [cmath.exp(pre_exp0[i]) for i in range(num_slots)]
        exp1 = [cmath.exp(pre_exp1[i]) for i in range(num_slots)]
        taylor_exp0 = taylor_exp(plain_slots0,
                                 const,
                                 num_iterations=self.num_taylor_exp_iterations)
        taylor_exp1 = taylor_exp(plain_slots1,
                                 const,
                                 num_iterations=self.num_taylor_exp_iterations)
        neg_exp0 = [cmath.exp(-pre_exp0[i]) for i in range(num_slots)]
        neg_exp1 = [cmath.exp(-pre_exp1[i]) for i in range(num_slots)]
        print("----- EXP -------")
        print("----- argument -----")
        print(pre_exp0)
        print(pre_exp1)
        print("---- actual exp ------")
        print(exp0)
        print(exp1)
        print("---- taylor series exp ----")
        print(taylor_exp0)
        print(taylor_exp1)
        print("---- actual negative exp -----")
        print(neg_exp0)
        print(neg_exp1)

        decrypted_exp0 = self.decryptor.decrypt(ciph_exp0)
        decoded_exp0 = self.encoder.decode(decrypted_exp0)
        decrypted_neg_exp0 = self.decryptor.decrypt(ciph_neg_exp0)
        decoded_neg_exp0 = self.encoder.decode(decrypted_neg_exp0)
        decrypted_exp1 = self.decryptor.decrypt(ciph_exp1)
        decoded_exp1 = self.encoder.decode(decrypted_exp1)
        decrypted_neg_exp1 = self.decryptor.decrypt(ciph_neg_exp1)
        decoded_neg_exp1 = self.encoder.decode(decrypted_neg_exp1)
        check_complex_vector_approx_eq(decoded_exp0,
                                       exp0,
                                       error=0.001,
                                       error_message="EXP FAILED")
        check_complex_vector_approx_eq(decoded_exp1,
                                       exp1,
                                       error=0.001,
                                       error_message="EXP FAILED")
        check_complex_vector_approx_eq(decoded_neg_exp0,
                                       neg_exp0,
                                       error=0.001,
                                       error_message="EXP FAILED")
        check_complex_vector_approx_eq(decoded_neg_exp1,
                                       neg_exp1,
                                       error=0.001,
                                       error_message="EXP FAILED")

        # Compute sine.
        ciph_sin0 = self.evaluator.subtract(ciph_exp0, ciph_neg_exp0)
        ciph_sin1 = self.evaluator.subtract(ciph_exp1, ciph_neg_exp1)

        sin0 = [(exp0[i] - neg_exp0[i]) / 2 / 1j for i in range(num_slots)]
        sin1 = [(exp1[i] - neg_exp1[i]) / 2 / 1j for i in range(num_slots)]

        # Scale sine.
        const = self.evaluator.create_complex_constant_plain(
            old_modulus / self.evaluator.scaling_factor * 0.25 / math.pi / 1j,
            self.encoder)
        ciph0 = self.evaluator.multiply_plain(ciph_sin0, const)
        ciph1 = self.evaluator.multiply_plain(ciph_sin1, const)
        ciph0 = self.evaluator.rescale(ciph0, self.evaluator.scaling_factor)
        ciph1 = self.evaluator.rescale(ciph1, self.evaluator.scaling_factor)

        print("----- SIN -------")
        print(sin0)
        print(sin1)
        sin_check0 = [cmath.sin(pre_exp0[i]) for i in range(num_slots)]
        sin_check1 = [cmath.sin(pre_exp1[i]) for i in range(num_slots)]
        print(sin_check0)
        print(sin_check1)

        scaled_sin0 = [
            sin0[i] * self.ciph_modulus / self.evaluator.scaling_factor / 2 /
            math.pi for i in range(num_slots)
        ]
        scaled_sin1 = [
            sin1[i] * self.ciph_modulus / self.evaluator.scaling_factor / 2 /
            math.pi for i in range(num_slots)
        ]
        print("----- SCALED SIN -------")
        print(scaled_sin0)
        print(scaled_sin1)
        expected_slots0 = [(plain.poly.coeffs[i] % self.ciph_modulus) /
                           self.evaluator.scaling_factor
                           for i in range(num_slots)]
        expected_slots1 = [(plain.poly.coeffs[i] % self.ciph_modulus) /
                           self.evaluator.scaling_factor
                           for i in range(num_slots, 2 * num_slots)]
        print(expected_slots0)
        print(expected_slots1)

        decrypted0 = self.decryptor.decrypt(ciph0)
        decoded0 = self.encoder.decode(decrypted0)
        decrypted1 = self.decryptor.decrypt(ciph1)
        decoded1 = self.encoder.decode(decrypted1)
        check_complex_vector_approx_eq(decoded0,
                                       scaled_sin0,
                                       error=0.1,
                                       error_message="SIN FAILED")
        check_complex_vector_approx_eq(decoded1,
                                       scaled_sin1,
                                       error=0.1,
                                       error_message="SIN FAILED")

        # Slot to coeff.
        ciph = self.evaluator.slot_to_coeff(ciph0, ciph1, rot_keys,
                                            self.encoder)

        # Reset scaling factor.
        self.scaling_factor = old_scaling_factor
        ciph.scaling_factor = self.scaling_factor

        new_plain = self.decryptor.decrypt(ciph)
        new_plain = self.encoder.decode(new_plain)\

        print("-------- ANSWER -------")
        print(new_plain)
        check_complex_vector_approx_eq(message,
                                       new_plain,
                                       error=0.05,
                                       error_message="FINAL CHECK FAILED")

        print("------------ BOOTSTRAPPING MODULUS CHANGES -------------")
        print("Old modulus q: %d bits" % (int(math.log(old_modulus, 2))))
        print("Raised modulus Q_0: %d bits" %
              (int(math.log(self.big_modulus, 2))))
        print("Final modulus Q_1: %d bits" % (int(math.log(ciph.modulus, 2))))