Пример #1
0
def kornia_color_jitter_numpy(img, setting):
    if setting * 255 > 1:
        # I'm using Kornia's ColorJitter, which requires pytorch arrays in b,c,h,w format.
        img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
        img = ColorJitter(setting, setting, setting, setting)(img)
        img = img.squeeze(0).permute(1, 2, 0).numpy()
    return img
Пример #2
0
    def test_param(self, brightness, contrast, saturation, hue,
                   return_transform, same_on_batch, device, dtype):
        count = 0

        _brightness = (brightness if isinstance(brightness, (int, float)) else
                       nn.Parameter(brightness.clone().to(device=device,
                                                          dtype=dtype)))
        _contrast = (contrast if isinstance(contrast, (int, float)) else
                     nn.Parameter(contrast.clone().to(device=device,
                                                      dtype=dtype)))
        _saturation = (saturation if isinstance(saturation, (int, float)) else
                       nn.Parameter(saturation.clone().to(device=device,
                                                          dtype=dtype)))
        _hue = hue if isinstance(hue, (int, float)) else nn.Parameter(
            hue.clone().to(device=device, dtype=dtype))

        torch.manual_seed(0)
        input = torch.randint(255, (2, 3, 10, 10), device=device,
                              dtype=dtype) / 255.0
        aug = ColorJitter(_brightness, _contrast, _saturation, _hue,
                          return_transform, same_on_batch)

        if return_transform:
            output, _ = aug(input)
        else:
            output = aug(input)

        if len(list(aug.parameters())) != 0:
            mse = nn.MSELoss()
            opt = torch.optim.SGD(aug.parameters(), lr=0.1)
            loss = mse(output, torch.ones_like(output) * 2)
            loss.backward()
            opt.step()

        if not isinstance(brightness, (int, float)):
            assert isinstance(aug.brightness, torch.Tensor)
            # Assert if param not updated
            assert (brightness.to(device=device, dtype=dtype) -
                    aug.brightness.data).sum() != 0
        if not isinstance(contrast, (int, float)):
            assert isinstance(aug.contrast, torch.Tensor)
            # Assert if param not updated
            assert (contrast.to(device=device, dtype=dtype) -
                    aug.contrast.data).sum() != 0
        if not isinstance(saturation, (int, float)):
            assert isinstance(aug.saturation, torch.Tensor)
            # Assert if param not updated
            assert (saturation.to(device=device, dtype=dtype) -
                    aug.saturation.data).sum() != 0
        if not isinstance(hue, (int, float)):
            assert isinstance(aug.hue, torch.Tensor)
            # Assert if param not updated
            assert (hue.to(device=device, dtype=dtype) -
                    aug.hue.data).sum() != 0
Пример #3
0
    def test_color_jitter_batch(self):
        f = ColorJitter()
        f1 = ColorJitter(return_transform=True)

        input = torch.rand(2, 3, 5, 5)  # 2 x 3 x 5 x 5
        expected = input

        expected_transform = torch.eye(3).unsqueeze(0).expand(
            (2, 3, 3))  # 2 x 3 x 3

        assert_allclose(f(input), expected, atol=1e-4, rtol=1e-5)
        assert_allclose(f1(input)[0], expected, atol=1e-4, rtol=1e-5)
        assert_allclose(f1(input)[1], expected_transform)
Пример #4
0
    def test_sequential(self):

        f = nn.Sequential(
            ColorJitter(return_transform=True),
            ColorJitter(return_transform=True),
        )

        input = torch.rand(3, 5, 5)  # 3 x 5 x 5

        expected = input

        expected_transform = torch.eye(3).unsqueeze(0)  # 3 x 3

        assert_allclose(f(input)[0], expected, atol=1e-4, rtol=1e-5)
        assert_allclose(f(input)[1], expected_transform)
Пример #5
0
    def test_random_hue_tensor(self):
        torch.manual_seed(42)
        f = ColorJitter(hue=torch.tensor([-0.2, 0.2]))

        input = torch.tensor([[[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4],
                                [0.7, 0.8, 1.]],
                               [[1.0, 0.5, 0.6], [0.6, 0.3, 0.2],
                                [0.8, 0.1, 0.2]],
                               [[0.6, 0.8, 0.7], [0.9, 0.3, 0.2],
                                [0.8, 0.4, .5]]]])  # 1 x 1 x 3 x 3
        input = input.repeat(2, 1, 1, 1)  # 2 x 3 x 3

        expected = torch.tensor([[[[0.1000, 0.2000, 0.3000],
                                   [0.6000, 0.5000, 0.4000],
                                   [0.7000, 0.8000, 1.0000]],
                                  [[1.0000, 0.5251, 0.6167],
                                   [0.6126, 0.3000, 0.2000],
                                   [0.8000, 0.1000, 0.2000]],
                                  [[0.5623, 0.8000, 0.7000],
                                   [0.9000, 0.3084, 0.2084],
                                   [0.7958, 0.4293, 0.5335]]],
                                 [[[0.1000, 0.2000, 0.3000],
                                   [0.6116, 0.5000, 0.4000],
                                   [0.7000, 0.8000, 1.0000]],
                                  [[1.0000, 0.4769, 0.5846],
                                   [0.6000, 0.3077, 0.2077],
                                   [0.7961, 0.1000, 0.2000]],
                                  [[0.6347, 0.8000, 0.7000],
                                   [0.9000, 0.3000, 0.2000],
                                   [0.8000, 0.3730, 0.4692]]]])

        assert_allclose(f(input), expected)
Пример #6
0
    def test_random_saturation_tuple(self):
        torch.manual_seed(42)
        f = ColorJitter(saturation=(0.8, 1.2))

        input = torch.tensor([[[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4],
                                [0.7, 0.8, 1.]],
                               [[1.0, 0.5, 0.6], [0.6, 0.3, 0.2],
                                [0.8, 0.1, 0.2]],
                               [[0.6, 0.8, 0.7], [0.9, 0.3, 0.2],
                                [0.8, 0.4, .5]]]])  # 1 x 1 x 3 x 3
        input = input.repeat(2, 1, 1, 1)  # 2 x 3 x 3

        expected = torch.tensor([[[[1.8763e-01, 2.5842e-01, 3.3895e-01],
                                   [6.2921e-01, 5.0000e-01, 4.0000e-01],
                                   [7.0974e-01, 8.0000e-01, 1.0000e+00]],
                                  [[1.0000e+00, 5.2921e-01, 6.0974e-01],
                                   [6.2921e-01, 3.1947e-01, 2.1947e-01],
                                   [8.0000e-01, 1.6816e-01, 2.7790e-01]],
                                  [[6.3895e-01, 8.0000e-01, 7.0000e-01],
                                   [9.0000e-01, 3.1947e-01, 2.1947e-01],
                                   [8.0000e-01, 4.3895e-01, 5.4869e-01]]],
                                 [[[1.1921e-07, 1.2953e-01, 2.5302e-01],
                                   [5.6476e-01, 5.0000e-01, 4.0000e-01],
                                   [6.8825e-01, 8.0000e-01, 1.0000e+00]],
                                  [[1.0000e+00, 4.6476e-01, 5.8825e-01],
                                   [5.6476e-01, 2.7651e-01, 1.7651e-01],
                                   [8.0000e-01, 1.7781e-02, 1.0603e-01]],
                                  [[5.5556e-01, 8.0000e-01, 7.0000e-01],
                                   [9.0000e-01, 2.7651e-01, 1.7651e-01],
                                   [8.0000e-01, 3.5302e-01, 4.4127e-01]]]])

        assert_allclose(f(input), expected)
Пример #7
0
    def test_random_contrast_list(self):
        torch.manual_seed(42)
        f = ColorJitter(contrast=[0.8, 1.2])

        input = torch.tensor([[[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4],
                                [0.7, 0.8, 1.]]]])  # 1 x 1 x 3 x 3
        input = input.repeat(2, 3, 1, 1)  # 2 x 3 x 3

        expected = torch.tensor([[[[0.0953, 0.1906, 0.2859],
                                   [0.5719, 0.4766, 0.3813],
                                   [0.6672, 0.7625, 0.9531]],
                                  [[0.0953, 0.1906, 0.2859],
                                   [0.5719, 0.4766, 0.3813],
                                   [0.6672, 0.7625, 0.9531]],
                                  [[0.0953, 0.1906, 0.2859],
                                   [0.5719, 0.4766, 0.3813],
                                   [0.6672, 0.7625, 0.9531]]],
                                 [[[0.1184, 0.2367, 0.3551],
                                   [0.7102, 0.5919, 0.4735],
                                   [0.8286, 0.9470, 1.0000]],
                                  [[0.1184, 0.2367, 0.3551],
                                   [0.7102, 0.5919, 0.4735],
                                   [0.8286, 0.9470, 1.0000]],
                                  [[0.1184, 0.2367, 0.3551],
                                   [0.7102, 0.5919, 0.4735],
                                   [0.8286, 0.9470, 1.0000]]]])

        assert_allclose(f(input), expected, atol=1e-4, rtol=1e-5)
Пример #8
0
    def test_random_brightness_tuple(self):
        torch.manual_seed(42)
        f = ColorJitter(brightness=(-0.2, 0.2))

        input = torch.tensor([[[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4],
                                [0.7, 0.8, 1.]]]])  # 1 x 1 x 3 x 3
        input = input.repeat(2, 3, 1, 1)  # 2 x 3 x 3

        expected = torch.tensor([[[[0.2529, 0.3529, 0.4529],
                                   [0.7529, 0.6529, 0.5529],
                                   [0.8529, 0.9529, 1.0000]],
                                  [[0.2529, 0.3529, 0.4529],
                                   [0.7529, 0.6529, 0.5529],
                                   [0.8529, 0.9529, 1.0000]],
                                  [[0.2529, 0.3529, 0.4529],
                                   [0.7529, 0.6529, 0.5529],
                                   [0.8529, 0.9529, 1.0000]]],
                                 [[[0.2660, 0.3660, 0.4660],
                                   [0.7660, 0.6660, 0.5660],
                                   [0.8660, 0.9660, 1.0000]],
                                  [[0.2660, 0.3660, 0.4660],
                                   [0.7660, 0.6660, 0.5660],
                                   [0.8660, 0.9660, 1.0000]],
                                  [[0.2660, 0.3660, 0.4660],
                                   [0.7660, 0.6660, 0.5660],
                                   [0.8660, 0.9660,
                                    1.0000]]]])  # 1 x 1 x 3 x 3

        assert_allclose(f(input), expected)
Пример #9
0
 def smoke_test(self):
     f = ColorJitter(brightness=0.5,
                     contrast=0.3,
                     saturation=[0.2, 1.2],
                     hue=0.1)
     repr = "ColorJitter(brightness=0.5, contrast=0.3, saturation=[0.2, 1.2], hue=0.1, return_transform=False)"
     assert str(f) == repr
Пример #10
0
 def color_model(gaussian_p, solarize_p):
     return nn.Sequential(
         RandomResizedCrop((image_size, image_size),
                           interpolation="BICUBIC"), RandomHorizontalFlip(),
         ColorJitter(0.4, 0.4, 0.2, 0.1, p=0.8), RandomGrayscale(p=0.2),
         RandomApply(GaussianBlur2d(get_kernel_size(image_size),
                                    (0.1, 2.0)),
                     p=gaussian_p), RandomSolarize(0, 0, p=solarize_p))
Пример #11
0
    def __init__(self, input_shape, s=1.0, apply_transforms=None):

        assert len(input_shape) == 3, "input_shape should be (H, W, C)"

        self.input_shape = input_shape
        self.H, self.W, self.C = input_shape[0], input_shape[1], input_shape[2]
        self.s = s
        self.apply_transforms = apply_transforms

        if self.apply_transforms is None:
            kernel_size = int(0.1 * self.H)
            sigma = self._get_sigma()

            self.apply_transforms = KorniaCompose([
                RandomResizedCrop(size=(self.H, self.W), scale=(0.08, 1.0)),
                RandomHorizontalFlip(p=0.5),
                ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s),
                RandomGrayscale(p=0.2),
                GaussianBlur2d(kernel_size=(kernel_size, kernel_size),
                               sigma=(sigma, sigma))
            ])
Пример #12
0
 def color_jitter(self, hue: float = 0.0, p: float = 1.0) -> TransformType:
     return ColorJitter(brightness=0.4,
                        contrast=0.4,
                        saturation=0.4,
                        hue=hue,
                        p=p)