def test_update_trial(tmp_dir): class MyOracle(kerastuner.Oracle): def _populate_space(self, _): values = { p.name: p.random_sample() for p in self.hyperparameters.space } return {'values': values, 'status': 'RUNNING'} def update_trial(self, trial_id, metrics, step=0): if step == 3: trial = self.trials[trial_id] trial.status = "STOPPED" return trial.status return super(MyOracle, self).update_trial(trial_id, metrics, step) my_oracle = MyOracle(objective='val_accuracy', max_trials=2) tuner = kerastuner.Tuner(oracle=my_oracle, hypermodel=build_model, directory=tmp_dir) tuner.search(x=TRAIN_INPUTS, y=TRAIN_TARGETS, epochs=5, validation_data=(VAL_INPUTS, VAL_TARGETS)) assert len(my_oracle.trials) == 2 for trial in my_oracle.trials.values(): # Test that early stopping worked. assert len(trial.metrics.get_history('val_accuracy')) == 3
def test_checkpoint_removal(tmp_dir): def build_model(hp): model = keras.Sequential([ keras.layers.Dense(hp.Int('size', 5, 10)), keras.layers.Dense(1)]) model.compile('sgd', 'mse', metrics=['accuracy']) return model tuner = kerastuner.Tuner( oracle=kerastuner.tuners.randomsearch.RandomSearchOracle( objective='val_accuracy', max_trials=1, seed=1337), hypermodel=build_model, directory=tmp_dir, ) x, y = np.ones((1, 5)), np.ones((1, 1)) tuner.search(x, y, validation_data=(x, y), epochs=21) trial = list(tuner.oracle.trials.values())[0] trial_id = trial.trial_id assert tf.io.gfile.exists(tuner._get_checkpoint_fname(trial_id, 20)) assert not tf.io.gfile.exists(tuner._get_checkpoint_fname(trial_id, 10))
def test_checkpoint_fname_no_tpu(tmp_dir): def build_model(hp): model = keras.Sequential( [keras.layers.Dense(hp.Int('size', 5, 10)), keras.layers.Dense(1)]) model.compile('sgd', 'mse', metrics=['accuracy']) return model tuner = kerastuner.Tuner( oracle=kerastuner.tuners.randomsearch.RandomSearchOracle( objective='val_accuracy', max_trials=1, seed=1337), hypermodel=build_model, directory=tmp_dir, ) assert not tuner._get_checkpoint_fname(trial_id=0, epoch=20).endswith('.h5')
def test_checkpoint_fname_tpu(tmp_dir): def build_model(hp): model = keras.Sequential( [keras.layers.Dense(hp.Int("size", 5, 10)), keras.layers.Dense(1)]) model.compile("sgd", "mse", metrics=["accuracy"]) return model strategy = mock.MagicMock(spec=tf.distribute.TPUStrategy) tuner = kerastuner.Tuner( oracle=kerastuner.tuners.randomsearch.RandomSearchOracle( objective="val_accuracy", max_trials=1, seed=1337), hypermodel=build_model, directory=tmp_dir, distribution_strategy=strategy, ) assert tuner._get_checkpoint_fname(trial_id=0, epoch=20).endswith(".h5")
def test_tuning_correctness(tmp_dir): tuner = kerastuner.Tuner( oracle=kerastuner.tuners.randomsearch.RandomSearchOracle( objective="loss", max_trials=2, seed=1337), hypermodel=MockHyperModel(), directory=tmp_dir, ) tuner.search() assert len(tuner.oracle.trials) == 2 m0_epochs = [float(np.average(x)) for x in MockHyperModel.mode_0] m1_epochs = [float(np.average(x)) for x in MockHyperModel.mode_1] # Score tracking correctness first_trial, second_trial = sorted(tuner.oracle.trials.values(), key=lambda t: t.score) assert first_trial.score == min(m0_epochs) assert second_trial.score == min(m1_epochs) assert tuner.oracle.get_best_trials(1)[0].trial_id == first_trial.trial_id
def test_report_status_to_oracle(tmp_dir): class MyOracle(kerastuner.Oracle): def __init__(self): super(MyOracle, self).__init__() self.trials = collections.defaultdict(list) def populate_space(self, trial_id, space): values = {p.name: p.random_sample() for p in space} return {'values': values, 'status': 'RUN'} def report_status(self, trial_id, status, score=None, t=None): self.trials[trial_id].append((score, t)) if t == 2: return kerastuner.engine.oracle.OracleResponse.STOP return kerastuner.engine.oracle.OracleResponse.OK def save(self, fname): return {} my_oracle = MyOracle() tuner = kerastuner.Tuner( oracle=my_oracle, hypermodel=build_model, objective='val_accuracy', max_trials=2, executions_per_trial=1, directory=tmp_dir) tuner.search(x=TRAIN_INPUTS, y=TRAIN_TARGETS, epochs=5, validation_data=(VAL_INPUTS, VAL_TARGETS)) oracle_trial_ids = set(my_oracle.trials.keys()) tuner_trial_ids = set(trial.trial_id for trial in tuner.trials) assert oracle_trial_ids == tuner_trial_ids for trial_id, scores in my_oracle.trials.items(): # Test that early stopping worked. assert len(scores) == 3