示例#1
0
    def test_server():
        rank = Config.server_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        img = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                    num_elem)
        img_s = gen_unirand_int_grain(0, q_23 - 1, num_elem)
        img_c = pmod(img - img_s, q_23)

        prot = Maxpool2x2DgkServer(num_elem, q_23, q_16, work_bit, data_bit,
                                   img_hw, fhe_builder_16, fhe_builder_23,
                                   "max_dgk")

        blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                               "recon_max_c")
        torch_sync()
        blob_img_c.send(img_c)

        torch_sync()
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            prot.online(img_s)
            torch_sync()
        traffic_record.reset("server-online")

        blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_max_c")
        blob_max_c.prepare_recv()
        torch_sync()
        max_c = blob_max_c.get_recv()
        check_correctness_online(img, prot.max_s, max_c)

        end_communicate()
示例#2
0
def test_recon_to_client_comm():
    test_name = "test_recon_to_client_comm"
    print(f"\nTest for {test_name}: Start")
    modulus = 786433
    num_elem = 2**17

    print(f"Number of element: {num_elem}")

    x_s = gen_unirand_int_grain(0, modulus - 1, num_elem).cpu()
    x_c = gen_unirand_int_grain(0, modulus - 1, num_elem).cpu()

    def check_correctness_online(output, input_s, input_c):
        expected = pmod(output.cuda(), modulus)
        actual = pmod(input_s.cuda() + input_c.cuda(), modulus)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + " online",
                                get_relative=True)

    def test_server():
        init_communicate(Config.server_rank)
        warming_up_cuda()
        prot = ReconToClientServer(num_elem, modulus, test_name)
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Server online"):
            prot.online(x_s)
            torch_sync()

        end_communicate()

    def test_client():
        init_communicate(Config.client_rank)
        warming_up_cuda()
        prot = ReconToClientClient(num_elem, modulus, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online(x_c)
            torch_sync()

        check_correctness_online(prot.output, x_s, x_c)
        end_communicate()

    marshal_funcs([test_server, test_client])
    print(f"\nTest for {test_name}: End")
示例#3
0
    def request(self, enc):
        self.prepare_recv()
        torch_sync()
        self.r_s = gen_unirand_int_grain(0, self.modulus - 1, self.shape)

        if len(self.shape) == 2:
            pt = []
            for i in range(self.shape[0]):
                pt.append(self.fhe_builder.build_plain_from_torch(self.r_s[i]))
                enc[i] += pt[i]

            self.common.masked.send(enc)
            refreshed = self.common.refreshed.get_recv()

            for i in range(self.shape[0]):
                refreshed[i] -= pt[i]
            delete_fhe(enc)
            delete_fhe(pt)
            torch_sync()
            return refreshed
        else:
            pt = self.fhe_builder.build_plain_from_torch(self.r_s)
            enc += pt
            self.common.masked.send(enc)
            refreshed = self.common.refreshed.get_recv()
            refreshed -= pt
            delete_fhe(enc)
            delete_fhe(pt)
            torch_sync()
            return refreshed
示例#4
0
def test_ntt_conv():
    modulus = 786433
    img_hw = 16
    filter_hw = 3
    padding = 1
    num_input_channel = 64
    num_output_channel = 128
    data_bit = 17
    data_range = 2**data_bit

    x_shape = [num_input_channel, img_hw, img_hw]
    w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw]

    x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(x_shape)).reshape(x_shape)
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(w_shape)).reshape(w_shape)
    # x = torch.arange(get_prod(x_shape)).reshape(x_shape)
    # w = torch.arange(get_prod(w_shape)).reshape(w_shape)

    conv2d_ntt = Conv2dNtt(modulus, data_range, img_hw, filter_hw,
                           num_input_channel, num_output_channel,
                           "test_conv2d_ntt", padding)
    y = conv2d_ntt.conv2d(x, w)

    with NamedTimerInstance("ntt x"):
        conv2d_ntt.load_and_ntt_x(x)
    with NamedTimerInstance("ntt w"):
        conv2d_ntt.load_and_ntt_w(w)
    with NamedTimerInstance("conv2d"):
        y = conv2d_ntt.conv2d_loaded()
    actual = pmod(y, modulus)
    # print("actual\n", actual)

    torch_x = x.reshape([1] + x_shape).double()
    torch_w = w.reshape(w_shape).double()
    with NamedTimerInstance("Conv2d Torch"):
        expected = F.conv2d(torch_x, torch_w, padding=padding)
        expected = pmod(expected.reshape(conv2d_ntt.y_shape), modulus)
    # print("expected", expected)
    compare_expected_actual(expected,
                            actual,
                            name="ntt",
                            get_relative=True,
                            show_where_err=False)
示例#5
0
def test_nnt_conv_single_channel():
    modulus = 786433
    img_hw = 6
    filter_hw = 3
    padding = 1
    data_bit = 17
    data_range = 2**data_bit

    conv_hw = img_hw + 2 * padding
    padded_hw = get_pow_2_ceil(conv_hw)
    output_hw = img_hw + 2 * padding - (filter_hw - 1)
    output_offset = filter_hw - 2
    x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              img_hw**2).reshape([img_hw, img_hw
                                                  ]).numpy().astype(np.int)
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              filter_hw**2).reshape([filter_hw, filter_hw
                                                     ]).numpy().astype(np.int)
    x = np.arange(img_hw**2).reshape([img_hw, img_hw]).astype(np.int)
    w = np.arange(filter_hw**2).reshape([filter_hw, filter_hw]).astype(np.int)
    padded_x = pad_to_size(x, padded_hw)
    padded_w = pad_to_size(np.rot90(w, 2), padded_hw)
    print(padded_x)
    print(padded_w)
    with NamedTimerInstance("NTT2D, Sympy"):
        ntted_x = transform2d(
            padded_x, lambda sub_img: ntt(sub_img.tolist(), prime=modulus))
        ntted_w = transform2d(
            padded_w, lambda sub_img: ntt(sub_img.tolist(), prime=modulus))
    with NamedTimerInstance("Point-wise Dot"):
        doted = ntted_x * ntted_w
    with NamedTimerInstance("iNTT2D"):
        reved = transform2d(
            doted, lambda sub_img: intt(sub_img.tolist(), prime=modulus))
    actual = reved[output_offset:output_hw + output_offset,
                   output_offset:output_hw + output_offset]
    print("reved\n", reved)
    print("actual\n", actual)

    torch_x = torch.tensor(x).reshape([1, 1, img_hw, img_hw])
    torch_w = torch.tensor(w).reshape([1, 1, filter_hw, filter_hw])
    expected = F.conv2d(torch_x, torch_w, padding=1)
    expected = pmod(expected.reshape(output_hw, output_hw), modulus)
    print("expected", expected)
    compare_expected_actual(expected, actual, name="ntt", get_relative=True)
示例#6
0
def test_noise():
    print()
    print("Test for FheBuilder: start")
    modulus, degree = 12289, 2048
    num_elem = 2 ** 14
    fhe_builder = FheBuilder(modulus, degree)
    fhe_builder.generate_keys()
    print(f"modulus: {modulus}, degree: {degree}")
    print()

    gpu_tensor = gen_unirand_int_grain(0, 2, num_elem)
    print(gpu_tensor)
    gpu_tensor_rev = gen_unirand_int_grain(0, modulus - 1, num_elem)
    with NamedTimerInstance(f"build_plain_from_torch with num_elem: {num_elem}"):
        plain = fhe_builder.build_plain_from_torch(gpu_tensor)
    with NamedTimerInstance(f"plain.export_as_torch_gpu() with num_elem: {num_elem}"):
        tensor_from_plain = plain.export_as_torch()
    print("Fhe Plain encrypt and decrypt: ", end="")
    assert(compare_expected_actual(gpu_tensor, tensor_from_plain, verbose=True, get_relative=True).RelAvgDiff == 0)
    print()

    with NamedTimerInstance(f"fhe_builder.build_enc with num_elem: {num_elem}"):
        cipher = fhe_builder.build_enc(num_elem)
    with NamedTimerInstance(f"cipher.encrypt_additive with num_elem: {num_elem}"):
        cipher.encrypt_additive(gpu_tensor)
    with NamedTimerInstance(f"fhe_builder.decrypt_to_torch with num_elem: {num_elem}"):
        tensor_from_cipher = fhe_builder.decrypt_to_torch(cipher)
    print("Fhe Enc encrypt and decrypt: ", end="")
    assert(compare_expected_actual(gpu_tensor, tensor_from_cipher, verbose=True, get_relative=True).RelAvgDiff == 0)
    print()

    pt = fhe_builder.build_plain_from_torch(gpu_tensor)
    ct = fhe_builder.build_enc_from_torch(gpu_tensor_rev)
    fhe_builder.noise_budget(ct, name="before ep")
    with NamedTimerInstance(f"EP Mult with num_elem: {num_elem}"):
        ct *= pt
    fhe_builder.noise_budget(ct, name="after ep")
    expected = pmod(gpu_tensor.double() * gpu_tensor_rev.double(), modulus)
    actual = fhe_builder.decrypt_to_torch(ct)
    assert(compare_expected_actual(expected, actual, verbose=True, get_relative=True).RelAvgDiff == 0)
    print()

    print("Test for FheBuilder: Finish")
示例#7
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank, master_address=master_address, master_port=master_port)
        traffic_record = TrafficRecord()
        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        a = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem)
        a_c = gen_unirand_int_grain(0, q_23 - 1, num_elem)
        a_s = pmod(a - a_c, q_23)

        prot = ReluDgkClient(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk")

        blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a")
        blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s")
        torch_sync()
        blob_a_s.send(a_s)
        blob_max_s.prepare_recv()

        torch_sync()
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            prot.online(a_c)
            torch_sync()
        traffic_record.reset("client-online")

        max_s = blob_max_s.get_recv()
        check_correctness_online(a, max_s, prot.max_c)

        torch.cuda.empty_cache()
        end_communicate()
示例#8
0
    def offline(self):
        self.offline_recv()

        self.beta_i_c = gen_unirand_int_grain(
            0, self.q_16 - 1, self.decomp_bit_shape).to(Config.device)
        self.delta_b_c = gen_unirand_int_grain(0, self.q_23 - 1,
                                               self.num_elem).to(Config.device)
        self.z_work_c = gen_unirand_int_grain(0, self.q_23 - 1,
                                              self.num_elem).to(Config.device)
        self.beta_i_zeros = torch.zeros_like(self.beta_i_c)
        self.fast_ones = torch.ones(self.num_elem).to(Config.device)
        self.fast_zeros = torch.zeros(self.num_elem).to(Config.device)
        self.fast_ones_c_i = torch.ones(self.sum_shape).float().to(
            Config.device)
        self.fast_zeros_c_i = torch.zeros(self.sum_shape).float().to(
            Config.device)

        self.common.beta_i_c.send_from_torch(self.beta_i_c)
        self.common.delta_b_c.send_from_torch(self.delta_b_c)
        self.common.z_work_c.send_from_torch(self.z_work_c)

        if self.is_shuffle:
            self.sum_c_refresher = EncRefresherClient(
                self.sum_shape, self.fhe_builder_16,
                self.sub_name("shuffle_refresher"))

        self.mod_div_offline()

        refresher_ab_xor_c = EncRefresherClient(
            self.decomp_bit_shape, self.fhe_builder_16,
            self.common.sub_name("refresher_ab_xor_c"))
        refresher_ab_xor_c.response()
        self.sum_c_i_offline()

        self.c_i_c = self.common.c_i_c.get_recv_decrypt()
        self.delta_xor_c = self.common.delta_xor_c.get_recv_decrypt()
        self.dgk_x_leq_y_c = self.common.dgk_x_leq_y_c.get_recv_decrypt()
        self.dgk_x_leq_y_c = pmod(
            self.dgk_x_leq_y_c + self.correct_mod_div_work_c, self.q_23)

        self.online_recv()
        torch_sync()
示例#9
0
    def test_1d_server():
        init_communicate(Config.server_rank)
        shape = num_elem
        tensor = gen_unirand_int_grain(0, modulus, shape)
        refresher = EncRefresherServer(shape, fhe_builder, "test_1d_refresher")
        enc = fhe_builder.build_enc_from_torch(tensor)
        refreshed = refresher.request(enc)
        tensor_refreshed = fhe_builder.decrypt_to_torch(refreshed)
        compare_expected_actual(tensor, tensor_refreshed, get_relative=True, name="1d_refresh")

        end_communicate()
示例#10
0
    def mod_div_offline(self):
        fhe_builder = self.fhe_builder_23
        self.elem_zeros = torch.zeros(self.num_elem).to(Config.device)
        self.pre_mod_div_c = gen_unirand_int_grain(
            0, self.q_23 - 1, self.num_elem).to(Config.device)
        fhe_correct_mod_div_work = fhe_builder.build_enc_from_torch(
            self.pre_mod_div_c)

        self.common.fhe_pre_corr_mod.send(fhe_correct_mod_div_work)
        fhe_corr_mod_c = self.common.fhe_corr_mod_c.get_recv()
        self.correct_mod_div_work_c = fhe_builder.decrypt_to_torch(
            fhe_corr_mod_c)
示例#11
0
    def mod_div_offline(self):
        fhe_builder = self.fhe_builder_23

        self.elem_zeros = torch.zeros(self.num_elem).to(Config.device)
        self.correct_mod_div_work_mult = torch.where(
            (self.r < self.nullify_threshold), self.elem_zeros,
            self.elem_zeros + self.q_23 // self.work_range).double()
        self.correct_mod_div_work_mask_s = gen_unirand_int_grain(
            0, self.q_23 - 1, self.num_elem).to(Config.device)
        fhe_mult = fhe_builder.build_plain_from_torch(
            self.correct_mod_div_work_mult)
        fhe_bias = fhe_builder.build_plain_from_torch(
            self.correct_mod_div_work_mask_s)
        fhe_correct_mod_div_work = self.common.fhe_pre_corr_mod.get_recv()
        fhe_correct_mod_div_work *= fhe_mult
        fhe_correct_mod_div_work += fhe_bias
        del fhe_mult, fhe_bias

        self.common.fhe_corr_mod_c.send(fhe_correct_mod_div_work)
示例#12
0
def test_basic_ntt():
    modulus = 786433
    img_hw = 34
    x = gen_unirand_int_grain(0, modulus - 1,
                              img_hw**2).reshape([img_hw, img_hw
                                                  ]).numpy().astype(np.int)
    padded = np.zeros([get_pow_2_ceil(img_hw),
                       get_pow_2_ceil(img_hw)]).astype(np.int)
    padded[:img_hw, :img_hw] = x
    x = padded
    expected = x[:, :]
    with NamedTimerInstance("NTT2D, Sympy"):
        ntted = transform2d(
            x, lambda sub_img: ntt(sub_img.tolist(), prime=modulus))
    with NamedTimerInstance("iNTT2D"):
        reved = transform2d(
            ntted, lambda sub_img: intt(sub_img.tolist(), prime=modulus))
    actual = reved
    compare_expected_actual(expected, actual, name="ntt", get_relative=True)
示例#13
0
    def masking_output(self):
        self.output_mask_s = gen_unirand_int_grain(
            0, self.modulus - 1, get_prod(self.y_shape)).reshape(self.y_shape)
        # self.output_mask_s = torch.ones(self.y_shape)
        ntted_mask = self.conv2d_ntt.ntt_output_masking(self.output_mask_s)

        pod_vector = uIntVector()
        pt = Plaintext()
        for idx_output_batch in range(self.num_output_batch):
            encoding_tensor = torch.zeros(self.degree, dtype=torch.float)
            for index_piece in range(self.num_rotation):
                span = self.num_elem_in_padded
                start_piece = index_piece * span
                index_output_channel = self.index_output_piece_to_channel(
                    idx_output_batch, index_piece)
                if index_output_channel is False:
                    continue
                encoding_tensor[start_piece:start_piece + span] = ntted_mask[
                    index_output_channel].reshape(-1)
            encoding_tensor = pmod(encoding_tensor, self.modulus)
            pod_vector.from_np(encoding_tensor.numpy().astype(np.uint64))
            self.batch_encoder.encode(pod_vector, pt)
            self.evaluator.add_plain_inplace(self.output_cts[idx_output_batch],
                                             pt)
示例#14
0
 def generate_random_data(self, shape):
     return gen_unirand_int_grain(-self.data_range//2 + 1, self.data_range//2, get_prod(shape)).reshape(shape)
示例#15
0
 def generate_random(self):
     return gen_unirand_int_grain(0, self.modulus - 1, self.num_elem)
示例#16
0
def test_shares_mult():
    print("\nTest for Shares Mult: Start")
    modulus = Config.q_23
    num_elem = 2**17
    print(f"Number of element: {num_elem}")

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    a = gen_unirand_int_grain(0, modulus - 1, num_elem)
    a_s = gen_unirand_int_grain(0, modulus - 1, num_elem)
    a_c = pmod(a - a_s, modulus)
    b = gen_unirand_int_grain(0, modulus - 1, num_elem)
    b_s = gen_unirand_int_grain(0, modulus - 1, num_elem)
    b_c = pmod(b - b_s, modulus)

    def check_correctness_offline(u, v, z_s, z_c):
        expected = pmod(
            u.double().to(Config.device) * v.double().to(Config.device),
            modulus)
        actual = pmod(z_s + z_c, modulus)
        compare_expected_actual(expected,
                                actual,
                                name="shares_mult_offline",
                                get_relative=True)

    def check_correctness_online(a, b, c_s, c_c):
        expected = pmod(
            a.double().to(Config.device) * b.double().to(Config.device),
            modulus)
        actual = pmod(c_s + c_c, modulus)
        compare_expected_actual(expected,
                                actual,
                                name="shares_mult_online",
                                get_relative=True)

    def test_server():
        init_communicate(Config.server_rank)
        prot = SharesMultServer(num_elem, modulus, fhe_builder,
                                "test_shares_mult")
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(a_s, b_s)
            torch_sync()

        blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_u_c")
        blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_v_c")
        blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_z_c")
        blob_u_c.prepare_recv()
        blob_v_c.prepare_recv()
        blob_z_c.prepare_recv()
        torch_sync()
        u_c = blob_u_c.get_recv()
        v_c = blob_v_c.get_recv()
        z_c = blob_z_c.get_recv()
        u = pmod(prot.u_s + u_c, modulus)
        v = pmod(prot.v_s + v_c, modulus)
        check_correctness_online(u, v, prot.z_s, z_c)

        blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c")
        blob_c_c.prepare_recv()
        torch_sync()
        c_c = blob_c_c.get_recv()
        check_correctness_online(a, b, prot.c_s, c_c)
        end_communicate()

    def test_client():
        init_communicate(Config.client_rank)
        prot = SharesMultClient(num_elem, modulus, fhe_builder,
                                "test_shares_mult")
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online(a_c, b_c)
            torch_sync()

        blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_u_c")
        blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_v_c")
        blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_z_c")
        torch_sync()
        blob_u_c.send(prot.u_c)
        blob_v_c.send(prot.v_c)
        blob_z_c.send(prot.z_c)
        blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c")
        torch_sync()
        blob_c_c.send(prot.c_c)
        end_communicate()

    marshal_funcs([test_server, test_client])
    print("\nTest for Shares Mult: End")
示例#17
0
def test_conv2d_secure_comm(input_sid,
                            master_address,
                            master_port,
                            setting=(16, 3, 128, 128)):
    test_name = "Conv2d Secure Comm"
    print(f"\nTest for {test_name}: Start")
    modulus = 786433
    padding = 1
    img_hw, filter_hw, num_input_channel, num_output_channel = setting
    data_bit = 17
    data_range = 2**data_bit
    # n_23 = 8192
    n_23 = 16384
    print(f"Setting covn2d: img_hw: {img_hw}, "
          f"filter_hw: {filter_hw}, "
          f"num_input_channel: {num_input_channel}, "
          f"num_output_channel: {num_output_channel}")

    x_shape = [num_input_channel, img_hw, img_hw]
    w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw]
    b_shape = [num_output_channel]

    fhe_builder = FheBuilder(modulus, n_23)
    fhe_builder.generate_keys()

    weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                   get_prod(w_shape)).reshape(w_shape)
    bias = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                 get_prod(b_shape)).reshape(b_shape)
    input = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  get_prod(x_shape)).reshape(x_shape)
    input_c = generate_random_mask(modulus, x_shape)
    input_s = pmod(input - input_c, modulus)

    def check_correctness_online(x, w, b, output_s, output_c):
        actual = pmod(output_s.cuda() + output_c.cuda(), modulus)
        torch_x = x.reshape([1] + x_shape).cuda().double()
        torch_w = w.reshape(w_shape).cuda().double()
        torch_b = b.cuda().double() if b is not None else None
        expected = F.conv2d(torch_x, torch_w, padding=padding, bias=torch_b)
        expected = pmod(expected.reshape(output_s.shape), modulus)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + " online",
                                get_relative=True)

    def test_server():
        rank = Config.server_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = Conv2dSecureServer(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_secure_comm",
                                  padding=padding)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight, bias=bias)
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(input_s)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_online(input, weight, bias, prot.output_s, output_c)

        end_communicate()

    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = Conv2dSecureClient(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_secure_comm",
                                  padding=padding)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online()
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()

    if input_sid == Config.both_rank:
        marshal_funcs([test_server, test_client])
    elif input_sid == Config.server_rank:
        test_server()
    elif input_sid == Config.client_rank:
        test_client()

    print(f"\nTest for {test_name}: End")
示例#18
0
def test_conv2d_fhe_ntt_comm():
    test_name = "Conv2d Fhe NTT Comm"
    print(f"\nTest for {test_name}: Start")
    modulus = 786433
    img_hw = 2
    filter_hw = 3
    padding = 1
    num_input_channel = 512
    num_output_channel = 512
    data_bit = 17
    data_range = 2**data_bit
    print(f"Setting: img_hw {img_hw}, "
          f"filter_hw: {filter_hw}, "
          f"num_input_channel: {num_input_channel}, "
          f"num_output_channel: {num_output_channel}")

    x_shape = [num_input_channel, img_hw, img_hw]
    w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                   get_prod(w_shape)).reshape(w_shape)
    input_mask = gen_unirand_int_grain(0, modulus - 1,
                                       get_prod(x_shape)).reshape(x_shape)

    # input_mask = torch.arange(get_prod(x_shape)).reshape(x_shape)

    def check_correctness_offline(x, w, output_mask, output_c):
        actual = pmod(output_c.cuda() - output_mask.cuda(), modulus)
        torch_x = x.reshape([1] + x_shape).cuda().double()
        torch_w = w.reshape(w_shape).cuda().double()
        expected = F.conv2d(torch_x, torch_w, padding=padding)
        expected = pmod(expected.reshape(output_mask.shape), modulus)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + " offline",
                                get_relative=True)

    def test_server():
        init_communicate(Config.server_rank)
        prot = Conv2dFheNttServer(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_fhe_ntt_comm",
                                  padding=padding)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_offline(input_mask, weight, prot.output_mask_s,
                                  output_c)

        end_communicate()

    def test_client():
        init_communicate(Config.client_rank)
        prot = Conv2dFheNttClient(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_fhe_ntt_comm",
                                  padding=padding)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_mask)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()

    marshal_funcs([test_server, test_client])
    print(f"\nTest for {test_name}: End")
示例#19
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, n_16)
        fhe_builder_23 = FheBuilder(q_23, n_23)
        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        dgk = DgkBitClient(num_elem, q_23, q_16, work_bit, data_bit,
                           fhe_builder_16, fhe_builder_23, "DgkBitTest")

        x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  num_elem)
        y = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  num_elem)
        x_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                    num_elem)
        y_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                    num_elem)
        x_s = pmod(x - x_c, q_23)
        y_s = pmod(y - y_c, q_23)
        y_sub_x_s = pmod(y_s - x_s, q_23)

        x_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "x")
        y_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y")
        y_sub_x_s_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                   "y_sub_x_s")
        torch_sync()
        x_blob.send(x)
        y_blob.send(y)
        y_sub_x_s_blob.send(y_sub_x_s)

        torch_sync()
        with NamedTimerInstance("Client Offline"):
            dgk.offline()
        y_sub_x_c = pmod(y_c - x_c, q_23)
        traffic_record.reset("client-offline")
        torch_sync()

        with NamedTimerInstance("Client Online"):
            dgk.online(y_sub_x_c)
        traffic_record.reset("client-online")

        dgk_x_leq_y_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                       "dgk_x_leq_y_c")
        correct_mod_div_work_c_blob = BlobTorch(num_elem, torch.float,
                                                dgk.comm_base,
                                                "correct_mod_div_work_c")
        z_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "z")
        torch_sync()
        dgk_x_leq_y_c_blob.send(dgk.dgk_x_leq_y_c)
        correct_mod_div_work_c_blob.send(dgk.correct_mod_div_work_c)
        z_blob.send(dgk.z)
        end_communicate()
示例#20
0
def test_fc_secure_comm(input_sid,
                        master_address,
                        master_port,
                        setting=(512, 512)):
    test_name = "test_fc_secure_comm"
    print(f"\nTest for {test_name}: Start")
    modulus = 786433
    num_input_unit, num_output_unit = setting
    data_bit = 17

    print(f"Setting fc: "
          f"num_input_unit: {num_input_unit}, "
          f"num_output_unit: {num_output_unit}")

    data_range = 2**data_bit
    x_shape = [num_input_unit]
    w_shape = [num_output_unit, num_input_unit]
    b_shape = [num_output_unit]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                   get_prod(w_shape)).reshape(w_shape)
    bias = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                 get_prod(b_shape)).reshape(b_shape)
    input = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  get_prod(x_shape)).reshape(x_shape)
    input_c = generate_random_mask(modulus, x_shape)
    input_s = pmod(input - input_c, modulus)

    def check_correctness_online(x, w, output_s, output_c):
        actual = pmod(output_s.cuda() + output_c.cuda(), modulus)
        torch_x = x.reshape([1] + x_shape).cuda().double()
        torch_w = w.reshape(w_shape).cuda().double()
        expected = torch.mm(torch_x, torch_w.t())
        if bias is not None:
            expected += bias.cuda().double()
        expected = pmod(expected.reshape(output_s.shape), modulus)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + " online",
                                get_relative=True)

    def test_server():
        rank = Config.server_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = FcSecureServer(modulus, data_range, num_input_unit,
                              num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight, bias=bias)
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(input_s)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_online(input, weight, prot.output_s, output_c)

        end_communicate()

    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = FcSecureClient(modulus, data_range, num_input_unit,
                              num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online()
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()

    marshal_funcs([test_server, test_client])
    print(f"\nTest for {test_name}: End")
示例#21
0
def test_fc_fhe_comm():
    test_name = "test_fc_fhe_comm"
    print(f"\nTest for {test_name}: Start")
    modulus = 786433
    num_input_unit = 512
    num_output_unit = 512
    data_bit = 17

    print(f"Setting: num_input_unit {num_input_unit}, "
          f"num_output_unit: {num_output_unit}")

    data_range = 2**data_bit
    x_shape = [num_input_unit]
    w_shape = [num_output_unit, num_input_unit]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                   get_prod(w_shape)).reshape(w_shape)
    input_mask = gen_unirand_int_grain(0, modulus - 1,
                                       get_prod(x_shape)).reshape(x_shape)

    # input_mask = torch.arange(get_prod(x_shape)).reshape(x_shape)

    def check_correctness_offline(x, w, output_mask, output_c):
        actual = pmod(output_c.cuda() - output_mask.cuda(), modulus)
        torch_x = x.reshape([1] + x_shape).cuda().double()
        torch_w = w.reshape(w_shape).cuda().double()
        expected = torch.mm(torch_x, torch_w.t())
        expected = pmod(expected.reshape(output_mask.shape), modulus)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + " offline",
                                get_relative=True)

    def test_server():
        init_communicate(Config.server_rank)
        prot = FcFheServer(modulus, data_range, num_input_unit,
                           num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_offline(input_mask, weight, prot.output_mask_s,
                                  output_c)

        end_communicate()

    def test_client():
        init_communicate(Config.client_rank)
        prot = FcFheClient(modulus, data_range, num_input_unit,
                           num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_mask)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()

    marshal_funcs([test_server, test_client])
    print(f"\nTest for {test_name}: End")
示例#22
0
    def offline(self):
        self.offline_recv()

        self.delta_a = gen_unirand_int_grain(0, 1,
                                             self.num_elem).to(Config.device)
        # self.s = pmod(1 - 2 * self.delta_a, self.q_16)
        self.s = pmod(1 - 2 * self.delta_a, self.q_16)
        # self.r = gen_unirand_int_grain(0, 2 ** (self.work_bit + 1) - 1, self.num_elem).to(Config.device)
        self.r = gen_unirand_int_grain(0, self.q_23 - 1,
                                       self.num_elem).to(Config.device)
        self.alpha = pmod(self.r, self.work_range)
        self.alpha_i = self.common.decomp_to_bit(self.alpha).to(Config.device)
        self.beta_i_mask_s = gen_unirand_int_grain(
            0, self.q_16 - 1, self.decomp_bit_shape).to(Config.device)
        self.ci_mask_s = gen_unirand_int_grain(
            0, self.q_16 - 1,
            [self.work_bit + 1, self.num_elem]).to(Config.device)
        self.ci_mult_mask_s = gen_unirand_int_grain(
            1, self.q_16 - 1,
            [self.work_bit + 1, self.num_elem]).to(Config.device)
        self.shuffle_order = torch.rand([self.work_bit + 1, self.num_elem
                                         ]).argsort(dim=0).to(Config.device)
        self.delta_xor_mask_s = gen_unirand_int_grain(
            0, self.q_16 - 1, self.num_elem).to(Config.device)
        self.dgk_x_leq_y_mask_s = gen_unirand_int_grain(
            0, self.q_23 - 1, self.num_elem).to(Config.device)
        self.fast_zeros_sum_xor = torch.zeros(self.sum_shape).to(Config.device)

        self.mod_div_offline()

        refresher_ab_xor_c = EncRefresherServer(
            self.decomp_bit_shape, self.fhe_builder_16,
            self.common.sub_name("refresher_ab_xor_c"))

        fhe_beta_i_c = self.common.beta_i_c.get_recv()
        fhe_beta_i_c_for_sum_c = [
            fhe_beta_i_c[i].copy() for i in range(len(fhe_beta_i_c))
        ]
        fhe_alpha_beta_xor_c = self.xor_alpha_known_offline(
            self.alpha_i, fhe_beta_i_c, self.beta_i_mask_s)
        fhe_alpha_beta_xor_c = refresher_ab_xor_c.request(fhe_alpha_beta_xor_c)
        fhe_c_i_c = self.sum_c_i_offline(self.delta_a, fhe_beta_i_c_for_sum_c,
                                         fhe_alpha_beta_xor_c, self.s,
                                         self.alpha_i, self.ci_mask_s,
                                         self.ci_mult_mask_s,
                                         self.shuffle_order)
        self.common.c_i_c.send(fhe_c_i_c)

        fhe_delta_b_c = self.common.delta_b_c.get_recv()
        fhe_delta_xor_c = self.xor_delta_known_offline(self.delta_a,
                                                       fhe_delta_b_c,
                                                       self.delta_xor_mask_s)
        self.common.delta_xor_c.send(fhe_delta_xor_c)

        fhe_z_work_c = self.common.z_work_c.get_recv()
        fhe_z_work_c -= fhe_delta_xor_c
        fhe_z_work_c -= self.fhe_builder_23.build_plain_from_torch(
            self.dgk_x_leq_y_mask_s)
        self.common.dgk_x_leq_y_c.send(fhe_z_work_c)

        for ct in fhe_c_i_c + fhe_beta_i_c + fhe_beta_i_c_for_sum_c:
            del ct
        del fhe_beta_i_c, fhe_beta_i_c_for_sum_c, fhe_alpha_beta_xor_c, fhe_c_i_c, fhe_delta_b_c, fhe_delta_xor_c, fhe_z_work_c
        del refresher_ab_xor_c

        self.online_recv()
        torch_sync()
示例#23
0
def test_fhe_builder():
    print()
    print("Test for FheBuilder: start")
    modulus, degree = 12289, 2048
    # modulus, degree = 65537, 2048
    # modulus, degree = 786433, 4096
    # modulus, degree = 65537, 4096
    num_elem = 2 ** 17 - 1
    fhe_builder = FheBuilder(modulus, degree)
    fhe_builder.generate_keys()
    print(f"modulus: {modulus}, degree: {degree}")
    print()

    gpu_tensor = gen_unirand_int_grain(0, modulus - 1, num_elem)
    gpu_tensor_rev = gen_unirand_int_grain(0, modulus - 1, num_elem)
    with NamedTimerInstance(f"build_plain_from_torch with num_elem: {num_elem}"):
        plain = fhe_builder.build_plain_from_torch(gpu_tensor)
    with NamedTimerInstance(f"plain.export_as_torch_gpu() with num_elem: {num_elem}"):
        tensor_from_plain = plain.export_as_torch()
    print("Fhe Plain encrypt and decrypt: ", end="")
    assert(compare_expected_actual(gpu_tensor, tensor_from_plain, verbose=True, get_relative=True).RelAvgDiff == 0)
    print()

    with NamedTimerInstance(f"fhe_builder.build_enc with num_elem: {num_elem}"):
        cipher = fhe_builder.build_enc(num_elem)
    with NamedTimerInstance(f"cipher.encrypt_additive with num_elem: {num_elem}"):
        cipher.encrypt_additive(gpu_tensor)
    with NamedTimerInstance(f"fhe_builder.decrypt_to_torch with num_elem: {num_elem}"):
        tensor_from_cipher = fhe_builder.decrypt_to_torch(cipher)
    print("Fhe Enc encrypt and decrypt: ", end="")
    assert(compare_expected_actual(gpu_tensor, tensor_from_cipher, verbose=True, get_relative=True).RelAvgDiff == 0)
    print()

    pt = fhe_builder.build_plain_from_torch(gpu_tensor)
    ct = fhe_builder.build_enc_from_torch(gpu_tensor_rev)
    with NamedTimerInstance(f"EP add with num_elem: {num_elem}"):
        ct += pt
    expected = pmod(gpu_tensor + gpu_tensor_rev, modulus)
    actual = fhe_builder.decrypt_to_torch(ct)
    assert(compare_expected_actual(expected, actual, verbose=True, get_relative=True).RelAvgDiff == 0)
    print()

    ct1 = fhe_builder.build_enc_from_torch(gpu_tensor)
    ct2 = fhe_builder.build_enc_from_torch(gpu_tensor_rev)
    with NamedTimerInstance(f"EE add with num_elem: {num_elem}"):
        ct1 += ct2
    expected = pmod(gpu_tensor + gpu_tensor_rev, modulus)
    actual = fhe_builder.decrypt_to_torch(ct1)
    assert(compare_expected_actual(expected, actual, verbose=True, get_relative=True).RelAvgDiff == 0)
    print()

    pt = fhe_builder.build_plain_from_torch(gpu_tensor)
    ct = fhe_builder.build_enc_from_torch(gpu_tensor_rev)
    with NamedTimerInstance(f"EP sub with num_elem: {num_elem}"):
        ct -= pt
    expected = pmod(gpu_tensor_rev - gpu_tensor, modulus)
    actual = fhe_builder.decrypt_to_torch(ct)
    assert(compare_expected_actual(expected, actual, verbose=True, get_relative=True).RelAvgDiff == 0)
    print()

    ct1 = fhe_builder.build_enc_from_torch(gpu_tensor)
    ct2 = fhe_builder.build_enc_from_torch(gpu_tensor_rev)
    with NamedTimerInstance(f"EE add with num_elem: {num_elem}"):
        ct1 -= ct2
    expected = pmod(gpu_tensor - gpu_tensor_rev, modulus)
    actual = fhe_builder.decrypt_to_torch(ct1)
    assert(compare_expected_actual(expected, actual, verbose=True, get_relative=True).RelAvgDiff == 0)
    print()

    pt = fhe_builder.build_plain_from_torch(gpu_tensor)
    ct = fhe_builder.build_enc_from_torch(gpu_tensor_rev)
    fhe_builder.noise_budget(ct, name="before ep")
    with NamedTimerInstance(f"EP Mult with num_elem: {num_elem}"):
        ct *= pt
    fhe_builder.noise_budget(ct, name="after ep")
    expected = pmod(gpu_tensor.double() * gpu_tensor_rev.double(), modulus)
    actual = fhe_builder.decrypt_to_torch(ct)
    assert(compare_expected_actual(expected, actual, verbose=True, get_relative=True).RelAvgDiff == 0)
    print()

    ct1 = fhe_builder.build_enc_from_torch(gpu_tensor)
    ct2 = fhe_builder.build_enc_from_torch(gpu_tensor_rev)
    fhe_builder.noise_budget(ct1, name="before ep")
    with NamedTimerInstance(f"EE mult with num_elem: {num_elem}"):
        ct1 *= ct2
    fhe_builder.noise_budget(ct1, name="before ep")
    expected = pmod(gpu_tensor.double() * gpu_tensor_rev.double(), modulus).float()
    actual = fhe_builder.decrypt_to_torch(ct1)
    assert(compare_expected_actual(expected, actual, verbose=True, get_relative=True).RelAvgDiff == 0)
    print()

    ct1 = fhe_builder.build_enc_from_torch(gpu_tensor)
    ct2 = fhe_builder.build_enc_from_torch(gpu_tensor_rev)
    fhe_builder.noise_budget(ct1, name="before ep")
    with NamedTimerInstance(f"EE Add with num_elem: {num_elem}"):
        ct1 *= ct2
    fhe_builder.noise_budget(ct1, name="before ep")
    expected = pmod(gpu_tensor.double() * gpu_tensor_rev.double(), modulus)
    actual = fhe_builder.decrypt_to_torch(ct1)
    assert(compare_expected_actual(expected, actual, verbose=True, get_relative=True).RelAvgDiff == 0)
    print()

    ct = fhe_builder.build_enc_from_torch(gpu_tensor)
    fhe_builder.noise_budget(ct, name="before ep")
    with NamedTimerInstance(f"neg E with num_elem: {num_elem}"):
        ct = -ct
    fhe_builder.noise_budget(ct, name="before ep")
    expected = pmod(-gpu_tensor, modulus)
    actual = fhe_builder.decrypt_to_torch(ct)
    assert(compare_expected_actual(expected, actual, verbose=True, get_relative=True).RelAvgDiff == 0)
    print()

    print("Test for FheBuilder: Finish")
示例#24
0
def test_avgpool2x2_dgk(input_sid,
                        master_address,
                        master_port,
                        num_elem=2**17):
    test_name = "Avgpool2x2"
    print(f"\nTest for {test_name}: Start")
    data_bit = 20
    work_bit = 20
    data_range = 2**data_bit
    q_16 = 12289
    # q_23 = 786433
    q_23 = 7340033
    img_hw = 4
    print(f"Number of element: {num_elem}")

    fhe_builder_16 = FheBuilder(q_16, Config.n_16)
    fhe_builder_16.generate_keys()
    fhe_builder_23 = FheBuilder(q_23, Config.n_23)
    fhe_builder_23.generate_keys()

    img = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                num_elem)
    img_s = gen_unirand_int_grain(0, q_23 - 1, num_elem)
    img_c = pmod(img - img_s, q_23)

    def check_correctness_online(img, max_s, max_c):
        img = torch.where(img < q_23 // 2, img, img - q_23).cuda()
        pool = torch.nn.AvgPool2d(2)
        expected = pool(img.double().reshape([-1, img_hw, img_hw
                                              ])).reshape(-1) * 4
        expected = pmod(expected, q_23)
        actual = pmod(max_s + max_c, q_23)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + "_online",
                                get_relative=True)

    def test_server():
        init_communicate(Config.server_rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()
        prot = Avgpool2x2Server(num_elem, q_23, q_16, work_bit, data_bit,
                                img_hw, fhe_builder_16, fhe_builder_23,
                                "avgpool")

        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            prot.online(img_s)
            torch_sync()
        traffic_record.reset("server-online")

        blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_res_c")
        blob_out_c.prepare_recv()
        torch_sync()
        out_c = blob_out_c.get_recv()
        check_correctness_online(img, prot.out_s, out_c)

        end_communicate()

    def test_client():
        init_communicate(Config.client_rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()
        prot = Avgpool2x2Client(num_elem, q_23, q_16, work_bit, data_bit,
                                img_hw, fhe_builder_16, fhe_builder_23,
                                "avgpool")

        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            prot.online(img_c)
            torch_sync()
        traffic_record.reset("client-online")

        blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_res_c")
        torch_sync()
        blob_out_c.send(prot.out_c)
        end_communicate()

    if input_sid == Config.both_rank:
        marshal_funcs([test_server, test_client])
    elif input_sid == Config.server_rank:
        test_server()
    elif input_sid == Config.client_rank:
        test_client()

    print(f"\nTest for {test_name}: End")