Exemple #1
0
    def log_tree_ratio_prune(model: Model, tree: Tree,
                             proposal: PruneMutation):
        prob_chosen_split = log_probability_split_within_node(
            GrowMutation(proposal.updated_node, proposal.existing_node))
        if model.prior_name in ["poly_splits", "exponential_splits"]:
            numerator = log_probability_node_not_split(model,
                                                       proposal.updated_node)

            prob_left_not_split = log_probability_node_not_split(
                model, proposal.existing_node.left_child)
            prob_right_not_split = log_probability_node_not_split(
                model, proposal.existing_node.left_child)
            prob_updated_node_split = log_probability_node_split(
                model, proposal.existing_node)
            prob_chosen_split = log_probability_split_within_node(
                GrowMutation(proposal.updated_node, proposal.existing_node))
            denominator = prob_left_not_split + prob_right_not_split + prob_updated_node_split + prob_chosen_split
        elif model.prior_name == "cond_unif":
            K = len(tree.leaf_nodes)
            # log gives negative for tree with single node, so don't suggest prune if single node
            denominator = np.log(model.lam / (4 * K - 6)) + prob_chosen_split
            numerator = 0
        elif model.prior_name == "exp_prior":
            denominator = -model.c + prob_chosen_split
            numerator = 0
        else:
            return
        return numerator - denominator
Exemple #2
0
    def search(index: int = 0):

        left_child_index, right_child_index = sklearn_tree.children_left[
            index], sklearn_tree.children_right[index]

        if left_child_index == -1:  # Trees are binary splits, so only need to check left tree
            return

        searched_node: LeafNode = nodes[index]

        split_conditions = map_sklearn_split_into_bartpy_split_conditions(
            sklearn_tree, index)
        decision_node = split_node(searched_node, split_conditions)

        left_child: LeafNode = decision_node.left_child
        right_child: LeafNode = decision_node.right_child
        left_child.set_value(sklearn_tree.value[left_child_index][0][0])
        right_child.set_value(sklearn_tree.value[right_child_index][0][0])

        mutation = GrowMutation(searched_node, decision_node)
        mutate(bartpy_tree, mutation)

        nodes[index] = decision_node
        nodes[left_child_index] = decision_node.left_child
        nodes[right_child_index] = decision_node.right_child

        search(left_child_index)
        search(right_child_index)
Exemple #3
0
    def test_growing_decision_node(self):
        a = LeafNode(Split(self.data))
        b = LeafNode(Split(self.data))
        c = LeafNode(Split(self.data))
        d = DecisionNode(Split(self.data), a, b)
        e = DecisionNode(Split(self.data), c, d)

        with self.assertRaises(TypeError):
            GrowMutation(d, a)
Exemple #4
0
    def log_tree_ratio_prune(model: Model, proposal: PruneMutation):
        numerator = log_probability_node_not_split(model, proposal.updated_node)

        prob_left_not_split = log_probability_node_not_split(model, proposal.existing_node.left_child)
        prob_right_not_split = log_probability_node_not_split(model, proposal.existing_node.left_child)
        prob_updated_node_split = log_probability_node_split(model, proposal.existing_node)
        prob_chosen_split = log_probability_split_within_node(GrowMutation(proposal.updated_node, proposal.existing_node))
        denominator = prob_left_not_split + prob_right_not_split + prob_updated_node_split + prob_chosen_split

        return numerator - denominator
Exemple #5
0
    def log_prune_transition_ratio(self, tree: Tree, mutation: PruneMutation):
        prob_grow_node_selected = safe_negative_log(
            n_splittable_leaf_nodes(tree) - 1)
        prob_split = log_probability_split_within_node(
            GrowMutation(mutation.updated_node, mutation.existing_node))
        prob_grow_selected = prob_grow_node_selected + prob_split

        prob_prune_selected = safe_negative_log(
            n_prunable_decision_nodes(tree))

        prob_selection_ratio = prob_grow_selected - prob_prune_selected
        grow_prune_ratio = np.log(self.prob_method[0] / self.prob_method[1])

        return grow_prune_ratio + prob_selection_ratio
Exemple #6
0
    def log_prune_transition_ratio(self, tree: Tree, mutation: PruneMutation):
        if n_splittable_leaf_nodes(tree) == 1:
            prob_grow_node_selected = -np.inf  # Infinitely unlikely to be able to grow a null tree
        else:
            prob_grow_node_selected = -np.log(
                n_splittable_leaf_nodes(tree) - 1)
        prob_split = log_probability_split_within_node(
            GrowMutation(mutation.updated_node, mutation.existing_node))
        prob_grow_selected = prob_grow_node_selected + prob_split

        prob_prune_selected = -np.log(n_prunable_decision_nodes(tree))

        prob_selection_ratio = prob_grow_selected - prob_prune_selected
        grow_prune_ratio = np.log(self.prob_method[0] / self.prob_method[1])

        return grow_prune_ratio + prob_selection_ratio
Exemple #7
0
def uniformly_sample_grow_mutation(tree: Tree) -> TreeMutation:
    node = random_splittable_leaf_node(tree)
    updated_node = sample_split_node(node)
    return GrowMutation(node, updated_node)
Exemple #8
0
    def log_prune_transition_ratio(self, tree: Tree, mutation: PruneMutation):
        prob_selection_ratio = log_probability_split_within_node(GrowMutation(mutation.updated_node, mutation.existing_node))
        grow_prune_ratio = np.log(self.prob_method[0] / self.prob_method[1])

        return grow_prune_ratio + prob_selection_ratio
Exemple #9
0
def grow_mutations(tree: Tree) -> List[TreeMutation]:
    return [GrowMutation(x, sample_split_node(x)) for x in tree.leaf_nodes]