def __init__(self, num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16: FheBuilder, fhe_builder_23: FheBuilder, name: str, is_shuffle=None): super(DgkBitClient, self).__init__(num_elem, q_23, q_16, work_bit, data_bit, name=name) self.fhe_builder_16 = fhe_builder_16 self.fhe_builder_23 = fhe_builder_23 self.comm_base = CommBase(Config.client_rank, name) self.comm_fhe_16 = CommFheBuilder(Config.client_rank, self.fhe_builder_16, name + '_' + "comm_fhe_16") self.comm_fhe_23 = CommFheBuilder(Config.client_rank, self.fhe_builder_23, name + '_' + "comm_fhe_23") self.common = DgkBitCommon(num_elem, q_23, q_16, work_bit, data_bit, name, self.comm_base, self.comm_fhe_16, self.comm_fhe_23) self.is_shuffle = Config.is_shuffle if is_shuffle is None else is_shuffle
def test_client(): rank = Config.client_rank init_communicate(rank) context.set_rank(rank) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() input_img = generate_random_data(input_shape) trunc1.set_div_to_pow(pow_to_div) secure_nn.feed_input(input_img) with NamedTimerInstance("Client Offline"): secure_nn.offline() torch_sync() with NamedTimerInstance("Client Online"): secure_nn.online() torch_sync() check_correctness(input_img, secure_nn.get_output()) end_communicate()
def __init__(self, shape, fhe_builder: FheBuilder, name: str): self.shape = shape self.fhe_builder = fhe_builder self.modulus = self.fhe_builder.modulus comm = CommFheBuilder(Config.client_rank, fhe_builder, name) self.common = EncRefresherCommon(shape, comm, name) self.is_prepared = False
def __init__(self, modulus, data_range, num_input_unit, num_output_unit, fhe_builder: FheBuilder, name, rank): super().__init__(modulus, data_range, num_input_unit, num_output_unit, name) self.fhe_builder = fhe_builder self.rank = rank self.comm_base = CommBase(rank, self.sub_name("comm_base")) self.comm_fhe = CommFheBuilder(rank, fhe_builder, self.sub_name("comm_fhe")) self.compute_core = FcFheSingleThread(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, name) self.blob_input_cts = BlobFheRawCts(self.num_input_batch, self.comm_fhe, self.sub_name("input_cts")) self.blob_output_cts = BlobFheRawCts(self.num_output_batch, self.comm_fhe, self.sub_name("output_cts")) self.offline_server_send = [self.blob_output_cts] self.offline_client_send = [self.blob_input_cts] self.online_server_send = [] self.online_client_send = []
def __init__(self, num_elem, modulus, fhe_builder: FheBuilder, name: str, rank): self.num_elem = num_elem self.modulus = modulus self.name = name self.fhe_builder = fhe_builder self.rank = rank self.comm_base = CommBase(rank, name) self.comm_fhe = CommFheBuilder(rank, fhe_builder, self.sub_name("comm_fhe")) assert (self.modulus == self.fhe_builder.modulus) self.blob_fhe_u = BlobFheEnc(num_elem, self.comm_fhe, self.sub_name("fhe_u")) self.blob_fhe_v = BlobFheEnc(num_elem, self.comm_fhe, self.sub_name("fhe_v")) self.blob_fhe_z_c = BlobFheEnc(num_elem, self.comm_fhe, self.sub_name("fhe_z_c"), ee_mult_time=1) self.blob_e_s = BlobTorch(num_elem, torch.float, self.comm_base, "e_s") self.blob_f_s = BlobTorch(num_elem, torch.float, self.comm_base, "f_s") self.blob_e_c = BlobTorch(num_elem, torch.float, self.comm_base, "e_c") self.blob_f_c = BlobTorch(num_elem, torch.float, self.comm_base, "f_c") self.offline_server_send = [self.blob_fhe_z_c] self.offline_client_send = [self.blob_fhe_u, self.blob_fhe_v] self.online_server_send = [self.blob_e_s, self.blob_f_s] self.online_client_send = [self.blob_e_c, self.blob_f_c]
def __init__(self, modulus, fhe_builder: FheBuilder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, name: str, rank, padding): super(Conv2dFheNttBase, self).__init__(modulus, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, name, padding) self.fhe_builder = fhe_builder self.rank = rank self.comm_base = CommBase(rank, self.sub_name("comm_base")) self.comm_fhe = CommFheBuilder(rank, fhe_builder, self.sub_name("comm_fhe")) self.compute_core = Conv2dFheNttSingleThread( modulus, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, fhe_builder, name, padding) assert (self.modulus == self.fhe_builder.modulus) self.blob_input_cts = BlobFheRawCts(self.num_input_batch, self.comm_fhe, self.sub_name("input_cts")) self.blob_output_cts = BlobFheRawCts(self.num_output_batch, self.comm_fhe, self.sub_name("output_cts")) self.offline_server_send = [self.blob_output_cts] self.offline_client_send = [self.blob_input_cts] self.online_server_send = [] self.online_client_send = []
def __init__(self, num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16: FheBuilder, fhe_builder_23: FheBuilder, name: str, rank: int, class_name: str): super(DgkCommBase, self).__init__(num_elem, q_23, q_16, work_bit, data_bit) self.class_name = class_name self.name = name self.rank = rank self.fhe_builder_16 = fhe_builder_16 self.fhe_builder_23 = fhe_builder_23 self.comm_base = CommBase(rank, name) self.comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, self.sub_name("comm_fhe_16")) self.comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, self.sub_name("comm_fhe_23")) assert (self.fhe_builder_16.modulus == self.q_16) assert (self.fhe_builder_23.modulus == self.q_23)
def __init__(self, modulus, fhe_builder: FheBuilder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, name: str, rank, padding): super().__init__(modulus, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, name, padding) self.fhe_builder = fhe_builder self.rank = rank self.comm_base = CommBase(rank, self.sub_name("comm_base")) self.comm_fhe = CommFheBuilder(rank, fhe_builder, self.sub_name("comm_fhe")) self.offline_server_send = [] self.offline_client_send = [] self.online_server_send = [] self.online_client_send = []
def __init__(self, modulus, data_range, num_input_unit, num_output_unit, fhe_builder: FheBuilder, name, rank): super().__init__(modulus, data_range, num_input_unit, num_output_unit, name) self.fhe_builder = fhe_builder self.rank = rank self.comm_base = CommBase(rank, self.sub_name("comm_base")) self.comm_fhe = CommFheBuilder(rank, fhe_builder, self.sub_name("comm_fhe")) assert (self.fhe_builder.degree == self.degree) self.offline_server_send = [] self.offline_client_send = [] self.online_server_send = [] self.online_client_send = []
def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() prot = Maxpool2x2DgkClient(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "max_dgk") blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_max_c") blob_img_c.prepare_recv() torch_sync() img_c = blob_img_c.get_recv() torch_sync() with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): prot.online(img_c) torch_sync() traffic_record.reset("client-online") blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_max_c") torch_sync() blob_max_c.send(prot.max_c) end_communicate()
def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() a = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) a_c = gen_unirand_int_grain(0, q_23 - 1, num_elem) a_s = pmod(a - a_c, q_23) prot = ReluDgkClient(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk") blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a") blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s") torch_sync() blob_a_s.send(a_s) blob_max_s.prepare_recv() torch_sync() with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): prot.online(a_c) torch_sync() traffic_record.reset("client-online") max_s = blob_max_s.get_recv() check_correctness_online(a, max_s, prot.max_c) torch.cuda.empty_cache() end_communicate()
def fhe_builder_sync(self): comm_fhe_16 = CommFheBuilder(self.rank, self.fhe_builder_16, self.sub_name("fhe_builder_16")) comm_fhe_23 = CommFheBuilder(self.rank, self.fhe_builder_23, self.sub_name("fhe_builder_23")) torch_sync() if self.is_server(): comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() elif self.is_client(): self.fhe_builder_16.generate_keys() self.fhe_builder_23.generate_keys() comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() torch_sync() return self
def test_server(): rank = Config.server_rank init_communicate(rank, master_address=master_address, master_port=master_port) traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() img = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) img_s = gen_unirand_int_grain(0, q_23 - 1, num_elem) img_c = pmod(img - img_s, q_23) prot = Maxpool2x2DgkServer(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "max_dgk") blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_max_c") torch_sync() blob_img_c.send(img_c) torch_sync() with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): prot.online(img_s) torch_sync() traffic_record.reset("server-online") blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_max_c") blob_max_c.prepare_recv() torch_sync() max_c = blob_max_c.get_recv() check_correctness_online(img, prot.max_s, max_c) end_communicate()
def test_server(): rank = Config.server_rank init_communicate(rank) context.set_rank(rank) comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() conv1.load_weight(conv1_w) conv2.load_weight(conv2_w) fc1.load_weight(fc1_w) trunc1.set_div_to_pow(pow_to_div) with NamedTimerInstance("Server Offline"): secure_nn.offline() torch_sync() with NamedTimerInstance("Server Online"): secure_nn.online() torch_sync() end_communicate()
def test_server(): rank = Config.server_rank init_communicate(Config.server_rank, master_address=master_address, master_port=master_port) traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() prot = ReluDgkServer(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk") blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a") blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s") torch_sync() blob_a_s.prepare_recv() a_s = blob_a_s.get_recv() torch_sync() with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): prot.online(a_s) torch_sync() traffic_record.reset("server-online") blob_max_s.send(prot.max_s) torch.cuda.empty_cache() end_communicate()
def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, n_16) fhe_builder_23 = FheBuilder(q_23, n_23) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() dgk = DgkBitClient(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "DgkBitTest") x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) y = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) x_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) y_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) x_s = pmod(x - x_c, q_23) y_s = pmod(y - y_c, q_23) y_sub_x_s = pmod(y_s - x_s, q_23) x_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "x") y_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y") y_sub_x_s_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y_sub_x_s") torch_sync() x_blob.send(x) y_blob.send(y) y_sub_x_s_blob.send(y_sub_x_s) torch_sync() with NamedTimerInstance("Client Offline"): dgk.offline() y_sub_x_c = pmod(y_c - x_c, q_23) traffic_record.reset("client-offline") torch_sync() with NamedTimerInstance("Client Online"): dgk.online(y_sub_x_c) traffic_record.reset("client-online") dgk_x_leq_y_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "dgk_x_leq_y_c") correct_mod_div_work_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "correct_mod_div_work_c") z_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "z") torch_sync() dgk_x_leq_y_c_blob.send(dgk.dgk_x_leq_y_c) correct_mod_div_work_c_blob.send(dgk.correct_mod_div_work_c) z_blob.send(dgk.z) end_communicate()
def test_server(): rank = Config.server_rank init_communicate(Config.server_rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() dgk = DgkBitServer(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "DgkBitTest") x_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "x") y_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y") y_sub_x_s_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y_sub_x_s") x_blob.prepare_recv() y_blob.prepare_recv() y_sub_x_s_blob.prepare_recv() torch_sync() x = x_blob.get_recv() y = y_blob.get_recv() y_sub_x_s = y_sub_x_s_blob.get_recv() torch_sync() with NamedTimerInstance("Server Offline"): dgk.offline() # y_sub_x_s = pmod(y_s.to(Config.device) - x_s.to(Config.device), q_23) torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): dgk.online(y_sub_x_s) traffic_record.reset("server-online") dgk_x_leq_y_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "dgk_x_leq_y_c") correct_mod_div_work_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "correct_mod_div_work_c") z_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "z") dgk_x_leq_y_c_blob.prepare_recv() correct_mod_div_work_c_blob.prepare_recv() z_blob.prepare_recv() torch_sync() dgk_x_leq_y_c = dgk_x_leq_y_c_blob.get_recv() correct_mod_div_work_c = correct_mod_div_work_c_blob.get_recv() z = z_blob.get_recv() check_correctness(x, y, dgk.dgk_x_leq_y_s, dgk_x_leq_y_c) check_correctness_mod_div(dgk.r, z, dgk.correct_mod_div_work_s, correct_mod_div_work_c) end_communicate()