Exemple #1
0
def get_train_augmentation_transforms(
    p_flip_vertical=0.5,
    p_flip_horizontal=0.5,
    max_rotation=10.0,
    max_zoom=1.1,
    max_warp=0.2,
    p_affine=0.75,
    max_lighting=0.2,
    p_lighting=0.75,
):
    """
    Build a set of pytorch image Transforms to use during training:
        p_flip_vertical: probability of a vertical flip
        p_flip_horizontal: probability of a horizontal flip
        max_rotation: maximum rotation angle in degrees
        max_zoom: maximum zoom level
        max_warp: perspective warping scale (from 0.0 to 1.0)
        p_affine: probility of rotation, zoom and perspective warping
        max_lighting: maximum scaling of brightness and contrast
    """
    return [
        kaug.RandomVerticalFlip(p=p_flip_vertical),
        kaug.RandomHorizontalFlip(p=p_flip_horizontal),
        kaug.RandomAffine(p=p_affine, degrees=max_rotation, scale=(1.0, max_zoom)),
        kaug.RandomPerspective(p=p_affine, distortion_scale=max_warp),
        kaug.ColorJitter(p=p_lighting, brightness=max_lighting, contrast=max_lighting),
        # TODO: the kornia transforms work on batches and so by default they
        # add a batch dimension. Until I work out how to apply transform by
        # batches (and only while training) I will just keep this here to
        # remove the batch dimension again
        GetItemTransform(),
    ]
Exemple #2
0
    def test_inverse_and_forward_return_transform(self, device, dtype):
        inp = torch.randn(1, 3, 1000, 500, device=device, dtype=dtype)
        bbox = torch.tensor([[[355, 10], [660, 10], [660, 250], [355, 250]]],
                            device=device,
                            dtype=dtype)
        keypoints = torch.tensor([[[465, 115], [545, 116]]],
                                 device=device,
                                 dtype=dtype)
        mask = bbox_to_mask(
            torch.tensor([[[155, 0], [900, 0], [900, 400], [155, 400]]],
                         device=device,
                         dtype=dtype), 1000, 500)[:, None].float()
        aug = K.AugmentationSequential(
            K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0, return_transform=True),
            K.RandomAffine(360, p=1.0, return_transform=True),
            data_keys=["input", "mask", "bbox", "keypoints"],
        )

        out_inv = aug.inverse(inp, mask, bbox, keypoints)
        assert out_inv[0].shape == inp.shape
        assert out_inv[1].shape == mask.shape
        assert out_inv[2].shape == bbox.shape
        assert out_inv[3].shape == keypoints.shape

        out = aug(inp, mask, bbox, keypoints)
        assert out[0][0].shape == inp.shape
        assert out[1].shape == mask.shape
        assert out[2].shape == bbox.shape
        assert out[3].shape == keypoints.shape
Exemple #3
0
    def test_inverse_and_forward_return_transform(self, random_apply, device,
                                                  dtype):
        inp = torch.randn(1, 3, 1000, 500, device=device, dtype=dtype)
        bbox = torch.tensor([[[355, 10], [660, 10], [660, 250], [355, 250]]],
                            device=device,
                            dtype=dtype)
        keypoints = torch.tensor([[[465, 115], [545, 116]]],
                                 device=device,
                                 dtype=dtype)
        mask = bbox_to_mask(
            torch.tensor([[[155, 0], [900, 0], [900, 400], [155, 400]]],
                         device=device,
                         dtype=dtype), 1000, 500)[:, None].float()
        aug = K.AugmentationSequential(
            K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0, return_transform=True),
            K.RandomAffine(360, p=1.0, return_transform=True),
            data_keys=["input", "mask", "bbox", "keypoints"],
            random_apply=random_apply,
        )
        with pytest.raises(
                Exception):  # No parameters avaliable for inversing.
            aug.inverse(inp, mask, bbox, keypoints)

        out = aug(inp, mask, bbox, keypoints)
        assert out[0][0].shape == inp.shape
        assert out[1].shape == mask.shape
        assert out[2].shape == bbox.shape
        assert out[3].shape == keypoints.shape

        reproducibility_test((inp, mask, bbox, keypoints), aug)
Exemple #4
0
    def test_forward_and_inverse(self, random_apply, return_transform, device,
                                 dtype):
        inp = torch.randn(1, 3, 1000, 500, device=device, dtype=dtype)
        bbox = torch.tensor([[[355, 10], [660, 10], [660, 250], [355, 250]]],
                            device=device,
                            dtype=dtype)
        keypoints = torch.tensor([[[465, 115], [545, 116]]],
                                 device=device,
                                 dtype=dtype)
        mask = bbox_to_mask(
            torch.tensor([[[155, 0], [900, 0], [900, 400], [155, 400]]],
                         device=device,
                         dtype=dtype), 1000, 500)[:, None].float()
        aug = K.AugmentationSequential(
            K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
            K.RandomAffine(360, p=1.0),
            data_keys=["input", "mask", "bbox", "keypoints"],
            random_apply=random_apply,
            return_transform=return_transform,
        )
        out = aug(inp, mask, bbox, keypoints)
        if return_transform and isinstance(out, (tuple, list)):
            assert out[0][0].shape == inp.shape
        else:
            assert out[0].shape == inp.shape
        assert out[1].shape == mask.shape
        assert out[2].shape == bbox.shape
        assert out[3].shape == keypoints.shape
        reproducibility_test((inp, mask, bbox, keypoints), aug)

        out_inv = aug.inverse(*out)
        assert out_inv[0].shape == inp.shape
        assert out_inv[1].shape == mask.shape
        assert out_inv[2].shape == bbox.shape
        assert out_inv[3].shape == keypoints.shape
Exemple #5
0
 def __init__(self, mean, std, scale=(0.9, 1.1), max_degrees=0) -> None:
     super(Transform, self).__init__()
     self.max_degrees = max_degrees
     self.aff = k.RandomAffine(max_degrees,
                               resample=k.Resample.NEAREST,
                               scale=scale)
     self.norm = k.Normalize(mean, std)
Exemple #6
0
 def test_mixup(self, inp, random_apply, same_on_batch, device, dtype):
     inp = torch.as_tensor(inp, device=device, dtype=dtype)
     aug = K.AugmentationSequential(
         K.ImageSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
                           K.RandomAffine(360, p=1.0)),
         K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
         K.RandomAffine(360, p=1.0),
         K.RandomMixUp(p=1.0),
         data_keys=["input"],
         random_apply=random_apply,
         same_on_batch=same_on_batch,
     )
     out = aug(inp)
     if aug.return_label:
         out, _ = out
     assert out.shape[-3:] == inp.shape[-3:]
     reproducibility_test(inp, aug)
Exemple #7
0
 def test_construction(self, same_on_batch, return_transform, keepdim):
     K.ImageSequential(
         K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
         K.RandomAffine(360, p=1.0),
         same_on_batch=same_on_batch,
         return_transform=return_transform,
         keepdim=keepdim,
     )
Exemple #8
0
 def test_jit(self, device, dtype):
     B, C, H, W = 2, 3, 4, 4
     img = torch.ones(B, C, H, W, device=device, dtype=dtype)
     op = K.AugmentationSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                                   K.RandomAffine(360, p=1.0),
                                   same_on_batch=True)
     op_jit = torch.jit.script(op)
     assert_close(op(img), op_jit(img))
Exemple #9
0
 def test_forward(self, random_apply, device, dtype):
     inp = torch.randn(1, 3, 30, 30, device=device, dtype=dtype)
     aug = K.ImageSequential(
         K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
         kornia.filters.MedianBlur((3, 3)),
         K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
         K.ImageSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0)),
         K.ImageSequential(K.RandomAffine(360, p=1.0)),
         K.RandomAffine(360, p=1.0),
         K.RandomMixUp(p=1.0),
         random_apply=random_apply,
     )
     out = aug(inp)
     if aug.return_label:
         out, _ = out
     assert out.shape == inp.shape
     aug.inverse(inp)
     reproducibility_test(inp, aug)
 def __init__(self, cutn):
     super().__init__()
     self.cutn = cutn
     self.augs = nn.Sequential(
         K.RandomHorizontalFlip(p=0.5),
         K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
         #K.RandomSolarize(0.01, 0.01, p=0.7),
         K.RandomSharpness(0.3, p=0.4),
         K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'),
         K.RandomPerspective(0.2, p=0.4), )
     self.noise_fac = 0.1
def test_transform_kornia():
    """Run few epochs to check ``BatchTransformCallback`` callback."""
    model = torch.nn.Linear(28 * 28, 10)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.02)

    loaders = {
        "train":
        DataLoader(
            MnistDataset(
                MNIST(os.getcwd(),
                      train=False,
                      download=True,
                      transform=ToTensor())),
            batch_size=32,
        ),
        "valid":
        DataLoader(
            MnistDataset(
                MNIST(os.getcwd(),
                      train=False,
                      download=True,
                      transform=ToTensor())),
            batch_size=32,
        ),
    }

    transrorms = [
        augmentation.RandomAffine(degrees=(-15, 20), scale=(0.75, 1.25)),
    ]

    runner = CustomRunner()

    # model training
    runner.train(
        model=model,
        optimizer=optimizer,
        loaders=loaders,
        logdir="./logs",
        num_epochs=5,
        verbose=False,
        load_best_on_end=True,
        check=True,
        callbacks=[
            BatchTransformCallback(transform=transrorms,
                                   scope="on_batch_start",
                                   input_key="features")
        ],
    )

    # model inference
    for prediction in runner.predict_loader(loader=loaders["train"]):
        assert prediction.detach().cpu().numpy().shape[-1] == 10
Exemple #12
0
 def test_forward(self, return_transform, random_apply, device, dtype):
     inp = torch.randn(1, 3, 30, 30, device=device, dtype=dtype)
     aug = K.ImageSequential(
         K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
         kornia.filters.MedianBlur((3, 3)),
         K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0, return_transform=True),
         K.ImageSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)),
         K.ImageSequential(K.RandomAffine(360, p=1.0)),
         K.RandomAffine(360, p=1.0),
         K.RandomMixUp(p=1.0),
         return_transform=return_transform,
         random_apply=random_apply,
     )
     out = aug(inp)
     if aug.return_label:
         out, _ = out
     if isinstance(out, (tuple, )):
         assert out[0].shape == inp.shape
     else:
         assert out.shape == inp.shape
     aug.inverse(inp)
     reproducibility_test(inp, aug)
Exemple #13
0
 def test_mixup(self, inp, return_transform, random_apply, same_on_batch,
                device, dtype):
     inp = torch.as_tensor(inp, device=device, dtype=dtype)
     aug = K.AugmentationSequential(
         K.ImageSequential(
             K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
             K.RandomAffine(360, p=1.0, return_transform=True)),
         K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
         K.RandomAffine(360, p=1.0),
         K.RandomMixUp(p=1.0),
         data_keys=["input"],
         random_apply=random_apply,
         return_transform=return_transform,
         same_on_batch=same_on_batch,
     )
     out = aug(inp)
     if aug.return_label:
         out, _ = out
     if return_transform and isinstance(out, (tuple, list)):
         out = out[0]
     assert out.shape[-3:] == inp.shape[-3:]
     reproducibility_test(inp, aug)
Exemple #14
0
    def test_individual_forward_and_inverse(self, device, dtype):
        inp = torch.randn(1, 3, 1000, 500, device=device, dtype=dtype)
        bbox = torch.tensor([[[355, 10], [660, 10], [660, 250], [355, 250]]],
                            device=device,
                            dtype=dtype)
        keypoints = torch.tensor([[[465, 115], [545, 116]]],
                                 device=device,
                                 dtype=dtype)
        mask = bbox_to_mask(
            torch.tensor([[[155, 0], [900, 0], [900, 400], [155, 400]]],
                         device=device,
                         dtype=dtype), 1000, 500)[:, None].float()

        aug = K.AugmentationSequential(
            K.ImageSequential(
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.RandomAffine(360, p=1.0, return_transform=True)),
            K.AugmentationSequential(
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.RandomAffine(360, p=1.0, return_transform=True)),
            K.RandomAffine(360, p=1.0, return_transform=False),
            data_keys=['input', 'mask', 'bbox', 'keypoints'],
        )
        reproducibility_test((inp, mask, bbox, keypoints), aug)

        aug = K.AugmentationSequential(
            K.RandomAffine(360, p=1.0, return_transform=True))
        assert aug(inp, data_keys=['input'])[0].shape == inp.shape
        aug = K.AugmentationSequential(
            K.RandomAffine(360, p=1.0, return_transform=False))
        assert aug(inp, data_keys=['input']).shape == inp.shape
        assert aug(mask, data_keys=['mask'],
                   params=aug._params).shape == mask.shape

        assert aug.inverse(inp, data_keys=['input']).shape == inp.shape
        assert aug.inverse(bbox, data_keys=['bbox']).shape == bbox.shape
        assert aug.inverse(keypoints,
                           data_keys=['keypoints']).shape == keypoints.shape
        assert aug.inverse(mask, data_keys=['mask']).shape == mask.shape
Exemple #15
0
    def test_forward_and_inverse_return_transform(self, random_apply, device,
                                                  dtype):
        inp = torch.randn(1, 3, 1000, 500, device=device, dtype=dtype)
        bbox = torch.tensor([[[355, 10], [660, 10], [660, 250], [355, 250]]],
                            device=device,
                            dtype=dtype)
        keypoints = torch.tensor([[[465, 115], [545, 116]]],
                                 device=device,
                                 dtype=dtype)
        mask = bbox_to_mask(
            torch.tensor([[[155, 0], [900, 0], [900, 400], [155, 400]]],
                         device=device,
                         dtype=dtype), 1000, 500)[:, None].float()
        aug = K.AugmentationSequential(
            K.ImageSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
                              K.RandomAffine(360, p=1.0)),
            K.AugmentationSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
                                     K.RandomAffine(360, p=1.0)),
            K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
            K.RandomAffine(360, p=1.0),
            data_keys=["input", "mask", "bbox", "keypoints"],
            random_apply=random_apply,
        )
        out = aug(inp, mask, bbox, keypoints)
        assert out[0].shape == inp.shape
        assert out[1].shape == mask.shape
        assert out[2].shape == bbox.shape
        assert out[3].shape == keypoints.shape

        reproducibility_test((inp, mask, bbox, keypoints), aug)

        # TODO(jian): we sometimes throw the following error
        # AttributeError: 'tuple' object has no attribute 'shape'
        out_inv = aug.inverse(*out)
        assert out_inv[0].shape == inp.shape
        assert out_inv[1].shape == mask.shape
        assert out_inv[2].shape == bbox.shape
        assert out_inv[3].shape == keypoints.shape
Exemple #16
0
 def test_forward(self, return_transform, device, dtype):
     inp = torch.randn(1, 3, 30, 30, device=device, dtype=dtype)
     aug = K.ImageSequential(
         K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
         kornia.filters.MedianBlur((3, 3)),
         K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0, return_transform=True),
         K.RandomAffine(360, p=1.0),
         return_transform=return_transform,
     )
     out = aug(inp)
     if isinstance(out, (tuple, )):
         assert out[0].shape == inp.shape
     else:
         assert out.shape == inp.shape
Exemple #17
0
 def test_construction(self, same_on_batch, keepdim, random_apply):
     aug = K.ImageSequential(
         K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
         K.RandomAffine(360, p=1.0),
         K.RandomMixUp(p=1.0),
         same_on_batch=same_on_batch,
         keepdim=keepdim,
         random_apply=random_apply,
     )
     c = 0
     for a in aug.get_forward_sequence():
         if isinstance(a, (MixAugmentationBase, )):
             c += 1
     assert c < 2
     aug.same_on_batch = True
     aug.keepdim = True
     for m in aug.children():
         assert m.same_on_batch is True, m.same_on_batch
         assert m.keepdim is True, m.keepdim
Exemple #18
0
    def __init__(self, utes, mask, ct, length, opt):
        super(TrainDataset, self).__init__()

        self.utes = utes
        self.label = ct
        self.mask = mask
        self.length = length
        self.num_vols = utes.shape[0]

        self.batch_size = opt.trainBatchSize

        self.spatial = nn.Sequential(
            ka.RandomAffine(45,
                            translate=(0.1, 0.1),
                            scale=(0.85, 1.15),
                            shear=(0.1, 0.1),
                            same_on_batch=True),
            ka.RandomVerticalFlip(same_on_batch=True),
            ka.RandomHorizontalFlip(same_on_batch=True))
        self.dim = 2
        self.counter = 0
Exemple #19
0
    def __init__(self,
                 brightness=(0.75, 1.25),
                 contrast=(0.75, 1.25),
                 saturation=(0., 2.),
                 translate=(0.125, 0.125),
                 normalized=True,
                 mean=0.5,
                 std=0.5,
                 device=None):
        if normalized:
            if isinstance(mean,
                          (tuple, list)) and isinstance(std, (tuple, list)):
                if not device:
                    raise Exception(
                        'Please specify a torch.device() object when using mean and std for each channels'
                    )
                mean = torch.Tensor(mean).to(device)
                std = torch.Tensor(std).to(device)
            self.normalize = aug.Normalize(mean, std)
            self.denormalize = aug.Denormalize(mean, std)
        else:
            self.normalize, self.denormalize = None, None

        color_jitter = aug.ColorJitter(
            brightness=brightness,
            contrast=contrast,
            saturation=saturation,
            p=1.)  # rand_brightness, rand_contrast, rand_saturation
        affine = aug.RandomAffine(degrees=0,
                                  translate=translate,
                                  padding_mode=SamplePadding.BORDER,
                                  p=1.)  # rand_translate
        cutout = aug.RandomErasing(value=0.5, p=1.)  # rand_cutout

        self.augmentations = {
            'color': color_jitter,
            'translation': affine,
            'cutout': cutout
        }
Exemple #20
0
		def __init__(self, cut_size, cutn, cut_pow=1.):
			super().__init__()
			self.cut_size = cut_size
			self.cutn = cutn
			self.cut_pow = cut_pow

			self.augs = nn.Sequential(
				# K.RandomHorizontalFlip(p=0.5),
				# K.RandomVerticalFlip(p=0.5),
				# K.RandomSolarize(0.01, 0.01, p=0.7),
				# K.RandomSharpness(0.3, p=0.4),
				# K.RandomResizedCrop(
				#	size=(self.cut_size, self.cut_size), 
				#	scale=(0.1, 1), ratio=(0.75, 1.333), 
				#	cropping_mode="resample", p=0.5
				# ),
				# K.RandomCrop(
				#	size=(self.cut_size, self.cut_size), p=0.5
				# ),
				K.RandomAffine(
					degrees=15, translate=0.1, p=0.7, 
					padding_mode="border"
				),
				K.RandomPerspective(0.7, p=0.7),
				K.ColorJitter(hue=0.1, saturation=0.1, p=0.7),
				K.RandomErasing(
					(.1, .4), (.3, 1/.3), same_on_batch=True, p=0.7
				),
			)

			self.noise_fac = 0.1
			self.av_pool = nn.AdaptiveAvgPool2d(
				(self.cut_size, self.cut_size)
			)
			self.max_pool = nn.AdaptiveMaxPool2d(
				(self.cut_size, self.cut_size)
			)
                                               mode="validation",
                                               download=True)

transform = {
    "per_sample_transform":
    nn.Sequential(
        ApplyToKeys(
            DataKeys.INPUT,
            nn.Sequential(
                torchvision.transforms.ToTensor(),
                Kg.Resize((196, 196)),
                # SPATIAL
                Ka.RandomHorizontalFlip(p=0.25),
                Ka.RandomRotation(degrees=90.0, p=0.25),
                Ka.RandomAffine(degrees=1 * 5.0,
                                shear=1 / 5,
                                translate=1 / 20,
                                p=0.25),
                Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25),
                # PIXEL-LEVEL
                Ka.ColorJitter(brightness=1 / 30, p=0.25),  # brightness
                Ka.ColorJitter(saturation=1 / 30, p=0.25),  # saturation
                Ka.ColorJitter(contrast=1 / 30, p=0.25),  # contrast
                Ka.ColorJitter(hue=1 / 30, p=0.25),  # hue
                Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1,
                                    angle=1,
                                    direction=1.0,
                                    p=0.25),
                Ka.RandomErasing(scale=(1 / 100, 1 / 50),
                                 ratio=(1 / 20, 1),
                                 p=0.25),
            ),
Exemple #22
0
class TestVideoSequential:
    @pytest.mark.parametrize('shape', [(3, 4), (2, 3, 4), (2, 3, 5, 6),
                                       (2, 3, 4, 5, 6, 7)])
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_exception(self, shape, data_format, device, dtype):
        aug_list = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1),
                                     data_format=data_format,
                                     same_on_frame=True)
        with pytest.raises(AssertionError):
            img = torch.randn(*shape, device=device, dtype=dtype)
            aug_list(img)

    @pytest.mark.parametrize(
        'augmentation',
        [
            K.RandomAffine(360, p=1.0),
            K.CenterCrop((3, 3), p=1.0),
            K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
            K.RandomCrop((5, 5), p=1.0),
            K.RandomErasing(p=1.0),
            K.RandomGrayscale(p=1.0),
            K.RandomHorizontalFlip(p=1.0),
            K.RandomVerticalFlip(p=1.0),
            K.RandomPerspective(p=1.0),
            K.RandomResizedCrop((5, 5), p=1.0),
            K.RandomRotation(360.0, p=1.0),
            K.RandomSolarize(p=1.0),
            K.RandomPosterize(p=1.0),
            K.RandomSharpness(p=1.0),
            K.RandomEqualize(p=1.0),
            K.RandomMotionBlur(3, 35.0, 0.5, p=1.0),
            K.Normalize(torch.tensor([0.5, 0.5, 0.5]),
                        torch.tensor([0.5, 0.5, 0.5]),
                        p=1.0),
            K.Denormalize(torch.tensor([0.5, 0.5, 0.5]),
                          torch.tensor([0.5, 0.5, 0.5]),
                          p=1.0),
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_augmentation(self, augmentation, data_format, device, dtype):
        input = torch.randint(255, (1, 3, 3, 5, 6), device=device,
                              dtype=dtype).repeat(2, 1, 1, 1, 1) / 255.0
        torch.manual_seed(21)
        aug_list = K.VideoSequential(augmentation,
                                     data_format=data_format,
                                     same_on_frame=True)
        reproducibility_test(input, aug_list)

    @pytest.mark.parametrize(
        'augmentations',
        [
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.RandomAffine(360, p=1.0)
            ],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)
            ],
            [K.RandomAffine(360, p=1.0),
             kornia.color.BgrToRgb()],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0),
                K.RandomAffine(360, p=0.0)
            ],
            [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0)],
            [K.RandomAffine(360, p=0.0)],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.RandomAffine(360, p=1.0),
                K.RandomMixUp(p=1.0)
            ],
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    @pytest.mark.parametrize('random_apply',
                             [1, (1, 1), (1, ), 10, True, False])
    def test_same_on_frame(self, augmentations, data_format, random_apply,
                           device, dtype):
        aug_list = K.VideoSequential(*augmentations,
                                     data_format=data_format,
                                     same_on_frame=True,
                                     random_apply=random_apply)

        if data_format == 'BCTHW':
            input = torch.randn(2, 3, 1, 5, 6, device=device,
                                dtype=dtype).repeat(1, 1, 4, 1, 1)
            output = aug_list(input)
            if aug_list.return_label:
                output, _ = output
            assert (output[:, :, 0] == output[:, :, 1]).all()
            assert (output[:, :, 1] == output[:, :, 2]).all()
            assert (output[:, :, 2] == output[:, :, 3]).all()
        if data_format == 'BTCHW':
            input = torch.randn(2, 1, 3, 5, 6, device=device,
                                dtype=dtype).repeat(1, 4, 1, 1, 1)
            output = aug_list(input)
            if aug_list.return_label:
                output, _ = output
            assert (output[:, 0] == output[:, 1]).all()
            assert (output[:, 1] == output[:, 2]).all()
            assert (output[:, 2] == output[:, 3]).all()
        reproducibility_test(input, aug_list)

    @pytest.mark.parametrize(
        'augmentations',
        [
            [K.RandomAffine(360, p=1.0)],
            [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)],
            [
                K.RandomAffine(360, p=0.0),
                K.ImageSequential(K.RandomAffine(360, p=0.0))
            ],
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_against_sequential(self, augmentations, data_format, device,
                                dtype):
        aug_list_1 = K.VideoSequential(*augmentations,
                                       data_format=data_format,
                                       same_on_frame=False)
        aug_list_2 = torch.nn.Sequential(*augmentations)

        if data_format == 'BCTHW':
            input = torch.randn(2, 3, 1, 5, 6, device=device,
                                dtype=dtype).repeat(1, 1, 4, 1, 1)
        if data_format == 'BTCHW':
            input = torch.randn(2, 1, 3, 5, 6, device=device,
                                dtype=dtype).repeat(1, 4, 1, 1, 1)

        torch.manual_seed(0)
        output_1 = aug_list_1(input)

        torch.manual_seed(0)
        if data_format == 'BCTHW':
            input = input.transpose(1, 2)
        output_2 = aug_list_2(input.reshape(-1, 3, 5, 6))
        output_2 = output_2.view(2, 4, 3, 5, 6)
        if data_format == 'BCTHW':
            output_2 = output_2.transpose(1, 2)
        assert (output_1 == output_2).all(), dict(aug_list_1._params)

    @pytest.mark.jit
    @pytest.mark.skip(reason="turn off due to Union Type")
    def test_jit(self, device, dtype):
        B, C, D, H, W = 2, 3, 5, 4, 4
        img = torch.ones(B, C, D, H, W, device=device, dtype=dtype)
        op = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1),
                               same_on_frame=True)
        op_jit = torch.jit.script(op)
        assert_close(op(img), op_jit(img))
Exemple #23
0
def get_augmenter(augmenter_type: str,
                  image_size: ImageSizeType,
                  dataset_mean: DatasetStatType,
                  dataset_std: DatasetStatType,
                  padding: PaddingInputType = 1. / 8.,
                  pad_if_needed: bool = False,
                  subset_size: int = 2) -> Union[Module, Callable]:
    """
    
    Args:
        augmenter_type: augmenter type
        image_size: (height, width) image size
        dataset_mean: dataset mean value in CHW
        dataset_std: dataset standard deviation in CHW
        padding: percent of image size to pad on each border of the image. If a sequence of length 4 is provided,
            it is used to pad left, top, right, bottom borders respectively. If a sequence of length 2 is provided, it is
            used to pad left/right, top/bottom borders, respectively.
        pad_if_needed: bool flag for RandomCrop "pad_if_needed" option
        subset_size: number of augmentations used in subset

    Returns: nn.Module for Kornia augmentation or Callable for torchvision transform

    """
    if not isinstance(padding, tuple):
        assert isinstance(padding, float)
        padding = (padding, padding, padding, padding)

    assert len(padding) == 2 or len(padding) == 4
    if len(padding) == 2:
        # padding of length 2 is used to pad left/right, top/bottom borders, respectively
        # padding of length 4 is used to pad left, top, right, bottom borders respectively
        padding = (padding[0], padding[1], padding[0], padding[1])

    # image_size is of shape (h,w); padding values is [left, top, right, bottom] borders
    padding = (int(image_size[1] * padding[0]), int(
        image_size[0] * padding[1]), int(image_size[1] * padding[2]),
               int(image_size[0] * padding[3]))

    augmenter_type = augmenter_type.strip().lower()

    if augmenter_type == "simple":
        return nn.Sequential(
            K.RandomCrop(size=image_size,
                         padding=padding,
                         pad_if_needed=pad_if_needed,
                         padding_mode='reflect'),
            K.RandomHorizontalFlip(p=0.5),
            K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32),
                        std=torch.tensor(dataset_std, dtype=torch.float32)),
        )

    elif augmenter_type == "fixed":
        return nn.Sequential(
            K.RandomHorizontalFlip(p=0.5),
            # K.RandomVerticalFlip(p=0.2),
            K.RandomResizedCrop(size=image_size,
                                scale=(0.8, 1.0),
                                ratio=(1., 1.)),
            RandomAugmentation(p=0.5,
                               augmentation=F.GaussianBlur2d(
                                   kernel_size=(3, 3),
                                   sigma=(1.5, 1.5),
                                   border_type='constant')),
            K.ColorJitter(contrast=(0.75, 1.5)),
            # additive Gaussian noise
            K.RandomErasing(p=0.1),
            # Multiply
            K.RandomAffine(degrees=(-25., 25.),
                           translate=(0.2, 0.2),
                           scale=(0.8, 1.2),
                           shear=(-8., 8.)),
            K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32),
                        std=torch.tensor(dataset_std, dtype=torch.float32)),
        )

    elif augmenter_type in ["validation", "test"]:
        return nn.Sequential(
            K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32),
                        std=torch.tensor(dataset_std, dtype=torch.float32)), )

    elif augmenter_type == "randaugment":
        return nn.Sequential(
            K.RandomCrop(size=image_size,
                         padding=padding,
                         pad_if_needed=pad_if_needed,
                         padding_mode='reflect'),
            K.RandomHorizontalFlip(p=0.5),
            RandAugmentNS(n=subset_size, m=10),
            K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32),
                        std=torch.tensor(dataset_std, dtype=torch.float32)),
        )

    else:
        raise NotImplementedError(
            f"\"{augmenter_type}\" is not a supported augmenter type")
Exemple #24
0
def shearX(data, opt):
    ra = K.RandomAffine(degrees=0, shear=(-20, 20))
    out = ra(data.view([-1] + list(data.shape[-3:])))

    return out.view(data.shape)
Exemple #25
0
    def test_random_crops(self, device, dtype):
        torch.manual_seed(233)
        input = torch.randn(3, 3, 3, 3, device=device, dtype=dtype)
        bbox = torch.tensor([[[1.0, 1.0, 2.0, 2.0], [0.0, 0.0, 1.0, 2.0],
                              [0.0, 0.0, 2.0, 1.0]]],
                            device=device,
                            dtype=dtype).expand(3, -1, -1)
        points = torch.tensor([[[0.0, 0.0], [1.0, 1.0]]],
                              device=device,
                              dtype=dtype).expand(3, -1, -1)
        aug = K.AugmentationSequential(
            K.RandomCrop((3, 3), padding=1, cropping_mode='resample', fill=0),
            K.RandomAffine((360., 360.), p=1.),
            data_keys=["input", "mask", "bbox_xyxy", "keypoints"],
        )

        reproducibility_test((input, input, bbox, points), aug)

        _params = aug.forward_parameters(input.shape)
        # specifying the crops allows us to compute by hand the expected outputs
        _params[0].data['src'] = torch.tensor(
            [
                [[1.0, 2.0], [3.0, 2.0], [3.0, 4.0], [1.0, 4.0]],
                [[1.0, 1.0], [3.0, 1.0], [3.0, 3.0], [1.0, 3.0]],
                [[2.0, 0.0], [4.0, 0.0], [4.0, 2.0], [2.0, 2.0]],
            ],
            device=_params[0].data['src'].device,
            dtype=_params[0].data['src'].dtype,
        )

        expected_out_bbox = torch.tensor(
            [
                [[1.0, 0.0, 2.0, 1.0], [0.0, -1.0, 1.0, 1.0],
                 [0.0, -1.0, 2.0, 0.0]],
                [[1.0, 1.0, 2.0, 2.0], [0.0, 0.0, 1.0, 2.0],
                 [0.0, 0.0, 2.0, 1.0]],
                [[0.0, 2.0, 1.0, 3.0], [-1.0, 1.0, 0.0, 3.0],
                 [-1.0, 1.0, 1.0, 2.0]],
            ],
            device=device,
            dtype=dtype,
        )
        expected_out_points = torch.tensor(
            [[[0.0, -1.0], [1.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]],
             [[-1.0, 1.0], [0.0, 2.0]]],
            device=device,
            dtype=dtype)

        out = aug(input, input, bbox, points, params=_params)
        assert out[0].shape == (3, 3, 3, 3)
        assert_close(out[0], out[1], atol=1e-4, rtol=1e-4)
        assert out[2].shape == bbox.shape
        assert_close(out[2], expected_out_bbox, atol=1e-4, rtol=1e-4)
        assert out[3].shape == points.shape
        assert_close(out[3], expected_out_points, atol=1e-4, rtol=1e-4)

        out_inv = aug.inverse(*out)
        assert out_inv[0].shape == input.shape
        assert_close(out_inv[0], out_inv[1], atol=1e-4, rtol=1e-4)
        assert out_inv[2].shape == bbox.shape
        assert_close(out_inv[2], bbox, atol=1e-4, rtol=1e-4)
        assert out_inv[3].shape == points.shape
        assert_close(out_inv[3], points, atol=1e-4, rtol=1e-4)
    def __init__(
        self,
        image_size,
        latent_dim=512,
        style_depth=8,
        network_capacity=16,
        transparent=False,
        fp16=False,
        cl_reg=False,
        augment_fn=None,
        steps=1,
        lr=1e-4,
        fq_layers=[],
        fq_dict_size=256,
        attn_layers=[],
    ):
        super().__init__()
        self.lr = lr
        self.steps = steps
        self.ema_updater = EMA(0.995)

        self.S = StyleVectorizer(latent_dim, style_depth)
        self.G = Generator(image_size,
                           latent_dim,
                           network_capacity,
                           transparent=transparent,
                           attn_layers=attn_layers)
        self.D = Discriminator(
            image_size,
            network_capacity,
            fq_layers=fq_layers,
            fq_dict_size=fq_dict_size,
            attn_layers=attn_layers,
            transparent=transparent,
        )

        self.SE = StyleVectorizer(latent_dim, style_depth)
        self.GE = Generator(image_size,
                            latent_dim,
                            network_capacity,
                            transparent=transparent,
                            attn_layers=attn_layers)

        set_requires_grad(self.SE, False)
        set_requires_grad(self.GE, False)

        generator_params = list(self.G.parameters()) + list(
            self.S.parameters())
        self.G_opt = DiffGrad(generator_params, lr=self.lr, betas=(0.5, 0.9))
        self.D_opt = DiffGrad(self.D.parameters(),
                              lr=self.lr,
                              betas=(0.5, 0.9))

        self._init_weights()
        self.reset_parameter_averaging()

        self.cuda()

        if fp16:
            (self.S, self.G, self.D, self.SE,
             self.GE), (self.G_opt, self.D_opt) = amp.initialize(
                 [self.S, self.G, self.D, self.SE, self.GE],
                 [self.G_opt, self.D_opt],
                 opt_level="O2")

        # experimental contrastive loss discriminator regularization
        if augment_fn is not None:
            self.augment_fn = augment_fn
        else:
            self.augment_fn = nn.Sequential(
                nn.ReflectionPad2d(int((sqrt(2) - 1) * image_size / 4)),
                RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.7),
                augs.RandomGrayscale(p=0.2),
                augs.RandomHorizontalFlip(),
                RandomApply(augs.RandomAffine(degrees=0,
                                              translate=(0.25, 0.25),
                                              shear=(15, 15)),
                            p=0.3),
                RandomApply(nn.Sequential(
                    augs.RandomRotation(180),
                    augs.CenterCrop(size=(image_size, image_size))),
                            p=0.2),
                augs.RandomResizedCrop(size=(image_size, image_size)),
                RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
                RandomApply(augs.RandomErasing(), p=0.1),
            )

        self.D_cl = (ContrastiveLearner(self.D,
                                        image_size,
                                        augment_fn=self.augment_fn,
                                        fp16=fp16,
                                        hidden_layer="flatten")
                     if cl_reg else None)
Exemple #27
0
    opt = optim.Adam({}, clip_args={"clip_norm": 10.0})
    elbo = infer.Trace_ELBO()

    svi = infer.SVI(
        transforming_template_mnist, transforming_template_encoder, opt, loss=elbo
    )
    files = glob.glob("runs/*")
    for f in files:
        os.remove(f)
    tb = SummaryWriter("runs/")

    x_train = x_train.expand(516, 1, 28, 28)
    x_train = F.pad(x_train, (6, 6, 6, 6))
    for i in range(20000):
        x_rot = augmentation.RandomAffine(40.0)(
            x_train
        )  # randomly translate the image by up to 40 degrees

        l = svi.step(x_rot, transforms, cond=True, grid_size=40, encoder=encoder)
        tb.add_scalar("loss", l, i)
        if (i % 100) == 0:
            print(i, l)
            x_rot = x_rot[:32]
            tb.add_image("originals", torchvision.utils.make_grid(x_rot), i)
            bwd_trace = poutine.trace(transforming_template_encoder).get_trace(
                x_rot, transforms, cond=True, grid_size=40, encoder=encoder
            )
            fwd_trace = poutine.trace(
                poutine.replay(transforming_template_mnist, trace=bwd_trace)
            ).get_trace(x_rot, transforms, cond=False, grid_size=40, encoder=encoder)
            recon = fwd_trace.nodes["pixels"]["fn"].mean
Exemple #28
0
    g_ema = Generator(
        args.size,
        args.latent_size,
        args.n_mlp,
        channel_multiplier=args.channel_multiplier,
        constant_input=args.constant_input,
    ).to(device)
    g_ema.requires_grad_(False)
    g_ema.eval()
    accumulate(g_ema, generator, 0)

    augment_fn = nn.Sequential(
        nn.ReflectionPad2d(int((math.sqrt(2) - 1) * args.size / 4)),  # zoom out
        augs.RandomHorizontalFlip(),
        RandomApply(augs.RandomAffine(degrees=0, translate=(0.25, 0.25), shear=(15, 15)), p=0.2),
        RandomApply(augs.RandomRotation(180), p=0.2),
        augs.RandomResizedCrop(size=(args.size, args.size), scale=(1, 1), ratio=(1, 1)),
        RandomApply(augs.RandomResizedCrop(size=(args.size, args.size), scale=(0.5, 0.9)), p=0.1),  # zoom in
        RandomApply(augs.RandomErasing(), p=0.1),
    )
    contrast_learner = (
        ContrastiveLearner(discriminator, args.size, augment_fn=augment_fn, hidden_layer=(-1, 0))
        if args.contrastive > 0
        else None
    )

    g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
    d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)

    g_optim = th.optim.Adam(
def main(argv=None):
    # CLI
    parser = argparse.ArgumentParser()
    parser.add_argument("name", help="Name of the experiment")
    parser.add_argument(
        "-a",
        "--augment",
        action="store_true",
        help="If True, we apply augmentations",
    )
    parser.add_argument("-b",
                        "--batch-size",
                        type=int,
                        default=16,
                        help="Batch size")
    parser.add_argument(
        "--b1",
        type=float,
        default=0.5,
        help="Adam optimizer hyperparamter",
    )
    parser.add_argument(
        "--b2",
        type=float,
        default=0.999,
        help="Adam optimizer hyperparamter",
    )
    parser.add_argument(
        "-d",
        "--device",
        type=str,
        default="cpu",
        choices=["cpu", "cuda"],
        help="Device to use",
    )
    parser.add_argument(
        "--eval-frequency",
        type=int,
        default=400,
        help="Generate generator images every `eval_frequency` epochs",
    )
    parser.add_argument(
        "--latent-dim",
        type=int,
        default=100,
        help="Dimensionality of the random noise",
    )
    parser.add_argument("--lr",
                        type=float,
                        default=0.0002,
                        help="Learning rate")
    parser.add_argument(
        "--ndf",
        type=int,
        default=32,
        help="Number of discriminator feature maps (after first convolution)",
    )
    parser.add_argument(
        "--ngf",
        type=int,
        default=32,
        help=
        "Number of generator feature maps (before last transposed convolution)",
    )
    parser.add_argument(
        "-n",
        "--n-epochs",
        type=int,
        default=200,
        help="Number of training epochs",
    )
    parser.add_argument(
        "--mosaic-size",
        type=int,
        default=10,
        help="Size of the side of the rectangular mosaic",
    )
    parser.add_argument(
        "-p",
        "--prob",
        type=float,
        default=0.9,
        help="Probability of applying an augmentation",
    )

    args = parser.parse_args(argv)
    args_d = vars(args)
    print(args)

    img_size = 128

    # Additional parameters
    device = torch.device(args.device)
    mosaic_kwargs = {"nrow": args.mosaic_size, "normalize": True}
    n_mosaic_cells = args.mosaic_size * args.mosaic_size
    sample_showcase_ix = (
        0  # this one will be used to demonstrate the augmentations
    )

    augment_module = torch.nn.Sequential(
        K.RandomAffine(degrees=0, translate=(1 / 8, 1 / 8), p=args.prob),
        K.RandomErasing((0.0, 0.5), p=args.prob),
    )

    # Loss function
    adversarial_loss = torch.nn.BCELoss()

    # Initialize generator and discriminator
    generator = Generator(latent_dim=args.latent_dim, ngf=args.ngf)
    discriminator = Discriminator(
        ndf=args.ndf, augment_module=augment_module if args.augment else None)

    generator.to(device)
    discriminator.to(device)

    # Initialize weights
    generator.apply(init_weights_)
    discriminator.apply(init_weights_)

    # Configure data loader
    data_path = pathlib.Path("data")
    tform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])
    dataset = DatasetImages(
        data_path,
        transform=tform,
    )
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
    )

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=args.lr,
                                   betas=(args.b1, args.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=args.lr,
                                   betas=(args.b1, args.b2))

    # Output path and metadata
    output_path = pathlib.Path("outputs") / args.name
    output_path.mkdir(exist_ok=True, parents=True)

    # Add other parameters (not included in CLI)
    args_d["time"] = datetime.now()
    args_d["kornia"] = str(augment_module)

    # Prepare tensorboard writer
    writer = SummaryWriter(output_path)

    # Log hyperparameters as text
    writer.add_text(
        "hyperparameter",
        pprint.pformat(args_d).replace(
            "\n", "  \n"),  # markdown needs 2 spaces before newline
        0,
    )
    # Log true data
    writer.add_image(
        "true_data",
        make_grid(torch.stack([dataset[i] for i in range(n_mosaic_cells)]),
                  **mosaic_kwargs),
        0,
    )
    # Log augmented data
    batch_showcase = dataset[sample_showcase_ix][None, ...].repeat(
        n_mosaic_cells, 1, 1, 1)
    batch_showcase_aug = discriminator.augment_module(batch_showcase)
    writer.add_image("augmentations",
                     make_grid(batch_showcase_aug, **mosaic_kwargs), 0)

    # Prepate evaluation noise
    z_eval = torch.randn(n_mosaic_cells, args.latent_dim).to(device)

    for epoch in tqdm(range(args.n_epochs)):
        for i, imgs in enumerate(dataloader):
            n_samples, *_ = imgs.shape
            batches_done = epoch * len(dataloader) + i

            # Adversarial ground truths
            valid = 0.9 * torch.ones(
                n_samples, 1, device=device, dtype=torch.float32)
            fake = torch.zeros(n_samples,
                               1,
                               device=device,
                               dtype=torch.float32)

            # D preparation
            optimizer_D.zero_grad()

            # D loss on reals
            real_imgs = imgs.to(device)
            d_x = discriminator(real_imgs)
            real_loss = adversarial_loss(d_x, valid)
            real_loss.backward()

            # D loss on fakes
            z = torch.randn(n_samples, args.latent_dim).to(device)
            gen_imgs = generator(z)
            d_g_z1 = discriminator(gen_imgs.detach())

            fake_loss = adversarial_loss(d_g_z1, fake)
            fake_loss.backward()

            optimizer_D.step()  # we called backward twice, the result is a sum

            # G preparation
            optimizer_G.zero_grad()

            # G loss
            d_g_z2 = discriminator(gen_imgs)
            g_loss = adversarial_loss(d_g_z2, valid)

            g_loss.backward()
            optimizer_G.step()

            # Logging
            if batches_done % 50 == 0:
                writer.add_scalar("d_x", d_x.mean().item(), batches_done)
                writer.add_scalar("d_g_z1", d_g_z1.mean().item(), batches_done)
                writer.add_scalar("d_g_z2", d_g_z2.mean().item(), batches_done)
                writer.add_scalar("D_loss", (real_loss + fake_loss).item(),
                                  batches_done)
                writer.add_scalar("G_loss", g_loss.item(), batches_done)

            if epoch % args.eval_frequency == 0 and i == 0:
                generator.eval()
                discriminator.eval()

                # Generate fake images
                gen_imgs_eval = generator(z_eval)

                # Generate nice mosaic
                writer.add_image(
                    "fake",
                    make_grid(gen_imgs_eval.data, **mosaic_kwargs),
                    batches_done,
                )

                # Save checkpoint (and potentially overwrite an existing one)
                torch.save(generator, output_path / "model.pt")

                # Make sure generator and discriminator in the training mode
                generator.train()
                discriminator.train()