def update_tree(self, cur_to_split: List[Node],
                    split_info: List[SplitInfo]):
        """
        update current tree structure
        ----------
        split_info
        """
        LOGGER.debug('updating tree_node, cur layer has {} node'.format(
            len(cur_to_split)))
        next_layer_node = []
        assert len(cur_to_split) == len(split_info)

        for idx in range(len(cur_to_split)):
            sum_grad = cur_to_split[idx].sum_grad
            sum_hess = cur_to_split[idx].sum_hess
            if split_info[idx].best_fid is None or split_info[
                    idx].gain <= self.min_impurity_split + consts.FLOAT_ZERO:
                cur_to_split[idx].is_leaf = True
                self.tree_node.append(cur_to_split[idx])
                continue

            cur_to_split[idx].fid = split_info[idx].best_fid
            cur_to_split[idx].bid = split_info[idx].best_bid
            cur_to_split[idx].missing_dir = split_info[idx].missing_dir

            p_id = cur_to_split[idx].id
            l_id, r_id = self.tree_node_num + 1, self.tree_node_num + 2
            cur_to_split[idx].left_nodeid, cur_to_split[
                idx].right_nodeid = l_id, r_id
            self.tree_node_num += 2

            l_g, l_h = split_info[idx].sum_grad, split_info[idx].sum_hess

            # create new left node and new right node
            left_node = Node(id=l_id,
                             sitename=self.sitename,
                             sum_grad=l_g,
                             sum_hess=l_h,
                             weight=self.splitter.node_weight(l_g, l_h),
                             parent_nodeid=p_id,
                             sibling_nodeid=r_id,
                             is_left_node=True)
            right_node = Node(id=r_id,
                              sitename=self.sitename,
                              sum_grad=sum_grad - l_g,
                              sum_hess=sum_hess - l_h,
                              weight=self.splitter.node_weight(
                                  sum_grad - l_g, sum_hess - l_h),
                              parent_nodeid=p_id,
                              sibling_nodeid=l_id,
                              is_left_node=False)

            next_layer_node.append(left_node)
            next_layer_node.append(right_node)
            self.tree_node.append(cur_to_split[idx])

            self.update_feature_importance(split_info[idx],
                                           record_site_name=False)

        return next_layer_node
Example #2
0
 def test_node(self):
     param_dict = {"id": 5, "sitename": "test", "fid": 55, "bid": 555,
                   "weight": -1, "is_leaf": True, "sum_grad": 2, "sum_hess": 3,
                   "left_nodeid": 6, "right_nodeid": 7}
     node = Node(id=5, sitename="test", fid=55, bid=555, weight=-1, is_leaf=True,
                 sum_grad=2, sum_hess=3, left_nodeid=6, right_nodeid=7)
     for key in param_dict:
         self.assertTrue(param_dict[key] == getattr(node, key))
Example #3
0
 def init_root_node_and_gh_sum(self):
     # compute local g_sum and h_sum
     g_sum, h_sum = self.get_grad_hess_sum(self.g_h)
     # get aggregated root info
     self.aggregator.send_local_root_node_info(g_sum, h_sum, suffix=('root_node_sync1', self.epoch_idx))
     g_h_dict = self.aggregator.get_aggregated_root_info(suffix=('root_node_sync2', self.epoch_idx))
     global_g_sum, global_h_sum = g_h_dict['g_sum'], g_h_dict['h_sum']
     # initialize node
     root_node = Node(id=0, sitename=consts.GUEST, sum_grad=global_g_sum, sum_hess=global_h_sum, weight=
                      self.splitter.node_weight(global_g_sum, global_h_sum))
     self.cur_layer_node = [root_node]
Example #4
0
    def set_model_param(self, model_param):
        self.tree_node = []
        for node_param in model_param.tree_:
            _node = Node(id=node_param.id,
                         sitename=node_param.sitename,
                         fid=node_param.fid,
                         bid=node_param.bid,
                         weight=node_param.weight,
                         is_leaf=node_param.is_leaf,
                         left_nodeid=node_param.left_nodeid,
                         right_nodeid=node_param.right_nodeid,
                         missing_dir=node_param.missing_dir)

            self.tree_node.append(_node)
Example #5
0
    def fit(self):
        """
        start to fit
        """
        LOGGER.info(
            'begin to fit h**o decision tree, epoch {}, tree idx {}'.format(
                self.epoch_idx, self.tree_idx))

        # compute local g_sum and h_sum
        g_sum, h_sum = self.get_grad_hess_sum(self.g_h)

        # get aggregated root info
        self.aggregator.send_local_root_node_info(g_sum,
                                                  h_sum,
                                                  suffix=('root_node_sync1',
                                                          self.epoch_idx))
        g_h_dict = self.aggregator.get_aggregated_root_info(
            suffix=('root_node_sync2', self.epoch_idx))
        global_g_sum, global_h_sum = g_h_dict['g_sum'], g_h_dict['h_sum']

        # initialize node
        root_node = Node(id=0,
                         sitename=consts.GUEST,
                         sum_grad=global_g_sum,
                         sum_hess=global_h_sum,
                         weight=self.splitter.node_weight(
                             global_g_sum, global_h_sum))

        self.cur_layer_node = [root_node]
        LOGGER.debug('assign samples to root node')
        self.inst2node_idx = self.assign_instance_to_root_node(
            self.data_bin, 0)

        tree_height = self.max_depth + 1  # non-leaf node height + 1 layer leaf

        for dep in range(tree_height):

            if dep + 1 == tree_height:

                for node in self.cur_layer_node:
                    node.is_leaf = True
                    self.tree_node.append(node)

                rest_sample_weights = self.get_node_sample_weights(
                    self.inst2node_idx, self.tree_node)
                if self.sample_weights is None:
                    self.sample_weights = rest_sample_weights
                else:
                    self.sample_weights = self.sample_weights.union(
                        rest_sample_weights)

                # stop fitting
                break

            LOGGER.debug('start to fit layer {}'.format(dep))

            table_with_assignment = self.update_instances_node_positions()

            # send current layer node number:
            self.sync_cur_layer_node_num(len(self.cur_layer_node),
                                         suffix=(dep, self.epoch_idx,
                                                 self.tree_idx))

            split_info, agg_histograms = [], []
            for batch_id, i in enumerate(
                    range(0, len(self.cur_layer_node), self.max_split_nodes)):
                cur_to_split = self.cur_layer_node[i:i + self.max_split_nodes]

                node_map = self.get_node_map(nodes=cur_to_split)
                LOGGER.debug('node map is {}'.format(node_map))
                LOGGER.debug(
                    'computing histogram for batch{} at depth{}'.format(
                        batch_id, dep))
                local_histogram = self.get_left_node_local_histogram(
                    cur_nodes=cur_to_split,
                    tree=self.tree_node,
                    g_h=self.g_h,
                    table_with_assign=table_with_assignment,
                    split_points=self.bin_split_points,
                    sparse_point=self.bin_sparse_points,
                    valid_feature=self.valid_features)

                LOGGER.debug(
                    'federated finding best splits for batch{} at layer {}'.
                    format(batch_id, dep))
                self.sync_local_node_histogram(local_histogram,
                                               suffix=(batch_id, dep,
                                                       self.epoch_idx,
                                                       self.tree_idx))

                agg_histograms += local_histogram

            split_info = self.sync_best_splits(suffix=(dep, self.epoch_idx))
            LOGGER.debug('got best splits from arbiter')

            new_layer_node = self.update_tree(self.cur_layer_node, split_info)
            self.cur_layer_node = new_layer_node

            self.inst2node_idx, leaf_val = self.assign_instances_to_new_node(
                table_with_assignment, self.tree_node)

            # record leaf val
            if self.sample_weights is None:
                self.sample_weights = leaf_val
            else:
                self.sample_weights = self.sample_weights.union(leaf_val)

            LOGGER.debug('assigning instance to new nodes done')

        self.convert_bin_to_real()
        LOGGER.debug('fitting tree done')