def _create_collate_preprocessors(
        self,
        stage: RunningStage,
        collate_fn: Optional[Callable] = None,
    ) -> Tuple[_PreProcessor, _PreProcessor]:

        original_collate_fn = collate_fn

        if collate_fn is None:
            collate_fn = default_collate

        preprocess: Preprocess = self._preprocess_pipeline
        prefix: str = _STAGES_PREFIX[stage]

        func_names: Dict[str, str] = {
            k: self._resolve_function_hierarchy(k, preprocess, stage, Preprocess)
            for k in self.PREPROCESS_FUNCS
        }

        if self._is_overriden_recursive("collate", preprocess, Preprocess, prefix=prefix):
            collate_fn: Callable = getattr(preprocess, func_names["collate"])

        per_batch_transform_overriden: bool = self._is_overriden_recursive(
            "per_batch_transform", preprocess, Preprocess, prefix=prefix
        )

        per_sample_transform_on_device_overriden: bool = self._is_overriden_recursive(
            "per_sample_transform_on_device", preprocess, Preprocess, prefix=prefix
        )

        collate_in_worker_from_transform: Optional[bool] = getattr(
            preprocess, f"_{prefix}_collate_in_worker_from_transform", None
        )

        if (
            collate_in_worker_from_transform is None and per_batch_transform_overriden
            and per_sample_transform_on_device_overriden
        ):
            raise MisconfigurationException(
                f'{self.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` '
                f'are mutual exclusive for stage {stage}'
            )

        if isinstance(collate_in_worker_from_transform, bool):
            worker_collate_fn, device_collate_fn = self._make_collates(not collate_in_worker_from_transform, collate_fn)
        else:
            worker_collate_fn, device_collate_fn = self._make_collates(
                per_sample_transform_on_device_overriden, collate_fn
            )

        worker_collate_fn = worker_collate_fn.collate_fn if isinstance(
            worker_collate_fn, _PreProcessor
        ) else worker_collate_fn

        assert_contains_tensor = self._is_overriden_recursive(
            "to_tensor_transform", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage]
        )

        worker_preprocessor = _PreProcessor(
            preprocess, worker_collate_fn,
            _Sequential(
                preprocess,
                getattr(preprocess, func_names['pre_tensor_transform']),
                getattr(preprocess, func_names['to_tensor_transform']),
                getattr(preprocess, func_names['post_tensor_transform']),
                stage,
                assert_contains_tensor=assert_contains_tensor,
            ), getattr(preprocess, func_names['per_batch_transform']), stage
        )
        worker_preprocessor._original_collate_fn = original_collate_fn
        device_preprocessor = _PreProcessor(
            preprocess,
            device_collate_fn,
            getattr(preprocess, func_names['per_sample_transform_on_device']),
            getattr(preprocess, func_names['per_batch_transform_on_device']),
            stage,
            apply_per_sample_transform=device_collate_fn != self._identity,
            on_device=True,
        )
        return worker_preprocessor, device_preprocessor
    def _create_collate_preprocessors(
        self,
        stage: RunningStage,
        collate_fn: Optional[Callable] = None,
    ) -> Tuple[_PreProcessor, _PreProcessor]:
        original_collate_fn = collate_fn
        if collate_fn is None:
            collate_fn = default_collate

        func_names = {
            k: self._resolve_function_hierarchy(k, self._preprocess_pipeline,
                                                stage, Preprocess)
            for k in self.PREPROCESS_FUNCS
        }

        if self._is_overriden_recursive("collate",
                                        self._preprocess_pipeline,
                                        Preprocess,
                                        prefix=_STAGES_PREFIX[stage]):
            collate_fn = getattr(self._preprocess_pipeline,
                                 func_names["collate"])

        per_batch_transform_overriden = self._is_overriden_recursive(
            "per_batch_transform",
            self._preprocess_pipeline,
            Preprocess,
            prefix=_STAGES_PREFIX[stage])

        per_sample_transform_on_device_overriden = self._is_overriden_recursive(
            "per_sample_transform_on_device",
            self._preprocess_pipeline,
            Preprocess,
            prefix=_STAGES_PREFIX[stage])

        if per_batch_transform_overriden and per_sample_transform_on_device_overriden:
            raise MisconfigurationException(
                f'{self.__class__.__name__}: `per_batch_transform` and `gpu_per_sample_transform` '
                f'are mutual exclusive for stage {stage}')

        elif per_batch_transform_overriden:
            worker_collate_fn = collate_fn
            device_collate_fn = self._identity

        elif per_sample_transform_on_device_overriden:
            worker_collate_fn = self._identity
            device_collate_fn = collate_fn

        else:
            worker_collate_fn = collate_fn
            device_collate_fn = self._identity

        worker_collate_fn = worker_collate_fn.collate_fn if isinstance(
            worker_collate_fn, _PreProcessor) else worker_collate_fn

        assert_contains_tensor = self._is_overriden_recursive(
            "to_tensor_transform",
            self._preprocess_pipeline,
            Preprocess,
            prefix=_STAGES_PREFIX[stage])

        worker_preprocessor = _PreProcessor(
            worker_collate_fn,
            _Sequential(
                getattr(self._preprocess_pipeline,
                        func_names['pre_tensor_transform']),
                getattr(self._preprocess_pipeline,
                        func_names['to_tensor_transform']),
                getattr(self._preprocess_pipeline,
                        func_names['post_tensor_transform']),
                assert_contains_tensor=assert_contains_tensor,
            ),
            getattr(self._preprocess_pipeline,
                    func_names['per_batch_transform']), stage)
        worker_preprocessor._original_collate_fn = original_collate_fn
        device_preprocessor = _PreProcessor(
            device_collate_fn,
            getattr(self._preprocess_pipeline,
                    func_names['per_sample_transform_on_device']),
            getattr(self._preprocess_pipeline,
                    func_names['per_batch_transform_on_device']),
            stage,
            apply_per_sample_transform=device_collate_fn != self._identity)
        return worker_preprocessor, device_preprocessor