def forward(self, sample: Any): sample = self.pre_tensor_transform(sample) sample = self.to_tensor_transform(sample) if self.assert_contains_tensor: if not _contains_any_tensor(sample): raise MisconfigurationException( "When ``to_tensor_transform`` is overriden, " "``DataPipeline`` expects the outputs to be ``tensors``") sample = self.post_tensor_transform(sample) return sample
def pre_tensor_transform(self, samples: Any) -> Any: if _contains_any_tensor(samples): return samples if isinstance(samples, str): samples = [samples] if isinstance(samples, (list, tuple)) and all( isinstance(p, str) for p in samples): outputs = [] for sample in samples: outputs.append(pil_loader(sample)) return outputs raise MisconfigurationException( "The samples should either be a tensor, a list of paths or a path." )
def forward(self, sample: Any) -> Any: with self._current_stage_context: with self._pre_tensor_transform_context: sample = self.pre_tensor_transform(sample) self.callback.on_pre_tensor_transform(sample, self.stage) with self._to_tensor_transform_context: sample = self.to_tensor_transform(sample) self.callback.on_to_tensor_transform(sample, self.stage) if self.assert_contains_tensor: if not _contains_any_tensor(sample): raise MisconfigurationException( "When ``to_tensor_transform`` is overriden, " "``DataPipeline`` expects the outputs to be ``tensors``" ) with self._post_tensor_transform_context: sample = self.post_tensor_transform(sample) self.callback.on_post_tensor_transform(sample, self.stage) return sample