示例#1
0
def test_saving_with_serializers(tmpdir):

    checkpoint_file = os.path.join(tmpdir, 'tmp.ckpt')

    class CustomModel(Task):
        def __init__(self):
            super().__init__(model=torch.nn.Linear(1, 1),
                             loss_fn=torch.nn.MSELoss())

    serializer = Labels(["a", "b"])
    model = CustomModel()
    trainer = Trainer(fast_dev_run=True)
    data_pipeline = DataPipeline(DefaultPreprocess(), serializer=serializer)
    data_pipeline.initialize()
    model.data_pipeline = data_pipeline
    assert isinstance(model.preprocess, DefaultPreprocess)
    dummy_data = DataLoader(
        list(
            zip(torch.arange(10, dtype=torch.float),
                torch.arange(10, dtype=torch.float))))
    trainer.fit(model, train_dataloader=dummy_data)
    trainer.save_checkpoint(checkpoint_file)
    model = CustomModel.load_from_checkpoint(checkpoint_file)
    assert isinstance(model.preprocess._data_pipeline_state, DataPipelineState)
    assert model.preprocess._data_pipeline_state._state[
        ClassificationState] == ClassificationState(['a', 'b'])
示例#2
0
def test_predict_numpy():
    img = np.ones((1, 3, 10, 20))
    model = SemanticSegmentation(2)
    data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess())
    out = model.predict(img, data_source="numpy", data_pipeline=data_pipe)
    assert isinstance(out[0], torch.Tensor)
    assert out[0].shape == (196, 196)
示例#3
0
def test_serialization_data_pipeline(tmpdir):
    model = CustomModel()

    checkpoint_file = os.path.join(tmpdir, 'tmp.ckpt')
    checkpoint = ModelCheckpoint(tmpdir, 'test.ckpt')
    trainer = Trainer(callbacks=[checkpoint], max_epochs=1)
    dummy_data = DataLoader(
        list(
            zip(torch.arange(10, dtype=torch.float),
                torch.arange(10, dtype=torch.float))))
    trainer.fit(model, dummy_data)

    assert model.data_pipeline is None
    trainer.save_checkpoint(checkpoint_file)

    loaded_model = CustomModel.load_from_checkpoint(checkpoint_file)
    assert loaded_model.data_pipeline is None

    model.data_pipeline = DataPipeline(CustomPreprocess())

    trainer.fit(model, dummy_data)
    assert model.data_pipeline
    assert isinstance(model.preprocess, CustomPreprocess)
    trainer.save_checkpoint(checkpoint_file)
    loaded_model = CustomModel.load_from_checkpoint(checkpoint_file)
    assert loaded_model.data_pipeline
    assert isinstance(loaded_model.preprocess, CustomPreprocess)
    for file in os.listdir(tmpdir):
        if file.endswith('.ckpt'):
            os.remove(os.path.join(tmpdir, file))
示例#4
0
def test_preprocessing_data_pipeline_no_running_stage(with_dataset):
    pipe = DataPipeline(_AutoDatasetTestPreprocess(with_dataset))

    dataset = pipe._generate_auto_dataset(range(10), running_stage=None)

    with pytest.raises(RuntimeError, match='`__len__` for `load_sample`'):
        for idx in range(len(dataset)):
            dataset[idx]

    # will be triggered when running stage is set
    if with_dataset:
        assert not hasattr(dataset, 'load_sample_was_called')
        assert not hasattr(dataset, 'load_data_was_called')
        assert pipe._preprocess_pipeline.load_sample_with_dataset_count == 0
        assert pipe._preprocess_pipeline.load_data_with_dataset_count == 0
    else:
        assert pipe._preprocess_pipeline.load_sample_count == 0
        assert pipe._preprocess_pipeline.load_data_count == 0

    dataset.running_stage = RunningStage.TRAINING

    if with_dataset:
        assert pipe._preprocess_pipeline.train_load_data_with_dataset_count == 1
        assert dataset.train_load_data_was_called
    else:
        assert pipe._preprocess_pipeline.train_load_data_count == 1
def test_autodataset_warning():
    with pytest.warns(
            UserWarning,
            match=
            "``datapipeline`` is specified but load_sample and/or load_data are also specified"
    ):
        AutoDataset(range(10),
                    load_data=lambda x: x,
                    load_sample=lambda x: x,
                    data_pipeline=DataPipeline())
    def from_load_data_inputs(
        cls,
        train_load_data_input: Optional[Any] = None,
        val_load_data_input: Optional[Any] = None,
        test_load_data_input: Optional[Any] = None,
        predict_load_data_input: Optional[Any] = None,
        preprocess: Optional[Preprocess] = None,
        postprocess: Optional[Postprocess] = None,
        **kwargs,
    ) -> 'DataModule':
        """
        This functions is an helper to generate a ``DataModule`` from a ``DataPipeline``.

        Args:
            cls: ``DataModule`` subclass
            train_load_data_input: Data to be received by the ``train_load_data`` function from this ``Preprocess``
            val_load_data_input: Data to be received by the ``val_load_data`` function from this ``Preprocess``
            test_load_data_input: Data to be received by the ``test_load_data`` function from this ``Preprocess``
            predict_load_data_input: Data to be received by the ``predict_load_data`` function from this ``Preprocess``
            kwargs: Any extra arguments to instantiate the provided ``DataModule``
        """
        # trick to get data_pipeline from empty DataModule
        if preprocess or postprocess:
            data_pipeline = DataPipeline(
                preprocess or cls(**kwargs).preprocess,
                postprocess or cls(**kwargs).postprocess,
            )
        else:
            data_pipeline = cls(**kwargs).data_pipeline

        train_dataset = cls._generate_dataset_if_possible(
            train_load_data_input,
            running_stage=RunningStage.TRAINING,
            data_pipeline=data_pipeline)
        val_dataset = cls._generate_dataset_if_possible(
            val_load_data_input,
            running_stage=RunningStage.VALIDATING,
            data_pipeline=data_pipeline)
        test_dataset = cls._generate_dataset_if_possible(
            test_load_data_input,
            running_stage=RunningStage.TESTING,
            data_pipeline=data_pipeline)
        predict_dataset = cls._generate_dataset_if_possible(
            predict_load_data_input,
            running_stage=RunningStage.PREDICTING,
            data_pipeline=data_pipeline)
        datamodule = cls(train_dataset=train_dataset,
                         val_dataset=val_dataset,
                         test_dataset=test_dataset,
                         predict_dataset=predict_dataset,
                         **kwargs)
        datamodule._preprocess = data_pipeline._preprocess_pipeline
        datamodule._postprocess = data_pipeline._postprocess_pipeline
        return datamodule
def test_data_pipeline_predict_worker_preprocessor_and_device_preprocessor():

    preprocess = CustomPreprocess()
    data_pipeline = DataPipeline(preprocess)

    data_pipeline.worker_preprocessor(RunningStage.TRAINING)
    with pytest.raises(MisconfigurationException, match="are mutual exclusive"):
        data_pipeline.worker_preprocessor(RunningStage.VALIDATING)
    with pytest.raises(MisconfigurationException, match="are mutual exclusive"):
        data_pipeline.worker_preprocessor(RunningStage.TESTING)
    data_pipeline.worker_preprocessor(RunningStage.PREDICTING)
def test_detach_preprocessing_from_model(tmpdir):

    preprocess = CustomPreprocess()
    data_pipeline = DataPipeline(preprocess)
    model = CustomModel()
    model.data_pipeline = data_pipeline

    assert model.train_dataloader().collate_fn == default_collate
    assert model.transfer_batch_to_device.__self__ == model
    model.on_train_dataloader()
    assert isinstance(model.train_dataloader().collate_fn, _PreProcessor)
    assert isinstance(model.transfer_batch_to_device, _StageOrchestrator)
    model.on_fit_end()
    assert model.transfer_batch_to_device.__self__ == model
    assert model.train_dataloader().collate_fn == default_collate
def test_data_pipeline_init_and_assignement(use_preprocess, use_postprocess, tmpdir):

    class SubPreprocess(Preprocess):
        pass

    class SubPostprocess(Postprocess):
        pass

    data_pipeline = DataPipeline(
        SubPreprocess() if use_preprocess else None,
        SubPostprocess() if use_postprocess else None,
    )
    assert isinstance(data_pipeline._preprocess_pipeline, SubPreprocess if use_preprocess else Preprocess)
    assert isinstance(data_pipeline._postprocess_pipeline, SubPostprocess if use_postprocess else Postprocess)

    model = CustomModel(Postprocess())
    model.data_pipeline = data_pipeline
    assert isinstance(model._preprocess, SubPreprocess if use_preprocess else Preprocess)
    assert isinstance(model._postprocess, SubPostprocess if use_postprocess else Postprocess)
示例#10
0
    def data_pipeline(self) -> Optional[DataPipeline]:
        if self._data_pipeline is not None:
            return self._data_pipeline

        elif self.preprocess is not None or self.postprocess is not None:
            # use direct attributes here to avoid recursion with properties that also check the data_pipeline property
            return DataPipeline(self.preprocess, self.postprocess)

        elif self.datamodule is not None and getattr(
                self.datamodule, 'data_pipeline', None) is not None:
            return self.datamodule.data_pipeline

        elif self.trainer is not None and hasattr(
                self.trainer, 'datamodule') and getattr(
                    self.trainer.datamodule, 'data_pipeline',
                    None) is not None:
            return self.trainer.datamodule.data_pipeline

        return self._data_pipeline
示例#11
0
def test_data_pipeline_init_and_assignement(use_preprocess, use_postprocess,
                                            tmpdir):
    class CustomModel(Task):
        def __init__(self, postprocess: Optional[Postprocess] = None):
            super().__init__(model=torch.nn.Linear(1, 1),
                             loss_fn=torch.nn.MSELoss())
            self._postprocess = postprocess

        def train_dataloader(self) -> Any:
            return DataLoader(DummyDataset())

    class SubPreprocess(DefaultPreprocess):
        pass

    class SubPostprocess(Postprocess):
        pass

    data_pipeline = DataPipeline(
        preprocess=SubPreprocess() if use_preprocess else None,
        postprocess=SubPostprocess() if use_postprocess else None,
    )
    assert isinstance(data_pipeline._preprocess_pipeline,
                      SubPreprocess if use_preprocess else DefaultPreprocess)
    assert isinstance(data_pipeline._postprocess_pipeline,
                      SubPostprocess if use_postprocess else Postprocess)

    model = CustomModel(postprocess=Postprocess())
    model.data_pipeline = data_pipeline
    # TODO: the line below should make the same effect but it's not
    # data_pipeline._attach_to_model(model)

    if use_preprocess:
        assert isinstance(model._preprocess, SubPreprocess)
    else:
        assert model._preprocess is None or isinstance(model._preprocess,
                                                       Preprocess)

    if use_postprocess:
        assert isinstance(model._postprocess, SubPostprocess)
    else:
        assert model._postprocess is None or isinstance(
            model._postprocess, Postprocess)
示例#12
0
def test_preprocessing_data_pipeline_with_running_stage(with_dataset):
    pipe = DataPipeline(_AutoDatasetTestPreprocess(with_dataset))

    running_stage = RunningStage.TRAINING

    dataset = pipe._generate_auto_dataset(range(10), running_stage=running_stage)

    assert len(dataset) == 10

    for idx in range(len(dataset)):
        dataset[idx]

    if with_dataset:
        assert dataset.train_load_sample_was_called
        assert dataset.train_load_data_was_called
        assert pipe._preprocess_pipeline.train_load_sample_with_dataset_count == len(dataset)
        assert pipe._preprocess_pipeline.train_load_data_with_dataset_count == 1
    else:
        assert pipe._preprocess_pipeline.train_load_sample_count == len(dataset)
        assert pipe._preprocess_pipeline.train_load_data_count == 1
示例#13
0
def test_detach_preprocessing_from_model(tmpdir):
    class CustomModel(Task):
        def __init__(self, postprocess: Optional[Postprocess] = None):
            super().__init__(model=torch.nn.Linear(1, 1),
                             loss_fn=torch.nn.MSELoss())
            self._postprocess = postprocess

        def train_dataloader(self) -> Any:
            return DataLoader(DummyDataset())

    preprocess = CustomPreprocess()
    data_pipeline = DataPipeline(preprocess=preprocess)
    model = CustomModel()
    model.data_pipeline = data_pipeline

    assert model.train_dataloader().collate_fn == default_collate
    assert model.transfer_batch_to_device.__self__ == model
    model.on_train_dataloader()
    assert isinstance(model.train_dataloader().collate_fn, _PreProcessor)
    assert isinstance(model.transfer_batch_to_device, _StageOrchestrator)
    model.on_fit_end()
    assert model.transfer_batch_to_device.__self__ == model
    assert model.train_dataloader().collate_fn == default_collate
示例#14
0
 def preprocess(self, preprocess: Preprocess) -> None:
     self._preprocess = preprocess
     self.data_pipeline = DataPipeline(preprocess, self.postprocess)
 def data_pipeline(self) -> DataPipeline:
     return DataPipeline(self.preprocess, self.postprocess)
示例#16
0
def test_attaching_datapipeline_to_model(tmpdir):
    class SubPreprocess(DefaultPreprocess):
        pass

    preprocess = SubPreprocess()
    data_pipeline = DataPipeline(preprocess=preprocess)

    class CustomModel(Task):
        def __init__(self):
            super().__init__(model=torch.nn.Linear(1, 1),
                             loss_fn=torch.nn.MSELoss())
            self._postprocess = Postprocess()

        def training_step(self, batch: Any, batch_idx: int) -> Any:
            pass

        def validation_step(self, batch: Any, batch_idx: int) -> Any:
            pass

        def test_step(self, batch: Any, batch_idx: int) -> Any:
            pass

        def train_dataloader(self) -> Any:
            return DataLoader(DummyDataset())

        def val_dataloader(self) -> Any:
            return DataLoader(DummyDataset())

        def test_dataloader(self) -> Any:
            return DataLoader(DummyDataset())

        def predict_dataloader(self) -> Any:
            return DataLoader(DummyDataset())

    class TestModel(CustomModel):

        stages = [
            RunningStage.TRAINING, RunningStage.VALIDATING,
            RunningStage.TESTING, RunningStage.PREDICTING
        ]
        on_train_start_called = False
        on_val_start_called = False
        on_test_start_called = False
        on_predict_start_called = False

        def on_fit_start(self):
            assert self.predict_step.__self__ == self
            self._saved_predict_step = self.predict_step

        def _compare_pre_processor(self, p1, p2):
            p1_seq = p1.per_sample_transform
            p2_seq = p2.per_sample_transform
            assert p1_seq.pre_tensor_transform.func == p2_seq.pre_tensor_transform.func
            assert p1_seq.to_tensor_transform.func == p2_seq.to_tensor_transform.func
            assert p1_seq.post_tensor_transform.func == p2_seq.post_tensor_transform.func
            assert p1.collate_fn.func == p2.collate_fn.func
            assert p1.per_batch_transform.func == p2.per_batch_transform.func

        def _assert_stage_orchestrator_state(
                self,
                stage_mapping: Dict,
                current_running_stage: RunningStage,
                cls=_PreProcessor):
            assert isinstance(stage_mapping[current_running_stage], cls)
            assert stage_mapping[current_running_stage]

        def on_train_dataloader(self) -> None:
            current_running_stage = RunningStage.TRAINING
            self.on_train_dataloader_called = True
            collate_fn = self.train_dataloader().collate_fn  # noqa F811
            assert collate_fn == default_collate
            assert not isinstance(self.transfer_batch_to_device,
                                  _StageOrchestrator)
            super().on_train_dataloader()
            collate_fn = self.train_dataloader().collate_fn  # noqa F811
            assert collate_fn.stage == current_running_stage
            self._compare_pre_processor(
                collate_fn,
                self.data_pipeline.worker_preprocessor(current_running_stage))
            assert isinstance(self.transfer_batch_to_device,
                              _StageOrchestrator)
            self._assert_stage_orchestrator_state(
                self.transfer_batch_to_device._stage_mapping,
                current_running_stage)

        def on_val_dataloader(self) -> None:
            current_running_stage = RunningStage.VALIDATING
            self.on_val_dataloader_called = True
            collate_fn = self.val_dataloader().collate_fn  # noqa F811
            assert collate_fn == default_collate
            assert isinstance(self.transfer_batch_to_device,
                              _StageOrchestrator)
            super().on_val_dataloader()
            collate_fn = self.val_dataloader().collate_fn  # noqa F811
            assert collate_fn.stage == current_running_stage
            self._compare_pre_processor(
                collate_fn,
                self.data_pipeline.worker_preprocessor(current_running_stage))
            assert isinstance(self.transfer_batch_to_device,
                              _StageOrchestrator)
            self._assert_stage_orchestrator_state(
                self.transfer_batch_to_device._stage_mapping,
                current_running_stage)

        def on_test_dataloader(self) -> None:
            current_running_stage = RunningStage.TESTING
            self.on_test_dataloader_called = True
            collate_fn = self.test_dataloader().collate_fn  # noqa F811
            assert collate_fn == default_collate
            assert not isinstance(self.transfer_batch_to_device,
                                  _StageOrchestrator)
            super().on_test_dataloader()
            collate_fn = self.test_dataloader().collate_fn  # noqa F811
            assert collate_fn.stage == current_running_stage
            self._compare_pre_processor(
                collate_fn,
                self.data_pipeline.worker_preprocessor(current_running_stage))
            assert isinstance(self.transfer_batch_to_device,
                              _StageOrchestrator)
            self._assert_stage_orchestrator_state(
                self.transfer_batch_to_device._stage_mapping,
                current_running_stage)

        def on_predict_dataloader(self) -> None:
            current_running_stage = RunningStage.PREDICTING
            self.on_predict_dataloader_called = True
            collate_fn = self.predict_dataloader().collate_fn  # noqa F811
            assert collate_fn == default_collate
            assert isinstance(self.transfer_batch_to_device,
                              _StageOrchestrator)
            assert self.predict_step == self._saved_predict_step
            super().on_predict_dataloader()
            collate_fn = self.predict_dataloader().collate_fn  # noqa F811
            assert collate_fn.stage == current_running_stage
            self._compare_pre_processor(
                collate_fn,
                self.data_pipeline.worker_preprocessor(current_running_stage))
            assert isinstance(self.transfer_batch_to_device,
                              _StageOrchestrator)
            assert isinstance(self.predict_step, _StageOrchestrator)
            self._assert_stage_orchestrator_state(
                self.transfer_batch_to_device._stage_mapping,
                current_running_stage)
            self._assert_stage_orchestrator_state(
                self.predict_step._stage_mapping,
                current_running_stage,
                cls=_PostProcessor)

        def on_fit_end(self) -> None:
            super().on_fit_end()
            assert self.train_dataloader().collate_fn == default_collate
            assert self.val_dataloader().collate_fn == default_collate
            assert self.test_dataloader().collate_fn == default_collate
            assert self.predict_dataloader().collate_fn == default_collate
            assert not isinstance(self.transfer_batch_to_device,
                                  _StageOrchestrator)
            assert self.predict_step == self._saved_predict_step

    model = TestModel()
    model.data_pipeline = data_pipeline
    trainer = Trainer(fast_dev_run=True)
    trainer.fit(model)
    trainer.test(model)
    trainer.predict(model)

    assert model.on_train_dataloader_called
    assert model.on_val_dataloader_called
    assert model.on_test_dataloader_called
    assert model.on_predict_dataloader_called
示例#17
0
def test_preprocess_transforms(tmpdir):
    """
    This test makes sure that when a preprocess is being provided transforms as dictionaries,
    checking is done properly, and collate_in_worker_from_transform is properly extracted.
    """

    with pytest.raises(MisconfigurationException,
                       match="Transform should be a dict."):
        DefaultPreprocess(train_transform="choco")

    with pytest.raises(MisconfigurationException,
                       match="train_transform contains {'choco'}. Only"):
        DefaultPreprocess(train_transform={"choco": None})

    preprocess = DefaultPreprocess(
        train_transform={"to_tensor_transform": torch.nn.Linear(1, 1)})
    # keep is None
    assert preprocess._train_collate_in_worker_from_transform is True
    assert preprocess._val_collate_in_worker_from_transform is None
    assert preprocess._test_collate_in_worker_from_transform is None
    assert preprocess._predict_collate_in_worker_from_transform is None

    with pytest.raises(
            MisconfigurationException,
            match="`per_batch_transform` and `per_sample_transform_on_device`"
    ):
        preprocess = DefaultPreprocess(
            train_transform={
                "per_batch_transform": torch.nn.Linear(1, 1),
                "per_sample_transform_on_device": torch.nn.Linear(1, 1)
            })

    preprocess = DefaultPreprocess(
        train_transform={"per_batch_transform": torch.nn.Linear(1, 1)},
        predict_transform={
            "per_sample_transform_on_device": torch.nn.Linear(1, 1)
        })
    # keep is None
    assert preprocess._train_collate_in_worker_from_transform is True
    assert preprocess._val_collate_in_worker_from_transform is None
    assert preprocess._test_collate_in_worker_from_transform is None
    assert preprocess._predict_collate_in_worker_from_transform is False

    train_preprocessor = DataPipeline(
        preprocess=preprocess).worker_preprocessor(RunningStage.TRAINING)
    val_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor(
        RunningStage.VALIDATING)
    test_preprocessor = DataPipeline(
        preprocess=preprocess).worker_preprocessor(RunningStage.TESTING)
    predict_preprocessor = DataPipeline(
        preprocess=preprocess).worker_preprocessor(RunningStage.PREDICTING)

    assert train_preprocessor.collate_fn.func == default_collate
    assert val_preprocessor.collate_fn.func == default_collate
    assert test_preprocessor.collate_fn.func == default_collate
    assert predict_preprocessor.collate_fn.func == DataPipeline._identity

    class CustomPreprocess(DefaultPreprocess):
        def per_sample_transform_on_device(self, sample: Any) -> Any:
            return super().per_sample_transform_on_device(sample)

        def per_batch_transform(self, batch: Any) -> Any:
            return super().per_batch_transform(batch)

    preprocess = CustomPreprocess(
        train_transform={"per_batch_transform": torch.nn.Linear(1, 1)},
        predict_transform={
            "per_sample_transform_on_device": torch.nn.Linear(1, 1)
        })
    # keep is None
    assert preprocess._train_collate_in_worker_from_transform is True
    assert preprocess._val_collate_in_worker_from_transform is None
    assert preprocess._test_collate_in_worker_from_transform is None
    assert preprocess._predict_collate_in_worker_from_transform is False

    data_pipeline = DataPipeline(preprocess=preprocess)

    train_preprocessor = data_pipeline.worker_preprocessor(
        RunningStage.TRAINING)
    with pytest.raises(
            MisconfigurationException,
            match="`per_batch_transform` and `per_sample_transform_on_device`"
    ):
        val_preprocessor = data_pipeline.worker_preprocessor(
            RunningStage.VALIDATING)
    with pytest.raises(
            MisconfigurationException,
            match="`per_batch_transform` and `per_sample_transform_on_device`"
    ):
        test_preprocessor = data_pipeline.worker_preprocessor(
            RunningStage.TESTING)
    predict_preprocessor = data_pipeline.worker_preprocessor(
        RunningStage.PREDICTING)

    assert train_preprocessor.collate_fn.func == default_collate
    assert predict_preprocessor.collate_fn.func == DataPipeline._identity
def test_data_pipeline_is_overriden_and_resolve_function_hierarchy(tmpdir):

    class CustomPreprocess(Preprocess):

        def load_data(self, *_, **__):
            pass

        def test_load_data(self, *_, **__):
            pass

        def predict_load_data(self, *_, **__):
            pass

        def predict_load_sample(self, *_, **__):
            pass

        def val_load_sample(self, *_, **__):
            pass

        def val_pre_tensor_transform(self, *_, **__):
            pass

        def predict_to_tensor_transform(self, *_, **__):
            pass

        def train_post_tensor_transform(self, *_, **__):
            pass

        def test_collate(self, *_, **__):
            pass

        def val_per_sample_transform_on_device(self, *_, **__):
            pass

        def train_per_batch_transform_on_device(self, *_, **__):
            pass

        def test_per_batch_transform_on_device(self, *_, **__):
            pass

    preprocess = CustomPreprocess()
    data_pipeline = DataPipeline(preprocess)
    train_func_names = {
        k: data_pipeline._resolve_function_hierarchy(
            k, data_pipeline._preprocess_pipeline, RunningStage.TRAINING, Preprocess
        )
        for k in data_pipeline.PREPROCESS_FUNCS
    }
    val_func_names = {
        k: data_pipeline._resolve_function_hierarchy(
            k, data_pipeline._preprocess_pipeline, RunningStage.VALIDATING, Preprocess
        )
        for k in data_pipeline.PREPROCESS_FUNCS
    }
    test_func_names = {
        k: data_pipeline._resolve_function_hierarchy(
            k, data_pipeline._preprocess_pipeline, RunningStage.TESTING, Preprocess
        )
        for k in data_pipeline.PREPROCESS_FUNCS
    }
    predict_func_names = {
        k: data_pipeline._resolve_function_hierarchy(
            k, data_pipeline._preprocess_pipeline, RunningStage.PREDICTING, Preprocess
        )
        for k in data_pipeline.PREPROCESS_FUNCS
    }
    # load_data
    assert train_func_names["load_data"] == "load_data"
    assert val_func_names["load_data"] == "load_data"
    assert test_func_names["load_data"] == "test_load_data"
    assert predict_func_names["load_data"] == "predict_load_data"

    # load_sample
    assert train_func_names["load_sample"] == "load_sample"
    assert val_func_names["load_sample"] == "val_load_sample"
    assert test_func_names["load_sample"] == "load_sample"
    assert predict_func_names["load_sample"] == "predict_load_sample"

    # pre_tensor_transform
    assert train_func_names["pre_tensor_transform"] == "pre_tensor_transform"
    assert val_func_names["pre_tensor_transform"] == "val_pre_tensor_transform"
    assert test_func_names["pre_tensor_transform"] == "pre_tensor_transform"
    assert predict_func_names["pre_tensor_transform"] == "pre_tensor_transform"

    # to_tensor_transform
    assert train_func_names["to_tensor_transform"] == "to_tensor_transform"
    assert val_func_names["to_tensor_transform"] == "to_tensor_transform"
    assert test_func_names["to_tensor_transform"] == "to_tensor_transform"
    assert predict_func_names["to_tensor_transform"] == "predict_to_tensor_transform"

    # post_tensor_transform
    assert train_func_names["post_tensor_transform"] == "train_post_tensor_transform"
    assert val_func_names["post_tensor_transform"] == "post_tensor_transform"
    assert test_func_names["post_tensor_transform"] == "post_tensor_transform"
    assert predict_func_names["post_tensor_transform"] == "post_tensor_transform"

    # collate
    assert train_func_names["collate"] == "collate"
    assert val_func_names["collate"] == "collate"
    assert test_func_names["collate"] == "test_collate"
    assert predict_func_names["collate"] == "collate"

    # per_sample_transform_on_device
    assert train_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device"
    assert val_func_names["per_sample_transform_on_device"] == "val_per_sample_transform_on_device"
    assert test_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device"
    assert predict_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device"

    # per_batch_transform_on_device
    assert train_func_names["per_batch_transform_on_device"] == "train_per_batch_transform_on_device"
    assert val_func_names["per_batch_transform_on_device"] == "per_batch_transform_on_device"
    assert test_func_names["per_batch_transform_on_device"] == "test_per_batch_transform_on_device"
    assert predict_func_names["per_batch_transform_on_device"] == "per_batch_transform_on_device"

    train_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TRAINING)
    val_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.VALIDATING)
    test_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TESTING)
    predict_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.PREDICTING)

    _seq = train_worker_preprocessor.per_sample_transform
    assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform
    assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform
    assert _seq.post_tensor_transform.func == preprocess.train_post_tensor_transform
    assert train_worker_preprocessor.collate_fn.func == default_collate
    assert train_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform

    _seq = val_worker_preprocessor.per_sample_transform
    assert _seq.pre_tensor_transform.func == preprocess.val_pre_tensor_transform
    assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform
    assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform
    assert val_worker_preprocessor.collate_fn.func == data_pipeline._identity
    assert val_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform

    _seq = test_worker_preprocessor.per_sample_transform
    assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform
    assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform
    assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform
    assert test_worker_preprocessor.collate_fn.func == preprocess.test_collate
    assert test_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform

    _seq = predict_worker_preprocessor.per_sample_transform
    assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform
    assert _seq.to_tensor_transform.func == preprocess.predict_to_tensor_transform
    assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform
    assert predict_worker_preprocessor.collate_fn.func == default_collate
    assert predict_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform
def test_attaching_datapipeline_to_model(tmpdir):

    preprocess = TestPreprocess()
    data_pipeline = DataPipeline(preprocess)

    class TestModel(CustomModel):

        stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING]
        on_train_start_called = False
        on_val_start_called = False
        on_test_start_called = False
        on_predict_start_called = False

        def on_fit_start(self):
            assert self.predict_step.__self__ == self
            self._saved_predict_step = self.predict_step

        def _compare_pre_processor(self, p1, p2):
            p1_seq = p1.per_sample_transform
            p2_seq = p2.per_sample_transform
            assert p1_seq.pre_tensor_transform.func == p2_seq.pre_tensor_transform.func
            assert p1_seq.to_tensor_transform.func == p2_seq.to_tensor_transform.func
            assert p1_seq.post_tensor_transform.func == p2_seq.post_tensor_transform.func
            assert p1.collate_fn.func == p2.collate_fn.func
            assert p1.per_batch_transform.func == p2.per_batch_transform.func

        def _assert_stage_orchestrator_state(
            self, stage_mapping: Dict, current_running_stage: RunningStage, cls=_PreProcessor
        ):
            assert isinstance(stage_mapping[current_running_stage], cls)
            assert stage_mapping[current_running_stage]

        def on_train_dataloader(self) -> None:
            current_running_stage = RunningStage.TRAINING
            self.on_train_dataloader_called = True
            collate_fn = self.train_dataloader().collate_fn  # noqa F811
            assert collate_fn == default_collate
            assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator)
            super().on_train_dataloader()
            collate_fn = self.train_dataloader().collate_fn  # noqa F811
            assert collate_fn.stage == current_running_stage
            self._compare_pre_processor(collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage))
            assert isinstance(self.transfer_batch_to_device, _StageOrchestrator)
            self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage)

        def on_val_dataloader(self) -> None:
            current_running_stage = RunningStage.VALIDATING
            self.on_val_dataloader_called = True
            collate_fn = self.val_dataloader().collate_fn  # noqa F811
            assert collate_fn == default_collate
            assert isinstance(self.transfer_batch_to_device, _StageOrchestrator)
            super().on_val_dataloader()
            collate_fn = self.val_dataloader().collate_fn  # noqa F811
            assert collate_fn.stage == current_running_stage
            self._compare_pre_processor(collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage))
            assert isinstance(self.transfer_batch_to_device, _StageOrchestrator)
            self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage)

        def on_test_dataloader(self) -> None:
            current_running_stage = RunningStage.TESTING
            self.on_test_dataloader_called = True
            collate_fn = self.test_dataloader().collate_fn  # noqa F811
            assert collate_fn == default_collate
            assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator)
            super().on_test_dataloader()
            collate_fn = self.test_dataloader().collate_fn  # noqa F811
            assert collate_fn.stage == current_running_stage
            self._compare_pre_processor(collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage))
            assert isinstance(self.transfer_batch_to_device, _StageOrchestrator)
            self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage)

        def on_predict_dataloader(self) -> None:
            current_running_stage = RunningStage.PREDICTING
            self.on_predict_dataloader_called = True
            collate_fn = self.predict_dataloader().collate_fn  # noqa F811
            assert collate_fn == default_collate
            assert isinstance(self.transfer_batch_to_device, _StageOrchestrator)
            assert self.predict_step == self._saved_predict_step
            super().on_predict_dataloader()
            collate_fn = self.predict_dataloader().collate_fn  # noqa F811
            assert collate_fn.stage == current_running_stage
            self._compare_pre_processor(collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage))
            assert isinstance(self.transfer_batch_to_device, _StageOrchestrator)
            assert isinstance(self.predict_step, _StageOrchestrator)
            self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage)
            self._assert_stage_orchestrator_state(
                self.predict_step._stage_mapping, current_running_stage, cls=_PostProcessor
            )

        def on_fit_end(self) -> None:
            super().on_fit_end()
            assert self.train_dataloader().collate_fn == default_collate
            assert self.val_dataloader().collate_fn == default_collate
            assert self.test_dataloader().collate_fn == default_collate
            assert self.predict_dataloader().collate_fn == default_collate
            assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator)
            assert self.predict_step == self._saved_predict_step

    datamodule = CustomDataModule()
    datamodule._data_pipeline = data_pipeline
    model = TestModel()
    trainer = Trainer(fast_dev_run=True)
    trainer.fit(model, datamodule=datamodule)
    trainer.test(model)
    trainer.predict(model)

    assert model.on_train_dataloader_called
    assert model.on_val_dataloader_called
    assert model.on_test_dataloader_called
    assert model.on_predict_dataloader_called
示例#20
0
    def build_data_pipeline(
        self,
        data_pipeline: Optional[DataPipeline] = None
    ) -> Optional[DataPipeline]:
        """Build a :class:`.DataPipeline` incorporating available
        :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`
        objects. These will be overridden in the following resolution order (lowest priority first):

        - Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`.
        - :class:`.Task` defaults given to ``.Task.__init__``.
        - :class:`.Task` manual overrides by setting :py:attr:`~data_pipeline`.
        - :class:`.DataPipeline` passed to this method.

        Args:
            data_pipeline: Optional highest priority source of
                :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`.

        Returns:
            The fully resolved :class:`.DataPipeline`.
        """
        preprocess, postprocess, serializer = None, None, None

        # Datamodule
        if self.datamodule is not None and getattr(
                self.datamodule, 'data_pipeline', None) is not None:
            preprocess = getattr(self.datamodule.data_pipeline,
                                 '_preprocess_pipeline', None)
            postprocess = getattr(self.datamodule.data_pipeline,
                                  '_postprocess_pipeline', None)
            serializer = getattr(self.datamodule.data_pipeline, '_serializer',
                                 None)

        elif self.trainer is not None and hasattr(
                self.trainer, 'datamodule') and getattr(
                    self.trainer.datamodule, 'data_pipeline',
                    None) is not None:
            preprocess = getattr(self.trainer.datamodule.data_pipeline,
                                 '_preprocess_pipeline', None)
            postprocess = getattr(self.trainer.datamodule.data_pipeline,
                                  '_postprocess_pipeline', None)
            serializer = getattr(self.trainer.datamodule.data_pipeline,
                                 '_serializer', None)

        # Defaults / task attributes
        preprocess, postprocess, serializer = Task._resolve(
            preprocess,
            postprocess,
            serializer,
            self._preprocess,
            self._postprocess,
            self.serializer,
        )

        # Datapipeline
        if data_pipeline is not None:
            preprocess, postprocess, serializer = Task._resolve(
                preprocess,
                postprocess,
                serializer,
                getattr(data_pipeline, '_preprocess_pipeline', None),
                getattr(data_pipeline, '_postprocess_pipeline', None),
                getattr(data_pipeline, '_serializer', None),
            )

        data_pipeline = DataPipeline(preprocess, postprocess, serializer)
        data_pipeline.initialize()
        return data_pipeline
示例#21
0
    def from_load_data_inputs(
        cls,
        train_load_data_input: Optional[Any] = None,
        val_load_data_input: Optional[Any] = None,
        test_load_data_input: Optional[Any] = None,
        predict_load_data_input: Optional[Any] = None,
        data_fetcher: BaseDataFetcher = None,
        preprocess: Optional[Preprocess] = None,
        postprocess: Optional[Postprocess] = None,
        use_iterable_auto_dataset: bool = False,
        seed: int = 42,
        val_split: Optional[float] = None,
        **kwargs,
    ) -> 'DataModule':
        """
        This functions is an helper to generate a ``DataModule`` from a ``DataPipeline``.

        Args:
            cls: ``DataModule`` subclass
            train_load_data_input: Data to be received by the ``train_load_data`` function
                from this :class:`~flash.data.process.Preprocess`
            val_load_data_input: Data to be received by the ``val_load_data`` function
                from this :class:`~flash.data.process.Preprocess`
            test_load_data_input: Data to be received by the ``test_load_data`` function
                from this :class:`~flash.data.process.Preprocess`
            predict_load_data_input: Data to be received by the ``predict_load_data`` function
                from this :class:`~flash.data.process.Preprocess`
            kwargs: Any extra arguments to instantiate the provided ``DataModule``
        """
        # trick to get data_pipeline from empty DataModule
        if preprocess or postprocess:
            data_pipeline = DataPipeline(
                preprocess or cls(**kwargs).preprocess,
                postprocess or cls(**kwargs).postprocess,
            )
        else:
            data_pipeline = cls(**kwargs).data_pipeline

        data_fetcher: BaseDataFetcher = data_fetcher or cls.configure_data_fetcher(
        )

        data_fetcher.attach_to_preprocess(data_pipeline._preprocess_pipeline)

        train_dataset = cls._generate_dataset_if_possible(
            train_load_data_input,
            running_stage=RunningStage.TRAINING,
            data_pipeline=data_pipeline,
            use_iterable_auto_dataset=use_iterable_auto_dataset,
        )
        val_dataset = cls._generate_dataset_if_possible(
            val_load_data_input,
            running_stage=RunningStage.VALIDATING,
            data_pipeline=data_pipeline,
            use_iterable_auto_dataset=use_iterable_auto_dataset,
        )
        test_dataset = cls._generate_dataset_if_possible(
            test_load_data_input,
            running_stage=RunningStage.TESTING,
            data_pipeline=data_pipeline,
            use_iterable_auto_dataset=use_iterable_auto_dataset,
        )
        predict_dataset = cls._generate_dataset_if_possible(
            predict_load_data_input,
            running_stage=RunningStage.PREDICTING,
            data_pipeline=data_pipeline,
            use_iterable_auto_dataset=use_iterable_auto_dataset,
        )

        if train_dataset is not None and (val_split is not None
                                          and val_dataset is None):
            train_dataset, val_dataset = cls._split_train_val(
                train_dataset, val_split)

        datamodule = cls(train_dataset=train_dataset,
                         val_dataset=val_dataset,
                         test_dataset=test_dataset,
                         predict_dataset=predict_dataset,
                         **kwargs)
        datamodule._preprocess = data_pipeline._preprocess_pipeline
        datamodule._postprocess = data_pipeline._postprocess_pipeline
        data_fetcher.attach_to_datamodule(datamodule)
        return datamodule
示例#22
0
    def build_data_pipeline(
        self,
        data_source: Optional[str] = None,
        data_pipeline: Optional[DataPipeline] = None,
    ) -> Optional[DataPipeline]:
        """Build a :class:`.DataPipeline` incorporating available
        :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`
        objects. These will be overridden in the following resolution order (lowest priority first):

        - Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`.
        - :class:`.Task` defaults given to ``.Task.__init__``.
        - :class:`.Task` manual overrides by setting :py:attr:`~data_pipeline`.
        - :class:`.DataPipeline` passed to this method.

        Args:
            data_pipeline: Optional highest priority source of
                :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`.

        Returns:
            The fully resolved :class:`.DataPipeline`.
        """
        old_data_source, preprocess, postprocess, serializer = None, None, None, None

        # Datamodule
        if self.datamodule is not None and getattr(
                self.datamodule, 'data_pipeline', None) is not None:
            old_data_source = getattr(self.datamodule.data_pipeline,
                                      'data_source', None)
            preprocess = getattr(self.datamodule.data_pipeline,
                                 '_preprocess_pipeline', None)
            postprocess = getattr(self.datamodule.data_pipeline,
                                  '_postprocess_pipeline', None)
            serializer = getattr(self.datamodule.data_pipeline, '_serializer',
                                 None)

        elif self.trainer is not None and hasattr(
                self.trainer, 'datamodule') and getattr(
                    self.trainer.datamodule, 'data_pipeline',
                    None) is not None:
            old_data_source = getattr(self.trainer.datamodule.data_pipeline,
                                      'data_source', None)
            preprocess = getattr(self.trainer.datamodule.data_pipeline,
                                 '_preprocess_pipeline', None)
            postprocess = getattr(self.trainer.datamodule.data_pipeline,
                                  '_postprocess_pipeline', None)
            serializer = getattr(self.trainer.datamodule.data_pipeline,
                                 '_serializer', None)
        else:
            # TODO: we should log with low severity level that we use defaults to create
            # `preprocess`, `postprocess` and `serializer`.
            pass

        # Defaults / task attributes
        preprocess, postprocess, serializer = Task._resolve(
            preprocess,
            postprocess,
            serializer,
            self._preprocess,
            self._postprocess,
            self.serializer,
        )

        # Datapipeline
        if data_pipeline is not None:
            preprocess, postprocess, serializer = Task._resolve(
                preprocess,
                postprocess,
                serializer,
                getattr(data_pipeline, '_preprocess_pipeline', None),
                getattr(data_pipeline, '_postprocess_pipeline', None),
                getattr(data_pipeline, '_serializer', None),
            )

        data_source = data_source or old_data_source

        if isinstance(data_source, str):
            if preprocess is None:
                data_source = DataSource(
                )  # TODO: warn the user that we are not using the specified data source
            else:
                data_source = preprocess.data_source_of_name(data_source)

        data_pipeline = DataPipeline(data_source, preprocess, postprocess,
                                     serializer)
        self._data_pipeline_state = data_pipeline.initialize(
            self._data_pipeline_state)
        return data_pipeline
示例#23
0
 def postprocess(self, postprocess: Postprocess) -> None:
     self.data_pipeline = DataPipeline(self.preprocess, postprocess)
     self._postprocess = postprocess