def kornia_list(MAGN: int = 4): """ Returns standard list of kornia transforms, each with magnitude `MAGN`. Args: MAGN (int): Magnitude of each transform in the returned list. """ transform_list = [ # SPATIAL K.RandomHorizontalFlip(p=1), K.RandomRotation(degrees=90., p=1), K.RandomAffine(degrees=MAGN * 5., shear=MAGN / 5, translate=MAGN / 20, p=1), K.RandomPerspective(distortion_scale=MAGN / 25, p=1), # PIXEL-LEVEL K.ColorJitter(brightness=MAGN / 30, p=1), # brightness K.ColorJitter(saturation=MAGN / 30, p=1), # saturation K.ColorJitter(contrast=MAGN / 30, p=1), # contrast K.ColorJitter(hue=MAGN / 30, p=1), # hue K.ColorJitter(p=0), # identity K.RandomMotionBlur(kernel_size=2 * (MAGN // 3) + 1, angle=MAGN, direction=1., p=1), K.RandomErasing(scale=(MAGN / 100, MAGN / 50), ratio=(MAGN / 20, MAGN), p=1), ] return transform_list
def test_random_erasing(self, device, dtype): fill_value = 0.5 input = torch.randn(3, 3, 100, 100, device=device, dtype=dtype) aug = K.AugmentationSequential( K.RandomErasing(p=1., value=fill_value), data_keys=["input", "mask"], ) reproducibility_test((input, input), aug) out = aug(input, input) assert torch.all(out[1][out[0] == fill_value] == 0.)
def __init__(self, brightness=(0.75, 1.25), contrast=(0.75, 1.25), saturation=(0., 2.), translate=(0.125, 0.125), normalized=True, mean=0.5, std=0.5, device=None): if normalized: if isinstance(mean, (tuple, list)) and isinstance(std, (tuple, list)): if not device: raise Exception( 'Please specify a torch.device() object when using mean and std for each channels' ) mean = torch.Tensor(mean).to(device) std = torch.Tensor(std).to(device) self.normalize = aug.Normalize(mean, std) self.denormalize = aug.Denormalize(mean, std) else: self.normalize, self.denormalize = None, None color_jitter = aug.ColorJitter( brightness=brightness, contrast=contrast, saturation=saturation, p=1.) # rand_brightness, rand_contrast, rand_saturation affine = aug.RandomAffine(degrees=0, translate=translate, padding_mode=SamplePadding.BORDER, p=1.) # rand_translate cutout = aug.RandomErasing(value=0.5, p=1.) # rand_cutout self.augmentations = { 'color': color_jitter, 'translation': affine, 'cutout': cutout }
def __init__(self, cut_size, cutn, cut_pow=1.): super().__init__() self.cut_size = cut_size self.cutn = cutn self.cut_pow = cut_pow self.augs = nn.Sequential( # K.RandomHorizontalFlip(p=0.5), # K.RandomVerticalFlip(p=0.5), # K.RandomSolarize(0.01, 0.01, p=0.7), # K.RandomSharpness(0.3, p=0.4), # K.RandomResizedCrop( # size=(self.cut_size, self.cut_size), # scale=(0.1, 1), ratio=(0.75, 1.333), # cropping_mode="resample", p=0.5 # ), # K.RandomCrop( # size=(self.cut_size, self.cut_size), p=0.5 # ), K.RandomAffine( degrees=15, translate=0.1, p=0.7, padding_mode="border" ), K.RandomPerspective(0.7, p=0.7), K.ColorJitter(hue=0.1, saturation=0.1, p=0.7), K.RandomErasing( (.1, .4), (.3, 1/.3), same_on_batch=True, p=0.7 ), ) self.noise_fac = 0.1 self.av_pool = nn.AdaptiveAvgPool2d( (self.cut_size, self.cut_size) ) self.max_pool = nn.AdaptiveMaxPool2d( (self.cut_size, self.cut_size) )
args.n_mlp, channel_multiplier=args.channel_multiplier, constant_input=args.constant_input, ).to(device) g_ema.requires_grad_(False) g_ema.eval() accumulate(g_ema, generator, 0) augment_fn = nn.Sequential( nn.ReflectionPad2d(int((math.sqrt(2) - 1) * args.size / 4)), # zoom out augs.RandomHorizontalFlip(), RandomApply(augs.RandomAffine(degrees=0, translate=(0.25, 0.25), shear=(15, 15)), p=0.2), RandomApply(augs.RandomRotation(180), p=0.2), augs.RandomResizedCrop(size=(args.size, args.size), scale=(1, 1), ratio=(1, 1)), RandomApply(augs.RandomResizedCrop(size=(args.size, args.size), scale=(0.5, 0.9)), p=0.1), # zoom in RandomApply(augs.RandomErasing(), p=0.1), ) contrast_learner = ( ContrastiveLearner(discriminator, args.size, augment_fn=augment_fn, hidden_layer=(-1, 0)) if args.contrastive > 0 else None ) g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) g_optim = th.optim.Adam( generator.parameters(), lr=args.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), ) d_optim = th.optim.Adam( discriminator.parameters(), lr=args.lr * d_reg_ratio, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
def __init__(self, probability: float = 0.5): self._probability = probability self._operation = aug.RandomErasing(p=probability)
class TestVideoSequential: @pytest.mark.parametrize('shape', [(3, 4), (2, 3, 4), (2, 3, 5, 6), (2, 3, 4, 5, 6, 7)]) @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"]) def test_exception(self, shape, data_format, device, dtype): aug_list = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1), data_format=data_format, same_on_frame=True) with pytest.raises(AssertionError): img = torch.randn(*shape, device=device, dtype=dtype) aug_list(img) @pytest.mark.parametrize( 'augmentation', [ K.RandomAffine(360, p=1.0), K.CenterCrop((3, 3), p=1.0), K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomCrop((5, 5), p=1.0), K.RandomErasing(p=1.0), K.RandomGrayscale(p=1.0), K.RandomHorizontalFlip(p=1.0), K.RandomVerticalFlip(p=1.0), K.RandomPerspective(p=1.0), K.RandomResizedCrop((5, 5), p=1.0), K.RandomRotation(360.0, p=1.0), K.RandomSolarize(p=1.0), K.RandomPosterize(p=1.0), K.RandomSharpness(p=1.0), K.RandomEqualize(p=1.0), K.RandomMotionBlur(3, 35.0, 0.5, p=1.0), K.Normalize(torch.tensor([0.5, 0.5, 0.5]), torch.tensor([0.5, 0.5, 0.5]), p=1.0), K.Denormalize(torch.tensor([0.5, 0.5, 0.5]), torch.tensor([0.5, 0.5, 0.5]), p=1.0), ], ) @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"]) def test_augmentation(self, augmentation, data_format, device, dtype): input = torch.randint(255, (1, 3, 3, 5, 6), device=device, dtype=dtype).repeat(2, 1, 1, 1, 1) / 255.0 torch.manual_seed(21) aug_list = K.VideoSequential(augmentation, data_format=data_format, same_on_frame=True) reproducibility_test(input, aug_list) @pytest.mark.parametrize( 'augmentations', [ [ K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0) ], [ K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0) ], [K.RandomAffine(360, p=1.0), kornia.color.BgrToRgb()], [ K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0), K.RandomAffine(360, p=0.0) ], [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0)], [K.RandomAffine(360, p=0.0)], [ K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0), K.RandomMixUp(p=1.0) ], ], ) @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"]) @pytest.mark.parametrize('random_apply', [1, (1, 1), (1, ), 10, True, False]) def test_same_on_frame(self, augmentations, data_format, random_apply, device, dtype): aug_list = K.VideoSequential(*augmentations, data_format=data_format, same_on_frame=True, random_apply=random_apply) if data_format == 'BCTHW': input = torch.randn(2, 3, 1, 5, 6, device=device, dtype=dtype).repeat(1, 1, 4, 1, 1) output = aug_list(input) if aug_list.return_label: output, _ = output assert (output[:, :, 0] == output[:, :, 1]).all() assert (output[:, :, 1] == output[:, :, 2]).all() assert (output[:, :, 2] == output[:, :, 3]).all() if data_format == 'BTCHW': input = torch.randn(2, 1, 3, 5, 6, device=device, dtype=dtype).repeat(1, 4, 1, 1, 1) output = aug_list(input) if aug_list.return_label: output, _ = output assert (output[:, 0] == output[:, 1]).all() assert (output[:, 1] == output[:, 2]).all() assert (output[:, 2] == output[:, 3]).all() reproducibility_test(input, aug_list) @pytest.mark.parametrize( 'augmentations', [ [K.RandomAffine(360, p=1.0)], [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)], [ K.RandomAffine(360, p=0.0), K.ImageSequential(K.RandomAffine(360, p=0.0)) ], ], ) @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"]) def test_against_sequential(self, augmentations, data_format, device, dtype): aug_list_1 = K.VideoSequential(*augmentations, data_format=data_format, same_on_frame=False) aug_list_2 = torch.nn.Sequential(*augmentations) if data_format == 'BCTHW': input = torch.randn(2, 3, 1, 5, 6, device=device, dtype=dtype).repeat(1, 1, 4, 1, 1) if data_format == 'BTCHW': input = torch.randn(2, 1, 3, 5, 6, device=device, dtype=dtype).repeat(1, 4, 1, 1, 1) torch.manual_seed(0) output_1 = aug_list_1(input) torch.manual_seed(0) if data_format == 'BCTHW': input = input.transpose(1, 2) output_2 = aug_list_2(input.reshape(-1, 3, 5, 6)) output_2 = output_2.view(2, 4, 3, 5, 6) if data_format == 'BCTHW': output_2 = output_2.transpose(1, 2) assert (output_1 == output_2).all(), dict(aug_list_1._params) @pytest.mark.jit @pytest.mark.skip(reason="turn off due to Union Type") def test_jit(self, device, dtype): B, C, D, H, W = 2, 3, 5, 4, 4 img = torch.ones(B, C, D, H, W, device=device, dtype=dtype) op = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1), same_on_frame=True) op_jit = torch.jit.script(op) assert_close(op(img), op_jit(img))
def random_erase(data, opt): rec_er = K.RandomErasing(0.5, (.02, .4), (.3, 1 / .3)) out = rec_er(data.view([-1] + list(data.shape[-3:]))) return out.view(data.shape)
def main(argv=None): # CLI parser = argparse.ArgumentParser() parser.add_argument("name", help="Name of the experiment") parser.add_argument( "-a", "--augment", action="store_true", help="If True, we apply augmentations", ) parser.add_argument("-b", "--batch-size", type=int, default=16, help="Batch size") parser.add_argument( "--b1", type=float, default=0.5, help="Adam optimizer hyperparamter", ) parser.add_argument( "--b2", type=float, default=0.999, help="Adam optimizer hyperparamter", ) parser.add_argument( "-d", "--device", type=str, default="cpu", choices=["cpu", "cuda"], help="Device to use", ) parser.add_argument( "--eval-frequency", type=int, default=400, help="Generate generator images every `eval_frequency` epochs", ) parser.add_argument( "--latent-dim", type=int, default=100, help="Dimensionality of the random noise", ) parser.add_argument("--lr", type=float, default=0.0002, help="Learning rate") parser.add_argument( "--ndf", type=int, default=32, help="Number of discriminator feature maps (after first convolution)", ) parser.add_argument( "--ngf", type=int, default=32, help= "Number of generator feature maps (before last transposed convolution)", ) parser.add_argument( "-n", "--n-epochs", type=int, default=200, help="Number of training epochs", ) parser.add_argument( "--mosaic-size", type=int, default=10, help="Size of the side of the rectangular mosaic", ) parser.add_argument( "-p", "--prob", type=float, default=0.9, help="Probability of applying an augmentation", ) args = parser.parse_args(argv) args_d = vars(args) print(args) img_size = 128 # Additional parameters device = torch.device(args.device) mosaic_kwargs = {"nrow": args.mosaic_size, "normalize": True} n_mosaic_cells = args.mosaic_size * args.mosaic_size sample_showcase_ix = ( 0 # this one will be used to demonstrate the augmentations ) augment_module = torch.nn.Sequential( K.RandomAffine(degrees=0, translate=(1 / 8, 1 / 8), p=args.prob), K.RandomErasing((0.0, 0.5), p=args.prob), ) # Loss function adversarial_loss = torch.nn.BCELoss() # Initialize generator and discriminator generator = Generator(latent_dim=args.latent_dim, ngf=args.ngf) discriminator = Discriminator( ndf=args.ndf, augment_module=augment_module if args.augment else None) generator.to(device) discriminator.to(device) # Initialize weights generator.apply(init_weights_) discriminator.apply(init_weights_) # Configure data loader data_path = pathlib.Path("data") tform = transforms.Compose([ transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) dataset = DatasetImages( data_path, transform=tform, ) dataloader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, ) # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(args.b1, args.b2)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.b1, args.b2)) # Output path and metadata output_path = pathlib.Path("outputs") / args.name output_path.mkdir(exist_ok=True, parents=True) # Add other parameters (not included in CLI) args_d["time"] = datetime.now() args_d["kornia"] = str(augment_module) # Prepare tensorboard writer writer = SummaryWriter(output_path) # Log hyperparameters as text writer.add_text( "hyperparameter", pprint.pformat(args_d).replace( "\n", " \n"), # markdown needs 2 spaces before newline 0, ) # Log true data writer.add_image( "true_data", make_grid(torch.stack([dataset[i] for i in range(n_mosaic_cells)]), **mosaic_kwargs), 0, ) # Log augmented data batch_showcase = dataset[sample_showcase_ix][None, ...].repeat( n_mosaic_cells, 1, 1, 1) batch_showcase_aug = discriminator.augment_module(batch_showcase) writer.add_image("augmentations", make_grid(batch_showcase_aug, **mosaic_kwargs), 0) # Prepate evaluation noise z_eval = torch.randn(n_mosaic_cells, args.latent_dim).to(device) for epoch in tqdm(range(args.n_epochs)): for i, imgs in enumerate(dataloader): n_samples, *_ = imgs.shape batches_done = epoch * len(dataloader) + i # Adversarial ground truths valid = 0.9 * torch.ones( n_samples, 1, device=device, dtype=torch.float32) fake = torch.zeros(n_samples, 1, device=device, dtype=torch.float32) # D preparation optimizer_D.zero_grad() # D loss on reals real_imgs = imgs.to(device) d_x = discriminator(real_imgs) real_loss = adversarial_loss(d_x, valid) real_loss.backward() # D loss on fakes z = torch.randn(n_samples, args.latent_dim).to(device) gen_imgs = generator(z) d_g_z1 = discriminator(gen_imgs.detach()) fake_loss = adversarial_loss(d_g_z1, fake) fake_loss.backward() optimizer_D.step() # we called backward twice, the result is a sum # G preparation optimizer_G.zero_grad() # G loss d_g_z2 = discriminator(gen_imgs) g_loss = adversarial_loss(d_g_z2, valid) g_loss.backward() optimizer_G.step() # Logging if batches_done % 50 == 0: writer.add_scalar("d_x", d_x.mean().item(), batches_done) writer.add_scalar("d_g_z1", d_g_z1.mean().item(), batches_done) writer.add_scalar("d_g_z2", d_g_z2.mean().item(), batches_done) writer.add_scalar("D_loss", (real_loss + fake_loss).item(), batches_done) writer.add_scalar("G_loss", g_loss.item(), batches_done) if epoch % args.eval_frequency == 0 and i == 0: generator.eval() discriminator.eval() # Generate fake images gen_imgs_eval = generator(z_eval) # Generate nice mosaic writer.add_image( "fake", make_grid(gen_imgs_eval.data, **mosaic_kwargs), batches_done, ) # Save checkpoint (and potentially overwrite an existing one) torch.save(generator, output_path / "model.pt") # Make sure generator and discriminator in the training mode generator.train() discriminator.train()
Ka.RandomAffine(degrees=1 * 5.0, shear=1 / 5, translate=1 / 20, p=0.25), Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25), # PIXEL-LEVEL Ka.ColorJitter(brightness=1 / 30, p=0.25), # brightness Ka.ColorJitter(saturation=1 / 30, p=0.25), # saturation Ka.ColorJitter(contrast=1 / 30, p=0.25), # contrast Ka.ColorJitter(hue=1 / 30, p=0.25), # hue Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1, angle=1, direction=1.0, p=0.25), Ka.RandomErasing(scale=(1 / 100, 1 / 50), ratio=(1 / 20, 1), p=0.25), ), ), ApplyToKeys(DataKeys.TARGET, torch.as_tensor), ), "collate": kornia_collate, "per_batch_transform_on_device": ApplyToKeys( DataKeys.INPUT, Ka.RandomHorizontalFlip(p=0.25), ), } # construct datamodule
def __init__( self, image_size, latent_dim=512, style_depth=8, network_capacity=16, transparent=False, fp16=False, cl_reg=False, augment_fn=None, steps=1, lr=1e-4, fq_layers=[], fq_dict_size=256, attn_layers=[], ): super().__init__() self.lr = lr self.steps = steps self.ema_updater = EMA(0.995) self.S = StyleVectorizer(latent_dim, style_depth) self.G = Generator(image_size, latent_dim, network_capacity, transparent=transparent, attn_layers=attn_layers) self.D = Discriminator( image_size, network_capacity, fq_layers=fq_layers, fq_dict_size=fq_dict_size, attn_layers=attn_layers, transparent=transparent, ) self.SE = StyleVectorizer(latent_dim, style_depth) self.GE = Generator(image_size, latent_dim, network_capacity, transparent=transparent, attn_layers=attn_layers) set_requires_grad(self.SE, False) set_requires_grad(self.GE, False) generator_params = list(self.G.parameters()) + list( self.S.parameters()) self.G_opt = DiffGrad(generator_params, lr=self.lr, betas=(0.5, 0.9)) self.D_opt = DiffGrad(self.D.parameters(), lr=self.lr, betas=(0.5, 0.9)) self._init_weights() self.reset_parameter_averaging() self.cuda() if fp16: (self.S, self.G, self.D, self.SE, self.GE), (self.G_opt, self.D_opt) = amp.initialize( [self.S, self.G, self.D, self.SE, self.GE], [self.G_opt, self.D_opt], opt_level="O2") # experimental contrastive loss discriminator regularization if augment_fn is not None: self.augment_fn = augment_fn else: self.augment_fn = nn.Sequential( nn.ReflectionPad2d(int((sqrt(2) - 1) * image_size / 4)), RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.7), augs.RandomGrayscale(p=0.2), augs.RandomHorizontalFlip(), RandomApply(augs.RandomAffine(degrees=0, translate=(0.25, 0.25), shear=(15, 15)), p=0.3), RandomApply(nn.Sequential( augs.RandomRotation(180), augs.CenterCrop(size=(image_size, image_size))), p=0.2), augs.RandomResizedCrop(size=(image_size, image_size)), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), RandomApply(augs.RandomErasing(), p=0.1), ) self.D_cl = (ContrastiveLearner(self.D, image_size, augment_fn=self.augment_fn, fp16=fp16, hidden_layer="flatten") if cl_reg else None)
def get_augmenter(augmenter_type: str, image_size: ImageSizeType, dataset_mean: DatasetStatType, dataset_std: DatasetStatType, padding: PaddingInputType = 1. / 8., pad_if_needed: bool = False, subset_size: int = 2) -> Union[Module, Callable]: """ Args: augmenter_type: augmenter type image_size: (height, width) image size dataset_mean: dataset mean value in CHW dataset_std: dataset standard deviation in CHW padding: percent of image size to pad on each border of the image. If a sequence of length 4 is provided, it is used to pad left, top, right, bottom borders respectively. If a sequence of length 2 is provided, it is used to pad left/right, top/bottom borders, respectively. pad_if_needed: bool flag for RandomCrop "pad_if_needed" option subset_size: number of augmentations used in subset Returns: nn.Module for Kornia augmentation or Callable for torchvision transform """ if not isinstance(padding, tuple): assert isinstance(padding, float) padding = (padding, padding, padding, padding) assert len(padding) == 2 or len(padding) == 4 if len(padding) == 2: # padding of length 2 is used to pad left/right, top/bottom borders, respectively # padding of length 4 is used to pad left, top, right, bottom borders respectively padding = (padding[0], padding[1], padding[0], padding[1]) # image_size is of shape (h,w); padding values is [left, top, right, bottom] borders padding = (int(image_size[1] * padding[0]), int( image_size[0] * padding[1]), int(image_size[1] * padding[2]), int(image_size[0] * padding[3])) augmenter_type = augmenter_type.strip().lower() if augmenter_type == "simple": return nn.Sequential( K.RandomCrop(size=image_size, padding=padding, pad_if_needed=pad_if_needed, padding_mode='reflect'), K.RandomHorizontalFlip(p=0.5), K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32), std=torch.tensor(dataset_std, dtype=torch.float32)), ) elif augmenter_type == "fixed": return nn.Sequential( K.RandomHorizontalFlip(p=0.5), # K.RandomVerticalFlip(p=0.2), K.RandomResizedCrop(size=image_size, scale=(0.8, 1.0), ratio=(1., 1.)), RandomAugmentation(p=0.5, augmentation=F.GaussianBlur2d( kernel_size=(3, 3), sigma=(1.5, 1.5), border_type='constant')), K.ColorJitter(contrast=(0.75, 1.5)), # additive Gaussian noise K.RandomErasing(p=0.1), # Multiply K.RandomAffine(degrees=(-25., 25.), translate=(0.2, 0.2), scale=(0.8, 1.2), shear=(-8., 8.)), K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32), std=torch.tensor(dataset_std, dtype=torch.float32)), ) elif augmenter_type in ["validation", "test"]: return nn.Sequential( K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32), std=torch.tensor(dataset_std, dtype=torch.float32)), ) elif augmenter_type == "randaugment": return nn.Sequential( K.RandomCrop(size=image_size, padding=padding, pad_if_needed=pad_if_needed, padding_mode='reflect'), K.RandomHorizontalFlip(p=0.5), RandAugmentNS(n=subset_size, m=10), K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32), std=torch.tensor(dataset_std, dtype=torch.float32)), ) else: raise NotImplementedError( f"\"{augmenter_type}\" is not a supported augmenter type")