def kornia_color_jitter_numpy(img, setting): if setting * 255 > 1: # I'm using Kornia's ColorJitter, which requires pytorch arrays in b,c,h,w format. img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) img = ColorJitter(setting, setting, setting, setting)(img) img = img.squeeze(0).permute(1, 2, 0).numpy() return img
def test_param(self, brightness, contrast, saturation, hue, return_transform, same_on_batch, device, dtype): count = 0 _brightness = (brightness if isinstance(brightness, (int, float)) else nn.Parameter(brightness.clone().to(device=device, dtype=dtype))) _contrast = (contrast if isinstance(contrast, (int, float)) else nn.Parameter(contrast.clone().to(device=device, dtype=dtype))) _saturation = (saturation if isinstance(saturation, (int, float)) else nn.Parameter(saturation.clone().to(device=device, dtype=dtype))) _hue = hue if isinstance(hue, (int, float)) else nn.Parameter( hue.clone().to(device=device, dtype=dtype)) torch.manual_seed(0) input = torch.randint(255, (2, 3, 10, 10), device=device, dtype=dtype) / 255.0 aug = ColorJitter(_brightness, _contrast, _saturation, _hue, return_transform, same_on_batch) if return_transform: output, _ = aug(input) else: output = aug(input) if len(list(aug.parameters())) != 0: mse = nn.MSELoss() opt = torch.optim.SGD(aug.parameters(), lr=0.1) loss = mse(output, torch.ones_like(output) * 2) loss.backward() opt.step() if not isinstance(brightness, (int, float)): assert isinstance(aug.brightness, torch.Tensor) # Assert if param not updated assert (brightness.to(device=device, dtype=dtype) - aug.brightness.data).sum() != 0 if not isinstance(contrast, (int, float)): assert isinstance(aug.contrast, torch.Tensor) # Assert if param not updated assert (contrast.to(device=device, dtype=dtype) - aug.contrast.data).sum() != 0 if not isinstance(saturation, (int, float)): assert isinstance(aug.saturation, torch.Tensor) # Assert if param not updated assert (saturation.to(device=device, dtype=dtype) - aug.saturation.data).sum() != 0 if not isinstance(hue, (int, float)): assert isinstance(aug.hue, torch.Tensor) # Assert if param not updated assert (hue.to(device=device, dtype=dtype) - aug.hue.data).sum() != 0
def test_color_jitter_batch(self): f = ColorJitter() f1 = ColorJitter(return_transform=True) input = torch.rand(2, 3, 5, 5) # 2 x 3 x 5 x 5 expected = input expected_transform = torch.eye(3).unsqueeze(0).expand( (2, 3, 3)) # 2 x 3 x 3 assert_allclose(f(input), expected, atol=1e-4, rtol=1e-5) assert_allclose(f1(input)[0], expected, atol=1e-4, rtol=1e-5) assert_allclose(f1(input)[1], expected_transform)
def test_sequential(self): f = nn.Sequential( ColorJitter(return_transform=True), ColorJitter(return_transform=True), ) input = torch.rand(3, 5, 5) # 3 x 5 x 5 expected = input expected_transform = torch.eye(3).unsqueeze(0) # 3 x 3 assert_allclose(f(input)[0], expected, atol=1e-4, rtol=1e-5) assert_allclose(f(input)[1], expected_transform)
def test_random_hue_tensor(self): torch.manual_seed(42) f = ColorJitter(hue=torch.tensor([-0.2, 0.2])) input = torch.tensor([[[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.]], [[1.0, 0.5, 0.6], [0.6, 0.3, 0.2], [0.8, 0.1, 0.2]], [[0.6, 0.8, 0.7], [0.9, 0.3, 0.2], [0.8, 0.4, .5]]]]) # 1 x 1 x 3 x 3 input = input.repeat(2, 1, 1, 1) # 2 x 3 x 3 expected = torch.tensor([[[[0.1000, 0.2000, 0.3000], [0.6000, 0.5000, 0.4000], [0.7000, 0.8000, 1.0000]], [[1.0000, 0.5251, 0.6167], [0.6126, 0.3000, 0.2000], [0.8000, 0.1000, 0.2000]], [[0.5623, 0.8000, 0.7000], [0.9000, 0.3084, 0.2084], [0.7958, 0.4293, 0.5335]]], [[[0.1000, 0.2000, 0.3000], [0.6116, 0.5000, 0.4000], [0.7000, 0.8000, 1.0000]], [[1.0000, 0.4769, 0.5846], [0.6000, 0.3077, 0.2077], [0.7961, 0.1000, 0.2000]], [[0.6347, 0.8000, 0.7000], [0.9000, 0.3000, 0.2000], [0.8000, 0.3730, 0.4692]]]]) assert_allclose(f(input), expected)
def test_random_saturation_tuple(self): torch.manual_seed(42) f = ColorJitter(saturation=(0.8, 1.2)) input = torch.tensor([[[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.]], [[1.0, 0.5, 0.6], [0.6, 0.3, 0.2], [0.8, 0.1, 0.2]], [[0.6, 0.8, 0.7], [0.9, 0.3, 0.2], [0.8, 0.4, .5]]]]) # 1 x 1 x 3 x 3 input = input.repeat(2, 1, 1, 1) # 2 x 3 x 3 expected = torch.tensor([[[[1.8763e-01, 2.5842e-01, 3.3895e-01], [6.2921e-01, 5.0000e-01, 4.0000e-01], [7.0974e-01, 8.0000e-01, 1.0000e+00]], [[1.0000e+00, 5.2921e-01, 6.0974e-01], [6.2921e-01, 3.1947e-01, 2.1947e-01], [8.0000e-01, 1.6816e-01, 2.7790e-01]], [[6.3895e-01, 8.0000e-01, 7.0000e-01], [9.0000e-01, 3.1947e-01, 2.1947e-01], [8.0000e-01, 4.3895e-01, 5.4869e-01]]], [[[1.1921e-07, 1.2953e-01, 2.5302e-01], [5.6476e-01, 5.0000e-01, 4.0000e-01], [6.8825e-01, 8.0000e-01, 1.0000e+00]], [[1.0000e+00, 4.6476e-01, 5.8825e-01], [5.6476e-01, 2.7651e-01, 1.7651e-01], [8.0000e-01, 1.7781e-02, 1.0603e-01]], [[5.5556e-01, 8.0000e-01, 7.0000e-01], [9.0000e-01, 2.7651e-01, 1.7651e-01], [8.0000e-01, 3.5302e-01, 4.4127e-01]]]]) assert_allclose(f(input), expected)
def test_random_contrast_list(self): torch.manual_seed(42) f = ColorJitter(contrast=[0.8, 1.2]) input = torch.tensor([[[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.]]]]) # 1 x 1 x 3 x 3 input = input.repeat(2, 3, 1, 1) # 2 x 3 x 3 expected = torch.tensor([[[[0.0953, 0.1906, 0.2859], [0.5719, 0.4766, 0.3813], [0.6672, 0.7625, 0.9531]], [[0.0953, 0.1906, 0.2859], [0.5719, 0.4766, 0.3813], [0.6672, 0.7625, 0.9531]], [[0.0953, 0.1906, 0.2859], [0.5719, 0.4766, 0.3813], [0.6672, 0.7625, 0.9531]]], [[[0.1184, 0.2367, 0.3551], [0.7102, 0.5919, 0.4735], [0.8286, 0.9470, 1.0000]], [[0.1184, 0.2367, 0.3551], [0.7102, 0.5919, 0.4735], [0.8286, 0.9470, 1.0000]], [[0.1184, 0.2367, 0.3551], [0.7102, 0.5919, 0.4735], [0.8286, 0.9470, 1.0000]]]]) assert_allclose(f(input), expected, atol=1e-4, rtol=1e-5)
def test_random_brightness_tuple(self): torch.manual_seed(42) f = ColorJitter(brightness=(-0.2, 0.2)) input = torch.tensor([[[[0.1, 0.2, 0.3], [0.6, 0.5, 0.4], [0.7, 0.8, 1.]]]]) # 1 x 1 x 3 x 3 input = input.repeat(2, 3, 1, 1) # 2 x 3 x 3 expected = torch.tensor([[[[0.2529, 0.3529, 0.4529], [0.7529, 0.6529, 0.5529], [0.8529, 0.9529, 1.0000]], [[0.2529, 0.3529, 0.4529], [0.7529, 0.6529, 0.5529], [0.8529, 0.9529, 1.0000]], [[0.2529, 0.3529, 0.4529], [0.7529, 0.6529, 0.5529], [0.8529, 0.9529, 1.0000]]], [[[0.2660, 0.3660, 0.4660], [0.7660, 0.6660, 0.5660], [0.8660, 0.9660, 1.0000]], [[0.2660, 0.3660, 0.4660], [0.7660, 0.6660, 0.5660], [0.8660, 0.9660, 1.0000]], [[0.2660, 0.3660, 0.4660], [0.7660, 0.6660, 0.5660], [0.8660, 0.9660, 1.0000]]]]) # 1 x 1 x 3 x 3 assert_allclose(f(input), expected)
def smoke_test(self): f = ColorJitter(brightness=0.5, contrast=0.3, saturation=[0.2, 1.2], hue=0.1) repr = "ColorJitter(brightness=0.5, contrast=0.3, saturation=[0.2, 1.2], hue=0.1, return_transform=False)" assert str(f) == repr
def color_model(gaussian_p, solarize_p): return nn.Sequential( RandomResizedCrop((image_size, image_size), interpolation="BICUBIC"), RandomHorizontalFlip(), ColorJitter(0.4, 0.4, 0.2, 0.1, p=0.8), RandomGrayscale(p=0.2), RandomApply(GaussianBlur2d(get_kernel_size(image_size), (0.1, 2.0)), p=gaussian_p), RandomSolarize(0, 0, p=solarize_p))
def __init__(self, input_shape, s=1.0, apply_transforms=None): assert len(input_shape) == 3, "input_shape should be (H, W, C)" self.input_shape = input_shape self.H, self.W, self.C = input_shape[0], input_shape[1], input_shape[2] self.s = s self.apply_transforms = apply_transforms if self.apply_transforms is None: kernel_size = int(0.1 * self.H) sigma = self._get_sigma() self.apply_transforms = KorniaCompose([ RandomResizedCrop(size=(self.H, self.W), scale=(0.08, 1.0)), RandomHorizontalFlip(p=0.5), ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s), RandomGrayscale(p=0.2), GaussianBlur2d(kernel_size=(kernel_size, kernel_size), sigma=(sigma, sigma)) ])
def color_jitter(self, hue: float = 0.0, p: float = 1.0) -> TransformType: return ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=hue, p=p)