예제 #1
0
def test_hyperband(tmpdir):
    TUNE_OBJECT_FACTORY.set_temp_path(str(tmpdir))

    space = keras_space(MockSpec, l1=RandInt(8, 16), l2=RandInt(8, 24))
    with raises(TuneCompileError):
        suggest_keras_models_by_hyperband(
            space,
            plans=[
                [(2.0, 4), (4.0, 2)],
                [(4.0, 2), (2.0, 4)],
            ],
        )

    space = space.sample(10, 0)
    reports = suggest_keras_models_by_hyperband(
        space,
        plans=[
            [(2.0, 4), (4.0, 2)],
            [(4.0, 2), (2.0, 4)],
        ],
        top_n=2,
    )
    for r in reports:
        print(r)
    assert 2 == len(reports)
예제 #2
0
def test_sha(tmpdir):
    TUNE_OBJECT_FACTORY.set_temp_path(str(tmpdir))

    space = keras_space(MockSpec, l1=RandInt(8, 16), l2=RandInt(8, 24))
    with raises(TuneCompileError):
        suggest_keras_models_by_sha(space, plan=[(2.0, 4), (4.0, 2)])

    space = space.sample(6, 0)
    reports = suggest_keras_models_by_sha(space, plan=[(2.0, 4), (4.0, 2)], top_n=2)
    for r in reports:
        print(r.jsondict)
    assert 2 == len(reports)
예제 #3
0
def test_objective():
    def validate(reports):
        assert reports[-1].metric < 15

    space = keras_space(MockSpec, l1=16, l2=16)
    obj = KerasObjective(_TYPE_DICT)

    for cont in [True, False]:
        validate_iterative_objective(
            obj,
            Trial("a", params=list(space)[0]),
            budgets=[3, 3, 4],
            continuous=cont,
            validator=validate,
        )