def test_get_best_trial_across_jumps(self): root_a = build_full_tree(4, starting_objective=1) root_b = build_full_tree(4, starting_objective=10) a_leafs = root_a.get_nodes_at_depth(3) b_leafs = root_b.get_nodes_at_depth(3) assert b_leafs[0].get_best_trial() == root_b.item a_leafs[0].set_jump(b_leafs[0].parent) # Should look past jump of parent assert b_leafs[0].get_best_trial() == root_a.item # Should look past jump directly assert b_leafs[0].parent.get_best_trial() == root_a.item # Should look towards root, there is no jump between root and this node assert b_leafs[0].parent.parent.get_best_trial() == root_b.item
def test_get_best_trial_equality(self): root = build_full_tree(4) leafs = root.get_nodes_at_depth(3) assert leafs[0].item.id == "id-8" assert leafs[0].get_best_trial() == root.item # Return parent in case of equality, if they are all as good, we want the earliest one. root.children[0].item.objective.value = root.item.objective.value assert leafs[0].get_best_trial() == root.item # Make sure the second one is returned is root is not as good. root.item.objective.value += 1 assert leafs[0].get_best_trial() == root.children[0].item
def test_get_best_trial_straigth_lineage(self): root = build_full_tree(4) leafs = root.get_nodes_at_depth(3) assert leafs[0].item.id == "id-8" assert leafs[0].get_best_trial() == root.item assert leafs[1].get_best_trial() == root.item leafs[0].item.objective.value = -1 # Now best trial is leaf on first branch assert leafs[0].get_best_trial() == leafs[0].item # But still root for second branch assert leafs[1].get_best_trial() == root.item third_row = root.get_nodes_at_depth(2) assert third_row[0].item.id == "id-4" assert third_row[0].get_best_trial() == root.item assert third_row[1].get_best_trial() == root.item third_row[0].item.objective.value = -2 # Now best trial is third node on first branch assert third_row[0].get_best_trial() == third_row[0].item # But still root for second branch assert third_row[1].get_best_trial() == root.item # And third node on full first and second branches assert leafs[0].get_best_trial() == third_row[0].item assert leafs[1].get_best_trial() == third_row[0].item # But not for third branch assert leafs[2].get_best_trial() == root.item second_row = root.get_nodes_at_depth(1) assert second_row[0].item.id == "id-2" assert second_row[0].get_best_trial() == root.item assert second_row[1].get_best_trial() == root.item second_row[0].item.objective.value = -3 # Now best trial is second node on first branch assert second_row[0].get_best_trial() == second_row[0].item # But still root for second branch assert second_row[1].get_best_trial() == root.item # And second node on full 4 first branches assert leafs[0].get_best_trial() == second_row[0].item assert leafs[1].get_best_trial() == second_row[0].item assert leafs[2].get_best_trial() == second_row[0].item assert leafs[3].get_best_trial() == second_row[0].item # But not for fifth branch assert leafs[4].get_best_trial() == root.item
def test_get_best_trial_broken_leaf(self): root = build_full_tree(4, starting_objective=1) leafs = root.get_nodes_at_depth(3) leafs[0].item.objective = None assert leafs[0].get_best_trial() == root.item