예제 #1
0
    def __init__(
        self,
        preprocess: "Preprocess",
        pre_tensor_transform: Optional[Callable],
        to_tensor_transform: Optional[Callable],
        post_tensor_transform: Callable,
        stage: RunningStage,
        assert_contains_tensor: bool = False,
    ):
        super().__init__()
        self.preprocess = preprocess
        self.callback = ControlFlow(self.preprocess.callbacks)
        self.pre_tensor_transform = convert_to_modules(pre_tensor_transform)
        self.to_tensor_transform = convert_to_modules(to_tensor_transform)
        self.post_tensor_transform = convert_to_modules(post_tensor_transform)
        self.stage = stage
        self.assert_contains_tensor = assert_contains_tensor

        self._current_stage_context = CurrentRunningStageContext(stage,
                                                                 preprocess,
                                                                 reset=False)
        self._pre_tensor_transform_context = CurrentFuncContext(
            "pre_tensor_transform", preprocess)
        self._to_tensor_transform_context = CurrentFuncContext(
            "to_tensor_transform", preprocess)
        self._post_tensor_transform_context = CurrentFuncContext(
            "post_tensor_transform", preprocess)
예제 #2
0
    def __init__(
        self,
        preprocess: "Preprocess",
        collate_fn: Callable,
        per_sample_transform: Union[Callable, _Sequential],
        per_batch_transform: Callable,
        stage: RunningStage,
        apply_per_sample_transform: bool = True,
        on_device: bool = False,
    ):
        super().__init__()
        self.preprocess = preprocess
        self.callback = ControlFlow(self.preprocess.callbacks)
        self.collate_fn = convert_to_modules(collate_fn)
        self.per_sample_transform = convert_to_modules(per_sample_transform)
        self.per_batch_transform = convert_to_modules(per_batch_transform)
        self.apply_per_sample_transform = apply_per_sample_transform
        self.stage = stage
        self.on_device = on_device

        extension = f"{'_on_device' if self.on_device else ''}"
        self._current_stage_context = CurrentRunningStageContext(
            stage, preprocess)
        self._per_sample_transform_context = CurrentFuncContext(
            f"per_sample_transform{extension}", preprocess)
        self._collate_context = CurrentFuncContext("collate", preprocess)
        self._per_batch_transform_context = CurrentFuncContext(
            f"per_batch_transform{extension}", preprocess)
예제 #3
0
 def __init__(
     self,
     input_transform: InputTransform,
     collate_fn: Callable,
     per_sample_transform: Callable,
     per_batch_transform: Callable,
     stage: RunningStage,
     apply_per_sample_transform: bool = True,
     on_device: bool = False,
 ):
     super().__init__()
     self.input_transform = input_transform
     self.callback = ControlFlow(self.input_transform.callbacks or [])
     self.collate_fn = collate_fn
     self.per_sample_transform = per_sample_transform
     self.per_batch_transform = per_batch_transform
     self.apply_per_sample_transform = apply_per_sample_transform
     self.stage = stage
     self.on_device = on_device
예제 #4
0
    def __init__(
        self,
        deserializer: "Deserializer",
        preprocess: "Preprocess",
        pre_tensor_transform: Callable,
        to_tensor_transform: Callable,
    ):
        super().__init__()
        self.preprocess = preprocess
        self.callback = ControlFlow(self.preprocess.callbacks)
        self.deserializer = convert_to_modules(deserializer)
        self.pre_tensor_transform = convert_to_modules(pre_tensor_transform)
        self.to_tensor_transform = convert_to_modules(to_tensor_transform)

        self._current_stage_context = CurrentRunningStageContext(
            RunningStage.PREDICTING, preprocess, reset=False)
        self._pre_tensor_transform_context = CurrentFuncContext(
            "pre_tensor_transform", preprocess)
        self._to_tensor_transform_context = CurrentFuncContext(
            "to_tensor_transform", preprocess)
예제 #5
0
class _DeserializeProcessor(torch.nn.Module):
    def __init__(
        self,
        deserializer: "Deserializer",
        preprocess: "Preprocess",
        pre_tensor_transform: Callable,
        to_tensor_transform: Callable,
    ):
        super().__init__()
        self.preprocess = preprocess
        self.callback = ControlFlow(self.preprocess.callbacks)
        self.deserializer = convert_to_modules(deserializer)
        self.pre_tensor_transform = convert_to_modules(pre_tensor_transform)
        self.to_tensor_transform = convert_to_modules(to_tensor_transform)

        self._current_stage_context = CurrentRunningStageContext(
            RunningStage.PREDICTING, preprocess, reset=False)
        self._pre_tensor_transform_context = CurrentFuncContext(
            "pre_tensor_transform", preprocess)
        self._to_tensor_transform_context = CurrentFuncContext(
            "to_tensor_transform", preprocess)

    def forward(self, sample: str):

        sample = self.deserializer(sample)

        with self._current_stage_context:
            with self._pre_tensor_transform_context:
                sample = self.pre_tensor_transform(sample)
                self.callback.on_pre_tensor_transform(sample,
                                                      RunningStage.PREDICTING)

            with self._to_tensor_transform_context:
                sample = self.to_tensor_transform(sample)
                self.callback.on_to_tensor_transform(sample,
                                                     RunningStage.PREDICTING)

        return sample
예제 #6
0
class _Sequential(torch.nn.Module):
    """This class is used to chain 3 functions together for the _Preprocessor ``per_sample_transform`` function.

    1. ``pre_tensor_transform``
    2. ``to_tensor_transform``
    3. ``post_tensor_transform``
    """
    def __init__(
        self,
        preprocess: "Preprocess",
        pre_tensor_transform: Optional[Callable],
        to_tensor_transform: Optional[Callable],
        post_tensor_transform: Callable,
        stage: RunningStage,
        assert_contains_tensor: bool = False,
    ):
        super().__init__()
        self.preprocess = preprocess
        self.callback = ControlFlow(self.preprocess.callbacks)
        self.pre_tensor_transform = convert_to_modules(pre_tensor_transform)
        self.to_tensor_transform = convert_to_modules(to_tensor_transform)
        self.post_tensor_transform = convert_to_modules(post_tensor_transform)
        self.stage = stage
        self.assert_contains_tensor = assert_contains_tensor

        self._current_stage_context = CurrentRunningStageContext(stage,
                                                                 preprocess,
                                                                 reset=False)
        self._pre_tensor_transform_context = CurrentFuncContext(
            "pre_tensor_transform", preprocess)
        self._to_tensor_transform_context = CurrentFuncContext(
            "to_tensor_transform", preprocess)
        self._post_tensor_transform_context = CurrentFuncContext(
            "post_tensor_transform", preprocess)

    def forward(self, sample: Any) -> Any:
        self.callback.on_load_sample(sample, self.stage)

        with self._current_stage_context:
            if self.pre_tensor_transform is not None:
                with self._pre_tensor_transform_context:
                    sample = self.pre_tensor_transform(sample)
                    self.callback.on_pre_tensor_transform(sample, self.stage)

            if self.to_tensor_transform is not None:
                with self._to_tensor_transform_context:
                    sample = self.to_tensor_transform(sample)
                    self.callback.on_to_tensor_transform(sample, self.stage)

                if self.assert_contains_tensor:
                    if not _contains_any_tensor(sample):
                        raise MisconfigurationException(
                            "When ``to_tensor_transform`` is overriden, "
                            "``DataPipeline`` expects the outputs to be ``tensors``"
                        )

            with self._post_tensor_transform_context:
                sample = self.post_tensor_transform(sample)
                self.callback.on_post_tensor_transform(sample, self.stage)

            return sample

    def __str__(self) -> str:
        return (
            f"{self.__class__.__name__}:\n"
            f"\t(pre_tensor_transform): {str(self.pre_tensor_transform)}\n"
            f"\t(to_tensor_transform): {str(self.to_tensor_transform)}\n"
            f"\t(post_tensor_transform): {str(self.post_tensor_transform)}\n"
            f"\t(assert_contains_tensor): {str(self.assert_contains_tensor)}\n"
            f"\t(stage): {str(self.stage)}")
예제 #7
0
class _Preprocessor(torch.nn.Module):
    """
    This class is used to encapsultate the following functions of a Preprocess Object:
    Inside a worker:
        per_sample_transform: Function to transform an individual sample
            Inside a worker, it is actually make of 3 functions:
                * pre_tensor_transform
                * to_tensor_transform
                * post_tensor_transform
        collate: Function to merge sample into a batch
        per_batch_transform: Function to transform an individual batch
            * per_batch_transform

    Inside main process:
        per_sample_transform: Function to transform an individual sample
            * per_sample_transform_on_device
        collate: Function to merge sample into a batch
        per_batch_transform: Function to transform an individual batch
            * per_batch_transform_on_device
    """
    def __init__(
        self,
        preprocess: "Preprocess",
        collate_fn: Callable,
        per_sample_transform: Union[Callable, _Sequential],
        per_batch_transform: Callable,
        stage: RunningStage,
        apply_per_sample_transform: bool = True,
        on_device: bool = False,
    ):
        super().__init__()
        self.preprocess = preprocess
        self.callback = ControlFlow(self.preprocess.callbacks)
        self.collate_fn = convert_to_modules(collate_fn)
        self.per_sample_transform = convert_to_modules(per_sample_transform)
        self.per_batch_transform = convert_to_modules(per_batch_transform)
        self.apply_per_sample_transform = apply_per_sample_transform
        self.stage = stage
        self.on_device = on_device

        extension = f"{'_on_device' if self.on_device else ''}"
        self._current_stage_context = CurrentRunningStageContext(
            stage, preprocess)
        self._per_sample_transform_context = CurrentFuncContext(
            f"per_sample_transform{extension}", preprocess)
        self._collate_context = CurrentFuncContext("collate", preprocess)
        self._per_batch_transform_context = CurrentFuncContext(
            f"per_batch_transform{extension}", preprocess)

    @staticmethod
    def _extract_metadata(
        samples: List[Dict[str, Any]],
    ) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]:
        metadata = [
            s.pop(DefaultDataKeys.METADATA, None)
            if isinstance(s, Mapping) else None for s in samples
        ]
        return samples, metadata if any(m is not None
                                        for m in metadata) else None

    def forward(self, samples: Sequence[Any]) -> Any:
        # we create a new dict to prevent from potential memory leaks
        # assuming that the dictionary samples are stored in between and
        # potentially modified before the transforms are applied.
        if isinstance(samples, dict):
            samples = dict(samples.items())

        with self._current_stage_context:

            if self.apply_per_sample_transform:
                with self._per_sample_transform_context:
                    _samples = []

                    if isinstance(samples, Mapping):
                        samples = [samples]

                    for sample in samples:
                        sample = self.per_sample_transform(sample)
                        if self.on_device:
                            self.callback.on_per_sample_transform_on_device(
                                sample, self.stage)
                        _samples.append(sample)

                samples = type(_samples)(_samples)

                with self._collate_context:
                    samples, metadata = self._extract_metadata(samples)
                    try:
                        samples = self.collate_fn(samples, metadata)
                    except TypeError:
                        samples = self.collate_fn(samples)
                    if metadata and isinstance(samples, dict):
                        samples[DefaultDataKeys.METADATA] = metadata
                    self.callback.on_collate(samples, self.stage)

            with self._per_batch_transform_context:
                samples = self.per_batch_transform(samples)
                if self.on_device:
                    self.callback.on_per_batch_transform_on_device(
                        samples, self.stage)
                else:
                    self.callback.on_per_batch_transform(samples, self.stage)
            return samples

    def __str__(self) -> str:
        # todo: define repr function which would take object and string attributes to be shown
        return (
            "_Preprocessor:\n"
            f"\t(per_sample_transform): {str(self.per_sample_transform)}\n"
            f"\t(collate_fn): {str(self.collate_fn)}\n"
            f"\t(per_batch_transform): {str(self.per_batch_transform)}\n"
            f"\t(apply_per_sample_transform): {str(self.apply_per_sample_transform)}\n"
            f"\t(on_device): {str(self.on_device)}\n"
            f"\t(stage): {str(self.stage)}")
예제 #8
0
class _InputTransformProcessor:
    """
    This class is used to encapsulate the following functions of an `InputTransform` Object:
    Inside a worker:
        per_sample_transform: Function to transform an individual sample
        collate: Function to merge sample into a batch
        per_batch_transform: Function to transform an individual batch

    Inside main process:
        per_sample_transform_on_device: Function to transform an individual sample
        collate: Function to merge sample into a batch
        per_batch_transform_on_device: Function to transform an individual batch
    """
    def __init__(
        self,
        input_transform: InputTransform,
        collate_fn: Callable,
        per_sample_transform: Callable,
        per_batch_transform: Callable,
        stage: RunningStage,
        apply_per_sample_transform: bool = True,
        on_device: bool = False,
    ):
        super().__init__()
        self.input_transform = input_transform
        self.callback = ControlFlow(self.input_transform.callbacks or [])
        self.collate_fn = collate_fn
        self.per_sample_transform = per_sample_transform
        self.per_batch_transform = per_batch_transform
        self.apply_per_sample_transform = apply_per_sample_transform
        self.stage = stage
        self.on_device = on_device

    def __call__(self, samples: Sequence[Any]) -> Any:
        if not self.on_device:
            for sample in samples:
                self.callback.on_load_sample(sample, self.stage)

        if self.apply_per_sample_transform:
            if not isinstance(samples, list):
                list_samples = [samples]
            else:
                list_samples = samples

            transformed_samples = [
                self.per_sample_transform(sample, self.stage)
                for sample in list_samples
            ]

            for sample in transformed_samples:
                if self.on_device:
                    self.callback.on_per_sample_transform_on_device(
                        sample, self.stage)
                else:
                    self.callback.on_per_sample_transform(sample, self.stage)

            collated_samples = self.collate_fn(transformed_samples, self.stage)
            self.callback.on_collate(collated_samples, self.stage)
        else:
            collated_samples = samples

        transformed_collated_samples = self.per_batch_transform(
            collated_samples, self.stage)
        if self.on_device:
            self.callback.on_per_batch_transform_on_device(
                transformed_collated_samples, self.stage)
        else:
            self.callback.on_per_batch_transform(transformed_collated_samples,
                                                 self.stage)
        return transformed_collated_samples

    def __str__(self) -> str:
        # todo: define repr function which would take object and string attributes to be shown
        return (
            "_InputTransformProcessor:\n"
            f"\t(per_sample_transform): {str(self.per_sample_transform)}\n"
            f"\t(collate_fn): {str(self.collate_fn)}\n"
            f"\t(per_batch_transform): {str(self.per_batch_transform)}\n"
            f"\t(apply_per_sample_transform): {str(self.apply_per_sample_transform)}\n"
            f"\t(on_device): {str(self.on_device)}\n"
            f"\t(stage): {str(self.stage)}")