Ejemplo n.º 1
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = Conv2dSecureClient(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_secure_comm",
                                  padding=padding)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online()
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()
Ejemplo n.º 2
0
    def test_server():
        rank = Config.server_rank
        sys.stdout = Logger()
        traffic_record = TrafficRecord()
        secure_nn = get_secure_nn()
        secure_nn.set_rank(rank).init_communication(master_address=master_addr,
                                                    master_port=master_port)
        warming_up_cuda()
        secure_nn.fhe_builder_sync()
        load_trunc_params(secure_nn, store_configs)

        net_state = torch.load(net_state_name)
        load_weight_params(secure_nn, store_configs, net_state)

        meta_rg = MetaTruncRandomGenerator()
        meta_rg.reset_seed()

        with NamedTimerInstance("Server Offline"):
            secure_nn.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            secure_nn.online()
            torch_sync()
        traffic_record.reset("server-online")

        secure_nn.check_correctness(check_correctness)
        secure_nn.check_layers(get_plain_net, get_hooking_lst(model_name_base))
        secure_nn.end_communication()
Ejemplo n.º 3
0
    def test_client():
        init_communicate(Config.client_rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()
        prot = Avgpool2x2Client(num_elem, q_23, q_16, work_bit, data_bit,
                                img_hw, fhe_builder_16, fhe_builder_23,
                                "avgpool")

        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            prot.online(img_c)
            torch_sync()
        traffic_record.reset("client-online")

        blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_res_c")
        torch_sync()
        blob_out_c.send(prot.out_c)
        end_communicate()
Ejemplo n.º 4
0
    def test_server():
        init_communicate(Config.server_rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()
        prot = Avgpool2x2Server(num_elem, q_23, q_16, work_bit, data_bit,
                                img_hw, fhe_builder_16, fhe_builder_23,
                                "avgpool")

        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            prot.online(img_s)
            torch_sync()
        traffic_record.reset("server-online")

        blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_res_c")
        blob_out_c.prepare_recv()
        torch_sync()
        out_c = blob_out_c.get_recv()
        check_correctness_online(img, prot.out_s, out_c)

        end_communicate()
Ejemplo n.º 5
0
    def test_server():
        rank = Config.server_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = Conv2dSecureServer(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_secure_comm",
                                  padding=padding)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight, bias=bias)
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(input_s)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_online(input, weight, bias, prot.output_s, output_c)

        end_communicate()
Ejemplo n.º 6
0
def main(net_state_name, net, testset):
    warming_up_cuda()
    # print("CUDA is available:", torch.cuda.is_available())

    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=minibatch,
                                             shuffle=False,
                                             num_workers=4)

    # seed = 100
    # torch.manual_seed(seed)
    # torch.cuda.manual_seed(seed)
    # np.random.seed(seed)

    net.to(device)

    sys.stdout = Logger()

    NumShowInter = 100
    NumEpoch = 200
    IterCounter = -1
    training_start = time.time()

    load_state_dict(net_state_name, net)
    sys.stdout = Logger()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(),
                          lr=0.05,
                          momentum=0.9,
                          weight_decay=5e-4)

    correct = 0
    total = 0

    loss_sum = 0.0
    cnt = 0
    inference_start = time.time()
    net.eval()
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            loss = criterion(outputs, labels)
            loss_sum += loss.data.cpu().item() * images.size(0)
            correct += (predicted == labels).sum().item()
            cnt += int(images.size()[0])
        print('Accuracy of the network on the 10000 test images: %f %%' %
              (100 * correct / total))
        print("loss=", loss_sum / float(cnt))

    elapsed_time = time.time() - inference_start
    print("Elapsed time for Prediction", elapsed_time)
Ejemplo n.º 7
0
def test_conv2d_fhe_ntt_single_thread():
    modulus = 786433
    img_hw = 16
    filter_hw = 3
    padding = 1
    num_input_channel = 64
    num_output_channel = 128
    data_bit = 17
    data_range = 2**data_bit

    x_shape = [num_input_channel, img_hw, img_hw]
    w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape)
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(w_shape)).reshape(w_shape)
    x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape)
    # x = torch.arange(get_prod(x_shape)).reshape(x_shape)
    # w = torch.arange(get_prod(w_shape)).reshape(w_shape)

    warming_up_cuda()
    prot = Conv2dFheNttSingleThread(modulus, data_range, img_hw, filter_hw,
                                    num_input_channel, num_output_channel,
                                    fhe_builder, "test_conv2d_fhe_ntt",
                                    padding)

    print("prot.num_input_batch", prot.num_input_batch)
    print("prot.num_output_batch", prot.num_output_batch)

    with NamedTimerInstance("encoding x"):
        prot.encode_input_to_fhe_batch(x)
    with NamedTimerInstance("conv2d with w"):
        prot.compute_conv2d(w)
    with NamedTimerInstance("conv2d masking output"):
        prot.masking_output()
    with NamedTimerInstance("decoding output"):
        y = prot.decode_output_from_fhe_batch()
    # actual = pmod(y, modulus)
    actual = pmod(y - prot.output_mask_s, modulus)
    # print("actual\n", actual)

    torch_x = x.reshape([1] + x_shape).double()
    torch_w = w.reshape(w_shape).double()
    with NamedTimerInstance("Conv2d Torch"):
        expected = F.conv2d(torch_x, torch_w, padding=padding)
        expected = pmod(expected.reshape(prot.output_shape), modulus)
    # print("expected", expected)
    compare_expected_actual(expected,
                            actual,
                            name="test_conv2d_fhe_ntt_single_thread",
                            get_relative=True)
Ejemplo n.º 8
0
    def test_server():
        init_communicate(Config.server_rank)
        warming_up_cuda()
        prot = ReconToClientServer(num_elem, modulus, test_name)
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Server online"):
            prot.online(x_s)
            torch_sync()

        end_communicate()
Ejemplo n.º 9
0
def test_fc_fhe_single_thread():
    test_name = "test_fc_fhe_single_thread"
    print(f"\nTest for {test_name}: Start")
    modulus = Config.q_23
    num_input_unit = 512
    num_output_unit = 512
    data_bit = 17

    data_range = 2**data_bit
    x_shape = [num_input_unit]
    w_shape = [num_output_unit, num_input_unit]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape)
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(w_shape)).reshape(w_shape)
    x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape)
    # w = gen_unirand_int_grain(0, modulus, get_prod(w_shape)).reshape(w_shape)
    # x = torch.arange(get_prod(x_shape)).reshape(x_shape)
    # w = torch.arange(get_prod(w_shape)).reshape(w_shape)

    warming_up_cuda()
    prot = FcFheSingleThread(modulus, data_range, num_input_unit,
                             num_output_unit, fhe_builder, test_name)

    print("prot.num_input_batch", prot.num_input_batch)
    print("prot.num_output_batch", prot.num_output_batch)
    print("prot.num_elem_in_piece", prot.num_elem_in_piece)

    with NamedTimerInstance("encoding x"):
        prot.encode_input_to_fhe_batch(x)
    with NamedTimerInstance("conv2d with w"):
        prot.compute_with_weight(w)
    with NamedTimerInstance("conv2d masking output"):
        prot.masking_output()
    with NamedTimerInstance("decoding output"):
        y = prot.decode_output_from_fhe_batch()
    actual = pmod(y, modulus)
    actual = pmod(y - prot.output_mask_s, modulus)
    # print("actual\n", actual)

    torch_x = x.reshape([1] + x_shape).double()
    torch_w = w.reshape(w_shape).double()
    with NamedTimerInstance("Conv2d Torch"):
        expected = torch.mm(torch_x, torch_w.t())
        expected = pmod(expected.reshape(prot.output_shape), modulus)
    compare_expected_actual(expected,
                            actual,
                            name=test_name,
                            get_relative=True)
    print(f"\nTest for {test_name}: End")
Ejemplo n.º 10
0
    def test_client():
        init_communicate(Config.client_rank)
        warming_up_cuda()
        prot = ReconToClientClient(num_elem, modulus, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online(x_c)
            torch_sync()

        check_correctness_online(prot.output, x_s, x_c)
        end_communicate()
Ejemplo n.º 11
0
    def test_client():
        init_communicate(Config.client_rank)
        warming_up_cuda()
        prot = SwapToClientOfflineClient(num_elem, modulus, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(y_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online(x_c)
            torch_sync()

        blob_output_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                                  "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()
Ejemplo n.º 12
0
def run_secure_nn_client_with_random_data(secure_nn, check_correctness, master_address, master_port):
    rank = Config.client_rank
    traffic_record = TrafficRecord()
    secure_nn.set_rank(rank).init_communication(master_address=master_address, master_port=master_port)
    warming_up_cuda()
    secure_nn.fhe_builder_sync().fill_random_input()

    with NamedTimerInstance("Client Offline"):
        secure_nn.offline()
        torch_sync()
    traffic_record.reset("client-offline")

    with NamedTimerInstance("Client Online"):
        secure_nn.online()
        torch_sync()
    traffic_record.reset("client-online")

    secure_nn.check_correctness(check_correctness).end_communication()
Ejemplo n.º 13
0
    def test_server():
        init_communicate(Config.server_rank)
        warming_up_cuda()
        prot = SwapToClientOfflineServer(num_elem, modulus, test_name)
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Server online"):
            prot.online(x_s)
            torch_sync()

        blob_output_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                                  "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_online(x_s, x_c, prot.output_s, output_c)

        end_communicate()
Ejemplo n.º 14
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = FcSecureClient(modulus, data_range, num_input_unit,
                              num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online()
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()
Ejemplo n.º 15
0
    def test_server():
        rank = Config.server_rank
        init_communicate(Config.server_rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        dgk = DgkBitServer(num_elem, q_23, q_16, work_bit, data_bit,
                           fhe_builder_16, fhe_builder_23, "DgkBitTest")

        x_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "x")
        y_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y")
        y_sub_x_s_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                   "y_sub_x_s")
        x_blob.prepare_recv()
        y_blob.prepare_recv()
        y_sub_x_s_blob.prepare_recv()
        torch_sync()
        x = x_blob.get_recv()
        y = y_blob.get_recv()
        y_sub_x_s = y_sub_x_s_blob.get_recv()

        torch_sync()
        with NamedTimerInstance("Server Offline"):
            dgk.offline()
        # y_sub_x_s = pmod(y_s.to(Config.device) - x_s.to(Config.device), q_23)
        torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            dgk.online(y_sub_x_s)
        traffic_record.reset("server-online")

        dgk_x_leq_y_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                       "dgk_x_leq_y_c")
        correct_mod_div_work_c_blob = BlobTorch(num_elem, torch.float,
                                                dgk.comm_base,
                                                "correct_mod_div_work_c")
        z_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "z")
        dgk_x_leq_y_c_blob.prepare_recv()
        correct_mod_div_work_c_blob.prepare_recv()
        z_blob.prepare_recv()
        torch_sync()
        dgk_x_leq_y_c = dgk_x_leq_y_c_blob.get_recv()
        correct_mod_div_work_c = correct_mod_div_work_c_blob.get_recv()
        z = z_blob.get_recv()
        check_correctness(x, y, dgk.dgk_x_leq_y_s, dgk_x_leq_y_c)
        check_correctness_mod_div(dgk.r, z, dgk.correct_mod_div_work_s,
                                  correct_mod_div_work_c)
        end_communicate()
Ejemplo n.º 16
0
    def test_client():
        rank = Config.client_rank
        sys.stdout = Logger()
        traffic_record = TrafficRecord()
        secure_nn = get_secure_nn()
        secure_nn.set_rank(rank).init_communication(master_address=master_addr,
                                                    master_port=master_port)
        warming_up_cuda()
        secure_nn.fhe_builder_sync()

        load_trunc_params(secure_nn, store_configs)

        def input_shift(data):
            first_layer_name = "conv1"
            return data_shift(data,
                              store_configs[first_layer_name + "ForwardX"])

        def testset():
            if model_name_base in ["vgg16_cifar100"]:
                return torchvision.datasets.CIFAR100(
                    root='./data',
                    train=False,
                    download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.4914, 0.4822, 0.4465),
                                             (0.2023, 0.1994, 0.2010)),
                        input_shift
                    ]))
            elif model_name_base in ["vgg16_cifar10", "minionn_maxpool"]:
                return torchvision.datasets.CIFAR10(
                    root='./data',
                    train=False,
                    download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.4914, 0.4822, 0.4465),
                                             (0.2023, 0.1994, 0.2010)),
                        input_shift
                    ]))

        testloader = torch.utils.data.DataLoader(testset(),
                                                 batch_size=batch_size,
                                                 shuffle=True,
                                                 num_workers=2)

        data_iter = iter(testloader)
        image, truth = next(data_iter)
        image = image.reshape(secure_nn.get_input_shape())
        secure_nn.fill_input(image)

        with NamedTimerInstance("Client Offline"):
            secure_nn.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            secure_nn.online()
            torch_sync()
        traffic_record.reset("client-online")

        secure_nn.check_correctness(check_correctness, truth=truth)
        secure_nn.check_layers(get_plain_net, get_hooking_lst(model_name_base))
        secure_nn.end_communication()
Ejemplo n.º 17
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, n_16)
        fhe_builder_23 = FheBuilder(q_23, n_23)
        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        dgk = DgkBitClient(num_elem, q_23, q_16, work_bit, data_bit,
                           fhe_builder_16, fhe_builder_23, "DgkBitTest")

        x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  num_elem)
        y = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  num_elem)
        x_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                    num_elem)
        y_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                    num_elem)
        x_s = pmod(x - x_c, q_23)
        y_s = pmod(y - y_c, q_23)
        y_sub_x_s = pmod(y_s - x_s, q_23)

        x_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "x")
        y_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y")
        y_sub_x_s_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                   "y_sub_x_s")
        torch_sync()
        x_blob.send(x)
        y_blob.send(y)
        y_sub_x_s_blob.send(y_sub_x_s)

        torch_sync()
        with NamedTimerInstance("Client Offline"):
            dgk.offline()
        y_sub_x_c = pmod(y_c - x_c, q_23)
        traffic_record.reset("client-offline")
        torch_sync()

        with NamedTimerInstance("Client Online"):
            dgk.online(y_sub_x_c)
        traffic_record.reset("client-online")

        dgk_x_leq_y_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                       "dgk_x_leq_y_c")
        correct_mod_div_work_c_blob = BlobTorch(num_elem, torch.float,
                                                dgk.comm_base,
                                                "correct_mod_div_work_c")
        z_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "z")
        torch_sync()
        dgk_x_leq_y_c_blob.send(dgk.dgk_x_leq_y_c)
        correct_mod_div_work_c_blob.send(dgk.correct_mod_div_work_c)
        z_blob.send(dgk.z)
        end_communicate()