コード例 #1
0
def test_swap_to_client_comm():
    test_name = "test_swap_to_client_comm"
    print(f"\nTest for {test_name}: Start")
    modulus = 786433
    num_elem = 2**17

    print(f"Number of element: {num_elem}")

    x_s = gen_unirand_int_grain(0, modulus - 1, num_elem).cpu()
    x_c = gen_unirand_int_grain(0, modulus - 1, num_elem).cpu()
    y_c = gen_unirand_int_grain(0, modulus - 1, num_elem).cpu()

    def check_correctness_online(input_s, input_c, output_s, output_c):
        expected = pmod(input_s.cuda() + input_c.cuda(), modulus)
        actual = pmod(output_s.cuda() + output_c.cuda(), modulus)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + " online",
                                get_relative=True)

    def test_server():
        init_communicate(Config.server_rank)
        warming_up_cuda()
        prot = SwapToClientOfflineServer(num_elem, modulus, test_name)
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Server online"):
            prot.online(x_s)
            torch_sync()

        blob_output_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                                  "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_online(x_s, x_c, prot.output_s, output_c)

        end_communicate()

    def test_client():
        init_communicate(Config.client_rank)
        warming_up_cuda()
        prot = SwapToClientOfflineClient(num_elem, modulus, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(y_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online(x_c)
            torch_sync()

        blob_output_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                                  "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()

    marshal_funcs([test_server, test_client])
    print(f"\nTest for {test_name}: End")
コード例 #2
0
def test_fhe_refresh():
    print()
    print("Test: FHE refresh: Start")
    modulus, degree = 12289, 2048
    num_batch = 12
    num_elem = 2 ** 15
    fhe_builder = FheBuilder(modulus, degree)
    fhe_builder.generate_keys()

    def test_batch_server():
        init_communicate(Config.server_rank)
        shape = [num_batch, num_elem]
        tensor = gen_unirand_int_grain(0, modulus, shape)
        refresher = EncRefresherServer(shape, fhe_builder, "test_batch_refresher")
        enc = [fhe_builder.build_enc_from_torch(tensor[i]) for i in range(num_batch)]
        refreshed = refresher.request(enc)
        tensor_refreshed = fhe_builder.decrypt_to_torch(refreshed)
        compare_expected_actual(tensor, tensor_refreshed, get_relative=True, name="batch refresh")

        end_communicate()

    def test_batch_client():
        init_communicate(Config.client_rank)
        shape = [num_batch, num_elem]
        refresher = EncRefresherClient(shape, fhe_builder, "test_batch_refresher")
        refresher.prepare_recv()
        refresher.response()

        end_communicate()

    marshal_funcs([test_batch_server, test_batch_client])

    def test_1d_server():
        init_communicate(Config.server_rank)
        shape = num_elem
        tensor = gen_unirand_int_grain(0, modulus, shape)
        refresher = EncRefresherServer(shape, fhe_builder, "test_1d_refresher")
        enc = fhe_builder.build_enc_from_torch(tensor)
        refreshed = refresher.request(enc)
        tensor_refreshed = fhe_builder.decrypt_to_torch(refreshed)
        compare_expected_actual(tensor, tensor_refreshed, get_relative=True, name="1d_refresh")

        end_communicate()

    def test_1d_client():
        init_communicate(Config.client_rank)
        shape = num_elem
        refresher = EncRefresherClient(shape, fhe_builder, "test_1d_refresher")
        refresher.prepare_recv()
        refresher.response()

        end_communicate()

    marshal_funcs([test_1d_server, test_1d_client])

    print()
    print("Test: FHE refresh: End")
コード例 #3
0
def test_comm_base():
    num_elem = 2**17
    expect_float = torch.ones(num_elem).float() + 3
    expect_double = torch.ones(num_elem).double() + 5
    expect_int16 = torch.ones(num_elem).type(torch.int16) + 324
    expect_uint8 = torch.ones(num_elem).type(torch.uint8) + 123
    comm_name = "comm_base"
    float_tag = "float_tag"
    double_tag = "double_tag"
    int16_tag = "int16_tag"
    uint8_tag = "uint8_tag"

    def comm_base_server():
        init_communicate(Config.server_rank)
        comm_base = CommBase(Config.server_rank, comm_name)

        send_int16 = expect_int16.cuda()
        send_uint8 = expect_uint8.cuda()

        with NamedTimerInstance("Server float and int16"):
            comm_base.recv_torch(torch.zeros(num_elem).float(), float_tag)
            comm_base.send_torch(send_int16, int16_tag)
            comm_base.wait(float_tag)
            actual_float = comm_base.get_tensor(float_tag).cuda()

        comm_base.recv_torch(torch.zeros(num_elem).float(), float_tag)
        dist.barrier()
        with NamedTimerInstance("Server float and int16"):
            comm_base.send_torch(send_int16, int16_tag)
            comm_base.wait(float_tag)
            actual_float = comm_base.get_tensor(float_tag).cuda()

        comm_base.recv_torch(torch.zeros(num_elem).double(), double_tag)
        dist.barrier()
        with NamedTimerInstance("Server double and uint8"):
            comm_base.send_torch(send_uint8, uint8_tag)
            comm_base.wait(double_tag)
            actual_double = comm_base.get_tensor(double_tag).cuda()

        comm_base.recv_torch(torch.zeros(num_elem).double(), double_tag)
        dist.barrier()
        with NamedTimerInstance("Server double and uint8"):
            comm_base.send_torch(send_uint8, uint8_tag)
            comm_base.wait(double_tag)
            actual_double = comm_base.get_tensor(double_tag).cuda()

        dist.barrier()
        compare_expected_actual(expect_float.cuda(),
                                actual_float,
                                name="float",
                                get_relative=True)
        compare_expected_actual(expect_double.cuda(),
                                actual_double,
                                name="double",
                                get_relative=True)

        google_vm_simulator = NetworkSimulator(bandwidth=10 * (10**9),
                                               basis_latency=.001)
        with NamedTimerInstance("Simulate int16"):
            google_vm_simulator.simulate(send_int16.cpu().cuda())
        with NamedTimerInstance("Simulate uint8"):
            google_vm_simulator.simulate(send_uint8.cpu().cuda())

        dist.destroy_process_group()

    def comm_base_client():
        init_communicate(Config.client_rank)
        comm_base = CommBase(Config.client_rank, comm_name)

        send_float = expect_float.cuda()
        send_double = expect_double.cuda()

        with NamedTimerInstance("Client float and int16"):
            comm_base.recv_torch(
                torch.zeros(num_elem).type(torch.int16), int16_tag)
            comm_base.send_torch(send_float, float_tag)
            comm_base.wait(int16_tag)
            actual_int16 = comm_base.get_tensor(int16_tag).cuda()

        comm_base.recv_torch(
            torch.zeros(num_elem).type(torch.int16), int16_tag)
        dist.barrier()
        with NamedTimerInstance("Client float and int16"):
            comm_base.send_torch(send_float, float_tag)
            comm_base.wait(int16_tag)
            actual_int16 = comm_base.get_tensor(int16_tag).cuda()

        comm_base.recv_torch(
            torch.zeros(num_elem).type(torch.uint8), uint8_tag)
        dist.barrier()
        with NamedTimerInstance("Client double and uint8"):
            comm_base.send_torch(send_double, double_tag)
            comm_base.wait(uint8_tag)
            actual_uint8 = comm_base.get_tensor(uint8_tag).cuda()

        comm_base.recv_torch(
            torch.zeros(num_elem).type(torch.uint8), uint8_tag)
        dist.barrier()
        with NamedTimerInstance("Client double and uint8"):
            comm_base.send_torch(send_double, double_tag)
            comm_base.wait(uint8_tag)
            actual_uint8 = comm_base.get_tensor(uint8_tag).cuda()

        dist.barrier()
        compare_expected_actual(expect_int16.cuda(),
                                actual_int16,
                                name="int16",
                                get_relative=True)
        compare_expected_actual(expect_uint8.cuda(),
                                actual_uint8,
                                name="uint8",
                                get_relative=True)

        dist.destroy_process_group()

    marshal_funcs([comm_base_server, comm_base_client])
コード例 #4
0
def test_comm_fhe_builder():
    num_elem = 2**17
    # modulus, degree = 12289, 2048
    modulus, degree = Config.q_16, Config.n_16
    expected_tensor_pk = torch.zeros(num_elem).float() + 23
    expected_tensor_sk = torch.zeros(num_elem).float() + 42
    pk_tag = "pk"
    sk_tag = "sk"
    ori_tag = "ori"

    ori = generate_random_mask(modulus, num_elem)

    def test_server():
        init_communicate(Config.server_rank)
        fhe_builder = FheBuilder(modulus, degree)
        comm_fhe_builder = CommFheBuilder(Config.server_rank, fhe_builder,
                                          "fhe_builder")

        comm_fhe_builder.recv_public_key()
        comm_fhe_builder.wait_and_build_public_key()

        blob_ori = BlobFheEnc(num_elem, comm_fhe_builder, ori_tag)
        blob_ori.send_from_torch(ori)

        enc_pk = fhe_builder.build_enc_from_torch(expected_tensor_pk)
        comm_fhe_builder.send_enc(enc_pk, pk_tag)

        comm_fhe_builder.recv_secret_key()
        comm_fhe_builder.wait_and_build_secret_key()

        enc_sk = fhe_builder.build_enc(num_elem)
        comm_fhe_builder.recv_enc(enc_sk, sk_tag)
        comm_fhe_builder.wait_enc(sk_tag)
        actual_tensor_sk = fhe_builder.decrypt_to_torch(enc_sk)
        compare_expected_actual(expected_tensor_sk,
                                actual_tensor_sk,
                                get_relative=True,
                                name="Recovering to test sk")

        dist.destroy_process_group()

    def test_client():
        init_communicate(Config.client_rank)
        fhe_builder = FheBuilder(modulus, degree)
        comm_fhe_builder = CommFheBuilder(Config.client_rank, fhe_builder,
                                          "fhe_builder")
        fhe_builder.generate_keys()

        comm_fhe_builder.send_public_key()

        blob_ori = BlobFheEnc(num_elem, comm_fhe_builder, ori_tag)
        blob_ori.prepare_recv()
        dec = blob_ori.get_recv_decrypt()
        compare_expected_actual(ori, dec, get_relative=True, name=ori_tag)

        enc_pk = fhe_builder.build_enc(num_elem)
        comm_fhe_builder.recv_enc(enc_pk, pk_tag)
        comm_fhe_builder.wait_enc(pk_tag)
        actual_tensor_pk = fhe_builder.decrypt_to_torch(enc_pk)
        compare_expected_actual(expected_tensor_pk,
                                actual_tensor_pk,
                                get_relative=True,
                                name="Recovering to test pk")

        comm_fhe_builder.send_secret_key()

        enc_sk = fhe_builder.build_enc_from_torch(expected_tensor_sk)
        comm_fhe_builder.send_enc(enc_sk, sk_tag)

        dist.destroy_process_group()

    marshal_funcs([test_server, test_client])
コード例 #5
0
def test_maxpool2x2_dgk(input_sid,
                        master_address,
                        master_port,
                        num_elem=2**17):
    test_name = "Maxpool2x2 Dgk"
    print(f"\nTest for {test_name}: Start")
    data_bit = 20
    work_bit = 20
    data_range = 2**data_bit
    q_16 = Config.q_16
    # q_23 = 786433
    q_23 = Config.q_23
    img_hw = 4
    print(f"Number of element: {num_elem}")

    def check_correctness_online(img, max_s, max_c):
        img = torch.where(img < q_23 // 2, img, img - q_23).to(Config.device)
        pool = torch.nn.MaxPool2d(2)
        expected = pool(img.reshape([-1, img_hw, img_hw])).reshape(-1)
        expected = pmod(expected, q_23)
        actual = pmod(max_s + max_c, q_23)
        compare_expected_actual(expected,
                                actual,
                                name="maxpool2x2_dgk_online",
                                get_relative=True)

    def test_server():
        rank = Config.server_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        img = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                    num_elem)
        img_s = gen_unirand_int_grain(0, q_23 - 1, num_elem)
        img_c = pmod(img - img_s, q_23)

        prot = Maxpool2x2DgkServer(num_elem, q_23, q_16, work_bit, data_bit,
                                   img_hw, fhe_builder_16, fhe_builder_23,
                                   "max_dgk")

        blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                               "recon_max_c")
        torch_sync()
        blob_img_c.send(img_c)

        torch_sync()
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            prot.online(img_s)
            torch_sync()
        traffic_record.reset("server-online")

        blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_max_c")
        blob_max_c.prepare_recv()
        torch_sync()
        max_c = blob_max_c.get_recv()
        check_correctness_online(img, prot.max_s, max_c)

        end_communicate()

    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        prot = Maxpool2x2DgkClient(num_elem, q_23, q_16, work_bit, data_bit,
                                   img_hw, fhe_builder_16, fhe_builder_23,
                                   "max_dgk")
        blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                               "recon_max_c")
        blob_img_c.prepare_recv()
        torch_sync()
        img_c = blob_img_c.get_recv()

        torch_sync()
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            prot.online(img_c)
            torch_sync()
        traffic_record.reset("client-online")

        blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_max_c")
        torch_sync()
        blob_max_c.send(prot.max_c)
        end_communicate()

    if input_sid == Config.both_rank:
        marshal_funcs([test_server, test_client])
    elif input_sid == Config.server_rank:
        marshal_funcs([test_server])
    elif input_sid == Config.client_rank:
        marshal_funcs([test_client])

    print(f"\nTest for {test_name}: End")
コード例 #6
0
def test_shares_mult():
    print("\nTest for Shares Mult: Start")
    modulus = Config.q_23
    num_elem = 2**17
    print(f"Number of element: {num_elem}")

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    a = gen_unirand_int_grain(0, modulus - 1, num_elem)
    a_s = gen_unirand_int_grain(0, modulus - 1, num_elem)
    a_c = pmod(a - a_s, modulus)
    b = gen_unirand_int_grain(0, modulus - 1, num_elem)
    b_s = gen_unirand_int_grain(0, modulus - 1, num_elem)
    b_c = pmod(b - b_s, modulus)

    def check_correctness_offline(u, v, z_s, z_c):
        expected = pmod(
            u.double().to(Config.device) * v.double().to(Config.device),
            modulus)
        actual = pmod(z_s + z_c, modulus)
        compare_expected_actual(expected,
                                actual,
                                name="shares_mult_offline",
                                get_relative=True)

    def check_correctness_online(a, b, c_s, c_c):
        expected = pmod(
            a.double().to(Config.device) * b.double().to(Config.device),
            modulus)
        actual = pmod(c_s + c_c, modulus)
        compare_expected_actual(expected,
                                actual,
                                name="shares_mult_online",
                                get_relative=True)

    def test_server():
        init_communicate(Config.server_rank)
        prot = SharesMultServer(num_elem, modulus, fhe_builder,
                                "test_shares_mult")
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(a_s, b_s)
            torch_sync()

        blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_u_c")
        blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_v_c")
        blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_z_c")
        blob_u_c.prepare_recv()
        blob_v_c.prepare_recv()
        blob_z_c.prepare_recv()
        torch_sync()
        u_c = blob_u_c.get_recv()
        v_c = blob_v_c.get_recv()
        z_c = blob_z_c.get_recv()
        u = pmod(prot.u_s + u_c, modulus)
        v = pmod(prot.v_s + v_c, modulus)
        check_correctness_online(u, v, prot.z_s, z_c)

        blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c")
        blob_c_c.prepare_recv()
        torch_sync()
        c_c = blob_c_c.get_recv()
        check_correctness_online(a, b, prot.c_s, c_c)
        end_communicate()

    def test_client():
        init_communicate(Config.client_rank)
        prot = SharesMultClient(num_elem, modulus, fhe_builder,
                                "test_shares_mult")
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online(a_c, b_c)
            torch_sync()

        blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_u_c")
        blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_v_c")
        blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_z_c")
        torch_sync()
        blob_u_c.send(prot.u_c)
        blob_v_c.send(prot.v_c)
        blob_z_c.send(prot.z_c)
        blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c")
        torch_sync()
        blob_c_c.send(prot.c_c)
        end_communicate()

    marshal_funcs([test_server, test_client])
    print("\nTest for Shares Mult: End")
コード例 #7
0
def test_conv2d_secure_comm(input_sid,
                            master_address,
                            master_port,
                            setting=(16, 3, 128, 128)):
    test_name = "Conv2d Secure Comm"
    print(f"\nTest for {test_name}: Start")
    modulus = 786433
    padding = 1
    img_hw, filter_hw, num_input_channel, num_output_channel = setting
    data_bit = 17
    data_range = 2**data_bit
    # n_23 = 8192
    n_23 = 16384
    print(f"Setting covn2d: img_hw: {img_hw}, "
          f"filter_hw: {filter_hw}, "
          f"num_input_channel: {num_input_channel}, "
          f"num_output_channel: {num_output_channel}")

    x_shape = [num_input_channel, img_hw, img_hw]
    w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw]
    b_shape = [num_output_channel]

    fhe_builder = FheBuilder(modulus, n_23)
    fhe_builder.generate_keys()

    weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                   get_prod(w_shape)).reshape(w_shape)
    bias = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                 get_prod(b_shape)).reshape(b_shape)
    input = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  get_prod(x_shape)).reshape(x_shape)
    input_c = generate_random_mask(modulus, x_shape)
    input_s = pmod(input - input_c, modulus)

    def check_correctness_online(x, w, b, output_s, output_c):
        actual = pmod(output_s.cuda() + output_c.cuda(), modulus)
        torch_x = x.reshape([1] + x_shape).cuda().double()
        torch_w = w.reshape(w_shape).cuda().double()
        torch_b = b.cuda().double() if b is not None else None
        expected = F.conv2d(torch_x, torch_w, padding=padding, bias=torch_b)
        expected = pmod(expected.reshape(output_s.shape), modulus)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + " online",
                                get_relative=True)

    def test_server():
        rank = Config.server_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = Conv2dSecureServer(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_secure_comm",
                                  padding=padding)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight, bias=bias)
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(input_s)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_online(input, weight, bias, prot.output_s, output_c)

        end_communicate()

    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = Conv2dSecureClient(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_secure_comm",
                                  padding=padding)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online()
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()

    if input_sid == Config.both_rank:
        marshal_funcs([test_server, test_client])
    elif input_sid == Config.server_rank:
        test_server()
    elif input_sid == Config.client_rank:
        test_client()

    print(f"\nTest for {test_name}: End")
コード例 #8
0
def test_conv2d_fhe_ntt_comm():
    test_name = "Conv2d Fhe NTT Comm"
    print(f"\nTest for {test_name}: Start")
    modulus = 786433
    img_hw = 2
    filter_hw = 3
    padding = 1
    num_input_channel = 512
    num_output_channel = 512
    data_bit = 17
    data_range = 2**data_bit
    print(f"Setting: img_hw {img_hw}, "
          f"filter_hw: {filter_hw}, "
          f"num_input_channel: {num_input_channel}, "
          f"num_output_channel: {num_output_channel}")

    x_shape = [num_input_channel, img_hw, img_hw]
    w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                   get_prod(w_shape)).reshape(w_shape)
    input_mask = gen_unirand_int_grain(0, modulus - 1,
                                       get_prod(x_shape)).reshape(x_shape)

    # input_mask = torch.arange(get_prod(x_shape)).reshape(x_shape)

    def check_correctness_offline(x, w, output_mask, output_c):
        actual = pmod(output_c.cuda() - output_mask.cuda(), modulus)
        torch_x = x.reshape([1] + x_shape).cuda().double()
        torch_w = w.reshape(w_shape).cuda().double()
        expected = F.conv2d(torch_x, torch_w, padding=padding)
        expected = pmod(expected.reshape(output_mask.shape), modulus)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + " offline",
                                get_relative=True)

    def test_server():
        init_communicate(Config.server_rank)
        prot = Conv2dFheNttServer(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_fhe_ntt_comm",
                                  padding=padding)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_offline(input_mask, weight, prot.output_mask_s,
                                  output_c)

        end_communicate()

    def test_client():
        init_communicate(Config.client_rank)
        prot = Conv2dFheNttClient(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_fhe_ntt_comm",
                                  padding=padding)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_mask)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()

    marshal_funcs([test_server, test_client])
    print(f"\nTest for {test_name}: End")
コード例 #9
0
def test_fc_secure_comm(input_sid,
                        master_address,
                        master_port,
                        setting=(512, 512)):
    test_name = "test_fc_secure_comm"
    print(f"\nTest for {test_name}: Start")
    modulus = 786433
    num_input_unit, num_output_unit = setting
    data_bit = 17

    print(f"Setting fc: "
          f"num_input_unit: {num_input_unit}, "
          f"num_output_unit: {num_output_unit}")

    data_range = 2**data_bit
    x_shape = [num_input_unit]
    w_shape = [num_output_unit, num_input_unit]
    b_shape = [num_output_unit]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                   get_prod(w_shape)).reshape(w_shape)
    bias = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                 get_prod(b_shape)).reshape(b_shape)
    input = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  get_prod(x_shape)).reshape(x_shape)
    input_c = generate_random_mask(modulus, x_shape)
    input_s = pmod(input - input_c, modulus)

    def check_correctness_online(x, w, output_s, output_c):
        actual = pmod(output_s.cuda() + output_c.cuda(), modulus)
        torch_x = x.reshape([1] + x_shape).cuda().double()
        torch_w = w.reshape(w_shape).cuda().double()
        expected = torch.mm(torch_x, torch_w.t())
        if bias is not None:
            expected += bias.cuda().double()
        expected = pmod(expected.reshape(output_s.shape), modulus)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + " online",
                                get_relative=True)

    def test_server():
        rank = Config.server_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = FcSecureServer(modulus, data_range, num_input_unit,
                              num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight, bias=bias)
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(input_s)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_online(input, weight, prot.output_s, output_c)

        end_communicate()

    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = FcSecureClient(modulus, data_range, num_input_unit,
                              num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online()
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()

    marshal_funcs([test_server, test_client])
    print(f"\nTest for {test_name}: End")
コード例 #10
0
def test_fc_fhe_comm():
    test_name = "test_fc_fhe_comm"
    print(f"\nTest for {test_name}: Start")
    modulus = 786433
    num_input_unit = 512
    num_output_unit = 512
    data_bit = 17

    print(f"Setting: num_input_unit {num_input_unit}, "
          f"num_output_unit: {num_output_unit}")

    data_range = 2**data_bit
    x_shape = [num_input_unit]
    w_shape = [num_output_unit, num_input_unit]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                   get_prod(w_shape)).reshape(w_shape)
    input_mask = gen_unirand_int_grain(0, modulus - 1,
                                       get_prod(x_shape)).reshape(x_shape)

    # input_mask = torch.arange(get_prod(x_shape)).reshape(x_shape)

    def check_correctness_offline(x, w, output_mask, output_c):
        actual = pmod(output_c.cuda() - output_mask.cuda(), modulus)
        torch_x = x.reshape([1] + x_shape).cuda().double()
        torch_w = w.reshape(w_shape).cuda().double()
        expected = torch.mm(torch_x, torch_w.t())
        expected = pmod(expected.reshape(output_mask.shape), modulus)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + " offline",
                                get_relative=True)

    def test_server():
        init_communicate(Config.server_rank)
        prot = FcFheServer(modulus, data_range, num_input_unit,
                           num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_offline(input_mask, weight, prot.output_mask_s,
                                  output_c)

        end_communicate()

    def test_client():
        init_communicate(Config.client_rank)
        prot = FcFheClient(modulus, data_range, num_input_unit,
                           num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_mask)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()

    marshal_funcs([test_server, test_client])
    print(f"\nTest for {test_name}: End")
コード例 #11
0
def secure_vgg(input_sid,
               master_addr,
               master_port,
               model_name_base="vgg_swalp"):
    test_name = "secure inference"
    print(f"\nTest for {test_name}: Start")

    #normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    batch_size = 1

    net_state_name, config_name = get_net_config_name(model_name_base)
    print(f"net_state going to load: {net_state_name}")
    print(f"store_configs going to load: {config_name}")

    store_configs = np.load(config_name, allow_pickle="TRUE").item()

    def get_plain_net():
        if model_name_base == "vgg_swalp":
            net = ModulusNet(store_configs)
        elif model_name_base == "vgg_idc_swalp":
            net = ModulusNet(store_configs, 2)
        elif model_name_base == "vgg_cifar100":
            net = ModulusNet(store_configs, 100)
        elif model_name_base == "vgg16_cifar100":
            net = ModulusNet_vgg16(store_configs, 100)
        elif model_name_base == "vgg16_cifar10":
            net = ModulusNet_vgg16(store_configs, 10)
        elif model_name_base == "minionn_maxpool":
            net = ModulusNet_MiniONN(store_configs)
        else:
            raise Exception("Unknown: {model_name_base}")
        net_state = torch.load(net_state_name)
        device = torch.device("cuda:0")
        net.load_weight_bias(net_state)
        net.to(device)
        return net

    def get_secure_nn():
        if model_name_base == "vgg_swalp":
            return generate_secure_vgg(10)
        elif model_name_base == "vgg_idc_swalp":
            return generate_secure_vgg(2)
        elif model_name_base == "vgg_cifar100":
            return generate_secure_vgg(100)
        elif model_name_base == "vgg16_cifar100":
            return generate_secure_vgg16(100)
        elif model_name_base == "vgg16_cifar10":
            return generate_secure_vgg16(10)
        elif model_name_base == "minionn_cifar10":
            return generate_secure_minionn("avgpool")
        elif model_name_base == "minionn_maxpool":
            return generate_secure_minionn("maxpool")
        else:
            raise Exception("Unknown: {model_name_base}")

    def check_correctness(self, input_img, output, modulus):
        plain_net = get_plain_net()
        expected = plain_net(
            input_img.reshape([1] + list(input_img.shape)).cuda())
        expected = nmod(expected.reshape(expected.shape[1:]), modulus)
        actual = nmod(output, modulus).cuda()
        print("expected", expected)
        print("actual", actual)
        compare_expected_actual(expected,
                                actual,
                                name="secure_vgg",
                                get_relative=True)

        _, expected_max = torch.max(expected, 0)
        _, actual_max = torch.max(actual, 0)
        print(
            f"expected_max: {expected_max}, actual_max: {actual_max}, Match: {expected_max == actual_max}"
        )

    # check_correctness(None, torch.zeros([3, 32, 32]) + 1, torch.zeros(10), secure_nn.q_23)

    def test_server():
        rank = Config.server_rank
        sys.stdout = Logger()
        traffic_record = TrafficRecord()
        secure_nn = get_secure_nn()
        secure_nn.set_rank(rank).init_communication(master_address=master_addr,
                                                    master_port=master_port)
        warming_up_cuda()
        secure_nn.fhe_builder_sync()
        load_trunc_params(secure_nn, store_configs)

        net_state = torch.load(net_state_name)
        load_weight_params(secure_nn, store_configs, net_state)

        meta_rg = MetaTruncRandomGenerator()
        meta_rg.reset_seed()

        with NamedTimerInstance("Server Offline"):
            secure_nn.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            secure_nn.online()
            torch_sync()
        traffic_record.reset("server-online")

        secure_nn.check_correctness(check_correctness)
        secure_nn.check_layers(get_plain_net, get_hooking_lst(model_name_base))
        secure_nn.end_communication()

    def test_client():
        rank = Config.client_rank
        sys.stdout = Logger()
        traffic_record = TrafficRecord()
        secure_nn = get_secure_nn()
        secure_nn.set_rank(rank).init_communication(master_address=master_addr,
                                                    master_port=master_port)
        warming_up_cuda()
        secure_nn.fhe_builder_sync()

        load_trunc_params(secure_nn, store_configs)

        def input_shift(data):
            first_layer_name = "conv1"
            return data_shift(data,
                              store_configs[first_layer_name + "ForwardX"])

        def testset():
            if model_name_base in ["vgg16_cifar100"]:
                return torchvision.datasets.CIFAR100(
                    root='./data',
                    train=False,
                    download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.4914, 0.4822, 0.4465),
                                             (0.2023, 0.1994, 0.2010)),
                        input_shift
                    ]))
            elif model_name_base in ["vgg16_cifar10", "minionn_maxpool"]:
                return torchvision.datasets.CIFAR10(
                    root='./data',
                    train=False,
                    download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.4914, 0.4822, 0.4465),
                                             (0.2023, 0.1994, 0.2010)),
                        input_shift
                    ]))

        testloader = torch.utils.data.DataLoader(testset(),
                                                 batch_size=batch_size,
                                                 shuffle=True,
                                                 num_workers=2)

        data_iter = iter(testloader)
        image, truth = next(data_iter)
        image = image.reshape(secure_nn.get_input_shape())
        secure_nn.fill_input(image)

        with NamedTimerInstance("Client Offline"):
            secure_nn.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            secure_nn.online()
            torch_sync()
        traffic_record.reset("client-online")

        secure_nn.check_correctness(check_correctness, truth=truth)
        secure_nn.check_layers(get_plain_net, get_hooking_lst(model_name_base))
        secure_nn.end_communication()

    if input_sid == Config.server_rank:
        # test_server()
        marshal_funcs([test_server])
    elif input_sid == Config.client_rank:
        # test_client()
        marshal_funcs([test_client])
    elif input_sid == Config.both_rank:
        marshal_funcs([test_server, test_client])
    else:
        raise Exception(f"Unknown input_sid: {input_sid}")

    print(f"\nTest for {test_name}: End")
コード例 #12
0
def test_secure_nn():
    test_name = "test_secure_nn"
    print(f"\nTest for {test_name}: Start")
    data_bit = 5
    work_bit = 17
    data_range = 2**data_bit
    q_16 = 12289
    # q_23 = 786433
    q_23 = 7340033
    # q_23 = 8273921
    input_img_hw = 16
    input_channel = 3
    pow_to_div = 2

    fhe_builder_16 = FheBuilder(q_16, 2048)
    fhe_builder_23 = FheBuilder(q_23, 8192)

    input_shape = [input_channel, input_img_hw, input_img_hw]

    context = SecureLayerContext(work_bit, data_bit, q_16, q_23,
                                 fhe_builder_16, fhe_builder_23,
                                 test_name + "_context")

    input_layer = InputSecureLayer(input_shape, "input_layer")
    conv1 = Conv2dSecureLayer(3, 5, 3, "conv1", padding=1)
    relu1 = ReluSecureLayer("relu1")
    trunc1 = TruncSecureLayer("trunc1")
    pool1 = Maxpool2x2SecureLayer("pool1")
    conv2 = Conv2dSecureLayer(5, 10, 3, "conv2", padding=1)
    flatten = FlattenSecureLayer("flatten")
    fc1 = FcSecureLayer(32, "fc1")
    output_layer = OutputSecureLayer("output_layer")

    secure_nn = SecureNeuralNetwork("secure_nn")
    secure_nn.load_layers([
        input_layer, conv1, pool1, relu1, trunc1, conv2, flatten, fc1,
        output_layer
    ])
    # secure_nn.load_layers([input_layer, relu1, trunc1, output_layer])
    secure_nn.load_context(context)

    def generate_random_data(shape):
        return gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                     get_prod(shape)).reshape(shape)

    conv1_w = generate_random_data(conv1.weight_shape)
    conv2_w = generate_random_data(conv2.weight_shape)
    fc1_w = generate_random_data(fc1.weight_shape)

    def check_correctness(input_img, output):
        torch_pool1 = torch.nn.MaxPool2d(2)

        x = input_img.to(Config.device).double()
        x = x.reshape([1] + list(x.shape))
        x = pmod(F.conv2d(x, conv1_w.to(Config.device).double(), padding=1),
                 q_23)
        x = pmod(F.relu(nmod(x, q_23)), q_23)
        x = pmod(torch_pool1(nmod(x, q_23)), q_23)
        x = pmod(x // (2**pow_to_div), q_23)
        x = pmod(F.conv2d(x, conv2_w.to(Config.device).double(), padding=1),
                 q_23)
        x = x.view(-1)
        x = pmod(
            torch.mm(x.view(1, -1),
                     fc1_w.to(Config.device).double().t()).view(-1), q_23)

        expected = x
        actual = pmod(output, q_23)
        if len(expected.shape) == 4 and expected.shape[0] == 1:
            expected = expected.reshape(expected.shape[1:])
        compare_expected_actual(expected,
                                actual,
                                name=test_name,
                                get_relative=True)

    def test_server():
        rank = Config.server_rank
        init_communicate(rank)
        context.set_rank(rank)

        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        conv1.load_weight(conv1_w)
        conv2.load_weight(conv2_w)
        fc1.load_weight(fc1_w)
        trunc1.set_div_to_pow(pow_to_div)

        with NamedTimerInstance("Server Offline"):
            secure_nn.offline()
            torch_sync()

        with NamedTimerInstance("Server Online"):
            secure_nn.online()
            torch_sync()

        end_communicate()

    def test_client():
        rank = Config.client_rank
        init_communicate(rank)
        context.set_rank(rank)

        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        input_img = generate_random_data(input_shape)
        trunc1.set_div_to_pow(pow_to_div)
        secure_nn.feed_input(input_img)

        with NamedTimerInstance("Client Offline"):
            secure_nn.offline()
            torch_sync()

        with NamedTimerInstance("Client Online"):
            secure_nn.online()
            torch_sync()

        check_correctness(input_img, secure_nn.get_output())
        end_communicate()

    marshal_funcs([test_server, test_client])
    print(f"\nTest for {test_name}: End")
コード例 #13
0
def test_avgpool2x2_dgk(input_sid,
                        master_address,
                        master_port,
                        num_elem=2**17):
    test_name = "Avgpool2x2"
    print(f"\nTest for {test_name}: Start")
    data_bit = 20
    work_bit = 20
    data_range = 2**data_bit
    q_16 = 12289
    # q_23 = 786433
    q_23 = 7340033
    img_hw = 4
    print(f"Number of element: {num_elem}")

    fhe_builder_16 = FheBuilder(q_16, Config.n_16)
    fhe_builder_16.generate_keys()
    fhe_builder_23 = FheBuilder(q_23, Config.n_23)
    fhe_builder_23.generate_keys()

    img = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                num_elem)
    img_s = gen_unirand_int_grain(0, q_23 - 1, num_elem)
    img_c = pmod(img - img_s, q_23)

    def check_correctness_online(img, max_s, max_c):
        img = torch.where(img < q_23 // 2, img, img - q_23).cuda()
        pool = torch.nn.AvgPool2d(2)
        expected = pool(img.double().reshape([-1, img_hw, img_hw
                                              ])).reshape(-1) * 4
        expected = pmod(expected, q_23)
        actual = pmod(max_s + max_c, q_23)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + "_online",
                                get_relative=True)

    def test_server():
        init_communicate(Config.server_rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()
        prot = Avgpool2x2Server(num_elem, q_23, q_16, work_bit, data_bit,
                                img_hw, fhe_builder_16, fhe_builder_23,
                                "avgpool")

        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            prot.online(img_s)
            torch_sync()
        traffic_record.reset("server-online")

        blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_res_c")
        blob_out_c.prepare_recv()
        torch_sync()
        out_c = blob_out_c.get_recv()
        check_correctness_online(img, prot.out_s, out_c)

        end_communicate()

    def test_client():
        init_communicate(Config.client_rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()
        prot = Avgpool2x2Client(num_elem, q_23, q_16, work_bit, data_bit,
                                img_hw, fhe_builder_16, fhe_builder_23,
                                "avgpool")

        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            prot.online(img_c)
            torch_sync()
        traffic_record.reset("client-online")

        blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_res_c")
        torch_sync()
        blob_out_c.send(prot.out_c)
        end_communicate()

    if input_sid == Config.both_rank:
        marshal_funcs([test_server, test_client])
    elif input_sid == Config.server_rank:
        test_server()
    elif input_sid == Config.client_rank:
        test_client()

    print(f"\nTest for {test_name}: End")
コード例 #14
0
def test_relu_dgk(input_sid, master_address, master_port, num_elem=2**17):
    test_name = "Relu Dgk"
    print(f"\nTest for {test_name}: Start")
    data_bit = 20
    work_bit = 20
    data_range = 2 ** data_bit
    # q_16 = 12289
    q_16 = Config.q_16
    # q_23 = 786433
    q_23 = Config.q_23
    # q_23 = 8273921
    print(f"Number of element: {num_elem}")

    def check_correctness_online(a, c_s, c_c):
        a = a.to(Config.device)
        expected = pmod(torch.max(a, torch.zeros_like(a)), q_23)
        actual = pmod(c_s + c_c, q_23)
        compare_expected_actual(expected, actual, name="relu_dgk_online", get_relative=True)

    def test_server():
        rank = Config.server_rank
        init_communicate(Config.server_rank, master_address=master_address, master_port=master_port)
        traffic_record = TrafficRecord()
        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        prot = ReluDgkServer(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk")

        blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a")
        blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s")
        torch_sync()
        blob_a_s.prepare_recv()
        a_s = blob_a_s.get_recv()

        torch_sync()
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            prot.online(a_s)
            torch_sync()
        traffic_record.reset("server-online")

        blob_max_s.send(prot.max_s)
        torch.cuda.empty_cache()
        end_communicate()

    def test_client():
        rank = Config.client_rank
        init_communicate(rank, master_address=master_address, master_port=master_port)
        traffic_record = TrafficRecord()
        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        a = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem)
        a_c = gen_unirand_int_grain(0, q_23 - 1, num_elem)
        a_s = pmod(a - a_c, q_23)

        prot = ReluDgkClient(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk")

        blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a")
        blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s")
        torch_sync()
        blob_a_s.send(a_s)
        blob_max_s.prepare_recv()

        torch_sync()
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            prot.online(a_c)
            torch_sync()
        traffic_record.reset("client-online")

        max_s = blob_max_s.get_recv()
        check_correctness_online(a, max_s, prot.max_c)

        torch.cuda.empty_cache()
        end_communicate()

    if input_sid == Config.both_rank:
        marshal_funcs([test_server, test_client])
    elif input_sid == Config.server_rank:
        marshal_funcs([test_server])
        # test_server()
    elif input_sid == Config.client_rank:
        marshal_funcs([test_client])
        # test_client()

    print(f"\nTest for {test_name}: End")
コード例 #15
0
def test_dgk(input_sid, master_address, master_port, num_elem=2**17):
    print("\nTest for Dgk Basic: Start")
    data_bit = 20
    work_bit = 20
    # q_23 = 786433
    # q_23 = 8273921
    # q_23 = 4079617
    # n_23, q_23 = 8192, 7340033
    n_23, q_23 = Config.n_23, Config.q_23
    # n_16, q_16 = 2048, 12289
    # n_16, q_16 = 8192, 65537
    n_16, q_16 = Config.n_16, Config.q_16
    print(f"Number of element: {num_elem}")

    data_range = 2**data_bit
    work_range = 2**work_bit

    def check_correctness(x, y, dgk_x_leq_y_s, dgk_x_leq_y_c):
        x = torch.where(x < q_23 // 2, x, x - q_23).to(Config.device)
        y = torch.where(y < q_23 // 2, y, y - q_23).to(Config.device)
        expected_x_leq_y = (x <= y)
        dgk_x_leq_y_recon = pmod(dgk_x_leq_y_s + dgk_x_leq_y_c, q_23)
        compare_expected_actual(expected_x_leq_y,
                                dgk_x_leq_y_recon,
                                name="DGK x <= y",
                                get_relative=True)
        print(torch.sum(expected_x_leq_y != dgk_x_leq_y_recon))

    def check_correctness_mod_div(r, z, correct_mod_div_work_s,
                                  correct_mod_div_work_c):
        elem_zeros = torch.zeros(num_elem).to(Config.device)
        expected = torch.where(r > z, q_23 // work_range + elem_zeros,
                               elem_zeros)
        actual = pmod(correct_mod_div_work_s + correct_mod_div_work_c, q_23)
        compare_expected_actual(expected,
                                actual,
                                get_relative=True,
                                name="mod_div_online")

    def test_server():
        rank = Config.server_rank
        init_communicate(Config.server_rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        dgk = DgkBitServer(num_elem, q_23, q_16, work_bit, data_bit,
                           fhe_builder_16, fhe_builder_23, "DgkBitTest")

        x_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "x")
        y_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y")
        y_sub_x_s_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                   "y_sub_x_s")
        x_blob.prepare_recv()
        y_blob.prepare_recv()
        y_sub_x_s_blob.prepare_recv()
        torch_sync()
        x = x_blob.get_recv()
        y = y_blob.get_recv()
        y_sub_x_s = y_sub_x_s_blob.get_recv()

        torch_sync()
        with NamedTimerInstance("Server Offline"):
            dgk.offline()
        # y_sub_x_s = pmod(y_s.to(Config.device) - x_s.to(Config.device), q_23)
        torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            dgk.online(y_sub_x_s)
        traffic_record.reset("server-online")

        dgk_x_leq_y_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                       "dgk_x_leq_y_c")
        correct_mod_div_work_c_blob = BlobTorch(num_elem, torch.float,
                                                dgk.comm_base,
                                                "correct_mod_div_work_c")
        z_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "z")
        dgk_x_leq_y_c_blob.prepare_recv()
        correct_mod_div_work_c_blob.prepare_recv()
        z_blob.prepare_recv()
        torch_sync()
        dgk_x_leq_y_c = dgk_x_leq_y_c_blob.get_recv()
        correct_mod_div_work_c = correct_mod_div_work_c_blob.get_recv()
        z = z_blob.get_recv()
        check_correctness(x, y, dgk.dgk_x_leq_y_s, dgk_x_leq_y_c)
        check_correctness_mod_div(dgk.r, z, dgk.correct_mod_div_work_s,
                                  correct_mod_div_work_c)
        end_communicate()

    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, n_16)
        fhe_builder_23 = FheBuilder(q_23, n_23)
        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        dgk = DgkBitClient(num_elem, q_23, q_16, work_bit, data_bit,
                           fhe_builder_16, fhe_builder_23, "DgkBitTest")

        x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  num_elem)
        y = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  num_elem)
        x_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                    num_elem)
        y_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                    num_elem)
        x_s = pmod(x - x_c, q_23)
        y_s = pmod(y - y_c, q_23)
        y_sub_x_s = pmod(y_s - x_s, q_23)

        x_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "x")
        y_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y")
        y_sub_x_s_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                   "y_sub_x_s")
        torch_sync()
        x_blob.send(x)
        y_blob.send(y)
        y_sub_x_s_blob.send(y_sub_x_s)

        torch_sync()
        with NamedTimerInstance("Client Offline"):
            dgk.offline()
        y_sub_x_c = pmod(y_c - x_c, q_23)
        traffic_record.reset("client-offline")
        torch_sync()

        with NamedTimerInstance("Client Online"):
            dgk.online(y_sub_x_c)
        traffic_record.reset("client-online")

        dgk_x_leq_y_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                       "dgk_x_leq_y_c")
        correct_mod_div_work_c_blob = BlobTorch(num_elem, torch.float,
                                                dgk.comm_base,
                                                "correct_mod_div_work_c")
        z_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "z")
        torch_sync()
        dgk_x_leq_y_c_blob.send(dgk.dgk_x_leq_y_c)
        correct_mod_div_work_c_blob.send(dgk.correct_mod_div_work_c)
        z_blob.send(dgk.z)
        end_communicate()

    if input_sid == Config.both_rank:
        marshal_funcs([test_server, test_client])
    elif input_sid == Config.server_rank:
        marshal_funcs([test_server])
    elif input_sid == Config.client_rank:
        marshal_funcs([test_client])

    print("\nTest for Dgk Basic: End")