def test_randaugment(device, num_ops, magnitude, fill): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) transform = T.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill) s_transform = torch.jit.script(transform) for _ in range(25): _test_transform_vs_scripted(transform, s_transform, tensor) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
# automatically augments data based on a given auto-augmentation policy. # See :class:`~torchvision.transforms.AutoAugmentPolicy` for the available policies. policies = [T.AutoAugmentPolicy.CIFAR10, T.AutoAugmentPolicy.IMAGENET, T.AutoAugmentPolicy.SVHN] augmenters = [T.AutoAugment(policy) for policy in policies] imgs = [ [augmenter(orig_img) for _ in range(4)] for augmenter in augmenters ] row_title = [str(policy).split('.')[-1] for policy in policies] plot(imgs, row_title=row_title) #################################### # RandAugment # ~~~~~~~~~~~ # The :class:`~torchvision.transforms.RandAugment` transform automatically augments the data. augmenter = T.RandAugment() imgs = [augmenter(orig_img) for _ in range(4)] plot(imgs) #################################### # TrivialAugmentWide # ~~~~~~~~~~~~~~~~~~ # The :class:`~torchvision.transforms.TrivialAugmentWide` transform automatically augments the data. augmenter = T.TrivialAugmentWide() imgs = [augmenter(orig_img) for _ in range(4)] plot(imgs) #################################### # Randomly-applied transforms # --------------------------- #