def test_ra(self): dataset = CIFAR10(root=osp.join('../..', 'data', 'CIFAR10'), train=False, download=True, transform=None) idx = torch.randint(low=0, high=len(dataset), size=()).item() # 908 img, label = dataset[idx] print(f'Image index: {idx}.') print(f'Label: {label}.') magnitude = 1.0 ra = RandAugment(n_augm_apply=1, magnitude=magnitude, augm_pool=RAND_AUGMENT_DEFAULT_POOL[2:3], magnitude_policy='constant') img_ra = ra(img) img_tens = ToTensor()(img) img_ra_tens = ToTensor()(img_ra) print(img_tens[0:5, 0, 0]) print(img_ra_tens[0:5, 0, 0]) print('Img : ', img_tens.shape) print('Img RA : ', img_ra_tens.shape) if magnitude == 0: self.assertTrue(img_tens.eq(img_ra_tens).all()) else: self.assertTrue(img_tens.ne(img_ra_tens).any()) plt.figure() plt.imshow(img) plt.figure() plt.imshow(img_ra) plt.show(block=False)
def __init__(self, pil_transform: ImageTransform, p: float = 1.0): """ Class that convert PIL image to tensor internally for apply tensor transforms. Tensors images will have the shape (width, height, 3). :param pil_transform: The tensor transform to wrap. :param p: The probability to apply the transform. (default: 1.0) """ super().__init__( transform=pil_transform, pre_convert=ToTensor(), post_convert=ToPIL(mode=None), p=p, )
def __init__(self, pil_transform: ImageTransform, mode: Optional[str] = 'RGB', p: float = 1.0): """ Class that convert tensor to PIL image internally for apply PIL transforms. Tensors images must have the shape (width, height, 3). :param pil_transform: The PIL transform to wrap. :param mode: The PIL image mode of the image. (default: 'RGB') :param p: The probability to apply the transform. (default: 1.0) """ super().__init__( transform=pil_transform, pre_convert=ToPIL(mode=mode), post_convert=ToTensor(), p=p, )
def test_to_tensor(self): to_tens_mlu = ToTensor() to_tens_tvi = torchvision.transforms.ToTensor() # Black image (width, height, 3) data = Image.new('RGB', (32, 64), color='black') other_mlu = to_tens_mlu(data) other_tvi = to_tens_tvi(data) self.assertEqual( other_mlu.shape, other_tvi.shape, f'Mismatch shapes for conversion {to_tens_mlu.__class__.__name__}(x) == {to_tens_tvi.__class__.__name__}(x)' ) self.assertTrue( other_mlu.eq(other_tvi).all(), f'Mismatch values for conversion {to_tens_mlu.__class__.__name__}(x) == {to_tens_tvi.__class__.__name__}(x)' )
def test_pil_conversions(self): to_ten = ToTensor() to_num = ToNumpy() to_lis = ToList() to_pil = ToPIL() # Black image (width, height) data = Image.new('RGB', (32, 64), color='black') to_base = to_pil for to in [to_ten, to_num, to_lis, to_pil]: other = to_base(to(data)) self.assertEqual( data.size, other.size, f'Mismatch shapes for conversion {to_base.__class__.__name__}({to.__class__.__name__}(x)) == x' ) self.assertTrue( data == other, f'Mismatch values for conversion {to_base.__class__.__name__}({to.__class__.__name__}(x)) == x' )
def test_numpy_conversions(self): # Black image (width, height, channel) data = np.zeros((32, 64, 3)) to_ten = ToTensor() to_num = ToNumpy() to_lis = ToList() to_pil = ToPIL() to_base = to_num for to in [to_ten, to_num, to_lis, to_pil]: other = to_base(to(data)) self.assertEqual( data.shape, other.shape, f'Mismatch shapes for conversion {to_base.__class__.__name__}({to.__class__.__name__}(x)) == x' ) self.assertTrue((data == other).all( ), f'Mismatch values for conversion {to_base.__class__.__name__}({to.__class__.__name__}(x)) == x' )
def test_tensor_conversions(self): # Black image (channel, height, width) data = torch.zeros(3, 64, 32) to_ten = ToTensor(device=data.device) to_num = ToNumpy() to_lis = ToList() to_pil = ToPIL() to_base = to_ten for to in [to_ten, to_num, to_lis, to_pil]: other = to_base(to(data)) self.assertEqual( data.shape, other.shape, f'Mismatch shapes for conversion {to_base.__class__.__name__}({to.__class__.__name__}(x)) == x' ) self.assertTrue( data.eq(other).all(), f'Mismatch values for conversion {to_base.__class__.__name__}({to.__class__.__name__}(x)) == x' )
def test_list_conversions(self): # Black image (width, height, channel) data = [[[0 for _ in range(3)] for _ in range(64)] for _ in range(32)] to_ten = ToTensor() to_num = ToNumpy() to_lis = ToList() to_pil = ToPIL() to_base = to_lis for to in [to_ten, to_num, to_lis, to_pil]: other = to_base(to(data)) data_shape = _get_list_shape(data) other_shape = _get_list_shape(other) self.assertEqual( data_shape, other_shape, f'Mismatch shapes for conversion {to_base.__class__.__name__}({to.__class__.__name__}(x)) == x' ) self.assertTrue( data == other, f'Mismatch values for conversion {to_base.__class__.__name__}({to.__class__.__name__}(x)) == x' )