Ejemplo n.º 1
0
def test_tuner_not_call_super_search_with_overwrite(final_fit, super_search,
                                                    tmp_path):
    tuner = greedy.Greedy(hypermodel=utils.build_graph(), directory=tmp_path)

    tuner.search(epochs=10)
    tuner.save()
    super_search.reset_mock()

    tuner = greedy.Greedy(hypermodel=utils.build_graph(), directory=tmp_path)
    tuner.search(epochs=10)

    super_search.assert_not_called()
Ejemplo n.º 2
0
def test_tuner_not_call_super_search_with_overwrite(_, final_fit, super_search,
                                                    tmp_path):
    tuner = greedy.Greedy(hypermodel=utils.build_graph(), directory=tmp_path)
    final_fit.return_value = mock.Mock(), mock.Mock(), mock.Mock()

    tuner.search(x=None, epochs=10, validation_data=None)
    tuner.save()
    super_search.reset_mock()

    tuner = greedy.Greedy(hypermodel=utils.build_graph(), directory=tmp_path)
    tuner.search(x=None, epochs=10, validation_data=None)

    super_search.assert_not_called()
Ejemplo n.º 3
0
def test_no_final_fit_without_epochs_and_fov(get_best_models, final_fit,
                                             super_search, tmp_path):
    tuner = greedy.Greedy(hypermodel=utils.build_graph(), directory=tmp_path)

    tuner.search(x=None, epochs=None, fit_on_val_data=False)

    final_fit.assert_not_called()
Ejemplo n.º 4
0
def test_tuner_call_super_with_early_stopping(final_fit, super_search,
                                              tmp_path):
    tuner = greedy.Greedy(hypermodel=utils.build_graph(), directory=tmp_path)

    tuner.search(x=None, epochs=10)

    assert called_with_early_stopping(super_search)
Ejemplo n.º 5
0
def test_tuner_does_not_crash_with_distribution_strategy(tmp_path):
    tuner = greedy.Greedy(
        hypermodel=utils.build_graph(),
        directory=tmp_path,
        distribution_strategy=tf.distribute.MirroredStrategy(),
    )
    tuner.hypermodel.build(tuner.oracle.hyperparameters)
Ejemplo n.º 6
0
def test_final_fit_with_specified_epochs(_, final_fit, super_search, tmp_path):
    tuner = greedy.Greedy(hypermodel=utils.build_graph(), directory=tmp_path)
    final_fit.return_value = mock.Mock(), mock.Mock(), mock.Mock()

    tuner.search(x=None, epochs=10, validation_data=None)

    assert final_fit.call_args_list[0][1]["epochs"] == 10
Ejemplo n.º 7
0
def test_tuner_call_super_with_early_stopping(_, final_fit, super_search, tmp_path):
    tuner = greedy.Greedy(hypermodel=test_utils.build_graph(), directory=tmp_path)
    final_fit.return_value = mock.Mock(), mock.Mock(), mock.Mock()

    tuner.search(x=None, epochs=10, validation_data=None)

    assert called_with_early_stopping(super_search)
Ejemplo n.º 8
0
def test_final_fit_with_specified_epochs(
        final_fit, super_search, tmp_path):
    tuner = greedy.Greedy(
        hypermodel=utils.build_graph(),
        directory=tmp_path)

    tuner.search(x=None, epochs=10)

    assert final_fit.call_args_list[0][1]['epochs'] == 10
Ejemplo n.º 9
0
def test_final_fit_best_epochs_if_epoch_unspecified(best_epochs, final_fit,
                                                    super_search, tmp_path):
    tuner = greedy.Greedy(hypermodel=utils.build_graph(), directory=tmp_path)

    tuner.search(x=mock.Mock(),
                 epochs=None,
                 fit_on_val_data=True,
                 validation_data=mock.Mock())

    assert final_fit.call_args_list[0][1]["epochs"] == 2
Ejemplo n.º 10
0
def test_super_with_1k_epochs_if_epoch_unspecified(best_epochs, final_fit,
                                                   super_search, tmp_path):
    tuner = greedy.Greedy(hypermodel=utils.build_graph(), directory=tmp_path)

    tuner.search(x=mock.Mock(),
                 epochs=None,
                 fit_on_val_data=True,
                 validation_data=mock.Mock())

    assert super_search.call_args_list[0][1]['epochs'] == 1000
    assert called_with_early_stopping(super_search)
Ejemplo n.º 11
0
def test_super_with_1k_epochs_if_epoch_unspecified(
    _, best_epochs, final_fit, super_search, tmp_path
):
    tuner = greedy.Greedy(hypermodel=test_utils.build_graph(), directory=tmp_path)
    final_fit.return_value = mock.Mock(), mock.Mock(), mock.Mock()

    tuner.search(
        x=mock.Mock(), epochs=None, validation_split=0.2, validation_data=mock.Mock()
    )

    assert super_search.call_args_list[0][1]["epochs"] == 1000
    assert called_with_early_stopping(super_search)