def log_tree_ratio_prune(model: Model, tree: Tree, proposal: PruneMutation): prob_chosen_split = log_probability_split_within_node( GrowMutation(proposal.updated_node, proposal.existing_node)) if model.prior_name in ["poly_splits", "exponential_splits"]: numerator = log_probability_node_not_split(model, proposal.updated_node) prob_left_not_split = log_probability_node_not_split( model, proposal.existing_node.left_child) prob_right_not_split = log_probability_node_not_split( model, proposal.existing_node.left_child) prob_updated_node_split = log_probability_node_split( model, proposal.existing_node) prob_chosen_split = log_probability_split_within_node( GrowMutation(proposal.updated_node, proposal.existing_node)) denominator = prob_left_not_split + prob_right_not_split + prob_updated_node_split + prob_chosen_split elif model.prior_name == "cond_unif": K = len(tree.leaf_nodes) # log gives negative for tree with single node, so don't suggest prune if single node denominator = np.log(model.lam / (4 * K - 6)) + prob_chosen_split numerator = 0 elif model.prior_name == "exp_prior": denominator = -model.c + prob_chosen_split numerator = 0 else: return return numerator - denominator
def search(index: int = 0): left_child_index, right_child_index = sklearn_tree.children_left[ index], sklearn_tree.children_right[index] if left_child_index == -1: # Trees are binary splits, so only need to check left tree return searched_node: LeafNode = nodes[index] split_conditions = map_sklearn_split_into_bartpy_split_conditions( sklearn_tree, index) decision_node = split_node(searched_node, split_conditions) left_child: LeafNode = decision_node.left_child right_child: LeafNode = decision_node.right_child left_child.set_value(sklearn_tree.value[left_child_index][0][0]) right_child.set_value(sklearn_tree.value[right_child_index][0][0]) mutation = GrowMutation(searched_node, decision_node) mutate(bartpy_tree, mutation) nodes[index] = decision_node nodes[left_child_index] = decision_node.left_child nodes[right_child_index] = decision_node.right_child search(left_child_index) search(right_child_index)
def test_growing_decision_node(self): a = LeafNode(Split(self.data)) b = LeafNode(Split(self.data)) c = LeafNode(Split(self.data)) d = DecisionNode(Split(self.data), a, b) e = DecisionNode(Split(self.data), c, d) with self.assertRaises(TypeError): GrowMutation(d, a)
def log_tree_ratio_prune(model: Model, proposal: PruneMutation): numerator = log_probability_node_not_split(model, proposal.updated_node) prob_left_not_split = log_probability_node_not_split(model, proposal.existing_node.left_child) prob_right_not_split = log_probability_node_not_split(model, proposal.existing_node.left_child) prob_updated_node_split = log_probability_node_split(model, proposal.existing_node) prob_chosen_split = log_probability_split_within_node(GrowMutation(proposal.updated_node, proposal.existing_node)) denominator = prob_left_not_split + prob_right_not_split + prob_updated_node_split + prob_chosen_split return numerator - denominator
def log_prune_transition_ratio(self, tree: Tree, mutation: PruneMutation): prob_grow_node_selected = safe_negative_log( n_splittable_leaf_nodes(tree) - 1) prob_split = log_probability_split_within_node( GrowMutation(mutation.updated_node, mutation.existing_node)) prob_grow_selected = prob_grow_node_selected + prob_split prob_prune_selected = safe_negative_log( n_prunable_decision_nodes(tree)) prob_selection_ratio = prob_grow_selected - prob_prune_selected grow_prune_ratio = np.log(self.prob_method[0] / self.prob_method[1]) return grow_prune_ratio + prob_selection_ratio
def log_prune_transition_ratio(self, tree: Tree, mutation: PruneMutation): if n_splittable_leaf_nodes(tree) == 1: prob_grow_node_selected = -np.inf # Infinitely unlikely to be able to grow a null tree else: prob_grow_node_selected = -np.log( n_splittable_leaf_nodes(tree) - 1) prob_split = log_probability_split_within_node( GrowMutation(mutation.updated_node, mutation.existing_node)) prob_grow_selected = prob_grow_node_selected + prob_split prob_prune_selected = -np.log(n_prunable_decision_nodes(tree)) prob_selection_ratio = prob_grow_selected - prob_prune_selected grow_prune_ratio = np.log(self.prob_method[0] / self.prob_method[1]) return grow_prune_ratio + prob_selection_ratio
def uniformly_sample_grow_mutation(tree: Tree) -> TreeMutation: node = random_splittable_leaf_node(tree) updated_node = sample_split_node(node) return GrowMutation(node, updated_node)
def log_prune_transition_ratio(self, tree: Tree, mutation: PruneMutation): prob_selection_ratio = log_probability_split_within_node(GrowMutation(mutation.updated_node, mutation.existing_node)) grow_prune_ratio = np.log(self.prob_method[0] / self.prob_method[1]) return grow_prune_ratio + prob_selection_ratio
def grow_mutations(tree: Tree) -> List[TreeMutation]: return [GrowMutation(x, sample_split_node(x)) for x in tree.leaf_nodes]