def test_preprocessing_data_pipeline_no_running_stage(with_dataset): pipe = DataPipeline(_AutoDatasetTestPreprocess(with_dataset)) dataset = pipe._generate_auto_dataset(range(10), running_stage=None) with pytest.raises(RuntimeError, match='`__len__` for `load_sample`'): for idx in range(len(dataset)): dataset[idx] # will be triggered when running stage is set if with_dataset: assert not hasattr(dataset, 'load_sample_was_called') assert not hasattr(dataset, 'load_data_was_called') assert pipe._preprocess_pipeline.load_sample_with_dataset_count == 0 assert pipe._preprocess_pipeline.load_data_with_dataset_count == 0 else: assert pipe._preprocess_pipeline.load_sample_count == 0 assert pipe._preprocess_pipeline.load_data_count == 0 dataset.running_stage = RunningStage.TRAINING if with_dataset: assert pipe._preprocess_pipeline.train_load_data_with_dataset_count == 1 assert dataset.train_load_data_was_called else: assert pipe._preprocess_pipeline.train_load_data_count == 1
def test_preprocessing_data_pipeline_with_running_stage(with_dataset): pipe = DataPipeline(_AutoDatasetTestPreprocess(with_dataset)) running_stage = RunningStage.TRAINING dataset = pipe._generate_auto_dataset(range(10), running_stage=running_stage) assert len(dataset) == 10 for idx in range(len(dataset)): dataset[idx] if with_dataset: assert dataset.train_load_sample_was_called assert dataset.train_load_data_was_called assert pipe._preprocess_pipeline.train_load_sample_with_dataset_count == len(dataset) assert pipe._preprocess_pipeline.train_load_data_with_dataset_count == 1 else: assert pipe._preprocess_pipeline.train_load_sample_count == len(dataset) assert pipe._preprocess_pipeline.train_load_data_count == 1