def _create_collate_preprocessors( self, stage: RunningStage, collate_fn: Optional[Callable] = None, ) -> Tuple[_PreProcessor, _PreProcessor]: original_collate_fn = collate_fn if collate_fn is None: collate_fn = default_collate preprocess: Preprocess = self._preprocess_pipeline prefix: str = _STAGES_PREFIX[stage] func_names: Dict[str, str] = { k: self._resolve_function_hierarchy(k, preprocess, stage, Preprocess) for k in self.PREPROCESS_FUNCS } if self._is_overriden_recursive("collate", preprocess, Preprocess, prefix=prefix): collate_fn: Callable = getattr(preprocess, func_names["collate"]) per_batch_transform_overriden: bool = self._is_overriden_recursive( "per_batch_transform", preprocess, Preprocess, prefix=prefix ) per_sample_transform_on_device_overriden: bool = self._is_overriden_recursive( "per_sample_transform_on_device", preprocess, Preprocess, prefix=prefix ) collate_in_worker_from_transform: Optional[bool] = getattr( preprocess, f"_{prefix}_collate_in_worker_from_transform", None ) if ( collate_in_worker_from_transform is None and per_batch_transform_overriden and per_sample_transform_on_device_overriden ): raise MisconfigurationException( f'{self.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` ' f'are mutual exclusive for stage {stage}' ) if isinstance(collate_in_worker_from_transform, bool): worker_collate_fn, device_collate_fn = self._make_collates(not collate_in_worker_from_transform, collate_fn) else: worker_collate_fn, device_collate_fn = self._make_collates( per_sample_transform_on_device_overriden, collate_fn ) worker_collate_fn = worker_collate_fn.collate_fn if isinstance( worker_collate_fn, _PreProcessor ) else worker_collate_fn assert_contains_tensor = self._is_overriden_recursive( "to_tensor_transform", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage] ) worker_preprocessor = _PreProcessor( preprocess, worker_collate_fn, _Sequential( preprocess, getattr(preprocess, func_names['pre_tensor_transform']), getattr(preprocess, func_names['to_tensor_transform']), getattr(preprocess, func_names['post_tensor_transform']), stage, assert_contains_tensor=assert_contains_tensor, ), getattr(preprocess, func_names['per_batch_transform']), stage ) worker_preprocessor._original_collate_fn = original_collate_fn device_preprocessor = _PreProcessor( preprocess, device_collate_fn, getattr(preprocess, func_names['per_sample_transform_on_device']), getattr(preprocess, func_names['per_batch_transform_on_device']), stage, apply_per_sample_transform=device_collate_fn != self._identity, on_device=True, ) return worker_preprocessor, device_preprocessor
def _create_collate_preprocessors( self, stage: RunningStage, collate_fn: Optional[Callable] = None, ) -> Tuple[_PreProcessor, _PreProcessor]: original_collate_fn = collate_fn if collate_fn is None: collate_fn = default_collate func_names = { k: self._resolve_function_hierarchy(k, self._preprocess_pipeline, stage, Preprocess) for k in self.PREPROCESS_FUNCS } if self._is_overriden_recursive("collate", self._preprocess_pipeline, Preprocess, prefix=_STAGES_PREFIX[stage]): collate_fn = getattr(self._preprocess_pipeline, func_names["collate"]) per_batch_transform_overriden = self._is_overriden_recursive( "per_batch_transform", self._preprocess_pipeline, Preprocess, prefix=_STAGES_PREFIX[stage]) per_sample_transform_on_device_overriden = self._is_overriden_recursive( "per_sample_transform_on_device", self._preprocess_pipeline, Preprocess, prefix=_STAGES_PREFIX[stage]) if per_batch_transform_overriden and per_sample_transform_on_device_overriden: raise MisconfigurationException( f'{self.__class__.__name__}: `per_batch_transform` and `gpu_per_sample_transform` ' f'are mutual exclusive for stage {stage}') elif per_batch_transform_overriden: worker_collate_fn = collate_fn device_collate_fn = self._identity elif per_sample_transform_on_device_overriden: worker_collate_fn = self._identity device_collate_fn = collate_fn else: worker_collate_fn = collate_fn device_collate_fn = self._identity worker_collate_fn = worker_collate_fn.collate_fn if isinstance( worker_collate_fn, _PreProcessor) else worker_collate_fn assert_contains_tensor = self._is_overriden_recursive( "to_tensor_transform", self._preprocess_pipeline, Preprocess, prefix=_STAGES_PREFIX[stage]) worker_preprocessor = _PreProcessor( worker_collate_fn, _Sequential( getattr(self._preprocess_pipeline, func_names['pre_tensor_transform']), getattr(self._preprocess_pipeline, func_names['to_tensor_transform']), getattr(self._preprocess_pipeline, func_names['post_tensor_transform']), assert_contains_tensor=assert_contains_tensor, ), getattr(self._preprocess_pipeline, func_names['per_batch_transform']), stage) worker_preprocessor._original_collate_fn = original_collate_fn device_preprocessor = _PreProcessor( device_collate_fn, getattr(self._preprocess_pipeline, func_names['per_sample_transform_on_device']), getattr(self._preprocess_pipeline, func_names['per_batch_transform_on_device']), stage, apply_per_sample_transform=device_collate_fn != self._identity) return worker_preprocessor, device_preprocessor