def test_client(): rank = Config.client_rank init_communicate(rank) context.set_rank(rank) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() input_img = generate_random_data(input_shape) trunc1.set_div_to_pow(pow_to_div) secure_nn.feed_input(input_img) with NamedTimerInstance("Client Offline"): secure_nn.offline() torch_sync() with NamedTimerInstance("Client Online"): secure_nn.online() torch_sync() check_correctness(input_img, secure_nn.get_output()) end_communicate()
def test_server(): rank = Config.server_rank sys.stdout = Logger() traffic_record = TrafficRecord() secure_nn = get_secure_nn() secure_nn.set_rank(rank).init_communication(master_address=master_addr, master_port=master_port) warming_up_cuda() secure_nn.fhe_builder_sync() load_trunc_params(secure_nn, store_configs) net_state = torch.load(net_state_name) load_weight_params(secure_nn, store_configs, net_state) meta_rg = MetaTruncRandomGenerator() meta_rg.reset_seed() with NamedTimerInstance("Server Offline"): secure_nn.offline() torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): secure_nn.online() torch_sync() traffic_record.reset("server-online") secure_nn.check_correctness(check_correctness) secure_nn.check_layers(get_plain_net, get_hooking_lst(model_name_base)) secure_nn.end_communication()
def check_correctness(self, verify_func, is_argmax=False, truth=None): blob_input_img = BlobTorch(self.get_input_shape(), torch.float, self.comm_base, "input_img") blob_actual_output = BlobTorch(self.get_output_shape(), torch.float, self.comm_base, "actual_output") blob_truth = BlobTorch(1, torch.float, self.comm_base, "truth") if self.is_server(): blob_input_img.prepare_recv() blob_actual_output.prepare_recv() blob_truth.prepare_recv() torch_sync() input_img = blob_input_img.get_recv() actual_output = blob_actual_output.get_recv() truth = int(blob_truth.get_recv().item()) verify_func(self, input_img, actual_output, self.q_23) actual_output = nmod(actual_output, self.q_23).cuda() _, actual_max = torch.max(actual_output, 0) print(f"truth: {truth}, actual: {actual_max}, MatchTruth: {truth == actual_max}") if self.is_client(): torch_sync() actual_output = self.secure_nn_core.get_argmax_output() if is_argmax else self.secure_nn_core.get_output() blob_input_img.send(self.input_img) blob_actual_output.send(actual_output) blob_truth.send(torch.tensor(truth)) return self
def test_server(): rank = Config.server_rank init_communicate(rank) context.set_rank(rank) comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() conv1.load_weight(conv1_w) conv2.load_weight(conv2_w) fc1.load_weight(fc1_w) trunc1.set_div_to_pow(pow_to_div) with NamedTimerInstance("Server Offline"): secure_nn.offline() torch_sync() with NamedTimerInstance("Server Online"): secure_nn.online() torch_sync() end_communicate()
def response(self): self.prepare_recv() torch_sync() fhe_masked = self.common.masked.get_recv() with NamedTimerInstance("client refresh reencrypt"): if len(self.shape) == 2: for i in range(self.shape[0]): sub_dec = self.fhe_builder.decrypt_to_torch(fhe_masked[i]) self.refreshed[i].encrypt_additive(sub_dec) else: self.refreshed.encrypt_additive(self.fhe_builder.decrypt_to_torch(fhe_masked)) # def sub_reencrypt(sub_enc): # sub_dec = self.fhe_builder.decrypt_to_torch(sub_enc) # sub_refreshed = self.fhe_builder.build_enc_from_torch(sub_dec) # del sub_dec # return sub_refreshed # with NamedTimerInstance("client refresh reencrypt"): # self.refreshed = sub_handle(sub_reencrypt, fhe_masked) self.common.refreshed.send(self.refreshed) torch_sync() delete_fhe(fhe_masked) delete_fhe(self.refreshed)
def test_server(): init_communicate(Config.server_rank) warming_up_cuda() prot = ReconToClientServer(num_elem, modulus, test_name) with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() with NamedTimerInstance("Server online"): prot.online(x_s) torch_sync() end_communicate()
def test_client(): init_communicate(Config.client_rank) warming_up_cuda() prot = ReconToClientClient(num_elem, modulus, test_name) with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() with NamedTimerInstance("Client Online"): prot.online(x_c) torch_sync() check_correctness_online(prot.output, x_s, x_c) end_communicate()
def test_client(): init_communicate(Config.client_rank) prot = FcFheClient(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) with NamedTimerInstance("Client Offline"): prot.offline(input_mask) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate()
def test_client(): init_communicate(Config.client_rank) prot = SharesMultClient(num_elem, modulus, fhe_builder, "test_shares_mult") with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() with NamedTimerInstance("Client Online"): prot.online(a_c, b_c) torch_sync() blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_u_c") blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_v_c") blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_z_c") torch_sync() blob_u_c.send(prot.u_c) blob_v_c.send(prot.v_c) blob_z_c.send(prot.z_c) blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c") torch_sync() blob_c_c.send(prot.c_c) end_communicate()
def test_server(): init_communicate(Config.server_rank) prot = SharesMultServer(num_elem, modulus, fhe_builder, "test_shares_mult") with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() with NamedTimerInstance("Server Online"): prot.online(a_s, b_s) torch_sync() blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_u_c") blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_v_c") blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_z_c") blob_u_c.prepare_recv() blob_v_c.prepare_recv() blob_z_c.prepare_recv() torch_sync() u_c = blob_u_c.get_recv() v_c = blob_v_c.get_recv() z_c = blob_z_c.get_recv() u = pmod(prot.u_s + u_c, modulus) v = pmod(prot.v_s + v_c, modulus) check_correctness_online(u, v, prot.z_s, z_c) blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c") blob_c_c.prepare_recv() torch_sync() c_c = blob_c_c.get_recv() check_correctness_online(a, b, prot.c_s, c_c) end_communicate()
def fhe_builder_sync(self): comm_fhe_16 = CommFheBuilder(self.rank, self.fhe_builder_16, self.sub_name("fhe_builder_16")) comm_fhe_23 = CommFheBuilder(self.rank, self.fhe_builder_23, self.sub_name("fhe_builder_23")) torch_sync() if self.is_server(): comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() elif self.is_client(): self.fhe_builder_16.generate_keys() self.fhe_builder_23.generate_keys() comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() torch_sync() return self
def online(self, a_s, b_s): a_s = a_s.to(Config.device) b_s = b_s.to(Config.device) e_s = self.mod_to_modulus(a_s - self.u_s) f_s = self.mod_to_modulus(b_s - self.v_s) torch_sync() torch_sync() self.blob_e_s.send(e_s) self.blob_f_s.send(f_s) e_s = e_s.to(Config.device).double() f_s = f_s.to(Config.device).double() e_c = self.blob_e_c.get_recv() e = self.mod_to_modulus(e_s + e_c).to(Config.device).double() f_c = self.blob_f_c.get_recv() f = self.mod_to_modulus(f_s + f_c).to(Config.device).double() self.c_s = pmod(a_s * f + e * b_s + self.z_s - e * f, self.modulus)
def test_server(): init_communicate(Config.server_rank) prot = FcFheServer(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) with NamedTimerInstance("Server Offline"): prot.offline(weight) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_offline(input_mask, weight, prot.output_mask_s, output_c) end_communicate()
def reconstructed_to_server(self, comm_base: CommBase, modulus): blob_output_share = BlobTorch(self.get_output_shape(), torch.float, comm_base, self.name + "_output_share") if self.is_server(): blob_output_share.prepare_recv() torch_sync() other_output_share = blob_output_share.get_recv() # print(self.name + "_output_share" + "_server: have", self.get_output_share()) # print(self.name + "_output_share" + "_server: received", other_output_share) self.reconstructed_output = nmod( self.get_output_share() + other_output_share, modulus) # print(self.name + "_output_share" + "_server: recon", self.reconstructed_output) if self.is_client(): torch_sync() blob_output_share.send(self.get_output_share())
def online(self, a_c, b_c): a_c = a_c.to(Config.device) b_c = b_c.to(Config.device) e_c = self.mod_to_modulus(a_c - self.u_c) f_c = self.mod_to_modulus(b_c - self.v_c) torch_sync() self.blob_e_c.send(e_c) self.blob_f_c.send(f_c) torch_sync() e_c = e_c.to(Config.device).double() f_c = f_c.to(Config.device).double() e_s = self.blob_e_s.get_recv() e = self.mod_to_modulus(e_s + e_c).double().to(Config.device) f_s = self.blob_f_s.get_recv() f = self.mod_to_modulus(f_s + f_c).double().to(Config.device) self.c_c = pmod(a_c * f + e * b_c + self.z_c, self.modulus)
def run_secure_nn_client_with_random_data(secure_nn, check_correctness, master_address, master_port): rank = Config.client_rank traffic_record = TrafficRecord() secure_nn.set_rank(rank).init_communication(master_address=master_address, master_port=master_port) warming_up_cuda() secure_nn.fhe_builder_sync().fill_random_input() with NamedTimerInstance("Client Offline"): secure_nn.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): secure_nn.online() torch_sync() traffic_record.reset("client-online") secure_nn.check_correctness(check_correctness).end_communication()
def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() a = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) a_c = gen_unirand_int_grain(0, q_23 - 1, num_elem) a_s = pmod(a - a_c, q_23) prot = ReluDgkClient(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk") blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a") blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s") torch_sync() blob_a_s.send(a_s) blob_max_s.prepare_recv() torch_sync() with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): prot.online(a_c) torch_sync() traffic_record.reset("client-online") max_s = blob_max_s.get_recv() check_correctness_online(a, max_s, prot.max_c) torch.cuda.empty_cache() end_communicate()
def offline(self): self.offline_recv() self.beta_i_c = gen_unirand_int_grain( 0, self.q_16 - 1, self.decomp_bit_shape).to(Config.device) self.delta_b_c = gen_unirand_int_grain(0, self.q_23 - 1, self.num_elem).to(Config.device) self.z_work_c = gen_unirand_int_grain(0, self.q_23 - 1, self.num_elem).to(Config.device) self.beta_i_zeros = torch.zeros_like(self.beta_i_c) self.fast_ones = torch.ones(self.num_elem).to(Config.device) self.fast_zeros = torch.zeros(self.num_elem).to(Config.device) self.fast_ones_c_i = torch.ones(self.sum_shape).float().to( Config.device) self.fast_zeros_c_i = torch.zeros(self.sum_shape).float().to( Config.device) self.common.beta_i_c.send_from_torch(self.beta_i_c) self.common.delta_b_c.send_from_torch(self.delta_b_c) self.common.z_work_c.send_from_torch(self.z_work_c) if self.is_shuffle: self.sum_c_refresher = EncRefresherClient( self.sum_shape, self.fhe_builder_16, self.sub_name("shuffle_refresher")) self.mod_div_offline() refresher_ab_xor_c = EncRefresherClient( self.decomp_bit_shape, self.fhe_builder_16, self.common.sub_name("refresher_ab_xor_c")) refresher_ab_xor_c.response() self.sum_c_i_offline() self.c_i_c = self.common.c_i_c.get_recv_decrypt() self.delta_xor_c = self.common.delta_xor_c.get_recv_decrypt() self.dgk_x_leq_y_c = self.common.dgk_x_leq_y_c.get_recv_decrypt() self.dgk_x_leq_y_c = pmod( self.dgk_x_leq_y_c + self.correct_mod_div_work_c, self.q_23) self.online_recv() torch_sync()
def test_client(): init_communicate(Config.client_rank) prot = Conv2dFheNttClient(modulus, fhe_builder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_fhe_ntt_comm", padding=padding) with NamedTimerInstance("Client Offline"): prot.offline(input_mask) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate()
def test_server(): rank = Config.server_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = Conv2dSecureServer(modulus, fhe_builder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_secure_comm", padding=padding) with NamedTimerInstance("Server Offline"): prot.offline(weight, bias=bias) torch_sync() with NamedTimerInstance("Server Online"): prot.online(input_s) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_online(input, weight, bias, prot.output_s, output_c) end_communicate()
def test_client(): init_communicate(Config.client_rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() prot = Avgpool2x2Client(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "avgpool") with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): prot.online(img_c) torch_sync() traffic_record.reset("client-online") blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_res_c") torch_sync() blob_out_c.send(prot.out_c) end_communicate()
def request(self, enc): self.prepare_recv() torch_sync() self.r_s = gen_unirand_int_grain(0, self.modulus - 1, self.shape) if len(self.shape) == 2: pt = [] for i in range(self.shape[0]): pt.append(self.fhe_builder.build_plain_from_torch(self.r_s[i])) enc[i] += pt[i] self.common.masked.send(enc) refreshed = self.common.refreshed.get_recv() for i in range(self.shape[0]): refreshed[i] -= pt[i] delete_fhe(enc) delete_fhe(pt) torch_sync() return refreshed else: pt = self.fhe_builder.build_plain_from_torch(self.r_s) enc += pt self.common.masked.send(enc) refreshed = self.common.refreshed.get_recv() refreshed -= pt delete_fhe(enc) delete_fhe(pt) torch_sync() return refreshed
def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = Conv2dSecureClient(modulus, fhe_builder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_secure_comm", padding=padding) with NamedTimerInstance("Client Offline"): prot.offline(input_c) torch_sync() with NamedTimerInstance("Client Online"): prot.online() torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate()
def test_server(): init_communicate(Config.server_rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() prot = Avgpool2x2Server(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "avgpool") with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): prot.online(img_s) torch_sync() traffic_record.reset("server-online") blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_res_c") blob_out_c.prepare_recv() torch_sync() out_c = blob_out_c.get_recv() check_correctness_online(img, prot.out_s, out_c) end_communicate()
def test_server(): rank = Config.server_rank init_communicate(Config.server_rank, master_address=master_address, master_port=master_port) traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() prot = ReluDgkServer(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk") blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a") blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s") torch_sync() blob_a_s.prepare_recv() a_s = blob_a_s.get_recv() torch_sync() with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): prot.online(a_s) torch_sync() traffic_record.reset("server-online") blob_max_s.send(prot.max_s) torch.cuda.empty_cache() end_communicate()
def test_server(): init_communicate(Config.server_rank) prot = Conv2dFheNttServer(modulus, fhe_builder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_fhe_ntt_comm", padding=padding) with NamedTimerInstance("Server Offline"): prot.offline(weight) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_offline(input_mask, weight, prot.output_mask_s, output_c) end_communicate()
def test_client(): init_communicate(Config.client_rank) warming_up_cuda() prot = SwapToClientOfflineClient(num_elem, modulus, test_name) with NamedTimerInstance("Client Offline"): prot.offline(y_c) torch_sync() with NamedTimerInstance("Client Online"): prot.online(x_c) torch_sync() blob_output_c = BlobTorch(num_elem, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate()
def test_server(): init_communicate(Config.server_rank) warming_up_cuda() prot = SwapToClientOfflineServer(num_elem, modulus, test_name) with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() with NamedTimerInstance("Server online"): prot.online(x_s) torch_sync() blob_output_c = BlobTorch(num_elem, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_online(x_s, x_c, prot.output_s, output_c) end_communicate()
def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = FcSecureClient(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) with NamedTimerInstance("Client Offline"): prot.offline(input_c) torch_sync() with NamedTimerInstance("Client Online"): prot.online() torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate()
def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() prot = Maxpool2x2DgkClient(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "max_dgk") blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_max_c") blob_img_c.prepare_recv() torch_sync() img_c = blob_img_c.get_recv() torch_sync() with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): prot.online(img_c) torch_sync() traffic_record.reset("client-online") blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_max_c") torch_sync() blob_max_c.send(prot.max_c) end_communicate()