def test_suggest_duplicates(self, monkeypatch, asha, bracket, rung_0, rung_1, rung_2): """Test that sampling collisions are handled.""" asha.brackets = [bracket] bracket.asha = asha fidelity = 1 duplicate_trial = create_trial_for_hb((fidelity, 0.0)) new_trial = create_trial_for_hb((fidelity, 0.5)) duplicate_id_wo_fidelity = asha.get_id(duplicate_trial, ignore_fidelity=True) bracket.rungs[0] = dict( n_trials=2, resources=1, results={duplicate_id_wo_fidelity: (0.0, duplicate_trial)}, ) asha.trial_to_brackets[duplicate_id_wo_fidelity] = 0 asha.register(duplicate_trial) trials = [duplicate_trial, new_trial] def sample(num=1, seed=None): return trials monkeypatch.setattr(asha.space, "sample", sample) assert asha.suggest(1)[0].params == new_trial.params
def test_register_bracket_multi_fidelity(self, space, b_config): """Check that a point is registered inside the same bracket for diff fidelity.""" asha = ASHA(space, num_brackets=3) value = 50 fidelity = 1 trial = create_trial_for_hb((fidelity, value), 0.0) trial_id = asha.get_id(trial, ignore_fidelity=True) force_observe(asha, trial) bracket = asha.brackets[0] assert len(bracket.rungs[0]) assert trial_id in bracket.rungs[0]["results"] assert bracket.rungs[0]["results"][trial_id][0] == 0.0 assert bracket.rungs[0]["results"][trial_id][1].params == trial.params fidelity = 3 trial = create_trial_for_hb((fidelity, value), 0.0) trial_id = asha.get_id(trial, ignore_fidelity=True) force_observe(asha, trial) assert len(bracket.rungs[1]) assert trial_id in bracket.rungs[1]["results"] assert bracket.rungs[0]["results"][trial_id][1].params != trial.params assert bracket.rungs[1]["results"][trial_id][0] == 0.0 assert bracket.rungs[1]["results"][trial_id][1].params == trial.params
def test_promotion_with_rung_1_hit(self, asha, bracket, rung_0): """Test that get_candidate gives us the next best thing if point is already in rung 1.""" trial = create_trial_for_hb((1, 0.0), None) bracket.asha = asha bracket.rungs[0] = rung_0 bracket.rungs[1]["results"][asha.get_id(trial, ignore_fidelity=True)] = ( trial.objective.value, trial, ) trial = bracket.get_candidates(0)[0] assert trial.params == create_trial_for_hb((1, 1.0), 0.0).params
def test_suggest_promote_identic_objectives(self, asha, bracket, big_rung_0, big_rung_1): """Test that identic objectives are handled properly""" asha.brackets = [bracket] bracket.asha = asha n_trials = 9 resources = 1 results = {} for param in np.linspace(0, 8, 9): trial = create_trial_for_hb((resources, param), objective=0) trial_hash = trial.compute_trial_hash( trial, ignore_fidelity=True, ignore_experiment=True, ) results[trial_hash] = (trial.objective.value, trial) bracket.rungs[0] = dict(n_trials=n_trials, resources=resources, results=results) candidates = asha.suggest(2) assert len(candidates) == 2 assert (sum(1 for trial in candidates if trial.params[asha.fidelity_index] == 3) == 2)
def test_bad_register(self, asha, bracket): """Check that a non-valid point is not registered.""" bracket.asha = asha with pytest.raises(IndexError) as ex: bracket.register(create_trial_for_hb((55, 0.0), 0.0)) assert "Bad fidelity level 55" in str(ex.value)
def test_register_corrupted_db(self, caplog, space, b_config): """Check that a point cannot registered if passed in order diff than fidelity.""" asha = ASHA(space, num_brackets=3) value = 50 fidelity = 3 trial = create_trial_for_hb((fidelity, value)) force_observe(asha, trial) assert "Trial registered to wrong bracket" not in caplog.text fidelity = 1 trial = create_trial_for_hb((fidelity, value), objective=0.0) caplog.clear() force_observe(asha, trial) assert "Trial registered to wrong bracket" in caplog.text
def test_candidate_promotion(self, asha, bracket, rung_0): """Test that correct point is promoted.""" bracket.asha = asha bracket.rungs[0] = rung_0 point = bracket.get_candidates(0)[0] assert point.params == create_trial_for_hb((1, 0.0), 0.0).params
def test_candidate_promotion(self, asha: ASHA, bracket: ASHABracket, rung_0: RungDict): """Test that correct point is promoted.""" assert bracket.owner is asha bracket.rungs[0] = rung_0 point = bracket.get_candidates(0)[0] assert point.params == create_trial_for_hb((1, 0.0), 0.0).params
def test_register_invalid_fidelity(self, space, b_config): """Check that a point cannot registered if fidelity is invalid.""" asha = ASHA(space, num_brackets=3) value = 50 fidelity = 2 trial = create_trial_for_hb((fidelity, value)) asha.observe([trial]) assert not asha.has_suggested(trial) assert not asha.has_observed(trial)
def test_update_rungs_return_candidate(self, asha, bracket, rung_1): """Check if a valid modified candidate is returned by update_rungs.""" bracket.asha = asha bracket.rungs[1] = rung_1 trial = create_trial_for_hb((3, 0.0), 0.0) candidate = bracket.promote(1)[0] trial_id = asha.get_id(trial, ignore_fidelity=True) assert trial_id in bracket.rungs[1]["results"] assert bracket.rungs[1]["results"][trial_id][1].params == trial.params assert candidate.params["epoch"] == 9
def test_register_not_sampled(self, space, b_config, caplog): """Check that a point cannot registered if not sampled.""" asha = ASHA(space, num_brackets=3) value = 50 fidelity = 2 trial = create_trial_for_hb((fidelity, value)) with caplog.at_level(logging.DEBUG, logger="orion.algo.hyperband"): asha.observe([trial]) assert len(caplog.records) == 1 assert "Ignoring trial" in caplog.records[0].msg
def test_register(self, asha, bracket): """Check that a point is correctly registered inside a bracket.""" bracket.asha = asha trial = create_trial_for_hb((1, 0.0), 0.0) trial_id = asha.get_id(trial, ignore_fidelity=True) bracket.register(trial) assert len(bracket.rungs[0]) assert trial_id in bracket.rungs[0]["results"] assert bracket.rungs[0]["results"][trial_id][ 0] == trial.objective.value assert bracket.rungs[0]["results"][trial_id][1].to_dict( ) == trial.to_dict()
def test_register(self, asha, bracket, rung_0, rung_1): """Check that a point is registered inside the bracket.""" asha.brackets = [bracket] bracket.asha = asha bracket.rungs = [rung_0, rung_1] trial = create_trial_for_hb((1, 0.0), 0.0) trial_id = asha.get_id(trial, ignore_fidelity=True) asha.observe([trial]) assert len(bracket.rungs[0]) assert trial_id in bracket.rungs[0]["results"] assert bracket.rungs[0]["results"][trial_id][0] == 0.0 assert bracket.rungs[0]["results"][trial_id][1].params == trial.params
def test_register_next_bracket(self, space, b_config): """Check that a point is registered inside the good bracket when higher fidelity.""" asha = ASHA(space, num_brackets=3) value = 50 fidelity = 3 trial = create_trial_for_hb((fidelity, value), 0.0) trial_id = asha.get_id(trial, ignore_fidelity=True) force_observe(asha, trial) assert sum(len(rung["results"]) for rung in asha.brackets[0].rungs) == 0 assert sum(len(rung["results"]) for rung in asha.brackets[1].rungs) == 1 assert sum(len(rung["results"]) for rung in asha.brackets[2].rungs) == 0 assert trial_id in asha.brackets[1].rungs[0]["results"] compare_registered_trial( asha.brackets[1].rungs[0]["results"][trial_id], trial) value = 51 fidelity = 9 trial = create_trial_for_hb((fidelity, value), 0.0) trial_id = asha.get_id(trial, ignore_fidelity=True) force_observe(asha, trial) assert sum(len(rung["results"]) for rung in asha.brackets[0].rungs) == 0 assert sum(len(rung["results"]) for rung in asha.brackets[1].rungs) == 1 assert sum(len(rung["results"]) for rung in asha.brackets[2].rungs) == 1 assert trial_id in asha.brackets[2].rungs[0]["results"] compare_registered_trial( asha.brackets[2].rungs[0]["results"][trial_id], trial)
def test_register(self, evolution, bracket, rung_0, rung_1): """Check that a point is registered inside the bracket.""" evolution.brackets = [bracket] bracket.hyperband = evolution bracket.eves = evolution bracket.rungs = [rung_0, rung_1] trial = create_trial_for_hb((1, 0.0), objective=0.0) trial_id = evolution.get_id(trial, ignore_fidelity=True) evolution.observe([trial]) assert len(bracket.rungs[0]) assert trial_id in bracket.rungs[0]["results"] assert bracket.rungs[0]["results"][trial_id][0] == 0.0 assert bracket.rungs[0]["results"][trial_id][1].params == trial.params
def test_suggest_inf_duplicates(self, monkeypatch, asha, bracket, rung_0, rung_1, rung_2): """Test that sampling inf collisions returns None.""" asha.brackets = [bracket] bracket.asha = asha fidelity = 1 zhe_trial = create_trial_for_hb((fidelity, 0.0)) asha.trial_to_brackets[asha.get_id(zhe_trial, ignore_fidelity=True)] = 0 def sample(num=1, seed=None): return [zhe_trial] monkeypatch.setattr(asha.space, "sample", sample) assert asha.suggest(1) == []
def sample(num=1, seed=None): return [create_trial_for_hb(("fidelity", 0.5))]