Exemple #1
0
class _Preprocessor(torch.nn.Module):
    """
    This class is used to encapsultate the following functions of a Preprocess Object:
    Inside a worker:
        per_sample_transform: Function to transform an individual sample
            Inside a worker, it is actually make of 3 functions:
                * pre_tensor_transform
                * to_tensor_transform
                * post_tensor_transform
        collate: Function to merge sample into a batch
        per_batch_transform: Function to transform an individual batch
            * per_batch_transform

    Inside main process:
        per_sample_transform: Function to transform an individual sample
            * per_sample_transform_on_device
        collate: Function to merge sample into a batch
        per_batch_transform: Function to transform an individual batch
            * per_batch_transform_on_device
    """
    def __init__(
        self,
        preprocess: "Preprocess",
        collate_fn: Callable,
        per_sample_transform: Union[Callable, _Sequential],
        per_batch_transform: Callable,
        stage: RunningStage,
        apply_per_sample_transform: bool = True,
        on_device: bool = False,
    ):
        super().__init__()
        self.preprocess = preprocess
        self.callback = ControlFlow(self.preprocess.callbacks)
        self.collate_fn = convert_to_modules(collate_fn)
        self.per_sample_transform = convert_to_modules(per_sample_transform)
        self.per_batch_transform = convert_to_modules(per_batch_transform)
        self.apply_per_sample_transform = apply_per_sample_transform
        self.stage = stage
        self.on_device = on_device

        extension = f"{'_on_device' if self.on_device else ''}"
        self._current_stage_context = CurrentRunningStageContext(
            stage, preprocess)
        self._per_sample_transform_context = CurrentFuncContext(
            f"per_sample_transform{extension}", preprocess)
        self._collate_context = CurrentFuncContext("collate", preprocess)
        self._per_batch_transform_context = CurrentFuncContext(
            f"per_batch_transform{extension}", preprocess)

    @staticmethod
    def _extract_metadata(
        samples: List[Dict[str, Any]],
    ) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]:
        metadata = [
            s.pop(DefaultDataKeys.METADATA, None)
            if isinstance(s, Mapping) else None for s in samples
        ]
        return samples, metadata if any(m is not None
                                        for m in metadata) else None

    def forward(self, samples: Sequence[Any]) -> Any:
        # we create a new dict to prevent from potential memory leaks
        # assuming that the dictionary samples are stored in between and
        # potentially modified before the transforms are applied.
        if isinstance(samples, dict):
            samples = dict(samples.items())

        with self._current_stage_context:

            if self.apply_per_sample_transform:
                with self._per_sample_transform_context:
                    _samples = []

                    if isinstance(samples, Mapping):
                        samples = [samples]

                    for sample in samples:
                        sample = self.per_sample_transform(sample)
                        if self.on_device:
                            self.callback.on_per_sample_transform_on_device(
                                sample, self.stage)
                        _samples.append(sample)

                samples = type(_samples)(_samples)

                with self._collate_context:
                    samples, metadata = self._extract_metadata(samples)
                    try:
                        samples = self.collate_fn(samples, metadata)
                    except TypeError:
                        samples = self.collate_fn(samples)
                    if metadata and isinstance(samples, dict):
                        samples[DefaultDataKeys.METADATA] = metadata
                    self.callback.on_collate(samples, self.stage)

            with self._per_batch_transform_context:
                samples = self.per_batch_transform(samples)
                if self.on_device:
                    self.callback.on_per_batch_transform_on_device(
                        samples, self.stage)
                else:
                    self.callback.on_per_batch_transform(samples, self.stage)
            return samples

    def __str__(self) -> str:
        # todo: define repr function which would take object and string attributes to be shown
        return (
            "_Preprocessor:\n"
            f"\t(per_sample_transform): {str(self.per_sample_transform)}\n"
            f"\t(collate_fn): {str(self.collate_fn)}\n"
            f"\t(per_batch_transform): {str(self.per_batch_transform)}\n"
            f"\t(apply_per_sample_transform): {str(self.apply_per_sample_transform)}\n"
            f"\t(on_device): {str(self.on_device)}\n"
            f"\t(stage): {str(self.stage)}")
Exemple #2
0
class _InputTransformProcessor:
    """
    This class is used to encapsulate the following functions of an `InputTransform` Object:
    Inside a worker:
        per_sample_transform: Function to transform an individual sample
        collate: Function to merge sample into a batch
        per_batch_transform: Function to transform an individual batch

    Inside main process:
        per_sample_transform_on_device: Function to transform an individual sample
        collate: Function to merge sample into a batch
        per_batch_transform_on_device: Function to transform an individual batch
    """
    def __init__(
        self,
        input_transform: InputTransform,
        collate_fn: Callable,
        per_sample_transform: Callable,
        per_batch_transform: Callable,
        stage: RunningStage,
        apply_per_sample_transform: bool = True,
        on_device: bool = False,
    ):
        super().__init__()
        self.input_transform = input_transform
        self.callback = ControlFlow(self.input_transform.callbacks or [])
        self.collate_fn = collate_fn
        self.per_sample_transform = per_sample_transform
        self.per_batch_transform = per_batch_transform
        self.apply_per_sample_transform = apply_per_sample_transform
        self.stage = stage
        self.on_device = on_device

    def __call__(self, samples: Sequence[Any]) -> Any:
        if not self.on_device:
            for sample in samples:
                self.callback.on_load_sample(sample, self.stage)

        if self.apply_per_sample_transform:
            if not isinstance(samples, list):
                list_samples = [samples]
            else:
                list_samples = samples

            transformed_samples = [
                self.per_sample_transform(sample, self.stage)
                for sample in list_samples
            ]

            for sample in transformed_samples:
                if self.on_device:
                    self.callback.on_per_sample_transform_on_device(
                        sample, self.stage)
                else:
                    self.callback.on_per_sample_transform(sample, self.stage)

            collated_samples = self.collate_fn(transformed_samples, self.stage)
            self.callback.on_collate(collated_samples, self.stage)
        else:
            collated_samples = samples

        transformed_collated_samples = self.per_batch_transform(
            collated_samples, self.stage)
        if self.on_device:
            self.callback.on_per_batch_transform_on_device(
                transformed_collated_samples, self.stage)
        else:
            self.callback.on_per_batch_transform(transformed_collated_samples,
                                                 self.stage)
        return transformed_collated_samples

    def __str__(self) -> str:
        # todo: define repr function which would take object and string attributes to be shown
        return (
            "_InputTransformProcessor:\n"
            f"\t(per_sample_transform): {str(self.per_sample_transform)}\n"
            f"\t(collate_fn): {str(self.collate_fn)}\n"
            f"\t(per_batch_transform): {str(self.per_batch_transform)}\n"
            f"\t(apply_per_sample_transform): {str(self.apply_per_sample_transform)}\n"
            f"\t(on_device): {str(self.on_device)}\n"
            f"\t(stage): {str(self.stage)}")