def test_swap_to_client_comm(): test_name = "test_swap_to_client_comm" print(f"\nTest for {test_name}: Start") modulus = 786433 num_elem = 2**17 print(f"Number of element: {num_elem}") x_s = gen_unirand_int_grain(0, modulus - 1, num_elem).cpu() x_c = gen_unirand_int_grain(0, modulus - 1, num_elem).cpu() y_c = gen_unirand_int_grain(0, modulus - 1, num_elem).cpu() def check_correctness_online(input_s, input_c, output_s, output_c): expected = pmod(input_s.cuda() + input_c.cuda(), modulus) actual = pmod(output_s.cuda() + output_c.cuda(), modulus) compare_expected_actual(expected, actual, name=test_name + " online", get_relative=True) def test_server(): init_communicate(Config.server_rank) warming_up_cuda() prot = SwapToClientOfflineServer(num_elem, modulus, test_name) with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() with NamedTimerInstance("Server online"): prot.online(x_s) torch_sync() blob_output_c = BlobTorch(num_elem, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_online(x_s, x_c, prot.output_s, output_c) end_communicate() def test_client(): init_communicate(Config.client_rank) warming_up_cuda() prot = SwapToClientOfflineClient(num_elem, modulus, test_name) with NamedTimerInstance("Client Offline"): prot.offline(y_c) torch_sync() with NamedTimerInstance("Client Online"): prot.online(x_c) torch_sync() blob_output_c = BlobTorch(num_elem, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate() marshal_funcs([test_server, test_client]) print(f"\nTest for {test_name}: End")
def test_fhe_refresh(): print() print("Test: FHE refresh: Start") modulus, degree = 12289, 2048 num_batch = 12 num_elem = 2 ** 15 fhe_builder = FheBuilder(modulus, degree) fhe_builder.generate_keys() def test_batch_server(): init_communicate(Config.server_rank) shape = [num_batch, num_elem] tensor = gen_unirand_int_grain(0, modulus, shape) refresher = EncRefresherServer(shape, fhe_builder, "test_batch_refresher") enc = [fhe_builder.build_enc_from_torch(tensor[i]) for i in range(num_batch)] refreshed = refresher.request(enc) tensor_refreshed = fhe_builder.decrypt_to_torch(refreshed) compare_expected_actual(tensor, tensor_refreshed, get_relative=True, name="batch refresh") end_communicate() def test_batch_client(): init_communicate(Config.client_rank) shape = [num_batch, num_elem] refresher = EncRefresherClient(shape, fhe_builder, "test_batch_refresher") refresher.prepare_recv() refresher.response() end_communicate() marshal_funcs([test_batch_server, test_batch_client]) def test_1d_server(): init_communicate(Config.server_rank) shape = num_elem tensor = gen_unirand_int_grain(0, modulus, shape) refresher = EncRefresherServer(shape, fhe_builder, "test_1d_refresher") enc = fhe_builder.build_enc_from_torch(tensor) refreshed = refresher.request(enc) tensor_refreshed = fhe_builder.decrypt_to_torch(refreshed) compare_expected_actual(tensor, tensor_refreshed, get_relative=True, name="1d_refresh") end_communicate() def test_1d_client(): init_communicate(Config.client_rank) shape = num_elem refresher = EncRefresherClient(shape, fhe_builder, "test_1d_refresher") refresher.prepare_recv() refresher.response() end_communicate() marshal_funcs([test_1d_server, test_1d_client]) print() print("Test: FHE refresh: End")
def test_comm_base(): num_elem = 2**17 expect_float = torch.ones(num_elem).float() + 3 expect_double = torch.ones(num_elem).double() + 5 expect_int16 = torch.ones(num_elem).type(torch.int16) + 324 expect_uint8 = torch.ones(num_elem).type(torch.uint8) + 123 comm_name = "comm_base" float_tag = "float_tag" double_tag = "double_tag" int16_tag = "int16_tag" uint8_tag = "uint8_tag" def comm_base_server(): init_communicate(Config.server_rank) comm_base = CommBase(Config.server_rank, comm_name) send_int16 = expect_int16.cuda() send_uint8 = expect_uint8.cuda() with NamedTimerInstance("Server float and int16"): comm_base.recv_torch(torch.zeros(num_elem).float(), float_tag) comm_base.send_torch(send_int16, int16_tag) comm_base.wait(float_tag) actual_float = comm_base.get_tensor(float_tag).cuda() comm_base.recv_torch(torch.zeros(num_elem).float(), float_tag) dist.barrier() with NamedTimerInstance("Server float and int16"): comm_base.send_torch(send_int16, int16_tag) comm_base.wait(float_tag) actual_float = comm_base.get_tensor(float_tag).cuda() comm_base.recv_torch(torch.zeros(num_elem).double(), double_tag) dist.barrier() with NamedTimerInstance("Server double and uint8"): comm_base.send_torch(send_uint8, uint8_tag) comm_base.wait(double_tag) actual_double = comm_base.get_tensor(double_tag).cuda() comm_base.recv_torch(torch.zeros(num_elem).double(), double_tag) dist.barrier() with NamedTimerInstance("Server double and uint8"): comm_base.send_torch(send_uint8, uint8_tag) comm_base.wait(double_tag) actual_double = comm_base.get_tensor(double_tag).cuda() dist.barrier() compare_expected_actual(expect_float.cuda(), actual_float, name="float", get_relative=True) compare_expected_actual(expect_double.cuda(), actual_double, name="double", get_relative=True) google_vm_simulator = NetworkSimulator(bandwidth=10 * (10**9), basis_latency=.001) with NamedTimerInstance("Simulate int16"): google_vm_simulator.simulate(send_int16.cpu().cuda()) with NamedTimerInstance("Simulate uint8"): google_vm_simulator.simulate(send_uint8.cpu().cuda()) dist.destroy_process_group() def comm_base_client(): init_communicate(Config.client_rank) comm_base = CommBase(Config.client_rank, comm_name) send_float = expect_float.cuda() send_double = expect_double.cuda() with NamedTimerInstance("Client float and int16"): comm_base.recv_torch( torch.zeros(num_elem).type(torch.int16), int16_tag) comm_base.send_torch(send_float, float_tag) comm_base.wait(int16_tag) actual_int16 = comm_base.get_tensor(int16_tag).cuda() comm_base.recv_torch( torch.zeros(num_elem).type(torch.int16), int16_tag) dist.barrier() with NamedTimerInstance("Client float and int16"): comm_base.send_torch(send_float, float_tag) comm_base.wait(int16_tag) actual_int16 = comm_base.get_tensor(int16_tag).cuda() comm_base.recv_torch( torch.zeros(num_elem).type(torch.uint8), uint8_tag) dist.barrier() with NamedTimerInstance("Client double and uint8"): comm_base.send_torch(send_double, double_tag) comm_base.wait(uint8_tag) actual_uint8 = comm_base.get_tensor(uint8_tag).cuda() comm_base.recv_torch( torch.zeros(num_elem).type(torch.uint8), uint8_tag) dist.barrier() with NamedTimerInstance("Client double and uint8"): comm_base.send_torch(send_double, double_tag) comm_base.wait(uint8_tag) actual_uint8 = comm_base.get_tensor(uint8_tag).cuda() dist.barrier() compare_expected_actual(expect_int16.cuda(), actual_int16, name="int16", get_relative=True) compare_expected_actual(expect_uint8.cuda(), actual_uint8, name="uint8", get_relative=True) dist.destroy_process_group() marshal_funcs([comm_base_server, comm_base_client])
def test_comm_fhe_builder(): num_elem = 2**17 # modulus, degree = 12289, 2048 modulus, degree = Config.q_16, Config.n_16 expected_tensor_pk = torch.zeros(num_elem).float() + 23 expected_tensor_sk = torch.zeros(num_elem).float() + 42 pk_tag = "pk" sk_tag = "sk" ori_tag = "ori" ori = generate_random_mask(modulus, num_elem) def test_server(): init_communicate(Config.server_rank) fhe_builder = FheBuilder(modulus, degree) comm_fhe_builder = CommFheBuilder(Config.server_rank, fhe_builder, "fhe_builder") comm_fhe_builder.recv_public_key() comm_fhe_builder.wait_and_build_public_key() blob_ori = BlobFheEnc(num_elem, comm_fhe_builder, ori_tag) blob_ori.send_from_torch(ori) enc_pk = fhe_builder.build_enc_from_torch(expected_tensor_pk) comm_fhe_builder.send_enc(enc_pk, pk_tag) comm_fhe_builder.recv_secret_key() comm_fhe_builder.wait_and_build_secret_key() enc_sk = fhe_builder.build_enc(num_elem) comm_fhe_builder.recv_enc(enc_sk, sk_tag) comm_fhe_builder.wait_enc(sk_tag) actual_tensor_sk = fhe_builder.decrypt_to_torch(enc_sk) compare_expected_actual(expected_tensor_sk, actual_tensor_sk, get_relative=True, name="Recovering to test sk") dist.destroy_process_group() def test_client(): init_communicate(Config.client_rank) fhe_builder = FheBuilder(modulus, degree) comm_fhe_builder = CommFheBuilder(Config.client_rank, fhe_builder, "fhe_builder") fhe_builder.generate_keys() comm_fhe_builder.send_public_key() blob_ori = BlobFheEnc(num_elem, comm_fhe_builder, ori_tag) blob_ori.prepare_recv() dec = blob_ori.get_recv_decrypt() compare_expected_actual(ori, dec, get_relative=True, name=ori_tag) enc_pk = fhe_builder.build_enc(num_elem) comm_fhe_builder.recv_enc(enc_pk, pk_tag) comm_fhe_builder.wait_enc(pk_tag) actual_tensor_pk = fhe_builder.decrypt_to_torch(enc_pk) compare_expected_actual(expected_tensor_pk, actual_tensor_pk, get_relative=True, name="Recovering to test pk") comm_fhe_builder.send_secret_key() enc_sk = fhe_builder.build_enc_from_torch(expected_tensor_sk) comm_fhe_builder.send_enc(enc_sk, sk_tag) dist.destroy_process_group() marshal_funcs([test_server, test_client])
def test_maxpool2x2_dgk(input_sid, master_address, master_port, num_elem=2**17): test_name = "Maxpool2x2 Dgk" print(f"\nTest for {test_name}: Start") data_bit = 20 work_bit = 20 data_range = 2**data_bit q_16 = Config.q_16 # q_23 = 786433 q_23 = Config.q_23 img_hw = 4 print(f"Number of element: {num_elem}") def check_correctness_online(img, max_s, max_c): img = torch.where(img < q_23 // 2, img, img - q_23).to(Config.device) pool = torch.nn.MaxPool2d(2) expected = pool(img.reshape([-1, img_hw, img_hw])).reshape(-1) expected = pmod(expected, q_23) actual = pmod(max_s + max_c, q_23) compare_expected_actual(expected, actual, name="maxpool2x2_dgk_online", get_relative=True) def test_server(): rank = Config.server_rank init_communicate(rank, master_address=master_address, master_port=master_port) traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() img = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) img_s = gen_unirand_int_grain(0, q_23 - 1, num_elem) img_c = pmod(img - img_s, q_23) prot = Maxpool2x2DgkServer(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "max_dgk") blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_max_c") torch_sync() blob_img_c.send(img_c) torch_sync() with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): prot.online(img_s) torch_sync() traffic_record.reset("server-online") blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_max_c") blob_max_c.prepare_recv() torch_sync() max_c = blob_max_c.get_recv() check_correctness_online(img, prot.max_s, max_c) end_communicate() def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() prot = Maxpool2x2DgkClient(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "max_dgk") blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_max_c") blob_img_c.prepare_recv() torch_sync() img_c = blob_img_c.get_recv() torch_sync() with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): prot.online(img_c) torch_sync() traffic_record.reset("client-online") blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_max_c") torch_sync() blob_max_c.send(prot.max_c) end_communicate() if input_sid == Config.both_rank: marshal_funcs([test_server, test_client]) elif input_sid == Config.server_rank: marshal_funcs([test_server]) elif input_sid == Config.client_rank: marshal_funcs([test_client]) print(f"\nTest for {test_name}: End")
def test_shares_mult(): print("\nTest for Shares Mult: Start") modulus = Config.q_23 num_elem = 2**17 print(f"Number of element: {num_elem}") fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() a = gen_unirand_int_grain(0, modulus - 1, num_elem) a_s = gen_unirand_int_grain(0, modulus - 1, num_elem) a_c = pmod(a - a_s, modulus) b = gen_unirand_int_grain(0, modulus - 1, num_elem) b_s = gen_unirand_int_grain(0, modulus - 1, num_elem) b_c = pmod(b - b_s, modulus) def check_correctness_offline(u, v, z_s, z_c): expected = pmod( u.double().to(Config.device) * v.double().to(Config.device), modulus) actual = pmod(z_s + z_c, modulus) compare_expected_actual(expected, actual, name="shares_mult_offline", get_relative=True) def check_correctness_online(a, b, c_s, c_c): expected = pmod( a.double().to(Config.device) * b.double().to(Config.device), modulus) actual = pmod(c_s + c_c, modulus) compare_expected_actual(expected, actual, name="shares_mult_online", get_relative=True) def test_server(): init_communicate(Config.server_rank) prot = SharesMultServer(num_elem, modulus, fhe_builder, "test_shares_mult") with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() with NamedTimerInstance("Server Online"): prot.online(a_s, b_s) torch_sync() blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_u_c") blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_v_c") blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_z_c") blob_u_c.prepare_recv() blob_v_c.prepare_recv() blob_z_c.prepare_recv() torch_sync() u_c = blob_u_c.get_recv() v_c = blob_v_c.get_recv() z_c = blob_z_c.get_recv() u = pmod(prot.u_s + u_c, modulus) v = pmod(prot.v_s + v_c, modulus) check_correctness_online(u, v, prot.z_s, z_c) blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c") blob_c_c.prepare_recv() torch_sync() c_c = blob_c_c.get_recv() check_correctness_online(a, b, prot.c_s, c_c) end_communicate() def test_client(): init_communicate(Config.client_rank) prot = SharesMultClient(num_elem, modulus, fhe_builder, "test_shares_mult") with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() with NamedTimerInstance("Client Online"): prot.online(a_c, b_c) torch_sync() blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_u_c") blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_v_c") blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_z_c") torch_sync() blob_u_c.send(prot.u_c) blob_v_c.send(prot.v_c) blob_z_c.send(prot.z_c) blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c") torch_sync() blob_c_c.send(prot.c_c) end_communicate() marshal_funcs([test_server, test_client]) print("\nTest for Shares Mult: End")
def test_conv2d_secure_comm(input_sid, master_address, master_port, setting=(16, 3, 128, 128)): test_name = "Conv2d Secure Comm" print(f"\nTest for {test_name}: Start") modulus = 786433 padding = 1 img_hw, filter_hw, num_input_channel, num_output_channel = setting data_bit = 17 data_range = 2**data_bit # n_23 = 8192 n_23 = 16384 print(f"Setting covn2d: img_hw: {img_hw}, " f"filter_hw: {filter_hw}, " f"num_input_channel: {num_input_channel}, " f"num_output_channel: {num_output_channel}") x_shape = [num_input_channel, img_hw, img_hw] w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw] b_shape = [num_output_channel] fhe_builder = FheBuilder(modulus, n_23) fhe_builder.generate_keys() weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) bias = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(b_shape)).reshape(b_shape) input = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape) input_c = generate_random_mask(modulus, x_shape) input_s = pmod(input - input_c, modulus) def check_correctness_online(x, w, b, output_s, output_c): actual = pmod(output_s.cuda() + output_c.cuda(), modulus) torch_x = x.reshape([1] + x_shape).cuda().double() torch_w = w.reshape(w_shape).cuda().double() torch_b = b.cuda().double() if b is not None else None expected = F.conv2d(torch_x, torch_w, padding=padding, bias=torch_b) expected = pmod(expected.reshape(output_s.shape), modulus) compare_expected_actual(expected, actual, name=test_name + " online", get_relative=True) def test_server(): rank = Config.server_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = Conv2dSecureServer(modulus, fhe_builder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_secure_comm", padding=padding) with NamedTimerInstance("Server Offline"): prot.offline(weight, bias=bias) torch_sync() with NamedTimerInstance("Server Online"): prot.online(input_s) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_online(input, weight, bias, prot.output_s, output_c) end_communicate() def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = Conv2dSecureClient(modulus, fhe_builder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_secure_comm", padding=padding) with NamedTimerInstance("Client Offline"): prot.offline(input_c) torch_sync() with NamedTimerInstance("Client Online"): prot.online() torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate() if input_sid == Config.both_rank: marshal_funcs([test_server, test_client]) elif input_sid == Config.server_rank: test_server() elif input_sid == Config.client_rank: test_client() print(f"\nTest for {test_name}: End")
def test_conv2d_fhe_ntt_comm(): test_name = "Conv2d Fhe NTT Comm" print(f"\nTest for {test_name}: Start") modulus = 786433 img_hw = 2 filter_hw = 3 padding = 1 num_input_channel = 512 num_output_channel = 512 data_bit = 17 data_range = 2**data_bit print(f"Setting: img_hw {img_hw}, " f"filter_hw: {filter_hw}, " f"num_input_channel: {num_input_channel}, " f"num_output_channel: {num_output_channel}") x_shape = [num_input_channel, img_hw, img_hw] w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw] fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) input_mask = gen_unirand_int_grain(0, modulus - 1, get_prod(x_shape)).reshape(x_shape) # input_mask = torch.arange(get_prod(x_shape)).reshape(x_shape) def check_correctness_offline(x, w, output_mask, output_c): actual = pmod(output_c.cuda() - output_mask.cuda(), modulus) torch_x = x.reshape([1] + x_shape).cuda().double() torch_w = w.reshape(w_shape).cuda().double() expected = F.conv2d(torch_x, torch_w, padding=padding) expected = pmod(expected.reshape(output_mask.shape), modulus) compare_expected_actual(expected, actual, name=test_name + " offline", get_relative=True) def test_server(): init_communicate(Config.server_rank) prot = Conv2dFheNttServer(modulus, fhe_builder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_fhe_ntt_comm", padding=padding) with NamedTimerInstance("Server Offline"): prot.offline(weight) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_offline(input_mask, weight, prot.output_mask_s, output_c) end_communicate() def test_client(): init_communicate(Config.client_rank) prot = Conv2dFheNttClient(modulus, fhe_builder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_fhe_ntt_comm", padding=padding) with NamedTimerInstance("Client Offline"): prot.offline(input_mask) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate() marshal_funcs([test_server, test_client]) print(f"\nTest for {test_name}: End")
def test_fc_secure_comm(input_sid, master_address, master_port, setting=(512, 512)): test_name = "test_fc_secure_comm" print(f"\nTest for {test_name}: Start") modulus = 786433 num_input_unit, num_output_unit = setting data_bit = 17 print(f"Setting fc: " f"num_input_unit: {num_input_unit}, " f"num_output_unit: {num_output_unit}") data_range = 2**data_bit x_shape = [num_input_unit] w_shape = [num_output_unit, num_input_unit] b_shape = [num_output_unit] fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) bias = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(b_shape)).reshape(b_shape) input = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape) input_c = generate_random_mask(modulus, x_shape) input_s = pmod(input - input_c, modulus) def check_correctness_online(x, w, output_s, output_c): actual = pmod(output_s.cuda() + output_c.cuda(), modulus) torch_x = x.reshape([1] + x_shape).cuda().double() torch_w = w.reshape(w_shape).cuda().double() expected = torch.mm(torch_x, torch_w.t()) if bias is not None: expected += bias.cuda().double() expected = pmod(expected.reshape(output_s.shape), modulus) compare_expected_actual(expected, actual, name=test_name + " online", get_relative=True) def test_server(): rank = Config.server_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = FcSecureServer(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) with NamedTimerInstance("Server Offline"): prot.offline(weight, bias=bias) torch_sync() with NamedTimerInstance("Server Online"): prot.online(input_s) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_online(input, weight, prot.output_s, output_c) end_communicate() def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = FcSecureClient(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) with NamedTimerInstance("Client Offline"): prot.offline(input_c) torch_sync() with NamedTimerInstance("Client Online"): prot.online() torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate() marshal_funcs([test_server, test_client]) print(f"\nTest for {test_name}: End")
def test_fc_fhe_comm(): test_name = "test_fc_fhe_comm" print(f"\nTest for {test_name}: Start") modulus = 786433 num_input_unit = 512 num_output_unit = 512 data_bit = 17 print(f"Setting: num_input_unit {num_input_unit}, " f"num_output_unit: {num_output_unit}") data_range = 2**data_bit x_shape = [num_input_unit] w_shape = [num_output_unit, num_input_unit] fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) input_mask = gen_unirand_int_grain(0, modulus - 1, get_prod(x_shape)).reshape(x_shape) # input_mask = torch.arange(get_prod(x_shape)).reshape(x_shape) def check_correctness_offline(x, w, output_mask, output_c): actual = pmod(output_c.cuda() - output_mask.cuda(), modulus) torch_x = x.reshape([1] + x_shape).cuda().double() torch_w = w.reshape(w_shape).cuda().double() expected = torch.mm(torch_x, torch_w.t()) expected = pmod(expected.reshape(output_mask.shape), modulus) compare_expected_actual(expected, actual, name=test_name + " offline", get_relative=True) def test_server(): init_communicate(Config.server_rank) prot = FcFheServer(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) with NamedTimerInstance("Server Offline"): prot.offline(weight) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_offline(input_mask, weight, prot.output_mask_s, output_c) end_communicate() def test_client(): init_communicate(Config.client_rank) prot = FcFheClient(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) with NamedTimerInstance("Client Offline"): prot.offline(input_mask) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate() marshal_funcs([test_server, test_client]) print(f"\nTest for {test_name}: End")
def secure_vgg(input_sid, master_addr, master_port, model_name_base="vgg_swalp"): test_name = "secure inference" print(f"\nTest for {test_name}: Start") #normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) batch_size = 1 net_state_name, config_name = get_net_config_name(model_name_base) print(f"net_state going to load: {net_state_name}") print(f"store_configs going to load: {config_name}") store_configs = np.load(config_name, allow_pickle="TRUE").item() def get_plain_net(): if model_name_base == "vgg_swalp": net = ModulusNet(store_configs) elif model_name_base == "vgg_idc_swalp": net = ModulusNet(store_configs, 2) elif model_name_base == "vgg_cifar100": net = ModulusNet(store_configs, 100) elif model_name_base == "vgg16_cifar100": net = ModulusNet_vgg16(store_configs, 100) elif model_name_base == "vgg16_cifar10": net = ModulusNet_vgg16(store_configs, 10) elif model_name_base == "minionn_maxpool": net = ModulusNet_MiniONN(store_configs) else: raise Exception("Unknown: {model_name_base}") net_state = torch.load(net_state_name) device = torch.device("cuda:0") net.load_weight_bias(net_state) net.to(device) return net def get_secure_nn(): if model_name_base == "vgg_swalp": return generate_secure_vgg(10) elif model_name_base == "vgg_idc_swalp": return generate_secure_vgg(2) elif model_name_base == "vgg_cifar100": return generate_secure_vgg(100) elif model_name_base == "vgg16_cifar100": return generate_secure_vgg16(100) elif model_name_base == "vgg16_cifar10": return generate_secure_vgg16(10) elif model_name_base == "minionn_cifar10": return generate_secure_minionn("avgpool") elif model_name_base == "minionn_maxpool": return generate_secure_minionn("maxpool") else: raise Exception("Unknown: {model_name_base}") def check_correctness(self, input_img, output, modulus): plain_net = get_plain_net() expected = plain_net( input_img.reshape([1] + list(input_img.shape)).cuda()) expected = nmod(expected.reshape(expected.shape[1:]), modulus) actual = nmod(output, modulus).cuda() print("expected", expected) print("actual", actual) compare_expected_actual(expected, actual, name="secure_vgg", get_relative=True) _, expected_max = torch.max(expected, 0) _, actual_max = torch.max(actual, 0) print( f"expected_max: {expected_max}, actual_max: {actual_max}, Match: {expected_max == actual_max}" ) # check_correctness(None, torch.zeros([3, 32, 32]) + 1, torch.zeros(10), secure_nn.q_23) def test_server(): rank = Config.server_rank sys.stdout = Logger() traffic_record = TrafficRecord() secure_nn = get_secure_nn() secure_nn.set_rank(rank).init_communication(master_address=master_addr, master_port=master_port) warming_up_cuda() secure_nn.fhe_builder_sync() load_trunc_params(secure_nn, store_configs) net_state = torch.load(net_state_name) load_weight_params(secure_nn, store_configs, net_state) meta_rg = MetaTruncRandomGenerator() meta_rg.reset_seed() with NamedTimerInstance("Server Offline"): secure_nn.offline() torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): secure_nn.online() torch_sync() traffic_record.reset("server-online") secure_nn.check_correctness(check_correctness) secure_nn.check_layers(get_plain_net, get_hooking_lst(model_name_base)) secure_nn.end_communication() def test_client(): rank = Config.client_rank sys.stdout = Logger() traffic_record = TrafficRecord() secure_nn = get_secure_nn() secure_nn.set_rank(rank).init_communication(master_address=master_addr, master_port=master_port) warming_up_cuda() secure_nn.fhe_builder_sync() load_trunc_params(secure_nn, store_configs) def input_shift(data): first_layer_name = "conv1" return data_shift(data, store_configs[first_layer_name + "ForwardX"]) def testset(): if model_name_base in ["vgg16_cifar100"]: return torchvision.datasets.CIFAR100( root='./data', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), input_shift ])) elif model_name_base in ["vgg16_cifar10", "minionn_maxpool"]: return torchvision.datasets.CIFAR10( root='./data', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), input_shift ])) testloader = torch.utils.data.DataLoader(testset(), batch_size=batch_size, shuffle=True, num_workers=2) data_iter = iter(testloader) image, truth = next(data_iter) image = image.reshape(secure_nn.get_input_shape()) secure_nn.fill_input(image) with NamedTimerInstance("Client Offline"): secure_nn.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): secure_nn.online() torch_sync() traffic_record.reset("client-online") secure_nn.check_correctness(check_correctness, truth=truth) secure_nn.check_layers(get_plain_net, get_hooking_lst(model_name_base)) secure_nn.end_communication() if input_sid == Config.server_rank: # test_server() marshal_funcs([test_server]) elif input_sid == Config.client_rank: # test_client() marshal_funcs([test_client]) elif input_sid == Config.both_rank: marshal_funcs([test_server, test_client]) else: raise Exception(f"Unknown input_sid: {input_sid}") print(f"\nTest for {test_name}: End")
def test_secure_nn(): test_name = "test_secure_nn" print(f"\nTest for {test_name}: Start") data_bit = 5 work_bit = 17 data_range = 2**data_bit q_16 = 12289 # q_23 = 786433 q_23 = 7340033 # q_23 = 8273921 input_img_hw = 16 input_channel = 3 pow_to_div = 2 fhe_builder_16 = FheBuilder(q_16, 2048) fhe_builder_23 = FheBuilder(q_23, 8192) input_shape = [input_channel, input_img_hw, input_img_hw] context = SecureLayerContext(work_bit, data_bit, q_16, q_23, fhe_builder_16, fhe_builder_23, test_name + "_context") input_layer = InputSecureLayer(input_shape, "input_layer") conv1 = Conv2dSecureLayer(3, 5, 3, "conv1", padding=1) relu1 = ReluSecureLayer("relu1") trunc1 = TruncSecureLayer("trunc1") pool1 = Maxpool2x2SecureLayer("pool1") conv2 = Conv2dSecureLayer(5, 10, 3, "conv2", padding=1) flatten = FlattenSecureLayer("flatten") fc1 = FcSecureLayer(32, "fc1") output_layer = OutputSecureLayer("output_layer") secure_nn = SecureNeuralNetwork("secure_nn") secure_nn.load_layers([ input_layer, conv1, pool1, relu1, trunc1, conv2, flatten, fc1, output_layer ]) # secure_nn.load_layers([input_layer, relu1, trunc1, output_layer]) secure_nn.load_context(context) def generate_random_data(shape): return gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(shape)).reshape(shape) conv1_w = generate_random_data(conv1.weight_shape) conv2_w = generate_random_data(conv2.weight_shape) fc1_w = generate_random_data(fc1.weight_shape) def check_correctness(input_img, output): torch_pool1 = torch.nn.MaxPool2d(2) x = input_img.to(Config.device).double() x = x.reshape([1] + list(x.shape)) x = pmod(F.conv2d(x, conv1_w.to(Config.device).double(), padding=1), q_23) x = pmod(F.relu(nmod(x, q_23)), q_23) x = pmod(torch_pool1(nmod(x, q_23)), q_23) x = pmod(x // (2**pow_to_div), q_23) x = pmod(F.conv2d(x, conv2_w.to(Config.device).double(), padding=1), q_23) x = x.view(-1) x = pmod( torch.mm(x.view(1, -1), fc1_w.to(Config.device).double().t()).view(-1), q_23) expected = x actual = pmod(output, q_23) if len(expected.shape) == 4 and expected.shape[0] == 1: expected = expected.reshape(expected.shape[1:]) compare_expected_actual(expected, actual, name=test_name, get_relative=True) def test_server(): rank = Config.server_rank init_communicate(rank) context.set_rank(rank) comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() conv1.load_weight(conv1_w) conv2.load_weight(conv2_w) fc1.load_weight(fc1_w) trunc1.set_div_to_pow(pow_to_div) with NamedTimerInstance("Server Offline"): secure_nn.offline() torch_sync() with NamedTimerInstance("Server Online"): secure_nn.online() torch_sync() end_communicate() def test_client(): rank = Config.client_rank init_communicate(rank) context.set_rank(rank) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() input_img = generate_random_data(input_shape) trunc1.set_div_to_pow(pow_to_div) secure_nn.feed_input(input_img) with NamedTimerInstance("Client Offline"): secure_nn.offline() torch_sync() with NamedTimerInstance("Client Online"): secure_nn.online() torch_sync() check_correctness(input_img, secure_nn.get_output()) end_communicate() marshal_funcs([test_server, test_client]) print(f"\nTest for {test_name}: End")
def test_avgpool2x2_dgk(input_sid, master_address, master_port, num_elem=2**17): test_name = "Avgpool2x2" print(f"\nTest for {test_name}: Start") data_bit = 20 work_bit = 20 data_range = 2**data_bit q_16 = 12289 # q_23 = 786433 q_23 = 7340033 img_hw = 4 print(f"Number of element: {num_elem}") fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_16.generate_keys() fhe_builder_23 = FheBuilder(q_23, Config.n_23) fhe_builder_23.generate_keys() img = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) img_s = gen_unirand_int_grain(0, q_23 - 1, num_elem) img_c = pmod(img - img_s, q_23) def check_correctness_online(img, max_s, max_c): img = torch.where(img < q_23 // 2, img, img - q_23).cuda() pool = torch.nn.AvgPool2d(2) expected = pool(img.double().reshape([-1, img_hw, img_hw ])).reshape(-1) * 4 expected = pmod(expected, q_23) actual = pmod(max_s + max_c, q_23) compare_expected_actual(expected, actual, name=test_name + "_online", get_relative=True) def test_server(): init_communicate(Config.server_rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() prot = Avgpool2x2Server(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "avgpool") with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): prot.online(img_s) torch_sync() traffic_record.reset("server-online") blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_res_c") blob_out_c.prepare_recv() torch_sync() out_c = blob_out_c.get_recv() check_correctness_online(img, prot.out_s, out_c) end_communicate() def test_client(): init_communicate(Config.client_rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() prot = Avgpool2x2Client(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "avgpool") with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): prot.online(img_c) torch_sync() traffic_record.reset("client-online") blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_res_c") torch_sync() blob_out_c.send(prot.out_c) end_communicate() if input_sid == Config.both_rank: marshal_funcs([test_server, test_client]) elif input_sid == Config.server_rank: test_server() elif input_sid == Config.client_rank: test_client() print(f"\nTest for {test_name}: End")
def test_relu_dgk(input_sid, master_address, master_port, num_elem=2**17): test_name = "Relu Dgk" print(f"\nTest for {test_name}: Start") data_bit = 20 work_bit = 20 data_range = 2 ** data_bit # q_16 = 12289 q_16 = Config.q_16 # q_23 = 786433 q_23 = Config.q_23 # q_23 = 8273921 print(f"Number of element: {num_elem}") def check_correctness_online(a, c_s, c_c): a = a.to(Config.device) expected = pmod(torch.max(a, torch.zeros_like(a)), q_23) actual = pmod(c_s + c_c, q_23) compare_expected_actual(expected, actual, name="relu_dgk_online", get_relative=True) def test_server(): rank = Config.server_rank init_communicate(Config.server_rank, master_address=master_address, master_port=master_port) traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() prot = ReluDgkServer(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk") blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a") blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s") torch_sync() blob_a_s.prepare_recv() a_s = blob_a_s.get_recv() torch_sync() with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): prot.online(a_s) torch_sync() traffic_record.reset("server-online") blob_max_s.send(prot.max_s) torch.cuda.empty_cache() end_communicate() def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() a = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) a_c = gen_unirand_int_grain(0, q_23 - 1, num_elem) a_s = pmod(a - a_c, q_23) prot = ReluDgkClient(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk") blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a") blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s") torch_sync() blob_a_s.send(a_s) blob_max_s.prepare_recv() torch_sync() with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): prot.online(a_c) torch_sync() traffic_record.reset("client-online") max_s = blob_max_s.get_recv() check_correctness_online(a, max_s, prot.max_c) torch.cuda.empty_cache() end_communicate() if input_sid == Config.both_rank: marshal_funcs([test_server, test_client]) elif input_sid == Config.server_rank: marshal_funcs([test_server]) # test_server() elif input_sid == Config.client_rank: marshal_funcs([test_client]) # test_client() print(f"\nTest for {test_name}: End")
def test_dgk(input_sid, master_address, master_port, num_elem=2**17): print("\nTest for Dgk Basic: Start") data_bit = 20 work_bit = 20 # q_23 = 786433 # q_23 = 8273921 # q_23 = 4079617 # n_23, q_23 = 8192, 7340033 n_23, q_23 = Config.n_23, Config.q_23 # n_16, q_16 = 2048, 12289 # n_16, q_16 = 8192, 65537 n_16, q_16 = Config.n_16, Config.q_16 print(f"Number of element: {num_elem}") data_range = 2**data_bit work_range = 2**work_bit def check_correctness(x, y, dgk_x_leq_y_s, dgk_x_leq_y_c): x = torch.where(x < q_23 // 2, x, x - q_23).to(Config.device) y = torch.where(y < q_23 // 2, y, y - q_23).to(Config.device) expected_x_leq_y = (x <= y) dgk_x_leq_y_recon = pmod(dgk_x_leq_y_s + dgk_x_leq_y_c, q_23) compare_expected_actual(expected_x_leq_y, dgk_x_leq_y_recon, name="DGK x <= y", get_relative=True) print(torch.sum(expected_x_leq_y != dgk_x_leq_y_recon)) def check_correctness_mod_div(r, z, correct_mod_div_work_s, correct_mod_div_work_c): elem_zeros = torch.zeros(num_elem).to(Config.device) expected = torch.where(r > z, q_23 // work_range + elem_zeros, elem_zeros) actual = pmod(correct_mod_div_work_s + correct_mod_div_work_c, q_23) compare_expected_actual(expected, actual, get_relative=True, name="mod_div_online") def test_server(): rank = Config.server_rank init_communicate(Config.server_rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() dgk = DgkBitServer(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "DgkBitTest") x_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "x") y_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y") y_sub_x_s_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y_sub_x_s") x_blob.prepare_recv() y_blob.prepare_recv() y_sub_x_s_blob.prepare_recv() torch_sync() x = x_blob.get_recv() y = y_blob.get_recv() y_sub_x_s = y_sub_x_s_blob.get_recv() torch_sync() with NamedTimerInstance("Server Offline"): dgk.offline() # y_sub_x_s = pmod(y_s.to(Config.device) - x_s.to(Config.device), q_23) torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): dgk.online(y_sub_x_s) traffic_record.reset("server-online") dgk_x_leq_y_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "dgk_x_leq_y_c") correct_mod_div_work_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "correct_mod_div_work_c") z_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "z") dgk_x_leq_y_c_blob.prepare_recv() correct_mod_div_work_c_blob.prepare_recv() z_blob.prepare_recv() torch_sync() dgk_x_leq_y_c = dgk_x_leq_y_c_blob.get_recv() correct_mod_div_work_c = correct_mod_div_work_c_blob.get_recv() z = z_blob.get_recv() check_correctness(x, y, dgk.dgk_x_leq_y_s, dgk_x_leq_y_c) check_correctness_mod_div(dgk.r, z, dgk.correct_mod_div_work_s, correct_mod_div_work_c) end_communicate() def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, n_16) fhe_builder_23 = FheBuilder(q_23, n_23) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() dgk = DgkBitClient(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "DgkBitTest") x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) y = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) x_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) y_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) x_s = pmod(x - x_c, q_23) y_s = pmod(y - y_c, q_23) y_sub_x_s = pmod(y_s - x_s, q_23) x_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "x") y_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y") y_sub_x_s_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y_sub_x_s") torch_sync() x_blob.send(x) y_blob.send(y) y_sub_x_s_blob.send(y_sub_x_s) torch_sync() with NamedTimerInstance("Client Offline"): dgk.offline() y_sub_x_c = pmod(y_c - x_c, q_23) traffic_record.reset("client-offline") torch_sync() with NamedTimerInstance("Client Online"): dgk.online(y_sub_x_c) traffic_record.reset("client-online") dgk_x_leq_y_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "dgk_x_leq_y_c") correct_mod_div_work_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "correct_mod_div_work_c") z_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "z") torch_sync() dgk_x_leq_y_c_blob.send(dgk.dgk_x_leq_y_c) correct_mod_div_work_c_blob.send(dgk.correct_mod_div_work_c) z_blob.send(dgk.z) end_communicate() if input_sid == Config.both_rank: marshal_funcs([test_server, test_client]) elif input_sid == Config.server_rank: marshal_funcs([test_server]) elif input_sid == Config.client_rank: marshal_funcs([test_client]) print("\nTest for Dgk Basic: End")