예제 #1
0
def test_preprocessing_data_pipeline_no_running_stage(with_dataset):
    pipe = DataPipeline(_AutoDatasetTestPreprocess(with_dataset))

    dataset = pipe._generate_auto_dataset(range(10), running_stage=None)

    with pytest.raises(RuntimeError, match='`__len__` for `load_sample`'):
        for idx in range(len(dataset)):
            dataset[idx]

    # will be triggered when running stage is set
    if with_dataset:
        assert not hasattr(dataset, 'load_sample_was_called')
        assert not hasattr(dataset, 'load_data_was_called')
        assert pipe._preprocess_pipeline.load_sample_with_dataset_count == 0
        assert pipe._preprocess_pipeline.load_data_with_dataset_count == 0
    else:
        assert pipe._preprocess_pipeline.load_sample_count == 0
        assert pipe._preprocess_pipeline.load_data_count == 0

    dataset.running_stage = RunningStage.TRAINING

    if with_dataset:
        assert pipe._preprocess_pipeline.train_load_data_with_dataset_count == 1
        assert dataset.train_load_data_was_called
    else:
        assert pipe._preprocess_pipeline.train_load_data_count == 1
예제 #2
0
def test_preprocessing_data_pipeline_with_running_stage(with_dataset):
    pipe = DataPipeline(_AutoDatasetTestPreprocess(with_dataset))

    running_stage = RunningStage.TRAINING

    dataset = pipe._generate_auto_dataset(range(10), running_stage=running_stage)

    assert len(dataset) == 10

    for idx in range(len(dataset)):
        dataset[idx]

    if with_dataset:
        assert dataset.train_load_sample_was_called
        assert dataset.train_load_data_was_called
        assert pipe._preprocess_pipeline.train_load_sample_with_dataset_count == len(dataset)
        assert pipe._preprocess_pipeline.train_load_data_with_dataset_count == 1
    else:
        assert pipe._preprocess_pipeline.train_load_sample_count == len(dataset)
        assert pipe._preprocess_pipeline.train_load_data_count == 1