def test_candidate_promotion(self, hyperband: Hyperband, bracket: HyperbandBracket, rung_0: RungDict): """Test that correct point is promoted.""" assert bracket.owner is hyperband bracket.rungs[0] = rung_0 points = bracket.get_candidates(0) assert points[0].params == create_trial_for_hb((1, 0.0), 0.0).params
def test_bad_register(self, hyperband: Hyperband, bracket: HyperbandBracket): """Check that a non-valid point is not registered.""" assert bracket.owner is hyperband with pytest.raises(IndexError) as ex: bracket.register(create_trial_for_hb((55, 0.0), 0.0)) assert "Bad fidelity level 55" in str(ex.value)
def test_get_trial_max_resource( self, hyperband: Hyperband, bracket: HyperbandBracket, rung_0: RungDict, rung_1: RungDict, rung_2: RungDict, ): """Test to get the max resource R for a particular trial""" assert bracket.owner is hyperband bracket.rungs[0] = rung_0 assert bracket.get_trial_max_resource( trial=create_trial_for_hb((1, 0.0))) == 1 assert bracket.get_trial_max_resource( trial=create_trial_for_hb((1, 8.0))) == 1 bracket.rungs[1] = rung_1 assert bracket.get_trial_max_resource( trial=create_trial_for_hb((1, 0.0))) == 3 assert bracket.get_trial_max_resource( trial=create_trial_for_hb((1, 8.0))) == 1 bracket.rungs[2] = rung_2 assert bracket.get_trial_max_resource( trial=create_trial_for_hb((1, 0.0))) == 9 assert bracket.get_trial_max_resource( trial=create_trial_for_hb((1, 8.0))) == 1
def test_suggest_duplicates_between_execution(self, monkeypatch, hyperband, budgets): """Test that sampling collisions are handled between different hyperband execution.""" hyperband.repetitions = 2 bracket = HyperbandBracket(hyperband, budgets, 1) hyperband.brackets = [bracket] bracket.hyperband = hyperband for i in range(9): force_observe(hyperband, create_trial_for_hb((1, i), objective=i)) for i in range(3): force_observe(hyperband, create_trial_for_hb((3, i), objective=i)) force_observe(hyperband, create_trial_for_hb((9, 0), objective=0)) assert not hyperband.is_done # lr:7 and lr:8 are already sampled in first repetition, they should not be present # in second repetition. Samples with lr:7 and lr:8 will be ignored. # (9, 0) already exists candidates_for_epoch_9_bracket = [(9, 0), (9, 2), (9, 3), (9, 10)] # (9, 1) -> (3, 1) already promoted in last repetition # (9, 3) sampled for previous bracket candidates_for_epoch_3_bracket = [(9, 1), (9, 3), (9, 4), (9, 5), (9, 11)] # (9, 0) -> (1, 0) already sampled in last repetition # (9, 8) -> (1, 8) already sampled in last repetition candidates_for_epoch_1_bracket = [(9, 0), (9, 8), (9, 12), (9, 13)] zhe_point = list( map( create_trial_for_hb, candidates_for_epoch_9_bracket + candidates_for_epoch_3_bracket + candidates_for_epoch_1_bracket, )) hyperband._refresh_brackets() mock_samples(hyperband, zhe_point) zhe_samples = hyperband.suggest(100) assert len(zhe_samples) == 8 assert zhe_samples[0].params == {"epoch": 9, "lr": 2} assert zhe_samples[1].params == {"epoch": 9, "lr": 3} assert zhe_samples[2].params == {"epoch": 9, "lr": 10} assert zhe_samples[3].params == {"epoch": 3, "lr": 4} assert zhe_samples[4].params == {"epoch": 3, "lr": 5} assert zhe_samples[5].params == {"epoch": 3, "lr": 11} assert zhe_samples[6].params == {"epoch": 1, "lr": 12} assert zhe_samples[7].params == {"epoch": 1, "lr": 13}
def test_repr( self, bracket: HyperbandBracket, rung_0: RungDict, rung_1: RungDict, rung_2: RungDict, ): """Test the string representation of HyperbandBracket""" bracket.rungs[0] = rung_0 bracket.rungs[1] = rung_1 bracket.rungs[2] = rung_2 assert str( bracket) == "HyperbandBracket(resource=[1, 3, 9], repetition id=1)"
def test_update_rungs_return_candidate(self, hyperband: Hyperband, bracket: HyperbandBracket, rung_1: RungDict): """Check if a valid modified candidate is returned by update_rungs.""" assert bracket.owner is hyperband bracket.rungs[1] = rung_1 trial = create_trial_for_hb((3, 0.0), 0.0) candidates = bracket.promote(1) trial_id = hyperband.get_id(trial, ignore_fidelity=True) assert trial_id in bracket.rungs[1]["results"] assert bracket.rungs[1]["results"][trial_id][1].params == trial.params assert candidates[0].params["epoch"] == 9
def test_register(self, hyperband: Hyperband, bracket: HyperbandBracket): """Check that a point is correctly registered inside a bracket.""" assert bracket.owner is hyperband trial = create_trial_for_hb((1, 0.0), 0.0) trial_id = hyperband.get_id(trial, ignore_fidelity=True) bracket.register(trial) assert len(bracket.rungs[0]) assert trial_id in bracket.rungs[0]["results"] assert trial.objective is not None assert bracket.rungs[0]["results"][trial_id][ 0] == trial.objective.value assert bracket.rungs[0]["results"][trial_id][1].to_dict( ) == trial.to_dict()
def test_no_promotion_if_not_completed(self, hyperband: Hyperband, bracket: HyperbandBracket, rung_0: RungDict): """Test the get_candidate return None if trials are not completed.""" assert bracket.owner is hyperband bracket.rungs[0] = rung_0 rung = bracket.rungs[0]["results"] # points = bracket.get_candidates(0) for p_id in rung.keys(): rung[p_id] = (None, rung[p_id][1]) with pytest.raises(TypeError): bracket.get_candidates(0)
def test_no_promotion_when_rung_full( self, hyperband: Hyperband, bracket: HyperbandBracket, rung_0: RungDict, rung_1: RungDict, ): """Test that get_candidate returns `None` if rung 1 is full.""" assert bracket.owner is hyperband bracket.rungs[0] = rung_0 bracket.rungs[1] = rung_1 points = bracket.get_candidates(0) assert points == []
def test_suggest_promote_identic_objectives(self, hyperband: Hyperband, bracket: HyperbandBracket): """Test that identic objectives are handled properly""" hyperband.brackets = [bracket] assert bracket.owner is hyperband n_trials = 9 resources = 1 results = {} for param in np.linspace(0, 8, 9): trial = create_trial_for_hb((resources, param), objective=0) trial_hash = trial.compute_trial_hash( trial, ignore_fidelity=True, ignore_experiment=True, ) assert trial.objective is not None results[trial_hash] = (trial.objective.value, trial) bracket.rungs[0] = RungDict(n_trials=n_trials, resources=resources, results=results) candidates = hyperband.suggest(2) assert candidates is not None assert len(candidates) == 2 assert (sum(1 for trial in candidates if trial.params[hyperband.fidelity_index] == 3) == 2)
def test_suggest_duplicates_between_calls(self, monkeypatch, hyperband: Hyperband, bracket: HyperbandBracket): """Test that same trials are not allowed in different suggest call of the same hyperband execution. """ hyperband.brackets = [bracket] assert bracket.owner is hyperband fidelity = 1 duplicate_trial = create_trial_for_hb((fidelity, 0.0)) new_trial = create_trial_for_hb((fidelity, 0.5)) duplicate_id = hyperband.get_id(duplicate_trial, ignore_fidelity=True) bracket.rungs[0]["results"] = { duplicate_id: (0.0, duplicate_trial) } # type: ignore hyperband.trial_to_brackets[duplicate_id] = 0 trials = [duplicate_trial, new_trial] mock_samples( hyperband, trials + [create_trial_for_hb((fidelity, i)) for i in range(10 - 2)], ) trials = hyperband.suggest(100) assert trials is not None assert trials[0].params == new_trial.params
def test_promotion_with_rung_1_hit(self, hyperband: Hyperband, bracket: HyperbandBracket, rung_0: RungDict): """Test that get_candidate gives us the next best thing if point is already in rung 1.""" trial = create_trial_for_hb((1, 0.0), None) assert bracket.owner is hyperband bracket.rungs[0] = rung_0 assert trial.objective is not None bracket.rungs[1]["results"][hyperband.get_id(trial, ignore_fidelity=True)] = ( trial.objective.value, trial, ) trials = bracket.get_candidates(0) assert trials[0].params == create_trial_for_hb((1, 1), 0.0).params
def test_update_rungs_return_no_candidate(self, hyperband: Hyperband, bracket: HyperbandBracket, rung_1: RungDict): """Check if no candidate is returned by update_rungs.""" assert bracket.owner is hyperband candidates = bracket.promote(1) assert candidates == []
def test_is_done(self, bracket: HyperbandBracket, rung_0: RungDict): """Test that the `is_done` property works.""" assert not bracket.is_done # Actual value of the point is not important here bracket.rungs[2]["results"] = { "1": (1, 0.0), "2": (1, 0.0), "3": (1, 0.0) } # type: ignore assert bracket.is_done
def test_suggest_promote(self, hyperband: Hyperband, bracket: HyperbandBracket, rung_0: RungDict): """Test that correct point is promoted and returned.""" hyperband.brackets = [bracket] assert bracket.owner is hyperband bracket.rungs[0] = rung_0 points = hyperband.suggest(100) assert points is not None assert len(points) == 3 assert points[0].params == {"epoch": 3, "lr": 0} assert points[1].params == {"epoch": 3, "lr": 1} assert points[2].params == {"epoch": 3, "lr": 2}
def test_suggest_opt_out( self, hyperband: Hyperband, bracket: HyperbandBracket, rung_0: RungDict, rung_1: RungDict, rung_2: RungDict, ): """Test that Hyperband opts out when rungs are not ready.""" hyperband.brackets = [bracket] assert bracket.owner is hyperband bracket.rungs[0] = rung_0 trial_id = next(iter(rung_0["results"].keys())) objective, point = rung_0["results"][trial_id] rung_0["results"][trial_id] = (None, point) assert hyperband.suggest(100) == []
def test_register( self, hyperband: Hyperband, bracket: HyperbandBracket, rung_0: RungDict, rung_1: RungDict, ): """Check that a point is registered inside the bracket.""" hyperband.brackets = [bracket] assert bracket.owner is hyperband bracket.rungs = [rung_0, rung_1] trial = create_trial_for_hb((1, 0.0), 0.0) trial_id = hyperband.get_id(trial, ignore_fidelity=True) hyperband.observe([trial]) assert len(bracket.rungs[0]) assert trial_id in bracket.rungs[0]["results"] assert bracket.rungs[0]["results"][trial_id][0] == 0.0 assert bracket.rungs[0]["results"][trial_id][1].params == trial.params
def bracket(budgets: list[BudgetTuple], hyperband: Hyperband): """Return a `HyperbandBracket` instance configured with `b_config`.""" return HyperbandBracket(hyperband, budgets, 1)
def test_is_ready( self, hyperband: Hyperband, bracket: HyperbandBracket, rung_0: RungDict, rung_1: RungDict, rung_2: RungDict, ): """Test that Hyperband bracket detects when rung is ready.""" hyperband.brackets = [bracket] assert bracket.owner is hyperband bracket.rungs[0] = rung_0 rung = bracket.rungs[0]["results"] trial_id = next(iter(rung.keys())) objective, point = rung[trial_id] rung[trial_id] = (None, point) assert not bracket.is_ready() assert not bracket.is_ready(0) rung[trial_id] = (objective, point) assert bracket.is_ready() assert bracket.is_ready(0) assert not bracket.is_ready(1) assert not bracket.is_ready(2) bracket.rungs[1] = rung_1 rung = bracket.rungs[1]["results"] trial_id = next(iter(rung.keys())) objective, point = rung[trial_id] rung[trial_id] = (None, point) assert not bracket.is_ready( ) # Should depend on last rung that contains trials assert bracket.is_ready(0) assert not bracket.is_ready(1) assert not bracket.is_ready(2) rung[trial_id] = (objective, point) assert bracket.is_ready( ) # Should depend on last rung that contains trials assert bracket.is_ready(0) assert bracket.is_ready(1) assert not bracket.is_ready(2) bracket.rungs[2] = rung_2 rung = bracket.rungs[2]["results"] trial_id = next(iter(rung.keys())) objective, point = rung[trial_id] rung[trial_id] = (None, point) assert not bracket.is_ready( ) # Should depend on last rung that contains trials assert bracket.is_ready(0) assert bracket.is_ready(1) assert not bracket.is_ready(2) rung[trial_id] = (objective, point) assert bracket.is_ready( ) # Should depend on last rung that contains trials assert bracket.is_ready(0) assert bracket.is_ready(1) assert bracket.is_ready(2)
def test_is_filled( self, hyperband: Hyperband, bracket: HyperbandBracket, rung_0: RungDict, rung_1: RungDict, rung_2: RungDict, ): """Test that Hyperband bracket detects when rung is filled.""" hyperband.brackets = [bracket] assert bracket.owner is hyperband bracket.rungs[0] = rung_0 rung = bracket.rungs[0]["results"] trial_id = next(iter(rung.keys())) objective, point = rung.pop(trial_id) assert not bracket.is_filled assert not bracket.has_rung_filled(0) rung[trial_id] = (objective, point) assert bracket.is_filled assert bracket.has_rung_filled(0) assert not bracket.has_rung_filled(1) assert not bracket.has_rung_filled(2) bracket.rungs[1] = rung_1 rung = bracket.rungs[1]["results"] trial_id = next(iter(rung.keys())) objective, point = rung.pop(trial_id) assert bracket.is_filled # Should depend first rung only assert bracket.has_rung_filled(0) assert not bracket.has_rung_filled(1) rung[trial_id] = (objective, point) assert bracket.is_filled # Should depend first rung only assert bracket.has_rung_filled(0) assert bracket.has_rung_filled(1) assert not bracket.has_rung_filled(2) bracket.rungs[2] = rung_2 rung = bracket.rungs[2]["results"] trial_id = next(iter(rung.keys())) objective, point = rung.pop(trial_id) assert bracket.is_filled # Should depend first rung only assert bracket.has_rung_filled(0) assert bracket.has_rung_filled(1) assert not bracket.has_rung_filled(2) rung[trial_id] = (objective, point) assert bracket.is_filled # Should depend first rung only assert bracket.has_rung_filled(0) assert bracket.has_rung_filled(1) assert bracket.has_rung_filled(2)