Exemplo n.º 1
0
    def masking_output(self):
        spread_mask = generate_random_mask(
            self.modulus, [self.num_output_batch, self.degree])
        self.output_mask_s = torch.zeros(self.num_output_unit).double()

        pod_vector = uIntVector()
        pt = Plaintext()
        for idx_output_batch in range(self.num_output_batch):
            encoding_tensor = torch.zeros(self.degree, dtype=torch.float)
            for idx_piece in range(self.num_piece_in_batch):
                idx_output_unit = self.index_output_batch_to_units(
                    idx_output_batch, idx_piece)
                if idx_output_unit is False:
                    break
                padded_span = self.num_elem_in_piece
                start_piece = idx_piece * padded_span
                arr = spread_mask[idx_output_batch,
                                  start_piece:start_piece + padded_span]
                encoding_tensor[start_piece:start_piece + padded_span] = arr
                self.output_mask_s[idx_output_unit] = arr.double().sum()
            encoding_tensor = pmod(encoding_tensor, self.modulus)
            pod_vector.from_np(encoding_tensor.numpy().astype(np.uint64))
            self.batch_encoder.encode(pod_vector, pt)
            self.evaluator.add_plain_inplace(self.output_cts[idx_output_batch],
                                             pt)

        self.output_mask_s = pmod(self.output_mask_s, self.modulus)
Exemplo n.º 2
0
    def offline(self, weight, bias=None):
        self.weight = weight
        self.bias = bias.cuda().double() if bias is not None else None
        self.fhe_ntt.offline(weight)
        self.output_mask_s = self.fhe_ntt.output_mask_s.cuda().double()
        self.torch_w = weight.reshape(self.weight_shape).cuda().double()

        # warming up
        warm_up_x = generate_random_mask(
            self.modulus, [1] + list(self.input_shape)).cuda().double()
        warm_up_y = F.conv2d(warm_up_x, self.torch_w, padding=self.padding)
Exemplo n.º 3
0
    def offline(self, weight, bias):
        self.weight = weight.cuda().double()
        self.bias = bias.cuda().double() if bias is not None else None
        self.offline_core.offline(weight)
        self.output_mask_s = self.offline_core.output_mask_s.cuda().double()
        self.torch_w = weight.reshape(self.weight_shape).cuda().double()

        # warming up
        warm_up_x = generate_random_mask(
            self.modulus, [1] + list(self.input_shape)).cuda().double()
        warm_up_y = torch.mm(warm_up_x, self.weight.t())
Exemplo n.º 4
0
 def offline(self):
     device = self.next_input_device
     dtype = self.next_input_dtype
     swap_prot_name = self.sub_name("swap_prot")
     modulus = self.context.q_23
     if self.is_server():
         self.swap_prot = SwapToClientOfflineServer(
             get_prod(self.input_shape), modulus, swap_prot_name)
         self.dummy_input_s = torch.zeros(
             self.input_shape).to(device).type(dtype)
         self.swap_prot.offline()
     elif self.is_client():
         self.swap_prot = SwapToClientOfflineClient(
             get_prod(self.input_shape), modulus, swap_prot_name)
         self.output_share = generate_random_mask(modulus, self.input_shape)
         self.swap_prot.offline(self.output_share.reshape(-1))
         self.output_share = self.output_share.to(device).type(dtype)
Exemplo n.º 5
0
    def offline(self):
        modulus = self.context.q_23
        swap_prot_name = self.sub_name("swap_prot")

        if self.is_need_swap:
            if self.is_server():
                self.swap_prot = SwapToClientOfflineServer(
                    get_prod(self.input_shape), modulus, swap_prot_name)
                self.swap_prot.offline()
            elif self.is_client():
                self.swap_prot = SwapToClientOfflineClient(
                    get_prod(self.input_shape), modulus, swap_prot_name)
                self.swapped_input_c = generate_random_mask(
                    modulus, self.input_shape)
                self.swap_prot.offline(self.swapped_input_c.reshape(-1))
                self.swapped_input_c = self.swapped_input_c.to(
                    Config.device).reshape(self.input_shape)
        if not self.is_need_swap and self.is_client():
            self.swapped_input_c = self.get_input_share().to(Config.device)
Exemplo n.º 6
0
def test_comm_fhe_builder():
    num_elem = 2**17
    # modulus, degree = 12289, 2048
    modulus, degree = Config.q_16, Config.n_16
    expected_tensor_pk = torch.zeros(num_elem).float() + 23
    expected_tensor_sk = torch.zeros(num_elem).float() + 42
    pk_tag = "pk"
    sk_tag = "sk"
    ori_tag = "ori"

    ori = generate_random_mask(modulus, num_elem)

    def test_server():
        init_communicate(Config.server_rank)
        fhe_builder = FheBuilder(modulus, degree)
        comm_fhe_builder = CommFheBuilder(Config.server_rank, fhe_builder,
                                          "fhe_builder")

        comm_fhe_builder.recv_public_key()
        comm_fhe_builder.wait_and_build_public_key()

        blob_ori = BlobFheEnc(num_elem, comm_fhe_builder, ori_tag)
        blob_ori.send_from_torch(ori)

        enc_pk = fhe_builder.build_enc_from_torch(expected_tensor_pk)
        comm_fhe_builder.send_enc(enc_pk, pk_tag)

        comm_fhe_builder.recv_secret_key()
        comm_fhe_builder.wait_and_build_secret_key()

        enc_sk = fhe_builder.build_enc(num_elem)
        comm_fhe_builder.recv_enc(enc_sk, sk_tag)
        comm_fhe_builder.wait_enc(sk_tag)
        actual_tensor_sk = fhe_builder.decrypt_to_torch(enc_sk)
        compare_expected_actual(expected_tensor_sk,
                                actual_tensor_sk,
                                get_relative=True,
                                name="Recovering to test sk")

        dist.destroy_process_group()

    def test_client():
        init_communicate(Config.client_rank)
        fhe_builder = FheBuilder(modulus, degree)
        comm_fhe_builder = CommFheBuilder(Config.client_rank, fhe_builder,
                                          "fhe_builder")
        fhe_builder.generate_keys()

        comm_fhe_builder.send_public_key()

        blob_ori = BlobFheEnc(num_elem, comm_fhe_builder, ori_tag)
        blob_ori.prepare_recv()
        dec = blob_ori.get_recv_decrypt()
        compare_expected_actual(ori, dec, get_relative=True, name=ori_tag)

        enc_pk = fhe_builder.build_enc(num_elem)
        comm_fhe_builder.recv_enc(enc_pk, pk_tag)
        comm_fhe_builder.wait_enc(pk_tag)
        actual_tensor_pk = fhe_builder.decrypt_to_torch(enc_pk)
        compare_expected_actual(expected_tensor_pk,
                                actual_tensor_pk,
                                get_relative=True,
                                name="Recovering to test pk")

        comm_fhe_builder.send_secret_key()

        enc_sk = fhe_builder.build_enc_from_torch(expected_tensor_sk)
        comm_fhe_builder.send_enc(enc_sk, sk_tag)

        dist.destroy_process_group()

    marshal_funcs([test_server, test_client])
Exemplo n.º 7
0
 def offline(self):
     PhaseProtocolServer.offline(self)
     self.input_mask_s = generate_random_mask(
         self.modulus, self.num_elem).to(self.comp_device)
Exemplo n.º 8
0
def test_conv2d_secure_comm(input_sid,
                            master_address,
                            master_port,
                            setting=(16, 3, 128, 128)):
    test_name = "Conv2d Secure Comm"
    print(f"\nTest for {test_name}: Start")
    modulus = 786433
    padding = 1
    img_hw, filter_hw, num_input_channel, num_output_channel = setting
    data_bit = 17
    data_range = 2**data_bit
    # n_23 = 8192
    n_23 = 16384
    print(f"Setting covn2d: img_hw: {img_hw}, "
          f"filter_hw: {filter_hw}, "
          f"num_input_channel: {num_input_channel}, "
          f"num_output_channel: {num_output_channel}")

    x_shape = [num_input_channel, img_hw, img_hw]
    w_shape = [num_output_channel, num_input_channel, filter_hw, filter_hw]
    b_shape = [num_output_channel]

    fhe_builder = FheBuilder(modulus, n_23)
    fhe_builder.generate_keys()

    weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                   get_prod(w_shape)).reshape(w_shape)
    bias = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                 get_prod(b_shape)).reshape(b_shape)
    input = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  get_prod(x_shape)).reshape(x_shape)
    input_c = generate_random_mask(modulus, x_shape)
    input_s = pmod(input - input_c, modulus)

    def check_correctness_online(x, w, b, output_s, output_c):
        actual = pmod(output_s.cuda() + output_c.cuda(), modulus)
        torch_x = x.reshape([1] + x_shape).cuda().double()
        torch_w = w.reshape(w_shape).cuda().double()
        torch_b = b.cuda().double() if b is not None else None
        expected = F.conv2d(torch_x, torch_w, padding=padding, bias=torch_b)
        expected = pmod(expected.reshape(output_s.shape), modulus)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + " online",
                                get_relative=True)

    def test_server():
        rank = Config.server_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = Conv2dSecureServer(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_secure_comm",
                                  padding=padding)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight, bias=bias)
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(input_s)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_online(input, weight, bias, prot.output_s, output_c)

        end_communicate()

    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = Conv2dSecureClient(modulus,
                                  fhe_builder,
                                  data_range,
                                  img_hw,
                                  filter_hw,
                                  num_input_channel,
                                  num_output_channel,
                                  "test_conv2d_secure_comm",
                                  padding=padding)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online()
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()

    if input_sid == Config.both_rank:
        marshal_funcs([test_server, test_client])
    elif input_sid == Config.server_rank:
        test_server()
    elif input_sid == Config.client_rank:
        test_client()

    print(f"\nTest for {test_name}: End")
Exemplo n.º 9
0
def test_fc_secure_comm(input_sid,
                        master_address,
                        master_port,
                        setting=(512, 512)):
    test_name = "test_fc_secure_comm"
    print(f"\nTest for {test_name}: Start")
    modulus = 786433
    num_input_unit, num_output_unit = setting
    data_bit = 17

    print(f"Setting fc: "
          f"num_input_unit: {num_input_unit}, "
          f"num_output_unit: {num_output_unit}")

    data_range = 2**data_bit
    x_shape = [num_input_unit]
    w_shape = [num_output_unit, num_input_unit]
    b_shape = [num_output_unit]

    fhe_builder = FheBuilder(modulus, Config.n_23)
    fhe_builder.generate_keys()

    weight = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                   get_prod(w_shape)).reshape(w_shape)
    bias = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                 get_prod(b_shape)).reshape(b_shape)
    input = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  get_prod(x_shape)).reshape(x_shape)
    input_c = generate_random_mask(modulus, x_shape)
    input_s = pmod(input - input_c, modulus)

    def check_correctness_online(x, w, output_s, output_c):
        actual = pmod(output_s.cuda() + output_c.cuda(), modulus)
        torch_x = x.reshape([1] + x_shape).cuda().double()
        torch_w = w.reshape(w_shape).cuda().double()
        expected = torch.mm(torch_x, torch_w.t())
        if bias is not None:
            expected += bias.cuda().double()
        expected = pmod(expected.reshape(output_s.shape), modulus)
        compare_expected_actual(expected,
                                actual,
                                name=test_name + " online",
                                get_relative=True)

    def test_server():
        rank = Config.server_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = FcSecureServer(modulus, data_range, num_input_unit,
                              num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Server Offline"):
            prot.offline(weight, bias=bias)
            torch_sync()
        with NamedTimerInstance("Server Online"):
            prot.online(input_s)
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        blob_output_c.prepare_recv()
        torch_sync()
        output_c = blob_output_c.get_recv()
        check_correctness_online(input, weight, prot.output_s, output_c)

        end_communicate()

    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        prot = FcSecureClient(modulus, data_range, num_input_unit,
                              num_output_unit, fhe_builder, test_name)
        with NamedTimerInstance("Client Offline"):
            prot.offline(input_c)
            torch_sync()
        with NamedTimerInstance("Client Online"):
            prot.online()
            torch_sync()

        blob_output_c = BlobTorch(prot.output_shape, torch.float,
                                  prot.comm_base, "output_c")
        torch_sync()
        blob_output_c.send(prot.output_c)
        end_communicate()

    marshal_funcs([test_server, test_client])
    print(f"\nTest for {test_name}: End")