예제 #1
0
def test_hyperband_oracle_bracket_configs(tmp_dir):
    oracle = hyperband_module.HyperbandOracle(
        objective=kt.Objective("score", "max"),
        hyperband_iterations=1,
        max_epochs=8,
        factor=2,
    )
    oracle._set_project_dir(tmp_dir, "untitled")

    # 8, 4, 2, 1 starting epochs.
    assert oracle._get_num_brackets() == 4

    assert oracle._get_num_rounds(bracket_num=3) == 4
    assert oracle._get_size(bracket_num=3, round_num=0) == 8
    assert oracle._get_epochs(bracket_num=3, round_num=0) == 1
    assert oracle._get_size(bracket_num=3, round_num=3) == 1
    assert oracle._get_epochs(bracket_num=3, round_num=3) == 8

    assert oracle._get_num_rounds(bracket_num=0) == 1
    assert oracle._get_size(bracket_num=0, round_num=0) == 4
    assert oracle._get_epochs(bracket_num=0, round_num=0) == 8
예제 #2
0
def test_hyperband_oracle_one_sweep_single_thread(tmp_dir):
    hp = kt.HyperParameters()
    hp.Float("a", -100, 100)
    hp.Float("b", -100, 100)
    oracle = hyperband_module.HyperbandOracle(
        hyperparameters=hp,
        objective=kt.Objective("score", "max"),
        hyperband_iterations=1,
        max_epochs=9,
        factor=3,
    )
    oracle._set_project_dir(tmp_dir, "untitled")

    score = 0
    for bracket_num in reversed(range(oracle._get_num_brackets())):
        for round_num in range(oracle._get_num_rounds(bracket_num)):
            for model_num in range(oracle._get_size(bracket_num, round_num)):
                trial = oracle.create_trial("tuner0")
                assert trial.status == "RUNNING"
                score += 1
                oracle.update_trial(trial.trial_id, {"score": score})
                oracle.end_trial(trial.trial_id, status="COMPLETED")
            assert len(oracle._brackets[0]["rounds"][round_num]) == oracle._get_size(
                bracket_num, round_num
            )
        assert len(oracle._brackets) == 1

    # Iteration should now be complete.
    trial = oracle.create_trial("tuner0")
    assert trial.status == "STOPPED", oracle.hyperband_iterations
    assert len(oracle.ongoing_trials) == 0

    # Brackets should all be finished and removed.
    assert len(oracle._brackets) == 0

    best_trial = oracle.get_best_trials()[0]
    assert best_trial.score == score
예제 #3
0
def test_hyperband_oracle_one_sweep_parallel(tmp_dir):
    hp = kt.HyperParameters()
    hp.Float("a", -100, 100)
    hp.Float("b", -100, 100)
    oracle = hyperband_module.HyperbandOracle(
        hyperparameters=hp,
        objective=kt.Objective("score", "max"),
        hyperband_iterations=1,
        max_epochs=4,
        factor=2,
    )
    oracle._set_project_dir(tmp_dir, "untitled")

    # All round 0 trials from different brackets can be run
    # in parallel.
    round0_trials = []
    for i in range(10):
        t = oracle.create_trial("tuner" + str(i))
        assert t.status == "RUNNING"
        round0_trials.append(t)

    assert len(oracle._brackets) == 3

    # Round 1 can't be run until enough models from round 0
    # have completed.
    t = oracle.create_trial("tuner10")
    assert t.status == "IDLE"

    for t in round0_trials:
        oracle.update_trial(t.trial_id, {"score": 1})
        oracle.end_trial(t.trial_id, "COMPLETED")

    round1_trials = []
    for i in range(4):
        t = oracle.create_trial("tuner" + str(i))
        assert t.status == "RUNNING"
        round1_trials.append(t)

    # Bracket 0 is complete as it only has round 0.
    assert len(oracle._brackets) == 2

    # Round 2 can't be run until enough models from round 1
    # have completed.
    t = oracle.create_trial("tuner10")
    assert t.status == "IDLE"

    for t in round1_trials:
        oracle.update_trial(t.trial_id, {"score": 1})
        oracle.end_trial(t.trial_id, "COMPLETED")

    # Only one trial runs in round 2.
    round2_trial = oracle.create_trial("tuner0")

    assert len(oracle._brackets) == 1

    # No more trials to run, but wait for existing brackets to end.
    t = oracle.create_trial("tuner10")
    assert t.status == "IDLE"

    oracle.update_trial(round2_trial.trial_id, {"score": 1})
    oracle.end_trial(round2_trial.trial_id, "COMPLETED")

    t = oracle.create_trial("tuner10")
    assert t.status == "STOPPED", oracle._current_sweep