def test_tuner_not_call_super_search_with_overwrite( _, final_fit, super_search, tmp_path ): tuner = greedy.Greedy(hypermodel=test_utils.build_graph(), directory=tmp_path) final_fit.return_value = mock.Mock(), mock.Mock(), mock.Mock() tuner.search(x=None, epochs=10, validation_data=None) tuner.save() super_search.reset_mock() tuner = greedy.Greedy(hypermodel=test_utils.build_graph(), directory=tmp_path) tuner.search(x=None, epochs=10, validation_data=None) super_search.assert_not_called()
def test_greedy_oracle_populate_different_values(get_best_trials): hp = keras_tuner.HyperParameters() test_utils.build_graph().build(hp) oracle = greedy.GreedyOracle(objective="val_loss", seed=test_utils.SEED) trial = mock.Mock() trial.hyperparameters = hp get_best_trials.return_value = [trial] oracle.update_space(hp) values_a = oracle.populate_space("a")["values"] values_b = oracle.populate_space("b")["values"] assert not all([values_a[key] == values_b[key] for key in values_a])
def test_greedy_oracle_stop_reach_max_collision(get_best_trials, compute_values_hash): hp = keras_tuner.HyperParameters() test_utils.build_graph().build(hp) oracle = greedy.GreedyOracle(objective="val_loss", seed=test_utils.SEED) trial = mock.Mock() trial.hyperparameters = hp get_best_trials.return_value = [trial] compute_values_hash.return_value = 1 oracle.update_space(hp) oracle.populate_space("a")["values"] assert (oracle.populate_space("b")["status"] == keras_tuner.engine.trial.TrialStatus.STOPPED)
def test_tuner_call_super_with_early_stopping(_, final_fit, super_search, tmp_path): tuner = greedy.Greedy(hypermodel=test_utils.build_graph(), directory=tmp_path) final_fit.return_value = mock.Mock(), mock.Mock(), mock.Mock() tuner.search(x=None, epochs=10, validation_data=None) assert called_with_early_stopping(super_search)
def test_final_fit_with_specified_epochs(_, final_fit, super_search, tmp_path): tuner = greedy.Greedy(hypermodel=test_utils.build_graph(), directory=tmp_path) final_fit.return_value = mock.Mock(), mock.Mock(), mock.Mock() tuner.search(x=None, epochs=10, validation_data=None) assert final_fit.call_args_list[0][1]["epochs"] == 10
def test_tuner_does_not_crash_with_distribution_strategy(tmp_path): tuner = greedy.Greedy( hypermodel=test_utils.build_graph(), directory=tmp_path, distribution_strategy=tf.distribute.MirroredStrategy(), ) tuner.hypermodel.build(tuner.oracle.hyperparameters)
def test_no_final_fit_without_epochs_and_fov( _, _1, _2, get_best_models, final_fit, super_search, tmp_path ): tuner = greedy.Greedy(hypermodel=test_utils.build_graph(), directory=tmp_path) tuner.search(x=None, epochs=None, validation_data=None) final_fit.assert_not_called()
def test_final_fit_best_epochs_if_epoch_unspecified( _, best_epochs, final_fit, super_search, tmp_path ): tuner = greedy.Greedy(hypermodel=test_utils.build_graph(), directory=tmp_path) final_fit.return_value = mock.Mock(), mock.Mock(), mock.Mock() tuner.search( x=mock.Mock(), epochs=None, validation_split=0.2, validation_data=mock.Mock() ) assert final_fit.call_args_list[0][1]["epochs"] == 2