示例#1
0
def test_softmax_model(fake_dataset: DictsDataset):
    tf.logging.set_verbosity(tf.logging.DEBUG)

    features = fake_dataset.features
    steps = np.random.randint(2, 5)
    run_data = gen.run_data(model=MnistSoftmaxModel())
    estimator = training.create_estimator(run_data)

    estimator.train(input_fn=lambda: fake_dataset.as_tuple(), steps=steps)
    names = [gen.random_str() for _ in range(2)]
    eval_results = estimator.evaluate(input_fn=lambda: fake_dataset.as_tuple(),
                                      steps=1,
                                      name=names[0])
    estimator.evaluate(input_fn=lambda: fake_dataset.as_tuple(),
                       steps=1,
                       name=names[1])

    _assert_log_dirs(filenames.get_run_logs_data_dir(run_data), names)

    loss = eval_results['loss']
    global_step = eval_results['global_step']
    accuracy = eval_results['accuracy']
    assert loss.shape == ()
    assert global_step == steps
    assert accuracy.shape == ()

    predictions_generator = estimator.predict(input_fn=lambda: features)

    for _ in range((list(features.values())[0]).shape[0]):
        predictions = next(predictions_generator)
        assert predictions['probabilities'].shape == (2, )
        assert predictions[consts.INFERENCE_CLASSES].shape == ()
示例#2
0
def test_should_not_throw_if_model_has_more_checkpoints():
    run_data = gen.run_data(with_model_dir=True)

    model_dir = filenames.get_run_logs_data_dir(run_data)
    empty_checkpoints = [(model_dir / ("model.ckpt-{}.foobar".format(x)))
                         for x in range(2)]
    [f.write_text("this is sparta") for f in empty_checkpoints]
    inference.single_run_inference(run_data=run_data, show=False)
示例#3
0
def test_should_throw_if_model_has_only_0_step_checkpoints():
    run_data = gen.run_data(with_model_dir=True)

    model_dir = filenames.get_run_logs_data_dir(run_data)
    empty_checkpoints = [
        model_dir / ("model.ckpt-0.foobar{}".format(x)) for x in range(5)
    ]
    [f.write_text("this is sparta") for f in empty_checkpoints]
    with pytest.raises(AssertionError):
        inference.single_run_inference(run_data=run_data, show=False)
示例#4
0
def run_data(model=None,
             launcher_name="launcher_name",
             runs_directory_name="runs_directory_name",
             is_experiment=False,
             run_no=1,
             models_count=1,
             with_model_dir=False):
    _run_data = RunData(
        model=model if model is not None else testing_classes.FakeModel(),
        launcher_name=launcher_name,
        runs_directory_name=runs_directory_name,
        is_experiment=is_experiment,
        run_no=run_no,
        models_count=models_count,
        launcher_params={})
    if with_model_dir:
        filenames.get_run_logs_data_dir(_run_data).mkdir(parents=True,
                                                         exist_ok=True)
    return _run_data
示例#5
0
def create_estimator(run_data: RunData):
    model = run_data.model
    utils.log('Creating estimator from model: {}'.format(model.summary))
    model_dir = str(filenames.get_run_logs_data_dir(run_data))
    params = model.params
    params[consts.MODEL_DIR] = model_dir
    return tf.estimator.Estimator(model_fn=model.get_model_fn(),
                                  model_dir=model_dir,
                                  config=tf.estimator.RunConfig(
                                      keep_checkpoint_max=1,
                                      save_checkpoints_secs=60 * 30),
                                  params=params)
示例#6
0
def test_should_delete_existing_run_directory_for_default_launcher(
        when, rmtree_mock):
    run_data = gen.run_data(is_experiment=False)

    run_logs_dir = filenames.get_run_logs_data_dir(run_data)

    when(utils).check_filepath(filename=run_logs_dir,
                               exists=True,
                               is_directory=True,
                               is_empty=ANY).thenReturn(True)

    before_run._prepare_dirs(None, run_data)
    rmtree_mock.assert_called_once_with(str(run_logs_dir))
示例#7
0
def test_create_pair_summaries(patched_dataset_reading):
    provider = patched_dataset_reading.param
    run_data = gen.run_data(model=FakeModel(data_provider=provider()))
    dir_with_pair_summaries = filenames.get_run_logs_data_dir(
        run_data) / 'features'
    assert utils.check_filepath(dir_with_pair_summaries, exists=False)

    image_summaries.create_pair_summaries(run_data)

    assert utils.check_filepath(dir_with_pair_summaries,
                                is_directory=True,
                                is_empty=False)
    assert len(list(dir_with_pair_summaries.iterdir())) == 1
示例#8
0
def create_text_summary(run_data: RunData):
    tf.reset_default_graph()
    with tf.Session() as sess:
        txt_summary = tf.summary.text(
            'configuration',
            tf.constant(config.pretty_full_dict_summary(run_data)))

        dir = filenames.get_run_logs_data_dir(run_data)
        dir.mkdir(exist_ok=True, parents=True)
        writer = tf.summary.FileWriter(str(dir), sess.graph)

        sess.run(tf.global_variables_initializer())

        summary = sess.run(txt_summary)
        writer.add_summary(summary)
        writer.flush()
示例#9
0
def _prepare_log_dir(run_data: RunData):
    log_dir = filenames.get_run_logs_data_dir(run_data)
    if utils.check_filepath(filename=log_dir,
                            exists=True,
                            is_directory=True,
                            is_empty=False):
        utils.log(
            'Found not empty logs directory from previous runs: {}'.format(
                log_dir))
        if config[consts.REMOVE_OLD_MODEL_DIR]:
            utils.log('Deleting old model_dir: {}'.format(log_dir))
            shutil.rmtree(str(log_dir))
    else:
        utils.log(
            'Logs directory from previous runs not found. Creating new: {}'.
            format(log_dir))
        log_dir.mkdir(exist_ok=False, parents=True)
示例#10
0
def create_pair_summaries(run_data: RunData):
    dataset_provider_cls = run_data.model.raw_data_provider
    tf.reset_default_graph()
    batch_size = 10
    utils.log('Creating {} sample features summaries'.format(batch_size))
    dataset: tf.data.Dataset = run_data.model.dataset_provider.supply_dataset(
        dataset_spec=DatasetSpec(
            dataset_provider_cls,
            DatasetType.TEST,
            with_excludes=False,
            encoding=run_data.model.dataset_provider.is_encoded()),
        shuffle_buffer_size=10000,
        batch_size=batch_size,
        prefetch=False)
    iterator = dataset.make_one_shot_iterator()
    iterator = iterator.get_next()
    with tf.Session() as sess:
        left = iterator[0][consts.LEFT_FEATURE_IMAGE]
        right = iterator[0][consts.RIGHT_FEATURE_IMAGE]
        pair_labels = iterator[1][consts.PAIR_LABEL]
        left_labels = iterator[1][consts.LEFT_FEATURE_LABEL]
        right_labels = iterator[1][consts.RIGHT_FEATURE_LABEL]
        pairs_imgs_summary = create_pair_summary(
            left, right, pair_labels, left_labels, right_labels,
            dataset_provider_cls.description)

        image_summary = tf.summary.image('paired_images',
                                         pairs_imgs_summary,
                                         max_outputs=batch_size)
        all_summaries = tf.summary.merge_all()

        dir = filenames.get_run_logs_data_dir(run_data) / 'features'
        dir.mkdir(exist_ok=True, parents=True)
        writer = tf.summary.FileWriter(str(dir), sess.graph)

        sess.run(tf.global_variables_initializer())

        summary = sess.run(all_summaries)
        writer.add_summary(summary)
        writer.flush()
示例#11
0
def _check_model_checkpoint_existence(run_data: RunData):
    strict: bool = config[consts.IS_INFER_CHECKPOINT_OBLIGATORY]
    if not strict:
        utils.log("Not checking checkpoint existence")
        return
    model_dir = filenames.get_run_logs_data_dir(run_data)
    assert model_dir.exists(), "{} does not exists - no model to load!".format(
        model_dir)

    checkpoints = model_dir.glob('*.ckpt-*')
    checkpoints_with_number = {
        x
        for y in checkpoints for x in str(y).split('.') if x.startswith("ckpt")
    }
    step_numbers = {int(x.split('-')[-1]) for x in checkpoints_with_number}

    assert bool(step_numbers), "No checkpoints exists!"
    assert len(
        step_numbers
    ) > 1 or 0 not in step_numbers, "Only one checkpoint  - for 0th step exists!"
    utils.log("Checkpoint directory: ok, max checkoint number: {}".format(
        max(step_numbers)))
示例#12
0
def test_should_pass_model_dir_to_estimator():
    model = FakeModel()
    run_data = gen.run_data(model)
    estimator = training.create_estimator(run_data)
    model_dir = estimator.params[consts.MODEL_DIR]
    assert model_dir == str(filenames.get_run_logs_data_dir(run_data))
示例#13
0
def test_run_logging_dir(run_dir_mock):
    run_logging_dir = filenames.get_run_logs_data_dir(gen.run_data())

    assert_that(str(run_logging_dir), ends_with('baz/logs'))
def _check_for_existence(run_data, glob, quantity=None):
    files = list(filenames.get_run_logs_data_dir(run_data).glob(glob))
    assert len(files) != 0
    if quantity:
        assert len(files) == quantity