Exemplo n.º 1
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank)
        context.set_rank(rank)

        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        input_img = generate_random_data(input_shape)
        trunc1.set_div_to_pow(pow_to_div)
        secure_nn.feed_input(input_img)

        with NamedTimerInstance("Client Offline"):
            secure_nn.offline()
            torch_sync()

        with NamedTimerInstance("Client Online"):
            secure_nn.online()
            torch_sync()

        check_correctness(input_img, secure_nn.get_output())
        end_communicate()
Exemplo n.º 2
0
    def test_server():
        rank = Config.server_rank
        sys.stdout = Logger()
        traffic_record = TrafficRecord()
        secure_nn = get_secure_nn()
        secure_nn.set_rank(rank).init_communication(master_address=master_addr,
                                                    master_port=master_port)
        warming_up_cuda()
        secure_nn.fhe_builder_sync()
        load_trunc_params(secure_nn, store_configs)

        net_state = torch.load(net_state_name)
        load_weight_params(secure_nn, store_configs, net_state)

        meta_rg = MetaTruncRandomGenerator()
        meta_rg.reset_seed()

        with NamedTimerInstance("Server Offline"):
            secure_nn.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            secure_nn.online()
            torch_sync()
        traffic_record.reset("server-online")

        secure_nn.check_correctness(check_correctness)
        secure_nn.check_layers(get_plain_net, get_hooking_lst(model_name_base))
        secure_nn.end_communication()
Exemplo n.º 3
0
    def check_correctness(self, verify_func, is_argmax=False, truth=None):
        blob_input_img = BlobTorch(self.get_input_shape(), torch.float, self.comm_base, "input_img")
        blob_actual_output = BlobTorch(self.get_output_shape(), torch.float, self.comm_base, "actual_output")
        blob_truth = BlobTorch(1, torch.float, self.comm_base, "truth")

        if self.is_server():
            blob_input_img.prepare_recv()
            blob_actual_output.prepare_recv()
            blob_truth.prepare_recv()
            torch_sync()
            input_img = blob_input_img.get_recv()
            actual_output = blob_actual_output.get_recv()
            truth = int(blob_truth.get_recv().item())
            verify_func(self, input_img, actual_output, self.q_23)

            actual_output = nmod(actual_output, self.q_23).cuda()
            _, actual_max = torch.max(actual_output, 0)
            print(f"truth: {truth}, actual: {actual_max}, MatchTruth: {truth == actual_max}")

        if self.is_client():
            torch_sync()
            actual_output = self.secure_nn_core.get_argmax_output() if is_argmax else self.secure_nn_core.get_output()
            blob_input_img.send(self.input_img)
            blob_actual_output.send(actual_output)
            blob_truth.send(torch.tensor(truth))

        return self
Exemplo n.º 4
0
    def test_server():
        rank = Config.server_rank
        init_communicate(rank)
        context.set_rank(rank)

        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        conv1.load_weight(conv1_w)
        conv2.load_weight(conv2_w)
        fc1.load_weight(fc1_w)
        trunc1.set_div_to_pow(pow_to_div)

        with NamedTimerInstance("Server Offline"):
            secure_nn.offline()
            torch_sync()

        with NamedTimerInstance("Server Online"):
            secure_nn.online()
            torch_sync()

        end_communicate()
Exemplo n.º 5
0
    def response(self):
        self.prepare_recv()
        torch_sync()

        fhe_masked = self.common.masked.get_recv()

        with NamedTimerInstance("client refresh reencrypt"):
            if len(self.shape) == 2:
                for i in range(self.shape[0]):
                    sub_dec = self.fhe_builder.decrypt_to_torch(fhe_masked[i])
                    self.refreshed[i].encrypt_additive(sub_dec)
            else:
                self.refreshed.encrypt_additive(self.fhe_builder.decrypt_to_torch(fhe_masked))

        # def sub_reencrypt(sub_enc):
        #     sub_dec = self.fhe_builder.decrypt_to_torch(sub_enc)
        #     sub_refreshed = self.fhe_builder.build_enc_from_torch(sub_dec)
        #     del sub_dec
        #     return sub_refreshed
        # with NamedTimerInstance("client refresh reencrypt"):
        #     self.refreshed = sub_handle(sub_reencrypt, fhe_masked)

        self.common.refreshed.send(self.refreshed)

        torch_sync()
        delete_fhe(fhe_masked)
        delete_fhe(self.refreshed)
Exemplo n.º 6
0
    def test_server():
        init_communicate(Config.server_rank)
        warming_up_cuda()
        prot = ReconToClientServer(num_elem, modulus, test_name)
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Server online"):
            prot.online(x_s)
            torch_sync()

        end_communicate()
Exemplo n.º 7
0
    def test_client():
        init_communicate(Config.client_rank)
        warming_up_cuda()
        prot = ReconToClientClient(num_elem, modulus, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online(x_c)
            torch_sync()

        check_correctness_online(prot.output, x_s, x_c)
        end_communicate()
Exemplo n.º 8
0
    def test_client():
        init_communicate(Config.client_rank)
        prot = FcFheClient(modulus, data_range, num_input_unit,
                           num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_mask)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()
Exemplo n.º 9
0
    def test_client():
        init_communicate(Config.client_rank)
        prot = SharesMultClient(num_elem, modulus, fhe_builder,
                                "test_shares_mult")
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online(a_c, b_c)
            torch_sync()

        blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_u_c")
        blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_v_c")
        blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_z_c")
        torch_sync()
        blob_u_c.send(prot.u_c)
        blob_v_c.send(prot.v_c)
        blob_z_c.send(prot.z_c)
        blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c")
        torch_sync()
        blob_c_c.send(prot.c_c)
        end_communicate()
Exemplo n.º 10
0
    def test_server():
        init_communicate(Config.server_rank)
        prot = SharesMultServer(num_elem, modulus, fhe_builder,
                                "test_shares_mult")
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(a_s, b_s)
            torch_sync()

        blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_u_c")
        blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_v_c")
        blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_z_c")
        blob_u_c.prepare_recv()
        blob_v_c.prepare_recv()
        blob_z_c.prepare_recv()
        torch_sync()
        u_c = blob_u_c.get_recv()
        v_c = blob_v_c.get_recv()
        z_c = blob_z_c.get_recv()
        u = pmod(prot.u_s + u_c, modulus)
        v = pmod(prot.v_s + v_c, modulus)
        check_correctness_online(u, v, prot.z_s, z_c)

        blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c")
        blob_c_c.prepare_recv()
        torch_sync()
        c_c = blob_c_c.get_recv()
        check_correctness_online(a, b, prot.c_s, c_c)
        end_communicate()
Exemplo n.º 11
0
 def fhe_builder_sync(self):
     comm_fhe_16 = CommFheBuilder(self.rank, self.fhe_builder_16, self.sub_name("fhe_builder_16"))
     comm_fhe_23 = CommFheBuilder(self.rank, self.fhe_builder_23, self.sub_name("fhe_builder_23"))
     torch_sync()
     if self.is_server():
         comm_fhe_16.recv_public_key()
         comm_fhe_23.recv_public_key()
         comm_fhe_16.wait_and_build_public_key()
         comm_fhe_23.wait_and_build_public_key()
     elif self.is_client():
         self.fhe_builder_16.generate_keys()
         self.fhe_builder_23.generate_keys()
         comm_fhe_16.send_public_key()
         comm_fhe_23.send_public_key()
     torch_sync()
     return self
Exemplo n.º 12
0
    def online(self, a_s, b_s):
        a_s = a_s.to(Config.device)
        b_s = b_s.to(Config.device)
        e_s = self.mod_to_modulus(a_s - self.u_s)
        f_s = self.mod_to_modulus(b_s - self.v_s)
        torch_sync()
        torch_sync()
        self.blob_e_s.send(e_s)
        self.blob_f_s.send(f_s)
        e_s = e_s.to(Config.device).double()
        f_s = f_s.to(Config.device).double()

        e_c = self.blob_e_c.get_recv()
        e = self.mod_to_modulus(e_s + e_c).to(Config.device).double()
        f_c = self.blob_f_c.get_recv()
        f = self.mod_to_modulus(f_s + f_c).to(Config.device).double()
        self.c_s = pmod(a_s * f + e * b_s + self.z_s - e * f, self.modulus)
Exemplo n.º 13
0
    def test_server():
        init_communicate(Config.server_rank)
        prot = FcFheServer(modulus, data_range, num_input_unit,
                           num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_offline(input_mask, weight, prot.output_mask_s,
                                  output_c)

        end_communicate()
Exemplo n.º 14
0
    def reconstructed_to_server(self, comm_base: CommBase, modulus):
        blob_output_share = BlobTorch(self.get_output_shape(), torch.float,
                                      comm_base, self.name + "_output_share")

        if self.is_server():
            blob_output_share.prepare_recv()
            torch_sync()
            other_output_share = blob_output_share.get_recv()
            # print(self.name + "_output_share" + "_server: have", self.get_output_share())
            # print(self.name + "_output_share" + "_server: received", other_output_share)
            self.reconstructed_output = nmod(
                self.get_output_share() + other_output_share, modulus)
            # print(self.name + "_output_share" + "_server: recon", self.reconstructed_output)

        if self.is_client():
            torch_sync()
            blob_output_share.send(self.get_output_share())
Exemplo n.º 15
0
    def online(self, a_c, b_c):
        a_c = a_c.to(Config.device)
        b_c = b_c.to(Config.device)
        e_c = self.mod_to_modulus(a_c - self.u_c)
        f_c = self.mod_to_modulus(b_c - self.v_c)
        torch_sync()
        self.blob_e_c.send(e_c)
        self.blob_f_c.send(f_c)
        torch_sync()
        e_c = e_c.to(Config.device).double()
        f_c = f_c.to(Config.device).double()

        e_s = self.blob_e_s.get_recv()
        e = self.mod_to_modulus(e_s + e_c).double().to(Config.device)
        f_s = self.blob_f_s.get_recv()
        f = self.mod_to_modulus(f_s + f_c).double().to(Config.device)
        self.c_c = pmod(a_c * f + e * b_c + self.z_c, self.modulus)
Exemplo n.º 16
0
def run_secure_nn_client_with_random_data(secure_nn, check_correctness, master_address, master_port):
    rank = Config.client_rank
    traffic_record = TrafficRecord()
    secure_nn.set_rank(rank).init_communication(master_address=master_address, master_port=master_port)
    warming_up_cuda()
    secure_nn.fhe_builder_sync().fill_random_input()

    with NamedTimerInstance("Client Offline"):
        secure_nn.offline()
        torch_sync()
    traffic_record.reset("client-offline")

    with NamedTimerInstance("Client Online"):
        secure_nn.online()
        torch_sync()
    traffic_record.reset("client-online")

    secure_nn.check_correctness(check_correctness).end_communication()
Exemplo n.º 17
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank, master_address=master_address, master_port=master_port)
        traffic_record = TrafficRecord()
        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        a = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem)
        a_c = gen_unirand_int_grain(0, q_23 - 1, num_elem)
        a_s = pmod(a - a_c, q_23)

        prot = ReluDgkClient(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk")

        blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a")
        blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s")
        torch_sync()
        blob_a_s.send(a_s)
        blob_max_s.prepare_recv()

        torch_sync()
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            prot.online(a_c)
            torch_sync()
        traffic_record.reset("client-online")

        max_s = blob_max_s.get_recv()
        check_correctness_online(a, max_s, prot.max_c)

        torch.cuda.empty_cache()
        end_communicate()
Exemplo n.º 18
0
    def offline(self):
        self.offline_recv()

        self.beta_i_c = gen_unirand_int_grain(
            0, self.q_16 - 1, self.decomp_bit_shape).to(Config.device)
        self.delta_b_c = gen_unirand_int_grain(0, self.q_23 - 1,
                                               self.num_elem).to(Config.device)
        self.z_work_c = gen_unirand_int_grain(0, self.q_23 - 1,
                                              self.num_elem).to(Config.device)
        self.beta_i_zeros = torch.zeros_like(self.beta_i_c)
        self.fast_ones = torch.ones(self.num_elem).to(Config.device)
        self.fast_zeros = torch.zeros(self.num_elem).to(Config.device)
        self.fast_ones_c_i = torch.ones(self.sum_shape).float().to(
            Config.device)
        self.fast_zeros_c_i = torch.zeros(self.sum_shape).float().to(
            Config.device)

        self.common.beta_i_c.send_from_torch(self.beta_i_c)
        self.common.delta_b_c.send_from_torch(self.delta_b_c)
        self.common.z_work_c.send_from_torch(self.z_work_c)

        if self.is_shuffle:
            self.sum_c_refresher = EncRefresherClient(
                self.sum_shape, self.fhe_builder_16,
                self.sub_name("shuffle_refresher"))

        self.mod_div_offline()

        refresher_ab_xor_c = EncRefresherClient(
            self.decomp_bit_shape, self.fhe_builder_16,
            self.common.sub_name("refresher_ab_xor_c"))
        refresher_ab_xor_c.response()
        self.sum_c_i_offline()

        self.c_i_c = self.common.c_i_c.get_recv_decrypt()
        self.delta_xor_c = self.common.delta_xor_c.get_recv_decrypt()
        self.dgk_x_leq_y_c = self.common.dgk_x_leq_y_c.get_recv_decrypt()
        self.dgk_x_leq_y_c = pmod(
            self.dgk_x_leq_y_c + self.correct_mod_div_work_c, self.q_23)

        self.online_recv()
        torch_sync()
Exemplo n.º 19
0
    def test_client():
        init_communicate(Config.client_rank)
        prot = Conv2dFheNttClient(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_fhe_ntt_comm",
                                  padding=padding)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_mask)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()
Exemplo n.º 20
0
    def test_server():
        rank = Config.server_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = Conv2dSecureServer(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_secure_comm",
                                  padding=padding)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight, bias=bias)
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(input_s)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_online(input, weight, bias, prot.output_s, output_c)

        end_communicate()
Exemplo n.º 21
0
    def test_client():
        init_communicate(Config.client_rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()
        prot = Avgpool2x2Client(num_elem, q_23, q_16, work_bit, data_bit,
                                img_hw, fhe_builder_16, fhe_builder_23,
                                "avgpool")

        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            prot.online(img_c)
            torch_sync()
        traffic_record.reset("client-online")

        blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_res_c")
        torch_sync()
        blob_out_c.send(prot.out_c)
        end_communicate()
Exemplo n.º 22
0
    def request(self, enc):
        self.prepare_recv()
        torch_sync()
        self.r_s = gen_unirand_int_grain(0, self.modulus - 1, self.shape)

        if len(self.shape) == 2:
            pt = []
            for i in range(self.shape[0]):
                pt.append(self.fhe_builder.build_plain_from_torch(self.r_s[i]))
                enc[i] += pt[i]

            self.common.masked.send(enc)
            refreshed = self.common.refreshed.get_recv()

            for i in range(self.shape[0]):
                refreshed[i] -= pt[i]
            delete_fhe(enc)
            delete_fhe(pt)
            torch_sync()
            return refreshed
        else:
            pt = self.fhe_builder.build_plain_from_torch(self.r_s)
            enc += pt
            self.common.masked.send(enc)
            refreshed = self.common.refreshed.get_recv()
            refreshed -= pt
            delete_fhe(enc)
            delete_fhe(pt)
            torch_sync()
            return refreshed
Exemplo n.º 23
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = Conv2dSecureClient(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_secure_comm",
                                  padding=padding)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online()
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()
Exemplo n.º 24
0
    def test_server():
        init_communicate(Config.server_rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()
        prot = Avgpool2x2Server(num_elem, q_23, q_16, work_bit, data_bit,
                                img_hw, fhe_builder_16, fhe_builder_23,
                                "avgpool")

        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            prot.online(img_s)
            torch_sync()
        traffic_record.reset("server-online")

        blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_res_c")
        blob_out_c.prepare_recv()
        torch_sync()
        out_c = blob_out_c.get_recv()
        check_correctness_online(img, prot.out_s, out_c)

        end_communicate()
Exemplo n.º 25
0
    def test_server():
        rank = Config.server_rank
        init_communicate(Config.server_rank, master_address=master_address, master_port=master_port)
        traffic_record = TrafficRecord()
        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        prot = ReluDgkServer(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk")

        blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a")
        blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s")
        torch_sync()
        blob_a_s.prepare_recv()
        a_s = blob_a_s.get_recv()

        torch_sync()
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            prot.online(a_s)
            torch_sync()
        traffic_record.reset("server-online")

        blob_max_s.send(prot.max_s)
        torch.cuda.empty_cache()
        end_communicate()
Exemplo n.º 26
0
    def test_server():
        init_communicate(Config.server_rank)
        prot = Conv2dFheNttServer(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_fhe_ntt_comm",
                                  padding=padding)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_offline(input_mask, weight, prot.output_mask_s,
                                  output_c)

        end_communicate()
Exemplo n.º 27
0
    def test_client():
        init_communicate(Config.client_rank)
        warming_up_cuda()
        prot = SwapToClientOfflineClient(num_elem, modulus, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(y_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online(x_c)
            torch_sync()

        blob_output_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                                  "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()
Exemplo n.º 28
0
    def test_server():
        init_communicate(Config.server_rank)
        warming_up_cuda()
        prot = SwapToClientOfflineServer(num_elem, modulus, test_name)
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Server online"):
            prot.online(x_s)
            torch_sync()

        blob_output_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                                  "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_online(x_s, x_c, prot.output_s, output_c)

        end_communicate()
Exemplo n.º 29
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = FcSecureClient(modulus, data_range, num_input_unit,
                              num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online()
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()
Exemplo n.º 30
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        prot = Maxpool2x2DgkClient(num_elem, q_23, q_16, work_bit, data_bit,
                                   img_hw, fhe_builder_16, fhe_builder_23,
                                   "max_dgk")
        blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                               "recon_max_c")
        blob_img_c.prepare_recv()
        torch_sync()
        img_c = blob_img_c.get_recv()

        torch_sync()
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            prot.online(img_c)
            torch_sync()
        traffic_record.reset("client-online")

        blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_max_c")
        torch_sync()
        blob_max_c.send(prot.max_c)
        end_communicate()