Esempio n. 1
0
    def _resolve_transforms(
            self,
            running_stage: RunningStage) -> Optional[Dict[str, Callable]]:
        from flash.core.data.data_pipeline import DataPipeline

        resolved_function = getattr(
            self,
            DataPipeline._resolve_function_hierarchy("default_transforms",
                                                     self, running_stage,
                                                     Preprocess))

        with CurrentRunningStageFuncContext(running_stage,
                                            "default_transforms", self):
            transforms: Optional[Dict[str, Callable]] = resolved_function()
        return transforms
Esempio n. 2
0
    def running_stage(self, running_stage: RunningStage) -> None:
        from flash.core.data.data_pipeline import DataPipeline  # noqa F811
        from flash.core.data.data_source import DataSource  # noqa F811 # TODO: something better than this

        self._running_stage = running_stage

        self._load_sample_context = CurrentRunningStageFuncContext(
            self.running_stage, "load_sample", self.data_source)

        self.load_sample: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr(
            self.data_source,
            DataPipeline._resolve_function_hierarchy(
                'load_sample',
                self.data_source,
                self.running_stage,
                DataSource,
            ))
    def generate_dataset(
        self,
        data: Optional[DATA_TYPE],
        running_stage: RunningStage,
    ) -> Optional[Union[AutoDataset, IterableAutoDataset]]:
        """Generate a single dataset with the given input to :meth:`~flash.core.data.data_source.DataSource.load_data` for
        the given ``running_stage``.

        Args:
            data: The input to :meth:`~flash.core.data.data_source.DataSource.load_data` to use to create the dataset.
            running_stage: The running_stage for this dataset.

        Returns:
            The constructed :class:`~flash.core.data.auto_dataset.BaseAutoDataset`.
        """
        is_none = data is None

        if isinstance(data, Sequence):
            is_none = data[0] is None

        if not is_none:
            from flash.core.data.data_pipeline import DataPipeline

            mock_dataset = typing.cast(AutoDataset, MockDataset())
            with CurrentRunningStageFuncContext(running_stage, "load_data",
                                                self):
                resolved_func_name = DataPipeline._resolve_function_hierarchy(
                    "load_data", self, running_stage, DataSource)
                load_data: Callable[[DATA_TYPE, Optional[Any]],
                                    Any] = getattr(self, resolved_func_name)
                parameters = signature(load_data).parameters
                if len(
                        parameters
                ) > 1 and "dataset" in parameters:  # TODO: This was DATASET_KEY before
                    data = load_data(data, mock_dataset)
                else:
                    data = load_data(data)

            if has_len(data):
                dataset = AutoDataset(data, self, running_stage)
            else:
                dataset = IterableAutoDataset(data, self, running_stage)
            dataset.__dict__.update(mock_dataset.metadata)
            return dataset
def test_data_pipeline_is_overriden_and_resolve_function_hierarchy(tmpdir):
    class CustomPreprocess(DefaultPreprocess):
        def val_pre_tensor_transform(self, *_, **__):
            pass

        def predict_to_tensor_transform(self, *_, **__):
            pass

        def train_post_tensor_transform(self, *_, **__):
            pass

        def test_collate(self, *_, **__):
            pass

        def val_per_sample_transform_on_device(self, *_, **__):
            pass

        def train_per_batch_transform_on_device(self, *_, **__):
            pass

        def test_per_batch_transform_on_device(self, *_, **__):
            pass

    preprocess = CustomPreprocess()
    data_pipeline = DataPipeline(preprocess=preprocess)

    train_func_names: Dict[str, str] = {
        k: data_pipeline._resolve_function_hierarchy(
            k, data_pipeline._preprocess_pipeline, RunningStage.TRAINING,
            Preprocess)
        for k in data_pipeline.PREPROCESS_FUNCS
    }
    val_func_names: Dict[str, str] = {
        k: data_pipeline._resolve_function_hierarchy(
            k, data_pipeline._preprocess_pipeline, RunningStage.VALIDATING,
            Preprocess)
        for k in data_pipeline.PREPROCESS_FUNCS
    }
    test_func_names: Dict[str, str] = {
        k: data_pipeline._resolve_function_hierarchy(
            k, data_pipeline._preprocess_pipeline, RunningStage.TESTING,
            Preprocess)
        for k in data_pipeline.PREPROCESS_FUNCS
    }
    predict_func_names: Dict[str, str] = {
        k: data_pipeline._resolve_function_hierarchy(
            k, data_pipeline._preprocess_pipeline, RunningStage.PREDICTING,
            Preprocess)
        for k in data_pipeline.PREPROCESS_FUNCS
    }

    # pre_tensor_transform
    assert train_func_names["pre_tensor_transform"] == "pre_tensor_transform"
    assert val_func_names["pre_tensor_transform"] == "val_pre_tensor_transform"
    assert test_func_names["pre_tensor_transform"] == "pre_tensor_transform"
    assert predict_func_names["pre_tensor_transform"] == "pre_tensor_transform"

    # to_tensor_transform
    assert train_func_names["to_tensor_transform"] == "to_tensor_transform"
    assert val_func_names["to_tensor_transform"] == "to_tensor_transform"
    assert test_func_names["to_tensor_transform"] == "to_tensor_transform"
    assert predict_func_names[
        "to_tensor_transform"] == "predict_to_tensor_transform"

    # post_tensor_transform
    assert train_func_names[
        "post_tensor_transform"] == "train_post_tensor_transform"
    assert val_func_names["post_tensor_transform"] == "post_tensor_transform"
    assert test_func_names["post_tensor_transform"] == "post_tensor_transform"
    assert predict_func_names[
        "post_tensor_transform"] == "post_tensor_transform"

    # collate
    assert train_func_names["collate"] == "collate"
    assert val_func_names["collate"] == "collate"
    assert test_func_names["collate"] == "test_collate"
    assert predict_func_names["collate"] == "collate"

    # per_sample_transform_on_device
    assert train_func_names[
        "per_sample_transform_on_device"] == "per_sample_transform_on_device"
    assert val_func_names[
        "per_sample_transform_on_device"] == "val_per_sample_transform_on_device"
    assert test_func_names[
        "per_sample_transform_on_device"] == "per_sample_transform_on_device"
    assert predict_func_names[
        "per_sample_transform_on_device"] == "per_sample_transform_on_device"

    # per_batch_transform_on_device
    assert train_func_names[
        "per_batch_transform_on_device"] == "train_per_batch_transform_on_device"
    assert val_func_names[
        "per_batch_transform_on_device"] == "per_batch_transform_on_device"
    assert test_func_names[
        "per_batch_transform_on_device"] == "test_per_batch_transform_on_device"
    assert predict_func_names[
        "per_batch_transform_on_device"] == "per_batch_transform_on_device"

    train_worker_preprocessor = data_pipeline.worker_preprocessor(
        RunningStage.TRAINING)
    val_worker_preprocessor = data_pipeline.worker_preprocessor(
        RunningStage.VALIDATING)
    test_worker_preprocessor = data_pipeline.worker_preprocessor(
        RunningStage.TESTING)
    predict_worker_preprocessor = data_pipeline.worker_preprocessor(
        RunningStage.PREDICTING)

    _seq = train_worker_preprocessor.per_sample_transform
    assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform
    assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform
    assert _seq.post_tensor_transform.func == preprocess.train_post_tensor_transform
    assert train_worker_preprocessor.collate_fn.func == preprocess.collate
    assert train_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform

    _seq = val_worker_preprocessor.per_sample_transform
    assert _seq.pre_tensor_transform.func == preprocess.val_pre_tensor_transform
    assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform
    assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform
    assert val_worker_preprocessor.collate_fn.func == DataPipeline._identity
    assert val_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform

    _seq = test_worker_preprocessor.per_sample_transform
    assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform
    assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform
    assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform
    assert test_worker_preprocessor.collate_fn.func == preprocess.test_collate
    assert test_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform

    _seq = predict_worker_preprocessor.per_sample_transform
    assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform
    assert _seq.to_tensor_transform.func == preprocess.predict_to_tensor_transform
    assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform
    assert predict_worker_preprocessor.collate_fn.func == preprocess.collate
    assert predict_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform
Esempio n. 5
0
    def __resolve_transforms(
            self,
            running_stage: RunningStage) -> Optional[Dict[str, Callable]]:
        from flash.core.data.data_pipeline import DataPipeline

        transforms_out = {}
        stage = _STAGES_PREFIX[running_stage]

        # iterate over all transforms hook name
        for transform_name in InputTransformPlacement:

            transforms = {}
            transform_name = transform_name.value

            # iterate over all prefixes
            for key in ApplyToKeyPrefix:

                # get the resolved hook name based on the current stage
                resolved_name = DataPipeline._resolve_function_hierarchy(
                    transform_name, self, running_stage, InputTransform)
                # check if the hook name is specialized
                is_specialized_name = resolved_name.startswith(stage)

                # get the resolved hook name for apply to key on the current stage
                resolved_apply_to_key_name = DataPipeline._resolve_function_hierarchy(
                    f"{key}_{transform_name}", self, running_stage,
                    InputTransform)
                # check if resolved hook name for apply to key is specialized
                is_specialized_apply_to_key_name = resolved_apply_to_key_name.startswith(
                    stage)

                # check if they are overridden by the user
                resolve_name_overridden = DataPipeline._is_overridden(
                    resolved_name, self, InputTransform)
                resolved_apply_to_key_name_overridden = DataPipeline._is_overridden(
                    resolved_apply_to_key_name, self, InputTransform)

                if resolve_name_overridden and resolved_apply_to_key_name_overridden:
                    # if both are specialized or both aren't specialized, raise a exception
                    # It means there is priority to specialize hooks name.
                    if not (is_specialized_name
                            ^ is_specialized_apply_to_key_name):
                        raise MisconfigurationException(
                            f"Only one of {resolved_name} or {resolved_apply_to_key_name} can be overridden."
                        )

                    method_name = resolved_name if is_specialized_name else resolved_apply_to_key_name
                else:
                    method_name = resolved_apply_to_key_name if resolved_apply_to_key_name_overridden else resolved_name

                # get associated transform
                try:
                    fn = getattr(self, method_name)()
                except AttributeError as e:
                    raise AttributeError(
                        str(e) +
                        ". Hint: Call super().__init__(...) after setting all attributes."
                    )

                if not callable(fn):
                    raise MisconfigurationException(
                        f"The hook {method_name} should return a function.")

                # if the default hook is used, it should return identity, skip it.
                if fn is self._identity:
                    continue

                # wrap apply to key hook into `ApplyToKeys` with the associated key.
                if method_name == resolved_apply_to_key_name:
                    fn = ApplyToKeys(key.value, fn)

                if method_name not in transforms:
                    transforms[method_name] = fn

            # store the transforms.
            if transforms:
                transforms = list(transforms.values())
                transforms_out[transform_name] = Compose(
                    transforms) if len(transforms) > 1 else transforms[0]

        return transforms_out