Exemplo n.º 1
0
def test_softmax_model(fake_dataset: DictsDataset):
    tf.logging.set_verbosity(tf.logging.DEBUG)

    features = fake_dataset.features
    steps = np.random.randint(2, 5)
    run_data = gen.run_data(model=MnistSoftmaxModel())
    estimator = training.create_estimator(run_data)

    estimator.train(input_fn=lambda: fake_dataset.as_tuple(), steps=steps)
    names = [gen.random_str() for _ in range(2)]
    eval_results = estimator.evaluate(input_fn=lambda: fake_dataset.as_tuple(),
                                      steps=1,
                                      name=names[0])
    estimator.evaluate(input_fn=lambda: fake_dataset.as_tuple(),
                       steps=1,
                       name=names[1])

    _assert_log_dirs(filenames.get_run_logs_data_dir(run_data), names)

    loss = eval_results['loss']
    global_step = eval_results['global_step']
    accuracy = eval_results['accuracy']
    assert loss.shape == ()
    assert global_step == steps
    assert accuracy.shape == ()

    predictions_generator = estimator.predict(input_fn=lambda: features)

    for _ in range((list(features.values())[0]).shape[0]):
        predictions = next(predictions_generator)
        assert predictions['probabilities'].shape == (2, )
        assert predictions[consts.INFERENCE_CLASSES].shape == ()
Exemplo n.º 2
0
def test_should_train_with_each_model(injected_raw_data_provider):
    tf.logging.set_verbosity(tf.logging.DEBUG)
    run_data = gen.run_data(model=injected_raw_data_provider())
    before_run.prepare_env([], run_data)
    training.train(run_data)

    verify_log_directory(run_data, config[consts.EXCLUDED_KEYS])
Exemplo n.º 3
0
def test_should_create_correct_infer_directory_for_single_model_launcher(
        patched_params):
    run_data = gen.run_data()
    result = filenames.get_infer_dir(run_data)
    expected = filenames._get_home_infer_dir() / run_data.model.summary

    assert result == expected
Exemplo n.º 4
0
def test_should_not_override_old_inference_log(patched_dataset_reading,
                                               patched_params):
    run_data = gen.run_data(model=MnistContrastiveModel())

    inference.single_run_inference(run_data=run_data, show=False)
    infer_results_dir_path = filenames.get_infer_dir(run_data)

    old_cr_times = [
        x.stat().st_ctime
        for x in list(Path(infer_results_dir_path).iterdir())
        if x.suffix == consts.LOG
    ]

    inference.single_run_inference(run_data=run_data, show=False)

    new_cr_times = [
        x.stat().st_ctime
        for x in list(Path(infer_results_dir_path).iterdir())
        if x.suffix == consts.LOG
    ]

    assert len(old_cr_times) == 1
    assert len(new_cr_times) == 2

    assert old_cr_times[0] in new_cr_times
Exemplo n.º 5
0
def test_should_not_delete_existing_launcher_directory_for_experiment_launcher_when_its_not_first_run(
        when, rmtree_mock):
    run_data = gen.run_data(is_experiment=True, run_no=2)

    result = before_run._prepare_launcher_dir(run_data)

    assert result is None
    rmtree_mock.assert_not_called()
Exemplo n.º 6
0
def test_should_not_delete_existing_launcher_directory_for_default_launcher(
        when, rmtree_mock):
    run_data = gen.run_data(is_experiment=False, run_no=1)

    result = before_run._prepare_launcher_dir(run_data)

    assert result is None
    rmtree_mock.assert_not_called()
Exemplo n.º 7
0
def test_should_create_correct_infer_directory_for_experiment_launcher(
        patched_params):
    run_data = gen.run_data(is_experiment=True)
    result = filenames.get_infer_dir(run_data)
    expected = filenames._get_home_infer_dir(
    ) / run_data.launcher_name / run_data.model.summary

    assert result == expected
Exemplo n.º 8
0
def test_should_not_throw_if_model_has_more_checkpoints():
    run_data = gen.run_data(with_model_dir=True)

    model_dir = filenames.get_run_logs_data_dir(run_data)
    empty_checkpoints = [(model_dir / ("model.ckpt-{}.foobar".format(x)))
                         for x in range(2)]
    [f.write_text("this is sparta") for f in empty_checkpoints]
    inference.single_run_inference(run_data=run_data, show=False)
Exemplo n.º 9
0
def test_should_provide_single_run_data_for_experiment_launcher(mocker):
    launcher = FakeExperimentLauncher([FakeModel(), FakeModel()])
    run_data = gen.run_data()
    mocker.patch('src.estimator.launcher.providing_launcher.provide_launcher',
                 return_value=launcher)
    mocker.patch('src.utils.utils.user_run_selection', return_value=run_data)

    result = providing_launcher.provide_single_run_data()
    assert result == run_data
Exemplo n.º 10
0
def test_should_throw_if_model_has_only_0_step_checkpoints():
    run_data = gen.run_data(with_model_dir=True)

    model_dir = filenames.get_run_logs_data_dir(run_data)
    empty_checkpoints = [
        model_dir / ("model.ckpt-0.foobar{}".format(x)) for x in range(5)
    ]
    [f.write_text("this is sparta") for f in empty_checkpoints]
    with pytest.raises(AssertionError):
        inference.single_run_inference(run_data=run_data, show=False)
Exemplo n.º 11
0
def test_should_create_summaries_for_different_models(patched_dataset_reading,
                                                      patched_params):
    model = patched_dataset_reading.param
    run_data = gen.run_data(model=model())

    inference.single_run_inference(run_data=run_data, show=False)

    infer_results_dir_path = filenames.get_infer_dir(run_data)
    assert utils.check_filepath(infer_results_dir_path,
                                is_directory=True,
                                is_empty=False)
Exemplo n.º 12
0
def test_should_delete_existing_run_directory_for_default_launcher(
        when, rmtree_mock):
    run_data = gen.run_data(is_experiment=False)

    run_logs_dir = filenames.get_run_logs_data_dir(run_data)

    when(utils).check_filepath(filename=run_logs_dir,
                               exists=True,
                               is_directory=True,
                               is_empty=ANY).thenReturn(True)

    before_run._prepare_dirs(None, run_data)
    rmtree_mock.assert_called_once_with(str(run_logs_dir))
Exemplo n.º 13
0
def test_create_pair_summaries(patched_dataset_reading):
    provider = patched_dataset_reading.param
    run_data = gen.run_data(model=FakeModel(data_provider=provider()))
    dir_with_pair_summaries = filenames.get_run_logs_data_dir(
        run_data) / 'features'
    assert utils.check_filepath(dir_with_pair_summaries, exists=False)

    image_summaries.create_pair_summaries(run_data)

    assert utils.check_filepath(dir_with_pair_summaries,
                                is_directory=True,
                                is_empty=False)
    assert len(list(dir_with_pair_summaries.iterdir())) == 1
Exemplo n.º 14
0
def test_should_delete_existing_launcher_directory_for_experiment_launcher_during_first_run(
        when, rmtree_mock):
    run_data = gen.run_data(is_experiment=True, run_no=1)
    launcher_dir = filenames.get_launcher_dir(run_data)
    when(utils).check_filepath(filename=launcher_dir,
                               exists=True,
                               is_directory=True,
                               is_empty=ANY).thenReturn(True)

    result = before_run._prepare_launcher_dir(run_data)

    assert result == launcher_dir
    rmtree_mock.assert_called_once_with(str(launcher_dir))
Exemplo n.º 15
0
def test_should_create_correct_number_of_inference_files(
        patched_dataset_reading, patched_params):
    model = patched_dataset_reading.param()
    run_data = gen.run_data(model=model)

    inference.single_run_inference(run_data=run_data, show=False)

    infer_results_dir_path = filenames.get_infer_dir(run_data)
    expected_file_count = 4 if model.produces_2d_embedding else 2
    assert utils.check_filepath(infer_results_dir_path,
                                is_directory=True,
                                is_empty=False,
                                expected_len=expected_file_count)
Exemplo n.º 16
0
def test_should_call_in_memory_evaluator_hooks(input_fn_spies,
                                               patched_dataset_reading,
                                               patched_excluded):
    (train_input_fn_spy, eval_input_fn_spy,
     eval_with_excludes_fn_spy) = input_fn_spies
    run_data = gen.run_data(model=patched_dataset_reading.param())
    before_run.prepare_env([], run_data)
    training.train(run_data)

    train_input_fn_spy.assert_called_once()
    eval_input_fn_spy.assert_called_once()
    assert eval_with_excludes_fn_spy.call_count == (
        1 if config[consts.EXCLUDED_KEYS] else 0)
    verify_log_directory(run_data, config[consts.EXCLUDED_KEYS])
    assert run_data.model.model_fn_calls == (3 if config[consts.EXCLUDED_KEYS]
                                             else 2)
Exemplo n.º 17
0
def test_should_override_plots_with_newer_inference(patched_dataset_reading,
                                                    patched_params):
    run_data = gen.run_data(model=MnistContrastiveModel())

    inference.single_run_inference(run_data=run_data, show=False)
    infer_results_dir_path = filenames.get_infer_dir(run_data)

    old_cr_times = [
        x.stat().st_ctime
        for x in list(Path(infer_results_dir_path).iterdir())
        if x.suffix != consts.LOG
    ]

    inference.single_run_inference(run_data=run_data, show=False)

    new_cr_times = [
        x.stat().st_ctime
        for x in list(Path(infer_results_dir_path).iterdir())
        if x.suffix != consts.LOG
    ]

    var = [(x not in new_cr_times) for x in old_cr_times]
    assert_that(var, only_contains(True))
Exemplo n.º 18
0
def test_run_text_logs_dir(run_dir_mock):
    run_logging_dir = filenames.get_run_text_logs_dir(gen.run_data())

    assert_that(str(run_logging_dir), ends_with('baz/text_logs'))
Exemplo n.º 19
0
def test_should_pass_model_dir_to_estimator():
    model = FakeModel()
    run_data = gen.run_data(model)
    estimator = training.create_estimator(run_data)
    model_dir = estimator.params[consts.MODEL_DIR]
    assert model_dir == str(filenames.get_run_logs_data_dir(run_data))
Exemplo n.º 20
0
def test_should_throw_if_model_has_no_checkpoints():
    run_data = gen.run_data(with_model_dir=True)

    with pytest.raises(AssertionError):
        inference.single_run_inference(run_data=run_data, show=False)
Exemplo n.º 21
0
def test_should_throw_if_model_dir_not_exists():
    run_data = gen.run_data(with_model_dir=False)

    with pytest.raises(AssertionError):
        inference.single_run_inference(run_data=run_data, show=False)