示例#1
0
    def find_split_single_histogram_host(self,
                                         histogram,
                                         valid_features,
                                         sitename,
                                         use_missing=False,
                                         zero_as_missing=False):
        node_splitinfo = []
        node_grad_hess = []

        missing_bin = 0
        if use_missing:
            missing_bin = 1

        for fid in range(len(histogram)):
            if valid_features[fid] is False:
                continue
            bin_num = len(histogram[fid])
            if bin_num == 0:
                continue
            node_cnt = histogram[fid][bin_num - 1][2]

            if node_cnt < self.min_sample_split:
                break

            for bid in range(bin_num - missing_bin - 1):
                sum_grad_l = histogram[fid][bid][0]
                sum_hess_l = histogram[fid][bid][1]
                node_cnt_l = histogram[fid][bid][2]

                node_cnt_r = node_cnt - node_cnt_l

                if node_cnt_l >= self.min_leaf_node and node_cnt_r >= self.min_leaf_node:
                    splitinfo = SplitInfo(sitename=sitename,
                                          best_fid=fid,
                                          best_bid=bid,
                                          sum_grad=sum_grad_l,
                                          sum_hess=sum_hess_l,
                                          missing_dir=1)

                    node_splitinfo.append(splitinfo)
                    node_grad_hess.append((sum_grad_l, sum_hess_l))

                if use_missing:
                    sum_grad_l += histogram[fid][-1][0] - histogram[fid][-2][0]
                    sum_hess_l += histogram[fid][-1][1] - histogram[fid][-2][1]
                    node_cnt_l += histogram[fid][-1][2] - histogram[fid][-2][2]

                    splitinfo = SplitInfo(sitename=sitename,
                                          best_fid=fid,
                                          best_bid=bid,
                                          sum_grad=sum_grad_l,
                                          sum_hess=sum_hess_l,
                                          missing_dir=-1)

                    node_splitinfo.append(splitinfo)
                    node_grad_hess.append((sum_grad_l, sum_hess_l))

        return node_splitinfo, node_grad_hess
    def sync_final_splitinfo_host(self,
                                  splitinfo_host,
                                  federated_best_splitinfo_host,
                                  dep=-1,
                                  batch=-1):
        LOGGER.info("send host final splitinfo of depth {}, batch {}".format(
            dep, batch))
        final_splitinfos = []
        for i in range(len(splitinfo_host)):
            best_idx, best_gain = federated_best_splitinfo_host[i]
            if best_idx != -1:
                assert splitinfo_host[i][best_idx].sitename == consts.HOST
                splitinfo = splitinfo_host[i][best_idx]
                splitinfo.best_fid = self.encode("feature_idx",
                                                 splitinfo.best_fid)
                assert splitinfo.best_fid is not None
                splitinfo.best_bid = self.encode("feature_val",
                                                 splitinfo.best_bid,
                                                 self.cur_split_nodes[i].id)
                splitinfo.gain = best_gain
            else:
                splitinfo = SplitInfo(sitename=consts.HOST,
                                      best_fid=-1,
                                      best_bid=-1,
                                      gain=best_gain)

            final_splitinfos.append(splitinfo)

        federation.remote(obj=final_splitinfos,
                          name=self.transfer_inst.final_splitinfo_host.name,
                          tag=self.transfer_inst.generate_transferid(
                              self.transfer_inst.final_splitinfo_host, dep,
                              batch),
                          role=consts.GUEST,
                          idx=0)
示例#3
0
文件: splitter.py 项目: 03040081/FATE
    def find_split_single_histogram_host(self, histogram, valid_features):
        node_splitinfo = []
        node_grad_hess = []
        for fid in range(len(histogram)):
            if valid_features[fid] is False:
                continue
            bin_num = len(histogram[fid])
            if bin_num == 0:
                continue
            node_cnt = histogram[fid][bin_num - 1][2]

            if node_cnt < self.min_sample_split:
                break

            for bid in range(bin_num):
                sum_grad_l = histogram[fid][bid][0]
                sum_hess_l = histogram[fid][bid][1]
                node_cnt_l = histogram[fid][bid][2]

                node_cnt_r = node_cnt - node_cnt_l

                if node_cnt_l >= self.min_leaf_node and node_cnt_r >= self.min_leaf_node:
                    splitinfo = SplitInfo(sitename=consts.HOST,
                                          best_fid=fid,
                                          best_bid=bid,
                                          sum_grad=sum_grad_l,
                                          sum_hess=sum_hess_l)

                    node_splitinfo.append(splitinfo)
                    node_grad_hess.append((sum_grad_l, sum_hess_l))

        return node_splitinfo, node_grad_hess
示例#4
0
 def test_splitinfo(self):
     pass
     param_dict = {"sitename": "testsplitinfo",
                   "best_fid": 23, "best_bid": 233,
                   "sum_grad": 2333, "sum_hess": 23333, "gain": 233333}
     splitinfo = SplitInfo(sitename="testsplitinfo", best_fid=23, best_bid=233,
                           sum_grad=2333, sum_hess=23333, gain=233333)
     for key in param_dict:
         self.assertTrue(param_dict[key] == getattr(splitinfo, key))
示例#5
0
文件: splitter.py 项目: 03040081/FATE
    def find_split_single_histogram_guest(self, histogram, valid_features):
        best_fid = None
        best_gain = self.min_impurity_split - consts.FLOAT_ZERO
        best_bid = None
        best_sum_grad_l = None
        best_sum_hess_l = None
        for fid in range(len(histogram)):
            if valid_features[fid] is False:
                continue
            bin_num = len(histogram[fid])
            if bin_num == 0:
                continue
            sum_grad = histogram[fid][bin_num - 1][0]
            sum_hess = histogram[fid][bin_num - 1][1]
            node_cnt = histogram[fid][bin_num - 1][2]

            if node_cnt < self.min_sample_split:
                break

            for bid in range(bin_num):
                sum_grad_l = histogram[fid][bid][0]
                sum_hess_l = histogram[fid][bid][1]
                node_cnt_l = histogram[fid][bid][2]

                sum_grad_r = sum_grad - sum_grad_l
                sum_hess_r = sum_hess - sum_hess_l
                node_cnt_r = node_cnt - node_cnt_l

                if node_cnt_l >= self.min_leaf_node and node_cnt_r >= self.min_leaf_node:
                    gain = self.criterion.split_gain([sum_grad, sum_hess],
                                                     [sum_grad_l, sum_hess_l],
                                                     [sum_grad_r, sum_hess_r])

                    if gain > self.min_impurity_split and gain > best_gain:
                        best_gain = gain
                        best_fid = fid
                        best_bid = bid
                        best_sum_grad_l = sum_grad_l
                        best_sum_hess_l = sum_hess_l

        splitinfo = SplitInfo(sitename=consts.GUEST,
                              best_fid=best_fid,
                              best_bid=best_bid,
                              gain=best_gain,
                              sum_grad=best_sum_grad_l,
                              sum_hess=best_sum_hess_l)

        return splitinfo
示例#6
0
    def sync_final_splitinfo_host(self,
                                  splitinfo_host,
                                  federated_best_splitinfo_host,
                                  dep=-1,
                                  batch=-1):
        LOGGER.info("send host final splitinfo of depth {}, batch {}".format(
            dep, batch))
        final_splitinfos = []
        for i in range(len(splitinfo_host)):
            best_idx, best_gain = federated_best_splitinfo_host[i]
            if best_idx != -1:
                assert splitinfo_host[i][best_idx].sitename == self.sitename
                splitinfo = splitinfo_host[i][best_idx]
                splitinfo.best_fid = self.encode("feature_idx",
                                                 splitinfo.best_fid)
                assert splitinfo.best_fid is not None
                splitinfo.best_bid = self.encode("feature_val",
                                                 splitinfo.best_bid,
                                                 self.cur_split_nodes[i].id)
                splitinfo.missing_dir = self.encode("missing_dir",
                                                    splitinfo.missing_dir,
                                                    self.cur_split_nodes[i].id)
                splitinfo.gain = best_gain
            else:
                splitinfo = SplitInfo(sitename=self.sitename,
                                      best_fid=-1,
                                      best_bid=-1,
                                      gain=best_gain)

            final_splitinfos.append(splitinfo)

        self.transfer_inst.final_splitinfo_host.remote(final_splitinfos,
                                                       role=consts.GUEST,
                                                       idx=-1,
                                                       suffix=(
                                                           dep,
                                                           batch,
                                                       ))
        """
示例#7
0
    def find_split_host(self, histograms, valid_features):
        LOGGER.info("splitter find split of host")
        tree_node_splitinfo = []
        encrypted_node_grad_hess = []
        for i in range(len(histograms)):
            node_splitinfo = []
            node_grad_hess = []
            for fid in range(len(histograms[i])):
                if valid_features[fid] is False:
                    continue
                bin_num = len(histograms[i][fid])
                if bin_num == 0:
                    continue
                node_cnt = histograms[i][fid][bin_num - 1][2]

                if node_cnt < self.min_sample_split:
                    break

                for bid in range(bin_num):
                    sum_grad_l = histograms[i][fid][bid][0]
                    sum_hess_l = histograms[i][fid][bid][1]
                    node_cnt_l = histograms[i][fid][bid][2]

                    node_cnt_r = node_cnt - node_cnt_l

                    if node_cnt_l >= self.min_leaf_node and node_cnt_r >= self.min_leaf_node:
                        splitinfo = SplitInfo(sitename=consts.HOST, best_fid=fid, \
                                              best_bid=bid, sum_grad=sum_grad_l, sum_hess=sum_hess_l)

                        node_splitinfo.append(splitinfo)
                        node_grad_hess.append((sum_grad_l, sum_hess_l))

            tree_node_splitinfo.append(node_splitinfo)
            encrypted_node_grad_hess.append(node_grad_hess)

        return tree_node_splitinfo, encrypted_node_grad_hess
示例#8
0
    def find_split_single_histogram_guest(self, histogram, valid_features,
                                          sitename, use_missing,
                                          zero_as_missing):
        best_fid = None
        best_gain = self.min_impurity_split - consts.FLOAT_ZERO
        best_bid = None
        best_sum_grad_l = None
        best_sum_hess_l = None
        missing_bin = 0
        if use_missing:
            missing_bin = 1

        # in default, missing value going to right
        missing_dir = 1

        for fid in range(len(histogram)):
            if valid_features[fid] is False:
                continue
            bin_num = len(histogram[fid])
            if bin_num == 0 + missing_bin:
                continue
            sum_grad = histogram[fid][bin_num - 1][0]
            sum_hess = histogram[fid][bin_num - 1][1]
            node_cnt = histogram[fid][bin_num - 1][2]

            if node_cnt < self.min_sample_split:
                break

            for bid in range(bin_num - missing_bin - 1):
                sum_grad_l = histogram[fid][bid][0]
                sum_hess_l = histogram[fid][bid][1]
                node_cnt_l = histogram[fid][bid][2]

                sum_grad_r = sum_grad - sum_grad_l
                sum_hess_r = sum_hess - sum_hess_l
                node_cnt_r = node_cnt - node_cnt_l

                if node_cnt_l >= self.min_leaf_node and node_cnt_r >= self.min_leaf_node:
                    gain = self.criterion.split_gain([sum_grad, sum_hess],
                                                     [sum_grad_l, sum_hess_l],
                                                     [sum_grad_r, sum_hess_r])

                    if gain > self.min_impurity_split and gain > best_gain:
                        best_gain = gain
                        best_fid = fid
                        best_bid = bid
                        best_sum_grad_l = sum_grad_l
                        best_sum_hess_l = sum_hess_l
                        missing_dir = 1
                """ missing value handle: dispatch to left child"""
                if use_missing:
                    sum_grad_l += histogram[fid][-1][0] - histogram[fid][-2][0]
                    sum_hess_l += histogram[fid][-1][1] - histogram[fid][-2][1]
                    node_cnt_l += histogram[fid][-1][2] - histogram[fid][-2][2]

                    sum_grad_r -= histogram[fid][-1][0] - histogram[fid][-2][0]
                    sum_hess_r -= histogram[fid][-1][1] - histogram[fid][-2][1]
                    node_cnt_r -= histogram[fid][-1][2] - histogram[fid][-2][2]

                    if node_cnt_l >= self.min_leaf_node and node_cnt_r >= self.min_leaf_node:
                        gain = self.criterion.split_gain(
                            [sum_grad, sum_hess], [sum_grad_l, sum_hess_l],
                            [sum_grad_r, sum_hess_r])

                        if gain > self.min_impurity_split and gain > best_gain:
                            best_gain = gain
                            best_fid = fid
                            best_bid = bid
                            best_sum_grad_l = sum_grad_l
                            best_sum_hess_l = sum_hess_l
                            missing_dir = -1

        splitinfo = SplitInfo(sitename=sitename,
                              best_fid=best_fid,
                              best_bid=best_bid,
                              gain=best_gain,
                              sum_grad=best_sum_grad_l,
                              sum_hess=best_sum_hess_l,
                              missing_dir=missing_dir)

        return splitinfo