コード例 #1
0
ファイル: test_asha.py プロジェクト: breuleux/orion
    def test_suggest_in_finite_cardinality(self):
        """Test that suggest None when search space is empty"""
        space = Space()
        space.register(Integer("yolo1", "uniform", 0, 5))
        space.register(Fidelity("epoch", 1, 9, 3))

        asha = ASHA(space)
        for i in range(6):
            force_observe(
                asha,
                create_trial(
                    (1, i),
                    names=("epoch", "yolo1"),
                    types=("fidelity", "integer"),
                    results={"objective": i},
                ),
            )

        for i in range(2):
            force_observe(
                asha,
                create_trial(
                    (3, i),
                    names=("epoch", "yolo1"),
                    types=("fidelity", "integer"),
                    results={"objective": i},
                ),
            )

        assert asha.suggest(1) == []
コード例 #2
0
ファイル: test_asha.py プロジェクト: breuleux/orion
    def test_register_bracket_multi_fidelity(self, space, b_config):
        """Check that a point is registered inside the same bracket for diff fidelity."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 1
        trial = create_trial_for_hb((fidelity, value), 0.0)
        trial_id = asha.get_id(trial, ignore_fidelity=True)

        force_observe(asha, trial)

        bracket = asha.brackets[0]

        assert len(bracket.rungs[0])
        assert trial_id in bracket.rungs[0]["results"]
        assert bracket.rungs[0]["results"][trial_id][0] == 0.0
        assert bracket.rungs[0]["results"][trial_id][1].params == trial.params

        fidelity = 3
        trial = create_trial_for_hb((fidelity, value), 0.0)
        trial_id = asha.get_id(trial, ignore_fidelity=True)

        force_observe(asha, trial)

        assert len(bracket.rungs[1])
        assert trial_id in bracket.rungs[1]["results"]
        assert bracket.rungs[0]["results"][trial_id][1].params != trial.params
        assert bracket.rungs[1]["results"][trial_id][0] == 0.0
        assert bracket.rungs[1]["results"][trial_id][1].params == trial.params
コード例 #3
0
ファイル: test_asha.py プロジェクト: breuleux/orion
    def test_register_corrupted_db(self, caplog, space, b_config):
        """Check that a point cannot registered if passed in order diff than fidelity."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 3
        trial = create_trial_for_hb((fidelity, value))

        force_observe(asha, trial)
        assert "Trial registered to wrong bracket" not in caplog.text

        fidelity = 1
        trial = create_trial_for_hb((fidelity, value), objective=0.0)

        caplog.clear()
        force_observe(asha, trial)
        assert "Trial registered to wrong bracket" in caplog.text
コード例 #4
0
ファイル: test_asha.py プロジェクト: breuleux/orion
    def test_register_next_bracket(self, space, b_config):
        """Check that a point is registered inside the good bracket when higher fidelity."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 3
        trial = create_trial_for_hb((fidelity, value), 0.0)
        trial_id = asha.get_id(trial, ignore_fidelity=True)

        force_observe(asha, trial)

        assert sum(len(rung["results"])
                   for rung in asha.brackets[0].rungs) == 0
        assert sum(len(rung["results"])
                   for rung in asha.brackets[1].rungs) == 1
        assert sum(len(rung["results"])
                   for rung in asha.brackets[2].rungs) == 0
        assert trial_id in asha.brackets[1].rungs[0]["results"]
        compare_registered_trial(
            asha.brackets[1].rungs[0]["results"][trial_id], trial)

        value = 51
        fidelity = 9
        trial = create_trial_for_hb((fidelity, value), 0.0)
        trial_id = asha.get_id(trial, ignore_fidelity=True)

        force_observe(asha, trial)

        assert sum(len(rung["results"])
                   for rung in asha.brackets[0].rungs) == 0
        assert sum(len(rung["results"])
                   for rung in asha.brackets[1].rungs) == 1
        assert sum(len(rung["results"])
                   for rung in asha.brackets[2].rungs) == 1
        assert trial_id in asha.brackets[2].rungs[0]["results"]
        compare_registered_trial(
            asha.brackets[2].rungs[0]["results"][trial_id], trial)