Example #1
0
def make_vanilla_tensor_images(*args, **kwargs):
    for image in make_images(*args, **kwargs):
        if image.ndim > 3:
            continue
        yield image.data
Example #2
0
class TestSmoke:
    @parametrize_from_transforms(
        transforms.RandomErasing(p=1.0),
        transforms.HorizontalFlip(),
        transforms.Resize([16, 16]),
        transforms.CenterCrop([16, 16]),
        transforms.ConvertImageDtype(),
    )
    def test_common(self, transform, input):
        transform(input)

    @parametrize([(
        transform,
        [
            dict(
                image=features.Image.new_like(image,
                                              image.unsqueeze(0),
                                              dtype=torch.float),
                one_hot_label=features.OneHotLabel.new_like(
                    one_hot_label,
                    one_hot_label.unsqueeze(0),
                    dtype=torch.float),
            ) for image, one_hot_label in itertools.product(
                make_images(), make_one_hot_labels())
        ],
    ) for transform in [
        transforms.RandomMixup(alpha=1.0),
        transforms.RandomCutmix(alpha=1.0),
    ]])
    def test_mixup_cutmix(self, transform, input):
        transform(input)

    @parametrize([(
        transform,
        itertools.chain.from_iterable(
            fn(
                color_spaces=[
                    features.ColorSpace.GRAY,
                    features.ColorSpace.RGB,
                ],
                dtypes=[torch.uint8],
                extra_dims=[(4, )],
            ) for fn in [
                make_images,
                make_vanilla_tensor_images,
                make_pil_images,
            ]),
    ) for transform in (
        transforms.RandAugment(),
        transforms.TrivialAugmentWide(),
        transforms.AutoAugment(),
        transforms.AugMix(),
    )])
    def test_auto_augment(self, transform, input):
        transform(input)

    @parametrize([
        (
            transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
            itertools.chain.from_iterable(
                fn(color_spaces=[features.ColorSpace.RGB],
                   dtypes=[torch.float32]) for fn in [
                       make_images,
                       make_vanilla_tensor_images,
                   ]),
        ),
    ])
    def test_normalize(self, transform, input):
        transform(input)

    @parametrize([(
        transforms.RandomResizedCrop([16, 16]),
        itertools.chain(
            make_images(extra_dims=[(4, )]),
            make_vanilla_tensor_images(),
            make_pil_images(),
        ),
    )])
    def test_random_resized_crop(self, transform, input):
        transform(input)

    @parametrize([(
        transforms.ConvertImageColorSpace(color_space=new_color_space,
                                          old_color_space=old_color_space),
        itertools.chain.from_iterable([
            fn(color_spaces=[old_color_space]) for fn in (
                make_images,
                make_vanilla_tensor_images,
                make_pil_images,
            )
        ]),
    ) for old_color_space, new_color_space in itertools.product(
        [
            features.ColorSpace.GRAY,
            features.ColorSpace.GRAY_ALPHA,
            features.ColorSpace.RGB,
            features.ColorSpace.RGB_ALPHA,
        ],
        repeat=2,
    )])
    def test_convert_image_color_space(self, transform, input):
        transform(input)