Exemplo n.º 1
0
def test_augments():
    imgs = torch.randn(4, 3, 32, 32)
    labels = torch.randint(0, 10, (4, ))

    # Test cutmix
    augments_cfg = dict(type='BatchCutMix', alpha=1., num_classes=10, prob=1.)
    augs = Augments(augments_cfg)
    mixed_imgs, mixed_labels = augs(imgs, labels)
    assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
    assert mixed_labels.shape == torch.Size((4, 10))

    # Test mixup
    augments_cfg = dict(type='BatchMixup', alpha=1., num_classes=10, prob=1.)
    augs = Augments(augments_cfg)
    mixed_imgs, mixed_labels = augs(imgs, labels)
    assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
    assert mixed_labels.shape == torch.Size((4, 10))

    # Test cutmixup
    augments_cfg = [
        dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
        dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3)
    ]
    augs = Augments(augments_cfg)
    mixed_imgs, mixed_labels = augs(imgs, labels)
    assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
    assert mixed_labels.shape == torch.Size((4, 10))

    augments_cfg = [
        dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
        dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.5)
    ]
    augs = Augments(augments_cfg)
    mixed_imgs, mixed_labels = augs(imgs, labels)
    assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
    assert mixed_labels.shape == torch.Size((4, 10))

    augments_cfg = [
        dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
        dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3),
        dict(type='Identity', num_classes=10, prob=0.2)
    ]
    augs = Augments(augments_cfg)
    mixed_imgs, mixed_labels = augs(imgs, labels)
    assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
    assert mixed_labels.shape == torch.Size((4, 10))
Exemplo n.º 2
0
def test_multilabel_augment(cfg):

    cfg_ = dict(num_classes=10, **cfg)
    augs = Augments(cfg_)

    imgs = torch.randn(4, 3, 32, 32)
    labels = torch.randint(0, 2, (4, 10)).float()

    mixed_imgs, mixed_labels = augs(imgs, labels)
    assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
    assert mixed_labels.shape == torch.Size((4, 10))