def test_hyperband_oracle_bracket_configs(tmp_dir): oracle = hyperband_module.HyperbandOracle( objective=kt.Objective("score", "max"), hyperband_iterations=1, max_epochs=8, factor=2, ) oracle._set_project_dir(tmp_dir, "untitled") # 8, 4, 2, 1 starting epochs. assert oracle._get_num_brackets() == 4 assert oracle._get_num_rounds(bracket_num=3) == 4 assert oracle._get_size(bracket_num=3, round_num=0) == 8 assert oracle._get_epochs(bracket_num=3, round_num=0) == 1 assert oracle._get_size(bracket_num=3, round_num=3) == 1 assert oracle._get_epochs(bracket_num=3, round_num=3) == 8 assert oracle._get_num_rounds(bracket_num=0) == 1 assert oracle._get_size(bracket_num=0, round_num=0) == 4 assert oracle._get_epochs(bracket_num=0, round_num=0) == 8
def test_hyperband_oracle_one_sweep_single_thread(tmp_dir): hp = kt.HyperParameters() hp.Float("a", -100, 100) hp.Float("b", -100, 100) oracle = hyperband_module.HyperbandOracle( hyperparameters=hp, objective=kt.Objective("score", "max"), hyperband_iterations=1, max_epochs=9, factor=3, ) oracle._set_project_dir(tmp_dir, "untitled") score = 0 for bracket_num in reversed(range(oracle._get_num_brackets())): for round_num in range(oracle._get_num_rounds(bracket_num)): for model_num in range(oracle._get_size(bracket_num, round_num)): trial = oracle.create_trial("tuner0") assert trial.status == "RUNNING" score += 1 oracle.update_trial(trial.trial_id, {"score": score}) oracle.end_trial(trial.trial_id, status="COMPLETED") assert len(oracle._brackets[0]["rounds"][round_num]) == oracle._get_size( bracket_num, round_num ) assert len(oracle._brackets) == 1 # Iteration should now be complete. trial = oracle.create_trial("tuner0") assert trial.status == "STOPPED", oracle.hyperband_iterations assert len(oracle.ongoing_trials) == 0 # Brackets should all be finished and removed. assert len(oracle._brackets) == 0 best_trial = oracle.get_best_trials()[0] assert best_trial.score == score
def test_hyperband_oracle_one_sweep_parallel(tmp_dir): hp = kt.HyperParameters() hp.Float("a", -100, 100) hp.Float("b", -100, 100) oracle = hyperband_module.HyperbandOracle( hyperparameters=hp, objective=kt.Objective("score", "max"), hyperband_iterations=1, max_epochs=4, factor=2, ) oracle._set_project_dir(tmp_dir, "untitled") # All round 0 trials from different brackets can be run # in parallel. round0_trials = [] for i in range(10): t = oracle.create_trial("tuner" + str(i)) assert t.status == "RUNNING" round0_trials.append(t) assert len(oracle._brackets) == 3 # Round 1 can't be run until enough models from round 0 # have completed. t = oracle.create_trial("tuner10") assert t.status == "IDLE" for t in round0_trials: oracle.update_trial(t.trial_id, {"score": 1}) oracle.end_trial(t.trial_id, "COMPLETED") round1_trials = [] for i in range(4): t = oracle.create_trial("tuner" + str(i)) assert t.status == "RUNNING" round1_trials.append(t) # Bracket 0 is complete as it only has round 0. assert len(oracle._brackets) == 2 # Round 2 can't be run until enough models from round 1 # have completed. t = oracle.create_trial("tuner10") assert t.status == "IDLE" for t in round1_trials: oracle.update_trial(t.trial_id, {"score": 1}) oracle.end_trial(t.trial_id, "COMPLETED") # Only one trial runs in round 2. round2_trial = oracle.create_trial("tuner0") assert len(oracle._brackets) == 1 # No more trials to run, but wait for existing brackets to end. t = oracle.create_trial("tuner10") assert t.status == "IDLE" oracle.update_trial(round2_trial.trial_id, {"score": 1}) oracle.end_trial(round2_trial.trial_id, "COMPLETED") t = oracle.create_trial("tuner10") assert t.status == "STOPPED", oracle._current_sweep