Exemple #1
0
    def test_ra(self):
        dataset = CIFAR10(root=osp.join('../..', 'data', 'CIFAR10'),
                          train=False,
                          download=True,
                          transform=None)
        idx = torch.randint(low=0, high=len(dataset), size=()).item()  # 908
        img, label = dataset[idx]

        print(f'Image index: {idx}.')
        print(f'Label: {label}.')

        magnitude = 1.0
        ra = RandAugment(n_augm_apply=1,
                         magnitude=magnitude,
                         augm_pool=RAND_AUGMENT_DEFAULT_POOL[2:3],
                         magnitude_policy='constant')
        img_ra = ra(img)

        img_tens = ToTensor()(img)
        img_ra_tens = ToTensor()(img_ra)

        print(img_tens[0:5, 0, 0])
        print(img_ra_tens[0:5, 0, 0])

        print('Img    : ', img_tens.shape)
        print('Img RA : ', img_ra_tens.shape)

        if magnitude == 0:
            self.assertTrue(img_tens.eq(img_ra_tens).all())
        else:
            self.assertTrue(img_tens.ne(img_ra_tens).any())

        plt.figure()
        plt.imshow(img)

        plt.figure()
        plt.imshow(img_ra)

        plt.show(block=False)
Exemple #2
0
    def __init__(self, pil_transform: ImageTransform, p: float = 1.0):
        """
			Class that convert PIL image to tensor internally for apply tensor transforms.
			Tensors images will have the shape (width, height, 3).

			:param pil_transform: The tensor transform to wrap.
			:param p: The probability to apply the transform. (default: 1.0)
		"""
        super().__init__(
            transform=pil_transform,
            pre_convert=ToTensor(),
            post_convert=ToPIL(mode=None),
            p=p,
        )
Exemple #3
0
    def __init__(self,
                 pil_transform: ImageTransform,
                 mode: Optional[str] = 'RGB',
                 p: float = 1.0):
        """
			Class that convert tensor to PIL image internally for apply PIL transforms.
			Tensors images must have the shape (width, height, 3).

			:param pil_transform: The PIL transform to wrap.
			:param mode: The PIL image mode of the image. (default: 'RGB')
			:param p: The probability to apply the transform. (default: 1.0)
		"""
        super().__init__(
            transform=pil_transform,
            pre_convert=ToPIL(mode=mode),
            post_convert=ToTensor(),
            p=p,
        )
Exemple #4
0
    def test_to_tensor(self):
        to_tens_mlu = ToTensor()
        to_tens_tvi = torchvision.transforms.ToTensor()

        # Black image (width, height, 3)
        data = Image.new('RGB', (32, 64), color='black')

        other_mlu = to_tens_mlu(data)
        other_tvi = to_tens_tvi(data)

        self.assertEqual(
            other_mlu.shape, other_tvi.shape,
            f'Mismatch shapes for conversion {to_tens_mlu.__class__.__name__}(x) == {to_tens_tvi.__class__.__name__}(x)'
        )
        self.assertTrue(
            other_mlu.eq(other_tvi).all(),
            f'Mismatch values for conversion {to_tens_mlu.__class__.__name__}(x) == {to_tens_tvi.__class__.__name__}(x)'
        )
Exemple #5
0
    def test_pil_conversions(self):
        to_ten = ToTensor()
        to_num = ToNumpy()
        to_lis = ToList()
        to_pil = ToPIL()

        # Black image (width, height)
        data = Image.new('RGB', (32, 64), color='black')
        to_base = to_pil

        for to in [to_ten, to_num, to_lis, to_pil]:
            other = to_base(to(data))

            self.assertEqual(
                data.size, other.size,
                f'Mismatch shapes for conversion {to_base.__class__.__name__}({to.__class__.__name__}(x)) == x'
            )
            self.assertTrue(
                data == other,
                f'Mismatch values for conversion {to_base.__class__.__name__}({to.__class__.__name__}(x)) == x'
            )
Exemple #6
0
    def test_numpy_conversions(self):
        # Black image (width, height, channel)
        data = np.zeros((32, 64, 3))

        to_ten = ToTensor()
        to_num = ToNumpy()
        to_lis = ToList()
        to_pil = ToPIL()

        to_base = to_num

        for to in [to_ten, to_num, to_lis, to_pil]:
            other = to_base(to(data))

            self.assertEqual(
                data.shape, other.shape,
                f'Mismatch shapes for conversion {to_base.__class__.__name__}({to.__class__.__name__}(x)) == x'
            )
            self.assertTrue((data == other).all(
            ), f'Mismatch values for conversion {to_base.__class__.__name__}({to.__class__.__name__}(x)) == x'
                            )
Exemple #7
0
    def test_tensor_conversions(self):
        # Black image (channel, height, width)
        data = torch.zeros(3, 64, 32)

        to_ten = ToTensor(device=data.device)
        to_num = ToNumpy()
        to_lis = ToList()
        to_pil = ToPIL()

        to_base = to_ten

        for to in [to_ten, to_num, to_lis, to_pil]:
            other = to_base(to(data))

            self.assertEqual(
                data.shape, other.shape,
                f'Mismatch shapes for conversion {to_base.__class__.__name__}({to.__class__.__name__}(x)) == x'
            )
            self.assertTrue(
                data.eq(other).all(),
                f'Mismatch values for conversion {to_base.__class__.__name__}({to.__class__.__name__}(x)) == x'
            )
Exemple #8
0
    def test_list_conversions(self):
        # Black image (width, height, channel)
        data = [[[0 for _ in range(3)] for _ in range(64)] for _ in range(32)]

        to_ten = ToTensor()
        to_num = ToNumpy()
        to_lis = ToList()
        to_pil = ToPIL()

        to_base = to_lis

        for to in [to_ten, to_num, to_lis, to_pil]:
            other = to_base(to(data))
            data_shape = _get_list_shape(data)
            other_shape = _get_list_shape(other)

            self.assertEqual(
                data_shape, other_shape,
                f'Mismatch shapes for conversion {to_base.__class__.__name__}({to.__class__.__name__}(x)) == x'
            )
            self.assertTrue(
                data == other,
                f'Mismatch values for conversion {to_base.__class__.__name__}({to.__class__.__name__}(x)) == x'
            )