Example #1
0
    def reconstruct_from_latent_presentation(input, generator, session):
        # ex, label = utils.split_labels_out_of_latent(input)
        #input=torch.randn(args.n_label * 1, args.nz)
        myz = Variable(input).cuda()
        myz = utils.normalize(myz)
        myz, input_class = utils.split_labels_out_of_latent(myz)

        new_img = generator(myz, input_class, session.phase,
                            session.alpha).detach().data.cpu()

        #gex = generator(input, label, session.phase, session.alpha).detach()
        return new_img
Example #2
0
def D_prediction_of_G_output(generator, encoder, step, alpha):
    # To use labels, enable here and elsewhere:
    #label = Variable(torch.ones(batch_size_by_phase(step), args.n_label)).cuda()
    #               label = Variable(
    #                    torch.multinomial(
    #                        torch.ones(args.n_label), args.batch_size, replacement=True)).cuda()

    myz = Variable(torch.randn(batch_size_by_phase(step),
                               args.nz)).cuda(async=(args.gpu_count > 1))
    myz = utils.normalize(myz)
    myz, label = utils.split_labels_out_of_latent(myz)

    fake_image = generator(myz, label, step, alpha)
    fake_predict, _ = encoder(fake_image, step, alpha, args.use_ALQ)

    loss = fake_predict.mean()
    return loss, fake_image
Example #3
0
    def generate_intermediate_samples(generator,
                                      global_i,
                                      session,
                                      writer=None,
                                      collateImages=True):
        generator.eval()

        utils.requires_grad(generator, False)

        # Total number is samplesRepeatN * colN * rowN
        # e.g. for 51200 samples, outcome is 5*80*128. Please only give multiples of 128 here.
        samplesRepeatN = int(args.sample_N / 128) if not collateImages else 1
        reso = 4 * 2**session.phase

        if not collateImages:
            special_dir = '../metrics/{}/{}/{}'.format(args.data, reso,
                                                       str(global_i).zfill(6))
            while os.path.exists(special_dir):
                special_dir += '_'

            os.makedirs(special_dir)

        for outer_count in range(samplesRepeatN):

            colN = 1 if not collateImages else min(
                10, int(np.ceil(args.sample_N / 4.0)))
            rowN = 128 if not collateImages else min(
                5, int(np.ceil(args.sample_N / 4.0)))
            images = []
            for _ in range(rowN):
                myz = utils.conditional_to_cuda(
                    Variable(torch.randn(args.n_label * colN, args.nz)))
                myz = utils.normalize(myz)
                myz, input_class = utils.split_labels_out_of_latent(myz)

                new_imgs = generator(myz, input_class, session.phase,
                                     session.alpha).detach().data.cpu()

                images.append(new_imgs)

            if collateImages:
                sample_dir = '{}/sample'.format(args.save_dir)
                if not os.path.exists(sample_dir):
                    os.makedirs(sample_dir)

                save_path = '{}/{}.png'.format(sample_dir,
                                               str(global_i + 1).zfill(6))
                torchvision.utils.save_image(torch.cat(images, 0),
                                             save_path,
                                             nrow=args.n_label * colN,
                                             normalize=True,
                                             range=(-1, 1),
                                             padding=0)
                # Hacky but this is an easy way to rescale the images to nice big lego format:
                im = Image.open(save_path)
                im2 = im.resize((1024, 512 if reso < 256 else 1024))
                im2.save(save_path)

                if writer:
                    writer.add_images(
                        'samples_latest_{}'.format(session.phase),
                        torch.cat(images, 0), session.phase)
                    # Igor: changed add_image to add_images
            else:
                for ii, img in enumerate(images):
                    torchvision.utils.save_image(
                        img,
                        '{}/{}_{}.png'.format(special_dir,
                                              str(global_i + 1).zfill(6),
                                              ii + outer_count * 128),
                        nrow=args.n_label * colN,
                        normalize=True,
                        range=(-1, 1),
                        padding=0)

        generator.train()
Example #4
0
    def generate_intermediate_samples(generator, global_i, session, writer=None, collateImages = True):
        generator.eval()
        with torch.no_grad():        

            save_root = args.sample_dir if args.sample_dir != None else args.save_dir

            utils.requires_grad(generator, False)
            
            # Total number is samplesRepeatN * colN * rowN
            # e.g. for 51200 samples, outcome is 5*80*128. Please only give multiples of 128 here.
            samplesRepeatN = max(1, int(args.sample_N / 128)) if not collateImages else 1
            reso = session.getReso()

            if not collateImages:
                special_dir = '{}/sample/{}'.format(save_root, str(global_i).zfill(6))
                while os.path.exists(special_dir):
                    special_dir += '_'

                os.makedirs(special_dir)
            
            for outer_count in range(samplesRepeatN):
                
                colN = 1 if not collateImages else min(10, int(np.ceil(args.sample_N / 4.0)))
                rowN = 128 if not collateImages else min(5, int(np.ceil(args.sample_N / 4.0)))
                images = []
                myzs = []
                for _ in range(rowN):
                    myz = {}
                    for i in range(2):
                       myz0 = Variable(torch.randn(args.n_label * colN, args.nz+1)).to(device=args.device)
                       myz[i] = utils.normalize(myz0)
                       #myz[i], input_class = utils.split_labels_out_of_latent(myz0)

                    if args.randomMixN <= 1:
                        new_imgs = generator(
                            myz[0],
                            None, # input_class,
                            session.phase,
                            session.alpha).detach() #.data.cpu()
                    else:
                        max_i = session.phase
                        style_layer_begin = 2 #np.random.randint(low=1, high=(max_i+1))
                        print("Cut at layer {} for the next {} random samples.".format(style_layer_begin, (myz[0].shape)))
                        
                        new_imgs = utils.gen_seq([ (myz[0], 0, style_layer_begin),
                                                   (myz[1], style_layer_begin, -1)
                                                 ], generator, session).detach()
                        
                    images.append(new_imgs)
                    myzs.append(myz[0]) # TODO: Support mixed latents with rejection sampling

                if args.rejection_sampling > 0 and not collateImages:
                    filtered_images = []
                    batch_size = 8
                    for b in range(int(len(images) / batch_size)):
                        img_batch = images[b*batch_size:(b+1)*batch_size]
                        reco_z = session.encoder(Variable(torch.cat(img_batch,0)), session.getResoPhase(), session.alpha, args.use_ALQ).detach()
                        cost_z = utils.mismatchV(torch.cat(myzs[b*batch_size:(b+1)*batch_size],0), reco_z, 'cos')
                        _, ii = torch.sort(cost_z)
                        keeper = args.rejection_sampling / 100.0
                        keepN = max(1, int(len(ii)*keeper))
                        for iindex in ii[:keepN]:
                            filtered_images.append(img_batch[iindex])
#                        filtered_images.append(img_batch[ii[:keepN]] ) #retain the best ones only
                    images = filtered_images

                if collateImages:
                    sample_dir = '{}/sample'.format(save_root)
                    if not os.path.exists(sample_dir):
                        os.makedirs(sample_dir)
                    
                    save_path = '{}/{}.png'.format(sample_dir,str(global_i + 1).zfill(6))
                    torchvision.utils.save_image(
                        torch.cat(images, 0),
                        save_path,
                        nrow=args.n_label * colN,
                        normalize=True,
                        range=(-1, 1),
                        padding=0)
                    # Hacky but this is an easy way to rescale the images to nice big lego format:
                    im = Image.open(save_path)
                    im2 = im.resize((1024, 512 if reso < 256 else 1024))
                    im2.save(save_path)

                    #if writer:
                    #    writer.add_image('samples_latest_{}'.format(session.phase), torch.cat(images, 0), session.phase)
                else:
                    print("Generating samples at {}...".format(special_dir))
                    for ii, img in enumerate(images):
                        ipath =  '{}/{}_{}.png'.format(special_dir, str(global_i + 1).zfill(6), ii+outer_count*128)
                        #print(ipath)
                        torchvision.utils.save_image(
                            img,
                            ipath,
                            nrow=args.n_label * colN,
                            normalize=True,
                            range=(-1, 1),
                            padding=0)

        generator.train()
Example #5
0
    def interpolate_images(generator, encoder, loader, epoch, prefix, session, writer=None):
        generator.eval()
        encoder.eval()

        with torch.no_grad():
            utils.requires_grad(generator, False)
            utils.requires_grad(encoder, False)

            nr_of_imgs = 4 if not args.hexmode else 6 # "Corners"
            reso = session.getReso()
            if True:
            #if Utils.interpolation_set_x is None or Utils.interpolation_set_x.size(2) != reso or (phase >= 1 and alpha < 1.0):
                if session.getResoPhase() < 1:
                    dataset = data.Utils.sample_data(loader, nr_of_imgs, reso)
                else:
                    dataset = data.Utils.sample_data2(loader, nr_of_imgs, reso, session)
                real_image, _ = next(dataset)
                Utils.interpolation_set_x = Variable(real_image, volatile=True).to(device=args.device)

            latent_reso_hor = 8
            latent_reso_ver = 8

            x = Utils.interpolation_set_x

            if args.hexmode:
                x = x[:nr_of_imgs] #Corners
                z0 = encoder(Variable(x), session.getResoPhase(), session.alpha, args.use_ALQ).detach()

                X = np.array([-1.7321,0.0000 ,1.7321 ,-2.5981,-0.8660,0.8660 ,2.5981 ,-3.4641,-1.7321,0.0000 ,1.7321 ,3.4641 ,-2.5981,-0.8660,0.8660 ,2.5981 ,-1.7321,0.0000,1.7321])
                Y = np.array([-3.0000,-3.0000,-3.0000,-1.5000,-1.5000,-1.5000,-1.5000,0.0000,0.0000,0.0000,0.0000,0.0000,1.5000,1.5000,1.5000,1.5000,3.0000,3.0000,3.0000])
                corner_indices = np.array([0, 2, 7, 11, 16, 18])
                inter_indices = [i for i in range(19) if not i in corner_indices]
                edge_indices = np.array([1, 3, 6, 12, 15, 17])
                #distances_to_corners = np.sqrt(np.power(X - X[corner_indices[0]], 2) + np.power(Y - Y[corner_indices[0]], 2))
                distances_to_corners = [np.sqrt(np.power(X - X[corner_indices[i]], 2) + np.power(Y - Y[corner_indices[i]], 2)) for i in range(6)]
                weights = 1.0 - distances_to_corners / np.max(distances_to_corners)
                z0_x = torch.zeros((19,args.nz+args.n_label)).to(device=args.device)
                z0_x[corner_indices,:] = z0
                z0_x[edge_indices[0], :] = Utils.slerp(z0[0], z0[1], 0.5)
                z0_x[edge_indices[1], :] = Utils.slerp(z0[0], z0[2], 0.5)
                z0_x[edge_indices[2], :] = Utils.slerp(z0[1], z0[3], 0.5)
                z0_x[edge_indices[3], :] = Utils.slerp(z0[2], z0[4], 0.5)
                z0_x[edge_indices[4], :] = Utils.slerp(z0[3], z0[5], 0.5)
                z0_x[edge_indices[5], :] = Utils.slerp(z0[4], z0[5], 0.5)
                # Linear:
                #z0x[inter_indices,:] = (weights.T @ z0)[inter_indices,:]
                z0_x[inter_indices,:] = (torch.from_numpy(weights.T.astype(np.float32)).to(device=args.device) @ z0)[inter_indices,:]
                z0_x = utils.normalize(z0_x)
            else:
                z0 = encoder(Variable(x), session.getResoPhase(), session.alpha, args.use_ALQ).detach()

            t = torch.FloatTensor(latent_reso_hor * (latent_reso_ver+1) + nr_of_imgs, x.size(1),
                                x.size(2), x.size(3))
            t[0:nr_of_imgs] = x.data[:]

            save_root = args.sample_dir if args.sample_dir != None else args.save_dir
            special_dir = save_root if not args.aux_outpath else args.aux_outpath

            if not os.path.exists(special_dir):
                os.makedirs(special_dir)

            for o_i in range(nr_of_imgs):
                single_save_path = '{}{}/hex_interpolations_{}_{}_{}_orig_{}.png'.format(special_dir, prefix, session.phase, epoch, session.alpha, o_i)
                grid = torchvision.utils.save_image(x.data[o_i] / 2 + 0.5, single_save_path, nrow=1, padding=0) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)?

            if args.hexmode:
                z0_x, label = utils.split_labels_out_of_latent(z0_x)
                gex = generator(z0_x, label, session.phase, session.alpha).detach()                

                for x_i in range(19):
                    single_save_path = '{}{}/hex_interpolations_{}_{}_{}x{}.png'.format(special_dir, prefix, session.phase, epoch, session.alpha, x_i)
                    grid = torchvision.utils.save_image(gex.data[x_i] / 2 + 0.5, single_save_path, nrow=1, padding=0) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)?

                return

            # Origs on the first row here
            # Corners are: z0[0] ... z0[1]
            #                .   
            #                .
            #              z0[2] ... z0[3]                

            delta_z_ver0 = ((z0[2] - z0[0]) / (latent_reso_ver - 1))
            delta_z_verN = ((z0[3] - z0[1]) / (latent_reso_ver - 1))

            for y_i in range(latent_reso_ver):
                if False: #Linear interpolation
                    z0_x0 = z0[0] + y_i * delta_z_ver0
                    z0_xN = z0[1] + y_i * delta_z_verN
                    delta_z_hor = (z0_xN - z0_x0) / (latent_reso_hor - 1)
                    z0_x = Variable(torch.FloatTensor(latent_reso_hor, z0_x0.size(0)))

                    for x_i in range(latent_reso_hor):
                        z0_x[x_i] = z0_x0 + x_i * delta_z_hor

                if True: #Spherical
                    t_y = float(y_i) / (latent_reso_ver-1)
                    #z0_y = Variable(torch.FloatTensor(latent_reso_ver, z0.size(0)))
                    
                    z0_y1 = Utils.slerp(z0[0].data.cpu().numpy(), z0[2].data.cpu().numpy(), t_y)
                    z0_y2 = Utils.slerp(z0[1].data.cpu().numpy(), z0[3].data.cpu().numpy(), t_y)
                    z0_x = Variable(torch.FloatTensor(latent_reso_hor, z0[0].size(0)))
                    for x_i in range(latent_reso_hor):
                        t_x = float(x_i) / (latent_reso_hor-1)
                        z0_x[x_i] = torch.from_numpy( Utils.slerp(z0_y1, z0_y2, t_x) )

                z0_x, label = utils.split_labels_out_of_latent(z0_x)
                gex = generator(z0_x, label, session.phase, session.alpha).detach()                
                
                # Recall that yi=0 is the original's row:
                t[(y_i+1) * latent_reso_ver:(y_i+2)* latent_reso_ver] = gex.data[:]

                for x_i in range(latent_reso_hor):
                    single_save_path = '{}{}/interpolations_{}_{}_{}_{}x{}.png'.format(special_dir, prefix, session.phase, epoch, session.alpha, y_i, x_i)
                    grid = torchvision.utils.save_image(gex.data[x_i] / 2 + 0.5, single_save_path, nrow=1, padding=0) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)?
            
            save_path = '{}{}/interpolations_{}_{}_{}.png'.format(special_dir, prefix, session.phase, epoch, session.alpha)
            grid = torchvision.utils.save_image(t / 2 + 0.5, save_path, nrow=latent_reso_ver, padding=0) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)?
            # Hacky but this is an easy way to rescale the images to nice big lego format:
            if session.getResoPhase() < 4:
                im = Image.open(save_path)
                im2 = im.resize((1024, 1024))
                im2.save(save_path)        

            #if writer:
            #    writer.add_image('interpolation_latest_{}'.format(session.phase), t / 2 + 0.5, session.phase)

        generator.train()
        encoder.train()
Example #6
0
    def eval_metrics(generator, encoder, loader, global_i, nr_of_imgs, prefix, reals, reconstructions, session,
                           writer=None):  # of the form"/[dir]"
        generator.eval()
        encoder.eval()

        with torch.no_grad():
            utils.requires_grad(generator, False)
            utils.requires_grad(encoder, False)

            if reconstructions and nr_of_imgs > 0:
                reso = session.getReso()

                batch_size = 16
                n_batches = (nr_of_imgs+batch_size-1) // batch_size
                n_used_imgs = n_batches * batch_size


                #fid_score = 0
                lpips_score = 0
                lpips_model = lpips.PerceptualLoss(model='net-lin', net='alex', use_gpu=False)
                real_images = []
                reco_images = []
                rand_images = []

                FIDm = 3 # FID multiplier

                if session.getResoPhase() >= 3:
                    for o in range(n_batches*FIDm):
                        myz = Variable(torch.randn(args.n_label * batch_size, args.nz)).to(device=args.device)
                        myz = utils.normalize(myz)
                        myz, input_class = utils.split_labels_out_of_latent(myz)

                        random_image = generator(
                                myz,
                                input_class,
                                session.phase,
                                session.alpha).detach().data.cpu()

                        rand_images.append(random_image)

                        if o >= n_batches: #Reconstructions only up to n_batches, FIDs will take n_batches * 3 samples
                             continue

                        dataset = data.Utils.sample_data2(loader, batch_size, reso, session)

                        real_image, _ = next(dataset)
                        reco_image = Utils.reconstruct(real_image, encoder, generator, session)

                        t = torch.FloatTensor(real_image.size(0) * 2, real_image.size(1),
                                              real_image.size(2), real_image.size(3))

                        # compute metrics and write it to tensorboard (if the reso >= 32)


                        #lpips_score = "?"

                        crop_needed = (args.data == 'celeba' or args.data == 'celebaHQ' or args.data == 'ffhq')

                        if session.getResoPhase() >= 4 or not crop_needed: # 32x32 is minimum for LPIPS
                            if crop_needed:
                                real_image = Utils.face_as_cropped(real_image)
                                reco_image = Utils.face_as_cropped(reco_image)

                        real_images.append(real_image)
                        reco_images.append(reco_image.detach().data.cpu())

                    real_images = torch.cat(real_images, 0)
                    reco_images = torch.cat(reco_images, 0)
                    rand_images = torch.cat(rand_images, 0)


                    lpips_dist = lpips_model.forward(real_images, reco_images)
                    lpips_score = torch.mean(lpips_dist)

                    fid_score = calculate_fid_given_images(real_images, rand_images, batch_size=batch_size, cuda=False, dims=2048)

                    writer.add_scalar('LPIPS', lpips_score, session.sample_i)
                    writer.add_scalar('FID', fid_score, session.sample_i)

                    print("{}: FID = {}, LPIPS = {}".format(session.sample_i, fid_score, lpips_score))

        encoder.train()
        generator.train()