def test_hyperband_save_load_at_begining(tmp_dir): hps = hp_module.HyperParameters() hps.Choice('a', [1, 2], default=1) hps.Choice('b', [3, 4], default=3) hps.Choice('c', [5, 6], default=5) hps.Choice('d', [7, 8], default=7) hps.Choice('e', [9, 0], default=9) oracle = hyperband_module.HyperbandOracle(objective='score', max_trials=50, hyperparameters=hps) oracle._set_project_dir(tmp_dir, 'untitled') oracle.save() oracle = hyperband_module.HyperbandOracle(objective='score', max_trials=50, hyperparameters=hps) oracle._set_project_dir(tmp_dir, 'untitled') oracle.reload() trials = [] for i in range(oracle._model_sequence[0]): trial = oracle.create_trial(i) trials.append(trial) assert trial.status == 'RUNNING' oracle.update_trial(trial.trial_id, {'score': 1}) trial = oracle.create_trial('idle0') assert trial.status == 'IDLE' for trial in trials: oracle.end_trial(trial.trial_id, 'COMPLETED')
def test_hyperband_save_load_middle_of_bracket(tmp_dir): hp_list = [ hp_module.Choice('a', [1, 2], default=1), hp_module.Choice('b', [3, 4], default=3), hp_module.Choice('c', [5, 6], default=5), hp_module.Choice('d', [7, 8], default=7), hp_module.Choice('e', [9, 0], default=9) ] oracle = hyperband_module.HyperbandOracle() for trial_id in range(3): oracle.populate_space('0_' + str(trial_id), hp_list) for trial_id in range(2): oracle.result('0_' + str(trial_id), trial_id) fname = os.path.join(tmp_dir, 'oracle') oracle.save(fname) oracle = hyperband_module.HyperbandOracle() oracle.reload(fname) for trial_id in range(oracle._model_sequence[0] - 2): hp = oracle.populate_space('1_' + str(trial_id), hp_list) assert hp['status'] == 'RUN' assert oracle.populate_space('idle', hp_list)['status'] == 'IDLE' for trial_id in range(oracle._model_sequence[0] - 2): oracle.result('1_' + str(trial_id), trial_id)
def test_hyperband_save_load_at_the_end_of_bandit(tmp_dir): hps = hp_module.HyperParameters() hps.Choice('a', [1, 2], default=1) hps.Choice('b', [3, 4], default=3) hps.Choice('c', [5, 6], default=5) hps.Choice('d', [7, 8], default=7) hps.Choice('e', [9, 0], default=9) oracle = hyperband_module.HyperbandOracle(objective='score', max_trials=50, hyperparameters=hps) oracle._set_project_dir(tmp_dir, 'untitled') for bracket in range(oracle._num_brackets): trials = [] for i in range(oracle._model_sequence[bracket]): trial = oracle.create_trial(i) trials.append(trial) hp = trial.hyperparameters assert trial.status == 'RUNNING' assert ( hp.values['tuner/epochs'] == oracle._epoch_sequence[bracket]) if bracket > 0: assert 'tuner/trial_id' in hp.values else: assert 'tuner/trial_id' not in hp.values # Asking for more trials when bracket is not yet complete. trial = oracle.create_trial('idle0') assert trial.status == 'IDLE' for trial in trials: oracle.update_trial(trial.trial_id, {'score': 1.}) oracle.end_trial(trial.trial_id, 'COMPLETED') oracle.save() oracle = hyperband_module.HyperbandOracle(objective='score', max_trials=50, hyperparameters=hps) oracle._set_project_dir(tmp_dir, 'untitled') oracle.reload() trials = [] for i in range(oracle._model_sequence[0]): trial = oracle.create_trial(i) trials.append(trial) hp = trial.hyperparameters assert trial.status == 'RUNNING' assert (hp.values['tuner/epochs'] == oracle._epoch_sequence[0]) assert 'tuner/trial_id' not in hp.values # Asking for more trials when bracket is not yet complete. trial = oracle.create_trial('idle0') assert trial.status == 'IDLE' for trial in trials: oracle.update_trial(trial.trial_id, {'score': 1.}) oracle.end_trial(trial.trial_id, 'COMPLETED')
def test_hyperband_save_load_at_the_end_of_bandit(tmp_dir): hp_list = [ hp_module.Choice('a', [1, 2], default=1), hp_module.Choice('b', [3, 4], default=3), hp_module.Choice('c', [5, 6], default=5), hp_module.Choice('d', [7, 8], default=7), hp_module.Choice('e', [9, 0], default=9) ] oracle = hyperband_module.HyperbandOracle() for trial_id in range(oracle._model_sequence[0]): hp = oracle.populate_space('0_' + str(trial_id), hp_list) assert hp['status'] == 'RUN' assert oracle.populate_space('idle', hp_list)['status'] == 'IDLE' for trial_id in range(oracle._model_sequence[0]): oracle.result('0_' + str(trial_id), trial_id) for trial_id in range(oracle._model_sequence[1]): hp = oracle.populate_space('1_' + str(trial_id), hp_list) assert hp['status'] == 'RUN' assert oracle.populate_space('idle1', hp_list)['status'] == 'IDLE' for trial_id in range(oracle._model_sequence[1]): oracle.result('1_' + str(trial_id), trial_id) for trial_id in range(oracle._model_sequence[2]): hp = oracle.populate_space('2_' + str(trial_id), hp_list) assert hp['status'] == 'RUN' assert hp['values']['tuner/epochs'] == oracle._epoch_sequence[2] assert 'tuner/trial_id' in hp['values'] assert oracle.populate_space('idle2', hp_list)['status'] == 'IDLE' for trial_id in range(oracle._model_sequence[2]): oracle.result('2_' + str(trial_id), trial_id) fname = os.path.join(tmp_dir, 'oracle') oracle.save(fname) oracle = hyperband_module.HyperbandOracle() oracle.reload(fname) for trial_id in range(oracle._model_sequence[0]): hp = oracle.populate_space('3_' + str(trial_id), hp_list) assert hp['status'] == 'RUN' assert hp['values']['tuner/epochs'] == oracle._epoch_sequence[0] assert 'tuner/trial_id' not in hp['values'] assert oracle.populate_space('idle3', hp_list)['status'] == 'IDLE'
def test_hyperband_save_load_middle_of_bracket(tmp_dir): hps = hp_module.HyperParameters() hps.Choice('a', [1, 2], default=1) hps.Choice('b', [3, 4], default=3) hps.Choice('c', [5, 6], default=5) hps.Choice('d', [7, 8], default=7) hps.Choice('e', [9, 0], default=9) oracle = hyperband_module.HyperbandOracle(objective='score', max_trials=50, hyperparameters=hps) oracle._set_project_dir(tmp_dir, 'untitled') trials = [] for i in range(3): trial = oracle.create_trial(i) trials.append(trial) for i in range(2): trial = trials[i] oracle.update_trial(trial.trial_id, {'score': 1.}) oracle.end_trial(trial.trial_id, "COMPLETED") oracle.save() oracle = hyperband_module.HyperbandOracle(objective='score', max_trials=50, hyperparameters=hps) oracle._set_project_dir(tmp_dir, 'untitled') oracle.reload() trials = [] for i in range(oracle._model_sequence[0] - 2): trial = oracle.create_trial(i + 2) trials.append(trial) assert trial.status == 'RUNNING' # Asking for more trials when bracket is not yet complete. trial = oracle.create_trial('idle0') assert trial.status == 'IDLE' for trial in trials: oracle.update_trial(trial.trial_id, {'score': 1.}) oracle.end_trial(trial.trial_id, 'COMPLETED')
def test_hyperband_dynamic_space(tmp_dir): hp_list = [hp_module.Choice('a', [1, 2], default=1)] oracle = hyperband_module.HyperbandOracle() hp_list.append(hp_module.Choice('b', [3, 4], default=3)) values = oracle.populate_space('0', hp_list)['values'] assert 'b' in values oracle.update_space(hp_list) hp_list.append(hp_module.Choice('c', [5, 6], default=5)) assert 'c' in oracle.populate_space('1', hp_list)['values'] hp_list.append(hp_module.Choice('d', [7, 8], default=7)) assert 'd' in oracle.populate_space('2', hp_list)['values'] hp_list.append(hp_module.Choice('e', [9, 0], default=9)) assert 'e' in oracle.populate_space('3', hp_list)['values']
def test_hyperband_oracle(tmp_dir): hp_list = [ hp_module.Choice('a', [1, 2], default=1), hp_module.Choice('b', [3, 4], default=3), hp_module.Choice('c', [5, 6], default=5), hp_module.Choice('d', [7, 8], default=7), hp_module.Choice('e', [9, 0], default=9) ] oracle = hyperband_module.HyperbandOracle() assert oracle._num_brackets == 3 for trial_id in range(oracle._model_sequence[0]): hp = oracle.populate_space('0_' + str(trial_id), hp_list) assert hp['status'] == 'RUN' assert hp['values']['tuner/epochs'] == oracle._epoch_sequence[0] assert 'tuner/trial_id' not in hp['values'] assert oracle.populate_space('idle0', hp_list)['status'] == 'IDLE' for trial_id in range(oracle._model_sequence[0]): oracle.result('0_' + str(trial_id), trial_id) for trial_id in range(oracle._model_sequence[1]): hp = oracle.populate_space('1_' + str(trial_id), hp_list) assert hp['status'] == 'RUN' assert hp['values']['tuner/epochs'] == oracle._epoch_sequence[1] assert 'tuner/trial_id' in hp['values'] assert oracle.populate_space('idle1', hp_list)['status'] == 'IDLE' for trial_id in range(oracle._model_sequence[1]): oracle.result('1_' + str(trial_id), trial_id) for trial_id in range(oracle._model_sequence[2]): hp = oracle.populate_space('2_' + str(trial_id), hp_list) assert hp['status'] == 'RUN' assert hp['values']['tuner/epochs'] == oracle._epoch_sequence[2] assert 'tuner/trial_id' in hp['values'] assert oracle.populate_space('idle2', hp_list)['status'] == 'IDLE' for trial_id in range(oracle._model_sequence[2]): oracle.result('2_' + str(trial_id), trial_id) for trial_id in range(oracle._model_sequence[0]): hp = oracle.populate_space('3_' + str(trial_id), hp_list) assert hp['status'] == 'RUN' assert hp['values']['tuner/epochs'] == oracle._epoch_sequence[0] assert 'tuner/trial_id' not in hp['values'] assert oracle.populate_space('idle3', hp_list)['status'] == 'IDLE' for trial_id in range(oracle._model_sequence[0]): oracle.result('3_' + str(trial_id), trial_id) assert oracle.populate_space('last', hp_list)['status'] == 'RUN'
def test_hyperband_oracle_bracket_configs(tmp_dir): oracle = hyperband_module.HyperbandOracle(objective='score', hyperband_iterations=1, max_epochs=8, factor=2) oracle._set_project_dir(tmp_dir, 'untitled') # 8, 4, 2, 1 starting epochs. assert oracle._get_num_brackets() == 4 assert oracle._get_num_rounds(bracket_num=3) == 4 assert oracle._get_size(bracket_num=3, round_num=0) == 8 assert oracle._get_epochs(bracket_num=3, round_num=0) == 1 assert oracle._get_size(bracket_num=3, round_num=3) == 1 assert oracle._get_epochs(bracket_num=3, round_num=3) == 8 assert oracle._get_num_rounds(bracket_num=0) == 1 assert oracle._get_size(bracket_num=0, round_num=0) == 4 assert oracle._get_epochs(bracket_num=0, round_num=0) == 8
def test_hyperband_dynamic_space(tmp_dir): hps = hp_module.HyperParameters() hps.Choice('a', [1, 2], default=1) oracle = hyperband_module.HyperbandOracle(objective='score', max_trials=50, hyperparameters=hps) oracle._set_project_dir(tmp_dir, 'untitled') hps.Choice('b', [3, 4], default=3) values = oracle._populate_space('0')['values'] assert 'b' in values new_hps = hp_module.HyperParameters() new_hps.Choice('c', [5, 6], default=5) oracle.update_space(new_hps) assert 'c' in oracle._populate_space('1')['values'] new_hps.Choice('d', [7, 8], default=7) oracle.update_space(new_hps) assert 'd' in oracle._populate_space('2')['values'] new_hps.Choice('e', [9, 0], default=9) oracle.update_space(new_hps) assert 'e' in oracle._populate_space('3')['values']
def test_hyperband_oracle_one_sweep_single_thread(tmp_dir): hp = kt.HyperParameters() hp.Float("a", -100, 100) hp.Float("b", -100, 100) oracle = hyperband_module.HyperbandOracle( hyperparameters=hp, objective=kt.Objective("score", "max"), hyperband_iterations=1, max_epochs=9, factor=3, ) oracle._set_project_dir(tmp_dir, "untitled") score = 0 for bracket_num in reversed(range(oracle._get_num_brackets())): for round_num in range(oracle._get_num_rounds(bracket_num)): for model_num in range(oracle._get_size(bracket_num, round_num)): trial = oracle.create_trial("tuner0") assert trial.status == "RUNNING" score += 1 oracle.update_trial(trial.trial_id, {"score": score}) oracle.end_trial(trial.trial_id, status="COMPLETED") assert len(oracle._brackets[0]["rounds"][round_num]) == oracle._get_size( bracket_num, round_num ) assert len(oracle._brackets) == 1 # Iteration should now be complete. trial = oracle.create_trial("tuner0") assert trial.status == "STOPPED", oracle.hyperband_iterations assert len(oracle.ongoing_trials) == 0 # Brackets should all be finished and removed. assert len(oracle._brackets) == 0 best_trial = oracle.get_best_trials()[0] assert best_trial.score == score
def test_hyperband_oracle_one_sweep_single_thread(tmp_dir): hp = kt.HyperParameters() hp.Float('a', -100, 100) hp.Float('b', -100, 100) oracle = hyperband_module.HyperbandOracle(hyperparameters=hp, objective=kt.Objective( 'score', 'max'), hyperband_iterations=1, max_epochs=9, factor=3) oracle._set_project_dir(tmp_dir, 'untitled') score = 0 for bracket_num in reversed(range(oracle._get_num_brackets())): for round_num in range(oracle._get_num_rounds(bracket_num)): for model_num in range(oracle._get_size(bracket_num, round_num)): trial = oracle.create_trial('tuner0') assert trial.status == 'RUNNING' score += 1 oracle.update_trial(trial.trial_id, {'score': score}) oracle.end_trial(trial.trial_id, status='COMPLETED') assert len( oracle._brackets[0]['rounds'][round_num]) == oracle._get_size( bracket_num, round_num) assert len(oracle._brackets) == 1 # Iteration should now be complete. trial = oracle.create_trial('tuner0') assert trial.status == 'STOPPED', oracle.hyperband_iterations assert len(oracle.ongoing_trials) == 0 # Brackets should all be finished and removed. assert len(oracle._brackets) == 0 best_trial = oracle.get_best_trials()[0] assert best_trial.score == score
def test_hyperband_oracle_one_sweep_parallel(tmp_dir): hp = kt.HyperParameters() hp.Float('a', -100, 100) hp.Float('b', -100, 100) oracle = hyperband_module.HyperbandOracle(hyperparameters=hp, objective='score', hyperband_iterations=1, max_epochs=4, factor=2) oracle._set_project_dir(tmp_dir, 'untitled') # All round 0 trials from different brackets can be run # in parallel. round0_trials = [] for i in range(10): t = oracle.create_trial('tuner' + str(i)) assert t.status == 'RUNNING' round0_trials.append(t) assert len(oracle._brackets) == 3 # Round 1 can't be run until enough models from round 0 # have completed. t = oracle.create_trial('tuner10') assert t.status == 'IDLE' for t in round0_trials: oracle.update_trial(t.trial_id, {'score': 1}) oracle.end_trial(t.trial_id, 'COMPLETED') round1_trials = [] for i in range(4): t = oracle.create_trial('tuner' + str(i)) assert t.status == 'RUNNING' round1_trials.append(t) # Bracket 0 is complete as it only has round 0. assert len(oracle._brackets) == 2 # Round 2 can't be run until enough models from round 1 # have completed. t = oracle.create_trial('tuner10') assert t.status == 'IDLE' for t in round1_trials: oracle.update_trial(t.trial_id, {'score': 1}) oracle.end_trial(t.trial_id, 'COMPLETED') # Only one trial runs in round 2. round2_trial = oracle.create_trial('tuner0') assert len(oracle._brackets) == 1 # No more trials to run, but wait for existing brackets to end. t = oracle.create_trial('tuner10') assert t.status == 'IDLE' oracle.update_trial(round2_trial.trial_id, {'score': 1}) oracle.end_trial(round2_trial.trial_id, 'COMPLETED') t = oracle.create_trial('tuner10') assert t.status == 'STOPPED', oracle._current_sweep
def test_hyperband_oracle_one_sweep_parallel(tmp_dir): hp = kt.HyperParameters() hp.Float("a", -100, 100) hp.Float("b", -100, 100) oracle = hyperband_module.HyperbandOracle( hyperparameters=hp, objective=kt.Objective("score", "max"), hyperband_iterations=1, max_epochs=4, factor=2, ) oracle._set_project_dir(tmp_dir, "untitled") # All round 0 trials from different brackets can be run # in parallel. round0_trials = [] for i in range(10): t = oracle.create_trial("tuner" + str(i)) assert t.status == "RUNNING" round0_trials.append(t) assert len(oracle._brackets) == 3 # Round 1 can't be run until enough models from round 0 # have completed. t = oracle.create_trial("tuner10") assert t.status == "IDLE" for t in round0_trials: oracle.update_trial(t.trial_id, {"score": 1}) oracle.end_trial(t.trial_id, "COMPLETED") round1_trials = [] for i in range(4): t = oracle.create_trial("tuner" + str(i)) assert t.status == "RUNNING" round1_trials.append(t) # Bracket 0 is complete as it only has round 0. assert len(oracle._brackets) == 2 # Round 2 can't be run until enough models from round 1 # have completed. t = oracle.create_trial("tuner10") assert t.status == "IDLE" for t in round1_trials: oracle.update_trial(t.trial_id, {"score": 1}) oracle.end_trial(t.trial_id, "COMPLETED") # Only one trial runs in round 2. round2_trial = oracle.create_trial("tuner0") assert len(oracle._brackets) == 1 # No more trials to run, but wait for existing brackets to end. t = oracle.create_trial("tuner10") assert t.status == "IDLE" oracle.update_trial(round2_trial.trial_id, {"score": 1}) oracle.end_trial(round2_trial.trial_id, "COMPLETED") t = oracle.create_trial("tuner10") assert t.status == "STOPPED", oracle._current_sweep