def test_saving_with_serializers(tmpdir): checkpoint_file = os.path.join(tmpdir, 'tmp.ckpt') class CustomModel(Task): def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) serializer = Labels(["a", "b"]) model = CustomModel() trainer = Trainer(fast_dev_run=True) data_pipeline = DataPipeline(DefaultPreprocess(), serializer=serializer) data_pipeline.initialize() model.data_pipeline = data_pipeline assert isinstance(model.preprocess, DefaultPreprocess) dummy_data = DataLoader( list( zip(torch.arange(10, dtype=torch.float), torch.arange(10, dtype=torch.float)))) trainer.fit(model, train_dataloader=dummy_data) trainer.save_checkpoint(checkpoint_file) model = CustomModel.load_from_checkpoint(checkpoint_file) assert isinstance(model.preprocess._data_pipeline_state, DataPipelineState) assert model.preprocess._data_pipeline_state._state[ ClassificationState] == ClassificationState(['a', 'b'])
def test_predict_numpy(): img = np.ones((1, 3, 10, 20)) model = SemanticSegmentation(2) data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess()) out = model.predict(img, data_source="numpy", data_pipeline=data_pipe) assert isinstance(out[0], torch.Tensor) assert out[0].shape == (196, 196)
def test_serialization_data_pipeline(tmpdir): model = CustomModel() checkpoint_file = os.path.join(tmpdir, 'tmp.ckpt') checkpoint = ModelCheckpoint(tmpdir, 'test.ckpt') trainer = Trainer(callbacks=[checkpoint], max_epochs=1) dummy_data = DataLoader( list( zip(torch.arange(10, dtype=torch.float), torch.arange(10, dtype=torch.float)))) trainer.fit(model, dummy_data) assert model.data_pipeline is None trainer.save_checkpoint(checkpoint_file) loaded_model = CustomModel.load_from_checkpoint(checkpoint_file) assert loaded_model.data_pipeline is None model.data_pipeline = DataPipeline(CustomPreprocess()) trainer.fit(model, dummy_data) assert model.data_pipeline assert isinstance(model.preprocess, CustomPreprocess) trainer.save_checkpoint(checkpoint_file) loaded_model = CustomModel.load_from_checkpoint(checkpoint_file) assert loaded_model.data_pipeline assert isinstance(loaded_model.preprocess, CustomPreprocess) for file in os.listdir(tmpdir): if file.endswith('.ckpt'): os.remove(os.path.join(tmpdir, file))
def test_preprocessing_data_pipeline_no_running_stage(with_dataset): pipe = DataPipeline(_AutoDatasetTestPreprocess(with_dataset)) dataset = pipe._generate_auto_dataset(range(10), running_stage=None) with pytest.raises(RuntimeError, match='`__len__` for `load_sample`'): for idx in range(len(dataset)): dataset[idx] # will be triggered when running stage is set if with_dataset: assert not hasattr(dataset, 'load_sample_was_called') assert not hasattr(dataset, 'load_data_was_called') assert pipe._preprocess_pipeline.load_sample_with_dataset_count == 0 assert pipe._preprocess_pipeline.load_data_with_dataset_count == 0 else: assert pipe._preprocess_pipeline.load_sample_count == 0 assert pipe._preprocess_pipeline.load_data_count == 0 dataset.running_stage = RunningStage.TRAINING if with_dataset: assert pipe._preprocess_pipeline.train_load_data_with_dataset_count == 1 assert dataset.train_load_data_was_called else: assert pipe._preprocess_pipeline.train_load_data_count == 1
def test_autodataset_warning(): with pytest.warns( UserWarning, match= "``datapipeline`` is specified but load_sample and/or load_data are also specified" ): AutoDataset(range(10), load_data=lambda x: x, load_sample=lambda x: x, data_pipeline=DataPipeline())
def from_load_data_inputs( cls, train_load_data_input: Optional[Any] = None, val_load_data_input: Optional[Any] = None, test_load_data_input: Optional[Any] = None, predict_load_data_input: Optional[Any] = None, preprocess: Optional[Preprocess] = None, postprocess: Optional[Postprocess] = None, **kwargs, ) -> 'DataModule': """ This functions is an helper to generate a ``DataModule`` from a ``DataPipeline``. Args: cls: ``DataModule`` subclass train_load_data_input: Data to be received by the ``train_load_data`` function from this ``Preprocess`` val_load_data_input: Data to be received by the ``val_load_data`` function from this ``Preprocess`` test_load_data_input: Data to be received by the ``test_load_data`` function from this ``Preprocess`` predict_load_data_input: Data to be received by the ``predict_load_data`` function from this ``Preprocess`` kwargs: Any extra arguments to instantiate the provided ``DataModule`` """ # trick to get data_pipeline from empty DataModule if preprocess or postprocess: data_pipeline = DataPipeline( preprocess or cls(**kwargs).preprocess, postprocess or cls(**kwargs).postprocess, ) else: data_pipeline = cls(**kwargs).data_pipeline train_dataset = cls._generate_dataset_if_possible( train_load_data_input, running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline) val_dataset = cls._generate_dataset_if_possible( val_load_data_input, running_stage=RunningStage.VALIDATING, data_pipeline=data_pipeline) test_dataset = cls._generate_dataset_if_possible( test_load_data_input, running_stage=RunningStage.TESTING, data_pipeline=data_pipeline) predict_dataset = cls._generate_dataset_if_possible( predict_load_data_input, running_stage=RunningStage.PREDICTING, data_pipeline=data_pipeline) datamodule = cls(train_dataset=train_dataset, val_dataset=val_dataset, test_dataset=test_dataset, predict_dataset=predict_dataset, **kwargs) datamodule._preprocess = data_pipeline._preprocess_pipeline datamodule._postprocess = data_pipeline._postprocess_pipeline return datamodule
def test_data_pipeline_predict_worker_preprocessor_and_device_preprocessor(): preprocess = CustomPreprocess() data_pipeline = DataPipeline(preprocess) data_pipeline.worker_preprocessor(RunningStage.TRAINING) with pytest.raises(MisconfigurationException, match="are mutual exclusive"): data_pipeline.worker_preprocessor(RunningStage.VALIDATING) with pytest.raises(MisconfigurationException, match="are mutual exclusive"): data_pipeline.worker_preprocessor(RunningStage.TESTING) data_pipeline.worker_preprocessor(RunningStage.PREDICTING)
def test_detach_preprocessing_from_model(tmpdir): preprocess = CustomPreprocess() data_pipeline = DataPipeline(preprocess) model = CustomModel() model.data_pipeline = data_pipeline assert model.train_dataloader().collate_fn == default_collate assert model.transfer_batch_to_device.__self__ == model model.on_train_dataloader() assert isinstance(model.train_dataloader().collate_fn, _PreProcessor) assert isinstance(model.transfer_batch_to_device, _StageOrchestrator) model.on_fit_end() assert model.transfer_batch_to_device.__self__ == model assert model.train_dataloader().collate_fn == default_collate
def test_data_pipeline_init_and_assignement(use_preprocess, use_postprocess, tmpdir): class SubPreprocess(Preprocess): pass class SubPostprocess(Postprocess): pass data_pipeline = DataPipeline( SubPreprocess() if use_preprocess else None, SubPostprocess() if use_postprocess else None, ) assert isinstance(data_pipeline._preprocess_pipeline, SubPreprocess if use_preprocess else Preprocess) assert isinstance(data_pipeline._postprocess_pipeline, SubPostprocess if use_postprocess else Postprocess) model = CustomModel(Postprocess()) model.data_pipeline = data_pipeline assert isinstance(model._preprocess, SubPreprocess if use_preprocess else Preprocess) assert isinstance(model._postprocess, SubPostprocess if use_postprocess else Postprocess)
def data_pipeline(self) -> Optional[DataPipeline]: if self._data_pipeline is not None: return self._data_pipeline elif self.preprocess is not None or self.postprocess is not None: # use direct attributes here to avoid recursion with properties that also check the data_pipeline property return DataPipeline(self.preprocess, self.postprocess) elif self.datamodule is not None and getattr( self.datamodule, 'data_pipeline', None) is not None: return self.datamodule.data_pipeline elif self.trainer is not None and hasattr( self.trainer, 'datamodule') and getattr( self.trainer.datamodule, 'data_pipeline', None) is not None: return self.trainer.datamodule.data_pipeline return self._data_pipeline
def test_data_pipeline_init_and_assignement(use_preprocess, use_postprocess, tmpdir): class CustomModel(Task): def __init__(self, postprocess: Optional[Postprocess] = None): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) self._postprocess = postprocess def train_dataloader(self) -> Any: return DataLoader(DummyDataset()) class SubPreprocess(DefaultPreprocess): pass class SubPostprocess(Postprocess): pass data_pipeline = DataPipeline( preprocess=SubPreprocess() if use_preprocess else None, postprocess=SubPostprocess() if use_postprocess else None, ) assert isinstance(data_pipeline._preprocess_pipeline, SubPreprocess if use_preprocess else DefaultPreprocess) assert isinstance(data_pipeline._postprocess_pipeline, SubPostprocess if use_postprocess else Postprocess) model = CustomModel(postprocess=Postprocess()) model.data_pipeline = data_pipeline # TODO: the line below should make the same effect but it's not # data_pipeline._attach_to_model(model) if use_preprocess: assert isinstance(model._preprocess, SubPreprocess) else: assert model._preprocess is None or isinstance(model._preprocess, Preprocess) if use_postprocess: assert isinstance(model._postprocess, SubPostprocess) else: assert model._postprocess is None or isinstance( model._postprocess, Postprocess)
def test_preprocessing_data_pipeline_with_running_stage(with_dataset): pipe = DataPipeline(_AutoDatasetTestPreprocess(with_dataset)) running_stage = RunningStage.TRAINING dataset = pipe._generate_auto_dataset(range(10), running_stage=running_stage) assert len(dataset) == 10 for idx in range(len(dataset)): dataset[idx] if with_dataset: assert dataset.train_load_sample_was_called assert dataset.train_load_data_was_called assert pipe._preprocess_pipeline.train_load_sample_with_dataset_count == len(dataset) assert pipe._preprocess_pipeline.train_load_data_with_dataset_count == 1 else: assert pipe._preprocess_pipeline.train_load_sample_count == len(dataset) assert pipe._preprocess_pipeline.train_load_data_count == 1
def test_detach_preprocessing_from_model(tmpdir): class CustomModel(Task): def __init__(self, postprocess: Optional[Postprocess] = None): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) self._postprocess = postprocess def train_dataloader(self) -> Any: return DataLoader(DummyDataset()) preprocess = CustomPreprocess() data_pipeline = DataPipeline(preprocess=preprocess) model = CustomModel() model.data_pipeline = data_pipeline assert model.train_dataloader().collate_fn == default_collate assert model.transfer_batch_to_device.__self__ == model model.on_train_dataloader() assert isinstance(model.train_dataloader().collate_fn, _PreProcessor) assert isinstance(model.transfer_batch_to_device, _StageOrchestrator) model.on_fit_end() assert model.transfer_batch_to_device.__self__ == model assert model.train_dataloader().collate_fn == default_collate
def preprocess(self, preprocess: Preprocess) -> None: self._preprocess = preprocess self.data_pipeline = DataPipeline(preprocess, self.postprocess)
def data_pipeline(self) -> DataPipeline: return DataPipeline(self.preprocess, self.postprocess)
def test_attaching_datapipeline_to_model(tmpdir): class SubPreprocess(DefaultPreprocess): pass preprocess = SubPreprocess() data_pipeline = DataPipeline(preprocess=preprocess) class CustomModel(Task): def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) self._postprocess = Postprocess() def training_step(self, batch: Any, batch_idx: int) -> Any: pass def validation_step(self, batch: Any, batch_idx: int) -> Any: pass def test_step(self, batch: Any, batch_idx: int) -> Any: pass def train_dataloader(self) -> Any: return DataLoader(DummyDataset()) def val_dataloader(self) -> Any: return DataLoader(DummyDataset()) def test_dataloader(self) -> Any: return DataLoader(DummyDataset()) def predict_dataloader(self) -> Any: return DataLoader(DummyDataset()) class TestModel(CustomModel): stages = [ RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING ] on_train_start_called = False on_val_start_called = False on_test_start_called = False on_predict_start_called = False def on_fit_start(self): assert self.predict_step.__self__ == self self._saved_predict_step = self.predict_step def _compare_pre_processor(self, p1, p2): p1_seq = p1.per_sample_transform p2_seq = p2.per_sample_transform assert p1_seq.pre_tensor_transform.func == p2_seq.pre_tensor_transform.func assert p1_seq.to_tensor_transform.func == p2_seq.to_tensor_transform.func assert p1_seq.post_tensor_transform.func == p2_seq.post_tensor_transform.func assert p1.collate_fn.func == p2.collate_fn.func assert p1.per_batch_transform.func == p2.per_batch_transform.func def _assert_stage_orchestrator_state( self, stage_mapping: Dict, current_running_stage: RunningStage, cls=_PreProcessor): assert isinstance(stage_mapping[current_running_stage], cls) assert stage_mapping[current_running_stage] def on_train_dataloader(self) -> None: current_running_stage = RunningStage.TRAINING self.on_train_dataloader_called = True collate_fn = self.train_dataloader().collate_fn # noqa F811 assert collate_fn == default_collate assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator) super().on_train_dataloader() collate_fn = self.train_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage self._compare_pre_processor( collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) self._assert_stage_orchestrator_state( self.transfer_batch_to_device._stage_mapping, current_running_stage) def on_val_dataloader(self) -> None: current_running_stage = RunningStage.VALIDATING self.on_val_dataloader_called = True collate_fn = self.val_dataloader().collate_fn # noqa F811 assert collate_fn == default_collate assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) super().on_val_dataloader() collate_fn = self.val_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage self._compare_pre_processor( collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) self._assert_stage_orchestrator_state( self.transfer_batch_to_device._stage_mapping, current_running_stage) def on_test_dataloader(self) -> None: current_running_stage = RunningStage.TESTING self.on_test_dataloader_called = True collate_fn = self.test_dataloader().collate_fn # noqa F811 assert collate_fn == default_collate assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator) super().on_test_dataloader() collate_fn = self.test_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage self._compare_pre_processor( collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) self._assert_stage_orchestrator_state( self.transfer_batch_to_device._stage_mapping, current_running_stage) def on_predict_dataloader(self) -> None: current_running_stage = RunningStage.PREDICTING self.on_predict_dataloader_called = True collate_fn = self.predict_dataloader().collate_fn # noqa F811 assert collate_fn == default_collate assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) assert self.predict_step == self._saved_predict_step super().on_predict_dataloader() collate_fn = self.predict_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage self._compare_pre_processor( collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) assert isinstance(self.predict_step, _StageOrchestrator) self._assert_stage_orchestrator_state( self.transfer_batch_to_device._stage_mapping, current_running_stage) self._assert_stage_orchestrator_state( self.predict_step._stage_mapping, current_running_stage, cls=_PostProcessor) def on_fit_end(self) -> None: super().on_fit_end() assert self.train_dataloader().collate_fn == default_collate assert self.val_dataloader().collate_fn == default_collate assert self.test_dataloader().collate_fn == default_collate assert self.predict_dataloader().collate_fn == default_collate assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator) assert self.predict_step == self._saved_predict_step model = TestModel() model.data_pipeline = data_pipeline trainer = Trainer(fast_dev_run=True) trainer.fit(model) trainer.test(model) trainer.predict(model) assert model.on_train_dataloader_called assert model.on_val_dataloader_called assert model.on_test_dataloader_called assert model.on_predict_dataloader_called
def test_preprocess_transforms(tmpdir): """ This test makes sure that when a preprocess is being provided transforms as dictionaries, checking is done properly, and collate_in_worker_from_transform is properly extracted. """ with pytest.raises(MisconfigurationException, match="Transform should be a dict."): DefaultPreprocess(train_transform="choco") with pytest.raises(MisconfigurationException, match="train_transform contains {'choco'}. Only"): DefaultPreprocess(train_transform={"choco": None}) preprocess = DefaultPreprocess( train_transform={"to_tensor_transform": torch.nn.Linear(1, 1)}) # keep is None assert preprocess._train_collate_in_worker_from_transform is True assert preprocess._val_collate_in_worker_from_transform is None assert preprocess._test_collate_in_worker_from_transform is None assert preprocess._predict_collate_in_worker_from_transform is None with pytest.raises( MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`" ): preprocess = DefaultPreprocess( train_transform={ "per_batch_transform": torch.nn.Linear(1, 1), "per_sample_transform_on_device": torch.nn.Linear(1, 1) }) preprocess = DefaultPreprocess( train_transform={"per_batch_transform": torch.nn.Linear(1, 1)}, predict_transform={ "per_sample_transform_on_device": torch.nn.Linear(1, 1) }) # keep is None assert preprocess._train_collate_in_worker_from_transform is True assert preprocess._val_collate_in_worker_from_transform is None assert preprocess._test_collate_in_worker_from_transform is None assert preprocess._predict_collate_in_worker_from_transform is False train_preprocessor = DataPipeline( preprocess=preprocess).worker_preprocessor(RunningStage.TRAINING) val_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor( RunningStage.VALIDATING) test_preprocessor = DataPipeline( preprocess=preprocess).worker_preprocessor(RunningStage.TESTING) predict_preprocessor = DataPipeline( preprocess=preprocess).worker_preprocessor(RunningStage.PREDICTING) assert train_preprocessor.collate_fn.func == default_collate assert val_preprocessor.collate_fn.func == default_collate assert test_preprocessor.collate_fn.func == default_collate assert predict_preprocessor.collate_fn.func == DataPipeline._identity class CustomPreprocess(DefaultPreprocess): def per_sample_transform_on_device(self, sample: Any) -> Any: return super().per_sample_transform_on_device(sample) def per_batch_transform(self, batch: Any) -> Any: return super().per_batch_transform(batch) preprocess = CustomPreprocess( train_transform={"per_batch_transform": torch.nn.Linear(1, 1)}, predict_transform={ "per_sample_transform_on_device": torch.nn.Linear(1, 1) }) # keep is None assert preprocess._train_collate_in_worker_from_transform is True assert preprocess._val_collate_in_worker_from_transform is None assert preprocess._test_collate_in_worker_from_transform is None assert preprocess._predict_collate_in_worker_from_transform is False data_pipeline = DataPipeline(preprocess=preprocess) train_preprocessor = data_pipeline.worker_preprocessor( RunningStage.TRAINING) with pytest.raises( MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`" ): val_preprocessor = data_pipeline.worker_preprocessor( RunningStage.VALIDATING) with pytest.raises( MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`" ): test_preprocessor = data_pipeline.worker_preprocessor( RunningStage.TESTING) predict_preprocessor = data_pipeline.worker_preprocessor( RunningStage.PREDICTING) assert train_preprocessor.collate_fn.func == default_collate assert predict_preprocessor.collate_fn.func == DataPipeline._identity
def test_data_pipeline_is_overriden_and_resolve_function_hierarchy(tmpdir): class CustomPreprocess(Preprocess): def load_data(self, *_, **__): pass def test_load_data(self, *_, **__): pass def predict_load_data(self, *_, **__): pass def predict_load_sample(self, *_, **__): pass def val_load_sample(self, *_, **__): pass def val_pre_tensor_transform(self, *_, **__): pass def predict_to_tensor_transform(self, *_, **__): pass def train_post_tensor_transform(self, *_, **__): pass def test_collate(self, *_, **__): pass def val_per_sample_transform_on_device(self, *_, **__): pass def train_per_batch_transform_on_device(self, *_, **__): pass def test_per_batch_transform_on_device(self, *_, **__): pass preprocess = CustomPreprocess() data_pipeline = DataPipeline(preprocess) train_func_names = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.TRAINING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } val_func_names = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.VALIDATING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } test_func_names = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.TESTING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } predict_func_names = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.PREDICTING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } # load_data assert train_func_names["load_data"] == "load_data" assert val_func_names["load_data"] == "load_data" assert test_func_names["load_data"] == "test_load_data" assert predict_func_names["load_data"] == "predict_load_data" # load_sample assert train_func_names["load_sample"] == "load_sample" assert val_func_names["load_sample"] == "val_load_sample" assert test_func_names["load_sample"] == "load_sample" assert predict_func_names["load_sample"] == "predict_load_sample" # pre_tensor_transform assert train_func_names["pre_tensor_transform"] == "pre_tensor_transform" assert val_func_names["pre_tensor_transform"] == "val_pre_tensor_transform" assert test_func_names["pre_tensor_transform"] == "pre_tensor_transform" assert predict_func_names["pre_tensor_transform"] == "pre_tensor_transform" # to_tensor_transform assert train_func_names["to_tensor_transform"] == "to_tensor_transform" assert val_func_names["to_tensor_transform"] == "to_tensor_transform" assert test_func_names["to_tensor_transform"] == "to_tensor_transform" assert predict_func_names["to_tensor_transform"] == "predict_to_tensor_transform" # post_tensor_transform assert train_func_names["post_tensor_transform"] == "train_post_tensor_transform" assert val_func_names["post_tensor_transform"] == "post_tensor_transform" assert test_func_names["post_tensor_transform"] == "post_tensor_transform" assert predict_func_names["post_tensor_transform"] == "post_tensor_transform" # collate assert train_func_names["collate"] == "collate" assert val_func_names["collate"] == "collate" assert test_func_names["collate"] == "test_collate" assert predict_func_names["collate"] == "collate" # per_sample_transform_on_device assert train_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device" assert val_func_names["per_sample_transform_on_device"] == "val_per_sample_transform_on_device" assert test_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device" assert predict_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device" # per_batch_transform_on_device assert train_func_names["per_batch_transform_on_device"] == "train_per_batch_transform_on_device" assert val_func_names["per_batch_transform_on_device"] == "per_batch_transform_on_device" assert test_func_names["per_batch_transform_on_device"] == "test_per_batch_transform_on_device" assert predict_func_names["per_batch_transform_on_device"] == "per_batch_transform_on_device" train_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TRAINING) val_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.VALIDATING) test_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TESTING) predict_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.PREDICTING) _seq = train_worker_preprocessor.per_sample_transform assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.train_post_tensor_transform assert train_worker_preprocessor.collate_fn.func == default_collate assert train_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform _seq = val_worker_preprocessor.per_sample_transform assert _seq.pre_tensor_transform.func == preprocess.val_pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform assert val_worker_preprocessor.collate_fn.func == data_pipeline._identity assert val_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform _seq = test_worker_preprocessor.per_sample_transform assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform assert test_worker_preprocessor.collate_fn.func == preprocess.test_collate assert test_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform _seq = predict_worker_preprocessor.per_sample_transform assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.predict_to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform assert predict_worker_preprocessor.collate_fn.func == default_collate assert predict_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform
def test_attaching_datapipeline_to_model(tmpdir): preprocess = TestPreprocess() data_pipeline = DataPipeline(preprocess) class TestModel(CustomModel): stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING] on_train_start_called = False on_val_start_called = False on_test_start_called = False on_predict_start_called = False def on_fit_start(self): assert self.predict_step.__self__ == self self._saved_predict_step = self.predict_step def _compare_pre_processor(self, p1, p2): p1_seq = p1.per_sample_transform p2_seq = p2.per_sample_transform assert p1_seq.pre_tensor_transform.func == p2_seq.pre_tensor_transform.func assert p1_seq.to_tensor_transform.func == p2_seq.to_tensor_transform.func assert p1_seq.post_tensor_transform.func == p2_seq.post_tensor_transform.func assert p1.collate_fn.func == p2.collate_fn.func assert p1.per_batch_transform.func == p2.per_batch_transform.func def _assert_stage_orchestrator_state( self, stage_mapping: Dict, current_running_stage: RunningStage, cls=_PreProcessor ): assert isinstance(stage_mapping[current_running_stage], cls) assert stage_mapping[current_running_stage] def on_train_dataloader(self) -> None: current_running_stage = RunningStage.TRAINING self.on_train_dataloader_called = True collate_fn = self.train_dataloader().collate_fn # noqa F811 assert collate_fn == default_collate assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator) super().on_train_dataloader() collate_fn = self.train_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage self._compare_pre_processor(collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage) def on_val_dataloader(self) -> None: current_running_stage = RunningStage.VALIDATING self.on_val_dataloader_called = True collate_fn = self.val_dataloader().collate_fn # noqa F811 assert collate_fn == default_collate assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) super().on_val_dataloader() collate_fn = self.val_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage self._compare_pre_processor(collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage) def on_test_dataloader(self) -> None: current_running_stage = RunningStage.TESTING self.on_test_dataloader_called = True collate_fn = self.test_dataloader().collate_fn # noqa F811 assert collate_fn == default_collate assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator) super().on_test_dataloader() collate_fn = self.test_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage self._compare_pre_processor(collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage) def on_predict_dataloader(self) -> None: current_running_stage = RunningStage.PREDICTING self.on_predict_dataloader_called = True collate_fn = self.predict_dataloader().collate_fn # noqa F811 assert collate_fn == default_collate assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) assert self.predict_step == self._saved_predict_step super().on_predict_dataloader() collate_fn = self.predict_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage self._compare_pre_processor(collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) assert isinstance(self.predict_step, _StageOrchestrator) self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage) self._assert_stage_orchestrator_state( self.predict_step._stage_mapping, current_running_stage, cls=_PostProcessor ) def on_fit_end(self) -> None: super().on_fit_end() assert self.train_dataloader().collate_fn == default_collate assert self.val_dataloader().collate_fn == default_collate assert self.test_dataloader().collate_fn == default_collate assert self.predict_dataloader().collate_fn == default_collate assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator) assert self.predict_step == self._saved_predict_step datamodule = CustomDataModule() datamodule._data_pipeline = data_pipeline model = TestModel() trainer = Trainer(fast_dev_run=True) trainer.fit(model, datamodule=datamodule) trainer.test(model) trainer.predict(model) assert model.on_train_dataloader_called assert model.on_val_dataloader_called assert model.on_test_dataloader_called assert model.on_predict_dataloader_called
def build_data_pipeline( self, data_pipeline: Optional[DataPipeline] = None ) -> Optional[DataPipeline]: """Build a :class:`.DataPipeline` incorporating available :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` objects. These will be overridden in the following resolution order (lowest priority first): - Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`. - :class:`.Task` defaults given to ``.Task.__init__``. - :class:`.Task` manual overrides by setting :py:attr:`~data_pipeline`. - :class:`.DataPipeline` passed to this method. Args: data_pipeline: Optional highest priority source of :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`. Returns: The fully resolved :class:`.DataPipeline`. """ preprocess, postprocess, serializer = None, None, None # Datamodule if self.datamodule is not None and getattr( self.datamodule, 'data_pipeline', None) is not None: preprocess = getattr(self.datamodule.data_pipeline, '_preprocess_pipeline', None) postprocess = getattr(self.datamodule.data_pipeline, '_postprocess_pipeline', None) serializer = getattr(self.datamodule.data_pipeline, '_serializer', None) elif self.trainer is not None and hasattr( self.trainer, 'datamodule') and getattr( self.trainer.datamodule, 'data_pipeline', None) is not None: preprocess = getattr(self.trainer.datamodule.data_pipeline, '_preprocess_pipeline', None) postprocess = getattr(self.trainer.datamodule.data_pipeline, '_postprocess_pipeline', None) serializer = getattr(self.trainer.datamodule.data_pipeline, '_serializer', None) # Defaults / task attributes preprocess, postprocess, serializer = Task._resolve( preprocess, postprocess, serializer, self._preprocess, self._postprocess, self.serializer, ) # Datapipeline if data_pipeline is not None: preprocess, postprocess, serializer = Task._resolve( preprocess, postprocess, serializer, getattr(data_pipeline, '_preprocess_pipeline', None), getattr(data_pipeline, '_postprocess_pipeline', None), getattr(data_pipeline, '_serializer', None), ) data_pipeline = DataPipeline(preprocess, postprocess, serializer) data_pipeline.initialize() return data_pipeline
def from_load_data_inputs( cls, train_load_data_input: Optional[Any] = None, val_load_data_input: Optional[Any] = None, test_load_data_input: Optional[Any] = None, predict_load_data_input: Optional[Any] = None, data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, postprocess: Optional[Postprocess] = None, use_iterable_auto_dataset: bool = False, seed: int = 42, val_split: Optional[float] = None, **kwargs, ) -> 'DataModule': """ This functions is an helper to generate a ``DataModule`` from a ``DataPipeline``. Args: cls: ``DataModule`` subclass train_load_data_input: Data to be received by the ``train_load_data`` function from this :class:`~flash.data.process.Preprocess` val_load_data_input: Data to be received by the ``val_load_data`` function from this :class:`~flash.data.process.Preprocess` test_load_data_input: Data to be received by the ``test_load_data`` function from this :class:`~flash.data.process.Preprocess` predict_load_data_input: Data to be received by the ``predict_load_data`` function from this :class:`~flash.data.process.Preprocess` kwargs: Any extra arguments to instantiate the provided ``DataModule`` """ # trick to get data_pipeline from empty DataModule if preprocess or postprocess: data_pipeline = DataPipeline( preprocess or cls(**kwargs).preprocess, postprocess or cls(**kwargs).postprocess, ) else: data_pipeline = cls(**kwargs).data_pipeline data_fetcher: BaseDataFetcher = data_fetcher or cls.configure_data_fetcher( ) data_fetcher.attach_to_preprocess(data_pipeline._preprocess_pipeline) train_dataset = cls._generate_dataset_if_possible( train_load_data_input, running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline, use_iterable_auto_dataset=use_iterable_auto_dataset, ) val_dataset = cls._generate_dataset_if_possible( val_load_data_input, running_stage=RunningStage.VALIDATING, data_pipeline=data_pipeline, use_iterable_auto_dataset=use_iterable_auto_dataset, ) test_dataset = cls._generate_dataset_if_possible( test_load_data_input, running_stage=RunningStage.TESTING, data_pipeline=data_pipeline, use_iterable_auto_dataset=use_iterable_auto_dataset, ) predict_dataset = cls._generate_dataset_if_possible( predict_load_data_input, running_stage=RunningStage.PREDICTING, data_pipeline=data_pipeline, use_iterable_auto_dataset=use_iterable_auto_dataset, ) if train_dataset is not None and (val_split is not None and val_dataset is None): train_dataset, val_dataset = cls._split_train_val( train_dataset, val_split) datamodule = cls(train_dataset=train_dataset, val_dataset=val_dataset, test_dataset=test_dataset, predict_dataset=predict_dataset, **kwargs) datamodule._preprocess = data_pipeline._preprocess_pipeline datamodule._postprocess = data_pipeline._postprocess_pipeline data_fetcher.attach_to_datamodule(datamodule) return datamodule
def build_data_pipeline( self, data_source: Optional[str] = None, data_pipeline: Optional[DataPipeline] = None, ) -> Optional[DataPipeline]: """Build a :class:`.DataPipeline` incorporating available :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` objects. These will be overridden in the following resolution order (lowest priority first): - Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`. - :class:`.Task` defaults given to ``.Task.__init__``. - :class:`.Task` manual overrides by setting :py:attr:`~data_pipeline`. - :class:`.DataPipeline` passed to this method. Args: data_pipeline: Optional highest priority source of :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`. Returns: The fully resolved :class:`.DataPipeline`. """ old_data_source, preprocess, postprocess, serializer = None, None, None, None # Datamodule if self.datamodule is not None and getattr( self.datamodule, 'data_pipeline', None) is not None: old_data_source = getattr(self.datamodule.data_pipeline, 'data_source', None) preprocess = getattr(self.datamodule.data_pipeline, '_preprocess_pipeline', None) postprocess = getattr(self.datamodule.data_pipeline, '_postprocess_pipeline', None) serializer = getattr(self.datamodule.data_pipeline, '_serializer', None) elif self.trainer is not None and hasattr( self.trainer, 'datamodule') and getattr( self.trainer.datamodule, 'data_pipeline', None) is not None: old_data_source = getattr(self.trainer.datamodule.data_pipeline, 'data_source', None) preprocess = getattr(self.trainer.datamodule.data_pipeline, '_preprocess_pipeline', None) postprocess = getattr(self.trainer.datamodule.data_pipeline, '_postprocess_pipeline', None) serializer = getattr(self.trainer.datamodule.data_pipeline, '_serializer', None) else: # TODO: we should log with low severity level that we use defaults to create # `preprocess`, `postprocess` and `serializer`. pass # Defaults / task attributes preprocess, postprocess, serializer = Task._resolve( preprocess, postprocess, serializer, self._preprocess, self._postprocess, self.serializer, ) # Datapipeline if data_pipeline is not None: preprocess, postprocess, serializer = Task._resolve( preprocess, postprocess, serializer, getattr(data_pipeline, '_preprocess_pipeline', None), getattr(data_pipeline, '_postprocess_pipeline', None), getattr(data_pipeline, '_serializer', None), ) data_source = data_source or old_data_source if isinstance(data_source, str): if preprocess is None: data_source = DataSource( ) # TODO: warn the user that we are not using the specified data source else: data_source = preprocess.data_source_of_name(data_source) data_pipeline = DataPipeline(data_source, preprocess, postprocess, serializer) self._data_pipeline_state = data_pipeline.initialize( self._data_pipeline_state) return data_pipeline
def postprocess(self, postprocess: Postprocess) -> None: self.data_pipeline = DataPipeline(self.preprocess, postprocess) self._postprocess = postprocess