Esempio n. 1
0
def test_extended_summary_working(kwargs, capsys):
    state = TunerState('test', None, **kwargs)
    state.summary()
    summary_out = capsys.readouterr()
    state.summary(extended=True)
    extended_summary_out = capsys.readouterr()
    assert summary_out.out.count(":") < extended_summary_out.out.count(":")
Esempio n. 2
0
    def __init__(self, model_fn, objective, name, distributions, **kwargs):
        """ Tuner abstract class

        Args:
            model_fn (function): Function that return a Keras model
            name (str): name of the tuner
            objective (str): Which objective the tuner optimize for
            distributions (Distributions): distributions object

        Notes:
            All meta data and varialbles are stored into self.state
            defined in ../states/tunerstate.py
        """

        # hypertuner state init
        self.state = TunerState(name, objective, **kwargs)
        self.stats = self.state.stats  # shorthand access
        self.cloudservice = CloudService()

        # check model function
        if not model_fn:
            fatal("Model function can't be empty")
        try:
            mdl = model_fn()
        except:
            traceback.print_exc()
            fatal("Invalid model function")

        if not isinstance(mdl, Model):
            t = "tensorflow.keras.models.Model"
            fatal("Invalid model function: Doesn't return a %s object" % t)

        # function is valid - recording it
        self.model_fn = model_fn

        # Initializing distributions
        hparams = config._DISTRIBUTIONS.get_hyperparameters_config()
        if len(hparams) == 0:
            warning("No hyperparameters used in model function. Are you sure?")

        # set global distribution object to the one requested by tuner
        # !MUST be after _eval_model_fn()
        config._DISTRIBUTIONS = distributions(hparams)

        # instances management
        self.max_fail_streak = 5  # how many failure before giving up
        self.instance_states = InstanceStatesCollection()

        # previous models
        print("Loading from %s" % self.state.host.results_dir)
        count = self.instance_states.load_from_dir(self.state.host.results_dir,
                                                   self.state.project,
                                                   self.state.architecture)
        self.stats.instance_states_previously_trained = count
        info("Tuner initialized")
Esempio n. 3
0
def test_summary(kwargs, capsys):
    state = TunerState('test', None, **kwargs)
    state.summary()
    captured = capsys.readouterr()
    to_test = [
        'results: %s' % kwargs.get('results_dir'),
        'tmp: %s' % kwargs.get('tmp_dir'),
        'export: %s' % kwargs.get('export_dir'),
    ]
    for s in to_test:
        assert s in captured.out
Esempio n. 4
0
def test_invalid_max_epochs(kwargs):
    with pytest.raises(ValueError):
        TunerState('test', None, max_epochs=[], **kwargs)
Esempio n. 5
0
def test_invalid_epoch_budget(kwargs):
    with pytest.raises(ValueError):
        TunerState('test', None, epoch_budget=[], **kwargs)
Esempio n. 6
0
def test_invalid_user_info(kwargs):
    with pytest.raises(ValueError):
        TunerState('test', None, user_info=[], **kwargs)

    with pytest.raises(ValueError):
        TunerState('test', None, user_info='bad', **kwargs)
Esempio n. 7
0
def test_is_serializable(kwargs):
    st = TunerState('test', None, **kwargs)
    is_serializable(st)