def test_is_overridden_recursive(tmpdir):
    class TestInputTransform(InputTransform):
        @staticmethod
        def custom_transform(x):
            return x

        def collate(self):
            return self.custom_transform

        def val_collate(self):
            return self.custom_transform

    input_transform = TestInputTransform()
    assert DataPipeline._is_overridden_recursive("collate",
                                                 input_transform,
                                                 InputTransform,
                                                 prefix="val")
    assert DataPipeline._is_overridden_recursive("collate",
                                                 input_transform,
                                                 InputTransform,
                                                 prefix="train")
    assert not DataPipeline._is_overridden_recursive(
        "per_batch_transform_on_device",
        input_transform,
        InputTransform,
        prefix="train")
    assert not DataPipeline._is_overridden_recursive(
        "per_batch_transform_on_device", input_transform, InputTransform)
    with pytest.raises(
            MisconfigurationException,
            match="This function doesn't belong to the parent class"):
        assert not DataPipeline._is_overridden_recursive(
            "chocolate", input_transform, InputTransform)
def test_is_overriden_recursive(tmpdir):
    class TestPreprocess(DefaultPreprocess):
        def collate(self, *_):
            pass

        def val_collate(self, *_):
            pass

    preprocess = TestPreprocess()
    assert DataPipeline._is_overriden_recursive("collate",
                                                preprocess,
                                                Preprocess,
                                                prefix="val")
    assert DataPipeline._is_overriden_recursive("collate",
                                                preprocess,
                                                Preprocess,
                                                prefix="train")
    assert not DataPipeline._is_overriden_recursive(
        "per_batch_transform_on_device",
        preprocess,
        Preprocess,
        prefix="train")
    assert not DataPipeline._is_overriden_recursive(
        "per_batch_transform_on_device", preprocess, Preprocess)
    with pytest.raises(
            MisconfigurationException,
            match="This function doesn't belong to the parent class"):
        assert not DataPipeline._is_overriden_recursive(
            "chocolate", preprocess, Preprocess)
Esempio n. 3
0
def test_saving_with_serializers(tmpdir):

    checkpoint_file = os.path.join(tmpdir, 'tmp.ckpt')

    class CustomModel(Task):
        def __init__(self):
            super().__init__(model=torch.nn.Linear(1, 1),
                             loss_fn=torch.nn.MSELoss())

    serializer = Labels(["a", "b"])
    model = CustomModel()
    trainer = Trainer(fast_dev_run=True)
    data_pipeline = DataPipeline(preprocess=DefaultPreprocess(),
                                 serializer=serializer)
    data_pipeline.initialize()
    model.data_pipeline = data_pipeline
    assert isinstance(model.preprocess, DefaultPreprocess)
    dummy_data = DataLoader(
        list(
            zip(torch.arange(10, dtype=torch.float),
                torch.arange(10, dtype=torch.float))))
    trainer.fit(model, train_dataloader=dummy_data)
    trainer.save_checkpoint(checkpoint_file)
    model = CustomModel.load_from_checkpoint(checkpoint_file)
    assert isinstance(model._data_pipeline_state, DataPipelineState)
    assert model._data_pipeline_state._state[LabelsState] == LabelsState(
        ["a", "b"])
Esempio n. 4
0
def test_data_pipeline_predict_worker_preprocessor_and_device_preprocessor():

    preprocess = CustomPreprocess()
    data_pipeline = DataPipeline(preprocess=preprocess)

    data_pipeline.worker_preprocessor(RunningStage.TRAINING)
    with pytest.raises(MisconfigurationException, match="are mutually exclusive"):
        data_pipeline.worker_preprocessor(RunningStage.VALIDATING)
    with pytest.raises(MisconfigurationException, match="are mutually exclusive"):
        data_pipeline.worker_preprocessor(RunningStage.TESTING)
    data_pipeline.worker_preprocessor(RunningStage.PREDICTING)
Esempio n. 5
0
def test_predict_sklearn():
    """Tests that we can generate predictions from a scikit-learn ``Bunch``."""
    bunch = datasets.load_iris()
    model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
    data_pipe = DataPipeline(preprocess=TemplatePreprocess())
    out = model.predict(bunch, data_source="sklearn", data_pipeline=data_pipe)
    assert isinstance(out[0], int)
Esempio n. 6
0
def test_predict_dataset(tmpdir):
    """Tests that we can generate predictions from a pytorch geometric dataset."""
    tudataset = datasets.TUDataset(root=tmpdir, name="KKI")
    model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes)
    data_pipe = DataPipeline(preprocess=GraphClassificationPreprocess())
    out = model.predict(tudataset, data_source="datasets", data_pipeline=data_pipe)
    assert isinstance(out[0], int)
def test_serialization_data_pipeline(tmpdir):
    model = CustomModel()

    checkpoint_file = os.path.join(tmpdir, "tmp.ckpt")
    checkpoint = ModelCheckpoint(tmpdir, "test.ckpt")
    trainer = Trainer(callbacks=[checkpoint], max_epochs=1)
    dummy_data = DataLoader(list(zip(torch.arange(10, dtype=torch.float), torch.arange(10, dtype=torch.float))))
    trainer.fit(model, dummy_data)

    assert model.data_pipeline
    trainer.save_checkpoint(checkpoint_file)

    loaded_model = CustomModel.load_from_checkpoint(checkpoint_file)
    assert loaded_model.data_pipeline

    model.data_pipeline = DataPipeline(preprocess=CustomPreprocess())
    assert isinstance(model.preprocess, CustomPreprocess)

    trainer.fit(model, dummy_data)
    assert model.data_pipeline
    assert isinstance(model.preprocess, CustomPreprocess)
    trainer.save_checkpoint(checkpoint_file)

    def fn(*args, **kwargs):
        return "0.0.2"

    CustomPreprocess.version = fn

    loaded_model = CustomModel.load_from_checkpoint(checkpoint_file)
    assert loaded_model.data_pipeline
    assert isinstance(loaded_model.preprocess, CustomPreprocess)
    for file in os.listdir(tmpdir):
        if file.endswith(".ckpt"):
            os.remove(os.path.join(tmpdir, file))
Esempio n. 8
0
def test_predict_numpy():
    """Tests that we can generate predictions from a numpy array."""
    row = np.random.rand(1, DummyDataset.num_features)
    model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
    data_pipe = DataPipeline(preprocess=TemplatePreprocess())
    out = model.predict(row, data_pipeline=data_pipe)
    assert isinstance(out[0], int)
Esempio n. 9
0
def test_predict_numpy():
    img = np.ones((1, 3, 10, 20))
    model = SemanticSegmentation(2)
    data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(
        num_classes=1))
    out = model.predict(img, data_source="numpy", data_pipeline=data_pipe)
    assert isinstance(out[0], torch.Tensor)
    assert out[0].shape == (10, 20)
Esempio n. 10
0
def test_test(tmpdir):
    """Tests that the model can be tested on a pytorch geometric dataset."""
    tudataset = datasets.TUDataset(root=tmpdir, name="KKI")
    model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes)
    model.data_pipeline = DataPipeline(preprocess=GraphClassificationPreprocess())
    test_dl = torch.utils.data.DataLoader(tudataset, batch_size=4)
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.test(model, test_dl)
Esempio n. 11
0
def test_predict_numpy():
    img = np.ones((1, 3, 64, 64))
    model = SemanticSegmentation(2, backbone="mobilenetv3_large_100")
    data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1))
    out = model.predict(img, data_source="numpy", data_pipeline=data_pipe)
    assert isinstance(out[0], list)
    assert len(out[0]) == 64
    assert len(out[0][0]) == 64
Esempio n. 12
0
def test_predict_tensor():
    img = torch.rand(1, 3, 10, 20)
    model = SemanticSegmentation(2)
    data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(
        num_classes=1))
    out = model.predict(img, data_source="tensors", data_pipeline=data_pipe)
    assert isinstance(out[0], list)
    assert len(out[0]) == 10
    assert len(out[0][0]) == 20
Esempio n. 13
0
def test_predict_numpy():
    img = np.ones((1, 3, 10, 20))
    model = SemanticSegmentation(2)
    data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(
        num_classes=1))
    out = model.predict(img, data_source="numpy", data_pipeline=data_pipe)
    assert isinstance(out[0], list)
    assert len(out[0]) == 10
    assert len(out[0][0]) == 20
def test_data_pipeline_str():
    data_pipeline = DataPipeline(
        data_source=cast(DataSource, "data_source"),
        preprocess=cast(Preprocess, "preprocess"),
        postprocess=cast(Postprocess, "postprocess"),
        serializer=cast(Serializer, "serializer"),
        deserializer=cast(Deserializer, "deserializer"),
    )

    expected = "data_source=data_source, deserializer=deserializer, "
    expected += "preprocess=preprocess, postprocess=postprocess, serializer=serializer"
    assert str(data_pipeline) == (f"DataPipeline({expected})")
Esempio n. 15
0
def __configure_worker_and_device_collate_fn(
        running_stage: RunningStage,
        input_transform: InputTransform) -> Tuple[Callable, Callable]:

    from flash.core.data.data_pipeline import DataPipeline

    prefix: str = _STAGES_PREFIX[running_stage]
    transform_for_stage: _InputTransformPerStage = input_transform._transform[
        running_stage]

    per_batch_transform_overridden: bool = DataPipeline._is_overridden_recursive(
        "per_batch_transform", input_transform, InputTransform, prefix=prefix)

    per_sample_transform_on_device_overridden: bool = DataPipeline._is_overridden_recursive(
        "per_sample_transform_on_device",
        input_transform,
        InputTransform,
        prefix=prefix)

    is_per_overridden = per_batch_transform_overridden and per_sample_transform_on_device_overridden
    if transform_for_stage.collate_in_worker_from_transform is None and is_per_overridden:
        raise MisconfigurationException(
            f"{input_transform.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` "
            f"are mutually exclusive for stage {running_stage}")

    if isinstance(transform_for_stage.collate_in_worker_from_transform, bool):
        worker_collate_fn, device_collate_fn = __make_collates(
            input_transform,
            not transform_for_stage.collate_in_worker_from_transform,
            input_transform._collate)
    else:
        worker_collate_fn, device_collate_fn = __make_collates(
            input_transform, per_sample_transform_on_device_overridden,
            input_transform._collate)

    worker_collate_fn = (worker_collate_fn.collate_fn if isinstance(
        worker_collate_fn, _InputTransformProcessor) else worker_collate_fn)

    return worker_collate_fn, device_collate_fn
Esempio n. 16
0
    def _resolve_transforms(
            self,
            running_stage: RunningStage) -> Optional[Dict[str, Callable]]:
        from flash.core.data.data_pipeline import DataPipeline

        resolved_function = getattr(
            self,
            DataPipeline._resolve_function_hierarchy("default_transforms",
                                                     self, running_stage,
                                                     Preprocess))

        with CurrentRunningStageFuncContext(running_stage,
                                            "default_transforms", self):
            transforms: Optional[Dict[str, Callable]] = resolved_function()
        return transforms
Esempio n. 17
0
    def running_stage(self, running_stage: RunningStage) -> None:
        from flash.core.data.data_pipeline import DataPipeline  # noqa F811
        from flash.core.data.data_source import DataSource  # noqa F811 # TODO: something better than this

        self._running_stage = running_stage

        self._load_sample_context = CurrentRunningStageFuncContext(
            self.running_stage, "load_sample", self.data_source)

        self.load_sample: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr(
            self.data_source,
            DataPipeline._resolve_function_hierarchy(
                'load_sample',
                self.data_source,
                self.running_stage,
                DataSource,
            ))
Esempio n. 18
0
    def generate_dataset(
        self,
        data: Optional[DATA_TYPE],
        running_stage: RunningStage,
    ) -> Optional[Union[AutoDataset, IterableAutoDataset]]:
        """Generate a single dataset with the given input to :meth:`~flash.core.data.data_source.DataSource.load_data` for
        the given ``running_stage``.

        Args:
            data: The input to :meth:`~flash.core.data.data_source.DataSource.load_data` to use to create the dataset.
            running_stage: The running_stage for this dataset.

        Returns:
            The constructed :class:`~flash.core.data.auto_dataset.BaseAutoDataset`.
        """
        is_none = data is None

        if isinstance(data, Sequence):
            is_none = data[0] is None

        if not is_none:
            from flash.core.data.data_pipeline import DataPipeline

            mock_dataset = typing.cast(AutoDataset, MockDataset())
            with CurrentRunningStageFuncContext(running_stage, "load_data",
                                                self):
                resolved_func_name = DataPipeline._resolve_function_hierarchy(
                    "load_data", self, running_stage, DataSource)
                load_data: Callable[[DATA_TYPE, Optional[Any]],
                                    Any] = getattr(self, resolved_func_name)
                parameters = signature(load_data).parameters
                if len(
                        parameters
                ) > 1 and "dataset" in parameters:  # TODO: This was DATASET_KEY before
                    data = load_data(data, mock_dataset)
                else:
                    data = load_data(data)

            if has_len(data):
                dataset = AutoDataset(data, self, running_stage)
            else:
                dataset = IterableAutoDataset(data, self, running_stage)
            dataset.__dict__.update(mock_dataset.metadata)
            return dataset
def test_data_pipeline_init_and_assignement(use_preprocess, use_postprocess,
                                            tmpdir):
    class CustomModel(Task):
        def __init__(self, postprocess: Optional[Postprocess] = None):
            super().__init__(model=torch.nn.Linear(1, 1),
                             loss_fn=torch.nn.MSELoss())
            self._postprocess = postprocess

        def train_dataloader(self) -> Any:
            return DataLoader(DummyDataset())

    class SubPreprocess(DefaultPreprocess):
        pass

    class SubPostprocess(Postprocess):
        pass

    data_pipeline = DataPipeline(
        preprocess=SubPreprocess() if use_preprocess else None,
        postprocess=SubPostprocess() if use_postprocess else None,
    )
    assert isinstance(data_pipeline._preprocess_pipeline,
                      SubPreprocess if use_preprocess else DefaultPreprocess)
    assert isinstance(data_pipeline._postprocess_pipeline,
                      SubPostprocess if use_postprocess else Postprocess)

    model = CustomModel(postprocess=Postprocess())
    model.data_pipeline = data_pipeline
    # TODO: the line below should make the same effect but it's not
    # data_pipeline._attach_to_model(model)

    if use_preprocess:
        assert isinstance(model._preprocess, SubPreprocess)
    else:
        assert model._preprocess is None or isinstance(model._preprocess,
                                                       Preprocess)

    if use_postprocess:
        assert isinstance(model._postprocess, SubPostprocess)
    else:
        assert model._postprocess is None or isinstance(
            model._postprocess, Postprocess)
def test_detach_preprocessing_from_model(tmpdir):
    class CustomModel(Task):
        def __init__(self, postprocess: Optional[Postprocess] = None):
            super().__init__(model=torch.nn.Linear(1, 1),
                             loss_fn=torch.nn.MSELoss())
            self._postprocess = postprocess

        def train_dataloader(self) -> Any:
            return DataLoader(DummyDataset())

    preprocess = CustomPreprocess()
    data_pipeline = DataPipeline(preprocess=preprocess)
    model = CustomModel()
    model.data_pipeline = data_pipeline

    assert model.train_dataloader().collate_fn == default_collate
    assert model.transfer_batch_to_device.__self__ == model
    model.on_train_dataloader()
    assert isinstance(model.train_dataloader().collate_fn, _Preprocessor)
    assert isinstance(model.transfer_batch_to_device, _StageOrchestrator)
    model.on_fit_end()
    assert model.transfer_batch_to_device.__self__ == model
    assert model.train_dataloader().collate_fn == default_collate
def test_attaching_datapipeline_to_model(tmpdir):
    class SubPreprocess(DefaultPreprocess):
        pass

    preprocess = SubPreprocess()
    data_pipeline = DataPipeline(preprocess=preprocess)

    class CustomModel(Task):
        def __init__(self):
            super().__init__(model=torch.nn.Linear(1, 1),
                             loss_fn=torch.nn.MSELoss())
            self._postprocess = Postprocess()

        def training_step(self, batch: Any, batch_idx: int) -> Any:
            pass

        def validation_step(self, batch: Any, batch_idx: int) -> Any:
            pass

        def test_step(self, batch: Any, batch_idx: int) -> Any:
            pass

        def train_dataloader(self) -> Any:
            return DataLoader(DummyDataset())

        def val_dataloader(self) -> Any:
            return DataLoader(DummyDataset())

        def test_dataloader(self) -> Any:
            return DataLoader(DummyDataset())

        def predict_dataloader(self) -> Any:
            return DataLoader(DummyDataset())

    class TestModel(CustomModel):

        stages = [
            RunningStage.TRAINING, RunningStage.VALIDATING,
            RunningStage.TESTING, RunningStage.PREDICTING
        ]
        on_train_start_called = False
        on_val_start_called = False
        on_test_start_called = False
        on_predict_start_called = False

        def on_fit_start(self):
            assert self.predict_step.__self__ == self
            self._saved_predict_step = self.predict_step

        def _compare_pre_processor(self, p1, p2):
            p1_seq = p1.per_sample_transform
            p2_seq = p2.per_sample_transform
            assert p1_seq.pre_tensor_transform.func == p2_seq.pre_tensor_transform.func
            assert p1_seq.to_tensor_transform.func == p2_seq.to_tensor_transform.func
            assert p1_seq.post_tensor_transform.func == p2_seq.post_tensor_transform.func
            assert p1.collate_fn.func == p2.collate_fn.func
            assert p1.per_batch_transform.func == p2.per_batch_transform.func

        def _assert_stage_orchestrator_state(
                self,
                stage_mapping: Dict,
                current_running_stage: RunningStage,
                cls=_Preprocessor):
            assert isinstance(stage_mapping[current_running_stage], cls)
            assert stage_mapping[current_running_stage]

        def on_train_dataloader(self) -> None:
            current_running_stage = RunningStage.TRAINING
            self.on_train_dataloader_called = True
            collate_fn = self.train_dataloader().collate_fn  # noqa F811
            assert collate_fn == default_collate
            assert not isinstance(self.transfer_batch_to_device,
                                  _StageOrchestrator)
            super().on_train_dataloader()
            collate_fn = self.train_dataloader().collate_fn  # noqa F811
            assert collate_fn.stage == current_running_stage
            self._compare_pre_processor(
                collate_fn,
                self.data_pipeline.worker_preprocessor(current_running_stage))
            assert isinstance(self.transfer_batch_to_device,
                              _StageOrchestrator)
            self._assert_stage_orchestrator_state(
                self.transfer_batch_to_device._stage_mapping,
                current_running_stage)

        def on_val_dataloader(self) -> None:
            current_running_stage = RunningStage.VALIDATING
            self.on_val_dataloader_called = True
            collate_fn = self.val_dataloader().collate_fn  # noqa F811
            assert collate_fn == default_collate
            assert isinstance(self.transfer_batch_to_device,
                              _StageOrchestrator)
            super().on_val_dataloader()
            collate_fn = self.val_dataloader().collate_fn  # noqa F811
            assert collate_fn.stage == current_running_stage
            self._compare_pre_processor(
                collate_fn,
                self.data_pipeline.worker_preprocessor(current_running_stage))
            assert isinstance(self.transfer_batch_to_device,
                              _StageOrchestrator)
            self._assert_stage_orchestrator_state(
                self.transfer_batch_to_device._stage_mapping,
                current_running_stage)

        def on_test_dataloader(self) -> None:
            current_running_stage = RunningStage.TESTING
            self.on_test_dataloader_called = True
            collate_fn = self.test_dataloader().collate_fn  # noqa F811
            assert collate_fn == default_collate
            assert not isinstance(self.transfer_batch_to_device,
                                  _StageOrchestrator)
            super().on_test_dataloader()
            collate_fn = self.test_dataloader().collate_fn  # noqa F811
            assert collate_fn.stage == current_running_stage
            self._compare_pre_processor(
                collate_fn,
                self.data_pipeline.worker_preprocessor(current_running_stage))
            assert isinstance(self.transfer_batch_to_device,
                              _StageOrchestrator)
            self._assert_stage_orchestrator_state(
                self.transfer_batch_to_device._stage_mapping,
                current_running_stage)

        def on_predict_dataloader(self) -> None:
            current_running_stage = RunningStage.PREDICTING
            self.on_predict_dataloader_called = True
            collate_fn = self.predict_dataloader().collate_fn  # noqa F811
            assert collate_fn == default_collate
            assert isinstance(self.transfer_batch_to_device,
                              _StageOrchestrator)
            assert self.predict_step == self._saved_predict_step
            super().on_predict_dataloader()
            collate_fn = self.predict_dataloader().collate_fn  # noqa F811
            assert collate_fn.stage == current_running_stage
            self._compare_pre_processor(
                collate_fn,
                self.data_pipeline.worker_preprocessor(current_running_stage))
            assert isinstance(self.transfer_batch_to_device,
                              _StageOrchestrator)
            assert isinstance(self.predict_step, _StageOrchestrator)
            self._assert_stage_orchestrator_state(
                self.transfer_batch_to_device._stage_mapping,
                current_running_stage)
            self._assert_stage_orchestrator_state(
                self.predict_step._stage_mapping,
                current_running_stage,
                cls=_Postprocessor)

        def on_fit_end(self) -> None:
            super().on_fit_end()
            assert self.train_dataloader().collate_fn == default_collate
            assert self.val_dataloader().collate_fn == default_collate
            assert self.test_dataloader().collate_fn == default_collate
            assert self.predict_dataloader().collate_fn == default_collate
            assert not isinstance(self.transfer_batch_to_device,
                                  _StageOrchestrator)
            assert self.predict_step == self._saved_predict_step

    model = TestModel()
    model.data_pipeline = data_pipeline
    trainer = Trainer(fast_dev_run=True)
    trainer.fit(model)
    trainer.test(model)
    trainer.predict(model)

    assert model.on_train_dataloader_called
    assert model.on_val_dataloader_called
    assert model.on_test_dataloader_called
    assert model.on_predict_dataloader_called
def test_data_pipeline_is_overriden_and_resolve_function_hierarchy(tmpdir):
    class CustomPreprocess(DefaultPreprocess):
        def val_pre_tensor_transform(self, *_, **__):
            pass

        def predict_to_tensor_transform(self, *_, **__):
            pass

        def train_post_tensor_transform(self, *_, **__):
            pass

        def test_collate(self, *_, **__):
            pass

        def val_per_sample_transform_on_device(self, *_, **__):
            pass

        def train_per_batch_transform_on_device(self, *_, **__):
            pass

        def test_per_batch_transform_on_device(self, *_, **__):
            pass

    preprocess = CustomPreprocess()
    data_pipeline = DataPipeline(preprocess=preprocess)

    train_func_names: Dict[str, str] = {
        k: data_pipeline._resolve_function_hierarchy(
            k, data_pipeline._preprocess_pipeline, RunningStage.TRAINING,
            Preprocess)
        for k in data_pipeline.PREPROCESS_FUNCS
    }
    val_func_names: Dict[str, str] = {
        k: data_pipeline._resolve_function_hierarchy(
            k, data_pipeline._preprocess_pipeline, RunningStage.VALIDATING,
            Preprocess)
        for k in data_pipeline.PREPROCESS_FUNCS
    }
    test_func_names: Dict[str, str] = {
        k: data_pipeline._resolve_function_hierarchy(
            k, data_pipeline._preprocess_pipeline, RunningStage.TESTING,
            Preprocess)
        for k in data_pipeline.PREPROCESS_FUNCS
    }
    predict_func_names: Dict[str, str] = {
        k: data_pipeline._resolve_function_hierarchy(
            k, data_pipeline._preprocess_pipeline, RunningStage.PREDICTING,
            Preprocess)
        for k in data_pipeline.PREPROCESS_FUNCS
    }

    # pre_tensor_transform
    assert train_func_names["pre_tensor_transform"] == "pre_tensor_transform"
    assert val_func_names["pre_tensor_transform"] == "val_pre_tensor_transform"
    assert test_func_names["pre_tensor_transform"] == "pre_tensor_transform"
    assert predict_func_names["pre_tensor_transform"] == "pre_tensor_transform"

    # to_tensor_transform
    assert train_func_names["to_tensor_transform"] == "to_tensor_transform"
    assert val_func_names["to_tensor_transform"] == "to_tensor_transform"
    assert test_func_names["to_tensor_transform"] == "to_tensor_transform"
    assert predict_func_names[
        "to_tensor_transform"] == "predict_to_tensor_transform"

    # post_tensor_transform
    assert train_func_names[
        "post_tensor_transform"] == "train_post_tensor_transform"
    assert val_func_names["post_tensor_transform"] == "post_tensor_transform"
    assert test_func_names["post_tensor_transform"] == "post_tensor_transform"
    assert predict_func_names[
        "post_tensor_transform"] == "post_tensor_transform"

    # collate
    assert train_func_names["collate"] == "collate"
    assert val_func_names["collate"] == "collate"
    assert test_func_names["collate"] == "test_collate"
    assert predict_func_names["collate"] == "collate"

    # per_sample_transform_on_device
    assert train_func_names[
        "per_sample_transform_on_device"] == "per_sample_transform_on_device"
    assert val_func_names[
        "per_sample_transform_on_device"] == "val_per_sample_transform_on_device"
    assert test_func_names[
        "per_sample_transform_on_device"] == "per_sample_transform_on_device"
    assert predict_func_names[
        "per_sample_transform_on_device"] == "per_sample_transform_on_device"

    # per_batch_transform_on_device
    assert train_func_names[
        "per_batch_transform_on_device"] == "train_per_batch_transform_on_device"
    assert val_func_names[
        "per_batch_transform_on_device"] == "per_batch_transform_on_device"
    assert test_func_names[
        "per_batch_transform_on_device"] == "test_per_batch_transform_on_device"
    assert predict_func_names[
        "per_batch_transform_on_device"] == "per_batch_transform_on_device"

    train_worker_preprocessor = data_pipeline.worker_preprocessor(
        RunningStage.TRAINING)
    val_worker_preprocessor = data_pipeline.worker_preprocessor(
        RunningStage.VALIDATING)
    test_worker_preprocessor = data_pipeline.worker_preprocessor(
        RunningStage.TESTING)
    predict_worker_preprocessor = data_pipeline.worker_preprocessor(
        RunningStage.PREDICTING)

    _seq = train_worker_preprocessor.per_sample_transform
    assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform
    assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform
    assert _seq.post_tensor_transform.func == preprocess.train_post_tensor_transform
    assert train_worker_preprocessor.collate_fn.func == preprocess.collate
    assert train_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform

    _seq = val_worker_preprocessor.per_sample_transform
    assert _seq.pre_tensor_transform.func == preprocess.val_pre_tensor_transform
    assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform
    assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform
    assert val_worker_preprocessor.collate_fn.func == DataPipeline._identity
    assert val_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform

    _seq = test_worker_preprocessor.per_sample_transform
    assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform
    assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform
    assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform
    assert test_worker_preprocessor.collate_fn.func == preprocess.test_collate
    assert test_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform

    _seq = predict_worker_preprocessor.per_sample_transform
    assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform
    assert _seq.to_tensor_transform.func == preprocess.predict_to_tensor_transform
    assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform
    assert predict_worker_preprocessor.collate_fn.func == preprocess.collate
    assert predict_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform
Esempio n. 23
0
    def build_data_pipeline(
        self,
        data_source: Optional[str] = None,
        data_pipeline: Optional[DataPipeline] = None,
    ) -> Optional[DataPipeline]:
        """Build a :class:`.DataPipeline` incorporating available
        :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess`
        objects. These will be overridden in the following resolution order (lowest priority first):

        - Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`.
        - :class:`.Task` defaults given to :meth:`.Task.__init__`.
        - :class:`.Task` manual overrides by setting :py:attr:`~data_pipeline`.
        - :class:`.DataPipeline` passed to this method.

        Args:
            data_pipeline: Optional highest priority source of
                :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess`.

        Returns:
            The fully resolved :class:`.DataPipeline`.
        """
        old_data_source, preprocess, postprocess, serializer = None, None, None, None

        # Datamodule
        if self.datamodule is not None and getattr(
                self.datamodule, 'data_pipeline', None) is not None:
            old_data_source = getattr(self.datamodule.data_pipeline,
                                      'data_source', None)
            preprocess = getattr(self.datamodule.data_pipeline,
                                 '_preprocess_pipeline', None)
            postprocess = getattr(self.datamodule.data_pipeline,
                                  '_postprocess_pipeline', None)
            serializer = getattr(self.datamodule.data_pipeline, '_serializer',
                                 None)

        elif self.trainer is not None and hasattr(
                self.trainer, 'datamodule') and getattr(
                    self.trainer.datamodule, 'data_pipeline',
                    None) is not None:
            old_data_source = getattr(self.trainer.datamodule.data_pipeline,
                                      'data_source', None)
            preprocess = getattr(self.trainer.datamodule.data_pipeline,
                                 '_preprocess_pipeline', None)
            postprocess = getattr(self.trainer.datamodule.data_pipeline,
                                  '_postprocess_pipeline', None)
            serializer = getattr(self.trainer.datamodule.data_pipeline,
                                 '_serializer', None)
        else:
            # TODO: we should log with low severity level that we use defaults to create
            # `preprocess`, `postprocess` and `serializer`.
            pass

        # Defaults / task attributes
        preprocess, postprocess, serializer = Task._resolve(
            preprocess,
            postprocess,
            serializer,
            self._preprocess,
            self._postprocess,
            self.serializer,
        )

        # Datapipeline
        if data_pipeline is not None:
            preprocess, postprocess, serializer = Task._resolve(
                preprocess,
                postprocess,
                serializer,
                getattr(data_pipeline, '_preprocess_pipeline', None),
                getattr(data_pipeline, '_postprocess_pipeline', None),
                getattr(data_pipeline, '_serializer', None),
            )

        data_source = data_source or old_data_source

        if isinstance(data_source, str):
            if preprocess is None:
                data_source = DataSource(
                )  # TODO: warn the user that we are not using the specified data source
            else:
                data_source = preprocess.data_source_of_name(data_source)

        data_pipeline = DataPipeline(data_source, preprocess, postprocess,
                                     serializer)
        self._data_pipeline_state = data_pipeline.initialize(
            self._data_pipeline_state)
        return data_pipeline
def test_preprocess_transforms(tmpdir):
    """
    This test makes sure that when a preprocess is being provided transforms as dictionaries,
    checking is done properly, and collate_in_worker_from_transform is properly extracted.
    """

    with pytest.raises(MisconfigurationException,
                       match="Transform should be a dict."):
        DefaultPreprocess(train_transform="choco")

    with pytest.raises(MisconfigurationException,
                       match="train_transform contains {'choco'}. Only"):
        DefaultPreprocess(train_transform={"choco": None})

    preprocess = DefaultPreprocess(
        train_transform={"to_tensor_transform": torch.nn.Linear(1, 1)})
    # keep is None
    assert preprocess._train_collate_in_worker_from_transform is True
    assert preprocess._val_collate_in_worker_from_transform is None
    assert preprocess._test_collate_in_worker_from_transform is None
    assert preprocess._predict_collate_in_worker_from_transform is None

    with pytest.raises(
            MisconfigurationException,
            match="`per_batch_transform` and `per_sample_transform_on_device`"
    ):
        preprocess = DefaultPreprocess(
            train_transform={
                "per_batch_transform": torch.nn.Linear(1, 1),
                "per_sample_transform_on_device": torch.nn.Linear(1, 1)
            })

    preprocess = DefaultPreprocess(
        train_transform={"per_batch_transform": torch.nn.Linear(1, 1)},
        predict_transform={
            "per_sample_transform_on_device": torch.nn.Linear(1, 1)
        })
    # keep is None
    assert preprocess._train_collate_in_worker_from_transform is True
    assert preprocess._val_collate_in_worker_from_transform is None
    assert preprocess._test_collate_in_worker_from_transform is None
    assert preprocess._predict_collate_in_worker_from_transform is False

    train_preprocessor = DataPipeline(
        preprocess=preprocess).worker_preprocessor(RunningStage.TRAINING)
    val_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor(
        RunningStage.VALIDATING)
    test_preprocessor = DataPipeline(
        preprocess=preprocess).worker_preprocessor(RunningStage.TESTING)
    predict_preprocessor = DataPipeline(
        preprocess=preprocess).worker_preprocessor(RunningStage.PREDICTING)

    assert train_preprocessor.collate_fn.func == preprocess.collate
    assert val_preprocessor.collate_fn.func == preprocess.collate
    assert test_preprocessor.collate_fn.func == preprocess.collate
    assert predict_preprocessor.collate_fn.func == DataPipeline._identity

    class CustomPreprocess(DefaultPreprocess):
        def per_sample_transform_on_device(self, sample: Any) -> Any:
            return super().per_sample_transform_on_device(sample)

        def per_batch_transform(self, batch: Any) -> Any:
            return super().per_batch_transform(batch)

    preprocess = CustomPreprocess(
        train_transform={"per_batch_transform": torch.nn.Linear(1, 1)},
        predict_transform={
            "per_sample_transform_on_device": torch.nn.Linear(1, 1)
        })
    # keep is None
    assert preprocess._train_collate_in_worker_from_transform is True
    assert preprocess._val_collate_in_worker_from_transform is None
    assert preprocess._test_collate_in_worker_from_transform is None
    assert preprocess._predict_collate_in_worker_from_transform is False

    data_pipeline = DataPipeline(preprocess=preprocess)

    train_preprocessor = data_pipeline.worker_preprocessor(
        RunningStage.TRAINING)
    with pytest.raises(
            MisconfigurationException,
            match="`per_batch_transform` and `per_sample_transform_on_device`"
    ):
        val_preprocessor = data_pipeline.worker_preprocessor(
            RunningStage.VALIDATING)
    with pytest.raises(
            MisconfigurationException,
            match="`per_batch_transform` and `per_sample_transform_on_device`"
    ):
        test_preprocessor = data_pipeline.worker_preprocessor(
            RunningStage.TESTING)
    predict_preprocessor = data_pipeline.worker_preprocessor(
        RunningStage.PREDICTING)

    assert train_preprocessor.collate_fn.func == preprocess.collate
    assert predict_preprocessor.collate_fn.func == DataPipeline._identity
Esempio n. 25
0
    def __resolve_transforms(
            self,
            running_stage: RunningStage) -> Optional[Dict[str, Callable]]:
        from flash.core.data.data_pipeline import DataPipeline

        transforms_out = {}
        stage = _STAGES_PREFIX[running_stage]

        # iterate over all transforms hook name
        for transform_name in InputTransformPlacement:

            transforms = {}
            transform_name = transform_name.value

            # iterate over all prefixes
            for key in ApplyToKeyPrefix:

                # get the resolved hook name based on the current stage
                resolved_name = DataPipeline._resolve_function_hierarchy(
                    transform_name, self, running_stage, InputTransform)
                # check if the hook name is specialized
                is_specialized_name = resolved_name.startswith(stage)

                # get the resolved hook name for apply to key on the current stage
                resolved_apply_to_key_name = DataPipeline._resolve_function_hierarchy(
                    f"{key}_{transform_name}", self, running_stage,
                    InputTransform)
                # check if resolved hook name for apply to key is specialized
                is_specialized_apply_to_key_name = resolved_apply_to_key_name.startswith(
                    stage)

                # check if they are overridden by the user
                resolve_name_overridden = DataPipeline._is_overridden(
                    resolved_name, self, InputTransform)
                resolved_apply_to_key_name_overridden = DataPipeline._is_overridden(
                    resolved_apply_to_key_name, self, InputTransform)

                if resolve_name_overridden and resolved_apply_to_key_name_overridden:
                    # if both are specialized or both aren't specialized, raise a exception
                    # It means there is priority to specialize hooks name.
                    if not (is_specialized_name
                            ^ is_specialized_apply_to_key_name):
                        raise MisconfigurationException(
                            f"Only one of {resolved_name} or {resolved_apply_to_key_name} can be overridden."
                        )

                    method_name = resolved_name if is_specialized_name else resolved_apply_to_key_name
                else:
                    method_name = resolved_apply_to_key_name if resolved_apply_to_key_name_overridden else resolved_name

                # get associated transform
                try:
                    fn = getattr(self, method_name)()
                except AttributeError as e:
                    raise AttributeError(
                        str(e) +
                        ". Hint: Call super().__init__(...) after setting all attributes."
                    )

                if not callable(fn):
                    raise MisconfigurationException(
                        f"The hook {method_name} should return a function.")

                # if the default hook is used, it should return identity, skip it.
                if fn is self._identity:
                    continue

                # wrap apply to key hook into `ApplyToKeys` with the associated key.
                if method_name == resolved_apply_to_key_name:
                    fn = ApplyToKeys(key.value, fn)

                if method_name not in transforms:
                    transforms[method_name] = fn

            # store the transforms.
            if transforms:
                transforms = list(transforms.values())
                transforms_out[transform_name] = Compose(
                    transforms) if len(transforms) > 1 else transforms[0]

        return transforms_out