Exemplo n.º 1
0
    def reconstruction_dryrun(generator, encoder, loader, session):
        generator.eval()
        encoder.eval()

        utils.requires_grad(generator, False)
        utils.requires_grad(encoder, False)

        reso = 4 * 2**session.phase

        warmup_rounds = 200
        print('Warm-up rounds: {}'.format(warmup_rounds))

        if session.phase < 1:
            dataset = data.Utils.sample_data(loader, 4, reso)
        else:
            dataset = data.Utils.sample_data2(loader, 4, reso, session)
        real_image, _ = next(dataset)
        x = utils.conditional_to_cuda(Variable(real_image))

        for i in range(warmup_rounds):
            ex = encoder(x, session.phase, session.alpha,
                         args.use_ALQ).detach()
            ex, label = utils.split_labels_out_of_latent(ex)
            gex = generator(ex, label, session.phase, session.alpha).detach()

        encoder.train()
        generator.train()
Exemplo n.º 2
0
    def reconstruct_images(session):
        encoder, generator = session.encoder, session.generator
        batch, alpha, res, phase = session.cur_batch(
        ), session.alpha, session.cur_res(), session.phase
        generator.eval()
        encoder.eval()

        utils.requires_grad(generator, False)
        utils.requires_grad(encoder, False)

        nsamples = args.test_cols * args.test_rows
        session.test_data.init_epoch(nsamples, alpha, res, phase)
        input_ims = session.get_next_test_batch()  #test_data.next_batch()

        real_z = encoder(input_ims, session.phase, session.alpha)
        reco_ims = generator(real_z, session.phase, session.alpha).data

        # reco_ims    = Utils.reconstruct(input_ims, session)

        # join source and reconstructed images side by side
        out_ims = torch.cat((input_ims, reco_ims),
                            1).view(2 * nsamples, 1, reco_ims.shape[-2],
                                    reco_ims.shape[-1])
        sample_dir = '{}/recon'.format(args.save_dir)
        save_path = '{}/{}.png'.format(sample_dir,
                                       session.sample_i + 1).zfill(6)
        mkdir_assure(sample_dir)

        print('\nSaving a new collage ...')
        torchvision.utils.save_image(out_ims,
                                     save_path,
                                     nrow=args.test_cols,
                                     normalize=True,
                                     padding=0,
                                     scale_each=False)
        # utils.requires_grad(generator, True)
        # utils.requires_grad(encoder, True)
        encoder.train()
        generator.train()
Exemplo n.º 3
0
    def make(generator, session, writer):
        import bcolz
        if not os.path.exists(args.x_outpath):
            os.makedirs(args.x_outpath)

        utils.requires_grad(generator, False)

        z = bcolz.carray(rootdir=args.z_inpath)
        print("Loading external z from {}, shape:".format(args.z_inpath))
        print(np.shape(z))
        N = np.shape(z)[0]

        samplesRepeatN =  1 + int(N / 8) #Assume large files...
        samplesDone = 0
        for outer_count in range(samplesRepeatN):
            samplesN = min(8, N - samplesDone)
            if samplesN <= 0:
                break
            zrange = z[samplesDone:samplesDone+samplesN, :]
            myz, input_class = utils.split_labels_out_of_latent(torch.from_numpy(zrange.astype(np.float32)))
            new_imgs = generator(
                myz,
                input_class,
                session.phase,
                session.alpha).detach() #.data.cpu()

            for ii, img in enumerate(new_imgs):
                torchvision.utils.save_image(
                    img,
                    '{}/{}.png'.format(args.x_outpath, str(ii + outer_count*8)), #.zfill(6)
                    nrow=args.n_label,
                    normalize=True,
                    range=(-1, 1),
                    padding=0)

            samplesDone += samplesN
Exemplo n.º 4
0
def train(generator, encoder, g_running, train_data_loader, test_data_loader,
          session, total_steps, train_mode):
    pbar = tqdm(initial=session.sample_i, total=total_steps)

    benchmarking = False

    match_x = args.match_x
    generatedImagePool = None

    refresh_dataset = True
    refresh_imagePool = True

    # After the Loading stage, we cycle through successive Fade-in and Stabilization stages

    batch_count = 0

    reset_optimizers_on_phase_start = False

    # TODO Unhack this (only affects the episode count statistics anyway):
    if args.data != 'celebaHQ':
        epoch_len = len(train_data_loader(1, 4).dataset)
    else:
        epoch_len = train_data_loader._len['data4x4']

    if args.step_offset != 0:
        if args.step_offset == -1:
            args.step_offset = session.sample_i
        print("Step offset is {}".format(args.step_offset))
        session.phase += args.phase_offset
        session.alpha = 0.0

    while session.sample_i < total_steps:
        #######################  Phase Maintenance #######################

        steps_in_previous_phases = max(session.phase * args.images_per_stage,
                                       args.step_offset)

        sample_i_current_stage = session.sample_i - steps_in_previous_phases

        # If we can move to the next phase
        if sample_i_current_stage >= args.images_per_stage:
            if session.phase < args.max_phase:  # If any phases left
                iteration_levels = int(sample_i_current_stage /
                                       args.images_per_stage)
                session.phase += iteration_levels
                sample_i_current_stage -= iteration_levels * args.images_per_stage
                match_x = args.match_x  # Reset to non-matching phase
                print(
                    "iteration B alpha={} phase {} will be reduced to 1 and [max]"
                    .format(sample_i_current_stage, session.phase))

                refresh_dataset = True
                refresh_imagePool = True  # Reset the pool to avoid images of 2 different resolutions in the pool

                if reset_optimizers_on_phase_start:
                    utils.requires_grad(generator)
                    utils.requires_grad(encoder)
                    generator.zero_grad()
                    encoder.zero_grad()
                    session.reset_opt()
                    print("Optimizers have been reset.")

        reso = 4 * 2**session.phase

        # If we can switch from fade-training to stable-training
        if sample_i_current_stage >= args.images_per_stage / 2:
            if session.alpha < 1.0:
                refresh_dataset = True  # refresh dataset generator since no longer have to fade
            match_x = args.match_x * args.matching_phase_x
        else:
            match_x = args.match_x

        session.alpha = min(
            1, sample_i_current_stage * 2.0 / args.images_per_stage
        )  # For 100k, it was 0.00002 = 2.0 / args.images_per_stage

        if refresh_dataset:
            train_dataset = data.Utils.sample_data2(train_data_loader,
                                                    batch_size(reso), reso,
                                                    session)
            refresh_dataset = False
            print("Refreshed dataset. Alpha={} and iteration={}".format(
                session.alpha, sample_i_current_stage))
        if refresh_imagePool:
            imagePoolSize = 200 if reso < 256 else 100
            generatedImagePool = utils.ImagePool(
                imagePoolSize
            )  #Reset the pool to avoid images of 2 different resolutions in the pool
            refresh_imagePool = False
            print('Image pool created with size {} because reso is {}'.format(
                imagePoolSize, reso))

        ####################### Training init #######################

        z = utils.conditional_to_cuda(
            Variable(torch.FloatTensor(batch_size(reso), args.nz, 1, 1)),
            args.gpu_count > 1)
        KL_minimizer = KLN01Loss(direction=args.KL, minimize=True)
        KL_maximizer = KLN01Loss(direction=args.KL, minimize=False)

        stats = {}

        #one = torch.FloatTensor([1]).cuda(async=(args.gpu_count>1))
        one = torch.FloatTensor([1]).cuda(non_blocking=(args.gpu_count > 1))

        try:
            real_image, _ = next(train_dataset)
        except (OSError, StopIteration):
            train_dataset = data.Utils.sample_data2(train_data_loader,
                                                    batch_size(reso), reso,
                                                    session)
            real_image, _ = next(train_dataset)

        ####################### DISCRIMINATOR / ENCODER ###########################

        utils.switch_grad_updates_to_first_of(encoder, generator)
        encoder.zero_grad()

        #x = Variable(real_image).cuda(async=(args.gpu_count>1))
        x = Variable(real_image).cuda(non_blocking=(args.gpu_count > 1))
        kls = ""
        if train_mode == config.MODE_GAN:

            # Discriminator for real samples
            real_predict, _ = encoder(x, session.phase, session.alpha,
                                      args.use_ALQ)
            real_predict = real_predict.mean() \
                - 0.001 * (real_predict ** 2).mean()
            real_predict.backward(-one)  # Towards 1

            # (1) Generator => D. Identical to (2) see below

            fake_predict, fake_image = D_prediction_of_G_output(
                generator, encoder, session.phase, session.alpha)
            fake_predict.backward(one)

            # Grad penalty

            grad_penalty = get_grad_penalty(encoder, x, fake_image,
                                            session.phase, session.alpha)
            grad_penalty.backward()

        elif train_mode == config.MODE_CYCLIC:
            e_losses = []

            # e(X)

            real_z = encoder(x, session.phase, session.alpha, args.use_ALQ)
            if args.use_real_x_KL:
                # KL_real: - \Delta( e(X) , Z ) -> max_e
                KL_real = KL_minimizer(real_z) * args.real_x_KL_scale
                e_losses.append(KL_real)

                stats['real_mean'] = KL_minimizer.samples_mean.data.mean()
                stats['real_var'] = KL_minimizer.samples_var.data.mean()
                stats['KL_real'] = KL_real  #FIXME .data[0]
                kls = "{0:.3f}".format(stats['KL_real'])

            # The final entries are the label. Normal case, just 1. Extract it/them, and make it [b x 1]:

            real_z, label = utils.split_labels_out_of_latent(real_z)
            recon_x = generator(real_z, label, session.phase, session.alpha)
            if args.use_loss_x_reco:
                # match_x: E_x||g(e(x)) - x|| -> min_e
                err = utils.mismatch(recon_x, x, args.match_x_metric) * match_x
                e_losses.append(err)
                stats['x_reconstruction_error'] = err  #FIXME .data[0]

            args.use_wpgan_grad_penalty = False
            grad_penalty = 0.0

            if args.use_loss_fake_D_KL:
                # TODO: The following codeblock is essentially the same as the KL_minimizer part on G side. Unify
                utils.populate_z(z, args.nz + args.n_label, args.noise,
                                 batch_size(reso))
                z, label = utils.split_labels_out_of_latent(z)
                fake = generator(z, label, session.phase,
                                 session.alpha).detach()

                if session.alpha >= 1.0:
                    fake = generatedImagePool.query(fake.data)

                # e(g(Z))
                egz = encoder(fake, session.phase, session.alpha, args.use_ALQ)

                # KL_fake: \Delta( e(g(Z)) , Z ) -> max_e
                KL_fake = KL_maximizer(egz) * args.fake_D_KL_scale

                # Added by Igor
                m = args.kl_margin
                if m > 0.0 and session.phase >= 5:
                    KL_loss = torch.min(
                        torch.ones_like(KL_fake) * m / 2,
                        torch.max(-torch.ones_like(KL_fake) * m,
                                  KL_real + KL_fake))
                    # KL_fake is always negative with abs value typically larger than KL_real.
                    # Hence, the sum is negative, and must be gapped so that the minimum is the negative of the margin.
                else:
                    KL_loss = KL_real + KL_fake
                e_losses.append(KL_loss)
                #e_losses.append(KL_fake)

                stats['fake_mean'] = KL_maximizer.samples_mean.data.mean()
                stats['fake_var'] = KL_maximizer.samples_var.data.mean()
                stats['KL_fake'] = -KL_fake  #FIXME .data[0]
                kls = "{0}/{1:.3f}".format(kls, stats['KL_fake'])

                if args.use_wpgan_grad_penalty:
                    grad_penalty = get_grad_penalty(encoder, x, fake,
                                                    session.phase,
                                                    session.alpha)

            # Update e
            if len(e_losses) > 0:
                e_loss = sum(e_losses)
                stats['E_loss'] = e_loss.detach().cpu().float().numpy(
                )  #np.float32(e_loss.data)
                e_loss.backward()

                if args.use_wpgan_grad_penalty:
                    grad_penalty.backward()
                    stats['Grad_penalty'] = grad_penalty.data

                #book-keeping
                disc_loss_val = e_loss  #FIXME .data[0]

        session.optimizerD.step()

        torch.cuda.empty_cache()

        ######################## GENERATOR / DECODER #############################

        if (batch_count + 1) % args.n_critic == 0:
            utils.switch_grad_updates_to_first_of(generator, encoder)

            for _ in range(args.n_generator):
                generator.zero_grad()
                g_losses = []

                if train_mode == config.MODE_GAN:
                    fake_predict, _ = D_prediction_of_G_output(
                        generator, encoder, session.phase, session.alpha)
                    loss = -fake_predict
                    g_losses.append(loss)

                elif train_mode == config.MODE_CYCLIC:  #TODO We push the z variable around here like idiots

                    def KL_of_encoded_G_output(generator, z):
                        utils.populate_z(z, args.nz + args.n_label, args.noise,
                                         batch_size(reso))
                        z, label = utils.split_labels_out_of_latent(z)
                        fake = generator(z, label, session.phase,
                                         session.alpha)

                        egz = encoder(fake, session.phase, session.alpha,
                                      args.use_ALQ)
                        # KL_fake: \Delta( e(g(Z)) , Z ) -> min_g
                        return egz, label, KL_minimizer(
                            egz) * args.fake_G_KL_scale, z

                    egz, label, kl, z = KL_of_encoded_G_output(generator, z)

                    if args.use_loss_KL_z:
                        g_losses.append(kl)  # G minimizes this KL
                        stats['KL(Phi(G))'] = kl  #FIXME .data[0]
                        kls = "{0}/{1:.3f}".format(kls, stats['KL(Phi(G))'])

                    if args.use_loss_z_reco:
                        z = torch.cat((z, label), 1)
                        z_diff = utils.mismatch(
                            egz, z, args.match_z_metric
                        ) * args.match_z  # G tries to make the original z and encoded z match
                        g_losses.append(z_diff)

                if len(g_losses) > 0:
                    loss = sum(g_losses)
                    stats['G_loss'] = loss.detach().cpu().float().numpy(
                    )  #np.float32(loss.data)
                    loss.backward()

                    # Book-keeping only:
                    gen_loss_val = loss  #FIXME .data[0]

                session.optimizerG.step()

                torch.cuda.empty_cache()

                if train_mode == config.MODE_CYCLIC:
                    if args.use_loss_z_reco:
                        stats[
                            'z_reconstruction_error'] = z_diff  #FIXME .data[0]

            accumulate(g_running, generator)

            del z, x, one, real_image, real_z, KL_real, label, recon_x, fake, egz, KL_fake, kl, z_diff

            if train_mode == config.MODE_CYCLIC:
                if args.use_TB:
                    for key, val in stats.items():
                        writer.add_scalar(key, val, session.sample_i)
                elif batch_count % 100 == 0:
                    print(stats)

            if args.use_TB:
                writer.add_scalar('LOD', session.phase + session.alpha,
                                  session.sample_i)

        ########################  Statistics ########################

        b = batch_size_by_phase(session.phase)
        zr, xr = (stats['z_reconstruction_error'],
                  stats['x_reconstruction_error']
                  ) if train_mode == config.MODE_CYCLIC else (0.0, 0.0)
        e = (session.sample_i / float(epoch_len))
        pbar.set_description((
            '{0}; it: {1}; phase: {2}; b: {3:.1f}; Alpha: {4:.3f}; Reso: {5}; E: {6:.2f}; KL(real/fake/fakeG): {7}; z-reco: {8:.2f}; x-reco {9:.3f}; real_var {10:.4f}'
        ).format(batch_count + 1, session.sample_i + 1, session.phase, b,
                 session.alpha, reso, e, kls, zr, xr, stats['real_var']))
        #(f'{i + 1}; it: {iteration+1}; b: {b:.1f}; G: {gen_loss_val:.5f}; D: {disc_loss_val:.5f};'
        # f' Grad: {grad_loss_val:.5f}; Alpha: {alpha:.3f}; Reso: {reso}; S-mean: {real_mean:.3f}; KL(real/fake/fakeG): {kls}; z-reco: {zr:.2f}'))

        pbar.update(batch_size(reso))
        session.sample_i += batch_size(reso)  # if not benchmarking else 100
        batch_count += 1

        ########################  Saving ########################

        if batch_count % args.checkpoint_cycle == 0:
            for postfix in {'latest', str(session.sample_i).zfill(6)}:
                session.save_all('{}/{}_state'.format(args.checkpoint_dir,
                                                      postfix))

            print("Checkpointed to {}".format(session.sample_i))

        ########################  Tests ########################

        try:
            evaluate.tests_run(
                g_running,
                encoder,
                test_data_loader,
                session,
                writer,
                reconstruction=(batch_count % 800 == 0),
                interpolation=(batch_count % 800 == 0),
                collated_sampling=(batch_count % 800 == 0),
                individual_sampling=(
                    batch_count %
                    (args.images_per_stage / batch_size(reso) / 4) == 0))
        except (OSError, StopIteration):
            print("Skipped periodic tests due to an exception.")

    pbar.close()
Exemplo n.º 5
0
    def interpolate_images(generator,
                           encoder,
                           loader,
                           epoch,
                           prefix,
                           session,
                           writer=None):
        generator.eval()
        encoder.eval()

        utils.requires_grad(generator, False)
        utils.requires_grad(encoder, False)

        nr_of_imgs = 4  # "Corners"
        reso = 4 * 2**session.phase
        if True:
            #if Utils.interpolation_set_x is None or Utils.interpolation_set_x.size(2) != reso or (phase >= 1 and alpha < 1.0):
            if session.phase < 1:
                dataset = data.Utils.sample_data(loader, nr_of_imgs, reso)
            else:
                dataset = data.Utils.sample_data2(loader, nr_of_imgs, reso,
                                                  session)
            real_image, _ = next(dataset)
            Utils.interpolation_set_x = utils.conditional_to_cuda(
                Variable(real_image, volatile=True))

        latent_reso_hor = 8
        latent_reso_ver = 8

        x = Utils.interpolation_set_x

        z0 = encoder(Variable(x), session.phase, session.alpha,
                     args.use_ALQ).detach()

        t = torch.FloatTensor(
            latent_reso_hor * (latent_reso_ver + 1) + nr_of_imgs, x.size(1),
            x.size(2), x.size(3))
        t[0:nr_of_imgs] = x.data[:]

        special_dir = args.save_dir if not args.aux_outpath else args.aux_outpath

        if not os.path.exists(special_dir):
            os.makedirs(special_dir)

        for o_i in range(nr_of_imgs):
            single_save_path = '{}{}/interpolations_{}_{}_{}_orig_{}.png'.format(
                special_dir, prefix, session.phase, epoch, session.alpha, o_i)
            grid = torchvision.utils.save_image(
                x.data[o_i] / 2 + 0.5, single_save_path, nrow=1, padding=0
            )  #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)?

        # Origs on the first row here
        # Corners are: z0[0] ... z0[1]
        #                .
        #                .
        #              z0[2] ... z0[3]

        delta_z_ver0 = ((z0[2] - z0[0]) / (latent_reso_ver - 1))
        delta_z_verN = ((z0[3] - z0[1]) / (latent_reso_ver - 1))
        for y_i in range(latent_reso_ver):
            if False:  #Linear interpolation
                z0_x0 = z0[0] + y_i * delta_z_ver0
                z0_xN = z0[1] + y_i * delta_z_verN
                delta_z_hor = (z0_xN - z0_x0) / (latent_reso_hor - 1)
                z0_x = Variable(
                    torch.FloatTensor(latent_reso_hor, z0_x0.size(0)))

                for x_i in range(latent_reso_hor):
                    z0_x[x_i] = z0_x0 + x_i * delta_z_hor

            if True:  #Spherical
                t_y = float(y_i) / (latent_reso_ver - 1)
                #z0_y = Variable(torch.FloatTensor(latent_reso_ver, z0.size(0)))
                z0_y1 = Utils.slerp(z0[0].data, z0[2].data, t_y)
                z0_y2 = Utils.slerp(z0[1].data, z0[3].data, t_y)
                z0_x = Variable(
                    torch.FloatTensor(latent_reso_hor, z0[0].size(0)))
                for x_i in range(latent_reso_hor):
                    t_x = float(x_i) / (latent_reso_hor - 1)
                    z0_x[x_i] = Utils.slerp(z0_y1, z0_y2, t_x)

            z0_x, label = utils.split_labels_out_of_latent(z0_x)
            gex = generator(z0_x, label, session.phase, session.alpha).detach()

            # Recall that yi=0 is the original's row:
            t[(y_i + 1) * latent_reso_ver:(y_i + 2) *
              latent_reso_ver] = gex.data[:]

            for x_i in range(latent_reso_hor):
                single_save_path = '{}{}/interpolations_{}_{}_{}_{}x{}.png'.format(
                    special_dir, prefix, session.phase, epoch, session.alpha,
                    y_i, x_i)
                grid = torchvision.utils.save_image(
                    gex.data[x_i] / 2 + 0.5,
                    single_save_path,
                    nrow=1,
                    padding=0
                )  #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)?

        save_path = '{}{}/interpolations_{}_{}_{}.png'.format(
            special_dir, prefix, session.phase, epoch, session.alpha)
        grid = torchvision.utils.save_image(
            t / 2 + 0.5, save_path, nrow=latent_reso_ver, padding=0
        )  #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)?
        # Hacky but this is an easy way to rescale the images to nice big lego format:
        if session.phase < 4:
            im = Image.open(save_path)
            im2 = im.resize((1024, 1024))
            im2.save(save_path)

        if writer:
            writer.add_images('interpolation_latest_{}'.format(session.phase),
                              t / 2 + 0.5, session.phase)
            # Igor: changed add_image to add_images

        generator.train()
        encoder.train()
Exemplo n.º 6
0
    def generate_intermediate_samples(generator,
                                      global_i,
                                      session,
                                      writer=None,
                                      collateImages=True):
        generator.eval()

        utils.requires_grad(generator, False)

        # Total number is samplesRepeatN * colN * rowN
        # e.g. for 51200 samples, outcome is 5*80*128. Please only give multiples of 128 here.
        samplesRepeatN = int(args.sample_N / 128) if not collateImages else 1
        reso = 4 * 2**session.phase

        if not collateImages:
            special_dir = '../metrics/{}/{}/{}'.format(args.data, reso,
                                                       str(global_i).zfill(6))
            while os.path.exists(special_dir):
                special_dir += '_'

            os.makedirs(special_dir)

        for outer_count in range(samplesRepeatN):

            colN = 1 if not collateImages else min(
                10, int(np.ceil(args.sample_N / 4.0)))
            rowN = 128 if not collateImages else min(
                5, int(np.ceil(args.sample_N / 4.0)))
            images = []
            for _ in range(rowN):
                myz = utils.conditional_to_cuda(
                    Variable(torch.randn(args.n_label * colN, args.nz)))
                myz = utils.normalize(myz)
                myz, input_class = utils.split_labels_out_of_latent(myz)

                new_imgs = generator(myz, input_class, session.phase,
                                     session.alpha).detach().data.cpu()

                images.append(new_imgs)

            if collateImages:
                sample_dir = '{}/sample'.format(args.save_dir)
                if not os.path.exists(sample_dir):
                    os.makedirs(sample_dir)

                save_path = '{}/{}.png'.format(sample_dir,
                                               str(global_i + 1).zfill(6))
                torchvision.utils.save_image(torch.cat(images, 0),
                                             save_path,
                                             nrow=args.n_label * colN,
                                             normalize=True,
                                             range=(-1, 1),
                                             padding=0)
                # Hacky but this is an easy way to rescale the images to nice big lego format:
                im = Image.open(save_path)
                im2 = im.resize((1024, 512 if reso < 256 else 1024))
                im2.save(save_path)

                if writer:
                    writer.add_images(
                        'samples_latest_{}'.format(session.phase),
                        torch.cat(images, 0), session.phase)
                    # Igor: changed add_image to add_images
            else:
                for ii, img in enumerate(images):
                    torchvision.utils.save_image(
                        img,
                        '{}/{}_{}.png'.format(special_dir,
                                              str(global_i + 1).zfill(6),
                                              ii + outer_count * 128),
                        nrow=args.n_label * colN,
                        normalize=True,
                        range=(-1, 1),
                        padding=0)

        generator.train()
Exemplo n.º 7
0
    def reconstruct_images(generator,
                           encoder,
                           loader,
                           global_i,
                           nr_of_imgs,
                           prefix,
                           reals,
                           reconstructions,
                           session,
                           writer=None):  #of the form"/[dir]"
        generator.eval()
        encoder.eval()

        utils.requires_grad(generator, False)
        utils.requires_grad(encoder, False)

        if reconstructions and nr_of_imgs > 0:
            reso = 4 * 2**session.phase

            # First, create the single grid

            if Utils.reconstruction_set_x is None or Utils.reconstruction_set_x.size(
                    2) != reso or (session.phase >= 1 and session.alpha < 1.0):
                if session.phase < 1:
                    dataset = data.Utils.sample_data(loader,
                                                     min(nr_of_imgs, 16), reso)
                else:
                    dataset = data.Utils.sample_data2(loader,
                                                      min(nr_of_imgs, 16),
                                                      reso, session)
                Utils.reconstruction_set_x, _ = next(dataset)

            reco_image = Utils.reconstruct(Utils.reconstruction_set_x, encoder,
                                           generator, session)

            t = torch.FloatTensor(
                Utils.reconstruction_set_x.size(0) * 2,
                Utils.reconstruction_set_x.size(1),
                Utils.reconstruction_set_x.size(2),
                Utils.reconstruction_set_x.size(3))

            t[0::2] = Utils.reconstruction_set_x[:]
            t[1::2] = reco_image

            save_path = '{}{}/reconstructions_{}_{}_{}.png'.format(
                args.save_dir, prefix, session.phase, global_i, session.alpha)
            grid = torchvision.utils.save_image(t[:nr_of_imgs] / 2 + 0.5,
                                                save_path,
                                                padding=0)

            # Hacky but this is an easy way to rescale the images to nice big lego format:
            if session.phase < 4:
                h = np.ceil(nr_of_imgs / 8)
                h_scale = min(1.0, h / 8.0)
                im = Image.open(save_path)
                im2 = im.resize((1024, int(1024 * h_scale)))
                im2.save(save_path)

            if writer:
                writer.add_images(
                    'reconstruction_latest_{}'.format(session.phase),
                    t[:nr_of_imgs] / 2 + 0.5, session.phase)
                # Igor: changed add_image to add_images

            # Second, create the Individual images:
            if session.phase < 1:
                dataset = data.Utils.sample_data(loader, 1, reso)
            else:
                dataset = data.Utils.sample_data2(loader, 1, reso, session)

            special_dir = '{}/{}'.format(
                args.save_dir if not args.aux_outpath else args.aux_outpath,
                str(global_i).zfill(6))

            if not os.path.exists(special_dir):
                os.makedirs(special_dir)

            print("Save images: Alpha={}, phase={}, images={}, at {}".format(
                session.alpha, session.phase, nr_of_imgs, special_dir))
            for o in range(nr_of_imgs):
                if o % 500 == 0:
                    print(o)

                real_image, _ = next(dataset)
                reco_image = Utils.reconstruct(real_image, encoder, generator,
                                               session)

                t = torch.FloatTensor(
                    real_image.size(0) * 2, real_image.size(1),
                    real_image.size(2), real_image.size(3))

                save_path_A = '{}/{}_orig.png'.format(special_dir, o)
                save_path_B = '{}/{}_pine.png'.format(special_dir, o)

                torchvision.utils.save_image(real_image[0] / 2 + 0.5,
                                             save_path_A,
                                             padding=0)
                torchvision.utils.save_image(reco_image[0] / 2 + 0.5,
                                             save_path_B,
                                             padding=0)

        encoder.train()
        generator.train()
Exemplo n.º 8
0
    def generate_intermediate_samples(generator, global_i, session, writer=None, collateImages = True):
        generator.eval()
        with torch.no_grad():        

            save_root = args.sample_dir if args.sample_dir != None else args.save_dir

            utils.requires_grad(generator, False)
            
            # Total number is samplesRepeatN * colN * rowN
            # e.g. for 51200 samples, outcome is 5*80*128. Please only give multiples of 128 here.
            samplesRepeatN = max(1, int(args.sample_N / 128)) if not collateImages else 1
            reso = session.getReso()

            if not collateImages:
                special_dir = '{}/sample/{}'.format(save_root, str(global_i).zfill(6))
                while os.path.exists(special_dir):
                    special_dir += '_'

                os.makedirs(special_dir)
            
            for outer_count in range(samplesRepeatN):
                
                colN = 1 if not collateImages else min(10, int(np.ceil(args.sample_N / 4.0)))
                rowN = 128 if not collateImages else min(5, int(np.ceil(args.sample_N / 4.0)))
                images = []
                myzs = []
                for _ in range(rowN):
                    myz = {}
                    for i in range(2):
                       myz0 = Variable(torch.randn(args.n_label * colN, args.nz+1)).to(device=args.device)
                       myz[i] = utils.normalize(myz0)
                       #myz[i], input_class = utils.split_labels_out_of_latent(myz0)

                    if args.randomMixN <= 1:
                        new_imgs = generator(
                            myz[0],
                            None, # input_class,
                            session.phase,
                            session.alpha).detach() #.data.cpu()
                    else:
                        max_i = session.phase
                        style_layer_begin = 2 #np.random.randint(low=1, high=(max_i+1))
                        print("Cut at layer {} for the next {} random samples.".format(style_layer_begin, (myz[0].shape)))
                        
                        new_imgs = utils.gen_seq([ (myz[0], 0, style_layer_begin),
                                                   (myz[1], style_layer_begin, -1)
                                                 ], generator, session).detach()
                        
                    images.append(new_imgs)
                    myzs.append(myz[0]) # TODO: Support mixed latents with rejection sampling

                if args.rejection_sampling > 0 and not collateImages:
                    filtered_images = []
                    batch_size = 8
                    for b in range(int(len(images) / batch_size)):
                        img_batch = images[b*batch_size:(b+1)*batch_size]
                        reco_z = session.encoder(Variable(torch.cat(img_batch,0)), session.getResoPhase(), session.alpha, args.use_ALQ).detach()
                        cost_z = utils.mismatchV(torch.cat(myzs[b*batch_size:(b+1)*batch_size],0), reco_z, 'cos')
                        _, ii = torch.sort(cost_z)
                        keeper = args.rejection_sampling / 100.0
                        keepN = max(1, int(len(ii)*keeper))
                        for iindex in ii[:keepN]:
                            filtered_images.append(img_batch[iindex])
#                        filtered_images.append(img_batch[ii[:keepN]] ) #retain the best ones only
                    images = filtered_images

                if collateImages:
                    sample_dir = '{}/sample'.format(save_root)
                    if not os.path.exists(sample_dir):
                        os.makedirs(sample_dir)
                    
                    save_path = '{}/{}.png'.format(sample_dir,str(global_i + 1).zfill(6))
                    torchvision.utils.save_image(
                        torch.cat(images, 0),
                        save_path,
                        nrow=args.n_label * colN,
                        normalize=True,
                        range=(-1, 1),
                        padding=0)
                    # Hacky but this is an easy way to rescale the images to nice big lego format:
                    im = Image.open(save_path)
                    im2 = im.resize((1024, 512 if reso < 256 else 1024))
                    im2.save(save_path)

                    #if writer:
                    #    writer.add_image('samples_latest_{}'.format(session.phase), torch.cat(images, 0), session.phase)
                else:
                    print("Generating samples at {}...".format(special_dir))
                    for ii, img in enumerate(images):
                        ipath =  '{}/{}_{}.png'.format(special_dir, str(global_i + 1).zfill(6), ii+outer_count*128)
                        #print(ipath)
                        torchvision.utils.save_image(
                            img,
                            ipath,
                            nrow=args.n_label * colN,
                            normalize=True,
                            range=(-1, 1),
                            padding=0)

        generator.train()
Exemplo n.º 9
0
    def interpolate_images(generator, encoder, loader, epoch, prefix, session, writer=None):
        generator.eval()
        encoder.eval()

        with torch.no_grad():
            utils.requires_grad(generator, False)
            utils.requires_grad(encoder, False)

            nr_of_imgs = 4 if not args.hexmode else 6 # "Corners"
            reso = session.getReso()
            if True:
            #if Utils.interpolation_set_x is None or Utils.interpolation_set_x.size(2) != reso or (phase >= 1 and alpha < 1.0):
                if session.getResoPhase() < 1:
                    dataset = data.Utils.sample_data(loader, nr_of_imgs, reso)
                else:
                    dataset = data.Utils.sample_data2(loader, nr_of_imgs, reso, session)
                real_image, _ = next(dataset)
                Utils.interpolation_set_x = Variable(real_image, volatile=True).to(device=args.device)

            latent_reso_hor = 8
            latent_reso_ver = 8

            x = Utils.interpolation_set_x

            if args.hexmode:
                x = x[:nr_of_imgs] #Corners
                z0 = encoder(Variable(x), session.getResoPhase(), session.alpha, args.use_ALQ).detach()

                X = np.array([-1.7321,0.0000 ,1.7321 ,-2.5981,-0.8660,0.8660 ,2.5981 ,-3.4641,-1.7321,0.0000 ,1.7321 ,3.4641 ,-2.5981,-0.8660,0.8660 ,2.5981 ,-1.7321,0.0000,1.7321])
                Y = np.array([-3.0000,-3.0000,-3.0000,-1.5000,-1.5000,-1.5000,-1.5000,0.0000,0.0000,0.0000,0.0000,0.0000,1.5000,1.5000,1.5000,1.5000,3.0000,3.0000,3.0000])
                corner_indices = np.array([0, 2, 7, 11, 16, 18])
                inter_indices = [i for i in range(19) if not i in corner_indices]
                edge_indices = np.array([1, 3, 6, 12, 15, 17])
                #distances_to_corners = np.sqrt(np.power(X - X[corner_indices[0]], 2) + np.power(Y - Y[corner_indices[0]], 2))
                distances_to_corners = [np.sqrt(np.power(X - X[corner_indices[i]], 2) + np.power(Y - Y[corner_indices[i]], 2)) for i in range(6)]
                weights = 1.0 - distances_to_corners / np.max(distances_to_corners)
                z0_x = torch.zeros((19,args.nz+args.n_label)).to(device=args.device)
                z0_x[corner_indices,:] = z0
                z0_x[edge_indices[0], :] = Utils.slerp(z0[0], z0[1], 0.5)
                z0_x[edge_indices[1], :] = Utils.slerp(z0[0], z0[2], 0.5)
                z0_x[edge_indices[2], :] = Utils.slerp(z0[1], z0[3], 0.5)
                z0_x[edge_indices[3], :] = Utils.slerp(z0[2], z0[4], 0.5)
                z0_x[edge_indices[4], :] = Utils.slerp(z0[3], z0[5], 0.5)
                z0_x[edge_indices[5], :] = Utils.slerp(z0[4], z0[5], 0.5)
                # Linear:
                #z0x[inter_indices,:] = (weights.T @ z0)[inter_indices,:]
                z0_x[inter_indices,:] = (torch.from_numpy(weights.T.astype(np.float32)).to(device=args.device) @ z0)[inter_indices,:]
                z0_x = utils.normalize(z0_x)
            else:
                z0 = encoder(Variable(x), session.getResoPhase(), session.alpha, args.use_ALQ).detach()

            t = torch.FloatTensor(latent_reso_hor * (latent_reso_ver+1) + nr_of_imgs, x.size(1),
                                x.size(2), x.size(3))
            t[0:nr_of_imgs] = x.data[:]

            save_root = args.sample_dir if args.sample_dir != None else args.save_dir
            special_dir = save_root if not args.aux_outpath else args.aux_outpath

            if not os.path.exists(special_dir):
                os.makedirs(special_dir)

            for o_i in range(nr_of_imgs):
                single_save_path = '{}{}/hex_interpolations_{}_{}_{}_orig_{}.png'.format(special_dir, prefix, session.phase, epoch, session.alpha, o_i)
                grid = torchvision.utils.save_image(x.data[o_i] / 2 + 0.5, single_save_path, nrow=1, padding=0) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)?

            if args.hexmode:
                z0_x, label = utils.split_labels_out_of_latent(z0_x)
                gex = generator(z0_x, label, session.phase, session.alpha).detach()                

                for x_i in range(19):
                    single_save_path = '{}{}/hex_interpolations_{}_{}_{}x{}.png'.format(special_dir, prefix, session.phase, epoch, session.alpha, x_i)
                    grid = torchvision.utils.save_image(gex.data[x_i] / 2 + 0.5, single_save_path, nrow=1, padding=0) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)?

                return

            # Origs on the first row here
            # Corners are: z0[0] ... z0[1]
            #                .   
            #                .
            #              z0[2] ... z0[3]                

            delta_z_ver0 = ((z0[2] - z0[0]) / (latent_reso_ver - 1))
            delta_z_verN = ((z0[3] - z0[1]) / (latent_reso_ver - 1))

            for y_i in range(latent_reso_ver):
                if False: #Linear interpolation
                    z0_x0 = z0[0] + y_i * delta_z_ver0
                    z0_xN = z0[1] + y_i * delta_z_verN
                    delta_z_hor = (z0_xN - z0_x0) / (latent_reso_hor - 1)
                    z0_x = Variable(torch.FloatTensor(latent_reso_hor, z0_x0.size(0)))

                    for x_i in range(latent_reso_hor):
                        z0_x[x_i] = z0_x0 + x_i * delta_z_hor

                if True: #Spherical
                    t_y = float(y_i) / (latent_reso_ver-1)
                    #z0_y = Variable(torch.FloatTensor(latent_reso_ver, z0.size(0)))
                    
                    z0_y1 = Utils.slerp(z0[0].data.cpu().numpy(), z0[2].data.cpu().numpy(), t_y)
                    z0_y2 = Utils.slerp(z0[1].data.cpu().numpy(), z0[3].data.cpu().numpy(), t_y)
                    z0_x = Variable(torch.FloatTensor(latent_reso_hor, z0[0].size(0)))
                    for x_i in range(latent_reso_hor):
                        t_x = float(x_i) / (latent_reso_hor-1)
                        z0_x[x_i] = torch.from_numpy( Utils.slerp(z0_y1, z0_y2, t_x) )

                z0_x, label = utils.split_labels_out_of_latent(z0_x)
                gex = generator(z0_x, label, session.phase, session.alpha).detach()                
                
                # Recall that yi=0 is the original's row:
                t[(y_i+1) * latent_reso_ver:(y_i+2)* latent_reso_ver] = gex.data[:]

                for x_i in range(latent_reso_hor):
                    single_save_path = '{}{}/interpolations_{}_{}_{}_{}x{}.png'.format(special_dir, prefix, session.phase, epoch, session.alpha, y_i, x_i)
                    grid = torchvision.utils.save_image(gex.data[x_i] / 2 + 0.5, single_save_path, nrow=1, padding=0) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)?
            
            save_path = '{}{}/interpolations_{}_{}_{}.png'.format(special_dir, prefix, session.phase, epoch, session.alpha)
            grid = torchvision.utils.save_image(t / 2 + 0.5, save_path, nrow=latent_reso_ver, padding=0) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)?
            # Hacky but this is an easy way to rescale the images to nice big lego format:
            if session.getResoPhase() < 4:
                im = Image.open(save_path)
                im2 = im.resize((1024, 1024))
                im2.save(save_path)        

            #if writer:
            #    writer.add_image('interpolation_latest_{}'.format(session.phase), t / 2 + 0.5, session.phase)

        generator.train()
        encoder.train()
Exemplo n.º 10
0
    def eval_metrics(generator, encoder, loader, global_i, nr_of_imgs, prefix, reals, reconstructions, session,
                           writer=None):  # of the form"/[dir]"
        generator.eval()
        encoder.eval()

        with torch.no_grad():
            utils.requires_grad(generator, False)
            utils.requires_grad(encoder, False)

            if reconstructions and nr_of_imgs > 0:
                reso = session.getReso()

                batch_size = 16
                n_batches = (nr_of_imgs+batch_size-1) // batch_size
                n_used_imgs = n_batches * batch_size


                #fid_score = 0
                lpips_score = 0
                lpips_model = lpips.PerceptualLoss(model='net-lin', net='alex', use_gpu=False)
                real_images = []
                reco_images = []
                rand_images = []

                FIDm = 3 # FID multiplier

                if session.getResoPhase() >= 3:
                    for o in range(n_batches*FIDm):
                        myz = Variable(torch.randn(args.n_label * batch_size, args.nz)).to(device=args.device)
                        myz = utils.normalize(myz)
                        myz, input_class = utils.split_labels_out_of_latent(myz)

                        random_image = generator(
                                myz,
                                input_class,
                                session.phase,
                                session.alpha).detach().data.cpu()

                        rand_images.append(random_image)

                        if o >= n_batches: #Reconstructions only up to n_batches, FIDs will take n_batches * 3 samples
                             continue

                        dataset = data.Utils.sample_data2(loader, batch_size, reso, session)

                        real_image, _ = next(dataset)
                        reco_image = Utils.reconstruct(real_image, encoder, generator, session)

                        t = torch.FloatTensor(real_image.size(0) * 2, real_image.size(1),
                                              real_image.size(2), real_image.size(3))

                        # compute metrics and write it to tensorboard (if the reso >= 32)


                        #lpips_score = "?"

                        crop_needed = (args.data == 'celeba' or args.data == 'celebaHQ' or args.data == 'ffhq')

                        if session.getResoPhase() >= 4 or not crop_needed: # 32x32 is minimum for LPIPS
                            if crop_needed:
                                real_image = Utils.face_as_cropped(real_image)
                                reco_image = Utils.face_as_cropped(reco_image)

                        real_images.append(real_image)
                        reco_images.append(reco_image.detach().data.cpu())

                    real_images = torch.cat(real_images, 0)
                    reco_images = torch.cat(reco_images, 0)
                    rand_images = torch.cat(rand_images, 0)


                    lpips_dist = lpips_model.forward(real_images, reco_images)
                    lpips_score = torch.mean(lpips_dist)

                    fid_score = calculate_fid_given_images(real_images, rand_images, batch_size=batch_size, cuda=False, dims=2048)

                    writer.add_scalar('LPIPS', lpips_score, session.sample_i)
                    writer.add_scalar('FID', fid_score, session.sample_i)

                    print("{}: FID = {}, LPIPS = {}".format(session.sample_i, fid_score, lpips_score))

        encoder.train()
        generator.train()
Exemplo n.º 11
0
    def reconstruct_images(generator, encoder, loader, global_i, nr_of_imgs, prefix, reals, reconstructions, session, writer=None): #of the form"/[dir]"
        generator.eval()
        encoder.eval()

        imgs_per_style_mixture_run = min(nr_of_imgs, 8)

        with torch.no_grad():
            utils.requires_grad(generator, False)
            utils.requires_grad(encoder, False)

            save_root = args.sample_dir if args.sample_dir != None else args.save_dir

            if reconstructions and nr_of_imgs > 0:
                reso = session.getReso()

                # First, create the single grid

                if Utils.reconstruction_set_x is None or Utils.reconstruction_set_x.size(2) != reso or (session.getResoPhase() >= 1 and session.alpha < 1.0):
                    if session.getResoPhase() < 1:
                        dataset = data.Utils.sample_data(loader, min(nr_of_imgs, 16), reso)
                    else:
                        dataset = data.Utils.sample_data2(loader, min(nr_of_imgs, 16), reso, session)
                    Utils.reconstruction_set_x, _ = next(dataset)

                unstyledVersionDone = False
                # For 64x64: (0,1), (2,3) and (4,-1)
                # For 128x128: (2,3) and (3,4) fail (collapse to static image). Find out why.
                for sc in [(0,1), (0,2), (2,-1), (2,4), (4,5), (5,-1)]:
                    for style_i in range(imgs_per_style_mixture_run):

                        if not unstyledVersionDone:
                            scs = [(0, -1), sc]
                        else:
                            scs = [sc]                        

                        for (style_layer_begin, style_layer_end) in scs:
                            reco_image = Utils.reconstruct( Utils.reconstruction_set_x, encoder, generator, session,
                                                            style_i = style_i, style_layer_begin=style_layer_begin, style_layer_end=style_layer_end)


                            t = torch.FloatTensor(Utils.reconstruction_set_x.size(0) * 2, Utils.reconstruction_set_x.size(1),
                                                Utils.reconstruction_set_x.size(2), Utils.reconstruction_set_x.size(3))

                            NN = int(Utils.reconstruction_set_x.size(0) / 2)
                            t[0:NN,:,:,:] = Utils.reconstruction_set_x[:NN]
                            t[NN:NN*2,:,:,:] = reco_image[:NN, :, :,:]
                        
                            from pioneer.model import AdaNorm                       

                            save_path = '{}{}/reconstructions_{}_{}_{}_{}_{}_{}-{}.png'.format(save_root, prefix, session.phase, global_i, session.alpha, 'X' if AdaNorm.disable else 'ADA', style_i,style_layer_begin,style_layer_end)
                            grid = torchvision.utils.save_image(t[:nr_of_imgs] / 2 + 0.5, save_path, padding=0)
                        unstyledVersionDone = True

                # Rescale the images to nice big lego format:
                if session.getResoPhase() < 4:
                    h = np.ceil(nr_of_imgs / 8)
                    h_scale = min(1.0, h/8.0)
                    im = Image.open(save_path)
                    im2 = im.resize((1024, int(1024 * h_scale)))
                    im2.save(save_path)

                #if writer:
                #    writer.add_image('reconstruction_latest_{}'.format(session.phase), t[:nr_of_imgs] / 2 + 0.5, session.phase)

                # Second, create the Individual images:
                if session.getResoPhase() < 1:
                    dataset = data.Utils.sample_data(loader, 1, reso)
                else:
                    dataset = data.Utils.sample_data2(loader, 1, reso, session)

                special_dir = '{}/{}'.format(save_root if not args.aux_outpath else args.aux_outpath, str(global_i).zfill(6))

                if not os.path.exists(special_dir):
                    os.makedirs(special_dir)                

                print("Save images: Alpha={}, phase={}, images={}, at {}".format(session.alpha, session.getResoPhase(), nr_of_imgs, special_dir))

                lpips_dist = 0

                inf_time = []
                n_time = 0

                for o in range(nr_of_imgs):
                    if o%500==0:
                        print(o)
                
                    real_image, _ = next(dataset)
                    start = time.time()
                    reco_image = Utils.reconstruct(real_image, encoder, generator, session)

                    end = time.time()

                    inf_time +=  [(end-start)]

                    t = torch.FloatTensor(real_image.size(0) * 2, real_image.size(1),
                                        real_image.size(2), real_image.size(3))      
                           
                    save_path_A = '{}/{}_orig.png'.format(special_dir, o)
                    save_path_B = '{}/{}_pine.png'.format(special_dir, o)
                    torchvision.utils.save_image(real_image[0] / 2 + 0.5, save_path_A, padding=0)
                    torchvision.utils.save_image(reco_image[0] / 2 + 0.5, save_path_B, padding=0)

                print("Mean elaped: {} or {} for bs={} with std={}".format(np.mean(inf_time), np.mean(inf_time)/len(real_image), len(real_image), np.std(inf_time)  ))

        encoder.train()
        generator.train()
Exemplo n.º 12
0
def train(generator, encoder, g_running, train_data_loader, test_data_loader,
          session, total_steps, train_mode, sched):
    pbar = tqdm(initial=session.sample_i, total=total_steps)

    benchmarking = False

    match_x = args.match_x
    generatedImagePool = None

    refresh_dataset = True
    refresh_imagePool = True
    refresh_adaptiveLoss = False

    # After the Loading stage, we cycle through successive Fade-in and Stabilization stages

    batch_count = 0

    reset_optimizers_on_phase_start = False

    # TODO Unhack this (only affects the episode count statistics anyway):
    if args.data != 'celebaHQ':
        epoch_len = len(train_data_loader(1, 4).dataset)
    else:
        epoch_len = train_data_loader._len['data4x4']

    if args.step_offset != 0:
        if args.step_offset == -1:
            args.step_offset = session.sample_i
        print("Step offset is {}".format(args.step_offset))
        session.phase += args.phase_offset
        session.alpha = 0.0

    last_fade_done_at_reso = -1

    while session.sample_i < total_steps:
        #######################  Phase Maintenance #######################
        sched.update(session.sample_i)
        sample_i_current_stage = sched.get_iteration_of_current_phase(
            session.sample_i)

        if sched.phaseChangedOnLastUpdate:
            match_x = args.match_x  # Reset to non-matching phase
            refresh_dataset = True
            refresh_imagePool = True  # Reset the pool to avoid images of 2 different resolutions in the pool
            refresh_adaptiveLoss = True
            if reset_optimizers_on_phase_start:
                utils.requires_grad(generator)
                utils.requires_grad(encoder)
                generator.zero_grad()
                encoder.zero_grad()
                session.reset_opt()
                print("Optimizers have been reset.")

        reso = session.getReso()

        # If we can switch from fade-training to stable-training
        if sample_i_current_stage >= args.images_per_stage / 2:
            if session.alpha < 1.0:
                refresh_dataset = True  # refresh dataset generator since no longer have to fade
                last_fade_done_at_reso = reso
                session.alpha = 1
            match_x = args.match_x * args.matching_phase_x
        else:
            match_x = args.match_x

        # We track whether this resolution was already present in the previous stage, which means that it was already faded once.
        if last_fade_done_at_reso != reso:
            session.alpha = min(
                1, sample_i_current_stage * 2.0 / args.images_per_stage
            )  # For 100k, it was 0.00002 = 2.0 / args.images_per_stage

        if refresh_adaptiveLoss:
            session.prepareAdaptiveLossForNewPhase()
            refresh_adaptiveLoss = False
        if refresh_dataset:
            train_dataset = data.Utils.sample_data2(train_data_loader,
                                                    batch_size(reso), reso,
                                                    session)
            refresh_dataset = False
            print("Refreshed dataset. Alpha={} and iteration={}".format(
                session.alpha, sample_i_current_stage))
        if refresh_imagePool:
            imagePoolSize = 200 if reso < 256 else 100
            generatedImagePool = utils.ImagePool(
                imagePoolSize
            )  #Reset the pool to avoid images of 2 different resolutions in the pool
            refresh_imagePool = False
            print('Image pool created with size {} because reso is {}'.format(
                imagePoolSize, reso))

        ####################### Training init #######################

        stats = {}
        stats['z_mix_reconstruction_error'] = 0

        try:
            real_image, _ = next(train_dataset)
        except (OSError, StopIteration):
            train_dataset = data.Utils.sample_data2(train_data_loader,
                                                    batch_size(reso), reso,
                                                    session)
            real_image, _ = next(train_dataset)

        ####################### DISCRIMINATOR / ENCODER ###########################

        utils.switch_grad_updates_to_first_of(encoder, generator)
        kls = encoder_train(session,
                            real_image,
                            generatedImagePool,
                            batch_size(reso),
                            match_x,
                            stats,
                            "",
                            margin=sched.m)

        ######################## GENERATOR / DECODER #############################

        if (batch_count + 1) % args.n_critic == 0:
            utils.switch_grad_updates_to_first_of(generator, encoder)

            for _ in range(args.n_generator):
                kls = decoder_train(session, batch_size(reso), stats, kls,
                                    real_image.data)

            accumulate(g_running, generator)

        del real_image

        ########################  Statistics ########################

        if args.use_TB:
            for key, val in stats.items():
                writer.add_scalar(key, val, session.sample_i)
        elif batch_count % 100 == 0:
            print(stats)

        if args.use_TB:
            writer.add_scalar('LOD',
                              session.getResoPhase() + session.alpha,
                              session.sample_i)

        b = session.getBatchSize()
        zr, xr = (stats['z_reconstruction_error'],
                  stats['x_reconstruction_error']) if (
                      'z_reconstruction_error' in stats
                      and 'x_reconstruction_error' in stats) else (0.0, 0.0)
        e = (session.sample_i / float(epoch_len))

        pbar.set_description((
            '{0}; it: {1}; phase: {2}; b: {3:.1f}; Alpha: {4:.3f}; Reso: {5}; E: {6:.2f}; KL(Phi(x)/Phi(G0)/Phi(G1)/Phi(G2)): {7}; z-reco: {8:.2f}; x-reco {9:.3f}; real_var {10:.6f}; fake_var {11:.6f}; z-mix: {12:.4f};'
        ).format(batch_count + 1, session.sample_i + 1, session.phase, b,
                 session.alpha, reso, e, kls, zr, xr,
                 stats['real_var'], stats['fake_var'],
                 float(stats['z_mix_reconstruction_error'])))

        pbar.update(batch_size(reso))
        session.sample_i += batch_size(reso)  # if not benchmarking else 100
        batch_count += 1

        ########################  Saving ########################

        if batch_count % args.checkpoint_cycle == 0 or session.sample_i >= total_steps:
            for postfix in {str(session.sample_i).zfill(6)}:  # 'latest'
                session.save_all('{}/{}_state'.format(args.checkpoint_dir,
                                                      postfix))
                saveSNU('{}/{}_SNU'.format(args.checkpoint_dir, postfix))

            print("Checkpointed to {}".format(session.sample_i))

        ########################  Tests ########################

        try:
            evaluate.tests_run(
                g_running,
                encoder,
                test_data_loader,
                session,
                writer,
                reconstruction=(batch_count % 2400 == 0),
                interpolation=(batch_count % 2400 == 0),
                collated_sampling=(batch_count % 800 == 0),
                individual_sampling=
                False,  #(batch_count % (args.images_per_stage/batch_size(reso)/4) == 0),
                metric_eval=(batch_count % 2500 == 0))
        except (OSError, StopIteration):
            print("Skipped periodic tests due to an exception.")

    pbar.close()
Exemplo n.º 13
0
    def get_latent_presentation(encoder,
                                loader,
                                global_i,
                                nr_of_imgs,
                                prefix,
                                reals,
                                session,
                                writer=None):  #of the form"/[dir]"
        if args.latent_outpath == None:
            print(
                "Output folder for latent presentations is not defined - output not generated."
            )
            return

        #generator.eval()
        encoder.eval()

        #utils.requires_grad(generator, False)
        utils.requires_grad(encoder, False)

        # if reconstructions and nr_of_imgs > 0:
        reso = 4 * 2**session.phase

        # First, create the single grid

        if Utils.reconstruction_set_x is None or Utils.reconstruction_set_x.size(
                2) != reso or (session.phase >= 1 and session.alpha < 1.0):
            if session.phase < 1:
                dataset = data.Utils.sample_data(loader, min(nr_of_imgs, 16),
                                                 reso)
            else:
                dataset = data.Utils.sample_data2(loader, min(nr_of_imgs, 16),
                                                  reso, session)
            Utils.reconstruction_set_x, _ = next(dataset)

        # reco_image = Utils.reconstruct(Utils.reconstruction_set_x, encoder, generator, session)
        late_image = Utils.extract_latent_presentation_from_image(
            Utils.reconstruction_set_x, encoder, session)

        # Second, create the Individual images:
        if session.phase < 1:
            dataset = data.Utils.sample_data(loader, 1, reso)
        else:
            dataset = data.Utils.sample_data2(loader, 1, reso, session)

        #Define output directory
        special_dir = args.latent_outpath

        if not os.path.exists(special_dir):
            os.makedirs(special_dir)

        print("Save images: Alpha={}, phase={}, images={}, at {}".format(
            session.alpha, session.phase, nr_of_imgs, special_dir))
        for o in range(nr_of_imgs):
            if o % 500 == 0:
                print(o)

            real_image, _ = next(dataset)
            #reco_image = Utils.reconstruct(real_image, encoder, generator, session)
            late_image = Utils.extract_latent_presentation_from_image(
                real_image, encoder, session)

            # t = torch.FloatTensor(real_image.size(0) * 2, real_image.size(1),
            # real_image.size(2), real_image.size(3))

            save_path_A = '{}/{}_late_orig.png'.format(special_dir, o)
            #save_path_B = '{}/{}_pine.png'.format(special_dir, o)
            save_path_C = '{}/{}_late'.format(special_dir, o)
            # print("\nSave path is ")
            # print(save_path_C)
            torchvision.utils.save_image(real_image[0] / 2 + 0.5,
                                         save_path_A,
                                         padding=0)
            #torchvision.utils.save_image(reco_image[0] / 2 + 0.5, save_path_B, padding=0)
            # print(late_image.cpu().numpy())
            #np.save(save_path_C, late_image.cpu().numpy()) # this works
            torch.save(late_image, save_path_C)
            # with open(save_path_C, 'w') as file:
            # file.write(str(late_image.cpu().numpy()))
            #torchvision.utils.save_image(late_image[0], save_path_C, padding=0)

        encoder.train()
Exemplo n.º 14
0
    def reconstruct_images_from_latent(generator,
                                       encoder,
                                       global_i,
                                       nr_of_imgs,
                                       prefix,
                                       reals,
                                       session,
                                       writer=None):  #of the form"/[dir]"
        if args.latent_inpath == None:
            print(
                "Input folder for latent presentations is not defined - output not generated."
            )
            return
        generator.eval()
        encoder.eval()

        utils.requires_grad(generator, False)
        utils.requires_grad(encoder, False)

        print("Loading images")
        load_folder = args.latent_inpath
        latent_reconstructions = []
        filenames = []
        for file_path in os.listdir(load_folder):
            if file_path.endswith("_late"):
                # with open(os.path.join(load_folder,file_path), 'r') as file:
                ex = torch.load(os.path.join(load_folder, file_path))
                # print("Shape of ex is "+str(len(ex)))
                # ex = torch.from_numpy(ex)
                lare_image = Utils.reconstruct_from_latent_presentation(
                    ex, generator, session)
                latent_reconstructions.append(lare_image)
                for i in range(args.latent_reconstructions_N -
                               1):  #reconstruct n times
                    print("Reconstruncting, round " + str(i) + " started")
                    lare_image = Utils.reconstruct(lare_image, encoder,
                                                   generator, session)
                    latent_reconstructions.append(lare_image)
                filenames.append(os.path.join(load_folder, file_path))

        #Define output directory
        special_dir = args.latent_inpath
        print("Special directory is:\n" + special_dir)

        if not os.path.exists(special_dir):
            os.makedirs(special_dir)

        print("Save images: Alpha={}, phase={}, images={}, at {}".format(
            session.alpha, session.phase, nr_of_imgs, special_dir))
        for o in range(len(latent_reconstructions)):
            if o % 500 == 0:
                print(o)

            # real_image, _ = next(dataset)
            # reco_image = Utils.reconstruct(real_image, encoder, generator, session)

            # t = torch.FloatTensor(real_image.size(0) * 2, real_image.size(1),
            # real_image.size(2), real_image.size(3))

            # save_path_A = '{}/{}_orig.png'.format(special_dir, o)
            # save_path_B = '{}/{}_pine.png'.format(special_dir, o)
            if args.latent_reconstructions_N > 1:
                save_path_C = '{}/{}_late_reconstructed.png'.format(
                    special_dir, o)
            else:
                save_path_C = filenames[o] + "_reconstructed.png"

            # torchvision.utils.save_image(real_image[0] / 2 + 0.5, save_path_A, padding=0)
            # torchvision.utils.save_image(reco_image[0] / 2 + 0.5, save_path_B, padding=0)
            torchvision.utils.save_image(latent_reconstructions[o] / 2 + 0.5,
                                         save_path_C,
                                         padding=0)

        generator.train()
        encoder.train()