Пример #1
0
    def test_server():
        rank = Config.server_rank
        init_communicate(rank)
        context.set_rank(rank)

        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        conv1.load_weight(conv1_w)
        conv2.load_weight(conv2_w)
        fc1.load_weight(fc1_w)
        trunc1.set_div_to_pow(pow_to_div)

        with NamedTimerInstance("Server Offline"):
            secure_nn.offline()
            torch_sync()

        with NamedTimerInstance("Server Online"):
            secure_nn.online()
            torch_sync()

        end_communicate()
Пример #2
0
    def test_server():
        rank = Config.server_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = Conv2dSecureServer(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_secure_comm",
                                  padding=padding)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight, bias=bias)
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(input_s)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_online(input, weight, bias, prot.output_s, output_c)

        end_communicate()
Пример #3
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = Conv2dSecureClient(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_secure_comm",
                                  padding=padding)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online()
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()
Пример #4
0
    def test_client():
        init_communicate(Config.client_rank)
        prot = SharesMultClient(num_elem, modulus, fhe_builder,
                                "test_shares_mult")
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online(a_c, b_c)
            torch_sync()

        blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_u_c")
        blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_v_c")
        blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_z_c")
        torch_sync()
        blob_u_c.send(prot.u_c)
        blob_v_c.send(prot.v_c)
        blob_z_c.send(prot.z_c)
        blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c")
        torch_sync()
        blob_c_c.send(prot.c_c)
        end_communicate()
Пример #5
0
    def test_server():
        init_communicate(Config.server_rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()
        prot = Avgpool2x2Server(num_elem, q_23, q_16, work_bit, data_bit,
                                img_hw, fhe_builder_16, fhe_builder_23,
                                "avgpool")

        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            prot.online(img_s)
            torch_sync()
        traffic_record.reset("server-online")

        blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_res_c")
        blob_out_c.prepare_recv()
        torch_sync()
        out_c = blob_out_c.get_recv()
        check_correctness_online(img, prot.out_s, out_c)

        end_communicate()
Пример #6
0
    def test_client():
        init_communicate(Config.client_rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()
        prot = Avgpool2x2Client(num_elem, q_23, q_16, work_bit, data_bit,
                                img_hw, fhe_builder_16, fhe_builder_23,
                                "avgpool")

        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            prot.online(img_c)
            torch_sync()
        traffic_record.reset("client-online")

        blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_res_c")
        torch_sync()
        blob_out_c.send(prot.out_c)
        end_communicate()
Пример #7
0
    def test_server():
        init_communicate(Config.server_rank)
        prot = SharesMultServer(num_elem, modulus, fhe_builder,
                                "test_shares_mult")
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(a_s, b_s)
            torch_sync()

        blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_u_c")
        blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_v_c")
        blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_z_c")
        blob_u_c.prepare_recv()
        blob_v_c.prepare_recv()
        blob_z_c.prepare_recv()
        torch_sync()
        u_c = blob_u_c.get_recv()
        v_c = blob_v_c.get_recv()
        z_c = blob_z_c.get_recv()
        u = pmod(prot.u_s + u_c, modulus)
        v = pmod(prot.v_s + v_c, modulus)
        check_correctness_online(u, v, prot.z_s, z_c)

        blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c")
        blob_c_c.prepare_recv()
        torch_sync()
        c_c = blob_c_c.get_recv()
        check_correctness_online(a, b, prot.c_s, c_c)
        end_communicate()
Пример #8
0
    def sum_c_i_offline(self, delta_a, fhe_beta_i_c, fhe_alpha_beta_xor_c, s,
                        alpha_i, ci_mask_s, mult_mask_s, shuffle_order):
        # the last row of sum_xor is c_{-1}, which helps check the case with x == y
        fhe_builder = self.fhe_builder_16
        # fhe_sum_xor = [fhe_builder.build_enc(self.num_elem) for i in range(self.num_work_batch)]
        fhe_sum_xor = [None for i in range(self.num_work_batch)]
        fhe_sum_xor[self.work_bit - 1] = fhe_builder.build_enc(self.num_elem)
        for i in range(self.work_bit - 1)[::-1]:
            fhe_sum_xor[i] = fhe_sum_xor[i + 1].copy()
            fhe_sum_xor[i] += fhe_alpha_beta_xor_c[i + 1]
        fhe_delta_a = fhe_builder.build_plain_from_torch(delta_a)
        fhe_sum_xor[self.work_bit] = fhe_sum_xor[0].copy()
        fhe_sum_xor[self.work_bit] += fhe_alpha_beta_xor_c[0]
        fhe_sum_xor[self.work_bit] += fhe_delta_a
        del fhe_delta_a

        for i in range(self.work_bit)[::-1]:
            fhe_mult_3 = fhe_builder.build_plain_from_torch(
                pmod(3 * mult_mask_s[i].cpu(), self.q_16))
            fhe_mult_mask_s = fhe_builder.build_plain_from_torch(
                mult_mask_s[i])
            masked_s = pmod(
                s.type(torch.int64) * mult_mask_s[i].type(torch.int64),
                self.q_16).type(torch.float32)
            # print("s * mult_mask_s[i]", torch.max(masked_s))
            fhe_s = fhe_builder.build_plain_from_torch(masked_s)
            fhe_alpha_i = fhe_builder.build_plain_from_torch(alpha_i[i] *
                                                             mult_mask_s[i])
            fhe_ci_mask_s = fhe_builder.build_plain_from_torch(ci_mask_s[i])
            fhe_beta_i_c[i] *= fhe_mult_mask_s
            fhe_sum_xor[i] *= fhe_mult_3
            fhe_sum_xor[i] -= fhe_beta_i_c[i]
            fhe_sum_xor[i] += fhe_s
            fhe_sum_xor[i] += fhe_alpha_i
            fhe_sum_xor[i] += fhe_ci_mask_s

            del fhe_mult_3, fhe_mult_mask_s, fhe_s, fhe_alpha_i, fhe_ci_mask_s

        fhe_mult_mask_s = fhe_builder.build_plain_from_torch(
            mult_mask_s[self.work_bit])
        fhe_ci_mask_s = fhe_builder.build_plain_from_torch(
            ci_mask_s[self.work_bit])
        fhe_sum_xor[self.work_bit] *= fhe_mult_mask_s
        fhe_sum_xor[self.work_bit] += fhe_ci_mask_s

        del fhe_mult_mask_s, fhe_ci_mask_s

        if self.is_shuffle:
            with NamedTimerInstance("Shuffle"):
                refresher = EncRefresherServer(
                    self.sum_shape, fhe_builder,
                    self.sub_name("shuffle_refresher"))
                with NamedTimerInstance("refresh"):
                    new_fhe_sum_xor = refresher.request(fhe_sum_xor)
                del fhe_sum_xor
                fhe_sum_xor = self.generate_fhe_shuffled(
                    shuffle_order, new_fhe_sum_xor)
                del refresher

        return fhe_sum_xor
Пример #9
0
def test_rotation():
    modulus = Config.q_23
    degree = Config.n_23
    fhe_builder = FheBuilder(modulus, degree)
    fhe_builder.generate_keys()
    fhe_builder.generate_galois_keys()

    # x = gen_unirand_int_grain(0, modulus-1, degree)
    x = torch.arange(degree)
    with NamedTimerInstance("Fhe Encrypt"):
        enc = fhe_builder.build_enc_from_torch(x)
    enc_less = fhe_builder.build_enc_from_torch(x)
    plain = fhe_builder.build_plain_from_torch(x)

    fhe_builder.noise_budget(enc, "before mul")
    with NamedTimerInstance("ep mult"):
        enc *= plain
        enc_less *= plain
    fhe_builder.noise_budget(enc, "after mul")
    with NamedTimerInstance("ee add"):
        for i in range(128):
            enc += enc_less
    fhe_builder.noise_budget(enc, "after add")
    with NamedTimerInstance("rot"):
        fhe_builder.evaluator.rotate_rows_inplace(enc.cts[0], 64,
                                                  fhe_builder.galois_keys)
    fhe_builder.noise_budget(enc, "after rot")
    print(fhe_builder.decrypt_to_torch(enc))
Пример #10
0
    def test_server():
        rank = Config.server_rank
        sys.stdout = Logger()
        traffic_record = TrafficRecord()
        secure_nn = get_secure_nn()
        secure_nn.set_rank(rank).init_communication(master_address=master_addr,
                                                    master_port=master_port)
        warming_up_cuda()
        secure_nn.fhe_builder_sync()
        load_trunc_params(secure_nn, store_configs)

        net_state = torch.load(net_state_name)
        load_weight_params(secure_nn, store_configs, net_state)

        meta_rg = MetaTruncRandomGenerator()
        meta_rg.reset_seed()

        with NamedTimerInstance("Server Offline"):
            secure_nn.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            secure_nn.online()
            torch_sync()
        traffic_record.reset("server-online")

        secure_nn.check_correctness(check_correctness)
        secure_nn.check_layers(get_plain_net, get_hooking_lst(model_name_base))
        secure_nn.end_communication()
Пример #11
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank)
        context.set_rank(rank)

        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        input_img = generate_random_data(input_shape)
        trunc1.set_div_to_pow(pow_to_div)
        secure_nn.feed_input(input_img)

        with NamedTimerInstance("Client Offline"):
            secure_nn.offline()
            torch_sync()

        with NamedTimerInstance("Client Online"):
            secure_nn.online()
            torch_sync()

        check_correctness(input_img, secure_nn.get_output())
        end_communicate()
Пример #12
0
def transform2d(img, func, reverse=False):
    N = len(img)
    with NamedTimerInstance("Apply func ina loop"):
        tmp = [func(img[i, :]) for i in range(N)]
    with NamedTimerInstance("list to numpy"):
        img = np.array(tmp)
    img = np.array([func(img[:, i]) for i in range(N)])
    return img.transpose()
Пример #13
0
def test_conv2d_fhe_ntt_single_thread():
    modulus = 786433
    img_hw = 16
    filter_hw = 3
    padding = 1
    num_input_channel = 64
    num_output_channel = 128
    data_bit = 17
    data_range = 2**data_bit

    x_shape = [num_input_channel, img_hw, img_hw]
    w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape)
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(w_shape)).reshape(w_shape)
    x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape)
    # x = torch.arange(get_prod(x_shape)).reshape(x_shape)
    # w = torch.arange(get_prod(w_shape)).reshape(w_shape)

    warming_up_cuda()
    prot = Conv2dFheNttSingleThread(modulus, data_range, img_hw, filter_hw,
                                    num_input_channel, num_output_channel,
                                    fhe_builder, "test_conv2d_fhe_ntt",
                                    padding)

    print("prot.num_input_batch", prot.num_input_batch)
    print("prot.num_output_batch", prot.num_output_batch)

    with NamedTimerInstance("encoding x"):
        prot.encode_input_to_fhe_batch(x)
    with NamedTimerInstance("conv2d with w"):
        prot.compute_conv2d(w)
    with NamedTimerInstance("conv2d masking output"):
        prot.masking_output()
    with NamedTimerInstance("decoding output"):
        y = prot.decode_output_from_fhe_batch()
    # actual = pmod(y, modulus)
    actual = pmod(y - prot.output_mask_s, modulus)
    # print("actual\n", actual)

    torch_x = x.reshape([1] + x_shape).double()
    torch_w = w.reshape(w_shape).double()
    with NamedTimerInstance("Conv2d Torch"):
        expected = F.conv2d(torch_x, torch_w, padding=padding)
        expected = pmod(expected.reshape(prot.output_shape), modulus)
    # print("expected", expected)
    compare_expected_actual(expected,
                            actual,
                            name="test_conv2d_fhe_ntt_single_thread",
                            get_relative=True)
Пример #14
0
    def test_server():
        init_communicate(Config.server_rank)
        warming_up_cuda()
        prot = ReconToClientServer(num_elem, modulus, test_name)
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Server online"):
            prot.online(x_s)
            torch_sync()

        end_communicate()
Пример #15
0
def test_fc_fhe_single_thread():
    test_name = "test_fc_fhe_single_thread"
    print(f"\nTest for {test_name}: Start")
    modulus = Config.q_23
    num_input_unit = 512
    num_output_unit = 512
    data_bit = 17

    data_range = 2**data_bit
    x_shape = [num_input_unit]
    w_shape = [num_output_unit, num_input_unit]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape)
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(w_shape)).reshape(w_shape)
    x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape)
    # w = gen_unirand_int_grain(0, modulus, get_prod(w_shape)).reshape(w_shape)
    # x = torch.arange(get_prod(x_shape)).reshape(x_shape)
    # w = torch.arange(get_prod(w_shape)).reshape(w_shape)

    warming_up_cuda()
    prot = FcFheSingleThread(modulus, data_range, num_input_unit,
                             num_output_unit, fhe_builder, test_name)

    print("prot.num_input_batch", prot.num_input_batch)
    print("prot.num_output_batch", prot.num_output_batch)
    print("prot.num_elem_in_piece", prot.num_elem_in_piece)

    with NamedTimerInstance("encoding x"):
        prot.encode_input_to_fhe_batch(x)
    with NamedTimerInstance("conv2d with w"):
        prot.compute_with_weight(w)
    with NamedTimerInstance("conv2d masking output"):
        prot.masking_output()
    with NamedTimerInstance("decoding output"):
        y = prot.decode_output_from_fhe_batch()
    actual = pmod(y, modulus)
    actual = pmod(y - prot.output_mask_s, modulus)
    # print("actual\n", actual)

    torch_x = x.reshape([1] + x_shape).double()
    torch_w = w.reshape(w_shape).double()
    with NamedTimerInstance("Conv2d Torch"):
        expected = torch.mm(torch_x, torch_w.t())
        expected = pmod(expected.reshape(prot.output_shape), modulus)
    compare_expected_actual(expected,
                            actual,
                            name=test_name,
                            get_relative=True)
    print(f"\nTest for {test_name}: End")
Пример #16
0
    def test_client():
        init_communicate(Config.client_rank)
        warming_up_cuda()
        prot = ReconToClientClient(num_elem, modulus, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online(x_c)
            torch_sync()

        check_correctness_online(prot.output, x_s, x_c)
        end_communicate()
Пример #17
0
    def test_server():
        rank = Config.server_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        img = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                    num_elem)
        img_s = gen_unirand_int_grain(0, q_23 - 1, num_elem)
        img_c = pmod(img - img_s, q_23)

        prot = Maxpool2x2DgkServer(num_elem, q_23, q_16, work_bit, data_bit,
                                   img_hw, fhe_builder_16, fhe_builder_23,
                                   "max_dgk")

        blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                               "recon_max_c")
        torch_sync()
        blob_img_c.send(img_c)

        torch_sync()
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            prot.online(img_s)
            torch_sync()
        traffic_record.reset("server-online")

        blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_max_c")
        blob_max_c.prepare_recv()
        torch_sync()
        max_c = blob_max_c.get_recv()
        check_correctness_online(img, prot.max_s, max_c)

        end_communicate()
Пример #18
0
    def comm_base_client():
        init_communicate(Config.client_rank)
        comm_base = CommBase(Config.client_rank, comm_name)

        send_float = expect_float.cuda()
        send_double = expect_double.cuda()

        with NamedTimerInstance("Client float and int16"):
            comm_base.recv_torch(
                torch.zeros(num_elem).type(torch.int16), int16_tag)
            comm_base.send_torch(send_float, float_tag)
            comm_base.wait(int16_tag)
            actual_int16 = comm_base.get_tensor(int16_tag).cuda()

        comm_base.recv_torch(
            torch.zeros(num_elem).type(torch.int16), int16_tag)
        dist.barrier()
        with NamedTimerInstance("Client float and int16"):
            comm_base.send_torch(send_float, float_tag)
            comm_base.wait(int16_tag)
            actual_int16 = comm_base.get_tensor(int16_tag).cuda()

        comm_base.recv_torch(
            torch.zeros(num_elem).type(torch.uint8), uint8_tag)
        dist.barrier()
        with NamedTimerInstance("Client double and uint8"):
            comm_base.send_torch(send_double, double_tag)
            comm_base.wait(uint8_tag)
            actual_uint8 = comm_base.get_tensor(uint8_tag).cuda()

        comm_base.recv_torch(
            torch.zeros(num_elem).type(torch.uint8), uint8_tag)
        dist.barrier()
        with NamedTimerInstance("Client double and uint8"):
            comm_base.send_torch(send_double, double_tag)
            comm_base.wait(uint8_tag)
            actual_uint8 = comm_base.get_tensor(uint8_tag).cuda()

        dist.barrier()
        compare_expected_actual(expect_int16.cuda(),
                                actual_int16,
                                name="int16",
                                get_relative=True)
        compare_expected_actual(expect_uint8.cuda(),
                                actual_uint8,
                                name="uint8",
                                get_relative=True)

        dist.destroy_process_group()
Пример #19
0
def test_torch_ntt():
    modulus = 786433
    img_hw = 64
    filter_hw = 3
    padding = 1
    data_bit = 17

    len_vector = img_hw
    data_range = 2**data_bit
    root, mod, ntt_mat, inv_mat = generate_ntt_param(modulus, len_vector,
                                                     data_range)

    x = np.arange(img_hw**2).reshape([img_hw, img_hw]).astype(np.int)
    w = np.arange(filter_hw**2).reshape([filter_hw, filter_hw]).astype(np.int)
    x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              img_hw**2).reshape([img_hw, img_hw]).double()
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              filter_hw**2).reshape([filter_hw,
                                                     filter_hw]).double()

    ntt_mat = ntt_mat.double()
    inv_mat = inv_mat.double()

    with NamedTimerInstance("Mat NTT 2d"):
        ntted = ntt_mat2d(ntt_mat, mod, x)
    reved = ntt_mat2d(inv_mat, mod, ntted)
    expected = pmod(x, modulus).type(torch.int)
    actual = pmod(reved, modulus).type(torch.int)
    compare_expected_actual(expected, actual, name="ntt", get_relative=True)
Пример #20
0
    def response(self):
        self.prepare_recv()
        torch_sync()

        fhe_masked = self.common.masked.get_recv()

        with NamedTimerInstance("client refresh reencrypt"):
            if len(self.shape) == 2:
                for i in range(self.shape[0]):
                    sub_dec = self.fhe_builder.decrypt_to_torch(fhe_masked[i])
                    self.refreshed[i].encrypt_additive(sub_dec)
            else:
                self.refreshed.encrypt_additive(self.fhe_builder.decrypt_to_torch(fhe_masked))

        # def sub_reencrypt(sub_enc):
        #     sub_dec = self.fhe_builder.decrypt_to_torch(sub_enc)
        #     sub_refreshed = self.fhe_builder.build_enc_from_torch(sub_dec)
        #     del sub_dec
        #     return sub_refreshed
        # with NamedTimerInstance("client refresh reencrypt"):
        #     self.refreshed = sub_handle(sub_reencrypt, fhe_masked)

        self.common.refreshed.send(self.refreshed)

        torch_sync()
        delete_fhe(fhe_masked)
        delete_fhe(self.refreshed)
Пример #21
0
    def test_client():
        init_communicate(Config.client_rank)
        warming_up_cuda()
        prot = SwapToClientOfflineClient(num_elem, modulus, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(y_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online(x_c)
            torch_sync()

        blob_output_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                                  "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()
Пример #22
0
def test_ntt_conv():
    modulus = 786433
    img_hw = 16
    filter_hw = 3
    padding = 1
    num_input_channel = 64
    num_output_channel = 128
    data_bit = 17
    data_range = 2**data_bit

    x_shape = [num_input_channel, img_hw, img_hw]
    w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw]

    x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(x_shape)).reshape(x_shape)
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(w_shape)).reshape(w_shape)
    # x = torch.arange(get_prod(x_shape)).reshape(x_shape)
    # w = torch.arange(get_prod(w_shape)).reshape(w_shape)

    conv2d_ntt = Conv2dNtt(modulus, data_range, img_hw, filter_hw,
                           num_input_channel, num_output_channel,
                           "test_conv2d_ntt", padding)
    y = conv2d_ntt.conv2d(x, w)

    with NamedTimerInstance("ntt x"):
        conv2d_ntt.load_and_ntt_x(x)
    with NamedTimerInstance("ntt w"):
        conv2d_ntt.load_and_ntt_w(w)
    with NamedTimerInstance("conv2d"):
        y = conv2d_ntt.conv2d_loaded()
    actual = pmod(y, modulus)
    # print("actual\n", actual)

    torch_x = x.reshape([1] + x_shape).double()
    torch_w = w.reshape(w_shape).double()
    with NamedTimerInstance("Conv2d Torch"):
        expected = F.conv2d(torch_x, torch_w, padding=padding)
        expected = pmod(expected.reshape(conv2d_ntt.y_shape), modulus)
    # print("expected", expected)
    compare_expected_actual(expected,
                            actual,
                            name="ntt",
                            get_relative=True,
                            show_where_err=False)
Пример #23
0
def test_nnt_conv_single_channel():
    modulus = 786433
    img_hw = 6
    filter_hw = 3
    padding = 1
    data_bit = 17
    data_range = 2**data_bit

    conv_hw = img_hw + 2 * padding
    padded_hw = get_pow_2_ceil(conv_hw)
    output_hw = img_hw + 2 * padding - (filter_hw - 1)
    output_offset = filter_hw - 2
    x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              img_hw**2).reshape([img_hw, img_hw
                                                  ]).numpy().astype(np.int)
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              filter_hw**2).reshape([filter_hw, filter_hw
                                                     ]).numpy().astype(np.int)
    x = np.arange(img_hw**2).reshape([img_hw, img_hw]).astype(np.int)
    w = np.arange(filter_hw**2).reshape([filter_hw, filter_hw]).astype(np.int)
    padded_x = pad_to_size(x, padded_hw)
    padded_w = pad_to_size(np.rot90(w, 2), padded_hw)
    print(padded_x)
    print(padded_w)
    with NamedTimerInstance("NTT2D, Sympy"):
        ntted_x = transform2d(
            padded_x, lambda sub_img: ntt(sub_img.tolist(), prime=modulus))
        ntted_w = transform2d(
            padded_w, lambda sub_img: ntt(sub_img.tolist(), prime=modulus))
    with NamedTimerInstance("Point-wise Dot"):
        doted = ntted_x * ntted_w
    with NamedTimerInstance("iNTT2D"):
        reved = transform2d(
            doted, lambda sub_img: intt(sub_img.tolist(), prime=modulus))
    actual = reved[output_offset:output_hw + output_offset,
                   output_offset:output_hw + output_offset]
    print("reved\n", reved)
    print("actual\n", actual)

    torch_x = torch.tensor(x).reshape([1, 1, img_hw, img_hw])
    torch_w = torch.tensor(w).reshape([1, 1, filter_hw, filter_hw])
    expected = F.conv2d(torch_x, torch_w, padding=1)
    expected = pmod(expected.reshape(output_hw, output_hw), modulus)
    print("expected", expected)
    compare_expected_actual(expected, actual, name="ntt", get_relative=True)
Пример #24
0
def run_secure_nn_client_with_random_data(secure_nn, check_correctness, master_address, master_port):
    rank = Config.client_rank
    traffic_record = TrafficRecord()
    secure_nn.set_rank(rank).init_communication(master_address=master_address, master_port=master_port)
    warming_up_cuda()
    secure_nn.fhe_builder_sync().fill_random_input()

    with NamedTimerInstance("Client Offline"):
        secure_nn.offline()
        torch_sync()
    traffic_record.reset("client-offline")

    with NamedTimerInstance("Client Online"):
        secure_nn.online()
        torch_sync()
    traffic_record.reset("client-online")

    secure_nn.check_correctness(check_correctness).end_communication()
Пример #25
0
def test_basic_ntt():
    modulus = 786433
    img_hw = 34
    x = gen_unirand_int_grain(0, modulus - 1,
                              img_hw**2).reshape([img_hw, img_hw
                                                  ]).numpy().astype(np.int)
    padded = np.zeros([get_pow_2_ceil(img_hw),
                       get_pow_2_ceil(img_hw)]).astype(np.int)
    padded[:img_hw, :img_hw] = x
    x = padded
    expected = x[:, :]
    with NamedTimerInstance("NTT2D, Sympy"):
        ntted = transform2d(
            x, lambda sub_img: ntt(sub_img.tolist(), prime=modulus))
    with NamedTimerInstance("iNTT2D"):
        reved = transform2d(
            ntted, lambda sub_img: intt(sub_img.tolist(), prime=modulus))
    actual = reved
    compare_expected_actual(expected, actual, name="ntt", get_relative=True)
Пример #26
0
    def test_server():
        init_communicate(Config.server_rank)
        warming_up_cuda()
        prot = SwapToClientOfflineServer(num_elem, modulus, test_name)
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Server online"):
            prot.online(x_s)
            torch_sync()

        blob_output_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                                  "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_online(x_s, x_c, prot.output_s, output_c)

        end_communicate()
Пример #27
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        prot = Maxpool2x2DgkClient(num_elem, q_23, q_16, work_bit, data_bit,
                                   img_hw, fhe_builder_16, fhe_builder_23,
                                   "max_dgk")
        blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                               "recon_max_c")
        blob_img_c.prepare_recv()
        torch_sync()
        img_c = blob_img_c.get_recv()

        torch_sync()
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            prot.online(img_c)
            torch_sync()
        traffic_record.reset("client-online")

        blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_max_c")
        torch_sync()
        blob_max_c.send(prot.max_c)
        end_communicate()
Пример #28
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank, master_address=master_address, master_port=master_port)
        traffic_record = TrafficRecord()
        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        a = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem)
        a_c = gen_unirand_int_grain(0, q_23 - 1, num_elem)
        a_s = pmod(a - a_c, q_23)

        prot = ReluDgkClient(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk")

        blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a")
        blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s")
        torch_sync()
        blob_a_s.send(a_s)
        blob_max_s.prepare_recv()

        torch_sync()
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            prot.online(a_c)
            torch_sync()
        traffic_record.reset("client-online")

        max_s = blob_max_s.get_recv()
        check_correctness_online(a, max_s, prot.max_c)

        torch.cuda.empty_cache()
        end_communicate()
Пример #29
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = FcSecureClient(modulus, data_range, num_input_unit,
                              num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online()
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()
Пример #30
0
    def test_client():
        init_communicate(Config.client_rank)
        prot = FcFheClient(modulus, data_range, num_input_unit,
                           num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_mask)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()