コード例 #1
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = Conv2dSecureClient(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_secure_comm",
                                  padding=padding)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online()
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()
コード例 #2
0
    def test_server():
        init_communicate(Config.server_rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()
        prot = Avgpool2x2Server(num_elem, q_23, q_16, work_bit, data_bit,
                                img_hw, fhe_builder_16, fhe_builder_23,
                                "avgpool")

        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            prot.online(img_s)
            torch_sync()
        traffic_record.reset("server-online")

        blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_res_c")
        blob_out_c.prepare_recv()
        torch_sync()
        out_c = blob_out_c.get_recv()
        check_correctness_online(img, prot.out_s, out_c)

        end_communicate()
コード例 #3
0
    def test_server():
        rank = Config.server_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = Conv2dSecureServer(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_secure_comm",
                                  padding=padding)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight, bias=bias)
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(input_s)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_online(input, weight, bias, prot.output_s, output_c)

        end_communicate()
コード例 #4
0
    def test_client():
        init_communicate(Config.client_rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()
        prot = Avgpool2x2Client(num_elem, q_23, q_16, work_bit, data_bit,
                                img_hw, fhe_builder_16, fhe_builder_23,
                                "avgpool")

        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            prot.online(img_c)
            torch_sync()
        traffic_record.reset("client-online")

        blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_res_c")
        torch_sync()
        blob_out_c.send(prot.out_c)
        end_communicate()
コード例 #5
0
    def __init__(self, num_elem, modulus, fhe_builder: FheBuilder, name: str,
                 rank):
        self.num_elem = num_elem
        self.modulus = modulus
        self.name = name
        self.fhe_builder = fhe_builder
        self.rank = rank
        self.comm_base = CommBase(rank, name)
        self.comm_fhe = CommFheBuilder(rank, fhe_builder,
                                       self.sub_name("comm_fhe"))

        assert (self.modulus == self.fhe_builder.modulus)

        self.blob_fhe_u = BlobFheEnc(num_elem, self.comm_fhe,
                                     self.sub_name("fhe_u"))
        self.blob_fhe_v = BlobFheEnc(num_elem, self.comm_fhe,
                                     self.sub_name("fhe_v"))
        self.blob_fhe_z_c = BlobFheEnc(num_elem,
                                       self.comm_fhe,
                                       self.sub_name("fhe_z_c"),
                                       ee_mult_time=1)

        self.blob_e_s = BlobTorch(num_elem, torch.float, self.comm_base, "e_s")
        self.blob_f_s = BlobTorch(num_elem, torch.float, self.comm_base, "f_s")
        self.blob_e_c = BlobTorch(num_elem, torch.float, self.comm_base, "e_c")
        self.blob_f_c = BlobTorch(num_elem, torch.float, self.comm_base, "f_c")

        self.offline_server_send = [self.blob_fhe_z_c]
        self.offline_client_send = [self.blob_fhe_u, self.blob_fhe_v]
        self.online_server_send = [self.blob_e_s, self.blob_f_s]
        self.online_client_send = [self.blob_e_c, self.blob_f_c]
コード例 #6
0
    def test_client():
        init_communicate(Config.client_rank)
        prot = FcFheClient(modulus, data_range, num_input_unit,
                           num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_mask)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()
コード例 #7
0
    def test_client():
        init_communicate(Config.client_rank)
        warming_up_cuda()
        prot = SwapToClientOfflineClient(num_elem, modulus, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(y_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online(x_c)
            torch_sync()

        blob_output_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                                  "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()
コード例 #8
0
    def test_server():
        init_communicate(Config.server_rank)
        prot = FcFheServer(modulus, data_range, num_input_unit,
                           num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_offline(input_mask, weight, prot.output_mask_s,
                                  output_c)

        end_communicate()
コード例 #9
0
    def test_server():
        rank = Config.server_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        img = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                    num_elem)
        img_s = gen_unirand_int_grain(0, q_23 - 1, num_elem)
        img_c = pmod(img - img_s, q_23)

        prot = Maxpool2x2DgkServer(num_elem, q_23, q_16, work_bit, data_bit,
                                   img_hw, fhe_builder_16, fhe_builder_23,
                                   "max_dgk")

        blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                               "recon_max_c")
        torch_sync()
        blob_img_c.send(img_c)

        torch_sync()
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            prot.online(img_s)
            torch_sync()
        traffic_record.reset("server-online")

        blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_max_c")
        blob_max_c.prepare_recv()
        torch_sync()
        max_c = blob_max_c.get_recv()
        check_correctness_online(img, prot.max_s, max_c)

        end_communicate()
コード例 #10
0
    def test_server():
        init_communicate(Config.server_rank)
        warming_up_cuda()
        prot = SwapToClientOfflineServer(num_elem, modulus, test_name)
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Server online"):
            prot.online(x_s)
            torch_sync()

        blob_output_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                                  "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_online(x_s, x_c, prot.output_s, output_c)

        end_communicate()
コード例 #11
0
    def test_client():
        init_communicate(Config.client_rank)
        prot = Conv2dFheNttClient(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_fhe_ntt_comm",
                                  padding=padding)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_mask)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()
コード例 #12
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = FcSecureClient(modulus, data_range, num_input_unit,
                              num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online()
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()
コード例 #13
0
    def __init__(self, num_elem, q_23, q_16, work_bit, data_bit, name: str,
                 comm_base: CommBase, comm_fhe_16: CommFheBuilder,
                 comm_fhe_23: CommFheBuilder):
        super(DgkBitCommon, self).__init__(num_elem, q_23, q_16, work_bit,
                                           data_bit)
        self.comm_base = comm_base
        self.comm_fhe_16 = comm_fhe_16
        self.comm_fhe_23 = comm_fhe_23
        self.name = name

        self.beta_i_c = BlobFheEnc2D(self.decomp_bit_shape, comm_fhe_16,
                                     self.sub_name("beta_i_c"))
        self.delta_b_c = BlobFheEnc(self.num_elem, comm_fhe_23,
                                    self.sub_name("delta_b_c"))
        self.z_work_c = BlobFheEnc(self.num_elem, comm_fhe_23,
                                   self.sub_name("z_work_c"))
        self.c_i_c = BlobFheEnc2D(self.sum_shape, comm_fhe_16,
                                  self.sub_name("c_i_c"))
        self.dgk_x_leq_y_c = BlobFheEnc(self.num_elem, comm_fhe_23,
                                        self.sub_name("dgk_x_leq_y_c"))
        self.delta_xor_c = BlobFheEnc(self.num_elem, comm_fhe_23,
                                      self.sub_name("delta_xor_c"))
        self.fhe_pre_corr_mod = BlobFheEnc(self.num_elem, comm_fhe_23,
                                           self.sub_name("fhe_pre_corr_mod"))
        self.fhe_corr_mod_c = BlobFheEnc(self.num_elem, comm_fhe_23,
                                         self.sub_name("fhe_corr_mod_c"))

        self.z_s = BlobTorch(self.num_elem, torch.float, self.comm_base,
                             self.sub_name("z_s"))
        # self.beta_i_s = BlobTorch(self.decomp_bit_shape, torch.int16, self.comm_base, self.sub_name("beta_i_s"), comp_dtype=torch.float)
        self.beta_i_s = BlobTorch(self.decomp_bit_shape,
                                  torch.float,
                                  self.comm_base,
                                  self.sub_name("beta_i_s"),
                                  comp_dtype=torch.float)
        # self.c_i_s = BlobTorch(self.sum_shape, torch.int16, self.comm_base, self.sub_name("c_i_s"), comp_dtype=torch.float)
        self.c_i_s = BlobTorch(self.sum_shape,
                               torch.float,
                               self.comm_base,
                               self.sub_name("c_i_s"),
                               comp_dtype=torch.float)
        self.delta_b_s = BlobTorch(self.num_elem, torch.float, self.comm_base,
                                   self.sub_name("delta_b_s"))
        self.z_work_s = BlobTorch(self.num_elem, torch.float, self.comm_base,
                                  self.sub_name("z_work_s"))
        self.pre_corr_mod_s = BlobTorch(self.num_elem, torch.float,
                                        self.comm_base,
                                        self.sub_name("pre_corr_mod_s"))

        self.offline_server_send = [
            self.c_i_c, self.dgk_x_leq_y_c, self.delta_xor_c,
            self.fhe_corr_mod_c
        ]
        self.offline_client_send = [
            self.beta_i_c, self.delta_b_c, self.z_work_c, self.fhe_pre_corr_mod
        ]
        self.online_server_send = [self.z_s, self.c_i_s]
        self.online_client_send = [
            self.beta_i_s, self.delta_b_s, self.z_work_s, self.pre_corr_mod_s
        ]
コード例 #14
0
    def __init__(self, num_elem, q_23, q_16, work_bit, data_bit, img_hw,
                 fhe_builder_16: FheBuilder, fhe_builder_23: FheBuilder,
                 name: str, rank):
        super(Avgpool2x2Common,
              self).__init__(num_elem, q_23, q_16, work_bit, data_bit,
                             fhe_builder_16, fhe_builder_23, name, rank,
                             "Avgpool2x2Comm")
        self.img_hw = img_hw
        self.fhe_builder = self.fhe_builder_23
        self.modulus = self.q_23
        self.pool = torch.nn.AvgPool2d(2, divisor_override=1)

        if num_elem % 4 != 0:
            raise Exception(
                f"num_elem should be divisible by 4, but got {num_elem}")
        if img_hw % 2 != 0:
            raise Exception(
                f"img_hw should be divisible by 2, but got {img_hw}")

        self.blob_offline_input = BlobTorch(num_elem, torch.float,
                                            self.comm_base, "offline_input")
        self.blob_online_input = BlobTorch(num_elem // 4, torch.float,
                                           self.comm_base, "online_input")
コード例 #15
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank, master_address=master_address, master_port=master_port)
        traffic_record = TrafficRecord()
        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        a = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem)
        a_c = gen_unirand_int_grain(0, q_23 - 1, num_elem)
        a_s = pmod(a - a_c, q_23)

        prot = ReluDgkClient(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk")

        blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a")
        blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s")
        torch_sync()
        blob_a_s.send(a_s)
        blob_max_s.prepare_recv()

        torch_sync()
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            prot.online(a_c)
            torch_sync()
        traffic_record.reset("client-online")

        max_s = blob_max_s.get_recv()
        check_correctness_online(a, max_s, prot.max_c)

        torch.cuda.empty_cache()
        end_communicate()
コード例 #16
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        prot = Maxpool2x2DgkClient(num_elem, q_23, q_16, work_bit, data_bit,
                                   img_hw, fhe_builder_16, fhe_builder_23,
                                   "max_dgk")
        blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                               "recon_max_c")
        blob_img_c.prepare_recv()
        torch_sync()
        img_c = blob_img_c.get_recv()

        torch_sync()
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            prot.online(img_c)
            torch_sync()
        traffic_record.reset("client-online")

        blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_max_c")
        torch_sync()
        blob_max_c.send(prot.max_c)
        end_communicate()
コード例 #17
0
    def test_server():
        init_communicate(Config.server_rank)
        prot = Conv2dFheNttServer(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_fhe_ntt_comm",
                                  padding=padding)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_offline(input_mask, weight, prot.output_mask_s,
                                  output_c)

        end_communicate()
コード例 #18
0
    def __init__(self, num_elem, modulus, name: str, rank):
        super().__init__(name)
        self.rank = rank
        self.name = name
        self.comm_base = CommBase(rank, self.sub_name("comm_base"))
        self.num_elem = num_elem
        self.modulus = modulus
        self.comp_device = torch.device("cuda")

        self.blob_masked_input_s = BlobTorch(self.num_elem,
                                             torch.float,
                                             self.comm_base,
                                             "masked_input_s",
                                             dst_device=self.comp_device)
        self.blob_masked_output_s = BlobTorch(self.num_elem,
                                              torch.float,
                                              self.comm_base,
                                              "masked_output_s",
                                              dst_device=self.comp_device)

        self.offline_server_send = []
        self.offline_client_send = []
        self.online_server_send = [self.blob_masked_input_s]
        self.online_client_send = [self.blob_masked_output_s]
コード例 #19
0
    def test_server():
        rank = Config.server_rank
        init_communicate(Config.server_rank, master_address=master_address, master_port=master_port)
        traffic_record = TrafficRecord()
        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        prot = ReluDgkServer(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk")

        blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a")
        blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s")
        torch_sync()
        blob_a_s.prepare_recv()
        a_s = blob_a_s.get_recv()

        torch_sync()
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            prot.online(a_s)
            torch_sync()
        traffic_record.reset("server-online")

        blob_max_s.send(prot.max_s)
        torch.cuda.empty_cache()
        end_communicate()
コード例 #20
0
    def reconstructed_to_server(self, comm_base: CommBase, modulus):
        blob_output_share = BlobTorch(self.get_output_shape(), torch.float,
                                      comm_base, self.name + "_output_share")

        if self.is_server():
            blob_output_share.prepare_recv()
            torch_sync()
            other_output_share = blob_output_share.get_recv()
            # print(self.name + "_output_share" + "_server: have", self.get_output_share())
            # print(self.name + "_output_share" + "_server: received", other_output_share)
            self.reconstructed_output = nmod(
                self.get_output_share() + other_output_share, modulus)
            # print(self.name + "_output_share" + "_server: recon", self.reconstructed_output)

        if self.is_client():
            torch_sync()
            blob_output_share.send(self.get_output_share())
コード例 #21
0
    def check_correctness(self, verify_func, is_argmax=False, truth=None):
        blob_input_img = BlobTorch(self.get_input_shape(), torch.float, self.comm_base, "input_img")
        blob_actual_output = BlobTorch(self.get_output_shape(), torch.float, self.comm_base, "actual_output")
        blob_truth = BlobTorch(1, torch.float, self.comm_base, "truth")

        if self.is_server():
            blob_input_img.prepare_recv()
            blob_actual_output.prepare_recv()
            blob_truth.prepare_recv()
            torch_sync()
            input_img = blob_input_img.get_recv()
            actual_output = blob_actual_output.get_recv()
            truth = int(blob_truth.get_recv().item())
            verify_func(self, input_img, actual_output, self.q_23)

            actual_output = nmod(actual_output, self.q_23).cuda()
            _, actual_max = torch.max(actual_output, 0)
            print(f"truth: {truth}, actual: {actual_max}, MatchTruth: {truth == actual_max}")

        if self.is_client():
            torch_sync()
            actual_output = self.secure_nn_core.get_argmax_output() if is_argmax else self.secure_nn_core.get_output()
            blob_input_img.send(self.input_img)
            blob_actual_output.send(actual_output)
            blob_truth.send(torch.tensor(truth))

        return self
コード例 #22
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, n_16)
        fhe_builder_23 = FheBuilder(q_23, n_23)
        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        dgk = DgkBitClient(num_elem, q_23, q_16, work_bit, data_bit,
                           fhe_builder_16, fhe_builder_23, "DgkBitTest")

        x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  num_elem)
        y = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  num_elem)
        x_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                    num_elem)
        y_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                    num_elem)
        x_s = pmod(x - x_c, q_23)
        y_s = pmod(y - y_c, q_23)
        y_sub_x_s = pmod(y_s - x_s, q_23)

        x_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "x")
        y_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y")
        y_sub_x_s_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                   "y_sub_x_s")
        torch_sync()
        x_blob.send(x)
        y_blob.send(y)
        y_sub_x_s_blob.send(y_sub_x_s)

        torch_sync()
        with NamedTimerInstance("Client Offline"):
            dgk.offline()
        y_sub_x_c = pmod(y_c - x_c, q_23)
        traffic_record.reset("client-offline")
        torch_sync()

        with NamedTimerInstance("Client Online"):
            dgk.online(y_sub_x_c)
        traffic_record.reset("client-online")

        dgk_x_leq_y_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                       "dgk_x_leq_y_c")
        correct_mod_div_work_c_blob = BlobTorch(num_elem, torch.float,
                                                dgk.comm_base,
                                                "correct_mod_div_work_c")
        z_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "z")
        torch_sync()
        dgk_x_leq_y_c_blob.send(dgk.dgk_x_leq_y_c)
        correct_mod_div_work_c_blob.send(dgk.correct_mod_div_work_c)
        z_blob.send(dgk.z)
        end_communicate()
コード例 #23
0
    def test_server():
        rank = Config.server_rank
        init_communicate(Config.server_rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        dgk = DgkBitServer(num_elem, q_23, q_16, work_bit, data_bit,
                           fhe_builder_16, fhe_builder_23, "DgkBitTest")

        x_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "x")
        y_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y")
        y_sub_x_s_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                   "y_sub_x_s")
        x_blob.prepare_recv()
        y_blob.prepare_recv()
        y_sub_x_s_blob.prepare_recv()
        torch_sync()
        x = x_blob.get_recv()
        y = y_blob.get_recv()
        y_sub_x_s = y_sub_x_s_blob.get_recv()

        torch_sync()
        with NamedTimerInstance("Server Offline"):
            dgk.offline()
        # y_sub_x_s = pmod(y_s.to(Config.device) - x_s.to(Config.device), q_23)
        torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            dgk.online(y_sub_x_s)
        traffic_record.reset("server-online")

        dgk_x_leq_y_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                       "dgk_x_leq_y_c")
        correct_mod_div_work_c_blob = BlobTorch(num_elem, torch.float,
                                                dgk.comm_base,
                                                "correct_mod_div_work_c")
        z_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "z")
        dgk_x_leq_y_c_blob.prepare_recv()
        correct_mod_div_work_c_blob.prepare_recv()
        z_blob.prepare_recv()
        torch_sync()
        dgk_x_leq_y_c = dgk_x_leq_y_c_blob.get_recv()
        correct_mod_div_work_c = correct_mod_div_work_c_blob.get_recv()
        z = z_blob.get_recv()
        check_correctness(x, y, dgk.dgk_x_leq_y_s, dgk_x_leq_y_c)
        check_correctness_mod_div(dgk.r, z, dgk.correct_mod_div_work_s,
                                  correct_mod_div_work_c)
        end_communicate()
コード例 #24
0
    def test_client():
        init_communicate(Config.client_rank)
        prot = SharesMultClient(num_elem, modulus, fhe_builder,
                                "test_shares_mult")
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online(a_c, b_c)
            torch_sync()

        blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_u_c")
        blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_v_c")
        blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_z_c")
        torch_sync()
        blob_u_c.send(prot.u_c)
        blob_v_c.send(prot.v_c)
        blob_z_c.send(prot.z_c)
        blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c")
        torch_sync()
        blob_c_c.send(prot.c_c)
        end_communicate()
コード例 #25
0
    def test_server():
        init_communicate(Config.server_rank)
        prot = SharesMultServer(num_elem, modulus, fhe_builder,
                                "test_shares_mult")
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(a_s, b_s)
            torch_sync()

        blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_u_c")
        blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_v_c")
        blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_z_c")
        blob_u_c.prepare_recv()
        blob_v_c.prepare_recv()
        blob_z_c.prepare_recv()
        torch_sync()
        u_c = blob_u_c.get_recv()
        v_c = blob_v_c.get_recv()
        z_c = blob_z_c.get_recv()
        u = pmod(prot.u_s + u_c, modulus)
        v = pmod(prot.v_s + v_c, modulus)
        check_correctness_online(u, v, prot.z_s, z_c)

        blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c")
        blob_c_c.prepare_recv()
        torch_sync()
        c_c = blob_c_c.get_recv()
        check_correctness_online(a, b, prot.c_s, c_c)
        end_communicate()