Ejemplo n.º 1
0
    def sync_final_splitinfo_host(self, splitinfo_host, federated_best_splitinfo_host, dep=-1, batch=-1):

        LOGGER.info("send host final splitinfo of depth {}, batch {}".format(dep, batch))
        final_splitinfos = []
        for i in range(len(splitinfo_host)):
            best_idx, best_gain = federated_best_splitinfo_host[i]
            if best_idx != -1:
                LOGGER.debug('sitename is {}, self.sitename is {}'
                             .format(splitinfo_host[i][best_idx].sitename, self.sitename))
                assert splitinfo_host[i][best_idx].sitename == self.sitename
                splitinfo = splitinfo_host[i][best_idx]
                splitinfo.best_fid = self.encode("feature_idx", splitinfo.best_fid)
                assert splitinfo.best_fid is not None
                splitinfo.best_bid = self.encode("feature_val", splitinfo.best_bid, self.cur_to_split_nodes[i].id)
                splitinfo.missing_dir = self.encode("missing_dir", splitinfo.missing_dir, self.cur_to_split_nodes[i].id)
                splitinfo.gain = best_gain
            else:
                splitinfo = SplitInfo(sitename=self.sitename, best_fid=-1, best_bid=-1, gain=best_gain)

            final_splitinfos.append(splitinfo)

        self.transfer_inst.final_splitinfo_host.remote(final_splitinfos,
                                                       role=consts.GUEST,
                                                       idx=-1,
                                                       suffix=(dep, batch,))
Ejemplo n.º 2
0
    def get_host_split_info(self, splitinfo_host,
                            federated_best_splitinfo_host):

        final_splitinfos = []
        for i in range(len(splitinfo_host)):
            best_idx, best_gain = federated_best_splitinfo_host[i]
            if best_idx != -1:
                LOGGER.debug('sitename is {}, self.sitename is {}'.format(
                    splitinfo_host[i][best_idx].sitename, self.sitename))
                assert splitinfo_host[i][best_idx].sitename == self.sitename
                splitinfo = splitinfo_host[i][best_idx]
                splitinfo.best_fid = splitinfo.best_fid
                assert splitinfo.best_fid is not None
                splitinfo.best_bid = splitinfo.best_bid
                splitinfo.missing_dir = splitinfo.missing_dir
                splitinfo.gain = best_gain
            else:
                splitinfo = SplitInfo(sitename=self.sitename,
                                      best_fid=-1,
                                      best_bid=-1,
                                      gain=best_gain)

            final_splitinfos.append(splitinfo)

        return final_splitinfos
Ejemplo n.º 3
0
    def collect_host_split_feat_importance(self):

        for node in self.tree_node:
            if node.is_leaf:
                continue
            elif node.sitename == self.sitename:
                LOGGER.debug('sitename are {} {}'.format(
                    node.sitename, self.sitename))
                fid = self.split_feature_dict[node.id]
                self.update_feature_importance(
                    SplitInfo(sitename=self.sitename, best_fid=fid), False)
Ejemplo n.º 4
0
    def add(self, split_info):
        split_info_cp = SplitInfo(sitename=split_info.sitename,
                                  best_bid=split_info.best_bid,
                                  best_fid=split_info.best_fid,
                                  missing_dir=split_info.missing_dir,
                                  mask_id=split_info.mask_id,
                                  sample_count=split_info.sample_count)

        en_g = split_info.sum_grad
        super(SplitInfoPackage2, self).add(en_g)
        self._cur_splitinfo_contains += 1
        self._split_info_without_gh.append(split_info_cp)
Ejemplo n.º 5
0
    def test_regression_cipher_compress(self):

        # test the correctness of cipher compressing
        print('testing regression')
        collected_gh = self.reg_p_collected_gh
        en_g_l = self.reg_p_en_g_l
        en_h_l = self.reg_p_en_h_l
        packer = self.reg_p_packer
        en = self.p_en

        sp_list = []
        g_sum_list, h_sum_list = [], []
        pack_en_list = []

        for i in range(self.split_info_test_num):
            g_sum, h_sum, en_sum, en_g_sum, en_h_sum, sample_num = make_random_sum(
                collected_gh, self.g_reg, self.h_reg, en_g_l, en_h_l,
                self.max_sample_num)
            sp = SplitInfo(sum_grad=en_sum,
                           sum_hess=0,
                           sample_count=sample_num)
            sp_list.append(sp)
            g_sum_list.append(g_sum)
            h_sum_list.append(h_sum)
            pack_en_list.append(en_sum)

        print('generating split-info done')
        packages = self.reg_compressor.compress_split_info(
            sp_list[:-1], sp_list[-1])
        print('package length is {}'.format(len(packages)))
        unpack_rs = packer.decompress_and_unpack(packages)
        case_id = 0
        for s, g, h, en_gh in zip(unpack_rs, g_sum_list, h_sum_list,
                                  pack_en_list):
            print('*' * 10)
            print(case_id)
            case_id += 1
            de_num = en.raw_decrypt(
                en_gh)  # make sure packing result close to plaintext sum
            unpack_num = packer.packer.unpack_an_int(
                de_num, packer.packer.bit_assignment[0])
            g_sum_ = unpack_num[
                0] / fix_point_precision - s.sample_count * packer.g_offset
            h_sum_ = unpack_num[1] / fix_point_precision

            print(s.sample_count)
            print(s.sum_grad, g_sum_, g)
            print(s.sum_hess, h_sum_, h)

            # make sure cipher compress is correct
            self.assertTrue(truncate(s.sum_grad) == truncate(g_sum_))
            self.assertTrue(truncate(s.sum_hess) == truncate(h_sum_))
        print('check passed')
Ejemplo n.º 6
0
    def random_split_info_generate(num=5, max_num=90000):

        split_info_list = []
        for i in range(num):
            g, h = np.random.randint(max_num) + np.random.random(
            ), np.random.randint(max_num) + np.random.random()
            best_fid, best_bid = np.random.randint(10), np.random.randint(10)
            missing_dir = np.random.randint(10000)
            info = SplitInfo(sum_grad=g,
                             sum_hess=h,
                             best_fid=best_fid,
                             best_bid=best_bid,
                             missing_dir=missing_dir,
                             sample_count=0)
            split_info_list.append(info)

        return split_info_list