def test_suggest_promote_identic_objectives(self, hyperband: Hyperband, bracket: HyperbandBracket): """Test that identic objectives are handled properly""" hyperband.brackets = [bracket] assert bracket.owner is hyperband n_trials = 9 resources = 1 results = {} for param in np.linspace(0, 8, 9): trial = create_trial_for_hb((resources, param), objective=0) trial_hash = trial.compute_trial_hash( trial, ignore_fidelity=True, ignore_experiment=True, ) assert trial.objective is not None results[trial_hash] = (trial.objective.value, trial) bracket.rungs[0] = RungDict(n_trials=n_trials, resources=resources, results=results) candidates = hyperband.suggest(2) assert candidates is not None assert len(candidates) == 2 assert (sum(1 for trial in candidates if trial.params[hyperband.fidelity_index] == 3) == 2)
def test_suggest_duplicates_between_calls(self, monkeypatch, hyperband: Hyperband, bracket: HyperbandBracket): """Test that same trials are not allowed in different suggest call of the same hyperband execution. """ hyperband.brackets = [bracket] assert bracket.owner is hyperband fidelity = 1 duplicate_trial = create_trial_for_hb((fidelity, 0.0)) new_trial = create_trial_for_hb((fidelity, 0.5)) duplicate_id = hyperband.get_id(duplicate_trial, ignore_fidelity=True) bracket.rungs[0]["results"] = { duplicate_id: (0.0, duplicate_trial) } # type: ignore hyperband.trial_to_brackets[duplicate_id] = 0 trials = [duplicate_trial, new_trial] mock_samples( hyperband, trials + [create_trial_for_hb((fidelity, i)) for i in range(10 - 2)], ) trials = hyperband.suggest(100) assert trials is not None assert trials[0].params == new_trial.params
def test_register_bracket_multi_fidelity(self, space: Space): """Check that a point is registered inside the same bracket for diff fidelity.""" hyperband = Hyperband(space) value = 50 fidelity = 1 trial = create_trial_for_hb((fidelity, value), 0.0) trial_id = hyperband.get_id(trial, ignore_fidelity=True) force_observe(hyperband, trial) assert hyperband.brackets is not None bracket = hyperband.brackets[0] assert len(bracket.rungs[0]) assert trial_id in bracket.rungs[0]["results"] assert bracket.rungs[0]["results"][trial_id][0] == 0.0 assert bracket.rungs[0]["results"][trial_id][1].params == trial.params fidelity = 3 trial = create_trial_for_hb((fidelity, value), 0.0) trial_id = hyperband.get_id(trial, ignore_fidelity=True) force_observe(hyperband, trial) assert len(bracket.rungs[1]) assert trial_id in bracket.rungs[1]["results"] assert bracket.rungs[0]["results"][trial_id][1].params != trial.params assert bracket.rungs[1]["results"][trial_id][0] == 0.0 assert bracket.rungs[1]["results"][trial_id][1].params == trial.params
def test_suggest_in_finite_cardinality(self): """Test that suggest None when search space is empty""" space = Space() space.register(Integer("yolo1", "uniform", 0, 5)) space.register(Fidelity("epoch", 1, 9, 3)) hyperband = Hyperband(space, repetitions=1) for i in range(6): force_observe(hyperband, (1, i), {"objective": i}) assert hyperband.suggest() is None
def test_suggest_in_finite_cardinality(self): """Test that suggest None when search space is empty""" space = Space() space.register(Integer('yolo1', 'uniform', 0, 6)) space.register(Fidelity('epoch', 1, 9, 3)) hyperband = Hyperband(space, repetitions=1) for i in range(6): hyperband.observe([(1, i)], [{'objective': i}]) assert hyperband.suggest() is None
def test_suggest_duplicates_between_execution(self, monkeypatch, hyperband: Hyperband, budgets: list[BudgetTuple]): """Test that sampling collisions are handled between different hyperband execution.""" hyperband.repetitions = 2 bracket = HyperbandBracket(hyperband, budgets, 1) hyperband.brackets = [bracket] assert bracket.owner is hyperband for i in range(9): force_observe(hyperband, create_trial_for_hb((1, i), objective=i)) for i in range(3): force_observe(hyperband, create_trial_for_hb((3, i), objective=i)) force_observe(hyperband, create_trial_for_hb((9, 0), objective=0)) assert not hyperband.is_done # lr:7 and lr:8 are already sampled in first repetition, they should not be present # in second repetition. Samples with lr:7 and lr:8 will be ignored. # (9, 0) already exists candidates_for_epoch_9_bracket = [(9, 0), (9, 2), (9, 3), (9, 10)] # (9, 1) -> (3, 1) already promoted in last repetition # (9, 3) sampled for previous bracket candidates_for_epoch_3_bracket = [(9, 1), (9, 3), (9, 4), (9, 5), (9, 11)] # (9, 0) -> (1, 0) already sampled in last repetition # (9, 8) -> (1, 8) already sampled in last repetition candidates_for_epoch_1_bracket = [(9, 0), (9, 8), (9, 12), (9, 13)] zhe_point = list( map( create_trial_for_hb, candidates_for_epoch_9_bracket + candidates_for_epoch_3_bracket + candidates_for_epoch_1_bracket, )) hyperband._refresh_brackets() mock_samples(hyperband, zhe_point) zhe_samples = hyperband.suggest(100) assert zhe_samples is not None assert len(zhe_samples) == 8 assert zhe_samples[0].params == {"epoch": 9, "lr": 2} assert zhe_samples[1].params == {"epoch": 9, "lr": 3} assert zhe_samples[2].params == {"epoch": 9, "lr": 10} assert zhe_samples[3].params == {"epoch": 3, "lr": 4} assert zhe_samples[4].params == {"epoch": 3, "lr": 5} assert zhe_samples[5].params == {"epoch": 3, "lr": 11} assert zhe_samples[6].params == {"epoch": 1, "lr": 12} assert zhe_samples[7].params == {"epoch": 1, "lr": 13}
def test_register_invalid_fidelity(self, space): """Check that a point cannot registered if fidelity is invalid.""" hyperband = Hyperband(space) value = 50 fidelity = 2 point = (fidelity, value) with pytest.raises(ValueError) as ex: hyperband.observe([point], [{'objective': 0.0}]) assert 'No bracket found for point' in str(ex.value)
def test_register_not_sampled(self, space, caplog): """Check that a point cannot registered if not sampled.""" hyperband = Hyperband(space) value = 50 fidelity = 2 point = (fidelity, value) with caplog.at_level(logging.INFO, logger="orion.algo.hyperband"): hyperband.observe([point], [{"objective": 0.0}]) assert len(caplog.records) == 1 assert "Ignoring point" in caplog.records[0].msg
def test_register_not_sampled(self, space, caplog): """Check that a point cannot registered if not sampled.""" hyperband = Hyperband(space) value = 50 fidelity = 2 trial = create_trial_for_hb((fidelity, value)) with caplog.at_level(logging.DEBUG, logger="orion.algo.hyperband"): hyperband.observe([trial]) assert len(caplog.records) == 1 assert "Ignoring trial" in caplog.records[0].msg
def test_suggest_promote(self, hyperband: Hyperband, bracket: HyperbandBracket, rung_0: RungDict): """Test that correct point is promoted and returned.""" hyperband.brackets = [bracket] assert bracket.owner is hyperband bracket.rungs[0] = rung_0 points = hyperband.suggest(100) assert points is not None assert len(points) == 3 assert points[0].params == {"epoch": 3, "lr": 0} assert points[1].params == {"epoch": 3, "lr": 1} assert points[2].params == {"epoch": 3, "lr": 2}
def test_register_bracket_multi_fidelity(self, space): """Check that a point is registered inside the same bracket for diff fidelity.""" hyperband = Hyperband(space) value = 50 fidelity = 1 point = (fidelity, value) point_hash = hashlib.md5(str([value]).encode("utf-8")).hexdigest() force_observe(hyperband, point, {"objective": 0.0}) bracket = hyperband.brackets[0] assert len(bracket.rungs[0]) assert point_hash in bracket.rungs[0]["results"] assert (0.0, point) == bracket.rungs[0]["results"][point_hash] fidelity = 3 point = [fidelity, value] point_hash = hashlib.md5(str([value]).encode("utf-8")).hexdigest() force_observe(hyperband, point, {"objective": 0.0}) assert len(bracket.rungs[0]) assert point_hash in bracket.rungs[1]["results"] assert (0.0, point) != bracket.rungs[0]["results"][point_hash] assert (0.0, point) == bracket.rungs[1]["results"][point_hash]
def test_register_next_bracket(self, space): """Check that a point is registered inside the good bracket when higher fidelity.""" hyperband = Hyperband(space) value = 50 fidelity = 3 point = (fidelity, value) point_hash = hashlib.md5(str([value]).encode("utf-8")).hexdigest() force_observe(hyperband, point, {"objective": 0.0}) assert sum(len(rung["results"]) for rung in hyperband.brackets[0].rungs) == 0 assert sum(len(rung["results"]) for rung in hyperband.brackets[1].rungs) == 1 assert sum(len(rung["results"]) for rung in hyperband.brackets[2].rungs) == 0 assert point_hash in hyperband.brackets[1].rungs[0]["results"] assert (0.0, point) == hyperband.brackets[1].rungs[0]["results"][point_hash] value = 51 fidelity = 9 point = (fidelity, value) point_hash = hashlib.md5(str([value]).encode("utf-8")).hexdigest() force_observe(hyperband, point, {"objective": 0.0}) assert sum(len(rung["results"]) for rung in hyperband.brackets[0].rungs) == 0 assert sum(len(rung["results"]) for rung in hyperband.brackets[1].rungs) == 1 assert sum(len(rung["results"]) for rung in hyperband.brackets[2].rungs) == 1 assert point_hash in hyperband.brackets[2].rungs[0]["results"] assert (0.0, point) == hyperband.brackets[2].rungs[0]["results"][point_hash]
def test_suggest_duplicates_one_call(self, monkeypatch, hyperband: Hyperband, bracket: HyperbandBracket): """Test that same points are not allowed in the same suggest call ofxs the same hyperband execution. """ hyperband.brackets = [bracket] assert bracket.owner is hyperband zhe_point = list( map(create_trial_for_hb, [(1, 0.0), (1, 1.0), (1, 1.0), (1, 2.0)])) mock_samples(hyperband, zhe_point * 2) zhe_samples = hyperband.suggest(100) assert zhe_samples is not None assert zhe_samples[0].params["lr"] == 0.0 assert zhe_samples[1].params["lr"] == 1.0 assert zhe_samples[2].params["lr"] == 2.0 # zhe_point = mock_samples( hyperband, list( map( create_trial_for_hb, [ (3, 0.0), (3, 1.0), (3, 1.0), (3, 2.0), (3, 5.0), (3, 4.0), ], )), ) hyperband.trial_to_brackets[hyperband.get_id(create_trial_for_hb( (1, 0.0)), ignore_fidelity=True)] = 0 hyperband.trial_to_brackets[hyperband.get_id(create_trial_for_hb( (1, 0.0)), ignore_fidelity=True)] = 0 zhe_samples = hyperband.suggest(100) assert zhe_samples is not None assert zhe_samples[0].params["lr"] == 5.0 assert zhe_samples[1].params["lr"] == 4.0
def test_suggest_opt_out( self, hyperband: Hyperband, bracket: HyperbandBracket, rung_0: RungDict, rung_1: RungDict, rung_2: RungDict, ): """Test that Hyperband opts out when rungs are not ready.""" hyperband.brackets = [bracket] assert bracket.owner is hyperband bracket.rungs[0] = rung_0 trial_id = next(iter(rung_0["results"].keys())) objective, point = rung_0["results"][trial_id] rung_0["results"][trial_id] = (None, point) assert hyperband.suggest(100) == []
def test_suggest_in_finite_cardinality(self): """Test that suggest None when search space is empty""" space = Space() space.register(Integer("yolo1", "uniform", 0, 5)) space.register(Fidelity("epoch", 1, 9, 3)) hyperband = Hyperband(space, repetitions=1) for i in range(6): force_observe( hyperband, create_trial( (1, i), names=("epoch", "yolo1"), types=("fidelity", "integer"), results={"objective": i}, ), ) assert hyperband.suggest(100) == []
def test_suggest_inf_duplicates( self, monkeypatch, hyperband: Hyperband, bracket: HyperbandBracket, rung_0: RungDict, rung_1: RungDict, rung_2: RungDict, ): """Test that sampling inf collisions will return None.""" hyperband.brackets = [bracket] assert bracket.owner is hyperband zhe_trial = create_trial_for_hb(("fidelity", 0.0)) hyperband.trial_to_brackets[hyperband.get_id(zhe_trial, ignore_fidelity=True)] = 0 mock_samples(hyperband, [zhe_trial] * 2) assert hyperband.suggest(100) == []
def test_suggest_new( self, monkeypatch, hyperband: Hyperband, bracket: HyperbandBracket, rung_0: RungDict, rung_1: RungDict, rung_2: RungDict, ): """Test that a new point is sampled.""" hyperband.brackets = [bracket] assert bracket.owner is hyperband mock_samples(hyperband, [create_trial_for_hb(("fidelity", i)) for i in range(10)]) trials = hyperband.suggest(100) assert trials is not None assert trials[0].params == {"epoch": 1.0, "lr": 0} assert trials[1].params == {"epoch": 1.0, "lr": 1}
def test_register( self, hyperband: Hyperband, bracket: HyperbandBracket, rung_0: RungDict, rung_1: RungDict, ): """Check that a point is registered inside the bracket.""" hyperband.brackets = [bracket] assert bracket.owner is hyperband bracket.rungs = [rung_0, rung_1] trial = create_trial_for_hb((1, 0.0), 0.0) trial_id = hyperband.get_id(trial, ignore_fidelity=True) hyperband.observe([trial]) assert len(bracket.rungs[0]) assert trial_id in bracket.rungs[0]["results"] assert bracket.rungs[0]["results"][trial_id][0] == 0.0 assert bracket.rungs[0]["results"][trial_id][1].params == trial.params
def test_register_invalid_fidelity(self, space): """Check that a point cannot registered if fidelity is invalid.""" hyperband = Hyperband(space) value = 50 fidelity = 2 point = (fidelity, value) with pytest.raises(ValueError) as ex: force_observe(hyperband, point, {"objective": 0.0}) assert "No bracket found for point" in str(ex.value)
def test_update_rungs_return_candidate(self, hyperband: Hyperband, bracket: HyperbandBracket, rung_1: RungDict): """Check if a valid modified candidate is returned by update_rungs.""" assert bracket.owner is hyperband bracket.rungs[1] = rung_1 trial = create_trial_for_hb((3, 0.0), 0.0) candidates = bracket.promote(1) trial_id = hyperband.get_id(trial, ignore_fidelity=True) assert trial_id in bracket.rungs[1]["results"] assert bracket.rungs[1]["results"][trial_id][1].params == trial.params assert candidates[0].params["epoch"] == 9
def test_register_next_bracket(self, space: Space): """Check that a point is registered inside the good bracket when higher fidelity.""" hyperband = Hyperband(space) value = 50 fidelity = 3 trial = create_trial_for_hb((fidelity, value), 0.0) trial_id = hyperband.get_id(trial, ignore_fidelity=True) force_observe(hyperband, trial) assert hyperband.brackets is not None assert sum( len(rung["results"]) for rung in hyperband.brackets[0].rungs) == 0 assert sum( len(rung["results"]) for rung in hyperband.brackets[1].rungs) == 1 assert sum( len(rung["results"]) for rung in hyperband.brackets[2].rungs) == 0 assert trial_id in hyperband.brackets[1].rungs[0]["results"] compare_registered_trial( hyperband.brackets[1].rungs[0]["results"][trial_id], trial) value = 51 fidelity = 9 trial = create_trial_for_hb((fidelity, value), 0.0) trial_id = hyperband.get_id(trial, ignore_fidelity=True) force_observe(hyperband, trial) assert sum( len(rung["results"]) for rung in hyperband.brackets[0].rungs) == 0 assert sum( len(rung["results"]) for rung in hyperband.brackets[1].rungs) == 1 assert sum( len(rung["results"]) for rung in hyperband.brackets[2].rungs) == 1 assert trial_id in hyperband.brackets[2].rungs[0]["results"] compare_registered_trial( hyperband.brackets[2].rungs[0]["results"][trial_id], trial)
def test_register(self, hyperband: Hyperband, bracket: HyperbandBracket): """Check that a point is correctly registered inside a bracket.""" assert bracket.owner is hyperband trial = create_trial_for_hb((1, 0.0), 0.0) trial_id = hyperband.get_id(trial, ignore_fidelity=True) bracket.register(trial) assert len(bracket.rungs[0]) assert trial_id in bracket.rungs[0]["results"] assert trial.objective is not None assert bracket.rungs[0]["results"][trial_id][ 0] == trial.objective.value assert bracket.rungs[0]["results"][trial_id][1].to_dict( ) == trial.to_dict()
def test_register_corrupted_db(self, caplog, space): """Check that a point cannot registered if passed in order diff than fidelity.""" hyperband = Hyperband(space) value = 50 fidelity = 3 trial = create_trial_for_hb((fidelity, value)) force_observe(hyperband, trial) assert "Trial registered to wrong bracket" not in caplog.text fidelity = 1 trial = create_trial_for_hb((fidelity, value)) caplog.clear() force_observe(hyperband, trial) assert "Trial registered to wrong bracket" in caplog.text
def test_register_corrupted_db(self, caplog, space): """Check that a point cannot registered if passed in order diff than fidelity.""" hyperband = Hyperband(space) value = 50 fidelity = 3 point = (fidelity, value) force_observe(hyperband, point, {"objective": 0.0}) assert "Point registered to wrong bracket" not in caplog.text fidelity = 1 point = [fidelity, value] caplog.clear() force_observe(hyperband, point, {"objective": 0.0}) assert "Point registered to wrong bracket" in caplog.text
def test_promotion_with_rung_1_hit(self, hyperband: Hyperband, bracket: HyperbandBracket, rung_0: RungDict): """Test that get_candidate gives us the next best thing if point is already in rung 1.""" trial = create_trial_for_hb((1, 0.0), None) assert bracket.owner is hyperband bracket.rungs[0] = rung_0 assert trial.objective is not None bracket.rungs[1]["results"][hyperband.get_id(trial, ignore_fidelity=True)] = ( trial.objective.value, trial, ) trials = bracket.get_candidates(0) assert trials[0].params == create_trial_for_hb((1, 1), 0.0).params
def test_register_invalid_fidelity(self, space: Space): """Check that a point cannot registered if fidelity is invalid.""" hyperband = Hyperband(space) value = 50 fidelity = 2 trial = create_trial_for_hb((fidelity, value)) hyperband.observe([trial]) assert not hyperband.has_suggested(trial) assert not hyperband.has_observed(trial)
def test_register_corrupted_db(self, caplog, space): """Check that a point cannot registered if passed in order diff than fidelity.""" hyperband = Hyperband(space) value = 50 fidelity = 3 point = (fidelity, value) hyperband.observe([point], [{'objective': 0.0}]) assert 'Point registered to wrong bracket' not in caplog.text fidelity = 1 point = [fidelity, value] caplog.clear() hyperband.observe([point], [{'objective': 0.0}]) assert 'Point registered to wrong bracket' in caplog.text
def force_observe(hyperband: Hyperband, trial: Trial) -> None: # hyperband.sampled.add(hashlib.md5(str(list(point)).encode("utf-8")).hexdigest()) hyperband.register(trial) id_wo_fidelity = hyperband.get_id(trial, ignore_fidelity=True) bracket_index = hyperband.trial_to_brackets.get(id_wo_fidelity, None) if bracket_index is None: fidelity = flatten(trial.params)[hyperband.fidelity_index] assert hyperband.brackets is not None bracket_index = [ i for i, bracket in enumerate(hyperband.brackets) if bracket.rungs[0]["resources"] == fidelity ][0] hyperband.trial_to_brackets[id_wo_fidelity] = bracket_index hyperband.observe([trial])
def test_get_id_multidim(self): """Test valid id for points with dim of shape > 1""" space = Space() space.register(Fidelity('epoch', 1, 9, 3)) space.register(Real('lr', 'uniform', 0, 1, shape=2)) hyperband = Hyperband(space) assert hyperband.get_id(['whatever', [1, 1]]) == hyperband.get_id(['is here', [1, 1]]) assert hyperband.get_id(['whatever', [1, 1]]) != hyperband.get_id(['is here', [2, 2]]) assert hyperband.get_id(['whatever', [1, 1]], ignore_fidelity=False) != \ hyperband.get_id(['is here', [1, 1]], ignore_fidelity=False) assert hyperband.get_id(['whatever', [1, 1]], ignore_fidelity=False) != \ hyperband.get_id(['is here', [2, 2]], ignore_fidelity=False) assert hyperband.get_id(['same', [1, 1]], ignore_fidelity=False) == \ hyperband.get_id(['same', [1, 1]], ignore_fidelity=False) assert hyperband.get_id(['same', [1, 1]], ignore_fidelity=False) != \ hyperband.get_id(['same', [1, 1]])
def test_get_id(self, space): """Test valid id of points""" hyperband = Hyperband(space) assert hyperband.get_id(['whatever', 1]) == hyperband.get_id(['is here', 1]) assert hyperband.get_id(['whatever', 1]) != hyperband.get_id(['is here', 2]) assert hyperband.get_id(['whatever', 1], ignore_fidelity=False) != \ hyperband.get_id(['is here', 1], ignore_fidelity=False) assert hyperband.get_id(['whatever', 1], ignore_fidelity=False) != \ hyperband.get_id(['is here', 2], ignore_fidelity=False) assert hyperband.get_id(['same', 1], ignore_fidelity=False) == \ hyperband.get_id(['same', 1], ignore_fidelity=False) assert hyperband.get_id(['same', 1], ignore_fidelity=False) != \ hyperband.get_id(['same', 1])