def test_extended_summary_working(kwargs, capsys): state = TunerState('test', None, **kwargs) state.summary() summary_out = capsys.readouterr() state.summary(extended=True) extended_summary_out = capsys.readouterr() assert summary_out.out.count(":") < extended_summary_out.out.count(":")
def __init__(self, model_fn, objective, name, distributions, **kwargs): """ Tuner abstract class Args: model_fn (function): Function that return a Keras model name (str): name of the tuner objective (str): Which objective the tuner optimize for distributions (Distributions): distributions object Notes: All meta data and varialbles are stored into self.state defined in ../states/tunerstate.py """ # hypertuner state init self.state = TunerState(name, objective, **kwargs) self.stats = self.state.stats # shorthand access self.cloudservice = CloudService() # check model function if not model_fn: fatal("Model function can't be empty") try: mdl = model_fn() except: traceback.print_exc() fatal("Invalid model function") if not isinstance(mdl, Model): t = "tensorflow.keras.models.Model" fatal("Invalid model function: Doesn't return a %s object" % t) # function is valid - recording it self.model_fn = model_fn # Initializing distributions hparams = config._DISTRIBUTIONS.get_hyperparameters_config() if len(hparams) == 0: warning("No hyperparameters used in model function. Are you sure?") # set global distribution object to the one requested by tuner # !MUST be after _eval_model_fn() config._DISTRIBUTIONS = distributions(hparams) # instances management self.max_fail_streak = 5 # how many failure before giving up self.instance_states = InstanceStatesCollection() # previous models print("Loading from %s" % self.state.host.results_dir) count = self.instance_states.load_from_dir(self.state.host.results_dir, self.state.project, self.state.architecture) self.stats.instance_states_previously_trained = count info("Tuner initialized")
def test_summary(kwargs, capsys): state = TunerState('test', None, **kwargs) state.summary() captured = capsys.readouterr() to_test = [ 'results: %s' % kwargs.get('results_dir'), 'tmp: %s' % kwargs.get('tmp_dir'), 'export: %s' % kwargs.get('export_dir'), ] for s in to_test: assert s in captured.out
def test_invalid_max_epochs(kwargs): with pytest.raises(ValueError): TunerState('test', None, max_epochs=[], **kwargs)
def test_invalid_epoch_budget(kwargs): with pytest.raises(ValueError): TunerState('test', None, epoch_budget=[], **kwargs)
def test_invalid_user_info(kwargs): with pytest.raises(ValueError): TunerState('test', None, user_info=[], **kwargs) with pytest.raises(ValueError): TunerState('test', None, user_info='bad', **kwargs)
def test_is_serializable(kwargs): st = TunerState('test', None, **kwargs) is_serializable(st)