示例#1
0
    def test_server():
        init_communicate(Config.server_rank)
        fhe_builder = FheBuilder(modulus, degree)
        comm_fhe_builder = CommFheBuilder(Config.server_rank, fhe_builder,
                                          "fhe_builder")

        comm_fhe_builder.recv_public_key()
        comm_fhe_builder.wait_and_build_public_key()

        blob_ori = BlobFheEnc(num_elem, comm_fhe_builder, ori_tag)
        blob_ori.send_from_torch(ori)

        enc_pk = fhe_builder.build_enc_from_torch(expected_tensor_pk)
        comm_fhe_builder.send_enc(enc_pk, pk_tag)

        comm_fhe_builder.recv_secret_key()
        comm_fhe_builder.wait_and_build_secret_key()

        enc_sk = fhe_builder.build_enc(num_elem)
        comm_fhe_builder.recv_enc(enc_sk, sk_tag)
        comm_fhe_builder.wait_enc(sk_tag)
        actual_tensor_sk = fhe_builder.decrypt_to_torch(enc_sk)
        compare_expected_actual(expected_tensor_sk,
                                actual_tensor_sk,
                                get_relative=True,
                                name="Recovering to test sk")

        dist.destroy_process_group()
示例#2
0
def test_fhe_refresh():
    print()
    print("Test: FHE refresh: Start")
    modulus, degree = 12289, 2048
    num_batch = 12
    num_elem = 2 ** 15
    fhe_builder = FheBuilder(modulus, degree)
    fhe_builder.generate_keys()

    def test_batch_server():
        init_communicate(Config.server_rank)
        shape = [num_batch, num_elem]
        tensor = gen_unirand_int_grain(0, modulus, shape)
        refresher = EncRefresherServer(shape, fhe_builder, "test_batch_refresher")
        enc = [fhe_builder.build_enc_from_torch(tensor[i]) for i in range(num_batch)]
        refreshed = refresher.request(enc)
        tensor_refreshed = fhe_builder.decrypt_to_torch(refreshed)
        compare_expected_actual(tensor, tensor_refreshed, get_relative=True, name="batch refresh")

        end_communicate()

    def test_batch_client():
        init_communicate(Config.client_rank)
        shape = [num_batch, num_elem]
        refresher = EncRefresherClient(shape, fhe_builder, "test_batch_refresher")
        refresher.prepare_recv()
        refresher.response()

        end_communicate()

    marshal_funcs([test_batch_server, test_batch_client])

    def test_1d_server():
        init_communicate(Config.server_rank)
        shape = num_elem
        tensor = gen_unirand_int_grain(0, modulus, shape)
        refresher = EncRefresherServer(shape, fhe_builder, "test_1d_refresher")
        enc = fhe_builder.build_enc_from_torch(tensor)
        refreshed = refresher.request(enc)
        tensor_refreshed = fhe_builder.decrypt_to_torch(refreshed)
        compare_expected_actual(tensor, tensor_refreshed, get_relative=True, name="1d_refresh")

        end_communicate()

    def test_1d_client():
        init_communicate(Config.client_rank)
        shape = num_elem
        refresher = EncRefresherClient(shape, fhe_builder, "test_1d_refresher")
        refresher.prepare_recv()
        refresher.response()

        end_communicate()

    marshal_funcs([test_1d_server, test_1d_client])

    print()
    print("Test: FHE refresh: End")
示例#3
0
def test_conv2d_fhe_ntt_single_thread():
    modulus = 786433
    img_hw = 16
    filter_hw = 3
    padding = 1
    num_input_channel = 64
    num_output_channel = 128
    data_bit = 17
    data_range = 2**data_bit

    x_shape = [num_input_channel, img_hw, img_hw]
    w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape)
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(w_shape)).reshape(w_shape)
    x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape)
    # x = torch.arange(get_prod(x_shape)).reshape(x_shape)
    # w = torch.arange(get_prod(w_shape)).reshape(w_shape)

    warming_up_cuda()
    prot = Conv2dFheNttSingleThread(modulus, data_range, img_hw, filter_hw,
                                    num_input_channel, num_output_channel,
                                    fhe_builder, "test_conv2d_fhe_ntt",
                                    padding)

    print("prot.num_input_batch", prot.num_input_batch)
    print("prot.num_output_batch", prot.num_output_batch)

    with NamedTimerInstance("encoding x"):
        prot.encode_input_to_fhe_batch(x)
    with NamedTimerInstance("conv2d with w"):
        prot.compute_conv2d(w)
    with NamedTimerInstance("conv2d masking output"):
        prot.masking_output()
    with NamedTimerInstance("decoding output"):
        y = prot.decode_output_from_fhe_batch()
    # actual = pmod(y, modulus)
    actual = pmod(y - prot.output_mask_s, modulus)
    # print("actual\n", actual)

    torch_x = x.reshape([1] + x_shape).double()
    torch_w = w.reshape(w_shape).double()
    with NamedTimerInstance("Conv2d Torch"):
        expected = F.conv2d(torch_x, torch_w, padding=padding)
        expected = pmod(expected.reshape(prot.output_shape), modulus)
    # print("expected", expected)
    compare_expected_actual(expected,
                            actual,
                            name="test_conv2d_fhe_ntt_single_thread",
                            get_relative=True)
示例#4
0
def test_fc_fhe_single_thread():
    test_name = "test_fc_fhe_single_thread"
    print(f"\nTest for {test_name}: Start")
    modulus = Config.q_23
    num_input_unit = 512
    num_output_unit = 512
    data_bit = 17

    data_range = 2**data_bit
    x_shape = [num_input_unit]
    w_shape = [num_output_unit, num_input_unit]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    # x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, get_prod(x_shape)).reshape(x_shape)
    w = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                              get_prod(w_shape)).reshape(w_shape)
    x = gen_unirand_int_grain(0, modulus, get_prod(x_shape)).reshape(x_shape)
    # w = gen_unirand_int_grain(0, modulus, get_prod(w_shape)).reshape(w_shape)
    # x = torch.arange(get_prod(x_shape)).reshape(x_shape)
    # w = torch.arange(get_prod(w_shape)).reshape(w_shape)

    warming_up_cuda()
    prot = FcFheSingleThread(modulus, data_range, num_input_unit,
                             num_output_unit, fhe_builder, test_name)

    print("prot.num_input_batch", prot.num_input_batch)
    print("prot.num_output_batch", prot.num_output_batch)
    print("prot.num_elem_in_piece", prot.num_elem_in_piece)

    with NamedTimerInstance("encoding x"):
        prot.encode_input_to_fhe_batch(x)
    with NamedTimerInstance("conv2d with w"):
        prot.compute_with_weight(w)
    with NamedTimerInstance("conv2d masking output"):
        prot.masking_output()
    with NamedTimerInstance("decoding output"):
        y = prot.decode_output_from_fhe_batch()
    actual = pmod(y, modulus)
    actual = pmod(y - prot.output_mask_s, modulus)
    # print("actual\n", actual)

    torch_x = x.reshape([1] + x_shape).double()
    torch_w = w.reshape(w_shape).double()
    with NamedTimerInstance("Conv2d Torch"):
        expected = torch.mm(torch_x, torch_w.t())
        expected = pmod(expected.reshape(prot.output_shape), modulus)
    compare_expected_actual(expected,
                            actual,
                            name=test_name,
                            get_relative=True)
    print(f"\nTest for {test_name}: End")
示例#5
0
    def test_server():
        rank = Config.server_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        img = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                    num_elem)
        img_s = gen_unirand_int_grain(0, q_23 - 1, num_elem)
        img_c = pmod(img - img_s, q_23)

        prot = Maxpool2x2DgkServer(num_elem, q_23, q_16, work_bit, data_bit,
                                   img_hw, fhe_builder_16, fhe_builder_23,
                                   "max_dgk")

        blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                               "recon_max_c")
        torch_sync()
        blob_img_c.send(img_c)

        torch_sync()
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            prot.online(img_s)
            torch_sync()
        traffic_record.reset("server-online")

        blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_max_c")
        blob_max_c.prepare_recv()
        torch_sync()
        max_c = blob_max_c.get_recv()
        check_correctness_online(img, prot.max_s, max_c)

        end_communicate()
示例#6
0
    def __init__(self, name, data_bit=5, work_bit=18,
                 n_16=Config.n_16, n_23=Config.n_23, q_16=Config.q_16, q_23=Config.q_23,
                 input_hw=32, input_channel=3):

        super().__init__(name)
        self.data_bit = data_bit
        self.work_bit = work_bit
        self.q_16 = q_16
        self.q_23 = q_23
        self.input_hw = input_hw
        self.input_channel = input_channel

        self.data_range = 2 ** data_bit
        self.fhe_builder_16 = FheBuilder(q_16, n_16)
        self.fhe_builder_23 = FheBuilder(q_23, n_23)

        self.context = SecureLayerContext(work_bit, data_bit, q_16, q_23, self.fhe_builder_16, self.fhe_builder_23,
                                          self.sub_name("context"))

        self.secure_nn_core = SecureNeuralNetwork(self.sub_name("secure_nn"))

        self.layer_dict = {}
示例#7
0
    def test_server():
        rank = Config.server_rank
        init_communicate(Config.server_rank, master_address=master_address, master_port=master_port)
        traffic_record = TrafficRecord()
        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        prot = ReluDgkServer(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk")

        blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a")
        blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s")
        torch_sync()
        blob_a_s.prepare_recv()
        a_s = blob_a_s.get_recv()

        torch_sync()
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            prot.online(a_s)
            torch_sync()
        traffic_record.reset("server-online")

        blob_max_s.send(prot.max_s)
        torch.cuda.empty_cache()
        end_communicate()
示例#8
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        prot = Maxpool2x2DgkClient(num_elem, q_23, q_16, work_bit, data_bit,
                                   img_hw, fhe_builder_16, fhe_builder_23,
                                   "max_dgk")
        blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                               "recon_max_c")
        blob_img_c.prepare_recv()
        torch_sync()
        img_c = blob_img_c.get_recv()

        torch_sync()
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            prot.online(img_c)
            torch_sync()
        traffic_record.reset("client-online")

        blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_max_c")
        torch_sync()
        blob_max_c.send(prot.max_c)
        end_communicate()
示例#9
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank, master_address=master_address, master_port=master_port)
        traffic_record = TrafficRecord()
        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        a = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem)
        a_c = gen_unirand_int_grain(0, q_23 - 1, num_elem)
        a_s = pmod(a - a_c, q_23)

        prot = ReluDgkClient(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk")

        blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a")
        blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s")
        torch_sync()
        blob_a_s.send(a_s)
        blob_max_s.prepare_recv()

        torch_sync()
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            prot.online(a_c)
            torch_sync()
        traffic_record.reset("client-online")

        max_s = blob_max_s.get_recv()
        check_correctness_online(a, max_s, prot.max_c)

        torch.cuda.empty_cache()
        end_communicate()
示例#10
0
class SecureNnFramework(NamedBase, ContextRankBase):
    class_name = "SecureNnFramework"
    layers: List[SecureLayerBase]
    layer_dict: Dict[str, SecureLayerBase]
    input_img: torch.Tensor
    output_res: torch.Tensor
    comm_base: CommBase

    # def __init__(self, name, data_bit=5, work_bit=18, q_16=12289, q_23=7340033, input_hw=32, input_channel=3):
    def __init__(self, name, data_bit=5, work_bit=18,
                 n_16=Config.n_16, n_23=Config.n_23, q_16=Config.q_16, q_23=Config.q_23,
                 input_hw=32, input_channel=3):

        super().__init__(name)
        self.data_bit = data_bit
        self.work_bit = work_bit
        self.q_16 = q_16
        self.q_23 = q_23
        self.input_hw = input_hw
        self.input_channel = input_channel

        self.data_range = 2 ** data_bit
        self.fhe_builder_16 = FheBuilder(q_16, n_16)
        self.fhe_builder_23 = FheBuilder(q_23, n_23)

        self.context = SecureLayerContext(work_bit, data_bit, q_16, q_23, self.fhe_builder_16, self.fhe_builder_23,
                                          self.sub_name("context"))

        self.secure_nn_core = SecureNeuralNetwork(self.sub_name("secure_nn"))

        self.layer_dict = {}

    def generate_random_data(self, shape):
        return gen_unirand_int_grain(-self.data_range//2 + 1, self.data_range//2, get_prod(shape)).reshape(shape)

    def load_layers(self, layers):
        self.layers = layers
        for layer in layers:
            self.layer_dict[layer.name] = layer

        self.secure_nn_core.load_layers(layers)
        self.secure_nn_core.load_context(self.context)
        return self

    def get_input_shape(self):
        return self.secure_nn_core.input_layer.get_output_shape()

    def get_output_shape(self):
        return self.secure_nn_core.output_layer.get_output_shape()

    def set_rank(self, rank):
        self.rank = rank
        self.context.set_rank(rank)
        return self

    def init_communication(self, **kwargs):
        init_communicate(self.rank, **kwargs)
        self.comm_base = CommBase(self.rank, self.sub_name("comm_base"))
        return self

    def fhe_builder_sync(self):
        comm_fhe_16 = CommFheBuilder(self.rank, self.fhe_builder_16, self.sub_name("fhe_builder_16"))
        comm_fhe_23 = CommFheBuilder(self.rank, self.fhe_builder_23, self.sub_name("fhe_builder_23"))
        torch_sync()
        if self.is_server():
            comm_fhe_16.recv_public_key()
            comm_fhe_23.recv_public_key()
            comm_fhe_16.wait_and_build_public_key()
            comm_fhe_23.wait_and_build_public_key()
        elif self.is_client():
            self.fhe_builder_16.generate_keys()
            self.fhe_builder_23.generate_keys()
            comm_fhe_16.send_public_key()
            comm_fhe_23.send_public_key()
        torch_sync()
        return self

    def fill_random_weight(self):
        assert(self.is_server())
        for layer in self.layers:
            if not layer.has_weight: continue
            layer.load_weight(self.generate_random_data(layer.weight_shape))
        return self

    def fill_input(self, input_img):
        assert(self.is_client())
        self.input_img = input_img
        self.secure_nn_core.feed_input(input_img)
        return self

    def fill_random_input(self):
        assert(self.is_client())
        self.fill_input(self.generate_random_data(self.get_input_shape()))
        return self

    def offline(self):
        self.secure_nn_core.offline()
        return self

    def online(self):
        self.secure_nn_core.online()
        return self

    def check_layers(self, get_plain_net_func, all_layer_pair):
        secure_input_layer = self.layers[0]
        secure_input_layer.reconstructed_to_server(self.comm_base, self.q_23)

        for secure_layer_name, plain_layer_name in all_layer_pair:
            secure_layer = self.layer_dict[secure_layer_name]
            secure_layer.reconstructed_to_server(self.comm_base, self.q_23)

        if self.is_server():
            plain_net = get_plain_net_func()
            plain_output_dict = {}

            def hook_generator(name):
                def hook(module, input, output):
                    plain_output_dict[name] = output.data.detach().clone()
                return hook

            for secure_layer_name, plain_layer_name in all_layer_pair:
                # plain_output_dict[plain_layer_name] = torch.zeros()

                plain_layer = getattr(plain_net, plain_layer_name)
                plain_layer.register_forward_hook(hook_generator(plain_layer_name))

            input_img = secure_input_layer.get_reconstructed_output()
            input_img = input_img.reshape([1] + list(input_img.shape)).cuda()

            meta_rg = MetaTruncRandomGenerator()
            meta_rg.reset_rg("plain")
            plain_output = plain_net(input_img)

            for secure_layer_name, plain_layer_name in all_layer_pair:
                print("secure_layer_name, plain_layer_name: %s, %s"%(secure_layer_name, plain_layer_name))
                plain_output = plain_output_dict[plain_layer_name]
                secure_layer = self.layer_dict[secure_layer_name]
                secure_output = secure_layer.get_reconstructed_output()
                compare_expected_actual(plain_output, secure_output,
                                        name=f"compare secure-plain: {secure_layer_name}, {plain_layer_name}", get_relative=True)
                # print("secure", secure_output.shape, secure_output)
                # print("plain", plain_output.shape, plain_output)

        return self

    def check_correctness(self, verify_func, is_argmax=False, truth=None):
        blob_input_img = BlobTorch(self.get_input_shape(), torch.float, self.comm_base, "input_img")
        blob_actual_output = BlobTorch(self.get_output_shape(), torch.float, self.comm_base, "actual_output")
        blob_truth = BlobTorch(1, torch.float, self.comm_base, "truth")

        if self.is_server():
            blob_input_img.prepare_recv()
            blob_actual_output.prepare_recv()
            blob_truth.prepare_recv()
            torch_sync()
            input_img = blob_input_img.get_recv()
            actual_output = blob_actual_output.get_recv()
            truth = int(blob_truth.get_recv().item())
            verify_func(self, input_img, actual_output, self.q_23)

            actual_output = nmod(actual_output, self.q_23).cuda()
            _, actual_max = torch.max(actual_output, 0)
            print(f"truth: {truth}, actual: {actual_max}, MatchTruth: {truth == actual_max}")

        if self.is_client():
            torch_sync()
            actual_output = self.secure_nn_core.get_argmax_output() if is_argmax else self.secure_nn_core.get_output()
            blob_input_img.send(self.input_img)
            blob_actual_output.send(actual_output)
            blob_truth.send(torch.tensor(truth))

        return self

    def end_communication(self):
        end_communicate()
        return self
示例#11
0
def test_shares_mult():
    print("\nTest for Shares Mult: Start")
    modulus = Config.q_23
    num_elem = 2**17
    print(f"Number of element: {num_elem}")

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    a = gen_unirand_int_grain(0, modulus - 1, num_elem)
    a_s = gen_unirand_int_grain(0, modulus - 1, num_elem)
    a_c = pmod(a - a_s, modulus)
    b = gen_unirand_int_grain(0, modulus - 1, num_elem)
    b_s = gen_unirand_int_grain(0, modulus - 1, num_elem)
    b_c = pmod(b - b_s, modulus)

    def check_correctness_offline(u, v, z_s, z_c):
        expected = pmod(
            u.double().to(Config.device) * v.double().to(Config.device),
            modulus)
        actual = pmod(z_s + z_c, modulus)
        compare_expected_actual(expected,
                                actual,
                                name="shares_mult_offline",
                                get_relative=True)

    def check_correctness_online(a, b, c_s, c_c):
        expected = pmod(
            a.double().to(Config.device) * b.double().to(Config.device),
            modulus)
        actual = pmod(c_s + c_c, modulus)
        compare_expected_actual(expected,
                                actual,
                                name="shares_mult_online",
                                get_relative=True)

    def test_server():
        init_communicate(Config.server_rank)
        prot = SharesMultServer(num_elem, modulus, fhe_builder,
                                "test_shares_mult")
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(a_s, b_s)
            torch_sync()

        blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_u_c")
        blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_v_c")
        blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_z_c")
        blob_u_c.prepare_recv()
        blob_v_c.prepare_recv()
        blob_z_c.prepare_recv()
        torch_sync()
        u_c = blob_u_c.get_recv()
        v_c = blob_v_c.get_recv()
        z_c = blob_z_c.get_recv()
        u = pmod(prot.u_s + u_c, modulus)
        v = pmod(prot.v_s + v_c, modulus)
        check_correctness_online(u, v, prot.z_s, z_c)

        blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c")
        blob_c_c.prepare_recv()
        torch_sync()
        c_c = blob_c_c.get_recv()
        check_correctness_online(a, b, prot.c_s, c_c)
        end_communicate()

    def test_client():
        init_communicate(Config.client_rank)
        prot = SharesMultClient(num_elem, modulus, fhe_builder,
                                "test_shares_mult")
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online(a_c, b_c)
            torch_sync()

        blob_u_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_u_c")
        blob_v_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_v_c")
        blob_z_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                             "recon_z_c")
        torch_sync()
        blob_u_c.send(prot.u_c)
        blob_v_c.send(prot.v_c)
        blob_z_c.send(prot.z_c)
        blob_c_c = BlobTorch(num_elem, torch.float, prot.comm_base, "c_c")
        torch_sync()
        blob_c_c.send(prot.c_c)
        end_communicate()

    marshal_funcs([test_server, test_client])
    print("\nTest for Shares Mult: End")
示例#12
0
def test_conv2d_secure_comm(input_sid,
                            master_address,
                            master_port,
                            setting=(16, 3, 128, 128)):
    test_name = "Conv2d Secure Comm"
    print(f"\nTest for {test_name}: Start")
    modulus = 786433
    padding = 1
    img_hw, filter_hw, num_input_channel, num_output_channel = setting
    data_bit = 17
    data_range = 2**data_bit
    # n_23 = 8192
    n_23 = 16384
    print(f"Setting covn2d: img_hw: {img_hw}, "
          f"filter_hw: {filter_hw}, "
          f"num_input_channel: {num_input_channel}, "
          f"num_output_channel: {num_output_channel}")

    x_shape = [num_input_channel, img_hw, img_hw]
    w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw]
    b_shape = [num_output_channel]

    fhe_builder = FheBuilder(modulus, n_23)
    fhe_builder.generate_keys()

    weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                   get_prod(w_shape)).reshape(w_shape)
    bias = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                 get_prod(b_shape)).reshape(b_shape)
    input = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  get_prod(x_shape)).reshape(x_shape)
    input_c = generate_random_mask(modulus, x_shape)
    input_s = pmod(input - input_c, modulus)

    def check_correctness_online(x, w, b, output_s, output_c):
        actual = pmod(output_s.cuda() + output_c.cuda(), modulus)
        torch_x = x.reshape([1] + x_shape).cuda().double()
        torch_w = w.reshape(w_shape).cuda().double()
        torch_b = b.cuda().double() if b is not None else None
        expected = F.conv2d(torch_x, torch_w, padding=padding, bias=torch_b)
        expected = pmod(expected.reshape(output_s.shape), modulus)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + " online",
                                get_relative=True)

    def test_server():
        rank = Config.server_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = Conv2dSecureServer(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_secure_comm",
                                  padding=padding)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight, bias=bias)
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(input_s)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_online(input, weight, bias, prot.output_s, output_c)

        end_communicate()

    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = Conv2dSecureClient(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_secure_comm",
                                  padding=padding)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online()
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()

    if input_sid == Config.both_rank:
        marshal_funcs([test_server, test_client])
    elif input_sid == Config.server_rank:
        test_server()
    elif input_sid == Config.client_rank:
        test_client()

    print(f"\nTest for {test_name}: End")
示例#13
0
def test_conv2d_fhe_ntt_comm():
    test_name = "Conv2d Fhe NTT Comm"
    print(f"\nTest for {test_name}: Start")
    modulus = 786433
    img_hw = 2
    filter_hw = 3
    padding = 1
    num_input_channel = 512
    num_output_channel = 512
    data_bit = 17
    data_range = 2**data_bit
    print(f"Setting: img_hw {img_hw}, "
          f"filter_hw: {filter_hw}, "
          f"num_input_channel: {num_input_channel}, "
          f"num_output_channel: {num_output_channel}")

    x_shape = [num_input_channel, img_hw, img_hw]
    w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                   get_prod(w_shape)).reshape(w_shape)
    input_mask = gen_unirand_int_grain(0, modulus - 1,
                                       get_prod(x_shape)).reshape(x_shape)

    # input_mask = torch.arange(get_prod(x_shape)).reshape(x_shape)

    def check_correctness_offline(x, w, output_mask, output_c):
        actual = pmod(output_c.cuda() - output_mask.cuda(), modulus)
        torch_x = x.reshape([1] + x_shape).cuda().double()
        torch_w = w.reshape(w_shape).cuda().double()
        expected = F.conv2d(torch_x, torch_w, padding=padding)
        expected = pmod(expected.reshape(output_mask.shape), modulus)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + " offline",
                                get_relative=True)

    def test_server():
        init_communicate(Config.server_rank)
        prot = Conv2dFheNttServer(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_fhe_ntt_comm",
                                  padding=padding)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_offline(input_mask, weight, prot.output_mask_s,
                                  output_c)

        end_communicate()

    def test_client():
        init_communicate(Config.client_rank)
        prot = Conv2dFheNttClient(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_fhe_ntt_comm",
                                  padding=padding)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_mask)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()

    marshal_funcs([test_server, test_client])
    print(f"\nTest for {test_name}: End")
示例#14
0
    def test_server():
        rank = Config.server_rank
        init_communicate(Config.server_rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        dgk = DgkBitServer(num_elem, q_23, q_16, work_bit, data_bit,
                           fhe_builder_16, fhe_builder_23, "DgkBitTest")

        x_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "x")
        y_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y")
        y_sub_x_s_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                   "y_sub_x_s")
        x_blob.prepare_recv()
        y_blob.prepare_recv()
        y_sub_x_s_blob.prepare_recv()
        torch_sync()
        x = x_blob.get_recv()
        y = y_blob.get_recv()
        y_sub_x_s = y_sub_x_s_blob.get_recv()

        torch_sync()
        with NamedTimerInstance("Server Offline"):
            dgk.offline()
        # y_sub_x_s = pmod(y_s.to(Config.device) - x_s.to(Config.device), q_23)
        torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            dgk.online(y_sub_x_s)
        traffic_record.reset("server-online")

        dgk_x_leq_y_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                       "dgk_x_leq_y_c")
        correct_mod_div_work_c_blob = BlobTorch(num_elem, torch.float,
                                                dgk.comm_base,
                                                "correct_mod_div_work_c")
        z_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "z")
        dgk_x_leq_y_c_blob.prepare_recv()
        correct_mod_div_work_c_blob.prepare_recv()
        z_blob.prepare_recv()
        torch_sync()
        dgk_x_leq_y_c = dgk_x_leq_y_c_blob.get_recv()
        correct_mod_div_work_c = correct_mod_div_work_c_blob.get_recv()
        z = z_blob.get_recv()
        check_correctness(x, y, dgk.dgk_x_leq_y_s, dgk_x_leq_y_c)
        check_correctness_mod_div(dgk.r, z, dgk.correct_mod_div_work_s,
                                  correct_mod_div_work_c)
        end_communicate()
示例#15
0
def test_fc_secure_comm(input_sid,
                        master_address,
                        master_port,
                        setting=(512, 512)):
    test_name = "test_fc_secure_comm"
    print(f"\nTest for {test_name}: Start")
    modulus = 786433
    num_input_unit, num_output_unit = setting
    data_bit = 17

    print(f"Setting fc: "
          f"num_input_unit: {num_input_unit}, "
          f"num_output_unit: {num_output_unit}")

    data_range = 2**data_bit
    x_shape = [num_input_unit]
    w_shape = [num_output_unit, num_input_unit]
    b_shape = [num_output_unit]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                   get_prod(w_shape)).reshape(w_shape)
    bias = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                 get_prod(b_shape)).reshape(b_shape)
    input = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  get_prod(x_shape)).reshape(x_shape)
    input_c = generate_random_mask(modulus, x_shape)
    input_s = pmod(input - input_c, modulus)

    def check_correctness_online(x, w, output_s, output_c):
        actual = pmod(output_s.cuda() + output_c.cuda(), modulus)
        torch_x = x.reshape([1] + x_shape).cuda().double()
        torch_w = w.reshape(w_shape).cuda().double()
        expected = torch.mm(torch_x, torch_w.t())
        if bias is not None:
            expected += bias.cuda().double()
        expected = pmod(expected.reshape(output_s.shape), modulus)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + " online",
                                get_relative=True)

    def test_server():
        rank = Config.server_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = FcSecureServer(modulus, data_range, num_input_unit,
                              num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight, bias=bias)
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(input_s)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_online(input, weight, prot.output_s, output_c)

        end_communicate()

    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = FcSecureClient(modulus, data_range, num_input_unit,
                              num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online()
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()

    marshal_funcs([test_server, test_client])
    print(f"\nTest for {test_name}: End")
示例#16
0
def test_fc_fhe_comm():
    test_name = "test_fc_fhe_comm"
    print(f"\nTest for {test_name}: Start")
    modulus = 786433
    num_input_unit = 512
    num_output_unit = 512
    data_bit = 17

    print(f"Setting: num_input_unit {num_input_unit}, "
          f"num_output_unit: {num_output_unit}")

    data_range = 2**data_bit
    x_shape = [num_input_unit]
    w_shape = [num_output_unit, num_input_unit]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                   get_prod(w_shape)).reshape(w_shape)
    input_mask = gen_unirand_int_grain(0, modulus - 1,
                                       get_prod(x_shape)).reshape(x_shape)

    # input_mask = torch.arange(get_prod(x_shape)).reshape(x_shape)

    def check_correctness_offline(x, w, output_mask, output_c):
        actual = pmod(output_c.cuda() - output_mask.cuda(), modulus)
        torch_x = x.reshape([1] + x_shape).cuda().double()
        torch_w = w.reshape(w_shape).cuda().double()
        expected = torch.mm(torch_x, torch_w.t())
        expected = pmod(expected.reshape(output_mask.shape), modulus)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + " offline",
                                get_relative=True)

    def test_server():
        init_communicate(Config.server_rank)
        prot = FcFheServer(modulus, data_range, num_input_unit,
                           num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_offline(input_mask, weight, prot.output_mask_s,
                                  output_c)

        end_communicate()

    def test_client():
        init_communicate(Config.client_rank)
        prot = FcFheClient(modulus, data_range, num_input_unit,
                           num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_mask)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()

    marshal_funcs([test_server, test_client])
    print(f"\nTest for {test_name}: End")
示例#17
0
def test_secure_nn():
    test_name = "test_secure_nn"
    print(f"\nTest for {test_name}: Start")
    data_bit = 5
    work_bit = 17
    data_range = 2**data_bit
    q_16 = 12289
    # q_23 = 786433
    q_23 = 7340033
    # q_23 = 8273921
    input_img_hw = 16
    input_channel = 3
    pow_to_div = 2

    fhe_builder_16 = FheBuilder(q_16, 2048)
    fhe_builder_23 = FheBuilder(q_23, 8192)

    input_shape = [input_channel, input_img_hw, input_img_hw]

    context = SecureLayerContext(work_bit, data_bit, q_16, q_23,
                                 fhe_builder_16, fhe_builder_23,
                                 test_name + "_context")

    input_layer = InputSecureLayer(input_shape, "input_layer")
    conv1 = Conv2dSecureLayer(3, 5, 3, "conv1", padding=1)
    relu1 = ReluSecureLayer("relu1")
    trunc1 = TruncSecureLayer("trunc1")
    pool1 = Maxpool2x2SecureLayer("pool1")
    conv2 = Conv2dSecureLayer(5, 10, 3, "conv2", padding=1)
    flatten = FlattenSecureLayer("flatten")
    fc1 = FcSecureLayer(32, "fc1")
    output_layer = OutputSecureLayer("output_layer")

    secure_nn = SecureNeuralNetwork("secure_nn")
    secure_nn.load_layers([
        input_layer, conv1, pool1, relu1, trunc1, conv2, flatten, fc1,
        output_layer
    ])
    # secure_nn.load_layers([input_layer, relu1, trunc1, output_layer])
    secure_nn.load_context(context)

    def generate_random_data(shape):
        return gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                     get_prod(shape)).reshape(shape)

    conv1_w = generate_random_data(conv1.weight_shape)
    conv2_w = generate_random_data(conv2.weight_shape)
    fc1_w = generate_random_data(fc1.weight_shape)

    def check_correctness(input_img, output):
        torch_pool1 = torch.nn.MaxPool2d(2)

        x = input_img.to(Config.device).double()
        x = x.reshape([1] + list(x.shape))
        x = pmod(F.conv2d(x, conv1_w.to(Config.device).double(), padding=1),
                 q_23)
        x = pmod(F.relu(nmod(x, q_23)), q_23)
        x = pmod(torch_pool1(nmod(x, q_23)), q_23)
        x = pmod(x // (2**pow_to_div), q_23)
        x = pmod(F.conv2d(x, conv2_w.to(Config.device).double(), padding=1),
                 q_23)
        x = x.view(-1)
        x = pmod(
            torch.mm(x.view(1, -1),
                     fc1_w.to(Config.device).double().t()).view(-1), q_23)

        expected = x
        actual = pmod(output, q_23)
        if len(expected.shape) == 4 and expected.shape[0] == 1:
            expected = expected.reshape(expected.shape[1:])
        compare_expected_actual(expected,
                                actual,
                                name=test_name,
                                get_relative=True)

    def test_server():
        rank = Config.server_rank
        init_communicate(rank)
        context.set_rank(rank)

        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        conv1.load_weight(conv1_w)
        conv2.load_weight(conv2_w)
        fc1.load_weight(fc1_w)
        trunc1.set_div_to_pow(pow_to_div)

        with NamedTimerInstance("Server Offline"):
            secure_nn.offline()
            torch_sync()

        with NamedTimerInstance("Server Online"):
            secure_nn.online()
            torch_sync()

        end_communicate()

    def test_client():
        rank = Config.client_rank
        init_communicate(rank)
        context.set_rank(rank)

        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        input_img = generate_random_data(input_shape)
        trunc1.set_div_to_pow(pow_to_div)
        secure_nn.feed_input(input_img)

        with NamedTimerInstance("Client Offline"):
            secure_nn.offline()
            torch_sync()

        with NamedTimerInstance("Client Online"):
            secure_nn.online()
            torch_sync()

        check_correctness(input_img, secure_nn.get_output())
        end_communicate()

    marshal_funcs([test_server, test_client])
    print(f"\nTest for {test_name}: End")
示例#18
0
def test_avgpool2x2_dgk(input_sid,
                        master_address,
                        master_port,
                        num_elem=2**17):
    test_name = "Avgpool2x2"
    print(f"\nTest for {test_name}: Start")
    data_bit = 20
    work_bit = 20
    data_range = 2**data_bit
    q_16 = 12289
    # q_23 = 786433
    q_23 = 7340033
    img_hw = 4
    print(f"Number of element: {num_elem}")

    fhe_builder_16 = FheBuilder(q_16, Config.n_16)
    fhe_builder_16.generate_keys()
    fhe_builder_23 = FheBuilder(q_23, Config.n_23)
    fhe_builder_23.generate_keys()

    img = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                num_elem)
    img_s = gen_unirand_int_grain(0, q_23 - 1, num_elem)
    img_c = pmod(img - img_s, q_23)

    def check_correctness_online(img, max_s, max_c):
        img = torch.where(img < q_23 // 2, img, img - q_23).cuda()
        pool = torch.nn.AvgPool2d(2)
        expected = pool(img.double().reshape([-1, img_hw, img_hw
                                              ])).reshape(-1) * 4
        expected = pmod(expected, q_23)
        actual = pmod(max_s + max_c, q_23)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + "_online",
                                get_relative=True)

    def test_server():
        init_communicate(Config.server_rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()
        prot = Avgpool2x2Server(num_elem, q_23, q_16, work_bit, data_bit,
                                img_hw, fhe_builder_16, fhe_builder_23,
                                "avgpool")

        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            prot.online(img_s)
            torch_sync()
        traffic_record.reset("server-online")

        blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_res_c")
        blob_out_c.prepare_recv()
        torch_sync()
        out_c = blob_out_c.get_recv()
        check_correctness_online(img, prot.out_s, out_c)

        end_communicate()

    def test_client():
        init_communicate(Config.client_rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()
        prot = Avgpool2x2Client(num_elem, q_23, q_16, work_bit, data_bit,
                                img_hw, fhe_builder_16, fhe_builder_23,
                                "avgpool")

        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            prot.online(img_c)
            torch_sync()
        traffic_record.reset("client-online")

        blob_out_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_res_c")
        torch_sync()
        blob_out_c.send(prot.out_c)
        end_communicate()

    if input_sid == Config.both_rank:
        marshal_funcs([test_server, test_client])
    elif input_sid == Config.server_rank:
        test_server()
    elif input_sid == Config.client_rank:
        test_client()

    print(f"\nTest for {test_name}: End")
示例#19
0
def test_rotation():
    modulus = Config.q_23
    degree = Config.n_23
    fhe_builder = FheBuilder(modulus, degree)
    fhe_builder.generate_keys()
    fhe_builder.generate_galois_keys()

    # x = gen_unirand_int_grain(0, modulus-1, degree)
    x = torch.arange(degree)
    with NamedTimerInstance("Fhe Encrypt"):
        enc = fhe_builder.build_enc_from_torch(x)
    enc_less = fhe_builder.build_enc_from_torch(x)
    plain = fhe_builder.build_plain_from_torch(x)

    fhe_builder.noise_budget(enc, "before mul")
    with NamedTimerInstance("ep mult"):
        enc *= plain
        enc_less *= plain
    fhe_builder.noise_budget(enc, "after mul")
    with NamedTimerInstance("ee add"):
        for i in range(128):
            enc += enc_less
    fhe_builder.noise_budget(enc, "after add")
    with NamedTimerInstance("rot"):
        fhe_builder.evaluator.rotate_rows_inplace(enc.cts[0], 64,
                                                  fhe_builder.galois_keys)
    fhe_builder.noise_budget(enc, "after rot")
    print(fhe_builder.decrypt_to_torch(enc))
示例#20
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, n_16)
        fhe_builder_23 = FheBuilder(q_23, n_23)
        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        dgk = DgkBitClient(num_elem, q_23, q_16, work_bit, data_bit,
                           fhe_builder_16, fhe_builder_23, "DgkBitTest")

        x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  num_elem)
        y = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  num_elem)
        x_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                    num_elem)
        y_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                    num_elem)
        x_s = pmod(x - x_c, q_23)
        y_s = pmod(y - y_c, q_23)
        y_sub_x_s = pmod(y_s - x_s, q_23)

        x_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "x")
        y_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y")
        y_sub_x_s_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                   "y_sub_x_s")
        torch_sync()
        x_blob.send(x)
        y_blob.send(y)
        y_sub_x_s_blob.send(y_sub_x_s)

        torch_sync()
        with NamedTimerInstance("Client Offline"):
            dgk.offline()
        y_sub_x_c = pmod(y_c - x_c, q_23)
        traffic_record.reset("client-offline")
        torch_sync()

        with NamedTimerInstance("Client Online"):
            dgk.online(y_sub_x_c)
        traffic_record.reset("client-online")

        dgk_x_leq_y_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                       "dgk_x_leq_y_c")
        correct_mod_div_work_c_blob = BlobTorch(num_elem, torch.float,
                                                dgk.comm_base,
                                                "correct_mod_div_work_c")
        z_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "z")
        torch_sync()
        dgk_x_leq_y_c_blob.send(dgk.dgk_x_leq_y_c)
        correct_mod_div_work_c_blob.send(dgk.correct_mod_div_work_c)
        z_blob.send(dgk.z)
        end_communicate()