예제 #1
0
    def __init__(self, encryption_param):
        n = encryption_param.poly_modulus
        q = encryption_param.coeff_modulus
        t = encryption_param.plain_modulus

        self._coeff_count = n
        self.base_q = RNSBase(q)
        self.base_q_size = len(q)
        self._t = t
        self._base_t_gamma = RNSBase([t, gamma])
        self._base_t_gamma_size = 2
        self.prod_t_gamma_mod_q = [(t * gamma) % q for q in self.base_q.base]
        self._inv_gamma_mod_t = invert_mod(gamma, self._t)

        # Compute -prod(q)^(-1) mod {t, gamma}
        self.neg_inv_q_mod_t_gamma = [0] * self._base_t_gamma_size
        for i in range(self._base_t_gamma_size):
            self.neg_inv_q_mod_t_gamma[i] = self.base_q.base_prod % self._base_t_gamma.base[i]
            self.neg_inv_q_mod_t_gamma[i] = invert_mod(
                self.neg_inv_q_mod_t_gamma[i], self._base_t_gamma.base[i]
            )
            self.neg_inv_q_mod_t_gamma[i] = negate_mod(
                self.neg_inv_q_mod_t_gamma[i], self._base_t_gamma.base[i]
            )
예제 #2
0
    def __init__(self, base):
        self.size = len(base)

        for i in range(self.size):
            if base[i] == 0:
                raise ValueError("rns_base is invalid")

            # The base must be coprime
            for j in base[:i]:
                if gcd(base[i], j) != 1:
                    raise ValueError("rns_base is invalid")

        self.base = base
        self.base_prod = None
        self.punctured_prod_list = [0] * self.size
        self.inv_punctured_prod_mod_base_list = [0] * self.size

        if self.size > 1:
            # Compute punctured product
            for i in range(self.size):
                self.punctured_prod_list[i] = multiply_many_except(self.base, self.size, i)

            # Compute the full product
            self.base_prod = self.punctured_prod_list[0] * self.base[0]

            # Compute inverses of punctured products mod primes
            for i in range(self.size):
                self.inv_punctured_prod_mod_base_list[i] = (
                    self.punctured_prod_list[i] % self.base[i]
                )
                self.inv_punctured_prod_mod_base_list[i] = invert_mod(
                    self.inv_punctured_prod_mod_base_list[i], self.base[i]
                )

        else:
            self.base_prod = self.base[0]
            self.punctured_prod_list[0] = 1
            self.inv_punctured_prod_mod_base_list[0] = 1
예제 #3
0
def test_invert_mod(input, modulus, result):
    assert result == invert_mod(input, modulus)
예제 #4
0
파일: rns_tool.py 프로젝트: ribhu97/PySyft
    def initialize(self):
        base_q_size = len(self.q)
        # In some cases we might need to increase the size of the base B by one, namely we require
        # K*n*t*q^2<q*prod(B)*m_sk, where K takes into account cross terms when larger size
        # ciphertexts are used, and n is the "delta factor" for the ring. We reserve 32 bits
        # for K * n. Here the coeff modulus primes q_i are bounded to be
        # SEAL_USER_MOD_BIT_COUNT_MAX (60) bits, and all primes in B and m_sk are
        # SEAL_INTERNAL_MOD_BIT_COUNT (61) bits.
        total_coeff_bit_count = RNSBase(self.q).base_prod.bit_length()

        base_B_size = base_q_size
        if 32 + self.t.bit_length() + total_coeff_bit_count >= 61 * len(
                self.q) + 61:
            base_B_size += 1

        # Sample primes for B and two more primes: m_sk and gamma.
        baseconv_primes = get_primes(self.coeff_count, 50, base_B_size + 2)
        self.m_sk = baseconv_primes[0]
        self.gamma = baseconv_primes[1]
        base_B_primes = baseconv_primes[2:]

        self.prod_t_gamma_mod_q = [(self.t * self.gamma) % q for q in self.q]
        self.inv_gamma_mod_t = invert_mod(self.gamma, self.t)

        # Set m_tilde_ to a non-prime value
        self.m_tilde = 1 << 32

        # Populate the base arrays
        self.base_q = RNSBase(self.q)
        self.base_B = RNSBase(base_B_primes)
        self.base_Bsk = RNSBase(base_B_primes + [self.m_sk])
        self.base_Bsk_m_tilde = RNSBase(base_B_primes + [self.m_sk] +
                                        [self.m_tilde])

        if self.t != 0:
            self.base_t_gamma_size = 2
            self.base_t_gamma = RNSBase([self.t, self.gamma])

        # Set up BaseConvTool for q --> Bsk
        self.base_q_to_Bsk_conv = BaseConvertor(self.base_q, self.base_Bsk)

        # Set up BaseConvTool for q --> {m_tilde}
        self.base_q_to_m_tilde_conv = BaseConvertor(self.base_q,
                                                    RNSBase([self.m_tilde]))

        # Set up BaseConvTool for B --> q
        self.base_B_to_q_conv = BaseConvertor(self.base_B, self.base_q)

        # Set up BaseConvTool for B --> {m_sk}
        self.base_B_to_m_sk_conv = BaseConvertor(self.base_B,
                                                 RNSBase([self.m_sk]))

        if self.t != 0:
            # Base conversion: convert from q to {t, gamma}
            self.base_q_to_t_gamma_conv = BaseConvertor(
                self.base_q, self.base_t_gamma)

        # Compute prod(q)^(-1) mod m_tilde
        inv_prod_q_mod_m_tilde = self.base_q.base_prod % self.m_tilde
        self.inv_prod_q_mod_m_tilde = invert_mod(inv_prod_q_mod_m_tilde,
                                                 self.m_tilde)

        # Compute m_tilde^(-1) mod Bsk
        self.inv_m_tilde_mod_Bsk = [0] * self.base_Bsk.size
        for i in range(self.base_Bsk.size):
            self.inv_m_tilde_mod_Bsk[i] = invert_mod(
                self.m_tilde % self.base_Bsk.base[i], self.base_Bsk.base[i])

        # Compute prod(q) mod Bsk
        self.prod_q_mod_Bsk = [0] * self.base_Bsk.size
        for i in range(self.base_Bsk.size):
            self.prod_q_mod_Bsk[
                i] = self.base_q.base_prod % self.base_Bsk.base[i]

        # Compute prod(B)^(-1) mod m_sk
        self.inv_prod_B_mod_m_sk = self.base_B.base_prod % self.m_sk
        self.inv_prod_B_mod_m_sk = invert_mod(self.inv_prod_B_mod_m_sk,
                                              self.m_sk)

        # Compute prod(B) mod q
        self.prod_B_mod_q = [0] * self.base_q.size
        for i in range(self.base_q.size):
            self.prod_B_mod_q[i] = self.base_B.base_prod % self.base_q.base[i]

        # Compute prod(q)^(-1) mod Bsk
        self.inv_prod_q_mod_Bsk = [0] * self.base_Bsk.size
        for i in range(self.base_Bsk.size):
            self.inv_prod_q_mod_Bsk[
                i] = self.base_q.base_prod % self.base_Bsk.base[i]
            self.inv_prod_q_mod_Bsk[i] = invert_mod(self.inv_prod_q_mod_Bsk[i],
                                                    self.base_Bsk.base[i])

        if self.t != 0:
            # Compute -prod(q)^(-1) mod {t, gamma}
            self.neg_inv_q_mod_t_gamma = [0] * self.base_t_gamma_size
            for i in range(self.base_t_gamma_size):
                self.neg_inv_q_mod_t_gamma[
                    i] = self.base_q.base_prod % self.base_t_gamma.base[i]
                self.neg_inv_q_mod_t_gamma[i] = invert_mod(
                    self.neg_inv_q_mod_t_gamma[i], self.base_t_gamma.base[i])
                self.neg_inv_q_mod_t_gamma[i] = negate_mod(
                    self.neg_inv_q_mod_t_gamma[i], self.base_t_gamma.base[i])

        # Compute q[last]^(-1) mod q[i] for i = 0..last-1
        # This is used by modulus switching and rescaling
        self.inv_q_last_mod_q = [0] * (base_q_size - 1)
        for i in range(base_q_size - 1):
            self.inv_q_last_mod_q[i] = invert_mod(self.q[-1], self.q[i])