Esempio n. 1
0
 def __init__(self, cont_wgts=[0, 0, 0, 0], style_wgts=[0, 0, 0, 0]):
     super().__init__()
     pretrained_model = torchvision.models.vgg16(True).features.eval()
     requires_grad(pretrained_model, False)
     blocks = [
         i - 1 for i, o in enumerate(children(pretrained_model))
         if isinstance(o, nn.MaxPool2d)
     ]
     self.model = PerceptualLossModel(pretrained_model, blocks[:4],
                                      cont_wgts, style_wgts)
Esempio n. 2
0
    def dry_run(self, batch, phase, alpha):
        ''' dry run model on the batch '''
        self.enc.eval()
        self.gen.eval()
        utils.requires_grad(self.enc, False)
        utils.requires_grad(self.gen, False)

        rgb, fir = batch
        batch_size = rgb.shape[0]

        rgb_fir, rgb_rgb, _, _ = self.mixcod.rgb(rgb, fir, phase, alpha)
        fir_rgb, fir_fir, _, _ = self.mixcod.fir(fir, rgb, phase, alpha)

        # join source and reconstructed images side by side
        out_ims = torch.cat((rgb, rgb_rgb, rgb_fir, fir, fir_fir, fir_rgb),
                            1).view(6 * batch_size, 1, rgb.shape[-2],
                                    rgb.shape[-1])

        self.enc.train()
        self.gen.train()

        return out_ims
Esempio n. 3
0
    def dry_run(self, batch, phase, alpha):
        ''' dry run model on the batch '''
        encoder, generator = self.encoder, self.generator
        generator.eval()
        encoder.eval()
        x = batch[0]
        batch_size = x.shape[0]

        utils.requires_grad(generator, False)
        utils.requires_grad(encoder, False)

        real_z = encoder(x, phase, alpha)
        reco_ims = generator(real_z, phase, alpha).data

        # join source and reconstructed images side by side
        out_ims = torch.cat((x, reco_ims),
                            1).view(2 * batch_size, 1, reco_ims.shape[-2],
                                    reco_ims.shape[-1])

        encoder.train()
        generator.train()
        return out_ims
Esempio n. 4
0
def updateImages(input, session):
    encoder, generator, critic = session.encoder, session.generator, session.critic
    # batch,alpha,phase = session.cur_batch(), session.alpha, session.phase
    phase = 5
    alpha = 1.0

    stats = {}
    utils.requires_grad(encoder, False)
    utils.requires_grad(generator, False)
    utils.requires_grad(critic, False)

    smoothing = GaussianSmoothing(1, 11.0, 10.0).cuda()
    x = smoothing(input)
    x = torch.abs(x - x[[1]])  #+ 0.5*torch.rand_like(input)
    x = Variable(x, requires_grad=True)

    optimizer = optim.Adam([x], 0.001, betas=(0.0, 0.99))

    while True:
        for i in range(1):
            losses = []
            optimizer.zero_grad()

            real_z = encoder(x, phase, alpha)
            fake_x = generator(real_z, phase, alpha)

            err_x = utils.mismatch(fake_x, x, args.match_x_metric)
            losses.append(err_x)
            stats['x_err'] = err_x.data

            cls_fake = critic(fake_x, x, session.phase, session.alpha)
            # measure loss only where real score is highre than fake score
            cls_loss = -cls_fake.mean()
            stats['cls_loss'] = cls_loss.data
            # warm up critic loss to kick in with alpha
            losses.append(cls_loss)

            # Propagate gradients for encoder and decoder
            loss = sum(losses)
            loss.backward()

            g = x.grad.cpu().data

            # Apply encoder and decoder gradients
            optimizer.step()

        idx = 0
        imshow(x[idx, 0].cpu().data)
        imshow(fake_x[idx, 0].cpu().data)
        # imshow(input[idx,0].cpu().data)
        # imshow(g[0,0].cpu().data)

        clf()

    return stats
Esempio n. 5
0
    def update(self, batch, phase, alpha):
        ##### Train Encoder and Generator #########
        utils.requires_grad(self.enc, True)
        utils.requires_grad(self.gen, True)
        utils.requires_grad(self.crt, False)
        self.enc.zero_grad()
        self.gen.zero_grad()
        stats, losses = {}, []

        rgb, fir = batch
        # rgb = rgb.cuda(); fir = fir.cuda()

        ##### Autoencoders ###########
        loss_rgb, auto_fake_rgb = Model.autoenc_loss(self.autoenc.rgb,
                                                     self.crt.rgb, rgb, phase,
                                                     alpha)
        stats.update({'L1_auto_rgb': loss_rgb[0]})
        losses += loss_rgb

        loss_fir, auto_fake_fir = Model.autoenc_loss(self.autoenc.fir,
                                                     self.crt.fir, fir, phase,
                                                     alpha)
        stats.update({'L1_auto_fir': loss_fir[0]})
        losses += loss_fir

        ##### Mixcoders #############
        loss_rgb, mix_fake_rgb, mix_fake_rgb_fir = \
            Model.mix_loss(self.mixcod.rgb, self.crt.rgb, self.crt.fir, rgb, fir, phase, alpha)
        stats.update({'L1_mix_rgb': loss_rgb[0], 'L2z_mix_rgb': loss_rgb[1]})
        losses += loss_rgb

        loss_fir, mix_fake_fir, mix_fake_fir_rgb = \
            Model.mix_loss(self.mixcod.fir, self.crt.fir, self.crt.rgb, fir, rgb, phase, alpha)
        stats.update({'L1_mix_fir': loss_fir[0], 'L2z_mix_fir': loss_fir[1]})
        losses += loss_fir

        ##### Propagate gradients for encoders and generators
        loss = sum(losses)
        loss.backward()
        self.optimE.step()
        self.optimG.step()

        ##### Train Critics ##########
        utils.requires_grad(self.enc, False)
        utils.requires_grad(self.gen, False)
        utils.requires_grad(self.crt, True)
        self.crt.zero_grad()

        losses = []

        # Same domain critics
        loss_crt_rgb, crt_real_rgb, crt_auto_fake_rgb, crt_mix_fake_rgb = \
            Model.crt_loss_same_domain(self.crt.rgb,rgb,auto_fake_rgb,mix_fake_rgb,phase,alpha)
        stats.update({
            'crt_real_rgb': crt_real_rgb,
            'crt_auto_fake_rgb': crt_auto_fake_rgb,
            'crt_mix_fake_rgb': crt_mix_fake_rgb
        })
        losses += loss_crt_rgb

        loss_crt_fir, crt_real_fir, crt_auto_fake_fir, crt_mix_fake_fir = \
            Model.crt_loss_same_domain(self.crt.fir,fir,auto_fake_fir,mix_fake_fir, phase,alpha)
        stats.update({
            'crt_real_fir': crt_real_fir,
            'crt_auto_fake_fir': crt_auto_fake_fir,
            'crt_mix_fake_fir': crt_mix_fake_fir
        })
        losses += loss_crt_fir

        # Cross domain critics
        loss_crt_mix_fir_rgb, crt_mix_fake_fir_rgb = \
            Model.crt_loss_cross_domain(self.crt.rgb,crt_real_rgb,mix_fake_fir_rgb,phase,alpha)
        stats.update({'crt_mix_fake_fir_rgb': crt_mix_fake_fir_rgb})
        losses += loss_crt_mix_fir_rgb

        loss_crt_mix_rgb_fir, crt_mix_fake_rgb_fir = \
            Model.crt_loss_cross_domain(self.crt.fir,crt_real_fir,mix_fake_rgb_fir,phase,alpha)
        stats.update({'crt_mix_fake_rgb_fir': crt_mix_fake_rgb_fir})
        losses += loss_crt_mix_rgb_fir

        ######### Gradient regularization #############

        # measure gradient of autoencoder
        gnorm_auto_rgb = Model.autoenc_grad_norm(self.autoenc.rgb, rgb, phase,
                                                 alpha).mean()
        # equalize graqient between real and fake samples
        loss_grad_auto_rgb = Model.crt_grad_penalty(
            self.crt.rgb, rgb, auto_fake_rgb, phase, alpha,
            args.grad_norm_fact * gnorm_auto_rgb)
        stats.update({'loss_grad_auto_rgb': loss_grad_auto_rgb[0]})
        losses += loss_grad_auto_rgb

        gnorm_auto_fir = Model.autoenc_grad_norm(self.autoenc.fir, fir, phase,
                                                 alpha).mean()
        # equalize graqient between real and fake samples
        loss_grad_auto_fir = Model.crt_grad_penalty(
            self.crt.fir, fir, auto_fake_fir, phase, alpha,
            args.grad_norm_fact * gnorm_auto_fir)
        stats.update({'loss_grad_auto_fir': loss_grad_auto_fir[0]})
        losses += loss_grad_auto_fir

        # measure gradient of mixcoder
        gnorm_mix_rgb = Model.mixcod_grad_norm(self.mixcod.rgb, rgb, fir,
                                               phase, alpha).mean()
        loss_grad_mix_rgb = Model.crt_grad_penalty(
            self.crt.rgb, rgb, mix_fake_rgb, phase, alpha,
            args.grad_norm_fact * gnorm_mix_rgb)
        stats.update({'loss_grad_mix_rgb': loss_grad_mix_rgb[0]})
        losses += loss_grad_mix_rgb

        gnorm_mix_fir = Model.mixcod_grad_norm(self.mixcod.fir, fir, rgb,
                                               phase, alpha).mean()
        loss_grad_mix_fir = Model.crt_grad_penalty(
            self.crt.fir, fir, mix_fake_fir, phase, alpha,
            args.grad_norm_fact * gnorm_mix_fir)
        stats.update({'loss_grad_mix_fir': loss_grad_mix_fir[0]})
        losses += loss_grad_mix_fir

        ##### Propagate gradients for critics
        loss = sum(losses)
        loss.backward()
        self.optimC.step()

        return stats
Esempio n. 6
0
    def update(self, batch, phase, alpha):
        encoder, generator, critic = self.encoder, self.generator, self.critic
        stats, losses = {}, []
        utils.requires_grad(encoder, True)
        utils.requires_grad(generator, True)
        utils.requires_grad(critic, False)
        encoder.zero_grad()
        generator.zero_grad()

        x = batch[0]
        batch_size = x.shape[0]

        real_z = encoder(x, phase, alpha)
        fake_x = generator(real_z, phase, alpha)

        # use no gradient propagation if no x metric is required
        if args.use_x_metric:
            # match x: E_x||g(e(x)) - x|| -> min_e
            err_x = utils.mismatch(fake_x, x, args.match_x_metric)
            losses.append(err_x)
        else:
            with torch.no_grad():
                err_x = utils.mismatch(fake_x, x, args.match_x_metric)
        stats['x_err'] = err_x

        if args.use_z_metric:
            # cyclic match z E_x||e(g(e(x))) - e(x)||^2
            fake_z = encoder(fake_x, phase, alpha)
            err_z = utils.mismatch(real_z, fake_z, args.match_z_metric)
            losses.append(err_z)
        else:
            with torch.no_grad():
                fake_z = encoder(fake_x, phase, alpha)
                err_z = utils.mismatch(real_z, fake_z, args.match_z_metric)
        stats['z_err'] = err_z

        cls_fake = critic(fake_x, x, phase, alpha)

        cls_real = critic(x, x, phase, alpha)

        # measure loss only where real score is highre than fake score
        G_loss = -(cls_fake *
                   (cls_real.detach() > cls_fake.detach()).float()).mean()

        # Gloss      = -torch.log(cls_fake).mean()
        stats['G_loss'] = G_loss
        # warm up critic loss to kick in with alpha
        losses.append(alpha * G_loss)

        # Propagate gradients for encoder and decoder
        loss = sum(losses)
        loss.backward()

        # Apply encoder and decoder gradients
        self.optimizerE.step()
        self.optimizerG.step()

        ###### Critic ########
        losses = []
        utils.requires_grad(critic, True)
        utils.requires_grad(encoder, False)
        utils.requires_grad(generator, False)
        critic.zero_grad()
        # Use fake_x, as fixed data here
        fake_x = fake_x.detach()

        cls_fake = critic(fake_x, x, phase, alpha)
        cls_real = critic(x, x, phase, alpha)

        cf, cr = cls_fake.mean(), cls_real.mean()
        C_loss = cf - cr + torch.abs(cf + cr)

        grad_norm = autoenc_grad_norm(encoder, generator, x, phase,
                                      alpha).mean()
        grad_loss = critic_grad_penalty(critic, x, fake_x, batch_size, phase,
                                        alpha, grad_norm)
        stats['grad_loss'] = grad_loss
        losses.append(grad_loss)

        # C_loss      = -torch.log(1.0 - cls_fake).mean() - torch.log(cls_real).mean()

        stats['cls_fake'] = cls_fake.mean()
        stats['cls_real'] = cls_real.mean()
        stats['C_loss'] = C_loss.data

        # Propagate critic losses
        losses.append(C_loss)
        loss = sum(losses)
        loss.backward()

        # Apply critic gradient
        self.optimizerC.step()
        return stats
Esempio n. 7
0
def train(generator, discriminator, face_align_net, g_optim, d_optim,
          input_image, target_image, step, iteration, alpha):
    #generator training
    requires_grad(generator, True)
    requires_grad(discriminator, False)
    generator.zero_grad()
    fake_image = generator(input_image, step, alpha)
    pixel_loss = L1_Loss(fake_image, target_image)
    predict_fake = discriminator(fake_image, step, alpha)

    feature_real = vgg(0.5 * target_image + 0.5)
    feature_fake = vgg(0.5 * fake_image + 0.5)

    perceptual_loss = 0
    for fr, ff in zip(feature_real, feature_fake):
        perceptual_loss += MSE_Loss(ff, fr)  #MSE_Loss_sum(ff, fr)

    if step == 1:
        g_loss = -predict_fake.mean() + 10 * pixel_loss + 1e-2 * perceptual_loss

    elif step > 1:
        hm_f, hm_r = get_heat_map(face_align_net, fake_image, target_image,
                                  False, 2**(4 - step))
        face_hmap_loss = MSE_Loss(hm_f, hm_r)
        heatMap = operate_heatmap(hm_r)
        diff = abs(fake_image - target_image)
        attention_loss = torch.mean(heatMap * diff)
        if step == 2:
            g_loss = 10 * pixel_loss - predict_fake.mean(
            ) + 1e-2 * perceptual_loss + 500 * attention_loss + 50 * face_hmap_loss
        else:
            g_loss = 10 * pixel_loss - 1e-1 * predict_fake.mean(
            ) + 1e-3 * perceptual_loss + 50 * attention_loss + 50 * face_hmap_loss

    g_loss.backward()
    g_optim.step()

    #discriminator training
    requires_grad(generator, False)
    requires_grad(discriminator, True)
    discriminator.zero_grad()
    predict_real = discriminator(target_image, step, alpha)
    predict_real = predict_real.mean() - 0.001 * (predict_real**2).mean()
    fake_image = generator(input_image, step, alpha)
    predict_fake = discriminator(fake_image, step, alpha)
    predict_fake = predict_fake.mean()

    #Gradient Penalty (GP)
    eps = torch.rand(input_image.size(0), 1, 1, 1).to(device)
    x_hat = eps * target_image.data + (1 - eps) * fake_image.data
    x_hat = Variable(x_hat, requires_grad=True).to(device)
    hat_predict = discriminator(x_hat, step, alpha)
    grad_x_hat = grad(outputs=hat_predict.sum(),
                      inputs=x_hat,
                      create_graph=True)[0]
    #grad_penalty = ((grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1) -1)**2).mean()
    grad_penalty = torch.max(
        torch.zeros(1).to(device),
        ((grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1) - 1)).mean())
    grad_penalty = 10 * grad_penalty

    d_loss = predict_fake - predict_real + grad_penalty

    d_loss.backward()
    d_optim.step()

    if args.local_rank == 0:
        if step == 1:
            sys.stdout.write('\r Step:%1d Iteration:%5d alpha:%6.5f Pixel_loss:%6.4f perceptual_loss:%6.4f Generator loss:%6.4f Discriminator loss:%6.4f'\
            %(step, iteration, alpha, pixel_loss.item(), perceptual_loss.item(), g_loss.item(), d_loss.item()))
        else:
            sys.stdout.write('\r Step:%1d Iteration:%5d alpha:%6.5f Pixel_loss:%6.4f perceptual_loss:%6.4f attention_loss:%6.4f face_hmap_loss:%6.4f Generator loss:%6.4f Discriminator loss:%6.4f'\
            %(step, iteration, alpha, pixel_loss.item(), perceptual_loss.item(), attention_loss.item(), face_hmap_loss.item(), g_loss.item(), d_loss.item()))

        #save predict sample
        if iteration % 100 == 0:
            imgs = torch.cat(
                [0.5 * fake_image + 0.5, 0.5 * target_image + 0.5], dim=0)
            utils.save_image(
                imgs,
                os.path.join(args.result_path,
                             'result_iteration{}.jpeg'.format(iteration)))

        if iteration % args.save_interval == 0:
            torch.save(
                {
                    'step': step,
                    'alpha': alpha,
                    'iteration': iteration,
                    'model_state_dict': generator.state_dict(),
                },
                os.path.join(
                    args.checkpoint_path,
                    'awgan_generator_checkpoint_{}.ckpt'.format((step - 1) *
                                                                50000 +
                                                                iteration)))
            torch.save(
                {
                    'step': step,
                    'alpha': alpha,
                    'iteration': iteration,
                    'model_state_dict': discriminator.state_dict(),
                },
                os.path.join(args.checkpoint_path +
                             'awgan_discriminator_checkpoint_{}.ckpt'.format(
                                 (step - 1) * 50000 + iteration)))