Ejemplo n.º 1
0
 def train(self):
     for i in tqdm(range(self.iterations)):
         real_imgs = next(iter(self.dataloader))
         real_imgs = real_imgs.to(self.device, non_blocking=True)
         real_imgs = DiffAugment(real_imgs, policy=self.policy)
         
         cur_batch_size = real_imgs.shape[0]
         noise = torch.Tensor(cur_batch_size, self.latent_dim, 1, 1).normal_(0, 1).to(self.device, non_blocking=True)
         gen_imgs = self.G(noise)
         
         fake_imgs = DiffAugment(gen_imgs, policy=self.policy)
         
         self.D.zero_grad()
         self.train_discriminator(real_imgs, label='real')
         self.train_discriminator(fake_imgs, label='fake')
         self.D_optim.step()
         
         self.G.zero_grad()
         pred = self.D(fake_imgs, label='fake')
         loss3 = -pred.mean()
         loss3.backward()
         self.G_optim.step()
     
         
         if i % 5000 == 0:
             model_path = 'model' + str(i) + '.pth' 
             torch.save(self.G.state_dict(), model_path)
     
     
     noise = torch.Tensor(cur_batch_size, self.latent_dim, 1, 1).normal_(0, 1).to(self.device, non_blocking=True)
     gen_imgs = self.G(noise)
     img = gen_imgs[0].cpu().detach().numpy().copy()
     plt.imshow(img.transpose(1, 2, 0))
     plt.show()
Ejemplo n.º 2
0
    def forward(self,
                z,
                gy,
                x=None,
                dy=None,
                train_G=False,
                return_G_z=False,
                policy=False,
                CR=False,
                CR_augment=None):
        if z is not None:
            # If training G, enable grad tape
            with torch.set_grad_enabled(train_G):
                # Get Generator output given noise
                G_z = self.G(z, self.G.shared(gy))
                # Cast as necessary
                if self.G.fp16 and not self.D.fp16:
                    G_z = G_z.float()
                if self.D.fp16 and not self.G.fp16:
                    G_z = G_z.half()
        else:
            G_z = None

        D_input = torch.cat([img for img in [G_z, x] if img is not None], 0)
        D_class = torch.cat([label for label in [gy, dy] if label is not None],
                            0)
        D_input = DiffAugment(D_input, policy=policy)
        if CR:
            if CR_augment:
                x_CR_aug = torch.split(D_input, [G_z.shape[0], x.shape[0]])[1]
                if CR_augment.startswith('flip,'):
                    x_CR_aug = torch.where(
                        torch.randint(0,
                                      2,
                                      size=[x_CR_aug.size(0), 1, 1, 1],
                                      device=x_CR_aug.device) > 0,
                        x_CR_aug.flip(3), x_CR_aug)
                x_CR_aug = DiffAugment(x_CR_aug,
                                       policy=CR_augment.replace('flip,', ''))
                D_input = torch.cat([D_input, x_CR_aug], 0)
            else:
                D_input = torch.cat([D_input, x], 0)
            D_class = torch.cat([D_class, dy], 0)
        # Get Discriminator output
        D_out = self.D(D_input, D_class)
        if G_z is None:
            return D_out
        elif x is not None:
            if CR:
                return torch.split(D_out,
                                   [G_z.shape[0], x.shape[0], x.shape[0]])
            else:
                return torch.split(D_out, [G_z.shape[0], x.shape[0]])
        else:
            if return_G_z:
                return D_out, G_z
            else:
                return D_out
Ejemplo n.º 3
0
 def run_D(self, img, c, sync, phase):
     if self.diffaugment and phase in self.diffaugment_placement.split(','):
         img = DiffAugment(img, policy=self.diffaugment)
     if self.augment_pipe is not None:
         img = self.augment_pipe(img)
     with misc.ddp_sync(self.D, sync):
         logits = self.D(img, c)
     return logits
Ejemplo n.º 4
0
 def forward(self, z, gy, x=None, dy=None, train_G=False, return_G_z=False,
             split_D=False):
   # If training G, enable grad tape
   with torch.set_grad_enabled(train_G):
     # Get Generator output given noise
     G_z = self.G(z, self.G.shared(gy))
     # Cast as necessary
     if self.G.fp16 and not self.D.fp16:
       G_z = G_z.float()
     if self.D.fp16 and not self.G.fp16:
       G_z = G_z.half()
   # Split_D means to run D once with real data and once with fake,
   # rather than concatenating along the batch dimension.
   if split_D:
     if self.DiffAugment_policy != "None":
       D_fake = self.D(DiffAugment(G_z, policy=self.DiffAugment_policy), gy)
     else:
       D_fake = self.D(G_z, gy)
     
     if x is not None:
       if self.DiffAugment_policy != "None":
         D_real = self.D(DiffAugment(x, policy=self.DiffAugment_policy), dy)
       else:
         D_real = self.D(x, dy)
       return D_fake, D_real
     else:
       if return_G_z:
         return D_fake, G_z
       else:
         return D_fake
   # If real data is provided, concatenate it with the Generator's output
   # along the batch dimension for improved efficiency.
   else:
     D_input = torch.cat([G_z, x], 0) if x is not None else G_z
     D_class = torch.cat([gy, dy], 0) if dy is not None else gy
     # Get Discriminator output
     if self.DiffAugment_policy != "None":
       D_input = DiffAugment(D_input, policy=self.DiffAugment_policy)
     D_out = self.D(D_input, D_class)
     if x is not None:
       return torch.split(D_out, [G_z.shape[0], x.shape[0]]) # D_fake, D_real
     else:
       if return_G_z:
         return D_out, G_z
       else:
         return D_out
Ejemplo n.º 5
0
 def run_D(self, img, c, sync):
     if self.diffaugment:
         img = DiffAugment(img, policy=self.diffaugment)
     if self.augment_pipe is not None:
         img = self.augment_pipe(img)
     with misc.ddp_sync(self.D, sync):
         logits = self.D(img, c)
     return logits
Ejemplo n.º 6
0
def compute_accuracy(opts, batch_size=32, diff_aug=False):
    D = copy.deepcopy(opts.D).eval().requires_grad_(False).to(opts.device)
    train_dataset = opts.train_dataset
    train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset,
                                                   batch_size=batch_size)

    train_correct = 0
    train_all = 0

    for i, (train_img, train_c) in enumerate(tqdm(train_dataloader)):
        train_img = train_img.to(opts.device).to(torch.float32) / 127.5 - 1
        train_c = train_c.to(opts.device)
        if diff_aug and 'diffaugment' in opts.loss_kwargs:
            train_img = DiffAugment(train_img,
                                    policy=opts.loss_kwargs.diffaugment)
        gen_logits = D(train_img, train_c)
        train_all += train_img.shape[0]
        train_correct += torch.sum(gen_logits > 0).detach().item()

    train_accuracy = train_correct / train_all

    if opts.validation_dataset_kwargs != {}:
        validation_dataset = opts.validation_dataset
        validation_dataloader = torch.utils.data.DataLoader(
            dataset=validation_dataset, batch_size=batch_size)

        validation_correct = 0
        validation_all = 0

        for i, (validation_img,
                validation_c) in enumerate(tqdm(validation_dataloader)):
            validation_img = validation_img.to(opts.device).to(
                torch.float32) / 127.5 - 1
            validation_c = validation_c.to(opts.device)
            if diff_aug and 'diffaugment' in opts.loss_kwargs:
                validation_img = DiffAugment(
                    validation_img, policy=opts.loss_kwargs.diffaugment)
            gen_logits = D(validation_img, validation_c)
            validation_all += validation_img.shape[0]
            validation_correct += torch.sum(gen_logits > 0).detach().item()
        validation_accuracy = validation_correct / validation_all
    else:
        validation_accuracy = None

    return train_accuracy, validation_accuracy
Ejemplo n.º 7
0
    def forward(self, x, feature=False, test=False):

        h = x
        # DiffAugment is applied here.
        if test == False:
            h = DiffAugment(h, policy=self.policy)

        h = self.block1(h)
        h = self.block2(h)
        h = self.block3(h)
        h = self.block4(h)
        h = self.block5(h)
        h = self.activation(h)

        # Global average pooling
        h = h.sum(2).sum(2)
        output = self.l5(h)
        if feature:
            return h, output
        return output
Ejemplo n.º 8
0
def compute_accuracy_generated(opts, batch_size=32, diff_aug=False):
    D = copy.deepcopy(opts.D).eval().requires_grad_(False).to(opts.device)
    G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)

    train_correct = 0
    train_all = 0
    if opts.validation_dataset_kwargs != {}:
        all_z = torch.randn([len(opts.validation_dataset), G.z_dim],
                            device=opts.device)
    else:
        all_z = torch.randn([10000, G.z_dim], device=opts.device)
    z_loader = torch.utils.data.DataLoader(dataset=all_z,
                                           batch_size=batch_size)
    for i, z in enumerate(tqdm(z_loader)):
        fake_img = G(z, torch.empty([batch_size, 0], device=opts.device))
        if diff_aug and 'diffaugment' in opts.loss_kwargs:
            fake_img = DiffAugment(fake_img,
                                   policy=opts.loss_kwargs.diffaugment)
        logits = D(fake_img, torch.empty([batch_size, 0], device=opts.device))
        train_all += fake_img.shape[0]
        train_correct += torch.sum(logits <= 0).detach().item()
    result = train_correct / train_all
    return result
Ejemplo n.º 9
0
def train_cgan_concat(images,
                      labels,
                      netG,
                      netD,
                      save_images_folder,
                      save_models_folder=None):

    netG = netG.cuda()
    netD = netD.cuda()

    optimizerG = torch.optim.Adam(netG.parameters(),
                                  lr=lr_g,
                                  betas=(0.5, 0.999))
    optimizerD = torch.optim.Adam(netD.parameters(),
                                  lr=lr_d,
                                  betas=(0.5, 0.999))

    trainset = IMGs_dataset(images, labels, normalize=True)
    train_dataloader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=num_workers)
    unique_labels = np.sort(np.array(list(set(labels)))).astype(np.int)

    if save_models_folder is not None and resume_niters > 0:
        save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(
            gan_arch, num_D_steps, resume_niters)
        checkpoint = torch.load(save_file)
        netG.load_state_dict(checkpoint['netG_state_dict'])
        netD.load_state_dict(checkpoint['netD_state_dict'])
        optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
        optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
        torch.set_rng_state(checkpoint['rng_state'])
    #end if

    # printed images with labels between the 5-th quantile and 95-th quantile of training labels
    n_row = 10
    n_col = n_row
    z_fixed = torch.randn(n_row * n_col, dim_gan, dtype=torch.float).cuda()
    start_label = np.quantile(labels, 0.05)
    end_label = np.quantile(labels, 0.95)
    selected_labels = np.linspace(start_label, end_label, num=n_row)
    y_fixed = np.zeros(n_row * n_col)
    for i in range(n_row):
        curr_label = selected_labels[i]
        for j in range(n_col):
            y_fixed[i * n_col + j] = curr_label
    print(y_fixed)
    y_fixed = torch.from_numpy(y_fixed).type(torch.float).view(-1, 1).cuda()

    batch_idx = 0
    dataloader_iter = iter(train_dataloader)

    start_time = timeit.default_timer()
    for niter in range(resume_niters, niters):

        if batch_idx + 1 == len(train_dataloader):
            dataloader_iter = iter(train_dataloader)
            batch_idx = 0
        '''

        Train Generator: maximize log(D(G(z)))

        '''

        netG.train()

        # get training images
        _, batch_train_labels = dataloader_iter.next()
        assert batch_size == batch_train_labels.shape[0]
        batch_train_labels = batch_train_labels.type(torch.long).cuda()
        batch_idx += 1

        # Sample noise and labels as generator input
        z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()

        #generate fake images
        batch_fake_images = netG(z, batch_train_labels)

        # Loss measures generator's ability to fool the discriminator
        if use_DiffAugment:
            dis_out = netD(DiffAugment(batch_fake_images, policy=policy),
                           batch_train_labels)
        else:
            dis_out = netD(batch_fake_images, batch_train_labels)

        if loss_type == "vanilla":
            dis_out = torch.nn.Sigmoid()(dis_out)
            g_loss = -torch.mean(torch.log(dis_out + 1e-20))
        elif loss_type == "hinge":
            g_loss = -torch.mean(dis_out)

        optimizerG.zero_grad()
        g_loss.backward()
        optimizerG.step()
        '''

        Train Discriminator: maximize log(D(x)) + log(1 - D(G(z)))

        '''

        for _ in range(num_D_steps):

            if batch_idx + 1 == len(train_dataloader):
                dataloader_iter = iter(train_dataloader)
                batch_idx = 0

            # get training images
            batch_train_images, batch_train_labels = dataloader_iter.next()
            assert batch_size == batch_train_images.shape[0]
            batch_train_images = batch_train_images.type(torch.float).cuda()
            batch_train_labels = batch_train_labels.type(torch.long).cuda()
            batch_idx += 1

            # Measure discriminator's ability to classify real from generated samples
            if use_DiffAugment:
                real_dis_out = netD(
                    DiffAugment(batch_train_images, policy=policy),
                    batch_train_labels)
                fake_dis_out = netD(
                    DiffAugment(batch_fake_images.detach(), policy=policy),
                    batch_train_labels.detach())
            else:
                real_dis_out = netD(batch_train_images, batch_train_labels)
                fake_dis_out = netD(batch_fake_images.detach(),
                                    batch_train_labels.detach())

            if loss_type == "vanilla":
                real_dis_out = torch.nn.Sigmoid()(real_dis_out)
                fake_dis_out = torch.nn.Sigmoid()(fake_dis_out)
                d_loss_real = -torch.log(real_dis_out + 1e-20)
                d_loss_fake = -torch.log(1 - fake_dis_out + 1e-20)
            elif loss_type == "hinge":
                d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out)
                d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out)
            d_loss = (d_loss_real + d_loss_fake).mean()

            optimizerD.zero_grad()
            d_loss.backward()
            optimizerD.step()

        if (niter + 1) % 20 == 0:
            print(
                "cGAN(concat)-%s: [Iter %d/%d] [D loss: %.4f] [G loss: %.4f] [D out real:%.4f] [D out fake:%.4f] [Time: %.4f]"
                % (gan_arch, niter + 1, niters, d_loss.item(), g_loss.item(),
                   real_dis_out.mean().item(), fake_dis_out.mean().item(),
                   timeit.default_timer() - start_time))

        if (niter + 1) % visualize_freq == 0:
            netG.eval()
            with torch.no_grad():
                gen_imgs = netG(z_fixed, y_fixed)
                gen_imgs = gen_imgs.detach()
            save_image(gen_imgs.data,
                       save_images_folder + '/{}.png'.format(niter + 1),
                       nrow=n_row,
                       normalize=True)

        if save_models_folder is not None and (
            (niter + 1) % save_niters_freq == 0 or (niter + 1) == niters):
            save_file = save_models_folder + "/cGAN_{}_nDsteps_{}_checkpoint_intrain/cGAN_checkpoint_niters_{}.pth".format(
                gan_arch, num_D_steps, niter + 1)
            os.makedirs(os.path.dirname(save_file), exist_ok=True)
            torch.save(
                {
                    'netG_state_dict': netG.state_dict(),
                    'netD_state_dict': netD.state_dict(),
                    'optimizerG_state_dict': optimizerG.state_dict(),
                    'optimizerD_state_dict': optimizerD.state_dict(),
                    'rng_state': torch.get_rng_state()
                }, save_file)
    #end for niter

    return netG, netD
Ejemplo n.º 10
0
def main():
    args = get_args()
    weight_path, img_path = directory_path(args)

    # CUDA setting
    if not torch.cuda.is_available():
        raise ValueError("Should buy GPU!")

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.backends.cudnn.benchmark = True

    # dataloading
    train_loader, s_dlen, _n_cls = data_loader2(args)

    fixed_z = torch.randn(200, 10, 128)
    fixed_img_list, fixed_label_list = pick_fixed_img(args, train_loader, 200)

    # initialize model
    gen, dis = select_model(args, _n_cls)

    opt_gen = optim.Adam(gen.parameters(), args.lr, (args.beta1, args.beta2))
    opt_dis = optim.Adam(dis.parameters(), args.lr, (args.beta1, args.beta2))

    gen_criterion = L.GenLoss(args.loss_type, args.relativistic_loss)
    dis_criterion = L.DisLoss(args.loss_type, args.relativistic_loss)

    criterion = nn.CrossEntropyLoss()

    # Training loop
    for n_iter in tqdm.tqdm(range(0, args.max_iteration)):

        if n_iter >= args.lr_decay_start:
            decay_lr(opt_gen, args.max_iteration, args.lr_decay_start, args.lr)
            decay_lr(opt_dis, args.max_iteration, args.lr_decay_start, args.lr)

        # ==================== Beginning of 1 iteration. ====================
        _l_g = .0
        cumulative_loss_dis = .0
        for i in range(args.n_dis):
            if i == 0:
                fake, pseudo_y, _ = sample_from_gen(args, dev, _n_cls, gen)
                fake = DiffAugment(fake, policy=policy)
                dis_fake, dis_mi, dis_c = dis(fake, pseudo_y)
                dis_real = None

                loss_gen = gen_criterion(dis_fake, dis_real)

                ##################################################
                loss_mi = criterion(dis_mi, pseudo_y)
                loss_c = criterion(dis_c, pseudo_y)

                loss_gen = loss_gen + args.lambda_c * (loss_c - loss_mi)
                ##################################################

                gen.zero_grad()
                loss_gen.backward()
                opt_gen.step()
                _l_g += loss_gen.item()

            fake, pseudo_y, _ = sample_from_gen(args, dev, _n_cls, gen)
            real, y = sample_from_data(args, dev, train_loader)

            fake = DiffAugment(fake, policy=policy)
            real = DiffAugment(real, policy=policy)

            dis_fake, dis_fake_mi, dis_fake_c = dis(fake, pseudo_y)
            dis_real, dis_real_mi, dis_real_c = dis(real, y)

            ######################################################
            loss_dis_mi = criterion(dis_fake_mi, pseudo_y)
            loss_dis_c = criterion(dis_real_c, y)
            ######################################################

            loss_dis = dis_criterion(dis_fake, dis_real)
            loss_dis = loss_dis + args.lambda_c * (loss_dis_mi + loss_dis_c)

            dis.zero_grad()
            loss_dis.backward()
            opt_dis.step()

            cumulative_loss_dis += loss_dis.item()
        # ==================== End of 1 iteration. ====================

        if n_iter % args.log_interval == 0:
            tqdm.tqdm.write(
                'iteration: {:07d}/{:07d}, loss gen: {:05f}, loss dis {:05f}'
                ' loss mi {:05f}, loss c {:05f}'.format(n_iter, args.max_iteration, _l_g,
                                                        cumulative_loss_dis,
                                                        args.lambda_c * loss_dis_mi,
                                                        args.lambda_c * loss_dis_c))

        if n_iter % args.checkpoint_interval == 0:
            #Save checkpoints!
            utils.save_checkpoints(args, n_iter, gen, opt_gen, dis, opt_dis, weight_path)
            if args.dataset == "omniglot":
                utils.save_img(fixed_img_list, fixed_label_list, fixed_z, gen,
                               32, 28, img_path, n_iter, device=dev)
            elif args.dataset == "vgg" or args.dataset == "animal":
                utils.save_img(fixed_img_list, fixed_label_list, fixed_z, gen,
                               84, 64, img_path, n_iter, device=dev)
            elif args.dataset == "cub":
                utils.save_img(fixed_img_list, fixed_label_list, fixed_z, gen,
                               72, 64, img_path, n_iter, device=dev)
            else:
                raise Exception("Enter model omniglot or vgg or animal or cub")

    if args.test:
        shutil.rmtree(args.results_root)
Ejemplo n.º 11
0
def train_ccgan(kernel_sigma, kappa, train_images, train_labels, netG, netD, net_y2h, save_images_folder, save_models_folder = None, clip_label=False):
    
    '''
    Note that train_images are not normalized to [-1,1]
    '''

    netG = netG.cuda()
    netD = netD.cuda()
    net_y2h = net_y2h.cuda()
    net_y2h.eval()

    optimizerG = torch.optim.Adam(netG.parameters(), lr=lr_g, betas=(0.5, 0.999))
    optimizerD = torch.optim.Adam(netD.parameters(), lr=lr_d, betas=(0.5, 0.999))

    if save_models_folder is not None and resume_niters>0:
        save_file = save_models_folder + "/CcGAN_{}_{}_nDsteps_{}_checkpoint_intrain/CcGAN_checkpoint_niters_{}.pth".format(gan_arch, threshold_type, num_D_steps, resume_niters)
        checkpoint = torch.load(save_file)
        netG.load_state_dict(checkpoint['netG_state_dict'])
        netD.load_state_dict(checkpoint['netD_state_dict'])
        optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
        optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
        torch.set_rng_state(checkpoint['rng_state'])
    #end if

    #################
    unique_train_labels = np.sort(np.array(list(set(train_labels))))

    # printed images with labels between the 5-th quantile and 95-th quantile of training labels
    n_row=10; n_col = n_row
    z_fixed = torch.randn(n_row*n_col, dim_gan, dtype=torch.float).cuda()
    start_label = np.quantile(train_labels, 0.05)
    end_label = np.quantile(train_labels, 0.95)
    selected_labels = np.linspace(start_label, end_label, num=n_row)
    y_fixed = np.zeros(n_row*n_col)
    for i in range(n_row):
        curr_label = selected_labels[i]
        for j in range(n_col):
            y_fixed[i*n_col+j] = curr_label
    print(y_fixed)
    y_fixed = torch.from_numpy(y_fixed).type(torch.float).view(-1,1).cuda()


    start_time = timeit.default_timer()
    for niter in range(resume_niters, niters):

        '''  Train Discriminator   '''
        for _ in range(num_D_steps):

            ## randomly draw batch_size_disc y's from unique_train_labels
            batch_target_labels_in_dataset = np.random.choice(unique_train_labels, size=batch_size_disc, replace=True)
            ## add Gaussian noise; we estimate image distribution conditional on these labels
            batch_epsilons = np.random.normal(0, kernel_sigma, batch_size_disc)
            batch_target_labels = batch_target_labels_in_dataset + batch_epsilons

            ## find index of real images with labels in the vicinity of batch_target_labels
            ## generate labels for fake image generation; these labels are also in the vicinity of batch_target_labels
            batch_real_indx = np.zeros(batch_size_disc, dtype=int) #index of images in the datata; the labels of these images are in the vicinity
            batch_fake_labels = np.zeros(batch_size_disc)

            for j in range(batch_size_disc):
                ## index for real images
                if threshold_type == "hard":
                    indx_real_in_vicinity = np.where(np.abs(train_labels-batch_target_labels[j])<= kappa)[0]
                else:
                    # reverse the weight function for SVDL
                    indx_real_in_vicinity = np.where((train_labels-batch_target_labels[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]

                ## if the max gap between two consecutive ordered unique labels is large, it is possible that len(indx_real_in_vicinity)<1
                while len(indx_real_in_vicinity)<1:
                    batch_epsilons_j = np.random.normal(0, kernel_sigma, 1)
                    batch_target_labels[j] = batch_target_labels_in_dataset[j] + batch_epsilons_j
                    if clip_label:
                        batch_target_labels = np.clip(batch_target_labels, 0.0, 1.0)
                    ## index for real images
                    if threshold_type == "hard":
                        indx_real_in_vicinity = np.where(np.abs(train_labels-batch_target_labels[j])<= kappa)[0]
                    else:
                        # reverse the weight function for SVDL
                        indx_real_in_vicinity = np.where((train_labels-batch_target_labels[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]
                #end while len(indx_real_in_vicinity)<1

                assert len(indx_real_in_vicinity)>=1

                batch_real_indx[j] = np.random.choice(indx_real_in_vicinity, size=1)[0]

                ## labels for fake images generation
                if threshold_type == "hard":
                    lb = batch_target_labels[j] - kappa
                    ub = batch_target_labels[j] + kappa
                else:
                    lb = batch_target_labels[j] - np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
                    ub = batch_target_labels[j] + np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
                lb = max(0.0, lb); ub = min(ub, 1.0)
                assert lb<=ub
                assert lb>=0 and ub>=0
                assert lb<=1 and ub<=1
                batch_fake_labels[j] = np.random.uniform(lb, ub, size=1)[0]
            #end for j

            ## draw real image/label batch from the training set
            batch_real_images = torch.from_numpy(normalize_images(train_images[batch_real_indx]))
            batch_real_images = batch_real_images.type(torch.float).cuda()
            batch_real_labels = train_labels[batch_real_indx]
            batch_real_labels = torch.from_numpy(batch_real_labels).type(torch.float).cuda()


            ## generate the fake image batch
            batch_fake_labels = torch.from_numpy(batch_fake_labels).type(torch.float).cuda()
            z = torch.randn(batch_size_disc, dim_gan, dtype=torch.float).cuda()
            batch_fake_images = netG(z, net_y2h(batch_fake_labels))

            ## target labels on gpu
            batch_target_labels = torch.from_numpy(batch_target_labels).type(torch.float).cuda()

            ## weight vector
            if threshold_type == "soft":
                real_weights = torch.exp(-kappa*(batch_real_labels-batch_target_labels)**2).cuda()
                fake_weights = torch.exp(-kappa*(batch_fake_labels-batch_target_labels)**2).cuda()
            else:
                real_weights = torch.ones(batch_size_disc, dtype=torch.float).cuda()
                fake_weights = torch.ones(batch_size_disc, dtype=torch.float).cuda()
            #end if threshold type

            # forward pass
            if use_DiffAugment:
                real_dis_out = netD(DiffAugment(batch_real_images, policy=policy), net_y2h(batch_target_labels))
                fake_dis_out = netD(DiffAugment(batch_fake_images.detach(), policy=policy), net_y2h(batch_target_labels))
            else:
                real_dis_out = netD(batch_real_images, net_y2h(batch_target_labels))
                fake_dis_out = netD(batch_fake_images.detach(), net_y2h(batch_target_labels))

            if loss_type == "vanilla":
                real_dis_out = torch.nn.Sigmoid()(real_dis_out)
                fake_dis_out = torch.nn.Sigmoid()(fake_dis_out)
                d_loss_real = - torch.log(real_dis_out+1e-20)
                d_loss_fake = - torch.log(1-fake_dis_out+1e-20)
            elif loss_type == "hinge":
                d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out)
                d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out)
            else:
                raise ValueError('Not supported loss type!!!')

            d_loss = torch.mean(real_weights.view(-1) * d_loss_real.view(-1)) + torch.mean(fake_weights.view(-1) * d_loss_fake.view(-1))

            optimizerD.zero_grad()
            d_loss.backward()
            optimizerD.step()

        #end for step_D_index



        '''  Train Generator   '''
        netG.train()

        # generate fake images
        ## randomly draw batch_size_gene y's from unique_train_labels
        batch_target_labels_in_dataset = np.random.choice(unique_train_labels, size=batch_size_gene, replace=True)
        ## add Gaussian noise; we estimate image distribution conditional on these labels
        batch_epsilons = np.random.normal(0, kernel_sigma, batch_size_gene)
        batch_target_labels = batch_target_labels_in_dataset + batch_epsilons
        batch_target_labels = torch.from_numpy(batch_target_labels).type(torch.float).cuda()

        z = torch.randn(batch_size_gene, dim_gan, dtype=torch.float).cuda()
        batch_fake_images = netG(z, net_y2h(batch_target_labels))

        # loss
        if use_DiffAugment:
            dis_out = netD(DiffAugment(batch_fake_images, policy=policy), net_y2h(batch_target_labels))
        else:
            dis_out = netD(batch_fake_images, net_y2h(batch_target_labels))
        if loss_type == "vanilla":
            dis_out = torch.nn.Sigmoid()(dis_out)
            g_loss = - torch.mean(torch.log(dis_out+1e-20))
        elif loss_type == "hinge":
            g_loss = - dis_out.mean()

        # backward
        optimizerG.zero_grad()
        g_loss.backward()
        optimizerG.step()

        # print loss
        if (niter+1) % 20 == 0:
            print ("CcGAN,%s: [Iter %d/%d] [D loss: %.4e] [G loss: %.4e] [real prob: %.3f] [fake prob: %.3f] [Time: %.4f]" % (gan_arch, niter+1, niters, d_loss.item(), g_loss.item(), real_dis_out.mean().item(), fake_dis_out.mean().item(), timeit.default_timer()-start_time))

        if (niter+1) % visualize_freq == 0:
            netG.eval()
            with torch.no_grad():
                gen_imgs = netG(z_fixed, net_y2h(y_fixed))
                gen_imgs = gen_imgs.detach().cpu()
                save_image(gen_imgs.data, save_images_folder + '/{}.png'.format(niter+1), nrow=n_row, normalize=True)

        if save_models_folder is not None and ((niter+1) % save_niters_freq == 0 or (niter+1) == niters):
            save_file = save_models_folder + "/CcGAN_{}_{}_nDsteps_{}_checkpoint_intrain/CcGAN_checkpoint_niters_{}.pth".format(gan_arch, threshold_type, num_D_steps, niter+1)
            os.makedirs(os.path.dirname(save_file), exist_ok=True)
            torch.save({
                    'netG_state_dict': netG.state_dict(),
                    'netD_state_dict': netD.state_dict(),
                    'optimizerG_state_dict': optimizerG.state_dict(),
                    'optimizerD_state_dict': optimizerD.state_dict(),
                    'rng_state': torch.get_rng_state()
            }, save_file)
    #end for niter
    return netG, netD
Ejemplo n.º 12
0
    with torch.no_grad():
        embed, _ = arcface(
            F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                          mode='bilinear',
                          align_corners=True))
    same_person = same_person.to(device)
    Xt.requires_grad = True
    embed.requires_grad = True

    # train G
    D.requires_grad_(False)
    opt_G.zero_grad()
    with autocast():
        Y, Xt_attr = G(Xt, embed)

        Di = D(DiffAugment(Y, policy=policy))
        L_adv = 0

        for di in Di:
            #L_adv += hinge_loss(di[0], True)
            L_adv -= di[0].mean()
        L_adv /= len(Di)

        Y_aligned = Y[:, :, 19:237, 19:237]
        ZY, _ = arcface(
            F.interpolate(Y_aligned, [112, 112],
                          mode='bilinear',
                          align_corners=True))
        L_id = (1 - torch.cosine_similarity(embed, ZY, dim=1)).mean()

        Y_attr = G.get_attr(Y)