Esempio n. 1
0
class _Sequential(torch.nn.Module):
    """This class is used to chain 3 functions together for the _Preprocessor ``per_sample_transform`` function.

    1. ``pre_tensor_transform``
    2. ``to_tensor_transform``
    3. ``post_tensor_transform``
    """
    def __init__(
        self,
        preprocess: "Preprocess",
        pre_tensor_transform: Optional[Callable],
        to_tensor_transform: Optional[Callable],
        post_tensor_transform: Callable,
        stage: RunningStage,
        assert_contains_tensor: bool = False,
    ):
        super().__init__()
        self.preprocess = preprocess
        self.callback = ControlFlow(self.preprocess.callbacks)
        self.pre_tensor_transform = convert_to_modules(pre_tensor_transform)
        self.to_tensor_transform = convert_to_modules(to_tensor_transform)
        self.post_tensor_transform = convert_to_modules(post_tensor_transform)
        self.stage = stage
        self.assert_contains_tensor = assert_contains_tensor

        self._current_stage_context = CurrentRunningStageContext(stage,
                                                                 preprocess,
                                                                 reset=False)
        self._pre_tensor_transform_context = CurrentFuncContext(
            "pre_tensor_transform", preprocess)
        self._to_tensor_transform_context = CurrentFuncContext(
            "to_tensor_transform", preprocess)
        self._post_tensor_transform_context = CurrentFuncContext(
            "post_tensor_transform", preprocess)

    def forward(self, sample: Any) -> Any:
        self.callback.on_load_sample(sample, self.stage)

        with self._current_stage_context:
            if self.pre_tensor_transform is not None:
                with self._pre_tensor_transform_context:
                    sample = self.pre_tensor_transform(sample)
                    self.callback.on_pre_tensor_transform(sample, self.stage)

            if self.to_tensor_transform is not None:
                with self._to_tensor_transform_context:
                    sample = self.to_tensor_transform(sample)
                    self.callback.on_to_tensor_transform(sample, self.stage)

                if self.assert_contains_tensor:
                    if not _contains_any_tensor(sample):
                        raise MisconfigurationException(
                            "When ``to_tensor_transform`` is overriden, "
                            "``DataPipeline`` expects the outputs to be ``tensors``"
                        )

            with self._post_tensor_transform_context:
                sample = self.post_tensor_transform(sample)
                self.callback.on_post_tensor_transform(sample, self.stage)

            return sample

    def __str__(self) -> str:
        return (
            f"{self.__class__.__name__}:\n"
            f"\t(pre_tensor_transform): {str(self.pre_tensor_transform)}\n"
            f"\t(to_tensor_transform): {str(self.to_tensor_transform)}\n"
            f"\t(post_tensor_transform): {str(self.post_tensor_transform)}\n"
            f"\t(assert_contains_tensor): {str(self.assert_contains_tensor)}\n"
            f"\t(stage): {str(self.stage)}")
Esempio n. 2
0
class _InputTransformProcessor:
    """
    This class is used to encapsulate the following functions of an `InputTransform` Object:
    Inside a worker:
        per_sample_transform: Function to transform an individual sample
        collate: Function to merge sample into a batch
        per_batch_transform: Function to transform an individual batch

    Inside main process:
        per_sample_transform_on_device: Function to transform an individual sample
        collate: Function to merge sample into a batch
        per_batch_transform_on_device: Function to transform an individual batch
    """
    def __init__(
        self,
        input_transform: InputTransform,
        collate_fn: Callable,
        per_sample_transform: Callable,
        per_batch_transform: Callable,
        stage: RunningStage,
        apply_per_sample_transform: bool = True,
        on_device: bool = False,
    ):
        super().__init__()
        self.input_transform = input_transform
        self.callback = ControlFlow(self.input_transform.callbacks or [])
        self.collate_fn = collate_fn
        self.per_sample_transform = per_sample_transform
        self.per_batch_transform = per_batch_transform
        self.apply_per_sample_transform = apply_per_sample_transform
        self.stage = stage
        self.on_device = on_device

    def __call__(self, samples: Sequence[Any]) -> Any:
        if not self.on_device:
            for sample in samples:
                self.callback.on_load_sample(sample, self.stage)

        if self.apply_per_sample_transform:
            if not isinstance(samples, list):
                list_samples = [samples]
            else:
                list_samples = samples

            transformed_samples = [
                self.per_sample_transform(sample, self.stage)
                for sample in list_samples
            ]

            for sample in transformed_samples:
                if self.on_device:
                    self.callback.on_per_sample_transform_on_device(
                        sample, self.stage)
                else:
                    self.callback.on_per_sample_transform(sample, self.stage)

            collated_samples = self.collate_fn(transformed_samples, self.stage)
            self.callback.on_collate(collated_samples, self.stage)
        else:
            collated_samples = samples

        transformed_collated_samples = self.per_batch_transform(
            collated_samples, self.stage)
        if self.on_device:
            self.callback.on_per_batch_transform_on_device(
                transformed_collated_samples, self.stage)
        else:
            self.callback.on_per_batch_transform(transformed_collated_samples,
                                                 self.stage)
        return transformed_collated_samples

    def __str__(self) -> str:
        # todo: define repr function which would take object and string attributes to be shown
        return (
            "_InputTransformProcessor:\n"
            f"\t(per_sample_transform): {str(self.per_sample_transform)}\n"
            f"\t(collate_fn): {str(self.collate_fn)}\n"
            f"\t(per_batch_transform): {str(self.per_batch_transform)}\n"
            f"\t(apply_per_sample_transform): {str(self.apply_per_sample_transform)}\n"
            f"\t(on_device): {str(self.on_device)}\n"
            f"\t(stage): {str(self.stage)}")