def test_register_next_bracket(self, space): """Check that a point is registered inside the good bracket when higher fidelity.""" hyperband = Hyperband(space) value = 50 fidelity = 3 point = (fidelity, value) point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest() hyperband.observe([point], [{'objective': 0.0}]) assert sum(len(rung['results']) for rung in hyperband.brackets[0].rungs) == 0 assert sum(len(rung['results']) for rung in hyperband.brackets[1].rungs) == 1 assert sum(len(rung['results']) for rung in hyperband.brackets[2].rungs) == 0 assert point_hash in hyperband.brackets[1].rungs[0]['results'] assert (0.0, point) == hyperband.brackets[1].rungs[0]['results'][point_hash] value = 51 fidelity = 9 point = (fidelity, value) point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest() hyperband.observe([point], [{'objective': 0.0}]) assert sum(len(rung['results']) for rung in hyperband.brackets[0].rungs) == 0 assert sum(len(rung['results']) for rung in hyperband.brackets[1].rungs) == 1 assert sum(len(rung['results']) for rung in hyperband.brackets[2].rungs) == 1 assert point_hash in hyperband.brackets[2].rungs[0]['results'] assert (0.0, point) == hyperband.brackets[2].rungs[0]['results'][point_hash]
def test_register_bracket_multi_fidelity(self, space): """Check that a point is registered inside the same bracket for diff fidelity.""" hyperband = Hyperband(space) value = 50 fidelity = 1 point = (fidelity, value) point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest() hyperband.observe([point], [{'objective': 0.0}]) bracket = hyperband.brackets[0] assert len(bracket.rungs[0]) assert point_hash in bracket.rungs[0]['results'] assert (0.0, point) == bracket.rungs[0]['results'][point_hash] fidelity = 3 point = [fidelity, value] point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest() hyperband.observe([point], [{'objective': 0.0}]) assert len(bracket.rungs[0]) assert point_hash in bracket.rungs[1]['results'] assert (0.0, point) != bracket.rungs[0]['results'][point_hash] assert (0.0, point) == bracket.rungs[1]['results'][point_hash]
def test_suggest_in_finite_cardinality(self): """Test that suggest None when search space is empty""" space = Space() space.register(Integer('yolo1', 'uniform', 0, 6)) space.register(Fidelity('epoch', 1, 9, 3)) hyperband = Hyperband(space, repetitions=1) for i in range(6): hyperband.observe([(1, i)], [{'objective': i}]) assert hyperband.suggest() is None
def test_register_invalid_fidelity(self, space): """Check that a point cannot registered if fidelity is invalid.""" hyperband = Hyperband(space) value = 50 fidelity = 2 point = (fidelity, value) with pytest.raises(ValueError) as ex: hyperband.observe([point], [{'objective': 0.0}]) assert 'No bracket found for point' in str(ex.value)
def test_register_invalid_fidelity(self, space: Space): """Check that a point cannot registered if fidelity is invalid.""" hyperband = Hyperband(space) value = 50 fidelity = 2 trial = create_trial_for_hb((fidelity, value)) hyperband.observe([trial]) assert not hyperband.has_suggested(trial) assert not hyperband.has_observed(trial)
def test_register_not_sampled(self, space, caplog): """Check that a point cannot registered if not sampled.""" hyperband = Hyperband(space) value = 50 fidelity = 2 point = (fidelity, value) with caplog.at_level(logging.INFO, logger="orion.algo.hyperband"): hyperband.observe([point], [{"objective": 0.0}]) assert len(caplog.records) == 1 assert "Ignoring point" in caplog.records[0].msg
def test_register_not_sampled(self, space, caplog): """Check that a point cannot registered if not sampled.""" hyperband = Hyperband(space) value = 50 fidelity = 2 trial = create_trial_for_hb((fidelity, value)) with caplog.at_level(logging.DEBUG, logger="orion.algo.hyperband"): hyperband.observe([trial]) assert len(caplog.records) == 1 assert "Ignoring trial" in caplog.records[0].msg
def test_register_corrupted_db(self, caplog, space): """Check that a point cannot registered if passed in order diff than fidelity.""" hyperband = Hyperband(space) value = 50 fidelity = 3 point = (fidelity, value) hyperband.observe([point], [{'objective': 0.0}]) assert 'Point registered to wrong bracket' not in caplog.text fidelity = 1 point = [fidelity, value] caplog.clear() hyperband.observe([point], [{'objective': 0.0}]) assert 'Point registered to wrong bracket' in caplog.text
def test_register( self, hyperband: Hyperband, bracket: HyperbandBracket, rung_0: RungDict, rung_1: RungDict, ): """Check that a point is registered inside the bracket.""" hyperband.brackets = [bracket] assert bracket.owner is hyperband bracket.rungs = [rung_0, rung_1] trial = create_trial_for_hb((1, 0.0), 0.0) trial_id = hyperband.get_id(trial, ignore_fidelity=True) hyperband.observe([trial]) assert len(bracket.rungs[0]) assert trial_id in bracket.rungs[0]["results"] assert bracket.rungs[0]["results"][trial_id][0] == 0.0 assert bracket.rungs[0]["results"][trial_id][1].params == trial.params
def force_observe(hyperband: Hyperband, trial: Trial) -> None: # hyperband.sampled.add(hashlib.md5(str(list(point)).encode("utf-8")).hexdigest()) hyperband.register(trial) id_wo_fidelity = hyperband.get_id(trial, ignore_fidelity=True) bracket_index = hyperband.trial_to_brackets.get(id_wo_fidelity, None) if bracket_index is None: fidelity = flatten(trial.params)[hyperband.fidelity_index] assert hyperband.brackets is not None bracket_index = [ i for i, bracket in enumerate(hyperband.brackets) if bracket.rungs[0]["resources"] == fidelity ][0] hyperband.trial_to_brackets[id_wo_fidelity] = bracket_index hyperband.observe([trial])
def test_full_process(self, monkeypatch, hyperband: Hyperband): """Test Hyperband full process.""" sample_trials = [ create_trial_for_hb(("fidelity", i)) for i in range(100) ] hyperband._refresh_brackets() mock_samples(hyperband, copy.deepcopy(sample_trials)) # Fill all brackets' first rung first_rung = hyperband.suggest(100) assert first_rung is not None first_bracket_first_rung = first_rung[6:] second_bracket_first_rung = first_rung[3:6] third_bracket_first_rung = first_rung[:3] compare_trials( first_bracket_first_rung, [create_trial_for_hb((1, i)) for i in range(6, 15)], ) compare_trials( second_bracket_first_rung, [create_trial_for_hb((3, i)) for i in range(3, 6)], ) compare_trials(third_bracket_first_rung, [create_trial_for_hb((9, i)) for i in range(3)]) assert hyperband.brackets is not None assert hyperband.brackets[0].has_rung_filled(0) assert not hyperband.brackets[0].is_ready() assert hyperband.suggest(100) == [] assert hyperband.suggest(100) == [] # Observe first bracket first rung for i, trial in enumerate(first_bracket_first_rung): trial.status = "completed" trial._results.append( Trial.Result(name="objective", type="objective", value=16 - i)) hyperband.observe(first_bracket_first_rung) assert hyperband.brackets[0].is_ready() assert not hyperband.brackets[1].is_ready() assert not hyperband.brackets[2].is_ready() # Promote first bracket first rung first_bracket_second_rung = hyperband.suggest(100) compare_trials( first_bracket_second_rung, [create_trial_for_hb((3, 3 + 3 + 9 - 1 - i)) for i in range(3)], ) assert hyperband.brackets[0].has_rung_filled(1) assert not hyperband.brackets[0].is_ready() assert not hyperband.brackets[1].is_ready() assert not hyperband.brackets[2].is_ready() assert first_bracket_second_rung is not None # Observe first bracket second rung for i, trial in enumerate(first_bracket_second_rung): trial.status = "completed" trial._results.append( Trial.Result(name="objective", type="objective", value=8 - i)) hyperband.observe(first_bracket_second_rung) assert hyperband.brackets[0].is_ready() assert not hyperband.brackets[1].is_ready() assert not hyperband.brackets[2].is_ready() # Promote first bracket second rung first_bracket_third_rung = hyperband.suggest(100) compare_trials(first_bracket_third_rung, [create_trial_for_hb((9, 12))]) assert hyperband.brackets[0].has_rung_filled(2) assert not hyperband.brackets[0].is_ready() assert not hyperband.brackets[1].is_ready() assert not hyperband.brackets[2].is_ready() # Observe second bracket first rung for i, trial in enumerate(second_bracket_first_rung): trial.status = "completed" trial._results.append( Trial.Result(name="objective", type="objective", value=8 - i)) hyperband.observe(second_bracket_first_rung) assert not hyperband.brackets[0].is_ready() assert hyperband.brackets[1].is_ready() assert not hyperband.brackets[2].is_ready() # Promote second bracket first rung second_bracket_second_rung = hyperband.suggest(100) compare_trials(second_bracket_second_rung, [create_trial_for_hb((9, 5))]) assert not hyperband.brackets[0].is_ready() assert hyperband.brackets[1].has_rung_filled(1) assert not hyperband.brackets[1].is_ready() assert not hyperband.brackets[2].is_ready() # Observe third bracket first rung for i, trial in enumerate(third_bracket_first_rung): trial.status = "completed" trial._results.append( Trial.Result(name="objective", type="objective", value=3 - i)) hyperband.observe(third_bracket_first_rung) assert not hyperband.brackets[0].is_ready(2) assert not hyperband.brackets[1].is_ready(1) assert hyperband.brackets[2].is_ready(0) assert hyperband.brackets[2].is_done assert second_bracket_second_rung is not None # Observe second bracket second rung for i, trial in enumerate(second_bracket_second_rung): trial.status = "completed" trial._results.append( Trial.Result(name="objective", type="objective", value=5 - i)) hyperband.observe(second_bracket_second_rung) assert not hyperband.brackets[0].is_ready(2) assert hyperband.brackets[1].is_ready(1) assert hyperband.brackets[1].is_done assert first_bracket_third_rung is not None hyperband.observe(first_bracket_third_rung) assert hyperband.is_done assert hyperband.brackets[0].is_done assert hyperband.suggest(100) == [] # Refresh repeat and execution times property monkeypatch.setattr(hyperband, "repetitions", 2) # monkeypatch.setattr(hyperband.brackets[0], "repetition_id", 0) # hyperband.observe([(9, 12)], [{"objective": 3 - i}]) assert len(hyperband.brackets) == 3 hyperband._refresh_brackets() assert len(hyperband.brackets) == 6 mock_samples(hyperband, copy.deepcopy(sample_trials[:3] + sample_trials)) trials = hyperband.suggest(100) assert not hyperband.is_done assert not hyperband.brackets[3].is_ready(2) assert not hyperband.brackets[3].is_done assert trials is not None compare_trials(trials[:3], map(create_trial_for_hb, [(9, 3), (9, 4), (9, 6)])) compare_trials(trials[3:6], map(create_trial_for_hb, [(3, 7), (3, 8), (3, 9)])) compare_trials(trials[6:], [create_trial_for_hb((1, i)) for i in range(15, 24)])