예제 #1
0
    def test_invert(self):
        tf = T.RandomApply(
            transforms=[T.Identity(), T.ToTensor()],
            p=1,
        )

        img_tensor = tf(self.img_pil)
        self.assertIsInstance(img_tensor, torch.Tensor)
        self.assertIsInstance(tf.invert(img_tensor), Image.Image)
예제 #2
0
 def test_replay(self):
     tf = T.Compose([
         T.RandomCrop(self.crop_size),
         T.RandomVerticalFlip(),
         T.RandomHorizontalFlip(),
         T.ToTensor(),
     ])
     img_tf1 = tf.replay(self.img_pil)
     img_tf2 = tf.replay(self.img_pil)
     self.assertTrue(torch.allclose(img_tf1, img_tf2))
    def setUp(self) -> None:
        self.img_size = (256, 320)
        self.h, self.w = self.img_size
        self.crop_size = (64, 128)
        self.img_tensor = torch.randn((1, ) + self.img_size).clamp(0, 1)

        self.img_pil = T.ToPILImage()(self.img_tensor)
        self.img_tensor = T.ToTensor()(self.img_pil)

        self.n = random.randint(0, 1e9)
예제 #4
0
 def __randaug_mnist(self, img_pil, to_tensor, norm):
     return T.Compose([
         T.RandomAffine(
             degrees=10,
             translate=(0.17, 0.17),
             scale=(0.85, 1.05),
             shear=(-10, 10, -10, 10),
             resample=PIL.Image.BILINEAR,
         ),
         T.ColorJitter(0.5, 0.5, 0.5, 0.25),
         T.TransformIf(T.ToTensor(), to_tensor),
         T.TransformIf(T.Normalize(mean=(0.1307, ), std=(0.3081, )),
                       to_tensor and norm),
     ])(img_pil)
예제 #5
0
    def test_track(self):
        tf = T.Compose([
            T.ToPILImage(),
            T.RandomVerticalFlip(p=0.5),
            # the crop will include the center pixels
            T.RandomCrop(size=tuple(
                int(0.8 * self.img_size[i]) for i in range(2))),
            T.RandomHorizontalFlip(p=0.5),
            T.ToTensor(),
        ])

        imgs_tf = [tf.track(self.img_tensor) for _ in range(10)]
        for i, img_tf in enumerate(imgs_tf):
            n = min(self.img_size) // 10
            center_pixels = (0, ) + tuple(
                slice(self.img_size[i] // 2 - n, self.img_size[i] // 2 + n)
                for i in range(2))
            self.assertTrue(
                torch.allclose(tf[i](img_tf)[center_pixels],
                               self.img_tensor[center_pixels]))
예제 #6
0
    def __randaug_imagenet(self, img_pil, to_tensor, norm):
        # TODO: turn this code less cryptic...

        policy = [('FlipLR', 0.5, random.randint(1, 9)),
                  ('FlipUD', 0.5, random.randint(1, 9)),
                  ('Rotate', 0.5, random.randint(1, 9))] \
                 + random.choice(self._all_policies)

        img_shape = img_pil.size[::-1] + (3, )

        for xform in policy:
            assert len(xform) == 3
            name, probability, level = xform
            xform_fn = NAME_TO_TRANSFORM[name].pil_transformer(
                probability, level, img_shape)
            img_pil = xform_fn(img_pil)

        return T.Compose([
            T.Lambda(lambda img: img.convert('RGB')),
            T.TransformIf(T.ToTensor(), to_tensor),
            T.TransformIf(
                T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]), to_tensor and norm),
        ])(img_pil)
예제 #7
0
 def test_invert(self):
     tf = T.RandomChoice([T.ToTensor()] * 10)
     self.assertIsInstance(tf(self.img_pil), torch.Tensor)
     self.assertIsInstance(tf.inverse(), T.ToPILImage)
예제 #8
0
 def test_invert(self):
     tf = T.TransformIf(transform=T.ToTensor(), condition=True)
     self.assertIsInstance(tf.inverse(), T.ToPILImage)
예제 #9
0
def pil_unwrap2(img_pil, mean_std=None):
    img = T.ToTensor()(img_pil.convert('RGB'))
    if mean_std is not None:
        img = T.Normalize(*mean_std)(img)
    return img
예제 #10
0
def main(**kwargs):
    ########################
    # [DATA] some datasets #
    ########################

    dataset_name = kwargs['dataset_name']

    tf = T.Compose([
        T.ToTensor(),
        T.TransformIf(T.Normalize(mean=[0.13], std=[0.31]),
                      dataset_name == 'mnist'),
        T.TransformIf(T.Normalize(mean=[0.5] * 3, std=[0.5] * 3),
                      dataset_name == 'cifar10'),
    ])

    dataset, tst_dataset = {
        'mnist':
        lambda: (
            MNIST(root='/data/', train=True, transform=tf, download=True),
            MNIST(root='/data/', train=False, transform=tf, download=True),
        ),
        'cifar10':
        lambda: (
            CIFAR10(root='/data/', train=True, transform=tf, download=True),
            CIFAR10(root='/data/', train=False, transform=tf, download=True),
        )
    }[dataset_name]()

    tng_dataset, val_dataset = split(dataset,
                                     percentage=kwargs['val_percentage'])
    sampler = RandomSampler(tng_dataset,
                            num_samples=len(val_dataset),
                            replacement=False)
    tng_dataloader = DataLoader(dataset=tng_dataset,
                                batch_size=kwargs['tng_batch_size'],
                                sampler=sampler,
                                num_workers=4)
    val_dataloader = DataLoader(dataset=val_dataset,
                                batch_size=kwargs['val_batch_size'],
                                shuffle=False,
                                num_workers=4)
    tst_dataloader = DataLoader(dataset=tst_dataset,
                                batch_size=kwargs['val_batch_size'],
                                shuffle=False,
                                num_workers=4)

    ###########################
    # [MODEL] a pytorch model #
    ###########################

    sample_input, _ = dataset[0]
    net = SimpLeNet(
        input_size=sample_input.size(),
        n_classes=10,
    )

    ########################################
    # [STRATEGY] it describes the training #
    ########################################

    exp_name = f'{net.__class__.__name__.lower()}/{dataset_name}'
    kwargs['log_dir'] = Path(kwargs['log_dir']) / exp_name

    classifier = ClassifierStrategy(net=net, **kwargs)

    ##################################
    # [EXECUTOR] it handles the rest #
    ##################################

    executor = Executor(
        tng_dataloader=tng_dataloader,
        val_dataloader=val_dataloader,
        tst_dataloader=tst_dataloader,
        exp_name=exp_name,
        **kwargs,
    )

    executor.train_test(strategy=classifier, **kwargs)