예제 #1
0
    def _extract_image(
        self,
        sample: Any,
        unsupported_types: Tuple[Type, ...] = (features.BoundingBox,
                                               features.SegmentationMask),
    ) -> Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor,
                                      features.Image]]:
        def fn(
            id: Tuple[Any, ...], input: Any
        ) -> Optional[Tuple[Tuple[Any, ...], Union[
                PIL.Image.Image, torch.Tensor, features.Image]]]:
            if type(input) in {torch.Tensor, features.Image} or isinstance(
                    input, PIL.Image.Image):
                return id, input
            elif isinstance(input, unsupported_types):
                raise TypeError(
                    f"Inputs of type {type(input).__name__} are not supported by {type(self).__name__}()"
                )
            else:
                return None

        images = list(query_recursively(fn, sample))
        if not images:
            raise TypeError("Found no image in the sample.")
        if len(images) > 1:
            raise TypeError(
                f"Auto augment transformations are only properly defined for a single image, but found {len(images)}."
            )
        return images[0]
예제 #2
0
def query_image(
        sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]:
    def fn(
        input: Any
    ) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Image]]:
        if type(input) in {torch.Tensor, features.Image} or isinstance(
                input, PIL.Image.Image):
            return input

        return None

    try:
        return next(query_recursively(fn, sample))
    except StopIteration:
        raise TypeError("No image was found in the sample")
예제 #3
0
def _extract_types(sample: Any) -> Iterator[Type]:
    return query_recursively(lambda id, input: type(input), sample)