def test_preprocessing_data_pipeline_no_running_stage(with_dataset): pipe = DataPipeline(_AutoDatasetTestPreprocess(with_dataset)) dataset = pipe._generate_auto_dataset(range(10), running_stage=None) with pytest.raises(RuntimeError, match='`__len__` for `load_sample`'): for idx in range(len(dataset)): dataset[idx] # will be triggered when running stage is set if with_dataset: assert not hasattr(dataset, 'load_sample_was_called') assert not hasattr(dataset, 'load_data_was_called') assert pipe._preprocess_pipeline.load_sample_with_dataset_count == 0 assert pipe._preprocess_pipeline.load_data_with_dataset_count == 0 else: assert pipe._preprocess_pipeline.load_sample_count == 0 assert pipe._preprocess_pipeline.load_data_count == 0 dataset.running_stage = RunningStage.TRAINING if with_dataset: assert pipe._preprocess_pipeline.train_load_data_with_dataset_count == 1 assert dataset.train_load_data_was_called else: assert pipe._preprocess_pipeline.train_load_data_count == 1
def autogenerate_dataset( cls, data: Any, running_stage: RunningStage, whole_data_load_fn: Optional[Callable] = None, per_sample_load_fn: Optional[Callable] = None, data_pipeline: Optional[DataPipeline] = None, ) -> AutoDataset: """ This function is used to generate an ``AutoDataset`` from a ``DataPipeline`` if provided or from the provided ``whole_data_load_fn``, ``per_sample_load_fn`` functions directly """ if whole_data_load_fn is None: whole_data_load_fn = getattr( cls.preprocess_cls, DataPipeline._resolve_function_hierarchy( 'load_data', cls.preprocess_cls, running_stage, Preprocess)) if per_sample_load_fn is None: per_sample_load_fn = getattr( cls.preprocess_cls, DataPipeline._resolve_function_hierarchy( 'load_sample', cls.preprocess_cls, running_stage, Preprocess)) return AutoDataset(data, whole_data_load_fn, per_sample_load_fn, data_pipeline, running_stage=running_stage)
def test_saving_with_serializers(tmpdir): checkpoint_file = os.path.join(tmpdir, 'tmp.ckpt') class CustomModel(Task): def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) serializer = Labels(["a", "b"]) model = CustomModel() trainer = Trainer(fast_dev_run=True) data_pipeline = DataPipeline(DefaultPreprocess(), serializer=serializer) data_pipeline.initialize() model.data_pipeline = data_pipeline assert isinstance(model.preprocess, DefaultPreprocess) dummy_data = DataLoader( list( zip(torch.arange(10, dtype=torch.float), torch.arange(10, dtype=torch.float)))) trainer.fit(model, train_dataloader=dummy_data) trainer.save_checkpoint(checkpoint_file) model = CustomModel.load_from_checkpoint(checkpoint_file) assert isinstance(model.preprocess._data_pipeline_state, DataPipelineState) assert model.preprocess._data_pipeline_state._state[ ClassificationState] == ClassificationState(['a', 'b'])
def test_is_overriden_recursive(tmpdir): class TestPreprocess(DefaultPreprocess): def collate(self, *_): pass def val_collate(self, *_): pass preprocess = TestPreprocess() assert DataPipeline._is_overriden_recursive("collate", preprocess, Preprocess, prefix="val") assert DataPipeline._is_overriden_recursive("collate", preprocess, Preprocess, prefix="train") assert not DataPipeline._is_overriden_recursive( "per_batch_transform_on_device", preprocess, Preprocess, prefix="train") assert not DataPipeline._is_overriden_recursive( "per_batch_transform_on_device", preprocess, Preprocess) with pytest.raises( MisconfigurationException, match="This function doesn't belong to the parent class"): assert not DataPipeline._is_overriden_recursive( "chocolate", preprocess, Preprocess)
def test_data_pipeline_predict_worker_preprocessor_and_device_preprocessor(): preprocess = CustomPreprocess() data_pipeline = DataPipeline(preprocess) data_pipeline.worker_preprocessor(RunningStage.TRAINING) with pytest.raises(MisconfigurationException, match="are mutual exclusive"): data_pipeline.worker_preprocessor(RunningStage.VALIDATING) with pytest.raises(MisconfigurationException, match="are mutual exclusive"): data_pipeline.worker_preprocessor(RunningStage.TESTING) data_pipeline.worker_preprocessor(RunningStage.PREDICTING)
def generate_dataset( self, data: Optional[DATA_TYPE], running_stage: RunningStage, ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: is_none = data is None if isinstance(data, Sequence): is_none = data[0] is None if not is_none: from flash.data.data_pipeline import DataPipeline mock_dataset = typing.cast(AutoDataset, MockDataset()) with CurrentRunningStageFuncContext(running_stage, "load_data", self): load_data: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr( self, DataPipeline._resolve_function_hierarchy( "load_data", self, running_stage, DataSource, ) ) parameters = signature(load_data).parameters if len(parameters) > 1 and "dataset" in parameters: # TODO: This was DATASET_KEY before data = load_data(data, mock_dataset) else: data = load_data(data) if has_len(data): dataset = AutoDataset(data, self, running_stage) else: dataset = IterableAutoDataset(data, self, running_stage) dataset.__dict__.update(mock_dataset.metadata) return dataset
def test_predict_numpy(): img = np.ones((1, 3, 10, 20)) model = SemanticSegmentation(2) data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess()) out = model.predict(img, data_source="numpy", data_pipeline=data_pipe) assert isinstance(out[0], torch.Tensor) assert out[0].shape == (196, 196)
def test_serialization_data_pipeline(tmpdir): model = CustomModel() checkpoint_file = os.path.join(tmpdir, 'tmp.ckpt') checkpoint = ModelCheckpoint(tmpdir, 'test.ckpt') trainer = Trainer(callbacks=[checkpoint], max_epochs=1) dummy_data = DataLoader( list( zip(torch.arange(10, dtype=torch.float), torch.arange(10, dtype=torch.float)))) trainer.fit(model, dummy_data) assert model.data_pipeline is None trainer.save_checkpoint(checkpoint_file) loaded_model = CustomModel.load_from_checkpoint(checkpoint_file) assert loaded_model.data_pipeline is None model.data_pipeline = DataPipeline(CustomPreprocess()) trainer.fit(model, dummy_data) assert model.data_pipeline assert isinstance(model.preprocess, CustomPreprocess) trainer.save_checkpoint(checkpoint_file) loaded_model = CustomModel.load_from_checkpoint(checkpoint_file) assert loaded_model.data_pipeline assert isinstance(loaded_model.preprocess, CustomPreprocess) for file in os.listdir(tmpdir): if file.endswith('.ckpt'): os.remove(os.path.join(tmpdir, file))
def test_autodataset_warning(): with pytest.warns( UserWarning, match= "``datapipeline`` is specified but load_sample and/or load_data are also specified" ): AutoDataset(range(10), load_data=lambda x: x, load_sample=lambda x: x, data_pipeline=DataPipeline())
def test_preprocessing_data_pipeline_with_running_stage(with_dataset): pipe = DataPipeline(_AutoDatasetTestPreprocess(with_dataset)) running_stage = RunningStage.TRAINING dataset = pipe._generate_auto_dataset(range(10), running_stage=running_stage) assert len(dataset) == 10 for idx in range(len(dataset)): dataset[idx] if with_dataset: assert dataset.train_load_sample_was_called assert dataset.train_load_data_was_called assert pipe._preprocess_pipeline.train_load_sample_with_dataset_count == len(dataset) assert pipe._preprocess_pipeline.train_load_data_with_dataset_count == 1 else: assert pipe._preprocess_pipeline.train_load_sample_count == len(dataset) assert pipe._preprocess_pipeline.train_load_data_count == 1
def from_load_data_inputs( cls, train_load_data_input: Optional[Any] = None, val_load_data_input: Optional[Any] = None, test_load_data_input: Optional[Any] = None, predict_load_data_input: Optional[Any] = None, preprocess: Optional[Preprocess] = None, postprocess: Optional[Postprocess] = None, **kwargs, ) -> 'DataModule': """ This functions is an helper to generate a ``DataModule`` from a ``DataPipeline``. Args: cls: ``DataModule`` subclass train_load_data_input: Data to be received by the ``train_load_data`` function from this ``Preprocess`` val_load_data_input: Data to be received by the ``val_load_data`` function from this ``Preprocess`` test_load_data_input: Data to be received by the ``test_load_data`` function from this ``Preprocess`` predict_load_data_input: Data to be received by the ``predict_load_data`` function from this ``Preprocess`` kwargs: Any extra arguments to instantiate the provided ``DataModule`` """ # trick to get data_pipeline from empty DataModule if preprocess or postprocess: data_pipeline = DataPipeline( preprocess or cls(**kwargs).preprocess, postprocess or cls(**kwargs).postprocess, ) else: data_pipeline = cls(**kwargs).data_pipeline train_dataset = cls._generate_dataset_if_possible( train_load_data_input, running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline) val_dataset = cls._generate_dataset_if_possible( val_load_data_input, running_stage=RunningStage.VALIDATING, data_pipeline=data_pipeline) test_dataset = cls._generate_dataset_if_possible( test_load_data_input, running_stage=RunningStage.TESTING, data_pipeline=data_pipeline) predict_dataset = cls._generate_dataset_if_possible( predict_load_data_input, running_stage=RunningStage.PREDICTING, data_pipeline=data_pipeline) datamodule = cls(train_dataset=train_dataset, val_dataset=val_dataset, test_dataset=test_dataset, predict_dataset=predict_dataset, **kwargs) datamodule._preprocess = data_pipeline._preprocess_pipeline datamodule._postprocess = data_pipeline._postprocess_pipeline return datamodule
def autogenerate_dataset( cls, data: Any, running_stage: RunningStage, whole_data_load_fn: Optional[Callable] = None, per_sample_load_fn: Optional[Callable] = None, data_pipeline: Optional[DataPipeline] = None, use_iterable_auto_dataset: bool = False, ) -> BaseAutoDataset: """ This function is used to generate an ``BaseAutoDataset`` from a ``DataPipeline`` if provided or from the provided ``whole_data_load_fn``, ``per_sample_load_fn`` functions directly """ preprocess = getattr(data_pipeline, '_preprocess_pipeline', None) if whole_data_load_fn is None: whole_data_load_fn = getattr( preprocess, DataPipeline._resolve_function_hierarchy( 'load_data', preprocess, running_stage, Preprocess)) if per_sample_load_fn is None: per_sample_load_fn = getattr( preprocess, DataPipeline._resolve_function_hierarchy( 'load_sample', preprocess, running_stage, Preprocess)) if use_iterable_auto_dataset: return IterableAutoDataset(data, whole_data_load_fn, per_sample_load_fn, data_pipeline, running_stage=running_stage) return BaseAutoDataset(data, whole_data_load_fn, per_sample_load_fn, data_pipeline, running_stage=running_stage)
def test_detach_preprocessing_from_model(tmpdir): preprocess = CustomPreprocess() data_pipeline = DataPipeline(preprocess) model = CustomModel() model.data_pipeline = data_pipeline assert model.train_dataloader().collate_fn == default_collate assert model.transfer_batch_to_device.__self__ == model model.on_train_dataloader() assert isinstance(model.train_dataloader().collate_fn, _PreProcessor) assert isinstance(model.transfer_batch_to_device, _StageOrchestrator) model.on_fit_end() assert model.transfer_batch_to_device.__self__ == model assert model.train_dataloader().collate_fn == default_collate
def generate_dataset( self, data: Optional[DATA_TYPE], running_stage: RunningStage, ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: """Generate a single dataset with the given input to :meth:`~flash.data.data_source.DataSource.load_data` for the given ``running_stage``. Args: data: The input to :meth:`~flash.data.data_source.DataSource.load_data` to use to create the dataset. running_stage: The running_stage for this dataset. Returns: The constructed :class:`~flash.data.auto_dataset.BaseAutoDataset`. """ is_none = data is None if isinstance(data, Sequence): is_none = data[0] is None if not is_none: from flash.data.data_pipeline import DataPipeline mock_dataset = typing.cast(AutoDataset, MockDataset()) with CurrentRunningStageFuncContext(running_stage, "load_data", self): load_data: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr( self, DataPipeline._resolve_function_hierarchy( "load_data", self, running_stage, DataSource, )) parameters = signature(load_data).parameters if len( parameters ) > 1 and "dataset" in parameters: # TODO: This was DATASET_KEY before data = load_data(data, mock_dataset) else: data = load_data(data) if has_len(data): dataset = AutoDataset(data, self, running_stage) else: dataset = IterableAutoDataset(data, self, running_stage) dataset.__dict__.update(mock_dataset.metadata) return dataset
def running_stage(self, running_stage: RunningStage) -> None: from flash.data.data_pipeline import DataPipeline # noqa F811 from flash.data.data_source import DataSource # noqa F811 # TODO: something better than this self._running_stage = running_stage self._load_sample_context = CurrentRunningStageFuncContext( self.running_stage, "load_sample", self.data_source) self.load_sample: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr( self.data_source, DataPipeline._resolve_function_hierarchy( 'load_sample', self.data_source, self.running_stage, DataSource, ))
def test_data_pipeline_init_and_assignement(use_preprocess, use_postprocess, tmpdir): class SubPreprocess(Preprocess): pass class SubPostprocess(Postprocess): pass data_pipeline = DataPipeline( SubPreprocess() if use_preprocess else None, SubPostprocess() if use_postprocess else None, ) assert isinstance(data_pipeline._preprocess_pipeline, SubPreprocess if use_preprocess else Preprocess) assert isinstance(data_pipeline._postprocess_pipeline, SubPostprocess if use_postprocess else Postprocess) model = CustomModel(Postprocess()) model.data_pipeline = data_pipeline assert isinstance(model._preprocess, SubPreprocess if use_preprocess else Preprocess) assert isinstance(model._postprocess, SubPostprocess if use_postprocess else Postprocess)
def data_pipeline(self) -> Optional[DataPipeline]: if self._data_pipeline is not None: return self._data_pipeline elif self.preprocess is not None or self.postprocess is not None: # use direct attributes here to avoid recursion with properties that also check the data_pipeline property return DataPipeline(self.preprocess, self.postprocess) elif self.datamodule is not None and getattr( self.datamodule, 'data_pipeline', None) is not None: return self.datamodule.data_pipeline elif self.trainer is not None and hasattr( self.trainer, 'datamodule') and getattr( self.trainer.datamodule, 'data_pipeline', None) is not None: return self.trainer.datamodule.data_pipeline return self._data_pipeline
def test_data_pipeline_init_and_assignement(use_preprocess, use_postprocess, tmpdir): class CustomModel(Task): def __init__(self, postprocess: Optional[Postprocess] = None): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) self._postprocess = postprocess def train_dataloader(self) -> Any: return DataLoader(DummyDataset()) class SubPreprocess(DefaultPreprocess): pass class SubPostprocess(Postprocess): pass data_pipeline = DataPipeline( preprocess=SubPreprocess() if use_preprocess else None, postprocess=SubPostprocess() if use_postprocess else None, ) assert isinstance(data_pipeline._preprocess_pipeline, SubPreprocess if use_preprocess else DefaultPreprocess) assert isinstance(data_pipeline._postprocess_pipeline, SubPostprocess if use_postprocess else Postprocess) model = CustomModel(postprocess=Postprocess()) model.data_pipeline = data_pipeline # TODO: the line below should make the same effect but it's not # data_pipeline._attach_to_model(model) if use_preprocess: assert isinstance(model._preprocess, SubPreprocess) else: assert model._preprocess is None or isinstance(model._preprocess, Preprocess) if use_postprocess: assert isinstance(model._postprocess, SubPostprocess) else: assert model._postprocess is None or isinstance( model._postprocess, Postprocess)
def test_detach_preprocessing_from_model(tmpdir): class CustomModel(Task): def __init__(self, postprocess: Optional[Postprocess] = None): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) self._postprocess = postprocess def train_dataloader(self) -> Any: return DataLoader(DummyDataset()) preprocess = CustomPreprocess() data_pipeline = DataPipeline(preprocess=preprocess) model = CustomModel() model.data_pipeline = data_pipeline assert model.train_dataloader().collate_fn == default_collate assert model.transfer_batch_to_device.__self__ == model model.on_train_dataloader() assert isinstance(model.train_dataloader().collate_fn, _PreProcessor) assert isinstance(model.transfer_batch_to_device, _StageOrchestrator) model.on_fit_end() assert model.transfer_batch_to_device.__self__ == model assert model.train_dataloader().collate_fn == default_collate
def test_attaching_datapipeline_to_model(tmpdir): class SubPreprocess(DefaultPreprocess): pass preprocess = SubPreprocess() data_pipeline = DataPipeline(preprocess=preprocess) class CustomModel(Task): def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) self._postprocess = Postprocess() def training_step(self, batch: Any, batch_idx: int) -> Any: pass def validation_step(self, batch: Any, batch_idx: int) -> Any: pass def test_step(self, batch: Any, batch_idx: int) -> Any: pass def train_dataloader(self) -> Any: return DataLoader(DummyDataset()) def val_dataloader(self) -> Any: return DataLoader(DummyDataset()) def test_dataloader(self) -> Any: return DataLoader(DummyDataset()) def predict_dataloader(self) -> Any: return DataLoader(DummyDataset()) class TestModel(CustomModel): stages = [ RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING ] on_train_start_called = False on_val_start_called = False on_test_start_called = False on_predict_start_called = False def on_fit_start(self): assert self.predict_step.__self__ == self self._saved_predict_step = self.predict_step def _compare_pre_processor(self, p1, p2): p1_seq = p1.per_sample_transform p2_seq = p2.per_sample_transform assert p1_seq.pre_tensor_transform.func == p2_seq.pre_tensor_transform.func assert p1_seq.to_tensor_transform.func == p2_seq.to_tensor_transform.func assert p1_seq.post_tensor_transform.func == p2_seq.post_tensor_transform.func assert p1.collate_fn.func == p2.collate_fn.func assert p1.per_batch_transform.func == p2.per_batch_transform.func def _assert_stage_orchestrator_state( self, stage_mapping: Dict, current_running_stage: RunningStage, cls=_PreProcessor): assert isinstance(stage_mapping[current_running_stage], cls) assert stage_mapping[current_running_stage] def on_train_dataloader(self) -> None: current_running_stage = RunningStage.TRAINING self.on_train_dataloader_called = True collate_fn = self.train_dataloader().collate_fn # noqa F811 assert collate_fn == default_collate assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator) super().on_train_dataloader() collate_fn = self.train_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage self._compare_pre_processor( collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) self._assert_stage_orchestrator_state( self.transfer_batch_to_device._stage_mapping, current_running_stage) def on_val_dataloader(self) -> None: current_running_stage = RunningStage.VALIDATING self.on_val_dataloader_called = True collate_fn = self.val_dataloader().collate_fn # noqa F811 assert collate_fn == default_collate assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) super().on_val_dataloader() collate_fn = self.val_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage self._compare_pre_processor( collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) self._assert_stage_orchestrator_state( self.transfer_batch_to_device._stage_mapping, current_running_stage) def on_test_dataloader(self) -> None: current_running_stage = RunningStage.TESTING self.on_test_dataloader_called = True collate_fn = self.test_dataloader().collate_fn # noqa F811 assert collate_fn == default_collate assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator) super().on_test_dataloader() collate_fn = self.test_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage self._compare_pre_processor( collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) self._assert_stage_orchestrator_state( self.transfer_batch_to_device._stage_mapping, current_running_stage) def on_predict_dataloader(self) -> None: current_running_stage = RunningStage.PREDICTING self.on_predict_dataloader_called = True collate_fn = self.predict_dataloader().collate_fn # noqa F811 assert collate_fn == default_collate assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) assert self.predict_step == self._saved_predict_step super().on_predict_dataloader() collate_fn = self.predict_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage self._compare_pre_processor( collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) assert isinstance(self.predict_step, _StageOrchestrator) self._assert_stage_orchestrator_state( self.transfer_batch_to_device._stage_mapping, current_running_stage) self._assert_stage_orchestrator_state( self.predict_step._stage_mapping, current_running_stage, cls=_PostProcessor) def on_fit_end(self) -> None: super().on_fit_end() assert self.train_dataloader().collate_fn == default_collate assert self.val_dataloader().collate_fn == default_collate assert self.test_dataloader().collate_fn == default_collate assert self.predict_dataloader().collate_fn == default_collate assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator) assert self.predict_step == self._saved_predict_step model = TestModel() model.data_pipeline = data_pipeline trainer = Trainer(fast_dev_run=True) trainer.fit(model) trainer.test(model) trainer.predict(model) assert model.on_train_dataloader_called assert model.on_val_dataloader_called assert model.on_test_dataloader_called assert model.on_predict_dataloader_called
def from_load_data_inputs( cls, train_load_data_input: Optional[Any] = None, val_load_data_input: Optional[Any] = None, test_load_data_input: Optional[Any] = None, predict_load_data_input: Optional[Any] = None, data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, postprocess: Optional[Postprocess] = None, use_iterable_auto_dataset: bool = False, seed: int = 42, val_split: Optional[float] = None, **kwargs, ) -> 'DataModule': """ This functions is an helper to generate a ``DataModule`` from a ``DataPipeline``. Args: cls: ``DataModule`` subclass train_load_data_input: Data to be received by the ``train_load_data`` function from this :class:`~flash.data.process.Preprocess` val_load_data_input: Data to be received by the ``val_load_data`` function from this :class:`~flash.data.process.Preprocess` test_load_data_input: Data to be received by the ``test_load_data`` function from this :class:`~flash.data.process.Preprocess` predict_load_data_input: Data to be received by the ``predict_load_data`` function from this :class:`~flash.data.process.Preprocess` kwargs: Any extra arguments to instantiate the provided ``DataModule`` """ # trick to get data_pipeline from empty DataModule if preprocess or postprocess: data_pipeline = DataPipeline( preprocess or cls(**kwargs).preprocess, postprocess or cls(**kwargs).postprocess, ) else: data_pipeline = cls(**kwargs).data_pipeline data_fetcher: BaseDataFetcher = data_fetcher or cls.configure_data_fetcher( ) data_fetcher.attach_to_preprocess(data_pipeline._preprocess_pipeline) train_dataset = cls._generate_dataset_if_possible( train_load_data_input, running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline, use_iterable_auto_dataset=use_iterable_auto_dataset, ) val_dataset = cls._generate_dataset_if_possible( val_load_data_input, running_stage=RunningStage.VALIDATING, data_pipeline=data_pipeline, use_iterable_auto_dataset=use_iterable_auto_dataset, ) test_dataset = cls._generate_dataset_if_possible( test_load_data_input, running_stage=RunningStage.TESTING, data_pipeline=data_pipeline, use_iterable_auto_dataset=use_iterable_auto_dataset, ) predict_dataset = cls._generate_dataset_if_possible( predict_load_data_input, running_stage=RunningStage.PREDICTING, data_pipeline=data_pipeline, use_iterable_auto_dataset=use_iterable_auto_dataset, ) if train_dataset is not None and (val_split is not None and val_dataset is None): train_dataset, val_dataset = cls._split_train_val( train_dataset, val_split) datamodule = cls(train_dataset=train_dataset, val_dataset=val_dataset, test_dataset=test_dataset, predict_dataset=predict_dataset, **kwargs) datamodule._preprocess = data_pipeline._preprocess_pipeline datamodule._postprocess = data_pipeline._postprocess_pipeline data_fetcher.attach_to_datamodule(datamodule) return datamodule
class Task(LightningModule): """A general Task. Args: model: Model to use for the task. loss_fn: Loss function for training optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. metrics: Metrics to compute for training and evaluation. learning_rate: Learning rate to use for training, defaults to `5e-5` """ def __init__( self, model: Optional[nn.Module] = None, loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, learning_rate: float = 5e-5, ): super().__init__() if model is not None: self.model = model self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn) self.optimizer_cls = optimizer self.metrics = nn.ModuleDict( {} if metrics is None else get_callable_dict(metrics)) self.learning_rate = learning_rate # TODO: should we save more? Bug on some regarding yaml if we save metrics self.save_hyperparameters("learning_rate", "optimizer") self._data_pipeline = None self._preprocess = None self._postprocess = None def step(self, batch: Any, batch_idx: int) -> Any: """ The training/validation/test step. Override for custom behavior. """ x, y = batch y_hat = self(x) output = {"y_hat": y_hat} losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()} logs = {} for name, metric in self.metrics.items(): if isinstance(metric, torchmetrics.metric.Metric): metric(y_hat, y) logs[ name] = metric # log the metric itself if it is of type Metric else: logs[name] = metric(y_hat, y) logs.update(losses) if len(losses.values()) > 1: logs["total_loss"] = sum(losses.values()) return logs["total_loss"], logs output["loss"] = list(losses.values())[0] output["logs"] = logs output["y"] = y return output def forward(self, x: Any) -> Any: return self.model(x) def training_step(self, batch: Any, batch_idx: int) -> Any: output = self.step(batch, batch_idx) self.log_dict({f"train_{k}": v for k, v in output["logs"].items()}, on_step=True, on_epoch=True, prog_bar=True) return output["loss"] def validation_step(self, batch: Any, batch_idx: int) -> None: output = self.step(batch, batch_idx) self.log_dict({f"val_{k}": v for k, v in output["logs"].items()}, on_step=False, on_epoch=True, prog_bar=True) def test_step(self, batch: Any, batch_idx: int) -> None: output = self.step(batch, batch_idx) self.log_dict({f"test_{k}": v for k, v in output["logs"].items()}, on_step=False, on_epoch=True, prog_bar=True) @predict_context def predict( self, x: Any, data_pipeline: Optional[DataPipeline] = None, ) -> Any: """ Predict function for raw data or processed data Args: x: Input to predict. Can be raw data or processed data. If str, assumed to be a folder of data. data_pipeline: Use this to override the current data pipeline Returns: The post-processed model predictions """ running_stage = RunningStage.PREDICTING data_pipeline = data_pipeline or self.data_pipeline x = [x for x in data_pipeline._generate_auto_dataset(x, running_stage)] x = data_pipeline.worker_preprocessor(running_stage)(x) x = self.transfer_batch_to_device(x, self.device) x = data_pipeline.device_preprocessor(running_stage)(x) predictions = self.predict_step( x, 0) # batch_idx is always 0 when running with `model.predict` predictions = data_pipeline.postprocessor(predictions) return predictions def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: if isinstance(batch, tuple): batch = batch[0] elif isinstance(batch, list): # Todo: Understand why stack is needed batch = torch.stack(batch) return self(batch) def configure_optimizers(self) -> torch.optim.Optimizer: return self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate) def configure_finetune_callback(self) -> List[Callback]: return [] @property def preprocess(self) -> Optional[Preprocess]: return getattr(self._data_pipeline, '_preprocess_pipeline', None) or self._preprocess @preprocess.setter def preprocess(self, preprocess: Preprocess) -> None: self._preprocess = preprocess self.data_pipeline = DataPipeline(preprocess, self.postprocess) @property def postprocess(self) -> Postprocess: return getattr(self._data_pipeline, '_postprocess_pipeline', None) or self._postprocess @postprocess.setter def postprocess(self, postprocess: Postprocess) -> None: self.data_pipeline = DataPipeline(self.preprocess, postprocess) self._postprocess = postprocess @property def data_pipeline(self) -> Optional[DataPipeline]: if self._data_pipeline is not None: return self._data_pipeline elif self.preprocess is not None or self.postprocess is not None: # use direct attributes here to avoid recursion with properties that also check the data_pipeline property return DataPipeline(self.preprocess, self.postprocess) elif self.datamodule is not None and getattr( self.datamodule, 'data_pipeline', None) is not None: return self.datamodule.data_pipeline elif self.trainer is not None and hasattr( self.trainer, 'datamodule') and getattr( self.trainer.datamodule, 'data_pipeline', None) is not None: return self.trainer.datamodule.data_pipeline return self._data_pipeline @data_pipeline.setter def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None: self._data_pipeline = data_pipeline if data_pipeline is not None and getattr( data_pipeline, '_preprocess_pipeline', None) is not None: self._preprocess = data_pipeline._preprocess_pipeline if data_pipeline is not None and getattr( data_pipeline, '_postprocess_pipeline', None) is not None: if type(data_pipeline._postprocess_pipeline) != Postprocess: self._postprocess = data_pipeline._postprocess_pipeline def on_train_dataloader(self) -> None: if self.data_pipeline is not None: self.data_pipeline._detach_from_model(self, RunningStage.TRAINING) self.data_pipeline._attach_to_model(self, RunningStage.TRAINING) super().on_train_dataloader() def on_val_dataloader(self) -> None: if self.data_pipeline is not None: self.data_pipeline._detach_from_model(self, RunningStage.VALIDATING) self.data_pipeline._attach_to_model(self, RunningStage.VALIDATING) super().on_val_dataloader() def on_test_dataloader(self, *_) -> None: if self.data_pipeline is not None: self.data_pipeline._detach_from_model(self, RunningStage.TESTING) self.data_pipeline._attach_to_model(self, RunningStage.TESTING) super().on_test_dataloader() def on_predict_dataloader(self) -> None: if self.data_pipeline is not None: self.data_pipeline._detach_from_model(self, RunningStage.PREDICTING) self.data_pipeline._attach_to_model(self, RunningStage.PREDICTING) super().on_predict_dataloader() def on_predict_end(self) -> None: if self.data_pipeline is not None: self.data_pipeline._detach_from_model(self) super().on_predict_end() def on_fit_end(self) -> None: if self.data_pipeline is not None: self.data_pipeline._detach_from_model(self) super().on_fit_end() @staticmethod def _sanetize_funcs(obj: Any) -> Any: if hasattr(obj, "__dict__"): for k, v in obj.__dict__.items(): if isinstance(v, Callable): obj.__dict__[k] = inspect.unwrap(v) return obj def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # TODO: Is this the best way to do this? or should we also use some kind of hparams here? # This may be an issue since here we create the same problems with pickle as in # https://pytorch.org/docs/stable/notes/serialization.html if self.data_pipeline is not None and 'data_pipeline' not in checkpoint: self._preprocess = self._sanetize_funcs(self._preprocess) checkpoint['data_pipeline'] = self.data_pipeline # todo (tchaton) re-wrap visualization super().on_save_checkpoint(checkpoint) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: super().on_load_checkpoint(checkpoint) if 'data_pipeline' in checkpoint: self.data_pipeline = checkpoint['data_pipeline']
def preprocess(self, preprocess: Preprocess) -> None: self._preprocess = preprocess self.data_pipeline = DataPipeline(preprocess, self.postprocess)
def data_pipeline(self) -> DataPipeline: return DataPipeline(self.preprocess, self.postprocess)
def test_data_pipeline_is_overriden_and_resolve_function_hierarchy(tmpdir): class CustomPreprocess(Preprocess): def load_data(self, *_, **__): pass def test_load_data(self, *_, **__): pass def predict_load_data(self, *_, **__): pass def predict_load_sample(self, *_, **__): pass def val_load_sample(self, *_, **__): pass def val_pre_tensor_transform(self, *_, **__): pass def predict_to_tensor_transform(self, *_, **__): pass def train_post_tensor_transform(self, *_, **__): pass def test_collate(self, *_, **__): pass def val_per_sample_transform_on_device(self, *_, **__): pass def train_per_batch_transform_on_device(self, *_, **__): pass def test_per_batch_transform_on_device(self, *_, **__): pass preprocess = CustomPreprocess() data_pipeline = DataPipeline(preprocess) train_func_names = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.TRAINING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } val_func_names = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.VALIDATING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } test_func_names = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.TESTING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } predict_func_names = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.PREDICTING, Preprocess ) for k in data_pipeline.PREPROCESS_FUNCS } # load_data assert train_func_names["load_data"] == "load_data" assert val_func_names["load_data"] == "load_data" assert test_func_names["load_data"] == "test_load_data" assert predict_func_names["load_data"] == "predict_load_data" # load_sample assert train_func_names["load_sample"] == "load_sample" assert val_func_names["load_sample"] == "val_load_sample" assert test_func_names["load_sample"] == "load_sample" assert predict_func_names["load_sample"] == "predict_load_sample" # pre_tensor_transform assert train_func_names["pre_tensor_transform"] == "pre_tensor_transform" assert val_func_names["pre_tensor_transform"] == "val_pre_tensor_transform" assert test_func_names["pre_tensor_transform"] == "pre_tensor_transform" assert predict_func_names["pre_tensor_transform"] == "pre_tensor_transform" # to_tensor_transform assert train_func_names["to_tensor_transform"] == "to_tensor_transform" assert val_func_names["to_tensor_transform"] == "to_tensor_transform" assert test_func_names["to_tensor_transform"] == "to_tensor_transform" assert predict_func_names["to_tensor_transform"] == "predict_to_tensor_transform" # post_tensor_transform assert train_func_names["post_tensor_transform"] == "train_post_tensor_transform" assert val_func_names["post_tensor_transform"] == "post_tensor_transform" assert test_func_names["post_tensor_transform"] == "post_tensor_transform" assert predict_func_names["post_tensor_transform"] == "post_tensor_transform" # collate assert train_func_names["collate"] == "collate" assert val_func_names["collate"] == "collate" assert test_func_names["collate"] == "test_collate" assert predict_func_names["collate"] == "collate" # per_sample_transform_on_device assert train_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device" assert val_func_names["per_sample_transform_on_device"] == "val_per_sample_transform_on_device" assert test_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device" assert predict_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device" # per_batch_transform_on_device assert train_func_names["per_batch_transform_on_device"] == "train_per_batch_transform_on_device" assert val_func_names["per_batch_transform_on_device"] == "per_batch_transform_on_device" assert test_func_names["per_batch_transform_on_device"] == "test_per_batch_transform_on_device" assert predict_func_names["per_batch_transform_on_device"] == "per_batch_transform_on_device" train_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TRAINING) val_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.VALIDATING) test_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TESTING) predict_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.PREDICTING) _seq = train_worker_preprocessor.per_sample_transform assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.train_post_tensor_transform assert train_worker_preprocessor.collate_fn.func == default_collate assert train_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform _seq = val_worker_preprocessor.per_sample_transform assert _seq.pre_tensor_transform.func == preprocess.val_pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform assert val_worker_preprocessor.collate_fn.func == data_pipeline._identity assert val_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform _seq = test_worker_preprocessor.per_sample_transform assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform assert test_worker_preprocessor.collate_fn.func == preprocess.test_collate assert test_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform _seq = predict_worker_preprocessor.per_sample_transform assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.predict_to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform assert predict_worker_preprocessor.collate_fn.func == default_collate assert predict_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform
def test_attaching_datapipeline_to_model(tmpdir): preprocess = TestPreprocess() data_pipeline = DataPipeline(preprocess) class TestModel(CustomModel): stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING] on_train_start_called = False on_val_start_called = False on_test_start_called = False on_predict_start_called = False def on_fit_start(self): assert self.predict_step.__self__ == self self._saved_predict_step = self.predict_step def _compare_pre_processor(self, p1, p2): p1_seq = p1.per_sample_transform p2_seq = p2.per_sample_transform assert p1_seq.pre_tensor_transform.func == p2_seq.pre_tensor_transform.func assert p1_seq.to_tensor_transform.func == p2_seq.to_tensor_transform.func assert p1_seq.post_tensor_transform.func == p2_seq.post_tensor_transform.func assert p1.collate_fn.func == p2.collate_fn.func assert p1.per_batch_transform.func == p2.per_batch_transform.func def _assert_stage_orchestrator_state( self, stage_mapping: Dict, current_running_stage: RunningStage, cls=_PreProcessor ): assert isinstance(stage_mapping[current_running_stage], cls) assert stage_mapping[current_running_stage] def on_train_dataloader(self) -> None: current_running_stage = RunningStage.TRAINING self.on_train_dataloader_called = True collate_fn = self.train_dataloader().collate_fn # noqa F811 assert collate_fn == default_collate assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator) super().on_train_dataloader() collate_fn = self.train_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage self._compare_pre_processor(collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage) def on_val_dataloader(self) -> None: current_running_stage = RunningStage.VALIDATING self.on_val_dataloader_called = True collate_fn = self.val_dataloader().collate_fn # noqa F811 assert collate_fn == default_collate assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) super().on_val_dataloader() collate_fn = self.val_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage self._compare_pre_processor(collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage) def on_test_dataloader(self) -> None: current_running_stage = RunningStage.TESTING self.on_test_dataloader_called = True collate_fn = self.test_dataloader().collate_fn # noqa F811 assert collate_fn == default_collate assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator) super().on_test_dataloader() collate_fn = self.test_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage self._compare_pre_processor(collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage) def on_predict_dataloader(self) -> None: current_running_stage = RunningStage.PREDICTING self.on_predict_dataloader_called = True collate_fn = self.predict_dataloader().collate_fn # noqa F811 assert collate_fn == default_collate assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) assert self.predict_step == self._saved_predict_step super().on_predict_dataloader() collate_fn = self.predict_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage self._compare_pre_processor(collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) assert isinstance(self.predict_step, _StageOrchestrator) self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage) self._assert_stage_orchestrator_state( self.predict_step._stage_mapping, current_running_stage, cls=_PostProcessor ) def on_fit_end(self) -> None: super().on_fit_end() assert self.train_dataloader().collate_fn == default_collate assert self.val_dataloader().collate_fn == default_collate assert self.test_dataloader().collate_fn == default_collate assert self.predict_dataloader().collate_fn == default_collate assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator) assert self.predict_step == self._saved_predict_step datamodule = CustomDataModule() datamodule._data_pipeline = data_pipeline model = TestModel() trainer = Trainer(fast_dev_run=True) trainer.fit(model, datamodule=datamodule) trainer.test(model) trainer.predict(model) assert model.on_train_dataloader_called assert model.on_val_dataloader_called assert model.on_test_dataloader_called assert model.on_predict_dataloader_called
def postprocess(self, postprocess: Postprocess) -> None: self.data_pipeline = DataPipeline(self.preprocess, postprocess) self._postprocess = postprocess
def build_data_pipeline( self, data_source: Optional[str] = None, data_pipeline: Optional[DataPipeline] = None, ) -> Optional[DataPipeline]: """Build a :class:`.DataPipeline` incorporating available :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` objects. These will be overridden in the following resolution order (lowest priority first): - Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`. - :class:`.Task` defaults given to ``.Task.__init__``. - :class:`.Task` manual overrides by setting :py:attr:`~data_pipeline`. - :class:`.DataPipeline` passed to this method. Args: data_pipeline: Optional highest priority source of :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`. Returns: The fully resolved :class:`.DataPipeline`. """ old_data_source, preprocess, postprocess, serializer = None, None, None, None # Datamodule if self.datamodule is not None and getattr( self.datamodule, 'data_pipeline', None) is not None: old_data_source = getattr(self.datamodule.data_pipeline, 'data_source', None) preprocess = getattr(self.datamodule.data_pipeline, '_preprocess_pipeline', None) postprocess = getattr(self.datamodule.data_pipeline, '_postprocess_pipeline', None) serializer = getattr(self.datamodule.data_pipeline, '_serializer', None) elif self.trainer is not None and hasattr( self.trainer, 'datamodule') and getattr( self.trainer.datamodule, 'data_pipeline', None) is not None: old_data_source = getattr(self.trainer.datamodule.data_pipeline, 'data_source', None) preprocess = getattr(self.trainer.datamodule.data_pipeline, '_preprocess_pipeline', None) postprocess = getattr(self.trainer.datamodule.data_pipeline, '_postprocess_pipeline', None) serializer = getattr(self.trainer.datamodule.data_pipeline, '_serializer', None) else: # TODO: we should log with low severity level that we use defaults to create # `preprocess`, `postprocess` and `serializer`. pass # Defaults / task attributes preprocess, postprocess, serializer = Task._resolve( preprocess, postprocess, serializer, self._preprocess, self._postprocess, self.serializer, ) # Datapipeline if data_pipeline is not None: preprocess, postprocess, serializer = Task._resolve( preprocess, postprocess, serializer, getattr(data_pipeline, '_preprocess_pipeline', None), getattr(data_pipeline, '_postprocess_pipeline', None), getattr(data_pipeline, '_serializer', None), ) data_source = data_source or old_data_source if isinstance(data_source, str): if preprocess is None: data_source = DataSource( ) # TODO: warn the user that we are not using the specified data source else: data_source = preprocess.data_source_of_name(data_source) data_pipeline = DataPipeline(data_source, preprocess, postprocess, serializer) self._data_pipeline_state = data_pipeline.initialize( self._data_pipeline_state) return data_pipeline
def test_preprocess_transforms(tmpdir): """ This test makes sure that when a preprocess is being provided transforms as dictionaries, checking is done properly, and collate_in_worker_from_transform is properly extracted. """ with pytest.raises(MisconfigurationException, match="Transform should be a dict."): DefaultPreprocess(train_transform="choco") with pytest.raises(MisconfigurationException, match="train_transform contains {'choco'}. Only"): DefaultPreprocess(train_transform={"choco": None}) preprocess = DefaultPreprocess( train_transform={"to_tensor_transform": torch.nn.Linear(1, 1)}) # keep is None assert preprocess._train_collate_in_worker_from_transform is True assert preprocess._val_collate_in_worker_from_transform is None assert preprocess._test_collate_in_worker_from_transform is None assert preprocess._predict_collate_in_worker_from_transform is None with pytest.raises( MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`" ): preprocess = DefaultPreprocess( train_transform={ "per_batch_transform": torch.nn.Linear(1, 1), "per_sample_transform_on_device": torch.nn.Linear(1, 1) }) preprocess = DefaultPreprocess( train_transform={"per_batch_transform": torch.nn.Linear(1, 1)}, predict_transform={ "per_sample_transform_on_device": torch.nn.Linear(1, 1) }) # keep is None assert preprocess._train_collate_in_worker_from_transform is True assert preprocess._val_collate_in_worker_from_transform is None assert preprocess._test_collate_in_worker_from_transform is None assert preprocess._predict_collate_in_worker_from_transform is False train_preprocessor = DataPipeline( preprocess=preprocess).worker_preprocessor(RunningStage.TRAINING) val_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor( RunningStage.VALIDATING) test_preprocessor = DataPipeline( preprocess=preprocess).worker_preprocessor(RunningStage.TESTING) predict_preprocessor = DataPipeline( preprocess=preprocess).worker_preprocessor(RunningStage.PREDICTING) assert train_preprocessor.collate_fn.func == default_collate assert val_preprocessor.collate_fn.func == default_collate assert test_preprocessor.collate_fn.func == default_collate assert predict_preprocessor.collate_fn.func == DataPipeline._identity class CustomPreprocess(DefaultPreprocess): def per_sample_transform_on_device(self, sample: Any) -> Any: return super().per_sample_transform_on_device(sample) def per_batch_transform(self, batch: Any) -> Any: return super().per_batch_transform(batch) preprocess = CustomPreprocess( train_transform={"per_batch_transform": torch.nn.Linear(1, 1)}, predict_transform={ "per_sample_transform_on_device": torch.nn.Linear(1, 1) }) # keep is None assert preprocess._train_collate_in_worker_from_transform is True assert preprocess._val_collate_in_worker_from_transform is None assert preprocess._test_collate_in_worker_from_transform is None assert preprocess._predict_collate_in_worker_from_transform is False data_pipeline = DataPipeline(preprocess=preprocess) train_preprocessor = data_pipeline.worker_preprocessor( RunningStage.TRAINING) with pytest.raises( MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`" ): val_preprocessor = data_pipeline.worker_preprocessor( RunningStage.VALIDATING) with pytest.raises( MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`" ): test_preprocessor = data_pipeline.worker_preprocessor( RunningStage.TESTING) predict_preprocessor = data_pipeline.worker_preprocessor( RunningStage.PREDICTING) assert train_preprocessor.collate_fn.func == default_collate assert predict_preprocessor.collate_fn.func == DataPipeline._identity
def build_data_pipeline( self, data_pipeline: Optional[DataPipeline] = None ) -> Optional[DataPipeline]: """Build a :class:`.DataPipeline` incorporating available :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` objects. These will be overridden in the following resolution order (lowest priority first): - Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`. - :class:`.Task` defaults given to ``.Task.__init__``. - :class:`.Task` manual overrides by setting :py:attr:`~data_pipeline`. - :class:`.DataPipeline` passed to this method. Args: data_pipeline: Optional highest priority source of :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`. Returns: The fully resolved :class:`.DataPipeline`. """ preprocess, postprocess, serializer = None, None, None # Datamodule if self.datamodule is not None and getattr( self.datamodule, 'data_pipeline', None) is not None: preprocess = getattr(self.datamodule.data_pipeline, '_preprocess_pipeline', None) postprocess = getattr(self.datamodule.data_pipeline, '_postprocess_pipeline', None) serializer = getattr(self.datamodule.data_pipeline, '_serializer', None) elif self.trainer is not None and hasattr( self.trainer, 'datamodule') and getattr( self.trainer.datamodule, 'data_pipeline', None) is not None: preprocess = getattr(self.trainer.datamodule.data_pipeline, '_preprocess_pipeline', None) postprocess = getattr(self.trainer.datamodule.data_pipeline, '_postprocess_pipeline', None) serializer = getattr(self.trainer.datamodule.data_pipeline, '_serializer', None) # Defaults / task attributes preprocess, postprocess, serializer = Task._resolve( preprocess, postprocess, serializer, self._preprocess, self._postprocess, self.serializer, ) # Datapipeline if data_pipeline is not None: preprocess, postprocess, serializer = Task._resolve( preprocess, postprocess, serializer, getattr(data_pipeline, '_preprocess_pipeline', None), getattr(data_pipeline, '_postprocess_pipeline', None), getattr(data_pipeline, '_serializer', None), ) data_pipeline = DataPipeline(preprocess, postprocess, serializer) data_pipeline.initialize() return data_pipeline