예제 #1
0
 def __init__(self,
              num_elem,
              q_23,
              q_16,
              work_bit,
              data_bit,
              fhe_builder_16: FheBuilder,
              fhe_builder_23: FheBuilder,
              name: str,
              is_shuffle=None):
     super(DgkBitClient, self).__init__(num_elem,
                                        q_23,
                                        q_16,
                                        work_bit,
                                        data_bit,
                                        name=name)
     self.fhe_builder_16 = fhe_builder_16
     self.fhe_builder_23 = fhe_builder_23
     self.comm_base = CommBase(Config.client_rank, name)
     self.comm_fhe_16 = CommFheBuilder(Config.client_rank,
                                       self.fhe_builder_16,
                                       name + '_' + "comm_fhe_16")
     self.comm_fhe_23 = CommFheBuilder(Config.client_rank,
                                       self.fhe_builder_23,
                                       name + '_' + "comm_fhe_23")
     self.common = DgkBitCommon(num_elem, q_23, q_16, work_bit, data_bit,
                                name, self.comm_base, self.comm_fhe_16,
                                self.comm_fhe_23)
     self.is_shuffle = Config.is_shuffle if is_shuffle is None else is_shuffle
예제 #2
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank)
        context.set_rank(rank)

        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        input_img = generate_random_data(input_shape)
        trunc1.set_div_to_pow(pow_to_div)
        secure_nn.feed_input(input_img)

        with NamedTimerInstance("Client Offline"):
            secure_nn.offline()
            torch_sync()

        with NamedTimerInstance("Client Online"):
            secure_nn.online()
            torch_sync()

        check_correctness(input_img, secure_nn.get_output())
        end_communicate()
예제 #3
0
 def __init__(self, shape, fhe_builder: FheBuilder, name: str):
     self.shape = shape
     self.fhe_builder = fhe_builder
     self.modulus = self.fhe_builder.modulus
     comm = CommFheBuilder(Config.client_rank, fhe_builder, name)
     self.common = EncRefresherCommon(shape, comm, name)
     self.is_prepared = False
예제 #4
0
    def __init__(self, modulus, data_range, num_input_unit, num_output_unit,
                 fhe_builder: FheBuilder, name, rank):
        super().__init__(modulus, data_range, num_input_unit, num_output_unit,
                         name)

        self.fhe_builder = fhe_builder
        self.rank = rank
        self.comm_base = CommBase(rank, self.sub_name("comm_base"))
        self.comm_fhe = CommFheBuilder(rank, fhe_builder,
                                       self.sub_name("comm_fhe"))
        self.compute_core = FcFheSingleThread(modulus, data_range,
                                              num_input_unit, num_output_unit,
                                              fhe_builder, name)

        self.blob_input_cts = BlobFheRawCts(self.num_input_batch,
                                            self.comm_fhe,
                                            self.sub_name("input_cts"))
        self.blob_output_cts = BlobFheRawCts(self.num_output_batch,
                                             self.comm_fhe,
                                             self.sub_name("output_cts"))

        self.offline_server_send = [self.blob_output_cts]
        self.offline_client_send = [self.blob_input_cts]
        self.online_server_send = []
        self.online_client_send = []
예제 #5
0
    def __init__(self, num_elem, modulus, fhe_builder: FheBuilder, name: str,
                 rank):
        self.num_elem = num_elem
        self.modulus = modulus
        self.name = name
        self.fhe_builder = fhe_builder
        self.rank = rank
        self.comm_base = CommBase(rank, name)
        self.comm_fhe = CommFheBuilder(rank, fhe_builder,
                                       self.sub_name("comm_fhe"))

        assert (self.modulus == self.fhe_builder.modulus)

        self.blob_fhe_u = BlobFheEnc(num_elem, self.comm_fhe,
                                     self.sub_name("fhe_u"))
        self.blob_fhe_v = BlobFheEnc(num_elem, self.comm_fhe,
                                     self.sub_name("fhe_v"))
        self.blob_fhe_z_c = BlobFheEnc(num_elem,
                                       self.comm_fhe,
                                       self.sub_name("fhe_z_c"),
                                       ee_mult_time=1)

        self.blob_e_s = BlobTorch(num_elem, torch.float, self.comm_base, "e_s")
        self.blob_f_s = BlobTorch(num_elem, torch.float, self.comm_base, "f_s")
        self.blob_e_c = BlobTorch(num_elem, torch.float, self.comm_base, "e_c")
        self.blob_f_c = BlobTorch(num_elem, torch.float, self.comm_base, "f_c")

        self.offline_server_send = [self.blob_fhe_z_c]
        self.offline_client_send = [self.blob_fhe_u, self.blob_fhe_v]
        self.online_server_send = [self.blob_e_s, self.blob_f_s]
        self.online_client_send = [self.blob_e_c, self.blob_f_c]
예제 #6
0
    def __init__(self, modulus, fhe_builder: FheBuilder, data_range, img_hw,
                 filter_hw, num_input_channel, num_output_channel, name: str,
                 rank, padding):
        super(Conv2dFheNttBase,
              self).__init__(modulus, data_range, img_hw, filter_hw,
                             num_input_channel, num_output_channel, name,
                             padding)
        self.fhe_builder = fhe_builder
        self.rank = rank
        self.comm_base = CommBase(rank, self.sub_name("comm_base"))
        self.comm_fhe = CommFheBuilder(rank, fhe_builder,
                                       self.sub_name("comm_fhe"))
        self.compute_core = Conv2dFheNttSingleThread(
            modulus, data_range, img_hw, filter_hw, num_input_channel,
            num_output_channel, fhe_builder, name, padding)

        assert (self.modulus == self.fhe_builder.modulus)

        self.blob_input_cts = BlobFheRawCts(self.num_input_batch,
                                            self.comm_fhe,
                                            self.sub_name("input_cts"))
        self.blob_output_cts = BlobFheRawCts(self.num_output_batch,
                                             self.comm_fhe,
                                             self.sub_name("output_cts"))

        self.offline_server_send = [self.blob_output_cts]
        self.offline_client_send = [self.blob_input_cts]
        self.online_server_send = []
        self.online_client_send = []
예제 #7
0
    def __init__(self, num_elem, q_23, q_16, work_bit, data_bit,
                 fhe_builder_16: FheBuilder, fhe_builder_23: FheBuilder,
                 name: str, rank: int, class_name: str):
        super(DgkCommBase, self).__init__(num_elem, q_23, q_16, work_bit,
                                          data_bit)
        self.class_name = class_name
        self.name = name
        self.rank = rank
        self.fhe_builder_16 = fhe_builder_16
        self.fhe_builder_23 = fhe_builder_23
        self.comm_base = CommBase(rank, name)
        self.comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16,
                                          self.sub_name("comm_fhe_16"))
        self.comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23,
                                          self.sub_name("comm_fhe_23"))

        assert (self.fhe_builder_16.modulus == self.q_16)
        assert (self.fhe_builder_23.modulus == self.q_23)
예제 #8
0
    def __init__(self, modulus, fhe_builder: FheBuilder, data_range, img_hw,
                 filter_hw, num_input_channel, num_output_channel, name: str,
                 rank, padding):
        super().__init__(modulus, data_range, img_hw, filter_hw,
                         num_input_channel, num_output_channel, name, padding)
        self.fhe_builder = fhe_builder
        self.rank = rank
        self.comm_base = CommBase(rank, self.sub_name("comm_base"))
        self.comm_fhe = CommFheBuilder(rank, fhe_builder,
                                       self.sub_name("comm_fhe"))

        self.offline_server_send = []
        self.offline_client_send = []
        self.online_server_send = []
        self.online_client_send = []
예제 #9
0
    def __init__(self, modulus, data_range, num_input_unit, num_output_unit,
                 fhe_builder: FheBuilder, name, rank):
        super().__init__(modulus, data_range, num_input_unit, num_output_unit,
                         name)
        self.fhe_builder = fhe_builder
        self.rank = rank
        self.comm_base = CommBase(rank, self.sub_name("comm_base"))
        self.comm_fhe = CommFheBuilder(rank, fhe_builder,
                                       self.sub_name("comm_fhe"))

        assert (self.fhe_builder.degree == self.degree)

        self.offline_server_send = []
        self.offline_client_send = []
        self.online_server_send = []
        self.online_client_send = []
예제 #10
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        prot = Maxpool2x2DgkClient(num_elem, q_23, q_16, work_bit, data_bit,
                                   img_hw, fhe_builder_16, fhe_builder_23,
                                   "max_dgk")
        blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                               "recon_max_c")
        blob_img_c.prepare_recv()
        torch_sync()
        img_c = blob_img_c.get_recv()

        torch_sync()
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            prot.online(img_c)
            torch_sync()
        traffic_record.reset("client-online")

        blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_max_c")
        torch_sync()
        blob_max_c.send(prot.max_c)
        end_communicate()
예제 #11
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank, master_address=master_address, master_port=master_port)
        traffic_record = TrafficRecord()
        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        a = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2, num_elem)
        a_c = gen_unirand_int_grain(0, q_23 - 1, num_elem)
        a_s = pmod(a - a_c, q_23)

        prot = ReluDgkClient(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk")

        blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a")
        blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s")
        torch_sync()
        blob_a_s.send(a_s)
        blob_max_s.prepare_recv()

        torch_sync()
        with NamedTimerInstance("Client Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("client-offline")

        with NamedTimerInstance("Client Online"):
            prot.online(a_c)
            torch_sync()
        traffic_record.reset("client-online")

        max_s = blob_max_s.get_recv()
        check_correctness_online(a, max_s, prot.max_c)

        torch.cuda.empty_cache()
        end_communicate()
예제 #12
0
 def fhe_builder_sync(self):
     comm_fhe_16 = CommFheBuilder(self.rank, self.fhe_builder_16, self.sub_name("fhe_builder_16"))
     comm_fhe_23 = CommFheBuilder(self.rank, self.fhe_builder_23, self.sub_name("fhe_builder_23"))
     torch_sync()
     if self.is_server():
         comm_fhe_16.recv_public_key()
         comm_fhe_23.recv_public_key()
         comm_fhe_16.wait_and_build_public_key()
         comm_fhe_23.wait_and_build_public_key()
     elif self.is_client():
         self.fhe_builder_16.generate_keys()
         self.fhe_builder_23.generate_keys()
         comm_fhe_16.send_public_key()
         comm_fhe_23.send_public_key()
     torch_sync()
     return self
예제 #13
0
    def test_server():
        rank = Config.server_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        img = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                    num_elem)
        img_s = gen_unirand_int_grain(0, q_23 - 1, num_elem)
        img_c = pmod(img - img_s, q_23)

        prot = Maxpool2x2DgkServer(num_elem, q_23, q_16, work_bit, data_bit,
                                   img_hw, fhe_builder_16, fhe_builder_23,
                                   "max_dgk")

        blob_img_c = BlobTorch(num_elem, torch.float, prot.comm_base,
                               "recon_max_c")
        torch_sync()
        blob_img_c.send(img_c)

        torch_sync()
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            prot.online(img_s)
            torch_sync()
        traffic_record.reset("server-online")

        blob_max_c = BlobTorch(num_elem // 4, torch.float, prot.comm_base,
                               "recon_max_c")
        blob_max_c.prepare_recv()
        torch_sync()
        max_c = blob_max_c.get_recv()
        check_correctness_online(img, prot.max_s, max_c)

        end_communicate()
예제 #14
0
    def test_server():
        rank = Config.server_rank
        init_communicate(rank)
        context.set_rank(rank)

        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        conv1.load_weight(conv1_w)
        conv2.load_weight(conv2_w)
        fc1.load_weight(fc1_w)
        trunc1.set_div_to_pow(pow_to_div)

        with NamedTimerInstance("Server Offline"):
            secure_nn.offline()
            torch_sync()

        with NamedTimerInstance("Server Online"):
            secure_nn.online()
            torch_sync()

        end_communicate()
예제 #15
0
    def test_server():
        rank = Config.server_rank
        init_communicate(Config.server_rank, master_address=master_address, master_port=master_port)
        traffic_record = TrafficRecord()
        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        prot = ReluDgkServer(num_elem, q_23, q_16, work_bit, data_bit, fhe_builder_16, fhe_builder_23, "relu_dgk")

        blob_a_s = BlobTorch(num_elem, torch.float, prot.comm_base, "a")
        blob_max_s = BlobTorch(num_elem, torch.float, prot.comm_base, "max_s")
        torch_sync()
        blob_a_s.prepare_recv()
        a_s = blob_a_s.get_recv()

        torch_sync()
        with NamedTimerInstance("Server Offline"):
            prot.offline()
            torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            prot.online(a_s)
            torch_sync()
        traffic_record.reset("server-online")

        blob_max_s.send(prot.max_s)
        torch.cuda.empty_cache()
        end_communicate()
예제 #16
0
    def test_client():
        rank = Config.client_rank
        init_communicate(rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, n_16)
        fhe_builder_23 = FheBuilder(q_23, n_23)
        fhe_builder_16.generate_keys()
        fhe_builder_23.generate_keys()
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.send_public_key()
        comm_fhe_23.send_public_key()

        dgk = DgkBitClient(num_elem, q_23, q_16, work_bit, data_bit,
                           fhe_builder_16, fhe_builder_23, "DgkBitTest")

        x = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  num_elem)
        y = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                  num_elem)
        x_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                    num_elem)
        y_c = gen_unirand_int_grain(-data_range // 2 + 1, data_range // 2,
                                    num_elem)
        x_s = pmod(x - x_c, q_23)
        y_s = pmod(y - y_c, q_23)
        y_sub_x_s = pmod(y_s - x_s, q_23)

        x_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "x")
        y_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y")
        y_sub_x_s_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                   "y_sub_x_s")
        torch_sync()
        x_blob.send(x)
        y_blob.send(y)
        y_sub_x_s_blob.send(y_sub_x_s)

        torch_sync()
        with NamedTimerInstance("Client Offline"):
            dgk.offline()
        y_sub_x_c = pmod(y_c - x_c, q_23)
        traffic_record.reset("client-offline")
        torch_sync()

        with NamedTimerInstance("Client Online"):
            dgk.online(y_sub_x_c)
        traffic_record.reset("client-online")

        dgk_x_leq_y_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                       "dgk_x_leq_y_c")
        correct_mod_div_work_c_blob = BlobTorch(num_elem, torch.float,
                                                dgk.comm_base,
                                                "correct_mod_div_work_c")
        z_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "z")
        torch_sync()
        dgk_x_leq_y_c_blob.send(dgk.dgk_x_leq_y_c)
        correct_mod_div_work_c_blob.send(dgk.correct_mod_div_work_c)
        z_blob.send(dgk.z)
        end_communicate()
예제 #17
0
    def test_server():
        rank = Config.server_rank
        init_communicate(Config.server_rank,
                         master_address=master_address,
                         master_port=master_port)
        warming_up_cuda()
        traffic_record = TrafficRecord()

        fhe_builder_16 = FheBuilder(q_16, Config.n_16)
        fhe_builder_23 = FheBuilder(q_23, Config.n_23)
        comm_fhe_16 = CommFheBuilder(rank, fhe_builder_16, "fhe_builder_16")
        comm_fhe_23 = CommFheBuilder(rank, fhe_builder_23, "fhe_builder_23")
        torch_sync()
        comm_fhe_16.recv_public_key()
        comm_fhe_23.recv_public_key()
        comm_fhe_16.wait_and_build_public_key()
        comm_fhe_23.wait_and_build_public_key()

        dgk = DgkBitServer(num_elem, q_23, q_16, work_bit, data_bit,
                           fhe_builder_16, fhe_builder_23, "DgkBitTest")

        x_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "x")
        y_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "y")
        y_sub_x_s_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                   "y_sub_x_s")
        x_blob.prepare_recv()
        y_blob.prepare_recv()
        y_sub_x_s_blob.prepare_recv()
        torch_sync()
        x = x_blob.get_recv()
        y = y_blob.get_recv()
        y_sub_x_s = y_sub_x_s_blob.get_recv()

        torch_sync()
        with NamedTimerInstance("Server Offline"):
            dgk.offline()
        # y_sub_x_s = pmod(y_s.to(Config.device) - x_s.to(Config.device), q_23)
        torch_sync()
        traffic_record.reset("server-offline")

        with NamedTimerInstance("Server Online"):
            dgk.online(y_sub_x_s)
        traffic_record.reset("server-online")

        dgk_x_leq_y_c_blob = BlobTorch(num_elem, torch.float, dgk.comm_base,
                                       "dgk_x_leq_y_c")
        correct_mod_div_work_c_blob = BlobTorch(num_elem, torch.float,
                                                dgk.comm_base,
                                                "correct_mod_div_work_c")
        z_blob = BlobTorch(num_elem, torch.float, dgk.comm_base, "z")
        dgk_x_leq_y_c_blob.prepare_recv()
        correct_mod_div_work_c_blob.prepare_recv()
        z_blob.prepare_recv()
        torch_sync()
        dgk_x_leq_y_c = dgk_x_leq_y_c_blob.get_recv()
        correct_mod_div_work_c = correct_mod_div_work_c_blob.get_recv()
        z = z_blob.get_recv()
        check_correctness(x, y, dgk.dgk_x_leq_y_s, dgk_x_leq_y_c)
        check_correctness_mod_div(dgk.r, z, dgk.correct_mod_div_work_s,
                                  correct_mod_div_work_c)
        end_communicate()