예제 #1
0
def KL_of_encoded_G_output(generator, encoder, z, batch_N, session, alt_mix_z,
                           mix_N):
    KL_minimizer = KLN01Loss(direction=args.KL, minimize=True)
    #utils.populate_z(z, args.nz+args.n_label, args.noise, batch_N)

    assert (alt_mix_z is None
            or alt_mix_z.size()[0] == z.size()[0])  # Batch sizes must match
    mix_style_layer_begin = np.random.randint(
        low=1, high=(session.phase + 1)) if not alt_mix_z is None else -1
    z_intermediates = utils.gen_seq([(z, 0, mix_style_layer_begin),
                                     (alt_mix_z, mix_style_layer_begin, -1)],
                                    generator,
                                    session,
                                    retain_intermediate_results=True)

    assert (z_intermediates[0].size()[0] == z.size()[0]
            )  # Batch sizes must remain invariant
    fake = z_intermediates[1 if not alt_mix_z is None else 0]
    z_intermediate = z_intermediates[0][:mix_N, :]

    egz = encoder(fake, session.getResoPhase(), session.alpha, args.use_ALQ)

    if mix_style_layer_begin > -1:
        egz_intermediate = utils.gen_seq(
            [(egz[:mix_N, :], 0, mix_style_layer_begin)], generator,
            session)  # Or we could just call generator directly, ofc.
        assert (egz_intermediate.size()[0] == z_intermediate.size()[0])
    else:
        egz_intermediate = z_intermediate = None

    # KL_fake: \Delta( e(g(Z)) , Z ) -> min_g
    return egz, KL_minimizer(
        egz) * args.fake_G_KL_scale, egz_intermediate, z_intermediate
예제 #2
0
    def reconstruct(input_image, encoder, generator, session, style_i=-1, style_layer_begin=0, style_layer_end=-1):
        with torch.no_grad():
            style_input = encoder(Variable(input_image), session.getResoPhase(), session.alpha, args.use_ALQ).detach()

            #replicateStyle = True

            #if replicateStyle:
            z = style_input[style_i].repeat(style_input.size()[0],1) # Repeat the image #0 for all the image styles
            #else:
            #    z = None

            #The call would unwrap as follows:
            #z_w = generator(None,None, session.phase, session.alpha, style_input = z,           style_layer_begin=0, style_layer_end=style_layer_begin).detach()
            #z_w = generator(z_w, None, session.phase, session.alpha, style_input = style_input, style_layer_begin=style_layer_begin, style_layer_end=style_layer_end)
            #gex = generator(z_w, None, session.phase, session.alpha, style_input = z,           style_layer_begin=style_layer_end, style_layer_end=-1)

            gex = utils.gen_seq([ (z, 0, style_layer_begin),
                            (style_input, style_layer_begin, style_layer_end),
                            (z, style_layer_end, -1),
            ], generator, session).detach()

#            if gex.size()[0] > 1:
#                print("Norm of diff: {} / {}".format((gex[0][0] - gex[1][0]).norm(), (gex[0][0] - gex[1][0]).max()   ))

        return gex.data[:]
예제 #3
0
    def generate_intermediate_samples(generator, global_i, session, writer=None, collateImages = True):
        generator.eval()
        with torch.no_grad():        

            save_root = args.sample_dir if args.sample_dir != None else args.save_dir

            utils.requires_grad(generator, False)
            
            # Total number is samplesRepeatN * colN * rowN
            # e.g. for 51200 samples, outcome is 5*80*128. Please only give multiples of 128 here.
            samplesRepeatN = max(1, int(args.sample_N / 128)) if not collateImages else 1
            reso = session.getReso()

            if not collateImages:
                special_dir = '{}/sample/{}'.format(save_root, str(global_i).zfill(6))
                while os.path.exists(special_dir):
                    special_dir += '_'

                os.makedirs(special_dir)
            
            for outer_count in range(samplesRepeatN):
                
                colN = 1 if not collateImages else min(10, int(np.ceil(args.sample_N / 4.0)))
                rowN = 128 if not collateImages else min(5, int(np.ceil(args.sample_N / 4.0)))
                images = []
                myzs = []
                for _ in range(rowN):
                    myz = {}
                    for i in range(2):
                       myz0 = Variable(torch.randn(args.n_label * colN, args.nz+1)).to(device=args.device)
                       myz[i] = utils.normalize(myz0)
                       #myz[i], input_class = utils.split_labels_out_of_latent(myz0)

                    if args.randomMixN <= 1:
                        new_imgs = generator(
                            myz[0],
                            None, # input_class,
                            session.phase,
                            session.alpha).detach() #.data.cpu()
                    else:
                        max_i = session.phase
                        style_layer_begin = 2 #np.random.randint(low=1, high=(max_i+1))
                        print("Cut at layer {} for the next {} random samples.".format(style_layer_begin, (myz[0].shape)))
                        
                        new_imgs = utils.gen_seq([ (myz[0], 0, style_layer_begin),
                                                   (myz[1], style_layer_begin, -1)
                                                 ], generator, session).detach()
                        
                    images.append(new_imgs)
                    myzs.append(myz[0]) # TODO: Support mixed latents with rejection sampling

                if args.rejection_sampling > 0 and not collateImages:
                    filtered_images = []
                    batch_size = 8
                    for b in range(int(len(images) / batch_size)):
                        img_batch = images[b*batch_size:(b+1)*batch_size]
                        reco_z = session.encoder(Variable(torch.cat(img_batch,0)), session.getResoPhase(), session.alpha, args.use_ALQ).detach()
                        cost_z = utils.mismatchV(torch.cat(myzs[b*batch_size:(b+1)*batch_size],0), reco_z, 'cos')
                        _, ii = torch.sort(cost_z)
                        keeper = args.rejection_sampling / 100.0
                        keepN = max(1, int(len(ii)*keeper))
                        for iindex in ii[:keepN]:
                            filtered_images.append(img_batch[iindex])
#                        filtered_images.append(img_batch[ii[:keepN]] ) #retain the best ones only
                    images = filtered_images

                if collateImages:
                    sample_dir = '{}/sample'.format(save_root)
                    if not os.path.exists(sample_dir):
                        os.makedirs(sample_dir)
                    
                    save_path = '{}/{}.png'.format(sample_dir,str(global_i + 1).zfill(6))
                    torchvision.utils.save_image(
                        torch.cat(images, 0),
                        save_path,
                        nrow=args.n_label * colN,
                        normalize=True,
                        range=(-1, 1),
                        padding=0)
                    # Hacky but this is an easy way to rescale the images to nice big lego format:
                    im = Image.open(save_path)
                    im2 = im.resize((1024, 512 if reso < 256 else 1024))
                    im2.save(save_path)

                    #if writer:
                    #    writer.add_image('samples_latest_{}'.format(session.phase), torch.cat(images, 0), session.phase)
                else:
                    print("Generating samples at {}...".format(special_dir))
                    for ii, img in enumerate(images):
                        ipath =  '{}/{}_{}.png'.format(special_dir, str(global_i + 1).zfill(6), ii+outer_count*128)
                        #print(ipath)
                        torchvision.utils.save_image(
                            img,
                            ipath,
                            nrow=args.n_label * colN,
                            normalize=True,
                            range=(-1, 1),
                            padding=0)

        generator.train()
예제 #4
0
def encoder_train(session, real_image, generatedImagePool, batch_N, match_x,
                  stats, kls, margin):
    encoder = session.encoder
    generator = session.generator

    encoder.zero_grad()
    generator.zero_grad()

    x = Variable(real_image).to(device=args.device)  #async=(args.gpu_count>1))
    KL_maximizer = KLN01Loss(direction=args.KL, minimize=False)
    KL_minimizer = KLN01Loss(direction=args.KL, minimize=True)

    e_losses = []

    flipInvarianceLayer = args.flip_invariance_layer
    flipX = flipInvarianceLayer > -1 and session.phase >= flipInvarianceLayer

    phiAdaCotrain = args.phi_ada_cotrain

    if flipX:
        phiAdaCotrain = True

    #global adaparams
    if phiAdaCotrain:
        for b in generator.module.adanorm_blocks:
            if not b is None and not b.mod is None:
                for param in b.mod.parameters():
                    param.requires_grad_(True)

    if flipX:
        x_in = x[0:int(x.size()[0] / 2), :, :, :]
        x_mirror = x_in.clone().detach().requires_grad_(True).to(
            device=args.device).flip(dims=[3])
        x[int(x.size()[0] / 2):, :, :, :] = x_mirror

    real_z = encoder(x, session.getResoPhase(), session.alpha, args.use_ALQ)

    if args.use_real_x_KL:
        # KL_real: - \Delta( e(X) , Z ) -> max_e
        if not flipX:
            KL_real = KL_minimizer(real_z) * args.real_x_KL_scale
        else:  # Treat the KL div of each direction of the data as separate distributions
            z_in = real_z[0:int(x.size()[0] / 2), :]
            z_in_mirror = real_z[int(x.size()[0] / 2):, :]
            KL_real = (KL_minimizer(z_in) +
                       KL_minimizer(z_in_mirror)) * args.real_x_KL_scale / 2
        e_losses.append(KL_real)

        stats['real_mean'] = KL_minimizer.samples_mean.data.mean().item()
        stats['real_var'] = KL_minimizer.samples_var.data.mean().item()
        stats['KL_real'] = KL_real.data.item()
        kls = "{0:.3f}".format(stats['KL_real'])

    if flipX:
        x_reco_in = utils.gen_seq(
            [
                (
                    z_in, 0, flipInvarianceLayer
                ),  # Rotation of x_in, the ID of x_in_mirror (= the ID of x_in)
                (z_in_mirror, flipInvarianceLayer, -1)
            ],
            session.generator,
            session)

        x_reco_in_mirror = utils.gen_seq(
            [
                (z_in_mirror, 0, flipInvarianceLayer),  # Vice versa
                (z_in, flipInvarianceLayer, -1)
            ],
            session.generator,
            session)

        if args.match_x_metric == 'robust':
            loss_flip = torch.mean(session.adaptive_loss[session.getResoPhase()].lossfun((x_in - x_reco_in).view(-1, x_in.size()[1]*x_in.size()[2]*x_in.size()[3]) ))  * match_x * 0.2 + \
                        torch.mean(session.adaptive_loss[session.getResoPhase()].lossfun((x_mirror - x_reco_in_mirror).view(-1, x_mirror.size()[1]*x_mirror.size()[2]*x_mirror.size()[3]) ))  * match_x * 0.2
        else:
            loss_flip = (utils.mismatch(x_in, x_reco_in, args.match_x_metric) +
                         utils.mismatch(x_mirror, x_reco_in_mirror,
                                        args.match_x_metric)) * args.match_x

        loss_flip.backward(retain_graph=True)
        stats['loss_flip'] = loss_flip.data.item()
        stats['x_reconstruction_error'] = loss_flip.data.item()
        print('Flip loss: {}'.format(stats['loss_flip']))

    else:
        if args.use_loss_x_reco:
            recon_x = generator(real_z, None, session.phase, session.alpha)
            # match_x: E_x||g(e(x)) - x|| -> min_e

            if args.match_x_metric == 'robust':
                err_simple = utils.mismatch(recon_x, x, 'L1') * match_x
                err = torch.mean(
                    session.adaptive_loss[session.getResoPhase()].lossfun(
                        (recon_x - x).view(
                            -1,
                            x.size()[1] * x.size()[2] *
                            x.size()[3]))) * match_x * 0.2
                print("err vs. ROBUST err: {} / {}".format(err_simple, err))
            else:
                err_simple = utils.mismatch(recon_x, x,
                                            args.match_x_metric) * match_x
                err = err_simple

            if phiAdaCotrain:
                err.backward(retain_graph=True)
            else:
                e_losses.append(err)
            stats['x_reconstruction_error'] = err.data.item()

    if phiAdaCotrain:
        for b in session.generator.module.adanorm_blocks:
            if not b is None and not b.mod is None:
                for param in b.mod.parameters():
                    param.requires_grad_(False)

    if args.use_loss_fake_D_KL:
        # TODO: The following codeblock is essentially the same as the KL_minimizer part on G side. Unify

        mix_ratio = args.stylemix_E  #0.25
        mix_N = int(mix_ratio * batch_N)
        z = Variable(torch.FloatTensor(batch_N, args.nz, 1, 1)).to(
            device=args.device)  #async=(args.gpu_count>1))
        utils.populate_z(z, args.nz + args.n_label, args.noise, batch_N)

        if session.phase > 0 and mix_N > 0:
            alt_mix_z = Variable(torch.FloatTensor(mix_N, args.nz, 1, 1)).to(
                device=args.device)  #async=(args.gpu_count>1))
            utils.populate_z(alt_mix_z, args.nz + args.n_label, args.noise,
                             mix_N)
            alt_mix_z = torch.cat(
                (alt_mix_z,
                 z[mix_N:, :]), dim=0) if mix_N < z.size()[0] else alt_mix_z
        else:
            alt_mix_z = None

        with torch.no_grad():
            style_layer_begin = np.random.randint(
                low=1, high=(session.phase +
                             1)) if not alt_mix_z is None else -1
            fake = utils.gen_seq([(z, 0, style_layer_begin),
                                  (alt_mix_z, style_layer_begin, -1)],
                                 generator, session).detach()

        if session.alpha >= 1.0:
            fake = generatedImagePool.query(fake.data)

        fake.requires_grad_()

        # e(g(Z))
        egz = encoder(fake, session.getResoPhase(), session.alpha,
                      args.use_ALQ)

        # KL_fake: \Delta( e(g(Z)) , Z ) -> max_e
        KL_fake = KL_maximizer(egz) * args.fake_D_KL_scale

        m = margin
        if m > 0.0:
            KL_loss = torch.min(
                torch.ones_like(KL_fake) * m / 2,
                torch.max(-torch.ones_like(KL_fake) * m, KL_real + KL_fake)
            )  # KL_fake is always negative with abs value typically larger than KL_real. Hence, the sum is negative, and must be gapped so that the minimum is the negative of the margin.
        else:
            KL_loss = KL_real + KL_fake
        e_losses.append(KL_loss)

        stats['fake_mean'] = KL_maximizer.samples_mean.data.mean()
        stats['fake_var'] = KL_maximizer.samples_var.data.mean()
        stats['KL_fake'] = -KL_fake.data.item()
        stats['KL_loss'] = KL_loss.data.item()

        kls = "{0}/{1:.3f}".format(kls, stats['KL_fake'])

    # Update e
    if len(e_losses) > 0:
        e_loss = sum(e_losses)
        e_loss.backward()
        stats['E_loss'] = np.float32(e_loss.data.cpu())

    session.optimizerD.step()

    if flipX:
        session.optimizerA.step(
        )  #The AdaNorm params need to be updated separately since they are on "generator side"

    return kls