Exemple #1
0
    def reconstruction_dryrun(generator, encoder, loader, session):
        generator.eval()
        encoder.eval()

        utils.requires_grad(generator, False)
        utils.requires_grad(encoder, False)

        reso = 4 * 2**session.phase

        warmup_rounds = 200
        print('Warm-up rounds: {}'.format(warmup_rounds))

        if session.phase < 1:
            dataset = data.Utils.sample_data(loader, 4, reso)
        else:
            dataset = data.Utils.sample_data2(loader, 4, reso, session)
        real_image, _ = next(dataset)
        x = utils.conditional_to_cuda(Variable(real_image))

        for i in range(warmup_rounds):
            ex = encoder(x, session.phase, session.alpha,
                         args.use_ALQ).detach()
            ex, label = utils.split_labels_out_of_latent(ex)
            gex = generator(ex, label, session.phase, session.alpha).detach()

        encoder.train()
        generator.train()
Exemple #2
0
 def reconstruct(input_image, encoder, generator, session):
     with torch.no_grad():
         ex = encoder(Variable(input_image), session.phase, session.alpha,
                      args.use_ALQ).detach()
         ex, label = utils.split_labels_out_of_latent(ex)
         gex = generator(ex, label, session.phase, session.alpha).detach()
     return gex.data[:]
Exemple #3
0
 def extract_latent_presentation_from_image(input_image, encoder, session):
     ex = encoder(Variable(input_image, volatile=True), session.phase,
                  session.alpha, args.use_ALQ).detach()
     ex, label = utils.split_labels_out_of_latent(ex)
     # print("\n\n\n EX IS")
     #print(label)
     #gex = generator(ex, label, session.phase, session.alpha).detach()
     return ex  #gex.data[:]
Exemple #4
0
                    def KL_of_encoded_G_output(generator, z):
                        utils.populate_z(z, args.nz + args.n_label, args.noise,
                                         batch_size(reso))
                        z, label = utils.split_labels_out_of_latent(z)
                        fake = generator(z, label, session.phase,
                                         session.alpha)

                        egz = encoder(fake, session.phase, session.alpha,
                                      args.use_ALQ)
                        # KL_fake: \Delta( e(g(Z)) , Z ) -> min_g
                        return egz, label, KL_minimizer(
                            egz) * args.fake_G_KL_scale, z
Exemple #5
0
    def reconstruct_from_latent_presentation(input, generator, session):
        # ex, label = utils.split_labels_out_of_latent(input)
        #input=torch.randn(args.n_label * 1, args.nz)
        myz = Variable(input).cuda()
        myz = utils.normalize(myz)
        myz, input_class = utils.split_labels_out_of_latent(myz)

        new_img = generator(myz, input_class, session.phase,
                            session.alpha).detach().data.cpu()

        #gex = generator(input, label, session.phase, session.alpha).detach()
        return new_img
Exemple #6
0
def D_prediction_of_G_output(generator, encoder, step, alpha):
    # To use labels, enable here and elsewhere:
    #label = Variable(torch.ones(batch_size_by_phase(step), args.n_label)).cuda()
    #               label = Variable(
    #                    torch.multinomial(
    #                        torch.ones(args.n_label), args.batch_size, replacement=True)).cuda()

    myz = Variable(torch.randn(batch_size_by_phase(step),
                               args.nz)).cuda(async=(args.gpu_count > 1))
    myz = utils.normalize(myz)
    myz, label = utils.split_labels_out_of_latent(myz)

    fake_image = generator(myz, label, step, alpha)
    fake_predict, _ = encoder(fake_image, step, alpha, args.use_ALQ)

    loss = fake_predict.mean()
    return loss, fake_image
Exemple #7
0
    def make(generator, session, writer):
        import bcolz
        if not os.path.exists(args.x_outpath):
            os.makedirs(args.x_outpath)

        utils.requires_grad(generator, False)

        z = bcolz.carray(rootdir=args.z_inpath)
        print("Loading external z from {}, shape:".format(args.z_inpath))
        print(np.shape(z))
        N = np.shape(z)[0]

        samplesRepeatN =  1 + int(N / 8) #Assume large files...
        samplesDone = 0
        for outer_count in range(samplesRepeatN):
            samplesN = min(8, N - samplesDone)
            if samplesN <= 0:
                break
            zrange = z[samplesDone:samplesDone+samplesN, :]
            myz, input_class = utils.split_labels_out_of_latent(torch.from_numpy(zrange.astype(np.float32)))
            new_imgs = generator(
                myz,
                input_class,
                session.phase,
                session.alpha).detach() #.data.cpu()

            for ii, img in enumerate(new_imgs):
                torchvision.utils.save_image(
                    img,
                    '{}/{}.png'.format(args.x_outpath, str(ii + outer_count*8)), #.zfill(6)
                    nrow=args.n_label,
                    normalize=True,
                    range=(-1, 1),
                    padding=0)

            samplesDone += samplesN
Exemple #8
0
def train(generator, encoder, g_running, train_data_loader, test_data_loader,
          session, total_steps, train_mode):
    pbar = tqdm(initial=session.sample_i, total=total_steps)

    benchmarking = False

    match_x = args.match_x
    generatedImagePool = None

    refresh_dataset = True
    refresh_imagePool = True

    # After the Loading stage, we cycle through successive Fade-in and Stabilization stages

    batch_count = 0

    reset_optimizers_on_phase_start = False

    # TODO Unhack this (only affects the episode count statistics anyway):
    if args.data != 'celebaHQ':
        epoch_len = len(train_data_loader(1, 4).dataset)
    else:
        epoch_len = train_data_loader._len['data4x4']

    if args.step_offset != 0:
        if args.step_offset == -1:
            args.step_offset = session.sample_i
        print("Step offset is {}".format(args.step_offset))
        session.phase += args.phase_offset
        session.alpha = 0.0

    while session.sample_i < total_steps:
        #######################  Phase Maintenance #######################

        steps_in_previous_phases = max(session.phase * args.images_per_stage,
                                       args.step_offset)

        sample_i_current_stage = session.sample_i - steps_in_previous_phases

        # If we can move to the next phase
        if sample_i_current_stage >= args.images_per_stage:
            if session.phase < args.max_phase:  # If any phases left
                iteration_levels = int(sample_i_current_stage /
                                       args.images_per_stage)
                session.phase += iteration_levels
                sample_i_current_stage -= iteration_levels * args.images_per_stage
                match_x = args.match_x  # Reset to non-matching phase
                print(
                    "iteration B alpha={} phase {} will be reduced to 1 and [max]"
                    .format(sample_i_current_stage, session.phase))

                refresh_dataset = True
                refresh_imagePool = True  # Reset the pool to avoid images of 2 different resolutions in the pool

                if reset_optimizers_on_phase_start:
                    utils.requires_grad(generator)
                    utils.requires_grad(encoder)
                    generator.zero_grad()
                    encoder.zero_grad()
                    session.reset_opt()
                    print("Optimizers have been reset.")

        reso = 4 * 2**session.phase

        # If we can switch from fade-training to stable-training
        if sample_i_current_stage >= args.images_per_stage / 2:
            if session.alpha < 1.0:
                refresh_dataset = True  # refresh dataset generator since no longer have to fade
            match_x = args.match_x * args.matching_phase_x
        else:
            match_x = args.match_x

        session.alpha = min(
            1, sample_i_current_stage * 2.0 / args.images_per_stage
        )  # For 100k, it was 0.00002 = 2.0 / args.images_per_stage

        if refresh_dataset:
            train_dataset = data.Utils.sample_data2(train_data_loader,
                                                    batch_size(reso), reso,
                                                    session)
            refresh_dataset = False
            print("Refreshed dataset. Alpha={} and iteration={}".format(
                session.alpha, sample_i_current_stage))
        if refresh_imagePool:
            imagePoolSize = 200 if reso < 256 else 100
            generatedImagePool = utils.ImagePool(
                imagePoolSize
            )  #Reset the pool to avoid images of 2 different resolutions in the pool
            refresh_imagePool = False
            print('Image pool created with size {} because reso is {}'.format(
                imagePoolSize, reso))

        ####################### Training init #######################

        z = utils.conditional_to_cuda(
            Variable(torch.FloatTensor(batch_size(reso), args.nz, 1, 1)),
            args.gpu_count > 1)
        KL_minimizer = KLN01Loss(direction=args.KL, minimize=True)
        KL_maximizer = KLN01Loss(direction=args.KL, minimize=False)

        stats = {}

        #one = torch.FloatTensor([1]).cuda(async=(args.gpu_count>1))
        one = torch.FloatTensor([1]).cuda(non_blocking=(args.gpu_count > 1))

        try:
            real_image, _ = next(train_dataset)
        except (OSError, StopIteration):
            train_dataset = data.Utils.sample_data2(train_data_loader,
                                                    batch_size(reso), reso,
                                                    session)
            real_image, _ = next(train_dataset)

        ####################### DISCRIMINATOR / ENCODER ###########################

        utils.switch_grad_updates_to_first_of(encoder, generator)
        encoder.zero_grad()

        #x = Variable(real_image).cuda(async=(args.gpu_count>1))
        x = Variable(real_image).cuda(non_blocking=(args.gpu_count > 1))
        kls = ""
        if train_mode == config.MODE_GAN:

            # Discriminator for real samples
            real_predict, _ = encoder(x, session.phase, session.alpha,
                                      args.use_ALQ)
            real_predict = real_predict.mean() \
                - 0.001 * (real_predict ** 2).mean()
            real_predict.backward(-one)  # Towards 1

            # (1) Generator => D. Identical to (2) see below

            fake_predict, fake_image = D_prediction_of_G_output(
                generator, encoder, session.phase, session.alpha)
            fake_predict.backward(one)

            # Grad penalty

            grad_penalty = get_grad_penalty(encoder, x, fake_image,
                                            session.phase, session.alpha)
            grad_penalty.backward()

        elif train_mode == config.MODE_CYCLIC:
            e_losses = []

            # e(X)

            real_z = encoder(x, session.phase, session.alpha, args.use_ALQ)
            if args.use_real_x_KL:
                # KL_real: - \Delta( e(X) , Z ) -> max_e
                KL_real = KL_minimizer(real_z) * args.real_x_KL_scale
                e_losses.append(KL_real)

                stats['real_mean'] = KL_minimizer.samples_mean.data.mean()
                stats['real_var'] = KL_minimizer.samples_var.data.mean()
                stats['KL_real'] = KL_real  #FIXME .data[0]
                kls = "{0:.3f}".format(stats['KL_real'])

            # The final entries are the label. Normal case, just 1. Extract it/them, and make it [b x 1]:

            real_z, label = utils.split_labels_out_of_latent(real_z)
            recon_x = generator(real_z, label, session.phase, session.alpha)
            if args.use_loss_x_reco:
                # match_x: E_x||g(e(x)) - x|| -> min_e
                err = utils.mismatch(recon_x, x, args.match_x_metric) * match_x
                e_losses.append(err)
                stats['x_reconstruction_error'] = err  #FIXME .data[0]

            args.use_wpgan_grad_penalty = False
            grad_penalty = 0.0

            if args.use_loss_fake_D_KL:
                # TODO: The following codeblock is essentially the same as the KL_minimizer part on G side. Unify
                utils.populate_z(z, args.nz + args.n_label, args.noise,
                                 batch_size(reso))
                z, label = utils.split_labels_out_of_latent(z)
                fake = generator(z, label, session.phase,
                                 session.alpha).detach()

                if session.alpha >= 1.0:
                    fake = generatedImagePool.query(fake.data)

                # e(g(Z))
                egz = encoder(fake, session.phase, session.alpha, args.use_ALQ)

                # KL_fake: \Delta( e(g(Z)) , Z ) -> max_e
                KL_fake = KL_maximizer(egz) * args.fake_D_KL_scale

                # Added by Igor
                m = args.kl_margin
                if m > 0.0 and session.phase >= 5:
                    KL_loss = torch.min(
                        torch.ones_like(KL_fake) * m / 2,
                        torch.max(-torch.ones_like(KL_fake) * m,
                                  KL_real + KL_fake))
                    # KL_fake is always negative with abs value typically larger than KL_real.
                    # Hence, the sum is negative, and must be gapped so that the minimum is the negative of the margin.
                else:
                    KL_loss = KL_real + KL_fake
                e_losses.append(KL_loss)
                #e_losses.append(KL_fake)

                stats['fake_mean'] = KL_maximizer.samples_mean.data.mean()
                stats['fake_var'] = KL_maximizer.samples_var.data.mean()
                stats['KL_fake'] = -KL_fake  #FIXME .data[0]
                kls = "{0}/{1:.3f}".format(kls, stats['KL_fake'])

                if args.use_wpgan_grad_penalty:
                    grad_penalty = get_grad_penalty(encoder, x, fake,
                                                    session.phase,
                                                    session.alpha)

            # Update e
            if len(e_losses) > 0:
                e_loss = sum(e_losses)
                stats['E_loss'] = e_loss.detach().cpu().float().numpy(
                )  #np.float32(e_loss.data)
                e_loss.backward()

                if args.use_wpgan_grad_penalty:
                    grad_penalty.backward()
                    stats['Grad_penalty'] = grad_penalty.data

                #book-keeping
                disc_loss_val = e_loss  #FIXME .data[0]

        session.optimizerD.step()

        torch.cuda.empty_cache()

        ######################## GENERATOR / DECODER #############################

        if (batch_count + 1) % args.n_critic == 0:
            utils.switch_grad_updates_to_first_of(generator, encoder)

            for _ in range(args.n_generator):
                generator.zero_grad()
                g_losses = []

                if train_mode == config.MODE_GAN:
                    fake_predict, _ = D_prediction_of_G_output(
                        generator, encoder, session.phase, session.alpha)
                    loss = -fake_predict
                    g_losses.append(loss)

                elif train_mode == config.MODE_CYCLIC:  #TODO We push the z variable around here like idiots

                    def KL_of_encoded_G_output(generator, z):
                        utils.populate_z(z, args.nz + args.n_label, args.noise,
                                         batch_size(reso))
                        z, label = utils.split_labels_out_of_latent(z)
                        fake = generator(z, label, session.phase,
                                         session.alpha)

                        egz = encoder(fake, session.phase, session.alpha,
                                      args.use_ALQ)
                        # KL_fake: \Delta( e(g(Z)) , Z ) -> min_g
                        return egz, label, KL_minimizer(
                            egz) * args.fake_G_KL_scale, z

                    egz, label, kl, z = KL_of_encoded_G_output(generator, z)

                    if args.use_loss_KL_z:
                        g_losses.append(kl)  # G minimizes this KL
                        stats['KL(Phi(G))'] = kl  #FIXME .data[0]
                        kls = "{0}/{1:.3f}".format(kls, stats['KL(Phi(G))'])

                    if args.use_loss_z_reco:
                        z = torch.cat((z, label), 1)
                        z_diff = utils.mismatch(
                            egz, z, args.match_z_metric
                        ) * args.match_z  # G tries to make the original z and encoded z match
                        g_losses.append(z_diff)

                if len(g_losses) > 0:
                    loss = sum(g_losses)
                    stats['G_loss'] = loss.detach().cpu().float().numpy(
                    )  #np.float32(loss.data)
                    loss.backward()

                    # Book-keeping only:
                    gen_loss_val = loss  #FIXME .data[0]

                session.optimizerG.step()

                torch.cuda.empty_cache()

                if train_mode == config.MODE_CYCLIC:
                    if args.use_loss_z_reco:
                        stats[
                            'z_reconstruction_error'] = z_diff  #FIXME .data[0]

            accumulate(g_running, generator)

            del z, x, one, real_image, real_z, KL_real, label, recon_x, fake, egz, KL_fake, kl, z_diff

            if train_mode == config.MODE_CYCLIC:
                if args.use_TB:
                    for key, val in stats.items():
                        writer.add_scalar(key, val, session.sample_i)
                elif batch_count % 100 == 0:
                    print(stats)

            if args.use_TB:
                writer.add_scalar('LOD', session.phase + session.alpha,
                                  session.sample_i)

        ########################  Statistics ########################

        b = batch_size_by_phase(session.phase)
        zr, xr = (stats['z_reconstruction_error'],
                  stats['x_reconstruction_error']
                  ) if train_mode == config.MODE_CYCLIC else (0.0, 0.0)
        e = (session.sample_i / float(epoch_len))
        pbar.set_description((
            '{0}; it: {1}; phase: {2}; b: {3:.1f}; Alpha: {4:.3f}; Reso: {5}; E: {6:.2f}; KL(real/fake/fakeG): {7}; z-reco: {8:.2f}; x-reco {9:.3f}; real_var {10:.4f}'
        ).format(batch_count + 1, session.sample_i + 1, session.phase, b,
                 session.alpha, reso, e, kls, zr, xr, stats['real_var']))
        #(f'{i + 1}; it: {iteration+1}; b: {b:.1f}; G: {gen_loss_val:.5f}; D: {disc_loss_val:.5f};'
        # f' Grad: {grad_loss_val:.5f}; Alpha: {alpha:.3f}; Reso: {reso}; S-mean: {real_mean:.3f}; KL(real/fake/fakeG): {kls}; z-reco: {zr:.2f}'))

        pbar.update(batch_size(reso))
        session.sample_i += batch_size(reso)  # if not benchmarking else 100
        batch_count += 1

        ########################  Saving ########################

        if batch_count % args.checkpoint_cycle == 0:
            for postfix in {'latest', str(session.sample_i).zfill(6)}:
                session.save_all('{}/{}_state'.format(args.checkpoint_dir,
                                                      postfix))

            print("Checkpointed to {}".format(session.sample_i))

        ########################  Tests ########################

        try:
            evaluate.tests_run(
                g_running,
                encoder,
                test_data_loader,
                session,
                writer,
                reconstruction=(batch_count % 800 == 0),
                interpolation=(batch_count % 800 == 0),
                collated_sampling=(batch_count % 800 == 0),
                individual_sampling=(
                    batch_count %
                    (args.images_per_stage / batch_size(reso) / 4) == 0))
        except (OSError, StopIteration):
            print("Skipped periodic tests due to an exception.")

    pbar.close()
Exemple #9
0
    def interpolate_images(generator,
                           encoder,
                           loader,
                           epoch,
                           prefix,
                           session,
                           writer=None):
        generator.eval()
        encoder.eval()

        utils.requires_grad(generator, False)
        utils.requires_grad(encoder, False)

        nr_of_imgs = 4  # "Corners"
        reso = 4 * 2**session.phase
        if True:
            #if Utils.interpolation_set_x is None or Utils.interpolation_set_x.size(2) != reso or (phase >= 1 and alpha < 1.0):
            if session.phase < 1:
                dataset = data.Utils.sample_data(loader, nr_of_imgs, reso)
            else:
                dataset = data.Utils.sample_data2(loader, nr_of_imgs, reso,
                                                  session)
            real_image, _ = next(dataset)
            Utils.interpolation_set_x = utils.conditional_to_cuda(
                Variable(real_image, volatile=True))

        latent_reso_hor = 8
        latent_reso_ver = 8

        x = Utils.interpolation_set_x

        z0 = encoder(Variable(x), session.phase, session.alpha,
                     args.use_ALQ).detach()

        t = torch.FloatTensor(
            latent_reso_hor * (latent_reso_ver + 1) + nr_of_imgs, x.size(1),
            x.size(2), x.size(3))
        t[0:nr_of_imgs] = x.data[:]

        special_dir = args.save_dir if not args.aux_outpath else args.aux_outpath

        if not os.path.exists(special_dir):
            os.makedirs(special_dir)

        for o_i in range(nr_of_imgs):
            single_save_path = '{}{}/interpolations_{}_{}_{}_orig_{}.png'.format(
                special_dir, prefix, session.phase, epoch, session.alpha, o_i)
            grid = torchvision.utils.save_image(
                x.data[o_i] / 2 + 0.5, single_save_path, nrow=1, padding=0
            )  #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)?

        # Origs on the first row here
        # Corners are: z0[0] ... z0[1]
        #                .
        #                .
        #              z0[2] ... z0[3]

        delta_z_ver0 = ((z0[2] - z0[0]) / (latent_reso_ver - 1))
        delta_z_verN = ((z0[3] - z0[1]) / (latent_reso_ver - 1))
        for y_i in range(latent_reso_ver):
            if False:  #Linear interpolation
                z0_x0 = z0[0] + y_i * delta_z_ver0
                z0_xN = z0[1] + y_i * delta_z_verN
                delta_z_hor = (z0_xN - z0_x0) / (latent_reso_hor - 1)
                z0_x = Variable(
                    torch.FloatTensor(latent_reso_hor, z0_x0.size(0)))

                for x_i in range(latent_reso_hor):
                    z0_x[x_i] = z0_x0 + x_i * delta_z_hor

            if True:  #Spherical
                t_y = float(y_i) / (latent_reso_ver - 1)
                #z0_y = Variable(torch.FloatTensor(latent_reso_ver, z0.size(0)))
                z0_y1 = Utils.slerp(z0[0].data, z0[2].data, t_y)
                z0_y2 = Utils.slerp(z0[1].data, z0[3].data, t_y)
                z0_x = Variable(
                    torch.FloatTensor(latent_reso_hor, z0[0].size(0)))
                for x_i in range(latent_reso_hor):
                    t_x = float(x_i) / (latent_reso_hor - 1)
                    z0_x[x_i] = Utils.slerp(z0_y1, z0_y2, t_x)

            z0_x, label = utils.split_labels_out_of_latent(z0_x)
            gex = generator(z0_x, label, session.phase, session.alpha).detach()

            # Recall that yi=0 is the original's row:
            t[(y_i + 1) * latent_reso_ver:(y_i + 2) *
              latent_reso_ver] = gex.data[:]

            for x_i in range(latent_reso_hor):
                single_save_path = '{}{}/interpolations_{}_{}_{}_{}x{}.png'.format(
                    special_dir, prefix, session.phase, epoch, session.alpha,
                    y_i, x_i)
                grid = torchvision.utils.save_image(
                    gex.data[x_i] / 2 + 0.5,
                    single_save_path,
                    nrow=1,
                    padding=0
                )  #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)?

        save_path = '{}{}/interpolations_{}_{}_{}.png'.format(
            special_dir, prefix, session.phase, epoch, session.alpha)
        grid = torchvision.utils.save_image(
            t / 2 + 0.5, save_path, nrow=latent_reso_ver, padding=0
        )  #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)?
        # Hacky but this is an easy way to rescale the images to nice big lego format:
        if session.phase < 4:
            im = Image.open(save_path)
            im2 = im.resize((1024, 1024))
            im2.save(save_path)

        if writer:
            writer.add_images('interpolation_latest_{}'.format(session.phase),
                              t / 2 + 0.5, session.phase)
            # Igor: changed add_image to add_images

        generator.train()
        encoder.train()
Exemple #10
0
    def generate_intermediate_samples(generator,
                                      global_i,
                                      session,
                                      writer=None,
                                      collateImages=True):
        generator.eval()

        utils.requires_grad(generator, False)

        # Total number is samplesRepeatN * colN * rowN
        # e.g. for 51200 samples, outcome is 5*80*128. Please only give multiples of 128 here.
        samplesRepeatN = int(args.sample_N / 128) if not collateImages else 1
        reso = 4 * 2**session.phase

        if not collateImages:
            special_dir = '../metrics/{}/{}/{}'.format(args.data, reso,
                                                       str(global_i).zfill(6))
            while os.path.exists(special_dir):
                special_dir += '_'

            os.makedirs(special_dir)

        for outer_count in range(samplesRepeatN):

            colN = 1 if not collateImages else min(
                10, int(np.ceil(args.sample_N / 4.0)))
            rowN = 128 if not collateImages else min(
                5, int(np.ceil(args.sample_N / 4.0)))
            images = []
            for _ in range(rowN):
                myz = utils.conditional_to_cuda(
                    Variable(torch.randn(args.n_label * colN, args.nz)))
                myz = utils.normalize(myz)
                myz, input_class = utils.split_labels_out_of_latent(myz)

                new_imgs = generator(myz, input_class, session.phase,
                                     session.alpha).detach().data.cpu()

                images.append(new_imgs)

            if collateImages:
                sample_dir = '{}/sample'.format(args.save_dir)
                if not os.path.exists(sample_dir):
                    os.makedirs(sample_dir)

                save_path = '{}/{}.png'.format(sample_dir,
                                               str(global_i + 1).zfill(6))
                torchvision.utils.save_image(torch.cat(images, 0),
                                             save_path,
                                             nrow=args.n_label * colN,
                                             normalize=True,
                                             range=(-1, 1),
                                             padding=0)
                # Hacky but this is an easy way to rescale the images to nice big lego format:
                im = Image.open(save_path)
                im2 = im.resize((1024, 512 if reso < 256 else 1024))
                im2.save(save_path)

                if writer:
                    writer.add_images(
                        'samples_latest_{}'.format(session.phase),
                        torch.cat(images, 0), session.phase)
                    # Igor: changed add_image to add_images
            else:
                for ii, img in enumerate(images):
                    torchvision.utils.save_image(
                        img,
                        '{}/{}_{}.png'.format(special_dir,
                                              str(global_i + 1).zfill(6),
                                              ii + outer_count * 128),
                        nrow=args.n_label * colN,
                        normalize=True,
                        range=(-1, 1),
                        padding=0)

        generator.train()
Exemple #11
0
    def interpolate_images(generator, encoder, loader, epoch, prefix, session, writer=None):
        generator.eval()
        encoder.eval()

        with torch.no_grad():
            utils.requires_grad(generator, False)
            utils.requires_grad(encoder, False)

            nr_of_imgs = 4 if not args.hexmode else 6 # "Corners"
            reso = session.getReso()
            if True:
            #if Utils.interpolation_set_x is None or Utils.interpolation_set_x.size(2) != reso or (phase >= 1 and alpha < 1.0):
                if session.getResoPhase() < 1:
                    dataset = data.Utils.sample_data(loader, nr_of_imgs, reso)
                else:
                    dataset = data.Utils.sample_data2(loader, nr_of_imgs, reso, session)
                real_image, _ = next(dataset)
                Utils.interpolation_set_x = Variable(real_image, volatile=True).to(device=args.device)

            latent_reso_hor = 8
            latent_reso_ver = 8

            x = Utils.interpolation_set_x

            if args.hexmode:
                x = x[:nr_of_imgs] #Corners
                z0 = encoder(Variable(x), session.getResoPhase(), session.alpha, args.use_ALQ).detach()

                X = np.array([-1.7321,0.0000 ,1.7321 ,-2.5981,-0.8660,0.8660 ,2.5981 ,-3.4641,-1.7321,0.0000 ,1.7321 ,3.4641 ,-2.5981,-0.8660,0.8660 ,2.5981 ,-1.7321,0.0000,1.7321])
                Y = np.array([-3.0000,-3.0000,-3.0000,-1.5000,-1.5000,-1.5000,-1.5000,0.0000,0.0000,0.0000,0.0000,0.0000,1.5000,1.5000,1.5000,1.5000,3.0000,3.0000,3.0000])
                corner_indices = np.array([0, 2, 7, 11, 16, 18])
                inter_indices = [i for i in range(19) if not i in corner_indices]
                edge_indices = np.array([1, 3, 6, 12, 15, 17])
                #distances_to_corners = np.sqrt(np.power(X - X[corner_indices[0]], 2) + np.power(Y - Y[corner_indices[0]], 2))
                distances_to_corners = [np.sqrt(np.power(X - X[corner_indices[i]], 2) + np.power(Y - Y[corner_indices[i]], 2)) for i in range(6)]
                weights = 1.0 - distances_to_corners / np.max(distances_to_corners)
                z0_x = torch.zeros((19,args.nz+args.n_label)).to(device=args.device)
                z0_x[corner_indices,:] = z0
                z0_x[edge_indices[0], :] = Utils.slerp(z0[0], z0[1], 0.5)
                z0_x[edge_indices[1], :] = Utils.slerp(z0[0], z0[2], 0.5)
                z0_x[edge_indices[2], :] = Utils.slerp(z0[1], z0[3], 0.5)
                z0_x[edge_indices[3], :] = Utils.slerp(z0[2], z0[4], 0.5)
                z0_x[edge_indices[4], :] = Utils.slerp(z0[3], z0[5], 0.5)
                z0_x[edge_indices[5], :] = Utils.slerp(z0[4], z0[5], 0.5)
                # Linear:
                #z0x[inter_indices,:] = (weights.T @ z0)[inter_indices,:]
                z0_x[inter_indices,:] = (torch.from_numpy(weights.T.astype(np.float32)).to(device=args.device) @ z0)[inter_indices,:]
                z0_x = utils.normalize(z0_x)
            else:
                z0 = encoder(Variable(x), session.getResoPhase(), session.alpha, args.use_ALQ).detach()

            t = torch.FloatTensor(latent_reso_hor * (latent_reso_ver+1) + nr_of_imgs, x.size(1),
                                x.size(2), x.size(3))
            t[0:nr_of_imgs] = x.data[:]

            save_root = args.sample_dir if args.sample_dir != None else args.save_dir
            special_dir = save_root if not args.aux_outpath else args.aux_outpath

            if not os.path.exists(special_dir):
                os.makedirs(special_dir)

            for o_i in range(nr_of_imgs):
                single_save_path = '{}{}/hex_interpolations_{}_{}_{}_orig_{}.png'.format(special_dir, prefix, session.phase, epoch, session.alpha, o_i)
                grid = torchvision.utils.save_image(x.data[o_i] / 2 + 0.5, single_save_path, nrow=1, padding=0) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)?

            if args.hexmode:
                z0_x, label = utils.split_labels_out_of_latent(z0_x)
                gex = generator(z0_x, label, session.phase, session.alpha).detach()                

                for x_i in range(19):
                    single_save_path = '{}{}/hex_interpolations_{}_{}_{}x{}.png'.format(special_dir, prefix, session.phase, epoch, session.alpha, x_i)
                    grid = torchvision.utils.save_image(gex.data[x_i] / 2 + 0.5, single_save_path, nrow=1, padding=0) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)?

                return

            # Origs on the first row here
            # Corners are: z0[0] ... z0[1]
            #                .   
            #                .
            #              z0[2] ... z0[3]                

            delta_z_ver0 = ((z0[2] - z0[0]) / (latent_reso_ver - 1))
            delta_z_verN = ((z0[3] - z0[1]) / (latent_reso_ver - 1))

            for y_i in range(latent_reso_ver):
                if False: #Linear interpolation
                    z0_x0 = z0[0] + y_i * delta_z_ver0
                    z0_xN = z0[1] + y_i * delta_z_verN
                    delta_z_hor = (z0_xN - z0_x0) / (latent_reso_hor - 1)
                    z0_x = Variable(torch.FloatTensor(latent_reso_hor, z0_x0.size(0)))

                    for x_i in range(latent_reso_hor):
                        z0_x[x_i] = z0_x0 + x_i * delta_z_hor

                if True: #Spherical
                    t_y = float(y_i) / (latent_reso_ver-1)
                    #z0_y = Variable(torch.FloatTensor(latent_reso_ver, z0.size(0)))
                    
                    z0_y1 = Utils.slerp(z0[0].data.cpu().numpy(), z0[2].data.cpu().numpy(), t_y)
                    z0_y2 = Utils.slerp(z0[1].data.cpu().numpy(), z0[3].data.cpu().numpy(), t_y)
                    z0_x = Variable(torch.FloatTensor(latent_reso_hor, z0[0].size(0)))
                    for x_i in range(latent_reso_hor):
                        t_x = float(x_i) / (latent_reso_hor-1)
                        z0_x[x_i] = torch.from_numpy( Utils.slerp(z0_y1, z0_y2, t_x) )

                z0_x, label = utils.split_labels_out_of_latent(z0_x)
                gex = generator(z0_x, label, session.phase, session.alpha).detach()                
                
                # Recall that yi=0 is the original's row:
                t[(y_i+1) * latent_reso_ver:(y_i+2)* latent_reso_ver] = gex.data[:]

                for x_i in range(latent_reso_hor):
                    single_save_path = '{}{}/interpolations_{}_{}_{}_{}x{}.png'.format(special_dir, prefix, session.phase, epoch, session.alpha, y_i, x_i)
                    grid = torchvision.utils.save_image(gex.data[x_i] / 2 + 0.5, single_save_path, nrow=1, padding=0) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)?
            
            save_path = '{}{}/interpolations_{}_{}_{}.png'.format(special_dir, prefix, session.phase, epoch, session.alpha)
            grid = torchvision.utils.save_image(t / 2 + 0.5, save_path, nrow=latent_reso_ver, padding=0) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)?
            # Hacky but this is an easy way to rescale the images to nice big lego format:
            if session.getResoPhase() < 4:
                im = Image.open(save_path)
                im2 = im.resize((1024, 1024))
                im2.save(save_path)        

            #if writer:
            #    writer.add_image('interpolation_latest_{}'.format(session.phase), t / 2 + 0.5, session.phase)

        generator.train()
        encoder.train()
Exemple #12
0
    def eval_metrics(generator, encoder, loader, global_i, nr_of_imgs, prefix, reals, reconstructions, session,
                           writer=None):  # of the form"/[dir]"
        generator.eval()
        encoder.eval()

        with torch.no_grad():
            utils.requires_grad(generator, False)
            utils.requires_grad(encoder, False)

            if reconstructions and nr_of_imgs > 0:
                reso = session.getReso()

                batch_size = 16
                n_batches = (nr_of_imgs+batch_size-1) // batch_size
                n_used_imgs = n_batches * batch_size


                #fid_score = 0
                lpips_score = 0
                lpips_model = lpips.PerceptualLoss(model='net-lin', net='alex', use_gpu=False)
                real_images = []
                reco_images = []
                rand_images = []

                FIDm = 3 # FID multiplier

                if session.getResoPhase() >= 3:
                    for o in range(n_batches*FIDm):
                        myz = Variable(torch.randn(args.n_label * batch_size, args.nz)).to(device=args.device)
                        myz = utils.normalize(myz)
                        myz, input_class = utils.split_labels_out_of_latent(myz)

                        random_image = generator(
                                myz,
                                input_class,
                                session.phase,
                                session.alpha).detach().data.cpu()

                        rand_images.append(random_image)

                        if o >= n_batches: #Reconstructions only up to n_batches, FIDs will take n_batches * 3 samples
                             continue

                        dataset = data.Utils.sample_data2(loader, batch_size, reso, session)

                        real_image, _ = next(dataset)
                        reco_image = Utils.reconstruct(real_image, encoder, generator, session)

                        t = torch.FloatTensor(real_image.size(0) * 2, real_image.size(1),
                                              real_image.size(2), real_image.size(3))

                        # compute metrics and write it to tensorboard (if the reso >= 32)


                        #lpips_score = "?"

                        crop_needed = (args.data == 'celeba' or args.data == 'celebaHQ' or args.data == 'ffhq')

                        if session.getResoPhase() >= 4 or not crop_needed: # 32x32 is minimum for LPIPS
                            if crop_needed:
                                real_image = Utils.face_as_cropped(real_image)
                                reco_image = Utils.face_as_cropped(reco_image)

                        real_images.append(real_image)
                        reco_images.append(reco_image.detach().data.cpu())

                    real_images = torch.cat(real_images, 0)
                    reco_images = torch.cat(reco_images, 0)
                    rand_images = torch.cat(rand_images, 0)


                    lpips_dist = lpips_model.forward(real_images, reco_images)
                    lpips_score = torch.mean(lpips_dist)

                    fid_score = calculate_fid_given_images(real_images, rand_images, batch_size=batch_size, cuda=False, dims=2048)

                    writer.add_scalar('LPIPS', lpips_score, session.sample_i)
                    writer.add_scalar('FID', fid_score, session.sample_i)

                    print("{}: FID = {}, LPIPS = {}".format(session.sample_i, fid_score, lpips_score))

        encoder.train()
        generator.train()