def update_tree(self, cur_to_split: List[Node], split_info: List[SplitInfo]): """ update current tree structure ---------- split_info """ LOGGER.debug('updating tree_node, cur layer has {} node'.format( len(cur_to_split))) next_layer_node = [] assert len(cur_to_split) == len(split_info) for idx in range(len(cur_to_split)): sum_grad = cur_to_split[idx].sum_grad sum_hess = cur_to_split[idx].sum_hess if split_info[idx].best_fid is None or split_info[ idx].gain <= self.min_impurity_split + consts.FLOAT_ZERO: cur_to_split[idx].is_leaf = True self.tree_node.append(cur_to_split[idx]) continue cur_to_split[idx].fid = split_info[idx].best_fid cur_to_split[idx].bid = split_info[idx].best_bid cur_to_split[idx].missing_dir = split_info[idx].missing_dir p_id = cur_to_split[idx].id l_id, r_id = self.tree_node_num + 1, self.tree_node_num + 2 cur_to_split[idx].left_nodeid, cur_to_split[ idx].right_nodeid = l_id, r_id self.tree_node_num += 2 l_g, l_h = split_info[idx].sum_grad, split_info[idx].sum_hess # create new left node and new right node left_node = Node(id=l_id, sitename=self.sitename, sum_grad=l_g, sum_hess=l_h, weight=self.splitter.node_weight(l_g, l_h), parent_nodeid=p_id, sibling_nodeid=r_id, is_left_node=True) right_node = Node(id=r_id, sitename=self.sitename, sum_grad=sum_grad - l_g, sum_hess=sum_hess - l_h, weight=self.splitter.node_weight( sum_grad - l_g, sum_hess - l_h), parent_nodeid=p_id, sibling_nodeid=l_id, is_left_node=False) next_layer_node.append(left_node) next_layer_node.append(right_node) self.tree_node.append(cur_to_split[idx]) self.update_feature_importance(split_info[idx], record_site_name=False) return next_layer_node
def test_node(self): param_dict = {"id": 5, "sitename": "test", "fid": 55, "bid": 555, "weight": -1, "is_leaf": True, "sum_grad": 2, "sum_hess": 3, "left_nodeid": 6, "right_nodeid": 7} node = Node(id=5, sitename="test", fid=55, bid=555, weight=-1, is_leaf=True, sum_grad=2, sum_hess=3, left_nodeid=6, right_nodeid=7) for key in param_dict: self.assertTrue(param_dict[key] == getattr(node, key))
def init_root_node_and_gh_sum(self): # compute local g_sum and h_sum g_sum, h_sum = self.get_grad_hess_sum(self.g_h) # get aggregated root info self.aggregator.send_local_root_node_info(g_sum, h_sum, suffix=('root_node_sync1', self.epoch_idx)) g_h_dict = self.aggregator.get_aggregated_root_info(suffix=('root_node_sync2', self.epoch_idx)) global_g_sum, global_h_sum = g_h_dict['g_sum'], g_h_dict['h_sum'] # initialize node root_node = Node(id=0, sitename=consts.GUEST, sum_grad=global_g_sum, sum_hess=global_h_sum, weight= self.splitter.node_weight(global_g_sum, global_h_sum)) self.cur_layer_node = [root_node]
def set_model_param(self, model_param): self.tree_node = [] for node_param in model_param.tree_: _node = Node(id=node_param.id, sitename=node_param.sitename, fid=node_param.fid, bid=node_param.bid, weight=node_param.weight, is_leaf=node_param.is_leaf, left_nodeid=node_param.left_nodeid, right_nodeid=node_param.right_nodeid, missing_dir=node_param.missing_dir) self.tree_node.append(_node)
def fit(self): """ start to fit """ LOGGER.info( 'begin to fit h**o decision tree, epoch {}, tree idx {}'.format( self.epoch_idx, self.tree_idx)) # compute local g_sum and h_sum g_sum, h_sum = self.get_grad_hess_sum(self.g_h) # get aggregated root info self.aggregator.send_local_root_node_info(g_sum, h_sum, suffix=('root_node_sync1', self.epoch_idx)) g_h_dict = self.aggregator.get_aggregated_root_info( suffix=('root_node_sync2', self.epoch_idx)) global_g_sum, global_h_sum = g_h_dict['g_sum'], g_h_dict['h_sum'] # initialize node root_node = Node(id=0, sitename=consts.GUEST, sum_grad=global_g_sum, sum_hess=global_h_sum, weight=self.splitter.node_weight( global_g_sum, global_h_sum)) self.cur_layer_node = [root_node] LOGGER.debug('assign samples to root node') self.inst2node_idx = self.assign_instance_to_root_node( self.data_bin, 0) tree_height = self.max_depth + 1 # non-leaf node height + 1 layer leaf for dep in range(tree_height): if dep + 1 == tree_height: for node in self.cur_layer_node: node.is_leaf = True self.tree_node.append(node) rest_sample_weights = self.get_node_sample_weights( self.inst2node_idx, self.tree_node) if self.sample_weights is None: self.sample_weights = rest_sample_weights else: self.sample_weights = self.sample_weights.union( rest_sample_weights) # stop fitting break LOGGER.debug('start to fit layer {}'.format(dep)) table_with_assignment = self.update_instances_node_positions() # send current layer node number: self.sync_cur_layer_node_num(len(self.cur_layer_node), suffix=(dep, self.epoch_idx, self.tree_idx)) split_info, agg_histograms = [], [] for batch_id, i in enumerate( range(0, len(self.cur_layer_node), self.max_split_nodes)): cur_to_split = self.cur_layer_node[i:i + self.max_split_nodes] node_map = self.get_node_map(nodes=cur_to_split) LOGGER.debug('node map is {}'.format(node_map)) LOGGER.debug( 'computing histogram for batch{} at depth{}'.format( batch_id, dep)) local_histogram = self.get_left_node_local_histogram( cur_nodes=cur_to_split, tree=self.tree_node, g_h=self.g_h, table_with_assign=table_with_assignment, split_points=self.bin_split_points, sparse_point=self.bin_sparse_points, valid_feature=self.valid_features) LOGGER.debug( 'federated finding best splits for batch{} at layer {}'. format(batch_id, dep)) self.sync_local_node_histogram(local_histogram, suffix=(batch_id, dep, self.epoch_idx, self.tree_idx)) agg_histograms += local_histogram split_info = self.sync_best_splits(suffix=(dep, self.epoch_idx)) LOGGER.debug('got best splits from arbiter') new_layer_node = self.update_tree(self.cur_layer_node, split_info) self.cur_layer_node = new_layer_node self.inst2node_idx, leaf_val = self.assign_instances_to_new_node( table_with_assignment, self.tree_node) # record leaf val if self.sample_weights is None: self.sample_weights = leaf_val else: self.sample_weights = self.sample_weights.union(leaf_val) LOGGER.debug('assigning instance to new nodes done') self.convert_bin_to_real() LOGGER.debug('fitting tree done')