コード例 #1
0
 def __init__(self, viz: bool = False):
     super().__init__()
     self.viz = viz
     '''self.geometric = [
         K.augmentation.RandomAffine(60., p=0.75),
     ]'''
     self.augmentations = nn.Sequential(
         augmentation.RandomRotation(degrees=30.),
         augmentation.RandomPerspective(distortion_scale=0.4),
         augmentation.RandomResizedCrop((224, 224)),
         augmentation.RandomHorizontalFlip(p=0.5),
         augmentation.RandomVerticalFlip(p=0.5),
         # K.augmentation.GaussianBlur((3, 3), (0.1, 2.0), p=1.0),
         # K.augmentation.ColorJitter(0.01, 0.01, 0.01, 0.01, p=0.25),
     )
     self.denorm = augmentation.Denormalize(Tensor(DATASET_IMAGE_MEAN), Tensor(DATASET_IMAGE_STD))
コード例 #2
0
    def __init__(self,
                 brightness=(0.75, 1.25),
                 contrast=(0.75, 1.25),
                 saturation=(0., 2.),
                 translate=(0.125, 0.125),
                 normalized=True,
                 mean=0.5,
                 std=0.5,
                 device=None):
        if normalized:
            if isinstance(mean,
                          (tuple, list)) and isinstance(std, (tuple, list)):
                if not device:
                    raise Exception(
                        'Please specify a torch.device() object when using mean and std for each channels'
                    )
                mean = torch.Tensor(mean).to(device)
                std = torch.Tensor(std).to(device)
            self.normalize = aug.Normalize(mean, std)
            self.denormalize = aug.Denormalize(mean, std)
        else:
            self.normalize, self.denormalize = None, None

        color_jitter = aug.ColorJitter(
            brightness=brightness,
            contrast=contrast,
            saturation=saturation,
            p=1.)  # rand_brightness, rand_contrast, rand_saturation
        affine = aug.RandomAffine(degrees=0,
                                  translate=translate,
                                  padding_mode=SamplePadding.BORDER,
                                  p=1.)  # rand_translate
        cutout = aug.RandomErasing(value=0.5, p=1.)  # rand_cutout

        self.augmentations = {
            'color': color_jitter,
            'translation': affine,
            'cutout': cutout
        }
コード例 #3
0
ファイル: test_container.py プロジェクト: shijianjian/kornia
class TestVideoSequential:
    @pytest.mark.parametrize('shape', [(3, 4), (2, 3, 4), (2, 3, 5, 6),
                                       (2, 3, 4, 5, 6, 7)])
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_exception(self, shape, data_format, device, dtype):
        aug_list = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1),
                                     data_format=data_format,
                                     same_on_frame=True)
        with pytest.raises(AssertionError):
            img = torch.randn(*shape, device=device, dtype=dtype)
            aug_list(img)

    @pytest.mark.parametrize(
        'augmentation',
        [
            K.RandomAffine(360, p=1.0),
            K.CenterCrop((3, 3), p=1.0),
            K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
            K.RandomCrop((5, 5), p=1.0),
            K.RandomErasing(p=1.0),
            K.RandomGrayscale(p=1.0),
            K.RandomHorizontalFlip(p=1.0),
            K.RandomVerticalFlip(p=1.0),
            K.RandomPerspective(p=1.0),
            K.RandomResizedCrop((5, 5), p=1.0),
            K.RandomRotation(360.0, p=1.0),
            K.RandomSolarize(p=1.0),
            K.RandomPosterize(p=1.0),
            K.RandomSharpness(p=1.0),
            K.RandomEqualize(p=1.0),
            K.RandomMotionBlur(3, 35.0, 0.5, p=1.0),
            K.Normalize(torch.tensor([0.5, 0.5, 0.5]),
                        torch.tensor([0.5, 0.5, 0.5]),
                        p=1.0),
            K.Denormalize(torch.tensor([0.5, 0.5, 0.5]),
                          torch.tensor([0.5, 0.5, 0.5]),
                          p=1.0),
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_augmentation(self, augmentation, data_format, device, dtype):
        input = torch.randint(255, (1, 3, 3, 5, 6), device=device,
                              dtype=dtype).repeat(2, 1, 1, 1, 1) / 255.0
        torch.manual_seed(21)
        aug_list = K.VideoSequential(augmentation,
                                     data_format=data_format,
                                     same_on_frame=True)
        reproducibility_test(input, aug_list)

    @pytest.mark.parametrize(
        'augmentations',
        [
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.RandomAffine(360, p=1.0)
            ],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)
            ],
            [K.RandomAffine(360, p=1.0),
             kornia.color.BgrToRgb()],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0),
                K.RandomAffine(360, p=0.0)
            ],
            [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0)],
            [K.RandomAffine(360, p=0.0)],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.RandomAffine(360, p=1.0),
                K.RandomMixUp(p=1.0)
            ],
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    @pytest.mark.parametrize('random_apply',
                             [1, (1, 1), (1, ), 10, True, False])
    def test_same_on_frame(self, augmentations, data_format, random_apply,
                           device, dtype):
        aug_list = K.VideoSequential(*augmentations,
                                     data_format=data_format,
                                     same_on_frame=True,
                                     random_apply=random_apply)

        if data_format == 'BCTHW':
            input = torch.randn(2, 3, 1, 5, 6, device=device,
                                dtype=dtype).repeat(1, 1, 4, 1, 1)
            output = aug_list(input)
            if aug_list.return_label:
                output, _ = output
            assert (output[:, :, 0] == output[:, :, 1]).all()
            assert (output[:, :, 1] == output[:, :, 2]).all()
            assert (output[:, :, 2] == output[:, :, 3]).all()
        if data_format == 'BTCHW':
            input = torch.randn(2, 1, 3, 5, 6, device=device,
                                dtype=dtype).repeat(1, 4, 1, 1, 1)
            output = aug_list(input)
            if aug_list.return_label:
                output, _ = output
            assert (output[:, 0] == output[:, 1]).all()
            assert (output[:, 1] == output[:, 2]).all()
            assert (output[:, 2] == output[:, 3]).all()
        reproducibility_test(input, aug_list)

    @pytest.mark.parametrize(
        'augmentations',
        [
            [K.RandomAffine(360, p=1.0)],
            [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)],
            [
                K.RandomAffine(360, p=0.0),
                K.ImageSequential(K.RandomAffine(360, p=0.0))
            ],
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_against_sequential(self, augmentations, data_format, device,
                                dtype):
        aug_list_1 = K.VideoSequential(*augmentations,
                                       data_format=data_format,
                                       same_on_frame=False)
        aug_list_2 = torch.nn.Sequential(*augmentations)

        if data_format == 'BCTHW':
            input = torch.randn(2, 3, 1, 5, 6, device=device,
                                dtype=dtype).repeat(1, 1, 4, 1, 1)
        if data_format == 'BTCHW':
            input = torch.randn(2, 1, 3, 5, 6, device=device,
                                dtype=dtype).repeat(1, 4, 1, 1, 1)

        torch.manual_seed(0)
        output_1 = aug_list_1(input)

        torch.manual_seed(0)
        if data_format == 'BCTHW':
            input = input.transpose(1, 2)
        output_2 = aug_list_2(input.reshape(-1, 3, 5, 6))
        output_2 = output_2.view(2, 4, 3, 5, 6)
        if data_format == 'BCTHW':
            output_2 = output_2.transpose(1, 2)
        assert (output_1 == output_2).all(), dict(aug_list_1._params)

    @pytest.mark.jit
    @pytest.mark.skip(reason="turn off due to Union Type")
    def test_jit(self, device, dtype):
        B, C, D, H, W = 2, 3, 5, 4, 4
        img = torch.ones(B, C, D, H, W, device=device, dtype=dtype)
        op = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1),
                               same_on_frame=True)
        op_jit = torch.jit.script(op)
        assert_close(op(img), op_jit(img))