예제 #1
0
    def test_fork_non_existing_trial(self):
        lineages = Lineages()
        trial = TrialStub(id="stub")
        new_trial = TrialStub(id="fork")

        with pytest.raises(KeyError):
            new_lineage = lineages.fork(trial, new_trial)
예제 #2
0
 def test_add_new_trial(self):
     lineages = Lineages()
     assert len(lineages) == 0
     lineage = lineages.add(TrialStub(id="stub"))
     assert len(lineages) == 1
     assert lineages._lineage_roots[0] is lineage
     assert lineages._trial_to_lineages["stub"] is lineage
예제 #3
0
    def test_add_duplicate(self):
        lineages = Lineages()
        assert len(lineages) == 0
        lineage = lineages.add(TrialStub(id="stub"))
        assert len(lineages) == 1

        new_lineage = lineages.add(TrialStub(id="stub"))
        assert new_lineage is lineage
        assert len(lineages) == 1
예제 #4
0
    def test_register_existing_trial(self):
        lineages = Lineages()
        trial = TrialStub(id="my-id")
        lineage = lineages.add(trial)
        assert lineages._lineage_roots == [lineage]
        assert lineage.item.objective is None

        trial.objective = ObjectiveStub(1)
        assert lineages.register(trial) is lineage
        assert lineages._lineage_roots == [lineage]
        assert lineage.item.objective.value == 1
예제 #5
0
 def test_fork_existing_trial(self, tmp_path):
     lineages = Lineages()
     trial = TrialStub(id="stub", working_dir=os.path.join(tmp_path, "stub"))
     os.makedirs(trial.working_dir)
     lineage = lineages.add(trial)
     assert len(lineages) == 1
     new_trial = TrialStub(id="fork", working_dir=os.path.join(tmp_path, "fork"))
     new_lineage = lineages.fork(trial, new_trial)
     assert len(lineages) == 1
     assert lineages._lineage_roots[0].children[0] is new_lineage
     assert lineages._trial_to_lineages["fork"] is new_lineage
예제 #6
0
def build_lineages_for_exploit(space,
                               monkeypatch,
                               trials=None,
                               elites=None,
                               additional_trials=None,
                               seed=1,
                               num=10):
    if trials is None:
        trials = space.sample(num, seed=seed)
        for i, trial in enumerate(trials):
            trial.status = "completed"
            trial._results.append(
                trial.Result(name="objective", type="objective", value=i))
    if elites is None:
        elites = space.sample(num, seed=seed + 1)
        for i, trial in enumerate(elites):
            trial.status = "completed"
            trial._results.append(
                trial.Result(name="objective", type="objective", value=i * 2))

    if additional_trials:
        trials += additional_trials

    def return_trials(*args, **kwargs):
        return trials

    def return_elites(*args, **kwargs):
        return elites

    lineages = Lineages()
    monkeypatch.setattr(lineages, "get_trials_at_depth", return_trials)
    monkeypatch.setattr(lineages, "get_elites", return_elites)

    return lineages
예제 #7
0
    def test_set_jump_existing_trial(self):
        lineages = Lineages()
        root_1 = TrialStub(id="root-1")
        lineage_1 = lineages.add(root_1)
        root_2 = TrialStub(id="root-2")
        lineage_2 = lineages.add(root_2)
        child_trial = TrialStub(id="child")
        child_lineage = lineages.fork(root_1, child_trial)
        lineages.set_jump(root_2, child_trial)

        assert child_lineage.base is lineage_2
        assert lineage_2.jumps == [child_lineage]
        assert child_lineage.jumps == []
        assert lineage_2.base is None
        assert lineage_1.jumps == []
        assert lineage_1.base is None
예제 #8
0
    def test_get_lineage_existing_node_trial(self):
        lineages = Lineages()
        for root_index in range(2):

            trial = TrialStub(id=f"lineage-{root_index}-0")
            lineage = lineages.add(trial)
            for depth in range(1, 10):
                new_trial = TrialStub(id=f"lineage-{root_index}-{depth}")
                lineage = lineages.fork(trial, new_trial)
                trial = new_trial

        lineage = lineages.get_lineage(TrialStub(id="lineage-0-2"))
        assert lineage.root is lineages._lineage_roots[0]
        assert lineage.node_depth == 2

        lineage = lineages.get_lineage(TrialStub(id="lineage-1-5"))
        assert lineage.root is lineages._lineage_roots[1]
        assert lineage.node_depth == 5
예제 #9
0
    def test_get_trials_at_depth_given_non_existing_trial(self):
        lineages = Lineages()

        with pytest.raises(KeyError, match="idontexist"):
            lineages.get_trials_at_depth(TrialStub(id="idontexist"))
예제 #10
0
 def test_get_elites_none_completed(self):
     lineages = Lineages()
     lineages.add(TrialStub(id="1"))
     lineages.add(TrialStub(id="2"))
     lineages.add(TrialStub(id="3"))
     assert lineages.get_elites() == []
예제 #11
0
 def test_get_elites_empty(self):
     lineages = Lineages()
     assert lineages.get_elites() == []
예제 #12
0
 def test_register_new_trial(self):
     lineages = Lineages()
     new_trial = TrialStub(id="new")
     lineage = lineages.register(new_trial)
     assert lineages._lineage_roots == [lineage]
예제 #13
0
 def test_set_jump_non_existing_new_trial(self):
     lineages = Lineages()
     trial = TrialStub(id="exists")
     lineages.add(trial)
     with pytest.raises(KeyError, match="'newtrialdontexist'"):
         lineages.set_jump(trial, TrialStub(id="newtrialdontexist"))
예제 #14
0
 def test_set_jump_non_existing_base_trial(self):
     lineages = Lineages()
     with pytest.raises(KeyError, match="'dontexist'"):
         lineages.set_jump(
             TrialStub(id="dontexist"), TrialStub(id="dontexistbutdoesntmatter")
         )
예제 #15
0
    def test_get_lineage_non_existing_trial(self):
        lineages = Lineages()

        with pytest.raises(KeyError):
            lineages.get_lineage(TrialStub(id="id"))
예제 #16
0
 def test_get_lineage_existing_root_trial(self):
     lineages = Lineages()
     trial = TrialStub(id="stub")
     lineage = lineages.add(trial)
     assert lineages.get_lineage(trial) is lineage