def fit(self, attribute_list, class_list, row_sampler, col_sampler, bin_structure): # when we start to fit a tree, we first conduct row and column sampling col_sampler.shuffle() row_sampler.shuffle() class_list.sampling(row_sampler.row_mask) # then we create the root node, initialize histogram(Gradient sum and Hessian sum) root_node = TreeNode(name=1, depth=1, feature_dim=attribute_list.feature_dim) root_node.Grad_setter(class_list.grad.sum()) root_node.Hess_setter(class_list.hess.sum()) self.root = root_node # every time a new node is created, we put it into self.name_to_node self.name_to_node[root_node.name] = root_node # put it into the alive_node, and fill the class_list, all data are assigned to root node initially self.alive_nodes.append(root_node) for i in range(class_list.dataset_size): class_list.corresponding_tree_node[i] = root_node # then build the tree util there is no alive tree_node to split self.build(attribute_list, class_list, col_sampler, bin_structure) self.clean_up()