Esempio n. 1
0
 def _test_get_best_trials():
     hps = kt.HyperParameters()
     hps.Int('a', 0, 100, default=5)
     hps.Int('b', 0, 100, default=6)
     oracle = randomsearch.RandomSearchOracle(objective=kt.Objective(
         'score', direction='max'),
                                              max_trials=10,
                                              hyperparameters=hps)
     oracle._set_project_dir(tmp_dir, 'untitled')
     tuner_id = os.environ['KERASTUNER_TUNER_ID']
     if 'chief' in tuner_id:
         oracle_chief.start_server(oracle)
     else:
         client = oracle_client.OracleClient(oracle)
         trial_scores = {}
         for score in range(10):
             trial = client.create_trial(tuner_id)
             assert trial.status == "RUNNING"
             assert 'a' in trial.hyperparameters.values
             assert 'b' in trial.hyperparameters.values
             trial_id = trial.trial_id
             client.update_trial(trial_id, {'score': score})
             client.end_trial(trial_id)
             trial_scores[trial_id] = score
         return
         best_trials = client.get_best_trials(3)
         best_scores = [t.score for t in best_trials]
         assert best_scores == [9, 8, 7]
         # Check that trial_ids are correctly mapped to scores.
         for t in best_trials:
             assert trial_scores[t.trial_id] == t.score
Esempio n. 2
0
 def _test_get_space():
     hps = kt.HyperParameters()
     hps.Int('a', 0, 10, default=3)
     oracle = randomsearch.RandomSearchOracle(objective=kt.Objective(
         'score', 'max'),
                                              max_trials=10,
                                              hyperparameters=hps)
     oracle._set_project_dir(tmp_dir, 'untitled')
     tuner_id = os.environ['KERASTUNER_TUNER_ID']
     if 'chief' in tuner_id:
         oracle_chief.start_server(oracle)
     else:
         client = oracle_client.OracleClient(oracle)
         retrieved_hps = client.get_space()
         assert retrieved_hps.values == {'a': 3}
         assert len(retrieved_hps.space) == 1
Esempio n. 3
0
 def _test_end_trial():
     hps = kt.HyperParameters()
     hps.Int('a', 0, 10, default=5)
     oracle = randomsearch.RandomSearchOracle(objective=kt.Objective(
         'score', 'max'),
                                              max_trials=10,
                                              hyperparameters=hps)
     oracle._set_project_dir(tmp_dir, 'untitled')
     tuner_id = os.environ['KERASTUNER_TUNER_ID']
     if 'chief' in tuner_id:
         oracle_chief.start_server(oracle)
     else:
         client = oracle_client.OracleClient(oracle)
         trial = client.create_trial(tuner_id)
         trial_id = trial.trial_id
         client.update_trial(trial_id, {'score': 1}, step=2)
         client.end_trial(trial_id, "INVALID")
         updated_trial = client.get_trial(trial_id)
         assert updated_trial.status == "INVALID"
Esempio n. 4
0
 def _test_create_trial():
     hps = kt.HyperParameters()
     hps.Int('a', 0, 10, default=5)
     hps.Choice('b', [1, 2, 3])
     oracle = randomsearch.RandomSearchOracle(objective=kt.Objective(
         'score', 'max'),
                                              max_trials=10,
                                              hyperparameters=hps)
     oracle._set_project_dir(tmp_dir, 'untitled')
     tuner_id = os.environ['KERASTUNER_TUNER_ID']
     if 'chief' in tuner_id:
         oracle_chief.start_server(oracle)
     else:
         client = oracle_client.OracleClient(oracle)
         trial = client.create_trial(tuner_id)
         assert trial.status == "RUNNING"
         a = trial.hyperparameters.get('a')
         assert a >= 0 and a <= 10
         b = trial.hyperparameters.get('b')
         assert b in {1, 2, 3}
Esempio n. 5
0
    def _test_update_space():
        oracle = randomsearch.RandomSearchOracle(objective=kt.Objective(
            'score', 'max'),
                                                 max_trials=10)
        oracle._set_project_dir(tmp_dir, 'untitled')
        tuner_id = os.environ['KERASTUNER_TUNER_ID']
        if 'chief' in tuner_id:
            oracle_chief.start_server(oracle)
        else:
            client = oracle_client.OracleClient(oracle)

            hps = kt.HyperParameters()
            hps.Int('a', 0, 10, default=5)
            hps.Choice('b', [1, 2, 3])
            client.update_space(hps)

            retrieved_hps = client.get_space()
            assert len(retrieved_hps.space) == 2
            assert retrieved_hps.values['a'] == 5
            assert retrieved_hps.values['b'] == 1
Esempio n. 6
0
 def _test_update_trial():
     hps = kt.HyperParameters()
     hps.Int('a', 0, 10, default=5)
     oracle = randomsearch.RandomSearchOracle(
         objective='score',
         max_trials=10,
         hyperparameters=hps)
     oracle._set_project_dir(tmp_dir, 'untitled')
     tuner_id = os.environ['KERASTUNER_TUNER_ID']
     if 'chief' in tuner_id:
         oracle_chief.start_server(oracle)
     else:
         client = oracle_client.OracleClient(oracle)
         trial = client.create_trial(tuner_id)
         assert 'score' not in trial.metrics.metrics
         trial_id = trial.trial_id
         client.update_trial(trial_id, {'score': 1}, step=2)
         updated_trial = client.get_trial(trial_id)
         assert updated_trial.metrics.get_history('score') == [
             metrics_tracking.MetricObservation([1], step=2)]