Beispiel #1
0
    def masking_output(self):
        spread_mask = generate_random_mask(
            self.modulus, [self.num_output_batch, self.degree])
        self.output_mask_s = torch.zeros(self.num_output_unit).double()

        pod_vector = uIntVector()
        pt = Plaintext()
        for idx_output_batch in range(self.num_output_batch):
            encoding_tensor = torch.zeros(self.degree, dtype=torch.float)
            for idx_piece in range(self.num_piece_in_batch):
                idx_output_unit = self.index_output_batch_to_units(
                    idx_output_batch, idx_piece)
                if idx_output_unit is False:
                    break
                padded_span = self.num_elem_in_piece
                start_piece = idx_piece * padded_span
                arr = spread_mask[idx_output_batch,
                                  start_piece:start_piece + padded_span]
                encoding_tensor[start_piece:start_piece + padded_span] = arr
                self.output_mask_s[idx_output_unit] = arr.double().sum()
            encoding_tensor = pmod(encoding_tensor, self.modulus)
            pod_vector.from_np(encoding_tensor.numpy().astype(np.uint64))
            self.batch_encoder.encode(pod_vector, pt)
            self.evaluator.add_plain_inplace(self.output_cts[idx_output_batch],
                                             pt)

        self.output_mask_s = pmod(self.output_mask_s, self.modulus)
Beispiel #2
0
 def compute_with_weight(self, weight_tensor):
     assert (weight_tensor.shape == self.weight_shape)
     pod_vector = uIntVector()
     pt = Plaintext()
     self.output_cts = encrypt_zeros(self.num_output_batch,
                                     self.batch_encoder, self.encryptor,
                                     self.degree)
     for idx_output_batch, idx_input_batch in product(
             range(self.num_output_batch), range(self.num_input_batch)):
         encoding_tensor = torch.zeros(self.degree)
         is_w_changed = False
         for idx_piece in range(self.num_piece_in_batch):
             idx_row, idx_col_start, idx_col_end = \
                 self.index_weight_batch_to_units(idx_output_batch, idx_input_batch, idx_piece)
             if idx_row is False:
                 continue
             is_w_changed = True
             padded_span = self.num_elem_in_piece
             data_span = idx_col_end - idx_col_start
             start_piece = idx_piece * padded_span
             encoding_tensor[start_piece:start_piece +
                             data_span] = weight_tensor[
                                 idx_row, idx_col_start:idx_col_end]
         if not is_w_changed:
             continue
         encoding_tensor = pmod(encoding_tensor, self.modulus)
         pod_vector.from_np(encoding_tensor.numpy().astype(np.uint64))
         self.batch_encoder.encode(pod_vector, pt)
         sub_dotted = Ciphertext(self.input_cts[idx_input_batch])
         self.evaluator.multiply_plain_inplace(sub_dotted, pt)
         self.evaluator.add_inplace(self.output_cts[idx_output_batch],
                                    sub_dotted)
Beispiel #3
0
def encrypt_zeros(num_batch, batch_encoder, encryptor, degree):
    cts = [Ciphertext() for i in range(num_batch)]
    pod_vector = uIntVector()
    pt = Plaintext()
    zeros_tensor = np.zeros(degree).astype(np.int64)
    pod_vector.from_np(zeros_tensor)
    batch_encoder.encode(pod_vector, pt)
    for i in range(num_batch):
        encryptor.encrypt(pt, cts[i])
    return cts
Beispiel #4
0
    def __init__(self, num_elem, modulus, degree, context, evaluator, batch_encoder, encryptor, is_cheap_init=False):
        super().__init__(num_elem, modulus, degree,
                         context=context, evaluator=evaluator, batch_encoder=batch_encoder, encryptor=encryptor)
        self.pod_vector = uIntVector()
        self.pt = Plaintext()
        self.cts = [Ciphertext() for i in range(self.num_batch)]
        self.multiply_enc_quota = 1
        self.multiply_enc_times = 0
        self.multiply_plain_quota = 1
        self.multiply_plain_times = 0

        if not is_cheap_init:
            self.encrypt_zeros()
            self.all_zeros = True
Beispiel #5
0
 def encode_input_to_fhe_batch(self, input_tensor):
     assert (input_tensor.shape == self.input_shape)
     self.input_cts = [Ciphertext() for _ in range(self.num_input_batch)]
     pod_vector = uIntVector()
     pt = Plaintext()
     for index_batch in range(self.num_input_batch):
         encoding_tensor = torch.zeros(self.degree)
         start_unit, end_unit = self.index_input_batch_to_units(index_batch)
         input_span = end_unit - start_unit
         piece_span = self.num_elem_in_piece
         for i in range(self.num_piece_in_batch):
             encoding_tensor[i * piece_span:i * piece_span +
                             input_span] = input_tensor[start_unit:end_unit]
         encoding_tensor = pmod(encoding_tensor, self.modulus)
         pod_vector.from_np(encoding_tensor.numpy().astype(np.uint64))
         self.batch_encoder.encode(pod_vector, pt)
         self.encryptor.encrypt(pt, self.input_cts[index_batch])
Beispiel #6
0
 def encode_input_to_fhe_batch(self, input_tensor):
     assert (input_tensor.shape == self.input_shape)
     self.conv2d_ntt.load_and_ntt_x(input_tensor)
     ntted_input = self.conv2d_ntt.ntted_x
     self.input_cts = [Ciphertext() for _ in range(self.num_input_batch)]
     pod_vector = uIntVector()
     pt = Plaintext()
     for index_batch in range(self.num_input_batch):
         encoding_tensor = torch.zeros(self.degree, dtype=torch.float)
         for index_piece in range(self.num_rotation):
             index_input_channel = self.index_input_piece_to_channel(
                 index_batch, index_piece)
             if index_input_channel is False:
                 continue
             span = self.num_elem_in_padded
             start_piece = index_piece * span
             encoding_tensor[start_piece:start_piece + span] = ntted_input[
                 index_input_channel].reshape(-1)
         encoding_tensor = pmod(encoding_tensor, self.modulus)
         pod_vector.from_np(encoding_tensor.numpy().astype(np.uint64))
         self.batch_encoder.encode(pod_vector, pt)
         self.encryptor.encrypt(pt, self.input_cts[index_batch])
Beispiel #7
0
    def decode_output_from_fhe_batch(self):
        output_tensor = torch.zeros(self.output_shape)
        pod_vector = uIntVector()
        pt = Plaintext()
        cts = self.output_cts
        for index_output_batch in range(self.num_output_batch):
            self.decryptor.decrypt(cts[index_output_batch], pt)
            self.batch_encoder.decode(pt, pod_vector)
            arr = np.array(pod_vector, copy=False)
            arr = torch.from_numpy(arr.astype(np.float))
            for index_rot in range(self.num_rotation):
                index_output_channel = self.index_output_piece_to_channel(
                    index_output_batch, index_rot)
                if index_output_channel is False:
                    continue
                span = self.num_elem_in_padded
                start_piece = index_rot * span
                sub_ntted_y = arr[start_piece:start_piece + span].reshape(
                    [self.padded_hw, self.padded_hw])
                sub_y = self.conv2d_ntt.transform_y_single_channel(sub_ntted_y)
                output_tensor[index_output_channel].copy_(sub_y)

        return output_tensor
Beispiel #8
0
    def decode_output_from_fhe_batch(self):
        output_tensor = torch.zeros(self.output_shape).double()
        pod_vector = uIntVector()
        pt = Plaintext()
        cts = self.output_cts
        for idx_output_batch in range(self.num_output_batch):
            self.decryptor.decrypt(cts[idx_output_batch], pt)
            self.batch_encoder.decode(pt, pod_vector)
            arr = np.array(pod_vector, copy=False)
            arr = torch.from_numpy(arr.astype(np.double))
            for idx_piece in range(self.num_piece_in_batch):
                idx_output_unit = self.index_output_batch_to_units(
                    idx_output_batch, idx_piece)
                if idx_output_unit is False:
                    break
                padded_span = self.num_elem_in_piece
                start_piece = idx_piece * padded_span
                output_tensor[idx_output_unit] = arr[start_piece:start_piece +
                                                     padded_span].double().sum(
                                                     )

        output_tensor = pmod(output_tensor, self.modulus)
        return output_tensor
Beispiel #9
0
    def masking_output(self):
        self.output_mask_s = gen_unirand_int_grain(
            0, self.modulus - 1, get_prod(self.y_shape)).reshape(self.y_shape)
        # self.output_mask_s = torch.ones(self.y_shape)
        ntted_mask = self.conv2d_ntt.ntt_output_masking(self.output_mask_s)

        pod_vector = uIntVector()
        pt = Plaintext()
        for idx_output_batch in range(self.num_output_batch):
            encoding_tensor = torch.zeros(self.degree, dtype=torch.float)
            for index_piece in range(self.num_rotation):
                span = self.num_elem_in_padded
                start_piece = index_piece * span
                index_output_channel = self.index_output_piece_to_channel(
                    idx_output_batch, index_piece)
                if index_output_channel is False:
                    continue
                encoding_tensor[start_piece:start_piece + span] = ntted_mask[
                    index_output_channel].reshape(-1)
            encoding_tensor = pmod(encoding_tensor, self.modulus)
            pod_vector.from_np(encoding_tensor.numpy().astype(np.uint64))
            self.batch_encoder.encode(pod_vector, pt)
            self.evaluator.add_plain_inplace(self.output_cts[idx_output_batch],
                                             pt)
Beispiel #10
0
 def compute_conv2d(self, weight_tensor):
     assert (weight_tensor.shape == self.weight_shape)
     self.conv2d_ntt.load_and_ntt_w(weight_tensor)
     ntted_weight = self.conv2d_ntt.ntted_w
     pod_vector = uIntVector()
     pt_w = Plaintext()
     self.output_cts = encrypt_zeros(self.num_output_batch,
                                     self.batch_encoder, self.encryptor,
                                     self.degree)
     for idx_output_batch, idx_input_batch in product(
             range(self.num_output_batch), range(self.num_input_batch)):
         encoding_tensor = torch.zeros(self.degree, dtype=torch.float)
         is_w_changed = False
         for index_piece in range(self.num_rotation):
             span = self.num_elem_in_padded
             start_piece = index_piece * span
             index_input_channel, index_output_channel = \
                 self.index_weight_piece_to_channel(idx_output_batch, idx_input_batch, index_piece)
             if index_input_channel is False or index_output_channel is False:
                 continue
             is_w_changed = True
             encoding_tensor[start_piece: start_piece+span] = \
                 ntted_weight[index_output_channel, index_input_channel].reshape(-1)
         if not is_w_changed:
             continue
         encoding_tensor = pmod(encoding_tensor, self.modulus)
         pod_vector.from_np(encoding_tensor.numpy().astype(np.uint64))
         self.batch_encoder.encode(pod_vector, pt_w)
         sub_dotted = Ciphertext(self.input_cts[idx_input_batch])
         # print(idx_output_batch, idx_input_batch)
         # print("noise", self.decryptor.invariant_noise_budget(self.input_cts[idx_input_batch]))
         # print("noise", self.decryptor.invariant_noise_budget(sub_dotted))
         self.evaluator.multiply_plain_inplace(sub_dotted, pt_w)
         self.evaluator.add_inplace(self.output_cts[idx_output_batch],
                                    sub_dotted)
         del sub_dotted
Beispiel #11
0
 def __init__(self, num_elem, modulus, degree, context, evaluator, batch_encoder):
     super().__init__(num_elem, modulus, degree, context=context, evaluator=evaluator, batch_encoder=batch_encoder)
     self.pod_vector = uIntVector()
     self.batched_pts = [Plaintext(self.degree, 0) for _ in range(self.num_batch)]