Beispiel #1
0
    def test_suggest_duplicates(self, monkeypatch, asha, bracket, rung_0,
                                rung_1, rung_2):
        """Test that sampling collisions are handled."""
        asha.brackets = [bracket]
        bracket.asha = asha

        fidelity = 1
        duplicate_trial = create_trial_for_hb((fidelity, 0.0))
        new_trial = create_trial_for_hb((fidelity, 0.5))

        duplicate_id_wo_fidelity = asha.get_id(duplicate_trial,
                                               ignore_fidelity=True)
        bracket.rungs[0] = dict(
            n_trials=2,
            resources=1,
            results={duplicate_id_wo_fidelity: (0.0, duplicate_trial)},
        )
        asha.trial_to_brackets[duplicate_id_wo_fidelity] = 0

        asha.register(duplicate_trial)

        trials = [duplicate_trial, new_trial]

        def sample(num=1, seed=None):
            return trials

        monkeypatch.setattr(asha.space, "sample", sample)

        assert asha.suggest(1)[0].params == new_trial.params
Beispiel #2
0
    def test_register_bracket_multi_fidelity(self, space, b_config):
        """Check that a point is registered inside the same bracket for diff fidelity."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 1
        trial = create_trial_for_hb((fidelity, value), 0.0)
        trial_id = asha.get_id(trial, ignore_fidelity=True)

        force_observe(asha, trial)

        bracket = asha.brackets[0]

        assert len(bracket.rungs[0])
        assert trial_id in bracket.rungs[0]["results"]
        assert bracket.rungs[0]["results"][trial_id][0] == 0.0
        assert bracket.rungs[0]["results"][trial_id][1].params == trial.params

        fidelity = 3
        trial = create_trial_for_hb((fidelity, value), 0.0)
        trial_id = asha.get_id(trial, ignore_fidelity=True)

        force_observe(asha, trial)

        assert len(bracket.rungs[1])
        assert trial_id in bracket.rungs[1]["results"]
        assert bracket.rungs[0]["results"][trial_id][1].params != trial.params
        assert bracket.rungs[1]["results"][trial_id][0] == 0.0
        assert bracket.rungs[1]["results"][trial_id][1].params == trial.params
Beispiel #3
0
    def test_promotion_with_rung_1_hit(self, asha, bracket, rung_0):
        """Test that get_candidate gives us the next best thing if point is already in rung 1."""
        trial = create_trial_for_hb((1, 0.0), None)
        bracket.asha = asha
        bracket.rungs[0] = rung_0
        bracket.rungs[1]["results"][asha.get_id(trial,
                                                ignore_fidelity=True)] = (
                                                    trial.objective.value,
                                                    trial,
                                                )

        trial = bracket.get_candidates(0)[0]

        assert trial.params == create_trial_for_hb((1, 1.0), 0.0).params
Beispiel #4
0
    def test_suggest_promote_identic_objectives(self, asha, bracket,
                                                big_rung_0, big_rung_1):
        """Test that identic objectives are handled properly"""
        asha.brackets = [bracket]
        bracket.asha = asha

        n_trials = 9
        resources = 1

        results = {}
        for param in np.linspace(0, 8, 9):
            trial = create_trial_for_hb((resources, param), objective=0)
            trial_hash = trial.compute_trial_hash(
                trial,
                ignore_fidelity=True,
                ignore_experiment=True,
            )
            results[trial_hash] = (trial.objective.value, trial)

        bracket.rungs[0] = dict(n_trials=n_trials,
                                resources=resources,
                                results=results)

        candidates = asha.suggest(2)

        assert len(candidates) == 2
        assert (sum(1 for trial in candidates
                    if trial.params[asha.fidelity_index] == 3) == 2)
Beispiel #5
0
    def test_bad_register(self, asha, bracket):
        """Check that a non-valid point is not registered."""
        bracket.asha = asha

        with pytest.raises(IndexError) as ex:
            bracket.register(create_trial_for_hb((55, 0.0), 0.0))

        assert "Bad fidelity level 55" in str(ex.value)
Beispiel #6
0
    def test_register_corrupted_db(self, caplog, space, b_config):
        """Check that a point cannot registered if passed in order diff than fidelity."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 3
        trial = create_trial_for_hb((fidelity, value))

        force_observe(asha, trial)
        assert "Trial registered to wrong bracket" not in caplog.text

        fidelity = 1
        trial = create_trial_for_hb((fidelity, value), objective=0.0)

        caplog.clear()
        force_observe(asha, trial)
        assert "Trial registered to wrong bracket" in caplog.text
Beispiel #7
0
    def test_candidate_promotion(self, asha, bracket, rung_0):
        """Test that correct point is promoted."""
        bracket.asha = asha
        bracket.rungs[0] = rung_0

        point = bracket.get_candidates(0)[0]

        assert point.params == create_trial_for_hb((1, 0.0), 0.0).params
Beispiel #8
0
    def test_candidate_promotion(self, asha: ASHA, bracket: ASHABracket,
                                 rung_0: RungDict):
        """Test that correct point is promoted."""
        assert bracket.owner is asha
        bracket.rungs[0] = rung_0

        point = bracket.get_candidates(0)[0]

        assert point.params == create_trial_for_hb((1, 0.0), 0.0).params
Beispiel #9
0
    def test_register_invalid_fidelity(self, space, b_config):
        """Check that a point cannot registered if fidelity is invalid."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 2
        trial = create_trial_for_hb((fidelity, value))

        asha.observe([trial])

        assert not asha.has_suggested(trial)
        assert not asha.has_observed(trial)
Beispiel #10
0
    def test_update_rungs_return_candidate(self, asha, bracket, rung_1):
        """Check if a valid modified candidate is returned by update_rungs."""
        bracket.asha = asha
        bracket.rungs[1] = rung_1
        trial = create_trial_for_hb((3, 0.0), 0.0)

        candidate = bracket.promote(1)[0]

        trial_id = asha.get_id(trial, ignore_fidelity=True)
        assert trial_id in bracket.rungs[1]["results"]
        assert bracket.rungs[1]["results"][trial_id][1].params == trial.params
        assert candidate.params["epoch"] == 9
Beispiel #11
0
    def test_register_not_sampled(self, space, b_config, caplog):
        """Check that a point cannot registered if not sampled."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 2
        trial = create_trial_for_hb((fidelity, value))

        with caplog.at_level(logging.DEBUG, logger="orion.algo.hyperband"):
            asha.observe([trial])

        assert len(caplog.records) == 1
        assert "Ignoring trial" in caplog.records[0].msg
Beispiel #12
0
    def test_register(self, asha, bracket):
        """Check that a point is correctly registered inside a bracket."""
        bracket.asha = asha
        trial = create_trial_for_hb((1, 0.0), 0.0)
        trial_id = asha.get_id(trial, ignore_fidelity=True)

        bracket.register(trial)

        assert len(bracket.rungs[0])
        assert trial_id in bracket.rungs[0]["results"]
        assert bracket.rungs[0]["results"][trial_id][
            0] == trial.objective.value
        assert bracket.rungs[0]["results"][trial_id][1].to_dict(
        ) == trial.to_dict()
Beispiel #13
0
    def test_register(self, asha, bracket, rung_0, rung_1):
        """Check that a point is registered inside the bracket."""
        asha.brackets = [bracket]
        bracket.asha = asha
        bracket.rungs = [rung_0, rung_1]
        trial = create_trial_for_hb((1, 0.0), 0.0)
        trial_id = asha.get_id(trial, ignore_fidelity=True)

        asha.observe([trial])

        assert len(bracket.rungs[0])
        assert trial_id in bracket.rungs[0]["results"]
        assert bracket.rungs[0]["results"][trial_id][0] == 0.0
        assert bracket.rungs[0]["results"][trial_id][1].params == trial.params
Beispiel #14
0
    def test_register_next_bracket(self, space, b_config):
        """Check that a point is registered inside the good bracket when higher fidelity."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 3
        trial = create_trial_for_hb((fidelity, value), 0.0)
        trial_id = asha.get_id(trial, ignore_fidelity=True)

        force_observe(asha, trial)

        assert sum(len(rung["results"])
                   for rung in asha.brackets[0].rungs) == 0
        assert sum(len(rung["results"])
                   for rung in asha.brackets[1].rungs) == 1
        assert sum(len(rung["results"])
                   for rung in asha.brackets[2].rungs) == 0
        assert trial_id in asha.brackets[1].rungs[0]["results"]
        compare_registered_trial(
            asha.brackets[1].rungs[0]["results"][trial_id], trial)

        value = 51
        fidelity = 9
        trial = create_trial_for_hb((fidelity, value), 0.0)
        trial_id = asha.get_id(trial, ignore_fidelity=True)

        force_observe(asha, trial)

        assert sum(len(rung["results"])
                   for rung in asha.brackets[0].rungs) == 0
        assert sum(len(rung["results"])
                   for rung in asha.brackets[1].rungs) == 1
        assert sum(len(rung["results"])
                   for rung in asha.brackets[2].rungs) == 1
        assert trial_id in asha.brackets[2].rungs[0]["results"]
        compare_registered_trial(
            asha.brackets[2].rungs[0]["results"][trial_id], trial)
Beispiel #15
0
    def test_register(self, evolution, bracket, rung_0, rung_1):
        """Check that a point is registered inside the bracket."""
        evolution.brackets = [bracket]
        bracket.hyperband = evolution
        bracket.eves = evolution
        bracket.rungs = [rung_0, rung_1]
        trial = create_trial_for_hb((1, 0.0), objective=0.0)
        trial_id = evolution.get_id(trial, ignore_fidelity=True)

        evolution.observe([trial])

        assert len(bracket.rungs[0])
        assert trial_id in bracket.rungs[0]["results"]
        assert bracket.rungs[0]["results"][trial_id][0] == 0.0
        assert bracket.rungs[0]["results"][trial_id][1].params == trial.params
Beispiel #16
0
    def test_suggest_inf_duplicates(self, monkeypatch, asha, bracket, rung_0,
                                    rung_1, rung_2):
        """Test that sampling inf collisions returns None."""
        asha.brackets = [bracket]
        bracket.asha = asha

        fidelity = 1
        zhe_trial = create_trial_for_hb((fidelity, 0.0))
        asha.trial_to_brackets[asha.get_id(zhe_trial,
                                           ignore_fidelity=True)] = 0

        def sample(num=1, seed=None):
            return [zhe_trial]

        monkeypatch.setattr(asha.space, "sample", sample)

        assert asha.suggest(1) == []
Beispiel #17
0
 def sample(num=1, seed=None):
     return [create_trial_for_hb(("fidelity", 0.5))]