Example #1
0
    def sample(self, x_real, txt_src2trg, txt_lens):
        self.eval()
        x_real_recon, x_ab, x_sam, x_att = [], [], [], []
        for i in range(x_real.size(0)):
            content_real, style_real, _ = self.gen.encode(x_real[i:i + 1])
            style_real = torch.cat(style_real, dim=1)
            style_txt, logvar_txt = self.gen.encode_txt(
                style_real, txt_src2trg[i:i + 1], txt_lens[i:i + 1])
            style_txt = torch.cat(style_txt, dim=1)

            x_real_rec, x_real_rec_att = self.gen.decode(
                self.two_sided_attention(
                    content_real.clone(),
                    style_real.reshape(-1, self.c_dim, self.c_dim)),
                style_real)
            x_trg, x_trg_att = self.gen.decode(
                self.two_sided_attention(
                    content_real.clone(),
                    style_txt.reshape(-1, self.c_dim, self.c_dim)), style_txt)

            mus_real = torch.ones(1, self.num_cls).float().to(self.device)
            mus_txt = torch.ones(1, self.num_cls).float().to(self.device)
            for idx in range(self.num_cls):
                if style_real[0, idx * self.c_dim:(idx + 1) *
                              self.c_dim].mean() < 0.0:
                    mus_real[0, idx] = -1.0
                if style_txt[0, idx * self.c_dim:(idx + 1) *
                             self.c_dim].mean() < 0.0:
                    mus_txt[0, idx] = -1.0
            z_sample = dist_sampling_split(mus_txt, self.c_dim, self.stddev,
                                           self.device)
            z_sample = self.style_replace(mus_real, mus_txt, style_real,
                                          z_sample)
            x_sample, x_sample_att = self.gen.decode(
                self.two_sided_attention(
                    content_real.clone(),
                    z_sample.reshape(-1, self.c_dim, self.c_dim)), z_sample)

            if self.use_attention:
                x_trg = x_trg * x_trg_att + x_real[i:i + 1] * (1 - x_trg_att)
                x_real_rec = x_real_rec * x_real_rec_att + x_real[i:i + 1] * (
                    1 - x_real_rec_att)
                x_sample = x_sample * x_sample_att + x_real[i:i + 1] * (
                    1 - x_sample_att)
                x_att.append(
                    torch.cat([x_trg_att, x_trg_att, x_trg_att], dim=1))
            x_ab.append(x_trg)
            x_real_recon.append(x_real_rec)
            x_sam.append(x_sample)
        x_real_recon = torch.cat(x_real_recon)
        x_ab = torch.cat(x_ab)
        x_sam = torch.cat(x_sam)
        outputs = [x_real, x_real_recon, x_ab, x_sam]
        if self.use_attention:
            x_att = torch.cat(x_att)
            outputs.append((x_att - 0.5) / 0.5)
        self.train()
        return outputs
Example #2
0
    def dis_update(self, x_real, c_src, c_trg, txt_src2trg, txt_lens,
                   label_src, label_trg, configs, iters):
        self.dis_opt.zero_grad()
        content_real, style_real, _ = self.gen.encode(x_real)
        style_real = torch.cat(style_real, dim=1)

        style1 = dist_sampling_split(c_trg, self.c_dim, self.stddev,
                                     self.device)
        style_txt, logvar_txt = self.gen.encode_txt(style_real, txt_src2trg,
                                                    txt_lens)
        if self.two_sided:
            content_real = self.two_sided_attention(content_real, style_txt)
        style_txt = torch.cat(style_txt, dim=1)
        x_fake, x_fake_att = self.gen.decode(content_real, style_txt)
        x_fake1, x_fake_att1 = self.gen.decode(content_real, style1)
        if self.use_attention:
            x_fake = x_fake * x_fake_att + x_real * (1 - x_fake_att)
            x_fake1 = x_fake1 * x_fake_att1 + x_real * (1 - x_fake_att1)

        self.loss_dis = self.dis.calc_dis_loss(x_fake, x_real, label_trg, label_src, configs['gan_w'], configs['cls_w']) + \
            self.dis.calc_dis_loss(x_fake1, x_real, label_trg, label_src, configs['gan_w'], configs['cls_w'])
        self.loss_dis_all = self.loss_dis

        # Compute loss for gradient penalty.
        if configs['gp_w'] > 0.0:
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
            x_hat = (alpha * x_real.data +
                     (1 - alpha) * x_fake.data).requires_grad_(True)
            out_src, _ = self.dis(x_hat, False)[0]
            self.loss_gp = self.gradient_penalty(out_src,
                                                 x_hat) * configs['gp_w']
            self.loss_dis_all += self.loss_gp

        # Compute loss for r1 penalty.
        if configs['use_r1'] and (iters + 1) % self.d_reg_every == 0:
            x_real.requires_grad = True
            output, _ = self.dis(x_real, False)[0]
            self.loss_r1 = self.r1_penalty(
                output, x_real) * 10. / 2  #* self.d_reg_every
            self.loss_dis_all += self.loss_r1

        self.loss_dis_all.backward()
        self.dis_opt.step()
Example #3
0
    def gen_update(self, x_real, c_src, c_trg, txt_src2trg, txt_lens,
                   label_src, label_trg, configs, iters):
        self.gen_opt.zero_grad()
        # encode
        content_real, style_real, logvar = self.gen.encode(x_real)

        # decode (within domain)
        x_real_rec, x_real_rec_att = self.gen.decode(
            self.two_sided_attention(content_real.clone(), style_real),
            torch.cat(style_real, dim=1))
        if self.use_attention:
            x_real_rec = x_real_rec * x_real_rec_att + x_real * (
                1 - x_real_rec_att)
        content_real_rec, style_real_rec, _ = self.gen.encode(x_real_rec)

        # decode (cross domain)
        style_txt, logvar_txt = self.gen.encode_txt(
            torch.cat(style_real, dim=1), txt_src2trg, txt_lens)
        x_fake, x_fake_att = self.gen.decode(
            self.two_sided_attention(content_real.clone(), style_txt),
            torch.cat(style_txt, dim=1))
        if self.use_attention:
            x_fake = x_fake * x_fake_att + x_real * (1 - x_fake_att)

        #self.loss_ds = 0.0
        #if self.stddev > 0 and iters > self.ds_iter:
        style1 = dist_sampling_split(c_trg, self.c_dim, self.stddev,
                                     self.device)
        content_fake1 = content_real.clone()
        if self.two_sided:
            content_fake1 = self.two_sided_attention(
                content_fake1, style1.reshape(-1, c_trg.shape[-1], self.c_dim))
        x_fake1, x_fake_att1 = self.gen.decode(content_fake1, style1)
        style2 = dist_sampling_split(c_trg, self.c_dim, self.stddev,
                                     self.device)
        content_fake2 = content_real.clone()
        if self.two_sided:
            content_fake2 = self.two_sided_attention(
                content_fake2, style2.reshape(-1, c_trg.shape[-1], self.c_dim))
        x_fake2, x_fake_att2 = self.gen.decode(content_fake2, style2)
        if self.use_attention:
            x_fake1 = x_fake1 * x_fake_att1 + x_real * (1 - x_fake_att1)
            x_fake2 = x_fake2 * x_fake_att2 + x_real * (1 - x_fake_att2)
        self.loss_ds = torch.mean(torch.abs(x_fake1 - x_fake2.detach()))
        content_rand, style_rand, _ = self.gen.encode(x_fake1)
        self.init_ds_w = max(self.init_ds_w - 1 / 1e5, 0.0)

        # encode again
        content_fake_rec, style_fake_rec, _ = self.gen.encode(x_fake)
        # decode again (if needed)
        if configs['recon_x_cyc_w'] > 0:
            x_cycle, x_cycle_att = self.gen.decode(
                self.two_sided_attention(content_fake_rec.clone(),
                                         style_fake_rec),
                torch.cat(style_real, dim=1))
            if self.use_attention:
                x_cycle = x_cycle * x_cycle_att + x_real * (1 - x_cycle_att)

        # reconstruction loss
        self.loss_gen_recon_x = self.recon_criterion(x_real_rec, x_real)
        self.loss_gen_recon_c_real = self.recon_criterion(
            content_real_rec, content_real)
        self.loss_gen_recon_c_fake = self.recon_criterion(
            content_fake_rec, content_real)
        self.loss_gen_recon_c_rand = self.recon_criterion(
            content_rand, content_real)
        self.loss_gen_recon_s_real = self.criterion_l1(style_real_rec,
                                                       style_real)
        self.loss_gen_recon_s_fake = self.criterion_l1(style_fake_rec,
                                                       style_txt)
        self.loss_gen_recon_s_rand = self.criterion_l1(style_rand, style1)

        self.loss_gen_cycrecon_x = 0
        if configs['recon_x_cyc_w'] > 0:
            self.loss_gen_cycrecon_x = self.recon_criterion(x_cycle, x_real)

        # GAN loss
        self.loss_gen_adv = self.dis.calc_gen_loss(x_fake, label_trg, configs['gan_w'], configs['cls_w']) + \
            self.dis.calc_gen_loss(x_fake1, label_trg, configs['gan_w'], configs['cls_w'])

        # KL loss
        self.loss_kl_x, self.loss_kl_trg = 0.0, 0.0
        if self.dist_mode == 'kls':
            self.loss_kl_x = gmm_kl_distance_sp(style_real, logvar, c_src,
                                                self.sigma)
            self.loss_kl_trg = gmm_kl_distance_sp(style_txt, logvar_txt, c_trg,
                                                  self.sigma)
        else:  #  self.dist_mode == 'em':
            self.loss_kl_x = gmm_earth_mover_distance_sp(style_real, c_src)
            self.loss_kl_trg = gmm_earth_mover_distance_sp(style_txt, c_trg)

        # domain-invariant perceptual loss
        self.loss_gen_vgg = 0
        if configs['recon_x_cyc_w'] > 0 and configs['vgg_w'] > 0:
            self.loss_gen_vgg = self.compute_vgg_loss(self.vgg, x_real,
                                                      x_cycle)

        # total loss
        self.loss_gen_total = self.loss_gen_adv + \
                              configs['recon_x_w'] * self.loss_gen_recon_x + \
                              configs['recon_c_w'] * self.loss_gen_recon_c_real + \
                              configs['recon_c_w'] * self.loss_gen_recon_c_fake + \
                              configs['recon_c_w'] * self.loss_gen_recon_c_rand + \
                              configs['recon_s_w'] * self.loss_gen_recon_s_real + \
                              configs['recon_s_w'] * self.loss_gen_recon_s_fake + \
                              configs['recon_s_w'] * self.loss_gen_recon_s_rand + \
                              configs['recon_x_cyc_w'] * self.loss_gen_cycrecon_x + \
                              configs['kl_w'] * self.loss_kl_x + \
                              configs['kl_w'] * self.loss_kl_trg + \
                              configs['vgg_w'] * self.loss_gen_vgg - \
                              self.init_ds_w * self.loss_ds
        self.loss_gen_total.backward()
        self.gen_opt.step()