def test_hyperband(tmpdir): TUNE_OBJECT_FACTORY.set_temp_path(str(tmpdir)) space = keras_space(MockSpec, l1=RandInt(8, 16), l2=RandInt(8, 24)) with raises(TuneCompileError): suggest_keras_models_by_hyperband( space, plans=[ [(2.0, 4), (4.0, 2)], [(4.0, 2), (2.0, 4)], ], ) space = space.sample(10, 0) reports = suggest_keras_models_by_hyperband( space, plans=[ [(2.0, 4), (4.0, 2)], [(4.0, 2), (2.0, 4)], ], top_n=2, ) for r in reports: print(r) assert 2 == len(reports)
def test_sha(tmpdir): TUNE_OBJECT_FACTORY.set_temp_path(str(tmpdir)) space = keras_space(MockSpec, l1=RandInt(8, 16), l2=RandInt(8, 24)) with raises(TuneCompileError): suggest_keras_models_by_sha(space, plan=[(2.0, 4), (4.0, 2)]) space = space.sample(6, 0) reports = suggest_keras_models_by_sha(space, plan=[(2.0, 4), (4.0, 2)], top_n=2) for r in reports: print(r.jsondict) assert 2 == len(reports)
def test_objective(): def validate(reports): assert reports[-1].metric < 15 space = keras_space(MockSpec, l1=16, l2=16) obj = KerasObjective(_TYPE_DICT) for cont in [True, False]: validate_iterative_objective( obj, Trial("a", params=list(space)[0]), budgets=[3, 3, 4], continuous=cont, validator=validate, )