def test_is_overridden_recursive(tmpdir): class TestInputTransform(InputTransform): @staticmethod def custom_transform(x): return x def collate(self): return self.custom_transform def val_collate(self): return self.custom_transform input_transform = TestInputTransform() assert DataPipeline._is_overridden_recursive("collate", input_transform, InputTransform, prefix="val") assert DataPipeline._is_overridden_recursive("collate", input_transform, InputTransform, prefix="train") assert not DataPipeline._is_overridden_recursive( "per_batch_transform_on_device", input_transform, InputTransform, prefix="train") assert not DataPipeline._is_overridden_recursive( "per_batch_transform_on_device", input_transform, InputTransform) with pytest.raises( MisconfigurationException, match="This function doesn't belong to the parent class"): assert not DataPipeline._is_overridden_recursive( "chocolate", input_transform, InputTransform)
def test_is_overriden_recursive(tmpdir): class TestPreprocess(DefaultPreprocess): def collate(self, *_): pass def val_collate(self, *_): pass preprocess = TestPreprocess() assert DataPipeline._is_overriden_recursive("collate", preprocess, Preprocess, prefix="val") assert DataPipeline._is_overriden_recursive("collate", preprocess, Preprocess, prefix="train") assert not DataPipeline._is_overriden_recursive( "per_batch_transform_on_device", preprocess, Preprocess, prefix="train") assert not DataPipeline._is_overriden_recursive( "per_batch_transform_on_device", preprocess, Preprocess) with pytest.raises( MisconfigurationException, match="This function doesn't belong to the parent class"): assert not DataPipeline._is_overriden_recursive( "chocolate", preprocess, Preprocess)
def test_saving_with_serializers(tmpdir): checkpoint_file = os.path.join(tmpdir, 'tmp.ckpt') class CustomModel(Task): def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) serializer = Labels(["a", "b"]) model = CustomModel() trainer = Trainer(fast_dev_run=True) data_pipeline = DataPipeline(preprocess=DefaultPreprocess(), serializer=serializer) data_pipeline.initialize() model.data_pipeline = data_pipeline assert isinstance(model.preprocess, DefaultPreprocess) dummy_data = DataLoader( list( zip(torch.arange(10, dtype=torch.float), torch.arange(10, dtype=torch.float)))) trainer.fit(model, train_dataloader=dummy_data) trainer.save_checkpoint(checkpoint_file) model = CustomModel.load_from_checkpoint(checkpoint_file) assert isinstance(model._data_pipeline_state, DataPipelineState) assert model._data_pipeline_state._state[LabelsState] == LabelsState( ["a", "b"])
def test_data_pipeline_predict_worker_preprocessor_and_device_preprocessor(): preprocess = CustomPreprocess() data_pipeline = DataPipeline(preprocess=preprocess) data_pipeline.worker_preprocessor(RunningStage.TRAINING) with pytest.raises(MisconfigurationException, match="are mutually exclusive"): data_pipeline.worker_preprocessor(RunningStage.VALIDATING) with pytest.raises(MisconfigurationException, match="are mutually exclusive"): data_pipeline.worker_preprocessor(RunningStage.TESTING) data_pipeline.worker_preprocessor(RunningStage.PREDICTING)
def test_predict_sklearn(): """Tests that we can generate predictions from a scikit-learn ``Bunch``.""" bunch = datasets.load_iris() model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes) data_pipe = DataPipeline(preprocess=TemplatePreprocess()) out = model.predict(bunch, data_source="sklearn", data_pipeline=data_pipe) assert isinstance(out[0], int)
def test_predict_dataset(tmpdir): """Tests that we can generate predictions from a pytorch geometric dataset.""" tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) data_pipe = DataPipeline(preprocess=GraphClassificationPreprocess()) out = model.predict(tudataset, data_source="datasets", data_pipeline=data_pipe) assert isinstance(out[0], int)
def test_serialization_data_pipeline(tmpdir): model = CustomModel() checkpoint_file = os.path.join(tmpdir, "tmp.ckpt") checkpoint = ModelCheckpoint(tmpdir, "test.ckpt") trainer = Trainer(callbacks=[checkpoint], max_epochs=1) dummy_data = DataLoader(list(zip(torch.arange(10, dtype=torch.float), torch.arange(10, dtype=torch.float)))) trainer.fit(model, dummy_data) assert model.data_pipeline trainer.save_checkpoint(checkpoint_file) loaded_model = CustomModel.load_from_checkpoint(checkpoint_file) assert loaded_model.data_pipeline model.data_pipeline = DataPipeline(preprocess=CustomPreprocess()) assert isinstance(model.preprocess, CustomPreprocess) trainer.fit(model, dummy_data) assert model.data_pipeline assert isinstance(model.preprocess, CustomPreprocess) trainer.save_checkpoint(checkpoint_file) def fn(*args, **kwargs): return "0.0.2" CustomPreprocess.version = fn loaded_model = CustomModel.load_from_checkpoint(checkpoint_file) assert loaded_model.data_pipeline assert isinstance(loaded_model.preprocess, CustomPreprocess) for file in os.listdir(tmpdir): if file.endswith(".ckpt"): os.remove(os.path.join(tmpdir, file))
def test_predict_numpy(): """Tests that we can generate predictions from a numpy array.""" row = np.random.rand(1, DummyDataset.num_features) model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes) data_pipe = DataPipeline(preprocess=TemplatePreprocess()) out = model.predict(row, data_pipeline=data_pipe) assert isinstance(out[0], int)
def test_predict_numpy(): img = np.ones((1, 3, 10, 20)) model = SemanticSegmentation(2) data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess( num_classes=1)) out = model.predict(img, data_source="numpy", data_pipeline=data_pipe) assert isinstance(out[0], torch.Tensor) assert out[0].shape == (10, 20)
def test_test(tmpdir): """Tests that the model can be tested on a pytorch geometric dataset.""" tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) model.data_pipeline = DataPipeline(preprocess=GraphClassificationPreprocess()) test_dl = torch.utils.data.DataLoader(tudataset, batch_size=4) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.test(model, test_dl)
def test_predict_numpy(): img = np.ones((1, 3, 64, 64)) model = SemanticSegmentation(2, backbone="mobilenetv3_large_100") data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1)) out = model.predict(img, data_source="numpy", data_pipeline=data_pipe) assert isinstance(out[0], list) assert len(out[0]) == 64 assert len(out[0][0]) == 64
def test_predict_tensor(): img = torch.rand(1, 3, 10, 20) model = SemanticSegmentation(2) data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess( num_classes=1)) out = model.predict(img, data_source="tensors", data_pipeline=data_pipe) assert isinstance(out[0], list) assert len(out[0]) == 10 assert len(out[0][0]) == 20
def test_predict_numpy(): img = np.ones((1, 3, 10, 20)) model = SemanticSegmentation(2) data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess( num_classes=1)) out = model.predict(img, data_source="numpy", data_pipeline=data_pipe) assert isinstance(out[0], list) assert len(out[0]) == 10 assert len(out[0][0]) == 20
def test_data_pipeline_str(): data_pipeline = DataPipeline( data_source=cast(DataSource, "data_source"), preprocess=cast(Preprocess, "preprocess"), postprocess=cast(Postprocess, "postprocess"), serializer=cast(Serializer, "serializer"), deserializer=cast(Deserializer, "deserializer"), ) expected = "data_source=data_source, deserializer=deserializer, " expected += "preprocess=preprocess, postprocess=postprocess, serializer=serializer" assert str(data_pipeline) == (f"DataPipeline({expected})")
def __configure_worker_and_device_collate_fn( running_stage: RunningStage, input_transform: InputTransform) -> Tuple[Callable, Callable]: from flash.core.data.data_pipeline import DataPipeline prefix: str = _STAGES_PREFIX[running_stage] transform_for_stage: _InputTransformPerStage = input_transform._transform[ running_stage] per_batch_transform_overridden: bool = DataPipeline._is_overridden_recursive( "per_batch_transform", input_transform, InputTransform, prefix=prefix) per_sample_transform_on_device_overridden: bool = DataPipeline._is_overridden_recursive( "per_sample_transform_on_device", input_transform, InputTransform, prefix=prefix) is_per_overridden = per_batch_transform_overridden and per_sample_transform_on_device_overridden if transform_for_stage.collate_in_worker_from_transform is None and is_per_overridden: raise MisconfigurationException( f"{input_transform.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` " f"are mutually exclusive for stage {running_stage}") if isinstance(transform_for_stage.collate_in_worker_from_transform, bool): worker_collate_fn, device_collate_fn = __make_collates( input_transform, not transform_for_stage.collate_in_worker_from_transform, input_transform._collate) else: worker_collate_fn, device_collate_fn = __make_collates( input_transform, per_sample_transform_on_device_overridden, input_transform._collate) worker_collate_fn = (worker_collate_fn.collate_fn if isinstance( worker_collate_fn, _InputTransformProcessor) else worker_collate_fn) return worker_collate_fn, device_collate_fn
def _resolve_transforms( self, running_stage: RunningStage) -> Optional[Dict[str, Callable]]: from flash.core.data.data_pipeline import DataPipeline resolved_function = getattr( self, DataPipeline._resolve_function_hierarchy("default_transforms", self, running_stage, Preprocess)) with CurrentRunningStageFuncContext(running_stage, "default_transforms", self): transforms: Optional[Dict[str, Callable]] = resolved_function() return transforms
def running_stage(self, running_stage: RunningStage) -> None: from flash.core.data.data_pipeline import DataPipeline # noqa F811 from flash.core.data.data_source import DataSource # noqa F811 # TODO: something better than this self._running_stage = running_stage self._load_sample_context = CurrentRunningStageFuncContext( self.running_stage, "load_sample", self.data_source) self.load_sample: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr( self.data_source, DataPipeline._resolve_function_hierarchy( 'load_sample', self.data_source, self.running_stage, DataSource, ))
def generate_dataset( self, data: Optional[DATA_TYPE], running_stage: RunningStage, ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: """Generate a single dataset with the given input to :meth:`~flash.core.data.data_source.DataSource.load_data` for the given ``running_stage``. Args: data: The input to :meth:`~flash.core.data.data_source.DataSource.load_data` to use to create the dataset. running_stage: The running_stage for this dataset. Returns: The constructed :class:`~flash.core.data.auto_dataset.BaseAutoDataset`. """ is_none = data is None if isinstance(data, Sequence): is_none = data[0] is None if not is_none: from flash.core.data.data_pipeline import DataPipeline mock_dataset = typing.cast(AutoDataset, MockDataset()) with CurrentRunningStageFuncContext(running_stage, "load_data", self): resolved_func_name = DataPipeline._resolve_function_hierarchy( "load_data", self, running_stage, DataSource) load_data: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr(self, resolved_func_name) parameters = signature(load_data).parameters if len( parameters ) > 1 and "dataset" in parameters: # TODO: This was DATASET_KEY before data = load_data(data, mock_dataset) else: data = load_data(data) if has_len(data): dataset = AutoDataset(data, self, running_stage) else: dataset = IterableAutoDataset(data, self, running_stage) dataset.__dict__.update(mock_dataset.metadata) return dataset
def test_data_pipeline_init_and_assignement(use_preprocess, use_postprocess, tmpdir): class CustomModel(Task): def __init__(self, postprocess: Optional[Postprocess] = None): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) self._postprocess = postprocess def train_dataloader(self) -> Any: return DataLoader(DummyDataset()) class SubPreprocess(DefaultPreprocess): pass class SubPostprocess(Postprocess): pass data_pipeline = DataPipeline( preprocess=SubPreprocess() if use_preprocess else None, postprocess=SubPostprocess() if use_postprocess else None, ) assert isinstance(data_pipeline._preprocess_pipeline, SubPreprocess if use_preprocess else DefaultPreprocess) assert isinstance(data_pipeline._postprocess_pipeline, SubPostprocess if use_postprocess else Postprocess) model = CustomModel(postprocess=Postprocess()) model.data_pipeline = data_pipeline # TODO: the line below should make the same effect but it's not # data_pipeline._attach_to_model(model) if use_preprocess: assert isinstance(model._preprocess, SubPreprocess) else: assert model._preprocess is None or isinstance(model._preprocess, Preprocess) if use_postprocess: assert isinstance(model._postprocess, SubPostprocess) else: assert model._postprocess is None or isinstance( model._postprocess, Postprocess)
def test_detach_preprocessing_from_model(tmpdir): class CustomModel(Task): def __init__(self, postprocess: Optional[Postprocess] = None): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) self._postprocess = postprocess def train_dataloader(self) -> Any: return DataLoader(DummyDataset()) preprocess = CustomPreprocess() data_pipeline = DataPipeline(preprocess=preprocess) model = CustomModel() model.data_pipeline = data_pipeline assert model.train_dataloader().collate_fn == default_collate assert model.transfer_batch_to_device.__self__ == model model.on_train_dataloader() assert isinstance(model.train_dataloader().collate_fn, _Preprocessor) assert isinstance(model.transfer_batch_to_device, _StageOrchestrator) model.on_fit_end() assert model.transfer_batch_to_device.__self__ == model assert model.train_dataloader().collate_fn == default_collate
def test_attaching_datapipeline_to_model(tmpdir): class SubPreprocess(DefaultPreprocess): pass preprocess = SubPreprocess() data_pipeline = DataPipeline(preprocess=preprocess) class CustomModel(Task): def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) self._postprocess = Postprocess() def training_step(self, batch: Any, batch_idx: int) -> Any: pass def validation_step(self, batch: Any, batch_idx: int) -> Any: pass def test_step(self, batch: Any, batch_idx: int) -> Any: pass def train_dataloader(self) -> Any: return DataLoader(DummyDataset()) def val_dataloader(self) -> Any: return DataLoader(DummyDataset()) def test_dataloader(self) -> Any: return DataLoader(DummyDataset()) def predict_dataloader(self) -> Any: return DataLoader(DummyDataset()) class TestModel(CustomModel): stages = [ RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING ] on_train_start_called = False on_val_start_called = False on_test_start_called = False on_predict_start_called = False def on_fit_start(self): assert self.predict_step.__self__ == self self._saved_predict_step = self.predict_step def _compare_pre_processor(self, p1, p2): p1_seq = p1.per_sample_transform p2_seq = p2.per_sample_transform assert p1_seq.pre_tensor_transform.func == p2_seq.pre_tensor_transform.func assert p1_seq.to_tensor_transform.func == p2_seq.to_tensor_transform.func assert p1_seq.post_tensor_transform.func == p2_seq.post_tensor_transform.func assert p1.collate_fn.func == p2.collate_fn.func assert p1.per_batch_transform.func == p2.per_batch_transform.func def _assert_stage_orchestrator_state( self, stage_mapping: Dict, current_running_stage: RunningStage, cls=_Preprocessor): assert isinstance(stage_mapping[current_running_stage], cls) assert stage_mapping[current_running_stage] def on_train_dataloader(self) -> None: current_running_stage = RunningStage.TRAINING self.on_train_dataloader_called = True collate_fn = self.train_dataloader().collate_fn # noqa F811 assert collate_fn == default_collate assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator) super().on_train_dataloader() collate_fn = self.train_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage self._compare_pre_processor( collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) self._assert_stage_orchestrator_state( self.transfer_batch_to_device._stage_mapping, current_running_stage) def on_val_dataloader(self) -> None: current_running_stage = RunningStage.VALIDATING self.on_val_dataloader_called = True collate_fn = self.val_dataloader().collate_fn # noqa F811 assert collate_fn == default_collate assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) super().on_val_dataloader() collate_fn = self.val_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage self._compare_pre_processor( collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) self._assert_stage_orchestrator_state( self.transfer_batch_to_device._stage_mapping, current_running_stage) def on_test_dataloader(self) -> None: current_running_stage = RunningStage.TESTING self.on_test_dataloader_called = True collate_fn = self.test_dataloader().collate_fn # noqa F811 assert collate_fn == default_collate assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator) super().on_test_dataloader() collate_fn = self.test_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage self._compare_pre_processor( collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) self._assert_stage_orchestrator_state( self.transfer_batch_to_device._stage_mapping, current_running_stage) def on_predict_dataloader(self) -> None: current_running_stage = RunningStage.PREDICTING self.on_predict_dataloader_called = True collate_fn = self.predict_dataloader().collate_fn # noqa F811 assert collate_fn == default_collate assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) assert self.predict_step == self._saved_predict_step super().on_predict_dataloader() collate_fn = self.predict_dataloader().collate_fn # noqa F811 assert collate_fn.stage == current_running_stage self._compare_pre_processor( collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage)) assert isinstance(self.transfer_batch_to_device, _StageOrchestrator) assert isinstance(self.predict_step, _StageOrchestrator) self._assert_stage_orchestrator_state( self.transfer_batch_to_device._stage_mapping, current_running_stage) self._assert_stage_orchestrator_state( self.predict_step._stage_mapping, current_running_stage, cls=_Postprocessor) def on_fit_end(self) -> None: super().on_fit_end() assert self.train_dataloader().collate_fn == default_collate assert self.val_dataloader().collate_fn == default_collate assert self.test_dataloader().collate_fn == default_collate assert self.predict_dataloader().collate_fn == default_collate assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator) assert self.predict_step == self._saved_predict_step model = TestModel() model.data_pipeline = data_pipeline trainer = Trainer(fast_dev_run=True) trainer.fit(model) trainer.test(model) trainer.predict(model) assert model.on_train_dataloader_called assert model.on_val_dataloader_called assert model.on_test_dataloader_called assert model.on_predict_dataloader_called
def test_data_pipeline_is_overriden_and_resolve_function_hierarchy(tmpdir): class CustomPreprocess(DefaultPreprocess): def val_pre_tensor_transform(self, *_, **__): pass def predict_to_tensor_transform(self, *_, **__): pass def train_post_tensor_transform(self, *_, **__): pass def test_collate(self, *_, **__): pass def val_per_sample_transform_on_device(self, *_, **__): pass def train_per_batch_transform_on_device(self, *_, **__): pass def test_per_batch_transform_on_device(self, *_, **__): pass preprocess = CustomPreprocess() data_pipeline = DataPipeline(preprocess=preprocess) train_func_names: Dict[str, str] = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.TRAINING, Preprocess) for k in data_pipeline.PREPROCESS_FUNCS } val_func_names: Dict[str, str] = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.VALIDATING, Preprocess) for k in data_pipeline.PREPROCESS_FUNCS } test_func_names: Dict[str, str] = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.TESTING, Preprocess) for k in data_pipeline.PREPROCESS_FUNCS } predict_func_names: Dict[str, str] = { k: data_pipeline._resolve_function_hierarchy( k, data_pipeline._preprocess_pipeline, RunningStage.PREDICTING, Preprocess) for k in data_pipeline.PREPROCESS_FUNCS } # pre_tensor_transform assert train_func_names["pre_tensor_transform"] == "pre_tensor_transform" assert val_func_names["pre_tensor_transform"] == "val_pre_tensor_transform" assert test_func_names["pre_tensor_transform"] == "pre_tensor_transform" assert predict_func_names["pre_tensor_transform"] == "pre_tensor_transform" # to_tensor_transform assert train_func_names["to_tensor_transform"] == "to_tensor_transform" assert val_func_names["to_tensor_transform"] == "to_tensor_transform" assert test_func_names["to_tensor_transform"] == "to_tensor_transform" assert predict_func_names[ "to_tensor_transform"] == "predict_to_tensor_transform" # post_tensor_transform assert train_func_names[ "post_tensor_transform"] == "train_post_tensor_transform" assert val_func_names["post_tensor_transform"] == "post_tensor_transform" assert test_func_names["post_tensor_transform"] == "post_tensor_transform" assert predict_func_names[ "post_tensor_transform"] == "post_tensor_transform" # collate assert train_func_names["collate"] == "collate" assert val_func_names["collate"] == "collate" assert test_func_names["collate"] == "test_collate" assert predict_func_names["collate"] == "collate" # per_sample_transform_on_device assert train_func_names[ "per_sample_transform_on_device"] == "per_sample_transform_on_device" assert val_func_names[ "per_sample_transform_on_device"] == "val_per_sample_transform_on_device" assert test_func_names[ "per_sample_transform_on_device"] == "per_sample_transform_on_device" assert predict_func_names[ "per_sample_transform_on_device"] == "per_sample_transform_on_device" # per_batch_transform_on_device assert train_func_names[ "per_batch_transform_on_device"] == "train_per_batch_transform_on_device" assert val_func_names[ "per_batch_transform_on_device"] == "per_batch_transform_on_device" assert test_func_names[ "per_batch_transform_on_device"] == "test_per_batch_transform_on_device" assert predict_func_names[ "per_batch_transform_on_device"] == "per_batch_transform_on_device" train_worker_preprocessor = data_pipeline.worker_preprocessor( RunningStage.TRAINING) val_worker_preprocessor = data_pipeline.worker_preprocessor( RunningStage.VALIDATING) test_worker_preprocessor = data_pipeline.worker_preprocessor( RunningStage.TESTING) predict_worker_preprocessor = data_pipeline.worker_preprocessor( RunningStage.PREDICTING) _seq = train_worker_preprocessor.per_sample_transform assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.train_post_tensor_transform assert train_worker_preprocessor.collate_fn.func == preprocess.collate assert train_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform _seq = val_worker_preprocessor.per_sample_transform assert _seq.pre_tensor_transform.func == preprocess.val_pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform assert val_worker_preprocessor.collate_fn.func == DataPipeline._identity assert val_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform _seq = test_worker_preprocessor.per_sample_transform assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform assert test_worker_preprocessor.collate_fn.func == preprocess.test_collate assert test_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform _seq = predict_worker_preprocessor.per_sample_transform assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform assert _seq.to_tensor_transform.func == preprocess.predict_to_tensor_transform assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform assert predict_worker_preprocessor.collate_fn.func == preprocess.collate assert predict_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform
def build_data_pipeline( self, data_source: Optional[str] = None, data_pipeline: Optional[DataPipeline] = None, ) -> Optional[DataPipeline]: """Build a :class:`.DataPipeline` incorporating available :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess` objects. These will be overridden in the following resolution order (lowest priority first): - Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`. - :class:`.Task` defaults given to :meth:`.Task.__init__`. - :class:`.Task` manual overrides by setting :py:attr:`~data_pipeline`. - :class:`.DataPipeline` passed to this method. Args: data_pipeline: Optional highest priority source of :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess`. Returns: The fully resolved :class:`.DataPipeline`. """ old_data_source, preprocess, postprocess, serializer = None, None, None, None # Datamodule if self.datamodule is not None and getattr( self.datamodule, 'data_pipeline', None) is not None: old_data_source = getattr(self.datamodule.data_pipeline, 'data_source', None) preprocess = getattr(self.datamodule.data_pipeline, '_preprocess_pipeline', None) postprocess = getattr(self.datamodule.data_pipeline, '_postprocess_pipeline', None) serializer = getattr(self.datamodule.data_pipeline, '_serializer', None) elif self.trainer is not None and hasattr( self.trainer, 'datamodule') and getattr( self.trainer.datamodule, 'data_pipeline', None) is not None: old_data_source = getattr(self.trainer.datamodule.data_pipeline, 'data_source', None) preprocess = getattr(self.trainer.datamodule.data_pipeline, '_preprocess_pipeline', None) postprocess = getattr(self.trainer.datamodule.data_pipeline, '_postprocess_pipeline', None) serializer = getattr(self.trainer.datamodule.data_pipeline, '_serializer', None) else: # TODO: we should log with low severity level that we use defaults to create # `preprocess`, `postprocess` and `serializer`. pass # Defaults / task attributes preprocess, postprocess, serializer = Task._resolve( preprocess, postprocess, serializer, self._preprocess, self._postprocess, self.serializer, ) # Datapipeline if data_pipeline is not None: preprocess, postprocess, serializer = Task._resolve( preprocess, postprocess, serializer, getattr(data_pipeline, '_preprocess_pipeline', None), getattr(data_pipeline, '_postprocess_pipeline', None), getattr(data_pipeline, '_serializer', None), ) data_source = data_source or old_data_source if isinstance(data_source, str): if preprocess is None: data_source = DataSource( ) # TODO: warn the user that we are not using the specified data source else: data_source = preprocess.data_source_of_name(data_source) data_pipeline = DataPipeline(data_source, preprocess, postprocess, serializer) self._data_pipeline_state = data_pipeline.initialize( self._data_pipeline_state) return data_pipeline
def test_preprocess_transforms(tmpdir): """ This test makes sure that when a preprocess is being provided transforms as dictionaries, checking is done properly, and collate_in_worker_from_transform is properly extracted. """ with pytest.raises(MisconfigurationException, match="Transform should be a dict."): DefaultPreprocess(train_transform="choco") with pytest.raises(MisconfigurationException, match="train_transform contains {'choco'}. Only"): DefaultPreprocess(train_transform={"choco": None}) preprocess = DefaultPreprocess( train_transform={"to_tensor_transform": torch.nn.Linear(1, 1)}) # keep is None assert preprocess._train_collate_in_worker_from_transform is True assert preprocess._val_collate_in_worker_from_transform is None assert preprocess._test_collate_in_worker_from_transform is None assert preprocess._predict_collate_in_worker_from_transform is None with pytest.raises( MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`" ): preprocess = DefaultPreprocess( train_transform={ "per_batch_transform": torch.nn.Linear(1, 1), "per_sample_transform_on_device": torch.nn.Linear(1, 1) }) preprocess = DefaultPreprocess( train_transform={"per_batch_transform": torch.nn.Linear(1, 1)}, predict_transform={ "per_sample_transform_on_device": torch.nn.Linear(1, 1) }) # keep is None assert preprocess._train_collate_in_worker_from_transform is True assert preprocess._val_collate_in_worker_from_transform is None assert preprocess._test_collate_in_worker_from_transform is None assert preprocess._predict_collate_in_worker_from_transform is False train_preprocessor = DataPipeline( preprocess=preprocess).worker_preprocessor(RunningStage.TRAINING) val_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor( RunningStage.VALIDATING) test_preprocessor = DataPipeline( preprocess=preprocess).worker_preprocessor(RunningStage.TESTING) predict_preprocessor = DataPipeline( preprocess=preprocess).worker_preprocessor(RunningStage.PREDICTING) assert train_preprocessor.collate_fn.func == preprocess.collate assert val_preprocessor.collate_fn.func == preprocess.collate assert test_preprocessor.collate_fn.func == preprocess.collate assert predict_preprocessor.collate_fn.func == DataPipeline._identity class CustomPreprocess(DefaultPreprocess): def per_sample_transform_on_device(self, sample: Any) -> Any: return super().per_sample_transform_on_device(sample) def per_batch_transform(self, batch: Any) -> Any: return super().per_batch_transform(batch) preprocess = CustomPreprocess( train_transform={"per_batch_transform": torch.nn.Linear(1, 1)}, predict_transform={ "per_sample_transform_on_device": torch.nn.Linear(1, 1) }) # keep is None assert preprocess._train_collate_in_worker_from_transform is True assert preprocess._val_collate_in_worker_from_transform is None assert preprocess._test_collate_in_worker_from_transform is None assert preprocess._predict_collate_in_worker_from_transform is False data_pipeline = DataPipeline(preprocess=preprocess) train_preprocessor = data_pipeline.worker_preprocessor( RunningStage.TRAINING) with pytest.raises( MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`" ): val_preprocessor = data_pipeline.worker_preprocessor( RunningStage.VALIDATING) with pytest.raises( MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`" ): test_preprocessor = data_pipeline.worker_preprocessor( RunningStage.TESTING) predict_preprocessor = data_pipeline.worker_preprocessor( RunningStage.PREDICTING) assert train_preprocessor.collate_fn.func == preprocess.collate assert predict_preprocessor.collate_fn.func == DataPipeline._identity
def __resolve_transforms( self, running_stage: RunningStage) -> Optional[Dict[str, Callable]]: from flash.core.data.data_pipeline import DataPipeline transforms_out = {} stage = _STAGES_PREFIX[running_stage] # iterate over all transforms hook name for transform_name in InputTransformPlacement: transforms = {} transform_name = transform_name.value # iterate over all prefixes for key in ApplyToKeyPrefix: # get the resolved hook name based on the current stage resolved_name = DataPipeline._resolve_function_hierarchy( transform_name, self, running_stage, InputTransform) # check if the hook name is specialized is_specialized_name = resolved_name.startswith(stage) # get the resolved hook name for apply to key on the current stage resolved_apply_to_key_name = DataPipeline._resolve_function_hierarchy( f"{key}_{transform_name}", self, running_stage, InputTransform) # check if resolved hook name for apply to key is specialized is_specialized_apply_to_key_name = resolved_apply_to_key_name.startswith( stage) # check if they are overridden by the user resolve_name_overridden = DataPipeline._is_overridden( resolved_name, self, InputTransform) resolved_apply_to_key_name_overridden = DataPipeline._is_overridden( resolved_apply_to_key_name, self, InputTransform) if resolve_name_overridden and resolved_apply_to_key_name_overridden: # if both are specialized or both aren't specialized, raise a exception # It means there is priority to specialize hooks name. if not (is_specialized_name ^ is_specialized_apply_to_key_name): raise MisconfigurationException( f"Only one of {resolved_name} or {resolved_apply_to_key_name} can be overridden." ) method_name = resolved_name if is_specialized_name else resolved_apply_to_key_name else: method_name = resolved_apply_to_key_name if resolved_apply_to_key_name_overridden else resolved_name # get associated transform try: fn = getattr(self, method_name)() except AttributeError as e: raise AttributeError( str(e) + ". Hint: Call super().__init__(...) after setting all attributes." ) if not callable(fn): raise MisconfigurationException( f"The hook {method_name} should return a function.") # if the default hook is used, it should return identity, skip it. if fn is self._identity: continue # wrap apply to key hook into `ApplyToKeys` with the associated key. if method_name == resolved_apply_to_key_name: fn = ApplyToKeys(key.value, fn) if method_name not in transforms: transforms[method_name] = fn # store the transforms. if transforms: transforms = list(transforms.values()) transforms_out[transform_name] = Compose( transforms) if len(transforms) > 1 else transforms[0] return transforms_out