def test_server(): init_communicate(Config.server_rank) fhe_builder = FheBuilder(modulus, degree) comm_fhe_builder = CommFheBuilder(Config.server_rank, fhe_builder, "fhe_builder") comm_fhe_builder.recv_public_key() comm_fhe_builder.wait_and_build_public_key() blob_ori = BlobFheEnc(num_elem, comm_fhe_builder, ori_tag) blob_ori.send_from_torch(ori) enc_pk = fhe_builder.build_enc_from_torch(expected_tensor_pk) comm_fhe_builder.send_enc(enc_pk, pk_tag) comm_fhe_builder.recv_secret_key() comm_fhe_builder.wait_and_build_secret_key() enc_sk = fhe_builder.build_enc(num_elem) comm_fhe_builder.recv_enc(enc_sk, sk_tag) comm_fhe_builder.wait_enc(sk_tag) actual_tensor_sk = fhe_builder.decrypt_to_torch(enc_sk) compare_expected_actual(expected_tensor_sk, actual_tensor_sk, get_relative=True, name="Recovering to test sk") dist.destroy_process_group()
def test_fhe_refresh(): print() print("Test: FHE refresh: Start") modulus, degree = 12289, 2048 num_batch = 12 num_elem = 2 ** 15 fhe_builder = FheBuilder(modulus, degree) fhe_builder.generate_keys() def test_batch_server(): init_communicate(Config.server_rank) shape = [num_batch, num_elem] tensor = gen_unirand_int_grain(0, modulus, shape) refresher = EncRefresherServer(shape, fhe_builder, "test_batch_refresher") enc = [fhe_builder.build_enc_from_torch(tensor[i]) for i in range(num_batch)] refreshed = refresher.request(enc) tensor_refreshed = fhe_builder.decrypt_to_torch(refreshed) compare_expected_actual(tensor, tensor_refreshed, get_relative=True, name="batch refresh") end_communicate() def test_batch_client(): init_communicate(Config.client_rank) shape = [num_batch, num_elem] refresher = EncRefresherClient(shape, fhe_builder, "test_batch_refresher") refresher.prepare_recv() refresher.response() end_communicate() marshal_funcs([test_batch_server, test_batch_client]) def test_1d_server(): init_communicate(Config.server_rank) shape = num_elem tensor = gen_unirand_int_grain(0, modulus, shape) refresher = EncRefresherServer(shape, fhe_builder, "test_1d_refresher") enc = fhe_builder.build_enc_from_torch(tensor) refreshed = refresher.request(enc) tensor_refreshed = fhe_builder.decrypt_to_torch(refreshed) compare_expected_actual(tensor, tensor_refreshed, get_relative=True, name="1d_refresh") end_communicate() def test_1d_client(): init_communicate(Config.client_rank) shape = num_elem refresher = EncRefresherClient(shape, fhe_builder, "test_1d_refresher") refresher.prepare_recv() refresher.response() end_communicate() marshal_funcs([test_1d_server, test_1d_client]) print() print("Test: FHE refresh: End")
def test_conv2d_fhe_ntt_single_thread(): modulus = 786433 img_hw = 16 filter_hw = 3 padding = 1 num_input_channel = 64 num_output_channel = 128 data_bit = 17 data_range = 2**data_bit x_shape = [num_input_channel, img_hw, img_hw] w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw] fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape) w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape) # x = torch.arange(get_prod(x_shape)).reshape(x_shape) # w = torch.arange(get_prod(w_shape)).reshape(w_shape) warming_up_cuda() prot = Conv2dFheNttSingleThread(modulus, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, fhe_builder, "test_conv2d_fhe_ntt", padding) print("prot.num_input_batch", prot.num_input_batch) print("prot.num_output_batch", prot.num_output_batch) with NamedTimerInstance("encoding x"): prot.encode_input_to_fhe_batch(x) with NamedTimerInstance("conv2d with w"): prot.compute_conv2d(w) with NamedTimerInstance("conv2d masking output"): prot.masking_output() with NamedTimerInstance("decoding output"): y = prot.decode_output_from_fhe_batch() # actual = pmod(y, modulus) actual = pmod(y - prot.output_mask_s, modulus) # print("actual\n", actual) torch_x = x.reshape([1] + x_shape).double() torch_w = w.reshape(w_shape).double() with NamedTimerInstance("Conv2d Torch"): expected = F.conv2d(torch_x, torch_w, padding=padding) expected = pmod(expected.reshape(prot.output_shape), modulus) # print("expected", expected) compare_expected_actual(expected, actual, name="test_conv2d_fhe_ntt_single_thread", get_relative=True)
def test_fc_fhe_single_thread(): test_name = "test_fc_fhe_single_thread" print(f"\nTest for {test_name}: Start") modulus = Config.q_23 num_input_unit = 512 num_output_unit = 512 data_bit = 17 data_range = 2**data_bit x_shape = [num_input_unit] w_shape = [num_output_unit, num_input_unit] fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape) w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape) # w = gen_unirand_int_grain(0, modulus, get_prod(w_shape)).reshape(w_shape) # x = torch.arange(get_prod(x_shape)).reshape(x_shape) # w = torch.arange(get_prod(w_shape)).reshape(w_shape) warming_up_cuda() prot = FcFheSingleThread(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) print("prot.num_input_batch", prot.num_input_batch) print("prot.num_output_batch", prot.num_output_batch) print("prot.num_elem_in_piece", prot.num_elem_in_piece) with NamedTimerInstance("encoding x"): prot.encode_input_to_fhe_batch(x) with NamedTimerInstance("conv2d with w"): prot.compute_with_weight(w) with NamedTimerInstance("conv2d masking output"): prot.masking_output() with NamedTimerInstance("decoding output"): y = prot.decode_output_from_fhe_batch() actual = pmod(y, modulus) actual = pmod(y - prot.output_mask_s, modulus) # print("actual\n", actual) torch_x = x.reshape([1] + x_shape).double() torch_w = w.reshape(w_shape).double() with NamedTimerInstance("Conv2d Torch"): expected = torch.mm(torch_x, torch_w.t()) expected = pmod(expected.reshape(prot.output_shape), modulus) compare_expected_actual(expected, actual, name=test_name, get_relative=True) print(f"\nTest for {test_name}: End")
def test_server(): rank = Config.server_rank init_communicate(rank, master_address=master_address, master_port=master_port) traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() img = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) img_s = gen_unirand_int_grain(0, q_23 - 1, num_elem) img_c = pmod(img - img_s, q_23) prot = Maxpool2x2DgkServer(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "max_dgk") blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_max_c") torch_sync() blob_img_c.send(img_c) torch_sync() with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): prot.online(img_s) torch_sync() traffic_record.reset("server-online") blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_max_c") blob_max_c.prepare_recv() torch_sync() max_c = blob_max_c.get_recv() check_correctness_online(img, prot.max_s, max_c) end_communicate()
def __init__(self, name, data_bit=5, work_bit=18, n_16=Config.n_16, n_23=Config.n_23, q_16=Config.q_16, q_23=Config.q_23, input_hw=32, input_channel=3): super().__init__(name) self.data_bit = data_bit self.work_bit = work_bit self.q_16 = q_16 self.q_23 = q_23 self.input_hw = input_hw self.input_channel = input_channel self.data_range = 2 ** data_bit self.fhe_builder_16 = FheBuilder(q_16, n_16) self.fhe_builder_23 = FheBuilder(q_23, n_23) self.context = SecureLayerContext(work_bit, data_bit, q_16, q_23, self.fhe_builder_16, self.fhe_builder_23, self.sub_name("context")) self.secure_nn_core = SecureNeuralNetwork(self.sub_name("secure_nn")) self.layer_dict = {}
def test_server(): rank = Config.server_rank init_communicate(Config.server_rank, master_address=master_address, master_port=master_port) traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() prot = ReluDgkServer(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk") blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a") blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s") torch_sync() blob_a_s.prepare_recv() a_s = blob_a_s.get_recv() torch_sync() with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): prot.online(a_s) torch_sync() traffic_record.reset("server-online") blob_max_s.send(prot.max_s) torch.cuda.empty_cache() end_communicate()
def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() prot = Maxpool2x2DgkClient(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "max_dgk") blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_max_c") blob_img_c.prepare_recv() torch_sync() img_c = blob_img_c.get_recv() torch_sync() with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): prot.online(img_c) torch_sync() traffic_record.reset("client-online") blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_max_c") torch_sync() blob_max_c.send(prot.max_c) end_communicate()
def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() a = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) a_c = gen_unirand_int_grain(0, q_23 - 1, num_elem) a_s = pmod(a - a_c, q_23) prot = ReluDgkClient(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk") blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a") blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s") torch_sync() blob_a_s.send(a_s) blob_max_s.prepare_recv() torch_sync() with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): prot.online(a_c) torch_sync() traffic_record.reset("client-online") max_s = blob_max_s.get_recv() check_correctness_online(a, max_s, prot.max_c) torch.cuda.empty_cache() end_communicate()
class SecureNnFramework(NamedBase, ContextRankBase): class_name = "SecureNnFramework" layers: List[SecureLayerBase] layer_dict: Dict[str, SecureLayerBase] input_img: torch.Tensor output_res: torch.Tensor comm_base: CommBase # def __init__(self, name, data_bit=5, work_bit=18, q_16=12289, q_23=7340033, input_hw=32, input_channel=3): def __init__(self, name, data_bit=5, work_bit=18, n_16=Config.n_16, n_23=Config.n_23, q_16=Config.q_16, q_23=Config.q_23, input_hw=32, input_channel=3): super().__init__(name) self.data_bit = data_bit self.work_bit = work_bit self.q_16 = q_16 self.q_23 = q_23 self.input_hw = input_hw self.input_channel = input_channel self.data_range = 2 ** data_bit self.fhe_builder_16 = FheBuilder(q_16, n_16) self.fhe_builder_23 = FheBuilder(q_23, n_23) self.context = SecureLayerContext(work_bit, data_bit, q_16, q_23, self.fhe_builder_16, self.fhe_builder_23, self.sub_name("context")) self.secure_nn_core = SecureNeuralNetwork(self.sub_name("secure_nn")) self.layer_dict = {} def generate_random_data(self, shape): return gen_unirand_int_grain(-self.data_range//2 + 1, self.data_range//2, get_prod(shape)).reshape(shape) def load_layers(self, layers): self.layers = layers for layer in layers: self.layer_dict[layer.name] = layer self.secure_nn_core.load_layers(layers) self.secure_nn_core.load_context(self.context) return self def get_input_shape(self): return self.secure_nn_core.input_layer.get_output_shape() def get_output_shape(self): return self.secure_nn_core.output_layer.get_output_shape() def set_rank(self, rank): self.rank = rank self.context.set_rank(rank) return self def init_communication(self, **kwargs): init_communicate(self.rank, **kwargs) self.comm_base = CommBase(self.rank, self.sub_name("comm_base")) return self def fhe_builder_sync(self): comm_fhe_16 = CommFheBuilder(self.rank, self.fhe_builder_16, self.sub_name("fhe_builder_16")) comm_fhe_23 = CommFheBuilder(self.rank, self.fhe_builder_23, self.sub_name("fhe_builder_23")) torch_sync() if self.is_server(): comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() elif self.is_client(): self.fhe_builder_16.generate_keys() self.fhe_builder_23.generate_keys() comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() torch_sync() return self def fill_random_weight(self): assert(self.is_server()) for layer in self.layers: if not layer.has_weight: continue layer.load_weight(self.generate_random_data(layer.weight_shape)) return self def fill_input(self, input_img): assert(self.is_client()) self.input_img = input_img self.secure_nn_core.feed_input(input_img) return self def fill_random_input(self): assert(self.is_client()) self.fill_input(self.generate_random_data(self.get_input_shape())) return self def offline(self): self.secure_nn_core.offline() return self def online(self): self.secure_nn_core.online() return self def check_layers(self, get_plain_net_func, all_layer_pair): secure_input_layer = self.layers[0] secure_input_layer.reconstructed_to_server(self.comm_base, self.q_23) for secure_layer_name, plain_layer_name in all_layer_pair: secure_layer = self.layer_dict[secure_layer_name] secure_layer.reconstructed_to_server(self.comm_base, self.q_23) if self.is_server(): plain_net = get_plain_net_func() plain_output_dict = {} def hook_generator(name): def hook(module, input, output): plain_output_dict[name] = output.data.detach().clone() return hook for secure_layer_name, plain_layer_name in all_layer_pair: # plain_output_dict[plain_layer_name] = torch.zeros() plain_layer = getattr(plain_net, plain_layer_name) plain_layer.register_forward_hook(hook_generator(plain_layer_name)) input_img = secure_input_layer.get_reconstructed_output() input_img = input_img.reshape([1] + list(input_img.shape)).cuda() meta_rg = MetaTruncRandomGenerator() meta_rg.reset_rg("plain") plain_output = plain_net(input_img) for secure_layer_name, plain_layer_name in all_layer_pair: print("secure_layer_name, plain_layer_name: %s, %s"%(secure_layer_name, plain_layer_name)) plain_output = plain_output_dict[plain_layer_name] secure_layer = self.layer_dict[secure_layer_name] secure_output = secure_layer.get_reconstructed_output() compare_expected_actual(plain_output, secure_output, name=f"compare secure-plain: {secure_layer_name}, {plain_layer_name}", get_relative=True) # print("secure", secure_output.shape, secure_output) # print("plain", plain_output.shape, plain_output) return self def check_correctness(self, verify_func, is_argmax=False, truth=None): blob_input_img = BlobTorch(self.get_input_shape(), torch.float, self.comm_base, "input_img") blob_actual_output = BlobTorch(self.get_output_shape(), torch.float, self.comm_base, "actual_output") blob_truth = BlobTorch(1, torch.float, self.comm_base, "truth") if self.is_server(): blob_input_img.prepare_recv() blob_actual_output.prepare_recv() blob_truth.prepare_recv() torch_sync() input_img = blob_input_img.get_recv() actual_output = blob_actual_output.get_recv() truth = int(blob_truth.get_recv().item()) verify_func(self, input_img, actual_output, self.q_23) actual_output = nmod(actual_output, self.q_23).cuda() _, actual_max = torch.max(actual_output, 0) print(f"truth: {truth}, actual: {actual_max}, MatchTruth: {truth == actual_max}") if self.is_client(): torch_sync() actual_output = self.secure_nn_core.get_argmax_output() if is_argmax else self.secure_nn_core.get_output() blob_input_img.send(self.input_img) blob_actual_output.send(actual_output) blob_truth.send(torch.tensor(truth)) return self def end_communication(self): end_communicate() return self
def test_shares_mult(): print("\nTest for Shares Mult: Start") modulus = Config.q_23 num_elem = 2**17 print(f"Number of element: {num_elem}") fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() a = gen_unirand_int_grain(0, modulus - 1, num_elem) a_s = gen_unirand_int_grain(0, modulus - 1, num_elem) a_c = pmod(a - a_s, modulus) b = gen_unirand_int_grain(0, modulus - 1, num_elem) b_s = gen_unirand_int_grain(0, modulus - 1, num_elem) b_c = pmod(b - b_s, modulus) def check_correctness_offline(u, v, z_s, z_c): expected = pmod( u.double().to(Config.device) * v.double().to(Config.device), modulus) actual = pmod(z_s + z_c, modulus) compare_expected_actual(expected, actual, name="shares_mult_offline", get_relative=True) def check_correctness_online(a, b, c_s, c_c): expected = pmod( a.double().to(Config.device) * b.double().to(Config.device), modulus) actual = pmod(c_s + c_c, modulus) compare_expected_actual(expected, actual, name="shares_mult_online", get_relative=True) def test_server(): init_communicate(Config.server_rank) prot = SharesMultServer(num_elem, modulus, fhe_builder, "test_shares_mult") with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() with NamedTimerInstance("Server Online"): prot.online(a_s, b_s) torch_sync() blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_u_c") blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_v_c") blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_z_c") blob_u_c.prepare_recv() blob_v_c.prepare_recv() blob_z_c.prepare_recv() torch_sync() u_c = blob_u_c.get_recv() v_c = blob_v_c.get_recv() z_c = blob_z_c.get_recv() u = pmod(prot.u_s + u_c, modulus) v = pmod(prot.v_s + v_c, modulus) check_correctness_online(u, v, prot.z_s, z_c) blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c") blob_c_c.prepare_recv() torch_sync() c_c = blob_c_c.get_recv() check_correctness_online(a, b, prot.c_s, c_c) end_communicate() def test_client(): init_communicate(Config.client_rank) prot = SharesMultClient(num_elem, modulus, fhe_builder, "test_shares_mult") with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() with NamedTimerInstance("Client Online"): prot.online(a_c, b_c) torch_sync() blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_u_c") blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_v_c") blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base, "recon_z_c") torch_sync() blob_u_c.send(prot.u_c) blob_v_c.send(prot.v_c) blob_z_c.send(prot.z_c) blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c") torch_sync() blob_c_c.send(prot.c_c) end_communicate() marshal_funcs([test_server, test_client]) print("\nTest for Shares Mult: End")
def test_conv2d_secure_comm(input_sid, master_address, master_port, setting=(16, 3, 128, 128)): test_name = "Conv2d Secure Comm" print(f"\nTest for {test_name}: Start") modulus = 786433 padding = 1 img_hw, filter_hw, num_input_channel, num_output_channel = setting data_bit = 17 data_range = 2**data_bit # n_23 = 8192 n_23 = 16384 print(f"Setting covn2d: img_hw: {img_hw}, " f"filter_hw: {filter_hw}, " f"num_input_channel: {num_input_channel}, " f"num_output_channel: {num_output_channel}") x_shape = [num_input_channel, img_hw, img_hw] w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw] b_shape = [num_output_channel] fhe_builder = FheBuilder(modulus, n_23) fhe_builder.generate_keys() weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) bias = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(b_shape)).reshape(b_shape) input = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape) input_c = generate_random_mask(modulus, x_shape) input_s = pmod(input - input_c, modulus) def check_correctness_online(x, w, b, output_s, output_c): actual = pmod(output_s.cuda() + output_c.cuda(), modulus) torch_x = x.reshape([1] + x_shape).cuda().double() torch_w = w.reshape(w_shape).cuda().double() torch_b = b.cuda().double() if b is not None else None expected = F.conv2d(torch_x, torch_w, padding=padding, bias=torch_b) expected = pmod(expected.reshape(output_s.shape), modulus) compare_expected_actual(expected, actual, name=test_name + " online", get_relative=True) def test_server(): rank = Config.server_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = Conv2dSecureServer(modulus, fhe_builder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_secure_comm", padding=padding) with NamedTimerInstance("Server Offline"): prot.offline(weight, bias=bias) torch_sync() with NamedTimerInstance("Server Online"): prot.online(input_s) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_online(input, weight, bias, prot.output_s, output_c) end_communicate() def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = Conv2dSecureClient(modulus, fhe_builder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_secure_comm", padding=padding) with NamedTimerInstance("Client Offline"): prot.offline(input_c) torch_sync() with NamedTimerInstance("Client Online"): prot.online() torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate() if input_sid == Config.both_rank: marshal_funcs([test_server, test_client]) elif input_sid == Config.server_rank: test_server() elif input_sid == Config.client_rank: test_client() print(f"\nTest for {test_name}: End")
def test_conv2d_fhe_ntt_comm(): test_name = "Conv2d Fhe NTT Comm" print(f"\nTest for {test_name}: Start") modulus = 786433 img_hw = 2 filter_hw = 3 padding = 1 num_input_channel = 512 num_output_channel = 512 data_bit = 17 data_range = 2**data_bit print(f"Setting: img_hw {img_hw}, " f"filter_hw: {filter_hw}, " f"num_input_channel: {num_input_channel}, " f"num_output_channel: {num_output_channel}") x_shape = [num_input_channel, img_hw, img_hw] w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw] fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) input_mask = gen_unirand_int_grain(0, modulus - 1, get_prod(x_shape)).reshape(x_shape) # input_mask = torch.arange(get_prod(x_shape)).reshape(x_shape) def check_correctness_offline(x, w, output_mask, output_c): actual = pmod(output_c.cuda() - output_mask.cuda(), modulus) torch_x = x.reshape([1] + x_shape).cuda().double() torch_w = w.reshape(w_shape).cuda().double() expected = F.conv2d(torch_x, torch_w, padding=padding) expected = pmod(expected.reshape(output_mask.shape), modulus) compare_expected_actual(expected, actual, name=test_name + " offline", get_relative=True) def test_server(): init_communicate(Config.server_rank) prot = Conv2dFheNttServer(modulus, fhe_builder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_fhe_ntt_comm", padding=padding) with NamedTimerInstance("Server Offline"): prot.offline(weight) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_offline(input_mask, weight, prot.output_mask_s, output_c) end_communicate() def test_client(): init_communicate(Config.client_rank) prot = Conv2dFheNttClient(modulus, fhe_builder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_fhe_ntt_comm", padding=padding) with NamedTimerInstance("Client Offline"): prot.offline(input_mask) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate() marshal_funcs([test_server, test_client]) print(f"\nTest for {test_name}: End")
def test_server(): rank = Config.server_rank init_communicate(Config.server_rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() dgk = DgkBitServer(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "DgkBitTest") x_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "x") y_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y") y_sub_x_s_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y_sub_x_s") x_blob.prepare_recv() y_blob.prepare_recv() y_sub_x_s_blob.prepare_recv() torch_sync() x = x_blob.get_recv() y = y_blob.get_recv() y_sub_x_s = y_sub_x_s_blob.get_recv() torch_sync() with NamedTimerInstance("Server Offline"): dgk.offline() # y_sub_x_s = pmod(y_s.to(Config.device) - x_s.to(Config.device), q_23) torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): dgk.online(y_sub_x_s) traffic_record.reset("server-online") dgk_x_leq_y_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "dgk_x_leq_y_c") correct_mod_div_work_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "correct_mod_div_work_c") z_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "z") dgk_x_leq_y_c_blob.prepare_recv() correct_mod_div_work_c_blob.prepare_recv() z_blob.prepare_recv() torch_sync() dgk_x_leq_y_c = dgk_x_leq_y_c_blob.get_recv() correct_mod_div_work_c = correct_mod_div_work_c_blob.get_recv() z = z_blob.get_recv() check_correctness(x, y, dgk.dgk_x_leq_y_s, dgk_x_leq_y_c) check_correctness_mod_div(dgk.r, z, dgk.correct_mod_div_work_s, correct_mod_div_work_c) end_communicate()
def test_fc_secure_comm(input_sid, master_address, master_port, setting=(512, 512)): test_name = "test_fc_secure_comm" print(f"\nTest for {test_name}: Start") modulus = 786433 num_input_unit, num_output_unit = setting data_bit = 17 print(f"Setting fc: " f"num_input_unit: {num_input_unit}, " f"num_output_unit: {num_output_unit}") data_range = 2**data_bit x_shape = [num_input_unit] w_shape = [num_output_unit, num_input_unit] b_shape = [num_output_unit] fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) bias = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(b_shape)).reshape(b_shape) input = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape) input_c = generate_random_mask(modulus, x_shape) input_s = pmod(input - input_c, modulus) def check_correctness_online(x, w, output_s, output_c): actual = pmod(output_s.cuda() + output_c.cuda(), modulus) torch_x = x.reshape([1] + x_shape).cuda().double() torch_w = w.reshape(w_shape).cuda().double() expected = torch.mm(torch_x, torch_w.t()) if bias is not None: expected += bias.cuda().double() expected = pmod(expected.reshape(output_s.shape), modulus) compare_expected_actual(expected, actual, name=test_name + " online", get_relative=True) def test_server(): rank = Config.server_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = FcSecureServer(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) with NamedTimerInstance("Server Offline"): prot.offline(weight, bias=bias) torch_sync() with NamedTimerInstance("Server Online"): prot.online(input_s) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_online(input, weight, prot.output_s, output_c) end_communicate() def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = FcSecureClient(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) with NamedTimerInstance("Client Offline"): prot.offline(input_c) torch_sync() with NamedTimerInstance("Client Online"): prot.online() torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate() marshal_funcs([test_server, test_client]) print(f"\nTest for {test_name}: End")
def test_fc_fhe_comm(): test_name = "test_fc_fhe_comm" print(f"\nTest for {test_name}: Start") modulus = 786433 num_input_unit = 512 num_output_unit = 512 data_bit = 17 print(f"Setting: num_input_unit {num_input_unit}, " f"num_output_unit: {num_output_unit}") data_range = 2**data_bit x_shape = [num_input_unit] w_shape = [num_output_unit, num_input_unit] fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) input_mask = gen_unirand_int_grain(0, modulus - 1, get_prod(x_shape)).reshape(x_shape) # input_mask = torch.arange(get_prod(x_shape)).reshape(x_shape) def check_correctness_offline(x, w, output_mask, output_c): actual = pmod(output_c.cuda() - output_mask.cuda(), modulus) torch_x = x.reshape([1] + x_shape).cuda().double() torch_w = w.reshape(w_shape).cuda().double() expected = torch.mm(torch_x, torch_w.t()) expected = pmod(expected.reshape(output_mask.shape), modulus) compare_expected_actual(expected, actual, name=test_name + " offline", get_relative=True) def test_server(): init_communicate(Config.server_rank) prot = FcFheServer(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) with NamedTimerInstance("Server Offline"): prot.offline(weight) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_offline(input_mask, weight, prot.output_mask_s, output_c) end_communicate() def test_client(): init_communicate(Config.client_rank) prot = FcFheClient(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) with NamedTimerInstance("Client Offline"): prot.offline(input_mask) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate() marshal_funcs([test_server, test_client]) print(f"\nTest for {test_name}: End")
def test_secure_nn(): test_name = "test_secure_nn" print(f"\nTest for {test_name}: Start") data_bit = 5 work_bit = 17 data_range = 2**data_bit q_16 = 12289 # q_23 = 786433 q_23 = 7340033 # q_23 = 8273921 input_img_hw = 16 input_channel = 3 pow_to_div = 2 fhe_builder_16 = FheBuilder(q_16, 2048) fhe_builder_23 = FheBuilder(q_23, 8192) input_shape = [input_channel, input_img_hw, input_img_hw] context = SecureLayerContext(work_bit, data_bit, q_16, q_23, fhe_builder_16, fhe_builder_23, test_name + "_context") input_layer = InputSecureLayer(input_shape, "input_layer") conv1 = Conv2dSecureLayer(3, 5, 3, "conv1", padding=1) relu1 = ReluSecureLayer("relu1") trunc1 = TruncSecureLayer("trunc1") pool1 = Maxpool2x2SecureLayer("pool1") conv2 = Conv2dSecureLayer(5, 10, 3, "conv2", padding=1) flatten = FlattenSecureLayer("flatten") fc1 = FcSecureLayer(32, "fc1") output_layer = OutputSecureLayer("output_layer") secure_nn = SecureNeuralNetwork("secure_nn") secure_nn.load_layers([ input_layer, conv1, pool1, relu1, trunc1, conv2, flatten, fc1, output_layer ]) # secure_nn.load_layers([input_layer, relu1, trunc1, output_layer]) secure_nn.load_context(context) def generate_random_data(shape): return gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(shape)).reshape(shape) conv1_w = generate_random_data(conv1.weight_shape) conv2_w = generate_random_data(conv2.weight_shape) fc1_w = generate_random_data(fc1.weight_shape) def check_correctness(input_img, output): torch_pool1 = torch.nn.MaxPool2d(2) x = input_img.to(Config.device).double() x = x.reshape([1] + list(x.shape)) x = pmod(F.conv2d(x, conv1_w.to(Config.device).double(), padding=1), q_23) x = pmod(F.relu(nmod(x, q_23)), q_23) x = pmod(torch_pool1(nmod(x, q_23)), q_23) x = pmod(x // (2**pow_to_div), q_23) x = pmod(F.conv2d(x, conv2_w.to(Config.device).double(), padding=1), q_23) x = x.view(-1) x = pmod( torch.mm(x.view(1, -1), fc1_w.to(Config.device).double().t()).view(-1), q_23) expected = x actual = pmod(output, q_23) if len(expected.shape) == 4 and expected.shape[0] == 1: expected = expected.reshape(expected.shape[1:]) compare_expected_actual(expected, actual, name=test_name, get_relative=True) def test_server(): rank = Config.server_rank init_communicate(rank) context.set_rank(rank) comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() conv1.load_weight(conv1_w) conv2.load_weight(conv2_w) fc1.load_weight(fc1_w) trunc1.set_div_to_pow(pow_to_div) with NamedTimerInstance("Server Offline"): secure_nn.offline() torch_sync() with NamedTimerInstance("Server Online"): secure_nn.online() torch_sync() end_communicate() def test_client(): rank = Config.client_rank init_communicate(rank) context.set_rank(rank) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() input_img = generate_random_data(input_shape) trunc1.set_div_to_pow(pow_to_div) secure_nn.feed_input(input_img) with NamedTimerInstance("Client Offline"): secure_nn.offline() torch_sync() with NamedTimerInstance("Client Online"): secure_nn.online() torch_sync() check_correctness(input_img, secure_nn.get_output()) end_communicate() marshal_funcs([test_server, test_client]) print(f"\nTest for {test_name}: End")
def test_avgpool2x2_dgk(input_sid, master_address, master_port, num_elem=2**17): test_name = "Avgpool2x2" print(f"\nTest for {test_name}: Start") data_bit = 20 work_bit = 20 data_range = 2**data_bit q_16 = 12289 # q_23 = 786433 q_23 = 7340033 img_hw = 4 print(f"Number of element: {num_elem}") fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_16.generate_keys() fhe_builder_23 = FheBuilder(q_23, Config.n_23) fhe_builder_23.generate_keys() img = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) img_s = gen_unirand_int_grain(0, q_23 - 1, num_elem) img_c = pmod(img - img_s, q_23) def check_correctness_online(img, max_s, max_c): img = torch.where(img < q_23 // 2, img, img - q_23).cuda() pool = torch.nn.AvgPool2d(2) expected = pool(img.double().reshape([-1, img_hw, img_hw ])).reshape(-1) * 4 expected = pmod(expected, q_23) actual = pmod(max_s + max_c, q_23) compare_expected_actual(expected, actual, name=test_name + "_online", get_relative=True) def test_server(): init_communicate(Config.server_rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() prot = Avgpool2x2Server(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "avgpool") with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): prot.online(img_s) torch_sync() traffic_record.reset("server-online") blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_res_c") blob_out_c.prepare_recv() torch_sync() out_c = blob_out_c.get_recv() check_correctness_online(img, prot.out_s, out_c) end_communicate() def test_client(): init_communicate(Config.client_rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() prot = Avgpool2x2Client(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "avgpool") with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): prot.online(img_c) torch_sync() traffic_record.reset("client-online") blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_res_c") torch_sync() blob_out_c.send(prot.out_c) end_communicate() if input_sid == Config.both_rank: marshal_funcs([test_server, test_client]) elif input_sid == Config.server_rank: test_server() elif input_sid == Config.client_rank: test_client() print(f"\nTest for {test_name}: End")
def test_rotation(): modulus = Config.q_23 degree = Config.n_23 fhe_builder = FheBuilder(modulus, degree) fhe_builder.generate_keys() fhe_builder.generate_galois_keys() # x = gen_unirand_int_grain(0, modulus-1, degree) x = torch.arange(degree) with NamedTimerInstance("Fhe Encrypt"): enc = fhe_builder.build_enc_from_torch(x) enc_less = fhe_builder.build_enc_from_torch(x) plain = fhe_builder.build_plain_from_torch(x) fhe_builder.noise_budget(enc, "before mul") with NamedTimerInstance("ep mult"): enc *= plain enc_less *= plain fhe_builder.noise_budget(enc, "after mul") with NamedTimerInstance("ee add"): for i in range(128): enc += enc_less fhe_builder.noise_budget(enc, "after add") with NamedTimerInstance("rot"): fhe_builder.evaluator.rotate_rows_inplace(enc.cts[0], 64, fhe_builder.galois_keys) fhe_builder.noise_budget(enc, "after rot") print(fhe_builder.decrypt_to_torch(enc))
def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, n_16) fhe_builder_23 = FheBuilder(q_23, n_23) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() dgk = DgkBitClient(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "DgkBitTest") x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) y = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) x_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) y_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) x_s = pmod(x - x_c, q_23) y_s = pmod(y - y_c, q_23) y_sub_x_s = pmod(y_s - x_s, q_23) x_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "x") y_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y") y_sub_x_s_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y_sub_x_s") torch_sync() x_blob.send(x) y_blob.send(y) y_sub_x_s_blob.send(y_sub_x_s) torch_sync() with NamedTimerInstance("Client Offline"): dgk.offline() y_sub_x_c = pmod(y_c - x_c, q_23) traffic_record.reset("client-offline") torch_sync() with NamedTimerInstance("Client Online"): dgk.online(y_sub_x_c) traffic_record.reset("client-online") dgk_x_leq_y_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "dgk_x_leq_y_c") correct_mod_div_work_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "correct_mod_div_work_c") z_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "z") torch_sync() dgk_x_leq_y_c_blob.send(dgk.dgk_x_leq_y_c) correct_mod_div_work_c_blob.send(dgk.correct_mod_div_work_c) z_blob.send(dgk.z) end_communicate()