Exemplo n.º 1
0
def test_randaugment(device, num_ops, magnitude, fill):
    tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
    batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)

    transform = T.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
    s_transform = torch.jit.script(transform)
    for _ in range(25):
        _test_transform_vs_scripted(transform, s_transform, tensor)
        _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
Exemplo n.º 2
0
# automatically augments data based on a given auto-augmentation policy.
# See :class:`~torchvision.transforms.AutoAugmentPolicy` for the available policies.
policies = [T.AutoAugmentPolicy.CIFAR10, T.AutoAugmentPolicy.IMAGENET, T.AutoAugmentPolicy.SVHN]
augmenters = [T.AutoAugment(policy) for policy in policies]
imgs = [
    [augmenter(orig_img) for _ in range(4)]
    for augmenter in augmenters
]
row_title = [str(policy).split('.')[-1] for policy in policies]
plot(imgs, row_title=row_title)

####################################
# RandAugment
# ~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandAugment` transform automatically augments the data.
augmenter = T.RandAugment()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)

####################################
# TrivialAugmentWide
# ~~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.TrivialAugmentWide` transform automatically augments the data.
augmenter = T.TrivialAugmentWide()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)

####################################
# Randomly-applied transforms
# ---------------------------
#