def test_server(): rank = Config.server_rank init_communicate(rank) context.set_rank(rank) comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() conv1.load_weight(conv1_w) conv2.load_weight(conv2_w) fc1.load_weight(fc1_w) trunc1.set_div_to_pow(pow_to_div) with NamedTimerInstance("Server Offline"): secure_nn.offline() torch_sync() with NamedTimerInstance("Server Online"): secure_nn.online() torch_sync() end_communicate()
def test_server(): rank = Config.server_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = Conv2dSecureServer(modulus, fhe_builder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_secure_comm", padding=padding) with NamedTimerInstance("Server Offline"): prot.offline(weight, bias=bias) torch_sync() with NamedTimerInstance("Server Online"): prot.online(input_s) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_online(input, weight, bias, prot.output_s, output_c) end_communicate()
def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = Conv2dSecureClient(modulus, fhe_builder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_secure_comm", padding=padding) with NamedTimerInstance("Client Offline"): prot.offline(input_c) torch_sync() with NamedTimerInstance("Client Online"): prot.online() torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate()
def test_client(): init_communicate(Config.client_rank) prot = SharesMultClient(num_elem, modulus, fhe_builder, "test_shares_mult") with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() with NamedTimerInstance("Client Online"): prot.online(a_c, b_c) torch_sync() blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_u_c") blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_v_c") blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_z_c") torch_sync() blob_u_c.send(prot.u_c) blob_v_c.send(prot.v_c) blob_z_c.send(prot.z_c) blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c") torch_sync() blob_c_c.send(prot.c_c) end_communicate()
def test_server(): init_communicate(Config.server_rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() prot = Avgpool2x2Server(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "avgpool") with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): prot.online(img_s) torch_sync() traffic_record.reset("server-online") blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_res_c") blob_out_c.prepare_recv() torch_sync() out_c = blob_out_c.get_recv() check_correctness_online(img, prot.out_s, out_c) end_communicate()
def test_client(): init_communicate(Config.client_rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() prot = Avgpool2x2Client(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "avgpool") with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): prot.online(img_c) torch_sync() traffic_record.reset("client-online") blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_res_c") torch_sync() blob_out_c.send(prot.out_c) end_communicate()
def test_server(): init_communicate(Config.server_rank) prot = SharesMultServer(num_elem, modulus, fhe_builder, "test_shares_mult") with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() with NamedTimerInstance("Server Online"): prot.online(a_s, b_s) torch_sync() blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_u_c") blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_v_c") blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_z_c") blob_u_c.prepare_recv() blob_v_c.prepare_recv() blob_z_c.prepare_recv() torch_sync() u_c = blob_u_c.get_recv() v_c = blob_v_c.get_recv() z_c = blob_z_c.get_recv() u = pmod(prot.u_s + u_c, modulus) v = pmod(prot.v_s + v_c, modulus) check_correctness_online(u, v, prot.z_s, z_c) blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c") blob_c_c.prepare_recv() torch_sync() c_c = blob_c_c.get_recv() check_correctness_online(a, b, prot.c_s, c_c) end_communicate()
def sum_c_i_offline(self, delta_a, fhe_beta_i_c, fhe_alpha_beta_xor_c, s, alpha_i, ci_mask_s, mult_mask_s, shuffle_order): # the last row of sum_xor is c_{-1}, which helps check the case with x == y fhe_builder = self.fhe_builder_16 # fhe_sum_xor = [fhe_builder.build_enc(self.num_elem) for i in range(self.num_work_batch)] fhe_sum_xor = [None for i in range(self.num_work_batch)] fhe_sum_xor[self.work_bit - 1] = fhe_builder.build_enc(self.num_elem) for i in range(self.work_bit - 1)[::-1]: fhe_sum_xor[i] = fhe_sum_xor[i + 1].copy() fhe_sum_xor[i] += fhe_alpha_beta_xor_c[i + 1] fhe_delta_a = fhe_builder.build_plain_from_torch(delta_a) fhe_sum_xor[self.work_bit] = fhe_sum_xor[0].copy() fhe_sum_xor[self.work_bit] += fhe_alpha_beta_xor_c[0] fhe_sum_xor[self.work_bit] += fhe_delta_a del fhe_delta_a for i in range(self.work_bit)[::-1]: fhe_mult_3 = fhe_builder.build_plain_from_torch( pmod(3 * mult_mask_s[i].cpu(), self.q_16)) fhe_mult_mask_s = fhe_builder.build_plain_from_torch( mult_mask_s[i]) masked_s = pmod( s.type(torch.int64) * mult_mask_s[i].type(torch.int64), self.q_16).type(torch.float32) # print("s * mult_mask_s[i]", torch.max(masked_s)) fhe_s = fhe_builder.build_plain_from_torch(masked_s) fhe_alpha_i = fhe_builder.build_plain_from_torch(alpha_i[i] * mult_mask_s[i]) fhe_ci_mask_s = fhe_builder.build_plain_from_torch(ci_mask_s[i]) fhe_beta_i_c[i] *= fhe_mult_mask_s fhe_sum_xor[i] *= fhe_mult_3 fhe_sum_xor[i] -= fhe_beta_i_c[i] fhe_sum_xor[i] += fhe_s fhe_sum_xor[i] += fhe_alpha_i fhe_sum_xor[i] += fhe_ci_mask_s del fhe_mult_3, fhe_mult_mask_s, fhe_s, fhe_alpha_i, fhe_ci_mask_s fhe_mult_mask_s = fhe_builder.build_plain_from_torch( mult_mask_s[self.work_bit]) fhe_ci_mask_s = fhe_builder.build_plain_from_torch( ci_mask_s[self.work_bit]) fhe_sum_xor[self.work_bit] *= fhe_mult_mask_s fhe_sum_xor[self.work_bit] += fhe_ci_mask_s del fhe_mult_mask_s, fhe_ci_mask_s if self.is_shuffle: with NamedTimerInstance("Shuffle"): refresher = EncRefresherServer( self.sum_shape, fhe_builder, self.sub_name("shuffle_refresher")) with NamedTimerInstance("refresh"): new_fhe_sum_xor = refresher.request(fhe_sum_xor) del fhe_sum_xor fhe_sum_xor = self.generate_fhe_shuffled( shuffle_order, new_fhe_sum_xor) del refresher return fhe_sum_xor
def test_rotation(): modulus = Config.q_23 degree = Config.n_23 fhe_builder = FheBuilder(modulus, degree) fhe_builder.generate_keys() fhe_builder.generate_galois_keys() # x = gen_unirand_int_grain(0, modulus-1, degree) x = torch.arange(degree) with NamedTimerInstance("Fhe Encrypt"): enc = fhe_builder.build_enc_from_torch(x) enc_less = fhe_builder.build_enc_from_torch(x) plain = fhe_builder.build_plain_from_torch(x) fhe_builder.noise_budget(enc, "before mul") with NamedTimerInstance("ep mult"): enc *= plain enc_less *= plain fhe_builder.noise_budget(enc, "after mul") with NamedTimerInstance("ee add"): for i in range(128): enc += enc_less fhe_builder.noise_budget(enc, "after add") with NamedTimerInstance("rot"): fhe_builder.evaluator.rotate_rows_inplace(enc.cts[0], 64, fhe_builder.galois_keys) fhe_builder.noise_budget(enc, "after rot") print(fhe_builder.decrypt_to_torch(enc))
def test_server(): rank = Config.server_rank sys.stdout = Logger() traffic_record = TrafficRecord() secure_nn = get_secure_nn() secure_nn.set_rank(rank).init_communication(master_address=master_addr, master_port=master_port) warming_up_cuda() secure_nn.fhe_builder_sync() load_trunc_params(secure_nn, store_configs) net_state = torch.load(net_state_name) load_weight_params(secure_nn, store_configs, net_state) meta_rg = MetaTruncRandomGenerator() meta_rg.reset_seed() with NamedTimerInstance("Server Offline"): secure_nn.offline() torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): secure_nn.online() torch_sync() traffic_record.reset("server-online") secure_nn.check_correctness(check_correctness) secure_nn.check_layers(get_plain_net, get_hooking_lst(model_name_base)) secure_nn.end_communication()
def test_client(): rank = Config.client_rank init_communicate(rank) context.set_rank(rank) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() input_img = generate_random_data(input_shape) trunc1.set_div_to_pow(pow_to_div) secure_nn.feed_input(input_img) with NamedTimerInstance("Client Offline"): secure_nn.offline() torch_sync() with NamedTimerInstance("Client Online"): secure_nn.online() torch_sync() check_correctness(input_img, secure_nn.get_output()) end_communicate()
def transform2d(img, func, reverse=False): N = len(img) with NamedTimerInstance("Apply func ina loop"): tmp = [func(img[i, :]) for i in range(N)] with NamedTimerInstance("list to numpy"): img = np.array(tmp) img = np.array([func(img[:, i]) for i in range(N)]) return img.transpose()
def test_conv2d_fhe_ntt_single_thread(): modulus = 786433 img_hw = 16 filter_hw = 3 padding = 1 num_input_channel = 64 num_output_channel = 128 data_bit = 17 data_range = 2**data_bit x_shape = [num_input_channel, img_hw, img_hw] w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw] fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape) w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape) # x = torch.arange(get_prod(x_shape)).reshape(x_shape) # w = torch.arange(get_prod(w_shape)).reshape(w_shape) warming_up_cuda() prot = Conv2dFheNttSingleThread(modulus, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, fhe_builder, "test_conv2d_fhe_ntt", padding) print("prot.num_input_batch", prot.num_input_batch) print("prot.num_output_batch", prot.num_output_batch) with NamedTimerInstance("encoding x"): prot.encode_input_to_fhe_batch(x) with NamedTimerInstance("conv2d with w"): prot.compute_conv2d(w) with NamedTimerInstance("conv2d masking output"): prot.masking_output() with NamedTimerInstance("decoding output"): y = prot.decode_output_from_fhe_batch() # actual = pmod(y, modulus) actual = pmod(y - prot.output_mask_s, modulus) # print("actual\n", actual) torch_x = x.reshape([1] + x_shape).double() torch_w = w.reshape(w_shape).double() with NamedTimerInstance("Conv2d Torch"): expected = F.conv2d(torch_x, torch_w, padding=padding) expected = pmod(expected.reshape(prot.output_shape), modulus) # print("expected", expected) compare_expected_actual(expected, actual, name="test_conv2d_fhe_ntt_single_thread", get_relative=True)
def test_server(): init_communicate(Config.server_rank) warming_up_cuda() prot = ReconToClientServer(num_elem, modulus, test_name) with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() with NamedTimerInstance("Server online"): prot.online(x_s) torch_sync() end_communicate()
def test_fc_fhe_single_thread(): test_name = "test_fc_fhe_single_thread" print(f"\nTest for {test_name}: Start") modulus = Config.q_23 num_input_unit = 512 num_output_unit = 512 data_bit = 17 data_range = 2**data_bit x_shape = [num_input_unit] w_shape = [num_output_unit, num_input_unit] fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape) w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape) # w = gen_unirand_int_grain(0, modulus, get_prod(w_shape)).reshape(w_shape) # x = torch.arange(get_prod(x_shape)).reshape(x_shape) # w = torch.arange(get_prod(w_shape)).reshape(w_shape) warming_up_cuda() prot = FcFheSingleThread(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) print("prot.num_input_batch", prot.num_input_batch) print("prot.num_output_batch", prot.num_output_batch) print("prot.num_elem_in_piece", prot.num_elem_in_piece) with NamedTimerInstance("encoding x"): prot.encode_input_to_fhe_batch(x) with NamedTimerInstance("conv2d with w"): prot.compute_with_weight(w) with NamedTimerInstance("conv2d masking output"): prot.masking_output() with NamedTimerInstance("decoding output"): y = prot.decode_output_from_fhe_batch() actual = pmod(y, modulus) actual = pmod(y - prot.output_mask_s, modulus) # print("actual\n", actual) torch_x = x.reshape([1] + x_shape).double() torch_w = w.reshape(w_shape).double() with NamedTimerInstance("Conv2d Torch"): expected = torch.mm(torch_x, torch_w.t()) expected = pmod(expected.reshape(prot.output_shape), modulus) compare_expected_actual(expected, actual, name=test_name, get_relative=True) print(f"\nTest for {test_name}: End")
def test_client(): init_communicate(Config.client_rank) warming_up_cuda() prot = ReconToClientClient(num_elem, modulus, test_name) with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() with NamedTimerInstance("Client Online"): prot.online(x_c) torch_sync() check_correctness_online(prot.output, x_s, x_c) end_communicate()
def test_server(): rank = Config.server_rank init_communicate(rank, master_address=master_address, master_port=master_port) traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() img = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) img_s = gen_unirand_int_grain(0, q_23 - 1, num_elem) img_c = pmod(img - img_s, q_23) prot = Maxpool2x2DgkServer(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "max_dgk") blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_max_c") torch_sync() blob_img_c.send(img_c) torch_sync() with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): prot.online(img_s) torch_sync() traffic_record.reset("server-online") blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_max_c") blob_max_c.prepare_recv() torch_sync() max_c = blob_max_c.get_recv() check_correctness_online(img, prot.max_s, max_c) end_communicate()
def comm_base_client(): init_communicate(Config.client_rank) comm_base = CommBase(Config.client_rank, comm_name) send_float = expect_float.cuda() send_double = expect_double.cuda() with NamedTimerInstance("Client float and int16"): comm_base.recv_torch( torch.zeros(num_elem).type(torch.int16), int16_tag) comm_base.send_torch(send_float, float_tag) comm_base.wait(int16_tag) actual_int16 = comm_base.get_tensor(int16_tag).cuda() comm_base.recv_torch( torch.zeros(num_elem).type(torch.int16), int16_tag) dist.barrier() with NamedTimerInstance("Client float and int16"): comm_base.send_torch(send_float, float_tag) comm_base.wait(int16_tag) actual_int16 = comm_base.get_tensor(int16_tag).cuda() comm_base.recv_torch( torch.zeros(num_elem).type(torch.uint8), uint8_tag) dist.barrier() with NamedTimerInstance("Client double and uint8"): comm_base.send_torch(send_double, double_tag) comm_base.wait(uint8_tag) actual_uint8 = comm_base.get_tensor(uint8_tag).cuda() comm_base.recv_torch( torch.zeros(num_elem).type(torch.uint8), uint8_tag) dist.barrier() with NamedTimerInstance("Client double and uint8"): comm_base.send_torch(send_double, double_tag) comm_base.wait(uint8_tag) actual_uint8 = comm_base.get_tensor(uint8_tag).cuda() dist.barrier() compare_expected_actual(expect_int16.cuda(), actual_int16, name="int16", get_relative=True) compare_expected_actual(expect_uint8.cuda(), actual_uint8, name="uint8", get_relative=True) dist.destroy_process_group()
def test_torch_ntt(): modulus = 786433 img_hw = 64 filter_hw = 3 padding = 1 data_bit = 17 len_vector = img_hw data_range = 2**data_bit root, mod, ntt_mat, inv_mat = generate_ntt_param(modulus, len_vector, data_range) x = np.arange(img_hw**2).reshape([img_hw, img_hw]).astype(np.int) w = np.arange(filter_hw**2).reshape([filter_hw, filter_hw]).astype(np.int) x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, img_hw**2).reshape([img_hw, img_hw]).double() w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, filter_hw**2).reshape([filter_hw, filter_hw]).double() ntt_mat = ntt_mat.double() inv_mat = inv_mat.double() with NamedTimerInstance("Mat NTT 2d"): ntted = ntt_mat2d(ntt_mat, mod, x) reved = ntt_mat2d(inv_mat, mod, ntted) expected = pmod(x, modulus).type(torch.int) actual = pmod(reved, modulus).type(torch.int) compare_expected_actual(expected, actual, name="ntt", get_relative=True)
def response(self): self.prepare_recv() torch_sync() fhe_masked = self.common.masked.get_recv() with NamedTimerInstance("client refresh reencrypt"): if len(self.shape) == 2: for i in range(self.shape[0]): sub_dec = self.fhe_builder.decrypt_to_torch(fhe_masked[i]) self.refreshed[i].encrypt_additive(sub_dec) else: self.refreshed.encrypt_additive(self.fhe_builder.decrypt_to_torch(fhe_masked)) # def sub_reencrypt(sub_enc): # sub_dec = self.fhe_builder.decrypt_to_torch(sub_enc) # sub_refreshed = self.fhe_builder.build_enc_from_torch(sub_dec) # del sub_dec # return sub_refreshed # with NamedTimerInstance("client refresh reencrypt"): # self.refreshed = sub_handle(sub_reencrypt, fhe_masked) self.common.refreshed.send(self.refreshed) torch_sync() delete_fhe(fhe_masked) delete_fhe(self.refreshed)
def test_client(): init_communicate(Config.client_rank) warming_up_cuda() prot = SwapToClientOfflineClient(num_elem, modulus, test_name) with NamedTimerInstance("Client Offline"): prot.offline(y_c) torch_sync() with NamedTimerInstance("Client Online"): prot.online(x_c) torch_sync() blob_output_c = BlobTorch(num_elem, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate()
def test_ntt_conv(): modulus = 786433 img_hw = 16 filter_hw = 3 padding = 1 num_input_channel = 64 num_output_channel = 128 data_bit = 17 data_range = 2**data_bit x_shape = [num_input_channel, img_hw, img_hw] w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw] x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape) w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) # x = torch.arange(get_prod(x_shape)).reshape(x_shape) # w = torch.arange(get_prod(w_shape)).reshape(w_shape) conv2d_ntt = Conv2dNtt(modulus, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_ntt", padding) y = conv2d_ntt.conv2d(x, w) with NamedTimerInstance("ntt x"): conv2d_ntt.load_and_ntt_x(x) with NamedTimerInstance("ntt w"): conv2d_ntt.load_and_ntt_w(w) with NamedTimerInstance("conv2d"): y = conv2d_ntt.conv2d_loaded() actual = pmod(y, modulus) # print("actual\n", actual) torch_x = x.reshape([1] + x_shape).double() torch_w = w.reshape(w_shape).double() with NamedTimerInstance("Conv2d Torch"): expected = F.conv2d(torch_x, torch_w, padding=padding) expected = pmod(expected.reshape(conv2d_ntt.y_shape), modulus) # print("expected", expected) compare_expected_actual(expected, actual, name="ntt", get_relative=True, show_where_err=False)
def test_nnt_conv_single_channel(): modulus = 786433 img_hw = 6 filter_hw = 3 padding = 1 data_bit = 17 data_range = 2**data_bit conv_hw = img_hw + 2 * padding padded_hw = get_pow_2_ceil(conv_hw) output_hw = img_hw + 2 * padding - (filter_hw - 1) output_offset = filter_hw - 2 x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, img_hw**2).reshape([img_hw, img_hw ]).numpy().astype(np.int) w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, filter_hw**2).reshape([filter_hw, filter_hw ]).numpy().astype(np.int) x = np.arange(img_hw**2).reshape([img_hw, img_hw]).astype(np.int) w = np.arange(filter_hw**2).reshape([filter_hw, filter_hw]).astype(np.int) padded_x = pad_to_size(x, padded_hw) padded_w = pad_to_size(np.rot90(w, 2), padded_hw) print(padded_x) print(padded_w) with NamedTimerInstance("NTT2D, Sympy"): ntted_x = transform2d( padded_x, lambda sub_img: ntt(sub_img.tolist(), prime=modulus)) ntted_w = transform2d( padded_w, lambda sub_img: ntt(sub_img.tolist(), prime=modulus)) with NamedTimerInstance("Point-wise Dot"): doted = ntted_x * ntted_w with NamedTimerInstance("iNTT2D"): reved = transform2d( doted, lambda sub_img: intt(sub_img.tolist(), prime=modulus)) actual = reved[output_offset:output_hw + output_offset, output_offset:output_hw + output_offset] print("reved\n", reved) print("actual\n", actual) torch_x = torch.tensor(x).reshape([1, 1, img_hw, img_hw]) torch_w = torch.tensor(w).reshape([1, 1, filter_hw, filter_hw]) expected = F.conv2d(torch_x, torch_w, padding=1) expected = pmod(expected.reshape(output_hw, output_hw), modulus) print("expected", expected) compare_expected_actual(expected, actual, name="ntt", get_relative=True)
def run_secure_nn_client_with_random_data(secure_nn, check_correctness, master_address, master_port): rank = Config.client_rank traffic_record = TrafficRecord() secure_nn.set_rank(rank).init_communication(master_address=master_address, master_port=master_port) warming_up_cuda() secure_nn.fhe_builder_sync().fill_random_input() with NamedTimerInstance("Client Offline"): secure_nn.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): secure_nn.online() torch_sync() traffic_record.reset("client-online") secure_nn.check_correctness(check_correctness).end_communication()
def test_basic_ntt(): modulus = 786433 img_hw = 34 x = gen_unirand_int_grain(0, modulus - 1, img_hw**2).reshape([img_hw, img_hw ]).numpy().astype(np.int) padded = np.zeros([get_pow_2_ceil(img_hw), get_pow_2_ceil(img_hw)]).astype(np.int) padded[:img_hw, :img_hw] = x x = padded expected = x[:, :] with NamedTimerInstance("NTT2D, Sympy"): ntted = transform2d( x, lambda sub_img: ntt(sub_img.tolist(), prime=modulus)) with NamedTimerInstance("iNTT2D"): reved = transform2d( ntted, lambda sub_img: intt(sub_img.tolist(), prime=modulus)) actual = reved compare_expected_actual(expected, actual, name="ntt", get_relative=True)
def test_server(): init_communicate(Config.server_rank) warming_up_cuda() prot = SwapToClientOfflineServer(num_elem, modulus, test_name) with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() with NamedTimerInstance("Server online"): prot.online(x_s) torch_sync() blob_output_c = BlobTorch(num_elem, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_online(x_s, x_c, prot.output_s, output_c) end_communicate()
def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() prot = Maxpool2x2DgkClient(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "max_dgk") blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_max_c") blob_img_c.prepare_recv() torch_sync() img_c = blob_img_c.get_recv() torch_sync() with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): prot.online(img_c) torch_sync() traffic_record.reset("client-online") blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_max_c") torch_sync() blob_max_c.send(prot.max_c) end_communicate()
def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() a = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) a_c = gen_unirand_int_grain(0, q_23 - 1, num_elem) a_s = pmod(a - a_c, q_23) prot = ReluDgkClient(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk") blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a") blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s") torch_sync() blob_a_s.send(a_s) blob_max_s.prepare_recv() torch_sync() with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): prot.online(a_c) torch_sync() traffic_record.reset("client-online") max_s = blob_max_s.get_recv() check_correctness_online(a, max_s, prot.max_c) torch.cuda.empty_cache() end_communicate()
def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = FcSecureClient(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) with NamedTimerInstance("Client Offline"): prot.offline(input_c) torch_sync() with NamedTimerInstance("Client Online"): prot.online() torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate()
def test_client(): init_communicate(Config.client_rank) prot = FcFheClient(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) with NamedTimerInstance("Client Offline"): prot.offline(input_mask) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate()