Exemple #1
0
def train(args):
    data_root = args.path
    total_iterations = args.iter
    checkpoint = args.ckpt
    batch_size = args.batch_size
    im_size = args.im_size
    ndf = 64
    ngf = 64
    nz = 256
    nlr = 0.0002
    nbeta1 = 0.5
    multi_gpu = False
    dataloader_workers = 8
    current_iteration = 0
    save_interval = 100

    device = torch.device("cpu")
    if torch.cuda.is_available():
        device = torch.device("cuda:0")

    transform_list = [
        transforms.Resize((int(im_size), int(im_size))),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ]
    trans = transforms.Compose(transform_list)

    if 'lmdb' in data_root:
        from operation import MultiResolutionDataset
        dataset = MultiResolutionDataset(data_root, trans, 1024)
    else:
        dataset = ImageFolder(root=data_root, transform=trans)

    dataloader = iter(
        DataLoader(dataset,
                   batch_size=batch_size,
                   shuffle=False,
                   sampler=InfiniteSamplerWrapper(dataset),
                   num_workers=0,
                   pin_memory=True))
    '''
    loader = MultiEpochsDataLoader(dataset, batch_size=batch_size, 
                               shuffle=True, num_workers=dataloader_workers, 
                               pin_memory=True)
    dataloader = CudaDataLoader(loader, 'cuda')
    '''

    #from model_s import Generator, Discriminator
    netG = Generator(ngf=ngf, nz=nz, im_size=im_size)
    netG.apply(weights_init)

    netD = Discriminator(ndf=ndf, im_size=im_size)
    netD.apply(weights_init)

    netG.to(device)
    netD.to(device)

    avg_param_G = copy_G_params(netG)

    fixed_noise = torch.FloatTensor(8, nz).normal_(0, 1).to(device)

    if torch.cuda.is_available():
        netG = nn.DataParallel(netG)
        netD = nn.DataParallel(netD)

    optimizerG = optim.Adam(netG.parameters(), lr=nlr, betas=(nbeta1, 0.999))
    optimizerD = optim.Adam(netD.parameters(), lr=nlr, betas=(nbeta1, 0.999))

    if checkpoint != 'None':
        ckpt = torch.load(checkpoint)
        netG.load_state_dict(ckpt['g'])
        netD.load_state_dict(ckpt['d'])
        avg_param_G = ckpt['g_ema']
        optimizerG.load_state_dict(ckpt['opt_g'])
        optimizerD.load_state_dict(ckpt['opt_d'])
        current_iteration = int(checkpoint.split('_')[-1].split('.')[0])
        del ckpt

    for iteration in range(current_iteration, total_iterations + 1):
        real_image = next(dataloader)
        if torch.cuda.is_available():
            real_image = real_image.cuda(non_blocking=True)
        current_batch_size = real_image.size(0)
        noise = torch.Tensor(current_batch_size, nz).normal_(0, 1).to(device)

        fake_images = netG(noise)

        real_image = DiffAugment(real_image, policy=policy)
        fake_images = [
            DiffAugment(fake, policy=policy) for fake in fake_images
        ]

        ## 2. train Discriminator
        netD.zero_grad()

        err_dr, rec_img_all, rec_img_small, rec_img_part = train_d(
            netD, real_image, label="real")
        train_d(netD, [fi.detach() for fi in fake_images], label="fake")
        optimizerD.step()

        ## 3. train Generator
        netG.zero_grad()
        pred_g = netD(fake_images, "fake")
        err_g = -pred_g.mean()

        err_g.backward()
        optimizerG.step()

        for p, avg_p in zip(netG.parameters(), avg_param_G):
            avg_p.mul_(0.999).add_(0.001 * p.data)

        if iteration % save_interval == 0:
            print("GAN: loss d: %.5f    loss g: %.5f" %
                  (err_dr.item(), -err_g.item()))

        if iteration % (save_interval) == 0:
            backup_para = copy_G_params(netG)
            load_params(netG, avg_param_G)
            saved_model_folder, saved_image_folder = get_dir(args)
            with torch.no_grad():
                vutils.save_image(netG(fixed_noise)[0].add(1).mul(0.5),
                                  saved_image_folder + '/%d.jpg' % iteration,
                                  nrow=4)
                vutils.save_image(
                    torch.cat([
                        F.interpolate(real_image, 128), rec_img_all,
                        rec_img_small, rec_img_part
                    ]).add(1).mul(0.5),
                    saved_image_folder + '/rec_%d.jpg' % iteration)
            load_params(netG, backup_para)

        if iteration % (save_interval *
                        50) == 0 or iteration == total_iterations:
            backup_para = copy_G_params(netG)
            load_params(netG, avg_param_G)
            torch.save({
                'g': netG.state_dict(),
                'd': netD.state_dict()
            }, saved_model_folder + '/%d.pth' % iteration)
            load_params(netG, backup_para)
            torch.save(
                {
                    'g': netG.state_dict(),
                    'd': netD.state_dict(),
                    'g_ema': avg_param_G,
                    'opt_g': optimizerG.state_dict(),
                    'opt_d': optimizerD.state_dict()
                }, saved_model_folder + '/all_%d.pth' % iteration)
Exemple #2
0
    args = parser.parse_args()

    ngf = 64
    nz = 256 # latent dimension
    use_cuda = True
    device = 'cuda:0'
    im_size = 256

    netG = Generator(ngf=ngf, nz=nz, im_size=im_size, \
                infogan_latent_dim=args.latent_dim, spatial_code_dim=args.spatial_code_dim)
    netG.apply(weights_init)

    netG.to(device)

    avg_param_G = copy_G_params(netG)


    dir_name = args.dir_name
    checkpoint = "./train_results/"+dir_name+"/models/all_%d.pth"%(args.ckpt_iter)
    ckpt = torch.load(checkpoint)
    netG.load_state_dict(ckpt['g'])
    avg_param_G = ckpt['g_ema']
    load_params(netG, avg_param_G)
    netG.eval()

    if args.method == 'latent_traversal':
        latent_traversal(netG, args.num_steps-1)
    elif args.method == 'sample_fixed_latent':
        sample_fixed_latent(netG, args.num_samples, args.latent_dim)
    elif args.method == 'sample_fixed_noise':
Exemple #3
0
def train(args):

    data_root = args.path
    total_iterations = args.iter
    checkpoint = args.ckpt
    batch_size = args.batch_size
    im_size = args.im_size
    info_lambda = args.info_lambda

    ndf = 64
    ngf = 64
    nz = args.nz # latent dimension
    nlr = 0.0002
    nbeta1 = 0.5
    use_cuda = not args.use_cpu
    multi_gpu = False
    dataloader_workers = 8
    current_iteration = 0
    save_interval = 100
    saved_model_folder, saved_image_folder = get_dir(args)
    
    device = torch.device("cpu")
    if use_cuda:
        device = torch.device("cuda:%d"%(args.cuda))

    transform_list = [
            transforms.Resize((int(im_size),int(im_size))),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ]
    trans = transforms.Compose(transform_list)
    
    dataset = ImageFolder(root=data_root, transform=trans)
    dataloader = iter(DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True,
                      sampler=InfiniteSamplerWrapper(dataset), num_workers=dataloader_workers, pin_memory=True))


    netG = Generator(ngf=ngf, nz=nz, im_size=im_size, sle=(not args.no_sle), \
                    big=args.big, infogan_latent_dim=args.infogan_latent_dim, spatial_code_dim=args.spatial_code_dim)
    netG.apply(weights_init)

    netD = Discriminator(ndf=ndf, im_size=im_size, sle=(not args.no_sle), decode=(not args.no_decode), \
                        big=args.big, infogan_latent_dim=args.infogan_latent_dim, spatial_code_dim=args.spatial_code_dim)
    netD.apply(weights_init)

    netG.to(device)
    netD.to(device)

    from pytorch_model_summary import summary

    # print(summary( netG, torch.zeros((1, 256+args.infogan_latent_dim+spatial_infogan_size*args.spatial_code_dim)).to(device), show_input=False))
    # print(summary( netD, torch.zeros((1, 3, im_size, im_size)).to(device), 'True', show_input=False))

    avg_param_G = copy_G_params(netG)

    if args.use_infogan:
        fixed_noise = torch.FloatTensor(64, nz).normal_(0, 1)
        latent = torch.FloatTensor(64, args.infogan_latent_dim+spatial_infogan_size*args.spatial_code_dim).uniform_(-1, 1)
        fixed_noise = torch.cat([fixed_noise, latent], dim=1).to(device)
        
        if not args.spatial_code_dim:
            num_steps = 7
            upper = 2
            z = torch.FloatTensor(1, 1, args.infogan_latent_dim).uniform_(-1, 1)
            z = z.expand(args.infogan_latent_dim, num_steps+1, -1).clone()
            intervals = [-upper+i*2*upper/num_steps for i in range(num_steps+1)]
            for i in range(args.infogan_latent_dim):
                for j in range(num_steps+1):
                    z[i, j, i] = intervals[j]
            z = z.reshape(-1, args.infogan_latent_dim)
            fixed_noise_1 = torch.FloatTensor(1, nz).normal_(0, 1).repeat(z.shape[0], 1)
            traversal_z_1 = torch.cat([fixed_noise_1, z], dim=1).to(device)
            fixed_noise_2 = torch.FloatTensor(1, nz).normal_(0, 1).repeat(z.shape[0], 1)
            traversal_z_2 = torch.cat([fixed_noise_2, z], dim=1).to(device)
        else:
            num_steps = 7
            upper = 2

            # same spatial code, traverse latent
            z = torch.FloatTensor(1, 1, args.infogan_latent_dim).uniform_(-1, 1)
            z = z.expand(args.infogan_latent_dim, num_steps+1, -1).clone()
            intervals = [-upper+i*2*upper/num_steps for i in range(num_steps+1)]
            for i in range(args.infogan_latent_dim):
                for j in range(num_steps+1):
                    z[i, j, i] = intervals[j]
            z = z.reshape(-1, args.infogan_latent_dim)

            sz = torch.FloatTensor(1, args.spatial_code_dim).uniform_(-1, 1).repeat(1, spatial_infogan_size)
            sz = sz.repeat(z.shape[0], 1)
            fixed_noise_1 = torch.FloatTensor(1, nz).normal_(0, 1).repeat(z.shape[0], 1)
            traversal_z_1 = torch.cat([fixed_noise_1, z, sz], dim=1).to(device)
            fixed_noise_2 = torch.FloatTensor(1, nz).normal_(0, 1).repeat(z.shape[0], 1)
            traversal_z_2 = torch.cat([fixed_noise_2, z, sz], dim=1).to(device)

            # traverse lower right
            sz = torch.FloatTensor(1, 1, args.spatial_code_dim).uniform_(-1, 1)
            sz = sz.expand(args.spatial_code_dim, num_steps+1, -1).clone()
            for i in range(args.spatial_code_dim):
                for j in range(num_steps+1):
                    sz[i, j, i] = intervals[j]
            # corner latent
            sz = sz.reshape(-1, args.spatial_code_dim)
            # entire latent
            z = torch.FloatTensor(1, args.infogan_latent_dim).uniform_(-1, 1).repeat(sz.shape[0], 1)
            # spatial latent except corner
            sz_all = torch.FloatTensor(1, args.spatial_code_dim*(spatial_infogan_size-1)).uniform_(-1, 1).repeat(sz.shape[0], 1)
            fixed_noise_3 = torch.FloatTensor(1, nz).normal_(0, 1).repeat(sz.shape[0], 1)
            traversal_corner = torch.cat([fixed_noise_3, z, sz_all, sz], dim=1).to(device)

    else:
        fixed_noise = torch.FloatTensor(64, nz).normal_(0, 1).to(device)

    if multi_gpu:
        netG = nn.DataParallel(netG.cuda())
        netD = nn.DataParallel(netD.cuda())

    optimizerG = optim.Adam(netG.parameters(), lr=nlr, betas=(nbeta1, 0.999))
    optimizerD = optim.Adam(netD.parameters(), lr=nlr, betas=(nbeta1, 0.999))
    if args.use_infogan:
        optimizerQ = optim.Adam([{'params': netD.latent_from_128.parameters()}, \
                                {'params': netD.conv_q.parameters()}], \
                                lr=args.q_lr, betas=(nbeta1, 0.999))
    if args.spatial_code_dim:
        optimizerQ = optim.Adam([{'params': netD.latent_from_128.parameters()}, \
                                {'params': netD.conv_q.parameters()},
                                {'params': netD.spatial_latent_from_128.parameters()}],
                                # {'params': netD.spatial_conv_q.parameters()}], \
                                lr=args.q_lr, betas=(nbeta1, 0.999))
    
    if checkpoint != 'None':
        ckpt = torch.load(checkpoint)
        netG.load_state_dict(ckpt['g'])
        netD.load_state_dict(ckpt['d'])
        avg_param_G = ckpt['g_ema']
        optimizerG.load_state_dict(ckpt['opt_g'])
        optimizerD.load_state_dict(ckpt['opt_d'])
        if args.use_infogan:
            optimizerQ.load_state_dict(ckpt['opt_q'])
        current_iteration = int(checkpoint.split('_')[-1].split('.')[0])
        del ckpt
    
    vutils.save_image( next(dataloader).add(1).mul(0.5), saved_image_folder+'/real_image.jpg' )
    
    for iteration in tqdm(range(current_iteration, total_iterations+1)):
        real_image = next(dataloader)
        real_image = real_image.to(device)
        current_batch_size = real_image.size(0)

        if args.use_infogan:
            noise = torch.Tensor(current_batch_size, nz).normal_(0, 1)
            latent = torch.Tensor(current_batch_size, args.infogan_latent_dim+spatial_infogan_size*args.spatial_code_dim).uniform_(-1, 1)
            noise = torch.cat([noise, latent], dim=1).to(device)
            latent = latent.to(device)
        else:
            noise = torch.Tensor(current_batch_size, nz).normal_(0, 1).to(device)

        fake_images = netG(noise)

        real_image = DiffAugment(real_image, policy=policy)
        fake_images = [DiffAugment(fake, policy=policy) for fake in fake_images]
        
        ## 2. train Discriminator
        netD.zero_grad()

    
        err_dr = train_d( netD, real_image, label="real", decode=(not args.no_decode) ) #err on real data backproped
        if not args.no_decode:
            err_dr, rec_img_all, rec_img_small, rec_img_part = err_dr
            

        train_d(netD, [fi.detach() for fi in fake_images], label="fake", use_infogan=args.use_infogan)
        optimizerD.step()
        
        ## 3. train Generator
        netG.zero_grad()
        if args.use_infogan:
            netD.zero_grad()
            pred_g, q_pred = netD(fake_images, "fake")
            if not args.spatial_code_dim:
                q_mu, q_logvar = q_pred[:, :args.infogan_latent_dim], q_pred[:, args.infogan_latent_dim:]
                info_total_loss = criterionQ_con(latent, q_mu, q_logvar.exp())
            else:
                q_pred, s_list = q_pred
                q_mu, q_logvar = q_pred[:, :args.infogan_latent_dim], q_pred[:, args.infogan_latent_dim:]
                info_loss = criterionQ_con(latent[:, :args.infogan_latent_dim], q_mu, q_logvar.exp())
                s_info_loss = 0
                for part in range(4):
                    sq_mu, sq_logvar = s_list[part][:, :args.spatial_code_dim], s_list[part][:, args.spatial_code_dim:]
                    s_info_loss += criterionQ_con(latent[:, \
                        args.infogan_latent_dim+part*args.spatial_code_dim : \
                        args.infogan_latent_dim+(part+1)*args.spatial_code_dim], sq_mu, sq_logvar.exp())

                info_total_loss = s_info_loss/4 + info_loss

            err_g = info_total_loss*args.info_lambda - pred_g.mean()
            err_g.backward()
            optimizerG.step()
            optimizerQ.step()
        else:
            pred_g = netD(fake_images, "fake")
            err_g = -pred_g.mean()

            err_g.backward()
            optimizerG.step()

        for p, avg_p in zip(netG.parameters(), avg_param_G):
            avg_p.mul_(0.999).add_(0.001 * p.data)

        if iteration % 100 == 0:
            if args.spatial_code_dim:
                print("GAN: loss d: %.5f    loss g: %.5f    loss info: %.5f    loss s info: %.5f"%(err_dr, -err_g.item(), -info_loss*args.info_lambda, -s_info_loss*args.info_lambda/4))
            elif args.infogan_latent_dim:
                print("GAN: loss d: %.5f    loss g: %.5f    loss info: %.5f"%(err_dr, -err_g.item(), -info_total_loss*args.info_lambda))
            else:
                print("GAN: loss d: %.5f    loss g: %.5f"%(err_dr, -err_g.item()))

        if iteration % (save_interval*10) == 0:
            backup_para = copy_G_params(netG)
            load_params(netG, avg_param_G)
            with torch.no_grad():
                vutils.save_image(netG(fixed_noise)[0].add(1).mul(0.5), saved_image_folder+'/%d.jpg'%iteration, nrow=8)
                if args.use_infogan:
                    vutils.save_image(netG(traversal_z_1)[0].add(1).mul(0.5), saved_image_folder+'/trav1_%d.jpg'%iteration, nrow=num_steps+1)
                    vutils.save_image(netG(traversal_z_2)[0].add(1).mul(0.5), saved_image_folder+'/trav2_%d.jpg'%iteration, nrow=num_steps+1)
                    if args.spatial_code_dim:
                        vutils.save_image(netG(traversal_corner)[0].add(1).mul(0.5), saved_image_folder+'/trav_c_%d.jpg'%iteration, nrow=num_steps+1)
                    
            load_params(netG, backup_para)

        if iteration % (save_interval*50) == 0 or iteration == total_iterations:
            backup_para = copy_G_params(netG)
            load_params(netG, avg_param_G)
            load_params(netG, backup_para)
            states = {'g':netG.state_dict(),
                        'd':netD.state_dict(),
                        'g_ema': avg_param_G,
                        'opt_g': optimizerG.state_dict(),
                        'opt_d': optimizerD.state_dict()}
            if args.use_infogan:
                states['opt_q'] = optimizerQ.state_dict()
            torch.save(states, saved_model_folder+'/all_%d.pth'%iteration)