コード例 #1
0
def test_preprocessing_data_pipeline_no_running_stage(with_dataset):
    pipe = DataPipeline(_AutoDatasetTestPreprocess(with_dataset))

    dataset = pipe._generate_auto_dataset(range(10), running_stage=None)

    with pytest.raises(RuntimeError, match='`__len__` for `load_sample`'):
        for idx in range(len(dataset)):
            dataset[idx]

    # will be triggered when running stage is set
    if with_dataset:
        assert not hasattr(dataset, 'load_sample_was_called')
        assert not hasattr(dataset, 'load_data_was_called')
        assert pipe._preprocess_pipeline.load_sample_with_dataset_count == 0
        assert pipe._preprocess_pipeline.load_data_with_dataset_count == 0
    else:
        assert pipe._preprocess_pipeline.load_sample_count == 0
        assert pipe._preprocess_pipeline.load_data_count == 0

    dataset.running_stage = RunningStage.TRAINING

    if with_dataset:
        assert pipe._preprocess_pipeline.train_load_data_with_dataset_count == 1
        assert dataset.train_load_data_was_called
    else:
        assert pipe._preprocess_pipeline.train_load_data_count == 1
コード例 #2
0
    def autogenerate_dataset(
        cls,
        data: Any,
        running_stage: RunningStage,
        whole_data_load_fn: Optional[Callable] = None,
        per_sample_load_fn: Optional[Callable] = None,
        data_pipeline: Optional[DataPipeline] = None,
    ) -> AutoDataset:
        """
        This function is used to generate an ``AutoDataset`` from a ``DataPipeline`` if provided
        or from the provided ``whole_data_load_fn``, ``per_sample_load_fn`` functions directly
        """

        if whole_data_load_fn is None:
            whole_data_load_fn = getattr(
                cls.preprocess_cls,
                DataPipeline._resolve_function_hierarchy(
                    'load_data', cls.preprocess_cls, running_stage,
                    Preprocess))

        if per_sample_load_fn is None:
            per_sample_load_fn = getattr(
                cls.preprocess_cls,
                DataPipeline._resolve_function_hierarchy(
                    'load_sample', cls.preprocess_cls, running_stage,
                    Preprocess))
        return AutoDataset(data,
                           whole_data_load_fn,
                           per_sample_load_fn,
                           data_pipeline,
                           running_stage=running_stage)
コード例 #3
0
def test_saving_with_serializers(tmpdir):

    checkpoint_file = os.path.join(tmpdir, 'tmp.ckpt')

    class CustomModel(Task):
        def __init__(self):
            super().__init__(model=torch.nn.Linear(1, 1),
                             loss_fn=torch.nn.MSELoss())

    serializer = Labels(["a", "b"])
    model = CustomModel()
    trainer = Trainer(fast_dev_run=True)
    data_pipeline = DataPipeline(DefaultPreprocess(), serializer=serializer)
    data_pipeline.initialize()
    model.data_pipeline = data_pipeline
    assert isinstance(model.preprocess, DefaultPreprocess)
    dummy_data = DataLoader(
        list(
            zip(torch.arange(10, dtype=torch.float),
                torch.arange(10, dtype=torch.float))))
    trainer.fit(model, train_dataloader=dummy_data)
    trainer.save_checkpoint(checkpoint_file)
    model = CustomModel.load_from_checkpoint(checkpoint_file)
    assert isinstance(model.preprocess._data_pipeline_state, DataPipelineState)
    assert model.preprocess._data_pipeline_state._state[
        ClassificationState] == ClassificationState(['a', 'b'])
コード例 #4
0
def test_is_overriden_recursive(tmpdir):
    class TestPreprocess(DefaultPreprocess):
        def collate(self, *_):
            pass

        def val_collate(self, *_):
            pass

    preprocess = TestPreprocess()
    assert DataPipeline._is_overriden_recursive("collate",
                                                preprocess,
                                                Preprocess,
                                                prefix="val")
    assert DataPipeline._is_overriden_recursive("collate",
                                                preprocess,
                                                Preprocess,
                                                prefix="train")
    assert not DataPipeline._is_overriden_recursive(
        "per_batch_transform_on_device",
        preprocess,
        Preprocess,
        prefix="train")
    assert not DataPipeline._is_overriden_recursive(
        "per_batch_transform_on_device", preprocess, Preprocess)
    with pytest.raises(
            MisconfigurationException,
            match="This function doesn't belong to the parent class"):
        assert not DataPipeline._is_overriden_recursive(
            "chocolate", preprocess, Preprocess)
コード例 #5
0
def test_data_pipeline_predict_worker_preprocessor_and_device_preprocessor():

    preprocess = CustomPreprocess()
    data_pipeline = DataPipeline(preprocess)

    data_pipeline.worker_preprocessor(RunningStage.TRAINING)
    with pytest.raises(MisconfigurationException, match="are mutual exclusive"):
        data_pipeline.worker_preprocessor(RunningStage.VALIDATING)
    with pytest.raises(MisconfigurationException, match="are mutual exclusive"):
        data_pipeline.worker_preprocessor(RunningStage.TESTING)
    data_pipeline.worker_preprocessor(RunningStage.PREDICTING)
コード例 #6
0
    def generate_dataset(
        self,
        data: Optional[DATA_TYPE],
        running_stage: RunningStage,
    ) -> Optional[Union[AutoDataset, IterableAutoDataset]]:
        is_none = data is None

        if isinstance(data, Sequence):
            is_none = data[0] is None

        if not is_none:
            from flash.data.data_pipeline import DataPipeline

            mock_dataset = typing.cast(AutoDataset, MockDataset())
            with CurrentRunningStageFuncContext(running_stage, "load_data", self):
                load_data: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr(
                    self, DataPipeline._resolve_function_hierarchy(
                        "load_data",
                        self,
                        running_stage,
                        DataSource,
                    )
                )
                parameters = signature(load_data).parameters
                if len(parameters) > 1 and "dataset" in parameters:  # TODO: This was DATASET_KEY before
                    data = load_data(data, mock_dataset)
                else:
                    data = load_data(data)

            if has_len(data):
                dataset = AutoDataset(data, self, running_stage)
            else:
                dataset = IterableAutoDataset(data, self, running_stage)
            dataset.__dict__.update(mock_dataset.metadata)
            return dataset
コード例 #7
0
def test_predict_numpy():
    img = np.ones((1, 3, 10, 20))
    model = SemanticSegmentation(2)
    data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess())
    out = model.predict(img, data_source="numpy", data_pipeline=data_pipe)
    assert isinstance(out[0], torch.Tensor)
    assert out[0].shape == (196, 196)
コード例 #8
0
def test_serialization_data_pipeline(tmpdir):
    model = CustomModel()

    checkpoint_file = os.path.join(tmpdir, 'tmp.ckpt')
    checkpoint = ModelCheckpoint(tmpdir, 'test.ckpt')
    trainer = Trainer(callbacks=[checkpoint], max_epochs=1)
    dummy_data = DataLoader(
        list(
            zip(torch.arange(10, dtype=torch.float),
                torch.arange(10, dtype=torch.float))))
    trainer.fit(model, dummy_data)

    assert model.data_pipeline is None
    trainer.save_checkpoint(checkpoint_file)

    loaded_model = CustomModel.load_from_checkpoint(checkpoint_file)
    assert loaded_model.data_pipeline is None

    model.data_pipeline = DataPipeline(CustomPreprocess())

    trainer.fit(model, dummy_data)
    assert model.data_pipeline
    assert isinstance(model.preprocess, CustomPreprocess)
    trainer.save_checkpoint(checkpoint_file)
    loaded_model = CustomModel.load_from_checkpoint(checkpoint_file)
    assert loaded_model.data_pipeline
    assert isinstance(loaded_model.preprocess, CustomPreprocess)
    for file in os.listdir(tmpdir):
        if file.endswith('.ckpt'):
            os.remove(os.path.join(tmpdir, file))
コード例 #9
0
def test_autodataset_warning():
    with pytest.warns(
            UserWarning,
            match=
            "``datapipeline`` is specified but load_sample and/or load_data are also specified"
    ):
        AutoDataset(range(10),
                    load_data=lambda x: x,
                    load_sample=lambda x: x,
                    data_pipeline=DataPipeline())
コード例 #10
0
def test_preprocessing_data_pipeline_with_running_stage(with_dataset):
    pipe = DataPipeline(_AutoDatasetTestPreprocess(with_dataset))

    running_stage = RunningStage.TRAINING

    dataset = pipe._generate_auto_dataset(range(10), running_stage=running_stage)

    assert len(dataset) == 10

    for idx in range(len(dataset)):
        dataset[idx]

    if with_dataset:
        assert dataset.train_load_sample_was_called
        assert dataset.train_load_data_was_called
        assert pipe._preprocess_pipeline.train_load_sample_with_dataset_count == len(dataset)
        assert pipe._preprocess_pipeline.train_load_data_with_dataset_count == 1
    else:
        assert pipe._preprocess_pipeline.train_load_sample_count == len(dataset)
        assert pipe._preprocess_pipeline.train_load_data_count == 1
コード例 #11
0
    def from_load_data_inputs(
        cls,
        train_load_data_input: Optional[Any] = None,
        val_load_data_input: Optional[Any] = None,
        test_load_data_input: Optional[Any] = None,
        predict_load_data_input: Optional[Any] = None,
        preprocess: Optional[Preprocess] = None,
        postprocess: Optional[Postprocess] = None,
        **kwargs,
    ) -> 'DataModule':
        """
        This functions is an helper to generate a ``DataModule`` from a ``DataPipeline``.

        Args:
            cls: ``DataModule`` subclass
            train_load_data_input: Data to be received by the ``train_load_data`` function from this ``Preprocess``
            val_load_data_input: Data to be received by the ``val_load_data`` function from this ``Preprocess``
            test_load_data_input: Data to be received by the ``test_load_data`` function from this ``Preprocess``
            predict_load_data_input: Data to be received by the ``predict_load_data`` function from this ``Preprocess``
            kwargs: Any extra arguments to instantiate the provided ``DataModule``
        """
        # trick to get data_pipeline from empty DataModule
        if preprocess or postprocess:
            data_pipeline = DataPipeline(
                preprocess or cls(**kwargs).preprocess,
                postprocess or cls(**kwargs).postprocess,
            )
        else:
            data_pipeline = cls(**kwargs).data_pipeline

        train_dataset = cls._generate_dataset_if_possible(
            train_load_data_input,
            running_stage=RunningStage.TRAINING,
            data_pipeline=data_pipeline)
        val_dataset = cls._generate_dataset_if_possible(
            val_load_data_input,
            running_stage=RunningStage.VALIDATING,
            data_pipeline=data_pipeline)
        test_dataset = cls._generate_dataset_if_possible(
            test_load_data_input,
            running_stage=RunningStage.TESTING,
            data_pipeline=data_pipeline)
        predict_dataset = cls._generate_dataset_if_possible(
            predict_load_data_input,
            running_stage=RunningStage.PREDICTING,
            data_pipeline=data_pipeline)
        datamodule = cls(train_dataset=train_dataset,
                         val_dataset=val_dataset,
                         test_dataset=test_dataset,
                         predict_dataset=predict_dataset,
                         **kwargs)
        datamodule._preprocess = data_pipeline._preprocess_pipeline
        datamodule._postprocess = data_pipeline._postprocess_pipeline
        return datamodule
コード例 #12
0
    def autogenerate_dataset(
        cls,
        data: Any,
        running_stage: RunningStage,
        whole_data_load_fn: Optional[Callable] = None,
        per_sample_load_fn: Optional[Callable] = None,
        data_pipeline: Optional[DataPipeline] = None,
        use_iterable_auto_dataset: bool = False,
    ) -> BaseAutoDataset:
        """
        This function is used to generate an ``BaseAutoDataset`` from a ``DataPipeline`` if provided
        or from the provided ``whole_data_load_fn``, ``per_sample_load_fn`` functions directly
        """

        preprocess = getattr(data_pipeline, '_preprocess_pipeline', None)

        if whole_data_load_fn is None:
            whole_data_load_fn = getattr(
                preprocess,
                DataPipeline._resolve_function_hierarchy(
                    'load_data', preprocess, running_stage, Preprocess))

        if per_sample_load_fn is None:
            per_sample_load_fn = getattr(
                preprocess,
                DataPipeline._resolve_function_hierarchy(
                    'load_sample', preprocess, running_stage, Preprocess))
        if use_iterable_auto_dataset:
            return IterableAutoDataset(data,
                                       whole_data_load_fn,
                                       per_sample_load_fn,
                                       data_pipeline,
                                       running_stage=running_stage)
        return BaseAutoDataset(data,
                               whole_data_load_fn,
                               per_sample_load_fn,
                               data_pipeline,
                               running_stage=running_stage)
コード例 #13
0
def test_detach_preprocessing_from_model(tmpdir):

    preprocess = CustomPreprocess()
    data_pipeline = DataPipeline(preprocess)
    model = CustomModel()
    model.data_pipeline = data_pipeline

    assert model.train_dataloader().collate_fn == default_collate
    assert model.transfer_batch_to_device.__self__ == model
    model.on_train_dataloader()
    assert isinstance(model.train_dataloader().collate_fn, _PreProcessor)
    assert isinstance(model.transfer_batch_to_device, _StageOrchestrator)
    model.on_fit_end()
    assert model.transfer_batch_to_device.__self__ == model
    assert model.train_dataloader().collate_fn == default_collate
コード例 #14
0
    def generate_dataset(
        self,
        data: Optional[DATA_TYPE],
        running_stage: RunningStage,
    ) -> Optional[Union[AutoDataset, IterableAutoDataset]]:
        """Generate a single dataset with the given input to :meth:`~flash.data.data_source.DataSource.load_data` for
        the given ``running_stage``.

        Args:
            data: The input to :meth:`~flash.data.data_source.DataSource.load_data` to use to create the dataset.
            running_stage: The running_stage for this dataset.

        Returns:
            The constructed :class:`~flash.data.auto_dataset.BaseAutoDataset`.
        """
        is_none = data is None

        if isinstance(data, Sequence):
            is_none = data[0] is None

        if not is_none:
            from flash.data.data_pipeline import DataPipeline

            mock_dataset = typing.cast(AutoDataset, MockDataset())
            with CurrentRunningStageFuncContext(running_stage, "load_data",
                                                self):
                load_data: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr(
                    self,
                    DataPipeline._resolve_function_hierarchy(
                        "load_data",
                        self,
                        running_stage,
                        DataSource,
                    ))
                parameters = signature(load_data).parameters
                if len(
                        parameters
                ) > 1 and "dataset" in parameters:  # TODO: This was DATASET_KEY before
                    data = load_data(data, mock_dataset)
                else:
                    data = load_data(data)

            if has_len(data):
                dataset = AutoDataset(data, self, running_stage)
            else:
                dataset = IterableAutoDataset(data, self, running_stage)
            dataset.__dict__.update(mock_dataset.metadata)
            return dataset
コード例 #15
0
    def running_stage(self, running_stage: RunningStage) -> None:
        from flash.data.data_pipeline import DataPipeline  # noqa F811
        from flash.data.data_source import DataSource  # noqa F811 # TODO: something better than this

        self._running_stage = running_stage

        self._load_sample_context = CurrentRunningStageFuncContext(
            self.running_stage, "load_sample", self.data_source)

        self.load_sample: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr(
            self.data_source,
            DataPipeline._resolve_function_hierarchy(
                'load_sample',
                self.data_source,
                self.running_stage,
                DataSource,
            ))
コード例 #16
0
def test_data_pipeline_init_and_assignement(use_preprocess, use_postprocess, tmpdir):

    class SubPreprocess(Preprocess):
        pass

    class SubPostprocess(Postprocess):
        pass

    data_pipeline = DataPipeline(
        SubPreprocess() if use_preprocess else None,
        SubPostprocess() if use_postprocess else None,
    )
    assert isinstance(data_pipeline._preprocess_pipeline, SubPreprocess if use_preprocess else Preprocess)
    assert isinstance(data_pipeline._postprocess_pipeline, SubPostprocess if use_postprocess else Postprocess)

    model = CustomModel(Postprocess())
    model.data_pipeline = data_pipeline
    assert isinstance(model._preprocess, SubPreprocess if use_preprocess else Preprocess)
    assert isinstance(model._postprocess, SubPostprocess if use_postprocess else Postprocess)
コード例 #17
0
    def data_pipeline(self) -> Optional[DataPipeline]:
        if self._data_pipeline is not None:
            return self._data_pipeline

        elif self.preprocess is not None or self.postprocess is not None:
            # use direct attributes here to avoid recursion with properties that also check the data_pipeline property
            return DataPipeline(self.preprocess, self.postprocess)

        elif self.datamodule is not None and getattr(
                self.datamodule, 'data_pipeline', None) is not None:
            return self.datamodule.data_pipeline

        elif self.trainer is not None and hasattr(
                self.trainer, 'datamodule') and getattr(
                    self.trainer.datamodule, 'data_pipeline',
                    None) is not None:
            return self.trainer.datamodule.data_pipeline

        return self._data_pipeline
コード例 #18
0
def test_data_pipeline_init_and_assignement(use_preprocess, use_postprocess,
                                            tmpdir):
    class CustomModel(Task):
        def __init__(self, postprocess: Optional[Postprocess] = None):
            super().__init__(model=torch.nn.Linear(1, 1),
                             loss_fn=torch.nn.MSELoss())
            self._postprocess = postprocess

        def train_dataloader(self) -> Any:
            return DataLoader(DummyDataset())

    class SubPreprocess(DefaultPreprocess):
        pass

    class SubPostprocess(Postprocess):
        pass

    data_pipeline = DataPipeline(
        preprocess=SubPreprocess() if use_preprocess else None,
        postprocess=SubPostprocess() if use_postprocess else None,
    )
    assert isinstance(data_pipeline._preprocess_pipeline,
                      SubPreprocess if use_preprocess else DefaultPreprocess)
    assert isinstance(data_pipeline._postprocess_pipeline,
                      SubPostprocess if use_postprocess else Postprocess)

    model = CustomModel(postprocess=Postprocess())
    model.data_pipeline = data_pipeline
    # TODO: the line below should make the same effect but it's not
    # data_pipeline._attach_to_model(model)

    if use_preprocess:
        assert isinstance(model._preprocess, SubPreprocess)
    else:
        assert model._preprocess is None or isinstance(model._preprocess,
                                                       Preprocess)

    if use_postprocess:
        assert isinstance(model._postprocess, SubPostprocess)
    else:
        assert model._postprocess is None or isinstance(
            model._postprocess, Postprocess)
コード例 #19
0
def test_detach_preprocessing_from_model(tmpdir):
    class CustomModel(Task):
        def __init__(self, postprocess: Optional[Postprocess] = None):
            super().__init__(model=torch.nn.Linear(1, 1),
                             loss_fn=torch.nn.MSELoss())
            self._postprocess = postprocess

        def train_dataloader(self) -> Any:
            return DataLoader(DummyDataset())

    preprocess = CustomPreprocess()
    data_pipeline = DataPipeline(preprocess=preprocess)
    model = CustomModel()
    model.data_pipeline = data_pipeline

    assert model.train_dataloader().collate_fn == default_collate
    assert model.transfer_batch_to_device.__self__ == model
    model.on_train_dataloader()
    assert isinstance(model.train_dataloader().collate_fn, _PreProcessor)
    assert isinstance(model.transfer_batch_to_device, _StageOrchestrator)
    model.on_fit_end()
    assert model.transfer_batch_to_device.__self__ == model
    assert model.train_dataloader().collate_fn == default_collate
コード例 #20
0
def test_attaching_datapipeline_to_model(tmpdir):
    class SubPreprocess(DefaultPreprocess):
        pass

    preprocess = SubPreprocess()
    data_pipeline = DataPipeline(preprocess=preprocess)

    class CustomModel(Task):
        def __init__(self):
            super().__init__(model=torch.nn.Linear(1, 1),
                             loss_fn=torch.nn.MSELoss())
            self._postprocess = Postprocess()

        def training_step(self, batch: Any, batch_idx: int) -> Any:
            pass

        def validation_step(self, batch: Any, batch_idx: int) -> Any:
            pass

        def test_step(self, batch: Any, batch_idx: int) -> Any:
            pass

        def train_dataloader(self) -> Any:
            return DataLoader(DummyDataset())

        def val_dataloader(self) -> Any:
            return DataLoader(DummyDataset())

        def test_dataloader(self) -> Any:
            return DataLoader(DummyDataset())

        def predict_dataloader(self) -> Any:
            return DataLoader(DummyDataset())

    class TestModel(CustomModel):

        stages = [
            RunningStage.TRAINING, RunningStage.VALIDATING,
            RunningStage.TESTING, RunningStage.PREDICTING
        ]
        on_train_start_called = False
        on_val_start_called = False
        on_test_start_called = False
        on_predict_start_called = False

        def on_fit_start(self):
            assert self.predict_step.__self__ == self
            self._saved_predict_step = self.predict_step

        def _compare_pre_processor(self, p1, p2):
            p1_seq = p1.per_sample_transform
            p2_seq = p2.per_sample_transform
            assert p1_seq.pre_tensor_transform.func == p2_seq.pre_tensor_transform.func
            assert p1_seq.to_tensor_transform.func == p2_seq.to_tensor_transform.func
            assert p1_seq.post_tensor_transform.func == p2_seq.post_tensor_transform.func
            assert p1.collate_fn.func == p2.collate_fn.func
            assert p1.per_batch_transform.func == p2.per_batch_transform.func

        def _assert_stage_orchestrator_state(
                self,
                stage_mapping: Dict,
                current_running_stage: RunningStage,
                cls=_PreProcessor):
            assert isinstance(stage_mapping[current_running_stage], cls)
            assert stage_mapping[current_running_stage]

        def on_train_dataloader(self) -> None:
            current_running_stage = RunningStage.TRAINING
            self.on_train_dataloader_called = True
            collate_fn = self.train_dataloader().collate_fn  # noqa F811
            assert collate_fn == default_collate
            assert not isinstance(self.transfer_batch_to_device,
                                  _StageOrchestrator)
            super().on_train_dataloader()
            collate_fn = self.train_dataloader().collate_fn  # noqa F811
            assert collate_fn.stage == current_running_stage
            self._compare_pre_processor(
                collate_fn,
                self.data_pipeline.worker_preprocessor(current_running_stage))
            assert isinstance(self.transfer_batch_to_device,
                              _StageOrchestrator)
            self._assert_stage_orchestrator_state(
                self.transfer_batch_to_device._stage_mapping,
                current_running_stage)

        def on_val_dataloader(self) -> None:
            current_running_stage = RunningStage.VALIDATING
            self.on_val_dataloader_called = True
            collate_fn = self.val_dataloader().collate_fn  # noqa F811
            assert collate_fn == default_collate
            assert isinstance(self.transfer_batch_to_device,
                              _StageOrchestrator)
            super().on_val_dataloader()
            collate_fn = self.val_dataloader().collate_fn  # noqa F811
            assert collate_fn.stage == current_running_stage
            self._compare_pre_processor(
                collate_fn,
                self.data_pipeline.worker_preprocessor(current_running_stage))
            assert isinstance(self.transfer_batch_to_device,
                              _StageOrchestrator)
            self._assert_stage_orchestrator_state(
                self.transfer_batch_to_device._stage_mapping,
                current_running_stage)

        def on_test_dataloader(self) -> None:
            current_running_stage = RunningStage.TESTING
            self.on_test_dataloader_called = True
            collate_fn = self.test_dataloader().collate_fn  # noqa F811
            assert collate_fn == default_collate
            assert not isinstance(self.transfer_batch_to_device,
                                  _StageOrchestrator)
            super().on_test_dataloader()
            collate_fn = self.test_dataloader().collate_fn  # noqa F811
            assert collate_fn.stage == current_running_stage
            self._compare_pre_processor(
                collate_fn,
                self.data_pipeline.worker_preprocessor(current_running_stage))
            assert isinstance(self.transfer_batch_to_device,
                              _StageOrchestrator)
            self._assert_stage_orchestrator_state(
                self.transfer_batch_to_device._stage_mapping,
                current_running_stage)

        def on_predict_dataloader(self) -> None:
            current_running_stage = RunningStage.PREDICTING
            self.on_predict_dataloader_called = True
            collate_fn = self.predict_dataloader().collate_fn  # noqa F811
            assert collate_fn == default_collate
            assert isinstance(self.transfer_batch_to_device,
                              _StageOrchestrator)
            assert self.predict_step == self._saved_predict_step
            super().on_predict_dataloader()
            collate_fn = self.predict_dataloader().collate_fn  # noqa F811
            assert collate_fn.stage == current_running_stage
            self._compare_pre_processor(
                collate_fn,
                self.data_pipeline.worker_preprocessor(current_running_stage))
            assert isinstance(self.transfer_batch_to_device,
                              _StageOrchestrator)
            assert isinstance(self.predict_step, _StageOrchestrator)
            self._assert_stage_orchestrator_state(
                self.transfer_batch_to_device._stage_mapping,
                current_running_stage)
            self._assert_stage_orchestrator_state(
                self.predict_step._stage_mapping,
                current_running_stage,
                cls=_PostProcessor)

        def on_fit_end(self) -> None:
            super().on_fit_end()
            assert self.train_dataloader().collate_fn == default_collate
            assert self.val_dataloader().collate_fn == default_collate
            assert self.test_dataloader().collate_fn == default_collate
            assert self.predict_dataloader().collate_fn == default_collate
            assert not isinstance(self.transfer_batch_to_device,
                                  _StageOrchestrator)
            assert self.predict_step == self._saved_predict_step

    model = TestModel()
    model.data_pipeline = data_pipeline
    trainer = Trainer(fast_dev_run=True)
    trainer.fit(model)
    trainer.test(model)
    trainer.predict(model)

    assert model.on_train_dataloader_called
    assert model.on_val_dataloader_called
    assert model.on_test_dataloader_called
    assert model.on_predict_dataloader_called
コード例 #21
0
    def from_load_data_inputs(
        cls,
        train_load_data_input: Optional[Any] = None,
        val_load_data_input: Optional[Any] = None,
        test_load_data_input: Optional[Any] = None,
        predict_load_data_input: Optional[Any] = None,
        data_fetcher: BaseDataFetcher = None,
        preprocess: Optional[Preprocess] = None,
        postprocess: Optional[Postprocess] = None,
        use_iterable_auto_dataset: bool = False,
        seed: int = 42,
        val_split: Optional[float] = None,
        **kwargs,
    ) -> 'DataModule':
        """
        This functions is an helper to generate a ``DataModule`` from a ``DataPipeline``.

        Args:
            cls: ``DataModule`` subclass
            train_load_data_input: Data to be received by the ``train_load_data`` function
                from this :class:`~flash.data.process.Preprocess`
            val_load_data_input: Data to be received by the ``val_load_data`` function
                from this :class:`~flash.data.process.Preprocess`
            test_load_data_input: Data to be received by the ``test_load_data`` function
                from this :class:`~flash.data.process.Preprocess`
            predict_load_data_input: Data to be received by the ``predict_load_data`` function
                from this :class:`~flash.data.process.Preprocess`
            kwargs: Any extra arguments to instantiate the provided ``DataModule``
        """
        # trick to get data_pipeline from empty DataModule
        if preprocess or postprocess:
            data_pipeline = DataPipeline(
                preprocess or cls(**kwargs).preprocess,
                postprocess or cls(**kwargs).postprocess,
            )
        else:
            data_pipeline = cls(**kwargs).data_pipeline

        data_fetcher: BaseDataFetcher = data_fetcher or cls.configure_data_fetcher(
        )

        data_fetcher.attach_to_preprocess(data_pipeline._preprocess_pipeline)

        train_dataset = cls._generate_dataset_if_possible(
            train_load_data_input,
            running_stage=RunningStage.TRAINING,
            data_pipeline=data_pipeline,
            use_iterable_auto_dataset=use_iterable_auto_dataset,
        )
        val_dataset = cls._generate_dataset_if_possible(
            val_load_data_input,
            running_stage=RunningStage.VALIDATING,
            data_pipeline=data_pipeline,
            use_iterable_auto_dataset=use_iterable_auto_dataset,
        )
        test_dataset = cls._generate_dataset_if_possible(
            test_load_data_input,
            running_stage=RunningStage.TESTING,
            data_pipeline=data_pipeline,
            use_iterable_auto_dataset=use_iterable_auto_dataset,
        )
        predict_dataset = cls._generate_dataset_if_possible(
            predict_load_data_input,
            running_stage=RunningStage.PREDICTING,
            data_pipeline=data_pipeline,
            use_iterable_auto_dataset=use_iterable_auto_dataset,
        )

        if train_dataset is not None and (val_split is not None
                                          and val_dataset is None):
            train_dataset, val_dataset = cls._split_train_val(
                train_dataset, val_split)

        datamodule = cls(train_dataset=train_dataset,
                         val_dataset=val_dataset,
                         test_dataset=test_dataset,
                         predict_dataset=predict_dataset,
                         **kwargs)
        datamodule._preprocess = data_pipeline._preprocess_pipeline
        datamodule._postprocess = data_pipeline._postprocess_pipeline
        data_fetcher.attach_to_datamodule(datamodule)
        return datamodule
コード例 #22
0
class Task(LightningModule):
    """A general Task.

    Args:
        model: Model to use for the task.
        loss_fn: Loss function for training
        optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`.
        metrics: Metrics to compute for training and evaluation.
        learning_rate: Learning rate to use for training, defaults to `5e-5`
    """
    def __init__(
        self,
        model: Optional[nn.Module] = None,
        loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None,
        optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
        metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
        learning_rate: float = 5e-5,
    ):
        super().__init__()
        if model is not None:
            self.model = model
        self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn)
        self.optimizer_cls = optimizer
        self.metrics = nn.ModuleDict(
            {} if metrics is None else get_callable_dict(metrics))
        self.learning_rate = learning_rate
        # TODO: should we save more? Bug on some regarding yaml if we save metrics
        self.save_hyperparameters("learning_rate", "optimizer")

        self._data_pipeline = None
        self._preprocess = None
        self._postprocess = None

    def step(self, batch: Any, batch_idx: int) -> Any:
        """
        The training/validation/test step. Override for custom behavior.
        """
        x, y = batch
        y_hat = self(x)
        output = {"y_hat": y_hat}
        losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()}
        logs = {}
        for name, metric in self.metrics.items():
            if isinstance(metric, torchmetrics.metric.Metric):
                metric(y_hat, y)
                logs[
                    name] = metric  # log the metric itself if it is of type Metric
            else:
                logs[name] = metric(y_hat, y)
        logs.update(losses)
        if len(losses.values()) > 1:
            logs["total_loss"] = sum(losses.values())
            return logs["total_loss"], logs
        output["loss"] = list(losses.values())[0]
        output["logs"] = logs
        output["y"] = y
        return output

    def forward(self, x: Any) -> Any:
        return self.model(x)

    def training_step(self, batch: Any, batch_idx: int) -> Any:
        output = self.step(batch, batch_idx)
        self.log_dict({f"train_{k}": v
                       for k, v in output["logs"].items()},
                      on_step=True,
                      on_epoch=True,
                      prog_bar=True)
        return output["loss"]

    def validation_step(self, batch: Any, batch_idx: int) -> None:
        output = self.step(batch, batch_idx)
        self.log_dict({f"val_{k}": v
                       for k, v in output["logs"].items()},
                      on_step=False,
                      on_epoch=True,
                      prog_bar=True)

    def test_step(self, batch: Any, batch_idx: int) -> None:
        output = self.step(batch, batch_idx)
        self.log_dict({f"test_{k}": v
                       for k, v in output["logs"].items()},
                      on_step=False,
                      on_epoch=True,
                      prog_bar=True)

    @predict_context
    def predict(
        self,
        x: Any,
        data_pipeline: Optional[DataPipeline] = None,
    ) -> Any:
        """
        Predict function for raw data or processed data

        Args:
            x: Input to predict. Can be raw data or processed data. If str, assumed to be a folder of data.

            data_pipeline: Use this to override the current data pipeline

        Returns:
            The post-processed model predictions
        """
        running_stage = RunningStage.PREDICTING
        data_pipeline = data_pipeline or self.data_pipeline
        x = [x for x in data_pipeline._generate_auto_dataset(x, running_stage)]
        x = data_pipeline.worker_preprocessor(running_stage)(x)
        x = self.transfer_batch_to_device(x, self.device)
        x = data_pipeline.device_preprocessor(running_stage)(x)
        predictions = self.predict_step(
            x, 0)  # batch_idx is always 0 when running with `model.predict`
        predictions = data_pipeline.postprocessor(predictions)
        return predictions

    def predict_step(self,
                     batch: Any,
                     batch_idx: int,
                     dataloader_idx: int = 0) -> Any:
        if isinstance(batch, tuple):
            batch = batch[0]
        elif isinstance(batch, list):
            # Todo: Understand why stack is needed
            batch = torch.stack(batch)
        return self(batch)

    def configure_optimizers(self) -> torch.optim.Optimizer:
        return self.optimizer_cls(filter(lambda p: p.requires_grad,
                                         self.parameters()),
                                  lr=self.learning_rate)

    def configure_finetune_callback(self) -> List[Callback]:
        return []

    @property
    def preprocess(self) -> Optional[Preprocess]:
        return getattr(self._data_pipeline, '_preprocess_pipeline',
                       None) or self._preprocess

    @preprocess.setter
    def preprocess(self, preprocess: Preprocess) -> None:
        self._preprocess = preprocess
        self.data_pipeline = DataPipeline(preprocess, self.postprocess)

    @property
    def postprocess(self) -> Postprocess:
        return getattr(self._data_pipeline, '_postprocess_pipeline',
                       None) or self._postprocess

    @postprocess.setter
    def postprocess(self, postprocess: Postprocess) -> None:
        self.data_pipeline = DataPipeline(self.preprocess, postprocess)
        self._postprocess = postprocess

    @property
    def data_pipeline(self) -> Optional[DataPipeline]:
        if self._data_pipeline is not None:
            return self._data_pipeline

        elif self.preprocess is not None or self.postprocess is not None:
            # use direct attributes here to avoid recursion with properties that also check the data_pipeline property
            return DataPipeline(self.preprocess, self.postprocess)

        elif self.datamodule is not None and getattr(
                self.datamodule, 'data_pipeline', None) is not None:
            return self.datamodule.data_pipeline

        elif self.trainer is not None and hasattr(
                self.trainer, 'datamodule') and getattr(
                    self.trainer.datamodule, 'data_pipeline',
                    None) is not None:
            return self.trainer.datamodule.data_pipeline

        return self._data_pipeline

    @data_pipeline.setter
    def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None:
        self._data_pipeline = data_pipeline
        if data_pipeline is not None and getattr(
                data_pipeline, '_preprocess_pipeline', None) is not None:
            self._preprocess = data_pipeline._preprocess_pipeline

        if data_pipeline is not None and getattr(
                data_pipeline, '_postprocess_pipeline', None) is not None:
            if type(data_pipeline._postprocess_pipeline) != Postprocess:
                self._postprocess = data_pipeline._postprocess_pipeline

    def on_train_dataloader(self) -> None:
        if self.data_pipeline is not None:
            self.data_pipeline._detach_from_model(self, RunningStage.TRAINING)
            self.data_pipeline._attach_to_model(self, RunningStage.TRAINING)
        super().on_train_dataloader()

    def on_val_dataloader(self) -> None:
        if self.data_pipeline is not None:
            self.data_pipeline._detach_from_model(self,
                                                  RunningStage.VALIDATING)
            self.data_pipeline._attach_to_model(self, RunningStage.VALIDATING)
        super().on_val_dataloader()

    def on_test_dataloader(self, *_) -> None:
        if self.data_pipeline is not None:
            self.data_pipeline._detach_from_model(self, RunningStage.TESTING)
            self.data_pipeline._attach_to_model(self, RunningStage.TESTING)
        super().on_test_dataloader()

    def on_predict_dataloader(self) -> None:
        if self.data_pipeline is not None:
            self.data_pipeline._detach_from_model(self,
                                                  RunningStage.PREDICTING)
            self.data_pipeline._attach_to_model(self, RunningStage.PREDICTING)
        super().on_predict_dataloader()

    def on_predict_end(self) -> None:
        if self.data_pipeline is not None:
            self.data_pipeline._detach_from_model(self)
        super().on_predict_end()

    def on_fit_end(self) -> None:
        if self.data_pipeline is not None:
            self.data_pipeline._detach_from_model(self)
        super().on_fit_end()

    @staticmethod
    def _sanetize_funcs(obj: Any) -> Any:
        if hasattr(obj, "__dict__"):
            for k, v in obj.__dict__.items():
                if isinstance(v, Callable):
                    obj.__dict__[k] = inspect.unwrap(v)
        return obj

    def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        # TODO: Is this the best way to do this? or should we also use some kind of hparams here?
        # This may be an issue since here we create the same problems with pickle as in
        # https://pytorch.org/docs/stable/notes/serialization.html
        if self.data_pipeline is not None and 'data_pipeline' not in checkpoint:
            self._preprocess = self._sanetize_funcs(self._preprocess)
            checkpoint['data_pipeline'] = self.data_pipeline
            # todo (tchaton) re-wrap visualization
        super().on_save_checkpoint(checkpoint)

    def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        super().on_load_checkpoint(checkpoint)
        if 'data_pipeline' in checkpoint:
            self.data_pipeline = checkpoint['data_pipeline']
コード例 #23
0
 def preprocess(self, preprocess: Preprocess) -> None:
     self._preprocess = preprocess
     self.data_pipeline = DataPipeline(preprocess, self.postprocess)
コード例 #24
0
 def data_pipeline(self) -> DataPipeline:
     return DataPipeline(self.preprocess, self.postprocess)
コード例 #25
0
def test_data_pipeline_is_overriden_and_resolve_function_hierarchy(tmpdir):

    class CustomPreprocess(Preprocess):

        def load_data(self, *_, **__):
            pass

        def test_load_data(self, *_, **__):
            pass

        def predict_load_data(self, *_, **__):
            pass

        def predict_load_sample(self, *_, **__):
            pass

        def val_load_sample(self, *_, **__):
            pass

        def val_pre_tensor_transform(self, *_, **__):
            pass

        def predict_to_tensor_transform(self, *_, **__):
            pass

        def train_post_tensor_transform(self, *_, **__):
            pass

        def test_collate(self, *_, **__):
            pass

        def val_per_sample_transform_on_device(self, *_, **__):
            pass

        def train_per_batch_transform_on_device(self, *_, **__):
            pass

        def test_per_batch_transform_on_device(self, *_, **__):
            pass

    preprocess = CustomPreprocess()
    data_pipeline = DataPipeline(preprocess)
    train_func_names = {
        k: data_pipeline._resolve_function_hierarchy(
            k, data_pipeline._preprocess_pipeline, RunningStage.TRAINING, Preprocess
        )
        for k in data_pipeline.PREPROCESS_FUNCS
    }
    val_func_names = {
        k: data_pipeline._resolve_function_hierarchy(
            k, data_pipeline._preprocess_pipeline, RunningStage.VALIDATING, Preprocess
        )
        for k in data_pipeline.PREPROCESS_FUNCS
    }
    test_func_names = {
        k: data_pipeline._resolve_function_hierarchy(
            k, data_pipeline._preprocess_pipeline, RunningStage.TESTING, Preprocess
        )
        for k in data_pipeline.PREPROCESS_FUNCS
    }
    predict_func_names = {
        k: data_pipeline._resolve_function_hierarchy(
            k, data_pipeline._preprocess_pipeline, RunningStage.PREDICTING, Preprocess
        )
        for k in data_pipeline.PREPROCESS_FUNCS
    }
    # load_data
    assert train_func_names["load_data"] == "load_data"
    assert val_func_names["load_data"] == "load_data"
    assert test_func_names["load_data"] == "test_load_data"
    assert predict_func_names["load_data"] == "predict_load_data"

    # load_sample
    assert train_func_names["load_sample"] == "load_sample"
    assert val_func_names["load_sample"] == "val_load_sample"
    assert test_func_names["load_sample"] == "load_sample"
    assert predict_func_names["load_sample"] == "predict_load_sample"

    # pre_tensor_transform
    assert train_func_names["pre_tensor_transform"] == "pre_tensor_transform"
    assert val_func_names["pre_tensor_transform"] == "val_pre_tensor_transform"
    assert test_func_names["pre_tensor_transform"] == "pre_tensor_transform"
    assert predict_func_names["pre_tensor_transform"] == "pre_tensor_transform"

    # to_tensor_transform
    assert train_func_names["to_tensor_transform"] == "to_tensor_transform"
    assert val_func_names["to_tensor_transform"] == "to_tensor_transform"
    assert test_func_names["to_tensor_transform"] == "to_tensor_transform"
    assert predict_func_names["to_tensor_transform"] == "predict_to_tensor_transform"

    # post_tensor_transform
    assert train_func_names["post_tensor_transform"] == "train_post_tensor_transform"
    assert val_func_names["post_tensor_transform"] == "post_tensor_transform"
    assert test_func_names["post_tensor_transform"] == "post_tensor_transform"
    assert predict_func_names["post_tensor_transform"] == "post_tensor_transform"

    # collate
    assert train_func_names["collate"] == "collate"
    assert val_func_names["collate"] == "collate"
    assert test_func_names["collate"] == "test_collate"
    assert predict_func_names["collate"] == "collate"

    # per_sample_transform_on_device
    assert train_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device"
    assert val_func_names["per_sample_transform_on_device"] == "val_per_sample_transform_on_device"
    assert test_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device"
    assert predict_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device"

    # per_batch_transform_on_device
    assert train_func_names["per_batch_transform_on_device"] == "train_per_batch_transform_on_device"
    assert val_func_names["per_batch_transform_on_device"] == "per_batch_transform_on_device"
    assert test_func_names["per_batch_transform_on_device"] == "test_per_batch_transform_on_device"
    assert predict_func_names["per_batch_transform_on_device"] == "per_batch_transform_on_device"

    train_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TRAINING)
    val_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.VALIDATING)
    test_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TESTING)
    predict_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.PREDICTING)

    _seq = train_worker_preprocessor.per_sample_transform
    assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform
    assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform
    assert _seq.post_tensor_transform.func == preprocess.train_post_tensor_transform
    assert train_worker_preprocessor.collate_fn.func == default_collate
    assert train_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform

    _seq = val_worker_preprocessor.per_sample_transform
    assert _seq.pre_tensor_transform.func == preprocess.val_pre_tensor_transform
    assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform
    assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform
    assert val_worker_preprocessor.collate_fn.func == data_pipeline._identity
    assert val_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform

    _seq = test_worker_preprocessor.per_sample_transform
    assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform
    assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform
    assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform
    assert test_worker_preprocessor.collate_fn.func == preprocess.test_collate
    assert test_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform

    _seq = predict_worker_preprocessor.per_sample_transform
    assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform
    assert _seq.to_tensor_transform.func == preprocess.predict_to_tensor_transform
    assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform
    assert predict_worker_preprocessor.collate_fn.func == default_collate
    assert predict_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform
コード例 #26
0
def test_attaching_datapipeline_to_model(tmpdir):

    preprocess = TestPreprocess()
    data_pipeline = DataPipeline(preprocess)

    class TestModel(CustomModel):

        stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING]
        on_train_start_called = False
        on_val_start_called = False
        on_test_start_called = False
        on_predict_start_called = False

        def on_fit_start(self):
            assert self.predict_step.__self__ == self
            self._saved_predict_step = self.predict_step

        def _compare_pre_processor(self, p1, p2):
            p1_seq = p1.per_sample_transform
            p2_seq = p2.per_sample_transform
            assert p1_seq.pre_tensor_transform.func == p2_seq.pre_tensor_transform.func
            assert p1_seq.to_tensor_transform.func == p2_seq.to_tensor_transform.func
            assert p1_seq.post_tensor_transform.func == p2_seq.post_tensor_transform.func
            assert p1.collate_fn.func == p2.collate_fn.func
            assert p1.per_batch_transform.func == p2.per_batch_transform.func

        def _assert_stage_orchestrator_state(
            self, stage_mapping: Dict, current_running_stage: RunningStage, cls=_PreProcessor
        ):
            assert isinstance(stage_mapping[current_running_stage], cls)
            assert stage_mapping[current_running_stage]

        def on_train_dataloader(self) -> None:
            current_running_stage = RunningStage.TRAINING
            self.on_train_dataloader_called = True
            collate_fn = self.train_dataloader().collate_fn  # noqa F811
            assert collate_fn == default_collate
            assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator)
            super().on_train_dataloader()
            collate_fn = self.train_dataloader().collate_fn  # noqa F811
            assert collate_fn.stage == current_running_stage
            self._compare_pre_processor(collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage))
            assert isinstance(self.transfer_batch_to_device, _StageOrchestrator)
            self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage)

        def on_val_dataloader(self) -> None:
            current_running_stage = RunningStage.VALIDATING
            self.on_val_dataloader_called = True
            collate_fn = self.val_dataloader().collate_fn  # noqa F811
            assert collate_fn == default_collate
            assert isinstance(self.transfer_batch_to_device, _StageOrchestrator)
            super().on_val_dataloader()
            collate_fn = self.val_dataloader().collate_fn  # noqa F811
            assert collate_fn.stage == current_running_stage
            self._compare_pre_processor(collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage))
            assert isinstance(self.transfer_batch_to_device, _StageOrchestrator)
            self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage)

        def on_test_dataloader(self) -> None:
            current_running_stage = RunningStage.TESTING
            self.on_test_dataloader_called = True
            collate_fn = self.test_dataloader().collate_fn  # noqa F811
            assert collate_fn == default_collate
            assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator)
            super().on_test_dataloader()
            collate_fn = self.test_dataloader().collate_fn  # noqa F811
            assert collate_fn.stage == current_running_stage
            self._compare_pre_processor(collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage))
            assert isinstance(self.transfer_batch_to_device, _StageOrchestrator)
            self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage)

        def on_predict_dataloader(self) -> None:
            current_running_stage = RunningStage.PREDICTING
            self.on_predict_dataloader_called = True
            collate_fn = self.predict_dataloader().collate_fn  # noqa F811
            assert collate_fn == default_collate
            assert isinstance(self.transfer_batch_to_device, _StageOrchestrator)
            assert self.predict_step == self._saved_predict_step
            super().on_predict_dataloader()
            collate_fn = self.predict_dataloader().collate_fn  # noqa F811
            assert collate_fn.stage == current_running_stage
            self._compare_pre_processor(collate_fn, self.data_pipeline.worker_preprocessor(current_running_stage))
            assert isinstance(self.transfer_batch_to_device, _StageOrchestrator)
            assert isinstance(self.predict_step, _StageOrchestrator)
            self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage)
            self._assert_stage_orchestrator_state(
                self.predict_step._stage_mapping, current_running_stage, cls=_PostProcessor
            )

        def on_fit_end(self) -> None:
            super().on_fit_end()
            assert self.train_dataloader().collate_fn == default_collate
            assert self.val_dataloader().collate_fn == default_collate
            assert self.test_dataloader().collate_fn == default_collate
            assert self.predict_dataloader().collate_fn == default_collate
            assert not isinstance(self.transfer_batch_to_device, _StageOrchestrator)
            assert self.predict_step == self._saved_predict_step

    datamodule = CustomDataModule()
    datamodule._data_pipeline = data_pipeline
    model = TestModel()
    trainer = Trainer(fast_dev_run=True)
    trainer.fit(model, datamodule=datamodule)
    trainer.test(model)
    trainer.predict(model)

    assert model.on_train_dataloader_called
    assert model.on_val_dataloader_called
    assert model.on_test_dataloader_called
    assert model.on_predict_dataloader_called
コード例 #27
0
 def postprocess(self, postprocess: Postprocess) -> None:
     self.data_pipeline = DataPipeline(self.preprocess, postprocess)
     self._postprocess = postprocess
コード例 #28
0
    def build_data_pipeline(
        self,
        data_source: Optional[str] = None,
        data_pipeline: Optional[DataPipeline] = None,
    ) -> Optional[DataPipeline]:
        """Build a :class:`.DataPipeline` incorporating available
        :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`
        objects. These will be overridden in the following resolution order (lowest priority first):

        - Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`.
        - :class:`.Task` defaults given to ``.Task.__init__``.
        - :class:`.Task` manual overrides by setting :py:attr:`~data_pipeline`.
        - :class:`.DataPipeline` passed to this method.

        Args:
            data_pipeline: Optional highest priority source of
                :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`.

        Returns:
            The fully resolved :class:`.DataPipeline`.
        """
        old_data_source, preprocess, postprocess, serializer = None, None, None, None

        # Datamodule
        if self.datamodule is not None and getattr(
                self.datamodule, 'data_pipeline', None) is not None:
            old_data_source = getattr(self.datamodule.data_pipeline,
                                      'data_source', None)
            preprocess = getattr(self.datamodule.data_pipeline,
                                 '_preprocess_pipeline', None)
            postprocess = getattr(self.datamodule.data_pipeline,
                                  '_postprocess_pipeline', None)
            serializer = getattr(self.datamodule.data_pipeline, '_serializer',
                                 None)

        elif self.trainer is not None and hasattr(
                self.trainer, 'datamodule') and getattr(
                    self.trainer.datamodule, 'data_pipeline',
                    None) is not None:
            old_data_source = getattr(self.trainer.datamodule.data_pipeline,
                                      'data_source', None)
            preprocess = getattr(self.trainer.datamodule.data_pipeline,
                                 '_preprocess_pipeline', None)
            postprocess = getattr(self.trainer.datamodule.data_pipeline,
                                  '_postprocess_pipeline', None)
            serializer = getattr(self.trainer.datamodule.data_pipeline,
                                 '_serializer', None)
        else:
            # TODO: we should log with low severity level that we use defaults to create
            # `preprocess`, `postprocess` and `serializer`.
            pass

        # Defaults / task attributes
        preprocess, postprocess, serializer = Task._resolve(
            preprocess,
            postprocess,
            serializer,
            self._preprocess,
            self._postprocess,
            self.serializer,
        )

        # Datapipeline
        if data_pipeline is not None:
            preprocess, postprocess, serializer = Task._resolve(
                preprocess,
                postprocess,
                serializer,
                getattr(data_pipeline, '_preprocess_pipeline', None),
                getattr(data_pipeline, '_postprocess_pipeline', None),
                getattr(data_pipeline, '_serializer', None),
            )

        data_source = data_source or old_data_source

        if isinstance(data_source, str):
            if preprocess is None:
                data_source = DataSource(
                )  # TODO: warn the user that we are not using the specified data source
            else:
                data_source = preprocess.data_source_of_name(data_source)

        data_pipeline = DataPipeline(data_source, preprocess, postprocess,
                                     serializer)
        self._data_pipeline_state = data_pipeline.initialize(
            self._data_pipeline_state)
        return data_pipeline
コード例 #29
0
def test_preprocess_transforms(tmpdir):
    """
    This test makes sure that when a preprocess is being provided transforms as dictionaries,
    checking is done properly, and collate_in_worker_from_transform is properly extracted.
    """

    with pytest.raises(MisconfigurationException,
                       match="Transform should be a dict."):
        DefaultPreprocess(train_transform="choco")

    with pytest.raises(MisconfigurationException,
                       match="train_transform contains {'choco'}. Only"):
        DefaultPreprocess(train_transform={"choco": None})

    preprocess = DefaultPreprocess(
        train_transform={"to_tensor_transform": torch.nn.Linear(1, 1)})
    # keep is None
    assert preprocess._train_collate_in_worker_from_transform is True
    assert preprocess._val_collate_in_worker_from_transform is None
    assert preprocess._test_collate_in_worker_from_transform is None
    assert preprocess._predict_collate_in_worker_from_transform is None

    with pytest.raises(
            MisconfigurationException,
            match="`per_batch_transform` and `per_sample_transform_on_device`"
    ):
        preprocess = DefaultPreprocess(
            train_transform={
                "per_batch_transform": torch.nn.Linear(1, 1),
                "per_sample_transform_on_device": torch.nn.Linear(1, 1)
            })

    preprocess = DefaultPreprocess(
        train_transform={"per_batch_transform": torch.nn.Linear(1, 1)},
        predict_transform={
            "per_sample_transform_on_device": torch.nn.Linear(1, 1)
        })
    # keep is None
    assert preprocess._train_collate_in_worker_from_transform is True
    assert preprocess._val_collate_in_worker_from_transform is None
    assert preprocess._test_collate_in_worker_from_transform is None
    assert preprocess._predict_collate_in_worker_from_transform is False

    train_preprocessor = DataPipeline(
        preprocess=preprocess).worker_preprocessor(RunningStage.TRAINING)
    val_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor(
        RunningStage.VALIDATING)
    test_preprocessor = DataPipeline(
        preprocess=preprocess).worker_preprocessor(RunningStage.TESTING)
    predict_preprocessor = DataPipeline(
        preprocess=preprocess).worker_preprocessor(RunningStage.PREDICTING)

    assert train_preprocessor.collate_fn.func == default_collate
    assert val_preprocessor.collate_fn.func == default_collate
    assert test_preprocessor.collate_fn.func == default_collate
    assert predict_preprocessor.collate_fn.func == DataPipeline._identity

    class CustomPreprocess(DefaultPreprocess):
        def per_sample_transform_on_device(self, sample: Any) -> Any:
            return super().per_sample_transform_on_device(sample)

        def per_batch_transform(self, batch: Any) -> Any:
            return super().per_batch_transform(batch)

    preprocess = CustomPreprocess(
        train_transform={"per_batch_transform": torch.nn.Linear(1, 1)},
        predict_transform={
            "per_sample_transform_on_device": torch.nn.Linear(1, 1)
        })
    # keep is None
    assert preprocess._train_collate_in_worker_from_transform is True
    assert preprocess._val_collate_in_worker_from_transform is None
    assert preprocess._test_collate_in_worker_from_transform is None
    assert preprocess._predict_collate_in_worker_from_transform is False

    data_pipeline = DataPipeline(preprocess=preprocess)

    train_preprocessor = data_pipeline.worker_preprocessor(
        RunningStage.TRAINING)
    with pytest.raises(
            MisconfigurationException,
            match="`per_batch_transform` and `per_sample_transform_on_device`"
    ):
        val_preprocessor = data_pipeline.worker_preprocessor(
            RunningStage.VALIDATING)
    with pytest.raises(
            MisconfigurationException,
            match="`per_batch_transform` and `per_sample_transform_on_device`"
    ):
        test_preprocessor = data_pipeline.worker_preprocessor(
            RunningStage.TESTING)
    predict_preprocessor = data_pipeline.worker_preprocessor(
        RunningStage.PREDICTING)

    assert train_preprocessor.collate_fn.func == default_collate
    assert predict_preprocessor.collate_fn.func == DataPipeline._identity
コード例 #30
0
ファイル: model.py プロジェクト: hahaxun/lightning-flash
    def build_data_pipeline(
        self,
        data_pipeline: Optional[DataPipeline] = None
    ) -> Optional[DataPipeline]:
        """Build a :class:`.DataPipeline` incorporating available
        :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`
        objects. These will be overridden in the following resolution order (lowest priority first):

        - Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`.
        - :class:`.Task` defaults given to ``.Task.__init__``.
        - :class:`.Task` manual overrides by setting :py:attr:`~data_pipeline`.
        - :class:`.DataPipeline` passed to this method.

        Args:
            data_pipeline: Optional highest priority source of
                :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`.

        Returns:
            The fully resolved :class:`.DataPipeline`.
        """
        preprocess, postprocess, serializer = None, None, None

        # Datamodule
        if self.datamodule is not None and getattr(
                self.datamodule, 'data_pipeline', None) is not None:
            preprocess = getattr(self.datamodule.data_pipeline,
                                 '_preprocess_pipeline', None)
            postprocess = getattr(self.datamodule.data_pipeline,
                                  '_postprocess_pipeline', None)
            serializer = getattr(self.datamodule.data_pipeline, '_serializer',
                                 None)

        elif self.trainer is not None and hasattr(
                self.trainer, 'datamodule') and getattr(
                    self.trainer.datamodule, 'data_pipeline',
                    None) is not None:
            preprocess = getattr(self.trainer.datamodule.data_pipeline,
                                 '_preprocess_pipeline', None)
            postprocess = getattr(self.trainer.datamodule.data_pipeline,
                                  '_postprocess_pipeline', None)
            serializer = getattr(self.trainer.datamodule.data_pipeline,
                                 '_serializer', None)

        # Defaults / task attributes
        preprocess, postprocess, serializer = Task._resolve(
            preprocess,
            postprocess,
            serializer,
            self._preprocess,
            self._postprocess,
            self.serializer,
        )

        # Datapipeline
        if data_pipeline is not None:
            preprocess, postprocess, serializer = Task._resolve(
                preprocess,
                postprocess,
                serializer,
                getattr(data_pipeline, '_preprocess_pipeline', None),
                getattr(data_pipeline, '_postprocess_pipeline', None),
                getattr(data_pipeline, '_serializer', None),
            )

        data_pipeline = DataPipeline(preprocess, postprocess, serializer)
        data_pipeline.initialize()
        return data_pipeline