def test_data_pipeline_predict_worker_preprocessor_and_device_preprocessor(): preprocess = CustomPreprocess() data_pipeline = DataPipeline(preprocess) data_pipeline.worker_preprocessor(RunningStage.TRAINING) with pytest.raises(MisconfigurationException, match="are mutual exclusive"): data_pipeline.worker_preprocessor(RunningStage.VALIDATING) with pytest.raises(MisconfigurationException, match="are mutual exclusive"): data_pipeline.worker_preprocessor(RunningStage.TESTING) data_pipeline.worker_preprocessor(RunningStage.PREDICTING)
def test_data_pipeline_is_overriden_and_resolve_function_hierarchy(tmpdir): class CustomPreprocess(Preprocess): def load_data(self, *_, **__): pass def test_load_data(self, *_, **__): pass def predict_load_data(self, *_, **__): pass def predict_load_sample(self, *_, **__): pass def val_load_sample(self, *_, **__): pass def val_pre_tensor_transform(self, *_, **__): pass def predict_to_tensor_transform(self, *_, **__): pass def train_post_tensor_transform(self, *_, **__): pass def test_collate(self, *_, **__): pass def val_per_sample_transform_on_device(self, *_, **__): pass def train_per_batch_transform_on_device(self, *_, **__): pass def test_per_batch_transform_on_device(self, *_, **__): pass preprocess = CustomPreprocess() data_pipeline = DataPipeline(preprocess) train_func_names = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.TRAINING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } val_func_names = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.VALIDATING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } test_func_names = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.TESTING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } predict_func_names = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.PREDICTING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } # load_data assert train_func_names["load_data"] == "load_data" assert val_func_names["load_data"] == "load_data" assert test_func_names["load_data"] == "test_load_data" assert predict_func_names["load_data"] == "predict_load_data" # load_sample assert train_func_names["load_sample"] == "load_sample" assert val_func_names["load_sample"] == "val_load_sample" assert test_func_names["load_sample"] == "load_sample" assert predict_func_names["load_sample"] == "predict_load_sample" # pre_tensor_transform assert train_func_names["pre_tensor_transform"] == "pre_tensor_transform" assert val_func_names["pre_tensor_transform"] == "val_pre_tensor_transform" assert test_func_names["pre_tensor_transform"] == "pre_tensor_transform" assert predict_func_names["pre_tensor_transform"] == "pre_tensor_transform" # to_tensor_transform assert train_func_names["to_tensor_transform"] == "to_tensor_transform" assert val_func_names["to_tensor_transform"] == "to_tensor_transform" assert test_func_names["to_tensor_transform"] == "to_tensor_transform" assert predict_func_names["to_tensor_transform"] == "predict_to_tensor_transform" # post_tensor_transform assert train_func_names["post_tensor_transform"] == "train_post_tensor_transform" assert val_func_names["post_tensor_transform"] == "post_tensor_transform" assert test_func_names["post_tensor_transform"] == "post_tensor_transform" assert predict_func_names["post_tensor_transform"] == "post_tensor_transform" # collate assert train_func_names["collate"] == "collate" assert val_func_names["collate"] == "collate" assert test_func_names["collate"] == "test_collate" assert predict_func_names["collate"] == "collate" # per_sample_transform_on_device assert train_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device" assert val_func_names["per_sample_transform_on_device"] == "val_per_sample_transform_on_device" assert test_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device" assert predict_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device" # per_batch_transform_on_device assert train_func_names["per_batch_transform_on_device"] == "train_per_batch_transform_on_device" assert val_func_names["per_batch_transform_on_device"] == "per_batch_transform_on_device" assert test_func_names["per_batch_transform_on_device"] == "test_per_batch_transform_on_device" assert predict_func_names["per_batch_transform_on_device"] == "per_batch_transform_on_device" train_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TRAINING) val_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.VALIDATING) test_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TESTING) predict_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.PREDICTING) _seq = train_worker_preprocessor.per_sample_transform assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.train_post_tensor_transform assert train_worker_preprocessor.collate_fn.func == default_collate assert train_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform _seq = val_worker_preprocessor.per_sample_transform assert _seq.pre_tensor_transform.func == preprocess.val_pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform assert val_worker_preprocessor.collate_fn.func == data_pipeline._identity assert val_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform _seq = test_worker_preprocessor.per_sample_transform assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform assert test_worker_preprocessor.collate_fn.func == preprocess.test_collate assert test_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform _seq = predict_worker_preprocessor.per_sample_transform assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.predict_to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform assert predict_worker_preprocessor.collate_fn.func == default_collate assert predict_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform
def test_preprocess_transforms(tmpdir): """ This test makes sure that when a preprocess is being provided transforms as dictionaries, checking is done properly, and collate_in_worker_from_transform is properly extracted. """ with pytest.raises(MisconfigurationException, match="Transform should be a dict."): DefaultPreprocess(train_transform="choco") with pytest.raises(MisconfigurationException, match="train_transform contains {'choco'}. Only"): DefaultPreprocess(train_transform={"choco": None}) preprocess = DefaultPreprocess( train_transform={"to_tensor_transform": torch.nn.Linear(1, 1)}) # keep is None assert preprocess._train_collate_in_worker_from_transform is True assert preprocess._val_collate_in_worker_from_transform is None assert preprocess._test_collate_in_worker_from_transform is None assert preprocess._predict_collate_in_worker_from_transform is None with pytest.raises( MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`" ): preprocess = DefaultPreprocess( train_transform={ "per_batch_transform": torch.nn.Linear(1, 1), "per_sample_transform_on_device": torch.nn.Linear(1, 1) }) preprocess = DefaultPreprocess( train_transform={"per_batch_transform": torch.nn.Linear(1, 1)}, predict_transform={ "per_sample_transform_on_device": torch.nn.Linear(1, 1) }) # keep is None assert preprocess._train_collate_in_worker_from_transform is True assert preprocess._val_collate_in_worker_from_transform is None assert preprocess._test_collate_in_worker_from_transform is None assert preprocess._predict_collate_in_worker_from_transform is False train_preprocessor = DataPipeline( preprocess=preprocess).worker_preprocessor(RunningStage.TRAINING) val_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor( RunningStage.VALIDATING) test_preprocessor = DataPipeline( preprocess=preprocess).worker_preprocessor(RunningStage.TESTING) predict_preprocessor = DataPipeline( preprocess=preprocess).worker_preprocessor(RunningStage.PREDICTING) assert train_preprocessor.collate_fn.func == default_collate assert val_preprocessor.collate_fn.func == default_collate assert test_preprocessor.collate_fn.func == default_collate assert predict_preprocessor.collate_fn.func == DataPipeline._identity class CustomPreprocess(DefaultPreprocess): def per_sample_transform_on_device(self, sample: Any) -> Any: return super().per_sample_transform_on_device(sample) def per_batch_transform(self, batch: Any) -> Any: return super().per_batch_transform(batch) preprocess = CustomPreprocess( train_transform={"per_batch_transform": torch.nn.Linear(1, 1)}, predict_transform={ "per_sample_transform_on_device": torch.nn.Linear(1, 1) }) # keep is None assert preprocess._train_collate_in_worker_from_transform is True assert preprocess._val_collate_in_worker_from_transform is None assert preprocess._test_collate_in_worker_from_transform is None assert preprocess._predict_collate_in_worker_from_transform is False data_pipeline = DataPipeline(preprocess=preprocess) train_preprocessor = data_pipeline.worker_preprocessor( RunningStage.TRAINING) with pytest.raises( MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`" ): val_preprocessor = data_pipeline.worker_preprocessor( RunningStage.VALIDATING) with pytest.raises( MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`" ): test_preprocessor = data_pipeline.worker_preprocessor( RunningStage.TESTING) predict_preprocessor = data_pipeline.worker_preprocessor( RunningStage.PREDICTING) assert train_preprocessor.collate_fn.func == default_collate assert predict_preprocessor.collate_fn.func == DataPipeline._identity