예제 #1
0
    def test_get_best_trial_across_jumps(self):
        root_a = build_full_tree(4, starting_objective=1)
        root_b = build_full_tree(4, starting_objective=10)

        a_leafs = root_a.get_nodes_at_depth(3)
        b_leafs = root_b.get_nodes_at_depth(3)
        assert b_leafs[0].get_best_trial() == root_b.item
        a_leafs[0].set_jump(b_leafs[0].parent)

        # Should look past jump of parent
        assert b_leafs[0].get_best_trial() == root_a.item
        # Should look past jump directly
        assert b_leafs[0].parent.get_best_trial() == root_a.item
        # Should look towards root, there is no jump between root and this node
        assert b_leafs[0].parent.parent.get_best_trial() == root_b.item
예제 #2
0
    def test_get_best_trial_equality(self):
        root = build_full_tree(4)

        leafs = root.get_nodes_at_depth(3)
        assert leafs[0].item.id == "id-8"
        assert leafs[0].get_best_trial() == root.item

        # Return parent in case of equality, if they are all as good, we want the earliest one.
        root.children[0].item.objective.value = root.item.objective.value
        assert leafs[0].get_best_trial() == root.item

        # Make sure the second one is returned is root is not as good.
        root.item.objective.value += 1
        assert leafs[0].get_best_trial() == root.children[0].item
예제 #3
0
    def test_get_best_trial_straigth_lineage(self):
        root = build_full_tree(4)
        leafs = root.get_nodes_at_depth(3)
        assert leafs[0].item.id == "id-8"
        assert leafs[0].get_best_trial() == root.item
        assert leafs[1].get_best_trial() == root.item
        leafs[0].item.objective.value = -1
        # Now best trial is leaf on first branch
        assert leafs[0].get_best_trial() == leafs[0].item
        # But still root for second branch
        assert leafs[1].get_best_trial() == root.item

        third_row = root.get_nodes_at_depth(2)
        assert third_row[0].item.id == "id-4"
        assert third_row[0].get_best_trial() == root.item
        assert third_row[1].get_best_trial() == root.item

        third_row[0].item.objective.value = -2
        # Now best trial is third node on first branch
        assert third_row[0].get_best_trial() == third_row[0].item
        # But still root for second branch
        assert third_row[1].get_best_trial() == root.item
        # And third node on full first and second branches
        assert leafs[0].get_best_trial() == third_row[0].item
        assert leafs[1].get_best_trial() == third_row[0].item
        # But not for third branch
        assert leafs[2].get_best_trial() == root.item

        second_row = root.get_nodes_at_depth(1)
        assert second_row[0].item.id == "id-2"
        assert second_row[0].get_best_trial() == root.item
        assert second_row[1].get_best_trial() == root.item

        second_row[0].item.objective.value = -3
        # Now best trial is second node on first branch
        assert second_row[0].get_best_trial() == second_row[0].item
        # But still root for second branch
        assert second_row[1].get_best_trial() == root.item
        # And second node on full 4 first branches
        assert leafs[0].get_best_trial() == second_row[0].item
        assert leafs[1].get_best_trial() == second_row[0].item
        assert leafs[2].get_best_trial() == second_row[0].item
        assert leafs[3].get_best_trial() == second_row[0].item
        # But not for fifth branch
        assert leafs[4].get_best_trial() == root.item
예제 #4
0
    def test_get_best_trial_broken_leaf(self):
        root = build_full_tree(4, starting_objective=1)

        leafs = root.get_nodes_at_depth(3)
        leafs[0].item.objective = None
        assert leafs[0].get_best_trial() == root.item