def test_server(): rank = Config.server_rank init_communicate(rank, master_address=master_address, master_port=master_port) traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() img = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) img_s = gen_unirand_int_grain(0, q_23 - 1, num_elem) img_c = pmod(img - img_s, q_23) prot = Maxpool2x2DgkServer(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "max_dgk") blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_max_c") torch_sync() blob_img_c.send(img_c) torch_sync() with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): prot.online(img_s) torch_sync() traffic_record.reset("server-online") blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_max_c") blob_max_c.prepare_recv() torch_sync() max_c = blob_max_c.get_recv() check_correctness_online(img, prot.max_s, max_c) end_communicate()
def test_recon_to_client_comm(): test_name = "test_recon_to_client_comm" print(f"\nTest for {test_name}: Start") modulus = 786433 num_elem = 2**17 print(f"Number of element: {num_elem}") x_s = gen_unirand_int_grain(0, modulus - 1, num_elem).cpu() x_c = gen_unirand_int_grain(0, modulus - 1, num_elem).cpu() def check_correctness_online(output, input_s, input_c): expected = pmod(output.cuda(), modulus) actual = pmod(input_s.cuda() + input_c.cuda(), modulus) compare_expected_actual(expected, actual, name=test_name + " online", get_relative=True) def test_server(): init_communicate(Config.server_rank) warming_up_cuda() prot = ReconToClientServer(num_elem, modulus, test_name) with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() with NamedTimerInstance("Server online"): prot.online(x_s) torch_sync() end_communicate() def test_client(): init_communicate(Config.client_rank) warming_up_cuda() prot = ReconToClientClient(num_elem, modulus, test_name) with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() with NamedTimerInstance("Client Online"): prot.online(x_c) torch_sync() check_correctness_online(prot.output, x_s, x_c) end_communicate() marshal_funcs([test_server, test_client]) print(f"\nTest for {test_name}: End")
def request(self, enc): self.prepare_recv() torch_sync() self.r_s = gen_unirand_int_grain(0, self.modulus - 1, self.shape) if len(self.shape) == 2: pt = [] for i in range(self.shape[0]): pt.append(self.fhe_builder.build_plain_from_torch(self.r_s[i])) enc[i] += pt[i] self.common.masked.send(enc) refreshed = self.common.refreshed.get_recv() for i in range(self.shape[0]): refreshed[i] -= pt[i] delete_fhe(enc) delete_fhe(pt) torch_sync() return refreshed else: pt = self.fhe_builder.build_plain_from_torch(self.r_s) enc += pt self.common.masked.send(enc) refreshed = self.common.refreshed.get_recv() refreshed -= pt delete_fhe(enc) delete_fhe(pt) torch_sync() return refreshed
def test_ntt_conv(): modulus = 786433 img_hw = 16 filter_hw = 3 padding = 1 num_input_channel = 64 num_output_channel = 128 data_bit = 17 data_range = 2**data_bit x_shape = [num_input_channel, img_hw, img_hw] w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw] x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape) w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) # x = torch.arange(get_prod(x_shape)).reshape(x_shape) # w = torch.arange(get_prod(w_shape)).reshape(w_shape) conv2d_ntt = Conv2dNtt(modulus, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_ntt", padding) y = conv2d_ntt.conv2d(x, w) with NamedTimerInstance("ntt x"): conv2d_ntt.load_and_ntt_x(x) with NamedTimerInstance("ntt w"): conv2d_ntt.load_and_ntt_w(w) with NamedTimerInstance("conv2d"): y = conv2d_ntt.conv2d_loaded() actual = pmod(y, modulus) # print("actual\n", actual) torch_x = x.reshape([1] + x_shape).double() torch_w = w.reshape(w_shape).double() with NamedTimerInstance("Conv2d Torch"): expected = F.conv2d(torch_x, torch_w, padding=padding) expected = pmod(expected.reshape(conv2d_ntt.y_shape), modulus) # print("expected", expected) compare_expected_actual(expected, actual, name="ntt", get_relative=True, show_where_err=False)
def test_nnt_conv_single_channel(): modulus = 786433 img_hw = 6 filter_hw = 3 padding = 1 data_bit = 17 data_range = 2**data_bit conv_hw = img_hw + 2 * padding padded_hw = get_pow_2_ceil(conv_hw) output_hw = img_hw + 2 * padding - (filter_hw - 1) output_offset = filter_hw - 2 x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, img_hw**2).reshape([img_hw, img_hw ]).numpy().astype(np.int) w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, filter_hw**2).reshape([filter_hw, filter_hw ]).numpy().astype(np.int) x = np.arange(img_hw**2).reshape([img_hw, img_hw]).astype(np.int) w = np.arange(filter_hw**2).reshape([filter_hw, filter_hw]).astype(np.int) padded_x = pad_to_size(x, padded_hw) padded_w = pad_to_size(np.rot90(w, 2), padded_hw) print(padded_x) print(padded_w) with NamedTimerInstance("NTT2D, Sympy"): ntted_x = transform2d( padded_x, lambda sub_img: ntt(sub_img.tolist(), prime=modulus)) ntted_w = transform2d( padded_w, lambda sub_img: ntt(sub_img.tolist(), prime=modulus)) with NamedTimerInstance("Point-wise Dot"): doted = ntted_x * ntted_w with NamedTimerInstance("iNTT2D"): reved = transform2d( doted, lambda sub_img: intt(sub_img.tolist(), prime=modulus)) actual = reved[output_offset:output_hw + output_offset, output_offset:output_hw + output_offset] print("reved\n", reved) print("actual\n", actual) torch_x = torch.tensor(x).reshape([1, 1, img_hw, img_hw]) torch_w = torch.tensor(w).reshape([1, 1, filter_hw, filter_hw]) expected = F.conv2d(torch_x, torch_w, padding=1) expected = pmod(expected.reshape(output_hw, output_hw), modulus) print("expected", expected) compare_expected_actual(expected, actual, name="ntt", get_relative=True)
def test_noise(): print() print("Test for FheBuilder: start") modulus, degree = 12289, 2048 num_elem = 2 ** 14 fhe_builder = FheBuilder(modulus, degree) fhe_builder.generate_keys() print(f"modulus: {modulus}, degree: {degree}") print() gpu_tensor = gen_unirand_int_grain(0, 2, num_elem) print(gpu_tensor) gpu_tensor_rev = gen_unirand_int_grain(0, modulus - 1, num_elem) with NamedTimerInstance(f"build_plain_from_torch with num_elem: {num_elem}"): plain = fhe_builder.build_plain_from_torch(gpu_tensor) with NamedTimerInstance(f"plain.export_as_torch_gpu() with num_elem: {num_elem}"): tensor_from_plain = plain.export_as_torch() print("Fhe Plain encrypt and decrypt: ", end="") assert(compare_expected_actual(gpu_tensor, tensor_from_plain, verbose=True, get_relative=True).RelAvgDiff == 0) print() with NamedTimerInstance(f"fhe_builder.build_enc with num_elem: {num_elem}"): cipher = fhe_builder.build_enc(num_elem) with NamedTimerInstance(f"cipher.encrypt_additive with num_elem: {num_elem}"): cipher.encrypt_additive(gpu_tensor) with NamedTimerInstance(f"fhe_builder.decrypt_to_torch with num_elem: {num_elem}"): tensor_from_cipher = fhe_builder.decrypt_to_torch(cipher) print("Fhe Enc encrypt and decrypt: ", end="") assert(compare_expected_actual(gpu_tensor, tensor_from_cipher, verbose=True, get_relative=True).RelAvgDiff == 0) print() pt = fhe_builder.build_plain_from_torch(gpu_tensor) ct = fhe_builder.build_enc_from_torch(gpu_tensor_rev) fhe_builder.noise_budget(ct, name="before ep") with NamedTimerInstance(f"EP Mult with num_elem: {num_elem}"): ct *= pt fhe_builder.noise_budget(ct, name="after ep") expected = pmod(gpu_tensor.double() * gpu_tensor_rev.double(), modulus) actual = fhe_builder.decrypt_to_torch(ct) assert(compare_expected_actual(expected, actual, verbose=True, get_relative=True).RelAvgDiff == 0) print() print("Test for FheBuilder: Finish")
def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() a = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) a_c = gen_unirand_int_grain(0, q_23 - 1, num_elem) a_s = pmod(a - a_c, q_23) prot = ReluDgkClient(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk") blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a") blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s") torch_sync() blob_a_s.send(a_s) blob_max_s.prepare_recv() torch_sync() with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): prot.online(a_c) torch_sync() traffic_record.reset("client-online") max_s = blob_max_s.get_recv() check_correctness_online(a, max_s, prot.max_c) torch.cuda.empty_cache() end_communicate()
def offline(self): self.offline_recv() self.beta_i_c = gen_unirand_int_grain( 0, self.q_16 - 1, self.decomp_bit_shape).to(Config.device) self.delta_b_c = gen_unirand_int_grain(0, self.q_23 - 1, self.num_elem).to(Config.device) self.z_work_c = gen_unirand_int_grain(0, self.q_23 - 1, self.num_elem).to(Config.device) self.beta_i_zeros = torch.zeros_like(self.beta_i_c) self.fast_ones = torch.ones(self.num_elem).to(Config.device) self.fast_zeros = torch.zeros(self.num_elem).to(Config.device) self.fast_ones_c_i = torch.ones(self.sum_shape).float().to( Config.device) self.fast_zeros_c_i = torch.zeros(self.sum_shape).float().to( Config.device) self.common.beta_i_c.send_from_torch(self.beta_i_c) self.common.delta_b_c.send_from_torch(self.delta_b_c) self.common.z_work_c.send_from_torch(self.z_work_c) if self.is_shuffle: self.sum_c_refresher = EncRefresherClient( self.sum_shape, self.fhe_builder_16, self.sub_name("shuffle_refresher")) self.mod_div_offline() refresher_ab_xor_c = EncRefresherClient( self.decomp_bit_shape, self.fhe_builder_16, self.common.sub_name("refresher_ab_xor_c")) refresher_ab_xor_c.response() self.sum_c_i_offline() self.c_i_c = self.common.c_i_c.get_recv_decrypt() self.delta_xor_c = self.common.delta_xor_c.get_recv_decrypt() self.dgk_x_leq_y_c = self.common.dgk_x_leq_y_c.get_recv_decrypt() self.dgk_x_leq_y_c = pmod( self.dgk_x_leq_y_c + self.correct_mod_div_work_c, self.q_23) self.online_recv() torch_sync()
def test_1d_server(): init_communicate(Config.server_rank) shape = num_elem tensor = gen_unirand_int_grain(0, modulus, shape) refresher = EncRefresherServer(shape, fhe_builder, "test_1d_refresher") enc = fhe_builder.build_enc_from_torch(tensor) refreshed = refresher.request(enc) tensor_refreshed = fhe_builder.decrypt_to_torch(refreshed) compare_expected_actual(tensor, tensor_refreshed, get_relative=True, name="1d_refresh") end_communicate()
def mod_div_offline(self): fhe_builder = self.fhe_builder_23 self.elem_zeros = torch.zeros(self.num_elem).to(Config.device) self.pre_mod_div_c = gen_unirand_int_grain( 0, self.q_23 - 1, self.num_elem).to(Config.device) fhe_correct_mod_div_work = fhe_builder.build_enc_from_torch( self.pre_mod_div_c) self.common.fhe_pre_corr_mod.send(fhe_correct_mod_div_work) fhe_corr_mod_c = self.common.fhe_corr_mod_c.get_recv() self.correct_mod_div_work_c = fhe_builder.decrypt_to_torch( fhe_corr_mod_c)
def mod_div_offline(self): fhe_builder = self.fhe_builder_23 self.elem_zeros = torch.zeros(self.num_elem).to(Config.device) self.correct_mod_div_work_mult = torch.where( (self.r < self.nullify_threshold), self.elem_zeros, self.elem_zeros + self.q_23 // self.work_range).double() self.correct_mod_div_work_mask_s = gen_unirand_int_grain( 0, self.q_23 - 1, self.num_elem).to(Config.device) fhe_mult = fhe_builder.build_plain_from_torch( self.correct_mod_div_work_mult) fhe_bias = fhe_builder.build_plain_from_torch( self.correct_mod_div_work_mask_s) fhe_correct_mod_div_work = self.common.fhe_pre_corr_mod.get_recv() fhe_correct_mod_div_work *= fhe_mult fhe_correct_mod_div_work += fhe_bias del fhe_mult, fhe_bias self.common.fhe_corr_mod_c.send(fhe_correct_mod_div_work)
def test_basic_ntt(): modulus = 786433 img_hw = 34 x = gen_unirand_int_grain(0, modulus - 1, img_hw**2).reshape([img_hw, img_hw ]).numpy().astype(np.int) padded = np.zeros([get_pow_2_ceil(img_hw), get_pow_2_ceil(img_hw)]).astype(np.int) padded[:img_hw, :img_hw] = x x = padded expected = x[:, :] with NamedTimerInstance("NTT2D, Sympy"): ntted = transform2d( x, lambda sub_img: ntt(sub_img.tolist(), prime=modulus)) with NamedTimerInstance("iNTT2D"): reved = transform2d( ntted, lambda sub_img: intt(sub_img.tolist(), prime=modulus)) actual = reved compare_expected_actual(expected, actual, name="ntt", get_relative=True)
def masking_output(self): self.output_mask_s = gen_unirand_int_grain( 0, self.modulus - 1, get_prod(self.y_shape)).reshape(self.y_shape) # self.output_mask_s = torch.ones(self.y_shape) ntted_mask = self.conv2d_ntt.ntt_output_masking(self.output_mask_s) pod_vector = uIntVector() pt = Plaintext() for idx_output_batch in range(self.num_output_batch): encoding_tensor = torch.zeros(self.degree, dtype=torch.float) for index_piece in range(self.num_rotation): span = self.num_elem_in_padded start_piece = index_piece * span index_output_channel = self.index_output_piece_to_channel( idx_output_batch, index_piece) if index_output_channel is False: continue encoding_tensor[start_piece:start_piece + span] = ntted_mask[ index_output_channel].reshape(-1) encoding_tensor = pmod(encoding_tensor, self.modulus) pod_vector.from_np(encoding_tensor.numpy().astype(np.uint64)) self.batch_encoder.encode(pod_vector, pt) self.evaluator.add_plain_inplace(self.output_cts[idx_output_batch], pt)
def generate_random_data(self, shape): return gen_unirand_int_grain(-self.data_range//2 + 1, self.data_range//2, get_prod(shape)).reshape(shape)
def generate_random(self): return gen_unirand_int_grain(0, self.modulus - 1, self.num_elem)
def test_shares_mult(): print("\nTest for Shares Mult: Start") modulus = Config.q_23 num_elem = 2**17 print(f"Number of element: {num_elem}") fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() a = gen_unirand_int_grain(0, modulus - 1, num_elem) a_s = gen_unirand_int_grain(0, modulus - 1, num_elem) a_c = pmod(a - a_s, modulus) b = gen_unirand_int_grain(0, modulus - 1, num_elem) b_s = gen_unirand_int_grain(0, modulus - 1, num_elem) b_c = pmod(b - b_s, modulus) def check_correctness_offline(u, v, z_s, z_c): expected = pmod( u.double().to(Config.device) * v.double().to(Config.device), modulus) actual = pmod(z_s + z_c, modulus) compare_expected_actual(expected, actual, name="shares_mult_offline", get_relative=True) def check_correctness_online(a, b, c_s, c_c): expected = pmod( a.double().to(Config.device) * b.double().to(Config.device), modulus) actual = pmod(c_s + c_c, modulus) compare_expected_actual(expected, actual, name="shares_mult_online", get_relative=True) def test_server(): init_communicate(Config.server_rank) prot = SharesMultServer(num_elem, modulus, fhe_builder, "test_shares_mult") with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() with NamedTimerInstance("Server Online"): prot.online(a_s, b_s) torch_sync() blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_u_c") blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_v_c") blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_z_c") blob_u_c.prepare_recv() blob_v_c.prepare_recv() blob_z_c.prepare_recv() torch_sync() u_c = blob_u_c.get_recv() v_c = blob_v_c.get_recv() z_c = blob_z_c.get_recv() u = pmod(prot.u_s + u_c, modulus) v = pmod(prot.v_s + v_c, modulus) check_correctness_online(u, v, prot.z_s, z_c) blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c") blob_c_c.prepare_recv() torch_sync() c_c = blob_c_c.get_recv() check_correctness_online(a, b, prot.c_s, c_c) end_communicate() def test_client(): init_communicate(Config.client_rank) prot = SharesMultClient(num_elem, modulus, fhe_builder, "test_shares_mult") with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() with NamedTimerInstance("Client Online"): prot.online(a_c, b_c) torch_sync() blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_u_c") blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_v_c") blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_z_c") torch_sync() blob_u_c.send(prot.u_c) blob_v_c.send(prot.v_c) blob_z_c.send(prot.z_c) blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c") torch_sync() blob_c_c.send(prot.c_c) end_communicate() marshal_funcs([test_server, test_client]) print("\nTest for Shares Mult: End")
def test_conv2d_secure_comm(input_sid, master_address, master_port, setting=(16, 3, 128, 128)): test_name = "Conv2d Secure Comm" print(f"\nTest for {test_name}: Start") modulus = 786433 padding = 1 img_hw, filter_hw, num_input_channel, num_output_channel = setting data_bit = 17 data_range = 2**data_bit # n_23 = 8192 n_23 = 16384 print(f"Setting covn2d: img_hw: {img_hw}, " f"filter_hw: {filter_hw}, " f"num_input_channel: {num_input_channel}, " f"num_output_channel: {num_output_channel}") x_shape = [num_input_channel, img_hw, img_hw] w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw] b_shape = [num_output_channel] fhe_builder = FheBuilder(modulus, n_23) fhe_builder.generate_keys() weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) bias = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(b_shape)).reshape(b_shape) input = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape) input_c = generate_random_mask(modulus, x_shape) input_s = pmod(input - input_c, modulus) def check_correctness_online(x, w, b, output_s, output_c): actual = pmod(output_s.cuda() + output_c.cuda(), modulus) torch_x = x.reshape([1] + x_shape).cuda().double() torch_w = w.reshape(w_shape).cuda().double() torch_b = b.cuda().double() if b is not None else None expected = F.conv2d(torch_x, torch_w, padding=padding, bias=torch_b) expected = pmod(expected.reshape(output_s.shape), modulus) compare_expected_actual(expected, actual, name=test_name + " online", get_relative=True) def test_server(): rank = Config.server_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = Conv2dSecureServer(modulus, fhe_builder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_secure_comm", padding=padding) with NamedTimerInstance("Server Offline"): prot.offline(weight, bias=bias) torch_sync() with NamedTimerInstance("Server Online"): prot.online(input_s) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_online(input, weight, bias, prot.output_s, output_c) end_communicate() def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = Conv2dSecureClient(modulus, fhe_builder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_secure_comm", padding=padding) with NamedTimerInstance("Client Offline"): prot.offline(input_c) torch_sync() with NamedTimerInstance("Client Online"): prot.online() torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate() if input_sid == Config.both_rank: marshal_funcs([test_server, test_client]) elif input_sid == Config.server_rank: test_server() elif input_sid == Config.client_rank: test_client() print(f"\nTest for {test_name}: End")
def test_conv2d_fhe_ntt_comm(): test_name = "Conv2d Fhe NTT Comm" print(f"\nTest for {test_name}: Start") modulus = 786433 img_hw = 2 filter_hw = 3 padding = 1 num_input_channel = 512 num_output_channel = 512 data_bit = 17 data_range = 2**data_bit print(f"Setting: img_hw {img_hw}, " f"filter_hw: {filter_hw}, " f"num_input_channel: {num_input_channel}, " f"num_output_channel: {num_output_channel}") x_shape = [num_input_channel, img_hw, img_hw] w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw] fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) input_mask = gen_unirand_int_grain(0, modulus - 1, get_prod(x_shape)).reshape(x_shape) # input_mask = torch.arange(get_prod(x_shape)).reshape(x_shape) def check_correctness_offline(x, w, output_mask, output_c): actual = pmod(output_c.cuda() - output_mask.cuda(), modulus) torch_x = x.reshape([1] + x_shape).cuda().double() torch_w = w.reshape(w_shape).cuda().double() expected = F.conv2d(torch_x, torch_w, padding=padding) expected = pmod(expected.reshape(output_mask.shape), modulus) compare_expected_actual(expected, actual, name=test_name + " offline", get_relative=True) def test_server(): init_communicate(Config.server_rank) prot = Conv2dFheNttServer(modulus, fhe_builder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_fhe_ntt_comm", padding=padding) with NamedTimerInstance("Server Offline"): prot.offline(weight) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_offline(input_mask, weight, prot.output_mask_s, output_c) end_communicate() def test_client(): init_communicate(Config.client_rank) prot = Conv2dFheNttClient(modulus, fhe_builder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_fhe_ntt_comm", padding=padding) with NamedTimerInstance("Client Offline"): prot.offline(input_mask) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate() marshal_funcs([test_server, test_client]) print(f"\nTest for {test_name}: End")
def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, n_16) fhe_builder_23 = FheBuilder(q_23, n_23) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() dgk = DgkBitClient(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "DgkBitTest") x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) y = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) x_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) y_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) x_s = pmod(x - x_c, q_23) y_s = pmod(y - y_c, q_23) y_sub_x_s = pmod(y_s - x_s, q_23) x_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "x") y_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y") y_sub_x_s_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y_sub_x_s") torch_sync() x_blob.send(x) y_blob.send(y) y_sub_x_s_blob.send(y_sub_x_s) torch_sync() with NamedTimerInstance("Client Offline"): dgk.offline() y_sub_x_c = pmod(y_c - x_c, q_23) traffic_record.reset("client-offline") torch_sync() with NamedTimerInstance("Client Online"): dgk.online(y_sub_x_c) traffic_record.reset("client-online") dgk_x_leq_y_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "dgk_x_leq_y_c") correct_mod_div_work_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "correct_mod_div_work_c") z_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "z") torch_sync() dgk_x_leq_y_c_blob.send(dgk.dgk_x_leq_y_c) correct_mod_div_work_c_blob.send(dgk.correct_mod_div_work_c) z_blob.send(dgk.z) end_communicate()
def test_fc_secure_comm(input_sid, master_address, master_port, setting=(512, 512)): test_name = "test_fc_secure_comm" print(f"\nTest for {test_name}: Start") modulus = 786433 num_input_unit, num_output_unit = setting data_bit = 17 print(f"Setting fc: " f"num_input_unit: {num_input_unit}, " f"num_output_unit: {num_output_unit}") data_range = 2**data_bit x_shape = [num_input_unit] w_shape = [num_output_unit, num_input_unit] b_shape = [num_output_unit] fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) bias = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(b_shape)).reshape(b_shape) input = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape) input_c = generate_random_mask(modulus, x_shape) input_s = pmod(input - input_c, modulus) def check_correctness_online(x, w, output_s, output_c): actual = pmod(output_s.cuda() + output_c.cuda(), modulus) torch_x = x.reshape([1] + x_shape).cuda().double() torch_w = w.reshape(w_shape).cuda().double() expected = torch.mm(torch_x, torch_w.t()) if bias is not None: expected += bias.cuda().double() expected = pmod(expected.reshape(output_s.shape), modulus) compare_expected_actual(expected, actual, name=test_name + " online", get_relative=True) def test_server(): rank = Config.server_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = FcSecureServer(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) with NamedTimerInstance("Server Offline"): prot.offline(weight, bias=bias) torch_sync() with NamedTimerInstance("Server Online"): prot.online(input_s) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_online(input, weight, prot.output_s, output_c) end_communicate() def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = FcSecureClient(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) with NamedTimerInstance("Client Offline"): prot.offline(input_c) torch_sync() with NamedTimerInstance("Client Online"): prot.online() torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate() marshal_funcs([test_server, test_client]) print(f"\nTest for {test_name}: End")
def test_fc_fhe_comm(): test_name = "test_fc_fhe_comm" print(f"\nTest for {test_name}: Start") modulus = 786433 num_input_unit = 512 num_output_unit = 512 data_bit = 17 print(f"Setting: num_input_unit {num_input_unit}, " f"num_output_unit: {num_output_unit}") data_range = 2**data_bit x_shape = [num_input_unit] w_shape = [num_output_unit, num_input_unit] fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) input_mask = gen_unirand_int_grain(0, modulus - 1, get_prod(x_shape)).reshape(x_shape) # input_mask = torch.arange(get_prod(x_shape)).reshape(x_shape) def check_correctness_offline(x, w, output_mask, output_c): actual = pmod(output_c.cuda() - output_mask.cuda(), modulus) torch_x = x.reshape([1] + x_shape).cuda().double() torch_w = w.reshape(w_shape).cuda().double() expected = torch.mm(torch_x, torch_w.t()) expected = pmod(expected.reshape(output_mask.shape), modulus) compare_expected_actual(expected, actual, name=test_name + " offline", get_relative=True) def test_server(): init_communicate(Config.server_rank) prot = FcFheServer(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) with NamedTimerInstance("Server Offline"): prot.offline(weight) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_offline(input_mask, weight, prot.output_mask_s, output_c) end_communicate() def test_client(): init_communicate(Config.client_rank) prot = FcFheClient(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) with NamedTimerInstance("Client Offline"): prot.offline(input_mask) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate() marshal_funcs([test_server, test_client]) print(f"\nTest for {test_name}: End")
def offline(self): self.offline_recv() self.delta_a = gen_unirand_int_grain(0, 1, self.num_elem).to(Config.device) # self.s = pmod(1 - 2 * self.delta_a, self.q_16) self.s = pmod(1 - 2 * self.delta_a, self.q_16) # self.r = gen_unirand_int_grain(0, 2 ** (self.work_bit + 1) - 1, self.num_elem).to(Config.device) self.r = gen_unirand_int_grain(0, self.q_23 - 1, self.num_elem).to(Config.device) self.alpha = pmod(self.r, self.work_range) self.alpha_i = self.common.decomp_to_bit(self.alpha).to(Config.device) self.beta_i_mask_s = gen_unirand_int_grain( 0, self.q_16 - 1, self.decomp_bit_shape).to(Config.device) self.ci_mask_s = gen_unirand_int_grain( 0, self.q_16 - 1, [self.work_bit + 1, self.num_elem]).to(Config.device) self.ci_mult_mask_s = gen_unirand_int_grain( 1, self.q_16 - 1, [self.work_bit + 1, self.num_elem]).to(Config.device) self.shuffle_order = torch.rand([self.work_bit + 1, self.num_elem ]).argsort(dim=0).to(Config.device) self.delta_xor_mask_s = gen_unirand_int_grain( 0, self.q_16 - 1, self.num_elem).to(Config.device) self.dgk_x_leq_y_mask_s = gen_unirand_int_grain( 0, self.q_23 - 1, self.num_elem).to(Config.device) self.fast_zeros_sum_xor = torch.zeros(self.sum_shape).to(Config.device) self.mod_div_offline() refresher_ab_xor_c = EncRefresherServer( self.decomp_bit_shape, self.fhe_builder_16, self.common.sub_name("refresher_ab_xor_c")) fhe_beta_i_c = self.common.beta_i_c.get_recv() fhe_beta_i_c_for_sum_c = [ fhe_beta_i_c[i].copy() for i in range(len(fhe_beta_i_c)) ] fhe_alpha_beta_xor_c = self.xor_alpha_known_offline( self.alpha_i, fhe_beta_i_c, self.beta_i_mask_s) fhe_alpha_beta_xor_c = refresher_ab_xor_c.request(fhe_alpha_beta_xor_c) fhe_c_i_c = self.sum_c_i_offline(self.delta_a, fhe_beta_i_c_for_sum_c, fhe_alpha_beta_xor_c, self.s, self.alpha_i, self.ci_mask_s, self.ci_mult_mask_s, self.shuffle_order) self.common.c_i_c.send(fhe_c_i_c) fhe_delta_b_c = self.common.delta_b_c.get_recv() fhe_delta_xor_c = self.xor_delta_known_offline(self.delta_a, fhe_delta_b_c, self.delta_xor_mask_s) self.common.delta_xor_c.send(fhe_delta_xor_c) fhe_z_work_c = self.common.z_work_c.get_recv() fhe_z_work_c -= fhe_delta_xor_c fhe_z_work_c -= self.fhe_builder_23.build_plain_from_torch( self.dgk_x_leq_y_mask_s) self.common.dgk_x_leq_y_c.send(fhe_z_work_c) for ct in fhe_c_i_c + fhe_beta_i_c + fhe_beta_i_c_for_sum_c: del ct del fhe_beta_i_c, fhe_beta_i_c_for_sum_c, fhe_alpha_beta_xor_c, fhe_c_i_c, fhe_delta_b_c, fhe_delta_xor_c, fhe_z_work_c del refresher_ab_xor_c self.online_recv() torch_sync()
def test_fhe_builder(): print() print("Test for FheBuilder: start") modulus, degree = 12289, 2048 # modulus, degree = 65537, 2048 # modulus, degree = 786433, 4096 # modulus, degree = 65537, 4096 num_elem = 2 ** 17 - 1 fhe_builder = FheBuilder(modulus, degree) fhe_builder.generate_keys() print(f"modulus: {modulus}, degree: {degree}") print() gpu_tensor = gen_unirand_int_grain(0, modulus - 1, num_elem) gpu_tensor_rev = gen_unirand_int_grain(0, modulus - 1, num_elem) with NamedTimerInstance(f"build_plain_from_torch with num_elem: {num_elem}"): plain = fhe_builder.build_plain_from_torch(gpu_tensor) with NamedTimerInstance(f"plain.export_as_torch_gpu() with num_elem: {num_elem}"): tensor_from_plain = plain.export_as_torch() print("Fhe Plain encrypt and decrypt: ", end="") assert(compare_expected_actual(gpu_tensor, tensor_from_plain, verbose=True, get_relative=True).RelAvgDiff == 0) print() with NamedTimerInstance(f"fhe_builder.build_enc with num_elem: {num_elem}"): cipher = fhe_builder.build_enc(num_elem) with NamedTimerInstance(f"cipher.encrypt_additive with num_elem: {num_elem}"): cipher.encrypt_additive(gpu_tensor) with NamedTimerInstance(f"fhe_builder.decrypt_to_torch with num_elem: {num_elem}"): tensor_from_cipher = fhe_builder.decrypt_to_torch(cipher) print("Fhe Enc encrypt and decrypt: ", end="") assert(compare_expected_actual(gpu_tensor, tensor_from_cipher, verbose=True, get_relative=True).RelAvgDiff == 0) print() pt = fhe_builder.build_plain_from_torch(gpu_tensor) ct = fhe_builder.build_enc_from_torch(gpu_tensor_rev) with NamedTimerInstance(f"EP add with num_elem: {num_elem}"): ct += pt expected = pmod(gpu_tensor + gpu_tensor_rev, modulus) actual = fhe_builder.decrypt_to_torch(ct) assert(compare_expected_actual(expected, actual, verbose=True, get_relative=True).RelAvgDiff == 0) print() ct1 = fhe_builder.build_enc_from_torch(gpu_tensor) ct2 = fhe_builder.build_enc_from_torch(gpu_tensor_rev) with NamedTimerInstance(f"EE add with num_elem: {num_elem}"): ct1 += ct2 expected = pmod(gpu_tensor + gpu_tensor_rev, modulus) actual = fhe_builder.decrypt_to_torch(ct1) assert(compare_expected_actual(expected, actual, verbose=True, get_relative=True).RelAvgDiff == 0) print() pt = fhe_builder.build_plain_from_torch(gpu_tensor) ct = fhe_builder.build_enc_from_torch(gpu_tensor_rev) with NamedTimerInstance(f"EP sub with num_elem: {num_elem}"): ct -= pt expected = pmod(gpu_tensor_rev - gpu_tensor, modulus) actual = fhe_builder.decrypt_to_torch(ct) assert(compare_expected_actual(expected, actual, verbose=True, get_relative=True).RelAvgDiff == 0) print() ct1 = fhe_builder.build_enc_from_torch(gpu_tensor) ct2 = fhe_builder.build_enc_from_torch(gpu_tensor_rev) with NamedTimerInstance(f"EE add with num_elem: {num_elem}"): ct1 -= ct2 expected = pmod(gpu_tensor - gpu_tensor_rev, modulus) actual = fhe_builder.decrypt_to_torch(ct1) assert(compare_expected_actual(expected, actual, verbose=True, get_relative=True).RelAvgDiff == 0) print() pt = fhe_builder.build_plain_from_torch(gpu_tensor) ct = fhe_builder.build_enc_from_torch(gpu_tensor_rev) fhe_builder.noise_budget(ct, name="before ep") with NamedTimerInstance(f"EP Mult with num_elem: {num_elem}"): ct *= pt fhe_builder.noise_budget(ct, name="after ep") expected = pmod(gpu_tensor.double() * gpu_tensor_rev.double(), modulus) actual = fhe_builder.decrypt_to_torch(ct) assert(compare_expected_actual(expected, actual, verbose=True, get_relative=True).RelAvgDiff == 0) print() ct1 = fhe_builder.build_enc_from_torch(gpu_tensor) ct2 = fhe_builder.build_enc_from_torch(gpu_tensor_rev) fhe_builder.noise_budget(ct1, name="before ep") with NamedTimerInstance(f"EE mult with num_elem: {num_elem}"): ct1 *= ct2 fhe_builder.noise_budget(ct1, name="before ep") expected = pmod(gpu_tensor.double() * gpu_tensor_rev.double(), modulus).float() actual = fhe_builder.decrypt_to_torch(ct1) assert(compare_expected_actual(expected, actual, verbose=True, get_relative=True).RelAvgDiff == 0) print() ct1 = fhe_builder.build_enc_from_torch(gpu_tensor) ct2 = fhe_builder.build_enc_from_torch(gpu_tensor_rev) fhe_builder.noise_budget(ct1, name="before ep") with NamedTimerInstance(f"EE Add with num_elem: {num_elem}"): ct1 *= ct2 fhe_builder.noise_budget(ct1, name="before ep") expected = pmod(gpu_tensor.double() * gpu_tensor_rev.double(), modulus) actual = fhe_builder.decrypt_to_torch(ct1) assert(compare_expected_actual(expected, actual, verbose=True, get_relative=True).RelAvgDiff == 0) print() ct = fhe_builder.build_enc_from_torch(gpu_tensor) fhe_builder.noise_budget(ct, name="before ep") with NamedTimerInstance(f"neg E with num_elem: {num_elem}"): ct = -ct fhe_builder.noise_budget(ct, name="before ep") expected = pmod(-gpu_tensor, modulus) actual = fhe_builder.decrypt_to_torch(ct) assert(compare_expected_actual(expected, actual, verbose=True, get_relative=True).RelAvgDiff == 0) print() print("Test for FheBuilder: Finish")
def test_avgpool2x2_dgk(input_sid, master_address, master_port, num_elem=2**17): test_name = "Avgpool2x2" print(f"\nTest for {test_name}: Start") data_bit = 20 work_bit = 20 data_range = 2**data_bit q_16 = 12289 # q_23 = 786433 q_23 = 7340033 img_hw = 4 print(f"Number of element: {num_elem}") fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_16.generate_keys() fhe_builder_23 = FheBuilder(q_23, Config.n_23) fhe_builder_23.generate_keys() img = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) img_s = gen_unirand_int_grain(0, q_23 - 1, num_elem) img_c = pmod(img - img_s, q_23) def check_correctness_online(img, max_s, max_c): img = torch.where(img < q_23 // 2, img, img - q_23).cuda() pool = torch.nn.AvgPool2d(2) expected = pool(img.double().reshape([-1, img_hw, img_hw ])).reshape(-1) * 4 expected = pmod(expected, q_23) actual = pmod(max_s + max_c, q_23) compare_expected_actual(expected, actual, name=test_name + "_online", get_relative=True) def test_server(): init_communicate(Config.server_rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() prot = Avgpool2x2Server(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "avgpool") with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): prot.online(img_s) torch_sync() traffic_record.reset("server-online") blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_res_c") blob_out_c.prepare_recv() torch_sync() out_c = blob_out_c.get_recv() check_correctness_online(img, prot.out_s, out_c) end_communicate() def test_client(): init_communicate(Config.client_rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() prot = Avgpool2x2Client(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "avgpool") with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): prot.online(img_c) torch_sync() traffic_record.reset("client-online") blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_res_c") torch_sync() blob_out_c.send(prot.out_c) end_communicate() if input_sid == Config.both_rank: marshal_funcs([test_server, test_client]) elif input_sid == Config.server_rank: test_server() elif input_sid == Config.client_rank: test_client() print(f"\nTest for {test_name}: End")