def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = Conv2dSecureClient(modulus, fhe_builder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_secure_comm", padding=padding) with NamedTimerInstance("Client Offline"): prot.offline(input_c) torch_sync() with NamedTimerInstance("Client Online"): prot.online() torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate()
def test_server(): rank = Config.server_rank sys.stdout = Logger() traffic_record = TrafficRecord() secure_nn = get_secure_nn() secure_nn.set_rank(rank).init_communication(master_address=master_addr, master_port=master_port) warming_up_cuda() secure_nn.fhe_builder_sync() load_trunc_params(secure_nn, store_configs) net_state = torch.load(net_state_name) load_weight_params(secure_nn, store_configs, net_state) meta_rg = MetaTruncRandomGenerator() meta_rg.reset_seed() with NamedTimerInstance("Server Offline"): secure_nn.offline() torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): secure_nn.online() torch_sync() traffic_record.reset("server-online") secure_nn.check_correctness(check_correctness) secure_nn.check_layers(get_plain_net, get_hooking_lst(model_name_base)) secure_nn.end_communication()
def test_client(): init_communicate(Config.client_rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() prot = Avgpool2x2Client(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "avgpool") with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): prot.online(img_c) torch_sync() traffic_record.reset("client-online") blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_res_c") torch_sync() blob_out_c.send(prot.out_c) end_communicate()
def test_server(): init_communicate(Config.server_rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() prot = Avgpool2x2Server(num_elem, q_23, q_16, work_bit, data_bit, img_hw, fhe_builder_16, fhe_builder_23, "avgpool") with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): prot.online(img_s) torch_sync() traffic_record.reset("server-online") blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base, "recon_res_c") blob_out_c.prepare_recv() torch_sync() out_c = blob_out_c.get_recv() check_correctness_online(img, prot.out_s, out_c) end_communicate()
def test_server(): rank = Config.server_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = Conv2dSecureServer(modulus, fhe_builder, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, "test_conv2d_secure_comm", padding=padding) with NamedTimerInstance("Server Offline"): prot.offline(weight, bias=bias) torch_sync() with NamedTimerInstance("Server Online"): prot.online(input_s) torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_online(input, weight, bias, prot.output_s, output_c) end_communicate()
def main(net_state_name, net, testset): warming_up_cuda() # print("CUDA is available:", torch.cuda.is_available()) testloader = torch.utils.data.DataLoader(testset, batch_size=minibatch, shuffle=False, num_workers=4) # seed = 100 # torch.manual_seed(seed) # torch.cuda.manual_seed(seed) # np.random.seed(seed) net.to(device) sys.stdout = Logger() NumShowInter = 100 NumEpoch = 200 IterCounter = -1 training_start = time.time() load_state_dict(net_state_name, net) sys.stdout = Logger() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4) correct = 0 total = 0 loss_sum = 0.0 cnt = 0 inference_start = time.time() net.eval() with torch.no_grad(): for data in testloader: images, labels = data[0].to(device), data[1].to(device) outputs = net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) loss = criterion(outputs, labels) loss_sum += loss.data.cpu().item() * images.size(0) correct += (predicted == labels).sum().item() cnt += int(images.size()[0]) print('Accuracy of the network on the 10000 test images: %f %%' % (100 * correct / total)) print("loss=", loss_sum / float(cnt)) elapsed_time = time.time() - inference_start print("Elapsed time for Prediction", elapsed_time)
def test_conv2d_fhe_ntt_single_thread(): modulus = 786433 img_hw = 16 filter_hw = 3 padding = 1 num_input_channel = 64 num_output_channel = 128 data_bit = 17 data_range = 2**data_bit x_shape = [num_input_channel, img_hw, img_hw] w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw] fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape) w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape) # x = torch.arange(get_prod(x_shape)).reshape(x_shape) # w = torch.arange(get_prod(w_shape)).reshape(w_shape) warming_up_cuda() prot = Conv2dFheNttSingleThread(modulus, data_range, img_hw, filter_hw, num_input_channel, num_output_channel, fhe_builder, "test_conv2d_fhe_ntt", padding) print("prot.num_input_batch", prot.num_input_batch) print("prot.num_output_batch", prot.num_output_batch) with NamedTimerInstance("encoding x"): prot.encode_input_to_fhe_batch(x) with NamedTimerInstance("conv2d with w"): prot.compute_conv2d(w) with NamedTimerInstance("conv2d masking output"): prot.masking_output() with NamedTimerInstance("decoding output"): y = prot.decode_output_from_fhe_batch() # actual = pmod(y, modulus) actual = pmod(y - prot.output_mask_s, modulus) # print("actual\n", actual) torch_x = x.reshape([1] + x_shape).double() torch_w = w.reshape(w_shape).double() with NamedTimerInstance("Conv2d Torch"): expected = F.conv2d(torch_x, torch_w, padding=padding) expected = pmod(expected.reshape(prot.output_shape), modulus) # print("expected", expected) compare_expected_actual(expected, actual, name="test_conv2d_fhe_ntt_single_thread", get_relative=True)
def test_server(): init_communicate(Config.server_rank) warming_up_cuda() prot = ReconToClientServer(num_elem, modulus, test_name) with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() with NamedTimerInstance("Server online"): prot.online(x_s) torch_sync() end_communicate()
def test_fc_fhe_single_thread(): test_name = "test_fc_fhe_single_thread" print(f"\nTest for {test_name}: Start") modulus = Config.q_23 num_input_unit = 512 num_output_unit = 512 data_bit = 17 data_range = 2**data_bit x_shape = [num_input_unit] w_shape = [num_output_unit, num_input_unit] fhe_builder = FheBuilder(modulus, Config.n_23) fhe_builder.generate_keys() # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape) w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(w_shape)).reshape(w_shape) x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape) # w = gen_unirand_int_grain(0, modulus, get_prod(w_shape)).reshape(w_shape) # x = torch.arange(get_prod(x_shape)).reshape(x_shape) # w = torch.arange(get_prod(w_shape)).reshape(w_shape) warming_up_cuda() prot = FcFheSingleThread(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) print("prot.num_input_batch", prot.num_input_batch) print("prot.num_output_batch", prot.num_output_batch) print("prot.num_elem_in_piece", prot.num_elem_in_piece) with NamedTimerInstance("encoding x"): prot.encode_input_to_fhe_batch(x) with NamedTimerInstance("conv2d with w"): prot.compute_with_weight(w) with NamedTimerInstance("conv2d masking output"): prot.masking_output() with NamedTimerInstance("decoding output"): y = prot.decode_output_from_fhe_batch() actual = pmod(y, modulus) actual = pmod(y - prot.output_mask_s, modulus) # print("actual\n", actual) torch_x = x.reshape([1] + x_shape).double() torch_w = w.reshape(w_shape).double() with NamedTimerInstance("Conv2d Torch"): expected = torch.mm(torch_x, torch_w.t()) expected = pmod(expected.reshape(prot.output_shape), modulus) compare_expected_actual(expected, actual, name=test_name, get_relative=True) print(f"\nTest for {test_name}: End")
def test_client(): init_communicate(Config.client_rank) warming_up_cuda() prot = ReconToClientClient(num_elem, modulus, test_name) with NamedTimerInstance("Client Offline"): prot.offline() torch_sync() with NamedTimerInstance("Client Online"): prot.online(x_c) torch_sync() check_correctness_online(prot.output, x_s, x_c) end_communicate()
def test_client(): init_communicate(Config.client_rank) warming_up_cuda() prot = SwapToClientOfflineClient(num_elem, modulus, test_name) with NamedTimerInstance("Client Offline"): prot.offline(y_c) torch_sync() with NamedTimerInstance("Client Online"): prot.online(x_c) torch_sync() blob_output_c = BlobTorch(num_elem, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate()
def run_secure_nn_client_with_random_data(secure_nn, check_correctness, master_address, master_port): rank = Config.client_rank traffic_record = TrafficRecord() secure_nn.set_rank(rank).init_communication(master_address=master_address, master_port=master_port) warming_up_cuda() secure_nn.fhe_builder_sync().fill_random_input() with NamedTimerInstance("Client Offline"): secure_nn.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): secure_nn.online() torch_sync() traffic_record.reset("client-online") secure_nn.check_correctness(check_correctness).end_communication()
def test_server(): init_communicate(Config.server_rank) warming_up_cuda() prot = SwapToClientOfflineServer(num_elem, modulus, test_name) with NamedTimerInstance("Server Offline"): prot.offline() torch_sync() with NamedTimerInstance("Server online"): prot.online(x_s) torch_sync() blob_output_c = BlobTorch(num_elem, torch.float, prot.comm_base, "output_c") blob_output_c.prepare_recv() torch_sync() output_c = blob_output_c.get_recv() check_correctness_online(x_s, x_c, prot.output_s, output_c) end_communicate()
def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() prot = FcSecureClient(modulus, data_range, num_input_unit, num_output_unit, fhe_builder, test_name) with NamedTimerInstance("Client Offline"): prot.offline(input_c) torch_sync() with NamedTimerInstance("Client Online"): prot.online() torch_sync() blob_output_c = BlobTorch(prot.output_shape, torch.float, prot.comm_base, "output_c") torch_sync() blob_output_c.send(prot.output_c) end_communicate()
def test_server(): rank = Config.server_rank init_communicate(Config.server_rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, Config.n_16) fhe_builder_23 = FheBuilder(q_23, Config.n_23) comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.recv_public_key() comm_fhe_23.recv_public_key() comm_fhe_16.wait_and_build_public_key() comm_fhe_23.wait_and_build_public_key() dgk = DgkBitServer(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "DgkBitTest") x_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "x") y_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y") y_sub_x_s_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y_sub_x_s") x_blob.prepare_recv() y_blob.prepare_recv() y_sub_x_s_blob.prepare_recv() torch_sync() x = x_blob.get_recv() y = y_blob.get_recv() y_sub_x_s = y_sub_x_s_blob.get_recv() torch_sync() with NamedTimerInstance("Server Offline"): dgk.offline() # y_sub_x_s = pmod(y_s.to(Config.device) - x_s.to(Config.device), q_23) torch_sync() traffic_record.reset("server-offline") with NamedTimerInstance("Server Online"): dgk.online(y_sub_x_s) traffic_record.reset("server-online") dgk_x_leq_y_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "dgk_x_leq_y_c") correct_mod_div_work_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "correct_mod_div_work_c") z_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "z") dgk_x_leq_y_c_blob.prepare_recv() correct_mod_div_work_c_blob.prepare_recv() z_blob.prepare_recv() torch_sync() dgk_x_leq_y_c = dgk_x_leq_y_c_blob.get_recv() correct_mod_div_work_c = correct_mod_div_work_c_blob.get_recv() z = z_blob.get_recv() check_correctness(x, y, dgk.dgk_x_leq_y_s, dgk_x_leq_y_c) check_correctness_mod_div(dgk.r, z, dgk.correct_mod_div_work_s, correct_mod_div_work_c) end_communicate()
def test_client(): rank = Config.client_rank sys.stdout = Logger() traffic_record = TrafficRecord() secure_nn = get_secure_nn() secure_nn.set_rank(rank).init_communication(master_address=master_addr, master_port=master_port) warming_up_cuda() secure_nn.fhe_builder_sync() load_trunc_params(secure_nn, store_configs) def input_shift(data): first_layer_name = "conv1" return data_shift(data, store_configs[first_layer_name + "ForwardX"]) def testset(): if model_name_base in ["vgg16_cifar100"]: return torchvision.datasets.CIFAR100( root='./data', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), input_shift ])) elif model_name_base in ["vgg16_cifar10", "minionn_maxpool"]: return torchvision.datasets.CIFAR10( root='./data', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), input_shift ])) testloader = torch.utils.data.DataLoader(testset(), batch_size=batch_size, shuffle=True, num_workers=2) data_iter = iter(testloader) image, truth = next(data_iter) image = image.reshape(secure_nn.get_input_shape()) secure_nn.fill_input(image) with NamedTimerInstance("Client Offline"): secure_nn.offline() torch_sync() traffic_record.reset("client-offline") with NamedTimerInstance("Client Online"): secure_nn.online() torch_sync() traffic_record.reset("client-online") secure_nn.check_correctness(check_correctness, truth=truth) secure_nn.check_layers(get_plain_net, get_hooking_lst(model_name_base)) secure_nn.end_communication()
def test_client(): rank = Config.client_rank init_communicate(rank, master_address=master_address, master_port=master_port) warming_up_cuda() traffic_record = TrafficRecord() fhe_builder_16 = FheBuilder(q_16, n_16) fhe_builder_23 = FheBuilder(q_23, n_23) fhe_builder_16.generate_keys() fhe_builder_23.generate_keys() comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16") comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23") torch_sync() comm_fhe_16.send_public_key() comm_fhe_23.send_public_key() dgk = DgkBitClient(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "DgkBitTest") x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) y = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) x_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) y_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem) x_s = pmod(x - x_c, q_23) y_s = pmod(y - y_c, q_23) y_sub_x_s = pmod(y_s - x_s, q_23) x_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "x") y_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y") y_sub_x_s_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y_sub_x_s") torch_sync() x_blob.send(x) y_blob.send(y) y_sub_x_s_blob.send(y_sub_x_s) torch_sync() with NamedTimerInstance("Client Offline"): dgk.offline() y_sub_x_c = pmod(y_c - x_c, q_23) traffic_record.reset("client-offline") torch_sync() with NamedTimerInstance("Client Online"): dgk.online(y_sub_x_c) traffic_record.reset("client-online") dgk_x_leq_y_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "dgk_x_leq_y_c") correct_mod_div_work_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "correct_mod_div_work_c") z_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "z") torch_sync() dgk_x_leq_y_c_blob.send(dgk.dgk_x_leq_y_c) correct_mod_div_work_c_blob.send(dgk.correct_mod_div_work_c) z_blob.send(dgk.z) end_communicate()