Example #1
0
def correctness_vgg(self, input_img, output, modulus):
    return
    torch_pool1 = torch.nn.MaxPool2d(2)
    torch_pool2 = torch.nn.MaxPool2d(2)

    x = input_img.cuda().double()
    x = x.reshape([1] + list(x.shape))
    x = pmod(F.conv2d(x, self.layers[1].weight.cuda().double(), padding=1), modulus)
    x = pmod(F.relu(nmod(x, modulus)), modulus)
    x = pmod(F.conv2d(x, self.layers[3].weight.cuda().double(), padding=1), modulus)
    x = pmod(torch_pool1(nmod(x, modulus)), modulus)
    x = pmod(F.relu(nmod(x, modulus)), modulus)
    x = pmod(F.conv2d(x, self.layers[6].weight.cuda().double(), padding=1), modulus)
    x = pmod(F.relu(nmod(x, modulus)), modulus)
    x = pmod(F.conv2d(x, self.layers[8].weight.cuda().double(), padding=1), modulus)
    x = pmod(torch_pool2(nmod(x, modulus)), modulus)
    x = pmod(F.relu(nmod(x, modulus)), modulus)
    x = pmod(F.conv2d(x, self.layers[11].weight.cuda().double(), padding=1), modulus)
    x = pmod(F.relu(nmod(x, modulus)), modulus)
    x = pmod(F.conv2d(x, self.layers[13].weight.cuda().double(), padding=1), modulus)
    x = pmod(F.relu(nmod(x, modulus)), modulus)
    x = pmod(F.conv2d(x, self.layers[15].weight.cuda().double(), padding=1), modulus)
    x = pmod(F.relu(nmod(x, modulus)), modulus)
    x = x.view(-1)
    x = pmod(torch.mm(x.view(1, -1), self.layers[18].weight.cuda().double().t()).view(-1), modulus)

    expected = x
    actual = pmod(output, modulus)
    if len(expected.shape) == 4 and expected.shape[0] == 1:
        expected = expected.reshape(expected.shape[1:])
    compare_expected_actual(expected, actual, name="minionn_mnist", get_relative=True)
Example #2
0
 def check_correctness_online(output, input_s, input_c):
     expected = pmod(output.cuda(), modulus)
     actual = pmod(input_s.cuda() + input_c.cuda(), modulus)
     compare_expected_actual(expected,
                             actual,
                             name=test_name + " online",
                             get_relative=True)
Example #3
0
    def conv2d(self, x, w):
        self.load_and_ntt_w(w)
        self.load_and_ntt_x(x)

        sub_x = x[0]
        sub_ntted_x = self.ntted_x[0]
        inv_sub_ntted_x = self.ntt_matmul.intt2d(sub_ntted_x.double())
        trun_inv_sub_ntted_x = inv_sub_ntted_x[:self.img_hw, :self.img_hw]
        compare_expected_actual(pmod(sub_x, self.modulus),
                                pmod(trun_inv_sub_ntted_x, self.modulus),
                                get_relative=True,
                                name="sub_x")

        sub_w = w[0, 0]
        sub_ntted_w = self.ntted_w[0, 0]
        inv_sub_ntted_w = self.ntt_matmul.intt2d(sub_ntted_w.double())
        trun_inv_sub_ntted_w = inv_sub_ntted_w[:self.filter_hw, :self.
                                               filter_hw]
        expected = pmod(sub_w, self.modulus).rot90(2)
        actual = pmod(trun_inv_sub_ntted_w, self.modulus)
        compare_expected_actual(expected,
                                actual,
                                get_relative=True,
                                name="sub_w")

        dotted = self.conv2d_ntted_single_channel(sub_ntted_x, sub_ntted_w)
        sub_y = self.transform_y_single_channel(dotted)

        # sub_w = w[0, 0]
        # sub_ntted_w = self.ntted_w[0]
        return self.conv2d_loaded()
Example #4
0
def test_torch_ntt():
    modulus = 786433
    img_hw = 64
    filter_hw = 3
    padding = 1
    data_bit = 17

    len_vector = img_hw
    data_range = 2**data_bit
    root, mod, ntt_mat, inv_mat = generate_ntt_param(modulus, len_vector,
                                                     data_range)

    x = np.arange(img_hw**2).reshape([img_hw, img_hw]).astype(np.int)
    w = np.arange(filter_hw**2).reshape([filter_hw, filter_hw]).astype(np.int)
    x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              img_hw**2).reshape([img_hw, img_hw]).double()
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              filter_hw**2).reshape([filter_hw,
                                                     filter_hw]).double()

    ntt_mat = ntt_mat.double()
    inv_mat = inv_mat.double()

    with NamedTimerInstance("Mat NTT 2d"):
        ntted = ntt_mat2d(ntt_mat, mod, x)
    reved = ntt_mat2d(inv_mat, mod, ntted)
    expected = pmod(x, modulus).type(torch.int)
    actual = pmod(reved, modulus).type(torch.int)
    compare_expected_actual(expected, actual, name="ntt", get_relative=True)
Example #5
0
    def test_client():
        init_communicate(Config.client_rank)
        fhe_builder = FheBuilder(modulus, degree)
        comm_fhe_builder = CommFheBuilder(Config.client_rank, fhe_builder,
                                          "fhe_builder")
        fhe_builder.generate_keys()

        comm_fhe_builder.send_public_key()

        blob_ori = BlobFheEnc(num_elem, comm_fhe_builder, ori_tag)
        blob_ori.prepare_recv()
        dec = blob_ori.get_recv_decrypt()
        compare_expected_actual(ori, dec, get_relative=True, name=ori_tag)

        enc_pk = fhe_builder.build_enc(num_elem)
        comm_fhe_builder.recv_enc(enc_pk, pk_tag)
        comm_fhe_builder.wait_enc(pk_tag)
        actual_tensor_pk = fhe_builder.decrypt_to_torch(enc_pk)
        compare_expected_actual(expected_tensor_pk,
                                actual_tensor_pk,
                                get_relative=True,
                                name="Recovering to test pk")

        comm_fhe_builder.send_secret_key()

        enc_sk = fhe_builder.build_enc_from_torch(expected_tensor_sk)
        comm_fhe_builder.send_enc(enc_sk, sk_tag)

        dist.destroy_process_group()
Example #6
0
def test_fhe_builder_rebuild():
    print()
    print("Test for FheBuilder rebuild: start")
    modulus, degree = 12289, 2048
    num_elem = 2 ** 17
    print(f"modulus: {modulus}, degree: {degree}")

    fhe_builder_client = FheBuilder(modulus, degree)
    fhe_builder_server = FheBuilder(modulus, degree)
    fhe_builder_client.generate_keys()

    tensor_server = torch.ones(num_elem).float() + 5
    fhe_builder_server.get_public_key_buffer().copy_(fhe_builder_client.get_public_key_buffer())
    fhe_builder_server.build_from_loaded_public_key()
    enc_pk = fhe_builder_server.build_enc_from_torch(tensor_server)
    tensor_client = fhe_builder_client.decrypt_to_torch(enc_pk)
    compare_expected_actual(tensor_server, tensor_client, get_relative=True, name="Rebuilding public key")

    tensor_client_sk = torch.ones(num_elem).float() + 12
    fhe_builder_server.get_secret_key_buffer().copy_(fhe_builder_client.get_secret_key_buffer())
    fhe_builder_server.build_from_loaded_secret_key()
    enc_sk = fhe_builder_client.build_enc_from_torch(tensor_client_sk)
    tensor_server_sk = fhe_builder_server.decrypt_to_torch(enc_sk)
    compare_expected_actual(tensor_client_sk, tensor_server_sk, get_relative=True, name="Rebuilding secret key")

    print("Test for FheBuilder rebuild: end")
    print()
Example #7
0
    def check_correctness(input_img, output):
        torch_pool1 = torch.nn.MaxPool2d(2)

        x = input_img.to(Config.device).double()
        x = x.reshape([1] + list(x.shape))
        x = pmod(F.conv2d(x, conv1_w.to(Config.device).double(), padding=1),
                 q_23)
        x = pmod(F.relu(nmod(x, q_23)), q_23)
        x = pmod(torch_pool1(nmod(x, q_23)), q_23)
        x = pmod(x // (2**pow_to_div), q_23)
        x = pmod(F.conv2d(x, conv2_w.to(Config.device).double(), padding=1),
                 q_23)
        x = x.view(-1)
        x = pmod(
            torch.mm(x.view(1, -1),
                     fc1_w.to(Config.device).double().t()).view(-1), q_23)

        expected = x
        actual = pmod(output, q_23)
        if len(expected.shape) == 4 and expected.shape[0] == 1:
            expected = expected.reshape(expected.shape[1:])
        compare_expected_actual(expected,
                                actual,
                                name=test_name,
                                get_relative=True)
Example #8
0
 def check_correctness_offline(u, v, z_s, z_c):
     expected = pmod(
         u.double().to(Config.device) * v.double().to(Config.device),
         modulus)
     actual = pmod(z_s + z_c, modulus)
     compare_expected_actual(expected,
                             actual,
                             name="shares_mult_offline",
                             get_relative=True)
Example #9
0
def correctness_fc(self, input_img, output, modulus):
    x = input_img.cuda().double()
    x = pmod(torch.mm(x.view(1, -1), self.layers[1].weight.cuda().double().t()).view(-1), modulus)

    expected = x
    actual = pmod(output, modulus)
    if len(expected.shape) == 4 and expected.shape[0] == 1:
        expected = expected.reshape(expected.shape[1:])
    compare_expected_actual(expected, actual, name="fc", get_relative=True)
Example #10
0
 def check_correctness_online(a, b, c_s, c_c):
     expected = pmod(
         a.double().to(Config.device) * b.double().to(Config.device),
         modulus)
     actual = pmod(c_s + c_c, modulus)
     compare_expected_actual(expected,
                             actual,
                             name="shares_mult_online",
                             get_relative=True)
Example #11
0
 def check_correctness_offline(x, w, output_mask, output_c):
     actual = pmod(output_c.cuda() - output_mask.cuda(), modulus)
     torch_x = x.reshape([1] + x_shape).cuda().double()
     torch_w = w.reshape(w_shape).cuda().double()
     expected = F.conv2d(torch_x, torch_w, padding=padding)
     expected = pmod(expected.reshape(output_mask.shape), modulus)
     compare_expected_actual(expected,
                             actual,
                             name=test_name + " offline",
                             get_relative=True)
Example #12
0
 def check_correctness(x, y, dgk_x_leq_y_s, dgk_x_leq_y_c):
     x = torch.where(x < q_23 // 2, x, x - q_23).to(Config.device)
     y = torch.where(y < q_23 // 2, y, y - q_23).to(Config.device)
     expected_x_leq_y = (x <= y)
     dgk_x_leq_y_recon = pmod(dgk_x_leq_y_s + dgk_x_leq_y_c, q_23)
     compare_expected_actual(expected_x_leq_y,
                             dgk_x_leq_y_recon,
                             name="DGK x <= y",
                             get_relative=True)
     print(torch.sum(expected_x_leq_y != dgk_x_leq_y_recon))
Example #13
0
 def check_correctness_online(img, max_s, max_c):
     img = torch.where(img < q_23 // 2, img, img - q_23).to(Config.device)
     pool = torch.nn.MaxPool2d(2)
     expected = pool(img.reshape([-1, img_hw, img_hw])).reshape(-1)
     expected = pmod(expected, q_23)
     actual = pmod(max_s + max_c, q_23)
     compare_expected_actual(expected,
                             actual,
                             name="maxpool2x2_dgk_online",
                             get_relative=True)
Example #14
0
 def check_correctness_mod_div(r, z, correct_mod_div_work_s,
                               correct_mod_div_work_c):
     elem_zeros = torch.zeros(num_elem).to(Config.device)
     expected = torch.where(r > z, q_23 // work_range + elem_zeros,
                            elem_zeros)
     actual = pmod(correct_mod_div_work_s + correct_mod_div_work_c, q_23)
     compare_expected_actual(expected,
                             actual,
                             get_relative=True,
                             name="mod_div_online")
Example #15
0
def correctness_relu_only_nn(self, input_img, output, modulus):
    x = input_img.cuda().double()
    x = x.reshape([1] + list(x.shape))
    x = pmod(F.relu(nmod(x, modulus)), modulus)

    expected = x
    actual = pmod(output, modulus)
    if len(expected.shape) == 4 and expected.shape[0] == 1:
        expected = expected.reshape(expected.shape[1:])
    compare_expected_actual(expected, actual, name="relu_only_nn", get_relative=True)
Example #16
0
def correctness_conv2d(self, input_img, output, modulus):
    x = input_img.cuda().double()
    x = x.reshape([1] + list(x.shape))
    x = pmod(F.conv2d(x, self.layers[1].weight.cuda().double(), padding=1), modulus)

    expected = x
    actual = pmod(output, modulus)
    if len(expected.shape) == 4 and expected.shape[0] == 1:
        expected = expected.reshape(expected.shape[1:])
    compare_expected_actual(expected, actual, name="conv2d", get_relative=True)
Example #17
0
def test_conv2d_fhe_ntt_single_thread():
    modulus = 786433
    img_hw = 16
    filter_hw = 3
    padding = 1
    num_input_channel = 64
    num_output_channel = 128
    data_bit = 17
    data_range = 2**data_bit

    x_shape = [num_input_channel, img_hw, img_hw]
    w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape)
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(w_shape)).reshape(w_shape)
    x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape)
    # x = torch.arange(get_prod(x_shape)).reshape(x_shape)
    # w = torch.arange(get_prod(w_shape)).reshape(w_shape)

    warming_up_cuda()
    prot = Conv2dFheNttSingleThread(modulus, data_range, img_hw, filter_hw,
                                    num_input_channel, num_output_channel,
                                    fhe_builder, "test_conv2d_fhe_ntt",
                                    padding)

    print("prot.num_input_batch", prot.num_input_batch)
    print("prot.num_output_batch", prot.num_output_batch)

    with NamedTimerInstance("encoding x"):
        prot.encode_input_to_fhe_batch(x)
    with NamedTimerInstance("conv2d with w"):
        prot.compute_conv2d(w)
    with NamedTimerInstance("conv2d masking output"):
        prot.masking_output()
    with NamedTimerInstance("decoding output"):
        y = prot.decode_output_from_fhe_batch()
    # actual = pmod(y, modulus)
    actual = pmod(y - prot.output_mask_s, modulus)
    # print("actual\n", actual)

    torch_x = x.reshape([1] + x_shape).double()
    torch_w = w.reshape(w_shape).double()
    with NamedTimerInstance("Conv2d Torch"):
        expected = F.conv2d(torch_x, torch_w, padding=padding)
        expected = pmod(expected.reshape(prot.output_shape), modulus)
    # print("expected", expected)
    compare_expected_actual(expected,
                            actual,
                            name="test_conv2d_fhe_ntt_single_thread",
                            get_relative=True)
Example #18
0
 def check_correctness_online(x, w, b, output_s, output_c):
     actual = pmod(output_s.cuda() + output_c.cuda(), modulus)
     torch_x = x.reshape([1] + x_shape).cuda().double()
     torch_w = w.reshape(w_shape).cuda().double()
     torch_b = b.cuda().double() if b is not None else None
     expected = F.conv2d(torch_x, torch_w, padding=padding, bias=torch_b)
     expected = pmod(expected.reshape(output_s.shape), modulus)
     compare_expected_actual(expected,
                             actual,
                             name=test_name + " online",
                             get_relative=True)
Example #19
0
 def check_correctness_online(img, max_s, max_c):
     img = torch.where(img < q_23 // 2, img, img - q_23).cuda()
     pool = torch.nn.AvgPool2d(2)
     expected = pool(img.double().reshape([-1, img_hw, img_hw
                                           ])).reshape(-1) * 4
     expected = pmod(expected, q_23)
     actual = pmod(max_s + max_c, q_23)
     compare_expected_actual(expected,
                             actual,
                             name=test_name + "_online",
                             get_relative=True)
Example #20
0
    def test_1d_server():
        init_communicate(Config.server_rank)
        shape = num_elem
        tensor = gen_unirand_int_grain(0, modulus, shape)
        refresher = EncRefresherServer(shape, fhe_builder, "test_1d_refresher")
        enc = fhe_builder.build_enc_from_torch(tensor)
        refreshed = refresher.request(enc)
        tensor_refreshed = fhe_builder.decrypt_to_torch(refreshed)
        compare_expected_actual(tensor, tensor_refreshed, get_relative=True, name="1d_refresh")

        end_communicate()
Example #21
0
def test_fc_fhe_single_thread():
    test_name = "test_fc_fhe_single_thread"
    print(f"\nTest for {test_name}: Start")
    modulus = Config.q_23
    num_input_unit = 512
    num_output_unit = 512
    data_bit = 17

    data_range = 2**data_bit
    x_shape = [num_input_unit]
    w_shape = [num_output_unit, num_input_unit]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape)
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(w_shape)).reshape(w_shape)
    x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape)
    # w = gen_unirand_int_grain(0, modulus, get_prod(w_shape)).reshape(w_shape)
    # x = torch.arange(get_prod(x_shape)).reshape(x_shape)
    # w = torch.arange(get_prod(w_shape)).reshape(w_shape)

    warming_up_cuda()
    prot = FcFheSingleThread(modulus, data_range, num_input_unit,
                             num_output_unit, fhe_builder, test_name)

    print("prot.num_input_batch", prot.num_input_batch)
    print("prot.num_output_batch", prot.num_output_batch)
    print("prot.num_elem_in_piece", prot.num_elem_in_piece)

    with NamedTimerInstance("encoding x"):
        prot.encode_input_to_fhe_batch(x)
    with NamedTimerInstance("conv2d with w"):
        prot.compute_with_weight(w)
    with NamedTimerInstance("conv2d masking output"):
        prot.masking_output()
    with NamedTimerInstance("decoding output"):
        y = prot.decode_output_from_fhe_batch()
    actual = pmod(y, modulus)
    actual = pmod(y - prot.output_mask_s, modulus)
    # print("actual\n", actual)

    torch_x = x.reshape([1] + x_shape).double()
    torch_w = w.reshape(w_shape).double()
    with NamedTimerInstance("Conv2d Torch"):
        expected = torch.mm(torch_x, torch_w.t())
        expected = pmod(expected.reshape(prot.output_shape), modulus)
    compare_expected_actual(expected,
                            actual,
                            name=test_name,
                            get_relative=True)
    print(f"\nTest for {test_name}: End")
Example #22
0
def correctness_maxpool2x2(self, input_img, output, modulus):
    torch_pool1 = torch.nn.MaxPool2d(2)

    x = input_img.cuda().double()
    x = x.reshape([1] + list(x.shape))
    x = pmod(torch_pool1(nmod(x, modulus)), modulus)

    expected = x
    actual = pmod(output, modulus)
    if len(expected.shape) == 4 and expected.shape[0] == 1:
        expected = expected.reshape(expected.shape[1:])
    compare_expected_actual(expected, actual, name="maxpool2x2", get_relative=True)
Example #23
0
 def check_correctness_online(x, w, output_s, output_c):
     actual = pmod(output_s.cuda() + output_c.cuda(), modulus)
     torch_x = x.reshape([1] + x_shape).cuda().double()
     torch_w = w.reshape(w_shape).cuda().double()
     expected = torch.mm(torch_x, torch_w.t())
     if bias is not None:
         expected += bias.cuda().double()
     expected = pmod(expected.reshape(output_s.shape), modulus)
     compare_expected_actual(expected,
                             actual,
                             name=test_name + " online",
                             get_relative=True)
Example #24
0
    def comm_base_server():
        init_communicate(Config.server_rank)
        comm_base = CommBase(Config.server_rank, comm_name)

        send_int16 = expect_int16.cuda()
        send_uint8 = expect_uint8.cuda()

        with NamedTimerInstance("Server float and int16"):
            comm_base.recv_torch(torch.zeros(num_elem).float(), float_tag)
            comm_base.send_torch(send_int16, int16_tag)
            comm_base.wait(float_tag)
            actual_float = comm_base.get_tensor(float_tag).cuda()

        comm_base.recv_torch(torch.zeros(num_elem).float(), float_tag)
        dist.barrier()
        with NamedTimerInstance("Server float and int16"):
            comm_base.send_torch(send_int16, int16_tag)
            comm_base.wait(float_tag)
            actual_float = comm_base.get_tensor(float_tag).cuda()

        comm_base.recv_torch(torch.zeros(num_elem).double(), double_tag)
        dist.barrier()
        with NamedTimerInstance("Server double and uint8"):
            comm_base.send_torch(send_uint8, uint8_tag)
            comm_base.wait(double_tag)
            actual_double = comm_base.get_tensor(double_tag).cuda()

        comm_base.recv_torch(torch.zeros(num_elem).double(), double_tag)
        dist.barrier()
        with NamedTimerInstance("Server double and uint8"):
            comm_base.send_torch(send_uint8, uint8_tag)
            comm_base.wait(double_tag)
            actual_double = comm_base.get_tensor(double_tag).cuda()

        dist.barrier()
        compare_expected_actual(expect_float.cuda(),
                                actual_float,
                                name="float",
                                get_relative=True)
        compare_expected_actual(expect_double.cuda(),
                                actual_double,
                                name="double",
                                get_relative=True)

        google_vm_simulator = NetworkSimulator(bandwidth=10 * (10**9),
                                               basis_latency=.001)
        with NamedTimerInstance("Simulate int16"):
            google_vm_simulator.simulate(send_int16.cpu().cuda())
        with NamedTimerInstance("Simulate uint8"):
            google_vm_simulator.simulate(send_uint8.cpu().cuda())

        dist.destroy_process_group()
Example #25
0
    def comm_base_client():
        init_communicate(Config.client_rank)
        comm_base = CommBase(Config.client_rank, comm_name)

        send_float = expect_float.cuda()
        send_double = expect_double.cuda()

        with NamedTimerInstance("Client float and int16"):
            comm_base.recv_torch(
                torch.zeros(num_elem).type(torch.int16), int16_tag)
            comm_base.send_torch(send_float, float_tag)
            comm_base.wait(int16_tag)
            actual_int16 = comm_base.get_tensor(int16_tag).cuda()

        comm_base.recv_torch(
            torch.zeros(num_elem).type(torch.int16), int16_tag)
        dist.barrier()
        with NamedTimerInstance("Client float and int16"):
            comm_base.send_torch(send_float, float_tag)
            comm_base.wait(int16_tag)
            actual_int16 = comm_base.get_tensor(int16_tag).cuda()

        comm_base.recv_torch(
            torch.zeros(num_elem).type(torch.uint8), uint8_tag)
        dist.barrier()
        with NamedTimerInstance("Client double and uint8"):
            comm_base.send_torch(send_double, double_tag)
            comm_base.wait(uint8_tag)
            actual_uint8 = comm_base.get_tensor(uint8_tag).cuda()

        comm_base.recv_torch(
            torch.zeros(num_elem).type(torch.uint8), uint8_tag)
        dist.barrier()
        with NamedTimerInstance("Client double and uint8"):
            comm_base.send_torch(send_double, double_tag)
            comm_base.wait(uint8_tag)
            actual_uint8 = comm_base.get_tensor(uint8_tag).cuda()

        dist.barrier()
        compare_expected_actual(expect_int16.cuda(),
                                actual_int16,
                                name="int16",
                                get_relative=True)
        compare_expected_actual(expect_uint8.cuda(),
                                actual_uint8,
                                name="uint8",
                                get_relative=True)

        dist.destroy_process_group()
Example #26
0
def test_fhe_enc_copy():
    print()
    print("Test for FheEncTensor Copy: start")
    modulus, degree = 12289, 2048
    num_elem = 2 ** 17
    print(f"modulus: {modulus}, degree: {degree}")

    fhe_builder = FheBuilder(modulus, degree)
    fhe_builder.generate_keys()
    tensor = torch.ones(num_elem).float() + 5
    enc_old = fhe_builder.build_enc_from_torch(tensor)
    enc_new = enc_old.copy()
    actual = fhe_builder.decrypt_to_torch(enc_new)
    compare_expected_actual(tensor, actual, get_relative=True, name="fhe enc copy")

    print("Test for FheEncTensor Copy: end")
    print()
Example #27
0
def test_nnt_conv_single_channel():
    modulus = 786433
    img_hw = 6
    filter_hw = 3
    padding = 1
    data_bit = 17
    data_range = 2**data_bit

    conv_hw = img_hw + 2 * padding
    padded_hw = get_pow_2_ceil(conv_hw)
    output_hw = img_hw + 2 * padding - (filter_hw - 1)
    output_offset = filter_hw - 2
    x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              img_hw**2).reshape([img_hw, img_hw
                                                  ]).numpy().astype(np.int)
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              filter_hw**2).reshape([filter_hw, filter_hw
                                                     ]).numpy().astype(np.int)
    x = np.arange(img_hw**2).reshape([img_hw, img_hw]).astype(np.int)
    w = np.arange(filter_hw**2).reshape([filter_hw, filter_hw]).astype(np.int)
    padded_x = pad_to_size(x, padded_hw)
    padded_w = pad_to_size(np.rot90(w, 2), padded_hw)
    print(padded_x)
    print(padded_w)
    with NamedTimerInstance("NTT2D, Sympy"):
        ntted_x = transform2d(
            padded_x, lambda sub_img: ntt(sub_img.tolist(), prime=modulus))
        ntted_w = transform2d(
            padded_w, lambda sub_img: ntt(sub_img.tolist(), prime=modulus))
    with NamedTimerInstance("Point-wise Dot"):
        doted = ntted_x * ntted_w
    with NamedTimerInstance("iNTT2D"):
        reved = transform2d(
            doted, lambda sub_img: intt(sub_img.tolist(), prime=modulus))
    actual = reved[output_offset:output_hw + output_offset,
                   output_offset:output_hw + output_offset]
    print("reved\n", reved)
    print("actual\n", actual)

    torch_x = torch.tensor(x).reshape([1, 1, img_hw, img_hw])
    torch_w = torch.tensor(w).reshape([1, 1, filter_hw, filter_hw])
    expected = F.conv2d(torch_x, torch_w, padding=1)
    expected = pmod(expected.reshape(output_hw, output_hw), modulus)
    print("expected", expected)
    compare_expected_actual(expected, actual, name="ntt", get_relative=True)
Example #28
0
def test_ntt_conv():
    modulus = 786433
    img_hw = 16
    filter_hw = 3
    padding = 1
    num_input_channel = 64
    num_output_channel = 128
    data_bit = 17
    data_range = 2**data_bit

    x_shape = [num_input_channel, img_hw, img_hw]
    w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw]

    x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(x_shape)).reshape(x_shape)
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(w_shape)).reshape(w_shape)
    # x = torch.arange(get_prod(x_shape)).reshape(x_shape)
    # w = torch.arange(get_prod(w_shape)).reshape(w_shape)

    conv2d_ntt = Conv2dNtt(modulus, data_range, img_hw, filter_hw,
                           num_input_channel, num_output_channel,
                           "test_conv2d_ntt", padding)
    y = conv2d_ntt.conv2d(x, w)

    with NamedTimerInstance("ntt x"):
        conv2d_ntt.load_and_ntt_x(x)
    with NamedTimerInstance("ntt w"):
        conv2d_ntt.load_and_ntt_w(w)
    with NamedTimerInstance("conv2d"):
        y = conv2d_ntt.conv2d_loaded()
    actual = pmod(y, modulus)
    # print("actual\n", actual)

    torch_x = x.reshape([1] + x_shape).double()
    torch_w = w.reshape(w_shape).double()
    with NamedTimerInstance("Conv2d Torch"):
        expected = F.conv2d(torch_x, torch_w, padding=padding)
        expected = pmod(expected.reshape(conv2d_ntt.y_shape), modulus)
    # print("expected", expected)
    compare_expected_actual(expected,
                            actual,
                            name="ntt",
                            get_relative=True,
                            show_where_err=False)
Example #29
0
    def check_correctness(self, input_img, output, modulus):
        plain_net = get_plain_net()
        expected = plain_net(
            input_img.reshape([1] + list(input_img.shape)).cuda())
        expected = nmod(expected.reshape(expected.shape[1:]), modulus)
        actual = nmod(output, modulus).cuda()
        print("expected", expected)
        print("actual", actual)
        compare_expected_actual(expected,
                                actual,
                                name="secure_vgg",
                                get_relative=True)

        _, expected_max = torch.max(expected, 0)
        _, actual_max = torch.max(actual, 0)
        print(
            f"expected_max: {expected_max}, actual_max: {actual_max}, Match: {expected_max == actual_max}"
        )
Example #30
0
def test_noise():
    print()
    print("Test for FheBuilder: start")
    modulus, degree = 12289, 2048
    num_elem = 2 ** 14
    fhe_builder = FheBuilder(modulus, degree)
    fhe_builder.generate_keys()
    print(f"modulus: {modulus}, degree: {degree}")
    print()

    gpu_tensor = gen_unirand_int_grain(0, 2, num_elem)
    print(gpu_tensor)
    gpu_tensor_rev = gen_unirand_int_grain(0, modulus - 1, num_elem)
    with NamedTimerInstance(f"build_plain_from_torch with num_elem: {num_elem}"):
        plain = fhe_builder.build_plain_from_torch(gpu_tensor)
    with NamedTimerInstance(f"plain.export_as_torch_gpu() with num_elem: {num_elem}"):
        tensor_from_plain = plain.export_as_torch()
    print("Fhe Plain encrypt and decrypt: ", end="")
    assert(compare_expected_actual(gpu_tensor, tensor_from_plain, verbose=True, get_relative=True).RelAvgDiff == 0)
    print()

    with NamedTimerInstance(f"fhe_builder.build_enc with num_elem: {num_elem}"):
        cipher = fhe_builder.build_enc(num_elem)
    with NamedTimerInstance(f"cipher.encrypt_additive with num_elem: {num_elem}"):
        cipher.encrypt_additive(gpu_tensor)
    with NamedTimerInstance(f"fhe_builder.decrypt_to_torch with num_elem: {num_elem}"):
        tensor_from_cipher = fhe_builder.decrypt_to_torch(cipher)
    print("Fhe Enc encrypt and decrypt: ", end="")
    assert(compare_expected_actual(gpu_tensor, tensor_from_cipher, verbose=True, get_relative=True).RelAvgDiff == 0)
    print()

    pt = fhe_builder.build_plain_from_torch(gpu_tensor)
    ct = fhe_builder.build_enc_from_torch(gpu_tensor_rev)
    fhe_builder.noise_budget(ct, name="before ep")
    with NamedTimerInstance(f"EP Mult with num_elem: {num_elem}"):
        ct *= pt
    fhe_builder.noise_budget(ct, name="after ep")
    expected = pmod(gpu_tensor.double() * gpu_tensor_rev.double(), modulus)
    actual = fhe_builder.decrypt_to_torch(ct)
    assert(compare_expected_actual(expected, actual, verbose=True, get_relative=True).RelAvgDiff == 0)
    print()

    print("Test for FheBuilder: Finish")