Ejemplo n.º 1
0
def kornia_list(MAGN: int = 4):
    """
    Returns standard list of kornia transforms, each with magnitude `MAGN`.
    
    Args:
        MAGN (int): Magnitude of each transform in the returned list.
    """
    transform_list = [
        # SPATIAL
        K.RandomHorizontalFlip(p=1),
        K.RandomRotation(degrees=90., p=1),
        K.RandomAffine(degrees=MAGN * 5.,
                       shear=MAGN / 5,
                       translate=MAGN / 20,
                       p=1),
        K.RandomPerspective(distortion_scale=MAGN / 25, p=1),

        # PIXEL-LEVEL
        K.ColorJitter(brightness=MAGN / 30, p=1),  # brightness
        K.ColorJitter(saturation=MAGN / 30, p=1),  # saturation
        K.ColorJitter(contrast=MAGN / 30, p=1),  # contrast
        K.ColorJitter(hue=MAGN / 30, p=1),  # hue
        K.ColorJitter(p=0),  # identity
        K.RandomMotionBlur(kernel_size=2 * (MAGN // 3) + 1,
                           angle=MAGN,
                           direction=1.,
                           p=1),
        K.RandomErasing(scale=(MAGN / 100, MAGN / 50),
                        ratio=(MAGN / 20, MAGN),
                        p=1),
    ]
    return transform_list
Ejemplo n.º 2
0
    def test_random_erasing(self, device, dtype):
        fill_value = 0.5
        input = torch.randn(3, 3, 100, 100, device=device, dtype=dtype)
        aug = K.AugmentationSequential(
            K.RandomErasing(p=1., value=fill_value),
            data_keys=["input", "mask"],
        )

        reproducibility_test((input, input), aug)

        out = aug(input, input)
        assert torch.all(out[1][out[0] == fill_value] == 0.)
Ejemplo n.º 3
0
    def __init__(self,
                 brightness=(0.75, 1.25),
                 contrast=(0.75, 1.25),
                 saturation=(0., 2.),
                 translate=(0.125, 0.125),
                 normalized=True,
                 mean=0.5,
                 std=0.5,
                 device=None):
        if normalized:
            if isinstance(mean,
                          (tuple, list)) and isinstance(std, (tuple, list)):
                if not device:
                    raise Exception(
                        'Please specify a torch.device() object when using mean and std for each channels'
                    )
                mean = torch.Tensor(mean).to(device)
                std = torch.Tensor(std).to(device)
            self.normalize = aug.Normalize(mean, std)
            self.denormalize = aug.Denormalize(mean, std)
        else:
            self.normalize, self.denormalize = None, None

        color_jitter = aug.ColorJitter(
            brightness=brightness,
            contrast=contrast,
            saturation=saturation,
            p=1.)  # rand_brightness, rand_contrast, rand_saturation
        affine = aug.RandomAffine(degrees=0,
                                  translate=translate,
                                  padding_mode=SamplePadding.BORDER,
                                  p=1.)  # rand_translate
        cutout = aug.RandomErasing(value=0.5, p=1.)  # rand_cutout

        self.augmentations = {
            'color': color_jitter,
            'translation': affine,
            'cutout': cutout
        }
Ejemplo n.º 4
0
		def __init__(self, cut_size, cutn, cut_pow=1.):
			super().__init__()
			self.cut_size = cut_size
			self.cutn = cutn
			self.cut_pow = cut_pow

			self.augs = nn.Sequential(
				# K.RandomHorizontalFlip(p=0.5),
				# K.RandomVerticalFlip(p=0.5),
				# K.RandomSolarize(0.01, 0.01, p=0.7),
				# K.RandomSharpness(0.3, p=0.4),
				# K.RandomResizedCrop(
				#	size=(self.cut_size, self.cut_size), 
				#	scale=(0.1, 1), ratio=(0.75, 1.333), 
				#	cropping_mode="resample", p=0.5
				# ),
				# K.RandomCrop(
				#	size=(self.cut_size, self.cut_size), p=0.5
				# ),
				K.RandomAffine(
					degrees=15, translate=0.1, p=0.7, 
					padding_mode="border"
				),
				K.RandomPerspective(0.7, p=0.7),
				K.ColorJitter(hue=0.1, saturation=0.1, p=0.7),
				K.RandomErasing(
					(.1, .4), (.3, 1/.3), same_on_batch=True, p=0.7
				),
			)

			self.noise_fac = 0.1
			self.av_pool = nn.AdaptiveAvgPool2d(
				(self.cut_size, self.cut_size)
			)
			self.max_pool = nn.AdaptiveMaxPool2d(
				(self.cut_size, self.cut_size)
			)
Ejemplo n.º 5
0
        args.n_mlp,
        channel_multiplier=args.channel_multiplier,
        constant_input=args.constant_input,
    ).to(device)
    g_ema.requires_grad_(False)
    g_ema.eval()
    accumulate(g_ema, generator, 0)

    augment_fn = nn.Sequential(
        nn.ReflectionPad2d(int((math.sqrt(2) - 1) * args.size / 4)),  # zoom out
        augs.RandomHorizontalFlip(),
        RandomApply(augs.RandomAffine(degrees=0, translate=(0.25, 0.25), shear=(15, 15)), p=0.2),
        RandomApply(augs.RandomRotation(180), p=0.2),
        augs.RandomResizedCrop(size=(args.size, args.size), scale=(1, 1), ratio=(1, 1)),
        RandomApply(augs.RandomResizedCrop(size=(args.size, args.size), scale=(0.5, 0.9)), p=0.1),  # zoom in
        RandomApply(augs.RandomErasing(), p=0.1),
    )
    contrast_learner = (
        ContrastiveLearner(discriminator, args.size, augment_fn=augment_fn, hidden_layer=(-1, 0))
        if args.contrastive > 0
        else None
    )

    g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
    d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)

    g_optim = th.optim.Adam(
        generator.parameters(), lr=args.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
    )
    d_optim = th.optim.Adam(
        discriminator.parameters(), lr=args.lr * d_reg_ratio, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
Ejemplo n.º 6
0
 def __init__(self, probability: float = 0.5):
     self._probability = probability
     self._operation = aug.RandomErasing(p=probability)
Ejemplo n.º 7
0
class TestVideoSequential:
    @pytest.mark.parametrize('shape', [(3, 4), (2, 3, 4), (2, 3, 5, 6),
                                       (2, 3, 4, 5, 6, 7)])
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_exception(self, shape, data_format, device, dtype):
        aug_list = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1),
                                     data_format=data_format,
                                     same_on_frame=True)
        with pytest.raises(AssertionError):
            img = torch.randn(*shape, device=device, dtype=dtype)
            aug_list(img)

    @pytest.mark.parametrize(
        'augmentation',
        [
            K.RandomAffine(360, p=1.0),
            K.CenterCrop((3, 3), p=1.0),
            K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
            K.RandomCrop((5, 5), p=1.0),
            K.RandomErasing(p=1.0),
            K.RandomGrayscale(p=1.0),
            K.RandomHorizontalFlip(p=1.0),
            K.RandomVerticalFlip(p=1.0),
            K.RandomPerspective(p=1.0),
            K.RandomResizedCrop((5, 5), p=1.0),
            K.RandomRotation(360.0, p=1.0),
            K.RandomSolarize(p=1.0),
            K.RandomPosterize(p=1.0),
            K.RandomSharpness(p=1.0),
            K.RandomEqualize(p=1.0),
            K.RandomMotionBlur(3, 35.0, 0.5, p=1.0),
            K.Normalize(torch.tensor([0.5, 0.5, 0.5]),
                        torch.tensor([0.5, 0.5, 0.5]),
                        p=1.0),
            K.Denormalize(torch.tensor([0.5, 0.5, 0.5]),
                          torch.tensor([0.5, 0.5, 0.5]),
                          p=1.0),
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_augmentation(self, augmentation, data_format, device, dtype):
        input = torch.randint(255, (1, 3, 3, 5, 6), device=device,
                              dtype=dtype).repeat(2, 1, 1, 1, 1) / 255.0
        torch.manual_seed(21)
        aug_list = K.VideoSequential(augmentation,
                                     data_format=data_format,
                                     same_on_frame=True)
        reproducibility_test(input, aug_list)

    @pytest.mark.parametrize(
        'augmentations',
        [
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.RandomAffine(360, p=1.0)
            ],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)
            ],
            [K.RandomAffine(360, p=1.0),
             kornia.color.BgrToRgb()],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0),
                K.RandomAffine(360, p=0.0)
            ],
            [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0)],
            [K.RandomAffine(360, p=0.0)],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.RandomAffine(360, p=1.0),
                K.RandomMixUp(p=1.0)
            ],
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    @pytest.mark.parametrize('random_apply',
                             [1, (1, 1), (1, ), 10, True, False])
    def test_same_on_frame(self, augmentations, data_format, random_apply,
                           device, dtype):
        aug_list = K.VideoSequential(*augmentations,
                                     data_format=data_format,
                                     same_on_frame=True,
                                     random_apply=random_apply)

        if data_format == 'BCTHW':
            input = torch.randn(2, 3, 1, 5, 6, device=device,
                                dtype=dtype).repeat(1, 1, 4, 1, 1)
            output = aug_list(input)
            if aug_list.return_label:
                output, _ = output
            assert (output[:, :, 0] == output[:, :, 1]).all()
            assert (output[:, :, 1] == output[:, :, 2]).all()
            assert (output[:, :, 2] == output[:, :, 3]).all()
        if data_format == 'BTCHW':
            input = torch.randn(2, 1, 3, 5, 6, device=device,
                                dtype=dtype).repeat(1, 4, 1, 1, 1)
            output = aug_list(input)
            if aug_list.return_label:
                output, _ = output
            assert (output[:, 0] == output[:, 1]).all()
            assert (output[:, 1] == output[:, 2]).all()
            assert (output[:, 2] == output[:, 3]).all()
        reproducibility_test(input, aug_list)

    @pytest.mark.parametrize(
        'augmentations',
        [
            [K.RandomAffine(360, p=1.0)],
            [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)],
            [
                K.RandomAffine(360, p=0.0),
                K.ImageSequential(K.RandomAffine(360, p=0.0))
            ],
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_against_sequential(self, augmentations, data_format, device,
                                dtype):
        aug_list_1 = K.VideoSequential(*augmentations,
                                       data_format=data_format,
                                       same_on_frame=False)
        aug_list_2 = torch.nn.Sequential(*augmentations)

        if data_format == 'BCTHW':
            input = torch.randn(2, 3, 1, 5, 6, device=device,
                                dtype=dtype).repeat(1, 1, 4, 1, 1)
        if data_format == 'BTCHW':
            input = torch.randn(2, 1, 3, 5, 6, device=device,
                                dtype=dtype).repeat(1, 4, 1, 1, 1)

        torch.manual_seed(0)
        output_1 = aug_list_1(input)

        torch.manual_seed(0)
        if data_format == 'BCTHW':
            input = input.transpose(1, 2)
        output_2 = aug_list_2(input.reshape(-1, 3, 5, 6))
        output_2 = output_2.view(2, 4, 3, 5, 6)
        if data_format == 'BCTHW':
            output_2 = output_2.transpose(1, 2)
        assert (output_1 == output_2).all(), dict(aug_list_1._params)

    @pytest.mark.jit
    @pytest.mark.skip(reason="turn off due to Union Type")
    def test_jit(self, device, dtype):
        B, C, D, H, W = 2, 3, 5, 4, 4
        img = torch.ones(B, C, D, H, W, device=device, dtype=dtype)
        op = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1),
                               same_on_frame=True)
        op_jit = torch.jit.script(op)
        assert_close(op(img), op_jit(img))
Ejemplo n.º 8
0
def random_erase(data, opt):
    rec_er = K.RandomErasing(0.5, (.02, .4), (.3, 1 / .3))
    out = rec_er(data.view([-1] + list(data.shape[-3:])))

    return out.view(data.shape)
Ejemplo n.º 9
0
def main(argv=None):
    # CLI
    parser = argparse.ArgumentParser()
    parser.add_argument("name", help="Name of the experiment")
    parser.add_argument(
        "-a",
        "--augment",
        action="store_true",
        help="If True, we apply augmentations",
    )
    parser.add_argument("-b",
                        "--batch-size",
                        type=int,
                        default=16,
                        help="Batch size")
    parser.add_argument(
        "--b1",
        type=float,
        default=0.5,
        help="Adam optimizer hyperparamter",
    )
    parser.add_argument(
        "--b2",
        type=float,
        default=0.999,
        help="Adam optimizer hyperparamter",
    )
    parser.add_argument(
        "-d",
        "--device",
        type=str,
        default="cpu",
        choices=["cpu", "cuda"],
        help="Device to use",
    )
    parser.add_argument(
        "--eval-frequency",
        type=int,
        default=400,
        help="Generate generator images every `eval_frequency` epochs",
    )
    parser.add_argument(
        "--latent-dim",
        type=int,
        default=100,
        help="Dimensionality of the random noise",
    )
    parser.add_argument("--lr",
                        type=float,
                        default=0.0002,
                        help="Learning rate")
    parser.add_argument(
        "--ndf",
        type=int,
        default=32,
        help="Number of discriminator feature maps (after first convolution)",
    )
    parser.add_argument(
        "--ngf",
        type=int,
        default=32,
        help=
        "Number of generator feature maps (before last transposed convolution)",
    )
    parser.add_argument(
        "-n",
        "--n-epochs",
        type=int,
        default=200,
        help="Number of training epochs",
    )
    parser.add_argument(
        "--mosaic-size",
        type=int,
        default=10,
        help="Size of the side of the rectangular mosaic",
    )
    parser.add_argument(
        "-p",
        "--prob",
        type=float,
        default=0.9,
        help="Probability of applying an augmentation",
    )

    args = parser.parse_args(argv)
    args_d = vars(args)
    print(args)

    img_size = 128

    # Additional parameters
    device = torch.device(args.device)
    mosaic_kwargs = {"nrow": args.mosaic_size, "normalize": True}
    n_mosaic_cells = args.mosaic_size * args.mosaic_size
    sample_showcase_ix = (
        0  # this one will be used to demonstrate the augmentations
    )

    augment_module = torch.nn.Sequential(
        K.RandomAffine(degrees=0, translate=(1 / 8, 1 / 8), p=args.prob),
        K.RandomErasing((0.0, 0.5), p=args.prob),
    )

    # Loss function
    adversarial_loss = torch.nn.BCELoss()

    # Initialize generator and discriminator
    generator = Generator(latent_dim=args.latent_dim, ngf=args.ngf)
    discriminator = Discriminator(
        ndf=args.ndf, augment_module=augment_module if args.augment else None)

    generator.to(device)
    discriminator.to(device)

    # Initialize weights
    generator.apply(init_weights_)
    discriminator.apply(init_weights_)

    # Configure data loader
    data_path = pathlib.Path("data")
    tform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])
    dataset = DatasetImages(
        data_path,
        transform=tform,
    )
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
    )

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=args.lr,
                                   betas=(args.b1, args.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=args.lr,
                                   betas=(args.b1, args.b2))

    # Output path and metadata
    output_path = pathlib.Path("outputs") / args.name
    output_path.mkdir(exist_ok=True, parents=True)

    # Add other parameters (not included in CLI)
    args_d["time"] = datetime.now()
    args_d["kornia"] = str(augment_module)

    # Prepare tensorboard writer
    writer = SummaryWriter(output_path)

    # Log hyperparameters as text
    writer.add_text(
        "hyperparameter",
        pprint.pformat(args_d).replace(
            "\n", "  \n"),  # markdown needs 2 spaces before newline
        0,
    )
    # Log true data
    writer.add_image(
        "true_data",
        make_grid(torch.stack([dataset[i] for i in range(n_mosaic_cells)]),
                  **mosaic_kwargs),
        0,
    )
    # Log augmented data
    batch_showcase = dataset[sample_showcase_ix][None, ...].repeat(
        n_mosaic_cells, 1, 1, 1)
    batch_showcase_aug = discriminator.augment_module(batch_showcase)
    writer.add_image("augmentations",
                     make_grid(batch_showcase_aug, **mosaic_kwargs), 0)

    # Prepate evaluation noise
    z_eval = torch.randn(n_mosaic_cells, args.latent_dim).to(device)

    for epoch in tqdm(range(args.n_epochs)):
        for i, imgs in enumerate(dataloader):
            n_samples, *_ = imgs.shape
            batches_done = epoch * len(dataloader) + i

            # Adversarial ground truths
            valid = 0.9 * torch.ones(
                n_samples, 1, device=device, dtype=torch.float32)
            fake = torch.zeros(n_samples,
                               1,
                               device=device,
                               dtype=torch.float32)

            # D preparation
            optimizer_D.zero_grad()

            # D loss on reals
            real_imgs = imgs.to(device)
            d_x = discriminator(real_imgs)
            real_loss = adversarial_loss(d_x, valid)
            real_loss.backward()

            # D loss on fakes
            z = torch.randn(n_samples, args.latent_dim).to(device)
            gen_imgs = generator(z)
            d_g_z1 = discriminator(gen_imgs.detach())

            fake_loss = adversarial_loss(d_g_z1, fake)
            fake_loss.backward()

            optimizer_D.step()  # we called backward twice, the result is a sum

            # G preparation
            optimizer_G.zero_grad()

            # G loss
            d_g_z2 = discriminator(gen_imgs)
            g_loss = adversarial_loss(d_g_z2, valid)

            g_loss.backward()
            optimizer_G.step()

            # Logging
            if batches_done % 50 == 0:
                writer.add_scalar("d_x", d_x.mean().item(), batches_done)
                writer.add_scalar("d_g_z1", d_g_z1.mean().item(), batches_done)
                writer.add_scalar("d_g_z2", d_g_z2.mean().item(), batches_done)
                writer.add_scalar("D_loss", (real_loss + fake_loss).item(),
                                  batches_done)
                writer.add_scalar("G_loss", g_loss.item(), batches_done)

            if epoch % args.eval_frequency == 0 and i == 0:
                generator.eval()
                discriminator.eval()

                # Generate fake images
                gen_imgs_eval = generator(z_eval)

                # Generate nice mosaic
                writer.add_image(
                    "fake",
                    make_grid(gen_imgs_eval.data, **mosaic_kwargs),
                    batches_done,
                )

                # Save checkpoint (and potentially overwrite an existing one)
                torch.save(generator, output_path / "model.pt")

                # Make sure generator and discriminator in the training mode
                generator.train()
                discriminator.train()
                Ka.RandomAffine(degrees=1 * 5.0,
                                shear=1 / 5,
                                translate=1 / 20,
                                p=0.25),
                Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25),
                # PIXEL-LEVEL
                Ka.ColorJitter(brightness=1 / 30, p=0.25),  # brightness
                Ka.ColorJitter(saturation=1 / 30, p=0.25),  # saturation
                Ka.ColorJitter(contrast=1 / 30, p=0.25),  # contrast
                Ka.ColorJitter(hue=1 / 30, p=0.25),  # hue
                Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1,
                                    angle=1,
                                    direction=1.0,
                                    p=0.25),
                Ka.RandomErasing(scale=(1 / 100, 1 / 50),
                                 ratio=(1 / 20, 1),
                                 p=0.25),
            ),
        ),
        ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
    ),
    "collate":
    kornia_collate,
    "per_batch_transform_on_device":
    ApplyToKeys(
        DataKeys.INPUT,
        Ka.RandomHorizontalFlip(p=0.25),
    ),
}

# construct datamodule
Ejemplo n.º 11
0
    def __init__(
        self,
        image_size,
        latent_dim=512,
        style_depth=8,
        network_capacity=16,
        transparent=False,
        fp16=False,
        cl_reg=False,
        augment_fn=None,
        steps=1,
        lr=1e-4,
        fq_layers=[],
        fq_dict_size=256,
        attn_layers=[],
    ):
        super().__init__()
        self.lr = lr
        self.steps = steps
        self.ema_updater = EMA(0.995)

        self.S = StyleVectorizer(latent_dim, style_depth)
        self.G = Generator(image_size,
                           latent_dim,
                           network_capacity,
                           transparent=transparent,
                           attn_layers=attn_layers)
        self.D = Discriminator(
            image_size,
            network_capacity,
            fq_layers=fq_layers,
            fq_dict_size=fq_dict_size,
            attn_layers=attn_layers,
            transparent=transparent,
        )

        self.SE = StyleVectorizer(latent_dim, style_depth)
        self.GE = Generator(image_size,
                            latent_dim,
                            network_capacity,
                            transparent=transparent,
                            attn_layers=attn_layers)

        set_requires_grad(self.SE, False)
        set_requires_grad(self.GE, False)

        generator_params = list(self.G.parameters()) + list(
            self.S.parameters())
        self.G_opt = DiffGrad(generator_params, lr=self.lr, betas=(0.5, 0.9))
        self.D_opt = DiffGrad(self.D.parameters(),
                              lr=self.lr,
                              betas=(0.5, 0.9))

        self._init_weights()
        self.reset_parameter_averaging()

        self.cuda()

        if fp16:
            (self.S, self.G, self.D, self.SE,
             self.GE), (self.G_opt, self.D_opt) = amp.initialize(
                 [self.S, self.G, self.D, self.SE, self.GE],
                 [self.G_opt, self.D_opt],
                 opt_level="O2")

        # experimental contrastive loss discriminator regularization
        if augment_fn is not None:
            self.augment_fn = augment_fn
        else:
            self.augment_fn = nn.Sequential(
                nn.ReflectionPad2d(int((sqrt(2) - 1) * image_size / 4)),
                RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.7),
                augs.RandomGrayscale(p=0.2),
                augs.RandomHorizontalFlip(),
                RandomApply(augs.RandomAffine(degrees=0,
                                              translate=(0.25, 0.25),
                                              shear=(15, 15)),
                            p=0.3),
                RandomApply(nn.Sequential(
                    augs.RandomRotation(180),
                    augs.CenterCrop(size=(image_size, image_size))),
                            p=0.2),
                augs.RandomResizedCrop(size=(image_size, image_size)),
                RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
                RandomApply(augs.RandomErasing(), p=0.1),
            )

        self.D_cl = (ContrastiveLearner(self.D,
                                        image_size,
                                        augment_fn=self.augment_fn,
                                        fp16=fp16,
                                        hidden_layer="flatten")
                     if cl_reg else None)
Ejemplo n.º 12
0
def get_augmenter(augmenter_type: str,
                  image_size: ImageSizeType,
                  dataset_mean: DatasetStatType,
                  dataset_std: DatasetStatType,
                  padding: PaddingInputType = 1. / 8.,
                  pad_if_needed: bool = False,
                  subset_size: int = 2) -> Union[Module, Callable]:
    """
    
    Args:
        augmenter_type: augmenter type
        image_size: (height, width) image size
        dataset_mean: dataset mean value in CHW
        dataset_std: dataset standard deviation in CHW
        padding: percent of image size to pad on each border of the image. If a sequence of length 4 is provided,
            it is used to pad left, top, right, bottom borders respectively. If a sequence of length 2 is provided, it is
            used to pad left/right, top/bottom borders, respectively.
        pad_if_needed: bool flag for RandomCrop "pad_if_needed" option
        subset_size: number of augmentations used in subset

    Returns: nn.Module for Kornia augmentation or Callable for torchvision transform

    """
    if not isinstance(padding, tuple):
        assert isinstance(padding, float)
        padding = (padding, padding, padding, padding)

    assert len(padding) == 2 or len(padding) == 4
    if len(padding) == 2:
        # padding of length 2 is used to pad left/right, top/bottom borders, respectively
        # padding of length 4 is used to pad left, top, right, bottom borders respectively
        padding = (padding[0], padding[1], padding[0], padding[1])

    # image_size is of shape (h,w); padding values is [left, top, right, bottom] borders
    padding = (int(image_size[1] * padding[0]), int(
        image_size[0] * padding[1]), int(image_size[1] * padding[2]),
               int(image_size[0] * padding[3]))

    augmenter_type = augmenter_type.strip().lower()

    if augmenter_type == "simple":
        return nn.Sequential(
            K.RandomCrop(size=image_size,
                         padding=padding,
                         pad_if_needed=pad_if_needed,
                         padding_mode='reflect'),
            K.RandomHorizontalFlip(p=0.5),
            K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32),
                        std=torch.tensor(dataset_std, dtype=torch.float32)),
        )

    elif augmenter_type == "fixed":
        return nn.Sequential(
            K.RandomHorizontalFlip(p=0.5),
            # K.RandomVerticalFlip(p=0.2),
            K.RandomResizedCrop(size=image_size,
                                scale=(0.8, 1.0),
                                ratio=(1., 1.)),
            RandomAugmentation(p=0.5,
                               augmentation=F.GaussianBlur2d(
                                   kernel_size=(3, 3),
                                   sigma=(1.5, 1.5),
                                   border_type='constant')),
            K.ColorJitter(contrast=(0.75, 1.5)),
            # additive Gaussian noise
            K.RandomErasing(p=0.1),
            # Multiply
            K.RandomAffine(degrees=(-25., 25.),
                           translate=(0.2, 0.2),
                           scale=(0.8, 1.2),
                           shear=(-8., 8.)),
            K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32),
                        std=torch.tensor(dataset_std, dtype=torch.float32)),
        )

    elif augmenter_type in ["validation", "test"]:
        return nn.Sequential(
            K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32),
                        std=torch.tensor(dataset_std, dtype=torch.float32)), )

    elif augmenter_type == "randaugment":
        return nn.Sequential(
            K.RandomCrop(size=image_size,
                         padding=padding,
                         pad_if_needed=pad_if_needed,
                         padding_mode='reflect'),
            K.RandomHorizontalFlip(p=0.5),
            RandAugmentNS(n=subset_size, m=10),
            K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32),
                        std=torch.tensor(dataset_std, dtype=torch.float32)),
        )

    else:
        raise NotImplementedError(
            f"\"{augmenter_type}\" is not a supported augmenter type")