示例#1
0
 def forward(self, sample: Any):
     sample = self.pre_tensor_transform(sample)
     sample = self.to_tensor_transform(sample)
     if self.assert_contains_tensor:
         if not _contains_any_tensor(sample):
             raise MisconfigurationException(
                 "When ``to_tensor_transform`` is overriden, "
                 "``DataPipeline`` expects the outputs to be ``tensors``")
     sample = self.post_tensor_transform(sample)
     return sample
示例#2
0
    def pre_tensor_transform(self, samples: Any) -> Any:
        if _contains_any_tensor(samples):
            return samples

        if isinstance(samples, str):
            samples = [samples]

        if isinstance(samples, (list, tuple)) and all(
                isinstance(p, str) for p in samples):
            outputs = []
            for sample in samples:
                outputs.append(pil_loader(sample))
            return outputs
        raise MisconfigurationException(
            "The samples should either be a tensor, a list of paths or a path."
        )
示例#3
0
    def forward(self, sample: Any) -> Any:
        with self._current_stage_context:
            with self._pre_tensor_transform_context:
                sample = self.pre_tensor_transform(sample)
                self.callback.on_pre_tensor_transform(sample, self.stage)

            with self._to_tensor_transform_context:
                sample = self.to_tensor_transform(sample)
                self.callback.on_to_tensor_transform(sample, self.stage)

            if self.assert_contains_tensor:
                if not _contains_any_tensor(sample):
                    raise MisconfigurationException(
                        "When ``to_tensor_transform`` is overriden, "
                        "``DataPipeline`` expects the outputs to be ``tensors``"
                    )

            with self._post_tensor_transform_context:
                sample = self.post_tensor_transform(sample)
                self.callback.on_post_tensor_transform(sample, self.stage)

            return sample