Esempio n. 1
0
def generate_kornia_transforms(image_size=224, resize=256, mean=[], std=[], include_jitter=False):
    mean=torch.tensor(mean) if mean else torch.tensor([0.5, 0.5, 0.5])
    std=torch.tensor(std) if std else torch.tensor([0.1, 0.1, 0.1])
    if torch.cuda.is_available():
        mean=mean.cuda()
        std=std.cuda()
    train_transforms=[G.Resize((resize,resize))]
    if include_jitter:
        train_transforms.append(K.ColorJitter(brightness=0.4, contrast=0.4,
                                   saturation=0.4, hue=0.1))
    train_transforms.extend([K.RandomHorizontalFlip(p=0.5),
           K.RandomVerticalFlip(p=0.5),
           K.RandomRotation(90),
           K.RandomResizedCrop((image_size,image_size)),
           K.Normalize(mean,std)
           ])
    val_transforms=[G.Resize((resize,resize)),
           K.CenterCrop((image_size,image_size)),
           K.Normalize(mean,std)
           ]
    transforms=dict(train=nn.Sequential(*train_transforms),
                val=nn.Sequential(*val_transforms))
    if torch.cuda.is_available():
        for k in transforms:
            transforms[k]=transforms[k].cuda()
    return transforms
Esempio n. 2
0
 def __init__(self, modulation, h, w):
     # worst case rotation brings sqrt(2) * max_side_length out-of-frame pixels into frame
     # padding should cover that exactly
     padding = int(max(h, w) * (1 - math.sqrt(2) / 2))
     sequential_fn = lambda b: th.nn.Sequential(
         th.nn.ReflectionPad2d(padding), kT.Rotate(b), kA.CenterCrop(
             (h, w)))
     super(Rotate, self).__init__(sequential_fn, modulation)
Esempio n. 3
0
 def __init__(self, modulation, h, w, noise):
     sequential_fn = lambda b: th.nn.Sequential(
         th.nn.ReflectionPad2d((int(w / 2), int(w / 2), 0, 0)),
         th.nn.ReflectionPad2d((w, w, 0, 0)),
         th.nn.ReflectionPad2d((w, 0, 0, 0)),
         AddNoise(noise),
         kT.Translate(b),
         kA.CenterCrop((h, w)),
     )
     super(Translate, self).__init__(sequential_fn, modulation)
Esempio n. 4
0
 def __init__(self,resize,image_size,mean,std,include_jitter=False,Set="train"):
     super().__init__()
     self.resize=G.Resize((resize,resize),align_corners=False)
     self.mask_resize=lambda x: torch.nn.functional.interpolate(x, size=(resize,resize), mode='nearest', align_corners=None)#G.Resize((resize,resize),interpolation='nearest',align_corners=False)#
     self.jit=K.ColorJitter(brightness=0.4, contrast=0.4,
                                saturation=0.4, hue=0.1) if include_jitter else (lambda x: x)
     # self.rotations=nn.ModuleList([
     #        K.augmentation.RandomAffine([-90., 90.], [0., 0.15], [0.5, 1.5], [0., 0.15])
     #        # K.RandomHorizontalFlip(p=0.5),
     #        # K.RandomVerticalFlip(p=0.5),
     #        # K.RandomRotation(90),#K.RandomResizedCrop((image_size,image_size),interpolation="nearest")
     #        ])
     # self.rotations_mask=nn.ModuleList([
     #        K.augmentation.RandomAffine([-90., 90.], [0., 0.15], [0.5, 1.5], [0., 0.15],resample="NEAREST")
     #        ])
     self.affine=K.augmentation.RandomAffine([-90., 90.], [0., 0.15], None, [0., 0.15])
     self.affine_mask=K.augmentation.RandomAffine([-90., 90.], [0., 0.15], None, [0., 0.15],resample="NEAREST",align_corners=False)
     self.normalize=K.Normalize(mean,std)
     self.crop,self.mask_crop=K.CenterCrop((image_size,image_size)),K.CenterCrop((image_size,image_size),resample="NEAREST")
     self.Set=Set
Esempio n. 5
0
    def __init__(self,
                 N_TFMS: int,
                 MAGN: int,
                 mean: Union[tuple, list, torch.tensor],
                 std: Union[tuple, list, torch.tensor],
                 transform_list: list = None,
                 use_resize: int = None,
                 image_size: tuple = None,
                 use_mix: int = None,
                 mix_p: float = .5):
        super().__init__()

        self.N_TFMS, self.MAGN = N_TFMS, MAGN
        self.use_mix, self.mix_p = use_mix, mix_p
        self.image_size = image_size

        if not isinstance(mean, torch.Tensor): mean = torch.Tensor(mean)
        if not isinstance(std, torch.Tensor): std = torch.Tensor(std)

        if self.use_mix is not None:
            self.mix_list = [
                K.RandomCutMix(self.image_size[0], self.image_size[1], p=1),
                K.RandomMixUp(p=1)
            ]

        self.use_resize = use_resize
        if use_resize is not None:
            assert len(
                image_size
            ) == 2, 'Invalid `image_size`. Must be a tuple of form (h, w)'
            self.resize_list = [
                K.RandomResizedCrop(image_size),
                K.RandomCrop(image_size),
                K.CenterCrop(image_size)
            ]
            if self.use_resize < 3:
                self.resize = self.resize_list[use_resize]

        self.normalize = K.Normalize(mean, std)

        self.transform_list = transform_list
        if transform_list is None: self.transform_list = kornia_list(MAGN)
Esempio n. 6
0
class TestVideoSequential:
    @pytest.mark.parametrize('shape', [(3, 4), (2, 3, 4), (2, 3, 5, 6),
                                       (2, 3, 4, 5, 6, 7)])
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_exception(self, shape, data_format, device, dtype):
        aug_list = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1),
                                     data_format=data_format,
                                     same_on_frame=True)
        with pytest.raises(AssertionError):
            img = torch.randn(*shape, device=device, dtype=dtype)
            aug_list(img)

    @pytest.mark.parametrize(
        'augmentation',
        [
            K.RandomAffine(360, p=1.0),
            K.CenterCrop((3, 3), p=1.0),
            K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
            K.RandomCrop((5, 5), p=1.0),
            K.RandomErasing(p=1.0),
            K.RandomGrayscale(p=1.0),
            K.RandomHorizontalFlip(p=1.0),
            K.RandomVerticalFlip(p=1.0),
            K.RandomPerspective(p=1.0),
            K.RandomResizedCrop((5, 5), p=1.0),
            K.RandomRotation(360.0, p=1.0),
            K.RandomSolarize(p=1.0),
            K.RandomPosterize(p=1.0),
            K.RandomSharpness(p=1.0),
            K.RandomEqualize(p=1.0),
            K.RandomMotionBlur(3, 35.0, 0.5, p=1.0),
            K.Normalize(torch.tensor([0.5, 0.5, 0.5]),
                        torch.tensor([0.5, 0.5, 0.5]),
                        p=1.0),
            K.Denormalize(torch.tensor([0.5, 0.5, 0.5]),
                          torch.tensor([0.5, 0.5, 0.5]),
                          p=1.0),
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_augmentation(self, augmentation, data_format, device, dtype):
        input = torch.randint(255, (1, 3, 3, 5, 6), device=device,
                              dtype=dtype).repeat(2, 1, 1, 1, 1) / 255.0
        torch.manual_seed(21)
        aug_list = K.VideoSequential(augmentation,
                                     data_format=data_format,
                                     same_on_frame=True)
        reproducibility_test(input, aug_list)

    @pytest.mark.parametrize(
        'augmentations',
        [
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.RandomAffine(360, p=1.0)
            ],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)
            ],
            [K.RandomAffine(360, p=1.0),
             kornia.color.BgrToRgb()],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0),
                K.RandomAffine(360, p=0.0)
            ],
            [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0)],
            [K.RandomAffine(360, p=0.0)],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.RandomAffine(360, p=1.0),
                K.RandomMixUp(p=1.0)
            ],
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    @pytest.mark.parametrize('random_apply',
                             [1, (1, 1), (1, ), 10, True, False])
    def test_same_on_frame(self, augmentations, data_format, random_apply,
                           device, dtype):
        aug_list = K.VideoSequential(*augmentations,
                                     data_format=data_format,
                                     same_on_frame=True,
                                     random_apply=random_apply)

        if data_format == 'BCTHW':
            input = torch.randn(2, 3, 1, 5, 6, device=device,
                                dtype=dtype).repeat(1, 1, 4, 1, 1)
            output = aug_list(input)
            if aug_list.return_label:
                output, _ = output
            assert (output[:, :, 0] == output[:, :, 1]).all()
            assert (output[:, :, 1] == output[:, :, 2]).all()
            assert (output[:, :, 2] == output[:, :, 3]).all()
        if data_format == 'BTCHW':
            input = torch.randn(2, 1, 3, 5, 6, device=device,
                                dtype=dtype).repeat(1, 4, 1, 1, 1)
            output = aug_list(input)
            if aug_list.return_label:
                output, _ = output
            assert (output[:, 0] == output[:, 1]).all()
            assert (output[:, 1] == output[:, 2]).all()
            assert (output[:, 2] == output[:, 3]).all()
        reproducibility_test(input, aug_list)

    @pytest.mark.parametrize(
        'augmentations',
        [
            [K.RandomAffine(360, p=1.0)],
            [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)],
            [
                K.RandomAffine(360, p=0.0),
                K.ImageSequential(K.RandomAffine(360, p=0.0))
            ],
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_against_sequential(self, augmentations, data_format, device,
                                dtype):
        aug_list_1 = K.VideoSequential(*augmentations,
                                       data_format=data_format,
                                       same_on_frame=False)
        aug_list_2 = torch.nn.Sequential(*augmentations)

        if data_format == 'BCTHW':
            input = torch.randn(2, 3, 1, 5, 6, device=device,
                                dtype=dtype).repeat(1, 1, 4, 1, 1)
        if data_format == 'BTCHW':
            input = torch.randn(2, 1, 3, 5, 6, device=device,
                                dtype=dtype).repeat(1, 4, 1, 1, 1)

        torch.manual_seed(0)
        output_1 = aug_list_1(input)

        torch.manual_seed(0)
        if data_format == 'BCTHW':
            input = input.transpose(1, 2)
        output_2 = aug_list_2(input.reshape(-1, 3, 5, 6))
        output_2 = output_2.view(2, 4, 3, 5, 6)
        if data_format == 'BCTHW':
            output_2 = output_2.transpose(1, 2)
        assert (output_1 == output_2).all(), dict(aug_list_1._params)

    @pytest.mark.jit
    @pytest.mark.skip(reason="turn off due to Union Type")
    def test_jit(self, device, dtype):
        B, C, D, H, W = 2, 3, 5, 4, 4
        img = torch.ones(B, C, D, H, W, device=device, dtype=dtype)
        op = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1),
                               same_on_frame=True)
        op_jit = torch.jit.script(op)
        assert_close(op(img), op_jit(img))
Esempio n. 7
0
 def __init__(self, modulation, h, w):
     padding = int(max(h, w)) - 1
     sequential_fn = lambda b: th.nn.Sequential(
         th.nn.ReflectionPad2d(padding), kT.Scale(b), kA.CenterCrop((h, w)))
     super(Zoom, self).__init__(sequential_fn, modulation)
Esempio n. 8
0
    def __init__(
        self,
        image_size,
        latent_dim=512,
        style_depth=8,
        network_capacity=16,
        transparent=False,
        fp16=False,
        cl_reg=False,
        augment_fn=None,
        steps=1,
        lr=1e-4,
        fq_layers=[],
        fq_dict_size=256,
        attn_layers=[],
    ):
        super().__init__()
        self.lr = lr
        self.steps = steps
        self.ema_updater = EMA(0.995)

        self.S = StyleVectorizer(latent_dim, style_depth)
        self.G = Generator(image_size,
                           latent_dim,
                           network_capacity,
                           transparent=transparent,
                           attn_layers=attn_layers)
        self.D = Discriminator(
            image_size,
            network_capacity,
            fq_layers=fq_layers,
            fq_dict_size=fq_dict_size,
            attn_layers=attn_layers,
            transparent=transparent,
        )

        self.SE = StyleVectorizer(latent_dim, style_depth)
        self.GE = Generator(image_size,
                            latent_dim,
                            network_capacity,
                            transparent=transparent,
                            attn_layers=attn_layers)

        set_requires_grad(self.SE, False)
        set_requires_grad(self.GE, False)

        generator_params = list(self.G.parameters()) + list(
            self.S.parameters())
        self.G_opt = DiffGrad(generator_params, lr=self.lr, betas=(0.5, 0.9))
        self.D_opt = DiffGrad(self.D.parameters(),
                              lr=self.lr,
                              betas=(0.5, 0.9))

        self._init_weights()
        self.reset_parameter_averaging()

        self.cuda()

        if fp16:
            (self.S, self.G, self.D, self.SE,
             self.GE), (self.G_opt, self.D_opt) = amp.initialize(
                 [self.S, self.G, self.D, self.SE, self.GE],
                 [self.G_opt, self.D_opt],
                 opt_level="O2")

        # experimental contrastive loss discriminator regularization
        if augment_fn is not None:
            self.augment_fn = augment_fn
        else:
            self.augment_fn = nn.Sequential(
                nn.ReflectionPad2d(int((sqrt(2) - 1) * image_size / 4)),
                RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.7),
                augs.RandomGrayscale(p=0.2),
                augs.RandomHorizontalFlip(),
                RandomApply(augs.RandomAffine(degrees=0,
                                              translate=(0.25, 0.25),
                                              shear=(15, 15)),
                            p=0.3),
                RandomApply(nn.Sequential(
                    augs.RandomRotation(180),
                    augs.CenterCrop(size=(image_size, image_size))),
                            p=0.2),
                augs.RandomResizedCrop(size=(image_size, image_size)),
                RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
                RandomApply(augs.RandomErasing(), p=0.1),
            )

        self.D_cl = (ContrastiveLearner(self.D,
                                        image_size,
                                        augment_fn=self.augment_fn,
                                        fp16=fp16,
                                        hidden_layer="flatten")
                     if cl_reg else None)