Exemplo n.º 1
0
def test_conv2d_fhe_ntt_single_thread():
    modulus = 786433
    img_hw = 16
    filter_hw = 3
    padding = 1
    num_input_channel = 64
    num_output_channel = 128
    data_bit = 17
    data_range = 2**data_bit

    x_shape = [num_input_channel, img_hw, img_hw]
    w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape)
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(w_shape)).reshape(w_shape)
    x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape)
    # x = torch.arange(get_prod(x_shape)).reshape(x_shape)
    # w = torch.arange(get_prod(w_shape)).reshape(w_shape)

    warming_up_cuda()
    prot = Conv2dFheNttSingleThread(modulus, data_range, img_hw, filter_hw,
                                    num_input_channel, num_output_channel,
                                    fhe_builder, "test_conv2d_fhe_ntt",
                                    padding)

    print("prot.num_input_batch", prot.num_input_batch)
    print("prot.num_output_batch", prot.num_output_batch)

    with NamedTimerInstance("encoding x"):
        prot.encode_input_to_fhe_batch(x)
    with NamedTimerInstance("conv2d with w"):
        prot.compute_conv2d(w)
    with NamedTimerInstance("conv2d masking output"):
        prot.masking_output()
    with NamedTimerInstance("decoding output"):
        y = prot.decode_output_from_fhe_batch()
    # actual = pmod(y, modulus)
    actual = pmod(y - prot.output_mask_s, modulus)
    # print("actual\n", actual)

    torch_x = x.reshape([1] + x_shape).double()
    torch_w = w.reshape(w_shape).double()
    with NamedTimerInstance("Conv2d Torch"):
        expected = F.conv2d(torch_x, torch_w, padding=padding)
        expected = pmod(expected.reshape(prot.output_shape), modulus)
    # print("expected", expected)
    compare_expected_actual(expected,
                            actual,
                            name="test_conv2d_fhe_ntt_single_thread",
                            get_relative=True)
Exemplo n.º 2
0
def test_fc_fhe_single_thread():
    test_name = "test_fc_fhe_single_thread"
    print(f"\nTest for {test_name}: Start")
    modulus = Config.q_23
    num_input_unit = 512
    num_output_unit = 512
    data_bit = 17

    data_range = 2**data_bit
    x_shape = [num_input_unit]
    w_shape = [num_output_unit, num_input_unit]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape)
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(w_shape)).reshape(w_shape)
    x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape)
    # w = gen_unirand_int_grain(0, modulus, get_prod(w_shape)).reshape(w_shape)
    # x = torch.arange(get_prod(x_shape)).reshape(x_shape)
    # w = torch.arange(get_prod(w_shape)).reshape(w_shape)

    warming_up_cuda()
    prot = FcFheSingleThread(modulus, data_range, num_input_unit,
                             num_output_unit, fhe_builder, test_name)

    print("prot.num_input_batch", prot.num_input_batch)
    print("prot.num_output_batch", prot.num_output_batch)
    print("prot.num_elem_in_piece", prot.num_elem_in_piece)

    with NamedTimerInstance("encoding x"):
        prot.encode_input_to_fhe_batch(x)
    with NamedTimerInstance("conv2d with w"):
        prot.compute_with_weight(w)
    with NamedTimerInstance("conv2d masking output"):
        prot.masking_output()
    with NamedTimerInstance("decoding output"):
        y = prot.decode_output_from_fhe_batch()
    actual = pmod(y, modulus)
    actual = pmod(y - prot.output_mask_s, modulus)
    # print("actual\n", actual)

    torch_x = x.reshape([1] + x_shape).double()
    torch_w = w.reshape(w_shape).double()
    with NamedTimerInstance("Conv2d Torch"):
        expected = torch.mm(torch_x, torch_w.t())
        expected = pmod(expected.reshape(prot.output_shape), modulus)
    compare_expected_actual(expected,
                            actual,
                            name=test_name,
                            get_relative=True)
    print(f"\nTest for {test_name}: End")
Exemplo n.º 3
0
 def offline(self):
     device = self.next_input_device
     dtype = self.next_input_dtype
     swap_prot_name = self.sub_name("swap_prot")
     modulus = self.context.q_23
     if self.is_server():
         self.swap_prot = SwapToClientOfflineServer(
             get_prod(self.input_shape), modulus, swap_prot_name)
         self.dummy_input_s = torch.zeros(
             self.input_shape).to(device).type(dtype)
         self.swap_prot.offline()
     elif self.is_client():
         self.swap_prot = SwapToClientOfflineClient(
             get_prod(self.input_shape), modulus, swap_prot_name)
         self.output_share = generate_random_mask(modulus, self.input_shape)
         self.swap_prot.offline(self.output_share.reshape(-1))
         self.output_share = self.output_share.to(device).type(dtype)
Exemplo n.º 4
0
def test_ntt_conv():
    modulus = 786433
    img_hw = 16
    filter_hw = 3
    padding = 1
    num_input_channel = 64
    num_output_channel = 128
    data_bit = 17
    data_range = 2**data_bit

    x_shape = [num_input_channel, img_hw, img_hw]
    w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw]

    x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(x_shape)).reshape(x_shape)
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(w_shape)).reshape(w_shape)
    # x = torch.arange(get_prod(x_shape)).reshape(x_shape)
    # w = torch.arange(get_prod(w_shape)).reshape(w_shape)

    conv2d_ntt = Conv2dNtt(modulus, data_range, img_hw, filter_hw,
                           num_input_channel, num_output_channel,
                           "test_conv2d_ntt", padding)
    y = conv2d_ntt.conv2d(x, w)

    with NamedTimerInstance("ntt x"):
        conv2d_ntt.load_and_ntt_x(x)
    with NamedTimerInstance("ntt w"):
        conv2d_ntt.load_and_ntt_w(w)
    with NamedTimerInstance("conv2d"):
        y = conv2d_ntt.conv2d_loaded()
    actual = pmod(y, modulus)
    # print("actual\n", actual)

    torch_x = x.reshape([1] + x_shape).double()
    torch_w = w.reshape(w_shape).double()
    with NamedTimerInstance("Conv2d Torch"):
        expected = F.conv2d(torch_x, torch_w, padding=padding)
        expected = pmod(expected.reshape(conv2d_ntt.y_shape), modulus)
    # print("expected", expected)
    compare_expected_actual(expected,
                            actual,
                            name="ntt",
                            get_relative=True,
                            show_where_err=False)
Exemplo n.º 5
0
    def offline(self):
        modulus = self.context.q_23
        swap_prot_name = self.sub_name("swap_prot")

        if self.is_need_swap:
            if self.is_server():
                self.swap_prot = SwapToClientOfflineServer(
                    get_prod(self.input_shape), modulus, swap_prot_name)
                self.swap_prot.offline()
            elif self.is_client():
                self.swap_prot = SwapToClientOfflineClient(
                    get_prod(self.input_shape), modulus, swap_prot_name)
                self.swapped_input_c = generate_random_mask(
                    modulus, self.input_shape)
                self.swap_prot.offline(self.swapped_input_c.reshape(-1))
                self.swapped_input_c = self.swapped_input_c.to(
                    Config.device).reshape(self.input_shape)
        if not self.is_need_swap and self.is_client():
            self.swapped_input_c = self.get_input_share().to(Config.device)
Exemplo n.º 6
0
    def offline(self):
        num_elem = get_prod(self.input_shape)
        modulus = self.context.q_23
        name = self.class_name + self.name
        if self.is_server():
            self.prot = ReconToClientServer(num_elem, modulus, name)
        elif self.is_client():
            self.prot = ReconToClientClient(num_elem, modulus, name)

        self.prot.offline()
Exemplo n.º 7
0
    def offline(self):
        num_elem = get_prod(self.input_shape)
        name = self.sub_name("trunc_prot")

        if self.is_server():
            self.prot = TruncServer(num_elem, self.context.q_23,
                                    self.div_to_pow,
                                    self.context.fhe_builder_23, name)
        elif self.is_client():
            self.prot = TruncClient(num_elem, self.context.q_23,
                                    self.div_to_pow,
                                    self.context.fhe_builder_23, name)

        self.prot.offline()
Exemplo n.º 8
0
    def offline(self):
        num_elem = get_prod(self.input_shape)
        name = self.sub_name("maxpool2x2_dgk_prot")

        if self.is_server():
            self.prot = Maxpool2x2DgkServer(
                num_elem, self.context.q_23, self.context.q_16,
                self.context.work_bit, self.context.data_bit, self.input_hw,
                self.context.fhe_builder_16, self.context.fhe_builder_23, name)
        elif self.is_client():
            self.prot = Maxpool2x2DgkClient(
                num_elem, self.context.q_23, self.context.q_16,
                self.context.work_bit, self.context.data_bit, self.input_hw,
                self.context.fhe_builder_16, self.context.fhe_builder_23, name)

        self.prot.offline()
Exemplo n.º 9
0
    def masking_output(self):
        self.output_mask_s = gen_unirand_int_grain(
            0, self.modulus - 1, get_prod(self.y_shape)).reshape(self.y_shape)
        # self.output_mask_s = torch.ones(self.y_shape)
        ntted_mask = self.conv2d_ntt.ntt_output_masking(self.output_mask_s)

        pod_vector = uIntVector()
        pt = Plaintext()
        for idx_output_batch in range(self.num_output_batch):
            encoding_tensor = torch.zeros(self.degree, dtype=torch.float)
            for index_piece in range(self.num_rotation):
                span = self.num_elem_in_padded
                start_piece = index_piece * span
                index_output_channel = self.index_output_piece_to_channel(
                    idx_output_batch, index_piece)
                if index_output_channel is False:
                    continue
                encoding_tensor[start_piece:start_piece + span] = ntted_mask[
                    index_output_channel].reshape(-1)
            encoding_tensor = pmod(encoding_tensor, self.modulus)
            pod_vector.from_np(encoding_tensor.numpy().astype(np.uint64))
            self.batch_encoder.encode(pod_vector, pt)
            self.evaluator.add_plain_inplace(self.output_cts[idx_output_batch],
                                             pt)
Exemplo n.º 10
0
def test_conv2d_fhe_ntt_comm():
    test_name = "Conv2d Fhe NTT Comm"
    print(f"\nTest for {test_name}: Start")
    modulus = 786433
    img_hw = 2
    filter_hw = 3
    padding = 1
    num_input_channel = 512
    num_output_channel = 512
    data_bit = 17
    data_range = 2**data_bit
    print(f"Setting: img_hw {img_hw}, "
          f"filter_hw: {filter_hw}, "
          f"num_input_channel: {num_input_channel}, "
          f"num_output_channel: {num_output_channel}")

    x_shape = [num_input_channel, img_hw, img_hw]
    w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                   get_prod(w_shape)).reshape(w_shape)
    input_mask = gen_unirand_int_grain(0, modulus - 1,
                                       get_prod(x_shape)).reshape(x_shape)

    # input_mask = torch.arange(get_prod(x_shape)).reshape(x_shape)

    def check_correctness_offline(x, w, output_mask, output_c):
        actual = pmod(output_c.cuda() - output_mask.cuda(), modulus)
        torch_x = x.reshape([1] + x_shape).cuda().double()
        torch_w = w.reshape(w_shape).cuda().double()
        expected = F.conv2d(torch_x, torch_w, padding=padding)
        expected = pmod(expected.reshape(output_mask.shape), modulus)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + " offline",
                                get_relative=True)

    def test_server():
        init_communicate(Config.server_rank)
        prot = Conv2dFheNttServer(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_fhe_ntt_comm",
                                  padding=padding)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_offline(input_mask, weight, prot.output_mask_s,
                                  output_c)

        end_communicate()

    def test_client():
        init_communicate(Config.client_rank)
        prot = Conv2dFheNttClient(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_fhe_ntt_comm",
                                  padding=padding)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_mask)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()

    marshal_funcs([test_server, test_client])
    print(f"\nTest for {test_name}: End")
Exemplo n.º 11
0
def test_fc_secure_comm(input_sid,
                        master_address,
                        master_port,
                        setting=(512, 512)):
    test_name = "test_fc_secure_comm"
    print(f"\nTest for {test_name}: Start")
    modulus = 786433
    num_input_unit, num_output_unit = setting
    data_bit = 17

    print(f"Setting fc: "
          f"num_input_unit: {num_input_unit}, "
          f"num_output_unit: {num_output_unit}")

    data_range = 2**data_bit
    x_shape = [num_input_unit]
    w_shape = [num_output_unit, num_input_unit]
    b_shape = [num_output_unit]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                   get_prod(w_shape)).reshape(w_shape)
    bias = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                 get_prod(b_shape)).reshape(b_shape)
    input = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  get_prod(x_shape)).reshape(x_shape)
    input_c = generate_random_mask(modulus, x_shape)
    input_s = pmod(input - input_c, modulus)

    def check_correctness_online(x, w, output_s, output_c):
        actual = pmod(output_s.cuda() + output_c.cuda(), modulus)
        torch_x = x.reshape([1] + x_shape).cuda().double()
        torch_w = w.reshape(w_shape).cuda().double()
        expected = torch.mm(torch_x, torch_w.t())
        if bias is not None:
            expected += bias.cuda().double()
        expected = pmod(expected.reshape(output_s.shape), modulus)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + " online",
                                get_relative=True)

    def test_server():
        rank = Config.server_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = FcSecureServer(modulus, data_range, num_input_unit,
                              num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight, bias=bias)
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(input_s)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_online(input, weight, prot.output_s, output_c)

        end_communicate()

    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = FcSecureClient(modulus, data_range, num_input_unit,
                              num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online()
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()

    marshal_funcs([test_server, test_client])
    print(f"\nTest for {test_name}: End")
Exemplo n.º 12
0
def test_fc_fhe_comm():
    test_name = "test_fc_fhe_comm"
    print(f"\nTest for {test_name}: Start")
    modulus = 786433
    num_input_unit = 512
    num_output_unit = 512
    data_bit = 17

    print(f"Setting: num_input_unit {num_input_unit}, "
          f"num_output_unit: {num_output_unit}")

    data_range = 2**data_bit
    x_shape = [num_input_unit]
    w_shape = [num_output_unit, num_input_unit]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                   get_prod(w_shape)).reshape(w_shape)
    input_mask = gen_unirand_int_grain(0, modulus - 1,
                                       get_prod(x_shape)).reshape(x_shape)

    # input_mask = torch.arange(get_prod(x_shape)).reshape(x_shape)

    def check_correctness_offline(x, w, output_mask, output_c):
        actual = pmod(output_c.cuda() - output_mask.cuda(), modulus)
        torch_x = x.reshape([1] + x_shape).cuda().double()
        torch_w = w.reshape(w_shape).cuda().double()
        expected = torch.mm(torch_x, torch_w.t())
        expected = pmod(expected.reshape(output_mask.shape), modulus)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + " offline",
                                get_relative=True)

    def test_server():
        init_communicate(Config.server_rank)
        prot = FcFheServer(modulus, data_range, num_input_unit,
                           num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_offline(input_mask, weight, prot.output_mask_s,
                                  output_c)

        end_communicate()

    def test_client():
        init_communicate(Config.client_rank)
        prot = FcFheClient(modulus, data_range, num_input_unit,
                           num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_mask)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()

    marshal_funcs([test_server, test_client])
    print(f"\nTest for {test_name}: End")
Exemplo n.º 13
0
 def build_plain_from_torch(self, torch_tensor) -> FhePlainTensor:
     res = self.build_plain(get_prod(torch_tensor.size()))
     res.load_from_torch(torch_tensor)
     return res
Exemplo n.º 14
0
 def register_prev_layer(self, layer: SecureLayerBase):
     SecureLayerBase.register_prev_layer(self, layer)
     self.output_shape = get_torch_size(get_prod(self.input_shape))
Exemplo n.º 15
0
def test_conv2d_secure_comm(input_sid,
                            master_address,
                            master_port,
                            setting=(16, 3, 128, 128)):
    test_name = "Conv2d Secure Comm"
    print(f"\nTest for {test_name}: Start")
    modulus = 786433
    padding = 1
    img_hw, filter_hw, num_input_channel, num_output_channel = setting
    data_bit = 17
    data_range = 2**data_bit
    # n_23 = 8192
    n_23 = 16384
    print(f"Setting covn2d: img_hw: {img_hw}, "
          f"filter_hw: {filter_hw}, "
          f"num_input_channel: {num_input_channel}, "
          f"num_output_channel: {num_output_channel}")

    x_shape = [num_input_channel, img_hw, img_hw]
    w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw]
    b_shape = [num_output_channel]

    fhe_builder = FheBuilder(modulus, n_23)
    fhe_builder.generate_keys()

    weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                   get_prod(w_shape)).reshape(w_shape)
    bias = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                 get_prod(b_shape)).reshape(b_shape)
    input = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  get_prod(x_shape)).reshape(x_shape)
    input_c = generate_random_mask(modulus, x_shape)
    input_s = pmod(input - input_c, modulus)

    def check_correctness_online(x, w, b, output_s, output_c):
        actual = pmod(output_s.cuda() + output_c.cuda(), modulus)
        torch_x = x.reshape([1] + x_shape).cuda().double()
        torch_w = w.reshape(w_shape).cuda().double()
        torch_b = b.cuda().double() if b is not None else None
        expected = F.conv2d(torch_x, torch_w, padding=padding, bias=torch_b)
        expected = pmod(expected.reshape(output_s.shape), modulus)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + " online",
                                get_relative=True)

    def test_server():
        rank = Config.server_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = Conv2dSecureServer(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_secure_comm",
                                  padding=padding)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight, bias=bias)
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(input_s)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_online(input, weight, bias, prot.output_s, output_c)

        end_communicate()

    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = Conv2dSecureClient(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_secure_comm",
                                  padding=padding)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online()
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()

    if input_sid == Config.both_rank:
        marshal_funcs([test_server, test_client])
    elif input_sid == Config.server_rank:
        test_server()
    elif input_sid == Config.client_rank:
        test_client()

    print(f"\nTest for {test_name}: End")
Exemplo n.º 16
0
 def generate_random_data(self, shape):
     return gen_unirand_int_grain(-self.data_range//2 + 1, self.data_range//2, get_prod(shape)).reshape(shape)
Exemplo n.º 17
0
 def build_enc_from_torch(self, torch_tensor) -> FheEncTensor:
     res = self.build_enc(get_prod(torch_tensor.size()))
     res.encrypt_additive(torch_tensor)
     return res