Esempio n. 1
0
def decoder_train(session, batch_N, stats, kls, x):
    session.generator.zero_grad()

    if session.phase > 0:
        for p in session.generator.module.to_rgb[session.phase-1].parameters():
            p.requires_grad_(False)
 
    g_losses = []

    mix_ratio = args.stylemix_D
    mix_N = int(mix_ratio*batch_N) if session.phase > 0 else 0
    z = Variable( torch.FloatTensor(batch_N, args.nz, 1, 1) ).to(device=args.device)
    utils.populate_z(z, args.nz+args.n_label, args.noise, batch_N)

    if mix_N > 0:
        alt_mix_z = Variable( torch.FloatTensor(mix_N, args.nz, 1, 1) ).to(device=args.device)
        utils.populate_z(alt_mix_z, args.nz+args.n_label, args.noise, mix_N)
        alt_mix_z = torch.cat((alt_mix_z, z[mix_N:,:]), dim=0) if mix_N < z.size()[0] else alt_mix_z
    else:
        alt_mix_z = None   

    # KL is calculated from the distro of z's that was re-encoded from z and mixed-z
    egz, kl, egz_intermediate, z_intermediate = KL_of_encoded_G_output(session.generator, session.encoder, z, batch_N, session, alt_mix_z, mix_N)

    if args.use_loss_KL_z:
        g_losses.append(kl) # G minimizes this KL
        stats['KL(Phi(G))'] = kl.data.item()
        kls = "{0}/{1:.3f}".format(kls, stats['KL(Phi(G))'])

    # z_diff is calculated only from the regular z (not mixed z)
    if args.use_loss_z_reco: #and (mix_N == 0 or session.phase == 0):
        z_diff = utils.mismatch(egz[mix_N:,:], z[mix_N:,:], args.match_z_metric) * args.match_z # G tries to make the original z and encoded z match #Alternative: [mix_N:,:]
        z_mix_diff = utils.mismatch(egz_intermediate.view([mix_N,-1]), z_intermediate.view([mix_N,-1]), 'L2') if mix_N>0 else torch.zeros(1).cuda()
        if args.intermediate_zreco > 0:
            g_losses.append(z_mix_diff)
        g_losses.append(z_diff)
        stats['z_reconstruction_error'] = z_diff.data.item()
       	if args.intermediate_zreco > 0:
            stats['z_mix_reconstruction_error'] = z_mix_diff.data.item()

    if len(g_losses) > 0:
        loss = sum(g_losses)
        stats['G_loss'] = np.float32(loss.data.cpu())

        loss.backward()

        if False: #For debugging the adanorm blocks
            from model import AdaNorm
            adaparams0 = list(session.generator.adanorm_blocks[0].mod.parameters())
            param_scale  = np.linalg.norm(adaparams0[0].detach().cpu().numpy().ravel())
            print("Ada norm: {} / {}".format(np.linalg.norm(adaparams0[0].grad.detach().cpu().numpy().ravel()), param_scale ))

    session.optimizerG.step()

    return kls
Esempio n. 2
0
                    def KL_of_encoded_G_output(generator, z):
                        utils.populate_z(z, args.nz + args.n_label, args.noise,
                                         batch_size(reso))
                        z, label = utils.split_labels_out_of_latent(z)
                        fake = generator(z, label, session.phase,
                                         session.alpha)

                        egz = encoder(fake, session.phase, session.alpha,
                                      args.use_ALQ)
                        # KL_fake: \Delta( e(g(Z)) , Z ) -> min_g
                        return egz, label, KL_minimizer(
                            egz) * args.fake_G_KL_scale, z
Esempio n. 3
0
def train(generator, encoder, g_running, train_data_loader, test_data_loader,
          session, total_steps, train_mode):
    pbar = tqdm(initial=session.sample_i, total=total_steps)

    benchmarking = False

    match_x = args.match_x
    generatedImagePool = None

    refresh_dataset = True
    refresh_imagePool = True

    # After the Loading stage, we cycle through successive Fade-in and Stabilization stages

    batch_count = 0

    reset_optimizers_on_phase_start = False

    # TODO Unhack this (only affects the episode count statistics anyway):
    if args.data != 'celebaHQ':
        epoch_len = len(train_data_loader(1, 4).dataset)
    else:
        epoch_len = train_data_loader._len['data4x4']

    if args.step_offset != 0:
        if args.step_offset == -1:
            args.step_offset = session.sample_i
        print("Step offset is {}".format(args.step_offset))
        session.phase += args.phase_offset
        session.alpha = 0.0

    while session.sample_i < total_steps:
        #######################  Phase Maintenance #######################

        steps_in_previous_phases = max(session.phase * args.images_per_stage,
                                       args.step_offset)

        sample_i_current_stage = session.sample_i - steps_in_previous_phases

        # If we can move to the next phase
        if sample_i_current_stage >= args.images_per_stage:
            if session.phase < args.max_phase:  # If any phases left
                iteration_levels = int(sample_i_current_stage /
                                       args.images_per_stage)
                session.phase += iteration_levels
                sample_i_current_stage -= iteration_levels * args.images_per_stage
                match_x = args.match_x  # Reset to non-matching phase
                print(
                    "iteration B alpha={} phase {} will be reduced to 1 and [max]"
                    .format(sample_i_current_stage, session.phase))

                refresh_dataset = True
                refresh_imagePool = True  # Reset the pool to avoid images of 2 different resolutions in the pool

                if reset_optimizers_on_phase_start:
                    utils.requires_grad(generator)
                    utils.requires_grad(encoder)
                    generator.zero_grad()
                    encoder.zero_grad()
                    session.reset_opt()
                    print("Optimizers have been reset.")

        reso = 4 * 2**session.phase

        # If we can switch from fade-training to stable-training
        if sample_i_current_stage >= args.images_per_stage / 2:
            if session.alpha < 1.0:
                refresh_dataset = True  # refresh dataset generator since no longer have to fade
            match_x = args.match_x * args.matching_phase_x
        else:
            match_x = args.match_x

        session.alpha = min(
            1, sample_i_current_stage * 2.0 / args.images_per_stage
        )  # For 100k, it was 0.00002 = 2.0 / args.images_per_stage

        if refresh_dataset:
            train_dataset = data.Utils.sample_data2(train_data_loader,
                                                    batch_size(reso), reso,
                                                    session)
            refresh_dataset = False
            print("Refreshed dataset. Alpha={} and iteration={}".format(
                session.alpha, sample_i_current_stage))
        if refresh_imagePool:
            imagePoolSize = 200 if reso < 256 else 100
            generatedImagePool = utils.ImagePool(
                imagePoolSize
            )  #Reset the pool to avoid images of 2 different resolutions in the pool
            refresh_imagePool = False
            print('Image pool created with size {} because reso is {}'.format(
                imagePoolSize, reso))

        ####################### Training init #######################

        z = utils.conditional_to_cuda(
            Variable(torch.FloatTensor(batch_size(reso), args.nz, 1, 1)),
            args.gpu_count > 1)
        KL_minimizer = KLN01Loss(direction=args.KL, minimize=True)
        KL_maximizer = KLN01Loss(direction=args.KL, minimize=False)

        stats = {}

        #one = torch.FloatTensor([1]).cuda(async=(args.gpu_count>1))
        one = torch.FloatTensor([1]).cuda(non_blocking=(args.gpu_count > 1))

        try:
            real_image, _ = next(train_dataset)
        except (OSError, StopIteration):
            train_dataset = data.Utils.sample_data2(train_data_loader,
                                                    batch_size(reso), reso,
                                                    session)
            real_image, _ = next(train_dataset)

        ####################### DISCRIMINATOR / ENCODER ###########################

        utils.switch_grad_updates_to_first_of(encoder, generator)
        encoder.zero_grad()

        #x = Variable(real_image).cuda(async=(args.gpu_count>1))
        x = Variable(real_image).cuda(non_blocking=(args.gpu_count > 1))
        kls = ""
        if train_mode == config.MODE_GAN:

            # Discriminator for real samples
            real_predict, _ = encoder(x, session.phase, session.alpha,
                                      args.use_ALQ)
            real_predict = real_predict.mean() \
                - 0.001 * (real_predict ** 2).mean()
            real_predict.backward(-one)  # Towards 1

            # (1) Generator => D. Identical to (2) see below

            fake_predict, fake_image = D_prediction_of_G_output(
                generator, encoder, session.phase, session.alpha)
            fake_predict.backward(one)

            # Grad penalty

            grad_penalty = get_grad_penalty(encoder, x, fake_image,
                                            session.phase, session.alpha)
            grad_penalty.backward()

        elif train_mode == config.MODE_CYCLIC:
            e_losses = []

            # e(X)

            real_z = encoder(x, session.phase, session.alpha, args.use_ALQ)
            if args.use_real_x_KL:
                # KL_real: - \Delta( e(X) , Z ) -> max_e
                KL_real = KL_minimizer(real_z) * args.real_x_KL_scale
                e_losses.append(KL_real)

                stats['real_mean'] = KL_minimizer.samples_mean.data.mean()
                stats['real_var'] = KL_minimizer.samples_var.data.mean()
                stats['KL_real'] = KL_real  #FIXME .data[0]
                kls = "{0:.3f}".format(stats['KL_real'])

            # The final entries are the label. Normal case, just 1. Extract it/them, and make it [b x 1]:

            real_z, label = utils.split_labels_out_of_latent(real_z)
            recon_x = generator(real_z, label, session.phase, session.alpha)
            if args.use_loss_x_reco:
                # match_x: E_x||g(e(x)) - x|| -> min_e
                err = utils.mismatch(recon_x, x, args.match_x_metric) * match_x
                e_losses.append(err)
                stats['x_reconstruction_error'] = err  #FIXME .data[0]

            args.use_wpgan_grad_penalty = False
            grad_penalty = 0.0

            if args.use_loss_fake_D_KL:
                # TODO: The following codeblock is essentially the same as the KL_minimizer part on G side. Unify
                utils.populate_z(z, args.nz + args.n_label, args.noise,
                                 batch_size(reso))
                z, label = utils.split_labels_out_of_latent(z)
                fake = generator(z, label, session.phase,
                                 session.alpha).detach()

                if session.alpha >= 1.0:
                    fake = generatedImagePool.query(fake.data)

                # e(g(Z))
                egz = encoder(fake, session.phase, session.alpha, args.use_ALQ)

                # KL_fake: \Delta( e(g(Z)) , Z ) -> max_e
                KL_fake = KL_maximizer(egz) * args.fake_D_KL_scale

                # Added by Igor
                m = args.kl_margin
                if m > 0.0 and session.phase >= 5:
                    KL_loss = torch.min(
                        torch.ones_like(KL_fake) * m / 2,
                        torch.max(-torch.ones_like(KL_fake) * m,
                                  KL_real + KL_fake))
                    # KL_fake is always negative with abs value typically larger than KL_real.
                    # Hence, the sum is negative, and must be gapped so that the minimum is the negative of the margin.
                else:
                    KL_loss = KL_real + KL_fake
                e_losses.append(KL_loss)
                #e_losses.append(KL_fake)

                stats['fake_mean'] = KL_maximizer.samples_mean.data.mean()
                stats['fake_var'] = KL_maximizer.samples_var.data.mean()
                stats['KL_fake'] = -KL_fake  #FIXME .data[0]
                kls = "{0}/{1:.3f}".format(kls, stats['KL_fake'])

                if args.use_wpgan_grad_penalty:
                    grad_penalty = get_grad_penalty(encoder, x, fake,
                                                    session.phase,
                                                    session.alpha)

            # Update e
            if len(e_losses) > 0:
                e_loss = sum(e_losses)
                stats['E_loss'] = e_loss.detach().cpu().float().numpy(
                )  #np.float32(e_loss.data)
                e_loss.backward()

                if args.use_wpgan_grad_penalty:
                    grad_penalty.backward()
                    stats['Grad_penalty'] = grad_penalty.data

                #book-keeping
                disc_loss_val = e_loss  #FIXME .data[0]

        session.optimizerD.step()

        torch.cuda.empty_cache()

        ######################## GENERATOR / DECODER #############################

        if (batch_count + 1) % args.n_critic == 0:
            utils.switch_grad_updates_to_first_of(generator, encoder)

            for _ in range(args.n_generator):
                generator.zero_grad()
                g_losses = []

                if train_mode == config.MODE_GAN:
                    fake_predict, _ = D_prediction_of_G_output(
                        generator, encoder, session.phase, session.alpha)
                    loss = -fake_predict
                    g_losses.append(loss)

                elif train_mode == config.MODE_CYCLIC:  #TODO We push the z variable around here like idiots

                    def KL_of_encoded_G_output(generator, z):
                        utils.populate_z(z, args.nz + args.n_label, args.noise,
                                         batch_size(reso))
                        z, label = utils.split_labels_out_of_latent(z)
                        fake = generator(z, label, session.phase,
                                         session.alpha)

                        egz = encoder(fake, session.phase, session.alpha,
                                      args.use_ALQ)
                        # KL_fake: \Delta( e(g(Z)) , Z ) -> min_g
                        return egz, label, KL_minimizer(
                            egz) * args.fake_G_KL_scale, z

                    egz, label, kl, z = KL_of_encoded_G_output(generator, z)

                    if args.use_loss_KL_z:
                        g_losses.append(kl)  # G minimizes this KL
                        stats['KL(Phi(G))'] = kl  #FIXME .data[0]
                        kls = "{0}/{1:.3f}".format(kls, stats['KL(Phi(G))'])

                    if args.use_loss_z_reco:
                        z = torch.cat((z, label), 1)
                        z_diff = utils.mismatch(
                            egz, z, args.match_z_metric
                        ) * args.match_z  # G tries to make the original z and encoded z match
                        g_losses.append(z_diff)

                if len(g_losses) > 0:
                    loss = sum(g_losses)
                    stats['G_loss'] = loss.detach().cpu().float().numpy(
                    )  #np.float32(loss.data)
                    loss.backward()

                    # Book-keeping only:
                    gen_loss_val = loss  #FIXME .data[0]

                session.optimizerG.step()

                torch.cuda.empty_cache()

                if train_mode == config.MODE_CYCLIC:
                    if args.use_loss_z_reco:
                        stats[
                            'z_reconstruction_error'] = z_diff  #FIXME .data[0]

            accumulate(g_running, generator)

            del z, x, one, real_image, real_z, KL_real, label, recon_x, fake, egz, KL_fake, kl, z_diff

            if train_mode == config.MODE_CYCLIC:
                if args.use_TB:
                    for key, val in stats.items():
                        writer.add_scalar(key, val, session.sample_i)
                elif batch_count % 100 == 0:
                    print(stats)

            if args.use_TB:
                writer.add_scalar('LOD', session.phase + session.alpha,
                                  session.sample_i)

        ########################  Statistics ########################

        b = batch_size_by_phase(session.phase)
        zr, xr = (stats['z_reconstruction_error'],
                  stats['x_reconstruction_error']
                  ) if train_mode == config.MODE_CYCLIC else (0.0, 0.0)
        e = (session.sample_i / float(epoch_len))
        pbar.set_description((
            '{0}; it: {1}; phase: {2}; b: {3:.1f}; Alpha: {4:.3f}; Reso: {5}; E: {6:.2f}; KL(real/fake/fakeG): {7}; z-reco: {8:.2f}; x-reco {9:.3f}; real_var {10:.4f}'
        ).format(batch_count + 1, session.sample_i + 1, session.phase, b,
                 session.alpha, reso, e, kls, zr, xr, stats['real_var']))
        #(f'{i + 1}; it: {iteration+1}; b: {b:.1f}; G: {gen_loss_val:.5f}; D: {disc_loss_val:.5f};'
        # f' Grad: {grad_loss_val:.5f}; Alpha: {alpha:.3f}; Reso: {reso}; S-mean: {real_mean:.3f}; KL(real/fake/fakeG): {kls}; z-reco: {zr:.2f}'))

        pbar.update(batch_size(reso))
        session.sample_i += batch_size(reso)  # if not benchmarking else 100
        batch_count += 1

        ########################  Saving ########################

        if batch_count % args.checkpoint_cycle == 0:
            for postfix in {'latest', str(session.sample_i).zfill(6)}:
                session.save_all('{}/{}_state'.format(args.checkpoint_dir,
                                                      postfix))

            print("Checkpointed to {}".format(session.sample_i))

        ########################  Tests ########################

        try:
            evaluate.tests_run(
                g_running,
                encoder,
                test_data_loader,
                session,
                writer,
                reconstruction=(batch_count % 800 == 0),
                interpolation=(batch_count % 800 == 0),
                collated_sampling=(batch_count % 800 == 0),
                individual_sampling=(
                    batch_count %
                    (args.images_per_stage / batch_size(reso) / 4) == 0))
        except (OSError, StopIteration):
            print("Skipped periodic tests due to an exception.")

    pbar.close()
Esempio n. 4
0
def encoder_train(session, real_image, generatedImagePool, batch_N, match_x,
                  stats, kls, margin):
    encoder = session.encoder
    generator = session.generator

    encoder.zero_grad()
    generator.zero_grad()

    x = Variable(real_image).to(device=args.device)  #async=(args.gpu_count>1))
    KL_maximizer = KLN01Loss(direction=args.KL, minimize=False)
    KL_minimizer = KLN01Loss(direction=args.KL, minimize=True)

    e_losses = []

    flipInvarianceLayer = args.flip_invariance_layer
    flipX = flipInvarianceLayer > -1 and session.phase >= flipInvarianceLayer

    phiAdaCotrain = args.phi_ada_cotrain

    if flipX:
        phiAdaCotrain = True

    #global adaparams
    if phiAdaCotrain:
        for b in generator.module.adanorm_blocks:
            if not b is None and not b.mod is None:
                for param in b.mod.parameters():
                    param.requires_grad_(True)

    if flipX:
        x_in = x[0:int(x.size()[0] / 2), :, :, :]
        x_mirror = x_in.clone().detach().requires_grad_(True).to(
            device=args.device).flip(dims=[3])
        x[int(x.size()[0] / 2):, :, :, :] = x_mirror

    real_z = encoder(x, session.getResoPhase(), session.alpha, args.use_ALQ)

    if args.use_real_x_KL:
        # KL_real: - \Delta( e(X) , Z ) -> max_e
        if not flipX:
            KL_real = KL_minimizer(real_z) * args.real_x_KL_scale
        else:  # Treat the KL div of each direction of the data as separate distributions
            z_in = real_z[0:int(x.size()[0] / 2), :]
            z_in_mirror = real_z[int(x.size()[0] / 2):, :]
            KL_real = (KL_minimizer(z_in) +
                       KL_minimizer(z_in_mirror)) * args.real_x_KL_scale / 2
        e_losses.append(KL_real)

        stats['real_mean'] = KL_minimizer.samples_mean.data.mean().item()
        stats['real_var'] = KL_minimizer.samples_var.data.mean().item()
        stats['KL_real'] = KL_real.data.item()
        kls = "{0:.3f}".format(stats['KL_real'])

    if flipX:
        x_reco_in = utils.gen_seq(
            [
                (
                    z_in, 0, flipInvarianceLayer
                ),  # Rotation of x_in, the ID of x_in_mirror (= the ID of x_in)
                (z_in_mirror, flipInvarianceLayer, -1)
            ],
            session.generator,
            session)

        x_reco_in_mirror = utils.gen_seq(
            [
                (z_in_mirror, 0, flipInvarianceLayer),  # Vice versa
                (z_in, flipInvarianceLayer, -1)
            ],
            session.generator,
            session)

        if args.match_x_metric == 'robust':
            loss_flip = torch.mean(session.adaptive_loss[session.getResoPhase()].lossfun((x_in - x_reco_in).view(-1, x_in.size()[1]*x_in.size()[2]*x_in.size()[3]) ))  * match_x * 0.2 + \
                        torch.mean(session.adaptive_loss[session.getResoPhase()].lossfun((x_mirror - x_reco_in_mirror).view(-1, x_mirror.size()[1]*x_mirror.size()[2]*x_mirror.size()[3]) ))  * match_x * 0.2
        else:
            loss_flip = (utils.mismatch(x_in, x_reco_in, args.match_x_metric) +
                         utils.mismatch(x_mirror, x_reco_in_mirror,
                                        args.match_x_metric)) * args.match_x

        loss_flip.backward(retain_graph=True)
        stats['loss_flip'] = loss_flip.data.item()
        stats['x_reconstruction_error'] = loss_flip.data.item()
        print('Flip loss: {}'.format(stats['loss_flip']))

    else:
        if args.use_loss_x_reco:
            recon_x = generator(real_z, None, session.phase, session.alpha)
            # match_x: E_x||g(e(x)) - x|| -> min_e

            if args.match_x_metric == 'robust':
                err_simple = utils.mismatch(recon_x, x, 'L1') * match_x
                err = torch.mean(
                    session.adaptive_loss[session.getResoPhase()].lossfun(
                        (recon_x - x).view(
                            -1,
                            x.size()[1] * x.size()[2] *
                            x.size()[3]))) * match_x * 0.2
                print("err vs. ROBUST err: {} / {}".format(err_simple, err))
            else:
                err_simple = utils.mismatch(recon_x, x,
                                            args.match_x_metric) * match_x
                err = err_simple

            if phiAdaCotrain:
                err.backward(retain_graph=True)
            else:
                e_losses.append(err)
            stats['x_reconstruction_error'] = err.data.item()

    if phiAdaCotrain:
        for b in session.generator.module.adanorm_blocks:
            if not b is None and not b.mod is None:
                for param in b.mod.parameters():
                    param.requires_grad_(False)

    if args.use_loss_fake_D_KL:
        # TODO: The following codeblock is essentially the same as the KL_minimizer part on G side. Unify

        mix_ratio = args.stylemix_E  #0.25
        mix_N = int(mix_ratio * batch_N)
        z = Variable(torch.FloatTensor(batch_N, args.nz, 1, 1)).to(
            device=args.device)  #async=(args.gpu_count>1))
        utils.populate_z(z, args.nz + args.n_label, args.noise, batch_N)

        if session.phase > 0 and mix_N > 0:
            alt_mix_z = Variable(torch.FloatTensor(mix_N, args.nz, 1, 1)).to(
                device=args.device)  #async=(args.gpu_count>1))
            utils.populate_z(alt_mix_z, args.nz + args.n_label, args.noise,
                             mix_N)
            alt_mix_z = torch.cat(
                (alt_mix_z,
                 z[mix_N:, :]), dim=0) if mix_N < z.size()[0] else alt_mix_z
        else:
            alt_mix_z = None

        with torch.no_grad():
            style_layer_begin = np.random.randint(
                low=1, high=(session.phase +
                             1)) if not alt_mix_z is None else -1
            fake = utils.gen_seq([(z, 0, style_layer_begin),
                                  (alt_mix_z, style_layer_begin, -1)],
                                 generator, session).detach()

        if session.alpha >= 1.0:
            fake = generatedImagePool.query(fake.data)

        fake.requires_grad_()

        # e(g(Z))
        egz = encoder(fake, session.getResoPhase(), session.alpha,
                      args.use_ALQ)

        # KL_fake: \Delta( e(g(Z)) , Z ) -> max_e
        KL_fake = KL_maximizer(egz) * args.fake_D_KL_scale

        m = margin
        if m > 0.0:
            KL_loss = torch.min(
                torch.ones_like(KL_fake) * m / 2,
                torch.max(-torch.ones_like(KL_fake) * m, KL_real + KL_fake)
            )  # KL_fake is always negative with abs value typically larger than KL_real. Hence, the sum is negative, and must be gapped so that the minimum is the negative of the margin.
        else:
            KL_loss = KL_real + KL_fake
        e_losses.append(KL_loss)

        stats['fake_mean'] = KL_maximizer.samples_mean.data.mean()
        stats['fake_var'] = KL_maximizer.samples_var.data.mean()
        stats['KL_fake'] = -KL_fake.data.item()
        stats['KL_loss'] = KL_loss.data.item()

        kls = "{0}/{1:.3f}".format(kls, stats['KL_fake'])

    # Update e
    if len(e_losses) > 0:
        e_loss = sum(e_losses)
        e_loss.backward()
        stats['E_loss'] = np.float32(e_loss.data.cpu())

    session.optimizerD.step()

    if flipX:
        session.optimizerA.step(
        )  #The AdaNorm params need to be updated separately since they are on "generator side"

    return kls